grasp_agents 0.5.10__py3-none-any.whl → 0.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +3 -0
- grasp_agents/cloud_llm.py +15 -15
- grasp_agents/generics_utils.py +1 -1
- grasp_agents/litellm/lite_llm.py +3 -0
- grasp_agents/llm_agent.py +63 -38
- grasp_agents/llm_agent_memory.py +1 -0
- grasp_agents/llm_policy_executor.py +40 -45
- grasp_agents/openai/openai_llm.py +4 -1
- grasp_agents/printer.py +153 -136
- grasp_agents/processors/base_processor.py +5 -3
- grasp_agents/processors/parallel_processor.py +2 -2
- grasp_agents/processors/processor.py +2 -2
- grasp_agents/prompt_builder.py +23 -7
- grasp_agents/run_context.py +2 -9
- grasp_agents/typing/tool.py +5 -3
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/METADATA +7 -20
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/RECORD +19 -19
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/licenses/LICENSE.md +0 -0
@@ -37,6 +37,7 @@ class ToolCallLoopTerminator(Protocol[CtxT]):
|
|
37
37
|
conversation: Messages,
|
38
38
|
*,
|
39
39
|
ctx: RunContext[CtxT],
|
40
|
+
call_id: str,
|
40
41
|
**kwargs: Any,
|
41
42
|
) -> bool: ...
|
42
43
|
|
@@ -47,6 +48,7 @@ class MemoryManager(Protocol[CtxT]):
|
|
47
48
|
memory: LLMAgentMemory,
|
48
49
|
*,
|
49
50
|
ctx: RunContext[CtxT],
|
51
|
+
call_id: str,
|
50
52
|
**kwargs: Any,
|
51
53
|
) -> None: ...
|
52
54
|
|
@@ -123,10 +125,13 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
123
125
|
conversation: Messages,
|
124
126
|
*,
|
125
127
|
ctx: RunContext[CtxT],
|
128
|
+
call_id: str,
|
126
129
|
**kwargs: Any,
|
127
130
|
) -> bool:
|
128
131
|
if self.tool_call_loop_terminator:
|
129
|
-
return self.tool_call_loop_terminator(
|
132
|
+
return self.tool_call_loop_terminator(
|
133
|
+
conversation, ctx=ctx, call_id=call_id, **kwargs
|
134
|
+
)
|
130
135
|
|
131
136
|
return False
|
132
137
|
|
@@ -136,18 +141,19 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
136
141
|
memory: LLMAgentMemory,
|
137
142
|
*,
|
138
143
|
ctx: RunContext[CtxT],
|
144
|
+
call_id: str,
|
139
145
|
**kwargs: Any,
|
140
146
|
) -> None:
|
141
147
|
if self.memory_manager:
|
142
|
-
self.memory_manager(memory=memory, ctx=ctx, **kwargs)
|
148
|
+
self.memory_manager(memory=memory, ctx=ctx, call_id=call_id, **kwargs)
|
143
149
|
|
144
150
|
async def generate_message(
|
145
151
|
self,
|
146
152
|
memory: LLMAgentMemory,
|
147
153
|
*,
|
148
|
-
call_id: str,
|
149
154
|
tool_choice: ToolChoice | None = None,
|
150
155
|
ctx: RunContext[CtxT],
|
156
|
+
call_id: str,
|
151
157
|
) -> AssistantMessage:
|
152
158
|
completion = await self.llm.generate_completion(
|
153
159
|
memory.message_history,
|
@@ -155,14 +161,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
155
161
|
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
156
162
|
tools=self.tools,
|
157
163
|
tool_choice=tool_choice,
|
158
|
-
n_choices=1,
|
159
164
|
proc_name=self.agent_name,
|
160
165
|
call_id=call_id,
|
161
166
|
)
|
162
167
|
memory.update(completion.messages)
|
163
|
-
self._process_completion(
|
164
|
-
completion, call_id=call_id, ctx=ctx, print_messages=True
|
165
|
-
)
|
168
|
+
self._process_completion(completion, ctx=ctx, call_id=call_id)
|
166
169
|
|
167
170
|
return completion.messages[0]
|
168
171
|
|
@@ -170,9 +173,9 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
170
173
|
self,
|
171
174
|
memory: LLMAgentMemory,
|
172
175
|
*,
|
173
|
-
call_id: str,
|
174
176
|
tool_choice: ToolChoice | None = None,
|
175
177
|
ctx: RunContext[CtxT],
|
178
|
+
call_id: str,
|
176
179
|
) -> AsyncIterator[
|
177
180
|
CompletionChunkEvent[CompletionChunk]
|
178
181
|
| CompletionEvent
|
@@ -187,7 +190,6 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
187
190
|
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
188
191
|
tools=self.tools,
|
189
192
|
tool_choice=tool_choice,
|
190
|
-
n_choices=1,
|
191
193
|
proc_name=self.agent_name,
|
192
194
|
call_id=call_id,
|
193
195
|
)
|
@@ -206,23 +208,21 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
206
208
|
|
207
209
|
memory.update(completion.messages)
|
208
210
|
|
209
|
-
self._process_completion(
|
210
|
-
completion, call_id=call_id, print_messages=True, ctx=ctx
|
211
|
-
)
|
211
|
+
self._process_completion(completion, ctx=ctx, call_id=call_id)
|
212
212
|
|
213
213
|
async def call_tools(
|
214
214
|
self,
|
215
215
|
calls: Sequence[ToolCall],
|
216
216
|
memory: LLMAgentMemory,
|
217
|
-
call_id: str,
|
218
217
|
ctx: RunContext[CtxT],
|
218
|
+
call_id: str,
|
219
219
|
) -> Sequence[ToolMessage]:
|
220
220
|
# TODO: Add image support
|
221
221
|
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
222
222
|
for call in calls:
|
223
223
|
tool = self.tools[call.tool_name]
|
224
224
|
args = json.loads(call.tool_arguments)
|
225
|
-
corouts.append(tool(
|
225
|
+
corouts.append(tool(ctx=ctx, call_id=call_id, **args))
|
226
226
|
|
227
227
|
outs = await asyncio.gather(*corouts)
|
228
228
|
tool_messages = list(
|
@@ -231,7 +231,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
231
231
|
|
232
232
|
memory.update(tool_messages)
|
233
233
|
|
234
|
-
if ctx
|
234
|
+
if ctx.printer:
|
235
235
|
ctx.printer.print_messages(
|
236
236
|
tool_messages, agent_name=self.agent_name, call_id=call_id
|
237
237
|
)
|
@@ -242,11 +242,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
242
242
|
self,
|
243
243
|
calls: Sequence[ToolCall],
|
244
244
|
memory: LLMAgentMemory,
|
245
|
-
call_id: str,
|
246
245
|
ctx: RunContext[CtxT],
|
246
|
+
call_id: str,
|
247
247
|
) -> AsyncIterator[ToolMessageEvent]:
|
248
248
|
tool_messages = await self.call_tools(
|
249
|
-
calls, memory=memory,
|
249
|
+
calls, memory=memory, ctx=ctx, call_id=call_id
|
250
250
|
)
|
251
251
|
for tool_message, call in zip(tool_messages, calls, strict=True):
|
252
252
|
yield ToolMessageEvent(
|
@@ -271,20 +271,20 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
271
271
|
return final_answer_message
|
272
272
|
|
273
273
|
async def _generate_final_answer(
|
274
|
-
self, memory: LLMAgentMemory,
|
274
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
|
275
275
|
) -> AssistantMessage:
|
276
276
|
user_message = UserMessage.from_text(
|
277
277
|
"Exceeded the maximum number of turns: provide a final answer now!"
|
278
278
|
)
|
279
279
|
memory.update([user_message])
|
280
|
-
if ctx
|
280
|
+
if ctx.printer:
|
281
281
|
ctx.printer.print_messages(
|
282
282
|
[user_message], agent_name=self.agent_name, call_id=call_id
|
283
283
|
)
|
284
284
|
|
285
285
|
tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
|
286
286
|
await self.generate_message(
|
287
|
-
memory, tool_choice=tool_choice,
|
287
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
288
288
|
)
|
289
289
|
|
290
290
|
final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
|
@@ -294,7 +294,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
294
294
|
return final_answer_message
|
295
295
|
|
296
296
|
async def _generate_final_answer_stream(
|
297
|
-
self, memory: LLMAgentMemory,
|
297
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
|
298
298
|
) -> AsyncIterator[Event[Any]]:
|
299
299
|
user_message = UserMessage.from_text(
|
300
300
|
"Exceeded the maximum number of turns: provide a final answer now!",
|
@@ -303,14 +303,14 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
303
303
|
yield UserMessageEvent(
|
304
304
|
proc_name=self.agent_name, call_id=call_id, data=user_message
|
305
305
|
)
|
306
|
-
if ctx
|
306
|
+
if ctx.printer:
|
307
307
|
ctx.printer.print_messages(
|
308
308
|
[user_message], agent_name=self.agent_name, call_id=call_id
|
309
309
|
)
|
310
310
|
|
311
311
|
tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
|
312
312
|
async for event in self.generate_message_stream(
|
313
|
-
memory, tool_choice=tool_choice,
|
313
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
314
314
|
):
|
315
315
|
yield event
|
316
316
|
|
@@ -322,7 +322,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
322
322
|
)
|
323
323
|
|
324
324
|
async def execute(
|
325
|
-
self, memory: LLMAgentMemory,
|
325
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
|
326
326
|
) -> AssistantMessage | Sequence[AssistantMessage]:
|
327
327
|
# 1. Generate the first message:
|
328
328
|
# In ReAct mode, we generate the first message without tool calls
|
@@ -332,7 +332,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
332
332
|
if self.tools:
|
333
333
|
tool_choice = "none" if self._react_mode else "auto"
|
334
334
|
gen_message = await self.generate_message(
|
335
|
-
memory, tool_choice=tool_choice,
|
335
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
336
336
|
)
|
337
337
|
if not self.tools:
|
338
338
|
return gen_message
|
@@ -345,7 +345,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
345
345
|
# If a final answer is not provided via a tool call, we use
|
346
346
|
# _terminate_tool_call_loop to determine whether to exit the loop.
|
347
347
|
if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
|
348
|
-
memory.message_history, ctx=ctx, num_turns=turns
|
348
|
+
memory.message_history, ctx=ctx, call_id=call_id, num_turns=turns
|
349
349
|
):
|
350
350
|
return gen_message
|
351
351
|
|
@@ -364,7 +364,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
364
364
|
# Otherwise, we simply return the last generated message.
|
365
365
|
if self._final_answer_as_tool_call:
|
366
366
|
final_answer = await self._generate_final_answer(
|
367
|
-
memory,
|
367
|
+
memory, ctx=ctx, call_id=call_id
|
368
368
|
)
|
369
369
|
else:
|
370
370
|
final_answer = gen_message
|
@@ -377,11 +377,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
377
377
|
|
378
378
|
if gen_message.tool_calls:
|
379
379
|
await self.call_tools(
|
380
|
-
gen_message.tool_calls, memory=memory,
|
380
|
+
gen_message.tool_calls, memory=memory, ctx=ctx, call_id=call_id
|
381
381
|
)
|
382
382
|
|
383
383
|
# Apply memory management (e.g. compacting or pruning memory)
|
384
|
-
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
384
|
+
self._manage_memory(memory, ctx=ctx, call_id=call_id, num_turns=turns)
|
385
385
|
|
386
386
|
# 4. Generate the next message based on the updated memory.
|
387
387
|
# In ReAct mode, we set tool_choice to "none" if we just called tools,
|
@@ -396,7 +396,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
396
396
|
tool_choice = "required"
|
397
397
|
|
398
398
|
gen_message = await self.generate_message(
|
399
|
-
memory, tool_choice=tool_choice,
|
399
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
400
400
|
)
|
401
401
|
|
402
402
|
turns += 1
|
@@ -404,13 +404,13 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
404
404
|
async def execute_stream(
|
405
405
|
self,
|
406
406
|
memory: LLMAgentMemory,
|
407
|
-
call_id: str,
|
408
407
|
ctx: RunContext[CtxT],
|
408
|
+
call_id: str,
|
409
409
|
) -> AsyncIterator[Event[Any]]:
|
410
410
|
tool_choice: ToolChoice = "none" if self._react_mode else "auto"
|
411
411
|
gen_message: AssistantMessage | None = None
|
412
412
|
async for event in self.generate_message_stream(
|
413
|
-
memory, tool_choice=tool_choice,
|
413
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
414
414
|
):
|
415
415
|
if isinstance(event, GenMessageEvent):
|
416
416
|
gen_message = event.data
|
@@ -425,7 +425,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
425
425
|
|
426
426
|
while True:
|
427
427
|
if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
|
428
|
-
memory.message_history, ctx=ctx, num_turns=turns
|
428
|
+
memory.message_history, ctx=ctx, call_id=call_id, num_turns=turns
|
429
429
|
):
|
430
430
|
return
|
431
431
|
|
@@ -444,7 +444,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
444
444
|
if turns >= self.max_turns:
|
445
445
|
if self._final_answer_as_tool_call:
|
446
446
|
async for event in self._generate_final_answer_stream(
|
447
|
-
memory,
|
447
|
+
memory, ctx=ctx, call_id=call_id
|
448
448
|
):
|
449
449
|
yield event
|
450
450
|
logger.info(
|
@@ -459,11 +459,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
459
459
|
)
|
460
460
|
|
461
461
|
async for event in self.call_tools_stream(
|
462
|
-
gen_message.tool_calls, memory=memory,
|
462
|
+
gen_message.tool_calls, memory=memory, ctx=ctx, call_id=call_id
|
463
463
|
):
|
464
464
|
yield event
|
465
465
|
|
466
|
-
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
466
|
+
self._manage_memory(memory, ctx=ctx, call_id=call_id, num_turns=turns)
|
467
467
|
|
468
468
|
if self._react_mode and gen_message.tool_calls:
|
469
469
|
tool_choice = "none"
|
@@ -473,7 +473,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
473
473
|
tool_choice = "required"
|
474
474
|
|
475
475
|
async for event in self.generate_message_stream(
|
476
|
-
memory, tool_choice=tool_choice,
|
476
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
477
477
|
):
|
478
478
|
yield event
|
479
479
|
if isinstance(event, GenMessageEvent):
|
@@ -493,20 +493,15 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
493
493
|
self,
|
494
494
|
inp: BaseModel,
|
495
495
|
*,
|
496
|
-
call_id: str | None = None,
|
497
496
|
ctx: RunContext[Any] | None = None,
|
497
|
+
call_id: str | None = None,
|
498
498
|
) -> None:
|
499
499
|
return None
|
500
500
|
|
501
501
|
return FinalAnswerTool()
|
502
502
|
|
503
503
|
def _process_completion(
|
504
|
-
self,
|
505
|
-
completion: Completion,
|
506
|
-
*,
|
507
|
-
call_id: str,
|
508
|
-
print_messages: bool = False,
|
509
|
-
ctx: RunContext[CtxT],
|
504
|
+
self, completion: Completion, *, ctx: RunContext[CtxT], call_id: str
|
510
505
|
) -> None:
|
511
506
|
ctx.completions[self.agent_name].append(completion)
|
512
507
|
ctx.usage_tracker.update(
|
@@ -514,7 +509,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
514
509
|
completions=[completion],
|
515
510
|
model_name=self.llm.model_name,
|
516
511
|
)
|
517
|
-
if ctx.printer
|
512
|
+
if ctx.printer:
|
518
513
|
usages = [None] * (len(completion.messages) - 1) + [completion.usage]
|
519
514
|
ctx.printer.print_messages(
|
520
515
|
completion.messages,
|
@@ -60,7 +60,7 @@ def get_openai_compatible_providers() -> list[APIProvider]:
|
|
60
60
|
|
61
61
|
|
62
62
|
class OpenAILLMSettings(CloudLLMSettings, total=False):
|
63
|
-
reasoning_effort: Literal["low", "medium", "high"] | None
|
63
|
+
reasoning_effort: Literal["disable", "minimal", "low", "medium", "high"] | None
|
64
64
|
|
65
65
|
parallel_tool_calls: bool
|
66
66
|
|
@@ -172,6 +172,9 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
172
172
|
response_format = api_response_schema or NOT_GIVEN
|
173
173
|
n = n_choices or NOT_GIVEN
|
174
174
|
|
175
|
+
if api_llm_settings and api_llm_settings.get("stream_options"):
|
176
|
+
api_llm_settings.pop("stream_options")
|
177
|
+
|
175
178
|
if self.apply_response_schema_via_provider:
|
176
179
|
return await self.client.beta.chat.completions.parse(
|
177
180
|
model=self.model_name,
|