speedy-utils 1.1.23__py3-none-any.whl → 1.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/mixins.py ADDED
@@ -0,0 +1,379 @@
1
+ """Mixin classes for LLM functionality extensions."""
2
+
3
+ import os
4
+ import subprocess
5
+ from time import sleep
6
+ from typing import Any, Dict, List, Optional, Type, Union
7
+
8
+ import requests
9
+ from loguru import logger
10
+ from openai import OpenAI
11
+ from pydantic import BaseModel
12
+
13
+
14
+ class TemperatureRangeMixin:
15
+ """Mixin for sampling with different temperature ranges."""
16
+
17
+ def temperature_range_sampling(
18
+ self,
19
+ input_data: Union[str, BaseModel, List[Dict]],
20
+ temperature_ranges: tuple[float, float],
21
+ n: int = 32,
22
+ response_model: Optional[Type[BaseModel] | Type[str]] = None,
23
+ **runtime_kwargs,
24
+ ) -> List[Dict[str, Any]]:
25
+ """
26
+ Sample LLM responses with a range of temperatures.
27
+
28
+ This method generates multiple responses by systematically varying
29
+ the temperature parameter, which controls randomness in the output.
30
+
31
+ Args:
32
+ input_data: Input data (string, BaseModel, or message list)
33
+ temperature_ranges: Tuple of (min_temp, max_temp) to sample
34
+ n: Number of temperature samples to generate (must be >= 2)
35
+ response_model: Optional response model override
36
+ **runtime_kwargs: Additional runtime parameters
37
+
38
+ Returns:
39
+ List of response dictionaries from all temperature samples
40
+ """
41
+ from speedy_utils.multi_worker.thread import multi_thread
42
+
43
+ min_temp, max_temp = temperature_ranges
44
+ if n < 2:
45
+ raise ValueError(f"n must be >= 2, got {n}")
46
+
47
+ step = (max_temp - min_temp) / (n - 1)
48
+ list_kwargs = []
49
+
50
+ for i in range(n):
51
+ kwargs = dict(
52
+ temperature=min_temp + i * step,
53
+ i=i,
54
+ **runtime_kwargs,
55
+ )
56
+ list_kwargs.append(kwargs)
57
+
58
+ def f(kwargs):
59
+ i = kwargs.pop("i")
60
+ sleep(i * 0.05)
61
+ return self.__inner_call__(
62
+ input_data,
63
+ response_model=response_model,
64
+ **kwargs,
65
+ )[0]
66
+
67
+ choices = multi_thread(f, list_kwargs, progress=False)
68
+ return [c for c in choices if c is not None]
69
+
70
+
71
+ class TwoStepPydanticMixin:
72
+ """Mixin for two-step Pydantic parsing functionality."""
73
+
74
+ def two_step_pydantic_parse(
75
+ self,
76
+ input_data: Union[str, BaseModel, List[Dict]],
77
+ response_model: Type[BaseModel],
78
+ **runtime_kwargs,
79
+ ) -> List[Dict[str, Any]]:
80
+ """
81
+ Parse responses in two steps: text completion then Pydantic parsing.
82
+
83
+ This is useful for models that may include reasoning or extra text
84
+ before the JSON output.
85
+
86
+ Args:
87
+ input_data: Input data (string, BaseModel, or message list)
88
+ response_model: Pydantic model to parse into
89
+ **runtime_kwargs: Additional runtime parameters
90
+
91
+ Returns:
92
+ List of parsed response dictionaries
93
+ """
94
+ # Step 1: Get text completions
95
+ results = self.text_completion(input_data, **runtime_kwargs)
96
+ parsed_results = []
97
+
98
+ for result in results:
99
+ response_text = result["parsed"]
100
+ messages = result["messages"]
101
+
102
+ # Handle reasoning models that use <think> tags
103
+ if "</think>" in response_text:
104
+ response_text = response_text.split("</think>")[1]
105
+
106
+ try:
107
+ # Try direct parsing
108
+ parsed = response_model.model_validate_json(response_text)
109
+ except Exception:
110
+ # Fallback: use LLM to extract JSON
111
+ logger.warning("Failed to parse JSON directly, using LLM to extract")
112
+ _parsed_messages = [
113
+ {
114
+ "role": "system",
115
+ "content": ("You are a helpful assistant that extracts JSON from text."),
116
+ },
117
+ {
118
+ "role": "user",
119
+ "content": (f"Extract JSON from the following text:\n{response_text}"),
120
+ },
121
+ ]
122
+ parsed_result = self.pydantic_parse(
123
+ _parsed_messages,
124
+ response_model=response_model,
125
+ **runtime_kwargs,
126
+ )[0]
127
+ parsed = parsed_result["parsed"]
128
+
129
+ parsed_results.append({"parsed": parsed, "messages": messages})
130
+
131
+ return parsed_results
132
+
133
+
134
+ class VLLMMixin:
135
+ """Mixin for VLLM server management and LoRA operations."""
136
+
137
+ def _setup_vllm_server(self) -> None:
138
+ """
139
+ Setup VLLM server if vllm_cmd is provided.
140
+
141
+ This method handles:
142
+ - Server reuse logic
143
+ - Starting new servers
144
+ - Port management
145
+
146
+ Should be called from __init__.
147
+ """
148
+ from .utils import (
149
+ _extract_port_from_vllm_cmd,
150
+ _is_server_running,
151
+ _kill_vllm_on_port,
152
+ _start_vllm_server,
153
+ get_base_client,
154
+ )
155
+
156
+ if not hasattr(self, "vllm_cmd") or not self.vllm_cmd:
157
+ return
158
+
159
+ port = _extract_port_from_vllm_cmd(self.vllm_cmd)
160
+ reuse_existing = False
161
+
162
+ if self.vllm_reuse:
163
+ try:
164
+ reuse_client = get_base_client(port, cache=False)
165
+ models_response = reuse_client.models.list()
166
+ if getattr(models_response, "data", None):
167
+ reuse_existing = True
168
+ logger.info(
169
+ f"VLLM server already running on port {port}, reusing existing server (vllm_reuse=True)"
170
+ )
171
+ else:
172
+ logger.info(f"No models returned from VLLM server on port {port}; starting a new server")
173
+ except Exception as exc:
174
+ logger.info(
175
+ f"Unable to reach VLLM server on port {port} (list_models failed): {exc}. Starting a new server."
176
+ )
177
+
178
+ if not self.vllm_reuse:
179
+ if _is_server_running(port):
180
+ logger.info(f"VLLM server already running on port {port}, killing it first (vllm_reuse=False)")
181
+ _kill_vllm_on_port(port)
182
+ logger.info(f"Starting new VLLM server on port {port}")
183
+ self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
184
+ elif not reuse_existing:
185
+ logger.info(f"Starting VLLM server on port {port}")
186
+ self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
187
+
188
+ def _load_lora_adapter(self) -> None:
189
+ """
190
+ Load LoRA adapter from the specified lora_path.
191
+
192
+ This method:
193
+ 1. Validates that lora_path is a valid LoRA directory
194
+ 2. Checks if LoRA is already loaded (unless force_lora_unload)
195
+ 3. Loads the LoRA adapter and updates the model name
196
+ """
197
+ from .utils import (
198
+ _is_lora_path,
199
+ _get_port_from_client,
200
+ _load_lora_adapter,
201
+ )
202
+
203
+ if not self.lora_path:
204
+ return
205
+
206
+ if not _is_lora_path(self.lora_path):
207
+ raise ValueError(f"Invalid LoRA path '{self.lora_path}': Directory must contain 'adapter_config.json'")
208
+
209
+ logger.info(f"Loading LoRA adapter from: {self.lora_path}")
210
+
211
+ # Get the expected LoRA name (basename of the path)
212
+ lora_name = os.path.basename(self.lora_path.rstrip("/\\"))
213
+ if not lora_name: # Handle edge case of empty basename
214
+ lora_name = os.path.basename(os.path.dirname(self.lora_path))
215
+
216
+ # Get list of available models to check if LoRA is already loaded
217
+ try:
218
+ available_models = [m.id for m in self.client.models.list().data]
219
+ except Exception as e:
220
+ logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
221
+ available_models = []
222
+
223
+ # Check if LoRA is already loaded
224
+ if lora_name in available_models and not self.force_lora_unload:
225
+ logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
226
+ self.model_kwargs["model"] = lora_name
227
+ return
228
+
229
+ # Force unload if requested
230
+ if self.force_lora_unload and lora_name in available_models:
231
+ logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
232
+ port = _get_port_from_client(self.client)
233
+ if port is not None:
234
+ try:
235
+ VLLMMixin.unload_lora(port, lora_name)
236
+ logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
237
+ except Exception as e:
238
+ logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
239
+
240
+ # Get port from client for API calls
241
+ port = _get_port_from_client(self.client)
242
+ if port is None:
243
+ raise ValueError(
244
+ f"Cannot load LoRA adapter '{self.lora_path}': "
245
+ f"Unable to determine port from client base_url. "
246
+ f"LoRA loading requires a client initialized with port."
247
+ )
248
+
249
+ try:
250
+ # Load the LoRA adapter
251
+ loaded_lora_name = _load_lora_adapter(self.lora_path, port)
252
+ logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
253
+
254
+ # Update model name to the loaded LoRA name
255
+ self.model_kwargs["model"] = loaded_lora_name
256
+
257
+ except requests.RequestException as e:
258
+ # Check if error is due to LoRA already being loaded
259
+ error_msg = str(e)
260
+ if "400" in error_msg or "Bad Request" in error_msg:
261
+ logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
262
+ # Refresh the model list to check if it's now available
263
+ try:
264
+ updated_models = [m.id for m in self.client.models.list().data]
265
+ if lora_name in updated_models:
266
+ logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
267
+ self.model_kwargs["model"] = lora_name
268
+ return
269
+ except Exception:
270
+ pass # Fall through to original error
271
+
272
+ raise ValueError(f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}")
273
+
274
+ def unload_lora_adapter(self, lora_path: str) -> None:
275
+ """
276
+ Unload a LoRA adapter.
277
+
278
+ Args:
279
+ lora_path: Path to the LoRA adapter directory to unload
280
+
281
+ Raises:
282
+ ValueError: If unable to determine port from client
283
+ """
284
+ from .utils import _get_port_from_client, _unload_lora_adapter
285
+
286
+ port = _get_port_from_client(self.client)
287
+ if port is None:
288
+ raise ValueError(
289
+ "Cannot unload LoRA adapter: "
290
+ "Unable to determine port from client base_url. "
291
+ "LoRA operations require a client initialized with port."
292
+ )
293
+
294
+ _unload_lora_adapter(lora_path, port)
295
+ lora_name = os.path.basename(lora_path.rstrip("/\\"))
296
+ logger.info(f"Unloaded LoRA adapter: {lora_name}")
297
+
298
+ @staticmethod
299
+ def unload_lora(port: int, lora_name: str) -> None:
300
+ """
301
+ Static method to unload a LoRA adapter by name.
302
+
303
+ Args:
304
+ port: Port number for the API endpoint
305
+ lora_name: Name of the LoRA adapter to unload
306
+
307
+ Raises:
308
+ requests.RequestException: If the API call fails
309
+ """
310
+ try:
311
+ response = requests.post(
312
+ f"http://localhost:{port}/v1/unload_lora_adapter",
313
+ headers={
314
+ "accept": "application/json",
315
+ "Content-Type": "application/json",
316
+ },
317
+ json={"lora_name": lora_name, "lora_int_id": 0},
318
+ )
319
+ response.raise_for_status()
320
+ logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
321
+ except requests.RequestException as e:
322
+ logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
323
+ raise
324
+
325
+ def cleanup_vllm_server(self) -> None:
326
+ """Stop the VLLM server process if started by this instance."""
327
+ from .utils import stop_vllm_process
328
+
329
+ if hasattr(self, "vllm_process") and self.vllm_process is not None:
330
+ stop_vllm_process(self.vllm_process)
331
+ self.vllm_process = None
332
+
333
+ @staticmethod
334
+ def kill_all_vllm() -> int:
335
+ """
336
+ Kill all tracked VLLM server processes.
337
+
338
+ Returns:
339
+ Number of processes killed
340
+ """
341
+ from .utils import kill_all_vllm_processes
342
+
343
+ return kill_all_vllm_processes()
344
+
345
+ @staticmethod
346
+ def kill_vllm_on_port(port: int) -> bool:
347
+ """
348
+ Kill VLLM server running on a specific port.
349
+
350
+ Args:
351
+ port: Port number to kill server on
352
+
353
+ Returns:
354
+ True if a server was killed, False if no server was running
355
+ """
356
+ from .utils import _kill_vllm_on_port
357
+
358
+ return _kill_vllm_on_port(port)
359
+
360
+
361
+ class ModelUtilsMixin:
362
+ """Mixin for model utility methods."""
363
+
364
+ @staticmethod
365
+ def list_models(client: Union[OpenAI, int, str, None] = None) -> List[str]:
366
+ """
367
+ List available models from the OpenAI client.
368
+
369
+ Args:
370
+ client: OpenAI client, port number, or base_url string
371
+
372
+ Returns:
373
+ List of available model names
374
+ """
375
+ from .utils import get_base_client
376
+
377
+ client_instance = get_base_client(client, cache=False)
378
+ models = client_instance.models.list().data
379
+ return [m.id for m in models]
@@ -42,13 +42,16 @@ class MOpenAI(OpenAI):
42
42
 
43
43
  def __init__(self, *args, cache=True, **kwargs):
44
44
  super().__init__(*args, **kwargs)
45
+ self._orig_post = self.post
45
46
  if cache:
46
- # Create a memoized wrapper for the instance's post method.
47
- # The memoize decorator now preserves exact type information,
48
- # so no casting is needed.
49
- orig_post = self.post
50
- memoized = memoize(orig_post)
51
- self.post = memoized
47
+ self.set_cache(cache)
48
+
49
+ def set_cache(self, cache: bool) -> None:
50
+ """Enable or disable caching of the post method."""
51
+ if cache and self.post == self._orig_post:
52
+ self.post = memoize(self._orig_post) # type: ignore
53
+ elif not cache and self.post != self._orig_post:
54
+ self.post = self._orig_post
52
55
 
53
56
 
54
57
  class MAsyncOpenAI(AsyncOpenAI):
@@ -76,5 +79,13 @@ class MAsyncOpenAI(AsyncOpenAI):
76
79
 
77
80
  def __init__(self, *args, cache=True, **kwargs):
78
81
  super().__init__(*args, **kwargs)
82
+ self._orig_post = self.post
79
83
  if cache:
80
- self.post = memoize(self.post) # type: ignore
84
+ self.set_cache(cache)
85
+
86
+ def set_cache(self, cache: bool) -> None:
87
+ """Enable or disable caching of the post method."""
88
+ if cache and self.post == self._orig_post:
89
+ self.post = memoize(self._orig_post) # type: ignore
90
+ elif not cache and self.post != self._orig_post:
91
+ self.post = self._orig_post
llm_utils/lm/signature.py CHANGED
@@ -5,43 +5,38 @@ This module provides a declarative way to define LLM input/output schemas
5
5
  with field descriptions and type annotations.
6
6
  """
7
7
 
8
- from typing import Any, Dict, List, Type, Union, get_type_hints, Annotated, get_origin, get_args
8
+ from typing import Any, Dict, List, Type, get_type_hints, Annotated, get_origin, get_args, cast
9
9
  from pydantic import BaseModel, Field
10
10
  import inspect
11
11
 
12
12
 
13
- class InputField:
14
- """Represents an input field in a signature."""
13
+ class _FieldProxy:
14
+ """Proxy that stores field information while appearing type-compatible."""
15
15
 
16
- def __init__(self, desc: str = "", **kwargs):
16
+ def __init__(self, field_type: str, desc: str = "", **kwargs):
17
+ self.field_type = field_type # 'input' or 'output'
17
18
  self.desc = desc
18
19
  self.kwargs = kwargs
19
-
20
- def __class_getitem__(cls, item):
21
- """Support for InputField[type] syntax."""
22
- return item
23
20
 
24
21
 
25
- class OutputField:
26
- """Represents an output field in a signature."""
27
-
28
- def __init__(self, desc: str = "", **kwargs):
29
- self.desc = desc
30
- self.kwargs = kwargs
31
-
32
- def __class_getitem__(cls, item):
33
- """Support for OutputField[type] syntax."""
34
- return item
22
+ def InputField(desc: str = "", **kwargs) -> Any:
23
+ """Create an input field descriptor."""
24
+ return cast(Any, _FieldProxy('input', desc=desc, **kwargs))
25
+
26
+
27
+ def OutputField(desc: str = "", **kwargs) -> Any:
28
+ """Create an output field descriptor."""
29
+ return cast(Any, _FieldProxy('output', desc=desc, **kwargs))
35
30
 
36
31
 
37
32
  # Type aliases for cleaner syntax
38
33
  def Input(desc: str = "", **kwargs) -> Any:
39
- """Create an input field descriptor."""
34
+ """Create an input field descriptor that's compatible with type annotations."""
40
35
  return InputField(desc=desc, **kwargs)
41
36
 
42
37
 
43
38
  def Output(desc: str = "", **kwargs) -> Any:
44
- """Create an output field descriptor."""
39
+ """Create an output field descriptor that's compatible with type annotations."""
45
40
  return OutputField(desc=desc, **kwargs)
46
41
 
47
42
 
@@ -67,24 +62,24 @@ class SignatureMeta(type):
67
62
  if args:
68
63
  # First arg is the actual type
69
64
  field_type = args[0]
70
- # Look for InputField or OutputField in the metadata
65
+ # Look for _FieldProxy in the metadata
71
66
  for metadata in args[1:]:
72
- if isinstance(metadata, (InputField, OutputField)):
67
+ if isinstance(metadata, _FieldProxy):
73
68
  field_desc = metadata
74
69
  break
75
70
 
76
71
  # Handle old syntax with direct assignment
77
- if field_desc is None and isinstance(field_value, (InputField, OutputField)):
72
+ if field_desc is None and isinstance(field_value, _FieldProxy):
78
73
  field_desc = field_value
79
74
 
80
75
  # Store field information
81
- if isinstance(field_desc, InputField):
76
+ if field_desc and field_desc.field_type == 'input':
82
77
  input_fields[field_name] = {
83
78
  'type': field_type,
84
79
  'desc': field_desc.desc,
85
80
  **field_desc.kwargs
86
81
  }
87
- elif isinstance(field_desc, OutputField):
82
+ elif field_desc and field_desc.field_type == 'output':
88
83
  output_fields[field_name] = {
89
84
  'type': field_type,
90
85
  'desc': field_desc.desc,
@@ -136,10 +131,10 @@ class Signature(metaclass=SignatureMeta):
136
131
  return instruction
137
132
 
138
133
  @classmethod
139
- def get_input_model(cls) -> Union[Type[BaseModel], type[str]]:
134
+ def get_input_model(cls) -> Type[BaseModel]:
140
135
  """Generate Pydantic input model from input fields."""
141
136
  if not cls._input_fields:
142
- return str
137
+ raise ValueError(f"Signature {cls.__name__} must have at least one input field")
143
138
 
144
139
  fields = {}
145
140
  annotations = {}
@@ -170,10 +165,10 @@ class Signature(metaclass=SignatureMeta):
170
165
  return input_model
171
166
 
172
167
  @classmethod
173
- def get_output_model(cls) -> Union[Type[BaseModel], type[str]]:
168
+ def get_output_model(cls) -> Type[BaseModel]:
174
169
  """Generate Pydantic output model from output fields."""
175
170
  if not cls._output_fields:
176
- return str
171
+ raise ValueError(f"Signature {cls.__name__} must have at least one output field")
177
172
 
178
173
  fields = {}
179
174
  annotations = {}
@@ -259,17 +254,11 @@ if __name__ == "__main__":
259
254
 
260
255
  print("\nInput Model:")
261
256
  input_model = judge_class.get_input_model()
262
- if input_model is not str and hasattr(input_model, 'model_json_schema'):
263
- print(input_model.model_json_schema()) # type: ignore
264
- else:
265
- print("String input model")
257
+ print(input_model.model_json_schema())
266
258
 
267
259
  print("\nOutput Model:")
268
260
  output_model = judge_class.get_output_model()
269
- if output_model is not str and hasattr(output_model, 'model_json_schema'):
270
- print(output_model.model_json_schema()) # type: ignore
271
- else:
272
- print("String output model")
261
+ print(output_model.model_json_schema())
273
262
 
274
263
  # Test instance usage
275
264
  judge = judge_class()