speedy-utils 1.1.27__py3-none-any.whl → 1.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. llm_utils/__init__.py +16 -4
  2. llm_utils/chat_format/__init__.py +10 -10
  3. llm_utils/chat_format/display.py +33 -21
  4. llm_utils/chat_format/transform.py +17 -19
  5. llm_utils/chat_format/utils.py +6 -4
  6. llm_utils/group_messages.py +17 -14
  7. llm_utils/lm/__init__.py +6 -5
  8. llm_utils/lm/async_lm/__init__.py +1 -0
  9. llm_utils/lm/async_lm/_utils.py +10 -9
  10. llm_utils/lm/async_lm/async_llm_task.py +141 -137
  11. llm_utils/lm/async_lm/async_lm.py +48 -42
  12. llm_utils/lm/async_lm/async_lm_base.py +59 -60
  13. llm_utils/lm/async_lm/lm_specific.py +4 -3
  14. llm_utils/lm/base_prompt_builder.py +93 -70
  15. llm_utils/lm/llm.py +126 -108
  16. llm_utils/lm/llm_signature.py +4 -2
  17. llm_utils/lm/lm_base.py +72 -73
  18. llm_utils/lm/mixins.py +102 -62
  19. llm_utils/lm/openai_memoize.py +124 -87
  20. llm_utils/lm/signature.py +105 -92
  21. llm_utils/lm/utils.py +42 -23
  22. llm_utils/scripts/vllm_load_balancer.py +23 -30
  23. llm_utils/scripts/vllm_serve.py +8 -7
  24. llm_utils/vector_cache/__init__.py +9 -3
  25. llm_utils/vector_cache/cli.py +1 -1
  26. llm_utils/vector_cache/core.py +59 -63
  27. llm_utils/vector_cache/types.py +7 -5
  28. llm_utils/vector_cache/utils.py +12 -8
  29. speedy_utils/__imports.py +244 -0
  30. speedy_utils/__init__.py +90 -194
  31. speedy_utils/all.py +125 -227
  32. speedy_utils/common/clock.py +37 -42
  33. speedy_utils/common/function_decorator.py +6 -12
  34. speedy_utils/common/logger.py +43 -52
  35. speedy_utils/common/notebook_utils.py +13 -21
  36. speedy_utils/common/patcher.py +21 -17
  37. speedy_utils/common/report_manager.py +42 -44
  38. speedy_utils/common/utils_cache.py +152 -169
  39. speedy_utils/common/utils_io.py +137 -103
  40. speedy_utils/common/utils_misc.py +15 -21
  41. speedy_utils/common/utils_print.py +22 -28
  42. speedy_utils/multi_worker/process.py +66 -79
  43. speedy_utils/multi_worker/thread.py +78 -155
  44. speedy_utils/scripts/mpython.py +38 -36
  45. speedy_utils/scripts/openapi_client_codegen.py +10 -10
  46. {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/METADATA +1 -1
  47. speedy_utils-1.1.29.dist-info/RECORD +57 -0
  48. vision_utils/README.md +202 -0
  49. vision_utils/__init__.py +4 -0
  50. vision_utils/io_utils.py +735 -0
  51. vision_utils/plot.py +345 -0
  52. speedy_utils-1.1.27.dist-info/RECORD +0 -52
  53. {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/WHEEL +0 -0
  54. {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/entry_points.txt +0 -0
llm_utils/lm/lm_base.py CHANGED
@@ -40,23 +40,23 @@ class LMBase:
40
40
  def __init__(
41
41
  self,
42
42
  *,
43
- base_url: Optional[str] = None,
44
- api_key: Optional[str] = None,
43
+ base_url: str | None = None,
44
+ api_key: str | None = None,
45
45
  cache: bool = True,
46
- ports: Optional[List[int]] = None,
46
+ ports: list[int] | None = None,
47
47
  ) -> None:
48
48
  self.base_url = base_url
49
- self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
49
+ self.api_key = api_key or os.getenv('OPENAI_API_KEY', 'abc')
50
50
  self._cache = cache
51
51
  self.ports = ports
52
52
 
53
53
  @property
54
- def client(self) -> MOpenAI:
54
+ def client(self) -> MOpenAI: # type: ignore
55
55
  # if have multiple ports
56
56
  if self.ports and self.base_url:
57
57
  import random
58
58
  import re
59
-
59
+
60
60
  port = random.choice(self.ports)
61
61
  # Replace port in base_url if it exists
62
62
  base_url_pattern = r'(https?://[^:/]+):?\d*(/.*)?'
@@ -64,16 +64,16 @@ class LMBase:
64
64
  if match:
65
65
  host_part = match.group(1)
66
66
  path_part = match.group(2) or '/v1'
67
- api_base = f"{host_part}:{port}{path_part}"
67
+ api_base = f'{host_part}:{port}{path_part}'
68
68
  else:
69
69
  api_base = self.base_url
70
- logger.debug(f"Using port: {port}")
70
+ logger.debug(f'Using port: {port}')
71
71
  else:
72
72
  api_base = self.base_url
73
-
73
+
74
74
  if api_base is None:
75
- raise ValueError("base_url must be provided")
76
-
75
+ raise ValueError('base_url must be provided')
76
+
77
77
  client = MOpenAI(
78
78
  api_key=self.api_key,
79
79
  base_url=api_base,
@@ -89,8 +89,8 @@ class LMBase:
89
89
  def __call__( # type: ignore
90
90
  self,
91
91
  *,
92
- prompt: Optional[str] = ...,
93
- messages: Optional[RawMsgs] = ...,
92
+ prompt: str | None = ...,
93
+ messages: RawMsgs | None = ...,
94
94
  response_format: type[str] = str,
95
95
  return_openai_response: bool = ...,
96
96
  **kwargs: Any,
@@ -100,9 +100,9 @@ class LMBase:
100
100
  def __call__(
101
101
  self,
102
102
  *,
103
- prompt: Optional[str] = ...,
104
- messages: Optional[RawMsgs] = ...,
105
- response_format: Type[TModel],
103
+ prompt: str | None = ...,
104
+ messages: RawMsgs | None = ...,
105
+ response_format: type[TModel],
106
106
  return_openai_response: bool = ...,
107
107
  **kwargs: Any,
108
108
  ) -> TModel: ...
@@ -114,62 +114,62 @@ class LMBase:
114
114
  def _convert_messages(msgs: LegacyMsgs) -> Messages:
115
115
  converted: Messages = []
116
116
  for msg in msgs:
117
- role = msg["role"]
118
- content = msg["content"]
119
- if role == "user":
117
+ role = msg['role']
118
+ content = msg['content']
119
+ if role == 'user':
120
120
  converted.append(
121
- ChatCompletionUserMessageParam(role="user", content=content)
121
+ ChatCompletionUserMessageParam(role='user', content=content)
122
122
  )
123
- elif role == "assistant":
123
+ elif role == 'assistant':
124
124
  converted.append(
125
125
  ChatCompletionAssistantMessageParam(
126
- role="assistant", content=content
126
+ role='assistant', content=content
127
127
  )
128
128
  )
129
- elif role == "system":
129
+ elif role == 'system':
130
130
  converted.append(
131
- ChatCompletionSystemMessageParam(role="system", content=content)
131
+ ChatCompletionSystemMessageParam(role='system', content=content)
132
132
  )
133
- elif role == "tool":
133
+ elif role == 'tool':
134
134
  converted.append(
135
135
  ChatCompletionToolMessageParam(
136
- role="tool",
136
+ role='tool',
137
137
  content=content,
138
- tool_call_id=msg.get("tool_call_id") or "",
138
+ tool_call_id=msg.get('tool_call_id') or '',
139
139
  )
140
140
  )
141
141
  else:
142
- converted.append({"role": role, "content": content}) # type: ignore[arg-type]
142
+ converted.append({'role': role, 'content': content}) # type: ignore[arg-type]
143
143
  return converted
144
144
 
145
145
  @staticmethod
146
146
  def _parse_output(
147
- raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
148
- ) -> Union[str, BaseModel]:
149
- if hasattr(raw_response, "model_dump"):
147
+ raw_response: Any, response_format: type[str] | type[BaseModel]
148
+ ) -> str | BaseModel:
149
+ if hasattr(raw_response, 'model_dump'):
150
150
  raw_response = raw_response.model_dump()
151
151
 
152
152
  if response_format is str:
153
- if isinstance(raw_response, dict) and "choices" in raw_response:
154
- message = raw_response["choices"][0]["message"]
155
- return message.get("content", "") or ""
153
+ if isinstance(raw_response, dict) and 'choices' in raw_response:
154
+ message = raw_response['choices'][0]['message']
155
+ return message.get('content', '') or ''
156
156
  return cast(str, raw_response)
157
157
 
158
- model_cls = cast(Type[BaseModel], response_format)
158
+ model_cls = cast(type[BaseModel], response_format)
159
159
 
160
- if isinstance(raw_response, dict) and "choices" in raw_response:
161
- message = raw_response["choices"][0]["message"]
162
- if "parsed" in message:
163
- return model_cls.model_validate(message["parsed"])
164
- content = message.get("content")
160
+ if isinstance(raw_response, dict) and 'choices' in raw_response:
161
+ message = raw_response['choices'][0]['message']
162
+ if 'parsed' in message:
163
+ return model_cls.model_validate(message['parsed'])
164
+ content = message.get('content')
165
165
  if content is None:
166
- raise ValueError("Model returned empty content")
166
+ raise ValueError('Model returned empty content')
167
167
  try:
168
168
  data = json.loads(content)
169
169
  return model_cls.model_validate(data)
170
170
  except Exception as exc:
171
171
  raise ValueError(
172
- f"Failed to parse model output as JSON:\n{content}"
172
+ f'Failed to parse model output as JSON:\n{content}'
173
173
  ) from exc
174
174
 
175
175
  if isinstance(raw_response, model_cls):
@@ -182,7 +182,7 @@ class LMBase:
182
182
  return model_cls.model_validate(data)
183
183
  except Exception as exc:
184
184
  raise ValueError(
185
- f"Model did not return valid JSON:\n---\n{raw_response}"
185
+ f'Model did not return valid JSON:\n---\n{raw_response}'
186
186
  ) from exc
187
187
 
188
188
  # ------------------------------------------------------------------ #
@@ -190,17 +190,17 @@ class LMBase:
190
190
  # ------------------------------------------------------------------ #
191
191
 
192
192
  @staticmethod
193
- def list_models(base_url: Optional[str] = None) -> List[str]:
193
+ def list_models(base_url: str | None = None) -> list[str]:
194
194
  try:
195
195
  if base_url is None:
196
- raise ValueError("base_url must be provided")
196
+ raise ValueError('base_url must be provided')
197
197
  client = LMBase(base_url=base_url).client
198
198
  base_url_obj: URL = client.base_url
199
- logger.debug(f"Base URL: {base_url_obj}")
199
+ logger.debug(f'Base URL: {base_url_obj}')
200
200
  models: SyncPage[Model] = client.models.list() # type: ignore[assignment]
201
201
  return [model.id for model in models.data]
202
202
  except Exception as exc:
203
- logger.error(f"Failed to list models: {exc}")
203
+ logger.error(f'Failed to list models: {exc}')
204
204
  return []
205
205
 
206
206
  def build_system_prompt(
@@ -212,15 +212,15 @@ class LMBase:
212
212
  think,
213
213
  ):
214
214
  if add_json_schema_to_instruction and response_model:
215
- schema_block = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
215
+ schema_block = f'\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>'
216
216
  # if schema_block not in system_content:
217
- if "<output_json_schema>" in system_content:
217
+ if '<output_json_schema>' in system_content:
218
218
  # remove exsting schema block
219
219
  import re # replace
220
220
 
221
221
  system_content = re.sub(
222
- r"<output_json_schema>.*?</output_json_schema>",
223
- "",
222
+ r'<output_json_schema>.*?</output_json_schema>',
223
+ '',
224
224
  system_content,
225
225
  flags=re.DOTALL,
226
226
  )
@@ -228,36 +228,35 @@ class LMBase:
228
228
  system_content += schema_block
229
229
 
230
230
  if think is True:
231
- if "/think" in system_content:
231
+ if '/think' in system_content:
232
232
  pass
233
- elif "/no_think" in system_content:
234
- system_content = system_content.replace("/no_think", "/think")
233
+ elif '/no_think' in system_content:
234
+ system_content = system_content.replace('/no_think', '/think')
235
235
  else:
236
- system_content += "\n\n/think"
236
+ system_content += '\n\n/think'
237
237
  elif think is False:
238
- if "/no_think" in system_content:
238
+ if '/no_think' in system_content:
239
239
  pass
240
- elif "/think" in system_content:
241
- system_content = system_content.replace("/think", "/no_think")
240
+ elif '/think' in system_content:
241
+ system_content = system_content.replace('/think', '/no_think')
242
242
  else:
243
- system_content += "\n\n/no_think"
243
+ system_content += '\n\n/no_think'
244
244
  return system_content
245
245
 
246
246
  def inspect_history(self):
247
247
  """Inspect the history of the LLM calls."""
248
- pass
249
-
250
248
 
251
- def get_model_name(client: OpenAI|str|int) -> str:
249
+
250
+ def get_model_name(client: OpenAI | str | int) -> str:
252
251
  """
253
252
  Get the first available model name from the client.
254
-
253
+
255
254
  Args:
256
255
  client: OpenAI client, base_url string, or port number
257
-
256
+
258
257
  Returns:
259
258
  Name of the first available model
260
-
259
+
261
260
  Raises:
262
261
  ValueError: If no models are available or client is invalid
263
262
  """
@@ -269,17 +268,17 @@ def get_model_name(client: OpenAI|str|int) -> str:
269
268
  openai_client = OpenAI(base_url=client, api_key='abc')
270
269
  elif isinstance(client, int):
271
270
  # Port number
272
- base_url = f"http://localhost:{client}/v1"
271
+ base_url = f'http://localhost:{client}/v1'
273
272
  openai_client = OpenAI(base_url=base_url, api_key='abc')
274
273
  else:
275
- raise ValueError(f"Unsupported client type: {type(client)}")
276
-
274
+ raise ValueError(f'Unsupported client type: {type(client)}')
275
+
277
276
  models = openai_client.models.list()
278
277
  if not models.data:
279
- raise ValueError("No models available")
280
-
278
+ raise ValueError('No models available')
279
+
281
280
  return models.data[0].id
282
-
281
+
283
282
  except Exception as exc:
284
- logger.error(f"Failed to get model name: {exc}")
285
- raise ValueError(f"Could not retrieve model name: {exc}") from exc
283
+ logger.error(f'Failed to get model name: {exc}')
284
+ raise ValueError(f'Could not retrieve model name: {exc}') from exc
llm_utils/lm/mixins.py CHANGED
@@ -1,14 +1,21 @@
1
1
  """Mixin classes for LLM functionality extensions."""
2
2
 
3
+ # type: ignore
4
+
5
+ from __future__ import annotations
6
+
3
7
  import os
4
8
  import subprocess
5
9
  from time import sleep
6
- from typing import Any, Dict, List, Optional, Type, Union
10
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
7
11
 
8
12
  import requests
9
13
  from loguru import logger
10
- from openai import OpenAI
11
- from pydantic import BaseModel
14
+
15
+
16
+ if TYPE_CHECKING:
17
+ from openai import OpenAI
18
+ from pydantic import BaseModel
12
19
 
13
20
 
14
21
  class TemperatureRangeMixin:
@@ -16,12 +23,12 @@ class TemperatureRangeMixin:
16
23
 
17
24
  def temperature_range_sampling(
18
25
  self,
19
- input_data: Union[str, BaseModel, List[Dict]],
26
+ input_data: 'str | BaseModel | list[dict]',
20
27
  temperature_ranges: tuple[float, float],
21
28
  n: int = 32,
22
- response_model: Optional[Type[BaseModel] | Type[str]] = None,
29
+ response_model: 'type[BaseModel] | type[str] | None' = None,
23
30
  **runtime_kwargs,
24
- ) -> List[Dict[str, Any]]:
31
+ ) -> list[dict[str, Any]]:
25
32
  """
26
33
  Sample LLM responses with a range of temperatures.
27
34
 
@@ -38,11 +45,13 @@ class TemperatureRangeMixin:
38
45
  Returns:
39
46
  List of response dictionaries from all temperature samples
40
47
  """
48
+ from pydantic import BaseModel
49
+
41
50
  from speedy_utils.multi_worker.thread import multi_thread
42
51
 
43
52
  min_temp, max_temp = temperature_ranges
44
53
  if n < 2:
45
- raise ValueError(f"n must be >= 2, got {n}")
54
+ raise ValueError(f'n must be >= 2, got {n}')
46
55
 
47
56
  step = (max_temp - min_temp) / (n - 1)
48
57
  list_kwargs = []
@@ -56,7 +65,7 @@ class TemperatureRangeMixin:
56
65
  list_kwargs.append(kwargs)
57
66
 
58
67
  def f(kwargs):
59
- i = kwargs.pop("i")
68
+ i = kwargs.pop('i')
60
69
  sleep(i * 0.05)
61
70
  return self.__inner_call__(
62
71
  input_data,
@@ -73,10 +82,10 @@ class TwoStepPydanticMixin:
73
82
 
74
83
  def two_step_pydantic_parse(
75
84
  self,
76
- input_data: Union[str, BaseModel, List[Dict]],
77
- response_model: Type[BaseModel],
85
+ input_data,
86
+ response_model,
78
87
  **runtime_kwargs,
79
- ) -> List[Dict[str, Any]]:
88
+ ) -> list[dict[str, Any]]:
80
89
  """
81
90
  Parse responses in two steps: text completion then Pydantic parsing.
82
91
 
@@ -91,32 +100,45 @@ class TwoStepPydanticMixin:
91
100
  Returns:
92
101
  List of parsed response dictionaries
93
102
  """
103
+ from pydantic import BaseModel
104
+
94
105
  # Step 1: Get text completions
95
106
  results = self.text_completion(input_data, **runtime_kwargs)
96
107
  parsed_results = []
97
108
 
98
109
  for result in results:
99
- response_text = result["parsed"]
100
- messages = result["messages"]
110
+ response_text = result['parsed']
111
+ messages = result['messages']
101
112
 
102
113
  # Handle reasoning models that use <think> tags
103
- if "</think>" in response_text:
104
- response_text = response_text.split("</think>")[1]
114
+ if '</think>' in response_text:
115
+ response_text = response_text.split('</think>')[1]
105
116
 
106
117
  try:
107
- # Try direct parsing
108
- parsed = response_model.model_validate_json(response_text)
118
+ # Try direct parsing - support both Pydantic v1 and v2
119
+ if hasattr(response_model, 'model_validate_json'):
120
+ # Pydantic v2
121
+ parsed = response_model.model_validate_json(response_text)
122
+ else:
123
+ # Pydantic v1
124
+ import json
125
+
126
+ parsed = response_model.parse_obj(json.loads(response_text))
109
127
  except Exception:
110
128
  # Fallback: use LLM to extract JSON
111
- logger.warning("Failed to parse JSON directly, using LLM to extract")
129
+ logger.warning('Failed to parse JSON directly, using LLM to extract')
112
130
  _parsed_messages = [
113
131
  {
114
- "role": "system",
115
- "content": ("You are a helpful assistant that extracts JSON from text."),
132
+ 'role': 'system',
133
+ 'content': (
134
+ 'You are a helpful assistant that extracts JSON from text.'
135
+ ),
116
136
  },
117
137
  {
118
- "role": "user",
119
- "content": (f"Extract JSON from the following text:\n{response_text}"),
138
+ 'role': 'user',
139
+ 'content': (
140
+ f'Extract JSON from the following text:\n{response_text}'
141
+ ),
120
142
  },
121
143
  ]
122
144
  parsed_result = self.pydantic_parse(
@@ -124,9 +146,9 @@ class TwoStepPydanticMixin:
124
146
  response_model=response_model,
125
147
  **runtime_kwargs,
126
148
  )[0]
127
- parsed = parsed_result["parsed"]
149
+ parsed = parsed_result['parsed']
128
150
 
129
- parsed_results.append({"parsed": parsed, "messages": messages})
151
+ parsed_results.append({'parsed': parsed, 'messages': messages})
130
152
 
131
153
  return parsed_results
132
154
 
@@ -153,7 +175,7 @@ class VLLMMixin:
153
175
  get_base_client,
154
176
  )
155
177
 
156
- if not hasattr(self, "vllm_cmd") or not self.vllm_cmd:
178
+ if not hasattr(self, 'vllm_cmd') or not self.vllm_cmd:
157
179
  return
158
180
 
159
181
  port = _extract_port_from_vllm_cmd(self.vllm_cmd)
@@ -163,26 +185,30 @@ class VLLMMixin:
163
185
  try:
164
186
  reuse_client = get_base_client(port, cache=False)
165
187
  models_response = reuse_client.models.list()
166
- if getattr(models_response, "data", None):
188
+ if getattr(models_response, 'data', None):
167
189
  reuse_existing = True
168
190
  logger.info(
169
- f"VLLM server already running on port {port}, reusing existing server (vllm_reuse=True)"
191
+ f'VLLM server already running on port {port}, reusing existing server (vllm_reuse=True)'
170
192
  )
171
193
  else:
172
- logger.info(f"No models returned from VLLM server on port {port}; starting a new server")
194
+ logger.info(
195
+ f'No models returned from VLLM server on port {port}; starting a new server'
196
+ )
173
197
  except Exception as exc:
174
198
  logger.info(
175
- f"Unable to reach VLLM server on port {port} (list_models failed): {exc}. Starting a new server."
199
+ f'Unable to reach VLLM server on port {port} (list_models failed): {exc}. Starting a new server.'
176
200
  )
177
201
 
178
202
  if not self.vllm_reuse:
179
203
  if _is_server_running(port):
180
- logger.info(f"VLLM server already running on port {port}, killing it first (vllm_reuse=False)")
204
+ logger.info(
205
+ f'VLLM server already running on port {port}, killing it first (vllm_reuse=False)'
206
+ )
181
207
  _kill_vllm_on_port(port)
182
- logger.info(f"Starting new VLLM server on port {port}")
208
+ logger.info(f'Starting new VLLM server on port {port}')
183
209
  self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
184
210
  elif not reuse_existing:
185
- logger.info(f"Starting VLLM server on port {port}")
211
+ logger.info(f'Starting VLLM server on port {port}')
186
212
  self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
187
213
 
188
214
  def _load_lora_adapter(self) -> None:
@@ -195,8 +221,8 @@ class VLLMMixin:
195
221
  3. Loads the LoRA adapter and updates the model name
196
222
  """
197
223
  from .utils import (
198
- _is_lora_path,
199
224
  _get_port_from_client,
225
+ _is_lora_path,
200
226
  _load_lora_adapter,
201
227
  )
202
228
 
@@ -204,12 +230,14 @@ class VLLMMixin:
204
230
  return
205
231
 
206
232
  if not _is_lora_path(self.lora_path):
207
- raise ValueError(f"Invalid LoRA path '{self.lora_path}': Directory must contain 'adapter_config.json'")
233
+ raise ValueError(
234
+ f"Invalid LoRA path '{self.lora_path}': Directory must contain 'adapter_config.json'"
235
+ )
208
236
 
209
- logger.info(f"Loading LoRA adapter from: {self.lora_path}")
237
+ logger.info(f'Loading LoRA adapter from: {self.lora_path}')
210
238
 
211
239
  # Get the expected LoRA name (basename of the path)
212
- lora_name = os.path.basename(self.lora_path.rstrip("/\\"))
240
+ lora_name = os.path.basename(self.lora_path.rstrip('/\\'))
213
241
  if not lora_name: # Handle edge case of empty basename
214
242
  lora_name = os.path.basename(os.path.dirname(self.lora_path))
215
243
 
@@ -217,13 +245,17 @@ class VLLMMixin:
217
245
  try:
218
246
  available_models = [m.id for m in self.client.models.list().data]
219
247
  except Exception as e:
220
- logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
248
+ logger.warning(
249
+ f'Failed to list models, proceeding with LoRA load: {str(e)[:100]}'
250
+ )
221
251
  available_models = []
222
252
 
223
253
  # Check if LoRA is already loaded
224
254
  if lora_name in available_models and not self.force_lora_unload:
225
- logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
226
- self.model_kwargs["model"] = lora_name
255
+ logger.info(
256
+ f"LoRA adapter '{lora_name}' is already loaded, using existing model"
257
+ )
258
+ self.model_kwargs['model'] = lora_name
227
259
  return
228
260
 
229
261
  # Force unload if requested
@@ -233,43 +265,49 @@ class VLLMMixin:
233
265
  if port is not None:
234
266
  try:
235
267
  VLLMMixin.unload_lora(port, lora_name)
236
- logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
268
+ logger.info(f'Successfully unloaded LoRA adapter: {lora_name}')
237
269
  except Exception as e:
238
- logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
270
+ logger.warning(f'Failed to unload LoRA adapter: {str(e)[:100]}')
239
271
 
240
272
  # Get port from client for API calls
241
273
  port = _get_port_from_client(self.client)
242
274
  if port is None:
243
275
  raise ValueError(
244
276
  f"Cannot load LoRA adapter '{self.lora_path}': "
245
- f"Unable to determine port from client base_url. "
246
- f"LoRA loading requires a client initialized with port."
277
+ f'Unable to determine port from client base_url. '
278
+ f'LoRA loading requires a client initialized with port.'
247
279
  )
248
280
 
249
281
  try:
250
282
  # Load the LoRA adapter
251
283
  loaded_lora_name = _load_lora_adapter(self.lora_path, port)
252
- logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
284
+ logger.info(f'Successfully loaded LoRA adapter: {loaded_lora_name}')
253
285
 
254
286
  # Update model name to the loaded LoRA name
255
- self.model_kwargs["model"] = loaded_lora_name
287
+ self.model_kwargs['model'] = loaded_lora_name
256
288
 
257
289
  except requests.RequestException as e:
258
290
  # Check if error is due to LoRA already being loaded
259
291
  error_msg = str(e)
260
- if "400" in error_msg or "Bad Request" in error_msg:
261
- logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
292
+ if '400' in error_msg or 'Bad Request' in error_msg:
293
+ logger.info(
294
+ f"LoRA adapter may already be loaded, attempting to use '{lora_name}'"
295
+ )
262
296
  # Refresh the model list to check if it's now available
263
297
  try:
264
298
  updated_models = [m.id for m in self.client.models.list().data]
265
299
  if lora_name in updated_models:
266
- logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
267
- self.model_kwargs["model"] = lora_name
300
+ logger.info(
301
+ f"Found LoRA adapter '{lora_name}' in updated model list"
302
+ )
303
+ self.model_kwargs['model'] = lora_name
268
304
  return
269
305
  except Exception:
270
306
  pass # Fall through to original error
271
307
 
272
- raise ValueError(f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}")
308
+ raise ValueError(
309
+ f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}"
310
+ ) from e
273
311
 
274
312
  def unload_lora_adapter(self, lora_path: str) -> None:
275
313
  """
@@ -286,14 +324,14 @@ class VLLMMixin:
286
324
  port = _get_port_from_client(self.client)
287
325
  if port is None:
288
326
  raise ValueError(
289
- "Cannot unload LoRA adapter: "
290
- "Unable to determine port from client base_url. "
291
- "LoRA operations require a client initialized with port."
327
+ 'Cannot unload LoRA adapter: '
328
+ 'Unable to determine port from client base_url. '
329
+ 'LoRA operations require a client initialized with port.'
292
330
  )
293
331
 
294
332
  _unload_lora_adapter(lora_path, port)
295
- lora_name = os.path.basename(lora_path.rstrip("/\\"))
296
- logger.info(f"Unloaded LoRA adapter: {lora_name}")
333
+ lora_name = os.path.basename(lora_path.rstrip('/\\'))
334
+ logger.info(f'Unloaded LoRA adapter: {lora_name}')
297
335
 
298
336
  @staticmethod
299
337
  def unload_lora(port: int, lora_name: str) -> None:
@@ -309,15 +347,15 @@ class VLLMMixin:
309
347
  """
310
348
  try:
311
349
  response = requests.post(
312
- f"http://localhost:{port}/v1/unload_lora_adapter",
350
+ f'http://localhost:{port}/v1/unload_lora_adapter',
313
351
  headers={
314
- "accept": "application/json",
315
- "Content-Type": "application/json",
352
+ 'accept': 'application/json',
353
+ 'Content-Type': 'application/json',
316
354
  },
317
- json={"lora_name": lora_name, "lora_int_id": 0},
355
+ json={'lora_name': lora_name, 'lora_int_id': 0},
318
356
  )
319
357
  response.raise_for_status()
320
- logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
358
+ logger.info(f'Successfully unloaded LoRA adapter: {lora_name}')
321
359
  except requests.RequestException as e:
322
360
  logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
323
361
  raise
@@ -326,7 +364,7 @@ class VLLMMixin:
326
364
  """Stop the VLLM server process if started by this instance."""
327
365
  from .utils import stop_vllm_process
328
366
 
329
- if hasattr(self, "vllm_process") and self.vllm_process is not None:
367
+ if hasattr(self, 'vllm_process') and self.vllm_process is not None:
330
368
  stop_vllm_process(self.vllm_process)
331
369
  self.vllm_process = None
332
370
 
@@ -362,7 +400,7 @@ class ModelUtilsMixin:
362
400
  """Mixin for model utility methods."""
363
401
 
364
402
  @staticmethod
365
- def list_models(client: Union[OpenAI, int, str, None] = None) -> List[str]:
403
+ def list_models(client=None) -> list[str]:
366
404
  """
367
405
  List available models from the OpenAI client.
368
406
 
@@ -372,6 +410,8 @@ class ModelUtilsMixin:
372
410
  Returns:
373
411
  List of available model names
374
412
  """
413
+ from openai import OpenAI
414
+
375
415
  from .utils import get_base_client
376
416
 
377
417
  client_instance = get_base_client(client, cache=False)