speedy-utils 1.1.45__py3-none-any.whl → 1.1.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/lm/llm.py +41 -12
- speedy_utils/__init__.py +4 -0
- speedy_utils/multi_worker/__init__.py +4 -0
- speedy_utils/multi_worker/_multi_process.py +425 -0
- speedy_utils/multi_worker/_multi_process_ray.py +308 -0
- speedy_utils/multi_worker/common.py +879 -0
- speedy_utils/multi_worker/dataset_sharding.py +203 -0
- speedy_utils/multi_worker/process.py +53 -1234
- speedy_utils/multi_worker/progress.py +71 -1
- speedy_utils/multi_worker/thread.py +45 -0
- speedy_utils/scripts/kill_mpython.py +58 -0
- speedy_utils/scripts/mpython.py +63 -16
- {speedy_utils-1.1.45.dist-info → speedy_utils-1.1.47.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.45.dist-info → speedy_utils-1.1.47.dist-info}/RECORD +16 -11
- {speedy_utils-1.1.45.dist-info → speedy_utils-1.1.47.dist-info}/entry_points.txt +1 -0
- {speedy_utils-1.1.45.dist-info → speedy_utils-1.1.47.dist-info}/WHEEL +0 -0
llm_utils/lm/llm.py
CHANGED
|
@@ -4,11 +4,9 @@
|
|
|
4
4
|
Simplified LLM Task module for handling language model interactions with structured input/output.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
import os
|
|
8
7
|
import subprocess
|
|
9
|
-
from typing import Any, Dict, List, Optional,
|
|
8
|
+
from typing import Any, Dict, List, Optional, cast
|
|
10
9
|
|
|
11
|
-
import requests
|
|
12
10
|
from httpx import Timeout
|
|
13
11
|
from loguru import logger
|
|
14
12
|
from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError
|
|
@@ -28,16 +26,7 @@ from .mixins import (
|
|
|
28
26
|
)
|
|
29
27
|
from .utils import (
|
|
30
28
|
_extract_port_from_vllm_cmd,
|
|
31
|
-
_get_port_from_client,
|
|
32
|
-
_is_lora_path,
|
|
33
|
-
_is_server_running,
|
|
34
|
-
_kill_vllm_on_port,
|
|
35
|
-
_load_lora_adapter,
|
|
36
|
-
_start_vllm_server,
|
|
37
|
-
_unload_lora_adapter,
|
|
38
29
|
get_base_client,
|
|
39
|
-
kill_all_vllm_processes,
|
|
40
|
-
stop_vllm_process,
|
|
41
30
|
)
|
|
42
31
|
|
|
43
32
|
|
|
@@ -216,6 +205,46 @@ class LLM(
|
|
|
216
205
|
results.append(result_dict)
|
|
217
206
|
return results
|
|
218
207
|
|
|
208
|
+
@staticmethod
|
|
209
|
+
def _strip_think_tags(text: str) -> str:
|
|
210
|
+
"""Remove <think> tags if present, returning only the reasoning body."""
|
|
211
|
+
cleaned = text.strip()
|
|
212
|
+
if cleaned.startswith('<think>'):
|
|
213
|
+
cleaned = cleaned[len('<think>') :].lstrip()
|
|
214
|
+
if '</think>' in cleaned:
|
|
215
|
+
cleaned = cleaned.split('</think>', 1)[0].rstrip()
|
|
216
|
+
return cleaned
|
|
217
|
+
|
|
218
|
+
def generate_with_think_prefix(
|
|
219
|
+
self, input_data: str | BaseModel | list[dict], **runtime_kwargs
|
|
220
|
+
) -> list[dict[str, Any]]:
|
|
221
|
+
"""
|
|
222
|
+
Generate text and format output as:
|
|
223
|
+
<think>reasoning</think>
|
|
224
|
+
"""
|
|
225
|
+
results = self.text_completion(input_data, **runtime_kwargs)
|
|
226
|
+
|
|
227
|
+
for result in results:
|
|
228
|
+
content = result.get('parsed') or ''
|
|
229
|
+
reasoning = result.get('reasoning_content') or ''
|
|
230
|
+
|
|
231
|
+
if not reasoning and str(content).lstrip().startswith('<think>'):
|
|
232
|
+
formatted = str(content)
|
|
233
|
+
else:
|
|
234
|
+
reasoning_body = self._strip_think_tags(str(reasoning))
|
|
235
|
+
formatted = (
|
|
236
|
+
f'<think>\n{reasoning_body}\n</think>\n\n{str(content).lstrip()}'
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
result['parsed'] = formatted
|
|
240
|
+
messages = result.get('messages')
|
|
241
|
+
if isinstance(messages, list) and messages:
|
|
242
|
+
last_msg = messages[-1]
|
|
243
|
+
if isinstance(last_msg, dict) and last_msg.get('role') == 'assistant':
|
|
244
|
+
last_msg['content'] = formatted
|
|
245
|
+
|
|
246
|
+
return results
|
|
247
|
+
|
|
219
248
|
@clean_traceback
|
|
220
249
|
def pydantic_parse(
|
|
221
250
|
self,
|
speedy_utils/__init__.py
CHANGED
|
@@ -65,6 +65,8 @@ from .common.utils_error import clean_traceback, handle_exceptions_with_clean_tr
|
|
|
65
65
|
from .multi_worker.process import multi_process
|
|
66
66
|
from .multi_worker.thread import kill_all_thread, multi_thread
|
|
67
67
|
from .multi_worker.dataset_ray import multi_process_dataset_ray, WorkerResources
|
|
68
|
+
from .multi_worker.dataset_sharding import multi_process_dataset
|
|
69
|
+
from .multi_worker.progress import report_progress
|
|
68
70
|
|
|
69
71
|
|
|
70
72
|
__all__ = [
|
|
@@ -167,7 +169,9 @@ __all__ = [
|
|
|
167
169
|
'multi_thread',
|
|
168
170
|
'kill_all_thread',
|
|
169
171
|
'multi_process_dataset_ray',
|
|
172
|
+
'multi_process_dataset',
|
|
170
173
|
'WorkerResources',
|
|
174
|
+
'report_progress',
|
|
171
175
|
# Notebook utilities
|
|
172
176
|
'change_dir',
|
|
173
177
|
]
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from .process import multi_process, cleanup_phantom_workers, create_progress_tracker
|
|
2
2
|
from .thread import multi_thread
|
|
3
3
|
from .dataset_ray import multi_process_dataset_ray, WorkerResources
|
|
4
|
+
from .dataset_sharding import multi_process_dataset
|
|
5
|
+
from .progress import report_progress
|
|
4
6
|
|
|
5
7
|
__all__ = [
|
|
6
8
|
'multi_process',
|
|
@@ -9,4 +11,6 @@ __all__ = [
|
|
|
9
11
|
'create_progress_tracker',
|
|
10
12
|
'multi_process_dataset_ray',
|
|
11
13
|
'WorkerResources',
|
|
14
|
+
'multi_process_dataset',
|
|
15
|
+
'report_progress',
|
|
12
16
|
]
|
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Multi-process implementation with sequential and threadpool backends.
|
|
3
|
+
|
|
4
|
+
Provides the public `multi_process` dispatcher and implementations for:
|
|
5
|
+
- 'seq': Sequential execution
|
|
6
|
+
- 'mp': ThreadPoolExecutor-based parallelism
|
|
7
|
+
- 'safe': Same as 'mp' but without process tracking (for tests)
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import concurrent.futures
|
|
12
|
+
import inspect
|
|
13
|
+
import os
|
|
14
|
+
import sys
|
|
15
|
+
import warnings
|
|
16
|
+
from typing import Any, Callable, Iterable, Literal
|
|
17
|
+
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
|
|
20
|
+
from .common import (
|
|
21
|
+
ErrorHandlerType,
|
|
22
|
+
ErrorStats,
|
|
23
|
+
_build_cache_dir,
|
|
24
|
+
_call_with_log_control,
|
|
25
|
+
_cleanup_log_gate,
|
|
26
|
+
_exit_on_worker_error,
|
|
27
|
+
_prune_dead_processes,
|
|
28
|
+
_track_multiprocessing_processes,
|
|
29
|
+
create_log_gate_path,
|
|
30
|
+
wrap_dump,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Import thread tracking functions if available
|
|
34
|
+
try:
|
|
35
|
+
from .thread import _prune_dead_threads, _track_executor_threads
|
|
36
|
+
except ImportError:
|
|
37
|
+
_prune_dead_threads = None # type: ignore[assignment]
|
|
38
|
+
_track_executor_threads = None # type: ignore[assignment]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def multi_process(
|
|
42
|
+
func: Callable[[Any], Any],
|
|
43
|
+
items: Iterable[Any] | None = None,
|
|
44
|
+
*,
|
|
45
|
+
inputs: Iterable[Any] | None = None,
|
|
46
|
+
workers: int | None = None,
|
|
47
|
+
lazy_output: bool = False,
|
|
48
|
+
progress: bool = True,
|
|
49
|
+
backend: Literal['seq', 'ray', 'mp', 'safe'] = 'mp',
|
|
50
|
+
desc: str | None = None,
|
|
51
|
+
shared_kwargs: list[str] | None = None,
|
|
52
|
+
dump_in_thread: bool = True,
|
|
53
|
+
ray_metrics_port: int | None = None,
|
|
54
|
+
log_worker: Literal['zero', 'first', 'all'] = 'first',
|
|
55
|
+
total_items: int | None = None,
|
|
56
|
+
poll_interval: float = 0.3,
|
|
57
|
+
error_handler: ErrorHandlerType = 'log',
|
|
58
|
+
max_error_files: int = 100,
|
|
59
|
+
process_update_interval: int | None = None,
|
|
60
|
+
batch: int | None = None,
|
|
61
|
+
ordered: bool = True,
|
|
62
|
+
stop_on_error: bool = True,
|
|
63
|
+
**func_kwargs: Any,
|
|
64
|
+
) -> list[Any]:
|
|
65
|
+
"""
|
|
66
|
+
Multi-process map with selectable backend.
|
|
67
|
+
|
|
68
|
+
backend:
|
|
69
|
+
- "seq": run sequentially
|
|
70
|
+
- "ray": run in parallel with Ray
|
|
71
|
+
- "mp": run in parallel with thread pool (uses ThreadPoolExecutor)
|
|
72
|
+
- "safe": run in parallel with thread pool (explicitly safe for tests)
|
|
73
|
+
|
|
74
|
+
shared_kwargs:
|
|
75
|
+
- Optional list of kwarg names that should be shared via Ray's
|
|
76
|
+
zero-copy object store
|
|
77
|
+
- Only works with Ray backend
|
|
78
|
+
- Useful for large objects (e.g., models, datasets)
|
|
79
|
+
- Example: shared_kwargs=['model', 'tokenizer']
|
|
80
|
+
|
|
81
|
+
dump_in_thread:
|
|
82
|
+
- Whether to dump results to disk in a separate thread (default: True)
|
|
83
|
+
- If False, dumping is done synchronously
|
|
84
|
+
|
|
85
|
+
ray_metrics_port:
|
|
86
|
+
- Optional port for Ray metrics export (Ray backend only)
|
|
87
|
+
|
|
88
|
+
log_worker:
|
|
89
|
+
- Control worker stdout/stderr noise
|
|
90
|
+
- 'first': only first worker emits logs (default)
|
|
91
|
+
- 'all': allow worker prints
|
|
92
|
+
- 'zero': silence all worker output
|
|
93
|
+
|
|
94
|
+
total_items:
|
|
95
|
+
- Optional item-level total for progress tracking (Ray backend only)
|
|
96
|
+
|
|
97
|
+
poll_interval:
|
|
98
|
+
- Poll interval in seconds for progress actor updates (Ray only)
|
|
99
|
+
|
|
100
|
+
error_handler:
|
|
101
|
+
- 'raise': raise exception on first error
|
|
102
|
+
- 'ignore': continue processing, return None for failed items
|
|
103
|
+
- 'log': same as ignore, but logs errors to files (default)
|
|
104
|
+
- Note: for 'mp' and 'ray' backends, 'raise' prints a formatted
|
|
105
|
+
traceback and exits the process.
|
|
106
|
+
|
|
107
|
+
max_error_files:
|
|
108
|
+
- Maximum number of error log files to write (default: 100)
|
|
109
|
+
- Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
|
|
110
|
+
- First error is always printed to screen with the log file path
|
|
111
|
+
|
|
112
|
+
process_update_interval:
|
|
113
|
+
- Legacy parameter, accepted for backward compatibility but not implemented
|
|
114
|
+
|
|
115
|
+
batch:
|
|
116
|
+
- Legacy parameter, accepted for backward compatibility but not implemented
|
|
117
|
+
|
|
118
|
+
ordered:
|
|
119
|
+
- Whether to maintain order of results (default: True)
|
|
120
|
+
- Legacy parameter, accepted for backward compatibility but not implemented
|
|
121
|
+
|
|
122
|
+
stop_on_error:
|
|
123
|
+
- Whether to stop on first error (default: True)
|
|
124
|
+
- Legacy parameter, accepted for backward compatibility
|
|
125
|
+
- Use error_handler parameter instead for error handling control
|
|
126
|
+
|
|
127
|
+
If lazy_output=True, every result is saved to .pkl and
|
|
128
|
+
the returned list contains file paths.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
# default backend selection
|
|
132
|
+
if backend is None:
|
|
133
|
+
try:
|
|
134
|
+
import ray as _ray_module
|
|
135
|
+
backend = 'ray'
|
|
136
|
+
except ImportError:
|
|
137
|
+
backend = 'mp'
|
|
138
|
+
|
|
139
|
+
# Validate shared_kwargs
|
|
140
|
+
if shared_kwargs:
|
|
141
|
+
sig = inspect.signature(func)
|
|
142
|
+
valid_params = set(sig.parameters.keys())
|
|
143
|
+
|
|
144
|
+
for kw in shared_kwargs:
|
|
145
|
+
if kw not in func_kwargs:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"shared_kwargs key '{kw}' not found in "
|
|
148
|
+
f"provided func_kwargs"
|
|
149
|
+
)
|
|
150
|
+
has_var_keyword = any(
|
|
151
|
+
p.kind == inspect.Parameter.VAR_KEYWORD
|
|
152
|
+
for p in sig.parameters.values()
|
|
153
|
+
)
|
|
154
|
+
if kw not in valid_params and not has_var_keyword:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"shared_kwargs key '{kw}' is not a valid parameter "
|
|
157
|
+
f"for function '{func.__name__}'. "
|
|
158
|
+
f"Valid parameters: {valid_params}"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Prefer Ray backend when shared kwargs are requested
|
|
162
|
+
if shared_kwargs and backend != 'ray':
|
|
163
|
+
warnings.warn(
|
|
164
|
+
"shared_kwargs only supported with 'ray' backend, "
|
|
165
|
+
"switching backend to 'ray'",
|
|
166
|
+
UserWarning,
|
|
167
|
+
)
|
|
168
|
+
backend = 'ray'
|
|
169
|
+
|
|
170
|
+
# unify items and coerce to concrete list
|
|
171
|
+
if items is None and inputs is not None:
|
|
172
|
+
items = list(inputs)
|
|
173
|
+
if items is not None and not isinstance(items, list):
|
|
174
|
+
items = list(items)
|
|
175
|
+
if items is None:
|
|
176
|
+
raise ValueError("'items' or 'inputs' must be provided")
|
|
177
|
+
|
|
178
|
+
if workers is None and backend != 'ray':
|
|
179
|
+
workers = os.cpu_count() or 1
|
|
180
|
+
|
|
181
|
+
# build cache dir + wrap func
|
|
182
|
+
cache_dir = _build_cache_dir(func, items) if lazy_output else None
|
|
183
|
+
f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
|
|
184
|
+
|
|
185
|
+
log_gate_path = create_log_gate_path(log_worker)
|
|
186
|
+
|
|
187
|
+
total = len(items)
|
|
188
|
+
if desc:
|
|
189
|
+
desc = desc.strip() + f'[{backend}]'
|
|
190
|
+
else:
|
|
191
|
+
desc = f'Multi-process [{backend}]'
|
|
192
|
+
|
|
193
|
+
# Initialize error stats for error handling
|
|
194
|
+
func_name = getattr(func, '__name__', repr(func))
|
|
195
|
+
error_stats = ErrorStats(
|
|
196
|
+
func_name=func_name,
|
|
197
|
+
max_error_files=max_error_files,
|
|
198
|
+
write_logs=error_handler == 'log'
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def _update_pbar_postfix(pbar: tqdm) -> None:
|
|
202
|
+
"""Update pbar with success/error counts."""
|
|
203
|
+
postfix = error_stats.get_postfix_dict()
|
|
204
|
+
pbar.set_postfix(postfix)
|
|
205
|
+
|
|
206
|
+
# ---- sequential backend ----
|
|
207
|
+
if backend == 'seq':
|
|
208
|
+
return _run_seq_backend(
|
|
209
|
+
f_wrapped=f_wrapped,
|
|
210
|
+
items=items,
|
|
211
|
+
total=total,
|
|
212
|
+
desc=desc,
|
|
213
|
+
progress=progress,
|
|
214
|
+
func_kwargs=func_kwargs,
|
|
215
|
+
log_worker=log_worker,
|
|
216
|
+
log_gate_path=log_gate_path,
|
|
217
|
+
error_handler=error_handler,
|
|
218
|
+
error_stats=error_stats,
|
|
219
|
+
func_name=func_name,
|
|
220
|
+
update_pbar_postfix=_update_pbar_postfix,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# ---- ray backend ----
|
|
224
|
+
if backend == 'ray':
|
|
225
|
+
from ._multi_process_ray import run_ray_backend
|
|
226
|
+
return run_ray_backend(
|
|
227
|
+
f_wrapped=f_wrapped,
|
|
228
|
+
items=items,
|
|
229
|
+
total=total,
|
|
230
|
+
workers=workers,
|
|
231
|
+
progress=progress,
|
|
232
|
+
desc=desc,
|
|
233
|
+
func_kwargs=func_kwargs,
|
|
234
|
+
shared_kwargs=shared_kwargs,
|
|
235
|
+
log_worker=log_worker,
|
|
236
|
+
log_gate_path=log_gate_path,
|
|
237
|
+
total_items=total_items,
|
|
238
|
+
poll_interval=poll_interval,
|
|
239
|
+
ray_metrics_port=ray_metrics_port,
|
|
240
|
+
error_handler=error_handler,
|
|
241
|
+
error_stats=error_stats,
|
|
242
|
+
func_name=func_name,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# ---- threadpool backends (mp / safe) ----
|
|
246
|
+
if backend == 'mp':
|
|
247
|
+
return _run_threadpool_backend(
|
|
248
|
+
backend_label='mp',
|
|
249
|
+
track_processes=True,
|
|
250
|
+
f_wrapped=f_wrapped,
|
|
251
|
+
items=items,
|
|
252
|
+
total=total,
|
|
253
|
+
workers=workers,
|
|
254
|
+
desc=desc,
|
|
255
|
+
progress=progress,
|
|
256
|
+
func_kwargs=func_kwargs,
|
|
257
|
+
log_worker=log_worker,
|
|
258
|
+
log_gate_path=log_gate_path,
|
|
259
|
+
error_handler=error_handler,
|
|
260
|
+
error_stats=error_stats,
|
|
261
|
+
func_name=func_name,
|
|
262
|
+
update_pbar_postfix=_update_pbar_postfix,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
if backend == 'safe':
|
|
266
|
+
return _run_threadpool_backend(
|
|
267
|
+
backend_label='safe',
|
|
268
|
+
track_processes=False,
|
|
269
|
+
f_wrapped=f_wrapped,
|
|
270
|
+
items=items,
|
|
271
|
+
total=total,
|
|
272
|
+
workers=workers,
|
|
273
|
+
desc=desc,
|
|
274
|
+
progress=progress,
|
|
275
|
+
func_kwargs=func_kwargs,
|
|
276
|
+
log_worker=log_worker,
|
|
277
|
+
log_gate_path=log_gate_path,
|
|
278
|
+
error_handler=error_handler,
|
|
279
|
+
error_stats=error_stats,
|
|
280
|
+
func_name=func_name,
|
|
281
|
+
update_pbar_postfix=_update_pbar_postfix,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
raise ValueError(f'Unsupported backend: {backend!r}')
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def _run_seq_backend(
|
|
288
|
+
*,
|
|
289
|
+
f_wrapped,
|
|
290
|
+
items: list,
|
|
291
|
+
total: int,
|
|
292
|
+
desc: str,
|
|
293
|
+
progress: bool,
|
|
294
|
+
func_kwargs: dict,
|
|
295
|
+
log_worker,
|
|
296
|
+
log_gate_path,
|
|
297
|
+
error_handler,
|
|
298
|
+
error_stats: ErrorStats,
|
|
299
|
+
func_name: str,
|
|
300
|
+
update_pbar_postfix,
|
|
301
|
+
) -> list[Any]:
|
|
302
|
+
"""Run sequential (single-threaded) backend."""
|
|
303
|
+
results: list[Any] = []
|
|
304
|
+
with tqdm(
|
|
305
|
+
total=total,
|
|
306
|
+
desc=desc,
|
|
307
|
+
disable=not progress,
|
|
308
|
+
file=sys.stdout,
|
|
309
|
+
) as pbar:
|
|
310
|
+
for idx, x in enumerate(items):
|
|
311
|
+
try:
|
|
312
|
+
result = _call_with_log_control(
|
|
313
|
+
f_wrapped,
|
|
314
|
+
x,
|
|
315
|
+
func_kwargs,
|
|
316
|
+
log_worker,
|
|
317
|
+
log_gate_path,
|
|
318
|
+
)
|
|
319
|
+
error_stats.record_success()
|
|
320
|
+
results.append(result)
|
|
321
|
+
except Exception as e:
|
|
322
|
+
if error_handler == 'raise':
|
|
323
|
+
raise
|
|
324
|
+
error_stats.record_error(idx, e, x, func_name)
|
|
325
|
+
results.append(None)
|
|
326
|
+
pbar.update(1)
|
|
327
|
+
update_pbar_postfix(pbar)
|
|
328
|
+
_cleanup_log_gate(log_gate_path)
|
|
329
|
+
return results
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _run_threadpool_backend(
|
|
333
|
+
*,
|
|
334
|
+
backend_label: str,
|
|
335
|
+
track_processes: bool,
|
|
336
|
+
f_wrapped,
|
|
337
|
+
items: list,
|
|
338
|
+
total: int,
|
|
339
|
+
workers: int | None,
|
|
340
|
+
desc: str,
|
|
341
|
+
progress: bool,
|
|
342
|
+
func_kwargs: dict,
|
|
343
|
+
log_worker,
|
|
344
|
+
log_gate_path,
|
|
345
|
+
error_handler,
|
|
346
|
+
error_stats: ErrorStats,
|
|
347
|
+
func_name: str,
|
|
348
|
+
update_pbar_postfix,
|
|
349
|
+
) -> list[Any]:
|
|
350
|
+
"""Run ThreadPoolExecutor backend for 'mp' and 'safe' modes."""
|
|
351
|
+
# Capture caller frame for better error reporting
|
|
352
|
+
caller_frame = inspect.currentframe()
|
|
353
|
+
caller_info = None
|
|
354
|
+
if caller_frame and caller_frame.f_back and caller_frame.f_back.f_back:
|
|
355
|
+
# Go back two frames: _run_threadpool_backend -> multi_process -> user
|
|
356
|
+
outer = caller_frame.f_back.f_back
|
|
357
|
+
caller_info = {
|
|
358
|
+
'filename': outer.f_code.co_filename,
|
|
359
|
+
'lineno': outer.f_lineno,
|
|
360
|
+
'function': outer.f_code.co_name,
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
def worker_func(x):
|
|
364
|
+
return _call_with_log_control(
|
|
365
|
+
f_wrapped,
|
|
366
|
+
x,
|
|
367
|
+
func_kwargs,
|
|
368
|
+
log_worker,
|
|
369
|
+
log_gate_path,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
results: list[Any] = [None] * total
|
|
373
|
+
with tqdm(
|
|
374
|
+
total=total,
|
|
375
|
+
desc=desc,
|
|
376
|
+
disable=not progress,
|
|
377
|
+
file=sys.stdout,
|
|
378
|
+
) as pbar:
|
|
379
|
+
with concurrent.futures.ThreadPoolExecutor(
|
|
380
|
+
max_workers=workers
|
|
381
|
+
) as executor:
|
|
382
|
+
if _track_executor_threads is not None:
|
|
383
|
+
_track_executor_threads(executor)
|
|
384
|
+
|
|
385
|
+
# Submit all tasks
|
|
386
|
+
future_to_idx = {
|
|
387
|
+
executor.submit(worker_func, x): idx
|
|
388
|
+
for idx, x in enumerate(items)
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
# Process results as they complete
|
|
392
|
+
for future in concurrent.futures.as_completed(future_to_idx):
|
|
393
|
+
idx = future_to_idx[future]
|
|
394
|
+
try:
|
|
395
|
+
result = future.result()
|
|
396
|
+
error_stats.record_success()
|
|
397
|
+
results[idx] = result
|
|
398
|
+
except Exception as e:
|
|
399
|
+
if error_handler == 'raise':
|
|
400
|
+
# Cancel remaining futures
|
|
401
|
+
for f in future_to_idx:
|
|
402
|
+
f.cancel()
|
|
403
|
+
_exit_on_worker_error(
|
|
404
|
+
e,
|
|
405
|
+
pbar,
|
|
406
|
+
caller_info,
|
|
407
|
+
backend=backend_label,
|
|
408
|
+
)
|
|
409
|
+
error_stats.record_error(idx, e, items[idx], func_name)
|
|
410
|
+
results[idx] = None
|
|
411
|
+
pbar.update(1)
|
|
412
|
+
update_pbar_postfix(pbar)
|
|
413
|
+
|
|
414
|
+
if _prune_dead_threads is not None:
|
|
415
|
+
_prune_dead_threads()
|
|
416
|
+
|
|
417
|
+
if track_processes:
|
|
418
|
+
_track_multiprocessing_processes()
|
|
419
|
+
_prune_dead_processes()
|
|
420
|
+
|
|
421
|
+
_cleanup_log_gate(log_gate_path)
|
|
422
|
+
return results
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
__all__ = ['multi_process']
|