Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/a2a/server/agent_execution/active_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,6 @@ async def _handle_terminal_state(self, updated_task: Task) -> None:
self.active_task._task_id,
updated_task.status.state,
)
if not self.active_task._is_finished.is_set():
async with self.active_task._lock:
self.active_task._reference_count -= 1

self.active_task._is_finished.set()
self.active_task._request_queue.shutdown(immediate=True)

Expand Down Expand Up @@ -578,7 +574,8 @@ async def _run_consumer(self) -> None:
self._is_finished.set()
self._request_queue.shutdown(immediate=True)
await self._event_queue_agent.close(immediate=True)

async with self._lock:
self._reference_count -= 1
logger.debug('Consumer[%s]: Finishing', self._task_id)
await self._maybe_cleanup()

Expand Down
121 changes: 67 additions & 54 deletions tests/integration/test_scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,32 @@ def build(self, request: Any) -> ServerCallContext:
)


class InputRequiredAgent(AgentExecutor):
"""Test agent: 'start' -> INPUT_REQUIRED, anything else -> COMPLETED."""

async def execute(
self, context: RequestContext, event_queue: EventQueue
) -> None:
message = context.message
if message and message.parts and message.parts[0].text == 'start':
task = new_task_from_user_message(message)
task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED
await event_queue.enqueue_event(task)
else:
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=context.task_id,
context_id=context.context_id,
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
)
)

async def cancel(
self, context: RequestContext, event_queue: EventQueue
) -> None:
pass


def agent_card():
return AgentCard(
name='Test Agent',
Expand Down Expand Up @@ -464,7 +490,7 @@ async def cancel(
if streaming:
with pytest.raises(
InvalidParamsError,
match='Task .* is already completed',
match='Task .* is in terminal state',
):
await client.subscribe(
SubscribeToTaskRequest(id=task.id)
Expand Down Expand Up @@ -1087,32 +1113,7 @@ async def cancel(
'streaming', [False, True], ids=['blocking', 'streaming']
)
async def test_scenario_resumption_from_interrupted(use_legacy, streaming):
class ResumingAgent(AgentExecutor):
async def execute(
self, context: RequestContext, event_queue: EventQueue
):
message = context.message
if message and message.parts and message.parts[0].text == 'start':
task = new_task_from_user_message(message)
task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED
await event_queue.enqueue_event(task)
elif (
message
and message.parts
and message.parts[0].text == 'here is input'
):
task = new_task_from_user_message(message)
task.status.state = TaskState.TASK_STATE_COMPLETED
await event_queue.enqueue_event(task)
else:
raise ValueError('Unexpected message')

async def cancel(
self, context: RequestContext, event_queue: EventQueue
):
pass

handler = create_handler(ResumingAgent(), use_legacy)
handler = create_handler(InputRequiredAgent(), use_legacy)
client = await create_client(
handler, agent_card=agent_card(), streaming=streaming
)
Expand All @@ -1133,8 +1134,8 @@ async def cancel(
assert [get_state(event) for event in events1] == [
TaskState.TASK_STATE_INPUT_REQUIRED,
]
task_id = events1[0].status_update.task_id
context_id = events1[0].status_update.context_id
task_id = get_task_id(events1[0])
context_id = get_task_context_id(events1[0])

# Now send another message to resume
msg2 = Message(
Expand All @@ -1157,6 +1158,37 @@ async def cancel(
]


def test_input_required_followup_across_per_rpc_event_loops():
handler = create_handler(InputRequiredAgent(), use_legacy=False)
call_context = ServerCallContext(user=MockUser())

msg1 = Message(
message_id='msg-start', role=Role.ROLE_USER, parts=[Part(text='start')]
)
req1 = SendMessageRequest(
message=msg1,
configuration=SendMessageConfiguration(return_immediately=False),
)
result1 = asyncio.run(handler.on_message_send(req1, call_context))
assert isinstance(result1, Task)
assert result1.status.state == TaskState.TASK_STATE_INPUT_REQUIRED

msg2 = Message(
task_id=result1.id,
context_id=result1.context_id,
message_id='msg-resume',
role=Role.ROLE_USER,
parts=[Part(text='here is input')],
)
req2 = SendMessageRequest(
message=msg2,
configuration=SendMessageConfiguration(return_immediately=False),
)
result2 = asyncio.run(handler.on_message_send(req2, call_context))
assert isinstance(result2, Task)
assert result2.status.state == TaskState.TASK_STATE_COMPLETED


# Scenario: Auth required and side channel unblocking
# Migrated from: test_workflow_auth_required_side_channel in test_handler_comparison
@pytest.mark.timeout(2.0)
Expand Down Expand Up @@ -1811,31 +1843,10 @@ async def cancel(
async def test_restore_task_input_required_state(
use_legacy, streaming, subscribe_mode
):
class InputAgent(AgentExecutor):
async def execute(
self, context: RequestContext, event_queue: EventQueue
):
message = context.message
if message and message.parts and message.parts[0].text == 'start':
task = new_task_from_user_message(message)
task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED
await event_queue.enqueue_event(task)
elif message and message.parts and message.parts[0].text == 'input':
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=context.task_id,
context_id=context.context_id,
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
)
)

async def cancel(
self, context: RequestContext, event_queue: EventQueue
):
pass

task_store = InMemoryTaskStore()
handler1 = create_handler(InputAgent(), use_legacy, task_store=task_store)
handler1 = create_handler(
InputRequiredAgent(), use_legacy, task_store=task_store
)
client1 = await create_client(
handler1, agent_card=agent_card(), streaming=streaming
)
Expand All @@ -1859,7 +1870,9 @@ async def cancel(
)

# Restore task in a new handler (simulating server restart)
handler2 = create_handler(InputAgent(), use_legacy, task_store=task_store)
handler2 = create_handler(
InputRequiredAgent(), use_legacy, task_store=task_store
)
client2 = await create_client(
handler2, agent_card=agent_card(), streaming=streaming
)
Expand Down
Loading