penguiflow 1.0.3__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of penguiflow might be problematic. Click here for more details.

penguiflow/core.py CHANGED
@@ -10,19 +10,25 @@ import asyncio
10
10
  import logging
11
11
  import time
12
12
  from collections import deque
13
- from collections.abc import Callable, Sequence
13
+ from collections.abc import Awaitable, Callable, Mapping, Sequence
14
+ from contextlib import suppress
14
15
  from dataclasses import dataclass
15
- from typing import Any
16
+ from typing import Any, cast
16
17
 
18
+ from .bus import BusEnvelope, MessageBus
19
+ from .errors import FlowError, FlowErrorCode
20
+ from .metrics import FlowEvent
17
21
  from .middlewares import Middleware
18
22
  from .node import Node, NodePolicy
19
23
  from .registry import ModelRegistry
20
- from .types import WM, FinalAnswer, Message
24
+ from .state import RemoteBinding, StateStore, StoredEvent
25
+ from .types import WM, FinalAnswer, Message, StreamChunk
21
26
 
22
27
  logger = logging.getLogger("penguiflow.core")
23
28
 
24
29
  BUDGET_EXCEEDED_TEXT = "Hop budget exhausted"
25
30
  DEADLINE_EXCEEDED_TEXT = "Deadline exceeded"
31
+ TOKEN_BUDGET_EXCEEDED_TEXT = "Token budget exhausted"
26
32
 
27
33
  DEFAULT_QUEUE_MAXSIZE = 64
28
34
 
@@ -59,21 +65,46 @@ class Floe:
59
65
  self.queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=maxsize)
60
66
 
61
67
 
68
+ class TraceCancelled(Exception):
69
+ """Raised when work for a specific trace_id is cancelled."""
70
+
71
+ def __init__(self, trace_id: str | None) -> None:
72
+ super().__init__(f"trace cancelled: {trace_id}")
73
+ self.trace_id = trace_id
74
+
75
+
62
76
  class Context:
63
77
  """Provides fetch/emit helpers for a node within a flow."""
64
78
 
65
- __slots__ = ("_owner", "_incoming", "_outgoing", "_buffer")
79
+ __slots__ = (
80
+ "_owner",
81
+ "_incoming",
82
+ "_outgoing",
83
+ "_buffer",
84
+ "_stream_seq",
85
+ "_runtime",
86
+ )
66
87
 
67
- def __init__(self, owner: Node | Endpoint) -> None:
88
+ def __init__(
89
+ self, owner: Node | Endpoint, runtime: PenguiFlow | None = None
90
+ ) -> None:
68
91
  self._owner = owner
69
92
  self._incoming: dict[Node | Endpoint, Floe] = {}
70
93
  self._outgoing: dict[Node | Endpoint, Floe] = {}
71
94
  self._buffer: deque[Any] = deque()
95
+ self._stream_seq: dict[str, int] = {}
96
+ self._runtime = runtime
72
97
 
73
98
  @property
74
99
  def owner(self) -> Node | Endpoint:
75
100
  return self._owner
76
101
 
102
+ @property
103
+ def runtime(self) -> PenguiFlow | None:
104
+ """Return the runtime this context is attached to, if any."""
105
+
106
+ return self._runtime
107
+
77
108
  def add_incoming_floe(self, floe: Floe) -> None:
78
109
  if floe.source is None:
79
110
  return
@@ -111,14 +142,79 @@ class Context:
111
142
  async def emit(
112
143
  self, msg: Any, to: Node | Endpoint | Sequence[Node | Endpoint] | None = None
113
144
  ) -> None:
145
+ if self._runtime is None:
146
+ raise RuntimeError("Context is not attached to a running flow")
114
147
  for floe in self._resolve_targets(to, self._outgoing):
115
- await floe.queue.put(msg)
148
+ await self._runtime._send_to_floe(floe, msg)
116
149
 
117
150
  def emit_nowait(
118
151
  self, msg: Any, to: Node | Endpoint | Sequence[Node | Endpoint] | None = None
119
152
  ) -> None:
153
+ if self._runtime is None:
154
+ raise RuntimeError("Context is not attached to a running flow")
120
155
  for floe in self._resolve_targets(to, self._outgoing):
121
- floe.queue.put_nowait(msg)
156
+ self._runtime._send_to_floe_nowait(floe, msg)
157
+
158
+ async def emit_chunk(
159
+ self,
160
+ *,
161
+ parent: Message,
162
+ text: str,
163
+ stream_id: str | None = None,
164
+ seq: int | None = None,
165
+ done: bool = False,
166
+ meta: dict[str, Any] | None = None,
167
+ to: Node | Endpoint | Sequence[Node | Endpoint] | None = None,
168
+ ) -> StreamChunk:
169
+ """Emit a streaming chunk that inherits routing metadata from ``parent``.
170
+
171
+ The helper manages monotonically increasing sequence numbers per
172
+ ``stream_id`` (defaulting to the parent's trace id) unless an explicit
173
+ ``seq`` is provided. It returns the emitted ``StreamChunk`` for
174
+ introspection in tests or downstream logic.
175
+ """
176
+
177
+ sid = stream_id or parent.trace_id
178
+ first_chunk = sid not in self._stream_seq
179
+ if seq is None:
180
+ next_seq = self._stream_seq.get(sid, -1) + 1
181
+ else:
182
+ next_seq = seq
183
+ self._stream_seq[sid] = next_seq
184
+
185
+ meta_dict = dict(meta) if meta else {}
186
+
187
+ chunk = StreamChunk(
188
+ stream_id=sid,
189
+ seq=next_seq,
190
+ text=text,
191
+ done=done,
192
+ meta=meta_dict,
193
+ )
194
+
195
+ message_meta = dict(parent.meta)
196
+
197
+ message = Message(
198
+ payload=chunk,
199
+ headers=parent.headers,
200
+ trace_id=parent.trace_id,
201
+ deadline_s=parent.deadline_s,
202
+ meta=message_meta,
203
+ )
204
+
205
+ runtime = self._runtime
206
+ if runtime is None:
207
+ raise RuntimeError("Context is not attached to a running flow")
208
+
209
+ if not first_chunk:
210
+ await runtime._await_trace_capacity(sid, offset=1)
211
+
212
+ await self.emit(message, to=to)
213
+
214
+ if done:
215
+ self._stream_seq.pop(sid, None)
216
+
217
+ return chunk
122
218
 
123
219
  def fetch_nowait(
124
220
  self, from_: Node | Endpoint | Sequence[Node | Endpoint] | None = None
@@ -175,6 +271,25 @@ class Context:
175
271
  def queue_depth_in(self) -> int:
176
272
  return sum(floe.queue.qsize() for floe in self._incoming.values())
177
273
 
274
+ def queue_depth_out(self) -> int:
275
+ return sum(floe.queue.qsize() for floe in self._outgoing.values())
276
+
277
+ async def call_playbook(
278
+ self,
279
+ playbook: PlaybookFactory,
280
+ parent_msg: Message,
281
+ *,
282
+ timeout: float | None = None,
283
+ ) -> Any:
284
+ """Launch a subflow playbook using the current runtime for propagation."""
285
+
286
+ return await call_playbook(
287
+ playbook,
288
+ parent_msg,
289
+ timeout=timeout,
290
+ runtime=self._runtime,
291
+ )
292
+
178
293
 
179
294
  class PenguiFlow:
180
295
  """Coordinates node execution and message routing."""
@@ -185,6 +300,9 @@ class PenguiFlow:
185
300
  queue_maxsize: int = DEFAULT_QUEUE_MAXSIZE,
186
301
  allow_cycles: bool = False,
187
302
  middlewares: Sequence[Middleware] | None = None,
303
+ emit_errors_to_rookery: bool = False,
304
+ state_store: StateStore | None = None,
305
+ message_bus: MessageBus | None = None,
188
306
  ) -> None:
189
307
  self._queue_maxsize = queue_maxsize
190
308
  self._allow_cycles = allow_cycles
@@ -196,6 +314,16 @@ class PenguiFlow:
196
314
  self._running = False
197
315
  self._registry: Any | None = None
198
316
  self._middlewares: list[Middleware] = list(middlewares or [])
317
+ self._trace_counts: dict[str, int] = {}
318
+ self._trace_events: dict[str, asyncio.Event] = {}
319
+ self._trace_invocations: dict[str, set[asyncio.Task[Any]]] = {}
320
+ self._external_tasks: dict[str, set[asyncio.Future[Any]]] = {}
321
+ self._trace_capacity_waiters: dict[str, list[asyncio.Event]] = {}
322
+ self._latest_wm_hops: dict[str, int] = {}
323
+ self._emit_errors_to_rookery = emit_errors_to_rookery
324
+ self._state_store = state_store
325
+ self._message_bus = message_bus
326
+ self._bus_tasks: set[asyncio.Task[None]] = set()
199
327
 
200
328
  self._build_graph(adjacencies)
201
329
 
@@ -219,9 +347,9 @@ class PenguiFlow:
219
347
 
220
348
  # create contexts for nodes and endpoints
221
349
  for node in self._nodes:
222
- self._contexts[node] = Context(node)
223
- self._contexts[OPEN_SEA] = Context(OPEN_SEA)
224
- self._contexts[ROOKERY] = Context(ROOKERY)
350
+ self._contexts[node] = Context(node, self)
351
+ self._contexts[OPEN_SEA] = Context(OPEN_SEA, self)
352
+ self._contexts[ROOKERY] = Context(ROOKERY, self)
225
353
 
226
354
  incoming: dict[Node, set[Node | Endpoint]] = {
227
355
  node: set() for node in self._nodes
@@ -303,7 +431,49 @@ class PenguiFlow:
303
431
  while True:
304
432
  try:
305
433
  message = await context.fetch()
306
- await self._execute_with_reliability(node, context, message)
434
+ trace_id = self._get_trace_id(message)
435
+ if self._deadline_expired(message):
436
+ await self._emit_event(
437
+ event="deadline_skip",
438
+ node=node,
439
+ context=context,
440
+ trace_id=trace_id,
441
+ attempt=0,
442
+ latency_ms=None,
443
+ level=logging.INFO,
444
+ extra={"deadline_s": getattr(message, "deadline_s", None)},
445
+ )
446
+ if isinstance(message, Message):
447
+ await self._handle_deadline_expired(context, message)
448
+ await self._finalize_message(message)
449
+ continue
450
+ if trace_id is not None and self._is_trace_cancelled(trace_id):
451
+ await self._emit_event(
452
+ event="trace_cancel_drop",
453
+ node=node,
454
+ context=context,
455
+ trace_id=trace_id,
456
+ attempt=0,
457
+ latency_ms=None,
458
+ level=logging.INFO,
459
+ )
460
+ await self._finalize_message(message)
461
+ continue
462
+
463
+ try:
464
+ await self._execute_with_reliability(node, context, message)
465
+ except TraceCancelled:
466
+ await self._emit_event(
467
+ event="node_trace_cancelled",
468
+ node=node,
469
+ context=context,
470
+ trace_id=trace_id,
471
+ attempt=0,
472
+ latency_ms=None,
473
+ level=logging.INFO,
474
+ )
475
+ finally:
476
+ await self._finalize_message(message)
307
477
  except asyncio.CancelledError:
308
478
  await self._emit_event(
309
479
  event="node_cancelled",
@@ -323,19 +493,166 @@ class PenguiFlow:
323
493
  task.cancel()
324
494
  await asyncio.gather(*self._tasks, return_exceptions=True)
325
495
  self._tasks.clear()
496
+ if self._trace_invocations:
497
+ pending: list[asyncio.Task[Any]] = []
498
+ for invocation_tasks in self._trace_invocations.values():
499
+ for task in invocation_tasks:
500
+ if not task.done():
501
+ task.cancel()
502
+ pending.append(task)
503
+ if pending:
504
+ await asyncio.gather(*pending, return_exceptions=True)
505
+ self._trace_invocations.clear()
506
+ if self._external_tasks:
507
+ pending_ext: list[asyncio.Future[Any]] = []
508
+ for external_tasks in self._external_tasks.values():
509
+ for external_task in external_tasks:
510
+ if not external_task.done():
511
+ external_task.cancel()
512
+ pending_ext.append(external_task)
513
+ if pending_ext:
514
+ await asyncio.gather(*pending_ext, return_exceptions=True)
515
+ self._external_tasks.clear()
516
+ if self._bus_tasks:
517
+ await asyncio.gather(*self._bus_tasks, return_exceptions=True)
518
+ self._bus_tasks.clear()
519
+ self._trace_counts.clear()
520
+ self._trace_events.clear()
521
+ self._trace_invocations.clear()
522
+ for waiters in self._trace_capacity_waiters.values():
523
+ for waiter in waiters:
524
+ waiter.set()
525
+ self._trace_capacity_waiters.clear()
326
526
  self._running = False
327
527
 
328
528
  async def emit(self, msg: Any, to: Node | Sequence[Node] | None = None) -> None:
529
+ if isinstance(msg, Message):
530
+ payload = msg.payload
531
+ if isinstance(payload, WM):
532
+ last = self._latest_wm_hops.get(msg.trace_id)
533
+ if last is not None and payload.hops == last:
534
+ return
329
535
  await self._contexts[OPEN_SEA].emit(msg, to)
330
536
 
331
537
  def emit_nowait(self, msg: Any, to: Node | Sequence[Node] | None = None) -> None:
538
+ if isinstance(msg, Message):
539
+ payload = msg.payload
540
+ if isinstance(payload, WM):
541
+ last = self._latest_wm_hops.get(msg.trace_id)
542
+ if last is not None and payload.hops == last:
543
+ return
332
544
  self._contexts[OPEN_SEA].emit_nowait(msg, to)
333
545
 
546
+ async def emit_chunk(
547
+ self,
548
+ *,
549
+ parent: Message,
550
+ text: str,
551
+ stream_id: str | None = None,
552
+ seq: int | None = None,
553
+ done: bool = False,
554
+ meta: dict[str, Any] | None = None,
555
+ to: Node | Sequence[Node] | None = None,
556
+ ) -> StreamChunk:
557
+ """Emit a streaming chunk from outside a node via OpenSea context."""
558
+
559
+ return await self._contexts[OPEN_SEA].emit_chunk(
560
+ parent=parent,
561
+ text=text,
562
+ stream_id=stream_id,
563
+ seq=seq,
564
+ done=done,
565
+ meta=meta,
566
+ to=to,
567
+ )
568
+
334
569
  async def fetch(self, from_: Node | Sequence[Node] | None = None) -> Any:
335
- return await self._contexts[ROOKERY].fetch(from_)
570
+ result = await self._contexts[ROOKERY].fetch(from_)
571
+ await self._finalize_message(result)
572
+ return result
336
573
 
337
574
  async def fetch_any(self, from_: Node | Sequence[Node] | None = None) -> Any:
338
- return await self._contexts[ROOKERY].fetch_any(from_)
575
+ result = await self._contexts[ROOKERY].fetch_any(from_)
576
+ await self._finalize_message(result)
577
+ return result
578
+
579
+ async def load_history(self, trace_id: str) -> Sequence[StoredEvent]:
580
+ """Return the persisted history for ``trace_id`` from the state store."""
581
+
582
+ if self._state_store is None:
583
+ raise RuntimeError("PenguiFlow was created without a state_store")
584
+ return await self._state_store.load_history(trace_id)
585
+
586
+ def ensure_trace_event(self, trace_id: str) -> asyncio.Event:
587
+ """Return (and create if needed) the cancellation event for ``trace_id``."""
588
+
589
+ return self._trace_events.setdefault(trace_id, asyncio.Event())
590
+
591
+ def register_external_task(self, trace_id: str, task: asyncio.Future[Any]) -> None:
592
+ """Track an externally created task for cancellation bookkeeping."""
593
+
594
+ if trace_id is None:
595
+ return
596
+ tasks = self._external_tasks.get(trace_id)
597
+ if tasks is None:
598
+ tasks = set[asyncio.Future[Any]]()
599
+ self._external_tasks[trace_id] = tasks
600
+ tasks.add(task)
601
+
602
+ def _cleanup(finished: asyncio.Future[Any]) -> None:
603
+ remaining = self._external_tasks.get(trace_id)
604
+ if remaining is None:
605
+ return
606
+ remaining.discard(finished)
607
+ if not remaining:
608
+ self._external_tasks.pop(trace_id, None)
609
+
610
+ task.add_done_callback(_cleanup)
611
+
612
+ async def save_remote_binding(self, binding: RemoteBinding) -> None:
613
+ """Persist a remote binding if a state store is configured."""
614
+
615
+ if self._state_store is None:
616
+ return
617
+ try:
618
+ await self._state_store.save_remote_binding(binding)
619
+ except Exception as exc: # pragma: no cover - defensive logging
620
+ logger.exception(
621
+ "state_store_binding_failed",
622
+ extra={
623
+ "event": "state_store_binding_failed",
624
+ "trace_id": binding.trace_id,
625
+ "context_id": binding.context_id,
626
+ "task_id": binding.task_id,
627
+ "agent_url": binding.agent_url,
628
+ "exception": repr(exc),
629
+ },
630
+ )
631
+
632
+ async def record_remote_event(
633
+ self,
634
+ *,
635
+ event: str,
636
+ node: Node,
637
+ context: Context,
638
+ trace_id: str | None,
639
+ latency_ms: float | None,
640
+ level: int = logging.INFO,
641
+ extra: Mapping[str, Any] | None = None,
642
+ ) -> None:
643
+ """Emit a structured :class:`FlowEvent` for remote transport activity."""
644
+
645
+ payload = dict(extra or {})
646
+ await self._emit_event(
647
+ event=event,
648
+ node=node,
649
+ context=context,
650
+ trace_id=trace_id,
651
+ attempt=0,
652
+ latency_ms=latency_ms,
653
+ level=level,
654
+ extra=payload,
655
+ )
339
656
 
340
657
  async def _execute_with_reliability(
341
658
  self,
@@ -347,6 +664,9 @@ class PenguiFlow:
347
664
  attempt = 0
348
665
 
349
666
  while True:
667
+ if trace_id is not None and self._is_trace_cancelled(trace_id):
668
+ raise TraceCancelled(trace_id)
669
+
350
670
  start = time.perf_counter()
351
671
  await self._emit_event(
352
672
  event="node_start",
@@ -359,22 +679,42 @@ class PenguiFlow:
359
679
  )
360
680
 
361
681
  try:
362
- invocation = node.invoke(message, context, registry=self._registry)
363
- if node.policy.timeout_s is not None:
364
- result = await asyncio.wait_for(invocation, node.policy.timeout_s)
365
- else:
366
- result = await invocation
682
+ result = await self._invoke_node(
683
+ node,
684
+ context,
685
+ message,
686
+ trace_id,
687
+ )
367
688
 
368
689
  if result is not None:
369
- destination, prepared, targets = self._controller_postprocess(
690
+ (
691
+ destination,
692
+ prepared,
693
+ targets,
694
+ deliver_rookery,
695
+ ) = self._controller_postprocess(
370
696
  node, context, message, result
371
697
  )
372
698
 
699
+ if deliver_rookery:
700
+ rookery_msg = (
701
+ prepared.model_copy(deep=True)
702
+ if isinstance(prepared, Message)
703
+ else prepared
704
+ )
705
+ await self._emit_to_rookery(
706
+ rookery_msg, source=context.owner
707
+ )
708
+
373
709
  if destination == "skip":
374
710
  continue
711
+
375
712
  if destination == "rookery":
376
- await context.emit(prepared, to=[ROOKERY])
713
+ await self._emit_to_rookery(
714
+ prepared, source=context.owner
715
+ )
377
716
  continue
717
+
378
718
  await context.emit(prepared, to=targets)
379
719
 
380
720
  latency = (time.perf_counter() - start) * 1000
@@ -388,6 +728,8 @@ class PenguiFlow:
388
728
  level=logging.INFO,
389
729
  )
390
730
  return
731
+ except TraceCancelled:
732
+ raise
391
733
  except asyncio.CancelledError:
392
734
  raise
393
735
  except TimeoutError as exc:
@@ -403,15 +745,28 @@ class PenguiFlow:
403
745
  extra={"exception": repr(exc)},
404
746
  )
405
747
  if attempt >= node.policy.max_retries:
406
- await self._emit_event(
407
- event="node_failed",
748
+ timeout_message: str | None = None
749
+ if node.policy.timeout_s is not None:
750
+ timeout_message = (
751
+ f"Node '{node.name}' timed out after"
752
+ f" {node.policy.timeout_s:.2f}s"
753
+ )
754
+ flow_error = self._create_flow_error(
408
755
  node=node,
409
- context=context,
410
756
  trace_id=trace_id,
757
+ code=FlowErrorCode.NODE_TIMEOUT,
758
+ exc=exc,
411
759
  attempt=attempt,
412
760
  latency_ms=latency,
413
- level=logging.ERROR,
414
- extra={"exception": repr(exc)},
761
+ message=timeout_message,
762
+ metadata={"timeout_s": node.policy.timeout_s},
763
+ )
764
+ await self._handle_flow_error(
765
+ node=node,
766
+ context=context,
767
+ flow_error=flow_error,
768
+ latency=latency,
769
+ attempt=attempt,
415
770
  )
416
771
  return
417
772
  attempt += 1
@@ -441,15 +796,24 @@ class PenguiFlow:
441
796
  extra={"exception": repr(exc)},
442
797
  )
443
798
  if attempt >= node.policy.max_retries:
444
- await self._emit_event(
445
- event="node_failed",
799
+ flow_error = self._create_flow_error(
446
800
  node=node,
447
- context=context,
448
801
  trace_id=trace_id,
802
+ code=FlowErrorCode.NODE_EXCEPTION,
803
+ exc=exc,
449
804
  attempt=attempt,
450
805
  latency_ms=latency,
451
- level=logging.ERROR,
452
- extra={"exception": repr(exc)},
806
+ message=(
807
+ f"Node '{node.name}' raised {type(exc).__name__}: {exc}"
808
+ ),
809
+ metadata={"exception_repr": repr(exc)},
810
+ )
811
+ await self._handle_flow_error(
812
+ node=node,
813
+ context=context,
814
+ flow_error=flow_error,
815
+ latency=latency,
816
+ attempt=attempt,
453
817
  )
454
818
  return
455
819
  attempt += 1
@@ -473,13 +837,386 @@ class PenguiFlow:
473
837
  delay = min(delay, policy.max_backoff)
474
838
  return delay
475
839
 
840
+ def _create_flow_error(
841
+ self,
842
+ *,
843
+ node: Node,
844
+ trace_id: str | None,
845
+ code: FlowErrorCode,
846
+ exc: BaseException,
847
+ attempt: int,
848
+ latency_ms: float | None,
849
+ message: str | None = None,
850
+ metadata: Mapping[str, Any] | None = None,
851
+ ) -> FlowError:
852
+ node_name = node.name
853
+ assert node_name is not None
854
+ meta: dict[str, Any] = {"attempt": attempt}
855
+ if latency_ms is not None:
856
+ meta["latency_ms"] = latency_ms
857
+ if metadata:
858
+ meta.update(metadata)
859
+ return FlowError.from_exception(
860
+ trace_id=trace_id,
861
+ node_name=node_name,
862
+ node_id=node.node_id,
863
+ exc=exc,
864
+ code=code,
865
+ message=message,
866
+ metadata=meta,
867
+ )
868
+
869
+ async def _handle_flow_error(
870
+ self,
871
+ *,
872
+ node: Node,
873
+ context: Context,
874
+ flow_error: FlowError,
875
+ latency: float | None,
876
+ attempt: int,
877
+ ) -> None:
878
+ original = flow_error.unwrap()
879
+ exception_repr = repr(original) if original is not None else flow_error.message
880
+ extra = {
881
+ "exception": exception_repr,
882
+ "flow_error": flow_error.to_payload(),
883
+ }
884
+ await self._emit_event(
885
+ event="node_failed",
886
+ node=node,
887
+ context=context,
888
+ trace_id=flow_error.trace_id,
889
+ attempt=attempt,
890
+ latency_ms=latency,
891
+ level=logging.ERROR,
892
+ extra=extra,
893
+ )
894
+ if self._emit_errors_to_rookery and flow_error.trace_id is not None:
895
+ await self._emit_to_rookery(flow_error, source=context.owner)
896
+
897
+ async def _invoke_node(
898
+ self,
899
+ node: Node,
900
+ context: Context,
901
+ message: Any,
902
+ trace_id: str | None,
903
+ ) -> Any:
904
+ invocation = node.invoke(message, context, registry=self._registry)
905
+ timeout = node.policy.timeout_s
906
+
907
+ if trace_id is None:
908
+ if timeout is None:
909
+ return await invocation
910
+ return await asyncio.wait_for(invocation, timeout)
911
+
912
+ return await self._await_invocation(node, invocation, trace_id, timeout)
913
+
914
+ def _register_invocation_task(
915
+ self, trace_id: str, task: asyncio.Task[Any]
916
+ ) -> None:
917
+ tasks = self._trace_invocations.get(trace_id)
918
+ if tasks is None:
919
+ tasks = set[asyncio.Task[Any]]()
920
+ self._trace_invocations[trace_id] = tasks
921
+ tasks.add(task)
922
+
923
+ def _cleanup(finished: asyncio.Future[Any]) -> None:
924
+ remaining = self._trace_invocations.get(trace_id)
925
+ if remaining is None:
926
+ return
927
+ remaining.discard(cast(asyncio.Task[Any], finished))
928
+ if not remaining:
929
+ self._trace_invocations.pop(trace_id, None)
930
+
931
+ task.add_done_callback(_cleanup)
932
+
933
+ async def _await_invocation(
934
+ self,
935
+ node: Node,
936
+ invocation: Awaitable[Any],
937
+ trace_id: str,
938
+ timeout: float | None,
939
+ ) -> Any:
940
+ invocation_task = cast(asyncio.Task[Any], asyncio.ensure_future(invocation))
941
+ self._register_invocation_task(trace_id, invocation_task)
942
+
943
+ cancel_event = self._trace_events.get(trace_id)
944
+ cancel_waiter: asyncio.Future[Any] | None = None
945
+ if cancel_event is not None:
946
+ cancel_waiter = asyncio.ensure_future(cancel_event.wait())
947
+
948
+ timeout_task: asyncio.Future[Any] | None = None
949
+ if timeout is not None:
950
+ timeout_task = asyncio.ensure_future(asyncio.sleep(timeout))
951
+
952
+ wait_tasks: set[asyncio.Future[Any]] = {invocation_task}
953
+ if cancel_waiter is not None:
954
+ wait_tasks.add(cancel_waiter)
955
+ if timeout_task is not None:
956
+ wait_tasks.add(timeout_task)
957
+
958
+ pending: set[asyncio.Future[Any]] = set()
959
+ try:
960
+ done, pending = await asyncio.wait(
961
+ wait_tasks, return_when=asyncio.FIRST_COMPLETED
962
+ )
963
+
964
+ if invocation_task in done:
965
+ if invocation_task.cancelled():
966
+ raise TraceCancelled(trace_id)
967
+ return invocation_task.result()
968
+
969
+ if cancel_waiter is not None and cancel_waiter in done:
970
+ invocation_task.cancel()
971
+ await asyncio.gather(invocation_task, return_exceptions=True)
972
+ raise TraceCancelled(trace_id)
973
+
974
+ if timeout_task is not None and timeout_task in done:
975
+ invocation_task.cancel()
976
+ await asyncio.gather(invocation_task, return_exceptions=True)
977
+ raise TimeoutError
978
+
979
+ raise RuntimeError("node invocation wait exited without result")
980
+ except asyncio.CancelledError:
981
+ invocation_task.cancel()
982
+ await asyncio.gather(invocation_task, return_exceptions=True)
983
+ if cancel_waiter is not None:
984
+ cancel_waiter.cancel()
985
+ if timeout_task is not None:
986
+ timeout_task.cancel()
987
+ await asyncio.gather(
988
+ *(task for task in (cancel_waiter, timeout_task) if task is not None),
989
+ return_exceptions=True,
990
+ )
991
+ raise
992
+ finally:
993
+ for task in pending:
994
+ task.cancel()
995
+ watchers = [
996
+ task for task in (cancel_waiter, timeout_task) if task is not None
997
+ ]
998
+ for watcher in watchers:
999
+ watcher.cancel()
1000
+ if watchers:
1001
+ await asyncio.gather(*watchers, return_exceptions=True)
1002
+
1003
+ def _get_trace_id(self, message: Any) -> str | None:
1004
+ return getattr(message, "trace_id", None)
1005
+
1006
+ def _is_trace_cancelled(self, trace_id: str) -> bool:
1007
+ event = self._trace_events.get(trace_id)
1008
+ return event.is_set() if event is not None else False
1009
+
1010
+ def _on_message_enqueued(self, message: Any) -> None:
1011
+ trace_id = self._get_trace_id(message)
1012
+ if trace_id is None:
1013
+ return
1014
+ self._trace_counts[trace_id] = self._trace_counts.get(trace_id, 0) + 1
1015
+ self._trace_events.setdefault(trace_id, asyncio.Event())
1016
+
1017
+ def _node_label(self, node: Node | Endpoint | None) -> str | None:
1018
+ if node is None:
1019
+ return None
1020
+ name = getattr(node, "name", None)
1021
+ if name:
1022
+ return name
1023
+ return getattr(node, "node_id", None)
1024
+
1025
+ def _build_bus_envelope(
1026
+ self,
1027
+ source: Node | Endpoint | None,
1028
+ target: Node | Endpoint | None,
1029
+ message: Any,
1030
+ ) -> BusEnvelope:
1031
+ source_name = self._node_label(source)
1032
+ target_name = self._node_label(target)
1033
+ edge = f"{source_name or '*'}->{target_name or '*'}"
1034
+ headers: Mapping[str, Any] | None = None
1035
+ meta: Mapping[str, Any] | None = None
1036
+ if isinstance(message, Message):
1037
+ headers = message.headers.model_dump()
1038
+ meta = dict(message.meta)
1039
+ return BusEnvelope(
1040
+ edge=edge,
1041
+ source=source_name,
1042
+ target=target_name,
1043
+ trace_id=self._get_trace_id(message),
1044
+ payload=message,
1045
+ headers=headers,
1046
+ meta=meta,
1047
+ )
1048
+
1049
+ async def _publish_to_bus(
1050
+ self,
1051
+ source: Node | Endpoint | None,
1052
+ target: Node | Endpoint | None,
1053
+ message: Any,
1054
+ ) -> None:
1055
+ if self._message_bus is None:
1056
+ return
1057
+ envelope = self._build_bus_envelope(source, target, message)
1058
+ try:
1059
+ await self._message_bus.publish(envelope)
1060
+ except Exception as exc:
1061
+ logger.exception(
1062
+ "message_bus_publish_failed",
1063
+ extra={
1064
+ "event": "message_bus_publish_failed",
1065
+ "edge": envelope.edge,
1066
+ "trace_id": envelope.trace_id,
1067
+ "exception": repr(exc),
1068
+ },
1069
+ )
1070
+
1071
+ def _schedule_bus_publish(
1072
+ self,
1073
+ source: Node | Endpoint | None,
1074
+ target: Node | Endpoint | None,
1075
+ message: Any,
1076
+ ) -> None:
1077
+ if self._message_bus is None:
1078
+ return
1079
+ loop = asyncio.get_running_loop()
1080
+ task = loop.create_task(self._publish_to_bus(source, target, message))
1081
+ self._bus_tasks.add(task)
1082
+
1083
+ def _cleanup(done: asyncio.Task[None]) -> None:
1084
+ self._bus_tasks.discard(done)
1085
+
1086
+ task.add_done_callback(_cleanup)
1087
+
1088
+ async def _send_to_floe(self, floe: Floe, message: Any) -> None:
1089
+ self._on_message_enqueued(message)
1090
+ if self._message_bus is not None:
1091
+ await self._publish_to_bus(floe.source, floe.target, message)
1092
+ await floe.queue.put(message)
1093
+
1094
+ def _send_to_floe_nowait(self, floe: Floe, message: Any) -> None:
1095
+ self._on_message_enqueued(message)
1096
+ if self._message_bus is not None:
1097
+ self._schedule_bus_publish(floe.source, floe.target, message)
1098
+ floe.queue.put_nowait(message)
1099
+
1100
+ async def _finalize_message(self, message: Any) -> None:
1101
+ trace_id = self._get_trace_id(message)
1102
+ if trace_id is None:
1103
+ return
1104
+
1105
+ remaining = self._trace_counts.get(trace_id)
1106
+ if remaining is None:
1107
+ return
1108
+
1109
+ remaining -= 1
1110
+ if remaining <= 0:
1111
+ self._trace_counts.pop(trace_id, None)
1112
+ event = self._trace_events.pop(trace_id, None)
1113
+ if event is not None and event.is_set():
1114
+ await self._emit_event(
1115
+ event="trace_cancel_finish",
1116
+ node=ROOKERY,
1117
+ context=self._contexts[ROOKERY],
1118
+ trace_id=trace_id,
1119
+ attempt=0,
1120
+ latency_ms=None,
1121
+ level=logging.INFO,
1122
+ )
1123
+ self._notify_trace_capacity(trace_id)
1124
+ self._latest_wm_hops.pop(trace_id, None)
1125
+ else:
1126
+ self._trace_counts[trace_id] = remaining
1127
+ if self._queue_maxsize <= 0 or remaining <= self._queue_maxsize:
1128
+ self._notify_trace_capacity(trace_id)
1129
+
1130
+ async def _drop_trace_from_floe(self, floe: Floe, trace_id: str) -> None:
1131
+ queue = floe.queue
1132
+ retained: list[Any] = []
1133
+
1134
+ while True:
1135
+ try:
1136
+ item = queue.get_nowait()
1137
+ except asyncio.QueueEmpty:
1138
+ break
1139
+
1140
+ if self._get_trace_id(item) == trace_id:
1141
+ await self._finalize_message(item)
1142
+ continue
1143
+
1144
+ retained.append(item)
1145
+
1146
+ for item in retained:
1147
+ queue.put_nowait(item)
1148
+
1149
+ async def cancel(self, trace_id: str) -> bool:
1150
+ if not self._running:
1151
+ raise RuntimeError("PenguiFlow is not running")
1152
+
1153
+ active = trace_id in self._trace_counts or trace_id in self._trace_invocations
1154
+ if not active:
1155
+ return False
1156
+
1157
+ event = self._trace_events.setdefault(trace_id, asyncio.Event())
1158
+ if not event.is_set():
1159
+ event.set()
1160
+ await self._emit_event(
1161
+ event="trace_cancel_start",
1162
+ node=OPEN_SEA,
1163
+ context=self._contexts[OPEN_SEA],
1164
+ trace_id=trace_id,
1165
+ attempt=0,
1166
+ latency_ms=None,
1167
+ level=logging.INFO,
1168
+ extra={"pending": self._trace_counts.get(trace_id, 0)},
1169
+ )
1170
+ else:
1171
+ event.set()
1172
+
1173
+ for floe in list(self._floes):
1174
+ await self._drop_trace_from_floe(floe, trace_id)
1175
+
1176
+ tasks = list(self._trace_invocations.get(trace_id, set()))
1177
+ for task in tasks:
1178
+ task.cancel()
1179
+
1180
+ return True
1181
+
1182
+ async def _await_trace_capacity(self, trace_id: str, *, offset: int = 0) -> None:
1183
+ if self._queue_maxsize <= 0:
1184
+ return
1185
+
1186
+ while True:
1187
+ pending = self._trace_counts.get(trace_id, 0)
1188
+ effective = pending - offset if pending > offset else 0
1189
+ if effective < self._queue_maxsize:
1190
+ return
1191
+ waiter = asyncio.Event()
1192
+ waiters = self._trace_capacity_waiters.setdefault(trace_id, [])
1193
+ waiters.append(waiter)
1194
+ try:
1195
+ await waiter.wait()
1196
+ finally:
1197
+ remaining_waiters = self._trace_capacity_waiters.get(trace_id)
1198
+ if remaining_waiters is not None:
1199
+ try:
1200
+ remaining_waiters.remove(waiter)
1201
+ except ValueError:
1202
+ pass
1203
+ if not remaining_waiters:
1204
+ self._trace_capacity_waiters.pop(trace_id, None)
1205
+
1206
+ def _notify_trace_capacity(self, trace_id: str) -> None:
1207
+ waiters = self._trace_capacity_waiters.pop(trace_id, None)
1208
+ if not waiters:
1209
+ return
1210
+ for waiter in waiters:
1211
+ waiter.set()
1212
+
476
1213
  def _controller_postprocess(
477
1214
  self,
478
1215
  node: Node,
479
1216
  context: Context,
480
1217
  incoming: Any,
481
1218
  result: Any,
482
- ) -> tuple[str, Any, list[Node] | None]:
1219
+ ) -> tuple[str, Any, list[Node] | None, bool]:
483
1220
  if isinstance(result, Message):
484
1221
  payload = result.payload
485
1222
  if isinstance(payload, WM):
@@ -487,27 +1224,102 @@ class PenguiFlow:
487
1224
  if result.deadline_s is not None and now > result.deadline_s:
488
1225
  final = FinalAnswer(text=DEADLINE_EXCEEDED_TEXT)
489
1226
  final_msg = result.model_copy(update={"payload": final})
490
- return "rookery", final_msg, None
1227
+ return "rookery", final_msg, None, False
491
1228
 
492
- if payload.hops + 1 >= payload.budget_hops:
1229
+ if (
1230
+ payload.budget_tokens is not None
1231
+ and payload.tokens_used >= payload.budget_tokens
1232
+ ):
1233
+ final = FinalAnswer(text=TOKEN_BUDGET_EXCEEDED_TEXT)
1234
+ final_msg = result.model_copy(update={"payload": final})
1235
+ return "rookery", final_msg, None, False
1236
+
1237
+ incoming_hops: int | None = None
1238
+ if (
1239
+ isinstance(incoming, Message)
1240
+ and isinstance(incoming.payload, WM)
1241
+ ):
1242
+ incoming_hops = incoming.payload.hops
1243
+
1244
+ current_hops = payload.hops
1245
+ if incoming_hops is not None and current_hops <= incoming_hops:
1246
+ next_hops = incoming_hops + 1
1247
+ else:
1248
+ next_hops = current_hops
1249
+
1250
+ if (
1251
+ payload.budget_hops is not None
1252
+ and next_hops >= payload.budget_hops
1253
+ ):
493
1254
  final = FinalAnswer(text=BUDGET_EXCEEDED_TEXT)
494
1255
  final_msg = result.model_copy(update={"payload": final})
495
- return "rookery", final_msg, None
1256
+ return "rookery", final_msg, None, False
1257
+
1258
+ if next_hops != current_hops:
1259
+ updated_payload = payload.model_copy(update={"hops": next_hops})
1260
+ prepared = result.model_copy(update={"payload": updated_payload})
1261
+ else:
1262
+ prepared = result
496
1263
 
497
- updated_payload = payload.model_copy(update={"hops": payload.hops + 1})
498
- updated_msg = result.model_copy(update={"payload": updated_payload})
499
- return "context", updated_msg, [node]
1264
+ stream_updates = (
1265
+ payload.budget_hops is None
1266
+ and payload.budget_tokens is None
1267
+ )
1268
+ return "context", prepared, [node], stream_updates
500
1269
 
501
1270
  if isinstance(payload, FinalAnswer):
502
- return "rookery", result, None
1271
+ return "rookery", result, None, False
1272
+
1273
+ return "context", result, None, False
1274
+
1275
+ def _deadline_expired(self, message: Any) -> bool:
1276
+ if isinstance(message, Message) and message.deadline_s is not None:
1277
+ return time.time() > message.deadline_s
1278
+ return False
1279
+
1280
+ async def _handle_deadline_expired(
1281
+ self, context: Context, message: Message
1282
+ ) -> None:
1283
+ payload = message.payload
1284
+ if not isinstance(payload, FinalAnswer):
1285
+ payload = FinalAnswer(text=DEADLINE_EXCEEDED_TEXT)
1286
+ final_msg = message.model_copy(update={"payload": payload})
1287
+ await self._emit_to_rookery(final_msg, source=context.owner)
1288
+
1289
+ async def _emit_to_rookery(
1290
+ self, message: Any, *, source: Node | Endpoint | None = None
1291
+ ) -> None:
1292
+ """Route ``message`` to the Rookery sink regardless of graph edges."""
503
1293
 
504
- return "context", result, None
1294
+ rookery_context = self._contexts[ROOKERY]
1295
+ incoming = rookery_context._incoming
1296
+
1297
+ floe: Floe | None = None
1298
+ if source is not None:
1299
+ floe = incoming.get(source)
1300
+ if floe is None and incoming:
1301
+ floe = next(iter(incoming.values()))
1302
+
1303
+ if floe is not None:
1304
+ await self._send_to_floe(floe, message)
1305
+ else:
1306
+ self._on_message_enqueued(message)
1307
+ if self._message_bus is not None:
1308
+ await self._publish_to_bus(source, ROOKERY, message)
1309
+ buffer = rookery_context._buffer
1310
+ buffer.append(message)
1311
+
1312
+ if isinstance(message, Message):
1313
+ payload = message.payload
1314
+ if isinstance(payload, WM):
1315
+ trace_id = message.trace_id
1316
+ self._latest_wm_hops[trace_id] = payload.hops
505
1317
 
506
1318
  async def _emit_event(
507
1319
  self,
508
1320
  *,
509
1321
  event: str,
510
- node: Node,
1322
+ node: Node | Endpoint,
511
1323
  context: Context,
512
1324
  trace_id: str | None,
513
1325
  attempt: int,
@@ -515,31 +1327,65 @@ class PenguiFlow:
515
1327
  level: int,
516
1328
  extra: dict[str, Any] | None = None,
517
1329
  ) -> None:
518
- payload: dict[str, Any] = {
519
- "ts": time.time(),
520
- "event": event,
521
- "node_name": node.name,
522
- "node_id": node.node_id,
523
- "trace_id": trace_id,
524
- "latency_ms": latency_ms,
525
- "q_depth_in": context.queue_depth_in(),
526
- "attempt": attempt,
527
- }
528
- if extra:
529
- payload.update(extra)
530
-
531
- logger.log(level, event, extra=payload)
1330
+ node_name = getattr(node, "name", None)
1331
+ node_id = getattr(node, "node_id", node_name)
1332
+ queue_depth_in = context.queue_depth_in()
1333
+ queue_depth_out = context.queue_depth_out()
1334
+ outgoing = context.outgoing_count()
1335
+
1336
+ trace_pending: int | None = None
1337
+ trace_inflight = 0
1338
+ trace_cancelled = False
1339
+ if trace_id is not None:
1340
+ trace_pending = self._trace_counts.get(trace_id, 0)
1341
+ trace_inflight = len(self._trace_invocations.get(trace_id, set()))
1342
+ trace_cancelled = self._is_trace_cancelled(trace_id)
1343
+
1344
+ event_obj = FlowEvent(
1345
+ event_type=event,
1346
+ ts=time.time(),
1347
+ node_name=node_name,
1348
+ node_id=node_id,
1349
+ trace_id=trace_id,
1350
+ attempt=attempt,
1351
+ latency_ms=latency_ms,
1352
+ queue_depth_in=queue_depth_in,
1353
+ queue_depth_out=queue_depth_out,
1354
+ outgoing_edges=outgoing,
1355
+ queue_maxsize=self._queue_maxsize,
1356
+ trace_pending=trace_pending,
1357
+ trace_inflight=trace_inflight,
1358
+ trace_cancelled=trace_cancelled,
1359
+ extra=extra or {},
1360
+ )
1361
+
1362
+ logger.log(level, event, extra=event_obj.to_payload())
1363
+
1364
+ if self._state_store is not None:
1365
+ stored_event = StoredEvent.from_flow_event(event_obj)
1366
+ try:
1367
+ await self._state_store.save_event(stored_event)
1368
+ except Exception as exc:
1369
+ logger.exception(
1370
+ "state_store_save_failed",
1371
+ extra={
1372
+ "event": "state_store_save_failed",
1373
+ "trace_id": stored_event.trace_id,
1374
+ "kind": stored_event.kind,
1375
+ "exception": repr(exc),
1376
+ },
1377
+ )
532
1378
 
533
1379
  for middleware in list(self._middlewares):
534
1380
  try:
535
- await middleware(event, payload)
1381
+ await middleware(event_obj)
536
1382
  except Exception as exc: # pragma: no cover - defensive
537
1383
  logger.exception(
538
1384
  "middleware_error",
539
1385
  extra={
540
1386
  "event": "middleware_error",
541
- "node_name": node.name,
542
- "node_id": node.node_id,
1387
+ "node_name": node_name,
1388
+ "node_id": node_id,
543
1389
  "exception": exc,
544
1390
  },
545
1391
  )
@@ -572,13 +1418,38 @@ async def call_playbook(
572
1418
  playbook: PlaybookFactory,
573
1419
  parent_msg: Message,
574
1420
  timeout: float | None = None,
1421
+ *,
1422
+ runtime: PenguiFlow | None = None,
575
1423
  ) -> Any:
576
1424
  """Execute a subflow playbook and return the first Rookery payload."""
577
1425
 
578
1426
  flow, registry = playbook()
579
1427
  flow.run(registry=registry)
580
1428
 
1429
+ trace_id = getattr(parent_msg, "trace_id", None)
1430
+ cancel_watch: asyncio.Task[None] | None = None
1431
+ pre_cancelled = False
1432
+
1433
+ if runtime is not None and trace_id is not None:
1434
+ parent_event = runtime._trace_events.setdefault(trace_id, asyncio.Event())
1435
+
1436
+ if parent_event.is_set():
1437
+ pre_cancelled = True
1438
+ else:
1439
+
1440
+ async def _mirror_cancel() -> None:
1441
+ try:
1442
+ await parent_event.wait()
1443
+ except asyncio.CancelledError:
1444
+ return
1445
+ with suppress(Exception):
1446
+ await flow.cancel(trace_id)
1447
+
1448
+ cancel_watch = asyncio.create_task(_mirror_cancel())
1449
+
581
1450
  try:
1451
+ if pre_cancelled:
1452
+ raise TraceCancelled(trace_id)
582
1453
  await flow.emit(parent_msg)
583
1454
  fetch_coro = flow.fetch()
584
1455
  if timeout is not None:
@@ -588,7 +1459,20 @@ async def call_playbook(
588
1459
  if isinstance(result_msg, Message):
589
1460
  return result_msg.payload
590
1461
  return result_msg
1462
+ except TraceCancelled:
1463
+ if trace_id is not None and not pre_cancelled:
1464
+ with suppress(Exception):
1465
+ await flow.cancel(trace_id)
1466
+ raise
1467
+ except asyncio.CancelledError:
1468
+ if trace_id is not None:
1469
+ with suppress(Exception):
1470
+ await flow.cancel(trace_id)
1471
+ raise
591
1472
  finally:
1473
+ if cancel_watch is not None:
1474
+ cancel_watch.cancel()
1475
+ await asyncio.gather(cancel_watch, return_exceptions=True)
592
1476
  await asyncio.shield(flow.stop())
593
1477
 
594
1478