penguiflow 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of penguiflow might be problematic. Click here for more details.

@@ -0,0 +1,695 @@
1
+ """Expose PenguiFlow runs through an A2A-compliant HTTP surface."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import contextvars
7
+ import json
8
+ import uuid
9
+ from collections.abc import AsyncIterator, Mapping, Sequence
10
+ from contextlib import asynccontextmanager, suppress
11
+ from dataclasses import dataclass
12
+ from types import MethodType
13
+ from typing import Any
14
+
15
+ from pydantic import BaseModel, ConfigDict, Field, ValidationError
16
+
17
+ from penguiflow.core import PenguiFlow, TraceCancelled
18
+ from penguiflow.errors import FlowError
19
+ from penguiflow.state import RemoteBinding
20
+ from penguiflow.streaming import format_sse_event
21
+ from penguiflow.types import Headers, Message, StreamChunk
22
+
23
+ _QUEUE_SHUTDOWN = object()
24
+ _TRACE_CONTEXT: contextvars.ContextVar[str | None] = contextvars.ContextVar(
25
+ "penguiflow_a2a_trace", default=None
26
+ )
27
+
28
+
29
+ @dataclass(slots=True)
30
+ class RookeryResult:
31
+ trace_id: str
32
+ value: Any
33
+
34
+
35
+ class A2ASkill(BaseModel):
36
+ """Description of a single capability exposed by an agent."""
37
+
38
+ name: str
39
+ description: str
40
+ mode: str = Field(
41
+ default="both",
42
+ description="Whether the skill supports message/send, message/stream, or both.",
43
+ )
44
+ inputs: dict[str, Any] = Field(default_factory=dict)
45
+ outputs: dict[str, Any] = Field(default_factory=dict)
46
+
47
+ model_config = ConfigDict(extra="allow")
48
+
49
+
50
+ class A2AAgentCard(BaseModel):
51
+ """Lightweight Agent Card surfaced at ``GET /agent``."""
52
+
53
+ name: str
54
+ description: str
55
+ version: str = "1.0.0"
56
+ schema_version: str = Field(default="1.0")
57
+ tags: list[str] = Field(default_factory=list)
58
+ capabilities: list[str] = Field(default_factory=list)
59
+ skills: list[A2ASkill] = Field(default_factory=list)
60
+ contact_url: str | None = None
61
+ documentation_url: str | None = None
62
+
63
+ model_config = ConfigDict(extra="allow")
64
+
65
+ def to_payload(self) -> dict[str, Any]:
66
+ """Return a serialisable dictionary representation."""
67
+
68
+ return self.model_dump()
69
+
70
+
71
+ class A2AMessagePayload(BaseModel):
72
+ """Request payload accepted by ``message/send`` and ``message/stream``."""
73
+
74
+ payload: Any
75
+ headers: Mapping[str, Any] = Field(default_factory=dict)
76
+ meta: dict[str, Any] = Field(default_factory=dict)
77
+ trace_id: str | None = Field(default=None, alias="traceId")
78
+ context_id: str | None = Field(default=None, alias="contextId")
79
+ task_id: str | None = Field(default=None, alias="taskId")
80
+ deadline_s: float | None = Field(default=None, alias="deadlineSeconds")
81
+
82
+ model_config = ConfigDict(populate_by_name=True)
83
+
84
+
85
+ class A2ATaskCancelRequest(BaseModel):
86
+ """JSON body accepted by ``tasks/cancel``."""
87
+
88
+ task_id: str = Field(alias="taskId")
89
+
90
+ model_config = ConfigDict(populate_by_name=True)
91
+
92
+
93
+ class A2ARequestError(Exception):
94
+ """Exception converted to ``HTTPException`` inside the FastAPI app."""
95
+
96
+ def __init__(self, *, status_code: int, detail: str) -> None:
97
+ super().__init__(detail)
98
+ self.status_code = status_code
99
+ self.detail = detail
100
+
101
+
102
+ class A2AServerAdapter:
103
+ """Bridge between PenguiFlow and the A2A HTTP surface."""
104
+
105
+ def __init__(
106
+ self,
107
+ flow: PenguiFlow,
108
+ *,
109
+ agent_card: A2AAgentCard | Mapping[str, Any],
110
+ agent_url: str,
111
+ target: Sequence[Any] | Any | None = None,
112
+ registry: Any | None = None,
113
+ default_headers: Mapping[str, Any] | None = None,
114
+ ) -> None:
115
+ self._flow = flow
116
+ self._registry = registry
117
+ self._target = target
118
+ self._default_headers = dict(default_headers or {})
119
+ self.agent_card = (
120
+ agent_card
121
+ if isinstance(agent_card, A2AAgentCard)
122
+ else A2AAgentCard.model_validate(agent_card)
123
+ )
124
+ self.agent_url = agent_url
125
+ self._flow_started = False
126
+ self._tasks: dict[str, str] = {}
127
+ self._contexts: dict[str, str] = {}
128
+ self._lock = asyncio.Lock()
129
+ self._queue_lock = asyncio.Lock()
130
+ self._trace_queues: dict[str, asyncio.Queue[Any]] = {}
131
+ self._pending_results: dict[str, list[Any]] = {}
132
+ self._cancel_watchers: dict[str, asyncio.Task[None]] = {}
133
+ self._dispatcher_task: asyncio.Task[None] | None = None
134
+ self._message_traces: dict[int, str] = {}
135
+ self._patch_flow()
136
+
137
+ async def start(self) -> None:
138
+ """Start the underlying flow if it is not running."""
139
+
140
+ if self._flow_started:
141
+ return
142
+ self._flow.run(registry=self._registry)
143
+ self._flow_started = True
144
+ self._ensure_dispatcher_task()
145
+
146
+ async def stop(self) -> None:
147
+ """Gracefully stop the underlying flow."""
148
+
149
+ if not self._flow_started:
150
+ return
151
+ dispatcher = self._dispatcher_task
152
+ self._dispatcher_task = None
153
+ if dispatcher is not None:
154
+ dispatcher.cancel()
155
+ await self._flow.stop()
156
+ if dispatcher is not None:
157
+ with suppress(asyncio.CancelledError):
158
+ await dispatcher
159
+ self._flow_started = False
160
+ async with self._queue_lock:
161
+ queues = list(self._trace_queues.values())
162
+ cancel_watchers = list(self._cancel_watchers.values())
163
+ self._trace_queues.clear()
164
+ self._pending_results.clear()
165
+ self._cancel_watchers.clear()
166
+ for watcher in cancel_watchers:
167
+ watcher.cancel()
168
+ with suppress(asyncio.CancelledError):
169
+ await watcher
170
+ for queue in queues:
171
+ queue.put_nowait(_QUEUE_SHUTDOWN)
172
+
173
+ def _ensure_started(self) -> None:
174
+ if not self._flow_started:
175
+ raise A2ARequestError(status_code=503, detail="flow is not running")
176
+
177
+ async def handle_send(self, request: A2AMessagePayload) -> dict[str, Any]:
178
+ """Execute ``message/send`` and return the final artifact."""
179
+
180
+ self._ensure_started()
181
+ message, task_id, context_id = self._prepare_message(request)
182
+ await self._register_task(task_id, message.trace_id, context_id)
183
+ await self._persist_binding(message.trace_id, context_id, task_id)
184
+ result_queue = await self._acquire_trace_queue(message.trace_id)
185
+
186
+ try:
187
+ await self._flow.emit(message, to=self._target)
188
+ while True:
189
+ item = await result_queue.get()
190
+ if item is _QUEUE_SHUTDOWN:
191
+ raise A2ARequestError(
192
+ status_code=503, detail="flow is shutting down"
193
+ )
194
+ if isinstance(item, TraceCancelled):
195
+ raise item
196
+ if isinstance(item, FlowError):
197
+ raise item
198
+ if isinstance(item, Exception): # pragma: no cover - defensive
199
+ raise item
200
+ if isinstance(item, RookeryResult):
201
+ payload_candidate = item.value
202
+ else:
203
+ payload_candidate = getattr(item, "payload", item)
204
+ if isinstance(payload_candidate, StreamChunk):
205
+ continue
206
+ result = item
207
+ break
208
+ if isinstance(result, RookeryResult):
209
+ payload = result.value
210
+ else:
211
+ payload = getattr(result, "payload", result)
212
+ response: dict[str, Any] = {
213
+ "status": "succeeded",
214
+ "taskId": task_id,
215
+ "contextId": context_id,
216
+ "traceId": message.trace_id,
217
+ "output": self._to_jsonable(payload),
218
+ }
219
+ meta = getattr(result, "meta", None)
220
+ if meta:
221
+ response["meta"] = dict(meta)
222
+ return response
223
+ except TraceCancelled:
224
+ return {
225
+ "status": "cancelled",
226
+ "taskId": task_id,
227
+ "contextId": context_id,
228
+ "traceId": message.trace_id,
229
+ }
230
+ except FlowError as exc:
231
+ error_payload = exc.to_payload()
232
+ error_payload.setdefault("trace_id", message.trace_id)
233
+ return {
234
+ "status": "failed",
235
+ "taskId": task_id,
236
+ "contextId": context_id,
237
+ "traceId": message.trace_id,
238
+ "error": error_payload,
239
+ }
240
+ except A2ARequestError:
241
+ raise
242
+ except Exception as exc: # pragma: no cover - defensive fallback
243
+ raise A2ARequestError(
244
+ status_code=500,
245
+ detail=f"flow execution failed: {exc}",
246
+ ) from exc
247
+ finally:
248
+ await self._release_task(task_id)
249
+ await self._release_trace_queue(message.trace_id)
250
+
251
+ async def stream(
252
+ self, request: A2AMessagePayload
253
+ ) -> tuple[AsyncIterator[bytes], str, str]:
254
+ """Execute ``message/stream`` and return an SSE iterator."""
255
+
256
+ self._ensure_started()
257
+ message, task_id, context_id = self._prepare_message(request)
258
+ await self._register_task(task_id, message.trace_id, context_id)
259
+ await self._persist_binding(message.trace_id, context_id, task_id)
260
+ await self._acquire_trace_queue(message.trace_id)
261
+ generator = self._stream_generator(message, task_id, context_id)
262
+ return generator, task_id, context_id
263
+
264
+ async def cancel(self, request: A2ATaskCancelRequest) -> dict[str, Any]:
265
+ """Cancel an active task."""
266
+
267
+ self._ensure_started()
268
+ task_id = request.task_id
269
+ async with self._lock:
270
+ trace_id = self._tasks.get(task_id)
271
+ context_id = self._contexts.get(task_id)
272
+ if trace_id is None:
273
+ return {"taskId": task_id, "cancelled": False}
274
+ cancelled = await self._flow.cancel(trace_id)
275
+ response = {
276
+ "taskId": task_id,
277
+ "cancelled": cancelled,
278
+ "traceId": trace_id,
279
+ }
280
+ if context_id is not None:
281
+ response["contextId"] = context_id
282
+ return response
283
+
284
+ def _prepare_message(
285
+ self, request: A2AMessagePayload
286
+ ) -> tuple[Message, str, str]:
287
+ headers_data = {**self._default_headers, **dict(request.headers)}
288
+ try:
289
+ headers = Headers(**headers_data)
290
+ except ValidationError as exc: # pragma: no cover - pydantic formats nicely
291
+ raise A2ARequestError(status_code=422, detail=str(exc)) from exc
292
+
293
+ kwargs: dict[str, Any] = {}
294
+ if request.trace_id is not None:
295
+ kwargs["trace_id"] = request.trace_id
296
+ if request.deadline_s is not None:
297
+ kwargs["deadline_s"] = request.deadline_s
298
+ message = Message(payload=request.payload, headers=headers, **kwargs)
299
+ message.meta.update(request.meta)
300
+
301
+ context_id = request.context_id or message.trace_id
302
+ task_id = request.task_id or message.trace_id or uuid.uuid4().hex
303
+ return message, task_id, context_id
304
+
305
+ async def _register_task(
306
+ self, task_id: str, trace_id: str, context_id: str
307
+ ) -> None:
308
+ async with self._lock:
309
+ if task_id in self._tasks:
310
+ raise A2ARequestError(
311
+ status_code=409, detail=f"task {task_id!r} already active"
312
+ )
313
+ self._tasks[task_id] = trace_id
314
+ self._contexts[task_id] = context_id
315
+
316
+ async def _release_task(self, task_id: str) -> None:
317
+ async with self._lock:
318
+ self._tasks.pop(task_id, None)
319
+ self._contexts.pop(task_id, None)
320
+
321
+ async def _persist_binding(
322
+ self, trace_id: str, context_id: str, task_id: str
323
+ ) -> None:
324
+ binding = RemoteBinding(
325
+ trace_id=trace_id,
326
+ context_id=context_id,
327
+ task_id=task_id,
328
+ agent_url=self.agent_url,
329
+ )
330
+ await self._flow.save_remote_binding(binding)
331
+
332
+ async def _stream_generator(
333
+ self, message: Message, task_id: str, context_id: str
334
+ ) -> AsyncIterator[bytes]:
335
+ result_queue = await self._get_trace_queue(message.trace_id)
336
+ try:
337
+ await self._flow.emit(message, to=self._target)
338
+ yield self._format_event(
339
+ "status",
340
+ {
341
+ "status": "accepted",
342
+ "taskId": task_id,
343
+ "contextId": context_id,
344
+ },
345
+ )
346
+ while True:
347
+ item = await result_queue.get()
348
+ if item is _QUEUE_SHUTDOWN:
349
+ raise A2ARequestError(
350
+ status_code=503, detail="flow is shutting down"
351
+ )
352
+ if isinstance(item, TraceCancelled):
353
+ raise item
354
+ if isinstance(item, FlowError):
355
+ raise item
356
+ if isinstance(item, Exception): # pragma: no cover - defensive
357
+ raise item
358
+ if isinstance(item, RookeryResult):
359
+ payload = item.value
360
+ else:
361
+ payload = getattr(item, "payload", item)
362
+ if isinstance(payload, StreamChunk):
363
+ yield self._format_chunk_event(payload, task_id, context_id)
364
+ continue
365
+ yield self._format_event(
366
+ "artifact",
367
+ {
368
+ "taskId": task_id,
369
+ "contextId": context_id,
370
+ "output": self._to_jsonable(payload),
371
+ },
372
+ )
373
+ break
374
+ yield self._format_event(
375
+ "done", {"taskId": task_id, "contextId": context_id}
376
+ )
377
+ except TraceCancelled:
378
+ yield self._format_event(
379
+ "error",
380
+ {
381
+ "taskId": task_id,
382
+ "contextId": context_id,
383
+ "code": "TRACE_CANCELLED",
384
+ "message": "Trace cancelled",
385
+ },
386
+ )
387
+ yield self._format_event(
388
+ "done", {"taskId": task_id, "contextId": context_id}
389
+ )
390
+ except FlowError as exc:
391
+ payload = exc.to_payload()
392
+ payload.update({"taskId": task_id, "contextId": context_id})
393
+ yield self._format_event("error", payload)
394
+ yield self._format_event(
395
+ "done", {"taskId": task_id, "contextId": context_id}
396
+ )
397
+ except Exception as exc: # pragma: no cover - defensive fallback
398
+ yield self._format_event(
399
+ "error",
400
+ {
401
+ "taskId": task_id,
402
+ "contextId": context_id,
403
+ "code": "INTERNAL_ERROR",
404
+ "message": str(exc) or exc.__class__.__name__,
405
+ },
406
+ )
407
+ yield self._format_event(
408
+ "done", {"taskId": task_id, "contextId": context_id}
409
+ )
410
+ finally:
411
+ await self._release_task(task_id)
412
+ await self._release_trace_queue(message.trace_id)
413
+
414
+ def _format_event(self, event: str, data: Mapping[str, Any]) -> bytes:
415
+ payload = json.dumps(data, ensure_ascii=False)
416
+ return f"event: {event}\ndata: {payload}\n\n".encode()
417
+
418
+ def _format_chunk_event(
419
+ self, chunk: StreamChunk, task_id: str, context_id: str
420
+ ) -> bytes:
421
+ meta = dict(chunk.meta)
422
+ meta.setdefault("taskId", task_id)
423
+ meta.setdefault("contextId", context_id)
424
+ enriched = chunk.model_copy(update={"meta": meta})
425
+ return format_sse_event(enriched).encode("utf-8")
426
+
427
+ def _to_jsonable(self, value: Any) -> Any:
428
+ if isinstance(value, BaseModel):
429
+ return value.model_dump()
430
+ if isinstance(value, Message):
431
+ return {
432
+ "payload": self._to_jsonable(value.payload),
433
+ "headers": value.headers.model_dump(),
434
+ "trace_id": value.trace_id,
435
+ "meta": dict(value.meta),
436
+ }
437
+ if isinstance(value, RookeryResult):
438
+ return self._to_jsonable(value.value)
439
+ if isinstance(value, Mapping):
440
+ return {k: self._to_jsonable(v) for k, v in value.items()}
441
+ if isinstance(value, list | tuple | set):
442
+ return [self._to_jsonable(item) for item in value]
443
+ return value
444
+
445
+ def _patch_flow(self) -> None:
446
+ flow = self._flow
447
+ if getattr(flow, "_a2a_adapter_patched", False):
448
+ return
449
+ required = (
450
+ "_emit_to_rookery",
451
+ "_execute_with_reliability",
452
+ "_on_message_enqueued",
453
+ )
454
+ if not all(hasattr(flow, name) for name in required):
455
+ return
456
+
457
+ original_emit = flow._emit_to_rookery
458
+ original_execute = flow._execute_with_reliability
459
+ original_on_enqueue = flow._on_message_enqueued
460
+
461
+ async def emit_with_trace(
462
+ flow_self: PenguiFlow,
463
+ message: Any,
464
+ *,
465
+ source: Any | None = None,
466
+ ) -> None:
467
+ trace_id = getattr(message, "trace_id", None)
468
+ if trace_id is None:
469
+ context_trace = _TRACE_CONTEXT.get()
470
+ if context_trace is not None:
471
+ self._message_traces[id(message)] = context_trace
472
+ message = RookeryResult(trace_id=context_trace, value=message)
473
+ await original_emit(message, source=source)
474
+
475
+ async def execute_with_trace(
476
+ flow_self: PenguiFlow,
477
+ node: Any,
478
+ context: Any,
479
+ message: Any,
480
+ ) -> None:
481
+ trace_id = getattr(message, "trace_id", None)
482
+ token = _TRACE_CONTEXT.set(trace_id)
483
+ try:
484
+ return await original_execute(node, context, message)
485
+ finally:
486
+ _TRACE_CONTEXT.reset(token)
487
+
488
+ def on_enqueue_with_trace(flow_self: PenguiFlow, message: Any) -> None:
489
+ trace_id = flow_self._get_trace_id(message)
490
+ if trace_id is None:
491
+ context_trace = _TRACE_CONTEXT.get()
492
+ if context_trace is not None:
493
+ self._message_traces[id(message)] = context_trace
494
+ original_on_enqueue(message)
495
+
496
+ object.__setattr__(flow, "_emit_to_rookery", MethodType(emit_with_trace, flow))
497
+ object.__setattr__(
498
+ flow,
499
+ "_execute_with_reliability",
500
+ MethodType(execute_with_trace, flow),
501
+ )
502
+ object.__setattr__(
503
+ flow,
504
+ "_on_message_enqueued",
505
+ MethodType(on_enqueue_with_trace, flow),
506
+ )
507
+ object.__setattr__(flow, "_a2a_adapter_patched", True)
508
+
509
+ def _ensure_dispatcher_task(self) -> None:
510
+ if self._dispatcher_task is not None and not self._dispatcher_task.done():
511
+ return
512
+ loop = asyncio.get_running_loop()
513
+ self._dispatcher_task = loop.create_task(self._dispatch_results())
514
+
515
+ async def _dispatch_results(self) -> None:
516
+ try:
517
+ while True:
518
+ counts_before = await self._snapshot_trace_counts()
519
+ item = await self._flow.fetch()
520
+ trace_id = getattr(item, "trace_id", None)
521
+ if trace_id is None:
522
+ trace_id = self._message_traces.pop(id(item), None)
523
+ counts_after = await self._snapshot_trace_counts()
524
+ if trace_id is None:
525
+ trace_id = self._infer_trace_from_counts(
526
+ counts_before, counts_after
527
+ )
528
+ if trace_id is None:
529
+ async with self._queue_lock:
530
+ active_traces = list(self._trace_queues.keys())
531
+ if len(active_traces) == 1:
532
+ trace_id = active_traces[0]
533
+ if trace_id is None:
534
+ raise RuntimeError("unable to determine trace for rookery payload")
535
+ async with self._queue_lock:
536
+ queue = self._trace_queues.get(trace_id)
537
+ if queue is None:
538
+ pending = self._pending_results.setdefault(trace_id, [])
539
+ pending.append(item)
540
+ continue
541
+ await queue.put(item)
542
+ except asyncio.CancelledError:
543
+ raise
544
+
545
+ async def _acquire_trace_queue(self, trace_id: str) -> asyncio.Queue[Any]:
546
+ self._ensure_dispatcher_task()
547
+ queue: asyncio.Queue[Any] = asyncio.Queue()
548
+ cancel_event = self._flow.ensure_trace_event(trace_id)
549
+ watcher = asyncio.create_task(
550
+ self._wait_for_cancellation(trace_id, cancel_event)
551
+ )
552
+ async with self._queue_lock:
553
+ if trace_id in self._trace_queues:
554
+ watcher.cancel()
555
+ with suppress(asyncio.CancelledError):
556
+ await watcher
557
+ raise A2ARequestError(
558
+ status_code=409, detail=f"trace {trace_id!r} already active"
559
+ )
560
+ self._trace_queues[trace_id] = queue
561
+ self._cancel_watchers[trace_id] = watcher
562
+ pending = self._pending_results.pop(trace_id, [])
563
+ for item in pending:
564
+ await queue.put(item)
565
+ return queue
566
+
567
+ async def _get_trace_queue(self, trace_id: str) -> asyncio.Queue[Any]:
568
+ async with self._queue_lock:
569
+ queue = self._trace_queues.get(trace_id)
570
+ if queue is None:
571
+ raise A2ARequestError(status_code=503, detail="trace queue missing")
572
+ return queue
573
+
574
+ async def _release_trace_queue(self, trace_id: str) -> None:
575
+ async with self._queue_lock:
576
+ queue = self._trace_queues.pop(trace_id, None)
577
+ self._pending_results.pop(trace_id, None)
578
+ watcher = self._cancel_watchers.pop(trace_id, None)
579
+ if watcher is not None:
580
+ watcher.cancel()
581
+ with suppress(asyncio.CancelledError):
582
+ await watcher
583
+ if queue is not None:
584
+ while not queue.empty():
585
+ queue.get_nowait()
586
+
587
+ async def _wait_for_cancellation(
588
+ self, trace_id: str, event: asyncio.Event
589
+ ) -> None:
590
+ try:
591
+ await event.wait()
592
+ async with self._queue_lock:
593
+ queue = self._trace_queues.get(trace_id)
594
+ if queue is not None:
595
+ await queue.put(TraceCancelled(trace_id))
596
+ except asyncio.CancelledError:
597
+ raise
598
+
599
+ async def _snapshot_trace_counts(self) -> dict[str, int]:
600
+ async with self._queue_lock:
601
+ active = list(self._trace_queues.keys())
602
+ return {trace: self._flow._trace_counts.get(trace, 0) for trace in active}
603
+
604
+ def _infer_trace_from_counts(
605
+ self, before: Mapping[str, int], after: Mapping[str, int]
606
+ ) -> str | None:
607
+ candidates: list[str] = []
608
+ for trace_id, before_count in before.items():
609
+ after_count = after.get(trace_id)
610
+ if after_count is None or after_count < before_count:
611
+ candidates.append(trace_id)
612
+ if candidates:
613
+ if len(candidates) == 1:
614
+ return candidates[0]
615
+ return None
616
+ new_traces = [trace_id for trace_id in after.keys() if trace_id not in before]
617
+ if len(new_traces) == 1:
618
+ return new_traces[0]
619
+ return None
620
+
621
+
622
+ def create_a2a_app(
623
+ adapter: A2AServerAdapter, *, include_docs: bool = True
624
+ ): # pragma: no cover - exercised via tests
625
+ """Create a FastAPI application exposing the A2A surface."""
626
+
627
+ try:
628
+ from fastapi import FastAPI, HTTPException
629
+ from fastapi.responses import StreamingResponse
630
+ except ModuleNotFoundError as exc: # pragma: no cover - optional extra
631
+ raise RuntimeError(
632
+ "FastAPI is required for the A2A server adapter."
633
+ " Install penguiflow[a2a-server]."
634
+ ) from exc
635
+
636
+ docs_url = "/docs" if include_docs else None
637
+ openapi_url = "/openapi.json" if include_docs else None
638
+
639
+ @asynccontextmanager
640
+ async def lifespan(_app): # pragma: no cover - executed in tests via router context
641
+ await adapter.start()
642
+ try:
643
+ yield
644
+ finally:
645
+ await adapter.stop()
646
+
647
+ app = FastAPI(
648
+ title=adapter.agent_card.name,
649
+ description=adapter.agent_card.description,
650
+ version=adapter.agent_card.version,
651
+ docs_url=docs_url,
652
+ openapi_url=openapi_url,
653
+ lifespan=lifespan,
654
+ )
655
+
656
+ @app.get("/agent")
657
+ async def get_agent() -> dict[str, Any]:
658
+ return adapter.agent_card.to_payload()
659
+
660
+ @app.post("/message/send")
661
+ async def message_send(payload: A2AMessagePayload) -> dict[str, Any]:
662
+ try:
663
+ return await adapter.handle_send(payload)
664
+ except A2ARequestError as exc:
665
+ raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
666
+
667
+ @app.post("/message/stream")
668
+ async def message_stream(payload: A2AMessagePayload):
669
+ try:
670
+ generator, task_id, context_id = await adapter.stream(payload)
671
+ except A2ARequestError as exc:
672
+ raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
673
+ response = StreamingResponse(generator, media_type="text/event-stream")
674
+ response.headers["Cache-Control"] = "no-cache"
675
+ response.headers["X-A2A-Task-Id"] = task_id
676
+ response.headers["X-A2A-Context-Id"] = context_id
677
+ return response
678
+
679
+ @app.post("/tasks/cancel")
680
+ async def cancel_task(payload: A2ATaskCancelRequest) -> dict[str, Any]:
681
+ try:
682
+ return await adapter.cancel(payload)
683
+ except A2ARequestError as exc:
684
+ raise HTTPException(status_code=exc.status_code, detail=exc.detail) from exc
685
+
686
+ return app
687
+
688
+
689
+ __all__ = [
690
+ "A2AAgentCard",
691
+ "A2AServerAdapter",
692
+ "A2AMessagePayload",
693
+ "A2ASkill",
694
+ "create_a2a_app",
695
+ ]