edda-framework 0.9.1__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edda/app.py +419 -26
- edda/integrations/mirascope/__init__.py +78 -0
- edda/integrations/mirascope/agent.py +467 -0
- edda/integrations/mirascope/call.py +166 -0
- edda/integrations/mirascope/decorator.py +163 -0
- edda/integrations/mirascope/types.py +268 -0
- edda/outbox/relayer.py +21 -2
- edda/storage/__init__.py +8 -0
- edda/storage/notify_base.py +162 -0
- edda/storage/pg_notify.py +325 -0
- edda/storage/protocol.py +9 -1
- edda/storage/sqlalchemy_storage.py +193 -13
- edda/viewer_ui/app.py +26 -0
- edda/viewer_ui/data_service.py +4 -0
- {edda_framework-0.9.1.dist-info → edda_framework-0.11.0.dist-info}/METADATA +17 -1
- {edda_framework-0.9.1.dist-info → edda_framework-0.11.0.dist-info}/RECORD +19 -12
- {edda_framework-0.9.1.dist-info → edda_framework-0.11.0.dist-info}/WHEEL +0 -0
- {edda_framework-0.9.1.dist-info → edda_framework-0.11.0.dist-info}/entry_points.txt +0 -0
- {edda_framework-0.9.1.dist-info → edda_framework-0.11.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
"""PostgreSQL LISTEN/NOTIFY implementation using asyncpg.
|
|
2
|
+
|
|
3
|
+
This module provides a dedicated listener for PostgreSQL's LISTEN/NOTIFY
|
|
4
|
+
mechanism, enabling near-instant notification delivery for workflow events.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
from collections.abc import Awaitable, Callable
|
|
13
|
+
from contextlib import suppress
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
NotifyCallback = Callable[[str], Awaitable[None]]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PostgresNotifyListener:
|
|
23
|
+
"""PostgreSQL LISTEN/NOTIFY listener using asyncpg.
|
|
24
|
+
|
|
25
|
+
This class maintains a dedicated connection for LISTEN/NOTIFY operations.
|
|
26
|
+
It provides:
|
|
27
|
+
- Automatic reconnection on connection loss
|
|
28
|
+
- Channel subscription management
|
|
29
|
+
- Callback dispatch for notifications
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
>>> listener = PostgresNotifyListener(dsn="postgresql://localhost/db")
|
|
33
|
+
>>> await listener.start()
|
|
34
|
+
>>> await listener.subscribe("my_channel", handle_notification)
|
|
35
|
+
>>> # ... later
|
|
36
|
+
>>> await listener.stop()
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
dsn: str,
|
|
42
|
+
reconnect_interval: float = 5.0,
|
|
43
|
+
max_reconnect_attempts: int | None = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Initialize the PostgreSQL notify listener.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
dsn: PostgreSQL connection string (postgresql://user:pass@host/db).
|
|
49
|
+
reconnect_interval: Seconds to wait between reconnection attempts.
|
|
50
|
+
max_reconnect_attempts: Maximum number of reconnection attempts.
|
|
51
|
+
None means unlimited.
|
|
52
|
+
"""
|
|
53
|
+
self._dsn = dsn
|
|
54
|
+
self._reconnect_interval = reconnect_interval
|
|
55
|
+
self._max_reconnect_attempts = max_reconnect_attempts
|
|
56
|
+
|
|
57
|
+
self._connection: Any = None # asyncpg.Connection
|
|
58
|
+
self._callbacks: dict[str, list[NotifyCallback]] = {}
|
|
59
|
+
self._channel_handlers: dict[str, Callable[..., None]] = {}
|
|
60
|
+
self._running = False
|
|
61
|
+
self._reconnect_task: asyncio.Task[None] | None = None
|
|
62
|
+
self._lock = asyncio.Lock()
|
|
63
|
+
|
|
64
|
+
async def start(self) -> None:
|
|
65
|
+
"""Start the notification listener.
|
|
66
|
+
|
|
67
|
+
Establishes the connection and begins listening for notifications.
|
|
68
|
+
Starts the automatic reconnection task.
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ImportError: If asyncpg is not installed.
|
|
72
|
+
"""
|
|
73
|
+
if self._running:
|
|
74
|
+
logger.warning("PostgresNotifyListener already running")
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
self._running = True
|
|
78
|
+
await self._establish_connection()
|
|
79
|
+
|
|
80
|
+
# Start reconnection monitor
|
|
81
|
+
self._reconnect_task = asyncio.create_task(self._reconnect_loop())
|
|
82
|
+
logger.info("PostgresNotifyListener started")
|
|
83
|
+
|
|
84
|
+
async def stop(self) -> None:
|
|
85
|
+
"""Stop the notification listener.
|
|
86
|
+
|
|
87
|
+
Closes the connection and stops the reconnection task.
|
|
88
|
+
"""
|
|
89
|
+
self._running = False
|
|
90
|
+
|
|
91
|
+
# Cancel reconnection task
|
|
92
|
+
if self._reconnect_task is not None:
|
|
93
|
+
self._reconnect_task.cancel()
|
|
94
|
+
with suppress(asyncio.CancelledError):
|
|
95
|
+
await self._reconnect_task
|
|
96
|
+
self._reconnect_task = None
|
|
97
|
+
|
|
98
|
+
# Close connection
|
|
99
|
+
await self._close_connection()
|
|
100
|
+
self._callbacks.clear()
|
|
101
|
+
logger.info("PostgresNotifyListener stopped")
|
|
102
|
+
|
|
103
|
+
async def subscribe(self, channel: str, callback: NotifyCallback) -> None:
|
|
104
|
+
"""Subscribe to notifications on a channel.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
channel: The PostgreSQL channel name to listen on.
|
|
108
|
+
callback: Async function called when a notification arrives.
|
|
109
|
+
|
|
110
|
+
Note:
|
|
111
|
+
Channel names must be valid PostgreSQL identifiers (max 63 chars).
|
|
112
|
+
Multiple callbacks can be registered for the same channel.
|
|
113
|
+
"""
|
|
114
|
+
async with self._lock:
|
|
115
|
+
is_new_channel = channel not in self._callbacks
|
|
116
|
+
|
|
117
|
+
if channel not in self._callbacks:
|
|
118
|
+
self._callbacks[channel] = []
|
|
119
|
+
self._callbacks[channel].append(callback)
|
|
120
|
+
|
|
121
|
+
# Register listener if this is a new channel and we're connected
|
|
122
|
+
if is_new_channel and self._connection is not None:
|
|
123
|
+
try:
|
|
124
|
+
await self._connection.add_listener(
|
|
125
|
+
channel, self._create_notification_handler(channel)
|
|
126
|
+
)
|
|
127
|
+
logger.debug(f"Subscribed to channel: {channel}")
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(f"Failed to LISTEN on channel {channel}: {e}")
|
|
130
|
+
|
|
131
|
+
async def unsubscribe(self, channel: str) -> None:
|
|
132
|
+
"""Unsubscribe from notifications on a channel.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
channel: The PostgreSQL channel name to stop listening on.
|
|
136
|
+
"""
|
|
137
|
+
async with self._lock:
|
|
138
|
+
if channel in self._callbacks:
|
|
139
|
+
del self._callbacks[channel]
|
|
140
|
+
|
|
141
|
+
# Remove listener if we're connected
|
|
142
|
+
if self._connection is not None:
|
|
143
|
+
try:
|
|
144
|
+
await self._connection.remove_listener(
|
|
145
|
+
channel, self._create_notification_handler(channel)
|
|
146
|
+
)
|
|
147
|
+
logger.debug(f"Unsubscribed from channel: {channel}")
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.error(f"Failed to UNLISTEN on channel {channel}: {e}")
|
|
150
|
+
|
|
151
|
+
async def notify(self, channel: str, payload: str) -> None:
|
|
152
|
+
"""Send a notification on a channel.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
channel: The PostgreSQL channel name.
|
|
156
|
+
payload: The payload string (max ~7500 bytes recommended).
|
|
157
|
+
|
|
158
|
+
Note:
|
|
159
|
+
This uses the existing connection pool from SQLAlchemy,
|
|
160
|
+
not the dedicated listener connection.
|
|
161
|
+
"""
|
|
162
|
+
if self._connection is None:
|
|
163
|
+
logger.warning("Cannot send NOTIFY: not connected")
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
# Use pg_notify function to properly escape the payload
|
|
168
|
+
await self._connection.execute("SELECT pg_notify($1, $2)", channel, payload)
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.warning(f"Failed to send NOTIFY on channel {channel}: {e}")
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def is_connected(self) -> bool:
|
|
174
|
+
"""Check if the listener is currently connected."""
|
|
175
|
+
return self._connection is not None and not self._connection.is_closed()
|
|
176
|
+
|
|
177
|
+
async def _establish_connection(self) -> None:
|
|
178
|
+
"""Establish connection to PostgreSQL."""
|
|
179
|
+
try:
|
|
180
|
+
import asyncpg
|
|
181
|
+
except ImportError as e:
|
|
182
|
+
raise ImportError(
|
|
183
|
+
"asyncpg is required for PostgreSQL LISTEN/NOTIFY support. "
|
|
184
|
+
"Install it with: pip install edda[postgres-notify]"
|
|
185
|
+
) from e
|
|
186
|
+
|
|
187
|
+
try:
|
|
188
|
+
self._connection = await asyncpg.connect(self._dsn)
|
|
189
|
+
|
|
190
|
+
# Re-subscribe to all channels (this also registers listeners)
|
|
191
|
+
await self._resubscribe_all()
|
|
192
|
+
|
|
193
|
+
logger.info("PostgresNotifyListener connected to database")
|
|
194
|
+
except Exception as e:
|
|
195
|
+
logger.error(f"Failed to connect to PostgreSQL: {e}")
|
|
196
|
+
self._connection = None
|
|
197
|
+
raise
|
|
198
|
+
|
|
199
|
+
async def _close_connection(self) -> None:
|
|
200
|
+
"""Close the database connection."""
|
|
201
|
+
if self._connection is not None:
|
|
202
|
+
try:
|
|
203
|
+
await self._connection.close()
|
|
204
|
+
except Exception as e:
|
|
205
|
+
logger.warning(f"Error closing connection: {e}")
|
|
206
|
+
finally:
|
|
207
|
+
self._connection = None
|
|
208
|
+
|
|
209
|
+
async def _resubscribe_all(self) -> None:
|
|
210
|
+
"""Re-subscribe to all registered channels after reconnection."""
|
|
211
|
+
if self._connection is None:
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
for channel in self._callbacks:
|
|
215
|
+
try:
|
|
216
|
+
# Register listener for this channel
|
|
217
|
+
await self._connection.add_listener(
|
|
218
|
+
channel, self._create_notification_handler(channel)
|
|
219
|
+
)
|
|
220
|
+
logger.debug(f"Re-subscribed to channel: {channel}")
|
|
221
|
+
except Exception as e:
|
|
222
|
+
logger.error(f"Failed to re-subscribe to channel {channel}: {e}")
|
|
223
|
+
|
|
224
|
+
def _create_notification_handler(self, channel: str) -> Callable[..., None]:
|
|
225
|
+
"""Create or retrieve a notification handler for a channel.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
channel: The channel name.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
A handler function that can be passed to add_listener/remove_listener.
|
|
232
|
+
"""
|
|
233
|
+
if channel not in self._channel_handlers:
|
|
234
|
+
|
|
235
|
+
def handler(_connection: Any, _pid: int, ch: str, payload: str) -> None:
|
|
236
|
+
"""Handle incoming notification from PostgreSQL."""
|
|
237
|
+
callbacks = self._callbacks.get(ch, [])
|
|
238
|
+
for callback in callbacks:
|
|
239
|
+
asyncio.create_task(self._safe_callback(callback, payload, ch))
|
|
240
|
+
|
|
241
|
+
self._channel_handlers[channel] = handler
|
|
242
|
+
|
|
243
|
+
return self._channel_handlers[channel]
|
|
244
|
+
|
|
245
|
+
async def _safe_callback(self, callback: NotifyCallback, payload: str, channel: str) -> None:
|
|
246
|
+
"""Execute callback with error handling."""
|
|
247
|
+
try:
|
|
248
|
+
await callback(payload)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
logger.error(
|
|
251
|
+
f"Error in notification callback for channel {channel}: {e}",
|
|
252
|
+
exc_info=True,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
async def _reconnect_loop(self) -> None:
|
|
256
|
+
"""Monitor connection and reconnect on failure."""
|
|
257
|
+
attempt = 0
|
|
258
|
+
|
|
259
|
+
with suppress(asyncio.CancelledError):
|
|
260
|
+
while self._running:
|
|
261
|
+
await asyncio.sleep(1) # Check every second
|
|
262
|
+
|
|
263
|
+
if self._connection is None or self._connection.is_closed():
|
|
264
|
+
attempt += 1
|
|
265
|
+
if (
|
|
266
|
+
self._max_reconnect_attempts is not None
|
|
267
|
+
and attempt > self._max_reconnect_attempts
|
|
268
|
+
):
|
|
269
|
+
logger.error(
|
|
270
|
+
f"Max reconnection attempts ({self._max_reconnect_attempts}) "
|
|
271
|
+
"exceeded, giving up"
|
|
272
|
+
)
|
|
273
|
+
break
|
|
274
|
+
|
|
275
|
+
logger.info(
|
|
276
|
+
f"Connection lost, attempting reconnection " f"(attempt {attempt})..."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
try:
|
|
280
|
+
await self._close_connection()
|
|
281
|
+
await self._establish_connection()
|
|
282
|
+
attempt = 0 # Reset on success
|
|
283
|
+
logger.info("Reconnection successful")
|
|
284
|
+
except Exception as e:
|
|
285
|
+
logger.error(
|
|
286
|
+
f"Reconnection failed: {e}, " f"retrying in {self._reconnect_interval}s"
|
|
287
|
+
)
|
|
288
|
+
await asyncio.sleep(self._reconnect_interval)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def get_notify_channel_for_message(channel: str) -> str:
|
|
292
|
+
"""Convert Edda channel name to PostgreSQL NOTIFY channel.
|
|
293
|
+
|
|
294
|
+
Uses a hash to ensure valid PostgreSQL identifier (max 63 chars).
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
channel: The Edda channel name.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
PostgreSQL-safe channel name.
|
|
301
|
+
"""
|
|
302
|
+
import hashlib
|
|
303
|
+
|
|
304
|
+
h = hashlib.sha256(channel.encode()).hexdigest()[:16]
|
|
305
|
+
return f"edda_msg_{h}"
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def make_notify_payload(data: dict[str, Any]) -> str:
|
|
309
|
+
"""Create JSON payload for NOTIFY.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
data: Dictionary to serialize as JSON.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
JSON string (kept under 7500 bytes for PostgreSQL safety).
|
|
316
|
+
"""
|
|
317
|
+
payload = json.dumps(data, separators=(",", ":")) # Compact JSON
|
|
318
|
+
if len(payload) > 7500:
|
|
319
|
+
logger.warning(
|
|
320
|
+
f"NOTIFY payload exceeds recommended size " f"({len(payload)} > 7500 bytes), truncating"
|
|
321
|
+
)
|
|
322
|
+
# For safety, just include essential fields
|
|
323
|
+
minimal_data = {k: v for k, v in data.items() if k in ("wf_id", "ts")}
|
|
324
|
+
payload = json.dumps(minimal_data, separators=(",", ":"))
|
|
325
|
+
return payload
|
edda/storage/protocol.py
CHANGED
|
@@ -262,6 +262,7 @@ class StorageProtocol(Protocol):
|
|
|
262
262
|
instance_id_filter: str | None = None,
|
|
263
263
|
started_after: datetime | None = None,
|
|
264
264
|
started_before: datetime | None = None,
|
|
265
|
+
input_filters: dict[str, Any] | None = None,
|
|
265
266
|
) -> dict[str, Any]:
|
|
266
267
|
"""
|
|
267
268
|
List workflow instances with cursor-based pagination and filtering.
|
|
@@ -277,6 +278,9 @@ class StorageProtocol(Protocol):
|
|
|
277
278
|
instance_id_filter: Optional instance ID filter (partial match, case-insensitive)
|
|
278
279
|
started_after: Filter instances started after this datetime (inclusive)
|
|
279
280
|
started_before: Filter instances started before this datetime (inclusive)
|
|
281
|
+
input_filters: Filter by input data values. Keys are JSON paths
|
|
282
|
+
(e.g., "order_id" or "customer.email"), values are expected
|
|
283
|
+
values (exact match). All filters are AND-combined.
|
|
280
284
|
|
|
281
285
|
Returns:
|
|
282
286
|
Dictionary containing:
|
|
@@ -860,7 +864,7 @@ class StorageProtocol(Protocol):
|
|
|
860
864
|
# Workflow Resumption Methods
|
|
861
865
|
# -------------------------------------------------------------------------
|
|
862
866
|
|
|
863
|
-
async def find_resumable_workflows(self) -> list[dict[str, Any]]:
|
|
867
|
+
async def find_resumable_workflows(self, limit: int | None = None) -> list[dict[str, Any]]:
|
|
864
868
|
"""
|
|
865
869
|
Find workflows that are ready to be resumed.
|
|
866
870
|
|
|
@@ -873,6 +877,10 @@ class StorageProtocol(Protocol):
|
|
|
873
877
|
This allows immediate resumption after message delivery rather than
|
|
874
878
|
waiting for the stale lock cleanup cycle (60+ seconds).
|
|
875
879
|
|
|
880
|
+
Args:
|
|
881
|
+
limit: Optional maximum number of workflows to return.
|
|
882
|
+
If None, returns all resumable workflows.
|
|
883
|
+
|
|
876
884
|
Returns:
|
|
877
885
|
List of resumable workflows.
|
|
878
886
|
Each item contains: instance_id, workflow_name
|
|
@@ -8,6 +8,7 @@ and transactional outbox pattern.
|
|
|
8
8
|
|
|
9
9
|
import json
|
|
10
10
|
import logging
|
|
11
|
+
import re
|
|
11
12
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
12
13
|
from contextlib import asynccontextmanager
|
|
13
14
|
from contextvars import ContextVar
|
|
@@ -511,14 +512,43 @@ class SQLAlchemyStorage:
|
|
|
511
512
|
- Automatic transaction management via @activity decorator
|
|
512
513
|
"""
|
|
513
514
|
|
|
514
|
-
def __init__(
|
|
515
|
+
def __init__(
|
|
516
|
+
self,
|
|
517
|
+
engine: AsyncEngine,
|
|
518
|
+
notify_listener: Any | None = None,
|
|
519
|
+
):
|
|
515
520
|
"""
|
|
516
521
|
Initialize SQLAlchemy storage.
|
|
517
522
|
|
|
518
523
|
Args:
|
|
519
524
|
engine: SQLAlchemy AsyncEngine instance
|
|
525
|
+
notify_listener: Optional notify listener for PostgreSQL LISTEN/NOTIFY.
|
|
526
|
+
If provided and PostgreSQL is used, NOTIFY messages
|
|
527
|
+
will be sent after key operations.
|
|
520
528
|
"""
|
|
521
529
|
self.engine = engine
|
|
530
|
+
self._notify_listener = notify_listener
|
|
531
|
+
|
|
532
|
+
@property
|
|
533
|
+
def _is_postgresql(self) -> bool:
|
|
534
|
+
"""Check if the database is PostgreSQL."""
|
|
535
|
+
return self.engine.dialect.name == "postgresql"
|
|
536
|
+
|
|
537
|
+
@property
|
|
538
|
+
def _notify_enabled(self) -> bool:
|
|
539
|
+
"""Check if NOTIFY is enabled (PostgreSQL with listener)."""
|
|
540
|
+
return self._is_postgresql and self._notify_listener is not None
|
|
541
|
+
|
|
542
|
+
def set_notify_listener(self, listener: Any) -> None:
|
|
543
|
+
"""Set the notify listener after initialization.
|
|
544
|
+
|
|
545
|
+
This allows setting the listener after EddaApp creates the storage,
|
|
546
|
+
useful for dependency injection patterns.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
listener: NotifyProtocol implementation (PostgresNotifyListener or NoopNotifyListener)
|
|
550
|
+
"""
|
|
551
|
+
self._notify_listener = listener
|
|
522
552
|
|
|
523
553
|
async def initialize(self) -> None:
|
|
524
554
|
"""Initialize database connection and create tables.
|
|
@@ -544,6 +574,39 @@ class SQLAlchemyStorage:
|
|
|
544
574
|
"""Close database connection."""
|
|
545
575
|
await self.engine.dispose()
|
|
546
576
|
|
|
577
|
+
async def _send_notify(
|
|
578
|
+
self,
|
|
579
|
+
channel: str,
|
|
580
|
+
payload: dict[str, Any],
|
|
581
|
+
) -> None:
|
|
582
|
+
"""Send PostgreSQL NOTIFY message.
|
|
583
|
+
|
|
584
|
+
This method sends a notification on the specified channel with the given
|
|
585
|
+
payload. It's a no-op if NOTIFY is not enabled (non-PostgreSQL or no listener).
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
channel: PostgreSQL NOTIFY channel name (max 63 chars).
|
|
589
|
+
payload: Dictionary to serialize as JSON payload (max ~7500 bytes).
|
|
590
|
+
"""
|
|
591
|
+
if not self._notify_enabled:
|
|
592
|
+
return
|
|
593
|
+
|
|
594
|
+
try:
|
|
595
|
+
import json as json_module
|
|
596
|
+
|
|
597
|
+
payload_str = json_module.dumps(payload, separators=(",", ":"))
|
|
598
|
+
|
|
599
|
+
# Use a separate connection for NOTIFY to avoid transaction issues
|
|
600
|
+
async with self.engine.connect() as conn:
|
|
601
|
+
await conn.execute(
|
|
602
|
+
text("SELECT pg_notify(:channel, :payload)"),
|
|
603
|
+
{"channel": channel, "payload": payload_str},
|
|
604
|
+
)
|
|
605
|
+
await conn.commit()
|
|
606
|
+
except Exception as e:
|
|
607
|
+
# Log but don't fail - polling will catch it as backup
|
|
608
|
+
logger.warning(f"Failed to send NOTIFY on channel {channel}: {e}")
|
|
609
|
+
|
|
547
610
|
async def _initialize_schema_version(self) -> None:
|
|
548
611
|
"""Initialize schema version for a fresh database."""
|
|
549
612
|
async with AsyncSession(self.engine) as session:
|
|
@@ -849,6 +912,60 @@ class SQLAlchemyStorage:
|
|
|
849
912
|
# PostgreSQL/MySQL: column is already timezone-aware
|
|
850
913
|
return column
|
|
851
914
|
|
|
915
|
+
def _validate_json_path(self, json_path: str) -> bool:
|
|
916
|
+
"""
|
|
917
|
+
Validate JSON path to prevent SQL injection.
|
|
918
|
+
|
|
919
|
+
Args:
|
|
920
|
+
json_path: JSON path string (e.g., "order_id" or "customer.email")
|
|
921
|
+
|
|
922
|
+
Returns:
|
|
923
|
+
True if valid, False otherwise
|
|
924
|
+
"""
|
|
925
|
+
# Only allow alphanumeric characters, dots, and underscores
|
|
926
|
+
# Must start with letter or underscore
|
|
927
|
+
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_.]*$", json_path))
|
|
928
|
+
|
|
929
|
+
def _build_json_extract_expr(self, column: Any, json_path: str) -> Any:
|
|
930
|
+
"""
|
|
931
|
+
Build a database-agnostic JSON extraction expression.
|
|
932
|
+
|
|
933
|
+
Args:
|
|
934
|
+
column: SQLAlchemy column containing JSON text
|
|
935
|
+
json_path: Dot-notation path (e.g., "order_id" or "customer.email")
|
|
936
|
+
|
|
937
|
+
Returns:
|
|
938
|
+
SQLAlchemy expression that extracts the value at json_path
|
|
939
|
+
|
|
940
|
+
Raises:
|
|
941
|
+
ValueError: If json_path is invalid or dialect is unsupported
|
|
942
|
+
"""
|
|
943
|
+
if not self._validate_json_path(json_path):
|
|
944
|
+
raise ValueError(f"Invalid JSON path: {json_path}")
|
|
945
|
+
|
|
946
|
+
full_path = f"$.{json_path}"
|
|
947
|
+
dialect = self.engine.dialect.name
|
|
948
|
+
|
|
949
|
+
if dialect == "sqlite":
|
|
950
|
+
# SQLite: json_extract(column, '$.key')
|
|
951
|
+
return func.json_extract(column, full_path)
|
|
952
|
+
elif dialect == "postgresql":
|
|
953
|
+
# PostgreSQL: For nested paths, use #>> operator with array path
|
|
954
|
+
# For simple paths, use ->> operator
|
|
955
|
+
# Since we're dealing with Text column, we need to cast first
|
|
956
|
+
if "." in json_path:
|
|
957
|
+
# Nested: (column)::json #>> '{customer,email}'
|
|
958
|
+
path_array = "{" + json_path.replace(".", ",") + "}"
|
|
959
|
+
return text(f"(input_data)::json #>> '{path_array}'")
|
|
960
|
+
else:
|
|
961
|
+
# Simple: (column)::json->>'key'
|
|
962
|
+
return text(f"(input_data)::json->>'{json_path}'")
|
|
963
|
+
elif dialect == "mysql":
|
|
964
|
+
# MySQL: JSON_UNQUOTE(JSON_EXTRACT(column, '$.key'))
|
|
965
|
+
return func.JSON_UNQUOTE(func.JSON_EXTRACT(column, full_path))
|
|
966
|
+
else:
|
|
967
|
+
raise ValueError(f"Unsupported database dialect: {dialect}")
|
|
968
|
+
|
|
852
969
|
# -------------------------------------------------------------------------
|
|
853
970
|
# Transaction Management Methods
|
|
854
971
|
# -------------------------------------------------------------------------
|
|
@@ -1200,6 +1317,7 @@ class SQLAlchemyStorage:
|
|
|
1200
1317
|
instance_id_filter: str | None = None,
|
|
1201
1318
|
started_after: datetime | None = None,
|
|
1202
1319
|
started_before: datetime | None = None,
|
|
1320
|
+
input_filters: dict[str, Any] | None = None,
|
|
1203
1321
|
) -> dict[str, Any]:
|
|
1204
1322
|
"""List workflow instances with cursor-based pagination and filtering."""
|
|
1205
1323
|
session = self._get_session_for_operation()
|
|
@@ -1292,6 +1410,42 @@ class SQLAlchemyStorage:
|
|
|
1292
1410
|
started_before_comparable = started_before
|
|
1293
1411
|
stmt = stmt.where(started_at_comparable <= started_before_comparable)
|
|
1294
1412
|
|
|
1413
|
+
# Apply input data filters (JSON field matching)
|
|
1414
|
+
if input_filters:
|
|
1415
|
+
for json_path, expected_value in input_filters.items():
|
|
1416
|
+
dialect = self.engine.dialect.name
|
|
1417
|
+
if dialect == "postgresql":
|
|
1418
|
+
# PostgreSQL: use text() for the entire condition
|
|
1419
|
+
if not self._validate_json_path(json_path):
|
|
1420
|
+
raise ValueError(f"Invalid JSON path: {json_path}")
|
|
1421
|
+
if "." in json_path:
|
|
1422
|
+
path_array = "{" + json_path.replace(".", ",") + "}"
|
|
1423
|
+
json_sql = f"(input_data)::json #>> '{path_array}'"
|
|
1424
|
+
else:
|
|
1425
|
+
json_sql = f"(input_data)::json->>'{json_path}'"
|
|
1426
|
+
if expected_value is None:
|
|
1427
|
+
stmt = stmt.where(text(f"({json_sql} IS NULL OR {json_sql} = 'null')"))
|
|
1428
|
+
else:
|
|
1429
|
+
# Escape single quotes in value
|
|
1430
|
+
safe_value = str(expected_value).replace("'", "''")
|
|
1431
|
+
stmt = stmt.where(text(f"{json_sql} = '{safe_value}'"))
|
|
1432
|
+
else:
|
|
1433
|
+
# SQLite and MySQL: use func-based approach
|
|
1434
|
+
json_expr = self._build_json_extract_expr(
|
|
1435
|
+
WorkflowInstance.input_data, json_path
|
|
1436
|
+
)
|
|
1437
|
+
if expected_value is None:
|
|
1438
|
+
stmt = stmt.where(or_(json_expr.is_(None), json_expr == "null"))
|
|
1439
|
+
elif isinstance(expected_value, bool):
|
|
1440
|
+
stmt = stmt.where(json_expr == str(expected_value).lower())
|
|
1441
|
+
elif isinstance(expected_value, (int, float)):
|
|
1442
|
+
if dialect == "sqlite":
|
|
1443
|
+
stmt = stmt.where(json_expr == expected_value)
|
|
1444
|
+
else:
|
|
1445
|
+
stmt = stmt.where(json_expr == str(expected_value))
|
|
1446
|
+
else:
|
|
1447
|
+
stmt = stmt.where(json_expr == str(expected_value))
|
|
1448
|
+
|
|
1295
1449
|
# Fetch limit+1 to determine if there are more pages
|
|
1296
1450
|
stmt = stmt.limit(limit + 1)
|
|
1297
1451
|
|
|
@@ -2096,6 +2250,12 @@ class SQLAlchemyStorage:
|
|
|
2096
2250
|
session.add(event)
|
|
2097
2251
|
await self._commit_if_not_in_transaction(session)
|
|
2098
2252
|
|
|
2253
|
+
# Send NOTIFY for new outbox event
|
|
2254
|
+
await self._send_notify(
|
|
2255
|
+
"edda_outbox_pending",
|
|
2256
|
+
{"evt_id": event_id, "evt_type": event_type},
|
|
2257
|
+
)
|
|
2258
|
+
|
|
2099
2259
|
async def get_pending_outbox_events(self, limit: int = 10) -> list[dict[str, Any]]:
|
|
2100
2260
|
"""
|
|
2101
2261
|
Get pending/failed outbox events for publishing (with row-level locking).
|
|
@@ -2719,29 +2879,34 @@ class SQLAlchemyStorage:
|
|
|
2719
2879
|
# Workflow Resumption Methods
|
|
2720
2880
|
# -------------------------------------------------------------------------
|
|
2721
2881
|
|
|
2722
|
-
async def find_resumable_workflows(self) -> list[dict[str, Any]]:
|
|
2882
|
+
async def find_resumable_workflows(self, limit: int | None = None) -> list[dict[str, Any]]:
|
|
2723
2883
|
"""
|
|
2724
2884
|
Find workflows that are ready to be resumed.
|
|
2725
2885
|
|
|
2726
2886
|
Returns workflows with status='running' that don't have an active lock.
|
|
2727
2887
|
Used for immediate resumption after message delivery.
|
|
2728
2888
|
|
|
2889
|
+
Args:
|
|
2890
|
+
limit: Optional maximum number of workflows to return.
|
|
2891
|
+
If None, returns all resumable workflows.
|
|
2892
|
+
|
|
2729
2893
|
Returns:
|
|
2730
2894
|
List of resumable workflows with instance_id and workflow_name.
|
|
2731
2895
|
"""
|
|
2732
2896
|
session = self._get_session_for_operation()
|
|
2733
2897
|
async with self._session_scope(session) as session:
|
|
2734
|
-
|
|
2735
|
-
|
|
2736
|
-
|
|
2737
|
-
|
|
2738
|
-
|
|
2739
|
-
|
|
2740
|
-
|
|
2741
|
-
|
|
2742
|
-
|
|
2743
|
-
|
|
2744
|
-
|
|
2898
|
+
query = select(
|
|
2899
|
+
WorkflowInstance.instance_id,
|
|
2900
|
+
WorkflowInstance.workflow_name,
|
|
2901
|
+
).where(
|
|
2902
|
+
and_(
|
|
2903
|
+
WorkflowInstance.status == "running",
|
|
2904
|
+
WorkflowInstance.locked_by.is_(None),
|
|
2905
|
+
)
|
|
2906
|
+
)
|
|
2907
|
+
if limit is not None:
|
|
2908
|
+
query = query.limit(limit)
|
|
2909
|
+
result = await session.execute(query)
|
|
2745
2910
|
return [
|
|
2746
2911
|
{
|
|
2747
2912
|
"instance_id": row.instance_id,
|
|
@@ -2837,6 +3002,15 @@ class SQLAlchemyStorage:
|
|
|
2837
3002
|
session.add(msg)
|
|
2838
3003
|
await self._commit_if_not_in_transaction(session)
|
|
2839
3004
|
|
|
3005
|
+
# Send NOTIFY for message published (channel-specific)
|
|
3006
|
+
import hashlib
|
|
3007
|
+
|
|
3008
|
+
channel_hash = hashlib.sha256(channel.encode()).hexdigest()[:16]
|
|
3009
|
+
await self._send_notify(
|
|
3010
|
+
f"edda_msg_{channel_hash}",
|
|
3011
|
+
{"ch": channel, "msg_id": message_id},
|
|
3012
|
+
)
|
|
3013
|
+
|
|
2840
3014
|
return message_id
|
|
2841
3015
|
|
|
2842
3016
|
async def subscribe_to_channel(
|
|
@@ -3395,6 +3569,12 @@ class SQLAlchemyStorage:
|
|
|
3395
3569
|
|
|
3396
3570
|
await session.commit()
|
|
3397
3571
|
|
|
3572
|
+
# Send NOTIFY for workflow resumable
|
|
3573
|
+
await self._send_notify(
|
|
3574
|
+
"edda_workflow_resumable",
|
|
3575
|
+
{"wf_id": instance_id, "wf_name": workflow_name},
|
|
3576
|
+
)
|
|
3577
|
+
|
|
3398
3578
|
return {
|
|
3399
3579
|
"instance_id": instance_id,
|
|
3400
3580
|
"workflow_name": workflow_name,
|