edda-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1809 @@
1
+ """
2
+ SQLAlchemy storage implementation for Edda framework.
3
+
4
+ This module provides a SQLAlchemy-based implementation of the StorageProtocol,
5
+ supporting SQLite, PostgreSQL, and MySQL with database-based exclusive control
6
+ and transactional outbox pattern.
7
+ """
8
+
9
+ import json
10
+ import logging
11
+ from collections.abc import AsyncIterator
12
+ from contextlib import asynccontextmanager
13
+ from contextvars import ContextVar
14
+ from dataclasses import dataclass, field
15
+ from datetime import UTC, datetime, timedelta
16
+ from typing import Any
17
+
18
+ from sqlalchemy import (
19
+ CheckConstraint,
20
+ Column,
21
+ DateTime,
22
+ ForeignKeyConstraint,
23
+ Index,
24
+ Integer,
25
+ LargeBinary,
26
+ String,
27
+ Text,
28
+ UniqueConstraint,
29
+ and_,
30
+ delete,
31
+ func,
32
+ or_,
33
+ select,
34
+ text,
35
+ update,
36
+ )
37
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
38
+ from sqlalchemy.orm import declarative_base
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ # Declarative base for ORM models
43
+ Base = declarative_base()
44
+
45
+
46
+ # ============================================================================
47
+ # SQLAlchemy ORM Models
48
+ # ============================================================================
49
+
50
+
51
+ class SchemaVersion(Base): # type: ignore[valid-type, misc]
52
+ """Schema version tracking."""
53
+
54
+ __tablename__ = "schema_version"
55
+
56
+ version = Column(Integer, primary_key=True)
57
+ applied_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
58
+ description = Column(Text, nullable=False)
59
+
60
+
61
+ class WorkflowDefinition(Base): # type: ignore[valid-type, misc]
62
+ """Workflow definition (source code storage)."""
63
+
64
+ __tablename__ = "workflow_definitions"
65
+
66
+ workflow_name = Column(String(255), nullable=False, primary_key=True)
67
+ source_hash = Column(String(64), nullable=False, primary_key=True)
68
+ source_code = Column(Text, nullable=False)
69
+ created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
70
+
71
+ __table_args__ = (
72
+ Index("idx_definitions_name", "workflow_name"),
73
+ Index("idx_definitions_hash", "source_hash"),
74
+ )
75
+
76
+
77
+ class WorkflowInstance(Base): # type: ignore[valid-type, misc]
78
+ """Workflow instance with distributed locking support."""
79
+
80
+ __tablename__ = "workflow_instances"
81
+
82
+ instance_id = Column(String(255), primary_key=True)
83
+ workflow_name = Column(String(255), nullable=False)
84
+ source_hash = Column(String(64), nullable=False)
85
+ owner_service = Column(String(255), nullable=False)
86
+ status = Column(
87
+ String(50),
88
+ nullable=False,
89
+ server_default=text("'running'"),
90
+ )
91
+ current_activity_id = Column(String(255), nullable=True)
92
+ started_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
93
+ updated_at = Column(
94
+ DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
95
+ )
96
+ input_data = Column(Text, nullable=False) # JSON
97
+ output_data = Column(Text, nullable=True) # JSON
98
+ locked_by = Column(String(255), nullable=True)
99
+ locked_at = Column(DateTime(timezone=True), nullable=True)
100
+ lock_timeout_seconds = Column(Integer, nullable=True) # None = use global default (300s)
101
+ lock_expires_at = Column(DateTime(timezone=True), nullable=True) # Absolute expiry time
102
+
103
+ __table_args__ = (
104
+ ForeignKeyConstraint(
105
+ ["workflow_name", "source_hash"],
106
+ ["workflow_definitions.workflow_name", "workflow_definitions.source_hash"],
107
+ ),
108
+ CheckConstraint(
109
+ "status IN ('running', 'completed', 'failed', 'waiting_for_event', "
110
+ "'waiting_for_timer', 'compensating', 'cancelled')",
111
+ name="valid_status",
112
+ ),
113
+ Index("idx_instances_status", "status"),
114
+ Index("idx_instances_workflow", "workflow_name"),
115
+ Index("idx_instances_owner", "owner_service"),
116
+ Index("idx_instances_locked", "locked_by", "locked_at"),
117
+ Index("idx_instances_lock_expires", "lock_expires_at"),
118
+ Index("idx_instances_updated", "updated_at"),
119
+ Index("idx_instances_hash", "source_hash"),
120
+ )
121
+
122
+
123
+ class WorkflowHistory(Base): # type: ignore[valid-type, misc]
124
+ """Workflow execution history (for deterministic replay)."""
125
+
126
+ __tablename__ = "workflow_history"
127
+
128
+ id = Column(Integer, primary_key=True, autoincrement=True)
129
+ instance_id = Column(
130
+ String(255),
131
+ nullable=False,
132
+ )
133
+ activity_id = Column(String(255), nullable=False)
134
+ event_type = Column(String(100), nullable=False)
135
+ data_type = Column(String(10), nullable=False) # 'json' or 'binary'
136
+ event_data = Column(Text, nullable=True) # JSON (when data_type='json')
137
+ event_data_binary = Column(LargeBinary, nullable=True) # Binary (when data_type='binary')
138
+ created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
139
+
140
+ __table_args__ = (
141
+ ForeignKeyConstraint(
142
+ ["instance_id"],
143
+ ["workflow_instances.instance_id"],
144
+ ondelete="CASCADE",
145
+ ),
146
+ CheckConstraint(
147
+ "data_type IN ('json', 'binary')",
148
+ name="valid_data_type",
149
+ ),
150
+ CheckConstraint(
151
+ "(data_type = 'json' AND event_data IS NOT NULL AND event_data_binary IS NULL) OR "
152
+ "(data_type = 'binary' AND event_data IS NULL AND event_data_binary IS NOT NULL)",
153
+ name="data_type_consistency",
154
+ ),
155
+ UniqueConstraint("instance_id", "activity_id", name="unique_instance_activity"),
156
+ Index("idx_history_instance", "instance_id", "activity_id"),
157
+ Index("idx_history_created", "created_at"),
158
+ )
159
+
160
+
161
+ class WorkflowCompensation(Base): # type: ignore[valid-type, misc]
162
+ """Compensation transactions (LIFO stack for Saga pattern)."""
163
+
164
+ __tablename__ = "workflow_compensations"
165
+
166
+ id = Column(Integer, primary_key=True, autoincrement=True)
167
+ instance_id = Column(
168
+ String(255),
169
+ nullable=False,
170
+ )
171
+ activity_id = Column(String(255), nullable=False)
172
+ activity_name = Column(String(255), nullable=False)
173
+ args = Column(Text, nullable=False) # JSON
174
+ created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
175
+
176
+ __table_args__ = (
177
+ ForeignKeyConstraint(
178
+ ["instance_id"],
179
+ ["workflow_instances.instance_id"],
180
+ ondelete="CASCADE",
181
+ ),
182
+ Index("idx_compensations_instance", "instance_id", "created_at"),
183
+ )
184
+
185
+
186
+ class WorkflowEventSubscription(Base): # type: ignore[valid-type, misc]
187
+ """Event subscriptions (for wait_event)."""
188
+
189
+ __tablename__ = "workflow_event_subscriptions"
190
+
191
+ id = Column(Integer, primary_key=True, autoincrement=True)
192
+ instance_id = Column(
193
+ String(255),
194
+ nullable=False,
195
+ )
196
+ event_type = Column(String(255), nullable=False)
197
+ activity_id = Column(String(255), nullable=True)
198
+ timeout_at = Column(DateTime(timezone=True), nullable=True)
199
+ created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
200
+
201
+ __table_args__ = (
202
+ ForeignKeyConstraint(
203
+ ["instance_id"],
204
+ ["workflow_instances.instance_id"],
205
+ ondelete="CASCADE",
206
+ ),
207
+ UniqueConstraint("instance_id", "event_type", name="unique_instance_event"),
208
+ Index("idx_subscriptions_event", "event_type"),
209
+ Index("idx_subscriptions_timeout", "timeout_at"),
210
+ Index("idx_subscriptions_instance", "instance_id"),
211
+ )
212
+
213
+
214
+ class WorkflowTimerSubscription(Base): # type: ignore[valid-type, misc]
215
+ """Timer subscriptions (for wait_timer)."""
216
+
217
+ __tablename__ = "workflow_timer_subscriptions"
218
+
219
+ id = Column(Integer, primary_key=True, autoincrement=True)
220
+ instance_id = Column(
221
+ String(255),
222
+ nullable=False,
223
+ )
224
+ timer_id = Column(String(255), nullable=False)
225
+ expires_at = Column(DateTime(timezone=True), nullable=False)
226
+ activity_id = Column(String(255), nullable=True)
227
+ created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
228
+
229
+ __table_args__ = (
230
+ ForeignKeyConstraint(
231
+ ["instance_id"],
232
+ ["workflow_instances.instance_id"],
233
+ ondelete="CASCADE",
234
+ ),
235
+ UniqueConstraint("instance_id", "timer_id", name="unique_instance_timer"),
236
+ Index("idx_timer_subscriptions_expires", "expires_at"),
237
+ Index("idx_timer_subscriptions_instance", "instance_id"),
238
+ )
239
+
240
+
241
+ class OutboxEvent(Base): # type: ignore[valid-type, misc]
242
+ """Transactional outbox pattern events."""
243
+
244
+ __tablename__ = "outbox_events"
245
+
246
+ event_id = Column(String(255), primary_key=True)
247
+ event_type = Column(String(255), nullable=False)
248
+ event_source = Column(String(255), nullable=False)
249
+ data_type = Column(String(10), nullable=False) # 'json' or 'binary'
250
+ event_data = Column(Text, nullable=True) # JSON (when data_type='json')
251
+ event_data_binary = Column(LargeBinary, nullable=True) # Binary (when data_type='binary')
252
+ content_type = Column(String(100), nullable=False, server_default=text("'application/json'"))
253
+ created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
254
+ published_at = Column(DateTime(timezone=True), nullable=True)
255
+ status = Column(String(50), nullable=False, server_default=text("'pending'"))
256
+ retry_count = Column(Integer, nullable=False, server_default=text("0"))
257
+ last_error = Column(Text, nullable=True)
258
+
259
+ __table_args__ = (
260
+ CheckConstraint(
261
+ "status IN ('pending', 'processing', 'published', 'failed', 'invalid', 'expired')",
262
+ name="valid_outbox_status",
263
+ ),
264
+ CheckConstraint(
265
+ "data_type IN ('json', 'binary')",
266
+ name="valid_outbox_data_type",
267
+ ),
268
+ CheckConstraint(
269
+ "(data_type = 'json' AND event_data IS NOT NULL AND event_data_binary IS NULL) OR "
270
+ "(data_type = 'binary' AND event_data IS NULL AND event_data_binary IS NOT NULL)",
271
+ name="outbox_data_type_consistency",
272
+ ),
273
+ Index("idx_outbox_status", "status", "created_at"),
274
+ Index("idx_outbox_retry", "status", "retry_count"),
275
+ Index("idx_outbox_published", "published_at"),
276
+ )
277
+
278
+
279
+ # Current schema version
280
+ CURRENT_SCHEMA_VERSION = 1
281
+
282
+
283
+ # ============================================================================
284
+ # Transaction Context
285
+ # ============================================================================
286
+
287
+
288
+ @dataclass
289
+ class TransactionContext:
290
+ """
291
+ Transaction context for managing nested transactions.
292
+
293
+ Uses savepoints for nested transaction support across all databases.
294
+ """
295
+
296
+ depth: int = 0
297
+ """Current transaction depth (0 = not in transaction, 1+ = in transaction)"""
298
+
299
+ savepoint_stack: list[Any] = field(default_factory=list)
300
+ """Stack of nested transaction objects for savepoint support"""
301
+
302
+ session: "AsyncSession | None" = None
303
+ """The actual session for this transaction"""
304
+
305
+
306
+ # Context variable for transaction state (asyncio-safe)
307
+ _transaction_context: ContextVar[TransactionContext | None] = ContextVar(
308
+ "_transaction_context", default=None
309
+ )
310
+
311
+
312
+ # ============================================================================
313
+ # SQLAlchemyStorage
314
+ # ============================================================================
315
+
316
+
317
+ class SQLAlchemyStorage:
318
+ """
319
+ SQLAlchemy implementation of StorageProtocol.
320
+
321
+ Supports SQLite, PostgreSQL, and MySQL with database-based exclusive control
322
+ and transactional outbox pattern.
323
+
324
+ Transaction Architecture:
325
+ - Lock operations: Always use separate session (isolated transactions)
326
+ - History/outbox operations: Use transaction context session when available
327
+ - Automatic transaction management via @activity decorator
328
+ """
329
+
330
+ def __init__(self, engine: AsyncEngine):
331
+ """
332
+ Initialize SQLAlchemy storage.
333
+
334
+ Args:
335
+ engine: SQLAlchemy AsyncEngine instance
336
+ """
337
+ self.engine = engine
338
+
339
+ async def initialize(self) -> None:
340
+ """Initialize database connection and create tables."""
341
+ # Create all tables and indexes
342
+ async with self.engine.begin() as conn:
343
+ await conn.run_sync(Base.metadata.create_all)
344
+
345
+ # Initialize schema version
346
+ await self._initialize_schema_version()
347
+
348
+ async def close(self) -> None:
349
+ """Close database connection."""
350
+ await self.engine.dispose()
351
+
352
+ async def _initialize_schema_version(self) -> None:
353
+ """Initialize schema version for a fresh database."""
354
+ async with AsyncSession(self.engine) as session:
355
+ # Check if schema_version table is empty
356
+ result = await session.execute(select(func.count()).select_from(SchemaVersion))
357
+ count = result.scalar()
358
+
359
+ # If empty, insert current version
360
+ if count == 0:
361
+ version = SchemaVersion(
362
+ version=CURRENT_SCHEMA_VERSION,
363
+ description="Initial schema with workflow_definitions",
364
+ )
365
+ session.add(version)
366
+ await session.commit()
367
+ logger.info(f"Initialized schema version to {CURRENT_SCHEMA_VERSION}")
368
+
369
+ def _get_session_for_operation(self, is_lock_operation: bool = False) -> AsyncSession:
370
+ """
371
+ Get the appropriate session for an operation.
372
+
373
+ Lock operations ALWAYS use a new session (separate transactions).
374
+ Other operations prefer: transaction session > new session.
375
+
376
+ Args:
377
+ is_lock_operation: True if this is a lock acquisition/release operation
378
+
379
+ Returns:
380
+ AsyncSession to use for the operation
381
+ """
382
+ if is_lock_operation:
383
+ # Lock operations always use new session
384
+ return AsyncSession(self.engine, expire_on_commit=False)
385
+
386
+ # Check for transaction context session
387
+ ctx = _transaction_context.get()
388
+ if ctx is not None and ctx.session is not None:
389
+ return ctx.session
390
+
391
+ # Otherwise create new session
392
+ return AsyncSession(self.engine, expire_on_commit=False)
393
+
394
+ def _is_managed_session(self, session: AsyncSession) -> bool:
395
+ """Check if session is managed by transaction context."""
396
+ ctx = _transaction_context.get()
397
+ return ctx is not None and ctx.session == session
398
+
399
+ @asynccontextmanager
400
+ async def _session_scope(self, session: AsyncSession) -> AsyncIterator[AsyncSession]:
401
+ """
402
+ Context manager for session usage.
403
+
404
+ If session is managed (transaction context), use it directly without closing.
405
+ If session is new, manage its lifecycle (commit/rollback/close).
406
+ """
407
+ if self._is_managed_session(session):
408
+ # Managed session: yield without lifecycle management
409
+ yield session
410
+ else:
411
+ # New session: full lifecycle management
412
+ try:
413
+ yield session
414
+ await self._commit_if_not_in_transaction(session)
415
+ except Exception:
416
+ await session.rollback()
417
+ raise
418
+ finally:
419
+ await session.close()
420
+
421
+ def _get_current_time_expr(self) -> Any:
422
+ """
423
+ Get database-specific current time SQL expression.
424
+
425
+ Returns:
426
+ SQLAlchemy function for current time in SQL queries.
427
+ - SQLite: datetime('now') - returns UTC datetime string
428
+ - PostgreSQL/MySQL: NOW() - returns timezone-aware datetime
429
+
430
+ This method enables cross-database datetime comparisons in SQL queries.
431
+ """
432
+ if self.engine.dialect.name == "sqlite":
433
+ # SQLite: datetime('now') returns UTC datetime as string
434
+ return func.datetime("now")
435
+ else:
436
+ # PostgreSQL/MySQL: NOW() returns timezone-aware datetime
437
+ return func.now()
438
+
439
+ def _make_datetime_comparable(self, column: Any) -> Any:
440
+ """
441
+ Make datetime column comparable with current time in SQL queries.
442
+
443
+ For SQLite, wraps column in datetime() function to ensure proper comparison.
444
+ For PostgreSQL/MySQL, returns column as-is (already timezone-aware).
445
+
446
+ Args:
447
+ column: SQLAlchemy Column expression representing a datetime field
448
+
449
+ Returns:
450
+ SQLAlchemy expression suitable for datetime comparison
451
+
452
+ Example:
453
+ >>> # SQLite: datetime(timeout_at) <= datetime('now')
454
+ >>> # PostgreSQL/MySQL: timeout_at <= NOW()
455
+ >>> self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
456
+ >>> <= self._get_current_time_expr()
457
+ """
458
+ if self.engine.dialect.name == "sqlite":
459
+ # SQLite: wrap in datetime() for proper comparison
460
+ return func.datetime(column)
461
+ else:
462
+ # PostgreSQL/MySQL: column is already timezone-aware
463
+ return column
464
+
465
+ # -------------------------------------------------------------------------
466
+ # Transaction Management Methods
467
+ # -------------------------------------------------------------------------
468
+
469
+ async def begin_transaction(self) -> None:
470
+ """
471
+ Begin a new transaction.
472
+
473
+ If a transaction is already in progress, creates a nested transaction
474
+ using savepoints. This is asyncio-safe using ContextVar.
475
+ """
476
+ ctx = _transaction_context.get()
477
+
478
+ if ctx is None:
479
+ # First transaction - create new context with session
480
+ session = AsyncSession(self.engine, expire_on_commit=False)
481
+ ctx = TransactionContext(session=session)
482
+ _transaction_context.set(ctx)
483
+
484
+ ctx.depth += 1
485
+
486
+ if ctx.depth == 1:
487
+ # Top-level transaction - begin the session transaction
488
+ logger.debug("Beginning top-level transaction")
489
+ await ctx.session.begin() # type: ignore[union-attr]
490
+ else:
491
+ # Nested transaction - use SQLAlchemy's begin_nested() (creates SAVEPOINT)
492
+ nested_tx = await ctx.session.begin_nested() # type: ignore[union-attr]
493
+ ctx.savepoint_stack.append(nested_tx)
494
+ logger.debug(f"Created nested transaction (savepoint) at depth={ctx.depth}")
495
+
496
+ async def commit_transaction(self) -> None:
497
+ """
498
+ Commit the current transaction.
499
+
500
+ For nested transactions, releases the savepoint.
501
+ For top-level transactions, commits to the database.
502
+ """
503
+ ctx = _transaction_context.get()
504
+ if ctx is None or ctx.depth == 0:
505
+ raise RuntimeError("Not in a transaction")
506
+
507
+ if ctx.depth == 1:
508
+ # Top-level transaction - commit the session
509
+ logger.debug("Committing top-level transaction")
510
+ await ctx.session.commit() # type: ignore[union-attr]
511
+ await ctx.session.close() # type: ignore[union-attr]
512
+ else:
513
+ # Nested transaction - commit the savepoint
514
+ nested_tx = ctx.savepoint_stack.pop()
515
+ await nested_tx.commit()
516
+ logger.debug(f"Committed nested transaction (savepoint) at depth={ctx.depth}")
517
+
518
+ ctx.depth -= 1
519
+
520
+ if ctx.depth == 0:
521
+ # All transactions completed - clear context
522
+ _transaction_context.set(None)
523
+
524
+ async def rollback_transaction(self) -> None:
525
+ """
526
+ Rollback the current transaction.
527
+
528
+ For nested transactions, rolls back to the savepoint.
529
+ For top-level transactions, rolls back all changes.
530
+ """
531
+ ctx = _transaction_context.get()
532
+ if ctx is None or ctx.depth == 0:
533
+ raise RuntimeError("Not in a transaction")
534
+
535
+ if ctx.depth == 1:
536
+ # Top-level transaction - rollback the session
537
+ logger.debug("Rolling back top-level transaction")
538
+ await ctx.session.rollback() # type: ignore[union-attr]
539
+ await ctx.session.close() # type: ignore[union-attr]
540
+ else:
541
+ # Nested transaction - rollback the savepoint
542
+ nested_tx = ctx.savepoint_stack.pop()
543
+ await nested_tx.rollback()
544
+ logger.debug(f"Rolled back nested transaction (savepoint) at depth={ctx.depth}")
545
+
546
+ ctx.depth -= 1
547
+
548
+ if ctx.depth == 0:
549
+ # All transactions rolled back - clear context
550
+ _transaction_context.set(None)
551
+
552
+ def in_transaction(self) -> bool:
553
+ """
554
+ Check if currently in a transaction.
555
+
556
+ Returns:
557
+ True if in a transaction, False otherwise.
558
+ """
559
+ ctx = _transaction_context.get()
560
+ return ctx is not None and ctx.depth > 0
561
+
562
+ async def _commit_if_not_in_transaction(self, session: AsyncSession) -> None:
563
+ """
564
+ Commit session if not in a transaction (auto-commit mode).
565
+
566
+ This helper method ensures that operations outside of explicit transactions
567
+ are still committed, while operations inside transactions are deferred
568
+ until the transaction is committed.
569
+
570
+ Args:
571
+ session: Database session
572
+ """
573
+ # If this is a transaction context session, don't commit (will be done by commit_transaction)
574
+ ctx = _transaction_context.get()
575
+ if ctx is not None and ctx.session == session:
576
+ return
577
+
578
+ # If not in transaction, commit
579
+ if not self.in_transaction():
580
+ await session.commit()
581
+
582
+ # -------------------------------------------------------------------------
583
+ # Workflow Definition Methods
584
+ # -------------------------------------------------------------------------
585
+
586
+ async def upsert_workflow_definition(
587
+ self,
588
+ workflow_name: str,
589
+ source_hash: str,
590
+ source_code: str,
591
+ ) -> None:
592
+ """Insert or update a workflow definition."""
593
+ session = self._get_session_for_operation()
594
+ async with self._session_scope(session) as session:
595
+ # Check if exists
596
+ result = await session.execute(
597
+ select(WorkflowDefinition).where(
598
+ and_(
599
+ WorkflowDefinition.workflow_name == workflow_name,
600
+ WorkflowDefinition.source_hash == source_hash,
601
+ )
602
+ )
603
+ )
604
+ existing = result.scalar_one_or_none()
605
+
606
+ if existing:
607
+ # Update
608
+ existing.source_code = source_code # type: ignore[assignment]
609
+ else:
610
+ # Insert
611
+ definition = WorkflowDefinition(
612
+ workflow_name=workflow_name,
613
+ source_hash=source_hash,
614
+ source_code=source_code,
615
+ )
616
+ session.add(definition)
617
+
618
+ await self._commit_if_not_in_transaction(session)
619
+
620
+ async def get_workflow_definition(
621
+ self,
622
+ workflow_name: str,
623
+ source_hash: str,
624
+ ) -> dict[str, Any] | None:
625
+ """Get a workflow definition by name and hash."""
626
+ session = self._get_session_for_operation()
627
+ async with self._session_scope(session) as session:
628
+ result = await session.execute(
629
+ select(WorkflowDefinition).where(
630
+ and_(
631
+ WorkflowDefinition.workflow_name == workflow_name,
632
+ WorkflowDefinition.source_hash == source_hash,
633
+ )
634
+ )
635
+ )
636
+ definition = result.scalar_one_or_none()
637
+
638
+ if definition is None:
639
+ return None
640
+
641
+ return {
642
+ "workflow_name": definition.workflow_name,
643
+ "source_hash": definition.source_hash,
644
+ "source_code": definition.source_code,
645
+ "created_at": definition.created_at.isoformat(),
646
+ }
647
+
648
+ async def get_current_workflow_definition(
649
+ self,
650
+ workflow_name: str,
651
+ ) -> dict[str, Any] | None:
652
+ """Get the most recent workflow definition by name."""
653
+ session = self._get_session_for_operation()
654
+ async with self._session_scope(session) as session:
655
+ result = await session.execute(
656
+ select(WorkflowDefinition)
657
+ .where(WorkflowDefinition.workflow_name == workflow_name)
658
+ .order_by(WorkflowDefinition.created_at.desc())
659
+ .limit(1)
660
+ )
661
+ definition = result.scalar_one_or_none()
662
+
663
+ if definition is None:
664
+ return None
665
+
666
+ return {
667
+ "workflow_name": definition.workflow_name,
668
+ "source_hash": definition.source_hash,
669
+ "source_code": definition.source_code,
670
+ "created_at": definition.created_at.isoformat(),
671
+ }
672
+
673
+ # -------------------------------------------------------------------------
674
+ # Workflow Instance Methods
675
+ # -------------------------------------------------------------------------
676
+
677
+ async def create_instance(
678
+ self,
679
+ instance_id: str,
680
+ workflow_name: str,
681
+ source_hash: str,
682
+ owner_service: str,
683
+ input_data: dict[str, Any],
684
+ lock_timeout_seconds: int | None = None,
685
+ ) -> None:
686
+ """Create a new workflow instance."""
687
+ session = self._get_session_for_operation()
688
+ async with self._session_scope(session) as session:
689
+ instance = WorkflowInstance(
690
+ instance_id=instance_id,
691
+ workflow_name=workflow_name,
692
+ source_hash=source_hash,
693
+ owner_service=owner_service,
694
+ input_data=json.dumps(input_data),
695
+ lock_timeout_seconds=lock_timeout_seconds,
696
+ )
697
+ session.add(instance)
698
+
699
+ async def get_instance(self, instance_id: str) -> dict[str, Any] | None:
700
+ """Get workflow instance metadata with its definition."""
701
+ session = self._get_session_for_operation()
702
+ async with self._session_scope(session) as session:
703
+ # Join with workflow_definitions to get source_code
704
+ result = await session.execute(
705
+ select(WorkflowInstance, WorkflowDefinition.source_code)
706
+ .join(
707
+ WorkflowDefinition,
708
+ and_(
709
+ WorkflowInstance.workflow_name == WorkflowDefinition.workflow_name,
710
+ WorkflowInstance.source_hash == WorkflowDefinition.source_hash,
711
+ ),
712
+ )
713
+ .where(WorkflowInstance.instance_id == instance_id)
714
+ )
715
+ row = result.one_or_none()
716
+
717
+ if row is None:
718
+ return None
719
+
720
+ instance, source_code = row
721
+
722
+ return {
723
+ "instance_id": instance.instance_id,
724
+ "workflow_name": instance.workflow_name,
725
+ "source_hash": instance.source_hash,
726
+ "owner_service": instance.owner_service,
727
+ "status": instance.status,
728
+ "current_activity_id": instance.current_activity_id,
729
+ "started_at": instance.started_at.isoformat(),
730
+ "updated_at": instance.updated_at.isoformat(),
731
+ "input_data": json.loads(instance.input_data),
732
+ "source_code": source_code,
733
+ "output_data": json.loads(instance.output_data) if instance.output_data else None,
734
+ "locked_by": instance.locked_by,
735
+ "locked_at": instance.locked_at.isoformat() if instance.locked_at else None,
736
+ "lock_timeout_seconds": instance.lock_timeout_seconds,
737
+ }
738
+
739
+ async def update_instance_status(
740
+ self,
741
+ instance_id: str,
742
+ status: str,
743
+ output_data: dict[str, Any] | None = None,
744
+ ) -> None:
745
+ """Update workflow instance status."""
746
+ session = self._get_session_for_operation()
747
+ async with self._session_scope(session) as session:
748
+ stmt = (
749
+ update(WorkflowInstance)
750
+ .where(WorkflowInstance.instance_id == instance_id)
751
+ .values(
752
+ status=status,
753
+ updated_at=func.now(),
754
+ )
755
+ )
756
+
757
+ if output_data is not None:
758
+ stmt = stmt.values(output_data=json.dumps(output_data))
759
+
760
+ await session.execute(stmt)
761
+ await self._commit_if_not_in_transaction(session)
762
+
763
+ async def update_instance_activity(self, instance_id: str, activity_id: str) -> None:
764
+ """Update the current activity ID for a workflow instance."""
765
+ session = self._get_session_for_operation()
766
+ async with self._session_scope(session) as session:
767
+ await session.execute(
768
+ update(WorkflowInstance)
769
+ .where(WorkflowInstance.instance_id == instance_id)
770
+ .values(current_activity_id=activity_id, updated_at=func.now())
771
+ )
772
+ await self._commit_if_not_in_transaction(session)
773
+
774
+ async def list_instances(
775
+ self,
776
+ limit: int = 50,
777
+ status_filter: str | None = None,
778
+ ) -> list[dict[str, Any]]:
779
+ """List workflow instances with optional filtering."""
780
+ session = self._get_session_for_operation()
781
+ async with self._session_scope(session) as session:
782
+ stmt = (
783
+ select(WorkflowInstance, WorkflowDefinition.source_code)
784
+ .join(
785
+ WorkflowDefinition,
786
+ and_(
787
+ WorkflowInstance.workflow_name == WorkflowDefinition.workflow_name,
788
+ WorkflowInstance.source_hash == WorkflowDefinition.source_hash,
789
+ ),
790
+ )
791
+ .order_by(WorkflowInstance.started_at.desc())
792
+ .limit(limit)
793
+ )
794
+
795
+ if status_filter:
796
+ stmt = stmt.where(WorkflowInstance.status == status_filter)
797
+
798
+ result = await session.execute(stmt)
799
+ rows = result.all()
800
+
801
+ return [
802
+ {
803
+ "instance_id": instance.instance_id,
804
+ "workflow_name": instance.workflow_name,
805
+ "source_hash": instance.source_hash,
806
+ "owner_service": instance.owner_service,
807
+ "status": instance.status,
808
+ "current_activity_id": instance.current_activity_id,
809
+ "started_at": instance.started_at.isoformat(),
810
+ "updated_at": instance.updated_at.isoformat(),
811
+ "input_data": json.loads(instance.input_data),
812
+ "source_code": source_code,
813
+ "output_data": (
814
+ json.loads(instance.output_data) if instance.output_data else None
815
+ ),
816
+ "locked_by": instance.locked_by,
817
+ "locked_at": instance.locked_at.isoformat() if instance.locked_at else None,
818
+ "lock_timeout_seconds": instance.lock_timeout_seconds,
819
+ }
820
+ for instance, source_code in rows
821
+ ]
822
+
823
+ # -------------------------------------------------------------------------
824
+ # Distributed Locking Methods (ALWAYS use separate session/transaction)
825
+ # -------------------------------------------------------------------------
826
+
827
+ async def try_acquire_lock(
828
+ self,
829
+ instance_id: str,
830
+ worker_id: str,
831
+ timeout_seconds: int = 300,
832
+ ) -> bool:
833
+ """
834
+ Try to acquire lock using SELECT FOR UPDATE.
835
+
836
+ This implements distributed locking with automatic stale lock detection.
837
+ Returns True if lock was acquired, False if already locked by another worker.
838
+ Can acquire locks that have timed out.
839
+
840
+ Note: ALWAYS uses separate session (not external session).
841
+ """
842
+ session = self._get_session_for_operation(is_lock_operation=True)
843
+ async with self._session_scope(session) as session:
844
+ # Calculate timeout threshold and current time
845
+ # Use UTC time consistently (timezone-aware to match DateTime(timezone=True) columns)
846
+ current_time = datetime.now(UTC)
847
+
848
+ # SELECT FOR UPDATE SKIP LOCKED to prevent blocking (PostgreSQL/MySQL)
849
+ # SKIP LOCKED: If row is already locked, return None immediately (no blocking)
850
+ result = await session.execute(
851
+ select(WorkflowInstance)
852
+ .where(WorkflowInstance.instance_id == instance_id)
853
+ .with_for_update(skip_locked=True)
854
+ )
855
+ instance = result.scalar_one_or_none()
856
+
857
+ if instance is None:
858
+ # Instance doesn't exist
859
+ await session.commit()
860
+ return False
861
+
862
+ # Determine actual timeout (priority: instance > parameter > default)
863
+ actual_timeout = int(
864
+ instance.lock_timeout_seconds
865
+ if instance.lock_timeout_seconds is not None
866
+ else timeout_seconds
867
+ )
868
+
869
+ # Check if we can acquire the lock
870
+ # Lock is available if: not locked OR lock has expired
871
+ # Note: SQLite stores datetime without timezone, add UTC timezone
872
+ if instance.locked_by is None:
873
+ can_acquire = True
874
+ elif instance.lock_expires_at is not None:
875
+ lock_expires_at_utc = (
876
+ instance.lock_expires_at.replace(tzinfo=UTC)
877
+ if instance.lock_expires_at.tzinfo is None
878
+ else instance.lock_expires_at
879
+ )
880
+ can_acquire = lock_expires_at_utc < current_time
881
+ else:
882
+ can_acquire = False
883
+
884
+ # Debug logging
885
+ logger.debug(
886
+ f"Lock acquisition check: instance_id={instance_id}, "
887
+ f"locked_by={instance.locked_by}, lock_expires_at={instance.lock_expires_at}, "
888
+ f"current_time={current_time}, can_acquire={can_acquire}"
889
+ )
890
+
891
+ if not can_acquire:
892
+ # Already locked by another worker
893
+ logger.debug(f"Lock acquisition failed: already locked by {instance.locked_by}")
894
+ await session.commit()
895
+ return False
896
+
897
+ # Acquire the lock and set expiry time
898
+ lock_expires_at = current_time + timedelta(seconds=actual_timeout)
899
+ await session.execute(
900
+ update(WorkflowInstance)
901
+ .where(WorkflowInstance.instance_id == instance_id)
902
+ .values(
903
+ locked_by=worker_id,
904
+ locked_at=current_time,
905
+ lock_expires_at=lock_expires_at,
906
+ updated_at=current_time,
907
+ )
908
+ )
909
+
910
+ await session.commit()
911
+ return True
912
+
913
+ async def release_lock(self, instance_id: str, worker_id: str) -> None:
914
+ """
915
+ Release lock only if we own it.
916
+
917
+ Note: ALWAYS uses separate session (not external session).
918
+ """
919
+ session = self._get_session_for_operation(is_lock_operation=True)
920
+ async with self._session_scope(session) as session:
921
+ # Use Python datetime for consistency (timezone-aware)
922
+ current_time = datetime.now(UTC)
923
+
924
+ await session.execute(
925
+ update(WorkflowInstance)
926
+ .where(
927
+ and_(
928
+ WorkflowInstance.instance_id == instance_id,
929
+ WorkflowInstance.locked_by == worker_id,
930
+ )
931
+ )
932
+ .values(
933
+ locked_by=None,
934
+ locked_at=None,
935
+ lock_expires_at=None,
936
+ updated_at=current_time,
937
+ )
938
+ )
939
+ await session.commit()
940
+
941
+ async def refresh_lock(
942
+ self, instance_id: str, worker_id: str, timeout_seconds: int = 300
943
+ ) -> bool:
944
+ """
945
+ Refresh lock timestamp and expiry time.
946
+
947
+ Args:
948
+ instance_id: Workflow instance ID
949
+ worker_id: Worker ID that currently owns the lock
950
+ timeout_seconds: Default timeout (used if instance doesn't have custom timeout)
951
+
952
+ Note: ALWAYS uses separate session (not external session).
953
+ """
954
+ session = self._get_session_for_operation(is_lock_operation=True)
955
+ async with self._session_scope(session) as session:
956
+ # Use Python datetime for consistency with try_acquire_lock() (timezone-aware)
957
+ current_time = datetime.now(UTC)
958
+
959
+ # First, get the instance to determine actual timeout
960
+ result = await session.execute(
961
+ select(WorkflowInstance).where(
962
+ and_(
963
+ WorkflowInstance.instance_id == instance_id,
964
+ WorkflowInstance.locked_by == worker_id,
965
+ )
966
+ )
967
+ )
968
+ instance = result.scalar_one_or_none()
969
+
970
+ if instance is None:
971
+ # Instance doesn't exist or not locked by us
972
+ await session.commit()
973
+ return False
974
+
975
+ # Determine actual timeout (priority: instance > parameter > default)
976
+ actual_timeout = int(
977
+ instance.lock_timeout_seconds
978
+ if instance.lock_timeout_seconds is not None
979
+ else timeout_seconds
980
+ )
981
+
982
+ # Calculate new expiry time
983
+ lock_expires_at = current_time + timedelta(seconds=actual_timeout)
984
+
985
+ # Update lock timestamp and expiry
986
+ result = await session.execute(
987
+ update(WorkflowInstance)
988
+ .where(
989
+ and_(
990
+ WorkflowInstance.instance_id == instance_id,
991
+ WorkflowInstance.locked_by == worker_id,
992
+ )
993
+ )
994
+ .values(
995
+ locked_at=current_time,
996
+ lock_expires_at=lock_expires_at,
997
+ updated_at=current_time,
998
+ )
999
+ )
1000
+ await session.commit()
1001
+ return bool(result.rowcount and result.rowcount > 0) # type: ignore[attr-defined]
1002
+
1003
+ async def cleanup_stale_locks(self) -> list[dict[str, str]]:
1004
+ """
1005
+ Clean up locks that have expired (based on lock_expires_at column).
1006
+
1007
+ Returns list of workflows with status='running' or 'compensating' that need auto-resume.
1008
+
1009
+ Workflows with status='compensating' crashed during compensation execution
1010
+ and need special handling to complete compensations.
1011
+
1012
+ Note: Uses lock_expires_at column for efficient SQL-side filtering.
1013
+ Note: ALWAYS uses separate session (not external session).
1014
+ """
1015
+ session = self._get_session_for_operation(is_lock_operation=True)
1016
+ async with self._session_scope(session) as session:
1017
+ # Use timezone-aware datetime to match DateTime(timezone=True) columns
1018
+ current_time = datetime.now(UTC)
1019
+
1020
+ # SQL-side filtering: Find all instances with expired locks
1021
+ # Use database abstraction layer for cross-database compatibility
1022
+ result = await session.execute(
1023
+ select(WorkflowInstance).where(
1024
+ and_(
1025
+ WorkflowInstance.locked_by.isnot(None),
1026
+ WorkflowInstance.lock_expires_at.isnot(None),
1027
+ self._make_datetime_comparable(WorkflowInstance.lock_expires_at)
1028
+ < self._get_current_time_expr(),
1029
+ )
1030
+ )
1031
+ )
1032
+ instances = result.scalars().all()
1033
+
1034
+ stale_instance_ids = []
1035
+ workflows_to_resume = []
1036
+
1037
+ # Collect instance IDs and workflows to resume
1038
+ for instance in instances:
1039
+ stale_instance_ids.append(instance.instance_id)
1040
+
1041
+ # Add to resume list if status is 'running' or 'compensating'
1042
+ if instance.status in ["running", "compensating"]:
1043
+ workflows_to_resume.append(
1044
+ {
1045
+ "instance_id": str(instance.instance_id),
1046
+ "workflow_name": str(instance.workflow_name),
1047
+ "source_hash": str(instance.source_hash),
1048
+ "status": str(instance.status),
1049
+ }
1050
+ )
1051
+
1052
+ # Release all stale locks in one UPDATE statement
1053
+ if stale_instance_ids:
1054
+ await session.execute(
1055
+ update(WorkflowInstance)
1056
+ .where(WorkflowInstance.instance_id.in_(stale_instance_ids))
1057
+ .values(
1058
+ locked_by=None,
1059
+ locked_at=None,
1060
+ lock_expires_at=None,
1061
+ updated_at=current_time,
1062
+ )
1063
+ )
1064
+
1065
+ await session.commit()
1066
+ return workflows_to_resume
1067
+
1068
+ # -------------------------------------------------------------------------
1069
+ # History Methods (prefer external session)
1070
+ # -------------------------------------------------------------------------
1071
+
1072
+ async def append_history(
1073
+ self,
1074
+ instance_id: str,
1075
+ activity_id: str,
1076
+ event_type: str,
1077
+ event_data: dict[str, Any] | bytes,
1078
+ ) -> None:
1079
+ """Append an event to workflow execution history."""
1080
+ session = self._get_session_for_operation()
1081
+ async with self._session_scope(session) as session:
1082
+ # Determine data type and storage columns
1083
+ if isinstance(event_data, bytes):
1084
+ data_type = "binary"
1085
+ event_data_json = None
1086
+ event_data_bin = event_data
1087
+ else:
1088
+ data_type = "json"
1089
+ event_data_json = json.dumps(event_data)
1090
+ event_data_bin = None
1091
+
1092
+ history = WorkflowHistory(
1093
+ instance_id=instance_id,
1094
+ activity_id=activity_id,
1095
+ event_type=event_type,
1096
+ data_type=data_type,
1097
+ event_data=event_data_json,
1098
+ event_data_binary=event_data_bin,
1099
+ )
1100
+ session.add(history)
1101
+ await self._commit_if_not_in_transaction(session)
1102
+
1103
+ async def get_history(self, instance_id: str) -> list[dict[str, Any]]:
1104
+ """
1105
+ Get workflow execution history in order.
1106
+
1107
+ Returns history events ordered by creation time.
1108
+ """
1109
+ session = self._get_session_for_operation()
1110
+ async with self._session_scope(session) as session:
1111
+ result = await session.execute(
1112
+ select(WorkflowHistory)
1113
+ .where(WorkflowHistory.instance_id == instance_id)
1114
+ .order_by(WorkflowHistory.created_at.asc())
1115
+ )
1116
+ rows = result.scalars().all()
1117
+
1118
+ return [
1119
+ {
1120
+ "id": row.id,
1121
+ "instance_id": row.instance_id,
1122
+ "activity_id": row.activity_id,
1123
+ "event_type": row.event_type,
1124
+ "event_data": (
1125
+ row.event_data_binary
1126
+ if row.data_type == "binary"
1127
+ else json.loads(row.event_data) # type: ignore[arg-type]
1128
+ ),
1129
+ "created_at": row.created_at.isoformat(),
1130
+ }
1131
+ for row in rows
1132
+ ]
1133
+
1134
+ # -------------------------------------------------------------------------
1135
+ # Compensation Methods (prefer external session)
1136
+ # -------------------------------------------------------------------------
1137
+
1138
+ async def push_compensation(
1139
+ self,
1140
+ instance_id: str,
1141
+ activity_id: str,
1142
+ activity_name: str,
1143
+ args: dict[str, Any],
1144
+ ) -> None:
1145
+ """Push a compensation to the stack."""
1146
+ session = self._get_session_for_operation()
1147
+ async with self._session_scope(session) as session:
1148
+ compensation = WorkflowCompensation(
1149
+ instance_id=instance_id,
1150
+ activity_id=activity_id,
1151
+ activity_name=activity_name,
1152
+ args=json.dumps(args),
1153
+ )
1154
+ session.add(compensation)
1155
+ await self._commit_if_not_in_transaction(session)
1156
+
1157
+ async def get_compensations(self, instance_id: str) -> list[dict[str, Any]]:
1158
+ """Get compensations in LIFO order (most recent first)."""
1159
+ session = self._get_session_for_operation()
1160
+ async with self._session_scope(session) as session:
1161
+ result = await session.execute(
1162
+ select(WorkflowCompensation)
1163
+ .where(WorkflowCompensation.instance_id == instance_id)
1164
+ .order_by(WorkflowCompensation.created_at.desc())
1165
+ )
1166
+ rows = result.scalars().all()
1167
+
1168
+ return [
1169
+ {
1170
+ "id": row.id,
1171
+ "instance_id": row.instance_id,
1172
+ "activity_id": row.activity_id,
1173
+ "activity_name": row.activity_name,
1174
+ "args": json.loads(row.args), # type: ignore[arg-type]
1175
+ "created_at": row.created_at.isoformat(),
1176
+ }
1177
+ for row in rows
1178
+ ]
1179
+
1180
+ async def clear_compensations(self, instance_id: str) -> None:
1181
+ """Clear all compensations for a workflow instance."""
1182
+ session = self._get_session_for_operation()
1183
+ async with self._session_scope(session) as session:
1184
+ await session.execute(
1185
+ delete(WorkflowCompensation).where(WorkflowCompensation.instance_id == instance_id)
1186
+ )
1187
+ await self._commit_if_not_in_transaction(session)
1188
+
1189
+ # -------------------------------------------------------------------------
1190
+ # Event Subscription Methods (prefer external session for registration)
1191
+ # -------------------------------------------------------------------------
1192
+
1193
+ async def add_event_subscription(
1194
+ self,
1195
+ instance_id: str,
1196
+ event_type: str,
1197
+ timeout_at: datetime | None = None,
1198
+ ) -> None:
1199
+ """Register an event wait subscription."""
1200
+ session = self._get_session_for_operation()
1201
+ async with self._session_scope(session) as session:
1202
+ subscription = WorkflowEventSubscription(
1203
+ instance_id=instance_id,
1204
+ event_type=event_type,
1205
+ timeout_at=timeout_at,
1206
+ )
1207
+ session.add(subscription)
1208
+ await self._commit_if_not_in_transaction(session)
1209
+
1210
+ async def find_waiting_instances(self, event_type: str) -> list[dict[str, Any]]:
1211
+ """Find workflow instances waiting for a specific event type."""
1212
+ session = self._get_session_for_operation()
1213
+ async with self._session_scope(session) as session:
1214
+ result = await session.execute(
1215
+ select(WorkflowEventSubscription).where(
1216
+ and_(
1217
+ WorkflowEventSubscription.event_type == event_type,
1218
+ or_(
1219
+ WorkflowEventSubscription.timeout_at.is_(None),
1220
+ self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
1221
+ > self._get_current_time_expr(),
1222
+ ),
1223
+ )
1224
+ )
1225
+ )
1226
+ rows = result.scalars().all()
1227
+
1228
+ return [
1229
+ {
1230
+ "id": row.id,
1231
+ "instance_id": row.instance_id,
1232
+ "event_type": row.event_type,
1233
+ "activity_id": row.activity_id,
1234
+ "timeout_at": row.timeout_at.isoformat() if row.timeout_at else None,
1235
+ "created_at": row.created_at.isoformat(),
1236
+ }
1237
+ for row in rows
1238
+ ]
1239
+
1240
+ async def remove_event_subscription(
1241
+ self,
1242
+ instance_id: str,
1243
+ event_type: str,
1244
+ ) -> None:
1245
+ """Remove event subscription after the event is received."""
1246
+ session = self._get_session_for_operation()
1247
+ async with self._session_scope(session) as session:
1248
+ await session.execute(
1249
+ delete(WorkflowEventSubscription).where(
1250
+ and_(
1251
+ WorkflowEventSubscription.instance_id == instance_id,
1252
+ WorkflowEventSubscription.event_type == event_type,
1253
+ )
1254
+ )
1255
+ )
1256
+ await self._commit_if_not_in_transaction(session)
1257
+
1258
+ async def cleanup_expired_subscriptions(self) -> int:
1259
+ """Clean up event subscriptions that have timed out."""
1260
+ session = self._get_session_for_operation()
1261
+ async with self._session_scope(session) as session:
1262
+ result = await session.execute(
1263
+ delete(WorkflowEventSubscription).where(
1264
+ and_(
1265
+ WorkflowEventSubscription.timeout_at.isnot(None),
1266
+ self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
1267
+ <= self._get_current_time_expr(),
1268
+ )
1269
+ )
1270
+ )
1271
+ await self._commit_if_not_in_transaction(session)
1272
+ return result.rowcount or 0 # type: ignore[attr-defined]
1273
+
1274
+ async def find_expired_event_subscriptions(self) -> list[dict[str, Any]]:
1275
+ """Find event subscriptions that have timed out."""
1276
+ session = self._get_session_for_operation()
1277
+ async with self._session_scope(session) as session:
1278
+ result = await session.execute(
1279
+ select(
1280
+ WorkflowEventSubscription.instance_id,
1281
+ WorkflowEventSubscription.event_type,
1282
+ WorkflowEventSubscription.activity_id,
1283
+ WorkflowEventSubscription.timeout_at,
1284
+ WorkflowEventSubscription.created_at,
1285
+ ).where(
1286
+ and_(
1287
+ WorkflowEventSubscription.timeout_at.isnot(None),
1288
+ self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
1289
+ <= self._get_current_time_expr(),
1290
+ )
1291
+ )
1292
+ )
1293
+ rows = result.all()
1294
+
1295
+ return [
1296
+ {
1297
+ "instance_id": row[0],
1298
+ "event_type": row[1],
1299
+ "activity_id": row[2],
1300
+ "timeout_at": row[3].isoformat() if row[3] else None,
1301
+ "created_at": row[4].isoformat() if row[4] else None,
1302
+ }
1303
+ for row in rows
1304
+ ]
1305
+
1306
+ async def register_event_subscription_and_release_lock(
1307
+ self,
1308
+ instance_id: str,
1309
+ worker_id: str,
1310
+ event_type: str,
1311
+ timeout_at: datetime | None = None,
1312
+ activity_id: str | None = None,
1313
+ ) -> None:
1314
+ """
1315
+ Atomically register event subscription and release workflow lock.
1316
+
1317
+ This performs THREE operations in a SINGLE transaction:
1318
+ 1. Register event subscription
1319
+ 2. Update current activity
1320
+ 3. Release lock
1321
+
1322
+ This ensures distributed coroutines work correctly - when a workflow
1323
+ calls wait_event(), the subscription is registered and lock is released
1324
+ atomically, so ANY worker can resume the workflow when the event arrives.
1325
+
1326
+ Note: Uses LOCK operation session (separate from external session).
1327
+ """
1328
+ session = self._get_session_for_operation(is_lock_operation=True)
1329
+ async with self._session_scope(session) as session, session.begin():
1330
+ # 1. Verify we hold the lock (sanity check)
1331
+ result = await session.execute(
1332
+ select(WorkflowInstance.locked_by).where(
1333
+ WorkflowInstance.instance_id == instance_id
1334
+ )
1335
+ )
1336
+ row = result.one_or_none()
1337
+
1338
+ if row is None:
1339
+ raise RuntimeError(f"Workflow instance {instance_id} not found")
1340
+
1341
+ current_lock_holder = row[0]
1342
+ if current_lock_holder != worker_id:
1343
+ raise RuntimeError(
1344
+ f"Cannot release lock: worker {worker_id} does not hold lock "
1345
+ f"for {instance_id} (held by: {current_lock_holder})"
1346
+ )
1347
+
1348
+ # 2. Register event subscription (INSERT OR REPLACE equivalent)
1349
+ # First delete existing
1350
+ await session.execute(
1351
+ delete(WorkflowEventSubscription).where(
1352
+ and_(
1353
+ WorkflowEventSubscription.instance_id == instance_id,
1354
+ WorkflowEventSubscription.event_type == event_type,
1355
+ )
1356
+ )
1357
+ )
1358
+
1359
+ # Then insert new
1360
+ subscription = WorkflowEventSubscription(
1361
+ instance_id=instance_id,
1362
+ event_type=event_type,
1363
+ activity_id=activity_id,
1364
+ timeout_at=timeout_at,
1365
+ )
1366
+ session.add(subscription)
1367
+
1368
+ # 3. Update current activity (if provided)
1369
+ if activity_id is not None:
1370
+ await session.execute(
1371
+ update(WorkflowInstance)
1372
+ .where(WorkflowInstance.instance_id == instance_id)
1373
+ .values(current_activity_id=activity_id, updated_at=func.now())
1374
+ )
1375
+
1376
+ # 4. Release lock
1377
+ await session.execute(
1378
+ update(WorkflowInstance)
1379
+ .where(
1380
+ and_(
1381
+ WorkflowInstance.instance_id == instance_id,
1382
+ WorkflowInstance.locked_by == worker_id,
1383
+ )
1384
+ )
1385
+ .values(
1386
+ locked_by=None,
1387
+ locked_at=None,
1388
+ updated_at=func.now(),
1389
+ )
1390
+ )
1391
+
1392
+ async def register_timer_subscription_and_release_lock(
1393
+ self,
1394
+ instance_id: str,
1395
+ worker_id: str,
1396
+ timer_id: str,
1397
+ expires_at: datetime,
1398
+ activity_id: str | None = None,
1399
+ ) -> None:
1400
+ """
1401
+ Atomically register timer subscription and release workflow lock.
1402
+
1403
+ This performs FOUR operations in a SINGLE transaction:
1404
+ 1. Register timer subscription
1405
+ 2. Update current activity
1406
+ 3. Update status to 'waiting_for_timer'
1407
+ 4. Release lock
1408
+
1409
+ This ensures distributed coroutines work correctly - when a workflow
1410
+ calls wait_timer(), the subscription is registered and lock is released
1411
+ atomically, so ANY worker can resume the workflow when the timer expires.
1412
+
1413
+ Note: Uses LOCK operation session (separate from external session).
1414
+ """
1415
+ session = self._get_session_for_operation(is_lock_operation=True)
1416
+ async with self._session_scope(session) as session, session.begin():
1417
+ # 1. Verify we hold the lock (sanity check)
1418
+ result = await session.execute(
1419
+ select(WorkflowInstance.locked_by).where(
1420
+ WorkflowInstance.instance_id == instance_id
1421
+ )
1422
+ )
1423
+ row = result.one_or_none()
1424
+
1425
+ if row is None:
1426
+ raise RuntimeError(f"Workflow instance {instance_id} not found")
1427
+
1428
+ current_lock_holder = row[0]
1429
+ if current_lock_holder != worker_id:
1430
+ raise RuntimeError(
1431
+ f"Cannot release lock: worker {worker_id} does not hold lock "
1432
+ f"for {instance_id} (held by: {current_lock_holder})"
1433
+ )
1434
+
1435
+ # 2. Register timer subscription (with conflict handling)
1436
+ # Check if exists
1437
+ result = await session.execute(
1438
+ select(WorkflowTimerSubscription).where(
1439
+ and_(
1440
+ WorkflowTimerSubscription.instance_id == instance_id,
1441
+ WorkflowTimerSubscription.timer_id == timer_id,
1442
+ )
1443
+ )
1444
+ )
1445
+ existing = result.scalar_one_or_none()
1446
+
1447
+ if not existing:
1448
+ # Insert new subscription
1449
+ subscription = WorkflowTimerSubscription(
1450
+ instance_id=instance_id,
1451
+ timer_id=timer_id,
1452
+ expires_at=expires_at,
1453
+ activity_id=activity_id,
1454
+ )
1455
+ session.add(subscription)
1456
+
1457
+ # 3. Update current activity (if provided)
1458
+ if activity_id is not None:
1459
+ await session.execute(
1460
+ update(WorkflowInstance)
1461
+ .where(WorkflowInstance.instance_id == instance_id)
1462
+ .values(current_activity_id=activity_id, updated_at=func.now())
1463
+ )
1464
+
1465
+ # 4. Update status to 'waiting_for_timer' and release lock
1466
+ await session.execute(
1467
+ update(WorkflowInstance)
1468
+ .where(
1469
+ and_(
1470
+ WorkflowInstance.instance_id == instance_id,
1471
+ WorkflowInstance.locked_by == worker_id,
1472
+ )
1473
+ )
1474
+ .values(
1475
+ status="waiting_for_timer",
1476
+ locked_by=None,
1477
+ locked_at=None,
1478
+ updated_at=func.now(),
1479
+ )
1480
+ )
1481
+
1482
+ async def find_expired_timers(self) -> list[dict[str, Any]]:
1483
+ """Find timer subscriptions that have expired."""
1484
+ session = self._get_session_for_operation()
1485
+ async with self._session_scope(session) as session:
1486
+ result = await session.execute(
1487
+ select(
1488
+ WorkflowTimerSubscription.instance_id,
1489
+ WorkflowTimerSubscription.timer_id,
1490
+ WorkflowTimerSubscription.expires_at,
1491
+ WorkflowTimerSubscription.activity_id,
1492
+ WorkflowInstance.workflow_name,
1493
+ )
1494
+ .join(
1495
+ WorkflowInstance,
1496
+ WorkflowTimerSubscription.instance_id == WorkflowInstance.instance_id,
1497
+ )
1498
+ .where(
1499
+ and_(
1500
+ self._make_datetime_comparable(WorkflowTimerSubscription.expires_at)
1501
+ <= self._get_current_time_expr(),
1502
+ WorkflowInstance.status == "waiting_for_timer",
1503
+ )
1504
+ )
1505
+ )
1506
+ rows = result.all()
1507
+
1508
+ return [
1509
+ {
1510
+ "instance_id": row[0],
1511
+ "timer_id": row[1],
1512
+ "expires_at": row[2].isoformat(),
1513
+ "activity_id": row[3],
1514
+ "workflow_name": row[4],
1515
+ }
1516
+ for row in rows
1517
+ ]
1518
+
1519
+ async def remove_timer_subscription(
1520
+ self,
1521
+ instance_id: str,
1522
+ timer_id: str,
1523
+ ) -> None:
1524
+ """Remove timer subscription after the timer expires."""
1525
+ session = self._get_session_for_operation()
1526
+ async with self._session_scope(session) as session:
1527
+ await session.execute(
1528
+ delete(WorkflowTimerSubscription).where(
1529
+ and_(
1530
+ WorkflowTimerSubscription.instance_id == instance_id,
1531
+ WorkflowTimerSubscription.timer_id == timer_id,
1532
+ )
1533
+ )
1534
+ )
1535
+ await self._commit_if_not_in_transaction(session)
1536
+
1537
+ # -------------------------------------------------------------------------
1538
+ # Transactional Outbox Methods (prefer external session)
1539
+ # -------------------------------------------------------------------------
1540
+
1541
+ async def add_outbox_event(
1542
+ self,
1543
+ event_id: str,
1544
+ event_type: str,
1545
+ event_source: str,
1546
+ event_data: dict[str, Any] | bytes,
1547
+ content_type: str = "application/json",
1548
+ ) -> None:
1549
+ """Add an event to the transactional outbox."""
1550
+ session = self._get_session_for_operation()
1551
+ async with self._session_scope(session) as session:
1552
+ # Determine data type and storage columns
1553
+ if isinstance(event_data, bytes):
1554
+ data_type = "binary"
1555
+ event_data_json = None
1556
+ event_data_bin = event_data
1557
+ else:
1558
+ data_type = "json"
1559
+ event_data_json = json.dumps(event_data)
1560
+ event_data_bin = None
1561
+
1562
+ event = OutboxEvent(
1563
+ event_id=event_id,
1564
+ event_type=event_type,
1565
+ event_source=event_source,
1566
+ data_type=data_type,
1567
+ event_data=event_data_json,
1568
+ event_data_binary=event_data_bin,
1569
+ content_type=content_type,
1570
+ )
1571
+ session.add(event)
1572
+ await self._commit_if_not_in_transaction(session)
1573
+
1574
+ async def get_pending_outbox_events(self, limit: int = 10) -> list[dict[str, Any]]:
1575
+ """
1576
+ Get pending/failed outbox events for publishing (with row-level locking).
1577
+
1578
+ This method uses SELECT FOR UPDATE SKIP LOCKED to safely fetch events
1579
+ in a multi-worker environment. It fetches both 'pending' and 'failed'
1580
+ events (for retry). Fetched events are immediately marked as 'processing'
1581
+ to prevent duplicate processing by other workers.
1582
+
1583
+ Args:
1584
+ limit: Maximum number of events to fetch
1585
+
1586
+ Returns:
1587
+ List of event dictionaries with 'processing' status
1588
+ """
1589
+ # Use new session for lock operation (SKIP LOCKED requires separate transactions)
1590
+ session = self._get_session_for_operation(is_lock_operation=True)
1591
+ # Explicitly begin transaction before SELECT FOR UPDATE
1592
+ # This ensures proper transaction isolation for SKIP LOCKED
1593
+ async with self._session_scope(session) as session, session.begin():
1594
+ # 1. SELECT FOR UPDATE to lock rows (both 'pending' and 'failed' for retry)
1595
+ result = await session.execute(
1596
+ select(OutboxEvent)
1597
+ .where(OutboxEvent.status.in_(["pending", "failed"]))
1598
+ .order_by(OutboxEvent.created_at.asc())
1599
+ .limit(limit)
1600
+ .with_for_update(skip_locked=True)
1601
+ )
1602
+ rows = result.scalars().all()
1603
+
1604
+ # 2. Mark as 'processing' to prevent duplicate fetches
1605
+ if rows:
1606
+ event_ids = [row.event_id for row in rows]
1607
+ await session.execute(
1608
+ update(OutboxEvent)
1609
+ .where(OutboxEvent.event_id.in_(event_ids))
1610
+ .values(status="processing")
1611
+ )
1612
+
1613
+ # 3. Return events (now with status='processing')
1614
+ return [
1615
+ {
1616
+ "event_id": row.event_id,
1617
+ "event_type": row.event_type,
1618
+ "event_source": row.event_source,
1619
+ "event_data": (
1620
+ row.event_data_binary
1621
+ if row.data_type == "binary"
1622
+ else json.loads(row.event_data) # type: ignore[arg-type]
1623
+ ),
1624
+ "content_type": row.content_type,
1625
+ "created_at": row.created_at.isoformat(),
1626
+ "status": "processing", # Always 'processing' after update
1627
+ "retry_count": row.retry_count,
1628
+ "last_error": row.last_error,
1629
+ }
1630
+ for row in rows
1631
+ ]
1632
+
1633
+ async def mark_outbox_published(self, event_id: str) -> None:
1634
+ """Mark outbox event as successfully published."""
1635
+ session = self._get_session_for_operation()
1636
+ async with self._session_scope(session) as session:
1637
+ await session.execute(
1638
+ update(OutboxEvent)
1639
+ .where(OutboxEvent.event_id == event_id)
1640
+ .values(status="published", published_at=func.now())
1641
+ )
1642
+ await self._commit_if_not_in_transaction(session)
1643
+
1644
+ async def mark_outbox_failed(self, event_id: str, error: str) -> None:
1645
+ """
1646
+ Mark event as failed and increment retry count.
1647
+
1648
+ The event status is changed to 'failed' so it can be retried later.
1649
+ get_pending_outbox_events() will fetch both 'pending' and 'failed' events.
1650
+ """
1651
+ session = self._get_session_for_operation()
1652
+ async with self._session_scope(session) as session:
1653
+ await session.execute(
1654
+ update(OutboxEvent)
1655
+ .where(OutboxEvent.event_id == event_id)
1656
+ .values(
1657
+ status="failed",
1658
+ retry_count=OutboxEvent.retry_count + 1,
1659
+ last_error=error,
1660
+ )
1661
+ )
1662
+ await self._commit_if_not_in_transaction(session)
1663
+
1664
+ async def mark_outbox_permanently_failed(self, event_id: str, error: str) -> None:
1665
+ """Mark outbox event as permanently failed (sets status to 'failed')."""
1666
+ session = self._get_session_for_operation()
1667
+ async with self._session_scope(session) as session:
1668
+ await session.execute(
1669
+ update(OutboxEvent)
1670
+ .where(OutboxEvent.event_id == event_id)
1671
+ .values(
1672
+ status="failed",
1673
+ last_error=error,
1674
+ )
1675
+ )
1676
+ await self._commit_if_not_in_transaction(session)
1677
+
1678
+ async def mark_outbox_invalid(self, event_id: str, error: str) -> None:
1679
+ """Mark outbox event as invalid (sets status to 'invalid')."""
1680
+ session = self._get_session_for_operation()
1681
+ async with self._session_scope(session) as session:
1682
+ await session.execute(
1683
+ update(OutboxEvent)
1684
+ .where(OutboxEvent.event_id == event_id)
1685
+ .values(
1686
+ status="invalid",
1687
+ last_error=error,
1688
+ )
1689
+ )
1690
+ await self._commit_if_not_in_transaction(session)
1691
+
1692
+ async def mark_outbox_expired(self, event_id: str, error: str) -> None:
1693
+ """Mark outbox event as expired (sets status to 'expired')."""
1694
+ session = self._get_session_for_operation()
1695
+ async with self._session_scope(session) as session:
1696
+ await session.execute(
1697
+ update(OutboxEvent)
1698
+ .where(OutboxEvent.event_id == event_id)
1699
+ .values(
1700
+ status="expired",
1701
+ last_error=error,
1702
+ )
1703
+ )
1704
+ await self._commit_if_not_in_transaction(session)
1705
+
1706
+ async def cleanup_published_events(self, older_than_hours: int = 24) -> int:
1707
+ """Clean up successfully published events older than threshold."""
1708
+ session = self._get_session_for_operation()
1709
+ async with self._session_scope(session) as session:
1710
+ threshold = datetime.now(UTC) - timedelta(hours=older_than_hours)
1711
+
1712
+ result = await session.execute(
1713
+ delete(OutboxEvent).where(
1714
+ and_(
1715
+ OutboxEvent.status == "published",
1716
+ OutboxEvent.published_at < threshold,
1717
+ )
1718
+ )
1719
+ )
1720
+ await self._commit_if_not_in_transaction(session)
1721
+ return result.rowcount or 0 # type: ignore[attr-defined]
1722
+
1723
+ # -------------------------------------------------------------------------
1724
+ # Workflow Cancellation Methods
1725
+ # -------------------------------------------------------------------------
1726
+
1727
+ async def cancel_instance(self, instance_id: str, cancelled_by: str) -> bool:
1728
+ """
1729
+ Cancel a workflow instance.
1730
+
1731
+ Only running or waiting_for_event workflows can be cancelled.
1732
+ This method atomically:
1733
+ 1. Checks current status
1734
+ 2. Updates status to 'cancelled' if allowed
1735
+ 3. Clears locks
1736
+ 4. Records cancellation metadata
1737
+ 5. Removes event subscriptions (if waiting for event)
1738
+ 6. Removes timer subscriptions (if waiting for timer)
1739
+
1740
+ Args:
1741
+ instance_id: Workflow instance to cancel
1742
+ cancelled_by: Who/what triggered the cancellation
1743
+
1744
+ Returns:
1745
+ True if successfully cancelled, False otherwise
1746
+
1747
+ Note: Uses LOCK operation session (separate from external session).
1748
+ """
1749
+ session = self._get_session_for_operation(is_lock_operation=True)
1750
+ async with self._session_scope(session) as session, session.begin():
1751
+ # Get current instance status
1752
+ result = await session.execute(
1753
+ select(WorkflowInstance.status).where(WorkflowInstance.instance_id == instance_id)
1754
+ )
1755
+ row = result.one_or_none()
1756
+
1757
+ if row is None:
1758
+ # Instance not found
1759
+ return False
1760
+
1761
+ current_status = row[0]
1762
+
1763
+ # Only allow cancellation of running, waiting, or compensating workflows
1764
+ # compensating workflows can be marked as cancelled after compensation completes
1765
+ if current_status not in (
1766
+ "running",
1767
+ "waiting_for_event",
1768
+ "waiting_for_timer",
1769
+ "compensating",
1770
+ ):
1771
+ # Already completed, failed, or cancelled
1772
+ return False
1773
+
1774
+ # Update status to cancelled and record metadata
1775
+ cancellation_metadata = {
1776
+ "cancelled_by": cancelled_by,
1777
+ "cancelled_at": datetime.now(UTC).isoformat(),
1778
+ "previous_status": current_status,
1779
+ }
1780
+
1781
+ await session.execute(
1782
+ update(WorkflowInstance)
1783
+ .where(WorkflowInstance.instance_id == instance_id)
1784
+ .values(
1785
+ status="cancelled",
1786
+ output_data=json.dumps(cancellation_metadata),
1787
+ locked_by=None,
1788
+ locked_at=None,
1789
+ updated_at=func.now(),
1790
+ )
1791
+ )
1792
+
1793
+ # Remove event subscriptions if waiting for event
1794
+ if current_status == "waiting_for_event":
1795
+ await session.execute(
1796
+ delete(WorkflowEventSubscription).where(
1797
+ WorkflowEventSubscription.instance_id == instance_id
1798
+ )
1799
+ )
1800
+
1801
+ # Remove timer subscriptions if waiting for timer
1802
+ if current_status == "waiting_for_timer":
1803
+ await session.execute(
1804
+ delete(WorkflowTimerSubscription).where(
1805
+ WorkflowTimerSubscription.instance_id == instance_id
1806
+ )
1807
+ )
1808
+
1809
+ return True