python-saga-orchestrator 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import timedelta
4
+ from typing import Any, Generic, TypeVar
5
+ from uuid import UUID
6
+
7
+ from pydantic import BaseModel
8
+ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
9
+
10
+ from ..domain.mixins import SagaStateMixin
11
+ from ..domain.models import SagaDefinition, SagaSnapshot
12
+ from .engine import SagaEngine
13
+ from .repository import SagaRepository
14
+
15
+ ModelT = TypeVar("ModelT", bound=SagaStateMixin)
16
+
17
+
18
+ class SagaOrchestrator(Generic[ModelT]):
19
+ """Provide the public runtime API for saga execution."""
20
+
21
+ def __init__(
22
+ self,
23
+ *,
24
+ model_class: type[ModelT],
25
+ session_maker: async_sessionmaker[AsyncSession],
26
+ execution_lease: timedelta = timedelta(minutes=5),
27
+ ) -> None:
28
+ """Initialize the orchestrator facade."""
29
+ self._engine = SagaEngine(
30
+ model_class=model_class,
31
+ session_maker=session_maker,
32
+ execution_lease=execution_lease,
33
+ )
34
+
35
+ @property
36
+ def engine(self) -> SagaEngine[ModelT]:
37
+ """Return the engine used by the orchestrator."""
38
+ return self._engine
39
+
40
+ @property
41
+ def repository(self) -> SagaRepository[ModelT]:
42
+ """Return the repository used by the engine."""
43
+ return self._engine.repository
44
+
45
+ def register(self, name: str, saga_definition: SagaDefinition) -> None:
46
+ """Register a saga definition under a runtime name."""
47
+ self._engine.register(name, saga_definition)
48
+
49
+ async def start(
50
+ self,
51
+ *,
52
+ saga_name: str,
53
+ initial_data: BaseModel | dict[str, Any] | Any,
54
+ aggregation_id: str,
55
+ trace_id: str | None = None,
56
+ ) -> UUID:
57
+ """Create a new saga instance and start executing it."""
58
+ return await self._engine.start(
59
+ saga_name=saga_name,
60
+ initial_data=initial_data,
61
+ aggregation_id=aggregation_id,
62
+ trace_id=trace_id,
63
+ )
64
+
65
+ async def notify(
66
+ self, *, saga_id: UUID, token: UUID, event: Any | None = None
67
+ ) -> bool:
68
+ """Resume a suspended saga when the provided execution token matches."""
69
+ return await self._engine.notify(saga_id=saga_id, token=token, event=event)
70
+
71
+ async def run_due(self, *, limit: int = 100) -> int:
72
+ """Resume due running, suspended, and compensating sagas."""
73
+ return await self._engine.run_due(limit=limit)
74
+
75
+ async def get_snapshot(self, saga_id: UUID) -> SagaSnapshot:
76
+ """Return the snapshot view of one saga."""
77
+ return await self._engine.get_snapshot(saga_id)
78
+
79
+ async def resume(self, saga_id: UUID) -> None:
80
+ """Resume forward execution of one saga."""
81
+ await self._engine.resume(saga_id)
@@ -0,0 +1,166 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Generic, TypeVar
5
+ from uuid import UUID
6
+
7
+ from sqlalchemy import Select, select
8
+ from sqlalchemy.ext.asyncio import AsyncSession
9
+
10
+ from ..domain.exceptions import ActiveSagaAlreadyExistsError, SagaNotFoundError
11
+ from ..domain.mixins import SagaStateMixin
12
+ from ..domain.models.enums import SagaStatus
13
+
14
+ ModelT = TypeVar("ModelT", bound=SagaStateMixin)
15
+
16
+
17
+ class SagaRepository(Generic[ModelT]):
18
+ """Provide persistence operations for saga state rows."""
19
+
20
+ ACTIVE_STATUSES: tuple[SagaStatus, ...] = (
21
+ SagaStatus.RUNNING,
22
+ SagaStatus.SUSPENDED,
23
+ SagaStatus.COMPENSATING,
24
+ )
25
+
26
+ def __init__(self, model_class: type[ModelT]) -> None:
27
+ """Initialize the repository for one saga state model."""
28
+ self.model_class = model_class
29
+
30
+ async def get(self, session: AsyncSession, saga_id: UUID) -> ModelT:
31
+ """Return one saga row by id."""
32
+ stmt: Select[tuple[ModelT]] = select(self.model_class).where(
33
+ self.model_class.id == saga_id
34
+ )
35
+ result = await session.execute(stmt)
36
+ saga = result.scalar_one_or_none()
37
+ if saga is None:
38
+ raise SagaNotFoundError(f"Saga '{saga_id}' not found")
39
+ return saga
40
+
41
+ async def get_for_update(self, session: AsyncSession, saga_id: UUID) -> ModelT:
42
+ """Return one saga row by id and lock it for update."""
43
+ stmt: Select[tuple[ModelT]] = (
44
+ select(self.model_class)
45
+ .where(self.model_class.id == saga_id)
46
+ .with_for_update(nowait=False)
47
+ )
48
+ result = await session.execute(stmt)
49
+ saga = result.scalar_one_or_none()
50
+ if saga is None:
51
+ raise SagaNotFoundError(f"Saga '{saga_id}' not found")
52
+ return saga
53
+
54
+ async def get_active_by_aggregation_id_for_update(
55
+ self,
56
+ session: AsyncSession,
57
+ aggregation_id: str,
58
+ ) -> ModelT | None:
59
+ """Return an active saga for the aggregation id and lock it for update."""
60
+ stmt: Select[tuple[ModelT]] = (
61
+ select(self.model_class)
62
+ .where(
63
+ self.model_class.aggregation_id == aggregation_id,
64
+ self.model_class.status.in_(self.ACTIVE_STATUSES),
65
+ )
66
+ .with_for_update(nowait=False)
67
+ )
68
+ result = await session.execute(stmt)
69
+ return result.scalar_one_or_none()
70
+
71
+ async def ensure_no_active_aggregation_conflict(
72
+ self,
73
+ session: AsyncSession,
74
+ aggregation_id: str,
75
+ ) -> None:
76
+ """Raise when an active saga exists for the aggregation id."""
77
+ existing = await self.get_active_by_aggregation_id_for_update(
78
+ session,
79
+ aggregation_id,
80
+ )
81
+ if existing is not None:
82
+ raise ActiveSagaAlreadyExistsError(
83
+ "Active saga already exists for aggregation_id "
84
+ f"'{aggregation_id}' (saga_id={existing.id})"
85
+ )
86
+
87
+ async def create(self, session: AsyncSession, saga: ModelT) -> ModelT:
88
+ """Add a new saga row to the session and flush it."""
89
+ session.add(saga)
90
+ await session.flush()
91
+ return saga
92
+
93
+ async def due_suspended(
94
+ self,
95
+ session: AsyncSession,
96
+ now: datetime,
97
+ limit: int,
98
+ ) -> list[ModelT]:
99
+ """Return suspended sagas whose deadlines are due."""
100
+ return await self._due_by_status(
101
+ session=session,
102
+ status=SagaStatus.SUSPENDED,
103
+ now=now,
104
+ limit=limit,
105
+ )
106
+
107
+ async def due_running(
108
+ self,
109
+ session: AsyncSession,
110
+ now: datetime,
111
+ limit: int,
112
+ ) -> list[ModelT]:
113
+ """Return running sagas whose deadlines are due."""
114
+ return await self._due_by_status(
115
+ session=session,
116
+ status=SagaStatus.RUNNING,
117
+ now=now,
118
+ limit=limit,
119
+ )
120
+
121
+ async def due_compensating(
122
+ self,
123
+ session: AsyncSession,
124
+ now: datetime,
125
+ limit: int,
126
+ ) -> list[ModelT]:
127
+ """Return compensating sagas whose deadlines are due."""
128
+ return await self._due_by_status(
129
+ session=session,
130
+ status=SagaStatus.COMPENSATING,
131
+ now=now,
132
+ limit=limit,
133
+ )
134
+
135
+ async def _due_by_status(
136
+ self,
137
+ *,
138
+ session: AsyncSession,
139
+ status: SagaStatus,
140
+ now: datetime,
141
+ limit: int,
142
+ ) -> list[ModelT]:
143
+ """Return due saga rows for one status ordered by deadline."""
144
+ stmt = (
145
+ select(self.model_class)
146
+ .where(
147
+ self.model_class.status == status,
148
+ self.model_class.deadline_at.is_not(None),
149
+ self.model_class.deadline_at <= now,
150
+ )
151
+ .order_by(self.model_class.deadline_at.asc())
152
+ .limit(limit)
153
+ )
154
+ if self._supports_skip_locked(session):
155
+ stmt = stmt.with_for_update(skip_locked=True)
156
+ else:
157
+ stmt = stmt.with_for_update(nowait=False)
158
+
159
+ result = await session.execute(stmt)
160
+ return list(result.scalars().all())
161
+
162
+ @staticmethod
163
+ def _supports_skip_locked(session: AsyncSession) -> bool:
164
+ """Return whether the current dialect supports ``SKIP LOCKED``."""
165
+ bind = session.get_bind()
166
+ return bind is not None and bind.dialect.name == "postgresql"
@@ -0,0 +1 @@
1
+ """Domain module."""
@@ -0,0 +1,17 @@
1
+ """Domain exceptions module."""
2
+
3
+ from .saga import (
4
+ ActiveSagaAlreadyExistsError,
5
+ SagaDefinitionError,
6
+ SagaNotFoundError,
7
+ SagaStateError,
8
+ TypeValidationError,
9
+ )
10
+
11
+ __all__ = [
12
+ "ActiveSagaAlreadyExistsError",
13
+ "SagaDefinitionError",
14
+ "TypeValidationError",
15
+ "SagaNotFoundError",
16
+ "SagaStateError",
17
+ ]
@@ -0,0 +1,22 @@
1
+ class SagaError(Exception):
2
+ """Base exception for orchestrator errors."""
3
+
4
+
5
+ class SagaDefinitionError(SagaError):
6
+ """Invalid saga definition or registration."""
7
+
8
+
9
+ class TypeValidationError(SagaDefinitionError):
10
+ """Type mismatch in step definitions."""
11
+
12
+
13
+ class SagaNotFoundError(SagaError):
14
+ """Saga instance is not found in persistence."""
15
+
16
+
17
+ class ActiveSagaAlreadyExistsError(SagaError):
18
+ """An active saga already exists for the provided aggregation key."""
19
+
20
+
21
+ class SagaStateError(SagaError):
22
+ """Invalid saga state transition or operation."""
@@ -0,0 +1,7 @@
1
+ """Domain mixins module."""
2
+
3
+ from .saga_state import SagaStateMixin
4
+
5
+ __all__ = [
6
+ "SagaStateMixin",
7
+ ]
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+
3
+ import uuid
4
+ from datetime import datetime
5
+ from typing import Any
6
+
7
+ from sqlalchemy import JSON, DateTime, Enum, Integer, String, Text, func
8
+ from sqlalchemy.dialects.postgresql import JSONB, UUID
9
+ from sqlalchemy.ext.mutable import MutableDict, MutableList
10
+ from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column
11
+
12
+ from ..models.enums import SagaStatus
13
+
14
+
15
+ def _json_type() -> JSON:
16
+ return JSON().with_variant(JSONB, "postgresql")
17
+
18
+
19
+ @declarative_mixin
20
+ class SagaStateMixin:
21
+ id: Mapped[uuid.UUID] = mapped_column(
22
+ UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
23
+ )
24
+ aggregation_id: Mapped[str] = mapped_column(String(255), index=True)
25
+ trace_id: Mapped[str] = mapped_column(String(255), index=True)
26
+ status: Mapped[SagaStatus] = mapped_column(
27
+ Enum(SagaStatus),
28
+ default=SagaStatus.RUNNING,
29
+ index=True,
30
+ )
31
+ current_step_index: Mapped[int] = mapped_column(Integer, default=0)
32
+ step_execution_token: Mapped[uuid.UUID | None] = mapped_column(
33
+ UUID(as_uuid=True),
34
+ nullable=True,
35
+ )
36
+ context: Mapped[dict[str, Any]] = mapped_column(
37
+ MutableDict.as_mutable(_json_type()), default=dict
38
+ )
39
+ step_history: Mapped[list[dict[str, Any]]] = mapped_column(
40
+ MutableList.as_mutable(_json_type()), default=list
41
+ )
42
+ deadline_at: Mapped[datetime | None] = mapped_column(
43
+ DateTime(timezone=True),
44
+ nullable=True,
45
+ index=True,
46
+ )
47
+ retry_counter: Mapped[int] = mapped_column(Integer, default=0)
48
+ last_error: Mapped[str | None] = mapped_column(Text, nullable=True)
49
+ created_at: Mapped[datetime] = mapped_column(
50
+ DateTime(timezone=True), nullable=False, server_default=func.now()
51
+ )
52
+ updated_at: Mapped[datetime] = mapped_column(
53
+ DateTime(timezone=True),
54
+ nullable=False,
55
+ server_default=func.now(),
56
+ onupdate=func.now(),
57
+ )
@@ -0,0 +1,21 @@
1
+ """Domain models module."""
2
+
3
+ from .builder import SagaDefinition
4
+ from .retry import ExponentialRetry, FixedRetry, NoRetry, RetryPolicy
5
+ from .saga_snapshot import SagaAdminSnapshot, SagaSnapshot
6
+ from .step import BaseStep, InputContext, StepDefinition, StepInputMap, StepRef
7
+
8
+ __all__ = [
9
+ "SagaDefinition",
10
+ "RetryPolicy",
11
+ "NoRetry",
12
+ "FixedRetry",
13
+ "ExponentialRetry",
14
+ "SagaAdminSnapshot",
15
+ "SagaSnapshot",
16
+ "StepRef",
17
+ "InputContext",
18
+ "StepInputMap",
19
+ "StepDefinition",
20
+ "BaseStep",
21
+ ]
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ from .step import StepDefinition
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class SagaDefinition:
11
+ steps: tuple[StepDefinition[Any, Any], ...]
12
+ compensate_on_failure: bool = True
@@ -0,0 +1,7 @@
1
+ """Domain enum models module."""
2
+
3
+ from .saga_status import SagaStatus
4
+
5
+ __all__ = [
6
+ "SagaStatus",
7
+ ]
@@ -0,0 +1,13 @@
1
+ from enum import Enum
2
+
3
+
4
+ class SagaStatus(str, Enum):
5
+ RUNNING = "RUNNING"
6
+ SUSPENDED = "SUSPENDED"
7
+ FAILED = "FAILED"
8
+ COMPENSATING = "COMPENSATING"
9
+ COMPLETED = "COMPLETED"
10
+
11
+ @property
12
+ def is_terminal(self) -> bool:
13
+ return self in {self.FAILED, self.COMPLETED}
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from datetime import timedelta
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class RetryPolicy:
9
+ max_attempts: int
10
+
11
+ def next_delay(self, attempt_number: int) -> timedelta | None:
12
+ if attempt_number > self.max_attempts:
13
+ return None
14
+ return self._delay_for_attempt(attempt_number)
15
+
16
+ def _delay_for_attempt(self, attempt_number: int) -> timedelta:
17
+ raise NotImplementedError
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class NoRetry(RetryPolicy):
22
+ def __init__(self) -> None:
23
+ super().__init__(max_attempts=0)
24
+
25
+ def _delay_for_attempt(self, attempt_number: int) -> timedelta:
26
+ return timedelta(seconds=0)
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class FixedRetry(RetryPolicy):
31
+ delay: timedelta
32
+
33
+ def _delay_for_attempt(self, attempt_number: int) -> timedelta:
34
+ return self.delay
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class ExponentialRetry(RetryPolicy):
39
+ base_delay: timedelta
40
+ multiplier: float = 2.0
41
+ max_delay: timedelta | None = None
42
+
43
+ def _delay_for_attempt(self, attempt_number: int) -> timedelta:
44
+ seconds = self.base_delay.total_seconds() * (
45
+ self.multiplier ** max(attempt_number - 1, 0)
46
+ )
47
+ delay = timedelta(seconds=seconds)
48
+ if self.max_delay is None:
49
+ return delay
50
+ return min(delay, self.max_delay)
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from datetime import datetime
5
+ from typing import Any
6
+ from uuid import UUID
7
+
8
+ from .enums import SagaStatus
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class SagaSnapshot:
13
+ id: UUID
14
+ aggregation_id: str
15
+ status: SagaStatus
16
+ current_step_index: int
17
+ retry_counter: int
18
+ deadline_at: datetime | None
19
+ trace_id: str
20
+ step_execution_token: UUID | None
21
+ last_error: str | None
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class SagaAdminSnapshot:
26
+ id: UUID
27
+ aggregation_id: str
28
+ trace_id: str
29
+ status: SagaStatus
30
+ current_step_index: int
31
+ step_execution_token: UUID | None
32
+ retry_counter: int
33
+ deadline_at: datetime | None
34
+ last_error: str | None
35
+ context: dict[str, Any]
36
+ step_history: list[dict[str, Any]]
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from collections.abc import Callable
5
+ from dataclasses import dataclass
6
+ from datetime import timedelta
7
+ from typing import Any, Generic, TypeAlias, TypeVar, get_type_hints
8
+
9
+ from pydantic import BaseModel
10
+
11
+ from ..exceptions import TypeValidationError
12
+ from .retry import RetryPolicy
13
+
14
+ InputModelT = TypeVar("InputModelT", bound=BaseModel)
15
+ OutputModelT = TypeVar("OutputModelT", bound=BaseModel)
16
+ DepModelT = TypeVar("DepModelT", bound=BaseModel)
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class StepRef(Generic[OutputModelT]):
21
+ step_id: str
22
+ output_model: type[OutputModelT]
23
+
24
+
25
+ @dataclass
26
+ class InputContext:
27
+ initial_data: Any
28
+ context: dict[str, Any]
29
+ step_outputs: dict[str, Any]
30
+ latest_event: Any | None = None
31
+ events: list[Any] | None = None
32
+
33
+
34
+ RootInputMap: TypeAlias = Callable[[InputContext], InputModelT | dict[str, Any]]
35
+ StepInputMap: TypeAlias = RootInputMap | Callable[[Any], InputModelT | dict[str, Any]]
36
+
37
+
38
+ @dataclass
39
+ class StepDefinition(Generic[InputModelT, OutputModelT]):
40
+ step_id: str
41
+ step: BaseStep[InputModelT, OutputModelT]
42
+ input_map: StepInputMap[InputModelT]
43
+ timeout: timedelta | None
44
+ retry_policy: RetryPolicy
45
+ depends_on: StepRef[Any] | None = None
46
+
47
+ @property
48
+ def input_model(self) -> type[InputModelT]:
49
+ return self.step.input_model
50
+
51
+ @property
52
+ def output_model(self) -> type[OutputModelT]:
53
+ return self.step.output_model
54
+
55
+
56
+ class BaseStep(Generic[InputModelT, OutputModelT]):
57
+ input_model: type[InputModelT]
58
+ output_model: type[OutputModelT]
59
+
60
+ def __init_subclass__(cls) -> None:
61
+ super().__init_subclass__()
62
+ if cls is BaseStep:
63
+ return
64
+
65
+ hints = get_type_hints(cls.execute)
66
+ if "inp" not in hints or "return" not in hints:
67
+ raise TypeValidationError(
68
+ f"Step '{cls.__name__}' must type annotate execute(inp) and return type"
69
+ )
70
+ input_model = hints["inp"]
71
+ output_model = hints["return"]
72
+ if not (inspect.isclass(input_model) and issubclass(input_model, BaseModel)):
73
+ raise TypeValidationError(
74
+ f"Step '{cls.__name__}' input must inherit from pydantic BaseModel"
75
+ )
76
+ if not (inspect.isclass(output_model) and issubclass(output_model, BaseModel)):
77
+ raise TypeValidationError(
78
+ f"Step '{cls.__name__}' output must inherit from pydantic BaseModel"
79
+ )
80
+ cls.input_model = input_model
81
+ cls.output_model = output_model
82
+
83
+ async def execute(self, inp: InputModelT) -> OutputModelT:
84
+ raise NotImplementedError
85
+
86
+ async def compensate(self, inp: InputModelT, out: OutputModelT) -> None:
87
+ raise NotImplementedError