grasp_agents 0.5.10__py3-none-any.whl → 0.5.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/__init__.py CHANGED
@@ -6,6 +6,7 @@ from .llm_agent import LLMAgent
6
6
  from .llm_agent_memory import LLMAgentMemory
7
7
  from .memory import Memory
8
8
  from .packet import Packet
9
+ from .printer import Printer, print_event_stream
9
10
  from .processors.base_processor import BaseProcessor
10
11
  from .processors.parallel_processor import ParallelProcessor
11
12
  from .processors.processor import Processor
@@ -33,9 +34,11 @@ __all__ = [
33
34
  "Packet",
34
35
  "Packet",
35
36
  "ParallelProcessor",
37
+ "Printer",
36
38
  "ProcName",
37
39
  "Processor",
38
40
  "RunContext",
39
41
  "SystemMessage",
40
42
  "UserMessage",
43
+ "print_event_stream",
41
44
  ]
grasp_agents/cloud_llm.py CHANGED
@@ -61,7 +61,6 @@ LLMRateLimiter = RateLimiterC[
61
61
 
62
62
  @dataclass(frozen=True)
63
63
  class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
64
- # Make this field keyword-only to avoid ordering issues with inherited defaulted fields
65
64
  api_provider: APIProvider | None = None
66
65
  llm_settings: SettingsT_co | None = None
67
66
  rate_limiter: LLMRateLimiter | None = None
@@ -70,6 +69,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
70
69
  0 # LLM response retries: try to regenerate to pass validation
71
70
  )
72
71
  apply_response_schema_via_provider: bool = False
72
+ apply_tool_call_schema_via_provider: bool = False
73
73
  async_http_client: httpx.AsyncClient | None = None
74
74
  async_http_client_params: dict[str, Any] | AsyncHTTPClientParams | None = None
75
75
 
@@ -80,6 +80,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
80
80
  f"{self.rate_limiter.rpm} RPM"
81
81
  )
82
82
 
83
+ if self.apply_response_schema_via_provider:
84
+ object.__setattr__(self, "apply_tool_call_schema_via_provider", True)
85
+
83
86
  if self.async_http_client is None and self.async_http_client_params is not None:
84
87
  object.__setattr__(
85
88
  self,
@@ -100,7 +103,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
100
103
  api_tools = None
101
104
  api_tool_choice = None
102
105
  if tools:
103
- strict = True if self.apply_response_schema_via_provider else None
106
+ strict = True if self.apply_tool_call_schema_via_provider else None
104
107
  api_tools = [
105
108
  self.converters.to_tool(t, strict=strict) for t in tools.values()
106
109
  ]
@@ -175,8 +178,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
175
178
  response_schema=response_schema,
176
179
  response_schema_by_xml_tag=response_schema_by_xml_tag,
177
180
  )
178
- if tools is not None:
179
- self._validate_tool_calls(completion, tools=tools)
181
+ if not self.apply_tool_call_schema_via_provider and tools is not None:
182
+ self._validate_tool_calls(completion, tools=tools)
180
183
 
181
184
  return completion
182
185
 
@@ -208,17 +211,16 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
208
211
 
209
212
  if n_attempt > self.max_response_retries:
210
213
  if n_attempt == 1:
211
- logger.warning(f"\nCloudLLM completion request failed:\n{err}")
214
+ logger.warning(f"\nCloudLLM completion failed:\n{err}")
212
215
  if n_attempt > 1:
213
216
  logger.warning(
214
- f"\nCloudLLM completion request failed after retrying:\n{err}"
217
+ f"\nCloudLLM completion failed after retrying:\n{err}"
215
218
  )
216
219
  raise err
217
220
  # return make_refusal_completion(self._model_name, err)
218
221
 
219
222
  logger.warning(
220
- f"\nCloudLLM completion request failed (retry attempt {n_attempt}):"
221
- f"\n{err}"
223
+ f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
222
224
  )
223
225
 
224
226
  return make_refusal_completion(
@@ -282,8 +284,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
282
284
  response_schema=response_schema,
283
285
  response_schema_by_xml_tag=response_schema_by_xml_tag,
284
286
  )
285
- if tools is not None:
286
- self._validate_tool_calls(completion, tools=tools)
287
+ if not self.apply_tool_call_schema_via_provider and tools is not None:
288
+ self._validate_tool_calls(completion, tools=tools)
287
289
 
288
290
  return iterator()
289
291
 
@@ -327,11 +329,10 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
327
329
  n_attempt += 1
328
330
  if n_attempt > self.max_response_retries:
329
331
  if n_attempt == 1:
330
- logger.warning(f"\nCloudLLM completion request failed:\n{err}")
332
+ logger.warning(f"\nCloudLLM completion failed:\n{err}")
331
333
  if n_attempt > 1:
332
334
  logger.warning(
333
- "\nCloudLLM completion request failed after "
334
- f"retrying:\n{err}"
335
+ f"\nCloudLLM completion failed after retrying:\n{err}"
335
336
  )
336
337
  refusal_completion = make_refusal_completion(
337
338
  self.model_name, err
@@ -345,6 +346,5 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
345
346
  # return
346
347
 
347
348
  logger.warning(
348
- "\nCloudLLM completion request failed "
349
- f"(retry attempt {n_attempt}):\n{err}"
349
+ f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
350
350
  )
@@ -159,7 +159,7 @@ class AutoInstanceAttributesMixin:
159
159
  attr_type = resolved_attr_types[attr_name]
160
160
  # attr_type = None if _attr_type is type(None) else _attr_type
161
161
  else:
162
- attr_type = Any
162
+ attr_type = object
163
163
 
164
164
  if attr_name in pyd_private:
165
165
  pyd_private[attr_name] = attr_type
@@ -149,6 +149,9 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
149
149
  n_choices: int | None = None,
150
150
  **api_llm_settings: Any,
151
151
  ) -> LiteLLMCompletion:
152
+ if api_llm_settings and api_llm_settings.get("stream_options"):
153
+ api_llm_settings.pop("stream_options")
154
+
152
155
  completion = await litellm.acompletion( # type: ignore[no-untyped-call]
153
156
  model=self.model_name,
154
157
  messages=api_messages,
grasp_agents/llm_agent.py CHANGED
@@ -42,6 +42,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
42
42
  *,
43
43
  in_args: _InT_contra | None,
44
44
  ctx: RunContext[CtxT],
45
+ call_id: str,
45
46
  ) -> _OutT_co: ...
46
47
 
47
48
 
@@ -169,10 +170,15 @@ class LLMAgent(
169
170
  in_args: InT | None = None,
170
171
  sys_prompt: LLMPrompt | None = None,
171
172
  ctx: RunContext[Any],
173
+ call_id: str,
172
174
  ) -> None:
173
175
  if self.memory_preparator:
174
176
  return self.memory_preparator(
175
- memory=memory, in_args=in_args, sys_prompt=sys_prompt, ctx=ctx
177
+ memory=memory,
178
+ in_args=in_args,
179
+ sys_prompt=sys_prompt,
180
+ ctx=ctx,
181
+ call_id=call_id,
176
182
  )
177
183
 
178
184
  def _memorize_inputs(
@@ -182,8 +188,11 @@ class LLMAgent(
182
188
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
183
189
  in_args: InT | None = None,
184
190
  ctx: RunContext[CtxT],
191
+ call_id: str,
185
192
  ) -> tuple[SystemMessage | None, UserMessage | None]:
186
- formatted_sys_prompt = self._prompt_builder.build_system_prompt(ctx=ctx)
193
+ formatted_sys_prompt = self._prompt_builder.build_system_prompt(
194
+ ctx=ctx, call_id=call_id
195
+ )
187
196
 
188
197
  system_message: SystemMessage | None = None
189
198
  if self._reset_memory_on_run or memory.is_empty:
@@ -192,24 +201,22 @@ class LLMAgent(
192
201
  system_message = cast("SystemMessage", memory.message_history[0])
193
202
  else:
194
203
  self._prepare_memory(
195
- memory=memory, in_args=in_args, sys_prompt=formatted_sys_prompt, ctx=ctx
204
+ memory=memory,
205
+ in_args=in_args,
206
+ sys_prompt=formatted_sys_prompt,
207
+ ctx=ctx,
208
+ call_id=call_id,
196
209
  )
197
210
 
198
211
  input_message = self._prompt_builder.build_input_message(
199
- chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
212
+ chat_inputs=chat_inputs, in_args=in_args, ctx=ctx, call_id=call_id
200
213
  )
201
214
  if input_message:
202
215
  memory.update([input_message])
203
216
 
204
217
  return system_message, input_message
205
218
 
206
- def _parse_output_default(
207
- self,
208
- conversation: Messages,
209
- *,
210
- in_args: InT | None = None,
211
- ctx: RunContext[CtxT],
212
- ) -> OutT:
219
+ def parse_output_default(self, conversation: Messages) -> OutT:
213
220
  return validate_obj_from_json_or_py_string(
214
221
  str(conversation[-1].content or ""),
215
222
  schema=self._out_type,
@@ -223,15 +230,14 @@ class LLMAgent(
223
230
  *,
224
231
  in_args: InT | None = None,
225
232
  ctx: RunContext[CtxT],
233
+ call_id: str,
226
234
  ) -> OutT:
227
235
  if self.output_parser:
228
236
  return self.output_parser(
229
- conversation=conversation, in_args=in_args, ctx=ctx
237
+ conversation=conversation, in_args=in_args, ctx=ctx, call_id=call_id
230
238
  )
231
239
 
232
- return self._parse_output_default(
233
- conversation=conversation, in_args=in_args, ctx=ctx
234
- )
240
+ return self.parse_output_default(conversation)
235
241
 
236
242
  async def _process(
237
243
  self,
@@ -239,24 +245,28 @@ class LLMAgent(
239
245
  *,
240
246
  in_args: InT | None = None,
241
247
  memory: LLMAgentMemory,
242
- call_id: str,
243
248
  ctx: RunContext[CtxT],
249
+ call_id: str,
244
250
  ) -> OutT:
245
251
  system_message, input_message = self._memorize_inputs(
246
252
  memory=memory,
247
253
  chat_inputs=chat_inputs,
248
254
  in_args=in_args,
249
255
  ctx=ctx,
256
+ call_id=call_id,
250
257
  )
251
258
  if system_message:
252
- self._print_messages([system_message], call_id=call_id, ctx=ctx)
259
+ self._print_messages([system_message], ctx=ctx, call_id=call_id)
253
260
  if input_message:
254
- self._print_messages([input_message], call_id=call_id, ctx=ctx)
261
+ self._print_messages([input_message], ctx=ctx, call_id=call_id)
255
262
 
256
- await self._policy_executor.execute(memory, call_id=call_id, ctx=ctx)
263
+ await self._policy_executor.execute(memory, ctx=ctx, call_id=call_id)
257
264
 
258
265
  return self._parse_output(
259
- conversation=memory.message_history, in_args=in_args, ctx=ctx
266
+ conversation=memory.message_history,
267
+ in_args=in_args,
268
+ ctx=ctx,
269
+ call_id=call_id,
260
270
  )
261
271
 
262
272
  async def _process_stream(
@@ -265,43 +275,44 @@ class LLMAgent(
265
275
  *,
266
276
  in_args: InT | None = None,
267
277
  memory: LLMAgentMemory,
268
- call_id: str,
269
278
  ctx: RunContext[CtxT],
279
+ call_id: str,
270
280
  ) -> AsyncIterator[Event[Any]]:
271
281
  system_message, input_message = self._memorize_inputs(
272
282
  memory=memory,
273
283
  chat_inputs=chat_inputs,
274
284
  in_args=in_args,
275
285
  ctx=ctx,
286
+ call_id=call_id,
276
287
  )
277
288
  if system_message:
278
- self._print_messages([system_message], call_id=call_id, ctx=ctx)
289
+ self._print_messages([system_message], ctx=ctx, call_id=call_id)
279
290
  yield SystemMessageEvent(
280
291
  data=system_message, proc_name=self.name, call_id=call_id
281
292
  )
282
293
  if input_message:
283
- self._print_messages([input_message], call_id=call_id, ctx=ctx)
294
+ self._print_messages([input_message], ctx=ctx, call_id=call_id)
284
295
  yield UserMessageEvent(
285
296
  data=input_message, proc_name=self.name, call_id=call_id
286
297
  )
287
298
 
288
299
  async for event in self._policy_executor.execute_stream(
289
- memory, call_id=call_id, ctx=ctx
300
+ memory, ctx=ctx, call_id=call_id
290
301
  ):
291
302
  yield event
292
303
 
293
304
  output = self._parse_output(
294
- conversation=memory.message_history, in_args=in_args, ctx=ctx
305
+ conversation=memory.message_history,
306
+ in_args=in_args,
307
+ ctx=ctx,
308
+ call_id=call_id,
295
309
  )
296
310
  yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
297
311
 
298
312
  def _print_messages(
299
- self,
300
- messages: Sequence[Message],
301
- call_id: str,
302
- ctx: RunContext[CtxT],
313
+ self, messages: Sequence[Message], ctx: RunContext[CtxT], call_id: str
303
314
  ) -> None:
304
- if ctx and ctx.printer:
315
+ if ctx.printer:
305
316
  ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
306
317
 
307
318
  # -- Override these methods in subclasses if needed --
@@ -328,31 +339,45 @@ class LLMAgent(
328
339
  if cur_cls.memory_manager is not base_cls.memory_manager:
329
340
  self._policy_executor.memory_manager = self.memory_manager
330
341
 
331
- def system_prompt_builder(self, ctx: RunContext[CtxT]) -> str | None:
342
+ def system_prompt_builder(self, ctx: RunContext[CtxT], call_id: str) -> str | None:
332
343
  if self._prompt_builder.system_prompt_builder is not None:
333
- return self._prompt_builder.system_prompt_builder(ctx=ctx)
344
+ return self._prompt_builder.system_prompt_builder(ctx=ctx, call_id=call_id)
334
345
  raise NotImplementedError("System prompt builder is not implemented.")
335
346
 
336
- def input_content_builder(self, in_args: InT, ctx: RunContext[CtxT]) -> Content:
347
+ def input_content_builder(
348
+ self, in_args: InT, ctx: RunContext[CtxT], call_id: str
349
+ ) -> Content:
337
350
  if self._prompt_builder.input_content_builder is not None:
338
- return self._prompt_builder.input_content_builder(in_args=in_args, ctx=ctx)
351
+ return self._prompt_builder.input_content_builder(
352
+ in_args=in_args, ctx=ctx, call_id=call_id
353
+ )
339
354
  raise NotImplementedError("Input content builder is not implemented.")
340
355
 
341
356
  def tool_call_loop_terminator(
342
- self, conversation: Messages, *, ctx: RunContext[CtxT], **kwargs: Any
357
+ self,
358
+ conversation: Messages,
359
+ *,
360
+ ctx: RunContext[CtxT],
361
+ call_id: str,
362
+ **kwargs: Any,
343
363
  ) -> bool:
344
364
  if self._policy_executor.tool_call_loop_terminator is not None:
345
365
  return self._policy_executor.tool_call_loop_terminator(
346
- conversation=conversation, ctx=ctx, **kwargs
366
+ conversation=conversation, ctx=ctx, call_id=call_id, **kwargs
347
367
  )
348
368
  raise NotImplementedError("Tool call loop terminator is not implemented.")
349
369
 
350
370
  def memory_manager(
351
- self, memory: LLMAgentMemory, *, ctx: RunContext[CtxT], **kwargs: Any
371
+ self,
372
+ memory: LLMAgentMemory,
373
+ *,
374
+ ctx: RunContext[CtxT],
375
+ call_id: str,
376
+ **kwargs: Any,
352
377
  ) -> None:
353
378
  if self._policy_executor.memory_manager is not None:
354
379
  return self._policy_executor.memory_manager(
355
- memory=memory, ctx=ctx, **kwargs
380
+ memory=memory, ctx=ctx, call_id=call_id, **kwargs
356
381
  )
357
382
  raise NotImplementedError("Memory manager is not implemented.")
358
383
 
@@ -17,6 +17,7 @@ class MemoryPreparator(Protocol):
17
17
  in_args: Any | None,
18
18
  sys_prompt: LLMPrompt | None,
19
19
  ctx: RunContext[Any],
20
+ call_id: str,
20
21
  ) -> None: ...
21
22
 
22
23