speedy-utils 1.1.18__py3-none-any.whl → 1.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +3 -2
- llm_utils/lm/async_lm/async_llm_task.py +1 -0
- llm_utils/lm/llm_task.py +303 -10
- llm_utils/lm/openai_memoize.py +10 -2
- llm_utils/vector_cache/core.py +250 -234
- speedy_utils/__init__.py +2 -1
- speedy_utils/common/utils_cache.py +38 -19
- speedy_utils/common/utils_io.py +9 -5
- speedy_utils/multi_worker/process.py +91 -10
- speedy_utils/multi_worker/thread.py +94 -2
- {speedy_utils-1.1.18.dist-info → speedy_utils-1.1.20.dist-info}/METADATA +34 -13
- {speedy_utils-1.1.18.dist-info → speedy_utils-1.1.20.dist-info}/RECORD +19 -19
- {speedy_utils-1.1.18.dist-info → speedy_utils-1.1.20.dist-info}/WHEEL +1 -1
- speedy_utils-1.1.20.dist-info/entry_points.txt +5 -0
- speedy_utils-1.1.18.dist-info/entry_points.txt +0 -6
llm_utils/__init__.py
CHANGED
|
@@ -4,7 +4,7 @@ from llm_utils.vector_cache import VectorCache
|
|
|
4
4
|
from llm_utils.lm.lm_base import get_model_name
|
|
5
5
|
from llm_utils.lm.base_prompt_builder import BasePromptBuilder
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
LLM = LLMTask
|
|
8
8
|
|
|
9
9
|
from .chat_format import (
|
|
10
10
|
build_chatml_input,
|
|
@@ -34,5 +34,6 @@ __all__ = [
|
|
|
34
34
|
"MOpenAI",
|
|
35
35
|
"get_model_name",
|
|
36
36
|
"VectorCache",
|
|
37
|
-
"BasePromptBuilder"
|
|
37
|
+
"BasePromptBuilder",
|
|
38
|
+
"LLM"
|
|
38
39
|
]
|
llm_utils/lm/llm_task.py
CHANGED
|
@@ -4,10 +4,12 @@
|
|
|
4
4
|
Simplified LLM Task module for handling language model interactions with structured input/output.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
import os
|
|
7
8
|
from typing import Any, Dict, List, Optional, Type, Union, cast
|
|
8
9
|
|
|
10
|
+
import requests
|
|
9
11
|
from loguru import logger
|
|
10
|
-
from openai import OpenAI
|
|
12
|
+
from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
11
13
|
from openai.types.chat import ChatCompletionMessageParam
|
|
12
14
|
from pydantic import BaseModel
|
|
13
15
|
|
|
@@ -38,6 +40,90 @@ def get_base_client(
|
|
|
38
40
|
)
|
|
39
41
|
|
|
40
42
|
|
|
43
|
+
def _is_lora_path(path: str) -> bool:
|
|
44
|
+
"""Check if the given path is a LoRA adapter directory.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
path: Path to check
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
True if the path contains adapter_config.json, False otherwise
|
|
51
|
+
"""
|
|
52
|
+
if not os.path.isdir(path):
|
|
53
|
+
return False
|
|
54
|
+
adapter_config_path = os.path.join(path, 'adapter_config.json')
|
|
55
|
+
return os.path.isfile(adapter_config_path)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _get_port_from_client(client: OpenAI) -> Optional[int]:
|
|
59
|
+
"""Extract port number from OpenAI client base_url.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
client: OpenAI client instance
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Port number if found, None otherwise
|
|
66
|
+
"""
|
|
67
|
+
if hasattr(client, 'base_url') and client.base_url:
|
|
68
|
+
base_url = str(client.base_url)
|
|
69
|
+
if 'localhost:' in base_url:
|
|
70
|
+
try:
|
|
71
|
+
# Extract port from localhost:PORT/v1 format
|
|
72
|
+
port_part = base_url.split('localhost:')[1].split('/')[0]
|
|
73
|
+
return int(port_part)
|
|
74
|
+
except (IndexError, ValueError):
|
|
75
|
+
pass
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _load_lora_adapter(lora_path: str, port: int) -> str:
|
|
80
|
+
"""Load a LoRA adapter from the specified path.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
lora_path: Path to the LoRA adapter directory
|
|
84
|
+
port: Port number for the API endpoint
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Name of the loaded LoRA adapter
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
requests.RequestException: If the API call fails
|
|
91
|
+
"""
|
|
92
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
93
|
+
if not lora_name: # Handle edge case of empty basename
|
|
94
|
+
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
95
|
+
|
|
96
|
+
response = requests.post(
|
|
97
|
+
f'http://localhost:{port}/v1/load_lora_adapter',
|
|
98
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
99
|
+
json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
|
|
100
|
+
)
|
|
101
|
+
response.raise_for_status()
|
|
102
|
+
return lora_name
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _unload_lora_adapter(lora_path: str, port: int) -> None:
|
|
106
|
+
"""Unload the current LoRA adapter.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
lora_path: Path to the LoRA adapter directory
|
|
110
|
+
port: Port number for the API endpoint
|
|
111
|
+
"""
|
|
112
|
+
try:
|
|
113
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
114
|
+
if not lora_name: # Handle edge case of empty basename
|
|
115
|
+
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
116
|
+
|
|
117
|
+
response = requests.post(
|
|
118
|
+
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
119
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
120
|
+
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
121
|
+
)
|
|
122
|
+
response.raise_for_status()
|
|
123
|
+
except requests.RequestException as e:
|
|
124
|
+
logger.warning(f"Error unloading LoRA adapter: {str(e)[:100]}")
|
|
125
|
+
|
|
126
|
+
|
|
41
127
|
class LLMTask:
|
|
42
128
|
"""
|
|
43
129
|
Language model task with structured input/output and optional system instruction.
|
|
@@ -106,6 +192,9 @@ class LLMTask:
|
|
|
106
192
|
output_model: Type[BaseModel] | Type[str] = None,
|
|
107
193
|
client: Union[OpenAI, int, str, None] = None,
|
|
108
194
|
cache=True,
|
|
195
|
+
is_reasoning_model: bool = False,
|
|
196
|
+
force_lora_unload: bool = False,
|
|
197
|
+
lora_path: Optional[str] = None,
|
|
109
198
|
**model_kwargs,
|
|
110
199
|
):
|
|
111
200
|
"""
|
|
@@ -117,6 +206,12 @@ class LLMTask:
|
|
|
117
206
|
output_model: Output BaseModel type
|
|
118
207
|
client: OpenAI client, port number, or base_url string
|
|
119
208
|
cache: Whether to use cached responses (default True)
|
|
209
|
+
is_reasoning_model: Whether the model is a reasoning model (o1-preview, o1-mini, etc.)
|
|
210
|
+
that outputs reasoning_content separately from content (default False)
|
|
211
|
+
force_lora_unload: If True, forces unloading of any existing LoRA adapter before loading
|
|
212
|
+
a new one when lora_path is provided (default False)
|
|
213
|
+
lora_path: Optional path to LoRA adapter directory. If provided, will load the LoRA
|
|
214
|
+
and use it as the model. Takes precedence over model parameter.
|
|
120
215
|
**model_kwargs: Additional model parameters including:
|
|
121
216
|
- temperature: Controls randomness (0.0 to 2.0)
|
|
122
217
|
- n: Number of responses to generate (when n > 1, returns list)
|
|
@@ -127,6 +222,10 @@ class LLMTask:
|
|
|
127
222
|
self.input_model = input_model
|
|
128
223
|
self.output_model = output_model
|
|
129
224
|
self.model_kwargs = model_kwargs
|
|
225
|
+
self.is_reasoning_model = is_reasoning_model
|
|
226
|
+
self.force_lora_unload = force_lora_unload
|
|
227
|
+
self.lora_path = lora_path
|
|
228
|
+
self.last_ai_response = None # Store raw response from client
|
|
130
229
|
|
|
131
230
|
# if cache:
|
|
132
231
|
# print("Caching is enabled will use llm_utils.MOpenAI")
|
|
@@ -135,11 +234,152 @@ class LLMTask:
|
|
|
135
234
|
# else:
|
|
136
235
|
# self.client = OpenAI(base_url=base_url, api_key=api_key)
|
|
137
236
|
self.client = get_base_client(client, cache=cache)
|
|
237
|
+
# check connection of client
|
|
238
|
+
try:
|
|
239
|
+
self.client.models.list()
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.error(f"Failed to connect to OpenAI client: {str(e)}, base_url={self.client.base_url}")
|
|
242
|
+
raise e
|
|
138
243
|
|
|
139
244
|
if not self.model_kwargs.get("model", ""):
|
|
140
245
|
self.model_kwargs["model"] = self.client.models.list().data[0].id
|
|
246
|
+
|
|
247
|
+
# Handle LoRA loading if lora_path is provided
|
|
248
|
+
if self.lora_path:
|
|
249
|
+
self._load_lora_adapter()
|
|
250
|
+
|
|
141
251
|
print(self.model_kwargs)
|
|
142
252
|
|
|
253
|
+
def _load_lora_adapter(self) -> None:
|
|
254
|
+
"""
|
|
255
|
+
Load LoRA adapter from the specified lora_path.
|
|
256
|
+
|
|
257
|
+
This method:
|
|
258
|
+
1. Validates that lora_path is a valid LoRA directory
|
|
259
|
+
2. Checks if LoRA is already loaded (unless force_lora_unload is True)
|
|
260
|
+
3. Loads the LoRA adapter and updates the model name
|
|
261
|
+
"""
|
|
262
|
+
if not self.lora_path:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
if not _is_lora_path(self.lora_path):
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Invalid LoRA path '{self.lora_path}': "
|
|
268
|
+
"Directory must contain 'adapter_config.json'"
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
logger.info(f"Loading LoRA adapter from: {self.lora_path}")
|
|
272
|
+
|
|
273
|
+
# Get the expected LoRA name (basename of the path)
|
|
274
|
+
lora_name = os.path.basename(self.lora_path.rstrip('/\\'))
|
|
275
|
+
if not lora_name: # Handle edge case of empty basename
|
|
276
|
+
lora_name = os.path.basename(os.path.dirname(self.lora_path))
|
|
277
|
+
|
|
278
|
+
# Get list of available models to check if LoRA is already loaded
|
|
279
|
+
try:
|
|
280
|
+
available_models = [m.id for m in self.client.models.list().data]
|
|
281
|
+
except Exception as e:
|
|
282
|
+
logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
|
|
283
|
+
available_models = []
|
|
284
|
+
|
|
285
|
+
# Check if LoRA is already loaded
|
|
286
|
+
if lora_name in available_models and not self.force_lora_unload:
|
|
287
|
+
logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
|
|
288
|
+
self.model_kwargs["model"] = lora_name
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
# Force unload if requested
|
|
292
|
+
if self.force_lora_unload and lora_name in available_models:
|
|
293
|
+
logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
|
|
294
|
+
port = _get_port_from_client(self.client)
|
|
295
|
+
if port is not None:
|
|
296
|
+
try:
|
|
297
|
+
LLMTask.unload_lora(port, lora_name)
|
|
298
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
299
|
+
except Exception as e:
|
|
300
|
+
logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
|
|
301
|
+
|
|
302
|
+
# Get port from client for API calls
|
|
303
|
+
port = _get_port_from_client(self.client)
|
|
304
|
+
if port is None:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Cannot load LoRA adapter '{self.lora_path}': "
|
|
307
|
+
"Unable to determine port from client base_url. "
|
|
308
|
+
"LoRA loading requires a client initialized with port number."
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
# Load the LoRA adapter
|
|
313
|
+
loaded_lora_name = _load_lora_adapter(self.lora_path, port)
|
|
314
|
+
logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
|
|
315
|
+
|
|
316
|
+
# Update model name to the loaded LoRA name
|
|
317
|
+
self.model_kwargs["model"] = loaded_lora_name
|
|
318
|
+
|
|
319
|
+
except requests.RequestException as e:
|
|
320
|
+
# Check if the error is due to LoRA already being loaded
|
|
321
|
+
error_msg = str(e)
|
|
322
|
+
if "400" in error_msg or "Bad Request" in error_msg:
|
|
323
|
+
logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
|
|
324
|
+
# Refresh the model list to check if it's now available
|
|
325
|
+
try:
|
|
326
|
+
updated_models = [m.id for m in self.client.models.list().data]
|
|
327
|
+
if lora_name in updated_models:
|
|
328
|
+
logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
|
|
329
|
+
self.model_kwargs["model"] = lora_name
|
|
330
|
+
return
|
|
331
|
+
except Exception:
|
|
332
|
+
pass # Fall through to original error
|
|
333
|
+
|
|
334
|
+
raise ValueError(
|
|
335
|
+
f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def unload_lora_adapter(self, lora_path: str) -> None:
|
|
339
|
+
"""
|
|
340
|
+
Unload a LoRA adapter.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
lora_path: Path to the LoRA adapter directory to unload
|
|
344
|
+
|
|
345
|
+
Raises:
|
|
346
|
+
ValueError: If unable to determine port from client
|
|
347
|
+
"""
|
|
348
|
+
port = _get_port_from_client(self.client)
|
|
349
|
+
if port is None:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
"Cannot unload LoRA adapter: "
|
|
352
|
+
"Unable to determine port from client base_url. "
|
|
353
|
+
"LoRA operations require a client initialized with port number."
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
_unload_lora_adapter(lora_path, port)
|
|
357
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
358
|
+
logger.info(f"Unloaded LoRA adapter: {lora_name}")
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def unload_lora(port: int, lora_name: str) -> None:
|
|
362
|
+
"""Static method to unload a LoRA adapter by name.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
port: Port number for the API endpoint
|
|
366
|
+
lora_name: Name of the LoRA adapter to unload
|
|
367
|
+
|
|
368
|
+
Raises:
|
|
369
|
+
requests.RequestException: If the API call fails
|
|
370
|
+
"""
|
|
371
|
+
try:
|
|
372
|
+
response = requests.post(
|
|
373
|
+
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
374
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
375
|
+
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
376
|
+
)
|
|
377
|
+
response.raise_for_status()
|
|
378
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
379
|
+
except requests.RequestException as e:
|
|
380
|
+
logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
|
|
381
|
+
raise
|
|
382
|
+
|
|
143
383
|
def _prepare_input(self, input_data: Union[str, BaseModel, List[Dict]]) -> Messages:
|
|
144
384
|
"""Convert input to messages format."""
|
|
145
385
|
if isinstance(input_data, list):
|
|
@@ -200,9 +440,24 @@ class LLMTask:
|
|
|
200
440
|
# Extract model name from kwargs for API call
|
|
201
441
|
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
|
|
202
442
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
443
|
+
try:
|
|
444
|
+
completion = self.client.chat.completions.create(
|
|
445
|
+
model=model_name, messages=messages, **api_kwargs
|
|
446
|
+
)
|
|
447
|
+
# Store raw response from client
|
|
448
|
+
self.last_ai_response = completion
|
|
449
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
450
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
451
|
+
logger.error(error_msg)
|
|
452
|
+
raise
|
|
453
|
+
except Exception as e:
|
|
454
|
+
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
455
|
+
if is_length_error:
|
|
456
|
+
raise ValueError(
|
|
457
|
+
f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
|
|
458
|
+
)
|
|
459
|
+
# Re-raise all other exceptions
|
|
460
|
+
raise
|
|
206
461
|
# print(completion)
|
|
207
462
|
|
|
208
463
|
results: List[Dict[str, Any]] = []
|
|
@@ -211,9 +466,13 @@ class LLMTask:
|
|
|
211
466
|
Messages,
|
|
212
467
|
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
213
468
|
)
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
469
|
+
result_dict = {"parsed": choice.message.content, "messages": choice_messages}
|
|
470
|
+
|
|
471
|
+
# Add reasoning content if this is a reasoning model
|
|
472
|
+
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
473
|
+
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
474
|
+
|
|
475
|
+
results.append(result_dict)
|
|
217
476
|
return results
|
|
218
477
|
|
|
219
478
|
def pydantic_parse(
|
|
@@ -239,6 +498,11 @@ class LLMTask:
|
|
|
239
498
|
List of dicts [{'parsed': parsed_model, 'messages': messages}, ...]
|
|
240
499
|
When n=1: List contains one dict
|
|
241
500
|
When n>1: List contains multiple dicts
|
|
501
|
+
|
|
502
|
+
Note:
|
|
503
|
+
This method ensures consistent Pydantic model output for both fresh and cached responses.
|
|
504
|
+
When responses are cached and loaded back, the parsed content is re-validated to maintain
|
|
505
|
+
type consistency between first-time and subsequent calls.
|
|
242
506
|
"""
|
|
243
507
|
# Prepare messages
|
|
244
508
|
messages = self._prepare_input(input_data)
|
|
@@ -265,12 +529,20 @@ class LLMTask:
|
|
|
265
529
|
response_format=pydantic_model_to_use,
|
|
266
530
|
**api_kwargs,
|
|
267
531
|
)
|
|
532
|
+
# Store raw response from client
|
|
533
|
+
self.last_ai_response = completion
|
|
534
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
535
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
536
|
+
logger.error(error_msg)
|
|
537
|
+
raise
|
|
268
538
|
except Exception as e:
|
|
269
539
|
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
270
540
|
if is_length_error:
|
|
271
541
|
raise ValueError(
|
|
272
542
|
f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
|
|
273
543
|
)
|
|
544
|
+
# Re-raise all other exceptions
|
|
545
|
+
raise
|
|
274
546
|
|
|
275
547
|
results: List[Dict[str, Any]] = []
|
|
276
548
|
for choice in completion.choices: # type: ignore[attr-defined]
|
|
@@ -278,9 +550,23 @@ class LLMTask:
|
|
|
278
550
|
Messages,
|
|
279
551
|
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
280
552
|
)
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
553
|
+
|
|
554
|
+
# Ensure consistent Pydantic model output for both fresh and cached responses
|
|
555
|
+
parsed_content = choice.message.parsed # type: ignore[attr-defined]
|
|
556
|
+
if isinstance(parsed_content, dict):
|
|
557
|
+
# Cached response: validate dict back to Pydantic model
|
|
558
|
+
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
559
|
+
elif not isinstance(parsed_content, pydantic_model_to_use):
|
|
560
|
+
# Fallback: ensure it's the correct type
|
|
561
|
+
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
562
|
+
|
|
563
|
+
result_dict = {"parsed": parsed_content, "messages": choice_messages}
|
|
564
|
+
|
|
565
|
+
# Add reasoning content if this is a reasoning model
|
|
566
|
+
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
567
|
+
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
568
|
+
|
|
569
|
+
results.append(result_dict)
|
|
284
570
|
return results
|
|
285
571
|
|
|
286
572
|
def __call__(
|
|
@@ -364,6 +650,8 @@ class LLMTask:
|
|
|
364
650
|
builder: BasePromptBuilder,
|
|
365
651
|
client: Union[OpenAI, int, str, None] = None,
|
|
366
652
|
cache=True,
|
|
653
|
+
is_reasoning_model: bool = False,
|
|
654
|
+
lora_path: Optional[str] = None,
|
|
367
655
|
**model_kwargs,
|
|
368
656
|
) -> "LLMTask":
|
|
369
657
|
"""
|
|
@@ -382,6 +670,10 @@ class LLMTask:
|
|
|
382
670
|
input_model=input_model,
|
|
383
671
|
output_model=output_model,
|
|
384
672
|
client=client,
|
|
673
|
+
cache=cache,
|
|
674
|
+
is_reasoning_model=is_reasoning_model,
|
|
675
|
+
lora_path=lora_path,
|
|
676
|
+
**model_kwargs,
|
|
385
677
|
)
|
|
386
678
|
|
|
387
679
|
@staticmethod
|
|
@@ -398,3 +690,4 @@ class LLMTask:
|
|
|
398
690
|
client = get_base_client(client, cache=False)
|
|
399
691
|
models = client.models.list().data
|
|
400
692
|
return [m.id for m in models]
|
|
693
|
+
|
llm_utils/lm/openai_memoize.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from openai import OpenAI, AsyncOpenAI
|
|
2
|
+
from typing import Any, Callable
|
|
2
3
|
|
|
3
4
|
from speedy_utils.common.utils_cache import memoize
|
|
4
5
|
|
|
@@ -30,6 +31,8 @@ class MOpenAI(OpenAI):
|
|
|
30
31
|
- If you need a shared cache across instances, or more advanced cache controls,
|
|
31
32
|
modify `memoize` or wrap at a class/static level instead of assigning to the
|
|
32
33
|
bound method.
|
|
34
|
+
- Type information is now fully preserved by the memoize decorator, eliminating
|
|
35
|
+
the need for type casting.
|
|
33
36
|
|
|
34
37
|
Example
|
|
35
38
|
m = MOpenAI(api_key="...", model="gpt-4")
|
|
@@ -40,7 +43,12 @@ class MOpenAI(OpenAI):
|
|
|
40
43
|
def __init__(self, *args, cache=True, **kwargs):
|
|
41
44
|
super().__init__(*args, **kwargs)
|
|
42
45
|
if cache:
|
|
43
|
-
|
|
46
|
+
# Create a memoized wrapper for the instance's post method.
|
|
47
|
+
# The memoize decorator now preserves exact type information,
|
|
48
|
+
# so no casting is needed.
|
|
49
|
+
orig_post = self.post
|
|
50
|
+
memoized = memoize(orig_post)
|
|
51
|
+
self.post = memoized
|
|
44
52
|
|
|
45
53
|
|
|
46
54
|
class MAsyncOpenAI(AsyncOpenAI):
|
|
@@ -69,4 +77,4 @@ class MAsyncOpenAI(AsyncOpenAI):
|
|
|
69
77
|
def __init__(self, *args, cache=True, **kwargs):
|
|
70
78
|
super().__init__(*args, **kwargs)
|
|
71
79
|
if cache:
|
|
72
|
-
self.post = memoize(self.post)
|
|
80
|
+
self.post = memoize(self.post) # type: ignore
|