grasp_agents 0.5.10__py3-none-any.whl → 0.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +3 -0
- grasp_agents/cloud_llm.py +15 -15
- grasp_agents/generics_utils.py +1 -1
- grasp_agents/litellm/lite_llm.py +3 -0
- grasp_agents/llm_agent.py +63 -38
- grasp_agents/llm_agent_memory.py +1 -0
- grasp_agents/llm_policy_executor.py +40 -45
- grasp_agents/openai/openai_llm.py +4 -1
- grasp_agents/printer.py +153 -136
- grasp_agents/processors/base_processor.py +5 -3
- grasp_agents/processors/parallel_processor.py +2 -2
- grasp_agents/processors/processor.py +2 -2
- grasp_agents/prompt_builder.py +23 -7
- grasp_agents/run_context.py +2 -9
- grasp_agents/typing/tool.py +5 -3
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/METADATA +7 -20
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/RECORD +19 -19
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/__init__.py
CHANGED
@@ -6,6 +6,7 @@ from .llm_agent import LLMAgent
|
|
6
6
|
from .llm_agent_memory import LLMAgentMemory
|
7
7
|
from .memory import Memory
|
8
8
|
from .packet import Packet
|
9
|
+
from .printer import Printer, print_event_stream
|
9
10
|
from .processors.base_processor import BaseProcessor
|
10
11
|
from .processors.parallel_processor import ParallelProcessor
|
11
12
|
from .processors.processor import Processor
|
@@ -33,9 +34,11 @@ __all__ = [
|
|
33
34
|
"Packet",
|
34
35
|
"Packet",
|
35
36
|
"ParallelProcessor",
|
37
|
+
"Printer",
|
36
38
|
"ProcName",
|
37
39
|
"Processor",
|
38
40
|
"RunContext",
|
39
41
|
"SystemMessage",
|
40
42
|
"UserMessage",
|
43
|
+
"print_event_stream",
|
41
44
|
]
|
grasp_agents/cloud_llm.py
CHANGED
@@ -61,7 +61,6 @@ LLMRateLimiter = RateLimiterC[
|
|
61
61
|
|
62
62
|
@dataclass(frozen=True)
|
63
63
|
class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
|
64
|
-
# Make this field keyword-only to avoid ordering issues with inherited defaulted fields
|
65
64
|
api_provider: APIProvider | None = None
|
66
65
|
llm_settings: SettingsT_co | None = None
|
67
66
|
rate_limiter: LLMRateLimiter | None = None
|
@@ -70,6 +69,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
70
69
|
0 # LLM response retries: try to regenerate to pass validation
|
71
70
|
)
|
72
71
|
apply_response_schema_via_provider: bool = False
|
72
|
+
apply_tool_call_schema_via_provider: bool = False
|
73
73
|
async_http_client: httpx.AsyncClient | None = None
|
74
74
|
async_http_client_params: dict[str, Any] | AsyncHTTPClientParams | None = None
|
75
75
|
|
@@ -80,6 +80,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
80
80
|
f"{self.rate_limiter.rpm} RPM"
|
81
81
|
)
|
82
82
|
|
83
|
+
if self.apply_response_schema_via_provider:
|
84
|
+
object.__setattr__(self, "apply_tool_call_schema_via_provider", True)
|
85
|
+
|
83
86
|
if self.async_http_client is None and self.async_http_client_params is not None:
|
84
87
|
object.__setattr__(
|
85
88
|
self,
|
@@ -100,7 +103,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
100
103
|
api_tools = None
|
101
104
|
api_tool_choice = None
|
102
105
|
if tools:
|
103
|
-
strict = True if self.
|
106
|
+
strict = True if self.apply_tool_call_schema_via_provider else None
|
104
107
|
api_tools = [
|
105
108
|
self.converters.to_tool(t, strict=strict) for t in tools.values()
|
106
109
|
]
|
@@ -175,8 +178,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
175
178
|
response_schema=response_schema,
|
176
179
|
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
177
180
|
)
|
178
|
-
|
179
|
-
|
181
|
+
if not self.apply_tool_call_schema_via_provider and tools is not None:
|
182
|
+
self._validate_tool_calls(completion, tools=tools)
|
180
183
|
|
181
184
|
return completion
|
182
185
|
|
@@ -208,17 +211,16 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
208
211
|
|
209
212
|
if n_attempt > self.max_response_retries:
|
210
213
|
if n_attempt == 1:
|
211
|
-
logger.warning(f"\nCloudLLM completion
|
214
|
+
logger.warning(f"\nCloudLLM completion failed:\n{err}")
|
212
215
|
if n_attempt > 1:
|
213
216
|
logger.warning(
|
214
|
-
f"\nCloudLLM completion
|
217
|
+
f"\nCloudLLM completion failed after retrying:\n{err}"
|
215
218
|
)
|
216
219
|
raise err
|
217
220
|
# return make_refusal_completion(self._model_name, err)
|
218
221
|
|
219
222
|
logger.warning(
|
220
|
-
f"\nCloudLLM completion
|
221
|
-
f"\n{err}"
|
223
|
+
f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
|
222
224
|
)
|
223
225
|
|
224
226
|
return make_refusal_completion(
|
@@ -282,8 +284,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
282
284
|
response_schema=response_schema,
|
283
285
|
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
284
286
|
)
|
285
|
-
|
286
|
-
|
287
|
+
if not self.apply_tool_call_schema_via_provider and tools is not None:
|
288
|
+
self._validate_tool_calls(completion, tools=tools)
|
287
289
|
|
288
290
|
return iterator()
|
289
291
|
|
@@ -327,11 +329,10 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
327
329
|
n_attempt += 1
|
328
330
|
if n_attempt > self.max_response_retries:
|
329
331
|
if n_attempt == 1:
|
330
|
-
logger.warning(f"\nCloudLLM completion
|
332
|
+
logger.warning(f"\nCloudLLM completion failed:\n{err}")
|
331
333
|
if n_attempt > 1:
|
332
334
|
logger.warning(
|
333
|
-
"\nCloudLLM completion
|
334
|
-
f"retrying:\n{err}"
|
335
|
+
f"\nCloudLLM completion failed after retrying:\n{err}"
|
335
336
|
)
|
336
337
|
refusal_completion = make_refusal_completion(
|
337
338
|
self.model_name, err
|
@@ -345,6 +346,5 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
345
346
|
# return
|
346
347
|
|
347
348
|
logger.warning(
|
348
|
-
"\nCloudLLM completion
|
349
|
-
f"(retry attempt {n_attempt}):\n{err}"
|
349
|
+
f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
|
350
350
|
)
|
grasp_agents/generics_utils.py
CHANGED
@@ -159,7 +159,7 @@ class AutoInstanceAttributesMixin:
|
|
159
159
|
attr_type = resolved_attr_types[attr_name]
|
160
160
|
# attr_type = None if _attr_type is type(None) else _attr_type
|
161
161
|
else:
|
162
|
-
attr_type =
|
162
|
+
attr_type = object
|
163
163
|
|
164
164
|
if attr_name in pyd_private:
|
165
165
|
pyd_private[attr_name] = attr_type
|
grasp_agents/litellm/lite_llm.py
CHANGED
@@ -149,6 +149,9 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
|
|
149
149
|
n_choices: int | None = None,
|
150
150
|
**api_llm_settings: Any,
|
151
151
|
) -> LiteLLMCompletion:
|
152
|
+
if api_llm_settings and api_llm_settings.get("stream_options"):
|
153
|
+
api_llm_settings.pop("stream_options")
|
154
|
+
|
152
155
|
completion = await litellm.acompletion( # type: ignore[no-untyped-call]
|
153
156
|
model=self.model_name,
|
154
157
|
messages=api_messages,
|
grasp_agents/llm_agent.py
CHANGED
@@ -42,6 +42,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
|
|
42
42
|
*,
|
43
43
|
in_args: _InT_contra | None,
|
44
44
|
ctx: RunContext[CtxT],
|
45
|
+
call_id: str,
|
45
46
|
) -> _OutT_co: ...
|
46
47
|
|
47
48
|
|
@@ -169,10 +170,15 @@ class LLMAgent(
|
|
169
170
|
in_args: InT | None = None,
|
170
171
|
sys_prompt: LLMPrompt | None = None,
|
171
172
|
ctx: RunContext[Any],
|
173
|
+
call_id: str,
|
172
174
|
) -> None:
|
173
175
|
if self.memory_preparator:
|
174
176
|
return self.memory_preparator(
|
175
|
-
memory=memory,
|
177
|
+
memory=memory,
|
178
|
+
in_args=in_args,
|
179
|
+
sys_prompt=sys_prompt,
|
180
|
+
ctx=ctx,
|
181
|
+
call_id=call_id,
|
176
182
|
)
|
177
183
|
|
178
184
|
def _memorize_inputs(
|
@@ -182,8 +188,11 @@ class LLMAgent(
|
|
182
188
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
183
189
|
in_args: InT | None = None,
|
184
190
|
ctx: RunContext[CtxT],
|
191
|
+
call_id: str,
|
185
192
|
) -> tuple[SystemMessage | None, UserMessage | None]:
|
186
|
-
formatted_sys_prompt = self._prompt_builder.build_system_prompt(
|
193
|
+
formatted_sys_prompt = self._prompt_builder.build_system_prompt(
|
194
|
+
ctx=ctx, call_id=call_id
|
195
|
+
)
|
187
196
|
|
188
197
|
system_message: SystemMessage | None = None
|
189
198
|
if self._reset_memory_on_run or memory.is_empty:
|
@@ -192,24 +201,22 @@ class LLMAgent(
|
|
192
201
|
system_message = cast("SystemMessage", memory.message_history[0])
|
193
202
|
else:
|
194
203
|
self._prepare_memory(
|
195
|
-
memory=memory,
|
204
|
+
memory=memory,
|
205
|
+
in_args=in_args,
|
206
|
+
sys_prompt=formatted_sys_prompt,
|
207
|
+
ctx=ctx,
|
208
|
+
call_id=call_id,
|
196
209
|
)
|
197
210
|
|
198
211
|
input_message = self._prompt_builder.build_input_message(
|
199
|
-
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
212
|
+
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx, call_id=call_id
|
200
213
|
)
|
201
214
|
if input_message:
|
202
215
|
memory.update([input_message])
|
203
216
|
|
204
217
|
return system_message, input_message
|
205
218
|
|
206
|
-
def
|
207
|
-
self,
|
208
|
-
conversation: Messages,
|
209
|
-
*,
|
210
|
-
in_args: InT | None = None,
|
211
|
-
ctx: RunContext[CtxT],
|
212
|
-
) -> OutT:
|
219
|
+
def parse_output_default(self, conversation: Messages) -> OutT:
|
213
220
|
return validate_obj_from_json_or_py_string(
|
214
221
|
str(conversation[-1].content or ""),
|
215
222
|
schema=self._out_type,
|
@@ -223,15 +230,14 @@ class LLMAgent(
|
|
223
230
|
*,
|
224
231
|
in_args: InT | None = None,
|
225
232
|
ctx: RunContext[CtxT],
|
233
|
+
call_id: str,
|
226
234
|
) -> OutT:
|
227
235
|
if self.output_parser:
|
228
236
|
return self.output_parser(
|
229
|
-
conversation=conversation, in_args=in_args, ctx=ctx
|
237
|
+
conversation=conversation, in_args=in_args, ctx=ctx, call_id=call_id
|
230
238
|
)
|
231
239
|
|
232
|
-
return self.
|
233
|
-
conversation=conversation, in_args=in_args, ctx=ctx
|
234
|
-
)
|
240
|
+
return self.parse_output_default(conversation)
|
235
241
|
|
236
242
|
async def _process(
|
237
243
|
self,
|
@@ -239,24 +245,28 @@ class LLMAgent(
|
|
239
245
|
*,
|
240
246
|
in_args: InT | None = None,
|
241
247
|
memory: LLMAgentMemory,
|
242
|
-
call_id: str,
|
243
248
|
ctx: RunContext[CtxT],
|
249
|
+
call_id: str,
|
244
250
|
) -> OutT:
|
245
251
|
system_message, input_message = self._memorize_inputs(
|
246
252
|
memory=memory,
|
247
253
|
chat_inputs=chat_inputs,
|
248
254
|
in_args=in_args,
|
249
255
|
ctx=ctx,
|
256
|
+
call_id=call_id,
|
250
257
|
)
|
251
258
|
if system_message:
|
252
|
-
self._print_messages([system_message],
|
259
|
+
self._print_messages([system_message], ctx=ctx, call_id=call_id)
|
253
260
|
if input_message:
|
254
|
-
self._print_messages([input_message],
|
261
|
+
self._print_messages([input_message], ctx=ctx, call_id=call_id)
|
255
262
|
|
256
|
-
await self._policy_executor.execute(memory,
|
263
|
+
await self._policy_executor.execute(memory, ctx=ctx, call_id=call_id)
|
257
264
|
|
258
265
|
return self._parse_output(
|
259
|
-
conversation=memory.message_history,
|
266
|
+
conversation=memory.message_history,
|
267
|
+
in_args=in_args,
|
268
|
+
ctx=ctx,
|
269
|
+
call_id=call_id,
|
260
270
|
)
|
261
271
|
|
262
272
|
async def _process_stream(
|
@@ -265,43 +275,44 @@ class LLMAgent(
|
|
265
275
|
*,
|
266
276
|
in_args: InT | None = None,
|
267
277
|
memory: LLMAgentMemory,
|
268
|
-
call_id: str,
|
269
278
|
ctx: RunContext[CtxT],
|
279
|
+
call_id: str,
|
270
280
|
) -> AsyncIterator[Event[Any]]:
|
271
281
|
system_message, input_message = self._memorize_inputs(
|
272
282
|
memory=memory,
|
273
283
|
chat_inputs=chat_inputs,
|
274
284
|
in_args=in_args,
|
275
285
|
ctx=ctx,
|
286
|
+
call_id=call_id,
|
276
287
|
)
|
277
288
|
if system_message:
|
278
|
-
self._print_messages([system_message],
|
289
|
+
self._print_messages([system_message], ctx=ctx, call_id=call_id)
|
279
290
|
yield SystemMessageEvent(
|
280
291
|
data=system_message, proc_name=self.name, call_id=call_id
|
281
292
|
)
|
282
293
|
if input_message:
|
283
|
-
self._print_messages([input_message],
|
294
|
+
self._print_messages([input_message], ctx=ctx, call_id=call_id)
|
284
295
|
yield UserMessageEvent(
|
285
296
|
data=input_message, proc_name=self.name, call_id=call_id
|
286
297
|
)
|
287
298
|
|
288
299
|
async for event in self._policy_executor.execute_stream(
|
289
|
-
memory,
|
300
|
+
memory, ctx=ctx, call_id=call_id
|
290
301
|
):
|
291
302
|
yield event
|
292
303
|
|
293
304
|
output = self._parse_output(
|
294
|
-
conversation=memory.message_history,
|
305
|
+
conversation=memory.message_history,
|
306
|
+
in_args=in_args,
|
307
|
+
ctx=ctx,
|
308
|
+
call_id=call_id,
|
295
309
|
)
|
296
310
|
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
297
311
|
|
298
312
|
def _print_messages(
|
299
|
-
self,
|
300
|
-
messages: Sequence[Message],
|
301
|
-
call_id: str,
|
302
|
-
ctx: RunContext[CtxT],
|
313
|
+
self, messages: Sequence[Message], ctx: RunContext[CtxT], call_id: str
|
303
314
|
) -> None:
|
304
|
-
if ctx
|
315
|
+
if ctx.printer:
|
305
316
|
ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
|
306
317
|
|
307
318
|
# -- Override these methods in subclasses if needed --
|
@@ -328,31 +339,45 @@ class LLMAgent(
|
|
328
339
|
if cur_cls.memory_manager is not base_cls.memory_manager:
|
329
340
|
self._policy_executor.memory_manager = self.memory_manager
|
330
341
|
|
331
|
-
def system_prompt_builder(self, ctx: RunContext[CtxT]) -> str | None:
|
342
|
+
def system_prompt_builder(self, ctx: RunContext[CtxT], call_id: str) -> str | None:
|
332
343
|
if self._prompt_builder.system_prompt_builder is not None:
|
333
|
-
return self._prompt_builder.system_prompt_builder(ctx=ctx)
|
344
|
+
return self._prompt_builder.system_prompt_builder(ctx=ctx, call_id=call_id)
|
334
345
|
raise NotImplementedError("System prompt builder is not implemented.")
|
335
346
|
|
336
|
-
def input_content_builder(
|
347
|
+
def input_content_builder(
|
348
|
+
self, in_args: InT, ctx: RunContext[CtxT], call_id: str
|
349
|
+
) -> Content:
|
337
350
|
if self._prompt_builder.input_content_builder is not None:
|
338
|
-
return self._prompt_builder.input_content_builder(
|
351
|
+
return self._prompt_builder.input_content_builder(
|
352
|
+
in_args=in_args, ctx=ctx, call_id=call_id
|
353
|
+
)
|
339
354
|
raise NotImplementedError("Input content builder is not implemented.")
|
340
355
|
|
341
356
|
def tool_call_loop_terminator(
|
342
|
-
self,
|
357
|
+
self,
|
358
|
+
conversation: Messages,
|
359
|
+
*,
|
360
|
+
ctx: RunContext[CtxT],
|
361
|
+
call_id: str,
|
362
|
+
**kwargs: Any,
|
343
363
|
) -> bool:
|
344
364
|
if self._policy_executor.tool_call_loop_terminator is not None:
|
345
365
|
return self._policy_executor.tool_call_loop_terminator(
|
346
|
-
conversation=conversation, ctx=ctx, **kwargs
|
366
|
+
conversation=conversation, ctx=ctx, call_id=call_id, **kwargs
|
347
367
|
)
|
348
368
|
raise NotImplementedError("Tool call loop terminator is not implemented.")
|
349
369
|
|
350
370
|
def memory_manager(
|
351
|
-
self,
|
371
|
+
self,
|
372
|
+
memory: LLMAgentMemory,
|
373
|
+
*,
|
374
|
+
ctx: RunContext[CtxT],
|
375
|
+
call_id: str,
|
376
|
+
**kwargs: Any,
|
352
377
|
) -> None:
|
353
378
|
if self._policy_executor.memory_manager is not None:
|
354
379
|
return self._policy_executor.memory_manager(
|
355
|
-
memory=memory, ctx=ctx, **kwargs
|
380
|
+
memory=memory, ctx=ctx, call_id=call_id, **kwargs
|
356
381
|
)
|
357
382
|
raise NotImplementedError("Memory manager is not implemented.")
|
358
383
|
|