edda-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edda/__init__.py +56 -0
- edda/activity.py +505 -0
- edda/app.py +996 -0
- edda/compensation.py +326 -0
- edda/context.py +489 -0
- edda/events.py +505 -0
- edda/exceptions.py +64 -0
- edda/hooks.py +284 -0
- edda/locking.py +322 -0
- edda/outbox/__init__.py +15 -0
- edda/outbox/relayer.py +274 -0
- edda/outbox/transactional.py +112 -0
- edda/pydantic_utils.py +316 -0
- edda/replay.py +799 -0
- edda/retry.py +207 -0
- edda/serialization/__init__.py +9 -0
- edda/serialization/base.py +83 -0
- edda/serialization/json.py +102 -0
- edda/storage/__init__.py +9 -0
- edda/storage/models.py +194 -0
- edda/storage/protocol.py +737 -0
- edda/storage/sqlalchemy_storage.py +1809 -0
- edda/viewer_ui/__init__.py +20 -0
- edda/viewer_ui/app.py +1399 -0
- edda/viewer_ui/components.py +1105 -0
- edda/viewer_ui/data_service.py +880 -0
- edda/visualizer/__init__.py +11 -0
- edda/visualizer/ast_analyzer.py +383 -0
- edda/visualizer/mermaid_generator.py +355 -0
- edda/workflow.py +218 -0
- edda_framework-0.1.0.dist-info/METADATA +748 -0
- edda_framework-0.1.0.dist-info/RECORD +35 -0
- edda_framework-0.1.0.dist-info/WHEEL +4 -0
- edda_framework-0.1.0.dist-info/entry_points.txt +2 -0
- edda_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1809 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SQLAlchemy storage implementation for Edda framework.
|
|
3
|
+
|
|
4
|
+
This module provides a SQLAlchemy-based implementation of the StorageProtocol,
|
|
5
|
+
supporting SQLite, PostgreSQL, and MySQL with database-based exclusive control
|
|
6
|
+
and transactional outbox pattern.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
from collections.abc import AsyncIterator
|
|
12
|
+
from contextlib import asynccontextmanager
|
|
13
|
+
from contextvars import ContextVar
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from datetime import UTC, datetime, timedelta
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from sqlalchemy import (
|
|
19
|
+
CheckConstraint,
|
|
20
|
+
Column,
|
|
21
|
+
DateTime,
|
|
22
|
+
ForeignKeyConstraint,
|
|
23
|
+
Index,
|
|
24
|
+
Integer,
|
|
25
|
+
LargeBinary,
|
|
26
|
+
String,
|
|
27
|
+
Text,
|
|
28
|
+
UniqueConstraint,
|
|
29
|
+
and_,
|
|
30
|
+
delete,
|
|
31
|
+
func,
|
|
32
|
+
or_,
|
|
33
|
+
select,
|
|
34
|
+
text,
|
|
35
|
+
update,
|
|
36
|
+
)
|
|
37
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
|
38
|
+
from sqlalchemy.orm import declarative_base
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
# Declarative base for ORM models
|
|
43
|
+
Base = declarative_base()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ============================================================================
|
|
47
|
+
# SQLAlchemy ORM Models
|
|
48
|
+
# ============================================================================
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SchemaVersion(Base): # type: ignore[valid-type, misc]
|
|
52
|
+
"""Schema version tracking."""
|
|
53
|
+
|
|
54
|
+
__tablename__ = "schema_version"
|
|
55
|
+
|
|
56
|
+
version = Column(Integer, primary_key=True)
|
|
57
|
+
applied_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
58
|
+
description = Column(Text, nullable=False)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class WorkflowDefinition(Base): # type: ignore[valid-type, misc]
|
|
62
|
+
"""Workflow definition (source code storage)."""
|
|
63
|
+
|
|
64
|
+
__tablename__ = "workflow_definitions"
|
|
65
|
+
|
|
66
|
+
workflow_name = Column(String(255), nullable=False, primary_key=True)
|
|
67
|
+
source_hash = Column(String(64), nullable=False, primary_key=True)
|
|
68
|
+
source_code = Column(Text, nullable=False)
|
|
69
|
+
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
70
|
+
|
|
71
|
+
__table_args__ = (
|
|
72
|
+
Index("idx_definitions_name", "workflow_name"),
|
|
73
|
+
Index("idx_definitions_hash", "source_hash"),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class WorkflowInstance(Base): # type: ignore[valid-type, misc]
|
|
78
|
+
"""Workflow instance with distributed locking support."""
|
|
79
|
+
|
|
80
|
+
__tablename__ = "workflow_instances"
|
|
81
|
+
|
|
82
|
+
instance_id = Column(String(255), primary_key=True)
|
|
83
|
+
workflow_name = Column(String(255), nullable=False)
|
|
84
|
+
source_hash = Column(String(64), nullable=False)
|
|
85
|
+
owner_service = Column(String(255), nullable=False)
|
|
86
|
+
status = Column(
|
|
87
|
+
String(50),
|
|
88
|
+
nullable=False,
|
|
89
|
+
server_default=text("'running'"),
|
|
90
|
+
)
|
|
91
|
+
current_activity_id = Column(String(255), nullable=True)
|
|
92
|
+
started_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
93
|
+
updated_at = Column(
|
|
94
|
+
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
95
|
+
)
|
|
96
|
+
input_data = Column(Text, nullable=False) # JSON
|
|
97
|
+
output_data = Column(Text, nullable=True) # JSON
|
|
98
|
+
locked_by = Column(String(255), nullable=True)
|
|
99
|
+
locked_at = Column(DateTime(timezone=True), nullable=True)
|
|
100
|
+
lock_timeout_seconds = Column(Integer, nullable=True) # None = use global default (300s)
|
|
101
|
+
lock_expires_at = Column(DateTime(timezone=True), nullable=True) # Absolute expiry time
|
|
102
|
+
|
|
103
|
+
__table_args__ = (
|
|
104
|
+
ForeignKeyConstraint(
|
|
105
|
+
["workflow_name", "source_hash"],
|
|
106
|
+
["workflow_definitions.workflow_name", "workflow_definitions.source_hash"],
|
|
107
|
+
),
|
|
108
|
+
CheckConstraint(
|
|
109
|
+
"status IN ('running', 'completed', 'failed', 'waiting_for_event', "
|
|
110
|
+
"'waiting_for_timer', 'compensating', 'cancelled')",
|
|
111
|
+
name="valid_status",
|
|
112
|
+
),
|
|
113
|
+
Index("idx_instances_status", "status"),
|
|
114
|
+
Index("idx_instances_workflow", "workflow_name"),
|
|
115
|
+
Index("idx_instances_owner", "owner_service"),
|
|
116
|
+
Index("idx_instances_locked", "locked_by", "locked_at"),
|
|
117
|
+
Index("idx_instances_lock_expires", "lock_expires_at"),
|
|
118
|
+
Index("idx_instances_updated", "updated_at"),
|
|
119
|
+
Index("idx_instances_hash", "source_hash"),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class WorkflowHistory(Base): # type: ignore[valid-type, misc]
|
|
124
|
+
"""Workflow execution history (for deterministic replay)."""
|
|
125
|
+
|
|
126
|
+
__tablename__ = "workflow_history"
|
|
127
|
+
|
|
128
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
129
|
+
instance_id = Column(
|
|
130
|
+
String(255),
|
|
131
|
+
nullable=False,
|
|
132
|
+
)
|
|
133
|
+
activity_id = Column(String(255), nullable=False)
|
|
134
|
+
event_type = Column(String(100), nullable=False)
|
|
135
|
+
data_type = Column(String(10), nullable=False) # 'json' or 'binary'
|
|
136
|
+
event_data = Column(Text, nullable=True) # JSON (when data_type='json')
|
|
137
|
+
event_data_binary = Column(LargeBinary, nullable=True) # Binary (when data_type='binary')
|
|
138
|
+
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
139
|
+
|
|
140
|
+
__table_args__ = (
|
|
141
|
+
ForeignKeyConstraint(
|
|
142
|
+
["instance_id"],
|
|
143
|
+
["workflow_instances.instance_id"],
|
|
144
|
+
ondelete="CASCADE",
|
|
145
|
+
),
|
|
146
|
+
CheckConstraint(
|
|
147
|
+
"data_type IN ('json', 'binary')",
|
|
148
|
+
name="valid_data_type",
|
|
149
|
+
),
|
|
150
|
+
CheckConstraint(
|
|
151
|
+
"(data_type = 'json' AND event_data IS NOT NULL AND event_data_binary IS NULL) OR "
|
|
152
|
+
"(data_type = 'binary' AND event_data IS NULL AND event_data_binary IS NOT NULL)",
|
|
153
|
+
name="data_type_consistency",
|
|
154
|
+
),
|
|
155
|
+
UniqueConstraint("instance_id", "activity_id", name="unique_instance_activity"),
|
|
156
|
+
Index("idx_history_instance", "instance_id", "activity_id"),
|
|
157
|
+
Index("idx_history_created", "created_at"),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class WorkflowCompensation(Base): # type: ignore[valid-type, misc]
|
|
162
|
+
"""Compensation transactions (LIFO stack for Saga pattern)."""
|
|
163
|
+
|
|
164
|
+
__tablename__ = "workflow_compensations"
|
|
165
|
+
|
|
166
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
167
|
+
instance_id = Column(
|
|
168
|
+
String(255),
|
|
169
|
+
nullable=False,
|
|
170
|
+
)
|
|
171
|
+
activity_id = Column(String(255), nullable=False)
|
|
172
|
+
activity_name = Column(String(255), nullable=False)
|
|
173
|
+
args = Column(Text, nullable=False) # JSON
|
|
174
|
+
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
175
|
+
|
|
176
|
+
__table_args__ = (
|
|
177
|
+
ForeignKeyConstraint(
|
|
178
|
+
["instance_id"],
|
|
179
|
+
["workflow_instances.instance_id"],
|
|
180
|
+
ondelete="CASCADE",
|
|
181
|
+
),
|
|
182
|
+
Index("idx_compensations_instance", "instance_id", "created_at"),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class WorkflowEventSubscription(Base): # type: ignore[valid-type, misc]
|
|
187
|
+
"""Event subscriptions (for wait_event)."""
|
|
188
|
+
|
|
189
|
+
__tablename__ = "workflow_event_subscriptions"
|
|
190
|
+
|
|
191
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
192
|
+
instance_id = Column(
|
|
193
|
+
String(255),
|
|
194
|
+
nullable=False,
|
|
195
|
+
)
|
|
196
|
+
event_type = Column(String(255), nullable=False)
|
|
197
|
+
activity_id = Column(String(255), nullable=True)
|
|
198
|
+
timeout_at = Column(DateTime(timezone=True), nullable=True)
|
|
199
|
+
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
200
|
+
|
|
201
|
+
__table_args__ = (
|
|
202
|
+
ForeignKeyConstraint(
|
|
203
|
+
["instance_id"],
|
|
204
|
+
["workflow_instances.instance_id"],
|
|
205
|
+
ondelete="CASCADE",
|
|
206
|
+
),
|
|
207
|
+
UniqueConstraint("instance_id", "event_type", name="unique_instance_event"),
|
|
208
|
+
Index("idx_subscriptions_event", "event_type"),
|
|
209
|
+
Index("idx_subscriptions_timeout", "timeout_at"),
|
|
210
|
+
Index("idx_subscriptions_instance", "instance_id"),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class WorkflowTimerSubscription(Base): # type: ignore[valid-type, misc]
|
|
215
|
+
"""Timer subscriptions (for wait_timer)."""
|
|
216
|
+
|
|
217
|
+
__tablename__ = "workflow_timer_subscriptions"
|
|
218
|
+
|
|
219
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
220
|
+
instance_id = Column(
|
|
221
|
+
String(255),
|
|
222
|
+
nullable=False,
|
|
223
|
+
)
|
|
224
|
+
timer_id = Column(String(255), nullable=False)
|
|
225
|
+
expires_at = Column(DateTime(timezone=True), nullable=False)
|
|
226
|
+
activity_id = Column(String(255), nullable=True)
|
|
227
|
+
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
228
|
+
|
|
229
|
+
__table_args__ = (
|
|
230
|
+
ForeignKeyConstraint(
|
|
231
|
+
["instance_id"],
|
|
232
|
+
["workflow_instances.instance_id"],
|
|
233
|
+
ondelete="CASCADE",
|
|
234
|
+
),
|
|
235
|
+
UniqueConstraint("instance_id", "timer_id", name="unique_instance_timer"),
|
|
236
|
+
Index("idx_timer_subscriptions_expires", "expires_at"),
|
|
237
|
+
Index("idx_timer_subscriptions_instance", "instance_id"),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class OutboxEvent(Base): # type: ignore[valid-type, misc]
|
|
242
|
+
"""Transactional outbox pattern events."""
|
|
243
|
+
|
|
244
|
+
__tablename__ = "outbox_events"
|
|
245
|
+
|
|
246
|
+
event_id = Column(String(255), primary_key=True)
|
|
247
|
+
event_type = Column(String(255), nullable=False)
|
|
248
|
+
event_source = Column(String(255), nullable=False)
|
|
249
|
+
data_type = Column(String(10), nullable=False) # 'json' or 'binary'
|
|
250
|
+
event_data = Column(Text, nullable=True) # JSON (when data_type='json')
|
|
251
|
+
event_data_binary = Column(LargeBinary, nullable=True) # Binary (when data_type='binary')
|
|
252
|
+
content_type = Column(String(100), nullable=False, server_default=text("'application/json'"))
|
|
253
|
+
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
254
|
+
published_at = Column(DateTime(timezone=True), nullable=True)
|
|
255
|
+
status = Column(String(50), nullable=False, server_default=text("'pending'"))
|
|
256
|
+
retry_count = Column(Integer, nullable=False, server_default=text("0"))
|
|
257
|
+
last_error = Column(Text, nullable=True)
|
|
258
|
+
|
|
259
|
+
__table_args__ = (
|
|
260
|
+
CheckConstraint(
|
|
261
|
+
"status IN ('pending', 'processing', 'published', 'failed', 'invalid', 'expired')",
|
|
262
|
+
name="valid_outbox_status",
|
|
263
|
+
),
|
|
264
|
+
CheckConstraint(
|
|
265
|
+
"data_type IN ('json', 'binary')",
|
|
266
|
+
name="valid_outbox_data_type",
|
|
267
|
+
),
|
|
268
|
+
CheckConstraint(
|
|
269
|
+
"(data_type = 'json' AND event_data IS NOT NULL AND event_data_binary IS NULL) OR "
|
|
270
|
+
"(data_type = 'binary' AND event_data IS NULL AND event_data_binary IS NOT NULL)",
|
|
271
|
+
name="outbox_data_type_consistency",
|
|
272
|
+
),
|
|
273
|
+
Index("idx_outbox_status", "status", "created_at"),
|
|
274
|
+
Index("idx_outbox_retry", "status", "retry_count"),
|
|
275
|
+
Index("idx_outbox_published", "published_at"),
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
# Current schema version
|
|
280
|
+
CURRENT_SCHEMA_VERSION = 1
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
# ============================================================================
|
|
284
|
+
# Transaction Context
|
|
285
|
+
# ============================================================================
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@dataclass
|
|
289
|
+
class TransactionContext:
|
|
290
|
+
"""
|
|
291
|
+
Transaction context for managing nested transactions.
|
|
292
|
+
|
|
293
|
+
Uses savepoints for nested transaction support across all databases.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
depth: int = 0
|
|
297
|
+
"""Current transaction depth (0 = not in transaction, 1+ = in transaction)"""
|
|
298
|
+
|
|
299
|
+
savepoint_stack: list[Any] = field(default_factory=list)
|
|
300
|
+
"""Stack of nested transaction objects for savepoint support"""
|
|
301
|
+
|
|
302
|
+
session: "AsyncSession | None" = None
|
|
303
|
+
"""The actual session for this transaction"""
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
# Context variable for transaction state (asyncio-safe)
|
|
307
|
+
_transaction_context: ContextVar[TransactionContext | None] = ContextVar(
|
|
308
|
+
"_transaction_context", default=None
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
# ============================================================================
|
|
313
|
+
# SQLAlchemyStorage
|
|
314
|
+
# ============================================================================
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class SQLAlchemyStorage:
|
|
318
|
+
"""
|
|
319
|
+
SQLAlchemy implementation of StorageProtocol.
|
|
320
|
+
|
|
321
|
+
Supports SQLite, PostgreSQL, and MySQL with database-based exclusive control
|
|
322
|
+
and transactional outbox pattern.
|
|
323
|
+
|
|
324
|
+
Transaction Architecture:
|
|
325
|
+
- Lock operations: Always use separate session (isolated transactions)
|
|
326
|
+
- History/outbox operations: Use transaction context session when available
|
|
327
|
+
- Automatic transaction management via @activity decorator
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
def __init__(self, engine: AsyncEngine):
|
|
331
|
+
"""
|
|
332
|
+
Initialize SQLAlchemy storage.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
engine: SQLAlchemy AsyncEngine instance
|
|
336
|
+
"""
|
|
337
|
+
self.engine = engine
|
|
338
|
+
|
|
339
|
+
async def initialize(self) -> None:
|
|
340
|
+
"""Initialize database connection and create tables."""
|
|
341
|
+
# Create all tables and indexes
|
|
342
|
+
async with self.engine.begin() as conn:
|
|
343
|
+
await conn.run_sync(Base.metadata.create_all)
|
|
344
|
+
|
|
345
|
+
# Initialize schema version
|
|
346
|
+
await self._initialize_schema_version()
|
|
347
|
+
|
|
348
|
+
async def close(self) -> None:
|
|
349
|
+
"""Close database connection."""
|
|
350
|
+
await self.engine.dispose()
|
|
351
|
+
|
|
352
|
+
async def _initialize_schema_version(self) -> None:
|
|
353
|
+
"""Initialize schema version for a fresh database."""
|
|
354
|
+
async with AsyncSession(self.engine) as session:
|
|
355
|
+
# Check if schema_version table is empty
|
|
356
|
+
result = await session.execute(select(func.count()).select_from(SchemaVersion))
|
|
357
|
+
count = result.scalar()
|
|
358
|
+
|
|
359
|
+
# If empty, insert current version
|
|
360
|
+
if count == 0:
|
|
361
|
+
version = SchemaVersion(
|
|
362
|
+
version=CURRENT_SCHEMA_VERSION,
|
|
363
|
+
description="Initial schema with workflow_definitions",
|
|
364
|
+
)
|
|
365
|
+
session.add(version)
|
|
366
|
+
await session.commit()
|
|
367
|
+
logger.info(f"Initialized schema version to {CURRENT_SCHEMA_VERSION}")
|
|
368
|
+
|
|
369
|
+
def _get_session_for_operation(self, is_lock_operation: bool = False) -> AsyncSession:
|
|
370
|
+
"""
|
|
371
|
+
Get the appropriate session for an operation.
|
|
372
|
+
|
|
373
|
+
Lock operations ALWAYS use a new session (separate transactions).
|
|
374
|
+
Other operations prefer: transaction session > new session.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
is_lock_operation: True if this is a lock acquisition/release operation
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
AsyncSession to use for the operation
|
|
381
|
+
"""
|
|
382
|
+
if is_lock_operation:
|
|
383
|
+
# Lock operations always use new session
|
|
384
|
+
return AsyncSession(self.engine, expire_on_commit=False)
|
|
385
|
+
|
|
386
|
+
# Check for transaction context session
|
|
387
|
+
ctx = _transaction_context.get()
|
|
388
|
+
if ctx is not None and ctx.session is not None:
|
|
389
|
+
return ctx.session
|
|
390
|
+
|
|
391
|
+
# Otherwise create new session
|
|
392
|
+
return AsyncSession(self.engine, expire_on_commit=False)
|
|
393
|
+
|
|
394
|
+
def _is_managed_session(self, session: AsyncSession) -> bool:
|
|
395
|
+
"""Check if session is managed by transaction context."""
|
|
396
|
+
ctx = _transaction_context.get()
|
|
397
|
+
return ctx is not None and ctx.session == session
|
|
398
|
+
|
|
399
|
+
@asynccontextmanager
|
|
400
|
+
async def _session_scope(self, session: AsyncSession) -> AsyncIterator[AsyncSession]:
|
|
401
|
+
"""
|
|
402
|
+
Context manager for session usage.
|
|
403
|
+
|
|
404
|
+
If session is managed (transaction context), use it directly without closing.
|
|
405
|
+
If session is new, manage its lifecycle (commit/rollback/close).
|
|
406
|
+
"""
|
|
407
|
+
if self._is_managed_session(session):
|
|
408
|
+
# Managed session: yield without lifecycle management
|
|
409
|
+
yield session
|
|
410
|
+
else:
|
|
411
|
+
# New session: full lifecycle management
|
|
412
|
+
try:
|
|
413
|
+
yield session
|
|
414
|
+
await self._commit_if_not_in_transaction(session)
|
|
415
|
+
except Exception:
|
|
416
|
+
await session.rollback()
|
|
417
|
+
raise
|
|
418
|
+
finally:
|
|
419
|
+
await session.close()
|
|
420
|
+
|
|
421
|
+
def _get_current_time_expr(self) -> Any:
|
|
422
|
+
"""
|
|
423
|
+
Get database-specific current time SQL expression.
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
SQLAlchemy function for current time in SQL queries.
|
|
427
|
+
- SQLite: datetime('now') - returns UTC datetime string
|
|
428
|
+
- PostgreSQL/MySQL: NOW() - returns timezone-aware datetime
|
|
429
|
+
|
|
430
|
+
This method enables cross-database datetime comparisons in SQL queries.
|
|
431
|
+
"""
|
|
432
|
+
if self.engine.dialect.name == "sqlite":
|
|
433
|
+
# SQLite: datetime('now') returns UTC datetime as string
|
|
434
|
+
return func.datetime("now")
|
|
435
|
+
else:
|
|
436
|
+
# PostgreSQL/MySQL: NOW() returns timezone-aware datetime
|
|
437
|
+
return func.now()
|
|
438
|
+
|
|
439
|
+
def _make_datetime_comparable(self, column: Any) -> Any:
|
|
440
|
+
"""
|
|
441
|
+
Make datetime column comparable with current time in SQL queries.
|
|
442
|
+
|
|
443
|
+
For SQLite, wraps column in datetime() function to ensure proper comparison.
|
|
444
|
+
For PostgreSQL/MySQL, returns column as-is (already timezone-aware).
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
column: SQLAlchemy Column expression representing a datetime field
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
SQLAlchemy expression suitable for datetime comparison
|
|
451
|
+
|
|
452
|
+
Example:
|
|
453
|
+
>>> # SQLite: datetime(timeout_at) <= datetime('now')
|
|
454
|
+
>>> # PostgreSQL/MySQL: timeout_at <= NOW()
|
|
455
|
+
>>> self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
456
|
+
>>> <= self._get_current_time_expr()
|
|
457
|
+
"""
|
|
458
|
+
if self.engine.dialect.name == "sqlite":
|
|
459
|
+
# SQLite: wrap in datetime() for proper comparison
|
|
460
|
+
return func.datetime(column)
|
|
461
|
+
else:
|
|
462
|
+
# PostgreSQL/MySQL: column is already timezone-aware
|
|
463
|
+
return column
|
|
464
|
+
|
|
465
|
+
# -------------------------------------------------------------------------
|
|
466
|
+
# Transaction Management Methods
|
|
467
|
+
# -------------------------------------------------------------------------
|
|
468
|
+
|
|
469
|
+
async def begin_transaction(self) -> None:
|
|
470
|
+
"""
|
|
471
|
+
Begin a new transaction.
|
|
472
|
+
|
|
473
|
+
If a transaction is already in progress, creates a nested transaction
|
|
474
|
+
using savepoints. This is asyncio-safe using ContextVar.
|
|
475
|
+
"""
|
|
476
|
+
ctx = _transaction_context.get()
|
|
477
|
+
|
|
478
|
+
if ctx is None:
|
|
479
|
+
# First transaction - create new context with session
|
|
480
|
+
session = AsyncSession(self.engine, expire_on_commit=False)
|
|
481
|
+
ctx = TransactionContext(session=session)
|
|
482
|
+
_transaction_context.set(ctx)
|
|
483
|
+
|
|
484
|
+
ctx.depth += 1
|
|
485
|
+
|
|
486
|
+
if ctx.depth == 1:
|
|
487
|
+
# Top-level transaction - begin the session transaction
|
|
488
|
+
logger.debug("Beginning top-level transaction")
|
|
489
|
+
await ctx.session.begin() # type: ignore[union-attr]
|
|
490
|
+
else:
|
|
491
|
+
# Nested transaction - use SQLAlchemy's begin_nested() (creates SAVEPOINT)
|
|
492
|
+
nested_tx = await ctx.session.begin_nested() # type: ignore[union-attr]
|
|
493
|
+
ctx.savepoint_stack.append(nested_tx)
|
|
494
|
+
logger.debug(f"Created nested transaction (savepoint) at depth={ctx.depth}")
|
|
495
|
+
|
|
496
|
+
async def commit_transaction(self) -> None:
|
|
497
|
+
"""
|
|
498
|
+
Commit the current transaction.
|
|
499
|
+
|
|
500
|
+
For nested transactions, releases the savepoint.
|
|
501
|
+
For top-level transactions, commits to the database.
|
|
502
|
+
"""
|
|
503
|
+
ctx = _transaction_context.get()
|
|
504
|
+
if ctx is None or ctx.depth == 0:
|
|
505
|
+
raise RuntimeError("Not in a transaction")
|
|
506
|
+
|
|
507
|
+
if ctx.depth == 1:
|
|
508
|
+
# Top-level transaction - commit the session
|
|
509
|
+
logger.debug("Committing top-level transaction")
|
|
510
|
+
await ctx.session.commit() # type: ignore[union-attr]
|
|
511
|
+
await ctx.session.close() # type: ignore[union-attr]
|
|
512
|
+
else:
|
|
513
|
+
# Nested transaction - commit the savepoint
|
|
514
|
+
nested_tx = ctx.savepoint_stack.pop()
|
|
515
|
+
await nested_tx.commit()
|
|
516
|
+
logger.debug(f"Committed nested transaction (savepoint) at depth={ctx.depth}")
|
|
517
|
+
|
|
518
|
+
ctx.depth -= 1
|
|
519
|
+
|
|
520
|
+
if ctx.depth == 0:
|
|
521
|
+
# All transactions completed - clear context
|
|
522
|
+
_transaction_context.set(None)
|
|
523
|
+
|
|
524
|
+
async def rollback_transaction(self) -> None:
|
|
525
|
+
"""
|
|
526
|
+
Rollback the current transaction.
|
|
527
|
+
|
|
528
|
+
For nested transactions, rolls back to the savepoint.
|
|
529
|
+
For top-level transactions, rolls back all changes.
|
|
530
|
+
"""
|
|
531
|
+
ctx = _transaction_context.get()
|
|
532
|
+
if ctx is None or ctx.depth == 0:
|
|
533
|
+
raise RuntimeError("Not in a transaction")
|
|
534
|
+
|
|
535
|
+
if ctx.depth == 1:
|
|
536
|
+
# Top-level transaction - rollback the session
|
|
537
|
+
logger.debug("Rolling back top-level transaction")
|
|
538
|
+
await ctx.session.rollback() # type: ignore[union-attr]
|
|
539
|
+
await ctx.session.close() # type: ignore[union-attr]
|
|
540
|
+
else:
|
|
541
|
+
# Nested transaction - rollback the savepoint
|
|
542
|
+
nested_tx = ctx.savepoint_stack.pop()
|
|
543
|
+
await nested_tx.rollback()
|
|
544
|
+
logger.debug(f"Rolled back nested transaction (savepoint) at depth={ctx.depth}")
|
|
545
|
+
|
|
546
|
+
ctx.depth -= 1
|
|
547
|
+
|
|
548
|
+
if ctx.depth == 0:
|
|
549
|
+
# All transactions rolled back - clear context
|
|
550
|
+
_transaction_context.set(None)
|
|
551
|
+
|
|
552
|
+
def in_transaction(self) -> bool:
|
|
553
|
+
"""
|
|
554
|
+
Check if currently in a transaction.
|
|
555
|
+
|
|
556
|
+
Returns:
|
|
557
|
+
True if in a transaction, False otherwise.
|
|
558
|
+
"""
|
|
559
|
+
ctx = _transaction_context.get()
|
|
560
|
+
return ctx is not None and ctx.depth > 0
|
|
561
|
+
|
|
562
|
+
async def _commit_if_not_in_transaction(self, session: AsyncSession) -> None:
|
|
563
|
+
"""
|
|
564
|
+
Commit session if not in a transaction (auto-commit mode).
|
|
565
|
+
|
|
566
|
+
This helper method ensures that operations outside of explicit transactions
|
|
567
|
+
are still committed, while operations inside transactions are deferred
|
|
568
|
+
until the transaction is committed.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
session: Database session
|
|
572
|
+
"""
|
|
573
|
+
# If this is a transaction context session, don't commit (will be done by commit_transaction)
|
|
574
|
+
ctx = _transaction_context.get()
|
|
575
|
+
if ctx is not None and ctx.session == session:
|
|
576
|
+
return
|
|
577
|
+
|
|
578
|
+
# If not in transaction, commit
|
|
579
|
+
if not self.in_transaction():
|
|
580
|
+
await session.commit()
|
|
581
|
+
|
|
582
|
+
# -------------------------------------------------------------------------
|
|
583
|
+
# Workflow Definition Methods
|
|
584
|
+
# -------------------------------------------------------------------------
|
|
585
|
+
|
|
586
|
+
async def upsert_workflow_definition(
|
|
587
|
+
self,
|
|
588
|
+
workflow_name: str,
|
|
589
|
+
source_hash: str,
|
|
590
|
+
source_code: str,
|
|
591
|
+
) -> None:
|
|
592
|
+
"""Insert or update a workflow definition."""
|
|
593
|
+
session = self._get_session_for_operation()
|
|
594
|
+
async with self._session_scope(session) as session:
|
|
595
|
+
# Check if exists
|
|
596
|
+
result = await session.execute(
|
|
597
|
+
select(WorkflowDefinition).where(
|
|
598
|
+
and_(
|
|
599
|
+
WorkflowDefinition.workflow_name == workflow_name,
|
|
600
|
+
WorkflowDefinition.source_hash == source_hash,
|
|
601
|
+
)
|
|
602
|
+
)
|
|
603
|
+
)
|
|
604
|
+
existing = result.scalar_one_or_none()
|
|
605
|
+
|
|
606
|
+
if existing:
|
|
607
|
+
# Update
|
|
608
|
+
existing.source_code = source_code # type: ignore[assignment]
|
|
609
|
+
else:
|
|
610
|
+
# Insert
|
|
611
|
+
definition = WorkflowDefinition(
|
|
612
|
+
workflow_name=workflow_name,
|
|
613
|
+
source_hash=source_hash,
|
|
614
|
+
source_code=source_code,
|
|
615
|
+
)
|
|
616
|
+
session.add(definition)
|
|
617
|
+
|
|
618
|
+
await self._commit_if_not_in_transaction(session)
|
|
619
|
+
|
|
620
|
+
async def get_workflow_definition(
|
|
621
|
+
self,
|
|
622
|
+
workflow_name: str,
|
|
623
|
+
source_hash: str,
|
|
624
|
+
) -> dict[str, Any] | None:
|
|
625
|
+
"""Get a workflow definition by name and hash."""
|
|
626
|
+
session = self._get_session_for_operation()
|
|
627
|
+
async with self._session_scope(session) as session:
|
|
628
|
+
result = await session.execute(
|
|
629
|
+
select(WorkflowDefinition).where(
|
|
630
|
+
and_(
|
|
631
|
+
WorkflowDefinition.workflow_name == workflow_name,
|
|
632
|
+
WorkflowDefinition.source_hash == source_hash,
|
|
633
|
+
)
|
|
634
|
+
)
|
|
635
|
+
)
|
|
636
|
+
definition = result.scalar_one_or_none()
|
|
637
|
+
|
|
638
|
+
if definition is None:
|
|
639
|
+
return None
|
|
640
|
+
|
|
641
|
+
return {
|
|
642
|
+
"workflow_name": definition.workflow_name,
|
|
643
|
+
"source_hash": definition.source_hash,
|
|
644
|
+
"source_code": definition.source_code,
|
|
645
|
+
"created_at": definition.created_at.isoformat(),
|
|
646
|
+
}
|
|
647
|
+
|
|
648
|
+
async def get_current_workflow_definition(
|
|
649
|
+
self,
|
|
650
|
+
workflow_name: str,
|
|
651
|
+
) -> dict[str, Any] | None:
|
|
652
|
+
"""Get the most recent workflow definition by name."""
|
|
653
|
+
session = self._get_session_for_operation()
|
|
654
|
+
async with self._session_scope(session) as session:
|
|
655
|
+
result = await session.execute(
|
|
656
|
+
select(WorkflowDefinition)
|
|
657
|
+
.where(WorkflowDefinition.workflow_name == workflow_name)
|
|
658
|
+
.order_by(WorkflowDefinition.created_at.desc())
|
|
659
|
+
.limit(1)
|
|
660
|
+
)
|
|
661
|
+
definition = result.scalar_one_or_none()
|
|
662
|
+
|
|
663
|
+
if definition is None:
|
|
664
|
+
return None
|
|
665
|
+
|
|
666
|
+
return {
|
|
667
|
+
"workflow_name": definition.workflow_name,
|
|
668
|
+
"source_hash": definition.source_hash,
|
|
669
|
+
"source_code": definition.source_code,
|
|
670
|
+
"created_at": definition.created_at.isoformat(),
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
# -------------------------------------------------------------------------
|
|
674
|
+
# Workflow Instance Methods
|
|
675
|
+
# -------------------------------------------------------------------------
|
|
676
|
+
|
|
677
|
+
async def create_instance(
|
|
678
|
+
self,
|
|
679
|
+
instance_id: str,
|
|
680
|
+
workflow_name: str,
|
|
681
|
+
source_hash: str,
|
|
682
|
+
owner_service: str,
|
|
683
|
+
input_data: dict[str, Any],
|
|
684
|
+
lock_timeout_seconds: int | None = None,
|
|
685
|
+
) -> None:
|
|
686
|
+
"""Create a new workflow instance."""
|
|
687
|
+
session = self._get_session_for_operation()
|
|
688
|
+
async with self._session_scope(session) as session:
|
|
689
|
+
instance = WorkflowInstance(
|
|
690
|
+
instance_id=instance_id,
|
|
691
|
+
workflow_name=workflow_name,
|
|
692
|
+
source_hash=source_hash,
|
|
693
|
+
owner_service=owner_service,
|
|
694
|
+
input_data=json.dumps(input_data),
|
|
695
|
+
lock_timeout_seconds=lock_timeout_seconds,
|
|
696
|
+
)
|
|
697
|
+
session.add(instance)
|
|
698
|
+
|
|
699
|
+
async def get_instance(self, instance_id: str) -> dict[str, Any] | None:
|
|
700
|
+
"""Get workflow instance metadata with its definition."""
|
|
701
|
+
session = self._get_session_for_operation()
|
|
702
|
+
async with self._session_scope(session) as session:
|
|
703
|
+
# Join with workflow_definitions to get source_code
|
|
704
|
+
result = await session.execute(
|
|
705
|
+
select(WorkflowInstance, WorkflowDefinition.source_code)
|
|
706
|
+
.join(
|
|
707
|
+
WorkflowDefinition,
|
|
708
|
+
and_(
|
|
709
|
+
WorkflowInstance.workflow_name == WorkflowDefinition.workflow_name,
|
|
710
|
+
WorkflowInstance.source_hash == WorkflowDefinition.source_hash,
|
|
711
|
+
),
|
|
712
|
+
)
|
|
713
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
714
|
+
)
|
|
715
|
+
row = result.one_or_none()
|
|
716
|
+
|
|
717
|
+
if row is None:
|
|
718
|
+
return None
|
|
719
|
+
|
|
720
|
+
instance, source_code = row
|
|
721
|
+
|
|
722
|
+
return {
|
|
723
|
+
"instance_id": instance.instance_id,
|
|
724
|
+
"workflow_name": instance.workflow_name,
|
|
725
|
+
"source_hash": instance.source_hash,
|
|
726
|
+
"owner_service": instance.owner_service,
|
|
727
|
+
"status": instance.status,
|
|
728
|
+
"current_activity_id": instance.current_activity_id,
|
|
729
|
+
"started_at": instance.started_at.isoformat(),
|
|
730
|
+
"updated_at": instance.updated_at.isoformat(),
|
|
731
|
+
"input_data": json.loads(instance.input_data),
|
|
732
|
+
"source_code": source_code,
|
|
733
|
+
"output_data": json.loads(instance.output_data) if instance.output_data else None,
|
|
734
|
+
"locked_by": instance.locked_by,
|
|
735
|
+
"locked_at": instance.locked_at.isoformat() if instance.locked_at else None,
|
|
736
|
+
"lock_timeout_seconds": instance.lock_timeout_seconds,
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
async def update_instance_status(
|
|
740
|
+
self,
|
|
741
|
+
instance_id: str,
|
|
742
|
+
status: str,
|
|
743
|
+
output_data: dict[str, Any] | None = None,
|
|
744
|
+
) -> None:
|
|
745
|
+
"""Update workflow instance status."""
|
|
746
|
+
session = self._get_session_for_operation()
|
|
747
|
+
async with self._session_scope(session) as session:
|
|
748
|
+
stmt = (
|
|
749
|
+
update(WorkflowInstance)
|
|
750
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
751
|
+
.values(
|
|
752
|
+
status=status,
|
|
753
|
+
updated_at=func.now(),
|
|
754
|
+
)
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
if output_data is not None:
|
|
758
|
+
stmt = stmt.values(output_data=json.dumps(output_data))
|
|
759
|
+
|
|
760
|
+
await session.execute(stmt)
|
|
761
|
+
await self._commit_if_not_in_transaction(session)
|
|
762
|
+
|
|
763
|
+
async def update_instance_activity(self, instance_id: str, activity_id: str) -> None:
|
|
764
|
+
"""Update the current activity ID for a workflow instance."""
|
|
765
|
+
session = self._get_session_for_operation()
|
|
766
|
+
async with self._session_scope(session) as session:
|
|
767
|
+
await session.execute(
|
|
768
|
+
update(WorkflowInstance)
|
|
769
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
770
|
+
.values(current_activity_id=activity_id, updated_at=func.now())
|
|
771
|
+
)
|
|
772
|
+
await self._commit_if_not_in_transaction(session)
|
|
773
|
+
|
|
774
|
+
async def list_instances(
|
|
775
|
+
self,
|
|
776
|
+
limit: int = 50,
|
|
777
|
+
status_filter: str | None = None,
|
|
778
|
+
) -> list[dict[str, Any]]:
|
|
779
|
+
"""List workflow instances with optional filtering."""
|
|
780
|
+
session = self._get_session_for_operation()
|
|
781
|
+
async with self._session_scope(session) as session:
|
|
782
|
+
stmt = (
|
|
783
|
+
select(WorkflowInstance, WorkflowDefinition.source_code)
|
|
784
|
+
.join(
|
|
785
|
+
WorkflowDefinition,
|
|
786
|
+
and_(
|
|
787
|
+
WorkflowInstance.workflow_name == WorkflowDefinition.workflow_name,
|
|
788
|
+
WorkflowInstance.source_hash == WorkflowDefinition.source_hash,
|
|
789
|
+
),
|
|
790
|
+
)
|
|
791
|
+
.order_by(WorkflowInstance.started_at.desc())
|
|
792
|
+
.limit(limit)
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
if status_filter:
|
|
796
|
+
stmt = stmt.where(WorkflowInstance.status == status_filter)
|
|
797
|
+
|
|
798
|
+
result = await session.execute(stmt)
|
|
799
|
+
rows = result.all()
|
|
800
|
+
|
|
801
|
+
return [
|
|
802
|
+
{
|
|
803
|
+
"instance_id": instance.instance_id,
|
|
804
|
+
"workflow_name": instance.workflow_name,
|
|
805
|
+
"source_hash": instance.source_hash,
|
|
806
|
+
"owner_service": instance.owner_service,
|
|
807
|
+
"status": instance.status,
|
|
808
|
+
"current_activity_id": instance.current_activity_id,
|
|
809
|
+
"started_at": instance.started_at.isoformat(),
|
|
810
|
+
"updated_at": instance.updated_at.isoformat(),
|
|
811
|
+
"input_data": json.loads(instance.input_data),
|
|
812
|
+
"source_code": source_code,
|
|
813
|
+
"output_data": (
|
|
814
|
+
json.loads(instance.output_data) if instance.output_data else None
|
|
815
|
+
),
|
|
816
|
+
"locked_by": instance.locked_by,
|
|
817
|
+
"locked_at": instance.locked_at.isoformat() if instance.locked_at else None,
|
|
818
|
+
"lock_timeout_seconds": instance.lock_timeout_seconds,
|
|
819
|
+
}
|
|
820
|
+
for instance, source_code in rows
|
|
821
|
+
]
|
|
822
|
+
|
|
823
|
+
# -------------------------------------------------------------------------
|
|
824
|
+
# Distributed Locking Methods (ALWAYS use separate session/transaction)
|
|
825
|
+
# -------------------------------------------------------------------------
|
|
826
|
+
|
|
827
|
+
async def try_acquire_lock(
|
|
828
|
+
self,
|
|
829
|
+
instance_id: str,
|
|
830
|
+
worker_id: str,
|
|
831
|
+
timeout_seconds: int = 300,
|
|
832
|
+
) -> bool:
|
|
833
|
+
"""
|
|
834
|
+
Try to acquire lock using SELECT FOR UPDATE.
|
|
835
|
+
|
|
836
|
+
This implements distributed locking with automatic stale lock detection.
|
|
837
|
+
Returns True if lock was acquired, False if already locked by another worker.
|
|
838
|
+
Can acquire locks that have timed out.
|
|
839
|
+
|
|
840
|
+
Note: ALWAYS uses separate session (not external session).
|
|
841
|
+
"""
|
|
842
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
843
|
+
async with self._session_scope(session) as session:
|
|
844
|
+
# Calculate timeout threshold and current time
|
|
845
|
+
# Use UTC time consistently (timezone-aware to match DateTime(timezone=True) columns)
|
|
846
|
+
current_time = datetime.now(UTC)
|
|
847
|
+
|
|
848
|
+
# SELECT FOR UPDATE SKIP LOCKED to prevent blocking (PostgreSQL/MySQL)
|
|
849
|
+
# SKIP LOCKED: If row is already locked, return None immediately (no blocking)
|
|
850
|
+
result = await session.execute(
|
|
851
|
+
select(WorkflowInstance)
|
|
852
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
853
|
+
.with_for_update(skip_locked=True)
|
|
854
|
+
)
|
|
855
|
+
instance = result.scalar_one_or_none()
|
|
856
|
+
|
|
857
|
+
if instance is None:
|
|
858
|
+
# Instance doesn't exist
|
|
859
|
+
await session.commit()
|
|
860
|
+
return False
|
|
861
|
+
|
|
862
|
+
# Determine actual timeout (priority: instance > parameter > default)
|
|
863
|
+
actual_timeout = int(
|
|
864
|
+
instance.lock_timeout_seconds
|
|
865
|
+
if instance.lock_timeout_seconds is not None
|
|
866
|
+
else timeout_seconds
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
# Check if we can acquire the lock
|
|
870
|
+
# Lock is available if: not locked OR lock has expired
|
|
871
|
+
# Note: SQLite stores datetime without timezone, add UTC timezone
|
|
872
|
+
if instance.locked_by is None:
|
|
873
|
+
can_acquire = True
|
|
874
|
+
elif instance.lock_expires_at is not None:
|
|
875
|
+
lock_expires_at_utc = (
|
|
876
|
+
instance.lock_expires_at.replace(tzinfo=UTC)
|
|
877
|
+
if instance.lock_expires_at.tzinfo is None
|
|
878
|
+
else instance.lock_expires_at
|
|
879
|
+
)
|
|
880
|
+
can_acquire = lock_expires_at_utc < current_time
|
|
881
|
+
else:
|
|
882
|
+
can_acquire = False
|
|
883
|
+
|
|
884
|
+
# Debug logging
|
|
885
|
+
logger.debug(
|
|
886
|
+
f"Lock acquisition check: instance_id={instance_id}, "
|
|
887
|
+
f"locked_by={instance.locked_by}, lock_expires_at={instance.lock_expires_at}, "
|
|
888
|
+
f"current_time={current_time}, can_acquire={can_acquire}"
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
if not can_acquire:
|
|
892
|
+
# Already locked by another worker
|
|
893
|
+
logger.debug(f"Lock acquisition failed: already locked by {instance.locked_by}")
|
|
894
|
+
await session.commit()
|
|
895
|
+
return False
|
|
896
|
+
|
|
897
|
+
# Acquire the lock and set expiry time
|
|
898
|
+
lock_expires_at = current_time + timedelta(seconds=actual_timeout)
|
|
899
|
+
await session.execute(
|
|
900
|
+
update(WorkflowInstance)
|
|
901
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
902
|
+
.values(
|
|
903
|
+
locked_by=worker_id,
|
|
904
|
+
locked_at=current_time,
|
|
905
|
+
lock_expires_at=lock_expires_at,
|
|
906
|
+
updated_at=current_time,
|
|
907
|
+
)
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
await session.commit()
|
|
911
|
+
return True
|
|
912
|
+
|
|
913
|
+
async def release_lock(self, instance_id: str, worker_id: str) -> None:
|
|
914
|
+
"""
|
|
915
|
+
Release lock only if we own it.
|
|
916
|
+
|
|
917
|
+
Note: ALWAYS uses separate session (not external session).
|
|
918
|
+
"""
|
|
919
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
920
|
+
async with self._session_scope(session) as session:
|
|
921
|
+
# Use Python datetime for consistency (timezone-aware)
|
|
922
|
+
current_time = datetime.now(UTC)
|
|
923
|
+
|
|
924
|
+
await session.execute(
|
|
925
|
+
update(WorkflowInstance)
|
|
926
|
+
.where(
|
|
927
|
+
and_(
|
|
928
|
+
WorkflowInstance.instance_id == instance_id,
|
|
929
|
+
WorkflowInstance.locked_by == worker_id,
|
|
930
|
+
)
|
|
931
|
+
)
|
|
932
|
+
.values(
|
|
933
|
+
locked_by=None,
|
|
934
|
+
locked_at=None,
|
|
935
|
+
lock_expires_at=None,
|
|
936
|
+
updated_at=current_time,
|
|
937
|
+
)
|
|
938
|
+
)
|
|
939
|
+
await session.commit()
|
|
940
|
+
|
|
941
|
+
async def refresh_lock(
|
|
942
|
+
self, instance_id: str, worker_id: str, timeout_seconds: int = 300
|
|
943
|
+
) -> bool:
|
|
944
|
+
"""
|
|
945
|
+
Refresh lock timestamp and expiry time.
|
|
946
|
+
|
|
947
|
+
Args:
|
|
948
|
+
instance_id: Workflow instance ID
|
|
949
|
+
worker_id: Worker ID that currently owns the lock
|
|
950
|
+
timeout_seconds: Default timeout (used if instance doesn't have custom timeout)
|
|
951
|
+
|
|
952
|
+
Note: ALWAYS uses separate session (not external session).
|
|
953
|
+
"""
|
|
954
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
955
|
+
async with self._session_scope(session) as session:
|
|
956
|
+
# Use Python datetime for consistency with try_acquire_lock() (timezone-aware)
|
|
957
|
+
current_time = datetime.now(UTC)
|
|
958
|
+
|
|
959
|
+
# First, get the instance to determine actual timeout
|
|
960
|
+
result = await session.execute(
|
|
961
|
+
select(WorkflowInstance).where(
|
|
962
|
+
and_(
|
|
963
|
+
WorkflowInstance.instance_id == instance_id,
|
|
964
|
+
WorkflowInstance.locked_by == worker_id,
|
|
965
|
+
)
|
|
966
|
+
)
|
|
967
|
+
)
|
|
968
|
+
instance = result.scalar_one_or_none()
|
|
969
|
+
|
|
970
|
+
if instance is None:
|
|
971
|
+
# Instance doesn't exist or not locked by us
|
|
972
|
+
await session.commit()
|
|
973
|
+
return False
|
|
974
|
+
|
|
975
|
+
# Determine actual timeout (priority: instance > parameter > default)
|
|
976
|
+
actual_timeout = int(
|
|
977
|
+
instance.lock_timeout_seconds
|
|
978
|
+
if instance.lock_timeout_seconds is not None
|
|
979
|
+
else timeout_seconds
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
# Calculate new expiry time
|
|
983
|
+
lock_expires_at = current_time + timedelta(seconds=actual_timeout)
|
|
984
|
+
|
|
985
|
+
# Update lock timestamp and expiry
|
|
986
|
+
result = await session.execute(
|
|
987
|
+
update(WorkflowInstance)
|
|
988
|
+
.where(
|
|
989
|
+
and_(
|
|
990
|
+
WorkflowInstance.instance_id == instance_id,
|
|
991
|
+
WorkflowInstance.locked_by == worker_id,
|
|
992
|
+
)
|
|
993
|
+
)
|
|
994
|
+
.values(
|
|
995
|
+
locked_at=current_time,
|
|
996
|
+
lock_expires_at=lock_expires_at,
|
|
997
|
+
updated_at=current_time,
|
|
998
|
+
)
|
|
999
|
+
)
|
|
1000
|
+
await session.commit()
|
|
1001
|
+
return bool(result.rowcount and result.rowcount > 0) # type: ignore[attr-defined]
|
|
1002
|
+
|
|
1003
|
+
async def cleanup_stale_locks(self) -> list[dict[str, str]]:
|
|
1004
|
+
"""
|
|
1005
|
+
Clean up locks that have expired (based on lock_expires_at column).
|
|
1006
|
+
|
|
1007
|
+
Returns list of workflows with status='running' or 'compensating' that need auto-resume.
|
|
1008
|
+
|
|
1009
|
+
Workflows with status='compensating' crashed during compensation execution
|
|
1010
|
+
and need special handling to complete compensations.
|
|
1011
|
+
|
|
1012
|
+
Note: Uses lock_expires_at column for efficient SQL-side filtering.
|
|
1013
|
+
Note: ALWAYS uses separate session (not external session).
|
|
1014
|
+
"""
|
|
1015
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1016
|
+
async with self._session_scope(session) as session:
|
|
1017
|
+
# Use timezone-aware datetime to match DateTime(timezone=True) columns
|
|
1018
|
+
current_time = datetime.now(UTC)
|
|
1019
|
+
|
|
1020
|
+
# SQL-side filtering: Find all instances with expired locks
|
|
1021
|
+
# Use database abstraction layer for cross-database compatibility
|
|
1022
|
+
result = await session.execute(
|
|
1023
|
+
select(WorkflowInstance).where(
|
|
1024
|
+
and_(
|
|
1025
|
+
WorkflowInstance.locked_by.isnot(None),
|
|
1026
|
+
WorkflowInstance.lock_expires_at.isnot(None),
|
|
1027
|
+
self._make_datetime_comparable(WorkflowInstance.lock_expires_at)
|
|
1028
|
+
< self._get_current_time_expr(),
|
|
1029
|
+
)
|
|
1030
|
+
)
|
|
1031
|
+
)
|
|
1032
|
+
instances = result.scalars().all()
|
|
1033
|
+
|
|
1034
|
+
stale_instance_ids = []
|
|
1035
|
+
workflows_to_resume = []
|
|
1036
|
+
|
|
1037
|
+
# Collect instance IDs and workflows to resume
|
|
1038
|
+
for instance in instances:
|
|
1039
|
+
stale_instance_ids.append(instance.instance_id)
|
|
1040
|
+
|
|
1041
|
+
# Add to resume list if status is 'running' or 'compensating'
|
|
1042
|
+
if instance.status in ["running", "compensating"]:
|
|
1043
|
+
workflows_to_resume.append(
|
|
1044
|
+
{
|
|
1045
|
+
"instance_id": str(instance.instance_id),
|
|
1046
|
+
"workflow_name": str(instance.workflow_name),
|
|
1047
|
+
"source_hash": str(instance.source_hash),
|
|
1048
|
+
"status": str(instance.status),
|
|
1049
|
+
}
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
# Release all stale locks in one UPDATE statement
|
|
1053
|
+
if stale_instance_ids:
|
|
1054
|
+
await session.execute(
|
|
1055
|
+
update(WorkflowInstance)
|
|
1056
|
+
.where(WorkflowInstance.instance_id.in_(stale_instance_ids))
|
|
1057
|
+
.values(
|
|
1058
|
+
locked_by=None,
|
|
1059
|
+
locked_at=None,
|
|
1060
|
+
lock_expires_at=None,
|
|
1061
|
+
updated_at=current_time,
|
|
1062
|
+
)
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
await session.commit()
|
|
1066
|
+
return workflows_to_resume
|
|
1067
|
+
|
|
1068
|
+
# -------------------------------------------------------------------------
|
|
1069
|
+
# History Methods (prefer external session)
|
|
1070
|
+
# -------------------------------------------------------------------------
|
|
1071
|
+
|
|
1072
|
+
async def append_history(
|
|
1073
|
+
self,
|
|
1074
|
+
instance_id: str,
|
|
1075
|
+
activity_id: str,
|
|
1076
|
+
event_type: str,
|
|
1077
|
+
event_data: dict[str, Any] | bytes,
|
|
1078
|
+
) -> None:
|
|
1079
|
+
"""Append an event to workflow execution history."""
|
|
1080
|
+
session = self._get_session_for_operation()
|
|
1081
|
+
async with self._session_scope(session) as session:
|
|
1082
|
+
# Determine data type and storage columns
|
|
1083
|
+
if isinstance(event_data, bytes):
|
|
1084
|
+
data_type = "binary"
|
|
1085
|
+
event_data_json = None
|
|
1086
|
+
event_data_bin = event_data
|
|
1087
|
+
else:
|
|
1088
|
+
data_type = "json"
|
|
1089
|
+
event_data_json = json.dumps(event_data)
|
|
1090
|
+
event_data_bin = None
|
|
1091
|
+
|
|
1092
|
+
history = WorkflowHistory(
|
|
1093
|
+
instance_id=instance_id,
|
|
1094
|
+
activity_id=activity_id,
|
|
1095
|
+
event_type=event_type,
|
|
1096
|
+
data_type=data_type,
|
|
1097
|
+
event_data=event_data_json,
|
|
1098
|
+
event_data_binary=event_data_bin,
|
|
1099
|
+
)
|
|
1100
|
+
session.add(history)
|
|
1101
|
+
await self._commit_if_not_in_transaction(session)
|
|
1102
|
+
|
|
1103
|
+
async def get_history(self, instance_id: str) -> list[dict[str, Any]]:
|
|
1104
|
+
"""
|
|
1105
|
+
Get workflow execution history in order.
|
|
1106
|
+
|
|
1107
|
+
Returns history events ordered by creation time.
|
|
1108
|
+
"""
|
|
1109
|
+
session = self._get_session_for_operation()
|
|
1110
|
+
async with self._session_scope(session) as session:
|
|
1111
|
+
result = await session.execute(
|
|
1112
|
+
select(WorkflowHistory)
|
|
1113
|
+
.where(WorkflowHistory.instance_id == instance_id)
|
|
1114
|
+
.order_by(WorkflowHistory.created_at.asc())
|
|
1115
|
+
)
|
|
1116
|
+
rows = result.scalars().all()
|
|
1117
|
+
|
|
1118
|
+
return [
|
|
1119
|
+
{
|
|
1120
|
+
"id": row.id,
|
|
1121
|
+
"instance_id": row.instance_id,
|
|
1122
|
+
"activity_id": row.activity_id,
|
|
1123
|
+
"event_type": row.event_type,
|
|
1124
|
+
"event_data": (
|
|
1125
|
+
row.event_data_binary
|
|
1126
|
+
if row.data_type == "binary"
|
|
1127
|
+
else json.loads(row.event_data) # type: ignore[arg-type]
|
|
1128
|
+
),
|
|
1129
|
+
"created_at": row.created_at.isoformat(),
|
|
1130
|
+
}
|
|
1131
|
+
for row in rows
|
|
1132
|
+
]
|
|
1133
|
+
|
|
1134
|
+
# -------------------------------------------------------------------------
|
|
1135
|
+
# Compensation Methods (prefer external session)
|
|
1136
|
+
# -------------------------------------------------------------------------
|
|
1137
|
+
|
|
1138
|
+
async def push_compensation(
|
|
1139
|
+
self,
|
|
1140
|
+
instance_id: str,
|
|
1141
|
+
activity_id: str,
|
|
1142
|
+
activity_name: str,
|
|
1143
|
+
args: dict[str, Any],
|
|
1144
|
+
) -> None:
|
|
1145
|
+
"""Push a compensation to the stack."""
|
|
1146
|
+
session = self._get_session_for_operation()
|
|
1147
|
+
async with self._session_scope(session) as session:
|
|
1148
|
+
compensation = WorkflowCompensation(
|
|
1149
|
+
instance_id=instance_id,
|
|
1150
|
+
activity_id=activity_id,
|
|
1151
|
+
activity_name=activity_name,
|
|
1152
|
+
args=json.dumps(args),
|
|
1153
|
+
)
|
|
1154
|
+
session.add(compensation)
|
|
1155
|
+
await self._commit_if_not_in_transaction(session)
|
|
1156
|
+
|
|
1157
|
+
async def get_compensations(self, instance_id: str) -> list[dict[str, Any]]:
|
|
1158
|
+
"""Get compensations in LIFO order (most recent first)."""
|
|
1159
|
+
session = self._get_session_for_operation()
|
|
1160
|
+
async with self._session_scope(session) as session:
|
|
1161
|
+
result = await session.execute(
|
|
1162
|
+
select(WorkflowCompensation)
|
|
1163
|
+
.where(WorkflowCompensation.instance_id == instance_id)
|
|
1164
|
+
.order_by(WorkflowCompensation.created_at.desc())
|
|
1165
|
+
)
|
|
1166
|
+
rows = result.scalars().all()
|
|
1167
|
+
|
|
1168
|
+
return [
|
|
1169
|
+
{
|
|
1170
|
+
"id": row.id,
|
|
1171
|
+
"instance_id": row.instance_id,
|
|
1172
|
+
"activity_id": row.activity_id,
|
|
1173
|
+
"activity_name": row.activity_name,
|
|
1174
|
+
"args": json.loads(row.args), # type: ignore[arg-type]
|
|
1175
|
+
"created_at": row.created_at.isoformat(),
|
|
1176
|
+
}
|
|
1177
|
+
for row in rows
|
|
1178
|
+
]
|
|
1179
|
+
|
|
1180
|
+
async def clear_compensations(self, instance_id: str) -> None:
|
|
1181
|
+
"""Clear all compensations for a workflow instance."""
|
|
1182
|
+
session = self._get_session_for_operation()
|
|
1183
|
+
async with self._session_scope(session) as session:
|
|
1184
|
+
await session.execute(
|
|
1185
|
+
delete(WorkflowCompensation).where(WorkflowCompensation.instance_id == instance_id)
|
|
1186
|
+
)
|
|
1187
|
+
await self._commit_if_not_in_transaction(session)
|
|
1188
|
+
|
|
1189
|
+
# -------------------------------------------------------------------------
|
|
1190
|
+
# Event Subscription Methods (prefer external session for registration)
|
|
1191
|
+
# -------------------------------------------------------------------------
|
|
1192
|
+
|
|
1193
|
+
async def add_event_subscription(
|
|
1194
|
+
self,
|
|
1195
|
+
instance_id: str,
|
|
1196
|
+
event_type: str,
|
|
1197
|
+
timeout_at: datetime | None = None,
|
|
1198
|
+
) -> None:
|
|
1199
|
+
"""Register an event wait subscription."""
|
|
1200
|
+
session = self._get_session_for_operation()
|
|
1201
|
+
async with self._session_scope(session) as session:
|
|
1202
|
+
subscription = WorkflowEventSubscription(
|
|
1203
|
+
instance_id=instance_id,
|
|
1204
|
+
event_type=event_type,
|
|
1205
|
+
timeout_at=timeout_at,
|
|
1206
|
+
)
|
|
1207
|
+
session.add(subscription)
|
|
1208
|
+
await self._commit_if_not_in_transaction(session)
|
|
1209
|
+
|
|
1210
|
+
async def find_waiting_instances(self, event_type: str) -> list[dict[str, Any]]:
|
|
1211
|
+
"""Find workflow instances waiting for a specific event type."""
|
|
1212
|
+
session = self._get_session_for_operation()
|
|
1213
|
+
async with self._session_scope(session) as session:
|
|
1214
|
+
result = await session.execute(
|
|
1215
|
+
select(WorkflowEventSubscription).where(
|
|
1216
|
+
and_(
|
|
1217
|
+
WorkflowEventSubscription.event_type == event_type,
|
|
1218
|
+
or_(
|
|
1219
|
+
WorkflowEventSubscription.timeout_at.is_(None),
|
|
1220
|
+
self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
1221
|
+
> self._get_current_time_expr(),
|
|
1222
|
+
),
|
|
1223
|
+
)
|
|
1224
|
+
)
|
|
1225
|
+
)
|
|
1226
|
+
rows = result.scalars().all()
|
|
1227
|
+
|
|
1228
|
+
return [
|
|
1229
|
+
{
|
|
1230
|
+
"id": row.id,
|
|
1231
|
+
"instance_id": row.instance_id,
|
|
1232
|
+
"event_type": row.event_type,
|
|
1233
|
+
"activity_id": row.activity_id,
|
|
1234
|
+
"timeout_at": row.timeout_at.isoformat() if row.timeout_at else None,
|
|
1235
|
+
"created_at": row.created_at.isoformat(),
|
|
1236
|
+
}
|
|
1237
|
+
for row in rows
|
|
1238
|
+
]
|
|
1239
|
+
|
|
1240
|
+
async def remove_event_subscription(
|
|
1241
|
+
self,
|
|
1242
|
+
instance_id: str,
|
|
1243
|
+
event_type: str,
|
|
1244
|
+
) -> None:
|
|
1245
|
+
"""Remove event subscription after the event is received."""
|
|
1246
|
+
session = self._get_session_for_operation()
|
|
1247
|
+
async with self._session_scope(session) as session:
|
|
1248
|
+
await session.execute(
|
|
1249
|
+
delete(WorkflowEventSubscription).where(
|
|
1250
|
+
and_(
|
|
1251
|
+
WorkflowEventSubscription.instance_id == instance_id,
|
|
1252
|
+
WorkflowEventSubscription.event_type == event_type,
|
|
1253
|
+
)
|
|
1254
|
+
)
|
|
1255
|
+
)
|
|
1256
|
+
await self._commit_if_not_in_transaction(session)
|
|
1257
|
+
|
|
1258
|
+
async def cleanup_expired_subscriptions(self) -> int:
|
|
1259
|
+
"""Clean up event subscriptions that have timed out."""
|
|
1260
|
+
session = self._get_session_for_operation()
|
|
1261
|
+
async with self._session_scope(session) as session:
|
|
1262
|
+
result = await session.execute(
|
|
1263
|
+
delete(WorkflowEventSubscription).where(
|
|
1264
|
+
and_(
|
|
1265
|
+
WorkflowEventSubscription.timeout_at.isnot(None),
|
|
1266
|
+
self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
1267
|
+
<= self._get_current_time_expr(),
|
|
1268
|
+
)
|
|
1269
|
+
)
|
|
1270
|
+
)
|
|
1271
|
+
await self._commit_if_not_in_transaction(session)
|
|
1272
|
+
return result.rowcount or 0 # type: ignore[attr-defined]
|
|
1273
|
+
|
|
1274
|
+
async def find_expired_event_subscriptions(self) -> list[dict[str, Any]]:
|
|
1275
|
+
"""Find event subscriptions that have timed out."""
|
|
1276
|
+
session = self._get_session_for_operation()
|
|
1277
|
+
async with self._session_scope(session) as session:
|
|
1278
|
+
result = await session.execute(
|
|
1279
|
+
select(
|
|
1280
|
+
WorkflowEventSubscription.instance_id,
|
|
1281
|
+
WorkflowEventSubscription.event_type,
|
|
1282
|
+
WorkflowEventSubscription.activity_id,
|
|
1283
|
+
WorkflowEventSubscription.timeout_at,
|
|
1284
|
+
WorkflowEventSubscription.created_at,
|
|
1285
|
+
).where(
|
|
1286
|
+
and_(
|
|
1287
|
+
WorkflowEventSubscription.timeout_at.isnot(None),
|
|
1288
|
+
self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
1289
|
+
<= self._get_current_time_expr(),
|
|
1290
|
+
)
|
|
1291
|
+
)
|
|
1292
|
+
)
|
|
1293
|
+
rows = result.all()
|
|
1294
|
+
|
|
1295
|
+
return [
|
|
1296
|
+
{
|
|
1297
|
+
"instance_id": row[0],
|
|
1298
|
+
"event_type": row[1],
|
|
1299
|
+
"activity_id": row[2],
|
|
1300
|
+
"timeout_at": row[3].isoformat() if row[3] else None,
|
|
1301
|
+
"created_at": row[4].isoformat() if row[4] else None,
|
|
1302
|
+
}
|
|
1303
|
+
for row in rows
|
|
1304
|
+
]
|
|
1305
|
+
|
|
1306
|
+
async def register_event_subscription_and_release_lock(
|
|
1307
|
+
self,
|
|
1308
|
+
instance_id: str,
|
|
1309
|
+
worker_id: str,
|
|
1310
|
+
event_type: str,
|
|
1311
|
+
timeout_at: datetime | None = None,
|
|
1312
|
+
activity_id: str | None = None,
|
|
1313
|
+
) -> None:
|
|
1314
|
+
"""
|
|
1315
|
+
Atomically register event subscription and release workflow lock.
|
|
1316
|
+
|
|
1317
|
+
This performs THREE operations in a SINGLE transaction:
|
|
1318
|
+
1. Register event subscription
|
|
1319
|
+
2. Update current activity
|
|
1320
|
+
3. Release lock
|
|
1321
|
+
|
|
1322
|
+
This ensures distributed coroutines work correctly - when a workflow
|
|
1323
|
+
calls wait_event(), the subscription is registered and lock is released
|
|
1324
|
+
atomically, so ANY worker can resume the workflow when the event arrives.
|
|
1325
|
+
|
|
1326
|
+
Note: Uses LOCK operation session (separate from external session).
|
|
1327
|
+
"""
|
|
1328
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1329
|
+
async with self._session_scope(session) as session, session.begin():
|
|
1330
|
+
# 1. Verify we hold the lock (sanity check)
|
|
1331
|
+
result = await session.execute(
|
|
1332
|
+
select(WorkflowInstance.locked_by).where(
|
|
1333
|
+
WorkflowInstance.instance_id == instance_id
|
|
1334
|
+
)
|
|
1335
|
+
)
|
|
1336
|
+
row = result.one_or_none()
|
|
1337
|
+
|
|
1338
|
+
if row is None:
|
|
1339
|
+
raise RuntimeError(f"Workflow instance {instance_id} not found")
|
|
1340
|
+
|
|
1341
|
+
current_lock_holder = row[0]
|
|
1342
|
+
if current_lock_holder != worker_id:
|
|
1343
|
+
raise RuntimeError(
|
|
1344
|
+
f"Cannot release lock: worker {worker_id} does not hold lock "
|
|
1345
|
+
f"for {instance_id} (held by: {current_lock_holder})"
|
|
1346
|
+
)
|
|
1347
|
+
|
|
1348
|
+
# 2. Register event subscription (INSERT OR REPLACE equivalent)
|
|
1349
|
+
# First delete existing
|
|
1350
|
+
await session.execute(
|
|
1351
|
+
delete(WorkflowEventSubscription).where(
|
|
1352
|
+
and_(
|
|
1353
|
+
WorkflowEventSubscription.instance_id == instance_id,
|
|
1354
|
+
WorkflowEventSubscription.event_type == event_type,
|
|
1355
|
+
)
|
|
1356
|
+
)
|
|
1357
|
+
)
|
|
1358
|
+
|
|
1359
|
+
# Then insert new
|
|
1360
|
+
subscription = WorkflowEventSubscription(
|
|
1361
|
+
instance_id=instance_id,
|
|
1362
|
+
event_type=event_type,
|
|
1363
|
+
activity_id=activity_id,
|
|
1364
|
+
timeout_at=timeout_at,
|
|
1365
|
+
)
|
|
1366
|
+
session.add(subscription)
|
|
1367
|
+
|
|
1368
|
+
# 3. Update current activity (if provided)
|
|
1369
|
+
if activity_id is not None:
|
|
1370
|
+
await session.execute(
|
|
1371
|
+
update(WorkflowInstance)
|
|
1372
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
1373
|
+
.values(current_activity_id=activity_id, updated_at=func.now())
|
|
1374
|
+
)
|
|
1375
|
+
|
|
1376
|
+
# 4. Release lock
|
|
1377
|
+
await session.execute(
|
|
1378
|
+
update(WorkflowInstance)
|
|
1379
|
+
.where(
|
|
1380
|
+
and_(
|
|
1381
|
+
WorkflowInstance.instance_id == instance_id,
|
|
1382
|
+
WorkflowInstance.locked_by == worker_id,
|
|
1383
|
+
)
|
|
1384
|
+
)
|
|
1385
|
+
.values(
|
|
1386
|
+
locked_by=None,
|
|
1387
|
+
locked_at=None,
|
|
1388
|
+
updated_at=func.now(),
|
|
1389
|
+
)
|
|
1390
|
+
)
|
|
1391
|
+
|
|
1392
|
+
async def register_timer_subscription_and_release_lock(
|
|
1393
|
+
self,
|
|
1394
|
+
instance_id: str,
|
|
1395
|
+
worker_id: str,
|
|
1396
|
+
timer_id: str,
|
|
1397
|
+
expires_at: datetime,
|
|
1398
|
+
activity_id: str | None = None,
|
|
1399
|
+
) -> None:
|
|
1400
|
+
"""
|
|
1401
|
+
Atomically register timer subscription and release workflow lock.
|
|
1402
|
+
|
|
1403
|
+
This performs FOUR operations in a SINGLE transaction:
|
|
1404
|
+
1. Register timer subscription
|
|
1405
|
+
2. Update current activity
|
|
1406
|
+
3. Update status to 'waiting_for_timer'
|
|
1407
|
+
4. Release lock
|
|
1408
|
+
|
|
1409
|
+
This ensures distributed coroutines work correctly - when a workflow
|
|
1410
|
+
calls wait_timer(), the subscription is registered and lock is released
|
|
1411
|
+
atomically, so ANY worker can resume the workflow when the timer expires.
|
|
1412
|
+
|
|
1413
|
+
Note: Uses LOCK operation session (separate from external session).
|
|
1414
|
+
"""
|
|
1415
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1416
|
+
async with self._session_scope(session) as session, session.begin():
|
|
1417
|
+
# 1. Verify we hold the lock (sanity check)
|
|
1418
|
+
result = await session.execute(
|
|
1419
|
+
select(WorkflowInstance.locked_by).where(
|
|
1420
|
+
WorkflowInstance.instance_id == instance_id
|
|
1421
|
+
)
|
|
1422
|
+
)
|
|
1423
|
+
row = result.one_or_none()
|
|
1424
|
+
|
|
1425
|
+
if row is None:
|
|
1426
|
+
raise RuntimeError(f"Workflow instance {instance_id} not found")
|
|
1427
|
+
|
|
1428
|
+
current_lock_holder = row[0]
|
|
1429
|
+
if current_lock_holder != worker_id:
|
|
1430
|
+
raise RuntimeError(
|
|
1431
|
+
f"Cannot release lock: worker {worker_id} does not hold lock "
|
|
1432
|
+
f"for {instance_id} (held by: {current_lock_holder})"
|
|
1433
|
+
)
|
|
1434
|
+
|
|
1435
|
+
# 2. Register timer subscription (with conflict handling)
|
|
1436
|
+
# Check if exists
|
|
1437
|
+
result = await session.execute(
|
|
1438
|
+
select(WorkflowTimerSubscription).where(
|
|
1439
|
+
and_(
|
|
1440
|
+
WorkflowTimerSubscription.instance_id == instance_id,
|
|
1441
|
+
WorkflowTimerSubscription.timer_id == timer_id,
|
|
1442
|
+
)
|
|
1443
|
+
)
|
|
1444
|
+
)
|
|
1445
|
+
existing = result.scalar_one_or_none()
|
|
1446
|
+
|
|
1447
|
+
if not existing:
|
|
1448
|
+
# Insert new subscription
|
|
1449
|
+
subscription = WorkflowTimerSubscription(
|
|
1450
|
+
instance_id=instance_id,
|
|
1451
|
+
timer_id=timer_id,
|
|
1452
|
+
expires_at=expires_at,
|
|
1453
|
+
activity_id=activity_id,
|
|
1454
|
+
)
|
|
1455
|
+
session.add(subscription)
|
|
1456
|
+
|
|
1457
|
+
# 3. Update current activity (if provided)
|
|
1458
|
+
if activity_id is not None:
|
|
1459
|
+
await session.execute(
|
|
1460
|
+
update(WorkflowInstance)
|
|
1461
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
1462
|
+
.values(current_activity_id=activity_id, updated_at=func.now())
|
|
1463
|
+
)
|
|
1464
|
+
|
|
1465
|
+
# 4. Update status to 'waiting_for_timer' and release lock
|
|
1466
|
+
await session.execute(
|
|
1467
|
+
update(WorkflowInstance)
|
|
1468
|
+
.where(
|
|
1469
|
+
and_(
|
|
1470
|
+
WorkflowInstance.instance_id == instance_id,
|
|
1471
|
+
WorkflowInstance.locked_by == worker_id,
|
|
1472
|
+
)
|
|
1473
|
+
)
|
|
1474
|
+
.values(
|
|
1475
|
+
status="waiting_for_timer",
|
|
1476
|
+
locked_by=None,
|
|
1477
|
+
locked_at=None,
|
|
1478
|
+
updated_at=func.now(),
|
|
1479
|
+
)
|
|
1480
|
+
)
|
|
1481
|
+
|
|
1482
|
+
async def find_expired_timers(self) -> list[dict[str, Any]]:
|
|
1483
|
+
"""Find timer subscriptions that have expired."""
|
|
1484
|
+
session = self._get_session_for_operation()
|
|
1485
|
+
async with self._session_scope(session) as session:
|
|
1486
|
+
result = await session.execute(
|
|
1487
|
+
select(
|
|
1488
|
+
WorkflowTimerSubscription.instance_id,
|
|
1489
|
+
WorkflowTimerSubscription.timer_id,
|
|
1490
|
+
WorkflowTimerSubscription.expires_at,
|
|
1491
|
+
WorkflowTimerSubscription.activity_id,
|
|
1492
|
+
WorkflowInstance.workflow_name,
|
|
1493
|
+
)
|
|
1494
|
+
.join(
|
|
1495
|
+
WorkflowInstance,
|
|
1496
|
+
WorkflowTimerSubscription.instance_id == WorkflowInstance.instance_id,
|
|
1497
|
+
)
|
|
1498
|
+
.where(
|
|
1499
|
+
and_(
|
|
1500
|
+
self._make_datetime_comparable(WorkflowTimerSubscription.expires_at)
|
|
1501
|
+
<= self._get_current_time_expr(),
|
|
1502
|
+
WorkflowInstance.status == "waiting_for_timer",
|
|
1503
|
+
)
|
|
1504
|
+
)
|
|
1505
|
+
)
|
|
1506
|
+
rows = result.all()
|
|
1507
|
+
|
|
1508
|
+
return [
|
|
1509
|
+
{
|
|
1510
|
+
"instance_id": row[0],
|
|
1511
|
+
"timer_id": row[1],
|
|
1512
|
+
"expires_at": row[2].isoformat(),
|
|
1513
|
+
"activity_id": row[3],
|
|
1514
|
+
"workflow_name": row[4],
|
|
1515
|
+
}
|
|
1516
|
+
for row in rows
|
|
1517
|
+
]
|
|
1518
|
+
|
|
1519
|
+
async def remove_timer_subscription(
|
|
1520
|
+
self,
|
|
1521
|
+
instance_id: str,
|
|
1522
|
+
timer_id: str,
|
|
1523
|
+
) -> None:
|
|
1524
|
+
"""Remove timer subscription after the timer expires."""
|
|
1525
|
+
session = self._get_session_for_operation()
|
|
1526
|
+
async with self._session_scope(session) as session:
|
|
1527
|
+
await session.execute(
|
|
1528
|
+
delete(WorkflowTimerSubscription).where(
|
|
1529
|
+
and_(
|
|
1530
|
+
WorkflowTimerSubscription.instance_id == instance_id,
|
|
1531
|
+
WorkflowTimerSubscription.timer_id == timer_id,
|
|
1532
|
+
)
|
|
1533
|
+
)
|
|
1534
|
+
)
|
|
1535
|
+
await self._commit_if_not_in_transaction(session)
|
|
1536
|
+
|
|
1537
|
+
# -------------------------------------------------------------------------
|
|
1538
|
+
# Transactional Outbox Methods (prefer external session)
|
|
1539
|
+
# -------------------------------------------------------------------------
|
|
1540
|
+
|
|
1541
|
+
async def add_outbox_event(
|
|
1542
|
+
self,
|
|
1543
|
+
event_id: str,
|
|
1544
|
+
event_type: str,
|
|
1545
|
+
event_source: str,
|
|
1546
|
+
event_data: dict[str, Any] | bytes,
|
|
1547
|
+
content_type: str = "application/json",
|
|
1548
|
+
) -> None:
|
|
1549
|
+
"""Add an event to the transactional outbox."""
|
|
1550
|
+
session = self._get_session_for_operation()
|
|
1551
|
+
async with self._session_scope(session) as session:
|
|
1552
|
+
# Determine data type and storage columns
|
|
1553
|
+
if isinstance(event_data, bytes):
|
|
1554
|
+
data_type = "binary"
|
|
1555
|
+
event_data_json = None
|
|
1556
|
+
event_data_bin = event_data
|
|
1557
|
+
else:
|
|
1558
|
+
data_type = "json"
|
|
1559
|
+
event_data_json = json.dumps(event_data)
|
|
1560
|
+
event_data_bin = None
|
|
1561
|
+
|
|
1562
|
+
event = OutboxEvent(
|
|
1563
|
+
event_id=event_id,
|
|
1564
|
+
event_type=event_type,
|
|
1565
|
+
event_source=event_source,
|
|
1566
|
+
data_type=data_type,
|
|
1567
|
+
event_data=event_data_json,
|
|
1568
|
+
event_data_binary=event_data_bin,
|
|
1569
|
+
content_type=content_type,
|
|
1570
|
+
)
|
|
1571
|
+
session.add(event)
|
|
1572
|
+
await self._commit_if_not_in_transaction(session)
|
|
1573
|
+
|
|
1574
|
+
async def get_pending_outbox_events(self, limit: int = 10) -> list[dict[str, Any]]:
|
|
1575
|
+
"""
|
|
1576
|
+
Get pending/failed outbox events for publishing (with row-level locking).
|
|
1577
|
+
|
|
1578
|
+
This method uses SELECT FOR UPDATE SKIP LOCKED to safely fetch events
|
|
1579
|
+
in a multi-worker environment. It fetches both 'pending' and 'failed'
|
|
1580
|
+
events (for retry). Fetched events are immediately marked as 'processing'
|
|
1581
|
+
to prevent duplicate processing by other workers.
|
|
1582
|
+
|
|
1583
|
+
Args:
|
|
1584
|
+
limit: Maximum number of events to fetch
|
|
1585
|
+
|
|
1586
|
+
Returns:
|
|
1587
|
+
List of event dictionaries with 'processing' status
|
|
1588
|
+
"""
|
|
1589
|
+
# Use new session for lock operation (SKIP LOCKED requires separate transactions)
|
|
1590
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1591
|
+
# Explicitly begin transaction before SELECT FOR UPDATE
|
|
1592
|
+
# This ensures proper transaction isolation for SKIP LOCKED
|
|
1593
|
+
async with self._session_scope(session) as session, session.begin():
|
|
1594
|
+
# 1. SELECT FOR UPDATE to lock rows (both 'pending' and 'failed' for retry)
|
|
1595
|
+
result = await session.execute(
|
|
1596
|
+
select(OutboxEvent)
|
|
1597
|
+
.where(OutboxEvent.status.in_(["pending", "failed"]))
|
|
1598
|
+
.order_by(OutboxEvent.created_at.asc())
|
|
1599
|
+
.limit(limit)
|
|
1600
|
+
.with_for_update(skip_locked=True)
|
|
1601
|
+
)
|
|
1602
|
+
rows = result.scalars().all()
|
|
1603
|
+
|
|
1604
|
+
# 2. Mark as 'processing' to prevent duplicate fetches
|
|
1605
|
+
if rows:
|
|
1606
|
+
event_ids = [row.event_id for row in rows]
|
|
1607
|
+
await session.execute(
|
|
1608
|
+
update(OutboxEvent)
|
|
1609
|
+
.where(OutboxEvent.event_id.in_(event_ids))
|
|
1610
|
+
.values(status="processing")
|
|
1611
|
+
)
|
|
1612
|
+
|
|
1613
|
+
# 3. Return events (now with status='processing')
|
|
1614
|
+
return [
|
|
1615
|
+
{
|
|
1616
|
+
"event_id": row.event_id,
|
|
1617
|
+
"event_type": row.event_type,
|
|
1618
|
+
"event_source": row.event_source,
|
|
1619
|
+
"event_data": (
|
|
1620
|
+
row.event_data_binary
|
|
1621
|
+
if row.data_type == "binary"
|
|
1622
|
+
else json.loads(row.event_data) # type: ignore[arg-type]
|
|
1623
|
+
),
|
|
1624
|
+
"content_type": row.content_type,
|
|
1625
|
+
"created_at": row.created_at.isoformat(),
|
|
1626
|
+
"status": "processing", # Always 'processing' after update
|
|
1627
|
+
"retry_count": row.retry_count,
|
|
1628
|
+
"last_error": row.last_error,
|
|
1629
|
+
}
|
|
1630
|
+
for row in rows
|
|
1631
|
+
]
|
|
1632
|
+
|
|
1633
|
+
async def mark_outbox_published(self, event_id: str) -> None:
|
|
1634
|
+
"""Mark outbox event as successfully published."""
|
|
1635
|
+
session = self._get_session_for_operation()
|
|
1636
|
+
async with self._session_scope(session) as session:
|
|
1637
|
+
await session.execute(
|
|
1638
|
+
update(OutboxEvent)
|
|
1639
|
+
.where(OutboxEvent.event_id == event_id)
|
|
1640
|
+
.values(status="published", published_at=func.now())
|
|
1641
|
+
)
|
|
1642
|
+
await self._commit_if_not_in_transaction(session)
|
|
1643
|
+
|
|
1644
|
+
async def mark_outbox_failed(self, event_id: str, error: str) -> None:
|
|
1645
|
+
"""
|
|
1646
|
+
Mark event as failed and increment retry count.
|
|
1647
|
+
|
|
1648
|
+
The event status is changed to 'failed' so it can be retried later.
|
|
1649
|
+
get_pending_outbox_events() will fetch both 'pending' and 'failed' events.
|
|
1650
|
+
"""
|
|
1651
|
+
session = self._get_session_for_operation()
|
|
1652
|
+
async with self._session_scope(session) as session:
|
|
1653
|
+
await session.execute(
|
|
1654
|
+
update(OutboxEvent)
|
|
1655
|
+
.where(OutboxEvent.event_id == event_id)
|
|
1656
|
+
.values(
|
|
1657
|
+
status="failed",
|
|
1658
|
+
retry_count=OutboxEvent.retry_count + 1,
|
|
1659
|
+
last_error=error,
|
|
1660
|
+
)
|
|
1661
|
+
)
|
|
1662
|
+
await self._commit_if_not_in_transaction(session)
|
|
1663
|
+
|
|
1664
|
+
async def mark_outbox_permanently_failed(self, event_id: str, error: str) -> None:
|
|
1665
|
+
"""Mark outbox event as permanently failed (sets status to 'failed')."""
|
|
1666
|
+
session = self._get_session_for_operation()
|
|
1667
|
+
async with self._session_scope(session) as session:
|
|
1668
|
+
await session.execute(
|
|
1669
|
+
update(OutboxEvent)
|
|
1670
|
+
.where(OutboxEvent.event_id == event_id)
|
|
1671
|
+
.values(
|
|
1672
|
+
status="failed",
|
|
1673
|
+
last_error=error,
|
|
1674
|
+
)
|
|
1675
|
+
)
|
|
1676
|
+
await self._commit_if_not_in_transaction(session)
|
|
1677
|
+
|
|
1678
|
+
async def mark_outbox_invalid(self, event_id: str, error: str) -> None:
|
|
1679
|
+
"""Mark outbox event as invalid (sets status to 'invalid')."""
|
|
1680
|
+
session = self._get_session_for_operation()
|
|
1681
|
+
async with self._session_scope(session) as session:
|
|
1682
|
+
await session.execute(
|
|
1683
|
+
update(OutboxEvent)
|
|
1684
|
+
.where(OutboxEvent.event_id == event_id)
|
|
1685
|
+
.values(
|
|
1686
|
+
status="invalid",
|
|
1687
|
+
last_error=error,
|
|
1688
|
+
)
|
|
1689
|
+
)
|
|
1690
|
+
await self._commit_if_not_in_transaction(session)
|
|
1691
|
+
|
|
1692
|
+
async def mark_outbox_expired(self, event_id: str, error: str) -> None:
|
|
1693
|
+
"""Mark outbox event as expired (sets status to 'expired')."""
|
|
1694
|
+
session = self._get_session_for_operation()
|
|
1695
|
+
async with self._session_scope(session) as session:
|
|
1696
|
+
await session.execute(
|
|
1697
|
+
update(OutboxEvent)
|
|
1698
|
+
.where(OutboxEvent.event_id == event_id)
|
|
1699
|
+
.values(
|
|
1700
|
+
status="expired",
|
|
1701
|
+
last_error=error,
|
|
1702
|
+
)
|
|
1703
|
+
)
|
|
1704
|
+
await self._commit_if_not_in_transaction(session)
|
|
1705
|
+
|
|
1706
|
+
async def cleanup_published_events(self, older_than_hours: int = 24) -> int:
|
|
1707
|
+
"""Clean up successfully published events older than threshold."""
|
|
1708
|
+
session = self._get_session_for_operation()
|
|
1709
|
+
async with self._session_scope(session) as session:
|
|
1710
|
+
threshold = datetime.now(UTC) - timedelta(hours=older_than_hours)
|
|
1711
|
+
|
|
1712
|
+
result = await session.execute(
|
|
1713
|
+
delete(OutboxEvent).where(
|
|
1714
|
+
and_(
|
|
1715
|
+
OutboxEvent.status == "published",
|
|
1716
|
+
OutboxEvent.published_at < threshold,
|
|
1717
|
+
)
|
|
1718
|
+
)
|
|
1719
|
+
)
|
|
1720
|
+
await self._commit_if_not_in_transaction(session)
|
|
1721
|
+
return result.rowcount or 0 # type: ignore[attr-defined]
|
|
1722
|
+
|
|
1723
|
+
# -------------------------------------------------------------------------
|
|
1724
|
+
# Workflow Cancellation Methods
|
|
1725
|
+
# -------------------------------------------------------------------------
|
|
1726
|
+
|
|
1727
|
+
async def cancel_instance(self, instance_id: str, cancelled_by: str) -> bool:
|
|
1728
|
+
"""
|
|
1729
|
+
Cancel a workflow instance.
|
|
1730
|
+
|
|
1731
|
+
Only running or waiting_for_event workflows can be cancelled.
|
|
1732
|
+
This method atomically:
|
|
1733
|
+
1. Checks current status
|
|
1734
|
+
2. Updates status to 'cancelled' if allowed
|
|
1735
|
+
3. Clears locks
|
|
1736
|
+
4. Records cancellation metadata
|
|
1737
|
+
5. Removes event subscriptions (if waiting for event)
|
|
1738
|
+
6. Removes timer subscriptions (if waiting for timer)
|
|
1739
|
+
|
|
1740
|
+
Args:
|
|
1741
|
+
instance_id: Workflow instance to cancel
|
|
1742
|
+
cancelled_by: Who/what triggered the cancellation
|
|
1743
|
+
|
|
1744
|
+
Returns:
|
|
1745
|
+
True if successfully cancelled, False otherwise
|
|
1746
|
+
|
|
1747
|
+
Note: Uses LOCK operation session (separate from external session).
|
|
1748
|
+
"""
|
|
1749
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1750
|
+
async with self._session_scope(session) as session, session.begin():
|
|
1751
|
+
# Get current instance status
|
|
1752
|
+
result = await session.execute(
|
|
1753
|
+
select(WorkflowInstance.status).where(WorkflowInstance.instance_id == instance_id)
|
|
1754
|
+
)
|
|
1755
|
+
row = result.one_or_none()
|
|
1756
|
+
|
|
1757
|
+
if row is None:
|
|
1758
|
+
# Instance not found
|
|
1759
|
+
return False
|
|
1760
|
+
|
|
1761
|
+
current_status = row[0]
|
|
1762
|
+
|
|
1763
|
+
# Only allow cancellation of running, waiting, or compensating workflows
|
|
1764
|
+
# compensating workflows can be marked as cancelled after compensation completes
|
|
1765
|
+
if current_status not in (
|
|
1766
|
+
"running",
|
|
1767
|
+
"waiting_for_event",
|
|
1768
|
+
"waiting_for_timer",
|
|
1769
|
+
"compensating",
|
|
1770
|
+
):
|
|
1771
|
+
# Already completed, failed, or cancelled
|
|
1772
|
+
return False
|
|
1773
|
+
|
|
1774
|
+
# Update status to cancelled and record metadata
|
|
1775
|
+
cancellation_metadata = {
|
|
1776
|
+
"cancelled_by": cancelled_by,
|
|
1777
|
+
"cancelled_at": datetime.now(UTC).isoformat(),
|
|
1778
|
+
"previous_status": current_status,
|
|
1779
|
+
}
|
|
1780
|
+
|
|
1781
|
+
await session.execute(
|
|
1782
|
+
update(WorkflowInstance)
|
|
1783
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
1784
|
+
.values(
|
|
1785
|
+
status="cancelled",
|
|
1786
|
+
output_data=json.dumps(cancellation_metadata),
|
|
1787
|
+
locked_by=None,
|
|
1788
|
+
locked_at=None,
|
|
1789
|
+
updated_at=func.now(),
|
|
1790
|
+
)
|
|
1791
|
+
)
|
|
1792
|
+
|
|
1793
|
+
# Remove event subscriptions if waiting for event
|
|
1794
|
+
if current_status == "waiting_for_event":
|
|
1795
|
+
await session.execute(
|
|
1796
|
+
delete(WorkflowEventSubscription).where(
|
|
1797
|
+
WorkflowEventSubscription.instance_id == instance_id
|
|
1798
|
+
)
|
|
1799
|
+
)
|
|
1800
|
+
|
|
1801
|
+
# Remove timer subscriptions if waiting for timer
|
|
1802
|
+
if current_status == "waiting_for_timer":
|
|
1803
|
+
await session.execute(
|
|
1804
|
+
delete(WorkflowTimerSubscription).where(
|
|
1805
|
+
WorkflowTimerSubscription.instance_id == instance_id
|
|
1806
|
+
)
|
|
1807
|
+
)
|
|
1808
|
+
|
|
1809
|
+
return True
|