speedy-utils 1.1.21__py3-none-any.whl → 1.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +22 -3
- llm_utils/lm/__init__.py +10 -0
- llm_utils/lm/llm_as_a_judge.py +390 -0
- llm_utils/lm/llm_task.py +172 -251
- llm_utils/lm/signature.py +282 -0
- llm_utils/lm/utils.py +332 -110
- speedy_utils/multi_worker/process.py +125 -25
- speedy_utils/multi_worker/thread.py +341 -226
- {speedy_utils-1.1.21.dist-info → speedy_utils-1.1.23.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.21.dist-info → speedy_utils-1.1.23.dist-info}/RECORD +12 -11
- llm_utils/lm/lm.py +0 -207
- {speedy_utils-1.1.21.dist-info → speedy_utils-1.1.23.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.21.dist-info → speedy_utils-1.1.23.dist-info}/entry_points.txt +0 -0
llm_utils/lm/utils.py
CHANGED
|
@@ -1,123 +1,345 @@
|
|
|
1
|
-
import fcntl
|
|
2
1
|
import os
|
|
3
|
-
import
|
|
2
|
+
import signal
|
|
4
3
|
import time
|
|
5
|
-
from typing import List,
|
|
6
|
-
|
|
4
|
+
from typing import Any, List, Optional, cast
|
|
5
|
+
|
|
7
6
|
from loguru import logger
|
|
8
7
|
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
9
|
+
# Additional imports for VLLM utilities
|
|
10
|
+
import re
|
|
11
|
+
import subprocess
|
|
12
|
+
import requests
|
|
13
|
+
from openai import OpenAI
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import psutil
|
|
17
|
+
HAS_PSUTIL = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
HAS_PSUTIL = False
|
|
20
|
+
psutil = cast(Any, None)
|
|
21
|
+
logger.warning("psutil not available. Some VLLM process management features may be limited.")
|
|
22
|
+
|
|
23
|
+
# Global tracking of VLLM processes
|
|
24
|
+
_VLLM_PROCESSES: List[subprocess.Popen] = []
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
|
|
28
|
+
"""Extract port from VLLM command string."""
|
|
29
|
+
port_match = re.search(r'--port\s+(\d+)', vllm_cmd)
|
|
30
|
+
if port_match:
|
|
31
|
+
return int(port_match.group(1))
|
|
32
|
+
return 8000
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _parse_env_vars_from_cmd(cmd: str) -> tuple[dict[str, str], str]:
|
|
36
|
+
"""Parse environment variables from command string.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
cmd: Command string that may contain environment variables like 'VAR=value command...'
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Tuple of (env_dict, cleaned_cmd) where env_dict contains parsed env vars
|
|
43
|
+
and cleaned_cmd is the command without the env vars.
|
|
44
|
+
"""
|
|
45
|
+
import shlex
|
|
46
|
+
|
|
47
|
+
# Split the command while preserving quoted strings
|
|
48
|
+
parts = shlex.split(cmd)
|
|
49
|
+
|
|
50
|
+
env_vars = {}
|
|
51
|
+
cmd_parts = []
|
|
52
|
+
|
|
53
|
+
for part in parts:
|
|
54
|
+
if '=' in part and not part.startswith('-'):
|
|
55
|
+
# Check if this looks like an environment variable
|
|
56
|
+
# Should be KEY=VALUE format, not contain spaces (unless quoted), and KEY should be uppercase
|
|
57
|
+
key_value = part.split('=', 1)
|
|
58
|
+
if len(key_value) == 2:
|
|
59
|
+
key, value = key_value
|
|
60
|
+
if key.isupper() and key.replace('_', '').isalnum():
|
|
61
|
+
env_vars[key] = value
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
# Not an env var, add to command parts
|
|
65
|
+
cmd_parts.append(part)
|
|
66
|
+
|
|
67
|
+
# Reconstruct the cleaned command
|
|
68
|
+
cleaned_cmd = ' '.join(cmd_parts)
|
|
69
|
+
|
|
70
|
+
return env_vars, cleaned_cmd
|
|
71
|
+
|
|
16
72
|
|
|
73
|
+
def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
|
|
74
|
+
"""Start VLLM server and wait for ready."""
|
|
75
|
+
# Parse environment variables from command
|
|
76
|
+
env_vars, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
|
|
77
|
+
|
|
78
|
+
port = _extract_port_from_vllm_cmd(cleaned_cmd)
|
|
79
|
+
|
|
80
|
+
logger.info(f"Starting VLLM server: {cleaned_cmd}")
|
|
81
|
+
if env_vars:
|
|
82
|
+
logger.info(f"Environment variables: {env_vars}")
|
|
83
|
+
logger.info(f"VLLM output logged to: /tmp/vllm_{port}.txt")
|
|
84
|
+
|
|
85
|
+
with open(f'/tmp/vllm_{port}.txt', 'w') as log_file:
|
|
86
|
+
log_file.write(f"VLLM Server started at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
87
|
+
log_file.write(f"Command: {cleaned_cmd}\n")
|
|
88
|
+
if env_vars:
|
|
89
|
+
log_file.write(f"Environment: {env_vars}\n")
|
|
90
|
+
log_file.write(f"Port: {port}\n")
|
|
91
|
+
log_file.write("-" * 50 + "\n")
|
|
92
|
+
|
|
93
|
+
# Prepare environment for subprocess
|
|
94
|
+
env = os.environ.copy()
|
|
95
|
+
env.update(env_vars)
|
|
17
96
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
97
|
+
with open(f'/tmp/vllm_{port}.txt', 'a') as log_file:
|
|
98
|
+
process = subprocess.Popen(
|
|
99
|
+
cleaned_cmd.split(),
|
|
100
|
+
stdout=log_file,
|
|
101
|
+
stderr=subprocess.STDOUT,
|
|
102
|
+
text=True,
|
|
103
|
+
preexec_fn=os.setsid,
|
|
104
|
+
env=env
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
_VLLM_PROCESSES.append(process)
|
|
108
|
+
|
|
109
|
+
start_time = time.time()
|
|
110
|
+
while time.time() - start_time < timeout:
|
|
23
111
|
try:
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
112
|
+
response = requests.get(f"http://localhost:{port}/health", timeout=2)
|
|
113
|
+
if response.status_code == 200:
|
|
114
|
+
logger.info(f"VLLM server ready on port {port}")
|
|
115
|
+
return process
|
|
116
|
+
except requests.RequestException:
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
if process.poll() is not None:
|
|
120
|
+
stdout, stderr = process.communicate()
|
|
121
|
+
raise RuntimeError(
|
|
122
|
+
f"VLLM server terminated unexpectedly. "
|
|
123
|
+
f"Return code: {process.returncode}, "
|
|
124
|
+
f"stderr: {stderr[:200]}..."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
time.sleep(2)
|
|
128
|
+
|
|
129
|
+
process.terminate()
|
|
130
|
+
try:
|
|
131
|
+
process.wait(timeout=5)
|
|
132
|
+
except subprocess.TimeoutExpired:
|
|
133
|
+
process.kill()
|
|
134
|
+
|
|
135
|
+
if process in _VLLM_PROCESSES:
|
|
136
|
+
_VLLM_PROCESSES.remove(process)
|
|
137
|
+
|
|
138
|
+
raise RuntimeError(f"VLLM server failed to start within {timeout}s on port {port}")
|
|
36
139
|
|
|
37
140
|
|
|
38
|
-
def
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
141
|
+
def _kill_vllm_on_port(port: int) -> bool:
|
|
142
|
+
"""Kill VLLM server on port."""
|
|
143
|
+
killed = False
|
|
144
|
+
logger.info(f"Checking VLLM server on port {port}")
|
|
145
|
+
|
|
146
|
+
processes_to_remove = []
|
|
147
|
+
for process in _VLLM_PROCESSES:
|
|
42
148
|
try:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
if os.path.exists(file_counter):
|
|
149
|
+
if process.poll() is None:
|
|
150
|
+
killed_process = False
|
|
151
|
+
if HAS_PSUTIL:
|
|
47
152
|
try:
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
153
|
+
proc = psutil.Process(process.pid)
|
|
154
|
+
cmdline = ' '.join(proc.cmdline())
|
|
155
|
+
if f'--port {port}' in cmdline or f'--port={port}' in cmdline:
|
|
156
|
+
logger.info(f"Killing tracked VLLM process {process.pid} on port {port}")
|
|
157
|
+
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
|
|
158
|
+
try:
|
|
159
|
+
process.wait(timeout=5)
|
|
160
|
+
except subprocess.TimeoutExpired:
|
|
161
|
+
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
|
|
162
|
+
process.wait()
|
|
163
|
+
killed = True
|
|
164
|
+
killed_process = True
|
|
165
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
if not HAS_PSUTIL or not killed_process:
|
|
169
|
+
logger.info(f"Killing tracked VLLM process {process.pid}")
|
|
170
|
+
try:
|
|
171
|
+
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
|
|
172
|
+
try:
|
|
173
|
+
process.wait(timeout=5)
|
|
174
|
+
except subprocess.TimeoutExpired:
|
|
175
|
+
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
|
|
176
|
+
process.wait()
|
|
177
|
+
killed = True
|
|
178
|
+
except (ProcessLookupError, OSError):
|
|
179
|
+
pass
|
|
180
|
+
|
|
181
|
+
processes_to_remove.append(process)
|
|
182
|
+
else:
|
|
183
|
+
processes_to_remove.append(process)
|
|
184
|
+
except (ProcessLookupError, OSError):
|
|
185
|
+
processes_to_remove.append(process)
|
|
186
|
+
|
|
187
|
+
for process in processes_to_remove:
|
|
188
|
+
if process in _VLLM_PROCESSES:
|
|
189
|
+
_VLLM_PROCESSES.remove(process)
|
|
190
|
+
|
|
191
|
+
if not killed and HAS_PSUTIL:
|
|
192
|
+
try:
|
|
193
|
+
for proc in psutil.process_iter(['pid', 'cmdline']):
|
|
75
194
|
try:
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
195
|
+
cmdline = ' '.join(proc.info['cmdline'] or [])
|
|
196
|
+
if ('vllm' in cmdline.lower() and
|
|
197
|
+
(f'--port {port}' in cmdline or f'--port={port}' in cmdline)):
|
|
198
|
+
logger.info(f"Killing untracked VLLM process {proc.info['pid']} on port {port}")
|
|
199
|
+
proc.terminate()
|
|
200
|
+
try:
|
|
201
|
+
proc.wait(timeout=5)
|
|
202
|
+
except psutil.TimeoutExpired:
|
|
203
|
+
proc.kill()
|
|
204
|
+
killed = True
|
|
205
|
+
break
|
|
206
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
207
|
+
continue
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.warning(f"Error searching processes on port {port}: {e}")
|
|
210
|
+
|
|
211
|
+
if killed:
|
|
212
|
+
logger.info(f"Killed VLLM server on port {port}")
|
|
213
|
+
time.sleep(2)
|
|
214
|
+
else:
|
|
215
|
+
logger.info(f"No VLLM server on port {port}")
|
|
216
|
+
|
|
217
|
+
return killed
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def stop_vllm_process(process: subprocess.Popen, wait_timeout: int = 10) -> None:
|
|
221
|
+
"""Terminate a tracked VLLM process and remove it from tracking."""
|
|
222
|
+
logger.info(f"Stopping VLLM process {process.pid}")
|
|
223
|
+
try:
|
|
224
|
+
if process.poll() is None:
|
|
225
|
+
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
|
|
226
|
+
try:
|
|
227
|
+
process.wait(timeout=wait_timeout)
|
|
228
|
+
logger.info("VLLM process stopped gracefully")
|
|
229
|
+
except subprocess.TimeoutExpired:
|
|
230
|
+
logger.warning("VLLM process didn't stop gracefully, forcing kill")
|
|
231
|
+
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
|
|
232
|
+
process.wait()
|
|
233
|
+
else:
|
|
234
|
+
logger.info("VLLM process already terminated")
|
|
235
|
+
except (ProcessLookupError, OSError) as exc:
|
|
236
|
+
logger.warning(f"Process may have already terminated: {exc}")
|
|
237
|
+
finally:
|
|
238
|
+
if process in _VLLM_PROCESSES:
|
|
239
|
+
_VLLM_PROCESSES.remove(process)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def kill_all_vllm_processes() -> int:
|
|
243
|
+
"""Kill all tracked VLLM processes."""
|
|
244
|
+
killed_count = 0
|
|
245
|
+
for process in list(_VLLM_PROCESSES):
|
|
246
|
+
if process.poll() is None:
|
|
247
|
+
logger.info(f"Killing VLLM process with PID {process.pid}")
|
|
248
|
+
stop_vllm_process(process, wait_timeout=5)
|
|
249
|
+
killed_count += 1
|
|
250
|
+
else:
|
|
251
|
+
_VLLM_PROCESSES.remove(process)
|
|
252
|
+
logger.info(f"Killed {killed_count} VLLM processes")
|
|
253
|
+
return killed_count
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _is_server_running(port: int) -> bool:
|
|
257
|
+
"""Check if server is running on port."""
|
|
258
|
+
try:
|
|
259
|
+
response = requests.get(f"http://localhost:{port}/health", timeout=2)
|
|
260
|
+
return response.status_code == 200
|
|
261
|
+
except requests.RequestException:
|
|
262
|
+
return False
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def get_base_client(
|
|
266
|
+
client=None,
|
|
267
|
+
cache: bool = True,
|
|
268
|
+
api_key="abc",
|
|
269
|
+
vllm_cmd=None,
|
|
270
|
+
vllm_process=None
|
|
271
|
+
) -> OpenAI:
|
|
272
|
+
"""Get OpenAI client from various inputs."""
|
|
273
|
+
from llm_utils import MOpenAI
|
|
274
|
+
|
|
275
|
+
open_ai_class = OpenAI if not cache else MOpenAI
|
|
276
|
+
|
|
277
|
+
if client is None:
|
|
278
|
+
if vllm_cmd is not None:
|
|
279
|
+
# Parse environment variables from command to get clean command for port extraction
|
|
280
|
+
_, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
|
|
281
|
+
port = _extract_port_from_vllm_cmd(cleaned_cmd)
|
|
282
|
+
return open_ai_class(base_url=f"http://localhost:{port}/v1", api_key=api_key)
|
|
283
|
+
else:
|
|
284
|
+
return open_ai_class()
|
|
285
|
+
elif isinstance(client, int):
|
|
286
|
+
return open_ai_class(base_url=f"http://localhost:{client}/v1", api_key=api_key)
|
|
287
|
+
elif isinstance(client, str):
|
|
288
|
+
return open_ai_class(base_url=client, api_key=api_key)
|
|
289
|
+
elif isinstance(client, OpenAI):
|
|
290
|
+
return client
|
|
291
|
+
else:
|
|
292
|
+
raise ValueError("Invalid client type. Must be OpenAI, port (int), base_url (str), or None.")
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _is_lora_path(path: str) -> bool:
|
|
296
|
+
"""Check if path is LoRA adapter directory."""
|
|
297
|
+
if not os.path.isdir(path):
|
|
298
|
+
return False
|
|
299
|
+
adapter_config_path = os.path.join(path, 'adapter_config.json')
|
|
300
|
+
return os.path.isfile(adapter_config_path)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _get_port_from_client(client: OpenAI) -> Optional[int]:
|
|
304
|
+
"""Extract port from OpenAI client base_url."""
|
|
305
|
+
if hasattr(client, 'base_url') and client.base_url:
|
|
306
|
+
base_url = str(client.base_url)
|
|
307
|
+
if 'localhost:' in base_url:
|
|
308
|
+
try:
|
|
309
|
+
port_part = base_url.split('localhost:')[1].split('/')[0]
|
|
310
|
+
return int(port_part)
|
|
311
|
+
except (IndexError, ValueError):
|
|
312
|
+
pass
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def _load_lora_adapter(lora_path: str, port: int) -> str:
|
|
317
|
+
"""Load LoRA adapter from path."""
|
|
318
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
319
|
+
if not lora_name:
|
|
320
|
+
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
321
|
+
|
|
322
|
+
response = requests.post(
|
|
323
|
+
f'http://localhost:{port}/v1/load_lora_adapter',
|
|
324
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
325
|
+
json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
|
|
326
|
+
)
|
|
327
|
+
response.raise_for_status()
|
|
328
|
+
return lora_name
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _unload_lora_adapter(lora_path: str, port: int) -> None:
|
|
332
|
+
"""Unload LoRA adapter."""
|
|
333
|
+
try:
|
|
334
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
335
|
+
if not lora_name:
|
|
336
|
+
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
337
|
+
|
|
338
|
+
response = requests.post(
|
|
339
|
+
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
340
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
341
|
+
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
342
|
+
)
|
|
343
|
+
response.raise_for_status()
|
|
344
|
+
except requests.RequestException as e:
|
|
345
|
+
logger.warning(f"Error unloading LoRA adapter: {str(e)[:100]}")
|
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
# ray_multi_process.py
|
|
2
|
-
import time, os, pickle, uuid, datetime, multiprocessing
|
|
3
2
|
import datetime
|
|
4
3
|
import os
|
|
5
4
|
import pickle
|
|
5
|
+
import threading
|
|
6
6
|
import time
|
|
7
7
|
import uuid
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Any, Callable
|
|
10
|
-
|
|
9
|
+
from typing import Any, Callable, Iterable
|
|
10
|
+
|
|
11
11
|
import psutil
|
|
12
|
-
import
|
|
12
|
+
from fastcore.parallel import parallel
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
13
15
|
ray: Any
|
|
14
16
|
try:
|
|
15
17
|
import ray as ray # type: ignore
|
|
@@ -17,11 +19,75 @@ try:
|
|
|
17
19
|
except Exception: # pragma: no cover
|
|
18
20
|
ray = None # type: ignore
|
|
19
21
|
_HAS_RAY = False
|
|
20
|
-
from typing import Any, Callable, Iterable
|
|
21
22
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
23
|
+
|
|
24
|
+
# ─── global tracking ──────────────────────────────────────────
|
|
25
|
+
|
|
26
|
+
# Global tracking for processes and threads
|
|
27
|
+
SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
|
|
28
|
+
_SPEEDY_PROCESSES_LOCK = threading.Lock()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _prune_dead_processes() -> None:
|
|
32
|
+
"""Remove dead processes from tracking list."""
|
|
33
|
+
with _SPEEDY_PROCESSES_LOCK:
|
|
34
|
+
SPEEDY_RUNNING_PROCESSES[:] = [p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _track_processes(processes: list[psutil.Process]) -> None:
|
|
38
|
+
"""Add processes to global tracking list."""
|
|
39
|
+
if not processes:
|
|
40
|
+
return
|
|
41
|
+
with _SPEEDY_PROCESSES_LOCK:
|
|
42
|
+
living = [p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()]
|
|
43
|
+
for candidate in processes:
|
|
44
|
+
if not candidate.is_running():
|
|
45
|
+
continue
|
|
46
|
+
if any(existing.pid == candidate.pid for existing in living):
|
|
47
|
+
continue
|
|
48
|
+
living.append(candidate)
|
|
49
|
+
SPEEDY_RUNNING_PROCESSES[:] = living
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _track_ray_processes() -> None:
|
|
53
|
+
"""Track Ray worker processes when Ray is initialized."""
|
|
54
|
+
if not _HAS_RAY or not ray.is_initialized():
|
|
55
|
+
return
|
|
56
|
+
try:
|
|
57
|
+
# Get Ray worker processes
|
|
58
|
+
current_pid = os.getpid()
|
|
59
|
+
parent = psutil.Process(current_pid)
|
|
60
|
+
ray_processes = []
|
|
61
|
+
for child in parent.children(recursive=True):
|
|
62
|
+
try:
|
|
63
|
+
if 'ray' in child.name().lower() or 'worker' in child.name().lower():
|
|
64
|
+
ray_processes.append(child)
|
|
65
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
66
|
+
continue
|
|
67
|
+
_track_processes(ray_processes)
|
|
68
|
+
except Exception:
|
|
69
|
+
# Don't fail if process tracking fails
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _track_multiprocessing_processes() -> None:
|
|
74
|
+
"""Track multiprocessing worker processes."""
|
|
75
|
+
try:
|
|
76
|
+
# Find recently created child processes that might be multiprocessing workers
|
|
77
|
+
current_pid = os.getpid()
|
|
78
|
+
parent = psutil.Process(current_pid)
|
|
79
|
+
new_processes = []
|
|
80
|
+
for child in parent.children(recursive=False): # Only direct children
|
|
81
|
+
try:
|
|
82
|
+
# Basic heuristic: if it's a recent child process, it might be a worker
|
|
83
|
+
if time.time() - child.create_time() < 5: # Created within last 5 seconds
|
|
84
|
+
new_processes.append(child)
|
|
85
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
86
|
+
continue
|
|
87
|
+
_track_processes(new_processes)
|
|
88
|
+
except Exception:
|
|
89
|
+
# Don't fail if process tracking fails
|
|
90
|
+
pass
|
|
25
91
|
|
|
26
92
|
|
|
27
93
|
# ─── cache helpers ──────────────────────────────────────────
|
|
@@ -68,6 +134,7 @@ def ensure_ray(workers: int, pbar: tqdm | None = None):
|
|
|
68
134
|
t0 = time.time()
|
|
69
135
|
ray.init(num_cpus=workers, ignore_reinit_error=True)
|
|
70
136
|
took = time.time() - t0
|
|
137
|
+
_track_ray_processes() # Track Ray worker processes
|
|
71
138
|
if pbar:
|
|
72
139
|
pbar.set_postfix_str(f"ray.init {workers} took {took:.2f}s")
|
|
73
140
|
RAY_WORKER = workers
|
|
@@ -87,11 +154,6 @@ def multi_process(
|
|
|
87
154
|
progress: bool = True,
|
|
88
155
|
# backend: str = "ray", # "seq", "ray", or "fastcore"
|
|
89
156
|
backend: Literal["seq", "ray", "mp", "threadpool", "safe"] = "mp",
|
|
90
|
-
# Additional optional knobs (accepted for compatibility)
|
|
91
|
-
batch: int | None = None,
|
|
92
|
-
ordered: bool | None = None,
|
|
93
|
-
process_update_interval: int | None = None,
|
|
94
|
-
stop_on_error: bool | None = None,
|
|
95
157
|
**func_kwargs: Any,
|
|
96
158
|
) -> list[Any]:
|
|
97
159
|
"""
|
|
@@ -171,6 +233,8 @@ def multi_process(
|
|
|
171
233
|
results = parallel(
|
|
172
234
|
f_wrapped, items, n_workers=workers, progress=progress, threadpool=False
|
|
173
235
|
)
|
|
236
|
+
_track_multiprocessing_processes() # Track multiprocessing workers
|
|
237
|
+
_prune_dead_processes() # Clean up dead processes
|
|
174
238
|
return list(results)
|
|
175
239
|
if backend == "threadpool":
|
|
176
240
|
results = parallel(
|
|
@@ -180,8 +244,17 @@ def multi_process(
|
|
|
180
244
|
if backend == "safe":
|
|
181
245
|
# Completely safe backend for tests - no multiprocessing, no external progress bars
|
|
182
246
|
import concurrent.futures
|
|
183
|
-
|
|
184
|
-
|
|
247
|
+
# Import thread tracking from thread module
|
|
248
|
+
try:
|
|
249
|
+
from .thread import _track_executor_threads, _prune_dead_threads
|
|
250
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
|
251
|
+
_track_executor_threads(executor) # Track threads
|
|
252
|
+
results = list(executor.map(f_wrapped, items))
|
|
253
|
+
_prune_dead_threads() # Clean up dead threads
|
|
254
|
+
except ImportError:
|
|
255
|
+
# Fallback if thread module not available
|
|
256
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
|
257
|
+
results = list(executor.map(f_wrapped, items))
|
|
185
258
|
return results
|
|
186
259
|
|
|
187
260
|
raise ValueError(f"Unsupported backend: {backend!r}")
|
|
@@ -190,12 +263,24 @@ def multi_process(
|
|
|
190
263
|
|
|
191
264
|
def cleanup_phantom_workers():
|
|
192
265
|
"""
|
|
193
|
-
Kill all
|
|
266
|
+
Kill all tracked processes and threads (phantom workers) without killing the Jupyter kernel itself.
|
|
194
267
|
Also lists non-daemon threads that remain.
|
|
195
268
|
"""
|
|
196
|
-
|
|
269
|
+
# Clean up tracked processes first
|
|
270
|
+
_prune_dead_processes()
|
|
271
|
+
killed_processes = 0
|
|
272
|
+
with _SPEEDY_PROCESSES_LOCK:
|
|
273
|
+
for process in SPEEDY_RUNNING_PROCESSES[:]: # Copy to avoid modification during iteration
|
|
274
|
+
try:
|
|
275
|
+
print(f"🔪 Killing tracked process {process.pid} ({process.name()})")
|
|
276
|
+
process.kill()
|
|
277
|
+
killed_processes += 1
|
|
278
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
|
|
279
|
+
print(f"⚠️ Could not kill process {process.pid}: {e}")
|
|
280
|
+
SPEEDY_RUNNING_PROCESSES.clear()
|
|
197
281
|
|
|
198
|
-
#
|
|
282
|
+
# Also kill any remaining child processes (fallback)
|
|
283
|
+
parent = psutil.Process(os.getpid())
|
|
199
284
|
for child in parent.children(recursive=True):
|
|
200
285
|
try:
|
|
201
286
|
print(f"🔪 Killing child process {child.pid} ({child.name()})")
|
|
@@ -203,14 +288,29 @@ def cleanup_phantom_workers():
|
|
|
203
288
|
except psutil.NoSuchProcess:
|
|
204
289
|
pass
|
|
205
290
|
|
|
206
|
-
#
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
291
|
+
# Try to clean up threads using thread module functions if available
|
|
292
|
+
try:
|
|
293
|
+
from .thread import SPEEDY_RUNNING_THREADS, kill_all_thread, _prune_dead_threads
|
|
294
|
+
_prune_dead_threads()
|
|
295
|
+
killed_threads = kill_all_thread()
|
|
296
|
+
if killed_threads > 0:
|
|
297
|
+
print(f"🔪 Killed {killed_threads} tracked threads")
|
|
298
|
+
except ImportError:
|
|
299
|
+
# Fallback: just report stray threads
|
|
300
|
+
for t in threading.enumerate():
|
|
301
|
+
if t is threading.current_thread():
|
|
302
|
+
continue
|
|
303
|
+
if not t.daemon:
|
|
304
|
+
print(f"⚠️ Thread {t.name} is still running (cannot be force-killed).")
|
|
212
305
|
|
|
213
|
-
print("✅ Cleaned up child processes (kernel untouched).")
|
|
306
|
+
print(f"✅ Cleaned up {killed_processes} tracked processes and child processes (kernel untouched).")
|
|
214
307
|
|
|
215
308
|
# Usage: run this anytime after cancelling a cell
|
|
216
309
|
|
|
310
|
+
|
|
311
|
+
__all__ = [
|
|
312
|
+
"SPEEDY_RUNNING_PROCESSES",
|
|
313
|
+
"multi_process",
|
|
314
|
+
"cleanup_phantom_workers",
|
|
315
|
+
]
|
|
316
|
+
|