speedy-utils 1.1.21__py3-none-any.whl → 1.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/utils.py CHANGED
@@ -1,123 +1,345 @@
1
- import fcntl
2
1
  import os
3
- import tempfile
2
+ import signal
4
3
  import time
5
- from typing import List, Dict
6
- import numpy as np
4
+ from typing import Any, List, Optional, cast
5
+
7
6
  from loguru import logger
8
7
 
9
8
 
10
- def _atomic_save(array: np.ndarray, filename: str):
11
- tmp_dir = os.path.dirname(filename) or "."
12
- with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
13
- np.save(tmp, array)
14
- temp_name = tmp.name
15
- os.replace(temp_name, filename)
9
+ # Additional imports for VLLM utilities
10
+ import re
11
+ import subprocess
12
+ import requests
13
+ from openai import OpenAI
14
+
15
+ try:
16
+ import psutil
17
+ HAS_PSUTIL = True
18
+ except ImportError:
19
+ HAS_PSUTIL = False
20
+ psutil = cast(Any, None)
21
+ logger.warning("psutil not available. Some VLLM process management features may be limited.")
22
+
23
+ # Global tracking of VLLM processes
24
+ _VLLM_PROCESSES: List[subprocess.Popen] = []
25
+
26
+
27
+ def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
28
+ """Extract port from VLLM command string."""
29
+ port_match = re.search(r'--port\s+(\d+)', vllm_cmd)
30
+ if port_match:
31
+ return int(port_match.group(1))
32
+ return 8000
33
+
34
+
35
+ def _parse_env_vars_from_cmd(cmd: str) -> tuple[dict[str, str], str]:
36
+ """Parse environment variables from command string.
37
+
38
+ Args:
39
+ cmd: Command string that may contain environment variables like 'VAR=value command...'
40
+
41
+ Returns:
42
+ Tuple of (env_dict, cleaned_cmd) where env_dict contains parsed env vars
43
+ and cleaned_cmd is the command without the env vars.
44
+ """
45
+ import shlex
46
+
47
+ # Split the command while preserving quoted strings
48
+ parts = shlex.split(cmd)
49
+
50
+ env_vars = {}
51
+ cmd_parts = []
52
+
53
+ for part in parts:
54
+ if '=' in part and not part.startswith('-'):
55
+ # Check if this looks like an environment variable
56
+ # Should be KEY=VALUE format, not contain spaces (unless quoted), and KEY should be uppercase
57
+ key_value = part.split('=', 1)
58
+ if len(key_value) == 2:
59
+ key, value = key_value
60
+ if key.isupper() and key.replace('_', '').isalnum():
61
+ env_vars[key] = value
62
+ continue
63
+
64
+ # Not an env var, add to command parts
65
+ cmd_parts.append(part)
66
+
67
+ # Reconstruct the cleaned command
68
+ cleaned_cmd = ' '.join(cmd_parts)
69
+
70
+ return env_vars, cleaned_cmd
71
+
16
72
 
73
+ def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
74
+ """Start VLLM server and wait for ready."""
75
+ # Parse environment variables from command
76
+ env_vars, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
77
+
78
+ port = _extract_port_from_vllm_cmd(cleaned_cmd)
79
+
80
+ logger.info(f"Starting VLLM server: {cleaned_cmd}")
81
+ if env_vars:
82
+ logger.info(f"Environment variables: {env_vars}")
83
+ logger.info(f"VLLM output logged to: /tmp/vllm_{port}.txt")
84
+
85
+ with open(f'/tmp/vllm_{port}.txt', 'w') as log_file:
86
+ log_file.write(f"VLLM Server started at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
87
+ log_file.write(f"Command: {cleaned_cmd}\n")
88
+ if env_vars:
89
+ log_file.write(f"Environment: {env_vars}\n")
90
+ log_file.write(f"Port: {port}\n")
91
+ log_file.write("-" * 50 + "\n")
92
+
93
+ # Prepare environment for subprocess
94
+ env = os.environ.copy()
95
+ env.update(env_vars)
17
96
 
18
- def _update_port_use(port: int, increment: int) -> None:
19
- file_counter: str = f"/tmp/port_use_counter_{port}.npy"
20
- file_counter_lock: str = f"/tmp/port_use_counter_{port}.lock"
21
- with open(file_counter_lock, "w") as lock_file:
22
- fcntl.flock(lock_file, fcntl.LOCK_EX)
97
+ with open(f'/tmp/vllm_{port}.txt', 'a') as log_file:
98
+ process = subprocess.Popen(
99
+ cleaned_cmd.split(),
100
+ stdout=log_file,
101
+ stderr=subprocess.STDOUT,
102
+ text=True,
103
+ preexec_fn=os.setsid,
104
+ env=env
105
+ )
106
+
107
+ _VLLM_PROCESSES.append(process)
108
+
109
+ start_time = time.time()
110
+ while time.time() - start_time < timeout:
23
111
  try:
24
- if os.path.exists(file_counter):
25
- try:
26
- counter = np.load(file_counter)
27
- except Exception as e:
28
- logger.warning(f"Corrupted usage file {file_counter}: {e}")
29
- counter = np.array([0])
30
- else:
31
- counter: np.ndarray = np.array([0], dtype=np.int64)
32
- counter[0] += increment
33
- _atomic_save(counter, file_counter)
34
- finally:
35
- fcntl.flock(lock_file, fcntl.LOCK_UN)
112
+ response = requests.get(f"http://localhost:{port}/health", timeout=2)
113
+ if response.status_code == 200:
114
+ logger.info(f"VLLM server ready on port {port}")
115
+ return process
116
+ except requests.RequestException:
117
+ pass
118
+
119
+ if process.poll() is not None:
120
+ stdout, stderr = process.communicate()
121
+ raise RuntimeError(
122
+ f"VLLM server terminated unexpectedly. "
123
+ f"Return code: {process.returncode}, "
124
+ f"stderr: {stderr[:200]}..."
125
+ )
126
+
127
+ time.sleep(2)
128
+
129
+ process.terminate()
130
+ try:
131
+ process.wait(timeout=5)
132
+ except subprocess.TimeoutExpired:
133
+ process.kill()
134
+
135
+ if process in _VLLM_PROCESSES:
136
+ _VLLM_PROCESSES.remove(process)
137
+
138
+ raise RuntimeError(f"VLLM server failed to start within {timeout}s on port {port}")
36
139
 
37
140
 
38
- def _pick_least_used_port(ports: List[int]) -> int:
39
- global_lock_file = "/tmp/ports.lock"
40
- with open(global_lock_file, "w") as lock_file:
41
- fcntl.flock(lock_file, fcntl.LOCK_EX)
141
+ def _kill_vllm_on_port(port: int) -> bool:
142
+ """Kill VLLM server on port."""
143
+ killed = False
144
+ logger.info(f"Checking VLLM server on port {port}")
145
+
146
+ processes_to_remove = []
147
+ for process in _VLLM_PROCESSES:
42
148
  try:
43
- port_use: Dict[int, int] = {}
44
- for port in ports:
45
- file_counter = f"/tmp/port_use_counter_{port}.npy"
46
- if os.path.exists(file_counter):
149
+ if process.poll() is None:
150
+ killed_process = False
151
+ if HAS_PSUTIL:
47
152
  try:
48
- counter = np.load(file_counter)
49
- except Exception as e:
50
- logger.warning(f"Corrupted usage file {file_counter}: {e}")
51
- counter = np.array([0])
52
- else:
53
- counter = np.array([0])
54
- port_use[port] = counter[0]
55
- if not port_use:
56
- if ports:
57
- raise ValueError("Port usage data is empty, cannot pick a port.")
58
- else:
59
- raise ValueError("No ports provided to pick from.")
60
- lsp = min(port_use, key=lambda k: port_use[k])
61
- _update_port_use(lsp, 1)
62
- finally:
63
- fcntl.flock(lock_file, fcntl.LOCK_UN)
64
- return lsp
65
-
66
-
67
- def retry_on_exception(max_retries=10, exceptions=(Exception,), sleep_time=3):
68
- def decorator(func):
69
- from functools import wraps
70
-
71
- def wrapper(self, *args, **kwargs):
72
- retry_count = kwargs.get("retry_count", 0)
73
- last_exception = None
74
- while retry_count <= max_retries:
153
+ proc = psutil.Process(process.pid)
154
+ cmdline = ' '.join(proc.cmdline())
155
+ if f'--port {port}' in cmdline or f'--port={port}' in cmdline:
156
+ logger.info(f"Killing tracked VLLM process {process.pid} on port {port}")
157
+ os.killpg(os.getpgid(process.pid), signal.SIGTERM)
158
+ try:
159
+ process.wait(timeout=5)
160
+ except subprocess.TimeoutExpired:
161
+ os.killpg(os.getpgid(process.pid), signal.SIGKILL)
162
+ process.wait()
163
+ killed = True
164
+ killed_process = True
165
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
166
+ pass
167
+
168
+ if not HAS_PSUTIL or not killed_process:
169
+ logger.info(f"Killing tracked VLLM process {process.pid}")
170
+ try:
171
+ os.killpg(os.getpgid(process.pid), signal.SIGTERM)
172
+ try:
173
+ process.wait(timeout=5)
174
+ except subprocess.TimeoutExpired:
175
+ os.killpg(os.getpgid(process.pid), signal.SIGKILL)
176
+ process.wait()
177
+ killed = True
178
+ except (ProcessLookupError, OSError):
179
+ pass
180
+
181
+ processes_to_remove.append(process)
182
+ else:
183
+ processes_to_remove.append(process)
184
+ except (ProcessLookupError, OSError):
185
+ processes_to_remove.append(process)
186
+
187
+ for process in processes_to_remove:
188
+ if process in _VLLM_PROCESSES:
189
+ _VLLM_PROCESSES.remove(process)
190
+
191
+ if not killed and HAS_PSUTIL:
192
+ try:
193
+ for proc in psutil.process_iter(['pid', 'cmdline']):
75
194
  try:
76
- return func(self, *args, **kwargs)
77
- except exceptions as e:
78
- import litellm # type: ignore
79
-
80
- if isinstance(
81
- e, (litellm.exceptions.APIError, litellm.exceptions.Timeout)
82
- ):
83
- base_url_info = kwargs.get(
84
- "base_url", getattr(self, "base_url", None)
85
- )
86
- logger.warning(
87
- f"[{base_url_info=}] {type(e).__name__}: {str(e)[:100]}, will sleep for {sleep_time}s and retry"
88
- )
89
- time.sleep(sleep_time)
90
- retry_count += 1
91
- kwargs["retry_count"] = retry_count
92
- last_exception = e
93
- continue
94
- elif hasattr(
95
- litellm.exceptions, "ContextWindowExceededError"
96
- ) and isinstance(e, litellm.exceptions.ContextWindowExceededError):
97
- logger.error(f"Context window exceeded: {e}")
98
- raise
99
- else:
100
- logger.error(f"Generic error during LLM call: {e}")
101
- import traceback
102
-
103
- traceback.print_exc()
104
- raise
105
- logger.error(f"Retry limit exceeded, error: {last_exception}")
106
- if last_exception:
107
- raise last_exception
108
- raise ValueError("Retry limit exceeded with no specific error.")
109
-
110
- return wraps(func)(wrapper)
111
-
112
- return decorator
113
-
114
-
115
- def forward_only(func):
116
- from functools import wraps
117
-
118
- @wraps(func)
119
- def wrapper(self, *args, **kwargs):
120
- kwargs["retry_count"] = 0
121
- return func(self, *args, **kwargs)
122
-
123
- return wrapper
195
+ cmdline = ' '.join(proc.info['cmdline'] or [])
196
+ if ('vllm' in cmdline.lower() and
197
+ (f'--port {port}' in cmdline or f'--port={port}' in cmdline)):
198
+ logger.info(f"Killing untracked VLLM process {proc.info['pid']} on port {port}")
199
+ proc.terminate()
200
+ try:
201
+ proc.wait(timeout=5)
202
+ except psutil.TimeoutExpired:
203
+ proc.kill()
204
+ killed = True
205
+ break
206
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
207
+ continue
208
+ except Exception as e:
209
+ logger.warning(f"Error searching processes on port {port}: {e}")
210
+
211
+ if killed:
212
+ logger.info(f"Killed VLLM server on port {port}")
213
+ time.sleep(2)
214
+ else:
215
+ logger.info(f"No VLLM server on port {port}")
216
+
217
+ return killed
218
+
219
+
220
+ def stop_vllm_process(process: subprocess.Popen, wait_timeout: int = 10) -> None:
221
+ """Terminate a tracked VLLM process and remove it from tracking."""
222
+ logger.info(f"Stopping VLLM process {process.pid}")
223
+ try:
224
+ if process.poll() is None:
225
+ os.killpg(os.getpgid(process.pid), signal.SIGTERM)
226
+ try:
227
+ process.wait(timeout=wait_timeout)
228
+ logger.info("VLLM process stopped gracefully")
229
+ except subprocess.TimeoutExpired:
230
+ logger.warning("VLLM process didn't stop gracefully, forcing kill")
231
+ os.killpg(os.getpgid(process.pid), signal.SIGKILL)
232
+ process.wait()
233
+ else:
234
+ logger.info("VLLM process already terminated")
235
+ except (ProcessLookupError, OSError) as exc:
236
+ logger.warning(f"Process may have already terminated: {exc}")
237
+ finally:
238
+ if process in _VLLM_PROCESSES:
239
+ _VLLM_PROCESSES.remove(process)
240
+
241
+
242
+ def kill_all_vllm_processes() -> int:
243
+ """Kill all tracked VLLM processes."""
244
+ killed_count = 0
245
+ for process in list(_VLLM_PROCESSES):
246
+ if process.poll() is None:
247
+ logger.info(f"Killing VLLM process with PID {process.pid}")
248
+ stop_vllm_process(process, wait_timeout=5)
249
+ killed_count += 1
250
+ else:
251
+ _VLLM_PROCESSES.remove(process)
252
+ logger.info(f"Killed {killed_count} VLLM processes")
253
+ return killed_count
254
+
255
+
256
+ def _is_server_running(port: int) -> bool:
257
+ """Check if server is running on port."""
258
+ try:
259
+ response = requests.get(f"http://localhost:{port}/health", timeout=2)
260
+ return response.status_code == 200
261
+ except requests.RequestException:
262
+ return False
263
+
264
+
265
+ def get_base_client(
266
+ client=None,
267
+ cache: bool = True,
268
+ api_key="abc",
269
+ vllm_cmd=None,
270
+ vllm_process=None
271
+ ) -> OpenAI:
272
+ """Get OpenAI client from various inputs."""
273
+ from llm_utils import MOpenAI
274
+
275
+ open_ai_class = OpenAI if not cache else MOpenAI
276
+
277
+ if client is None:
278
+ if vllm_cmd is not None:
279
+ # Parse environment variables from command to get clean command for port extraction
280
+ _, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
281
+ port = _extract_port_from_vllm_cmd(cleaned_cmd)
282
+ return open_ai_class(base_url=f"http://localhost:{port}/v1", api_key=api_key)
283
+ else:
284
+ return open_ai_class()
285
+ elif isinstance(client, int):
286
+ return open_ai_class(base_url=f"http://localhost:{client}/v1", api_key=api_key)
287
+ elif isinstance(client, str):
288
+ return open_ai_class(base_url=client, api_key=api_key)
289
+ elif isinstance(client, OpenAI):
290
+ return client
291
+ else:
292
+ raise ValueError("Invalid client type. Must be OpenAI, port (int), base_url (str), or None.")
293
+
294
+
295
+ def _is_lora_path(path: str) -> bool:
296
+ """Check if path is LoRA adapter directory."""
297
+ if not os.path.isdir(path):
298
+ return False
299
+ adapter_config_path = os.path.join(path, 'adapter_config.json')
300
+ return os.path.isfile(adapter_config_path)
301
+
302
+
303
+ def _get_port_from_client(client: OpenAI) -> Optional[int]:
304
+ """Extract port from OpenAI client base_url."""
305
+ if hasattr(client, 'base_url') and client.base_url:
306
+ base_url = str(client.base_url)
307
+ if 'localhost:' in base_url:
308
+ try:
309
+ port_part = base_url.split('localhost:')[1].split('/')[0]
310
+ return int(port_part)
311
+ except (IndexError, ValueError):
312
+ pass
313
+ return None
314
+
315
+
316
+ def _load_lora_adapter(lora_path: str, port: int) -> str:
317
+ """Load LoRA adapter from path."""
318
+ lora_name = os.path.basename(lora_path.rstrip('/\\'))
319
+ if not lora_name:
320
+ lora_name = os.path.basename(os.path.dirname(lora_path))
321
+
322
+ response = requests.post(
323
+ f'http://localhost:{port}/v1/load_lora_adapter',
324
+ headers={'accept': 'application/json', 'Content-Type': 'application/json'},
325
+ json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
326
+ )
327
+ response.raise_for_status()
328
+ return lora_name
329
+
330
+
331
+ def _unload_lora_adapter(lora_path: str, port: int) -> None:
332
+ """Unload LoRA adapter."""
333
+ try:
334
+ lora_name = os.path.basename(lora_path.rstrip('/\\'))
335
+ if not lora_name:
336
+ lora_name = os.path.basename(os.path.dirname(lora_path))
337
+
338
+ response = requests.post(
339
+ f'http://localhost:{port}/v1/unload_lora_adapter',
340
+ headers={'accept': 'application/json', 'Content-Type': 'application/json'},
341
+ json={"lora_name": lora_name, "lora_int_id": 0}
342
+ )
343
+ response.raise_for_status()
344
+ except requests.RequestException as e:
345
+ logger.warning(f"Error unloading LoRA adapter: {str(e)[:100]}")
@@ -1,15 +1,17 @@
1
1
  # ray_multi_process.py
2
- import time, os, pickle, uuid, datetime, multiprocessing
3
2
  import datetime
4
3
  import os
5
4
  import pickle
5
+ import threading
6
6
  import time
7
7
  import uuid
8
8
  from pathlib import Path
9
- from typing import Any, Callable
10
- from tqdm import tqdm
9
+ from typing import Any, Callable, Iterable
10
+
11
11
  import psutil
12
- import threading
12
+ from fastcore.parallel import parallel
13
+ from tqdm import tqdm
14
+
13
15
  ray: Any
14
16
  try:
15
17
  import ray as ray # type: ignore
@@ -17,11 +19,75 @@ try:
17
19
  except Exception: # pragma: no cover
18
20
  ray = None # type: ignore
19
21
  _HAS_RAY = False
20
- from typing import Any, Callable, Iterable
21
22
 
22
- import ray
23
- from fastcore.parallel import parallel
24
- from tqdm import tqdm
23
+
24
+ # ─── global tracking ──────────────────────────────────────────
25
+
26
+ # Global tracking for processes and threads
27
+ SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
28
+ _SPEEDY_PROCESSES_LOCK = threading.Lock()
29
+
30
+
31
+ def _prune_dead_processes() -> None:
32
+ """Remove dead processes from tracking list."""
33
+ with _SPEEDY_PROCESSES_LOCK:
34
+ SPEEDY_RUNNING_PROCESSES[:] = [p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()]
35
+
36
+
37
+ def _track_processes(processes: list[psutil.Process]) -> None:
38
+ """Add processes to global tracking list."""
39
+ if not processes:
40
+ return
41
+ with _SPEEDY_PROCESSES_LOCK:
42
+ living = [p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()]
43
+ for candidate in processes:
44
+ if not candidate.is_running():
45
+ continue
46
+ if any(existing.pid == candidate.pid for existing in living):
47
+ continue
48
+ living.append(candidate)
49
+ SPEEDY_RUNNING_PROCESSES[:] = living
50
+
51
+
52
+ def _track_ray_processes() -> None:
53
+ """Track Ray worker processes when Ray is initialized."""
54
+ if not _HAS_RAY or not ray.is_initialized():
55
+ return
56
+ try:
57
+ # Get Ray worker processes
58
+ current_pid = os.getpid()
59
+ parent = psutil.Process(current_pid)
60
+ ray_processes = []
61
+ for child in parent.children(recursive=True):
62
+ try:
63
+ if 'ray' in child.name().lower() or 'worker' in child.name().lower():
64
+ ray_processes.append(child)
65
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
66
+ continue
67
+ _track_processes(ray_processes)
68
+ except Exception:
69
+ # Don't fail if process tracking fails
70
+ pass
71
+
72
+
73
+ def _track_multiprocessing_processes() -> None:
74
+ """Track multiprocessing worker processes."""
75
+ try:
76
+ # Find recently created child processes that might be multiprocessing workers
77
+ current_pid = os.getpid()
78
+ parent = psutil.Process(current_pid)
79
+ new_processes = []
80
+ for child in parent.children(recursive=False): # Only direct children
81
+ try:
82
+ # Basic heuristic: if it's a recent child process, it might be a worker
83
+ if time.time() - child.create_time() < 5: # Created within last 5 seconds
84
+ new_processes.append(child)
85
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
86
+ continue
87
+ _track_processes(new_processes)
88
+ except Exception:
89
+ # Don't fail if process tracking fails
90
+ pass
25
91
 
26
92
 
27
93
  # ─── cache helpers ──────────────────────────────────────────
@@ -68,6 +134,7 @@ def ensure_ray(workers: int, pbar: tqdm | None = None):
68
134
  t0 = time.time()
69
135
  ray.init(num_cpus=workers, ignore_reinit_error=True)
70
136
  took = time.time() - t0
137
+ _track_ray_processes() # Track Ray worker processes
71
138
  if pbar:
72
139
  pbar.set_postfix_str(f"ray.init {workers} took {took:.2f}s")
73
140
  RAY_WORKER = workers
@@ -87,11 +154,6 @@ def multi_process(
87
154
  progress: bool = True,
88
155
  # backend: str = "ray", # "seq", "ray", or "fastcore"
89
156
  backend: Literal["seq", "ray", "mp", "threadpool", "safe"] = "mp",
90
- # Additional optional knobs (accepted for compatibility)
91
- batch: int | None = None,
92
- ordered: bool | None = None,
93
- process_update_interval: int | None = None,
94
- stop_on_error: bool | None = None,
95
157
  **func_kwargs: Any,
96
158
  ) -> list[Any]:
97
159
  """
@@ -171,6 +233,8 @@ def multi_process(
171
233
  results = parallel(
172
234
  f_wrapped, items, n_workers=workers, progress=progress, threadpool=False
173
235
  )
236
+ _track_multiprocessing_processes() # Track multiprocessing workers
237
+ _prune_dead_processes() # Clean up dead processes
174
238
  return list(results)
175
239
  if backend == "threadpool":
176
240
  results = parallel(
@@ -180,8 +244,17 @@ def multi_process(
180
244
  if backend == "safe":
181
245
  # Completely safe backend for tests - no multiprocessing, no external progress bars
182
246
  import concurrent.futures
183
- with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
184
- results = list(executor.map(f_wrapped, items))
247
+ # Import thread tracking from thread module
248
+ try:
249
+ from .thread import _track_executor_threads, _prune_dead_threads
250
+ with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
251
+ _track_executor_threads(executor) # Track threads
252
+ results = list(executor.map(f_wrapped, items))
253
+ _prune_dead_threads() # Clean up dead threads
254
+ except ImportError:
255
+ # Fallback if thread module not available
256
+ with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
257
+ results = list(executor.map(f_wrapped, items))
185
258
  return results
186
259
 
187
260
  raise ValueError(f"Unsupported backend: {backend!r}")
@@ -190,12 +263,24 @@ def multi_process(
190
263
 
191
264
  def cleanup_phantom_workers():
192
265
  """
193
- Kill all child processes (phantom workers) without killing the Jupyter kernel itself.
266
+ Kill all tracked processes and threads (phantom workers) without killing the Jupyter kernel itself.
194
267
  Also lists non-daemon threads that remain.
195
268
  """
196
- parent = psutil.Process(os.getpid())
269
+ # Clean up tracked processes first
270
+ _prune_dead_processes()
271
+ killed_processes = 0
272
+ with _SPEEDY_PROCESSES_LOCK:
273
+ for process in SPEEDY_RUNNING_PROCESSES[:]: # Copy to avoid modification during iteration
274
+ try:
275
+ print(f"🔪 Killing tracked process {process.pid} ({process.name()})")
276
+ process.kill()
277
+ killed_processes += 1
278
+ except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
279
+ print(f"⚠️ Could not kill process {process.pid}: {e}")
280
+ SPEEDY_RUNNING_PROCESSES.clear()
197
281
 
198
- # Kill only children, never the current process
282
+ # Also kill any remaining child processes (fallback)
283
+ parent = psutil.Process(os.getpid())
199
284
  for child in parent.children(recursive=True):
200
285
  try:
201
286
  print(f"🔪 Killing child process {child.pid} ({child.name()})")
@@ -203,14 +288,29 @@ def cleanup_phantom_workers():
203
288
  except psutil.NoSuchProcess:
204
289
  pass
205
290
 
206
- # Report stray threads (can't hard-kill them in Python)
207
- for t in threading.enumerate():
208
- if t is threading.current_thread():
209
- continue
210
- if not t.daemon:
211
- print(f"⚠️ Thread {t.name} is still running (cannot be force-killed).")
291
+ # Try to clean up threads using thread module functions if available
292
+ try:
293
+ from .thread import SPEEDY_RUNNING_THREADS, kill_all_thread, _prune_dead_threads
294
+ _prune_dead_threads()
295
+ killed_threads = kill_all_thread()
296
+ if killed_threads > 0:
297
+ print(f"🔪 Killed {killed_threads} tracked threads")
298
+ except ImportError:
299
+ # Fallback: just report stray threads
300
+ for t in threading.enumerate():
301
+ if t is threading.current_thread():
302
+ continue
303
+ if not t.daemon:
304
+ print(f"⚠️ Thread {t.name} is still running (cannot be force-killed).")
212
305
 
213
- print("✅ Cleaned up child processes (kernel untouched).")
306
+ print(f"✅ Cleaned up {killed_processes} tracked processes and child processes (kernel untouched).")
214
307
 
215
308
  # Usage: run this anytime after cancelling a cell
216
309
 
310
+
311
+ __all__ = [
312
+ "SPEEDY_RUNNING_PROCESSES",
313
+ "multi_process",
314
+ "cleanup_phantom_workers",
315
+ ]
316
+