penguiflow 1.0.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of penguiflow might be problematic. Click here for more details.

penguiflow/testkit.py ADDED
@@ -0,0 +1,269 @@
1
+ """Utilities for writing concise PenguiFlow tests.
2
+
3
+ The helpers in this module provide a minimal harness around ``PenguiFlow`` so
4
+ unit tests can focus on the behaviour of their nodes instead of the runtime
5
+ plumbing. Each helper intentionally works with the public runtime surface to
6
+ avoid relying on private attributes, keeping the harness forward compatible
7
+ with the v1 API.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import inspect
14
+ from collections import OrderedDict
15
+ from collections.abc import Awaitable, Callable, Iterable, Sequence
16
+ from dataclasses import dataclass, field
17
+ from itertools import groupby
18
+ from typing import Any
19
+ from weakref import WeakKeyDictionary
20
+
21
+ from .core import PenguiFlow
22
+ from .errors import FlowErrorCode
23
+ from .metrics import FlowEvent
24
+ from .types import Message
25
+
26
+ __all__ = ["run_one", "assert_node_sequence", "simulate_error"]
27
+
28
+
29
+ _MAX_TRACE_HISTORY = 64
30
+ _TRACE_HISTORY: OrderedDict[str, list[FlowEvent]] = OrderedDict()
31
+ _RECORDER_STATE: WeakKeyDictionary[PenguiFlow, _RecorderState] = (
32
+ WeakKeyDictionary()
33
+ )
34
+
35
+
36
+ def _register_trace_history(trace_id: str, events: list[FlowEvent]) -> None:
37
+ if not trace_id:
38
+ return
39
+ if trace_id in _TRACE_HISTORY:
40
+ _TRACE_HISTORY.move_to_end(trace_id)
41
+ _TRACE_HISTORY[trace_id] = events
42
+ while len(_TRACE_HISTORY) > _MAX_TRACE_HISTORY:
43
+ _TRACE_HISTORY.popitem(last=False)
44
+
45
+
46
+ @dataclass(slots=True)
47
+ class _RunLog:
48
+ events: list[FlowEvent] = field(default_factory=list)
49
+ traces: dict[str, list[FlowEvent]] = field(default_factory=dict)
50
+ active_traces: set[str] = field(default_factory=set)
51
+
52
+
53
+ class _RecorderState:
54
+ def __init__(self) -> None:
55
+ self._lock = asyncio.Lock()
56
+ self._log = _RunLog()
57
+ self._middleware = _Recorder(self)
58
+
59
+ @property
60
+ def middleware(self) -> _Recorder:
61
+ return self._middleware
62
+
63
+ def begin(self, traces: Iterable[str] | None = None) -> None:
64
+ trace_ids = set(traces or [])
65
+ self._log = _RunLog(active_traces=trace_ids)
66
+ for trace_id in trace_ids:
67
+ bucket: list[FlowEvent] = []
68
+ self._log.traces[trace_id] = bucket
69
+ _register_trace_history(trace_id, bucket)
70
+
71
+ async def record(self, event: FlowEvent) -> None:
72
+ async with self._lock:
73
+ self._log.events.append(event)
74
+ trace_id = event.trace_id
75
+ if trace_id is None:
76
+ return
77
+ bucket = self._log.traces.get(trace_id)
78
+ if bucket is None:
79
+ bucket = []
80
+ self._log.traces[trace_id] = bucket
81
+ _register_trace_history(trace_id, bucket)
82
+ bucket.append(event)
83
+
84
+ def node_sequence(self, trace_id: str) -> list[str]:
85
+ bucket = self._log.traces.get(trace_id)
86
+ if bucket is None:
87
+ bucket = _TRACE_HISTORY.get(trace_id, [])
88
+ sequence: list[str] = []
89
+ for event in bucket:
90
+ if event.event_type != "node_start":
91
+ continue
92
+ name = event.node_name or event.node_id or "<anonymous>"
93
+ sequence.append(name)
94
+ return sequence
95
+
96
+
97
+ class _Recorder:
98
+ def __init__(self, state: _RecorderState) -> None:
99
+ self._state = state
100
+
101
+ async def __call__(self, event: FlowEvent) -> None:
102
+ await self._state.record(event)
103
+
104
+
105
+ def _get_state(flow: PenguiFlow) -> _RecorderState:
106
+ state = _RECORDER_STATE.get(flow)
107
+ if state is None:
108
+ state = _RecorderState()
109
+ _RECORDER_STATE[flow] = state
110
+ middlewares = getattr(flow, "_middlewares", None)
111
+ if middlewares is None:
112
+ raise AttributeError("PenguiFlow instance is missing middleware hooks")
113
+ middleware = state.middleware
114
+ if not any(middleware is existing for existing in middlewares):
115
+ middlewares.append(middleware)
116
+ return state
117
+
118
+
119
+ async def run_one(
120
+ flow: PenguiFlow,
121
+ message: Message,
122
+ *,
123
+ registry: Any | None = None,
124
+ timeout_s: float | None = 1.0,
125
+ ) -> Any:
126
+ """Run ``message`` through ``flow`` and return the first Rookery payload.
127
+
128
+ The flow is started and stopped for the caller. The original message's
129
+ ``trace_id`` is tracked so :func:`assert_node_sequence` can introspect the
130
+ execution order afterwards.
131
+ """
132
+
133
+ if not isinstance(message, Message):
134
+ raise TypeError("run_one expects a penguiflow.types.Message instance")
135
+
136
+ state = _get_state(flow)
137
+ state.begin([message.trace_id])
138
+
139
+ flow.run(registry=registry)
140
+ try:
141
+ await flow.emit(message)
142
+ result_coro = flow.fetch()
143
+ if timeout_s is not None:
144
+ result = await asyncio.wait_for(result_coro, timeout_s)
145
+ else:
146
+ result = await result_coro
147
+ finally:
148
+ await flow.stop()
149
+
150
+ return result
151
+
152
+
153
+ def assert_node_sequence(trace_id: str, expected: Sequence[str]) -> None:
154
+ """Assert that ``expected`` matches the recorded node start order."""
155
+
156
+ expected_nodes = list(expected)
157
+ events = _TRACE_HISTORY.get(trace_id, [])
158
+ if not events:
159
+ raise AssertionError(
160
+ "No recorded events for trace_id="
161
+ f"{trace_id!r}; run a flow with run_one first."
162
+ )
163
+
164
+ actual_nodes = [
165
+ event.node_name or event.node_id or "<anonymous>"
166
+ for event in events
167
+ if event.event_type == "node_start"
168
+ ]
169
+ actual_nodes = [name for name, _ in groupby(actual_nodes)]
170
+ if actual_nodes != expected_nodes:
171
+ raise AssertionError(
172
+ "Node sequence mismatch:\n"
173
+ f" expected: {expected_nodes}\n"
174
+ f" actual: {actual_nodes}"
175
+ )
176
+
177
+
178
+ class _ErrorSimulation:
179
+ def __init__(
180
+ self,
181
+ *,
182
+ node_name: str,
183
+ code: str,
184
+ fail_times: int,
185
+ exception_factory: Callable[[str], Exception],
186
+ result_factory: Callable[[Any], Awaitable[Any] | Any] | None,
187
+ ) -> None:
188
+ self._node_name = node_name
189
+ self._code = code
190
+ self._fail_times = fail_times
191
+ self._exception_factory = exception_factory
192
+ self._result_factory = result_factory
193
+ self._attempts = 0
194
+
195
+ @property
196
+ def attempts(self) -> int:
197
+ return self._attempts
198
+
199
+ @property
200
+ def failures(self) -> int:
201
+ return min(self._attempts, self._fail_times)
202
+
203
+ async def __call__(self, message: Any, _ctx: Any) -> Any:
204
+ self._attempts += 1
205
+ if self._attempts <= self._fail_times:
206
+ text = (
207
+ f"[{self._code}] simulated failure in {self._node_name}"
208
+ f" (attempt {self._attempts})"
209
+ )
210
+ raise self._exception_factory(text)
211
+
212
+ if self._result_factory is None:
213
+ return message
214
+
215
+ result = self._result_factory(message)
216
+ if inspect.isawaitable(result):
217
+ return await result
218
+ return result
219
+
220
+
221
+ def simulate_error(
222
+ node_name: str,
223
+ code: FlowErrorCode | str,
224
+ *,
225
+ fail_times: int = 1,
226
+ result: Any | None = None,
227
+ result_factory: Callable[[Any], Awaitable[Any] | Any] | None = None,
228
+ exception_type: type[Exception] = RuntimeError,
229
+ ) -> Callable[[Any, Any], Awaitable[Any]]:
230
+ """Return an async callable that fails ``fail_times`` before succeeding.
231
+
232
+ The returned coroutine is suitable for wrapping in :class:`~penguiflow.node.Node`
233
+ and is especially useful for retry-centric tests. By default the callable
234
+ echoes the incoming ``message`` once the simulated failures are exhausted, but
235
+ ``result``/``result_factory`` can override the successful return value.
236
+ """
237
+
238
+ if fail_times < 1:
239
+ raise ValueError("fail_times must be >= 1")
240
+ if result is not None and result_factory is not None:
241
+ raise ValueError("Specify only one of result or result_factory")
242
+
243
+ resolved_code = code.value if isinstance(code, FlowErrorCode) else str(code)
244
+
245
+ def _exception_factory(text: str) -> Exception:
246
+ return exception_type(text)
247
+
248
+ if result_factory is None and result is not None:
249
+ async def _const_result(_: Any) -> Any:
250
+ return result
251
+
252
+ result_factory = _const_result
253
+
254
+ simulation = _ErrorSimulation(
255
+ node_name=node_name,
256
+ code=resolved_code,
257
+ fail_times=fail_times,
258
+ exception_factory=_exception_factory,
259
+ result_factory=result_factory,
260
+ )
261
+
262
+ async def _runner(message: Any, ctx: Any) -> Any:
263
+ return await simulation(message, ctx)
264
+
265
+ # Attach useful attributes for introspection in tests without exposing the
266
+ # internal class.
267
+ _runner.simulation = simulation # type: ignore[attr-defined]
268
+ return _runner
269
+
penguiflow/types.py CHANGED
@@ -21,6 +21,17 @@ class Message(BaseModel):
21
21
  trace_id: str = Field(default_factory=lambda: uuid.uuid4().hex)
22
22
  ts: float = Field(default_factory=time.time)
23
23
  deadline_s: float | None = None
24
+ meta: dict[str, Any] = Field(default_factory=dict)
25
+
26
+
27
+ class StreamChunk(BaseModel):
28
+ """Represents a chunk of streamed output."""
29
+
30
+ stream_id: str
31
+ seq: int
32
+ text: str
33
+ done: bool = False
34
+ meta: dict[str, Any] = Field(default_factory=dict)
24
35
 
25
36
 
26
37
  class PlanStep(BaseModel):
@@ -39,7 +50,9 @@ class WM(BaseModel):
39
50
  query: str
40
51
  facts: list[Any] = Field(default_factory=list)
41
52
  hops: int = 0
42
- budget_hops: int = 8
53
+ budget_hops: int | None = 8
54
+ tokens_used: int = 0
55
+ budget_tokens: int | None = None
43
56
  confidence: float = 0.0
44
57
 
45
58
 
@@ -51,6 +64,7 @@ class FinalAnswer(BaseModel):
51
64
  __all__ = [
52
65
  "Headers",
53
66
  "Message",
67
+ "StreamChunk",
54
68
  "PlanStep",
55
69
  "Thought",
56
70
  "WM",
penguiflow/viz.py CHANGED
@@ -2,4 +2,184 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __all__: list[str] = []
5
+ import re
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING
8
+
9
+ from .core import Endpoint
10
+ from .node import Node
11
+
12
+ if TYPE_CHECKING: # pragma: no cover - type checking only
13
+ from .core import PenguiFlow
14
+
15
+ __all__ = ["flow_to_mermaid", "flow_to_dot"]
16
+
17
+
18
+ @dataclass
19
+ class _VisualNode:
20
+ identifier: str
21
+ label: str
22
+ classes: list[str]
23
+
24
+
25
+ @dataclass
26
+ class _VisualEdge:
27
+ source: str
28
+ target: str
29
+ label: str | None
30
+
31
+
32
+ def flow_to_mermaid(flow: PenguiFlow, *, direction: str = "TD") -> str:
33
+ """Render the flow graph as a Mermaid diagram string.
34
+
35
+ Parameters
36
+ ----------
37
+ flow:
38
+ The :class:`PenguiFlow` instance to visualize.
39
+ direction:
40
+ Mermaid graph direction (``"TD"``, ``"LR"``, etc.). Defaults to top-down.
41
+ """
42
+
43
+ nodes, edges = _collect_graph(flow)
44
+
45
+ lines: list[str] = [f"graph {direction}"]
46
+ class_defs = {
47
+ "endpoint": "fill:#e0f2fe,stroke:#0369a1,stroke-width:1px",
48
+ "controller_loop": "fill:#fef3c7,stroke:#b45309,stroke-width:1px",
49
+ }
50
+
51
+ used_definitions: set[str] = set()
52
+
53
+ for node in nodes:
54
+ label = _escape_label(node.label)
55
+ lines.append(f" {node.identifier}[\"{label}\"]")
56
+ for class_name in node.classes:
57
+ used_definitions.add(class_name)
58
+
59
+ for class_name in sorted(used_definitions):
60
+ style = class_defs.get(class_name)
61
+ if style:
62
+ lines.append(f" classDef {class_name} {style}")
63
+
64
+ for node in nodes:
65
+ if node.classes:
66
+ classes = " ".join(node.classes)
67
+ lines.append(f" class {node.identifier} {classes}")
68
+
69
+ for edge in edges:
70
+ label = f"|{edge.label}|" if edge.label else ""
71
+ lines.append(f" {edge.source} -->{label} {edge.target}")
72
+
73
+ return "\n".join(lines)
74
+
75
+
76
+ def flow_to_dot(flow: PenguiFlow, *, rankdir: str = "TB") -> str:
77
+ """Render the flow graph as a Graphviz DOT string.
78
+
79
+ Parameters
80
+ ----------
81
+ flow:
82
+ The :class:`PenguiFlow` instance to visualize.
83
+ rankdir:
84
+ Graph orientation (``"TB"``, ``"LR"``, etc.). Defaults to top-bottom.
85
+ """
86
+
87
+ nodes, edges = _collect_graph(flow)
88
+
89
+ lines: list[str] = ["digraph PenguiFlow {", f" rankdir={rankdir}"]
90
+ lines.append(" node [shape=box, style=rounded]")
91
+
92
+ for node in nodes:
93
+ attributes: list[str] = [f'label="{node.label}"']
94
+ if "endpoint" in node.classes:
95
+ attributes.append('shape=oval')
96
+ attributes.append('style="filled"')
97
+ attributes.append('fillcolor="#e0f2fe"')
98
+ elif "controller_loop" in node.classes:
99
+ attributes.append('style="rounded,filled"')
100
+ attributes.append('fillcolor="#fef3c7"')
101
+ attr_str = ", ".join(attributes)
102
+ lines.append(f" {node.identifier} [{attr_str}]")
103
+
104
+ for edge in edges:
105
+ if edge.label:
106
+ edge_label = _escape_label(edge.label)
107
+ lines.append(
108
+ f" {edge.source} -> {edge.target} [label=\"{edge_label}\"]"
109
+ )
110
+ else:
111
+ lines.append(f" {edge.source} -> {edge.target}")
112
+
113
+ lines.append("}")
114
+ return "\n".join(lines)
115
+
116
+
117
+ def _collect_graph(flow: PenguiFlow) -> tuple[list[_VisualNode], list[_VisualEdge]]:
118
+ nodes: dict[object, _VisualNode] = {}
119
+ edges: list[_VisualEdge] = []
120
+ used_ids: set[str] = set()
121
+ loop_sources: set[object] = set()
122
+
123
+ def ensure_node(entity: object) -> _VisualNode:
124
+ node = nodes.get(entity)
125
+ if node is not None:
126
+ return node
127
+ label = _display_label(entity)
128
+ identifier = _unique_id(label, used_ids)
129
+ used_ids.add(identifier)
130
+ classes: list[str] = []
131
+ if isinstance(entity, Endpoint):
132
+ classes.append("endpoint")
133
+ if isinstance(entity, Node) and entity.allow_cycle:
134
+ classes.append("controller_loop")
135
+ node = _VisualNode(identifier=identifier, label=label, classes=classes)
136
+ nodes[entity] = node
137
+ return node
138
+
139
+ for floe in flow._floes: # noqa: SLF001 - visualization inspects internals
140
+ source = floe.source
141
+ target = floe.target
142
+ if source is None or target is None:
143
+ continue
144
+ src_node = ensure_node(source)
145
+ tgt_node = ensure_node(target)
146
+ if source is target:
147
+ loop_sources.add(source)
148
+ label = "loop"
149
+ elif isinstance(source, Endpoint):
150
+ label = "ingress"
151
+ elif isinstance(target, Endpoint):
152
+ label = "egress"
153
+ else:
154
+ label = None
155
+ edges.append(_VisualEdge(src_node.identifier, tgt_node.identifier, label))
156
+
157
+ if loop_sources:
158
+ for entity, node in nodes.items():
159
+ if entity in loop_sources and "controller_loop" not in node.classes:
160
+ node.classes.append("controller_loop")
161
+
162
+ return list(nodes.values()), edges
163
+
164
+
165
+ def _display_label(entity: object) -> str:
166
+ if isinstance(entity, Node):
167
+ return entity.name or entity.node_id
168
+ if isinstance(entity, Endpoint):
169
+ return entity.name
170
+ return str(entity)
171
+
172
+
173
+ def _unique_id(label: str, used: set[str]) -> str:
174
+ base = re.sub(r"[^0-9A-Za-z_]", "_", label) or "node"
175
+ candidate = base
176
+ index = 1
177
+ while candidate in used:
178
+ index += 1
179
+ candidate = f"{base}_{index}"
180
+ return candidate
181
+
182
+
183
+ def _escape_label(label: str) -> str:
184
+ return label.replace("\"", "\\\"")
185
+