grasp_agents 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/cloud_llm.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import logging
2
2
  from abc import abstractmethod
3
- from collections.abc import AsyncIterator, Mapping, Sequence
3
+ from collections.abc import AsyncIterator, Mapping
4
4
  from copy import deepcopy
5
+ from dataclasses import dataclass
5
6
  from typing import Any, Generic, Required, cast
6
7
 
7
8
  import httpx
@@ -58,111 +59,53 @@ LLMRateLimiter = RateLimiterC[
58
59
  ]
59
60
 
60
61
 
62
+ @dataclass(frozen=True)
61
63
  class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
62
- def __init__(
63
- self,
64
- # Base LLM args
65
- model_name: str,
66
- api_provider: APIProvider,
67
- converters: ConvertT_co,
68
- llm_settings: SettingsT_co | None = None,
69
- tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
70
- response_schema: Any | None = None,
71
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
72
- apply_response_schema_via_provider: bool = True,
73
- model_id: str | None = None,
74
- # Connection settings
75
- async_http_client: httpx.AsyncClient | None = None,
76
- async_http_client_params: (
77
- dict[str, Any] | AsyncHTTPClientParams | None
78
- ) = None,
79
- max_client_retries: int = 2,
80
- # Rate limiting
81
- rate_limiter: LLMRateLimiter | None = None,
82
- # LLM response retries: try to regenerate to pass validation
83
- max_response_retries: int = 0,
84
- **kwargs: Any,
85
- ) -> None:
86
- self.llm_settings: CloudLLMSettings | None
87
-
88
- super().__init__(
89
- model_name=model_name,
90
- llm_settings=llm_settings,
91
- converters=converters,
92
- model_id=model_id,
93
- tools=tools,
94
- response_schema=response_schema,
95
- response_schema_by_xml_tag=response_schema_by_xml_tag,
96
- **kwargs,
97
- )
98
-
99
- self._model_name = model_name
100
- self._api_provider = api_provider
101
- self._apply_response_schema_via_provider = apply_response_schema_via_provider
102
-
103
- if (
104
- apply_response_schema_via_provider
105
- and response_schema_by_xml_tag is not None
106
- ):
107
- raise ValueError(
108
- "Response schema by XML tag is not supported "
109
- "when apply_response_schema_via_provider is True."
110
- )
64
+ # Make this field keyword-only to avoid ordering issues with inherited defaulted fields
65
+ api_provider: APIProvider | None = None
66
+ llm_settings: SettingsT_co | None = None
67
+ rate_limiter: LLMRateLimiter | None = None
68
+ max_client_retries: int = 2 # HTTP client retries for network errors
69
+ max_response_retries: int = (
70
+ 0 # LLM response retries: try to regenerate to pass validation
71
+ )
72
+ apply_response_schema_via_provider: bool = False
73
+ async_http_client: httpx.AsyncClient | None = None
74
+ async_http_client_params: dict[str, Any] | AsyncHTTPClientParams | None = None
111
75
 
112
- self._rate_limiter: LLMRateLimiter | None = None
113
- if rate_limiter is not None:
114
- self._rate_limiter = rate_limiter
76
+ def __post_init__(self) -> None:
77
+ if self.rate_limiter is not None:
115
78
  logger.info(
116
- f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
79
+ f"[{self.__class__.__name__}] Set rate limit to "
80
+ f"{self.rate_limiter.rpm} RPM"
117
81
  )
118
82
 
119
- self._async_http_client: httpx.AsyncClient | None = None
120
- if async_http_client is not None:
121
- self._async_http_client = async_http_client
122
- elif async_http_client_params is not None:
123
- self._async_http_client = create_simple_async_httpx_client(
124
- async_http_client_params
83
+ if self.async_http_client is None and self.async_http_client_params is not None:
84
+ object.__setattr__(
85
+ self,
86
+ "async_http_client",
87
+ create_simple_async_httpx_client(self.async_http_client_params),
125
88
  )
126
89
 
127
- self.max_client_retries = max_client_retries
128
- self.max_response_retries = max_response_retries
129
-
130
- @property
131
- def api_provider(self) -> APIProvider:
132
- return self._api_provider
133
-
134
- @property
135
- def rate_limiter(self) -> LLMRateLimiter | None:
136
- return self._rate_limiter
137
-
138
- @property
139
- def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
140
- return self._tools
141
-
142
- @tools.setter
143
- def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
144
- if not tools:
145
- self._tools = None
146
- return
147
- strict_value = True if self._apply_response_schema_via_provider else None
148
- for t in tools:
149
- t.strict = strict_value
150
- self._tools = {t.name: t for t in tools}
151
-
152
90
  def _make_completion_kwargs(
153
91
  self,
154
92
  conversation: Messages,
93
+ response_schema: Any | None = None,
94
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
155
95
  tool_choice: ToolChoice | None = None,
156
96
  n_choices: int | None = None,
157
97
  ) -> dict[str, Any]:
158
- api_messages = [self._converters.to_message(m) for m in conversation]
98
+ api_messages = [self.converters.to_message(m) for m in conversation]
159
99
 
160
100
  api_tools = None
161
101
  api_tool_choice = None
162
- if self.tools:
163
- api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
102
+ if tools:
103
+ strict = True if self.apply_response_schema_via_provider else None
104
+ api_tools = [
105
+ self.converters.to_tool(t, strict=strict) for t in tools.values()
106
+ ]
164
107
  if tool_choice is not None:
165
- api_tool_choice = self._converters.to_tool_choice(tool_choice)
108
+ api_tool_choice = self.converters.to_tool_choice(tool_choice)
166
109
 
167
110
  api_llm_settings = deepcopy(self.llm_settings or {})
168
111
 
@@ -170,7 +113,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
170
113
  api_messages=api_messages,
171
114
  api_tools=api_tools,
172
115
  api_tool_choice=api_tool_choice,
173
- api_response_schema=self._response_schema,
116
+ api_response_schema=response_schema,
174
117
  n_choices=n_choices,
175
118
  **api_llm_settings,
176
119
  )
@@ -206,24 +149,34 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
206
149
  self,
207
150
  conversation: Messages,
208
151
  *,
152
+ response_schema: Any | None = None,
153
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
154
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
209
155
  tool_choice: ToolChoice | None = None,
210
156
  n_choices: int | None = None,
211
157
  ) -> Completion:
212
158
  completion_kwargs = self._make_completion_kwargs(
213
- conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
159
+ conversation=conversation,
160
+ response_schema=response_schema,
161
+ tools=tools,
162
+ tool_choice=tool_choice,
163
+ n_choices=n_choices,
214
164
  )
215
165
 
216
- if not self._apply_response_schema_via_provider:
166
+ if not self.apply_response_schema_via_provider:
217
167
  completion_kwargs.pop("api_response_schema", None)
218
168
  api_completion = await self._get_completion(**completion_kwargs)
219
169
 
220
- completion = self._converters.from_completion(
221
- api_completion, name=self.model_id
222
- )
170
+ completion = self.converters.from_completion(api_completion, name=self.model_id)
223
171
 
224
- if not self._apply_response_schema_via_provider:
225
- self._validate_response(completion)
226
- self._validate_tool_calls(completion)
172
+ if not self.apply_response_schema_via_provider:
173
+ self._validate_response(
174
+ completion,
175
+ response_schema=response_schema,
176
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
177
+ )
178
+ if tools is not None:
179
+ self._validate_tool_calls(completion, tools=tools)
227
180
 
228
181
  return completion
229
182
 
@@ -231,6 +184,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
231
184
  self,
232
185
  conversation: Messages,
233
186
  *,
187
+ response_schema: Any | None = None,
188
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
189
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
234
190
  tool_choice: ToolChoice | None = None,
235
191
  n_choices: int | None = None,
236
192
  proc_name: str | None = None,
@@ -241,6 +197,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
241
197
  try:
242
198
  return await self._generate_completion_once(
243
199
  conversation, # type: ignore[return]
200
+ response_schema=response_schema,
201
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
202
+ tools=tools,
244
203
  tool_choice=tool_choice,
245
204
  n_choices=n_choices,
246
205
  )
@@ -263,7 +222,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
263
222
  )
264
223
 
265
224
  return make_refusal_completion(
266
- self._model_name,
225
+ self.model_name,
267
226
  Exception("Unexpected error: retry loop exited without returning"),
268
227
  )
269
228
 
@@ -272,15 +231,22 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
272
231
  self,
273
232
  conversation: Messages,
274
233
  *,
234
+ response_schema: Any | None = None,
235
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
236
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
275
237
  tool_choice: ToolChoice | None = None,
276
238
  n_choices: int | None = None,
277
239
  proc_name: str | None = None,
278
240
  call_id: str | None = None,
279
241
  ) -> AsyncIterator[CompletionChunkEvent[CompletionChunk] | CompletionEvent]:
280
242
  completion_kwargs = self._make_completion_kwargs(
281
- conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
243
+ conversation=conversation,
244
+ response_schema=response_schema,
245
+ tools=tools,
246
+ tool_choice=tool_choice,
247
+ n_choices=n_choices,
282
248
  )
283
- if not self._apply_response_schema_via_provider:
249
+ if not self.apply_response_schema_via_provider:
284
250
  completion_kwargs.pop("api_response_schema", None)
285
251
 
286
252
  api_stream = self._get_completion_stream(**completion_kwargs)
@@ -293,7 +259,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
293
259
 
294
260
  async for api_completion_chunk in api_stream:
295
261
  api_completion_chunks.append(api_completion_chunk)
296
- completion_chunk = self._converters.from_completion_chunk(
262
+ completion_chunk = self.converters.from_completion_chunk(
297
263
  api_completion_chunk, name=self.model_id
298
264
  )
299
265
 
@@ -301,16 +267,23 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
301
267
  data=completion_chunk, proc_name=proc_name, call_id=call_id
302
268
  )
303
269
 
304
- api_completion = self.combine_completion_chunks(api_completion_chunks)
305
- completion = self._converters.from_completion(
270
+ api_completion = self.combine_completion_chunks(
271
+ api_completion_chunks, response_schema=response_schema, tools=tools
272
+ )
273
+ completion = self.converters.from_completion(
306
274
  api_completion, name=self.model_id
307
275
  )
308
276
 
309
277
  yield CompletionEvent(data=completion, proc_name=proc_name, call_id=call_id)
310
278
 
311
- if not self._apply_response_schema_via_provider:
312
- self._validate_response(completion)
313
- self._validate_tool_calls(completion)
279
+ if not self.apply_response_schema_via_provider:
280
+ self._validate_response(
281
+ completion,
282
+ response_schema=response_schema,
283
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
284
+ )
285
+ if tools is not None:
286
+ self._validate_tool_calls(completion, tools=tools)
314
287
 
315
288
  return iterator()
316
289
 
@@ -318,6 +291,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
318
291
  self,
319
292
  conversation: Messages,
320
293
  *,
294
+ response_schema: Any | None = None,
295
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
296
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
321
297
  tool_choice: ToolChoice | None = None,
322
298
  n_choices: int | None = None,
323
299
  proc_name: str | None = None,
@@ -330,6 +306,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
330
306
  try:
331
307
  async for event in await self._generate_completion_stream_once( # type: ignore[return]
332
308
  conversation, # type: ignore[arg-type]
309
+ response_schema=response_schema,
310
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
311
+ tools=tools,
333
312
  tool_choice=tool_choice,
334
313
  n_choices=n_choices,
335
314
  proc_name=proc_name,
@@ -339,7 +318,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
339
318
  return
340
319
  except (LLMResponseValidationError, LLMToolCallValidationError) as err:
341
320
  err_data = LLMStreamingErrorData(
342
- error=err, model_name=self._model_name, model_id=self.model_id
321
+ error=err, model_name=self.model_name, model_id=self.model_id
343
322
  )
344
323
  yield LLMStreamingErrorEvent(
345
324
  data=err_data, proc_name=proc_name, call_id=call_id
@@ -355,7 +334,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
355
334
  f"retrying:\n{err}"
356
335
  )
357
336
  refusal_completion = make_refusal_completion(
358
- self._model_name, err
337
+ self.model_name, err
359
338
  )
360
339
  yield CompletionEvent(
361
340
  data=refusal_completion,
@@ -118,8 +118,10 @@ class LiteLLMConverters(Converters):
118
118
  return from_api_tool_message(raw_message, name=name, **kwargs)
119
119
 
120
120
  @staticmethod
121
- def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> OpenAIToolParam:
122
- return to_api_tool(tool, **kwargs)
121
+ def to_tool(
122
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
123
+ ) -> OpenAIToolParam:
124
+ return to_api_tool(tool, strict=strict, **kwargs)
123
125
 
124
126
  @staticmethod
125
127
  def to_tool_choice(
@@ -1,5 +1,8 @@
1
1
  import logging
2
+ from collections import defaultdict
2
3
  from collections.abc import AsyncIterator, Mapping
4
+ from copy import deepcopy
5
+ from dataclasses import dataclass, field
3
6
  from typing import Any, cast
4
7
 
5
8
  import litellm
@@ -19,7 +22,7 @@ from litellm.utils import (
19
22
  # from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
20
23
  from pydantic import BaseModel
21
24
 
22
- from ..cloud_llm import APIProvider, CloudLLM, LLMRateLimiter
25
+ from ..cloud_llm import APIProvider, CloudLLM
23
26
  from ..openai.openai_llm import OpenAILLMSettings
24
27
  from ..typing.tool import BaseTool
25
28
  from . import (
@@ -38,107 +41,101 @@ class LiteLLMSettings(OpenAILLMSettings, total=False):
38
41
  thinking: AnthropicThinkingParam | None
39
42
 
40
43
 
44
+ @dataclass(frozen=True)
41
45
  class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
42
- def __init__(
43
- self,
44
- # Base LLM args
45
- model_name: str,
46
- model_id: str | None = None,
47
- llm_settings: LiteLLMSettings | None = None,
48
- tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
49
- response_schema: Any | None = None,
50
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
51
- apply_response_schema_via_provider: bool = False,
52
- # LLM provider
53
- api_provider: APIProvider | None = None,
54
- # deployment_id: str | None = None,
55
- # api_version: str | None = None,
56
- # Connection settings
57
- timeout: float | None = None,
58
- max_client_retries: int = 2,
59
- # Rate limiting
60
- rate_limiter: LLMRateLimiter | None = None,
61
- # Drop unsupported LLM settings
62
- drop_params: bool = True,
63
- additional_drop_params: list[str] | None = None,
64
- allowed_openai_params: list[str] | None = None,
65
- # Mock LLM response for testing
66
- mock_response: str | None = None,
67
- # LLM response retries: try to regenerate to pass validation
68
- max_response_retries: int = 1,
69
- ) -> None:
70
- self._lite_llm_completion_params: dict[str, Any] = {
71
- "max_retries": max_client_retries,
72
- "timeout": timeout,
73
- "drop_params": drop_params,
74
- "additional_drop_params": additional_drop_params,
75
- "allowed_openai_params": allowed_openai_params,
76
- "mock_response": mock_response,
77
- # "deployment_id": deployment_id,
78
- # "api_version": api_version,
79
- }
80
-
81
- if model_name in litellm.get_valid_models(): # type: ignore[no-untyped-call]
82
- _, provider_name, _, _ = litellm.get_llm_provider(model_name) # type: ignore[no-untyped-call]
83
- api_provider = APIProvider(name=provider_name)
84
- elif api_provider is not None:
85
- self._lite_llm_completion_params["api_key"] = api_provider.get("api_key")
86
- self._lite_llm_completion_params["api_base"] = api_provider.get("api_base")
87
- elif api_provider is None:
46
+ converters: LiteLLMConverters = field(default_factory=LiteLLMConverters)
47
+
48
+ timeout: float | None = None
49
+ # Drop unsupported LLM settings
50
+ drop_params: bool = True
51
+ additional_drop_params: list[str] | None = None
52
+ allowed_openai_params: list[str] | None = None
53
+ # Mock LLM response for testing
54
+ mock_response: str | None = None
55
+
56
+ _lite_llm_completion_params: dict[str, Any] = field(
57
+ default_factory=dict[str, Any], init=False, repr=False, compare=False
58
+ )
59
+
60
+ def __post_init__(self) -> None:
61
+ super().__post_init__()
62
+
63
+ self._lite_llm_completion_params.update(
64
+ {
65
+ "max_retries": self.max_client_retries,
66
+ "timeout": self.timeout,
67
+ "drop_params": self.drop_params,
68
+ "additional_drop_params": self.additional_drop_params,
69
+ "allowed_openai_params": self.allowed_openai_params,
70
+ "mock_response": self.mock_response,
71
+ # "deployment_id": deployment_id,
72
+ # "api_version": api_version,
73
+ }
74
+ )
75
+
76
+ _api_provider = self.api_provider
77
+
78
+ if self.model_name in litellm.get_valid_models(): # type: ignore[no-untyped-call]
79
+ _, provider_name, _, _ = litellm.get_llm_provider(self.model_name) # type: ignore[no-untyped-call]
80
+ _api_provider = APIProvider(name=provider_name)
81
+ elif self.api_provider is not None:
82
+ self._lite_llm_completion_params["api_key"] = self.api_provider.get(
83
+ "api_key"
84
+ )
85
+ self._lite_llm_completion_params["api_base"] = self.api_provider.get(
86
+ "api_base"
87
+ )
88
+ elif self.api_provider is None:
88
89
  raise ValueError(
89
- f"Model '{model_name}' is not supported by LiteLLM and no API provider "
90
+ f"Model '{self.model_name}' is not supported by LiteLLM and no API provider "
90
91
  "was specified. Please provide a valid API provider or use a different "
91
92
  "model."
92
93
  )
93
- super().__init__(
94
- model_name=model_name,
95
- model_id=model_id,
96
- llm_settings=llm_settings,
97
- converters=LiteLLMConverters(),
98
- tools=tools,
99
- response_schema=response_schema,
100
- response_schema_by_xml_tag=response_schema_by_xml_tag,
101
- apply_response_schema_via_provider=apply_response_schema_via_provider,
102
- api_provider=api_provider,
103
- rate_limiter=rate_limiter,
104
- max_client_retries=max_client_retries,
105
- max_response_retries=max_response_retries,
106
- )
107
94
 
108
- if self._apply_response_schema_via_provider:
109
- if self._tools:
110
- for tool in self._tools.values():
111
- tool.strict = True
112
- if not self.supports_response_schema:
113
- raise ValueError(
114
- f"Model '{self._model_name}' does not support response schema "
115
- "natively. Please set `apply_response_schema_via_provider=False`"
116
- )
95
+ if self.llm_settings is not None:
96
+ stream_options = self.llm_settings.get("stream_options") or {}
97
+ stream_options["include_usage"] = True
98
+ _llm_settings = deepcopy(self.llm_settings)
99
+ _llm_settings["stream_options"] = stream_options
100
+ else:
101
+ _llm_settings = LiteLLMSettings(stream_options={"include_usage": True})
102
+
103
+ if (
104
+ self.apply_response_schema_via_provider
105
+ and not self.supports_response_schema
106
+ ):
107
+ raise ValueError(
108
+ f"Model '{self.model_name}' does not support response schema "
109
+ "natively. Please set `apply_response_schema_via_provider=False`"
110
+ )
111
+
112
+ object.__setattr__(self, "api_provider", _api_provider)
113
+ object.__setattr__(self, "llm_settings", _llm_settings)
117
114
 
118
115
  def get_supported_openai_params(self) -> list[Any] | None:
119
116
  return get_supported_openai_params( # type: ignore[no-untyped-call]
120
- model=self._model_name, request_type="chat_completion"
117
+ model=self.model_name, request_type="chat_completion"
121
118
  )
122
119
 
123
120
  @property
124
121
  def supports_reasoning(self) -> bool:
125
- return supports_reasoning(model=self._model_name)
122
+ return supports_reasoning(model=self.model_name)
126
123
 
127
124
  @property
128
125
  def supports_parallel_function_calling(self) -> bool:
129
- return supports_parallel_function_calling(model=self._model_name)
126
+ return supports_parallel_function_calling(model=self.model_name)
130
127
 
131
128
  @property
132
129
  def supports_prompt_caching(self) -> bool:
133
- return supports_prompt_caching(model=self._model_name)
130
+ return supports_prompt_caching(model=self.model_name)
134
131
 
135
132
  @property
136
133
  def supports_response_schema(self) -> bool:
137
- return supports_response_schema(model=self._model_name)
134
+ return supports_response_schema(model=self.model_name)
138
135
 
139
136
  @property
140
137
  def supports_tool_choice(self) -> bool:
141
- return supports_tool_choice(model=self._model_name)
138
+ return supports_tool_choice(model=self.model_name)
142
139
 
143
140
  # # client
144
141
  # model_list: Optional[list] = (None,) # pass in a list of api_base,keys, etc.
@@ -153,7 +150,7 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
153
150
  **api_llm_settings: Any,
154
151
  ) -> LiteLLMCompletion:
155
152
  completion = await litellm.acompletion( # type: ignore[no-untyped-call]
156
- model=self._model_name,
153
+ model=self.model_name,
157
154
  messages=api_messages,
158
155
  tools=api_tools,
159
156
  tool_choice=api_tool_choice, # type: ignore[arg-type]
@@ -180,7 +177,7 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
180
177
  **api_llm_settings: Any,
181
178
  ) -> AsyncIterator[LiteLLMCompletionChunk]:
182
179
  stream = await litellm.acompletion( # type: ignore[no-untyped-call]
183
- model=self._model_name,
180
+ model=self.model_name,
184
181
  messages=api_messages,
185
182
  tools=api_tools,
186
183
  tool_choice=api_tool_choice, # type: ignore[arg-type]
@@ -192,11 +189,24 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
192
189
  )
193
190
  stream = cast("CustomStreamWrapper", stream)
194
191
 
192
+ tc_indices: dict[int, set[int]] = defaultdict(set)
193
+
195
194
  async for completion_chunk in stream:
195
+ # Fix tool call indices to be unique within each choice
196
+ for n, choice in enumerate(completion_chunk.choices):
197
+ for tc in choice.delta.tool_calls or []:
198
+ # Tool call ID is not None only when it is a new tool call
199
+ if tc.id and tc.index in tc_indices[n]:
200
+ tc.index = max(tc_indices[n]) + 1
201
+ tc_indices[n].add(tc.index)
202
+
196
203
  yield completion_chunk
197
204
 
198
205
  def combine_completion_chunks(
199
- self, completion_chunks: list[LiteLLMCompletionChunk]
206
+ self,
207
+ completion_chunks: list[LiteLLMCompletionChunk],
208
+ response_schema: Any | None = None,
209
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
200
210
  ) -> LiteLLMCompletion:
201
211
  combined_chunk = cast(
202
212
  "LiteLLMCompletion",