From 43daadbd332f49a35bfdb7d3171fb5fd38efc7d9 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 26 May 2026 09:44:51 +0000 Subject: [PATCH 1/3] fix: release producer reference on non-terminal exit to prevent stale registry entries Handles cases when producer finishes without reaching terminal task state. --- src/a2a/server/agent_execution/active_task.py | 7 +- tests/integration/test_scenarios.py | 86 +++++++++++++------ 2 files changed, 63 insertions(+), 30 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index fe8fca38c..a74c418f5 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -302,10 +302,6 @@ async def _handle_terminal_state(self, updated_task: Task) -> None: self.active_task._task_id, updated_task.status.state, ) - if not self.active_task._is_finished.is_set(): - async with self.active_task._lock: - self.active_task._reference_count -= 1 - self.active_task._is_finished.set() self.active_task._request_queue.shutdown(immediate=True) @@ -569,6 +565,9 @@ async def _run_producer(self) -> None: await self._event_queue_agent.close(immediate=False) await self._event_queue_subscribers.close(immediate=False) logger.debug('Producer[%s]: Completed', self._task_id) + async with self._lock: + self._reference_count -= 1 + await self._maybe_cleanup() async def _run_consumer(self) -> None: """Consumes events from the agent and updates system state.""" diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index 3f2383fae..3145d90d4 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -101,6 +101,34 @@ def build(self, request: Any) -> ServerCallContext: ) +class ResumingAgent(AgentExecutor): + """Test agent: 'start' -> INPUT_REQUIRED, 'here is input' -> COMPLETED.""" + + async def execute( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + message = context.message + if message and message.parts and message.parts[0].text == 'start': + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED + await event_queue.enqueue_event(task) + elif ( + message + and message.parts + and message.parts[0].text == 'here is input' + ): + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) + else: + raise ValueError('Unexpected message') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + pass + + def agent_card(): return AgentCard( name='Test Agent', @@ -464,7 +492,7 @@ async def cancel( if streaming: with pytest.raises( InvalidParamsError, - match='Task .* is already completed', + match='Task .* is in terminal state', ): await client.subscribe( SubscribeToTaskRequest(id=task.id) @@ -1087,31 +1115,6 @@ async def cancel( 'streaming', [False, True], ids=['blocking', 'streaming'] ) async def test_scenario_resumption_from_interrupted(use_legacy, streaming): - class ResumingAgent(AgentExecutor): - async def execute( - self, context: RequestContext, event_queue: EventQueue - ): - message = context.message - if message and message.parts and message.parts[0].text == 'start': - task = new_task_from_user_message(message) - task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED - await event_queue.enqueue_event(task) - elif ( - message - and message.parts - and message.parts[0].text == 'here is input' - ): - task = new_task_from_user_message(message) - task.status.state = TaskState.TASK_STATE_COMPLETED - await event_queue.enqueue_event(task) - else: - raise ValueError('Unexpected message') - - async def cancel( - self, context: RequestContext, event_queue: EventQueue - ): - pass - handler = create_handler(ResumingAgent(), use_legacy) client = await create_client( handler, agent_card=agent_card(), streaming=streaming @@ -1157,6 +1160,37 @@ async def cancel( ] +def test_input_required_followup_across_per_rpc_event_loops(): + handler = create_handler(ResumingAgent(), use_legacy=False) + call_context = ServerCallContext(user=MockUser()) + + msg1 = Message( + message_id='msg-start', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + req1 = SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + result1 = asyncio.run(handler.on_message_send(req1, call_context)) + assert isinstance(result1, Task) + assert result1.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + msg2 = Message( + task_id=result1.id, + context_id=result1.context_id, + message_id='msg-resume', + role=Role.ROLE_USER, + parts=[Part(text='here is input')], + ) + req2 = SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + result2 = asyncio.run(handler.on_message_send(req2, call_context)) + assert isinstance(result2, Task) + assert result2.status.state == TaskState.TASK_STATE_COMPLETED + + # Scenario: Auth required and side channel unblocking # Migrated from: test_workflow_auth_required_side_channel in test_handler_comparison @pytest.mark.timeout(2.0) From 81d80493d0816dec5a6b2f414e97c33cc1323f39 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 26 May 2026 10:34:37 +0000 Subject: [PATCH 2/3] Cosmetics --- tests/integration/test_scenarios.py | 59 ++++++++++------------------- 1 file changed, 19 insertions(+), 40 deletions(-) diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index 3145d90d4..762270a9a 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -101,8 +101,8 @@ def build(self, request: Any) -> ServerCallContext: ) -class ResumingAgent(AgentExecutor): - """Test agent: 'start' -> INPUT_REQUIRED, 'here is input' -> COMPLETED.""" +class InputRequiredAgent(AgentExecutor): + """Test agent: 'start' -> INPUT_REQUIRED, anything else -> COMPLETED.""" async def execute( self, context: RequestContext, event_queue: EventQueue @@ -112,16 +112,14 @@ async def execute( task = new_task_from_user_message(message) task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED await event_queue.enqueue_event(task) - elif ( - message - and message.parts - and message.parts[0].text == 'here is input' - ): - task = new_task_from_user_message(message) - task.status.state = TaskState.TASK_STATE_COMPLETED - await event_queue.enqueue_event(task) else: - raise ValueError('Unexpected message') + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) async def cancel( self, context: RequestContext, event_queue: EventQueue @@ -1115,7 +1113,7 @@ async def cancel( 'streaming', [False, True], ids=['blocking', 'streaming'] ) async def test_scenario_resumption_from_interrupted(use_legacy, streaming): - handler = create_handler(ResumingAgent(), use_legacy) + handler = create_handler(InputRequiredAgent(), use_legacy) client = await create_client( handler, agent_card=agent_card(), streaming=streaming ) @@ -1136,8 +1134,8 @@ async def test_scenario_resumption_from_interrupted(use_legacy, streaming): assert [get_state(event) for event in events1] == [ TaskState.TASK_STATE_INPUT_REQUIRED, ] - task_id = events1[0].status_update.task_id - context_id = events1[0].status_update.context_id + task_id = get_task_id(events1[0]) + context_id = get_task_context_id(events1[0]) # Now send another message to resume msg2 = Message( @@ -1161,7 +1159,7 @@ async def test_scenario_resumption_from_interrupted(use_legacy, streaming): def test_input_required_followup_across_per_rpc_event_loops(): - handler = create_handler(ResumingAgent(), use_legacy=False) + handler = create_handler(InputRequiredAgent(), use_legacy=False) call_context = ServerCallContext(user=MockUser()) msg1 = Message( @@ -1845,31 +1843,10 @@ async def cancel( async def test_restore_task_input_required_state( use_legacy, streaming, subscribe_mode ): - class InputAgent(AgentExecutor): - async def execute( - self, context: RequestContext, event_queue: EventQueue - ): - message = context.message - if message and message.parts and message.parts[0].text == 'start': - task = new_task_from_user_message(message) - task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED - await event_queue.enqueue_event(task) - elif message and message.parts and message.parts[0].text == 'input': - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) - - async def cancel( - self, context: RequestContext, event_queue: EventQueue - ): - pass - task_store = InMemoryTaskStore() - handler1 = create_handler(InputAgent(), use_legacy, task_store=task_store) + handler1 = create_handler( + InputRequiredAgent(), use_legacy, task_store=task_store + ) client1 = await create_client( handler1, agent_card=agent_card(), streaming=streaming ) @@ -1893,7 +1870,9 @@ async def cancel( ) # Restore task in a new handler (simulating server restart) - handler2 = create_handler(InputAgent(), use_legacy, task_store=task_store) + handler2 = create_handler( + InputRequiredAgent(), use_legacy, task_store=task_store + ) client2 = await create_client( handler2, agent_card=agent_card(), streaming=streaming ) From a19fa5e1e6b41d97c58122f7b6abca726e701d0e Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 26 May 2026 12:58:47 +0000 Subject: [PATCH 3/3] Move to _run_consumer --- src/a2a/server/agent_execution/active_task.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index a74c418f5..cb1e0a061 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -565,9 +565,6 @@ async def _run_producer(self) -> None: await self._event_queue_agent.close(immediate=False) await self._event_queue_subscribers.close(immediate=False) logger.debug('Producer[%s]: Completed', self._task_id) - async with self._lock: - self._reference_count -= 1 - await self._maybe_cleanup() async def _run_consumer(self) -> None: """Consumes events from the agent and updates system state.""" @@ -577,7 +574,8 @@ async def _run_consumer(self) -> None: self._is_finished.set() self._request_queue.shutdown(immediate=True) await self._event_queue_agent.close(immediate=True) - + async with self._lock: + self._reference_count -= 1 logger.debug('Consumer[%s]: Finishing', self._task_id) await self._maybe_cleanup()