From 37add3f06565b1b0bdf3070cc5eee9048d3628c8 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 2 Jul 2026 11:24:38 -0700 Subject: [PATCH] fix: support updated operation IDs in replay --- .../examples-catalog.json | 11 +++ .../src/plugin/execution_with_wait_plugin.py | 51 ++++++++++ .../template.yaml | 18 ++++ .../test/plugin/test_plugin.py | 35 ++++++- .../plugin.py | 1 + .../tests/test_plugin.py | 45 +++++++++ .../execution.py | 14 +++ .../invoker.py | 2 + .../tests/execution_test.py | 49 ++++++++++ .../tests/invoker_test.py | 2 + .../context.py | 12 ++- .../execution.py | 9 +- .../aws_durable_execution_sdk_python/state.py | 25 +++++ .../tests/context_test.py | 96 ++++++++++++++++++- .../tests/execution_test.py | 12 +++ 15 files changed, 376 insertions(+), 6 deletions(-) create mode 100644 packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_wait_plugin.py diff --git a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json index 079eed5d..18a2083d 100644 --- a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json +++ b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json @@ -644,6 +644,17 @@ }, "path": "./src/plugin/execution_with_plugin.py" }, + { + "name": "Plugin Wait", + "description": "Test plugin hook emission for a wait completed during suspend", + "handler": "execution_with_wait_plugin.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "path": "./src/plugin/execution_with_wait_plugin.py" + }, { "name": "Otel Plugin", "description": "Test Otel plugin", diff --git a/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_wait_plugin.py b/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_wait_plugin.py new file mode 100644 index 00000000..fb13cc81 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_wait_plugin.py @@ -0,0 +1,51 @@ +"""Demonstrates plugin hooks for a wait that completes while suspended.""" + +from typing import Any, ClassVar + +from aws_durable_execution_sdk_python.config import Duration +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import durable_execution +from aws_durable_execution_sdk_python.plugin import ( + DurableInstrumentationPlugin, + InvocationStartInfo, + OperationEndInfo, +) + + +class RecordingWaitPlugin(DurableInstrumentationPlugin): + operation_end_infos: ClassVar[list[dict[str, Any]]] = [] + + @classmethod + def reset(cls) -> None: + cls.operation_end_infos.clear() + + @classmethod + def get_wait_end_infos(cls) -> list[dict[str, Any]]: + return [ + info + for info in cls.operation_end_infos + if info["operation_type"] == "WAIT" and info["name"] == "plugin-wait" + ] + + def on_invocation_start(self, _info: InvocationStartInfo) -> None: + self.reset() + + def on_operation_end(self, info: OperationEndInfo) -> None: + self.operation_end_infos.append( + { + "operation_type": info.operation_type.value, + "name": info.name, + "status": info.status.value, + "is_replayed": info.is_replayed, + "has_end_time": info.end_time is not None, + } + ) + + +@durable_execution(plugins=[RecordingWaitPlugin()]) +def handler(_event: Any, context: DurableContext) -> dict[str, Any]: + context.wait(Duration.from_seconds(1), name="plugin-wait") + return { + "message": "Plugin wait completed", + "wait_end_infos": RecordingWaitPlugin.get_wait_end_infos(), + } diff --git a/packages/aws-durable-execution-sdk-python-examples/template.yaml b/packages/aws-durable-execution-sdk-python-examples/template.yaml index 073757e6..fa1220a7 100644 --- a/packages/aws-durable-execution-sdk-python-examples/template.yaml +++ b/packages/aws-durable-execution-sdk-python-examples/template.yaml @@ -1050,6 +1050,24 @@ } } }, + "ExecutionWithWaitPlugin": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "execution_with_wait_plugin.handler", + "Description": "Test plugin hook emission for a wait completed during suspend", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } + }, "ExecutionWithOtel": { "Type": "AWS::Serverless::Function", "Properties": { diff --git a/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py b/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py index 5e21ba6e..28d041b0 100644 --- a/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py +++ b/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py @@ -2,8 +2,12 @@ import pytest from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, + OperationType, +) -from src.plugin import execution_with_plugin +from src.plugin import execution_with_plugin, execution_with_wait_plugin from test.conftest import deserialize_operation_payload @@ -22,3 +26,32 @@ def test_plugin(durable_runner): step_result = result.get_step("add-result-to-2") assert deserialize_operation_payload(step_result.result) == 12 + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=execution_with_wait_plugin.handler, + lambda_function_name="Plugin Wait", +) +def test_plugin_on_operation_end_called_for_wait_completed_during_suspend( + durable_runner, monkeypatch +): + monkeypatch.setenv("DURABLE_EXECUTION_TIME_SCALE", "0.01") + + with durable_runner: + result = durable_runner.run(input=None, timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + result_data = deserialize_operation_payload(result.result) + assert result_data["message"] == "Plugin wait completed" + + wait_op = result.get_wait("plugin-wait") + assert wait_op.status is OperationStatus.SUCCEEDED + + wait_end_infos = result_data["wait_end_infos"] + + assert len(wait_end_infos) == 1 + assert wait_end_infos[0]["operation_type"] == OperationType.WAIT.value + assert wait_end_infos[0]["status"] == OperationStatus.SUCCEEDED.value + assert wait_end_infos[0]["is_replayed"] is False + assert wait_end_infos[0]["has_end_time"] is True diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py index b44fa1d6..edac4e7b 100644 --- a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py @@ -367,6 +367,7 @@ def on_operation_start(self, info: OperationStartInfo) -> None: attributes=attributes, start_time=info.start_time, parent_span=parent_span, + existed=info.is_replayed, ) def on_operation_end(self, info: OperationEndInfo) -> None: diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py index 642dfaa8..dc25bb1a 100644 --- a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py @@ -247,6 +247,51 @@ def test_operation_end_without_start_emits_continuation_span_with_link(): ) +def test_replayed_operation_start_emits_continuation_span_with_link(): + """Replayed operation spans should not reuse the original deterministic span ID.""" + plugin, exporter = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + operation_id = "wait-replayed" + random_span_id = int("abcdef1234567890", 16) + plugin._id_generator._fallback_id_generator.generate_span_id = lambda: ( + random_span_id + ) + + plugin.on_operation_start( + OperationStartInfo( + operation_id=operation_id, + operation_type=OperationType.WAIT, + sub_type=OperationSubType.WAIT, + name="replayed-wait", + parent_id=None, + start_time=START_TIME, + is_replayed=True, + status=OperationStatus.SUCCEEDED, + ) + ) + plugin.on_operation_end( + OperationEndInfo( + operation_id=operation_id, + operation_type=OperationType.WAIT, + sub_type=OperationSubType.WAIT, + name="replayed-wait", + parent_id=None, + start_time=START_TIME, + is_replayed=True, + status=OperationStatus.SUCCEEDED, + end_time=END_TIME, + error=None, + ) + ) + + span = exporter.get_finished_spans()[0] + assert span.name == "replayed-wait" + assert span.context.span_id == random_span_id + assert span.links[0].context.span_id == operation_id_to_span_id( + EXECUTION_ARN, operation_id + ) + + def test_step_operation_span_parents_attempt_span(): """STEP operations have a logical span with attempt spans beneath it.""" plugin, exporter = _create_plugin() diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py index a1096f1d..feabeb33 100644 --- a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py @@ -60,6 +60,7 @@ def __init__( self.operations: list[Operation] = operations self.updates: list[OperationUpdate] = [] self.invocation_completions: list[InvocationCompletedDetails] = [] + self.updated_operation_ids: list[str] = [] self.used_tokens: set[str] = set() # TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store self._token_sequence: int = 0 @@ -108,6 +109,7 @@ def to_json_dict(self) -> dict[str, Any]: "InvocationCompletions": [ completion.to_json_dict() for completion in self.invocation_completions ], + "UpdatedOperationIds": self.updated_operation_ids, "UsedTokens": list(self.used_tokens), "TokenSequence": self._token_sequence, "IsComplete": self.is_complete, @@ -142,6 +144,7 @@ def from_json_dict(cls, data: dict[str, Any]) -> Execution: InvocationCompletedDetails.from_json_dict(item) for item in data.get("InvocationCompletions", []) ] + execution.updated_operation_ids = list(data.get("UpdatedOperationIds", [])) execution.used_tokens = set(data["UsedTokens"]) execution._token_sequence = data["TokenSequence"] # noqa: SLF001 execution.is_complete = data["IsComplete"] @@ -239,6 +242,12 @@ def record_invocation_completion( request_id=request_id, ) ) + self.updated_operation_ids = [] + + def _record_updated_operation(self, operation_id: str) -> None: + """Remember an operation changed outside the last invocation.""" + if operation_id not in self.updated_operation_ids: + self.updated_operation_ids.append(operation_id) def complete_success(self, result: str | None) -> None: """Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION).""" @@ -319,6 +328,7 @@ def complete_wait(self, operation_id: str) -> Operation: status=OperationStatus.SUCCEEDED, end_timestamp=datetime.now(UTC), ) + self._record_updated_operation(operation_id) return self.operations[index] def complete_retry(self, operation_id: str) -> Operation: @@ -352,6 +362,7 @@ def complete_retry(self, operation_id: str) -> Operation: # Assign self.operations[index] = updated_operation + self._record_updated_operation(operation_id) return updated_operation def complete_callback_success( @@ -378,6 +389,7 @@ def complete_callback_success( end_timestamp=datetime.now(UTC), callback_details=updated_callback_details, ) + self._record_updated_operation(operation.operation_id) return self.operations[index] def complete_callback_failure( @@ -404,6 +416,7 @@ def complete_callback_failure( end_timestamp=datetime.now(UTC), callback_details=updated_callback_details, ) + self._record_updated_operation(operation.operation_id) return self.operations[index] def complete_callback_timeout( @@ -430,6 +443,7 @@ def complete_callback_timeout( end_timestamp=datetime.now(UTC), callback_details=updated_callback_details, ) + self._record_updated_operation(operation.operation_id) return self.operations[index] def _end_execution(self, status: OperationStatus) -> None: diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py index 26143c9e..c77d01e9 100644 --- a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py @@ -121,6 +121,7 @@ def create_invocation_input( operations=execution.operations, next_marker="", ), + updated_operation_ids=list(execution.updated_operation_ids), service_client=self.service_client, ) @@ -215,6 +216,7 @@ def create_invocation_input( operations=execution.operations, next_marker="", ), + updated_operation_ids=list(execution.updated_operation_ids), ) def invoke( diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py index 5bde7471..408a0f02 100644 --- a/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py +++ b/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py @@ -43,6 +43,7 @@ def test_execution_init(): assert execution.start_input == start_input assert execution.operations == operations assert execution.updates == [] + assert execution.updated_operation_ids == [] assert execution.used_tokens == set() assert execution.token_sequence == 0 assert execution.is_complete is False @@ -164,6 +165,54 @@ def test_get_new_checkpoint_token(): assert token1 != token2 +def test_complete_wait_records_updated_operation_id(): + """Wait completion happens outside the invocation and is reported on the next input.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-id", + ) + execution = Execution( + "test-arn", + start_input, + [ + Operation( + operation_id="wait-1", + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + ], + ) + + execution.complete_wait("wait-1") + + assert execution.updated_operation_ids == ["wait-1"] + + +def test_record_invocation_completion_clears_updated_operation_ids(): + """Updated IDs are scoped to the next completed invocation.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-id", + ) + execution = Execution("test-arn", start_input, []) + execution.updated_operation_ids = ["wait-1"] + + now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + execution.record_invocation_completion(now, now, "request-1") + + assert execution.updated_operation_ids == [] + + def test_get_navigable_operations(): """Test get_navigable_operations method.""" start_input = StartDurableExecutionInput( diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py index 1270f501..f0be76fa 100644 --- a/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py +++ b/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py @@ -74,6 +74,7 @@ def test_in_process_invoker_create_invocation_input(): invocation_id="test-invocation-id", ) execution = Execution.new(input_data) + execution.updated_operation_ids = ["wait-1"] invocation_input = invoker.create_invocation_input(execution) @@ -81,6 +82,7 @@ def test_in_process_invoker_create_invocation_input(): assert invocation_input.durable_execution_arn == execution.durable_execution_arn assert invocation_input.checkpoint_token is not None assert isinstance(invocation_input.initial_execution_state, InitialExecutionState) + assert invocation_input.updated_operation_ids == ["wait-1"] assert invocation_input.service_client is service_client diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py index 20c48619..0beb1e4c 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py @@ -509,12 +509,18 @@ def _replay_aware(self): next_exists and next_checkpoint.operation.operation_type is OperationType.STEP ) - # While replaying, an operation that already has a checkpoint was - # observed in a prior invocation — notify plugins (once per op). State + # observed in a prior invocation. If the backend says the operation + # changed since the last invocation, notify plugins as an update rather + # than replayed history; otherwise notify as replayed history. State # owns the dedup; the context owns the "only while replaying" gate. if was_replaying and next_exists: - self.state.emit_operation_replay_hook(next_checkpoint.operation) + if self.state.is_operation_updated_since_last_invocation( + next_checkpoint.operation.operation_id + ): + self.state.emit_operation_update_hook(next_checkpoint.operation) + else: + self.state.emit_operation_replay_hook(next_checkpoint.operation) # Deferred flip applies only to non-step resume points. For step ops we # flip before instead, so don't defer. flip_after: bool = ( diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py index 95d5e756..112bf115 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py @@ -6,7 +6,7 @@ import logging import warnings from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from aws_durable_execution_sdk_python.context import DurableContext @@ -94,6 +94,7 @@ class DurableExecutionInvocationInput: durable_execution_arn: str checkpoint_token: str initial_execution_state: InitialExecutionState + updated_operation_ids: list[str] = field(default_factory=list, kw_only=True) @staticmethod def from_dict( @@ -105,6 +106,7 @@ def from_dict( initial_execution_state=InitialExecutionState.from_dict( input_dict.get("InitialExecutionState", {}) ), + updated_operation_ids=list(input_dict.get("UpdatedOperationIds", [])), ) @staticmethod @@ -117,6 +119,7 @@ def from_json_dict( initial_execution_state=InitialExecutionState.from_json_dict( input_dict.get("InitialExecutionState", {}) ), + updated_operation_ids=list(input_dict.get("UpdatedOperationIds", [])), ) def to_dict(self) -> MutableMapping[str, Any]: @@ -124,6 +127,7 @@ def to_dict(self) -> MutableMapping[str, Any]: "DurableExecutionArn": self.durable_execution_arn, "CheckpointToken": self.checkpoint_token, "InitialExecutionState": self.initial_execution_state.to_dict(), + "UpdatedOperationIds": self.updated_operation_ids, } def to_json_dict(self) -> MutableMapping[str, Any]: @@ -131,6 +135,7 @@ def to_json_dict(self) -> MutableMapping[str, Any]: "DurableExecutionArn": self.durable_execution_arn, "CheckpointToken": self.checkpoint_token, "InitialExecutionState": self.initial_execution_state.to_json_dict(), + "UpdatedOperationIds": self.updated_operation_ids, } @@ -152,6 +157,7 @@ def from_durable_execution_invocation_input( durable_execution_arn=invocation_input.durable_execution_arn, checkpoint_token=invocation_input.checkpoint_token, initial_execution_state=invocation_input.initial_execution_state, + updated_operation_ids=invocation_input.updated_operation_ids, service_client=service_client, ) @@ -228,6 +234,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: operations={}, service_client=service_client, plugin_executor=plugin_executor, + updated_operation_ids=invocation_input.updated_operation_ids, ) try: diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py index df28dc5b..52536314 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py @@ -267,6 +267,7 @@ def __init__( service_client: DurableServiceClient, plugin_executor: PluginExecutor, batcher_config: CheckpointBatcherConfig | None = None, + updated_operation_ids: list[str] | None = None, ): self.durable_execution_arn: str = durable_execution_arn self._current_checkpoint_token: str = initial_checkpoint_token @@ -303,6 +304,13 @@ def __init__( self._replayed_operation_hooks: set[str] = set() self._replayed_operation_hooks_lock: Lock = Lock() + # Operations changed by the backend since the last successful + # invocation, such as waits, callbacks, invokes, or retry timers that + # completed while the Lambda was suspended. These are not "replayed" + # completions: plugins should observe them as operation updates when the + # replay reaches the operation. + self._updated_operation_ids: set[str] = set(updated_operation_ids or []) + @property def operations(self) -> dict[str, Operation]: """Return a point-in-time snapshot copy of the operations map. @@ -454,6 +462,23 @@ def emit_operation_replay_hook(self, operation: Operation) -> None: self._plugin_executor.on_operation_replay(operation) + def is_operation_updated_since_last_invocation(self, operation_id: str) -> bool: + """Return True if an operation changed while this execution was suspended.""" + return operation_id in self._updated_operation_ids + + def emit_operation_update_hook(self, operation: Operation) -> None: + """Fire the plugin update hook for an operation changed during suspend. + + This method is safe to call for any operation. It emits only for + operations listed in UpdatedOperationIds. + """ + if not self.is_operation_updated_since_last_invocation(operation.operation_id): + return + if operation.operation_type is OperationType.EXECUTION: + return + + self._plugin_executor.on_operation_update(operation) + def create_checkpoint( self, operation_update: OperationUpdate | None = None, diff --git a/packages/aws-durable-execution-sdk-python/tests/context_test.py b/packages/aws-durable-execution-sdk-python/tests/context_test.py index a4aaf136..d0061aa9 100644 --- a/packages/aws-durable-execution-sdk-python/tests/context_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/context_test.py @@ -38,7 +38,10 @@ OperationSubType, OperationType, ) -from aws_durable_execution_sdk_python.plugin import PluginExecutor +from aws_durable_execution_sdk_python.plugin import ( + DurableInstrumentationPlugin, + PluginExecutor, +) from aws_durable_execution_sdk_python.state import ( CheckpointedResult, ExecutionState, @@ -2357,6 +2360,15 @@ def _wait_op(operation_id: str, status: OperationStatus) -> Operation: ) +def _callback_op(operation_id: str, status: OperationStatus) -> Operation: + return Operation( + operation_id=operation_id, + operation_type=OperationType.CALLBACK, + status=status, + callback_details=CallbackDetails(callback_id="callback-1"), + ) + + def test_is_replaying_defaults_to_new_for_fresh_context(): """A context created without a replay seed is not replaying.""" ctx = create_test_context(state=_replay_state({})) @@ -2693,4 +2705,86 @@ def test_replay_aware_does_not_emit_replay_hook_when_not_replaying(): assert emitted == [] +def test_replay_aware_emits_update_hook_for_operation_updated_since_last_invocation(): + """Updated terminal operations emit operation_end, not replay start+end.""" + captured: list[tuple[str, str, bool, OperationStatus]] = [] + + class _CapturingPlugin(DurableInstrumentationPlugin): + def on_operation_start(self, info): + captured.append(("start", info.operation_id, info.is_replayed, info.status)) + + def on_operation_end(self, info): + captured.append(("end", info.operation_id, info.is_replayed, info.status)) + + plugin_executor = PluginExecutor(plugins=[_CapturingPlugin()]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="arn", + initial_checkpoint_token="token", # noqa: S106 + operations={}, + service_client=Mock(), + plugin_executor=plugin_executor, + updated_operation_ids=[], + ) + ctx = DurableContext( + state=state, + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + next_id = ctx._peek_next_operation_id() # noqa: SLF001 + state._operations[next_id] = _wait_op( # noqa: SLF001 + next_id, OperationStatus.SUCCEEDED + ) + state._updated_operation_ids.add(next_id) # noqa: SLF001 + + with ctx._replay_aware(): # noqa: SLF001 + ctx._create_step_id() # noqa: SLF001 + + assert captured == [("end", next_id, False, OperationStatus.SUCCEEDED)] + + +def test_replay_aware_updated_callback_with_following_op_stays_replaying(): + """A completed callback is not itself the replay boundary when later replayed ops exist.""" + captured: list[tuple[str, str, bool, OperationStatus]] = [] + + class _CapturingPlugin(DurableInstrumentationPlugin): + def on_operation_start(self, info): + captured.append(("start", info.operation_id, info.is_replayed, info.status)) + + def on_operation_end(self, info): + captured.append(("end", info.operation_id, info.is_replayed, info.status)) + + plugin_executor = PluginExecutor(plugins=[_CapturingPlugin()]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="arn", + initial_checkpoint_token="token", # noqa: S106 + operations={}, + service_client=Mock(), + plugin_executor=plugin_executor, + updated_operation_ids=[], + ) + ctx = DurableContext( + state=state, + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + callback_id = ctx._create_step_id_for_logical_step(1) # noqa: SLF001 + following_id = ctx._create_step_id_for_logical_step(2) # noqa: SLF001 + state._operations[callback_id] = _callback_op( # noqa: SLF001 + callback_id, OperationStatus.SUCCEEDED + ) + state._operations[following_id] = _step_op( # noqa: SLF001 + following_id, OperationStatus.SUCCEEDED + ) + state._updated_operation_ids.add(callback_id) # noqa: SLF001 + + with ctx._replay_aware(): # noqa: SLF001 + ctx._create_step_id() # noqa: SLF001 + + assert ctx.is_replaying() is True + + assert captured == [("end", callback_id, False, OperationStatus.SUCCEEDED)] + + # endregion per-context replay status diff --git a/packages/aws-durable-execution-sdk-python/tests/execution_test.py b/packages/aws-durable-execution-sdk-python/tests/execution_test.py index 26b235d8..a03fe7f8 100644 --- a/packages/aws-durable-execution-sdk-python/tests/execution_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/execution_test.py @@ -74,6 +74,7 @@ def test_durable_execution_invocation_input_from_dict(): ], "NextMarker": "", }, + "UpdatedOperationIds": ["op-1"], } result = DurableExecutionInvocationInput.from_dict(input_dict) @@ -90,6 +91,7 @@ def test_durable_execution_invocation_input_from_dict(): result.initial_execution_state.operations[0].operation_id == "9692ca80-399d-4f52-8d0a-41acc9cd0492" ) + assert result.updated_operation_ids == ["op-1"] def test_initial_execution_state_from_dict_minimal(): @@ -182,6 +184,7 @@ def test_durable_execution_invocation_input_to_dict(): "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_dict(), + "UpdatedOperationIds": [], } assert result == expected @@ -201,6 +204,7 @@ def test_durable_execution_invocation_input_to_dict_not_local(): "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_dict(), + "UpdatedOperationIds": [], } assert result == expected @@ -224,10 +228,12 @@ def test_durable_execution_invocation_input_with_client_inheritance(): "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_dict(), + "UpdatedOperationIds": [], } assert result == expected assert invocation_input.service_client == mock_client + assert invocation_input.updated_operation_ids == [] def test_durable_execution_invocation_input_with_client_from_parent(): @@ -248,6 +254,7 @@ def test_durable_execution_invocation_input_with_client_from_parent(): assert with_client.durable_execution_arn == parent_input.durable_execution_arn assert with_client.checkpoint_token == parent_input.checkpoint_token assert with_client.initial_execution_state == parent_input.initial_execution_state + assert with_client.updated_operation_ids == parent_input.updated_operation_ids assert with_client.service_client == mock_client @@ -2219,6 +2226,7 @@ def test_durable_execution_invocation_input_to_json_dict_minimal(): "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_json_dict(), + "UpdatedOperationIds": [], } assert result == expected @@ -2276,6 +2284,7 @@ def test_durable_execution_invocation_input_to_json_dict_empty_operations(): "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": {"Operations": [], "NextMarker": ""}, + "UpdatedOperationIds": [], } assert result == expected @@ -2306,6 +2315,7 @@ def test_durable_execution_invocation_input_from_json_dict_minimal(): assert len(result.initial_execution_state.operations) == 1 assert result.initial_execution_state.next_marker == "test_marker" assert result.initial_execution_state.operations[0].operation_id == "exec1" + assert result.updated_operation_ids == [] def test_durable_execution_invocation_input_from_json_dict_with_timestamps(): @@ -2408,6 +2418,7 @@ def test_durable_execution_invocation_input_json_roundtrip(): durable_execution_arn="arn:test:execution:12345", checkpoint_token="token123456", # noqa: S106 initial_execution_state=initial_state, + updated_operation_ids=["step1"], ) # Convert to JSON dict and back @@ -2417,6 +2428,7 @@ def test_durable_execution_invocation_input_json_roundtrip(): # Verify all top-level fields are preserved assert restored.durable_execution_arn == original.durable_execution_arn assert restored.checkpoint_token == original.checkpoint_token + assert restored.updated_operation_ids == original.updated_operation_ids # Verify initial execution state is preserved assert len(restored.initial_execution_state.operations) == len(