grasp_agents 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/llm.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import AsyncIterator, Mapping, Sequence
4
- from typing import Any, Generic, TypeVar, cast
3
+ from collections.abc import AsyncIterator, Mapping
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Generic, TypeVar
5
6
  from uuid import uuid4
6
7
 
7
8
  from pydantic import BaseModel
@@ -25,7 +26,7 @@ from .typing.events import (
25
26
  AnnotationsEndEvent,
26
27
  AnnotationsStartEvent,
27
28
  CompletionChunkEvent,
28
- CompletionEndEvent,
29
+ # CompletionEndEvent,
29
30
  CompletionEvent,
30
31
  CompletionStartEvent,
31
32
  LLMStateChangeEvent,
@@ -66,71 +67,24 @@ SettingsT_co = TypeVar("SettingsT_co", bound=LLMSettings, covariant=True)
66
67
  ConvertT_co = TypeVar("ConvertT_co", bound=Converters, covariant=True)
67
68
 
68
69
 
70
+ @dataclass(frozen=True)
69
71
  class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
70
- @abstractmethod
71
- def __init__(
72
+ model_name: str
73
+ converters: ConvertT_co
74
+ llm_settings: SettingsT_co | None = None
75
+ model_id: str = field(default_factory=lambda: str(uuid4())[:8])
76
+
77
+ def _validate_response(
72
78
  self,
73
- converters: ConvertT_co,
74
- model_name: str | None = None,
75
- model_id: str | None = None,
76
- llm_settings: SettingsT_co | None = None,
77
- tools: Sequence[BaseTool[BaseModel, Any, Any]] | None = None,
78
- response_schema: Any | None = None,
79
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
80
- **kwargs: Any,
79
+ completion: Completion,
80
+ response_schema: Any | None,
81
+ response_schema_by_xml_tag: Mapping[str, Any] | None,
81
82
  ) -> None:
82
- super().__init__()
83
-
84
- self._converters = converters
85
- self._model_id = model_id or str(uuid4())[:8]
86
- self._model_name = model_name
87
- self._tools = {t.name: t for t in tools} if tools else None
88
- self._llm_settings: SettingsT_co = llm_settings or cast("SettingsT_co", {})
89
-
90
83
  if response_schema and response_schema_by_xml_tag:
91
84
  raise ValueError(
92
85
  "Only one of response_schema and response_schema_by_xml_tag can be "
93
86
  "provided, but not both."
94
87
  )
95
- self._response_schema = response_schema
96
- self._response_schema_by_xml_tag = response_schema_by_xml_tag
97
-
98
- @property
99
- def model_id(self) -> str:
100
- return self._model_id
101
-
102
- @property
103
- def model_name(self) -> str | None:
104
- return self._model_name
105
-
106
- @property
107
- def llm_settings(self) -> SettingsT_co:
108
- return self._llm_settings
109
-
110
- @property
111
- def response_schema(self) -> Any | None:
112
- return self._response_schema
113
-
114
- @response_schema.setter
115
- def response_schema(self, response_schema: Any | None) -> None:
116
- self._response_schema = response_schema
117
-
118
- @property
119
- def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
120
- return self._response_schema_by_xml_tag
121
-
122
- @property
123
- def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
124
- return self._tools
125
-
126
- @tools.setter
127
- def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
128
- self._tools = {t.name: t for t in tools} if tools else None
129
-
130
- def __repr__(self) -> str:
131
- return f"{type(self).__name__}[{self.model_id}]; model_name={self._model_name})"
132
-
133
- def _validate_response(self, completion: Completion) -> None:
134
88
  parsing_params = {
135
89
  "from_substring": False,
136
90
  "strip_language_markdown": True,
@@ -138,17 +92,17 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
138
92
  try:
139
93
  for message in completion.messages:
140
94
  if not message.tool_calls:
141
- if self._response_schema:
95
+ if response_schema:
142
96
  validate_obj_from_json_or_py_string(
143
97
  message.content or "",
144
- schema=self._response_schema,
98
+ schema=response_schema,
145
99
  **parsing_params,
146
100
  )
147
101
 
148
- elif self._response_schema_by_xml_tag:
102
+ elif response_schema_by_xml_tag:
149
103
  validate_tagged_objs_from_json_or_py_string(
150
104
  message.content or "",
151
- schema_by_xml_tag=self._response_schema_by_xml_tag,
105
+ schema_by_xml_tag=response_schema_by_xml_tag,
152
106
  **parsing_params,
153
107
  )
154
108
  except JSONSchemaValidationError as exc:
@@ -156,7 +110,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
156
110
  exc.s, exc.schema, message=str(exc)
157
111
  ) from exc
158
112
 
159
- def _validate_tool_calls(self, completion: Completion) -> None:
113
+ def _validate_tool_calls(
114
+ self, completion: Completion, tools: Mapping[str, BaseTool[BaseModel, Any, Any]]
115
+ ) -> None:
160
116
  parsing_params = {
161
117
  "from_substring": False,
162
118
  "strip_language_markdown": True,
@@ -167,15 +123,15 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
167
123
  tool_name = tool_call.tool_name
168
124
  tool_arguments = tool_call.tool_arguments
169
125
 
170
- available_tool_names = list(self.tools) if self.tools else []
171
- if tool_name not in available_tool_names or not self.tools:
126
+ available_tool_names = list(tools) if tools else []
127
+ if tool_name not in available_tool_names or not tools:
172
128
  raise LLMToolCallValidationError(
173
129
  tool_name,
174
130
  tool_arguments,
175
131
  message=f"Tool '{tool_name}' is not available in the LLM "
176
132
  f"tools (available: {available_tool_names})",
177
133
  )
178
- tool = self.tools[tool_name]
134
+ tool = tools[tool_name]
179
135
  try:
180
136
  validate_obj_from_json_or_py_string(
181
137
  tool_arguments, schema=tool.in_type, **parsing_params
@@ -196,7 +152,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
196
152
  annotations_op_evt: AnnotationsChunkEvent | None = None
197
153
  tool_calls_op_evt: ToolCallChunkEvent | None = None
198
154
 
199
- def _close_open_events() -> list[LLMStateChangeEvent[Any]]:
155
+ def _close_open_events(
156
+ _event: CompletionChunkEvent[CompletionChunk] | None = None,
157
+ ) -> list[LLMStateChangeEvent[Any]]:
200
158
  nonlocal \
201
159
  chunk_op_evt, \
202
160
  thinking_op_evt, \
@@ -206,26 +164,21 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
206
164
 
207
165
  events: list[LLMStateChangeEvent[Any]] = []
208
166
 
209
- if tool_calls_op_evt:
167
+ if not isinstance(_event, ThinkingChunkEvent) and thinking_op_evt:
168
+ events.append(ThinkingEndEvent.from_chunk_event(thinking_op_evt))
169
+ thinking_op_evt = None
170
+
171
+ if not isinstance(_event, ToolCallChunkEvent) and tool_calls_op_evt:
210
172
  events.append(ToolCallEndEvent.from_chunk_event(tool_calls_op_evt))
173
+ tool_calls_op_evt = None
211
174
 
212
- if response_op_evt:
175
+ if not isinstance(_event, ResponseChunkEvent) and response_op_evt:
213
176
  events.append(ResponseEndEvent.from_chunk_event(response_op_evt))
177
+ response_op_evt = None
214
178
 
215
- if thinking_op_evt:
216
- events.append(ThinkingEndEvent.from_chunk_event(thinking_op_evt))
217
-
218
- if annotations_op_evt:
179
+ if not isinstance(_event, AnnotationsChunkEvent) and annotations_op_evt:
219
180
  events.append(AnnotationsEndEvent.from_chunk_event(annotations_op_evt))
220
-
221
- if chunk_op_evt:
222
- events.append(CompletionEndEvent.from_chunk_event(chunk_op_evt))
223
-
224
- chunk_op_evt = None
225
- thinking_op_evt = None
226
- tool_calls_op_evt = None
227
- response_op_evt = None
228
- annotations_op_evt = None
181
+ annotations_op_evt = None
229
182
 
230
183
  return events
231
184
 
@@ -252,14 +205,14 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
252
205
  sub_events = event.split_into_specialized()
253
206
 
254
207
  for sub_event in sub_events:
208
+ for close_event in _close_open_events(sub_event):
209
+ yield close_event
210
+
255
211
  if isinstance(sub_event, ThinkingChunkEvent):
256
212
  if not thinking_op_evt:
257
213
  thinking_op_evt = sub_event
258
214
  yield ThinkingStartEvent.from_chunk_event(sub_event)
259
215
  yield sub_event
260
- elif thinking_op_evt:
261
- yield ThinkingEndEvent.from_chunk_event(thinking_op_evt)
262
- thinking_op_evt = None
263
216
 
264
217
  if isinstance(sub_event, ToolCallChunkEvent):
265
218
  tc = sub_event.data.tool_call
@@ -273,27 +226,18 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
273
226
  tool_calls_op_evt = sub_event
274
227
  yield ToolCallStartEvent.from_chunk_event(sub_event)
275
228
  yield sub_event
276
- elif tool_calls_op_evt:
277
- yield ToolCallEndEvent.from_chunk_event(tool_calls_op_evt)
278
- tool_calls_op_evt = None
279
229
 
280
230
  if isinstance(sub_event, ResponseChunkEvent):
281
231
  if not response_op_evt:
282
232
  response_op_evt = sub_event
283
233
  yield ResponseStartEvent.from_chunk_event(sub_event)
284
234
  yield sub_event
285
- elif response_op_evt:
286
- yield ResponseEndEvent.from_chunk_event(response_op_evt)
287
- response_op_evt = None
288
235
 
289
236
  if isinstance(sub_event, AnnotationsChunkEvent):
290
237
  if not annotations_op_evt:
291
238
  annotations_op_evt = sub_event
292
239
  yield AnnotationsStartEvent.from_chunk_event(sub_event)
293
240
  yield sub_event
294
- elif annotations_op_evt:
295
- yield AnnotationsEndEvent.from_chunk_event(annotations_op_evt)
296
- annotations_op_evt = None
297
241
 
298
242
  prev_completion_id = chunk.id
299
243
 
@@ -308,6 +252,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
308
252
  self,
309
253
  conversation: Messages,
310
254
  *,
255
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
256
+ response_schema: Any | None = None,
257
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
311
258
  tool_choice: ToolChoice | None = None,
312
259
  n_choices: int | None = None,
313
260
  proc_name: str | None = None,
@@ -320,6 +267,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
320
267
  self,
321
268
  conversation: Messages,
322
269
  *,
270
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
271
+ response_schema: Any | None = None,
272
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
323
273
  tool_choice: ToolChoice | None = None,
324
274
  n_choices: int | None = None,
325
275
  proc_name: str | None = None,
@@ -330,5 +280,10 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
330
280
  pass
331
281
 
332
282
  @abstractmethod
333
- def combine_completion_chunks(self, completion_chunks: list[Any]) -> Any:
283
+ def combine_completion_chunks(
284
+ self,
285
+ completion_chunks: list[Any],
286
+ response_schema: Any | None = None,
287
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
288
+ ) -> Any:
334
289
  raise NotImplementedError
grasp_agents/llm_agent.py CHANGED
@@ -1,4 +1,4 @@
1
- from collections.abc import AsyncIterator, Sequence
1
+ from collections.abc import AsyncIterator, Mapping, Sequence
2
2
  from pathlib import Path
3
3
  from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
4
4
 
@@ -41,7 +41,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
41
41
  conversation: Messages,
42
42
  *,
43
43
  in_args: _InT_contra | None,
44
- ctx: RunContext[CtxT] | None,
44
+ ctx: RunContext[CtxT],
45
45
  ) -> _OutT_co: ...
46
46
 
47
47
 
@@ -68,16 +68,19 @@ class LLMAgent(
68
68
  # System prompt template
69
69
  sys_prompt: LLMPrompt | None = None,
70
70
  sys_prompt_path: str | Path | None = None,
71
+ # LLM response validation
72
+ response_schema: Any | None = None,
73
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
71
74
  # Agent loop settings
72
75
  max_turns: int = 100,
73
76
  react_mode: bool = False,
74
77
  final_answer_as_tool_call: bool = False,
75
78
  # Agent memory management
76
79
  reset_memory_on_run: bool = False,
77
- # Retries
80
+ # Agent run retries
78
81
  max_retries: int = 0,
79
82
  # Multi-agent routing
80
- recipients: list[ProcName] | None = None,
83
+ recipients: Sequence[ProcName] | None = None,
81
84
  ) -> None:
82
85
  super().__init__(name=name, recipients=recipients, max_retries=max_retries)
83
86
 
@@ -96,15 +99,6 @@ class LLMAgent(
96
99
 
97
100
  # LLM policy executor
98
101
 
99
- self._used_default_llm_response_schema: bool = False
100
- if (
101
- llm.response_schema is None
102
- and tools is None
103
- and not hasattr(type(self), "output_parser")
104
- ):
105
- llm.response_schema = self.out_type
106
- self._used_default_llm_response_schema = True
107
-
108
102
  if issubclass(self._out_type, BaseModel):
109
103
  final_answer_type = self._out_type
110
104
  elif not final_answer_as_tool_call:
@@ -115,10 +109,21 @@ class LLMAgent(
115
109
  "final_answer_as_tool_call is True."
116
110
  )
117
111
 
112
+ self._used_default_llm_response_schema: bool = False
113
+ if (
114
+ response_schema is None
115
+ and tools is None
116
+ and not hasattr(type(self), "output_parser")
117
+ ):
118
+ response_schema = self.out_type
119
+ self._used_default_llm_response_schema = True
120
+
118
121
  self._policy_executor: LLMPolicyExecutor[CtxT] = LLMPolicyExecutor[CtxT](
119
122
  agent_name=self.name,
120
123
  llm=llm,
121
124
  tools=tools,
125
+ response_schema=response_schema,
126
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
122
127
  max_turns=max_turns,
123
128
  react_mode=react_mode,
124
129
  final_answer_type=final_answer_type,
@@ -160,9 +165,10 @@ class LLMAgent(
160
165
  def _prepare_memory(
161
166
  self,
162
167
  memory: LLMAgentMemory,
168
+ *,
163
169
  in_args: InT | None = None,
164
170
  sys_prompt: LLMPrompt | None = None,
165
- ctx: RunContext[Any] | None = None,
171
+ ctx: RunContext[Any],
166
172
  ) -> None:
167
173
  if self.memory_preparator:
168
174
  return self.memory_preparator(
@@ -172,9 +178,10 @@ class LLMAgent(
172
178
  def _memorize_inputs(
173
179
  self,
174
180
  memory: LLMAgentMemory,
181
+ *,
175
182
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
176
183
  in_args: InT | None = None,
177
- ctx: RunContext[CtxT] | None = None,
184
+ ctx: RunContext[CtxT],
178
185
  ) -> tuple[SystemMessage | None, UserMessage | None]:
179
186
  formatted_sys_prompt = self._prompt_builder.build_system_prompt(ctx=ctx)
180
187
 
@@ -201,7 +208,7 @@ class LLMAgent(
201
208
  conversation: Messages,
202
209
  *,
203
210
  in_args: InT | None = None,
204
- ctx: RunContext[CtxT] | None = None,
211
+ ctx: RunContext[CtxT],
205
212
  ) -> OutT:
206
213
  return validate_obj_from_json_or_py_string(
207
214
  str(conversation[-1].content or ""),
@@ -215,7 +222,7 @@ class LLMAgent(
215
222
  conversation: Messages,
216
223
  *,
217
224
  in_args: InT | None = None,
218
- ctx: RunContext[CtxT] | None = None,
225
+ ctx: RunContext[CtxT],
219
226
  ) -> OutT:
220
227
  if self.output_parser:
221
228
  return self.output_parser(
@@ -233,7 +240,7 @@ class LLMAgent(
233
240
  in_args: InT | None = None,
234
241
  memory: LLMAgentMemory,
235
242
  call_id: str,
236
- ctx: RunContext[CtxT] | None = None,
243
+ ctx: RunContext[CtxT],
237
244
  ) -> OutT:
238
245
  system_message, input_message = self._memorize_inputs(
239
246
  memory=memory,
@@ -259,7 +266,7 @@ class LLMAgent(
259
266
  in_args: InT | None = None,
260
267
  memory: LLMAgentMemory,
261
268
  call_id: str,
262
- ctx: RunContext[CtxT] | None = None,
269
+ ctx: RunContext[CtxT],
263
270
  ) -> AsyncIterator[Event[Any]]:
264
271
  system_message, input_message = self._memorize_inputs(
265
272
  memory=memory,
@@ -292,7 +299,7 @@ class LLMAgent(
292
299
  self,
293
300
  messages: Sequence[Message],
294
301
  call_id: str,
295
- ctx: RunContext[CtxT] | None = None,
302
+ ctx: RunContext[CtxT],
296
303
  ) -> None:
297
304
  if ctx and ctx.printer:
298
305
  ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
@@ -321,24 +328,18 @@ class LLMAgent(
321
328
  if cur_cls.memory_manager is not base_cls.memory_manager:
322
329
  self._policy_executor.memory_manager = self.memory_manager
323
330
 
324
- def system_prompt_builder(self, ctx: RunContext[CtxT] | None = None) -> str | None:
331
+ def system_prompt_builder(self, ctx: RunContext[CtxT]) -> str | None:
325
332
  if self._prompt_builder.system_prompt_builder is not None:
326
333
  return self._prompt_builder.system_prompt_builder(ctx=ctx)
327
334
  raise NotImplementedError("System prompt builder is not implemented.")
328
335
 
329
- def input_content_builder(
330
- self, in_args: InT | None = None, *, ctx: RunContext[CtxT] | None = None
331
- ) -> Content:
336
+ def input_content_builder(self, in_args: InT, ctx: RunContext[CtxT]) -> Content:
332
337
  if self._prompt_builder.input_content_builder is not None:
333
338
  return self._prompt_builder.input_content_builder(in_args=in_args, ctx=ctx)
334
339
  raise NotImplementedError("Input content builder is not implemented.")
335
340
 
336
341
  def tool_call_loop_terminator(
337
- self,
338
- conversation: Messages,
339
- *,
340
- ctx: RunContext[CtxT] | None = None,
341
- **kwargs: Any,
342
+ self, conversation: Messages, *, ctx: RunContext[CtxT], **kwargs: Any
342
343
  ) -> bool:
343
344
  if self._policy_executor.tool_call_loop_terminator is not None:
344
345
  return self._policy_executor.tool_call_loop_terminator(
@@ -347,11 +348,7 @@ class LLMAgent(
347
348
  raise NotImplementedError("Tool call loop terminator is not implemented.")
348
349
 
349
350
  def memory_manager(
350
- self,
351
- memory: LLMAgentMemory,
352
- *,
353
- ctx: RunContext[CtxT] | None = None,
354
- **kwargs: Any,
351
+ self, memory: LLMAgentMemory, *, ctx: RunContext[CtxT], **kwargs: Any
355
352
  ) -> None:
356
353
  if self._policy_executor.memory_manager is not None:
357
354
  return self._policy_executor.memory_manager(
@@ -391,12 +388,11 @@ class LLMAgent(
391
388
  self, func: OutputParser[InT, OutT, CtxT]
392
389
  ) -> OutputParser[InT, OutT, CtxT]:
393
390
  if self._used_default_llm_response_schema:
394
- self._policy_executor.llm.response_schema = None
391
+ self._policy_executor.response_schema = None
395
392
  self.output_parser = func
396
393
 
397
394
  return func
398
395
 
399
396
  def add_memory_preparator(self, func: MemoryPreparator) -> MemoryPreparator:
400
397
  self.memory_preparator = func
401
-
402
398
  return func
@@ -13,14 +13,15 @@ class MemoryPreparator(Protocol):
13
13
  def __call__(
14
14
  self,
15
15
  memory: "LLMAgentMemory",
16
+ *,
16
17
  in_args: Any | None,
17
18
  sys_prompt: LLMPrompt | None,
18
- ctx: RunContext[Any] | None,
19
+ ctx: RunContext[Any],
19
20
  ) -> None: ...
20
21
 
21
22
 
22
23
  class LLMAgentMemory(Memory):
23
- _message_history: Messages = PrivateAttr(default_factory=list) # type: ignore
24
+ _message_history: Messages = PrivateAttr(default_factory=Messages)
24
25
  _sys_prompt: LLMPrompt | None = PrivateAttr(default=None)
25
26
 
26
27
  def __init__(self, sys_prompt: LLMPrompt | None = None) -> None: