penguiflow 2.0.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of penguiflow might be problematic. Click here for more details.

penguiflow/core.py CHANGED
@@ -9,17 +9,20 @@ from __future__ import annotations
9
9
  import asyncio
10
10
  import logging
11
11
  import time
12
+ import warnings
12
13
  from collections import deque
13
14
  from collections.abc import Awaitable, Callable, Mapping, Sequence
14
15
  from contextlib import suppress
15
16
  from dataclasses import dataclass
16
- from typing import Any
17
+ from typing import Any, cast
17
18
 
19
+ from .bus import BusEnvelope, MessageBus
18
20
  from .errors import FlowError, FlowErrorCode
19
21
  from .metrics import FlowEvent
20
22
  from .middlewares import Middleware
21
23
  from .node import Node, NodePolicy
22
24
  from .registry import ModelRegistry
25
+ from .state import RemoteBinding, StateStore, StoredEvent
23
26
  from .types import WM, FinalAnswer, Message, StreamChunk
24
27
 
25
28
  logger = logging.getLogger("penguiflow.core")
@@ -143,8 +146,7 @@ class Context:
143
146
  if self._runtime is None:
144
147
  raise RuntimeError("Context is not attached to a running flow")
145
148
  for floe in self._resolve_targets(to, self._outgoing):
146
- self._runtime._on_message_enqueued(msg)
147
- await floe.queue.put(msg)
149
+ await self._runtime._send_to_floe(floe, msg)
148
150
 
149
151
  def emit_nowait(
150
152
  self, msg: Any, to: Node | Endpoint | Sequence[Node | Endpoint] | None = None
@@ -152,8 +154,7 @@ class Context:
152
154
  if self._runtime is None:
153
155
  raise RuntimeError("Context is not attached to a running flow")
154
156
  for floe in self._resolve_targets(to, self._outgoing):
155
- self._runtime._on_message_enqueued(msg)
156
- floe.queue.put_nowait(msg)
157
+ self._runtime._send_to_floe_nowait(floe, msg)
157
158
 
158
159
  async def emit_chunk(
159
160
  self,
@@ -301,6 +302,8 @@ class PenguiFlow:
301
302
  allow_cycles: bool = False,
302
303
  middlewares: Sequence[Middleware] | None = None,
303
304
  emit_errors_to_rookery: bool = False,
305
+ state_store: StateStore | None = None,
306
+ message_bus: MessageBus | None = None,
304
307
  ) -> None:
305
308
  self._queue_maxsize = queue_maxsize
306
309
  self._allow_cycles = allow_cycles
@@ -314,10 +317,14 @@ class PenguiFlow:
314
317
  self._middlewares: list[Middleware] = list(middlewares or [])
315
318
  self._trace_counts: dict[str, int] = {}
316
319
  self._trace_events: dict[str, asyncio.Event] = {}
317
- self._trace_invocations: dict[str, set[asyncio.Future[Any]]] = {}
320
+ self._trace_invocations: dict[str, set[asyncio.Task[Any]]] = {}
321
+ self._external_tasks: dict[str, set[asyncio.Future[Any]]] = {}
318
322
  self._trace_capacity_waiters: dict[str, list[asyncio.Event]] = {}
319
323
  self._latest_wm_hops: dict[str, int] = {}
320
324
  self._emit_errors_to_rookery = emit_errors_to_rookery
325
+ self._state_store = state_store
326
+ self._message_bus = message_bus
327
+ self._bus_tasks: set[asyncio.Task[None]] = set()
321
328
 
322
329
  self._build_graph(adjacencies)
323
330
 
@@ -487,6 +494,29 @@ class PenguiFlow:
487
494
  task.cancel()
488
495
  await asyncio.gather(*self._tasks, return_exceptions=True)
489
496
  self._tasks.clear()
497
+ if self._trace_invocations:
498
+ pending: list[asyncio.Task[Any]] = []
499
+ for invocation_tasks in self._trace_invocations.values():
500
+ for task in invocation_tasks:
501
+ if not task.done():
502
+ task.cancel()
503
+ pending.append(task)
504
+ if pending:
505
+ await asyncio.gather(*pending, return_exceptions=True)
506
+ self._trace_invocations.clear()
507
+ if self._external_tasks:
508
+ pending_ext: list[asyncio.Future[Any]] = []
509
+ for external_tasks in self._external_tasks.values():
510
+ for external_task in external_tasks:
511
+ if not external_task.done():
512
+ external_task.cancel()
513
+ pending_ext.append(external_task)
514
+ if pending_ext:
515
+ await asyncio.gather(*pending_ext, return_exceptions=True)
516
+ self._external_tasks.clear()
517
+ if self._bus_tasks:
518
+ await asyncio.gather(*self._bus_tasks, return_exceptions=True)
519
+ self._bus_tasks.clear()
490
520
  self._trace_counts.clear()
491
521
  self._trace_events.clear()
492
522
  self._trace_invocations.clear()
@@ -547,6 +577,84 @@ class PenguiFlow:
547
577
  await self._finalize_message(result)
548
578
  return result
549
579
 
580
+ async def load_history(self, trace_id: str) -> Sequence[StoredEvent]:
581
+ """Return the persisted history for ``trace_id`` from the state store."""
582
+
583
+ if self._state_store is None:
584
+ raise RuntimeError("PenguiFlow was created without a state_store")
585
+ return await self._state_store.load_history(trace_id)
586
+
587
+ def ensure_trace_event(self, trace_id: str) -> asyncio.Event:
588
+ """Return (and create if needed) the cancellation event for ``trace_id``."""
589
+
590
+ return self._trace_events.setdefault(trace_id, asyncio.Event())
591
+
592
+ def register_external_task(self, trace_id: str, task: asyncio.Future[Any]) -> None:
593
+ """Track an externally created task for cancellation bookkeeping."""
594
+
595
+ if trace_id is None:
596
+ return
597
+ tasks = self._external_tasks.get(trace_id)
598
+ if tasks is None:
599
+ tasks = set[asyncio.Future[Any]]()
600
+ self._external_tasks[trace_id] = tasks
601
+ tasks.add(task)
602
+
603
+ def _cleanup(finished: asyncio.Future[Any]) -> None:
604
+ remaining = self._external_tasks.get(trace_id)
605
+ if remaining is None:
606
+ return
607
+ remaining.discard(finished)
608
+ if not remaining:
609
+ self._external_tasks.pop(trace_id, None)
610
+
611
+ task.add_done_callback(_cleanup)
612
+
613
+ async def save_remote_binding(self, binding: RemoteBinding) -> None:
614
+ """Persist a remote binding if a state store is configured."""
615
+
616
+ if self._state_store is None:
617
+ return
618
+ try:
619
+ await self._state_store.save_remote_binding(binding)
620
+ except Exception as exc: # pragma: no cover - defensive logging
621
+ logger.exception(
622
+ "state_store_binding_failed",
623
+ extra={
624
+ "event": "state_store_binding_failed",
625
+ "trace_id": binding.trace_id,
626
+ "context_id": binding.context_id,
627
+ "task_id": binding.task_id,
628
+ "agent_url": binding.agent_url,
629
+ "exception": repr(exc),
630
+ },
631
+ )
632
+
633
+ async def record_remote_event(
634
+ self,
635
+ *,
636
+ event: str,
637
+ node: Node,
638
+ context: Context,
639
+ trace_id: str | None,
640
+ latency_ms: float | None,
641
+ level: int = logging.INFO,
642
+ extra: Mapping[str, Any] | None = None,
643
+ ) -> None:
644
+ """Emit a structured :class:`FlowEvent` for remote transport activity."""
645
+
646
+ payload = dict(extra or {})
647
+ await self._emit_event(
648
+ event=event,
649
+ node=node,
650
+ context=context,
651
+ trace_id=trace_id,
652
+ attempt=0,
653
+ latency_ms=latency_ms,
654
+ level=level,
655
+ extra=payload,
656
+ )
657
+
550
658
  async def _execute_with_reliability(
551
659
  self,
552
660
  node: Node,
@@ -579,6 +687,21 @@ class PenguiFlow:
579
687
  trace_id,
580
688
  )
581
689
 
690
+ if (
691
+ result is not None
692
+ and self._expects_message_output(node)
693
+ and not isinstance(result, Message)
694
+ ):
695
+ node_name = node.name or node.node_id
696
+ warning_msg = (
697
+ "Node "
698
+ f"'{node_name}' is registered for Message -> Message outputs "
699
+ f"but returned {type(result).__name__}. "
700
+ "Return a penguiflow.types.Message to preserve headers, "
701
+ "trace_id, and meta."
702
+ )
703
+ warnings.warn(warning_msg, RuntimeWarning, stacklevel=2)
704
+
582
705
  if result is not None:
583
706
  (
584
707
  destination,
@@ -805,16 +928,19 @@ class PenguiFlow:
805
928
  return await self._await_invocation(node, invocation, trace_id, timeout)
806
929
 
807
930
  def _register_invocation_task(
808
- self, trace_id: str, task: asyncio.Future[Any]
931
+ self, trace_id: str, task: asyncio.Task[Any]
809
932
  ) -> None:
810
- tasks = self._trace_invocations.setdefault(trace_id, set())
933
+ tasks = self._trace_invocations.get(trace_id)
934
+ if tasks is None:
935
+ tasks = set[asyncio.Task[Any]]()
936
+ self._trace_invocations[trace_id] = tasks
811
937
  tasks.add(task)
812
938
 
813
939
  def _cleanup(finished: asyncio.Future[Any]) -> None:
814
940
  remaining = self._trace_invocations.get(trace_id)
815
941
  if remaining is None:
816
942
  return
817
- remaining.discard(finished)
943
+ remaining.discard(cast(asyncio.Task[Any], finished))
818
944
  if not remaining:
819
945
  self._trace_invocations.pop(trace_id, None)
820
946
 
@@ -827,7 +953,7 @@ class PenguiFlow:
827
953
  trace_id: str,
828
954
  timeout: float | None,
829
955
  ) -> Any:
830
- invocation_task = asyncio.ensure_future(invocation)
956
+ invocation_task = cast(asyncio.Task[Any], asyncio.ensure_future(invocation))
831
957
  self._register_invocation_task(trace_id, invocation_task)
832
958
 
833
959
  cancel_event = self._trace_events.get(trace_id)
@@ -904,6 +1030,89 @@ class PenguiFlow:
904
1030
  self._trace_counts[trace_id] = self._trace_counts.get(trace_id, 0) + 1
905
1031
  self._trace_events.setdefault(trace_id, asyncio.Event())
906
1032
 
1033
+ def _node_label(self, node: Node | Endpoint | None) -> str | None:
1034
+ if node is None:
1035
+ return None
1036
+ name = getattr(node, "name", None)
1037
+ if name:
1038
+ return name
1039
+ return getattr(node, "node_id", None)
1040
+
1041
+ def _build_bus_envelope(
1042
+ self,
1043
+ source: Node | Endpoint | None,
1044
+ target: Node | Endpoint | None,
1045
+ message: Any,
1046
+ ) -> BusEnvelope:
1047
+ source_name = self._node_label(source)
1048
+ target_name = self._node_label(target)
1049
+ edge = f"{source_name or '*'}->{target_name or '*'}"
1050
+ headers: Mapping[str, Any] | None = None
1051
+ meta: Mapping[str, Any] | None = None
1052
+ if isinstance(message, Message):
1053
+ headers = message.headers.model_dump()
1054
+ meta = dict(message.meta)
1055
+ return BusEnvelope(
1056
+ edge=edge,
1057
+ source=source_name,
1058
+ target=target_name,
1059
+ trace_id=self._get_trace_id(message),
1060
+ payload=message,
1061
+ headers=headers,
1062
+ meta=meta,
1063
+ )
1064
+
1065
+ async def _publish_to_bus(
1066
+ self,
1067
+ source: Node | Endpoint | None,
1068
+ target: Node | Endpoint | None,
1069
+ message: Any,
1070
+ ) -> None:
1071
+ if self._message_bus is None:
1072
+ return
1073
+ envelope = self._build_bus_envelope(source, target, message)
1074
+ try:
1075
+ await self._message_bus.publish(envelope)
1076
+ except Exception as exc:
1077
+ logger.exception(
1078
+ "message_bus_publish_failed",
1079
+ extra={
1080
+ "event": "message_bus_publish_failed",
1081
+ "edge": envelope.edge,
1082
+ "trace_id": envelope.trace_id,
1083
+ "exception": repr(exc),
1084
+ },
1085
+ )
1086
+
1087
+ def _schedule_bus_publish(
1088
+ self,
1089
+ source: Node | Endpoint | None,
1090
+ target: Node | Endpoint | None,
1091
+ message: Any,
1092
+ ) -> None:
1093
+ if self._message_bus is None:
1094
+ return
1095
+ loop = asyncio.get_running_loop()
1096
+ task = loop.create_task(self._publish_to_bus(source, target, message))
1097
+ self._bus_tasks.add(task)
1098
+
1099
+ def _cleanup(done: asyncio.Task[None]) -> None:
1100
+ self._bus_tasks.discard(done)
1101
+
1102
+ task.add_done_callback(_cleanup)
1103
+
1104
+ async def _send_to_floe(self, floe: Floe, message: Any) -> None:
1105
+ self._on_message_enqueued(message)
1106
+ if self._message_bus is not None:
1107
+ await self._publish_to_bus(floe.source, floe.target, message)
1108
+ await floe.queue.put(message)
1109
+
1110
+ def _send_to_floe_nowait(self, floe: Floe, message: Any) -> None:
1111
+ self._on_message_enqueued(message)
1112
+ if self._message_bus is not None:
1113
+ self._schedule_bus_publish(floe.source, floe.target, message)
1114
+ floe.queue.put_nowait(message)
1115
+
907
1116
  async def _finalize_message(self, message: Any) -> None:
908
1117
  trace_id = self._get_trace_id(message)
909
1118
  if trace_id is None:
@@ -1017,6 +1226,29 @@ class PenguiFlow:
1017
1226
  for waiter in waiters:
1018
1227
  waiter.set()
1019
1228
 
1229
+ def _expects_message_output(self, node: Node) -> bool:
1230
+ registry = self._registry
1231
+ if registry is None:
1232
+ return False
1233
+
1234
+ models = getattr(registry, "models", None)
1235
+ if models is None:
1236
+ return False
1237
+
1238
+ node_name = node.name
1239
+ if not node_name:
1240
+ return False
1241
+
1242
+ try:
1243
+ _in_model, out_model = models(node_name)
1244
+ except Exception: # pragma: no cover - registry without entry
1245
+ return False
1246
+
1247
+ try:
1248
+ return issubclass(out_model, Message)
1249
+ except TypeError:
1250
+ return False
1251
+
1020
1252
  def _controller_postprocess(
1021
1253
  self,
1022
1254
  node: Node,
@@ -1107,11 +1339,12 @@ class PenguiFlow:
1107
1339
  if floe is None and incoming:
1108
1340
  floe = next(iter(incoming.values()))
1109
1341
 
1110
- self._on_message_enqueued(message)
1111
-
1112
1342
  if floe is not None:
1113
- await floe.queue.put(message)
1343
+ await self._send_to_floe(floe, message)
1114
1344
  else:
1345
+ self._on_message_enqueued(message)
1346
+ if self._message_bus is not None:
1347
+ await self._publish_to_bus(source, ROOKERY, message)
1115
1348
  buffer = rookery_context._buffer
1116
1349
  buffer.append(message)
1117
1350
 
@@ -1167,6 +1400,21 @@ class PenguiFlow:
1167
1400
 
1168
1401
  logger.log(level, event, extra=event_obj.to_payload())
1169
1402
 
1403
+ if self._state_store is not None:
1404
+ stored_event = StoredEvent.from_flow_event(event_obj)
1405
+ try:
1406
+ await self._state_store.save_event(stored_event)
1407
+ except Exception as exc:
1408
+ logger.exception(
1409
+ "state_store_save_failed",
1410
+ extra={
1411
+ "event": "state_store_save_failed",
1412
+ "trace_id": stored_event.trace_id,
1413
+ "kind": stored_event.kind,
1414
+ "exception": repr(exc),
1415
+ },
1416
+ )
1417
+
1170
1418
  for middleware in list(self._middlewares):
1171
1419
  try:
1172
1420
  await middleware(event_obj)
penguiflow/debug.py ADDED
@@ -0,0 +1,30 @@
1
+ """Developer-facing debugging helpers for PenguiFlow."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Mapping
6
+ from typing import Any
7
+
8
+ from .metrics import FlowEvent
9
+
10
+
11
+ def format_flow_event(event: FlowEvent) -> dict[str, Any]:
12
+ """Return a structured payload ready for logging.
13
+
14
+ The returned dictionary mirrors :meth:`FlowEvent.to_payload` and flattens any
15
+ embedded ``FlowError`` payload so that log aggregators can index the error
16
+ metadata (``flow_error_code``, ``flow_error_message``, ...).
17
+ """
18
+
19
+ payload = dict(event.to_payload())
20
+ error_payload: Mapping[str, Any] | None = event.error_payload
21
+ if error_payload is not None:
22
+ # Preserve the original payload for downstream consumers.
23
+ payload["flow_error"] = dict(error_payload)
24
+ for key, value in error_payload.items():
25
+ payload[f"flow_error_{key}"] = value
26
+ return payload
27
+
28
+
29
+ __all__ = ["format_flow_event"]
30
+
penguiflow/metrics.py CHANGED
@@ -31,6 +31,15 @@ class FlowEvent:
31
31
  def __post_init__(self) -> None:
32
32
  object.__setattr__(self, "extra", MappingProxyType(dict(self.extra)))
33
33
 
34
+ @property
35
+ def error_payload(self) -> Mapping[str, Any] | None:
36
+ """Return the structured ``FlowError`` payload if present."""
37
+
38
+ raw_payload = self.extra.get("flow_error")
39
+ if isinstance(raw_payload, Mapping):
40
+ return MappingProxyType(dict(raw_payload))
41
+ return None
42
+
34
43
  @property
35
44
  def queue_depth(self) -> int:
36
45
  """Return the combined depth of incoming and outgoing queues."""
penguiflow/middlewares.py CHANGED
@@ -2,10 +2,14 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import logging
6
+ from collections.abc import Callable
5
7
  from typing import Protocol
6
8
 
7
9
  from .metrics import FlowEvent
8
10
 
11
+ LatencyCallback = Callable[[str, float, FlowEvent], None]
12
+
9
13
 
10
14
  class Middleware(Protocol):
11
15
  """Base middleware signature receiving :class:`FlowEvent` objects."""
@@ -13,4 +17,71 @@ class Middleware(Protocol):
13
17
  async def __call__(self, event: FlowEvent) -> None: ...
14
18
 
15
19
 
16
- __all__ = ["Middleware", "FlowEvent"]
20
+ def log_flow_events(
21
+ logger: logging.Logger | None = None,
22
+ *,
23
+ start_level: int = logging.INFO,
24
+ success_level: int = logging.INFO,
25
+ error_level: int = logging.ERROR,
26
+ latency_callback: LatencyCallback | None = None,
27
+ ) -> Middleware:
28
+ """Return middleware that emits structured node lifecycle logs.
29
+
30
+ Parameters
31
+ ----------
32
+ logger:
33
+ Optional :class:`logging.Logger` instance. When omitted a logger named
34
+ ``"penguiflow.flow"`` is used.
35
+ start_level, success_level, error_level:
36
+ Logging levels for ``node_start``, ``node_success``, and
37
+ ``node_error`` events respectively.
38
+ latency_callback:
39
+ Optional callable invoked with ``(event_type, latency_ms, event)`` for
40
+ ``node_success`` and ``node_error`` events. Use this hook to connect the
41
+ middleware to histogram-based metrics backends without
42
+ re-implementing timing logic.
43
+ """
44
+
45
+ log = logger or logging.getLogger("penguiflow.flow")
46
+
47
+ async def _middleware(event: FlowEvent) -> None:
48
+ if event.event_type not in {"node_start", "node_success", "node_error"}:
49
+ return
50
+
51
+ payload = event.to_payload()
52
+ log_level = start_level
53
+
54
+ if event.event_type == "node_start":
55
+ log_level = start_level
56
+ elif event.event_type == "node_success":
57
+ log_level = success_level
58
+ else:
59
+ log_level = error_level
60
+ if event.error_payload is not None:
61
+ payload = dict(payload)
62
+ payload["error_payload"] = dict(event.error_payload)
63
+
64
+ log.log(log_level, event.event_type, extra=payload)
65
+
66
+ if (
67
+ latency_callback is not None
68
+ and event.event_type in {"node_success", "node_error"}
69
+ and event.latency_ms is not None
70
+ ):
71
+ try:
72
+ latency_callback(event.event_type, float(event.latency_ms), event)
73
+ except Exception:
74
+ log.exception(
75
+ "log_flow_events_latency_callback_error",
76
+ extra={
77
+ "event": "log_flow_events_latency_callback_error",
78
+ "node_name": event.node_name,
79
+ "node_id": event.node_id,
80
+ "trace_id": event.trace_id,
81
+ },
82
+ )
83
+
84
+ return _middleware
85
+
86
+
87
+ __all__ = ["Middleware", "FlowEvent", "log_flow_events", "LatencyCallback"]
penguiflow/registry.py CHANGED
@@ -15,6 +15,8 @@ ModelT = TypeVar("ModelT", bound=BaseModel)
15
15
  class RegistryEntry:
16
16
  in_adapter: TypeAdapter[Any]
17
17
  out_adapter: TypeAdapter[Any]
18
+ in_model: type[BaseModel]
19
+ out_model: type[BaseModel]
18
20
 
19
21
 
20
22
  class ModelRegistry:
@@ -36,6 +38,8 @@ class ModelRegistry:
36
38
  self._entries[node_name] = RegistryEntry(
37
39
  TypeAdapter(in_model),
38
40
  TypeAdapter(out_model),
41
+ in_model,
42
+ out_model,
39
43
  )
40
44
 
41
45
  def adapters(self, node_name: str) -> tuple[TypeAdapter[Any], TypeAdapter[Any]]:
@@ -45,5 +49,22 @@ class ModelRegistry:
45
49
  raise KeyError(f"Node '{node_name}' not registered") from exc
46
50
  return entry.in_adapter, entry.out_adapter
47
51
 
52
+ def models(
53
+ self, node_name: str
54
+ ) -> tuple[type[BaseModel], type[BaseModel]]:
55
+ """Return the registered models for ``node_name``.
56
+
57
+ Raises
58
+ ------
59
+ KeyError
60
+ If the node has not been registered.
61
+ """
62
+
63
+ try:
64
+ entry = self._entries[node_name]
65
+ except KeyError as exc:
66
+ raise KeyError(f"Node '{node_name}' not registered") from exc
67
+ return entry.in_model, entry.out_model
68
+
48
69
 
49
70
  __all__ = ["ModelRegistry"]