penguiflow 1.0.3__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of penguiflow might be problematic. Click here for more details.

penguiflow/core.py CHANGED
@@ -10,19 +10,23 @@ import asyncio
10
10
  import logging
11
11
  import time
12
12
  from collections import deque
13
- from collections.abc import Callable, Sequence
13
+ from collections.abc import Awaitable, Callable, Mapping, Sequence
14
+ from contextlib import suppress
14
15
  from dataclasses import dataclass
15
16
  from typing import Any
16
17
 
18
+ from .errors import FlowError, FlowErrorCode
19
+ from .metrics import FlowEvent
17
20
  from .middlewares import Middleware
18
21
  from .node import Node, NodePolicy
19
22
  from .registry import ModelRegistry
20
- from .types import WM, FinalAnswer, Message
23
+ from .types import WM, FinalAnswer, Message, StreamChunk
21
24
 
22
25
  logger = logging.getLogger("penguiflow.core")
23
26
 
24
27
  BUDGET_EXCEEDED_TEXT = "Hop budget exhausted"
25
28
  DEADLINE_EXCEEDED_TEXT = "Deadline exceeded"
29
+ TOKEN_BUDGET_EXCEEDED_TEXT = "Token budget exhausted"
26
30
 
27
31
  DEFAULT_QUEUE_MAXSIZE = 64
28
32
 
@@ -59,21 +63,46 @@ class Floe:
59
63
  self.queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=maxsize)
60
64
 
61
65
 
66
+ class TraceCancelled(Exception):
67
+ """Raised when work for a specific trace_id is cancelled."""
68
+
69
+ def __init__(self, trace_id: str | None) -> None:
70
+ super().__init__(f"trace cancelled: {trace_id}")
71
+ self.trace_id = trace_id
72
+
73
+
62
74
  class Context:
63
75
  """Provides fetch/emit helpers for a node within a flow."""
64
76
 
65
- __slots__ = ("_owner", "_incoming", "_outgoing", "_buffer")
77
+ __slots__ = (
78
+ "_owner",
79
+ "_incoming",
80
+ "_outgoing",
81
+ "_buffer",
82
+ "_stream_seq",
83
+ "_runtime",
84
+ )
66
85
 
67
- def __init__(self, owner: Node | Endpoint) -> None:
86
+ def __init__(
87
+ self, owner: Node | Endpoint, runtime: PenguiFlow | None = None
88
+ ) -> None:
68
89
  self._owner = owner
69
90
  self._incoming: dict[Node | Endpoint, Floe] = {}
70
91
  self._outgoing: dict[Node | Endpoint, Floe] = {}
71
92
  self._buffer: deque[Any] = deque()
93
+ self._stream_seq: dict[str, int] = {}
94
+ self._runtime = runtime
72
95
 
73
96
  @property
74
97
  def owner(self) -> Node | Endpoint:
75
98
  return self._owner
76
99
 
100
+ @property
101
+ def runtime(self) -> PenguiFlow | None:
102
+ """Return the runtime this context is attached to, if any."""
103
+
104
+ return self._runtime
105
+
77
106
  def add_incoming_floe(self, floe: Floe) -> None:
78
107
  if floe.source is None:
79
108
  return
@@ -111,15 +140,82 @@ class Context:
111
140
  async def emit(
112
141
  self, msg: Any, to: Node | Endpoint | Sequence[Node | Endpoint] | None = None
113
142
  ) -> None:
143
+ if self._runtime is None:
144
+ raise RuntimeError("Context is not attached to a running flow")
114
145
  for floe in self._resolve_targets(to, self._outgoing):
146
+ self._runtime._on_message_enqueued(msg)
115
147
  await floe.queue.put(msg)
116
148
 
117
149
  def emit_nowait(
118
150
  self, msg: Any, to: Node | Endpoint | Sequence[Node | Endpoint] | None = None
119
151
  ) -> None:
152
+ if self._runtime is None:
153
+ raise RuntimeError("Context is not attached to a running flow")
120
154
  for floe in self._resolve_targets(to, self._outgoing):
155
+ self._runtime._on_message_enqueued(msg)
121
156
  floe.queue.put_nowait(msg)
122
157
 
158
+ async def emit_chunk(
159
+ self,
160
+ *,
161
+ parent: Message,
162
+ text: str,
163
+ stream_id: str | None = None,
164
+ seq: int | None = None,
165
+ done: bool = False,
166
+ meta: dict[str, Any] | None = None,
167
+ to: Node | Endpoint | Sequence[Node | Endpoint] | None = None,
168
+ ) -> StreamChunk:
169
+ """Emit a streaming chunk that inherits routing metadata from ``parent``.
170
+
171
+ The helper manages monotonically increasing sequence numbers per
172
+ ``stream_id`` (defaulting to the parent's trace id) unless an explicit
173
+ ``seq`` is provided. It returns the emitted ``StreamChunk`` for
174
+ introspection in tests or downstream logic.
175
+ """
176
+
177
+ sid = stream_id or parent.trace_id
178
+ first_chunk = sid not in self._stream_seq
179
+ if seq is None:
180
+ next_seq = self._stream_seq.get(sid, -1) + 1
181
+ else:
182
+ next_seq = seq
183
+ self._stream_seq[sid] = next_seq
184
+
185
+ meta_dict = dict(meta) if meta else {}
186
+
187
+ chunk = StreamChunk(
188
+ stream_id=sid,
189
+ seq=next_seq,
190
+ text=text,
191
+ done=done,
192
+ meta=meta_dict,
193
+ )
194
+
195
+ message_meta = dict(parent.meta)
196
+
197
+ message = Message(
198
+ payload=chunk,
199
+ headers=parent.headers,
200
+ trace_id=parent.trace_id,
201
+ deadline_s=parent.deadline_s,
202
+ meta=message_meta,
203
+ )
204
+
205
+ runtime = self._runtime
206
+ if runtime is None:
207
+ raise RuntimeError("Context is not attached to a running flow")
208
+
209
+ if not first_chunk:
210
+ await runtime._await_trace_capacity(sid, offset=1)
211
+
212
+ await self.emit(message, to=to)
213
+
214
+ if done:
215
+ self._stream_seq.pop(sid, None)
216
+
217
+ return chunk
218
+
123
219
  def fetch_nowait(
124
220
  self, from_: Node | Endpoint | Sequence[Node | Endpoint] | None = None
125
221
  ) -> Any:
@@ -175,6 +271,25 @@ class Context:
175
271
  def queue_depth_in(self) -> int:
176
272
  return sum(floe.queue.qsize() for floe in self._incoming.values())
177
273
 
274
+ def queue_depth_out(self) -> int:
275
+ return sum(floe.queue.qsize() for floe in self._outgoing.values())
276
+
277
+ async def call_playbook(
278
+ self,
279
+ playbook: PlaybookFactory,
280
+ parent_msg: Message,
281
+ *,
282
+ timeout: float | None = None,
283
+ ) -> Any:
284
+ """Launch a subflow playbook using the current runtime for propagation."""
285
+
286
+ return await call_playbook(
287
+ playbook,
288
+ parent_msg,
289
+ timeout=timeout,
290
+ runtime=self._runtime,
291
+ )
292
+
178
293
 
179
294
  class PenguiFlow:
180
295
  """Coordinates node execution and message routing."""
@@ -185,6 +300,7 @@ class PenguiFlow:
185
300
  queue_maxsize: int = DEFAULT_QUEUE_MAXSIZE,
186
301
  allow_cycles: bool = False,
187
302
  middlewares: Sequence[Middleware] | None = None,
303
+ emit_errors_to_rookery: bool = False,
188
304
  ) -> None:
189
305
  self._queue_maxsize = queue_maxsize
190
306
  self._allow_cycles = allow_cycles
@@ -196,6 +312,12 @@ class PenguiFlow:
196
312
  self._running = False
197
313
  self._registry: Any | None = None
198
314
  self._middlewares: list[Middleware] = list(middlewares or [])
315
+ self._trace_counts: dict[str, int] = {}
316
+ self._trace_events: dict[str, asyncio.Event] = {}
317
+ self._trace_invocations: dict[str, set[asyncio.Future[Any]]] = {}
318
+ self._trace_capacity_waiters: dict[str, list[asyncio.Event]] = {}
319
+ self._latest_wm_hops: dict[str, int] = {}
320
+ self._emit_errors_to_rookery = emit_errors_to_rookery
199
321
 
200
322
  self._build_graph(adjacencies)
201
323
 
@@ -219,9 +341,9 @@ class PenguiFlow:
219
341
 
220
342
  # create contexts for nodes and endpoints
221
343
  for node in self._nodes:
222
- self._contexts[node] = Context(node)
223
- self._contexts[OPEN_SEA] = Context(OPEN_SEA)
224
- self._contexts[ROOKERY] = Context(ROOKERY)
344
+ self._contexts[node] = Context(node, self)
345
+ self._contexts[OPEN_SEA] = Context(OPEN_SEA, self)
346
+ self._contexts[ROOKERY] = Context(ROOKERY, self)
225
347
 
226
348
  incoming: dict[Node, set[Node | Endpoint]] = {
227
349
  node: set() for node in self._nodes
@@ -303,7 +425,49 @@ class PenguiFlow:
303
425
  while True:
304
426
  try:
305
427
  message = await context.fetch()
306
- await self._execute_with_reliability(node, context, message)
428
+ trace_id = self._get_trace_id(message)
429
+ if self._deadline_expired(message):
430
+ await self._emit_event(
431
+ event="deadline_skip",
432
+ node=node,
433
+ context=context,
434
+ trace_id=trace_id,
435
+ attempt=0,
436
+ latency_ms=None,
437
+ level=logging.INFO,
438
+ extra={"deadline_s": getattr(message, "deadline_s", None)},
439
+ )
440
+ if isinstance(message, Message):
441
+ await self._handle_deadline_expired(context, message)
442
+ await self._finalize_message(message)
443
+ continue
444
+ if trace_id is not None and self._is_trace_cancelled(trace_id):
445
+ await self._emit_event(
446
+ event="trace_cancel_drop",
447
+ node=node,
448
+ context=context,
449
+ trace_id=trace_id,
450
+ attempt=0,
451
+ latency_ms=None,
452
+ level=logging.INFO,
453
+ )
454
+ await self._finalize_message(message)
455
+ continue
456
+
457
+ try:
458
+ await self._execute_with_reliability(node, context, message)
459
+ except TraceCancelled:
460
+ await self._emit_event(
461
+ event="node_trace_cancelled",
462
+ node=node,
463
+ context=context,
464
+ trace_id=trace_id,
465
+ attempt=0,
466
+ latency_ms=None,
467
+ level=logging.INFO,
468
+ )
469
+ finally:
470
+ await self._finalize_message(message)
307
471
  except asyncio.CancelledError:
308
472
  await self._emit_event(
309
473
  event="node_cancelled",
@@ -323,19 +487,65 @@ class PenguiFlow:
323
487
  task.cancel()
324
488
  await asyncio.gather(*self._tasks, return_exceptions=True)
325
489
  self._tasks.clear()
490
+ self._trace_counts.clear()
491
+ self._trace_events.clear()
492
+ self._trace_invocations.clear()
493
+ for waiters in self._trace_capacity_waiters.values():
494
+ for waiter in waiters:
495
+ waiter.set()
496
+ self._trace_capacity_waiters.clear()
326
497
  self._running = False
327
498
 
328
499
  async def emit(self, msg: Any, to: Node | Sequence[Node] | None = None) -> None:
500
+ if isinstance(msg, Message):
501
+ payload = msg.payload
502
+ if isinstance(payload, WM):
503
+ last = self._latest_wm_hops.get(msg.trace_id)
504
+ if last is not None and payload.hops == last:
505
+ return
329
506
  await self._contexts[OPEN_SEA].emit(msg, to)
330
507
 
331
508
  def emit_nowait(self, msg: Any, to: Node | Sequence[Node] | None = None) -> None:
509
+ if isinstance(msg, Message):
510
+ payload = msg.payload
511
+ if isinstance(payload, WM):
512
+ last = self._latest_wm_hops.get(msg.trace_id)
513
+ if last is not None and payload.hops == last:
514
+ return
332
515
  self._contexts[OPEN_SEA].emit_nowait(msg, to)
333
516
 
517
+ async def emit_chunk(
518
+ self,
519
+ *,
520
+ parent: Message,
521
+ text: str,
522
+ stream_id: str | None = None,
523
+ seq: int | None = None,
524
+ done: bool = False,
525
+ meta: dict[str, Any] | None = None,
526
+ to: Node | Sequence[Node] | None = None,
527
+ ) -> StreamChunk:
528
+ """Emit a streaming chunk from outside a node via OpenSea context."""
529
+
530
+ return await self._contexts[OPEN_SEA].emit_chunk(
531
+ parent=parent,
532
+ text=text,
533
+ stream_id=stream_id,
534
+ seq=seq,
535
+ done=done,
536
+ meta=meta,
537
+ to=to,
538
+ )
539
+
334
540
  async def fetch(self, from_: Node | Sequence[Node] | None = None) -> Any:
335
- return await self._contexts[ROOKERY].fetch(from_)
541
+ result = await self._contexts[ROOKERY].fetch(from_)
542
+ await self._finalize_message(result)
543
+ return result
336
544
 
337
545
  async def fetch_any(self, from_: Node | Sequence[Node] | None = None) -> Any:
338
- return await self._contexts[ROOKERY].fetch_any(from_)
546
+ result = await self._contexts[ROOKERY].fetch_any(from_)
547
+ await self._finalize_message(result)
548
+ return result
339
549
 
340
550
  async def _execute_with_reliability(
341
551
  self,
@@ -347,6 +557,9 @@ class PenguiFlow:
347
557
  attempt = 0
348
558
 
349
559
  while True:
560
+ if trace_id is not None and self._is_trace_cancelled(trace_id):
561
+ raise TraceCancelled(trace_id)
562
+
350
563
  start = time.perf_counter()
351
564
  await self._emit_event(
352
565
  event="node_start",
@@ -359,22 +572,42 @@ class PenguiFlow:
359
572
  )
360
573
 
361
574
  try:
362
- invocation = node.invoke(message, context, registry=self._registry)
363
- if node.policy.timeout_s is not None:
364
- result = await asyncio.wait_for(invocation, node.policy.timeout_s)
365
- else:
366
- result = await invocation
575
+ result = await self._invoke_node(
576
+ node,
577
+ context,
578
+ message,
579
+ trace_id,
580
+ )
367
581
 
368
582
  if result is not None:
369
- destination, prepared, targets = self._controller_postprocess(
583
+ (
584
+ destination,
585
+ prepared,
586
+ targets,
587
+ deliver_rookery,
588
+ ) = self._controller_postprocess(
370
589
  node, context, message, result
371
590
  )
372
591
 
592
+ if deliver_rookery:
593
+ rookery_msg = (
594
+ prepared.model_copy(deep=True)
595
+ if isinstance(prepared, Message)
596
+ else prepared
597
+ )
598
+ await self._emit_to_rookery(
599
+ rookery_msg, source=context.owner
600
+ )
601
+
373
602
  if destination == "skip":
374
603
  continue
604
+
375
605
  if destination == "rookery":
376
- await context.emit(prepared, to=[ROOKERY])
606
+ await self._emit_to_rookery(
607
+ prepared, source=context.owner
608
+ )
377
609
  continue
610
+
378
611
  await context.emit(prepared, to=targets)
379
612
 
380
613
  latency = (time.perf_counter() - start) * 1000
@@ -388,6 +621,8 @@ class PenguiFlow:
388
621
  level=logging.INFO,
389
622
  )
390
623
  return
624
+ except TraceCancelled:
625
+ raise
391
626
  except asyncio.CancelledError:
392
627
  raise
393
628
  except TimeoutError as exc:
@@ -403,15 +638,28 @@ class PenguiFlow:
403
638
  extra={"exception": repr(exc)},
404
639
  )
405
640
  if attempt >= node.policy.max_retries:
406
- await self._emit_event(
407
- event="node_failed",
641
+ timeout_message: str | None = None
642
+ if node.policy.timeout_s is not None:
643
+ timeout_message = (
644
+ f"Node '{node.name}' timed out after"
645
+ f" {node.policy.timeout_s:.2f}s"
646
+ )
647
+ flow_error = self._create_flow_error(
408
648
  node=node,
409
- context=context,
410
649
  trace_id=trace_id,
650
+ code=FlowErrorCode.NODE_TIMEOUT,
651
+ exc=exc,
411
652
  attempt=attempt,
412
653
  latency_ms=latency,
413
- level=logging.ERROR,
414
- extra={"exception": repr(exc)},
654
+ message=timeout_message,
655
+ metadata={"timeout_s": node.policy.timeout_s},
656
+ )
657
+ await self._handle_flow_error(
658
+ node=node,
659
+ context=context,
660
+ flow_error=flow_error,
661
+ latency=latency,
662
+ attempt=attempt,
415
663
  )
416
664
  return
417
665
  attempt += 1
@@ -441,15 +689,24 @@ class PenguiFlow:
441
689
  extra={"exception": repr(exc)},
442
690
  )
443
691
  if attempt >= node.policy.max_retries:
444
- await self._emit_event(
445
- event="node_failed",
692
+ flow_error = self._create_flow_error(
446
693
  node=node,
447
- context=context,
448
694
  trace_id=trace_id,
695
+ code=FlowErrorCode.NODE_EXCEPTION,
696
+ exc=exc,
449
697
  attempt=attempt,
450
698
  latency_ms=latency,
451
- level=logging.ERROR,
452
- extra={"exception": repr(exc)},
699
+ message=(
700
+ f"Node '{node.name}' raised {type(exc).__name__}: {exc}"
701
+ ),
702
+ metadata={"exception_repr": repr(exc)},
703
+ )
704
+ await self._handle_flow_error(
705
+ node=node,
706
+ context=context,
707
+ flow_error=flow_error,
708
+ latency=latency,
709
+ attempt=attempt,
453
710
  )
454
711
  return
455
712
  attempt += 1
@@ -473,13 +730,300 @@ class PenguiFlow:
473
730
  delay = min(delay, policy.max_backoff)
474
731
  return delay
475
732
 
733
+ def _create_flow_error(
734
+ self,
735
+ *,
736
+ node: Node,
737
+ trace_id: str | None,
738
+ code: FlowErrorCode,
739
+ exc: BaseException,
740
+ attempt: int,
741
+ latency_ms: float | None,
742
+ message: str | None = None,
743
+ metadata: Mapping[str, Any] | None = None,
744
+ ) -> FlowError:
745
+ node_name = node.name
746
+ assert node_name is not None
747
+ meta: dict[str, Any] = {"attempt": attempt}
748
+ if latency_ms is not None:
749
+ meta["latency_ms"] = latency_ms
750
+ if metadata:
751
+ meta.update(metadata)
752
+ return FlowError.from_exception(
753
+ trace_id=trace_id,
754
+ node_name=node_name,
755
+ node_id=node.node_id,
756
+ exc=exc,
757
+ code=code,
758
+ message=message,
759
+ metadata=meta,
760
+ )
761
+
762
+ async def _handle_flow_error(
763
+ self,
764
+ *,
765
+ node: Node,
766
+ context: Context,
767
+ flow_error: FlowError,
768
+ latency: float | None,
769
+ attempt: int,
770
+ ) -> None:
771
+ original = flow_error.unwrap()
772
+ exception_repr = repr(original) if original is not None else flow_error.message
773
+ extra = {
774
+ "exception": exception_repr,
775
+ "flow_error": flow_error.to_payload(),
776
+ }
777
+ await self._emit_event(
778
+ event="node_failed",
779
+ node=node,
780
+ context=context,
781
+ trace_id=flow_error.trace_id,
782
+ attempt=attempt,
783
+ latency_ms=latency,
784
+ level=logging.ERROR,
785
+ extra=extra,
786
+ )
787
+ if self._emit_errors_to_rookery and flow_error.trace_id is not None:
788
+ await self._emit_to_rookery(flow_error, source=context.owner)
789
+
790
+ async def _invoke_node(
791
+ self,
792
+ node: Node,
793
+ context: Context,
794
+ message: Any,
795
+ trace_id: str | None,
796
+ ) -> Any:
797
+ invocation = node.invoke(message, context, registry=self._registry)
798
+ timeout = node.policy.timeout_s
799
+
800
+ if trace_id is None:
801
+ if timeout is None:
802
+ return await invocation
803
+ return await asyncio.wait_for(invocation, timeout)
804
+
805
+ return await self._await_invocation(node, invocation, trace_id, timeout)
806
+
807
+ def _register_invocation_task(
808
+ self, trace_id: str, task: asyncio.Future[Any]
809
+ ) -> None:
810
+ tasks = self._trace_invocations.setdefault(trace_id, set())
811
+ tasks.add(task)
812
+
813
+ def _cleanup(finished: asyncio.Future[Any]) -> None:
814
+ remaining = self._trace_invocations.get(trace_id)
815
+ if remaining is None:
816
+ return
817
+ remaining.discard(finished)
818
+ if not remaining:
819
+ self._trace_invocations.pop(trace_id, None)
820
+
821
+ task.add_done_callback(_cleanup)
822
+
823
+ async def _await_invocation(
824
+ self,
825
+ node: Node,
826
+ invocation: Awaitable[Any],
827
+ trace_id: str,
828
+ timeout: float | None,
829
+ ) -> Any:
830
+ invocation_task = asyncio.ensure_future(invocation)
831
+ self._register_invocation_task(trace_id, invocation_task)
832
+
833
+ cancel_event = self._trace_events.get(trace_id)
834
+ cancel_waiter: asyncio.Future[Any] | None = None
835
+ if cancel_event is not None:
836
+ cancel_waiter = asyncio.ensure_future(cancel_event.wait())
837
+
838
+ timeout_task: asyncio.Future[Any] | None = None
839
+ if timeout is not None:
840
+ timeout_task = asyncio.ensure_future(asyncio.sleep(timeout))
841
+
842
+ wait_tasks: set[asyncio.Future[Any]] = {invocation_task}
843
+ if cancel_waiter is not None:
844
+ wait_tasks.add(cancel_waiter)
845
+ if timeout_task is not None:
846
+ wait_tasks.add(timeout_task)
847
+
848
+ pending: set[asyncio.Future[Any]] = set()
849
+ try:
850
+ done, pending = await asyncio.wait(
851
+ wait_tasks, return_when=asyncio.FIRST_COMPLETED
852
+ )
853
+
854
+ if invocation_task in done:
855
+ if invocation_task.cancelled():
856
+ raise TraceCancelled(trace_id)
857
+ return invocation_task.result()
858
+
859
+ if cancel_waiter is not None and cancel_waiter in done:
860
+ invocation_task.cancel()
861
+ await asyncio.gather(invocation_task, return_exceptions=True)
862
+ raise TraceCancelled(trace_id)
863
+
864
+ if timeout_task is not None and timeout_task in done:
865
+ invocation_task.cancel()
866
+ await asyncio.gather(invocation_task, return_exceptions=True)
867
+ raise TimeoutError
868
+
869
+ raise RuntimeError("node invocation wait exited without result")
870
+ except asyncio.CancelledError:
871
+ invocation_task.cancel()
872
+ await asyncio.gather(invocation_task, return_exceptions=True)
873
+ if cancel_waiter is not None:
874
+ cancel_waiter.cancel()
875
+ if timeout_task is not None:
876
+ timeout_task.cancel()
877
+ await asyncio.gather(
878
+ *(task for task in (cancel_waiter, timeout_task) if task is not None),
879
+ return_exceptions=True,
880
+ )
881
+ raise
882
+ finally:
883
+ for task in pending:
884
+ task.cancel()
885
+ watchers = [
886
+ task for task in (cancel_waiter, timeout_task) if task is not None
887
+ ]
888
+ for watcher in watchers:
889
+ watcher.cancel()
890
+ if watchers:
891
+ await asyncio.gather(*watchers, return_exceptions=True)
892
+
893
+ def _get_trace_id(self, message: Any) -> str | None:
894
+ return getattr(message, "trace_id", None)
895
+
896
+ def _is_trace_cancelled(self, trace_id: str) -> bool:
897
+ event = self._trace_events.get(trace_id)
898
+ return event.is_set() if event is not None else False
899
+
900
+ def _on_message_enqueued(self, message: Any) -> None:
901
+ trace_id = self._get_trace_id(message)
902
+ if trace_id is None:
903
+ return
904
+ self._trace_counts[trace_id] = self._trace_counts.get(trace_id, 0) + 1
905
+ self._trace_events.setdefault(trace_id, asyncio.Event())
906
+
907
+ async def _finalize_message(self, message: Any) -> None:
908
+ trace_id = self._get_trace_id(message)
909
+ if trace_id is None:
910
+ return
911
+
912
+ remaining = self._trace_counts.get(trace_id)
913
+ if remaining is None:
914
+ return
915
+
916
+ remaining -= 1
917
+ if remaining <= 0:
918
+ self._trace_counts.pop(trace_id, None)
919
+ event = self._trace_events.pop(trace_id, None)
920
+ if event is not None and event.is_set():
921
+ await self._emit_event(
922
+ event="trace_cancel_finish",
923
+ node=ROOKERY,
924
+ context=self._contexts[ROOKERY],
925
+ trace_id=trace_id,
926
+ attempt=0,
927
+ latency_ms=None,
928
+ level=logging.INFO,
929
+ )
930
+ self._notify_trace_capacity(trace_id)
931
+ self._latest_wm_hops.pop(trace_id, None)
932
+ else:
933
+ self._trace_counts[trace_id] = remaining
934
+ if self._queue_maxsize <= 0 or remaining <= self._queue_maxsize:
935
+ self._notify_trace_capacity(trace_id)
936
+
937
+ async def _drop_trace_from_floe(self, floe: Floe, trace_id: str) -> None:
938
+ queue = floe.queue
939
+ retained: list[Any] = []
940
+
941
+ while True:
942
+ try:
943
+ item = queue.get_nowait()
944
+ except asyncio.QueueEmpty:
945
+ break
946
+
947
+ if self._get_trace_id(item) == trace_id:
948
+ await self._finalize_message(item)
949
+ continue
950
+
951
+ retained.append(item)
952
+
953
+ for item in retained:
954
+ queue.put_nowait(item)
955
+
956
+ async def cancel(self, trace_id: str) -> bool:
957
+ if not self._running:
958
+ raise RuntimeError("PenguiFlow is not running")
959
+
960
+ active = trace_id in self._trace_counts or trace_id in self._trace_invocations
961
+ if not active:
962
+ return False
963
+
964
+ event = self._trace_events.setdefault(trace_id, asyncio.Event())
965
+ if not event.is_set():
966
+ event.set()
967
+ await self._emit_event(
968
+ event="trace_cancel_start",
969
+ node=OPEN_SEA,
970
+ context=self._contexts[OPEN_SEA],
971
+ trace_id=trace_id,
972
+ attempt=0,
973
+ latency_ms=None,
974
+ level=logging.INFO,
975
+ extra={"pending": self._trace_counts.get(trace_id, 0)},
976
+ )
977
+ else:
978
+ event.set()
979
+
980
+ for floe in list(self._floes):
981
+ await self._drop_trace_from_floe(floe, trace_id)
982
+
983
+ tasks = list(self._trace_invocations.get(trace_id, set()))
984
+ for task in tasks:
985
+ task.cancel()
986
+
987
+ return True
988
+
989
+ async def _await_trace_capacity(self, trace_id: str, *, offset: int = 0) -> None:
990
+ if self._queue_maxsize <= 0:
991
+ return
992
+
993
+ while True:
994
+ pending = self._trace_counts.get(trace_id, 0)
995
+ effective = pending - offset if pending > offset else 0
996
+ if effective < self._queue_maxsize:
997
+ return
998
+ waiter = asyncio.Event()
999
+ waiters = self._trace_capacity_waiters.setdefault(trace_id, [])
1000
+ waiters.append(waiter)
1001
+ try:
1002
+ await waiter.wait()
1003
+ finally:
1004
+ remaining_waiters = self._trace_capacity_waiters.get(trace_id)
1005
+ if remaining_waiters is not None:
1006
+ try:
1007
+ remaining_waiters.remove(waiter)
1008
+ except ValueError:
1009
+ pass
1010
+ if not remaining_waiters:
1011
+ self._trace_capacity_waiters.pop(trace_id, None)
1012
+
1013
+ def _notify_trace_capacity(self, trace_id: str) -> None:
1014
+ waiters = self._trace_capacity_waiters.pop(trace_id, None)
1015
+ if not waiters:
1016
+ return
1017
+ for waiter in waiters:
1018
+ waiter.set()
1019
+
476
1020
  def _controller_postprocess(
477
1021
  self,
478
1022
  node: Node,
479
1023
  context: Context,
480
1024
  incoming: Any,
481
1025
  result: Any,
482
- ) -> tuple[str, Any, list[Node] | None]:
1026
+ ) -> tuple[str, Any, list[Node] | None, bool]:
483
1027
  if isinstance(result, Message):
484
1028
  payload = result.payload
485
1029
  if isinstance(payload, WM):
@@ -487,27 +1031,101 @@ class PenguiFlow:
487
1031
  if result.deadline_s is not None and now > result.deadline_s:
488
1032
  final = FinalAnswer(text=DEADLINE_EXCEEDED_TEXT)
489
1033
  final_msg = result.model_copy(update={"payload": final})
490
- return "rookery", final_msg, None
1034
+ return "rookery", final_msg, None, False
491
1035
 
492
- if payload.hops + 1 >= payload.budget_hops:
1036
+ if (
1037
+ payload.budget_tokens is not None
1038
+ and payload.tokens_used >= payload.budget_tokens
1039
+ ):
1040
+ final = FinalAnswer(text=TOKEN_BUDGET_EXCEEDED_TEXT)
1041
+ final_msg = result.model_copy(update={"payload": final})
1042
+ return "rookery", final_msg, None, False
1043
+
1044
+ incoming_hops: int | None = None
1045
+ if (
1046
+ isinstance(incoming, Message)
1047
+ and isinstance(incoming.payload, WM)
1048
+ ):
1049
+ incoming_hops = incoming.payload.hops
1050
+
1051
+ current_hops = payload.hops
1052
+ if incoming_hops is not None and current_hops <= incoming_hops:
1053
+ next_hops = incoming_hops + 1
1054
+ else:
1055
+ next_hops = current_hops
1056
+
1057
+ if (
1058
+ payload.budget_hops is not None
1059
+ and next_hops >= payload.budget_hops
1060
+ ):
493
1061
  final = FinalAnswer(text=BUDGET_EXCEEDED_TEXT)
494
1062
  final_msg = result.model_copy(update={"payload": final})
495
- return "rookery", final_msg, None
1063
+ return "rookery", final_msg, None, False
496
1064
 
497
- updated_payload = payload.model_copy(update={"hops": payload.hops + 1})
498
- updated_msg = result.model_copy(update={"payload": updated_payload})
499
- return "context", updated_msg, [node]
1065
+ if next_hops != current_hops:
1066
+ updated_payload = payload.model_copy(update={"hops": next_hops})
1067
+ prepared = result.model_copy(update={"payload": updated_payload})
1068
+ else:
1069
+ prepared = result
1070
+
1071
+ stream_updates = (
1072
+ payload.budget_hops is None
1073
+ and payload.budget_tokens is None
1074
+ )
1075
+ return "context", prepared, [node], stream_updates
500
1076
 
501
1077
  if isinstance(payload, FinalAnswer):
502
- return "rookery", result, None
1078
+ return "rookery", result, None, False
1079
+
1080
+ return "context", result, None, False
1081
+
1082
+ def _deadline_expired(self, message: Any) -> bool:
1083
+ if isinstance(message, Message) and message.deadline_s is not None:
1084
+ return time.time() > message.deadline_s
1085
+ return False
1086
+
1087
+ async def _handle_deadline_expired(
1088
+ self, context: Context, message: Message
1089
+ ) -> None:
1090
+ payload = message.payload
1091
+ if not isinstance(payload, FinalAnswer):
1092
+ payload = FinalAnswer(text=DEADLINE_EXCEEDED_TEXT)
1093
+ final_msg = message.model_copy(update={"payload": payload})
1094
+ await self._emit_to_rookery(final_msg, source=context.owner)
1095
+
1096
+ async def _emit_to_rookery(
1097
+ self, message: Any, *, source: Node | Endpoint | None = None
1098
+ ) -> None:
1099
+ """Route ``message`` to the Rookery sink regardless of graph edges."""
1100
+
1101
+ rookery_context = self._contexts[ROOKERY]
1102
+ incoming = rookery_context._incoming
1103
+
1104
+ floe: Floe | None = None
1105
+ if source is not None:
1106
+ floe = incoming.get(source)
1107
+ if floe is None and incoming:
1108
+ floe = next(iter(incoming.values()))
1109
+
1110
+ self._on_message_enqueued(message)
1111
+
1112
+ if floe is not None:
1113
+ await floe.queue.put(message)
1114
+ else:
1115
+ buffer = rookery_context._buffer
1116
+ buffer.append(message)
503
1117
 
504
- return "context", result, None
1118
+ if isinstance(message, Message):
1119
+ payload = message.payload
1120
+ if isinstance(payload, WM):
1121
+ trace_id = message.trace_id
1122
+ self._latest_wm_hops[trace_id] = payload.hops
505
1123
 
506
1124
  async def _emit_event(
507
1125
  self,
508
1126
  *,
509
1127
  event: str,
510
- node: Node,
1128
+ node: Node | Endpoint,
511
1129
  context: Context,
512
1130
  trace_id: str | None,
513
1131
  attempt: int,
@@ -515,31 +1133,50 @@ class PenguiFlow:
515
1133
  level: int,
516
1134
  extra: dict[str, Any] | None = None,
517
1135
  ) -> None:
518
- payload: dict[str, Any] = {
519
- "ts": time.time(),
520
- "event": event,
521
- "node_name": node.name,
522
- "node_id": node.node_id,
523
- "trace_id": trace_id,
524
- "latency_ms": latency_ms,
525
- "q_depth_in": context.queue_depth_in(),
526
- "attempt": attempt,
527
- }
528
- if extra:
529
- payload.update(extra)
530
-
531
- logger.log(level, event, extra=payload)
1136
+ node_name = getattr(node, "name", None)
1137
+ node_id = getattr(node, "node_id", node_name)
1138
+ queue_depth_in = context.queue_depth_in()
1139
+ queue_depth_out = context.queue_depth_out()
1140
+ outgoing = context.outgoing_count()
1141
+
1142
+ trace_pending: int | None = None
1143
+ trace_inflight = 0
1144
+ trace_cancelled = False
1145
+ if trace_id is not None:
1146
+ trace_pending = self._trace_counts.get(trace_id, 0)
1147
+ trace_inflight = len(self._trace_invocations.get(trace_id, set()))
1148
+ trace_cancelled = self._is_trace_cancelled(trace_id)
1149
+
1150
+ event_obj = FlowEvent(
1151
+ event_type=event,
1152
+ ts=time.time(),
1153
+ node_name=node_name,
1154
+ node_id=node_id,
1155
+ trace_id=trace_id,
1156
+ attempt=attempt,
1157
+ latency_ms=latency_ms,
1158
+ queue_depth_in=queue_depth_in,
1159
+ queue_depth_out=queue_depth_out,
1160
+ outgoing_edges=outgoing,
1161
+ queue_maxsize=self._queue_maxsize,
1162
+ trace_pending=trace_pending,
1163
+ trace_inflight=trace_inflight,
1164
+ trace_cancelled=trace_cancelled,
1165
+ extra=extra or {},
1166
+ )
1167
+
1168
+ logger.log(level, event, extra=event_obj.to_payload())
532
1169
 
533
1170
  for middleware in list(self._middlewares):
534
1171
  try:
535
- await middleware(event, payload)
1172
+ await middleware(event_obj)
536
1173
  except Exception as exc: # pragma: no cover - defensive
537
1174
  logger.exception(
538
1175
  "middleware_error",
539
1176
  extra={
540
1177
  "event": "middleware_error",
541
- "node_name": node.name,
542
- "node_id": node.node_id,
1178
+ "node_name": node_name,
1179
+ "node_id": node_id,
543
1180
  "exception": exc,
544
1181
  },
545
1182
  )
@@ -572,13 +1209,38 @@ async def call_playbook(
572
1209
  playbook: PlaybookFactory,
573
1210
  parent_msg: Message,
574
1211
  timeout: float | None = None,
1212
+ *,
1213
+ runtime: PenguiFlow | None = None,
575
1214
  ) -> Any:
576
1215
  """Execute a subflow playbook and return the first Rookery payload."""
577
1216
 
578
1217
  flow, registry = playbook()
579
1218
  flow.run(registry=registry)
580
1219
 
1220
+ trace_id = getattr(parent_msg, "trace_id", None)
1221
+ cancel_watch: asyncio.Task[None] | None = None
1222
+ pre_cancelled = False
1223
+
1224
+ if runtime is not None and trace_id is not None:
1225
+ parent_event = runtime._trace_events.setdefault(trace_id, asyncio.Event())
1226
+
1227
+ if parent_event.is_set():
1228
+ pre_cancelled = True
1229
+ else:
1230
+
1231
+ async def _mirror_cancel() -> None:
1232
+ try:
1233
+ await parent_event.wait()
1234
+ except asyncio.CancelledError:
1235
+ return
1236
+ with suppress(Exception):
1237
+ await flow.cancel(trace_id)
1238
+
1239
+ cancel_watch = asyncio.create_task(_mirror_cancel())
1240
+
581
1241
  try:
1242
+ if pre_cancelled:
1243
+ raise TraceCancelled(trace_id)
582
1244
  await flow.emit(parent_msg)
583
1245
  fetch_coro = flow.fetch()
584
1246
  if timeout is not None:
@@ -588,7 +1250,20 @@ async def call_playbook(
588
1250
  if isinstance(result_msg, Message):
589
1251
  return result_msg.payload
590
1252
  return result_msg
1253
+ except TraceCancelled:
1254
+ if trace_id is not None and not pre_cancelled:
1255
+ with suppress(Exception):
1256
+ await flow.cancel(trace_id)
1257
+ raise
1258
+ except asyncio.CancelledError:
1259
+ if trace_id is not None:
1260
+ with suppress(Exception):
1261
+ await flow.cancel(trace_id)
1262
+ raise
591
1263
  finally:
1264
+ if cancel_watch is not None:
1265
+ cancel_watch.cancel()
1266
+ await asyncio.gather(cancel_watch, return_exceptions=True)
592
1267
  await asyncio.shield(flow.stop())
593
1268
 
594
1269