python-saga-orchestrator 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,753 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import uuid
5
+ from datetime import UTC, datetime, timedelta
6
+ from typing import Any, Generic, TypeVar
7
+ from uuid import UUID
8
+
9
+ from loguru import logger
10
+ from pydantic import BaseModel
11
+ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
12
+
13
+ from ..domain.exceptions import SagaDefinitionError, SagaStateError
14
+ from ..domain.mixins import SagaStateMixin
15
+ from ..domain.models import (
16
+ InputContext,
17
+ SagaAdminSnapshot,
18
+ SagaDefinition,
19
+ SagaSnapshot,
20
+ StepDefinition,
21
+ )
22
+ from ..domain.models.enums import SagaStatus
23
+ from .repository import SagaRepository
24
+
25
+ ModelT = TypeVar("ModelT", bound=SagaStateMixin)
26
+
27
+
28
+ class SagaEngine(Generic[ModelT]):
29
+ """Execute, resume, recover, and administrate saga instances."""
30
+
31
+ def __init__(
32
+ self,
33
+ *,
34
+ model_class: type[ModelT],
35
+ session_maker: async_sessionmaker[AsyncSession],
36
+ execution_lease: timedelta = timedelta(minutes=5),
37
+ ) -> None:
38
+ """Initialize the engine dependencies and execution lease."""
39
+ self._model_class = model_class
40
+ self._session_maker = session_maker
41
+ self._execution_lease = execution_lease
42
+ self._repository = SagaRepository(model_class)
43
+ self._registry: dict[str, SagaDefinition] = {}
44
+
45
+ @property
46
+ def repository(self) -> SagaRepository[ModelT]:
47
+ """Return the repository used by the engine."""
48
+ return self._repository
49
+
50
+ def register(self, name: str, saga_definition: SagaDefinition) -> None:
51
+ """Register a saga definition under a runtime name."""
52
+ if name in self._registry:
53
+ raise SagaDefinitionError(f"Saga '{name}' is already registered")
54
+ self._registry[name] = saga_definition
55
+
56
+ async def start(
57
+ self,
58
+ *,
59
+ saga_name: str,
60
+ initial_data: BaseModel | dict[str, Any] | Any,
61
+ aggregation_id: str,
62
+ trace_id: str | None = None,
63
+ ) -> UUID:
64
+ """Create a new saga instance and start executing it."""
65
+ if saga_name not in self._registry:
66
+ raise SagaDefinitionError(f"Saga '{saga_name}' is not registered")
67
+
68
+ normalized_initial = self._serialize_value(initial_data)
69
+ saga_trace_id = trace_id or str(uuid.uuid4())
70
+ saga_id = uuid.uuid4()
71
+ definition = self._registry[saga_name]
72
+ initial_deadline = self._running_deadline_for_step(
73
+ definition.steps[0],
74
+ now=datetime.now(UTC),
75
+ )
76
+
77
+ async with self._session_maker() as session:
78
+ async with session.begin():
79
+ await self._repository.ensure_no_active_aggregation_conflict(
80
+ session,
81
+ aggregation_id,
82
+ )
83
+ saga = self._model_class(
84
+ id=saga_id,
85
+ aggregation_id=aggregation_id,
86
+ trace_id=saga_trace_id,
87
+ status=SagaStatus.RUNNING,
88
+ current_step_index=0,
89
+ step_execution_token=uuid.uuid4(),
90
+ context={
91
+ "saga_name": saga_name,
92
+ "initial_data": normalized_initial,
93
+ "step_outputs": {},
94
+ },
95
+ step_history=[],
96
+ deadline_at=initial_deadline,
97
+ retry_counter=0,
98
+ )
99
+ await self._repository.create(session, saga)
100
+
101
+ await self._drive(saga_id)
102
+ return saga_id
103
+
104
+ async def notify(
105
+ self, *, saga_id: UUID, token: UUID, event: Any | None = None
106
+ ) -> bool:
107
+ """Resume a suspended saga when the provided execution token matches."""
108
+ async with self._session_maker() as session:
109
+ async with session.begin():
110
+ saga = await self._repository.get_for_update(session, saga_id)
111
+ if saga.status != SagaStatus.SUSPENDED:
112
+ return False
113
+ if saga.step_execution_token != token:
114
+ logger.info("Ignoring stale notify for saga_id=%s", saga_id)
115
+ return False
116
+ if event is not None:
117
+ events = saga.context.setdefault("events", [])
118
+ events.append(self._serialize_value(event))
119
+ saga.context["latest_event"] = self._serialize_value(event)
120
+ saga.status = SagaStatus.RUNNING
121
+ step_def = self._registry[saga.context["saga_name"]].steps[
122
+ saga.current_step_index
123
+ ]
124
+ saga.deadline_at = self._running_deadline_for_step(
125
+ step_def,
126
+ now=datetime.now(UTC),
127
+ )
128
+ saga.step_execution_token = uuid.uuid4()
129
+
130
+ await self._drive(saga_id)
131
+ return True
132
+
133
+ async def run_due(self, *, limit: int = 100) -> int:
134
+ """Resume due running, suspended, and compensating sagas."""
135
+ now = datetime.now(UTC)
136
+ ready_ids: list[UUID] = []
137
+ compensation_ids: list[UUID] = []
138
+
139
+ async with self._session_maker() as session:
140
+ async with session.begin():
141
+ due_running = await self._repository.due_running(
142
+ session,
143
+ now=now,
144
+ limit=limit,
145
+ )
146
+ remaining = max(limit - len(due_running), 0)
147
+ due_suspended = await self._repository.due_suspended(
148
+ session,
149
+ now=now,
150
+ limit=remaining,
151
+ )
152
+ remaining -= len(due_suspended)
153
+ due_compensating = await self._repository.due_compensating(
154
+ session,
155
+ now=now,
156
+ limit=max(remaining, 0),
157
+ )
158
+ for saga in [*due_running, *due_suspended]:
159
+ saga.status = SagaStatus.RUNNING
160
+ saga.step_execution_token = uuid.uuid4()
161
+ saga.deadline_at = now + self._execution_lease
162
+ ready_ids.append(saga.id)
163
+ for saga in due_compensating:
164
+ saga.status = SagaStatus.COMPENSATING
165
+ saga.step_execution_token = uuid.uuid4()
166
+ saga.deadline_at = now + self._execution_lease
167
+ compensation_ids.append(saga.id)
168
+
169
+ for saga_id in ready_ids:
170
+ await self._drive(saga_id)
171
+ for saga_id in compensation_ids:
172
+ await self._run_compensation(saga_id)
173
+
174
+ return len(ready_ids) + len(compensation_ids)
175
+
176
+ async def get_snapshot(self, saga_id: UUID) -> SagaSnapshot:
177
+ """Return the snapshot view of one saga."""
178
+ async with self._session_maker() as session:
179
+ saga = await self._repository.get_for_update(session, saga_id)
180
+ return self._to_snapshot(saga)
181
+
182
+ async def get_admin_snapshot(self, saga_id: UUID) -> SagaAdminSnapshot:
183
+ """Return the administrative view of one saga."""
184
+ async with self._session_maker() as session:
185
+ async with session.begin():
186
+ saga = await self._repository.get(session, saga_id)
187
+ return SagaAdminSnapshot(
188
+ id=saga.id,
189
+ aggregation_id=saga.aggregation_id,
190
+ trace_id=saga.trace_id,
191
+ status=saga.status,
192
+ current_step_index=saga.current_step_index,
193
+ step_execution_token=saga.step_execution_token,
194
+ retry_counter=saga.retry_counter,
195
+ deadline_at=saga.deadline_at,
196
+ last_error=saga.last_error,
197
+ context=saga.context,
198
+ step_history=saga.step_history,
199
+ )
200
+
201
+ async def resume(self, saga_id: UUID) -> None:
202
+ """Resume forward execution of one saga."""
203
+ await self._drive(saga_id)
204
+
205
+ async def retry_step(self, saga_id: UUID) -> None:
206
+ """Reset the current step state and retry it."""
207
+ async with self._session_maker() as session:
208
+ async with session.begin():
209
+ saga = await self._repository.get_for_update(session, saga_id)
210
+ if saga.status not in {SagaStatus.SUSPENDED, SagaStatus.FAILED}:
211
+ raise SagaStateError(
212
+ f"Cannot retry step when saga status is {saga.status.value}"
213
+ )
214
+
215
+ saga_name = saga.context["saga_name"]
216
+ definition = self._registry[saga_name]
217
+ if saga.current_step_index >= len(definition.steps):
218
+ raise SagaStateError(
219
+ "Cannot retry saga because there is no current step to resume"
220
+ )
221
+ if saga.status == SagaStatus.FAILED and self._has_compensation_history(
222
+ saga.step_history
223
+ ):
224
+ raise SagaStateError(
225
+ "Cannot retry saga after compensation has already started"
226
+ )
227
+ step_def = definition.steps[saga.current_step_index]
228
+
229
+ saga.status = SagaStatus.RUNNING
230
+ saga.retry_counter = 0
231
+ saga.last_error = None
232
+ saga.step_execution_token = uuid.uuid4()
233
+ saga.deadline_at = self._running_deadline_for_step(
234
+ step_def,
235
+ now=datetime.now(UTC),
236
+ )
237
+
238
+ await self._drive(saga_id)
239
+
240
+ async def compensate_step(self, saga_id: UUID) -> None:
241
+ """Switch one saga into compensation and execute rollback."""
242
+ async with self._session_maker() as session:
243
+ async with session.begin():
244
+ saga = await self._repository.get_for_update(session, saga_id)
245
+ if saga.status not in {
246
+ SagaStatus.SUSPENDED,
247
+ SagaStatus.FAILED,
248
+ SagaStatus.COMPENSATING,
249
+ }:
250
+ raise SagaStateError(
251
+ "Cannot start compensation unless saga is suspended, failed, "
252
+ f"or already compensating (status={saga.status.value})"
253
+ )
254
+ if saga.current_step_index <= 0:
255
+ raise SagaStateError(
256
+ "Cannot compensate saga because there is no completed step to roll back"
257
+ )
258
+
259
+ saga.status = SagaStatus.COMPENSATING
260
+ saga.retry_counter = 0
261
+ saga.step_execution_token = uuid.uuid4()
262
+ saga.deadline_at = datetime.now(UTC) + self._execution_lease
263
+
264
+ await self._run_compensation(saga_id)
265
+
266
+ async def abort(self, saga_id: UUID) -> None:
267
+ """Mark a saga as failed and invalidate its current execution token."""
268
+ async with self._session_maker() as session:
269
+ async with session.begin():
270
+ saga = await self._repository.get_for_update(session, saga_id)
271
+ saga.status = SagaStatus.FAILED
272
+ saga.deadline_at = None
273
+ saga.last_error = saga.last_error or "Aborted by admin"
274
+ saga.step_execution_token = uuid.uuid4()
275
+
276
+ async def skip_step(
277
+ self,
278
+ saga_id: UUID,
279
+ mock_output: BaseModel | dict[str, Any] | None = None,
280
+ ) -> None:
281
+ """Mark the current step as successful and continue execution."""
282
+ should_resume = False
283
+
284
+ async with self._session_maker() as session:
285
+ async with session.begin():
286
+ saga = await self._repository.get_for_update(session, saga_id)
287
+ if saga.status != SagaStatus.SUSPENDED:
288
+ raise SagaStateError(
289
+ "Cannot skip step unless saga is suspended on the current step "
290
+ f"(status={saga.status.value})"
291
+ )
292
+
293
+ saga_name = saga.context["saga_name"]
294
+ definition = self._registry[saga_name]
295
+ if saga.current_step_index >= len(definition.steps):
296
+ raise SagaStateError("No step available for skipping")
297
+
298
+ step_def = definition.steps[saga.current_step_index]
299
+ if mock_output is None:
300
+ output_payload: dict[str, Any] = {}
301
+ elif isinstance(mock_output, BaseModel):
302
+ output_payload = mock_output.model_dump(mode="json")
303
+ else:
304
+ output_payload = mock_output
305
+
306
+ output_model = step_def.output_model.model_validate(output_payload)
307
+ token = saga.step_execution_token or uuid.uuid4()
308
+
309
+ saga.step_history.append(
310
+ {
311
+ "phase": "execute",
312
+ "status": "SUCCESS",
313
+ "step_id": step_def.step_id,
314
+ "step_name": type(step_def.step).__name__,
315
+ "attempt": 0,
316
+ "token": str(token),
317
+ "input": {"_admin": "skip_step"},
318
+ "output": output_model.model_dump(mode="json"),
319
+ "error": None,
320
+ "skipped": True,
321
+ }
322
+ )
323
+ outputs = saga.context.setdefault("step_outputs", {})
324
+ outputs[step_def.step_id] = output_model.model_dump(mode="json")
325
+ saga.current_step_index += 1
326
+ saga.retry_counter = 0
327
+ saga.last_error = None
328
+ saga.step_execution_token = uuid.uuid4()
329
+
330
+ if saga.current_step_index >= len(definition.steps):
331
+ saga.status = SagaStatus.COMPLETED
332
+ saga.deadline_at = None
333
+ else:
334
+ next_step = definition.steps[saga.current_step_index]
335
+ saga.status = SagaStatus.RUNNING
336
+ saga.deadline_at = self._running_deadline_for_step(
337
+ next_step,
338
+ now=datetime.now(UTC),
339
+ )
340
+ should_resume = True
341
+
342
+ if should_resume:
343
+ await self._drive(saga_id)
344
+
345
+ async def _drive(self, saga_id: UUID) -> None:
346
+ """Execute forward steps until the saga stops progressing."""
347
+ while True:
348
+ prep = await self._prepare_step(saga_id)
349
+ if prep is None:
350
+ return
351
+
352
+ step_def = prep["step_def"]
353
+ step_token = prep["step_token"]
354
+ step_input = prep["step_input"]
355
+ attempt_number = prep["attempt_number"]
356
+
357
+ success = False
358
+ step_output: BaseModel | None = None
359
+ error: Exception | None = None
360
+
361
+ try:
362
+ if step_def.timeout is None:
363
+ step_output = await step_def.step.execute(step_input)
364
+ else:
365
+ step_output = await asyncio.wait_for(
366
+ step_def.step.execute(step_input),
367
+ timeout=step_def.timeout.total_seconds(),
368
+ )
369
+ success = True
370
+ except Exception as exc: # noqa: BLE001
371
+ error = exc
372
+
373
+ should_continue = await self._finalize_step(
374
+ saga_id=saga_id,
375
+ step_def=step_def,
376
+ token=step_token,
377
+ step_input=step_input,
378
+ step_output=step_output,
379
+ error=error,
380
+ attempt_number=attempt_number,
381
+ )
382
+ if not should_continue:
383
+ return
384
+ if not success:
385
+ return
386
+
387
+ async def _prepare_step(self, saga_id: UUID) -> dict[str, Any] | None:
388
+ """Load saga state and prepare the current forward step for execution."""
389
+ async with self._session_maker() as session:
390
+ async with session.begin():
391
+ saga = await self._repository.get_for_update(session, saga_id)
392
+ saga_name = saga.context.get("saga_name")
393
+ if not saga_name or saga_name not in self._registry:
394
+ raise SagaDefinitionError(
395
+ f"Unknown saga registered name in state: '{saga_name}'"
396
+ )
397
+
398
+ definition = self._registry[saga_name]
399
+ if saga.status != SagaStatus.RUNNING:
400
+ return None
401
+ if saga.current_step_index >= len(definition.steps):
402
+ saga.status = SagaStatus.COMPLETED
403
+ saga.deadline_at = None
404
+ saga.last_error = None
405
+ return None
406
+
407
+ step_def = definition.steps[saga.current_step_index]
408
+ step_token = saga.step_execution_token or uuid.uuid4()
409
+ saga.step_execution_token = step_token
410
+ saga.deadline_at = self._running_deadline_for_step(
411
+ step_def,
412
+ now=datetime.now(UTC),
413
+ )
414
+ attempt_number = saga.retry_counter + 1
415
+ step_input = self._build_step_input(step_def, saga.context)
416
+
417
+ return {
418
+ "step_def": step_def,
419
+ "step_token": step_token,
420
+ "step_input": step_input,
421
+ "attempt_number": attempt_number,
422
+ }
423
+
424
+ async def _finalize_step(
425
+ self,
426
+ *,
427
+ saga_id: UUID,
428
+ step_def: StepDefinition[Any, Any],
429
+ token: UUID,
430
+ step_input: BaseModel,
431
+ step_output: BaseModel | None,
432
+ error: Exception | None,
433
+ attempt_number: int,
434
+ ) -> bool:
435
+ """Persist one forward step result and return whether execution continues."""
436
+ async with self._session_maker() as session:
437
+ async with session.begin():
438
+ saga = await self._repository.get_for_update(session, saga_id)
439
+ if (
440
+ saga.step_execution_token != token
441
+ or saga.status != SagaStatus.RUNNING
442
+ ):
443
+ logger.info("Stale step result ignored for saga_id=%s", saga_id)
444
+ return False
445
+
446
+ if error is None and step_output is not None:
447
+ saga.step_history.append(
448
+ self._history_entry(
449
+ phase="execute",
450
+ status="SUCCESS",
451
+ step_def=step_def,
452
+ token=token,
453
+ attempt=attempt_number,
454
+ step_input=step_input,
455
+ step_output=step_output,
456
+ error=None,
457
+ )
458
+ )
459
+ outputs = saga.context.setdefault("step_outputs", {})
460
+ outputs[step_def.step_id] = self._serialize_value(step_output)
461
+ saga.context.pop("latest_event", None)
462
+ saga.current_step_index += 1
463
+ saga.retry_counter = 0
464
+ saga.deadline_at = None
465
+ saga.last_error = None
466
+ saga.step_execution_token = uuid.uuid4()
467
+
468
+ saga_name = saga.context["saga_name"]
469
+ definition = self._registry[saga_name]
470
+ if saga.current_step_index >= len(definition.steps):
471
+ saga.status = SagaStatus.COMPLETED
472
+ return saga.status == SagaStatus.RUNNING
473
+
474
+ assert error is not None
475
+ saga.step_history.append(
476
+ self._history_entry(
477
+ phase="execute",
478
+ status="ERROR",
479
+ step_def=step_def,
480
+ token=token,
481
+ attempt=attempt_number,
482
+ step_input=step_input,
483
+ step_output=None,
484
+ error=error,
485
+ )
486
+ )
487
+
488
+ saga.last_error = repr(error)
489
+ next_attempt = saga.retry_counter + 1
490
+ delay = step_def.retry_policy.next_delay(next_attempt)
491
+ saga.retry_counter = next_attempt
492
+ if delay is not None:
493
+ saga.status = SagaStatus.SUSPENDED
494
+ saga.deadline_at = datetime.now(UTC) + delay
495
+ saga.step_execution_token = uuid.uuid4()
496
+ return False
497
+
498
+ saga_name = saga.context["saga_name"]
499
+ definition = self._registry[saga_name]
500
+ if definition.compensate_on_failure and saga.current_step_index > 0:
501
+ saga.status = SagaStatus.COMPENSATING
502
+ saga.deadline_at = datetime.now(UTC) + self._execution_lease
503
+ else:
504
+ saga.status = SagaStatus.FAILED
505
+ saga.deadline_at = None
506
+ return False
507
+
508
+ await self._run_compensation(saga_id)
509
+ return False
510
+
511
+ async def _run_compensation(self, saga_id: UUID) -> None:
512
+ """Execute compensation steps until rollback stops."""
513
+ while True:
514
+ comp = await self._prepare_compensation(saga_id)
515
+ if comp is None:
516
+ return
517
+
518
+ step_def = comp["step_def"]
519
+ token = comp["token"]
520
+ original_input = comp["original_input"]
521
+ original_output = comp["original_output"]
522
+
523
+ error: Exception | None = None
524
+ try:
525
+ await step_def.step.compensate(original_input, original_output)
526
+ except Exception as exc: # noqa: BLE001
527
+ error = exc
528
+
529
+ should_continue = await self._finalize_compensation(
530
+ saga_id=saga_id,
531
+ step_def=step_def,
532
+ token=token,
533
+ original_input=original_input,
534
+ original_output=original_output,
535
+ error=error,
536
+ )
537
+ if not should_continue:
538
+ return
539
+
540
+ async def _prepare_compensation(self, saga_id: UUID) -> dict[str, Any] | None:
541
+ """Load saga state and prepare the next compensation step."""
542
+ async with self._session_maker() as session:
543
+ async with session.begin():
544
+ saga = await self._repository.get_for_update(session, saga_id)
545
+ if saga.status != SagaStatus.COMPENSATING:
546
+ return None
547
+
548
+ saga_name = saga.context["saga_name"]
549
+ definition = self._registry[saga_name]
550
+ if saga.current_step_index <= 0:
551
+ saga.status = SagaStatus.FAILED
552
+ return None
553
+
554
+ step_idx = saga.current_step_index - 1
555
+ step_def = definition.steps[step_idx]
556
+
557
+ execution_entry = None
558
+ for entry in reversed(saga.step_history):
559
+ if (
560
+ entry.get("phase") == "execute"
561
+ and entry.get("status") == "SUCCESS"
562
+ and entry.get("step_id") == step_def.step_id
563
+ ):
564
+ execution_entry = entry
565
+ break
566
+
567
+ if execution_entry is None:
568
+ saga.status = SagaStatus.FAILED
569
+ saga.last_error = (
570
+ "Missing successful execution entry for step "
571
+ f"'{step_def.step_id}'"
572
+ )
573
+ return None
574
+
575
+ token = uuid.uuid4()
576
+ saga.step_execution_token = token
577
+ saga.deadline_at = datetime.now(UTC) + self._execution_lease
578
+ return {
579
+ "step_def": step_def,
580
+ "token": token,
581
+ "original_input": step_def.input_model.model_validate(
582
+ execution_entry["input"]
583
+ ),
584
+ "original_output": step_def.output_model.model_validate(
585
+ execution_entry["output"]
586
+ ),
587
+ }
588
+
589
+ async def _finalize_compensation(
590
+ self,
591
+ *,
592
+ saga_id: UUID,
593
+ step_def: StepDefinition[Any, Any],
594
+ token: UUID,
595
+ original_input: BaseModel,
596
+ original_output: BaseModel,
597
+ error: Exception | None,
598
+ ) -> bool:
599
+ """Persist one compensation result and return whether rollback continues."""
600
+ async with self._session_maker() as session:
601
+ async with session.begin():
602
+ saga = await self._repository.get_for_update(session, saga_id)
603
+ if (
604
+ saga.status != SagaStatus.COMPENSATING
605
+ or saga.step_execution_token != token
606
+ ):
607
+ return False
608
+
609
+ if error is not None:
610
+ saga.step_history.append(
611
+ self._history_entry(
612
+ phase="compensate",
613
+ status="ERROR",
614
+ step_def=step_def,
615
+ token=token,
616
+ attempt=1,
617
+ step_input=original_input,
618
+ step_output=original_output,
619
+ error=error,
620
+ )
621
+ )
622
+ saga.last_error = (
623
+ f"Compensation failed for step '{step_def.step_id}': {error!r}"
624
+ )
625
+ saga.status = SagaStatus.FAILED
626
+ saga.deadline_at = None
627
+ return False
628
+
629
+ saga.step_history.append(
630
+ self._history_entry(
631
+ phase="compensate",
632
+ status="SUCCESS",
633
+ step_def=step_def,
634
+ token=token,
635
+ attempt=1,
636
+ step_input=original_input,
637
+ step_output=original_output,
638
+ error=None,
639
+ )
640
+ )
641
+ saga.current_step_index -= 1
642
+ saga.step_execution_token = uuid.uuid4()
643
+ if saga.current_step_index <= 0:
644
+ saga.status = SagaStatus.FAILED
645
+ saga.deadline_at = None
646
+ return False
647
+ saga.deadline_at = datetime.now(UTC) + self._execution_lease
648
+ return True
649
+
650
+ @staticmethod
651
+ def _build_step_input(
652
+ step_def: StepDefinition[Any, Any],
653
+ context: dict[str, Any],
654
+ ) -> BaseModel:
655
+ """Build the input model for one step from saga context."""
656
+ if step_def.depends_on is not None:
657
+ dep_payload = context.get("step_outputs", {}).get(
658
+ step_def.depends_on.step_id
659
+ )
660
+ if dep_payload is None:
661
+ raise SagaStateError(
662
+ f"Missing dependency output for step '{step_def.depends_on.step_id}'"
663
+ )
664
+ dep_model = step_def.depends_on.output_model.model_validate(dep_payload)
665
+ mapped = step_def.input_map(dep_model)
666
+ else:
667
+ mapped = step_def.input_map(
668
+ InputContext(
669
+ initial_data=context.get("initial_data"),
670
+ context=context,
671
+ step_outputs=context.get("step_outputs", {}),
672
+ latest_event=context.get("latest_event"),
673
+ events=context.get("events"),
674
+ )
675
+ )
676
+
677
+ if isinstance(mapped, step_def.input_model):
678
+ return mapped
679
+ if isinstance(mapped, dict):
680
+ return step_def.input_model.model_validate(mapped)
681
+ raise SagaStateError(
682
+ f"input_map for step '{step_def.step_id}' "
683
+ f"must return {step_def.input_model.__name__} or dict"
684
+ )
685
+
686
+ def _serialize_value(self, value: Any) -> Any:
687
+ """Convert values into JSON-serializable structures."""
688
+ if isinstance(value, BaseModel):
689
+ return value.model_dump(mode="json")
690
+ if isinstance(value, dict):
691
+ return {key: self._serialize_value(val) for key, val in value.items()}
692
+ if isinstance(value, list):
693
+ return [self._serialize_value(item) for item in value]
694
+ return value
695
+
696
+ def _history_entry(
697
+ self,
698
+ *,
699
+ phase: str,
700
+ status: str,
701
+ step_def: StepDefinition[Any, Any],
702
+ token: UUID,
703
+ attempt: int,
704
+ step_input: BaseModel,
705
+ step_output: BaseModel | None,
706
+ error: Exception | None,
707
+ ) -> dict[str, Any]:
708
+ """Return one step history record."""
709
+ return {
710
+ "timestamp": datetime.now(UTC).isoformat(),
711
+ "phase": phase,
712
+ "status": status,
713
+ "step_id": step_def.step_id,
714
+ "step_name": type(step_def.step).__name__,
715
+ "attempt": attempt,
716
+ "token": str(token),
717
+ "input": self._serialize_value(step_input),
718
+ "output": (
719
+ self._serialize_value(step_output) if step_output is not None else None
720
+ ),
721
+ "error": repr(error) if error is not None else None,
722
+ }
723
+
724
+ @staticmethod
725
+ def _has_compensation_history(step_history: list[dict[str, Any]]) -> bool:
726
+ """Return whether step history contains a compensation entry."""
727
+ return any(entry.get("phase") == "compensate" for entry in step_history)
728
+
729
+ def _running_deadline_for_step(
730
+ self,
731
+ step_def: StepDefinition[Any, Any],
732
+ *,
733
+ now: datetime,
734
+ ) -> datetime:
735
+ """Return the deadline for the current step execution."""
736
+ if step_def.timeout is not None:
737
+ return now + step_def.timeout
738
+ return now + self._execution_lease
739
+
740
+ @staticmethod
741
+ def _to_snapshot(saga: ModelT) -> SagaSnapshot:
742
+ """Convert a saga ORM object into a snapshot model."""
743
+ return SagaSnapshot(
744
+ id=saga.id,
745
+ aggregation_id=saga.aggregation_id,
746
+ status=saga.status,
747
+ current_step_index=saga.current_step_index,
748
+ retry_counter=saga.retry_counter,
749
+ deadline_at=saga.deadline_at,
750
+ trace_id=saga.trace_id,
751
+ step_execution_token=saga.step_execution_token,
752
+ last_error=saga.last_error,
753
+ )