speedy-utils 1.1.27__py3-none-any.whl → 1.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +16 -4
- llm_utils/chat_format/__init__.py +10 -10
- llm_utils/chat_format/display.py +33 -21
- llm_utils/chat_format/transform.py +17 -19
- llm_utils/chat_format/utils.py +6 -4
- llm_utils/group_messages.py +17 -14
- llm_utils/lm/__init__.py +6 -5
- llm_utils/lm/async_lm/__init__.py +1 -0
- llm_utils/lm/async_lm/_utils.py +10 -9
- llm_utils/lm/async_lm/async_llm_task.py +141 -137
- llm_utils/lm/async_lm/async_lm.py +48 -42
- llm_utils/lm/async_lm/async_lm_base.py +59 -60
- llm_utils/lm/async_lm/lm_specific.py +4 -3
- llm_utils/lm/base_prompt_builder.py +93 -70
- llm_utils/lm/llm.py +126 -108
- llm_utils/lm/llm_signature.py +4 -2
- llm_utils/lm/lm_base.py +72 -73
- llm_utils/lm/mixins.py +102 -62
- llm_utils/lm/openai_memoize.py +124 -87
- llm_utils/lm/signature.py +105 -92
- llm_utils/lm/utils.py +42 -23
- llm_utils/scripts/vllm_load_balancer.py +23 -30
- llm_utils/scripts/vllm_serve.py +8 -7
- llm_utils/vector_cache/__init__.py +9 -3
- llm_utils/vector_cache/cli.py +1 -1
- llm_utils/vector_cache/core.py +59 -63
- llm_utils/vector_cache/types.py +7 -5
- llm_utils/vector_cache/utils.py +12 -8
- speedy_utils/__imports.py +244 -0
- speedy_utils/__init__.py +90 -194
- speedy_utils/all.py +125 -227
- speedy_utils/common/clock.py +37 -42
- speedy_utils/common/function_decorator.py +6 -12
- speedy_utils/common/logger.py +43 -52
- speedy_utils/common/notebook_utils.py +13 -21
- speedy_utils/common/patcher.py +21 -17
- speedy_utils/common/report_manager.py +42 -44
- speedy_utils/common/utils_cache.py +152 -169
- speedy_utils/common/utils_io.py +137 -103
- speedy_utils/common/utils_misc.py +15 -21
- speedy_utils/common/utils_print.py +22 -28
- speedy_utils/multi_worker/process.py +66 -79
- speedy_utils/multi_worker/thread.py +78 -155
- speedy_utils/scripts/mpython.py +38 -36
- speedy_utils/scripts/openapi_client_codegen.py +10 -10
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/METADATA +1 -1
- speedy_utils-1.1.28.dist-info/RECORD +57 -0
- vision_utils/README.md +202 -0
- vision_utils/__init__.py +5 -0
- vision_utils/io_utils.py +470 -0
- vision_utils/plot.py +345 -0
- speedy_utils-1.1.27.dist-info/RECORD +0 -52
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/entry_points.txt +0 -0
llm_utils/lm/lm_base.py
CHANGED
|
@@ -40,23 +40,23 @@ class LMBase:
|
|
|
40
40
|
def __init__(
|
|
41
41
|
self,
|
|
42
42
|
*,
|
|
43
|
-
base_url:
|
|
44
|
-
api_key:
|
|
43
|
+
base_url: str | None = None,
|
|
44
|
+
api_key: str | None = None,
|
|
45
45
|
cache: bool = True,
|
|
46
|
-
ports:
|
|
46
|
+
ports: list[int] | None = None,
|
|
47
47
|
) -> None:
|
|
48
48
|
self.base_url = base_url
|
|
49
|
-
self.api_key = api_key or os.getenv(
|
|
49
|
+
self.api_key = api_key or os.getenv('OPENAI_API_KEY', 'abc')
|
|
50
50
|
self._cache = cache
|
|
51
51
|
self.ports = ports
|
|
52
52
|
|
|
53
53
|
@property
|
|
54
|
-
def client(self) -> MOpenAI:
|
|
54
|
+
def client(self) -> MOpenAI: # type: ignore
|
|
55
55
|
# if have multiple ports
|
|
56
56
|
if self.ports and self.base_url:
|
|
57
57
|
import random
|
|
58
58
|
import re
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
port = random.choice(self.ports)
|
|
61
61
|
# Replace port in base_url if it exists
|
|
62
62
|
base_url_pattern = r'(https?://[^:/]+):?\d*(/.*)?'
|
|
@@ -64,16 +64,16 @@ class LMBase:
|
|
|
64
64
|
if match:
|
|
65
65
|
host_part = match.group(1)
|
|
66
66
|
path_part = match.group(2) or '/v1'
|
|
67
|
-
api_base = f
|
|
67
|
+
api_base = f'{host_part}:{port}{path_part}'
|
|
68
68
|
else:
|
|
69
69
|
api_base = self.base_url
|
|
70
|
-
logger.debug(f
|
|
70
|
+
logger.debug(f'Using port: {port}')
|
|
71
71
|
else:
|
|
72
72
|
api_base = self.base_url
|
|
73
|
-
|
|
73
|
+
|
|
74
74
|
if api_base is None:
|
|
75
|
-
raise ValueError(
|
|
76
|
-
|
|
75
|
+
raise ValueError('base_url must be provided')
|
|
76
|
+
|
|
77
77
|
client = MOpenAI(
|
|
78
78
|
api_key=self.api_key,
|
|
79
79
|
base_url=api_base,
|
|
@@ -89,8 +89,8 @@ class LMBase:
|
|
|
89
89
|
def __call__( # type: ignore
|
|
90
90
|
self,
|
|
91
91
|
*,
|
|
92
|
-
prompt:
|
|
93
|
-
messages:
|
|
92
|
+
prompt: str | None = ...,
|
|
93
|
+
messages: RawMsgs | None = ...,
|
|
94
94
|
response_format: type[str] = str,
|
|
95
95
|
return_openai_response: bool = ...,
|
|
96
96
|
**kwargs: Any,
|
|
@@ -100,9 +100,9 @@ class LMBase:
|
|
|
100
100
|
def __call__(
|
|
101
101
|
self,
|
|
102
102
|
*,
|
|
103
|
-
prompt:
|
|
104
|
-
messages:
|
|
105
|
-
response_format:
|
|
103
|
+
prompt: str | None = ...,
|
|
104
|
+
messages: RawMsgs | None = ...,
|
|
105
|
+
response_format: type[TModel],
|
|
106
106
|
return_openai_response: bool = ...,
|
|
107
107
|
**kwargs: Any,
|
|
108
108
|
) -> TModel: ...
|
|
@@ -114,62 +114,62 @@ class LMBase:
|
|
|
114
114
|
def _convert_messages(msgs: LegacyMsgs) -> Messages:
|
|
115
115
|
converted: Messages = []
|
|
116
116
|
for msg in msgs:
|
|
117
|
-
role = msg[
|
|
118
|
-
content = msg[
|
|
119
|
-
if role ==
|
|
117
|
+
role = msg['role']
|
|
118
|
+
content = msg['content']
|
|
119
|
+
if role == 'user':
|
|
120
120
|
converted.append(
|
|
121
|
-
ChatCompletionUserMessageParam(role=
|
|
121
|
+
ChatCompletionUserMessageParam(role='user', content=content)
|
|
122
122
|
)
|
|
123
|
-
elif role ==
|
|
123
|
+
elif role == 'assistant':
|
|
124
124
|
converted.append(
|
|
125
125
|
ChatCompletionAssistantMessageParam(
|
|
126
|
-
role=
|
|
126
|
+
role='assistant', content=content
|
|
127
127
|
)
|
|
128
128
|
)
|
|
129
|
-
elif role ==
|
|
129
|
+
elif role == 'system':
|
|
130
130
|
converted.append(
|
|
131
|
-
ChatCompletionSystemMessageParam(role=
|
|
131
|
+
ChatCompletionSystemMessageParam(role='system', content=content)
|
|
132
132
|
)
|
|
133
|
-
elif role ==
|
|
133
|
+
elif role == 'tool':
|
|
134
134
|
converted.append(
|
|
135
135
|
ChatCompletionToolMessageParam(
|
|
136
|
-
role=
|
|
136
|
+
role='tool',
|
|
137
137
|
content=content,
|
|
138
|
-
tool_call_id=msg.get(
|
|
138
|
+
tool_call_id=msg.get('tool_call_id') or '',
|
|
139
139
|
)
|
|
140
140
|
)
|
|
141
141
|
else:
|
|
142
|
-
converted.append({
|
|
142
|
+
converted.append({'role': role, 'content': content}) # type: ignore[arg-type]
|
|
143
143
|
return converted
|
|
144
144
|
|
|
145
145
|
@staticmethod
|
|
146
146
|
def _parse_output(
|
|
147
|
-
raw_response: Any, response_format:
|
|
148
|
-
) ->
|
|
149
|
-
if hasattr(raw_response,
|
|
147
|
+
raw_response: Any, response_format: type[str] | type[BaseModel]
|
|
148
|
+
) -> str | BaseModel:
|
|
149
|
+
if hasattr(raw_response, 'model_dump'):
|
|
150
150
|
raw_response = raw_response.model_dump()
|
|
151
151
|
|
|
152
152
|
if response_format is str:
|
|
153
|
-
if isinstance(raw_response, dict) and
|
|
154
|
-
message = raw_response[
|
|
155
|
-
return message.get(
|
|
153
|
+
if isinstance(raw_response, dict) and 'choices' in raw_response:
|
|
154
|
+
message = raw_response['choices'][0]['message']
|
|
155
|
+
return message.get('content', '') or ''
|
|
156
156
|
return cast(str, raw_response)
|
|
157
157
|
|
|
158
|
-
model_cls = cast(
|
|
158
|
+
model_cls = cast(type[BaseModel], response_format)
|
|
159
159
|
|
|
160
|
-
if isinstance(raw_response, dict) and
|
|
161
|
-
message = raw_response[
|
|
162
|
-
if
|
|
163
|
-
return model_cls.model_validate(message[
|
|
164
|
-
content = message.get(
|
|
160
|
+
if isinstance(raw_response, dict) and 'choices' in raw_response:
|
|
161
|
+
message = raw_response['choices'][0]['message']
|
|
162
|
+
if 'parsed' in message:
|
|
163
|
+
return model_cls.model_validate(message['parsed'])
|
|
164
|
+
content = message.get('content')
|
|
165
165
|
if content is None:
|
|
166
|
-
raise ValueError(
|
|
166
|
+
raise ValueError('Model returned empty content')
|
|
167
167
|
try:
|
|
168
168
|
data = json.loads(content)
|
|
169
169
|
return model_cls.model_validate(data)
|
|
170
170
|
except Exception as exc:
|
|
171
171
|
raise ValueError(
|
|
172
|
-
f
|
|
172
|
+
f'Failed to parse model output as JSON:\n{content}'
|
|
173
173
|
) from exc
|
|
174
174
|
|
|
175
175
|
if isinstance(raw_response, model_cls):
|
|
@@ -182,7 +182,7 @@ class LMBase:
|
|
|
182
182
|
return model_cls.model_validate(data)
|
|
183
183
|
except Exception as exc:
|
|
184
184
|
raise ValueError(
|
|
185
|
-
f
|
|
185
|
+
f'Model did not return valid JSON:\n---\n{raw_response}'
|
|
186
186
|
) from exc
|
|
187
187
|
|
|
188
188
|
# ------------------------------------------------------------------ #
|
|
@@ -190,17 +190,17 @@ class LMBase:
|
|
|
190
190
|
# ------------------------------------------------------------------ #
|
|
191
191
|
|
|
192
192
|
@staticmethod
|
|
193
|
-
def list_models(base_url:
|
|
193
|
+
def list_models(base_url: str | None = None) -> list[str]:
|
|
194
194
|
try:
|
|
195
195
|
if base_url is None:
|
|
196
|
-
raise ValueError(
|
|
196
|
+
raise ValueError('base_url must be provided')
|
|
197
197
|
client = LMBase(base_url=base_url).client
|
|
198
198
|
base_url_obj: URL = client.base_url
|
|
199
|
-
logger.debug(f
|
|
199
|
+
logger.debug(f'Base URL: {base_url_obj}')
|
|
200
200
|
models: SyncPage[Model] = client.models.list() # type: ignore[assignment]
|
|
201
201
|
return [model.id for model in models.data]
|
|
202
202
|
except Exception as exc:
|
|
203
|
-
logger.error(f
|
|
203
|
+
logger.error(f'Failed to list models: {exc}')
|
|
204
204
|
return []
|
|
205
205
|
|
|
206
206
|
def build_system_prompt(
|
|
@@ -212,15 +212,15 @@ class LMBase:
|
|
|
212
212
|
think,
|
|
213
213
|
):
|
|
214
214
|
if add_json_schema_to_instruction and response_model:
|
|
215
|
-
schema_block = f
|
|
215
|
+
schema_block = f'\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>'
|
|
216
216
|
# if schema_block not in system_content:
|
|
217
|
-
if
|
|
217
|
+
if '<output_json_schema>' in system_content:
|
|
218
218
|
# remove exsting schema block
|
|
219
219
|
import re # replace
|
|
220
220
|
|
|
221
221
|
system_content = re.sub(
|
|
222
|
-
r
|
|
223
|
-
|
|
222
|
+
r'<output_json_schema>.*?</output_json_schema>',
|
|
223
|
+
'',
|
|
224
224
|
system_content,
|
|
225
225
|
flags=re.DOTALL,
|
|
226
226
|
)
|
|
@@ -228,36 +228,35 @@ class LMBase:
|
|
|
228
228
|
system_content += schema_block
|
|
229
229
|
|
|
230
230
|
if think is True:
|
|
231
|
-
if
|
|
231
|
+
if '/think' in system_content:
|
|
232
232
|
pass
|
|
233
|
-
elif
|
|
234
|
-
system_content = system_content.replace(
|
|
233
|
+
elif '/no_think' in system_content:
|
|
234
|
+
system_content = system_content.replace('/no_think', '/think')
|
|
235
235
|
else:
|
|
236
|
-
system_content +=
|
|
236
|
+
system_content += '\n\n/think'
|
|
237
237
|
elif think is False:
|
|
238
|
-
if
|
|
238
|
+
if '/no_think' in system_content:
|
|
239
239
|
pass
|
|
240
|
-
elif
|
|
241
|
-
system_content = system_content.replace(
|
|
240
|
+
elif '/think' in system_content:
|
|
241
|
+
system_content = system_content.replace('/think', '/no_think')
|
|
242
242
|
else:
|
|
243
|
-
system_content +=
|
|
243
|
+
system_content += '\n\n/no_think'
|
|
244
244
|
return system_content
|
|
245
245
|
|
|
246
246
|
def inspect_history(self):
|
|
247
247
|
"""Inspect the history of the LLM calls."""
|
|
248
|
-
pass
|
|
249
|
-
|
|
250
248
|
|
|
251
|
-
|
|
249
|
+
|
|
250
|
+
def get_model_name(client: OpenAI | str | int) -> str:
|
|
252
251
|
"""
|
|
253
252
|
Get the first available model name from the client.
|
|
254
|
-
|
|
253
|
+
|
|
255
254
|
Args:
|
|
256
255
|
client: OpenAI client, base_url string, or port number
|
|
257
|
-
|
|
256
|
+
|
|
258
257
|
Returns:
|
|
259
258
|
Name of the first available model
|
|
260
|
-
|
|
259
|
+
|
|
261
260
|
Raises:
|
|
262
261
|
ValueError: If no models are available or client is invalid
|
|
263
262
|
"""
|
|
@@ -269,17 +268,17 @@ def get_model_name(client: OpenAI|str|int) -> str:
|
|
|
269
268
|
openai_client = OpenAI(base_url=client, api_key='abc')
|
|
270
269
|
elif isinstance(client, int):
|
|
271
270
|
# Port number
|
|
272
|
-
base_url = f
|
|
271
|
+
base_url = f'http://localhost:{client}/v1'
|
|
273
272
|
openai_client = OpenAI(base_url=base_url, api_key='abc')
|
|
274
273
|
else:
|
|
275
|
-
raise ValueError(f
|
|
276
|
-
|
|
274
|
+
raise ValueError(f'Unsupported client type: {type(client)}')
|
|
275
|
+
|
|
277
276
|
models = openai_client.models.list()
|
|
278
277
|
if not models.data:
|
|
279
|
-
raise ValueError(
|
|
280
|
-
|
|
278
|
+
raise ValueError('No models available')
|
|
279
|
+
|
|
281
280
|
return models.data[0].id
|
|
282
|
-
|
|
281
|
+
|
|
283
282
|
except Exception as exc:
|
|
284
|
-
logger.error(f
|
|
285
|
-
raise ValueError(f
|
|
283
|
+
logger.error(f'Failed to get model name: {exc}')
|
|
284
|
+
raise ValueError(f'Could not retrieve model name: {exc}') from exc
|
llm_utils/lm/mixins.py
CHANGED
|
@@ -1,14 +1,21 @@
|
|
|
1
1
|
"""Mixin classes for LLM functionality extensions."""
|
|
2
2
|
|
|
3
|
+
# type: ignore
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
3
7
|
import os
|
|
4
8
|
import subprocess
|
|
5
9
|
from time import sleep
|
|
6
|
-
from typing import Any, Dict, List, Optional, Type, Union
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
|
7
11
|
|
|
8
12
|
import requests
|
|
9
13
|
from loguru import logger
|
|
10
|
-
|
|
11
|
-
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from openai import OpenAI
|
|
18
|
+
from pydantic import BaseModel
|
|
12
19
|
|
|
13
20
|
|
|
14
21
|
class TemperatureRangeMixin:
|
|
@@ -16,12 +23,12 @@ class TemperatureRangeMixin:
|
|
|
16
23
|
|
|
17
24
|
def temperature_range_sampling(
|
|
18
25
|
self,
|
|
19
|
-
input_data:
|
|
26
|
+
input_data: 'str | BaseModel | list[dict]',
|
|
20
27
|
temperature_ranges: tuple[float, float],
|
|
21
28
|
n: int = 32,
|
|
22
|
-
response_model:
|
|
29
|
+
response_model: 'type[BaseModel] | type[str] | None' = None,
|
|
23
30
|
**runtime_kwargs,
|
|
24
|
-
) ->
|
|
31
|
+
) -> list[dict[str, Any]]:
|
|
25
32
|
"""
|
|
26
33
|
Sample LLM responses with a range of temperatures.
|
|
27
34
|
|
|
@@ -38,11 +45,13 @@ class TemperatureRangeMixin:
|
|
|
38
45
|
Returns:
|
|
39
46
|
List of response dictionaries from all temperature samples
|
|
40
47
|
"""
|
|
48
|
+
from pydantic import BaseModel
|
|
49
|
+
|
|
41
50
|
from speedy_utils.multi_worker.thread import multi_thread
|
|
42
51
|
|
|
43
52
|
min_temp, max_temp = temperature_ranges
|
|
44
53
|
if n < 2:
|
|
45
|
-
raise ValueError(f
|
|
54
|
+
raise ValueError(f'n must be >= 2, got {n}')
|
|
46
55
|
|
|
47
56
|
step = (max_temp - min_temp) / (n - 1)
|
|
48
57
|
list_kwargs = []
|
|
@@ -56,7 +65,7 @@ class TemperatureRangeMixin:
|
|
|
56
65
|
list_kwargs.append(kwargs)
|
|
57
66
|
|
|
58
67
|
def f(kwargs):
|
|
59
|
-
i = kwargs.pop(
|
|
68
|
+
i = kwargs.pop('i')
|
|
60
69
|
sleep(i * 0.05)
|
|
61
70
|
return self.__inner_call__(
|
|
62
71
|
input_data,
|
|
@@ -73,10 +82,10 @@ class TwoStepPydanticMixin:
|
|
|
73
82
|
|
|
74
83
|
def two_step_pydantic_parse(
|
|
75
84
|
self,
|
|
76
|
-
input_data
|
|
77
|
-
response_model
|
|
85
|
+
input_data,
|
|
86
|
+
response_model,
|
|
78
87
|
**runtime_kwargs,
|
|
79
|
-
) ->
|
|
88
|
+
) -> list[dict[str, Any]]:
|
|
80
89
|
"""
|
|
81
90
|
Parse responses in two steps: text completion then Pydantic parsing.
|
|
82
91
|
|
|
@@ -91,32 +100,45 @@ class TwoStepPydanticMixin:
|
|
|
91
100
|
Returns:
|
|
92
101
|
List of parsed response dictionaries
|
|
93
102
|
"""
|
|
103
|
+
from pydantic import BaseModel
|
|
104
|
+
|
|
94
105
|
# Step 1: Get text completions
|
|
95
106
|
results = self.text_completion(input_data, **runtime_kwargs)
|
|
96
107
|
parsed_results = []
|
|
97
108
|
|
|
98
109
|
for result in results:
|
|
99
|
-
response_text = result[
|
|
100
|
-
messages = result[
|
|
110
|
+
response_text = result['parsed']
|
|
111
|
+
messages = result['messages']
|
|
101
112
|
|
|
102
113
|
# Handle reasoning models that use <think> tags
|
|
103
|
-
if
|
|
104
|
-
response_text = response_text.split(
|
|
114
|
+
if '</think>' in response_text:
|
|
115
|
+
response_text = response_text.split('</think>')[1]
|
|
105
116
|
|
|
106
117
|
try:
|
|
107
|
-
# Try direct parsing
|
|
108
|
-
|
|
118
|
+
# Try direct parsing - support both Pydantic v1 and v2
|
|
119
|
+
if hasattr(response_model, 'model_validate_json'):
|
|
120
|
+
# Pydantic v2
|
|
121
|
+
parsed = response_model.model_validate_json(response_text)
|
|
122
|
+
else:
|
|
123
|
+
# Pydantic v1
|
|
124
|
+
import json
|
|
125
|
+
|
|
126
|
+
parsed = response_model.parse_obj(json.loads(response_text))
|
|
109
127
|
except Exception:
|
|
110
128
|
# Fallback: use LLM to extract JSON
|
|
111
|
-
logger.warning(
|
|
129
|
+
logger.warning('Failed to parse JSON directly, using LLM to extract')
|
|
112
130
|
_parsed_messages = [
|
|
113
131
|
{
|
|
114
|
-
|
|
115
|
-
|
|
132
|
+
'role': 'system',
|
|
133
|
+
'content': (
|
|
134
|
+
'You are a helpful assistant that extracts JSON from text.'
|
|
135
|
+
),
|
|
116
136
|
},
|
|
117
137
|
{
|
|
118
|
-
|
|
119
|
-
|
|
138
|
+
'role': 'user',
|
|
139
|
+
'content': (
|
|
140
|
+
f'Extract JSON from the following text:\n{response_text}'
|
|
141
|
+
),
|
|
120
142
|
},
|
|
121
143
|
]
|
|
122
144
|
parsed_result = self.pydantic_parse(
|
|
@@ -124,9 +146,9 @@ class TwoStepPydanticMixin:
|
|
|
124
146
|
response_model=response_model,
|
|
125
147
|
**runtime_kwargs,
|
|
126
148
|
)[0]
|
|
127
|
-
parsed = parsed_result[
|
|
149
|
+
parsed = parsed_result['parsed']
|
|
128
150
|
|
|
129
|
-
parsed_results.append({
|
|
151
|
+
parsed_results.append({'parsed': parsed, 'messages': messages})
|
|
130
152
|
|
|
131
153
|
return parsed_results
|
|
132
154
|
|
|
@@ -153,7 +175,7 @@ class VLLMMixin:
|
|
|
153
175
|
get_base_client,
|
|
154
176
|
)
|
|
155
177
|
|
|
156
|
-
if not hasattr(self,
|
|
178
|
+
if not hasattr(self, 'vllm_cmd') or not self.vllm_cmd:
|
|
157
179
|
return
|
|
158
180
|
|
|
159
181
|
port = _extract_port_from_vllm_cmd(self.vllm_cmd)
|
|
@@ -163,26 +185,30 @@ class VLLMMixin:
|
|
|
163
185
|
try:
|
|
164
186
|
reuse_client = get_base_client(port, cache=False)
|
|
165
187
|
models_response = reuse_client.models.list()
|
|
166
|
-
if getattr(models_response,
|
|
188
|
+
if getattr(models_response, 'data', None):
|
|
167
189
|
reuse_existing = True
|
|
168
190
|
logger.info(
|
|
169
|
-
f
|
|
191
|
+
f'VLLM server already running on port {port}, reusing existing server (vllm_reuse=True)'
|
|
170
192
|
)
|
|
171
193
|
else:
|
|
172
|
-
logger.info(
|
|
194
|
+
logger.info(
|
|
195
|
+
f'No models returned from VLLM server on port {port}; starting a new server'
|
|
196
|
+
)
|
|
173
197
|
except Exception as exc:
|
|
174
198
|
logger.info(
|
|
175
|
-
f
|
|
199
|
+
f'Unable to reach VLLM server on port {port} (list_models failed): {exc}. Starting a new server.'
|
|
176
200
|
)
|
|
177
201
|
|
|
178
202
|
if not self.vllm_reuse:
|
|
179
203
|
if _is_server_running(port):
|
|
180
|
-
logger.info(
|
|
204
|
+
logger.info(
|
|
205
|
+
f'VLLM server already running on port {port}, killing it first (vllm_reuse=False)'
|
|
206
|
+
)
|
|
181
207
|
_kill_vllm_on_port(port)
|
|
182
|
-
logger.info(f
|
|
208
|
+
logger.info(f'Starting new VLLM server on port {port}')
|
|
183
209
|
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
184
210
|
elif not reuse_existing:
|
|
185
|
-
logger.info(f
|
|
211
|
+
logger.info(f'Starting VLLM server on port {port}')
|
|
186
212
|
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
187
213
|
|
|
188
214
|
def _load_lora_adapter(self) -> None:
|
|
@@ -195,8 +221,8 @@ class VLLMMixin:
|
|
|
195
221
|
3. Loads the LoRA adapter and updates the model name
|
|
196
222
|
"""
|
|
197
223
|
from .utils import (
|
|
198
|
-
_is_lora_path,
|
|
199
224
|
_get_port_from_client,
|
|
225
|
+
_is_lora_path,
|
|
200
226
|
_load_lora_adapter,
|
|
201
227
|
)
|
|
202
228
|
|
|
@@ -204,12 +230,14 @@ class VLLMMixin:
|
|
|
204
230
|
return
|
|
205
231
|
|
|
206
232
|
if not _is_lora_path(self.lora_path):
|
|
207
|
-
raise ValueError(
|
|
233
|
+
raise ValueError(
|
|
234
|
+
f"Invalid LoRA path '{self.lora_path}': Directory must contain 'adapter_config.json'"
|
|
235
|
+
)
|
|
208
236
|
|
|
209
|
-
logger.info(f
|
|
237
|
+
logger.info(f'Loading LoRA adapter from: {self.lora_path}')
|
|
210
238
|
|
|
211
239
|
# Get the expected LoRA name (basename of the path)
|
|
212
|
-
lora_name = os.path.basename(self.lora_path.rstrip(
|
|
240
|
+
lora_name = os.path.basename(self.lora_path.rstrip('/\\'))
|
|
213
241
|
if not lora_name: # Handle edge case of empty basename
|
|
214
242
|
lora_name = os.path.basename(os.path.dirname(self.lora_path))
|
|
215
243
|
|
|
@@ -217,13 +245,17 @@ class VLLMMixin:
|
|
|
217
245
|
try:
|
|
218
246
|
available_models = [m.id for m in self.client.models.list().data]
|
|
219
247
|
except Exception as e:
|
|
220
|
-
logger.warning(
|
|
248
|
+
logger.warning(
|
|
249
|
+
f'Failed to list models, proceeding with LoRA load: {str(e)[:100]}'
|
|
250
|
+
)
|
|
221
251
|
available_models = []
|
|
222
252
|
|
|
223
253
|
# Check if LoRA is already loaded
|
|
224
254
|
if lora_name in available_models and not self.force_lora_unload:
|
|
225
|
-
logger.info(
|
|
226
|
-
|
|
255
|
+
logger.info(
|
|
256
|
+
f"LoRA adapter '{lora_name}' is already loaded, using existing model"
|
|
257
|
+
)
|
|
258
|
+
self.model_kwargs['model'] = lora_name
|
|
227
259
|
return
|
|
228
260
|
|
|
229
261
|
# Force unload if requested
|
|
@@ -233,43 +265,49 @@ class VLLMMixin:
|
|
|
233
265
|
if port is not None:
|
|
234
266
|
try:
|
|
235
267
|
VLLMMixin.unload_lora(port, lora_name)
|
|
236
|
-
logger.info(f
|
|
268
|
+
logger.info(f'Successfully unloaded LoRA adapter: {lora_name}')
|
|
237
269
|
except Exception as e:
|
|
238
|
-
logger.warning(f
|
|
270
|
+
logger.warning(f'Failed to unload LoRA adapter: {str(e)[:100]}')
|
|
239
271
|
|
|
240
272
|
# Get port from client for API calls
|
|
241
273
|
port = _get_port_from_client(self.client)
|
|
242
274
|
if port is None:
|
|
243
275
|
raise ValueError(
|
|
244
276
|
f"Cannot load LoRA adapter '{self.lora_path}': "
|
|
245
|
-
f
|
|
246
|
-
f
|
|
277
|
+
f'Unable to determine port from client base_url. '
|
|
278
|
+
f'LoRA loading requires a client initialized with port.'
|
|
247
279
|
)
|
|
248
280
|
|
|
249
281
|
try:
|
|
250
282
|
# Load the LoRA adapter
|
|
251
283
|
loaded_lora_name = _load_lora_adapter(self.lora_path, port)
|
|
252
|
-
logger.info(f
|
|
284
|
+
logger.info(f'Successfully loaded LoRA adapter: {loaded_lora_name}')
|
|
253
285
|
|
|
254
286
|
# Update model name to the loaded LoRA name
|
|
255
|
-
self.model_kwargs[
|
|
287
|
+
self.model_kwargs['model'] = loaded_lora_name
|
|
256
288
|
|
|
257
289
|
except requests.RequestException as e:
|
|
258
290
|
# Check if error is due to LoRA already being loaded
|
|
259
291
|
error_msg = str(e)
|
|
260
|
-
if
|
|
261
|
-
logger.info(
|
|
292
|
+
if '400' in error_msg or 'Bad Request' in error_msg:
|
|
293
|
+
logger.info(
|
|
294
|
+
f"LoRA adapter may already be loaded, attempting to use '{lora_name}'"
|
|
295
|
+
)
|
|
262
296
|
# Refresh the model list to check if it's now available
|
|
263
297
|
try:
|
|
264
298
|
updated_models = [m.id for m in self.client.models.list().data]
|
|
265
299
|
if lora_name in updated_models:
|
|
266
|
-
logger.info(
|
|
267
|
-
|
|
300
|
+
logger.info(
|
|
301
|
+
f"Found LoRA adapter '{lora_name}' in updated model list"
|
|
302
|
+
)
|
|
303
|
+
self.model_kwargs['model'] = lora_name
|
|
268
304
|
return
|
|
269
305
|
except Exception:
|
|
270
306
|
pass # Fall through to original error
|
|
271
307
|
|
|
272
|
-
raise ValueError(
|
|
308
|
+
raise ValueError(
|
|
309
|
+
f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}"
|
|
310
|
+
) from e
|
|
273
311
|
|
|
274
312
|
def unload_lora_adapter(self, lora_path: str) -> None:
|
|
275
313
|
"""
|
|
@@ -286,14 +324,14 @@ class VLLMMixin:
|
|
|
286
324
|
port = _get_port_from_client(self.client)
|
|
287
325
|
if port is None:
|
|
288
326
|
raise ValueError(
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
327
|
+
'Cannot unload LoRA adapter: '
|
|
328
|
+
'Unable to determine port from client base_url. '
|
|
329
|
+
'LoRA operations require a client initialized with port.'
|
|
292
330
|
)
|
|
293
331
|
|
|
294
332
|
_unload_lora_adapter(lora_path, port)
|
|
295
|
-
lora_name = os.path.basename(lora_path.rstrip(
|
|
296
|
-
logger.info(f
|
|
333
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
334
|
+
logger.info(f'Unloaded LoRA adapter: {lora_name}')
|
|
297
335
|
|
|
298
336
|
@staticmethod
|
|
299
337
|
def unload_lora(port: int, lora_name: str) -> None:
|
|
@@ -309,15 +347,15 @@ class VLLMMixin:
|
|
|
309
347
|
"""
|
|
310
348
|
try:
|
|
311
349
|
response = requests.post(
|
|
312
|
-
f
|
|
350
|
+
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
313
351
|
headers={
|
|
314
|
-
|
|
315
|
-
|
|
352
|
+
'accept': 'application/json',
|
|
353
|
+
'Content-Type': 'application/json',
|
|
316
354
|
},
|
|
317
|
-
json={
|
|
355
|
+
json={'lora_name': lora_name, 'lora_int_id': 0},
|
|
318
356
|
)
|
|
319
357
|
response.raise_for_status()
|
|
320
|
-
logger.info(f
|
|
358
|
+
logger.info(f'Successfully unloaded LoRA adapter: {lora_name}')
|
|
321
359
|
except requests.RequestException as e:
|
|
322
360
|
logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
|
|
323
361
|
raise
|
|
@@ -326,7 +364,7 @@ class VLLMMixin:
|
|
|
326
364
|
"""Stop the VLLM server process if started by this instance."""
|
|
327
365
|
from .utils import stop_vllm_process
|
|
328
366
|
|
|
329
|
-
if hasattr(self,
|
|
367
|
+
if hasattr(self, 'vllm_process') and self.vllm_process is not None:
|
|
330
368
|
stop_vllm_process(self.vllm_process)
|
|
331
369
|
self.vllm_process = None
|
|
332
370
|
|
|
@@ -362,7 +400,7 @@ class ModelUtilsMixin:
|
|
|
362
400
|
"""Mixin for model utility methods."""
|
|
363
401
|
|
|
364
402
|
@staticmethod
|
|
365
|
-
def list_models(client
|
|
403
|
+
def list_models(client=None) -> list[str]:
|
|
366
404
|
"""
|
|
367
405
|
List available models from the OpenAI client.
|
|
368
406
|
|
|
@@ -372,6 +410,8 @@ class ModelUtilsMixin:
|
|
|
372
410
|
Returns:
|
|
373
411
|
List of available model names
|
|
374
412
|
"""
|
|
413
|
+
from openai import OpenAI
|
|
414
|
+
|
|
375
415
|
from .utils import get_base_client
|
|
376
416
|
|
|
377
417
|
client_instance = get_base_client(client, cache=False)
|