grasp_agents 0.5.9__tar.gz → 0.5.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/PKG-INFO +4 -5
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/README.md +3 -4
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/pyproject.toml +1 -1
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/cloud_llm.py +87 -109
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/litellm/converters.py +4 -2
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/litellm/lite_llm.py +72 -83
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/llm.py +35 -68
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/llm_agent.py +76 -52
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/llm_agent_memory.py +4 -2
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/llm_policy_executor.py +91 -55
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/openai/converters.py +4 -2
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/openai/openai_llm.py +61 -88
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/openai/tool_converters.py +6 -4
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/processors/base_processor.py +18 -10
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/processors/parallel_processor.py +8 -6
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/processors/processor.py +10 -6
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/prompt_builder.py +38 -28
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/run_context.py +1 -1
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/runner.py +1 -1
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/typing/converters.py +3 -1
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/typing/tool.py +15 -5
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/workflow/workflow_processor.py +4 -4
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/.gitignore +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/LICENSE.md +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/__init__.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/costs_dict.yaml +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/errors.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/generics_utils.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/grasp_logging.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/http_client.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/litellm/__init__.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/litellm/completion_chunk_converters.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/litellm/completion_converters.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/litellm/message_converters.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/memory.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/openai/__init__.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/openai/completion_chunk_converters.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/openai/completion_converters.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/openai/content_converters.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/openai/message_converters.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/packet.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/packet_pool.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/printer.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/rate_limiting/__init__.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/rate_limiting/types.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/rate_limiting/utils.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/typing/__init__.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/typing/completion.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/typing/completion_chunk.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/typing/content.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/typing/events.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/typing/io.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/typing/message.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/usage_tracker.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/utils.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/workflow/__init__.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/workflow/looped_workflow.py +0 -0
- {grasp_agents-0.5.9 → grasp_agents-0.5.11}/src/grasp_agents/workflow/sequential_workflow.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: grasp_agents
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.11
|
4
4
|
Summary: Grasp Agents Library
|
5
5
|
License-File: LICENSE.md
|
6
6
|
Requires-Python: <4,>=3.11.4
|
@@ -166,9 +166,7 @@ class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, Any]):
|
|
166
166
|
name: str = "ask_student"
|
167
167
|
description: str = ask_student_tool_description
|
168
168
|
|
169
|
-
async def run(
|
170
|
-
self, inp: TeacherQuestion, ctx: RunContext[Any] | None = None
|
171
|
-
) -> StudentReply:
|
169
|
+
async def run(self, inp: TeacherQuestion, **kwargs: Any) -> StudentReply:
|
172
170
|
return input(inp.question)
|
173
171
|
|
174
172
|
|
@@ -180,7 +178,8 @@ teacher = LLMAgent[None, Problem, None](
|
|
180
178
|
name="teacher",
|
181
179
|
llm=LiteLLM(
|
182
180
|
model_name="gpt-4.1",
|
183
|
-
|
181
|
+
# model_name="claude-sonnet-4-20250514",
|
182
|
+
# llm_settings=LiteLLMSettings(reasoning_effort="low"),
|
184
183
|
),
|
185
184
|
tools=[AskStudentTool()],
|
186
185
|
react_mode=True,
|
@@ -149,9 +149,7 @@ class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, Any]):
|
|
149
149
|
name: str = "ask_student"
|
150
150
|
description: str = ask_student_tool_description
|
151
151
|
|
152
|
-
async def run(
|
153
|
-
self, inp: TeacherQuestion, ctx: RunContext[Any] | None = None
|
154
|
-
) -> StudentReply:
|
152
|
+
async def run(self, inp: TeacherQuestion, **kwargs: Any) -> StudentReply:
|
155
153
|
return input(inp.question)
|
156
154
|
|
157
155
|
|
@@ -163,7 +161,8 @@ teacher = LLMAgent[None, Problem, None](
|
|
163
161
|
name="teacher",
|
164
162
|
llm=LiteLLM(
|
165
163
|
model_name="gpt-4.1",
|
166
|
-
|
164
|
+
# model_name="claude-sonnet-4-20250514",
|
165
|
+
# llm_settings=LiteLLMSettings(reasoning_effort="low"),
|
167
166
|
),
|
168
167
|
tools=[AskStudentTool()],
|
169
168
|
react_mode=True,
|
@@ -1,7 +1,8 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import abstractmethod
|
3
|
-
from collections.abc import AsyncIterator, Mapping
|
3
|
+
from collections.abc import AsyncIterator, Mapping
|
4
4
|
from copy import deepcopy
|
5
|
+
from dataclasses import dataclass
|
5
6
|
from typing import Any, Generic, Required, cast
|
6
7
|
|
7
8
|
import httpx
|
@@ -58,111 +59,52 @@ LLMRateLimiter = RateLimiterC[
|
|
58
59
|
]
|
59
60
|
|
60
61
|
|
62
|
+
@dataclass(frozen=True)
|
61
63
|
class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
apply_response_schema_via_provider: bool = True,
|
73
|
-
model_id: str | None = None,
|
74
|
-
# Connection settings
|
75
|
-
async_http_client: httpx.AsyncClient | None = None,
|
76
|
-
async_http_client_params: (
|
77
|
-
dict[str, Any] | AsyncHTTPClientParams | None
|
78
|
-
) = None,
|
79
|
-
max_client_retries: int = 2,
|
80
|
-
# Rate limiting
|
81
|
-
rate_limiter: LLMRateLimiter | None = None,
|
82
|
-
# LLM response retries: try to regenerate to pass validation
|
83
|
-
max_response_retries: int = 0,
|
84
|
-
**kwargs: Any,
|
85
|
-
) -> None:
|
86
|
-
self.llm_settings: CloudLLMSettings | None
|
87
|
-
|
88
|
-
super().__init__(
|
89
|
-
model_name=model_name,
|
90
|
-
llm_settings=llm_settings,
|
91
|
-
converters=converters,
|
92
|
-
model_id=model_id,
|
93
|
-
tools=tools,
|
94
|
-
response_schema=response_schema,
|
95
|
-
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
96
|
-
**kwargs,
|
97
|
-
)
|
98
|
-
|
99
|
-
self._model_name = model_name
|
100
|
-
self._api_provider = api_provider
|
101
|
-
self._apply_response_schema_via_provider = apply_response_schema_via_provider
|
102
|
-
|
103
|
-
if (
|
104
|
-
apply_response_schema_via_provider
|
105
|
-
and response_schema_by_xml_tag is not None
|
106
|
-
):
|
107
|
-
raise ValueError(
|
108
|
-
"Response schema by XML tag is not supported "
|
109
|
-
"when apply_response_schema_via_provider is True."
|
110
|
-
)
|
64
|
+
api_provider: APIProvider | None = None
|
65
|
+
llm_settings: SettingsT_co | None = None
|
66
|
+
rate_limiter: LLMRateLimiter | None = None
|
67
|
+
max_client_retries: int = 2 # HTTP client retries for network errors
|
68
|
+
max_response_retries: int = (
|
69
|
+
0 # LLM response retries: try to regenerate to pass validation
|
70
|
+
)
|
71
|
+
apply_response_schema_via_provider: bool = False
|
72
|
+
async_http_client: httpx.AsyncClient | None = None
|
73
|
+
async_http_client_params: dict[str, Any] | AsyncHTTPClientParams | None = None
|
111
74
|
|
112
|
-
|
113
|
-
if rate_limiter is not None:
|
114
|
-
self._rate_limiter = rate_limiter
|
75
|
+
def __post_init__(self) -> None:
|
76
|
+
if self.rate_limiter is not None:
|
115
77
|
logger.info(
|
116
|
-
f"[{self.__class__.__name__}] Set rate limit to
|
78
|
+
f"[{self.__class__.__name__}] Set rate limit to "
|
79
|
+
f"{self.rate_limiter.rpm} RPM"
|
117
80
|
)
|
118
81
|
|
119
|
-
self.
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
async_http_client_params
|
82
|
+
if self.async_http_client is None and self.async_http_client_params is not None:
|
83
|
+
object.__setattr__(
|
84
|
+
self,
|
85
|
+
"async_http_client",
|
86
|
+
create_simple_async_httpx_client(self.async_http_client_params),
|
125
87
|
)
|
126
88
|
|
127
|
-
self.max_client_retries = max_client_retries
|
128
|
-
self.max_response_retries = max_response_retries
|
129
|
-
|
130
|
-
@property
|
131
|
-
def api_provider(self) -> APIProvider:
|
132
|
-
return self._api_provider
|
133
|
-
|
134
|
-
@property
|
135
|
-
def rate_limiter(self) -> LLMRateLimiter | None:
|
136
|
-
return self._rate_limiter
|
137
|
-
|
138
|
-
@property
|
139
|
-
def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
|
140
|
-
return self._tools
|
141
|
-
|
142
|
-
@tools.setter
|
143
|
-
def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
144
|
-
if not tools:
|
145
|
-
self._tools = None
|
146
|
-
return
|
147
|
-
strict_value = True if self._apply_response_schema_via_provider else None
|
148
|
-
for t in tools:
|
149
|
-
t.strict = strict_value
|
150
|
-
self._tools = {t.name: t for t in tools}
|
151
|
-
|
152
89
|
def _make_completion_kwargs(
|
153
90
|
self,
|
154
91
|
conversation: Messages,
|
92
|
+
response_schema: Any | None = None,
|
93
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
155
94
|
tool_choice: ToolChoice | None = None,
|
156
95
|
n_choices: int | None = None,
|
157
96
|
) -> dict[str, Any]:
|
158
|
-
api_messages = [self.
|
97
|
+
api_messages = [self.converters.to_message(m) for m in conversation]
|
159
98
|
|
160
99
|
api_tools = None
|
161
100
|
api_tool_choice = None
|
162
|
-
if
|
163
|
-
|
101
|
+
if tools:
|
102
|
+
strict = True if self.apply_response_schema_via_provider else None
|
103
|
+
api_tools = [
|
104
|
+
self.converters.to_tool(t, strict=strict) for t in tools.values()
|
105
|
+
]
|
164
106
|
if tool_choice is not None:
|
165
|
-
api_tool_choice = self.
|
107
|
+
api_tool_choice = self.converters.to_tool_choice(tool_choice)
|
166
108
|
|
167
109
|
api_llm_settings = deepcopy(self.llm_settings or {})
|
168
110
|
|
@@ -170,7 +112,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
170
112
|
api_messages=api_messages,
|
171
113
|
api_tools=api_tools,
|
172
114
|
api_tool_choice=api_tool_choice,
|
173
|
-
api_response_schema=
|
115
|
+
api_response_schema=response_schema,
|
174
116
|
n_choices=n_choices,
|
175
117
|
**api_llm_settings,
|
176
118
|
)
|
@@ -206,24 +148,34 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
206
148
|
self,
|
207
149
|
conversation: Messages,
|
208
150
|
*,
|
151
|
+
response_schema: Any | None = None,
|
152
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
153
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
209
154
|
tool_choice: ToolChoice | None = None,
|
210
155
|
n_choices: int | None = None,
|
211
156
|
) -> Completion:
|
212
157
|
completion_kwargs = self._make_completion_kwargs(
|
213
|
-
conversation=conversation,
|
158
|
+
conversation=conversation,
|
159
|
+
response_schema=response_schema,
|
160
|
+
tools=tools,
|
161
|
+
tool_choice=tool_choice,
|
162
|
+
n_choices=n_choices,
|
214
163
|
)
|
215
164
|
|
216
|
-
if not self.
|
165
|
+
if not self.apply_response_schema_via_provider:
|
217
166
|
completion_kwargs.pop("api_response_schema", None)
|
218
167
|
api_completion = await self._get_completion(**completion_kwargs)
|
219
168
|
|
220
|
-
completion = self.
|
221
|
-
api_completion, name=self.model_id
|
222
|
-
)
|
169
|
+
completion = self.converters.from_completion(api_completion, name=self.model_id)
|
223
170
|
|
224
|
-
if not self.
|
225
|
-
self._validate_response(
|
226
|
-
|
171
|
+
if not self.apply_response_schema_via_provider:
|
172
|
+
self._validate_response(
|
173
|
+
completion,
|
174
|
+
response_schema=response_schema,
|
175
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
176
|
+
)
|
177
|
+
if tools is not None:
|
178
|
+
self._validate_tool_calls(completion, tools=tools)
|
227
179
|
|
228
180
|
return completion
|
229
181
|
|
@@ -231,6 +183,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
231
183
|
self,
|
232
184
|
conversation: Messages,
|
233
185
|
*,
|
186
|
+
response_schema: Any | None = None,
|
187
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
188
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
234
189
|
tool_choice: ToolChoice | None = None,
|
235
190
|
n_choices: int | None = None,
|
236
191
|
proc_name: str | None = None,
|
@@ -241,6 +196,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
241
196
|
try:
|
242
197
|
return await self._generate_completion_once(
|
243
198
|
conversation, # type: ignore[return]
|
199
|
+
response_schema=response_schema,
|
200
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
201
|
+
tools=tools,
|
244
202
|
tool_choice=tool_choice,
|
245
203
|
n_choices=n_choices,
|
246
204
|
)
|
@@ -263,7 +221,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
263
221
|
)
|
264
222
|
|
265
223
|
return make_refusal_completion(
|
266
|
-
self.
|
224
|
+
self.model_name,
|
267
225
|
Exception("Unexpected error: retry loop exited without returning"),
|
268
226
|
)
|
269
227
|
|
@@ -272,15 +230,22 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
272
230
|
self,
|
273
231
|
conversation: Messages,
|
274
232
|
*,
|
233
|
+
response_schema: Any | None = None,
|
234
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
235
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
275
236
|
tool_choice: ToolChoice | None = None,
|
276
237
|
n_choices: int | None = None,
|
277
238
|
proc_name: str | None = None,
|
278
239
|
call_id: str | None = None,
|
279
240
|
) -> AsyncIterator[CompletionChunkEvent[CompletionChunk] | CompletionEvent]:
|
280
241
|
completion_kwargs = self._make_completion_kwargs(
|
281
|
-
conversation=conversation,
|
242
|
+
conversation=conversation,
|
243
|
+
response_schema=response_schema,
|
244
|
+
tools=tools,
|
245
|
+
tool_choice=tool_choice,
|
246
|
+
n_choices=n_choices,
|
282
247
|
)
|
283
|
-
if not self.
|
248
|
+
if not self.apply_response_schema_via_provider:
|
284
249
|
completion_kwargs.pop("api_response_schema", None)
|
285
250
|
|
286
251
|
api_stream = self._get_completion_stream(**completion_kwargs)
|
@@ -293,7 +258,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
293
258
|
|
294
259
|
async for api_completion_chunk in api_stream:
|
295
260
|
api_completion_chunks.append(api_completion_chunk)
|
296
|
-
completion_chunk = self.
|
261
|
+
completion_chunk = self.converters.from_completion_chunk(
|
297
262
|
api_completion_chunk, name=self.model_id
|
298
263
|
)
|
299
264
|
|
@@ -301,16 +266,23 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
301
266
|
data=completion_chunk, proc_name=proc_name, call_id=call_id
|
302
267
|
)
|
303
268
|
|
304
|
-
api_completion = self.combine_completion_chunks(
|
305
|
-
|
269
|
+
api_completion = self.combine_completion_chunks(
|
270
|
+
api_completion_chunks, response_schema=response_schema, tools=tools
|
271
|
+
)
|
272
|
+
completion = self.converters.from_completion(
|
306
273
|
api_completion, name=self.model_id
|
307
274
|
)
|
308
275
|
|
309
276
|
yield CompletionEvent(data=completion, proc_name=proc_name, call_id=call_id)
|
310
277
|
|
311
|
-
if not self.
|
312
|
-
self._validate_response(
|
313
|
-
|
278
|
+
if not self.apply_response_schema_via_provider:
|
279
|
+
self._validate_response(
|
280
|
+
completion,
|
281
|
+
response_schema=response_schema,
|
282
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
283
|
+
)
|
284
|
+
if tools is not None:
|
285
|
+
self._validate_tool_calls(completion, tools=tools)
|
314
286
|
|
315
287
|
return iterator()
|
316
288
|
|
@@ -318,6 +290,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
318
290
|
self,
|
319
291
|
conversation: Messages,
|
320
292
|
*,
|
293
|
+
response_schema: Any | None = None,
|
294
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
295
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
321
296
|
tool_choice: ToolChoice | None = None,
|
322
297
|
n_choices: int | None = None,
|
323
298
|
proc_name: str | None = None,
|
@@ -330,6 +305,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
330
305
|
try:
|
331
306
|
async for event in await self._generate_completion_stream_once( # type: ignore[return]
|
332
307
|
conversation, # type: ignore[arg-type]
|
308
|
+
response_schema=response_schema,
|
309
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
310
|
+
tools=tools,
|
333
311
|
tool_choice=tool_choice,
|
334
312
|
n_choices=n_choices,
|
335
313
|
proc_name=proc_name,
|
@@ -339,7 +317,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
339
317
|
return
|
340
318
|
except (LLMResponseValidationError, LLMToolCallValidationError) as err:
|
341
319
|
err_data = LLMStreamingErrorData(
|
342
|
-
error=err, model_name=self.
|
320
|
+
error=err, model_name=self.model_name, model_id=self.model_id
|
343
321
|
)
|
344
322
|
yield LLMStreamingErrorEvent(
|
345
323
|
data=err_data, proc_name=proc_name, call_id=call_id
|
@@ -355,7 +333,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
355
333
|
f"retrying:\n{err}"
|
356
334
|
)
|
357
335
|
refusal_completion = make_refusal_completion(
|
358
|
-
self.
|
336
|
+
self.model_name, err
|
359
337
|
)
|
360
338
|
yield CompletionEvent(
|
361
339
|
data=refusal_completion,
|
@@ -118,8 +118,10 @@ class LiteLLMConverters(Converters):
|
|
118
118
|
return from_api_tool_message(raw_message, name=name, **kwargs)
|
119
119
|
|
120
120
|
@staticmethod
|
121
|
-
def to_tool(
|
122
|
-
|
121
|
+
def to_tool(
|
122
|
+
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
|
123
|
+
) -> OpenAIToolParam:
|
124
|
+
return to_api_tool(tool, strict=strict, **kwargs)
|
123
125
|
|
124
126
|
@staticmethod
|
125
127
|
def to_tool_choice(
|
@@ -2,6 +2,7 @@ import logging
|
|
2
2
|
from collections import defaultdict
|
3
3
|
from collections.abc import AsyncIterator, Mapping
|
4
4
|
from copy import deepcopy
|
5
|
+
from dataclasses import dataclass, field
|
5
6
|
from typing import Any, cast
|
6
7
|
|
7
8
|
import litellm
|
@@ -21,7 +22,7 @@ from litellm.utils import (
|
|
21
22
|
# from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
|
22
23
|
from pydantic import BaseModel
|
23
24
|
|
24
|
-
from ..cloud_llm import APIProvider, CloudLLM
|
25
|
+
from ..cloud_llm import APIProvider, CloudLLM
|
25
26
|
from ..openai.openai_llm import OpenAILLMSettings
|
26
27
|
from ..typing.tool import BaseTool
|
27
28
|
from . import (
|
@@ -40,116 +41,101 @@ class LiteLLMSettings(OpenAILLMSettings, total=False):
|
|
40
41
|
thinking: AnthropicThinkingParam | None
|
41
42
|
|
42
43
|
|
44
|
+
@dataclass(frozen=True)
|
43
45
|
class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
elif api_provider is
|
87
|
-
self._lite_llm_completion_params["api_key"] = api_provider.get("api_key")
|
88
|
-
self._lite_llm_completion_params["api_base"] = api_provider.get("api_base")
|
89
|
-
elif api_provider is None:
|
46
|
+
converters: LiteLLMConverters = field(default_factory=LiteLLMConverters)
|
47
|
+
|
48
|
+
timeout: float | None = None
|
49
|
+
# Drop unsupported LLM settings
|
50
|
+
drop_params: bool = True
|
51
|
+
additional_drop_params: list[str] | None = None
|
52
|
+
allowed_openai_params: list[str] | None = None
|
53
|
+
# Mock LLM response for testing
|
54
|
+
mock_response: str | None = None
|
55
|
+
|
56
|
+
_lite_llm_completion_params: dict[str, Any] = field(
|
57
|
+
default_factory=dict[str, Any], init=False, repr=False, compare=False
|
58
|
+
)
|
59
|
+
|
60
|
+
def __post_init__(self) -> None:
|
61
|
+
super().__post_init__()
|
62
|
+
|
63
|
+
self._lite_llm_completion_params.update(
|
64
|
+
{
|
65
|
+
"max_retries": self.max_client_retries,
|
66
|
+
"timeout": self.timeout,
|
67
|
+
"drop_params": self.drop_params,
|
68
|
+
"additional_drop_params": self.additional_drop_params,
|
69
|
+
"allowed_openai_params": self.allowed_openai_params,
|
70
|
+
"mock_response": self.mock_response,
|
71
|
+
# "deployment_id": deployment_id,
|
72
|
+
# "api_version": api_version,
|
73
|
+
}
|
74
|
+
)
|
75
|
+
|
76
|
+
_api_provider = self.api_provider
|
77
|
+
|
78
|
+
if self.model_name in litellm.get_valid_models(): # type: ignore[no-untyped-call]
|
79
|
+
_, provider_name, _, _ = litellm.get_llm_provider(self.model_name) # type: ignore[no-untyped-call]
|
80
|
+
_api_provider = APIProvider(name=provider_name)
|
81
|
+
elif self.api_provider is not None:
|
82
|
+
self._lite_llm_completion_params["api_key"] = self.api_provider.get(
|
83
|
+
"api_key"
|
84
|
+
)
|
85
|
+
self._lite_llm_completion_params["api_base"] = self.api_provider.get(
|
86
|
+
"api_base"
|
87
|
+
)
|
88
|
+
elif self.api_provider is None:
|
90
89
|
raise ValueError(
|
91
|
-
f"Model '{model_name}' is not supported by LiteLLM and no API provider "
|
90
|
+
f"Model '{self.model_name}' is not supported by LiteLLM and no API provider "
|
92
91
|
"was specified. Please provide a valid API provider or use a different "
|
93
92
|
"model."
|
94
93
|
)
|
95
94
|
|
96
|
-
if llm_settings is not None:
|
97
|
-
stream_options = llm_settings.get("stream_options") or {}
|
95
|
+
if self.llm_settings is not None:
|
96
|
+
stream_options = self.llm_settings.get("stream_options") or {}
|
98
97
|
stream_options["include_usage"] = True
|
99
|
-
_llm_settings = deepcopy(llm_settings)
|
98
|
+
_llm_settings = deepcopy(self.llm_settings)
|
100
99
|
_llm_settings["stream_options"] = stream_options
|
101
100
|
else:
|
102
101
|
_llm_settings = LiteLLMSettings(stream_options={"include_usage": True})
|
103
102
|
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
apply_response_schema_via_provider=apply_response_schema_via_provider,
|
113
|
-
api_provider=api_provider,
|
114
|
-
rate_limiter=rate_limiter,
|
115
|
-
max_client_retries=max_client_retries,
|
116
|
-
max_response_retries=max_response_retries,
|
117
|
-
)
|
103
|
+
if (
|
104
|
+
self.apply_response_schema_via_provider
|
105
|
+
and not self.supports_response_schema
|
106
|
+
):
|
107
|
+
raise ValueError(
|
108
|
+
f"Model '{self.model_name}' does not support response schema "
|
109
|
+
"natively. Please set `apply_response_schema_via_provider=False`"
|
110
|
+
)
|
118
111
|
|
119
|
-
|
120
|
-
|
121
|
-
for tool in self._tools.values():
|
122
|
-
tool.strict = True
|
123
|
-
if not self.supports_response_schema:
|
124
|
-
raise ValueError(
|
125
|
-
f"Model '{self._model_name}' does not support response schema "
|
126
|
-
"natively. Please set `apply_response_schema_via_provider=False`"
|
127
|
-
)
|
112
|
+
object.__setattr__(self, "api_provider", _api_provider)
|
113
|
+
object.__setattr__(self, "llm_settings", _llm_settings)
|
128
114
|
|
129
115
|
def get_supported_openai_params(self) -> list[Any] | None:
|
130
116
|
return get_supported_openai_params( # type: ignore[no-untyped-call]
|
131
|
-
model=self.
|
117
|
+
model=self.model_name, request_type="chat_completion"
|
132
118
|
)
|
133
119
|
|
134
120
|
@property
|
135
121
|
def supports_reasoning(self) -> bool:
|
136
|
-
return supports_reasoning(model=self.
|
122
|
+
return supports_reasoning(model=self.model_name)
|
137
123
|
|
138
124
|
@property
|
139
125
|
def supports_parallel_function_calling(self) -> bool:
|
140
|
-
return supports_parallel_function_calling(model=self.
|
126
|
+
return supports_parallel_function_calling(model=self.model_name)
|
141
127
|
|
142
128
|
@property
|
143
129
|
def supports_prompt_caching(self) -> bool:
|
144
|
-
return supports_prompt_caching(model=self.
|
130
|
+
return supports_prompt_caching(model=self.model_name)
|
145
131
|
|
146
132
|
@property
|
147
133
|
def supports_response_schema(self) -> bool:
|
148
|
-
return supports_response_schema(model=self.
|
134
|
+
return supports_response_schema(model=self.model_name)
|
149
135
|
|
150
136
|
@property
|
151
137
|
def supports_tool_choice(self) -> bool:
|
152
|
-
return supports_tool_choice(model=self.
|
138
|
+
return supports_tool_choice(model=self.model_name)
|
153
139
|
|
154
140
|
# # client
|
155
141
|
# model_list: Optional[list] = (None,) # pass in a list of api_base,keys, etc.
|
@@ -164,7 +150,7 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
|
|
164
150
|
**api_llm_settings: Any,
|
165
151
|
) -> LiteLLMCompletion:
|
166
152
|
completion = await litellm.acompletion( # type: ignore[no-untyped-call]
|
167
|
-
model=self.
|
153
|
+
model=self.model_name,
|
168
154
|
messages=api_messages,
|
169
155
|
tools=api_tools,
|
170
156
|
tool_choice=api_tool_choice, # type: ignore[arg-type]
|
@@ -191,7 +177,7 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
|
|
191
177
|
**api_llm_settings: Any,
|
192
178
|
) -> AsyncIterator[LiteLLMCompletionChunk]:
|
193
179
|
stream = await litellm.acompletion( # type: ignore[no-untyped-call]
|
194
|
-
model=self.
|
180
|
+
model=self.model_name,
|
195
181
|
messages=api_messages,
|
196
182
|
tools=api_tools,
|
197
183
|
tool_choice=api_tool_choice, # type: ignore[arg-type]
|
@@ -217,7 +203,10 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
|
|
217
203
|
yield completion_chunk
|
218
204
|
|
219
205
|
def combine_completion_chunks(
|
220
|
-
self,
|
206
|
+
self,
|
207
|
+
completion_chunks: list[LiteLLMCompletionChunk],
|
208
|
+
response_schema: Any | None = None,
|
209
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
221
210
|
) -> LiteLLMCompletion:
|
222
211
|
combined_chunk = cast(
|
223
212
|
"LiteLLMCompletion",
|