penguiflow 1.0.3__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of penguiflow might be problematic. Click here for more details.
- penguiflow/__init__.py +45 -3
- penguiflow/admin.py +174 -0
- penguiflow/bus.py +30 -0
- penguiflow/core.py +941 -57
- penguiflow/errors.py +113 -0
- penguiflow/metrics.py +105 -0
- penguiflow/middlewares.py +6 -7
- penguiflow/patterns.py +47 -5
- penguiflow/policies.py +149 -0
- penguiflow/remote.py +486 -0
- penguiflow/state.py +64 -0
- penguiflow/streaming.py +142 -0
- penguiflow/testkit.py +269 -0
- penguiflow/types.py +15 -1
- penguiflow/viz.py +133 -24
- penguiflow-2.1.0.dist-info/METADATA +646 -0
- penguiflow-2.1.0.dist-info/RECORD +25 -0
- penguiflow-2.1.0.dist-info/entry_points.txt +2 -0
- penguiflow-2.1.0.dist-info/top_level.txt +2 -0
- penguiflow_a2a/__init__.py +19 -0
- penguiflow_a2a/server.py +695 -0
- penguiflow-1.0.3.dist-info/METADATA +0 -425
- penguiflow-1.0.3.dist-info/RECORD +0 -13
- penguiflow-1.0.3.dist-info/top_level.txt +0 -1
- {penguiflow-1.0.3.dist-info → penguiflow-2.1.0.dist-info}/WHEEL +0 -0
- {penguiflow-1.0.3.dist-info → penguiflow-2.1.0.dist-info}/licenses/LICENSE +0 -0
penguiflow_a2a/server.py
ADDED
|
@@ -0,0 +1,695 @@
|
|
|
1
|
+
"""Expose PenguiFlow runs through an A2A-compliant HTTP surface."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import contextvars
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
from collections.abc import AsyncIterator, Mapping, Sequence
|
|
10
|
+
from contextlib import asynccontextmanager, suppress
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from types import MethodType
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
|
16
|
+
|
|
17
|
+
from penguiflow.core import PenguiFlow, TraceCancelled
|
|
18
|
+
from penguiflow.errors import FlowError
|
|
19
|
+
from penguiflow.state import RemoteBinding
|
|
20
|
+
from penguiflow.streaming import format_sse_event
|
|
21
|
+
from penguiflow.types import Headers, Message, StreamChunk
|
|
22
|
+
|
|
23
|
+
_QUEUE_SHUTDOWN = object()
|
|
24
|
+
_TRACE_CONTEXT: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
|
25
|
+
"penguiflow_a2a_trace", default=None
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(slots=True)
|
|
30
|
+
class RookeryResult:
|
|
31
|
+
trace_id: str
|
|
32
|
+
value: Any
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class A2ASkill(BaseModel):
|
|
36
|
+
"""Description of a single capability exposed by an agent."""
|
|
37
|
+
|
|
38
|
+
name: str
|
|
39
|
+
description: str
|
|
40
|
+
mode: str = Field(
|
|
41
|
+
default="both",
|
|
42
|
+
description="Whether the skill supports message/send, message/stream, or both.",
|
|
43
|
+
)
|
|
44
|
+
inputs: dict[str, Any] = Field(default_factory=dict)
|
|
45
|
+
outputs: dict[str, Any] = Field(default_factory=dict)
|
|
46
|
+
|
|
47
|
+
model_config = ConfigDict(extra="allow")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class A2AAgentCard(BaseModel):
|
|
51
|
+
"""Lightweight Agent Card surfaced at ``GET /agent``."""
|
|
52
|
+
|
|
53
|
+
name: str
|
|
54
|
+
description: str
|
|
55
|
+
version: str = "1.0.0"
|
|
56
|
+
schema_version: str = Field(default="1.0")
|
|
57
|
+
tags: list[str] = Field(default_factory=list)
|
|
58
|
+
capabilities: list[str] = Field(default_factory=list)
|
|
59
|
+
skills: list[A2ASkill] = Field(default_factory=list)
|
|
60
|
+
contact_url: str | None = None
|
|
61
|
+
documentation_url: str | None = None
|
|
62
|
+
|
|
63
|
+
model_config = ConfigDict(extra="allow")
|
|
64
|
+
|
|
65
|
+
def to_payload(self) -> dict[str, Any]:
|
|
66
|
+
"""Return a serialisable dictionary representation."""
|
|
67
|
+
|
|
68
|
+
return self.model_dump()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class A2AMessagePayload(BaseModel):
|
|
72
|
+
"""Request payload accepted by ``message/send`` and ``message/stream``."""
|
|
73
|
+
|
|
74
|
+
payload: Any
|
|
75
|
+
headers: Mapping[str, Any] = Field(default_factory=dict)
|
|
76
|
+
meta: dict[str, Any] = Field(default_factory=dict)
|
|
77
|
+
trace_id: str | None = Field(default=None, alias="traceId")
|
|
78
|
+
context_id: str | None = Field(default=None, alias="contextId")
|
|
79
|
+
task_id: str | None = Field(default=None, alias="taskId")
|
|
80
|
+
deadline_s: float | None = Field(default=None, alias="deadlineSeconds")
|
|
81
|
+
|
|
82
|
+
model_config = ConfigDict(populate_by_name=True)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class A2ATaskCancelRequest(BaseModel):
|
|
86
|
+
"""JSON body accepted by ``tasks/cancel``."""
|
|
87
|
+
|
|
88
|
+
task_id: str = Field(alias="taskId")
|
|
89
|
+
|
|
90
|
+
model_config = ConfigDict(populate_by_name=True)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class A2ARequestError(Exception):
|
|
94
|
+
"""Exception converted to ``HTTPException`` inside the FastAPI app."""
|
|
95
|
+
|
|
96
|
+
def __init__(self, *, status_code: int, detail: str) -> None:
|
|
97
|
+
super().__init__(detail)
|
|
98
|
+
self.status_code = status_code
|
|
99
|
+
self.detail = detail
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class A2AServerAdapter:
|
|
103
|
+
"""Bridge between PenguiFlow and the A2A HTTP surface."""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
flow: PenguiFlow,
|
|
108
|
+
*,
|
|
109
|
+
agent_card: A2AAgentCard | Mapping[str, Any],
|
|
110
|
+
agent_url: str,
|
|
111
|
+
target: Sequence[Any] | Any | None = None,
|
|
112
|
+
registry: Any | None = None,
|
|
113
|
+
default_headers: Mapping[str, Any] | None = None,
|
|
114
|
+
) -> None:
|
|
115
|
+
self._flow = flow
|
|
116
|
+
self._registry = registry
|
|
117
|
+
self._target = target
|
|
118
|
+
self._default_headers = dict(default_headers or {})
|
|
119
|
+
self.agent_card = (
|
|
120
|
+
agent_card
|
|
121
|
+
if isinstance(agent_card, A2AAgentCard)
|
|
122
|
+
else A2AAgentCard.model_validate(agent_card)
|
|
123
|
+
)
|
|
124
|
+
self.agent_url = agent_url
|
|
125
|
+
self._flow_started = False
|
|
126
|
+
self._tasks: dict[str, str] = {}
|
|
127
|
+
self._contexts: dict[str, str] = {}
|
|
128
|
+
self._lock = asyncio.Lock()
|
|
129
|
+
self._queue_lock = asyncio.Lock()
|
|
130
|
+
self._trace_queues: dict[str, asyncio.Queue[Any]] = {}
|
|
131
|
+
self._pending_results: dict[str, list[Any]] = {}
|
|
132
|
+
self._cancel_watchers: dict[str, asyncio.Task[None]] = {}
|
|
133
|
+
self._dispatcher_task: asyncio.Task[None] | None = None
|
|
134
|
+
self._message_traces: dict[int, str] = {}
|
|
135
|
+
self._patch_flow()
|
|
136
|
+
|
|
137
|
+
async def start(self) -> None:
|
|
138
|
+
"""Start the underlying flow if it is not running."""
|
|
139
|
+
|
|
140
|
+
if self._flow_started:
|
|
141
|
+
return
|
|
142
|
+
self._flow.run(registry=self._registry)
|
|
143
|
+
self._flow_started = True
|
|
144
|
+
self._ensure_dispatcher_task()
|
|
145
|
+
|
|
146
|
+
async def stop(self) -> None:
|
|
147
|
+
"""Gracefully stop the underlying flow."""
|
|
148
|
+
|
|
149
|
+
if not self._flow_started:
|
|
150
|
+
return
|
|
151
|
+
dispatcher = self._dispatcher_task
|
|
152
|
+
self._dispatcher_task = None
|
|
153
|
+
if dispatcher is not None:
|
|
154
|
+
dispatcher.cancel()
|
|
155
|
+
await self._flow.stop()
|
|
156
|
+
if dispatcher is not None:
|
|
157
|
+
with suppress(asyncio.CancelledError):
|
|
158
|
+
await dispatcher
|
|
159
|
+
self._flow_started = False
|
|
160
|
+
async with self._queue_lock:
|
|
161
|
+
queues = list(self._trace_queues.values())
|
|
162
|
+
cancel_watchers = list(self._cancel_watchers.values())
|
|
163
|
+
self._trace_queues.clear()
|
|
164
|
+
self._pending_results.clear()
|
|
165
|
+
self._cancel_watchers.clear()
|
|
166
|
+
for watcher in cancel_watchers:
|
|
167
|
+
watcher.cancel()
|
|
168
|
+
with suppress(asyncio.CancelledError):
|
|
169
|
+
await watcher
|
|
170
|
+
for queue in queues:
|
|
171
|
+
queue.put_nowait(_QUEUE_SHUTDOWN)
|
|
172
|
+
|
|
173
|
+
def _ensure_started(self) -> None:
|
|
174
|
+
if not self._flow_started:
|
|
175
|
+
raise A2ARequestError(status_code=503, detail="flow is not running")
|
|
176
|
+
|
|
177
|
+
async def handle_send(self, request: A2AMessagePayload) -> dict[str, Any]:
|
|
178
|
+
"""Execute ``message/send`` and return the final artifact."""
|
|
179
|
+
|
|
180
|
+
self._ensure_started()
|
|
181
|
+
message, task_id, context_id = self._prepare_message(request)
|
|
182
|
+
await self._register_task(task_id, message.trace_id, context_id)
|
|
183
|
+
await self._persist_binding(message.trace_id, context_id, task_id)
|
|
184
|
+
result_queue = await self._acquire_trace_queue(message.trace_id)
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
await self._flow.emit(message, to=self._target)
|
|
188
|
+
while True:
|
|
189
|
+
item = await result_queue.get()
|
|
190
|
+
if item is _QUEUE_SHUTDOWN:
|
|
191
|
+
raise A2ARequestError(
|
|
192
|
+
status_code=503, detail="flow is shutting down"
|
|
193
|
+
)
|
|
194
|
+
if isinstance(item, TraceCancelled):
|
|
195
|
+
raise item
|
|
196
|
+
if isinstance(item, FlowError):
|
|
197
|
+
raise item
|
|
198
|
+
if isinstance(item, Exception): # pragma: no cover - defensive
|
|
199
|
+
raise item
|
|
200
|
+
if isinstance(item, RookeryResult):
|
|
201
|
+
payload_candidate = item.value
|
|
202
|
+
else:
|
|
203
|
+
payload_candidate = getattr(item, "payload", item)
|
|
204
|
+
if isinstance(payload_candidate, StreamChunk):
|
|
205
|
+
continue
|
|
206
|
+
result = item
|
|
207
|
+
break
|
|
208
|
+
if isinstance(result, RookeryResult):
|
|
209
|
+
payload = result.value
|
|
210
|
+
else:
|
|
211
|
+
payload = getattr(result, "payload", result)
|
|
212
|
+
response: dict[str, Any] = {
|
|
213
|
+
"status": "succeeded",
|
|
214
|
+
"taskId": task_id,
|
|
215
|
+
"contextId": context_id,
|
|
216
|
+
"traceId": message.trace_id,
|
|
217
|
+
"output": self._to_jsonable(payload),
|
|
218
|
+
}
|
|
219
|
+
meta = getattr(result, "meta", None)
|
|
220
|
+
if meta:
|
|
221
|
+
response["meta"] = dict(meta)
|
|
222
|
+
return response
|
|
223
|
+
except TraceCancelled:
|
|
224
|
+
return {
|
|
225
|
+
"status": "cancelled",
|
|
226
|
+
"taskId": task_id,
|
|
227
|
+
"contextId": context_id,
|
|
228
|
+
"traceId": message.trace_id,
|
|
229
|
+
}
|
|
230
|
+
except FlowError as exc:
|
|
231
|
+
error_payload = exc.to_payload()
|
|
232
|
+
error_payload.setdefault("trace_id", message.trace_id)
|
|
233
|
+
return {
|
|
234
|
+
"status": "failed",
|
|
235
|
+
"taskId": task_id,
|
|
236
|
+
"contextId": context_id,
|
|
237
|
+
"traceId": message.trace_id,
|
|
238
|
+
"error": error_payload,
|
|
239
|
+
}
|
|
240
|
+
except A2ARequestError:
|
|
241
|
+
raise
|
|
242
|
+
except Exception as exc: # pragma: no cover - defensive fallback
|
|
243
|
+
raise A2ARequestError(
|
|
244
|
+
status_code=500,
|
|
245
|
+
detail=f"flow execution failed: {exc}",
|
|
246
|
+
) from exc
|
|
247
|
+
finally:
|
|
248
|
+
await self._release_task(task_id)
|
|
249
|
+
await self._release_trace_queue(message.trace_id)
|
|
250
|
+
|
|
251
|
+
async def stream(
|
|
252
|
+
self, request: A2AMessagePayload
|
|
253
|
+
) -> tuple[AsyncIterator[bytes], str, str]:
|
|
254
|
+
"""Execute ``message/stream`` and return an SSE iterator."""
|
|
255
|
+
|
|
256
|
+
self._ensure_started()
|
|
257
|
+
message, task_id, context_id = self._prepare_message(request)
|
|
258
|
+
await self._register_task(task_id, message.trace_id, context_id)
|
|
259
|
+
await self._persist_binding(message.trace_id, context_id, task_id)
|
|
260
|
+
await self._acquire_trace_queue(message.trace_id)
|
|
261
|
+
generator = self._stream_generator(message, task_id, context_id)
|
|
262
|
+
return generator, task_id, context_id
|
|
263
|
+
|
|
264
|
+
async def cancel(self, request: A2ATaskCancelRequest) -> dict[str, Any]:
|
|
265
|
+
"""Cancel an active task."""
|
|
266
|
+
|
|
267
|
+
self._ensure_started()
|
|
268
|
+
task_id = request.task_id
|
|
269
|
+
async with self._lock:
|
|
270
|
+
trace_id = self._tasks.get(task_id)
|
|
271
|
+
context_id = self._contexts.get(task_id)
|
|
272
|
+
if trace_id is None:
|
|
273
|
+
return {"taskId": task_id, "cancelled": False}
|
|
274
|
+
cancelled = await self._flow.cancel(trace_id)
|
|
275
|
+
response = {
|
|
276
|
+
"taskId": task_id,
|
|
277
|
+
"cancelled": cancelled,
|
|
278
|
+
"traceId": trace_id,
|
|
279
|
+
}
|
|
280
|
+
if context_id is not None:
|
|
281
|
+
response["contextId"] = context_id
|
|
282
|
+
return response
|
|
283
|
+
|
|
284
|
+
def _prepare_message(
|
|
285
|
+
self, request: A2AMessagePayload
|
|
286
|
+
) -> tuple[Message, str, str]:
|
|
287
|
+
headers_data = {**self._default_headers, **dict(request.headers)}
|
|
288
|
+
try:
|
|
289
|
+
headers = Headers(**headers_data)
|
|
290
|
+
except ValidationError as exc: # pragma: no cover - pydantic formats nicely
|
|
291
|
+
raise A2ARequestError(status_code=422, detail=str(exc)) from exc
|
|
292
|
+
|
|
293
|
+
kwargs: dict[str, Any] = {}
|
|
294
|
+
if request.trace_id is not None:
|
|
295
|
+
kwargs["trace_id"] = request.trace_id
|
|
296
|
+
if request.deadline_s is not None:
|
|
297
|
+
kwargs["deadline_s"] = request.deadline_s
|
|
298
|
+
message = Message(payload=request.payload, headers=headers, **kwargs)
|
|
299
|
+
message.meta.update(request.meta)
|
|
300
|
+
|
|
301
|
+
context_id = request.context_id or message.trace_id
|
|
302
|
+
task_id = request.task_id or message.trace_id or uuid.uuid4().hex
|
|
303
|
+
return message, task_id, context_id
|
|
304
|
+
|
|
305
|
+
async def _register_task(
|
|
306
|
+
self, task_id: str, trace_id: str, context_id: str
|
|
307
|
+
) -> None:
|
|
308
|
+
async with self._lock:
|
|
309
|
+
if task_id in self._tasks:
|
|
310
|
+
raise A2ARequestError(
|
|
311
|
+
status_code=409, detail=f"task {task_id!r} already active"
|
|
312
|
+
)
|
|
313
|
+
self._tasks[task_id] = trace_id
|
|
314
|
+
self._contexts[task_id] = context_id
|
|
315
|
+
|
|
316
|
+
async def _release_task(self, task_id: str) -> None:
|
|
317
|
+
async with self._lock:
|
|
318
|
+
self._tasks.pop(task_id, None)
|
|
319
|
+
self._contexts.pop(task_id, None)
|
|
320
|
+
|
|
321
|
+
async def _persist_binding(
|
|
322
|
+
self, trace_id: str, context_id: str, task_id: str
|
|
323
|
+
) -> None:
|
|
324
|
+
binding = RemoteBinding(
|
|
325
|
+
trace_id=trace_id,
|
|
326
|
+
context_id=context_id,
|
|
327
|
+
task_id=task_id,
|
|
328
|
+
agent_url=self.agent_url,
|
|
329
|
+
)
|
|
330
|
+
await self._flow.save_remote_binding(binding)
|
|
331
|
+
|
|
332
|
+
async def _stream_generator(
|
|
333
|
+
self, message: Message, task_id: str, context_id: str
|
|
334
|
+
) -> AsyncIterator[bytes]:
|
|
335
|
+
result_queue = await self._get_trace_queue(message.trace_id)
|
|
336
|
+
try:
|
|
337
|
+
await self._flow.emit(message, to=self._target)
|
|
338
|
+
yield self._format_event(
|
|
339
|
+
"status",
|
|
340
|
+
{
|
|
341
|
+
"status": "accepted",
|
|
342
|
+
"taskId": task_id,
|
|
343
|
+
"contextId": context_id,
|
|
344
|
+
},
|
|
345
|
+
)
|
|
346
|
+
while True:
|
|
347
|
+
item = await result_queue.get()
|
|
348
|
+
if item is _QUEUE_SHUTDOWN:
|
|
349
|
+
raise A2ARequestError(
|
|
350
|
+
status_code=503, detail="flow is shutting down"
|
|
351
|
+
)
|
|
352
|
+
if isinstance(item, TraceCancelled):
|
|
353
|
+
raise item
|
|
354
|
+
if isinstance(item, FlowError):
|
|
355
|
+
raise item
|
|
356
|
+
if isinstance(item, Exception): # pragma: no cover - defensive
|
|
357
|
+
raise item
|
|
358
|
+
if isinstance(item, RookeryResult):
|
|
359
|
+
payload = item.value
|
|
360
|
+
else:
|
|
361
|
+
payload = getattr(item, "payload", item)
|
|
362
|
+
if isinstance(payload, StreamChunk):
|
|
363
|
+
yield self._format_chunk_event(payload, task_id, context_id)
|
|
364
|
+
continue
|
|
365
|
+
yield self._format_event(
|
|
366
|
+
"artifact",
|
|
367
|
+
{
|
|
368
|
+
"taskId": task_id,
|
|
369
|
+
"contextId": context_id,
|
|
370
|
+
"output": self._to_jsonable(payload),
|
|
371
|
+
},
|
|
372
|
+
)
|
|
373
|
+
break
|
|
374
|
+
yield self._format_event(
|
|
375
|
+
"done", {"taskId": task_id, "contextId": context_id}
|
|
376
|
+
)
|
|
377
|
+
except TraceCancelled:
|
|
378
|
+
yield self._format_event(
|
|
379
|
+
"error",
|
|
380
|
+
{
|
|
381
|
+
"taskId": task_id,
|
|
382
|
+
"contextId": context_id,
|
|
383
|
+
"code": "TRACE_CANCELLED",
|
|
384
|
+
"message": "Trace cancelled",
|
|
385
|
+
},
|
|
386
|
+
)
|
|
387
|
+
yield self._format_event(
|
|
388
|
+
"done", {"taskId": task_id, "contextId": context_id}
|
|
389
|
+
)
|
|
390
|
+
except FlowError as exc:
|
|
391
|
+
payload = exc.to_payload()
|
|
392
|
+
payload.update({"taskId": task_id, "contextId": context_id})
|
|
393
|
+
yield self._format_event("error", payload)
|
|
394
|
+
yield self._format_event(
|
|
395
|
+
"done", {"taskId": task_id, "contextId": context_id}
|
|
396
|
+
)
|
|
397
|
+
except Exception as exc: # pragma: no cover - defensive fallback
|
|
398
|
+
yield self._format_event(
|
|
399
|
+
"error",
|
|
400
|
+
{
|
|
401
|
+
"taskId": task_id,
|
|
402
|
+
"contextId": context_id,
|
|
403
|
+
"code": "INTERNAL_ERROR",
|
|
404
|
+
"message": str(exc) or exc.__class__.__name__,
|
|
405
|
+
},
|
|
406
|
+
)
|
|
407
|
+
yield self._format_event(
|
|
408
|
+
"done", {"taskId": task_id, "contextId": context_id}
|
|
409
|
+
)
|
|
410
|
+
finally:
|
|
411
|
+
await self._release_task(task_id)
|
|
412
|
+
await self._release_trace_queue(message.trace_id)
|
|
413
|
+
|
|
414
|
+
def _format_event(self, event: str, data: Mapping[str, Any]) -> bytes:
|
|
415
|
+
payload = json.dumps(data, ensure_ascii=False)
|
|
416
|
+
return f"event: {event}\ndata: {payload}\n\n".encode()
|
|
417
|
+
|
|
418
|
+
def _format_chunk_event(
|
|
419
|
+
self, chunk: StreamChunk, task_id: str, context_id: str
|
|
420
|
+
) -> bytes:
|
|
421
|
+
meta = dict(chunk.meta)
|
|
422
|
+
meta.setdefault("taskId", task_id)
|
|
423
|
+
meta.setdefault("contextId", context_id)
|
|
424
|
+
enriched = chunk.model_copy(update={"meta": meta})
|
|
425
|
+
return format_sse_event(enriched).encode("utf-8")
|
|
426
|
+
|
|
427
|
+
def _to_jsonable(self, value: Any) -> Any:
|
|
428
|
+
if isinstance(value, BaseModel):
|
|
429
|
+
return value.model_dump()
|
|
430
|
+
if isinstance(value, Message):
|
|
431
|
+
return {
|
|
432
|
+
"payload": self._to_jsonable(value.payload),
|
|
433
|
+
"headers": value.headers.model_dump(),
|
|
434
|
+
"trace_id": value.trace_id,
|
|
435
|
+
"meta": dict(value.meta),
|
|
436
|
+
}
|
|
437
|
+
if isinstance(value, RookeryResult):
|
|
438
|
+
return self._to_jsonable(value.value)
|
|
439
|
+
if isinstance(value, Mapping):
|
|
440
|
+
return {k: self._to_jsonable(v) for k, v in value.items()}
|
|
441
|
+
if isinstance(value, list | tuple | set):
|
|
442
|
+
return [self._to_jsonable(item) for item in value]
|
|
443
|
+
return value
|
|
444
|
+
|
|
445
|
+
def _patch_flow(self) -> None:
|
|
446
|
+
flow = self._flow
|
|
447
|
+
if getattr(flow, "_a2a_adapter_patched", False):
|
|
448
|
+
return
|
|
449
|
+
required = (
|
|
450
|
+
"_emit_to_rookery",
|
|
451
|
+
"_execute_with_reliability",
|
|
452
|
+
"_on_message_enqueued",
|
|
453
|
+
)
|
|
454
|
+
if not all(hasattr(flow, name) for name in required):
|
|
455
|
+
return
|
|
456
|
+
|
|
457
|
+
original_emit = flow._emit_to_rookery
|
|
458
|
+
original_execute = flow._execute_with_reliability
|
|
459
|
+
original_on_enqueue = flow._on_message_enqueued
|
|
460
|
+
|
|
461
|
+
async def emit_with_trace(
|
|
462
|
+
flow_self: PenguiFlow,
|
|
463
|
+
message: Any,
|
|
464
|
+
*,
|
|
465
|
+
source: Any | None = None,
|
|
466
|
+
) -> None:
|
|
467
|
+
trace_id = getattr(message, "trace_id", None)
|
|
468
|
+
if trace_id is None:
|
|
469
|
+
context_trace = _TRACE_CONTEXT.get()
|
|
470
|
+
if context_trace is not None:
|
|
471
|
+
self._message_traces[id(message)] = context_trace
|
|
472
|
+
message = RookeryResult(trace_id=context_trace, value=message)
|
|
473
|
+
await original_emit(message, source=source)
|
|
474
|
+
|
|
475
|
+
async def execute_with_trace(
|
|
476
|
+
flow_self: PenguiFlow,
|
|
477
|
+
node: Any,
|
|
478
|
+
context: Any,
|
|
479
|
+
message: Any,
|
|
480
|
+
) -> None:
|
|
481
|
+
trace_id = getattr(message, "trace_id", None)
|
|
482
|
+
token = _TRACE_CONTEXT.set(trace_id)
|
|
483
|
+
try:
|
|
484
|
+
return await original_execute(node, context, message)
|
|
485
|
+
finally:
|
|
486
|
+
_TRACE_CONTEXT.reset(token)
|
|
487
|
+
|
|
488
|
+
def on_enqueue_with_trace(flow_self: PenguiFlow, message: Any) -> None:
|
|
489
|
+
trace_id = flow_self._get_trace_id(message)
|
|
490
|
+
if trace_id is None:
|
|
491
|
+
context_trace = _TRACE_CONTEXT.get()
|
|
492
|
+
if context_trace is not None:
|
|
493
|
+
self._message_traces[id(message)] = context_trace
|
|
494
|
+
original_on_enqueue(message)
|
|
495
|
+
|
|
496
|
+
object.__setattr__(flow, "_emit_to_rookery", MethodType(emit_with_trace, flow))
|
|
497
|
+
object.__setattr__(
|
|
498
|
+
flow,
|
|
499
|
+
"_execute_with_reliability",
|
|
500
|
+
MethodType(execute_with_trace, flow),
|
|
501
|
+
)
|
|
502
|
+
object.__setattr__(
|
|
503
|
+
flow,
|
|
504
|
+
"_on_message_enqueued",
|
|
505
|
+
MethodType(on_enqueue_with_trace, flow),
|
|
506
|
+
)
|
|
507
|
+
object.__setattr__(flow, "_a2a_adapter_patched", True)
|
|
508
|
+
|
|
509
|
+
def _ensure_dispatcher_task(self) -> None:
|
|
510
|
+
if self._dispatcher_task is not None and not self._dispatcher_task.done():
|
|
511
|
+
return
|
|
512
|
+
loop = asyncio.get_running_loop()
|
|
513
|
+
self._dispatcher_task = loop.create_task(self._dispatch_results())
|
|
514
|
+
|
|
515
|
+
async def _dispatch_results(self) -> None:
|
|
516
|
+
try:
|
|
517
|
+
while True:
|
|
518
|
+
counts_before = await self._snapshot_trace_counts()
|
|
519
|
+
item = await self._flow.fetch()
|
|
520
|
+
trace_id = getattr(item, "trace_id", None)
|
|
521
|
+
if trace_id is None:
|
|
522
|
+
trace_id = self._message_traces.pop(id(item), None)
|
|
523
|
+
counts_after = await self._snapshot_trace_counts()
|
|
524
|
+
if trace_id is None:
|
|
525
|
+
trace_id = self._infer_trace_from_counts(
|
|
526
|
+
counts_before, counts_after
|
|
527
|
+
)
|
|
528
|
+
if trace_id is None:
|
|
529
|
+
async with self._queue_lock:
|
|
530
|
+
active_traces = list(self._trace_queues.keys())
|
|
531
|
+
if len(active_traces) == 1:
|
|
532
|
+
trace_id = active_traces[0]
|
|
533
|
+
if trace_id is None:
|
|
534
|
+
raise RuntimeError("unable to determine trace for rookery payload")
|
|
535
|
+
async with self._queue_lock:
|
|
536
|
+
queue = self._trace_queues.get(trace_id)
|
|
537
|
+
if queue is None:
|
|
538
|
+
pending = self._pending_results.setdefault(trace_id, [])
|
|
539
|
+
pending.append(item)
|
|
540
|
+
continue
|
|
541
|
+
await queue.put(item)
|
|
542
|
+
except asyncio.CancelledError:
|
|
543
|
+
raise
|
|
544
|
+
|
|
545
|
+
async def _acquire_trace_queue(self, trace_id: str) -> asyncio.Queue[Any]:
|
|
546
|
+
self._ensure_dispatcher_task()
|
|
547
|
+
queue: asyncio.Queue[Any] = asyncio.Queue()
|
|
548
|
+
cancel_event = self._flow.ensure_trace_event(trace_id)
|
|
549
|
+
watcher = asyncio.create_task(
|
|
550
|
+
self._wait_for_cancellation(trace_id, cancel_event)
|
|
551
|
+
)
|
|
552
|
+
async with self._queue_lock:
|
|
553
|
+
if trace_id in self._trace_queues:
|
|
554
|
+
watcher.cancel()
|
|
555
|
+
with suppress(asyncio.CancelledError):
|
|
556
|
+
await watcher
|
|
557
|
+
raise A2ARequestError(
|
|
558
|
+
status_code=409, detail=f"trace {trace_id!r} already active"
|
|
559
|
+
)
|
|
560
|
+
self._trace_queues[trace_id] = queue
|
|
561
|
+
self._cancel_watchers[trace_id] = watcher
|
|
562
|
+
pending = self._pending_results.pop(trace_id, [])
|
|
563
|
+
for item in pending:
|
|
564
|
+
await queue.put(item)
|
|
565
|
+
return queue
|
|
566
|
+
|
|
567
|
+
async def _get_trace_queue(self, trace_id: str) -> asyncio.Queue[Any]:
|
|
568
|
+
async with self._queue_lock:
|
|
569
|
+
queue = self._trace_queues.get(trace_id)
|
|
570
|
+
if queue is None:
|
|
571
|
+
raise A2ARequestError(status_code=503, detail="trace queue missing")
|
|
572
|
+
return queue
|
|
573
|
+
|
|
574
|
+
async def _release_trace_queue(self, trace_id: str) -> None:
|
|
575
|
+
async with self._queue_lock:
|
|
576
|
+
queue = self._trace_queues.pop(trace_id, None)
|
|
577
|
+
self._pending_results.pop(trace_id, None)
|
|
578
|
+
watcher = self._cancel_watchers.pop(trace_id, None)
|
|
579
|
+
if watcher is not None:
|
|
580
|
+
watcher.cancel()
|
|
581
|
+
with suppress(asyncio.CancelledError):
|
|
582
|
+
await watcher
|
|
583
|
+
if queue is not None:
|
|
584
|
+
while not queue.empty():
|
|
585
|
+
queue.get_nowait()
|
|
586
|
+
|
|
587
|
+
async def _wait_for_cancellation(
|
|
588
|
+
self, trace_id: str, event: asyncio.Event
|
|
589
|
+
) -> None:
|
|
590
|
+
try:
|
|
591
|
+
await event.wait()
|
|
592
|
+
async with self._queue_lock:
|
|
593
|
+
queue = self._trace_queues.get(trace_id)
|
|
594
|
+
if queue is not None:
|
|
595
|
+
await queue.put(TraceCancelled(trace_id))
|
|
596
|
+
except asyncio.CancelledError:
|
|
597
|
+
raise
|
|
598
|
+
|
|
599
|
+
async def _snapshot_trace_counts(self) -> dict[str, int]:
|
|
600
|
+
async with self._queue_lock:
|
|
601
|
+
active = list(self._trace_queues.keys())
|
|
602
|
+
return {trace: self._flow._trace_counts.get(trace, 0) for trace in active}
|
|
603
|
+
|
|
604
|
+
def _infer_trace_from_counts(
|
|
605
|
+
self, before: Mapping[str, int], after: Mapping[str, int]
|
|
606
|
+
) -> str | None:
|
|
607
|
+
candidates: list[str] = []
|
|
608
|
+
for trace_id, before_count in before.items():
|
|
609
|
+
after_count = after.get(trace_id)
|
|
610
|
+
if after_count is None or after_count < before_count:
|
|
611
|
+
candidates.append(trace_id)
|
|
612
|
+
if candidates:
|
|
613
|
+
if len(candidates) == 1:
|
|
614
|
+
return candidates[0]
|
|
615
|
+
return None
|
|
616
|
+
new_traces = [trace_id for trace_id in after.keys() if trace_id not in before]
|
|
617
|
+
if len(new_traces) == 1:
|
|
618
|
+
return new_traces[0]
|
|
619
|
+
return None
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def create_a2a_app(
|
|
623
|
+
adapter: A2AServerAdapter, *, include_docs: bool = True
|
|
624
|
+
): # pragma: no cover - exercised via tests
|
|
625
|
+
"""Create a FastAPI application exposing the A2A surface."""
|
|
626
|
+
|
|
627
|
+
try:
|
|
628
|
+
from fastapi import FastAPI, HTTPException
|
|
629
|
+
from fastapi.responses import StreamingResponse
|
|
630
|
+
except ModuleNotFoundError as exc: # pragma: no cover - optional extra
|
|
631
|
+
raise RuntimeError(
|
|
632
|
+
"FastAPI is required for the A2A server adapter."
|
|
633
|
+
" Install penguiflow[a2a-server]."
|
|
634
|
+
) from exc
|
|
635
|
+
|
|
636
|
+
docs_url = "/docs" if include_docs else None
|
|
637
|
+
openapi_url = "/openapi.json" if include_docs else None
|
|
638
|
+
|
|
639
|
+
@asynccontextmanager
|
|
640
|
+
async def lifespan(_app): # pragma: no cover - executed in tests via router context
|
|
641
|
+
await adapter.start()
|
|
642
|
+
try:
|
|
643
|
+
yield
|
|
644
|
+
finally:
|
|
645
|
+
await adapter.stop()
|
|
646
|
+
|
|
647
|
+
app = FastAPI(
|
|
648
|
+
title=adapter.agent_card.name,
|
|
649
|
+
description=adapter.agent_card.description,
|
|
650
|
+
version=adapter.agent_card.version,
|
|
651
|
+
docs_url=docs_url,
|
|
652
|
+
openapi_url=openapi_url,
|
|
653
|
+
lifespan=lifespan,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
@app.get("/agent")
|
|
657
|
+
async def get_agent() -> dict[str, Any]:
|
|
658
|
+
return adapter.agent_card.to_payload()
|
|
659
|
+
|
|
660
|
+
@app.post("/message/send")
|
|
661
|
+
async def message_send(payload: A2AMessagePayload) -> dict[str, Any]:
|
|
662
|
+
try:
|
|
663
|
+
return await adapter.handle_send(payload)
|
|
664
|
+
except A2ARequestError as exc:
|
|
665
|
+
raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
|
|
666
|
+
|
|
667
|
+
@app.post("/message/stream")
|
|
668
|
+
async def message_stream(payload: A2AMessagePayload):
|
|
669
|
+
try:
|
|
670
|
+
generator, task_id, context_id = await adapter.stream(payload)
|
|
671
|
+
except A2ARequestError as exc:
|
|
672
|
+
raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
|
|
673
|
+
response = StreamingResponse(generator, media_type="text/event-stream")
|
|
674
|
+
response.headers["Cache-Control"] = "no-cache"
|
|
675
|
+
response.headers["X-A2A-Task-Id"] = task_id
|
|
676
|
+
response.headers["X-A2A-Context-Id"] = context_id
|
|
677
|
+
return response
|
|
678
|
+
|
|
679
|
+
@app.post("/tasks/cancel")
|
|
680
|
+
async def cancel_task(payload: A2ATaskCancelRequest) -> dict[str, Any]:
|
|
681
|
+
try:
|
|
682
|
+
return await adapter.cancel(payload)
|
|
683
|
+
except A2ARequestError as exc:
|
|
684
|
+
raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
|
|
685
|
+
|
|
686
|
+
return app
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
__all__ = [
|
|
690
|
+
"A2AAgentCard",
|
|
691
|
+
"A2AServerAdapter",
|
|
692
|
+
"A2AMessagePayload",
|
|
693
|
+
"A2ASkill",
|
|
694
|
+
"create_a2a_app",
|
|
695
|
+
]
|