speedy-utils 1.1.46__py3-none-any.whl → 1.1.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/llm.py CHANGED
@@ -4,19 +4,17 @@
4
4
  Simplified LLM Task module for handling language model interactions with structured input/output.
5
5
  """
6
6
 
7
- import os
8
7
  import subprocess
9
- from typing import Any, Dict, List, Optional, Type, Union, cast
8
+ from typing import Any, Dict, List, Optional, cast
10
9
 
11
- import requests
12
10
  from httpx import Timeout
13
11
  from loguru import logger
14
- from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError
12
+ from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError, APITimeoutError
15
13
  from openai.types.chat import ChatCompletionMessageParam
16
14
  from pydantic import BaseModel
17
15
 
18
- from speedy_utils.common.utils_io import jdumps
19
16
  from speedy_utils import clean_traceback
17
+ from speedy_utils.common.utils_io import jdumps
20
18
 
21
19
  from .base_prompt_builder import BasePromptBuilder
22
20
  from .mixins import (
@@ -28,16 +26,7 @@ from .mixins import (
28
26
  )
29
27
  from .utils import (
30
28
  _extract_port_from_vllm_cmd,
31
- _get_port_from_client,
32
- _is_lora_path,
33
- _is_server_running,
34
- _kill_vllm_on_port,
35
- _load_lora_adapter,
36
- _start_vllm_server,
37
- _unload_lora_adapter,
38
29
  get_base_client,
39
- kill_all_vllm_processes,
40
- stop_vllm_process,
41
30
  )
42
31
 
43
32
 
@@ -184,38 +173,89 @@ class LLM(
184
173
  )
185
174
  # Store raw response from client
186
175
  self.last_ai_response = completion
176
+ except APITimeoutError as exc:
177
+ error_msg = f'OpenAI API timeout ({api_kwargs['timeout']}) error: {exc} for model {model_name}'
178
+ logger.error(error_msg)
179
+ raise
187
180
  except (AuthenticationError, RateLimitError, BadRequestError) as exc:
188
181
  error_msg = f'OpenAI API error ({type(exc).__name__}): {exc}'
189
182
  logger.error(error_msg)
190
183
  raise
184
+ except ValueError as exc:
185
+ logger.error(f'ValueError during API call: {exc}')
186
+ raise
191
187
  except Exception as e:
192
188
  is_length_error = 'Length' in str(e) or 'maximum context length' in str(e)
193
189
  if is_length_error:
194
190
  raise ValueError(
195
191
  f'Input too long for model {model_name}. Error: {str(e)[:100]}...'
196
192
  ) from e
197
- # Re-raise all other exceptions
198
193
  raise
199
194
  # print(completion)
200
195
 
201
196
  results: list[dict[str, Any]] = []
202
197
  for choice in completion.choices:
198
+ assistant_message = [{'role': 'assistant', 'content': choice.message.content}]
199
+ try:
200
+ reasoning_content = choice.message.reasoning
201
+ except:
202
+ reasoning_content = None
203
+ if reasoning_content:
204
+ assistant_message[0]['reasoning_content'] = reasoning_content
205
+
203
206
  choice_messages = cast(
204
207
  Messages,
205
- messages + [{'role': 'assistant', 'content': choice.message.content}],
208
+ messages + assistant_message,
206
209
  )
207
210
  result_dict = {
208
211
  'parsed': choice.message.content,
209
212
  'messages': choice_messages,
210
213
  }
211
214
 
212
- # Add reasoning content if this is a reasoning model
213
- if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
214
- result_dict['reasoning_content'] = choice.message.reasoning_content
215
215
 
216
216
  results.append(result_dict)
217
217
  return results
218
218
 
219
+ @staticmethod
220
+ def _strip_think_tags(text: str) -> str:
221
+ """Remove <think> tags if present, returning only the reasoning body."""
222
+ cleaned = text.strip()
223
+ if cleaned.startswith('<think>'):
224
+ cleaned = cleaned[len('<think>') :].lstrip()
225
+ if '</think>' in cleaned:
226
+ cleaned = cleaned.split('</think>', 1)[0].rstrip()
227
+ return cleaned
228
+
229
+ def generate_with_think_prefix(
230
+ self, input_data: str | BaseModel | list[dict], **runtime_kwargs
231
+ ) -> list[dict[str, Any]]:
232
+ """
233
+ Generate text and format output as:
234
+ <think>reasoning</think>
235
+ """
236
+ results = self.text_completion(input_data, **runtime_kwargs)
237
+
238
+ for result in results:
239
+ content = result.get('parsed') or ''
240
+ reasoning = result.get('reasoning_content') or ''
241
+
242
+ if not reasoning and str(content).lstrip().startswith('<think>'):
243
+ formatted = str(content)
244
+ else:
245
+ reasoning_body = self._strip_think_tags(str(reasoning))
246
+ formatted = (
247
+ f'<think>\n{reasoning_body}\n</think>\n\n{str(content).lstrip()}'
248
+ )
249
+
250
+ result['parsed'] = formatted
251
+ messages = result.get('messages')
252
+ if isinstance(messages, list) and messages:
253
+ last_msg = messages[-1]
254
+ if isinstance(last_msg, dict) and last_msg.get('role') == 'assistant':
255
+ last_msg['content'] = formatted
256
+
257
+ return results
258
+
219
259
  @clean_traceback
220
260
  def pydantic_parse(
221
261
  self,
@@ -365,12 +405,12 @@ class LLM(
365
405
  ) -> list[dict[str, Any]]:
366
406
  """Inspect the message history of a specific response choice."""
367
407
  if hasattr(self, '_last_conversations'):
368
- from llm_utils import show_chat_v2
408
+ from llm_utils import show_chat
369
409
 
370
410
  conv = self._last_conversations[idx]
371
411
  if k_last_messages > 0:
372
412
  conv = conv[-k_last_messages:]
373
- return show_chat_v2(conv)
413
+ return show_chat(conv)
374
414
  raise ValueError('No message history available. Make a call first.')
375
415
 
376
416
  def __inner_call__(
@@ -413,7 +453,7 @@ class LLM(
413
453
  is_reasoning_model: bool = False,
414
454
  lora_path: str | None = None,
415
455
  vllm_cmd: str | None = None,
416
- vllm_timeout: int = 120,
456
+ vllm_timeout: int = 0.1,
417
457
  vllm_reuse: bool = True,
418
458
  timeout: float | Timeout | None = None,
419
459
  **model_kwargs,
speedy_utils/__init__.py CHANGED
@@ -65,6 +65,8 @@ from .common.utils_error import clean_traceback, handle_exceptions_with_clean_tr
65
65
  from .multi_worker.process import multi_process
66
66
  from .multi_worker.thread import kill_all_thread, multi_thread
67
67
  from .multi_worker.dataset_ray import multi_process_dataset_ray, WorkerResources
68
+ from .multi_worker.dataset_sharding import multi_process_dataset
69
+ from .multi_worker.progress import report_progress
68
70
 
69
71
 
70
72
  __all__ = [
@@ -167,7 +169,9 @@ __all__ = [
167
169
  'multi_thread',
168
170
  'kill_all_thread',
169
171
  'multi_process_dataset_ray',
172
+ 'multi_process_dataset',
170
173
  'WorkerResources',
174
+ 'report_progress',
171
175
  # Notebook utilities
172
176
  'change_dir',
173
177
  ]
@@ -1,6 +1,8 @@
1
1
  from .process import multi_process, cleanup_phantom_workers, create_progress_tracker
2
2
  from .thread import multi_thread
3
3
  from .dataset_ray import multi_process_dataset_ray, WorkerResources
4
+ from .dataset_sharding import multi_process_dataset
5
+ from .progress import report_progress
4
6
 
5
7
  __all__ = [
6
8
  'multi_process',
@@ -9,4 +11,6 @@ __all__ = [
9
11
  'create_progress_tracker',
10
12
  'multi_process_dataset_ray',
11
13
  'WorkerResources',
14
+ 'multi_process_dataset',
15
+ 'report_progress',
12
16
  ]
@@ -0,0 +1,425 @@
1
+ """
2
+ Multi-process implementation with sequential and threadpool backends.
3
+
4
+ Provides the public `multi_process` dispatcher and implementations for:
5
+ - 'seq': Sequential execution
6
+ - 'mp': ThreadPoolExecutor-based parallelism
7
+ - 'safe': Same as 'mp' but without process tracking (for tests)
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import concurrent.futures
12
+ import inspect
13
+ import os
14
+ import sys
15
+ import warnings
16
+ from typing import Any, Callable, Iterable, Literal
17
+
18
+ from tqdm import tqdm
19
+
20
+ from .common import (
21
+ ErrorHandlerType,
22
+ ErrorStats,
23
+ _build_cache_dir,
24
+ _call_with_log_control,
25
+ _cleanup_log_gate,
26
+ _exit_on_worker_error,
27
+ _prune_dead_processes,
28
+ _track_multiprocessing_processes,
29
+ create_log_gate_path,
30
+ wrap_dump,
31
+ )
32
+
33
+ # Import thread tracking functions if available
34
+ try:
35
+ from .thread import _prune_dead_threads, _track_executor_threads
36
+ except ImportError:
37
+ _prune_dead_threads = None # type: ignore[assignment]
38
+ _track_executor_threads = None # type: ignore[assignment]
39
+
40
+
41
+ def multi_process(
42
+ func: Callable[[Any], Any],
43
+ items: Iterable[Any] | None = None,
44
+ *,
45
+ inputs: Iterable[Any] | None = None,
46
+ workers: int | None = None,
47
+ lazy_output: bool = False,
48
+ progress: bool = True,
49
+ backend: Literal['seq', 'ray', 'mp', 'safe'] = 'mp',
50
+ desc: str | None = None,
51
+ shared_kwargs: list[str] | None = None,
52
+ dump_in_thread: bool = True,
53
+ ray_metrics_port: int | None = None,
54
+ log_worker: Literal['zero', 'first', 'all'] = 'first',
55
+ total_items: int | None = None,
56
+ poll_interval: float = 0.3,
57
+ error_handler: ErrorHandlerType = 'log',
58
+ max_error_files: int = 100,
59
+ process_update_interval: int | None = None,
60
+ batch: int | None = None,
61
+ ordered: bool = True,
62
+ stop_on_error: bool = True,
63
+ **func_kwargs: Any,
64
+ ) -> list[Any]:
65
+ """
66
+ Multi-process map with selectable backend.
67
+
68
+ backend:
69
+ - "seq": run sequentially
70
+ - "ray": run in parallel with Ray
71
+ - "mp": run in parallel with thread pool (uses ThreadPoolExecutor)
72
+ - "safe": run in parallel with thread pool (explicitly safe for tests)
73
+
74
+ shared_kwargs:
75
+ - Optional list of kwarg names that should be shared via Ray's
76
+ zero-copy object store
77
+ - Only works with Ray backend
78
+ - Useful for large objects (e.g., models, datasets)
79
+ - Example: shared_kwargs=['model', 'tokenizer']
80
+
81
+ dump_in_thread:
82
+ - Whether to dump results to disk in a separate thread (default: True)
83
+ - If False, dumping is done synchronously
84
+
85
+ ray_metrics_port:
86
+ - Optional port for Ray metrics export (Ray backend only)
87
+
88
+ log_worker:
89
+ - Control worker stdout/stderr noise
90
+ - 'first': only first worker emits logs (default)
91
+ - 'all': allow worker prints
92
+ - 'zero': silence all worker output
93
+
94
+ total_items:
95
+ - Optional item-level total for progress tracking (Ray backend only)
96
+
97
+ poll_interval:
98
+ - Poll interval in seconds for progress actor updates (Ray only)
99
+
100
+ error_handler:
101
+ - 'raise': raise exception on first error
102
+ - 'ignore': continue processing, return None for failed items
103
+ - 'log': same as ignore, but logs errors to files (default)
104
+ - Note: for 'mp' and 'ray' backends, 'raise' prints a formatted
105
+ traceback and exits the process.
106
+
107
+ max_error_files:
108
+ - Maximum number of error log files to write (default: 100)
109
+ - Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
110
+ - First error is always printed to screen with the log file path
111
+
112
+ process_update_interval:
113
+ - Legacy parameter, accepted for backward compatibility but not implemented
114
+
115
+ batch:
116
+ - Legacy parameter, accepted for backward compatibility but not implemented
117
+
118
+ ordered:
119
+ - Whether to maintain order of results (default: True)
120
+ - Legacy parameter, accepted for backward compatibility but not implemented
121
+
122
+ stop_on_error:
123
+ - Whether to stop on first error (default: True)
124
+ - Legacy parameter, accepted for backward compatibility
125
+ - Use error_handler parameter instead for error handling control
126
+
127
+ If lazy_output=True, every result is saved to .pkl and
128
+ the returned list contains file paths.
129
+ """
130
+
131
+ # default backend selection
132
+ if backend is None:
133
+ try:
134
+ import ray as _ray_module
135
+ backend = 'ray'
136
+ except ImportError:
137
+ backend = 'mp'
138
+
139
+ # Validate shared_kwargs
140
+ if shared_kwargs:
141
+ sig = inspect.signature(func)
142
+ valid_params = set(sig.parameters.keys())
143
+
144
+ for kw in shared_kwargs:
145
+ if kw not in func_kwargs:
146
+ raise ValueError(
147
+ f"shared_kwargs key '{kw}' not found in "
148
+ f"provided func_kwargs"
149
+ )
150
+ has_var_keyword = any(
151
+ p.kind == inspect.Parameter.VAR_KEYWORD
152
+ for p in sig.parameters.values()
153
+ )
154
+ if kw not in valid_params and not has_var_keyword:
155
+ raise ValueError(
156
+ f"shared_kwargs key '{kw}' is not a valid parameter "
157
+ f"for function '{func.__name__}'. "
158
+ f"Valid parameters: {valid_params}"
159
+ )
160
+
161
+ # Prefer Ray backend when shared kwargs are requested
162
+ if shared_kwargs and backend != 'ray':
163
+ warnings.warn(
164
+ "shared_kwargs only supported with 'ray' backend, "
165
+ "switching backend to 'ray'",
166
+ UserWarning,
167
+ )
168
+ backend = 'ray'
169
+
170
+ # unify items and coerce to concrete list
171
+ if items is None and inputs is not None:
172
+ items = list(inputs)
173
+ if items is not None and not isinstance(items, list):
174
+ items = list(items)
175
+ if items is None:
176
+ raise ValueError("'items' or 'inputs' must be provided")
177
+
178
+ if workers is None and backend != 'ray':
179
+ workers = os.cpu_count() or 1
180
+
181
+ # build cache dir + wrap func
182
+ cache_dir = _build_cache_dir(func, items) if lazy_output else None
183
+ f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
184
+
185
+ log_gate_path = create_log_gate_path(log_worker)
186
+
187
+ total = len(items)
188
+ if desc:
189
+ desc = desc.strip() + f'[{backend}]'
190
+ else:
191
+ desc = f'Multi-process [{backend}]'
192
+
193
+ # Initialize error stats for error handling
194
+ func_name = getattr(func, '__name__', repr(func))
195
+ error_stats = ErrorStats(
196
+ func_name=func_name,
197
+ max_error_files=max_error_files,
198
+ write_logs=error_handler == 'log'
199
+ )
200
+
201
+ def _update_pbar_postfix(pbar: tqdm) -> None:
202
+ """Update pbar with success/error counts."""
203
+ postfix = error_stats.get_postfix_dict()
204
+ pbar.set_postfix(postfix)
205
+
206
+ # ---- sequential backend ----
207
+ if backend == 'seq':
208
+ return _run_seq_backend(
209
+ f_wrapped=f_wrapped,
210
+ items=items,
211
+ total=total,
212
+ desc=desc,
213
+ progress=progress,
214
+ func_kwargs=func_kwargs,
215
+ log_worker=log_worker,
216
+ log_gate_path=log_gate_path,
217
+ error_handler=error_handler,
218
+ error_stats=error_stats,
219
+ func_name=func_name,
220
+ update_pbar_postfix=_update_pbar_postfix,
221
+ )
222
+
223
+ # ---- ray backend ----
224
+ if backend == 'ray':
225
+ from ._multi_process_ray import run_ray_backend
226
+ return run_ray_backend(
227
+ f_wrapped=f_wrapped,
228
+ items=items,
229
+ total=total,
230
+ workers=workers,
231
+ progress=progress,
232
+ desc=desc,
233
+ func_kwargs=func_kwargs,
234
+ shared_kwargs=shared_kwargs,
235
+ log_worker=log_worker,
236
+ log_gate_path=log_gate_path,
237
+ total_items=total_items,
238
+ poll_interval=poll_interval,
239
+ ray_metrics_port=ray_metrics_port,
240
+ error_handler=error_handler,
241
+ error_stats=error_stats,
242
+ func_name=func_name,
243
+ )
244
+
245
+ # ---- threadpool backends (mp / safe) ----
246
+ if backend == 'mp':
247
+ return _run_threadpool_backend(
248
+ backend_label='mp',
249
+ track_processes=True,
250
+ f_wrapped=f_wrapped,
251
+ items=items,
252
+ total=total,
253
+ workers=workers,
254
+ desc=desc,
255
+ progress=progress,
256
+ func_kwargs=func_kwargs,
257
+ log_worker=log_worker,
258
+ log_gate_path=log_gate_path,
259
+ error_handler=error_handler,
260
+ error_stats=error_stats,
261
+ func_name=func_name,
262
+ update_pbar_postfix=_update_pbar_postfix,
263
+ )
264
+
265
+ if backend == 'safe':
266
+ return _run_threadpool_backend(
267
+ backend_label='safe',
268
+ track_processes=False,
269
+ f_wrapped=f_wrapped,
270
+ items=items,
271
+ total=total,
272
+ workers=workers,
273
+ desc=desc,
274
+ progress=progress,
275
+ func_kwargs=func_kwargs,
276
+ log_worker=log_worker,
277
+ log_gate_path=log_gate_path,
278
+ error_handler=error_handler,
279
+ error_stats=error_stats,
280
+ func_name=func_name,
281
+ update_pbar_postfix=_update_pbar_postfix,
282
+ )
283
+
284
+ raise ValueError(f'Unsupported backend: {backend!r}')
285
+
286
+
287
+ def _run_seq_backend(
288
+ *,
289
+ f_wrapped,
290
+ items: list,
291
+ total: int,
292
+ desc: str,
293
+ progress: bool,
294
+ func_kwargs: dict,
295
+ log_worker,
296
+ log_gate_path,
297
+ error_handler,
298
+ error_stats: ErrorStats,
299
+ func_name: str,
300
+ update_pbar_postfix,
301
+ ) -> list[Any]:
302
+ """Run sequential (single-threaded) backend."""
303
+ results: list[Any] = []
304
+ with tqdm(
305
+ total=total,
306
+ desc=desc,
307
+ disable=not progress,
308
+ file=sys.stdout,
309
+ ) as pbar:
310
+ for idx, x in enumerate(items):
311
+ try:
312
+ result = _call_with_log_control(
313
+ f_wrapped,
314
+ x,
315
+ func_kwargs,
316
+ log_worker,
317
+ log_gate_path,
318
+ )
319
+ error_stats.record_success()
320
+ results.append(result)
321
+ except Exception as e:
322
+ if error_handler == 'raise':
323
+ raise
324
+ error_stats.record_error(idx, e, x, func_name)
325
+ results.append(None)
326
+ pbar.update(1)
327
+ update_pbar_postfix(pbar)
328
+ _cleanup_log_gate(log_gate_path)
329
+ return results
330
+
331
+
332
+ def _run_threadpool_backend(
333
+ *,
334
+ backend_label: str,
335
+ track_processes: bool,
336
+ f_wrapped,
337
+ items: list,
338
+ total: int,
339
+ workers: int | None,
340
+ desc: str,
341
+ progress: bool,
342
+ func_kwargs: dict,
343
+ log_worker,
344
+ log_gate_path,
345
+ error_handler,
346
+ error_stats: ErrorStats,
347
+ func_name: str,
348
+ update_pbar_postfix,
349
+ ) -> list[Any]:
350
+ """Run ThreadPoolExecutor backend for 'mp' and 'safe' modes."""
351
+ # Capture caller frame for better error reporting
352
+ caller_frame = inspect.currentframe()
353
+ caller_info = None
354
+ if caller_frame and caller_frame.f_back and caller_frame.f_back.f_back:
355
+ # Go back two frames: _run_threadpool_backend -> multi_process -> user
356
+ outer = caller_frame.f_back.f_back
357
+ caller_info = {
358
+ 'filename': outer.f_code.co_filename,
359
+ 'lineno': outer.f_lineno,
360
+ 'function': outer.f_code.co_name,
361
+ }
362
+
363
+ def worker_func(x):
364
+ return _call_with_log_control(
365
+ f_wrapped,
366
+ x,
367
+ func_kwargs,
368
+ log_worker,
369
+ log_gate_path,
370
+ )
371
+
372
+ results: list[Any] = [None] * total
373
+ with tqdm(
374
+ total=total,
375
+ desc=desc,
376
+ disable=not progress,
377
+ file=sys.stdout,
378
+ ) as pbar:
379
+ with concurrent.futures.ThreadPoolExecutor(
380
+ max_workers=workers
381
+ ) as executor:
382
+ if _track_executor_threads is not None:
383
+ _track_executor_threads(executor)
384
+
385
+ # Submit all tasks
386
+ future_to_idx = {
387
+ executor.submit(worker_func, x): idx
388
+ for idx, x in enumerate(items)
389
+ }
390
+
391
+ # Process results as they complete
392
+ for future in concurrent.futures.as_completed(future_to_idx):
393
+ idx = future_to_idx[future]
394
+ try:
395
+ result = future.result()
396
+ error_stats.record_success()
397
+ results[idx] = result
398
+ except Exception as e:
399
+ if error_handler == 'raise':
400
+ # Cancel remaining futures
401
+ for f in future_to_idx:
402
+ f.cancel()
403
+ _exit_on_worker_error(
404
+ e,
405
+ pbar,
406
+ caller_info,
407
+ backend=backend_label,
408
+ )
409
+ error_stats.record_error(idx, e, items[idx], func_name)
410
+ results[idx] = None
411
+ pbar.update(1)
412
+ update_pbar_postfix(pbar)
413
+
414
+ if _prune_dead_threads is not None:
415
+ _prune_dead_threads()
416
+
417
+ if track_processes:
418
+ _track_multiprocessing_processes()
419
+ _prune_dead_processes()
420
+
421
+ _cleanup_log_gate(log_gate_path)
422
+ return results
423
+
424
+
425
+ __all__ = ['multi_process']