Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,17 @@
},
"path": "./src/plugin/execution_with_plugin.py"
},
{
"name": "Plugin Wait",
"description": "Test plugin hook emission for a wait completed during suspend",
"handler": "execution_with_wait_plugin.handler",
"integration": true,
"durableConfig": {
"RetentionPeriodInDays": 7,
"ExecutionTimeout": 300
},
"path": "./src/plugin/execution_with_wait_plugin.py"
},
{
"name": "Otel Plugin",
"description": "Test Otel plugin",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Demonstrates plugin hooks for a wait that completes while suspended."""

from typing import Any, ClassVar

from aws_durable_execution_sdk_python.config import Duration
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.execution import durable_execution
from aws_durable_execution_sdk_python.plugin import (
DurableInstrumentationPlugin,
InvocationStartInfo,
OperationEndInfo,
)


class RecordingWaitPlugin(DurableInstrumentationPlugin):
operation_end_infos: ClassVar[list[dict[str, Any]]] = []

@classmethod
def reset(cls) -> None:
cls.operation_end_infos.clear()

@classmethod
def get_wait_end_infos(cls) -> list[dict[str, Any]]:
return [
info
for info in cls.operation_end_infos
if info["operation_type"] == "WAIT" and info["name"] == "plugin-wait"
]

def on_invocation_start(self, _info: InvocationStartInfo) -> None:
self.reset()

def on_operation_end(self, info: OperationEndInfo) -> None:
self.operation_end_infos.append(
{
"operation_type": info.operation_type.value,
"name": info.name,
"status": info.status.value,
"is_replayed": info.is_replayed,
"has_end_time": info.end_time is not None,
}
)


@durable_execution(plugins=[RecordingWaitPlugin()])
def handler(_event: Any, context: DurableContext) -> dict[str, Any]:
context.wait(Duration.from_seconds(1), name="plugin-wait")
return {
"message": "Plugin wait completed",
"wait_end_infos": RecordingWaitPlugin.get_wait_end_infos(),
}
18 changes: 18 additions & 0 deletions packages/aws-durable-execution-sdk-python-examples/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,24 @@
}
}
},
"ExecutionWithWaitPlugin": {
"Type": "AWS::Serverless::Function",
"Properties": {
"CodeUri": "build/",
"Handler": "execution_with_wait_plugin.handler",
"Description": "Test plugin hook emission for a wait completed during suspend",
"Role": {
"Fn::GetAtt": [
"DurableFunctionRole",
"Arn"
]
},
"DurableConfig": {
"RetentionPeriodInDays": 7,
"ExecutionTimeout": 300
}
}
},
"ExecutionWithOtel": {
"Type": "AWS::Serverless::Function",
"Properties": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

import pytest
from aws_durable_execution_sdk_python.execution import InvocationStatus
from aws_durable_execution_sdk_python.lambda_service import (
OperationStatus,
OperationType,
)

from src.plugin import execution_with_plugin
from src.plugin import execution_with_plugin, execution_with_wait_plugin
from test.conftest import deserialize_operation_payload


Expand All @@ -22,3 +26,32 @@ def test_plugin(durable_runner):

step_result = result.get_step("add-result-to-2")
assert deserialize_operation_payload(step_result.result) == 12


@pytest.mark.example
@pytest.mark.durable_execution(
handler=execution_with_wait_plugin.handler,
lambda_function_name="Plugin Wait",
)
def test_plugin_on_operation_end_called_for_wait_completed_during_suspend(
durable_runner, monkeypatch
):
monkeypatch.setenv("DURABLE_EXECUTION_TIME_SCALE", "0.01")

with durable_runner:
result = durable_runner.run(input=None, timeout=30)

assert result.status is InvocationStatus.SUCCEEDED
result_data = deserialize_operation_payload(result.result)
assert result_data["message"] == "Plugin wait completed"

wait_op = result.get_wait("plugin-wait")
assert wait_op.status is OperationStatus.SUCCEEDED

wait_end_infos = result_data["wait_end_infos"]

assert len(wait_end_infos) == 1
assert wait_end_infos[0]["operation_type"] == OperationType.WAIT.value
assert wait_end_infos[0]["status"] == OperationStatus.SUCCEEDED.value
assert wait_end_infos[0]["is_replayed"] is False
assert wait_end_infos[0]["has_end_time"] is True
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def on_operation_start(self, info: OperationStartInfo) -> None:
attributes=attributes,
start_time=info.start_time,
parent_span=parent_span,
existed=info.is_replayed,
)

def on_operation_end(self, info: OperationEndInfo) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,51 @@ def test_operation_end_without_start_emits_continuation_span_with_link():
)


def test_replayed_operation_start_emits_continuation_span_with_link():
"""Replayed operation spans should not reuse the original deterministic span ID."""
plugin, exporter = _create_plugin()
plugin.on_invocation_start(_invocation_start_info())
operation_id = "wait-replayed"
random_span_id = int("abcdef1234567890", 16)
plugin._id_generator._fallback_id_generator.generate_span_id = lambda: (
random_span_id
)

plugin.on_operation_start(
OperationStartInfo(
operation_id=operation_id,
operation_type=OperationType.WAIT,
sub_type=OperationSubType.WAIT,
name="replayed-wait",
parent_id=None,
start_time=START_TIME,
is_replayed=True,
status=OperationStatus.SUCCEEDED,
)
)
plugin.on_operation_end(
OperationEndInfo(
operation_id=operation_id,
operation_type=OperationType.WAIT,
sub_type=OperationSubType.WAIT,
name="replayed-wait",
parent_id=None,
start_time=START_TIME,
is_replayed=True,
status=OperationStatus.SUCCEEDED,
end_time=END_TIME,
error=None,
)
)

span = exporter.get_finished_spans()[0]
assert span.name == "replayed-wait"
assert span.context.span_id == random_span_id
assert span.links[0].context.span_id == operation_id_to_span_id(
EXECUTION_ARN, operation_id
)


def test_step_operation_span_parents_attempt_span():
"""STEP operations have a logical span with attempt spans beneath it."""
plugin, exporter = _create_plugin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
self.operations: list[Operation] = operations
self.updates: list[OperationUpdate] = []
self.invocation_completions: list[InvocationCompletedDetails] = []
self.updated_operation_ids: list[str] = []
self.used_tokens: set[str] = set()
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
self._token_sequence: int = 0
Expand Down Expand Up @@ -108,6 +109,7 @@ def to_json_dict(self) -> dict[str, Any]:
"InvocationCompletions": [
completion.to_json_dict() for completion in self.invocation_completions
],
"UpdatedOperationIds": self.updated_operation_ids,
"UsedTokens": list(self.used_tokens),
"TokenSequence": self._token_sequence,
"IsComplete": self.is_complete,
Expand Down Expand Up @@ -142,6 +144,7 @@ def from_json_dict(cls, data: dict[str, Any]) -> Execution:
InvocationCompletedDetails.from_json_dict(item)
for item in data.get("InvocationCompletions", [])
]
execution.updated_operation_ids = list(data.get("UpdatedOperationIds", []))
execution.used_tokens = set(data["UsedTokens"])
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
execution.is_complete = data["IsComplete"]
Expand Down Expand Up @@ -239,6 +242,12 @@ def record_invocation_completion(
request_id=request_id,
)
)
self.updated_operation_ids = []

def _record_updated_operation(self, operation_id: str) -> None:
"""Remember an operation changed outside the last invocation."""
if operation_id not in self.updated_operation_ids:
self.updated_operation_ids.append(operation_id)

def complete_success(self, result: str | None) -> None:
"""Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION)."""
Expand Down Expand Up @@ -319,6 +328,7 @@ def complete_wait(self, operation_id: str) -> Operation:
status=OperationStatus.SUCCEEDED,
end_timestamp=datetime.now(UTC),
)
self._record_updated_operation(operation_id)
return self.operations[index]

def complete_retry(self, operation_id: str) -> Operation:
Expand Down Expand Up @@ -352,6 +362,7 @@ def complete_retry(self, operation_id: str) -> Operation:

# Assign
self.operations[index] = updated_operation
self._record_updated_operation(operation_id)
return updated_operation

def complete_callback_success(
Expand All @@ -378,6 +389,7 @@ def complete_callback_success(
end_timestamp=datetime.now(UTC),
callback_details=updated_callback_details,
)
self._record_updated_operation(operation.operation_id)
return self.operations[index]

def complete_callback_failure(
Expand All @@ -404,6 +416,7 @@ def complete_callback_failure(
end_timestamp=datetime.now(UTC),
callback_details=updated_callback_details,
)
self._record_updated_operation(operation.operation_id)
return self.operations[index]

def complete_callback_timeout(
Expand All @@ -430,6 +443,7 @@ def complete_callback_timeout(
end_timestamp=datetime.now(UTC),
callback_details=updated_callback_details,
)
self._record_updated_operation(operation.operation_id)
return self.operations[index]

def _end_execution(self, status: OperationStatus) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def create_invocation_input(
operations=execution.operations,
next_marker="",
),
updated_operation_ids=list(execution.updated_operation_ids),
service_client=self.service_client,
)

Expand Down Expand Up @@ -215,6 +216,7 @@ def create_invocation_input(
operations=execution.operations,
next_marker="",
),
updated_operation_ids=list(execution.updated_operation_ids),
)

def invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_execution_init():
assert execution.start_input == start_input
assert execution.operations == operations
assert execution.updates == []
assert execution.updated_operation_ids == []
assert execution.used_tokens == set()
assert execution.token_sequence == 0
assert execution.is_complete is False
Expand Down Expand Up @@ -164,6 +165,54 @@ def test_get_new_checkpoint_token():
assert token1 != token2


def test_complete_wait_records_updated_operation_id():
"""Wait completion happens outside the invocation and is reported on the next input."""
start_input = StartDurableExecutionInput(
account_id="123456789012",
function_name="test-function",
function_qualifier="$LATEST",
execution_name="test-execution",
execution_timeout_seconds=300,
execution_retention_period_days=7,
invocation_id="invocation-id",
)
execution = Execution(
"test-arn",
start_input,
[
Operation(
operation_id="wait-1",
operation_type=OperationType.WAIT,
status=OperationStatus.STARTED,
)
],
)

execution.complete_wait("wait-1")

assert execution.updated_operation_ids == ["wait-1"]


def test_record_invocation_completion_clears_updated_operation_ids():
"""Updated IDs are scoped to the next completed invocation."""
start_input = StartDurableExecutionInput(
account_id="123456789012",
function_name="test-function",
function_qualifier="$LATEST",
execution_name="test-execution",
execution_timeout_seconds=300,
execution_retention_period_days=7,
invocation_id="invocation-id",
)
execution = Execution("test-arn", start_input, [])
execution.updated_operation_ids = ["wait-1"]

now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
execution.record_invocation_completion(now, now, "request-1")

assert execution.updated_operation_ids == []


def test_get_navigable_operations():
"""Test get_navigable_operations method."""
start_input = StartDurableExecutionInput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@ def test_in_process_invoker_create_invocation_input():
invocation_id="test-invocation-id",
)
execution = Execution.new(input_data)
execution.updated_operation_ids = ["wait-1"]

invocation_input = invoker.create_invocation_input(execution)

assert isinstance(invocation_input, DurableExecutionInvocationInputWithClient)
assert invocation_input.durable_execution_arn == execution.durable_execution_arn
assert invocation_input.checkpoint_token is not None
assert isinstance(invocation_input.initial_execution_state, InitialExecutionState)
assert invocation_input.updated_operation_ids == ["wait-1"]
assert invocation_input.service_client is service_client


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,12 +509,18 @@ def _replay_aware(self):
next_exists
and next_checkpoint.operation.operation_type is OperationType.STEP
)

# While replaying, an operation that already has a checkpoint was
# observed in a prior invocation — notify plugins (once per op). State
# observed in a prior invocation. If the backend says the operation
# changed since the last invocation, notify plugins as an update rather
# than replayed history; otherwise notify as replayed history. State
# owns the dedup; the context owns the "only while replaying" gate.
if was_replaying and next_exists:
self.state.emit_operation_replay_hook(next_checkpoint.operation)
if self.state.is_operation_updated_since_last_invocation(
next_checkpoint.operation.operation_id
):
self.state.emit_operation_update_hook(next_checkpoint.operation)
else:
self.state.emit_operation_replay_hook(next_checkpoint.operation)
# Deferred flip applies only to non-step resume points. For step ops we
# flip before instead, so don't defer.
flip_after: bool = (
Expand Down
Loading