edda-framework 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edda/__init__.py +39 -5
- edda/app.py +383 -223
- edda/channels.py +992 -0
- edda/compensation.py +22 -22
- edda/context.py +77 -51
- edda/integrations/opentelemetry/hooks.py +7 -2
- edda/locking.py +130 -67
- edda/replay.py +312 -82
- edda/storage/models.py +165 -24
- edda/storage/protocol.py +557 -118
- edda/storage/sqlalchemy_storage.py +1968 -314
- edda/viewer_ui/app.py +6 -1
- edda/viewer_ui/data_service.py +19 -22
- edda/workflow.py +43 -0
- {edda_framework-0.7.0.dist-info → edda_framework-0.8.0.dist-info}/METADATA +165 -9
- {edda_framework-0.7.0.dist-info → edda_framework-0.8.0.dist-info}/RECORD +19 -19
- edda/events.py +0 -505
- {edda_framework-0.7.0.dist-info → edda_framework-0.8.0.dist-info}/WHEEL +0 -0
- {edda_framework-0.7.0.dist-info → edda_framework-0.8.0.dist-info}/entry_points.txt +0 -0
- {edda_framework-0.7.0.dist-info → edda_framework-0.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -19,6 +19,7 @@ from sqlalchemy import (
|
|
|
19
19
|
CheckConstraint,
|
|
20
20
|
Column,
|
|
21
21
|
DateTime,
|
|
22
|
+
ForeignKey,
|
|
22
23
|
ForeignKeyConstraint,
|
|
23
24
|
Index,
|
|
24
25
|
Integer,
|
|
@@ -29,18 +30,21 @@ from sqlalchemy import (
|
|
|
29
30
|
and_,
|
|
30
31
|
delete,
|
|
31
32
|
func,
|
|
33
|
+
inspect,
|
|
32
34
|
or_,
|
|
33
35
|
select,
|
|
34
36
|
text,
|
|
35
37
|
update,
|
|
36
38
|
)
|
|
37
39
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
|
38
|
-
from sqlalchemy.orm import
|
|
40
|
+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
39
41
|
|
|
40
42
|
logger = logging.getLogger(__name__)
|
|
41
43
|
|
|
44
|
+
|
|
42
45
|
# Declarative base for ORM models
|
|
43
|
-
Base
|
|
46
|
+
class Base(DeclarativeBase):
|
|
47
|
+
pass
|
|
44
48
|
|
|
45
49
|
|
|
46
50
|
# ============================================================================
|
|
@@ -48,25 +52,25 @@ Base = declarative_base()
|
|
|
48
52
|
# ============================================================================
|
|
49
53
|
|
|
50
54
|
|
|
51
|
-
class SchemaVersion(Base):
|
|
55
|
+
class SchemaVersion(Base):
|
|
52
56
|
"""Schema version tracking."""
|
|
53
57
|
|
|
54
58
|
__tablename__ = "schema_version"
|
|
55
59
|
|
|
56
|
-
version =
|
|
57
|
-
applied_at =
|
|
58
|
-
description =
|
|
60
|
+
version: Mapped[int] = mapped_column(Integer, primary_key=True)
|
|
61
|
+
applied_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
62
|
+
description: Mapped[str] = mapped_column(Text)
|
|
59
63
|
|
|
60
64
|
|
|
61
|
-
class WorkflowDefinition(Base):
|
|
65
|
+
class WorkflowDefinition(Base):
|
|
62
66
|
"""Workflow definition (source code storage)."""
|
|
63
67
|
|
|
64
68
|
__tablename__ = "workflow_definitions"
|
|
65
69
|
|
|
66
|
-
workflow_name =
|
|
67
|
-
source_hash =
|
|
68
|
-
source_code =
|
|
69
|
-
created_at =
|
|
70
|
+
workflow_name: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
71
|
+
source_hash: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
72
|
+
source_code: Mapped[str] = mapped_column(Text)
|
|
73
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
70
74
|
|
|
71
75
|
__table_args__ = (
|
|
72
76
|
Index("idx_definitions_name", "workflow_name"),
|
|
@@ -74,31 +78,34 @@ class WorkflowDefinition(Base): # type: ignore[valid-type, misc]
|
|
|
74
78
|
)
|
|
75
79
|
|
|
76
80
|
|
|
77
|
-
class WorkflowInstance(Base):
|
|
81
|
+
class WorkflowInstance(Base):
|
|
78
82
|
"""Workflow instance with distributed locking support."""
|
|
79
83
|
|
|
80
84
|
__tablename__ = "workflow_instances"
|
|
81
85
|
|
|
82
|
-
instance_id =
|
|
83
|
-
workflow_name =
|
|
84
|
-
source_hash =
|
|
85
|
-
owner_service =
|
|
86
|
-
status =
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
86
|
+
instance_id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
87
|
+
workflow_name: Mapped[str] = mapped_column(String(255))
|
|
88
|
+
source_hash: Mapped[str] = mapped_column(String(64))
|
|
89
|
+
owner_service: Mapped[str] = mapped_column(String(255))
|
|
90
|
+
status: Mapped[str] = mapped_column(String(50), server_default=text("'running'"))
|
|
91
|
+
current_activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
92
|
+
continued_from: Mapped[str | None] = mapped_column(
|
|
93
|
+
String(255), ForeignKey("workflow_instances.instance_id"), nullable=True
|
|
90
94
|
)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
95
|
+
started_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
96
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
97
|
+
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
|
95
98
|
)
|
|
96
|
-
input_data =
|
|
97
|
-
output_data =
|
|
98
|
-
locked_by =
|
|
99
|
-
locked_at =
|
|
100
|
-
lock_timeout_seconds
|
|
101
|
-
|
|
99
|
+
input_data: Mapped[str] = mapped_column(Text) # JSON
|
|
100
|
+
output_data: Mapped[str | None] = mapped_column(Text, nullable=True) # JSON
|
|
101
|
+
locked_by: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
102
|
+
locked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
103
|
+
lock_timeout_seconds: Mapped[int | None] = mapped_column(
|
|
104
|
+
Integer, nullable=True
|
|
105
|
+
) # None = use global default (300s)
|
|
106
|
+
lock_expires_at: Mapped[datetime | None] = mapped_column(
|
|
107
|
+
DateTime(timezone=True), nullable=True
|
|
108
|
+
) # Absolute expiry time
|
|
102
109
|
|
|
103
110
|
__table_args__ = (
|
|
104
111
|
ForeignKeyConstraint(
|
|
@@ -107,7 +114,7 @@ class WorkflowInstance(Base): # type: ignore[valid-type, misc]
|
|
|
107
114
|
),
|
|
108
115
|
CheckConstraint(
|
|
109
116
|
"status IN ('running', 'completed', 'failed', 'waiting_for_event', "
|
|
110
|
-
"'waiting_for_timer', 'compensating', 'cancelled')",
|
|
117
|
+
"'waiting_for_timer', 'waiting_for_message', 'compensating', 'cancelled', 'recurred')",
|
|
111
118
|
name="valid_status",
|
|
112
119
|
),
|
|
113
120
|
Index("idx_instances_status", "status"),
|
|
@@ -117,25 +124,29 @@ class WorkflowInstance(Base): # type: ignore[valid-type, misc]
|
|
|
117
124
|
Index("idx_instances_lock_expires", "lock_expires_at"),
|
|
118
125
|
Index("idx_instances_updated", "updated_at"),
|
|
119
126
|
Index("idx_instances_hash", "source_hash"),
|
|
127
|
+
Index("idx_instances_continued_from", "continued_from"),
|
|
128
|
+
# Composite index for find_resumable_workflows(): WHERE status='running' AND locked_by IS NULL
|
|
129
|
+
Index("idx_instances_resumable", "status", "locked_by"),
|
|
120
130
|
)
|
|
121
131
|
|
|
122
132
|
|
|
123
|
-
class WorkflowHistory(Base):
|
|
133
|
+
class WorkflowHistory(Base):
|
|
124
134
|
"""Workflow execution history (for deterministic replay)."""
|
|
125
135
|
|
|
126
136
|
__tablename__ = "workflow_history"
|
|
127
137
|
|
|
128
|
-
id =
|
|
129
|
-
instance_id =
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
138
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
139
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
140
|
+
activity_id: Mapped[str] = mapped_column(String(255))
|
|
141
|
+
event_type: Mapped[str] = mapped_column(String(100))
|
|
142
|
+
data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
|
|
143
|
+
event_data: Mapped[str | None] = mapped_column(
|
|
144
|
+
Text, nullable=True
|
|
145
|
+
) # JSON (when data_type='json')
|
|
146
|
+
event_data_binary: Mapped[bytes | None] = mapped_column(
|
|
147
|
+
LargeBinary, nullable=True
|
|
148
|
+
) # Binary (when data_type='binary')
|
|
149
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
139
150
|
|
|
140
151
|
__table_args__ = (
|
|
141
152
|
ForeignKeyConstraint(
|
|
@@ -158,20 +169,20 @@ class WorkflowHistory(Base): # type: ignore[valid-type, misc]
|
|
|
158
169
|
)
|
|
159
170
|
|
|
160
171
|
|
|
161
|
-
class
|
|
162
|
-
"""
|
|
172
|
+
class WorkflowHistoryArchive(Base):
|
|
173
|
+
"""Archived workflow execution history (for recur pattern)."""
|
|
163
174
|
|
|
164
|
-
__tablename__ = "
|
|
175
|
+
__tablename__ = "workflow_history_archive"
|
|
165
176
|
|
|
166
|
-
id =
|
|
167
|
-
instance_id =
|
|
168
|
-
|
|
169
|
-
|
|
177
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
178
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
179
|
+
activity_id: Mapped[str] = mapped_column(String(255))
|
|
180
|
+
event_type: Mapped[str] = mapped_column(String(100))
|
|
181
|
+
event_data: Mapped[str] = mapped_column(Text) # JSON (includes both types for archive)
|
|
182
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
183
|
+
archived_at: Mapped[datetime] = mapped_column(
|
|
184
|
+
DateTime(timezone=True), server_default=func.now()
|
|
170
185
|
)
|
|
171
|
-
activity_id = Column(String(255), nullable=False)
|
|
172
|
-
activity_name = Column(String(255), nullable=False)
|
|
173
|
-
args = Column(Text, nullable=False) # JSON
|
|
174
|
-
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
175
186
|
|
|
176
187
|
__table_args__ = (
|
|
177
188
|
ForeignKeyConstraint(
|
|
@@ -179,24 +190,22 @@ class WorkflowCompensation(Base): # type: ignore[valid-type, misc]
|
|
|
179
190
|
["workflow_instances.instance_id"],
|
|
180
191
|
ondelete="CASCADE",
|
|
181
192
|
),
|
|
182
|
-
Index("
|
|
193
|
+
Index("idx_history_archive_instance", "instance_id"),
|
|
194
|
+
Index("idx_history_archive_archived", "archived_at"),
|
|
183
195
|
)
|
|
184
196
|
|
|
185
197
|
|
|
186
|
-
class
|
|
187
|
-
"""
|
|
198
|
+
class WorkflowCompensation(Base):
|
|
199
|
+
"""Compensation transactions (LIFO stack for Saga pattern)."""
|
|
188
200
|
|
|
189
|
-
__tablename__ = "
|
|
201
|
+
__tablename__ = "workflow_compensations"
|
|
190
202
|
|
|
191
|
-
id =
|
|
192
|
-
instance_id =
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
activity_id = Column(String(255), nullable=True)
|
|
198
|
-
timeout_at = Column(DateTime(timezone=True), nullable=True)
|
|
199
|
-
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
203
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
204
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
205
|
+
activity_id: Mapped[str] = mapped_column(String(255))
|
|
206
|
+
activity_name: Mapped[str] = mapped_column(String(255))
|
|
207
|
+
args: Mapped[str] = mapped_column(Text) # JSON
|
|
208
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
200
209
|
|
|
201
210
|
__table_args__ = (
|
|
202
211
|
ForeignKeyConstraint(
|
|
@@ -204,27 +213,21 @@ class WorkflowEventSubscription(Base): # type: ignore[valid-type, misc]
|
|
|
204
213
|
["workflow_instances.instance_id"],
|
|
205
214
|
ondelete="CASCADE",
|
|
206
215
|
),
|
|
207
|
-
|
|
208
|
-
Index("idx_subscriptions_event", "event_type"),
|
|
209
|
-
Index("idx_subscriptions_timeout", "timeout_at"),
|
|
210
|
-
Index("idx_subscriptions_instance", "instance_id"),
|
|
216
|
+
Index("idx_compensations_instance", "instance_id", "created_at"),
|
|
211
217
|
)
|
|
212
218
|
|
|
213
219
|
|
|
214
|
-
class WorkflowTimerSubscription(Base):
|
|
220
|
+
class WorkflowTimerSubscription(Base):
|
|
215
221
|
"""Timer subscriptions (for wait_timer)."""
|
|
216
222
|
|
|
217
223
|
__tablename__ = "workflow_timer_subscriptions"
|
|
218
224
|
|
|
219
|
-
id =
|
|
220
|
-
instance_id =
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
expires_at = Column(DateTime(timezone=True), nullable=False)
|
|
226
|
-
activity_id = Column(String(255), nullable=True)
|
|
227
|
-
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
225
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
226
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
227
|
+
timer_id: Mapped[str] = mapped_column(String(255))
|
|
228
|
+
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
229
|
+
activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
230
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
228
231
|
|
|
229
232
|
__table_args__ = (
|
|
230
233
|
ForeignKeyConstraint(
|
|
@@ -238,23 +241,29 @@ class WorkflowTimerSubscription(Base): # type: ignore[valid-type, misc]
|
|
|
238
241
|
)
|
|
239
242
|
|
|
240
243
|
|
|
241
|
-
class OutboxEvent(Base):
|
|
244
|
+
class OutboxEvent(Base):
|
|
242
245
|
"""Transactional outbox pattern events."""
|
|
243
246
|
|
|
244
247
|
__tablename__ = "outbox_events"
|
|
245
248
|
|
|
246
|
-
event_id =
|
|
247
|
-
event_type =
|
|
248
|
-
event_source =
|
|
249
|
-
data_type =
|
|
250
|
-
event_data
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
249
|
+
event_id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
250
|
+
event_type: Mapped[str] = mapped_column(String(255))
|
|
251
|
+
event_source: Mapped[str] = mapped_column(String(255))
|
|
252
|
+
data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
|
|
253
|
+
event_data: Mapped[str | None] = mapped_column(
|
|
254
|
+
Text, nullable=True
|
|
255
|
+
) # JSON (when data_type='json')
|
|
256
|
+
event_data_binary: Mapped[bytes | None] = mapped_column(
|
|
257
|
+
LargeBinary, nullable=True
|
|
258
|
+
) # Binary (when data_type='binary')
|
|
259
|
+
content_type: Mapped[str] = mapped_column(
|
|
260
|
+
String(100), server_default=text("'application/json'")
|
|
261
|
+
)
|
|
262
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
263
|
+
published_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
264
|
+
status: Mapped[str] = mapped_column(String(50), server_default=text("'pending'"))
|
|
265
|
+
retry_count: Mapped[int] = mapped_column(Integer, server_default=text("0"))
|
|
266
|
+
last_error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
258
267
|
|
|
259
268
|
__table_args__ = (
|
|
260
269
|
CheckConstraint(
|
|
@@ -276,6 +285,203 @@ class OutboxEvent(Base): # type: ignore[valid-type, misc]
|
|
|
276
285
|
)
|
|
277
286
|
|
|
278
287
|
|
|
288
|
+
class WorkflowMessageSubscription(Base):
|
|
289
|
+
"""Message subscriptions (for wait_message)."""
|
|
290
|
+
|
|
291
|
+
__tablename__ = "workflow_message_subscriptions"
|
|
292
|
+
|
|
293
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
294
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
295
|
+
channel: Mapped[str] = mapped_column(String(255))
|
|
296
|
+
activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
297
|
+
timeout_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
298
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
299
|
+
|
|
300
|
+
__table_args__ = (
|
|
301
|
+
ForeignKeyConstraint(
|
|
302
|
+
["instance_id"],
|
|
303
|
+
["workflow_instances.instance_id"],
|
|
304
|
+
ondelete="CASCADE",
|
|
305
|
+
),
|
|
306
|
+
UniqueConstraint("instance_id", "channel", name="unique_instance_channel"),
|
|
307
|
+
Index("idx_message_subscriptions_channel", "channel"),
|
|
308
|
+
Index("idx_message_subscriptions_timeout", "timeout_at"),
|
|
309
|
+
Index("idx_message_subscriptions_instance", "instance_id"),
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class WorkflowGroupMembership(Base):
|
|
314
|
+
"""Group memberships (Erlang pg style for broadcast messaging)."""
|
|
315
|
+
|
|
316
|
+
__tablename__ = "workflow_group_memberships"
|
|
317
|
+
|
|
318
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
319
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
320
|
+
group_name: Mapped[str] = mapped_column(String(255))
|
|
321
|
+
joined_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
322
|
+
|
|
323
|
+
__table_args__ = (
|
|
324
|
+
ForeignKeyConstraint(
|
|
325
|
+
["instance_id"],
|
|
326
|
+
["workflow_instances.instance_id"],
|
|
327
|
+
ondelete="CASCADE",
|
|
328
|
+
),
|
|
329
|
+
UniqueConstraint("instance_id", "group_name", name="unique_instance_group"),
|
|
330
|
+
Index("idx_group_memberships_group", "group_name"),
|
|
331
|
+
Index("idx_group_memberships_instance", "instance_id"),
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
# =============================================================================
|
|
336
|
+
# Channel-based Message Queue Models
|
|
337
|
+
# =============================================================================
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class ChannelMessage(Base):
|
|
341
|
+
"""Channel message queue (persistent message storage)."""
|
|
342
|
+
|
|
343
|
+
__tablename__ = "channel_messages"
|
|
344
|
+
|
|
345
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
346
|
+
channel: Mapped[str] = mapped_column(String(255))
|
|
347
|
+
message_id: Mapped[str] = mapped_column(String(255), unique=True)
|
|
348
|
+
data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
|
|
349
|
+
data: Mapped[str | None] = mapped_column(Text, nullable=True) # JSON (when data_type='json')
|
|
350
|
+
data_binary: Mapped[bytes | None] = mapped_column(
|
|
351
|
+
LargeBinary, nullable=True
|
|
352
|
+
) # Binary (when data_type='binary')
|
|
353
|
+
message_metadata: Mapped[str | None] = mapped_column(
|
|
354
|
+
"metadata", Text, nullable=True
|
|
355
|
+
) # JSON - renamed to avoid SQLAlchemy reserved name
|
|
356
|
+
published_at: Mapped[datetime] = mapped_column(
|
|
357
|
+
DateTime(timezone=True), server_default=func.now()
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
__table_args__ = (
|
|
361
|
+
CheckConstraint(
|
|
362
|
+
"data_type IN ('json', 'binary')",
|
|
363
|
+
name="channel_valid_data_type",
|
|
364
|
+
),
|
|
365
|
+
CheckConstraint(
|
|
366
|
+
"(data_type = 'json' AND data IS NOT NULL AND data_binary IS NULL) OR "
|
|
367
|
+
"(data_type = 'binary' AND data IS NULL AND data_binary IS NOT NULL)",
|
|
368
|
+
name="channel_data_type_consistency",
|
|
369
|
+
),
|
|
370
|
+
Index("idx_channel_messages_channel", "channel", "published_at"),
|
|
371
|
+
Index("idx_channel_messages_id", "id"),
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class ChannelSubscription(Base):
|
|
376
|
+
"""Channel subscriptions for workflow instances."""
|
|
377
|
+
|
|
378
|
+
__tablename__ = "channel_subscriptions"
|
|
379
|
+
|
|
380
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
381
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
382
|
+
channel: Mapped[str] = mapped_column(String(255))
|
|
383
|
+
mode: Mapped[str] = mapped_column(String(20)) # 'broadcast' or 'competing'
|
|
384
|
+
activity_id: Mapped[str | None] = mapped_column(
|
|
385
|
+
String(255), nullable=True
|
|
386
|
+
) # Set when waiting for message
|
|
387
|
+
cursor_message_id: Mapped[int | None] = mapped_column(
|
|
388
|
+
Integer, nullable=True
|
|
389
|
+
) # Last received message id (broadcast)
|
|
390
|
+
timeout_at: Mapped[datetime | None] = mapped_column(
|
|
391
|
+
DateTime(timezone=True), nullable=True
|
|
392
|
+
) # Timeout deadline
|
|
393
|
+
subscribed_at: Mapped[datetime] = mapped_column(
|
|
394
|
+
DateTime(timezone=True), server_default=func.now()
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
__table_args__ = (
|
|
398
|
+
ForeignKeyConstraint(
|
|
399
|
+
["instance_id"],
|
|
400
|
+
["workflow_instances.instance_id"],
|
|
401
|
+
ondelete="CASCADE",
|
|
402
|
+
),
|
|
403
|
+
CheckConstraint(
|
|
404
|
+
"mode IN ('broadcast', 'competing')",
|
|
405
|
+
name="channel_valid_mode",
|
|
406
|
+
),
|
|
407
|
+
UniqueConstraint("instance_id", "channel", name="unique_channel_instance_channel"),
|
|
408
|
+
Index("idx_channel_subscriptions_channel", "channel"),
|
|
409
|
+
Index("idx_channel_subscriptions_instance", "instance_id"),
|
|
410
|
+
Index("idx_channel_subscriptions_waiting", "channel", "activity_id"),
|
|
411
|
+
Index("idx_channel_subscriptions_timeout", "timeout_at"),
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class ChannelDeliveryCursor(Base):
|
|
416
|
+
"""Channel delivery cursors (broadcast mode: track who read what)."""
|
|
417
|
+
|
|
418
|
+
__tablename__ = "channel_delivery_cursors"
|
|
419
|
+
|
|
420
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
421
|
+
channel: Mapped[str] = mapped_column(String(255))
|
|
422
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
423
|
+
last_delivered_id: Mapped[int] = mapped_column(Integer)
|
|
424
|
+
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
425
|
+
|
|
426
|
+
__table_args__ = (
|
|
427
|
+
ForeignKeyConstraint(
|
|
428
|
+
["instance_id"],
|
|
429
|
+
["workflow_instances.instance_id"],
|
|
430
|
+
ondelete="CASCADE",
|
|
431
|
+
),
|
|
432
|
+
UniqueConstraint("channel", "instance_id", name="unique_channel_delivery_cursor"),
|
|
433
|
+
Index("idx_channel_delivery_cursors_channel", "channel"),
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class ChannelMessageClaim(Base):
|
|
438
|
+
"""Channel message claims (competing mode: who is processing what)."""
|
|
439
|
+
|
|
440
|
+
__tablename__ = "channel_message_claims"
|
|
441
|
+
|
|
442
|
+
message_id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
443
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
444
|
+
claimed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
445
|
+
|
|
446
|
+
__table_args__ = (
|
|
447
|
+
ForeignKeyConstraint(
|
|
448
|
+
["message_id"],
|
|
449
|
+
["channel_messages.message_id"],
|
|
450
|
+
ondelete="CASCADE",
|
|
451
|
+
),
|
|
452
|
+
ForeignKeyConstraint(
|
|
453
|
+
["instance_id"],
|
|
454
|
+
["workflow_instances.instance_id"],
|
|
455
|
+
ondelete="CASCADE",
|
|
456
|
+
),
|
|
457
|
+
Index("idx_channel_message_claims_instance", "instance_id"),
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
# =============================================================================
|
|
462
|
+
# System-level Lock Models (for background task coordination)
|
|
463
|
+
# =============================================================================
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
class SystemLock(Base):
|
|
467
|
+
"""System-level locks for coordinating background tasks across pods.
|
|
468
|
+
|
|
469
|
+
Used to prevent duplicate execution of operational tasks like:
|
|
470
|
+
- cleanup_stale_locks_periodically()
|
|
471
|
+
- auto_resume_stale_workflows_periodically()
|
|
472
|
+
- _cleanup_old_messages_periodically()
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
__tablename__ = "system_locks"
|
|
476
|
+
|
|
477
|
+
lock_name: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
478
|
+
locked_by: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
479
|
+
locked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
480
|
+
lock_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
481
|
+
|
|
482
|
+
__table_args__ = (Index("idx_system_locks_expires", "lock_expires_at"),)
|
|
483
|
+
|
|
484
|
+
|
|
279
485
|
# Current schema version
|
|
280
486
|
CURRENT_SCHEMA_VERSION = 1
|
|
281
487
|
|
|
@@ -337,11 +543,22 @@ class SQLAlchemyStorage:
|
|
|
337
543
|
self.engine = engine
|
|
338
544
|
|
|
339
545
|
async def initialize(self) -> None:
|
|
340
|
-
"""Initialize database connection and create tables.
|
|
546
|
+
"""Initialize database connection and create tables.
|
|
547
|
+
|
|
548
|
+
This method creates all tables if they don't exist, and then performs
|
|
549
|
+
automatic schema migration to add any missing columns and update CHECK
|
|
550
|
+
constraints. This ensures compatibility when upgrading Edda versions.
|
|
551
|
+
"""
|
|
341
552
|
# Create all tables and indexes
|
|
342
553
|
async with self.engine.begin() as conn:
|
|
343
554
|
await conn.run_sync(Base.metadata.create_all)
|
|
344
555
|
|
|
556
|
+
# Auto-migrate schema (add missing columns)
|
|
557
|
+
await self._auto_migrate_schema()
|
|
558
|
+
|
|
559
|
+
# Auto-migrate CHECK constraints
|
|
560
|
+
await self._auto_migrate_check_constraints()
|
|
561
|
+
|
|
345
562
|
# Initialize schema version
|
|
346
563
|
await self._initialize_schema_version()
|
|
347
564
|
|
|
@@ -366,6 +583,198 @@ class SQLAlchemyStorage:
|
|
|
366
583
|
await session.commit()
|
|
367
584
|
logger.info(f"Initialized schema version to {CURRENT_SCHEMA_VERSION}")
|
|
368
585
|
|
|
586
|
+
async def _auto_migrate_schema(self) -> None:
|
|
587
|
+
"""
|
|
588
|
+
Automatically add missing columns to existing tables.
|
|
589
|
+
|
|
590
|
+
This method compares the ORM model definitions with the actual database
|
|
591
|
+
schema and adds any missing columns using ALTER TABLE ADD COLUMN.
|
|
592
|
+
|
|
593
|
+
Note: This only handles column additions, not removals or type changes.
|
|
594
|
+
For complex migrations, use Alembic.
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
def _get_column_type_sql(column: Column, dialect_name: str) -> str: # type: ignore[type-arg]
|
|
598
|
+
"""Get SQL type string for a column based on dialect."""
|
|
599
|
+
col_type = column.type
|
|
600
|
+
|
|
601
|
+
# Map SQLAlchemy types to SQL types
|
|
602
|
+
if isinstance(col_type, String):
|
|
603
|
+
length = col_type.length or 255
|
|
604
|
+
return f"VARCHAR({length})"
|
|
605
|
+
elif isinstance(col_type, Text):
|
|
606
|
+
return "TEXT"
|
|
607
|
+
elif isinstance(col_type, Integer):
|
|
608
|
+
return "INTEGER"
|
|
609
|
+
elif isinstance(col_type, DateTime):
|
|
610
|
+
if dialect_name == "postgresql":
|
|
611
|
+
return "TIMESTAMP WITH TIME ZONE" if col_type.timezone else "TIMESTAMP"
|
|
612
|
+
elif dialect_name == "mysql":
|
|
613
|
+
return "DATETIME" if not col_type.timezone else "DATETIME"
|
|
614
|
+
else: # sqlite
|
|
615
|
+
return "DATETIME"
|
|
616
|
+
elif isinstance(col_type, LargeBinary):
|
|
617
|
+
if dialect_name == "postgresql":
|
|
618
|
+
return "BYTEA"
|
|
619
|
+
elif dialect_name == "mysql":
|
|
620
|
+
return "LONGBLOB"
|
|
621
|
+
else: # sqlite
|
|
622
|
+
return "BLOB"
|
|
623
|
+
else:
|
|
624
|
+
# Fallback to compiled type
|
|
625
|
+
return str(col_type.compile(dialect=self.engine.dialect))
|
|
626
|
+
|
|
627
|
+
def _get_default_sql(column: Column, _dialect_name: str) -> str | None: # type: ignore[type-arg]
|
|
628
|
+
"""Get DEFAULT clause for a column if applicable."""
|
|
629
|
+
if column.server_default is not None:
|
|
630
|
+
# Handle text() server defaults - try to get the arg attribute
|
|
631
|
+
server_default = column.server_default
|
|
632
|
+
if hasattr(server_default, "arg"):
|
|
633
|
+
default_val = server_default.arg
|
|
634
|
+
if hasattr(default_val, "text"):
|
|
635
|
+
return f"DEFAULT {default_val.text}"
|
|
636
|
+
return f"DEFAULT {default_val}"
|
|
637
|
+
return None
|
|
638
|
+
|
|
639
|
+
def _run_migration(conn: Any) -> None:
|
|
640
|
+
"""Run migration in sync context."""
|
|
641
|
+
dialect_name = self.engine.dialect.name
|
|
642
|
+
inspector = inspect(conn)
|
|
643
|
+
|
|
644
|
+
# Iterate through all ORM tables
|
|
645
|
+
for table in Base.metadata.tables.values():
|
|
646
|
+
table_name = table.name
|
|
647
|
+
|
|
648
|
+
# Check if table exists
|
|
649
|
+
if not inspector.has_table(table_name):
|
|
650
|
+
logger.debug(f"Table {table_name} does not exist, skipping migration")
|
|
651
|
+
continue
|
|
652
|
+
|
|
653
|
+
# Get existing columns
|
|
654
|
+
existing_columns = {col["name"] for col in inspector.get_columns(table_name)}
|
|
655
|
+
|
|
656
|
+
# Check each column in the ORM model
|
|
657
|
+
for column in table.columns:
|
|
658
|
+
if column.name not in existing_columns:
|
|
659
|
+
# Column is missing, generate ALTER TABLE
|
|
660
|
+
col_type_sql = _get_column_type_sql(column, dialect_name)
|
|
661
|
+
nullable = "NULL" if column.nullable else "NOT NULL"
|
|
662
|
+
|
|
663
|
+
# Build ALTER TABLE statement
|
|
664
|
+
alter_sql = (
|
|
665
|
+
f'ALTER TABLE "{table_name}" ADD COLUMN "{column.name}" {col_type_sql}'
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
# Add nullable constraint (only if NOT NULL and has default)
|
|
669
|
+
default_sql = _get_default_sql(column, dialect_name)
|
|
670
|
+
if not column.nullable and default_sql:
|
|
671
|
+
alter_sql += f" {default_sql} {nullable}"
|
|
672
|
+
elif column.nullable:
|
|
673
|
+
alter_sql += f" {nullable}"
|
|
674
|
+
elif default_sql:
|
|
675
|
+
alter_sql += f" {default_sql}"
|
|
676
|
+
# For NOT NULL without default, just add the column as nullable
|
|
677
|
+
# (PostgreSQL requires default or nullable for existing rows)
|
|
678
|
+
else:
|
|
679
|
+
alter_sql += " NULL"
|
|
680
|
+
|
|
681
|
+
logger.info(f"Auto-migrating: Adding column {column.name} to {table_name}")
|
|
682
|
+
logger.debug(f"Executing: {alter_sql}")
|
|
683
|
+
|
|
684
|
+
try:
|
|
685
|
+
conn.execute(text(alter_sql))
|
|
686
|
+
except Exception as e:
|
|
687
|
+
logger.warning(
|
|
688
|
+
f"Failed to add column {column.name} to {table_name}: {e}"
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
async with self.engine.begin() as conn:
|
|
692
|
+
await conn.run_sync(_run_migration)
|
|
693
|
+
|
|
694
|
+
async def _auto_migrate_check_constraints(self) -> None:
|
|
695
|
+
"""
|
|
696
|
+
Automatically update CHECK constraints for workflow status.
|
|
697
|
+
|
|
698
|
+
This method ensures the valid_status CHECK constraint includes all
|
|
699
|
+
required status values (including 'waiting_for_message').
|
|
700
|
+
"""
|
|
701
|
+
dialect_name = self.engine.dialect.name
|
|
702
|
+
|
|
703
|
+
# SQLite doesn't support ALTER CONSTRAINT easily, and SQLAlchemy create_all
|
|
704
|
+
# handles it correctly for new databases. For existing SQLite databases,
|
|
705
|
+
# the constraint is more lenient (CHECK is not enforced in many SQLite versions).
|
|
706
|
+
if dialect_name == "sqlite":
|
|
707
|
+
return
|
|
708
|
+
|
|
709
|
+
# Expected status values (must match WorkflowInstance model)
|
|
710
|
+
expected_statuses = (
|
|
711
|
+
"'running', 'completed', 'failed', 'waiting_for_event', "
|
|
712
|
+
"'waiting_for_timer', 'waiting_for_message', 'compensating', 'cancelled', 'recurred'"
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
def _run_constraint_migration(conn: Any) -> None:
|
|
716
|
+
"""Run CHECK constraint migration in sync context."""
|
|
717
|
+
inspector = inspect(conn)
|
|
718
|
+
|
|
719
|
+
# Check if workflow_instances table exists
|
|
720
|
+
if not inspector.has_table("workflow_instances"):
|
|
721
|
+
return
|
|
722
|
+
|
|
723
|
+
# Get existing CHECK constraints
|
|
724
|
+
try:
|
|
725
|
+
constraints = inspector.get_check_constraints("workflow_instances")
|
|
726
|
+
except NotImplementedError:
|
|
727
|
+
# Some databases don't support get_check_constraints
|
|
728
|
+
logger.debug("Database doesn't support get_check_constraints inspection")
|
|
729
|
+
constraints = []
|
|
730
|
+
|
|
731
|
+
# Find the valid_status constraint
|
|
732
|
+
valid_status_constraint = None
|
|
733
|
+
for constraint in constraints:
|
|
734
|
+
if constraint.get("name") == "valid_status":
|
|
735
|
+
valid_status_constraint = constraint
|
|
736
|
+
break
|
|
737
|
+
|
|
738
|
+
# Check if constraint exists and needs updating
|
|
739
|
+
if valid_status_constraint:
|
|
740
|
+
sqltext = valid_status_constraint.get("sqltext", "")
|
|
741
|
+
# Check if 'waiting_for_message' is already in the constraint
|
|
742
|
+
if "waiting_for_message" in sqltext:
|
|
743
|
+
logger.debug("valid_status constraint already includes waiting_for_message")
|
|
744
|
+
return
|
|
745
|
+
|
|
746
|
+
# Need to update the constraint
|
|
747
|
+
logger.info("Updating valid_status CHECK constraint to include waiting_for_message")
|
|
748
|
+
try:
|
|
749
|
+
if dialect_name == "postgresql":
|
|
750
|
+
conn.execute(
|
|
751
|
+
text("ALTER TABLE workflow_instances DROP CONSTRAINT valid_status")
|
|
752
|
+
)
|
|
753
|
+
conn.execute(
|
|
754
|
+
text(
|
|
755
|
+
f"ALTER TABLE workflow_instances ADD CONSTRAINT valid_status "
|
|
756
|
+
f"CHECK (status IN ({expected_statuses}))"
|
|
757
|
+
)
|
|
758
|
+
)
|
|
759
|
+
elif dialect_name == "mysql":
|
|
760
|
+
# MySQL uses DROP CHECK and ADD CONSTRAINT CHECK syntax
|
|
761
|
+
conn.execute(text("ALTER TABLE workflow_instances DROP CHECK valid_status"))
|
|
762
|
+
conn.execute(
|
|
763
|
+
text(
|
|
764
|
+
f"ALTER TABLE workflow_instances ADD CONSTRAINT valid_status "
|
|
765
|
+
f"CHECK (status IN ({expected_statuses}))"
|
|
766
|
+
)
|
|
767
|
+
)
|
|
768
|
+
logger.info("Successfully updated valid_status CHECK constraint")
|
|
769
|
+
except Exception as e:
|
|
770
|
+
logger.warning(f"Failed to update valid_status CHECK constraint: {e}")
|
|
771
|
+
else:
|
|
772
|
+
# Constraint doesn't exist (shouldn't happen with create_all, but handle it)
|
|
773
|
+
logger.debug("valid_status constraint not found, will be created by create_all")
|
|
774
|
+
|
|
775
|
+
async with self.engine.begin() as conn:
|
|
776
|
+
await conn.run_sync(_run_constraint_migration)
|
|
777
|
+
|
|
369
778
|
def _get_session_for_operation(self, is_lock_operation: bool = False) -> AsyncSession:
|
|
370
779
|
"""
|
|
371
780
|
Get the appropriate session for an operation.
|
|
@@ -452,7 +861,7 @@ class SQLAlchemyStorage:
|
|
|
452
861
|
Example:
|
|
453
862
|
>>> # SQLite: datetime(timeout_at) <= datetime('now')
|
|
454
863
|
>>> # PostgreSQL/MySQL: timeout_at <= NOW()
|
|
455
|
-
>>> self._make_datetime_comparable(
|
|
864
|
+
>>> self._make_datetime_comparable(WorkflowMessageSubscription.timeout_at)
|
|
456
865
|
>>> <= self._get_current_time_expr()
|
|
457
866
|
"""
|
|
458
867
|
if self.engine.dialect.name == "sqlite":
|
|
@@ -605,7 +1014,7 @@ class SQLAlchemyStorage:
|
|
|
605
1014
|
|
|
606
1015
|
if existing:
|
|
607
1016
|
# Update
|
|
608
|
-
existing.source_code = source_code
|
|
1017
|
+
existing.source_code = source_code
|
|
609
1018
|
else:
|
|
610
1019
|
# Insert
|
|
611
1020
|
definition = WorkflowDefinition(
|
|
@@ -682,6 +1091,7 @@ class SQLAlchemyStorage:
|
|
|
682
1091
|
owner_service: str,
|
|
683
1092
|
input_data: dict[str, Any],
|
|
684
1093
|
lock_timeout_seconds: int | None = None,
|
|
1094
|
+
continued_from: str | None = None,
|
|
685
1095
|
) -> None:
|
|
686
1096
|
"""Create a new workflow instance."""
|
|
687
1097
|
session = self._get_session_for_operation()
|
|
@@ -693,6 +1103,7 @@ class SQLAlchemyStorage:
|
|
|
693
1103
|
owner_service=owner_service,
|
|
694
1104
|
input_data=json.dumps(input_data),
|
|
695
1105
|
lock_timeout_seconds=lock_timeout_seconds,
|
|
1106
|
+
continued_from=continued_from,
|
|
696
1107
|
)
|
|
697
1108
|
session.add(instance)
|
|
698
1109
|
|
|
@@ -1165,6 +1576,103 @@ class SQLAlchemyStorage:
|
|
|
1165
1576
|
await session.commit()
|
|
1166
1577
|
return workflows_to_resume
|
|
1167
1578
|
|
|
1579
|
+
# -------------------------------------------------------------------------
|
|
1580
|
+
# System-level Locking Methods (for background task coordination)
|
|
1581
|
+
# -------------------------------------------------------------------------
|
|
1582
|
+
|
|
1583
|
+
async def try_acquire_system_lock(
|
|
1584
|
+
self,
|
|
1585
|
+
lock_name: str,
|
|
1586
|
+
worker_id: str,
|
|
1587
|
+
timeout_seconds: int = 60,
|
|
1588
|
+
) -> bool:
|
|
1589
|
+
"""
|
|
1590
|
+
Try to acquire a system-level lock for coordinating background tasks.
|
|
1591
|
+
|
|
1592
|
+
Uses INSERT ON CONFLICT pattern to handle race conditions:
|
|
1593
|
+
1. Try to INSERT new lock record
|
|
1594
|
+
2. If exists, check if expired or unlocked
|
|
1595
|
+
3. If available, acquire lock; otherwise return False
|
|
1596
|
+
|
|
1597
|
+
Note: ALWAYS uses separate session (not external session).
|
|
1598
|
+
"""
|
|
1599
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1600
|
+
async with self._session_scope(session) as session:
|
|
1601
|
+
current_time = datetime.now(UTC)
|
|
1602
|
+
lock_expires_at = current_time + timedelta(seconds=timeout_seconds)
|
|
1603
|
+
|
|
1604
|
+
# Try to get existing lock
|
|
1605
|
+
result = await session.execute(
|
|
1606
|
+
select(SystemLock).where(SystemLock.lock_name == lock_name)
|
|
1607
|
+
)
|
|
1608
|
+
lock = result.scalar_one_or_none()
|
|
1609
|
+
|
|
1610
|
+
if lock is None:
|
|
1611
|
+
# No lock exists - create new one
|
|
1612
|
+
lock = SystemLock(
|
|
1613
|
+
lock_name=lock_name,
|
|
1614
|
+
locked_by=worker_id,
|
|
1615
|
+
locked_at=current_time,
|
|
1616
|
+
lock_expires_at=lock_expires_at,
|
|
1617
|
+
)
|
|
1618
|
+
session.add(lock)
|
|
1619
|
+
await session.commit()
|
|
1620
|
+
return True
|
|
1621
|
+
|
|
1622
|
+
# Lock exists - check if available
|
|
1623
|
+
if lock.locked_by is None:
|
|
1624
|
+
# Unlocked - acquire
|
|
1625
|
+
lock.locked_by = worker_id
|
|
1626
|
+
lock.locked_at = current_time
|
|
1627
|
+
lock.lock_expires_at = lock_expires_at
|
|
1628
|
+
await session.commit()
|
|
1629
|
+
return True
|
|
1630
|
+
|
|
1631
|
+
# Check if expired (use SQL-side comparison for cross-DB compatibility)
|
|
1632
|
+
if lock.lock_expires_at is not None:
|
|
1633
|
+
# Handle timezone-naive datetime from SQLite
|
|
1634
|
+
lock_expires = (
|
|
1635
|
+
lock.lock_expires_at.replace(tzinfo=UTC)
|
|
1636
|
+
if lock.lock_expires_at.tzinfo is None
|
|
1637
|
+
else lock.lock_expires_at
|
|
1638
|
+
)
|
|
1639
|
+
if lock_expires <= current_time:
|
|
1640
|
+
# Expired - acquire
|
|
1641
|
+
lock.locked_by = worker_id
|
|
1642
|
+
lock.locked_at = current_time
|
|
1643
|
+
lock.lock_expires_at = lock_expires_at
|
|
1644
|
+
await session.commit()
|
|
1645
|
+
return True
|
|
1646
|
+
|
|
1647
|
+
# Already locked by another worker
|
|
1648
|
+
return False
|
|
1649
|
+
|
|
1650
|
+
async def release_system_lock(self, lock_name: str, worker_id: str) -> None:
|
|
1651
|
+
"""
|
|
1652
|
+
Release a system-level lock.
|
|
1653
|
+
|
|
1654
|
+
Only releases the lock if it's held by the specified worker.
|
|
1655
|
+
|
|
1656
|
+
Note: ALWAYS uses separate session (not external session).
|
|
1657
|
+
"""
|
|
1658
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1659
|
+
async with self._session_scope(session) as session:
|
|
1660
|
+
await session.execute(
|
|
1661
|
+
update(SystemLock)
|
|
1662
|
+
.where(
|
|
1663
|
+
and_(
|
|
1664
|
+
SystemLock.lock_name == lock_name,
|
|
1665
|
+
SystemLock.locked_by == worker_id,
|
|
1666
|
+
)
|
|
1667
|
+
)
|
|
1668
|
+
.values(
|
|
1669
|
+
locked_by=None,
|
|
1670
|
+
locked_at=None,
|
|
1671
|
+
lock_expires_at=None,
|
|
1672
|
+
)
|
|
1673
|
+
)
|
|
1674
|
+
await session.commit()
|
|
1675
|
+
|
|
1168
1676
|
# -------------------------------------------------------------------------
|
|
1169
1677
|
# History Methods (prefer external session)
|
|
1170
1678
|
# -------------------------------------------------------------------------
|
|
@@ -1231,6 +1739,107 @@ class SQLAlchemyStorage:
|
|
|
1231
1739
|
for row in rows
|
|
1232
1740
|
]
|
|
1233
1741
|
|
|
1742
|
+
async def archive_history(self, instance_id: str) -> int:
|
|
1743
|
+
"""
|
|
1744
|
+
Archive workflow history for the recur pattern.
|
|
1745
|
+
|
|
1746
|
+
Moves all history entries from workflow_history to workflow_history_archive.
|
|
1747
|
+
Binary data is converted to base64 for JSON storage in the archive.
|
|
1748
|
+
|
|
1749
|
+
Returns:
|
|
1750
|
+
Number of history entries archived
|
|
1751
|
+
"""
|
|
1752
|
+
import base64
|
|
1753
|
+
|
|
1754
|
+
session = self._get_session_for_operation()
|
|
1755
|
+
async with self._session_scope(session) as session:
|
|
1756
|
+
# Get all history entries for this instance
|
|
1757
|
+
result = await session.execute(
|
|
1758
|
+
select(WorkflowHistory)
|
|
1759
|
+
.where(WorkflowHistory.instance_id == instance_id)
|
|
1760
|
+
.order_by(WorkflowHistory.created_at.asc())
|
|
1761
|
+
)
|
|
1762
|
+
history_rows = result.scalars().all()
|
|
1763
|
+
|
|
1764
|
+
if not history_rows:
|
|
1765
|
+
return 0
|
|
1766
|
+
|
|
1767
|
+
# Archive each history entry
|
|
1768
|
+
for row in history_rows:
|
|
1769
|
+
# Convert event_data to JSON string for archive
|
|
1770
|
+
event_data_json: str | None
|
|
1771
|
+
if row.data_type == "binary" and row.event_data_binary is not None:
|
|
1772
|
+
# Convert binary to base64 for JSON storage
|
|
1773
|
+
event_data_json = json.dumps(
|
|
1774
|
+
{
|
|
1775
|
+
"_binary": True,
|
|
1776
|
+
"data": base64.b64encode(row.event_data_binary).decode("ascii"),
|
|
1777
|
+
}
|
|
1778
|
+
)
|
|
1779
|
+
else:
|
|
1780
|
+
# Already JSON, use as-is
|
|
1781
|
+
event_data_json = row.event_data
|
|
1782
|
+
|
|
1783
|
+
archive_entry = WorkflowHistoryArchive(
|
|
1784
|
+
instance_id=row.instance_id,
|
|
1785
|
+
activity_id=row.activity_id,
|
|
1786
|
+
event_type=row.event_type,
|
|
1787
|
+
event_data=event_data_json,
|
|
1788
|
+
created_at=row.created_at,
|
|
1789
|
+
)
|
|
1790
|
+
session.add(archive_entry)
|
|
1791
|
+
|
|
1792
|
+
# Delete original history entries
|
|
1793
|
+
await session.execute(
|
|
1794
|
+
delete(WorkflowHistory).where(WorkflowHistory.instance_id == instance_id)
|
|
1795
|
+
)
|
|
1796
|
+
|
|
1797
|
+
await self._commit_if_not_in_transaction(session)
|
|
1798
|
+
return len(history_rows)
|
|
1799
|
+
|
|
1800
|
+
async def find_first_cancellation_event(self, instance_id: str) -> dict[str, Any] | None:
|
|
1801
|
+
"""
|
|
1802
|
+
Find the first cancellation event in workflow history.
|
|
1803
|
+
|
|
1804
|
+
Uses LIMIT 1 optimization to avoid loading all history events.
|
|
1805
|
+
"""
|
|
1806
|
+
session = self._get_session_for_operation()
|
|
1807
|
+
async with self._session_scope(session) as session:
|
|
1808
|
+
# Query for cancellation events using LIMIT 1
|
|
1809
|
+
result = await session.execute(
|
|
1810
|
+
select(WorkflowHistory)
|
|
1811
|
+
.where(
|
|
1812
|
+
and_(
|
|
1813
|
+
WorkflowHistory.instance_id == instance_id,
|
|
1814
|
+
or_(
|
|
1815
|
+
WorkflowHistory.event_type == "WorkflowCancelled",
|
|
1816
|
+
func.lower(WorkflowHistory.event_type).contains("cancel"),
|
|
1817
|
+
),
|
|
1818
|
+
)
|
|
1819
|
+
)
|
|
1820
|
+
.order_by(WorkflowHistory.created_at.asc())
|
|
1821
|
+
.limit(1)
|
|
1822
|
+
)
|
|
1823
|
+
row = result.scalars().first()
|
|
1824
|
+
|
|
1825
|
+
if row is None:
|
|
1826
|
+
return None
|
|
1827
|
+
|
|
1828
|
+
# Parse event_data based on data_type
|
|
1829
|
+
if row.data_type == "binary" and row.event_data_binary is not None:
|
|
1830
|
+
event_data: dict[str, Any] | bytes = row.event_data_binary
|
|
1831
|
+
else:
|
|
1832
|
+
event_data = json.loads(row.event_data) if row.event_data else {}
|
|
1833
|
+
|
|
1834
|
+
return {
|
|
1835
|
+
"id": row.id,
|
|
1836
|
+
"instance_id": row.instance_id,
|
|
1837
|
+
"activity_id": row.activity_id,
|
|
1838
|
+
"event_type": row.event_type,
|
|
1839
|
+
"event_data": event_data,
|
|
1840
|
+
"created_at": row.created_at,
|
|
1841
|
+
}
|
|
1842
|
+
|
|
1234
1843
|
# -------------------------------------------------------------------------
|
|
1235
1844
|
# Compensation Methods (prefer external session)
|
|
1236
1845
|
# -------------------------------------------------------------------------
|
|
@@ -1271,7 +1880,7 @@ class SQLAlchemyStorage:
|
|
|
1271
1880
|
"instance_id": row.instance_id,
|
|
1272
1881
|
"activity_id": row.activity_id,
|
|
1273
1882
|
"activity_name": row.activity_name,
|
|
1274
|
-
"args": json.loads(row.args)
|
|
1883
|
+
"args": json.loads(row.args) if row.args else [],
|
|
1275
1884
|
"created_at": row.created_at.isoformat(),
|
|
1276
1885
|
}
|
|
1277
1886
|
for row in rows
|
|
@@ -1287,218 +1896,19 @@ class SQLAlchemyStorage:
|
|
|
1287
1896
|
await self._commit_if_not_in_transaction(session)
|
|
1288
1897
|
|
|
1289
1898
|
# -------------------------------------------------------------------------
|
|
1290
|
-
#
|
|
1899
|
+
# Timer Subscription Methods
|
|
1291
1900
|
# -------------------------------------------------------------------------
|
|
1292
1901
|
|
|
1293
|
-
async def
|
|
1902
|
+
async def register_timer_subscription_and_release_lock(
|
|
1294
1903
|
self,
|
|
1295
1904
|
instance_id: str,
|
|
1296
|
-
|
|
1297
|
-
|
|
1905
|
+
worker_id: str,
|
|
1906
|
+
timer_id: str,
|
|
1907
|
+
expires_at: datetime,
|
|
1908
|
+
activity_id: str | None = None,
|
|
1298
1909
|
) -> None:
|
|
1299
|
-
"""
|
|
1300
|
-
|
|
1301
|
-
async with self._session_scope(session) as session:
|
|
1302
|
-
subscription = WorkflowEventSubscription(
|
|
1303
|
-
instance_id=instance_id,
|
|
1304
|
-
event_type=event_type,
|
|
1305
|
-
timeout_at=timeout_at,
|
|
1306
|
-
)
|
|
1307
|
-
session.add(subscription)
|
|
1308
|
-
await self._commit_if_not_in_transaction(session)
|
|
1309
|
-
|
|
1310
|
-
async def find_waiting_instances(self, event_type: str) -> list[dict[str, Any]]:
|
|
1311
|
-
"""Find workflow instances waiting for a specific event type."""
|
|
1312
|
-
session = self._get_session_for_operation()
|
|
1313
|
-
async with self._session_scope(session) as session:
|
|
1314
|
-
result = await session.execute(
|
|
1315
|
-
select(WorkflowEventSubscription).where(
|
|
1316
|
-
and_(
|
|
1317
|
-
WorkflowEventSubscription.event_type == event_type,
|
|
1318
|
-
or_(
|
|
1319
|
-
WorkflowEventSubscription.timeout_at.is_(None),
|
|
1320
|
-
self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
1321
|
-
> self._get_current_time_expr(),
|
|
1322
|
-
),
|
|
1323
|
-
)
|
|
1324
|
-
)
|
|
1325
|
-
)
|
|
1326
|
-
rows = result.scalars().all()
|
|
1327
|
-
|
|
1328
|
-
return [
|
|
1329
|
-
{
|
|
1330
|
-
"id": row.id,
|
|
1331
|
-
"instance_id": row.instance_id,
|
|
1332
|
-
"event_type": row.event_type,
|
|
1333
|
-
"activity_id": row.activity_id,
|
|
1334
|
-
"timeout_at": row.timeout_at.isoformat() if row.timeout_at else None,
|
|
1335
|
-
"created_at": row.created_at.isoformat(),
|
|
1336
|
-
}
|
|
1337
|
-
for row in rows
|
|
1338
|
-
]
|
|
1339
|
-
|
|
1340
|
-
async def remove_event_subscription(
|
|
1341
|
-
self,
|
|
1342
|
-
instance_id: str,
|
|
1343
|
-
event_type: str,
|
|
1344
|
-
) -> None:
|
|
1345
|
-
"""Remove event subscription after the event is received."""
|
|
1346
|
-
session = self._get_session_for_operation()
|
|
1347
|
-
async with self._session_scope(session) as session:
|
|
1348
|
-
await session.execute(
|
|
1349
|
-
delete(WorkflowEventSubscription).where(
|
|
1350
|
-
and_(
|
|
1351
|
-
WorkflowEventSubscription.instance_id == instance_id,
|
|
1352
|
-
WorkflowEventSubscription.event_type == event_type,
|
|
1353
|
-
)
|
|
1354
|
-
)
|
|
1355
|
-
)
|
|
1356
|
-
await self._commit_if_not_in_transaction(session)
|
|
1357
|
-
|
|
1358
|
-
async def cleanup_expired_subscriptions(self) -> int:
|
|
1359
|
-
"""Clean up event subscriptions that have timed out."""
|
|
1360
|
-
session = self._get_session_for_operation()
|
|
1361
|
-
async with self._session_scope(session) as session:
|
|
1362
|
-
result = await session.execute(
|
|
1363
|
-
delete(WorkflowEventSubscription).where(
|
|
1364
|
-
and_(
|
|
1365
|
-
WorkflowEventSubscription.timeout_at.isnot(None),
|
|
1366
|
-
self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
1367
|
-
<= self._get_current_time_expr(),
|
|
1368
|
-
)
|
|
1369
|
-
)
|
|
1370
|
-
)
|
|
1371
|
-
await self._commit_if_not_in_transaction(session)
|
|
1372
|
-
return result.rowcount or 0 # type: ignore[attr-defined]
|
|
1373
|
-
|
|
1374
|
-
async def find_expired_event_subscriptions(self) -> list[dict[str, Any]]:
|
|
1375
|
-
"""Find event subscriptions that have timed out."""
|
|
1376
|
-
session = self._get_session_for_operation()
|
|
1377
|
-
async with self._session_scope(session) as session:
|
|
1378
|
-
result = await session.execute(
|
|
1379
|
-
select(
|
|
1380
|
-
WorkflowEventSubscription.instance_id,
|
|
1381
|
-
WorkflowEventSubscription.event_type,
|
|
1382
|
-
WorkflowEventSubscription.activity_id,
|
|
1383
|
-
WorkflowEventSubscription.timeout_at,
|
|
1384
|
-
WorkflowEventSubscription.created_at,
|
|
1385
|
-
).where(
|
|
1386
|
-
and_(
|
|
1387
|
-
WorkflowEventSubscription.timeout_at.isnot(None),
|
|
1388
|
-
self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
1389
|
-
<= self._get_current_time_expr(),
|
|
1390
|
-
)
|
|
1391
|
-
)
|
|
1392
|
-
)
|
|
1393
|
-
rows = result.all()
|
|
1394
|
-
|
|
1395
|
-
return [
|
|
1396
|
-
{
|
|
1397
|
-
"instance_id": row[0],
|
|
1398
|
-
"event_type": row[1],
|
|
1399
|
-
"activity_id": row[2],
|
|
1400
|
-
"timeout_at": row[3].isoformat() if row[3] else None,
|
|
1401
|
-
"created_at": row[4].isoformat() if row[4] else None,
|
|
1402
|
-
}
|
|
1403
|
-
for row in rows
|
|
1404
|
-
]
|
|
1405
|
-
|
|
1406
|
-
async def register_event_subscription_and_release_lock(
|
|
1407
|
-
self,
|
|
1408
|
-
instance_id: str,
|
|
1409
|
-
worker_id: str,
|
|
1410
|
-
event_type: str,
|
|
1411
|
-
timeout_at: datetime | None = None,
|
|
1412
|
-
activity_id: str | None = None,
|
|
1413
|
-
) -> None:
|
|
1414
|
-
"""
|
|
1415
|
-
Atomically register event subscription and release workflow lock.
|
|
1416
|
-
|
|
1417
|
-
This performs THREE operations in a SINGLE transaction:
|
|
1418
|
-
1. Register event subscription
|
|
1419
|
-
2. Update current activity
|
|
1420
|
-
3. Release lock
|
|
1421
|
-
|
|
1422
|
-
This ensures distributed coroutines work correctly - when a workflow
|
|
1423
|
-
calls wait_event(), the subscription is registered and lock is released
|
|
1424
|
-
atomically, so ANY worker can resume the workflow when the event arrives.
|
|
1425
|
-
|
|
1426
|
-
Note: Uses LOCK operation session (separate from external session).
|
|
1427
|
-
"""
|
|
1428
|
-
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1429
|
-
async with self._session_scope(session) as session, session.begin():
|
|
1430
|
-
# 1. Verify we hold the lock (sanity check)
|
|
1431
|
-
result = await session.execute(
|
|
1432
|
-
select(WorkflowInstance.locked_by).where(
|
|
1433
|
-
WorkflowInstance.instance_id == instance_id
|
|
1434
|
-
)
|
|
1435
|
-
)
|
|
1436
|
-
row = result.one_or_none()
|
|
1437
|
-
|
|
1438
|
-
if row is None:
|
|
1439
|
-
raise RuntimeError(f"Workflow instance {instance_id} not found")
|
|
1440
|
-
|
|
1441
|
-
current_lock_holder = row[0]
|
|
1442
|
-
if current_lock_holder != worker_id:
|
|
1443
|
-
raise RuntimeError(
|
|
1444
|
-
f"Cannot release lock: worker {worker_id} does not hold lock "
|
|
1445
|
-
f"for {instance_id} (held by: {current_lock_holder})"
|
|
1446
|
-
)
|
|
1447
|
-
|
|
1448
|
-
# 2. Register event subscription (INSERT OR REPLACE equivalent)
|
|
1449
|
-
# First delete existing
|
|
1450
|
-
await session.execute(
|
|
1451
|
-
delete(WorkflowEventSubscription).where(
|
|
1452
|
-
and_(
|
|
1453
|
-
WorkflowEventSubscription.instance_id == instance_id,
|
|
1454
|
-
WorkflowEventSubscription.event_type == event_type,
|
|
1455
|
-
)
|
|
1456
|
-
)
|
|
1457
|
-
)
|
|
1458
|
-
|
|
1459
|
-
# Then insert new
|
|
1460
|
-
subscription = WorkflowEventSubscription(
|
|
1461
|
-
instance_id=instance_id,
|
|
1462
|
-
event_type=event_type,
|
|
1463
|
-
activity_id=activity_id,
|
|
1464
|
-
timeout_at=timeout_at,
|
|
1465
|
-
)
|
|
1466
|
-
session.add(subscription)
|
|
1467
|
-
|
|
1468
|
-
# 3. Update current activity (if provided)
|
|
1469
|
-
if activity_id is not None:
|
|
1470
|
-
await session.execute(
|
|
1471
|
-
update(WorkflowInstance)
|
|
1472
|
-
.where(WorkflowInstance.instance_id == instance_id)
|
|
1473
|
-
.values(current_activity_id=activity_id, updated_at=func.now())
|
|
1474
|
-
)
|
|
1475
|
-
|
|
1476
|
-
# 4. Release lock
|
|
1477
|
-
await session.execute(
|
|
1478
|
-
update(WorkflowInstance)
|
|
1479
|
-
.where(
|
|
1480
|
-
and_(
|
|
1481
|
-
WorkflowInstance.instance_id == instance_id,
|
|
1482
|
-
WorkflowInstance.locked_by == worker_id,
|
|
1483
|
-
)
|
|
1484
|
-
)
|
|
1485
|
-
.values(
|
|
1486
|
-
locked_by=None,
|
|
1487
|
-
locked_at=None,
|
|
1488
|
-
updated_at=func.now(),
|
|
1489
|
-
)
|
|
1490
|
-
)
|
|
1491
|
-
|
|
1492
|
-
async def register_timer_subscription_and_release_lock(
|
|
1493
|
-
self,
|
|
1494
|
-
instance_id: str,
|
|
1495
|
-
worker_id: str,
|
|
1496
|
-
timer_id: str,
|
|
1497
|
-
expires_at: datetime,
|
|
1498
|
-
activity_id: str | None = None,
|
|
1499
|
-
) -> None:
|
|
1500
|
-
"""
|
|
1501
|
-
Atomically register timer subscription and release workflow lock.
|
|
1910
|
+
"""
|
|
1911
|
+
Atomically register timer subscription and release workflow lock.
|
|
1502
1912
|
|
|
1503
1913
|
This performs FOUR operations in a SINGLE transaction:
|
|
1504
1914
|
1. Register timer subscription
|
|
@@ -1580,7 +1990,11 @@ class SQLAlchemyStorage:
|
|
|
1580
1990
|
)
|
|
1581
1991
|
|
|
1582
1992
|
async def find_expired_timers(self) -> list[dict[str, Any]]:
|
|
1583
|
-
"""Find timer subscriptions that have expired.
|
|
1993
|
+
"""Find timer subscriptions that have expired.
|
|
1994
|
+
|
|
1995
|
+
Returns timer info including workflow status to avoid N+1 queries.
|
|
1996
|
+
The SQL query already filters by status='waiting_for_timer'.
|
|
1997
|
+
"""
|
|
1584
1998
|
session = self._get_session_for_operation()
|
|
1585
1999
|
async with self._session_scope(session) as session:
|
|
1586
2000
|
result = await session.execute(
|
|
@@ -1590,6 +2004,7 @@ class SQLAlchemyStorage:
|
|
|
1590
2004
|
WorkflowTimerSubscription.expires_at,
|
|
1591
2005
|
WorkflowTimerSubscription.activity_id,
|
|
1592
2006
|
WorkflowInstance.workflow_name,
|
|
2007
|
+
WorkflowInstance.status, # Include status to avoid N+1 query
|
|
1593
2008
|
)
|
|
1594
2009
|
.join(
|
|
1595
2010
|
WorkflowInstance,
|
|
@@ -1612,6 +2027,7 @@ class SQLAlchemyStorage:
|
|
|
1612
2027
|
"expires_at": row[2].isoformat(),
|
|
1613
2028
|
"activity_id": row[3],
|
|
1614
2029
|
"workflow_name": row[4],
|
|
2030
|
+
"status": row[5], # Always 'waiting_for_timer' due to WHERE clause
|
|
1615
2031
|
}
|
|
1616
2032
|
for row in rows
|
|
1617
2033
|
]
|
|
@@ -1866,6 +2282,7 @@ class SQLAlchemyStorage:
|
|
|
1866
2282
|
"running",
|
|
1867
2283
|
"waiting_for_event",
|
|
1868
2284
|
"waiting_for_timer",
|
|
2285
|
+
"waiting_for_message",
|
|
1869
2286
|
"compensating",
|
|
1870
2287
|
):
|
|
1871
2288
|
# Already completed, failed, or cancelled
|
|
@@ -1890,14 +2307,6 @@ class SQLAlchemyStorage:
|
|
|
1890
2307
|
)
|
|
1891
2308
|
)
|
|
1892
2309
|
|
|
1893
|
-
# Remove event subscriptions if waiting for event
|
|
1894
|
-
if current_status == "waiting_for_event":
|
|
1895
|
-
await session.execute(
|
|
1896
|
-
delete(WorkflowEventSubscription).where(
|
|
1897
|
-
WorkflowEventSubscription.instance_id == instance_id
|
|
1898
|
-
)
|
|
1899
|
-
)
|
|
1900
|
-
|
|
1901
2310
|
# Remove timer subscriptions if waiting for timer
|
|
1902
2311
|
if current_status == "waiting_for_timer":
|
|
1903
2312
|
await session.execute(
|
|
@@ -1906,4 +2315,1249 @@ class SQLAlchemyStorage:
|
|
|
1906
2315
|
)
|
|
1907
2316
|
)
|
|
1908
2317
|
|
|
2318
|
+
# Remove message subscriptions if waiting for message
|
|
2319
|
+
await session.execute(
|
|
2320
|
+
delete(WorkflowMessageSubscription).where(
|
|
2321
|
+
WorkflowMessageSubscription.instance_id == instance_id
|
|
2322
|
+
)
|
|
2323
|
+
)
|
|
2324
|
+
|
|
1909
2325
|
return True
|
|
2326
|
+
|
|
2327
|
+
# -------------------------------------------------------------------------
|
|
2328
|
+
# Message Subscription Methods
|
|
2329
|
+
# -------------------------------------------------------------------------
|
|
2330
|
+
|
|
2331
|
+
async def register_message_subscription_and_release_lock(
|
|
2332
|
+
self,
|
|
2333
|
+
instance_id: str,
|
|
2334
|
+
worker_id: str,
|
|
2335
|
+
channel: str,
|
|
2336
|
+
timeout_at: datetime | None = None,
|
|
2337
|
+
activity_id: str | None = None,
|
|
2338
|
+
) -> None:
|
|
2339
|
+
"""
|
|
2340
|
+
Atomically register a message subscription and release the workflow lock.
|
|
2341
|
+
|
|
2342
|
+
This is called when a workflow executes wait_message() and needs to:
|
|
2343
|
+
1. Verify lock ownership (RuntimeError if mismatch - full rollback)
|
|
2344
|
+
2. Register a subscription for the channel
|
|
2345
|
+
3. Update the workflow status to waiting_for_message
|
|
2346
|
+
4. Release the lock
|
|
2347
|
+
|
|
2348
|
+
All operations happen in a single transaction for atomicity.
|
|
2349
|
+
|
|
2350
|
+
Args:
|
|
2351
|
+
instance_id: Workflow instance ID
|
|
2352
|
+
worker_id: Worker ID that must hold the current lock (verified before release)
|
|
2353
|
+
channel: Channel name to subscribe to
|
|
2354
|
+
timeout_at: Optional absolute timeout time
|
|
2355
|
+
activity_id: Activity ID for the wait operation
|
|
2356
|
+
|
|
2357
|
+
Raises:
|
|
2358
|
+
RuntimeError: If the worker does not hold the lock (entire operation rolls back)
|
|
2359
|
+
"""
|
|
2360
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
2361
|
+
async with self._session_scope(session) as session, session.begin():
|
|
2362
|
+
# 1. Verify we hold the lock (sanity check - fail fast if not)
|
|
2363
|
+
result = await session.execute(
|
|
2364
|
+
select(WorkflowInstance.locked_by).where(
|
|
2365
|
+
WorkflowInstance.instance_id == instance_id
|
|
2366
|
+
)
|
|
2367
|
+
)
|
|
2368
|
+
row = result.one_or_none()
|
|
2369
|
+
|
|
2370
|
+
if row is None:
|
|
2371
|
+
raise RuntimeError(f"Workflow instance {instance_id} not found")
|
|
2372
|
+
|
|
2373
|
+
current_lock_holder = row[0]
|
|
2374
|
+
if current_lock_holder != worker_id:
|
|
2375
|
+
raise RuntimeError(
|
|
2376
|
+
f"Cannot release lock: worker {worker_id} does not hold lock "
|
|
2377
|
+
f"for {instance_id} (held by: {current_lock_holder})"
|
|
2378
|
+
)
|
|
2379
|
+
|
|
2380
|
+
# 2. Register subscription (delete then insert for cross-database compatibility)
|
|
2381
|
+
await session.execute(
|
|
2382
|
+
delete(WorkflowMessageSubscription).where(
|
|
2383
|
+
and_(
|
|
2384
|
+
WorkflowMessageSubscription.instance_id == instance_id,
|
|
2385
|
+
WorkflowMessageSubscription.channel == channel,
|
|
2386
|
+
)
|
|
2387
|
+
)
|
|
2388
|
+
)
|
|
2389
|
+
|
|
2390
|
+
subscription = WorkflowMessageSubscription(
|
|
2391
|
+
instance_id=instance_id,
|
|
2392
|
+
channel=channel,
|
|
2393
|
+
activity_id=activity_id,
|
|
2394
|
+
timeout_at=timeout_at,
|
|
2395
|
+
)
|
|
2396
|
+
session.add(subscription)
|
|
2397
|
+
|
|
2398
|
+
# 3. Update workflow status and activity
|
|
2399
|
+
await session.execute(
|
|
2400
|
+
update(WorkflowInstance)
|
|
2401
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
2402
|
+
.values(
|
|
2403
|
+
status="waiting_for_message",
|
|
2404
|
+
current_activity_id=activity_id,
|
|
2405
|
+
updated_at=func.now(),
|
|
2406
|
+
)
|
|
2407
|
+
)
|
|
2408
|
+
|
|
2409
|
+
# 4. Release the lock (with ownership check for extra safety)
|
|
2410
|
+
await session.execute(
|
|
2411
|
+
update(WorkflowInstance)
|
|
2412
|
+
.where(
|
|
2413
|
+
and_(
|
|
2414
|
+
WorkflowInstance.instance_id == instance_id,
|
|
2415
|
+
WorkflowInstance.locked_by == worker_id,
|
|
2416
|
+
)
|
|
2417
|
+
)
|
|
2418
|
+
.values(
|
|
2419
|
+
locked_by=None,
|
|
2420
|
+
locked_at=None,
|
|
2421
|
+
lock_expires_at=None,
|
|
2422
|
+
)
|
|
2423
|
+
)
|
|
2424
|
+
|
|
2425
|
+
async def find_waiting_instances_by_channel(self, channel: str) -> list[dict[str, Any]]:
|
|
2426
|
+
"""
|
|
2427
|
+
Find all workflow instances waiting on a specific channel.
|
|
2428
|
+
|
|
2429
|
+
Args:
|
|
2430
|
+
channel: Channel name to search for
|
|
2431
|
+
|
|
2432
|
+
Returns:
|
|
2433
|
+
List of subscription info dicts with instance_id, channel, activity_id, timeout_at
|
|
2434
|
+
"""
|
|
2435
|
+
session = self._get_session_for_operation()
|
|
2436
|
+
async with self._session_scope(session) as session:
|
|
2437
|
+
result = await session.execute(
|
|
2438
|
+
select(WorkflowMessageSubscription).where(
|
|
2439
|
+
WorkflowMessageSubscription.channel == channel
|
|
2440
|
+
)
|
|
2441
|
+
)
|
|
2442
|
+
subscriptions = result.scalars().all()
|
|
2443
|
+
return [
|
|
2444
|
+
{
|
|
2445
|
+
"instance_id": sub.instance_id,
|
|
2446
|
+
"channel": sub.channel,
|
|
2447
|
+
"activity_id": sub.activity_id,
|
|
2448
|
+
"timeout_at": sub.timeout_at.isoformat() if sub.timeout_at else None,
|
|
2449
|
+
"created_at": sub.created_at.isoformat() if sub.created_at else None,
|
|
2450
|
+
}
|
|
2451
|
+
for sub in subscriptions
|
|
2452
|
+
]
|
|
2453
|
+
|
|
2454
|
+
async def remove_message_subscription(
|
|
2455
|
+
self,
|
|
2456
|
+
instance_id: str,
|
|
2457
|
+
channel: str,
|
|
2458
|
+
) -> None:
|
|
2459
|
+
"""
|
|
2460
|
+
Remove a message subscription.
|
|
2461
|
+
|
|
2462
|
+
This method removes from the legacy WorkflowMessageSubscription table
|
|
2463
|
+
and clears waiting state from the ChannelSubscription table.
|
|
2464
|
+
|
|
2465
|
+
Args:
|
|
2466
|
+
instance_id: Workflow instance ID
|
|
2467
|
+
channel: Channel name
|
|
2468
|
+
"""
|
|
2469
|
+
session = self._get_session_for_operation()
|
|
2470
|
+
async with self._session_scope(session) as session:
|
|
2471
|
+
# Remove from legacy WorkflowMessageSubscription table
|
|
2472
|
+
await session.execute(
|
|
2473
|
+
delete(WorkflowMessageSubscription).where(
|
|
2474
|
+
and_(
|
|
2475
|
+
WorkflowMessageSubscription.instance_id == instance_id,
|
|
2476
|
+
WorkflowMessageSubscription.channel == channel,
|
|
2477
|
+
)
|
|
2478
|
+
)
|
|
2479
|
+
)
|
|
2480
|
+
|
|
2481
|
+
# Clear waiting state from ChannelSubscription table
|
|
2482
|
+
# Don't delete the subscription - just clear the waiting state
|
|
2483
|
+
await session.execute(
|
|
2484
|
+
update(ChannelSubscription)
|
|
2485
|
+
.where(
|
|
2486
|
+
and_(
|
|
2487
|
+
ChannelSubscription.instance_id == instance_id,
|
|
2488
|
+
ChannelSubscription.channel == channel,
|
|
2489
|
+
)
|
|
2490
|
+
)
|
|
2491
|
+
.values(activity_id=None, timeout_at=None)
|
|
2492
|
+
)
|
|
2493
|
+
await self._commit_if_not_in_transaction(session)
|
|
2494
|
+
|
|
2495
|
+
async def deliver_message(
|
|
2496
|
+
self,
|
|
2497
|
+
instance_id: str,
|
|
2498
|
+
channel: str,
|
|
2499
|
+
data: dict[str, Any] | bytes,
|
|
2500
|
+
metadata: dict[str, Any],
|
|
2501
|
+
worker_id: str | None = None,
|
|
2502
|
+
) -> dict[str, Any] | None:
|
|
2503
|
+
"""
|
|
2504
|
+
Deliver a message to a waiting workflow using Lock-First pattern.
|
|
2505
|
+
|
|
2506
|
+
This method:
|
|
2507
|
+
1. Checks if there's a subscription for this instance/channel
|
|
2508
|
+
2. Acquires lock (Lock-First pattern) - if worker_id provided
|
|
2509
|
+
3. Records the message in history
|
|
2510
|
+
4. Removes the subscription
|
|
2511
|
+
5. Updates status to 'running'
|
|
2512
|
+
6. Releases lock
|
|
2513
|
+
|
|
2514
|
+
Args:
|
|
2515
|
+
instance_id: Target workflow instance ID
|
|
2516
|
+
channel: Channel name
|
|
2517
|
+
data: Message payload (dict or bytes)
|
|
2518
|
+
metadata: Message metadata
|
|
2519
|
+
worker_id: Worker ID for locking. If None, skip locking (legacy mode).
|
|
2520
|
+
|
|
2521
|
+
Returns:
|
|
2522
|
+
Dict with delivery info if successful, None otherwise
|
|
2523
|
+
"""
|
|
2524
|
+
import uuid
|
|
2525
|
+
|
|
2526
|
+
# Step 1: Check if subscription exists (without lock)
|
|
2527
|
+
session = self._get_session_for_operation()
|
|
2528
|
+
async with self._session_scope(session) as session:
|
|
2529
|
+
result = await session.execute(
|
|
2530
|
+
select(WorkflowMessageSubscription).where(
|
|
2531
|
+
and_(
|
|
2532
|
+
WorkflowMessageSubscription.instance_id == instance_id,
|
|
2533
|
+
WorkflowMessageSubscription.channel == channel,
|
|
2534
|
+
)
|
|
2535
|
+
)
|
|
2536
|
+
)
|
|
2537
|
+
subscription = result.scalar_one_or_none()
|
|
2538
|
+
|
|
2539
|
+
if subscription is None:
|
|
2540
|
+
return None
|
|
2541
|
+
|
|
2542
|
+
activity_id = subscription.activity_id
|
|
2543
|
+
|
|
2544
|
+
# Step 2: Acquire lock (Lock-First pattern) if worker_id provided
|
|
2545
|
+
lock_acquired = False
|
|
2546
|
+
if worker_id is not None:
|
|
2547
|
+
lock_acquired = await self.try_acquire_lock(instance_id, worker_id)
|
|
2548
|
+
if not lock_acquired:
|
|
2549
|
+
# Another worker is processing this workflow
|
|
2550
|
+
return None
|
|
2551
|
+
|
|
2552
|
+
try:
|
|
2553
|
+
# Step 3-5: Deliver message atomically
|
|
2554
|
+
session = self._get_session_for_operation()
|
|
2555
|
+
async with self._session_scope(session) as session:
|
|
2556
|
+
# Re-check subscription (may have been removed by another worker)
|
|
2557
|
+
result = await session.execute(
|
|
2558
|
+
select(WorkflowMessageSubscription).where(
|
|
2559
|
+
and_(
|
|
2560
|
+
WorkflowMessageSubscription.instance_id == instance_id,
|
|
2561
|
+
WorkflowMessageSubscription.channel == channel,
|
|
2562
|
+
)
|
|
2563
|
+
)
|
|
2564
|
+
)
|
|
2565
|
+
subscription = result.scalar_one_or_none()
|
|
2566
|
+
|
|
2567
|
+
if subscription is None:
|
|
2568
|
+
# Already delivered by another worker
|
|
2569
|
+
return None
|
|
2570
|
+
|
|
2571
|
+
activity_id = subscription.activity_id
|
|
2572
|
+
|
|
2573
|
+
# Get workflow info for return value
|
|
2574
|
+
instance_result = await session.execute(
|
|
2575
|
+
select(WorkflowInstance).where(WorkflowInstance.instance_id == instance_id)
|
|
2576
|
+
)
|
|
2577
|
+
instance = instance_result.scalar_one_or_none()
|
|
2578
|
+
workflow_name = instance.workflow_name if instance else "unknown"
|
|
2579
|
+
|
|
2580
|
+
# Build message data for history
|
|
2581
|
+
message_id = str(uuid.uuid4())
|
|
2582
|
+
message_data = {
|
|
2583
|
+
"id": message_id,
|
|
2584
|
+
"channel": channel,
|
|
2585
|
+
"data": data if isinstance(data, dict) else None,
|
|
2586
|
+
"metadata": metadata,
|
|
2587
|
+
}
|
|
2588
|
+
|
|
2589
|
+
# Handle binary data
|
|
2590
|
+
if isinstance(data, bytes):
|
|
2591
|
+
data_type = "binary"
|
|
2592
|
+
event_data_json = None
|
|
2593
|
+
event_data_binary = data
|
|
2594
|
+
else:
|
|
2595
|
+
data_type = "json"
|
|
2596
|
+
event_data_json = json.dumps(message_data)
|
|
2597
|
+
event_data_binary = None
|
|
2598
|
+
|
|
2599
|
+
# Record in history
|
|
2600
|
+
history_entry = WorkflowHistory(
|
|
2601
|
+
instance_id=instance_id,
|
|
2602
|
+
activity_id=activity_id,
|
|
2603
|
+
event_type="ChannelMessageReceived",
|
|
2604
|
+
data_type=data_type,
|
|
2605
|
+
event_data=event_data_json,
|
|
2606
|
+
event_data_binary=event_data_binary,
|
|
2607
|
+
)
|
|
2608
|
+
session.add(history_entry)
|
|
2609
|
+
|
|
2610
|
+
# Remove subscription
|
|
2611
|
+
await session.execute(
|
|
2612
|
+
delete(WorkflowMessageSubscription).where(
|
|
2613
|
+
and_(
|
|
2614
|
+
WorkflowMessageSubscription.instance_id == instance_id,
|
|
2615
|
+
WorkflowMessageSubscription.channel == channel,
|
|
2616
|
+
)
|
|
2617
|
+
)
|
|
2618
|
+
)
|
|
2619
|
+
|
|
2620
|
+
# Update status to 'running' (ready for resumption)
|
|
2621
|
+
await session.execute(
|
|
2622
|
+
update(WorkflowInstance)
|
|
2623
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
2624
|
+
.values(status="running", updated_at=func.now())
|
|
2625
|
+
)
|
|
2626
|
+
|
|
2627
|
+
await self._commit_if_not_in_transaction(session)
|
|
2628
|
+
|
|
2629
|
+
return {
|
|
2630
|
+
"instance_id": instance_id,
|
|
2631
|
+
"workflow_name": workflow_name,
|
|
2632
|
+
"activity_id": activity_id,
|
|
2633
|
+
}
|
|
2634
|
+
|
|
2635
|
+
finally:
|
|
2636
|
+
# Step 6: Release lock if we acquired it
|
|
2637
|
+
if lock_acquired and worker_id is not None:
|
|
2638
|
+
await self.release_lock(instance_id, worker_id)
|
|
2639
|
+
|
|
2640
|
+
async def find_expired_message_subscriptions(self) -> list[dict[str, Any]]:
|
|
2641
|
+
"""
|
|
2642
|
+
Find all message subscriptions that have timed out.
|
|
2643
|
+
|
|
2644
|
+
This method queries both the legacy WorkflowMessageSubscription table and
|
|
2645
|
+
the ChannelSubscription table for expired subscriptions.
|
|
2646
|
+
JOINs with WorkflowInstance to ensure instance exists and avoid N+1 queries.
|
|
2647
|
+
|
|
2648
|
+
Returns:
|
|
2649
|
+
List of dicts with instance_id, channel, activity_id, timeout_at, created_at, workflow_name
|
|
2650
|
+
"""
|
|
2651
|
+
session = self._get_session_for_operation()
|
|
2652
|
+
async with self._session_scope(session) as session:
|
|
2653
|
+
results: list[dict[str, Any]] = []
|
|
2654
|
+
|
|
2655
|
+
# Query legacy WorkflowMessageSubscription table with JOIN to verify instance exists
|
|
2656
|
+
legacy_result = await session.execute(
|
|
2657
|
+
select(
|
|
2658
|
+
WorkflowMessageSubscription.instance_id,
|
|
2659
|
+
WorkflowMessageSubscription.channel,
|
|
2660
|
+
WorkflowMessageSubscription.activity_id,
|
|
2661
|
+
WorkflowMessageSubscription.timeout_at,
|
|
2662
|
+
WorkflowMessageSubscription.created_at,
|
|
2663
|
+
WorkflowInstance.workflow_name,
|
|
2664
|
+
)
|
|
2665
|
+
.join(
|
|
2666
|
+
WorkflowInstance,
|
|
2667
|
+
WorkflowMessageSubscription.instance_id == WorkflowInstance.instance_id,
|
|
2668
|
+
)
|
|
2669
|
+
.where(
|
|
2670
|
+
and_(
|
|
2671
|
+
WorkflowMessageSubscription.timeout_at.isnot(None),
|
|
2672
|
+
self._make_datetime_comparable(WorkflowMessageSubscription.timeout_at)
|
|
2673
|
+
<= self._get_current_time_expr(),
|
|
2674
|
+
)
|
|
2675
|
+
)
|
|
2676
|
+
)
|
|
2677
|
+
legacy_rows = legacy_result.all()
|
|
2678
|
+
results.extend(
|
|
2679
|
+
[
|
|
2680
|
+
{
|
|
2681
|
+
"instance_id": row[0],
|
|
2682
|
+
"channel": row[1],
|
|
2683
|
+
"activity_id": row[2],
|
|
2684
|
+
"timeout_at": row[3],
|
|
2685
|
+
"created_at": row[4],
|
|
2686
|
+
"workflow_name": row[5],
|
|
2687
|
+
}
|
|
2688
|
+
for row in legacy_rows
|
|
2689
|
+
]
|
|
2690
|
+
)
|
|
2691
|
+
|
|
2692
|
+
# Query ChannelSubscription table with JOIN
|
|
2693
|
+
channel_result = await session.execute(
|
|
2694
|
+
select(
|
|
2695
|
+
ChannelSubscription.instance_id,
|
|
2696
|
+
ChannelSubscription.channel,
|
|
2697
|
+
ChannelSubscription.activity_id,
|
|
2698
|
+
ChannelSubscription.timeout_at,
|
|
2699
|
+
ChannelSubscription.subscribed_at,
|
|
2700
|
+
WorkflowInstance.workflow_name,
|
|
2701
|
+
)
|
|
2702
|
+
.join(
|
|
2703
|
+
WorkflowInstance,
|
|
2704
|
+
ChannelSubscription.instance_id == WorkflowInstance.instance_id,
|
|
2705
|
+
)
|
|
2706
|
+
.where(
|
|
2707
|
+
and_(
|
|
2708
|
+
ChannelSubscription.timeout_at.isnot(None),
|
|
2709
|
+
ChannelSubscription.activity_id.isnot(None), # Only waiting subscriptions
|
|
2710
|
+
self._make_datetime_comparable(ChannelSubscription.timeout_at)
|
|
2711
|
+
<= self._get_current_time_expr(),
|
|
2712
|
+
)
|
|
2713
|
+
)
|
|
2714
|
+
)
|
|
2715
|
+
channel_rows = channel_result.all()
|
|
2716
|
+
results.extend(
|
|
2717
|
+
[
|
|
2718
|
+
{
|
|
2719
|
+
"instance_id": row[0],
|
|
2720
|
+
"channel": row[1],
|
|
2721
|
+
"activity_id": row[2],
|
|
2722
|
+
"timeout_at": row[3],
|
|
2723
|
+
"created_at": row[4], # subscribed_at as created_at for compatibility
|
|
2724
|
+
"workflow_name": row[5],
|
|
2725
|
+
}
|
|
2726
|
+
for row in channel_rows
|
|
2727
|
+
]
|
|
2728
|
+
)
|
|
2729
|
+
|
|
2730
|
+
return results
|
|
2731
|
+
|
|
2732
|
+
# -------------------------------------------------------------------------
|
|
2733
|
+
# Group Membership Methods (Erlang pg style)
|
|
2734
|
+
# -------------------------------------------------------------------------
|
|
2735
|
+
|
|
2736
|
+
async def join_group(self, instance_id: str, group_name: str) -> None:
|
|
2737
|
+
"""
|
|
2738
|
+
Add a workflow instance to a group.
|
|
2739
|
+
|
|
2740
|
+
Groups provide loose coupling for broadcast messaging.
|
|
2741
|
+
Idempotent - joining a group the instance is already in is a no-op.
|
|
2742
|
+
|
|
2743
|
+
Args:
|
|
2744
|
+
instance_id: Workflow instance ID
|
|
2745
|
+
group_name: Group to join
|
|
2746
|
+
"""
|
|
2747
|
+
session = self._get_session_for_operation()
|
|
2748
|
+
async with self._session_scope(session) as session:
|
|
2749
|
+
# Check if already a member (for idempotency)
|
|
2750
|
+
result = await session.execute(
|
|
2751
|
+
select(WorkflowGroupMembership).where(
|
|
2752
|
+
and_(
|
|
2753
|
+
WorkflowGroupMembership.instance_id == instance_id,
|
|
2754
|
+
WorkflowGroupMembership.group_name == group_name,
|
|
2755
|
+
)
|
|
2756
|
+
)
|
|
2757
|
+
)
|
|
2758
|
+
if result.scalar_one_or_none() is not None:
|
|
2759
|
+
# Already a member, nothing to do
|
|
2760
|
+
return
|
|
2761
|
+
|
|
2762
|
+
# Add membership
|
|
2763
|
+
membership = WorkflowGroupMembership(
|
|
2764
|
+
instance_id=instance_id,
|
|
2765
|
+
group_name=group_name,
|
|
2766
|
+
)
|
|
2767
|
+
session.add(membership)
|
|
2768
|
+
await self._commit_if_not_in_transaction(session)
|
|
2769
|
+
|
|
2770
|
+
async def leave_group(self, instance_id: str, group_name: str) -> None:
|
|
2771
|
+
"""
|
|
2772
|
+
Remove a workflow instance from a group.
|
|
2773
|
+
|
|
2774
|
+
Args:
|
|
2775
|
+
instance_id: Workflow instance ID
|
|
2776
|
+
group_name: Group to leave
|
|
2777
|
+
"""
|
|
2778
|
+
session = self._get_session_for_operation()
|
|
2779
|
+
async with self._session_scope(session) as session:
|
|
2780
|
+
await session.execute(
|
|
2781
|
+
delete(WorkflowGroupMembership).where(
|
|
2782
|
+
and_(
|
|
2783
|
+
WorkflowGroupMembership.instance_id == instance_id,
|
|
2784
|
+
WorkflowGroupMembership.group_name == group_name,
|
|
2785
|
+
)
|
|
2786
|
+
)
|
|
2787
|
+
)
|
|
2788
|
+
await self._commit_if_not_in_transaction(session)
|
|
2789
|
+
|
|
2790
|
+
async def get_group_members(self, group_name: str) -> list[str]:
|
|
2791
|
+
"""
|
|
2792
|
+
Get all workflow instances in a group.
|
|
2793
|
+
|
|
2794
|
+
Args:
|
|
2795
|
+
group_name: Group name
|
|
2796
|
+
|
|
2797
|
+
Returns:
|
|
2798
|
+
List of instance IDs in the group
|
|
2799
|
+
"""
|
|
2800
|
+
session = self._get_session_for_operation()
|
|
2801
|
+
async with self._session_scope(session) as session:
|
|
2802
|
+
result = await session.execute(
|
|
2803
|
+
select(WorkflowGroupMembership.instance_id).where(
|
|
2804
|
+
WorkflowGroupMembership.group_name == group_name
|
|
2805
|
+
)
|
|
2806
|
+
)
|
|
2807
|
+
return [row[0] for row in result.fetchall()]
|
|
2808
|
+
|
|
2809
|
+
async def leave_all_groups(self, instance_id: str) -> None:
|
|
2810
|
+
"""
|
|
2811
|
+
Remove a workflow instance from all groups.
|
|
2812
|
+
|
|
2813
|
+
Called automatically when a workflow completes or fails.
|
|
2814
|
+
|
|
2815
|
+
Args:
|
|
2816
|
+
instance_id: Workflow instance ID
|
|
2817
|
+
"""
|
|
2818
|
+
session = self._get_session_for_operation()
|
|
2819
|
+
async with self._session_scope(session) as session:
|
|
2820
|
+
await session.execute(
|
|
2821
|
+
delete(WorkflowGroupMembership).where(
|
|
2822
|
+
WorkflowGroupMembership.instance_id == instance_id
|
|
2823
|
+
)
|
|
2824
|
+
)
|
|
2825
|
+
await self._commit_if_not_in_transaction(session)
|
|
2826
|
+
|
|
2827
|
+
# -------------------------------------------------------------------------
|
|
2828
|
+
# Workflow Resumption Methods
|
|
2829
|
+
# -------------------------------------------------------------------------
|
|
2830
|
+
|
|
2831
|
+
async def find_resumable_workflows(self) -> list[dict[str, Any]]:
|
|
2832
|
+
"""
|
|
2833
|
+
Find workflows that are ready to be resumed.
|
|
2834
|
+
|
|
2835
|
+
Returns workflows with status='running' that don't have an active lock.
|
|
2836
|
+
Used for immediate resumption after message delivery.
|
|
2837
|
+
|
|
2838
|
+
Returns:
|
|
2839
|
+
List of resumable workflows with instance_id and workflow_name.
|
|
2840
|
+
"""
|
|
2841
|
+
session = self._get_session_for_operation()
|
|
2842
|
+
async with self._session_scope(session) as session:
|
|
2843
|
+
result = await session.execute(
|
|
2844
|
+
select(
|
|
2845
|
+
WorkflowInstance.instance_id,
|
|
2846
|
+
WorkflowInstance.workflow_name,
|
|
2847
|
+
).where(
|
|
2848
|
+
and_(
|
|
2849
|
+
WorkflowInstance.status == "running",
|
|
2850
|
+
WorkflowInstance.locked_by.is_(None),
|
|
2851
|
+
)
|
|
2852
|
+
)
|
|
2853
|
+
)
|
|
2854
|
+
return [
|
|
2855
|
+
{
|
|
2856
|
+
"instance_id": row.instance_id,
|
|
2857
|
+
"workflow_name": row.workflow_name,
|
|
2858
|
+
}
|
|
2859
|
+
for row in result.fetchall()
|
|
2860
|
+
]
|
|
2861
|
+
|
|
2862
|
+
# -------------------------------------------------------------------------
|
|
2863
|
+
# Subscription Cleanup Methods (for recur())
|
|
2864
|
+
# -------------------------------------------------------------------------
|
|
2865
|
+
|
|
2866
|
+
async def cleanup_instance_subscriptions(self, instance_id: str) -> None:
|
|
2867
|
+
"""
|
|
2868
|
+
Remove all subscriptions for a workflow instance.
|
|
2869
|
+
|
|
2870
|
+
Called during recur() to clean up event/timer/message subscriptions
|
|
2871
|
+
before archiving the history.
|
|
2872
|
+
|
|
2873
|
+
Args:
|
|
2874
|
+
instance_id: Workflow instance ID to clean up
|
|
2875
|
+
"""
|
|
2876
|
+
session = self._get_session_for_operation()
|
|
2877
|
+
async with self._session_scope(session) as session:
|
|
2878
|
+
# Remove timer subscriptions
|
|
2879
|
+
await session.execute(
|
|
2880
|
+
delete(WorkflowTimerSubscription).where(
|
|
2881
|
+
WorkflowTimerSubscription.instance_id == instance_id
|
|
2882
|
+
)
|
|
2883
|
+
)
|
|
2884
|
+
|
|
2885
|
+
# Remove message subscriptions (legacy)
|
|
2886
|
+
await session.execute(
|
|
2887
|
+
delete(WorkflowMessageSubscription).where(
|
|
2888
|
+
WorkflowMessageSubscription.instance_id == instance_id
|
|
2889
|
+
)
|
|
2890
|
+
)
|
|
2891
|
+
|
|
2892
|
+
# Remove channel subscriptions
|
|
2893
|
+
await session.execute(
|
|
2894
|
+
delete(ChannelSubscription).where(ChannelSubscription.instance_id == instance_id)
|
|
2895
|
+
)
|
|
2896
|
+
|
|
2897
|
+
# Remove channel message claims
|
|
2898
|
+
await session.execute(
|
|
2899
|
+
delete(ChannelMessageClaim).where(ChannelMessageClaim.instance_id == instance_id)
|
|
2900
|
+
)
|
|
2901
|
+
|
|
2902
|
+
await self._commit_if_not_in_transaction(session)
|
|
2903
|
+
|
|
2904
|
+
# -------------------------------------------------------------------------
|
|
2905
|
+
# Channel-based Message Queue Methods
|
|
2906
|
+
# -------------------------------------------------------------------------
|
|
2907
|
+
|
|
2908
|
+
async def publish_to_channel(
|
|
2909
|
+
self,
|
|
2910
|
+
channel: str,
|
|
2911
|
+
data: dict[str, Any] | bytes,
|
|
2912
|
+
metadata: dict[str, Any] | None = None,
|
|
2913
|
+
) -> str:
|
|
2914
|
+
"""
|
|
2915
|
+
Publish a message to a channel.
|
|
2916
|
+
|
|
2917
|
+
Messages are persisted to channel_messages and available for subscribers.
|
|
2918
|
+
|
|
2919
|
+
Args:
|
|
2920
|
+
channel: Channel name
|
|
2921
|
+
data: Message payload (dict or bytes)
|
|
2922
|
+
metadata: Optional message metadata
|
|
2923
|
+
|
|
2924
|
+
Returns:
|
|
2925
|
+
Generated message_id (UUID)
|
|
2926
|
+
"""
|
|
2927
|
+
import uuid
|
|
2928
|
+
|
|
2929
|
+
message_id = str(uuid.uuid4())
|
|
2930
|
+
|
|
2931
|
+
# Determine data type and serialize
|
|
2932
|
+
if isinstance(data, bytes):
|
|
2933
|
+
data_type = "binary"
|
|
2934
|
+
data_json = None
|
|
2935
|
+
data_binary = data
|
|
2936
|
+
else:
|
|
2937
|
+
data_type = "json"
|
|
2938
|
+
data_json = json.dumps(data)
|
|
2939
|
+
data_binary = None
|
|
2940
|
+
|
|
2941
|
+
metadata_json = json.dumps(metadata) if metadata else None
|
|
2942
|
+
|
|
2943
|
+
session = self._get_session_for_operation()
|
|
2944
|
+
async with self._session_scope(session) as session:
|
|
2945
|
+
msg = ChannelMessage(
|
|
2946
|
+
channel=channel,
|
|
2947
|
+
message_id=message_id,
|
|
2948
|
+
data_type=data_type,
|
|
2949
|
+
data=data_json,
|
|
2950
|
+
data_binary=data_binary,
|
|
2951
|
+
message_metadata=metadata_json,
|
|
2952
|
+
)
|
|
2953
|
+
session.add(msg)
|
|
2954
|
+
await self._commit_if_not_in_transaction(session)
|
|
2955
|
+
|
|
2956
|
+
return message_id
|
|
2957
|
+
|
|
2958
|
+
async def subscribe_to_channel(
|
|
2959
|
+
self,
|
|
2960
|
+
instance_id: str,
|
|
2961
|
+
channel: str,
|
|
2962
|
+
mode: str,
|
|
2963
|
+
) -> None:
|
|
2964
|
+
"""
|
|
2965
|
+
Subscribe a workflow instance to a channel.
|
|
2966
|
+
|
|
2967
|
+
Args:
|
|
2968
|
+
instance_id: Workflow instance ID
|
|
2969
|
+
channel: Channel name
|
|
2970
|
+
mode: 'broadcast' or 'competing'
|
|
2971
|
+
|
|
2972
|
+
Raises:
|
|
2973
|
+
ValueError: If mode is invalid
|
|
2974
|
+
"""
|
|
2975
|
+
if mode not in ("broadcast", "competing"):
|
|
2976
|
+
raise ValueError(f"Invalid mode: {mode}. Must be 'broadcast' or 'competing'")
|
|
2977
|
+
|
|
2978
|
+
session = self._get_session_for_operation()
|
|
2979
|
+
async with self._session_scope(session) as session:
|
|
2980
|
+
# Check if already subscribed
|
|
2981
|
+
result = await session.execute(
|
|
2982
|
+
select(ChannelSubscription).where(
|
|
2983
|
+
and_(
|
|
2984
|
+
ChannelSubscription.instance_id == instance_id,
|
|
2985
|
+
ChannelSubscription.channel == channel,
|
|
2986
|
+
)
|
|
2987
|
+
)
|
|
2988
|
+
)
|
|
2989
|
+
existing = result.scalar_one_or_none()
|
|
2990
|
+
|
|
2991
|
+
if existing is not None:
|
|
2992
|
+
# Already subscribed, update mode if different
|
|
2993
|
+
if existing.mode != mode:
|
|
2994
|
+
existing.mode = mode
|
|
2995
|
+
await self._commit_if_not_in_transaction(session)
|
|
2996
|
+
return
|
|
2997
|
+
|
|
2998
|
+
# For broadcast mode, set cursor to current max message id
|
|
2999
|
+
# So subscriber only sees messages published after subscription
|
|
3000
|
+
cursor_message_id = None
|
|
3001
|
+
if mode == "broadcast":
|
|
3002
|
+
result = await session.execute(
|
|
3003
|
+
select(func.max(ChannelMessage.id)).where(ChannelMessage.channel == channel)
|
|
3004
|
+
)
|
|
3005
|
+
max_id = result.scalar()
|
|
3006
|
+
cursor_message_id = max_id if max_id is not None else 0
|
|
3007
|
+
|
|
3008
|
+
# Create subscription
|
|
3009
|
+
subscription = ChannelSubscription(
|
|
3010
|
+
instance_id=instance_id,
|
|
3011
|
+
channel=channel,
|
|
3012
|
+
mode=mode,
|
|
3013
|
+
cursor_message_id=cursor_message_id,
|
|
3014
|
+
)
|
|
3015
|
+
session.add(subscription)
|
|
3016
|
+
await self._commit_if_not_in_transaction(session)
|
|
3017
|
+
|
|
3018
|
+
async def unsubscribe_from_channel(
|
|
3019
|
+
self,
|
|
3020
|
+
instance_id: str,
|
|
3021
|
+
channel: str,
|
|
3022
|
+
) -> None:
|
|
3023
|
+
"""
|
|
3024
|
+
Unsubscribe a workflow instance from a channel.
|
|
3025
|
+
|
|
3026
|
+
Args:
|
|
3027
|
+
instance_id: Workflow instance ID
|
|
3028
|
+
channel: Channel name
|
|
3029
|
+
"""
|
|
3030
|
+
session = self._get_session_for_operation()
|
|
3031
|
+
async with self._session_scope(session) as session:
|
|
3032
|
+
await session.execute(
|
|
3033
|
+
delete(ChannelSubscription).where(
|
|
3034
|
+
and_(
|
|
3035
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3036
|
+
ChannelSubscription.channel == channel,
|
|
3037
|
+
)
|
|
3038
|
+
)
|
|
3039
|
+
)
|
|
3040
|
+
await self._commit_if_not_in_transaction(session)
|
|
3041
|
+
|
|
3042
|
+
async def get_channel_subscription(
|
|
3043
|
+
self,
|
|
3044
|
+
instance_id: str,
|
|
3045
|
+
channel: str,
|
|
3046
|
+
) -> dict[str, Any] | None:
|
|
3047
|
+
"""
|
|
3048
|
+
Get the subscription info for a workflow instance on a channel.
|
|
3049
|
+
|
|
3050
|
+
Args:
|
|
3051
|
+
instance_id: Workflow instance ID
|
|
3052
|
+
channel: Channel name
|
|
3053
|
+
|
|
3054
|
+
Returns:
|
|
3055
|
+
Subscription info dict with: mode, activity_id, cursor_message_id
|
|
3056
|
+
or None if not subscribed
|
|
3057
|
+
"""
|
|
3058
|
+
session = self._get_session_for_operation()
|
|
3059
|
+
async with self._session_scope(session) as session:
|
|
3060
|
+
result = await session.execute(
|
|
3061
|
+
select(ChannelSubscription).where(
|
|
3062
|
+
and_(
|
|
3063
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3064
|
+
ChannelSubscription.channel == channel,
|
|
3065
|
+
)
|
|
3066
|
+
)
|
|
3067
|
+
)
|
|
3068
|
+
subscription = result.scalar_one_or_none()
|
|
3069
|
+
|
|
3070
|
+
if subscription is None:
|
|
3071
|
+
return None
|
|
3072
|
+
|
|
3073
|
+
return {
|
|
3074
|
+
"mode": subscription.mode,
|
|
3075
|
+
"activity_id": subscription.activity_id,
|
|
3076
|
+
"cursor_message_id": subscription.cursor_message_id,
|
|
3077
|
+
}
|
|
3078
|
+
|
|
3079
|
+
async def register_channel_receive_and_release_lock(
|
|
3080
|
+
self,
|
|
3081
|
+
instance_id: str,
|
|
3082
|
+
worker_id: str,
|
|
3083
|
+
channel: str,
|
|
3084
|
+
activity_id: str | None = None,
|
|
3085
|
+
timeout_seconds: int | None = None,
|
|
3086
|
+
) -> None:
|
|
3087
|
+
"""
|
|
3088
|
+
Atomically register that workflow is waiting for channel message and release lock.
|
|
3089
|
+
|
|
3090
|
+
Args:
|
|
3091
|
+
instance_id: Workflow instance ID
|
|
3092
|
+
worker_id: Worker ID that currently holds the lock
|
|
3093
|
+
channel: Channel name being waited on
|
|
3094
|
+
activity_id: Current activity ID to record
|
|
3095
|
+
timeout_seconds: Optional timeout in seconds for the message wait
|
|
3096
|
+
|
|
3097
|
+
Raises:
|
|
3098
|
+
RuntimeError: If the worker doesn't hold the lock
|
|
3099
|
+
ValueError: If workflow is not subscribed to the channel
|
|
3100
|
+
"""
|
|
3101
|
+
async with self.engine.begin() as conn:
|
|
3102
|
+
session = AsyncSession(bind=conn, expire_on_commit=False)
|
|
3103
|
+
|
|
3104
|
+
# Verify lock ownership
|
|
3105
|
+
result = await session.execute(
|
|
3106
|
+
select(WorkflowInstance).where(WorkflowInstance.instance_id == instance_id)
|
|
3107
|
+
)
|
|
3108
|
+
instance = result.scalar_one_or_none()
|
|
3109
|
+
|
|
3110
|
+
if instance is None:
|
|
3111
|
+
raise RuntimeError(f"Instance not found: {instance_id}")
|
|
3112
|
+
|
|
3113
|
+
if instance.locked_by != worker_id:
|
|
3114
|
+
raise RuntimeError(
|
|
3115
|
+
f"Worker {worker_id} does not hold lock for {instance_id}. "
|
|
3116
|
+
f"Locked by: {instance.locked_by}"
|
|
3117
|
+
)
|
|
3118
|
+
|
|
3119
|
+
# Verify subscription exists
|
|
3120
|
+
sub_result = await session.execute(
|
|
3121
|
+
select(ChannelSubscription).where(
|
|
3122
|
+
and_(
|
|
3123
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3124
|
+
ChannelSubscription.channel == channel,
|
|
3125
|
+
)
|
|
3126
|
+
)
|
|
3127
|
+
)
|
|
3128
|
+
subscription: ChannelSubscription | None = sub_result.scalar_one_or_none()
|
|
3129
|
+
|
|
3130
|
+
if subscription is None:
|
|
3131
|
+
raise ValueError(f"Instance {instance_id} is not subscribed to channel {channel}")
|
|
3132
|
+
|
|
3133
|
+
# Update subscription to mark as waiting
|
|
3134
|
+
current_time = datetime.now(UTC)
|
|
3135
|
+
subscription.activity_id = activity_id
|
|
3136
|
+
# Calculate timeout_at if timeout_seconds is provided
|
|
3137
|
+
if timeout_seconds is not None:
|
|
3138
|
+
subscription.timeout_at = current_time + timedelta(seconds=timeout_seconds)
|
|
3139
|
+
else:
|
|
3140
|
+
subscription.timeout_at = None
|
|
3141
|
+
|
|
3142
|
+
# Update instance: set activity, status, release lock
|
|
3143
|
+
await session.execute(
|
|
3144
|
+
update(WorkflowInstance)
|
|
3145
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
3146
|
+
.values(
|
|
3147
|
+
current_activity_id=activity_id,
|
|
3148
|
+
status="waiting_for_message",
|
|
3149
|
+
locked_by=None,
|
|
3150
|
+
locked_at=None,
|
|
3151
|
+
lock_expires_at=None,
|
|
3152
|
+
updated_at=current_time,
|
|
3153
|
+
)
|
|
3154
|
+
)
|
|
3155
|
+
|
|
3156
|
+
await session.commit()
|
|
3157
|
+
|
|
3158
|
+
async def get_pending_channel_messages(
|
|
3159
|
+
self,
|
|
3160
|
+
instance_id: str,
|
|
3161
|
+
channel: str,
|
|
3162
|
+
) -> list[dict[str, Any]]:
|
|
3163
|
+
"""
|
|
3164
|
+
Get pending messages for a subscriber on a channel.
|
|
3165
|
+
|
|
3166
|
+
For broadcast mode: messages with id > cursor_message_id
|
|
3167
|
+
For competing mode: unclaimed messages
|
|
3168
|
+
|
|
3169
|
+
Args:
|
|
3170
|
+
instance_id: Workflow instance ID
|
|
3171
|
+
channel: Channel name
|
|
3172
|
+
|
|
3173
|
+
Returns:
|
|
3174
|
+
List of pending messages
|
|
3175
|
+
"""
|
|
3176
|
+
session = self._get_session_for_operation()
|
|
3177
|
+
async with self._session_scope(session) as session:
|
|
3178
|
+
# Get subscription info
|
|
3179
|
+
sub_result = await session.execute(
|
|
3180
|
+
select(ChannelSubscription).where(
|
|
3181
|
+
and_(
|
|
3182
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3183
|
+
ChannelSubscription.channel == channel,
|
|
3184
|
+
)
|
|
3185
|
+
)
|
|
3186
|
+
)
|
|
3187
|
+
subscription = sub_result.scalar_one_or_none()
|
|
3188
|
+
|
|
3189
|
+
if subscription is None:
|
|
3190
|
+
return []
|
|
3191
|
+
|
|
3192
|
+
if subscription.mode == "broadcast":
|
|
3193
|
+
# Get messages after cursor
|
|
3194
|
+
cursor = subscription.cursor_message_id or 0
|
|
3195
|
+
msg_result = await session.execute(
|
|
3196
|
+
select(ChannelMessage)
|
|
3197
|
+
.where(
|
|
3198
|
+
and_(
|
|
3199
|
+
ChannelMessage.channel == channel,
|
|
3200
|
+
ChannelMessage.id > cursor,
|
|
3201
|
+
)
|
|
3202
|
+
)
|
|
3203
|
+
.order_by(ChannelMessage.published_at.asc())
|
|
3204
|
+
)
|
|
3205
|
+
else: # competing
|
|
3206
|
+
# Get unclaimed messages (not in channel_message_claims)
|
|
3207
|
+
subquery = select(ChannelMessageClaim.message_id)
|
|
3208
|
+
msg_result = await session.execute(
|
|
3209
|
+
select(ChannelMessage)
|
|
3210
|
+
.where(
|
|
3211
|
+
and_(
|
|
3212
|
+
ChannelMessage.channel == channel,
|
|
3213
|
+
ChannelMessage.message_id.not_in(subquery),
|
|
3214
|
+
)
|
|
3215
|
+
)
|
|
3216
|
+
.order_by(ChannelMessage.published_at.asc())
|
|
3217
|
+
)
|
|
3218
|
+
|
|
3219
|
+
messages = msg_result.scalars().all()
|
|
3220
|
+
return [
|
|
3221
|
+
{
|
|
3222
|
+
"id": msg.id,
|
|
3223
|
+
"message_id": msg.message_id,
|
|
3224
|
+
"channel": msg.channel,
|
|
3225
|
+
"data": (
|
|
3226
|
+
msg.data_binary
|
|
3227
|
+
if msg.data_type == "binary"
|
|
3228
|
+
else json.loads(msg.data) if msg.data else {}
|
|
3229
|
+
),
|
|
3230
|
+
"metadata": json.loads(msg.message_metadata) if msg.message_metadata else {},
|
|
3231
|
+
"published_at": msg.published_at.isoformat() if msg.published_at else None,
|
|
3232
|
+
}
|
|
3233
|
+
for msg in messages
|
|
3234
|
+
]
|
|
3235
|
+
|
|
3236
|
+
async def claim_channel_message(
|
|
3237
|
+
self,
|
|
3238
|
+
message_id: str,
|
|
3239
|
+
instance_id: str,
|
|
3240
|
+
) -> bool:
|
|
3241
|
+
"""
|
|
3242
|
+
Claim a message for competing consumption.
|
|
3243
|
+
|
|
3244
|
+
Uses INSERT with conflict check to ensure only one subscriber claims.
|
|
3245
|
+
|
|
3246
|
+
Args:
|
|
3247
|
+
message_id: Message ID to claim
|
|
3248
|
+
instance_id: Workflow instance claiming the message
|
|
3249
|
+
|
|
3250
|
+
Returns:
|
|
3251
|
+
True if claim succeeded, False if already claimed
|
|
3252
|
+
"""
|
|
3253
|
+
session = self._get_session_for_operation()
|
|
3254
|
+
async with self._session_scope(session) as session:
|
|
3255
|
+
try:
|
|
3256
|
+
# Check if already claimed
|
|
3257
|
+
result = await session.execute(
|
|
3258
|
+
select(ChannelMessageClaim).where(ChannelMessageClaim.message_id == message_id)
|
|
3259
|
+
)
|
|
3260
|
+
if result.scalar_one_or_none() is not None:
|
|
3261
|
+
return False # Already claimed
|
|
3262
|
+
|
|
3263
|
+
claim = ChannelMessageClaim(
|
|
3264
|
+
message_id=message_id,
|
|
3265
|
+
instance_id=instance_id,
|
|
3266
|
+
)
|
|
3267
|
+
session.add(claim)
|
|
3268
|
+
await self._commit_if_not_in_transaction(session)
|
|
3269
|
+
return True
|
|
3270
|
+
except Exception:
|
|
3271
|
+
return False
|
|
3272
|
+
|
|
3273
|
+
async def delete_channel_message(self, message_id: str) -> None:
|
|
3274
|
+
"""
|
|
3275
|
+
Delete a message from the channel queue.
|
|
3276
|
+
|
|
3277
|
+
Args:
|
|
3278
|
+
message_id: Message ID to delete
|
|
3279
|
+
"""
|
|
3280
|
+
session = self._get_session_for_operation()
|
|
3281
|
+
async with self._session_scope(session) as session:
|
|
3282
|
+
# Delete claim first (foreign key)
|
|
3283
|
+
await session.execute(
|
|
3284
|
+
delete(ChannelMessageClaim).where(ChannelMessageClaim.message_id == message_id)
|
|
3285
|
+
)
|
|
3286
|
+
# Delete message
|
|
3287
|
+
await session.execute(
|
|
3288
|
+
delete(ChannelMessage).where(ChannelMessage.message_id == message_id)
|
|
3289
|
+
)
|
|
3290
|
+
await self._commit_if_not_in_transaction(session)
|
|
3291
|
+
|
|
3292
|
+
async def update_delivery_cursor(
|
|
3293
|
+
self,
|
|
3294
|
+
channel: str,
|
|
3295
|
+
instance_id: str,
|
|
3296
|
+
message_id: int,
|
|
3297
|
+
) -> None:
|
|
3298
|
+
"""
|
|
3299
|
+
Update the delivery cursor for broadcast mode.
|
|
3300
|
+
|
|
3301
|
+
Args:
|
|
3302
|
+
channel: Channel name
|
|
3303
|
+
instance_id: Subscriber instance ID
|
|
3304
|
+
message_id: Last delivered message's internal ID
|
|
3305
|
+
"""
|
|
3306
|
+
session = self._get_session_for_operation()
|
|
3307
|
+
async with self._session_scope(session) as session:
|
|
3308
|
+
# Update subscription cursor
|
|
3309
|
+
await session.execute(
|
|
3310
|
+
update(ChannelSubscription)
|
|
3311
|
+
.where(
|
|
3312
|
+
and_(
|
|
3313
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3314
|
+
ChannelSubscription.channel == channel,
|
|
3315
|
+
)
|
|
3316
|
+
)
|
|
3317
|
+
.values(cursor_message_id=message_id)
|
|
3318
|
+
)
|
|
3319
|
+
await self._commit_if_not_in_transaction(session)
|
|
3320
|
+
|
|
3321
|
+
async def get_channel_subscribers_waiting(
|
|
3322
|
+
self,
|
|
3323
|
+
channel: str,
|
|
3324
|
+
) -> list[dict[str, Any]]:
|
|
3325
|
+
"""
|
|
3326
|
+
Get channel subscribers that are waiting (activity_id is set).
|
|
3327
|
+
|
|
3328
|
+
Args:
|
|
3329
|
+
channel: Channel name
|
|
3330
|
+
|
|
3331
|
+
Returns:
|
|
3332
|
+
List of waiting subscribers
|
|
3333
|
+
"""
|
|
3334
|
+
session = self._get_session_for_operation()
|
|
3335
|
+
async with self._session_scope(session) as session:
|
|
3336
|
+
result = await session.execute(
|
|
3337
|
+
select(ChannelSubscription).where(
|
|
3338
|
+
and_(
|
|
3339
|
+
ChannelSubscription.channel == channel,
|
|
3340
|
+
ChannelSubscription.activity_id.isnot(None),
|
|
3341
|
+
)
|
|
3342
|
+
)
|
|
3343
|
+
)
|
|
3344
|
+
subscriptions = result.scalars().all()
|
|
3345
|
+
return [
|
|
3346
|
+
{
|
|
3347
|
+
"instance_id": sub.instance_id,
|
|
3348
|
+
"channel": sub.channel,
|
|
3349
|
+
"mode": sub.mode,
|
|
3350
|
+
"activity_id": sub.activity_id,
|
|
3351
|
+
}
|
|
3352
|
+
for sub in subscriptions
|
|
3353
|
+
]
|
|
3354
|
+
|
|
3355
|
+
async def clear_channel_waiting_state(
|
|
3356
|
+
self,
|
|
3357
|
+
instance_id: str,
|
|
3358
|
+
channel: str,
|
|
3359
|
+
) -> None:
|
|
3360
|
+
"""
|
|
3361
|
+
Clear the waiting state for a channel subscription.
|
|
3362
|
+
|
|
3363
|
+
Args:
|
|
3364
|
+
instance_id: Workflow instance ID
|
|
3365
|
+
channel: Channel name
|
|
3366
|
+
"""
|
|
3367
|
+
session = self._get_session_for_operation()
|
|
3368
|
+
async with self._session_scope(session) as session:
|
|
3369
|
+
await session.execute(
|
|
3370
|
+
update(ChannelSubscription)
|
|
3371
|
+
.where(
|
|
3372
|
+
and_(
|
|
3373
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3374
|
+
ChannelSubscription.channel == channel,
|
|
3375
|
+
)
|
|
3376
|
+
)
|
|
3377
|
+
.values(activity_id=None)
|
|
3378
|
+
)
|
|
3379
|
+
await self._commit_if_not_in_transaction(session)
|
|
3380
|
+
|
|
3381
|
+
async def deliver_channel_message(
|
|
3382
|
+
self,
|
|
3383
|
+
instance_id: str,
|
|
3384
|
+
channel: str,
|
|
3385
|
+
message_id: str,
|
|
3386
|
+
data: dict[str, Any] | bytes,
|
|
3387
|
+
metadata: dict[str, Any],
|
|
3388
|
+
worker_id: str,
|
|
3389
|
+
) -> dict[str, Any] | None:
|
|
3390
|
+
"""
|
|
3391
|
+
Deliver a channel message to a waiting workflow.
|
|
3392
|
+
|
|
3393
|
+
Uses Lock-First pattern for distributed safety.
|
|
3394
|
+
|
|
3395
|
+
Args:
|
|
3396
|
+
instance_id: Target workflow instance ID
|
|
3397
|
+
channel: Channel name
|
|
3398
|
+
message_id: Message ID being delivered
|
|
3399
|
+
data: Message payload
|
|
3400
|
+
metadata: Message metadata
|
|
3401
|
+
worker_id: Worker ID for locking
|
|
3402
|
+
|
|
3403
|
+
Returns:
|
|
3404
|
+
Delivery info if successful, None if failed
|
|
3405
|
+
"""
|
|
3406
|
+
try:
|
|
3407
|
+
# Try to acquire lock
|
|
3408
|
+
if not await self.try_acquire_lock(instance_id, worker_id):
|
|
3409
|
+
logger.debug(f"Failed to acquire lock for {instance_id}")
|
|
3410
|
+
return None
|
|
3411
|
+
|
|
3412
|
+
try:
|
|
3413
|
+
async with self.engine.begin() as conn:
|
|
3414
|
+
session = AsyncSession(bind=conn, expire_on_commit=False)
|
|
3415
|
+
|
|
3416
|
+
# Get subscription info
|
|
3417
|
+
result = await session.execute(
|
|
3418
|
+
select(ChannelSubscription).where(
|
|
3419
|
+
and_(
|
|
3420
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3421
|
+
ChannelSubscription.channel == channel,
|
|
3422
|
+
)
|
|
3423
|
+
)
|
|
3424
|
+
)
|
|
3425
|
+
subscription = result.scalar_one_or_none()
|
|
3426
|
+
|
|
3427
|
+
if subscription is None or subscription.activity_id is None:
|
|
3428
|
+
logger.debug(f"No waiting subscription for {instance_id} on {channel}")
|
|
3429
|
+
return None
|
|
3430
|
+
|
|
3431
|
+
activity_id = subscription.activity_id
|
|
3432
|
+
|
|
3433
|
+
# Get instance info for return value
|
|
3434
|
+
result = await session.execute(
|
|
3435
|
+
select(WorkflowInstance.workflow_name).where(
|
|
3436
|
+
WorkflowInstance.instance_id == instance_id
|
|
3437
|
+
)
|
|
3438
|
+
)
|
|
3439
|
+
row = result.one_or_none()
|
|
3440
|
+
if row is None:
|
|
3441
|
+
return None
|
|
3442
|
+
workflow_name = row[0]
|
|
3443
|
+
|
|
3444
|
+
# Prepare message data for history
|
|
3445
|
+
# Use "id" key to match what context.py expects when loading history
|
|
3446
|
+
current_time = datetime.now(UTC)
|
|
3447
|
+
message_result = {
|
|
3448
|
+
"id": message_id,
|
|
3449
|
+
"channel": channel,
|
|
3450
|
+
"data": data if isinstance(data, dict) else None,
|
|
3451
|
+
"metadata": metadata,
|
|
3452
|
+
"published_at": current_time.isoformat(),
|
|
3453
|
+
}
|
|
3454
|
+
|
|
3455
|
+
# Record to history
|
|
3456
|
+
if isinstance(data, bytes):
|
|
3457
|
+
history = WorkflowHistory(
|
|
3458
|
+
instance_id=instance_id,
|
|
3459
|
+
activity_id=activity_id,
|
|
3460
|
+
event_type="ChannelMessageReceived",
|
|
3461
|
+
data_type="binary",
|
|
3462
|
+
event_data=None,
|
|
3463
|
+
event_data_binary=data,
|
|
3464
|
+
)
|
|
3465
|
+
else:
|
|
3466
|
+
history = WorkflowHistory(
|
|
3467
|
+
instance_id=instance_id,
|
|
3468
|
+
activity_id=activity_id,
|
|
3469
|
+
event_type="ChannelMessageReceived",
|
|
3470
|
+
data_type="json",
|
|
3471
|
+
event_data=json.dumps(message_result),
|
|
3472
|
+
event_data_binary=None,
|
|
3473
|
+
)
|
|
3474
|
+
session.add(history)
|
|
3475
|
+
|
|
3476
|
+
# Handle mode-specific logic
|
|
3477
|
+
if subscription.mode == "broadcast":
|
|
3478
|
+
# Get message internal id to update cursor
|
|
3479
|
+
result = await session.execute(
|
|
3480
|
+
select(ChannelMessage.id).where(ChannelMessage.message_id == message_id)
|
|
3481
|
+
)
|
|
3482
|
+
msg_row = result.one_or_none()
|
|
3483
|
+
if msg_row:
|
|
3484
|
+
subscription.cursor_message_id = msg_row[0]
|
|
3485
|
+
else: # competing
|
|
3486
|
+
# Claim and delete the message
|
|
3487
|
+
claim = ChannelMessageClaim(
|
|
3488
|
+
message_id=message_id,
|
|
3489
|
+
instance_id=instance_id,
|
|
3490
|
+
)
|
|
3491
|
+
session.add(claim)
|
|
3492
|
+
|
|
3493
|
+
# Delete the message (competing mode consumes it)
|
|
3494
|
+
await session.execute(
|
|
3495
|
+
delete(ChannelMessage).where(ChannelMessage.message_id == message_id)
|
|
3496
|
+
)
|
|
3497
|
+
|
|
3498
|
+
# Clear waiting state
|
|
3499
|
+
subscription.activity_id = None
|
|
3500
|
+
|
|
3501
|
+
# Update instance status to running
|
|
3502
|
+
current_time = datetime.now(UTC)
|
|
3503
|
+
await session.execute(
|
|
3504
|
+
update(WorkflowInstance)
|
|
3505
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
3506
|
+
.values(
|
|
3507
|
+
status="running",
|
|
3508
|
+
updated_at=current_time,
|
|
3509
|
+
)
|
|
3510
|
+
)
|
|
3511
|
+
|
|
3512
|
+
await session.commit()
|
|
3513
|
+
|
|
3514
|
+
return {
|
|
3515
|
+
"instance_id": instance_id,
|
|
3516
|
+
"workflow_name": workflow_name,
|
|
3517
|
+
"activity_id": activity_id,
|
|
3518
|
+
}
|
|
3519
|
+
|
|
3520
|
+
finally:
|
|
3521
|
+
# Always release lock
|
|
3522
|
+
await self.release_lock(instance_id, worker_id)
|
|
3523
|
+
|
|
3524
|
+
except Exception as e:
|
|
3525
|
+
logger.error(f"Error delivering channel message: {e}")
|
|
3526
|
+
return None
|
|
3527
|
+
|
|
3528
|
+
async def cleanup_old_channel_messages(self, older_than_days: int = 7) -> int:
|
|
3529
|
+
"""
|
|
3530
|
+
Clean up old messages from channel queues.
|
|
3531
|
+
|
|
3532
|
+
Args:
|
|
3533
|
+
older_than_days: Message retention period in days
|
|
3534
|
+
|
|
3535
|
+
Returns:
|
|
3536
|
+
Number of messages deleted
|
|
3537
|
+
"""
|
|
3538
|
+
cutoff_time = datetime.now(UTC) - timedelta(days=older_than_days)
|
|
3539
|
+
|
|
3540
|
+
session = self._get_session_for_operation()
|
|
3541
|
+
async with self._session_scope(session) as session:
|
|
3542
|
+
# First delete claims for old messages
|
|
3543
|
+
await session.execute(
|
|
3544
|
+
delete(ChannelMessageClaim).where(
|
|
3545
|
+
ChannelMessageClaim.message_id.in_(
|
|
3546
|
+
select(ChannelMessage.message_id).where(
|
|
3547
|
+
self._make_datetime_comparable(ChannelMessage.published_at)
|
|
3548
|
+
< self._get_current_time_expr()
|
|
3549
|
+
)
|
|
3550
|
+
)
|
|
3551
|
+
)
|
|
3552
|
+
)
|
|
3553
|
+
|
|
3554
|
+
# Delete old messages
|
|
3555
|
+
result = await session.execute(
|
|
3556
|
+
delete(ChannelMessage)
|
|
3557
|
+
.where(ChannelMessage.published_at < cutoff_time)
|
|
3558
|
+
.returning(ChannelMessage.id)
|
|
3559
|
+
)
|
|
3560
|
+
deleted_ids = result.fetchall()
|
|
3561
|
+
await self._commit_if_not_in_transaction(session)
|
|
3562
|
+
|
|
3563
|
+
return len(deleted_ids)
|