speedy-utils 1.1.40__py3-none-any.whl → 1.1.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +2 -0
- llm_utils/llm_ray.py +370 -0
- llm_utils/lm/llm.py +36 -29
- speedy_utils/__init__.py +3 -0
- speedy_utils/common/utils_io.py +3 -1
- speedy_utils/multi_worker/__init__.py +12 -0
- speedy_utils/multi_worker/dataset_ray.py +303 -0
- speedy_utils/multi_worker/parallel_gpu_pool.py +178 -0
- speedy_utils/multi_worker/process.py +375 -75
- speedy_utils/multi_worker/progress.py +140 -0
- speedy_utils/scripts/mpython.py +49 -4
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.42.dist-info}/METADATA +3 -2
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.42.dist-info}/RECORD +15 -11
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.42.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.42.dist-info}/entry_points.txt +0 -0
llm_utils/__init__.py
CHANGED
|
@@ -12,6 +12,7 @@ from llm_utils.lm import (
|
|
|
12
12
|
from llm_utils.lm.base_prompt_builder import BasePromptBuilder
|
|
13
13
|
from llm_utils.lm.lm_base import get_model_name
|
|
14
14
|
from llm_utils.lm.openai_memoize import MOpenAI
|
|
15
|
+
from llm_utils.llm_ray import LLMRay
|
|
15
16
|
from llm_utils.vector_cache import VectorCache
|
|
16
17
|
|
|
17
18
|
|
|
@@ -57,6 +58,7 @@ __all__ = [
|
|
|
57
58
|
"AsyncLM",
|
|
58
59
|
"AsyncLLMTask",
|
|
59
60
|
"LLM",
|
|
61
|
+
"LLMRay",
|
|
60
62
|
"MOpenAI",
|
|
61
63
|
"get_model_name",
|
|
62
64
|
"VectorCache",
|
llm_utils/llm_ray.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLMRay: Simplified Ray-based vLLM wrapper for offline batch inference.
|
|
3
|
+
|
|
4
|
+
Automatically handles data parallelism across available GPUs in Ray cluster.
|
|
5
|
+
Pipeline parallel is always 1 (no layer splitting).
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
# dp=4, tp=2 means 8 GPUs total, 4 model replicas each using 2 GPUs
|
|
9
|
+
llm = LLMRay(model_name='Qwen/Qwen3-0.6B', dp=4, tp=2)
|
|
10
|
+
|
|
11
|
+
# dp=8, tp=2 means 16 GPUs across nodes, 8 model replicas
|
|
12
|
+
llm = LLMRay(model_name='meta-llama/Llama-3-70B', dp=8, tp=2)
|
|
13
|
+
"""
|
|
14
|
+
import os
|
|
15
|
+
import datetime
|
|
16
|
+
import ray
|
|
17
|
+
import numpy as np
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import Any, Dict, List, Optional
|
|
20
|
+
from tqdm.auto import tqdm
|
|
21
|
+
|
|
22
|
+
# Type alias for OpenAI-style messages
|
|
23
|
+
Message = Dict[str, str] # {'role': str, 'content': str}
|
|
24
|
+
Messages = List[Message]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@ray.remote
|
|
28
|
+
class _ProgressTracker:
|
|
29
|
+
"""Ray actor for tracking global progress across workers."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, total_items: int):
|
|
32
|
+
self.total_items = total_items
|
|
33
|
+
self.processed_count = 0
|
|
34
|
+
import time
|
|
35
|
+
self.start_time = time.time()
|
|
36
|
+
|
|
37
|
+
def increment(self) -> None:
|
|
38
|
+
self.processed_count += 1
|
|
39
|
+
|
|
40
|
+
def get_stats(self) -> tuple:
|
|
41
|
+
import time
|
|
42
|
+
elapsed = time.time() - self.start_time
|
|
43
|
+
speed = self.processed_count / elapsed if elapsed > 0 else 0
|
|
44
|
+
return self.processed_count, self.total_items, speed, elapsed
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class _VLLMWorkerBase(ABC):
|
|
48
|
+
"""Base worker class for vLLM inference."""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
worker_id: int,
|
|
53
|
+
log_dir: Optional[str],
|
|
54
|
+
tracker: Any,
|
|
55
|
+
**kwargs: Any,
|
|
56
|
+
):
|
|
57
|
+
self.worker_id = worker_id
|
|
58
|
+
self.log_dir = log_dir
|
|
59
|
+
self.tracker = tracker
|
|
60
|
+
self.kwargs = kwargs
|
|
61
|
+
self._log_file_handle = None
|
|
62
|
+
self._last_print_time = 0
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def setup(self) -> None:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def process_one_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
def _redirect_output(self) -> None:
|
|
73
|
+
"""Workers > 0 write to disk. Worker 0 writes to stdout."""
|
|
74
|
+
import sys
|
|
75
|
+
if self.worker_id == 0 or self.log_dir is None:
|
|
76
|
+
return
|
|
77
|
+
log_path = os.path.join(self.log_dir, f'worker_{self.worker_id}.log')
|
|
78
|
+
self._log_file_handle = open(log_path, 'w', buffering=1)
|
|
79
|
+
sys.stdout = self._log_file_handle
|
|
80
|
+
sys.stderr = self._log_file_handle
|
|
81
|
+
|
|
82
|
+
def _print_global_stats(self) -> None:
|
|
83
|
+
"""Only used by Worker 0 to print global stats."""
|
|
84
|
+
import time
|
|
85
|
+
import datetime as dt
|
|
86
|
+
if self.tracker is None:
|
|
87
|
+
return
|
|
88
|
+
if time.time() - self._last_print_time < 5:
|
|
89
|
+
return
|
|
90
|
+
count, total, speed, elapsed = ray.get(self.tracker.get_stats.remote())
|
|
91
|
+
if speed > 0:
|
|
92
|
+
eta = (total - count) / speed
|
|
93
|
+
eta_str = str(dt.timedelta(seconds=int(eta)))
|
|
94
|
+
else:
|
|
95
|
+
eta_str = '?'
|
|
96
|
+
msg = (
|
|
97
|
+
f'[Global] {count}/{total} | {count/total:.1%} | '
|
|
98
|
+
f'Speed: {speed:.2f} it/s | ETA: {eta_str}'
|
|
99
|
+
)
|
|
100
|
+
print(msg)
|
|
101
|
+
self._last_print_time = time.time()
|
|
102
|
+
|
|
103
|
+
def _run_shard(self, shard: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
104
|
+
self._redirect_output()
|
|
105
|
+
try:
|
|
106
|
+
self.setup()
|
|
107
|
+
results = []
|
|
108
|
+
iterator = shard
|
|
109
|
+
if self.worker_id > 0:
|
|
110
|
+
iterator = tqdm(shard, desc=f'Worker {self.worker_id}')
|
|
111
|
+
for item in iterator:
|
|
112
|
+
try:
|
|
113
|
+
res = self.process_one_item(item)
|
|
114
|
+
results.append(res)
|
|
115
|
+
except Exception as e:
|
|
116
|
+
print(f'Error {item}: {e}')
|
|
117
|
+
results.append(None)
|
|
118
|
+
if self.tracker:
|
|
119
|
+
self.tracker.increment.remote()
|
|
120
|
+
if self.worker_id == 0:
|
|
121
|
+
self._print_global_stats()
|
|
122
|
+
return results
|
|
123
|
+
finally:
|
|
124
|
+
if self._log_file_handle:
|
|
125
|
+
self._log_file_handle.close()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class _VLLMWorker(_VLLMWorkerBase):
|
|
129
|
+
"""Worker that runs vLLM inference on assigned GPUs."""
|
|
130
|
+
|
|
131
|
+
def setup(self) -> None:
|
|
132
|
+
"""Initialize vLLM engine with configured parameters."""
|
|
133
|
+
from vllm import LLM
|
|
134
|
+
|
|
135
|
+
model_name = self.kwargs['model_name']
|
|
136
|
+
tp = self.kwargs.get('tp', 1)
|
|
137
|
+
gpu_memory_utilization = self.kwargs.get(
|
|
138
|
+
'gpu_memory_utilization', 0.9
|
|
139
|
+
)
|
|
140
|
+
trust_remote_code = self.kwargs.get('trust_remote_code', True)
|
|
141
|
+
vllm_kwargs = self.kwargs.get('vllm_kwargs', {})
|
|
142
|
+
|
|
143
|
+
print(
|
|
144
|
+
f'Worker {self.worker_id}: Loading vLLM model {model_name} '
|
|
145
|
+
f'with TP={tp}...'
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.model = LLM(
|
|
149
|
+
model=model_name,
|
|
150
|
+
tensor_parallel_size=tp,
|
|
151
|
+
pipeline_parallel_size=1, # Always 1 as per requirement
|
|
152
|
+
gpu_memory_utilization=gpu_memory_utilization,
|
|
153
|
+
trust_remote_code=trust_remote_code,
|
|
154
|
+
enforce_eager=True,
|
|
155
|
+
**vllm_kwargs,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Store default sampling params
|
|
159
|
+
self.default_sampling_params = self.kwargs.get(
|
|
160
|
+
'sampling_params', {}
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def process_one_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
|
164
|
+
"""Process a single input item with OpenAI-style messages."""
|
|
165
|
+
from vllm import SamplingParams
|
|
166
|
+
|
|
167
|
+
messages = item.get('messages')
|
|
168
|
+
if not messages:
|
|
169
|
+
raise ValueError('Item must contain "messages" key')
|
|
170
|
+
|
|
171
|
+
# Validate messages format
|
|
172
|
+
for msg in messages:
|
|
173
|
+
if not isinstance(msg, dict):
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f'Each message must be dict, got {type(msg)}'
|
|
176
|
+
)
|
|
177
|
+
if 'role' not in msg or 'content' not in msg:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
'Each message must have "role" and "content"'
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Build sampling params (item-specific overrides default)
|
|
183
|
+
sampling_config = {
|
|
184
|
+
**self.default_sampling_params,
|
|
185
|
+
**item.get('sampling_params', {}),
|
|
186
|
+
}
|
|
187
|
+
sampling_params = SamplingParams(**sampling_config)
|
|
188
|
+
|
|
189
|
+
# Use vLLM chat interface
|
|
190
|
+
outputs = self.model.chat(
|
|
191
|
+
messages=[messages],
|
|
192
|
+
sampling_params=sampling_params,
|
|
193
|
+
)
|
|
194
|
+
generated_text = outputs[0].outputs[0].text
|
|
195
|
+
|
|
196
|
+
# Build result
|
|
197
|
+
result = {
|
|
198
|
+
'messages': messages,
|
|
199
|
+
'generated_text': generated_text,
|
|
200
|
+
'worker_id': self.worker_id,
|
|
201
|
+
'finish_reason': outputs[0].outputs[0].finish_reason,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
# Include any extra metadata from input
|
|
205
|
+
for key in item:
|
|
206
|
+
if key not in ['messages', 'sampling_params']:
|
|
207
|
+
result[f'meta_{key}'] = item[key]
|
|
208
|
+
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class LLMRay:
|
|
213
|
+
"""
|
|
214
|
+
Ray-based LLM wrapper for offline batch inference with OpenAI messages.
|
|
215
|
+
|
|
216
|
+
Spawns multiple model replicas (data parallel) across GPUs/nodes.
|
|
217
|
+
Each replica can use multiple GPUs (tensor parallel).
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
model_name: HuggingFace model name or path
|
|
221
|
+
dp: Data parallel - number of model replicas
|
|
222
|
+
tp: Tensor parallel - GPUs per replica
|
|
223
|
+
Total GPUs used = dp * tp
|
|
224
|
+
|
|
225
|
+
Example:
|
|
226
|
+
# 8 GPUs: 4 replicas, each using 2 GPUs
|
|
227
|
+
>>> llm = LLMRay(model_name='Qwen/Qwen3-0.6B', dp=4, tp=2)
|
|
228
|
+
|
|
229
|
+
# 16 GPUs across 2 nodes: 8 replicas, each using 2 GPUs
|
|
230
|
+
>>> llm = LLMRay(model_name='meta-llama/Llama-3-70B', dp=8, tp=2)
|
|
231
|
+
|
|
232
|
+
>>> inputs = [
|
|
233
|
+
... [{'role': 'user', 'content': 'What is AI?'}],
|
|
234
|
+
... [{'role': 'user', 'content': 'Explain quantum computing.'}],
|
|
235
|
+
... ]
|
|
236
|
+
>>> results = llm.generate(inputs)
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
def __init__(
|
|
240
|
+
self,
|
|
241
|
+
model_name: str,
|
|
242
|
+
dp: int = 1,
|
|
243
|
+
tp: int = 1,
|
|
244
|
+
gpu_memory_utilization: float = 0.9,
|
|
245
|
+
trust_remote_code: bool = True,
|
|
246
|
+
sampling_params: Optional[Dict[str, Any]] = None,
|
|
247
|
+
vllm_kwargs: Optional[Dict[str, Any]] = None,
|
|
248
|
+
ray_address: Optional[str] = None,
|
|
249
|
+
):
|
|
250
|
+
"""
|
|
251
|
+
Initialize LLMRay.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
model_name: HuggingFace model name or path
|
|
255
|
+
dp: Data parallel - number of model replicas (workers)
|
|
256
|
+
tp: Tensor parallel - number of GPUs per replica
|
|
257
|
+
gpu_memory_utilization: Fraction of GPU memory to use
|
|
258
|
+
trust_remote_code: Whether to trust remote code from HF
|
|
259
|
+
sampling_params: Default sampling parameters
|
|
260
|
+
vllm_kwargs: Additional kwargs to pass to vLLM constructor
|
|
261
|
+
ray_address: Ray cluster address ('auto' for existing cluster,
|
|
262
|
+
None for local, or specific address like 'ray://...')
|
|
263
|
+
"""
|
|
264
|
+
self.model_name = model_name
|
|
265
|
+
self.dp = dp
|
|
266
|
+
self.tp = tp
|
|
267
|
+
self.gpu_memory_utilization = gpu_memory_utilization
|
|
268
|
+
self.trust_remote_code = trust_remote_code
|
|
269
|
+
self.sampling_params = sampling_params or {
|
|
270
|
+
'temperature': 0.7,
|
|
271
|
+
'max_tokens': 512,
|
|
272
|
+
}
|
|
273
|
+
self.vllm_kwargs = vllm_kwargs or {}
|
|
274
|
+
self.ray_address = ray_address
|
|
275
|
+
|
|
276
|
+
# Setup logging
|
|
277
|
+
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
278
|
+
self.log_base = f'/tmp/raylog/llmray_{timestamp}'
|
|
279
|
+
|
|
280
|
+
# Initialize Ray
|
|
281
|
+
self._init_ray()
|
|
282
|
+
|
|
283
|
+
def _init_ray(self) -> None:
|
|
284
|
+
"""Initialize Ray cluster connection."""
|
|
285
|
+
if not ray.is_initialized():
|
|
286
|
+
if self.ray_address:
|
|
287
|
+
ray.init(address=self.ray_address, ignore_reinit_error=True)
|
|
288
|
+
else:
|
|
289
|
+
ray.init(ignore_reinit_error=True)
|
|
290
|
+
|
|
291
|
+
resources = ray.cluster_resources()
|
|
292
|
+
total_gpus = int(resources.get('GPU', 0))
|
|
293
|
+
required_gpus = self.dp * self.tp
|
|
294
|
+
|
|
295
|
+
if total_gpus == 0:
|
|
296
|
+
raise RuntimeError('No GPUs found in Ray cluster!')
|
|
297
|
+
|
|
298
|
+
if total_gpus < required_gpus:
|
|
299
|
+
raise RuntimeError(
|
|
300
|
+
f'Not enough GPUs: need {required_gpus} (dp={self.dp} x '
|
|
301
|
+
f'tp={self.tp}), but cluster has {total_gpus}'
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
print(f'>>> Ray cluster connected. Total GPUs: {total_gpus}')
|
|
305
|
+
print(f'>>> Config: dp={self.dp}, tp={self.tp} → {required_gpus} GPUs')
|
|
306
|
+
print(f'>>> Logs: {self.log_base}')
|
|
307
|
+
os.makedirs(self.log_base, exist_ok=True)
|
|
308
|
+
|
|
309
|
+
def generate(self, inputs: List[Messages]) -> List[Dict[str, Any]]:
|
|
310
|
+
"""
|
|
311
|
+
Generate responses for a batch of message lists.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
inputs: List of message lists, where each message list is
|
|
315
|
+
OpenAI-style: [{'role': 'user', 'content': '...'}]
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
List of result dictionaries with generated text and metadata
|
|
319
|
+
"""
|
|
320
|
+
# Normalize inputs to dict format with 'messages' key
|
|
321
|
+
normalized_inputs = []
|
|
322
|
+
for messages in inputs:
|
|
323
|
+
if not isinstance(messages, list):
|
|
324
|
+
raise ValueError(
|
|
325
|
+
f'Each input must be list of messages, got {type(messages)}'
|
|
326
|
+
)
|
|
327
|
+
normalized_inputs.append({'messages': messages})
|
|
328
|
+
|
|
329
|
+
num_workers = self.dp
|
|
330
|
+
print(f'>>> Spawning {num_workers} workers for {len(inputs)} items.')
|
|
331
|
+
|
|
332
|
+
# 1. Start the Global Tracker
|
|
333
|
+
tracker = _ProgressTracker.remote(len(normalized_inputs))
|
|
334
|
+
|
|
335
|
+
# 2. Prepare Shards
|
|
336
|
+
shards = np.array_split(normalized_inputs, num_workers)
|
|
337
|
+
|
|
338
|
+
# 3. Create Remote Worker Class with tp GPUs per worker
|
|
339
|
+
RemoteWorker = ray.remote(num_gpus=self.tp)(_VLLMWorker)
|
|
340
|
+
|
|
341
|
+
actors = []
|
|
342
|
+
futures = []
|
|
343
|
+
|
|
344
|
+
for i, shard in enumerate(shards):
|
|
345
|
+
if len(shard) == 0:
|
|
346
|
+
continue
|
|
347
|
+
|
|
348
|
+
# Initialize Actor
|
|
349
|
+
actor = RemoteWorker.remote(
|
|
350
|
+
worker_id=i,
|
|
351
|
+
log_dir=self.log_base,
|
|
352
|
+
tracker=tracker,
|
|
353
|
+
model_name=self.model_name,
|
|
354
|
+
tp=self.tp,
|
|
355
|
+
gpu_memory_utilization=self.gpu_memory_utilization,
|
|
356
|
+
trust_remote_code=self.trust_remote_code,
|
|
357
|
+
sampling_params=self.sampling_params,
|
|
358
|
+
vllm_kwargs=self.vllm_kwargs,
|
|
359
|
+
)
|
|
360
|
+
actors.append(actor)
|
|
361
|
+
|
|
362
|
+
# Launch Task
|
|
363
|
+
futures.append(actor._run_shard.remote(shard.tolist()))
|
|
364
|
+
|
|
365
|
+
results = ray.get(futures)
|
|
366
|
+
return [item for sublist in results for item in sublist]
|
|
367
|
+
|
|
368
|
+
def __call__(self, inputs: List[Messages]) -> List[Dict[str, Any]]:
|
|
369
|
+
"""Alias for generate()."""
|
|
370
|
+
return self.generate(inputs)
|
llm_utils/lm/llm.py
CHANGED
|
@@ -9,6 +9,7 @@ import subprocess
|
|
|
9
9
|
from typing import Any, Dict, List, Optional, Type, Union, cast
|
|
10
10
|
|
|
11
11
|
import requests
|
|
12
|
+
from httpx import Timeout
|
|
12
13
|
from loguru import logger
|
|
13
14
|
from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError
|
|
14
15
|
from openai.types.chat import ChatCompletionMessageParam
|
|
@@ -66,6 +67,7 @@ class LLM(
|
|
|
66
67
|
vllm_timeout: int = 1200,
|
|
67
68
|
vllm_reuse: bool = True,
|
|
68
69
|
verbose=False,
|
|
70
|
+
timeout: float | Timeout | None = None,
|
|
69
71
|
**model_kwargs,
|
|
70
72
|
):
|
|
71
73
|
"""Initialize LLMTask."""
|
|
@@ -83,8 +85,10 @@ class LLM(
|
|
|
83
85
|
self.vllm_timeout = vllm_timeout
|
|
84
86
|
self.vllm_reuse = vllm_reuse
|
|
85
87
|
self.vllm_process: subprocess.Popen | None = None
|
|
88
|
+
self.timeout = timeout
|
|
86
89
|
self.last_ai_response = None # Store raw response from client
|
|
87
90
|
self.cache = cache
|
|
91
|
+
self.api_key = client.api_key if isinstance(client, OpenAI) else 'abc'
|
|
88
92
|
|
|
89
93
|
# Handle VLLM server startup if vllm_cmd is provided
|
|
90
94
|
if self.vllm_cmd:
|
|
@@ -96,7 +100,11 @@ class LLM(
|
|
|
96
100
|
client = port
|
|
97
101
|
|
|
98
102
|
self.client = get_base_client(
|
|
99
|
-
client,
|
|
103
|
+
client,
|
|
104
|
+
cache=cache,
|
|
105
|
+
api_key=self.api_key,
|
|
106
|
+
vllm_cmd=self.vllm_cmd,
|
|
107
|
+
vllm_process=self.vllm_process,
|
|
100
108
|
)
|
|
101
109
|
# check connection of client
|
|
102
110
|
try:
|
|
@@ -165,6 +173,9 @@ class LLM(
|
|
|
165
173
|
# Extract model name from kwargs for API call
|
|
166
174
|
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != 'model'}
|
|
167
175
|
|
|
176
|
+
if 'timeout' not in api_kwargs and self.timeout is not None:
|
|
177
|
+
api_kwargs['timeout'] = self.timeout
|
|
178
|
+
|
|
168
179
|
try:
|
|
169
180
|
completion = self.client.chat.completions.create(
|
|
170
181
|
model=model_name, messages=messages, **api_kwargs
|
|
@@ -220,6 +231,9 @@ class LLM(
|
|
|
220
231
|
# Extract model name from kwargs for API call
|
|
221
232
|
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != 'model'}
|
|
222
233
|
|
|
234
|
+
if 'timeout' not in api_kwargs and self.timeout is not None:
|
|
235
|
+
api_kwargs['timeout'] = self.timeout
|
|
236
|
+
|
|
223
237
|
pydantic_model_to_use_opt = response_model or self.output_model
|
|
224
238
|
if pydantic_model_to_use_opt is None:
|
|
225
239
|
raise ValueError(
|
|
@@ -398,6 +412,7 @@ class LLM(
|
|
|
398
412
|
vllm_cmd: str | None = None,
|
|
399
413
|
vllm_timeout: int = 120,
|
|
400
414
|
vllm_reuse: bool = True,
|
|
415
|
+
timeout: float | Timeout | None = None,
|
|
401
416
|
**model_kwargs,
|
|
402
417
|
) -> 'LLM':
|
|
403
418
|
"""
|
|
@@ -415,6 +430,7 @@ class LLM(
|
|
|
415
430
|
vllm_cmd: Optional VLLM command to start server automatically
|
|
416
431
|
vllm_timeout: Timeout in seconds to wait for VLLM server (default 120)
|
|
417
432
|
vllm_reuse: If True (default), reuse existing server on target port
|
|
433
|
+
timeout: Optional OpenAI client timeout in seconds
|
|
418
434
|
**model_kwargs: Additional model parameters
|
|
419
435
|
"""
|
|
420
436
|
instruction = cls.get_instruction()
|
|
@@ -433,11 +449,9 @@ class LLM(
|
|
|
433
449
|
vllm_cmd=vllm_cmd,
|
|
434
450
|
vllm_timeout=vllm_timeout,
|
|
435
451
|
vllm_reuse=vllm_reuse,
|
|
452
|
+
timeout=timeout,
|
|
436
453
|
**model_kwargs,
|
|
437
454
|
)
|
|
438
|
-
from typing import Any, Dict, List, Optional, Type, Union
|
|
439
|
-
from pydantic import BaseModel
|
|
440
|
-
from .llm import LLM, Messages
|
|
441
455
|
|
|
442
456
|
class LLM_NEMOTRON3(LLM):
|
|
443
457
|
"""
|
|
@@ -447,15 +461,15 @@ class LLM_NEMOTRON3(LLM):
|
|
|
447
461
|
|
|
448
462
|
def __init__(
|
|
449
463
|
self,
|
|
450
|
-
model: str =
|
|
464
|
+
model: str = 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16',
|
|
451
465
|
thinking_budget: int = 1024,
|
|
452
466
|
enable_thinking: bool = True,
|
|
453
|
-
**kwargs
|
|
467
|
+
**kwargs,
|
|
454
468
|
):
|
|
455
469
|
# Force reasoning_model to True to enable reasoning_content extraction
|
|
456
470
|
kwargs['is_reasoning_model'] = True
|
|
457
471
|
super().__init__(**kwargs)
|
|
458
|
-
|
|
472
|
+
|
|
459
473
|
self.model_kwargs['model'] = model
|
|
460
474
|
self.thinking_budget = thinking_budget
|
|
461
475
|
self.enable_thinking = enable_thinking
|
|
@@ -469,55 +483,48 @@ class LLM_NEMOTRON3(LLM):
|
|
|
469
483
|
self,
|
|
470
484
|
input_data: str | BaseModel | list[dict],
|
|
471
485
|
thinking_budget: Optional[int] = None,
|
|
472
|
-
**kwargs
|
|
486
|
+
**kwargs,
|
|
473
487
|
) -> List[Dict[str, Any]]:
|
|
474
488
|
budget = thinking_budget or self.thinking_budget
|
|
475
|
-
|
|
489
|
+
|
|
476
490
|
if not self.enable_thinking:
|
|
477
491
|
# Simple pass with thinking disabled in template
|
|
478
492
|
return super().__call__(
|
|
479
|
-
input_data,
|
|
480
|
-
chat_template_kwargs={"enable_thinking": False},
|
|
481
|
-
**kwargs
|
|
493
|
+
input_data, chat_template_kwargs={'enable_thinking': False}, **kwargs
|
|
482
494
|
)
|
|
483
495
|
|
|
484
496
|
# --- STEP 1: Generate Thinking Trace ---
|
|
485
497
|
# We manually append <think> to force the reasoning MoE layers
|
|
486
498
|
messages = self._prepare_input(input_data)
|
|
487
|
-
|
|
499
|
+
|
|
488
500
|
# We use the raw text completion for the budget phase
|
|
489
501
|
# Stop at the closing tag or budget limit
|
|
490
502
|
thinking_response = self.text_completion(
|
|
491
|
-
input_data,
|
|
492
|
-
max_tokens=budget,
|
|
493
|
-
stop=["</think>"],
|
|
494
|
-
**kwargs
|
|
503
|
+
input_data, max_tokens=budget, stop=['</think>'], **kwargs
|
|
495
504
|
)[0]
|
|
496
505
|
|
|
497
506
|
reasoning_content = thinking_response['parsed']
|
|
498
|
-
|
|
507
|
+
|
|
499
508
|
# Ensure proper tag closing for the second pass
|
|
500
|
-
if
|
|
501
|
-
reasoning_content = f
|
|
502
|
-
elif not reasoning_content.endswith(
|
|
509
|
+
if '</think>' not in reasoning_content:
|
|
510
|
+
reasoning_content = f'{reasoning_content}\n</think>'
|
|
511
|
+
elif not reasoning_content.endswith('</think>'):
|
|
503
512
|
# Ensure it ends exactly with the tag for continuity
|
|
504
|
-
reasoning_content = reasoning_content.split(
|
|
513
|
+
reasoning_content = reasoning_content.split('</think>')[0] + '</think>'
|
|
505
514
|
|
|
506
515
|
# --- STEP 2: Generate Final Answer ---
|
|
507
516
|
# Append the thought to the assistant role and continue
|
|
508
517
|
final_messages = messages + [
|
|
509
|
-
{
|
|
518
|
+
{'role': 'assistant', 'content': f'<think>\n{reasoning_content}\n'}
|
|
510
519
|
]
|
|
511
|
-
|
|
520
|
+
|
|
512
521
|
# Use continue_final_message to prevent the model from repeating the header
|
|
513
522
|
results = super().__call__(
|
|
514
|
-
final_messages,
|
|
515
|
-
extra_body={"continue_final_message": True},
|
|
516
|
-
**kwargs
|
|
523
|
+
final_messages, extra_body={'continue_final_message': True}, **kwargs
|
|
517
524
|
)
|
|
518
525
|
|
|
519
526
|
# Inject the reasoning back into the result for the UI/API
|
|
520
527
|
for res in results:
|
|
521
528
|
res['reasoning_content'] = reasoning_content
|
|
522
|
-
|
|
523
|
-
return results
|
|
529
|
+
|
|
530
|
+
return results
|
speedy_utils/__init__.py
CHANGED
|
@@ -54,6 +54,7 @@ from .common.utils_print import (
|
|
|
54
54
|
# Multi-worker processing
|
|
55
55
|
from .multi_worker.process import multi_process
|
|
56
56
|
from .multi_worker.thread import kill_all_thread, multi_thread
|
|
57
|
+
from .multi_worker.dataset_ray import multi_process_dataset_ray, WorkerResources
|
|
57
58
|
|
|
58
59
|
|
|
59
60
|
__all__ = [
|
|
@@ -152,6 +153,8 @@ __all__ = [
|
|
|
152
153
|
'multi_process',
|
|
153
154
|
'multi_thread',
|
|
154
155
|
'kill_all_thread',
|
|
156
|
+
'multi_process_dataset_ray',
|
|
157
|
+
'WorkerResources',
|
|
155
158
|
# Notebook utilities
|
|
156
159
|
'change_dir',
|
|
157
160
|
]
|
speedy_utils/common/utils_io.py
CHANGED
|
@@ -87,10 +87,12 @@ def load_json_or_pickle(fname: str, counter=0) -> Any:
|
|
|
87
87
|
# EOFError: Ran out of input
|
|
88
88
|
except EOFError:
|
|
89
89
|
time.sleep(1)
|
|
90
|
+
|
|
90
91
|
if counter > 5:
|
|
91
92
|
# Keep message concise and actionable
|
|
92
93
|
print(
|
|
93
|
-
f
|
|
94
|
+
f"[load_json_or_pickle] EOFError reading cache file='{fname}' (attempt={counter}). "
|
|
95
|
+
f"Assuming partial write/corruption; deleted file and will regenerate on next access."
|
|
94
96
|
)
|
|
95
97
|
os.remove(fname)
|
|
96
98
|
raise
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .process import multi_process, cleanup_phantom_workers, create_progress_tracker
|
|
2
|
+
from .thread import multi_thread
|
|
3
|
+
from .dataset_ray import multi_process_dataset_ray, WorkerResources
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
'multi_process',
|
|
7
|
+
'multi_thread',
|
|
8
|
+
'cleanup_phantom_workers',
|
|
9
|
+
'create_progress_tracker',
|
|
10
|
+
'multi_process_dataset_ray',
|
|
11
|
+
'WorkerResources',
|
|
12
|
+
]
|