speedy-utils 1.1.23__py3-none-any.whl → 1.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +12 -8
- llm_utils/chat_format/__init__.py +2 -0
- llm_utils/chat_format/display.py +115 -44
- llm_utils/lm/__init__.py +14 -6
- llm_utils/lm/llm.py +413 -0
- llm_utils/lm/llm_signature.py +35 -0
- llm_utils/lm/mixins.py +379 -0
- llm_utils/lm/openai_memoize.py +18 -7
- llm_utils/lm/signature.py +26 -37
- llm_utils/lm/utils.py +61 -76
- speedy_utils/__init__.py +28 -1
- speedy_utils/all.py +30 -1
- speedy_utils/common/utils_io.py +36 -26
- speedy_utils/common/utils_misc.py +25 -1
- speedy_utils/multi_worker/thread.py +145 -58
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.24.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.24.dist-info}/RECORD +19 -18
- llm_utils/lm/llm_as_a_judge.py +0 -390
- llm_utils/lm/llm_task.py +0 -614
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.24.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.24.dist-info}/entry_points.txt +0 -0
llm_utils/lm/utils.py
CHANGED
|
@@ -14,6 +14,7 @@ from openai import OpenAI
|
|
|
14
14
|
|
|
15
15
|
try:
|
|
16
16
|
import psutil
|
|
17
|
+
|
|
17
18
|
HAS_PSUTIL = True
|
|
18
19
|
except ImportError:
|
|
19
20
|
HAS_PSUTIL = False
|
|
@@ -26,7 +27,7 @@ _VLLM_PROCESSES: List[subprocess.Popen] = []
|
|
|
26
27
|
|
|
27
28
|
def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
|
|
28
29
|
"""Extract port from VLLM command string."""
|
|
29
|
-
port_match = re.search(r
|
|
30
|
+
port_match = re.search(r"--port\s+(\d+)", vllm_cmd)
|
|
30
31
|
if port_match:
|
|
31
32
|
return int(port_match.group(1))
|
|
32
33
|
return 8000
|
|
@@ -34,39 +35,39 @@ def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
|
|
|
34
35
|
|
|
35
36
|
def _parse_env_vars_from_cmd(cmd: str) -> tuple[dict[str, str], str]:
|
|
36
37
|
"""Parse environment variables from command string.
|
|
37
|
-
|
|
38
|
+
|
|
38
39
|
Args:
|
|
39
40
|
cmd: Command string that may contain environment variables like 'VAR=value command...'
|
|
40
|
-
|
|
41
|
+
|
|
41
42
|
Returns:
|
|
42
43
|
Tuple of (env_dict, cleaned_cmd) where env_dict contains parsed env vars
|
|
43
44
|
and cleaned_cmd is the command without the env vars.
|
|
44
45
|
"""
|
|
45
46
|
import shlex
|
|
46
|
-
|
|
47
|
+
|
|
47
48
|
# Split the command while preserving quoted strings
|
|
48
49
|
parts = shlex.split(cmd)
|
|
49
|
-
|
|
50
|
+
|
|
50
51
|
env_vars = {}
|
|
51
52
|
cmd_parts = []
|
|
52
|
-
|
|
53
|
+
|
|
53
54
|
for part in parts:
|
|
54
|
-
if
|
|
55
|
+
if "=" in part and not part.startswith("-"):
|
|
55
56
|
# Check if this looks like an environment variable
|
|
56
57
|
# Should be KEY=VALUE format, not contain spaces (unless quoted), and KEY should be uppercase
|
|
57
|
-
key_value = part.split(
|
|
58
|
+
key_value = part.split("=", 1)
|
|
58
59
|
if len(key_value) == 2:
|
|
59
60
|
key, value = key_value
|
|
60
|
-
if key.isupper() and key.replace(
|
|
61
|
+
if key.isupper() and key.replace("_", "").isalnum():
|
|
61
62
|
env_vars[key] = value
|
|
62
63
|
continue
|
|
63
|
-
|
|
64
|
+
|
|
64
65
|
# Not an env var, add to command parts
|
|
65
66
|
cmd_parts.append(part)
|
|
66
|
-
|
|
67
|
+
|
|
67
68
|
# Reconstruct the cleaned command
|
|
68
|
-
cleaned_cmd =
|
|
69
|
-
|
|
69
|
+
cleaned_cmd = " ".join(cmd_parts)
|
|
70
|
+
|
|
70
71
|
return env_vars, cleaned_cmd
|
|
71
72
|
|
|
72
73
|
|
|
@@ -74,38 +75,33 @@ def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
|
|
|
74
75
|
"""Start VLLM server and wait for ready."""
|
|
75
76
|
# Parse environment variables from command
|
|
76
77
|
env_vars, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
|
|
77
|
-
|
|
78
|
+
|
|
78
79
|
port = _extract_port_from_vllm_cmd(cleaned_cmd)
|
|
79
|
-
|
|
80
|
+
|
|
80
81
|
logger.info(f"Starting VLLM server: {cleaned_cmd}")
|
|
81
82
|
if env_vars:
|
|
82
83
|
logger.info(f"Environment variables: {env_vars}")
|
|
83
84
|
logger.info(f"VLLM output logged to: /tmp/vllm_{port}.txt")
|
|
84
|
-
|
|
85
|
-
with open(f
|
|
85
|
+
|
|
86
|
+
with open(f"/tmp/vllm_{port}.txt", "w") as log_file:
|
|
86
87
|
log_file.write(f"VLLM Server started at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
87
88
|
log_file.write(f"Command: {cleaned_cmd}\n")
|
|
88
89
|
if env_vars:
|
|
89
90
|
log_file.write(f"Environment: {env_vars}\n")
|
|
90
91
|
log_file.write(f"Port: {port}\n")
|
|
91
92
|
log_file.write("-" * 50 + "\n")
|
|
92
|
-
|
|
93
|
+
|
|
93
94
|
# Prepare environment for subprocess
|
|
94
95
|
env = os.environ.copy()
|
|
95
96
|
env.update(env_vars)
|
|
96
97
|
|
|
97
|
-
with open(f
|
|
98
|
+
with open(f"/tmp/vllm_{port}.txt", "a") as log_file:
|
|
98
99
|
process = subprocess.Popen(
|
|
99
|
-
cleaned_cmd.split(),
|
|
100
|
-
stdout=log_file,
|
|
101
|
-
stderr=subprocess.STDOUT,
|
|
102
|
-
text=True,
|
|
103
|
-
preexec_fn=os.setsid,
|
|
104
|
-
env=env
|
|
100
|
+
cleaned_cmd.split(), stdout=log_file, stderr=subprocess.STDOUT, text=True, preexec_fn=os.setsid, env=env
|
|
105
101
|
)
|
|
106
|
-
|
|
102
|
+
|
|
107
103
|
_VLLM_PROCESSES.append(process)
|
|
108
|
-
|
|
104
|
+
|
|
109
105
|
start_time = time.time()
|
|
110
106
|
while time.time() - start_time < timeout:
|
|
111
107
|
try:
|
|
@@ -115,26 +111,24 @@ def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
|
|
|
115
111
|
return process
|
|
116
112
|
except requests.RequestException:
|
|
117
113
|
pass
|
|
118
|
-
|
|
114
|
+
|
|
119
115
|
if process.poll() is not None:
|
|
120
116
|
stdout, stderr = process.communicate()
|
|
121
117
|
raise RuntimeError(
|
|
122
|
-
f"VLLM server terminated unexpectedly. "
|
|
123
|
-
f"Return code: {process.returncode}, "
|
|
124
|
-
f"stderr: {stderr[:200]}..."
|
|
118
|
+
f"VLLM server terminated unexpectedly. Return code: {process.returncode}, stderr: {stderr[:200]}..."
|
|
125
119
|
)
|
|
126
|
-
|
|
120
|
+
|
|
127
121
|
time.sleep(2)
|
|
128
|
-
|
|
122
|
+
|
|
129
123
|
process.terminate()
|
|
130
124
|
try:
|
|
131
125
|
process.wait(timeout=5)
|
|
132
126
|
except subprocess.TimeoutExpired:
|
|
133
127
|
process.kill()
|
|
134
|
-
|
|
128
|
+
|
|
135
129
|
if process in _VLLM_PROCESSES:
|
|
136
130
|
_VLLM_PROCESSES.remove(process)
|
|
137
|
-
|
|
131
|
+
|
|
138
132
|
raise RuntimeError(f"VLLM server failed to start within {timeout}s on port {port}")
|
|
139
133
|
|
|
140
134
|
|
|
@@ -142,7 +136,7 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
142
136
|
"""Kill VLLM server on port."""
|
|
143
137
|
killed = False
|
|
144
138
|
logger.info(f"Checking VLLM server on port {port}")
|
|
145
|
-
|
|
139
|
+
|
|
146
140
|
processes_to_remove = []
|
|
147
141
|
for process in _VLLM_PROCESSES:
|
|
148
142
|
try:
|
|
@@ -151,8 +145,8 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
151
145
|
if HAS_PSUTIL:
|
|
152
146
|
try:
|
|
153
147
|
proc = psutil.Process(process.pid)
|
|
154
|
-
cmdline =
|
|
155
|
-
if f
|
|
148
|
+
cmdline = " ".join(proc.cmdline())
|
|
149
|
+
if f"--port {port}" in cmdline or f"--port={port}" in cmdline:
|
|
156
150
|
logger.info(f"Killing tracked VLLM process {process.pid} on port {port}")
|
|
157
151
|
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
|
|
158
152
|
try:
|
|
@@ -164,7 +158,7 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
164
158
|
killed_process = True
|
|
165
159
|
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
166
160
|
pass
|
|
167
|
-
|
|
161
|
+
|
|
168
162
|
if not HAS_PSUTIL or not killed_process:
|
|
169
163
|
logger.info(f"Killing tracked VLLM process {process.pid}")
|
|
170
164
|
try:
|
|
@@ -177,24 +171,23 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
177
171
|
killed = True
|
|
178
172
|
except (ProcessLookupError, OSError):
|
|
179
173
|
pass
|
|
180
|
-
|
|
174
|
+
|
|
181
175
|
processes_to_remove.append(process)
|
|
182
176
|
else:
|
|
183
177
|
processes_to_remove.append(process)
|
|
184
178
|
except (ProcessLookupError, OSError):
|
|
185
179
|
processes_to_remove.append(process)
|
|
186
|
-
|
|
180
|
+
|
|
187
181
|
for process in processes_to_remove:
|
|
188
182
|
if process in _VLLM_PROCESSES:
|
|
189
183
|
_VLLM_PROCESSES.remove(process)
|
|
190
|
-
|
|
184
|
+
|
|
191
185
|
if not killed and HAS_PSUTIL:
|
|
192
186
|
try:
|
|
193
|
-
for proc in psutil.process_iter([
|
|
187
|
+
for proc in psutil.process_iter(["pid", "cmdline"]):
|
|
194
188
|
try:
|
|
195
|
-
cmdline =
|
|
196
|
-
if
|
|
197
|
-
(f'--port {port}' in cmdline or f'--port={port}' in cmdline)):
|
|
189
|
+
cmdline = " ".join(proc.info["cmdline"] or [])
|
|
190
|
+
if "vllm" in cmdline.lower() and (f"--port {port}" in cmdline or f"--port={port}" in cmdline):
|
|
198
191
|
logger.info(f"Killing untracked VLLM process {proc.info['pid']} on port {port}")
|
|
199
192
|
proc.terminate()
|
|
200
193
|
try:
|
|
@@ -207,13 +200,13 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
207
200
|
continue
|
|
208
201
|
except Exception as e:
|
|
209
202
|
logger.warning(f"Error searching processes on port {port}: {e}")
|
|
210
|
-
|
|
203
|
+
|
|
211
204
|
if killed:
|
|
212
205
|
logger.info(f"Killed VLLM server on port {port}")
|
|
213
206
|
time.sleep(2)
|
|
214
207
|
else:
|
|
215
208
|
logger.info(f"No VLLM server on port {port}")
|
|
216
|
-
|
|
209
|
+
|
|
217
210
|
return killed
|
|
218
211
|
|
|
219
212
|
|
|
@@ -262,32 +255,24 @@ def _is_server_running(port: int) -> bool:
|
|
|
262
255
|
return False
|
|
263
256
|
|
|
264
257
|
|
|
265
|
-
def get_base_client(
|
|
266
|
-
client=None,
|
|
267
|
-
cache: bool = True,
|
|
268
|
-
api_key="abc",
|
|
269
|
-
vllm_cmd=None,
|
|
270
|
-
vllm_process=None
|
|
271
|
-
) -> OpenAI:
|
|
258
|
+
def get_base_client(client=None, cache: bool = True, api_key="abc", vllm_cmd=None, vllm_process=None) -> OpenAI:
|
|
272
259
|
"""Get OpenAI client from various inputs."""
|
|
273
260
|
from llm_utils import MOpenAI
|
|
274
261
|
|
|
275
|
-
open_ai_class = OpenAI if not cache else MOpenAI
|
|
276
|
-
|
|
277
262
|
if client is None:
|
|
278
263
|
if vllm_cmd is not None:
|
|
279
264
|
# Parse environment variables from command to get clean command for port extraction
|
|
280
265
|
_, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
|
|
281
266
|
port = _extract_port_from_vllm_cmd(cleaned_cmd)
|
|
282
|
-
return
|
|
267
|
+
return MOpenAI(base_url=f"http://localhost:{port}/v1", api_key=api_key, cache=cache)
|
|
283
268
|
else:
|
|
284
|
-
|
|
269
|
+
raise ValueError("Either client or vllm_cmd must be provided.")
|
|
285
270
|
elif isinstance(client, int):
|
|
286
|
-
return
|
|
271
|
+
return MOpenAI(base_url=f"http://localhost:{client}/v1", api_key=api_key, cache=cache)
|
|
287
272
|
elif isinstance(client, str):
|
|
288
|
-
return
|
|
273
|
+
return MOpenAI(base_url=client, api_key=api_key, cache=cache)
|
|
289
274
|
elif isinstance(client, OpenAI):
|
|
290
|
-
return client
|
|
275
|
+
return MOpenAI(base_url=client.base_url, api_key=api_key, cache=cache)
|
|
291
276
|
else:
|
|
292
277
|
raise ValueError("Invalid client type. Must be OpenAI, port (int), base_url (str), or None.")
|
|
293
278
|
|
|
@@ -296,17 +281,17 @@ def _is_lora_path(path: str) -> bool:
|
|
|
296
281
|
"""Check if path is LoRA adapter directory."""
|
|
297
282
|
if not os.path.isdir(path):
|
|
298
283
|
return False
|
|
299
|
-
adapter_config_path = os.path.join(path,
|
|
284
|
+
adapter_config_path = os.path.join(path, "adapter_config.json")
|
|
300
285
|
return os.path.isfile(adapter_config_path)
|
|
301
286
|
|
|
302
287
|
|
|
303
288
|
def _get_port_from_client(client: OpenAI) -> Optional[int]:
|
|
304
289
|
"""Extract port from OpenAI client base_url."""
|
|
305
|
-
if hasattr(client,
|
|
290
|
+
if hasattr(client, "base_url") and client.base_url:
|
|
306
291
|
base_url = str(client.base_url)
|
|
307
|
-
if
|
|
292
|
+
if "localhost:" in base_url:
|
|
308
293
|
try:
|
|
309
|
-
port_part = base_url.split(
|
|
294
|
+
port_part = base_url.split("localhost:")[1].split("/")[0]
|
|
310
295
|
return int(port_part)
|
|
311
296
|
except (IndexError, ValueError):
|
|
312
297
|
pass
|
|
@@ -315,14 +300,14 @@ def _get_port_from_client(client: OpenAI) -> Optional[int]:
|
|
|
315
300
|
|
|
316
301
|
def _load_lora_adapter(lora_path: str, port: int) -> str:
|
|
317
302
|
"""Load LoRA adapter from path."""
|
|
318
|
-
lora_name = os.path.basename(lora_path.rstrip(
|
|
303
|
+
lora_name = os.path.basename(lora_path.rstrip("/\\"))
|
|
319
304
|
if not lora_name:
|
|
320
305
|
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
321
|
-
|
|
306
|
+
|
|
322
307
|
response = requests.post(
|
|
323
|
-
f
|
|
324
|
-
headers={
|
|
325
|
-
json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
|
|
308
|
+
f"http://localhost:{port}/v1/load_lora_adapter",
|
|
309
|
+
headers={"accept": "application/json", "Content-Type": "application/json"},
|
|
310
|
+
json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)},
|
|
326
311
|
)
|
|
327
312
|
response.raise_for_status()
|
|
328
313
|
return lora_name
|
|
@@ -331,14 +316,14 @@ def _load_lora_adapter(lora_path: str, port: int) -> str:
|
|
|
331
316
|
def _unload_lora_adapter(lora_path: str, port: int) -> None:
|
|
332
317
|
"""Unload LoRA adapter."""
|
|
333
318
|
try:
|
|
334
|
-
lora_name = os.path.basename(lora_path.rstrip(
|
|
319
|
+
lora_name = os.path.basename(lora_path.rstrip("/\\"))
|
|
335
320
|
if not lora_name:
|
|
336
321
|
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
337
|
-
|
|
322
|
+
|
|
338
323
|
response = requests.post(
|
|
339
|
-
f
|
|
340
|
-
headers={
|
|
341
|
-
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
324
|
+
f"http://localhost:{port}/v1/unload_lora_adapter",
|
|
325
|
+
headers={"accept": "application/json", "Content-Type": "application/json"},
|
|
326
|
+
json={"lora_name": lora_name, "lora_int_id": 0},
|
|
342
327
|
)
|
|
343
328
|
response.raise_for_status()
|
|
344
329
|
except requests.RequestException as e:
|
speedy_utils/__init__.py
CHANGED
|
@@ -79,7 +79,24 @@ from glob import glob
|
|
|
79
79
|
from multiprocessing import Pool
|
|
80
80
|
from pathlib import Path
|
|
81
81
|
from threading import Lock
|
|
82
|
-
from typing import
|
|
82
|
+
from typing import (
|
|
83
|
+
Any,
|
|
84
|
+
Awaitable,
|
|
85
|
+
Callable as TypingCallable,
|
|
86
|
+
Dict,
|
|
87
|
+
Generic,
|
|
88
|
+
Iterable,
|
|
89
|
+
List,
|
|
90
|
+
Literal,
|
|
91
|
+
Mapping,
|
|
92
|
+
Optional,
|
|
93
|
+
Sequence,
|
|
94
|
+
Set,
|
|
95
|
+
Tuple,
|
|
96
|
+
Type,
|
|
97
|
+
TypeVar,
|
|
98
|
+
Union,
|
|
99
|
+
)
|
|
83
100
|
|
|
84
101
|
# Third-party imports
|
|
85
102
|
import numpy as np
|
|
@@ -124,6 +141,7 @@ from .common.utils_io import (
|
|
|
124
141
|
# Misc utilities
|
|
125
142
|
from .common.utils_misc import (
|
|
126
143
|
convert_to_builtin_python,
|
|
144
|
+
dedup,
|
|
127
145
|
flatten_list,
|
|
128
146
|
get_arg_names,
|
|
129
147
|
is_notebook,
|
|
@@ -171,12 +189,20 @@ __all__ = [
|
|
|
171
189
|
"defaultdict",
|
|
172
190
|
# Typing
|
|
173
191
|
"Any",
|
|
192
|
+
"Awaitable",
|
|
174
193
|
"Callable",
|
|
194
|
+
"TypingCallable",
|
|
175
195
|
"Dict",
|
|
176
196
|
"Generic",
|
|
197
|
+
"Iterable",
|
|
177
198
|
"List",
|
|
178
199
|
"Literal",
|
|
200
|
+
"Mapping",
|
|
179
201
|
"Optional",
|
|
202
|
+
"Sequence",
|
|
203
|
+
"Set",
|
|
204
|
+
"Tuple",
|
|
205
|
+
"Type",
|
|
180
206
|
"TypeVar",
|
|
181
207
|
"Union",
|
|
182
208
|
# Third-party
|
|
@@ -214,6 +240,7 @@ __all__ = [
|
|
|
214
240
|
"get_arg_names",
|
|
215
241
|
"is_notebook",
|
|
216
242
|
"convert_to_builtin_python",
|
|
243
|
+
"dedup",
|
|
217
244
|
# Print utilities
|
|
218
245
|
"display_pretty_table_html",
|
|
219
246
|
"flatten_dict",
|
speedy_utils/all.py
CHANGED
|
@@ -71,7 +71,24 @@ from glob import glob
|
|
|
71
71
|
from multiprocessing import Pool
|
|
72
72
|
from pathlib import Path
|
|
73
73
|
from threading import Lock
|
|
74
|
-
from typing import
|
|
74
|
+
from typing import (
|
|
75
|
+
Any,
|
|
76
|
+
Awaitable,
|
|
77
|
+
Callable as TypingCallable,
|
|
78
|
+
Dict,
|
|
79
|
+
Generic,
|
|
80
|
+
Iterable,
|
|
81
|
+
List,
|
|
82
|
+
Literal,
|
|
83
|
+
Mapping,
|
|
84
|
+
Optional,
|
|
85
|
+
Sequence,
|
|
86
|
+
Set,
|
|
87
|
+
Tuple,
|
|
88
|
+
Type,
|
|
89
|
+
TypeVar,
|
|
90
|
+
Union,
|
|
91
|
+
)
|
|
75
92
|
|
|
76
93
|
# Third-party imports
|
|
77
94
|
import numpy as np
|
|
@@ -115,6 +132,9 @@ from speedy_utils import ( # Clock module; Function decorators; Cache utilities
|
|
|
115
132
|
timef,
|
|
116
133
|
)
|
|
117
134
|
|
|
135
|
+
|
|
136
|
+
choice = random.choice
|
|
137
|
+
|
|
118
138
|
# Define __all__ explicitly with all exports
|
|
119
139
|
__all__ = [
|
|
120
140
|
# Standard library
|
|
@@ -146,12 +166,20 @@ __all__ = [
|
|
|
146
166
|
"defaultdict",
|
|
147
167
|
# Typing
|
|
148
168
|
"Any",
|
|
169
|
+
"Awaitable",
|
|
149
170
|
"Callable",
|
|
171
|
+
"TypingCallable",
|
|
150
172
|
"Dict",
|
|
151
173
|
"Generic",
|
|
174
|
+
"Iterable",
|
|
152
175
|
"List",
|
|
153
176
|
"Literal",
|
|
177
|
+
"Mapping",
|
|
154
178
|
"Optional",
|
|
179
|
+
"Sequence",
|
|
180
|
+
"Set",
|
|
181
|
+
"Tuple",
|
|
182
|
+
"Type",
|
|
155
183
|
"TypeVar",
|
|
156
184
|
"Union",
|
|
157
185
|
# Third-party
|
|
@@ -199,4 +227,5 @@ __all__ = [
|
|
|
199
227
|
# Multi-worker processing
|
|
200
228
|
"multi_process",
|
|
201
229
|
"multi_thread",
|
|
230
|
+
"choice",
|
|
202
231
|
]
|
speedy_utils/common/utils_io.py
CHANGED
|
@@ -29,9 +29,7 @@ def dump_jsonl(list_dictionaries: list[dict], file_name: str = "output.jsonl") -
|
|
|
29
29
|
file.write(json.dumps(dictionary, ensure_ascii=False) + "\n")
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
def dump_json_or_pickle(
|
|
33
|
-
obj: Any, fname: str, ensure_ascii: bool = False, indent: int = 4
|
|
34
|
-
) -> None:
|
|
32
|
+
def dump_json_or_pickle(obj: Any, fname: str, ensure_ascii: bool = False, indent: int = 4) -> None:
|
|
35
33
|
"""
|
|
36
34
|
Dump an object to a file, supporting both JSON and pickle formats.
|
|
37
35
|
"""
|
|
@@ -59,6 +57,7 @@ def dump_json_or_pickle(
|
|
|
59
57
|
if isinstance(obj, BaseModel):
|
|
60
58
|
data = obj.model_dump()
|
|
61
59
|
from fastcore.all import dict2obj, obj2dict
|
|
60
|
+
|
|
62
61
|
obj2 = dict2obj(data)
|
|
63
62
|
with open(fname, "wb") as f:
|
|
64
63
|
pickle.dump(obj2, f)
|
|
@@ -84,7 +83,8 @@ def load_json_or_pickle(fname: str, counter=0) -> Any:
|
|
|
84
83
|
except EOFError:
|
|
85
84
|
time.sleep(1)
|
|
86
85
|
if counter > 5:
|
|
87
|
-
|
|
86
|
+
# Keep message concise and actionable
|
|
87
|
+
print(f"Corrupted cache file {fname} removed; it will be regenerated on next access")
|
|
88
88
|
os.remove(fname)
|
|
89
89
|
raise
|
|
90
90
|
return load_json_or_pickle(fname, counter + 1)
|
|
@@ -92,8 +92,6 @@ def load_json_or_pickle(fname: str, counter=0) -> Any:
|
|
|
92
92
|
raise ValueError(f"Error {e} while loading {fname}") from e
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
|
|
96
|
-
|
|
97
95
|
try:
|
|
98
96
|
import orjson # type: ignore[import-not-found] # fastest JSON parser when available
|
|
99
97
|
except Exception:
|
|
@@ -113,11 +111,11 @@ def fast_load_jsonl(
|
|
|
113
111
|
use_orjson: bool = True,
|
|
114
112
|
encoding: str = "utf-8",
|
|
115
113
|
errors: str = "strict",
|
|
116
|
-
on_error: str = "raise",
|
|
114
|
+
on_error: str = "raise", # 'raise' | 'warn' | 'skip'
|
|
117
115
|
skip_empty: bool = True,
|
|
118
116
|
max_lines: Optional[int] = None,
|
|
119
117
|
use_multiworker: bool = True,
|
|
120
|
-
multiworker_threshold: int =
|
|
118
|
+
multiworker_threshold: int = 1000000,
|
|
121
119
|
workers: Optional[int] = None,
|
|
122
120
|
) -> Iterable[Any]:
|
|
123
121
|
"""
|
|
@@ -127,7 +125,7 @@ def fast_load_jsonl(
|
|
|
127
125
|
- Optional tqdm progress over bytes (compressed size if gz/bz2/xz/zst).
|
|
128
126
|
- Auto-detects compression by extension: .gz, .bz2, .xz/.lzma, .zst/.zstd.
|
|
129
127
|
- Uses orjson if available (use_orjson=True), falls back to json.
|
|
130
|
-
- Automatically uses multi-worker processing for large files (>
|
|
128
|
+
- Automatically uses multi-worker processing for large files (>100k lines).
|
|
131
129
|
|
|
132
130
|
Args:
|
|
133
131
|
path_or_file: Path-like or file-like object. File-like can be binary or text.
|
|
@@ -140,11 +138,12 @@ def fast_load_jsonl(
|
|
|
140
138
|
max_lines: Stop after reading this many lines (useful for sampling).
|
|
141
139
|
use_multiworker: Enable multi-worker processing for large files.
|
|
142
140
|
multiworker_threshold: Line count threshold to trigger multi-worker processing.
|
|
143
|
-
workers: Number of worker threads (defaults to CPU count).
|
|
141
|
+
workers: Number of worker threads (defaults to 80% of CPU count, max 8).
|
|
144
142
|
|
|
145
143
|
Yields:
|
|
146
144
|
Parsed Python objects per line.
|
|
147
145
|
"""
|
|
146
|
+
|
|
148
147
|
def _open_auto(pth_or_f) -> IO[Any]:
|
|
149
148
|
if hasattr(pth_or_f, "read"):
|
|
150
149
|
# ensure binary buffer for consistent byte-length progress
|
|
@@ -206,39 +205,47 @@ def fast_load_jsonl(
|
|
|
206
205
|
|
|
207
206
|
# Check if we should use multi-worker processing
|
|
208
207
|
should_use_multiworker = (
|
|
209
|
-
use_multiworker
|
|
208
|
+
use_multiworker
|
|
210
209
|
and not hasattr(path_or_file, "read") # Only for file paths, not file objects
|
|
211
210
|
and max_lines is None # Don't use multiworker if we're limiting lines
|
|
212
211
|
)
|
|
213
|
-
|
|
212
|
+
|
|
214
213
|
if should_use_multiworker:
|
|
215
214
|
line_count = _count_lines_fast(cast(Union[str, os.PathLike], path_or_file))
|
|
216
215
|
if line_count > multiworker_threshold:
|
|
217
216
|
# Use multi-worker processing
|
|
218
217
|
from ..multi_worker.thread import multi_thread
|
|
219
218
|
|
|
219
|
+
# Calculate optimal worker count: 80% of CPU count, capped at 8
|
|
220
|
+
cpu_count = os.cpu_count() or 4
|
|
221
|
+
default_workers = min(int(cpu_count * 0.8), 8)
|
|
222
|
+
num_workers = workers if workers is not None else default_workers
|
|
223
|
+
num_workers = max(1, num_workers) # At least 1 worker
|
|
224
|
+
|
|
220
225
|
# Read all lines into chunks
|
|
221
226
|
f = _open_auto(path_or_file)
|
|
222
227
|
all_lines = list(f)
|
|
223
228
|
f.close()
|
|
224
|
-
|
|
225
|
-
# Split into chunks for
|
|
226
|
-
|
|
227
|
-
chunk_size = max(len(all_lines) // num_workers,
|
|
229
|
+
|
|
230
|
+
# Split into chunks - aim for ~10k-20k lines per chunk minimum
|
|
231
|
+
min_chunk_size = 10000
|
|
232
|
+
chunk_size = max(len(all_lines) // num_workers, min_chunk_size)
|
|
228
233
|
chunks = []
|
|
229
234
|
for i in range(0, len(all_lines), chunk_size):
|
|
230
|
-
chunks.append(all_lines[i:i + chunk_size])
|
|
231
|
-
|
|
235
|
+
chunks.append(all_lines[i : i + chunk_size])
|
|
236
|
+
|
|
232
237
|
# Process chunks in parallel
|
|
233
238
|
if progress:
|
|
234
|
-
print(f"Processing {line_count} lines with {num_workers} workers...")
|
|
235
|
-
|
|
239
|
+
print(f"Processing {line_count} lines with {num_workers} workers ({len(chunks)} chunks)...")
|
|
240
|
+
|
|
236
241
|
chunk_results = multi_thread(_process_chunk, chunks, workers=num_workers, progress=progress)
|
|
237
|
-
|
|
242
|
+
|
|
238
243
|
# Flatten results and yield
|
|
239
|
-
|
|
240
|
-
for
|
|
241
|
-
|
|
244
|
+
if chunk_results:
|
|
245
|
+
for chunk_result in chunk_results:
|
|
246
|
+
if chunk_result:
|
|
247
|
+
for obj in chunk_result:
|
|
248
|
+
yield obj
|
|
242
249
|
return
|
|
243
250
|
|
|
244
251
|
# Single-threaded processing (original logic)
|
|
@@ -266,7 +273,11 @@ def fast_load_jsonl(
|
|
|
266
273
|
line_no += 1
|
|
267
274
|
if pbar is not None:
|
|
268
275
|
# raw_line is bytes here; if not, compute byte length
|
|
269
|
-
nbytes =
|
|
276
|
+
nbytes = (
|
|
277
|
+
len(raw_line)
|
|
278
|
+
if isinstance(raw_line, (bytes, bytearray))
|
|
279
|
+
else len(str(raw_line).encode(encoding, errors))
|
|
280
|
+
)
|
|
270
281
|
pbar.update(nbytes)
|
|
271
282
|
|
|
272
283
|
# Normalize to bytes -> str only if needed
|
|
@@ -322,7 +333,6 @@ def fast_load_jsonl(
|
|
|
322
333
|
pass
|
|
323
334
|
|
|
324
335
|
|
|
325
|
-
|
|
326
336
|
def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
|
|
327
337
|
"""
|
|
328
338
|
Load data based on file extension.
|
|
@@ -3,10 +3,12 @@
|
|
|
3
3
|
import inspect
|
|
4
4
|
import os
|
|
5
5
|
from collections.abc import Callable
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, TypeVar
|
|
7
7
|
|
|
8
8
|
from pydantic import BaseModel
|
|
9
9
|
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
def mkdir_or_exist(dir_name: str) -> None:
|
|
12
14
|
"""Create a directory if it doesn't exist."""
|
|
@@ -50,10 +52,32 @@ def convert_to_builtin_python(input_data: Any) -> Any:
|
|
|
50
52
|
raise ValueError(f"Unsupported type {type(input_data)}")
|
|
51
53
|
|
|
52
54
|
|
|
55
|
+
def dedup(items: list[T], key: Callable[[T], Any]) -> list[T]:
|
|
56
|
+
"""
|
|
57
|
+
Deduplicate items in a list based on a key function.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
items: The list of items.
|
|
61
|
+
key: A function that takes an item and returns a hashable key.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
A list with duplicates removed, preserving the first occurrence.
|
|
65
|
+
"""
|
|
66
|
+
seen = set()
|
|
67
|
+
result = []
|
|
68
|
+
for item in items:
|
|
69
|
+
k = key(item)
|
|
70
|
+
if k not in seen:
|
|
71
|
+
seen.add(k)
|
|
72
|
+
result.append(item)
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
|
|
53
76
|
__all__ = [
|
|
54
77
|
"mkdir_or_exist",
|
|
55
78
|
"flatten_list",
|
|
56
79
|
"get_arg_names",
|
|
57
80
|
"is_notebook",
|
|
58
81
|
"convert_to_builtin_python",
|
|
82
|
+
"dedup",
|
|
59
83
|
]
|