grasp_agents 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +0 -1
- grasp_agents/base_agent.py +1 -1
- grasp_agents/cloud_llm.py +79 -36
- grasp_agents/comm_agent.py +40 -49
- grasp_agents/llm.py +6 -6
- grasp_agents/llm_agent.py +81 -63
- grasp_agents/memory.py +0 -6
- grasp_agents/openai/completion_converters.py +4 -3
- grasp_agents/openai/converters.py +2 -8
- grasp_agents/openai/message_converters.py +1 -6
- grasp_agents/openai/openai_llm.py +1 -3
- grasp_agents/openai/tool_converters.py +1 -1
- grasp_agents/tool_orchestrator.py +2 -2
- grasp_agents/typing/converters.py +2 -10
- grasp_agents/typing/io.py +1 -4
- grasp_agents/typing/message.py +5 -3
- grasp_agents/typing/tool.py +18 -11
- grasp_agents/utils.py +117 -67
- {grasp_agents-0.1.14.dist-info → grasp_agents-0.1.16.dist-info}/METADATA +3 -4
- {grasp_agents-0.1.14.dist-info → grasp_agents-0.1.16.dist-info}/RECORD +22 -22
- {grasp_agents-0.1.14.dist-info → grasp_agents-0.1.16.dist-info}/WHEEL +0 -0
- {grasp_agents-0.1.14.dist-info → grasp_agents-0.1.16.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/agent_message.py
CHANGED
grasp_agents/base_agent.py
CHANGED
grasp_agents/cloud_llm.py
CHANGED
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
import os
|
4
4
|
from abc import abstractmethod
|
5
5
|
from collections.abc import AsyncIterator, Sequence
|
6
|
+
from copy import deepcopy
|
6
7
|
from typing import Any, Generic, Literal
|
7
8
|
|
8
9
|
import httpx
|
@@ -19,7 +20,6 @@ from .data_retrieval.rate_limiter_chunked import ( # type: ignore
|
|
19
20
|
RateLimiterC,
|
20
21
|
limit_rate_chunked,
|
21
22
|
)
|
22
|
-
|
23
23
|
from .http_client import AsyncHTTPClientParams, create_async_http_client
|
24
24
|
from .llm import LLM, ConvertT, LLMSettings, SettingsT
|
25
25
|
from .memory import MessageHistory
|
@@ -38,7 +38,7 @@ class APIProviderInfo(TypedDict):
|
|
38
38
|
name: APIProvider
|
39
39
|
base_url: str
|
40
40
|
api_key: str | None
|
41
|
-
struct_output_support:
|
41
|
+
struct_output_support: tuple[str, ...]
|
42
42
|
|
43
43
|
|
44
44
|
PROVIDERS: dict[APIProvider, APIProviderInfo] = {
|
@@ -46,19 +46,19 @@ PROVIDERS: dict[APIProvider, APIProviderInfo] = {
|
|
46
46
|
name="openai",
|
47
47
|
base_url="https://api.openai.com/v1",
|
48
48
|
api_key=os.getenv("OPENAI_API_KEY"),
|
49
|
-
struct_output_support=
|
49
|
+
struct_output_support=("*",),
|
50
50
|
),
|
51
51
|
"openrouter": APIProviderInfo(
|
52
52
|
name="openrouter",
|
53
53
|
base_url="https://openrouter.ai/api/v1",
|
54
54
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
55
|
-
struct_output_support=
|
55
|
+
struct_output_support=(),
|
56
56
|
),
|
57
57
|
"google_ai_studio": APIProviderInfo(
|
58
58
|
name="google_ai_studio",
|
59
59
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
60
60
|
api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
|
61
|
-
struct_output_support=
|
61
|
+
struct_output_support=("*",),
|
62
62
|
),
|
63
63
|
}
|
64
64
|
|
@@ -92,6 +92,7 @@ class CloudLLMSettings(LLMSettings, total=False):
|
|
92
92
|
temperature: float | None
|
93
93
|
top_p: float | None
|
94
94
|
seed: int | None
|
95
|
+
use_structured_outputs: bool
|
95
96
|
|
96
97
|
|
97
98
|
class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
@@ -102,7 +103,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
102
103
|
converters: ConvertT,
|
103
104
|
llm_settings: SettingsT | None = None,
|
104
105
|
model_id: str | None = None,
|
105
|
-
tools: list[BaseTool[BaseModel,
|
106
|
+
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
106
107
|
response_format: type | None = None,
|
107
108
|
# Connection settings
|
108
109
|
api_provider: APIProvider = "openai",
|
@@ -135,13 +136,21 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
135
136
|
self._model_name = model_name
|
136
137
|
self._api_provider: APIProvider = api_provider
|
137
138
|
|
138
|
-
patterns = PROVIDERS[api_provider]["struct_output_support"]
|
139
139
|
self._struct_output_support: bool = any(
|
140
|
-
fnmatch.fnmatch(self._model_name, pat)
|
140
|
+
fnmatch.fnmatch(self._model_name, pat)
|
141
|
+
for pat in PROVIDERS[api_provider]["struct_output_support"]
|
141
142
|
)
|
142
143
|
self._response_format_pyd: TypeAdapter[Any] | None = (
|
143
144
|
TypeAdapter(self._response_format) if response_format else None
|
144
145
|
)
|
146
|
+
if (
|
147
|
+
self._llm_settings.get("use_structured_outputs")
|
148
|
+
and not self._struct_output_support
|
149
|
+
):
|
150
|
+
raise ValueError(
|
151
|
+
f"Model {api_provider}:{self._model_name} does "
|
152
|
+
"not support structured outputs."
|
153
|
+
)
|
145
154
|
|
146
155
|
self._rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = (
|
147
156
|
self._get_rate_limiter(
|
@@ -181,8 +190,8 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
181
190
|
def _make_completion_kwargs(
|
182
191
|
self, conversation: Conversation, tool_choice: ToolChoice | None = None
|
183
192
|
) -> dict[str, Any]:
|
184
|
-
api_llm_settings = self.llm_settings or {}
|
185
193
|
api_messages = [self._converters.to_message(m) for m in conversation]
|
194
|
+
|
186
195
|
api_tools = None
|
187
196
|
api_tool_choice = None
|
188
197
|
if self.tools:
|
@@ -190,6 +199,9 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
190
199
|
if tool_choice is not None:
|
191
200
|
api_tool_choice = self._converters.to_tool_choice(tool_choice)
|
192
201
|
|
202
|
+
api_llm_settings = deepcopy(self.llm_settings or {})
|
203
|
+
api_llm_settings.pop("use_structured_outputs", None)
|
204
|
+
|
193
205
|
return dict(
|
194
206
|
api_messages=api_messages,
|
195
207
|
api_tools=api_tools,
|
@@ -216,6 +228,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
216
228
|
*,
|
217
229
|
api_tools: list[Any] | None = None,
|
218
230
|
api_tool_choice: Any | None = None,
|
231
|
+
api_response_format: type | None = None,
|
219
232
|
**api_llm_settings: Any,
|
220
233
|
) -> Any:
|
221
234
|
pass
|
@@ -242,7 +255,11 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
242
255
|
conversation=conversation, tool_choice=tool_choice
|
243
256
|
)
|
244
257
|
|
245
|
-
if
|
258
|
+
if (
|
259
|
+
self._response_format is None
|
260
|
+
or (not self._struct_output_support)
|
261
|
+
or (not self._llm_settings.get("use_structured_outputs"))
|
262
|
+
):
|
246
263
|
completion_kwargs.pop("api_response_format", None)
|
247
264
|
api_completion = await self._get_completion(**completion_kwargs, **kwargs)
|
248
265
|
else:
|
@@ -250,7 +267,23 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
250
267
|
**completion_kwargs, **kwargs
|
251
268
|
)
|
252
269
|
|
253
|
-
|
270
|
+
completion = self._converters.from_completion(
|
271
|
+
api_completion, model_id=self.model_id
|
272
|
+
)
|
273
|
+
|
274
|
+
for choice in completion.choices:
|
275
|
+
message = choice.message
|
276
|
+
if (
|
277
|
+
self._response_format_pyd is not None
|
278
|
+
and not self._llm_settings.get("use_structured_outputs")
|
279
|
+
and not message.tool_calls
|
280
|
+
):
|
281
|
+
message_json = extract_json(
|
282
|
+
message.content, return_none_on_failure=True
|
283
|
+
)
|
284
|
+
self._response_format_pyd.validate_python(message_json)
|
285
|
+
|
286
|
+
return completion
|
254
287
|
|
255
288
|
async def generate_completion_stream(
|
256
289
|
self,
|
@@ -271,63 +304,73 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
271
304
|
api_completion_chunk_iterator, model_id=self.model_id
|
272
305
|
)
|
273
306
|
|
274
|
-
async def
|
275
|
-
self,
|
276
|
-
conversation: Conversation,
|
277
|
-
*,
|
278
|
-
tool_choice: ToolChoice | None = None,
|
279
|
-
**kwargs: Any,
|
280
|
-
) -> AssistantMessage:
|
281
|
-
completion = await self.generate_completion(
|
282
|
-
conversation, tool_choice=tool_choice, **kwargs
|
283
|
-
)
|
284
|
-
message = completion.choices[0].message
|
285
|
-
if self._response_format_pyd is not None and not self._struct_output_support:
|
286
|
-
self._response_format_pyd.validate_python(extract_json(message.content))
|
287
|
-
|
288
|
-
return message
|
289
|
-
|
290
|
-
async def _generate_message_with_retry(
|
307
|
+
async def _generate_completion_with_retry(
|
291
308
|
self,
|
292
309
|
conversation: Conversation,
|
293
310
|
*,
|
294
311
|
tool_choice: ToolChoice | None = None,
|
295
312
|
**kwargs: Any,
|
296
|
-
) ->
|
313
|
+
) -> Completion:
|
297
314
|
wrapped_func = retry(
|
298
315
|
wait=wait_random_exponential(min=1, max=8),
|
299
316
|
stop=stop_after_attempt(self.num_generation_retries + 1),
|
300
317
|
before=retry_before_callback,
|
301
318
|
retry_error_callback=retry_error_callback,
|
302
|
-
)(self.__class__.
|
319
|
+
)(self.__class__.generate_completion)
|
303
320
|
|
304
321
|
return await wrapped_func(self, conversation, tool_choice=tool_choice, **kwargs)
|
305
322
|
|
306
323
|
@limit_rate_chunked # type: ignore
|
307
|
-
async def
|
324
|
+
async def _generate_completion_batch_with_retry_and_rate_lim(
|
308
325
|
self,
|
309
326
|
conversation: Conversation,
|
310
327
|
*,
|
311
328
|
tool_choice: ToolChoice | None = None,
|
312
329
|
**kwargs: Any,
|
313
|
-
) ->
|
314
|
-
return await self.
|
330
|
+
) -> Completion:
|
331
|
+
return await self._generate_completion_with_retry(
|
315
332
|
conversation, tool_choice=tool_choice, **kwargs
|
316
333
|
)
|
317
334
|
|
318
|
-
async def
|
335
|
+
async def generate_completion_batch(
|
319
336
|
self,
|
320
337
|
message_history: MessageHistory,
|
321
338
|
*,
|
322
339
|
tool_choice: ToolChoice | None = None,
|
323
340
|
**kwargs: Any,
|
324
|
-
) -> Sequence[
|
325
|
-
return await self.
|
341
|
+
) -> Sequence[Completion]:
|
342
|
+
return await self._generate_completion_batch_with_retry_and_rate_lim(
|
326
343
|
list(message_history.batched_conversations), # type: ignore
|
327
344
|
tool_choice=tool_choice,
|
328
345
|
**kwargs,
|
329
346
|
)
|
330
347
|
|
348
|
+
async def generate_message(
|
349
|
+
self,
|
350
|
+
conversation: Conversation,
|
351
|
+
*,
|
352
|
+
tool_choice: ToolChoice | None = None,
|
353
|
+
**kwargs: Any,
|
354
|
+
) -> AssistantMessage:
|
355
|
+
completion = await self.generate_completion(
|
356
|
+
conversation, tool_choice=tool_choice, **kwargs
|
357
|
+
)
|
358
|
+
|
359
|
+
return completion.choices[0].message
|
360
|
+
|
361
|
+
async def generate_message_batch(
|
362
|
+
self,
|
363
|
+
message_history: MessageHistory,
|
364
|
+
*,
|
365
|
+
tool_choice: ToolChoice | None = None,
|
366
|
+
**kwargs: Any,
|
367
|
+
) -> Sequence[AssistantMessage]:
|
368
|
+
completion_batch = await self.generate_completion_batch(
|
369
|
+
message_history, tool_choice=tool_choice, **kwargs
|
370
|
+
)
|
371
|
+
|
372
|
+
return [completion.choices[0].message for completion in completion_batch]
|
373
|
+
|
331
374
|
def _get_rate_limiter(
|
332
375
|
self,
|
333
376
|
rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = None,
|
grasp_agents/comm_agent.py
CHANGED
@@ -4,6 +4,7 @@ from collections.abc import Sequence
|
|
4
4
|
from typing import Any, Generic, Protocol, TypeVar, cast, final
|
5
5
|
|
6
6
|
from pydantic import BaseModel
|
7
|
+
from pydantic.json_schema import SkipJsonSchema
|
7
8
|
|
8
9
|
from .agent_message import AgentMessage
|
9
10
|
from .agent_message_pool import AgentMessagePool
|
@@ -14,6 +15,11 @@ from .typing.tool import BaseTool
|
|
14
15
|
|
15
16
|
logger = logging.getLogger(__name__)
|
16
17
|
|
18
|
+
|
19
|
+
class DCommAgentPayload(AgentPayload):
|
20
|
+
selected_recipient_ids: SkipJsonSchema[Sequence[AgentID]]
|
21
|
+
|
22
|
+
|
17
23
|
_EH_OutT = TypeVar("_EH_OutT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
|
18
24
|
_EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
|
19
25
|
|
@@ -22,7 +28,6 @@ class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
|
|
22
28
|
def __call__(
|
23
29
|
self,
|
24
30
|
output_message: AgentMessage[_EH_OutT, _EH_StateT],
|
25
|
-
agent_state: _EH_StateT,
|
26
31
|
ctx: RunContextWrapper[CtxT] | None,
|
27
32
|
) -> bool: ...
|
28
33
|
|
@@ -38,14 +43,11 @@ class CommunicatingAgent(
|
|
38
43
|
rcv_args_schema: type[InT] = AgentPayload,
|
39
44
|
recipient_ids: Sequence[AgentID] | None = None,
|
40
45
|
message_pool: AgentMessagePool[CtxT] | None = None,
|
41
|
-
dynamic_routing: bool = False,
|
42
46
|
**kwargs: Any,
|
43
47
|
) -> None:
|
44
48
|
super().__init__(agent_id=agent_id, out_schema=out_schema, **kwargs)
|
45
49
|
self._message_pool = message_pool or AgentMessagePool()
|
46
50
|
|
47
|
-
self._dynamic_routing = dynamic_routing
|
48
|
-
|
49
51
|
self._is_listening = False
|
50
52
|
self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
|
51
53
|
|
@@ -56,10 +58,6 @@ class CommunicatingAgent(
|
|
56
58
|
def rcv_args_schema(self) -> type[InT]: # type: ignore[reportInvalidTypeVarUse]
|
57
59
|
return self._rcv_args_schema
|
58
60
|
|
59
|
-
@property
|
60
|
-
def dynamic_routing(self) -> bool:
|
61
|
-
return self._dynamic_routing
|
62
|
-
|
63
61
|
def _parse_output(
|
64
62
|
self,
|
65
63
|
*args: Any,
|
@@ -72,41 +70,36 @@ class CommunicatingAgent(
|
|
72
70
|
|
73
71
|
return self._out_schema()
|
74
72
|
|
75
|
-
def
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
73
|
+
def _validate_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
|
74
|
+
if all(isinstance(p, DCommAgentPayload) for p in payloads):
|
75
|
+
payloads_ = cast("Sequence[DCommAgentPayload]", payloads)
|
76
|
+
selected_recipient_ids_per_payload = [
|
77
|
+
set(p.selected_recipient_ids or []) for p in payloads_
|
78
|
+
]
|
79
|
+
assert all(
|
80
|
+
x == selected_recipient_ids_per_payload[0]
|
81
|
+
for x in selected_recipient_ids_per_payload
|
82
|
+
), "All payloads must have the same recipient IDs for dynamic routing"
|
83
|
+
|
84
|
+
assert payloads_[0].selected_recipient_ids is not None
|
85
|
+
selected_recipient_ids = payloads_[0].selected_recipient_ids
|
86
|
+
|
87
|
+
assert all(rid in self.recipient_ids for rid in selected_recipient_ids), (
|
88
|
+
"Dynamic routing is enabled, but recipient IDs are not in "
|
89
|
+
"the allowed agent's recipient IDs"
|
90
|
+
)
|
90
91
|
|
91
|
-
|
92
|
-
"Dynamic routing is enabled, but recipient IDs are not in "
|
93
|
-
"the allowed agent's recipient IDs"
|
94
|
-
)
|
92
|
+
return selected_recipient_ids
|
95
93
|
|
96
|
-
|
94
|
+
if all((not isinstance(p, DCommAgentPayload)) for p in payloads):
|
95
|
+
return self.recipient_ids
|
97
96
|
|
98
|
-
|
99
|
-
|
100
|
-
"Dynamic routing is not enabled, but some payloads have recipient IDs"
|
97
|
+
raise ValueError(
|
98
|
+
"All payloads must be either DCommAgentPayload or not DCommAgentPayload"
|
101
99
|
)
|
102
100
|
|
103
|
-
return self.recipient_ids
|
104
|
-
|
105
101
|
async def post_message(self, message: AgentMessage[OutT, StateT]) -> None:
|
106
|
-
|
107
|
-
self._validate_dynamic_routing(message.payloads)
|
108
|
-
else:
|
109
|
-
self._validate_static_routing(message.payloads)
|
102
|
+
self._validate_routing(message.payloads)
|
110
103
|
|
111
104
|
await self._message_pool.post(message)
|
112
105
|
|
@@ -144,9 +137,7 @@ class CommunicatingAgent(
|
|
144
137
|
ctx: RunContextWrapper[CtxT] | None,
|
145
138
|
) -> bool:
|
146
139
|
if self._exit_impl:
|
147
|
-
return self._exit_impl(
|
148
|
-
output_message=output_message, agent_state=self.state, ctx=ctx
|
149
|
-
)
|
140
|
+
return self._exit_impl(output_message=output_message, ctx=ctx)
|
150
141
|
|
151
142
|
return False
|
152
143
|
|
@@ -190,28 +181,28 @@ class CommunicatingAgent(
|
|
190
181
|
|
191
182
|
@final
|
192
183
|
def as_tool(
|
193
|
-
self,
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
184
|
+
self,
|
185
|
+
tool_name: str,
|
186
|
+
tool_description: str,
|
187
|
+
tool_strict: bool = True,
|
188
|
+
) -> BaseTool[Any, Any, Any]:
|
199
189
|
agent_instance = self
|
200
190
|
|
201
|
-
class AgentTool(BaseTool[
|
191
|
+
class AgentTool(BaseTool[Any, Any, Any]):
|
202
192
|
name: str = tool_name
|
203
193
|
description: str = tool_description
|
204
194
|
in_schema: type[BaseModel] = agent_instance.rcv_args_schema
|
205
|
-
out_schema:
|
195
|
+
out_schema: Any = agent_instance.out_schema
|
206
196
|
|
207
197
|
strict: bool | None = tool_strict
|
208
198
|
|
209
199
|
async def run(
|
210
200
|
self,
|
211
|
-
inp:
|
201
|
+
inp: InT,
|
212
202
|
ctx: RunContextWrapper[CtxT] | None = None,
|
213
203
|
) -> OutT:
|
214
204
|
rcv_args = agent_instance.rcv_args_schema.model_validate(inp)
|
205
|
+
|
215
206
|
rcv_message = AgentMessage( # type: ignore[arg-type]
|
216
207
|
payloads=[rcv_args],
|
217
208
|
sender_id="<tool_user>",
|
grasp_agents/llm.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
from collections.abc import AsyncIterator, Sequence
|
4
|
-
from typing import Any, Generic, TypeVar
|
4
|
+
from typing import Any, Generic, TypeVar, cast
|
5
5
|
from uuid import uuid4
|
6
6
|
|
7
7
|
from pydantic import BaseModel
|
@@ -32,7 +32,7 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
|
|
32
32
|
model_name: str | None = None,
|
33
33
|
model_id: str | None = None,
|
34
34
|
llm_settings: SettingsT | None = None,
|
35
|
-
tools: list[BaseTool[BaseModel,
|
35
|
+
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
36
36
|
response_format: type | None = None,
|
37
37
|
**kwargs: Any,
|
38
38
|
) -> None:
|
@@ -41,9 +41,9 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
|
|
41
41
|
self._converters = converters
|
42
42
|
self._model_id = model_id or str(uuid4())[:8]
|
43
43
|
self._model_name = model_name
|
44
|
-
self._llm_settings = llm_settings
|
45
44
|
self._tools = {t.name: t for t in tools} if tools else None
|
46
45
|
self._response_format = response_format
|
46
|
+
self._llm_settings: SettingsT = llm_settings or cast("SettingsT", {})
|
47
47
|
|
48
48
|
@property
|
49
49
|
def model_id(self) -> str:
|
@@ -54,11 +54,11 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
|
|
54
54
|
return self._model_name
|
55
55
|
|
56
56
|
@property
|
57
|
-
def llm_settings(self) -> SettingsT
|
57
|
+
def llm_settings(self) -> SettingsT:
|
58
58
|
return self._llm_settings
|
59
59
|
|
60
60
|
@property
|
61
|
-
def tools(self) -> dict[str, BaseTool[BaseModel,
|
61
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
|
62
62
|
return self._tools
|
63
63
|
|
64
64
|
@property
|
@@ -66,7 +66,7 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
|
|
66
66
|
return self._response_format
|
67
67
|
|
68
68
|
@tools.setter
|
69
|
-
def tools(self, tools: list[BaseTool[BaseModel,
|
69
|
+
def tools(self, tools: list[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
70
70
|
self._tools = {t.name: t for t in tools} if tools else None
|
71
71
|
|
72
72
|
def __repr__(self) -> str:
|
grasp_agents/llm_agent.py
CHANGED
@@ -33,6 +33,8 @@ from .typing.io import (
|
|
33
33
|
AgentPayload,
|
34
34
|
AgentState,
|
35
35
|
InT,
|
36
|
+
LLMFormattedArgs,
|
37
|
+
LLMFormattedSystemArgs,
|
36
38
|
LLMPrompt,
|
37
39
|
LLMPromptArgs,
|
38
40
|
OutT,
|
@@ -67,7 +69,7 @@ class LLMAgent(
|
|
67
69
|
# Output schema
|
68
70
|
out_schema: type[OutT] = cast("type[OutT]", AgentPayload),
|
69
71
|
# Tools
|
70
|
-
tools: list[BaseTool[
|
72
|
+
tools: list[BaseTool[Any, Any, CtxT]] | None = None,
|
71
73
|
max_turns: int = 1000,
|
72
74
|
react_mode: bool = False,
|
73
75
|
# Agent state management
|
@@ -75,7 +77,6 @@ class LLMAgent(
|
|
75
77
|
# Multi-agent routing
|
76
78
|
message_pool: AgentMessagePool[CtxT] | None = None,
|
77
79
|
recipient_ids: list[AgentID] | None = None,
|
78
|
-
dynamic_routing: bool = False,
|
79
80
|
) -> None:
|
80
81
|
super().__init__(
|
81
82
|
agent_id=agent_id,
|
@@ -83,7 +84,6 @@ class LLMAgent(
|
|
83
84
|
rcv_args_schema=rcv_args_schema,
|
84
85
|
message_pool=message_pool,
|
85
86
|
recipient_ids=recipient_ids,
|
86
|
-
dynamic_routing=dynamic_routing,
|
87
87
|
)
|
88
88
|
|
89
89
|
# Agent state
|
@@ -114,12 +114,24 @@ class LLMAgent(
|
|
114
114
|
|
115
115
|
self.no_tqdm = getattr(llm, "no_tqdm", False)
|
116
116
|
|
117
|
+
if type(self)._format_sys_args is not LLMAgent[Any, Any, Any]._format_sys_args: # noqa: SLF001
|
118
|
+
self._prompt_builder.format_sys_args_impl = self._format_sys_args
|
119
|
+
|
120
|
+
if type(self)._format_inp_args is not LLMAgent[Any, Any, Any]._format_inp_args: # noqa: SLF001
|
121
|
+
self._prompt_builder.format_inp_args_impl = self._format_inp_args
|
122
|
+
|
123
|
+
if (
|
124
|
+
type(self)._tool_call_loop_exit # noqa: SLF001
|
125
|
+
is not LLMAgent[Any, Any, Any]._tool_call_loop_exit # noqa: SLF001
|
126
|
+
):
|
127
|
+
self._tool_orchestrator.tool_call_loop_exit_impl = self._tool_call_loop_exit
|
128
|
+
|
117
129
|
@property
|
118
130
|
def llm(self) -> LLM[LLMSettings, Converters]:
|
119
131
|
return self._tool_orchestrator.llm
|
120
132
|
|
121
133
|
@property
|
122
|
-
def tools(self) -> dict[str, BaseTool[BaseModel,
|
134
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
123
135
|
return self._tool_orchestrator.tools
|
124
136
|
|
125
137
|
@property
|
@@ -142,37 +154,10 @@ class LLMAgent(
|
|
142
154
|
def inp_prompt(self) -> LLMPrompt | None:
|
143
155
|
return self._prompt_builder.inp_prompt
|
144
156
|
|
145
|
-
def format_sys_args_handler(
|
146
|
-
self, func: FormatSystemArgsHandler[CtxT]
|
147
|
-
) -> FormatSystemArgsHandler[CtxT]:
|
148
|
-
self._prompt_builder.format_sys_args_impl = func
|
149
|
-
|
150
|
-
return func
|
151
|
-
|
152
|
-
def format_inp_args_handler(
|
153
|
-
self, func: FormatInputArgsHandler[InT, CtxT]
|
154
|
-
) -> FormatInputArgsHandler[InT, CtxT]:
|
155
|
-
self._prompt_builder.format_inp_args_impl = func
|
156
|
-
|
157
|
-
return func
|
158
|
-
|
159
|
-
def make_custom_agent_state_handler(
|
160
|
-
self, func: MakeCustomAgentState
|
161
|
-
) -> MakeCustomAgentState:
|
162
|
-
self._make_custom_agent_state_impl = func
|
163
|
-
|
164
|
-
return func
|
165
|
-
|
166
|
-
def tool_call_loop_exit_handler(
|
167
|
-
self, func: ToolCallLoopExitHandler[CtxT]
|
168
|
-
) -> ToolCallLoopExitHandler[CtxT]:
|
169
|
-
self._tool_orchestrator.tool_call_loop_exit_impl = func
|
170
|
-
|
171
|
-
return func
|
172
|
-
|
173
157
|
def _parse_output(
|
174
158
|
self,
|
175
159
|
conversation: Conversation,
|
160
|
+
*,
|
176
161
|
rcv_args: InT | None = None,
|
177
162
|
ctx: RunContextWrapper[CtxT] | None = None,
|
178
163
|
**kwargs: Any,
|
@@ -274,10 +259,7 @@ class LLMAgent(
|
|
274
259
|
|
275
260
|
# 6. Write interaction history to context
|
276
261
|
|
277
|
-
|
278
|
-
recipient_ids = self._validate_dynamic_routing(val_output_batch)
|
279
|
-
else:
|
280
|
-
recipient_ids = self._validate_static_routing(val_output_batch)
|
262
|
+
recipient_ids = self._validate_routing(val_output_batch)
|
281
263
|
|
282
264
|
if ctx:
|
283
265
|
interaction_record = InteractionRecord(
|
@@ -332,30 +314,66 @@ class LLMAgent(
|
|
332
314
|
):
|
333
315
|
self._print_msgs([state.message_history[0][0]], ctx=ctx)
|
334
316
|
|
335
|
-
#
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
317
|
+
# -- Handlers for custom implementations --
|
318
|
+
|
319
|
+
def format_sys_args_handler(
|
320
|
+
self, func: FormatSystemArgsHandler[CtxT]
|
321
|
+
) -> FormatSystemArgsHandler[CtxT]:
|
322
|
+
self._prompt_builder.format_sys_args_impl = func
|
323
|
+
|
324
|
+
return func
|
325
|
+
|
326
|
+
def format_inp_args_handler(
|
327
|
+
self, func: FormatInputArgsHandler[InT, CtxT]
|
328
|
+
) -> FormatInputArgsHandler[InT, CtxT]:
|
329
|
+
self._prompt_builder.format_inp_args_impl = func
|
330
|
+
|
331
|
+
return func
|
332
|
+
|
333
|
+
def make_custom_agent_state_handler(
|
334
|
+
self, func: MakeCustomAgentState
|
335
|
+
) -> MakeCustomAgentState:
|
336
|
+
self._make_custom_agent_state_impl = func
|
337
|
+
|
338
|
+
return func
|
339
|
+
|
340
|
+
def tool_call_loop_exit_handler(
|
341
|
+
self, func: ToolCallLoopExitHandler[CtxT]
|
342
|
+
) -> ToolCallLoopExitHandler[CtxT]:
|
343
|
+
self._tool_orchestrator.tool_call_loop_exit_impl = func
|
344
|
+
|
345
|
+
return func
|
346
|
+
|
347
|
+
# -- Override these methods in subclasses if needed --
|
348
|
+
|
349
|
+
def _format_sys_args(
|
350
|
+
self,
|
351
|
+
sys_args: LLMPromptArgs,
|
352
|
+
*,
|
353
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
354
|
+
) -> LLMFormattedSystemArgs:
|
355
|
+
raise NotImplementedError(
|
356
|
+
"LLMAgent._format_sys_args must be overridden by a subclass "
|
357
|
+
"if it's intended to be used as the system arguments formatter."
|
358
|
+
)
|
359
|
+
|
360
|
+
def _format_inp_args(
|
361
|
+
self,
|
362
|
+
usr_args: LLMPromptArgs,
|
363
|
+
rcv_args: InT,
|
364
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
365
|
+
) -> LLMFormattedArgs:
|
366
|
+
raise NotImplementedError(
|
367
|
+
"LLMAgent._format_inp_args must be overridden by a subclass"
|
368
|
+
)
|
369
|
+
|
370
|
+
def _tool_call_loop_exit(
|
371
|
+
self,
|
372
|
+
conversation: Conversation,
|
373
|
+
*,
|
374
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
375
|
+
**kwargs: Any,
|
376
|
+
) -> bool:
|
377
|
+
raise NotImplementedError(
|
378
|
+
"LLMAgent._tool_call_loop_exit must be overridden by a subclass"
|
379
|
+
)
|
grasp_agents/memory.py
CHANGED
@@ -142,9 +142,3 @@ class MessageHistory:
|
|
142
142
|
|
143
143
|
def erase(self) -> None:
|
144
144
|
self._batched_conversations = [[]]
|
145
|
-
|
146
|
-
# def get_batch(self, batch_id: int) -> list[Message]:
|
147
|
-
# return self._batched_conversations[batch_id]
|
148
|
-
|
149
|
-
# def iterate_conversations(self) -> Iterator[list[Message]]:
|
150
|
-
# return iter(self._batched_conversations)
|
@@ -13,13 +13,14 @@ def from_api_completion(
|
|
13
13
|
api_completion: ChatCompletion, model_id: str | None = None
|
14
14
|
) -> Completion:
|
15
15
|
choices: list[CompletionChoice] = []
|
16
|
-
# TODO: add custom error type
|
17
16
|
if api_completion.choices is None: # type: ignore
|
18
|
-
|
17
|
+
# Choices can sometimes be None for some providers using the OpenAI API
|
18
|
+
# TODO: add custom error types
|
19
|
+
raise RuntimeError(
|
19
20
|
f"Completion API error: {getattr(api_completion, 'error', None)}"
|
20
21
|
)
|
21
22
|
for api_choice in api_completion.choices:
|
22
|
-
# TODO: no way to assign individual message usages when len(choices) > 1
|
23
|
+
# TODO: currently no way to assign individual message usages when len(choices) > 1
|
23
24
|
message = from_api_assistant_message(
|
24
25
|
api_choice.message, api_completion.usage, model_id=model_id
|
25
26
|
)
|
@@ -6,12 +6,7 @@ from pydantic import BaseModel
|
|
6
6
|
from ..typing.completion import Completion, CompletionChunk
|
7
7
|
from ..typing.content import Content
|
8
8
|
from ..typing.converters import Converters
|
9
|
-
from ..typing.message import
|
10
|
-
AssistantMessage,
|
11
|
-
SystemMessage,
|
12
|
-
ToolMessage,
|
13
|
-
UserMessage,
|
14
|
-
)
|
9
|
+
from ..typing.message import AssistantMessage, SystemMessage, ToolMessage, UserMessage
|
15
10
|
from ..typing.tool import BaseTool, ToolChoice
|
16
11
|
from . import (
|
17
12
|
ChatCompletion,
|
@@ -19,7 +14,6 @@ from . import (
|
|
19
14
|
ChatCompletionAsyncStream, # type: ignore[import]
|
20
15
|
ChatCompletionChunk,
|
21
16
|
ChatCompletionContentPartParam,
|
22
|
-
# ChatCompletionDeveloperMessageParam,
|
23
17
|
ChatCompletionMessage,
|
24
18
|
ChatCompletionSystemMessageParam,
|
25
19
|
ChatCompletionToolChoiceOptionParam,
|
@@ -110,7 +104,7 @@ class OpenAIConverters(Converters):
|
|
110
104
|
|
111
105
|
@staticmethod
|
112
106
|
def to_tool(
|
113
|
-
tool: BaseTool[BaseModel,
|
107
|
+
tool: BaseTool[BaseModel, Any, Any], **kwargs: Any
|
114
108
|
) -> ChatCompletionToolParam:
|
115
109
|
return to_api_tool(tool, **kwargs)
|
116
110
|
|
@@ -132,9 +132,6 @@ def to_api_system_message(
|
|
132
132
|
message: SystemMessage,
|
133
133
|
) -> ChatCompletionSystemMessageParam:
|
134
134
|
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
135
|
-
# return ChatCompletionSystemMessageParam(
|
136
|
-
# role="system", content=message.content
|
137
|
-
# )
|
138
135
|
|
139
136
|
|
140
137
|
def from_api_tool_message(
|
@@ -149,7 +146,5 @@ def from_api_tool_message(
|
|
149
146
|
|
150
147
|
def to_api_tool_message(message: ToolMessage) -> ChatCompletionToolMessageParam:
|
151
148
|
return ChatCompletionToolMessageParam(
|
152
|
-
role="tool",
|
153
|
-
content=message.content,
|
154
|
-
tool_call_id=message.tool_call_id,
|
149
|
+
role="tool", content=message.content, tool_call_id=message.tool_call_id
|
155
150
|
)
|
@@ -67,7 +67,7 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
67
67
|
model_name: str,
|
68
68
|
model_id: str | None = None,
|
69
69
|
llm_settings: OpenAILLMSettings | None = None,
|
70
|
-
tools: list[BaseTool[BaseModel,
|
70
|
+
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
71
71
|
response_format: type | None = None,
|
72
72
|
# Connection settings
|
73
73
|
api_provider: APIProvider = "openai",
|
@@ -113,8 +113,6 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
113
113
|
base_url=self._base_url,
|
114
114
|
api_key=self._api_key,
|
115
115
|
**async_openai_client_params_,
|
116
|
-
# timeout=10.0,
|
117
|
-
# max_retries=3,
|
118
116
|
)
|
119
117
|
|
120
118
|
async def _get_completion(
|
@@ -31,7 +31,7 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
31
31
|
self,
|
32
32
|
agent_id: str,
|
33
33
|
llm: LLM[LLMSettings, Converters],
|
34
|
-
tools: list[BaseTool[BaseModel,
|
34
|
+
tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
|
35
35
|
max_turns: int,
|
36
36
|
react_mode: bool = False,
|
37
37
|
) -> None:
|
@@ -55,7 +55,7 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
55
55
|
return self._llm
|
56
56
|
|
57
57
|
@property
|
58
|
-
def tools(self) -> dict[str, BaseTool[BaseModel,
|
58
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
59
59
|
return self._llm.tools or {}
|
60
60
|
|
61
61
|
@property
|
@@ -2,17 +2,9 @@ from abc import ABC, abstractmethod
|
|
2
2
|
from collections.abc import AsyncIterator
|
3
3
|
from typing import Any
|
4
4
|
|
5
|
-
from pydantic import BaseModel
|
6
|
-
|
7
5
|
from .completion import Completion, CompletionChunk
|
8
6
|
from .content import Content
|
9
|
-
from .message import
|
10
|
-
AssistantMessage,
|
11
|
-
Message,
|
12
|
-
SystemMessage,
|
13
|
-
ToolMessage,
|
14
|
-
UserMessage,
|
15
|
-
)
|
7
|
+
from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
|
16
8
|
from .tool import BaseTool, ToolChoice
|
17
9
|
|
18
10
|
|
@@ -72,7 +64,7 @@ class Converters(ABC):
|
|
72
64
|
|
73
65
|
@staticmethod
|
74
66
|
@abstractmethod
|
75
|
-
def to_tool(tool: BaseTool[
|
67
|
+
def to_tool(tool: BaseTool[Any, Any, Any], **kwargs: Any) -> Any:
|
76
68
|
pass
|
77
69
|
|
78
70
|
@staticmethod
|
grasp_agents/typing/io.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1
|
-
from collections.abc import Sequence
|
2
1
|
from typing import TypeAlias, TypeVar
|
3
2
|
|
4
3
|
from pydantic import BaseModel
|
5
|
-
from pydantic.json_schema import SkipJsonSchema
|
6
4
|
|
7
5
|
from .content import ImageData
|
8
6
|
|
@@ -10,8 +8,7 @@ AgentID: TypeAlias = str
|
|
10
8
|
|
11
9
|
|
12
10
|
class AgentPayload(BaseModel):
|
13
|
-
|
14
|
-
selected_recipient_ids: SkipJsonSchema[Sequence[AgentID] | None] = None
|
11
|
+
pass
|
15
12
|
|
16
13
|
|
17
14
|
class AgentState(BaseModel):
|
grasp_agents/typing/message.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
|
+
import json
|
1
2
|
from collections.abc import Hashable, Sequence
|
2
3
|
from enum import StrEnum
|
3
|
-
from typing import Annotated, Literal, TypeAlias
|
4
|
+
from typing import Annotated, Any, Literal, TypeAlias
|
4
5
|
from uuid import uuid4
|
5
6
|
|
6
7
|
from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
|
8
|
+
from pydantic.json import pydantic_encoder
|
7
9
|
|
8
10
|
from .content import Content, ImageData
|
9
11
|
from .tool import ToolCall
|
@@ -110,13 +112,13 @@ class ToolMessage(MessageBase):
|
|
110
112
|
@classmethod
|
111
113
|
def from_tool_output(
|
112
114
|
cls,
|
113
|
-
tool_output:
|
115
|
+
tool_output: Any,
|
114
116
|
tool_call: ToolCall,
|
115
117
|
model_id: str | None = None,
|
116
118
|
indent: int = 2,
|
117
119
|
) -> "ToolMessage":
|
118
120
|
return cls(
|
119
|
-
content=
|
121
|
+
content=json.dumps(tool_output, default=pydantic_encoder, indent=indent),
|
120
122
|
tool_call_id=tool_call.id,
|
121
123
|
model_id=model_id,
|
122
124
|
)
|
grasp_agents/typing/tool.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import asyncio
|
3
4
|
from abc import ABC, abstractmethod
|
5
|
+
from collections.abc import Sequence
|
4
6
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar
|
5
7
|
|
6
|
-
from pydantic import BaseModel
|
8
|
+
from pydantic import BaseModel, TypeAdapter
|
7
9
|
|
8
10
|
if TYPE_CHECKING:
|
9
11
|
from ..run_context import CtxT, RunContextWrapper
|
@@ -14,8 +16,8 @@ else:
|
|
14
16
|
"""Runtime placeholder so RunContextWrapper[CtxT] works"""
|
15
17
|
|
16
18
|
|
17
|
-
|
18
|
-
|
19
|
+
_ToolInT = TypeVar("_ToolInT", bound=BaseModel, contravariant=True) # noqa: PLC0105
|
20
|
+
_ToolOutT = TypeVar("_ToolOutT", covariant=True) # noqa: PLC0105
|
19
21
|
|
20
22
|
|
21
23
|
class ToolCall(BaseModel):
|
@@ -24,29 +26,34 @@ class ToolCall(BaseModel):
|
|
24
26
|
tool_arguments: str
|
25
27
|
|
26
28
|
|
27
|
-
class BaseTool(BaseModel, ABC, Generic[
|
29
|
+
class BaseTool(BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]):
|
28
30
|
name: str
|
29
31
|
description: str
|
30
|
-
in_schema: type[
|
31
|
-
out_schema: type[
|
32
|
+
in_schema: type[_ToolInT]
|
33
|
+
out_schema: type[_ToolOutT]
|
32
34
|
|
33
35
|
# Supported by OpenAI API
|
34
36
|
strict: bool | None = None
|
35
37
|
|
36
38
|
@abstractmethod
|
37
39
|
async def run(
|
38
|
-
self, inp:
|
39
|
-
) ->
|
40
|
+
self, inp: _ToolInT, ctx: RunContextWrapper[CtxT] | None = None
|
41
|
+
) -> _ToolOutT:
|
40
42
|
pass
|
41
43
|
|
44
|
+
async def run_batch(
|
45
|
+
self, inp_batch: Sequence[_ToolInT], ctx: RunContextWrapper[CtxT] | None = None
|
46
|
+
) -> Sequence[_ToolOutT]:
|
47
|
+
return await asyncio.gather(*[self.run(inp, ctx=ctx) for inp in inp_batch])
|
48
|
+
|
42
49
|
async def __call__(
|
43
50
|
self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
|
44
|
-
) ->
|
51
|
+
) -> _ToolOutT:
|
45
52
|
result = await self.run(self.in_schema(**kwargs), ctx=ctx)
|
46
53
|
|
47
|
-
return self.out_schema.
|
54
|
+
return TypeAdapter(self.out_schema).validate_python(result)
|
48
55
|
|
49
56
|
|
50
57
|
ToolChoice: TypeAlias = (
|
51
|
-
Literal["none", "auto", "required"] | BaseTool[BaseModel,
|
58
|
+
Literal["none", "auto", "required"] | BaseTool[BaseModel, Any, Any]
|
52
59
|
)
|
grasp_agents/utils.py
CHANGED
@@ -1,57 +1,39 @@
|
|
1
1
|
import ast
|
2
|
+
import functools
|
2
3
|
import asyncio
|
3
4
|
from datetime import datetime
|
4
|
-
import functools
|
5
5
|
import json
|
6
6
|
import re
|
7
|
-
from
|
7
|
+
from logging import getLogger
|
8
|
+
from collections.abc import Callable
|
8
9
|
from copy import deepcopy
|
9
10
|
from pathlib import Path
|
10
|
-
from typing import Any, TypeVar, cast
|
11
|
-
|
12
|
-
from pydantic import BaseModel, create_model
|
13
|
-
from pydantic.fields import FieldInfo
|
11
|
+
from typing import Any, TypeVar, cast, Coroutine
|
14
12
|
from tqdm.autonotebook import tqdm
|
15
13
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
binary_mode: bool = False,
|
20
|
-
) -> str:
|
21
|
-
"""Reads and returns contents of file"""
|
22
|
-
try:
|
23
|
-
if binary_mode:
|
24
|
-
with open(file_path, "rb") as file:
|
25
|
-
return file.read()
|
26
|
-
else:
|
27
|
-
with open(file_path) as file:
|
28
|
-
return file.read()
|
29
|
-
except FileNotFoundError:
|
30
|
-
print(f"File {file_path} not found.")
|
31
|
-
return ""
|
14
|
+
from pydantic import BaseModel, GetCoreSchemaHandler, TypeAdapter, create_model
|
15
|
+
from pydantic.fields import FieldInfo
|
16
|
+
from pydantic_core import core_schema
|
32
17
|
|
33
18
|
|
34
|
-
|
35
|
-
*corouts: Coroutine[Any, Any, Any],
|
36
|
-
no_tqdm: bool = False,
|
37
|
-
desc: str | None = None,
|
38
|
-
) -> list[Any]:
|
39
|
-
pbar = tqdm(total=len(corouts), desc=desc, disable=no_tqdm)
|
19
|
+
logger = getLogger(__name__)
|
40
20
|
|
41
|
-
async def run_and_update(coro: Coroutine[Any, Any, Any]) -> Any:
|
42
|
-
result = await coro
|
43
|
-
pbar.update(1)
|
44
|
-
return result
|
45
21
|
|
46
|
-
|
47
|
-
|
48
|
-
|
22
|
+
def merge_pydantic_models(*models: type[BaseModel]) -> type[BaseModel]:
|
23
|
+
fields_dict: dict[str, FieldInfo] = {}
|
24
|
+
for model in models:
|
25
|
+
for field_name, field_info in model.model_fields.items():
|
26
|
+
if field_name in fields_dict:
|
27
|
+
raise ValueError(
|
28
|
+
f"Field conflict detected: '{field_name}' exists in multiple models"
|
29
|
+
)
|
30
|
+
fields_dict[field_name] = field_info
|
49
31
|
|
50
|
-
return
|
32
|
+
return create_model("MergedModel", __module__=__name__, **fields_dict) # type: ignore
|
51
33
|
|
52
34
|
|
53
|
-
def
|
54
|
-
return
|
35
|
+
def filter_fields(data: dict[str, Any], model: type[BaseModel]) -> dict[str, Any]:
|
36
|
+
return {key: data[key] for key in model.model_fields if key in data}
|
55
37
|
|
56
38
|
|
57
39
|
def read_txt(file_path: str) -> str:
|
@@ -73,26 +55,32 @@ def format_json_string(text: str) -> str:
|
|
73
55
|
pass
|
74
56
|
i += 1
|
75
57
|
|
76
|
-
return
|
58
|
+
return text
|
77
59
|
|
78
60
|
|
79
|
-
def read_json_string(
|
61
|
+
def read_json_string(
|
62
|
+
json_str: str, return_none_on_failure: bool = False
|
63
|
+
) -> dict[str, Any] | list[Any] | None:
|
80
64
|
try:
|
81
65
|
json_response = ast.literal_eval(json_str)
|
82
66
|
except (ValueError, SyntaxError):
|
83
67
|
try:
|
84
68
|
json_response = json.loads(json_str)
|
85
69
|
except json.JSONDecodeError as exc:
|
70
|
+
if return_none_on_failure:
|
71
|
+
return None
|
86
72
|
raise ValueError(
|
87
73
|
"Invalid JSON - Both ast.literal_eval and json.loads "
|
88
74
|
f"failed to parse the following response:\n{json_str}"
|
89
75
|
) from exc
|
90
76
|
|
91
|
-
return json_response
|
77
|
+
return json_response
|
92
78
|
|
93
79
|
|
94
|
-
def extract_json(
|
95
|
-
|
80
|
+
def extract_json(
|
81
|
+
json_str: str, return_none_on_failure: bool = False
|
82
|
+
) -> dict[str, Any] | list[Any] | None:
|
83
|
+
return read_json_string(format_json_string(json_str), return_none_on_failure)
|
96
84
|
|
97
85
|
|
98
86
|
def extract_xml_list(text: str) -> list[str]:
|
@@ -105,32 +93,43 @@ def extract_xml_list(text: str) -> list[str]:
|
|
105
93
|
return chunks
|
106
94
|
|
107
95
|
|
108
|
-
def
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
)
|
127
|
-
|
128
|
-
|
129
|
-
|
96
|
+
def make_conditional_parsed_output_type(
|
97
|
+
response_format: type, marker: str = "<DONE>"
|
98
|
+
) -> type:
|
99
|
+
class ParsedOutput:
|
100
|
+
"""
|
101
|
+
* Accepts any **str**.
|
102
|
+
* If the string contains `marker`, it must contain a valid JSON for
|
103
|
+
`response_format` → we return that a response_format instance.
|
104
|
+
* Otherwise we leave the string untouched.
|
105
|
+
"""
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def __get_pydantic_core_schema__(
|
109
|
+
cls,
|
110
|
+
_source_type: Any,
|
111
|
+
_handler: GetCoreSchemaHandler,
|
112
|
+
) -> core_schema.CoreSchema:
|
113
|
+
def validator(v: Any) -> Any:
|
114
|
+
if isinstance(v, str) and marker in v:
|
115
|
+
v_json_str = format_json_string(v)
|
116
|
+
response_format_adapter = TypeAdapter[Any](response_format)
|
117
|
+
|
118
|
+
return response_format_adapter.validate_json(v_json_str)
|
119
|
+
|
120
|
+
return v
|
121
|
+
|
122
|
+
return core_schema.no_info_after_validator_function(
|
123
|
+
validator, core_schema.any_schema()
|
124
|
+
)
|
130
125
|
|
126
|
+
@classmethod
|
127
|
+
def __get_pydantic_json_schema__(
|
128
|
+
cls, core_schema: core_schema.CoreSchema, handler: GetCoreSchemaHandler
|
129
|
+
):
|
130
|
+
return handler(core_schema)
|
131
131
|
|
132
|
-
|
133
|
-
return {key: data[key] for key in model.model_fields if key in data}
|
132
|
+
return ParsedOutput
|
134
133
|
|
135
134
|
|
136
135
|
T = TypeVar("T", bound=Callable[..., Any])
|
@@ -149,3 +148,54 @@ def forbid_state_change(method: T) -> T:
|
|
149
148
|
return result
|
150
149
|
|
151
150
|
return cast("T", wrapper)
|
151
|
+
|
152
|
+
|
153
|
+
def read_contents_from_file(
|
154
|
+
file_path: str | Path,
|
155
|
+
binary_mode: bool = False,
|
156
|
+
) -> str | bytes:
|
157
|
+
"""Reads and returns contents of file"""
|
158
|
+
try:
|
159
|
+
if binary_mode:
|
160
|
+
with open(file_path, "rb") as file:
|
161
|
+
return file.read()
|
162
|
+
else:
|
163
|
+
with open(file_path) as file:
|
164
|
+
return file.read()
|
165
|
+
except FileNotFoundError:
|
166
|
+
logger.error(f"File {file_path} not found.")
|
167
|
+
return ""
|
168
|
+
|
169
|
+
|
170
|
+
def get_prompt(prompt_text: str | None, prompt_path: str | Path | None) -> str | None:
|
171
|
+
if prompt_text is None:
|
172
|
+
prompt = (
|
173
|
+
read_contents_from_file(prompt_path) if prompt_path is not None else None
|
174
|
+
)
|
175
|
+
else:
|
176
|
+
prompt = prompt_text
|
177
|
+
|
178
|
+
return prompt # type: ignore[assignment]
|
179
|
+
|
180
|
+
|
181
|
+
async def asyncio_gather_with_pbar(
|
182
|
+
*corouts: Coroutine[Any, Any, Any],
|
183
|
+
no_tqdm: bool = False,
|
184
|
+
desc: str | None = None,
|
185
|
+
) -> list[Any]:
|
186
|
+
pbar = tqdm(total=len(corouts), desc=desc, disable=no_tqdm)
|
187
|
+
|
188
|
+
async def run_and_update(coro: Coroutine[Any, Any, Any]) -> Any:
|
189
|
+
result = await coro
|
190
|
+
pbar.update(1)
|
191
|
+
return result
|
192
|
+
|
193
|
+
wrapped_tasks = [run_and_update(c) for c in corouts]
|
194
|
+
results = await asyncio.gather(*wrapped_tasks)
|
195
|
+
pbar.close()
|
196
|
+
|
197
|
+
return results
|
198
|
+
|
199
|
+
|
200
|
+
def get_timestamp() -> str:
|
201
|
+
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
@@ -1,12 +1,13 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: grasp_agents
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.16
|
4
4
|
Summary: Grasp Agents Library
|
5
5
|
License-File: LICENSE.md
|
6
|
-
Requires-Python: <
|
6
|
+
Requires-Python: <4,>=3.11.4
|
7
7
|
Requires-Dist: dotenv>=0.9.9
|
8
8
|
Requires-Dist: httpx<1,>=0.27.0
|
9
9
|
Requires-Dist: openai<2,>=1.68.2
|
10
|
+
Requires-Dist: pydantic>=2
|
10
11
|
Requires-Dist: pyyaml>=6.0.2
|
11
12
|
Requires-Dist: tenacity>=9.1.2
|
12
13
|
Requires-Dist: termcolor<3,>=2.4.0
|
@@ -53,8 +54,6 @@ Description-Content-Type: text/markdown
|
|
53
54
|
|
54
55
|
## Quickstart & Installation Variants (UV Package manager)
|
55
56
|
|
56
|
-
### Option 1: UV Package Manager Project
|
57
|
-
|
58
57
|
> **Note:** You can check this sample project code in the [src/grasp_agents/examples/demo/uv](src/grasp_agents/examples/demo/uv) folder. Feel free to copy and paste the code from there to a separate project. There are also [examples](src/grasp_agents/examples/demo/) for other package managers.
|
59
58
|
|
60
59
|
#### 1. Prerequisites
|
@@ -1,44 +1,44 @@
|
|
1
|
-
grasp_agents/agent_message.py,sha256=
|
1
|
+
grasp_agents/agent_message.py,sha256=Z-czIHNSe7eAo5r6q2Zu10HFpUjQjVbayyglGyyj4Lw,911
|
2
2
|
grasp_agents/agent_message_pool.py,sha256=4O4xz-aZ_I5m3iLAUiMAQcVn5AGRP4daUM8XME03Xsw,3250
|
3
|
-
grasp_agents/base_agent.py,sha256=
|
4
|
-
grasp_agents/cloud_llm.py,sha256=
|
5
|
-
grasp_agents/comm_agent.py,sha256=
|
3
|
+
grasp_agents/base_agent.py,sha256=4mNMyWdt9MTzg64JhrVjcU_FSYDj0ED_AKpRZXCRVb4,1870
|
4
|
+
grasp_agents/cloud_llm.py,sha256=q72bDyC9bxQdiO552LiOYr58B9T3bMAJu0TGUZH4Lfk,13016
|
5
|
+
grasp_agents/comm_agent.py,sha256=ukkT0zXgE8-yt63fyi3MQ_3Q0QjwIDHKz4MbHt9FBxw,7323
|
6
6
|
grasp_agents/costs_dict.yaml,sha256=EW6XxRXLZobMwQEEiUNYALbDzfbZFb2zEVCaTSAqYjw,2334
|
7
7
|
grasp_agents/grasp_logging.py,sha256=H1GYhXdQvVkmauFDZ-KDwvVmPQHZUUm9sRqX_ObK2xI,1111
|
8
8
|
grasp_agents/http_client.py,sha256=KZva2MjJjuI5ohUeU8RdTAImUnQYaqBrV2jDH8smbJw,738
|
9
|
-
grasp_agents/llm.py,sha256=
|
10
|
-
grasp_agents/llm_agent.py,sha256=
|
9
|
+
grasp_agents/llm.py,sha256=RZY25UNJEdPUmqOHifUrrQTgfoDCAVtZN9WQvvxnLC4,3004
|
10
|
+
grasp_agents/llm_agent.py,sha256=_D16rA8LDhALZKYYuWxB-E75g3uBgzM04EcJyuRMKZA,12619
|
11
11
|
grasp_agents/llm_agent_state.py,sha256=91K1-8Uodbe-t_I6nu0xBzHfQjssZYCHjMuDbu5aCr0,2327
|
12
|
-
grasp_agents/memory.py,sha256=
|
12
|
+
grasp_agents/memory.py,sha256=X1YtVX8XxP5KnGPMW8BqjID8QK4hTG2obxoyhnnZ4pU,5575
|
13
13
|
grasp_agents/printer.py,sha256=Jk6OJExio53gbKBod5Dd8Y3CWYrVb4K5q4UJ8i9cQvo,5024
|
14
14
|
grasp_agents/prompt_builder.py,sha256=rYVIY4adJwBitjrTYvpEh5x8C7cLbIiXxT1F-vQuvEM,7393
|
15
15
|
grasp_agents/run_context.py,sha256=hyETO3-p0azPFws75kX6rrUDLf58Ar6jmyt6TQ5Po78,2589
|
16
|
-
grasp_agents/tool_orchestrator.py,sha256=
|
16
|
+
grasp_agents/tool_orchestrator.py,sha256=6VX8FGYeUHiSt9GpUjOga7KGP55Kzs1jEiNcOZysIAo,5501
|
17
17
|
grasp_agents/usage_tracker.py,sha256=5YuN6hpg6HASdg-hOylgWzhCiORmDMnZuQtbISfhm_4,3378
|
18
|
-
grasp_agents/utils.py,sha256=
|
18
|
+
grasp_agents/utils.py,sha256=tacUTUnPqz8qUvJxGpGhcdXOyvZJOzfblCrtzqfkCj8,5870
|
19
19
|
grasp_agents/data_retrieval/__init__.py,sha256=KRgtF_E7R3YfA2cpYcFcZ7wycV0pWVJ0xRQC7YhiIEQ,158
|
20
20
|
grasp_agents/data_retrieval/rate_limiter_chunked.py,sha256=NPqYrWwKTx1lim_zlhWar5wDwFz1cA-b6JOzT10kOtE,5843
|
21
21
|
grasp_agents/data_retrieval/types.py,sha256=JbLYJC-gmzcHH_4-YNTz9IcIwVpcpDyDGvljxNznf5k,1389
|
22
22
|
grasp_agents/data_retrieval/utils.py,sha256=D3Bkq6-9gF7ubearjZwZaTt_u2-sM3JDlGQD9HmJ3rQ,1917
|
23
23
|
grasp_agents/openai/__init__.py,sha256=qN8HMAatSJKOsA6v-JwakMYguwkswCVHqrmK1gFy9wI,3096
|
24
|
-
grasp_agents/openai/completion_converters.py,sha256=
|
24
|
+
grasp_agents/openai/completion_converters.py,sha256=lX9h1kaGAo5ttsl-4V7l4x8IpjxJaJJtyU2cKu3-EOc,1871
|
25
25
|
grasp_agents/openai/content_converters.py,sha256=6GI0D7xJalzsiawAJOyCUzTJTo0NQdpv87YKmfN0LYQ,2631
|
26
|
-
grasp_agents/openai/converters.py,sha256=
|
27
|
-
grasp_agents/openai/message_converters.py,sha256=
|
28
|
-
grasp_agents/openai/openai_llm.py,sha256=
|
29
|
-
grasp_agents/openai/tool_converters.py,sha256=
|
26
|
+
grasp_agents/openai/converters.py,sha256=DBXBxow9oRG6pc8inpZBLiuUqHzVfpscmHFpN9bAdvc,5276
|
27
|
+
grasp_agents/openai/message_converters.py,sha256=KjF6FbXzwlWdM-1YT3cswUV-74sjiwOhLFPMY4sJ5Xk,4593
|
28
|
+
grasp_agents/openai/openai_llm.py,sha256=rHeix_lSH8mogPfHjgrsaYMrHBMt5q9k-4x33_-zwyA,6149
|
29
|
+
grasp_agents/openai/tool_converters.py,sha256=KhWRETkjhjocISUo_HBZ8QfBiyTOoC5WurPNAR4BYxc,1027
|
30
30
|
grasp_agents/typing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
grasp_agents/typing/completion.py,sha256=_KDLx3Gtz7o-pEZrvAFgCZwDmkr2oQkxrL-2LSXHHsw,657
|
32
32
|
grasp_agents/typing/content.py,sha256=13nLNZqZgtpo9sM0vCRQmZ4bQjjZqUSElMQOwjL7bO8,3651
|
33
|
-
grasp_agents/typing/converters.py,sha256=
|
34
|
-
grasp_agents/typing/io.py,sha256=
|
35
|
-
grasp_agents/typing/message.py,sha256=
|
36
|
-
grasp_agents/typing/tool.py,sha256=
|
33
|
+
grasp_agents/typing/converters.py,sha256=EbGur_Ngx-wINfyOEHa7JqmKnMAxZ5vJjhAPoqBf_AM,3048
|
34
|
+
grasp_agents/typing/io.py,sha256=M294tNROa9gJ_I38c9pLekEbUK2p3qirYdE_QGpuw1c,624
|
35
|
+
grasp_agents/typing/message.py,sha256=oCpqD_CV2Da-M-l-e5liFJSwK8267fxfcU68LIc7C1E,3801
|
36
|
+
grasp_agents/typing/tool.py,sha256=t40vr3ljQY_Qx0f0KZIc151bYlHUF7Bf7JfOkiJOi2c,1657
|
37
37
|
grasp_agents/workflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
38
38
|
grasp_agents/workflow/looped_agent.py,sha256=YBOgOIvy3_NwKvEoGgzQJ2fY9SNG66MQk6obSBGWvCc,3896
|
39
39
|
grasp_agents/workflow/sequential_agent.py,sha256=yDt2nA-b1leVByD8jsKrWD6bHe0o9z33jrOJGOLwbyk,2004
|
40
40
|
grasp_agents/workflow/workflow_agent.py,sha256=9U94IQ39Vb1W_5u8aoqHb65ikdarEhEJkexDz8xwHD4,2294
|
41
|
-
grasp_agents-0.1.
|
42
|
-
grasp_agents-0.1.
|
43
|
-
grasp_agents-0.1.
|
44
|
-
grasp_agents-0.1.
|
41
|
+
grasp_agents-0.1.16.dist-info/METADATA,sha256=s8QElgMOrmr0ZXlfWgRvPbz-kyfDtUBPnCu2EPCUsBY,4254
|
42
|
+
grasp_agents-0.1.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
43
|
+
grasp_agents-0.1.16.dist-info/licenses/LICENSE.md,sha256=Kfeo0gdlLS6tLQiWwO9UWhjp9-f93a5kShSiBp2FG-c,1201
|
44
|
+
grasp_agents-0.1.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|