penguiflow 1.0.3__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of penguiflow might be problematic. Click here for more details.
- penguiflow/__init__.py +26 -3
- penguiflow/core.py +729 -54
- penguiflow/errors.py +113 -0
- penguiflow/metrics.py +105 -0
- penguiflow/middlewares.py +6 -7
- penguiflow/patterns.py +47 -5
- penguiflow/policies.py +149 -0
- penguiflow/streaming.py +142 -0
- penguiflow/testkit.py +269 -0
- penguiflow/types.py +15 -1
- penguiflow/viz.py +133 -24
- {penguiflow-1.0.3.dist-info → penguiflow-2.0.0.dist-info}/METADATA +161 -20
- penguiflow-2.0.0.dist-info/RECORD +18 -0
- penguiflow-1.0.3.dist-info/RECORD +0 -13
- {penguiflow-1.0.3.dist-info → penguiflow-2.0.0.dist-info}/WHEEL +0 -0
- {penguiflow-1.0.3.dist-info → penguiflow-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {penguiflow-1.0.3.dist-info → penguiflow-2.0.0.dist-info}/top_level.txt +0 -0
penguiflow/core.py
CHANGED
|
@@ -10,19 +10,23 @@ import asyncio
|
|
|
10
10
|
import logging
|
|
11
11
|
import time
|
|
12
12
|
from collections import deque
|
|
13
|
-
from collections.abc import Callable, Sequence
|
|
13
|
+
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
|
14
|
+
from contextlib import suppress
|
|
14
15
|
from dataclasses import dataclass
|
|
15
16
|
from typing import Any
|
|
16
17
|
|
|
18
|
+
from .errors import FlowError, FlowErrorCode
|
|
19
|
+
from .metrics import FlowEvent
|
|
17
20
|
from .middlewares import Middleware
|
|
18
21
|
from .node import Node, NodePolicy
|
|
19
22
|
from .registry import ModelRegistry
|
|
20
|
-
from .types import WM, FinalAnswer, Message
|
|
23
|
+
from .types import WM, FinalAnswer, Message, StreamChunk
|
|
21
24
|
|
|
22
25
|
logger = logging.getLogger("penguiflow.core")
|
|
23
26
|
|
|
24
27
|
BUDGET_EXCEEDED_TEXT = "Hop budget exhausted"
|
|
25
28
|
DEADLINE_EXCEEDED_TEXT = "Deadline exceeded"
|
|
29
|
+
TOKEN_BUDGET_EXCEEDED_TEXT = "Token budget exhausted"
|
|
26
30
|
|
|
27
31
|
DEFAULT_QUEUE_MAXSIZE = 64
|
|
28
32
|
|
|
@@ -59,21 +63,46 @@ class Floe:
|
|
|
59
63
|
self.queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=maxsize)
|
|
60
64
|
|
|
61
65
|
|
|
66
|
+
class TraceCancelled(Exception):
|
|
67
|
+
"""Raised when work for a specific trace_id is cancelled."""
|
|
68
|
+
|
|
69
|
+
def __init__(self, trace_id: str | None) -> None:
|
|
70
|
+
super().__init__(f"trace cancelled: {trace_id}")
|
|
71
|
+
self.trace_id = trace_id
|
|
72
|
+
|
|
73
|
+
|
|
62
74
|
class Context:
|
|
63
75
|
"""Provides fetch/emit helpers for a node within a flow."""
|
|
64
76
|
|
|
65
|
-
__slots__ = (
|
|
77
|
+
__slots__ = (
|
|
78
|
+
"_owner",
|
|
79
|
+
"_incoming",
|
|
80
|
+
"_outgoing",
|
|
81
|
+
"_buffer",
|
|
82
|
+
"_stream_seq",
|
|
83
|
+
"_runtime",
|
|
84
|
+
)
|
|
66
85
|
|
|
67
|
-
def __init__(
|
|
86
|
+
def __init__(
|
|
87
|
+
self, owner: Node | Endpoint, runtime: PenguiFlow | None = None
|
|
88
|
+
) -> None:
|
|
68
89
|
self._owner = owner
|
|
69
90
|
self._incoming: dict[Node | Endpoint, Floe] = {}
|
|
70
91
|
self._outgoing: dict[Node | Endpoint, Floe] = {}
|
|
71
92
|
self._buffer: deque[Any] = deque()
|
|
93
|
+
self._stream_seq: dict[str, int] = {}
|
|
94
|
+
self._runtime = runtime
|
|
72
95
|
|
|
73
96
|
@property
|
|
74
97
|
def owner(self) -> Node | Endpoint:
|
|
75
98
|
return self._owner
|
|
76
99
|
|
|
100
|
+
@property
|
|
101
|
+
def runtime(self) -> PenguiFlow | None:
|
|
102
|
+
"""Return the runtime this context is attached to, if any."""
|
|
103
|
+
|
|
104
|
+
return self._runtime
|
|
105
|
+
|
|
77
106
|
def add_incoming_floe(self, floe: Floe) -> None:
|
|
78
107
|
if floe.source is None:
|
|
79
108
|
return
|
|
@@ -111,15 +140,82 @@ class Context:
|
|
|
111
140
|
async def emit(
|
|
112
141
|
self, msg: Any, to: Node | Endpoint | Sequence[Node | Endpoint] | None = None
|
|
113
142
|
) -> None:
|
|
143
|
+
if self._runtime is None:
|
|
144
|
+
raise RuntimeError("Context is not attached to a running flow")
|
|
114
145
|
for floe in self._resolve_targets(to, self._outgoing):
|
|
146
|
+
self._runtime._on_message_enqueued(msg)
|
|
115
147
|
await floe.queue.put(msg)
|
|
116
148
|
|
|
117
149
|
def emit_nowait(
|
|
118
150
|
self, msg: Any, to: Node | Endpoint | Sequence[Node | Endpoint] | None = None
|
|
119
151
|
) -> None:
|
|
152
|
+
if self._runtime is None:
|
|
153
|
+
raise RuntimeError("Context is not attached to a running flow")
|
|
120
154
|
for floe in self._resolve_targets(to, self._outgoing):
|
|
155
|
+
self._runtime._on_message_enqueued(msg)
|
|
121
156
|
floe.queue.put_nowait(msg)
|
|
122
157
|
|
|
158
|
+
async def emit_chunk(
|
|
159
|
+
self,
|
|
160
|
+
*,
|
|
161
|
+
parent: Message,
|
|
162
|
+
text: str,
|
|
163
|
+
stream_id: str | None = None,
|
|
164
|
+
seq: int | None = None,
|
|
165
|
+
done: bool = False,
|
|
166
|
+
meta: dict[str, Any] | None = None,
|
|
167
|
+
to: Node | Endpoint | Sequence[Node | Endpoint] | None = None,
|
|
168
|
+
) -> StreamChunk:
|
|
169
|
+
"""Emit a streaming chunk that inherits routing metadata from ``parent``.
|
|
170
|
+
|
|
171
|
+
The helper manages monotonically increasing sequence numbers per
|
|
172
|
+
``stream_id`` (defaulting to the parent's trace id) unless an explicit
|
|
173
|
+
``seq`` is provided. It returns the emitted ``StreamChunk`` for
|
|
174
|
+
introspection in tests or downstream logic.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
sid = stream_id or parent.trace_id
|
|
178
|
+
first_chunk = sid not in self._stream_seq
|
|
179
|
+
if seq is None:
|
|
180
|
+
next_seq = self._stream_seq.get(sid, -1) + 1
|
|
181
|
+
else:
|
|
182
|
+
next_seq = seq
|
|
183
|
+
self._stream_seq[sid] = next_seq
|
|
184
|
+
|
|
185
|
+
meta_dict = dict(meta) if meta else {}
|
|
186
|
+
|
|
187
|
+
chunk = StreamChunk(
|
|
188
|
+
stream_id=sid,
|
|
189
|
+
seq=next_seq,
|
|
190
|
+
text=text,
|
|
191
|
+
done=done,
|
|
192
|
+
meta=meta_dict,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
message_meta = dict(parent.meta)
|
|
196
|
+
|
|
197
|
+
message = Message(
|
|
198
|
+
payload=chunk,
|
|
199
|
+
headers=parent.headers,
|
|
200
|
+
trace_id=parent.trace_id,
|
|
201
|
+
deadline_s=parent.deadline_s,
|
|
202
|
+
meta=message_meta,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
runtime = self._runtime
|
|
206
|
+
if runtime is None:
|
|
207
|
+
raise RuntimeError("Context is not attached to a running flow")
|
|
208
|
+
|
|
209
|
+
if not first_chunk:
|
|
210
|
+
await runtime._await_trace_capacity(sid, offset=1)
|
|
211
|
+
|
|
212
|
+
await self.emit(message, to=to)
|
|
213
|
+
|
|
214
|
+
if done:
|
|
215
|
+
self._stream_seq.pop(sid, None)
|
|
216
|
+
|
|
217
|
+
return chunk
|
|
218
|
+
|
|
123
219
|
def fetch_nowait(
|
|
124
220
|
self, from_: Node | Endpoint | Sequence[Node | Endpoint] | None = None
|
|
125
221
|
) -> Any:
|
|
@@ -175,6 +271,25 @@ class Context:
|
|
|
175
271
|
def queue_depth_in(self) -> int:
|
|
176
272
|
return sum(floe.queue.qsize() for floe in self._incoming.values())
|
|
177
273
|
|
|
274
|
+
def queue_depth_out(self) -> int:
|
|
275
|
+
return sum(floe.queue.qsize() for floe in self._outgoing.values())
|
|
276
|
+
|
|
277
|
+
async def call_playbook(
|
|
278
|
+
self,
|
|
279
|
+
playbook: PlaybookFactory,
|
|
280
|
+
parent_msg: Message,
|
|
281
|
+
*,
|
|
282
|
+
timeout: float | None = None,
|
|
283
|
+
) -> Any:
|
|
284
|
+
"""Launch a subflow playbook using the current runtime for propagation."""
|
|
285
|
+
|
|
286
|
+
return await call_playbook(
|
|
287
|
+
playbook,
|
|
288
|
+
parent_msg,
|
|
289
|
+
timeout=timeout,
|
|
290
|
+
runtime=self._runtime,
|
|
291
|
+
)
|
|
292
|
+
|
|
178
293
|
|
|
179
294
|
class PenguiFlow:
|
|
180
295
|
"""Coordinates node execution and message routing."""
|
|
@@ -185,6 +300,7 @@ class PenguiFlow:
|
|
|
185
300
|
queue_maxsize: int = DEFAULT_QUEUE_MAXSIZE,
|
|
186
301
|
allow_cycles: bool = False,
|
|
187
302
|
middlewares: Sequence[Middleware] | None = None,
|
|
303
|
+
emit_errors_to_rookery: bool = False,
|
|
188
304
|
) -> None:
|
|
189
305
|
self._queue_maxsize = queue_maxsize
|
|
190
306
|
self._allow_cycles = allow_cycles
|
|
@@ -196,6 +312,12 @@ class PenguiFlow:
|
|
|
196
312
|
self._running = False
|
|
197
313
|
self._registry: Any | None = None
|
|
198
314
|
self._middlewares: list[Middleware] = list(middlewares or [])
|
|
315
|
+
self._trace_counts: dict[str, int] = {}
|
|
316
|
+
self._trace_events: dict[str, asyncio.Event] = {}
|
|
317
|
+
self._trace_invocations: dict[str, set[asyncio.Future[Any]]] = {}
|
|
318
|
+
self._trace_capacity_waiters: dict[str, list[asyncio.Event]] = {}
|
|
319
|
+
self._latest_wm_hops: dict[str, int] = {}
|
|
320
|
+
self._emit_errors_to_rookery = emit_errors_to_rookery
|
|
199
321
|
|
|
200
322
|
self._build_graph(adjacencies)
|
|
201
323
|
|
|
@@ -219,9 +341,9 @@ class PenguiFlow:
|
|
|
219
341
|
|
|
220
342
|
# create contexts for nodes and endpoints
|
|
221
343
|
for node in self._nodes:
|
|
222
|
-
self._contexts[node] = Context(node)
|
|
223
|
-
self._contexts[OPEN_SEA] = Context(OPEN_SEA)
|
|
224
|
-
self._contexts[ROOKERY] = Context(ROOKERY)
|
|
344
|
+
self._contexts[node] = Context(node, self)
|
|
345
|
+
self._contexts[OPEN_SEA] = Context(OPEN_SEA, self)
|
|
346
|
+
self._contexts[ROOKERY] = Context(ROOKERY, self)
|
|
225
347
|
|
|
226
348
|
incoming: dict[Node, set[Node | Endpoint]] = {
|
|
227
349
|
node: set() for node in self._nodes
|
|
@@ -303,7 +425,49 @@ class PenguiFlow:
|
|
|
303
425
|
while True:
|
|
304
426
|
try:
|
|
305
427
|
message = await context.fetch()
|
|
306
|
-
|
|
428
|
+
trace_id = self._get_trace_id(message)
|
|
429
|
+
if self._deadline_expired(message):
|
|
430
|
+
await self._emit_event(
|
|
431
|
+
event="deadline_skip",
|
|
432
|
+
node=node,
|
|
433
|
+
context=context,
|
|
434
|
+
trace_id=trace_id,
|
|
435
|
+
attempt=0,
|
|
436
|
+
latency_ms=None,
|
|
437
|
+
level=logging.INFO,
|
|
438
|
+
extra={"deadline_s": getattr(message, "deadline_s", None)},
|
|
439
|
+
)
|
|
440
|
+
if isinstance(message, Message):
|
|
441
|
+
await self._handle_deadline_expired(context, message)
|
|
442
|
+
await self._finalize_message(message)
|
|
443
|
+
continue
|
|
444
|
+
if trace_id is not None and self._is_trace_cancelled(trace_id):
|
|
445
|
+
await self._emit_event(
|
|
446
|
+
event="trace_cancel_drop",
|
|
447
|
+
node=node,
|
|
448
|
+
context=context,
|
|
449
|
+
trace_id=trace_id,
|
|
450
|
+
attempt=0,
|
|
451
|
+
latency_ms=None,
|
|
452
|
+
level=logging.INFO,
|
|
453
|
+
)
|
|
454
|
+
await self._finalize_message(message)
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
try:
|
|
458
|
+
await self._execute_with_reliability(node, context, message)
|
|
459
|
+
except TraceCancelled:
|
|
460
|
+
await self._emit_event(
|
|
461
|
+
event="node_trace_cancelled",
|
|
462
|
+
node=node,
|
|
463
|
+
context=context,
|
|
464
|
+
trace_id=trace_id,
|
|
465
|
+
attempt=0,
|
|
466
|
+
latency_ms=None,
|
|
467
|
+
level=logging.INFO,
|
|
468
|
+
)
|
|
469
|
+
finally:
|
|
470
|
+
await self._finalize_message(message)
|
|
307
471
|
except asyncio.CancelledError:
|
|
308
472
|
await self._emit_event(
|
|
309
473
|
event="node_cancelled",
|
|
@@ -323,19 +487,65 @@ class PenguiFlow:
|
|
|
323
487
|
task.cancel()
|
|
324
488
|
await asyncio.gather(*self._tasks, return_exceptions=True)
|
|
325
489
|
self._tasks.clear()
|
|
490
|
+
self._trace_counts.clear()
|
|
491
|
+
self._trace_events.clear()
|
|
492
|
+
self._trace_invocations.clear()
|
|
493
|
+
for waiters in self._trace_capacity_waiters.values():
|
|
494
|
+
for waiter in waiters:
|
|
495
|
+
waiter.set()
|
|
496
|
+
self._trace_capacity_waiters.clear()
|
|
326
497
|
self._running = False
|
|
327
498
|
|
|
328
499
|
async def emit(self, msg: Any, to: Node | Sequence[Node] | None = None) -> None:
|
|
500
|
+
if isinstance(msg, Message):
|
|
501
|
+
payload = msg.payload
|
|
502
|
+
if isinstance(payload, WM):
|
|
503
|
+
last = self._latest_wm_hops.get(msg.trace_id)
|
|
504
|
+
if last is not None and payload.hops == last:
|
|
505
|
+
return
|
|
329
506
|
await self._contexts[OPEN_SEA].emit(msg, to)
|
|
330
507
|
|
|
331
508
|
def emit_nowait(self, msg: Any, to: Node | Sequence[Node] | None = None) -> None:
|
|
509
|
+
if isinstance(msg, Message):
|
|
510
|
+
payload = msg.payload
|
|
511
|
+
if isinstance(payload, WM):
|
|
512
|
+
last = self._latest_wm_hops.get(msg.trace_id)
|
|
513
|
+
if last is not None and payload.hops == last:
|
|
514
|
+
return
|
|
332
515
|
self._contexts[OPEN_SEA].emit_nowait(msg, to)
|
|
333
516
|
|
|
517
|
+
async def emit_chunk(
|
|
518
|
+
self,
|
|
519
|
+
*,
|
|
520
|
+
parent: Message,
|
|
521
|
+
text: str,
|
|
522
|
+
stream_id: str | None = None,
|
|
523
|
+
seq: int | None = None,
|
|
524
|
+
done: bool = False,
|
|
525
|
+
meta: dict[str, Any] | None = None,
|
|
526
|
+
to: Node | Sequence[Node] | None = None,
|
|
527
|
+
) -> StreamChunk:
|
|
528
|
+
"""Emit a streaming chunk from outside a node via OpenSea context."""
|
|
529
|
+
|
|
530
|
+
return await self._contexts[OPEN_SEA].emit_chunk(
|
|
531
|
+
parent=parent,
|
|
532
|
+
text=text,
|
|
533
|
+
stream_id=stream_id,
|
|
534
|
+
seq=seq,
|
|
535
|
+
done=done,
|
|
536
|
+
meta=meta,
|
|
537
|
+
to=to,
|
|
538
|
+
)
|
|
539
|
+
|
|
334
540
|
async def fetch(self, from_: Node | Sequence[Node] | None = None) -> Any:
|
|
335
|
-
|
|
541
|
+
result = await self._contexts[ROOKERY].fetch(from_)
|
|
542
|
+
await self._finalize_message(result)
|
|
543
|
+
return result
|
|
336
544
|
|
|
337
545
|
async def fetch_any(self, from_: Node | Sequence[Node] | None = None) -> Any:
|
|
338
|
-
|
|
546
|
+
result = await self._contexts[ROOKERY].fetch_any(from_)
|
|
547
|
+
await self._finalize_message(result)
|
|
548
|
+
return result
|
|
339
549
|
|
|
340
550
|
async def _execute_with_reliability(
|
|
341
551
|
self,
|
|
@@ -347,6 +557,9 @@ class PenguiFlow:
|
|
|
347
557
|
attempt = 0
|
|
348
558
|
|
|
349
559
|
while True:
|
|
560
|
+
if trace_id is not None and self._is_trace_cancelled(trace_id):
|
|
561
|
+
raise TraceCancelled(trace_id)
|
|
562
|
+
|
|
350
563
|
start = time.perf_counter()
|
|
351
564
|
await self._emit_event(
|
|
352
565
|
event="node_start",
|
|
@@ -359,22 +572,42 @@ class PenguiFlow:
|
|
|
359
572
|
)
|
|
360
573
|
|
|
361
574
|
try:
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
575
|
+
result = await self._invoke_node(
|
|
576
|
+
node,
|
|
577
|
+
context,
|
|
578
|
+
message,
|
|
579
|
+
trace_id,
|
|
580
|
+
)
|
|
367
581
|
|
|
368
582
|
if result is not None:
|
|
369
|
-
|
|
583
|
+
(
|
|
584
|
+
destination,
|
|
585
|
+
prepared,
|
|
586
|
+
targets,
|
|
587
|
+
deliver_rookery,
|
|
588
|
+
) = self._controller_postprocess(
|
|
370
589
|
node, context, message, result
|
|
371
590
|
)
|
|
372
591
|
|
|
592
|
+
if deliver_rookery:
|
|
593
|
+
rookery_msg = (
|
|
594
|
+
prepared.model_copy(deep=True)
|
|
595
|
+
if isinstance(prepared, Message)
|
|
596
|
+
else prepared
|
|
597
|
+
)
|
|
598
|
+
await self._emit_to_rookery(
|
|
599
|
+
rookery_msg, source=context.owner
|
|
600
|
+
)
|
|
601
|
+
|
|
373
602
|
if destination == "skip":
|
|
374
603
|
continue
|
|
604
|
+
|
|
375
605
|
if destination == "rookery":
|
|
376
|
-
await
|
|
606
|
+
await self._emit_to_rookery(
|
|
607
|
+
prepared, source=context.owner
|
|
608
|
+
)
|
|
377
609
|
continue
|
|
610
|
+
|
|
378
611
|
await context.emit(prepared, to=targets)
|
|
379
612
|
|
|
380
613
|
latency = (time.perf_counter() - start) * 1000
|
|
@@ -388,6 +621,8 @@ class PenguiFlow:
|
|
|
388
621
|
level=logging.INFO,
|
|
389
622
|
)
|
|
390
623
|
return
|
|
624
|
+
except TraceCancelled:
|
|
625
|
+
raise
|
|
391
626
|
except asyncio.CancelledError:
|
|
392
627
|
raise
|
|
393
628
|
except TimeoutError as exc:
|
|
@@ -403,15 +638,28 @@ class PenguiFlow:
|
|
|
403
638
|
extra={"exception": repr(exc)},
|
|
404
639
|
)
|
|
405
640
|
if attempt >= node.policy.max_retries:
|
|
406
|
-
|
|
407
|
-
|
|
641
|
+
timeout_message: str | None = None
|
|
642
|
+
if node.policy.timeout_s is not None:
|
|
643
|
+
timeout_message = (
|
|
644
|
+
f"Node '{node.name}' timed out after"
|
|
645
|
+
f" {node.policy.timeout_s:.2f}s"
|
|
646
|
+
)
|
|
647
|
+
flow_error = self._create_flow_error(
|
|
408
648
|
node=node,
|
|
409
|
-
context=context,
|
|
410
649
|
trace_id=trace_id,
|
|
650
|
+
code=FlowErrorCode.NODE_TIMEOUT,
|
|
651
|
+
exc=exc,
|
|
411
652
|
attempt=attempt,
|
|
412
653
|
latency_ms=latency,
|
|
413
|
-
|
|
414
|
-
|
|
654
|
+
message=timeout_message,
|
|
655
|
+
metadata={"timeout_s": node.policy.timeout_s},
|
|
656
|
+
)
|
|
657
|
+
await self._handle_flow_error(
|
|
658
|
+
node=node,
|
|
659
|
+
context=context,
|
|
660
|
+
flow_error=flow_error,
|
|
661
|
+
latency=latency,
|
|
662
|
+
attempt=attempt,
|
|
415
663
|
)
|
|
416
664
|
return
|
|
417
665
|
attempt += 1
|
|
@@ -441,15 +689,24 @@ class PenguiFlow:
|
|
|
441
689
|
extra={"exception": repr(exc)},
|
|
442
690
|
)
|
|
443
691
|
if attempt >= node.policy.max_retries:
|
|
444
|
-
|
|
445
|
-
event="node_failed",
|
|
692
|
+
flow_error = self._create_flow_error(
|
|
446
693
|
node=node,
|
|
447
|
-
context=context,
|
|
448
694
|
trace_id=trace_id,
|
|
695
|
+
code=FlowErrorCode.NODE_EXCEPTION,
|
|
696
|
+
exc=exc,
|
|
449
697
|
attempt=attempt,
|
|
450
698
|
latency_ms=latency,
|
|
451
|
-
|
|
452
|
-
|
|
699
|
+
message=(
|
|
700
|
+
f"Node '{node.name}' raised {type(exc).__name__}: {exc}"
|
|
701
|
+
),
|
|
702
|
+
metadata={"exception_repr": repr(exc)},
|
|
703
|
+
)
|
|
704
|
+
await self._handle_flow_error(
|
|
705
|
+
node=node,
|
|
706
|
+
context=context,
|
|
707
|
+
flow_error=flow_error,
|
|
708
|
+
latency=latency,
|
|
709
|
+
attempt=attempt,
|
|
453
710
|
)
|
|
454
711
|
return
|
|
455
712
|
attempt += 1
|
|
@@ -473,13 +730,300 @@ class PenguiFlow:
|
|
|
473
730
|
delay = min(delay, policy.max_backoff)
|
|
474
731
|
return delay
|
|
475
732
|
|
|
733
|
+
def _create_flow_error(
|
|
734
|
+
self,
|
|
735
|
+
*,
|
|
736
|
+
node: Node,
|
|
737
|
+
trace_id: str | None,
|
|
738
|
+
code: FlowErrorCode,
|
|
739
|
+
exc: BaseException,
|
|
740
|
+
attempt: int,
|
|
741
|
+
latency_ms: float | None,
|
|
742
|
+
message: str | None = None,
|
|
743
|
+
metadata: Mapping[str, Any] | None = None,
|
|
744
|
+
) -> FlowError:
|
|
745
|
+
node_name = node.name
|
|
746
|
+
assert node_name is not None
|
|
747
|
+
meta: dict[str, Any] = {"attempt": attempt}
|
|
748
|
+
if latency_ms is not None:
|
|
749
|
+
meta["latency_ms"] = latency_ms
|
|
750
|
+
if metadata:
|
|
751
|
+
meta.update(metadata)
|
|
752
|
+
return FlowError.from_exception(
|
|
753
|
+
trace_id=trace_id,
|
|
754
|
+
node_name=node_name,
|
|
755
|
+
node_id=node.node_id,
|
|
756
|
+
exc=exc,
|
|
757
|
+
code=code,
|
|
758
|
+
message=message,
|
|
759
|
+
metadata=meta,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
async def _handle_flow_error(
|
|
763
|
+
self,
|
|
764
|
+
*,
|
|
765
|
+
node: Node,
|
|
766
|
+
context: Context,
|
|
767
|
+
flow_error: FlowError,
|
|
768
|
+
latency: float | None,
|
|
769
|
+
attempt: int,
|
|
770
|
+
) -> None:
|
|
771
|
+
original = flow_error.unwrap()
|
|
772
|
+
exception_repr = repr(original) if original is not None else flow_error.message
|
|
773
|
+
extra = {
|
|
774
|
+
"exception": exception_repr,
|
|
775
|
+
"flow_error": flow_error.to_payload(),
|
|
776
|
+
}
|
|
777
|
+
await self._emit_event(
|
|
778
|
+
event="node_failed",
|
|
779
|
+
node=node,
|
|
780
|
+
context=context,
|
|
781
|
+
trace_id=flow_error.trace_id,
|
|
782
|
+
attempt=attempt,
|
|
783
|
+
latency_ms=latency,
|
|
784
|
+
level=logging.ERROR,
|
|
785
|
+
extra=extra,
|
|
786
|
+
)
|
|
787
|
+
if self._emit_errors_to_rookery and flow_error.trace_id is not None:
|
|
788
|
+
await self._emit_to_rookery(flow_error, source=context.owner)
|
|
789
|
+
|
|
790
|
+
async def _invoke_node(
|
|
791
|
+
self,
|
|
792
|
+
node: Node,
|
|
793
|
+
context: Context,
|
|
794
|
+
message: Any,
|
|
795
|
+
trace_id: str | None,
|
|
796
|
+
) -> Any:
|
|
797
|
+
invocation = node.invoke(message, context, registry=self._registry)
|
|
798
|
+
timeout = node.policy.timeout_s
|
|
799
|
+
|
|
800
|
+
if trace_id is None:
|
|
801
|
+
if timeout is None:
|
|
802
|
+
return await invocation
|
|
803
|
+
return await asyncio.wait_for(invocation, timeout)
|
|
804
|
+
|
|
805
|
+
return await self._await_invocation(node, invocation, trace_id, timeout)
|
|
806
|
+
|
|
807
|
+
def _register_invocation_task(
|
|
808
|
+
self, trace_id: str, task: asyncio.Future[Any]
|
|
809
|
+
) -> None:
|
|
810
|
+
tasks = self._trace_invocations.setdefault(trace_id, set())
|
|
811
|
+
tasks.add(task)
|
|
812
|
+
|
|
813
|
+
def _cleanup(finished: asyncio.Future[Any]) -> None:
|
|
814
|
+
remaining = self._trace_invocations.get(trace_id)
|
|
815
|
+
if remaining is None:
|
|
816
|
+
return
|
|
817
|
+
remaining.discard(finished)
|
|
818
|
+
if not remaining:
|
|
819
|
+
self._trace_invocations.pop(trace_id, None)
|
|
820
|
+
|
|
821
|
+
task.add_done_callback(_cleanup)
|
|
822
|
+
|
|
823
|
+
async def _await_invocation(
|
|
824
|
+
self,
|
|
825
|
+
node: Node,
|
|
826
|
+
invocation: Awaitable[Any],
|
|
827
|
+
trace_id: str,
|
|
828
|
+
timeout: float | None,
|
|
829
|
+
) -> Any:
|
|
830
|
+
invocation_task = asyncio.ensure_future(invocation)
|
|
831
|
+
self._register_invocation_task(trace_id, invocation_task)
|
|
832
|
+
|
|
833
|
+
cancel_event = self._trace_events.get(trace_id)
|
|
834
|
+
cancel_waiter: asyncio.Future[Any] | None = None
|
|
835
|
+
if cancel_event is not None:
|
|
836
|
+
cancel_waiter = asyncio.ensure_future(cancel_event.wait())
|
|
837
|
+
|
|
838
|
+
timeout_task: asyncio.Future[Any] | None = None
|
|
839
|
+
if timeout is not None:
|
|
840
|
+
timeout_task = asyncio.ensure_future(asyncio.sleep(timeout))
|
|
841
|
+
|
|
842
|
+
wait_tasks: set[asyncio.Future[Any]] = {invocation_task}
|
|
843
|
+
if cancel_waiter is not None:
|
|
844
|
+
wait_tasks.add(cancel_waiter)
|
|
845
|
+
if timeout_task is not None:
|
|
846
|
+
wait_tasks.add(timeout_task)
|
|
847
|
+
|
|
848
|
+
pending: set[asyncio.Future[Any]] = set()
|
|
849
|
+
try:
|
|
850
|
+
done, pending = await asyncio.wait(
|
|
851
|
+
wait_tasks, return_when=asyncio.FIRST_COMPLETED
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
if invocation_task in done:
|
|
855
|
+
if invocation_task.cancelled():
|
|
856
|
+
raise TraceCancelled(trace_id)
|
|
857
|
+
return invocation_task.result()
|
|
858
|
+
|
|
859
|
+
if cancel_waiter is not None and cancel_waiter in done:
|
|
860
|
+
invocation_task.cancel()
|
|
861
|
+
await asyncio.gather(invocation_task, return_exceptions=True)
|
|
862
|
+
raise TraceCancelled(trace_id)
|
|
863
|
+
|
|
864
|
+
if timeout_task is not None and timeout_task in done:
|
|
865
|
+
invocation_task.cancel()
|
|
866
|
+
await asyncio.gather(invocation_task, return_exceptions=True)
|
|
867
|
+
raise TimeoutError
|
|
868
|
+
|
|
869
|
+
raise RuntimeError("node invocation wait exited without result")
|
|
870
|
+
except asyncio.CancelledError:
|
|
871
|
+
invocation_task.cancel()
|
|
872
|
+
await asyncio.gather(invocation_task, return_exceptions=True)
|
|
873
|
+
if cancel_waiter is not None:
|
|
874
|
+
cancel_waiter.cancel()
|
|
875
|
+
if timeout_task is not None:
|
|
876
|
+
timeout_task.cancel()
|
|
877
|
+
await asyncio.gather(
|
|
878
|
+
*(task for task in (cancel_waiter, timeout_task) if task is not None),
|
|
879
|
+
return_exceptions=True,
|
|
880
|
+
)
|
|
881
|
+
raise
|
|
882
|
+
finally:
|
|
883
|
+
for task in pending:
|
|
884
|
+
task.cancel()
|
|
885
|
+
watchers = [
|
|
886
|
+
task for task in (cancel_waiter, timeout_task) if task is not None
|
|
887
|
+
]
|
|
888
|
+
for watcher in watchers:
|
|
889
|
+
watcher.cancel()
|
|
890
|
+
if watchers:
|
|
891
|
+
await asyncio.gather(*watchers, return_exceptions=True)
|
|
892
|
+
|
|
893
|
+
def _get_trace_id(self, message: Any) -> str | None:
|
|
894
|
+
return getattr(message, "trace_id", None)
|
|
895
|
+
|
|
896
|
+
def _is_trace_cancelled(self, trace_id: str) -> bool:
|
|
897
|
+
event = self._trace_events.get(trace_id)
|
|
898
|
+
return event.is_set() if event is not None else False
|
|
899
|
+
|
|
900
|
+
def _on_message_enqueued(self, message: Any) -> None:
|
|
901
|
+
trace_id = self._get_trace_id(message)
|
|
902
|
+
if trace_id is None:
|
|
903
|
+
return
|
|
904
|
+
self._trace_counts[trace_id] = self._trace_counts.get(trace_id, 0) + 1
|
|
905
|
+
self._trace_events.setdefault(trace_id, asyncio.Event())
|
|
906
|
+
|
|
907
|
+
async def _finalize_message(self, message: Any) -> None:
|
|
908
|
+
trace_id = self._get_trace_id(message)
|
|
909
|
+
if trace_id is None:
|
|
910
|
+
return
|
|
911
|
+
|
|
912
|
+
remaining = self._trace_counts.get(trace_id)
|
|
913
|
+
if remaining is None:
|
|
914
|
+
return
|
|
915
|
+
|
|
916
|
+
remaining -= 1
|
|
917
|
+
if remaining <= 0:
|
|
918
|
+
self._trace_counts.pop(trace_id, None)
|
|
919
|
+
event = self._trace_events.pop(trace_id, None)
|
|
920
|
+
if event is not None and event.is_set():
|
|
921
|
+
await self._emit_event(
|
|
922
|
+
event="trace_cancel_finish",
|
|
923
|
+
node=ROOKERY,
|
|
924
|
+
context=self._contexts[ROOKERY],
|
|
925
|
+
trace_id=trace_id,
|
|
926
|
+
attempt=0,
|
|
927
|
+
latency_ms=None,
|
|
928
|
+
level=logging.INFO,
|
|
929
|
+
)
|
|
930
|
+
self._notify_trace_capacity(trace_id)
|
|
931
|
+
self._latest_wm_hops.pop(trace_id, None)
|
|
932
|
+
else:
|
|
933
|
+
self._trace_counts[trace_id] = remaining
|
|
934
|
+
if self._queue_maxsize <= 0 or remaining <= self._queue_maxsize:
|
|
935
|
+
self._notify_trace_capacity(trace_id)
|
|
936
|
+
|
|
937
|
+
async def _drop_trace_from_floe(self, floe: Floe, trace_id: str) -> None:
|
|
938
|
+
queue = floe.queue
|
|
939
|
+
retained: list[Any] = []
|
|
940
|
+
|
|
941
|
+
while True:
|
|
942
|
+
try:
|
|
943
|
+
item = queue.get_nowait()
|
|
944
|
+
except asyncio.QueueEmpty:
|
|
945
|
+
break
|
|
946
|
+
|
|
947
|
+
if self._get_trace_id(item) == trace_id:
|
|
948
|
+
await self._finalize_message(item)
|
|
949
|
+
continue
|
|
950
|
+
|
|
951
|
+
retained.append(item)
|
|
952
|
+
|
|
953
|
+
for item in retained:
|
|
954
|
+
queue.put_nowait(item)
|
|
955
|
+
|
|
956
|
+
async def cancel(self, trace_id: str) -> bool:
|
|
957
|
+
if not self._running:
|
|
958
|
+
raise RuntimeError("PenguiFlow is not running")
|
|
959
|
+
|
|
960
|
+
active = trace_id in self._trace_counts or trace_id in self._trace_invocations
|
|
961
|
+
if not active:
|
|
962
|
+
return False
|
|
963
|
+
|
|
964
|
+
event = self._trace_events.setdefault(trace_id, asyncio.Event())
|
|
965
|
+
if not event.is_set():
|
|
966
|
+
event.set()
|
|
967
|
+
await self._emit_event(
|
|
968
|
+
event="trace_cancel_start",
|
|
969
|
+
node=OPEN_SEA,
|
|
970
|
+
context=self._contexts[OPEN_SEA],
|
|
971
|
+
trace_id=trace_id,
|
|
972
|
+
attempt=0,
|
|
973
|
+
latency_ms=None,
|
|
974
|
+
level=logging.INFO,
|
|
975
|
+
extra={"pending": self._trace_counts.get(trace_id, 0)},
|
|
976
|
+
)
|
|
977
|
+
else:
|
|
978
|
+
event.set()
|
|
979
|
+
|
|
980
|
+
for floe in list(self._floes):
|
|
981
|
+
await self._drop_trace_from_floe(floe, trace_id)
|
|
982
|
+
|
|
983
|
+
tasks = list(self._trace_invocations.get(trace_id, set()))
|
|
984
|
+
for task in tasks:
|
|
985
|
+
task.cancel()
|
|
986
|
+
|
|
987
|
+
return True
|
|
988
|
+
|
|
989
|
+
async def _await_trace_capacity(self, trace_id: str, *, offset: int = 0) -> None:
|
|
990
|
+
if self._queue_maxsize <= 0:
|
|
991
|
+
return
|
|
992
|
+
|
|
993
|
+
while True:
|
|
994
|
+
pending = self._trace_counts.get(trace_id, 0)
|
|
995
|
+
effective = pending - offset if pending > offset else 0
|
|
996
|
+
if effective < self._queue_maxsize:
|
|
997
|
+
return
|
|
998
|
+
waiter = asyncio.Event()
|
|
999
|
+
waiters = self._trace_capacity_waiters.setdefault(trace_id, [])
|
|
1000
|
+
waiters.append(waiter)
|
|
1001
|
+
try:
|
|
1002
|
+
await waiter.wait()
|
|
1003
|
+
finally:
|
|
1004
|
+
remaining_waiters = self._trace_capacity_waiters.get(trace_id)
|
|
1005
|
+
if remaining_waiters is not None:
|
|
1006
|
+
try:
|
|
1007
|
+
remaining_waiters.remove(waiter)
|
|
1008
|
+
except ValueError:
|
|
1009
|
+
pass
|
|
1010
|
+
if not remaining_waiters:
|
|
1011
|
+
self._trace_capacity_waiters.pop(trace_id, None)
|
|
1012
|
+
|
|
1013
|
+
def _notify_trace_capacity(self, trace_id: str) -> None:
|
|
1014
|
+
waiters = self._trace_capacity_waiters.pop(trace_id, None)
|
|
1015
|
+
if not waiters:
|
|
1016
|
+
return
|
|
1017
|
+
for waiter in waiters:
|
|
1018
|
+
waiter.set()
|
|
1019
|
+
|
|
476
1020
|
def _controller_postprocess(
|
|
477
1021
|
self,
|
|
478
1022
|
node: Node,
|
|
479
1023
|
context: Context,
|
|
480
1024
|
incoming: Any,
|
|
481
1025
|
result: Any,
|
|
482
|
-
) -> tuple[str, Any, list[Node] | None]:
|
|
1026
|
+
) -> tuple[str, Any, list[Node] | None, bool]:
|
|
483
1027
|
if isinstance(result, Message):
|
|
484
1028
|
payload = result.payload
|
|
485
1029
|
if isinstance(payload, WM):
|
|
@@ -487,27 +1031,101 @@ class PenguiFlow:
|
|
|
487
1031
|
if result.deadline_s is not None and now > result.deadline_s:
|
|
488
1032
|
final = FinalAnswer(text=DEADLINE_EXCEEDED_TEXT)
|
|
489
1033
|
final_msg = result.model_copy(update={"payload": final})
|
|
490
|
-
return "rookery", final_msg, None
|
|
1034
|
+
return "rookery", final_msg, None, False
|
|
491
1035
|
|
|
492
|
-
if
|
|
1036
|
+
if (
|
|
1037
|
+
payload.budget_tokens is not None
|
|
1038
|
+
and payload.tokens_used >= payload.budget_tokens
|
|
1039
|
+
):
|
|
1040
|
+
final = FinalAnswer(text=TOKEN_BUDGET_EXCEEDED_TEXT)
|
|
1041
|
+
final_msg = result.model_copy(update={"payload": final})
|
|
1042
|
+
return "rookery", final_msg, None, False
|
|
1043
|
+
|
|
1044
|
+
incoming_hops: int | None = None
|
|
1045
|
+
if (
|
|
1046
|
+
isinstance(incoming, Message)
|
|
1047
|
+
and isinstance(incoming.payload, WM)
|
|
1048
|
+
):
|
|
1049
|
+
incoming_hops = incoming.payload.hops
|
|
1050
|
+
|
|
1051
|
+
current_hops = payload.hops
|
|
1052
|
+
if incoming_hops is not None and current_hops <= incoming_hops:
|
|
1053
|
+
next_hops = incoming_hops + 1
|
|
1054
|
+
else:
|
|
1055
|
+
next_hops = current_hops
|
|
1056
|
+
|
|
1057
|
+
if (
|
|
1058
|
+
payload.budget_hops is not None
|
|
1059
|
+
and next_hops >= payload.budget_hops
|
|
1060
|
+
):
|
|
493
1061
|
final = FinalAnswer(text=BUDGET_EXCEEDED_TEXT)
|
|
494
1062
|
final_msg = result.model_copy(update={"payload": final})
|
|
495
|
-
return "rookery", final_msg, None
|
|
1063
|
+
return "rookery", final_msg, None, False
|
|
496
1064
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
1065
|
+
if next_hops != current_hops:
|
|
1066
|
+
updated_payload = payload.model_copy(update={"hops": next_hops})
|
|
1067
|
+
prepared = result.model_copy(update={"payload": updated_payload})
|
|
1068
|
+
else:
|
|
1069
|
+
prepared = result
|
|
1070
|
+
|
|
1071
|
+
stream_updates = (
|
|
1072
|
+
payload.budget_hops is None
|
|
1073
|
+
and payload.budget_tokens is None
|
|
1074
|
+
)
|
|
1075
|
+
return "context", prepared, [node], stream_updates
|
|
500
1076
|
|
|
501
1077
|
if isinstance(payload, FinalAnswer):
|
|
502
|
-
return "rookery", result, None
|
|
1078
|
+
return "rookery", result, None, False
|
|
1079
|
+
|
|
1080
|
+
return "context", result, None, False
|
|
1081
|
+
|
|
1082
|
+
def _deadline_expired(self, message: Any) -> bool:
|
|
1083
|
+
if isinstance(message, Message) and message.deadline_s is not None:
|
|
1084
|
+
return time.time() > message.deadline_s
|
|
1085
|
+
return False
|
|
1086
|
+
|
|
1087
|
+
async def _handle_deadline_expired(
|
|
1088
|
+
self, context: Context, message: Message
|
|
1089
|
+
) -> None:
|
|
1090
|
+
payload = message.payload
|
|
1091
|
+
if not isinstance(payload, FinalAnswer):
|
|
1092
|
+
payload = FinalAnswer(text=DEADLINE_EXCEEDED_TEXT)
|
|
1093
|
+
final_msg = message.model_copy(update={"payload": payload})
|
|
1094
|
+
await self._emit_to_rookery(final_msg, source=context.owner)
|
|
1095
|
+
|
|
1096
|
+
async def _emit_to_rookery(
|
|
1097
|
+
self, message: Any, *, source: Node | Endpoint | None = None
|
|
1098
|
+
) -> None:
|
|
1099
|
+
"""Route ``message`` to the Rookery sink regardless of graph edges."""
|
|
1100
|
+
|
|
1101
|
+
rookery_context = self._contexts[ROOKERY]
|
|
1102
|
+
incoming = rookery_context._incoming
|
|
1103
|
+
|
|
1104
|
+
floe: Floe | None = None
|
|
1105
|
+
if source is not None:
|
|
1106
|
+
floe = incoming.get(source)
|
|
1107
|
+
if floe is None and incoming:
|
|
1108
|
+
floe = next(iter(incoming.values()))
|
|
1109
|
+
|
|
1110
|
+
self._on_message_enqueued(message)
|
|
1111
|
+
|
|
1112
|
+
if floe is not None:
|
|
1113
|
+
await floe.queue.put(message)
|
|
1114
|
+
else:
|
|
1115
|
+
buffer = rookery_context._buffer
|
|
1116
|
+
buffer.append(message)
|
|
503
1117
|
|
|
504
|
-
|
|
1118
|
+
if isinstance(message, Message):
|
|
1119
|
+
payload = message.payload
|
|
1120
|
+
if isinstance(payload, WM):
|
|
1121
|
+
trace_id = message.trace_id
|
|
1122
|
+
self._latest_wm_hops[trace_id] = payload.hops
|
|
505
1123
|
|
|
506
1124
|
async def _emit_event(
|
|
507
1125
|
self,
|
|
508
1126
|
*,
|
|
509
1127
|
event: str,
|
|
510
|
-
node: Node,
|
|
1128
|
+
node: Node | Endpoint,
|
|
511
1129
|
context: Context,
|
|
512
1130
|
trace_id: str | None,
|
|
513
1131
|
attempt: int,
|
|
@@ -515,31 +1133,50 @@ class PenguiFlow:
|
|
|
515
1133
|
level: int,
|
|
516
1134
|
extra: dict[str, Any] | None = None,
|
|
517
1135
|
) -> None:
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
1136
|
+
node_name = getattr(node, "name", None)
|
|
1137
|
+
node_id = getattr(node, "node_id", node_name)
|
|
1138
|
+
queue_depth_in = context.queue_depth_in()
|
|
1139
|
+
queue_depth_out = context.queue_depth_out()
|
|
1140
|
+
outgoing = context.outgoing_count()
|
|
1141
|
+
|
|
1142
|
+
trace_pending: int | None = None
|
|
1143
|
+
trace_inflight = 0
|
|
1144
|
+
trace_cancelled = False
|
|
1145
|
+
if trace_id is not None:
|
|
1146
|
+
trace_pending = self._trace_counts.get(trace_id, 0)
|
|
1147
|
+
trace_inflight = len(self._trace_invocations.get(trace_id, set()))
|
|
1148
|
+
trace_cancelled = self._is_trace_cancelled(trace_id)
|
|
1149
|
+
|
|
1150
|
+
event_obj = FlowEvent(
|
|
1151
|
+
event_type=event,
|
|
1152
|
+
ts=time.time(),
|
|
1153
|
+
node_name=node_name,
|
|
1154
|
+
node_id=node_id,
|
|
1155
|
+
trace_id=trace_id,
|
|
1156
|
+
attempt=attempt,
|
|
1157
|
+
latency_ms=latency_ms,
|
|
1158
|
+
queue_depth_in=queue_depth_in,
|
|
1159
|
+
queue_depth_out=queue_depth_out,
|
|
1160
|
+
outgoing_edges=outgoing,
|
|
1161
|
+
queue_maxsize=self._queue_maxsize,
|
|
1162
|
+
trace_pending=trace_pending,
|
|
1163
|
+
trace_inflight=trace_inflight,
|
|
1164
|
+
trace_cancelled=trace_cancelled,
|
|
1165
|
+
extra=extra or {},
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
logger.log(level, event, extra=event_obj.to_payload())
|
|
532
1169
|
|
|
533
1170
|
for middleware in list(self._middlewares):
|
|
534
1171
|
try:
|
|
535
|
-
await middleware(
|
|
1172
|
+
await middleware(event_obj)
|
|
536
1173
|
except Exception as exc: # pragma: no cover - defensive
|
|
537
1174
|
logger.exception(
|
|
538
1175
|
"middleware_error",
|
|
539
1176
|
extra={
|
|
540
1177
|
"event": "middleware_error",
|
|
541
|
-
"node_name":
|
|
542
|
-
"node_id":
|
|
1178
|
+
"node_name": node_name,
|
|
1179
|
+
"node_id": node_id,
|
|
543
1180
|
"exception": exc,
|
|
544
1181
|
},
|
|
545
1182
|
)
|
|
@@ -572,13 +1209,38 @@ async def call_playbook(
|
|
|
572
1209
|
playbook: PlaybookFactory,
|
|
573
1210
|
parent_msg: Message,
|
|
574
1211
|
timeout: float | None = None,
|
|
1212
|
+
*,
|
|
1213
|
+
runtime: PenguiFlow | None = None,
|
|
575
1214
|
) -> Any:
|
|
576
1215
|
"""Execute a subflow playbook and return the first Rookery payload."""
|
|
577
1216
|
|
|
578
1217
|
flow, registry = playbook()
|
|
579
1218
|
flow.run(registry=registry)
|
|
580
1219
|
|
|
1220
|
+
trace_id = getattr(parent_msg, "trace_id", None)
|
|
1221
|
+
cancel_watch: asyncio.Task[None] | None = None
|
|
1222
|
+
pre_cancelled = False
|
|
1223
|
+
|
|
1224
|
+
if runtime is not None and trace_id is not None:
|
|
1225
|
+
parent_event = runtime._trace_events.setdefault(trace_id, asyncio.Event())
|
|
1226
|
+
|
|
1227
|
+
if parent_event.is_set():
|
|
1228
|
+
pre_cancelled = True
|
|
1229
|
+
else:
|
|
1230
|
+
|
|
1231
|
+
async def _mirror_cancel() -> None:
|
|
1232
|
+
try:
|
|
1233
|
+
await parent_event.wait()
|
|
1234
|
+
except asyncio.CancelledError:
|
|
1235
|
+
return
|
|
1236
|
+
with suppress(Exception):
|
|
1237
|
+
await flow.cancel(trace_id)
|
|
1238
|
+
|
|
1239
|
+
cancel_watch = asyncio.create_task(_mirror_cancel())
|
|
1240
|
+
|
|
581
1241
|
try:
|
|
1242
|
+
if pre_cancelled:
|
|
1243
|
+
raise TraceCancelled(trace_id)
|
|
582
1244
|
await flow.emit(parent_msg)
|
|
583
1245
|
fetch_coro = flow.fetch()
|
|
584
1246
|
if timeout is not None:
|
|
@@ -588,7 +1250,20 @@ async def call_playbook(
|
|
|
588
1250
|
if isinstance(result_msg, Message):
|
|
589
1251
|
return result_msg.payload
|
|
590
1252
|
return result_msg
|
|
1253
|
+
except TraceCancelled:
|
|
1254
|
+
if trace_id is not None and not pre_cancelled:
|
|
1255
|
+
with suppress(Exception):
|
|
1256
|
+
await flow.cancel(trace_id)
|
|
1257
|
+
raise
|
|
1258
|
+
except asyncio.CancelledError:
|
|
1259
|
+
if trace_id is not None:
|
|
1260
|
+
with suppress(Exception):
|
|
1261
|
+
await flow.cancel(trace_id)
|
|
1262
|
+
raise
|
|
591
1263
|
finally:
|
|
1264
|
+
if cancel_watch is not None:
|
|
1265
|
+
cancel_watch.cancel()
|
|
1266
|
+
await asyncio.gather(cancel_watch, return_exceptions=True)
|
|
592
1267
|
await asyncio.shield(flow.stop())
|
|
593
1268
|
|
|
594
1269
|
|