Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions src/a2a/server/agent_execution/active_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,20 @@ async def _handle_task_modification_event(
)

if self.message_to_save is not None:
updated_task = self.active_task._task_manager.update_with_message(
self.message_to_save,
updated_task,
message_already_saved = any(
message.message_id == self.message_to_save.message_id
for message in updated_task.history
)
Comment thread
AXEG0 marked this conversation as resolved.
await self.active_task._task_manager.save_task_event(updated_task)
if not message_already_saved:
updated_task = (
self.active_task._task_manager.update_with_message(
self.message_to_save,
updated_task,
)
)
await self.active_task._task_manager.save_task_event(
updated_task
)
self.message_to_save = None

self.active_task._task_manager.context_id = event.context_id
Expand Down Expand Up @@ -551,12 +560,15 @@ async def _run_producer(self) -> None:
'Producer[%s]: Execution failed',
self._task_id,
)
# Create task and mark as failed.
# Persist the failure directly instead of relying on the closing
# event queue to carry a final status update.
if request_context:
await self._task_manager.ensure_task_id(
task = await self._task_manager.ensure_task_id(
self._task_id,
request_context.context_id or '',
)
task.status.state = TaskState.TASK_STATE_FAILED
await self._task_manager.save_task_event(task)
self._task_created.set()
await self._event_queue_agent.enqueue_event(cast('Event', e))

Expand Down
4 changes: 3 additions & 1 deletion src/a2a/server/agent_execution/active_task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from a2a.server.context import ServerCallContext
from a2a.server.tasks.push_notification_sender import PushNotificationSender
from a2a.server.tasks.task_store import TaskStore
from a2a.types.a2a_pb2 import Message

from a2a.server.agent_execution.active_task import ActiveTask
from a2a.server.tasks.task_manager import TaskManager
Expand Down Expand Up @@ -41,6 +42,7 @@ async def get_or_create(
call_context: ServerCallContext,
context_id: str | None = None,
create_task_if_missing: bool = False,
initial_message: Message | None = None,
) -> ActiveTask:
"""Retrieves an existing ActiveTask or creates a new one."""
async with self._lock:
Expand All @@ -51,7 +53,7 @@ async def get_or_create(
task_id=task_id,
context_id=context_id,
task_store=self._task_store,
initial_message=None,
initial_message=initial_message,
context=call_context,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ async def _setup_active_task(
context_id=context_id,
call_context=call_context,
create_task_if_missing=True,
initial_message=request_context.message,
)

return active_task, request_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry
from a2a.server.context import ServerCallContext
from a2a.server.events import EventQueue
from a2a.server.events.event_queue_v2 import EventQueueSource
from a2a.server.request_handlers import DefaultRequestHandlerV2
from a2a.server.tasks import (
InMemoryPushNotificationConfigStore,
Expand Down Expand Up @@ -306,6 +307,26 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue):
pass


class EarlyFailingAgentExecutor(AgentExecutor):
async def execute(self, context: RequestContext, event_queue: EventQueue):
raise RuntimeError('early producer failure')

async def cancel(self, context: RequestContext, event_queue: EventQueue):
pass


async def send_message_with_early_failure(
request_handler: DefaultRequestHandlerV2,
params: SendMessageRequest,
context: ServerCallContext,
) -> Message | Task | None:
try:
return await request_handler.on_message_send(params, context)
except RuntimeError as e:
assert str(e) == 'early producer failure'
return None


@pytest.mark.asyncio
async def test_on_get_task_limit_history():
task_store = InMemoryTaskStore()
Expand Down Expand Up @@ -1126,6 +1147,69 @@ async def test_on_message_send_limit_history():
assert task.history is not None and len(task.history) > 1


@pytest.mark.asyncio
async def test_on_message_send_early_producer_exception_marks_task_failed_and_preserves_originating_message():
task_store = InMemoryTaskStore()
request_handler = DefaultRequestHandlerV2(
agent_executor=HelloAgentExecutor(),
task_store=task_store,
agent_card=create_default_agent_card(),
)
params = SendMessageRequest(
message=Message(
role=Role.ROLE_USER,
message_id='msg_early_failure_state',
parts=[Part(text='Hi')],
)
)
context = create_server_call_context()
original_enqueue_event = EventQueueSource.enqueue_event

async def fail_before_request_started(self, event):
if type(event).__name__ == '_RequestStarted':
raise RuntimeError('early producer failure')
return await original_enqueue_event(self, event)

with patch.object(
EventQueueSource, 'enqueue_event', fail_before_request_started
):
await send_message_with_early_failure(request_handler, params, context)

stored_task = await task_store.get(params.message.task_id, context)
assert stored_task is not None
assert stored_task.status.state == TaskState.TASK_STATE_FAILED
assert len(stored_task.history) == 1
assert stored_task.history[0].message_id == 'msg_early_failure_state'
assert stored_task.history[0].parts[0].text == 'Hi'


@pytest.mark.asyncio
async def test_on_message_send_early_producer_exception_preserves_originating_message():
task_store = InMemoryTaskStore()
request_handler = DefaultRequestHandlerV2(
agent_executor=EarlyFailingAgentExecutor(),
task_store=task_store,
agent_card=create_default_agent_card(),
)
params = SendMessageRequest(
message=Message(
role=Role.ROLE_USER,
message_id='msg_early_failure_history',
parts=[Part(text='Hi')],
)
)
context = create_server_call_context()

await send_message_with_early_failure(request_handler, params, context)

stored_task = await task_store.get(params.message.task_id, context)
assert stored_task is not None
assert stored_task.history is not None
assert len(stored_task.history) == 1
assert stored_task.history[0].message_id == 'msg_early_failure_history'
assert stored_task.history[0].parts[0].text == 'Hi'


@pytest.mark.asyncio
async def test_on_message_send_stream_task_id_mismatch():
mock_task_store = AsyncMock(spec=TaskStore)
Expand Down
Loading