speedy-utils 1.1.17__py3-none-any.whl → 1.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +8 -1
- llm_utils/chat_format/display.py +109 -14
- llm_utils/lm/__init__.py +12 -11
- llm_utils/lm/async_lm/async_llm_task.py +0 -10
- llm_utils/lm/async_lm/async_lm.py +13 -4
- llm_utils/lm/async_lm/async_lm_base.py +24 -14
- llm_utils/lm/base_prompt_builder.py +288 -0
- llm_utils/lm/llm_task.py +400 -0
- llm_utils/lm/lm.py +207 -0
- llm_utils/lm/lm_base.py +285 -0
- llm_utils/vector_cache/core.py +285 -89
- speedy_utils/common/patcher.py +68 -0
- speedy_utils/common/utils_cache.py +5 -5
- speedy_utils/common/utils_io.py +232 -6
- speedy_utils/multi_worker/process.py +124 -193
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.18.dist-info}/METADATA +3 -2
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.18.dist-info}/RECORD +19 -14
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.18.dist-info}/WHEEL +1 -1
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.18.dist-info}/entry_points.txt +0 -0
llm_utils/lm/lm.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
# # from ._utils import *
|
|
2
|
+
# from typing import (
|
|
3
|
+
# Any,
|
|
4
|
+
# List,
|
|
5
|
+
# Literal,
|
|
6
|
+
# Optional,
|
|
7
|
+
# Type,
|
|
8
|
+
# Union,
|
|
9
|
+
# cast,
|
|
10
|
+
# )
|
|
11
|
+
|
|
12
|
+
# from loguru import logger
|
|
13
|
+
# from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError
|
|
14
|
+
# from pydantic import BaseModel
|
|
15
|
+
# from speedy_utils import jloads
|
|
16
|
+
|
|
17
|
+
# # from llm_utils.lm.async_lm.async_llm_task import OutputModelType
|
|
18
|
+
# from llm_utils.lm.lm_base import LMBase
|
|
19
|
+
|
|
20
|
+
# from .async_lm._utils import (
|
|
21
|
+
# LegacyMsgs,
|
|
22
|
+
# Messages,
|
|
23
|
+
# OutputModelType,
|
|
24
|
+
# ParsedOutput,
|
|
25
|
+
# RawMsgs,
|
|
26
|
+
# )
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# class LM(LMBase):
|
|
30
|
+
# """Unified **sync** language‑model wrapper with optional JSON parsing."""
|
|
31
|
+
|
|
32
|
+
# def __init__(
|
|
33
|
+
# self,
|
|
34
|
+
# *,
|
|
35
|
+
# model: Optional[str] = None,
|
|
36
|
+
# response_model: Optional[type[BaseModel]] = None,
|
|
37
|
+
# temperature: float = 0.0,
|
|
38
|
+
# max_tokens: int = 2_000,
|
|
39
|
+
# base_url: Optional[str] = None,
|
|
40
|
+
# api_key: Optional[str] = None,
|
|
41
|
+
# cache: bool = True,
|
|
42
|
+
# ports: Optional[List[int]] = None,
|
|
43
|
+
# top_p: float = 1.0,
|
|
44
|
+
# presence_penalty: float = 0.0,
|
|
45
|
+
# top_k: int = 1,
|
|
46
|
+
# repetition_penalty: float = 1.0,
|
|
47
|
+
# frequency_penalty: Optional[float] = None,
|
|
48
|
+
# ) -> None:
|
|
49
|
+
|
|
50
|
+
# if model is None:
|
|
51
|
+
# if base_url is None:
|
|
52
|
+
# raise ValueError("Either model or base_url must be provided")
|
|
53
|
+
# models = OpenAI(base_url=base_url, api_key=api_key or 'abc').models.list().data
|
|
54
|
+
# assert len(models) == 1, f"Found {len(models)} models, please specify one."
|
|
55
|
+
# model = models[0].id
|
|
56
|
+
# print(f"Using model: {model}")
|
|
57
|
+
|
|
58
|
+
# super().__init__(
|
|
59
|
+
# ports=ports,
|
|
60
|
+
# base_url=base_url,
|
|
61
|
+
# cache=cache,
|
|
62
|
+
# api_key=api_key,
|
|
63
|
+
# )
|
|
64
|
+
|
|
65
|
+
# # Model behavior options
|
|
66
|
+
# self.response_model = response_model
|
|
67
|
+
|
|
68
|
+
# # Store all model-related parameters in model_kwargs
|
|
69
|
+
# self.model_kwargs = dict(
|
|
70
|
+
# model=model,
|
|
71
|
+
# temperature=temperature,
|
|
72
|
+
# max_tokens=max_tokens,
|
|
73
|
+
# top_p=top_p,
|
|
74
|
+
# presence_penalty=presence_penalty,
|
|
75
|
+
# )
|
|
76
|
+
# self.extra_body = dict(
|
|
77
|
+
# top_k=top_k,
|
|
78
|
+
# repetition_penalty=repetition_penalty,
|
|
79
|
+
# frequency_penalty=frequency_penalty,
|
|
80
|
+
# )
|
|
81
|
+
|
|
82
|
+
# def _unified_client_call(
|
|
83
|
+
# self,
|
|
84
|
+
# messages: RawMsgs,
|
|
85
|
+
# extra_body: Optional[dict] = None,
|
|
86
|
+
# max_tokens: Optional[int] = None,
|
|
87
|
+
# ) -> dict:
|
|
88
|
+
# """Unified method for all client interactions (caching handled by MOpenAI)."""
|
|
89
|
+
# converted_messages: Messages = (
|
|
90
|
+
# self._convert_messages(cast(LegacyMsgs, messages))
|
|
91
|
+
# if messages and isinstance(messages[0], dict)
|
|
92
|
+
# else cast(Messages, messages)
|
|
93
|
+
# )
|
|
94
|
+
# if max_tokens is not None:
|
|
95
|
+
# self.model_kwargs["max_tokens"] = max_tokens
|
|
96
|
+
|
|
97
|
+
# try:
|
|
98
|
+
# # Get completion from API (caching handled by MOpenAI)
|
|
99
|
+
# call_kwargs = {
|
|
100
|
+
# "messages": converted_messages,
|
|
101
|
+
# **self.model_kwargs,
|
|
102
|
+
# }
|
|
103
|
+
# if extra_body:
|
|
104
|
+
# call_kwargs["extra_body"] = extra_body
|
|
105
|
+
|
|
106
|
+
# completion = self.client.chat.completions.create(**call_kwargs)
|
|
107
|
+
|
|
108
|
+
# if hasattr(completion, "model_dump"):
|
|
109
|
+
# completion = completion.model_dump()
|
|
110
|
+
|
|
111
|
+
# except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
112
|
+
# error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
113
|
+
# logger.error(error_msg)
|
|
114
|
+
# raise
|
|
115
|
+
|
|
116
|
+
# return completion
|
|
117
|
+
|
|
118
|
+
# def __call__(
|
|
119
|
+
# self,
|
|
120
|
+
# prompt: Optional[str] = None,
|
|
121
|
+
# messages: Optional[RawMsgs] = None,
|
|
122
|
+
# max_tokens: Optional[int] = None,
|
|
123
|
+
# ): # -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:# -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:
|
|
124
|
+
# """Unified sync call for language model, returns (assistant_message.model_dump(), messages)."""
|
|
125
|
+
# if (prompt is None) == (messages is None):
|
|
126
|
+
# raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
127
|
+
|
|
128
|
+
# if prompt is not None:
|
|
129
|
+
# messages = [{"role": "user", "content": prompt}]
|
|
130
|
+
|
|
131
|
+
# assert messages is not None
|
|
132
|
+
|
|
133
|
+
# openai_msgs: Messages = (
|
|
134
|
+
# self._convert_messages(cast(LegacyMsgs, messages))
|
|
135
|
+
# if isinstance(messages[0], dict)
|
|
136
|
+
# else cast(Messages, messages)
|
|
137
|
+
# )
|
|
138
|
+
|
|
139
|
+
# assert self.model_kwargs["model"] is not None, (
|
|
140
|
+
# "Model must be set before making a call."
|
|
141
|
+
# )
|
|
142
|
+
|
|
143
|
+
# # Use unified client call
|
|
144
|
+
# raw_response = self._unified_client_call(
|
|
145
|
+
# list(openai_msgs), max_tokens=max_tokens
|
|
146
|
+
# )
|
|
147
|
+
|
|
148
|
+
# if hasattr(raw_response, "model_dump"):
|
|
149
|
+
# raw_response = raw_response.model_dump() # type: ignore
|
|
150
|
+
|
|
151
|
+
# # Extract the assistant's message
|
|
152
|
+
# assistant_msg = raw_response["choices"][0]["message"]
|
|
153
|
+
# # Build the full messages list (input + assistant reply)
|
|
154
|
+
# full_messages = list(messages) + [
|
|
155
|
+
# {"role": assistant_msg["role"], "content": assistant_msg["content"]}
|
|
156
|
+
# ]
|
|
157
|
+
# # Return the OpenAI message as model_dump (if available) and the messages list
|
|
158
|
+
# if hasattr(assistant_msg, "model_dump"):
|
|
159
|
+
# msg_dump = assistant_msg.model_dump()
|
|
160
|
+
# else:
|
|
161
|
+
# msg_dump = dict(assistant_msg)
|
|
162
|
+
# return msg_dump, full_messages
|
|
163
|
+
|
|
164
|
+
# def parse(
|
|
165
|
+
# self,
|
|
166
|
+
# messages: Messages,
|
|
167
|
+
# response_model: Optional[type[BaseModel]] = None,
|
|
168
|
+
# ) -> ParsedOutput[BaseModel]:
|
|
169
|
+
# """Parse response using OpenAI's native parse API."""
|
|
170
|
+
# # Use provided response_model or fall back to instance default
|
|
171
|
+
# model_to_use = response_model or self.response_model
|
|
172
|
+
# assert model_to_use is not None, "response_model must be provided or set at init."
|
|
173
|
+
|
|
174
|
+
# # Use OpenAI's native parse API directly
|
|
175
|
+
# response = self.client.chat.completions.parse(
|
|
176
|
+
# model=self.model_kwargs["model"],
|
|
177
|
+
# messages=messages,
|
|
178
|
+
# response_format=model_to_use,
|
|
179
|
+
# **{k: v for k, v in self.model_kwargs.items() if k != "model"}
|
|
180
|
+
# )
|
|
181
|
+
|
|
182
|
+
# parsed = response.choices[0].message.parsed
|
|
183
|
+
# completion = response.model_dump() if hasattr(response, "model_dump") else {}
|
|
184
|
+
# full_messages = list(messages) + [
|
|
185
|
+
# {"role": "assistant", "content": parsed}
|
|
186
|
+
# ]
|
|
187
|
+
|
|
188
|
+
# return ParsedOutput(
|
|
189
|
+
# messages=full_messages,
|
|
190
|
+
# parsed=cast(BaseModel, parsed),
|
|
191
|
+
# completion=completion,
|
|
192
|
+
# model_kwargs=self.model_kwargs,
|
|
193
|
+
# )
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# def __enter__(self):
|
|
198
|
+
# return self
|
|
199
|
+
|
|
200
|
+
# def __exit__(self, exc_type, exc_val, exc_tb):
|
|
201
|
+
# if hasattr(self, "_last_client"):
|
|
202
|
+
# last_client = self._last_client # type: ignore
|
|
203
|
+
# if hasattr(last_client, "close"):
|
|
204
|
+
# last_client.close()
|
|
205
|
+
# else:
|
|
206
|
+
# logger.warning("No last client to close")
|
|
207
|
+
LM = None
|
llm_utils/lm/lm_base.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
# from ._utils import *
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
List,
|
|
7
|
+
Optional,
|
|
8
|
+
Type,
|
|
9
|
+
Union,
|
|
10
|
+
cast,
|
|
11
|
+
overload,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from httpx import URL
|
|
15
|
+
from loguru import logger
|
|
16
|
+
from openai import OpenAI
|
|
17
|
+
from openai.pagination import SyncPage
|
|
18
|
+
from openai.types.chat import (
|
|
19
|
+
ChatCompletionAssistantMessageParam,
|
|
20
|
+
ChatCompletionSystemMessageParam,
|
|
21
|
+
ChatCompletionToolMessageParam,
|
|
22
|
+
ChatCompletionUserMessageParam,
|
|
23
|
+
)
|
|
24
|
+
from openai.types.model import Model
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
|
|
27
|
+
from llm_utils.lm.openai_memoize import MOpenAI
|
|
28
|
+
|
|
29
|
+
from .async_lm._utils import (
|
|
30
|
+
LegacyMsgs,
|
|
31
|
+
Messages,
|
|
32
|
+
RawMsgs,
|
|
33
|
+
TModel,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class LMBase:
|
|
38
|
+
"""Unified **sync** language‑model wrapper with optional JSON parsing."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
*,
|
|
43
|
+
base_url: Optional[str] = None,
|
|
44
|
+
api_key: Optional[str] = None,
|
|
45
|
+
cache: bool = True,
|
|
46
|
+
ports: Optional[List[int]] = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
self.base_url = base_url
|
|
49
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
50
|
+
self._cache = cache
|
|
51
|
+
self.ports = ports
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def client(self) -> MOpenAI:
|
|
55
|
+
# if have multiple ports
|
|
56
|
+
if self.ports and self.base_url:
|
|
57
|
+
import random
|
|
58
|
+
import re
|
|
59
|
+
|
|
60
|
+
port = random.choice(self.ports)
|
|
61
|
+
# Replace port in base_url if it exists
|
|
62
|
+
base_url_pattern = r'(https?://[^:/]+):?\d*(/.*)?'
|
|
63
|
+
match = re.match(base_url_pattern, self.base_url)
|
|
64
|
+
if match:
|
|
65
|
+
host_part = match.group(1)
|
|
66
|
+
path_part = match.group(2) or '/v1'
|
|
67
|
+
api_base = f"{host_part}:{port}{path_part}"
|
|
68
|
+
else:
|
|
69
|
+
api_base = self.base_url
|
|
70
|
+
logger.debug(f"Using port: {port}")
|
|
71
|
+
else:
|
|
72
|
+
api_base = self.base_url
|
|
73
|
+
|
|
74
|
+
if api_base is None:
|
|
75
|
+
raise ValueError("base_url must be provided")
|
|
76
|
+
|
|
77
|
+
client = MOpenAI(
|
|
78
|
+
api_key=self.api_key,
|
|
79
|
+
base_url=api_base,
|
|
80
|
+
cache=self._cache,
|
|
81
|
+
)
|
|
82
|
+
self._last_client = client
|
|
83
|
+
return client
|
|
84
|
+
|
|
85
|
+
# ------------------------------------------------------------------ #
|
|
86
|
+
# Public API – typed overloads
|
|
87
|
+
# ------------------------------------------------------------------ #
|
|
88
|
+
@overload
|
|
89
|
+
def __call__( # type: ignore
|
|
90
|
+
self,
|
|
91
|
+
*,
|
|
92
|
+
prompt: Optional[str] = ...,
|
|
93
|
+
messages: Optional[RawMsgs] = ...,
|
|
94
|
+
response_format: type[str] = str,
|
|
95
|
+
return_openai_response: bool = ...,
|
|
96
|
+
**kwargs: Any,
|
|
97
|
+
) -> str: ...
|
|
98
|
+
|
|
99
|
+
@overload
|
|
100
|
+
def __call__(
|
|
101
|
+
self,
|
|
102
|
+
*,
|
|
103
|
+
prompt: Optional[str] = ...,
|
|
104
|
+
messages: Optional[RawMsgs] = ...,
|
|
105
|
+
response_format: Type[TModel],
|
|
106
|
+
return_openai_response: bool = ...,
|
|
107
|
+
**kwargs: Any,
|
|
108
|
+
) -> TModel: ...
|
|
109
|
+
|
|
110
|
+
# ------------------------------------------------------------------ #
|
|
111
|
+
# Utilities below are unchanged (sync I/O is acceptable)
|
|
112
|
+
# ------------------------------------------------------------------ #
|
|
113
|
+
@staticmethod
|
|
114
|
+
def _convert_messages(msgs: LegacyMsgs) -> Messages:
|
|
115
|
+
converted: Messages = []
|
|
116
|
+
for msg in msgs:
|
|
117
|
+
role = msg["role"]
|
|
118
|
+
content = msg["content"]
|
|
119
|
+
if role == "user":
|
|
120
|
+
converted.append(
|
|
121
|
+
ChatCompletionUserMessageParam(role="user", content=content)
|
|
122
|
+
)
|
|
123
|
+
elif role == "assistant":
|
|
124
|
+
converted.append(
|
|
125
|
+
ChatCompletionAssistantMessageParam(
|
|
126
|
+
role="assistant", content=content
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
elif role == "system":
|
|
130
|
+
converted.append(
|
|
131
|
+
ChatCompletionSystemMessageParam(role="system", content=content)
|
|
132
|
+
)
|
|
133
|
+
elif role == "tool":
|
|
134
|
+
converted.append(
|
|
135
|
+
ChatCompletionToolMessageParam(
|
|
136
|
+
role="tool",
|
|
137
|
+
content=content,
|
|
138
|
+
tool_call_id=msg.get("tool_call_id") or "",
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
converted.append({"role": role, "content": content}) # type: ignore[arg-type]
|
|
143
|
+
return converted
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def _parse_output(
|
|
147
|
+
raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
148
|
+
) -> Union[str, BaseModel]:
|
|
149
|
+
if hasattr(raw_response, "model_dump"):
|
|
150
|
+
raw_response = raw_response.model_dump()
|
|
151
|
+
|
|
152
|
+
if response_format is str:
|
|
153
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
154
|
+
message = raw_response["choices"][0]["message"]
|
|
155
|
+
return message.get("content", "") or ""
|
|
156
|
+
return cast(str, raw_response)
|
|
157
|
+
|
|
158
|
+
model_cls = cast(Type[BaseModel], response_format)
|
|
159
|
+
|
|
160
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
161
|
+
message = raw_response["choices"][0]["message"]
|
|
162
|
+
if "parsed" in message:
|
|
163
|
+
return model_cls.model_validate(message["parsed"])
|
|
164
|
+
content = message.get("content")
|
|
165
|
+
if content is None:
|
|
166
|
+
raise ValueError("Model returned empty content")
|
|
167
|
+
try:
|
|
168
|
+
data = json.loads(content)
|
|
169
|
+
return model_cls.model_validate(data)
|
|
170
|
+
except Exception as exc:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"Failed to parse model output as JSON:\n{content}"
|
|
173
|
+
) from exc
|
|
174
|
+
|
|
175
|
+
if isinstance(raw_response, model_cls):
|
|
176
|
+
return raw_response
|
|
177
|
+
if isinstance(raw_response, dict):
|
|
178
|
+
return model_cls.model_validate(raw_response)
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
data = json.loads(raw_response)
|
|
182
|
+
return model_cls.model_validate(data)
|
|
183
|
+
except Exception as exc:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"Model did not return valid JSON:\n---\n{raw_response}"
|
|
186
|
+
) from exc
|
|
187
|
+
|
|
188
|
+
# ------------------------------------------------------------------ #
|
|
189
|
+
# Misc helpers
|
|
190
|
+
# ------------------------------------------------------------------ #
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def list_models(base_url: Optional[str] = None) -> List[str]:
|
|
194
|
+
try:
|
|
195
|
+
if base_url is None:
|
|
196
|
+
raise ValueError("base_url must be provided")
|
|
197
|
+
client = LMBase(base_url=base_url).client
|
|
198
|
+
base_url_obj: URL = client.base_url
|
|
199
|
+
logger.debug(f"Base URL: {base_url_obj}")
|
|
200
|
+
models: SyncPage[Model] = client.models.list() # type: ignore[assignment]
|
|
201
|
+
return [model.id for model in models.data]
|
|
202
|
+
except Exception as exc:
|
|
203
|
+
logger.error(f"Failed to list models: {exc}")
|
|
204
|
+
return []
|
|
205
|
+
|
|
206
|
+
def build_system_prompt(
|
|
207
|
+
self,
|
|
208
|
+
response_model,
|
|
209
|
+
add_json_schema_to_instruction,
|
|
210
|
+
json_schema,
|
|
211
|
+
system_content,
|
|
212
|
+
think,
|
|
213
|
+
):
|
|
214
|
+
if add_json_schema_to_instruction and response_model:
|
|
215
|
+
schema_block = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
|
|
216
|
+
# if schema_block not in system_content:
|
|
217
|
+
if "<output_json_schema>" in system_content:
|
|
218
|
+
# remove exsting schema block
|
|
219
|
+
import re # replace
|
|
220
|
+
|
|
221
|
+
system_content = re.sub(
|
|
222
|
+
r"<output_json_schema>.*?</output_json_schema>",
|
|
223
|
+
"",
|
|
224
|
+
system_content,
|
|
225
|
+
flags=re.DOTALL,
|
|
226
|
+
)
|
|
227
|
+
system_content = system_content.strip()
|
|
228
|
+
system_content += schema_block
|
|
229
|
+
|
|
230
|
+
if think is True:
|
|
231
|
+
if "/think" in system_content:
|
|
232
|
+
pass
|
|
233
|
+
elif "/no_think" in system_content:
|
|
234
|
+
system_content = system_content.replace("/no_think", "/think")
|
|
235
|
+
else:
|
|
236
|
+
system_content += "\n\n/think"
|
|
237
|
+
elif think is False:
|
|
238
|
+
if "/no_think" in system_content:
|
|
239
|
+
pass
|
|
240
|
+
elif "/think" in system_content:
|
|
241
|
+
system_content = system_content.replace("/think", "/no_think")
|
|
242
|
+
else:
|
|
243
|
+
system_content += "\n\n/no_think"
|
|
244
|
+
return system_content
|
|
245
|
+
|
|
246
|
+
def inspect_history(self):
|
|
247
|
+
"""Inspect the history of the LLM calls."""
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def get_model_name(client: OpenAI|str|int) -> str:
|
|
252
|
+
"""
|
|
253
|
+
Get the first available model name from the client.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
client: OpenAI client, base_url string, or port number
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Name of the first available model
|
|
260
|
+
|
|
261
|
+
Raises:
|
|
262
|
+
ValueError: If no models are available or client is invalid
|
|
263
|
+
"""
|
|
264
|
+
try:
|
|
265
|
+
if isinstance(client, OpenAI):
|
|
266
|
+
openai_client = client
|
|
267
|
+
elif isinstance(client, str):
|
|
268
|
+
# String base_url
|
|
269
|
+
openai_client = OpenAI(base_url=client, api_key='abc')
|
|
270
|
+
elif isinstance(client, int):
|
|
271
|
+
# Port number
|
|
272
|
+
base_url = f"http://localhost:{client}/v1"
|
|
273
|
+
openai_client = OpenAI(base_url=base_url, api_key='abc')
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(f"Unsupported client type: {type(client)}")
|
|
276
|
+
|
|
277
|
+
models = openai_client.models.list()
|
|
278
|
+
if not models.data:
|
|
279
|
+
raise ValueError("No models available")
|
|
280
|
+
|
|
281
|
+
return models.data[0].id
|
|
282
|
+
|
|
283
|
+
except Exception as exc:
|
|
284
|
+
logger.error(f"Failed to get model name: {exc}")
|
|
285
|
+
raise ValueError(f"Could not retrieve model name: {exc}") from exc
|