edda-framework 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ from sqlalchemy import (
19
19
  CheckConstraint,
20
20
  Column,
21
21
  DateTime,
22
+ ForeignKey,
22
23
  ForeignKeyConstraint,
23
24
  Index,
24
25
  Integer,
@@ -29,18 +30,21 @@ from sqlalchemy import (
29
30
  and_,
30
31
  delete,
31
32
  func,
33
+ inspect,
32
34
  or_,
33
35
  select,
34
36
  text,
35
37
  update,
36
38
  )
37
39
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
38
- from sqlalchemy.orm import declarative_base
40
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
39
41
 
40
42
  logger = logging.getLogger(__name__)
41
43
 
44
+
42
45
  # Declarative base for ORM models
43
- Base = declarative_base()
46
+ class Base(DeclarativeBase):
47
+ pass
44
48
 
45
49
 
46
50
  # ============================================================================
@@ -48,25 +52,25 @@ Base = declarative_base()
48
52
  # ============================================================================
49
53
 
50
54
 
51
- class SchemaVersion(Base): # type: ignore[valid-type, misc]
55
+ class SchemaVersion(Base):
52
56
  """Schema version tracking."""
53
57
 
54
58
  __tablename__ = "schema_version"
55
59
 
56
- version = Column(Integer, primary_key=True)
57
- applied_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
58
- description = Column(Text, nullable=False)
60
+ version: Mapped[int] = mapped_column(Integer, primary_key=True)
61
+ applied_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
62
+ description: Mapped[str] = mapped_column(Text)
59
63
 
60
64
 
61
- class WorkflowDefinition(Base): # type: ignore[valid-type, misc]
65
+ class WorkflowDefinition(Base):
62
66
  """Workflow definition (source code storage)."""
63
67
 
64
68
  __tablename__ = "workflow_definitions"
65
69
 
66
- workflow_name = Column(String(255), nullable=False, primary_key=True)
67
- source_hash = Column(String(64), nullable=False, primary_key=True)
68
- source_code = Column(Text, nullable=False)
69
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
70
+ workflow_name: Mapped[str] = mapped_column(String(255), primary_key=True)
71
+ source_hash: Mapped[str] = mapped_column(String(64), primary_key=True)
72
+ source_code: Mapped[str] = mapped_column(Text)
73
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
70
74
 
71
75
  __table_args__ = (
72
76
  Index("idx_definitions_name", "workflow_name"),
@@ -74,31 +78,34 @@ class WorkflowDefinition(Base): # type: ignore[valid-type, misc]
74
78
  )
75
79
 
76
80
 
77
- class WorkflowInstance(Base): # type: ignore[valid-type, misc]
81
+ class WorkflowInstance(Base):
78
82
  """Workflow instance with distributed locking support."""
79
83
 
80
84
  __tablename__ = "workflow_instances"
81
85
 
82
- instance_id = Column(String(255), primary_key=True)
83
- workflow_name = Column(String(255), nullable=False)
84
- source_hash = Column(String(64), nullable=False)
85
- owner_service = Column(String(255), nullable=False)
86
- status = Column(
87
- String(50),
88
- nullable=False,
89
- server_default=text("'running'"),
86
+ instance_id: Mapped[str] = mapped_column(String(255), primary_key=True)
87
+ workflow_name: Mapped[str] = mapped_column(String(255))
88
+ source_hash: Mapped[str] = mapped_column(String(64))
89
+ owner_service: Mapped[str] = mapped_column(String(255))
90
+ status: Mapped[str] = mapped_column(String(50), server_default=text("'running'"))
91
+ current_activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
92
+ continued_from: Mapped[str | None] = mapped_column(
93
+ String(255), ForeignKey("workflow_instances.instance_id"), nullable=True
90
94
  )
91
- current_activity_id = Column(String(255), nullable=True)
92
- started_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
93
- updated_at = Column(
94
- DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
95
+ started_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
96
+ updated_at: Mapped[datetime] = mapped_column(
97
+ DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
95
98
  )
96
- input_data = Column(Text, nullable=False) # JSON
97
- output_data = Column(Text, nullable=True) # JSON
98
- locked_by = Column(String(255), nullable=True)
99
- locked_at = Column(DateTime(timezone=True), nullable=True)
100
- lock_timeout_seconds = Column(Integer, nullable=True) # None = use global default (300s)
101
- lock_expires_at = Column(DateTime(timezone=True), nullable=True) # Absolute expiry time
99
+ input_data: Mapped[str] = mapped_column(Text) # JSON
100
+ output_data: Mapped[str | None] = mapped_column(Text, nullable=True) # JSON
101
+ locked_by: Mapped[str | None] = mapped_column(String(255), nullable=True)
102
+ locked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
103
+ lock_timeout_seconds: Mapped[int | None] = mapped_column(
104
+ Integer, nullable=True
105
+ ) # None = use global default (300s)
106
+ lock_expires_at: Mapped[datetime | None] = mapped_column(
107
+ DateTime(timezone=True), nullable=True
108
+ ) # Absolute expiry time
102
109
 
103
110
  __table_args__ = (
104
111
  ForeignKeyConstraint(
@@ -107,7 +114,7 @@ class WorkflowInstance(Base): # type: ignore[valid-type, misc]
107
114
  ),
108
115
  CheckConstraint(
109
116
  "status IN ('running', 'completed', 'failed', 'waiting_for_event', "
110
- "'waiting_for_timer', 'compensating', 'cancelled')",
117
+ "'waiting_for_timer', 'waiting_for_message', 'compensating', 'cancelled', 'recurred')",
111
118
  name="valid_status",
112
119
  ),
113
120
  Index("idx_instances_status", "status"),
@@ -117,25 +124,29 @@ class WorkflowInstance(Base): # type: ignore[valid-type, misc]
117
124
  Index("idx_instances_lock_expires", "lock_expires_at"),
118
125
  Index("idx_instances_updated", "updated_at"),
119
126
  Index("idx_instances_hash", "source_hash"),
127
+ Index("idx_instances_continued_from", "continued_from"),
128
+ # Composite index for find_resumable_workflows(): WHERE status='running' AND locked_by IS NULL
129
+ Index("idx_instances_resumable", "status", "locked_by"),
120
130
  )
121
131
 
122
132
 
123
- class WorkflowHistory(Base): # type: ignore[valid-type, misc]
133
+ class WorkflowHistory(Base):
124
134
  """Workflow execution history (for deterministic replay)."""
125
135
 
126
136
  __tablename__ = "workflow_history"
127
137
 
128
- id = Column(Integer, primary_key=True, autoincrement=True)
129
- instance_id = Column(
130
- String(255),
131
- nullable=False,
132
- )
133
- activity_id = Column(String(255), nullable=False)
134
- event_type = Column(String(100), nullable=False)
135
- data_type = Column(String(10), nullable=False) # 'json' or 'binary'
136
- event_data = Column(Text, nullable=True) # JSON (when data_type='json')
137
- event_data_binary = Column(LargeBinary, nullable=True) # Binary (when data_type='binary')
138
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
138
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
139
+ instance_id: Mapped[str] = mapped_column(String(255))
140
+ activity_id: Mapped[str] = mapped_column(String(255))
141
+ event_type: Mapped[str] = mapped_column(String(100))
142
+ data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
143
+ event_data: Mapped[str | None] = mapped_column(
144
+ Text, nullable=True
145
+ ) # JSON (when data_type='json')
146
+ event_data_binary: Mapped[bytes | None] = mapped_column(
147
+ LargeBinary, nullable=True
148
+ ) # Binary (when data_type='binary')
149
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
139
150
 
140
151
  __table_args__ = (
141
152
  ForeignKeyConstraint(
@@ -158,20 +169,20 @@ class WorkflowHistory(Base): # type: ignore[valid-type, misc]
158
169
  )
159
170
 
160
171
 
161
- class WorkflowCompensation(Base): # type: ignore[valid-type, misc]
162
- """Compensation transactions (LIFO stack for Saga pattern)."""
172
+ class WorkflowHistoryArchive(Base):
173
+ """Archived workflow execution history (for recur pattern)."""
163
174
 
164
- __tablename__ = "workflow_compensations"
175
+ __tablename__ = "workflow_history_archive"
165
176
 
166
- id = Column(Integer, primary_key=True, autoincrement=True)
167
- instance_id = Column(
168
- String(255),
169
- nullable=False,
177
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
178
+ instance_id: Mapped[str] = mapped_column(String(255))
179
+ activity_id: Mapped[str] = mapped_column(String(255))
180
+ event_type: Mapped[str] = mapped_column(String(100))
181
+ event_data: Mapped[str] = mapped_column(Text) # JSON (includes both types for archive)
182
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
183
+ archived_at: Mapped[datetime] = mapped_column(
184
+ DateTime(timezone=True), server_default=func.now()
170
185
  )
171
- activity_id = Column(String(255), nullable=False)
172
- activity_name = Column(String(255), nullable=False)
173
- args = Column(Text, nullable=False) # JSON
174
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
175
186
 
176
187
  __table_args__ = (
177
188
  ForeignKeyConstraint(
@@ -179,24 +190,22 @@ class WorkflowCompensation(Base): # type: ignore[valid-type, misc]
179
190
  ["workflow_instances.instance_id"],
180
191
  ondelete="CASCADE",
181
192
  ),
182
- Index("idx_compensations_instance", "instance_id", "created_at"),
193
+ Index("idx_history_archive_instance", "instance_id"),
194
+ Index("idx_history_archive_archived", "archived_at"),
183
195
  )
184
196
 
185
197
 
186
- class WorkflowEventSubscription(Base): # type: ignore[valid-type, misc]
187
- """Event subscriptions (for wait_event)."""
198
+ class WorkflowCompensation(Base):
199
+ """Compensation transactions (LIFO stack for Saga pattern)."""
188
200
 
189
- __tablename__ = "workflow_event_subscriptions"
201
+ __tablename__ = "workflow_compensations"
190
202
 
191
- id = Column(Integer, primary_key=True, autoincrement=True)
192
- instance_id = Column(
193
- String(255),
194
- nullable=False,
195
- )
196
- event_type = Column(String(255), nullable=False)
197
- activity_id = Column(String(255), nullable=True)
198
- timeout_at = Column(DateTime(timezone=True), nullable=True)
199
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
203
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
204
+ instance_id: Mapped[str] = mapped_column(String(255))
205
+ activity_id: Mapped[str] = mapped_column(String(255))
206
+ activity_name: Mapped[str] = mapped_column(String(255))
207
+ args: Mapped[str] = mapped_column(Text) # JSON
208
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
200
209
 
201
210
  __table_args__ = (
202
211
  ForeignKeyConstraint(
@@ -204,27 +213,21 @@ class WorkflowEventSubscription(Base): # type: ignore[valid-type, misc]
204
213
  ["workflow_instances.instance_id"],
205
214
  ondelete="CASCADE",
206
215
  ),
207
- UniqueConstraint("instance_id", "event_type", name="unique_instance_event"),
208
- Index("idx_subscriptions_event", "event_type"),
209
- Index("idx_subscriptions_timeout", "timeout_at"),
210
- Index("idx_subscriptions_instance", "instance_id"),
216
+ Index("idx_compensations_instance", "instance_id", "created_at"),
211
217
  )
212
218
 
213
219
 
214
- class WorkflowTimerSubscription(Base): # type: ignore[valid-type, misc]
220
+ class WorkflowTimerSubscription(Base):
215
221
  """Timer subscriptions (for wait_timer)."""
216
222
 
217
223
  __tablename__ = "workflow_timer_subscriptions"
218
224
 
219
- id = Column(Integer, primary_key=True, autoincrement=True)
220
- instance_id = Column(
221
- String(255),
222
- nullable=False,
223
- )
224
- timer_id = Column(String(255), nullable=False)
225
- expires_at = Column(DateTime(timezone=True), nullable=False)
226
- activity_id = Column(String(255), nullable=True)
227
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
225
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
226
+ instance_id: Mapped[str] = mapped_column(String(255))
227
+ timer_id: Mapped[str] = mapped_column(String(255))
228
+ expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
229
+ activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
230
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
228
231
 
229
232
  __table_args__ = (
230
233
  ForeignKeyConstraint(
@@ -238,23 +241,29 @@ class WorkflowTimerSubscription(Base): # type: ignore[valid-type, misc]
238
241
  )
239
242
 
240
243
 
241
- class OutboxEvent(Base): # type: ignore[valid-type, misc]
244
+ class OutboxEvent(Base):
242
245
  """Transactional outbox pattern events."""
243
246
 
244
247
  __tablename__ = "outbox_events"
245
248
 
246
- event_id = Column(String(255), primary_key=True)
247
- event_type = Column(String(255), nullable=False)
248
- event_source = Column(String(255), nullable=False)
249
- data_type = Column(String(10), nullable=False) # 'json' or 'binary'
250
- event_data = Column(Text, nullable=True) # JSON (when data_type='json')
251
- event_data_binary = Column(LargeBinary, nullable=True) # Binary (when data_type='binary')
252
- content_type = Column(String(100), nullable=False, server_default=text("'application/json'"))
253
- created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
254
- published_at = Column(DateTime(timezone=True), nullable=True)
255
- status = Column(String(50), nullable=False, server_default=text("'pending'"))
256
- retry_count = Column(Integer, nullable=False, server_default=text("0"))
257
- last_error = Column(Text, nullable=True)
249
+ event_id: Mapped[str] = mapped_column(String(255), primary_key=True)
250
+ event_type: Mapped[str] = mapped_column(String(255))
251
+ event_source: Mapped[str] = mapped_column(String(255))
252
+ data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
253
+ event_data: Mapped[str | None] = mapped_column(
254
+ Text, nullable=True
255
+ ) # JSON (when data_type='json')
256
+ event_data_binary: Mapped[bytes | None] = mapped_column(
257
+ LargeBinary, nullable=True
258
+ ) # Binary (when data_type='binary')
259
+ content_type: Mapped[str] = mapped_column(
260
+ String(100), server_default=text("'application/json'")
261
+ )
262
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
263
+ published_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
264
+ status: Mapped[str] = mapped_column(String(50), server_default=text("'pending'"))
265
+ retry_count: Mapped[int] = mapped_column(Integer, server_default=text("0"))
266
+ last_error: Mapped[str | None] = mapped_column(Text, nullable=True)
258
267
 
259
268
  __table_args__ = (
260
269
  CheckConstraint(
@@ -276,6 +285,203 @@ class OutboxEvent(Base): # type: ignore[valid-type, misc]
276
285
  )
277
286
 
278
287
 
288
+ class WorkflowMessageSubscription(Base):
289
+ """Message subscriptions (for wait_message)."""
290
+
291
+ __tablename__ = "workflow_message_subscriptions"
292
+
293
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
294
+ instance_id: Mapped[str] = mapped_column(String(255))
295
+ channel: Mapped[str] = mapped_column(String(255))
296
+ activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
297
+ timeout_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
298
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
299
+
300
+ __table_args__ = (
301
+ ForeignKeyConstraint(
302
+ ["instance_id"],
303
+ ["workflow_instances.instance_id"],
304
+ ondelete="CASCADE",
305
+ ),
306
+ UniqueConstraint("instance_id", "channel", name="unique_instance_channel"),
307
+ Index("idx_message_subscriptions_channel", "channel"),
308
+ Index("idx_message_subscriptions_timeout", "timeout_at"),
309
+ Index("idx_message_subscriptions_instance", "instance_id"),
310
+ )
311
+
312
+
313
+ class WorkflowGroupMembership(Base):
314
+ """Group memberships (Erlang pg style for broadcast messaging)."""
315
+
316
+ __tablename__ = "workflow_group_memberships"
317
+
318
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
319
+ instance_id: Mapped[str] = mapped_column(String(255))
320
+ group_name: Mapped[str] = mapped_column(String(255))
321
+ joined_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
322
+
323
+ __table_args__ = (
324
+ ForeignKeyConstraint(
325
+ ["instance_id"],
326
+ ["workflow_instances.instance_id"],
327
+ ondelete="CASCADE",
328
+ ),
329
+ UniqueConstraint("instance_id", "group_name", name="unique_instance_group"),
330
+ Index("idx_group_memberships_group", "group_name"),
331
+ Index("idx_group_memberships_instance", "instance_id"),
332
+ )
333
+
334
+
335
+ # =============================================================================
336
+ # Channel-based Message Queue Models
337
+ # =============================================================================
338
+
339
+
340
+ class ChannelMessage(Base):
341
+ """Channel message queue (persistent message storage)."""
342
+
343
+ __tablename__ = "channel_messages"
344
+
345
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
346
+ channel: Mapped[str] = mapped_column(String(255))
347
+ message_id: Mapped[str] = mapped_column(String(255), unique=True)
348
+ data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
349
+ data: Mapped[str | None] = mapped_column(Text, nullable=True) # JSON (when data_type='json')
350
+ data_binary: Mapped[bytes | None] = mapped_column(
351
+ LargeBinary, nullable=True
352
+ ) # Binary (when data_type='binary')
353
+ message_metadata: Mapped[str | None] = mapped_column(
354
+ "metadata", Text, nullable=True
355
+ ) # JSON - renamed to avoid SQLAlchemy reserved name
356
+ published_at: Mapped[datetime] = mapped_column(
357
+ DateTime(timezone=True), server_default=func.now()
358
+ )
359
+
360
+ __table_args__ = (
361
+ CheckConstraint(
362
+ "data_type IN ('json', 'binary')",
363
+ name="channel_valid_data_type",
364
+ ),
365
+ CheckConstraint(
366
+ "(data_type = 'json' AND data IS NOT NULL AND data_binary IS NULL) OR "
367
+ "(data_type = 'binary' AND data IS NULL AND data_binary IS NOT NULL)",
368
+ name="channel_data_type_consistency",
369
+ ),
370
+ Index("idx_channel_messages_channel", "channel", "published_at"),
371
+ Index("idx_channel_messages_id", "id"),
372
+ )
373
+
374
+
375
+ class ChannelSubscription(Base):
376
+ """Channel subscriptions for workflow instances."""
377
+
378
+ __tablename__ = "channel_subscriptions"
379
+
380
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
381
+ instance_id: Mapped[str] = mapped_column(String(255))
382
+ channel: Mapped[str] = mapped_column(String(255))
383
+ mode: Mapped[str] = mapped_column(String(20)) # 'broadcast' or 'competing'
384
+ activity_id: Mapped[str | None] = mapped_column(
385
+ String(255), nullable=True
386
+ ) # Set when waiting for message
387
+ cursor_message_id: Mapped[int | None] = mapped_column(
388
+ Integer, nullable=True
389
+ ) # Last received message id (broadcast)
390
+ timeout_at: Mapped[datetime | None] = mapped_column(
391
+ DateTime(timezone=True), nullable=True
392
+ ) # Timeout deadline
393
+ subscribed_at: Mapped[datetime] = mapped_column(
394
+ DateTime(timezone=True), server_default=func.now()
395
+ )
396
+
397
+ __table_args__ = (
398
+ ForeignKeyConstraint(
399
+ ["instance_id"],
400
+ ["workflow_instances.instance_id"],
401
+ ondelete="CASCADE",
402
+ ),
403
+ CheckConstraint(
404
+ "mode IN ('broadcast', 'competing')",
405
+ name="channel_valid_mode",
406
+ ),
407
+ UniqueConstraint("instance_id", "channel", name="unique_channel_instance_channel"),
408
+ Index("idx_channel_subscriptions_channel", "channel"),
409
+ Index("idx_channel_subscriptions_instance", "instance_id"),
410
+ Index("idx_channel_subscriptions_waiting", "channel", "activity_id"),
411
+ Index("idx_channel_subscriptions_timeout", "timeout_at"),
412
+ )
413
+
414
+
415
+ class ChannelDeliveryCursor(Base):
416
+ """Channel delivery cursors (broadcast mode: track who read what)."""
417
+
418
+ __tablename__ = "channel_delivery_cursors"
419
+
420
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
421
+ channel: Mapped[str] = mapped_column(String(255))
422
+ instance_id: Mapped[str] = mapped_column(String(255))
423
+ last_delivered_id: Mapped[int] = mapped_column(Integer)
424
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
425
+
426
+ __table_args__ = (
427
+ ForeignKeyConstraint(
428
+ ["instance_id"],
429
+ ["workflow_instances.instance_id"],
430
+ ondelete="CASCADE",
431
+ ),
432
+ UniqueConstraint("channel", "instance_id", name="unique_channel_delivery_cursor"),
433
+ Index("idx_channel_delivery_cursors_channel", "channel"),
434
+ )
435
+
436
+
437
+ class ChannelMessageClaim(Base):
438
+ """Channel message claims (competing mode: who is processing what)."""
439
+
440
+ __tablename__ = "channel_message_claims"
441
+
442
+ message_id: Mapped[str] = mapped_column(String(255), primary_key=True)
443
+ instance_id: Mapped[str] = mapped_column(String(255))
444
+ claimed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
445
+
446
+ __table_args__ = (
447
+ ForeignKeyConstraint(
448
+ ["message_id"],
449
+ ["channel_messages.message_id"],
450
+ ondelete="CASCADE",
451
+ ),
452
+ ForeignKeyConstraint(
453
+ ["instance_id"],
454
+ ["workflow_instances.instance_id"],
455
+ ondelete="CASCADE",
456
+ ),
457
+ Index("idx_channel_message_claims_instance", "instance_id"),
458
+ )
459
+
460
+
461
+ # =============================================================================
462
+ # System-level Lock Models (for background task coordination)
463
+ # =============================================================================
464
+
465
+
466
+ class SystemLock(Base):
467
+ """System-level locks for coordinating background tasks across pods.
468
+
469
+ Used to prevent duplicate execution of operational tasks like:
470
+ - cleanup_stale_locks_periodically()
471
+ - auto_resume_stale_workflows_periodically()
472
+ - _cleanup_old_messages_periodically()
473
+ """
474
+
475
+ __tablename__ = "system_locks"
476
+
477
+ lock_name: Mapped[str] = mapped_column(String(255), primary_key=True)
478
+ locked_by: Mapped[str | None] = mapped_column(String(255), nullable=True)
479
+ locked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
480
+ lock_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
481
+
482
+ __table_args__ = (Index("idx_system_locks_expires", "lock_expires_at"),)
483
+
484
+
279
485
  # Current schema version
280
486
  CURRENT_SCHEMA_VERSION = 1
281
487
 
@@ -337,11 +543,22 @@ class SQLAlchemyStorage:
337
543
  self.engine = engine
338
544
 
339
545
  async def initialize(self) -> None:
340
- """Initialize database connection and create tables."""
546
+ """Initialize database connection and create tables.
547
+
548
+ This method creates all tables if they don't exist, and then performs
549
+ automatic schema migration to add any missing columns and update CHECK
550
+ constraints. This ensures compatibility when upgrading Edda versions.
551
+ """
341
552
  # Create all tables and indexes
342
553
  async with self.engine.begin() as conn:
343
554
  await conn.run_sync(Base.metadata.create_all)
344
555
 
556
+ # Auto-migrate schema (add missing columns)
557
+ await self._auto_migrate_schema()
558
+
559
+ # Auto-migrate CHECK constraints
560
+ await self._auto_migrate_check_constraints()
561
+
345
562
  # Initialize schema version
346
563
  await self._initialize_schema_version()
347
564
 
@@ -366,6 +583,198 @@ class SQLAlchemyStorage:
366
583
  await session.commit()
367
584
  logger.info(f"Initialized schema version to {CURRENT_SCHEMA_VERSION}")
368
585
 
586
+ async def _auto_migrate_schema(self) -> None:
587
+ """
588
+ Automatically add missing columns to existing tables.
589
+
590
+ This method compares the ORM model definitions with the actual database
591
+ schema and adds any missing columns using ALTER TABLE ADD COLUMN.
592
+
593
+ Note: This only handles column additions, not removals or type changes.
594
+ For complex migrations, use Alembic.
595
+ """
596
+
597
+ def _get_column_type_sql(column: Column, dialect_name: str) -> str: # type: ignore[type-arg]
598
+ """Get SQL type string for a column based on dialect."""
599
+ col_type = column.type
600
+
601
+ # Map SQLAlchemy types to SQL types
602
+ if isinstance(col_type, String):
603
+ length = col_type.length or 255
604
+ return f"VARCHAR({length})"
605
+ elif isinstance(col_type, Text):
606
+ return "TEXT"
607
+ elif isinstance(col_type, Integer):
608
+ return "INTEGER"
609
+ elif isinstance(col_type, DateTime):
610
+ if dialect_name == "postgresql":
611
+ return "TIMESTAMP WITH TIME ZONE" if col_type.timezone else "TIMESTAMP"
612
+ elif dialect_name == "mysql":
613
+ return "DATETIME" if not col_type.timezone else "DATETIME"
614
+ else: # sqlite
615
+ return "DATETIME"
616
+ elif isinstance(col_type, LargeBinary):
617
+ if dialect_name == "postgresql":
618
+ return "BYTEA"
619
+ elif dialect_name == "mysql":
620
+ return "LONGBLOB"
621
+ else: # sqlite
622
+ return "BLOB"
623
+ else:
624
+ # Fallback to compiled type
625
+ return str(col_type.compile(dialect=self.engine.dialect))
626
+
627
+ def _get_default_sql(column: Column, _dialect_name: str) -> str | None: # type: ignore[type-arg]
628
+ """Get DEFAULT clause for a column if applicable."""
629
+ if column.server_default is not None:
630
+ # Handle text() server defaults - try to get the arg attribute
631
+ server_default = column.server_default
632
+ if hasattr(server_default, "arg"):
633
+ default_val = server_default.arg
634
+ if hasattr(default_val, "text"):
635
+ return f"DEFAULT {default_val.text}"
636
+ return f"DEFAULT {default_val}"
637
+ return None
638
+
639
+ def _run_migration(conn: Any) -> None:
640
+ """Run migration in sync context."""
641
+ dialect_name = self.engine.dialect.name
642
+ inspector = inspect(conn)
643
+
644
+ # Iterate through all ORM tables
645
+ for table in Base.metadata.tables.values():
646
+ table_name = table.name
647
+
648
+ # Check if table exists
649
+ if not inspector.has_table(table_name):
650
+ logger.debug(f"Table {table_name} does not exist, skipping migration")
651
+ continue
652
+
653
+ # Get existing columns
654
+ existing_columns = {col["name"] for col in inspector.get_columns(table_name)}
655
+
656
+ # Check each column in the ORM model
657
+ for column in table.columns:
658
+ if column.name not in existing_columns:
659
+ # Column is missing, generate ALTER TABLE
660
+ col_type_sql = _get_column_type_sql(column, dialect_name)
661
+ nullable = "NULL" if column.nullable else "NOT NULL"
662
+
663
+ # Build ALTER TABLE statement
664
+ alter_sql = (
665
+ f'ALTER TABLE "{table_name}" ADD COLUMN "{column.name}" {col_type_sql}'
666
+ )
667
+
668
+ # Add nullable constraint (only if NOT NULL and has default)
669
+ default_sql = _get_default_sql(column, dialect_name)
670
+ if not column.nullable and default_sql:
671
+ alter_sql += f" {default_sql} {nullable}"
672
+ elif column.nullable:
673
+ alter_sql += f" {nullable}"
674
+ elif default_sql:
675
+ alter_sql += f" {default_sql}"
676
+ # For NOT NULL without default, just add the column as nullable
677
+ # (PostgreSQL requires default or nullable for existing rows)
678
+ else:
679
+ alter_sql += " NULL"
680
+
681
+ logger.info(f"Auto-migrating: Adding column {column.name} to {table_name}")
682
+ logger.debug(f"Executing: {alter_sql}")
683
+
684
+ try:
685
+ conn.execute(text(alter_sql))
686
+ except Exception as e:
687
+ logger.warning(
688
+ f"Failed to add column {column.name} to {table_name}: {e}"
689
+ )
690
+
691
+ async with self.engine.begin() as conn:
692
+ await conn.run_sync(_run_migration)
693
+
694
+ async def _auto_migrate_check_constraints(self) -> None:
695
+ """
696
+ Automatically update CHECK constraints for workflow status.
697
+
698
+ This method ensures the valid_status CHECK constraint includes all
699
+ required status values (including 'waiting_for_message').
700
+ """
701
+ dialect_name = self.engine.dialect.name
702
+
703
+ # SQLite doesn't support ALTER CONSTRAINT easily, and SQLAlchemy create_all
704
+ # handles it correctly for new databases. For existing SQLite databases,
705
+ # the constraint is more lenient (CHECK is not enforced in many SQLite versions).
706
+ if dialect_name == "sqlite":
707
+ return
708
+
709
+ # Expected status values (must match WorkflowInstance model)
710
+ expected_statuses = (
711
+ "'running', 'completed', 'failed', 'waiting_for_event', "
712
+ "'waiting_for_timer', 'waiting_for_message', 'compensating', 'cancelled', 'recurred'"
713
+ )
714
+
715
+ def _run_constraint_migration(conn: Any) -> None:
716
+ """Run CHECK constraint migration in sync context."""
717
+ inspector = inspect(conn)
718
+
719
+ # Check if workflow_instances table exists
720
+ if not inspector.has_table("workflow_instances"):
721
+ return
722
+
723
+ # Get existing CHECK constraints
724
+ try:
725
+ constraints = inspector.get_check_constraints("workflow_instances")
726
+ except NotImplementedError:
727
+ # Some databases don't support get_check_constraints
728
+ logger.debug("Database doesn't support get_check_constraints inspection")
729
+ constraints = []
730
+
731
+ # Find the valid_status constraint
732
+ valid_status_constraint = None
733
+ for constraint in constraints:
734
+ if constraint.get("name") == "valid_status":
735
+ valid_status_constraint = constraint
736
+ break
737
+
738
+ # Check if constraint exists and needs updating
739
+ if valid_status_constraint:
740
+ sqltext = valid_status_constraint.get("sqltext", "")
741
+ # Check if 'waiting_for_message' is already in the constraint
742
+ if "waiting_for_message" in sqltext:
743
+ logger.debug("valid_status constraint already includes waiting_for_message")
744
+ return
745
+
746
+ # Need to update the constraint
747
+ logger.info("Updating valid_status CHECK constraint to include waiting_for_message")
748
+ try:
749
+ if dialect_name == "postgresql":
750
+ conn.execute(
751
+ text("ALTER TABLE workflow_instances DROP CONSTRAINT valid_status")
752
+ )
753
+ conn.execute(
754
+ text(
755
+ f"ALTER TABLE workflow_instances ADD CONSTRAINT valid_status "
756
+ f"CHECK (status IN ({expected_statuses}))"
757
+ )
758
+ )
759
+ elif dialect_name == "mysql":
760
+ # MySQL uses DROP CHECK and ADD CONSTRAINT CHECK syntax
761
+ conn.execute(text("ALTER TABLE workflow_instances DROP CHECK valid_status"))
762
+ conn.execute(
763
+ text(
764
+ f"ALTER TABLE workflow_instances ADD CONSTRAINT valid_status "
765
+ f"CHECK (status IN ({expected_statuses}))"
766
+ )
767
+ )
768
+ logger.info("Successfully updated valid_status CHECK constraint")
769
+ except Exception as e:
770
+ logger.warning(f"Failed to update valid_status CHECK constraint: {e}")
771
+ else:
772
+ # Constraint doesn't exist (shouldn't happen with create_all, but handle it)
773
+ logger.debug("valid_status constraint not found, will be created by create_all")
774
+
775
+ async with self.engine.begin() as conn:
776
+ await conn.run_sync(_run_constraint_migration)
777
+
369
778
  def _get_session_for_operation(self, is_lock_operation: bool = False) -> AsyncSession:
370
779
  """
371
780
  Get the appropriate session for an operation.
@@ -452,7 +861,7 @@ class SQLAlchemyStorage:
452
861
  Example:
453
862
  >>> # SQLite: datetime(timeout_at) <= datetime('now')
454
863
  >>> # PostgreSQL/MySQL: timeout_at <= NOW()
455
- >>> self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
864
+ >>> self._make_datetime_comparable(WorkflowMessageSubscription.timeout_at)
456
865
  >>> <= self._get_current_time_expr()
457
866
  """
458
867
  if self.engine.dialect.name == "sqlite":
@@ -605,7 +1014,7 @@ class SQLAlchemyStorage:
605
1014
 
606
1015
  if existing:
607
1016
  # Update
608
- existing.source_code = source_code # type: ignore[assignment]
1017
+ existing.source_code = source_code
609
1018
  else:
610
1019
  # Insert
611
1020
  definition = WorkflowDefinition(
@@ -682,6 +1091,7 @@ class SQLAlchemyStorage:
682
1091
  owner_service: str,
683
1092
  input_data: dict[str, Any],
684
1093
  lock_timeout_seconds: int | None = None,
1094
+ continued_from: str | None = None,
685
1095
  ) -> None:
686
1096
  """Create a new workflow instance."""
687
1097
  session = self._get_session_for_operation()
@@ -693,6 +1103,7 @@ class SQLAlchemyStorage:
693
1103
  owner_service=owner_service,
694
1104
  input_data=json.dumps(input_data),
695
1105
  lock_timeout_seconds=lock_timeout_seconds,
1106
+ continued_from=continued_from,
696
1107
  )
697
1108
  session.add(instance)
698
1109
 
@@ -1165,6 +1576,103 @@ class SQLAlchemyStorage:
1165
1576
  await session.commit()
1166
1577
  return workflows_to_resume
1167
1578
 
1579
+ # -------------------------------------------------------------------------
1580
+ # System-level Locking Methods (for background task coordination)
1581
+ # -------------------------------------------------------------------------
1582
+
1583
+ async def try_acquire_system_lock(
1584
+ self,
1585
+ lock_name: str,
1586
+ worker_id: str,
1587
+ timeout_seconds: int = 60,
1588
+ ) -> bool:
1589
+ """
1590
+ Try to acquire a system-level lock for coordinating background tasks.
1591
+
1592
+ Uses INSERT ON CONFLICT pattern to handle race conditions:
1593
+ 1. Try to INSERT new lock record
1594
+ 2. If exists, check if expired or unlocked
1595
+ 3. If available, acquire lock; otherwise return False
1596
+
1597
+ Note: ALWAYS uses separate session (not external session).
1598
+ """
1599
+ session = self._get_session_for_operation(is_lock_operation=True)
1600
+ async with self._session_scope(session) as session:
1601
+ current_time = datetime.now(UTC)
1602
+ lock_expires_at = current_time + timedelta(seconds=timeout_seconds)
1603
+
1604
+ # Try to get existing lock
1605
+ result = await session.execute(
1606
+ select(SystemLock).where(SystemLock.lock_name == lock_name)
1607
+ )
1608
+ lock = result.scalar_one_or_none()
1609
+
1610
+ if lock is None:
1611
+ # No lock exists - create new one
1612
+ lock = SystemLock(
1613
+ lock_name=lock_name,
1614
+ locked_by=worker_id,
1615
+ locked_at=current_time,
1616
+ lock_expires_at=lock_expires_at,
1617
+ )
1618
+ session.add(lock)
1619
+ await session.commit()
1620
+ return True
1621
+
1622
+ # Lock exists - check if available
1623
+ if lock.locked_by is None:
1624
+ # Unlocked - acquire
1625
+ lock.locked_by = worker_id
1626
+ lock.locked_at = current_time
1627
+ lock.lock_expires_at = lock_expires_at
1628
+ await session.commit()
1629
+ return True
1630
+
1631
+ # Check if expired (use SQL-side comparison for cross-DB compatibility)
1632
+ if lock.lock_expires_at is not None:
1633
+ # Handle timezone-naive datetime from SQLite
1634
+ lock_expires = (
1635
+ lock.lock_expires_at.replace(tzinfo=UTC)
1636
+ if lock.lock_expires_at.tzinfo is None
1637
+ else lock.lock_expires_at
1638
+ )
1639
+ if lock_expires <= current_time:
1640
+ # Expired - acquire
1641
+ lock.locked_by = worker_id
1642
+ lock.locked_at = current_time
1643
+ lock.lock_expires_at = lock_expires_at
1644
+ await session.commit()
1645
+ return True
1646
+
1647
+ # Already locked by another worker
1648
+ return False
1649
+
1650
+ async def release_system_lock(self, lock_name: str, worker_id: str) -> None:
1651
+ """
1652
+ Release a system-level lock.
1653
+
1654
+ Only releases the lock if it's held by the specified worker.
1655
+
1656
+ Note: ALWAYS uses separate session (not external session).
1657
+ """
1658
+ session = self._get_session_for_operation(is_lock_operation=True)
1659
+ async with self._session_scope(session) as session:
1660
+ await session.execute(
1661
+ update(SystemLock)
1662
+ .where(
1663
+ and_(
1664
+ SystemLock.lock_name == lock_name,
1665
+ SystemLock.locked_by == worker_id,
1666
+ )
1667
+ )
1668
+ .values(
1669
+ locked_by=None,
1670
+ locked_at=None,
1671
+ lock_expires_at=None,
1672
+ )
1673
+ )
1674
+ await session.commit()
1675
+
1168
1676
  # -------------------------------------------------------------------------
1169
1677
  # History Methods (prefer external session)
1170
1678
  # -------------------------------------------------------------------------
@@ -1231,6 +1739,107 @@ class SQLAlchemyStorage:
1231
1739
  for row in rows
1232
1740
  ]
1233
1741
 
1742
+ async def archive_history(self, instance_id: str) -> int:
1743
+ """
1744
+ Archive workflow history for the recur pattern.
1745
+
1746
+ Moves all history entries from workflow_history to workflow_history_archive.
1747
+ Binary data is converted to base64 for JSON storage in the archive.
1748
+
1749
+ Returns:
1750
+ Number of history entries archived
1751
+ """
1752
+ import base64
1753
+
1754
+ session = self._get_session_for_operation()
1755
+ async with self._session_scope(session) as session:
1756
+ # Get all history entries for this instance
1757
+ result = await session.execute(
1758
+ select(WorkflowHistory)
1759
+ .where(WorkflowHistory.instance_id == instance_id)
1760
+ .order_by(WorkflowHistory.created_at.asc())
1761
+ )
1762
+ history_rows = result.scalars().all()
1763
+
1764
+ if not history_rows:
1765
+ return 0
1766
+
1767
+ # Archive each history entry
1768
+ for row in history_rows:
1769
+ # Convert event_data to JSON string for archive
1770
+ event_data_json: str | None
1771
+ if row.data_type == "binary" and row.event_data_binary is not None:
1772
+ # Convert binary to base64 for JSON storage
1773
+ event_data_json = json.dumps(
1774
+ {
1775
+ "_binary": True,
1776
+ "data": base64.b64encode(row.event_data_binary).decode("ascii"),
1777
+ }
1778
+ )
1779
+ else:
1780
+ # Already JSON, use as-is
1781
+ event_data_json = row.event_data
1782
+
1783
+ archive_entry = WorkflowHistoryArchive(
1784
+ instance_id=row.instance_id,
1785
+ activity_id=row.activity_id,
1786
+ event_type=row.event_type,
1787
+ event_data=event_data_json,
1788
+ created_at=row.created_at,
1789
+ )
1790
+ session.add(archive_entry)
1791
+
1792
+ # Delete original history entries
1793
+ await session.execute(
1794
+ delete(WorkflowHistory).where(WorkflowHistory.instance_id == instance_id)
1795
+ )
1796
+
1797
+ await self._commit_if_not_in_transaction(session)
1798
+ return len(history_rows)
1799
+
1800
+ async def find_first_cancellation_event(self, instance_id: str) -> dict[str, Any] | None:
1801
+ """
1802
+ Find the first cancellation event in workflow history.
1803
+
1804
+ Uses LIMIT 1 optimization to avoid loading all history events.
1805
+ """
1806
+ session = self._get_session_for_operation()
1807
+ async with self._session_scope(session) as session:
1808
+ # Query for cancellation events using LIMIT 1
1809
+ result = await session.execute(
1810
+ select(WorkflowHistory)
1811
+ .where(
1812
+ and_(
1813
+ WorkflowHistory.instance_id == instance_id,
1814
+ or_(
1815
+ WorkflowHistory.event_type == "WorkflowCancelled",
1816
+ func.lower(WorkflowHistory.event_type).contains("cancel"),
1817
+ ),
1818
+ )
1819
+ )
1820
+ .order_by(WorkflowHistory.created_at.asc())
1821
+ .limit(1)
1822
+ )
1823
+ row = result.scalars().first()
1824
+
1825
+ if row is None:
1826
+ return None
1827
+
1828
+ # Parse event_data based on data_type
1829
+ if row.data_type == "binary" and row.event_data_binary is not None:
1830
+ event_data: dict[str, Any] | bytes = row.event_data_binary
1831
+ else:
1832
+ event_data = json.loads(row.event_data) if row.event_data else {}
1833
+
1834
+ return {
1835
+ "id": row.id,
1836
+ "instance_id": row.instance_id,
1837
+ "activity_id": row.activity_id,
1838
+ "event_type": row.event_type,
1839
+ "event_data": event_data,
1840
+ "created_at": row.created_at,
1841
+ }
1842
+
1234
1843
  # -------------------------------------------------------------------------
1235
1844
  # Compensation Methods (prefer external session)
1236
1845
  # -------------------------------------------------------------------------
@@ -1271,7 +1880,7 @@ class SQLAlchemyStorage:
1271
1880
  "instance_id": row.instance_id,
1272
1881
  "activity_id": row.activity_id,
1273
1882
  "activity_name": row.activity_name,
1274
- "args": json.loads(row.args), # type: ignore[arg-type]
1883
+ "args": json.loads(row.args) if row.args else [],
1275
1884
  "created_at": row.created_at.isoformat(),
1276
1885
  }
1277
1886
  for row in rows
@@ -1287,218 +1896,19 @@ class SQLAlchemyStorage:
1287
1896
  await self._commit_if_not_in_transaction(session)
1288
1897
 
1289
1898
  # -------------------------------------------------------------------------
1290
- # Event Subscription Methods (prefer external session for registration)
1899
+ # Timer Subscription Methods
1291
1900
  # -------------------------------------------------------------------------
1292
1901
 
1293
- async def add_event_subscription(
1902
+ async def register_timer_subscription_and_release_lock(
1294
1903
  self,
1295
1904
  instance_id: str,
1296
- event_type: str,
1297
- timeout_at: datetime | None = None,
1905
+ worker_id: str,
1906
+ timer_id: str,
1907
+ expires_at: datetime,
1908
+ activity_id: str | None = None,
1298
1909
  ) -> None:
1299
- """Register an event wait subscription."""
1300
- session = self._get_session_for_operation()
1301
- async with self._session_scope(session) as session:
1302
- subscription = WorkflowEventSubscription(
1303
- instance_id=instance_id,
1304
- event_type=event_type,
1305
- timeout_at=timeout_at,
1306
- )
1307
- session.add(subscription)
1308
- await self._commit_if_not_in_transaction(session)
1309
-
1310
- async def find_waiting_instances(self, event_type: str) -> list[dict[str, Any]]:
1311
- """Find workflow instances waiting for a specific event type."""
1312
- session = self._get_session_for_operation()
1313
- async with self._session_scope(session) as session:
1314
- result = await session.execute(
1315
- select(WorkflowEventSubscription).where(
1316
- and_(
1317
- WorkflowEventSubscription.event_type == event_type,
1318
- or_(
1319
- WorkflowEventSubscription.timeout_at.is_(None),
1320
- self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
1321
- > self._get_current_time_expr(),
1322
- ),
1323
- )
1324
- )
1325
- )
1326
- rows = result.scalars().all()
1327
-
1328
- return [
1329
- {
1330
- "id": row.id,
1331
- "instance_id": row.instance_id,
1332
- "event_type": row.event_type,
1333
- "activity_id": row.activity_id,
1334
- "timeout_at": row.timeout_at.isoformat() if row.timeout_at else None,
1335
- "created_at": row.created_at.isoformat(),
1336
- }
1337
- for row in rows
1338
- ]
1339
-
1340
- async def remove_event_subscription(
1341
- self,
1342
- instance_id: str,
1343
- event_type: str,
1344
- ) -> None:
1345
- """Remove event subscription after the event is received."""
1346
- session = self._get_session_for_operation()
1347
- async with self._session_scope(session) as session:
1348
- await session.execute(
1349
- delete(WorkflowEventSubscription).where(
1350
- and_(
1351
- WorkflowEventSubscription.instance_id == instance_id,
1352
- WorkflowEventSubscription.event_type == event_type,
1353
- )
1354
- )
1355
- )
1356
- await self._commit_if_not_in_transaction(session)
1357
-
1358
- async def cleanup_expired_subscriptions(self) -> int:
1359
- """Clean up event subscriptions that have timed out."""
1360
- session = self._get_session_for_operation()
1361
- async with self._session_scope(session) as session:
1362
- result = await session.execute(
1363
- delete(WorkflowEventSubscription).where(
1364
- and_(
1365
- WorkflowEventSubscription.timeout_at.isnot(None),
1366
- self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
1367
- <= self._get_current_time_expr(),
1368
- )
1369
- )
1370
- )
1371
- await self._commit_if_not_in_transaction(session)
1372
- return result.rowcount or 0 # type: ignore[attr-defined]
1373
-
1374
- async def find_expired_event_subscriptions(self) -> list[dict[str, Any]]:
1375
- """Find event subscriptions that have timed out."""
1376
- session = self._get_session_for_operation()
1377
- async with self._session_scope(session) as session:
1378
- result = await session.execute(
1379
- select(
1380
- WorkflowEventSubscription.instance_id,
1381
- WorkflowEventSubscription.event_type,
1382
- WorkflowEventSubscription.activity_id,
1383
- WorkflowEventSubscription.timeout_at,
1384
- WorkflowEventSubscription.created_at,
1385
- ).where(
1386
- and_(
1387
- WorkflowEventSubscription.timeout_at.isnot(None),
1388
- self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
1389
- <= self._get_current_time_expr(),
1390
- )
1391
- )
1392
- )
1393
- rows = result.all()
1394
-
1395
- return [
1396
- {
1397
- "instance_id": row[0],
1398
- "event_type": row[1],
1399
- "activity_id": row[2],
1400
- "timeout_at": row[3].isoformat() if row[3] else None,
1401
- "created_at": row[4].isoformat() if row[4] else None,
1402
- }
1403
- for row in rows
1404
- ]
1405
-
1406
- async def register_event_subscription_and_release_lock(
1407
- self,
1408
- instance_id: str,
1409
- worker_id: str,
1410
- event_type: str,
1411
- timeout_at: datetime | None = None,
1412
- activity_id: str | None = None,
1413
- ) -> None:
1414
- """
1415
- Atomically register event subscription and release workflow lock.
1416
-
1417
- This performs THREE operations in a SINGLE transaction:
1418
- 1. Register event subscription
1419
- 2. Update current activity
1420
- 3. Release lock
1421
-
1422
- This ensures distributed coroutines work correctly - when a workflow
1423
- calls wait_event(), the subscription is registered and lock is released
1424
- atomically, so ANY worker can resume the workflow when the event arrives.
1425
-
1426
- Note: Uses LOCK operation session (separate from external session).
1427
- """
1428
- session = self._get_session_for_operation(is_lock_operation=True)
1429
- async with self._session_scope(session) as session, session.begin():
1430
- # 1. Verify we hold the lock (sanity check)
1431
- result = await session.execute(
1432
- select(WorkflowInstance.locked_by).where(
1433
- WorkflowInstance.instance_id == instance_id
1434
- )
1435
- )
1436
- row = result.one_or_none()
1437
-
1438
- if row is None:
1439
- raise RuntimeError(f"Workflow instance {instance_id} not found")
1440
-
1441
- current_lock_holder = row[0]
1442
- if current_lock_holder != worker_id:
1443
- raise RuntimeError(
1444
- f"Cannot release lock: worker {worker_id} does not hold lock "
1445
- f"for {instance_id} (held by: {current_lock_holder})"
1446
- )
1447
-
1448
- # 2. Register event subscription (INSERT OR REPLACE equivalent)
1449
- # First delete existing
1450
- await session.execute(
1451
- delete(WorkflowEventSubscription).where(
1452
- and_(
1453
- WorkflowEventSubscription.instance_id == instance_id,
1454
- WorkflowEventSubscription.event_type == event_type,
1455
- )
1456
- )
1457
- )
1458
-
1459
- # Then insert new
1460
- subscription = WorkflowEventSubscription(
1461
- instance_id=instance_id,
1462
- event_type=event_type,
1463
- activity_id=activity_id,
1464
- timeout_at=timeout_at,
1465
- )
1466
- session.add(subscription)
1467
-
1468
- # 3. Update current activity (if provided)
1469
- if activity_id is not None:
1470
- await session.execute(
1471
- update(WorkflowInstance)
1472
- .where(WorkflowInstance.instance_id == instance_id)
1473
- .values(current_activity_id=activity_id, updated_at=func.now())
1474
- )
1475
-
1476
- # 4. Release lock
1477
- await session.execute(
1478
- update(WorkflowInstance)
1479
- .where(
1480
- and_(
1481
- WorkflowInstance.instance_id == instance_id,
1482
- WorkflowInstance.locked_by == worker_id,
1483
- )
1484
- )
1485
- .values(
1486
- locked_by=None,
1487
- locked_at=None,
1488
- updated_at=func.now(),
1489
- )
1490
- )
1491
-
1492
- async def register_timer_subscription_and_release_lock(
1493
- self,
1494
- instance_id: str,
1495
- worker_id: str,
1496
- timer_id: str,
1497
- expires_at: datetime,
1498
- activity_id: str | None = None,
1499
- ) -> None:
1500
- """
1501
- Atomically register timer subscription and release workflow lock.
1910
+ """
1911
+ Atomically register timer subscription and release workflow lock.
1502
1912
 
1503
1913
  This performs FOUR operations in a SINGLE transaction:
1504
1914
  1. Register timer subscription
@@ -1580,7 +1990,11 @@ class SQLAlchemyStorage:
1580
1990
  )
1581
1991
 
1582
1992
  async def find_expired_timers(self) -> list[dict[str, Any]]:
1583
- """Find timer subscriptions that have expired."""
1993
+ """Find timer subscriptions that have expired.
1994
+
1995
+ Returns timer info including workflow status to avoid N+1 queries.
1996
+ The SQL query already filters by status='waiting_for_timer'.
1997
+ """
1584
1998
  session = self._get_session_for_operation()
1585
1999
  async with self._session_scope(session) as session:
1586
2000
  result = await session.execute(
@@ -1590,6 +2004,7 @@ class SQLAlchemyStorage:
1590
2004
  WorkflowTimerSubscription.expires_at,
1591
2005
  WorkflowTimerSubscription.activity_id,
1592
2006
  WorkflowInstance.workflow_name,
2007
+ WorkflowInstance.status, # Include status to avoid N+1 query
1593
2008
  )
1594
2009
  .join(
1595
2010
  WorkflowInstance,
@@ -1612,6 +2027,7 @@ class SQLAlchemyStorage:
1612
2027
  "expires_at": row[2].isoformat(),
1613
2028
  "activity_id": row[3],
1614
2029
  "workflow_name": row[4],
2030
+ "status": row[5], # Always 'waiting_for_timer' due to WHERE clause
1615
2031
  }
1616
2032
  for row in rows
1617
2033
  ]
@@ -1866,6 +2282,7 @@ class SQLAlchemyStorage:
1866
2282
  "running",
1867
2283
  "waiting_for_event",
1868
2284
  "waiting_for_timer",
2285
+ "waiting_for_message",
1869
2286
  "compensating",
1870
2287
  ):
1871
2288
  # Already completed, failed, or cancelled
@@ -1890,14 +2307,6 @@ class SQLAlchemyStorage:
1890
2307
  )
1891
2308
  )
1892
2309
 
1893
- # Remove event subscriptions if waiting for event
1894
- if current_status == "waiting_for_event":
1895
- await session.execute(
1896
- delete(WorkflowEventSubscription).where(
1897
- WorkflowEventSubscription.instance_id == instance_id
1898
- )
1899
- )
1900
-
1901
2310
  # Remove timer subscriptions if waiting for timer
1902
2311
  if current_status == "waiting_for_timer":
1903
2312
  await session.execute(
@@ -1906,4 +2315,1249 @@ class SQLAlchemyStorage:
1906
2315
  )
1907
2316
  )
1908
2317
 
2318
+ # Remove message subscriptions if waiting for message
2319
+ await session.execute(
2320
+ delete(WorkflowMessageSubscription).where(
2321
+ WorkflowMessageSubscription.instance_id == instance_id
2322
+ )
2323
+ )
2324
+
1909
2325
  return True
2326
+
2327
+ # -------------------------------------------------------------------------
2328
+ # Message Subscription Methods
2329
+ # -------------------------------------------------------------------------
2330
+
2331
+ async def register_message_subscription_and_release_lock(
2332
+ self,
2333
+ instance_id: str,
2334
+ worker_id: str,
2335
+ channel: str,
2336
+ timeout_at: datetime | None = None,
2337
+ activity_id: str | None = None,
2338
+ ) -> None:
2339
+ """
2340
+ Atomically register a message subscription and release the workflow lock.
2341
+
2342
+ This is called when a workflow executes wait_message() and needs to:
2343
+ 1. Verify lock ownership (RuntimeError if mismatch - full rollback)
2344
+ 2. Register a subscription for the channel
2345
+ 3. Update the workflow status to waiting_for_message
2346
+ 4. Release the lock
2347
+
2348
+ All operations happen in a single transaction for atomicity.
2349
+
2350
+ Args:
2351
+ instance_id: Workflow instance ID
2352
+ worker_id: Worker ID that must hold the current lock (verified before release)
2353
+ channel: Channel name to subscribe to
2354
+ timeout_at: Optional absolute timeout time
2355
+ activity_id: Activity ID for the wait operation
2356
+
2357
+ Raises:
2358
+ RuntimeError: If the worker does not hold the lock (entire operation rolls back)
2359
+ """
2360
+ session = self._get_session_for_operation(is_lock_operation=True)
2361
+ async with self._session_scope(session) as session, session.begin():
2362
+ # 1. Verify we hold the lock (sanity check - fail fast if not)
2363
+ result = await session.execute(
2364
+ select(WorkflowInstance.locked_by).where(
2365
+ WorkflowInstance.instance_id == instance_id
2366
+ )
2367
+ )
2368
+ row = result.one_or_none()
2369
+
2370
+ if row is None:
2371
+ raise RuntimeError(f"Workflow instance {instance_id} not found")
2372
+
2373
+ current_lock_holder = row[0]
2374
+ if current_lock_holder != worker_id:
2375
+ raise RuntimeError(
2376
+ f"Cannot release lock: worker {worker_id} does not hold lock "
2377
+ f"for {instance_id} (held by: {current_lock_holder})"
2378
+ )
2379
+
2380
+ # 2. Register subscription (delete then insert for cross-database compatibility)
2381
+ await session.execute(
2382
+ delete(WorkflowMessageSubscription).where(
2383
+ and_(
2384
+ WorkflowMessageSubscription.instance_id == instance_id,
2385
+ WorkflowMessageSubscription.channel == channel,
2386
+ )
2387
+ )
2388
+ )
2389
+
2390
+ subscription = WorkflowMessageSubscription(
2391
+ instance_id=instance_id,
2392
+ channel=channel,
2393
+ activity_id=activity_id,
2394
+ timeout_at=timeout_at,
2395
+ )
2396
+ session.add(subscription)
2397
+
2398
+ # 3. Update workflow status and activity
2399
+ await session.execute(
2400
+ update(WorkflowInstance)
2401
+ .where(WorkflowInstance.instance_id == instance_id)
2402
+ .values(
2403
+ status="waiting_for_message",
2404
+ current_activity_id=activity_id,
2405
+ updated_at=func.now(),
2406
+ )
2407
+ )
2408
+
2409
+ # 4. Release the lock (with ownership check for extra safety)
2410
+ await session.execute(
2411
+ update(WorkflowInstance)
2412
+ .where(
2413
+ and_(
2414
+ WorkflowInstance.instance_id == instance_id,
2415
+ WorkflowInstance.locked_by == worker_id,
2416
+ )
2417
+ )
2418
+ .values(
2419
+ locked_by=None,
2420
+ locked_at=None,
2421
+ lock_expires_at=None,
2422
+ )
2423
+ )
2424
+
2425
+ async def find_waiting_instances_by_channel(self, channel: str) -> list[dict[str, Any]]:
2426
+ """
2427
+ Find all workflow instances waiting on a specific channel.
2428
+
2429
+ Args:
2430
+ channel: Channel name to search for
2431
+
2432
+ Returns:
2433
+ List of subscription info dicts with instance_id, channel, activity_id, timeout_at
2434
+ """
2435
+ session = self._get_session_for_operation()
2436
+ async with self._session_scope(session) as session:
2437
+ result = await session.execute(
2438
+ select(WorkflowMessageSubscription).where(
2439
+ WorkflowMessageSubscription.channel == channel
2440
+ )
2441
+ )
2442
+ subscriptions = result.scalars().all()
2443
+ return [
2444
+ {
2445
+ "instance_id": sub.instance_id,
2446
+ "channel": sub.channel,
2447
+ "activity_id": sub.activity_id,
2448
+ "timeout_at": sub.timeout_at.isoformat() if sub.timeout_at else None,
2449
+ "created_at": sub.created_at.isoformat() if sub.created_at else None,
2450
+ }
2451
+ for sub in subscriptions
2452
+ ]
2453
+
2454
+ async def remove_message_subscription(
2455
+ self,
2456
+ instance_id: str,
2457
+ channel: str,
2458
+ ) -> None:
2459
+ """
2460
+ Remove a message subscription.
2461
+
2462
+ This method removes from the legacy WorkflowMessageSubscription table
2463
+ and clears waiting state from the ChannelSubscription table.
2464
+
2465
+ Args:
2466
+ instance_id: Workflow instance ID
2467
+ channel: Channel name
2468
+ """
2469
+ session = self._get_session_for_operation()
2470
+ async with self._session_scope(session) as session:
2471
+ # Remove from legacy WorkflowMessageSubscription table
2472
+ await session.execute(
2473
+ delete(WorkflowMessageSubscription).where(
2474
+ and_(
2475
+ WorkflowMessageSubscription.instance_id == instance_id,
2476
+ WorkflowMessageSubscription.channel == channel,
2477
+ )
2478
+ )
2479
+ )
2480
+
2481
+ # Clear waiting state from ChannelSubscription table
2482
+ # Don't delete the subscription - just clear the waiting state
2483
+ await session.execute(
2484
+ update(ChannelSubscription)
2485
+ .where(
2486
+ and_(
2487
+ ChannelSubscription.instance_id == instance_id,
2488
+ ChannelSubscription.channel == channel,
2489
+ )
2490
+ )
2491
+ .values(activity_id=None, timeout_at=None)
2492
+ )
2493
+ await self._commit_if_not_in_transaction(session)
2494
+
2495
+ async def deliver_message(
2496
+ self,
2497
+ instance_id: str,
2498
+ channel: str,
2499
+ data: dict[str, Any] | bytes,
2500
+ metadata: dict[str, Any],
2501
+ worker_id: str | None = None,
2502
+ ) -> dict[str, Any] | None:
2503
+ """
2504
+ Deliver a message to a waiting workflow using Lock-First pattern.
2505
+
2506
+ This method:
2507
+ 1. Checks if there's a subscription for this instance/channel
2508
+ 2. Acquires lock (Lock-First pattern) - if worker_id provided
2509
+ 3. Records the message in history
2510
+ 4. Removes the subscription
2511
+ 5. Updates status to 'running'
2512
+ 6. Releases lock
2513
+
2514
+ Args:
2515
+ instance_id: Target workflow instance ID
2516
+ channel: Channel name
2517
+ data: Message payload (dict or bytes)
2518
+ metadata: Message metadata
2519
+ worker_id: Worker ID for locking. If None, skip locking (legacy mode).
2520
+
2521
+ Returns:
2522
+ Dict with delivery info if successful, None otherwise
2523
+ """
2524
+ import uuid
2525
+
2526
+ # Step 1: Check if subscription exists (without lock)
2527
+ session = self._get_session_for_operation()
2528
+ async with self._session_scope(session) as session:
2529
+ result = await session.execute(
2530
+ select(WorkflowMessageSubscription).where(
2531
+ and_(
2532
+ WorkflowMessageSubscription.instance_id == instance_id,
2533
+ WorkflowMessageSubscription.channel == channel,
2534
+ )
2535
+ )
2536
+ )
2537
+ subscription = result.scalar_one_or_none()
2538
+
2539
+ if subscription is None:
2540
+ return None
2541
+
2542
+ activity_id = subscription.activity_id
2543
+
2544
+ # Step 2: Acquire lock (Lock-First pattern) if worker_id provided
2545
+ lock_acquired = False
2546
+ if worker_id is not None:
2547
+ lock_acquired = await self.try_acquire_lock(instance_id, worker_id)
2548
+ if not lock_acquired:
2549
+ # Another worker is processing this workflow
2550
+ return None
2551
+
2552
+ try:
2553
+ # Step 3-5: Deliver message atomically
2554
+ session = self._get_session_for_operation()
2555
+ async with self._session_scope(session) as session:
2556
+ # Re-check subscription (may have been removed by another worker)
2557
+ result = await session.execute(
2558
+ select(WorkflowMessageSubscription).where(
2559
+ and_(
2560
+ WorkflowMessageSubscription.instance_id == instance_id,
2561
+ WorkflowMessageSubscription.channel == channel,
2562
+ )
2563
+ )
2564
+ )
2565
+ subscription = result.scalar_one_or_none()
2566
+
2567
+ if subscription is None:
2568
+ # Already delivered by another worker
2569
+ return None
2570
+
2571
+ activity_id = subscription.activity_id
2572
+
2573
+ # Get workflow info for return value
2574
+ instance_result = await session.execute(
2575
+ select(WorkflowInstance).where(WorkflowInstance.instance_id == instance_id)
2576
+ )
2577
+ instance = instance_result.scalar_one_or_none()
2578
+ workflow_name = instance.workflow_name if instance else "unknown"
2579
+
2580
+ # Build message data for history
2581
+ message_id = str(uuid.uuid4())
2582
+ message_data = {
2583
+ "id": message_id,
2584
+ "channel": channel,
2585
+ "data": data if isinstance(data, dict) else None,
2586
+ "metadata": metadata,
2587
+ }
2588
+
2589
+ # Handle binary data
2590
+ if isinstance(data, bytes):
2591
+ data_type = "binary"
2592
+ event_data_json = None
2593
+ event_data_binary = data
2594
+ else:
2595
+ data_type = "json"
2596
+ event_data_json = json.dumps(message_data)
2597
+ event_data_binary = None
2598
+
2599
+ # Record in history
2600
+ history_entry = WorkflowHistory(
2601
+ instance_id=instance_id,
2602
+ activity_id=activity_id,
2603
+ event_type="ChannelMessageReceived",
2604
+ data_type=data_type,
2605
+ event_data=event_data_json,
2606
+ event_data_binary=event_data_binary,
2607
+ )
2608
+ session.add(history_entry)
2609
+
2610
+ # Remove subscription
2611
+ await session.execute(
2612
+ delete(WorkflowMessageSubscription).where(
2613
+ and_(
2614
+ WorkflowMessageSubscription.instance_id == instance_id,
2615
+ WorkflowMessageSubscription.channel == channel,
2616
+ )
2617
+ )
2618
+ )
2619
+
2620
+ # Update status to 'running' (ready for resumption)
2621
+ await session.execute(
2622
+ update(WorkflowInstance)
2623
+ .where(WorkflowInstance.instance_id == instance_id)
2624
+ .values(status="running", updated_at=func.now())
2625
+ )
2626
+
2627
+ await self._commit_if_not_in_transaction(session)
2628
+
2629
+ return {
2630
+ "instance_id": instance_id,
2631
+ "workflow_name": workflow_name,
2632
+ "activity_id": activity_id,
2633
+ }
2634
+
2635
+ finally:
2636
+ # Step 6: Release lock if we acquired it
2637
+ if lock_acquired and worker_id is not None:
2638
+ await self.release_lock(instance_id, worker_id)
2639
+
2640
+ async def find_expired_message_subscriptions(self) -> list[dict[str, Any]]:
2641
+ """
2642
+ Find all message subscriptions that have timed out.
2643
+
2644
+ This method queries both the legacy WorkflowMessageSubscription table and
2645
+ the ChannelSubscription table for expired subscriptions.
2646
+ JOINs with WorkflowInstance to ensure instance exists and avoid N+1 queries.
2647
+
2648
+ Returns:
2649
+ List of dicts with instance_id, channel, activity_id, timeout_at, created_at, workflow_name
2650
+ """
2651
+ session = self._get_session_for_operation()
2652
+ async with self._session_scope(session) as session:
2653
+ results: list[dict[str, Any]] = []
2654
+
2655
+ # Query legacy WorkflowMessageSubscription table with JOIN to verify instance exists
2656
+ legacy_result = await session.execute(
2657
+ select(
2658
+ WorkflowMessageSubscription.instance_id,
2659
+ WorkflowMessageSubscription.channel,
2660
+ WorkflowMessageSubscription.activity_id,
2661
+ WorkflowMessageSubscription.timeout_at,
2662
+ WorkflowMessageSubscription.created_at,
2663
+ WorkflowInstance.workflow_name,
2664
+ )
2665
+ .join(
2666
+ WorkflowInstance,
2667
+ WorkflowMessageSubscription.instance_id == WorkflowInstance.instance_id,
2668
+ )
2669
+ .where(
2670
+ and_(
2671
+ WorkflowMessageSubscription.timeout_at.isnot(None),
2672
+ self._make_datetime_comparable(WorkflowMessageSubscription.timeout_at)
2673
+ <= self._get_current_time_expr(),
2674
+ )
2675
+ )
2676
+ )
2677
+ legacy_rows = legacy_result.all()
2678
+ results.extend(
2679
+ [
2680
+ {
2681
+ "instance_id": row[0],
2682
+ "channel": row[1],
2683
+ "activity_id": row[2],
2684
+ "timeout_at": row[3],
2685
+ "created_at": row[4],
2686
+ "workflow_name": row[5],
2687
+ }
2688
+ for row in legacy_rows
2689
+ ]
2690
+ )
2691
+
2692
+ # Query ChannelSubscription table with JOIN
2693
+ channel_result = await session.execute(
2694
+ select(
2695
+ ChannelSubscription.instance_id,
2696
+ ChannelSubscription.channel,
2697
+ ChannelSubscription.activity_id,
2698
+ ChannelSubscription.timeout_at,
2699
+ ChannelSubscription.subscribed_at,
2700
+ WorkflowInstance.workflow_name,
2701
+ )
2702
+ .join(
2703
+ WorkflowInstance,
2704
+ ChannelSubscription.instance_id == WorkflowInstance.instance_id,
2705
+ )
2706
+ .where(
2707
+ and_(
2708
+ ChannelSubscription.timeout_at.isnot(None),
2709
+ ChannelSubscription.activity_id.isnot(None), # Only waiting subscriptions
2710
+ self._make_datetime_comparable(ChannelSubscription.timeout_at)
2711
+ <= self._get_current_time_expr(),
2712
+ )
2713
+ )
2714
+ )
2715
+ channel_rows = channel_result.all()
2716
+ results.extend(
2717
+ [
2718
+ {
2719
+ "instance_id": row[0],
2720
+ "channel": row[1],
2721
+ "activity_id": row[2],
2722
+ "timeout_at": row[3],
2723
+ "created_at": row[4], # subscribed_at as created_at for compatibility
2724
+ "workflow_name": row[5],
2725
+ }
2726
+ for row in channel_rows
2727
+ ]
2728
+ )
2729
+
2730
+ return results
2731
+
2732
+ # -------------------------------------------------------------------------
2733
+ # Group Membership Methods (Erlang pg style)
2734
+ # -------------------------------------------------------------------------
2735
+
2736
+ async def join_group(self, instance_id: str, group_name: str) -> None:
2737
+ """
2738
+ Add a workflow instance to a group.
2739
+
2740
+ Groups provide loose coupling for broadcast messaging.
2741
+ Idempotent - joining a group the instance is already in is a no-op.
2742
+
2743
+ Args:
2744
+ instance_id: Workflow instance ID
2745
+ group_name: Group to join
2746
+ """
2747
+ session = self._get_session_for_operation()
2748
+ async with self._session_scope(session) as session:
2749
+ # Check if already a member (for idempotency)
2750
+ result = await session.execute(
2751
+ select(WorkflowGroupMembership).where(
2752
+ and_(
2753
+ WorkflowGroupMembership.instance_id == instance_id,
2754
+ WorkflowGroupMembership.group_name == group_name,
2755
+ )
2756
+ )
2757
+ )
2758
+ if result.scalar_one_or_none() is not None:
2759
+ # Already a member, nothing to do
2760
+ return
2761
+
2762
+ # Add membership
2763
+ membership = WorkflowGroupMembership(
2764
+ instance_id=instance_id,
2765
+ group_name=group_name,
2766
+ )
2767
+ session.add(membership)
2768
+ await self._commit_if_not_in_transaction(session)
2769
+
2770
+ async def leave_group(self, instance_id: str, group_name: str) -> None:
2771
+ """
2772
+ Remove a workflow instance from a group.
2773
+
2774
+ Args:
2775
+ instance_id: Workflow instance ID
2776
+ group_name: Group to leave
2777
+ """
2778
+ session = self._get_session_for_operation()
2779
+ async with self._session_scope(session) as session:
2780
+ await session.execute(
2781
+ delete(WorkflowGroupMembership).where(
2782
+ and_(
2783
+ WorkflowGroupMembership.instance_id == instance_id,
2784
+ WorkflowGroupMembership.group_name == group_name,
2785
+ )
2786
+ )
2787
+ )
2788
+ await self._commit_if_not_in_transaction(session)
2789
+
2790
+ async def get_group_members(self, group_name: str) -> list[str]:
2791
+ """
2792
+ Get all workflow instances in a group.
2793
+
2794
+ Args:
2795
+ group_name: Group name
2796
+
2797
+ Returns:
2798
+ List of instance IDs in the group
2799
+ """
2800
+ session = self._get_session_for_operation()
2801
+ async with self._session_scope(session) as session:
2802
+ result = await session.execute(
2803
+ select(WorkflowGroupMembership.instance_id).where(
2804
+ WorkflowGroupMembership.group_name == group_name
2805
+ )
2806
+ )
2807
+ return [row[0] for row in result.fetchall()]
2808
+
2809
+ async def leave_all_groups(self, instance_id: str) -> None:
2810
+ """
2811
+ Remove a workflow instance from all groups.
2812
+
2813
+ Called automatically when a workflow completes or fails.
2814
+
2815
+ Args:
2816
+ instance_id: Workflow instance ID
2817
+ """
2818
+ session = self._get_session_for_operation()
2819
+ async with self._session_scope(session) as session:
2820
+ await session.execute(
2821
+ delete(WorkflowGroupMembership).where(
2822
+ WorkflowGroupMembership.instance_id == instance_id
2823
+ )
2824
+ )
2825
+ await self._commit_if_not_in_transaction(session)
2826
+
2827
+ # -------------------------------------------------------------------------
2828
+ # Workflow Resumption Methods
2829
+ # -------------------------------------------------------------------------
2830
+
2831
+ async def find_resumable_workflows(self) -> list[dict[str, Any]]:
2832
+ """
2833
+ Find workflows that are ready to be resumed.
2834
+
2835
+ Returns workflows with status='running' that don't have an active lock.
2836
+ Used for immediate resumption after message delivery.
2837
+
2838
+ Returns:
2839
+ List of resumable workflows with instance_id and workflow_name.
2840
+ """
2841
+ session = self._get_session_for_operation()
2842
+ async with self._session_scope(session) as session:
2843
+ result = await session.execute(
2844
+ select(
2845
+ WorkflowInstance.instance_id,
2846
+ WorkflowInstance.workflow_name,
2847
+ ).where(
2848
+ and_(
2849
+ WorkflowInstance.status == "running",
2850
+ WorkflowInstance.locked_by.is_(None),
2851
+ )
2852
+ )
2853
+ )
2854
+ return [
2855
+ {
2856
+ "instance_id": row.instance_id,
2857
+ "workflow_name": row.workflow_name,
2858
+ }
2859
+ for row in result.fetchall()
2860
+ ]
2861
+
2862
+ # -------------------------------------------------------------------------
2863
+ # Subscription Cleanup Methods (for recur())
2864
+ # -------------------------------------------------------------------------
2865
+
2866
+ async def cleanup_instance_subscriptions(self, instance_id: str) -> None:
2867
+ """
2868
+ Remove all subscriptions for a workflow instance.
2869
+
2870
+ Called during recur() to clean up event/timer/message subscriptions
2871
+ before archiving the history.
2872
+
2873
+ Args:
2874
+ instance_id: Workflow instance ID to clean up
2875
+ """
2876
+ session = self._get_session_for_operation()
2877
+ async with self._session_scope(session) as session:
2878
+ # Remove timer subscriptions
2879
+ await session.execute(
2880
+ delete(WorkflowTimerSubscription).where(
2881
+ WorkflowTimerSubscription.instance_id == instance_id
2882
+ )
2883
+ )
2884
+
2885
+ # Remove message subscriptions (legacy)
2886
+ await session.execute(
2887
+ delete(WorkflowMessageSubscription).where(
2888
+ WorkflowMessageSubscription.instance_id == instance_id
2889
+ )
2890
+ )
2891
+
2892
+ # Remove channel subscriptions
2893
+ await session.execute(
2894
+ delete(ChannelSubscription).where(ChannelSubscription.instance_id == instance_id)
2895
+ )
2896
+
2897
+ # Remove channel message claims
2898
+ await session.execute(
2899
+ delete(ChannelMessageClaim).where(ChannelMessageClaim.instance_id == instance_id)
2900
+ )
2901
+
2902
+ await self._commit_if_not_in_transaction(session)
2903
+
2904
+ # -------------------------------------------------------------------------
2905
+ # Channel-based Message Queue Methods
2906
+ # -------------------------------------------------------------------------
2907
+
2908
+ async def publish_to_channel(
2909
+ self,
2910
+ channel: str,
2911
+ data: dict[str, Any] | bytes,
2912
+ metadata: dict[str, Any] | None = None,
2913
+ ) -> str:
2914
+ """
2915
+ Publish a message to a channel.
2916
+
2917
+ Messages are persisted to channel_messages and available for subscribers.
2918
+
2919
+ Args:
2920
+ channel: Channel name
2921
+ data: Message payload (dict or bytes)
2922
+ metadata: Optional message metadata
2923
+
2924
+ Returns:
2925
+ Generated message_id (UUID)
2926
+ """
2927
+ import uuid
2928
+
2929
+ message_id = str(uuid.uuid4())
2930
+
2931
+ # Determine data type and serialize
2932
+ if isinstance(data, bytes):
2933
+ data_type = "binary"
2934
+ data_json = None
2935
+ data_binary = data
2936
+ else:
2937
+ data_type = "json"
2938
+ data_json = json.dumps(data)
2939
+ data_binary = None
2940
+
2941
+ metadata_json = json.dumps(metadata) if metadata else None
2942
+
2943
+ session = self._get_session_for_operation()
2944
+ async with self._session_scope(session) as session:
2945
+ msg = ChannelMessage(
2946
+ channel=channel,
2947
+ message_id=message_id,
2948
+ data_type=data_type,
2949
+ data=data_json,
2950
+ data_binary=data_binary,
2951
+ message_metadata=metadata_json,
2952
+ )
2953
+ session.add(msg)
2954
+ await self._commit_if_not_in_transaction(session)
2955
+
2956
+ return message_id
2957
+
2958
+ async def subscribe_to_channel(
2959
+ self,
2960
+ instance_id: str,
2961
+ channel: str,
2962
+ mode: str,
2963
+ ) -> None:
2964
+ """
2965
+ Subscribe a workflow instance to a channel.
2966
+
2967
+ Args:
2968
+ instance_id: Workflow instance ID
2969
+ channel: Channel name
2970
+ mode: 'broadcast' or 'competing'
2971
+
2972
+ Raises:
2973
+ ValueError: If mode is invalid
2974
+ """
2975
+ if mode not in ("broadcast", "competing"):
2976
+ raise ValueError(f"Invalid mode: {mode}. Must be 'broadcast' or 'competing'")
2977
+
2978
+ session = self._get_session_for_operation()
2979
+ async with self._session_scope(session) as session:
2980
+ # Check if already subscribed
2981
+ result = await session.execute(
2982
+ select(ChannelSubscription).where(
2983
+ and_(
2984
+ ChannelSubscription.instance_id == instance_id,
2985
+ ChannelSubscription.channel == channel,
2986
+ )
2987
+ )
2988
+ )
2989
+ existing = result.scalar_one_or_none()
2990
+
2991
+ if existing is not None:
2992
+ # Already subscribed, update mode if different
2993
+ if existing.mode != mode:
2994
+ existing.mode = mode
2995
+ await self._commit_if_not_in_transaction(session)
2996
+ return
2997
+
2998
+ # For broadcast mode, set cursor to current max message id
2999
+ # So subscriber only sees messages published after subscription
3000
+ cursor_message_id = None
3001
+ if mode == "broadcast":
3002
+ result = await session.execute(
3003
+ select(func.max(ChannelMessage.id)).where(ChannelMessage.channel == channel)
3004
+ )
3005
+ max_id = result.scalar()
3006
+ cursor_message_id = max_id if max_id is not None else 0
3007
+
3008
+ # Create subscription
3009
+ subscription = ChannelSubscription(
3010
+ instance_id=instance_id,
3011
+ channel=channel,
3012
+ mode=mode,
3013
+ cursor_message_id=cursor_message_id,
3014
+ )
3015
+ session.add(subscription)
3016
+ await self._commit_if_not_in_transaction(session)
3017
+
3018
+ async def unsubscribe_from_channel(
3019
+ self,
3020
+ instance_id: str,
3021
+ channel: str,
3022
+ ) -> None:
3023
+ """
3024
+ Unsubscribe a workflow instance from a channel.
3025
+
3026
+ Args:
3027
+ instance_id: Workflow instance ID
3028
+ channel: Channel name
3029
+ """
3030
+ session = self._get_session_for_operation()
3031
+ async with self._session_scope(session) as session:
3032
+ await session.execute(
3033
+ delete(ChannelSubscription).where(
3034
+ and_(
3035
+ ChannelSubscription.instance_id == instance_id,
3036
+ ChannelSubscription.channel == channel,
3037
+ )
3038
+ )
3039
+ )
3040
+ await self._commit_if_not_in_transaction(session)
3041
+
3042
+ async def get_channel_subscription(
3043
+ self,
3044
+ instance_id: str,
3045
+ channel: str,
3046
+ ) -> dict[str, Any] | None:
3047
+ """
3048
+ Get the subscription info for a workflow instance on a channel.
3049
+
3050
+ Args:
3051
+ instance_id: Workflow instance ID
3052
+ channel: Channel name
3053
+
3054
+ Returns:
3055
+ Subscription info dict with: mode, activity_id, cursor_message_id
3056
+ or None if not subscribed
3057
+ """
3058
+ session = self._get_session_for_operation()
3059
+ async with self._session_scope(session) as session:
3060
+ result = await session.execute(
3061
+ select(ChannelSubscription).where(
3062
+ and_(
3063
+ ChannelSubscription.instance_id == instance_id,
3064
+ ChannelSubscription.channel == channel,
3065
+ )
3066
+ )
3067
+ )
3068
+ subscription = result.scalar_one_or_none()
3069
+
3070
+ if subscription is None:
3071
+ return None
3072
+
3073
+ return {
3074
+ "mode": subscription.mode,
3075
+ "activity_id": subscription.activity_id,
3076
+ "cursor_message_id": subscription.cursor_message_id,
3077
+ }
3078
+
3079
+ async def register_channel_receive_and_release_lock(
3080
+ self,
3081
+ instance_id: str,
3082
+ worker_id: str,
3083
+ channel: str,
3084
+ activity_id: str | None = None,
3085
+ timeout_seconds: int | None = None,
3086
+ ) -> None:
3087
+ """
3088
+ Atomically register that workflow is waiting for channel message and release lock.
3089
+
3090
+ Args:
3091
+ instance_id: Workflow instance ID
3092
+ worker_id: Worker ID that currently holds the lock
3093
+ channel: Channel name being waited on
3094
+ activity_id: Current activity ID to record
3095
+ timeout_seconds: Optional timeout in seconds for the message wait
3096
+
3097
+ Raises:
3098
+ RuntimeError: If the worker doesn't hold the lock
3099
+ ValueError: If workflow is not subscribed to the channel
3100
+ """
3101
+ async with self.engine.begin() as conn:
3102
+ session = AsyncSession(bind=conn, expire_on_commit=False)
3103
+
3104
+ # Verify lock ownership
3105
+ result = await session.execute(
3106
+ select(WorkflowInstance).where(WorkflowInstance.instance_id == instance_id)
3107
+ )
3108
+ instance = result.scalar_one_or_none()
3109
+
3110
+ if instance is None:
3111
+ raise RuntimeError(f"Instance not found: {instance_id}")
3112
+
3113
+ if instance.locked_by != worker_id:
3114
+ raise RuntimeError(
3115
+ f"Worker {worker_id} does not hold lock for {instance_id}. "
3116
+ f"Locked by: {instance.locked_by}"
3117
+ )
3118
+
3119
+ # Verify subscription exists
3120
+ sub_result = await session.execute(
3121
+ select(ChannelSubscription).where(
3122
+ and_(
3123
+ ChannelSubscription.instance_id == instance_id,
3124
+ ChannelSubscription.channel == channel,
3125
+ )
3126
+ )
3127
+ )
3128
+ subscription: ChannelSubscription | None = sub_result.scalar_one_or_none()
3129
+
3130
+ if subscription is None:
3131
+ raise ValueError(f"Instance {instance_id} is not subscribed to channel {channel}")
3132
+
3133
+ # Update subscription to mark as waiting
3134
+ current_time = datetime.now(UTC)
3135
+ subscription.activity_id = activity_id
3136
+ # Calculate timeout_at if timeout_seconds is provided
3137
+ if timeout_seconds is not None:
3138
+ subscription.timeout_at = current_time + timedelta(seconds=timeout_seconds)
3139
+ else:
3140
+ subscription.timeout_at = None
3141
+
3142
+ # Update instance: set activity, status, release lock
3143
+ await session.execute(
3144
+ update(WorkflowInstance)
3145
+ .where(WorkflowInstance.instance_id == instance_id)
3146
+ .values(
3147
+ current_activity_id=activity_id,
3148
+ status="waiting_for_message",
3149
+ locked_by=None,
3150
+ locked_at=None,
3151
+ lock_expires_at=None,
3152
+ updated_at=current_time,
3153
+ )
3154
+ )
3155
+
3156
+ await session.commit()
3157
+
3158
+ async def get_pending_channel_messages(
3159
+ self,
3160
+ instance_id: str,
3161
+ channel: str,
3162
+ ) -> list[dict[str, Any]]:
3163
+ """
3164
+ Get pending messages for a subscriber on a channel.
3165
+
3166
+ For broadcast mode: messages with id > cursor_message_id
3167
+ For competing mode: unclaimed messages
3168
+
3169
+ Args:
3170
+ instance_id: Workflow instance ID
3171
+ channel: Channel name
3172
+
3173
+ Returns:
3174
+ List of pending messages
3175
+ """
3176
+ session = self._get_session_for_operation()
3177
+ async with self._session_scope(session) as session:
3178
+ # Get subscription info
3179
+ sub_result = await session.execute(
3180
+ select(ChannelSubscription).where(
3181
+ and_(
3182
+ ChannelSubscription.instance_id == instance_id,
3183
+ ChannelSubscription.channel == channel,
3184
+ )
3185
+ )
3186
+ )
3187
+ subscription = sub_result.scalar_one_or_none()
3188
+
3189
+ if subscription is None:
3190
+ return []
3191
+
3192
+ if subscription.mode == "broadcast":
3193
+ # Get messages after cursor
3194
+ cursor = subscription.cursor_message_id or 0
3195
+ msg_result = await session.execute(
3196
+ select(ChannelMessage)
3197
+ .where(
3198
+ and_(
3199
+ ChannelMessage.channel == channel,
3200
+ ChannelMessage.id > cursor,
3201
+ )
3202
+ )
3203
+ .order_by(ChannelMessage.published_at.asc())
3204
+ )
3205
+ else: # competing
3206
+ # Get unclaimed messages (not in channel_message_claims)
3207
+ subquery = select(ChannelMessageClaim.message_id)
3208
+ msg_result = await session.execute(
3209
+ select(ChannelMessage)
3210
+ .where(
3211
+ and_(
3212
+ ChannelMessage.channel == channel,
3213
+ ChannelMessage.message_id.not_in(subquery),
3214
+ )
3215
+ )
3216
+ .order_by(ChannelMessage.published_at.asc())
3217
+ )
3218
+
3219
+ messages = msg_result.scalars().all()
3220
+ return [
3221
+ {
3222
+ "id": msg.id,
3223
+ "message_id": msg.message_id,
3224
+ "channel": msg.channel,
3225
+ "data": (
3226
+ msg.data_binary
3227
+ if msg.data_type == "binary"
3228
+ else json.loads(msg.data) if msg.data else {}
3229
+ ),
3230
+ "metadata": json.loads(msg.message_metadata) if msg.message_metadata else {},
3231
+ "published_at": msg.published_at.isoformat() if msg.published_at else None,
3232
+ }
3233
+ for msg in messages
3234
+ ]
3235
+
3236
+ async def claim_channel_message(
3237
+ self,
3238
+ message_id: str,
3239
+ instance_id: str,
3240
+ ) -> bool:
3241
+ """
3242
+ Claim a message for competing consumption.
3243
+
3244
+ Uses INSERT with conflict check to ensure only one subscriber claims.
3245
+
3246
+ Args:
3247
+ message_id: Message ID to claim
3248
+ instance_id: Workflow instance claiming the message
3249
+
3250
+ Returns:
3251
+ True if claim succeeded, False if already claimed
3252
+ """
3253
+ session = self._get_session_for_operation()
3254
+ async with self._session_scope(session) as session:
3255
+ try:
3256
+ # Check if already claimed
3257
+ result = await session.execute(
3258
+ select(ChannelMessageClaim).where(ChannelMessageClaim.message_id == message_id)
3259
+ )
3260
+ if result.scalar_one_or_none() is not None:
3261
+ return False # Already claimed
3262
+
3263
+ claim = ChannelMessageClaim(
3264
+ message_id=message_id,
3265
+ instance_id=instance_id,
3266
+ )
3267
+ session.add(claim)
3268
+ await self._commit_if_not_in_transaction(session)
3269
+ return True
3270
+ except Exception:
3271
+ return False
3272
+
3273
+ async def delete_channel_message(self, message_id: str) -> None:
3274
+ """
3275
+ Delete a message from the channel queue.
3276
+
3277
+ Args:
3278
+ message_id: Message ID to delete
3279
+ """
3280
+ session = self._get_session_for_operation()
3281
+ async with self._session_scope(session) as session:
3282
+ # Delete claim first (foreign key)
3283
+ await session.execute(
3284
+ delete(ChannelMessageClaim).where(ChannelMessageClaim.message_id == message_id)
3285
+ )
3286
+ # Delete message
3287
+ await session.execute(
3288
+ delete(ChannelMessage).where(ChannelMessage.message_id == message_id)
3289
+ )
3290
+ await self._commit_if_not_in_transaction(session)
3291
+
3292
+ async def update_delivery_cursor(
3293
+ self,
3294
+ channel: str,
3295
+ instance_id: str,
3296
+ message_id: int,
3297
+ ) -> None:
3298
+ """
3299
+ Update the delivery cursor for broadcast mode.
3300
+
3301
+ Args:
3302
+ channel: Channel name
3303
+ instance_id: Subscriber instance ID
3304
+ message_id: Last delivered message's internal ID
3305
+ """
3306
+ session = self._get_session_for_operation()
3307
+ async with self._session_scope(session) as session:
3308
+ # Update subscription cursor
3309
+ await session.execute(
3310
+ update(ChannelSubscription)
3311
+ .where(
3312
+ and_(
3313
+ ChannelSubscription.instance_id == instance_id,
3314
+ ChannelSubscription.channel == channel,
3315
+ )
3316
+ )
3317
+ .values(cursor_message_id=message_id)
3318
+ )
3319
+ await self._commit_if_not_in_transaction(session)
3320
+
3321
+ async def get_channel_subscribers_waiting(
3322
+ self,
3323
+ channel: str,
3324
+ ) -> list[dict[str, Any]]:
3325
+ """
3326
+ Get channel subscribers that are waiting (activity_id is set).
3327
+
3328
+ Args:
3329
+ channel: Channel name
3330
+
3331
+ Returns:
3332
+ List of waiting subscribers
3333
+ """
3334
+ session = self._get_session_for_operation()
3335
+ async with self._session_scope(session) as session:
3336
+ result = await session.execute(
3337
+ select(ChannelSubscription).where(
3338
+ and_(
3339
+ ChannelSubscription.channel == channel,
3340
+ ChannelSubscription.activity_id.isnot(None),
3341
+ )
3342
+ )
3343
+ )
3344
+ subscriptions = result.scalars().all()
3345
+ return [
3346
+ {
3347
+ "instance_id": sub.instance_id,
3348
+ "channel": sub.channel,
3349
+ "mode": sub.mode,
3350
+ "activity_id": sub.activity_id,
3351
+ }
3352
+ for sub in subscriptions
3353
+ ]
3354
+
3355
+ async def clear_channel_waiting_state(
3356
+ self,
3357
+ instance_id: str,
3358
+ channel: str,
3359
+ ) -> None:
3360
+ """
3361
+ Clear the waiting state for a channel subscription.
3362
+
3363
+ Args:
3364
+ instance_id: Workflow instance ID
3365
+ channel: Channel name
3366
+ """
3367
+ session = self._get_session_for_operation()
3368
+ async with self._session_scope(session) as session:
3369
+ await session.execute(
3370
+ update(ChannelSubscription)
3371
+ .where(
3372
+ and_(
3373
+ ChannelSubscription.instance_id == instance_id,
3374
+ ChannelSubscription.channel == channel,
3375
+ )
3376
+ )
3377
+ .values(activity_id=None)
3378
+ )
3379
+ await self._commit_if_not_in_transaction(session)
3380
+
3381
+ async def deliver_channel_message(
3382
+ self,
3383
+ instance_id: str,
3384
+ channel: str,
3385
+ message_id: str,
3386
+ data: dict[str, Any] | bytes,
3387
+ metadata: dict[str, Any],
3388
+ worker_id: str,
3389
+ ) -> dict[str, Any] | None:
3390
+ """
3391
+ Deliver a channel message to a waiting workflow.
3392
+
3393
+ Uses Lock-First pattern for distributed safety.
3394
+
3395
+ Args:
3396
+ instance_id: Target workflow instance ID
3397
+ channel: Channel name
3398
+ message_id: Message ID being delivered
3399
+ data: Message payload
3400
+ metadata: Message metadata
3401
+ worker_id: Worker ID for locking
3402
+
3403
+ Returns:
3404
+ Delivery info if successful, None if failed
3405
+ """
3406
+ try:
3407
+ # Try to acquire lock
3408
+ if not await self.try_acquire_lock(instance_id, worker_id):
3409
+ logger.debug(f"Failed to acquire lock for {instance_id}")
3410
+ return None
3411
+
3412
+ try:
3413
+ async with self.engine.begin() as conn:
3414
+ session = AsyncSession(bind=conn, expire_on_commit=False)
3415
+
3416
+ # Get subscription info
3417
+ result = await session.execute(
3418
+ select(ChannelSubscription).where(
3419
+ and_(
3420
+ ChannelSubscription.instance_id == instance_id,
3421
+ ChannelSubscription.channel == channel,
3422
+ )
3423
+ )
3424
+ )
3425
+ subscription = result.scalar_one_or_none()
3426
+
3427
+ if subscription is None or subscription.activity_id is None:
3428
+ logger.debug(f"No waiting subscription for {instance_id} on {channel}")
3429
+ return None
3430
+
3431
+ activity_id = subscription.activity_id
3432
+
3433
+ # Get instance info for return value
3434
+ result = await session.execute(
3435
+ select(WorkflowInstance.workflow_name).where(
3436
+ WorkflowInstance.instance_id == instance_id
3437
+ )
3438
+ )
3439
+ row = result.one_or_none()
3440
+ if row is None:
3441
+ return None
3442
+ workflow_name = row[0]
3443
+
3444
+ # Prepare message data for history
3445
+ # Use "id" key to match what context.py expects when loading history
3446
+ current_time = datetime.now(UTC)
3447
+ message_result = {
3448
+ "id": message_id,
3449
+ "channel": channel,
3450
+ "data": data if isinstance(data, dict) else None,
3451
+ "metadata": metadata,
3452
+ "published_at": current_time.isoformat(),
3453
+ }
3454
+
3455
+ # Record to history
3456
+ if isinstance(data, bytes):
3457
+ history = WorkflowHistory(
3458
+ instance_id=instance_id,
3459
+ activity_id=activity_id,
3460
+ event_type="ChannelMessageReceived",
3461
+ data_type="binary",
3462
+ event_data=None,
3463
+ event_data_binary=data,
3464
+ )
3465
+ else:
3466
+ history = WorkflowHistory(
3467
+ instance_id=instance_id,
3468
+ activity_id=activity_id,
3469
+ event_type="ChannelMessageReceived",
3470
+ data_type="json",
3471
+ event_data=json.dumps(message_result),
3472
+ event_data_binary=None,
3473
+ )
3474
+ session.add(history)
3475
+
3476
+ # Handle mode-specific logic
3477
+ if subscription.mode == "broadcast":
3478
+ # Get message internal id to update cursor
3479
+ result = await session.execute(
3480
+ select(ChannelMessage.id).where(ChannelMessage.message_id == message_id)
3481
+ )
3482
+ msg_row = result.one_or_none()
3483
+ if msg_row:
3484
+ subscription.cursor_message_id = msg_row[0]
3485
+ else: # competing
3486
+ # Claim and delete the message
3487
+ claim = ChannelMessageClaim(
3488
+ message_id=message_id,
3489
+ instance_id=instance_id,
3490
+ )
3491
+ session.add(claim)
3492
+
3493
+ # Delete the message (competing mode consumes it)
3494
+ await session.execute(
3495
+ delete(ChannelMessage).where(ChannelMessage.message_id == message_id)
3496
+ )
3497
+
3498
+ # Clear waiting state
3499
+ subscription.activity_id = None
3500
+
3501
+ # Update instance status to running
3502
+ current_time = datetime.now(UTC)
3503
+ await session.execute(
3504
+ update(WorkflowInstance)
3505
+ .where(WorkflowInstance.instance_id == instance_id)
3506
+ .values(
3507
+ status="running",
3508
+ updated_at=current_time,
3509
+ )
3510
+ )
3511
+
3512
+ await session.commit()
3513
+
3514
+ return {
3515
+ "instance_id": instance_id,
3516
+ "workflow_name": workflow_name,
3517
+ "activity_id": activity_id,
3518
+ }
3519
+
3520
+ finally:
3521
+ # Always release lock
3522
+ await self.release_lock(instance_id, worker_id)
3523
+
3524
+ except Exception as e:
3525
+ logger.error(f"Error delivering channel message: {e}")
3526
+ return None
3527
+
3528
+ async def cleanup_old_channel_messages(self, older_than_days: int = 7) -> int:
3529
+ """
3530
+ Clean up old messages from channel queues.
3531
+
3532
+ Args:
3533
+ older_than_days: Message retention period in days
3534
+
3535
+ Returns:
3536
+ Number of messages deleted
3537
+ """
3538
+ cutoff_time = datetime.now(UTC) - timedelta(days=older_than_days)
3539
+
3540
+ session = self._get_session_for_operation()
3541
+ async with self._session_scope(session) as session:
3542
+ # First delete claims for old messages
3543
+ await session.execute(
3544
+ delete(ChannelMessageClaim).where(
3545
+ ChannelMessageClaim.message_id.in_(
3546
+ select(ChannelMessage.message_id).where(
3547
+ self._make_datetime_comparable(ChannelMessage.published_at)
3548
+ < self._get_current_time_expr()
3549
+ )
3550
+ )
3551
+ )
3552
+ )
3553
+
3554
+ # Delete old messages
3555
+ result = await session.execute(
3556
+ delete(ChannelMessage)
3557
+ .where(ChannelMessage.published_at < cutoff_time)
3558
+ .returning(ChannelMessage.id)
3559
+ )
3560
+ deleted_ids = result.fetchall()
3561
+ await self._commit_if_not_in_transaction(session)
3562
+
3563
+ return len(deleted_ids)