speedy-utils 1.1.18__py3-none-any.whl → 1.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +3 -2
- llm_utils/lm/async_lm/async_llm_task.py +1 -0
- llm_utils/lm/llm_task.py +303 -10
- llm_utils/lm/openai_memoize.py +2 -2
- llm_utils/vector_cache/core.py +3 -3
- speedy_utils/__init__.py +2 -1
- speedy_utils/common/utils_cache.py +1 -1
- speedy_utils/common/utils_io.py +9 -5
- speedy_utils/multi_worker/process.py +63 -6
- speedy_utils/multi_worker/thread.py +94 -2
- {speedy_utils-1.1.18.dist-info → speedy_utils-1.1.19.dist-info}/METADATA +34 -13
- {speedy_utils-1.1.18.dist-info → speedy_utils-1.1.19.dist-info}/RECORD +19 -19
- {speedy_utils-1.1.18.dist-info → speedy_utils-1.1.19.dist-info}/WHEEL +1 -1
- speedy_utils-1.1.19.dist-info/entry_points.txt +5 -0
- speedy_utils-1.1.18.dist-info/entry_points.txt +0 -6
llm_utils/__init__.py
CHANGED
|
@@ -4,7 +4,7 @@ from llm_utils.vector_cache import VectorCache
|
|
|
4
4
|
from llm_utils.lm.lm_base import get_model_name
|
|
5
5
|
from llm_utils.lm.base_prompt_builder import BasePromptBuilder
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
LLM = LLMTask
|
|
8
8
|
|
|
9
9
|
from .chat_format import (
|
|
10
10
|
build_chatml_input,
|
|
@@ -34,5 +34,6 @@ __all__ = [
|
|
|
34
34
|
"MOpenAI",
|
|
35
35
|
"get_model_name",
|
|
36
36
|
"VectorCache",
|
|
37
|
-
"BasePromptBuilder"
|
|
37
|
+
"BasePromptBuilder",
|
|
38
|
+
"LLM"
|
|
38
39
|
]
|
llm_utils/lm/llm_task.py
CHANGED
|
@@ -4,10 +4,12 @@
|
|
|
4
4
|
Simplified LLM Task module for handling language model interactions with structured input/output.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
import os
|
|
7
8
|
from typing import Any, Dict, List, Optional, Type, Union, cast
|
|
8
9
|
|
|
10
|
+
import requests
|
|
9
11
|
from loguru import logger
|
|
10
|
-
from openai import OpenAI
|
|
12
|
+
from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
11
13
|
from openai.types.chat import ChatCompletionMessageParam
|
|
12
14
|
from pydantic import BaseModel
|
|
13
15
|
|
|
@@ -38,6 +40,90 @@ def get_base_client(
|
|
|
38
40
|
)
|
|
39
41
|
|
|
40
42
|
|
|
43
|
+
def _is_lora_path(path: str) -> bool:
|
|
44
|
+
"""Check if the given path is a LoRA adapter directory.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
path: Path to check
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
True if the path contains adapter_config.json, False otherwise
|
|
51
|
+
"""
|
|
52
|
+
if not os.path.isdir(path):
|
|
53
|
+
return False
|
|
54
|
+
adapter_config_path = os.path.join(path, 'adapter_config.json')
|
|
55
|
+
return os.path.isfile(adapter_config_path)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _get_port_from_client(client: OpenAI) -> Optional[int]:
|
|
59
|
+
"""Extract port number from OpenAI client base_url.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
client: OpenAI client instance
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Port number if found, None otherwise
|
|
66
|
+
"""
|
|
67
|
+
if hasattr(client, 'base_url') and client.base_url:
|
|
68
|
+
base_url = str(client.base_url)
|
|
69
|
+
if 'localhost:' in base_url:
|
|
70
|
+
try:
|
|
71
|
+
# Extract port from localhost:PORT/v1 format
|
|
72
|
+
port_part = base_url.split('localhost:')[1].split('/')[0]
|
|
73
|
+
return int(port_part)
|
|
74
|
+
except (IndexError, ValueError):
|
|
75
|
+
pass
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _load_lora_adapter(lora_path: str, port: int) -> str:
|
|
80
|
+
"""Load a LoRA adapter from the specified path.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
lora_path: Path to the LoRA adapter directory
|
|
84
|
+
port: Port number for the API endpoint
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Name of the loaded LoRA adapter
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
requests.RequestException: If the API call fails
|
|
91
|
+
"""
|
|
92
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
93
|
+
if not lora_name: # Handle edge case of empty basename
|
|
94
|
+
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
95
|
+
|
|
96
|
+
response = requests.post(
|
|
97
|
+
f'http://localhost:{port}/v1/load_lora_adapter',
|
|
98
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
99
|
+
json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
|
|
100
|
+
)
|
|
101
|
+
response.raise_for_status()
|
|
102
|
+
return lora_name
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _unload_lora_adapter(lora_path: str, port: int) -> None:
|
|
106
|
+
"""Unload the current LoRA adapter.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
lora_path: Path to the LoRA adapter directory
|
|
110
|
+
port: Port number for the API endpoint
|
|
111
|
+
"""
|
|
112
|
+
try:
|
|
113
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
114
|
+
if not lora_name: # Handle edge case of empty basename
|
|
115
|
+
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
116
|
+
|
|
117
|
+
response = requests.post(
|
|
118
|
+
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
119
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
120
|
+
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
121
|
+
)
|
|
122
|
+
response.raise_for_status()
|
|
123
|
+
except requests.RequestException as e:
|
|
124
|
+
logger.warning(f"Error unloading LoRA adapter: {str(e)[:100]}")
|
|
125
|
+
|
|
126
|
+
|
|
41
127
|
class LLMTask:
|
|
42
128
|
"""
|
|
43
129
|
Language model task with structured input/output and optional system instruction.
|
|
@@ -106,6 +192,9 @@ class LLMTask:
|
|
|
106
192
|
output_model: Type[BaseModel] | Type[str] = None,
|
|
107
193
|
client: Union[OpenAI, int, str, None] = None,
|
|
108
194
|
cache=True,
|
|
195
|
+
is_reasoning_model: bool = False,
|
|
196
|
+
force_lora_unload: bool = False,
|
|
197
|
+
lora_path: Optional[str] = None,
|
|
109
198
|
**model_kwargs,
|
|
110
199
|
):
|
|
111
200
|
"""
|
|
@@ -117,6 +206,12 @@ class LLMTask:
|
|
|
117
206
|
output_model: Output BaseModel type
|
|
118
207
|
client: OpenAI client, port number, or base_url string
|
|
119
208
|
cache: Whether to use cached responses (default True)
|
|
209
|
+
is_reasoning_model: Whether the model is a reasoning model (o1-preview, o1-mini, etc.)
|
|
210
|
+
that outputs reasoning_content separately from content (default False)
|
|
211
|
+
force_lora_unload: If True, forces unloading of any existing LoRA adapter before loading
|
|
212
|
+
a new one when lora_path is provided (default False)
|
|
213
|
+
lora_path: Optional path to LoRA adapter directory. If provided, will load the LoRA
|
|
214
|
+
and use it as the model. Takes precedence over model parameter.
|
|
120
215
|
**model_kwargs: Additional model parameters including:
|
|
121
216
|
- temperature: Controls randomness (0.0 to 2.0)
|
|
122
217
|
- n: Number of responses to generate (when n > 1, returns list)
|
|
@@ -127,6 +222,10 @@ class LLMTask:
|
|
|
127
222
|
self.input_model = input_model
|
|
128
223
|
self.output_model = output_model
|
|
129
224
|
self.model_kwargs = model_kwargs
|
|
225
|
+
self.is_reasoning_model = is_reasoning_model
|
|
226
|
+
self.force_lora_unload = force_lora_unload
|
|
227
|
+
self.lora_path = lora_path
|
|
228
|
+
self.last_ai_response = None # Store raw response from client
|
|
130
229
|
|
|
131
230
|
# if cache:
|
|
132
231
|
# print("Caching is enabled will use llm_utils.MOpenAI")
|
|
@@ -135,11 +234,152 @@ class LLMTask:
|
|
|
135
234
|
# else:
|
|
136
235
|
# self.client = OpenAI(base_url=base_url, api_key=api_key)
|
|
137
236
|
self.client = get_base_client(client, cache=cache)
|
|
237
|
+
# check connection of client
|
|
238
|
+
try:
|
|
239
|
+
self.client.models.list()
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.error(f"Failed to connect to OpenAI client: {str(e)}, base_url={self.client.base_url}")
|
|
242
|
+
raise e
|
|
138
243
|
|
|
139
244
|
if not self.model_kwargs.get("model", ""):
|
|
140
245
|
self.model_kwargs["model"] = self.client.models.list().data[0].id
|
|
246
|
+
|
|
247
|
+
# Handle LoRA loading if lora_path is provided
|
|
248
|
+
if self.lora_path:
|
|
249
|
+
self._load_lora_adapter()
|
|
250
|
+
|
|
141
251
|
print(self.model_kwargs)
|
|
142
252
|
|
|
253
|
+
def _load_lora_adapter(self) -> None:
|
|
254
|
+
"""
|
|
255
|
+
Load LoRA adapter from the specified lora_path.
|
|
256
|
+
|
|
257
|
+
This method:
|
|
258
|
+
1. Validates that lora_path is a valid LoRA directory
|
|
259
|
+
2. Checks if LoRA is already loaded (unless force_lora_unload is True)
|
|
260
|
+
3. Loads the LoRA adapter and updates the model name
|
|
261
|
+
"""
|
|
262
|
+
if not self.lora_path:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
if not _is_lora_path(self.lora_path):
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Invalid LoRA path '{self.lora_path}': "
|
|
268
|
+
"Directory must contain 'adapter_config.json'"
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
logger.info(f"Loading LoRA adapter from: {self.lora_path}")
|
|
272
|
+
|
|
273
|
+
# Get the expected LoRA name (basename of the path)
|
|
274
|
+
lora_name = os.path.basename(self.lora_path.rstrip('/\\'))
|
|
275
|
+
if not lora_name: # Handle edge case of empty basename
|
|
276
|
+
lora_name = os.path.basename(os.path.dirname(self.lora_path))
|
|
277
|
+
|
|
278
|
+
# Get list of available models to check if LoRA is already loaded
|
|
279
|
+
try:
|
|
280
|
+
available_models = [m.id for m in self.client.models.list().data]
|
|
281
|
+
except Exception as e:
|
|
282
|
+
logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
|
|
283
|
+
available_models = []
|
|
284
|
+
|
|
285
|
+
# Check if LoRA is already loaded
|
|
286
|
+
if lora_name in available_models and not self.force_lora_unload:
|
|
287
|
+
logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
|
|
288
|
+
self.model_kwargs["model"] = lora_name
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
# Force unload if requested
|
|
292
|
+
if self.force_lora_unload and lora_name in available_models:
|
|
293
|
+
logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
|
|
294
|
+
port = _get_port_from_client(self.client)
|
|
295
|
+
if port is not None:
|
|
296
|
+
try:
|
|
297
|
+
LLMTask.unload_lora(port, lora_name)
|
|
298
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
299
|
+
except Exception as e:
|
|
300
|
+
logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
|
|
301
|
+
|
|
302
|
+
# Get port from client for API calls
|
|
303
|
+
port = _get_port_from_client(self.client)
|
|
304
|
+
if port is None:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Cannot load LoRA adapter '{self.lora_path}': "
|
|
307
|
+
"Unable to determine port from client base_url. "
|
|
308
|
+
"LoRA loading requires a client initialized with port number."
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
# Load the LoRA adapter
|
|
313
|
+
loaded_lora_name = _load_lora_adapter(self.lora_path, port)
|
|
314
|
+
logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
|
|
315
|
+
|
|
316
|
+
# Update model name to the loaded LoRA name
|
|
317
|
+
self.model_kwargs["model"] = loaded_lora_name
|
|
318
|
+
|
|
319
|
+
except requests.RequestException as e:
|
|
320
|
+
# Check if the error is due to LoRA already being loaded
|
|
321
|
+
error_msg = str(e)
|
|
322
|
+
if "400" in error_msg or "Bad Request" in error_msg:
|
|
323
|
+
logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
|
|
324
|
+
# Refresh the model list to check if it's now available
|
|
325
|
+
try:
|
|
326
|
+
updated_models = [m.id for m in self.client.models.list().data]
|
|
327
|
+
if lora_name in updated_models:
|
|
328
|
+
logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
|
|
329
|
+
self.model_kwargs["model"] = lora_name
|
|
330
|
+
return
|
|
331
|
+
except Exception:
|
|
332
|
+
pass # Fall through to original error
|
|
333
|
+
|
|
334
|
+
raise ValueError(
|
|
335
|
+
f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def unload_lora_adapter(self, lora_path: str) -> None:
|
|
339
|
+
"""
|
|
340
|
+
Unload a LoRA adapter.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
lora_path: Path to the LoRA adapter directory to unload
|
|
344
|
+
|
|
345
|
+
Raises:
|
|
346
|
+
ValueError: If unable to determine port from client
|
|
347
|
+
"""
|
|
348
|
+
port = _get_port_from_client(self.client)
|
|
349
|
+
if port is None:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
"Cannot unload LoRA adapter: "
|
|
352
|
+
"Unable to determine port from client base_url. "
|
|
353
|
+
"LoRA operations require a client initialized with port number."
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
_unload_lora_adapter(lora_path, port)
|
|
357
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
358
|
+
logger.info(f"Unloaded LoRA adapter: {lora_name}")
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def unload_lora(port: int, lora_name: str) -> None:
|
|
362
|
+
"""Static method to unload a LoRA adapter by name.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
port: Port number for the API endpoint
|
|
366
|
+
lora_name: Name of the LoRA adapter to unload
|
|
367
|
+
|
|
368
|
+
Raises:
|
|
369
|
+
requests.RequestException: If the API call fails
|
|
370
|
+
"""
|
|
371
|
+
try:
|
|
372
|
+
response = requests.post(
|
|
373
|
+
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
374
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
375
|
+
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
376
|
+
)
|
|
377
|
+
response.raise_for_status()
|
|
378
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
379
|
+
except requests.RequestException as e:
|
|
380
|
+
logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
|
|
381
|
+
raise
|
|
382
|
+
|
|
143
383
|
def _prepare_input(self, input_data: Union[str, BaseModel, List[Dict]]) -> Messages:
|
|
144
384
|
"""Convert input to messages format."""
|
|
145
385
|
if isinstance(input_data, list):
|
|
@@ -200,9 +440,24 @@ class LLMTask:
|
|
|
200
440
|
# Extract model name from kwargs for API call
|
|
201
441
|
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
|
|
202
442
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
443
|
+
try:
|
|
444
|
+
completion = self.client.chat.completions.create(
|
|
445
|
+
model=model_name, messages=messages, **api_kwargs
|
|
446
|
+
)
|
|
447
|
+
# Store raw response from client
|
|
448
|
+
self.last_ai_response = completion
|
|
449
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
450
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
451
|
+
logger.error(error_msg)
|
|
452
|
+
raise
|
|
453
|
+
except Exception as e:
|
|
454
|
+
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
455
|
+
if is_length_error:
|
|
456
|
+
raise ValueError(
|
|
457
|
+
f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
|
|
458
|
+
)
|
|
459
|
+
# Re-raise all other exceptions
|
|
460
|
+
raise
|
|
206
461
|
# print(completion)
|
|
207
462
|
|
|
208
463
|
results: List[Dict[str, Any]] = []
|
|
@@ -211,9 +466,13 @@ class LLMTask:
|
|
|
211
466
|
Messages,
|
|
212
467
|
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
213
468
|
)
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
469
|
+
result_dict = {"parsed": choice.message.content, "messages": choice_messages}
|
|
470
|
+
|
|
471
|
+
# Add reasoning content if this is a reasoning model
|
|
472
|
+
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
473
|
+
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
474
|
+
|
|
475
|
+
results.append(result_dict)
|
|
217
476
|
return results
|
|
218
477
|
|
|
219
478
|
def pydantic_parse(
|
|
@@ -239,6 +498,11 @@ class LLMTask:
|
|
|
239
498
|
List of dicts [{'parsed': parsed_model, 'messages': messages}, ...]
|
|
240
499
|
When n=1: List contains one dict
|
|
241
500
|
When n>1: List contains multiple dicts
|
|
501
|
+
|
|
502
|
+
Note:
|
|
503
|
+
This method ensures consistent Pydantic model output for both fresh and cached responses.
|
|
504
|
+
When responses are cached and loaded back, the parsed content is re-validated to maintain
|
|
505
|
+
type consistency between first-time and subsequent calls.
|
|
242
506
|
"""
|
|
243
507
|
# Prepare messages
|
|
244
508
|
messages = self._prepare_input(input_data)
|
|
@@ -265,12 +529,20 @@ class LLMTask:
|
|
|
265
529
|
response_format=pydantic_model_to_use,
|
|
266
530
|
**api_kwargs,
|
|
267
531
|
)
|
|
532
|
+
# Store raw response from client
|
|
533
|
+
self.last_ai_response = completion
|
|
534
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
535
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
536
|
+
logger.error(error_msg)
|
|
537
|
+
raise
|
|
268
538
|
except Exception as e:
|
|
269
539
|
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
270
540
|
if is_length_error:
|
|
271
541
|
raise ValueError(
|
|
272
542
|
f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
|
|
273
543
|
)
|
|
544
|
+
# Re-raise all other exceptions
|
|
545
|
+
raise
|
|
274
546
|
|
|
275
547
|
results: List[Dict[str, Any]] = []
|
|
276
548
|
for choice in completion.choices: # type: ignore[attr-defined]
|
|
@@ -278,9 +550,23 @@ class LLMTask:
|
|
|
278
550
|
Messages,
|
|
279
551
|
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
280
552
|
)
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
553
|
+
|
|
554
|
+
# Ensure consistent Pydantic model output for both fresh and cached responses
|
|
555
|
+
parsed_content = choice.message.parsed # type: ignore[attr-defined]
|
|
556
|
+
if isinstance(parsed_content, dict):
|
|
557
|
+
# Cached response: validate dict back to Pydantic model
|
|
558
|
+
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
559
|
+
elif not isinstance(parsed_content, pydantic_model_to_use):
|
|
560
|
+
# Fallback: ensure it's the correct type
|
|
561
|
+
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
562
|
+
|
|
563
|
+
result_dict = {"parsed": parsed_content, "messages": choice_messages}
|
|
564
|
+
|
|
565
|
+
# Add reasoning content if this is a reasoning model
|
|
566
|
+
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
567
|
+
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
568
|
+
|
|
569
|
+
results.append(result_dict)
|
|
284
570
|
return results
|
|
285
571
|
|
|
286
572
|
def __call__(
|
|
@@ -364,6 +650,8 @@ class LLMTask:
|
|
|
364
650
|
builder: BasePromptBuilder,
|
|
365
651
|
client: Union[OpenAI, int, str, None] = None,
|
|
366
652
|
cache=True,
|
|
653
|
+
is_reasoning_model: bool = False,
|
|
654
|
+
lora_path: Optional[str] = None,
|
|
367
655
|
**model_kwargs,
|
|
368
656
|
) -> "LLMTask":
|
|
369
657
|
"""
|
|
@@ -382,6 +670,10 @@ class LLMTask:
|
|
|
382
670
|
input_model=input_model,
|
|
383
671
|
output_model=output_model,
|
|
384
672
|
client=client,
|
|
673
|
+
cache=cache,
|
|
674
|
+
is_reasoning_model=is_reasoning_model,
|
|
675
|
+
lora_path=lora_path,
|
|
676
|
+
**model_kwargs,
|
|
385
677
|
)
|
|
386
678
|
|
|
387
679
|
@staticmethod
|
|
@@ -398,3 +690,4 @@ class LLMTask:
|
|
|
398
690
|
client = get_base_client(client, cache=False)
|
|
399
691
|
models = client.models.list().data
|
|
400
692
|
return [m.id for m in models]
|
|
693
|
+
|
llm_utils/lm/openai_memoize.py
CHANGED
|
@@ -40,7 +40,7 @@ class MOpenAI(OpenAI):
|
|
|
40
40
|
def __init__(self, *args, cache=True, **kwargs):
|
|
41
41
|
super().__init__(*args, **kwargs)
|
|
42
42
|
if cache:
|
|
43
|
-
self.post = memoize(self.post)
|
|
43
|
+
self.post = memoize(self.post) # type: ignore
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
class MAsyncOpenAI(AsyncOpenAI):
|
|
@@ -69,4 +69,4 @@ class MAsyncOpenAI(AsyncOpenAI):
|
|
|
69
69
|
def __init__(self, *args, cache=True, **kwargs):
|
|
70
70
|
super().__init__(*args, **kwargs)
|
|
71
71
|
if cache:
|
|
72
|
-
self.post = memoize(self.post)
|
|
72
|
+
self.post = memoize(self.post) # type: ignore
|
llm_utils/vector_cache/core.py
CHANGED
|
@@ -535,9 +535,9 @@ class VectorCache:
|
|
|
535
535
|
if self.verbose:
|
|
536
536
|
print(f"Computing embeddings for {total_items} missing texts in batches of {batch_size}...")
|
|
537
537
|
if self.backend in ["vllm", "transformers"] and self._model is None:
|
|
538
|
-
print(
|
|
538
|
+
print("⚠️ Model will be loaded on first batch (lazy loading enabled)")
|
|
539
539
|
elif self.backend in ["vllm", "transformers"]:
|
|
540
|
-
print(
|
|
540
|
+
print("✓ Model already loaded, ready for efficient batch processing")
|
|
541
541
|
|
|
542
542
|
# Create progress bar
|
|
543
543
|
pbar = None
|
|
@@ -571,7 +571,7 @@ class VectorCache:
|
|
|
571
571
|
# Update progress
|
|
572
572
|
batch_size_actual = len(batch_items)
|
|
573
573
|
if use_tqdm:
|
|
574
|
-
pbar.update(batch_size_actual)
|
|
574
|
+
pbar.update(batch_size_actual) # type: ignore
|
|
575
575
|
else:
|
|
576
576
|
processed_count += batch_size_actual
|
|
577
577
|
if self.verbose:
|
speedy_utils/__init__.py
CHANGED
|
@@ -138,7 +138,7 @@ from .common.utils_print import (
|
|
|
138
138
|
|
|
139
139
|
# Multi-worker processing
|
|
140
140
|
from .multi_worker.process import multi_process
|
|
141
|
-
from .multi_worker.thread import multi_thread
|
|
141
|
+
from .multi_worker.thread import kill_all_thread, multi_thread
|
|
142
142
|
|
|
143
143
|
# Define __all__ explicitly
|
|
144
144
|
__all__ = [
|
|
@@ -224,6 +224,7 @@ __all__ = [
|
|
|
224
224
|
# Multi-worker processing
|
|
225
225
|
"multi_process",
|
|
226
226
|
"multi_thread",
|
|
227
|
+
"kill_all_thread",
|
|
227
228
|
# Notebook utilities
|
|
228
229
|
"change_dir",
|
|
229
230
|
]
|
speedy_utils/common/utils_io.py
CHANGED
|
@@ -1,13 +1,18 @@
|
|
|
1
1
|
# utils/utils_io.py
|
|
2
2
|
|
|
3
|
+
import bz2
|
|
4
|
+
import gzip
|
|
5
|
+
import io
|
|
3
6
|
import json
|
|
7
|
+
import lzma
|
|
4
8
|
import os
|
|
5
9
|
import os.path as osp
|
|
6
10
|
import pickle
|
|
7
11
|
import time
|
|
12
|
+
import warnings
|
|
8
13
|
from glob import glob
|
|
9
14
|
from pathlib import Path
|
|
10
|
-
from typing import Any, Union
|
|
15
|
+
from typing import IO, Any, Iterable, Optional, Union, cast
|
|
11
16
|
|
|
12
17
|
from json_repair import loads as jloads
|
|
13
18
|
from pydantic import BaseModel
|
|
@@ -53,7 +58,7 @@ def dump_json_or_pickle(
|
|
|
53
58
|
except Exception as e:
|
|
54
59
|
if isinstance(obj, BaseModel):
|
|
55
60
|
data = obj.model_dump()
|
|
56
|
-
from fastcore.all import
|
|
61
|
+
from fastcore.all import dict2obj, obj2dict
|
|
57
62
|
obj2 = dict2obj(data)
|
|
58
63
|
with open(fname, "wb") as f:
|
|
59
64
|
pickle.dump(obj2, f)
|
|
@@ -87,8 +92,7 @@ def load_json_or_pickle(fname: str, counter=0) -> Any:
|
|
|
87
92
|
raise ValueError(f"Error {e} while loading {fname}") from e
|
|
88
93
|
|
|
89
94
|
|
|
90
|
-
|
|
91
|
-
from typing import Iterable, Union, IO, Any, Optional, cast
|
|
95
|
+
|
|
92
96
|
|
|
93
97
|
try:
|
|
94
98
|
import orjson # type: ignore[import-not-found] # fastest JSON parser when available
|
|
@@ -212,7 +216,7 @@ def fast_load_jsonl(
|
|
|
212
216
|
if line_count > multiworker_threshold:
|
|
213
217
|
# Use multi-worker processing
|
|
214
218
|
from ..multi_worker.thread import multi_thread
|
|
215
|
-
|
|
219
|
+
|
|
216
220
|
# Read all lines into chunks
|
|
217
221
|
f = _open_auto(path_or_file)
|
|
218
222
|
all_lines = list(f)
|
|
@@ -1,11 +1,20 @@
|
|
|
1
1
|
# ray_multi_process.py
|
|
2
|
-
import time, os, pickle, uuid, datetime
|
|
2
|
+
import time, os, pickle, uuid, datetime, multiprocessing
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Any, Callable
|
|
5
5
|
from tqdm import tqdm
|
|
6
|
-
import
|
|
6
|
+
import psutil
|
|
7
|
+
import threading
|
|
8
|
+
ray: Any
|
|
9
|
+
try:
|
|
10
|
+
import ray as ray # type: ignore
|
|
11
|
+
_HAS_RAY = True
|
|
12
|
+
except Exception: # pragma: no cover
|
|
13
|
+
ray = None # type: ignore
|
|
14
|
+
_HAS_RAY = False
|
|
7
15
|
from fastcore.parallel import parallel
|
|
8
16
|
|
|
17
|
+
|
|
9
18
|
# ─── cache helpers ──────────────────────────────────────────
|
|
10
19
|
|
|
11
20
|
def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
|
|
@@ -61,7 +70,7 @@ def multi_process(
|
|
|
61
70
|
lazy_output: bool = False,
|
|
62
71
|
progress: bool = True,
|
|
63
72
|
# backend: str = "ray", # "seq", "ray", or "fastcore"
|
|
64
|
-
backend: Literal["seq", "ray", "mp", "threadpool"] =
|
|
73
|
+
backend: Literal["seq", "ray", "mp", "threadpool", "safe"] | None = None,
|
|
65
74
|
# Additional optional knobs (accepted for compatibility)
|
|
66
75
|
batch: int | None = None,
|
|
67
76
|
ordered: bool | None = None,
|
|
@@ -75,12 +84,18 @@ def multi_process(
|
|
|
75
84
|
backend:
|
|
76
85
|
- "seq": run sequentially
|
|
77
86
|
- "ray": run in parallel with Ray
|
|
78
|
-
- "
|
|
87
|
+
- "mp": run in parallel with multiprocessing (uses threadpool to avoid fork warnings)
|
|
88
|
+
- "threadpool": run in parallel with thread pool
|
|
89
|
+
- "safe": run in parallel with thread pool (explicitly safe for tests)
|
|
79
90
|
|
|
80
91
|
If lazy_output=True, every result is saved to .pkl and
|
|
81
92
|
the returned list contains file paths.
|
|
82
93
|
"""
|
|
83
94
|
|
|
95
|
+
# default backend selection
|
|
96
|
+
if backend is None:
|
|
97
|
+
backend = "ray" if _HAS_RAY else "mp"
|
|
98
|
+
|
|
84
99
|
# unify items
|
|
85
100
|
if items is None and inputs is not None:
|
|
86
101
|
items = list(inputs)
|
|
@@ -108,6 +123,13 @@ def multi_process(
|
|
|
108
123
|
|
|
109
124
|
# ---- ray backend ----
|
|
110
125
|
if backend == "ray":
|
|
126
|
+
if not _HAS_RAY:
|
|
127
|
+
msg = (
|
|
128
|
+
"Ray backend requested but 'ray' is not installed. "
|
|
129
|
+
"Install extra: pip install 'speedy-utils[ray]' or "
|
|
130
|
+
"poetry install -E ray."
|
|
131
|
+
)
|
|
132
|
+
raise RuntimeError(msg)
|
|
111
133
|
pbar.set_postfix_str("backend=ray")
|
|
112
134
|
ensure_ray(workers, pbar)
|
|
113
135
|
|
|
@@ -125,10 +147,45 @@ def multi_process(
|
|
|
125
147
|
|
|
126
148
|
# ---- fastcore backend ----
|
|
127
149
|
if backend == "mp":
|
|
128
|
-
|
|
150
|
+
# Use threadpool instead of multiprocessing to avoid fork warnings
|
|
151
|
+
# in multi-threaded environments like pytest
|
|
152
|
+
results = parallel(f_wrapped, items, n_workers=workers, progress=progress, threadpool=True)
|
|
129
153
|
return list(results)
|
|
130
154
|
if backend == "threadpool":
|
|
131
155
|
results = parallel(f_wrapped, items, n_workers=workers, progress=progress, threadpool=True)
|
|
132
156
|
return list(results)
|
|
133
|
-
|
|
157
|
+
if backend == "safe":
|
|
158
|
+
# Completely safe backend for tests - no multiprocessing, no external progress bars
|
|
159
|
+
import concurrent.futures
|
|
160
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
|
161
|
+
results = list(executor.map(f_wrapped, items))
|
|
134
162
|
raise ValueError(f"Unsupported backend: {backend!r}")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def cleanup_phantom_workers():
|
|
167
|
+
"""
|
|
168
|
+
Kill all child processes (phantom workers) without killing the Jupyter kernel itself.
|
|
169
|
+
Also lists non-daemon threads that remain.
|
|
170
|
+
"""
|
|
171
|
+
parent = psutil.Process(os.getpid())
|
|
172
|
+
|
|
173
|
+
# Kill only children, never the current process
|
|
174
|
+
for child in parent.children(recursive=True):
|
|
175
|
+
try:
|
|
176
|
+
print(f"🔪 Killing child process {child.pid} ({child.name()})")
|
|
177
|
+
child.kill()
|
|
178
|
+
except psutil.NoSuchProcess:
|
|
179
|
+
pass
|
|
180
|
+
|
|
181
|
+
# Report stray threads (can't hard-kill them in Python)
|
|
182
|
+
for t in threading.enumerate():
|
|
183
|
+
if t is threading.current_thread():
|
|
184
|
+
continue
|
|
185
|
+
if not t.daemon:
|
|
186
|
+
print(f"⚠️ Thread {t.name} is still running (cannot be force-killed).")
|
|
187
|
+
|
|
188
|
+
print("✅ Cleaned up child processes (kernel untouched).")
|
|
189
|
+
|
|
190
|
+
# Usage: run this anytime after cancelling a cell
|
|
191
|
+
|
|
@@ -77,7 +77,9 @@
|
|
|
77
77
|
# ============================================================================= #
|
|
78
78
|
"""
|
|
79
79
|
|
|
80
|
+
import ctypes
|
|
80
81
|
import os
|
|
82
|
+
import threading
|
|
81
83
|
import time
|
|
82
84
|
import traceback
|
|
83
85
|
from collections.abc import Callable, Iterable
|
|
@@ -98,6 +100,42 @@ DEFAULT_WORKERS = (os.cpu_count() or 4) * 2
|
|
|
98
100
|
T = TypeVar("T")
|
|
99
101
|
R = TypeVar("R")
|
|
100
102
|
|
|
103
|
+
SPEEDY_RUNNING_THREADS: list[threading.Thread] = []
|
|
104
|
+
_SPEEDY_THREADS_LOCK = threading.Lock()
|
|
105
|
+
|
|
106
|
+
_PY_SET_ASYNC_EXC = ctypes.pythonapi.PyThreadState_SetAsyncExc
|
|
107
|
+
try:
|
|
108
|
+
_PY_SET_ASYNC_EXC.argtypes = (ctypes.c_ulong, ctypes.py_object) # type: ignore[attr-defined]
|
|
109
|
+
_PY_SET_ASYNC_EXC.restype = ctypes.c_int # type: ignore[attr-defined]
|
|
110
|
+
except AttributeError: # pragma: no cover - platform specific
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _prune_dead_threads() -> None:
|
|
115
|
+
with _SPEEDY_THREADS_LOCK:
|
|
116
|
+
SPEEDY_RUNNING_THREADS[:] = [t for t in SPEEDY_RUNNING_THREADS if t.is_alive()]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _track_threads(threads: Iterable[threading.Thread]) -> None:
|
|
120
|
+
if not threads:
|
|
121
|
+
return
|
|
122
|
+
with _SPEEDY_THREADS_LOCK:
|
|
123
|
+
living = [t for t in SPEEDY_RUNNING_THREADS if t.is_alive()]
|
|
124
|
+
for candidate in threads:
|
|
125
|
+
if not candidate.is_alive():
|
|
126
|
+
continue
|
|
127
|
+
if any(existing is candidate for existing in living):
|
|
128
|
+
continue
|
|
129
|
+
living.append(candidate)
|
|
130
|
+
SPEEDY_RUNNING_THREADS[:] = living
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _track_executor_threads(pool: ThreadPoolExecutor) -> None:
|
|
134
|
+
thread_set = getattr(pool, "_threads", None)
|
|
135
|
+
if not thread_set:
|
|
136
|
+
return
|
|
137
|
+
_track_threads(tuple(thread_set))
|
|
138
|
+
|
|
101
139
|
|
|
102
140
|
def _group_iter(src: Iterable[T], size: int) -> Iterable[list[T]]:
|
|
103
141
|
"""Yield successive chunks from iterable of specified size."""
|
|
@@ -273,11 +311,13 @@ def multi_thread(
|
|
|
273
311
|
fut.idx = next_logical_idx # type: ignore[attr-defined]
|
|
274
312
|
inflight.add(fut)
|
|
275
313
|
next_logical_idx += len(arg)
|
|
314
|
+
_track_executor_threads(pool)
|
|
276
315
|
else:
|
|
277
316
|
fut = pool.submit(_worker, arg, func, fixed_kwargs)
|
|
278
317
|
fut.idx = next_logical_idx # type: ignore[attr-defined]
|
|
279
318
|
inflight.add(fut)
|
|
280
319
|
next_logical_idx += 1
|
|
320
|
+
_track_executor_threads(pool)
|
|
281
321
|
|
|
282
322
|
try:
|
|
283
323
|
# Process futures as they complete and add new ones to keep the pool busy
|
|
@@ -347,11 +387,13 @@ def multi_thread(
|
|
|
347
387
|
fut2.idx = next_logical_idx # type: ignore[attr-defined]
|
|
348
388
|
inflight.add(fut2)
|
|
349
389
|
next_logical_idx += len(arg)
|
|
390
|
+
_track_executor_threads(pool)
|
|
350
391
|
else:
|
|
351
392
|
fut2 = pool.submit(_worker, arg, func, fixed_kwargs)
|
|
352
393
|
fut2.idx = next_logical_idx # type: ignore[attr-defined]
|
|
353
394
|
inflight.add(fut2)
|
|
354
395
|
next_logical_idx += 1
|
|
396
|
+
_track_executor_threads(pool)
|
|
355
397
|
except StopIteration:
|
|
356
398
|
pass
|
|
357
399
|
|
|
@@ -370,6 +412,7 @@ def multi_thread(
|
|
|
370
412
|
bar.close()
|
|
371
413
|
if store_output_pkl_file:
|
|
372
414
|
dump_json_or_pickle(results, store_output_pkl_file)
|
|
415
|
+
_prune_dead_threads()
|
|
373
416
|
return results
|
|
374
417
|
|
|
375
418
|
|
|
@@ -396,9 +439,58 @@ def multi_thread_standard(
|
|
|
396
439
|
Results in same order as input items.
|
|
397
440
|
"""
|
|
398
441
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
399
|
-
futures = [
|
|
442
|
+
futures = []
|
|
443
|
+
for item in items:
|
|
444
|
+
futures.append(executor.submit(fn, item))
|
|
445
|
+
_track_executor_threads(executor)
|
|
400
446
|
results = [fut.result() for fut in futures]
|
|
447
|
+
_prune_dead_threads()
|
|
401
448
|
return results
|
|
402
449
|
|
|
403
450
|
|
|
404
|
-
|
|
451
|
+
def _async_raise(thread_id: int, exc_type: type[BaseException]) -> bool:
|
|
452
|
+
if thread_id <= 0:
|
|
453
|
+
return False
|
|
454
|
+
if not issubclass(exc_type, BaseException):
|
|
455
|
+
raise TypeError("exc_type must derive from BaseException")
|
|
456
|
+
res = _PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), ctypes.py_object(exc_type))
|
|
457
|
+
if res == 0:
|
|
458
|
+
return False
|
|
459
|
+
if res > 1: # pragma: no cover - defensive branch
|
|
460
|
+
_PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), None)
|
|
461
|
+
raise SystemError("PyThreadState_SetAsyncExc failed")
|
|
462
|
+
return True
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def kill_all_thread(exc_type: type[BaseException] = SystemExit, join_timeout: float = 0.1) -> int:
|
|
466
|
+
"""Forcefully stop tracked worker threads. Returns number of threads signalled."""
|
|
467
|
+
_prune_dead_threads()
|
|
468
|
+
current = threading.current_thread()
|
|
469
|
+
with _SPEEDY_THREADS_LOCK:
|
|
470
|
+
targets = [t for t in SPEEDY_RUNNING_THREADS if t.is_alive()]
|
|
471
|
+
|
|
472
|
+
terminated = 0
|
|
473
|
+
for thread in targets:
|
|
474
|
+
if thread is current:
|
|
475
|
+
continue
|
|
476
|
+
ident = thread.ident
|
|
477
|
+
if ident is None:
|
|
478
|
+
continue
|
|
479
|
+
try:
|
|
480
|
+
if _async_raise(ident, exc_type):
|
|
481
|
+
terminated += 1
|
|
482
|
+
thread.join(timeout=join_timeout)
|
|
483
|
+
else:
|
|
484
|
+
logger.warning("Unable to signal thread %s", thread.name)
|
|
485
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
486
|
+
logger.error("Failed to stop thread %s: %s", thread.name, exc)
|
|
487
|
+
_prune_dead_threads()
|
|
488
|
+
return terminated
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
__all__ = [
|
|
492
|
+
"SPEEDY_RUNNING_THREADS",
|
|
493
|
+
"multi_thread",
|
|
494
|
+
"multi_thread_standard",
|
|
495
|
+
"kill_all_thread",
|
|
496
|
+
]
|
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: speedy-utils
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.19
|
|
4
4
|
Summary: Fast and easy-to-use package for data science
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
5
|
+
Project-URL: Homepage, https://github.com/anhvth/speedy
|
|
6
|
+
Project-URL: Repository, https://github.com/anhvth/speedy
|
|
7
|
+
Author-email: AnhVTH <anhvth.226@gmail.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
12
|
Classifier: Programming Language :: Python :: 3
|
|
9
13
|
Classifier: Programming Language :: Python :: 3.8
|
|
10
14
|
Classifier: Programming Language :: Python :: 3.9
|
|
@@ -13,29 +17,34 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
13
17
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
18
|
Classifier: Programming Language :: Python :: 3.13
|
|
15
19
|
Classifier: Programming Language :: Python :: 3.14
|
|
20
|
+
Requires-Python: >=3.8
|
|
21
|
+
Requires-Dist: aiohttp>=3.10.11
|
|
16
22
|
Requires-Dist: bump2version
|
|
17
23
|
Requires-Dist: cachetools
|
|
18
24
|
Requires-Dist: debugpy
|
|
19
25
|
Requires-Dist: fastcore
|
|
20
26
|
Requires-Dist: fastprogress
|
|
21
|
-
Requires-Dist: freezegun
|
|
27
|
+
Requires-Dist: freezegun>=1.5.1
|
|
22
28
|
Requires-Dist: ipdb
|
|
23
29
|
Requires-Dist: ipywidgets
|
|
24
|
-
Requires-Dist: json-repair
|
|
30
|
+
Requires-Dist: json-repair<0.31.0,>=0.25.0
|
|
25
31
|
Requires-Dist: jupyterlab
|
|
26
32
|
Requires-Dist: loguru
|
|
27
33
|
Requires-Dist: matplotlib
|
|
28
34
|
Requires-Dist: numpy
|
|
29
|
-
Requires-Dist: openai
|
|
30
|
-
Requires-Dist: packaging
|
|
35
|
+
Requires-Dist: openai>=1.106.0
|
|
36
|
+
Requires-Dist: packaging<25,>=23.2
|
|
31
37
|
Requires-Dist: pandas
|
|
32
38
|
Requires-Dist: pydantic
|
|
39
|
+
Requires-Dist: pytest>=8.3.5
|
|
40
|
+
Requires-Dist: ray>=2.36.1
|
|
33
41
|
Requires-Dist: requests
|
|
34
42
|
Requires-Dist: scikit-learn
|
|
35
43
|
Requires-Dist: tabulate
|
|
36
44
|
Requires-Dist: tqdm
|
|
37
45
|
Requires-Dist: xxhash
|
|
38
|
-
|
|
46
|
+
Provides-Extra: ray
|
|
47
|
+
Requires-Dist: ray>=2.49.1; (python_version >= '3.9') and extra == 'ray'
|
|
39
48
|
Description-Content-Type: text/markdown
|
|
40
49
|
|
|
41
50
|
# Speedy Utils
|
|
@@ -84,6 +93,19 @@ cd speedy-utils
|
|
|
84
93
|
pip install .
|
|
85
94
|
```
|
|
86
95
|
|
|
96
|
+
### Extras
|
|
97
|
+
|
|
98
|
+
Optional dependencies can be installed via extras. For the `ray` backend
|
|
99
|
+
support (requires Python >= 3.9):
|
|
100
|
+
|
|
101
|
+
```bash
|
|
102
|
+
# pip
|
|
103
|
+
pip install 'speedy-utils[ray]'
|
|
104
|
+
|
|
105
|
+
# Poetry (for developing this repo)
|
|
106
|
+
poetry install -E ray
|
|
107
|
+
```
|
|
108
|
+
|
|
87
109
|
## Updating from previous versions
|
|
88
110
|
|
|
89
111
|
To update from previous versions or switch to v1.x, first uninstall any old
|
|
@@ -282,9 +304,8 @@ python speedy_utils/common/dataclass_parser.py
|
|
|
282
304
|
|
|
283
305
|
Example output:
|
|
284
306
|
|
|
285
|
-
| Field
|
|
286
|
-
|
|
287
|
-
| from_peft
|
|
307
|
+
| Field | Value |
|
|
308
|
+
| --------- | ------------------------------------- |
|
|
309
|
+
| from_peft | ./outputs/llm_hn_qw32b/hn_results_r3/ |
|
|
288
310
|
|
|
289
311
|
Please ensure your code adheres to the project's coding standards and includes appropriate tests.
|
|
290
|
-
|
|
@@ -1,31 +1,31 @@
|
|
|
1
|
-
llm_utils/__init__.py,sha256=
|
|
1
|
+
llm_utils/__init__.py,sha256=n9m0iB82oygFThbDEdI5hpmozmTNhQgwX148QZulfCE,940
|
|
2
|
+
llm_utils/group_messages.py,sha256=Oe2tlhg-zRodG1-hodYebddrR77j9UdE05LzJw0EvYI,3622
|
|
2
3
|
llm_utils/chat_format/__init__.py,sha256=8dBIUqFJvkgQYedxBtcyxt-4tt8JxAKVap2JlTXmgaM,737
|
|
3
4
|
llm_utils/chat_format/display.py,sha256=3jKDm4OTrvytK1qBhSOjRLltUIObHsYFdBLgm8SVDE8,14159
|
|
4
5
|
llm_utils/chat_format/transform.py,sha256=eU0c3PdAHCNLuGP1UqPwln0B34Lv3bt_uV9v9BrlCN4,5402
|
|
5
6
|
llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
|
|
6
|
-
llm_utils/group_messages.py,sha256=Oe2tlhg-zRodG1-hodYebddrR77j9UdE05LzJw0EvYI,3622
|
|
7
7
|
llm_utils/lm/__init__.py,sha256=totIZnq1P8eNlfVco0OfdGdTNt1-wSXDSRReRRzYYxw,319
|
|
8
|
+
llm_utils/lm/base_prompt_builder.py,sha256=OLqyxbA8QeYIVFzB9EqxUiE_P2p4_MD_Lq4WSwxFtKU,12136
|
|
9
|
+
llm_utils/lm/llm_task.py,sha256=kyBeMDJwW9ZWq5A_OMgE-ou9GQ0bk5c9lxXOvfo31R4,27915
|
|
10
|
+
llm_utils/lm/lm.py,sha256=8TaLuU7naPQbOFmiS2NQyWVLG0jUUzRRBQsR0In7GVo,7249
|
|
11
|
+
llm_utils/lm/lm_base.py,sha256=pqbHZOdR7yUMpvwt8uBG1dZnt76SY_Wk8BkXQQ-mpWs,9557
|
|
12
|
+
llm_utils/lm/openai_memoize.py,sha256=q1cj5tZOSEpvx4QhRNs37pVaFMpMViCdVtwRsoaXgeU,3054
|
|
13
|
+
llm_utils/lm/utils.py,sha256=a0KJj8vjT2fHKb7GKGNJjJHhKLThwpxIL7vnV9Fr3ZY,4584
|
|
8
14
|
llm_utils/lm/async_lm/__init__.py,sha256=PUBbCuf5u6-0GBUu-2PI6YAguzsyXj-LPkU6vccqT6E,121
|
|
9
15
|
llm_utils/lm/async_lm/_utils.py,sha256=P1-pUDf_0pDmo8WTIi43t5ARlyGA1RIJfpAhz-gfA5g,6105
|
|
10
|
-
llm_utils/lm/async_lm/async_llm_task.py,sha256
|
|
16
|
+
llm_utils/lm/async_lm/async_llm_task.py,sha256=-BVOk18ZD8eC2obTLgiPq39f2PP3cji17Ku-Gb7c7Xo,18683
|
|
11
17
|
llm_utils/lm/async_lm/async_lm.py,sha256=e3o9cyMbkVz_jQDTjJv2ybET_5mY012zdZGjNwi4Qk4,13719
|
|
12
18
|
llm_utils/lm/async_lm/async_lm_base.py,sha256=iJgtzI6pVJzWtlXGqVLwgCIb-FzZAa3E5xW8yhyHUmM,8426
|
|
13
19
|
llm_utils/lm/async_lm/lm_specific.py,sha256=KmqdCm3SJ5MqN-dRJd6S5tq5-ve1X2eNWf2CMFtc_3s,3926
|
|
14
|
-
llm_utils/lm/base_prompt_builder.py,sha256=OLqyxbA8QeYIVFzB9EqxUiE_P2p4_MD_Lq4WSwxFtKU,12136
|
|
15
|
-
llm_utils/lm/llm_task.py,sha256=K5c27iYM9etAbdDM1WiO3-GjTvl1dkzt2sIaW3N1YA0,15483
|
|
16
|
-
llm_utils/lm/lm.py,sha256=8TaLuU7naPQbOFmiS2NQyWVLG0jUUzRRBQsR0In7GVo,7249
|
|
17
|
-
llm_utils/lm/lm_base.py,sha256=pqbHZOdR7yUMpvwt8uBG1dZnt76SY_Wk8BkXQQ-mpWs,9557
|
|
18
|
-
llm_utils/lm/openai_memoize.py,sha256=DdMl31cV9AqLlkARajZrqAKCyhvH8JQk2SAHMSzO3mk,3024
|
|
19
|
-
llm_utils/lm/utils.py,sha256=a0KJj8vjT2fHKb7GKGNJjJHhKLThwpxIL7vnV9Fr3ZY,4584
|
|
20
20
|
llm_utils/scripts/README.md,sha256=yuOLnLa2od2jp4wVy3rV0rESeiV3o8zol5MNMsZx0DY,999
|
|
21
21
|
llm_utils/scripts/vllm_load_balancer.py,sha256=TT5Ypq7gUcl52gRFp--ORFFjzhfGlcaX2rkRv8NxlxU,37259
|
|
22
22
|
llm_utils/scripts/vllm_serve.py,sha256=gJ0-y4kybMfSt8qzye1pJqGMY3x9JLRi6Tu7RjJMnss,14771
|
|
23
23
|
llm_utils/vector_cache/__init__.py,sha256=i1KQuC4OhPewYpFl9X6HlWFBuASCTx2qgGizhpZhmn0,862
|
|
24
24
|
llm_utils/vector_cache/cli.py,sha256=DMXTj8nZ2_LRjprbYPb4uzq04qZtOfBbmblmaqDcCuM,6251
|
|
25
|
-
llm_utils/vector_cache/core.py,sha256=
|
|
25
|
+
llm_utils/vector_cache/core.py,sha256=222LcmVJR0bFo0jRAJEG6e5ceWFfySmVbCxywScE6E4,33595
|
|
26
26
|
llm_utils/vector_cache/types.py,sha256=ru8qmUZ8_lNd3_oYpjCMtpXTsqmwsSBe56Z4hTWm3xI,435
|
|
27
27
|
llm_utils/vector_cache/utils.py,sha256=dwbbXlRrARrpmS4YqSlYQqrTURg0UWe8XvaAWcX05MM,1458
|
|
28
|
-
speedy_utils/__init__.py,sha256=
|
|
28
|
+
speedy_utils/__init__.py,sha256=QBvGIbrC5yczQwh4T8iu9KQx6w9u-v_JdoQfA67hLUg,5780
|
|
29
29
|
speedy_utils/all.py,sha256=t-HKzDmhF1MTFnmq7xRnPs5nFG_aZaLH9Ua0RM6nQ9Y,4855
|
|
30
30
|
speedy_utils/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
31
|
speedy_utils/common/clock.py,sha256=3n4FkCW0dz46O8By09V5Pve1DSMgpLDRbWEVRryryeQ,7423
|
|
@@ -34,17 +34,17 @@ speedy_utils/common/logger.py,sha256=a2iZx0eWyfi2-2X_H2QmfuA3tfR7_XSM7Nd0GdUnUOs
|
|
|
34
34
|
speedy_utils/common/notebook_utils.py,sha256=-97kehJ_Gg3TzDLubsLIYJcykqX1NXhbvBO6nniZSYM,2063
|
|
35
35
|
speedy_utils/common/patcher.py,sha256=VCmdxyTF87qroggQkQklRPhAOPJbeBqhcJoTsLcDxNw,2303
|
|
36
36
|
speedy_utils/common/report_manager.py,sha256=eBiw5KY6bWUhwki3B4lK5o8bFsp7L5x28X9GCI-Sd1w,3899
|
|
37
|
-
speedy_utils/common/utils_cache.py,sha256=
|
|
38
|
-
speedy_utils/common/utils_io.py,sha256
|
|
37
|
+
speedy_utils/common/utils_cache.py,sha256=8KPCWPUCm91HCH9kvV_gcshlxJl6m4tZ8yAKHhJCfUc,22445
|
|
38
|
+
speedy_utils/common/utils_io.py,sha256=-RkQjYGa3zVqpgVInsdp8dbS5oLwdJdUsRz1XIUSJzg,14257
|
|
39
39
|
speedy_utils/common/utils_misc.py,sha256=cdEuBBpiB1xpuzj0UBDHDuTIerqsMIw37ENq6EXliOw,1795
|
|
40
40
|
speedy_utils/common/utils_print.py,sha256=syRrnSFtguxrV-elx6DDVcSGu4Qy7D_xVNZhPwbUY4A,4864
|
|
41
41
|
speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
-
speedy_utils/multi_worker/process.py,sha256=
|
|
43
|
-
speedy_utils/multi_worker/thread.py,sha256=
|
|
42
|
+
speedy_utils/multi_worker/process.py,sha256=ouN65PbOhg0rOGUK7ATB7zXkRA993w9iiPDZ7nZ9g0w,6881
|
|
43
|
+
speedy_utils/multi_worker/thread.py,sha256=xhCPgJokCDjjPrWh6vUtCBlZgs3E6mM81WCAEKvZea0,19522
|
|
44
44
|
speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
45
45
|
speedy_utils/scripts/mpython.py,sha256=IvywP7Y0_V6tWfMP-4MjPvN5_KfxWF21xaLJsCIayCk,3821
|
|
46
46
|
speedy_utils/scripts/openapi_client_codegen.py,sha256=f2125S_q0PILgH5dyzoKRz7pIvNEjCkzpi4Q4pPFRZE,9683
|
|
47
|
-
speedy_utils-1.1.
|
|
48
|
-
speedy_utils-1.1.
|
|
49
|
-
speedy_utils-1.1.
|
|
50
|
-
speedy_utils-1.1.
|
|
47
|
+
speedy_utils-1.1.19.dist-info/METADATA,sha256=AHlhLIK3CLwi6f_-_qJDS1lEfXYvvacZ1RHiV_Gfnb4,8094
|
|
48
|
+
speedy_utils-1.1.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
49
|
+
speedy_utils-1.1.19.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
|
|
50
|
+
speedy_utils-1.1.19.dist-info/RECORD,,
|