grasp_agents 0.1.14__tar.gz → 0.1.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/PKG-INFO +3 -4
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/README.md +0 -2
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/pyproject.toml +3 -3
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/agent_message.py +0 -1
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/base_agent.py +1 -1
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/cloud_llm.py +79 -36
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/comm_agent.py +40 -49
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/llm.py +6 -6
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/llm_agent.py +81 -63
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/memory.py +0 -6
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/openai/completion_converters.py +4 -3
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/openai/converters.py +2 -8
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/openai/message_converters.py +1 -6
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/openai/openai_llm.py +1 -3
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/openai/tool_converters.py +1 -1
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/tool_orchestrator.py +2 -2
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/typing/converters.py +2 -10
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/typing/io.py +1 -4
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/typing/message.py +5 -3
- grasp_agents-0.1.16/src/grasp_agents/typing/tool.py +59 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/utils.py +117 -67
- grasp_agents-0.1.14/src/grasp_agents/typing/tool.py +0 -52
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/.gitignore +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/LICENSE.md +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/agent_message_pool.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/costs_dict.yaml +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/data_retrieval/__init__.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/data_retrieval/rate_limiter_chunked.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/data_retrieval/types.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/data_retrieval/utils.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/grasp_logging.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/http_client.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/llm_agent_state.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/openai/__init__.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/openai/content_converters.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/printer.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/prompt_builder.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/run_context.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/typing/__init__.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/typing/completion.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/typing/content.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/usage_tracker.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/workflow/__init__.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/workflow/looped_agent.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/workflow/sequential_agent.py +0 -0
- {grasp_agents-0.1.14 → grasp_agents-0.1.16}/src/grasp_agents/workflow/workflow_agent.py +0 -0
@@ -1,12 +1,13 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: grasp_agents
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.16
|
4
4
|
Summary: Grasp Agents Library
|
5
5
|
License-File: LICENSE.md
|
6
|
-
Requires-Python: <
|
6
|
+
Requires-Python: <4,>=3.11.4
|
7
7
|
Requires-Dist: dotenv>=0.9.9
|
8
8
|
Requires-Dist: httpx<1,>=0.27.0
|
9
9
|
Requires-Dist: openai<2,>=1.68.2
|
10
|
+
Requires-Dist: pydantic>=2
|
10
11
|
Requires-Dist: pyyaml>=6.0.2
|
11
12
|
Requires-Dist: tenacity>=9.1.2
|
12
13
|
Requires-Dist: termcolor<3,>=2.4.0
|
@@ -53,8 +54,6 @@ Description-Content-Type: text/markdown
|
|
53
54
|
|
54
55
|
## Quickstart & Installation Variants (UV Package manager)
|
55
56
|
|
56
|
-
### Option 1: UV Package Manager Project
|
57
|
-
|
58
57
|
> **Note:** You can check this sample project code in the [src/grasp_agents/examples/demo/uv](src/grasp_agents/examples/demo/uv) folder. Feel free to copy and paste the code from there to a separate project. There are also [examples](src/grasp_agents/examples/demo/) for other package managers.
|
59
58
|
|
60
59
|
#### 1. Prerequisites
|
@@ -38,8 +38,6 @@
|
|
38
38
|
|
39
39
|
## Quickstart & Installation Variants (UV Package manager)
|
40
40
|
|
41
|
-
### Option 1: UV Package Manager Project
|
42
|
-
|
43
41
|
> **Note:** You can check this sample project code in the [src/grasp_agents/examples/demo/uv](src/grasp_agents/examples/demo/uv) folder. Feel free to copy and paste the code from there to a separate project. There are also [examples](src/grasp_agents/examples/demo/) for other package managers.
|
44
42
|
|
45
43
|
#### 1. Prerequisites
|
@@ -1,9 +1,9 @@
|
|
1
1
|
[project]
|
2
2
|
name = "grasp_agents"
|
3
|
-
version = "0.1.
|
3
|
+
version = "0.1.16"
|
4
4
|
description = "Grasp Agents Library"
|
5
5
|
readme = "README.md"
|
6
|
-
requires-python = ">=3.11.4,<
|
6
|
+
requires-python = ">=3.11.4,<4"
|
7
7
|
dependencies = [
|
8
8
|
"httpx>=0.27.0,<1",
|
9
9
|
"openai>=1.68.2,<2",
|
@@ -12,6 +12,7 @@ dependencies = [
|
|
12
12
|
"tqdm>=4.66.2,<5",
|
13
13
|
"dotenv>=0.9.9",
|
14
14
|
"pyyaml>=6.0.2",
|
15
|
+
"pydantic>=2",
|
15
16
|
]
|
16
17
|
|
17
18
|
[dependency-groups]
|
@@ -22,7 +23,6 @@ dev = [
|
|
22
23
|
"ipywidgets>=8.0.4,<9",
|
23
24
|
"widgetsnbextension>=4.0.5,<5",
|
24
25
|
"types-cachetools>=5.0.1,<6",
|
25
|
-
"pydantic>=2",
|
26
26
|
"pre-commit-uv>=4.1.4",
|
27
27
|
"twine>=5.1.1,<6",
|
28
28
|
"ruff>=0.11.8",
|
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
import os
|
4
4
|
from abc import abstractmethod
|
5
5
|
from collections.abc import AsyncIterator, Sequence
|
6
|
+
from copy import deepcopy
|
6
7
|
from typing import Any, Generic, Literal
|
7
8
|
|
8
9
|
import httpx
|
@@ -19,7 +20,6 @@ from .data_retrieval.rate_limiter_chunked import ( # type: ignore
|
|
19
20
|
RateLimiterC,
|
20
21
|
limit_rate_chunked,
|
21
22
|
)
|
22
|
-
|
23
23
|
from .http_client import AsyncHTTPClientParams, create_async_http_client
|
24
24
|
from .llm import LLM, ConvertT, LLMSettings, SettingsT
|
25
25
|
from .memory import MessageHistory
|
@@ -38,7 +38,7 @@ class APIProviderInfo(TypedDict):
|
|
38
38
|
name: APIProvider
|
39
39
|
base_url: str
|
40
40
|
api_key: str | None
|
41
|
-
struct_output_support:
|
41
|
+
struct_output_support: tuple[str, ...]
|
42
42
|
|
43
43
|
|
44
44
|
PROVIDERS: dict[APIProvider, APIProviderInfo] = {
|
@@ -46,19 +46,19 @@ PROVIDERS: dict[APIProvider, APIProviderInfo] = {
|
|
46
46
|
name="openai",
|
47
47
|
base_url="https://api.openai.com/v1",
|
48
48
|
api_key=os.getenv("OPENAI_API_KEY"),
|
49
|
-
struct_output_support=
|
49
|
+
struct_output_support=("*",),
|
50
50
|
),
|
51
51
|
"openrouter": APIProviderInfo(
|
52
52
|
name="openrouter",
|
53
53
|
base_url="https://openrouter.ai/api/v1",
|
54
54
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
55
|
-
struct_output_support=
|
55
|
+
struct_output_support=(),
|
56
56
|
),
|
57
57
|
"google_ai_studio": APIProviderInfo(
|
58
58
|
name="google_ai_studio",
|
59
59
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
60
60
|
api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
|
61
|
-
struct_output_support=
|
61
|
+
struct_output_support=("*",),
|
62
62
|
),
|
63
63
|
}
|
64
64
|
|
@@ -92,6 +92,7 @@ class CloudLLMSettings(LLMSettings, total=False):
|
|
92
92
|
temperature: float | None
|
93
93
|
top_p: float | None
|
94
94
|
seed: int | None
|
95
|
+
use_structured_outputs: bool
|
95
96
|
|
96
97
|
|
97
98
|
class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
@@ -102,7 +103,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
102
103
|
converters: ConvertT,
|
103
104
|
llm_settings: SettingsT | None = None,
|
104
105
|
model_id: str | None = None,
|
105
|
-
tools: list[BaseTool[BaseModel,
|
106
|
+
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
106
107
|
response_format: type | None = None,
|
107
108
|
# Connection settings
|
108
109
|
api_provider: APIProvider = "openai",
|
@@ -135,13 +136,21 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
135
136
|
self._model_name = model_name
|
136
137
|
self._api_provider: APIProvider = api_provider
|
137
138
|
|
138
|
-
patterns = PROVIDERS[api_provider]["struct_output_support"]
|
139
139
|
self._struct_output_support: bool = any(
|
140
|
-
fnmatch.fnmatch(self._model_name, pat)
|
140
|
+
fnmatch.fnmatch(self._model_name, pat)
|
141
|
+
for pat in PROVIDERS[api_provider]["struct_output_support"]
|
141
142
|
)
|
142
143
|
self._response_format_pyd: TypeAdapter[Any] | None = (
|
143
144
|
TypeAdapter(self._response_format) if response_format else None
|
144
145
|
)
|
146
|
+
if (
|
147
|
+
self._llm_settings.get("use_structured_outputs")
|
148
|
+
and not self._struct_output_support
|
149
|
+
):
|
150
|
+
raise ValueError(
|
151
|
+
f"Model {api_provider}:{self._model_name} does "
|
152
|
+
"not support structured outputs."
|
153
|
+
)
|
145
154
|
|
146
155
|
self._rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = (
|
147
156
|
self._get_rate_limiter(
|
@@ -181,8 +190,8 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
181
190
|
def _make_completion_kwargs(
|
182
191
|
self, conversation: Conversation, tool_choice: ToolChoice | None = None
|
183
192
|
) -> dict[str, Any]:
|
184
|
-
api_llm_settings = self.llm_settings or {}
|
185
193
|
api_messages = [self._converters.to_message(m) for m in conversation]
|
194
|
+
|
186
195
|
api_tools = None
|
187
196
|
api_tool_choice = None
|
188
197
|
if self.tools:
|
@@ -190,6 +199,9 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
190
199
|
if tool_choice is not None:
|
191
200
|
api_tool_choice = self._converters.to_tool_choice(tool_choice)
|
192
201
|
|
202
|
+
api_llm_settings = deepcopy(self.llm_settings or {})
|
203
|
+
api_llm_settings.pop("use_structured_outputs", None)
|
204
|
+
|
193
205
|
return dict(
|
194
206
|
api_messages=api_messages,
|
195
207
|
api_tools=api_tools,
|
@@ -216,6 +228,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
216
228
|
*,
|
217
229
|
api_tools: list[Any] | None = None,
|
218
230
|
api_tool_choice: Any | None = None,
|
231
|
+
api_response_format: type | None = None,
|
219
232
|
**api_llm_settings: Any,
|
220
233
|
) -> Any:
|
221
234
|
pass
|
@@ -242,7 +255,11 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
242
255
|
conversation=conversation, tool_choice=tool_choice
|
243
256
|
)
|
244
257
|
|
245
|
-
if
|
258
|
+
if (
|
259
|
+
self._response_format is None
|
260
|
+
or (not self._struct_output_support)
|
261
|
+
or (not self._llm_settings.get("use_structured_outputs"))
|
262
|
+
):
|
246
263
|
completion_kwargs.pop("api_response_format", None)
|
247
264
|
api_completion = await self._get_completion(**completion_kwargs, **kwargs)
|
248
265
|
else:
|
@@ -250,7 +267,23 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
250
267
|
**completion_kwargs, **kwargs
|
251
268
|
)
|
252
269
|
|
253
|
-
|
270
|
+
completion = self._converters.from_completion(
|
271
|
+
api_completion, model_id=self.model_id
|
272
|
+
)
|
273
|
+
|
274
|
+
for choice in completion.choices:
|
275
|
+
message = choice.message
|
276
|
+
if (
|
277
|
+
self._response_format_pyd is not None
|
278
|
+
and not self._llm_settings.get("use_structured_outputs")
|
279
|
+
and not message.tool_calls
|
280
|
+
):
|
281
|
+
message_json = extract_json(
|
282
|
+
message.content, return_none_on_failure=True
|
283
|
+
)
|
284
|
+
self._response_format_pyd.validate_python(message_json)
|
285
|
+
|
286
|
+
return completion
|
254
287
|
|
255
288
|
async def generate_completion_stream(
|
256
289
|
self,
|
@@ -271,63 +304,73 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
271
304
|
api_completion_chunk_iterator, model_id=self.model_id
|
272
305
|
)
|
273
306
|
|
274
|
-
async def
|
275
|
-
self,
|
276
|
-
conversation: Conversation,
|
277
|
-
*,
|
278
|
-
tool_choice: ToolChoice | None = None,
|
279
|
-
**kwargs: Any,
|
280
|
-
) -> AssistantMessage:
|
281
|
-
completion = await self.generate_completion(
|
282
|
-
conversation, tool_choice=tool_choice, **kwargs
|
283
|
-
)
|
284
|
-
message = completion.choices[0].message
|
285
|
-
if self._response_format_pyd is not None and not self._struct_output_support:
|
286
|
-
self._response_format_pyd.validate_python(extract_json(message.content))
|
287
|
-
|
288
|
-
return message
|
289
|
-
|
290
|
-
async def _generate_message_with_retry(
|
307
|
+
async def _generate_completion_with_retry(
|
291
308
|
self,
|
292
309
|
conversation: Conversation,
|
293
310
|
*,
|
294
311
|
tool_choice: ToolChoice | None = None,
|
295
312
|
**kwargs: Any,
|
296
|
-
) ->
|
313
|
+
) -> Completion:
|
297
314
|
wrapped_func = retry(
|
298
315
|
wait=wait_random_exponential(min=1, max=8),
|
299
316
|
stop=stop_after_attempt(self.num_generation_retries + 1),
|
300
317
|
before=retry_before_callback,
|
301
318
|
retry_error_callback=retry_error_callback,
|
302
|
-
)(self.__class__.
|
319
|
+
)(self.__class__.generate_completion)
|
303
320
|
|
304
321
|
return await wrapped_func(self, conversation, tool_choice=tool_choice, **kwargs)
|
305
322
|
|
306
323
|
@limit_rate_chunked # type: ignore
|
307
|
-
async def
|
324
|
+
async def _generate_completion_batch_with_retry_and_rate_lim(
|
308
325
|
self,
|
309
326
|
conversation: Conversation,
|
310
327
|
*,
|
311
328
|
tool_choice: ToolChoice | None = None,
|
312
329
|
**kwargs: Any,
|
313
|
-
) ->
|
314
|
-
return await self.
|
330
|
+
) -> Completion:
|
331
|
+
return await self._generate_completion_with_retry(
|
315
332
|
conversation, tool_choice=tool_choice, **kwargs
|
316
333
|
)
|
317
334
|
|
318
|
-
async def
|
335
|
+
async def generate_completion_batch(
|
319
336
|
self,
|
320
337
|
message_history: MessageHistory,
|
321
338
|
*,
|
322
339
|
tool_choice: ToolChoice | None = None,
|
323
340
|
**kwargs: Any,
|
324
|
-
) -> Sequence[
|
325
|
-
return await self.
|
341
|
+
) -> Sequence[Completion]:
|
342
|
+
return await self._generate_completion_batch_with_retry_and_rate_lim(
|
326
343
|
list(message_history.batched_conversations), # type: ignore
|
327
344
|
tool_choice=tool_choice,
|
328
345
|
**kwargs,
|
329
346
|
)
|
330
347
|
|
348
|
+
async def generate_message(
|
349
|
+
self,
|
350
|
+
conversation: Conversation,
|
351
|
+
*,
|
352
|
+
tool_choice: ToolChoice | None = None,
|
353
|
+
**kwargs: Any,
|
354
|
+
) -> AssistantMessage:
|
355
|
+
completion = await self.generate_completion(
|
356
|
+
conversation, tool_choice=tool_choice, **kwargs
|
357
|
+
)
|
358
|
+
|
359
|
+
return completion.choices[0].message
|
360
|
+
|
361
|
+
async def generate_message_batch(
|
362
|
+
self,
|
363
|
+
message_history: MessageHistory,
|
364
|
+
*,
|
365
|
+
tool_choice: ToolChoice | None = None,
|
366
|
+
**kwargs: Any,
|
367
|
+
) -> Sequence[AssistantMessage]:
|
368
|
+
completion_batch = await self.generate_completion_batch(
|
369
|
+
message_history, tool_choice=tool_choice, **kwargs
|
370
|
+
)
|
371
|
+
|
372
|
+
return [completion.choices[0].message for completion in completion_batch]
|
373
|
+
|
331
374
|
def _get_rate_limiter(
|
332
375
|
self,
|
333
376
|
rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = None,
|
@@ -4,6 +4,7 @@ from collections.abc import Sequence
|
|
4
4
|
from typing import Any, Generic, Protocol, TypeVar, cast, final
|
5
5
|
|
6
6
|
from pydantic import BaseModel
|
7
|
+
from pydantic.json_schema import SkipJsonSchema
|
7
8
|
|
8
9
|
from .agent_message import AgentMessage
|
9
10
|
from .agent_message_pool import AgentMessagePool
|
@@ -14,6 +15,11 @@ from .typing.tool import BaseTool
|
|
14
15
|
|
15
16
|
logger = logging.getLogger(__name__)
|
16
17
|
|
18
|
+
|
19
|
+
class DCommAgentPayload(AgentPayload):
|
20
|
+
selected_recipient_ids: SkipJsonSchema[Sequence[AgentID]]
|
21
|
+
|
22
|
+
|
17
23
|
_EH_OutT = TypeVar("_EH_OutT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
|
18
24
|
_EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
|
19
25
|
|
@@ -22,7 +28,6 @@ class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
|
|
22
28
|
def __call__(
|
23
29
|
self,
|
24
30
|
output_message: AgentMessage[_EH_OutT, _EH_StateT],
|
25
|
-
agent_state: _EH_StateT,
|
26
31
|
ctx: RunContextWrapper[CtxT] | None,
|
27
32
|
) -> bool: ...
|
28
33
|
|
@@ -38,14 +43,11 @@ class CommunicatingAgent(
|
|
38
43
|
rcv_args_schema: type[InT] = AgentPayload,
|
39
44
|
recipient_ids: Sequence[AgentID] | None = None,
|
40
45
|
message_pool: AgentMessagePool[CtxT] | None = None,
|
41
|
-
dynamic_routing: bool = False,
|
42
46
|
**kwargs: Any,
|
43
47
|
) -> None:
|
44
48
|
super().__init__(agent_id=agent_id, out_schema=out_schema, **kwargs)
|
45
49
|
self._message_pool = message_pool or AgentMessagePool()
|
46
50
|
|
47
|
-
self._dynamic_routing = dynamic_routing
|
48
|
-
|
49
51
|
self._is_listening = False
|
50
52
|
self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
|
51
53
|
|
@@ -56,10 +58,6 @@ class CommunicatingAgent(
|
|
56
58
|
def rcv_args_schema(self) -> type[InT]: # type: ignore[reportInvalidTypeVarUse]
|
57
59
|
return self._rcv_args_schema
|
58
60
|
|
59
|
-
@property
|
60
|
-
def dynamic_routing(self) -> bool:
|
61
|
-
return self._dynamic_routing
|
62
|
-
|
63
61
|
def _parse_output(
|
64
62
|
self,
|
65
63
|
*args: Any,
|
@@ -72,41 +70,36 @@ class CommunicatingAgent(
|
|
72
70
|
|
73
71
|
return self._out_schema()
|
74
72
|
|
75
|
-
def
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
73
|
+
def _validate_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
|
74
|
+
if all(isinstance(p, DCommAgentPayload) for p in payloads):
|
75
|
+
payloads_ = cast("Sequence[DCommAgentPayload]", payloads)
|
76
|
+
selected_recipient_ids_per_payload = [
|
77
|
+
set(p.selected_recipient_ids or []) for p in payloads_
|
78
|
+
]
|
79
|
+
assert all(
|
80
|
+
x == selected_recipient_ids_per_payload[0]
|
81
|
+
for x in selected_recipient_ids_per_payload
|
82
|
+
), "All payloads must have the same recipient IDs for dynamic routing"
|
83
|
+
|
84
|
+
assert payloads_[0].selected_recipient_ids is not None
|
85
|
+
selected_recipient_ids = payloads_[0].selected_recipient_ids
|
86
|
+
|
87
|
+
assert all(rid in self.recipient_ids for rid in selected_recipient_ids), (
|
88
|
+
"Dynamic routing is enabled, but recipient IDs are not in "
|
89
|
+
"the allowed agent's recipient IDs"
|
90
|
+
)
|
90
91
|
|
91
|
-
|
92
|
-
"Dynamic routing is enabled, but recipient IDs are not in "
|
93
|
-
"the allowed agent's recipient IDs"
|
94
|
-
)
|
92
|
+
return selected_recipient_ids
|
95
93
|
|
96
|
-
|
94
|
+
if all((not isinstance(p, DCommAgentPayload)) for p in payloads):
|
95
|
+
return self.recipient_ids
|
97
96
|
|
98
|
-
|
99
|
-
|
100
|
-
"Dynamic routing is not enabled, but some payloads have recipient IDs"
|
97
|
+
raise ValueError(
|
98
|
+
"All payloads must be either DCommAgentPayload or not DCommAgentPayload"
|
101
99
|
)
|
102
100
|
|
103
|
-
return self.recipient_ids
|
104
|
-
|
105
101
|
async def post_message(self, message: AgentMessage[OutT, StateT]) -> None:
|
106
|
-
|
107
|
-
self._validate_dynamic_routing(message.payloads)
|
108
|
-
else:
|
109
|
-
self._validate_static_routing(message.payloads)
|
102
|
+
self._validate_routing(message.payloads)
|
110
103
|
|
111
104
|
await self._message_pool.post(message)
|
112
105
|
|
@@ -144,9 +137,7 @@ class CommunicatingAgent(
|
|
144
137
|
ctx: RunContextWrapper[CtxT] | None,
|
145
138
|
) -> bool:
|
146
139
|
if self._exit_impl:
|
147
|
-
return self._exit_impl(
|
148
|
-
output_message=output_message, agent_state=self.state, ctx=ctx
|
149
|
-
)
|
140
|
+
return self._exit_impl(output_message=output_message, ctx=ctx)
|
150
141
|
|
151
142
|
return False
|
152
143
|
|
@@ -190,28 +181,28 @@ class CommunicatingAgent(
|
|
190
181
|
|
191
182
|
@final
|
192
183
|
def as_tool(
|
193
|
-
self,
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
184
|
+
self,
|
185
|
+
tool_name: str,
|
186
|
+
tool_description: str,
|
187
|
+
tool_strict: bool = True,
|
188
|
+
) -> BaseTool[Any, Any, Any]:
|
199
189
|
agent_instance = self
|
200
190
|
|
201
|
-
class AgentTool(BaseTool[
|
191
|
+
class AgentTool(BaseTool[Any, Any, Any]):
|
202
192
|
name: str = tool_name
|
203
193
|
description: str = tool_description
|
204
194
|
in_schema: type[BaseModel] = agent_instance.rcv_args_schema
|
205
|
-
out_schema:
|
195
|
+
out_schema: Any = agent_instance.out_schema
|
206
196
|
|
207
197
|
strict: bool | None = tool_strict
|
208
198
|
|
209
199
|
async def run(
|
210
200
|
self,
|
211
|
-
inp:
|
201
|
+
inp: InT,
|
212
202
|
ctx: RunContextWrapper[CtxT] | None = None,
|
213
203
|
) -> OutT:
|
214
204
|
rcv_args = agent_instance.rcv_args_schema.model_validate(inp)
|
205
|
+
|
215
206
|
rcv_message = AgentMessage( # type: ignore[arg-type]
|
216
207
|
payloads=[rcv_args],
|
217
208
|
sender_id="<tool_user>",
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
from collections.abc import AsyncIterator, Sequence
|
4
|
-
from typing import Any, Generic, TypeVar
|
4
|
+
from typing import Any, Generic, TypeVar, cast
|
5
5
|
from uuid import uuid4
|
6
6
|
|
7
7
|
from pydantic import BaseModel
|
@@ -32,7 +32,7 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
|
|
32
32
|
model_name: str | None = None,
|
33
33
|
model_id: str | None = None,
|
34
34
|
llm_settings: SettingsT | None = None,
|
35
|
-
tools: list[BaseTool[BaseModel,
|
35
|
+
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
36
36
|
response_format: type | None = None,
|
37
37
|
**kwargs: Any,
|
38
38
|
) -> None:
|
@@ -41,9 +41,9 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
|
|
41
41
|
self._converters = converters
|
42
42
|
self._model_id = model_id or str(uuid4())[:8]
|
43
43
|
self._model_name = model_name
|
44
|
-
self._llm_settings = llm_settings
|
45
44
|
self._tools = {t.name: t for t in tools} if tools else None
|
46
45
|
self._response_format = response_format
|
46
|
+
self._llm_settings: SettingsT = llm_settings or cast("SettingsT", {})
|
47
47
|
|
48
48
|
@property
|
49
49
|
def model_id(self) -> str:
|
@@ -54,11 +54,11 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
|
|
54
54
|
return self._model_name
|
55
55
|
|
56
56
|
@property
|
57
|
-
def llm_settings(self) -> SettingsT
|
57
|
+
def llm_settings(self) -> SettingsT:
|
58
58
|
return self._llm_settings
|
59
59
|
|
60
60
|
@property
|
61
|
-
def tools(self) -> dict[str, BaseTool[BaseModel,
|
61
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
|
62
62
|
return self._tools
|
63
63
|
|
64
64
|
@property
|
@@ -66,7 +66,7 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
|
|
66
66
|
return self._response_format
|
67
67
|
|
68
68
|
@tools.setter
|
69
|
-
def tools(self, tools: list[BaseTool[BaseModel,
|
69
|
+
def tools(self, tools: list[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
70
70
|
self._tools = {t.name: t for t in tools} if tools else None
|
71
71
|
|
72
72
|
def __repr__(self) -> str:
|