aip-agents-binary 0.5.25b9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,14 +18,14 @@ from collections.abc import Awaitable, Callable, Sequence
18
18
  from dataclasses import asdict, dataclass
19
19
  from functools import reduce
20
20
  from textwrap import dedent
21
- from typing import TYPE_CHECKING, Annotated, Any
21
+ from typing import TYPE_CHECKING, Annotated, Any, cast
22
22
 
23
- from deprecated import deprecated
23
+ from deprecated import deprecated # type: ignore[import-untyped]
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from aip_agents.guardrails.manager import GuardrailManager
27
- from gllm_core.event import EventEmitter
28
- from gllm_core.schema import Chunk
27
+ from gllm_core.event import EventEmitter # type: ignore[import-untyped]
28
+ from gllm_core.schema import Chunk # type: ignore[import-untyped]
29
29
  from langchain_core.language_models import BaseChatModel
30
30
  from langchain_core.messages import (
31
31
  AIMessage,
@@ -212,6 +212,12 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
212
212
  **kwargs,
213
213
  )
214
214
 
215
+ if self.model is None and self.lm_invoker is None:
216
+ logger.warning(
217
+ "Agent '%s': Model and LM invoker are both unset. Calls that require a model will fail.",
218
+ self.name,
219
+ )
220
+
215
221
  # Handle tool output management
216
222
  self.tool_output_manager = tool_output_manager
217
223
  self._pii_handlers_by_thread: dict[str, ToolPIIHandler] = {}
@@ -256,7 +262,7 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
256
262
 
257
263
  # Auto-configure TodoListMiddleware if planning enabled
258
264
  if planning:
259
- middleware_list.append(TodoListMiddleware())
265
+ middleware_list.append(cast(AgentMiddleware, TodoListMiddleware()))
260
266
 
261
267
  # Auto-configure GuardrailMiddleware if guardrail provided
262
268
  if guardrail:
@@ -732,7 +738,7 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
732
738
  pending_artifacts: list[dict[str, Any]] = state.get("artifacts") or []
733
739
  reference_updates: list[Chunk] = []
734
740
  tool_map = {tool.name: tool for tool in self.resolved_tools}
735
- pii_mapping = {}
741
+ pii_mapping: dict[str, str] = {}
736
742
 
737
743
  aggregated_metadata_delta: dict[str, Any] = {}
738
744
  total_tools_token_usage: list[UsageMetadata] = []
@@ -756,7 +762,8 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
756
762
  ),
757
763
  )
758
764
 
759
- tasks = [asyncio.create_task(run_tool(tc)) for tc in last_message.tool_calls]
765
+ normalized_tool_calls = [self._normalize_tool_call(tc) for tc in last_message.tool_calls]
766
+ tasks = [asyncio.create_task(run_tool(tc)) for tc in normalized_tool_calls]
760
767
 
761
768
  for coro in asyncio.as_completed(tasks):
762
769
  tool_result = await coro
@@ -779,6 +786,31 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
779
786
  pii_mapping,
780
787
  )
781
788
 
789
+ def _normalize_tool_call(self, tool_call: Any) -> dict[str, Any]:
790
+ """Normalize tool call inputs into a dict with required keys."""
791
+ if isinstance(tool_call, dict):
792
+ normalized = dict(tool_call)
793
+ elif hasattr(tool_call, "model_dump"):
794
+ normalized = tool_call.model_dump()
795
+ elif hasattr(tool_call, "dict"):
796
+ normalized = tool_call.dict()
797
+ elif hasattr(tool_call, "name") and hasattr(tool_call, "args"):
798
+ normalized = {
799
+ "id": getattr(tool_call, "id", None),
800
+ "name": getattr(tool_call, "name", None),
801
+ "args": getattr(tool_call, "args", None),
802
+ }
803
+ else:
804
+ raise TypeError("Tool call must be a dict-like object or ToolCall instance.")
805
+
806
+ if not isinstance(normalized, dict):
807
+ raise TypeError("Tool call normalization did not produce a dict.")
808
+
809
+ if "name" not in normalized or "args" not in normalized:
810
+ raise TypeError("Tool call must include 'name' and 'args' fields.")
811
+
812
+ return normalized
813
+
782
814
  def _accumulate_tool_result( # noqa: PLR0913
783
815
  self,
784
816
  tool_result: Any,
@@ -787,7 +819,7 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
787
819
  aggregated_metadata_delta: dict[str, Any],
788
820
  reference_updates: list[Chunk],
789
821
  total_tools_token_usage: list[UsageMetadata],
790
- pii_mapping: dict[str, str] | None,
822
+ pii_mapping: dict[str, str],
791
823
  ) -> None: # noqa: PLR0913
792
824
  """Accumulate results from a single tool call.
793
825
 
@@ -1233,13 +1265,16 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
1233
1265
 
1234
1266
  # Create enhanced tool configuration with output management
1235
1267
  tool_config = self._create_enhanced_tool_config(config, state, tool_call["name"], tool_call_id)
1268
+ if not isinstance(tool_config, dict):
1269
+ raise TypeError("Tool configuration must be a dictionary.")
1270
+ tool_config_runnable = tool_config
1236
1271
 
1237
1272
  arun_streaming_method = getattr(tool, TOOL_RUN_STREAMING_METHOD, None)
1238
1273
 
1239
1274
  if arun_streaming_method and callable(arun_streaming_method):
1240
1275
  tool_output = await self._execute_tool_with_streaming(tool, tool_call, tool_config)
1241
1276
  else:
1242
- tool_output = await tool.ainvoke(resolved_args, tool_config)
1277
+ tool_output = await tool.ainvoke(resolved_args, tool_config_runnable)
1243
1278
 
1244
1279
  references = extract_references_from_tool(tool, tool_output)
1245
1280
 
@@ -1513,7 +1548,7 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
1513
1548
  tool_call: dict[str, Any],
1514
1549
  execution_time: float,
1515
1550
  pending_artifacts: list[dict[str, Any]],
1516
- ) -> tuple[list[BaseMessage], list[dict[str, Any]], dict[str, Any]]:
1551
+ ) -> tuple[list[ToolMessage], list[dict[str, Any]], dict[str, Any]]:
1517
1552
  """Process tool output into messages, artifacts, and metadata.
1518
1553
 
1519
1554
  Args:
@@ -1541,7 +1576,7 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
1541
1576
 
1542
1577
  def _handle_command_output(
1543
1578
  self, tool_output: Command, tool_call: dict[str, Any], execution_time: float, metadata_delta: dict[str, Any]
1544
- ) -> tuple[list[BaseMessage], list[dict[str, Any]], dict[str, Any]]:
1579
+ ) -> tuple[list[ToolMessage], list[dict[str, Any]], dict[str, Any]]:
1545
1580
  """Handle Command type tool outputs.
1546
1581
 
1547
1582
  Args:
@@ -1570,7 +1605,7 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
1570
1605
 
1571
1606
  def _handle_string_output(
1572
1607
  self, tool_output: str, tool_call: dict[str, Any], execution_time: float
1573
- ) -> tuple[list[BaseMessage], list[dict[str, Any]], dict[str, Any]]:
1608
+ ) -> tuple[list[ToolMessage], list[dict[str, Any]], dict[str, Any]]:
1574
1609
  """Handle string type tool outputs.
1575
1610
 
1576
1611
  Args:
@@ -1596,7 +1631,7 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
1596
1631
  execution_time: float,
1597
1632
  pending_artifacts: list[dict[str, Any]],
1598
1633
  metadata_delta: dict[str, Any],
1599
- ) -> tuple[list[BaseMessage], list[dict[str, Any]], dict[str, Any]]:
1634
+ ) -> tuple[list[ToolMessage], list[dict[str, Any]], dict[str, Any]]:
1600
1635
  """Handle legacy dict and other tool outputs.
1601
1636
 
1602
1637
  Args:
@@ -1694,8 +1729,11 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
1694
1729
  self._emit_default_tool_call_event(writer, tool_name, tool_call_id, tool_args)
1695
1730
 
1696
1731
  streaming_kwargs = self._build_streaming_kwargs(tool_args, tool_config)
1732
+ arun_streaming_method = getattr(tool, TOOL_RUN_STREAMING_METHOD, None)
1733
+ if not callable(arun_streaming_method):
1734
+ raise RuntimeError(f"Tool '{tool_name}' does not implement streaming.")
1697
1735
 
1698
- async for chunk in tool.arun_streaming(**streaming_kwargs):
1736
+ async for chunk in arun_streaming_method(**streaming_kwargs):
1699
1737
  final_output, saw_tool_result = self._handle_streaming_chunk(
1700
1738
  chunk=chunk,
1701
1739
  writer=writer,
@@ -2125,6 +2163,9 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
2125
2163
 
2126
2164
  effective_event_emitter = state.get("event_emitter") or self.event_emitter
2127
2165
 
2166
+ if self.lm_invoker is None:
2167
+ raise RuntimeError("LM invoker is required for this execution path.")
2168
+
2128
2169
  if self.resolved_tools:
2129
2170
  self.lm_invoker.set_tools(self.resolved_tools)
2130
2171
 
@@ -2183,6 +2224,9 @@ class LangGraphReactAgent(LangGraphHitLMixin, BaseLangGraphAgent):
2183
2224
  ):
2184
2225
  langchain_prompt = [SystemMessage(content=enhanced_instruction)] + list(current_messages)
2185
2226
 
2227
+ if self.model is None:
2228
+ raise RuntimeError("Model is required for this execution path.")
2229
+
2186
2230
  model_with_tools = self.model.bind_tools(self.resolved_tools) if self.resolved_tools else self.model
2187
2231
 
2188
2232
  ai_message = await model_with_tools.ainvoke(langchain_prompt, config)
@@ -8,10 +8,10 @@ This script demonstrates:
8
8
 
9
9
  Prerequisites:
10
10
  Start the server first:
11
- uv run python -m aip_agents.examples.compare_streaming_server
11
+ poetry run python -m aip_agents.examples.compare_streaming_server
12
12
 
13
13
  Then run this client:
14
- uv run python -m aip_agents.examples.compare_streaming_client
14
+ poetry run python -m aip_agents.examples.compare_streaming_client
15
15
 
16
16
  Authors:
17
17
  AI Agent Platform Team
@@ -6,7 +6,7 @@ This server provides an agent with:
6
6
 
7
7
  To run this server:
8
8
  cd libs/aip_agents
9
- uv run python -m aip_agents.examples.compare_streaming_server
9
+ poetry run python -m aip_agents.examples.compare_streaming_server
10
10
 
11
11
  It will listen on http://localhost:18999 by default.
12
12
 
@@ -39,6 +39,15 @@ async def main():
39
39
  print(chunk["content"], end="", flush=True)
40
40
  if chunk.get("metadata"):
41
41
  print(f"\nMetadata: {chunk['metadata']}", end="\n\n", flush=True)
42
+ tool_info = chunk.get("metadata", {}).get("tool_info") if isinstance(chunk.get("metadata"), dict) else None
43
+ if isinstance(tool_info, dict):
44
+ for tool_call in tool_info.get("tool_calls", []):
45
+ if tool_call.get("name") == "data_visualizer":
46
+ data_source = tool_call.get("args", {}).get("data_source")
47
+ if not (isinstance(data_source, str) and data_source.startswith("$tool_output.")):
48
+ raise RuntimeError(
49
+ "Tool output sharing failed: expected data_source to reference $tool_output.<call_id>."
50
+ )
42
51
  print("\n")
43
52
 
44
53
 
@@ -1,10 +1,10 @@
1
1
  """A2A client for the planning LangGraphReactAgent.
2
2
 
3
3
  Run the planning server first:
4
- uv run python -m aip_agents.examples.todolist_planning_a2a_langgraph_server
4
+ poetry run python -m aip_agents.examples.todolist_planning_a2a_langgraph_server
5
5
 
6
6
  Then run this client:
7
- uv run python -m aip_agents.examples.todolist_planning_a2a_langchain_client
7
+ poetry run python -m aip_agents.examples.todolist_planning_a2a_langchain_client
8
8
 
9
9
  You should see streaming output, including when write_todos_tool is called.
10
10
  """
@@ -1,7 +1,7 @@
1
1
  """A2A server exposing a LangGraphReactAgent with planning (TodoListMiddleware).
2
2
 
3
3
  Run:
4
- uv run python -m aip_agents.examples.todolist_planning_a2a_langgraph_server \
4
+ poetry run python -m aip_agents.examples.todolist_planning_a2a_langgraph_server \
5
5
  --host localhost --port 8002
6
6
 
7
7
  Then connect with the matching A2A client to observe write_todos_tool calls.
@@ -39,7 +39,7 @@ class GuardrailEngine(Protocol):
39
39
  Returns:
40
40
  GuardrailResult indicating if content is safe
41
41
  """
42
- ...
42
+ ... # pragma: no cover
43
43
 
44
44
  @abstractmethod
45
45
  async def check_output(self, content: str) -> GuardrailResult:
@@ -51,12 +51,12 @@ class GuardrailEngine(Protocol):
51
51
  Returns:
52
52
  GuardrailResult indicating if content is safe
53
53
  """
54
- ...
54
+ ... # pragma: no cover
55
55
 
56
56
  @abstractmethod
57
57
  def model_dump(self) -> dict:
58
58
  """Serialize engine configuration into a JSON-compatible dictionary."""
59
- ...
59
+ ... # pragma: no cover
60
60
 
61
61
 
62
62
  class BaseGuardrailEngine(ABC):
@@ -77,14 +77,14 @@ class BaseGuardrailEngine(ABC):
77
77
  @abstractmethod
78
78
  async def check_input(self, content: str) -> GuardrailResult:
79
79
  """Check user input content for safety violations."""
80
- ...
80
+ ... # pragma: no cover
81
81
 
82
82
  @abstractmethod
83
83
  async def check_output(self, content: str) -> GuardrailResult:
84
84
  """Check AI output content for safety violations."""
85
- ...
85
+ ... # pragma: no cover
86
86
 
87
87
  @abstractmethod
88
88
  def model_dump(self) -> dict:
89
89
  """Serialize engine configuration into a JSON-compatible dictionary."""
90
- ...
90
+ ... # pragma: no cover
@@ -60,6 +60,9 @@ class MCPConnectionManager:
60
60
  async def start(self) -> tuple[Any, Any]:
61
61
  """Start connection in background task.
62
62
 
63
+ For HTTP/SSE transports, establishes connection directly to avoid anyio context issues.
64
+ For stdio transport, uses background task to manage subprocess lifecycle.
65
+
63
66
  Returns:
64
67
  tuple[Any, Any]: Tuple of (read_stream, write_stream) for ClientSession
65
68
 
@@ -67,6 +70,17 @@ class MCPConnectionManager:
67
70
  Exception: If connection establishment fails
68
71
  """
69
72
  logger.debug(f"Starting connection manager for {self.server_name}")
73
+
74
+ # Determine transport type first
75
+ self.transport_type = self._get_transport_type()
76
+
77
+ # For HTTP/SSE: connect directly (no background task needed)
78
+ # This avoids anyio.BrokenResourceError when streams cross task boundaries
79
+ if self.transport_type in (TransportType.HTTP, TransportType.SSE):
80
+ await self._establish_connection()
81
+ return self._connection
82
+
83
+ # For stdio: use background task to manage subprocess
70
84
  self._task = asyncio.create_task(self._connection_task())
71
85
  await self._ready_event.wait()
72
86
 
@@ -78,6 +92,20 @@ class MCPConnectionManager:
78
92
  async def stop(self) -> None:
79
93
  """Stop connection gracefully."""
80
94
  logger.debug(f"Stopping connection manager for {self.server_name}")
95
+
96
+ # For HTTP/SSE (no background task), just close transport
97
+ if self.transport_type in (TransportType.HTTP, TransportType.SSE):
98
+ if self._transport:
99
+ try:
100
+ close_result = self._transport.close()
101
+ if inspect.isawaitable(close_result):
102
+ await close_result
103
+ except Exception as exc:
104
+ logger.warning(f"Failed to close transport cleanly for {self.server_name}: {exc}")
105
+ self._connection = None
106
+ return
107
+
108
+ # For stdio (with background task), wait for task to finish
81
109
  if self._task and not self._task.done():
82
110
  self._stop_event.set()
83
111
  try:
@@ -94,6 +122,11 @@ class MCPConnectionManager:
94
122
  Returns:
95
123
  bool: True if connected, False otherwise
96
124
  """
125
+ # For HTTP/SSE (no background task), just check if connection exists
126
+ if self.transport_type in (TransportType.HTTP, TransportType.SSE):
127
+ return self._connection is not None
128
+
129
+ # For stdio (with background task), check task status too
97
130
  return (
98
131
  self._connection is not None
99
132
  and self._task is not None
@@ -144,7 +177,9 @@ class MCPConnectionManager:
144
177
  Raises:
145
178
  ConnectionError: If all connection attempts fail
146
179
  """
147
- self.transport_type = self._get_transport_type()
180
+ # transport_type may already be set by start() for HTTP/SSE
181
+ if not self.transport_type:
182
+ self.transport_type = self._get_transport_type()
148
183
  details = f"URL: {self.config.get('url', 'N/A')}, Command: {self.config.get('command', 'N/A')}"
149
184
  logger.info(f"Establishing connection to {self.server_name} via {self.transport_type} ({details})")
150
185
 
@@ -31,6 +31,9 @@ class MCPConnectionManager:
31
31
  async def start(self) -> tuple[Any, Any]:
32
32
  """Start connection in background task.
33
33
 
34
+ For HTTP/SSE transports, establishes connection directly to avoid anyio context issues.
35
+ For stdio transport, uses background task to manage subprocess lifecycle.
36
+
34
37
  Returns:
35
38
  tuple[Any, Any]: Tuple of (read_stream, write_stream) for ClientSession
36
39