edda-framework 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ and transactional outbox pattern.
8
8
 
9
9
  import json
10
10
  import logging
11
- from collections.abc import AsyncIterator
11
+ from collections.abc import AsyncIterator, Awaitable, Callable
12
12
  from contextlib import asynccontextmanager
13
13
  from contextvars import ContextVar
14
14
  from dataclasses import dataclass, field
@@ -19,6 +19,7 @@ from sqlalchemy import (
19
19
  CheckConstraint,
20
20
  Column,
21
21
  DateTime,
22
+ ForeignKey,
22
23
  ForeignKeyConstraint,
23
24
  Index,
24
25
  Integer,
@@ -29,18 +30,21 @@ from sqlalchemy import (
29
30
  and_,
30
31
  delete,
31
32
  func,
33
+ inspect,
32
34
  or_,
33
35
  select,
34
36
  text,
35
37
  update,
36
38
  )
37
39
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
38
- from sqlalchemy.orm import declarative_base
40
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
39
41
 
40
42
  logger = logging.getLogger(__name__)
41
43
 
44
+
42
45
  # Declarative base for ORM models
43
- Base = declarative_base()
46
+ class Base(DeclarativeBase):
47
+ pass
44
48
 
45
49
 
46
50
  # ============================================================================
@@ -48,25 +52,25 @@ Base = declarative_base()
48
52
  # ============================================================================
49
53
 
50
54
 
51
- class SchemaVersion(Base): # type: ignore[valid-type, misc]
55
+ class SchemaVersion(Base):
52
56
  """Schema version tracking."""
53
57
 
54
58
  __tablename__ = "schema_version"
55
59
 
56
- version = Column(Integer, primary_key=True)
57
- applied_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
58
- description = Column(Text, nullable=False)
60
+ version: Mapped[int] = mapped_column(Integer, primary_key=True)
61
+ applied_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
62
+ description: Mapped[str] = mapped_column(Text)
59
63
 
60
64
 
61
- class WorkflowDefinition(Base): # type: ignore[valid-type, misc]
65
+ class WorkflowDefinition(Base):
62
66
  """Workflow definition (source code storage)."""
63
67
 
64
68
  __tablename__ = "workflow_definitions"
65
69
 
66
- workflow_name = Column(String(255), nullable=False, primary_key=True)
67
- source_hash = Column(String(64), nullable=False, primary_key=True)
68
- source_code = Column(Text, nullable=False)
69
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
70
+ workflow_name: Mapped[str] = mapped_column(String(255), primary_key=True)
71
+ source_hash: Mapped[str] = mapped_column(String(64), primary_key=True)
72
+ source_code: Mapped[str] = mapped_column(Text)
73
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
70
74
 
71
75
  __table_args__ = (
72
76
  Index("idx_definitions_name", "workflow_name"),
@@ -74,31 +78,34 @@ class WorkflowDefinition(Base): # type: ignore[valid-type, misc]
74
78
  )
75
79
 
76
80
 
77
- class WorkflowInstance(Base): # type: ignore[valid-type, misc]
81
+ class WorkflowInstance(Base):
78
82
  """Workflow instance with distributed locking support."""
79
83
 
80
84
  __tablename__ = "workflow_instances"
81
85
 
82
- instance_id = Column(String(255), primary_key=True)
83
- workflow_name = Column(String(255), nullable=False)
84
- source_hash = Column(String(64), nullable=False)
85
- owner_service = Column(String(255), nullable=False)
86
- status = Column(
87
- String(50),
88
- nullable=False,
89
- server_default=text("'running'"),
86
+ instance_id: Mapped[str] = mapped_column(String(255), primary_key=True)
87
+ workflow_name: Mapped[str] = mapped_column(String(255))
88
+ source_hash: Mapped[str] = mapped_column(String(64))
89
+ owner_service: Mapped[str] = mapped_column(String(255))
90
+ status: Mapped[str] = mapped_column(String(50), server_default=text("'running'"))
91
+ current_activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
92
+ continued_from: Mapped[str | None] = mapped_column(
93
+ String(255), ForeignKey("workflow_instances.instance_id"), nullable=True
90
94
  )
91
- current_activity_id = Column(String(255), nullable=True)
92
- started_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
93
- updated_at = Column(
94
- DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
95
+ started_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
96
+ updated_at: Mapped[datetime] = mapped_column(
97
+ DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
95
98
  )
96
- input_data = Column(Text, nullable=False) # JSON
97
- output_data = Column(Text, nullable=True) # JSON
98
- locked_by = Column(String(255), nullable=True)
99
- locked_at = Column(DateTime(timezone=True), nullable=True)
100
- lock_timeout_seconds = Column(Integer, nullable=True) # None = use global default (300s)
101
- lock_expires_at = Column(DateTime(timezone=True), nullable=True) # Absolute expiry time
99
+ input_data: Mapped[str] = mapped_column(Text) # JSON
100
+ output_data: Mapped[str | None] = mapped_column(Text, nullable=True) # JSON
101
+ locked_by: Mapped[str | None] = mapped_column(String(255), nullable=True)
102
+ locked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
103
+ lock_timeout_seconds: Mapped[int | None] = mapped_column(
104
+ Integer, nullable=True
105
+ ) # None = use global default (300s)
106
+ lock_expires_at: Mapped[datetime | None] = mapped_column(
107
+ DateTime(timezone=True), nullable=True
108
+ ) # Absolute expiry time
102
109
 
103
110
  __table_args__ = (
104
111
  ForeignKeyConstraint(
@@ -107,7 +114,7 @@ class WorkflowInstance(Base): # type: ignore[valid-type, misc]
107
114
  ),
108
115
  CheckConstraint(
109
116
  "status IN ('running', 'completed', 'failed', 'waiting_for_event', "
110
- "'waiting_for_timer', 'compensating', 'cancelled')",
117
+ "'waiting_for_timer', 'waiting_for_message', 'compensating', 'cancelled', 'recurred')",
111
118
  name="valid_status",
112
119
  ),
113
120
  Index("idx_instances_status", "status"),
@@ -117,25 +124,29 @@ class WorkflowInstance(Base): # type: ignore[valid-type, misc]
117
124
  Index("idx_instances_lock_expires", "lock_expires_at"),
118
125
  Index("idx_instances_updated", "updated_at"),
119
126
  Index("idx_instances_hash", "source_hash"),
127
+ Index("idx_instances_continued_from", "continued_from"),
128
+ # Composite index for find_resumable_workflows(): WHERE status='running' AND locked_by IS NULL
129
+ Index("idx_instances_resumable", "status", "locked_by"),
120
130
  )
121
131
 
122
132
 
123
- class WorkflowHistory(Base): # type: ignore[valid-type, misc]
133
+ class WorkflowHistory(Base):
124
134
  """Workflow execution history (for deterministic replay)."""
125
135
 
126
136
  __tablename__ = "workflow_history"
127
137
 
128
- id = Column(Integer, primary_key=True, autoincrement=True)
129
- instance_id = Column(
130
- String(255),
131
- nullable=False,
132
- )
133
- activity_id = Column(String(255), nullable=False)
134
- event_type = Column(String(100), nullable=False)
135
- data_type = Column(String(10), nullable=False) # 'json' or 'binary'
136
- event_data = Column(Text, nullable=True) # JSON (when data_type='json')
137
- event_data_binary = Column(LargeBinary, nullable=True) # Binary (when data_type='binary')
138
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
138
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
139
+ instance_id: Mapped[str] = mapped_column(String(255))
140
+ activity_id: Mapped[str] = mapped_column(String(255))
141
+ event_type: Mapped[str] = mapped_column(String(100))
142
+ data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
143
+ event_data: Mapped[str | None] = mapped_column(
144
+ Text, nullable=True
145
+ ) # JSON (when data_type='json')
146
+ event_data_binary: Mapped[bytes | None] = mapped_column(
147
+ LargeBinary, nullable=True
148
+ ) # Binary (when data_type='binary')
149
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
139
150
 
140
151
  __table_args__ = (
141
152
  ForeignKeyConstraint(
@@ -158,20 +169,20 @@ class WorkflowHistory(Base): # type: ignore[valid-type, misc]
158
169
  )
159
170
 
160
171
 
161
- class WorkflowCompensation(Base): # type: ignore[valid-type, misc]
162
- """Compensation transactions (LIFO stack for Saga pattern)."""
172
+ class WorkflowHistoryArchive(Base):
173
+ """Archived workflow execution history (for recur pattern)."""
163
174
 
164
- __tablename__ = "workflow_compensations"
175
+ __tablename__ = "workflow_history_archive"
165
176
 
166
- id = Column(Integer, primary_key=True, autoincrement=True)
167
- instance_id = Column(
168
- String(255),
169
- nullable=False,
177
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
178
+ instance_id: Mapped[str] = mapped_column(String(255))
179
+ activity_id: Mapped[str] = mapped_column(String(255))
180
+ event_type: Mapped[str] = mapped_column(String(100))
181
+ event_data: Mapped[str] = mapped_column(Text) # JSON (includes both types for archive)
182
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
183
+ archived_at: Mapped[datetime] = mapped_column(
184
+ DateTime(timezone=True), server_default=func.now()
170
185
  )
171
- activity_id = Column(String(255), nullable=False)
172
- activity_name = Column(String(255), nullable=False)
173
- args = Column(Text, nullable=False) # JSON
174
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
175
186
 
176
187
  __table_args__ = (
177
188
  ForeignKeyConstraint(
@@ -179,24 +190,22 @@ class WorkflowCompensation(Base): # type: ignore[valid-type, misc]
179
190
  ["workflow_instances.instance_id"],
180
191
  ondelete="CASCADE",
181
192
  ),
182
- Index("idx_compensations_instance", "instance_id", "created_at"),
193
+ Index("idx_history_archive_instance", "instance_id"),
194
+ Index("idx_history_archive_archived", "archived_at"),
183
195
  )
184
196
 
185
197
 
186
- class WorkflowEventSubscription(Base): # type: ignore[valid-type, misc]
187
- """Event subscriptions (for wait_event)."""
198
+ class WorkflowCompensation(Base):
199
+ """Compensation transactions (LIFO stack for Saga pattern)."""
188
200
 
189
- __tablename__ = "workflow_event_subscriptions"
201
+ __tablename__ = "workflow_compensations"
190
202
 
191
- id = Column(Integer, primary_key=True, autoincrement=True)
192
- instance_id = Column(
193
- String(255),
194
- nullable=False,
195
- )
196
- event_type = Column(String(255), nullable=False)
197
- activity_id = Column(String(255), nullable=True)
198
- timeout_at = Column(DateTime(timezone=True), nullable=True)
199
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
203
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
204
+ instance_id: Mapped[str] = mapped_column(String(255))
205
+ activity_id: Mapped[str] = mapped_column(String(255))
206
+ activity_name: Mapped[str] = mapped_column(String(255))
207
+ args: Mapped[str] = mapped_column(Text) # JSON
208
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
200
209
 
201
210
  __table_args__ = (
202
211
  ForeignKeyConstraint(
@@ -204,27 +213,21 @@ class WorkflowEventSubscription(Base): # type: ignore[valid-type, misc]
204
213
  ["workflow_instances.instance_id"],
205
214
  ondelete="CASCADE",
206
215
  ),
207
- UniqueConstraint("instance_id", "event_type", name="unique_instance_event"),
208
- Index("idx_subscriptions_event", "event_type"),
209
- Index("idx_subscriptions_timeout", "timeout_at"),
210
- Index("idx_subscriptions_instance", "instance_id"),
216
+ Index("idx_compensations_instance", "instance_id", "created_at"),
211
217
  )
212
218
 
213
219
 
214
- class WorkflowTimerSubscription(Base): # type: ignore[valid-type, misc]
220
+ class WorkflowTimerSubscription(Base):
215
221
  """Timer subscriptions (for wait_timer)."""
216
222
 
217
223
  __tablename__ = "workflow_timer_subscriptions"
218
224
 
219
- id = Column(Integer, primary_key=True, autoincrement=True)
220
- instance_id = Column(
221
- String(255),
222
- nullable=False,
223
- )
224
- timer_id = Column(String(255), nullable=False)
225
- expires_at = Column(DateTime(timezone=True), nullable=False)
226
- activity_id = Column(String(255), nullable=True)
227
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
225
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
226
+ instance_id: Mapped[str] = mapped_column(String(255))
227
+ timer_id: Mapped[str] = mapped_column(String(255))
228
+ expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
229
+ activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
230
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
228
231
 
229
232
  __table_args__ = (
230
233
  ForeignKeyConstraint(
@@ -238,23 +241,29 @@ class WorkflowTimerSubscription(Base): # type: ignore[valid-type, misc]
238
241
  )
239
242
 
240
243
 
241
- class OutboxEvent(Base): # type: ignore[valid-type, misc]
244
+ class OutboxEvent(Base):
242
245
  """Transactional outbox pattern events."""
243
246
 
244
247
  __tablename__ = "outbox_events"
245
248
 
246
- event_id = Column(String(255), primary_key=True)
247
- event_type = Column(String(255), nullable=False)
248
- event_source = Column(String(255), nullable=False)
249
- data_type = Column(String(10), nullable=False) # 'json' or 'binary'
250
- event_data = Column(Text, nullable=True) # JSON (when data_type='json')
251
- event_data_binary = Column(LargeBinary, nullable=True) # Binary (when data_type='binary')
252
- content_type = Column(String(100), nullable=False, server_default=text("'application/json'"))
253
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
254
- published_at = Column(DateTime(timezone=True), nullable=True)
255
- status = Column(String(50), nullable=False, server_default=text("'pending'"))
256
- retry_count = Column(Integer, nullable=False, server_default=text("0"))
257
- last_error = Column(Text, nullable=True)
249
+ event_id: Mapped[str] = mapped_column(String(255), primary_key=True)
250
+ event_type: Mapped[str] = mapped_column(String(255))
251
+ event_source: Mapped[str] = mapped_column(String(255))
252
+ data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
253
+ event_data: Mapped[str | None] = mapped_column(
254
+ Text, nullable=True
255
+ ) # JSON (when data_type='json')
256
+ event_data_binary: Mapped[bytes | None] = mapped_column(
257
+ LargeBinary, nullable=True
258
+ ) # Binary (when data_type='binary')
259
+ content_type: Mapped[str] = mapped_column(
260
+ String(100), server_default=text("'application/json'")
261
+ )
262
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
263
+ published_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
264
+ status: Mapped[str] = mapped_column(String(50), server_default=text("'pending'"))
265
+ retry_count: Mapped[int] = mapped_column(Integer, server_default=text("0"))
266
+ last_error: Mapped[str | None] = mapped_column(Text, nullable=True)
258
267
 
259
268
  __table_args__ = (
260
269
  CheckConstraint(
@@ -276,6 +285,178 @@ class OutboxEvent(Base): # type: ignore[valid-type, misc]
276
285
  )
277
286
 
278
287
 
288
+ class WorkflowGroupMembership(Base):
289
+ """Group memberships (Erlang pg style for broadcast messaging)."""
290
+
291
+ __tablename__ = "workflow_group_memberships"
292
+
293
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
294
+ instance_id: Mapped[str] = mapped_column(String(255))
295
+ group_name: Mapped[str] = mapped_column(String(255))
296
+ joined_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
297
+
298
+ __table_args__ = (
299
+ ForeignKeyConstraint(
300
+ ["instance_id"],
301
+ ["workflow_instances.instance_id"],
302
+ ondelete="CASCADE",
303
+ ),
304
+ UniqueConstraint("instance_id", "group_name", name="unique_instance_group"),
305
+ Index("idx_group_memberships_group", "group_name"),
306
+ Index("idx_group_memberships_instance", "instance_id"),
307
+ )
308
+
309
+
310
+ # =============================================================================
311
+ # Channel-based Message Queue Models
312
+ # =============================================================================
313
+
314
+
315
+ class ChannelMessage(Base):
316
+ """Channel message queue (persistent message storage)."""
317
+
318
+ __tablename__ = "channel_messages"
319
+
320
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
321
+ channel: Mapped[str] = mapped_column(String(255))
322
+ message_id: Mapped[str] = mapped_column(String(255), unique=True)
323
+ data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
324
+ data: Mapped[str | None] = mapped_column(Text, nullable=True) # JSON (when data_type='json')
325
+ data_binary: Mapped[bytes | None] = mapped_column(
326
+ LargeBinary, nullable=True
327
+ ) # Binary (when data_type='binary')
328
+ message_metadata: Mapped[str | None] = mapped_column(
329
+ "metadata", Text, nullable=True
330
+ ) # JSON - renamed to avoid SQLAlchemy reserved name
331
+ published_at: Mapped[datetime] = mapped_column(
332
+ DateTime(timezone=True), server_default=func.now()
333
+ )
334
+
335
+ __table_args__ = (
336
+ CheckConstraint(
337
+ "data_type IN ('json', 'binary')",
338
+ name="channel_valid_data_type",
339
+ ),
340
+ CheckConstraint(
341
+ "(data_type = 'json' AND data IS NOT NULL AND data_binary IS NULL) OR "
342
+ "(data_type = 'binary' AND data IS NULL AND data_binary IS NOT NULL)",
343
+ name="channel_data_type_consistency",
344
+ ),
345
+ Index("idx_channel_messages_channel", "channel", "published_at"),
346
+ Index("idx_channel_messages_id", "id"),
347
+ )
348
+
349
+
350
+ class ChannelSubscription(Base):
351
+ """Channel subscriptions for workflow instances."""
352
+
353
+ __tablename__ = "channel_subscriptions"
354
+
355
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
356
+ instance_id: Mapped[str] = mapped_column(String(255))
357
+ channel: Mapped[str] = mapped_column(String(255))
358
+ mode: Mapped[str] = mapped_column(String(20)) # 'broadcast' or 'competing'
359
+ activity_id: Mapped[str | None] = mapped_column(
360
+ String(255), nullable=True
361
+ ) # Set when waiting for message
362
+ cursor_message_id: Mapped[int | None] = mapped_column(
363
+ Integer, nullable=True
364
+ ) # Last received message id (broadcast)
365
+ timeout_at: Mapped[datetime | None] = mapped_column(
366
+ DateTime(timezone=True), nullable=True
367
+ ) # Timeout deadline
368
+ subscribed_at: Mapped[datetime] = mapped_column(
369
+ DateTime(timezone=True), server_default=func.now()
370
+ )
371
+
372
+ __table_args__ = (
373
+ ForeignKeyConstraint(
374
+ ["instance_id"],
375
+ ["workflow_instances.instance_id"],
376
+ ondelete="CASCADE",
377
+ ),
378
+ CheckConstraint(
379
+ "mode IN ('broadcast', 'competing')",
380
+ name="channel_valid_mode",
381
+ ),
382
+ UniqueConstraint("instance_id", "channel", name="unique_channel_instance_channel"),
383
+ Index("idx_channel_subscriptions_channel", "channel"),
384
+ Index("idx_channel_subscriptions_instance", "instance_id"),
385
+ Index("idx_channel_subscriptions_waiting", "channel", "activity_id"),
386
+ Index("idx_channel_subscriptions_timeout", "timeout_at"),
387
+ )
388
+
389
+
390
+ class ChannelDeliveryCursor(Base):
391
+ """Channel delivery cursors (broadcast mode: track who read what)."""
392
+
393
+ __tablename__ = "channel_delivery_cursors"
394
+
395
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
396
+ channel: Mapped[str] = mapped_column(String(255))
397
+ instance_id: Mapped[str] = mapped_column(String(255))
398
+ last_delivered_id: Mapped[int] = mapped_column(Integer)
399
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
400
+
401
+ __table_args__ = (
402
+ ForeignKeyConstraint(
403
+ ["instance_id"],
404
+ ["workflow_instances.instance_id"],
405
+ ondelete="CASCADE",
406
+ ),
407
+ UniqueConstraint("channel", "instance_id", name="unique_channel_delivery_cursor"),
408
+ Index("idx_channel_delivery_cursors_channel", "channel"),
409
+ )
410
+
411
+
412
+ class ChannelMessageClaim(Base):
413
+ """Channel message claims (competing mode: who is processing what)."""
414
+
415
+ __tablename__ = "channel_message_claims"
416
+
417
+ message_id: Mapped[str] = mapped_column(String(255), primary_key=True)
418
+ instance_id: Mapped[str] = mapped_column(String(255))
419
+ claimed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
420
+
421
+ __table_args__ = (
422
+ ForeignKeyConstraint(
423
+ ["message_id"],
424
+ ["channel_messages.message_id"],
425
+ ondelete="CASCADE",
426
+ ),
427
+ ForeignKeyConstraint(
428
+ ["instance_id"],
429
+ ["workflow_instances.instance_id"],
430
+ ondelete="CASCADE",
431
+ ),
432
+ Index("idx_channel_message_claims_instance", "instance_id"),
433
+ )
434
+
435
+
436
+ # =============================================================================
437
+ # System-level Lock Models (for background task coordination)
438
+ # =============================================================================
439
+
440
+
441
+ class SystemLock(Base):
442
+ """System-level locks for coordinating background tasks across pods.
443
+
444
+ Used to prevent duplicate execution of operational tasks like:
445
+ - cleanup_stale_locks_periodically()
446
+ - auto_resume_stale_workflows_periodically()
447
+ - _cleanup_old_messages_periodically()
448
+ """
449
+
450
+ __tablename__ = "system_locks"
451
+
452
+ lock_name: Mapped[str] = mapped_column(String(255), primary_key=True)
453
+ locked_by: Mapped[str | None] = mapped_column(String(255), nullable=True)
454
+ locked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
455
+ lock_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
456
+
457
+ __table_args__ = (Index("idx_system_locks_expires", "lock_expires_at"),)
458
+
459
+
279
460
  # Current schema version
280
461
  CURRENT_SCHEMA_VERSION = 1
281
462
 
@@ -302,6 +483,9 @@ class TransactionContext:
302
483
  session: "AsyncSession | None" = None
303
484
  """The actual session for this transaction"""
304
485
 
486
+ post_commit_callbacks: list[Callable[[], Awaitable[None]]] = field(default_factory=list)
487
+ """Callbacks to execute after successful top-level commit"""
488
+
305
489
 
306
490
  # Context variable for transaction state (asyncio-safe)
307
491
  _transaction_context: ContextVar[TransactionContext | None] = ContextVar(
@@ -337,11 +521,22 @@ class SQLAlchemyStorage:
337
521
  self.engine = engine
338
522
 
339
523
  async def initialize(self) -> None:
340
- """Initialize database connection and create tables."""
524
+ """Initialize database connection and create tables.
525
+
526
+ This method creates all tables if they don't exist, and then performs
527
+ automatic schema migration to add any missing columns and update CHECK
528
+ constraints. This ensures compatibility when upgrading Edda versions.
529
+ """
341
530
  # Create all tables and indexes
342
531
  async with self.engine.begin() as conn:
343
532
  await conn.run_sync(Base.metadata.create_all)
344
533
 
534
+ # Auto-migrate schema (add missing columns)
535
+ await self._auto_migrate_schema()
536
+
537
+ # Auto-migrate CHECK constraints
538
+ await self._auto_migrate_check_constraints()
539
+
345
540
  # Initialize schema version
346
541
  await self._initialize_schema_version()
347
542
 
@@ -366,6 +561,198 @@ class SQLAlchemyStorage:
366
561
  await session.commit()
367
562
  logger.info(f"Initialized schema version to {CURRENT_SCHEMA_VERSION}")
368
563
 
564
+ async def _auto_migrate_schema(self) -> None:
565
+ """
566
+ Automatically add missing columns to existing tables.
567
+
568
+ This method compares the ORM model definitions with the actual database
569
+ schema and adds any missing columns using ALTER TABLE ADD COLUMN.
570
+
571
+ Note: This only handles column additions, not removals or type changes.
572
+ For complex migrations, use Alembic.
573
+ """
574
+
575
+ def _get_column_type_sql(column: Column, dialect_name: str) -> str: # type: ignore[type-arg]
576
+ """Get SQL type string for a column based on dialect."""
577
+ col_type = column.type
578
+
579
+ # Map SQLAlchemy types to SQL types
580
+ if isinstance(col_type, String):
581
+ length = col_type.length or 255
582
+ return f"VARCHAR({length})"
583
+ elif isinstance(col_type, Text):
584
+ return "TEXT"
585
+ elif isinstance(col_type, Integer):
586
+ return "INTEGER"
587
+ elif isinstance(col_type, DateTime):
588
+ if dialect_name == "postgresql":
589
+ return "TIMESTAMP WITH TIME ZONE" if col_type.timezone else "TIMESTAMP"
590
+ elif dialect_name == "mysql":
591
+ return "DATETIME" if not col_type.timezone else "DATETIME"
592
+ else: # sqlite
593
+ return "DATETIME"
594
+ elif isinstance(col_type, LargeBinary):
595
+ if dialect_name == "postgresql":
596
+ return "BYTEA"
597
+ elif dialect_name == "mysql":
598
+ return "LONGBLOB"
599
+ else: # sqlite
600
+ return "BLOB"
601
+ else:
602
+ # Fallback to compiled type
603
+ return str(col_type.compile(dialect=self.engine.dialect))
604
+
605
+ def _get_default_sql(column: Column, _dialect_name: str) -> str | None: # type: ignore[type-arg]
606
+ """Get DEFAULT clause for a column if applicable."""
607
+ if column.server_default is not None:
608
+ # Handle text() server defaults - try to get the arg attribute
609
+ server_default = column.server_default
610
+ if hasattr(server_default, "arg"):
611
+ default_val = server_default.arg
612
+ if hasattr(default_val, "text"):
613
+ return f"DEFAULT {default_val.text}"
614
+ return f"DEFAULT {default_val}"
615
+ return None
616
+
617
+ def _run_migration(conn: Any) -> None:
618
+ """Run migration in sync context."""
619
+ dialect_name = self.engine.dialect.name
620
+ inspector = inspect(conn)
621
+
622
+ # Iterate through all ORM tables
623
+ for table in Base.metadata.tables.values():
624
+ table_name = table.name
625
+
626
+ # Check if table exists
627
+ if not inspector.has_table(table_name):
628
+ logger.debug(f"Table {table_name} does not exist, skipping migration")
629
+ continue
630
+
631
+ # Get existing columns
632
+ existing_columns = {col["name"] for col in inspector.get_columns(table_name)}
633
+
634
+ # Check each column in the ORM model
635
+ for column in table.columns:
636
+ if column.name not in existing_columns:
637
+ # Column is missing, generate ALTER TABLE
638
+ col_type_sql = _get_column_type_sql(column, dialect_name)
639
+ nullable = "NULL" if column.nullable else "NOT NULL"
640
+
641
+ # Build ALTER TABLE statement
642
+ alter_sql = (
643
+ f'ALTER TABLE "{table_name}" ADD COLUMN "{column.name}" {col_type_sql}'
644
+ )
645
+
646
+ # Add nullable constraint (only if NOT NULL and has default)
647
+ default_sql = _get_default_sql(column, dialect_name)
648
+ if not column.nullable and default_sql:
649
+ alter_sql += f" {default_sql} {nullable}"
650
+ elif column.nullable:
651
+ alter_sql += f" {nullable}"
652
+ elif default_sql:
653
+ alter_sql += f" {default_sql}"
654
+ # For NOT NULL without default, just add the column as nullable
655
+ # (PostgreSQL requires default or nullable for existing rows)
656
+ else:
657
+ alter_sql += " NULL"
658
+
659
+ logger.info(f"Auto-migrating: Adding column {column.name} to {table_name}")
660
+ logger.debug(f"Executing: {alter_sql}")
661
+
662
+ try:
663
+ conn.execute(text(alter_sql))
664
+ except Exception as e:
665
+ logger.warning(
666
+ f"Failed to add column {column.name} to {table_name}: {e}"
667
+ )
668
+
669
+ async with self.engine.begin() as conn:
670
+ await conn.run_sync(_run_migration)
671
+
672
+ async def _auto_migrate_check_constraints(self) -> None:
673
+ """
674
+ Automatically update CHECK constraints for workflow status.
675
+
676
+ This method ensures the valid_status CHECK constraint includes all
677
+ required status values (including 'waiting_for_message').
678
+ """
679
+ dialect_name = self.engine.dialect.name
680
+
681
+ # SQLite doesn't support ALTER CONSTRAINT easily, and SQLAlchemy create_all
682
+ # handles it correctly for new databases. For existing SQLite databases,
683
+ # the constraint is more lenient (CHECK is not enforced in many SQLite versions).
684
+ if dialect_name == "sqlite":
685
+ return
686
+
687
+ # Expected status values (must match WorkflowInstance model)
688
+ expected_statuses = (
689
+ "'running', 'completed', 'failed', 'waiting_for_event', "
690
+ "'waiting_for_timer', 'waiting_for_message', 'compensating', 'cancelled', 'recurred'"
691
+ )
692
+
693
+ def _run_constraint_migration(conn: Any) -> None:
694
+ """Run CHECK constraint migration in sync context."""
695
+ inspector = inspect(conn)
696
+
697
+ # Check if workflow_instances table exists
698
+ if not inspector.has_table("workflow_instances"):
699
+ return
700
+
701
+ # Get existing CHECK constraints
702
+ try:
703
+ constraints = inspector.get_check_constraints("workflow_instances")
704
+ except NotImplementedError:
705
+ # Some databases don't support get_check_constraints
706
+ logger.debug("Database doesn't support get_check_constraints inspection")
707
+ constraints = []
708
+
709
+ # Find the valid_status constraint
710
+ valid_status_constraint = None
711
+ for constraint in constraints:
712
+ if constraint.get("name") == "valid_status":
713
+ valid_status_constraint = constraint
714
+ break
715
+
716
+ # Check if constraint exists and needs updating
717
+ if valid_status_constraint:
718
+ sqltext = valid_status_constraint.get("sqltext", "")
719
+ # Check if 'waiting_for_message' is already in the constraint
720
+ if "waiting_for_message" in sqltext:
721
+ logger.debug("valid_status constraint already includes waiting_for_message")
722
+ return
723
+
724
+ # Need to update the constraint
725
+ logger.info("Updating valid_status CHECK constraint to include waiting_for_message")
726
+ try:
727
+ if dialect_name == "postgresql":
728
+ conn.execute(
729
+ text("ALTER TABLE workflow_instances DROP CONSTRAINT valid_status")
730
+ )
731
+ conn.execute(
732
+ text(
733
+ f"ALTER TABLE workflow_instances ADD CONSTRAINT valid_status "
734
+ f"CHECK (status IN ({expected_statuses}))"
735
+ )
736
+ )
737
+ elif dialect_name == "mysql":
738
+ # MySQL uses DROP CHECK and ADD CONSTRAINT CHECK syntax
739
+ conn.execute(text("ALTER TABLE workflow_instances DROP CHECK valid_status"))
740
+ conn.execute(
741
+ text(
742
+ f"ALTER TABLE workflow_instances ADD CONSTRAINT valid_status "
743
+ f"CHECK (status IN ({expected_statuses}))"
744
+ )
745
+ )
746
+ logger.info("Successfully updated valid_status CHECK constraint")
747
+ except Exception as e:
748
+ logger.warning(f"Failed to update valid_status CHECK constraint: {e}")
749
+ else:
750
+ # Constraint doesn't exist (shouldn't happen with create_all, but handle it)
751
+ logger.debug("valid_status constraint not found, will be created by create_all")
752
+
753
+ async with self.engine.begin() as conn:
754
+ await conn.run_sync(_run_constraint_migration)
755
+
369
756
  def _get_session_for_operation(self, is_lock_operation: bool = False) -> AsyncSession:
370
757
  """
371
758
  Get the appropriate session for an operation.
@@ -452,7 +839,7 @@ class SQLAlchemyStorage:
452
839
  Example:
453
840
  >>> # SQLite: datetime(timeout_at) <= datetime('now')
454
841
  >>> # PostgreSQL/MySQL: timeout_at <= NOW()
455
- >>> self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
842
+ >>> self._make_datetime_comparable(ChannelSubscription.timeout_at)
456
843
  >>> <= self._get_current_time_expr()
457
844
  """
458
845
  if self.engine.dialect.name == "sqlite":
@@ -504,11 +891,16 @@ class SQLAlchemyStorage:
504
891
  if ctx is None or ctx.depth == 0:
505
892
  raise RuntimeError("Not in a transaction")
506
893
 
894
+ # Capture callbacks before any state changes
895
+ callbacks_to_execute: list[Callable[[], Awaitable[None]]] = []
896
+
507
897
  if ctx.depth == 1:
508
898
  # Top-level transaction - commit the session
509
899
  logger.debug("Committing top-level transaction")
510
900
  await ctx.session.commit() # type: ignore[union-attr]
511
901
  await ctx.session.close() # type: ignore[union-attr]
902
+ # Capture callbacks to execute after clearing context
903
+ callbacks_to_execute = ctx.post_commit_callbacks.copy()
512
904
  else:
513
905
  # Nested transaction - commit the savepoint
514
906
  nested_tx = ctx.savepoint_stack.pop()
@@ -520,6 +912,12 @@ class SQLAlchemyStorage:
520
912
  if ctx.depth == 0:
521
913
  # All transactions completed - clear context
522
914
  _transaction_context.set(None)
915
+ # Execute post-commit callbacks after successful top-level commit
916
+ for callback in callbacks_to_execute:
917
+ try:
918
+ await callback()
919
+ except Exception as e:
920
+ logger.error(f"Post-commit callback failed: {e}")
523
921
 
524
922
  async def rollback_transaction(self) -> None:
525
923
  """
@@ -559,6 +957,26 @@ class SQLAlchemyStorage:
559
957
  ctx = _transaction_context.get()
560
958
  return ctx is not None and ctx.depth > 0
561
959
 
960
+ def register_post_commit_callback(self, callback: Callable[[], Awaitable[None]]) -> None:
961
+ """
962
+ Register a callback to be executed after the current transaction commits.
963
+
964
+ The callback will be executed after the top-level transaction commits successfully.
965
+ If the transaction is rolled back, the callback will NOT be executed.
966
+ If not in a transaction, the callback will be executed immediately.
967
+
968
+ Args:
969
+ callback: An async function to call after commit.
970
+
971
+ Raises:
972
+ RuntimeError: If not in a transaction.
973
+ """
974
+ ctx = _transaction_context.get()
975
+ if ctx is None or ctx.depth == 0:
976
+ raise RuntimeError("Cannot register post-commit callback: not in a transaction")
977
+ ctx.post_commit_callbacks.append(callback)
978
+ logger.debug(f"Registered post-commit callback: {callback}")
979
+
562
980
  async def _commit_if_not_in_transaction(self, session: AsyncSession) -> None:
563
981
  """
564
982
  Commit session if not in a transaction (auto-commit mode).
@@ -605,7 +1023,7 @@ class SQLAlchemyStorage:
605
1023
 
606
1024
  if existing:
607
1025
  # Update
608
- existing.source_code = source_code # type: ignore[assignment]
1026
+ existing.source_code = source_code
609
1027
  else:
610
1028
  # Insert
611
1029
  definition = WorkflowDefinition(
@@ -682,6 +1100,7 @@ class SQLAlchemyStorage:
682
1100
  owner_service: str,
683
1101
  input_data: dict[str, Any],
684
1102
  lock_timeout_seconds: int | None = None,
1103
+ continued_from: str | None = None,
685
1104
  ) -> None:
686
1105
  """Create a new workflow instance."""
687
1106
  session = self._get_session_for_operation()
@@ -693,6 +1112,7 @@ class SQLAlchemyStorage:
693
1112
  owner_service=owner_service,
694
1113
  input_data=json.dumps(input_data),
695
1114
  lock_timeout_seconds=lock_timeout_seconds,
1115
+ continued_from=continued_from,
696
1116
  )
697
1117
  session.add(instance)
698
1118
 
@@ -1165,6 +1585,103 @@ class SQLAlchemyStorage:
1165
1585
  await session.commit()
1166
1586
  return workflows_to_resume
1167
1587
 
1588
+ # -------------------------------------------------------------------------
1589
+ # System-level Locking Methods (for background task coordination)
1590
+ # -------------------------------------------------------------------------
1591
+
1592
+ async def try_acquire_system_lock(
1593
+ self,
1594
+ lock_name: str,
1595
+ worker_id: str,
1596
+ timeout_seconds: int = 60,
1597
+ ) -> bool:
1598
+ """
1599
+ Try to acquire a system-level lock for coordinating background tasks.
1600
+
1601
+ Uses INSERT ON CONFLICT pattern to handle race conditions:
1602
+ 1. Try to INSERT new lock record
1603
+ 2. If exists, check if expired or unlocked
1604
+ 3. If available, acquire lock; otherwise return False
1605
+
1606
+ Note: ALWAYS uses separate session (not external session).
1607
+ """
1608
+ session = self._get_session_for_operation(is_lock_operation=True)
1609
+ async with self._session_scope(session) as session:
1610
+ current_time = datetime.now(UTC)
1611
+ lock_expires_at = current_time + timedelta(seconds=timeout_seconds)
1612
+
1613
+ # Try to get existing lock
1614
+ result = await session.execute(
1615
+ select(SystemLock).where(SystemLock.lock_name == lock_name)
1616
+ )
1617
+ lock = result.scalar_one_or_none()
1618
+
1619
+ if lock is None:
1620
+ # No lock exists - create new one
1621
+ lock = SystemLock(
1622
+ lock_name=lock_name,
1623
+ locked_by=worker_id,
1624
+ locked_at=current_time,
1625
+ lock_expires_at=lock_expires_at,
1626
+ )
1627
+ session.add(lock)
1628
+ await session.commit()
1629
+ return True
1630
+
1631
+ # Lock exists - check if available
1632
+ if lock.locked_by is None:
1633
+ # Unlocked - acquire
1634
+ lock.locked_by = worker_id
1635
+ lock.locked_at = current_time
1636
+ lock.lock_expires_at = lock_expires_at
1637
+ await session.commit()
1638
+ return True
1639
+
1640
+ # Check if expired (use SQL-side comparison for cross-DB compatibility)
1641
+ if lock.lock_expires_at is not None:
1642
+ # Handle timezone-naive datetime from SQLite
1643
+ lock_expires = (
1644
+ lock.lock_expires_at.replace(tzinfo=UTC)
1645
+ if lock.lock_expires_at.tzinfo is None
1646
+ else lock.lock_expires_at
1647
+ )
1648
+ if lock_expires <= current_time:
1649
+ # Expired - acquire
1650
+ lock.locked_by = worker_id
1651
+ lock.locked_at = current_time
1652
+ lock.lock_expires_at = lock_expires_at
1653
+ await session.commit()
1654
+ return True
1655
+
1656
+ # Already locked by another worker
1657
+ return False
1658
+
1659
+ async def release_system_lock(self, lock_name: str, worker_id: str) -> None:
1660
+ """
1661
+ Release a system-level lock.
1662
+
1663
+ Only releases the lock if it's held by the specified worker.
1664
+
1665
+ Note: ALWAYS uses separate session (not external session).
1666
+ """
1667
+ session = self._get_session_for_operation(is_lock_operation=True)
1668
+ async with self._session_scope(session) as session:
1669
+ await session.execute(
1670
+ update(SystemLock)
1671
+ .where(
1672
+ and_(
1673
+ SystemLock.lock_name == lock_name,
1674
+ SystemLock.locked_by == worker_id,
1675
+ )
1676
+ )
1677
+ .values(
1678
+ locked_by=None,
1679
+ locked_at=None,
1680
+ lock_expires_at=None,
1681
+ )
1682
+ )
1683
+ await session.commit()
1684
+
1168
1685
  # -------------------------------------------------------------------------
1169
1686
  # History Methods (prefer external session)
1170
1687
  # -------------------------------------------------------------------------
@@ -1231,6 +1748,107 @@ class SQLAlchemyStorage:
1231
1748
  for row in rows
1232
1749
  ]
1233
1750
 
1751
+ async def archive_history(self, instance_id: str) -> int:
1752
+ """
1753
+ Archive workflow history for the recur pattern.
1754
+
1755
+ Moves all history entries from workflow_history to workflow_history_archive.
1756
+ Binary data is converted to base64 for JSON storage in the archive.
1757
+
1758
+ Returns:
1759
+ Number of history entries archived
1760
+ """
1761
+ import base64
1762
+
1763
+ session = self._get_session_for_operation()
1764
+ async with self._session_scope(session) as session:
1765
+ # Get all history entries for this instance
1766
+ result = await session.execute(
1767
+ select(WorkflowHistory)
1768
+ .where(WorkflowHistory.instance_id == instance_id)
1769
+ .order_by(WorkflowHistory.created_at.asc())
1770
+ )
1771
+ history_rows = result.scalars().all()
1772
+
1773
+ if not history_rows:
1774
+ return 0
1775
+
1776
+ # Archive each history entry
1777
+ for row in history_rows:
1778
+ # Convert event_data to JSON string for archive
1779
+ event_data_json: str | None
1780
+ if row.data_type == "binary" and row.event_data_binary is not None:
1781
+ # Convert binary to base64 for JSON storage
1782
+ event_data_json = json.dumps(
1783
+ {
1784
+ "_binary": True,
1785
+ "data": base64.b64encode(row.event_data_binary).decode("ascii"),
1786
+ }
1787
+ )
1788
+ else:
1789
+ # Already JSON, use as-is
1790
+ event_data_json = row.event_data
1791
+
1792
+ archive_entry = WorkflowHistoryArchive(
1793
+ instance_id=row.instance_id,
1794
+ activity_id=row.activity_id,
1795
+ event_type=row.event_type,
1796
+ event_data=event_data_json,
1797
+ created_at=row.created_at,
1798
+ )
1799
+ session.add(archive_entry)
1800
+
1801
+ # Delete original history entries
1802
+ await session.execute(
1803
+ delete(WorkflowHistory).where(WorkflowHistory.instance_id == instance_id)
1804
+ )
1805
+
1806
+ await self._commit_if_not_in_transaction(session)
1807
+ return len(history_rows)
1808
+
1809
+ async def find_first_cancellation_event(self, instance_id: str) -> dict[str, Any] | None:
1810
+ """
1811
+ Find the first cancellation event in workflow history.
1812
+
1813
+ Uses LIMIT 1 optimization to avoid loading all history events.
1814
+ """
1815
+ session = self._get_session_for_operation()
1816
+ async with self._session_scope(session) as session:
1817
+ # Query for cancellation events using LIMIT 1
1818
+ result = await session.execute(
1819
+ select(WorkflowHistory)
1820
+ .where(
1821
+ and_(
1822
+ WorkflowHistory.instance_id == instance_id,
1823
+ or_(
1824
+ WorkflowHistory.event_type == "WorkflowCancelled",
1825
+ func.lower(WorkflowHistory.event_type).contains("cancel"),
1826
+ ),
1827
+ )
1828
+ )
1829
+ .order_by(WorkflowHistory.created_at.asc())
1830
+ .limit(1)
1831
+ )
1832
+ row = result.scalars().first()
1833
+
1834
+ if row is None:
1835
+ return None
1836
+
1837
+ # Parse event_data based on data_type
1838
+ if row.data_type == "binary" and row.event_data_binary is not None:
1839
+ event_data: dict[str, Any] | bytes = row.event_data_binary
1840
+ else:
1841
+ event_data = json.loads(row.event_data) if row.event_data else {}
1842
+
1843
+ return {
1844
+ "id": row.id,
1845
+ "instance_id": row.instance_id,
1846
+ "activity_id": row.activity_id,
1847
+ "event_type": row.event_type,
1848
+ "event_data": event_data,
1849
+ "created_at": row.created_at,
1850
+ }
1851
+
1234
1852
  # -------------------------------------------------------------------------
1235
1853
  # Compensation Methods (prefer external session)
1236
1854
  # -------------------------------------------------------------------------
@@ -1271,7 +1889,7 @@ class SQLAlchemyStorage:
1271
1889
  "instance_id": row.instance_id,
1272
1890
  "activity_id": row.activity_id,
1273
1891
  "activity_name": row.activity_name,
1274
- "args": json.loads(row.args), # type: ignore[arg-type]
1892
+ "args": json.loads(row.args) if row.args else [],
1275
1893
  "created_at": row.created_at.isoformat(),
1276
1894
  }
1277
1895
  for row in rows
@@ -1287,224 +1905,25 @@ class SQLAlchemyStorage:
1287
1905
  await self._commit_if_not_in_transaction(session)
1288
1906
 
1289
1907
  # -------------------------------------------------------------------------
1290
- # Event Subscription Methods (prefer external session for registration)
1908
+ # Timer Subscription Methods
1291
1909
  # -------------------------------------------------------------------------
1292
1910
 
1293
- async def add_event_subscription(
1911
+ async def register_timer_subscription_and_release_lock(
1294
1912
  self,
1295
1913
  instance_id: str,
1296
- event_type: str,
1297
- timeout_at: datetime | None = None,
1914
+ worker_id: str,
1915
+ timer_id: str,
1916
+ expires_at: datetime,
1917
+ activity_id: str | None = None,
1298
1918
  ) -> None:
1299
- """Register an event wait subscription."""
1300
- session = self._get_session_for_operation()
1301
- async with self._session_scope(session) as session:
1302
- subscription = WorkflowEventSubscription(
1303
- instance_id=instance_id,
1304
- event_type=event_type,
1305
- timeout_at=timeout_at,
1306
- )
1307
- session.add(subscription)
1308
- await self._commit_if_not_in_transaction(session)
1919
+ """
1920
+ Atomically register timer subscription and release workflow lock.
1309
1921
 
1310
- async def find_waiting_instances(self, event_type: str) -> list[dict[str, Any]]:
1311
- """Find workflow instances waiting for a specific event type."""
1312
- session = self._get_session_for_operation()
1313
- async with self._session_scope(session) as session:
1314
- result = await session.execute(
1315
- select(WorkflowEventSubscription).where(
1316
- and_(
1317
- WorkflowEventSubscription.event_type == event_type,
1318
- or_(
1319
- WorkflowEventSubscription.timeout_at.is_(None),
1320
- self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
1321
- > self._get_current_time_expr(),
1322
- ),
1323
- )
1324
- )
1325
- )
1326
- rows = result.scalars().all()
1327
-
1328
- return [
1329
- {
1330
- "id": row.id,
1331
- "instance_id": row.instance_id,
1332
- "event_type": row.event_type,
1333
- "activity_id": row.activity_id,
1334
- "timeout_at": row.timeout_at.isoformat() if row.timeout_at else None,
1335
- "created_at": row.created_at.isoformat(),
1336
- }
1337
- for row in rows
1338
- ]
1339
-
1340
- async def remove_event_subscription(
1341
- self,
1342
- instance_id: str,
1343
- event_type: str,
1344
- ) -> None:
1345
- """Remove event subscription after the event is received."""
1346
- session = self._get_session_for_operation()
1347
- async with self._session_scope(session) as session:
1348
- await session.execute(
1349
- delete(WorkflowEventSubscription).where(
1350
- and_(
1351
- WorkflowEventSubscription.instance_id == instance_id,
1352
- WorkflowEventSubscription.event_type == event_type,
1353
- )
1354
- )
1355
- )
1356
- await self._commit_if_not_in_transaction(session)
1357
-
1358
- async def cleanup_expired_subscriptions(self) -> int:
1359
- """Clean up event subscriptions that have timed out."""
1360
- session = self._get_session_for_operation()
1361
- async with self._session_scope(session) as session:
1362
- result = await session.execute(
1363
- delete(WorkflowEventSubscription).where(
1364
- and_(
1365
- WorkflowEventSubscription.timeout_at.isnot(None),
1366
- self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
1367
- <= self._get_current_time_expr(),
1368
- )
1369
- )
1370
- )
1371
- await self._commit_if_not_in_transaction(session)
1372
- return result.rowcount or 0 # type: ignore[attr-defined]
1373
-
1374
- async def find_expired_event_subscriptions(self) -> list[dict[str, Any]]:
1375
- """Find event subscriptions that have timed out."""
1376
- session = self._get_session_for_operation()
1377
- async with self._session_scope(session) as session:
1378
- result = await session.execute(
1379
- select(
1380
- WorkflowEventSubscription.instance_id,
1381
- WorkflowEventSubscription.event_type,
1382
- WorkflowEventSubscription.activity_id,
1383
- WorkflowEventSubscription.timeout_at,
1384
- WorkflowEventSubscription.created_at,
1385
- ).where(
1386
- and_(
1387
- WorkflowEventSubscription.timeout_at.isnot(None),
1388
- self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
1389
- <= self._get_current_time_expr(),
1390
- )
1391
- )
1392
- )
1393
- rows = result.all()
1394
-
1395
- return [
1396
- {
1397
- "instance_id": row[0],
1398
- "event_type": row[1],
1399
- "activity_id": row[2],
1400
- "timeout_at": row[3].isoformat() if row[3] else None,
1401
- "created_at": row[4].isoformat() if row[4] else None,
1402
- }
1403
- for row in rows
1404
- ]
1405
-
1406
- async def register_event_subscription_and_release_lock(
1407
- self,
1408
- instance_id: str,
1409
- worker_id: str,
1410
- event_type: str,
1411
- timeout_at: datetime | None = None,
1412
- activity_id: str | None = None,
1413
- ) -> None:
1414
- """
1415
- Atomically register event subscription and release workflow lock.
1416
-
1417
- This performs THREE operations in a SINGLE transaction:
1418
- 1. Register event subscription
1419
- 2. Update current activity
1420
- 3. Release lock
1421
-
1422
- This ensures distributed coroutines work correctly - when a workflow
1423
- calls wait_event(), the subscription is registered and lock is released
1424
- atomically, so ANY worker can resume the workflow when the event arrives.
1425
-
1426
- Note: Uses LOCK operation session (separate from external session).
1427
- """
1428
- session = self._get_session_for_operation(is_lock_operation=True)
1429
- async with self._session_scope(session) as session, session.begin():
1430
- # 1. Verify we hold the lock (sanity check)
1431
- result = await session.execute(
1432
- select(WorkflowInstance.locked_by).where(
1433
- WorkflowInstance.instance_id == instance_id
1434
- )
1435
- )
1436
- row = result.one_or_none()
1437
-
1438
- if row is None:
1439
- raise RuntimeError(f"Workflow instance {instance_id} not found")
1440
-
1441
- current_lock_holder = row[0]
1442
- if current_lock_holder != worker_id:
1443
- raise RuntimeError(
1444
- f"Cannot release lock: worker {worker_id} does not hold lock "
1445
- f"for {instance_id} (held by: {current_lock_holder})"
1446
- )
1447
-
1448
- # 2. Register event subscription (INSERT OR REPLACE equivalent)
1449
- # First delete existing
1450
- await session.execute(
1451
- delete(WorkflowEventSubscription).where(
1452
- and_(
1453
- WorkflowEventSubscription.instance_id == instance_id,
1454
- WorkflowEventSubscription.event_type == event_type,
1455
- )
1456
- )
1457
- )
1458
-
1459
- # Then insert new
1460
- subscription = WorkflowEventSubscription(
1461
- instance_id=instance_id,
1462
- event_type=event_type,
1463
- activity_id=activity_id,
1464
- timeout_at=timeout_at,
1465
- )
1466
- session.add(subscription)
1467
-
1468
- # 3. Update current activity (if provided)
1469
- if activity_id is not None:
1470
- await session.execute(
1471
- update(WorkflowInstance)
1472
- .where(WorkflowInstance.instance_id == instance_id)
1473
- .values(current_activity_id=activity_id, updated_at=func.now())
1474
- )
1475
-
1476
- # 4. Release lock
1477
- await session.execute(
1478
- update(WorkflowInstance)
1479
- .where(
1480
- and_(
1481
- WorkflowInstance.instance_id == instance_id,
1482
- WorkflowInstance.locked_by == worker_id,
1483
- )
1484
- )
1485
- .values(
1486
- locked_by=None,
1487
- locked_at=None,
1488
- updated_at=func.now(),
1489
- )
1490
- )
1491
-
1492
- async def register_timer_subscription_and_release_lock(
1493
- self,
1494
- instance_id: str,
1495
- worker_id: str,
1496
- timer_id: str,
1497
- expires_at: datetime,
1498
- activity_id: str | None = None,
1499
- ) -> None:
1500
- """
1501
- Atomically register timer subscription and release workflow lock.
1502
-
1503
- This performs FOUR operations in a SINGLE transaction:
1504
- 1. Register timer subscription
1505
- 2. Update current activity
1506
- 3. Update status to 'waiting_for_timer'
1507
- 4. Release lock
1922
+ This performs FOUR operations in a SINGLE transaction:
1923
+ 1. Register timer subscription
1924
+ 2. Update current activity
1925
+ 3. Update status to 'waiting_for_timer'
1926
+ 4. Release lock
1508
1927
 
1509
1928
  This ensures distributed coroutines work correctly - when a workflow
1510
1929
  calls wait_timer(), the subscription is registered and lock is released
@@ -1580,7 +1999,11 @@ class SQLAlchemyStorage:
1580
1999
  )
1581
2000
 
1582
2001
  async def find_expired_timers(self) -> list[dict[str, Any]]:
1583
- """Find timer subscriptions that have expired."""
2002
+ """Find timer subscriptions that have expired.
2003
+
2004
+ Returns timer info including workflow status to avoid N+1 queries.
2005
+ The SQL query already filters by status='waiting_for_timer'.
2006
+ """
1584
2007
  session = self._get_session_for_operation()
1585
2008
  async with self._session_scope(session) as session:
1586
2009
  result = await session.execute(
@@ -1590,6 +2013,7 @@ class SQLAlchemyStorage:
1590
2013
  WorkflowTimerSubscription.expires_at,
1591
2014
  WorkflowTimerSubscription.activity_id,
1592
2015
  WorkflowInstance.workflow_name,
2016
+ WorkflowInstance.status, # Include status to avoid N+1 query
1593
2017
  )
1594
2018
  .join(
1595
2019
  WorkflowInstance,
@@ -1612,6 +2036,7 @@ class SQLAlchemyStorage:
1612
2036
  "expires_at": row[2].isoformat(),
1613
2037
  "activity_id": row[3],
1614
2038
  "workflow_name": row[4],
2039
+ "status": row[5], # Always 'waiting_for_timer' due to WHERE clause
1615
2040
  }
1616
2041
  for row in rows
1617
2042
  ]
@@ -1831,12 +2256,17 @@ class SQLAlchemyStorage:
1831
2256
  Only running or waiting_for_event workflows can be cancelled.
1832
2257
  This method atomically:
1833
2258
  1. Checks current status
1834
- 2. Updates status to 'cancelled' if allowed
2259
+ 2. Updates status to 'cancelled' if allowed (with atomic status check)
1835
2260
  3. Clears locks
1836
2261
  4. Records cancellation metadata
1837
2262
  5. Removes event subscriptions (if waiting for event)
1838
2263
  6. Removes timer subscriptions (if waiting for timer)
1839
2264
 
2265
+ The UPDATE includes a status condition in WHERE clause to prevent
2266
+ TOCTOU (time-of-check to time-of-use) race conditions. If the status
2267
+ changes between SELECT and UPDATE, the UPDATE will affect 0 rows
2268
+ and the cancellation will fail safely.
2269
+
1840
2270
  Args:
1841
2271
  instance_id: Workflow instance to cancel
1842
2272
  cancelled_by: Who/what triggered the cancellation
@@ -1846,6 +2276,14 @@ class SQLAlchemyStorage:
1846
2276
 
1847
2277
  Note: Uses LOCK operation session (separate from external session).
1848
2278
  """
2279
+ cancellable_statuses = (
2280
+ "running",
2281
+ "waiting_for_event",
2282
+ "waiting_for_timer",
2283
+ "waiting_for_message",
2284
+ "compensating",
2285
+ )
2286
+
1849
2287
  session = self._get_session_for_operation(is_lock_operation=True)
1850
2288
  async with self._session_scope(session) as session, session.begin():
1851
2289
  # Get current instance status
@@ -1862,41 +2300,42 @@ class SQLAlchemyStorage:
1862
2300
 
1863
2301
  # Only allow cancellation of running, waiting, or compensating workflows
1864
2302
  # compensating workflows can be marked as cancelled after compensation completes
1865
- if current_status not in (
1866
- "running",
1867
- "waiting_for_event",
1868
- "waiting_for_timer",
1869
- "compensating",
1870
- ):
2303
+ if current_status not in cancellable_statuses:
1871
2304
  # Already completed, failed, or cancelled
1872
2305
  return False
1873
2306
 
1874
2307
  # Update status to cancelled and record metadata
2308
+ # IMPORTANT: Include status condition in WHERE clause to prevent TOCTOU race
2309
+ # If another worker changed the status between SELECT and UPDATE,
2310
+ # this UPDATE will affect 0 rows and we'll return False
1875
2311
  cancellation_metadata = {
1876
2312
  "cancelled_by": cancelled_by,
1877
2313
  "cancelled_at": datetime.now(UTC).isoformat(),
1878
2314
  "previous_status": current_status,
1879
2315
  }
1880
2316
 
1881
- await session.execute(
2317
+ update_result = await session.execute(
1882
2318
  update(WorkflowInstance)
1883
- .where(WorkflowInstance.instance_id == instance_id)
2319
+ .where(
2320
+ and_(
2321
+ WorkflowInstance.instance_id == instance_id,
2322
+ WorkflowInstance.status == current_status, # Atomic check
2323
+ )
2324
+ )
1884
2325
  .values(
1885
2326
  status="cancelled",
1886
2327
  output_data=json.dumps(cancellation_metadata),
1887
2328
  locked_by=None,
1888
2329
  locked_at=None,
2330
+ lock_expires_at=None,
1889
2331
  updated_at=func.now(),
1890
2332
  )
1891
2333
  )
1892
2334
 
1893
- # Remove event subscriptions if waiting for event
1894
- if current_status == "waiting_for_event":
1895
- await session.execute(
1896
- delete(WorkflowEventSubscription).where(
1897
- WorkflowEventSubscription.instance_id == instance_id
1898
- )
1899
- )
2335
+ if update_result.rowcount == 0: # type: ignore[attr-defined]
2336
+ # Status changed between SELECT and UPDATE (race condition)
2337
+ # Another worker may have resumed/modified the workflow
2338
+ return False
1900
2339
 
1901
2340
  # Remove timer subscriptions if waiting for timer
1902
2341
  if current_status == "waiting_for_timer":
@@ -1906,4 +2345,1103 @@ class SQLAlchemyStorage:
1906
2345
  )
1907
2346
  )
1908
2347
 
2348
+ # Clear channel subscriptions if waiting for event/message
2349
+ if current_status in ("waiting_for_event", "waiting_for_message"):
2350
+ await session.execute(
2351
+ update(ChannelSubscription)
2352
+ .where(ChannelSubscription.instance_id == instance_id)
2353
+ .values(activity_id=None, timeout_at=None)
2354
+ )
2355
+
1909
2356
  return True
2357
+
2358
+ # -------------------------------------------------------------------------
2359
+ # Message Subscription Methods
2360
+ # -------------------------------------------------------------------------
2361
+
2362
+ async def find_waiting_instances_by_channel(self, channel: str) -> list[dict[str, Any]]:
2363
+ """
2364
+ Find all workflow instances waiting on a specific channel.
2365
+
2366
+ Args:
2367
+ channel: Channel name to search for
2368
+
2369
+ Returns:
2370
+ List of subscription info dicts with instance_id, channel, activity_id, timeout_at
2371
+ """
2372
+ session = self._get_session_for_operation()
2373
+ async with self._session_scope(session) as session:
2374
+ # Query ChannelSubscription table for waiting instances
2375
+ result = await session.execute(
2376
+ select(ChannelSubscription).where(
2377
+ and_(
2378
+ ChannelSubscription.channel == channel,
2379
+ ChannelSubscription.activity_id.isnot(None), # Only waiting subscriptions
2380
+ )
2381
+ )
2382
+ )
2383
+ subscriptions = result.scalars().all()
2384
+ return [
2385
+ {
2386
+ "instance_id": sub.instance_id,
2387
+ "channel": sub.channel,
2388
+ "activity_id": sub.activity_id,
2389
+ "timeout_at": sub.timeout_at.isoformat() if sub.timeout_at else None,
2390
+ "created_at": sub.subscribed_at.isoformat() if sub.subscribed_at else None,
2391
+ }
2392
+ for sub in subscriptions
2393
+ ]
2394
+
2395
+ async def remove_message_subscription(
2396
+ self,
2397
+ instance_id: str,
2398
+ channel: str,
2399
+ ) -> None:
2400
+ """
2401
+ Remove a message subscription.
2402
+
2403
+ This method clears waiting state from the ChannelSubscription table.
2404
+
2405
+ Args:
2406
+ instance_id: Workflow instance ID
2407
+ channel: Channel name
2408
+ """
2409
+ session = self._get_session_for_operation()
2410
+ async with self._session_scope(session) as session:
2411
+ # Clear waiting state from ChannelSubscription table
2412
+ # Don't delete the subscription - just clear the waiting state
2413
+ await session.execute(
2414
+ update(ChannelSubscription)
2415
+ .where(
2416
+ and_(
2417
+ ChannelSubscription.instance_id == instance_id,
2418
+ ChannelSubscription.channel == channel,
2419
+ )
2420
+ )
2421
+ .values(activity_id=None, timeout_at=None)
2422
+ )
2423
+ await self._commit_if_not_in_transaction(session)
2424
+
2425
+ async def deliver_message(
2426
+ self,
2427
+ instance_id: str,
2428
+ channel: str,
2429
+ data: dict[str, Any] | bytes,
2430
+ metadata: dict[str, Any],
2431
+ worker_id: str | None = None,
2432
+ ) -> dict[str, Any] | None:
2433
+ """
2434
+ Deliver a message to a waiting workflow using Lock-First pattern.
2435
+
2436
+ This method:
2437
+ 1. Checks if there's a subscription for this instance/channel
2438
+ 2. Acquires lock (Lock-First pattern) - if worker_id provided
2439
+ 3. Records the message in history
2440
+ 4. Removes the subscription
2441
+ 5. Updates status to 'running'
2442
+ 6. Releases lock
2443
+
2444
+ Args:
2445
+ instance_id: Target workflow instance ID
2446
+ channel: Channel name
2447
+ data: Message payload (dict or bytes)
2448
+ metadata: Message metadata
2449
+ worker_id: Worker ID for locking. If None, skip locking (legacy mode).
2450
+
2451
+ Returns:
2452
+ Dict with delivery info if successful, None otherwise
2453
+ """
2454
+ import uuid
2455
+
2456
+ # Step 1: Check if subscription exists (without lock)
2457
+ session = self._get_session_for_operation()
2458
+ async with self._session_scope(session) as session:
2459
+ result = await session.execute(
2460
+ select(ChannelSubscription).where(
2461
+ and_(
2462
+ ChannelSubscription.instance_id == instance_id,
2463
+ ChannelSubscription.channel == channel,
2464
+ ChannelSubscription.activity_id.isnot(None), # Only waiting subscriptions
2465
+ )
2466
+ )
2467
+ )
2468
+ subscription = result.scalar_one_or_none()
2469
+
2470
+ if subscription is None:
2471
+ return None
2472
+
2473
+ activity_id = subscription.activity_id
2474
+
2475
+ # Step 2: Acquire lock (Lock-First pattern) if worker_id provided
2476
+ lock_acquired = False
2477
+ if worker_id is not None:
2478
+ lock_acquired = await self.try_acquire_lock(instance_id, worker_id)
2479
+ if not lock_acquired:
2480
+ # Another worker is processing this workflow
2481
+ return None
2482
+
2483
+ try:
2484
+ # Step 3-5: Deliver message atomically
2485
+ session = self._get_session_for_operation()
2486
+ async with self._session_scope(session) as session:
2487
+ # Re-check subscription (may have been removed by another worker)
2488
+ result = await session.execute(
2489
+ select(ChannelSubscription).where(
2490
+ and_(
2491
+ ChannelSubscription.instance_id == instance_id,
2492
+ ChannelSubscription.channel == channel,
2493
+ ChannelSubscription.activity_id.isnot(
2494
+ None
2495
+ ), # Only waiting subscriptions
2496
+ )
2497
+ )
2498
+ )
2499
+ subscription = result.scalar_one_or_none()
2500
+
2501
+ if subscription is None:
2502
+ # Already delivered by another worker
2503
+ return None
2504
+
2505
+ activity_id = subscription.activity_id
2506
+
2507
+ # Get workflow info for return value
2508
+ instance_result = await session.execute(
2509
+ select(WorkflowInstance).where(WorkflowInstance.instance_id == instance_id)
2510
+ )
2511
+ instance = instance_result.scalar_one_or_none()
2512
+ workflow_name = instance.workflow_name if instance else "unknown"
2513
+
2514
+ # Build message data for history
2515
+ message_id = str(uuid.uuid4())
2516
+ message_data = {
2517
+ "id": message_id,
2518
+ "channel": channel,
2519
+ "data": data if isinstance(data, dict) else None,
2520
+ "metadata": metadata,
2521
+ }
2522
+
2523
+ # Handle binary data
2524
+ if isinstance(data, bytes):
2525
+ data_type = "binary"
2526
+ event_data_json = None
2527
+ event_data_binary = data
2528
+ else:
2529
+ data_type = "json"
2530
+ event_data_json = json.dumps(message_data)
2531
+ event_data_binary = None
2532
+
2533
+ # Record in history
2534
+ history_entry = WorkflowHistory(
2535
+ instance_id=instance_id,
2536
+ activity_id=activity_id,
2537
+ event_type="ChannelMessageReceived",
2538
+ data_type=data_type,
2539
+ event_data=event_data_json,
2540
+ event_data_binary=event_data_binary,
2541
+ )
2542
+ session.add(history_entry)
2543
+
2544
+ # Clear waiting state from subscription (don't delete)
2545
+ await session.execute(
2546
+ update(ChannelSubscription)
2547
+ .where(
2548
+ and_(
2549
+ ChannelSubscription.instance_id == instance_id,
2550
+ ChannelSubscription.channel == channel,
2551
+ )
2552
+ )
2553
+ .values(activity_id=None, timeout_at=None)
2554
+ )
2555
+
2556
+ # Update status to 'running' (ready for resumption)
2557
+ await session.execute(
2558
+ update(WorkflowInstance)
2559
+ .where(WorkflowInstance.instance_id == instance_id)
2560
+ .values(status="running", updated_at=func.now())
2561
+ )
2562
+
2563
+ await self._commit_if_not_in_transaction(session)
2564
+
2565
+ return {
2566
+ "instance_id": instance_id,
2567
+ "workflow_name": workflow_name,
2568
+ "activity_id": activity_id,
2569
+ }
2570
+
2571
+ finally:
2572
+ # Step 6: Release lock if we acquired it
2573
+ if lock_acquired and worker_id is not None:
2574
+ await self.release_lock(instance_id, worker_id)
2575
+
2576
+ async def find_expired_message_subscriptions(self) -> list[dict[str, Any]]:
2577
+ """
2578
+ Find all message subscriptions that have timed out.
2579
+
2580
+ JOINs with WorkflowInstance to ensure instance exists and avoid N+1 queries.
2581
+
2582
+ Returns:
2583
+ List of dicts with instance_id, channel, activity_id, timeout_at, created_at, workflow_name
2584
+ """
2585
+ session = self._get_session_for_operation()
2586
+ async with self._session_scope(session) as session:
2587
+ # Query ChannelSubscription table with JOIN
2588
+ result = await session.execute(
2589
+ select(
2590
+ ChannelSubscription.instance_id,
2591
+ ChannelSubscription.channel,
2592
+ ChannelSubscription.activity_id,
2593
+ ChannelSubscription.timeout_at,
2594
+ ChannelSubscription.subscribed_at,
2595
+ WorkflowInstance.workflow_name,
2596
+ )
2597
+ .join(
2598
+ WorkflowInstance,
2599
+ ChannelSubscription.instance_id == WorkflowInstance.instance_id,
2600
+ )
2601
+ .where(
2602
+ and_(
2603
+ ChannelSubscription.timeout_at.isnot(None),
2604
+ ChannelSubscription.activity_id.isnot(None), # Only waiting subscriptions
2605
+ self._make_datetime_comparable(ChannelSubscription.timeout_at)
2606
+ <= self._get_current_time_expr(),
2607
+ )
2608
+ )
2609
+ )
2610
+ rows = result.all()
2611
+ return [
2612
+ {
2613
+ "instance_id": row[0],
2614
+ "channel": row[1],
2615
+ "activity_id": row[2],
2616
+ "timeout_at": row[3],
2617
+ "created_at": row[4], # subscribed_at as created_at for compatibility
2618
+ "workflow_name": row[5],
2619
+ }
2620
+ for row in rows
2621
+ ]
2622
+
2623
+ # -------------------------------------------------------------------------
2624
+ # Group Membership Methods (Erlang pg style)
2625
+ # -------------------------------------------------------------------------
2626
+
2627
+ async def join_group(self, instance_id: str, group_name: str) -> None:
2628
+ """
2629
+ Add a workflow instance to a group.
2630
+
2631
+ Groups provide loose coupling for broadcast messaging.
2632
+ Idempotent - joining a group the instance is already in is a no-op.
2633
+
2634
+ Args:
2635
+ instance_id: Workflow instance ID
2636
+ group_name: Group to join
2637
+ """
2638
+ session = self._get_session_for_operation()
2639
+ async with self._session_scope(session) as session:
2640
+ # Check if already a member (for idempotency)
2641
+ result = await session.execute(
2642
+ select(WorkflowGroupMembership).where(
2643
+ and_(
2644
+ WorkflowGroupMembership.instance_id == instance_id,
2645
+ WorkflowGroupMembership.group_name == group_name,
2646
+ )
2647
+ )
2648
+ )
2649
+ if result.scalar_one_or_none() is not None:
2650
+ # Already a member, nothing to do
2651
+ return
2652
+
2653
+ # Add membership
2654
+ membership = WorkflowGroupMembership(
2655
+ instance_id=instance_id,
2656
+ group_name=group_name,
2657
+ )
2658
+ session.add(membership)
2659
+ await self._commit_if_not_in_transaction(session)
2660
+
2661
+ async def leave_group(self, instance_id: str, group_name: str) -> None:
2662
+ """
2663
+ Remove a workflow instance from a group.
2664
+
2665
+ Args:
2666
+ instance_id: Workflow instance ID
2667
+ group_name: Group to leave
2668
+ """
2669
+ session = self._get_session_for_operation()
2670
+ async with self._session_scope(session) as session:
2671
+ await session.execute(
2672
+ delete(WorkflowGroupMembership).where(
2673
+ and_(
2674
+ WorkflowGroupMembership.instance_id == instance_id,
2675
+ WorkflowGroupMembership.group_name == group_name,
2676
+ )
2677
+ )
2678
+ )
2679
+ await self._commit_if_not_in_transaction(session)
2680
+
2681
+ async def get_group_members(self, group_name: str) -> list[str]:
2682
+ """
2683
+ Get all workflow instances in a group.
2684
+
2685
+ Args:
2686
+ group_name: Group name
2687
+
2688
+ Returns:
2689
+ List of instance IDs in the group
2690
+ """
2691
+ session = self._get_session_for_operation()
2692
+ async with self._session_scope(session) as session:
2693
+ result = await session.execute(
2694
+ select(WorkflowGroupMembership.instance_id).where(
2695
+ WorkflowGroupMembership.group_name == group_name
2696
+ )
2697
+ )
2698
+ return [row[0] for row in result.fetchall()]
2699
+
2700
+ async def leave_all_groups(self, instance_id: str) -> None:
2701
+ """
2702
+ Remove a workflow instance from all groups.
2703
+
2704
+ Called automatically when a workflow completes or fails.
2705
+
2706
+ Args:
2707
+ instance_id: Workflow instance ID
2708
+ """
2709
+ session = self._get_session_for_operation()
2710
+ async with self._session_scope(session) as session:
2711
+ await session.execute(
2712
+ delete(WorkflowGroupMembership).where(
2713
+ WorkflowGroupMembership.instance_id == instance_id
2714
+ )
2715
+ )
2716
+ await self._commit_if_not_in_transaction(session)
2717
+
2718
+ # -------------------------------------------------------------------------
2719
+ # Workflow Resumption Methods
2720
+ # -------------------------------------------------------------------------
2721
+
2722
+ async def find_resumable_workflows(self) -> list[dict[str, Any]]:
2723
+ """
2724
+ Find workflows that are ready to be resumed.
2725
+
2726
+ Returns workflows with status='running' that don't have an active lock.
2727
+ Used for immediate resumption after message delivery.
2728
+
2729
+ Returns:
2730
+ List of resumable workflows with instance_id and workflow_name.
2731
+ """
2732
+ session = self._get_session_for_operation()
2733
+ async with self._session_scope(session) as session:
2734
+ result = await session.execute(
2735
+ select(
2736
+ WorkflowInstance.instance_id,
2737
+ WorkflowInstance.workflow_name,
2738
+ ).where(
2739
+ and_(
2740
+ WorkflowInstance.status == "running",
2741
+ WorkflowInstance.locked_by.is_(None),
2742
+ )
2743
+ )
2744
+ )
2745
+ return [
2746
+ {
2747
+ "instance_id": row.instance_id,
2748
+ "workflow_name": row.workflow_name,
2749
+ }
2750
+ for row in result.fetchall()
2751
+ ]
2752
+
2753
+ # -------------------------------------------------------------------------
2754
+ # Subscription Cleanup Methods (for recur())
2755
+ # -------------------------------------------------------------------------
2756
+
2757
+ async def cleanup_instance_subscriptions(self, instance_id: str) -> None:
2758
+ """
2759
+ Remove all subscriptions for a workflow instance.
2760
+
2761
+ Called during recur() to clean up event/timer/message subscriptions
2762
+ before archiving the history.
2763
+
2764
+ Args:
2765
+ instance_id: Workflow instance ID to clean up
2766
+ """
2767
+ session = self._get_session_for_operation()
2768
+ async with self._session_scope(session) as session:
2769
+ # Remove timer subscriptions
2770
+ await session.execute(
2771
+ delete(WorkflowTimerSubscription).where(
2772
+ WorkflowTimerSubscription.instance_id == instance_id
2773
+ )
2774
+ )
2775
+
2776
+ # Remove channel subscriptions
2777
+ await session.execute(
2778
+ delete(ChannelSubscription).where(ChannelSubscription.instance_id == instance_id)
2779
+ )
2780
+
2781
+ # Remove channel message claims
2782
+ await session.execute(
2783
+ delete(ChannelMessageClaim).where(ChannelMessageClaim.instance_id == instance_id)
2784
+ )
2785
+
2786
+ await self._commit_if_not_in_transaction(session)
2787
+
2788
+ # -------------------------------------------------------------------------
2789
+ # Channel-based Message Queue Methods
2790
+ # -------------------------------------------------------------------------
2791
+
2792
+ async def publish_to_channel(
2793
+ self,
2794
+ channel: str,
2795
+ data: dict[str, Any] | bytes,
2796
+ metadata: dict[str, Any] | None = None,
2797
+ ) -> str:
2798
+ """
2799
+ Publish a message to a channel.
2800
+
2801
+ Messages are persisted to channel_messages and available for subscribers.
2802
+
2803
+ Args:
2804
+ channel: Channel name
2805
+ data: Message payload (dict or bytes)
2806
+ metadata: Optional message metadata
2807
+
2808
+ Returns:
2809
+ Generated message_id (UUID)
2810
+ """
2811
+ import uuid
2812
+
2813
+ message_id = str(uuid.uuid4())
2814
+
2815
+ # Determine data type and serialize
2816
+ if isinstance(data, bytes):
2817
+ data_type = "binary"
2818
+ data_json = None
2819
+ data_binary = data
2820
+ else:
2821
+ data_type = "json"
2822
+ data_json = json.dumps(data)
2823
+ data_binary = None
2824
+
2825
+ metadata_json = json.dumps(metadata) if metadata else None
2826
+
2827
+ session = self._get_session_for_operation()
2828
+ async with self._session_scope(session) as session:
2829
+ msg = ChannelMessage(
2830
+ channel=channel,
2831
+ message_id=message_id,
2832
+ data_type=data_type,
2833
+ data=data_json,
2834
+ data_binary=data_binary,
2835
+ message_metadata=metadata_json,
2836
+ )
2837
+ session.add(msg)
2838
+ await self._commit_if_not_in_transaction(session)
2839
+
2840
+ return message_id
2841
+
2842
+ async def subscribe_to_channel(
2843
+ self,
2844
+ instance_id: str,
2845
+ channel: str,
2846
+ mode: str,
2847
+ ) -> None:
2848
+ """
2849
+ Subscribe a workflow instance to a channel.
2850
+
2851
+ Args:
2852
+ instance_id: Workflow instance ID
2853
+ channel: Channel name
2854
+ mode: 'broadcast' or 'competing'
2855
+
2856
+ Raises:
2857
+ ValueError: If mode is invalid
2858
+ """
2859
+ if mode not in ("broadcast", "competing"):
2860
+ raise ValueError(f"Invalid mode: {mode}. Must be 'broadcast' or 'competing'")
2861
+
2862
+ session = self._get_session_for_operation()
2863
+ async with self._session_scope(session) as session:
2864
+ # Check if already subscribed
2865
+ result = await session.execute(
2866
+ select(ChannelSubscription).where(
2867
+ and_(
2868
+ ChannelSubscription.instance_id == instance_id,
2869
+ ChannelSubscription.channel == channel,
2870
+ )
2871
+ )
2872
+ )
2873
+ existing = result.scalar_one_or_none()
2874
+
2875
+ if existing is not None:
2876
+ # Already subscribed, update mode if different
2877
+ if existing.mode != mode:
2878
+ existing.mode = mode
2879
+ await self._commit_if_not_in_transaction(session)
2880
+ return
2881
+
2882
+ # For broadcast mode, set cursor to current max message id
2883
+ # So subscriber only sees messages published after subscription
2884
+ cursor_message_id = None
2885
+ if mode == "broadcast":
2886
+ result = await session.execute(
2887
+ select(func.max(ChannelMessage.id)).where(ChannelMessage.channel == channel)
2888
+ )
2889
+ max_id = result.scalar()
2890
+ cursor_message_id = max_id if max_id is not None else 0
2891
+
2892
+ # Create subscription
2893
+ subscription = ChannelSubscription(
2894
+ instance_id=instance_id,
2895
+ channel=channel,
2896
+ mode=mode,
2897
+ cursor_message_id=cursor_message_id,
2898
+ )
2899
+ session.add(subscription)
2900
+ await self._commit_if_not_in_transaction(session)
2901
+
2902
+ async def unsubscribe_from_channel(
2903
+ self,
2904
+ instance_id: str,
2905
+ channel: str,
2906
+ ) -> None:
2907
+ """
2908
+ Unsubscribe a workflow instance from a channel.
2909
+
2910
+ Args:
2911
+ instance_id: Workflow instance ID
2912
+ channel: Channel name
2913
+ """
2914
+ session = self._get_session_for_operation()
2915
+ async with self._session_scope(session) as session:
2916
+ await session.execute(
2917
+ delete(ChannelSubscription).where(
2918
+ and_(
2919
+ ChannelSubscription.instance_id == instance_id,
2920
+ ChannelSubscription.channel == channel,
2921
+ )
2922
+ )
2923
+ )
2924
+ await self._commit_if_not_in_transaction(session)
2925
+
2926
+ async def get_channel_subscription(
2927
+ self,
2928
+ instance_id: str,
2929
+ channel: str,
2930
+ ) -> dict[str, Any] | None:
2931
+ """
2932
+ Get the subscription info for a workflow instance on a channel.
2933
+
2934
+ Args:
2935
+ instance_id: Workflow instance ID
2936
+ channel: Channel name
2937
+
2938
+ Returns:
2939
+ Subscription info dict with: mode, activity_id, cursor_message_id
2940
+ or None if not subscribed
2941
+ """
2942
+ session = self._get_session_for_operation()
2943
+ async with self._session_scope(session) as session:
2944
+ result = await session.execute(
2945
+ select(ChannelSubscription).where(
2946
+ and_(
2947
+ ChannelSubscription.instance_id == instance_id,
2948
+ ChannelSubscription.channel == channel,
2949
+ )
2950
+ )
2951
+ )
2952
+ subscription = result.scalar_one_or_none()
2953
+
2954
+ if subscription is None:
2955
+ return None
2956
+
2957
+ return {
2958
+ "mode": subscription.mode,
2959
+ "activity_id": subscription.activity_id,
2960
+ "cursor_message_id": subscription.cursor_message_id,
2961
+ }
2962
+
2963
+ async def register_channel_receive_and_release_lock(
2964
+ self,
2965
+ instance_id: str,
2966
+ worker_id: str,
2967
+ channel: str,
2968
+ activity_id: str | None = None,
2969
+ timeout_seconds: int | None = None,
2970
+ ) -> None:
2971
+ """
2972
+ Atomically register that workflow is waiting for channel message and release lock.
2973
+
2974
+ Args:
2975
+ instance_id: Workflow instance ID
2976
+ worker_id: Worker ID that currently holds the lock
2977
+ channel: Channel name being waited on
2978
+ activity_id: Current activity ID to record
2979
+ timeout_seconds: Optional timeout in seconds for the message wait
2980
+
2981
+ Raises:
2982
+ RuntimeError: If the worker doesn't hold the lock
2983
+ ValueError: If workflow is not subscribed to the channel
2984
+ """
2985
+ async with self.engine.begin() as conn:
2986
+ session = AsyncSession(bind=conn, expire_on_commit=False)
2987
+
2988
+ # Verify lock ownership
2989
+ result = await session.execute(
2990
+ select(WorkflowInstance).where(WorkflowInstance.instance_id == instance_id)
2991
+ )
2992
+ instance = result.scalar_one_or_none()
2993
+
2994
+ if instance is None:
2995
+ raise RuntimeError(f"Instance not found: {instance_id}")
2996
+
2997
+ if instance.locked_by != worker_id:
2998
+ raise RuntimeError(
2999
+ f"Worker {worker_id} does not hold lock for {instance_id}. "
3000
+ f"Locked by: {instance.locked_by}"
3001
+ )
3002
+
3003
+ # Verify subscription exists
3004
+ sub_result = await session.execute(
3005
+ select(ChannelSubscription).where(
3006
+ and_(
3007
+ ChannelSubscription.instance_id == instance_id,
3008
+ ChannelSubscription.channel == channel,
3009
+ )
3010
+ )
3011
+ )
3012
+ subscription: ChannelSubscription | None = sub_result.scalar_one_or_none()
3013
+
3014
+ if subscription is None:
3015
+ raise ValueError(f"Instance {instance_id} is not subscribed to channel {channel}")
3016
+
3017
+ # Update subscription to mark as waiting
3018
+ current_time = datetime.now(UTC)
3019
+ subscription.activity_id = activity_id
3020
+ # Calculate timeout_at if timeout_seconds is provided
3021
+ if timeout_seconds is not None:
3022
+ subscription.timeout_at = current_time + timedelta(seconds=timeout_seconds)
3023
+ else:
3024
+ subscription.timeout_at = None
3025
+
3026
+ # Update instance: set activity, status, release lock
3027
+ await session.execute(
3028
+ update(WorkflowInstance)
3029
+ .where(WorkflowInstance.instance_id == instance_id)
3030
+ .values(
3031
+ current_activity_id=activity_id,
3032
+ status="waiting_for_message",
3033
+ locked_by=None,
3034
+ locked_at=None,
3035
+ lock_expires_at=None,
3036
+ updated_at=current_time,
3037
+ )
3038
+ )
3039
+
3040
+ await session.commit()
3041
+
3042
+ async def get_pending_channel_messages(
3043
+ self,
3044
+ instance_id: str,
3045
+ channel: str,
3046
+ ) -> list[dict[str, Any]]:
3047
+ """
3048
+ Get pending messages for a subscriber on a channel.
3049
+
3050
+ For broadcast mode: messages with id > cursor_message_id
3051
+ For competing mode: unclaimed messages
3052
+
3053
+ Args:
3054
+ instance_id: Workflow instance ID
3055
+ channel: Channel name
3056
+
3057
+ Returns:
3058
+ List of pending messages
3059
+ """
3060
+ session = self._get_session_for_operation()
3061
+ async with self._session_scope(session) as session:
3062
+ # Get subscription info
3063
+ sub_result = await session.execute(
3064
+ select(ChannelSubscription).where(
3065
+ and_(
3066
+ ChannelSubscription.instance_id == instance_id,
3067
+ ChannelSubscription.channel == channel,
3068
+ )
3069
+ )
3070
+ )
3071
+ subscription = sub_result.scalar_one_or_none()
3072
+
3073
+ if subscription is None:
3074
+ return []
3075
+
3076
+ if subscription.mode == "broadcast":
3077
+ # Get messages after cursor
3078
+ cursor = subscription.cursor_message_id or 0
3079
+ msg_result = await session.execute(
3080
+ select(ChannelMessage)
3081
+ .where(
3082
+ and_(
3083
+ ChannelMessage.channel == channel,
3084
+ ChannelMessage.id > cursor,
3085
+ )
3086
+ )
3087
+ .order_by(ChannelMessage.published_at.asc())
3088
+ )
3089
+ else: # competing
3090
+ # Get unclaimed messages (not in channel_message_claims)
3091
+ subquery = select(ChannelMessageClaim.message_id)
3092
+ msg_result = await session.execute(
3093
+ select(ChannelMessage)
3094
+ .where(
3095
+ and_(
3096
+ ChannelMessage.channel == channel,
3097
+ ChannelMessage.message_id.not_in(subquery),
3098
+ )
3099
+ )
3100
+ .order_by(ChannelMessage.published_at.asc())
3101
+ )
3102
+
3103
+ messages = msg_result.scalars().all()
3104
+ return [
3105
+ {
3106
+ "id": msg.id,
3107
+ "message_id": msg.message_id,
3108
+ "channel": msg.channel,
3109
+ "data": (
3110
+ msg.data_binary
3111
+ if msg.data_type == "binary"
3112
+ else json.loads(msg.data) if msg.data else {}
3113
+ ),
3114
+ "metadata": json.loads(msg.message_metadata) if msg.message_metadata else {},
3115
+ "published_at": msg.published_at.isoformat() if msg.published_at else None,
3116
+ }
3117
+ for msg in messages
3118
+ ]
3119
+
3120
+ async def claim_channel_message(
3121
+ self,
3122
+ message_id: str,
3123
+ instance_id: str,
3124
+ ) -> bool:
3125
+ """
3126
+ Claim a message for competing consumption.
3127
+
3128
+ Uses INSERT with conflict check to ensure only one subscriber claims.
3129
+
3130
+ Args:
3131
+ message_id: Message ID to claim
3132
+ instance_id: Workflow instance claiming the message
3133
+
3134
+ Returns:
3135
+ True if claim succeeded, False if already claimed
3136
+ """
3137
+ session = self._get_session_for_operation()
3138
+ async with self._session_scope(session) as session:
3139
+ try:
3140
+ # Check if already claimed
3141
+ result = await session.execute(
3142
+ select(ChannelMessageClaim).where(ChannelMessageClaim.message_id == message_id)
3143
+ )
3144
+ if result.scalar_one_or_none() is not None:
3145
+ return False # Already claimed
3146
+
3147
+ claim = ChannelMessageClaim(
3148
+ message_id=message_id,
3149
+ instance_id=instance_id,
3150
+ )
3151
+ session.add(claim)
3152
+ await self._commit_if_not_in_transaction(session)
3153
+ return True
3154
+ except Exception:
3155
+ return False
3156
+
3157
+ async def delete_channel_message(self, message_id: str) -> None:
3158
+ """
3159
+ Delete a message from the channel queue.
3160
+
3161
+ Args:
3162
+ message_id: Message ID to delete
3163
+ """
3164
+ session = self._get_session_for_operation()
3165
+ async with self._session_scope(session) as session:
3166
+ # Delete claim first (foreign key)
3167
+ await session.execute(
3168
+ delete(ChannelMessageClaim).where(ChannelMessageClaim.message_id == message_id)
3169
+ )
3170
+ # Delete message
3171
+ await session.execute(
3172
+ delete(ChannelMessage).where(ChannelMessage.message_id == message_id)
3173
+ )
3174
+ await self._commit_if_not_in_transaction(session)
3175
+
3176
+ async def update_delivery_cursor(
3177
+ self,
3178
+ channel: str,
3179
+ instance_id: str,
3180
+ message_id: int,
3181
+ ) -> None:
3182
+ """
3183
+ Update the delivery cursor for broadcast mode.
3184
+
3185
+ Args:
3186
+ channel: Channel name
3187
+ instance_id: Subscriber instance ID
3188
+ message_id: Last delivered message's internal ID
3189
+ """
3190
+ session = self._get_session_for_operation()
3191
+ async with self._session_scope(session) as session:
3192
+ # Update subscription cursor
3193
+ await session.execute(
3194
+ update(ChannelSubscription)
3195
+ .where(
3196
+ and_(
3197
+ ChannelSubscription.instance_id == instance_id,
3198
+ ChannelSubscription.channel == channel,
3199
+ )
3200
+ )
3201
+ .values(cursor_message_id=message_id)
3202
+ )
3203
+ await self._commit_if_not_in_transaction(session)
3204
+
3205
+ async def get_channel_subscribers_waiting(
3206
+ self,
3207
+ channel: str,
3208
+ ) -> list[dict[str, Any]]:
3209
+ """
3210
+ Get channel subscribers that are waiting (activity_id is set).
3211
+
3212
+ Args:
3213
+ channel: Channel name
3214
+
3215
+ Returns:
3216
+ List of waiting subscribers
3217
+ """
3218
+ session = self._get_session_for_operation()
3219
+ async with self._session_scope(session) as session:
3220
+ result = await session.execute(
3221
+ select(ChannelSubscription).where(
3222
+ and_(
3223
+ ChannelSubscription.channel == channel,
3224
+ ChannelSubscription.activity_id.isnot(None),
3225
+ )
3226
+ )
3227
+ )
3228
+ subscriptions = result.scalars().all()
3229
+ return [
3230
+ {
3231
+ "instance_id": sub.instance_id,
3232
+ "channel": sub.channel,
3233
+ "mode": sub.mode,
3234
+ "activity_id": sub.activity_id,
3235
+ }
3236
+ for sub in subscriptions
3237
+ ]
3238
+
3239
+ async def clear_channel_waiting_state(
3240
+ self,
3241
+ instance_id: str,
3242
+ channel: str,
3243
+ ) -> None:
3244
+ """
3245
+ Clear the waiting state for a channel subscription.
3246
+
3247
+ Args:
3248
+ instance_id: Workflow instance ID
3249
+ channel: Channel name
3250
+ """
3251
+ session = self._get_session_for_operation()
3252
+ async with self._session_scope(session) as session:
3253
+ await session.execute(
3254
+ update(ChannelSubscription)
3255
+ .where(
3256
+ and_(
3257
+ ChannelSubscription.instance_id == instance_id,
3258
+ ChannelSubscription.channel == channel,
3259
+ )
3260
+ )
3261
+ .values(activity_id=None)
3262
+ )
3263
+ await self._commit_if_not_in_transaction(session)
3264
+
3265
+ async def deliver_channel_message(
3266
+ self,
3267
+ instance_id: str,
3268
+ channel: str,
3269
+ message_id: str,
3270
+ data: dict[str, Any] | bytes,
3271
+ metadata: dict[str, Any],
3272
+ worker_id: str,
3273
+ ) -> dict[str, Any] | None:
3274
+ """
3275
+ Deliver a channel message to a waiting workflow.
3276
+
3277
+ Uses Lock-First pattern for distributed safety.
3278
+
3279
+ Args:
3280
+ instance_id: Target workflow instance ID
3281
+ channel: Channel name
3282
+ message_id: Message ID being delivered
3283
+ data: Message payload
3284
+ metadata: Message metadata
3285
+ worker_id: Worker ID for locking
3286
+
3287
+ Returns:
3288
+ Delivery info if successful, None if failed
3289
+ """
3290
+ try:
3291
+ # Try to acquire lock
3292
+ if not await self.try_acquire_lock(instance_id, worker_id):
3293
+ logger.debug(f"Failed to acquire lock for {instance_id}")
3294
+ return None
3295
+
3296
+ try:
3297
+ async with self.engine.begin() as conn:
3298
+ session = AsyncSession(bind=conn, expire_on_commit=False)
3299
+
3300
+ # Get subscription info
3301
+ result = await session.execute(
3302
+ select(ChannelSubscription).where(
3303
+ and_(
3304
+ ChannelSubscription.instance_id == instance_id,
3305
+ ChannelSubscription.channel == channel,
3306
+ )
3307
+ )
3308
+ )
3309
+ subscription = result.scalar_one_or_none()
3310
+
3311
+ if subscription is None or subscription.activity_id is None:
3312
+ logger.debug(f"No waiting subscription for {instance_id} on {channel}")
3313
+ return None
3314
+
3315
+ activity_id = subscription.activity_id
3316
+
3317
+ # Get instance info for return value
3318
+ result = await session.execute(
3319
+ select(WorkflowInstance.workflow_name).where(
3320
+ WorkflowInstance.instance_id == instance_id
3321
+ )
3322
+ )
3323
+ row = result.one_or_none()
3324
+ if row is None:
3325
+ return None
3326
+ workflow_name = row[0]
3327
+
3328
+ # Prepare message data for history
3329
+ # Use "id" key to match what context.py expects when loading history
3330
+ current_time = datetime.now(UTC)
3331
+ message_result = {
3332
+ "id": message_id,
3333
+ "channel": channel,
3334
+ "data": data if isinstance(data, dict) else None,
3335
+ "metadata": metadata,
3336
+ "published_at": current_time.isoformat(),
3337
+ }
3338
+
3339
+ # Record to history
3340
+ if isinstance(data, bytes):
3341
+ history = WorkflowHistory(
3342
+ instance_id=instance_id,
3343
+ activity_id=activity_id,
3344
+ event_type="ChannelMessageReceived",
3345
+ data_type="binary",
3346
+ event_data=None,
3347
+ event_data_binary=data,
3348
+ )
3349
+ else:
3350
+ history = WorkflowHistory(
3351
+ instance_id=instance_id,
3352
+ activity_id=activity_id,
3353
+ event_type="ChannelMessageReceived",
3354
+ data_type="json",
3355
+ event_data=json.dumps(message_result),
3356
+ event_data_binary=None,
3357
+ )
3358
+ session.add(history)
3359
+
3360
+ # Handle mode-specific logic
3361
+ if subscription.mode == "broadcast":
3362
+ # Get message internal id to update cursor
3363
+ result = await session.execute(
3364
+ select(ChannelMessage.id).where(ChannelMessage.message_id == message_id)
3365
+ )
3366
+ msg_row = result.one_or_none()
3367
+ if msg_row:
3368
+ subscription.cursor_message_id = msg_row[0]
3369
+ else: # competing
3370
+ # Claim and delete the message
3371
+ claim = ChannelMessageClaim(
3372
+ message_id=message_id,
3373
+ instance_id=instance_id,
3374
+ )
3375
+ session.add(claim)
3376
+
3377
+ # Delete the message (competing mode consumes it)
3378
+ await session.execute(
3379
+ delete(ChannelMessage).where(ChannelMessage.message_id == message_id)
3380
+ )
3381
+
3382
+ # Clear waiting state
3383
+ subscription.activity_id = None
3384
+
3385
+ # Update instance status to running
3386
+ current_time = datetime.now(UTC)
3387
+ await session.execute(
3388
+ update(WorkflowInstance)
3389
+ .where(WorkflowInstance.instance_id == instance_id)
3390
+ .values(
3391
+ status="running",
3392
+ updated_at=current_time,
3393
+ )
3394
+ )
3395
+
3396
+ await session.commit()
3397
+
3398
+ return {
3399
+ "instance_id": instance_id,
3400
+ "workflow_name": workflow_name,
3401
+ "activity_id": activity_id,
3402
+ }
3403
+
3404
+ finally:
3405
+ # Always release lock
3406
+ await self.release_lock(instance_id, worker_id)
3407
+
3408
+ except Exception as e:
3409
+ logger.error(f"Error delivering channel message: {e}")
3410
+ return None
3411
+
3412
+ async def cleanup_old_channel_messages(self, older_than_days: int = 7) -> int:
3413
+ """
3414
+ Clean up old messages from channel queues.
3415
+
3416
+ Args:
3417
+ older_than_days: Message retention period in days
3418
+
3419
+ Returns:
3420
+ Number of messages deleted
3421
+ """
3422
+ cutoff_time = datetime.now(UTC) - timedelta(days=older_than_days)
3423
+
3424
+ session = self._get_session_for_operation()
3425
+ async with self._session_scope(session) as session:
3426
+ # First delete claims for old messages
3427
+ await session.execute(
3428
+ delete(ChannelMessageClaim).where(
3429
+ ChannelMessageClaim.message_id.in_(
3430
+ select(ChannelMessage.message_id).where(
3431
+ self._make_datetime_comparable(ChannelMessage.published_at)
3432
+ < self._get_current_time_expr()
3433
+ )
3434
+ )
3435
+ )
3436
+ )
3437
+
3438
+ # Delete old messages
3439
+ result = await session.execute(
3440
+ delete(ChannelMessage)
3441
+ .where(ChannelMessage.published_at < cutoff_time)
3442
+ .returning(ChannelMessage.id)
3443
+ )
3444
+ deleted_ids = result.fetchall()
3445
+ await self._commit_if_not_in_transaction(session)
3446
+
3447
+ return len(deleted_ids)