grasp_agents 0.5.9__py3-none-any.whl → 0.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  import json
3
- from collections.abc import AsyncIterator, Coroutine, Sequence
3
+ from collections.abc import AsyncIterator, Coroutine, Mapping, Sequence
4
4
  from itertools import starmap
5
5
  from logging import getLogger
6
6
  from typing import Any, Generic, Protocol, final
@@ -36,7 +36,8 @@ class ToolCallLoopTerminator(Protocol[CtxT]):
36
36
  self,
37
37
  conversation: Messages,
38
38
  *,
39
- ctx: RunContext[CtxT] | None,
39
+ ctx: RunContext[CtxT],
40
+ call_id: str,
40
41
  **kwargs: Any,
41
42
  ) -> bool: ...
42
43
 
@@ -46,7 +47,8 @@ class MemoryManager(Protocol[CtxT]):
46
47
  self,
47
48
  memory: LLMAgentMemory,
48
49
  *,
49
- ctx: RunContext[CtxT] | None,
50
+ ctx: RunContext[CtxT],
51
+ call_id: str,
50
52
  **kwargs: Any,
51
53
  ) -> None: ...
52
54
 
@@ -54,9 +56,12 @@ class MemoryManager(Protocol[CtxT]):
54
56
  class LLMPolicyExecutor(Generic[CtxT]):
55
57
  def __init__(
56
58
  self,
59
+ *,
57
60
  agent_name: str,
58
61
  llm: LLM[LLMSettings, Converters],
59
62
  tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
63
+ response_schema: Any | None = None,
64
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
60
65
  max_turns: int,
61
66
  react_mode: bool = False,
62
67
  final_answer_type: type[BaseModel] = BaseModel,
@@ -70,12 +75,15 @@ class LLMPolicyExecutor(Generic[CtxT]):
70
75
  self._final_answer_as_tool_call = final_answer_as_tool_call
71
76
  self._final_answer_tool = self.get_final_answer_tool()
72
77
 
73
- _tools: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
78
+ tools_list: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
74
79
  if tools and final_answer_as_tool_call:
75
- _tools = tools + [self._final_answer_tool]
80
+ tools_list = tools + [self._final_answer_tool]
81
+ self._tools = {t.name: t for t in tools_list} if tools_list else None
82
+
83
+ self._response_schema = response_schema
84
+ self._response_schema_by_xml_tag = response_schema_by_xml_tag
76
85
 
77
86
  self._llm = llm
78
- self._llm.tools = _tools
79
87
 
80
88
  self._max_turns = max_turns
81
89
  self._react_mode = react_mode
@@ -91,9 +99,21 @@ class LLMPolicyExecutor(Generic[CtxT]):
91
99
  def llm(self) -> LLM[LLMSettings, Converters]:
92
100
  return self._llm
93
101
 
102
+ @property
103
+ def response_schema(self) -> Any | None:
104
+ return self._response_schema
105
+
106
+ @response_schema.setter
107
+ def response_schema(self, value: Any | None) -> None:
108
+ self._response_schema = value
109
+
110
+ @property
111
+ def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
112
+ return self._response_schema_by_xml_tag
113
+
94
114
  @property
95
115
  def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
96
- return self._llm.tools or {}
116
+ return self._tools or {}
97
117
 
98
118
  @property
99
119
  def max_turns(self) -> int:
@@ -104,11 +124,14 @@ class LLMPolicyExecutor(Generic[CtxT]):
104
124
  self,
105
125
  conversation: Messages,
106
126
  *,
107
- ctx: RunContext[CtxT] | None = None,
127
+ ctx: RunContext[CtxT],
128
+ call_id: str,
108
129
  **kwargs: Any,
109
130
  ) -> bool:
110
131
  if self.tool_call_loop_terminator:
111
- return self.tool_call_loop_terminator(conversation, ctx=ctx, **kwargs)
132
+ return self.tool_call_loop_terminator(
133
+ conversation, ctx=ctx, call_id=call_id, **kwargs
134
+ )
112
135
 
113
136
  return False
114
137
 
@@ -117,21 +140,26 @@ class LLMPolicyExecutor(Generic[CtxT]):
117
140
  self,
118
141
  memory: LLMAgentMemory,
119
142
  *,
120
- ctx: RunContext[CtxT] | None = None,
143
+ ctx: RunContext[CtxT],
144
+ call_id: str,
121
145
  **kwargs: Any,
122
146
  ) -> None:
123
147
  if self.memory_manager:
124
- self.memory_manager(memory=memory, ctx=ctx, **kwargs)
148
+ self.memory_manager(memory=memory, ctx=ctx, call_id=call_id, **kwargs)
125
149
 
126
150
  async def generate_message(
127
151
  self,
128
152
  memory: LLMAgentMemory,
129
- call_id: str,
153
+ *,
130
154
  tool_choice: ToolChoice | None = None,
131
- ctx: RunContext[CtxT] | None = None,
155
+ ctx: RunContext[CtxT],
156
+ call_id: str,
132
157
  ) -> AssistantMessage:
133
158
  completion = await self.llm.generate_completion(
134
159
  memory.message_history,
160
+ response_schema=self.response_schema,
161
+ response_schema_by_xml_tag=self.response_schema_by_xml_tag,
162
+ tools=self.tools,
135
163
  tool_choice=tool_choice,
136
164
  n_choices=1,
137
165
  proc_name=self.agent_name,
@@ -139,7 +167,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
139
167
  )
140
168
  memory.update(completion.messages)
141
169
  self._process_completion(
142
- completion, call_id=call_id, ctx=ctx, print_messages=True
170
+ completion, ctx=ctx, call_id=call_id, print_messages=True
143
171
  )
144
172
 
145
173
  return completion.messages[0]
@@ -147,9 +175,10 @@ class LLMPolicyExecutor(Generic[CtxT]):
147
175
  async def generate_message_stream(
148
176
  self,
149
177
  memory: LLMAgentMemory,
150
- call_id: str,
178
+ *,
151
179
  tool_choice: ToolChoice | None = None,
152
- ctx: RunContext[CtxT] | None = None,
180
+ ctx: RunContext[CtxT],
181
+ call_id: str,
153
182
  ) -> AsyncIterator[
154
183
  CompletionChunkEvent[CompletionChunk]
155
184
  | CompletionEvent
@@ -160,6 +189,9 @@ class LLMPolicyExecutor(Generic[CtxT]):
160
189
 
161
190
  llm_event_stream = self.llm.generate_completion_stream(
162
191
  memory.message_history,
192
+ response_schema=self.response_schema,
193
+ response_schema_by_xml_tag=self.response_schema_by_xml_tag,
194
+ tools=self.tools,
163
195
  tool_choice=tool_choice,
164
196
  n_choices=1,
165
197
  proc_name=self.agent_name,
@@ -181,22 +213,22 @@ class LLMPolicyExecutor(Generic[CtxT]):
181
213
  memory.update(completion.messages)
182
214
 
183
215
  self._process_completion(
184
- completion, call_id=call_id, print_messages=True, ctx=ctx
216
+ completion, print_messages=True, ctx=ctx, call_id=call_id
185
217
  )
186
218
 
187
219
  async def call_tools(
188
220
  self,
189
221
  calls: Sequence[ToolCall],
190
222
  memory: LLMAgentMemory,
223
+ ctx: RunContext[CtxT],
191
224
  call_id: str,
192
- ctx: RunContext[CtxT] | None = None,
193
225
  ) -> Sequence[ToolMessage]:
194
226
  # TODO: Add image support
195
227
  corouts: list[Coroutine[Any, Any, BaseModel]] = []
196
228
  for call in calls:
197
229
  tool = self.tools[call.tool_name]
198
230
  args = json.loads(call.tool_arguments)
199
- corouts.append(tool(ctx=ctx, **args))
231
+ corouts.append(tool(ctx=ctx, call_id=call_id, **args))
200
232
 
201
233
  outs = await asyncio.gather(*corouts)
202
234
  tool_messages = list(
@@ -216,11 +248,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
216
248
  self,
217
249
  calls: Sequence[ToolCall],
218
250
  memory: LLMAgentMemory,
251
+ ctx: RunContext[CtxT],
219
252
  call_id: str,
220
- ctx: RunContext[CtxT] | None = None,
221
253
  ) -> AsyncIterator[ToolMessageEvent]:
222
254
  tool_messages = await self.call_tools(
223
- calls, memory=memory, call_id=call_id, ctx=ctx
255
+ calls, memory=memory, ctx=ctx, call_id=call_id
224
256
  )
225
257
  for tool_message, call in zip(tool_messages, calls, strict=True):
226
258
  yield ToolMessageEvent(
@@ -245,7 +277,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
245
277
  return final_answer_message
246
278
 
247
279
  async def _generate_final_answer(
248
- self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
280
+ self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
249
281
  ) -> AssistantMessage:
250
282
  user_message = UserMessage.from_text(
251
283
  "Exceeded the maximum number of turns: provide a final answer now!"
@@ -258,7 +290,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
258
290
 
259
291
  tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
260
292
  await self.generate_message(
261
- memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
293
+ memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
262
294
  )
263
295
 
264
296
  final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
@@ -268,7 +300,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
268
300
  return final_answer_message
269
301
 
270
302
  async def _generate_final_answer_stream(
271
- self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
303
+ self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
272
304
  ) -> AsyncIterator[Event[Any]]:
273
305
  user_message = UserMessage.from_text(
274
306
  "Exceeded the maximum number of turns: provide a final answer now!",
@@ -284,7 +316,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
284
316
 
285
317
  tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
286
318
  async for event in self.generate_message_stream(
287
- memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
319
+ memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
288
320
  ):
289
321
  yield event
290
322
 
@@ -296,7 +328,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
296
328
  )
297
329
 
298
330
  async def execute(
299
- self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
331
+ self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
300
332
  ) -> AssistantMessage | Sequence[AssistantMessage]:
301
333
  # 1. Generate the first message:
302
334
  # In ReAct mode, we generate the first message without tool calls
@@ -306,7 +338,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
306
338
  if self.tools:
307
339
  tool_choice = "none" if self._react_mode else "auto"
308
340
  gen_message = await self.generate_message(
309
- memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
341
+ memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
310
342
  )
311
343
  if not self.tools:
312
344
  return gen_message
@@ -319,7 +351,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
319
351
  # If a final answer is not provided via a tool call, we use
320
352
  # _terminate_tool_call_loop to determine whether to exit the loop.
321
353
  if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
322
- memory.message_history, ctx=ctx, num_turns=turns
354
+ memory.message_history, ctx=ctx, call_id=call_id, num_turns=turns
323
355
  ):
324
356
  return gen_message
325
357
 
@@ -338,7 +370,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
338
370
  # Otherwise, we simply return the last generated message.
339
371
  if self._final_answer_as_tool_call:
340
372
  final_answer = await self._generate_final_answer(
341
- memory, call_id=call_id, ctx=ctx
373
+ memory, ctx=ctx, call_id=call_id
342
374
  )
343
375
  else:
344
376
  final_answer = gen_message
@@ -351,11 +383,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
351
383
 
352
384
  if gen_message.tool_calls:
353
385
  await self.call_tools(
354
- gen_message.tool_calls, memory=memory, call_id=call_id, ctx=ctx
386
+ gen_message.tool_calls, memory=memory, ctx=ctx, call_id=call_id
355
387
  )
356
388
 
357
389
  # Apply memory management (e.g. compacting or pruning memory)
358
- self._manage_memory(memory, ctx=ctx, num_turns=turns)
390
+ self._manage_memory(memory, ctx=ctx, call_id=call_id, num_turns=turns)
359
391
 
360
392
  # 4. Generate the next message based on the updated memory.
361
393
  # In ReAct mode, we set tool_choice to "none" if we just called tools,
@@ -370,7 +402,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
370
402
  tool_choice = "required"
371
403
 
372
404
  gen_message = await self.generate_message(
373
- memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
405
+ memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
374
406
  )
375
407
 
376
408
  turns += 1
@@ -378,13 +410,13 @@ class LLMPolicyExecutor(Generic[CtxT]):
378
410
  async def execute_stream(
379
411
  self,
380
412
  memory: LLMAgentMemory,
413
+ ctx: RunContext[CtxT],
381
414
  call_id: str,
382
- ctx: RunContext[CtxT] | None = None,
383
415
  ) -> AsyncIterator[Event[Any]]:
384
416
  tool_choice: ToolChoice = "none" if self._react_mode else "auto"
385
417
  gen_message: AssistantMessage | None = None
386
418
  async for event in self.generate_message_stream(
387
- memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
419
+ memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
388
420
  ):
389
421
  if isinstance(event, GenMessageEvent):
390
422
  gen_message = event.data
@@ -399,7 +431,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
399
431
 
400
432
  while True:
401
433
  if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
402
- memory.message_history, ctx=ctx, num_turns=turns
434
+ memory.message_history, ctx=ctx, call_id=call_id, num_turns=turns
403
435
  ):
404
436
  return
405
437
 
@@ -418,7 +450,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
418
450
  if turns >= self.max_turns:
419
451
  if self._final_answer_as_tool_call:
420
452
  async for event in self._generate_final_answer_stream(
421
- memory, call_id=call_id, ctx=ctx
453
+ memory, ctx=ctx, call_id=call_id
422
454
  ):
423
455
  yield event
424
456
  logger.info(
@@ -433,11 +465,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
433
465
  )
434
466
 
435
467
  async for event in self.call_tools_stream(
436
- gen_message.tool_calls, memory=memory, call_id=call_id, ctx=ctx
468
+ gen_message.tool_calls, memory=memory, ctx=ctx, call_id=call_id
437
469
  ):
438
470
  yield event
439
471
 
440
- self._manage_memory(memory, ctx=ctx, num_turns=turns)
472
+ self._manage_memory(memory, ctx=ctx, call_id=call_id, num_turns=turns)
441
473
 
442
474
  if self._react_mode and gen_message.tool_calls:
443
475
  tool_choice = "none"
@@ -447,7 +479,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
447
479
  tool_choice = "required"
448
480
 
449
481
  async for event in self.generate_message_stream(
450
- memory, tool_choice=tool_choice, call_id=call_id, ctx=ctx
482
+ memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
451
483
  ):
452
484
  yield event
453
485
  if isinstance(event, GenMessageEvent):
@@ -464,7 +496,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
464
496
  )
465
497
 
466
498
  async def run(
467
- self, inp: BaseModel, ctx: RunContext[Any] | None = None
499
+ self,
500
+ inp: BaseModel,
501
+ *,
502
+ ctx: RunContext[Any] | None = None,
503
+ call_id: str | None = None,
468
504
  ) -> None:
469
505
  return None
470
506
 
@@ -473,22 +509,22 @@ class LLMPolicyExecutor(Generic[CtxT]):
473
509
  def _process_completion(
474
510
  self,
475
511
  completion: Completion,
476
- call_id: str,
512
+ *,
477
513
  print_messages: bool = False,
478
- ctx: RunContext[CtxT] | None = None,
514
+ ctx: RunContext[CtxT],
515
+ call_id: str,
479
516
  ) -> None:
480
- if ctx is not None:
481
- ctx.completions[self.agent_name].append(completion)
482
- ctx.usage_tracker.update(
517
+ ctx.completions[self.agent_name].append(completion)
518
+ ctx.usage_tracker.update(
519
+ agent_name=self.agent_name,
520
+ completions=[completion],
521
+ model_name=self.llm.model_name,
522
+ )
523
+ if ctx.printer and print_messages:
524
+ usages = [None] * (len(completion.messages) - 1) + [completion.usage]
525
+ ctx.printer.print_messages(
526
+ completion.messages,
527
+ usages=usages,
483
528
  agent_name=self.agent_name,
484
- completions=[completion],
485
- model_name=self.llm.model_name,
529
+ call_id=call_id,
486
530
  )
487
- if ctx.printer and print_messages:
488
- usages = [None] * (len(completion.messages) - 1) + [completion.usage]
489
- ctx.printer.print_messages(
490
- completion.messages,
491
- usages=usages,
492
- agent_name=self.agent_name,
493
- call_id=call_id,
494
- )
@@ -96,8 +96,10 @@ class OpenAIConverters(Converters):
96
96
  return from_api_tool_message(raw_message, name=name, **kwargs)
97
97
 
98
98
  @staticmethod
99
- def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> OpenAIToolParam:
100
- return to_api_tool(tool, **kwargs)
99
+ def to_tool(
100
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
101
+ ) -> OpenAIToolParam:
102
+ return to_api_tool(tool, strict=strict, **kwargs)
101
103
 
102
104
  @staticmethod
103
105
  def to_tool_choice(
@@ -3,9 +3,9 @@ import logging
3
3
  import os
4
4
  from collections.abc import AsyncIterator, Iterable, Mapping
5
5
  from copy import deepcopy
6
+ from dataclasses import dataclass, field
6
7
  from typing import Any, Literal
7
8
 
8
- import httpx
9
9
  from openai import AsyncOpenAI, AsyncStream
10
10
  from openai._types import NOT_GIVEN # type: ignore[import]
11
11
  from openai.lib.streaming.chat import (
@@ -15,8 +15,7 @@ from openai.lib.streaming.chat import ChatCompletionStreamState
15
15
  from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
16
16
  from pydantic import BaseModel
17
17
 
18
- from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings, LLMRateLimiter
19
- from ..http_client import AsyncHTTPClientParams
18
+ from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
20
19
  from ..typing.tool import BaseTool
21
20
  from . import (
22
21
  OpenAICompletion,
@@ -61,7 +60,7 @@ def get_openai_compatible_providers() -> list[APIProvider]:
61
60
 
62
61
 
63
62
  class OpenAILLMSettings(CloudLLMSettings, total=False):
64
- reasoning_effort: Literal["low", "medium", "high"] | None
63
+ reasoning_effort: Literal["disable", "minimal", "low", "medium", "high"] | None
65
64
 
66
65
  parallel_tool_calls: bool
67
66
 
@@ -90,105 +89,75 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
90
89
  # TODO: support audio
91
90
 
92
91
 
92
+ @dataclass(frozen=True)
93
93
  class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
94
- def __init__(
95
- self,
96
- # Base LLM args
97
- model_name: str,
98
- llm_settings: OpenAILLMSettings | None = None,
99
- tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
100
- response_schema: Any | None = None,
101
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
102
- apply_response_schema_via_provider: bool = False,
103
- model_id: str | None = None,
104
- # Custom LLM provider
105
- api_provider: APIProvider | None = None,
106
- # Connection settings
107
- max_client_retries: int = 2,
108
- async_http_client: httpx.AsyncClient | None = None,
109
- async_http_client_params: (
110
- dict[str, Any] | AsyncHTTPClientParams | None
111
- ) = None,
112
- async_openai_client_params: dict[str, Any] | None = None,
113
- # Rate limiting
114
- rate_limiter: LLMRateLimiter | None = None,
115
- # LLM response retries: try to regenerate to pass validation
116
- max_response_retries: int = 1,
117
- ) -> None:
94
+ converters: OpenAIConverters = field(default_factory=OpenAIConverters)
95
+ async_openai_client_params: dict[str, Any] | None = None
96
+ client: AsyncOpenAI = field(init=False)
97
+
98
+ def __post_init__(self):
99
+ super().__post_init__()
100
+
118
101
  openai_compatible_providers = get_openai_compatible_providers()
119
102
 
120
- model_name_parts = model_name.split("/", 1)
121
- if api_provider is not None:
122
- provider_model_name = model_name
103
+ _api_provider = self.api_provider
104
+
105
+ model_name_parts = self.model_name.split("/", 1)
106
+ if _api_provider is not None:
107
+ _model_name = self.model_name
123
108
  elif len(model_name_parts) == 2:
124
109
  compat_providers_map = {
125
110
  provider["name"]: provider for provider in openai_compatible_providers
126
111
  }
127
- provider_name, provider_model_name = model_name_parts
112
+ provider_name, _model_name = model_name_parts
128
113
  if provider_name not in compat_providers_map:
129
114
  raise ValueError(
130
115
  f"API provider '{provider_name}' is not a supported OpenAI "
131
116
  f"compatible provider. Supported providers are: "
132
117
  f"{', '.join(compat_providers_map.keys())}"
133
118
  )
134
- api_provider = compat_providers_map[provider_name]
119
+ _api_provider = compat_providers_map[provider_name]
135
120
  else:
136
121
  raise ValueError(
137
122
  "Model name must be in the format 'provider/model_name' or "
138
123
  "you must provide an 'api_provider' argument."
139
124
  )
140
125
 
141
- if llm_settings is not None:
142
- stream_options = llm_settings.get("stream_options") or {}
126
+ if self.llm_settings is not None:
127
+ stream_options = self.llm_settings.get("stream_options") or {}
143
128
  stream_options["include_usage"] = True
144
- _llm_settings = deepcopy(llm_settings)
129
+ _llm_settings = deepcopy(self.llm_settings)
145
130
  _llm_settings["stream_options"] = stream_options
146
131
  else:
147
132
  _llm_settings = OpenAILLMSettings(stream_options={"include_usage": True})
148
133
 
149
- super().__init__(
150
- model_name=provider_model_name,
151
- model_id=model_id,
152
- llm_settings=_llm_settings,
153
- converters=OpenAIConverters(),
154
- tools=tools,
155
- response_schema=response_schema,
156
- response_schema_by_xml_tag=response_schema_by_xml_tag,
157
- apply_response_schema_via_provider=apply_response_schema_via_provider,
158
- api_provider=api_provider,
159
- async_http_client=async_http_client,
160
- async_http_client_params=async_http_client_params,
161
- rate_limiter=rate_limiter,
162
- max_client_retries=max_client_retries,
163
- max_response_retries=max_response_retries,
164
- )
165
-
166
134
  response_schema_support: bool = any(
167
- fnmatch.fnmatch(self._model_name, pat)
168
- for pat in api_provider.get("response_schema_support") or []
135
+ fnmatch.fnmatch(_model_name, pat)
136
+ for pat in _api_provider.get("response_schema_support") or []
169
137
  )
170
- if apply_response_schema_via_provider:
171
- if self._tools:
172
- for tool in self._tools.values():
173
- tool.strict = True
174
- if not response_schema_support:
175
- raise ValueError(
176
- "Native response schema validation is not supported for model "
177
- f"'{self._model_name}' by the API provider. Please set "
178
- "apply_response_schema_via_provider=False."
179
- )
138
+ if self.apply_response_schema_via_provider and not response_schema_support:
139
+ raise ValueError(
140
+ "Native response schema validation is not supported for model "
141
+ f"'{_model_name}' by the API provider. Please set "
142
+ "apply_response_schema_via_provider=False."
143
+ )
180
144
 
181
- _async_openai_client_params = deepcopy(async_openai_client_params or {})
182
- if self._async_http_client is not None:
183
- _async_openai_client_params["http_client"] = self._async_http_client
145
+ _async_openai_client_params = deepcopy(self.async_openai_client_params or {})
146
+ if self.async_http_client is not None:
147
+ _async_openai_client_params["http_client"] = self.async_http_client
184
148
 
185
- self._client: AsyncOpenAI = AsyncOpenAI(
186
- base_url=self.api_provider.get("base_url"),
187
- api_key=self.api_provider.get("api_key"),
188
- max_retries=max_client_retries,
149
+ _client = AsyncOpenAI(
150
+ base_url=_api_provider.get("base_url"),
151
+ api_key=_api_provider.get("api_key"),
152
+ max_retries=self.max_client_retries,
189
153
  **_async_openai_client_params,
190
154
  )
191
155
 
156
+ object.__setattr__(self, "model_name", _model_name)
157
+ object.__setattr__(self, "api_provider", _api_provider)
158
+ object.__setattr__(self, "llm_settings", _llm_settings)
159
+ object.__setattr__(self, "client", _client)
160
+
192
161
  async def _get_completion(
193
162
  self,
194
163
  api_messages: Iterable[OpenAIMessageParam],
@@ -203,9 +172,9 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
203
172
  response_format = api_response_schema or NOT_GIVEN
204
173
  n = n_choices or NOT_GIVEN
205
174
 
206
- if self._apply_response_schema_via_provider:
207
- return await self._client.beta.chat.completions.parse(
208
- model=self._model_name,
175
+ if self.apply_response_schema_via_provider:
176
+ return await self.client.beta.chat.completions.parse(
177
+ model=self.model_name,
209
178
  messages=api_messages,
210
179
  tools=tools,
211
180
  tool_choice=tool_choice,
@@ -214,8 +183,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
214
183
  **api_llm_settings,
215
184
  )
216
185
 
217
- return await self._client.chat.completions.create(
218
- model=self._model_name,
186
+ return await self.client.chat.completions.create(
187
+ model=self.model_name,
219
188
  messages=api_messages,
220
189
  tools=tools,
221
190
  tool_choice=tool_choice,
@@ -238,10 +207,10 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
238
207
  response_format = api_response_schema or NOT_GIVEN
239
208
  n = n_choices or NOT_GIVEN
240
209
 
241
- if self._apply_response_schema_via_provider:
210
+ if self.apply_response_schema_via_provider:
242
211
  stream_manager: OpenAIAsyncChatCompletionStreamManager[Any] = (
243
- self._client.beta.chat.completions.stream(
244
- model=self._model_name,
212
+ self.client.beta.chat.completions.stream(
213
+ model=self.model_name,
245
214
  messages=api_messages,
246
215
  tools=tools,
247
216
  tool_choice=tool_choice,
@@ -257,8 +226,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
257
226
  else:
258
227
  stream_generator: AsyncStream[
259
228
  OpenAICompletionChunk
260
- ] = await self._client.chat.completions.create(
261
- model=self._model_name,
229
+ ] = await self.client.chat.completions.create(
230
+ model=self.model_name,
262
231
  messages=api_messages,
263
232
  tools=tools,
264
233
  tool_choice=tool_choice,
@@ -271,16 +240,20 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
271
240
  yield completion_chunk
272
241
 
273
242
  def combine_completion_chunks(
274
- self, completion_chunks: list[OpenAICompletionChunk]
243
+ self,
244
+ completion_chunks: list[OpenAICompletionChunk],
245
+ response_schema: Any | None = None,
246
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
275
247
  ) -> OpenAICompletion:
276
248
  response_format = NOT_GIVEN
277
249
  input_tools = NOT_GIVEN
278
- if self._apply_response_schema_via_provider:
279
- if self._response_schema:
280
- response_format = self._response_schema
281
- if self._tools:
250
+ if self.apply_response_schema_via_provider:
251
+ if response_schema:
252
+ response_format = response_schema
253
+ if tools:
282
254
  input_tools = [
283
- self._converters.to_tool(tool) for tool in self._tools.values()
255
+ self.converters.to_tool(tool, strict=True)
256
+ for tool in tools.values()
284
257
  ]
285
258
  state = ChatCompletionStreamState[Any](
286
259
  input_tools=input_tools, response_format=response_format