a2a-adapter 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,22 +1,44 @@
1
1
  """
2
2
  LangChain adapter for A2A Protocol.
3
3
 
4
- This adapter enables LangChain runnables (chains, agents) to be exposed
5
- as A2A-compliant agents with support for both streaming and non-streaming modes.
4
+ This adapter enables LangChain runnables (chains, agents, RAG pipelines) to be
5
+ exposed as A2A-compliant agents with support for both streaming and non-streaming modes.
6
6
  """
7
7
 
8
8
  import json
9
+ import logging
10
+ import uuid
9
11
  from typing import Any, AsyncIterator, Dict
10
12
 
11
- from a2a.types import Message, MessageSendParams, Task, TextPart
13
+ from a2a.types import (
14
+ Message,
15
+ MessageSendParams,
16
+ Task,
17
+ TextPart,
18
+ Role,
19
+ Part,
20
+ )
21
+ from ..adapter import BaseAgentAdapter
12
22
 
23
+ logger = logging.getLogger(__name__)
13
24
 
14
- class LangChainAgentAdapter:
25
+
26
+ class LangChainAgentAdapter(BaseAgentAdapter):
15
27
  """
16
28
  Adapter for integrating LangChain runnables as A2A agents.
17
-
18
- This adapter works with any LangChain Runnable (chains, agents, etc.)
29
+
30
+ This adapter works with any LangChain Runnable (chains, agents, RAG pipelines)
19
31
  and supports both streaming and non-streaming execution modes.
32
+
33
+ Example:
34
+ >>> from langchain_openai import ChatOpenAI
35
+ >>> from langchain_core.prompts import ChatPromptTemplate
36
+ >>>
37
+ >>> llm = ChatOpenAI(model="gpt-4o-mini")
38
+ >>> prompt = ChatPromptTemplate.from_template("Answer: {input}")
39
+ >>> chain = prompt | llm
40
+ >>>
41
+ >>> adapter = LangChainAgentAdapter(runnable=chain, input_key="input")
20
42
  """
21
43
 
22
44
  def __init__(
@@ -27,145 +49,281 @@ class LangChainAgentAdapter:
27
49
  ):
28
50
  """
29
51
  Initialize the LangChain adapter.
30
-
52
+
31
53
  Args:
32
54
  runnable: A LangChain Runnable instance (chain, agent, etc.)
33
55
  input_key: The key name for passing input to the runnable (default: "input")
34
- output_key: Optional key to extract from runnable output. If None, uses the entire output.
56
+ output_key: Optional key to extract from runnable output. If None,
57
+ the adapter will attempt to extract text intelligently.
35
58
  """
36
59
  self.runnable = runnable
37
60
  self.input_key = input_key
38
61
  self.output_key = output_key
39
62
 
40
- async def handle(self, params: MessageSendParams) -> Message | Task:
41
- """Handle a non-streaming A2A message request."""
42
- framework_input = await self.to_framework(params)
43
- framework_output = await self.call_framework(framework_input, params)
44
- return await self.from_framework(framework_output, params)
45
-
46
- async def handle_stream(
47
- self, params: MessageSendParams
48
- ) -> AsyncIterator[Dict[str, Any]]:
49
- """
50
- Handle a streaming A2A message request.
51
-
52
- Yields Server-Sent Events compatible dictionaries with streaming chunks.
53
- """
54
- framework_input = await self.to_framework(params)
55
-
56
- # Stream from LangChain runnable
57
- async for chunk in self.runnable.astream(framework_input):
58
- # Extract text from chunk
59
- if hasattr(chunk, "content"):
60
- text = chunk.content
61
- elif isinstance(chunk, dict):
62
- text = chunk.get(self.output_key or "output", str(chunk))
63
- else:
64
- text = str(chunk)
65
-
66
- # Yield SSE-compatible event
67
- if text:
68
- yield {
69
- "event": "message",
70
- "data": json.dumps({
71
- "type": "content",
72
- "content": text,
73
- }),
74
- }
75
-
76
- # Send completion event
77
- yield {
78
- "event": "done",
79
- "data": json.dumps({"status": "completed"}),
80
- }
81
-
82
- def supports_streaming(self) -> bool:
83
- """Check if the runnable supports streaming."""
84
- return hasattr(self.runnable, "astream")
63
+ # ---------- Input mapping ----------
85
64
 
86
65
  async def to_framework(self, params: MessageSendParams) -> Dict[str, Any]:
87
66
  """
88
67
  Convert A2A message parameters to LangChain runnable input.
89
-
68
+
69
+ Extracts the user's message text and formats it for the runnable.
70
+
90
71
  Args:
91
72
  params: A2A message parameters
92
-
73
+
93
74
  Returns:
94
75
  Dictionary with runnable input data
95
76
  """
96
- # Extract text from the last user message
97
77
  user_message = ""
98
- if params.messages:
99
- last_message = params.messages[-1]
100
- if hasattr(last_message, "content"):
101
- if isinstance(last_message.content, list):
102
- # Extract text from content blocks
103
- text_parts = [
104
- item.text
105
- for item in last_message.content
106
- if hasattr(item, "text")
107
- ]
108
- user_message = " ".join(text_parts)
109
- elif isinstance(last_message.content, str):
110
- user_message = last_message.content
78
+
79
+ # Extract message from A2A params (new format with message.parts)
80
+ if hasattr(params, "message") and params.message:
81
+ msg = params.message
82
+ if hasattr(msg, "parts") and msg.parts:
83
+ text_parts = []
84
+ for part in msg.parts:
85
+ # Handle Part(root=TextPart(...)) structure
86
+ if hasattr(part, "root") and hasattr(part.root, "text"):
87
+ text_parts.append(part.root.text)
88
+ # Handle direct TextPart
89
+ elif hasattr(part, "text"):
90
+ text_parts.append(part.text)
91
+ user_message = self._join_text_parts(text_parts)
92
+
93
+ # Legacy support for messages array (deprecated)
94
+ elif getattr(params, "messages", None):
95
+ last = params.messages[-1]
96
+ content = getattr(last, "content", "")
97
+ if isinstance(content, str):
98
+ user_message = content.strip()
99
+ elif isinstance(content, list):
100
+ text_parts = []
101
+ for item in content:
102
+ txt = getattr(item, "text", None)
103
+ if txt and isinstance(txt, str) and txt.strip():
104
+ text_parts.append(txt.strip())
105
+ user_message = self._join_text_parts(text_parts)
111
106
 
112
107
  # Build runnable input
113
108
  return {
114
109
  self.input_key: user_message,
115
110
  }
116
111
 
112
+ @staticmethod
113
+ def _join_text_parts(parts: list[str]) -> str:
114
+ """Join text parts into a single string."""
115
+ if not parts:
116
+ return ""
117
+ text = " ".join(p.strip() for p in parts if p)
118
+ return text.strip()
119
+
120
+ # ---------- Framework call ----------
121
+
117
122
  async def call_framework(
118
123
  self, framework_input: Dict[str, Any], params: MessageSendParams
119
124
  ) -> Any:
120
125
  """
121
126
  Execute the LangChain runnable with the provided input.
122
-
127
+
123
128
  Args:
124
129
  framework_input: Input dictionary for the runnable
125
130
  params: Original A2A parameters (for context)
126
-
131
+
127
132
  Returns:
128
133
  Runnable execution output
129
-
134
+
130
135
  Raises:
131
136
  Exception: If runnable execution fails
132
137
  """
138
+ logger.debug("Invoking LangChain runnable with input: %s", framework_input)
133
139
  result = await self.runnable.ainvoke(framework_input)
140
+ logger.debug("LangChain runnable returned: %s", type(result).__name__)
134
141
  return result
135
142
 
143
+ # ---------- Output mapping ----------
144
+
136
145
  async def from_framework(
137
146
  self, framework_output: Any, params: MessageSendParams
138
147
  ) -> Message | Task:
139
148
  """
140
149
  Convert LangChain runnable output to A2A Message.
141
-
150
+
151
+ Handles various LangChain output types:
152
+ - AIMessage: Extract content attribute
153
+ - Dict: Extract using output_key or serialize
154
+ - String: Use directly
155
+
142
156
  Args:
143
157
  framework_output: Output from runnable execution
144
158
  params: Original A2A parameters
145
-
159
+
146
160
  Returns:
147
161
  A2A Message with the runnable's response
148
162
  """
149
- # Extract output based on type
163
+ response_text = self._extract_output_text(framework_output)
164
+
165
+ # Preserve context_id from the request for multi-turn conversation tracking
166
+ context_id = self._extract_context_id(params)
167
+
168
+ return Message(
169
+ role=Role.agent,
170
+ message_id=str(uuid.uuid4()),
171
+ context_id=context_id,
172
+ parts=[Part(root=TextPart(text=response_text))],
173
+ )
174
+
175
+ def _extract_output_text(self, framework_output: Any) -> str:
176
+ """
177
+ Extract text content from LangChain runnable output.
178
+
179
+ Args:
180
+ framework_output: Output from the runnable
181
+
182
+ Returns:
183
+ Extracted text string
184
+ """
185
+ # AIMessage or similar with content attribute
150
186
  if hasattr(framework_output, "content"):
151
- # AIMessage or similar
152
- response_text = framework_output.content
153
- elif isinstance(framework_output, dict):
154
- # Dictionary output - extract using output_key or serialize
187
+ content = framework_output.content
188
+ if isinstance(content, str):
189
+ return content
190
+ elif isinstance(content, list):
191
+ # Handle list of content blocks (multimodal)
192
+ text_parts = []
193
+ for item in content:
194
+ if isinstance(item, str):
195
+ text_parts.append(item)
196
+ elif hasattr(item, "text"):
197
+ text_parts.append(item.text)
198
+ elif isinstance(item, dict) and "text" in item:
199
+ text_parts.append(item["text"])
200
+ return " ".join(text_parts)
201
+ return str(content)
202
+
203
+ # Dictionary output - extract using output_key or serialize
204
+ if isinstance(framework_output, dict):
155
205
  if self.output_key and self.output_key in framework_output:
156
- response_text = str(framework_output[self.output_key])
157
- else:
158
- response_text = json.dumps(framework_output, indent=2)
159
- else:
160
- # String or other type - convert to string
161
- response_text = str(framework_output)
206
+ return str(framework_output[self.output_key])
207
+ # Try common output keys
208
+ for key in ["output", "result", "answer", "response", "text"]:
209
+ if key in framework_output:
210
+ return str(framework_output[key])
211
+ # Fallback: serialize as JSON
212
+ return json.dumps(framework_output, indent=2)
162
213
 
163
- return Message(
164
- role="assistant",
165
- content=[TextPart(type="text", text=response_text)],
214
+ # String or other type - convert to string
215
+ return str(framework_output)
216
+
217
+ def _extract_context_id(self, params: MessageSendParams) -> str | None:
218
+ """Extract context_id from MessageSendParams."""
219
+ if hasattr(params, "message") and params.message:
220
+ return getattr(params.message, "context_id", None)
221
+ return None
222
+
223
+ # ---------- Streaming support ----------
224
+
225
+ async def handle_stream(
226
+ self, params: MessageSendParams
227
+ ) -> AsyncIterator[Dict[str, Any]]:
228
+ """
229
+ Handle a streaming A2A message request.
230
+
231
+ Uses LangChain's astream() method to yield tokens as they are generated.
232
+
233
+ Args:
234
+ params: A2A message parameters
235
+
236
+ Yields:
237
+ Server-Sent Events compatible dictionaries with streaming chunks
238
+ """
239
+ framework_input = await self.to_framework(params)
240
+ context_id = self._extract_context_id(params)
241
+ message_id = str(uuid.uuid4())
242
+
243
+ logger.debug("Starting LangChain stream with input: %s", framework_input)
244
+
245
+ accumulated_text = ""
246
+
247
+ # Stream from LangChain runnable
248
+ async for chunk in self.runnable.astream(framework_input):
249
+ # Extract text from chunk
250
+ text = self._extract_chunk_text(chunk)
251
+
252
+ if text:
253
+ accumulated_text += text
254
+ # Yield SSE-compatible event
255
+ yield {
256
+ "event": "message",
257
+ "data": json.dumps({
258
+ "type": "content",
259
+ "content": text,
260
+ }),
261
+ }
262
+
263
+ # Send final message with complete response
264
+ final_message = Message(
265
+ role=Role.agent,
266
+ message_id=message_id,
267
+ context_id=context_id,
268
+ parts=[Part(root=TextPart(text=accumulated_text))],
166
269
  )
167
270
 
271
+ # Send completion event
272
+ yield {
273
+ "event": "done",
274
+ "data": json.dumps({
275
+ "status": "completed",
276
+ "message": final_message.model_dump() if hasattr(final_message, "model_dump") else str(final_message),
277
+ }),
278
+ }
279
+
280
+ logger.debug("LangChain stream completed")
281
+
282
+ def _extract_chunk_text(self, chunk: Any) -> str:
283
+ """
284
+ Extract text from a streaming chunk.
285
+
286
+ Args:
287
+ chunk: A streaming chunk from LangChain
288
+
289
+ Returns:
290
+ Extracted text string
291
+ """
292
+ # AIMessageChunk or similar
293
+ if hasattr(chunk, "content"):
294
+ content = chunk.content
295
+ if isinstance(content, str):
296
+ return content
297
+ elif isinstance(content, list):
298
+ text_parts = []
299
+ for item in content:
300
+ if isinstance(item, str):
301
+ text_parts.append(item)
302
+ elif hasattr(item, "text"):
303
+ text_parts.append(item.text)
304
+ return "".join(text_parts)
305
+ return str(content) if content else ""
306
+
307
+ # Dictionary chunk
308
+ if isinstance(chunk, dict):
309
+ if self.output_key and self.output_key in chunk:
310
+ return str(chunk[self.output_key])
311
+ for key in ["output", "result", "content", "text"]:
312
+ if key in chunk:
313
+ return str(chunk[key])
314
+ return ""
315
+
316
+ # String chunk
317
+ if isinstance(chunk, str):
318
+ return chunk
319
+
320
+ return ""
321
+
168
322
  def supports_streaming(self) -> bool:
169
- """Check if this adapter supports streaming responses."""
170
- return False
323
+ """
324
+ Check if the runnable supports streaming.
171
325
 
326
+ Returns:
327
+ True if the runnable has an astream method
328
+ """
329
+ return hasattr(self.runnable, "astream")