penguiflow 2.0.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of penguiflow might be problematic. Click here for more details.

penguiflow/remote.py ADDED
@@ -0,0 +1,486 @@
1
+ """Remote transport protocol and helper node for PenguiFlow."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import logging
8
+ import time
9
+ from collections.abc import AsyncIterator, Mapping
10
+ from dataclasses import dataclass
11
+ from typing import TYPE_CHECKING, Any, Protocol
12
+
13
+ from pydantic import BaseModel
14
+
15
+ from .core import TraceCancelled
16
+ from .node import Node, NodePolicy
17
+ from .state import RemoteBinding
18
+ from .types import Message
19
+
20
+ if TYPE_CHECKING: # pragma: no cover - import for typing only
21
+ from .core import Context, PenguiFlow
22
+
23
+
24
+ @dataclass(slots=True)
25
+ class RemoteCallRequest:
26
+ """Input to :class:`RemoteTransport` implementations."""
27
+
28
+ message: Message
29
+ skill: str
30
+ agent_url: str
31
+ agent_card: Mapping[str, Any] | None = None
32
+ metadata: Mapping[str, Any] | None = None
33
+ timeout_s: float | None = None
34
+
35
+
36
+ @dataclass(slots=True)
37
+ class RemoteCallResult:
38
+ """Return value for :meth:`RemoteTransport.send`."""
39
+
40
+ result: Any
41
+ context_id: str | None = None
42
+ task_id: str | None = None
43
+ agent_url: str | None = None
44
+ meta: Mapping[str, Any] | None = None
45
+
46
+
47
+ @dataclass(slots=True)
48
+ class RemoteStreamEvent:
49
+ """Streaming event yielded by :meth:`RemoteTransport.stream`."""
50
+
51
+ text: str | None = None
52
+ done: bool = False
53
+ meta: Mapping[str, Any] | None = None
54
+ context_id: str | None = None
55
+ task_id: str | None = None
56
+ agent_url: str | None = None
57
+ result: Any | None = None
58
+
59
+
60
+ class RemoteTransport(Protocol):
61
+ """Protocol describing the minimal remote invocation surface."""
62
+
63
+ async def send(self, request: RemoteCallRequest) -> RemoteCallResult:
64
+ """Perform a unary remote call."""
65
+
66
+ def stream(self, request: RemoteCallRequest) -> AsyncIterator[RemoteStreamEvent]:
67
+ """Perform a remote call that yields streaming events."""
68
+
69
+ async def cancel(self, *, agent_url: str, task_id: str) -> None:
70
+ """Cancel a remote task identified by ``task_id`` at ``agent_url``."""
71
+
72
+
73
+ def _json_default(value: Any) -> Any:
74
+ if isinstance(value, BaseModel):
75
+ return value.model_dump(mode="json")
76
+ if isinstance(value, bytes):
77
+ return value.decode("utf-8", errors="replace")
78
+ return repr(value)
79
+
80
+
81
+ def _estimate_bytes(value: Any) -> int | None:
82
+ """Best-effort size estimation for observability metrics."""
83
+
84
+ if value is None:
85
+ return None
86
+ try:
87
+ if isinstance(value, BaseModel):
88
+ payload = value.model_dump(mode="json")
89
+ else:
90
+ payload = value
91
+ encoded = json.dumps(payload, default=_json_default).encode("utf-8")
92
+ except Exception:
93
+ try:
94
+ encoded = str(value).encode("utf-8")
95
+ except Exception:
96
+ return None
97
+ return len(encoded)
98
+
99
+
100
+ def _text_bytes(text: str | None) -> int:
101
+ if text is None:
102
+ return 0
103
+ return len(text.encode("utf-8"))
104
+
105
+
106
+ def _merge_remote_extra(
107
+ base: Mapping[str, Any],
108
+ *,
109
+ agent_url: str | None,
110
+ context_id: str | None,
111
+ task_id: str | None,
112
+ additional: Mapping[str, Any] | None = None,
113
+ ) -> dict[str, Any]:
114
+ extra = dict(base)
115
+ if agent_url is not None:
116
+ extra["remote_agent_url"] = agent_url
117
+ if context_id is not None:
118
+ extra["remote_context_id"] = context_id
119
+ if task_id is not None:
120
+ extra["remote_task_id"] = task_id
121
+ if additional:
122
+ for key, value in additional.items():
123
+ if value is not None:
124
+ extra[key] = value
125
+ return extra
126
+
127
+
128
+ def RemoteNode(
129
+ *,
130
+ transport: RemoteTransport,
131
+ skill: str,
132
+ agent_url: str,
133
+ name: str,
134
+ agent_card: Mapping[str, Any] | None = None,
135
+ policy: NodePolicy | None = None,
136
+ streaming: bool = False,
137
+ record_binding: bool = True,
138
+ ) -> Node:
139
+ """Create a node that proxies work to a remote agent via ``transport``."""
140
+
141
+ node_policy = policy or NodePolicy()
142
+
143
+ async def _record_binding(
144
+ *,
145
+ runtime: PenguiFlow,
146
+ context: Context,
147
+ node_owner: Node,
148
+ trace_id: str,
149
+ context_id: str | None,
150
+ task_id: str | None,
151
+ agent_url_override: str | None,
152
+ base_extra: Mapping[str, Any],
153
+ ) -> tuple[asyncio.Task[None], asyncio.Event] | None:
154
+ if context_id is None or task_id is None:
155
+ return None
156
+
157
+ agent_ref = agent_url_override or agent_url
158
+
159
+ if record_binding:
160
+ binding = RemoteBinding(
161
+ trace_id=trace_id,
162
+ context_id=context_id,
163
+ task_id=task_id,
164
+ agent_url=agent_ref,
165
+ )
166
+ await runtime.save_remote_binding(binding)
167
+
168
+ cancel_event = runtime.ensure_trace_event(trace_id)
169
+
170
+ async def _issue_cancel(reason: str) -> None:
171
+ start_cancel = time.perf_counter()
172
+ try:
173
+ await transport.cancel(agent_url=agent_ref, task_id=task_id)
174
+ except Exception as exc: # pragma: no cover - defensive logging
175
+ latency = (time.perf_counter() - start_cancel) * 1000
176
+ extra = _merge_remote_extra(
177
+ base_extra,
178
+ agent_url=agent_ref,
179
+ context_id=context_id,
180
+ task_id=task_id,
181
+ additional={
182
+ "remote_cancel_reason": reason,
183
+ "remote_error": repr(exc),
184
+ "remote_status": "cancel_error",
185
+ },
186
+ )
187
+ await runtime.record_remote_event(
188
+ event="remote_cancel_error",
189
+ node=node_owner,
190
+ context=context,
191
+ trace_id=trace_id,
192
+ latency_ms=latency,
193
+ level=logging.ERROR,
194
+ extra=extra,
195
+ )
196
+ return
197
+
198
+ latency = (time.perf_counter() - start_cancel) * 1000
199
+ extra = _merge_remote_extra(
200
+ base_extra,
201
+ agent_url=agent_ref,
202
+ context_id=context_id,
203
+ task_id=task_id,
204
+ additional={
205
+ "remote_cancel_reason": reason,
206
+ "remote_status": "cancelled",
207
+ },
208
+ )
209
+ await runtime.record_remote_event(
210
+ event="remote_call_cancelled",
211
+ node=node_owner,
212
+ context=context,
213
+ trace_id=trace_id,
214
+ latency_ms=latency,
215
+ level=logging.INFO,
216
+ extra=extra,
217
+ )
218
+
219
+ if cancel_event.is_set():
220
+ await _issue_cancel("pre_cancelled")
221
+ raise TraceCancelled(trace_id)
222
+
223
+ async def _mirror_cancel() -> None:
224
+ try:
225
+ await cancel_event.wait()
226
+ except asyncio.CancelledError:
227
+ return
228
+ await _issue_cancel("trace_cancel")
229
+
230
+ cancel_task = asyncio.create_task(_mirror_cancel())
231
+ runtime.register_external_task(trace_id, cancel_task)
232
+ return cancel_task, cancel_event
233
+
234
+ async def _remote_impl(message: Message, ctx: Context) -> Any:
235
+ if not isinstance(message, Message):
236
+ raise TypeError("Remote nodes require penguiflow.types.Message inputs")
237
+
238
+ runtime = ctx.runtime
239
+ if runtime is None:
240
+ raise RuntimeError("Context is not bound to a running PenguiFlow")
241
+
242
+ owner = ctx.owner
243
+ if not isinstance(owner, Node): # pragma: no cover - defensive safety
244
+ raise RuntimeError("Remote context owner must be a Node")
245
+
246
+ trace_id = message.trace_id
247
+ cancel_task: asyncio.Task[None] | None = None
248
+ cancel_event: asyncio.Event | None = None
249
+ binding_registered = False
250
+
251
+ remote_context_id: str | None = None
252
+ remote_task_id: str | None = None
253
+ remote_agent_url_final = agent_url
254
+ response_bytes = 0
255
+ stream_events = 0
256
+
257
+ base_extra: dict[str, Any] = {
258
+ "remote_skill": skill,
259
+ "remote_transport": type(transport).__name__,
260
+ "remote_streaming": streaming,
261
+ }
262
+ request_bytes = _estimate_bytes(message)
263
+ if request_bytes is not None:
264
+ base_extra["remote_request_bytes"] = request_bytes
265
+
266
+ request = RemoteCallRequest(
267
+ message=message,
268
+ skill=skill,
269
+ agent_url=agent_url,
270
+ agent_card=agent_card,
271
+ metadata=message.meta,
272
+ timeout_s=node_policy.timeout_s,
273
+ )
274
+
275
+ async def _ensure_binding(
276
+ *,
277
+ context_id: str | None,
278
+ task_id: str | None,
279
+ agent_url_override: str | None,
280
+ ) -> None:
281
+ nonlocal cancel_task, cancel_event, binding_registered
282
+ nonlocal remote_context_id, remote_task_id, remote_agent_url_final
283
+ if context_id is not None:
284
+ remote_context_id = context_id
285
+ if task_id is not None:
286
+ remote_task_id = task_id
287
+ if agent_url_override is not None:
288
+ remote_agent_url_final = agent_url_override
289
+ if binding_registered:
290
+ return
291
+ if context_id is None or task_id is None:
292
+ return
293
+ record = await _record_binding(
294
+ runtime=runtime,
295
+ context=ctx,
296
+ node_owner=owner,
297
+ trace_id=trace_id,
298
+ context_id=context_id,
299
+ task_id=task_id,
300
+ agent_url_override=agent_url_override,
301
+ base_extra=base_extra,
302
+ )
303
+ if record is None:
304
+ return
305
+ cancel_task, cancel_event = record
306
+ binding_registered = True
307
+
308
+ async def _cleanup_cancel_task() -> None:
309
+ if cancel_task is not None:
310
+ try:
311
+ if cancel_event is not None and cancel_event.is_set():
312
+ await cancel_task
313
+ return
314
+ if not cancel_task.done():
315
+ cancel_task.cancel()
316
+ await cancel_task
317
+ except BaseException: # pragma: no cover - cleanup guard
318
+ pass
319
+
320
+ async def _run_stream() -> Any | None:
321
+ nonlocal response_bytes, stream_events, remote_agent_url_final
322
+ final_result: Any | None = None
323
+ stream_idx = 0
324
+ async for event in transport.stream(request):
325
+ stream_events = stream_idx + 1
326
+ await _ensure_binding(
327
+ context_id=event.context_id,
328
+ task_id=event.task_id,
329
+ agent_url_override=event.agent_url,
330
+ )
331
+ if event.agent_url is not None:
332
+ remote_agent_url_final = event.agent_url
333
+
334
+ chunk_bytes = 0
335
+ if event.text is not None:
336
+ meta = dict(event.meta) if event.meta is not None else None
337
+ chunk_bytes += _text_bytes(event.text)
338
+ meta_bytes = _estimate_bytes(event.meta)
339
+ if meta_bytes is not None:
340
+ chunk_bytes += meta_bytes
341
+ await ctx.emit_chunk(
342
+ parent=message,
343
+ text=event.text,
344
+ done=event.done,
345
+ meta=meta,
346
+ )
347
+
348
+ if runtime is not None:
349
+ meta_keys = None
350
+ if event.meta:
351
+ meta_keys = sorted(event.meta.keys())
352
+ extra = _merge_remote_extra(
353
+ base_extra,
354
+ agent_url=remote_agent_url_final,
355
+ context_id=remote_context_id,
356
+ task_id=remote_task_id,
357
+ additional={
358
+ "remote_stream_seq": stream_idx,
359
+ "remote_chunk_bytes": chunk_bytes if chunk_bytes else None,
360
+ "remote_chunk_done": event.done,
361
+ "remote_chunk_meta_keys": meta_keys,
362
+ },
363
+ )
364
+ await runtime.record_remote_event(
365
+ event="remote_stream_event",
366
+ node=owner,
367
+ context=ctx,
368
+ trace_id=trace_id,
369
+ latency_ms=(time.perf_counter() - call_start) * 1000,
370
+ level=logging.DEBUG,
371
+ extra=extra,
372
+ )
373
+
374
+ if chunk_bytes:
375
+ response_bytes += chunk_bytes
376
+
377
+ if event.result is not None:
378
+ result_bytes = _estimate_bytes(event.result)
379
+ if result_bytes is not None:
380
+ response_bytes += result_bytes
381
+ final_result = event.result
382
+
383
+ stream_idx += 1
384
+
385
+ return final_result
386
+
387
+ call_start = time.perf_counter()
388
+
389
+ await runtime.record_remote_event(
390
+ event="remote_call_start",
391
+ node=owner,
392
+ context=ctx,
393
+ trace_id=trace_id,
394
+ latency_ms=0.0,
395
+ level=logging.DEBUG,
396
+ extra=_merge_remote_extra(
397
+ base_extra,
398
+ agent_url=remote_agent_url_final,
399
+ context_id=None,
400
+ task_id=None,
401
+ ),
402
+ )
403
+
404
+ try:
405
+ if streaming:
406
+ final_result = await _run_stream()
407
+ result_payload = final_result
408
+ else:
409
+ result = await transport.send(request)
410
+ await _ensure_binding(
411
+ context_id=result.context_id,
412
+ task_id=result.task_id,
413
+ agent_url_override=result.agent_url,
414
+ )
415
+ if result.context_id is not None:
416
+ remote_context_id = result.context_id
417
+ if result.task_id is not None:
418
+ remote_task_id = result.task_id
419
+ if result.agent_url is not None:
420
+ remote_agent_url_final = result.agent_url
421
+ result_payload = result.result
422
+ response_size = _estimate_bytes(result_payload)
423
+ if response_size is not None:
424
+ response_bytes += response_size
425
+ except TraceCancelled:
426
+ raise
427
+ except asyncio.CancelledError:
428
+ raise
429
+ except Exception as exc:
430
+ latency = (time.perf_counter() - call_start) * 1000
431
+ extra = _merge_remote_extra(
432
+ base_extra,
433
+ agent_url=remote_agent_url_final,
434
+ context_id=remote_context_id,
435
+ task_id=remote_task_id,
436
+ additional={
437
+ "remote_error": repr(exc),
438
+ "remote_status": "error",
439
+ },
440
+ )
441
+ await runtime.record_remote_event(
442
+ event="remote_call_error",
443
+ node=owner,
444
+ context=ctx,
445
+ trace_id=trace_id,
446
+ latency_ms=latency,
447
+ level=logging.ERROR,
448
+ extra=extra,
449
+ )
450
+ raise
451
+ else:
452
+ latency = (time.perf_counter() - call_start) * 1000
453
+ extra = _merge_remote_extra(
454
+ base_extra,
455
+ agent_url=remote_agent_url_final,
456
+ context_id=remote_context_id,
457
+ task_id=remote_task_id,
458
+ additional={
459
+ "remote_response_bytes": response_bytes,
460
+ "remote_stream_events": stream_events,
461
+ "remote_status": "success",
462
+ },
463
+ )
464
+ await runtime.record_remote_event(
465
+ event="remote_call_success",
466
+ node=owner,
467
+ context=ctx,
468
+ trace_id=trace_id,
469
+ latency_ms=latency,
470
+ level=logging.INFO,
471
+ extra=extra,
472
+ )
473
+ return result_payload
474
+ finally:
475
+ await _cleanup_cancel_task()
476
+
477
+ return Node(_remote_impl, name=name, policy=node_policy)
478
+
479
+
480
+ __all__ = [
481
+ "RemoteCallRequest",
482
+ "RemoteCallResult",
483
+ "RemoteStreamEvent",
484
+ "RemoteTransport",
485
+ "RemoteNode",
486
+ ]
penguiflow/state.py ADDED
@@ -0,0 +1,64 @@
1
+ """State store protocol and helpers for PenguiFlow."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Mapping, Sequence
6
+ from dataclasses import dataclass
7
+ from typing import Any, Protocol
8
+
9
+ from .metrics import FlowEvent
10
+
11
+
12
+ @dataclass(slots=True)
13
+ class StoredEvent:
14
+ """Representation of a runtime event persisted by a state store."""
15
+
16
+ trace_id: str | None
17
+ ts: float
18
+ kind: str
19
+ node_name: str | None
20
+ node_id: str | None
21
+ payload: Mapping[str, Any]
22
+
23
+ @classmethod
24
+ def from_flow_event(cls, event: FlowEvent) -> StoredEvent:
25
+ """Create a stored representation from a :class:`FlowEvent`."""
26
+
27
+ return cls(
28
+ trace_id=event.trace_id,
29
+ ts=event.ts,
30
+ kind=event.event_type,
31
+ node_name=event.node_name,
32
+ node_id=event.node_id,
33
+ payload=event.to_payload(),
34
+ )
35
+
36
+
37
+ @dataclass(slots=True)
38
+ class RemoteBinding:
39
+ """Association between a trace and a remote worker/agent."""
40
+
41
+ trace_id: str
42
+ context_id: str
43
+ task_id: str
44
+ agent_url: str
45
+
46
+
47
+ class StateStore(Protocol):
48
+ """Protocol for durable state adapters used by PenguiFlow."""
49
+
50
+ async def save_event(self, event: StoredEvent) -> None:
51
+ """Persist a runtime event.
52
+
53
+ Implementations may choose any storage backend (Postgres, Redis, etc.).
54
+ The method must be idempotent since retries can emit duplicate events.
55
+ """
56
+
57
+ async def load_history(self, trace_id: str) -> Sequence[StoredEvent]:
58
+ """Return the ordered history for a trace id."""
59
+
60
+ async def save_remote_binding(self, binding: RemoteBinding) -> None:
61
+ """Persist the mapping between a trace and an external worker."""
62
+
63
+
64
+ __all__ = ["StateStore", "StoredEvent", "RemoteBinding"]
penguiflow/testkit.py CHANGED
@@ -21,9 +21,15 @@ from weakref import WeakKeyDictionary
21
21
  from .core import PenguiFlow
22
22
  from .errors import FlowErrorCode
23
23
  from .metrics import FlowEvent
24
- from .types import Message
24
+ from .types import Headers, Message
25
25
 
26
- __all__ = ["run_one", "assert_node_sequence", "simulate_error"]
26
+ __all__ = [
27
+ "run_one",
28
+ "assert_node_sequence",
29
+ "get_recorded_events",
30
+ "simulate_error",
31
+ "assert_preserves_message_envelope",
32
+ ]
27
33
 
28
34
 
29
35
  _MAX_TRACE_HISTORY = 64
@@ -102,6 +108,20 @@ class _Recorder:
102
108
  await self._state.record(event)
103
109
 
104
110
 
111
+ class _StubContext:
112
+ async def emit(self, *_args: Any, **_kwargs: Any) -> None:
113
+ return None
114
+
115
+ def emit_nowait(self, *_args: Any, **_kwargs: Any) -> None:
116
+ return None
117
+
118
+ async def emit_chunk(self, *_args: Any, **_kwargs: Any) -> Any:
119
+ raise RuntimeError(
120
+ "FlowTestKit stub context does not support emit_chunk; provide a custom"
121
+ " context via the 'ctx' parameter"
122
+ )
123
+
124
+
105
125
  def _get_state(flow: PenguiFlow) -> _RecorderState:
106
126
  state = _RECORDER_STATE.get(flow)
107
127
  if state is None:
@@ -150,6 +170,78 @@ async def run_one(
150
170
  return result
151
171
 
152
172
 
173
+ async def assert_preserves_message_envelope(
174
+ node: Callable[[Message, Any], Awaitable[Any]] | Any,
175
+ *,
176
+ message: Message | None = None,
177
+ ctx: Any | None = None,
178
+ ) -> Message:
179
+ """Execute ``node`` and assert it preserves the ``Message`` envelope.
180
+
181
+ Parameters
182
+ ----------
183
+ node:
184
+ Either a bare async callable or a :class:`penguiflow.node.Node` whose
185
+ first parameter is a :class:`~penguiflow.types.Message` instance.
186
+ message:
187
+ Optional sample message. When omitted, a minimal envelope is
188
+ synthesised.
189
+ ctx:
190
+ Optional context object passed to the node. By default a stub context
191
+ is used that simply no-ops ``emit``/``emit_nowait``.
192
+
193
+ Returns
194
+ -------
195
+ Message
196
+ The resulting message from the node, allowing additional assertions.
197
+
198
+ Raises
199
+ ------
200
+ AssertionError
201
+ If the node does not return a ``Message`` or mutates core envelope
202
+ fields (headers or trace_id).
203
+ TypeError
204
+ If ``node`` is not awaitable.
205
+ """
206
+
207
+ from .node import Node # Local import to avoid circular dependency
208
+
209
+ if isinstance(node, Node):
210
+ func = node.func
211
+ node_name = node.name or node.func.__name__
212
+ else:
213
+ func = node
214
+ node_name = getattr(node, "__name__", "<anonymous>")
215
+
216
+ if not inspect.iscoroutinefunction(func):
217
+ raise TypeError("assert_preserves_message_envelope expects an async node")
218
+
219
+ sample = message or Message(payload={}, headers=Headers(tenant="test"))
220
+ context = ctx if ctx is not None else _StubContext()
221
+
222
+ result = await func(sample, context)
223
+ if not isinstance(result, Message):
224
+ produced = type(result).__name__
225
+ raise AssertionError(
226
+ "Node "
227
+ f"'{node_name}' must return a Message but produced {produced}"
228
+ )
229
+
230
+ mismatches: list[str] = []
231
+ if result.headers != sample.headers:
232
+ mismatches.append("headers")
233
+ if result.trace_id != sample.trace_id:
234
+ mismatches.append("trace_id")
235
+
236
+ if mismatches:
237
+ joined = ", ".join(mismatches)
238
+ raise AssertionError(
239
+ f"Node '{node_name}' altered Message {joined}; preserve the envelope"
240
+ )
241
+
242
+ return result
243
+
244
+
153
245
  def assert_node_sequence(trace_id: str, expected: Sequence[str]) -> None:
154
246
  """Assert that ``expected`` matches the recorded node start order."""
155
247
 
@@ -175,6 +267,19 @@ def assert_node_sequence(trace_id: str, expected: Sequence[str]) -> None:
175
267
  )
176
268
 
177
269
 
270
+ def get_recorded_events(trace_id: str) -> tuple[FlowEvent, ...]:
271
+ """Return the recorded :class:`FlowEvent` history for ``trace_id``.
272
+
273
+ The FlowTestKit recorder maintains a bounded cache of trace histories.
274
+ This helper exposes the immutable snapshot so tests can assert on
275
+ diagnostics such as ``node_failed`` payloads or retry attempts without
276
+ touching the private cache directly.
277
+ """
278
+
279
+ events = _TRACE_HISTORY.get(trace_id, [])
280
+ return tuple(events)
281
+
282
+
178
283
  class _ErrorSimulation:
179
284
  def __init__(
180
285
  self,