grasp_agents 0.5.10__py3-none-any.whl → 0.5.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +0 -1
- grasp_agents/llm_agent.py +62 -34
- grasp_agents/llm_agent_memory.py +1 -0
- grasp_agents/llm_policy_executor.py +36 -30
- grasp_agents/openai/openai_llm.py +1 -1
- grasp_agents/processors/base_processor.py +1 -1
- grasp_agents/processors/parallel_processor.py +2 -2
- grasp_agents/processors/processor.py +2 -2
- grasp_agents/prompt_builder.py +23 -7
- grasp_agents/typing/tool.py +5 -3
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.11.dist-info}/METADATA +4 -5
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.11.dist-info}/RECORD +14 -14
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.11.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.11.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/cloud_llm.py
CHANGED
@@ -61,7 +61,6 @@ LLMRateLimiter = RateLimiterC[
|
|
61
61
|
|
62
62
|
@dataclass(frozen=True)
|
63
63
|
class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
|
64
|
-
# Make this field keyword-only to avoid ordering issues with inherited defaulted fields
|
65
64
|
api_provider: APIProvider | None = None
|
66
65
|
llm_settings: SettingsT_co | None = None
|
67
66
|
rate_limiter: LLMRateLimiter | None = None
|
grasp_agents/llm_agent.py
CHANGED
@@ -42,6 +42,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
|
|
42
42
|
*,
|
43
43
|
in_args: _InT_contra | None,
|
44
44
|
ctx: RunContext[CtxT],
|
45
|
+
call_id: str,
|
45
46
|
) -> _OutT_co: ...
|
46
47
|
|
47
48
|
|
@@ -169,10 +170,15 @@ class LLMAgent(
|
|
169
170
|
in_args: InT | None = None,
|
170
171
|
sys_prompt: LLMPrompt | None = None,
|
171
172
|
ctx: RunContext[Any],
|
173
|
+
call_id: str,
|
172
174
|
) -> None:
|
173
175
|
if self.memory_preparator:
|
174
176
|
return self.memory_preparator(
|
175
|
-
memory=memory,
|
177
|
+
memory=memory,
|
178
|
+
in_args=in_args,
|
179
|
+
sys_prompt=sys_prompt,
|
180
|
+
ctx=ctx,
|
181
|
+
call_id=call_id,
|
176
182
|
)
|
177
183
|
|
178
184
|
def _memorize_inputs(
|
@@ -182,8 +188,11 @@ class LLMAgent(
|
|
182
188
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
183
189
|
in_args: InT | None = None,
|
184
190
|
ctx: RunContext[CtxT],
|
191
|
+
call_id: str,
|
185
192
|
) -> tuple[SystemMessage | None, UserMessage | None]:
|
186
|
-
formatted_sys_prompt = self._prompt_builder.build_system_prompt(
|
193
|
+
formatted_sys_prompt = self._prompt_builder.build_system_prompt(
|
194
|
+
ctx=ctx, call_id=call_id
|
195
|
+
)
|
187
196
|
|
188
197
|
system_message: SystemMessage | None = None
|
189
198
|
if self._reset_memory_on_run or memory.is_empty:
|
@@ -192,24 +201,22 @@ class LLMAgent(
|
|
192
201
|
system_message = cast("SystemMessage", memory.message_history[0])
|
193
202
|
else:
|
194
203
|
self._prepare_memory(
|
195
|
-
memory=memory,
|
204
|
+
memory=memory,
|
205
|
+
in_args=in_args,
|
206
|
+
sys_prompt=formatted_sys_prompt,
|
207
|
+
ctx=ctx,
|
208
|
+
call_id=call_id,
|
196
209
|
)
|
197
210
|
|
198
211
|
input_message = self._prompt_builder.build_input_message(
|
199
|
-
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
212
|
+
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx, call_id=call_id
|
200
213
|
)
|
201
214
|
if input_message:
|
202
215
|
memory.update([input_message])
|
203
216
|
|
204
217
|
return system_message, input_message
|
205
218
|
|
206
|
-
def
|
207
|
-
self,
|
208
|
-
conversation: Messages,
|
209
|
-
*,
|
210
|
-
in_args: InT | None = None,
|
211
|
-
ctx: RunContext[CtxT],
|
212
|
-
) -> OutT:
|
219
|
+
def parse_output_default(self, conversation: Messages) -> OutT:
|
213
220
|
return validate_obj_from_json_or_py_string(
|
214
221
|
str(conversation[-1].content or ""),
|
215
222
|
schema=self._out_type,
|
@@ -223,15 +230,14 @@ class LLMAgent(
|
|
223
230
|
*,
|
224
231
|
in_args: InT | None = None,
|
225
232
|
ctx: RunContext[CtxT],
|
233
|
+
call_id: str,
|
226
234
|
) -> OutT:
|
227
235
|
if self.output_parser:
|
228
236
|
return self.output_parser(
|
229
|
-
conversation=conversation, in_args=in_args, ctx=ctx
|
237
|
+
conversation=conversation, in_args=in_args, ctx=ctx, call_id=call_id
|
230
238
|
)
|
231
239
|
|
232
|
-
return self.
|
233
|
-
conversation=conversation, in_args=in_args, ctx=ctx
|
234
|
-
)
|
240
|
+
return self.parse_output_default(conversation)
|
235
241
|
|
236
242
|
async def _process(
|
237
243
|
self,
|
@@ -239,24 +245,28 @@ class LLMAgent(
|
|
239
245
|
*,
|
240
246
|
in_args: InT | None = None,
|
241
247
|
memory: LLMAgentMemory,
|
242
|
-
call_id: str,
|
243
248
|
ctx: RunContext[CtxT],
|
249
|
+
call_id: str,
|
244
250
|
) -> OutT:
|
245
251
|
system_message, input_message = self._memorize_inputs(
|
246
252
|
memory=memory,
|
247
253
|
chat_inputs=chat_inputs,
|
248
254
|
in_args=in_args,
|
249
255
|
ctx=ctx,
|
256
|
+
call_id=call_id,
|
250
257
|
)
|
251
258
|
if system_message:
|
252
|
-
self._print_messages([system_message],
|
259
|
+
self._print_messages([system_message], ctx=ctx, call_id=call_id)
|
253
260
|
if input_message:
|
254
|
-
self._print_messages([input_message],
|
261
|
+
self._print_messages([input_message], ctx=ctx, call_id=call_id)
|
255
262
|
|
256
|
-
await self._policy_executor.execute(memory,
|
263
|
+
await self._policy_executor.execute(memory, ctx=ctx, call_id=call_id)
|
257
264
|
|
258
265
|
return self._parse_output(
|
259
|
-
conversation=memory.message_history,
|
266
|
+
conversation=memory.message_history,
|
267
|
+
in_args=in_args,
|
268
|
+
ctx=ctx,
|
269
|
+
call_id=call_id,
|
260
270
|
)
|
261
271
|
|
262
272
|
async def _process_stream(
|
@@ -265,41 +275,45 @@ class LLMAgent(
|
|
265
275
|
*,
|
266
276
|
in_args: InT | None = None,
|
267
277
|
memory: LLMAgentMemory,
|
268
|
-
call_id: str,
|
269
278
|
ctx: RunContext[CtxT],
|
279
|
+
call_id: str,
|
270
280
|
) -> AsyncIterator[Event[Any]]:
|
271
281
|
system_message, input_message = self._memorize_inputs(
|
272
282
|
memory=memory,
|
273
283
|
chat_inputs=chat_inputs,
|
274
284
|
in_args=in_args,
|
275
285
|
ctx=ctx,
|
286
|
+
call_id=call_id,
|
276
287
|
)
|
277
288
|
if system_message:
|
278
|
-
self._print_messages([system_message],
|
289
|
+
self._print_messages([system_message], ctx=ctx, call_id=call_id)
|
279
290
|
yield SystemMessageEvent(
|
280
291
|
data=system_message, proc_name=self.name, call_id=call_id
|
281
292
|
)
|
282
293
|
if input_message:
|
283
|
-
self._print_messages([input_message],
|
294
|
+
self._print_messages([input_message], ctx=ctx, call_id=call_id)
|
284
295
|
yield UserMessageEvent(
|
285
296
|
data=input_message, proc_name=self.name, call_id=call_id
|
286
297
|
)
|
287
298
|
|
288
299
|
async for event in self._policy_executor.execute_stream(
|
289
|
-
memory,
|
300
|
+
memory, ctx=ctx, call_id=call_id
|
290
301
|
):
|
291
302
|
yield event
|
292
303
|
|
293
304
|
output = self._parse_output(
|
294
|
-
conversation=memory.message_history,
|
305
|
+
conversation=memory.message_history,
|
306
|
+
in_args=in_args,
|
307
|
+
ctx=ctx,
|
308
|
+
call_id=call_id,
|
295
309
|
)
|
296
310
|
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
297
311
|
|
298
312
|
def _print_messages(
|
299
313
|
self,
|
300
314
|
messages: Sequence[Message],
|
301
|
-
call_id: str,
|
302
315
|
ctx: RunContext[CtxT],
|
316
|
+
call_id: str,
|
303
317
|
) -> None:
|
304
318
|
if ctx and ctx.printer:
|
305
319
|
ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
|
@@ -328,31 +342,45 @@ class LLMAgent(
|
|
328
342
|
if cur_cls.memory_manager is not base_cls.memory_manager:
|
329
343
|
self._policy_executor.memory_manager = self.memory_manager
|
330
344
|
|
331
|
-
def system_prompt_builder(self, ctx: RunContext[CtxT]) -> str | None:
|
345
|
+
def system_prompt_builder(self, ctx: RunContext[CtxT], call_id: str) -> str | None:
|
332
346
|
if self._prompt_builder.system_prompt_builder is not None:
|
333
|
-
return self._prompt_builder.system_prompt_builder(ctx=ctx)
|
347
|
+
return self._prompt_builder.system_prompt_builder(ctx=ctx, call_id=call_id)
|
334
348
|
raise NotImplementedError("System prompt builder is not implemented.")
|
335
349
|
|
336
|
-
def input_content_builder(
|
350
|
+
def input_content_builder(
|
351
|
+
self, in_args: InT, ctx: RunContext[CtxT], call_id: str
|
352
|
+
) -> Content:
|
337
353
|
if self._prompt_builder.input_content_builder is not None:
|
338
|
-
return self._prompt_builder.input_content_builder(
|
354
|
+
return self._prompt_builder.input_content_builder(
|
355
|
+
in_args=in_args, ctx=ctx, call_id=call_id
|
356
|
+
)
|
339
357
|
raise NotImplementedError("Input content builder is not implemented.")
|
340
358
|
|
341
359
|
def tool_call_loop_terminator(
|
342
|
-
self,
|
360
|
+
self,
|
361
|
+
conversation: Messages,
|
362
|
+
*,
|
363
|
+
ctx: RunContext[CtxT],
|
364
|
+
call_id: str,
|
365
|
+
**kwargs: Any,
|
343
366
|
) -> bool:
|
344
367
|
if self._policy_executor.tool_call_loop_terminator is not None:
|
345
368
|
return self._policy_executor.tool_call_loop_terminator(
|
346
|
-
conversation=conversation, ctx=ctx, **kwargs
|
369
|
+
conversation=conversation, ctx=ctx, call_id=call_id, **kwargs
|
347
370
|
)
|
348
371
|
raise NotImplementedError("Tool call loop terminator is not implemented.")
|
349
372
|
|
350
373
|
def memory_manager(
|
351
|
-
self,
|
374
|
+
self,
|
375
|
+
memory: LLMAgentMemory,
|
376
|
+
*,
|
377
|
+
ctx: RunContext[CtxT],
|
378
|
+
call_id: str,
|
379
|
+
**kwargs: Any,
|
352
380
|
) -> None:
|
353
381
|
if self._policy_executor.memory_manager is not None:
|
354
382
|
return self._policy_executor.memory_manager(
|
355
|
-
memory=memory, ctx=ctx, **kwargs
|
383
|
+
memory=memory, ctx=ctx, call_id=call_id, **kwargs
|
356
384
|
)
|
357
385
|
raise NotImplementedError("Memory manager is not implemented.")
|
358
386
|
|
grasp_agents/llm_agent_memory.py
CHANGED
@@ -37,6 +37,7 @@ class ToolCallLoopTerminator(Protocol[CtxT]):
|
|
37
37
|
conversation: Messages,
|
38
38
|
*,
|
39
39
|
ctx: RunContext[CtxT],
|
40
|
+
call_id: str,
|
40
41
|
**kwargs: Any,
|
41
42
|
) -> bool: ...
|
42
43
|
|
@@ -47,6 +48,7 @@ class MemoryManager(Protocol[CtxT]):
|
|
47
48
|
memory: LLMAgentMemory,
|
48
49
|
*,
|
49
50
|
ctx: RunContext[CtxT],
|
51
|
+
call_id: str,
|
50
52
|
**kwargs: Any,
|
51
53
|
) -> None: ...
|
52
54
|
|
@@ -123,10 +125,13 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
123
125
|
conversation: Messages,
|
124
126
|
*,
|
125
127
|
ctx: RunContext[CtxT],
|
128
|
+
call_id: str,
|
126
129
|
**kwargs: Any,
|
127
130
|
) -> bool:
|
128
131
|
if self.tool_call_loop_terminator:
|
129
|
-
return self.tool_call_loop_terminator(
|
132
|
+
return self.tool_call_loop_terminator(
|
133
|
+
conversation, ctx=ctx, call_id=call_id, **kwargs
|
134
|
+
)
|
130
135
|
|
131
136
|
return False
|
132
137
|
|
@@ -136,18 +141,19 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
136
141
|
memory: LLMAgentMemory,
|
137
142
|
*,
|
138
143
|
ctx: RunContext[CtxT],
|
144
|
+
call_id: str,
|
139
145
|
**kwargs: Any,
|
140
146
|
) -> None:
|
141
147
|
if self.memory_manager:
|
142
|
-
self.memory_manager(memory=memory, ctx=ctx, **kwargs)
|
148
|
+
self.memory_manager(memory=memory, ctx=ctx, call_id=call_id, **kwargs)
|
143
149
|
|
144
150
|
async def generate_message(
|
145
151
|
self,
|
146
152
|
memory: LLMAgentMemory,
|
147
153
|
*,
|
148
|
-
call_id: str,
|
149
154
|
tool_choice: ToolChoice | None = None,
|
150
155
|
ctx: RunContext[CtxT],
|
156
|
+
call_id: str,
|
151
157
|
) -> AssistantMessage:
|
152
158
|
completion = await self.llm.generate_completion(
|
153
159
|
memory.message_history,
|
@@ -161,7 +167,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
161
167
|
)
|
162
168
|
memory.update(completion.messages)
|
163
169
|
self._process_completion(
|
164
|
-
completion,
|
170
|
+
completion, ctx=ctx, call_id=call_id, print_messages=True
|
165
171
|
)
|
166
172
|
|
167
173
|
return completion.messages[0]
|
@@ -170,9 +176,9 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
170
176
|
self,
|
171
177
|
memory: LLMAgentMemory,
|
172
178
|
*,
|
173
|
-
call_id: str,
|
174
179
|
tool_choice: ToolChoice | None = None,
|
175
180
|
ctx: RunContext[CtxT],
|
181
|
+
call_id: str,
|
176
182
|
) -> AsyncIterator[
|
177
183
|
CompletionChunkEvent[CompletionChunk]
|
178
184
|
| CompletionEvent
|
@@ -207,22 +213,22 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
207
213
|
memory.update(completion.messages)
|
208
214
|
|
209
215
|
self._process_completion(
|
210
|
-
completion,
|
216
|
+
completion, print_messages=True, ctx=ctx, call_id=call_id
|
211
217
|
)
|
212
218
|
|
213
219
|
async def call_tools(
|
214
220
|
self,
|
215
221
|
calls: Sequence[ToolCall],
|
216
222
|
memory: LLMAgentMemory,
|
217
|
-
call_id: str,
|
218
223
|
ctx: RunContext[CtxT],
|
224
|
+
call_id: str,
|
219
225
|
) -> Sequence[ToolMessage]:
|
220
226
|
# TODO: Add image support
|
221
227
|
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
222
228
|
for call in calls:
|
223
229
|
tool = self.tools[call.tool_name]
|
224
230
|
args = json.loads(call.tool_arguments)
|
225
|
-
corouts.append(tool(
|
231
|
+
corouts.append(tool(ctx=ctx, call_id=call_id, **args))
|
226
232
|
|
227
233
|
outs = await asyncio.gather(*corouts)
|
228
234
|
tool_messages = list(
|
@@ -242,11 +248,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
242
248
|
self,
|
243
249
|
calls: Sequence[ToolCall],
|
244
250
|
memory: LLMAgentMemory,
|
245
|
-
call_id: str,
|
246
251
|
ctx: RunContext[CtxT],
|
252
|
+
call_id: str,
|
247
253
|
) -> AsyncIterator[ToolMessageEvent]:
|
248
254
|
tool_messages = await self.call_tools(
|
249
|
-
calls, memory=memory,
|
255
|
+
calls, memory=memory, ctx=ctx, call_id=call_id
|
250
256
|
)
|
251
257
|
for tool_message, call in zip(tool_messages, calls, strict=True):
|
252
258
|
yield ToolMessageEvent(
|
@@ -271,7 +277,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
271
277
|
return final_answer_message
|
272
278
|
|
273
279
|
async def _generate_final_answer(
|
274
|
-
self, memory: LLMAgentMemory,
|
280
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
|
275
281
|
) -> AssistantMessage:
|
276
282
|
user_message = UserMessage.from_text(
|
277
283
|
"Exceeded the maximum number of turns: provide a final answer now!"
|
@@ -284,7 +290,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
284
290
|
|
285
291
|
tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
|
286
292
|
await self.generate_message(
|
287
|
-
memory, tool_choice=tool_choice,
|
293
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
288
294
|
)
|
289
295
|
|
290
296
|
final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
|
@@ -294,7 +300,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
294
300
|
return final_answer_message
|
295
301
|
|
296
302
|
async def _generate_final_answer_stream(
|
297
|
-
self, memory: LLMAgentMemory,
|
303
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
|
298
304
|
) -> AsyncIterator[Event[Any]]:
|
299
305
|
user_message = UserMessage.from_text(
|
300
306
|
"Exceeded the maximum number of turns: provide a final answer now!",
|
@@ -310,7 +316,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
310
316
|
|
311
317
|
tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
|
312
318
|
async for event in self.generate_message_stream(
|
313
|
-
memory, tool_choice=tool_choice,
|
319
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
314
320
|
):
|
315
321
|
yield event
|
316
322
|
|
@@ -322,7 +328,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
322
328
|
)
|
323
329
|
|
324
330
|
async def execute(
|
325
|
-
self, memory: LLMAgentMemory,
|
331
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
|
326
332
|
) -> AssistantMessage | Sequence[AssistantMessage]:
|
327
333
|
# 1. Generate the first message:
|
328
334
|
# In ReAct mode, we generate the first message without tool calls
|
@@ -332,7 +338,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
332
338
|
if self.tools:
|
333
339
|
tool_choice = "none" if self._react_mode else "auto"
|
334
340
|
gen_message = await self.generate_message(
|
335
|
-
memory, tool_choice=tool_choice,
|
341
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
336
342
|
)
|
337
343
|
if not self.tools:
|
338
344
|
return gen_message
|
@@ -345,7 +351,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
345
351
|
# If a final answer is not provided via a tool call, we use
|
346
352
|
# _terminate_tool_call_loop to determine whether to exit the loop.
|
347
353
|
if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
|
348
|
-
memory.message_history, ctx=ctx, num_turns=turns
|
354
|
+
memory.message_history, ctx=ctx, call_id=call_id, num_turns=turns
|
349
355
|
):
|
350
356
|
return gen_message
|
351
357
|
|
@@ -364,7 +370,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
364
370
|
# Otherwise, we simply return the last generated message.
|
365
371
|
if self._final_answer_as_tool_call:
|
366
372
|
final_answer = await self._generate_final_answer(
|
367
|
-
memory,
|
373
|
+
memory, ctx=ctx, call_id=call_id
|
368
374
|
)
|
369
375
|
else:
|
370
376
|
final_answer = gen_message
|
@@ -377,11 +383,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
377
383
|
|
378
384
|
if gen_message.tool_calls:
|
379
385
|
await self.call_tools(
|
380
|
-
gen_message.tool_calls, memory=memory,
|
386
|
+
gen_message.tool_calls, memory=memory, ctx=ctx, call_id=call_id
|
381
387
|
)
|
382
388
|
|
383
389
|
# Apply memory management (e.g. compacting or pruning memory)
|
384
|
-
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
390
|
+
self._manage_memory(memory, ctx=ctx, call_id=call_id, num_turns=turns)
|
385
391
|
|
386
392
|
# 4. Generate the next message based on the updated memory.
|
387
393
|
# In ReAct mode, we set tool_choice to "none" if we just called tools,
|
@@ -396,7 +402,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
396
402
|
tool_choice = "required"
|
397
403
|
|
398
404
|
gen_message = await self.generate_message(
|
399
|
-
memory, tool_choice=tool_choice,
|
405
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
400
406
|
)
|
401
407
|
|
402
408
|
turns += 1
|
@@ -404,13 +410,13 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
404
410
|
async def execute_stream(
|
405
411
|
self,
|
406
412
|
memory: LLMAgentMemory,
|
407
|
-
call_id: str,
|
408
413
|
ctx: RunContext[CtxT],
|
414
|
+
call_id: str,
|
409
415
|
) -> AsyncIterator[Event[Any]]:
|
410
416
|
tool_choice: ToolChoice = "none" if self._react_mode else "auto"
|
411
417
|
gen_message: AssistantMessage | None = None
|
412
418
|
async for event in self.generate_message_stream(
|
413
|
-
memory, tool_choice=tool_choice,
|
419
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
414
420
|
):
|
415
421
|
if isinstance(event, GenMessageEvent):
|
416
422
|
gen_message = event.data
|
@@ -425,7 +431,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
425
431
|
|
426
432
|
while True:
|
427
433
|
if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
|
428
|
-
memory.message_history, ctx=ctx, num_turns=turns
|
434
|
+
memory.message_history, ctx=ctx, call_id=call_id, num_turns=turns
|
429
435
|
):
|
430
436
|
return
|
431
437
|
|
@@ -444,7 +450,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
444
450
|
if turns >= self.max_turns:
|
445
451
|
if self._final_answer_as_tool_call:
|
446
452
|
async for event in self._generate_final_answer_stream(
|
447
|
-
memory,
|
453
|
+
memory, ctx=ctx, call_id=call_id
|
448
454
|
):
|
449
455
|
yield event
|
450
456
|
logger.info(
|
@@ -459,11 +465,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
459
465
|
)
|
460
466
|
|
461
467
|
async for event in self.call_tools_stream(
|
462
|
-
gen_message.tool_calls, memory=memory,
|
468
|
+
gen_message.tool_calls, memory=memory, ctx=ctx, call_id=call_id
|
463
469
|
):
|
464
470
|
yield event
|
465
471
|
|
466
|
-
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
472
|
+
self._manage_memory(memory, ctx=ctx, call_id=call_id, num_turns=turns)
|
467
473
|
|
468
474
|
if self._react_mode and gen_message.tool_calls:
|
469
475
|
tool_choice = "none"
|
@@ -473,7 +479,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
473
479
|
tool_choice = "required"
|
474
480
|
|
475
481
|
async for event in self.generate_message_stream(
|
476
|
-
memory, tool_choice=tool_choice,
|
482
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
477
483
|
):
|
478
484
|
yield event
|
479
485
|
if isinstance(event, GenMessageEvent):
|
@@ -493,8 +499,8 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
493
499
|
self,
|
494
500
|
inp: BaseModel,
|
495
501
|
*,
|
496
|
-
call_id: str | None = None,
|
497
502
|
ctx: RunContext[Any] | None = None,
|
503
|
+
call_id: str | None = None,
|
498
504
|
) -> None:
|
499
505
|
return None
|
500
506
|
|
@@ -504,9 +510,9 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
504
510
|
self,
|
505
511
|
completion: Completion,
|
506
512
|
*,
|
507
|
-
call_id: str,
|
508
513
|
print_messages: bool = False,
|
509
514
|
ctx: RunContext[CtxT],
|
515
|
+
call_id: str,
|
510
516
|
) -> None:
|
511
517
|
ctx.completions[self.agent_name].append(completion)
|
512
518
|
ctx.usage_tracker.update(
|
@@ -60,7 +60,7 @@ def get_openai_compatible_providers() -> list[APIProvider]:
|
|
60
60
|
|
61
61
|
|
62
62
|
class OpenAILLMSettings(CloudLLMSettings, total=False):
|
63
|
-
reasoning_effort: Literal["low", "medium", "high"] | None
|
63
|
+
reasoning_effort: Literal["disable", "minimal", "low", "medium", "high"] | None
|
64
64
|
|
65
65
|
parallel_tool_calls: bool
|
66
66
|
|
@@ -106,7 +106,7 @@ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
|
106
106
|
|
107
107
|
class RecipientSelector(Protocol[_OutT_contra, CtxT]):
|
108
108
|
def __call__(
|
109
|
-
self, output: _OutT_contra, ctx: RunContext[CtxT]
|
109
|
+
self, output: _OutT_contra, *, ctx: RunContext[CtxT]
|
110
110
|
) -> Sequence[ProcName] | None: ...
|
111
111
|
|
112
112
|
|
@@ -114,7 +114,7 @@ class ParallelProcessor(
|
|
114
114
|
ctx: RunContext[CtxT] | None = None,
|
115
115
|
) -> Packet[OutT]:
|
116
116
|
call_id = self._generate_call_id(call_id)
|
117
|
-
ctx = RunContext[CtxT](state=None)
|
117
|
+
ctx = ctx or RunContext[CtxT](state=None) # type: ignore
|
118
118
|
|
119
119
|
val_in_args = self._validate_inputs(
|
120
120
|
call_id=call_id,
|
@@ -223,7 +223,7 @@ class ParallelProcessor(
|
|
223
223
|
ctx: RunContext[CtxT] | None = None,
|
224
224
|
) -> AsyncIterator[Event[Any]]:
|
225
225
|
call_id = self._generate_call_id(call_id)
|
226
|
-
ctx = RunContext[CtxT](state=None)
|
226
|
+
ctx = ctx or RunContext[CtxT](state=None) # type: ignore
|
227
227
|
|
228
228
|
val_in_args = self._validate_inputs(
|
229
229
|
call_id=call_id,
|
@@ -105,7 +105,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
105
105
|
call_id: str | None = None,
|
106
106
|
ctx: RunContext[CtxT] | None = None,
|
107
107
|
) -> Packet[OutT]:
|
108
|
-
ctx = RunContext[CtxT](state=None)
|
108
|
+
ctx = ctx or RunContext[CtxT](state=None) # type: ignore
|
109
109
|
|
110
110
|
val_in_args, memory, call_id = self._preprocess(
|
111
111
|
chat_inputs=chat_inputs,
|
@@ -136,7 +136,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
136
136
|
call_id: str | None = None,
|
137
137
|
ctx: RunContext[CtxT] | None = None,
|
138
138
|
) -> AsyncIterator[Event[Any]]:
|
139
|
-
ctx = RunContext[CtxT](state=None)
|
139
|
+
ctx = ctx or RunContext[CtxT](state=None) # type: ignore
|
140
140
|
|
141
141
|
val_in_args, memory, call_id = self._preprocess(
|
142
142
|
chat_inputs=chat_inputs,
|
grasp_agents/prompt_builder.py
CHANGED
@@ -15,11 +15,22 @@ _InT_contra = TypeVar("_InT_contra", contravariant=True)
|
|
15
15
|
|
16
16
|
|
17
17
|
class SystemPromptBuilder(Protocol[CtxT]):
|
18
|
-
def __call__(
|
18
|
+
def __call__(
|
19
|
+
self,
|
20
|
+
*,
|
21
|
+
ctx: RunContext[CtxT],
|
22
|
+
call_id: str,
|
23
|
+
) -> str | None: ...
|
19
24
|
|
20
25
|
|
21
26
|
class InputContentBuilder(Protocol[_InT_contra, CtxT]):
|
22
|
-
def __call__(
|
27
|
+
def __call__(
|
28
|
+
self,
|
29
|
+
in_args: _InT_contra,
|
30
|
+
*,
|
31
|
+
ctx: RunContext[CtxT],
|
32
|
+
call_id: str,
|
33
|
+
) -> Content: ...
|
23
34
|
|
24
35
|
|
25
36
|
PromptArgumentType: TypeAlias = str | bool | int | ImageData
|
@@ -43,9 +54,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
43
54
|
self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
44
55
|
|
45
56
|
@final
|
46
|
-
def build_system_prompt(self, ctx: RunContext[CtxT]) -> str | None:
|
57
|
+
def build_system_prompt(self, ctx: RunContext[CtxT], call_id: str) -> str | None:
|
47
58
|
if self.system_prompt_builder:
|
48
|
-
return self.system_prompt_builder(ctx=ctx)
|
59
|
+
return self.system_prompt_builder(ctx=ctx, call_id=call_id)
|
49
60
|
|
50
61
|
return self.sys_prompt
|
51
62
|
|
@@ -71,7 +82,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
71
82
|
|
72
83
|
@final
|
73
84
|
def _build_input_content(
|
74
|
-
self, in_args: InT | None, ctx: RunContext[CtxT]
|
85
|
+
self, in_args: InT | None, ctx: RunContext[CtxT], call_id: str
|
75
86
|
) -> Content:
|
76
87
|
if in_args is None and self._in_type is not type(None):
|
77
88
|
raise InputPromptBuilderError(
|
@@ -83,7 +94,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
83
94
|
|
84
95
|
val_in_args = self._validate_input_args(in_args)
|
85
96
|
if self.input_content_builder:
|
86
|
-
return self.input_content_builder(
|
97
|
+
return self.input_content_builder(
|
98
|
+
in_args=val_in_args, ctx=ctx, call_id=call_id
|
99
|
+
)
|
87
100
|
|
88
101
|
if issubclass(self._in_type, BaseModel) and isinstance(val_in_args, BaseModel):
|
89
102
|
val_in_args_map = self._format_pydantic_prompt_args(val_in_args)
|
@@ -102,6 +115,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
102
115
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
103
116
|
*,
|
104
117
|
in_args: InT | None = None,
|
118
|
+
call_id: str,
|
105
119
|
ctx: RunContext[CtxT],
|
106
120
|
) -> UserMessage | None:
|
107
121
|
if chat_inputs is not None:
|
@@ -116,7 +130,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
116
130
|
return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
|
117
131
|
|
118
132
|
return UserMessage(
|
119
|
-
content=self._build_input_content(
|
133
|
+
content=self._build_input_content(
|
134
|
+
in_args=in_args, ctx=ctx, call_id=call_id
|
135
|
+
),
|
120
136
|
name=self._agent_name,
|
121
137
|
)
|
122
138
|
|
grasp_agents/typing/tool.py
CHANGED
@@ -64,20 +64,22 @@ class BaseTool(
|
|
64
64
|
self,
|
65
65
|
inp: _InT,
|
66
66
|
*,
|
67
|
-
call_id: str | None = None,
|
68
67
|
ctx: RunContext[CtxT] | None = None,
|
68
|
+
call_id: str | None = None,
|
69
69
|
) -> _OutT_co:
|
70
70
|
pass
|
71
71
|
|
72
72
|
async def __call__(
|
73
73
|
self,
|
74
74
|
*,
|
75
|
-
call_id: str | None = None,
|
76
75
|
ctx: RunContext[CtxT] | None = None,
|
76
|
+
call_id: str | None = None,
|
77
77
|
**kwargs: Any,
|
78
78
|
) -> _OutT_co:
|
79
|
+
# NOTE: validation is probably redundant here when tool inputs have been
|
80
|
+
# validated by the LLM already
|
79
81
|
input_args = TypeAdapter(self._in_type).validate_python(kwargs)
|
80
|
-
output = await self.run(input_args,
|
82
|
+
output = await self.run(input_args, ctx=ctx, call_id=call_id)
|
81
83
|
|
82
84
|
return TypeAdapter(self._out_type).validate_python(output)
|
83
85
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: grasp_agents
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.11
|
4
4
|
Summary: Grasp Agents Library
|
5
5
|
License-File: LICENSE.md
|
6
6
|
Requires-Python: <4,>=3.11.4
|
@@ -166,9 +166,7 @@ class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, Any]):
|
|
166
166
|
name: str = "ask_student"
|
167
167
|
description: str = ask_student_tool_description
|
168
168
|
|
169
|
-
async def run(
|
170
|
-
self, inp: TeacherQuestion, ctx: RunContext[Any] | None = None
|
171
|
-
) -> StudentReply:
|
169
|
+
async def run(self, inp: TeacherQuestion, **kwargs: Any) -> StudentReply:
|
172
170
|
return input(inp.question)
|
173
171
|
|
174
172
|
|
@@ -180,7 +178,8 @@ teacher = LLMAgent[None, Problem, None](
|
|
180
178
|
name="teacher",
|
181
179
|
llm=LiteLLM(
|
182
180
|
model_name="gpt-4.1",
|
183
|
-
|
181
|
+
# model_name="claude-sonnet-4-20250514",
|
182
|
+
# llm_settings=LiteLLMSettings(reasoning_effort="low"),
|
184
183
|
),
|
185
184
|
tools=[AskStudentTool()],
|
186
185
|
react_mode=True,
|
@@ -1,19 +1,19 @@
|
|
1
1
|
grasp_agents/__init__.py,sha256=Z3a_j2Etiap9H6lvE8-PQP_OIGMUcHNPeJAJO12B8kY,1031
|
2
|
-
grasp_agents/cloud_llm.py,sha256=
|
2
|
+
grasp_agents/cloud_llm.py,sha256=vwI6gpLOsFqN4KtaTOo75xw8t7uRtdVrYGjopEDmQBw,13091
|
3
3
|
grasp_agents/costs_dict.yaml,sha256=2MFNWtkv5W5WSCcv1Cj13B1iQLVv5Ot9pS_KW2Gu2DA,2510
|
4
4
|
grasp_agents/errors.py,sha256=K-22TCM1Klhsej47Rg5eTqnGiGPaXgKOpdOZZ7cPipw,4633
|
5
5
|
grasp_agents/generics_utils.py,sha256=5Pw3I9dlnKC2VGqYKC4ZZUO3Z_vTNT-NPFovNfPkl6I,6542
|
6
6
|
grasp_agents/grasp_logging.py,sha256=H1GYhXdQvVkmauFDZ-KDwvVmPQHZUUm9sRqX_ObK2xI,1111
|
7
7
|
grasp_agents/http_client.py,sha256=Es8NXGDkp4Nem7g24-jW0KFGA9Hp_o2Cv3cOvjup-iU,859
|
8
8
|
grasp_agents/llm.py,sha256=IeV2QpR4AldVP3THzSETEnsaDx3DYz5HM6dkikSpy4o,10684
|
9
|
-
grasp_agents/llm_agent.py,sha256=
|
10
|
-
grasp_agents/llm_agent_memory.py,sha256=
|
11
|
-
grasp_agents/llm_policy_executor.py,sha256=
|
9
|
+
grasp_agents/llm_agent.py,sha256=F_ou0pfdztqZzd2yU1jZZZVzcyhsLXfE_i0c4y2fZIQ,14123
|
10
|
+
grasp_agents/llm_agent_memory.py,sha256=XmOT2G8RG5AHd0LR3WuK7VbD-KFFfThmJnuZK2iU3Fs,1856
|
11
|
+
grasp_agents/llm_policy_executor.py,sha256=r0UxwjnVzTBQqLlwvZZ_JL0wl6ZebCgxkcz6I4GdmrM,18136
|
12
12
|
grasp_agents/memory.py,sha256=keHuNEZNSxHT9FKpMohHOCNi7UAz_oRIc91IQEuzaWE,1162
|
13
13
|
grasp_agents/packet.py,sha256=EmE-W4ZSMVZoqClECGFe7OGqrT4FSJ8IVGICrdjtdEY,1462
|
14
14
|
grasp_agents/packet_pool.py,sha256=AF7ZMYY1U6ppNLEn6o0R8QXyWmcLQGcju7_TYQpAudg,4443
|
15
15
|
grasp_agents/printer.py,sha256=wVNCaR9mbFKyzYdT8YpYD1JQqRqHdLtdfiZrwYxaM6Y,11132
|
16
|
-
grasp_agents/prompt_builder.py,sha256=
|
16
|
+
grasp_agents/prompt_builder.py,sha256=wNPphkW8RL8501jV4Z7ncsN_sxBDR9Ax7eILLHr-OYg,6110
|
17
17
|
grasp_agents/run_context.py,sha256=7qVs0T5rLvINmtlXqOoyy2Hu9xPzuFDbcVR6R93NF-0,951
|
18
18
|
grasp_agents/runner.py,sha256=JL2wSKahbPYVd56NRB09cwco43sjhZPI4XYFCZyOXOA,5173
|
19
19
|
grasp_agents/usage_tracker.py,sha256=ZQfVUUpG0C89hyPWT_JgXnjQOxoYmumcQ9t-aCfcMo8,3561
|
@@ -30,11 +30,11 @@ grasp_agents/openai/completion_converters.py,sha256=UlDeQSl0AEFUS-QI5e8rrjfmXZoj
|
|
30
30
|
grasp_agents/openai/content_converters.py,sha256=sMsZhoatuL_8t0IdVaGWIVZLB4nyi1ajD61GewQmeY4,2503
|
31
31
|
grasp_agents/openai/converters.py,sha256=RKOfMbIJmfFQ7ot0RGR6wrdMbR6_L7PB0UZwxwgM88g,4691
|
32
32
|
grasp_agents/openai/message_converters.py,sha256=fhSN81uK51EGbLyM2-f0MvPX_UBrMy7SF3JQPo-dkXg,4686
|
33
|
-
grasp_agents/openai/openai_llm.py,sha256=
|
33
|
+
grasp_agents/openai/openai_llm.py,sha256=QjxrZ4fM_FX3ncBjehUjWPCCiI62u_W2XDi7nth1WrY,9737
|
34
34
|
grasp_agents/openai/tool_converters.py,sha256=rNH5t2Wir9nuy8Ei0jaxNuzDaXGqTLmLz3VyrnJhyn0,1196
|
35
|
-
grasp_agents/processors/base_processor.py,sha256=
|
36
|
-
grasp_agents/processors/parallel_processor.py,sha256=
|
37
|
-
grasp_agents/processors/processor.py,sha256=
|
35
|
+
grasp_agents/processors/base_processor.py,sha256=BQ2k8dJY0jTMmidXZdK7JLO2YIQkmkp5boF1fT1o6uQ,10838
|
36
|
+
grasp_agents/processors/parallel_processor.py,sha256=BOXRlPaZ-hooz0hHctqiW_5ldR-yDPYjFxuP7fAbZCI,7911
|
37
|
+
grasp_agents/processors/processor.py,sha256=35MtYKrKtCZZMhV-U1DXBXtCNbCvZGaiiXo_5a3tI6s,5249
|
38
38
|
grasp_agents/rate_limiting/__init__.py,sha256=KRgtF_E7R3YfA2cpYcFcZ7wycV0pWVJ0xRQC7YhiIEQ,158
|
39
39
|
grasp_agents/rate_limiting/rate_limiter_chunked.py,sha256=BPgkUXvhmZhTpZs2T6uujNFuxH_kYHiISuf6_-eNhUc,5544
|
40
40
|
grasp_agents/rate_limiting/types.py,sha256=PbnNhEAcYedQdIpPJWud8HUVcxa_xZS2RDZu4c5jr40,1003
|
@@ -47,12 +47,12 @@ grasp_agents/typing/converters.py,sha256=VrsqjuC_1IMj9rTOAMPBJ1N0hHY3Z9fx0zySo4Z
|
|
47
47
|
grasp_agents/typing/events.py,sha256=vFq6qRGofY8NuxOG9ZIN2_CnhAqsAodYLD4b4KtAq2U,12620
|
48
48
|
grasp_agents/typing/io.py,sha256=MGEoUjAwKH1AHYglFkKNpHiielw-NFf13Epg3B4Q7Iw,139
|
49
49
|
grasp_agents/typing/message.py,sha256=o7bN84AgrC5Fm3Wx20gqL9ArAMcEtYvnHnXbb04ngCs,3224
|
50
|
-
grasp_agents/typing/tool.py,sha256=
|
50
|
+
grasp_agents/typing/tool.py,sha256=qwC5baRratcyJWLMQ923IMGHH1hmj9eUtYLnNBcbwUU,2033
|
51
51
|
grasp_agents/workflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
52
52
|
grasp_agents/workflow/looped_workflow.py,sha256=WHp9O3Za2sBVfY_BLOdvPvtY20XsjZQaWSO2-oAFvOY,6806
|
53
53
|
grasp_agents/workflow/sequential_workflow.py,sha256=e3BIWzy_2novmEWNwIteyMbrzvl1-evHrTBE3r3SpU8,3648
|
54
54
|
grasp_agents/workflow/workflow_processor.py,sha256=DwHz70UOTp9dkbtzH9KE5LkGcT1RdHV7Hdiby0Bu9tw,3535
|
55
|
-
grasp_agents-0.5.
|
56
|
-
grasp_agents-0.5.
|
57
|
-
grasp_agents-0.5.
|
58
|
-
grasp_agents-0.5.
|
55
|
+
grasp_agents-0.5.11.dist-info/METADATA,sha256=BkVyEN63RzGsCIJCnm5S38EI2ua9NcbPmr3lRCmWPGs,7021
|
56
|
+
grasp_agents-0.5.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
57
|
+
grasp_agents-0.5.11.dist-info/licenses/LICENSE.md,sha256=-nNNdWqGB8gJ2O-peFQ2Irshv5tW5pHKyTcYkwvH7CE,1201
|
58
|
+
grasp_agents-0.5.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|