penguiflow 1.0.3__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of penguiflow might be problematic. Click here for more details.

penguiflow/testkit.py ADDED
@@ -0,0 +1,269 @@
1
+ """Utilities for writing concise PenguiFlow tests.
2
+
3
+ The helpers in this module provide a minimal harness around ``PenguiFlow`` so
4
+ unit tests can focus on the behaviour of their nodes instead of the runtime
5
+ plumbing. Each helper intentionally works with the public runtime surface to
6
+ avoid relying on private attributes, keeping the harness forward compatible
7
+ with the v1 API.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import inspect
14
+ from collections import OrderedDict
15
+ from collections.abc import Awaitable, Callable, Iterable, Sequence
16
+ from dataclasses import dataclass, field
17
+ from itertools import groupby
18
+ from typing import Any
19
+ from weakref import WeakKeyDictionary
20
+
21
+ from .core import PenguiFlow
22
+ from .errors import FlowErrorCode
23
+ from .metrics import FlowEvent
24
+ from .types import Message
25
+
26
+ __all__ = ["run_one", "assert_node_sequence", "simulate_error"]
27
+
28
+
29
+ _MAX_TRACE_HISTORY = 64
30
+ _TRACE_HISTORY: OrderedDict[str, list[FlowEvent]] = OrderedDict()
31
+ _RECORDER_STATE: WeakKeyDictionary[PenguiFlow, _RecorderState] = (
32
+ WeakKeyDictionary()
33
+ )
34
+
35
+
36
+ def _register_trace_history(trace_id: str, events: list[FlowEvent]) -> None:
37
+ if not trace_id:
38
+ return
39
+ if trace_id in _TRACE_HISTORY:
40
+ _TRACE_HISTORY.move_to_end(trace_id)
41
+ _TRACE_HISTORY[trace_id] = events
42
+ while len(_TRACE_HISTORY) > _MAX_TRACE_HISTORY:
43
+ _TRACE_HISTORY.popitem(last=False)
44
+
45
+
46
+ @dataclass(slots=True)
47
+ class _RunLog:
48
+ events: list[FlowEvent] = field(default_factory=list)
49
+ traces: dict[str, list[FlowEvent]] = field(default_factory=dict)
50
+ active_traces: set[str] = field(default_factory=set)
51
+
52
+
53
+ class _RecorderState:
54
+ def __init__(self) -> None:
55
+ self._lock = asyncio.Lock()
56
+ self._log = _RunLog()
57
+ self._middleware = _Recorder(self)
58
+
59
+ @property
60
+ def middleware(self) -> _Recorder:
61
+ return self._middleware
62
+
63
+ def begin(self, traces: Iterable[str] | None = None) -> None:
64
+ trace_ids = set(traces or [])
65
+ self._log = _RunLog(active_traces=trace_ids)
66
+ for trace_id in trace_ids:
67
+ bucket: list[FlowEvent] = []
68
+ self._log.traces[trace_id] = bucket
69
+ _register_trace_history(trace_id, bucket)
70
+
71
+ async def record(self, event: FlowEvent) -> None:
72
+ async with self._lock:
73
+ self._log.events.append(event)
74
+ trace_id = event.trace_id
75
+ if trace_id is None:
76
+ return
77
+ bucket = self._log.traces.get(trace_id)
78
+ if bucket is None:
79
+ bucket = []
80
+ self._log.traces[trace_id] = bucket
81
+ _register_trace_history(trace_id, bucket)
82
+ bucket.append(event)
83
+
84
+ def node_sequence(self, trace_id: str) -> list[str]:
85
+ bucket = self._log.traces.get(trace_id)
86
+ if bucket is None:
87
+ bucket = _TRACE_HISTORY.get(trace_id, [])
88
+ sequence: list[str] = []
89
+ for event in bucket:
90
+ if event.event_type != "node_start":
91
+ continue
92
+ name = event.node_name or event.node_id or "<anonymous>"
93
+ sequence.append(name)
94
+ return sequence
95
+
96
+
97
+ class _Recorder:
98
+ def __init__(self, state: _RecorderState) -> None:
99
+ self._state = state
100
+
101
+ async def __call__(self, event: FlowEvent) -> None:
102
+ await self._state.record(event)
103
+
104
+
105
+ def _get_state(flow: PenguiFlow) -> _RecorderState:
106
+ state = _RECORDER_STATE.get(flow)
107
+ if state is None:
108
+ state = _RecorderState()
109
+ _RECORDER_STATE[flow] = state
110
+ middlewares = getattr(flow, "_middlewares", None)
111
+ if middlewares is None:
112
+ raise AttributeError("PenguiFlow instance is missing middleware hooks")
113
+ middleware = state.middleware
114
+ if not any(middleware is existing for existing in middlewares):
115
+ middlewares.append(middleware)
116
+ return state
117
+
118
+
119
+ async def run_one(
120
+ flow: PenguiFlow,
121
+ message: Message,
122
+ *,
123
+ registry: Any | None = None,
124
+ timeout_s: float | None = 1.0,
125
+ ) -> Any:
126
+ """Run ``message`` through ``flow`` and return the first Rookery payload.
127
+
128
+ The flow is started and stopped for the caller. The original message's
129
+ ``trace_id`` is tracked so :func:`assert_node_sequence` can introspect the
130
+ execution order afterwards.
131
+ """
132
+
133
+ if not isinstance(message, Message):
134
+ raise TypeError("run_one expects a penguiflow.types.Message instance")
135
+
136
+ state = _get_state(flow)
137
+ state.begin([message.trace_id])
138
+
139
+ flow.run(registry=registry)
140
+ try:
141
+ await flow.emit(message)
142
+ result_coro = flow.fetch()
143
+ if timeout_s is not None:
144
+ result = await asyncio.wait_for(result_coro, timeout_s)
145
+ else:
146
+ result = await result_coro
147
+ finally:
148
+ await flow.stop()
149
+
150
+ return result
151
+
152
+
153
+ def assert_node_sequence(trace_id: str, expected: Sequence[str]) -> None:
154
+ """Assert that ``expected`` matches the recorded node start order."""
155
+
156
+ expected_nodes = list(expected)
157
+ events = _TRACE_HISTORY.get(trace_id, [])
158
+ if not events:
159
+ raise AssertionError(
160
+ "No recorded events for trace_id="
161
+ f"{trace_id!r}; run a flow with run_one first."
162
+ )
163
+
164
+ actual_nodes = [
165
+ event.node_name or event.node_id or "<anonymous>"
166
+ for event in events
167
+ if event.event_type == "node_start"
168
+ ]
169
+ actual_nodes = [name for name, _ in groupby(actual_nodes)]
170
+ if actual_nodes != expected_nodes:
171
+ raise AssertionError(
172
+ "Node sequence mismatch:\n"
173
+ f" expected: {expected_nodes}\n"
174
+ f" actual: {actual_nodes}"
175
+ )
176
+
177
+
178
+ class _ErrorSimulation:
179
+ def __init__(
180
+ self,
181
+ *,
182
+ node_name: str,
183
+ code: str,
184
+ fail_times: int,
185
+ exception_factory: Callable[[str], Exception],
186
+ result_factory: Callable[[Any], Awaitable[Any] | Any] | None,
187
+ ) -> None:
188
+ self._node_name = node_name
189
+ self._code = code
190
+ self._fail_times = fail_times
191
+ self._exception_factory = exception_factory
192
+ self._result_factory = result_factory
193
+ self._attempts = 0
194
+
195
+ @property
196
+ def attempts(self) -> int:
197
+ return self._attempts
198
+
199
+ @property
200
+ def failures(self) -> int:
201
+ return min(self._attempts, self._fail_times)
202
+
203
+ async def __call__(self, message: Any, _ctx: Any) -> Any:
204
+ self._attempts += 1
205
+ if self._attempts <= self._fail_times:
206
+ text = (
207
+ f"[{self._code}] simulated failure in {self._node_name}"
208
+ f" (attempt {self._attempts})"
209
+ )
210
+ raise self._exception_factory(text)
211
+
212
+ if self._result_factory is None:
213
+ return message
214
+
215
+ result = self._result_factory(message)
216
+ if inspect.isawaitable(result):
217
+ return await result
218
+ return result
219
+
220
+
221
+ def simulate_error(
222
+ node_name: str,
223
+ code: FlowErrorCode | str,
224
+ *,
225
+ fail_times: int = 1,
226
+ result: Any | None = None,
227
+ result_factory: Callable[[Any], Awaitable[Any] | Any] | None = None,
228
+ exception_type: type[Exception] = RuntimeError,
229
+ ) -> Callable[[Any, Any], Awaitable[Any]]:
230
+ """Return an async callable that fails ``fail_times`` before succeeding.
231
+
232
+ The returned coroutine is suitable for wrapping in :class:`~penguiflow.node.Node`
233
+ and is especially useful for retry-centric tests. By default the callable
234
+ echoes the incoming ``message`` once the simulated failures are exhausted, but
235
+ ``result``/``result_factory`` can override the successful return value.
236
+ """
237
+
238
+ if fail_times < 1:
239
+ raise ValueError("fail_times must be >= 1")
240
+ if result is not None and result_factory is not None:
241
+ raise ValueError("Specify only one of result or result_factory")
242
+
243
+ resolved_code = code.value if isinstance(code, FlowErrorCode) else str(code)
244
+
245
+ def _exception_factory(text: str) -> Exception:
246
+ return exception_type(text)
247
+
248
+ if result_factory is None and result is not None:
249
+ async def _const_result(_: Any) -> Any:
250
+ return result
251
+
252
+ result_factory = _const_result
253
+
254
+ simulation = _ErrorSimulation(
255
+ node_name=node_name,
256
+ code=resolved_code,
257
+ fail_times=fail_times,
258
+ exception_factory=_exception_factory,
259
+ result_factory=result_factory,
260
+ )
261
+
262
+ async def _runner(message: Any, ctx: Any) -> Any:
263
+ return await simulation(message, ctx)
264
+
265
+ # Attach useful attributes for introspection in tests without exposing the
266
+ # internal class.
267
+ _runner.simulation = simulation # type: ignore[attr-defined]
268
+ return _runner
269
+
penguiflow/types.py CHANGED
@@ -21,6 +21,17 @@ class Message(BaseModel):
21
21
  trace_id: str = Field(default_factory=lambda: uuid.uuid4().hex)
22
22
  ts: float = Field(default_factory=time.time)
23
23
  deadline_s: float | None = None
24
+ meta: dict[str, Any] = Field(default_factory=dict)
25
+
26
+
27
+ class StreamChunk(BaseModel):
28
+ """Represents a chunk of streamed output."""
29
+
30
+ stream_id: str
31
+ seq: int
32
+ text: str
33
+ done: bool = False
34
+ meta: dict[str, Any] = Field(default_factory=dict)
24
35
 
25
36
 
26
37
  class PlanStep(BaseModel):
@@ -39,7 +50,9 @@ class WM(BaseModel):
39
50
  query: str
40
51
  facts: list[Any] = Field(default_factory=list)
41
52
  hops: int = 0
42
- budget_hops: int = 8
53
+ budget_hops: int | None = 8
54
+ tokens_used: int = 0
55
+ budget_tokens: int | None = None
43
56
  confidence: float = 0.0
44
57
 
45
58
 
@@ -51,6 +64,7 @@ class FinalAnswer(BaseModel):
51
64
  __all__ = [
52
65
  "Headers",
53
66
  "Message",
67
+ "StreamChunk",
54
68
  "PlanStep",
55
69
  "Thought",
56
70
  "WM",
penguiflow/viz.py CHANGED
@@ -3,6 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import re
6
+ from dataclasses import dataclass
6
7
  from typing import TYPE_CHECKING
7
8
 
8
9
  from .core import Endpoint
@@ -11,7 +12,21 @@ from .node import Node
11
12
  if TYPE_CHECKING: # pragma: no cover - type checking only
12
13
  from .core import PenguiFlow
13
14
 
14
- __all__ = ["flow_to_mermaid"]
15
+ __all__ = ["flow_to_mermaid", "flow_to_dot"]
16
+
17
+
18
+ @dataclass
19
+ class _VisualNode:
20
+ identifier: str
21
+ label: str
22
+ classes: list[str]
23
+
24
+
25
+ @dataclass
26
+ class _VisualEdge:
27
+ source: str
28
+ target: str
29
+ label: str | None
15
30
 
16
31
 
17
32
  def flow_to_mermaid(flow: PenguiFlow, *, direction: str = "TD") -> str:
@@ -20,42 +35,131 @@ def flow_to_mermaid(flow: PenguiFlow, *, direction: str = "TD") -> str:
20
35
  Parameters
21
36
  ----------
22
37
  flow:
23
- The `PenguiFlow` instance to visualize.
38
+ The :class:`PenguiFlow` instance to visualize.
24
39
  direction:
25
- Mermaid graph direction ("TD", "LR", etc.). Defaults to top-down.
40
+ Mermaid graph direction (``"TD"``, ``"LR"``, etc.). Defaults to top-down.
26
41
  """
27
42
 
43
+ nodes, edges = _collect_graph(flow)
44
+
28
45
  lines: list[str] = [f"graph {direction}"]
29
- nodes: set[object] = set()
46
+ class_defs = {
47
+ "endpoint": "fill:#e0f2fe,stroke:#0369a1,stroke-width:1px",
48
+ "controller_loop": "fill:#fef3c7,stroke:#b45309,stroke-width:1px",
49
+ }
50
+
51
+ used_definitions: set[str] = set()
52
+
53
+ for node in nodes:
54
+ label = _escape_label(node.label)
55
+ lines.append(f" {node.identifier}[\"{label}\"]")
56
+ for class_name in node.classes:
57
+ used_definitions.add(class_name)
58
+
59
+ for class_name in sorted(used_definitions):
60
+ style = class_defs.get(class_name)
61
+ if style:
62
+ lines.append(f" classDef {class_name} {style}")
63
+
64
+ for node in nodes:
65
+ if node.classes:
66
+ classes = " ".join(node.classes)
67
+ lines.append(f" class {node.identifier} {classes}")
68
+
69
+ for edge in edges:
70
+ label = f"|{edge.label}|" if edge.label else ""
71
+ lines.append(f" {edge.source} -->{label} {edge.target}")
72
+
73
+ return "\n".join(lines)
74
+
75
+
76
+ def flow_to_dot(flow: PenguiFlow, *, rankdir: str = "TB") -> str:
77
+ """Render the flow graph as a Graphviz DOT string.
78
+
79
+ Parameters
80
+ ----------
81
+ flow:
82
+ The :class:`PenguiFlow` instance to visualize.
83
+ rankdir:
84
+ Graph orientation (``"TB"``, ``"LR"``, etc.). Defaults to top-bottom.
85
+ """
86
+
87
+ nodes, edges = _collect_graph(flow)
88
+
89
+ lines: list[str] = ["digraph PenguiFlow {", f" rankdir={rankdir}"]
90
+ lines.append(" node [shape=box, style=rounded]")
91
+
92
+ for node in nodes:
93
+ attributes: list[str] = [f'label="{node.label}"']
94
+ if "endpoint" in node.classes:
95
+ attributes.append('shape=oval')
96
+ attributes.append('style="filled"')
97
+ attributes.append('fillcolor="#e0f2fe"')
98
+ elif "controller_loop" in node.classes:
99
+ attributes.append('style="rounded,filled"')
100
+ attributes.append('fillcolor="#fef3c7"')
101
+ attr_str = ", ".join(attributes)
102
+ lines.append(f" {node.identifier} [{attr_str}]")
103
+
104
+ for edge in edges:
105
+ if edge.label:
106
+ edge_label = _escape_label(edge.label)
107
+ lines.append(
108
+ f" {edge.source} -> {edge.target} [label=\"{edge_label}\"]"
109
+ )
110
+ else:
111
+ lines.append(f" {edge.source} -> {edge.target}")
112
+
113
+ lines.append("}")
114
+ return "\n".join(lines)
30
115
 
31
- for floe in flow._floes: # noqa: SLF001 - visualization accesses internals by design
32
- if floe.source is not None:
33
- nodes.add(floe.source)
34
- if floe.target is not None:
35
- nodes.add(floe.target)
36
116
 
37
- id_lookup: dict[object, str] = {}
117
+ def _collect_graph(flow: PenguiFlow) -> tuple[list[_VisualNode], list[_VisualEdge]]:
118
+ nodes: dict[object, _VisualNode] = {}
119
+ edges: list[_VisualEdge] = []
38
120
  used_ids: set[str] = set()
121
+ loop_sources: set[object] = set()
39
122
 
40
- for entity in nodes:
123
+ def ensure_node(entity: object) -> _VisualNode:
124
+ node = nodes.get(entity)
125
+ if node is not None:
126
+ return node
41
127
  label = _display_label(entity)
42
- node_id = _unique_id(label, used_ids)
43
- used_ids.add(node_id)
44
- id_lookup[entity] = node_id
45
- lines.append(f" {node_id}[\"{label}\"]")
46
-
47
- for floe in flow._floes: # noqa: SLF001
128
+ identifier = _unique_id(label, used_ids)
129
+ used_ids.add(identifier)
130
+ classes: list[str] = []
131
+ if isinstance(entity, Endpoint):
132
+ classes.append("endpoint")
133
+ if isinstance(entity, Node) and entity.allow_cycle:
134
+ classes.append("controller_loop")
135
+ node = _VisualNode(identifier=identifier, label=label, classes=classes)
136
+ nodes[entity] = node
137
+ return node
138
+
139
+ for floe in flow._floes: # noqa: SLF001 - visualization inspects internals
48
140
  source = floe.source
49
141
  target = floe.target
50
142
  if source is None or target is None:
51
143
  continue
52
- src_id = id_lookup.get(source)
53
- tgt_id = id_lookup.get(target)
54
- if src_id is None or tgt_id is None:
55
- continue
56
- lines.append(f" {src_id} --> {tgt_id}")
57
-
58
- return "\n".join(lines)
144
+ src_node = ensure_node(source)
145
+ tgt_node = ensure_node(target)
146
+ if source is target:
147
+ loop_sources.add(source)
148
+ label = "loop"
149
+ elif isinstance(source, Endpoint):
150
+ label = "ingress"
151
+ elif isinstance(target, Endpoint):
152
+ label = "egress"
153
+ else:
154
+ label = None
155
+ edges.append(_VisualEdge(src_node.identifier, tgt_node.identifier, label))
156
+
157
+ if loop_sources:
158
+ for entity, node in nodes.items():
159
+ if entity in loop_sources and "controller_loop" not in node.classes:
160
+ node.classes.append("controller_loop")
161
+
162
+ return list(nodes.values()), edges
59
163
 
60
164
 
61
165
  def _display_label(entity: object) -> str:
@@ -74,3 +178,8 @@ def _unique_id(label: str, used: set[str]) -> str:
74
178
  index += 1
75
179
  candidate = f"{base}_{index}"
76
180
  return candidate
181
+
182
+
183
+ def _escape_label(label: str) -> str:
184
+ return label.replace("\"", "\\\"")
185
+