grasp_agents 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/LICENSE.md +1 -1
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/PKG-INFO +1 -1
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/pyproject.toml +1 -1
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/cloud_llm.py +1 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/llm_agent.py +70 -21
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/llm_agent_state.py +9 -7
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/prompt_builder.py +11 -4
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/tool_orchestrator.py +8 -8
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/content.py +2 -2
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/io.py +3 -2
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/message.py +2 -2
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/tool.py +0 -7
- grasp_agents-0.2.1/src/grasp_agents/utils.py +194 -0
- grasp_agents-0.2.0/src/grasp_agents/utils.py +0 -187
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/.gitignore +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/README.md +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/agent_message.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/agent_message_pool.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/base_agent.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/comm_agent.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/costs_dict.yaml +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/generics_utils.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/grasp_logging.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/http_client.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/llm.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/memory.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/__init__.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/completion_converters.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/content_converters.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/converters.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/message_converters.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/openai_llm.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/tool_converters.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/printer.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/rate_limiting/__init__.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/rate_limiting/types.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/rate_limiting/utils.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/run_context.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/__init__.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/completion.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/converters.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/usage_tracker.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/workflow/__init__.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/workflow/looped_agent.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/workflow/sequential_agent.py +0 -0
- {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/workflow/workflow_agent.py +0 -0
@@ -8,6 +8,6 @@ Package production dependencies are licensed under the following terms:
|
|
8
8
|
| dotenv | 0.9.9 | BSD-3-Clause license | https://github.com/pedroburon/dotenv |
|
9
9
|
| httpx | 0.28.1 | BSD License | https://github.com/encode/httpx |
|
10
10
|
| openai | 1.77.0 | Apache Software License | https://github.com/openai/openai-python |
|
11
|
-
| tenacity |
|
11
|
+
| tenacity | 8.5.0 | Apache Software License | https://github.com/jd/tenacity |
|
12
12
|
| termcolor | 2.5.0 | MIT License | https://github.com/termcolor/termcolor |
|
13
13
|
| tqdm | 4.67.1 | MIT License; Mozilla Public License 2.0 (MPL 2.0) | https://tqdm.github.io |
|
@@ -10,7 +10,7 @@ from .comm_agent import CommunicatingAgent
|
|
10
10
|
from .llm import LLM, LLMSettings
|
11
11
|
from .llm_agent_state import (
|
12
12
|
LLMAgentState,
|
13
|
-
|
13
|
+
SetAgentState,
|
14
14
|
SetAgentStateStrategy,
|
15
15
|
)
|
16
16
|
from .prompt_builder import (
|
@@ -25,7 +25,11 @@ from .run_context import (
|
|
25
25
|
SystemRunArgs,
|
26
26
|
UserRunArgs,
|
27
27
|
)
|
28
|
-
from .tool_orchestrator import
|
28
|
+
from .tool_orchestrator import (
|
29
|
+
ExitToolCallLoopHandler,
|
30
|
+
ManageAgentStateHandler,
|
31
|
+
ToolOrchestrator,
|
32
|
+
)
|
29
33
|
from .typing.content import ImageData
|
30
34
|
from .typing.converters import Converters
|
31
35
|
from .typing.io import (
|
@@ -43,9 +47,14 @@ from .typing.tool import BaseTool
|
|
43
47
|
from .utils import get_prompt, validate_obj_from_json_or_py_string
|
44
48
|
|
45
49
|
|
46
|
-
class ParseOutputHandler(Protocol[OutT, CtxT]):
|
50
|
+
class ParseOutputHandler(Protocol[InT, OutT, CtxT]):
|
47
51
|
def __call__(
|
48
|
-
self,
|
52
|
+
self,
|
53
|
+
conversation: Conversation,
|
54
|
+
*,
|
55
|
+
rcv_args: InT | None,
|
56
|
+
batch_idx: int,
|
57
|
+
ctx: RunContextWrapper[CtxT] | None,
|
49
58
|
) -> OutT: ...
|
50
59
|
|
51
60
|
|
@@ -91,9 +100,15 @@ class LLMAgent(
|
|
91
100
|
# Agent state
|
92
101
|
self._state: LLMAgentState = LLMAgentState()
|
93
102
|
self.set_state_strategy: SetAgentStateStrategy = set_state_strategy
|
94
|
-
self.
|
103
|
+
self._set_agent_state_impl: SetAgentState | None = None
|
95
104
|
|
96
105
|
# Tool orchestrator
|
106
|
+
|
107
|
+
self._using_default_llm_response_format: bool = False
|
108
|
+
if llm.response_format is None and tools is None:
|
109
|
+
llm.response_format = self.out_type
|
110
|
+
self._using_default_llm_response_format = True
|
111
|
+
|
97
112
|
self._tool_orchestrator: ToolOrchestrator[CtxT] = ToolOrchestrator[CtxT](
|
98
113
|
agent_id=self.agent_id,
|
99
114
|
llm=llm,
|
@@ -152,16 +167,28 @@ class LLMAgent(
|
|
152
167
|
conversation: Conversation,
|
153
168
|
*,
|
154
169
|
rcv_args: InT | None = None,
|
170
|
+
batch_idx: int = 0,
|
155
171
|
ctx: RunContextWrapper[CtxT] | None = None,
|
156
|
-
**kwargs: Any,
|
157
172
|
) -> OutT:
|
158
173
|
if self._parse_output_impl:
|
174
|
+
if self._using_default_llm_response_format:
|
175
|
+
# When using custom output parsing, the required LLM response format
|
176
|
+
# can differ from the final agent output type ->
|
177
|
+
# set it back to None unless it was specified explicitly at init.
|
178
|
+
self._tool_orchestrator.llm.response_format = None
|
179
|
+
self._using_default_llm_response_format = False
|
180
|
+
|
159
181
|
return self._parse_output_impl(
|
160
|
-
conversation=conversation,
|
182
|
+
conversation=conversation,
|
183
|
+
rcv_args=rcv_args,
|
184
|
+
batch_idx=batch_idx,
|
185
|
+
ctx=ctx,
|
161
186
|
)
|
162
187
|
|
163
188
|
return validate_obj_from_json_or_py_string(
|
164
|
-
str(conversation[-1].content),
|
189
|
+
str(conversation[-1].content),
|
190
|
+
adapter=self._out_type_adapter,
|
191
|
+
from_substring=True,
|
165
192
|
)
|
166
193
|
|
167
194
|
@final
|
@@ -210,7 +237,7 @@ class LLMAgent(
|
|
210
237
|
rcv_state=rcv_state,
|
211
238
|
sys_prompt=formatted_sys_prompt,
|
212
239
|
strategy=self.set_state_strategy,
|
213
|
-
|
240
|
+
set_agent_state_impl=self._set_agent_state_impl,
|
214
241
|
ctx=ctx,
|
215
242
|
)
|
216
243
|
|
@@ -234,7 +261,7 @@ class LLMAgent(
|
|
234
261
|
else:
|
235
262
|
# 4. Run tool call loop (new messages are added to the message
|
236
263
|
# history inside the loop)
|
237
|
-
await self._tool_orchestrator.run_loop(
|
264
|
+
await self._tool_orchestrator.run_loop(state=state, ctx=ctx)
|
238
265
|
|
239
266
|
# 5. Parse outputs
|
240
267
|
batch_size = state.message_history.batch_size
|
@@ -324,15 +351,13 @@ class LLMAgent(
|
|
324
351
|
return func
|
325
352
|
|
326
353
|
def parse_output_handler(
|
327
|
-
self, func: ParseOutputHandler[OutT, CtxT]
|
328
|
-
) -> ParseOutputHandler[OutT, CtxT]:
|
354
|
+
self, func: ParseOutputHandler[InT, OutT, CtxT]
|
355
|
+
) -> ParseOutputHandler[InT, OutT, CtxT]:
|
329
356
|
self._parse_output_impl = func
|
330
357
|
|
331
358
|
return func
|
332
359
|
|
333
|
-
def
|
334
|
-
self, func: MakeCustomAgentState
|
335
|
-
) -> MakeCustomAgentState:
|
360
|
+
def set_agent_state_handler(self, func: SetAgentState) -> SetAgentState:
|
336
361
|
self._make_custom_agent_state_impl = func
|
337
362
|
|
338
363
|
return func
|
@@ -344,6 +369,13 @@ class LLMAgent(
|
|
344
369
|
|
345
370
|
return func
|
346
371
|
|
372
|
+
def manage_agent_state_handler(
|
373
|
+
self, func: ManageAgentStateHandler[CtxT]
|
374
|
+
) -> ManageAgentStateHandler[CtxT]:
|
375
|
+
self._tool_orchestrator.manage_agent_state_impl = func
|
376
|
+
|
377
|
+
return func
|
378
|
+
|
347
379
|
# -- Override these methods in subclasses if needed --
|
348
380
|
|
349
381
|
def _register_overridden_handlers(self) -> None:
|
@@ -356,15 +388,18 @@ class LLMAgent(
|
|
356
388
|
if cur_cls._format_inp_args is not base_cls._format_inp_args: # noqa: SLF001
|
357
389
|
self._prompt_builder.format_inp_args_impl = self._format_inp_args
|
358
390
|
|
359
|
-
if cur_cls.
|
360
|
-
self.
|
391
|
+
if cur_cls._set_agent_state is not base_cls._set_agent_state: # noqa: SLF001
|
392
|
+
self._set_agent_state_impl = self._set_agent_state
|
393
|
+
|
394
|
+
if cur_cls._manage_agent_state is not base_cls._manage_agent_state: # noqa: SLF001
|
395
|
+
self._tool_orchestrator.manage_agent_state_impl = self._manage_agent_state
|
361
396
|
|
362
397
|
if (
|
363
398
|
cur_cls._tool_call_loop_exit is not base_cls._tool_call_loop_exit # noqa: SLF001
|
364
399
|
):
|
365
400
|
self._tool_orchestrator.exit_tool_call_loop_impl = self._tool_call_loop_exit
|
366
401
|
|
367
|
-
self._parse_output_impl: ParseOutputHandler[OutT, CtxT] | None = None
|
402
|
+
self._parse_output_impl: ParseOutputHandler[InT, OutT, CtxT] | None = None
|
368
403
|
|
369
404
|
def _format_sys_args(
|
370
405
|
self,
|
@@ -381,21 +416,24 @@ class LLMAgent(
|
|
381
416
|
self,
|
382
417
|
usr_args: LLMPromptArgs,
|
383
418
|
rcv_args: InT,
|
419
|
+
*,
|
420
|
+
batch_idx: int = 0,
|
384
421
|
ctx: RunContextWrapper[CtxT] | None = None,
|
385
422
|
) -> LLMFormattedArgs:
|
386
423
|
raise NotImplementedError(
|
387
424
|
"LLMAgent._format_inp_args must be overridden by a subclass"
|
388
425
|
)
|
389
426
|
|
390
|
-
def
|
427
|
+
def _set_agent_state(
|
391
428
|
self,
|
392
|
-
cur_state: LLMAgentState
|
429
|
+
cur_state: LLMAgentState,
|
430
|
+
*,
|
393
431
|
rcv_state: AgentState | None,
|
394
432
|
sys_prompt: LLMPrompt | None,
|
395
433
|
ctx: RunContextWrapper[Any] | None,
|
396
434
|
) -> LLMAgentState:
|
397
435
|
raise NotImplementedError(
|
398
|
-
"LLMAgent.
|
436
|
+
"LLMAgent._set_agent_state_handler must be overridden by a subclass"
|
399
437
|
)
|
400
438
|
|
401
439
|
def _tool_call_loop_exit(
|
@@ -408,3 +446,14 @@ class LLMAgent(
|
|
408
446
|
raise NotImplementedError(
|
409
447
|
"LLMAgent._tool_call_loop_exit must be overridden by a subclass"
|
410
448
|
)
|
449
|
+
|
450
|
+
def _manage_agent_state(
|
451
|
+
self,
|
452
|
+
state: LLMAgentState,
|
453
|
+
*,
|
454
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
455
|
+
**kwargs: Any,
|
456
|
+
) -> None:
|
457
|
+
raise NotImplementedError(
|
458
|
+
"LLMAgent._manage_agent_state must be overridden by a subclass"
|
459
|
+
)
|
@@ -10,10 +10,11 @@ from .typing.io import AgentState, LLMPrompt
|
|
10
10
|
SetAgentStateStrategy = Literal["keep", "reset", "from_sender", "custom"]
|
11
11
|
|
12
12
|
|
13
|
-
class
|
13
|
+
class SetAgentState(Protocol):
|
14
14
|
def __call__(
|
15
15
|
self,
|
16
|
-
cur_state:
|
16
|
+
cur_state: "LLMAgentState",
|
17
|
+
*,
|
17
18
|
rcv_state: AgentState | None,
|
18
19
|
sys_prompt: LLMPrompt | None,
|
19
20
|
ctx: RunContextWrapper[Any] | None,
|
@@ -30,11 +31,12 @@ class LLMAgentState(AgentState):
|
|
30
31
|
@classmethod
|
31
32
|
def from_cur_and_rcv_states(
|
32
33
|
cls,
|
33
|
-
cur_state:
|
34
|
+
cur_state: "LLMAgentState",
|
35
|
+
*,
|
34
36
|
rcv_state: Optional["AgentState"] = None,
|
35
37
|
sys_prompt: LLMPrompt | None = None,
|
36
38
|
strategy: SetAgentStateStrategy = "from_sender",
|
37
|
-
|
39
|
+
set_agent_state_impl: SetAgentState | None = None,
|
38
40
|
ctx: RunContextWrapper[Any] | None = None,
|
39
41
|
) -> "LLMAgentState":
|
40
42
|
upd_mh = cur_state.message_history if cur_state else None
|
@@ -59,10 +61,10 @@ class LLMAgentState(AgentState):
|
|
59
61
|
upd_mh.reset(sys_prompt)
|
60
62
|
|
61
63
|
elif strategy == "custom":
|
62
|
-
assert
|
63
|
-
"
|
64
|
+
assert set_agent_state_impl is not None, (
|
65
|
+
"Agent state setter implementation is not provided."
|
64
66
|
)
|
65
|
-
return
|
67
|
+
return set_agent_state_impl(
|
66
68
|
cur_state=cur_state,
|
67
69
|
rcv_state=rcv_state,
|
68
70
|
sys_prompt=sys_prompt,
|
@@ -35,6 +35,8 @@ class FormatInputArgsHandler(Protocol[InT, CtxT]):
|
|
35
35
|
self,
|
36
36
|
usr_args: LLMPromptArgs,
|
37
37
|
rcv_args: InT,
|
38
|
+
*,
|
39
|
+
batch_idx: int,
|
38
40
|
ctx: RunContextWrapper[CtxT] | None,
|
39
41
|
) -> LLMFormattedArgs: ...
|
40
42
|
|
@@ -77,11 +79,13 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
77
79
|
self,
|
78
80
|
usr_args: LLMPromptArgs,
|
79
81
|
rcv_args: InT,
|
82
|
+
*,
|
83
|
+
batch_idx: int = 0,
|
80
84
|
ctx: RunContextWrapper[CtxT] | None = None,
|
81
85
|
) -> LLMFormattedArgs:
|
82
86
|
if self.format_inp_args_impl:
|
83
87
|
return self.format_inp_args_impl(
|
84
|
-
usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
|
88
|
+
usr_args=usr_args, rcv_args=rcv_args, batch_idx=batch_idx, ctx=ctx
|
85
89
|
)
|
86
90
|
|
87
91
|
if not isinstance(rcv_args, BaseModel) and rcv_args is not None:
|
@@ -100,6 +104,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
100
104
|
def make_sys_prompt(
|
101
105
|
self,
|
102
106
|
sys_args: LLMPromptArgs,
|
107
|
+
*,
|
103
108
|
ctx: RunContextWrapper[CtxT] | None,
|
104
109
|
) -> LLMPrompt | None:
|
105
110
|
if self.sys_prompt is None:
|
@@ -151,9 +156,11 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
151
156
|
]
|
152
157
|
|
153
158
|
formatted_inp_args_batch = [
|
154
|
-
self._format_inp_args(
|
155
|
-
|
156
|
-
|
159
|
+
self._format_inp_args(
|
160
|
+
usr_args=val_usr_args, rcv_args=val_rcv_args, batch_idx=i, ctx=ctx
|
161
|
+
)
|
162
|
+
for i, (val_usr_args, val_rcv_args) in enumerate(
|
163
|
+
zip(val_usr_args_batch_, val_rcv_args_batch_, strict=False)
|
157
164
|
)
|
158
165
|
]
|
159
166
|
|
@@ -29,7 +29,7 @@ class ExitToolCallLoopHandler(Protocol[CtxT]):
|
|
29
29
|
class ManageAgentStateHandler(Protocol[CtxT]):
|
30
30
|
def __call__(
|
31
31
|
self,
|
32
|
-
|
32
|
+
state: LLMAgentState,
|
33
33
|
*,
|
34
34
|
ctx: RunContextWrapper[CtxT] | None,
|
35
35
|
**kwargs: Any,
|
@@ -93,13 +93,13 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
93
93
|
|
94
94
|
def _manage_agent_state(
|
95
95
|
self,
|
96
|
-
|
96
|
+
state: LLMAgentState,
|
97
97
|
*,
|
98
98
|
ctx: RunContextWrapper[CtxT] | None = None,
|
99
99
|
**kwargs: Any,
|
100
100
|
) -> None:
|
101
101
|
if self.manage_agent_state_impl:
|
102
|
-
self.manage_agent_state_impl(
|
102
|
+
self.manage_agent_state_impl(state=state, ctx=ctx, **kwargs)
|
103
103
|
|
104
104
|
async def generate_once(
|
105
105
|
self,
|
@@ -119,10 +119,10 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
119
119
|
|
120
120
|
async def run_loop(
|
121
121
|
self,
|
122
|
-
|
122
|
+
state: LLMAgentState,
|
123
123
|
ctx: RunContextWrapper[CtxT] | None = None,
|
124
124
|
) -> None:
|
125
|
-
message_history =
|
125
|
+
message_history = state.message_history
|
126
126
|
assert message_history.batch_size == 1, (
|
127
127
|
"Batch size must be 1 for tool call loop"
|
128
128
|
)
|
@@ -131,13 +131,13 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
131
131
|
|
132
132
|
tool_choice = "none" if self._react_mode else "auto"
|
133
133
|
gen_message_batch = await self.generate_once(
|
134
|
-
|
134
|
+
state, tool_choice=tool_choice, ctx=ctx
|
135
135
|
)
|
136
136
|
|
137
137
|
turns = 0
|
138
138
|
|
139
139
|
while True:
|
140
|
-
self._manage_agent_state(
|
140
|
+
self._manage_agent_state(state=state, ctx=ctx, num_turns=turns)
|
141
141
|
|
142
142
|
if self._exit_tool_call_loop(
|
143
143
|
message_history.batched_conversations[0], ctx=ctx, num_turns=turns
|
@@ -156,7 +156,7 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
156
156
|
|
157
157
|
tool_choice = "none" if (self._react_mode and msg.tool_calls) else "auto"
|
158
158
|
gen_message_batch = await self.generate_once(
|
159
|
-
|
159
|
+
state, tool_choice=tool_choice, ctx=ctx
|
160
160
|
)
|
161
161
|
|
162
162
|
turns += 1
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import base64
|
2
2
|
import re
|
3
|
-
from collections.abc import Iterable
|
3
|
+
from collections.abc import Iterable, Mapping
|
4
4
|
from enum import StrEnum
|
5
5
|
from pathlib import Path
|
6
6
|
from typing import Annotated, Any, Literal, TypeAlias
|
@@ -66,7 +66,7 @@ class Content(BaseModel):
|
|
66
66
|
def from_formatted_prompt(
|
67
67
|
cls,
|
68
68
|
prompt_template: str,
|
69
|
-
prompt_args:
|
69
|
+
prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
|
70
70
|
) -> "Content":
|
71
71
|
prompt_args = prompt_args or {}
|
72
72
|
image_args = {
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from collections.abc import Mapping
|
1
2
|
from typing import TypeAlias, TypeVar
|
2
3
|
|
3
4
|
from pydantic import BaseModel
|
@@ -20,5 +21,5 @@ OutT = TypeVar("OutT", covariant=True) # noqa: PLC0105
|
|
20
21
|
StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
|
21
22
|
|
22
23
|
LLMPrompt: TypeAlias = str
|
23
|
-
LLMFormattedSystemArgs: TypeAlias =
|
24
|
-
LLMFormattedArgs: TypeAlias =
|
24
|
+
LLMFormattedSystemArgs: TypeAlias = Mapping[str, str | int | bool]
|
25
|
+
LLMFormattedArgs: TypeAlias = Mapping[str, str | int | bool | ImageData]
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from collections.abc import Hashable, Sequence
|
2
|
+
from collections.abc import Hashable, Mapping, Sequence
|
3
3
|
from enum import StrEnum
|
4
4
|
from typing import Annotated, Any, Literal, TypeAlias
|
5
5
|
from uuid import uuid4
|
@@ -79,7 +79,7 @@ class UserMessage(MessageBase):
|
|
79
79
|
def from_formatted_prompt(
|
80
80
|
cls,
|
81
81
|
prompt_template: str,
|
82
|
-
prompt_args:
|
82
|
+
prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
|
83
83
|
model_id: str | None = None,
|
84
84
|
) -> "UserMessage":
|
85
85
|
content = Content.from_formatted_prompt(
|
@@ -1,8 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
from abc import ABC, abstractmethod
|
5
|
-
from collections.abc import Sequence
|
6
4
|
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
|
7
5
|
|
8
6
|
from pydantic import BaseModel, PrivateAttr, TypeAdapter
|
@@ -60,11 +58,6 @@ class BaseTool(
|
|
60
58
|
) -> _ToolOutT:
|
61
59
|
pass
|
62
60
|
|
63
|
-
async def run_batch(
|
64
|
-
self, inp_batch: Sequence[_ToolInT], ctx: RunContextWrapper[CtxT] | None = None
|
65
|
-
) -> Sequence[_ToolOutT]:
|
66
|
-
return await asyncio.gather(*[self.run(inp, ctx=ctx) for inp in inp_batch])
|
67
|
-
|
68
61
|
async def __call__(
|
69
62
|
self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
|
70
63
|
) -> _ToolOutT:
|
@@ -0,0 +1,194 @@
|
|
1
|
+
import ast
|
2
|
+
import asyncio
|
3
|
+
import json
|
4
|
+
import re
|
5
|
+
from collections.abc import Coroutine, Mapping
|
6
|
+
from datetime import datetime
|
7
|
+
from logging import getLogger
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Any, TypeVar
|
10
|
+
|
11
|
+
from pydantic import (
|
12
|
+
GetCoreSchemaHandler,
|
13
|
+
TypeAdapter,
|
14
|
+
ValidationError,
|
15
|
+
)
|
16
|
+
from pydantic_core import core_schema
|
17
|
+
from tqdm.autonotebook import tqdm
|
18
|
+
|
19
|
+
logger = getLogger(__name__)
|
20
|
+
|
21
|
+
_JSON_START_RE = re.compile(r"[{\[]")
|
22
|
+
|
23
|
+
T = TypeVar("T")
|
24
|
+
|
25
|
+
|
26
|
+
def extract_json_substring(text: str) -> str | None:
|
27
|
+
decoder = json.JSONDecoder()
|
28
|
+
for match in _JSON_START_RE.finditer(text):
|
29
|
+
start = match.start()
|
30
|
+
try:
|
31
|
+
_, end = decoder.raw_decode(text, idx=start)
|
32
|
+
return text[start:end]
|
33
|
+
except ValueError:
|
34
|
+
continue
|
35
|
+
|
36
|
+
return None
|
37
|
+
|
38
|
+
|
39
|
+
def parse_json_or_py_string(
|
40
|
+
s: str, return_none_on_failure: bool = False
|
41
|
+
) -> dict[str, Any] | list[Any] | None:
|
42
|
+
s_fmt = re.sub(r"```[a-zA-Z0-9]*\n|```", "", s).strip()
|
43
|
+
try:
|
44
|
+
return ast.literal_eval(s_fmt)
|
45
|
+
except (ValueError, SyntaxError):
|
46
|
+
try:
|
47
|
+
return json.loads(s_fmt)
|
48
|
+
except json.JSONDecodeError as exc:
|
49
|
+
if return_none_on_failure:
|
50
|
+
return None
|
51
|
+
raise ValueError(
|
52
|
+
"Invalid JSON/Python string - Both ast.literal_eval and json.loads "
|
53
|
+
f"failed to parse the following response:\n{s}"
|
54
|
+
) from exc
|
55
|
+
|
56
|
+
|
57
|
+
def parse_json_or_py_substring(
|
58
|
+
json_str: str, return_none_on_failure: bool = False
|
59
|
+
) -> dict[str, Any] | list[Any] | None:
|
60
|
+
return parse_json_or_py_string(
|
61
|
+
extract_json_substring(json_str) or "", return_none_on_failure
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
def validate_obj_from_json_or_py_string(
|
66
|
+
s: str, adapter: TypeAdapter[T], from_substring: bool = False
|
67
|
+
) -> T:
|
68
|
+
try:
|
69
|
+
if from_substring:
|
70
|
+
parsed = parse_json_or_py_substring(s, return_none_on_failure=True)
|
71
|
+
else:
|
72
|
+
parsed = parse_json_or_py_string(s, return_none_on_failure=True)
|
73
|
+
if parsed is None:
|
74
|
+
parsed = s
|
75
|
+
return adapter.validate_python(parsed)
|
76
|
+
except (json.JSONDecodeError, ValidationError) as exc:
|
77
|
+
raise ValueError(
|
78
|
+
f"Invalid JSON or Python string:\n{s}\nExpected type: {adapter._type}", # type: ignore[arg-type]
|
79
|
+
) from exc
|
80
|
+
|
81
|
+
|
82
|
+
def extract_xml_list(text: str) -> list[str]:
|
83
|
+
pattern = re.compile(r"<(chunk_\d+)>(.*?)</\1>", re.DOTALL)
|
84
|
+
|
85
|
+
chunks: list[str] = []
|
86
|
+
for match in pattern.finditer(text):
|
87
|
+
content = match.group(2).strip()
|
88
|
+
chunks.append(content)
|
89
|
+
return chunks
|
90
|
+
|
91
|
+
|
92
|
+
def build_marker_json_parser_type(
|
93
|
+
marker_to_model: Mapping[str, type],
|
94
|
+
) -> type:
|
95
|
+
"""
|
96
|
+
Return a Pydantic-compatible *type* that, when given a **str**, searches for
|
97
|
+
the first marker substring and validates the JSON that follows with the
|
98
|
+
corresponding Pydantic model.
|
99
|
+
|
100
|
+
If no marker is found, the raw string is returned unchanged.
|
101
|
+
|
102
|
+
Example:
|
103
|
+
-------
|
104
|
+
>>> Todo = build_marker_json_parser_type({'```json': MyModel})
|
105
|
+
>>> Todo.validate('```json {"a": 1}')
|
106
|
+
MyModel(a=1)
|
107
|
+
|
108
|
+
"""
|
109
|
+
|
110
|
+
class MarkerParsedOutput:
|
111
|
+
"""String → (Model | str) parser generated by build_marker_json_parser_type."""
|
112
|
+
|
113
|
+
@classmethod
|
114
|
+
def __get_pydantic_core_schema__(
|
115
|
+
cls,
|
116
|
+
_source_type: Any,
|
117
|
+
_handler: GetCoreSchemaHandler,
|
118
|
+
) -> core_schema.CoreSchema:
|
119
|
+
def _validate(value: Any) -> Any:
|
120
|
+
if not isinstance(value, str):
|
121
|
+
raise TypeError("MarkerParsedOutput expects a string")
|
122
|
+
|
123
|
+
for marker, model in marker_to_model.items():
|
124
|
+
if marker in value:
|
125
|
+
adapter = TypeAdapter[Any](model)
|
126
|
+
return validate_obj_from_json_or_py_string(
|
127
|
+
value, adapter=adapter, from_substring=True
|
128
|
+
)
|
129
|
+
|
130
|
+
return value
|
131
|
+
|
132
|
+
return core_schema.no_info_after_validator_function(
|
133
|
+
_validate, core_schema.any_schema()
|
134
|
+
)
|
135
|
+
|
136
|
+
@classmethod
|
137
|
+
def __get_pydantic_json_schema__(
|
138
|
+
cls,
|
139
|
+
schema: core_schema.CoreSchema,
|
140
|
+
handler: GetCoreSchemaHandler,
|
141
|
+
):
|
142
|
+
return handler(schema)
|
143
|
+
|
144
|
+
unique_suffix = "_".join(sorted(marker_to_model))[:40]
|
145
|
+
MarkerParsedOutput.__name__ = f"MarkerParsedOutput_{unique_suffix}"
|
146
|
+
|
147
|
+
return MarkerParsedOutput
|
148
|
+
|
149
|
+
|
150
|
+
def read_txt(file_path: str | Path, encoding: str = "utf-8") -> str:
|
151
|
+
return Path(file_path).read_text(encoding=encoding)
|
152
|
+
|
153
|
+
|
154
|
+
def read_contents_from_file(
|
155
|
+
file_path: str | Path,
|
156
|
+
binary_mode: bool = False,
|
157
|
+
) -> str | bytes:
|
158
|
+
try:
|
159
|
+
if binary_mode:
|
160
|
+
return Path(file_path).read_bytes()
|
161
|
+
return Path(file_path).read_text()
|
162
|
+
except FileNotFoundError:
|
163
|
+
logger.error(f"File {file_path} not found.")
|
164
|
+
return ""
|
165
|
+
|
166
|
+
|
167
|
+
def get_prompt(prompt_text: str | None, prompt_path: str | Path | None) -> str | None:
|
168
|
+
if prompt_text is None:
|
169
|
+
return read_contents_from_file(prompt_path) if prompt_path is not None else None # type: ignore[arg-type]
|
170
|
+
|
171
|
+
return prompt_text
|
172
|
+
|
173
|
+
|
174
|
+
async def asyncio_gather_with_pbar(
|
175
|
+
*corouts: Coroutine[Any, Any, Any],
|
176
|
+
no_tqdm: bool = False,
|
177
|
+
desc: str | None = None,
|
178
|
+
) -> list[Any]:
|
179
|
+
pbar = tqdm(total=len(corouts), desc=desc, disable=no_tqdm)
|
180
|
+
|
181
|
+
async def run_and_update(coro: Coroutine[Any, Any, Any]) -> Any:
|
182
|
+
result = await coro
|
183
|
+
pbar.update(1)
|
184
|
+
return result
|
185
|
+
|
186
|
+
wrapped_tasks = [run_and_update(c) for c in corouts]
|
187
|
+
results = await asyncio.gather(*wrapped_tasks)
|
188
|
+
pbar.close()
|
189
|
+
|
190
|
+
return results
|
191
|
+
|
192
|
+
|
193
|
+
def get_timestamp() -> str:
|
194
|
+
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
@@ -1,187 +0,0 @@
|
|
1
|
-
import ast
|
2
|
-
import asyncio
|
3
|
-
import json
|
4
|
-
import re
|
5
|
-
from collections.abc import Coroutine
|
6
|
-
from datetime import datetime
|
7
|
-
from logging import getLogger
|
8
|
-
from pathlib import Path
|
9
|
-
from typing import Any, TypeVar
|
10
|
-
|
11
|
-
from pydantic import (
|
12
|
-
BaseModel,
|
13
|
-
GetCoreSchemaHandler,
|
14
|
-
TypeAdapter,
|
15
|
-
ValidationError,
|
16
|
-
)
|
17
|
-
from pydantic_core import core_schema
|
18
|
-
from tqdm.autonotebook import tqdm
|
19
|
-
|
20
|
-
logger = getLogger(__name__)
|
21
|
-
|
22
|
-
T = TypeVar("T")
|
23
|
-
|
24
|
-
|
25
|
-
def filter_fields(data: dict[str, Any], model: type[BaseModel]) -> dict[str, Any]:
|
26
|
-
return {key: data[key] for key in model.model_fields if key in data}
|
27
|
-
|
28
|
-
|
29
|
-
def read_txt(file_path: str) -> str:
|
30
|
-
return Path(file_path).read_text()
|
31
|
-
|
32
|
-
|
33
|
-
def format_json_string(text: str) -> str:
|
34
|
-
decoder = json.JSONDecoder()
|
35
|
-
text = text.replace("\n", "")
|
36
|
-
length = len(text)
|
37
|
-
i = 0
|
38
|
-
while i < length:
|
39
|
-
ch = text[i]
|
40
|
-
if ch in "{[":
|
41
|
-
try:
|
42
|
-
_, end = decoder.raw_decode(text[i:])
|
43
|
-
return text[i : i + end]
|
44
|
-
except ValueError:
|
45
|
-
pass
|
46
|
-
i += 1
|
47
|
-
|
48
|
-
return text
|
49
|
-
|
50
|
-
|
51
|
-
def parse_json_or_py_string(
|
52
|
-
json_str: str, return_none_on_failure: bool = False
|
53
|
-
) -> dict[str, Any] | list[Any] | None:
|
54
|
-
try:
|
55
|
-
json_response = ast.literal_eval(json_str)
|
56
|
-
except (ValueError, SyntaxError):
|
57
|
-
try:
|
58
|
-
json_response = json.loads(json_str)
|
59
|
-
except json.JSONDecodeError as exc:
|
60
|
-
if return_none_on_failure:
|
61
|
-
return None
|
62
|
-
raise ValueError(
|
63
|
-
"Invalid JSON - Both ast.literal_eval and json.loads "
|
64
|
-
f"failed to parse the following response:\n{json_str}"
|
65
|
-
) from exc
|
66
|
-
|
67
|
-
return json_response
|
68
|
-
|
69
|
-
|
70
|
-
def extract_json(
|
71
|
-
json_str: str, return_none_on_failure: bool = False
|
72
|
-
) -> dict[str, Any] | list[Any] | None:
|
73
|
-
return parse_json_or_py_string(format_json_string(json_str), return_none_on_failure)
|
74
|
-
|
75
|
-
|
76
|
-
def validate_obj_from_json_or_py_string(s: str, adapter: TypeAdapter[T]) -> T:
|
77
|
-
s_fmt = re.sub(r"```[a-zA-Z0-9]*\n|```", "", s).strip()
|
78
|
-
try:
|
79
|
-
parsed = json.loads(s_fmt)
|
80
|
-
return adapter.validate_python(parsed)
|
81
|
-
except (json.JSONDecodeError, ValidationError):
|
82
|
-
try:
|
83
|
-
return adapter.validate_python(s_fmt)
|
84
|
-
except ValidationError as exc:
|
85
|
-
raise ValueError(
|
86
|
-
f"Invalid JSON or Python string:\n{s}\nExpected type: {adapter._type}", # type: ignore[arg-type]
|
87
|
-
) from exc
|
88
|
-
|
89
|
-
|
90
|
-
def extract_xml_list(text: str) -> list[str]:
|
91
|
-
pattern = re.compile(r"<(chunk_\d+)>(.*?)</\1>", re.DOTALL)
|
92
|
-
|
93
|
-
chunks: list[str] = []
|
94
|
-
for match in pattern.finditer(text):
|
95
|
-
content = match.group(2).strip()
|
96
|
-
chunks.append(content)
|
97
|
-
return chunks
|
98
|
-
|
99
|
-
|
100
|
-
def make_conditional_parsed_output_type(
|
101
|
-
response_format: type, marker: str = "<DONE>"
|
102
|
-
) -> type:
|
103
|
-
class ParsedOutput:
|
104
|
-
"""
|
105
|
-
* Accepts any **str**.
|
106
|
-
* If the string contains `marker`, it must contain a valid JSON for
|
107
|
-
`response_format` → we return that a response_format instance.
|
108
|
-
* Otherwise we leave the string untouched.
|
109
|
-
"""
|
110
|
-
|
111
|
-
@classmethod
|
112
|
-
def __get_pydantic_core_schema__(
|
113
|
-
cls,
|
114
|
-
_source_type: Any,
|
115
|
-
_handler: GetCoreSchemaHandler,
|
116
|
-
) -> core_schema.CoreSchema:
|
117
|
-
def validator(v: Any) -> Any:
|
118
|
-
if isinstance(v, str) and marker in v:
|
119
|
-
v_json_str = format_json_string(v)
|
120
|
-
response_format_adapter = TypeAdapter[Any](response_format)
|
121
|
-
|
122
|
-
return response_format_adapter.validate_json(v_json_str)
|
123
|
-
|
124
|
-
return v
|
125
|
-
|
126
|
-
return core_schema.no_info_after_validator_function(
|
127
|
-
validator, core_schema.any_schema()
|
128
|
-
)
|
129
|
-
|
130
|
-
@classmethod
|
131
|
-
def __get_pydantic_json_schema__(
|
132
|
-
cls, core_schema: core_schema.CoreSchema, handler: GetCoreSchemaHandler
|
133
|
-
):
|
134
|
-
return handler(core_schema)
|
135
|
-
|
136
|
-
return ParsedOutput
|
137
|
-
|
138
|
-
|
139
|
-
def read_contents_from_file(
|
140
|
-
file_path: str | Path,
|
141
|
-
binary_mode: bool = False,
|
142
|
-
) -> str | bytes:
|
143
|
-
"""Reads and returns contents of file"""
|
144
|
-
try:
|
145
|
-
if binary_mode:
|
146
|
-
with open(file_path, "rb") as file:
|
147
|
-
return file.read()
|
148
|
-
else:
|
149
|
-
with open(file_path) as file:
|
150
|
-
return file.read()
|
151
|
-
except FileNotFoundError:
|
152
|
-
logger.error(f"File {file_path} not found.")
|
153
|
-
return ""
|
154
|
-
|
155
|
-
|
156
|
-
def get_prompt(prompt_text: str | None, prompt_path: str | Path | None) -> str | None:
|
157
|
-
if prompt_text is None:
|
158
|
-
prompt = (
|
159
|
-
read_contents_from_file(prompt_path) if prompt_path is not None else None
|
160
|
-
)
|
161
|
-
else:
|
162
|
-
prompt = prompt_text
|
163
|
-
|
164
|
-
return prompt # type: ignore[assignment]
|
165
|
-
|
166
|
-
|
167
|
-
async def asyncio_gather_with_pbar(
|
168
|
-
*corouts: Coroutine[Any, Any, Any],
|
169
|
-
no_tqdm: bool = False,
|
170
|
-
desc: str | None = None,
|
171
|
-
) -> list[Any]:
|
172
|
-
pbar = tqdm(total=len(corouts), desc=desc, disable=no_tqdm)
|
173
|
-
|
174
|
-
async def run_and_update(coro: Coroutine[Any, Any, Any]) -> Any:
|
175
|
-
result = await coro
|
176
|
-
pbar.update(1)
|
177
|
-
return result
|
178
|
-
|
179
|
-
wrapped_tasks = [run_and_update(c) for c in corouts]
|
180
|
-
results = await asyncio.gather(*wrapped_tasks)
|
181
|
-
pbar.close()
|
182
|
-
|
183
|
-
return results
|
184
|
-
|
185
|
-
|
186
|
-
def get_timestamp() -> str:
|
187
|
-
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|