edda-framework 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edda/__init__.py +39 -5
- edda/app.py +383 -223
- edda/channels.py +1017 -0
- edda/compensation.py +22 -22
- edda/context.py +105 -52
- edda/integrations/opentelemetry/hooks.py +7 -2
- edda/locking.py +130 -67
- edda/replay.py +312 -82
- edda/storage/models.py +142 -24
- edda/storage/protocol.py +539 -118
- edda/storage/sqlalchemy_storage.py +1866 -328
- edda/viewer_ui/app.py +6 -1
- edda/viewer_ui/data_service.py +19 -22
- edda/workflow.py +43 -0
- {edda_framework-0.7.0.dist-info → edda_framework-0.9.0.dist-info}/METADATA +109 -9
- {edda_framework-0.7.0.dist-info → edda_framework-0.9.0.dist-info}/RECORD +19 -19
- edda/events.py +0 -505
- {edda_framework-0.7.0.dist-info → edda_framework-0.9.0.dist-info}/WHEEL +0 -0
- {edda_framework-0.7.0.dist-info → edda_framework-0.9.0.dist-info}/entry_points.txt +0 -0
- {edda_framework-0.7.0.dist-info → edda_framework-0.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -8,7 +8,7 @@ and transactional outbox pattern.
|
|
|
8
8
|
|
|
9
9
|
import json
|
|
10
10
|
import logging
|
|
11
|
-
from collections.abc import AsyncIterator
|
|
11
|
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
12
12
|
from contextlib import asynccontextmanager
|
|
13
13
|
from contextvars import ContextVar
|
|
14
14
|
from dataclasses import dataclass, field
|
|
@@ -19,6 +19,7 @@ from sqlalchemy import (
|
|
|
19
19
|
CheckConstraint,
|
|
20
20
|
Column,
|
|
21
21
|
DateTime,
|
|
22
|
+
ForeignKey,
|
|
22
23
|
ForeignKeyConstraint,
|
|
23
24
|
Index,
|
|
24
25
|
Integer,
|
|
@@ -29,18 +30,21 @@ from sqlalchemy import (
|
|
|
29
30
|
and_,
|
|
30
31
|
delete,
|
|
31
32
|
func,
|
|
33
|
+
inspect,
|
|
32
34
|
or_,
|
|
33
35
|
select,
|
|
34
36
|
text,
|
|
35
37
|
update,
|
|
36
38
|
)
|
|
37
39
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
|
38
|
-
from sqlalchemy.orm import
|
|
40
|
+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
39
41
|
|
|
40
42
|
logger = logging.getLogger(__name__)
|
|
41
43
|
|
|
44
|
+
|
|
42
45
|
# Declarative base for ORM models
|
|
43
|
-
Base
|
|
46
|
+
class Base(DeclarativeBase):
|
|
47
|
+
pass
|
|
44
48
|
|
|
45
49
|
|
|
46
50
|
# ============================================================================
|
|
@@ -48,25 +52,25 @@ Base = declarative_base()
|
|
|
48
52
|
# ============================================================================
|
|
49
53
|
|
|
50
54
|
|
|
51
|
-
class SchemaVersion(Base):
|
|
55
|
+
class SchemaVersion(Base):
|
|
52
56
|
"""Schema version tracking."""
|
|
53
57
|
|
|
54
58
|
__tablename__ = "schema_version"
|
|
55
59
|
|
|
56
|
-
version =
|
|
57
|
-
applied_at =
|
|
58
|
-
description =
|
|
60
|
+
version: Mapped[int] = mapped_column(Integer, primary_key=True)
|
|
61
|
+
applied_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
62
|
+
description: Mapped[str] = mapped_column(Text)
|
|
59
63
|
|
|
60
64
|
|
|
61
|
-
class WorkflowDefinition(Base):
|
|
65
|
+
class WorkflowDefinition(Base):
|
|
62
66
|
"""Workflow definition (source code storage)."""
|
|
63
67
|
|
|
64
68
|
__tablename__ = "workflow_definitions"
|
|
65
69
|
|
|
66
|
-
workflow_name =
|
|
67
|
-
source_hash =
|
|
68
|
-
source_code =
|
|
69
|
-
created_at =
|
|
70
|
+
workflow_name: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
71
|
+
source_hash: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
72
|
+
source_code: Mapped[str] = mapped_column(Text)
|
|
73
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
70
74
|
|
|
71
75
|
__table_args__ = (
|
|
72
76
|
Index("idx_definitions_name", "workflow_name"),
|
|
@@ -74,31 +78,34 @@ class WorkflowDefinition(Base): # type: ignore[valid-type, misc]
|
|
|
74
78
|
)
|
|
75
79
|
|
|
76
80
|
|
|
77
|
-
class WorkflowInstance(Base):
|
|
81
|
+
class WorkflowInstance(Base):
|
|
78
82
|
"""Workflow instance with distributed locking support."""
|
|
79
83
|
|
|
80
84
|
__tablename__ = "workflow_instances"
|
|
81
85
|
|
|
82
|
-
instance_id =
|
|
83
|
-
workflow_name =
|
|
84
|
-
source_hash =
|
|
85
|
-
owner_service =
|
|
86
|
-
status =
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
86
|
+
instance_id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
87
|
+
workflow_name: Mapped[str] = mapped_column(String(255))
|
|
88
|
+
source_hash: Mapped[str] = mapped_column(String(64))
|
|
89
|
+
owner_service: Mapped[str] = mapped_column(String(255))
|
|
90
|
+
status: Mapped[str] = mapped_column(String(50), server_default=text("'running'"))
|
|
91
|
+
current_activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
92
|
+
continued_from: Mapped[str | None] = mapped_column(
|
|
93
|
+
String(255), ForeignKey("workflow_instances.instance_id"), nullable=True
|
|
90
94
|
)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
95
|
+
started_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
96
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
97
|
+
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
|
95
98
|
)
|
|
96
|
-
input_data =
|
|
97
|
-
output_data =
|
|
98
|
-
locked_by =
|
|
99
|
-
locked_at =
|
|
100
|
-
lock_timeout_seconds
|
|
101
|
-
|
|
99
|
+
input_data: Mapped[str] = mapped_column(Text) # JSON
|
|
100
|
+
output_data: Mapped[str | None] = mapped_column(Text, nullable=True) # JSON
|
|
101
|
+
locked_by: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
102
|
+
locked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
103
|
+
lock_timeout_seconds: Mapped[int | None] = mapped_column(
|
|
104
|
+
Integer, nullable=True
|
|
105
|
+
) # None = use global default (300s)
|
|
106
|
+
lock_expires_at: Mapped[datetime | None] = mapped_column(
|
|
107
|
+
DateTime(timezone=True), nullable=True
|
|
108
|
+
) # Absolute expiry time
|
|
102
109
|
|
|
103
110
|
__table_args__ = (
|
|
104
111
|
ForeignKeyConstraint(
|
|
@@ -107,7 +114,7 @@ class WorkflowInstance(Base): # type: ignore[valid-type, misc]
|
|
|
107
114
|
),
|
|
108
115
|
CheckConstraint(
|
|
109
116
|
"status IN ('running', 'completed', 'failed', 'waiting_for_event', "
|
|
110
|
-
"'waiting_for_timer', 'compensating', 'cancelled')",
|
|
117
|
+
"'waiting_for_timer', 'waiting_for_message', 'compensating', 'cancelled', 'recurred')",
|
|
111
118
|
name="valid_status",
|
|
112
119
|
),
|
|
113
120
|
Index("idx_instances_status", "status"),
|
|
@@ -117,25 +124,29 @@ class WorkflowInstance(Base): # type: ignore[valid-type, misc]
|
|
|
117
124
|
Index("idx_instances_lock_expires", "lock_expires_at"),
|
|
118
125
|
Index("idx_instances_updated", "updated_at"),
|
|
119
126
|
Index("idx_instances_hash", "source_hash"),
|
|
127
|
+
Index("idx_instances_continued_from", "continued_from"),
|
|
128
|
+
# Composite index for find_resumable_workflows(): WHERE status='running' AND locked_by IS NULL
|
|
129
|
+
Index("idx_instances_resumable", "status", "locked_by"),
|
|
120
130
|
)
|
|
121
131
|
|
|
122
132
|
|
|
123
|
-
class WorkflowHistory(Base):
|
|
133
|
+
class WorkflowHistory(Base):
|
|
124
134
|
"""Workflow execution history (for deterministic replay)."""
|
|
125
135
|
|
|
126
136
|
__tablename__ = "workflow_history"
|
|
127
137
|
|
|
128
|
-
id =
|
|
129
|
-
instance_id =
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
138
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
139
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
140
|
+
activity_id: Mapped[str] = mapped_column(String(255))
|
|
141
|
+
event_type: Mapped[str] = mapped_column(String(100))
|
|
142
|
+
data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
|
|
143
|
+
event_data: Mapped[str | None] = mapped_column(
|
|
144
|
+
Text, nullable=True
|
|
145
|
+
) # JSON (when data_type='json')
|
|
146
|
+
event_data_binary: Mapped[bytes | None] = mapped_column(
|
|
147
|
+
LargeBinary, nullable=True
|
|
148
|
+
) # Binary (when data_type='binary')
|
|
149
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
139
150
|
|
|
140
151
|
__table_args__ = (
|
|
141
152
|
ForeignKeyConstraint(
|
|
@@ -158,20 +169,20 @@ class WorkflowHistory(Base): # type: ignore[valid-type, misc]
|
|
|
158
169
|
)
|
|
159
170
|
|
|
160
171
|
|
|
161
|
-
class
|
|
162
|
-
"""
|
|
172
|
+
class WorkflowHistoryArchive(Base):
|
|
173
|
+
"""Archived workflow execution history (for recur pattern)."""
|
|
163
174
|
|
|
164
|
-
__tablename__ = "
|
|
175
|
+
__tablename__ = "workflow_history_archive"
|
|
165
176
|
|
|
166
|
-
id =
|
|
167
|
-
instance_id =
|
|
168
|
-
|
|
169
|
-
|
|
177
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
178
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
179
|
+
activity_id: Mapped[str] = mapped_column(String(255))
|
|
180
|
+
event_type: Mapped[str] = mapped_column(String(100))
|
|
181
|
+
event_data: Mapped[str] = mapped_column(Text) # JSON (includes both types for archive)
|
|
182
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
183
|
+
archived_at: Mapped[datetime] = mapped_column(
|
|
184
|
+
DateTime(timezone=True), server_default=func.now()
|
|
170
185
|
)
|
|
171
|
-
activity_id = Column(String(255), nullable=False)
|
|
172
|
-
activity_name = Column(String(255), nullable=False)
|
|
173
|
-
args = Column(Text, nullable=False) # JSON
|
|
174
|
-
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
175
186
|
|
|
176
187
|
__table_args__ = (
|
|
177
188
|
ForeignKeyConstraint(
|
|
@@ -179,24 +190,22 @@ class WorkflowCompensation(Base): # type: ignore[valid-type, misc]
|
|
|
179
190
|
["workflow_instances.instance_id"],
|
|
180
191
|
ondelete="CASCADE",
|
|
181
192
|
),
|
|
182
|
-
Index("
|
|
193
|
+
Index("idx_history_archive_instance", "instance_id"),
|
|
194
|
+
Index("idx_history_archive_archived", "archived_at"),
|
|
183
195
|
)
|
|
184
196
|
|
|
185
197
|
|
|
186
|
-
class
|
|
187
|
-
"""
|
|
198
|
+
class WorkflowCompensation(Base):
|
|
199
|
+
"""Compensation transactions (LIFO stack for Saga pattern)."""
|
|
188
200
|
|
|
189
|
-
__tablename__ = "
|
|
201
|
+
__tablename__ = "workflow_compensations"
|
|
190
202
|
|
|
191
|
-
id =
|
|
192
|
-
instance_id =
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
activity_id = Column(String(255), nullable=True)
|
|
198
|
-
timeout_at = Column(DateTime(timezone=True), nullable=True)
|
|
199
|
-
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
203
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
204
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
205
|
+
activity_id: Mapped[str] = mapped_column(String(255))
|
|
206
|
+
activity_name: Mapped[str] = mapped_column(String(255))
|
|
207
|
+
args: Mapped[str] = mapped_column(Text) # JSON
|
|
208
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
200
209
|
|
|
201
210
|
__table_args__ = (
|
|
202
211
|
ForeignKeyConstraint(
|
|
@@ -204,27 +213,21 @@ class WorkflowEventSubscription(Base): # type: ignore[valid-type, misc]
|
|
|
204
213
|
["workflow_instances.instance_id"],
|
|
205
214
|
ondelete="CASCADE",
|
|
206
215
|
),
|
|
207
|
-
|
|
208
|
-
Index("idx_subscriptions_event", "event_type"),
|
|
209
|
-
Index("idx_subscriptions_timeout", "timeout_at"),
|
|
210
|
-
Index("idx_subscriptions_instance", "instance_id"),
|
|
216
|
+
Index("idx_compensations_instance", "instance_id", "created_at"),
|
|
211
217
|
)
|
|
212
218
|
|
|
213
219
|
|
|
214
|
-
class WorkflowTimerSubscription(Base):
|
|
220
|
+
class WorkflowTimerSubscription(Base):
|
|
215
221
|
"""Timer subscriptions (for wait_timer)."""
|
|
216
222
|
|
|
217
223
|
__tablename__ = "workflow_timer_subscriptions"
|
|
218
224
|
|
|
219
|
-
id =
|
|
220
|
-
instance_id =
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
expires_at = Column(DateTime(timezone=True), nullable=False)
|
|
226
|
-
activity_id = Column(String(255), nullable=True)
|
|
227
|
-
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
|
225
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
226
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
227
|
+
timer_id: Mapped[str] = mapped_column(String(255))
|
|
228
|
+
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
229
|
+
activity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
230
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
228
231
|
|
|
229
232
|
__table_args__ = (
|
|
230
233
|
ForeignKeyConstraint(
|
|
@@ -238,23 +241,29 @@ class WorkflowTimerSubscription(Base): # type: ignore[valid-type, misc]
|
|
|
238
241
|
)
|
|
239
242
|
|
|
240
243
|
|
|
241
|
-
class OutboxEvent(Base):
|
|
244
|
+
class OutboxEvent(Base):
|
|
242
245
|
"""Transactional outbox pattern events."""
|
|
243
246
|
|
|
244
247
|
__tablename__ = "outbox_events"
|
|
245
248
|
|
|
246
|
-
event_id =
|
|
247
|
-
event_type =
|
|
248
|
-
event_source =
|
|
249
|
-
data_type =
|
|
250
|
-
event_data
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
249
|
+
event_id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
250
|
+
event_type: Mapped[str] = mapped_column(String(255))
|
|
251
|
+
event_source: Mapped[str] = mapped_column(String(255))
|
|
252
|
+
data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
|
|
253
|
+
event_data: Mapped[str | None] = mapped_column(
|
|
254
|
+
Text, nullable=True
|
|
255
|
+
) # JSON (when data_type='json')
|
|
256
|
+
event_data_binary: Mapped[bytes | None] = mapped_column(
|
|
257
|
+
LargeBinary, nullable=True
|
|
258
|
+
) # Binary (when data_type='binary')
|
|
259
|
+
content_type: Mapped[str] = mapped_column(
|
|
260
|
+
String(100), server_default=text("'application/json'")
|
|
261
|
+
)
|
|
262
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
263
|
+
published_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
264
|
+
status: Mapped[str] = mapped_column(String(50), server_default=text("'pending'"))
|
|
265
|
+
retry_count: Mapped[int] = mapped_column(Integer, server_default=text("0"))
|
|
266
|
+
last_error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
258
267
|
|
|
259
268
|
__table_args__ = (
|
|
260
269
|
CheckConstraint(
|
|
@@ -276,6 +285,178 @@ class OutboxEvent(Base): # type: ignore[valid-type, misc]
|
|
|
276
285
|
)
|
|
277
286
|
|
|
278
287
|
|
|
288
|
+
class WorkflowGroupMembership(Base):
|
|
289
|
+
"""Group memberships (Erlang pg style for broadcast messaging)."""
|
|
290
|
+
|
|
291
|
+
__tablename__ = "workflow_group_memberships"
|
|
292
|
+
|
|
293
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
294
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
295
|
+
group_name: Mapped[str] = mapped_column(String(255))
|
|
296
|
+
joined_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
297
|
+
|
|
298
|
+
__table_args__ = (
|
|
299
|
+
ForeignKeyConstraint(
|
|
300
|
+
["instance_id"],
|
|
301
|
+
["workflow_instances.instance_id"],
|
|
302
|
+
ondelete="CASCADE",
|
|
303
|
+
),
|
|
304
|
+
UniqueConstraint("instance_id", "group_name", name="unique_instance_group"),
|
|
305
|
+
Index("idx_group_memberships_group", "group_name"),
|
|
306
|
+
Index("idx_group_memberships_instance", "instance_id"),
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# =============================================================================
|
|
311
|
+
# Channel-based Message Queue Models
|
|
312
|
+
# =============================================================================
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class ChannelMessage(Base):
|
|
316
|
+
"""Channel message queue (persistent message storage)."""
|
|
317
|
+
|
|
318
|
+
__tablename__ = "channel_messages"
|
|
319
|
+
|
|
320
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
321
|
+
channel: Mapped[str] = mapped_column(String(255))
|
|
322
|
+
message_id: Mapped[str] = mapped_column(String(255), unique=True)
|
|
323
|
+
data_type: Mapped[str] = mapped_column(String(10)) # 'json' or 'binary'
|
|
324
|
+
data: Mapped[str | None] = mapped_column(Text, nullable=True) # JSON (when data_type='json')
|
|
325
|
+
data_binary: Mapped[bytes | None] = mapped_column(
|
|
326
|
+
LargeBinary, nullable=True
|
|
327
|
+
) # Binary (when data_type='binary')
|
|
328
|
+
message_metadata: Mapped[str | None] = mapped_column(
|
|
329
|
+
"metadata", Text, nullable=True
|
|
330
|
+
) # JSON - renamed to avoid SQLAlchemy reserved name
|
|
331
|
+
published_at: Mapped[datetime] = mapped_column(
|
|
332
|
+
DateTime(timezone=True), server_default=func.now()
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
__table_args__ = (
|
|
336
|
+
CheckConstraint(
|
|
337
|
+
"data_type IN ('json', 'binary')",
|
|
338
|
+
name="channel_valid_data_type",
|
|
339
|
+
),
|
|
340
|
+
CheckConstraint(
|
|
341
|
+
"(data_type = 'json' AND data IS NOT NULL AND data_binary IS NULL) OR "
|
|
342
|
+
"(data_type = 'binary' AND data IS NULL AND data_binary IS NOT NULL)",
|
|
343
|
+
name="channel_data_type_consistency",
|
|
344
|
+
),
|
|
345
|
+
Index("idx_channel_messages_channel", "channel", "published_at"),
|
|
346
|
+
Index("idx_channel_messages_id", "id"),
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class ChannelSubscription(Base):
|
|
351
|
+
"""Channel subscriptions for workflow instances."""
|
|
352
|
+
|
|
353
|
+
__tablename__ = "channel_subscriptions"
|
|
354
|
+
|
|
355
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
356
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
357
|
+
channel: Mapped[str] = mapped_column(String(255))
|
|
358
|
+
mode: Mapped[str] = mapped_column(String(20)) # 'broadcast' or 'competing'
|
|
359
|
+
activity_id: Mapped[str | None] = mapped_column(
|
|
360
|
+
String(255), nullable=True
|
|
361
|
+
) # Set when waiting for message
|
|
362
|
+
cursor_message_id: Mapped[int | None] = mapped_column(
|
|
363
|
+
Integer, nullable=True
|
|
364
|
+
) # Last received message id (broadcast)
|
|
365
|
+
timeout_at: Mapped[datetime | None] = mapped_column(
|
|
366
|
+
DateTime(timezone=True), nullable=True
|
|
367
|
+
) # Timeout deadline
|
|
368
|
+
subscribed_at: Mapped[datetime] = mapped_column(
|
|
369
|
+
DateTime(timezone=True), server_default=func.now()
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
__table_args__ = (
|
|
373
|
+
ForeignKeyConstraint(
|
|
374
|
+
["instance_id"],
|
|
375
|
+
["workflow_instances.instance_id"],
|
|
376
|
+
ondelete="CASCADE",
|
|
377
|
+
),
|
|
378
|
+
CheckConstraint(
|
|
379
|
+
"mode IN ('broadcast', 'competing')",
|
|
380
|
+
name="channel_valid_mode",
|
|
381
|
+
),
|
|
382
|
+
UniqueConstraint("instance_id", "channel", name="unique_channel_instance_channel"),
|
|
383
|
+
Index("idx_channel_subscriptions_channel", "channel"),
|
|
384
|
+
Index("idx_channel_subscriptions_instance", "instance_id"),
|
|
385
|
+
Index("idx_channel_subscriptions_waiting", "channel", "activity_id"),
|
|
386
|
+
Index("idx_channel_subscriptions_timeout", "timeout_at"),
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class ChannelDeliveryCursor(Base):
|
|
391
|
+
"""Channel delivery cursors (broadcast mode: track who read what)."""
|
|
392
|
+
|
|
393
|
+
__tablename__ = "channel_delivery_cursors"
|
|
394
|
+
|
|
395
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
396
|
+
channel: Mapped[str] = mapped_column(String(255))
|
|
397
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
398
|
+
last_delivered_id: Mapped[int] = mapped_column(Integer)
|
|
399
|
+
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
400
|
+
|
|
401
|
+
__table_args__ = (
|
|
402
|
+
ForeignKeyConstraint(
|
|
403
|
+
["instance_id"],
|
|
404
|
+
["workflow_instances.instance_id"],
|
|
405
|
+
ondelete="CASCADE",
|
|
406
|
+
),
|
|
407
|
+
UniqueConstraint("channel", "instance_id", name="unique_channel_delivery_cursor"),
|
|
408
|
+
Index("idx_channel_delivery_cursors_channel", "channel"),
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
class ChannelMessageClaim(Base):
|
|
413
|
+
"""Channel message claims (competing mode: who is processing what)."""
|
|
414
|
+
|
|
415
|
+
__tablename__ = "channel_message_claims"
|
|
416
|
+
|
|
417
|
+
message_id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
418
|
+
instance_id: Mapped[str] = mapped_column(String(255))
|
|
419
|
+
claimed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
|
420
|
+
|
|
421
|
+
__table_args__ = (
|
|
422
|
+
ForeignKeyConstraint(
|
|
423
|
+
["message_id"],
|
|
424
|
+
["channel_messages.message_id"],
|
|
425
|
+
ondelete="CASCADE",
|
|
426
|
+
),
|
|
427
|
+
ForeignKeyConstraint(
|
|
428
|
+
["instance_id"],
|
|
429
|
+
["workflow_instances.instance_id"],
|
|
430
|
+
ondelete="CASCADE",
|
|
431
|
+
),
|
|
432
|
+
Index("idx_channel_message_claims_instance", "instance_id"),
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
# =============================================================================
|
|
437
|
+
# System-level Lock Models (for background task coordination)
|
|
438
|
+
# =============================================================================
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
class SystemLock(Base):
|
|
442
|
+
"""System-level locks for coordinating background tasks across pods.
|
|
443
|
+
|
|
444
|
+
Used to prevent duplicate execution of operational tasks like:
|
|
445
|
+
- cleanup_stale_locks_periodically()
|
|
446
|
+
- auto_resume_stale_workflows_periodically()
|
|
447
|
+
- _cleanup_old_messages_periodically()
|
|
448
|
+
"""
|
|
449
|
+
|
|
450
|
+
__tablename__ = "system_locks"
|
|
451
|
+
|
|
452
|
+
lock_name: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
453
|
+
locked_by: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
454
|
+
locked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
455
|
+
lock_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
456
|
+
|
|
457
|
+
__table_args__ = (Index("idx_system_locks_expires", "lock_expires_at"),)
|
|
458
|
+
|
|
459
|
+
|
|
279
460
|
# Current schema version
|
|
280
461
|
CURRENT_SCHEMA_VERSION = 1
|
|
281
462
|
|
|
@@ -302,6 +483,9 @@ class TransactionContext:
|
|
|
302
483
|
session: "AsyncSession | None" = None
|
|
303
484
|
"""The actual session for this transaction"""
|
|
304
485
|
|
|
486
|
+
post_commit_callbacks: list[Callable[[], Awaitable[None]]] = field(default_factory=list)
|
|
487
|
+
"""Callbacks to execute after successful top-level commit"""
|
|
488
|
+
|
|
305
489
|
|
|
306
490
|
# Context variable for transaction state (asyncio-safe)
|
|
307
491
|
_transaction_context: ContextVar[TransactionContext | None] = ContextVar(
|
|
@@ -337,11 +521,22 @@ class SQLAlchemyStorage:
|
|
|
337
521
|
self.engine = engine
|
|
338
522
|
|
|
339
523
|
async def initialize(self) -> None:
|
|
340
|
-
"""Initialize database connection and create tables.
|
|
524
|
+
"""Initialize database connection and create tables.
|
|
525
|
+
|
|
526
|
+
This method creates all tables if they don't exist, and then performs
|
|
527
|
+
automatic schema migration to add any missing columns and update CHECK
|
|
528
|
+
constraints. This ensures compatibility when upgrading Edda versions.
|
|
529
|
+
"""
|
|
341
530
|
# Create all tables and indexes
|
|
342
531
|
async with self.engine.begin() as conn:
|
|
343
532
|
await conn.run_sync(Base.metadata.create_all)
|
|
344
533
|
|
|
534
|
+
# Auto-migrate schema (add missing columns)
|
|
535
|
+
await self._auto_migrate_schema()
|
|
536
|
+
|
|
537
|
+
# Auto-migrate CHECK constraints
|
|
538
|
+
await self._auto_migrate_check_constraints()
|
|
539
|
+
|
|
345
540
|
# Initialize schema version
|
|
346
541
|
await self._initialize_schema_version()
|
|
347
542
|
|
|
@@ -366,6 +561,198 @@ class SQLAlchemyStorage:
|
|
|
366
561
|
await session.commit()
|
|
367
562
|
logger.info(f"Initialized schema version to {CURRENT_SCHEMA_VERSION}")
|
|
368
563
|
|
|
564
|
+
async def _auto_migrate_schema(self) -> None:
|
|
565
|
+
"""
|
|
566
|
+
Automatically add missing columns to existing tables.
|
|
567
|
+
|
|
568
|
+
This method compares the ORM model definitions with the actual database
|
|
569
|
+
schema and adds any missing columns using ALTER TABLE ADD COLUMN.
|
|
570
|
+
|
|
571
|
+
Note: This only handles column additions, not removals or type changes.
|
|
572
|
+
For complex migrations, use Alembic.
|
|
573
|
+
"""
|
|
574
|
+
|
|
575
|
+
def _get_column_type_sql(column: Column, dialect_name: str) -> str: # type: ignore[type-arg]
|
|
576
|
+
"""Get SQL type string for a column based on dialect."""
|
|
577
|
+
col_type = column.type
|
|
578
|
+
|
|
579
|
+
# Map SQLAlchemy types to SQL types
|
|
580
|
+
if isinstance(col_type, String):
|
|
581
|
+
length = col_type.length or 255
|
|
582
|
+
return f"VARCHAR({length})"
|
|
583
|
+
elif isinstance(col_type, Text):
|
|
584
|
+
return "TEXT"
|
|
585
|
+
elif isinstance(col_type, Integer):
|
|
586
|
+
return "INTEGER"
|
|
587
|
+
elif isinstance(col_type, DateTime):
|
|
588
|
+
if dialect_name == "postgresql":
|
|
589
|
+
return "TIMESTAMP WITH TIME ZONE" if col_type.timezone else "TIMESTAMP"
|
|
590
|
+
elif dialect_name == "mysql":
|
|
591
|
+
return "DATETIME" if not col_type.timezone else "DATETIME"
|
|
592
|
+
else: # sqlite
|
|
593
|
+
return "DATETIME"
|
|
594
|
+
elif isinstance(col_type, LargeBinary):
|
|
595
|
+
if dialect_name == "postgresql":
|
|
596
|
+
return "BYTEA"
|
|
597
|
+
elif dialect_name == "mysql":
|
|
598
|
+
return "LONGBLOB"
|
|
599
|
+
else: # sqlite
|
|
600
|
+
return "BLOB"
|
|
601
|
+
else:
|
|
602
|
+
# Fallback to compiled type
|
|
603
|
+
return str(col_type.compile(dialect=self.engine.dialect))
|
|
604
|
+
|
|
605
|
+
def _get_default_sql(column: Column, _dialect_name: str) -> str | None: # type: ignore[type-arg]
|
|
606
|
+
"""Get DEFAULT clause for a column if applicable."""
|
|
607
|
+
if column.server_default is not None:
|
|
608
|
+
# Handle text() server defaults - try to get the arg attribute
|
|
609
|
+
server_default = column.server_default
|
|
610
|
+
if hasattr(server_default, "arg"):
|
|
611
|
+
default_val = server_default.arg
|
|
612
|
+
if hasattr(default_val, "text"):
|
|
613
|
+
return f"DEFAULT {default_val.text}"
|
|
614
|
+
return f"DEFAULT {default_val}"
|
|
615
|
+
return None
|
|
616
|
+
|
|
617
|
+
def _run_migration(conn: Any) -> None:
|
|
618
|
+
"""Run migration in sync context."""
|
|
619
|
+
dialect_name = self.engine.dialect.name
|
|
620
|
+
inspector = inspect(conn)
|
|
621
|
+
|
|
622
|
+
# Iterate through all ORM tables
|
|
623
|
+
for table in Base.metadata.tables.values():
|
|
624
|
+
table_name = table.name
|
|
625
|
+
|
|
626
|
+
# Check if table exists
|
|
627
|
+
if not inspector.has_table(table_name):
|
|
628
|
+
logger.debug(f"Table {table_name} does not exist, skipping migration")
|
|
629
|
+
continue
|
|
630
|
+
|
|
631
|
+
# Get existing columns
|
|
632
|
+
existing_columns = {col["name"] for col in inspector.get_columns(table_name)}
|
|
633
|
+
|
|
634
|
+
# Check each column in the ORM model
|
|
635
|
+
for column in table.columns:
|
|
636
|
+
if column.name not in existing_columns:
|
|
637
|
+
# Column is missing, generate ALTER TABLE
|
|
638
|
+
col_type_sql = _get_column_type_sql(column, dialect_name)
|
|
639
|
+
nullable = "NULL" if column.nullable else "NOT NULL"
|
|
640
|
+
|
|
641
|
+
# Build ALTER TABLE statement
|
|
642
|
+
alter_sql = (
|
|
643
|
+
f'ALTER TABLE "{table_name}" ADD COLUMN "{column.name}" {col_type_sql}'
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
# Add nullable constraint (only if NOT NULL and has default)
|
|
647
|
+
default_sql = _get_default_sql(column, dialect_name)
|
|
648
|
+
if not column.nullable and default_sql:
|
|
649
|
+
alter_sql += f" {default_sql} {nullable}"
|
|
650
|
+
elif column.nullable:
|
|
651
|
+
alter_sql += f" {nullable}"
|
|
652
|
+
elif default_sql:
|
|
653
|
+
alter_sql += f" {default_sql}"
|
|
654
|
+
# For NOT NULL without default, just add the column as nullable
|
|
655
|
+
# (PostgreSQL requires default or nullable for existing rows)
|
|
656
|
+
else:
|
|
657
|
+
alter_sql += " NULL"
|
|
658
|
+
|
|
659
|
+
logger.info(f"Auto-migrating: Adding column {column.name} to {table_name}")
|
|
660
|
+
logger.debug(f"Executing: {alter_sql}")
|
|
661
|
+
|
|
662
|
+
try:
|
|
663
|
+
conn.execute(text(alter_sql))
|
|
664
|
+
except Exception as e:
|
|
665
|
+
logger.warning(
|
|
666
|
+
f"Failed to add column {column.name} to {table_name}: {e}"
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
async with self.engine.begin() as conn:
|
|
670
|
+
await conn.run_sync(_run_migration)
|
|
671
|
+
|
|
672
|
+
async def _auto_migrate_check_constraints(self) -> None:
|
|
673
|
+
"""
|
|
674
|
+
Automatically update CHECK constraints for workflow status.
|
|
675
|
+
|
|
676
|
+
This method ensures the valid_status CHECK constraint includes all
|
|
677
|
+
required status values (including 'waiting_for_message').
|
|
678
|
+
"""
|
|
679
|
+
dialect_name = self.engine.dialect.name
|
|
680
|
+
|
|
681
|
+
# SQLite doesn't support ALTER CONSTRAINT easily, and SQLAlchemy create_all
|
|
682
|
+
# handles it correctly for new databases. For existing SQLite databases,
|
|
683
|
+
# the constraint is more lenient (CHECK is not enforced in many SQLite versions).
|
|
684
|
+
if dialect_name == "sqlite":
|
|
685
|
+
return
|
|
686
|
+
|
|
687
|
+
# Expected status values (must match WorkflowInstance model)
|
|
688
|
+
expected_statuses = (
|
|
689
|
+
"'running', 'completed', 'failed', 'waiting_for_event', "
|
|
690
|
+
"'waiting_for_timer', 'waiting_for_message', 'compensating', 'cancelled', 'recurred'"
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
def _run_constraint_migration(conn: Any) -> None:
|
|
694
|
+
"""Run CHECK constraint migration in sync context."""
|
|
695
|
+
inspector = inspect(conn)
|
|
696
|
+
|
|
697
|
+
# Check if workflow_instances table exists
|
|
698
|
+
if not inspector.has_table("workflow_instances"):
|
|
699
|
+
return
|
|
700
|
+
|
|
701
|
+
# Get existing CHECK constraints
|
|
702
|
+
try:
|
|
703
|
+
constraints = inspector.get_check_constraints("workflow_instances")
|
|
704
|
+
except NotImplementedError:
|
|
705
|
+
# Some databases don't support get_check_constraints
|
|
706
|
+
logger.debug("Database doesn't support get_check_constraints inspection")
|
|
707
|
+
constraints = []
|
|
708
|
+
|
|
709
|
+
# Find the valid_status constraint
|
|
710
|
+
valid_status_constraint = None
|
|
711
|
+
for constraint in constraints:
|
|
712
|
+
if constraint.get("name") == "valid_status":
|
|
713
|
+
valid_status_constraint = constraint
|
|
714
|
+
break
|
|
715
|
+
|
|
716
|
+
# Check if constraint exists and needs updating
|
|
717
|
+
if valid_status_constraint:
|
|
718
|
+
sqltext = valid_status_constraint.get("sqltext", "")
|
|
719
|
+
# Check if 'waiting_for_message' is already in the constraint
|
|
720
|
+
if "waiting_for_message" in sqltext:
|
|
721
|
+
logger.debug("valid_status constraint already includes waiting_for_message")
|
|
722
|
+
return
|
|
723
|
+
|
|
724
|
+
# Need to update the constraint
|
|
725
|
+
logger.info("Updating valid_status CHECK constraint to include waiting_for_message")
|
|
726
|
+
try:
|
|
727
|
+
if dialect_name == "postgresql":
|
|
728
|
+
conn.execute(
|
|
729
|
+
text("ALTER TABLE workflow_instances DROP CONSTRAINT valid_status")
|
|
730
|
+
)
|
|
731
|
+
conn.execute(
|
|
732
|
+
text(
|
|
733
|
+
f"ALTER TABLE workflow_instances ADD CONSTRAINT valid_status "
|
|
734
|
+
f"CHECK (status IN ({expected_statuses}))"
|
|
735
|
+
)
|
|
736
|
+
)
|
|
737
|
+
elif dialect_name == "mysql":
|
|
738
|
+
# MySQL uses DROP CHECK and ADD CONSTRAINT CHECK syntax
|
|
739
|
+
conn.execute(text("ALTER TABLE workflow_instances DROP CHECK valid_status"))
|
|
740
|
+
conn.execute(
|
|
741
|
+
text(
|
|
742
|
+
f"ALTER TABLE workflow_instances ADD CONSTRAINT valid_status "
|
|
743
|
+
f"CHECK (status IN ({expected_statuses}))"
|
|
744
|
+
)
|
|
745
|
+
)
|
|
746
|
+
logger.info("Successfully updated valid_status CHECK constraint")
|
|
747
|
+
except Exception as e:
|
|
748
|
+
logger.warning(f"Failed to update valid_status CHECK constraint: {e}")
|
|
749
|
+
else:
|
|
750
|
+
# Constraint doesn't exist (shouldn't happen with create_all, but handle it)
|
|
751
|
+
logger.debug("valid_status constraint not found, will be created by create_all")
|
|
752
|
+
|
|
753
|
+
async with self.engine.begin() as conn:
|
|
754
|
+
await conn.run_sync(_run_constraint_migration)
|
|
755
|
+
|
|
369
756
|
def _get_session_for_operation(self, is_lock_operation: bool = False) -> AsyncSession:
|
|
370
757
|
"""
|
|
371
758
|
Get the appropriate session for an operation.
|
|
@@ -452,7 +839,7 @@ class SQLAlchemyStorage:
|
|
|
452
839
|
Example:
|
|
453
840
|
>>> # SQLite: datetime(timeout_at) <= datetime('now')
|
|
454
841
|
>>> # PostgreSQL/MySQL: timeout_at <= NOW()
|
|
455
|
-
>>> self._make_datetime_comparable(
|
|
842
|
+
>>> self._make_datetime_comparable(ChannelSubscription.timeout_at)
|
|
456
843
|
>>> <= self._get_current_time_expr()
|
|
457
844
|
"""
|
|
458
845
|
if self.engine.dialect.name == "sqlite":
|
|
@@ -504,11 +891,16 @@ class SQLAlchemyStorage:
|
|
|
504
891
|
if ctx is None or ctx.depth == 0:
|
|
505
892
|
raise RuntimeError("Not in a transaction")
|
|
506
893
|
|
|
894
|
+
# Capture callbacks before any state changes
|
|
895
|
+
callbacks_to_execute: list[Callable[[], Awaitable[None]]] = []
|
|
896
|
+
|
|
507
897
|
if ctx.depth == 1:
|
|
508
898
|
# Top-level transaction - commit the session
|
|
509
899
|
logger.debug("Committing top-level transaction")
|
|
510
900
|
await ctx.session.commit() # type: ignore[union-attr]
|
|
511
901
|
await ctx.session.close() # type: ignore[union-attr]
|
|
902
|
+
# Capture callbacks to execute after clearing context
|
|
903
|
+
callbacks_to_execute = ctx.post_commit_callbacks.copy()
|
|
512
904
|
else:
|
|
513
905
|
# Nested transaction - commit the savepoint
|
|
514
906
|
nested_tx = ctx.savepoint_stack.pop()
|
|
@@ -520,6 +912,12 @@ class SQLAlchemyStorage:
|
|
|
520
912
|
if ctx.depth == 0:
|
|
521
913
|
# All transactions completed - clear context
|
|
522
914
|
_transaction_context.set(None)
|
|
915
|
+
# Execute post-commit callbacks after successful top-level commit
|
|
916
|
+
for callback in callbacks_to_execute:
|
|
917
|
+
try:
|
|
918
|
+
await callback()
|
|
919
|
+
except Exception as e:
|
|
920
|
+
logger.error(f"Post-commit callback failed: {e}")
|
|
523
921
|
|
|
524
922
|
async def rollback_transaction(self) -> None:
|
|
525
923
|
"""
|
|
@@ -559,6 +957,26 @@ class SQLAlchemyStorage:
|
|
|
559
957
|
ctx = _transaction_context.get()
|
|
560
958
|
return ctx is not None and ctx.depth > 0
|
|
561
959
|
|
|
960
|
+
def register_post_commit_callback(self, callback: Callable[[], Awaitable[None]]) -> None:
|
|
961
|
+
"""
|
|
962
|
+
Register a callback to be executed after the current transaction commits.
|
|
963
|
+
|
|
964
|
+
The callback will be executed after the top-level transaction commits successfully.
|
|
965
|
+
If the transaction is rolled back, the callback will NOT be executed.
|
|
966
|
+
If not in a transaction, the callback will be executed immediately.
|
|
967
|
+
|
|
968
|
+
Args:
|
|
969
|
+
callback: An async function to call after commit.
|
|
970
|
+
|
|
971
|
+
Raises:
|
|
972
|
+
RuntimeError: If not in a transaction.
|
|
973
|
+
"""
|
|
974
|
+
ctx = _transaction_context.get()
|
|
975
|
+
if ctx is None or ctx.depth == 0:
|
|
976
|
+
raise RuntimeError("Cannot register post-commit callback: not in a transaction")
|
|
977
|
+
ctx.post_commit_callbacks.append(callback)
|
|
978
|
+
logger.debug(f"Registered post-commit callback: {callback}")
|
|
979
|
+
|
|
562
980
|
async def _commit_if_not_in_transaction(self, session: AsyncSession) -> None:
|
|
563
981
|
"""
|
|
564
982
|
Commit session if not in a transaction (auto-commit mode).
|
|
@@ -605,7 +1023,7 @@ class SQLAlchemyStorage:
|
|
|
605
1023
|
|
|
606
1024
|
if existing:
|
|
607
1025
|
# Update
|
|
608
|
-
existing.source_code = source_code
|
|
1026
|
+
existing.source_code = source_code
|
|
609
1027
|
else:
|
|
610
1028
|
# Insert
|
|
611
1029
|
definition = WorkflowDefinition(
|
|
@@ -682,6 +1100,7 @@ class SQLAlchemyStorage:
|
|
|
682
1100
|
owner_service: str,
|
|
683
1101
|
input_data: dict[str, Any],
|
|
684
1102
|
lock_timeout_seconds: int | None = None,
|
|
1103
|
+
continued_from: str | None = None,
|
|
685
1104
|
) -> None:
|
|
686
1105
|
"""Create a new workflow instance."""
|
|
687
1106
|
session = self._get_session_for_operation()
|
|
@@ -693,6 +1112,7 @@ class SQLAlchemyStorage:
|
|
|
693
1112
|
owner_service=owner_service,
|
|
694
1113
|
input_data=json.dumps(input_data),
|
|
695
1114
|
lock_timeout_seconds=lock_timeout_seconds,
|
|
1115
|
+
continued_from=continued_from,
|
|
696
1116
|
)
|
|
697
1117
|
session.add(instance)
|
|
698
1118
|
|
|
@@ -1165,6 +1585,103 @@ class SQLAlchemyStorage:
|
|
|
1165
1585
|
await session.commit()
|
|
1166
1586
|
return workflows_to_resume
|
|
1167
1587
|
|
|
1588
|
+
# -------------------------------------------------------------------------
|
|
1589
|
+
# System-level Locking Methods (for background task coordination)
|
|
1590
|
+
# -------------------------------------------------------------------------
|
|
1591
|
+
|
|
1592
|
+
async def try_acquire_system_lock(
|
|
1593
|
+
self,
|
|
1594
|
+
lock_name: str,
|
|
1595
|
+
worker_id: str,
|
|
1596
|
+
timeout_seconds: int = 60,
|
|
1597
|
+
) -> bool:
|
|
1598
|
+
"""
|
|
1599
|
+
Try to acquire a system-level lock for coordinating background tasks.
|
|
1600
|
+
|
|
1601
|
+
Uses INSERT ON CONFLICT pattern to handle race conditions:
|
|
1602
|
+
1. Try to INSERT new lock record
|
|
1603
|
+
2. If exists, check if expired or unlocked
|
|
1604
|
+
3. If available, acquire lock; otherwise return False
|
|
1605
|
+
|
|
1606
|
+
Note: ALWAYS uses separate session (not external session).
|
|
1607
|
+
"""
|
|
1608
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1609
|
+
async with self._session_scope(session) as session:
|
|
1610
|
+
current_time = datetime.now(UTC)
|
|
1611
|
+
lock_expires_at = current_time + timedelta(seconds=timeout_seconds)
|
|
1612
|
+
|
|
1613
|
+
# Try to get existing lock
|
|
1614
|
+
result = await session.execute(
|
|
1615
|
+
select(SystemLock).where(SystemLock.lock_name == lock_name)
|
|
1616
|
+
)
|
|
1617
|
+
lock = result.scalar_one_or_none()
|
|
1618
|
+
|
|
1619
|
+
if lock is None:
|
|
1620
|
+
# No lock exists - create new one
|
|
1621
|
+
lock = SystemLock(
|
|
1622
|
+
lock_name=lock_name,
|
|
1623
|
+
locked_by=worker_id,
|
|
1624
|
+
locked_at=current_time,
|
|
1625
|
+
lock_expires_at=lock_expires_at,
|
|
1626
|
+
)
|
|
1627
|
+
session.add(lock)
|
|
1628
|
+
await session.commit()
|
|
1629
|
+
return True
|
|
1630
|
+
|
|
1631
|
+
# Lock exists - check if available
|
|
1632
|
+
if lock.locked_by is None:
|
|
1633
|
+
# Unlocked - acquire
|
|
1634
|
+
lock.locked_by = worker_id
|
|
1635
|
+
lock.locked_at = current_time
|
|
1636
|
+
lock.lock_expires_at = lock_expires_at
|
|
1637
|
+
await session.commit()
|
|
1638
|
+
return True
|
|
1639
|
+
|
|
1640
|
+
# Check if expired (use SQL-side comparison for cross-DB compatibility)
|
|
1641
|
+
if lock.lock_expires_at is not None:
|
|
1642
|
+
# Handle timezone-naive datetime from SQLite
|
|
1643
|
+
lock_expires = (
|
|
1644
|
+
lock.lock_expires_at.replace(tzinfo=UTC)
|
|
1645
|
+
if lock.lock_expires_at.tzinfo is None
|
|
1646
|
+
else lock.lock_expires_at
|
|
1647
|
+
)
|
|
1648
|
+
if lock_expires <= current_time:
|
|
1649
|
+
# Expired - acquire
|
|
1650
|
+
lock.locked_by = worker_id
|
|
1651
|
+
lock.locked_at = current_time
|
|
1652
|
+
lock.lock_expires_at = lock_expires_at
|
|
1653
|
+
await session.commit()
|
|
1654
|
+
return True
|
|
1655
|
+
|
|
1656
|
+
# Already locked by another worker
|
|
1657
|
+
return False
|
|
1658
|
+
|
|
1659
|
+
async def release_system_lock(self, lock_name: str, worker_id: str) -> None:
|
|
1660
|
+
"""
|
|
1661
|
+
Release a system-level lock.
|
|
1662
|
+
|
|
1663
|
+
Only releases the lock if it's held by the specified worker.
|
|
1664
|
+
|
|
1665
|
+
Note: ALWAYS uses separate session (not external session).
|
|
1666
|
+
"""
|
|
1667
|
+
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1668
|
+
async with self._session_scope(session) as session:
|
|
1669
|
+
await session.execute(
|
|
1670
|
+
update(SystemLock)
|
|
1671
|
+
.where(
|
|
1672
|
+
and_(
|
|
1673
|
+
SystemLock.lock_name == lock_name,
|
|
1674
|
+
SystemLock.locked_by == worker_id,
|
|
1675
|
+
)
|
|
1676
|
+
)
|
|
1677
|
+
.values(
|
|
1678
|
+
locked_by=None,
|
|
1679
|
+
locked_at=None,
|
|
1680
|
+
lock_expires_at=None,
|
|
1681
|
+
)
|
|
1682
|
+
)
|
|
1683
|
+
await session.commit()
|
|
1684
|
+
|
|
1168
1685
|
# -------------------------------------------------------------------------
|
|
1169
1686
|
# History Methods (prefer external session)
|
|
1170
1687
|
# -------------------------------------------------------------------------
|
|
@@ -1231,6 +1748,107 @@ class SQLAlchemyStorage:
|
|
|
1231
1748
|
for row in rows
|
|
1232
1749
|
]
|
|
1233
1750
|
|
|
1751
|
+
async def archive_history(self, instance_id: str) -> int:
|
|
1752
|
+
"""
|
|
1753
|
+
Archive workflow history for the recur pattern.
|
|
1754
|
+
|
|
1755
|
+
Moves all history entries from workflow_history to workflow_history_archive.
|
|
1756
|
+
Binary data is converted to base64 for JSON storage in the archive.
|
|
1757
|
+
|
|
1758
|
+
Returns:
|
|
1759
|
+
Number of history entries archived
|
|
1760
|
+
"""
|
|
1761
|
+
import base64
|
|
1762
|
+
|
|
1763
|
+
session = self._get_session_for_operation()
|
|
1764
|
+
async with self._session_scope(session) as session:
|
|
1765
|
+
# Get all history entries for this instance
|
|
1766
|
+
result = await session.execute(
|
|
1767
|
+
select(WorkflowHistory)
|
|
1768
|
+
.where(WorkflowHistory.instance_id == instance_id)
|
|
1769
|
+
.order_by(WorkflowHistory.created_at.asc())
|
|
1770
|
+
)
|
|
1771
|
+
history_rows = result.scalars().all()
|
|
1772
|
+
|
|
1773
|
+
if not history_rows:
|
|
1774
|
+
return 0
|
|
1775
|
+
|
|
1776
|
+
# Archive each history entry
|
|
1777
|
+
for row in history_rows:
|
|
1778
|
+
# Convert event_data to JSON string for archive
|
|
1779
|
+
event_data_json: str | None
|
|
1780
|
+
if row.data_type == "binary" and row.event_data_binary is not None:
|
|
1781
|
+
# Convert binary to base64 for JSON storage
|
|
1782
|
+
event_data_json = json.dumps(
|
|
1783
|
+
{
|
|
1784
|
+
"_binary": True,
|
|
1785
|
+
"data": base64.b64encode(row.event_data_binary).decode("ascii"),
|
|
1786
|
+
}
|
|
1787
|
+
)
|
|
1788
|
+
else:
|
|
1789
|
+
# Already JSON, use as-is
|
|
1790
|
+
event_data_json = row.event_data
|
|
1791
|
+
|
|
1792
|
+
archive_entry = WorkflowHistoryArchive(
|
|
1793
|
+
instance_id=row.instance_id,
|
|
1794
|
+
activity_id=row.activity_id,
|
|
1795
|
+
event_type=row.event_type,
|
|
1796
|
+
event_data=event_data_json,
|
|
1797
|
+
created_at=row.created_at,
|
|
1798
|
+
)
|
|
1799
|
+
session.add(archive_entry)
|
|
1800
|
+
|
|
1801
|
+
# Delete original history entries
|
|
1802
|
+
await session.execute(
|
|
1803
|
+
delete(WorkflowHistory).where(WorkflowHistory.instance_id == instance_id)
|
|
1804
|
+
)
|
|
1805
|
+
|
|
1806
|
+
await self._commit_if_not_in_transaction(session)
|
|
1807
|
+
return len(history_rows)
|
|
1808
|
+
|
|
1809
|
+
async def find_first_cancellation_event(self, instance_id: str) -> dict[str, Any] | None:
|
|
1810
|
+
"""
|
|
1811
|
+
Find the first cancellation event in workflow history.
|
|
1812
|
+
|
|
1813
|
+
Uses LIMIT 1 optimization to avoid loading all history events.
|
|
1814
|
+
"""
|
|
1815
|
+
session = self._get_session_for_operation()
|
|
1816
|
+
async with self._session_scope(session) as session:
|
|
1817
|
+
# Query for cancellation events using LIMIT 1
|
|
1818
|
+
result = await session.execute(
|
|
1819
|
+
select(WorkflowHistory)
|
|
1820
|
+
.where(
|
|
1821
|
+
and_(
|
|
1822
|
+
WorkflowHistory.instance_id == instance_id,
|
|
1823
|
+
or_(
|
|
1824
|
+
WorkflowHistory.event_type == "WorkflowCancelled",
|
|
1825
|
+
func.lower(WorkflowHistory.event_type).contains("cancel"),
|
|
1826
|
+
),
|
|
1827
|
+
)
|
|
1828
|
+
)
|
|
1829
|
+
.order_by(WorkflowHistory.created_at.asc())
|
|
1830
|
+
.limit(1)
|
|
1831
|
+
)
|
|
1832
|
+
row = result.scalars().first()
|
|
1833
|
+
|
|
1834
|
+
if row is None:
|
|
1835
|
+
return None
|
|
1836
|
+
|
|
1837
|
+
# Parse event_data based on data_type
|
|
1838
|
+
if row.data_type == "binary" and row.event_data_binary is not None:
|
|
1839
|
+
event_data: dict[str, Any] | bytes = row.event_data_binary
|
|
1840
|
+
else:
|
|
1841
|
+
event_data = json.loads(row.event_data) if row.event_data else {}
|
|
1842
|
+
|
|
1843
|
+
return {
|
|
1844
|
+
"id": row.id,
|
|
1845
|
+
"instance_id": row.instance_id,
|
|
1846
|
+
"activity_id": row.activity_id,
|
|
1847
|
+
"event_type": row.event_type,
|
|
1848
|
+
"event_data": event_data,
|
|
1849
|
+
"created_at": row.created_at,
|
|
1850
|
+
}
|
|
1851
|
+
|
|
1234
1852
|
# -------------------------------------------------------------------------
|
|
1235
1853
|
# Compensation Methods (prefer external session)
|
|
1236
1854
|
# -------------------------------------------------------------------------
|
|
@@ -1271,7 +1889,7 @@ class SQLAlchemyStorage:
|
|
|
1271
1889
|
"instance_id": row.instance_id,
|
|
1272
1890
|
"activity_id": row.activity_id,
|
|
1273
1891
|
"activity_name": row.activity_name,
|
|
1274
|
-
"args": json.loads(row.args)
|
|
1892
|
+
"args": json.loads(row.args) if row.args else [],
|
|
1275
1893
|
"created_at": row.created_at.isoformat(),
|
|
1276
1894
|
}
|
|
1277
1895
|
for row in rows
|
|
@@ -1287,224 +1905,25 @@ class SQLAlchemyStorage:
|
|
|
1287
1905
|
await self._commit_if_not_in_transaction(session)
|
|
1288
1906
|
|
|
1289
1907
|
# -------------------------------------------------------------------------
|
|
1290
|
-
#
|
|
1908
|
+
# Timer Subscription Methods
|
|
1291
1909
|
# -------------------------------------------------------------------------
|
|
1292
1910
|
|
|
1293
|
-
async def
|
|
1911
|
+
async def register_timer_subscription_and_release_lock(
|
|
1294
1912
|
self,
|
|
1295
1913
|
instance_id: str,
|
|
1296
|
-
|
|
1297
|
-
|
|
1914
|
+
worker_id: str,
|
|
1915
|
+
timer_id: str,
|
|
1916
|
+
expires_at: datetime,
|
|
1917
|
+
activity_id: str | None = None,
|
|
1298
1918
|
) -> None:
|
|
1299
|
-
"""
|
|
1300
|
-
|
|
1301
|
-
async with self._session_scope(session) as session:
|
|
1302
|
-
subscription = WorkflowEventSubscription(
|
|
1303
|
-
instance_id=instance_id,
|
|
1304
|
-
event_type=event_type,
|
|
1305
|
-
timeout_at=timeout_at,
|
|
1306
|
-
)
|
|
1307
|
-
session.add(subscription)
|
|
1308
|
-
await self._commit_if_not_in_transaction(session)
|
|
1919
|
+
"""
|
|
1920
|
+
Atomically register timer subscription and release workflow lock.
|
|
1309
1921
|
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
select(WorkflowEventSubscription).where(
|
|
1316
|
-
and_(
|
|
1317
|
-
WorkflowEventSubscription.event_type == event_type,
|
|
1318
|
-
or_(
|
|
1319
|
-
WorkflowEventSubscription.timeout_at.is_(None),
|
|
1320
|
-
self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
1321
|
-
> self._get_current_time_expr(),
|
|
1322
|
-
),
|
|
1323
|
-
)
|
|
1324
|
-
)
|
|
1325
|
-
)
|
|
1326
|
-
rows = result.scalars().all()
|
|
1327
|
-
|
|
1328
|
-
return [
|
|
1329
|
-
{
|
|
1330
|
-
"id": row.id,
|
|
1331
|
-
"instance_id": row.instance_id,
|
|
1332
|
-
"event_type": row.event_type,
|
|
1333
|
-
"activity_id": row.activity_id,
|
|
1334
|
-
"timeout_at": row.timeout_at.isoformat() if row.timeout_at else None,
|
|
1335
|
-
"created_at": row.created_at.isoformat(),
|
|
1336
|
-
}
|
|
1337
|
-
for row in rows
|
|
1338
|
-
]
|
|
1339
|
-
|
|
1340
|
-
async def remove_event_subscription(
|
|
1341
|
-
self,
|
|
1342
|
-
instance_id: str,
|
|
1343
|
-
event_type: str,
|
|
1344
|
-
) -> None:
|
|
1345
|
-
"""Remove event subscription after the event is received."""
|
|
1346
|
-
session = self._get_session_for_operation()
|
|
1347
|
-
async with self._session_scope(session) as session:
|
|
1348
|
-
await session.execute(
|
|
1349
|
-
delete(WorkflowEventSubscription).where(
|
|
1350
|
-
and_(
|
|
1351
|
-
WorkflowEventSubscription.instance_id == instance_id,
|
|
1352
|
-
WorkflowEventSubscription.event_type == event_type,
|
|
1353
|
-
)
|
|
1354
|
-
)
|
|
1355
|
-
)
|
|
1356
|
-
await self._commit_if_not_in_transaction(session)
|
|
1357
|
-
|
|
1358
|
-
async def cleanup_expired_subscriptions(self) -> int:
|
|
1359
|
-
"""Clean up event subscriptions that have timed out."""
|
|
1360
|
-
session = self._get_session_for_operation()
|
|
1361
|
-
async with self._session_scope(session) as session:
|
|
1362
|
-
result = await session.execute(
|
|
1363
|
-
delete(WorkflowEventSubscription).where(
|
|
1364
|
-
and_(
|
|
1365
|
-
WorkflowEventSubscription.timeout_at.isnot(None),
|
|
1366
|
-
self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
1367
|
-
<= self._get_current_time_expr(),
|
|
1368
|
-
)
|
|
1369
|
-
)
|
|
1370
|
-
)
|
|
1371
|
-
await self._commit_if_not_in_transaction(session)
|
|
1372
|
-
return result.rowcount or 0 # type: ignore[attr-defined]
|
|
1373
|
-
|
|
1374
|
-
async def find_expired_event_subscriptions(self) -> list[dict[str, Any]]:
|
|
1375
|
-
"""Find event subscriptions that have timed out."""
|
|
1376
|
-
session = self._get_session_for_operation()
|
|
1377
|
-
async with self._session_scope(session) as session:
|
|
1378
|
-
result = await session.execute(
|
|
1379
|
-
select(
|
|
1380
|
-
WorkflowEventSubscription.instance_id,
|
|
1381
|
-
WorkflowEventSubscription.event_type,
|
|
1382
|
-
WorkflowEventSubscription.activity_id,
|
|
1383
|
-
WorkflowEventSubscription.timeout_at,
|
|
1384
|
-
WorkflowEventSubscription.created_at,
|
|
1385
|
-
).where(
|
|
1386
|
-
and_(
|
|
1387
|
-
WorkflowEventSubscription.timeout_at.isnot(None),
|
|
1388
|
-
self._make_datetime_comparable(WorkflowEventSubscription.timeout_at)
|
|
1389
|
-
<= self._get_current_time_expr(),
|
|
1390
|
-
)
|
|
1391
|
-
)
|
|
1392
|
-
)
|
|
1393
|
-
rows = result.all()
|
|
1394
|
-
|
|
1395
|
-
return [
|
|
1396
|
-
{
|
|
1397
|
-
"instance_id": row[0],
|
|
1398
|
-
"event_type": row[1],
|
|
1399
|
-
"activity_id": row[2],
|
|
1400
|
-
"timeout_at": row[3].isoformat() if row[3] else None,
|
|
1401
|
-
"created_at": row[4].isoformat() if row[4] else None,
|
|
1402
|
-
}
|
|
1403
|
-
for row in rows
|
|
1404
|
-
]
|
|
1405
|
-
|
|
1406
|
-
async def register_event_subscription_and_release_lock(
|
|
1407
|
-
self,
|
|
1408
|
-
instance_id: str,
|
|
1409
|
-
worker_id: str,
|
|
1410
|
-
event_type: str,
|
|
1411
|
-
timeout_at: datetime | None = None,
|
|
1412
|
-
activity_id: str | None = None,
|
|
1413
|
-
) -> None:
|
|
1414
|
-
"""
|
|
1415
|
-
Atomically register event subscription and release workflow lock.
|
|
1416
|
-
|
|
1417
|
-
This performs THREE operations in a SINGLE transaction:
|
|
1418
|
-
1. Register event subscription
|
|
1419
|
-
2. Update current activity
|
|
1420
|
-
3. Release lock
|
|
1421
|
-
|
|
1422
|
-
This ensures distributed coroutines work correctly - when a workflow
|
|
1423
|
-
calls wait_event(), the subscription is registered and lock is released
|
|
1424
|
-
atomically, so ANY worker can resume the workflow when the event arrives.
|
|
1425
|
-
|
|
1426
|
-
Note: Uses LOCK operation session (separate from external session).
|
|
1427
|
-
"""
|
|
1428
|
-
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1429
|
-
async with self._session_scope(session) as session, session.begin():
|
|
1430
|
-
# 1. Verify we hold the lock (sanity check)
|
|
1431
|
-
result = await session.execute(
|
|
1432
|
-
select(WorkflowInstance.locked_by).where(
|
|
1433
|
-
WorkflowInstance.instance_id == instance_id
|
|
1434
|
-
)
|
|
1435
|
-
)
|
|
1436
|
-
row = result.one_or_none()
|
|
1437
|
-
|
|
1438
|
-
if row is None:
|
|
1439
|
-
raise RuntimeError(f"Workflow instance {instance_id} not found")
|
|
1440
|
-
|
|
1441
|
-
current_lock_holder = row[0]
|
|
1442
|
-
if current_lock_holder != worker_id:
|
|
1443
|
-
raise RuntimeError(
|
|
1444
|
-
f"Cannot release lock: worker {worker_id} does not hold lock "
|
|
1445
|
-
f"for {instance_id} (held by: {current_lock_holder})"
|
|
1446
|
-
)
|
|
1447
|
-
|
|
1448
|
-
# 2. Register event subscription (INSERT OR REPLACE equivalent)
|
|
1449
|
-
# First delete existing
|
|
1450
|
-
await session.execute(
|
|
1451
|
-
delete(WorkflowEventSubscription).where(
|
|
1452
|
-
and_(
|
|
1453
|
-
WorkflowEventSubscription.instance_id == instance_id,
|
|
1454
|
-
WorkflowEventSubscription.event_type == event_type,
|
|
1455
|
-
)
|
|
1456
|
-
)
|
|
1457
|
-
)
|
|
1458
|
-
|
|
1459
|
-
# Then insert new
|
|
1460
|
-
subscription = WorkflowEventSubscription(
|
|
1461
|
-
instance_id=instance_id,
|
|
1462
|
-
event_type=event_type,
|
|
1463
|
-
activity_id=activity_id,
|
|
1464
|
-
timeout_at=timeout_at,
|
|
1465
|
-
)
|
|
1466
|
-
session.add(subscription)
|
|
1467
|
-
|
|
1468
|
-
# 3. Update current activity (if provided)
|
|
1469
|
-
if activity_id is not None:
|
|
1470
|
-
await session.execute(
|
|
1471
|
-
update(WorkflowInstance)
|
|
1472
|
-
.where(WorkflowInstance.instance_id == instance_id)
|
|
1473
|
-
.values(current_activity_id=activity_id, updated_at=func.now())
|
|
1474
|
-
)
|
|
1475
|
-
|
|
1476
|
-
# 4. Release lock
|
|
1477
|
-
await session.execute(
|
|
1478
|
-
update(WorkflowInstance)
|
|
1479
|
-
.where(
|
|
1480
|
-
and_(
|
|
1481
|
-
WorkflowInstance.instance_id == instance_id,
|
|
1482
|
-
WorkflowInstance.locked_by == worker_id,
|
|
1483
|
-
)
|
|
1484
|
-
)
|
|
1485
|
-
.values(
|
|
1486
|
-
locked_by=None,
|
|
1487
|
-
locked_at=None,
|
|
1488
|
-
updated_at=func.now(),
|
|
1489
|
-
)
|
|
1490
|
-
)
|
|
1491
|
-
|
|
1492
|
-
async def register_timer_subscription_and_release_lock(
|
|
1493
|
-
self,
|
|
1494
|
-
instance_id: str,
|
|
1495
|
-
worker_id: str,
|
|
1496
|
-
timer_id: str,
|
|
1497
|
-
expires_at: datetime,
|
|
1498
|
-
activity_id: str | None = None,
|
|
1499
|
-
) -> None:
|
|
1500
|
-
"""
|
|
1501
|
-
Atomically register timer subscription and release workflow lock.
|
|
1502
|
-
|
|
1503
|
-
This performs FOUR operations in a SINGLE transaction:
|
|
1504
|
-
1. Register timer subscription
|
|
1505
|
-
2. Update current activity
|
|
1506
|
-
3. Update status to 'waiting_for_timer'
|
|
1507
|
-
4. Release lock
|
|
1922
|
+
This performs FOUR operations in a SINGLE transaction:
|
|
1923
|
+
1. Register timer subscription
|
|
1924
|
+
2. Update current activity
|
|
1925
|
+
3. Update status to 'waiting_for_timer'
|
|
1926
|
+
4. Release lock
|
|
1508
1927
|
|
|
1509
1928
|
This ensures distributed coroutines work correctly - when a workflow
|
|
1510
1929
|
calls wait_timer(), the subscription is registered and lock is released
|
|
@@ -1580,7 +1999,11 @@ class SQLAlchemyStorage:
|
|
|
1580
1999
|
)
|
|
1581
2000
|
|
|
1582
2001
|
async def find_expired_timers(self) -> list[dict[str, Any]]:
|
|
1583
|
-
"""Find timer subscriptions that have expired.
|
|
2002
|
+
"""Find timer subscriptions that have expired.
|
|
2003
|
+
|
|
2004
|
+
Returns timer info including workflow status to avoid N+1 queries.
|
|
2005
|
+
The SQL query already filters by status='waiting_for_timer'.
|
|
2006
|
+
"""
|
|
1584
2007
|
session = self._get_session_for_operation()
|
|
1585
2008
|
async with self._session_scope(session) as session:
|
|
1586
2009
|
result = await session.execute(
|
|
@@ -1590,6 +2013,7 @@ class SQLAlchemyStorage:
|
|
|
1590
2013
|
WorkflowTimerSubscription.expires_at,
|
|
1591
2014
|
WorkflowTimerSubscription.activity_id,
|
|
1592
2015
|
WorkflowInstance.workflow_name,
|
|
2016
|
+
WorkflowInstance.status, # Include status to avoid N+1 query
|
|
1593
2017
|
)
|
|
1594
2018
|
.join(
|
|
1595
2019
|
WorkflowInstance,
|
|
@@ -1612,6 +2036,7 @@ class SQLAlchemyStorage:
|
|
|
1612
2036
|
"expires_at": row[2].isoformat(),
|
|
1613
2037
|
"activity_id": row[3],
|
|
1614
2038
|
"workflow_name": row[4],
|
|
2039
|
+
"status": row[5], # Always 'waiting_for_timer' due to WHERE clause
|
|
1615
2040
|
}
|
|
1616
2041
|
for row in rows
|
|
1617
2042
|
]
|
|
@@ -1831,12 +2256,17 @@ class SQLAlchemyStorage:
|
|
|
1831
2256
|
Only running or waiting_for_event workflows can be cancelled.
|
|
1832
2257
|
This method atomically:
|
|
1833
2258
|
1. Checks current status
|
|
1834
|
-
2. Updates status to 'cancelled' if allowed
|
|
2259
|
+
2. Updates status to 'cancelled' if allowed (with atomic status check)
|
|
1835
2260
|
3. Clears locks
|
|
1836
2261
|
4. Records cancellation metadata
|
|
1837
2262
|
5. Removes event subscriptions (if waiting for event)
|
|
1838
2263
|
6. Removes timer subscriptions (if waiting for timer)
|
|
1839
2264
|
|
|
2265
|
+
The UPDATE includes a status condition in WHERE clause to prevent
|
|
2266
|
+
TOCTOU (time-of-check to time-of-use) race conditions. If the status
|
|
2267
|
+
changes between SELECT and UPDATE, the UPDATE will affect 0 rows
|
|
2268
|
+
and the cancellation will fail safely.
|
|
2269
|
+
|
|
1840
2270
|
Args:
|
|
1841
2271
|
instance_id: Workflow instance to cancel
|
|
1842
2272
|
cancelled_by: Who/what triggered the cancellation
|
|
@@ -1846,6 +2276,14 @@ class SQLAlchemyStorage:
|
|
|
1846
2276
|
|
|
1847
2277
|
Note: Uses LOCK operation session (separate from external session).
|
|
1848
2278
|
"""
|
|
2279
|
+
cancellable_statuses = (
|
|
2280
|
+
"running",
|
|
2281
|
+
"waiting_for_event",
|
|
2282
|
+
"waiting_for_timer",
|
|
2283
|
+
"waiting_for_message",
|
|
2284
|
+
"compensating",
|
|
2285
|
+
)
|
|
2286
|
+
|
|
1849
2287
|
session = self._get_session_for_operation(is_lock_operation=True)
|
|
1850
2288
|
async with self._session_scope(session) as session, session.begin():
|
|
1851
2289
|
# Get current instance status
|
|
@@ -1862,41 +2300,42 @@ class SQLAlchemyStorage:
|
|
|
1862
2300
|
|
|
1863
2301
|
# Only allow cancellation of running, waiting, or compensating workflows
|
|
1864
2302
|
# compensating workflows can be marked as cancelled after compensation completes
|
|
1865
|
-
if current_status not in
|
|
1866
|
-
"running",
|
|
1867
|
-
"waiting_for_event",
|
|
1868
|
-
"waiting_for_timer",
|
|
1869
|
-
"compensating",
|
|
1870
|
-
):
|
|
2303
|
+
if current_status not in cancellable_statuses:
|
|
1871
2304
|
# Already completed, failed, or cancelled
|
|
1872
2305
|
return False
|
|
1873
2306
|
|
|
1874
2307
|
# Update status to cancelled and record metadata
|
|
2308
|
+
# IMPORTANT: Include status condition in WHERE clause to prevent TOCTOU race
|
|
2309
|
+
# If another worker changed the status between SELECT and UPDATE,
|
|
2310
|
+
# this UPDATE will affect 0 rows and we'll return False
|
|
1875
2311
|
cancellation_metadata = {
|
|
1876
2312
|
"cancelled_by": cancelled_by,
|
|
1877
2313
|
"cancelled_at": datetime.now(UTC).isoformat(),
|
|
1878
2314
|
"previous_status": current_status,
|
|
1879
2315
|
}
|
|
1880
2316
|
|
|
1881
|
-
await session.execute(
|
|
2317
|
+
update_result = await session.execute(
|
|
1882
2318
|
update(WorkflowInstance)
|
|
1883
|
-
.where(
|
|
2319
|
+
.where(
|
|
2320
|
+
and_(
|
|
2321
|
+
WorkflowInstance.instance_id == instance_id,
|
|
2322
|
+
WorkflowInstance.status == current_status, # Atomic check
|
|
2323
|
+
)
|
|
2324
|
+
)
|
|
1884
2325
|
.values(
|
|
1885
2326
|
status="cancelled",
|
|
1886
2327
|
output_data=json.dumps(cancellation_metadata),
|
|
1887
2328
|
locked_by=None,
|
|
1888
2329
|
locked_at=None,
|
|
2330
|
+
lock_expires_at=None,
|
|
1889
2331
|
updated_at=func.now(),
|
|
1890
2332
|
)
|
|
1891
2333
|
)
|
|
1892
2334
|
|
|
1893
|
-
|
|
1894
|
-
|
|
1895
|
-
|
|
1896
|
-
|
|
1897
|
-
WorkflowEventSubscription.instance_id == instance_id
|
|
1898
|
-
)
|
|
1899
|
-
)
|
|
2335
|
+
if update_result.rowcount == 0: # type: ignore[attr-defined]
|
|
2336
|
+
# Status changed between SELECT and UPDATE (race condition)
|
|
2337
|
+
# Another worker may have resumed/modified the workflow
|
|
2338
|
+
return False
|
|
1900
2339
|
|
|
1901
2340
|
# Remove timer subscriptions if waiting for timer
|
|
1902
2341
|
if current_status == "waiting_for_timer":
|
|
@@ -1906,4 +2345,1103 @@ class SQLAlchemyStorage:
|
|
|
1906
2345
|
)
|
|
1907
2346
|
)
|
|
1908
2347
|
|
|
2348
|
+
# Clear channel subscriptions if waiting for event/message
|
|
2349
|
+
if current_status in ("waiting_for_event", "waiting_for_message"):
|
|
2350
|
+
await session.execute(
|
|
2351
|
+
update(ChannelSubscription)
|
|
2352
|
+
.where(ChannelSubscription.instance_id == instance_id)
|
|
2353
|
+
.values(activity_id=None, timeout_at=None)
|
|
2354
|
+
)
|
|
2355
|
+
|
|
1909
2356
|
return True
|
|
2357
|
+
|
|
2358
|
+
# -------------------------------------------------------------------------
|
|
2359
|
+
# Message Subscription Methods
|
|
2360
|
+
# -------------------------------------------------------------------------
|
|
2361
|
+
|
|
2362
|
+
async def find_waiting_instances_by_channel(self, channel: str) -> list[dict[str, Any]]:
|
|
2363
|
+
"""
|
|
2364
|
+
Find all workflow instances waiting on a specific channel.
|
|
2365
|
+
|
|
2366
|
+
Args:
|
|
2367
|
+
channel: Channel name to search for
|
|
2368
|
+
|
|
2369
|
+
Returns:
|
|
2370
|
+
List of subscription info dicts with instance_id, channel, activity_id, timeout_at
|
|
2371
|
+
"""
|
|
2372
|
+
session = self._get_session_for_operation()
|
|
2373
|
+
async with self._session_scope(session) as session:
|
|
2374
|
+
# Query ChannelSubscription table for waiting instances
|
|
2375
|
+
result = await session.execute(
|
|
2376
|
+
select(ChannelSubscription).where(
|
|
2377
|
+
and_(
|
|
2378
|
+
ChannelSubscription.channel == channel,
|
|
2379
|
+
ChannelSubscription.activity_id.isnot(None), # Only waiting subscriptions
|
|
2380
|
+
)
|
|
2381
|
+
)
|
|
2382
|
+
)
|
|
2383
|
+
subscriptions = result.scalars().all()
|
|
2384
|
+
return [
|
|
2385
|
+
{
|
|
2386
|
+
"instance_id": sub.instance_id,
|
|
2387
|
+
"channel": sub.channel,
|
|
2388
|
+
"activity_id": sub.activity_id,
|
|
2389
|
+
"timeout_at": sub.timeout_at.isoformat() if sub.timeout_at else None,
|
|
2390
|
+
"created_at": sub.subscribed_at.isoformat() if sub.subscribed_at else None,
|
|
2391
|
+
}
|
|
2392
|
+
for sub in subscriptions
|
|
2393
|
+
]
|
|
2394
|
+
|
|
2395
|
+
async def remove_message_subscription(
|
|
2396
|
+
self,
|
|
2397
|
+
instance_id: str,
|
|
2398
|
+
channel: str,
|
|
2399
|
+
) -> None:
|
|
2400
|
+
"""
|
|
2401
|
+
Remove a message subscription.
|
|
2402
|
+
|
|
2403
|
+
This method clears waiting state from the ChannelSubscription table.
|
|
2404
|
+
|
|
2405
|
+
Args:
|
|
2406
|
+
instance_id: Workflow instance ID
|
|
2407
|
+
channel: Channel name
|
|
2408
|
+
"""
|
|
2409
|
+
session = self._get_session_for_operation()
|
|
2410
|
+
async with self._session_scope(session) as session:
|
|
2411
|
+
# Clear waiting state from ChannelSubscription table
|
|
2412
|
+
# Don't delete the subscription - just clear the waiting state
|
|
2413
|
+
await session.execute(
|
|
2414
|
+
update(ChannelSubscription)
|
|
2415
|
+
.where(
|
|
2416
|
+
and_(
|
|
2417
|
+
ChannelSubscription.instance_id == instance_id,
|
|
2418
|
+
ChannelSubscription.channel == channel,
|
|
2419
|
+
)
|
|
2420
|
+
)
|
|
2421
|
+
.values(activity_id=None, timeout_at=None)
|
|
2422
|
+
)
|
|
2423
|
+
await self._commit_if_not_in_transaction(session)
|
|
2424
|
+
|
|
2425
|
+
async def deliver_message(
|
|
2426
|
+
self,
|
|
2427
|
+
instance_id: str,
|
|
2428
|
+
channel: str,
|
|
2429
|
+
data: dict[str, Any] | bytes,
|
|
2430
|
+
metadata: dict[str, Any],
|
|
2431
|
+
worker_id: str | None = None,
|
|
2432
|
+
) -> dict[str, Any] | None:
|
|
2433
|
+
"""
|
|
2434
|
+
Deliver a message to a waiting workflow using Lock-First pattern.
|
|
2435
|
+
|
|
2436
|
+
This method:
|
|
2437
|
+
1. Checks if there's a subscription for this instance/channel
|
|
2438
|
+
2. Acquires lock (Lock-First pattern) - if worker_id provided
|
|
2439
|
+
3. Records the message in history
|
|
2440
|
+
4. Removes the subscription
|
|
2441
|
+
5. Updates status to 'running'
|
|
2442
|
+
6. Releases lock
|
|
2443
|
+
|
|
2444
|
+
Args:
|
|
2445
|
+
instance_id: Target workflow instance ID
|
|
2446
|
+
channel: Channel name
|
|
2447
|
+
data: Message payload (dict or bytes)
|
|
2448
|
+
metadata: Message metadata
|
|
2449
|
+
worker_id: Worker ID for locking. If None, skip locking (legacy mode).
|
|
2450
|
+
|
|
2451
|
+
Returns:
|
|
2452
|
+
Dict with delivery info if successful, None otherwise
|
|
2453
|
+
"""
|
|
2454
|
+
import uuid
|
|
2455
|
+
|
|
2456
|
+
# Step 1: Check if subscription exists (without lock)
|
|
2457
|
+
session = self._get_session_for_operation()
|
|
2458
|
+
async with self._session_scope(session) as session:
|
|
2459
|
+
result = await session.execute(
|
|
2460
|
+
select(ChannelSubscription).where(
|
|
2461
|
+
and_(
|
|
2462
|
+
ChannelSubscription.instance_id == instance_id,
|
|
2463
|
+
ChannelSubscription.channel == channel,
|
|
2464
|
+
ChannelSubscription.activity_id.isnot(None), # Only waiting subscriptions
|
|
2465
|
+
)
|
|
2466
|
+
)
|
|
2467
|
+
)
|
|
2468
|
+
subscription = result.scalar_one_or_none()
|
|
2469
|
+
|
|
2470
|
+
if subscription is None:
|
|
2471
|
+
return None
|
|
2472
|
+
|
|
2473
|
+
activity_id = subscription.activity_id
|
|
2474
|
+
|
|
2475
|
+
# Step 2: Acquire lock (Lock-First pattern) if worker_id provided
|
|
2476
|
+
lock_acquired = False
|
|
2477
|
+
if worker_id is not None:
|
|
2478
|
+
lock_acquired = await self.try_acquire_lock(instance_id, worker_id)
|
|
2479
|
+
if not lock_acquired:
|
|
2480
|
+
# Another worker is processing this workflow
|
|
2481
|
+
return None
|
|
2482
|
+
|
|
2483
|
+
try:
|
|
2484
|
+
# Step 3-5: Deliver message atomically
|
|
2485
|
+
session = self._get_session_for_operation()
|
|
2486
|
+
async with self._session_scope(session) as session:
|
|
2487
|
+
# Re-check subscription (may have been removed by another worker)
|
|
2488
|
+
result = await session.execute(
|
|
2489
|
+
select(ChannelSubscription).where(
|
|
2490
|
+
and_(
|
|
2491
|
+
ChannelSubscription.instance_id == instance_id,
|
|
2492
|
+
ChannelSubscription.channel == channel,
|
|
2493
|
+
ChannelSubscription.activity_id.isnot(
|
|
2494
|
+
None
|
|
2495
|
+
), # Only waiting subscriptions
|
|
2496
|
+
)
|
|
2497
|
+
)
|
|
2498
|
+
)
|
|
2499
|
+
subscription = result.scalar_one_or_none()
|
|
2500
|
+
|
|
2501
|
+
if subscription is None:
|
|
2502
|
+
# Already delivered by another worker
|
|
2503
|
+
return None
|
|
2504
|
+
|
|
2505
|
+
activity_id = subscription.activity_id
|
|
2506
|
+
|
|
2507
|
+
# Get workflow info for return value
|
|
2508
|
+
instance_result = await session.execute(
|
|
2509
|
+
select(WorkflowInstance).where(WorkflowInstance.instance_id == instance_id)
|
|
2510
|
+
)
|
|
2511
|
+
instance = instance_result.scalar_one_or_none()
|
|
2512
|
+
workflow_name = instance.workflow_name if instance else "unknown"
|
|
2513
|
+
|
|
2514
|
+
# Build message data for history
|
|
2515
|
+
message_id = str(uuid.uuid4())
|
|
2516
|
+
message_data = {
|
|
2517
|
+
"id": message_id,
|
|
2518
|
+
"channel": channel,
|
|
2519
|
+
"data": data if isinstance(data, dict) else None,
|
|
2520
|
+
"metadata": metadata,
|
|
2521
|
+
}
|
|
2522
|
+
|
|
2523
|
+
# Handle binary data
|
|
2524
|
+
if isinstance(data, bytes):
|
|
2525
|
+
data_type = "binary"
|
|
2526
|
+
event_data_json = None
|
|
2527
|
+
event_data_binary = data
|
|
2528
|
+
else:
|
|
2529
|
+
data_type = "json"
|
|
2530
|
+
event_data_json = json.dumps(message_data)
|
|
2531
|
+
event_data_binary = None
|
|
2532
|
+
|
|
2533
|
+
# Record in history
|
|
2534
|
+
history_entry = WorkflowHistory(
|
|
2535
|
+
instance_id=instance_id,
|
|
2536
|
+
activity_id=activity_id,
|
|
2537
|
+
event_type="ChannelMessageReceived",
|
|
2538
|
+
data_type=data_type,
|
|
2539
|
+
event_data=event_data_json,
|
|
2540
|
+
event_data_binary=event_data_binary,
|
|
2541
|
+
)
|
|
2542
|
+
session.add(history_entry)
|
|
2543
|
+
|
|
2544
|
+
# Clear waiting state from subscription (don't delete)
|
|
2545
|
+
await session.execute(
|
|
2546
|
+
update(ChannelSubscription)
|
|
2547
|
+
.where(
|
|
2548
|
+
and_(
|
|
2549
|
+
ChannelSubscription.instance_id == instance_id,
|
|
2550
|
+
ChannelSubscription.channel == channel,
|
|
2551
|
+
)
|
|
2552
|
+
)
|
|
2553
|
+
.values(activity_id=None, timeout_at=None)
|
|
2554
|
+
)
|
|
2555
|
+
|
|
2556
|
+
# Update status to 'running' (ready for resumption)
|
|
2557
|
+
await session.execute(
|
|
2558
|
+
update(WorkflowInstance)
|
|
2559
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
2560
|
+
.values(status="running", updated_at=func.now())
|
|
2561
|
+
)
|
|
2562
|
+
|
|
2563
|
+
await self._commit_if_not_in_transaction(session)
|
|
2564
|
+
|
|
2565
|
+
return {
|
|
2566
|
+
"instance_id": instance_id,
|
|
2567
|
+
"workflow_name": workflow_name,
|
|
2568
|
+
"activity_id": activity_id,
|
|
2569
|
+
}
|
|
2570
|
+
|
|
2571
|
+
finally:
|
|
2572
|
+
# Step 6: Release lock if we acquired it
|
|
2573
|
+
if lock_acquired and worker_id is not None:
|
|
2574
|
+
await self.release_lock(instance_id, worker_id)
|
|
2575
|
+
|
|
2576
|
+
async def find_expired_message_subscriptions(self) -> list[dict[str, Any]]:
|
|
2577
|
+
"""
|
|
2578
|
+
Find all message subscriptions that have timed out.
|
|
2579
|
+
|
|
2580
|
+
JOINs with WorkflowInstance to ensure instance exists and avoid N+1 queries.
|
|
2581
|
+
|
|
2582
|
+
Returns:
|
|
2583
|
+
List of dicts with instance_id, channel, activity_id, timeout_at, created_at, workflow_name
|
|
2584
|
+
"""
|
|
2585
|
+
session = self._get_session_for_operation()
|
|
2586
|
+
async with self._session_scope(session) as session:
|
|
2587
|
+
# Query ChannelSubscription table with JOIN
|
|
2588
|
+
result = await session.execute(
|
|
2589
|
+
select(
|
|
2590
|
+
ChannelSubscription.instance_id,
|
|
2591
|
+
ChannelSubscription.channel,
|
|
2592
|
+
ChannelSubscription.activity_id,
|
|
2593
|
+
ChannelSubscription.timeout_at,
|
|
2594
|
+
ChannelSubscription.subscribed_at,
|
|
2595
|
+
WorkflowInstance.workflow_name,
|
|
2596
|
+
)
|
|
2597
|
+
.join(
|
|
2598
|
+
WorkflowInstance,
|
|
2599
|
+
ChannelSubscription.instance_id == WorkflowInstance.instance_id,
|
|
2600
|
+
)
|
|
2601
|
+
.where(
|
|
2602
|
+
and_(
|
|
2603
|
+
ChannelSubscription.timeout_at.isnot(None),
|
|
2604
|
+
ChannelSubscription.activity_id.isnot(None), # Only waiting subscriptions
|
|
2605
|
+
self._make_datetime_comparable(ChannelSubscription.timeout_at)
|
|
2606
|
+
<= self._get_current_time_expr(),
|
|
2607
|
+
)
|
|
2608
|
+
)
|
|
2609
|
+
)
|
|
2610
|
+
rows = result.all()
|
|
2611
|
+
return [
|
|
2612
|
+
{
|
|
2613
|
+
"instance_id": row[0],
|
|
2614
|
+
"channel": row[1],
|
|
2615
|
+
"activity_id": row[2],
|
|
2616
|
+
"timeout_at": row[3],
|
|
2617
|
+
"created_at": row[4], # subscribed_at as created_at for compatibility
|
|
2618
|
+
"workflow_name": row[5],
|
|
2619
|
+
}
|
|
2620
|
+
for row in rows
|
|
2621
|
+
]
|
|
2622
|
+
|
|
2623
|
+
# -------------------------------------------------------------------------
|
|
2624
|
+
# Group Membership Methods (Erlang pg style)
|
|
2625
|
+
# -------------------------------------------------------------------------
|
|
2626
|
+
|
|
2627
|
+
async def join_group(self, instance_id: str, group_name: str) -> None:
|
|
2628
|
+
"""
|
|
2629
|
+
Add a workflow instance to a group.
|
|
2630
|
+
|
|
2631
|
+
Groups provide loose coupling for broadcast messaging.
|
|
2632
|
+
Idempotent - joining a group the instance is already in is a no-op.
|
|
2633
|
+
|
|
2634
|
+
Args:
|
|
2635
|
+
instance_id: Workflow instance ID
|
|
2636
|
+
group_name: Group to join
|
|
2637
|
+
"""
|
|
2638
|
+
session = self._get_session_for_operation()
|
|
2639
|
+
async with self._session_scope(session) as session:
|
|
2640
|
+
# Check if already a member (for idempotency)
|
|
2641
|
+
result = await session.execute(
|
|
2642
|
+
select(WorkflowGroupMembership).where(
|
|
2643
|
+
and_(
|
|
2644
|
+
WorkflowGroupMembership.instance_id == instance_id,
|
|
2645
|
+
WorkflowGroupMembership.group_name == group_name,
|
|
2646
|
+
)
|
|
2647
|
+
)
|
|
2648
|
+
)
|
|
2649
|
+
if result.scalar_one_or_none() is not None:
|
|
2650
|
+
# Already a member, nothing to do
|
|
2651
|
+
return
|
|
2652
|
+
|
|
2653
|
+
# Add membership
|
|
2654
|
+
membership = WorkflowGroupMembership(
|
|
2655
|
+
instance_id=instance_id,
|
|
2656
|
+
group_name=group_name,
|
|
2657
|
+
)
|
|
2658
|
+
session.add(membership)
|
|
2659
|
+
await self._commit_if_not_in_transaction(session)
|
|
2660
|
+
|
|
2661
|
+
async def leave_group(self, instance_id: str, group_name: str) -> None:
|
|
2662
|
+
"""
|
|
2663
|
+
Remove a workflow instance from a group.
|
|
2664
|
+
|
|
2665
|
+
Args:
|
|
2666
|
+
instance_id: Workflow instance ID
|
|
2667
|
+
group_name: Group to leave
|
|
2668
|
+
"""
|
|
2669
|
+
session = self._get_session_for_operation()
|
|
2670
|
+
async with self._session_scope(session) as session:
|
|
2671
|
+
await session.execute(
|
|
2672
|
+
delete(WorkflowGroupMembership).where(
|
|
2673
|
+
and_(
|
|
2674
|
+
WorkflowGroupMembership.instance_id == instance_id,
|
|
2675
|
+
WorkflowGroupMembership.group_name == group_name,
|
|
2676
|
+
)
|
|
2677
|
+
)
|
|
2678
|
+
)
|
|
2679
|
+
await self._commit_if_not_in_transaction(session)
|
|
2680
|
+
|
|
2681
|
+
async def get_group_members(self, group_name: str) -> list[str]:
|
|
2682
|
+
"""
|
|
2683
|
+
Get all workflow instances in a group.
|
|
2684
|
+
|
|
2685
|
+
Args:
|
|
2686
|
+
group_name: Group name
|
|
2687
|
+
|
|
2688
|
+
Returns:
|
|
2689
|
+
List of instance IDs in the group
|
|
2690
|
+
"""
|
|
2691
|
+
session = self._get_session_for_operation()
|
|
2692
|
+
async with self._session_scope(session) as session:
|
|
2693
|
+
result = await session.execute(
|
|
2694
|
+
select(WorkflowGroupMembership.instance_id).where(
|
|
2695
|
+
WorkflowGroupMembership.group_name == group_name
|
|
2696
|
+
)
|
|
2697
|
+
)
|
|
2698
|
+
return [row[0] for row in result.fetchall()]
|
|
2699
|
+
|
|
2700
|
+
async def leave_all_groups(self, instance_id: str) -> None:
|
|
2701
|
+
"""
|
|
2702
|
+
Remove a workflow instance from all groups.
|
|
2703
|
+
|
|
2704
|
+
Called automatically when a workflow completes or fails.
|
|
2705
|
+
|
|
2706
|
+
Args:
|
|
2707
|
+
instance_id: Workflow instance ID
|
|
2708
|
+
"""
|
|
2709
|
+
session = self._get_session_for_operation()
|
|
2710
|
+
async with self._session_scope(session) as session:
|
|
2711
|
+
await session.execute(
|
|
2712
|
+
delete(WorkflowGroupMembership).where(
|
|
2713
|
+
WorkflowGroupMembership.instance_id == instance_id
|
|
2714
|
+
)
|
|
2715
|
+
)
|
|
2716
|
+
await self._commit_if_not_in_transaction(session)
|
|
2717
|
+
|
|
2718
|
+
# -------------------------------------------------------------------------
|
|
2719
|
+
# Workflow Resumption Methods
|
|
2720
|
+
# -------------------------------------------------------------------------
|
|
2721
|
+
|
|
2722
|
+
async def find_resumable_workflows(self) -> list[dict[str, Any]]:
|
|
2723
|
+
"""
|
|
2724
|
+
Find workflows that are ready to be resumed.
|
|
2725
|
+
|
|
2726
|
+
Returns workflows with status='running' that don't have an active lock.
|
|
2727
|
+
Used for immediate resumption after message delivery.
|
|
2728
|
+
|
|
2729
|
+
Returns:
|
|
2730
|
+
List of resumable workflows with instance_id and workflow_name.
|
|
2731
|
+
"""
|
|
2732
|
+
session = self._get_session_for_operation()
|
|
2733
|
+
async with self._session_scope(session) as session:
|
|
2734
|
+
result = await session.execute(
|
|
2735
|
+
select(
|
|
2736
|
+
WorkflowInstance.instance_id,
|
|
2737
|
+
WorkflowInstance.workflow_name,
|
|
2738
|
+
).where(
|
|
2739
|
+
and_(
|
|
2740
|
+
WorkflowInstance.status == "running",
|
|
2741
|
+
WorkflowInstance.locked_by.is_(None),
|
|
2742
|
+
)
|
|
2743
|
+
)
|
|
2744
|
+
)
|
|
2745
|
+
return [
|
|
2746
|
+
{
|
|
2747
|
+
"instance_id": row.instance_id,
|
|
2748
|
+
"workflow_name": row.workflow_name,
|
|
2749
|
+
}
|
|
2750
|
+
for row in result.fetchall()
|
|
2751
|
+
]
|
|
2752
|
+
|
|
2753
|
+
# -------------------------------------------------------------------------
|
|
2754
|
+
# Subscription Cleanup Methods (for recur())
|
|
2755
|
+
# -------------------------------------------------------------------------
|
|
2756
|
+
|
|
2757
|
+
async def cleanup_instance_subscriptions(self, instance_id: str) -> None:
|
|
2758
|
+
"""
|
|
2759
|
+
Remove all subscriptions for a workflow instance.
|
|
2760
|
+
|
|
2761
|
+
Called during recur() to clean up event/timer/message subscriptions
|
|
2762
|
+
before archiving the history.
|
|
2763
|
+
|
|
2764
|
+
Args:
|
|
2765
|
+
instance_id: Workflow instance ID to clean up
|
|
2766
|
+
"""
|
|
2767
|
+
session = self._get_session_for_operation()
|
|
2768
|
+
async with self._session_scope(session) as session:
|
|
2769
|
+
# Remove timer subscriptions
|
|
2770
|
+
await session.execute(
|
|
2771
|
+
delete(WorkflowTimerSubscription).where(
|
|
2772
|
+
WorkflowTimerSubscription.instance_id == instance_id
|
|
2773
|
+
)
|
|
2774
|
+
)
|
|
2775
|
+
|
|
2776
|
+
# Remove channel subscriptions
|
|
2777
|
+
await session.execute(
|
|
2778
|
+
delete(ChannelSubscription).where(ChannelSubscription.instance_id == instance_id)
|
|
2779
|
+
)
|
|
2780
|
+
|
|
2781
|
+
# Remove channel message claims
|
|
2782
|
+
await session.execute(
|
|
2783
|
+
delete(ChannelMessageClaim).where(ChannelMessageClaim.instance_id == instance_id)
|
|
2784
|
+
)
|
|
2785
|
+
|
|
2786
|
+
await self._commit_if_not_in_transaction(session)
|
|
2787
|
+
|
|
2788
|
+
# -------------------------------------------------------------------------
|
|
2789
|
+
# Channel-based Message Queue Methods
|
|
2790
|
+
# -------------------------------------------------------------------------
|
|
2791
|
+
|
|
2792
|
+
async def publish_to_channel(
|
|
2793
|
+
self,
|
|
2794
|
+
channel: str,
|
|
2795
|
+
data: dict[str, Any] | bytes,
|
|
2796
|
+
metadata: dict[str, Any] | None = None,
|
|
2797
|
+
) -> str:
|
|
2798
|
+
"""
|
|
2799
|
+
Publish a message to a channel.
|
|
2800
|
+
|
|
2801
|
+
Messages are persisted to channel_messages and available for subscribers.
|
|
2802
|
+
|
|
2803
|
+
Args:
|
|
2804
|
+
channel: Channel name
|
|
2805
|
+
data: Message payload (dict or bytes)
|
|
2806
|
+
metadata: Optional message metadata
|
|
2807
|
+
|
|
2808
|
+
Returns:
|
|
2809
|
+
Generated message_id (UUID)
|
|
2810
|
+
"""
|
|
2811
|
+
import uuid
|
|
2812
|
+
|
|
2813
|
+
message_id = str(uuid.uuid4())
|
|
2814
|
+
|
|
2815
|
+
# Determine data type and serialize
|
|
2816
|
+
if isinstance(data, bytes):
|
|
2817
|
+
data_type = "binary"
|
|
2818
|
+
data_json = None
|
|
2819
|
+
data_binary = data
|
|
2820
|
+
else:
|
|
2821
|
+
data_type = "json"
|
|
2822
|
+
data_json = json.dumps(data)
|
|
2823
|
+
data_binary = None
|
|
2824
|
+
|
|
2825
|
+
metadata_json = json.dumps(metadata) if metadata else None
|
|
2826
|
+
|
|
2827
|
+
session = self._get_session_for_operation()
|
|
2828
|
+
async with self._session_scope(session) as session:
|
|
2829
|
+
msg = ChannelMessage(
|
|
2830
|
+
channel=channel,
|
|
2831
|
+
message_id=message_id,
|
|
2832
|
+
data_type=data_type,
|
|
2833
|
+
data=data_json,
|
|
2834
|
+
data_binary=data_binary,
|
|
2835
|
+
message_metadata=metadata_json,
|
|
2836
|
+
)
|
|
2837
|
+
session.add(msg)
|
|
2838
|
+
await self._commit_if_not_in_transaction(session)
|
|
2839
|
+
|
|
2840
|
+
return message_id
|
|
2841
|
+
|
|
2842
|
+
async def subscribe_to_channel(
|
|
2843
|
+
self,
|
|
2844
|
+
instance_id: str,
|
|
2845
|
+
channel: str,
|
|
2846
|
+
mode: str,
|
|
2847
|
+
) -> None:
|
|
2848
|
+
"""
|
|
2849
|
+
Subscribe a workflow instance to a channel.
|
|
2850
|
+
|
|
2851
|
+
Args:
|
|
2852
|
+
instance_id: Workflow instance ID
|
|
2853
|
+
channel: Channel name
|
|
2854
|
+
mode: 'broadcast' or 'competing'
|
|
2855
|
+
|
|
2856
|
+
Raises:
|
|
2857
|
+
ValueError: If mode is invalid
|
|
2858
|
+
"""
|
|
2859
|
+
if mode not in ("broadcast", "competing"):
|
|
2860
|
+
raise ValueError(f"Invalid mode: {mode}. Must be 'broadcast' or 'competing'")
|
|
2861
|
+
|
|
2862
|
+
session = self._get_session_for_operation()
|
|
2863
|
+
async with self._session_scope(session) as session:
|
|
2864
|
+
# Check if already subscribed
|
|
2865
|
+
result = await session.execute(
|
|
2866
|
+
select(ChannelSubscription).where(
|
|
2867
|
+
and_(
|
|
2868
|
+
ChannelSubscription.instance_id == instance_id,
|
|
2869
|
+
ChannelSubscription.channel == channel,
|
|
2870
|
+
)
|
|
2871
|
+
)
|
|
2872
|
+
)
|
|
2873
|
+
existing = result.scalar_one_or_none()
|
|
2874
|
+
|
|
2875
|
+
if existing is not None:
|
|
2876
|
+
# Already subscribed, update mode if different
|
|
2877
|
+
if existing.mode != mode:
|
|
2878
|
+
existing.mode = mode
|
|
2879
|
+
await self._commit_if_not_in_transaction(session)
|
|
2880
|
+
return
|
|
2881
|
+
|
|
2882
|
+
# For broadcast mode, set cursor to current max message id
|
|
2883
|
+
# So subscriber only sees messages published after subscription
|
|
2884
|
+
cursor_message_id = None
|
|
2885
|
+
if mode == "broadcast":
|
|
2886
|
+
result = await session.execute(
|
|
2887
|
+
select(func.max(ChannelMessage.id)).where(ChannelMessage.channel == channel)
|
|
2888
|
+
)
|
|
2889
|
+
max_id = result.scalar()
|
|
2890
|
+
cursor_message_id = max_id if max_id is not None else 0
|
|
2891
|
+
|
|
2892
|
+
# Create subscription
|
|
2893
|
+
subscription = ChannelSubscription(
|
|
2894
|
+
instance_id=instance_id,
|
|
2895
|
+
channel=channel,
|
|
2896
|
+
mode=mode,
|
|
2897
|
+
cursor_message_id=cursor_message_id,
|
|
2898
|
+
)
|
|
2899
|
+
session.add(subscription)
|
|
2900
|
+
await self._commit_if_not_in_transaction(session)
|
|
2901
|
+
|
|
2902
|
+
async def unsubscribe_from_channel(
|
|
2903
|
+
self,
|
|
2904
|
+
instance_id: str,
|
|
2905
|
+
channel: str,
|
|
2906
|
+
) -> None:
|
|
2907
|
+
"""
|
|
2908
|
+
Unsubscribe a workflow instance from a channel.
|
|
2909
|
+
|
|
2910
|
+
Args:
|
|
2911
|
+
instance_id: Workflow instance ID
|
|
2912
|
+
channel: Channel name
|
|
2913
|
+
"""
|
|
2914
|
+
session = self._get_session_for_operation()
|
|
2915
|
+
async with self._session_scope(session) as session:
|
|
2916
|
+
await session.execute(
|
|
2917
|
+
delete(ChannelSubscription).where(
|
|
2918
|
+
and_(
|
|
2919
|
+
ChannelSubscription.instance_id == instance_id,
|
|
2920
|
+
ChannelSubscription.channel == channel,
|
|
2921
|
+
)
|
|
2922
|
+
)
|
|
2923
|
+
)
|
|
2924
|
+
await self._commit_if_not_in_transaction(session)
|
|
2925
|
+
|
|
2926
|
+
async def get_channel_subscription(
|
|
2927
|
+
self,
|
|
2928
|
+
instance_id: str,
|
|
2929
|
+
channel: str,
|
|
2930
|
+
) -> dict[str, Any] | None:
|
|
2931
|
+
"""
|
|
2932
|
+
Get the subscription info for a workflow instance on a channel.
|
|
2933
|
+
|
|
2934
|
+
Args:
|
|
2935
|
+
instance_id: Workflow instance ID
|
|
2936
|
+
channel: Channel name
|
|
2937
|
+
|
|
2938
|
+
Returns:
|
|
2939
|
+
Subscription info dict with: mode, activity_id, cursor_message_id
|
|
2940
|
+
or None if not subscribed
|
|
2941
|
+
"""
|
|
2942
|
+
session = self._get_session_for_operation()
|
|
2943
|
+
async with self._session_scope(session) as session:
|
|
2944
|
+
result = await session.execute(
|
|
2945
|
+
select(ChannelSubscription).where(
|
|
2946
|
+
and_(
|
|
2947
|
+
ChannelSubscription.instance_id == instance_id,
|
|
2948
|
+
ChannelSubscription.channel == channel,
|
|
2949
|
+
)
|
|
2950
|
+
)
|
|
2951
|
+
)
|
|
2952
|
+
subscription = result.scalar_one_or_none()
|
|
2953
|
+
|
|
2954
|
+
if subscription is None:
|
|
2955
|
+
return None
|
|
2956
|
+
|
|
2957
|
+
return {
|
|
2958
|
+
"mode": subscription.mode,
|
|
2959
|
+
"activity_id": subscription.activity_id,
|
|
2960
|
+
"cursor_message_id": subscription.cursor_message_id,
|
|
2961
|
+
}
|
|
2962
|
+
|
|
2963
|
+
async def register_channel_receive_and_release_lock(
|
|
2964
|
+
self,
|
|
2965
|
+
instance_id: str,
|
|
2966
|
+
worker_id: str,
|
|
2967
|
+
channel: str,
|
|
2968
|
+
activity_id: str | None = None,
|
|
2969
|
+
timeout_seconds: int | None = None,
|
|
2970
|
+
) -> None:
|
|
2971
|
+
"""
|
|
2972
|
+
Atomically register that workflow is waiting for channel message and release lock.
|
|
2973
|
+
|
|
2974
|
+
Args:
|
|
2975
|
+
instance_id: Workflow instance ID
|
|
2976
|
+
worker_id: Worker ID that currently holds the lock
|
|
2977
|
+
channel: Channel name being waited on
|
|
2978
|
+
activity_id: Current activity ID to record
|
|
2979
|
+
timeout_seconds: Optional timeout in seconds for the message wait
|
|
2980
|
+
|
|
2981
|
+
Raises:
|
|
2982
|
+
RuntimeError: If the worker doesn't hold the lock
|
|
2983
|
+
ValueError: If workflow is not subscribed to the channel
|
|
2984
|
+
"""
|
|
2985
|
+
async with self.engine.begin() as conn:
|
|
2986
|
+
session = AsyncSession(bind=conn, expire_on_commit=False)
|
|
2987
|
+
|
|
2988
|
+
# Verify lock ownership
|
|
2989
|
+
result = await session.execute(
|
|
2990
|
+
select(WorkflowInstance).where(WorkflowInstance.instance_id == instance_id)
|
|
2991
|
+
)
|
|
2992
|
+
instance = result.scalar_one_or_none()
|
|
2993
|
+
|
|
2994
|
+
if instance is None:
|
|
2995
|
+
raise RuntimeError(f"Instance not found: {instance_id}")
|
|
2996
|
+
|
|
2997
|
+
if instance.locked_by != worker_id:
|
|
2998
|
+
raise RuntimeError(
|
|
2999
|
+
f"Worker {worker_id} does not hold lock for {instance_id}. "
|
|
3000
|
+
f"Locked by: {instance.locked_by}"
|
|
3001
|
+
)
|
|
3002
|
+
|
|
3003
|
+
# Verify subscription exists
|
|
3004
|
+
sub_result = await session.execute(
|
|
3005
|
+
select(ChannelSubscription).where(
|
|
3006
|
+
and_(
|
|
3007
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3008
|
+
ChannelSubscription.channel == channel,
|
|
3009
|
+
)
|
|
3010
|
+
)
|
|
3011
|
+
)
|
|
3012
|
+
subscription: ChannelSubscription | None = sub_result.scalar_one_or_none()
|
|
3013
|
+
|
|
3014
|
+
if subscription is None:
|
|
3015
|
+
raise ValueError(f"Instance {instance_id} is not subscribed to channel {channel}")
|
|
3016
|
+
|
|
3017
|
+
# Update subscription to mark as waiting
|
|
3018
|
+
current_time = datetime.now(UTC)
|
|
3019
|
+
subscription.activity_id = activity_id
|
|
3020
|
+
# Calculate timeout_at if timeout_seconds is provided
|
|
3021
|
+
if timeout_seconds is not None:
|
|
3022
|
+
subscription.timeout_at = current_time + timedelta(seconds=timeout_seconds)
|
|
3023
|
+
else:
|
|
3024
|
+
subscription.timeout_at = None
|
|
3025
|
+
|
|
3026
|
+
# Update instance: set activity, status, release lock
|
|
3027
|
+
await session.execute(
|
|
3028
|
+
update(WorkflowInstance)
|
|
3029
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
3030
|
+
.values(
|
|
3031
|
+
current_activity_id=activity_id,
|
|
3032
|
+
status="waiting_for_message",
|
|
3033
|
+
locked_by=None,
|
|
3034
|
+
locked_at=None,
|
|
3035
|
+
lock_expires_at=None,
|
|
3036
|
+
updated_at=current_time,
|
|
3037
|
+
)
|
|
3038
|
+
)
|
|
3039
|
+
|
|
3040
|
+
await session.commit()
|
|
3041
|
+
|
|
3042
|
+
async def get_pending_channel_messages(
|
|
3043
|
+
self,
|
|
3044
|
+
instance_id: str,
|
|
3045
|
+
channel: str,
|
|
3046
|
+
) -> list[dict[str, Any]]:
|
|
3047
|
+
"""
|
|
3048
|
+
Get pending messages for a subscriber on a channel.
|
|
3049
|
+
|
|
3050
|
+
For broadcast mode: messages with id > cursor_message_id
|
|
3051
|
+
For competing mode: unclaimed messages
|
|
3052
|
+
|
|
3053
|
+
Args:
|
|
3054
|
+
instance_id: Workflow instance ID
|
|
3055
|
+
channel: Channel name
|
|
3056
|
+
|
|
3057
|
+
Returns:
|
|
3058
|
+
List of pending messages
|
|
3059
|
+
"""
|
|
3060
|
+
session = self._get_session_for_operation()
|
|
3061
|
+
async with self._session_scope(session) as session:
|
|
3062
|
+
# Get subscription info
|
|
3063
|
+
sub_result = await session.execute(
|
|
3064
|
+
select(ChannelSubscription).where(
|
|
3065
|
+
and_(
|
|
3066
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3067
|
+
ChannelSubscription.channel == channel,
|
|
3068
|
+
)
|
|
3069
|
+
)
|
|
3070
|
+
)
|
|
3071
|
+
subscription = sub_result.scalar_one_or_none()
|
|
3072
|
+
|
|
3073
|
+
if subscription is None:
|
|
3074
|
+
return []
|
|
3075
|
+
|
|
3076
|
+
if subscription.mode == "broadcast":
|
|
3077
|
+
# Get messages after cursor
|
|
3078
|
+
cursor = subscription.cursor_message_id or 0
|
|
3079
|
+
msg_result = await session.execute(
|
|
3080
|
+
select(ChannelMessage)
|
|
3081
|
+
.where(
|
|
3082
|
+
and_(
|
|
3083
|
+
ChannelMessage.channel == channel,
|
|
3084
|
+
ChannelMessage.id > cursor,
|
|
3085
|
+
)
|
|
3086
|
+
)
|
|
3087
|
+
.order_by(ChannelMessage.published_at.asc())
|
|
3088
|
+
)
|
|
3089
|
+
else: # competing
|
|
3090
|
+
# Get unclaimed messages (not in channel_message_claims)
|
|
3091
|
+
subquery = select(ChannelMessageClaim.message_id)
|
|
3092
|
+
msg_result = await session.execute(
|
|
3093
|
+
select(ChannelMessage)
|
|
3094
|
+
.where(
|
|
3095
|
+
and_(
|
|
3096
|
+
ChannelMessage.channel == channel,
|
|
3097
|
+
ChannelMessage.message_id.not_in(subquery),
|
|
3098
|
+
)
|
|
3099
|
+
)
|
|
3100
|
+
.order_by(ChannelMessage.published_at.asc())
|
|
3101
|
+
)
|
|
3102
|
+
|
|
3103
|
+
messages = msg_result.scalars().all()
|
|
3104
|
+
return [
|
|
3105
|
+
{
|
|
3106
|
+
"id": msg.id,
|
|
3107
|
+
"message_id": msg.message_id,
|
|
3108
|
+
"channel": msg.channel,
|
|
3109
|
+
"data": (
|
|
3110
|
+
msg.data_binary
|
|
3111
|
+
if msg.data_type == "binary"
|
|
3112
|
+
else json.loads(msg.data) if msg.data else {}
|
|
3113
|
+
),
|
|
3114
|
+
"metadata": json.loads(msg.message_metadata) if msg.message_metadata else {},
|
|
3115
|
+
"published_at": msg.published_at.isoformat() if msg.published_at else None,
|
|
3116
|
+
}
|
|
3117
|
+
for msg in messages
|
|
3118
|
+
]
|
|
3119
|
+
|
|
3120
|
+
async def claim_channel_message(
|
|
3121
|
+
self,
|
|
3122
|
+
message_id: str,
|
|
3123
|
+
instance_id: str,
|
|
3124
|
+
) -> bool:
|
|
3125
|
+
"""
|
|
3126
|
+
Claim a message for competing consumption.
|
|
3127
|
+
|
|
3128
|
+
Uses INSERT with conflict check to ensure only one subscriber claims.
|
|
3129
|
+
|
|
3130
|
+
Args:
|
|
3131
|
+
message_id: Message ID to claim
|
|
3132
|
+
instance_id: Workflow instance claiming the message
|
|
3133
|
+
|
|
3134
|
+
Returns:
|
|
3135
|
+
True if claim succeeded, False if already claimed
|
|
3136
|
+
"""
|
|
3137
|
+
session = self._get_session_for_operation()
|
|
3138
|
+
async with self._session_scope(session) as session:
|
|
3139
|
+
try:
|
|
3140
|
+
# Check if already claimed
|
|
3141
|
+
result = await session.execute(
|
|
3142
|
+
select(ChannelMessageClaim).where(ChannelMessageClaim.message_id == message_id)
|
|
3143
|
+
)
|
|
3144
|
+
if result.scalar_one_or_none() is not None:
|
|
3145
|
+
return False # Already claimed
|
|
3146
|
+
|
|
3147
|
+
claim = ChannelMessageClaim(
|
|
3148
|
+
message_id=message_id,
|
|
3149
|
+
instance_id=instance_id,
|
|
3150
|
+
)
|
|
3151
|
+
session.add(claim)
|
|
3152
|
+
await self._commit_if_not_in_transaction(session)
|
|
3153
|
+
return True
|
|
3154
|
+
except Exception:
|
|
3155
|
+
return False
|
|
3156
|
+
|
|
3157
|
+
async def delete_channel_message(self, message_id: str) -> None:
|
|
3158
|
+
"""
|
|
3159
|
+
Delete a message from the channel queue.
|
|
3160
|
+
|
|
3161
|
+
Args:
|
|
3162
|
+
message_id: Message ID to delete
|
|
3163
|
+
"""
|
|
3164
|
+
session = self._get_session_for_operation()
|
|
3165
|
+
async with self._session_scope(session) as session:
|
|
3166
|
+
# Delete claim first (foreign key)
|
|
3167
|
+
await session.execute(
|
|
3168
|
+
delete(ChannelMessageClaim).where(ChannelMessageClaim.message_id == message_id)
|
|
3169
|
+
)
|
|
3170
|
+
# Delete message
|
|
3171
|
+
await session.execute(
|
|
3172
|
+
delete(ChannelMessage).where(ChannelMessage.message_id == message_id)
|
|
3173
|
+
)
|
|
3174
|
+
await self._commit_if_not_in_transaction(session)
|
|
3175
|
+
|
|
3176
|
+
async def update_delivery_cursor(
|
|
3177
|
+
self,
|
|
3178
|
+
channel: str,
|
|
3179
|
+
instance_id: str,
|
|
3180
|
+
message_id: int,
|
|
3181
|
+
) -> None:
|
|
3182
|
+
"""
|
|
3183
|
+
Update the delivery cursor for broadcast mode.
|
|
3184
|
+
|
|
3185
|
+
Args:
|
|
3186
|
+
channel: Channel name
|
|
3187
|
+
instance_id: Subscriber instance ID
|
|
3188
|
+
message_id: Last delivered message's internal ID
|
|
3189
|
+
"""
|
|
3190
|
+
session = self._get_session_for_operation()
|
|
3191
|
+
async with self._session_scope(session) as session:
|
|
3192
|
+
# Update subscription cursor
|
|
3193
|
+
await session.execute(
|
|
3194
|
+
update(ChannelSubscription)
|
|
3195
|
+
.where(
|
|
3196
|
+
and_(
|
|
3197
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3198
|
+
ChannelSubscription.channel == channel,
|
|
3199
|
+
)
|
|
3200
|
+
)
|
|
3201
|
+
.values(cursor_message_id=message_id)
|
|
3202
|
+
)
|
|
3203
|
+
await self._commit_if_not_in_transaction(session)
|
|
3204
|
+
|
|
3205
|
+
async def get_channel_subscribers_waiting(
|
|
3206
|
+
self,
|
|
3207
|
+
channel: str,
|
|
3208
|
+
) -> list[dict[str, Any]]:
|
|
3209
|
+
"""
|
|
3210
|
+
Get channel subscribers that are waiting (activity_id is set).
|
|
3211
|
+
|
|
3212
|
+
Args:
|
|
3213
|
+
channel: Channel name
|
|
3214
|
+
|
|
3215
|
+
Returns:
|
|
3216
|
+
List of waiting subscribers
|
|
3217
|
+
"""
|
|
3218
|
+
session = self._get_session_for_operation()
|
|
3219
|
+
async with self._session_scope(session) as session:
|
|
3220
|
+
result = await session.execute(
|
|
3221
|
+
select(ChannelSubscription).where(
|
|
3222
|
+
and_(
|
|
3223
|
+
ChannelSubscription.channel == channel,
|
|
3224
|
+
ChannelSubscription.activity_id.isnot(None),
|
|
3225
|
+
)
|
|
3226
|
+
)
|
|
3227
|
+
)
|
|
3228
|
+
subscriptions = result.scalars().all()
|
|
3229
|
+
return [
|
|
3230
|
+
{
|
|
3231
|
+
"instance_id": sub.instance_id,
|
|
3232
|
+
"channel": sub.channel,
|
|
3233
|
+
"mode": sub.mode,
|
|
3234
|
+
"activity_id": sub.activity_id,
|
|
3235
|
+
}
|
|
3236
|
+
for sub in subscriptions
|
|
3237
|
+
]
|
|
3238
|
+
|
|
3239
|
+
async def clear_channel_waiting_state(
|
|
3240
|
+
self,
|
|
3241
|
+
instance_id: str,
|
|
3242
|
+
channel: str,
|
|
3243
|
+
) -> None:
|
|
3244
|
+
"""
|
|
3245
|
+
Clear the waiting state for a channel subscription.
|
|
3246
|
+
|
|
3247
|
+
Args:
|
|
3248
|
+
instance_id: Workflow instance ID
|
|
3249
|
+
channel: Channel name
|
|
3250
|
+
"""
|
|
3251
|
+
session = self._get_session_for_operation()
|
|
3252
|
+
async with self._session_scope(session) as session:
|
|
3253
|
+
await session.execute(
|
|
3254
|
+
update(ChannelSubscription)
|
|
3255
|
+
.where(
|
|
3256
|
+
and_(
|
|
3257
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3258
|
+
ChannelSubscription.channel == channel,
|
|
3259
|
+
)
|
|
3260
|
+
)
|
|
3261
|
+
.values(activity_id=None)
|
|
3262
|
+
)
|
|
3263
|
+
await self._commit_if_not_in_transaction(session)
|
|
3264
|
+
|
|
3265
|
+
async def deliver_channel_message(
|
|
3266
|
+
self,
|
|
3267
|
+
instance_id: str,
|
|
3268
|
+
channel: str,
|
|
3269
|
+
message_id: str,
|
|
3270
|
+
data: dict[str, Any] | bytes,
|
|
3271
|
+
metadata: dict[str, Any],
|
|
3272
|
+
worker_id: str,
|
|
3273
|
+
) -> dict[str, Any] | None:
|
|
3274
|
+
"""
|
|
3275
|
+
Deliver a channel message to a waiting workflow.
|
|
3276
|
+
|
|
3277
|
+
Uses Lock-First pattern for distributed safety.
|
|
3278
|
+
|
|
3279
|
+
Args:
|
|
3280
|
+
instance_id: Target workflow instance ID
|
|
3281
|
+
channel: Channel name
|
|
3282
|
+
message_id: Message ID being delivered
|
|
3283
|
+
data: Message payload
|
|
3284
|
+
metadata: Message metadata
|
|
3285
|
+
worker_id: Worker ID for locking
|
|
3286
|
+
|
|
3287
|
+
Returns:
|
|
3288
|
+
Delivery info if successful, None if failed
|
|
3289
|
+
"""
|
|
3290
|
+
try:
|
|
3291
|
+
# Try to acquire lock
|
|
3292
|
+
if not await self.try_acquire_lock(instance_id, worker_id):
|
|
3293
|
+
logger.debug(f"Failed to acquire lock for {instance_id}")
|
|
3294
|
+
return None
|
|
3295
|
+
|
|
3296
|
+
try:
|
|
3297
|
+
async with self.engine.begin() as conn:
|
|
3298
|
+
session = AsyncSession(bind=conn, expire_on_commit=False)
|
|
3299
|
+
|
|
3300
|
+
# Get subscription info
|
|
3301
|
+
result = await session.execute(
|
|
3302
|
+
select(ChannelSubscription).where(
|
|
3303
|
+
and_(
|
|
3304
|
+
ChannelSubscription.instance_id == instance_id,
|
|
3305
|
+
ChannelSubscription.channel == channel,
|
|
3306
|
+
)
|
|
3307
|
+
)
|
|
3308
|
+
)
|
|
3309
|
+
subscription = result.scalar_one_or_none()
|
|
3310
|
+
|
|
3311
|
+
if subscription is None or subscription.activity_id is None:
|
|
3312
|
+
logger.debug(f"No waiting subscription for {instance_id} on {channel}")
|
|
3313
|
+
return None
|
|
3314
|
+
|
|
3315
|
+
activity_id = subscription.activity_id
|
|
3316
|
+
|
|
3317
|
+
# Get instance info for return value
|
|
3318
|
+
result = await session.execute(
|
|
3319
|
+
select(WorkflowInstance.workflow_name).where(
|
|
3320
|
+
WorkflowInstance.instance_id == instance_id
|
|
3321
|
+
)
|
|
3322
|
+
)
|
|
3323
|
+
row = result.one_or_none()
|
|
3324
|
+
if row is None:
|
|
3325
|
+
return None
|
|
3326
|
+
workflow_name = row[0]
|
|
3327
|
+
|
|
3328
|
+
# Prepare message data for history
|
|
3329
|
+
# Use "id" key to match what context.py expects when loading history
|
|
3330
|
+
current_time = datetime.now(UTC)
|
|
3331
|
+
message_result = {
|
|
3332
|
+
"id": message_id,
|
|
3333
|
+
"channel": channel,
|
|
3334
|
+
"data": data if isinstance(data, dict) else None,
|
|
3335
|
+
"metadata": metadata,
|
|
3336
|
+
"published_at": current_time.isoformat(),
|
|
3337
|
+
}
|
|
3338
|
+
|
|
3339
|
+
# Record to history
|
|
3340
|
+
if isinstance(data, bytes):
|
|
3341
|
+
history = WorkflowHistory(
|
|
3342
|
+
instance_id=instance_id,
|
|
3343
|
+
activity_id=activity_id,
|
|
3344
|
+
event_type="ChannelMessageReceived",
|
|
3345
|
+
data_type="binary",
|
|
3346
|
+
event_data=None,
|
|
3347
|
+
event_data_binary=data,
|
|
3348
|
+
)
|
|
3349
|
+
else:
|
|
3350
|
+
history = WorkflowHistory(
|
|
3351
|
+
instance_id=instance_id,
|
|
3352
|
+
activity_id=activity_id,
|
|
3353
|
+
event_type="ChannelMessageReceived",
|
|
3354
|
+
data_type="json",
|
|
3355
|
+
event_data=json.dumps(message_result),
|
|
3356
|
+
event_data_binary=None,
|
|
3357
|
+
)
|
|
3358
|
+
session.add(history)
|
|
3359
|
+
|
|
3360
|
+
# Handle mode-specific logic
|
|
3361
|
+
if subscription.mode == "broadcast":
|
|
3362
|
+
# Get message internal id to update cursor
|
|
3363
|
+
result = await session.execute(
|
|
3364
|
+
select(ChannelMessage.id).where(ChannelMessage.message_id == message_id)
|
|
3365
|
+
)
|
|
3366
|
+
msg_row = result.one_or_none()
|
|
3367
|
+
if msg_row:
|
|
3368
|
+
subscription.cursor_message_id = msg_row[0]
|
|
3369
|
+
else: # competing
|
|
3370
|
+
# Claim and delete the message
|
|
3371
|
+
claim = ChannelMessageClaim(
|
|
3372
|
+
message_id=message_id,
|
|
3373
|
+
instance_id=instance_id,
|
|
3374
|
+
)
|
|
3375
|
+
session.add(claim)
|
|
3376
|
+
|
|
3377
|
+
# Delete the message (competing mode consumes it)
|
|
3378
|
+
await session.execute(
|
|
3379
|
+
delete(ChannelMessage).where(ChannelMessage.message_id == message_id)
|
|
3380
|
+
)
|
|
3381
|
+
|
|
3382
|
+
# Clear waiting state
|
|
3383
|
+
subscription.activity_id = None
|
|
3384
|
+
|
|
3385
|
+
# Update instance status to running
|
|
3386
|
+
current_time = datetime.now(UTC)
|
|
3387
|
+
await session.execute(
|
|
3388
|
+
update(WorkflowInstance)
|
|
3389
|
+
.where(WorkflowInstance.instance_id == instance_id)
|
|
3390
|
+
.values(
|
|
3391
|
+
status="running",
|
|
3392
|
+
updated_at=current_time,
|
|
3393
|
+
)
|
|
3394
|
+
)
|
|
3395
|
+
|
|
3396
|
+
await session.commit()
|
|
3397
|
+
|
|
3398
|
+
return {
|
|
3399
|
+
"instance_id": instance_id,
|
|
3400
|
+
"workflow_name": workflow_name,
|
|
3401
|
+
"activity_id": activity_id,
|
|
3402
|
+
}
|
|
3403
|
+
|
|
3404
|
+
finally:
|
|
3405
|
+
# Always release lock
|
|
3406
|
+
await self.release_lock(instance_id, worker_id)
|
|
3407
|
+
|
|
3408
|
+
except Exception as e:
|
|
3409
|
+
logger.error(f"Error delivering channel message: {e}")
|
|
3410
|
+
return None
|
|
3411
|
+
|
|
3412
|
+
async def cleanup_old_channel_messages(self, older_than_days: int = 7) -> int:
|
|
3413
|
+
"""
|
|
3414
|
+
Clean up old messages from channel queues.
|
|
3415
|
+
|
|
3416
|
+
Args:
|
|
3417
|
+
older_than_days: Message retention period in days
|
|
3418
|
+
|
|
3419
|
+
Returns:
|
|
3420
|
+
Number of messages deleted
|
|
3421
|
+
"""
|
|
3422
|
+
cutoff_time = datetime.now(UTC) - timedelta(days=older_than_days)
|
|
3423
|
+
|
|
3424
|
+
session = self._get_session_for_operation()
|
|
3425
|
+
async with self._session_scope(session) as session:
|
|
3426
|
+
# First delete claims for old messages
|
|
3427
|
+
await session.execute(
|
|
3428
|
+
delete(ChannelMessageClaim).where(
|
|
3429
|
+
ChannelMessageClaim.message_id.in_(
|
|
3430
|
+
select(ChannelMessage.message_id).where(
|
|
3431
|
+
self._make_datetime_comparable(ChannelMessage.published_at)
|
|
3432
|
+
< self._get_current_time_expr()
|
|
3433
|
+
)
|
|
3434
|
+
)
|
|
3435
|
+
)
|
|
3436
|
+
)
|
|
3437
|
+
|
|
3438
|
+
# Delete old messages
|
|
3439
|
+
result = await session.execute(
|
|
3440
|
+
delete(ChannelMessage)
|
|
3441
|
+
.where(ChannelMessage.published_at < cutoff_time)
|
|
3442
|
+
.returning(ChannelMessage.id)
|
|
3443
|
+
)
|
|
3444
|
+
deleted_ids = result.fetchall()
|
|
3445
|
+
await self._commit_if_not_in_transaction(session)
|
|
3446
|
+
|
|
3447
|
+
return len(deleted_ids)
|