edda-framework 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edda/channels.py ADDED
@@ -0,0 +1,992 @@
1
+ """
2
+ Channel-based Message Queue System for Edda.
3
+
4
+ This module provides Erlang/Elixir mailbox-style messaging with support for:
5
+ - Broadcast mode: All subscribers receive all messages (fan-out pattern)
6
+ - Competing mode: Each message is processed by only one subscriber (producer-consumer pattern)
7
+
8
+ Key concepts:
9
+ - Channel: A named message queue with persistent storage
10
+ - Message: A data payload published to a channel
11
+ - Subscription: A workflow's interest in receiving messages from a channel
12
+
13
+ The channel system solves the "mailbox problem" where messages sent before
14
+ `receive()` is called would be lost. Messages are always queued and persist
15
+ until consumed.
16
+
17
+ Example:
18
+ >>> from edda.channels import subscribe, receive, publish, ChannelMessage
19
+ >>>
20
+ >>> @workflow
21
+ ... async def worker(ctx: WorkflowContext, id: str):
22
+ ... # Subscribe to a channel
23
+ ... await subscribe(ctx, "tasks", mode="competing")
24
+ ...
25
+ ... while True:
26
+ ... # Receive messages (blocks until message available)
27
+ ... msg = await receive(ctx, "tasks")
28
+ ... await process(ctx, msg.data, activity_id=f"process:{msg.id}")
29
+ ... await ctx.recur()
30
+ >>>
31
+ >>> @workflow
32
+ ... async def producer(ctx: WorkflowContext, task_data: dict):
33
+ ... await publish(ctx, "tasks", task_data)
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import uuid
39
+ from dataclasses import dataclass, field
40
+ from datetime import UTC, datetime, timedelta
41
+ from typing import TYPE_CHECKING, Any, overload
42
+
43
+ if TYPE_CHECKING:
44
+ from edda.context import WorkflowContext
45
+ from edda.storage.protocol import StorageProtocol
46
+
47
+
48
+ # =============================================================================
49
+ # Data Classes
50
+ # =============================================================================
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class ChannelMessage:
55
+ """
56
+ A message received from a channel.
57
+
58
+ Attributes:
59
+ id: Unique message identifier
60
+ channel: Channel name this message was received on
61
+ data: Message payload (dict or bytes)
62
+ metadata: Optional metadata (source, timestamp, etc.)
63
+ published_at: When the message was published
64
+ """
65
+
66
+ id: str
67
+ channel: str
68
+ data: dict[str, Any] | bytes
69
+ metadata: dict[str, Any] = field(default_factory=dict)
70
+ published_at: datetime = field(default_factory=lambda: datetime.now(UTC))
71
+
72
+
73
+ # =============================================================================
74
+ # Exceptions
75
+ # =============================================================================
76
+
77
+
78
+ class WaitForChannelMessageException(Exception):
79
+ """
80
+ Raised to pause workflow execution until a channel message arrives.
81
+
82
+ This exception is caught by the ReplayEngine to:
83
+ 1. Register the workflow as waiting for a channel message
84
+ 2. Release the workflow lock
85
+ 3. Update workflow status to 'waiting_for_message'
86
+
87
+ The workflow will be resumed when a message is delivered to the channel.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ channel: str,
93
+ timeout_seconds: int | None,
94
+ activity_id: str,
95
+ ) -> None:
96
+ self.channel = channel
97
+ self.timeout_seconds = timeout_seconds
98
+ self.activity_id = activity_id
99
+ # Calculate absolute timeout if specified
100
+ self.timeout_at: datetime | None = None
101
+ if timeout_seconds is not None:
102
+ self.timeout_at = datetime.now(UTC) + timedelta(seconds=timeout_seconds)
103
+ super().__init__(f"Waiting for message on channel: {channel}")
104
+
105
+
106
+ class WaitForTimerException(Exception):
107
+ """
108
+ Raised to pause workflow execution until a timer expires.
109
+
110
+ This exception is caught by the ReplayEngine to:
111
+ 1. Register a timer subscription in the database
112
+ 2. Release the workflow lock
113
+ 3. Update workflow status to 'waiting_for_timer'
114
+
115
+ The workflow will be resumed when the timer expires.
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ duration_seconds: int,
121
+ expires_at: datetime,
122
+ timer_id: str,
123
+ activity_id: str,
124
+ ) -> None:
125
+ self.duration_seconds = duration_seconds
126
+ self.expires_at = expires_at
127
+ self.timer_id = timer_id
128
+ self.activity_id = activity_id
129
+ super().__init__(f"Waiting for timer: {timer_id}")
130
+
131
+
132
+ # =============================================================================
133
+ # Subscription Functions
134
+ # =============================================================================
135
+
136
+
137
+ async def subscribe(
138
+ ctx: WorkflowContext,
139
+ channel: str,
140
+ mode: str = "broadcast",
141
+ ) -> None:
142
+ """
143
+ Subscribe to a channel for receiving messages.
144
+
145
+ Args:
146
+ ctx: Workflow context
147
+ channel: Channel name to subscribe to
148
+ mode: Subscription mode - "broadcast" (all subscribers receive all messages)
149
+ or "competing" (each message goes to only one subscriber)
150
+
151
+ Example:
152
+ >>> @workflow
153
+ ... async def event_handler(ctx: WorkflowContext, id: str):
154
+ ... # Subscribe to order events (all handlers receive all events)
155
+ ... await subscribe(ctx, "order.events", mode="broadcast")
156
+ ...
157
+ ... while True:
158
+ ... event = await receive(ctx, "order.events")
159
+ ... await handle_event(ctx, event.data, activity_id=f"handle:{event.id}")
160
+ ... await ctx.recur()
161
+
162
+ >>> @workflow
163
+ ... async def job_worker(ctx: WorkflowContext, worker_id: str):
164
+ ... # Subscribe to job queue (each job processed by one worker)
165
+ ... await subscribe(ctx, "jobs", mode="competing")
166
+ ...
167
+ ... while True:
168
+ ... job = await receive(ctx, "jobs")
169
+ ... await execute_job(ctx, job.data, activity_id=f"job:{job.id}")
170
+ ... await ctx.recur()
171
+ """
172
+ if mode not in ("broadcast", "competing"):
173
+ raise ValueError(f"Invalid subscription mode: {mode}. Must be 'broadcast' or 'competing'")
174
+
175
+ await ctx.storage.subscribe_to_channel(ctx.instance_id, channel, mode)
176
+
177
+
178
+ async def unsubscribe(
179
+ ctx: WorkflowContext,
180
+ channel: str,
181
+ ) -> None:
182
+ """
183
+ Unsubscribe from a channel.
184
+
185
+ Note: Workflows are automatically unsubscribed from all channels when they
186
+ complete, fail, or are cancelled. Explicit unsubscribe is usually not necessary.
187
+
188
+ Args:
189
+ ctx: Workflow context
190
+ channel: Channel name to unsubscribe from
191
+ """
192
+ await ctx.storage.unsubscribe_from_channel(ctx.instance_id, channel)
193
+
194
+
195
+ # =============================================================================
196
+ # Message Receiving
197
+ # =============================================================================
198
+
199
+
200
+ async def receive(
201
+ ctx: WorkflowContext,
202
+ channel: str,
203
+ timeout_seconds: int | None = None,
204
+ message_id: str | None = None,
205
+ ) -> ChannelMessage:
206
+ """
207
+ Receive a message from a channel.
208
+
209
+ This function blocks (pauses the workflow) until a message is available
210
+ on the channel. Messages are queued persistently, so messages published
211
+ before this function is called will still be received.
212
+
213
+ Args:
214
+ ctx: Workflow context
215
+ channel: Channel name to receive from
216
+ timeout_seconds: Optional timeout in seconds
217
+ message_id: Optional ID for concurrent waiting (auto-generated if not provided)
218
+
219
+ Returns:
220
+ ChannelMessage object containing data and metadata
221
+
222
+ Raises:
223
+ WaitForChannelMessageException: Raised to pause workflow (caught by ReplayEngine)
224
+ TimeoutError: If timeout expires before message arrives
225
+
226
+ Example:
227
+ >>> @workflow
228
+ ... async def consumer(ctx: WorkflowContext, id: str):
229
+ ... await subscribe(ctx, "tasks", mode="competing")
230
+ ...
231
+ ... while True:
232
+ ... msg = await receive(ctx, "tasks")
233
+ ... await process(ctx, msg.data, activity_id=f"process:{msg.id}")
234
+ ... await ctx.recur()
235
+ """
236
+ # Generate activity ID
237
+ if message_id is None:
238
+ activity_id = ctx._generate_activity_id(f"receive_{channel}")
239
+ else:
240
+ activity_id = message_id
241
+
242
+ ctx._record_activity_id(activity_id)
243
+
244
+ # During replay, return cached message
245
+ if ctx.is_replaying:
246
+ found, cached_result = ctx._get_cached_result(activity_id)
247
+ if found:
248
+ # Check for cached error
249
+ if isinstance(cached_result, dict) and cached_result.get("_error"):
250
+ error_type = cached_result.get("error_type", "Exception")
251
+ error_message = cached_result.get("error_message", "Unknown error")
252
+ if error_type == "TimeoutError":
253
+ raise TimeoutError(error_message)
254
+ raise Exception(f"{error_type}: {error_message}")
255
+ # Return cached ChannelMessage
256
+ if isinstance(cached_result, ChannelMessage):
257
+ return cached_result
258
+ # Convert dict to ChannelMessage (from history)
259
+ if isinstance(cached_result, dict):
260
+ raw_data = cached_result.get("data", cached_result.get("payload", {}))
261
+ data: dict[str, Any] | bytes = (
262
+ raw_data if isinstance(raw_data, (dict, bytes)) else {}
263
+ )
264
+ published_at_str = cached_result.get("published_at")
265
+ published_at = (
266
+ datetime.fromisoformat(published_at_str)
267
+ if published_at_str
268
+ else datetime.now(UTC)
269
+ )
270
+ return ChannelMessage(
271
+ id=cached_result.get("id", "") or "",
272
+ channel=cached_result.get("channel", channel) or channel,
273
+ data=data,
274
+ metadata=cached_result.get("metadata") or {},
275
+ published_at=published_at,
276
+ )
277
+ raise RuntimeError(f"Unexpected cached result type: {type(cached_result)}")
278
+
279
+ # Check for pending messages in the queue
280
+ pending = await ctx.storage.get_pending_channel_messages(ctx.instance_id, channel)
281
+ if pending:
282
+ # Get the first pending message
283
+ msg_dict = pending[0]
284
+ msg_id = msg_dict["message_id"]
285
+
286
+ # For competing mode, try to claim the message
287
+ subscription = await _get_subscription(ctx, channel)
288
+ if subscription and subscription.get("mode") == "competing":
289
+ claimed = await ctx.storage.claim_channel_message(msg_id, ctx.instance_id)
290
+ if not claimed:
291
+ # Another worker claimed it, check next message
292
+ # For simplicity, raise exception to retry
293
+ raise WaitForChannelMessageException(
294
+ channel=channel,
295
+ timeout_seconds=timeout_seconds,
296
+ activity_id=activity_id,
297
+ )
298
+ # Delete the message after claiming (competing mode)
299
+ await ctx.storage.delete_channel_message(msg_id)
300
+ else:
301
+ # Broadcast mode - update cursor
302
+ await ctx.storage.update_delivery_cursor(channel, ctx.instance_id, msg_dict["id"])
303
+
304
+ # Build the message
305
+ raw_data = msg_dict.get("data")
306
+ data = raw_data if isinstance(raw_data, (dict, bytes)) else {}
307
+ published_at_str = msg_dict.get("published_at")
308
+ published_at = (
309
+ datetime.fromisoformat(published_at_str)
310
+ if isinstance(published_at_str, str)
311
+ else (published_at_str if isinstance(published_at_str, datetime) else datetime.now(UTC))
312
+ )
313
+
314
+ message = ChannelMessage(
315
+ id=msg_id,
316
+ channel=channel,
317
+ data=data,
318
+ metadata=msg_dict.get("metadata") or {},
319
+ published_at=published_at,
320
+ )
321
+
322
+ # Record in history for replay
323
+ await ctx.storage.append_history(
324
+ ctx.instance_id,
325
+ activity_id,
326
+ "ChannelMessageReceived",
327
+ {
328
+ "id": message.id,
329
+ "channel": message.channel,
330
+ "data": message.data,
331
+ "metadata": message.metadata,
332
+ "published_at": message.published_at.isoformat(),
333
+ },
334
+ )
335
+
336
+ return message
337
+
338
+ # No pending messages, raise exception to pause workflow
339
+ raise WaitForChannelMessageException(
340
+ channel=channel,
341
+ timeout_seconds=timeout_seconds,
342
+ activity_id=activity_id,
343
+ )
344
+
345
+
346
+ async def _get_subscription(ctx: WorkflowContext, channel: str) -> dict[str, Any] | None:
347
+ """Get the subscription info for a channel."""
348
+ return await ctx.storage.get_channel_subscription(ctx.instance_id, channel)
349
+
350
+
351
+ # =============================================================================
352
+ # Message Publishing
353
+ # =============================================================================
354
+
355
+
356
+ @overload
357
+ async def publish(
358
+ ctx_or_storage: WorkflowContext,
359
+ channel: str,
360
+ data: dict[str, Any] | bytes,
361
+ metadata: dict[str, Any] | None = None,
362
+ *,
363
+ target_instance_id: str | None = None,
364
+ worker_id: str | None = None,
365
+ ) -> str: ...
366
+
367
+
368
+ @overload
369
+ async def publish(
370
+ ctx_or_storage: StorageProtocol,
371
+ channel: str,
372
+ data: dict[str, Any] | bytes,
373
+ metadata: dict[str, Any] | None = None,
374
+ *,
375
+ target_instance_id: str | None = None,
376
+ worker_id: str | None = None,
377
+ ) -> str: ...
378
+
379
+
380
+ async def publish(
381
+ ctx_or_storage: WorkflowContext | StorageProtocol,
382
+ channel: str,
383
+ data: dict[str, Any] | bytes,
384
+ metadata: dict[str, Any] | None = None,
385
+ *,
386
+ target_instance_id: str | None = None,
387
+ worker_id: str | None = None,
388
+ ) -> str:
389
+ """
390
+ Publish a message to a channel.
391
+
392
+ Can be called from within a workflow (with WorkflowContext) or from
393
+ external code (with StorageProtocol directly).
394
+
395
+ Args:
396
+ ctx_or_storage: Workflow context or storage backend
397
+ channel: Channel name to publish to
398
+ data: Message payload (dict or bytes)
399
+ metadata: Optional metadata
400
+ target_instance_id: If provided, only deliver to this specific instance
401
+ (Point-to-Point delivery). If None, deliver to all
402
+ waiting subscribers (Pub/Sub delivery).
403
+ worker_id: Optional worker ID for Lock-First pattern (required for
404
+ CloudEvents HTTP handler)
405
+
406
+ Returns:
407
+ Message ID of the published message
408
+
409
+ Example:
410
+ >>> # From within a workflow
411
+ >>> @workflow
412
+ ... async def order_processor(ctx: WorkflowContext, order_id: str):
413
+ ... result = await process_order(ctx, order_id, activity_id="process:1")
414
+ ... await publish(ctx, "order.completed", {"order_id": order_id})
415
+ ... return result
416
+
417
+ >>> # From external code (e.g., HTTP handler)
418
+ >>> async def api_handler(request):
419
+ ... message_id = await publish(app.storage, "jobs", {"task": "process"})
420
+ ... return {"message_id": message_id}
421
+
422
+ >>> # Point-to-Point delivery (CloudEvents with eddainstanceid)
423
+ >>> await publish(
424
+ ... storage, "payment.completed", {"amount": 100},
425
+ ... target_instance_id="order-123", worker_id="worker-1"
426
+ ... )
427
+ """
428
+ # Determine if we have a context or direct storage
429
+ from edda.context import WorkflowContext as WfCtx
430
+
431
+ if isinstance(ctx_or_storage, WfCtx):
432
+ storage = ctx_or_storage.storage
433
+ # Add source metadata
434
+ full_metadata = metadata.copy() if metadata else {}
435
+ full_metadata.setdefault("source_instance_id", ctx_or_storage.instance_id)
436
+ full_metadata.setdefault("published_at", datetime.now(UTC).isoformat())
437
+ effective_worker_id = worker_id or ctx_or_storage.worker_id
438
+ else:
439
+ storage = ctx_or_storage
440
+ full_metadata = metadata.copy() if metadata else {}
441
+ full_metadata.setdefault("published_at", datetime.now(UTC).isoformat())
442
+ effective_worker_id = worker_id or f"publisher-{uuid.uuid4()}"
443
+
444
+ # Publish to channel
445
+ message_id = await storage.publish_to_channel(channel, data, full_metadata)
446
+
447
+ # Wake up waiting subscribers
448
+ await _wake_waiting_subscribers(
449
+ storage,
450
+ channel,
451
+ message_id,
452
+ data,
453
+ full_metadata,
454
+ target_instance_id=target_instance_id,
455
+ worker_id=effective_worker_id,
456
+ )
457
+
458
+ return message_id
459
+
460
+
461
+ async def _wake_waiting_subscribers(
462
+ storage: StorageProtocol,
463
+ channel: str,
464
+ message_id: str,
465
+ data: dict[str, Any] | bytes,
466
+ metadata: dict[str, Any],
467
+ *,
468
+ target_instance_id: str | None = None,
469
+ worker_id: str,
470
+ ) -> None:
471
+ """
472
+ Wake up subscribers waiting on a channel.
473
+
474
+ Args:
475
+ storage: Storage backend
476
+ channel: Channel name
477
+ message_id: Message ID
478
+ data: Message payload
479
+ metadata: Message metadata
480
+ target_instance_id: If provided, only wake this specific instance
481
+ (Point-to-Point delivery)
482
+ worker_id: Worker ID for Lock-First pattern
483
+ """
484
+ if target_instance_id:
485
+ # Point-to-Point delivery: deliver only to specific instance
486
+ await storage.deliver_channel_message(
487
+ instance_id=target_instance_id,
488
+ channel=channel,
489
+ message_id=message_id,
490
+ data=data,
491
+ metadata=metadata,
492
+ worker_id=worker_id,
493
+ )
494
+ return
495
+
496
+ # Pub/Sub delivery: deliver to all waiting subscribers
497
+ waiting = await storage.get_channel_subscribers_waiting(channel)
498
+
499
+ for sub in waiting:
500
+ instance_id = sub["instance_id"]
501
+ mode = sub["mode"]
502
+
503
+ if mode == "competing":
504
+ # For competing mode, only wake one subscriber
505
+ # Use Lock-First pattern
506
+ result = await storage.deliver_channel_message(
507
+ instance_id=instance_id,
508
+ channel=channel,
509
+ message_id=message_id,
510
+ data=data,
511
+ metadata=metadata,
512
+ worker_id=worker_id,
513
+ )
514
+ if result:
515
+ # Successfully woke one subscriber, stop
516
+ break
517
+ else:
518
+ # For broadcast mode, wake all subscribers
519
+ await storage.deliver_channel_message(
520
+ instance_id=instance_id,
521
+ channel=channel,
522
+ message_id=message_id,
523
+ data=data,
524
+ metadata=metadata,
525
+ worker_id=worker_id,
526
+ )
527
+
528
+
529
+ # =============================================================================
530
+ # Direct Messaging (Instance-to-Instance)
531
+ # =============================================================================
532
+
533
+
534
+ async def send_to(
535
+ ctx: WorkflowContext,
536
+ instance_id: str,
537
+ data: dict[str, Any] | bytes,
538
+ channel: str = "__direct__",
539
+ metadata: dict[str, Any] | None = None,
540
+ ) -> bool:
541
+ """
542
+ Send a message directly to a specific workflow instance.
543
+
544
+ This is useful for workflow-to-workflow communication where the target
545
+ instance ID is known.
546
+
547
+ Args:
548
+ ctx: Workflow context (source workflow)
549
+ instance_id: Target workflow instance ID
550
+ channel: Channel name (defaults to "__direct__" for direct messages)
551
+ data: Message payload
552
+ metadata: Optional metadata
553
+
554
+ Returns:
555
+ True if delivered, False if no workflow waiting
556
+
557
+ Example:
558
+ >>> @workflow
559
+ ... async def approver(ctx: WorkflowContext, request_id: str):
560
+ ... decision = await review(ctx, request_id, activity_id="review:1")
561
+ ... await send_to(ctx, instance_id=request_id, data={"approved": decision})
562
+ """
563
+ full_metadata = metadata.copy() if metadata else {}
564
+ full_metadata.setdefault("source_instance_id", ctx.instance_id)
565
+ full_metadata.setdefault("sent_at", datetime.now(UTC).isoformat())
566
+
567
+ # Publish to a direct channel for the target instance
568
+ direct_channel = f"{channel}:{instance_id}"
569
+ message_id = await ctx.storage.publish_to_channel(direct_channel, data, full_metadata)
570
+
571
+ # Try to deliver
572
+ result = await ctx.storage.deliver_channel_message(
573
+ instance_id=instance_id,
574
+ channel=direct_channel,
575
+ message_id=message_id,
576
+ data=data,
577
+ metadata=full_metadata,
578
+ worker_id=ctx.worker_id,
579
+ )
580
+
581
+ return result is not None
582
+
583
+
584
+ # =============================================================================
585
+ # Timer Functions
586
+ # =============================================================================
587
+
588
+
589
+ async def sleep(
590
+ ctx: WorkflowContext,
591
+ seconds: int,
592
+ timer_id: str | None = None,
593
+ ) -> None:
594
+ """
595
+ Pause workflow execution for a specified duration.
596
+
597
+ This is a durable sleep - the workflow will be resumed after the specified
598
+ time even if the worker restarts.
599
+
600
+ Args:
601
+ ctx: Workflow context
602
+ seconds: Duration to sleep in seconds
603
+ timer_id: Optional unique ID for this timer (auto-generated if not provided)
604
+
605
+ Example:
606
+ >>> @workflow
607
+ ... async def order_workflow(ctx: WorkflowContext, order_id: str):
608
+ ... await create_order(ctx, order_id, activity_id="create:1")
609
+ ... await sleep(ctx, 60) # Wait 60 seconds for payment
610
+ ... await check_payment(ctx, order_id, activity_id="check:1")
611
+ """
612
+ # Generate activity ID
613
+ if timer_id is None:
614
+ activity_id = ctx._generate_activity_id("sleep")
615
+ timer_id = activity_id
616
+ else:
617
+ activity_id = timer_id
618
+
619
+ ctx._record_activity_id(activity_id)
620
+
621
+ # During replay, return immediately
622
+ if ctx.is_replaying:
623
+ found, cached_result = ctx._get_cached_result(activity_id)
624
+ if found:
625
+ return
626
+
627
+ # Calculate expiry time (deterministic - calculated once)
628
+ expires_at = datetime.now(UTC) + timedelta(seconds=seconds)
629
+
630
+ # Raise exception to pause workflow
631
+ raise WaitForTimerException(
632
+ duration_seconds=seconds,
633
+ expires_at=expires_at,
634
+ timer_id=timer_id,
635
+ activity_id=activity_id,
636
+ )
637
+
638
+
639
+ async def sleep_until(
640
+ ctx: WorkflowContext,
641
+ target_time: datetime,
642
+ timer_id: str | None = None,
643
+ ) -> None:
644
+ """
645
+ Pause workflow execution until a specific time.
646
+
647
+ This is a durable sleep - the workflow will be resumed at the specified
648
+ time even if the worker restarts.
649
+
650
+ Args:
651
+ ctx: Workflow context
652
+ target_time: Absolute time to wake up (must be timezone-aware)
653
+ timer_id: Optional unique ID for this timer (auto-generated if not provided)
654
+
655
+ Example:
656
+ >>> from datetime import datetime, timedelta, UTC
657
+ >>>
658
+ >>> @workflow
659
+ ... async def scheduled_report(ctx: WorkflowContext, report_id: str):
660
+ ... # Schedule for tomorrow at 9 AM
661
+ ... tomorrow_9am = datetime.now(UTC).replace(hour=9, minute=0, second=0)
662
+ ... tomorrow_9am += timedelta(days=1)
663
+ ... await sleep_until(ctx, tomorrow_9am)
664
+ ... await generate_report(ctx, report_id, activity_id="generate:1")
665
+ """
666
+ if target_time.tzinfo is None:
667
+ raise ValueError("target_time must be timezone-aware")
668
+
669
+ # Generate activity ID
670
+ if timer_id is None:
671
+ activity_id = ctx._generate_activity_id("sleep_until")
672
+ timer_id = activity_id
673
+ else:
674
+ activity_id = timer_id
675
+
676
+ ctx._record_activity_id(activity_id)
677
+
678
+ # During replay, return immediately
679
+ if ctx.is_replaying:
680
+ found, cached_result = ctx._get_cached_result(activity_id)
681
+ if found:
682
+ return
683
+
684
+ # Calculate seconds until target
685
+ now = datetime.now(UTC)
686
+ delta = target_time - now
687
+ seconds = max(0, int(delta.total_seconds()))
688
+
689
+ # Raise exception to pause workflow
690
+ raise WaitForTimerException(
691
+ duration_seconds=seconds,
692
+ expires_at=target_time,
693
+ timer_id=timer_id,
694
+ activity_id=activity_id,
695
+ )
696
+
697
+
698
+ # =============================================================================
699
+ # CloudEvents Integration
700
+ # =============================================================================
701
+
702
+
703
+ @dataclass(frozen=True)
704
+ class ReceivedEvent:
705
+ """
706
+ Represents a CloudEvent received by a workflow.
707
+
708
+ This class provides structured access to both the event payload (data)
709
+ and CloudEvents metadata (type, source, id, time, etc.).
710
+
711
+ Attributes:
712
+ data: The event payload (JSON dict or Pydantic model)
713
+ type: CloudEvent type (e.g., "payment.completed")
714
+ source: CloudEvent source (e.g., "payment-service")
715
+ id: Unique event identifier
716
+ time: Event timestamp (ISO 8601 format)
717
+ datacontenttype: Content type of the data (typically "application/json")
718
+ subject: Subject of the event (optional CloudEvents extension)
719
+ extensions: Additional CloudEvents extension attributes
720
+
721
+ Example:
722
+ >>> # Without Pydantic model
723
+ >>> event = await wait_event(ctx, "payment.completed")
724
+ >>> amount = event.data["amount"]
725
+ >>> order_id = event.data["order_id"]
726
+ >>>
727
+ >>> # With Pydantic model (type-safe)
728
+ >>> event = await wait_event(ctx, "payment.completed", model=PaymentCompleted)
729
+ >>> amount = event.data.amount # Type-safe access
730
+ >>> order_id = event.data.order_id # IDE completion
731
+ >>>
732
+ >>> # Access CloudEvents metadata
733
+ >>> event_source = event.source
734
+ >>> event_time = event.time
735
+ >>> event_id = event.id
736
+ """
737
+
738
+ # Event payload (JSON dict or Pydantic model)
739
+ data: dict[str, Any] | Any # Any to support Pydantic models
740
+
741
+ # CloudEvents standard attributes
742
+ type: str
743
+ source: str
744
+ id: str
745
+ time: str | None = None
746
+ datacontenttype: str | None = None
747
+ subject: str | None = None
748
+
749
+ # CloudEvents extension attributes
750
+ extensions: dict[str, Any] = field(default_factory=dict)
751
+
752
+
753
+ class EventTimeoutError(Exception):
754
+ """
755
+ Exception raised when wait_event() times out.
756
+
757
+ This exception is raised when an event does not arrive within the
758
+ specified timeout period. The workflow can catch this exception to
759
+ handle timeout scenarios gracefully.
760
+
761
+ Example:
762
+ try:
763
+ event = await wait_event(ctx, "payment.completed", timeout_seconds=60)
764
+ except EventTimeoutError:
765
+ # Handle timeout - maybe send reminder or cancel order
766
+ await send_notification("Payment timeout")
767
+ """
768
+
769
+ def __init__(self, event_type: str, timeout_seconds: int):
770
+ self.event_type = event_type
771
+ self.timeout_seconds = timeout_seconds
772
+ super().__init__(f"Event '{event_type}' did not arrive within {timeout_seconds} seconds")
773
+
774
+
775
+ def _convert_channel_message_to_received_event(
776
+ msg: ChannelMessage,
777
+ event_type: str,
778
+ model: type[Any] | None = None,
779
+ ) -> ReceivedEvent:
780
+ """
781
+ Convert a ChannelMessage to a ReceivedEvent.
782
+
783
+ CloudEvents metadata is extracted from the message's metadata field
784
+ where it was stored with 'ce_' prefix.
785
+
786
+ Args:
787
+ msg: ChannelMessage received from receive()
788
+ event_type: The event type that was waited for
789
+ model: Optional Pydantic model to convert data to
790
+
791
+ Returns:
792
+ ReceivedEvent with CloudEvents metadata
793
+ """
794
+ from edda.pydantic_utils import from_json_dict
795
+
796
+ data: dict[str, Any] | Any
797
+ if model is not None and isinstance(msg.data, dict):
798
+ data = from_json_dict(msg.data, model)
799
+ elif isinstance(msg.data, dict):
800
+ data = msg.data
801
+ else:
802
+ # bytes data - wrap in dict for ReceivedEvent compatibility
803
+ data = {"_binary": msg.data}
804
+
805
+ return ReceivedEvent(
806
+ data=data,
807
+ type=event_type,
808
+ source=msg.metadata.get("ce_source", "unknown"),
809
+ id=msg.metadata.get("ce_id", msg.id),
810
+ time=msg.metadata.get("ce_time"),
811
+ datacontenttype=msg.metadata.get("ce_datacontenttype"),
812
+ subject=msg.metadata.get("ce_subject"),
813
+ extensions=msg.metadata.get("ce_extensions", {}),
814
+ )
815
+
816
+
817
+ async def wait_event(
818
+ ctx: WorkflowContext,
819
+ event_type: str,
820
+ timeout_seconds: int | None = None,
821
+ model: type[Any] | None = None,
822
+ event_id: str | None = None,
823
+ ) -> ReceivedEvent:
824
+ """
825
+ Wait for a CloudEvent to arrive.
826
+
827
+ This function pauses the workflow execution until a matching CloudEvent is received.
828
+ During replay, it returns the cached event data and metadata.
829
+
830
+ Internally, this uses the Channel-based Message Queue with event_type as the channel name.
831
+ CloudEvents metadata is preserved in the message metadata.
832
+
833
+ Args:
834
+ ctx: Workflow context
835
+ event_type: CloudEvent type to wait for (e.g., "payment.completed")
836
+ timeout_seconds: Optional timeout in seconds
837
+ model: Optional Pydantic model class to convert event data to
838
+ event_id: Optional event identifier (auto-generated if not provided)
839
+
840
+ Returns:
841
+ ReceivedEvent object containing event data and CloudEvents metadata.
842
+ If model is provided, ReceivedEvent.data will be a Pydantic model instance.
843
+
844
+ Note:
845
+ Events are delivered to workflows that are subscribed to the event_type channel.
846
+ Use subscribe(ctx, event_type) before calling wait_event() or let it auto-subscribe.
847
+
848
+ Raises:
849
+ WaitForChannelMessageException: During normal execution to pause the workflow
850
+ EventTimeoutError: If timeout is reached
851
+
852
+ Example:
853
+ >>> # Without Pydantic (dict access)
854
+ >>> @workflow
855
+ ... async def order_workflow(ctx: WorkflowContext, order_id: str):
856
+ ... await subscribe(ctx, "payment.completed", mode="broadcast")
857
+ ... payment_event = await wait_event(ctx, "payment.completed")
858
+ ... amount = payment_event.data["amount"]
859
+ ... order_id = payment_event.data["order_id"]
860
+ ...
861
+ >>> # With Pydantic (type-safe access)
862
+ >>> @workflow
863
+ ... async def order_workflow_typed(ctx: WorkflowContext, order_id: str):
864
+ ... await subscribe(ctx, "payment.completed", mode="broadcast")
865
+ ... payment_event = await wait_event(
866
+ ... ctx,
867
+ ... event_type="payment.completed",
868
+ ... model=PaymentCompleted
869
+ ... )
870
+ ... # Type-safe access with IDE completion
871
+ ... amount = payment_event.data.amount
872
+ """
873
+ # Auto-subscribe to the event_type channel if not already subscribed
874
+ subscription = await _get_subscription(ctx, event_type)
875
+ if subscription is None:
876
+ await subscribe(ctx, event_type, mode="broadcast")
877
+
878
+ # Use receive() with event_type as channel
879
+ msg = await receive(
880
+ ctx,
881
+ channel=event_type,
882
+ timeout_seconds=timeout_seconds,
883
+ message_id=event_id,
884
+ )
885
+
886
+ # Convert ChannelMessage to ReceivedEvent with CloudEvents metadata
887
+ return _convert_channel_message_to_received_event(msg, event_type, model)
888
+
889
+
890
+ # Backward compatibility aliases
891
+ wait_timer = sleep
892
+ wait_until = sleep_until
893
+
894
+
895
+ async def send_event(
896
+ event_type: str,
897
+ source: str,
898
+ data: dict[str, Any] | Any,
899
+ broker_url: str = "http://broker-ingress.knative-eventing.svc.cluster.local",
900
+ datacontenttype: str | None = None,
901
+ ) -> None:
902
+ """
903
+ Send a CloudEvent to Knative Broker.
904
+
905
+ Args:
906
+ event_type: CloudEvent type (e.g., "order.created")
907
+ source: CloudEvent source (e.g., "order-service")
908
+ data: Event payload (JSON dict or Pydantic model)
909
+ broker_url: Knative Broker URL
910
+ datacontenttype: Content type (defaults to "application/json")
911
+
912
+ Raises:
913
+ httpx.HTTPError: If the HTTP request fails
914
+
915
+ Example:
916
+ >>> # With dict
917
+ >>> await send_event("order.created", "order-service", {"order_id": "123"})
918
+ >>>
919
+ >>> # With Pydantic model (automatically converted to JSON)
920
+ >>> order = OrderCreated(order_id="123", amount=99.99)
921
+ >>> await send_event("order.created", "order-service", order)
922
+ """
923
+ import httpx
924
+ from cloudevents.conversion import to_structured
925
+ from cloudevents.http import CloudEvent
926
+
927
+ from edda.pydantic_utils import is_pydantic_instance, to_json_dict
928
+
929
+ # Convert Pydantic model to JSON dict
930
+ data_dict: dict[str, Any]
931
+ if is_pydantic_instance(data):
932
+ data_dict = to_json_dict(data)
933
+ elif isinstance(data, dict):
934
+ data_dict = data
935
+ else:
936
+ data_dict = {"_data": data}
937
+
938
+ # Create CloudEvent attributes
939
+ attributes: dict[str, Any] = {
940
+ "type": event_type,
941
+ "source": source,
942
+ "id": str(uuid.uuid4()),
943
+ }
944
+
945
+ # Set datacontenttype if specified
946
+ if datacontenttype:
947
+ attributes["datacontenttype"] = datacontenttype
948
+
949
+ event = CloudEvent(attributes, data_dict)
950
+
951
+ # Convert to structured format (HTTP)
952
+ headers, body = to_structured(event)
953
+
954
+ # Send to Knative Broker via HTTP POST
955
+ async with httpx.AsyncClient() as client:
956
+ response = await client.post(
957
+ broker_url,
958
+ headers=headers,
959
+ content=body,
960
+ timeout=10.0,
961
+ )
962
+ response.raise_for_status()
963
+
964
+
965
+ # =============================================================================
966
+ # Utility Functions
967
+ # =============================================================================
968
+
969
+
970
+ async def get_channel_stats(
971
+ _storage: StorageProtocol,
972
+ channel: str,
973
+ ) -> dict[str, Any]:
974
+ """
975
+ Get statistics about a channel.
976
+
977
+ Args:
978
+ storage: Storage backend
979
+ channel: Channel name
980
+
981
+ Returns:
982
+ Dictionary with channel statistics
983
+ """
984
+ # TODO: Implement actual statistics retrieval using _storage
985
+ # - Query ChannelMessage table for message_count
986
+ # - Query ChannelSubscription table for subscriber_count
987
+ # For now, return placeholder values
988
+ return {
989
+ "channel": channel,
990
+ "message_count": 0,
991
+ "subscriber_count": 0,
992
+ }