a2a-adapter 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- a2a_adapter/__init__.py +1 -1
- a2a_adapter/integrations/__init__.py +5 -1
- a2a_adapter/integrations/callable.py +204 -90
- a2a_adapter/integrations/crewai.py +489 -46
- a2a_adapter/integrations/langchain.py +248 -90
- a2a_adapter/integrations/langgraph.py +756 -0
- a2a_adapter/loader.py +71 -28
- {a2a_adapter-0.1.2.dist-info → a2a_adapter-0.1.4.dist-info}/METADATA +96 -43
- a2a_adapter-0.1.4.dist-info/RECORD +15 -0
- {a2a_adapter-0.1.2.dist-info → a2a_adapter-0.1.4.dist-info}/WHEEL +1 -1
- a2a_adapter-0.1.2.dist-info/RECORD +0 -14
- {a2a_adapter-0.1.2.dist-info → a2a_adapter-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {a2a_adapter-0.1.2.dist-info → a2a_adapter-0.1.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,756 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LangGraph adapter for A2A Protocol.
|
|
3
|
+
|
|
4
|
+
This adapter enables LangGraph compiled workflows to be exposed as A2A-compliant
|
|
5
|
+
agents with support for both streaming and non-streaming modes.
|
|
6
|
+
|
|
7
|
+
Supports two modes:
|
|
8
|
+
- Synchronous (default): Blocks until workflow completes, returns Message
|
|
9
|
+
- Async Task Mode: Returns Task immediately, processes in background, supports polling
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import uuid
|
|
16
|
+
from datetime import datetime, timezone
|
|
17
|
+
from typing import Any, AsyncIterator, Dict
|
|
18
|
+
|
|
19
|
+
from a2a.types import (
|
|
20
|
+
Message,
|
|
21
|
+
MessageSendParams,
|
|
22
|
+
Task,
|
|
23
|
+
TaskState,
|
|
24
|
+
TaskStatus,
|
|
25
|
+
TextPart,
|
|
26
|
+
Role,
|
|
27
|
+
Part,
|
|
28
|
+
)
|
|
29
|
+
from ..adapter import BaseAgentAdapter
|
|
30
|
+
|
|
31
|
+
# Lazy import for TaskStore to avoid hard dependency
|
|
32
|
+
try:
|
|
33
|
+
from a2a.server.tasks import TaskStore, InMemoryTaskStore
|
|
34
|
+
_HAS_TASK_STORE = True
|
|
35
|
+
except ImportError:
|
|
36
|
+
_HAS_TASK_STORE = False
|
|
37
|
+
TaskStore = None # type: ignore
|
|
38
|
+
InMemoryTaskStore = None # type: ignore
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class LangGraphAgentAdapter(BaseAgentAdapter):
|
|
44
|
+
"""
|
|
45
|
+
Adapter for integrating LangGraph compiled workflows as A2A agents.
|
|
46
|
+
|
|
47
|
+
This adapter works with LangGraph's CompiledGraph objects (the result of
|
|
48
|
+
calling .compile() on a StateGraph) and supports both streaming and
|
|
49
|
+
non-streaming execution modes.
|
|
50
|
+
|
|
51
|
+
Supports three execution patterns:
|
|
52
|
+
|
|
53
|
+
1. **Synchronous Mode** (default):
|
|
54
|
+
- Blocks until the workflow completes
|
|
55
|
+
- Returns a Message with the final result
|
|
56
|
+
- Best for quick workflows (< 30 seconds)
|
|
57
|
+
|
|
58
|
+
2. **Streaming Mode**:
|
|
59
|
+
- Streams intermediate results as they're produced
|
|
60
|
+
- Uses LangGraph's astream() method
|
|
61
|
+
- Best for real-time feedback during execution
|
|
62
|
+
|
|
63
|
+
3. **Async Task Mode** (async_mode=True):
|
|
64
|
+
- Returns a Task with state="working" immediately
|
|
65
|
+
- Processes the workflow in the background
|
|
66
|
+
- Clients can poll get_task() for status updates
|
|
67
|
+
- Best for long-running workflows
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
>>> from langgraph.graph import StateGraph
|
|
71
|
+
>>> from typing import TypedDict
|
|
72
|
+
>>>
|
|
73
|
+
>>> class State(TypedDict):
|
|
74
|
+
... messages: list
|
|
75
|
+
... output: str
|
|
76
|
+
>>>
|
|
77
|
+
>>> def process(state: State) -> State:
|
|
78
|
+
... return {"output": f"Processed: {state['messages'][-1]}"}
|
|
79
|
+
>>>
|
|
80
|
+
>>> builder = StateGraph(State)
|
|
81
|
+
>>> builder.add_node("process", process)
|
|
82
|
+
>>> builder.set_entry_point("process")
|
|
83
|
+
>>> builder.set_finish_point("process")
|
|
84
|
+
>>> graph = builder.compile()
|
|
85
|
+
>>>
|
|
86
|
+
>>> adapter = LangGraphAgentAdapter(graph=graph)
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
graph: Any, # Type: CompiledGraph (avoiding hard dependency)
|
|
92
|
+
input_key: str = "messages",
|
|
93
|
+
output_key: str | None = None,
|
|
94
|
+
state_key: str | None = None,
|
|
95
|
+
async_mode: bool = False,
|
|
96
|
+
task_store: "TaskStore | None" = None,
|
|
97
|
+
async_timeout: int = 300,
|
|
98
|
+
):
|
|
99
|
+
"""
|
|
100
|
+
Initialize the LangGraph adapter.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
graph: A LangGraph CompiledGraph instance (result of StateGraph.compile())
|
|
104
|
+
input_key: The key in the state dict for input messages (default: "messages").
|
|
105
|
+
Set to "input" for simple string input workflows.
|
|
106
|
+
output_key: Optional key to extract from final state. If None, the adapter
|
|
107
|
+
will try common keys like "output", "response", "messages".
|
|
108
|
+
state_key: Optional key to use when extracting state for streaming events.
|
|
109
|
+
If None, uses output_key or auto-detection.
|
|
110
|
+
async_mode: If True, return Task immediately and process in background.
|
|
111
|
+
If False (default), block until workflow completes.
|
|
112
|
+
task_store: Optional TaskStore for persisting task state. If not provided
|
|
113
|
+
and async_mode is True, uses InMemoryTaskStore.
|
|
114
|
+
async_timeout: Timeout for async task execution in seconds (default: 300).
|
|
115
|
+
"""
|
|
116
|
+
self.graph = graph
|
|
117
|
+
self.input_key = input_key
|
|
118
|
+
self.output_key = output_key
|
|
119
|
+
self.state_key = state_key or output_key
|
|
120
|
+
|
|
121
|
+
# Async task mode configuration
|
|
122
|
+
self.async_mode = async_mode
|
|
123
|
+
self.async_timeout = async_timeout
|
|
124
|
+
self._background_tasks: Dict[str, "asyncio.Task[None]"] = {}
|
|
125
|
+
self._cancelled_tasks: set[str] = set()
|
|
126
|
+
|
|
127
|
+
# Initialize task store for async mode
|
|
128
|
+
if async_mode:
|
|
129
|
+
if not _HAS_TASK_STORE:
|
|
130
|
+
raise ImportError(
|
|
131
|
+
"Async task mode requires the A2A SDK with task support. "
|
|
132
|
+
"Install with: pip install a2a-sdk"
|
|
133
|
+
)
|
|
134
|
+
self.task_store: "TaskStore" = task_store or InMemoryTaskStore()
|
|
135
|
+
else:
|
|
136
|
+
self.task_store = task_store # type: ignore
|
|
137
|
+
|
|
138
|
+
async def handle(self, params: MessageSendParams) -> Message | Task:
|
|
139
|
+
"""
|
|
140
|
+
Handle a non-streaming A2A message request.
|
|
141
|
+
|
|
142
|
+
In sync mode (default): Blocks until workflow completes, returns Message.
|
|
143
|
+
In async mode: Returns Task immediately, processes in background.
|
|
144
|
+
"""
|
|
145
|
+
if self.async_mode:
|
|
146
|
+
return await self._handle_async(params)
|
|
147
|
+
else:
|
|
148
|
+
return await self._handle_sync(params)
|
|
149
|
+
|
|
150
|
+
async def _handle_sync(self, params: MessageSendParams) -> Message:
|
|
151
|
+
"""Handle request synchronously - blocks until workflow completes."""
|
|
152
|
+
framework_input = await self.to_framework(params)
|
|
153
|
+
framework_output = await self.call_framework(framework_input, params)
|
|
154
|
+
result = await self.from_framework(framework_output, params)
|
|
155
|
+
|
|
156
|
+
# In sync mode, always return Message
|
|
157
|
+
if isinstance(result, Task):
|
|
158
|
+
if result.status and result.status.message:
|
|
159
|
+
return result.status.message
|
|
160
|
+
return Message(
|
|
161
|
+
role=Role.agent,
|
|
162
|
+
message_id=str(uuid.uuid4()),
|
|
163
|
+
context_id=result.context_id,
|
|
164
|
+
parts=[Part(root=TextPart(text="Workflow completed"))],
|
|
165
|
+
)
|
|
166
|
+
return result
|
|
167
|
+
|
|
168
|
+
async def _handle_async(self, params: MessageSendParams) -> Task:
|
|
169
|
+
"""
|
|
170
|
+
Handle request asynchronously - returns Task immediately, processes in background.
|
|
171
|
+
"""
|
|
172
|
+
# Generate IDs
|
|
173
|
+
task_id = str(uuid.uuid4())
|
|
174
|
+
context_id = self._extract_context_id(params) or str(uuid.uuid4())
|
|
175
|
+
|
|
176
|
+
# Extract the initial message for history
|
|
177
|
+
initial_message = None
|
|
178
|
+
if hasattr(params, "message") and params.message:
|
|
179
|
+
initial_message = params.message
|
|
180
|
+
|
|
181
|
+
# Create initial task with "working" state
|
|
182
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
183
|
+
task = Task(
|
|
184
|
+
id=task_id,
|
|
185
|
+
context_id=context_id,
|
|
186
|
+
status=TaskStatus(
|
|
187
|
+
state=TaskState.working,
|
|
188
|
+
timestamp=now,
|
|
189
|
+
),
|
|
190
|
+
history=[initial_message] if initial_message else None,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Save initial task state
|
|
194
|
+
await self.task_store.save(task)
|
|
195
|
+
logger.debug("Created async task %s with state=working", task_id)
|
|
196
|
+
|
|
197
|
+
# Start background processing with timeout
|
|
198
|
+
bg_task = asyncio.create_task(
|
|
199
|
+
self._execute_workflow_with_timeout(task_id, context_id, params)
|
|
200
|
+
)
|
|
201
|
+
self._background_tasks[task_id] = bg_task
|
|
202
|
+
|
|
203
|
+
# Clean up background task reference when done
|
|
204
|
+
def _on_task_done(t: "asyncio.Task[None]") -> None:
|
|
205
|
+
self._background_tasks.pop(task_id, None)
|
|
206
|
+
self._cancelled_tasks.discard(task_id)
|
|
207
|
+
if not t.cancelled():
|
|
208
|
+
exc = t.exception()
|
|
209
|
+
if exc:
|
|
210
|
+
logger.error(
|
|
211
|
+
"Unhandled exception in background task %s: %s",
|
|
212
|
+
task_id,
|
|
213
|
+
exc,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
bg_task.add_done_callback(_on_task_done)
|
|
217
|
+
|
|
218
|
+
return task
|
|
219
|
+
|
|
220
|
+
async def _execute_workflow_with_timeout(
|
|
221
|
+
self,
|
|
222
|
+
task_id: str,
|
|
223
|
+
context_id: str,
|
|
224
|
+
params: MessageSendParams,
|
|
225
|
+
) -> None:
|
|
226
|
+
"""Execute the workflow with a timeout wrapper."""
|
|
227
|
+
try:
|
|
228
|
+
await asyncio.wait_for(
|
|
229
|
+
self._execute_workflow_background(task_id, context_id, params),
|
|
230
|
+
timeout=self.async_timeout,
|
|
231
|
+
)
|
|
232
|
+
except asyncio.TimeoutError:
|
|
233
|
+
if task_id in self._cancelled_tasks:
|
|
234
|
+
logger.debug("Task %s was cancelled, not marking as failed", task_id)
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
logger.error("Task %s timed out after %s seconds", task_id, self.async_timeout)
|
|
238
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
239
|
+
error_message = Message(
|
|
240
|
+
role=Role.agent,
|
|
241
|
+
message_id=str(uuid.uuid4()),
|
|
242
|
+
context_id=context_id,
|
|
243
|
+
parts=[Part(root=TextPart(text=f"Workflow timed out after {self.async_timeout} seconds"))],
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
timeout_task = Task(
|
|
247
|
+
id=task_id,
|
|
248
|
+
context_id=context_id,
|
|
249
|
+
status=TaskStatus(
|
|
250
|
+
state=TaskState.failed,
|
|
251
|
+
message=error_message,
|
|
252
|
+
timestamp=now,
|
|
253
|
+
),
|
|
254
|
+
)
|
|
255
|
+
await self.task_store.save(timeout_task)
|
|
256
|
+
|
|
257
|
+
async def _execute_workflow_background(
|
|
258
|
+
self,
|
|
259
|
+
task_id: str,
|
|
260
|
+
context_id: str,
|
|
261
|
+
params: MessageSendParams,
|
|
262
|
+
) -> None:
|
|
263
|
+
"""Execute the LangGraph workflow in the background and update task state."""
|
|
264
|
+
try:
|
|
265
|
+
logger.debug("Starting background execution for task %s", task_id)
|
|
266
|
+
|
|
267
|
+
# Execute the workflow
|
|
268
|
+
framework_input = await self.to_framework(params)
|
|
269
|
+
framework_output = await self.call_framework(framework_input, params)
|
|
270
|
+
|
|
271
|
+
# Check if task was cancelled during execution
|
|
272
|
+
if task_id in self._cancelled_tasks:
|
|
273
|
+
logger.debug("Task %s was cancelled during execution", task_id)
|
|
274
|
+
return
|
|
275
|
+
|
|
276
|
+
# Convert to message
|
|
277
|
+
response_text = self._extract_output_text(framework_output)
|
|
278
|
+
response_message = Message(
|
|
279
|
+
role=Role.agent,
|
|
280
|
+
message_id=str(uuid.uuid4()),
|
|
281
|
+
context_id=context_id,
|
|
282
|
+
parts=[Part(root=TextPart(text=response_text))],
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Build history
|
|
286
|
+
history = []
|
|
287
|
+
if hasattr(params, "message") and params.message:
|
|
288
|
+
history.append(params.message)
|
|
289
|
+
history.append(response_message)
|
|
290
|
+
|
|
291
|
+
# Update task to completed state
|
|
292
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
293
|
+
completed_task = Task(
|
|
294
|
+
id=task_id,
|
|
295
|
+
context_id=context_id,
|
|
296
|
+
status=TaskStatus(
|
|
297
|
+
state=TaskState.completed,
|
|
298
|
+
message=response_message,
|
|
299
|
+
timestamp=now,
|
|
300
|
+
),
|
|
301
|
+
history=history,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
await self.task_store.save(completed_task)
|
|
305
|
+
logger.debug("Task %s completed successfully", task_id)
|
|
306
|
+
|
|
307
|
+
except asyncio.CancelledError:
|
|
308
|
+
logger.debug("Task %s was cancelled", task_id)
|
|
309
|
+
raise
|
|
310
|
+
|
|
311
|
+
except Exception as e:
|
|
312
|
+
if task_id in self._cancelled_tasks:
|
|
313
|
+
logger.debug("Task %s was cancelled, not marking as failed", task_id)
|
|
314
|
+
return
|
|
315
|
+
|
|
316
|
+
logger.error("Task %s failed: %s", task_id, e)
|
|
317
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
318
|
+
error_message = Message(
|
|
319
|
+
role=Role.agent,
|
|
320
|
+
message_id=str(uuid.uuid4()),
|
|
321
|
+
context_id=context_id,
|
|
322
|
+
parts=[Part(root=TextPart(text=f"Workflow failed: {str(e)}"))],
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
failed_task = Task(
|
|
326
|
+
id=task_id,
|
|
327
|
+
context_id=context_id,
|
|
328
|
+
status=TaskStatus(
|
|
329
|
+
state=TaskState.failed,
|
|
330
|
+
message=error_message,
|
|
331
|
+
timestamp=now,
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
await self.task_store.save(failed_task)
|
|
336
|
+
|
|
337
|
+
# ---------- Input mapping ----------
|
|
338
|
+
|
|
339
|
+
async def to_framework(self, params: MessageSendParams) -> Dict[str, Any]:
|
|
340
|
+
"""
|
|
341
|
+
Convert A2A message parameters to LangGraph state input.
|
|
342
|
+
|
|
343
|
+
Supports two common input patterns:
|
|
344
|
+
1. Messages-based: {"messages": [{"role": "user", "content": "..."}]}
|
|
345
|
+
2. Simple input: {"input": "user message text"}
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
params: A2A message parameters
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
Dictionary with graph state input
|
|
352
|
+
"""
|
|
353
|
+
user_message = ""
|
|
354
|
+
|
|
355
|
+
# Extract message from A2A params (new format with message.parts)
|
|
356
|
+
if hasattr(params, "message") and params.message:
|
|
357
|
+
msg = params.message
|
|
358
|
+
if hasattr(msg, "parts") and msg.parts:
|
|
359
|
+
text_parts = []
|
|
360
|
+
for part in msg.parts:
|
|
361
|
+
# Handle Part(root=TextPart(...)) structure
|
|
362
|
+
if hasattr(part, "root") and hasattr(part.root, "text"):
|
|
363
|
+
text_parts.append(part.root.text)
|
|
364
|
+
# Handle direct TextPart
|
|
365
|
+
elif hasattr(part, "text"):
|
|
366
|
+
text_parts.append(part.text)
|
|
367
|
+
user_message = self._join_text_parts(text_parts)
|
|
368
|
+
|
|
369
|
+
# Legacy support for messages array (deprecated)
|
|
370
|
+
elif getattr(params, "messages", None):
|
|
371
|
+
last = params.messages[-1]
|
|
372
|
+
content = getattr(last, "content", "")
|
|
373
|
+
if isinstance(content, str):
|
|
374
|
+
user_message = content.strip()
|
|
375
|
+
elif isinstance(content, list):
|
|
376
|
+
text_parts = []
|
|
377
|
+
for item in content:
|
|
378
|
+
txt = getattr(item, "text", None)
|
|
379
|
+
if txt and isinstance(txt, str) and txt.strip():
|
|
380
|
+
text_parts.append(txt.strip())
|
|
381
|
+
user_message = self._join_text_parts(text_parts)
|
|
382
|
+
|
|
383
|
+
# Build graph input based on input_key
|
|
384
|
+
if self.input_key == "messages":
|
|
385
|
+
# LangGraph message format (for chat-like workflows)
|
|
386
|
+
# Try to use LangChain message format if available
|
|
387
|
+
try:
|
|
388
|
+
from langchain_core.messages import HumanMessage
|
|
389
|
+
return {"messages": [HumanMessage(content=user_message)]}
|
|
390
|
+
except ImportError:
|
|
391
|
+
# Fallback to dict format
|
|
392
|
+
return {"messages": [{"role": "user", "content": user_message}]}
|
|
393
|
+
else:
|
|
394
|
+
# Simple input key (e.g., "input", "query", etc.)
|
|
395
|
+
return {self.input_key: user_message}
|
|
396
|
+
|
|
397
|
+
@staticmethod
|
|
398
|
+
def _join_text_parts(parts: list[str]) -> str:
|
|
399
|
+
"""Join text parts into a single string."""
|
|
400
|
+
if not parts:
|
|
401
|
+
return ""
|
|
402
|
+
text = " ".join(p.strip() for p in parts if p)
|
|
403
|
+
return text.strip()
|
|
404
|
+
|
|
405
|
+
def _extract_context_id(self, params: MessageSendParams) -> str | None:
|
|
406
|
+
"""Extract context_id from MessageSendParams."""
|
|
407
|
+
if hasattr(params, "message") and params.message:
|
|
408
|
+
return getattr(params.message, "context_id", None)
|
|
409
|
+
return None
|
|
410
|
+
|
|
411
|
+
# ---------- Framework call ----------
|
|
412
|
+
|
|
413
|
+
async def call_framework(
|
|
414
|
+
self, framework_input: Dict[str, Any], params: MessageSendParams
|
|
415
|
+
) -> Dict[str, Any]:
|
|
416
|
+
"""
|
|
417
|
+
Execute the LangGraph workflow with the provided input.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
framework_input: Input state dictionary for the graph
|
|
421
|
+
params: Original A2A parameters (for context)
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
Final state from the graph execution
|
|
425
|
+
|
|
426
|
+
Raises:
|
|
427
|
+
Exception: If graph execution fails
|
|
428
|
+
"""
|
|
429
|
+
logger.debug("Invoking LangGraph with input: %s", framework_input)
|
|
430
|
+
result = await self.graph.ainvoke(framework_input)
|
|
431
|
+
logger.debug("LangGraph returned state with keys: %s", list(result.keys()) if isinstance(result, dict) else type(result).__name__)
|
|
432
|
+
return result
|
|
433
|
+
|
|
434
|
+
# ---------- Output mapping ----------
|
|
435
|
+
|
|
436
|
+
async def from_framework(
|
|
437
|
+
self, framework_output: Dict[str, Any], params: MessageSendParams
|
|
438
|
+
) -> Message | Task:
|
|
439
|
+
"""
|
|
440
|
+
Convert LangGraph final state to A2A Message.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
framework_output: Final state from graph execution
|
|
444
|
+
params: Original A2A parameters
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
A2A Message with the workflow's response
|
|
448
|
+
"""
|
|
449
|
+
response_text = self._extract_output_text(framework_output)
|
|
450
|
+
context_id = self._extract_context_id(params)
|
|
451
|
+
|
|
452
|
+
return Message(
|
|
453
|
+
role=Role.agent,
|
|
454
|
+
message_id=str(uuid.uuid4()),
|
|
455
|
+
context_id=context_id,
|
|
456
|
+
parts=[Part(root=TextPart(text=response_text))],
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
def _extract_output_text(self, framework_output: Any) -> str:
|
|
460
|
+
"""
|
|
461
|
+
Extract text content from LangGraph state.
|
|
462
|
+
|
|
463
|
+
Handles various output patterns:
|
|
464
|
+
- output_key specified: extract that key
|
|
465
|
+
- "messages" key: extract last message content
|
|
466
|
+
- Common keys: "output", "response", "result", "answer"
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
framework_output: Final state from the graph
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
Extracted text string
|
|
473
|
+
"""
|
|
474
|
+
if not isinstance(framework_output, dict):
|
|
475
|
+
return str(framework_output)
|
|
476
|
+
|
|
477
|
+
# Use output_key if specified
|
|
478
|
+
if self.output_key and self.output_key in framework_output:
|
|
479
|
+
return self._extract_value_text(framework_output[self.output_key])
|
|
480
|
+
|
|
481
|
+
# Try "messages" key (common in chat workflows)
|
|
482
|
+
if "messages" in framework_output:
|
|
483
|
+
messages = framework_output["messages"]
|
|
484
|
+
if messages and len(messages) > 0:
|
|
485
|
+
last_message = messages[-1]
|
|
486
|
+
return self._extract_message_content(last_message)
|
|
487
|
+
|
|
488
|
+
# Try common output keys
|
|
489
|
+
for key in ["output", "response", "result", "answer", "text", "content"]:
|
|
490
|
+
if key in framework_output:
|
|
491
|
+
return self._extract_value_text(framework_output[key])
|
|
492
|
+
|
|
493
|
+
# Fallback: serialize entire state (excluding internal keys)
|
|
494
|
+
clean_state = {k: v for k, v in framework_output.items() if not k.startswith("_")}
|
|
495
|
+
return json.dumps(clean_state, indent=2, default=str)
|
|
496
|
+
|
|
497
|
+
def _extract_value_text(self, value: Any) -> str:
|
|
498
|
+
"""Extract text from a value (handles strings, dicts, lists)."""
|
|
499
|
+
if isinstance(value, str):
|
|
500
|
+
return value
|
|
501
|
+
if isinstance(value, dict):
|
|
502
|
+
# Try common text keys in dict
|
|
503
|
+
for key in ["text", "content", "output"]:
|
|
504
|
+
if key in value:
|
|
505
|
+
return str(value[key])
|
|
506
|
+
return json.dumps(value, indent=2, default=str)
|
|
507
|
+
if isinstance(value, list):
|
|
508
|
+
# Join list items
|
|
509
|
+
return "\n".join(self._extract_value_text(item) for item in value)
|
|
510
|
+
return str(value)
|
|
511
|
+
|
|
512
|
+
def _extract_message_content(self, message: Any) -> str:
|
|
513
|
+
"""Extract content from a message object (LangChain or dict)."""
|
|
514
|
+
# LangChain message with content attribute
|
|
515
|
+
if hasattr(message, "content"):
|
|
516
|
+
content = message.content
|
|
517
|
+
if isinstance(content, str):
|
|
518
|
+
return content
|
|
519
|
+
elif isinstance(content, list):
|
|
520
|
+
text_parts = []
|
|
521
|
+
for item in content:
|
|
522
|
+
if isinstance(item, str):
|
|
523
|
+
text_parts.append(item)
|
|
524
|
+
elif hasattr(item, "text"):
|
|
525
|
+
text_parts.append(item.text)
|
|
526
|
+
return " ".join(text_parts)
|
|
527
|
+
return str(content)
|
|
528
|
+
|
|
529
|
+
# Dict message
|
|
530
|
+
if isinstance(message, dict):
|
|
531
|
+
if "content" in message:
|
|
532
|
+
return str(message["content"])
|
|
533
|
+
if "text" in message:
|
|
534
|
+
return str(message["text"])
|
|
535
|
+
|
|
536
|
+
return str(message)
|
|
537
|
+
|
|
538
|
+
# ---------- Streaming support ----------
|
|
539
|
+
|
|
540
|
+
async def handle_stream(
|
|
541
|
+
self, params: MessageSendParams
|
|
542
|
+
) -> AsyncIterator[Dict[str, Any]]:
|
|
543
|
+
"""
|
|
544
|
+
Handle a streaming A2A message request.
|
|
545
|
+
|
|
546
|
+
Uses LangGraph's astream() or astream_events() to yield intermediate
|
|
547
|
+
results as the workflow executes.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
params: A2A message parameters
|
|
551
|
+
|
|
552
|
+
Yields:
|
|
553
|
+
Server-Sent Events compatible dictionaries with streaming chunks
|
|
554
|
+
"""
|
|
555
|
+
framework_input = await self.to_framework(params)
|
|
556
|
+
context_id = self._extract_context_id(params)
|
|
557
|
+
message_id = str(uuid.uuid4())
|
|
558
|
+
|
|
559
|
+
logger.debug("Starting LangGraph stream with input: %s", framework_input)
|
|
560
|
+
|
|
561
|
+
accumulated_text = ""
|
|
562
|
+
last_state = None
|
|
563
|
+
|
|
564
|
+
# Stream from LangGraph
|
|
565
|
+
async for state in self.graph.astream(framework_input):
|
|
566
|
+
last_state = state
|
|
567
|
+
|
|
568
|
+
# Extract text from current state
|
|
569
|
+
text = self._extract_streaming_text(state)
|
|
570
|
+
|
|
571
|
+
if text and text != accumulated_text:
|
|
572
|
+
# Calculate the new chunk (delta)
|
|
573
|
+
new_content = text[len(accumulated_text):] if text.startswith(accumulated_text) else text
|
|
574
|
+
accumulated_text = text
|
|
575
|
+
|
|
576
|
+
if new_content:
|
|
577
|
+
yield {
|
|
578
|
+
"event": "message",
|
|
579
|
+
"data": json.dumps({
|
|
580
|
+
"type": "content",
|
|
581
|
+
"content": new_content,
|
|
582
|
+
}),
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
# Use final state if we have it, otherwise use accumulated
|
|
586
|
+
final_text = self._extract_output_text(last_state) if last_state else accumulated_text
|
|
587
|
+
|
|
588
|
+
# Send final message with complete response
|
|
589
|
+
final_message = Message(
|
|
590
|
+
role=Role.agent,
|
|
591
|
+
message_id=message_id,
|
|
592
|
+
context_id=context_id,
|
|
593
|
+
parts=[Part(root=TextPart(text=final_text))],
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
# Send completion event
|
|
597
|
+
yield {
|
|
598
|
+
"event": "done",
|
|
599
|
+
"data": json.dumps({
|
|
600
|
+
"status": "completed",
|
|
601
|
+
"message": final_message.model_dump() if hasattr(final_message, "model_dump") else str(final_message),
|
|
602
|
+
}),
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
logger.debug("LangGraph stream completed")
|
|
606
|
+
|
|
607
|
+
def _extract_streaming_text(self, state: Any) -> str:
|
|
608
|
+
"""
|
|
609
|
+
Extract text from intermediate streaming state.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
state: Intermediate state from astream()
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
Current text content
|
|
616
|
+
"""
|
|
617
|
+
if not isinstance(state, dict):
|
|
618
|
+
return str(state)
|
|
619
|
+
|
|
620
|
+
# Use state_key if specified
|
|
621
|
+
if self.state_key and self.state_key in state:
|
|
622
|
+
return self._extract_value_text(state[self.state_key])
|
|
623
|
+
|
|
624
|
+
# Try messages (for chat workflows)
|
|
625
|
+
if "messages" in state:
|
|
626
|
+
messages = state["messages"]
|
|
627
|
+
if messages:
|
|
628
|
+
# Get content from last message
|
|
629
|
+
last = messages[-1]
|
|
630
|
+
return self._extract_message_content(last)
|
|
631
|
+
|
|
632
|
+
# Try common keys
|
|
633
|
+
for key in ["output", "response", "text", "content"]:
|
|
634
|
+
if key in state:
|
|
635
|
+
return self._extract_value_text(state[key])
|
|
636
|
+
|
|
637
|
+
return ""
|
|
638
|
+
|
|
639
|
+
def supports_streaming(self) -> bool:
|
|
640
|
+
"""
|
|
641
|
+
Check if the graph supports streaming.
|
|
642
|
+
|
|
643
|
+
Returns:
|
|
644
|
+
True if the graph has an astream method
|
|
645
|
+
"""
|
|
646
|
+
return hasattr(self.graph, "astream")
|
|
647
|
+
|
|
648
|
+
# ---------- Async Task Support ----------
|
|
649
|
+
|
|
650
|
+
def supports_async_tasks(self) -> bool:
|
|
651
|
+
"""Check if this adapter supports async task execution."""
|
|
652
|
+
return self.async_mode
|
|
653
|
+
|
|
654
|
+
async def get_task(self, task_id: str) -> Task | None:
|
|
655
|
+
"""
|
|
656
|
+
Get the current status of a task by ID.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
task_id: The ID of the task to retrieve
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
The Task object with current status, or None if not found
|
|
663
|
+
|
|
664
|
+
Raises:
|
|
665
|
+
RuntimeError: If async mode is not enabled
|
|
666
|
+
"""
|
|
667
|
+
if not self.async_mode:
|
|
668
|
+
raise RuntimeError(
|
|
669
|
+
"get_task() is only available in async mode. "
|
|
670
|
+
"Initialize adapter with async_mode=True"
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
task = await self.task_store.get(task_id)
|
|
674
|
+
if task:
|
|
675
|
+
logger.debug("Retrieved task %s with state=%s", task_id, task.status.state)
|
|
676
|
+
else:
|
|
677
|
+
logger.debug("Task %s not found", task_id)
|
|
678
|
+
return task
|
|
679
|
+
|
|
680
|
+
async def cancel_task(self, task_id: str) -> Task | None:
|
|
681
|
+
"""
|
|
682
|
+
Attempt to cancel a running task.
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
task_id: The ID of the task to cancel
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
The updated Task object with state="canceled", or None if not found
|
|
689
|
+
"""
|
|
690
|
+
if not self.async_mode:
|
|
691
|
+
raise RuntimeError(
|
|
692
|
+
"cancel_task() is only available in async mode. "
|
|
693
|
+
"Initialize adapter with async_mode=True"
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# Mark task as cancelled to prevent race conditions
|
|
697
|
+
self._cancelled_tasks.add(task_id)
|
|
698
|
+
|
|
699
|
+
# Cancel the background task if still running
|
|
700
|
+
bg_task = self._background_tasks.get(task_id)
|
|
701
|
+
if bg_task and not bg_task.done():
|
|
702
|
+
bg_task.cancel()
|
|
703
|
+
logger.debug("Cancelling background task for %s", task_id)
|
|
704
|
+
try:
|
|
705
|
+
await bg_task
|
|
706
|
+
except asyncio.CancelledError:
|
|
707
|
+
pass
|
|
708
|
+
except Exception:
|
|
709
|
+
pass
|
|
710
|
+
|
|
711
|
+
# Update task state to canceled
|
|
712
|
+
task = await self.task_store.get(task_id)
|
|
713
|
+
if task:
|
|
714
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
715
|
+
canceled_task = Task(
|
|
716
|
+
id=task_id,
|
|
717
|
+
context_id=task.context_id,
|
|
718
|
+
status=TaskStatus(
|
|
719
|
+
state=TaskState.canceled,
|
|
720
|
+
timestamp=now,
|
|
721
|
+
),
|
|
722
|
+
history=task.history,
|
|
723
|
+
)
|
|
724
|
+
await self.task_store.save(canceled_task)
|
|
725
|
+
logger.debug("Task %s marked as canceled", task_id)
|
|
726
|
+
return canceled_task
|
|
727
|
+
|
|
728
|
+
return None
|
|
729
|
+
|
|
730
|
+
# ---------- Lifecycle ----------
|
|
731
|
+
|
|
732
|
+
async def close(self) -> None:
|
|
733
|
+
"""Cancel pending background tasks."""
|
|
734
|
+
for task_id in self._background_tasks:
|
|
735
|
+
self._cancelled_tasks.add(task_id)
|
|
736
|
+
|
|
737
|
+
tasks_to_cancel = []
|
|
738
|
+
for task_id, bg_task in list(self._background_tasks.items()):
|
|
739
|
+
if not bg_task.done():
|
|
740
|
+
bg_task.cancel()
|
|
741
|
+
tasks_to_cancel.append(bg_task)
|
|
742
|
+
logger.debug("Cancelling background task %s during close", task_id)
|
|
743
|
+
|
|
744
|
+
if tasks_to_cancel:
|
|
745
|
+
await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
|
|
746
|
+
|
|
747
|
+
self._background_tasks.clear()
|
|
748
|
+
self._cancelled_tasks.clear()
|
|
749
|
+
|
|
750
|
+
async def __aenter__(self):
|
|
751
|
+
"""Async context manager entry."""
|
|
752
|
+
return self
|
|
753
|
+
|
|
754
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
755
|
+
"""Async context manager exit."""
|
|
756
|
+
await self.close()
|