grasp_agents 0.5.9__py3-none-any.whl → 0.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/llm.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import AsyncIterator, Mapping, Sequence
4
- from typing import Any, Generic, TypeVar, cast
3
+ from collections.abc import AsyncIterator, Mapping
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Generic, TypeVar
5
6
  from uuid import uuid4
6
7
 
7
8
  from pydantic import BaseModel
@@ -66,71 +67,24 @@ SettingsT_co = TypeVar("SettingsT_co", bound=LLMSettings, covariant=True)
66
67
  ConvertT_co = TypeVar("ConvertT_co", bound=Converters, covariant=True)
67
68
 
68
69
 
70
+ @dataclass(frozen=True)
69
71
  class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
70
- @abstractmethod
71
- def __init__(
72
+ model_name: str
73
+ converters: ConvertT_co
74
+ llm_settings: SettingsT_co | None = None
75
+ model_id: str = field(default_factory=lambda: str(uuid4())[:8])
76
+
77
+ def _validate_response(
72
78
  self,
73
- converters: ConvertT_co,
74
- model_name: str | None = None,
75
- model_id: str | None = None,
76
- llm_settings: SettingsT_co | None = None,
77
- tools: Sequence[BaseTool[BaseModel, Any, Any]] | None = None,
78
- response_schema: Any | None = None,
79
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
80
- **kwargs: Any,
79
+ completion: Completion,
80
+ response_schema: Any | None,
81
+ response_schema_by_xml_tag: Mapping[str, Any] | None,
81
82
  ) -> None:
82
- super().__init__()
83
-
84
- self._converters = converters
85
- self._model_id = model_id or str(uuid4())[:8]
86
- self._model_name = model_name
87
- self._tools = {t.name: t for t in tools} if tools else None
88
- self._llm_settings: SettingsT_co = llm_settings or cast("SettingsT_co", {})
89
-
90
83
  if response_schema and response_schema_by_xml_tag:
91
84
  raise ValueError(
92
85
  "Only one of response_schema and response_schema_by_xml_tag can be "
93
86
  "provided, but not both."
94
87
  )
95
- self._response_schema = response_schema
96
- self._response_schema_by_xml_tag = response_schema_by_xml_tag
97
-
98
- @property
99
- def model_id(self) -> str:
100
- return self._model_id
101
-
102
- @property
103
- def model_name(self) -> str | None:
104
- return self._model_name
105
-
106
- @property
107
- def llm_settings(self) -> SettingsT_co:
108
- return self._llm_settings
109
-
110
- @property
111
- def response_schema(self) -> Any | None:
112
- return self._response_schema
113
-
114
- @response_schema.setter
115
- def response_schema(self, response_schema: Any | None) -> None:
116
- self._response_schema = response_schema
117
-
118
- @property
119
- def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
120
- return self._response_schema_by_xml_tag
121
-
122
- @property
123
- def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
124
- return self._tools
125
-
126
- @tools.setter
127
- def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
128
- self._tools = {t.name: t for t in tools} if tools else None
129
-
130
- def __repr__(self) -> str:
131
- return f"{type(self).__name__}[{self.model_id}]; model_name={self._model_name})"
132
-
133
- def _validate_response(self, completion: Completion) -> None:
134
88
  parsing_params = {
135
89
  "from_substring": False,
136
90
  "strip_language_markdown": True,
@@ -138,17 +92,17 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
138
92
  try:
139
93
  for message in completion.messages:
140
94
  if not message.tool_calls:
141
- if self._response_schema:
95
+ if response_schema:
142
96
  validate_obj_from_json_or_py_string(
143
97
  message.content or "",
144
- schema=self._response_schema,
98
+ schema=response_schema,
145
99
  **parsing_params,
146
100
  )
147
101
 
148
- elif self._response_schema_by_xml_tag:
102
+ elif response_schema_by_xml_tag:
149
103
  validate_tagged_objs_from_json_or_py_string(
150
104
  message.content or "",
151
- schema_by_xml_tag=self._response_schema_by_xml_tag,
105
+ schema_by_xml_tag=response_schema_by_xml_tag,
152
106
  **parsing_params,
153
107
  )
154
108
  except JSONSchemaValidationError as exc:
@@ -156,7 +110,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
156
110
  exc.s, exc.schema, message=str(exc)
157
111
  ) from exc
158
112
 
159
- def _validate_tool_calls(self, completion: Completion) -> None:
113
+ def _validate_tool_calls(
114
+ self, completion: Completion, tools: Mapping[str, BaseTool[BaseModel, Any, Any]]
115
+ ) -> None:
160
116
  parsing_params = {
161
117
  "from_substring": False,
162
118
  "strip_language_markdown": True,
@@ -167,15 +123,15 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
167
123
  tool_name = tool_call.tool_name
168
124
  tool_arguments = tool_call.tool_arguments
169
125
 
170
- available_tool_names = list(self.tools) if self.tools else []
171
- if tool_name not in available_tool_names or not self.tools:
126
+ available_tool_names = list(tools) if tools else []
127
+ if tool_name not in available_tool_names or not tools:
172
128
  raise LLMToolCallValidationError(
173
129
  tool_name,
174
130
  tool_arguments,
175
131
  message=f"Tool '{tool_name}' is not available in the LLM "
176
132
  f"tools (available: {available_tool_names})",
177
133
  )
178
- tool = self.tools[tool_name]
134
+ tool = tools[tool_name]
179
135
  try:
180
136
  validate_obj_from_json_or_py_string(
181
137
  tool_arguments, schema=tool.in_type, **parsing_params
@@ -296,6 +252,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
296
252
  self,
297
253
  conversation: Messages,
298
254
  *,
255
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
256
+ response_schema: Any | None = None,
257
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
299
258
  tool_choice: ToolChoice | None = None,
300
259
  n_choices: int | None = None,
301
260
  proc_name: str | None = None,
@@ -308,6 +267,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
308
267
  self,
309
268
  conversation: Messages,
310
269
  *,
270
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
271
+ response_schema: Any | None = None,
272
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
311
273
  tool_choice: ToolChoice | None = None,
312
274
  n_choices: int | None = None,
313
275
  proc_name: str | None = None,
@@ -318,5 +280,10 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
318
280
  pass
319
281
 
320
282
  @abstractmethod
321
- def combine_completion_chunks(self, completion_chunks: list[Any]) -> Any:
283
+ def combine_completion_chunks(
284
+ self,
285
+ completion_chunks: list[Any],
286
+ response_schema: Any | None = None,
287
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
288
+ ) -> Any:
322
289
  raise NotImplementedError
grasp_agents/llm_agent.py CHANGED
@@ -1,4 +1,4 @@
1
- from collections.abc import AsyncIterator, Sequence
1
+ from collections.abc import AsyncIterator, Mapping, Sequence
2
2
  from pathlib import Path
3
3
  from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
4
4
 
@@ -41,7 +41,8 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
41
41
  conversation: Messages,
42
42
  *,
43
43
  in_args: _InT_contra | None,
44
- ctx: RunContext[CtxT] | None,
44
+ ctx: RunContext[CtxT],
45
+ call_id: str,
45
46
  ) -> _OutT_co: ...
46
47
 
47
48
 
@@ -68,16 +69,19 @@ class LLMAgent(
68
69
  # System prompt template
69
70
  sys_prompt: LLMPrompt | None = None,
70
71
  sys_prompt_path: str | Path | None = None,
72
+ # LLM response validation
73
+ response_schema: Any | None = None,
74
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
71
75
  # Agent loop settings
72
76
  max_turns: int = 100,
73
77
  react_mode: bool = False,
74
78
  final_answer_as_tool_call: bool = False,
75
79
  # Agent memory management
76
80
  reset_memory_on_run: bool = False,
77
- # Retries
81
+ # Agent run retries
78
82
  max_retries: int = 0,
79
83
  # Multi-agent routing
80
- recipients: list[ProcName] | None = None,
84
+ recipients: Sequence[ProcName] | None = None,
81
85
  ) -> None:
82
86
  super().__init__(name=name, recipients=recipients, max_retries=max_retries)
83
87
 
@@ -96,15 +100,6 @@ class LLMAgent(
96
100
 
97
101
  # LLM policy executor
98
102
 
99
- self._used_default_llm_response_schema: bool = False
100
- if (
101
- llm.response_schema is None
102
- and tools is None
103
- and not hasattr(type(self), "output_parser")
104
- ):
105
- llm.response_schema = self.out_type
106
- self._used_default_llm_response_schema = True
107
-
108
103
  if issubclass(self._out_type, BaseModel):
109
104
  final_answer_type = self._out_type
110
105
  elif not final_answer_as_tool_call:
@@ -115,10 +110,21 @@ class LLMAgent(
115
110
  "final_answer_as_tool_call is True."
116
111
  )
117
112
 
113
+ self._used_default_llm_response_schema: bool = False
114
+ if (
115
+ response_schema is None
116
+ and tools is None
117
+ and not hasattr(type(self), "output_parser")
118
+ ):
119
+ response_schema = self.out_type
120
+ self._used_default_llm_response_schema = True
121
+
118
122
  self._policy_executor: LLMPolicyExecutor[CtxT] = LLMPolicyExecutor[CtxT](
119
123
  agent_name=self.name,
120
124
  llm=llm,
121
125
  tools=tools,
126
+ response_schema=response_schema,
127
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
122
128
  max_turns=max_turns,
123
129
  react_mode=react_mode,
124
130
  final_answer_type=final_answer_type,
@@ -160,23 +166,33 @@ class LLMAgent(
160
166
  def _prepare_memory(
161
167
  self,
162
168
  memory: LLMAgentMemory,
169
+ *,
163
170
  in_args: InT | None = None,
164
171
  sys_prompt: LLMPrompt | None = None,
165
- ctx: RunContext[Any] | None = None,
172
+ ctx: RunContext[Any],
173
+ call_id: str,
166
174
  ) -> None:
167
175
  if self.memory_preparator:
168
176
  return self.memory_preparator(
169
- memory=memory, in_args=in_args, sys_prompt=sys_prompt, ctx=ctx
177
+ memory=memory,
178
+ in_args=in_args,
179
+ sys_prompt=sys_prompt,
180
+ ctx=ctx,
181
+ call_id=call_id,
170
182
  )
171
183
 
172
184
  def _memorize_inputs(
173
185
  self,
174
186
  memory: LLMAgentMemory,
187
+ *,
175
188
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
176
189
  in_args: InT | None = None,
177
- ctx: RunContext[CtxT] | None = None,
190
+ ctx: RunContext[CtxT],
191
+ call_id: str,
178
192
  ) -> tuple[SystemMessage | None, UserMessage | None]:
179
- formatted_sys_prompt = self._prompt_builder.build_system_prompt(ctx=ctx)
193
+ formatted_sys_prompt = self._prompt_builder.build_system_prompt(
194
+ ctx=ctx, call_id=call_id
195
+ )
180
196
 
181
197
  system_message: SystemMessage | None = None
182
198
  if self._reset_memory_on_run or memory.is_empty:
@@ -185,24 +201,22 @@ class LLMAgent(
185
201
  system_message = cast("SystemMessage", memory.message_history[0])
186
202
  else:
187
203
  self._prepare_memory(
188
- memory=memory, in_args=in_args, sys_prompt=formatted_sys_prompt, ctx=ctx
204
+ memory=memory,
205
+ in_args=in_args,
206
+ sys_prompt=formatted_sys_prompt,
207
+ ctx=ctx,
208
+ call_id=call_id,
189
209
  )
190
210
 
191
211
  input_message = self._prompt_builder.build_input_message(
192
- chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
212
+ chat_inputs=chat_inputs, in_args=in_args, ctx=ctx, call_id=call_id
193
213
  )
194
214
  if input_message:
195
215
  memory.update([input_message])
196
216
 
197
217
  return system_message, input_message
198
218
 
199
- def _parse_output_default(
200
- self,
201
- conversation: Messages,
202
- *,
203
- in_args: InT | None = None,
204
- ctx: RunContext[CtxT] | None = None,
205
- ) -> OutT:
219
+ def parse_output_default(self, conversation: Messages) -> OutT:
206
220
  return validate_obj_from_json_or_py_string(
207
221
  str(conversation[-1].content or ""),
208
222
  schema=self._out_type,
@@ -215,16 +229,15 @@ class LLMAgent(
215
229
  conversation: Messages,
216
230
  *,
217
231
  in_args: InT | None = None,
218
- ctx: RunContext[CtxT] | None = None,
232
+ ctx: RunContext[CtxT],
233
+ call_id: str,
219
234
  ) -> OutT:
220
235
  if self.output_parser:
221
236
  return self.output_parser(
222
- conversation=conversation, in_args=in_args, ctx=ctx
237
+ conversation=conversation, in_args=in_args, ctx=ctx, call_id=call_id
223
238
  )
224
239
 
225
- return self._parse_output_default(
226
- conversation=conversation, in_args=in_args, ctx=ctx
227
- )
240
+ return self.parse_output_default(conversation)
228
241
 
229
242
  async def _process(
230
243
  self,
@@ -232,24 +245,28 @@ class LLMAgent(
232
245
  *,
233
246
  in_args: InT | None = None,
234
247
  memory: LLMAgentMemory,
248
+ ctx: RunContext[CtxT],
235
249
  call_id: str,
236
- ctx: RunContext[CtxT] | None = None,
237
250
  ) -> OutT:
238
251
  system_message, input_message = self._memorize_inputs(
239
252
  memory=memory,
240
253
  chat_inputs=chat_inputs,
241
254
  in_args=in_args,
242
255
  ctx=ctx,
256
+ call_id=call_id,
243
257
  )
244
258
  if system_message:
245
- self._print_messages([system_message], call_id=call_id, ctx=ctx)
259
+ self._print_messages([system_message], ctx=ctx, call_id=call_id)
246
260
  if input_message:
247
- self._print_messages([input_message], call_id=call_id, ctx=ctx)
261
+ self._print_messages([input_message], ctx=ctx, call_id=call_id)
248
262
 
249
- await self._policy_executor.execute(memory, call_id=call_id, ctx=ctx)
263
+ await self._policy_executor.execute(memory, ctx=ctx, call_id=call_id)
250
264
 
251
265
  return self._parse_output(
252
- conversation=memory.message_history, in_args=in_args, ctx=ctx
266
+ conversation=memory.message_history,
267
+ in_args=in_args,
268
+ ctx=ctx,
269
+ call_id=call_id,
253
270
  )
254
271
 
255
272
  async def _process_stream(
@@ -258,41 +275,45 @@ class LLMAgent(
258
275
  *,
259
276
  in_args: InT | None = None,
260
277
  memory: LLMAgentMemory,
278
+ ctx: RunContext[CtxT],
261
279
  call_id: str,
262
- ctx: RunContext[CtxT] | None = None,
263
280
  ) -> AsyncIterator[Event[Any]]:
264
281
  system_message, input_message = self._memorize_inputs(
265
282
  memory=memory,
266
283
  chat_inputs=chat_inputs,
267
284
  in_args=in_args,
268
285
  ctx=ctx,
286
+ call_id=call_id,
269
287
  )
270
288
  if system_message:
271
- self._print_messages([system_message], call_id=call_id, ctx=ctx)
289
+ self._print_messages([system_message], ctx=ctx, call_id=call_id)
272
290
  yield SystemMessageEvent(
273
291
  data=system_message, proc_name=self.name, call_id=call_id
274
292
  )
275
293
  if input_message:
276
- self._print_messages([input_message], call_id=call_id, ctx=ctx)
294
+ self._print_messages([input_message], ctx=ctx, call_id=call_id)
277
295
  yield UserMessageEvent(
278
296
  data=input_message, proc_name=self.name, call_id=call_id
279
297
  )
280
298
 
281
299
  async for event in self._policy_executor.execute_stream(
282
- memory, call_id=call_id, ctx=ctx
300
+ memory, ctx=ctx, call_id=call_id
283
301
  ):
284
302
  yield event
285
303
 
286
304
  output = self._parse_output(
287
- conversation=memory.message_history, in_args=in_args, ctx=ctx
305
+ conversation=memory.message_history,
306
+ in_args=in_args,
307
+ ctx=ctx,
308
+ call_id=call_id,
288
309
  )
289
310
  yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
290
311
 
291
312
  def _print_messages(
292
313
  self,
293
314
  messages: Sequence[Message],
315
+ ctx: RunContext[CtxT],
294
316
  call_id: str,
295
- ctx: RunContext[CtxT] | None = None,
296
317
  ) -> None:
297
318
  if ctx and ctx.printer:
298
319
  ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
@@ -321,28 +342,31 @@ class LLMAgent(
321
342
  if cur_cls.memory_manager is not base_cls.memory_manager:
322
343
  self._policy_executor.memory_manager = self.memory_manager
323
344
 
324
- def system_prompt_builder(self, ctx: RunContext[CtxT] | None = None) -> str | None:
345
+ def system_prompt_builder(self, ctx: RunContext[CtxT], call_id: str) -> str | None:
325
346
  if self._prompt_builder.system_prompt_builder is not None:
326
- return self._prompt_builder.system_prompt_builder(ctx=ctx)
347
+ return self._prompt_builder.system_prompt_builder(ctx=ctx, call_id=call_id)
327
348
  raise NotImplementedError("System prompt builder is not implemented.")
328
349
 
329
350
  def input_content_builder(
330
- self, in_args: InT | None = None, *, ctx: RunContext[CtxT] | None = None
351
+ self, in_args: InT, ctx: RunContext[CtxT], call_id: str
331
352
  ) -> Content:
332
353
  if self._prompt_builder.input_content_builder is not None:
333
- return self._prompt_builder.input_content_builder(in_args=in_args, ctx=ctx)
354
+ return self._prompt_builder.input_content_builder(
355
+ in_args=in_args, ctx=ctx, call_id=call_id
356
+ )
334
357
  raise NotImplementedError("Input content builder is not implemented.")
335
358
 
336
359
  def tool_call_loop_terminator(
337
360
  self,
338
361
  conversation: Messages,
339
362
  *,
340
- ctx: RunContext[CtxT] | None = None,
363
+ ctx: RunContext[CtxT],
364
+ call_id: str,
341
365
  **kwargs: Any,
342
366
  ) -> bool:
343
367
  if self._policy_executor.tool_call_loop_terminator is not None:
344
368
  return self._policy_executor.tool_call_loop_terminator(
345
- conversation=conversation, ctx=ctx, **kwargs
369
+ conversation=conversation, ctx=ctx, call_id=call_id, **kwargs
346
370
  )
347
371
  raise NotImplementedError("Tool call loop terminator is not implemented.")
348
372
 
@@ -350,12 +374,13 @@ class LLMAgent(
350
374
  self,
351
375
  memory: LLMAgentMemory,
352
376
  *,
353
- ctx: RunContext[CtxT] | None = None,
377
+ ctx: RunContext[CtxT],
378
+ call_id: str,
354
379
  **kwargs: Any,
355
380
  ) -> None:
356
381
  if self._policy_executor.memory_manager is not None:
357
382
  return self._policy_executor.memory_manager(
358
- memory=memory, ctx=ctx, **kwargs
383
+ memory=memory, ctx=ctx, call_id=call_id, **kwargs
359
384
  )
360
385
  raise NotImplementedError("Memory manager is not implemented.")
361
386
 
@@ -391,12 +416,11 @@ class LLMAgent(
391
416
  self, func: OutputParser[InT, OutT, CtxT]
392
417
  ) -> OutputParser[InT, OutT, CtxT]:
393
418
  if self._used_default_llm_response_schema:
394
- self._policy_executor.llm.response_schema = None
419
+ self._policy_executor.response_schema = None
395
420
  self.output_parser = func
396
421
 
397
422
  return func
398
423
 
399
424
  def add_memory_preparator(self, func: MemoryPreparator) -> MemoryPreparator:
400
425
  self.memory_preparator = func
401
-
402
426
  return func
@@ -13,14 +13,16 @@ class MemoryPreparator(Protocol):
13
13
  def __call__(
14
14
  self,
15
15
  memory: "LLMAgentMemory",
16
+ *,
16
17
  in_args: Any | None,
17
18
  sys_prompt: LLMPrompt | None,
18
- ctx: RunContext[Any] | None,
19
+ ctx: RunContext[Any],
20
+ call_id: str,
19
21
  ) -> None: ...
20
22
 
21
23
 
22
24
  class LLMAgentMemory(Memory):
23
- _message_history: Messages = PrivateAttr(default_factory=list) # type: ignore
25
+ _message_history: Messages = PrivateAttr(default_factory=Messages)
24
26
  _sys_prompt: LLMPrompt | None = PrivateAttr(default=None)
25
27
 
26
28
  def __init__(self, sys_prompt: LLMPrompt | None = None) -> None: