openhands 0.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openhands might be problematic. Click here for more details.
- openhands-1.0.1.dist-info/METADATA +52 -0
- openhands-1.0.1.dist-info/RECORD +31 -0
- {openhands-0.0.0.dist-info → openhands-1.0.1.dist-info}/WHEEL +1 -2
- openhands-1.0.1.dist-info/entry_points.txt +2 -0
- openhands_cli/__init__.py +8 -0
- openhands_cli/agent_chat.py +186 -0
- openhands_cli/argparsers/main_parser.py +56 -0
- openhands_cli/argparsers/serve_parser.py +31 -0
- openhands_cli/gui_launcher.py +220 -0
- openhands_cli/listeners/__init__.py +4 -0
- openhands_cli/listeners/loading_listener.py +63 -0
- openhands_cli/listeners/pause_listener.py +83 -0
- openhands_cli/llm_utils.py +57 -0
- openhands_cli/locations.py +13 -0
- openhands_cli/pt_style.py +30 -0
- openhands_cli/runner.py +178 -0
- openhands_cli/setup.py +116 -0
- openhands_cli/simple_main.py +59 -0
- openhands_cli/tui/__init__.py +5 -0
- openhands_cli/tui/settings/mcp_screen.py +217 -0
- openhands_cli/tui/settings/settings_screen.py +202 -0
- openhands_cli/tui/settings/store.py +93 -0
- openhands_cli/tui/status.py +109 -0
- openhands_cli/tui/tui.py +100 -0
- openhands_cli/tui/utils.py +14 -0
- openhands_cli/user_actions/__init__.py +17 -0
- openhands_cli/user_actions/agent_action.py +95 -0
- openhands_cli/user_actions/exit_session.py +18 -0
- openhands_cli/user_actions/settings_action.py +171 -0
- openhands_cli/user_actions/types.py +18 -0
- openhands_cli/user_actions/utils.py +199 -0
- openhands/__init__.py +0 -1
- openhands/sdk/__init__.py +0 -45
- openhands/sdk/agent/__init__.py +0 -8
- openhands/sdk/agent/agent/__init__.py +0 -6
- openhands/sdk/agent/agent/agent.py +0 -349
- openhands/sdk/agent/base.py +0 -103
- openhands/sdk/context/__init__.py +0 -28
- openhands/sdk/context/agent_context.py +0 -153
- openhands/sdk/context/condenser/__init__.py +0 -5
- openhands/sdk/context/condenser/condenser.py +0 -73
- openhands/sdk/context/condenser/no_op_condenser.py +0 -13
- openhands/sdk/context/manager.py +0 -5
- openhands/sdk/context/microagents/__init__.py +0 -26
- openhands/sdk/context/microagents/exceptions.py +0 -11
- openhands/sdk/context/microagents/microagent.py +0 -345
- openhands/sdk/context/microagents/types.py +0 -70
- openhands/sdk/context/utils/__init__.py +0 -8
- openhands/sdk/context/utils/prompt.py +0 -52
- openhands/sdk/context/view.py +0 -116
- openhands/sdk/conversation/__init__.py +0 -12
- openhands/sdk/conversation/conversation.py +0 -207
- openhands/sdk/conversation/state.py +0 -50
- openhands/sdk/conversation/types.py +0 -6
- openhands/sdk/conversation/visualizer.py +0 -300
- openhands/sdk/event/__init__.py +0 -27
- openhands/sdk/event/base.py +0 -148
- openhands/sdk/event/condenser.py +0 -49
- openhands/sdk/event/llm_convertible.py +0 -265
- openhands/sdk/event/types.py +0 -5
- openhands/sdk/event/user_action.py +0 -12
- openhands/sdk/event/utils.py +0 -30
- openhands/sdk/llm/__init__.py +0 -19
- openhands/sdk/llm/exceptions.py +0 -108
- openhands/sdk/llm/llm.py +0 -867
- openhands/sdk/llm/llm_registry.py +0 -116
- openhands/sdk/llm/message.py +0 -216
- openhands/sdk/llm/metadata.py +0 -34
- openhands/sdk/llm/utils/fn_call_converter.py +0 -1049
- openhands/sdk/llm/utils/metrics.py +0 -311
- openhands/sdk/llm/utils/model_features.py +0 -153
- openhands/sdk/llm/utils/retry_mixin.py +0 -122
- openhands/sdk/llm/utils/telemetry.py +0 -252
- openhands/sdk/logger.py +0 -167
- openhands/sdk/mcp/__init__.py +0 -20
- openhands/sdk/mcp/client.py +0 -113
- openhands/sdk/mcp/definition.py +0 -69
- openhands/sdk/mcp/tool.py +0 -104
- openhands/sdk/mcp/utils.py +0 -59
- openhands/sdk/tests/llm/test_llm.py +0 -447
- openhands/sdk/tests/llm/test_llm_fncall_converter.py +0 -691
- openhands/sdk/tests/llm/test_model_features.py +0 -221
- openhands/sdk/tool/__init__.py +0 -30
- openhands/sdk/tool/builtins/__init__.py +0 -34
- openhands/sdk/tool/builtins/finish.py +0 -57
- openhands/sdk/tool/builtins/think.py +0 -60
- openhands/sdk/tool/schema.py +0 -236
- openhands/sdk/tool/security_prompt.py +0 -5
- openhands/sdk/tool/tool.py +0 -142
- openhands/sdk/utils/__init__.py +0 -14
- openhands/sdk/utils/discriminated_union.py +0 -210
- openhands/sdk/utils/json.py +0 -48
- openhands/sdk/utils/truncate.py +0 -44
- openhands/tools/__init__.py +0 -44
- openhands/tools/execute_bash/__init__.py +0 -30
- openhands/tools/execute_bash/constants.py +0 -31
- openhands/tools/execute_bash/definition.py +0 -166
- openhands/tools/execute_bash/impl.py +0 -38
- openhands/tools/execute_bash/metadata.py +0 -101
- openhands/tools/execute_bash/terminal/__init__.py +0 -22
- openhands/tools/execute_bash/terminal/factory.py +0 -113
- openhands/tools/execute_bash/terminal/interface.py +0 -189
- openhands/tools/execute_bash/terminal/subprocess_terminal.py +0 -412
- openhands/tools/execute_bash/terminal/terminal_session.py +0 -492
- openhands/tools/execute_bash/terminal/tmux_terminal.py +0 -160
- openhands/tools/execute_bash/utils/command.py +0 -150
- openhands/tools/str_replace_editor/__init__.py +0 -17
- openhands/tools/str_replace_editor/definition.py +0 -158
- openhands/tools/str_replace_editor/editor.py +0 -683
- openhands/tools/str_replace_editor/exceptions.py +0 -41
- openhands/tools/str_replace_editor/impl.py +0 -66
- openhands/tools/str_replace_editor/utils/__init__.py +0 -0
- openhands/tools/str_replace_editor/utils/config.py +0 -2
- openhands/tools/str_replace_editor/utils/constants.py +0 -9
- openhands/tools/str_replace_editor/utils/encoding.py +0 -135
- openhands/tools/str_replace_editor/utils/file_cache.py +0 -154
- openhands/tools/str_replace_editor/utils/history.py +0 -122
- openhands/tools/str_replace_editor/utils/shell.py +0 -72
- openhands/tools/task_tracker/__init__.py +0 -16
- openhands/tools/task_tracker/definition.py +0 -336
- openhands/tools/utils/__init__.py +0 -1
- openhands-0.0.0.dist-info/METADATA +0 -3
- openhands-0.0.0.dist-info/RECORD +0 -94
- openhands-0.0.0.dist-info/top_level.txt +0 -1
openhands/sdk/llm/llm.py
DELETED
|
@@ -1,867 +0,0 @@
|
|
|
1
|
-
import copy
|
|
2
|
-
import json
|
|
3
|
-
import os
|
|
4
|
-
import time
|
|
5
|
-
import warnings
|
|
6
|
-
from contextlib import contextmanager
|
|
7
|
-
from typing import Any, Callable, Literal, TypeGuard, cast, get_args, get_origin
|
|
8
|
-
|
|
9
|
-
import httpx
|
|
10
|
-
from pydantic import (
|
|
11
|
-
BaseModel,
|
|
12
|
-
ConfigDict,
|
|
13
|
-
Field,
|
|
14
|
-
PrivateAttr,
|
|
15
|
-
SecretStr,
|
|
16
|
-
field_validator,
|
|
17
|
-
model_validator,
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
with warnings.catch_warnings():
|
|
22
|
-
warnings.simplefilter("ignore")
|
|
23
|
-
import litellm
|
|
24
|
-
|
|
25
|
-
from litellm import (
|
|
26
|
-
ChatCompletionToolParam,
|
|
27
|
-
Message as LiteLLMMessage,
|
|
28
|
-
completion as litellm_completion,
|
|
29
|
-
)
|
|
30
|
-
from litellm.exceptions import (
|
|
31
|
-
APIConnectionError,
|
|
32
|
-
InternalServerError,
|
|
33
|
-
RateLimitError,
|
|
34
|
-
ServiceUnavailableError,
|
|
35
|
-
Timeout as LiteLLMTimeout,
|
|
36
|
-
)
|
|
37
|
-
from litellm.types.utils import (
|
|
38
|
-
Choices,
|
|
39
|
-
ModelResponse,
|
|
40
|
-
StreamingChoices,
|
|
41
|
-
)
|
|
42
|
-
from litellm.utils import (
|
|
43
|
-
create_pretrained_tokenizer,
|
|
44
|
-
get_model_info,
|
|
45
|
-
supports_vision,
|
|
46
|
-
token_counter,
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
# OpenHands utilities
|
|
50
|
-
from openhands.sdk.llm.exceptions import LLMNoResponseError
|
|
51
|
-
from openhands.sdk.llm.message import Message
|
|
52
|
-
from openhands.sdk.llm.utils.fn_call_converter import (
|
|
53
|
-
STOP_WORDS,
|
|
54
|
-
convert_fncall_messages_to_non_fncall_messages,
|
|
55
|
-
convert_non_fncall_messages_to_fncall_messages,
|
|
56
|
-
)
|
|
57
|
-
from openhands.sdk.llm.utils.metrics import Metrics
|
|
58
|
-
from openhands.sdk.llm.utils.model_features import get_features
|
|
59
|
-
from openhands.sdk.llm.utils.telemetry import Telemetry
|
|
60
|
-
from openhands.sdk.logger import ENV_LOG_DIR, get_logger
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
logger = get_logger(__name__)
|
|
64
|
-
|
|
65
|
-
__all__ = ["LLM"]
|
|
66
|
-
|
|
67
|
-
# Exceptions we retry on
|
|
68
|
-
LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
|
|
69
|
-
APIConnectionError,
|
|
70
|
-
RateLimitError,
|
|
71
|
-
ServiceUnavailableError,
|
|
72
|
-
LiteLLMTimeout,
|
|
73
|
-
InternalServerError,
|
|
74
|
-
LLMNoResponseError,
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class RetryMixin:
|
|
79
|
-
"""Minimal retry mixin kept from your original design."""
|
|
80
|
-
|
|
81
|
-
def retry_decorator(
|
|
82
|
-
self,
|
|
83
|
-
*,
|
|
84
|
-
num_retries: int,
|
|
85
|
-
retry_exceptions: tuple[type[Exception], ...],
|
|
86
|
-
retry_min_wait: int,
|
|
87
|
-
retry_max_wait: int,
|
|
88
|
-
retry_multiplier: float,
|
|
89
|
-
retry_listener: Callable[[int, int], None] | None = None,
|
|
90
|
-
):
|
|
91
|
-
def decorator(fn: Callable[[], Any]):
|
|
92
|
-
def wrapped():
|
|
93
|
-
import random
|
|
94
|
-
|
|
95
|
-
attempt = 0
|
|
96
|
-
wait = retry_min_wait
|
|
97
|
-
last_exc = None
|
|
98
|
-
while attempt < num_retries:
|
|
99
|
-
try:
|
|
100
|
-
return fn()
|
|
101
|
-
except retry_exceptions as e:
|
|
102
|
-
last_exc = e
|
|
103
|
-
if attempt == num_retries - 1:
|
|
104
|
-
break
|
|
105
|
-
# jittered exponential backoff
|
|
106
|
-
sleep_for = min(
|
|
107
|
-
retry_max_wait, int(wait + random.uniform(0, 1))
|
|
108
|
-
)
|
|
109
|
-
if retry_listener:
|
|
110
|
-
retry_listener(attempt + 1, num_retries)
|
|
111
|
-
time.sleep(sleep_for)
|
|
112
|
-
wait = max(retry_min_wait, int(wait * retry_multiplier))
|
|
113
|
-
attempt += 1
|
|
114
|
-
assert last_exc is not None
|
|
115
|
-
raise last_exc
|
|
116
|
-
|
|
117
|
-
return wrapped
|
|
118
|
-
|
|
119
|
-
return decorator
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
class LLM(BaseModel, RetryMixin):
|
|
123
|
-
"""Refactored LLM: simple `completion()`, centralized Telemetry, tiny helpers."""
|
|
124
|
-
|
|
125
|
-
# =========================================================================
|
|
126
|
-
# Config fields
|
|
127
|
-
# =========================================================================
|
|
128
|
-
model: str = Field(default="claude-sonnet-4-20250514", description="Model name.")
|
|
129
|
-
api_key: SecretStr | None = Field(default=None, description="API key.")
|
|
130
|
-
base_url: str | None = Field(default=None, description="Custom base URL.")
|
|
131
|
-
api_version: str | None = Field(
|
|
132
|
-
default=None, description="API version (e.g., Azure)."
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
aws_access_key_id: SecretStr | None = Field(default=None)
|
|
136
|
-
aws_secret_access_key: SecretStr | None = Field(default=None)
|
|
137
|
-
aws_region_name: str | None = Field(default=None)
|
|
138
|
-
|
|
139
|
-
openrouter_site_url: str = Field(default="https://docs.all-hands.dev/")
|
|
140
|
-
openrouter_app_name: str = Field(default="OpenHands")
|
|
141
|
-
|
|
142
|
-
num_retries: int = Field(default=5)
|
|
143
|
-
retry_multiplier: float = Field(default=8)
|
|
144
|
-
retry_min_wait: int = Field(default=8)
|
|
145
|
-
retry_max_wait: int = Field(default=64)
|
|
146
|
-
|
|
147
|
-
timeout: int | None = Field(default=None, description="HTTP timeout (s).")
|
|
148
|
-
|
|
149
|
-
max_message_chars: int = Field(
|
|
150
|
-
default=30_000,
|
|
151
|
-
description="Approx max chars in each event/content sent to the LLM.",
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
temperature: float | None = Field(default=0.0)
|
|
155
|
-
top_p: float | None = Field(default=1.0)
|
|
156
|
-
top_k: float | None = Field(default=None)
|
|
157
|
-
|
|
158
|
-
custom_llm_provider: str | None = Field(default=None)
|
|
159
|
-
max_input_tokens: int | None = Field(
|
|
160
|
-
default=None,
|
|
161
|
-
description="The maximum number of input tokens. "
|
|
162
|
-
"Note that this is currently unused, and the value at runtime is actually"
|
|
163
|
-
" the total tokens in OpenAI (e.g. 128,000 tokens for GPT-4).",
|
|
164
|
-
)
|
|
165
|
-
max_output_tokens: int | None = Field(
|
|
166
|
-
default=None,
|
|
167
|
-
description="The maximum number of output tokens. This is sent to the LLM.",
|
|
168
|
-
)
|
|
169
|
-
input_cost_per_token: float | None = Field(
|
|
170
|
-
default=None,
|
|
171
|
-
description="The cost per input token. This will available in logs for user.",
|
|
172
|
-
)
|
|
173
|
-
output_cost_per_token: float | None = Field(
|
|
174
|
-
default=None,
|
|
175
|
-
description="The cost per output token. This will available in logs for user.",
|
|
176
|
-
)
|
|
177
|
-
ollama_base_url: str | None = Field(default=None)
|
|
178
|
-
|
|
179
|
-
drop_params: bool = Field(default=True)
|
|
180
|
-
modify_params: bool = Field(
|
|
181
|
-
default=True,
|
|
182
|
-
description="Modify params allows litellm to do transformations like adding"
|
|
183
|
-
" a default message, when a message is empty.",
|
|
184
|
-
)
|
|
185
|
-
disable_vision: bool | None = Field(
|
|
186
|
-
default=None,
|
|
187
|
-
description="If model is vision capable, this option allows to disable image "
|
|
188
|
-
"processing (useful for cost reduction).",
|
|
189
|
-
)
|
|
190
|
-
disable_stop_word: bool | None = Field(
|
|
191
|
-
default=False, description="Disable using of stop word."
|
|
192
|
-
)
|
|
193
|
-
caching_prompt: bool = Field(default=True, description="Enable caching of prompts.")
|
|
194
|
-
log_completions: bool = Field(
|
|
195
|
-
default=False, description="Enable logging of completions."
|
|
196
|
-
)
|
|
197
|
-
log_completions_folder: str = Field(
|
|
198
|
-
default=os.path.join(ENV_LOG_DIR, "completions"),
|
|
199
|
-
description="The folder to log LLM completions to. "
|
|
200
|
-
"Required if log_completions is True.",
|
|
201
|
-
)
|
|
202
|
-
custom_tokenizer: str | None = Field(
|
|
203
|
-
default=None, description="A custom tokenizer to use for token counting."
|
|
204
|
-
)
|
|
205
|
-
native_tool_calling: bool | None = Field(
|
|
206
|
-
default=None,
|
|
207
|
-
description="Whether to use native tool calling "
|
|
208
|
-
"if supported by the model. Can be True, False, or not set.",
|
|
209
|
-
)
|
|
210
|
-
reasoning_effort: Literal["low", "medium", "high", "none"] | None = Field(
|
|
211
|
-
default=None,
|
|
212
|
-
description="The effort to put into reasoning. "
|
|
213
|
-
"This is a string that can be one of 'low', 'medium', 'high', or 'none'. "
|
|
214
|
-
"Can apply to all reasoning models.",
|
|
215
|
-
)
|
|
216
|
-
seed: int | None = Field(
|
|
217
|
-
default=None, description="The seed to use for random number generation."
|
|
218
|
-
)
|
|
219
|
-
safety_settings: list[dict[str, str]] | None = Field(
|
|
220
|
-
default=None,
|
|
221
|
-
description=(
|
|
222
|
-
"Safety settings for models that support them (like Mistral AI and Gemini)"
|
|
223
|
-
),
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
# =========================================================================
|
|
227
|
-
# Internal fields (excluded from dumps)
|
|
228
|
-
# =========================================================================
|
|
229
|
-
service_id: str = Field(default="default", exclude=True)
|
|
230
|
-
metrics: Metrics | None = Field(default=None, exclude=True)
|
|
231
|
-
retry_listener: Callable[[int, int], None] | None = Field(
|
|
232
|
-
default=None, exclude=True
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
# Runtime-only private attrs
|
|
236
|
-
_model_info: Any = PrivateAttr(default=None)
|
|
237
|
-
_tokenizer: Any = PrivateAttr(default=None)
|
|
238
|
-
_function_calling_active: bool = PrivateAttr(default=False)
|
|
239
|
-
_telemetry: Telemetry | None = PrivateAttr(default=None)
|
|
240
|
-
|
|
241
|
-
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
|
242
|
-
|
|
243
|
-
# =========================================================================
|
|
244
|
-
# Validators
|
|
245
|
-
# =========================================================================
|
|
246
|
-
@field_validator("api_key", mode="before")
|
|
247
|
-
@classmethod
|
|
248
|
-
def _validate_api_key(cls, v):
|
|
249
|
-
"""Convert empty API keys to None to allow boto3 to use alternative auth methods.""" # noqa: E501
|
|
250
|
-
if v is None:
|
|
251
|
-
return None
|
|
252
|
-
|
|
253
|
-
# Handle both SecretStr and string inputs
|
|
254
|
-
if isinstance(v, SecretStr):
|
|
255
|
-
secret_value = v.get_secret_value()
|
|
256
|
-
else:
|
|
257
|
-
secret_value = str(v)
|
|
258
|
-
|
|
259
|
-
# If the API key is empty or whitespace-only, return None
|
|
260
|
-
if not secret_value or not secret_value.strip():
|
|
261
|
-
return None
|
|
262
|
-
|
|
263
|
-
return v
|
|
264
|
-
|
|
265
|
-
@model_validator(mode="before")
|
|
266
|
-
@classmethod
|
|
267
|
-
def _coerce_inputs(cls, data):
|
|
268
|
-
if not isinstance(data, dict):
|
|
269
|
-
return data
|
|
270
|
-
d = dict(data)
|
|
271
|
-
|
|
272
|
-
model_val = d.get("model")
|
|
273
|
-
if not model_val:
|
|
274
|
-
raise ValueError("model must be specified in LLM")
|
|
275
|
-
|
|
276
|
-
# default reasoning_effort unless Gemini 2.5
|
|
277
|
-
# (we keep consistent with old behavior)
|
|
278
|
-
if d.get("reasoning_effort") is None and "gemini-2.5-pro" not in model_val:
|
|
279
|
-
d["reasoning_effort"] = "high"
|
|
280
|
-
|
|
281
|
-
# Azure default version
|
|
282
|
-
if model_val.startswith("azure") and not d.get("api_version"):
|
|
283
|
-
d["api_version"] = "2024-12-01-preview"
|
|
284
|
-
|
|
285
|
-
# Provider rewrite: openhands/* -> litellm_proxy/*
|
|
286
|
-
if model_val.startswith("openhands/"):
|
|
287
|
-
model_name = model_val.removeprefix("openhands/")
|
|
288
|
-
d["model"] = f"litellm_proxy/{model_name}"
|
|
289
|
-
d.setdefault("base_url", "https://llm-proxy.app.all-hands.dev/")
|
|
290
|
-
|
|
291
|
-
# HF doesn't support the OpenAI default value for top_p (1)
|
|
292
|
-
if model_val.startswith("huggingface"):
|
|
293
|
-
if d.get("top_p", 1.0) == 1.0:
|
|
294
|
-
d["top_p"] = 0.9
|
|
295
|
-
|
|
296
|
-
return d
|
|
297
|
-
|
|
298
|
-
@model_validator(mode="after")
|
|
299
|
-
def _set_env_side_effects(self):
|
|
300
|
-
if self.openrouter_site_url:
|
|
301
|
-
os.environ["OR_SITE_URL"] = self.openrouter_site_url
|
|
302
|
-
if self.openrouter_app_name:
|
|
303
|
-
os.environ["OR_APP_NAME"] = self.openrouter_app_name
|
|
304
|
-
if self.aws_access_key_id:
|
|
305
|
-
os.environ["AWS_ACCESS_KEY_ID"] = self.aws_access_key_id.get_secret_value()
|
|
306
|
-
if self.aws_secret_access_key:
|
|
307
|
-
os.environ["AWS_SECRET_ACCESS_KEY"] = (
|
|
308
|
-
self.aws_secret_access_key.get_secret_value()
|
|
309
|
-
)
|
|
310
|
-
if self.aws_region_name:
|
|
311
|
-
os.environ["AWS_REGION_NAME"] = self.aws_region_name
|
|
312
|
-
|
|
313
|
-
# Metrics + Telemetry wiring
|
|
314
|
-
if self.metrics is None:
|
|
315
|
-
self.metrics = Metrics(model_name=self.model)
|
|
316
|
-
|
|
317
|
-
self._telemetry = Telemetry(
|
|
318
|
-
model_name=self.model,
|
|
319
|
-
log_enabled=self.log_completions,
|
|
320
|
-
log_dir=self.log_completions_folder if self.log_completions else None,
|
|
321
|
-
metrics=self.metrics,
|
|
322
|
-
)
|
|
323
|
-
|
|
324
|
-
# Tokenizer
|
|
325
|
-
if self.custom_tokenizer:
|
|
326
|
-
self._tokenizer = create_pretrained_tokenizer(self.custom_tokenizer)
|
|
327
|
-
|
|
328
|
-
# Capabilities + model info
|
|
329
|
-
self._init_model_info_and_caps()
|
|
330
|
-
|
|
331
|
-
logger.debug(
|
|
332
|
-
f"LLM ready: model={self.model} base_url={self.base_url} "
|
|
333
|
-
f"reasoning_effort={self.reasoning_effort}"
|
|
334
|
-
)
|
|
335
|
-
return self
|
|
336
|
-
|
|
337
|
-
# =========================================================================
|
|
338
|
-
# Public API
|
|
339
|
-
# =========================================================================
|
|
340
|
-
def completion(
|
|
341
|
-
self,
|
|
342
|
-
messages: list[dict[str, Any]] | list[Message],
|
|
343
|
-
tools: list[ChatCompletionToolParam] | None = None,
|
|
344
|
-
return_metrics: bool = False,
|
|
345
|
-
**kwargs,
|
|
346
|
-
) -> ModelResponse:
|
|
347
|
-
"""Single entry point for LLM completion.
|
|
348
|
-
|
|
349
|
-
Normalize → (maybe) mock tools → transport → postprocess.
|
|
350
|
-
"""
|
|
351
|
-
# Check if streaming is requested
|
|
352
|
-
if kwargs.get("stream", False):
|
|
353
|
-
raise ValueError("Streaming is not supported")
|
|
354
|
-
|
|
355
|
-
# 1) serialize messages
|
|
356
|
-
if messages and isinstance(messages[0], Message):
|
|
357
|
-
messages = self.format_messages_for_llm(cast(list[Message], messages))
|
|
358
|
-
else:
|
|
359
|
-
messages = cast(list[dict[str, Any]], messages)
|
|
360
|
-
|
|
361
|
-
# 2) choose function-calling strategy
|
|
362
|
-
use_native_fc = self.is_function_calling_active()
|
|
363
|
-
original_fncall_msgs = copy.deepcopy(messages)
|
|
364
|
-
if tools and not use_native_fc:
|
|
365
|
-
logger.debug(
|
|
366
|
-
"LLM.completion: mocking function-calling via prompt "
|
|
367
|
-
f"for model {self.model}"
|
|
368
|
-
)
|
|
369
|
-
messages, kwargs = self._pre_request_prompt_mock(messages, tools, kwargs)
|
|
370
|
-
|
|
371
|
-
# 3) normalize provider params
|
|
372
|
-
kwargs["tools"] = tools # we might remove this field in _normalize_call_kwargs
|
|
373
|
-
has_tools_flag = (
|
|
374
|
-
bool(tools) and use_native_fc
|
|
375
|
-
) # only keep tools when native FC is active
|
|
376
|
-
call_kwargs = self._normalize_call_kwargs(kwargs, has_tools=has_tools_flag)
|
|
377
|
-
|
|
378
|
-
# 4) optional request logging context (kept small)
|
|
379
|
-
assert self._telemetry is not None
|
|
380
|
-
log_ctx = None
|
|
381
|
-
if self._telemetry.log_enabled:
|
|
382
|
-
log_ctx = {
|
|
383
|
-
"messages": messages[:], # already simple dicts
|
|
384
|
-
"tools": tools,
|
|
385
|
-
"kwargs": {k: v for k, v in call_kwargs.items()},
|
|
386
|
-
"context_window": self.max_input_tokens,
|
|
387
|
-
}
|
|
388
|
-
if tools and not use_native_fc:
|
|
389
|
-
log_ctx["raw_messages"] = original_fncall_msgs
|
|
390
|
-
self._telemetry.on_request(log_ctx=log_ctx)
|
|
391
|
-
|
|
392
|
-
# 5) do the call with retries
|
|
393
|
-
@self.retry_decorator(
|
|
394
|
-
num_retries=self.num_retries,
|
|
395
|
-
retry_exceptions=LLM_RETRY_EXCEPTIONS,
|
|
396
|
-
retry_min_wait=self.retry_min_wait,
|
|
397
|
-
retry_max_wait=self.retry_max_wait,
|
|
398
|
-
retry_multiplier=self.retry_multiplier,
|
|
399
|
-
retry_listener=self.retry_listener,
|
|
400
|
-
)
|
|
401
|
-
def _one_attempt() -> ModelResponse:
|
|
402
|
-
assert self._telemetry is not None
|
|
403
|
-
resp = self._transport_call(messages=messages, **call_kwargs)
|
|
404
|
-
raw_resp: ModelResponse | None = None
|
|
405
|
-
if tools and not use_native_fc:
|
|
406
|
-
raw_resp = copy.deepcopy(resp)
|
|
407
|
-
resp = self._post_response_prompt_mock(
|
|
408
|
-
resp, nonfncall_msgs=messages, tools=tools
|
|
409
|
-
)
|
|
410
|
-
# 6) telemetry
|
|
411
|
-
self._telemetry.on_response(resp, raw_resp=raw_resp)
|
|
412
|
-
|
|
413
|
-
# Ensure at least one choice
|
|
414
|
-
if not resp.get("choices") or len(resp["choices"]) < 1:
|
|
415
|
-
raise LLMNoResponseError(
|
|
416
|
-
"Response choices is less than 1. Response: " + str(resp)
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
return resp
|
|
420
|
-
|
|
421
|
-
try:
|
|
422
|
-
resp = _one_attempt()
|
|
423
|
-
return resp
|
|
424
|
-
except Exception as e:
|
|
425
|
-
self._telemetry.on_error(e)
|
|
426
|
-
raise
|
|
427
|
-
|
|
428
|
-
# =========================================================================
|
|
429
|
-
# Transport + helpers
|
|
430
|
-
# =========================================================================
|
|
431
|
-
def _transport_call(
|
|
432
|
-
self, *, messages: list[dict[str, Any]], **kwargs
|
|
433
|
-
) -> ModelResponse:
|
|
434
|
-
# litellm.modify_params is GLOBAL; guard it for thread-safety
|
|
435
|
-
with self._litellm_modify_params_ctx(self.modify_params):
|
|
436
|
-
with warnings.catch_warnings():
|
|
437
|
-
warnings.filterwarnings(
|
|
438
|
-
"ignore", category=DeprecationWarning, module="httpx.*"
|
|
439
|
-
)
|
|
440
|
-
warnings.filterwarnings(
|
|
441
|
-
"ignore",
|
|
442
|
-
message=r".*content=.*upload.*",
|
|
443
|
-
category=DeprecationWarning,
|
|
444
|
-
)
|
|
445
|
-
# Some providers need renames handled in _normalize_call_kwargs.
|
|
446
|
-
ret = litellm_completion(
|
|
447
|
-
model=self.model,
|
|
448
|
-
api_key=self.api_key.get_secret_value() if self.api_key else None,
|
|
449
|
-
base_url=self.base_url,
|
|
450
|
-
api_version=self.api_version,
|
|
451
|
-
timeout=self.timeout,
|
|
452
|
-
drop_params=self.drop_params,
|
|
453
|
-
seed=self.seed,
|
|
454
|
-
messages=messages,
|
|
455
|
-
**kwargs,
|
|
456
|
-
)
|
|
457
|
-
assert isinstance(ret, ModelResponse), (
|
|
458
|
-
f"Expected ModelResponse, got {type(ret)}"
|
|
459
|
-
)
|
|
460
|
-
return ret
|
|
461
|
-
|
|
462
|
-
@contextmanager
|
|
463
|
-
def _litellm_modify_params_ctx(self, flag: bool):
|
|
464
|
-
old = getattr(litellm, "modify_params", None)
|
|
465
|
-
try:
|
|
466
|
-
litellm.modify_params = flag
|
|
467
|
-
yield
|
|
468
|
-
finally:
|
|
469
|
-
litellm.modify_params = old
|
|
470
|
-
|
|
471
|
-
def _normalize_call_kwargs(self, opts: dict, *, has_tools: bool) -> dict:
|
|
472
|
-
"""Central place for provider quirks + param harmonization."""
|
|
473
|
-
out = dict(opts)
|
|
474
|
-
|
|
475
|
-
# Respect configured sampling params unless reasoning models override
|
|
476
|
-
if self.top_k is not None:
|
|
477
|
-
out.setdefault("top_k", self.top_k)
|
|
478
|
-
if self.top_p is not None:
|
|
479
|
-
out.setdefault("top_p", self.top_p)
|
|
480
|
-
if self.temperature is not None:
|
|
481
|
-
out.setdefault("temperature", self.temperature)
|
|
482
|
-
|
|
483
|
-
# Max tokens wiring differences
|
|
484
|
-
if self.max_output_tokens is not None:
|
|
485
|
-
# OpenAI-compatible param is `max_completion_tokens`
|
|
486
|
-
out.setdefault("max_completion_tokens", self.max_output_tokens)
|
|
487
|
-
|
|
488
|
-
# Azure -> uses max_tokens instead
|
|
489
|
-
if self.model.startswith("azure"):
|
|
490
|
-
if "max_completion_tokens" in out:
|
|
491
|
-
out["max_tokens"] = out.pop("max_completion_tokens")
|
|
492
|
-
|
|
493
|
-
# Reasoning-model quirks
|
|
494
|
-
if get_features(self.model).supports_reasoning_effort:
|
|
495
|
-
# Preferred: use reasoning_effort
|
|
496
|
-
if self.reasoning_effort is not None:
|
|
497
|
-
out["reasoning_effort"] = self.reasoning_effort
|
|
498
|
-
# Anthropic/OpenAI reasoning models ignore temp/top_p
|
|
499
|
-
out.pop("temperature", None)
|
|
500
|
-
out.pop("top_p", None)
|
|
501
|
-
# Gemini 2.5-pro default to low if not set
|
|
502
|
-
# otherwise litellm doesn't send reasoning, even though it happens
|
|
503
|
-
if "gemini-2.5-pro" in self.model:
|
|
504
|
-
if self.reasoning_effort in {None, "none"}:
|
|
505
|
-
out["reasoning_effort"] = "low"
|
|
506
|
-
|
|
507
|
-
# Anthropic Opus 4.1: prefer temperature when
|
|
508
|
-
# both provided; disable extended thinking
|
|
509
|
-
if "claude-opus-4-1" in self.model.lower():
|
|
510
|
-
if "temperature" in out and "top_p" in out:
|
|
511
|
-
out.pop("top_p", None)
|
|
512
|
-
out.setdefault("thinking", {"type": "disabled"})
|
|
513
|
-
|
|
514
|
-
# Mistral / Gemini safety
|
|
515
|
-
if self.safety_settings:
|
|
516
|
-
ml = self.model.lower()
|
|
517
|
-
if "mistral" in ml or "gemini" in ml:
|
|
518
|
-
out["safety_settings"] = self.safety_settings
|
|
519
|
-
|
|
520
|
-
# Tools: if not using native, strip tool_choice so we don't confuse providers
|
|
521
|
-
if not has_tools:
|
|
522
|
-
out.pop("tools", None)
|
|
523
|
-
out.pop("tool_choice", None)
|
|
524
|
-
|
|
525
|
-
# non litellm proxy special-case: keep `extra_body` off unless model requires it
|
|
526
|
-
if "litellm_proxy" not in self.model:
|
|
527
|
-
out.pop("extra_body", None)
|
|
528
|
-
|
|
529
|
-
return out
|
|
530
|
-
|
|
531
|
-
def _pre_request_prompt_mock(
|
|
532
|
-
self, messages: list[dict], tools: list[ChatCompletionToolParam], kwargs: dict
|
|
533
|
-
) -> tuple[list[dict], dict]:
|
|
534
|
-
"""Convert to non-fncall prompting when native tool-calling is off."""
|
|
535
|
-
add_iclex = not any(s in self.model for s in ("openhands-lm", "devstral"))
|
|
536
|
-
messages = convert_fncall_messages_to_non_fncall_messages(
|
|
537
|
-
messages, tools, add_in_context_learning_example=add_iclex
|
|
538
|
-
)
|
|
539
|
-
if get_features(self.model).supports_stop_words and not self.disable_stop_word:
|
|
540
|
-
kwargs = dict(kwargs)
|
|
541
|
-
kwargs["stop"] = STOP_WORDS
|
|
542
|
-
|
|
543
|
-
# Ensure we don't send tool_choice when mocking
|
|
544
|
-
kwargs.pop("tool_choice", None)
|
|
545
|
-
return messages, kwargs
|
|
546
|
-
|
|
547
|
-
def _post_response_prompt_mock(
|
|
548
|
-
self,
|
|
549
|
-
resp: ModelResponse,
|
|
550
|
-
nonfncall_msgs: list[dict],
|
|
551
|
-
tools: list[ChatCompletionToolParam],
|
|
552
|
-
) -> ModelResponse:
|
|
553
|
-
if len(resp.choices) < 1:
|
|
554
|
-
raise LLMNoResponseError(
|
|
555
|
-
"Response choices is less than 1 (seen in some providers). Resp: "
|
|
556
|
-
+ str(resp)
|
|
557
|
-
)
|
|
558
|
-
|
|
559
|
-
def _all_choices(
|
|
560
|
-
items: list[Choices | StreamingChoices],
|
|
561
|
-
) -> TypeGuard[list[Choices]]:
|
|
562
|
-
return all(isinstance(c, Choices) for c in items)
|
|
563
|
-
|
|
564
|
-
if not _all_choices(resp.choices):
|
|
565
|
-
raise AssertionError(
|
|
566
|
-
"Expected non-streaming Choices when post-processing mocked tools"
|
|
567
|
-
)
|
|
568
|
-
|
|
569
|
-
# Preserve provider-specific reasoning fields before conversion
|
|
570
|
-
orig_msg = resp.choices[0].message
|
|
571
|
-
non_fn_message: dict = orig_msg.model_dump()
|
|
572
|
-
fn_msgs: list[dict] = convert_non_fncall_messages_to_fncall_messages(
|
|
573
|
-
nonfncall_msgs + [non_fn_message], tools
|
|
574
|
-
)
|
|
575
|
-
last: dict = fn_msgs[-1]
|
|
576
|
-
|
|
577
|
-
for name in ("reasoning_content", "provider_specific_fields"):
|
|
578
|
-
val = getattr(orig_msg, name, None)
|
|
579
|
-
if not val:
|
|
580
|
-
continue
|
|
581
|
-
last[name] = val
|
|
582
|
-
|
|
583
|
-
resp.choices[0].message = LiteLLMMessage.model_validate(last)
|
|
584
|
-
return resp
|
|
585
|
-
|
|
586
|
-
# =========================================================================
|
|
587
|
-
# Capabilities, formatting, and info
|
|
588
|
-
# =========================================================================
|
|
589
|
-
def _init_model_info_and_caps(self) -> None:
|
|
590
|
-
# Try to get model info via openrouter or litellm proxy first
|
|
591
|
-
tried = False
|
|
592
|
-
try:
|
|
593
|
-
if self.model.startswith("openrouter"):
|
|
594
|
-
self._model_info = get_model_info(self.model)
|
|
595
|
-
tried = True
|
|
596
|
-
except Exception as e:
|
|
597
|
-
logger.debug(f"get_model_info(openrouter) failed: {e}")
|
|
598
|
-
|
|
599
|
-
if not tried and self.model.startswith("litellm_proxy/"):
|
|
600
|
-
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
|
|
601
|
-
# GET {base_url}/v1/model/info with litellm_model_id as path param
|
|
602
|
-
base_url = self.base_url.strip() if self.base_url else ""
|
|
603
|
-
if not base_url.startswith(("http://", "https://")):
|
|
604
|
-
base_url = "http://" + base_url
|
|
605
|
-
try:
|
|
606
|
-
api_key = self.api_key.get_secret_value() if self.api_key else ""
|
|
607
|
-
response = httpx.get(
|
|
608
|
-
f"{base_url}/v1/model/info",
|
|
609
|
-
headers={"Authorization": f"Bearer {api_key}"},
|
|
610
|
-
)
|
|
611
|
-
data = response.json().get("data", [])
|
|
612
|
-
current = next(
|
|
613
|
-
(
|
|
614
|
-
info
|
|
615
|
-
for info in data
|
|
616
|
-
if info["model_name"]
|
|
617
|
-
== self.model.removeprefix("litellm_proxy/")
|
|
618
|
-
),
|
|
619
|
-
None,
|
|
620
|
-
)
|
|
621
|
-
if current:
|
|
622
|
-
self._model_info = current.get("model_info")
|
|
623
|
-
logger.debug(
|
|
624
|
-
f"Got model info from litellm proxy: {self._model_info}"
|
|
625
|
-
)
|
|
626
|
-
except Exception as e:
|
|
627
|
-
logger.info(f"Error fetching model info from proxy: {e}")
|
|
628
|
-
|
|
629
|
-
# Fallbacks: try base name variants
|
|
630
|
-
if not self._model_info:
|
|
631
|
-
try:
|
|
632
|
-
self._model_info = get_model_info(self.model.split(":")[0])
|
|
633
|
-
except Exception:
|
|
634
|
-
pass
|
|
635
|
-
if not self._model_info:
|
|
636
|
-
try:
|
|
637
|
-
self._model_info = get_model_info(self.model.split("/")[-1])
|
|
638
|
-
except Exception:
|
|
639
|
-
pass
|
|
640
|
-
|
|
641
|
-
# Context window and max_output_tokens
|
|
642
|
-
if (
|
|
643
|
-
self.max_input_tokens is None
|
|
644
|
-
and self._model_info is not None
|
|
645
|
-
and isinstance(self._model_info.get("max_input_tokens"), int)
|
|
646
|
-
):
|
|
647
|
-
self.max_input_tokens = self._model_info.get("max_input_tokens")
|
|
648
|
-
|
|
649
|
-
if self.max_output_tokens is None:
|
|
650
|
-
if any(m in self.model for m in ["claude-3-7-sonnet", "claude-3.7-sonnet"]):
|
|
651
|
-
self.max_output_tokens = (
|
|
652
|
-
64000 # practical cap (litellm may allow 128k with header)
|
|
653
|
-
)
|
|
654
|
-
elif self._model_info is not None:
|
|
655
|
-
if isinstance(self._model_info.get("max_output_tokens"), int):
|
|
656
|
-
self.max_output_tokens = self._model_info.get("max_output_tokens")
|
|
657
|
-
elif isinstance(self._model_info.get("max_tokens"), int):
|
|
658
|
-
self.max_output_tokens = self._model_info.get("max_tokens")
|
|
659
|
-
|
|
660
|
-
# Function-calling capabilities
|
|
661
|
-
feats = get_features(self.model)
|
|
662
|
-
logger.info(f"Model features for {self.model}: {feats}")
|
|
663
|
-
self._function_calling_active = (
|
|
664
|
-
self.native_tool_calling
|
|
665
|
-
if self.native_tool_calling is not None
|
|
666
|
-
else feats.supports_function_calling
|
|
667
|
-
)
|
|
668
|
-
|
|
669
|
-
def vision_is_active(self) -> bool:
|
|
670
|
-
with warnings.catch_warnings():
|
|
671
|
-
warnings.simplefilter("ignore")
|
|
672
|
-
return not self.disable_vision and self._supports_vision()
|
|
673
|
-
|
|
674
|
-
def _supports_vision(self) -> bool:
|
|
675
|
-
"""Acquire from litellm if model is vision capable.
|
|
676
|
-
|
|
677
|
-
Returns:
|
|
678
|
-
bool: True if model is vision capable. Return False if model not
|
|
679
|
-
supported by litellm.
|
|
680
|
-
"""
|
|
681
|
-
# litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes) # noqa: E501
|
|
682
|
-
# but model_info will have the correct value for some reason.
|
|
683
|
-
# we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers # noqa: E501
|
|
684
|
-
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608 # noqa: E501
|
|
685
|
-
# Check both the full model name and the name after proxy prefix for vision support # noqa: E501
|
|
686
|
-
return (
|
|
687
|
-
supports_vision(self.model)
|
|
688
|
-
or supports_vision(self.model.split("/")[-1])
|
|
689
|
-
or (
|
|
690
|
-
self._model_info is not None
|
|
691
|
-
and self._model_info.get("supports_vision", False)
|
|
692
|
-
)
|
|
693
|
-
or False # fallback to False if model_info is None
|
|
694
|
-
)
|
|
695
|
-
|
|
696
|
-
def is_caching_prompt_active(self) -> bool:
|
|
697
|
-
"""Check if prompt caching is supported and enabled for current model.
|
|
698
|
-
|
|
699
|
-
Returns:
|
|
700
|
-
boolean: True if prompt caching is supported and enabled for the given
|
|
701
|
-
model.
|
|
702
|
-
"""
|
|
703
|
-
if not self.caching_prompt:
|
|
704
|
-
return False
|
|
705
|
-
# We don't need to look-up model_info, because
|
|
706
|
-
# only Anthropic models need explicit caching breakpoints
|
|
707
|
-
return self.caching_prompt and get_features(self.model).supports_prompt_cache
|
|
708
|
-
|
|
709
|
-
def is_function_calling_active(self) -> bool:
|
|
710
|
-
"""Returns whether function calling is supported
|
|
711
|
-
and enabled for this LLM instance.
|
|
712
|
-
"""
|
|
713
|
-
return bool(self._function_calling_active)
|
|
714
|
-
|
|
715
|
-
@property
|
|
716
|
-
def model_info(self) -> dict | None:
|
|
717
|
-
"""Returns the model info dictionary."""
|
|
718
|
-
return self._model_info
|
|
719
|
-
|
|
720
|
-
# =========================================================================
|
|
721
|
-
# Utilities preserved from previous class
|
|
722
|
-
# =========================================================================
|
|
723
|
-
def _apply_prompt_caching(self, messages: list[Message]) -> None:
|
|
724
|
-
"""Applies caching breakpoints to the messages.
|
|
725
|
-
|
|
726
|
-
For new Anthropic API, we only need to mark the last user or
|
|
727
|
-
tool message as cacheable.
|
|
728
|
-
"""
|
|
729
|
-
if len(messages) > 0 and messages[0].role == "system":
|
|
730
|
-
messages[0].content[-1].cache_prompt = True
|
|
731
|
-
# NOTE: this is only needed for anthropic
|
|
732
|
-
for message in reversed(messages):
|
|
733
|
-
if message.role in ("user", "tool"):
|
|
734
|
-
message.content[
|
|
735
|
-
-1
|
|
736
|
-
].cache_prompt = True # Last item inside the message content
|
|
737
|
-
break
|
|
738
|
-
|
|
739
|
-
def format_messages_for_llm(self, messages: list[Message]) -> list[dict]:
|
|
740
|
-
"""Formats Message objects for LLM consumption."""
|
|
741
|
-
|
|
742
|
-
messages = copy.deepcopy(messages)
|
|
743
|
-
if self.is_caching_prompt_active():
|
|
744
|
-
self._apply_prompt_caching(messages)
|
|
745
|
-
|
|
746
|
-
for message in messages:
|
|
747
|
-
message.cache_enabled = self.is_caching_prompt_active()
|
|
748
|
-
message.vision_enabled = self.vision_is_active()
|
|
749
|
-
message.function_calling_enabled = self.is_function_calling_active()
|
|
750
|
-
if "deepseek" in self.model or (
|
|
751
|
-
"kimi-k2-instruct" in self.model and "groq" in self.model
|
|
752
|
-
):
|
|
753
|
-
message.force_string_serializer = True
|
|
754
|
-
|
|
755
|
-
return [message.to_llm_dict() for message in messages]
|
|
756
|
-
|
|
757
|
-
def get_token_count(self, messages: list[dict] | list[Message]) -> int:
|
|
758
|
-
if isinstance(messages, list) and messages and isinstance(messages[0], Message):
|
|
759
|
-
logger.info(
|
|
760
|
-
"Message objects now include serialized tool calls in token counting"
|
|
761
|
-
)
|
|
762
|
-
messages = self.format_messages_for_llm(cast(list[Message], messages))
|
|
763
|
-
try:
|
|
764
|
-
return int(
|
|
765
|
-
token_counter(
|
|
766
|
-
model=self.model,
|
|
767
|
-
messages=messages, # type: ignore[arg-type]
|
|
768
|
-
custom_tokenizer=self._tokenizer,
|
|
769
|
-
)
|
|
770
|
-
)
|
|
771
|
-
except Exception as e:
|
|
772
|
-
logger.error(
|
|
773
|
-
f"Error getting token count for model {self.model}\n{e}"
|
|
774
|
-
+ (
|
|
775
|
-
f"\ncustom_tokenizer: {self.custom_tokenizer}"
|
|
776
|
-
if self.custom_tokenizer
|
|
777
|
-
else ""
|
|
778
|
-
),
|
|
779
|
-
exc_info=True,
|
|
780
|
-
)
|
|
781
|
-
return 0
|
|
782
|
-
|
|
783
|
-
# =========================================================================
|
|
784
|
-
# Serialization helpers
|
|
785
|
-
# =========================================================================
|
|
786
|
-
@classmethod
|
|
787
|
-
def deserialize(cls, data: dict[str, Any]) -> "LLM":
|
|
788
|
-
return cls(**data)
|
|
789
|
-
|
|
790
|
-
def serialize(self) -> dict[str, Any]:
|
|
791
|
-
return self.model_dump()
|
|
792
|
-
|
|
793
|
-
@classmethod
|
|
794
|
-
def load_from_json(cls, json_path: str) -> "LLM":
|
|
795
|
-
with open(json_path, "r") as f:
|
|
796
|
-
data = json.load(f)
|
|
797
|
-
return cls.deserialize(data)
|
|
798
|
-
|
|
799
|
-
@classmethod
|
|
800
|
-
def load_from_env(cls, prefix: str = "LLM_") -> "LLM":
|
|
801
|
-
TRUTHY = {"true", "1", "yes", "on"}
|
|
802
|
-
|
|
803
|
-
def _unwrap_type(t: Any) -> Any:
|
|
804
|
-
origin = get_origin(t)
|
|
805
|
-
if origin is None:
|
|
806
|
-
return t
|
|
807
|
-
args = [a for a in get_args(t) if a is not type(None)]
|
|
808
|
-
return args[0] if args else t
|
|
809
|
-
|
|
810
|
-
def _cast_value(raw: str, t: Any) -> Any:
|
|
811
|
-
t = _unwrap_type(t)
|
|
812
|
-
if t is SecretStr:
|
|
813
|
-
return SecretStr(raw)
|
|
814
|
-
if t is bool:
|
|
815
|
-
return raw.lower() in TRUTHY
|
|
816
|
-
if t is int:
|
|
817
|
-
try:
|
|
818
|
-
return int(raw)
|
|
819
|
-
except ValueError:
|
|
820
|
-
return None
|
|
821
|
-
if t is float:
|
|
822
|
-
try:
|
|
823
|
-
return float(raw)
|
|
824
|
-
except ValueError:
|
|
825
|
-
return None
|
|
826
|
-
origin = get_origin(t)
|
|
827
|
-
if (origin in (list, dict, tuple)) or (
|
|
828
|
-
isinstance(t, type) and issubclass(t, BaseModel)
|
|
829
|
-
):
|
|
830
|
-
try:
|
|
831
|
-
return json.loads(raw)
|
|
832
|
-
except Exception:
|
|
833
|
-
pass
|
|
834
|
-
return raw
|
|
835
|
-
|
|
836
|
-
data: dict[str, Any] = {}
|
|
837
|
-
fields: dict[str, Any] = {
|
|
838
|
-
name: f.annotation
|
|
839
|
-
for name, f in cls.model_fields.items()
|
|
840
|
-
if not getattr(f, "exclude", False)
|
|
841
|
-
}
|
|
842
|
-
|
|
843
|
-
for key, value in os.environ.items():
|
|
844
|
-
if not key.startswith(prefix):
|
|
845
|
-
continue
|
|
846
|
-
field_name = key[len(prefix) :].lower()
|
|
847
|
-
if field_name not in fields:
|
|
848
|
-
continue
|
|
849
|
-
v = _cast_value(value, fields[field_name])
|
|
850
|
-
if v is not None:
|
|
851
|
-
data[field_name] = v
|
|
852
|
-
return cls.deserialize(data)
|
|
853
|
-
|
|
854
|
-
@classmethod
|
|
855
|
-
def load_from_toml(cls, toml_path: str) -> "LLM":
|
|
856
|
-
try:
|
|
857
|
-
import tomllib
|
|
858
|
-
except ImportError:
|
|
859
|
-
try:
|
|
860
|
-
import tomli as tomllib # type: ignore
|
|
861
|
-
except ImportError:
|
|
862
|
-
raise ImportError("tomllib or tomli is required to load TOML files")
|
|
863
|
-
with open(toml_path, "rb") as f:
|
|
864
|
-
data = tomllib.load(f)
|
|
865
|
-
if "llm" in data:
|
|
866
|
-
data = data["llm"]
|
|
867
|
-
return cls.deserialize(data)
|