grasp_agents 0.1.18__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/PKG-INFO +37 -33
  2. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/README.md +36 -32
  3. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/pyproject.toml +1 -1
  4. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/agent_message.py +2 -2
  5. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/agent_message_pool.py +6 -8
  6. grasp_agents-0.2.0/src/grasp_agents/base_agent.py +51 -0
  7. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/cloud_llm.py +9 -6
  8. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/comm_agent.py +39 -43
  9. grasp_agents-0.2.0/src/grasp_agents/generics_utils.py +159 -0
  10. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/llm.py +4 -0
  11. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/llm_agent.py +68 -37
  12. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/llm_agent_state.py +9 -5
  13. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/prompt_builder.py +45 -25
  14. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +49 -48
  15. grasp_agents-0.2.0/src/grasp_agents/rate_limiting/types.py +36 -0
  16. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/rate_limiting/utils.py +24 -27
  17. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/run_context.py +2 -15
  18. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/tool_orchestrator.py +30 -8
  19. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/typing/converters.py +3 -1
  20. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/typing/io.py +4 -9
  21. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/typing/tool.py +26 -7
  22. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/utils.py +26 -39
  23. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/workflow/looped_agent.py +12 -9
  24. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/workflow/sequential_agent.py +9 -6
  25. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/workflow/workflow_agent.py +16 -11
  26. grasp_agents-0.1.18/src/grasp_agents/base_agent.py +0 -72
  27. grasp_agents-0.1.18/src/grasp_agents/rate_limiting/types.py +0 -57
  28. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/.gitignore +0 -0
  29. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/LICENSE.md +0 -0
  30. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/costs_dict.yaml +0 -0
  31. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/grasp_logging.py +0 -0
  32. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/http_client.py +0 -0
  33. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/memory.py +0 -0
  34. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/openai/__init__.py +0 -0
  35. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/openai/completion_converters.py +0 -0
  36. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/openai/content_converters.py +0 -0
  37. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/openai/converters.py +0 -0
  38. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/openai/message_converters.py +0 -0
  39. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/openai/openai_llm.py +0 -0
  40. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/openai/tool_converters.py +0 -0
  41. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/printer.py +0 -0
  42. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/rate_limiting/__init__.py +0 -0
  43. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/typing/__init__.py +0 -0
  44. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/typing/completion.py +0 -0
  45. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/typing/content.py +0 -0
  46. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/typing/message.py +0 -0
  47. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/usage_tracker.py +0 -0
  48. {grasp_agents-0.1.18 → grasp_agents-0.2.0}/src/grasp_agents/workflow/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.1.18
3
+ Version: 0.2.0
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
@@ -17,7 +17,10 @@ Description-Content-Type: text/markdown
17
17
  # Grasp Agents
18
18
 
19
19
  <br/>
20
- <img src="./.assets/grasp.svg" alt="Grasp Agents" width="320" />
20
+ <picture>
21
+ <source srcset="./.assets/grasp-dark.svg" media="(prefers-color-scheme: dark)">
22
+ <img src="./.assets/grasp.svg" alt="Grasp Agents"/>
23
+ </picture>
21
24
  <br/>
22
25
  <br/>
23
26
 
@@ -34,14 +37,14 @@ Description-Content-Type: text/markdown
34
37
  ## Features
35
38
 
36
39
  - Clean formulation of agents as generic entities over:
37
- * I/O schemas
38
- * Agent state
39
- * Shared context
40
+ - I/O schemas
41
+ - Agent state
42
+ - Shared context
40
43
  - Transparent implementation of common agentic patterns:
41
44
  * Single-agent loops with an optional "ReAct mode" to enforce reasoning between the tool calls
42
45
  * Workflows (static communication topology), including loops
43
46
  * Agents-as-tools for task delegation
44
- * Freeform A2A communication via in-process Actor model
47
+ * Freeform A2A communication via the in-process actor model
45
48
  - Batch processing support outside of agentic loops
46
49
  - Simple logging and usage/cost tracking
47
50
 
@@ -54,7 +57,7 @@ Description-Content-Type: text/markdown
54
57
  - `prompt_builder.py`: Tools for constructing prompts.
55
58
  - `workflow/`: Modules for defining and managing agent workflows.
56
59
  - `cloud_llm.py`, `llm.py`: LLM integration and base LLM functionalities.
57
- - `openai/`: Modules specific to OpenAI API integration.
60
+ - `openai/`: Modules specific to OpenAI API integration.
58
61
  - `memory.py`: Memory management for agents (currently only message history).
59
62
  - `run_context.py`: Context management for agent runs.
60
63
  - `usage_tracker.py`: Tracking of API usage and costs.
@@ -107,18 +110,20 @@ GOOGLE_AI_STUDIO_API_KEY=your_google_ai_studio_api_key
107
110
  Create a script, e.g., `problem_recommender.py`:
108
111
 
109
112
  ```python
113
+ import asyncio
110
114
  import re
111
- from typing import Any
112
115
  from pathlib import Path
113
- from pydantic import BaseModel, Field
116
+ from typing import Any
117
+
114
118
  from dotenv import load_dotenv
115
- from grasp_agents.typing.tool import BaseTool
116
- from grasp_agents.typing.io import AgentPayload
117
- from grasp_agents.run_context import RunContextWrapper
118
- from grasp_agents.openai.openai_llm import OpenAILLM, OpenAILLMSettings
119
- from grasp_agents.llm_agent import LLMAgent
119
+ from pydantic import BaseModel, Field
120
+
120
121
  from grasp_agents.grasp_logging import setup_logging
122
+ from grasp_agents.llm_agent import LLMAgent
123
+ from grasp_agents.openai.openai_llm import OpenAILLM, OpenAILLMSettings
124
+ from grasp_agents.run_context import RunContextWrapper
121
125
  from grasp_agents.typing.message import Conversation
126
+ from grasp_agents.typing.tool import BaseTool
122
127
 
123
128
  load_dotenv()
124
129
 
@@ -130,8 +135,8 @@ setup_logging(
130
135
  )
131
136
 
132
137
  sys_prompt_react = """
133
- Your task is to suggest an exciting stats problem to a student.
134
- Ask the student about their education, interests, and preferences, then suggest a problem tailored to them.
138
+ Your task is to suggest an exciting stats problem to a student.
139
+ Ask the student about their education, interests, and preferences, then suggest a problem tailored to them.
135
140
 
136
141
  # Instructions
137
142
  * Ask questions one by one.
@@ -143,14 +148,13 @@ Ask the student about their education, interests, and preferences, then suggest
143
148
  class TeacherQuestion(BaseModel):
144
149
  question: str = Field(..., description="The question to ask the student.")
145
150
 
151
+
146
152
  StudentReply = str
147
153
 
148
154
 
149
155
  class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, Any]):
150
156
  name: str = "ask_student_tool"
151
157
  description: str = "Ask the student a question and get their reply."
152
- in_schema: type[TeacherQuestion] = TeacherQuestion
153
- out_schema: type[StudentReply] = StudentReply
154
158
 
155
159
  async def run(
156
160
  self, inp: TeacherQuestion, ctx: RunContextWrapper[Any] | None = None
@@ -158,11 +162,10 @@ class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, Any]):
158
162
  return input(inp.question)
159
163
 
160
164
 
161
- class FinalResponse(AgentPayload):
162
- problem: str
165
+ Problem = str
163
166
 
164
167
 
165
- teacher = LLMAgent[Any, FinalResponse, None](
168
+ teacher = LLMAgent[Any, Problem, None](
166
169
  agent_id="teacher",
167
170
  llm=OpenAILLM(
168
171
  model_name="gpt-4.1",
@@ -173,30 +176,31 @@ teacher = LLMAgent[Any, FinalResponse, None](
173
176
  max_turns=20,
174
177
  react_mode=True,
175
178
  sys_prompt=sys_prompt_react,
176
- out_schema=FinalResponse,
177
179
  set_state_strategy="reset",
178
180
  )
179
181
 
180
182
 
181
- @teacher.tool_call_loop_exit_handler
182
- def exit_tool_call_loop(conversation: Conversation, ctx, **kwargs) -> None:
183
- message_text = conversation[-1].content
184
-
185
- return re.search(r"<PROBLEM>", message_text)
183
+ @teacher.exit_tool_call_loop_handler
184
+ def exit_tool_call_loop(
185
+ conversation: Conversation, ctx: RunContextWrapper[Any] | None, **kwargs: Any
186
+ ) -> bool:
187
+ return r"<PROBLEM>" in str(conversation[-1].content)
186
188
 
187
189
 
188
190
  @teacher.parse_output_handler
189
- def parse_output(conversation: Conversation, ctx, **kwargs) -> FinalResponse:
190
- message_text = conversation[-1].content
191
- matches = re.findall(r"<PROBLEM>(.*?)</PROBLEM>", message_text, re.DOTALL)
191
+ def parse_output(
192
+ conversation: Conversation, ctx: RunContextWrapper[Any] | None, **kwargs: Any
193
+ ) -> Problem:
194
+ message = str(conversation[-1].content)
195
+ matches = re.findall(r"<PROBLEM>(.*?)</PROBLEM>", message, re.DOTALL)
192
196
 
193
- return FinalResponse(problem=matches[0])
197
+ return matches[0]
194
198
 
195
199
 
196
200
  async def main():
197
- ctx = RunContextWrapper(print_messages=True)
201
+ ctx = RunContextWrapper[None](print_messages=True)
198
202
  out = await teacher.run(ctx=ctx)
199
- print(out.payloads[0].problem)
203
+ print(out.payloads[0])
200
204
  print(ctx.usage_tracker.total_usage)
201
205
 
202
206
 
@@ -1,7 +1,10 @@
1
1
  # Grasp Agents
2
2
 
3
3
  <br/>
4
- <img src="./.assets/grasp.svg" alt="Grasp Agents" width="320" />
4
+ <picture>
5
+ <source srcset="./.assets/grasp-dark.svg" media="(prefers-color-scheme: dark)">
6
+ <img src="./.assets/grasp.svg" alt="Grasp Agents"/>
7
+ </picture>
5
8
  <br/>
6
9
  <br/>
7
10
 
@@ -18,14 +21,14 @@
18
21
  ## Features
19
22
 
20
23
  - Clean formulation of agents as generic entities over:
21
- * I/O schemas
22
- * Agent state
23
- * Shared context
24
+ - I/O schemas
25
+ - Agent state
26
+ - Shared context
24
27
  - Transparent implementation of common agentic patterns:
25
28
  * Single-agent loops with an optional "ReAct mode" to enforce reasoning between the tool calls
26
29
  * Workflows (static communication topology), including loops
27
30
  * Agents-as-tools for task delegation
28
- * Freeform A2A communication via in-process Actor model
31
+ * Freeform A2A communication via the in-process actor model
29
32
  - Batch processing support outside of agentic loops
30
33
  - Simple logging and usage/cost tracking
31
34
 
@@ -38,7 +41,7 @@
38
41
  - `prompt_builder.py`: Tools for constructing prompts.
39
42
  - `workflow/`: Modules for defining and managing agent workflows.
40
43
  - `cloud_llm.py`, `llm.py`: LLM integration and base LLM functionalities.
41
- - `openai/`: Modules specific to OpenAI API integration.
44
+ - `openai/`: Modules specific to OpenAI API integration.
42
45
  - `memory.py`: Memory management for agents (currently only message history).
43
46
  - `run_context.py`: Context management for agent runs.
44
47
  - `usage_tracker.py`: Tracking of API usage and costs.
@@ -91,18 +94,20 @@ GOOGLE_AI_STUDIO_API_KEY=your_google_ai_studio_api_key
91
94
  Create a script, e.g., `problem_recommender.py`:
92
95
 
93
96
  ```python
97
+ import asyncio
94
98
  import re
95
- from typing import Any
96
99
  from pathlib import Path
97
- from pydantic import BaseModel, Field
100
+ from typing import Any
101
+
98
102
  from dotenv import load_dotenv
99
- from grasp_agents.typing.tool import BaseTool
100
- from grasp_agents.typing.io import AgentPayload
101
- from grasp_agents.run_context import RunContextWrapper
102
- from grasp_agents.openai.openai_llm import OpenAILLM, OpenAILLMSettings
103
- from grasp_agents.llm_agent import LLMAgent
103
+ from pydantic import BaseModel, Field
104
+
104
105
  from grasp_agents.grasp_logging import setup_logging
106
+ from grasp_agents.llm_agent import LLMAgent
107
+ from grasp_agents.openai.openai_llm import OpenAILLM, OpenAILLMSettings
108
+ from grasp_agents.run_context import RunContextWrapper
105
109
  from grasp_agents.typing.message import Conversation
110
+ from grasp_agents.typing.tool import BaseTool
106
111
 
107
112
  load_dotenv()
108
113
 
@@ -114,8 +119,8 @@ setup_logging(
114
119
  )
115
120
 
116
121
  sys_prompt_react = """
117
- Your task is to suggest an exciting stats problem to a student.
118
- Ask the student about their education, interests, and preferences, then suggest a problem tailored to them.
122
+ Your task is to suggest an exciting stats problem to a student.
123
+ Ask the student about their education, interests, and preferences, then suggest a problem tailored to them.
119
124
 
120
125
  # Instructions
121
126
  * Ask questions one by one.
@@ -127,14 +132,13 @@ Ask the student about their education, interests, and preferences, then suggest
127
132
  class TeacherQuestion(BaseModel):
128
133
  question: str = Field(..., description="The question to ask the student.")
129
134
 
135
+
130
136
  StudentReply = str
131
137
 
132
138
 
133
139
  class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, Any]):
134
140
  name: str = "ask_student_tool"
135
141
  description: str = "Ask the student a question and get their reply."
136
- in_schema: type[TeacherQuestion] = TeacherQuestion
137
- out_schema: type[StudentReply] = StudentReply
138
142
 
139
143
  async def run(
140
144
  self, inp: TeacherQuestion, ctx: RunContextWrapper[Any] | None = None
@@ -142,11 +146,10 @@ class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, Any]):
142
146
  return input(inp.question)
143
147
 
144
148
 
145
- class FinalResponse(AgentPayload):
146
- problem: str
149
+ Problem = str
147
150
 
148
151
 
149
- teacher = LLMAgent[Any, FinalResponse, None](
152
+ teacher = LLMAgent[Any, Problem, None](
150
153
  agent_id="teacher",
151
154
  llm=OpenAILLM(
152
155
  model_name="gpt-4.1",
@@ -157,30 +160,31 @@ teacher = LLMAgent[Any, FinalResponse, None](
157
160
  max_turns=20,
158
161
  react_mode=True,
159
162
  sys_prompt=sys_prompt_react,
160
- out_schema=FinalResponse,
161
163
  set_state_strategy="reset",
162
164
  )
163
165
 
164
166
 
165
- @teacher.tool_call_loop_exit_handler
166
- def exit_tool_call_loop(conversation: Conversation, ctx, **kwargs) -> None:
167
- message_text = conversation[-1].content
168
-
169
- return re.search(r"<PROBLEM>", message_text)
167
+ @teacher.exit_tool_call_loop_handler
168
+ def exit_tool_call_loop(
169
+ conversation: Conversation, ctx: RunContextWrapper[Any] | None, **kwargs: Any
170
+ ) -> bool:
171
+ return r"<PROBLEM>" in str(conversation[-1].content)
170
172
 
171
173
 
172
174
  @teacher.parse_output_handler
173
- def parse_output(conversation: Conversation, ctx, **kwargs) -> FinalResponse:
174
- message_text = conversation[-1].content
175
- matches = re.findall(r"<PROBLEM>(.*?)</PROBLEM>", message_text, re.DOTALL)
175
+ def parse_output(
176
+ conversation: Conversation, ctx: RunContextWrapper[Any] | None, **kwargs: Any
177
+ ) -> Problem:
178
+ message = str(conversation[-1].content)
179
+ matches = re.findall(r"<PROBLEM>(.*?)</PROBLEM>", message, re.DOTALL)
176
180
 
177
- return FinalResponse(problem=matches[0])
181
+ return matches[0]
178
182
 
179
183
 
180
184
  async def main():
181
- ctx = RunContextWrapper(print_messages=True)
185
+ ctx = RunContextWrapper[None](print_messages=True)
182
186
  out = await teacher.run(ctx=ctx)
183
- print(out.payloads[0].problem)
187
+ print(out.payloads[0])
184
188
  print(ctx.usage_tracker.total_usage)
185
189
 
186
190
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "grasp_agents"
3
- version = "0.1.18"
3
+ version = "0.2.0"
4
4
  description = "Grasp Agents Library"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11.4,<4"
@@ -4,9 +4,9 @@ from uuid import uuid4
4
4
 
5
5
  from pydantic import BaseModel, ConfigDict, Field
6
6
 
7
- from .typing.io import AgentID, AgentPayload, AgentState
7
+ from .typing.io import AgentID, AgentState
8
8
 
9
- _PayloadT = TypeVar("_PayloadT", bound=AgentPayload, covariant=True) # noqa: PLC0105
9
+ _PayloadT = TypeVar("_PayloadT", covariant=True) # noqa: PLC0105
10
10
  _StateT = TypeVar("_StateT", bound=AgentState, covariant=True) # noqa: PLC0105
11
11
 
12
12
 
@@ -4,12 +4,12 @@ from typing import Any, Generic, Protocol, TypeVar
4
4
 
5
5
  from .agent_message import AgentMessage
6
6
  from .run_context import CtxT, RunContextWrapper
7
- from .typing.io import AgentID, AgentPayload, AgentState
7
+ from .typing.io import AgentID, AgentState
8
8
 
9
9
  logger = logging.getLogger(__name__)
10
10
 
11
11
 
12
- _MH_PayloadT = TypeVar("_MH_PayloadT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
12
+ _MH_PayloadT = TypeVar("_MH_PayloadT", contravariant=True) # noqa: PLC0105
13
13
  _MH_StateT = TypeVar("_MH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
14
14
 
15
15
 
@@ -24,15 +24,13 @@ class MessageHandler(Protocol[_MH_PayloadT, _MH_StateT, CtxT]):
24
24
 
25
25
  class AgentMessagePool(Generic[CtxT]):
26
26
  def __init__(self) -> None:
27
- self._queues: dict[
28
- AgentID, asyncio.Queue[AgentMessage[AgentPayload, AgentState]]
29
- ] = {}
27
+ self._queues: dict[AgentID, asyncio.Queue[AgentMessage[Any, AgentState]]] = {}
30
28
  self._message_handlers: dict[
31
- AgentID, MessageHandler[AgentPayload, AgentState, CtxT]
29
+ AgentID, MessageHandler[Any, AgentState, CtxT]
32
30
  ] = {}
33
31
  self._tasks: dict[AgentID, asyncio.Task[None]] = {}
34
32
 
35
- async def post(self, message: AgentMessage[AgentPayload, AgentState]) -> None:
33
+ async def post(self, message: AgentMessage[Any, AgentState]) -> None:
36
34
  for recipient_id in message.recipient_ids:
37
35
  queue = self._queues.setdefault(recipient_id, asyncio.Queue())
38
36
  await queue.put(message)
@@ -40,7 +38,7 @@ class AgentMessagePool(Generic[CtxT]):
40
38
  def register_message_handler(
41
39
  self,
42
40
  agent_id: AgentID,
43
- handler: MessageHandler[AgentPayload, AgentState, CtxT],
41
+ handler: MessageHandler[Any, AgentState, CtxT],
44
42
  ctx: RunContextWrapper[CtxT] | None = None,
45
43
  **run_kwargs: Any,
46
44
  ) -> None:
@@ -0,0 +1,51 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, ClassVar, Generic
3
+
4
+ from pydantic import TypeAdapter
5
+
6
+ from .generics_utils import AutoInstanceAttributesMixin
7
+ from .run_context import CtxT, RunContextWrapper
8
+ from .typing.io import AgentID, OutT, StateT
9
+ from .typing.tool import BaseTool
10
+
11
+
12
+ class BaseAgent(AutoInstanceAttributesMixin, ABC, Generic[OutT, StateT, CtxT]):
13
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_out_type"}
14
+
15
+ @abstractmethod
16
+ def __init__(self, agent_id: AgentID, **kwargs: Any) -> None:
17
+ self._out_type: type[OutT]
18
+ self._state: StateT
19
+
20
+ super().__init__()
21
+
22
+ self._agent_id = agent_id
23
+ self._out_type_adapter: TypeAdapter[OutT] = TypeAdapter(self._out_type)
24
+
25
+ @property
26
+ def out_type(self) -> type[OutT]:
27
+ return self._out_type
28
+
29
+ @property
30
+ def agent_id(self) -> AgentID:
31
+ return self._agent_id
32
+
33
+ @property
34
+ def state(self) -> StateT:
35
+ return self._state
36
+
37
+ @abstractmethod
38
+ async def run(
39
+ self,
40
+ inp_items: Any,
41
+ *,
42
+ ctx: RunContextWrapper[CtxT] | None = None,
43
+ **kwargs: Any,
44
+ ) -> Any:
45
+ pass
46
+
47
+ @abstractmethod
48
+ def as_tool(
49
+ self, tool_name: str, tool_description: str, tool_strict: bool = True
50
+ ) -> BaseTool[Any, OutT, CtxT]:
51
+ pass
@@ -26,7 +26,7 @@ from .rate_limiting.rate_limiter_chunked import ( # type: ignore
26
26
  from .typing.completion import Completion, CompletionChunk
27
27
  from .typing.message import AssistantMessage, Conversation
28
28
  from .typing.tool import BaseTool, ToolChoice
29
- from .utils import extract_json
29
+ from .utils import validate_obj_from_json_or_py_string
30
30
 
31
31
  logger = logging.getLogger(__name__)
32
32
 
@@ -271,6 +271,11 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
271
271
  api_completion, model_id=self.model_id
272
272
  )
273
273
 
274
+ self._validate_completion(completion)
275
+
276
+ return completion
277
+
278
+ def _validate_completion(self, completion: Completion) -> None:
274
279
  for choice in completion.choices:
275
280
  message = choice.message
276
281
  if (
@@ -278,12 +283,10 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
278
283
  and not self._llm_settings.get("use_structured_outputs")
279
284
  and not message.tool_calls
280
285
  ):
281
- message_json = extract_json(
282
- message.content, return_none_on_failure=True
286
+ validate_obj_from_json_or_py_string(
287
+ message.content,
288
+ adapter=self._response_format_pyd,
283
289
  )
284
- self._response_format_pyd.validate_python(message_json)
285
-
286
- return completion
287
290
 
288
291
  async def generate_completion_stream(
289
292
  self,
@@ -1,26 +1,26 @@
1
1
  import logging
2
2
  from abc import abstractmethod
3
3
  from collections.abc import Sequence
4
- from typing import Any, Generic, Protocol, TypeVar, cast, final
4
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
5
5
 
6
- from pydantic import BaseModel
6
+ from pydantic import BaseModel, TypeAdapter
7
7
  from pydantic.json_schema import SkipJsonSchema
8
8
 
9
9
  from .agent_message import AgentMessage
10
10
  from .agent_message_pool import AgentMessagePool
11
11
  from .base_agent import BaseAgent
12
12
  from .run_context import CtxT, RunContextWrapper
13
- from .typing.io import AgentID, AgentPayload, AgentState, InT, OutT, StateT
13
+ from .typing.io import AgentID, AgentState, InT, OutT, StateT
14
14
  from .typing.tool import BaseTool
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
18
18
 
19
- class DCommAgentPayload(AgentPayload):
19
+ class DynCommPayload(BaseModel):
20
20
  selected_recipient_ids: SkipJsonSchema[Sequence[AgentID]]
21
21
 
22
22
 
23
- _EH_OutT = TypeVar("_EH_OutT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
23
+ _EH_OutT = TypeVar("_EH_OutT", contravariant=True) # noqa: PLC0105
24
24
  _EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
25
25
 
26
26
 
@@ -35,44 +35,37 @@ class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
35
35
  class CommunicatingAgent(
36
36
  BaseAgent[OutT, StateT, CtxT], Generic[InT, OutT, StateT, CtxT]
37
37
  ):
38
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
39
+ 0: "_in_type",
40
+ 1: "_out_type",
41
+ }
42
+
38
43
  def __init__(
39
44
  self,
40
45
  agent_id: AgentID,
41
46
  *,
42
- out_schema: type[OutT] = AgentPayload,
43
- rcv_args_schema: type[InT] = AgentPayload,
44
47
  recipient_ids: Sequence[AgentID] | None = None,
45
48
  message_pool: AgentMessagePool[CtxT] | None = None,
46
49
  **kwargs: Any,
47
50
  ) -> None:
48
- super().__init__(agent_id=agent_id, out_schema=out_schema, **kwargs)
49
- self._message_pool = message_pool or AgentMessagePool()
51
+ self._in_type: type[InT]
52
+ super().__init__(agent_id=agent_id, **kwargs)
53
+
54
+ self._rcv_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
55
+ self.recipient_ids = recipient_ids or []
50
56
 
57
+ self._message_pool = message_pool or AgentMessagePool()
51
58
  self._is_listening = False
52
59
  self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
53
60
 
54
- self._rcv_args_schema = rcv_args_schema
55
- self.recipient_ids = recipient_ids or []
56
-
57
61
  @property
58
- def rcv_args_schema(self) -> type[InT]: # type: ignore[reportInvalidTypeVarUse]
59
- return self._rcv_args_schema
60
-
61
- def _parse_output(
62
- self,
63
- *args: Any,
64
- rcv_args: InT | None = None,
65
- ctx: RunContextWrapper[CtxT] | None = None,
66
- **kwargs: Any,
67
- ) -> OutT:
68
- if self._parse_output_impl:
69
- return self._parse_output_impl(*args, rcv_args=rcv_args, ctx=ctx, **kwargs)
70
-
71
- return self._out_schema()
62
+ def in_type(self) -> type[InT]: # type: ignore
63
+ # Exposing the type of a contravariant variable only, should be safe
64
+ return self._in_type
72
65
 
73
66
  def _validate_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
74
- if all(isinstance(p, DCommAgentPayload) for p in payloads):
75
- payloads_ = cast("Sequence[DCommAgentPayload]", payloads)
67
+ if all(isinstance(p, DynCommPayload) for p in payloads):
68
+ payloads_ = cast("Sequence[DynCommPayload]", payloads)
76
69
  selected_recipient_ids_per_payload = [
77
70
  set(p.selected_recipient_ids or []) for p in payloads_
78
71
  ]
@@ -91,7 +84,7 @@ class CommunicatingAgent(
91
84
 
92
85
  return selected_recipient_ids
93
86
 
94
- if all((not isinstance(p, DCommAgentPayload)) for p in payloads):
87
+ if all((not isinstance(p, DynCommPayload)) for p in payloads):
95
88
  return self.recipient_ids
96
89
 
97
90
  raise ValueError(
@@ -109,7 +102,7 @@ class CommunicatingAgent(
109
102
  inp_items: Any | None = None,
110
103
  *,
111
104
  ctx: RunContextWrapper[CtxT] | None = None,
112
- rcv_message: AgentMessage[InT, StateT] | None = None,
105
+ rcv_message: AgentMessage[InT, AgentState] | None = None,
113
106
  entry_point: bool = False,
114
107
  forbid_state_change: bool = False,
115
108
  **kwargs: Any,
@@ -143,11 +136,11 @@ class CommunicatingAgent(
143
136
 
144
137
  async def _message_handler(
145
138
  self,
146
- message: AgentMessage[AgentPayload, AgentState],
139
+ message: AgentMessage[Any, AgentState],
147
140
  ctx: RunContextWrapper[CtxT] | None = None,
148
141
  **run_kwargs: Any,
149
142
  ) -> None:
150
- rcv_message = cast("AgentMessage[InT, StateT]", message)
143
+ rcv_message = cast("AgentMessage[InT, AgentState]", message)
151
144
  out_message = await self.run(ctx=ctx, rcv_message=rcv_message, **run_kwargs)
152
145
 
153
146
  if self._exit_condition(output_message=out_message, ctx=ctx):
@@ -185,15 +178,20 @@ class CommunicatingAgent(
185
178
  tool_name: str,
186
179
  tool_description: str,
187
180
  tool_strict: bool = True,
188
- ) -> BaseTool[Any, Any, Any]:
181
+ ) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
182
+ # Will check if InT is a BaseModel at runtime
189
183
  agent_instance = self
184
+ in_type = agent_instance.in_type
185
+ out_type = agent_instance.out_type
186
+ if not issubclass(in_type, BaseModel):
187
+ raise TypeError(
188
+ "Cannot create a tool from an agent with "
189
+ f"non-BaseModel input type: {in_type}"
190
+ )
190
191
 
191
- class AgentTool(BaseTool[Any, Any, Any]):
192
+ class AgentTool(BaseTool[in_type, out_type, Any]):
192
193
  name: str = tool_name
193
194
  description: str = tool_description
194
- in_schema: type[BaseModel] = agent_instance.rcv_args_schema
195
- out_schema: Any = agent_instance.out_schema
196
-
197
195
  strict: bool | None = tool_strict
198
196
 
199
197
  async def run(
@@ -201,16 +199,14 @@ class CommunicatingAgent(
201
199
  inp: InT,
202
200
  ctx: RunContextWrapper[CtxT] | None = None,
203
201
  ) -> OutT:
204
- rcv_args = agent_instance.rcv_args_schema.model_validate(inp)
205
-
206
- rcv_message = AgentMessage( # type: ignore[arg-type]
202
+ rcv_args = in_type.model_validate(inp)
203
+ rcv_message = AgentMessage[in_type, AgentState](
207
204
  payloads=[rcv_args],
208
205
  sender_id="<tool_user>",
209
206
  recipient_ids=[agent_instance.agent_id],
210
207
  )
211
-
212
208
  agent_result = await agent_instance.run(
213
- rcv_message=rcv_message, # type: ignore[arg-type]
209
+ rcv_message=rcv_message,
214
210
  entry_point=False,
215
211
  forbid_state_change=True,
216
212
  ctx=ctx,
@@ -218,4 +214,4 @@ class CommunicatingAgent(
218
214
 
219
215
  return agent_result.payloads[0]
220
216
 
221
- return AgentTool()
217
+ return AgentTool() # type: ignore[return-value]