speedy-utils 1.1.46__py3-none-any.whl → 1.1.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/llm.py CHANGED
@@ -4,11 +4,9 @@
4
4
  Simplified LLM Task module for handling language model interactions with structured input/output.
5
5
  """
6
6
 
7
- import os
8
7
  import subprocess
9
- from typing import Any, Dict, List, Optional, Type, Union, cast
8
+ from typing import Any, Dict, List, Optional, cast
10
9
 
11
- import requests
12
10
  from httpx import Timeout
13
11
  from loguru import logger
14
12
  from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError
@@ -28,16 +26,7 @@ from .mixins import (
28
26
  )
29
27
  from .utils import (
30
28
  _extract_port_from_vllm_cmd,
31
- _get_port_from_client,
32
- _is_lora_path,
33
- _is_server_running,
34
- _kill_vllm_on_port,
35
- _load_lora_adapter,
36
- _start_vllm_server,
37
- _unload_lora_adapter,
38
29
  get_base_client,
39
- kill_all_vllm_processes,
40
- stop_vllm_process,
41
30
  )
42
31
 
43
32
 
@@ -216,6 +205,46 @@ class LLM(
216
205
  results.append(result_dict)
217
206
  return results
218
207
 
208
+ @staticmethod
209
+ def _strip_think_tags(text: str) -> str:
210
+ """Remove <think> tags if present, returning only the reasoning body."""
211
+ cleaned = text.strip()
212
+ if cleaned.startswith('<think>'):
213
+ cleaned = cleaned[len('<think>') :].lstrip()
214
+ if '</think>' in cleaned:
215
+ cleaned = cleaned.split('</think>', 1)[0].rstrip()
216
+ return cleaned
217
+
218
+ def generate_with_think_prefix(
219
+ self, input_data: str | BaseModel | list[dict], **runtime_kwargs
220
+ ) -> list[dict[str, Any]]:
221
+ """
222
+ Generate text and format output as:
223
+ <think>reasoning</think>
224
+ """
225
+ results = self.text_completion(input_data, **runtime_kwargs)
226
+
227
+ for result in results:
228
+ content = result.get('parsed') or ''
229
+ reasoning = result.get('reasoning_content') or ''
230
+
231
+ if not reasoning and str(content).lstrip().startswith('<think>'):
232
+ formatted = str(content)
233
+ else:
234
+ reasoning_body = self._strip_think_tags(str(reasoning))
235
+ formatted = (
236
+ f'<think>\n{reasoning_body}\n</think>\n\n{str(content).lstrip()}'
237
+ )
238
+
239
+ result['parsed'] = formatted
240
+ messages = result.get('messages')
241
+ if isinstance(messages, list) and messages:
242
+ last_msg = messages[-1]
243
+ if isinstance(last_msg, dict) and last_msg.get('role') == 'assistant':
244
+ last_msg['content'] = formatted
245
+
246
+ return results
247
+
219
248
  @clean_traceback
220
249
  def pydantic_parse(
221
250
  self,
speedy_utils/__init__.py CHANGED
@@ -65,6 +65,8 @@ from .common.utils_error import clean_traceback, handle_exceptions_with_clean_tr
65
65
  from .multi_worker.process import multi_process
66
66
  from .multi_worker.thread import kill_all_thread, multi_thread
67
67
  from .multi_worker.dataset_ray import multi_process_dataset_ray, WorkerResources
68
+ from .multi_worker.dataset_sharding import multi_process_dataset
69
+ from .multi_worker.progress import report_progress
68
70
 
69
71
 
70
72
  __all__ = [
@@ -167,7 +169,9 @@ __all__ = [
167
169
  'multi_thread',
168
170
  'kill_all_thread',
169
171
  'multi_process_dataset_ray',
172
+ 'multi_process_dataset',
170
173
  'WorkerResources',
174
+ 'report_progress',
171
175
  # Notebook utilities
172
176
  'change_dir',
173
177
  ]
@@ -1,6 +1,8 @@
1
1
  from .process import multi_process, cleanup_phantom_workers, create_progress_tracker
2
2
  from .thread import multi_thread
3
3
  from .dataset_ray import multi_process_dataset_ray, WorkerResources
4
+ from .dataset_sharding import multi_process_dataset
5
+ from .progress import report_progress
4
6
 
5
7
  __all__ = [
6
8
  'multi_process',
@@ -9,4 +11,6 @@ __all__ = [
9
11
  'create_progress_tracker',
10
12
  'multi_process_dataset_ray',
11
13
  'WorkerResources',
14
+ 'multi_process_dataset',
15
+ 'report_progress',
12
16
  ]
@@ -0,0 +1,425 @@
1
+ """
2
+ Multi-process implementation with sequential and threadpool backends.
3
+
4
+ Provides the public `multi_process` dispatcher and implementations for:
5
+ - 'seq': Sequential execution
6
+ - 'mp': ThreadPoolExecutor-based parallelism
7
+ - 'safe': Same as 'mp' but without process tracking (for tests)
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import concurrent.futures
12
+ import inspect
13
+ import os
14
+ import sys
15
+ import warnings
16
+ from typing import Any, Callable, Iterable, Literal
17
+
18
+ from tqdm import tqdm
19
+
20
+ from .common import (
21
+ ErrorHandlerType,
22
+ ErrorStats,
23
+ _build_cache_dir,
24
+ _call_with_log_control,
25
+ _cleanup_log_gate,
26
+ _exit_on_worker_error,
27
+ _prune_dead_processes,
28
+ _track_multiprocessing_processes,
29
+ create_log_gate_path,
30
+ wrap_dump,
31
+ )
32
+
33
+ # Import thread tracking functions if available
34
+ try:
35
+ from .thread import _prune_dead_threads, _track_executor_threads
36
+ except ImportError:
37
+ _prune_dead_threads = None # type: ignore[assignment]
38
+ _track_executor_threads = None # type: ignore[assignment]
39
+
40
+
41
+ def multi_process(
42
+ func: Callable[[Any], Any],
43
+ items: Iterable[Any] | None = None,
44
+ *,
45
+ inputs: Iterable[Any] | None = None,
46
+ workers: int | None = None,
47
+ lazy_output: bool = False,
48
+ progress: bool = True,
49
+ backend: Literal['seq', 'ray', 'mp', 'safe'] = 'mp',
50
+ desc: str | None = None,
51
+ shared_kwargs: list[str] | None = None,
52
+ dump_in_thread: bool = True,
53
+ ray_metrics_port: int | None = None,
54
+ log_worker: Literal['zero', 'first', 'all'] = 'first',
55
+ total_items: int | None = None,
56
+ poll_interval: float = 0.3,
57
+ error_handler: ErrorHandlerType = 'log',
58
+ max_error_files: int = 100,
59
+ process_update_interval: int | None = None,
60
+ batch: int | None = None,
61
+ ordered: bool = True,
62
+ stop_on_error: bool = True,
63
+ **func_kwargs: Any,
64
+ ) -> list[Any]:
65
+ """
66
+ Multi-process map with selectable backend.
67
+
68
+ backend:
69
+ - "seq": run sequentially
70
+ - "ray": run in parallel with Ray
71
+ - "mp": run in parallel with thread pool (uses ThreadPoolExecutor)
72
+ - "safe": run in parallel with thread pool (explicitly safe for tests)
73
+
74
+ shared_kwargs:
75
+ - Optional list of kwarg names that should be shared via Ray's
76
+ zero-copy object store
77
+ - Only works with Ray backend
78
+ - Useful for large objects (e.g., models, datasets)
79
+ - Example: shared_kwargs=['model', 'tokenizer']
80
+
81
+ dump_in_thread:
82
+ - Whether to dump results to disk in a separate thread (default: True)
83
+ - If False, dumping is done synchronously
84
+
85
+ ray_metrics_port:
86
+ - Optional port for Ray metrics export (Ray backend only)
87
+
88
+ log_worker:
89
+ - Control worker stdout/stderr noise
90
+ - 'first': only first worker emits logs (default)
91
+ - 'all': allow worker prints
92
+ - 'zero': silence all worker output
93
+
94
+ total_items:
95
+ - Optional item-level total for progress tracking (Ray backend only)
96
+
97
+ poll_interval:
98
+ - Poll interval in seconds for progress actor updates (Ray only)
99
+
100
+ error_handler:
101
+ - 'raise': raise exception on first error
102
+ - 'ignore': continue processing, return None for failed items
103
+ - 'log': same as ignore, but logs errors to files (default)
104
+ - Note: for 'mp' and 'ray' backends, 'raise' prints a formatted
105
+ traceback and exits the process.
106
+
107
+ max_error_files:
108
+ - Maximum number of error log files to write (default: 100)
109
+ - Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
110
+ - First error is always printed to screen with the log file path
111
+
112
+ process_update_interval:
113
+ - Legacy parameter, accepted for backward compatibility but not implemented
114
+
115
+ batch:
116
+ - Legacy parameter, accepted for backward compatibility but not implemented
117
+
118
+ ordered:
119
+ - Whether to maintain order of results (default: True)
120
+ - Legacy parameter, accepted for backward compatibility but not implemented
121
+
122
+ stop_on_error:
123
+ - Whether to stop on first error (default: True)
124
+ - Legacy parameter, accepted for backward compatibility
125
+ - Use error_handler parameter instead for error handling control
126
+
127
+ If lazy_output=True, every result is saved to .pkl and
128
+ the returned list contains file paths.
129
+ """
130
+
131
+ # default backend selection
132
+ if backend is None:
133
+ try:
134
+ import ray as _ray_module
135
+ backend = 'ray'
136
+ except ImportError:
137
+ backend = 'mp'
138
+
139
+ # Validate shared_kwargs
140
+ if shared_kwargs:
141
+ sig = inspect.signature(func)
142
+ valid_params = set(sig.parameters.keys())
143
+
144
+ for kw in shared_kwargs:
145
+ if kw not in func_kwargs:
146
+ raise ValueError(
147
+ f"shared_kwargs key '{kw}' not found in "
148
+ f"provided func_kwargs"
149
+ )
150
+ has_var_keyword = any(
151
+ p.kind == inspect.Parameter.VAR_KEYWORD
152
+ for p in sig.parameters.values()
153
+ )
154
+ if kw not in valid_params and not has_var_keyword:
155
+ raise ValueError(
156
+ f"shared_kwargs key '{kw}' is not a valid parameter "
157
+ f"for function '{func.__name__}'. "
158
+ f"Valid parameters: {valid_params}"
159
+ )
160
+
161
+ # Prefer Ray backend when shared kwargs are requested
162
+ if shared_kwargs and backend != 'ray':
163
+ warnings.warn(
164
+ "shared_kwargs only supported with 'ray' backend, "
165
+ "switching backend to 'ray'",
166
+ UserWarning,
167
+ )
168
+ backend = 'ray'
169
+
170
+ # unify items and coerce to concrete list
171
+ if items is None and inputs is not None:
172
+ items = list(inputs)
173
+ if items is not None and not isinstance(items, list):
174
+ items = list(items)
175
+ if items is None:
176
+ raise ValueError("'items' or 'inputs' must be provided")
177
+
178
+ if workers is None and backend != 'ray':
179
+ workers = os.cpu_count() or 1
180
+
181
+ # build cache dir + wrap func
182
+ cache_dir = _build_cache_dir(func, items) if lazy_output else None
183
+ f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
184
+
185
+ log_gate_path = create_log_gate_path(log_worker)
186
+
187
+ total = len(items)
188
+ if desc:
189
+ desc = desc.strip() + f'[{backend}]'
190
+ else:
191
+ desc = f'Multi-process [{backend}]'
192
+
193
+ # Initialize error stats for error handling
194
+ func_name = getattr(func, '__name__', repr(func))
195
+ error_stats = ErrorStats(
196
+ func_name=func_name,
197
+ max_error_files=max_error_files,
198
+ write_logs=error_handler == 'log'
199
+ )
200
+
201
+ def _update_pbar_postfix(pbar: tqdm) -> None:
202
+ """Update pbar with success/error counts."""
203
+ postfix = error_stats.get_postfix_dict()
204
+ pbar.set_postfix(postfix)
205
+
206
+ # ---- sequential backend ----
207
+ if backend == 'seq':
208
+ return _run_seq_backend(
209
+ f_wrapped=f_wrapped,
210
+ items=items,
211
+ total=total,
212
+ desc=desc,
213
+ progress=progress,
214
+ func_kwargs=func_kwargs,
215
+ log_worker=log_worker,
216
+ log_gate_path=log_gate_path,
217
+ error_handler=error_handler,
218
+ error_stats=error_stats,
219
+ func_name=func_name,
220
+ update_pbar_postfix=_update_pbar_postfix,
221
+ )
222
+
223
+ # ---- ray backend ----
224
+ if backend == 'ray':
225
+ from ._multi_process_ray import run_ray_backend
226
+ return run_ray_backend(
227
+ f_wrapped=f_wrapped,
228
+ items=items,
229
+ total=total,
230
+ workers=workers,
231
+ progress=progress,
232
+ desc=desc,
233
+ func_kwargs=func_kwargs,
234
+ shared_kwargs=shared_kwargs,
235
+ log_worker=log_worker,
236
+ log_gate_path=log_gate_path,
237
+ total_items=total_items,
238
+ poll_interval=poll_interval,
239
+ ray_metrics_port=ray_metrics_port,
240
+ error_handler=error_handler,
241
+ error_stats=error_stats,
242
+ func_name=func_name,
243
+ )
244
+
245
+ # ---- threadpool backends (mp / safe) ----
246
+ if backend == 'mp':
247
+ return _run_threadpool_backend(
248
+ backend_label='mp',
249
+ track_processes=True,
250
+ f_wrapped=f_wrapped,
251
+ items=items,
252
+ total=total,
253
+ workers=workers,
254
+ desc=desc,
255
+ progress=progress,
256
+ func_kwargs=func_kwargs,
257
+ log_worker=log_worker,
258
+ log_gate_path=log_gate_path,
259
+ error_handler=error_handler,
260
+ error_stats=error_stats,
261
+ func_name=func_name,
262
+ update_pbar_postfix=_update_pbar_postfix,
263
+ )
264
+
265
+ if backend == 'safe':
266
+ return _run_threadpool_backend(
267
+ backend_label='safe',
268
+ track_processes=False,
269
+ f_wrapped=f_wrapped,
270
+ items=items,
271
+ total=total,
272
+ workers=workers,
273
+ desc=desc,
274
+ progress=progress,
275
+ func_kwargs=func_kwargs,
276
+ log_worker=log_worker,
277
+ log_gate_path=log_gate_path,
278
+ error_handler=error_handler,
279
+ error_stats=error_stats,
280
+ func_name=func_name,
281
+ update_pbar_postfix=_update_pbar_postfix,
282
+ )
283
+
284
+ raise ValueError(f'Unsupported backend: {backend!r}')
285
+
286
+
287
+ def _run_seq_backend(
288
+ *,
289
+ f_wrapped,
290
+ items: list,
291
+ total: int,
292
+ desc: str,
293
+ progress: bool,
294
+ func_kwargs: dict,
295
+ log_worker,
296
+ log_gate_path,
297
+ error_handler,
298
+ error_stats: ErrorStats,
299
+ func_name: str,
300
+ update_pbar_postfix,
301
+ ) -> list[Any]:
302
+ """Run sequential (single-threaded) backend."""
303
+ results: list[Any] = []
304
+ with tqdm(
305
+ total=total,
306
+ desc=desc,
307
+ disable=not progress,
308
+ file=sys.stdout,
309
+ ) as pbar:
310
+ for idx, x in enumerate(items):
311
+ try:
312
+ result = _call_with_log_control(
313
+ f_wrapped,
314
+ x,
315
+ func_kwargs,
316
+ log_worker,
317
+ log_gate_path,
318
+ )
319
+ error_stats.record_success()
320
+ results.append(result)
321
+ except Exception as e:
322
+ if error_handler == 'raise':
323
+ raise
324
+ error_stats.record_error(idx, e, x, func_name)
325
+ results.append(None)
326
+ pbar.update(1)
327
+ update_pbar_postfix(pbar)
328
+ _cleanup_log_gate(log_gate_path)
329
+ return results
330
+
331
+
332
+ def _run_threadpool_backend(
333
+ *,
334
+ backend_label: str,
335
+ track_processes: bool,
336
+ f_wrapped,
337
+ items: list,
338
+ total: int,
339
+ workers: int | None,
340
+ desc: str,
341
+ progress: bool,
342
+ func_kwargs: dict,
343
+ log_worker,
344
+ log_gate_path,
345
+ error_handler,
346
+ error_stats: ErrorStats,
347
+ func_name: str,
348
+ update_pbar_postfix,
349
+ ) -> list[Any]:
350
+ """Run ThreadPoolExecutor backend for 'mp' and 'safe' modes."""
351
+ # Capture caller frame for better error reporting
352
+ caller_frame = inspect.currentframe()
353
+ caller_info = None
354
+ if caller_frame and caller_frame.f_back and caller_frame.f_back.f_back:
355
+ # Go back two frames: _run_threadpool_backend -> multi_process -> user
356
+ outer = caller_frame.f_back.f_back
357
+ caller_info = {
358
+ 'filename': outer.f_code.co_filename,
359
+ 'lineno': outer.f_lineno,
360
+ 'function': outer.f_code.co_name,
361
+ }
362
+
363
+ def worker_func(x):
364
+ return _call_with_log_control(
365
+ f_wrapped,
366
+ x,
367
+ func_kwargs,
368
+ log_worker,
369
+ log_gate_path,
370
+ )
371
+
372
+ results: list[Any] = [None] * total
373
+ with tqdm(
374
+ total=total,
375
+ desc=desc,
376
+ disable=not progress,
377
+ file=sys.stdout,
378
+ ) as pbar:
379
+ with concurrent.futures.ThreadPoolExecutor(
380
+ max_workers=workers
381
+ ) as executor:
382
+ if _track_executor_threads is not None:
383
+ _track_executor_threads(executor)
384
+
385
+ # Submit all tasks
386
+ future_to_idx = {
387
+ executor.submit(worker_func, x): idx
388
+ for idx, x in enumerate(items)
389
+ }
390
+
391
+ # Process results as they complete
392
+ for future in concurrent.futures.as_completed(future_to_idx):
393
+ idx = future_to_idx[future]
394
+ try:
395
+ result = future.result()
396
+ error_stats.record_success()
397
+ results[idx] = result
398
+ except Exception as e:
399
+ if error_handler == 'raise':
400
+ # Cancel remaining futures
401
+ for f in future_to_idx:
402
+ f.cancel()
403
+ _exit_on_worker_error(
404
+ e,
405
+ pbar,
406
+ caller_info,
407
+ backend=backend_label,
408
+ )
409
+ error_stats.record_error(idx, e, items[idx], func_name)
410
+ results[idx] = None
411
+ pbar.update(1)
412
+ update_pbar_postfix(pbar)
413
+
414
+ if _prune_dead_threads is not None:
415
+ _prune_dead_threads()
416
+
417
+ if track_processes:
418
+ _track_multiprocessing_processes()
419
+ _prune_dead_processes()
420
+
421
+ _cleanup_log_gate(log_gate_path)
422
+ return results
423
+
424
+
425
+ __all__ = ['multi_process']