grasp_agents 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +88 -109
- grasp_agents/litellm/converters.py +4 -2
- grasp_agents/litellm/lite_llm.py +72 -83
- grasp_agents/llm.py +35 -68
- grasp_agents/llm_agent.py +32 -36
- grasp_agents/llm_agent_memory.py +3 -2
- grasp_agents/llm_policy_executor.py +63 -33
- grasp_agents/openai/converters.py +4 -2
- grasp_agents/openai/openai_llm.py +60 -87
- grasp_agents/openai/tool_converters.py +6 -4
- grasp_agents/processors/base_processor.py +18 -10
- grasp_agents/processors/parallel_processor.py +8 -6
- grasp_agents/processors/processor.py +10 -6
- grasp_agents/prompt_builder.py +22 -28
- grasp_agents/run_context.py +1 -1
- grasp_agents/runner.py +1 -1
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/tool.py +13 -5
- grasp_agents/workflow/workflow_processor.py +4 -4
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.10.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.10.dist-info}/RECORD +23 -23
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.10.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.10.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/llm.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from collections.abc import AsyncIterator, Mapping
|
4
|
-
from
|
3
|
+
from collections.abc import AsyncIterator, Mapping
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from typing import Any, Generic, TypeVar
|
5
6
|
from uuid import uuid4
|
6
7
|
|
7
8
|
from pydantic import BaseModel
|
@@ -66,71 +67,24 @@ SettingsT_co = TypeVar("SettingsT_co", bound=LLMSettings, covariant=True)
|
|
66
67
|
ConvertT_co = TypeVar("ConvertT_co", bound=Converters, covariant=True)
|
67
68
|
|
68
69
|
|
70
|
+
@dataclass(frozen=True)
|
69
71
|
class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
70
|
-
|
71
|
-
|
72
|
+
model_name: str
|
73
|
+
converters: ConvertT_co
|
74
|
+
llm_settings: SettingsT_co | None = None
|
75
|
+
model_id: str = field(default_factory=lambda: str(uuid4())[:8])
|
76
|
+
|
77
|
+
def _validate_response(
|
72
78
|
self,
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
llm_settings: SettingsT_co | None = None,
|
77
|
-
tools: Sequence[BaseTool[BaseModel, Any, Any]] | None = None,
|
78
|
-
response_schema: Any | None = None,
|
79
|
-
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
80
|
-
**kwargs: Any,
|
79
|
+
completion: Completion,
|
80
|
+
response_schema: Any | None,
|
81
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None,
|
81
82
|
) -> None:
|
82
|
-
super().__init__()
|
83
|
-
|
84
|
-
self._converters = converters
|
85
|
-
self._model_id = model_id or str(uuid4())[:8]
|
86
|
-
self._model_name = model_name
|
87
|
-
self._tools = {t.name: t for t in tools} if tools else None
|
88
|
-
self._llm_settings: SettingsT_co = llm_settings or cast("SettingsT_co", {})
|
89
|
-
|
90
83
|
if response_schema and response_schema_by_xml_tag:
|
91
84
|
raise ValueError(
|
92
85
|
"Only one of response_schema and response_schema_by_xml_tag can be "
|
93
86
|
"provided, but not both."
|
94
87
|
)
|
95
|
-
self._response_schema = response_schema
|
96
|
-
self._response_schema_by_xml_tag = response_schema_by_xml_tag
|
97
|
-
|
98
|
-
@property
|
99
|
-
def model_id(self) -> str:
|
100
|
-
return self._model_id
|
101
|
-
|
102
|
-
@property
|
103
|
-
def model_name(self) -> str | None:
|
104
|
-
return self._model_name
|
105
|
-
|
106
|
-
@property
|
107
|
-
def llm_settings(self) -> SettingsT_co:
|
108
|
-
return self._llm_settings
|
109
|
-
|
110
|
-
@property
|
111
|
-
def response_schema(self) -> Any | None:
|
112
|
-
return self._response_schema
|
113
|
-
|
114
|
-
@response_schema.setter
|
115
|
-
def response_schema(self, response_schema: Any | None) -> None:
|
116
|
-
self._response_schema = response_schema
|
117
|
-
|
118
|
-
@property
|
119
|
-
def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
|
120
|
-
return self._response_schema_by_xml_tag
|
121
|
-
|
122
|
-
@property
|
123
|
-
def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
|
124
|
-
return self._tools
|
125
|
-
|
126
|
-
@tools.setter
|
127
|
-
def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
128
|
-
self._tools = {t.name: t for t in tools} if tools else None
|
129
|
-
|
130
|
-
def __repr__(self) -> str:
|
131
|
-
return f"{type(self).__name__}[{self.model_id}]; model_name={self._model_name})"
|
132
|
-
|
133
|
-
def _validate_response(self, completion: Completion) -> None:
|
134
88
|
parsing_params = {
|
135
89
|
"from_substring": False,
|
136
90
|
"strip_language_markdown": True,
|
@@ -138,17 +92,17 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
138
92
|
try:
|
139
93
|
for message in completion.messages:
|
140
94
|
if not message.tool_calls:
|
141
|
-
if
|
95
|
+
if response_schema:
|
142
96
|
validate_obj_from_json_or_py_string(
|
143
97
|
message.content or "",
|
144
|
-
schema=
|
98
|
+
schema=response_schema,
|
145
99
|
**parsing_params,
|
146
100
|
)
|
147
101
|
|
148
|
-
elif
|
102
|
+
elif response_schema_by_xml_tag:
|
149
103
|
validate_tagged_objs_from_json_or_py_string(
|
150
104
|
message.content or "",
|
151
|
-
schema_by_xml_tag=
|
105
|
+
schema_by_xml_tag=response_schema_by_xml_tag,
|
152
106
|
**parsing_params,
|
153
107
|
)
|
154
108
|
except JSONSchemaValidationError as exc:
|
@@ -156,7 +110,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
156
110
|
exc.s, exc.schema, message=str(exc)
|
157
111
|
) from exc
|
158
112
|
|
159
|
-
def _validate_tool_calls(
|
113
|
+
def _validate_tool_calls(
|
114
|
+
self, completion: Completion, tools: Mapping[str, BaseTool[BaseModel, Any, Any]]
|
115
|
+
) -> None:
|
160
116
|
parsing_params = {
|
161
117
|
"from_substring": False,
|
162
118
|
"strip_language_markdown": True,
|
@@ -167,15 +123,15 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
167
123
|
tool_name = tool_call.tool_name
|
168
124
|
tool_arguments = tool_call.tool_arguments
|
169
125
|
|
170
|
-
available_tool_names = list(
|
171
|
-
if tool_name not in available_tool_names or not
|
126
|
+
available_tool_names = list(tools) if tools else []
|
127
|
+
if tool_name not in available_tool_names or not tools:
|
172
128
|
raise LLMToolCallValidationError(
|
173
129
|
tool_name,
|
174
130
|
tool_arguments,
|
175
131
|
message=f"Tool '{tool_name}' is not available in the LLM "
|
176
132
|
f"tools (available: {available_tool_names})",
|
177
133
|
)
|
178
|
-
tool =
|
134
|
+
tool = tools[tool_name]
|
179
135
|
try:
|
180
136
|
validate_obj_from_json_or_py_string(
|
181
137
|
tool_arguments, schema=tool.in_type, **parsing_params
|
@@ -296,6 +252,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
296
252
|
self,
|
297
253
|
conversation: Messages,
|
298
254
|
*,
|
255
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
256
|
+
response_schema: Any | None = None,
|
257
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
299
258
|
tool_choice: ToolChoice | None = None,
|
300
259
|
n_choices: int | None = None,
|
301
260
|
proc_name: str | None = None,
|
@@ -308,6 +267,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
308
267
|
self,
|
309
268
|
conversation: Messages,
|
310
269
|
*,
|
270
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
271
|
+
response_schema: Any | None = None,
|
272
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
311
273
|
tool_choice: ToolChoice | None = None,
|
312
274
|
n_choices: int | None = None,
|
313
275
|
proc_name: str | None = None,
|
@@ -318,5 +280,10 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
318
280
|
pass
|
319
281
|
|
320
282
|
@abstractmethod
|
321
|
-
def combine_completion_chunks(
|
283
|
+
def combine_completion_chunks(
|
284
|
+
self,
|
285
|
+
completion_chunks: list[Any],
|
286
|
+
response_schema: Any | None = None,
|
287
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
288
|
+
) -> Any:
|
322
289
|
raise NotImplementedError
|
grasp_agents/llm_agent.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from collections.abc import AsyncIterator, Sequence
|
1
|
+
from collections.abc import AsyncIterator, Mapping, Sequence
|
2
2
|
from pathlib import Path
|
3
3
|
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
|
4
4
|
|
@@ -41,7 +41,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
|
|
41
41
|
conversation: Messages,
|
42
42
|
*,
|
43
43
|
in_args: _InT_contra | None,
|
44
|
-
ctx: RunContext[CtxT]
|
44
|
+
ctx: RunContext[CtxT],
|
45
45
|
) -> _OutT_co: ...
|
46
46
|
|
47
47
|
|
@@ -68,16 +68,19 @@ class LLMAgent(
|
|
68
68
|
# System prompt template
|
69
69
|
sys_prompt: LLMPrompt | None = None,
|
70
70
|
sys_prompt_path: str | Path | None = None,
|
71
|
+
# LLM response validation
|
72
|
+
response_schema: Any | None = None,
|
73
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
71
74
|
# Agent loop settings
|
72
75
|
max_turns: int = 100,
|
73
76
|
react_mode: bool = False,
|
74
77
|
final_answer_as_tool_call: bool = False,
|
75
78
|
# Agent memory management
|
76
79
|
reset_memory_on_run: bool = False,
|
77
|
-
#
|
80
|
+
# Agent run retries
|
78
81
|
max_retries: int = 0,
|
79
82
|
# Multi-agent routing
|
80
|
-
recipients:
|
83
|
+
recipients: Sequence[ProcName] | None = None,
|
81
84
|
) -> None:
|
82
85
|
super().__init__(name=name, recipients=recipients, max_retries=max_retries)
|
83
86
|
|
@@ -96,15 +99,6 @@ class LLMAgent(
|
|
96
99
|
|
97
100
|
# LLM policy executor
|
98
101
|
|
99
|
-
self._used_default_llm_response_schema: bool = False
|
100
|
-
if (
|
101
|
-
llm.response_schema is None
|
102
|
-
and tools is None
|
103
|
-
and not hasattr(type(self), "output_parser")
|
104
|
-
):
|
105
|
-
llm.response_schema = self.out_type
|
106
|
-
self._used_default_llm_response_schema = True
|
107
|
-
|
108
102
|
if issubclass(self._out_type, BaseModel):
|
109
103
|
final_answer_type = self._out_type
|
110
104
|
elif not final_answer_as_tool_call:
|
@@ -115,10 +109,21 @@ class LLMAgent(
|
|
115
109
|
"final_answer_as_tool_call is True."
|
116
110
|
)
|
117
111
|
|
112
|
+
self._used_default_llm_response_schema: bool = False
|
113
|
+
if (
|
114
|
+
response_schema is None
|
115
|
+
and tools is None
|
116
|
+
and not hasattr(type(self), "output_parser")
|
117
|
+
):
|
118
|
+
response_schema = self.out_type
|
119
|
+
self._used_default_llm_response_schema = True
|
120
|
+
|
118
121
|
self._policy_executor: LLMPolicyExecutor[CtxT] = LLMPolicyExecutor[CtxT](
|
119
122
|
agent_name=self.name,
|
120
123
|
llm=llm,
|
121
124
|
tools=tools,
|
125
|
+
response_schema=response_schema,
|
126
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
122
127
|
max_turns=max_turns,
|
123
128
|
react_mode=react_mode,
|
124
129
|
final_answer_type=final_answer_type,
|
@@ -160,9 +165,10 @@ class LLMAgent(
|
|
160
165
|
def _prepare_memory(
|
161
166
|
self,
|
162
167
|
memory: LLMAgentMemory,
|
168
|
+
*,
|
163
169
|
in_args: InT | None = None,
|
164
170
|
sys_prompt: LLMPrompt | None = None,
|
165
|
-
ctx: RunContext[Any]
|
171
|
+
ctx: RunContext[Any],
|
166
172
|
) -> None:
|
167
173
|
if self.memory_preparator:
|
168
174
|
return self.memory_preparator(
|
@@ -172,9 +178,10 @@ class LLMAgent(
|
|
172
178
|
def _memorize_inputs(
|
173
179
|
self,
|
174
180
|
memory: LLMAgentMemory,
|
181
|
+
*,
|
175
182
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
176
183
|
in_args: InT | None = None,
|
177
|
-
ctx: RunContext[CtxT]
|
184
|
+
ctx: RunContext[CtxT],
|
178
185
|
) -> tuple[SystemMessage | None, UserMessage | None]:
|
179
186
|
formatted_sys_prompt = self._prompt_builder.build_system_prompt(ctx=ctx)
|
180
187
|
|
@@ -201,7 +208,7 @@ class LLMAgent(
|
|
201
208
|
conversation: Messages,
|
202
209
|
*,
|
203
210
|
in_args: InT | None = None,
|
204
|
-
ctx: RunContext[CtxT]
|
211
|
+
ctx: RunContext[CtxT],
|
205
212
|
) -> OutT:
|
206
213
|
return validate_obj_from_json_or_py_string(
|
207
214
|
str(conversation[-1].content or ""),
|
@@ -215,7 +222,7 @@ class LLMAgent(
|
|
215
222
|
conversation: Messages,
|
216
223
|
*,
|
217
224
|
in_args: InT | None = None,
|
218
|
-
ctx: RunContext[CtxT]
|
225
|
+
ctx: RunContext[CtxT],
|
219
226
|
) -> OutT:
|
220
227
|
if self.output_parser:
|
221
228
|
return self.output_parser(
|
@@ -233,7 +240,7 @@ class LLMAgent(
|
|
233
240
|
in_args: InT | None = None,
|
234
241
|
memory: LLMAgentMemory,
|
235
242
|
call_id: str,
|
236
|
-
ctx: RunContext[CtxT]
|
243
|
+
ctx: RunContext[CtxT],
|
237
244
|
) -> OutT:
|
238
245
|
system_message, input_message = self._memorize_inputs(
|
239
246
|
memory=memory,
|
@@ -259,7 +266,7 @@ class LLMAgent(
|
|
259
266
|
in_args: InT | None = None,
|
260
267
|
memory: LLMAgentMemory,
|
261
268
|
call_id: str,
|
262
|
-
ctx: RunContext[CtxT]
|
269
|
+
ctx: RunContext[CtxT],
|
263
270
|
) -> AsyncIterator[Event[Any]]:
|
264
271
|
system_message, input_message = self._memorize_inputs(
|
265
272
|
memory=memory,
|
@@ -292,7 +299,7 @@ class LLMAgent(
|
|
292
299
|
self,
|
293
300
|
messages: Sequence[Message],
|
294
301
|
call_id: str,
|
295
|
-
ctx: RunContext[CtxT]
|
302
|
+
ctx: RunContext[CtxT],
|
296
303
|
) -> None:
|
297
304
|
if ctx and ctx.printer:
|
298
305
|
ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
|
@@ -321,24 +328,18 @@ class LLMAgent(
|
|
321
328
|
if cur_cls.memory_manager is not base_cls.memory_manager:
|
322
329
|
self._policy_executor.memory_manager = self.memory_manager
|
323
330
|
|
324
|
-
def system_prompt_builder(self, ctx: RunContext[CtxT]
|
331
|
+
def system_prompt_builder(self, ctx: RunContext[CtxT]) -> str | None:
|
325
332
|
if self._prompt_builder.system_prompt_builder is not None:
|
326
333
|
return self._prompt_builder.system_prompt_builder(ctx=ctx)
|
327
334
|
raise NotImplementedError("System prompt builder is not implemented.")
|
328
335
|
|
329
|
-
def input_content_builder(
|
330
|
-
self, in_args: InT | None = None, *, ctx: RunContext[CtxT] | None = None
|
331
|
-
) -> Content:
|
336
|
+
def input_content_builder(self, in_args: InT, ctx: RunContext[CtxT]) -> Content:
|
332
337
|
if self._prompt_builder.input_content_builder is not None:
|
333
338
|
return self._prompt_builder.input_content_builder(in_args=in_args, ctx=ctx)
|
334
339
|
raise NotImplementedError("Input content builder is not implemented.")
|
335
340
|
|
336
341
|
def tool_call_loop_terminator(
|
337
|
-
self,
|
338
|
-
conversation: Messages,
|
339
|
-
*,
|
340
|
-
ctx: RunContext[CtxT] | None = None,
|
341
|
-
**kwargs: Any,
|
342
|
+
self, conversation: Messages, *, ctx: RunContext[CtxT], **kwargs: Any
|
342
343
|
) -> bool:
|
343
344
|
if self._policy_executor.tool_call_loop_terminator is not None:
|
344
345
|
return self._policy_executor.tool_call_loop_terminator(
|
@@ -347,11 +348,7 @@ class LLMAgent(
|
|
347
348
|
raise NotImplementedError("Tool call loop terminator is not implemented.")
|
348
349
|
|
349
350
|
def memory_manager(
|
350
|
-
self,
|
351
|
-
memory: LLMAgentMemory,
|
352
|
-
*,
|
353
|
-
ctx: RunContext[CtxT] | None = None,
|
354
|
-
**kwargs: Any,
|
351
|
+
self, memory: LLMAgentMemory, *, ctx: RunContext[CtxT], **kwargs: Any
|
355
352
|
) -> None:
|
356
353
|
if self._policy_executor.memory_manager is not None:
|
357
354
|
return self._policy_executor.memory_manager(
|
@@ -391,12 +388,11 @@ class LLMAgent(
|
|
391
388
|
self, func: OutputParser[InT, OutT, CtxT]
|
392
389
|
) -> OutputParser[InT, OutT, CtxT]:
|
393
390
|
if self._used_default_llm_response_schema:
|
394
|
-
self._policy_executor.
|
391
|
+
self._policy_executor.response_schema = None
|
395
392
|
self.output_parser = func
|
396
393
|
|
397
394
|
return func
|
398
395
|
|
399
396
|
def add_memory_preparator(self, func: MemoryPreparator) -> MemoryPreparator:
|
400
397
|
self.memory_preparator = func
|
401
|
-
|
402
398
|
return func
|
grasp_agents/llm_agent_memory.py
CHANGED
@@ -13,14 +13,15 @@ class MemoryPreparator(Protocol):
|
|
13
13
|
def __call__(
|
14
14
|
self,
|
15
15
|
memory: "LLMAgentMemory",
|
16
|
+
*,
|
16
17
|
in_args: Any | None,
|
17
18
|
sys_prompt: LLMPrompt | None,
|
18
|
-
ctx: RunContext[Any]
|
19
|
+
ctx: RunContext[Any],
|
19
20
|
) -> None: ...
|
20
21
|
|
21
22
|
|
22
23
|
class LLMAgentMemory(Memory):
|
23
|
-
_message_history: Messages = PrivateAttr(default_factory=
|
24
|
+
_message_history: Messages = PrivateAttr(default_factory=Messages)
|
24
25
|
_sys_prompt: LLMPrompt | None = PrivateAttr(default=None)
|
25
26
|
|
26
27
|
def __init__(self, sys_prompt: LLMPrompt | None = None) -> None:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
|
-
from collections.abc import AsyncIterator, Coroutine, Sequence
|
3
|
+
from collections.abc import AsyncIterator, Coroutine, Mapping, Sequence
|
4
4
|
from itertools import starmap
|
5
5
|
from logging import getLogger
|
6
6
|
from typing import Any, Generic, Protocol, final
|
@@ -36,7 +36,7 @@ class ToolCallLoopTerminator(Protocol[CtxT]):
|
|
36
36
|
self,
|
37
37
|
conversation: Messages,
|
38
38
|
*,
|
39
|
-
ctx: RunContext[CtxT]
|
39
|
+
ctx: RunContext[CtxT],
|
40
40
|
**kwargs: Any,
|
41
41
|
) -> bool: ...
|
42
42
|
|
@@ -46,7 +46,7 @@ class MemoryManager(Protocol[CtxT]):
|
|
46
46
|
self,
|
47
47
|
memory: LLMAgentMemory,
|
48
48
|
*,
|
49
|
-
ctx: RunContext[CtxT]
|
49
|
+
ctx: RunContext[CtxT],
|
50
50
|
**kwargs: Any,
|
51
51
|
) -> None: ...
|
52
52
|
|
@@ -54,9 +54,12 @@ class MemoryManager(Protocol[CtxT]):
|
|
54
54
|
class LLMPolicyExecutor(Generic[CtxT]):
|
55
55
|
def __init__(
|
56
56
|
self,
|
57
|
+
*,
|
57
58
|
agent_name: str,
|
58
59
|
llm: LLM[LLMSettings, Converters],
|
59
60
|
tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
|
61
|
+
response_schema: Any | None = None,
|
62
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
60
63
|
max_turns: int,
|
61
64
|
react_mode: bool = False,
|
62
65
|
final_answer_type: type[BaseModel] = BaseModel,
|
@@ -70,12 +73,15 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
70
73
|
self._final_answer_as_tool_call = final_answer_as_tool_call
|
71
74
|
self._final_answer_tool = self.get_final_answer_tool()
|
72
75
|
|
73
|
-
|
76
|
+
tools_list: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
|
74
77
|
if tools and final_answer_as_tool_call:
|
75
|
-
|
78
|
+
tools_list = tools + [self._final_answer_tool]
|
79
|
+
self._tools = {t.name: t for t in tools_list} if tools_list else None
|
80
|
+
|
81
|
+
self._response_schema = response_schema
|
82
|
+
self._response_schema_by_xml_tag = response_schema_by_xml_tag
|
76
83
|
|
77
84
|
self._llm = llm
|
78
|
-
self._llm.tools = _tools
|
79
85
|
|
80
86
|
self._max_turns = max_turns
|
81
87
|
self._react_mode = react_mode
|
@@ -91,9 +97,21 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
91
97
|
def llm(self) -> LLM[LLMSettings, Converters]:
|
92
98
|
return self._llm
|
93
99
|
|
100
|
+
@property
|
101
|
+
def response_schema(self) -> Any | None:
|
102
|
+
return self._response_schema
|
103
|
+
|
104
|
+
@response_schema.setter
|
105
|
+
def response_schema(self, value: Any | None) -> None:
|
106
|
+
self._response_schema = value
|
107
|
+
|
108
|
+
@property
|
109
|
+
def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
|
110
|
+
return self._response_schema_by_xml_tag
|
111
|
+
|
94
112
|
@property
|
95
113
|
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
96
|
-
return self.
|
114
|
+
return self._tools or {}
|
97
115
|
|
98
116
|
@property
|
99
117
|
def max_turns(self) -> int:
|
@@ -104,7 +122,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
104
122
|
self,
|
105
123
|
conversation: Messages,
|
106
124
|
*,
|
107
|
-
ctx: RunContext[CtxT]
|
125
|
+
ctx: RunContext[CtxT],
|
108
126
|
**kwargs: Any,
|
109
127
|
) -> bool:
|
110
128
|
if self.tool_call_loop_terminator:
|
@@ -117,7 +135,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
117
135
|
self,
|
118
136
|
memory: LLMAgentMemory,
|
119
137
|
*,
|
120
|
-
ctx: RunContext[CtxT]
|
138
|
+
ctx: RunContext[CtxT],
|
121
139
|
**kwargs: Any,
|
122
140
|
) -> None:
|
123
141
|
if self.memory_manager:
|
@@ -126,12 +144,16 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
126
144
|
async def generate_message(
|
127
145
|
self,
|
128
146
|
memory: LLMAgentMemory,
|
147
|
+
*,
|
129
148
|
call_id: str,
|
130
149
|
tool_choice: ToolChoice | None = None,
|
131
|
-
ctx: RunContext[CtxT]
|
150
|
+
ctx: RunContext[CtxT],
|
132
151
|
) -> AssistantMessage:
|
133
152
|
completion = await self.llm.generate_completion(
|
134
153
|
memory.message_history,
|
154
|
+
response_schema=self.response_schema,
|
155
|
+
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
156
|
+
tools=self.tools,
|
135
157
|
tool_choice=tool_choice,
|
136
158
|
n_choices=1,
|
137
159
|
proc_name=self.agent_name,
|
@@ -147,9 +169,10 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
147
169
|
async def generate_message_stream(
|
148
170
|
self,
|
149
171
|
memory: LLMAgentMemory,
|
172
|
+
*,
|
150
173
|
call_id: str,
|
151
174
|
tool_choice: ToolChoice | None = None,
|
152
|
-
ctx: RunContext[CtxT]
|
175
|
+
ctx: RunContext[CtxT],
|
153
176
|
) -> AsyncIterator[
|
154
177
|
CompletionChunkEvent[CompletionChunk]
|
155
178
|
| CompletionEvent
|
@@ -160,6 +183,9 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
160
183
|
|
161
184
|
llm_event_stream = self.llm.generate_completion_stream(
|
162
185
|
memory.message_history,
|
186
|
+
response_schema=self.response_schema,
|
187
|
+
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
188
|
+
tools=self.tools,
|
163
189
|
tool_choice=tool_choice,
|
164
190
|
n_choices=1,
|
165
191
|
proc_name=self.agent_name,
|
@@ -189,14 +215,14 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
189
215
|
calls: Sequence[ToolCall],
|
190
216
|
memory: LLMAgentMemory,
|
191
217
|
call_id: str,
|
192
|
-
ctx: RunContext[CtxT]
|
218
|
+
ctx: RunContext[CtxT],
|
193
219
|
) -> Sequence[ToolMessage]:
|
194
220
|
# TODO: Add image support
|
195
221
|
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
196
222
|
for call in calls:
|
197
223
|
tool = self.tools[call.tool_name]
|
198
224
|
args = json.loads(call.tool_arguments)
|
199
|
-
corouts.append(tool(ctx=ctx, **args))
|
225
|
+
corouts.append(tool(call_id=call_id, ctx=ctx, **args))
|
200
226
|
|
201
227
|
outs = await asyncio.gather(*corouts)
|
202
228
|
tool_messages = list(
|
@@ -217,7 +243,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
217
243
|
calls: Sequence[ToolCall],
|
218
244
|
memory: LLMAgentMemory,
|
219
245
|
call_id: str,
|
220
|
-
ctx: RunContext[CtxT]
|
246
|
+
ctx: RunContext[CtxT],
|
221
247
|
) -> AsyncIterator[ToolMessageEvent]:
|
222
248
|
tool_messages = await self.call_tools(
|
223
249
|
calls, memory=memory, call_id=call_id, ctx=ctx
|
@@ -245,7 +271,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
245
271
|
return final_answer_message
|
246
272
|
|
247
273
|
async def _generate_final_answer(
|
248
|
-
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
274
|
+
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
249
275
|
) -> AssistantMessage:
|
250
276
|
user_message = UserMessage.from_text(
|
251
277
|
"Exceeded the maximum number of turns: provide a final answer now!"
|
@@ -268,7 +294,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
268
294
|
return final_answer_message
|
269
295
|
|
270
296
|
async def _generate_final_answer_stream(
|
271
|
-
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
297
|
+
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
272
298
|
) -> AsyncIterator[Event[Any]]:
|
273
299
|
user_message = UserMessage.from_text(
|
274
300
|
"Exceeded the maximum number of turns: provide a final answer now!",
|
@@ -296,7 +322,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
296
322
|
)
|
297
323
|
|
298
324
|
async def execute(
|
299
|
-
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
325
|
+
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
300
326
|
) -> AssistantMessage | Sequence[AssistantMessage]:
|
301
327
|
# 1. Generate the first message:
|
302
328
|
# In ReAct mode, we generate the first message without tool calls
|
@@ -379,7 +405,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
379
405
|
self,
|
380
406
|
memory: LLMAgentMemory,
|
381
407
|
call_id: str,
|
382
|
-
ctx: RunContext[CtxT]
|
408
|
+
ctx: RunContext[CtxT],
|
383
409
|
) -> AsyncIterator[Event[Any]]:
|
384
410
|
tool_choice: ToolChoice = "none" if self._react_mode else "auto"
|
385
411
|
gen_message: AssistantMessage | None = None
|
@@ -464,7 +490,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
464
490
|
)
|
465
491
|
|
466
492
|
async def run(
|
467
|
-
self,
|
493
|
+
self,
|
494
|
+
inp: BaseModel,
|
495
|
+
*,
|
496
|
+
call_id: str | None = None,
|
497
|
+
ctx: RunContext[Any] | None = None,
|
468
498
|
) -> None:
|
469
499
|
return None
|
470
500
|
|
@@ -473,22 +503,22 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
473
503
|
def _process_completion(
|
474
504
|
self,
|
475
505
|
completion: Completion,
|
506
|
+
*,
|
476
507
|
call_id: str,
|
477
508
|
print_messages: bool = False,
|
478
|
-
ctx: RunContext[CtxT]
|
509
|
+
ctx: RunContext[CtxT],
|
479
510
|
) -> None:
|
480
|
-
|
481
|
-
|
482
|
-
|
511
|
+
ctx.completions[self.agent_name].append(completion)
|
512
|
+
ctx.usage_tracker.update(
|
513
|
+
agent_name=self.agent_name,
|
514
|
+
completions=[completion],
|
515
|
+
model_name=self.llm.model_name,
|
516
|
+
)
|
517
|
+
if ctx.printer and print_messages:
|
518
|
+
usages = [None] * (len(completion.messages) - 1) + [completion.usage]
|
519
|
+
ctx.printer.print_messages(
|
520
|
+
completion.messages,
|
521
|
+
usages=usages,
|
483
522
|
agent_name=self.agent_name,
|
484
|
-
|
485
|
-
model_name=self.llm.model_name,
|
523
|
+
call_id=call_id,
|
486
524
|
)
|
487
|
-
if ctx.printer and print_messages:
|
488
|
-
usages = [None] * (len(completion.messages) - 1) + [completion.usage]
|
489
|
-
ctx.printer.print_messages(
|
490
|
-
completion.messages,
|
491
|
-
usages=usages,
|
492
|
-
agent_name=self.agent_name,
|
493
|
-
call_id=call_id,
|
494
|
-
)
|
@@ -96,8 +96,10 @@ class OpenAIConverters(Converters):
|
|
96
96
|
return from_api_tool_message(raw_message, name=name, **kwargs)
|
97
97
|
|
98
98
|
@staticmethod
|
99
|
-
def to_tool(
|
100
|
-
|
99
|
+
def to_tool(
|
100
|
+
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
|
101
|
+
) -> OpenAIToolParam:
|
102
|
+
return to_api_tool(tool, strict=strict, **kwargs)
|
101
103
|
|
102
104
|
@staticmethod
|
103
105
|
def to_tool_choice(
|