diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index b0154c8d6..5f965560d 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -260,11 +260,20 @@ async def _handle_task_modification_event( ) if self.message_to_save is not None: - updated_task = self.active_task._task_manager.update_with_message( - self.message_to_save, - updated_task, + message_already_saved = any( + message.message_id == self.message_to_save.message_id + for message in updated_task.history ) - await self.active_task._task_manager.save_task_event(updated_task) + if not message_already_saved: + updated_task = ( + self.active_task._task_manager.update_with_message( + self.message_to_save, + updated_task, + ) + ) + await self.active_task._task_manager.save_task_event( + updated_task + ) self.message_to_save = None self.active_task._task_manager.context_id = event.context_id @@ -551,12 +560,15 @@ async def _run_producer(self) -> None: 'Producer[%s]: Execution failed', self._task_id, ) - # Create task and mark as failed. + # Persist the failure directly instead of relying on the closing + # event queue to carry a final status update. if request_context: - await self._task_manager.ensure_task_id( + task = await self._task_manager.ensure_task_id( self._task_id, request_context.context_id or '', ) + task.status.state = TaskState.TASK_STATE_FAILED + await self._task_manager.save_task_event(task) self._task_created.set() await self._event_queue_agent.enqueue_event(cast('Event', e)) diff --git a/src/a2a/server/agent_execution/active_task_registry.py b/src/a2a/server/agent_execution/active_task_registry.py index 9c1299ab3..4f5822f68 100644 --- a/src/a2a/server/agent_execution/active_task_registry.py +++ b/src/a2a/server/agent_execution/active_task_registry.py @@ -11,6 +11,7 @@ from a2a.server.context import ServerCallContext from a2a.server.tasks.push_notification_sender import PushNotificationSender from a2a.server.tasks.task_store import TaskStore + from a2a.types.a2a_pb2 import Message from a2a.server.agent_execution.active_task import ActiveTask from a2a.server.tasks.task_manager import TaskManager @@ -41,6 +42,7 @@ async def get_or_create( call_context: ServerCallContext, context_id: str | None = None, create_task_if_missing: bool = False, + initial_message: Message | None = None, ) -> ActiveTask: """Retrieves an existing ActiveTask or creates a new one.""" async with self._lock: @@ -51,7 +53,7 @@ async def get_or_create( task_id=task_id, context_id=context_id, task_store=self._task_store, - initial_message=None, + initial_message=initial_message, context=call_context, ) diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index 30304609a..599e38ce5 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -222,6 +222,7 @@ async def _setup_active_task( context_id=context_id, call_context=call_context, create_task_if_missing=True, + initial_message=request_context.message, ) return active_task, request_context diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index caaa4f88e..e49ed5de4 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -21,6 +21,7 @@ from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue +from a2a.server.events.event_queue_v2 import EventQueueSource from a2a.server.request_handlers import DefaultRequestHandlerV2 from a2a.server.tasks import ( InMemoryPushNotificationConfigStore, @@ -306,6 +307,26 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue): pass +class EarlyFailingAgentExecutor(AgentExecutor): + async def execute(self, context: RequestContext, event_queue: EventQueue): + raise RuntimeError('early producer failure') + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +async def send_message_with_early_failure( + request_handler: DefaultRequestHandlerV2, + params: SendMessageRequest, + context: ServerCallContext, +) -> Message | Task | None: + try: + return await request_handler.on_message_send(params, context) + except RuntimeError as e: + assert str(e) == 'early producer failure' + return None + + @pytest.mark.asyncio async def test_on_get_task_limit_history(): task_store = InMemoryTaskStore() @@ -1126,6 +1147,69 @@ async def test_on_message_send_limit_history(): assert task.history is not None and len(task.history) > 1 +@pytest.mark.asyncio +async def test_on_message_send_early_producer_exception_marks_task_failed_and_preserves_originating_message(): + task_store = InMemoryTaskStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_early_failure_state', + parts=[Part(text='Hi')], + ) + ) + context = create_server_call_context() + original_enqueue_event = EventQueueSource.enqueue_event + + async def fail_before_request_started(self, event): + if type(event).__name__ == '_RequestStarted': + raise RuntimeError('early producer failure') + return await original_enqueue_event(self, event) + + with patch.object( + EventQueueSource, 'enqueue_event', fail_before_request_started + ): + await send_message_with_early_failure(request_handler, params, context) + + stored_task = await task_store.get(params.message.task_id, context) + assert stored_task is not None + assert stored_task.status.state == TaskState.TASK_STATE_FAILED + assert len(stored_task.history) == 1 + assert stored_task.history[0].message_id == 'msg_early_failure_state' + assert stored_task.history[0].parts[0].text == 'Hi' + + +@pytest.mark.asyncio +async def test_on_message_send_early_producer_exception_preserves_originating_message(): + task_store = InMemoryTaskStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=EarlyFailingAgentExecutor(), + task_store=task_store, + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_early_failure_history', + parts=[Part(text='Hi')], + ) + ) + context = create_server_call_context() + + await send_message_with_early_failure(request_handler, params, context) + + stored_task = await task_store.get(params.message.task_id, context) + assert stored_task is not None + assert stored_task.history is not None + assert len(stored_task.history) == 1 + assert stored_task.history[0].message_id == 'msg_early_failure_history' + assert stored_task.history[0].parts[0].text == 'Hi' + + @pytest.mark.asyncio async def test_on_message_send_stream_task_id_mismatch(): mock_task_store = AsyncMock(spec=TaskStore)