python-saga-orchestrator 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- python_saga_orchestrator-0.1.0.dist-info/METADATA +341 -0
- python_saga_orchestrator-0.1.0.dist-info/RECORD +25 -0
- python_saga_orchestrator-0.1.0.dist-info/WHEEL +5 -0
- python_saga_orchestrator-0.1.0.dist-info/licenses/LICENSE +21 -0
- python_saga_orchestrator-0.1.0.dist-info/top_level.txt +1 -0
- saga_orchestrator/__init__.py +57 -0
- saga_orchestrator/admin/__init__.py +7 -0
- saga_orchestrator/admin/api.py +47 -0
- saga_orchestrator/core/__init__.py +13 -0
- saga_orchestrator/core/builder.py +106 -0
- saga_orchestrator/core/engine.py +753 -0
- saga_orchestrator/core/orchestrator.py +81 -0
- saga_orchestrator/core/repository.py +166 -0
- saga_orchestrator/domain/__init__.py +1 -0
- saga_orchestrator/domain/exceptions/__init__.py +17 -0
- saga_orchestrator/domain/exceptions/saga.py +22 -0
- saga_orchestrator/domain/mixins/__init__.py +7 -0
- saga_orchestrator/domain/mixins/saga_state.py +57 -0
- saga_orchestrator/domain/models/__init__.py +21 -0
- saga_orchestrator/domain/models/builder.py +12 -0
- saga_orchestrator/domain/models/enums/__init__.py +7 -0
- saga_orchestrator/domain/models/enums/saga_status.py +13 -0
- saga_orchestrator/domain/models/retry.py +50 -0
- saga_orchestrator/domain/models/saga_snapshot.py +36 -0
- saga_orchestrator/domain/models/step.py +87 -0
|
@@ -0,0 +1,753 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from datetime import UTC, datetime, timedelta
|
|
6
|
+
from typing import Any, Generic, TypeVar
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
12
|
+
|
|
13
|
+
from ..domain.exceptions import SagaDefinitionError, SagaStateError
|
|
14
|
+
from ..domain.mixins import SagaStateMixin
|
|
15
|
+
from ..domain.models import (
|
|
16
|
+
InputContext,
|
|
17
|
+
SagaAdminSnapshot,
|
|
18
|
+
SagaDefinition,
|
|
19
|
+
SagaSnapshot,
|
|
20
|
+
StepDefinition,
|
|
21
|
+
)
|
|
22
|
+
from ..domain.models.enums import SagaStatus
|
|
23
|
+
from .repository import SagaRepository
|
|
24
|
+
|
|
25
|
+
ModelT = TypeVar("ModelT", bound=SagaStateMixin)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SagaEngine(Generic[ModelT]):
|
|
29
|
+
"""Execute, resume, recover, and administrate saga instances."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
*,
|
|
34
|
+
model_class: type[ModelT],
|
|
35
|
+
session_maker: async_sessionmaker[AsyncSession],
|
|
36
|
+
execution_lease: timedelta = timedelta(minutes=5),
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Initialize the engine dependencies and execution lease."""
|
|
39
|
+
self._model_class = model_class
|
|
40
|
+
self._session_maker = session_maker
|
|
41
|
+
self._execution_lease = execution_lease
|
|
42
|
+
self._repository = SagaRepository(model_class)
|
|
43
|
+
self._registry: dict[str, SagaDefinition] = {}
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def repository(self) -> SagaRepository[ModelT]:
|
|
47
|
+
"""Return the repository used by the engine."""
|
|
48
|
+
return self._repository
|
|
49
|
+
|
|
50
|
+
def register(self, name: str, saga_definition: SagaDefinition) -> None:
|
|
51
|
+
"""Register a saga definition under a runtime name."""
|
|
52
|
+
if name in self._registry:
|
|
53
|
+
raise SagaDefinitionError(f"Saga '{name}' is already registered")
|
|
54
|
+
self._registry[name] = saga_definition
|
|
55
|
+
|
|
56
|
+
async def start(
|
|
57
|
+
self,
|
|
58
|
+
*,
|
|
59
|
+
saga_name: str,
|
|
60
|
+
initial_data: BaseModel | dict[str, Any] | Any,
|
|
61
|
+
aggregation_id: str,
|
|
62
|
+
trace_id: str | None = None,
|
|
63
|
+
) -> UUID:
|
|
64
|
+
"""Create a new saga instance and start executing it."""
|
|
65
|
+
if saga_name not in self._registry:
|
|
66
|
+
raise SagaDefinitionError(f"Saga '{saga_name}' is not registered")
|
|
67
|
+
|
|
68
|
+
normalized_initial = self._serialize_value(initial_data)
|
|
69
|
+
saga_trace_id = trace_id or str(uuid.uuid4())
|
|
70
|
+
saga_id = uuid.uuid4()
|
|
71
|
+
definition = self._registry[saga_name]
|
|
72
|
+
initial_deadline = self._running_deadline_for_step(
|
|
73
|
+
definition.steps[0],
|
|
74
|
+
now=datetime.now(UTC),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
async with self._session_maker() as session:
|
|
78
|
+
async with session.begin():
|
|
79
|
+
await self._repository.ensure_no_active_aggregation_conflict(
|
|
80
|
+
session,
|
|
81
|
+
aggregation_id,
|
|
82
|
+
)
|
|
83
|
+
saga = self._model_class(
|
|
84
|
+
id=saga_id,
|
|
85
|
+
aggregation_id=aggregation_id,
|
|
86
|
+
trace_id=saga_trace_id,
|
|
87
|
+
status=SagaStatus.RUNNING,
|
|
88
|
+
current_step_index=0,
|
|
89
|
+
step_execution_token=uuid.uuid4(),
|
|
90
|
+
context={
|
|
91
|
+
"saga_name": saga_name,
|
|
92
|
+
"initial_data": normalized_initial,
|
|
93
|
+
"step_outputs": {},
|
|
94
|
+
},
|
|
95
|
+
step_history=[],
|
|
96
|
+
deadline_at=initial_deadline,
|
|
97
|
+
retry_counter=0,
|
|
98
|
+
)
|
|
99
|
+
await self._repository.create(session, saga)
|
|
100
|
+
|
|
101
|
+
await self._drive(saga_id)
|
|
102
|
+
return saga_id
|
|
103
|
+
|
|
104
|
+
async def notify(
|
|
105
|
+
self, *, saga_id: UUID, token: UUID, event: Any | None = None
|
|
106
|
+
) -> bool:
|
|
107
|
+
"""Resume a suspended saga when the provided execution token matches."""
|
|
108
|
+
async with self._session_maker() as session:
|
|
109
|
+
async with session.begin():
|
|
110
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
111
|
+
if saga.status != SagaStatus.SUSPENDED:
|
|
112
|
+
return False
|
|
113
|
+
if saga.step_execution_token != token:
|
|
114
|
+
logger.info("Ignoring stale notify for saga_id=%s", saga_id)
|
|
115
|
+
return False
|
|
116
|
+
if event is not None:
|
|
117
|
+
events = saga.context.setdefault("events", [])
|
|
118
|
+
events.append(self._serialize_value(event))
|
|
119
|
+
saga.context["latest_event"] = self._serialize_value(event)
|
|
120
|
+
saga.status = SagaStatus.RUNNING
|
|
121
|
+
step_def = self._registry[saga.context["saga_name"]].steps[
|
|
122
|
+
saga.current_step_index
|
|
123
|
+
]
|
|
124
|
+
saga.deadline_at = self._running_deadline_for_step(
|
|
125
|
+
step_def,
|
|
126
|
+
now=datetime.now(UTC),
|
|
127
|
+
)
|
|
128
|
+
saga.step_execution_token = uuid.uuid4()
|
|
129
|
+
|
|
130
|
+
await self._drive(saga_id)
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
async def run_due(self, *, limit: int = 100) -> int:
|
|
134
|
+
"""Resume due running, suspended, and compensating sagas."""
|
|
135
|
+
now = datetime.now(UTC)
|
|
136
|
+
ready_ids: list[UUID] = []
|
|
137
|
+
compensation_ids: list[UUID] = []
|
|
138
|
+
|
|
139
|
+
async with self._session_maker() as session:
|
|
140
|
+
async with session.begin():
|
|
141
|
+
due_running = await self._repository.due_running(
|
|
142
|
+
session,
|
|
143
|
+
now=now,
|
|
144
|
+
limit=limit,
|
|
145
|
+
)
|
|
146
|
+
remaining = max(limit - len(due_running), 0)
|
|
147
|
+
due_suspended = await self._repository.due_suspended(
|
|
148
|
+
session,
|
|
149
|
+
now=now,
|
|
150
|
+
limit=remaining,
|
|
151
|
+
)
|
|
152
|
+
remaining -= len(due_suspended)
|
|
153
|
+
due_compensating = await self._repository.due_compensating(
|
|
154
|
+
session,
|
|
155
|
+
now=now,
|
|
156
|
+
limit=max(remaining, 0),
|
|
157
|
+
)
|
|
158
|
+
for saga in [*due_running, *due_suspended]:
|
|
159
|
+
saga.status = SagaStatus.RUNNING
|
|
160
|
+
saga.step_execution_token = uuid.uuid4()
|
|
161
|
+
saga.deadline_at = now + self._execution_lease
|
|
162
|
+
ready_ids.append(saga.id)
|
|
163
|
+
for saga in due_compensating:
|
|
164
|
+
saga.status = SagaStatus.COMPENSATING
|
|
165
|
+
saga.step_execution_token = uuid.uuid4()
|
|
166
|
+
saga.deadline_at = now + self._execution_lease
|
|
167
|
+
compensation_ids.append(saga.id)
|
|
168
|
+
|
|
169
|
+
for saga_id in ready_ids:
|
|
170
|
+
await self._drive(saga_id)
|
|
171
|
+
for saga_id in compensation_ids:
|
|
172
|
+
await self._run_compensation(saga_id)
|
|
173
|
+
|
|
174
|
+
return len(ready_ids) + len(compensation_ids)
|
|
175
|
+
|
|
176
|
+
async def get_snapshot(self, saga_id: UUID) -> SagaSnapshot:
|
|
177
|
+
"""Return the snapshot view of one saga."""
|
|
178
|
+
async with self._session_maker() as session:
|
|
179
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
180
|
+
return self._to_snapshot(saga)
|
|
181
|
+
|
|
182
|
+
async def get_admin_snapshot(self, saga_id: UUID) -> SagaAdminSnapshot:
|
|
183
|
+
"""Return the administrative view of one saga."""
|
|
184
|
+
async with self._session_maker() as session:
|
|
185
|
+
async with session.begin():
|
|
186
|
+
saga = await self._repository.get(session, saga_id)
|
|
187
|
+
return SagaAdminSnapshot(
|
|
188
|
+
id=saga.id,
|
|
189
|
+
aggregation_id=saga.aggregation_id,
|
|
190
|
+
trace_id=saga.trace_id,
|
|
191
|
+
status=saga.status,
|
|
192
|
+
current_step_index=saga.current_step_index,
|
|
193
|
+
step_execution_token=saga.step_execution_token,
|
|
194
|
+
retry_counter=saga.retry_counter,
|
|
195
|
+
deadline_at=saga.deadline_at,
|
|
196
|
+
last_error=saga.last_error,
|
|
197
|
+
context=saga.context,
|
|
198
|
+
step_history=saga.step_history,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
async def resume(self, saga_id: UUID) -> None:
|
|
202
|
+
"""Resume forward execution of one saga."""
|
|
203
|
+
await self._drive(saga_id)
|
|
204
|
+
|
|
205
|
+
async def retry_step(self, saga_id: UUID) -> None:
|
|
206
|
+
"""Reset the current step state and retry it."""
|
|
207
|
+
async with self._session_maker() as session:
|
|
208
|
+
async with session.begin():
|
|
209
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
210
|
+
if saga.status not in {SagaStatus.SUSPENDED, SagaStatus.FAILED}:
|
|
211
|
+
raise SagaStateError(
|
|
212
|
+
f"Cannot retry step when saga status is {saga.status.value}"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
saga_name = saga.context["saga_name"]
|
|
216
|
+
definition = self._registry[saga_name]
|
|
217
|
+
if saga.current_step_index >= len(definition.steps):
|
|
218
|
+
raise SagaStateError(
|
|
219
|
+
"Cannot retry saga because there is no current step to resume"
|
|
220
|
+
)
|
|
221
|
+
if saga.status == SagaStatus.FAILED and self._has_compensation_history(
|
|
222
|
+
saga.step_history
|
|
223
|
+
):
|
|
224
|
+
raise SagaStateError(
|
|
225
|
+
"Cannot retry saga after compensation has already started"
|
|
226
|
+
)
|
|
227
|
+
step_def = definition.steps[saga.current_step_index]
|
|
228
|
+
|
|
229
|
+
saga.status = SagaStatus.RUNNING
|
|
230
|
+
saga.retry_counter = 0
|
|
231
|
+
saga.last_error = None
|
|
232
|
+
saga.step_execution_token = uuid.uuid4()
|
|
233
|
+
saga.deadline_at = self._running_deadline_for_step(
|
|
234
|
+
step_def,
|
|
235
|
+
now=datetime.now(UTC),
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
await self._drive(saga_id)
|
|
239
|
+
|
|
240
|
+
async def compensate_step(self, saga_id: UUID) -> None:
|
|
241
|
+
"""Switch one saga into compensation and execute rollback."""
|
|
242
|
+
async with self._session_maker() as session:
|
|
243
|
+
async with session.begin():
|
|
244
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
245
|
+
if saga.status not in {
|
|
246
|
+
SagaStatus.SUSPENDED,
|
|
247
|
+
SagaStatus.FAILED,
|
|
248
|
+
SagaStatus.COMPENSATING,
|
|
249
|
+
}:
|
|
250
|
+
raise SagaStateError(
|
|
251
|
+
"Cannot start compensation unless saga is suspended, failed, "
|
|
252
|
+
f"or already compensating (status={saga.status.value})"
|
|
253
|
+
)
|
|
254
|
+
if saga.current_step_index <= 0:
|
|
255
|
+
raise SagaStateError(
|
|
256
|
+
"Cannot compensate saga because there is no completed step to roll back"
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
saga.status = SagaStatus.COMPENSATING
|
|
260
|
+
saga.retry_counter = 0
|
|
261
|
+
saga.step_execution_token = uuid.uuid4()
|
|
262
|
+
saga.deadline_at = datetime.now(UTC) + self._execution_lease
|
|
263
|
+
|
|
264
|
+
await self._run_compensation(saga_id)
|
|
265
|
+
|
|
266
|
+
async def abort(self, saga_id: UUID) -> None:
|
|
267
|
+
"""Mark a saga as failed and invalidate its current execution token."""
|
|
268
|
+
async with self._session_maker() as session:
|
|
269
|
+
async with session.begin():
|
|
270
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
271
|
+
saga.status = SagaStatus.FAILED
|
|
272
|
+
saga.deadline_at = None
|
|
273
|
+
saga.last_error = saga.last_error or "Aborted by admin"
|
|
274
|
+
saga.step_execution_token = uuid.uuid4()
|
|
275
|
+
|
|
276
|
+
async def skip_step(
|
|
277
|
+
self,
|
|
278
|
+
saga_id: UUID,
|
|
279
|
+
mock_output: BaseModel | dict[str, Any] | None = None,
|
|
280
|
+
) -> None:
|
|
281
|
+
"""Mark the current step as successful and continue execution."""
|
|
282
|
+
should_resume = False
|
|
283
|
+
|
|
284
|
+
async with self._session_maker() as session:
|
|
285
|
+
async with session.begin():
|
|
286
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
287
|
+
if saga.status != SagaStatus.SUSPENDED:
|
|
288
|
+
raise SagaStateError(
|
|
289
|
+
"Cannot skip step unless saga is suspended on the current step "
|
|
290
|
+
f"(status={saga.status.value})"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
saga_name = saga.context["saga_name"]
|
|
294
|
+
definition = self._registry[saga_name]
|
|
295
|
+
if saga.current_step_index >= len(definition.steps):
|
|
296
|
+
raise SagaStateError("No step available for skipping")
|
|
297
|
+
|
|
298
|
+
step_def = definition.steps[saga.current_step_index]
|
|
299
|
+
if mock_output is None:
|
|
300
|
+
output_payload: dict[str, Any] = {}
|
|
301
|
+
elif isinstance(mock_output, BaseModel):
|
|
302
|
+
output_payload = mock_output.model_dump(mode="json")
|
|
303
|
+
else:
|
|
304
|
+
output_payload = mock_output
|
|
305
|
+
|
|
306
|
+
output_model = step_def.output_model.model_validate(output_payload)
|
|
307
|
+
token = saga.step_execution_token or uuid.uuid4()
|
|
308
|
+
|
|
309
|
+
saga.step_history.append(
|
|
310
|
+
{
|
|
311
|
+
"phase": "execute",
|
|
312
|
+
"status": "SUCCESS",
|
|
313
|
+
"step_id": step_def.step_id,
|
|
314
|
+
"step_name": type(step_def.step).__name__,
|
|
315
|
+
"attempt": 0,
|
|
316
|
+
"token": str(token),
|
|
317
|
+
"input": {"_admin": "skip_step"},
|
|
318
|
+
"output": output_model.model_dump(mode="json"),
|
|
319
|
+
"error": None,
|
|
320
|
+
"skipped": True,
|
|
321
|
+
}
|
|
322
|
+
)
|
|
323
|
+
outputs = saga.context.setdefault("step_outputs", {})
|
|
324
|
+
outputs[step_def.step_id] = output_model.model_dump(mode="json")
|
|
325
|
+
saga.current_step_index += 1
|
|
326
|
+
saga.retry_counter = 0
|
|
327
|
+
saga.last_error = None
|
|
328
|
+
saga.step_execution_token = uuid.uuid4()
|
|
329
|
+
|
|
330
|
+
if saga.current_step_index >= len(definition.steps):
|
|
331
|
+
saga.status = SagaStatus.COMPLETED
|
|
332
|
+
saga.deadline_at = None
|
|
333
|
+
else:
|
|
334
|
+
next_step = definition.steps[saga.current_step_index]
|
|
335
|
+
saga.status = SagaStatus.RUNNING
|
|
336
|
+
saga.deadline_at = self._running_deadline_for_step(
|
|
337
|
+
next_step,
|
|
338
|
+
now=datetime.now(UTC),
|
|
339
|
+
)
|
|
340
|
+
should_resume = True
|
|
341
|
+
|
|
342
|
+
if should_resume:
|
|
343
|
+
await self._drive(saga_id)
|
|
344
|
+
|
|
345
|
+
async def _drive(self, saga_id: UUID) -> None:
|
|
346
|
+
"""Execute forward steps until the saga stops progressing."""
|
|
347
|
+
while True:
|
|
348
|
+
prep = await self._prepare_step(saga_id)
|
|
349
|
+
if prep is None:
|
|
350
|
+
return
|
|
351
|
+
|
|
352
|
+
step_def = prep["step_def"]
|
|
353
|
+
step_token = prep["step_token"]
|
|
354
|
+
step_input = prep["step_input"]
|
|
355
|
+
attempt_number = prep["attempt_number"]
|
|
356
|
+
|
|
357
|
+
success = False
|
|
358
|
+
step_output: BaseModel | None = None
|
|
359
|
+
error: Exception | None = None
|
|
360
|
+
|
|
361
|
+
try:
|
|
362
|
+
if step_def.timeout is None:
|
|
363
|
+
step_output = await step_def.step.execute(step_input)
|
|
364
|
+
else:
|
|
365
|
+
step_output = await asyncio.wait_for(
|
|
366
|
+
step_def.step.execute(step_input),
|
|
367
|
+
timeout=step_def.timeout.total_seconds(),
|
|
368
|
+
)
|
|
369
|
+
success = True
|
|
370
|
+
except Exception as exc: # noqa: BLE001
|
|
371
|
+
error = exc
|
|
372
|
+
|
|
373
|
+
should_continue = await self._finalize_step(
|
|
374
|
+
saga_id=saga_id,
|
|
375
|
+
step_def=step_def,
|
|
376
|
+
token=step_token,
|
|
377
|
+
step_input=step_input,
|
|
378
|
+
step_output=step_output,
|
|
379
|
+
error=error,
|
|
380
|
+
attempt_number=attempt_number,
|
|
381
|
+
)
|
|
382
|
+
if not should_continue:
|
|
383
|
+
return
|
|
384
|
+
if not success:
|
|
385
|
+
return
|
|
386
|
+
|
|
387
|
+
async def _prepare_step(self, saga_id: UUID) -> dict[str, Any] | None:
|
|
388
|
+
"""Load saga state and prepare the current forward step for execution."""
|
|
389
|
+
async with self._session_maker() as session:
|
|
390
|
+
async with session.begin():
|
|
391
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
392
|
+
saga_name = saga.context.get("saga_name")
|
|
393
|
+
if not saga_name or saga_name not in self._registry:
|
|
394
|
+
raise SagaDefinitionError(
|
|
395
|
+
f"Unknown saga registered name in state: '{saga_name}'"
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
definition = self._registry[saga_name]
|
|
399
|
+
if saga.status != SagaStatus.RUNNING:
|
|
400
|
+
return None
|
|
401
|
+
if saga.current_step_index >= len(definition.steps):
|
|
402
|
+
saga.status = SagaStatus.COMPLETED
|
|
403
|
+
saga.deadline_at = None
|
|
404
|
+
saga.last_error = None
|
|
405
|
+
return None
|
|
406
|
+
|
|
407
|
+
step_def = definition.steps[saga.current_step_index]
|
|
408
|
+
step_token = saga.step_execution_token or uuid.uuid4()
|
|
409
|
+
saga.step_execution_token = step_token
|
|
410
|
+
saga.deadline_at = self._running_deadline_for_step(
|
|
411
|
+
step_def,
|
|
412
|
+
now=datetime.now(UTC),
|
|
413
|
+
)
|
|
414
|
+
attempt_number = saga.retry_counter + 1
|
|
415
|
+
step_input = self._build_step_input(step_def, saga.context)
|
|
416
|
+
|
|
417
|
+
return {
|
|
418
|
+
"step_def": step_def,
|
|
419
|
+
"step_token": step_token,
|
|
420
|
+
"step_input": step_input,
|
|
421
|
+
"attempt_number": attempt_number,
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
async def _finalize_step(
|
|
425
|
+
self,
|
|
426
|
+
*,
|
|
427
|
+
saga_id: UUID,
|
|
428
|
+
step_def: StepDefinition[Any, Any],
|
|
429
|
+
token: UUID,
|
|
430
|
+
step_input: BaseModel,
|
|
431
|
+
step_output: BaseModel | None,
|
|
432
|
+
error: Exception | None,
|
|
433
|
+
attempt_number: int,
|
|
434
|
+
) -> bool:
|
|
435
|
+
"""Persist one forward step result and return whether execution continues."""
|
|
436
|
+
async with self._session_maker() as session:
|
|
437
|
+
async with session.begin():
|
|
438
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
439
|
+
if (
|
|
440
|
+
saga.step_execution_token != token
|
|
441
|
+
or saga.status != SagaStatus.RUNNING
|
|
442
|
+
):
|
|
443
|
+
logger.info("Stale step result ignored for saga_id=%s", saga_id)
|
|
444
|
+
return False
|
|
445
|
+
|
|
446
|
+
if error is None and step_output is not None:
|
|
447
|
+
saga.step_history.append(
|
|
448
|
+
self._history_entry(
|
|
449
|
+
phase="execute",
|
|
450
|
+
status="SUCCESS",
|
|
451
|
+
step_def=step_def,
|
|
452
|
+
token=token,
|
|
453
|
+
attempt=attempt_number,
|
|
454
|
+
step_input=step_input,
|
|
455
|
+
step_output=step_output,
|
|
456
|
+
error=None,
|
|
457
|
+
)
|
|
458
|
+
)
|
|
459
|
+
outputs = saga.context.setdefault("step_outputs", {})
|
|
460
|
+
outputs[step_def.step_id] = self._serialize_value(step_output)
|
|
461
|
+
saga.context.pop("latest_event", None)
|
|
462
|
+
saga.current_step_index += 1
|
|
463
|
+
saga.retry_counter = 0
|
|
464
|
+
saga.deadline_at = None
|
|
465
|
+
saga.last_error = None
|
|
466
|
+
saga.step_execution_token = uuid.uuid4()
|
|
467
|
+
|
|
468
|
+
saga_name = saga.context["saga_name"]
|
|
469
|
+
definition = self._registry[saga_name]
|
|
470
|
+
if saga.current_step_index >= len(definition.steps):
|
|
471
|
+
saga.status = SagaStatus.COMPLETED
|
|
472
|
+
return saga.status == SagaStatus.RUNNING
|
|
473
|
+
|
|
474
|
+
assert error is not None
|
|
475
|
+
saga.step_history.append(
|
|
476
|
+
self._history_entry(
|
|
477
|
+
phase="execute",
|
|
478
|
+
status="ERROR",
|
|
479
|
+
step_def=step_def,
|
|
480
|
+
token=token,
|
|
481
|
+
attempt=attempt_number,
|
|
482
|
+
step_input=step_input,
|
|
483
|
+
step_output=None,
|
|
484
|
+
error=error,
|
|
485
|
+
)
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
saga.last_error = repr(error)
|
|
489
|
+
next_attempt = saga.retry_counter + 1
|
|
490
|
+
delay = step_def.retry_policy.next_delay(next_attempt)
|
|
491
|
+
saga.retry_counter = next_attempt
|
|
492
|
+
if delay is not None:
|
|
493
|
+
saga.status = SagaStatus.SUSPENDED
|
|
494
|
+
saga.deadline_at = datetime.now(UTC) + delay
|
|
495
|
+
saga.step_execution_token = uuid.uuid4()
|
|
496
|
+
return False
|
|
497
|
+
|
|
498
|
+
saga_name = saga.context["saga_name"]
|
|
499
|
+
definition = self._registry[saga_name]
|
|
500
|
+
if definition.compensate_on_failure and saga.current_step_index > 0:
|
|
501
|
+
saga.status = SagaStatus.COMPENSATING
|
|
502
|
+
saga.deadline_at = datetime.now(UTC) + self._execution_lease
|
|
503
|
+
else:
|
|
504
|
+
saga.status = SagaStatus.FAILED
|
|
505
|
+
saga.deadline_at = None
|
|
506
|
+
return False
|
|
507
|
+
|
|
508
|
+
await self._run_compensation(saga_id)
|
|
509
|
+
return False
|
|
510
|
+
|
|
511
|
+
async def _run_compensation(self, saga_id: UUID) -> None:
|
|
512
|
+
"""Execute compensation steps until rollback stops."""
|
|
513
|
+
while True:
|
|
514
|
+
comp = await self._prepare_compensation(saga_id)
|
|
515
|
+
if comp is None:
|
|
516
|
+
return
|
|
517
|
+
|
|
518
|
+
step_def = comp["step_def"]
|
|
519
|
+
token = comp["token"]
|
|
520
|
+
original_input = comp["original_input"]
|
|
521
|
+
original_output = comp["original_output"]
|
|
522
|
+
|
|
523
|
+
error: Exception | None = None
|
|
524
|
+
try:
|
|
525
|
+
await step_def.step.compensate(original_input, original_output)
|
|
526
|
+
except Exception as exc: # noqa: BLE001
|
|
527
|
+
error = exc
|
|
528
|
+
|
|
529
|
+
should_continue = await self._finalize_compensation(
|
|
530
|
+
saga_id=saga_id,
|
|
531
|
+
step_def=step_def,
|
|
532
|
+
token=token,
|
|
533
|
+
original_input=original_input,
|
|
534
|
+
original_output=original_output,
|
|
535
|
+
error=error,
|
|
536
|
+
)
|
|
537
|
+
if not should_continue:
|
|
538
|
+
return
|
|
539
|
+
|
|
540
|
+
async def _prepare_compensation(self, saga_id: UUID) -> dict[str, Any] | None:
|
|
541
|
+
"""Load saga state and prepare the next compensation step."""
|
|
542
|
+
async with self._session_maker() as session:
|
|
543
|
+
async with session.begin():
|
|
544
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
545
|
+
if saga.status != SagaStatus.COMPENSATING:
|
|
546
|
+
return None
|
|
547
|
+
|
|
548
|
+
saga_name = saga.context["saga_name"]
|
|
549
|
+
definition = self._registry[saga_name]
|
|
550
|
+
if saga.current_step_index <= 0:
|
|
551
|
+
saga.status = SagaStatus.FAILED
|
|
552
|
+
return None
|
|
553
|
+
|
|
554
|
+
step_idx = saga.current_step_index - 1
|
|
555
|
+
step_def = definition.steps[step_idx]
|
|
556
|
+
|
|
557
|
+
execution_entry = None
|
|
558
|
+
for entry in reversed(saga.step_history):
|
|
559
|
+
if (
|
|
560
|
+
entry.get("phase") == "execute"
|
|
561
|
+
and entry.get("status") == "SUCCESS"
|
|
562
|
+
and entry.get("step_id") == step_def.step_id
|
|
563
|
+
):
|
|
564
|
+
execution_entry = entry
|
|
565
|
+
break
|
|
566
|
+
|
|
567
|
+
if execution_entry is None:
|
|
568
|
+
saga.status = SagaStatus.FAILED
|
|
569
|
+
saga.last_error = (
|
|
570
|
+
"Missing successful execution entry for step "
|
|
571
|
+
f"'{step_def.step_id}'"
|
|
572
|
+
)
|
|
573
|
+
return None
|
|
574
|
+
|
|
575
|
+
token = uuid.uuid4()
|
|
576
|
+
saga.step_execution_token = token
|
|
577
|
+
saga.deadline_at = datetime.now(UTC) + self._execution_lease
|
|
578
|
+
return {
|
|
579
|
+
"step_def": step_def,
|
|
580
|
+
"token": token,
|
|
581
|
+
"original_input": step_def.input_model.model_validate(
|
|
582
|
+
execution_entry["input"]
|
|
583
|
+
),
|
|
584
|
+
"original_output": step_def.output_model.model_validate(
|
|
585
|
+
execution_entry["output"]
|
|
586
|
+
),
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
async def _finalize_compensation(
|
|
590
|
+
self,
|
|
591
|
+
*,
|
|
592
|
+
saga_id: UUID,
|
|
593
|
+
step_def: StepDefinition[Any, Any],
|
|
594
|
+
token: UUID,
|
|
595
|
+
original_input: BaseModel,
|
|
596
|
+
original_output: BaseModel,
|
|
597
|
+
error: Exception | None,
|
|
598
|
+
) -> bool:
|
|
599
|
+
"""Persist one compensation result and return whether rollback continues."""
|
|
600
|
+
async with self._session_maker() as session:
|
|
601
|
+
async with session.begin():
|
|
602
|
+
saga = await self._repository.get_for_update(session, saga_id)
|
|
603
|
+
if (
|
|
604
|
+
saga.status != SagaStatus.COMPENSATING
|
|
605
|
+
or saga.step_execution_token != token
|
|
606
|
+
):
|
|
607
|
+
return False
|
|
608
|
+
|
|
609
|
+
if error is not None:
|
|
610
|
+
saga.step_history.append(
|
|
611
|
+
self._history_entry(
|
|
612
|
+
phase="compensate",
|
|
613
|
+
status="ERROR",
|
|
614
|
+
step_def=step_def,
|
|
615
|
+
token=token,
|
|
616
|
+
attempt=1,
|
|
617
|
+
step_input=original_input,
|
|
618
|
+
step_output=original_output,
|
|
619
|
+
error=error,
|
|
620
|
+
)
|
|
621
|
+
)
|
|
622
|
+
saga.last_error = (
|
|
623
|
+
f"Compensation failed for step '{step_def.step_id}': {error!r}"
|
|
624
|
+
)
|
|
625
|
+
saga.status = SagaStatus.FAILED
|
|
626
|
+
saga.deadline_at = None
|
|
627
|
+
return False
|
|
628
|
+
|
|
629
|
+
saga.step_history.append(
|
|
630
|
+
self._history_entry(
|
|
631
|
+
phase="compensate",
|
|
632
|
+
status="SUCCESS",
|
|
633
|
+
step_def=step_def,
|
|
634
|
+
token=token,
|
|
635
|
+
attempt=1,
|
|
636
|
+
step_input=original_input,
|
|
637
|
+
step_output=original_output,
|
|
638
|
+
error=None,
|
|
639
|
+
)
|
|
640
|
+
)
|
|
641
|
+
saga.current_step_index -= 1
|
|
642
|
+
saga.step_execution_token = uuid.uuid4()
|
|
643
|
+
if saga.current_step_index <= 0:
|
|
644
|
+
saga.status = SagaStatus.FAILED
|
|
645
|
+
saga.deadline_at = None
|
|
646
|
+
return False
|
|
647
|
+
saga.deadline_at = datetime.now(UTC) + self._execution_lease
|
|
648
|
+
return True
|
|
649
|
+
|
|
650
|
+
@staticmethod
|
|
651
|
+
def _build_step_input(
|
|
652
|
+
step_def: StepDefinition[Any, Any],
|
|
653
|
+
context: dict[str, Any],
|
|
654
|
+
) -> BaseModel:
|
|
655
|
+
"""Build the input model for one step from saga context."""
|
|
656
|
+
if step_def.depends_on is not None:
|
|
657
|
+
dep_payload = context.get("step_outputs", {}).get(
|
|
658
|
+
step_def.depends_on.step_id
|
|
659
|
+
)
|
|
660
|
+
if dep_payload is None:
|
|
661
|
+
raise SagaStateError(
|
|
662
|
+
f"Missing dependency output for step '{step_def.depends_on.step_id}'"
|
|
663
|
+
)
|
|
664
|
+
dep_model = step_def.depends_on.output_model.model_validate(dep_payload)
|
|
665
|
+
mapped = step_def.input_map(dep_model)
|
|
666
|
+
else:
|
|
667
|
+
mapped = step_def.input_map(
|
|
668
|
+
InputContext(
|
|
669
|
+
initial_data=context.get("initial_data"),
|
|
670
|
+
context=context,
|
|
671
|
+
step_outputs=context.get("step_outputs", {}),
|
|
672
|
+
latest_event=context.get("latest_event"),
|
|
673
|
+
events=context.get("events"),
|
|
674
|
+
)
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
if isinstance(mapped, step_def.input_model):
|
|
678
|
+
return mapped
|
|
679
|
+
if isinstance(mapped, dict):
|
|
680
|
+
return step_def.input_model.model_validate(mapped)
|
|
681
|
+
raise SagaStateError(
|
|
682
|
+
f"input_map for step '{step_def.step_id}' "
|
|
683
|
+
f"must return {step_def.input_model.__name__} or dict"
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
def _serialize_value(self, value: Any) -> Any:
|
|
687
|
+
"""Convert values into JSON-serializable structures."""
|
|
688
|
+
if isinstance(value, BaseModel):
|
|
689
|
+
return value.model_dump(mode="json")
|
|
690
|
+
if isinstance(value, dict):
|
|
691
|
+
return {key: self._serialize_value(val) for key, val in value.items()}
|
|
692
|
+
if isinstance(value, list):
|
|
693
|
+
return [self._serialize_value(item) for item in value]
|
|
694
|
+
return value
|
|
695
|
+
|
|
696
|
+
def _history_entry(
|
|
697
|
+
self,
|
|
698
|
+
*,
|
|
699
|
+
phase: str,
|
|
700
|
+
status: str,
|
|
701
|
+
step_def: StepDefinition[Any, Any],
|
|
702
|
+
token: UUID,
|
|
703
|
+
attempt: int,
|
|
704
|
+
step_input: BaseModel,
|
|
705
|
+
step_output: BaseModel | None,
|
|
706
|
+
error: Exception | None,
|
|
707
|
+
) -> dict[str, Any]:
|
|
708
|
+
"""Return one step history record."""
|
|
709
|
+
return {
|
|
710
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
|
711
|
+
"phase": phase,
|
|
712
|
+
"status": status,
|
|
713
|
+
"step_id": step_def.step_id,
|
|
714
|
+
"step_name": type(step_def.step).__name__,
|
|
715
|
+
"attempt": attempt,
|
|
716
|
+
"token": str(token),
|
|
717
|
+
"input": self._serialize_value(step_input),
|
|
718
|
+
"output": (
|
|
719
|
+
self._serialize_value(step_output) if step_output is not None else None
|
|
720
|
+
),
|
|
721
|
+
"error": repr(error) if error is not None else None,
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
@staticmethod
|
|
725
|
+
def _has_compensation_history(step_history: list[dict[str, Any]]) -> bool:
|
|
726
|
+
"""Return whether step history contains a compensation entry."""
|
|
727
|
+
return any(entry.get("phase") == "compensate" for entry in step_history)
|
|
728
|
+
|
|
729
|
+
def _running_deadline_for_step(
|
|
730
|
+
self,
|
|
731
|
+
step_def: StepDefinition[Any, Any],
|
|
732
|
+
*,
|
|
733
|
+
now: datetime,
|
|
734
|
+
) -> datetime:
|
|
735
|
+
"""Return the deadline for the current step execution."""
|
|
736
|
+
if step_def.timeout is not None:
|
|
737
|
+
return now + step_def.timeout
|
|
738
|
+
return now + self._execution_lease
|
|
739
|
+
|
|
740
|
+
@staticmethod
|
|
741
|
+
def _to_snapshot(saga: ModelT) -> SagaSnapshot:
|
|
742
|
+
"""Convert a saga ORM object into a snapshot model."""
|
|
743
|
+
return SagaSnapshot(
|
|
744
|
+
id=saga.id,
|
|
745
|
+
aggregation_id=saga.aggregation_id,
|
|
746
|
+
status=saga.status,
|
|
747
|
+
current_step_index=saga.current_step_index,
|
|
748
|
+
retry_counter=saga.retry_counter,
|
|
749
|
+
deadline_at=saga.deadline_at,
|
|
750
|
+
trace_id=saga.trace_id,
|
|
751
|
+
step_execution_token=saga.step_execution_token,
|
|
752
|
+
last_error=saga.last_error,
|
|
753
|
+
)
|