grasp_agents 0.5.9__tar.gz → 0.5.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/PKG-INFO +1 -1
  2. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/pyproject.toml +1 -1
  3. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/cloud_llm.py +88 -109
  4. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/litellm/converters.py +4 -2
  5. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/litellm/lite_llm.py +72 -83
  6. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/llm.py +35 -68
  7. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/llm_agent.py +32 -36
  8. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/llm_agent_memory.py +3 -2
  9. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/llm_policy_executor.py +63 -33
  10. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/openai/converters.py +4 -2
  11. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/openai/openai_llm.py +60 -87
  12. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/openai/tool_converters.py +6 -4
  13. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/processors/base_processor.py +18 -10
  14. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/processors/parallel_processor.py +8 -6
  15. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/processors/processor.py +10 -6
  16. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/prompt_builder.py +22 -28
  17. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/run_context.py +1 -1
  18. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/runner.py +1 -1
  19. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/typing/converters.py +3 -1
  20. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/typing/tool.py +13 -5
  21. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/workflow/workflow_processor.py +4 -4
  22. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/.gitignore +0 -0
  23. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/LICENSE.md +0 -0
  24. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/README.md +0 -0
  25. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/__init__.py +0 -0
  26. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/costs_dict.yaml +0 -0
  27. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/errors.py +0 -0
  28. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/generics_utils.py +0 -0
  29. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/grasp_logging.py +0 -0
  30. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/http_client.py +0 -0
  31. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/litellm/__init__.py +0 -0
  32. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/litellm/completion_chunk_converters.py +0 -0
  33. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/litellm/completion_converters.py +0 -0
  34. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/litellm/message_converters.py +0 -0
  35. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/memory.py +0 -0
  36. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/openai/__init__.py +0 -0
  37. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/openai/completion_chunk_converters.py +0 -0
  38. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/openai/completion_converters.py +0 -0
  39. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/openai/content_converters.py +0 -0
  40. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/openai/message_converters.py +0 -0
  41. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/packet.py +0 -0
  42. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/packet_pool.py +0 -0
  43. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/printer.py +0 -0
  44. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/rate_limiting/__init__.py +0 -0
  45. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
  46. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/rate_limiting/types.py +0 -0
  47. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/rate_limiting/utils.py +0 -0
  48. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/typing/__init__.py +0 -0
  49. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/typing/completion.py +0 -0
  50. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/typing/completion_chunk.py +0 -0
  51. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/typing/content.py +0 -0
  52. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/typing/events.py +0 -0
  53. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/typing/io.py +0 -0
  54. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/typing/message.py +0 -0
  55. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/usage_tracker.py +0 -0
  56. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/utils.py +0 -0
  57. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/workflow/__init__.py +0 -0
  58. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/workflow/looped_workflow.py +0 -0
  59. {grasp_agents-0.5.9 → grasp_agents-0.5.10}/src/grasp_agents/workflow/sequential_workflow.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.5.9
3
+ Version: 0.5.10
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "grasp_agents"
3
- version = "0.5.9"
3
+ version = "0.5.10"
4
4
  description = "Grasp Agents Library"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11.4,<4"
@@ -1,7 +1,8 @@
1
1
  import logging
2
2
  from abc import abstractmethod
3
- from collections.abc import AsyncIterator, Mapping, Sequence
3
+ from collections.abc import AsyncIterator, Mapping
4
4
  from copy import deepcopy
5
+ from dataclasses import dataclass
5
6
  from typing import Any, Generic, Required, cast
6
7
 
7
8
  import httpx
@@ -58,111 +59,53 @@ LLMRateLimiter = RateLimiterC[
58
59
  ]
59
60
 
60
61
 
62
+ @dataclass(frozen=True)
61
63
  class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
62
- def __init__(
63
- self,
64
- # Base LLM args
65
- model_name: str,
66
- api_provider: APIProvider,
67
- converters: ConvertT_co,
68
- llm_settings: SettingsT_co | None = None,
69
- tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
70
- response_schema: Any | None = None,
71
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
72
- apply_response_schema_via_provider: bool = True,
73
- model_id: str | None = None,
74
- # Connection settings
75
- async_http_client: httpx.AsyncClient | None = None,
76
- async_http_client_params: (
77
- dict[str, Any] | AsyncHTTPClientParams | None
78
- ) = None,
79
- max_client_retries: int = 2,
80
- # Rate limiting
81
- rate_limiter: LLMRateLimiter | None = None,
82
- # LLM response retries: try to regenerate to pass validation
83
- max_response_retries: int = 0,
84
- **kwargs: Any,
85
- ) -> None:
86
- self.llm_settings: CloudLLMSettings | None
87
-
88
- super().__init__(
89
- model_name=model_name,
90
- llm_settings=llm_settings,
91
- converters=converters,
92
- model_id=model_id,
93
- tools=tools,
94
- response_schema=response_schema,
95
- response_schema_by_xml_tag=response_schema_by_xml_tag,
96
- **kwargs,
97
- )
98
-
99
- self._model_name = model_name
100
- self._api_provider = api_provider
101
- self._apply_response_schema_via_provider = apply_response_schema_via_provider
102
-
103
- if (
104
- apply_response_schema_via_provider
105
- and response_schema_by_xml_tag is not None
106
- ):
107
- raise ValueError(
108
- "Response schema by XML tag is not supported "
109
- "when apply_response_schema_via_provider is True."
110
- )
64
+ # Make this field keyword-only to avoid ordering issues with inherited defaulted fields
65
+ api_provider: APIProvider | None = None
66
+ llm_settings: SettingsT_co | None = None
67
+ rate_limiter: LLMRateLimiter | None = None
68
+ max_client_retries: int = 2 # HTTP client retries for network errors
69
+ max_response_retries: int = (
70
+ 0 # LLM response retries: try to regenerate to pass validation
71
+ )
72
+ apply_response_schema_via_provider: bool = False
73
+ async_http_client: httpx.AsyncClient | None = None
74
+ async_http_client_params: dict[str, Any] | AsyncHTTPClientParams | None = None
111
75
 
112
- self._rate_limiter: LLMRateLimiter | None = None
113
- if rate_limiter is not None:
114
- self._rate_limiter = rate_limiter
76
+ def __post_init__(self) -> None:
77
+ if self.rate_limiter is not None:
115
78
  logger.info(
116
- f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
79
+ f"[{self.__class__.__name__}] Set rate limit to "
80
+ f"{self.rate_limiter.rpm} RPM"
117
81
  )
118
82
 
119
- self._async_http_client: httpx.AsyncClient | None = None
120
- if async_http_client is not None:
121
- self._async_http_client = async_http_client
122
- elif async_http_client_params is not None:
123
- self._async_http_client = create_simple_async_httpx_client(
124
- async_http_client_params
83
+ if self.async_http_client is None and self.async_http_client_params is not None:
84
+ object.__setattr__(
85
+ self,
86
+ "async_http_client",
87
+ create_simple_async_httpx_client(self.async_http_client_params),
125
88
  )
126
89
 
127
- self.max_client_retries = max_client_retries
128
- self.max_response_retries = max_response_retries
129
-
130
- @property
131
- def api_provider(self) -> APIProvider:
132
- return self._api_provider
133
-
134
- @property
135
- def rate_limiter(self) -> LLMRateLimiter | None:
136
- return self._rate_limiter
137
-
138
- @property
139
- def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
140
- return self._tools
141
-
142
- @tools.setter
143
- def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
144
- if not tools:
145
- self._tools = None
146
- return
147
- strict_value = True if self._apply_response_schema_via_provider else None
148
- for t in tools:
149
- t.strict = strict_value
150
- self._tools = {t.name: t for t in tools}
151
-
152
90
  def _make_completion_kwargs(
153
91
  self,
154
92
  conversation: Messages,
93
+ response_schema: Any | None = None,
94
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
155
95
  tool_choice: ToolChoice | None = None,
156
96
  n_choices: int | None = None,
157
97
  ) -> dict[str, Any]:
158
- api_messages = [self._converters.to_message(m) for m in conversation]
98
+ api_messages = [self.converters.to_message(m) for m in conversation]
159
99
 
160
100
  api_tools = None
161
101
  api_tool_choice = None
162
- if self.tools:
163
- api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
102
+ if tools:
103
+ strict = True if self.apply_response_schema_via_provider else None
104
+ api_tools = [
105
+ self.converters.to_tool(t, strict=strict) for t in tools.values()
106
+ ]
164
107
  if tool_choice is not None:
165
- api_tool_choice = self._converters.to_tool_choice(tool_choice)
108
+ api_tool_choice = self.converters.to_tool_choice(tool_choice)
166
109
 
167
110
  api_llm_settings = deepcopy(self.llm_settings or {})
168
111
 
@@ -170,7 +113,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
170
113
  api_messages=api_messages,
171
114
  api_tools=api_tools,
172
115
  api_tool_choice=api_tool_choice,
173
- api_response_schema=self._response_schema,
116
+ api_response_schema=response_schema,
174
117
  n_choices=n_choices,
175
118
  **api_llm_settings,
176
119
  )
@@ -206,24 +149,34 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
206
149
  self,
207
150
  conversation: Messages,
208
151
  *,
152
+ response_schema: Any | None = None,
153
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
154
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
209
155
  tool_choice: ToolChoice | None = None,
210
156
  n_choices: int | None = None,
211
157
  ) -> Completion:
212
158
  completion_kwargs = self._make_completion_kwargs(
213
- conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
159
+ conversation=conversation,
160
+ response_schema=response_schema,
161
+ tools=tools,
162
+ tool_choice=tool_choice,
163
+ n_choices=n_choices,
214
164
  )
215
165
 
216
- if not self._apply_response_schema_via_provider:
166
+ if not self.apply_response_schema_via_provider:
217
167
  completion_kwargs.pop("api_response_schema", None)
218
168
  api_completion = await self._get_completion(**completion_kwargs)
219
169
 
220
- completion = self._converters.from_completion(
221
- api_completion, name=self.model_id
222
- )
170
+ completion = self.converters.from_completion(api_completion, name=self.model_id)
223
171
 
224
- if not self._apply_response_schema_via_provider:
225
- self._validate_response(completion)
226
- self._validate_tool_calls(completion)
172
+ if not self.apply_response_schema_via_provider:
173
+ self._validate_response(
174
+ completion,
175
+ response_schema=response_schema,
176
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
177
+ )
178
+ if tools is not None:
179
+ self._validate_tool_calls(completion, tools=tools)
227
180
 
228
181
  return completion
229
182
 
@@ -231,6 +184,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
231
184
  self,
232
185
  conversation: Messages,
233
186
  *,
187
+ response_schema: Any | None = None,
188
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
189
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
234
190
  tool_choice: ToolChoice | None = None,
235
191
  n_choices: int | None = None,
236
192
  proc_name: str | None = None,
@@ -241,6 +197,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
241
197
  try:
242
198
  return await self._generate_completion_once(
243
199
  conversation, # type: ignore[return]
200
+ response_schema=response_schema,
201
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
202
+ tools=tools,
244
203
  tool_choice=tool_choice,
245
204
  n_choices=n_choices,
246
205
  )
@@ -263,7 +222,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
263
222
  )
264
223
 
265
224
  return make_refusal_completion(
266
- self._model_name,
225
+ self.model_name,
267
226
  Exception("Unexpected error: retry loop exited without returning"),
268
227
  )
269
228
 
@@ -272,15 +231,22 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
272
231
  self,
273
232
  conversation: Messages,
274
233
  *,
234
+ response_schema: Any | None = None,
235
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
236
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
275
237
  tool_choice: ToolChoice | None = None,
276
238
  n_choices: int | None = None,
277
239
  proc_name: str | None = None,
278
240
  call_id: str | None = None,
279
241
  ) -> AsyncIterator[CompletionChunkEvent[CompletionChunk] | CompletionEvent]:
280
242
  completion_kwargs = self._make_completion_kwargs(
281
- conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
243
+ conversation=conversation,
244
+ response_schema=response_schema,
245
+ tools=tools,
246
+ tool_choice=tool_choice,
247
+ n_choices=n_choices,
282
248
  )
283
- if not self._apply_response_schema_via_provider:
249
+ if not self.apply_response_schema_via_provider:
284
250
  completion_kwargs.pop("api_response_schema", None)
285
251
 
286
252
  api_stream = self._get_completion_stream(**completion_kwargs)
@@ -293,7 +259,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
293
259
 
294
260
  async for api_completion_chunk in api_stream:
295
261
  api_completion_chunks.append(api_completion_chunk)
296
- completion_chunk = self._converters.from_completion_chunk(
262
+ completion_chunk = self.converters.from_completion_chunk(
297
263
  api_completion_chunk, name=self.model_id
298
264
  )
299
265
 
@@ -301,16 +267,23 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
301
267
  data=completion_chunk, proc_name=proc_name, call_id=call_id
302
268
  )
303
269
 
304
- api_completion = self.combine_completion_chunks(api_completion_chunks)
305
- completion = self._converters.from_completion(
270
+ api_completion = self.combine_completion_chunks(
271
+ api_completion_chunks, response_schema=response_schema, tools=tools
272
+ )
273
+ completion = self.converters.from_completion(
306
274
  api_completion, name=self.model_id
307
275
  )
308
276
 
309
277
  yield CompletionEvent(data=completion, proc_name=proc_name, call_id=call_id)
310
278
 
311
- if not self._apply_response_schema_via_provider:
312
- self._validate_response(completion)
313
- self._validate_tool_calls(completion)
279
+ if not self.apply_response_schema_via_provider:
280
+ self._validate_response(
281
+ completion,
282
+ response_schema=response_schema,
283
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
284
+ )
285
+ if tools is not None:
286
+ self._validate_tool_calls(completion, tools=tools)
314
287
 
315
288
  return iterator()
316
289
 
@@ -318,6 +291,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
318
291
  self,
319
292
  conversation: Messages,
320
293
  *,
294
+ response_schema: Any | None = None,
295
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
296
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
321
297
  tool_choice: ToolChoice | None = None,
322
298
  n_choices: int | None = None,
323
299
  proc_name: str | None = None,
@@ -330,6 +306,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
330
306
  try:
331
307
  async for event in await self._generate_completion_stream_once( # type: ignore[return]
332
308
  conversation, # type: ignore[arg-type]
309
+ response_schema=response_schema,
310
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
311
+ tools=tools,
333
312
  tool_choice=tool_choice,
334
313
  n_choices=n_choices,
335
314
  proc_name=proc_name,
@@ -339,7 +318,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
339
318
  return
340
319
  except (LLMResponseValidationError, LLMToolCallValidationError) as err:
341
320
  err_data = LLMStreamingErrorData(
342
- error=err, model_name=self._model_name, model_id=self.model_id
321
+ error=err, model_name=self.model_name, model_id=self.model_id
343
322
  )
344
323
  yield LLMStreamingErrorEvent(
345
324
  data=err_data, proc_name=proc_name, call_id=call_id
@@ -355,7 +334,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
355
334
  f"retrying:\n{err}"
356
335
  )
357
336
  refusal_completion = make_refusal_completion(
358
- self._model_name, err
337
+ self.model_name, err
359
338
  )
360
339
  yield CompletionEvent(
361
340
  data=refusal_completion,
@@ -118,8 +118,10 @@ class LiteLLMConverters(Converters):
118
118
  return from_api_tool_message(raw_message, name=name, **kwargs)
119
119
 
120
120
  @staticmethod
121
- def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> OpenAIToolParam:
122
- return to_api_tool(tool, **kwargs)
121
+ def to_tool(
122
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
123
+ ) -> OpenAIToolParam:
124
+ return to_api_tool(tool, strict=strict, **kwargs)
123
125
 
124
126
  @staticmethod
125
127
  def to_tool_choice(
@@ -2,6 +2,7 @@ import logging
2
2
  from collections import defaultdict
3
3
  from collections.abc import AsyncIterator, Mapping
4
4
  from copy import deepcopy
5
+ from dataclasses import dataclass, field
5
6
  from typing import Any, cast
6
7
 
7
8
  import litellm
@@ -21,7 +22,7 @@ from litellm.utils import (
21
22
  # from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
22
23
  from pydantic import BaseModel
23
24
 
24
- from ..cloud_llm import APIProvider, CloudLLM, LLMRateLimiter
25
+ from ..cloud_llm import APIProvider, CloudLLM
25
26
  from ..openai.openai_llm import OpenAILLMSettings
26
27
  from ..typing.tool import BaseTool
27
28
  from . import (
@@ -40,116 +41,101 @@ class LiteLLMSettings(OpenAILLMSettings, total=False):
40
41
  thinking: AnthropicThinkingParam | None
41
42
 
42
43
 
44
+ @dataclass(frozen=True)
43
45
  class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
44
- def __init__(
45
- self,
46
- # Base LLM args
47
- model_name: str,
48
- model_id: str | None = None,
49
- llm_settings: LiteLLMSettings | None = None,
50
- tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
51
- response_schema: Any | None = None,
52
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
53
- apply_response_schema_via_provider: bool = False,
54
- # LLM provider
55
- api_provider: APIProvider | None = None,
56
- # deployment_id: str | None = None,
57
- # api_version: str | None = None,
58
- # Connection settings
59
- timeout: float | None = None,
60
- max_client_retries: int = 2,
61
- # Rate limiting
62
- rate_limiter: LLMRateLimiter | None = None,
63
- # Drop unsupported LLM settings
64
- drop_params: bool = True,
65
- additional_drop_params: list[str] | None = None,
66
- allowed_openai_params: list[str] | None = None,
67
- # Mock LLM response for testing
68
- mock_response: str | None = None,
69
- # LLM response retries: try to regenerate to pass validation
70
- max_response_retries: int = 1,
71
- ) -> None:
72
- self._lite_llm_completion_params: dict[str, Any] = {
73
- "max_retries": max_client_retries,
74
- "timeout": timeout,
75
- "drop_params": drop_params,
76
- "additional_drop_params": additional_drop_params,
77
- "allowed_openai_params": allowed_openai_params,
78
- "mock_response": mock_response,
79
- # "deployment_id": deployment_id,
80
- # "api_version": api_version,
81
- }
82
-
83
- if model_name in litellm.get_valid_models(): # type: ignore[no-untyped-call]
84
- _, provider_name, _, _ = litellm.get_llm_provider(model_name) # type: ignore[no-untyped-call]
85
- api_provider = APIProvider(name=provider_name)
86
- elif api_provider is not None:
87
- self._lite_llm_completion_params["api_key"] = api_provider.get("api_key")
88
- self._lite_llm_completion_params["api_base"] = api_provider.get("api_base")
89
- elif api_provider is None:
46
+ converters: LiteLLMConverters = field(default_factory=LiteLLMConverters)
47
+
48
+ timeout: float | None = None
49
+ # Drop unsupported LLM settings
50
+ drop_params: bool = True
51
+ additional_drop_params: list[str] | None = None
52
+ allowed_openai_params: list[str] | None = None
53
+ # Mock LLM response for testing
54
+ mock_response: str | None = None
55
+
56
+ _lite_llm_completion_params: dict[str, Any] = field(
57
+ default_factory=dict[str, Any], init=False, repr=False, compare=False
58
+ )
59
+
60
+ def __post_init__(self) -> None:
61
+ super().__post_init__()
62
+
63
+ self._lite_llm_completion_params.update(
64
+ {
65
+ "max_retries": self.max_client_retries,
66
+ "timeout": self.timeout,
67
+ "drop_params": self.drop_params,
68
+ "additional_drop_params": self.additional_drop_params,
69
+ "allowed_openai_params": self.allowed_openai_params,
70
+ "mock_response": self.mock_response,
71
+ # "deployment_id": deployment_id,
72
+ # "api_version": api_version,
73
+ }
74
+ )
75
+
76
+ _api_provider = self.api_provider
77
+
78
+ if self.model_name in litellm.get_valid_models(): # type: ignore[no-untyped-call]
79
+ _, provider_name, _, _ = litellm.get_llm_provider(self.model_name) # type: ignore[no-untyped-call]
80
+ _api_provider = APIProvider(name=provider_name)
81
+ elif self.api_provider is not None:
82
+ self._lite_llm_completion_params["api_key"] = self.api_provider.get(
83
+ "api_key"
84
+ )
85
+ self._lite_llm_completion_params["api_base"] = self.api_provider.get(
86
+ "api_base"
87
+ )
88
+ elif self.api_provider is None:
90
89
  raise ValueError(
91
- f"Model '{model_name}' is not supported by LiteLLM and no API provider "
90
+ f"Model '{self.model_name}' is not supported by LiteLLM and no API provider "
92
91
  "was specified. Please provide a valid API provider or use a different "
93
92
  "model."
94
93
  )
95
94
 
96
- if llm_settings is not None:
97
- stream_options = llm_settings.get("stream_options") or {}
95
+ if self.llm_settings is not None:
96
+ stream_options = self.llm_settings.get("stream_options") or {}
98
97
  stream_options["include_usage"] = True
99
- _llm_settings = deepcopy(llm_settings)
98
+ _llm_settings = deepcopy(self.llm_settings)
100
99
  _llm_settings["stream_options"] = stream_options
101
100
  else:
102
101
  _llm_settings = LiteLLMSettings(stream_options={"include_usage": True})
103
102
 
104
- super().__init__(
105
- model_name=model_name,
106
- model_id=model_id,
107
- llm_settings=_llm_settings,
108
- converters=LiteLLMConverters(),
109
- tools=tools,
110
- response_schema=response_schema,
111
- response_schema_by_xml_tag=response_schema_by_xml_tag,
112
- apply_response_schema_via_provider=apply_response_schema_via_provider,
113
- api_provider=api_provider,
114
- rate_limiter=rate_limiter,
115
- max_client_retries=max_client_retries,
116
- max_response_retries=max_response_retries,
117
- )
103
+ if (
104
+ self.apply_response_schema_via_provider
105
+ and not self.supports_response_schema
106
+ ):
107
+ raise ValueError(
108
+ f"Model '{self.model_name}' does not support response schema "
109
+ "natively. Please set `apply_response_schema_via_provider=False`"
110
+ )
118
111
 
119
- if self._apply_response_schema_via_provider:
120
- if self._tools:
121
- for tool in self._tools.values():
122
- tool.strict = True
123
- if not self.supports_response_schema:
124
- raise ValueError(
125
- f"Model '{self._model_name}' does not support response schema "
126
- "natively. Please set `apply_response_schema_via_provider=False`"
127
- )
112
+ object.__setattr__(self, "api_provider", _api_provider)
113
+ object.__setattr__(self, "llm_settings", _llm_settings)
128
114
 
129
115
  def get_supported_openai_params(self) -> list[Any] | None:
130
116
  return get_supported_openai_params( # type: ignore[no-untyped-call]
131
- model=self._model_name, request_type="chat_completion"
117
+ model=self.model_name, request_type="chat_completion"
132
118
  )
133
119
 
134
120
  @property
135
121
  def supports_reasoning(self) -> bool:
136
- return supports_reasoning(model=self._model_name)
122
+ return supports_reasoning(model=self.model_name)
137
123
 
138
124
  @property
139
125
  def supports_parallel_function_calling(self) -> bool:
140
- return supports_parallel_function_calling(model=self._model_name)
126
+ return supports_parallel_function_calling(model=self.model_name)
141
127
 
142
128
  @property
143
129
  def supports_prompt_caching(self) -> bool:
144
- return supports_prompt_caching(model=self._model_name)
130
+ return supports_prompt_caching(model=self.model_name)
145
131
 
146
132
  @property
147
133
  def supports_response_schema(self) -> bool:
148
- return supports_response_schema(model=self._model_name)
134
+ return supports_response_schema(model=self.model_name)
149
135
 
150
136
  @property
151
137
  def supports_tool_choice(self) -> bool:
152
- return supports_tool_choice(model=self._model_name)
138
+ return supports_tool_choice(model=self.model_name)
153
139
 
154
140
  # # client
155
141
  # model_list: Optional[list] = (None,) # pass in a list of api_base,keys, etc.
@@ -164,7 +150,7 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
164
150
  **api_llm_settings: Any,
165
151
  ) -> LiteLLMCompletion:
166
152
  completion = await litellm.acompletion( # type: ignore[no-untyped-call]
167
- model=self._model_name,
153
+ model=self.model_name,
168
154
  messages=api_messages,
169
155
  tools=api_tools,
170
156
  tool_choice=api_tool_choice, # type: ignore[arg-type]
@@ -191,7 +177,7 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
191
177
  **api_llm_settings: Any,
192
178
  ) -> AsyncIterator[LiteLLMCompletionChunk]:
193
179
  stream = await litellm.acompletion( # type: ignore[no-untyped-call]
194
- model=self._model_name,
180
+ model=self.model_name,
195
181
  messages=api_messages,
196
182
  tools=api_tools,
197
183
  tool_choice=api_tool_choice, # type: ignore[arg-type]
@@ -217,7 +203,10 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
217
203
  yield completion_chunk
218
204
 
219
205
  def combine_completion_chunks(
220
- self, completion_chunks: list[LiteLLMCompletionChunk]
206
+ self,
207
+ completion_chunks: list[LiteLLMCompletionChunk],
208
+ response_schema: Any | None = None,
209
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
221
210
  ) -> LiteLLMCompletion:
222
211
  combined_chunk = cast(
223
212
  "LiteLLMCompletion",