grasp_agents 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/llm.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import AsyncIterator, Mapping, Sequence
4
- from typing import Any, Generic, TypeVar, cast
3
+ from collections.abc import AsyncIterator, Mapping
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Generic, TypeVar
5
6
  from uuid import uuid4
6
7
 
7
8
  from pydantic import BaseModel
@@ -66,71 +67,24 @@ SettingsT_co = TypeVar("SettingsT_co", bound=LLMSettings, covariant=True)
66
67
  ConvertT_co = TypeVar("ConvertT_co", bound=Converters, covariant=True)
67
68
 
68
69
 
70
+ @dataclass(frozen=True)
69
71
  class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
70
- @abstractmethod
71
- def __init__(
72
+ model_name: str
73
+ converters: ConvertT_co
74
+ llm_settings: SettingsT_co | None = None
75
+ model_id: str = field(default_factory=lambda: str(uuid4())[:8])
76
+
77
+ def _validate_response(
72
78
  self,
73
- converters: ConvertT_co,
74
- model_name: str | None = None,
75
- model_id: str | None = None,
76
- llm_settings: SettingsT_co | None = None,
77
- tools: Sequence[BaseTool[BaseModel, Any, Any]] | None = None,
78
- response_schema: Any | None = None,
79
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
80
- **kwargs: Any,
79
+ completion: Completion,
80
+ response_schema: Any | None,
81
+ response_schema_by_xml_tag: Mapping[str, Any] | None,
81
82
  ) -> None:
82
- super().__init__()
83
-
84
- self._converters = converters
85
- self._model_id = model_id or str(uuid4())[:8]
86
- self._model_name = model_name
87
- self._tools = {t.name: t for t in tools} if tools else None
88
- self._llm_settings: SettingsT_co = llm_settings or cast("SettingsT_co", {})
89
-
90
83
  if response_schema and response_schema_by_xml_tag:
91
84
  raise ValueError(
92
85
  "Only one of response_schema and response_schema_by_xml_tag can be "
93
86
  "provided, but not both."
94
87
  )
95
- self._response_schema = response_schema
96
- self._response_schema_by_xml_tag = response_schema_by_xml_tag
97
-
98
- @property
99
- def model_id(self) -> str:
100
- return self._model_id
101
-
102
- @property
103
- def model_name(self) -> str | None:
104
- return self._model_name
105
-
106
- @property
107
- def llm_settings(self) -> SettingsT_co:
108
- return self._llm_settings
109
-
110
- @property
111
- def response_schema(self) -> Any | None:
112
- return self._response_schema
113
-
114
- @response_schema.setter
115
- def response_schema(self, response_schema: Any | None) -> None:
116
- self._response_schema = response_schema
117
-
118
- @property
119
- def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
120
- return self._response_schema_by_xml_tag
121
-
122
- @property
123
- def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
124
- return self._tools
125
-
126
- @tools.setter
127
- def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
128
- self._tools = {t.name: t for t in tools} if tools else None
129
-
130
- def __repr__(self) -> str:
131
- return f"{type(self).__name__}[{self.model_id}]; model_name={self._model_name})"
132
-
133
- def _validate_response(self, completion: Completion) -> None:
134
88
  parsing_params = {
135
89
  "from_substring": False,
136
90
  "strip_language_markdown": True,
@@ -138,17 +92,17 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
138
92
  try:
139
93
  for message in completion.messages:
140
94
  if not message.tool_calls:
141
- if self._response_schema:
95
+ if response_schema:
142
96
  validate_obj_from_json_or_py_string(
143
97
  message.content or "",
144
- schema=self._response_schema,
98
+ schema=response_schema,
145
99
  **parsing_params,
146
100
  )
147
101
 
148
- elif self._response_schema_by_xml_tag:
102
+ elif response_schema_by_xml_tag:
149
103
  validate_tagged_objs_from_json_or_py_string(
150
104
  message.content or "",
151
- schema_by_xml_tag=self._response_schema_by_xml_tag,
105
+ schema_by_xml_tag=response_schema_by_xml_tag,
152
106
  **parsing_params,
153
107
  )
154
108
  except JSONSchemaValidationError as exc:
@@ -156,7 +110,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
156
110
  exc.s, exc.schema, message=str(exc)
157
111
  ) from exc
158
112
 
159
- def _validate_tool_calls(self, completion: Completion) -> None:
113
+ def _validate_tool_calls(
114
+ self, completion: Completion, tools: Mapping[str, BaseTool[BaseModel, Any, Any]]
115
+ ) -> None:
160
116
  parsing_params = {
161
117
  "from_substring": False,
162
118
  "strip_language_markdown": True,
@@ -167,15 +123,15 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
167
123
  tool_name = tool_call.tool_name
168
124
  tool_arguments = tool_call.tool_arguments
169
125
 
170
- available_tool_names = list(self.tools) if self.tools else []
171
- if tool_name not in available_tool_names or not self.tools:
126
+ available_tool_names = list(tools) if tools else []
127
+ if tool_name not in available_tool_names or not tools:
172
128
  raise LLMToolCallValidationError(
173
129
  tool_name,
174
130
  tool_arguments,
175
131
  message=f"Tool '{tool_name}' is not available in the LLM "
176
132
  f"tools (available: {available_tool_names})",
177
133
  )
178
- tool = self.tools[tool_name]
134
+ tool = tools[tool_name]
179
135
  try:
180
136
  validate_obj_from_json_or_py_string(
181
137
  tool_arguments, schema=tool.in_type, **parsing_params
@@ -296,6 +252,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
296
252
  self,
297
253
  conversation: Messages,
298
254
  *,
255
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
256
+ response_schema: Any | None = None,
257
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
299
258
  tool_choice: ToolChoice | None = None,
300
259
  n_choices: int | None = None,
301
260
  proc_name: str | None = None,
@@ -308,6 +267,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
308
267
  self,
309
268
  conversation: Messages,
310
269
  *,
270
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
271
+ response_schema: Any | None = None,
272
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
311
273
  tool_choice: ToolChoice | None = None,
312
274
  n_choices: int | None = None,
313
275
  proc_name: str | None = None,
@@ -318,5 +280,10 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
318
280
  pass
319
281
 
320
282
  @abstractmethod
321
- def combine_completion_chunks(self, completion_chunks: list[Any]) -> Any:
283
+ def combine_completion_chunks(
284
+ self,
285
+ completion_chunks: list[Any],
286
+ response_schema: Any | None = None,
287
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
288
+ ) -> Any:
322
289
  raise NotImplementedError
grasp_agents/llm_agent.py CHANGED
@@ -1,4 +1,4 @@
1
- from collections.abc import AsyncIterator, Sequence
1
+ from collections.abc import AsyncIterator, Mapping, Sequence
2
2
  from pathlib import Path
3
3
  from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
4
4
 
@@ -41,7 +41,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
41
41
  conversation: Messages,
42
42
  *,
43
43
  in_args: _InT_contra | None,
44
- ctx: RunContext[CtxT] | None,
44
+ ctx: RunContext[CtxT],
45
45
  ) -> _OutT_co: ...
46
46
 
47
47
 
@@ -68,16 +68,19 @@ class LLMAgent(
68
68
  # System prompt template
69
69
  sys_prompt: LLMPrompt | None = None,
70
70
  sys_prompt_path: str | Path | None = None,
71
+ # LLM response validation
72
+ response_schema: Any | None = None,
73
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
71
74
  # Agent loop settings
72
75
  max_turns: int = 100,
73
76
  react_mode: bool = False,
74
77
  final_answer_as_tool_call: bool = False,
75
78
  # Agent memory management
76
79
  reset_memory_on_run: bool = False,
77
- # Retries
80
+ # Agent run retries
78
81
  max_retries: int = 0,
79
82
  # Multi-agent routing
80
- recipients: list[ProcName] | None = None,
83
+ recipients: Sequence[ProcName] | None = None,
81
84
  ) -> None:
82
85
  super().__init__(name=name, recipients=recipients, max_retries=max_retries)
83
86
 
@@ -96,15 +99,6 @@ class LLMAgent(
96
99
 
97
100
  # LLM policy executor
98
101
 
99
- self._used_default_llm_response_schema: bool = False
100
- if (
101
- llm.response_schema is None
102
- and tools is None
103
- and not hasattr(type(self), "output_parser")
104
- ):
105
- llm.response_schema = self.out_type
106
- self._used_default_llm_response_schema = True
107
-
108
102
  if issubclass(self._out_type, BaseModel):
109
103
  final_answer_type = self._out_type
110
104
  elif not final_answer_as_tool_call:
@@ -115,10 +109,21 @@ class LLMAgent(
115
109
  "final_answer_as_tool_call is True."
116
110
  )
117
111
 
112
+ self._used_default_llm_response_schema: bool = False
113
+ if (
114
+ response_schema is None
115
+ and tools is None
116
+ and not hasattr(type(self), "output_parser")
117
+ ):
118
+ response_schema = self.out_type
119
+ self._used_default_llm_response_schema = True
120
+
118
121
  self._policy_executor: LLMPolicyExecutor[CtxT] = LLMPolicyExecutor[CtxT](
119
122
  agent_name=self.name,
120
123
  llm=llm,
121
124
  tools=tools,
125
+ response_schema=response_schema,
126
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
122
127
  max_turns=max_turns,
123
128
  react_mode=react_mode,
124
129
  final_answer_type=final_answer_type,
@@ -160,9 +165,10 @@ class LLMAgent(
160
165
  def _prepare_memory(
161
166
  self,
162
167
  memory: LLMAgentMemory,
168
+ *,
163
169
  in_args: InT | None = None,
164
170
  sys_prompt: LLMPrompt | None = None,
165
- ctx: RunContext[Any] | None = None,
171
+ ctx: RunContext[Any],
166
172
  ) -> None:
167
173
  if self.memory_preparator:
168
174
  return self.memory_preparator(
@@ -172,9 +178,10 @@ class LLMAgent(
172
178
  def _memorize_inputs(
173
179
  self,
174
180
  memory: LLMAgentMemory,
181
+ *,
175
182
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
176
183
  in_args: InT | None = None,
177
- ctx: RunContext[CtxT] | None = None,
184
+ ctx: RunContext[CtxT],
178
185
  ) -> tuple[SystemMessage | None, UserMessage | None]:
179
186
  formatted_sys_prompt = self._prompt_builder.build_system_prompt(ctx=ctx)
180
187
 
@@ -201,7 +208,7 @@ class LLMAgent(
201
208
  conversation: Messages,
202
209
  *,
203
210
  in_args: InT | None = None,
204
- ctx: RunContext[CtxT] | None = None,
211
+ ctx: RunContext[CtxT],
205
212
  ) -> OutT:
206
213
  return validate_obj_from_json_or_py_string(
207
214
  str(conversation[-1].content or ""),
@@ -215,7 +222,7 @@ class LLMAgent(
215
222
  conversation: Messages,
216
223
  *,
217
224
  in_args: InT | None = None,
218
- ctx: RunContext[CtxT] | None = None,
225
+ ctx: RunContext[CtxT],
219
226
  ) -> OutT:
220
227
  if self.output_parser:
221
228
  return self.output_parser(
@@ -233,7 +240,7 @@ class LLMAgent(
233
240
  in_args: InT | None = None,
234
241
  memory: LLMAgentMemory,
235
242
  call_id: str,
236
- ctx: RunContext[CtxT] | None = None,
243
+ ctx: RunContext[CtxT],
237
244
  ) -> OutT:
238
245
  system_message, input_message = self._memorize_inputs(
239
246
  memory=memory,
@@ -259,7 +266,7 @@ class LLMAgent(
259
266
  in_args: InT | None = None,
260
267
  memory: LLMAgentMemory,
261
268
  call_id: str,
262
- ctx: RunContext[CtxT] | None = None,
269
+ ctx: RunContext[CtxT],
263
270
  ) -> AsyncIterator[Event[Any]]:
264
271
  system_message, input_message = self._memorize_inputs(
265
272
  memory=memory,
@@ -292,7 +299,7 @@ class LLMAgent(
292
299
  self,
293
300
  messages: Sequence[Message],
294
301
  call_id: str,
295
- ctx: RunContext[CtxT] | None = None,
302
+ ctx: RunContext[CtxT],
296
303
  ) -> None:
297
304
  if ctx and ctx.printer:
298
305
  ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
@@ -321,24 +328,18 @@ class LLMAgent(
321
328
  if cur_cls.memory_manager is not base_cls.memory_manager:
322
329
  self._policy_executor.memory_manager = self.memory_manager
323
330
 
324
- def system_prompt_builder(self, ctx: RunContext[CtxT] | None = None) -> str | None:
331
+ def system_prompt_builder(self, ctx: RunContext[CtxT]) -> str | None:
325
332
  if self._prompt_builder.system_prompt_builder is not None:
326
333
  return self._prompt_builder.system_prompt_builder(ctx=ctx)
327
334
  raise NotImplementedError("System prompt builder is not implemented.")
328
335
 
329
- def input_content_builder(
330
- self, in_args: InT | None = None, *, ctx: RunContext[CtxT] | None = None
331
- ) -> Content:
336
+ def input_content_builder(self, in_args: InT, ctx: RunContext[CtxT]) -> Content:
332
337
  if self._prompt_builder.input_content_builder is not None:
333
338
  return self._prompt_builder.input_content_builder(in_args=in_args, ctx=ctx)
334
339
  raise NotImplementedError("Input content builder is not implemented.")
335
340
 
336
341
  def tool_call_loop_terminator(
337
- self,
338
- conversation: Messages,
339
- *,
340
- ctx: RunContext[CtxT] | None = None,
341
- **kwargs: Any,
342
+ self, conversation: Messages, *, ctx: RunContext[CtxT], **kwargs: Any
342
343
  ) -> bool:
343
344
  if self._policy_executor.tool_call_loop_terminator is not None:
344
345
  return self._policy_executor.tool_call_loop_terminator(
@@ -347,11 +348,7 @@ class LLMAgent(
347
348
  raise NotImplementedError("Tool call loop terminator is not implemented.")
348
349
 
349
350
  def memory_manager(
350
- self,
351
- memory: LLMAgentMemory,
352
- *,
353
- ctx: RunContext[CtxT] | None = None,
354
- **kwargs: Any,
351
+ self, memory: LLMAgentMemory, *, ctx: RunContext[CtxT], **kwargs: Any
355
352
  ) -> None:
356
353
  if self._policy_executor.memory_manager is not None:
357
354
  return self._policy_executor.memory_manager(
@@ -391,12 +388,11 @@ class LLMAgent(
391
388
  self, func: OutputParser[InT, OutT, CtxT]
392
389
  ) -> OutputParser[InT, OutT, CtxT]:
393
390
  if self._used_default_llm_response_schema:
394
- self._policy_executor.llm.response_schema = None
391
+ self._policy_executor.response_schema = None
395
392
  self.output_parser = func
396
393
 
397
394
  return func
398
395
 
399
396
  def add_memory_preparator(self, func: MemoryPreparator) -> MemoryPreparator:
400
397
  self.memory_preparator = func
401
-
402
398
  return func
@@ -13,14 +13,15 @@ class MemoryPreparator(Protocol):
13
13
  def __call__(
14
14
  self,
15
15
  memory: "LLMAgentMemory",
16
+ *,
16
17
  in_args: Any | None,
17
18
  sys_prompt: LLMPrompt | None,
18
- ctx: RunContext[Any] | None,
19
+ ctx: RunContext[Any],
19
20
  ) -> None: ...
20
21
 
21
22
 
22
23
  class LLMAgentMemory(Memory):
23
- _message_history: Messages = PrivateAttr(default_factory=list) # type: ignore
24
+ _message_history: Messages = PrivateAttr(default_factory=Messages)
24
25
  _sys_prompt: LLMPrompt | None = PrivateAttr(default=None)
25
26
 
26
27
  def __init__(self, sys_prompt: LLMPrompt | None = None) -> None:
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  import json
3
- from collections.abc import AsyncIterator, Coroutine, Sequence
3
+ from collections.abc import AsyncIterator, Coroutine, Mapping, Sequence
4
4
  from itertools import starmap
5
5
  from logging import getLogger
6
6
  from typing import Any, Generic, Protocol, final
@@ -36,7 +36,7 @@ class ToolCallLoopTerminator(Protocol[CtxT]):
36
36
  self,
37
37
  conversation: Messages,
38
38
  *,
39
- ctx: RunContext[CtxT] | None,
39
+ ctx: RunContext[CtxT],
40
40
  **kwargs: Any,
41
41
  ) -> bool: ...
42
42
 
@@ -46,7 +46,7 @@ class MemoryManager(Protocol[CtxT]):
46
46
  self,
47
47
  memory: LLMAgentMemory,
48
48
  *,
49
- ctx: RunContext[CtxT] | None,
49
+ ctx: RunContext[CtxT],
50
50
  **kwargs: Any,
51
51
  ) -> None: ...
52
52
 
@@ -54,9 +54,12 @@ class MemoryManager(Protocol[CtxT]):
54
54
  class LLMPolicyExecutor(Generic[CtxT]):
55
55
  def __init__(
56
56
  self,
57
+ *,
57
58
  agent_name: str,
58
59
  llm: LLM[LLMSettings, Converters],
59
60
  tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
61
+ response_schema: Any | None = None,
62
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
60
63
  max_turns: int,
61
64
  react_mode: bool = False,
62
65
  final_answer_type: type[BaseModel] = BaseModel,
@@ -70,12 +73,15 @@ class LLMPolicyExecutor(Generic[CtxT]):
70
73
  self._final_answer_as_tool_call = final_answer_as_tool_call
71
74
  self._final_answer_tool = self.get_final_answer_tool()
72
75
 
73
- _tools: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
76
+ tools_list: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
74
77
  if tools and final_answer_as_tool_call:
75
- _tools = tools + [self._final_answer_tool]
78
+ tools_list = tools + [self._final_answer_tool]
79
+ self._tools = {t.name: t for t in tools_list} if tools_list else None
80
+
81
+ self._response_schema = response_schema
82
+ self._response_schema_by_xml_tag = response_schema_by_xml_tag
76
83
 
77
84
  self._llm = llm
78
- self._llm.tools = _tools
79
85
 
80
86
  self._max_turns = max_turns
81
87
  self._react_mode = react_mode
@@ -91,9 +97,21 @@ class LLMPolicyExecutor(Generic[CtxT]):
91
97
  def llm(self) -> LLM[LLMSettings, Converters]:
92
98
  return self._llm
93
99
 
100
+ @property
101
+ def response_schema(self) -> Any | None:
102
+ return self._response_schema
103
+
104
+ @response_schema.setter
105
+ def response_schema(self, value: Any | None) -> None:
106
+ self._response_schema = value
107
+
108
+ @property
109
+ def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
110
+ return self._response_schema_by_xml_tag
111
+
94
112
  @property
95
113
  def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
96
- return self._llm.tools or {}
114
+ return self._tools or {}
97
115
 
98
116
  @property
99
117
  def max_turns(self) -> int:
@@ -104,7 +122,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
104
122
  self,
105
123
  conversation: Messages,
106
124
  *,
107
- ctx: RunContext[CtxT] | None = None,
125
+ ctx: RunContext[CtxT],
108
126
  **kwargs: Any,
109
127
  ) -> bool:
110
128
  if self.tool_call_loop_terminator:
@@ -117,7 +135,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
117
135
  self,
118
136
  memory: LLMAgentMemory,
119
137
  *,
120
- ctx: RunContext[CtxT] | None = None,
138
+ ctx: RunContext[CtxT],
121
139
  **kwargs: Any,
122
140
  ) -> None:
123
141
  if self.memory_manager:
@@ -126,12 +144,16 @@ class LLMPolicyExecutor(Generic[CtxT]):
126
144
  async def generate_message(
127
145
  self,
128
146
  memory: LLMAgentMemory,
147
+ *,
129
148
  call_id: str,
130
149
  tool_choice: ToolChoice | None = None,
131
- ctx: RunContext[CtxT] | None = None,
150
+ ctx: RunContext[CtxT],
132
151
  ) -> AssistantMessage:
133
152
  completion = await self.llm.generate_completion(
134
153
  memory.message_history,
154
+ response_schema=self.response_schema,
155
+ response_schema_by_xml_tag=self.response_schema_by_xml_tag,
156
+ tools=self.tools,
135
157
  tool_choice=tool_choice,
136
158
  n_choices=1,
137
159
  proc_name=self.agent_name,
@@ -147,9 +169,10 @@ class LLMPolicyExecutor(Generic[CtxT]):
147
169
  async def generate_message_stream(
148
170
  self,
149
171
  memory: LLMAgentMemory,
172
+ *,
150
173
  call_id: str,
151
174
  tool_choice: ToolChoice | None = None,
152
- ctx: RunContext[CtxT] | None = None,
175
+ ctx: RunContext[CtxT],
153
176
  ) -> AsyncIterator[
154
177
  CompletionChunkEvent[CompletionChunk]
155
178
  | CompletionEvent
@@ -160,6 +183,9 @@ class LLMPolicyExecutor(Generic[CtxT]):
160
183
 
161
184
  llm_event_stream = self.llm.generate_completion_stream(
162
185
  memory.message_history,
186
+ response_schema=self.response_schema,
187
+ response_schema_by_xml_tag=self.response_schema_by_xml_tag,
188
+ tools=self.tools,
163
189
  tool_choice=tool_choice,
164
190
  n_choices=1,
165
191
  proc_name=self.agent_name,
@@ -189,14 +215,14 @@ class LLMPolicyExecutor(Generic[CtxT]):
189
215
  calls: Sequence[ToolCall],
190
216
  memory: LLMAgentMemory,
191
217
  call_id: str,
192
- ctx: RunContext[CtxT] | None = None,
218
+ ctx: RunContext[CtxT],
193
219
  ) -> Sequence[ToolMessage]:
194
220
  # TODO: Add image support
195
221
  corouts: list[Coroutine[Any, Any, BaseModel]] = []
196
222
  for call in calls:
197
223
  tool = self.tools[call.tool_name]
198
224
  args = json.loads(call.tool_arguments)
199
- corouts.append(tool(ctx=ctx, **args))
225
+ corouts.append(tool(call_id=call_id, ctx=ctx, **args))
200
226
 
201
227
  outs = await asyncio.gather(*corouts)
202
228
  tool_messages = list(
@@ -217,7 +243,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
217
243
  calls: Sequence[ToolCall],
218
244
  memory: LLMAgentMemory,
219
245
  call_id: str,
220
- ctx: RunContext[CtxT] | None = None,
246
+ ctx: RunContext[CtxT],
221
247
  ) -> AsyncIterator[ToolMessageEvent]:
222
248
  tool_messages = await self.call_tools(
223
249
  calls, memory=memory, call_id=call_id, ctx=ctx
@@ -245,7 +271,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
245
271
  return final_answer_message
246
272
 
247
273
  async def _generate_final_answer(
248
- self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
274
+ self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
249
275
  ) -> AssistantMessage:
250
276
  user_message = UserMessage.from_text(
251
277
  "Exceeded the maximum number of turns: provide a final answer now!"
@@ -268,7 +294,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
268
294
  return final_answer_message
269
295
 
270
296
  async def _generate_final_answer_stream(
271
- self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
297
+ self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
272
298
  ) -> AsyncIterator[Event[Any]]:
273
299
  user_message = UserMessage.from_text(
274
300
  "Exceeded the maximum number of turns: provide a final answer now!",
@@ -296,7 +322,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
296
322
  )
297
323
 
298
324
  async def execute(
299
- self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
325
+ self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
300
326
  ) -> AssistantMessage | Sequence[AssistantMessage]:
301
327
  # 1. Generate the first message:
302
328
  # In ReAct mode, we generate the first message without tool calls
@@ -379,7 +405,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
379
405
  self,
380
406
  memory: LLMAgentMemory,
381
407
  call_id: str,
382
- ctx: RunContext[CtxT] | None = None,
408
+ ctx: RunContext[CtxT],
383
409
  ) -> AsyncIterator[Event[Any]]:
384
410
  tool_choice: ToolChoice = "none" if self._react_mode else "auto"
385
411
  gen_message: AssistantMessage | None = None
@@ -464,7 +490,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
464
490
  )
465
491
 
466
492
  async def run(
467
- self, inp: BaseModel, ctx: RunContext[Any] | None = None
493
+ self,
494
+ inp: BaseModel,
495
+ *,
496
+ call_id: str | None = None,
497
+ ctx: RunContext[Any] | None = None,
468
498
  ) -> None:
469
499
  return None
470
500
 
@@ -473,22 +503,22 @@ class LLMPolicyExecutor(Generic[CtxT]):
473
503
  def _process_completion(
474
504
  self,
475
505
  completion: Completion,
506
+ *,
476
507
  call_id: str,
477
508
  print_messages: bool = False,
478
- ctx: RunContext[CtxT] | None = None,
509
+ ctx: RunContext[CtxT],
479
510
  ) -> None:
480
- if ctx is not None:
481
- ctx.completions[self.agent_name].append(completion)
482
- ctx.usage_tracker.update(
511
+ ctx.completions[self.agent_name].append(completion)
512
+ ctx.usage_tracker.update(
513
+ agent_name=self.agent_name,
514
+ completions=[completion],
515
+ model_name=self.llm.model_name,
516
+ )
517
+ if ctx.printer and print_messages:
518
+ usages = [None] * (len(completion.messages) - 1) + [completion.usage]
519
+ ctx.printer.print_messages(
520
+ completion.messages,
521
+ usages=usages,
483
522
  agent_name=self.agent_name,
484
- completions=[completion],
485
- model_name=self.llm.model_name,
523
+ call_id=call_id,
486
524
  )
487
- if ctx.printer and print_messages:
488
- usages = [None] * (len(completion.messages) - 1) + [completion.usage]
489
- ctx.printer.print_messages(
490
- completion.messages,
491
- usages=usages,
492
- agent_name=self.agent_name,
493
- call_id=call_id,
494
- )
@@ -96,8 +96,10 @@ class OpenAIConverters(Converters):
96
96
  return from_api_tool_message(raw_message, name=name, **kwargs)
97
97
 
98
98
  @staticmethod
99
- def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> OpenAIToolParam:
100
- return to_api_tool(tool, **kwargs)
99
+ def to_tool(
100
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
101
+ ) -> OpenAIToolParam:
102
+ return to_api_tool(tool, strict=strict, **kwargs)
101
103
 
102
104
  @staticmethod
103
105
  def to_tool_choice(