grasp_agents 0.5.9__py3-none-any.whl → 0.5.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +87 -109
- grasp_agents/litellm/converters.py +4 -2
- grasp_agents/litellm/lite_llm.py +72 -83
- grasp_agents/llm.py +35 -68
- grasp_agents/llm_agent.py +76 -52
- grasp_agents/llm_agent_memory.py +4 -2
- grasp_agents/llm_policy_executor.py +91 -55
- grasp_agents/openai/converters.py +4 -2
- grasp_agents/openai/openai_llm.py +61 -88
- grasp_agents/openai/tool_converters.py +6 -4
- grasp_agents/processors/base_processor.py +18 -10
- grasp_agents/processors/parallel_processor.py +8 -6
- grasp_agents/processors/processor.py +10 -6
- grasp_agents/prompt_builder.py +38 -28
- grasp_agents/run_context.py +1 -1
- grasp_agents/runner.py +1 -1
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/tool.py +15 -5
- grasp_agents/workflow/workflow_processor.py +4 -4
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.11.dist-info}/METADATA +4 -5
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.11.dist-info}/RECORD +23 -23
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.11.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.11.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/llm.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from collections.abc import AsyncIterator, Mapping
|
4
|
-
from
|
3
|
+
from collections.abc import AsyncIterator, Mapping
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from typing import Any, Generic, TypeVar
|
5
6
|
from uuid import uuid4
|
6
7
|
|
7
8
|
from pydantic import BaseModel
|
@@ -66,71 +67,24 @@ SettingsT_co = TypeVar("SettingsT_co", bound=LLMSettings, covariant=True)
|
|
66
67
|
ConvertT_co = TypeVar("ConvertT_co", bound=Converters, covariant=True)
|
67
68
|
|
68
69
|
|
70
|
+
@dataclass(frozen=True)
|
69
71
|
class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
70
|
-
|
71
|
-
|
72
|
+
model_name: str
|
73
|
+
converters: ConvertT_co
|
74
|
+
llm_settings: SettingsT_co | None = None
|
75
|
+
model_id: str = field(default_factory=lambda: str(uuid4())[:8])
|
76
|
+
|
77
|
+
def _validate_response(
|
72
78
|
self,
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
llm_settings: SettingsT_co | None = None,
|
77
|
-
tools: Sequence[BaseTool[BaseModel, Any, Any]] | None = None,
|
78
|
-
response_schema: Any | None = None,
|
79
|
-
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
80
|
-
**kwargs: Any,
|
79
|
+
completion: Completion,
|
80
|
+
response_schema: Any | None,
|
81
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None,
|
81
82
|
) -> None:
|
82
|
-
super().__init__()
|
83
|
-
|
84
|
-
self._converters = converters
|
85
|
-
self._model_id = model_id or str(uuid4())[:8]
|
86
|
-
self._model_name = model_name
|
87
|
-
self._tools = {t.name: t for t in tools} if tools else None
|
88
|
-
self._llm_settings: SettingsT_co = llm_settings or cast("SettingsT_co", {})
|
89
|
-
|
90
83
|
if response_schema and response_schema_by_xml_tag:
|
91
84
|
raise ValueError(
|
92
85
|
"Only one of response_schema and response_schema_by_xml_tag can be "
|
93
86
|
"provided, but not both."
|
94
87
|
)
|
95
|
-
self._response_schema = response_schema
|
96
|
-
self._response_schema_by_xml_tag = response_schema_by_xml_tag
|
97
|
-
|
98
|
-
@property
|
99
|
-
def model_id(self) -> str:
|
100
|
-
return self._model_id
|
101
|
-
|
102
|
-
@property
|
103
|
-
def model_name(self) -> str | None:
|
104
|
-
return self._model_name
|
105
|
-
|
106
|
-
@property
|
107
|
-
def llm_settings(self) -> SettingsT_co:
|
108
|
-
return self._llm_settings
|
109
|
-
|
110
|
-
@property
|
111
|
-
def response_schema(self) -> Any | None:
|
112
|
-
return self._response_schema
|
113
|
-
|
114
|
-
@response_schema.setter
|
115
|
-
def response_schema(self, response_schema: Any | None) -> None:
|
116
|
-
self._response_schema = response_schema
|
117
|
-
|
118
|
-
@property
|
119
|
-
def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
|
120
|
-
return self._response_schema_by_xml_tag
|
121
|
-
|
122
|
-
@property
|
123
|
-
def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
|
124
|
-
return self._tools
|
125
|
-
|
126
|
-
@tools.setter
|
127
|
-
def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
128
|
-
self._tools = {t.name: t for t in tools} if tools else None
|
129
|
-
|
130
|
-
def __repr__(self) -> str:
|
131
|
-
return f"{type(self).__name__}[{self.model_id}]; model_name={self._model_name})"
|
132
|
-
|
133
|
-
def _validate_response(self, completion: Completion) -> None:
|
134
88
|
parsing_params = {
|
135
89
|
"from_substring": False,
|
136
90
|
"strip_language_markdown": True,
|
@@ -138,17 +92,17 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
138
92
|
try:
|
139
93
|
for message in completion.messages:
|
140
94
|
if not message.tool_calls:
|
141
|
-
if
|
95
|
+
if response_schema:
|
142
96
|
validate_obj_from_json_or_py_string(
|
143
97
|
message.content or "",
|
144
|
-
schema=
|
98
|
+
schema=response_schema,
|
145
99
|
**parsing_params,
|
146
100
|
)
|
147
101
|
|
148
|
-
elif
|
102
|
+
elif response_schema_by_xml_tag:
|
149
103
|
validate_tagged_objs_from_json_or_py_string(
|
150
104
|
message.content or "",
|
151
|
-
schema_by_xml_tag=
|
105
|
+
schema_by_xml_tag=response_schema_by_xml_tag,
|
152
106
|
**parsing_params,
|
153
107
|
)
|
154
108
|
except JSONSchemaValidationError as exc:
|
@@ -156,7 +110,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
156
110
|
exc.s, exc.schema, message=str(exc)
|
157
111
|
) from exc
|
158
112
|
|
159
|
-
def _validate_tool_calls(
|
113
|
+
def _validate_tool_calls(
|
114
|
+
self, completion: Completion, tools: Mapping[str, BaseTool[BaseModel, Any, Any]]
|
115
|
+
) -> None:
|
160
116
|
parsing_params = {
|
161
117
|
"from_substring": False,
|
162
118
|
"strip_language_markdown": True,
|
@@ -167,15 +123,15 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
167
123
|
tool_name = tool_call.tool_name
|
168
124
|
tool_arguments = tool_call.tool_arguments
|
169
125
|
|
170
|
-
available_tool_names = list(
|
171
|
-
if tool_name not in available_tool_names or not
|
126
|
+
available_tool_names = list(tools) if tools else []
|
127
|
+
if tool_name not in available_tool_names or not tools:
|
172
128
|
raise LLMToolCallValidationError(
|
173
129
|
tool_name,
|
174
130
|
tool_arguments,
|
175
131
|
message=f"Tool '{tool_name}' is not available in the LLM "
|
176
132
|
f"tools (available: {available_tool_names})",
|
177
133
|
)
|
178
|
-
tool =
|
134
|
+
tool = tools[tool_name]
|
179
135
|
try:
|
180
136
|
validate_obj_from_json_or_py_string(
|
181
137
|
tool_arguments, schema=tool.in_type, **parsing_params
|
@@ -296,6 +252,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
296
252
|
self,
|
297
253
|
conversation: Messages,
|
298
254
|
*,
|
255
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
256
|
+
response_schema: Any | None = None,
|
257
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
299
258
|
tool_choice: ToolChoice | None = None,
|
300
259
|
n_choices: int | None = None,
|
301
260
|
proc_name: str | None = None,
|
@@ -308,6 +267,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
308
267
|
self,
|
309
268
|
conversation: Messages,
|
310
269
|
*,
|
270
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
271
|
+
response_schema: Any | None = None,
|
272
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
311
273
|
tool_choice: ToolChoice | None = None,
|
312
274
|
n_choices: int | None = None,
|
313
275
|
proc_name: str | None = None,
|
@@ -318,5 +280,10 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
318
280
|
pass
|
319
281
|
|
320
282
|
@abstractmethod
|
321
|
-
def combine_completion_chunks(
|
283
|
+
def combine_completion_chunks(
|
284
|
+
self,
|
285
|
+
completion_chunks: list[Any],
|
286
|
+
response_schema: Any | None = None,
|
287
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
288
|
+
) -> Any:
|
322
289
|
raise NotImplementedError
|
grasp_agents/llm_agent.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from collections.abc import AsyncIterator, Sequence
|
1
|
+
from collections.abc import AsyncIterator, Mapping, Sequence
|
2
2
|
from pathlib import Path
|
3
3
|
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
|
4
4
|
|
@@ -41,7 +41,8 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
|
|
41
41
|
conversation: Messages,
|
42
42
|
*,
|
43
43
|
in_args: _InT_contra | None,
|
44
|
-
ctx: RunContext[CtxT]
|
44
|
+
ctx: RunContext[CtxT],
|
45
|
+
call_id: str,
|
45
46
|
) -> _OutT_co: ...
|
46
47
|
|
47
48
|
|
@@ -68,16 +69,19 @@ class LLMAgent(
|
|
68
69
|
# System prompt template
|
69
70
|
sys_prompt: LLMPrompt | None = None,
|
70
71
|
sys_prompt_path: str | Path | None = None,
|
72
|
+
# LLM response validation
|
73
|
+
response_schema: Any | None = None,
|
74
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
71
75
|
# Agent loop settings
|
72
76
|
max_turns: int = 100,
|
73
77
|
react_mode: bool = False,
|
74
78
|
final_answer_as_tool_call: bool = False,
|
75
79
|
# Agent memory management
|
76
80
|
reset_memory_on_run: bool = False,
|
77
|
-
#
|
81
|
+
# Agent run retries
|
78
82
|
max_retries: int = 0,
|
79
83
|
# Multi-agent routing
|
80
|
-
recipients:
|
84
|
+
recipients: Sequence[ProcName] | None = None,
|
81
85
|
) -> None:
|
82
86
|
super().__init__(name=name, recipients=recipients, max_retries=max_retries)
|
83
87
|
|
@@ -96,15 +100,6 @@ class LLMAgent(
|
|
96
100
|
|
97
101
|
# LLM policy executor
|
98
102
|
|
99
|
-
self._used_default_llm_response_schema: bool = False
|
100
|
-
if (
|
101
|
-
llm.response_schema is None
|
102
|
-
and tools is None
|
103
|
-
and not hasattr(type(self), "output_parser")
|
104
|
-
):
|
105
|
-
llm.response_schema = self.out_type
|
106
|
-
self._used_default_llm_response_schema = True
|
107
|
-
|
108
103
|
if issubclass(self._out_type, BaseModel):
|
109
104
|
final_answer_type = self._out_type
|
110
105
|
elif not final_answer_as_tool_call:
|
@@ -115,10 +110,21 @@ class LLMAgent(
|
|
115
110
|
"final_answer_as_tool_call is True."
|
116
111
|
)
|
117
112
|
|
113
|
+
self._used_default_llm_response_schema: bool = False
|
114
|
+
if (
|
115
|
+
response_schema is None
|
116
|
+
and tools is None
|
117
|
+
and not hasattr(type(self), "output_parser")
|
118
|
+
):
|
119
|
+
response_schema = self.out_type
|
120
|
+
self._used_default_llm_response_schema = True
|
121
|
+
|
118
122
|
self._policy_executor: LLMPolicyExecutor[CtxT] = LLMPolicyExecutor[CtxT](
|
119
123
|
agent_name=self.name,
|
120
124
|
llm=llm,
|
121
125
|
tools=tools,
|
126
|
+
response_schema=response_schema,
|
127
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
122
128
|
max_turns=max_turns,
|
123
129
|
react_mode=react_mode,
|
124
130
|
final_answer_type=final_answer_type,
|
@@ -160,23 +166,33 @@ class LLMAgent(
|
|
160
166
|
def _prepare_memory(
|
161
167
|
self,
|
162
168
|
memory: LLMAgentMemory,
|
169
|
+
*,
|
163
170
|
in_args: InT | None = None,
|
164
171
|
sys_prompt: LLMPrompt | None = None,
|
165
|
-
ctx: RunContext[Any]
|
172
|
+
ctx: RunContext[Any],
|
173
|
+
call_id: str,
|
166
174
|
) -> None:
|
167
175
|
if self.memory_preparator:
|
168
176
|
return self.memory_preparator(
|
169
|
-
memory=memory,
|
177
|
+
memory=memory,
|
178
|
+
in_args=in_args,
|
179
|
+
sys_prompt=sys_prompt,
|
180
|
+
ctx=ctx,
|
181
|
+
call_id=call_id,
|
170
182
|
)
|
171
183
|
|
172
184
|
def _memorize_inputs(
|
173
185
|
self,
|
174
186
|
memory: LLMAgentMemory,
|
187
|
+
*,
|
175
188
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
176
189
|
in_args: InT | None = None,
|
177
|
-
ctx: RunContext[CtxT]
|
190
|
+
ctx: RunContext[CtxT],
|
191
|
+
call_id: str,
|
178
192
|
) -> tuple[SystemMessage | None, UserMessage | None]:
|
179
|
-
formatted_sys_prompt = self._prompt_builder.build_system_prompt(
|
193
|
+
formatted_sys_prompt = self._prompt_builder.build_system_prompt(
|
194
|
+
ctx=ctx, call_id=call_id
|
195
|
+
)
|
180
196
|
|
181
197
|
system_message: SystemMessage | None = None
|
182
198
|
if self._reset_memory_on_run or memory.is_empty:
|
@@ -185,24 +201,22 @@ class LLMAgent(
|
|
185
201
|
system_message = cast("SystemMessage", memory.message_history[0])
|
186
202
|
else:
|
187
203
|
self._prepare_memory(
|
188
|
-
memory=memory,
|
204
|
+
memory=memory,
|
205
|
+
in_args=in_args,
|
206
|
+
sys_prompt=formatted_sys_prompt,
|
207
|
+
ctx=ctx,
|
208
|
+
call_id=call_id,
|
189
209
|
)
|
190
210
|
|
191
211
|
input_message = self._prompt_builder.build_input_message(
|
192
|
-
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
212
|
+
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx, call_id=call_id
|
193
213
|
)
|
194
214
|
if input_message:
|
195
215
|
memory.update([input_message])
|
196
216
|
|
197
217
|
return system_message, input_message
|
198
218
|
|
199
|
-
def
|
200
|
-
self,
|
201
|
-
conversation: Messages,
|
202
|
-
*,
|
203
|
-
in_args: InT | None = None,
|
204
|
-
ctx: RunContext[CtxT] | None = None,
|
205
|
-
) -> OutT:
|
219
|
+
def parse_output_default(self, conversation: Messages) -> OutT:
|
206
220
|
return validate_obj_from_json_or_py_string(
|
207
221
|
str(conversation[-1].content or ""),
|
208
222
|
schema=self._out_type,
|
@@ -215,16 +229,15 @@ class LLMAgent(
|
|
215
229
|
conversation: Messages,
|
216
230
|
*,
|
217
231
|
in_args: InT | None = None,
|
218
|
-
ctx: RunContext[CtxT]
|
232
|
+
ctx: RunContext[CtxT],
|
233
|
+
call_id: str,
|
219
234
|
) -> OutT:
|
220
235
|
if self.output_parser:
|
221
236
|
return self.output_parser(
|
222
|
-
conversation=conversation, in_args=in_args, ctx=ctx
|
237
|
+
conversation=conversation, in_args=in_args, ctx=ctx, call_id=call_id
|
223
238
|
)
|
224
239
|
|
225
|
-
return self.
|
226
|
-
conversation=conversation, in_args=in_args, ctx=ctx
|
227
|
-
)
|
240
|
+
return self.parse_output_default(conversation)
|
228
241
|
|
229
242
|
async def _process(
|
230
243
|
self,
|
@@ -232,24 +245,28 @@ class LLMAgent(
|
|
232
245
|
*,
|
233
246
|
in_args: InT | None = None,
|
234
247
|
memory: LLMAgentMemory,
|
248
|
+
ctx: RunContext[CtxT],
|
235
249
|
call_id: str,
|
236
|
-
ctx: RunContext[CtxT] | None = None,
|
237
250
|
) -> OutT:
|
238
251
|
system_message, input_message = self._memorize_inputs(
|
239
252
|
memory=memory,
|
240
253
|
chat_inputs=chat_inputs,
|
241
254
|
in_args=in_args,
|
242
255
|
ctx=ctx,
|
256
|
+
call_id=call_id,
|
243
257
|
)
|
244
258
|
if system_message:
|
245
|
-
self._print_messages([system_message],
|
259
|
+
self._print_messages([system_message], ctx=ctx, call_id=call_id)
|
246
260
|
if input_message:
|
247
|
-
self._print_messages([input_message],
|
261
|
+
self._print_messages([input_message], ctx=ctx, call_id=call_id)
|
248
262
|
|
249
|
-
await self._policy_executor.execute(memory,
|
263
|
+
await self._policy_executor.execute(memory, ctx=ctx, call_id=call_id)
|
250
264
|
|
251
265
|
return self._parse_output(
|
252
|
-
conversation=memory.message_history,
|
266
|
+
conversation=memory.message_history,
|
267
|
+
in_args=in_args,
|
268
|
+
ctx=ctx,
|
269
|
+
call_id=call_id,
|
253
270
|
)
|
254
271
|
|
255
272
|
async def _process_stream(
|
@@ -258,41 +275,45 @@ class LLMAgent(
|
|
258
275
|
*,
|
259
276
|
in_args: InT | None = None,
|
260
277
|
memory: LLMAgentMemory,
|
278
|
+
ctx: RunContext[CtxT],
|
261
279
|
call_id: str,
|
262
|
-
ctx: RunContext[CtxT] | None = None,
|
263
280
|
) -> AsyncIterator[Event[Any]]:
|
264
281
|
system_message, input_message = self._memorize_inputs(
|
265
282
|
memory=memory,
|
266
283
|
chat_inputs=chat_inputs,
|
267
284
|
in_args=in_args,
|
268
285
|
ctx=ctx,
|
286
|
+
call_id=call_id,
|
269
287
|
)
|
270
288
|
if system_message:
|
271
|
-
self._print_messages([system_message],
|
289
|
+
self._print_messages([system_message], ctx=ctx, call_id=call_id)
|
272
290
|
yield SystemMessageEvent(
|
273
291
|
data=system_message, proc_name=self.name, call_id=call_id
|
274
292
|
)
|
275
293
|
if input_message:
|
276
|
-
self._print_messages([input_message],
|
294
|
+
self._print_messages([input_message], ctx=ctx, call_id=call_id)
|
277
295
|
yield UserMessageEvent(
|
278
296
|
data=input_message, proc_name=self.name, call_id=call_id
|
279
297
|
)
|
280
298
|
|
281
299
|
async for event in self._policy_executor.execute_stream(
|
282
|
-
memory,
|
300
|
+
memory, ctx=ctx, call_id=call_id
|
283
301
|
):
|
284
302
|
yield event
|
285
303
|
|
286
304
|
output = self._parse_output(
|
287
|
-
conversation=memory.message_history,
|
305
|
+
conversation=memory.message_history,
|
306
|
+
in_args=in_args,
|
307
|
+
ctx=ctx,
|
308
|
+
call_id=call_id,
|
288
309
|
)
|
289
310
|
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
290
311
|
|
291
312
|
def _print_messages(
|
292
313
|
self,
|
293
314
|
messages: Sequence[Message],
|
315
|
+
ctx: RunContext[CtxT],
|
294
316
|
call_id: str,
|
295
|
-
ctx: RunContext[CtxT] | None = None,
|
296
317
|
) -> None:
|
297
318
|
if ctx and ctx.printer:
|
298
319
|
ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
|
@@ -321,28 +342,31 @@ class LLMAgent(
|
|
321
342
|
if cur_cls.memory_manager is not base_cls.memory_manager:
|
322
343
|
self._policy_executor.memory_manager = self.memory_manager
|
323
344
|
|
324
|
-
def system_prompt_builder(self, ctx: RunContext[CtxT]
|
345
|
+
def system_prompt_builder(self, ctx: RunContext[CtxT], call_id: str) -> str | None:
|
325
346
|
if self._prompt_builder.system_prompt_builder is not None:
|
326
|
-
return self._prompt_builder.system_prompt_builder(ctx=ctx)
|
347
|
+
return self._prompt_builder.system_prompt_builder(ctx=ctx, call_id=call_id)
|
327
348
|
raise NotImplementedError("System prompt builder is not implemented.")
|
328
349
|
|
329
350
|
def input_content_builder(
|
330
|
-
self, in_args: InT
|
351
|
+
self, in_args: InT, ctx: RunContext[CtxT], call_id: str
|
331
352
|
) -> Content:
|
332
353
|
if self._prompt_builder.input_content_builder is not None:
|
333
|
-
return self._prompt_builder.input_content_builder(
|
354
|
+
return self._prompt_builder.input_content_builder(
|
355
|
+
in_args=in_args, ctx=ctx, call_id=call_id
|
356
|
+
)
|
334
357
|
raise NotImplementedError("Input content builder is not implemented.")
|
335
358
|
|
336
359
|
def tool_call_loop_terminator(
|
337
360
|
self,
|
338
361
|
conversation: Messages,
|
339
362
|
*,
|
340
|
-
ctx: RunContext[CtxT]
|
363
|
+
ctx: RunContext[CtxT],
|
364
|
+
call_id: str,
|
341
365
|
**kwargs: Any,
|
342
366
|
) -> bool:
|
343
367
|
if self._policy_executor.tool_call_loop_terminator is not None:
|
344
368
|
return self._policy_executor.tool_call_loop_terminator(
|
345
|
-
conversation=conversation, ctx=ctx, **kwargs
|
369
|
+
conversation=conversation, ctx=ctx, call_id=call_id, **kwargs
|
346
370
|
)
|
347
371
|
raise NotImplementedError("Tool call loop terminator is not implemented.")
|
348
372
|
|
@@ -350,12 +374,13 @@ class LLMAgent(
|
|
350
374
|
self,
|
351
375
|
memory: LLMAgentMemory,
|
352
376
|
*,
|
353
|
-
ctx: RunContext[CtxT]
|
377
|
+
ctx: RunContext[CtxT],
|
378
|
+
call_id: str,
|
354
379
|
**kwargs: Any,
|
355
380
|
) -> None:
|
356
381
|
if self._policy_executor.memory_manager is not None:
|
357
382
|
return self._policy_executor.memory_manager(
|
358
|
-
memory=memory, ctx=ctx, **kwargs
|
383
|
+
memory=memory, ctx=ctx, call_id=call_id, **kwargs
|
359
384
|
)
|
360
385
|
raise NotImplementedError("Memory manager is not implemented.")
|
361
386
|
|
@@ -391,12 +416,11 @@ class LLMAgent(
|
|
391
416
|
self, func: OutputParser[InT, OutT, CtxT]
|
392
417
|
) -> OutputParser[InT, OutT, CtxT]:
|
393
418
|
if self._used_default_llm_response_schema:
|
394
|
-
self._policy_executor.
|
419
|
+
self._policy_executor.response_schema = None
|
395
420
|
self.output_parser = func
|
396
421
|
|
397
422
|
return func
|
398
423
|
|
399
424
|
def add_memory_preparator(self, func: MemoryPreparator) -> MemoryPreparator:
|
400
425
|
self.memory_preparator = func
|
401
|
-
|
402
426
|
return func
|
grasp_agents/llm_agent_memory.py
CHANGED
@@ -13,14 +13,16 @@ class MemoryPreparator(Protocol):
|
|
13
13
|
def __call__(
|
14
14
|
self,
|
15
15
|
memory: "LLMAgentMemory",
|
16
|
+
*,
|
16
17
|
in_args: Any | None,
|
17
18
|
sys_prompt: LLMPrompt | None,
|
18
|
-
ctx: RunContext[Any]
|
19
|
+
ctx: RunContext[Any],
|
20
|
+
call_id: str,
|
19
21
|
) -> None: ...
|
20
22
|
|
21
23
|
|
22
24
|
class LLMAgentMemory(Memory):
|
23
|
-
_message_history: Messages = PrivateAttr(default_factory=
|
25
|
+
_message_history: Messages = PrivateAttr(default_factory=Messages)
|
24
26
|
_sys_prompt: LLMPrompt | None = PrivateAttr(default=None)
|
25
27
|
|
26
28
|
def __init__(self, sys_prompt: LLMPrompt | None = None) -> None:
|