speedy-utils 1.1.46__py3-none-any.whl → 1.1.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,879 @@
1
+ """
2
+ Common utilities shared across multi_process backends.
3
+
4
+ Includes:
5
+ - Error formatting and logging
6
+ - Log gating (stdout/stderr control)
7
+ - Cache helpers
8
+ - Process/thread tracking
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import contextlib
13
+ import linecache
14
+ import os
15
+ import pickle
16
+ import sys
17
+ import tempfile
18
+ import threading
19
+ import time
20
+ import uuid
21
+ from pathlib import Path
22
+ from typing import TYPE_CHECKING, Any, Callable, Literal
23
+
24
+ import psutil
25
+
26
+ if TYPE_CHECKING:
27
+ from tqdm import tqdm
28
+
29
+ # ─── error handler types ────────────────────────────────────────
30
+ ErrorHandlerType = Literal['raise', 'ignore', 'log']
31
+
32
+
33
+ class ErrorStats:
34
+ """Thread-safe error statistics tracker."""
35
+
36
+ def __init__(
37
+ self,
38
+ func_name: str,
39
+ max_error_files: int = 100,
40
+ write_logs: bool = True
41
+ ):
42
+ self._lock = threading.Lock()
43
+ self._success_count = 0
44
+ self._error_count = 0
45
+ self._first_error_shown = False
46
+ self._max_error_files = max_error_files
47
+ self._write_logs = write_logs
48
+ self._error_log_dir = self._get_error_log_dir(func_name)
49
+ if write_logs:
50
+ self._error_log_dir.mkdir(parents=True, exist_ok=True)
51
+
52
+ @staticmethod
53
+ def _get_error_log_dir(func_name: str) -> Path:
54
+ """Generate unique error log directory with run counter."""
55
+ base_dir = Path('.cache/speedy_utils/error_logs')
56
+ base_dir.mkdir(parents=True, exist_ok=True)
57
+
58
+ # Find the next run counter
59
+ counter = 1
60
+ existing = list(base_dir.glob(f'{func_name}_run_*'))
61
+ if existing:
62
+ counters = []
63
+ for p in existing:
64
+ try:
65
+ parts = p.name.split('_run_')
66
+ if len(parts) == 2:
67
+ counters.append(int(parts[1]))
68
+ except (ValueError, IndexError):
69
+ pass
70
+ if counters:
71
+ counter = max(counters) + 1
72
+
73
+ return base_dir / f'{func_name}_run_{counter}'
74
+
75
+ def record_success(self) -> None:
76
+ with self._lock:
77
+ self._success_count += 1
78
+
79
+ def record_error(
80
+ self,
81
+ idx: int,
82
+ error: Exception,
83
+ input_value: Any,
84
+ func_name: str,
85
+ ray_task_error: Exception | None = None,
86
+ ) -> str | None:
87
+ """
88
+ Record an error and write to log file.
89
+ Returns the log file path if written, None otherwise.
90
+
91
+ Args:
92
+ ray_task_error: Optional RayTaskError for fallback frame extraction
93
+ when the native traceback is unavailable.
94
+ """
95
+ with self._lock:
96
+ self._error_count += 1
97
+ should_show_first = not self._first_error_shown
98
+ if should_show_first:
99
+ self._first_error_shown = True
100
+ should_write = (
101
+ self._write_logs and self._error_count <= self._max_error_files
102
+ )
103
+
104
+ log_path = None
105
+ if should_write:
106
+ log_path = self._write_error_log(
107
+ idx, error, input_value, func_name, ray_task_error
108
+ )
109
+
110
+ if should_show_first:
111
+ self._print_first_error(
112
+ error, input_value, func_name, log_path, ray_task_error
113
+ )
114
+
115
+ return log_path
116
+
117
+ def _write_error_log(
118
+ self,
119
+ idx: int,
120
+ error: Exception,
121
+ input_value: Any,
122
+ func_name: str,
123
+ ray_task_error: Exception | None = None,
124
+ ) -> str:
125
+ """Write error details to a log file."""
126
+ from io import StringIO
127
+ from rich.console import Console
128
+
129
+ log_path = self._error_log_dir / f'{idx}.log'
130
+
131
+ output = StringIO()
132
+ console = Console(file=output, width=120, no_color=False)
133
+
134
+ # Format traceback using unified extraction
135
+ tb_lines = _format_traceback_lines(
136
+ _extract_frames(error, ray_task_error),
137
+ include_locals=False,
138
+ )
139
+
140
+ console.print(f'{"=" * 60}')
141
+ console.print(f'Error at index: {idx}')
142
+ console.print(f'Function: {func_name}')
143
+ console.print(f'Error Type: {type(error).__name__}')
144
+ console.print(f'Error Message: {error}')
145
+ console.print(f'{"=" * 60}')
146
+ console.print('')
147
+ console.print('Input:')
148
+ console.print('-' * 40)
149
+ try:
150
+ import json
151
+ console.print(json.dumps(input_value, indent=2))
152
+ except Exception:
153
+ console.print(repr(input_value))
154
+ console.print('')
155
+ console.print('Traceback:')
156
+ console.print('-' * 40)
157
+ for line in tb_lines:
158
+ console.print(line)
159
+
160
+ with open(log_path, 'w') as f:
161
+ f.write(output.getvalue())
162
+
163
+ return str(log_path)
164
+
165
+ def _print_first_error(
166
+ self,
167
+ error: Exception,
168
+ input_value: Any,
169
+ func_name: str,
170
+ log_path: str | None,
171
+ ray_task_error: Exception | None = None,
172
+ ) -> None:
173
+ """Print the first error to screen with Rich formatting."""
174
+ try:
175
+ from rich.console import Console
176
+ from rich.panel import Panel
177
+ console = Console(stderr=True)
178
+
179
+ # Use unified frame extraction
180
+ tb_lines = _format_traceback_lines(
181
+ _extract_frames(error, ray_task_error),
182
+ include_locals=False,
183
+ )
184
+
185
+ console.print()
186
+ console.print(
187
+ Panel(
188
+ '\n'.join(tb_lines),
189
+ title=(
190
+ '[bold red]First Error '
191
+ '(continuing with remaining items)[/bold red]'
192
+ ),
193
+ border_style='yellow',
194
+ expand=False,
195
+ )
196
+ )
197
+ console.print(
198
+ f'[bold red]{type(error).__name__}[/bold red]: {error}'
199
+ )
200
+ if log_path:
201
+ console.print(f'[dim]Error log: {log_path}[/dim]')
202
+ console.print()
203
+ except ImportError:
204
+ # Fallback to plain print
205
+ print('\n--- First Error ---', file=sys.stderr)
206
+ print(f'{type(error).__name__}: {error}', file=sys.stderr)
207
+ if log_path:
208
+ print(f'Error log: {log_path}', file=sys.stderr)
209
+ print('', file=sys.stderr)
210
+
211
+ @property
212
+ def success_count(self) -> int:
213
+ with self._lock:
214
+ return self._success_count
215
+
216
+ @property
217
+ def error_count(self) -> int:
218
+ with self._lock:
219
+ return self._error_count
220
+
221
+ def get_postfix_dict(self) -> dict[str, int]:
222
+ """Get dict for pbar postfix."""
223
+ with self._lock:
224
+ return {'ok': self._success_count, 'err': self._error_count}
225
+
226
+
227
+ # ─── Traceback formatting utilities ─────────────────────────────
228
+
229
+
230
+ def _should_skip_frame(filepath: str) -> bool:
231
+ """Check if a frame should be filtered from traceback display."""
232
+ skip_patterns = [
233
+ 'ray/_private',
234
+ 'ray/worker',
235
+ 'site-packages/ray',
236
+ 'speedy_utils/multi_worker',
237
+ 'concurrent/futures',
238
+ 'multiprocessing/',
239
+ 'fastcore/parallel',
240
+ 'fastcore/foundation',
241
+ 'fastcore/basics',
242
+ 'site-packages/fastcore',
243
+ '/threading.py',
244
+ '/concurrent/',
245
+ ]
246
+ return any(skip in filepath for skip in skip_patterns)
247
+
248
+
249
+ def _should_show_local(name: str, value: object) -> bool:
250
+ """Check if a local variable should be displayed in traceback."""
251
+ import types
252
+
253
+ # Skip dunder variables
254
+ if name.startswith('__') and name.endswith('__'):
255
+ return False
256
+
257
+ # Skip modules
258
+ if isinstance(value, types.ModuleType):
259
+ return False
260
+
261
+ # Skip type objects and classes
262
+ if isinstance(value, type):
263
+ return False
264
+
265
+ # Skip functions and methods
266
+ if isinstance(
267
+ value,
268
+ (types.FunctionType, types.MethodType, types.BuiltinFunctionType)
269
+ ):
270
+ return False
271
+
272
+ # Skip common typing aliases
273
+ value_str = str(value)
274
+ if value_str.startswith('typing.'):
275
+ return False
276
+
277
+ # Skip large objects that would clutter output
278
+ skip_markers = ['module', 'function', 'method', 'built-in']
279
+ if value_str.startswith('<') and any(x in value_str for x in skip_markers):
280
+ return False
281
+
282
+ return True
283
+
284
+
285
+ def _format_locals(frame_locals: dict) -> list[str]:
286
+ """Format local variables for display, filtering out noisy imports."""
287
+ from io import StringIO
288
+ from rich.console import Console
289
+ from rich.pretty import Pretty
290
+
291
+ # Filter locals
292
+ clean_locals = {
293
+ k: v for k, v in frame_locals.items() if _should_show_local(k, v)
294
+ }
295
+
296
+ if not clean_locals:
297
+ return []
298
+
299
+ lines = []
300
+ lines.append('[dim]╭─ locals ─╮[/dim]')
301
+
302
+ # Format each local variable
303
+ for name, value in clean_locals.items():
304
+ # Use Rich's Pretty for nice formatting
305
+ try:
306
+ console = Console(file=StringIO(), width=60)
307
+ console.print(Pretty(value), end='')
308
+ value_str = console.file.getvalue().strip()
309
+ # Limit length
310
+ if len(value_str) > 100:
311
+ value_str = value_str[:97] + '...'
312
+ except Exception:
313
+ value_str = repr(value)
314
+ if len(value_str) > 100:
315
+ value_str = value_str[:97] + '...'
316
+
317
+ lines.append(f'[dim]│[/dim] {name} = {value_str}')
318
+
319
+ lines.append('[dim]╰──────────╯[/dim]')
320
+ return lines
321
+
322
+
323
+ def _format_frame_with_context(
324
+ filepath: str,
325
+ lineno: int,
326
+ funcname: str,
327
+ frame_locals: dict | None = None,
328
+ ) -> list[str]:
329
+ """Format a single frame with context lines and optional locals."""
330
+ lines = []
331
+ # Frame header
332
+ lines.append(
333
+ f'[cyan]{filepath}[/cyan]:[yellow]{lineno}[/yellow] '
334
+ f'in [green]{funcname}[/green]'
335
+ )
336
+ lines.append('')
337
+
338
+ # Get context lines
339
+ context_size = 3
340
+ start_line = max(1, lineno - context_size)
341
+ end_line = lineno + context_size + 1
342
+
343
+ for line_num in range(start_line, end_line):
344
+ line_text = linecache.getline(filepath, line_num).rstrip()
345
+ if line_text:
346
+ num_str = str(line_num).rjust(4)
347
+ if line_num == lineno:
348
+ lines.append(f'[dim]{num_str}[/dim] [red]❱[/red] {line_text}')
349
+ else:
350
+ lines.append(f'[dim]{num_str} │[/dim] {line_text}')
351
+
352
+ # Add locals if available
353
+ if frame_locals:
354
+ locals_lines = _format_locals(frame_locals)
355
+ if locals_lines:
356
+ lines.append('')
357
+ lines.extend(locals_lines)
358
+
359
+ lines.append('')
360
+ return lines
361
+
362
+
363
+ def _format_traceback_lines(
364
+ frames: list[tuple[str, int, str, dict]],
365
+ *,
366
+ caller_info: dict | None = None,
367
+ include_locals: bool = True,
368
+ ) -> list[str]:
369
+ """Format frames into a list of Rich-markup lines for display/logging."""
370
+ display_lines: list[str] = []
371
+ if caller_info:
372
+ display_lines.extend(
373
+ _format_frame_with_context(
374
+ caller_info['filename'],
375
+ caller_info['lineno'],
376
+ caller_info['function'],
377
+ None,
378
+ )
379
+ )
380
+ for filepath, lineno, funcname, frame_locals in frames:
381
+ locals_for_frame = frame_locals if include_locals else None
382
+ display_lines.extend(
383
+ _format_frame_with_context(
384
+ filepath,
385
+ lineno,
386
+ funcname,
387
+ locals_for_frame,
388
+ )
389
+ )
390
+ return display_lines
391
+
392
+
393
+ def _extract_frames_from_traceback(
394
+ error: Exception,
395
+ ) -> list[tuple[str, int, str, dict]]:
396
+ """Extract user frames from exception traceback object with locals."""
397
+ frames = []
398
+ if hasattr(error, '__traceback__') and error.__traceback__ is not None:
399
+ tb = error.__traceback__
400
+ while tb is not None:
401
+ frame = tb.tb_frame
402
+ filename = frame.f_code.co_filename
403
+ lineno = tb.tb_lineno
404
+ funcname = frame.f_code.co_name
405
+
406
+ if not _should_skip_frame(filename):
407
+ # Get local variables from the frame
408
+ frame_locals = dict(frame.f_locals)
409
+ frames.append((filename, lineno, funcname, frame_locals))
410
+
411
+ tb = tb.tb_next
412
+ return frames
413
+
414
+
415
+ def _extract_frames_from_ray_error(
416
+ ray_task_error: Exception,
417
+ ) -> list[tuple[str, int, str, dict]]:
418
+ """Extract user frames from Ray's string traceback representation."""
419
+ import re
420
+
421
+ frames = []
422
+ error_str = str(ray_task_error)
423
+ lines = error_str.split('\n')
424
+
425
+ for line in lines:
426
+ # Match: File "path", line N, in func
427
+ file_match = re.match(r'\s*File "([^"]+)", line (\d+), in (.+)', line)
428
+ if file_match:
429
+ filepath, lineno, funcname = file_match.groups()
430
+ if not _should_skip_frame(filepath):
431
+ # Ray doesn't preserve locals, so use empty dict
432
+ frames.append((filepath, int(lineno), funcname, {}))
433
+
434
+ return frames
435
+
436
+
437
+ def _extract_frames(
438
+ error: Exception,
439
+ ray_task_error: Exception | None = None,
440
+ ) -> list[tuple[str, int, str, dict]]:
441
+ """
442
+ Unified frame extraction that works for both native exceptions and Ray errors.
443
+
444
+ First tries to extract frames from the error's __traceback__.
445
+ If that's empty and ray_task_error is provided, falls back to parsing
446
+ the Ray error string representation.
447
+ """
448
+ # Try native traceback extraction first
449
+ frames = _extract_frames_from_traceback(error)
450
+
451
+ # If empty and we have a Ray error, try string parsing
452
+ if not frames and ray_task_error is not None:
453
+ frames = _extract_frames_from_ray_error(ray_task_error)
454
+
455
+ return frames
456
+
457
+
458
+ def _display_formatted_error_and_exit(
459
+ exc_type_name: str,
460
+ exc_msg: str,
461
+ frames: list[tuple[str, int, str, dict]],
462
+ caller_info: dict | None,
463
+ backend: str,
464
+ pbar: tqdm | None = None,
465
+ ) -> None:
466
+ """Display a formatted error and exit the process."""
467
+ # Suppress additional error logs
468
+ os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
469
+
470
+ # Close progress bar cleanly if provided
471
+ if pbar is not None:
472
+ pbar.close()
473
+
474
+ from rich.console import Console
475
+ from rich.panel import Panel
476
+ console = Console(stderr=True)
477
+
478
+ if frames or caller_info:
479
+ display_lines = _format_traceback_lines(
480
+ frames,
481
+ caller_info=caller_info,
482
+ include_locals=True,
483
+ )
484
+
485
+ # Display the traceback
486
+ console.print()
487
+ console.print(
488
+ Panel(
489
+ '\n'.join(display_lines),
490
+ title=(
491
+ f'[bold red]Traceback (most recent call last) '
492
+ f'[{backend}][/bold red]'
493
+ ),
494
+ border_style='red',
495
+ expand=False,
496
+ )
497
+ )
498
+ console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
499
+ console.print()
500
+ else:
501
+ # No frames found, minimal output
502
+ console.print()
503
+ console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
504
+ console.print()
505
+
506
+ # Ensure output is flushed
507
+ sys.stderr.flush()
508
+ sys.stdout.flush()
509
+ sys.exit(1)
510
+
511
+
512
+ def _exit_on_worker_error(
513
+ error: Exception,
514
+ pbar: tqdm | None = None,
515
+ caller_info: dict | None = None,
516
+ backend: str = 'unknown',
517
+ ) -> None:
518
+ """Display a clean traceback for a worker error and exit the process."""
519
+ frames = _extract_frames_from_traceback(error)
520
+ _display_formatted_error_and_exit(
521
+ exc_type_name=type(error).__name__,
522
+ exc_msg=str(error),
523
+ frames=frames,
524
+ caller_info=caller_info,
525
+ backend=backend,
526
+ pbar=pbar,
527
+ )
528
+
529
+
530
+ def _exit_on_ray_error(
531
+ ray_task_error: Exception,
532
+ pbar: tqdm | None = None,
533
+ caller_info: dict | None = None,
534
+ ) -> None:
535
+ """Display a clean traceback for a RayTaskError and exit the process."""
536
+ # Get the exception info
537
+ cause = (
538
+ ray_task_error.cause
539
+ if hasattr(ray_task_error, 'cause')
540
+ else None
541
+ )
542
+ if cause is None:
543
+ cause = ray_task_error.__cause__
544
+
545
+ exc_type_name = type(cause).__name__ if cause else 'Error'
546
+ exc_msg = str(cause) if cause else str(ray_task_error)
547
+
548
+ frames = []
549
+ if cause:
550
+ frames = _extract_frames_from_traceback(cause)
551
+ if not frames:
552
+ frames = _extract_frames_from_ray_error(ray_task_error)
553
+ _display_formatted_error_and_exit(
554
+ exc_type_name=exc_type_name,
555
+ exc_msg=exc_msg,
556
+ frames=frames,
557
+ caller_info=caller_info,
558
+ backend='ray',
559
+ pbar=pbar,
560
+ )
561
+
562
+
563
+ # ─── Process/thread tracking ────────────────────────────────────
564
+
565
+ SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
566
+ _SPEEDY_PROCESSES_LOCK = threading.Lock()
567
+
568
+
569
+ def _prune_dead_processes() -> None:
570
+ """Remove dead processes from tracking list."""
571
+ with _SPEEDY_PROCESSES_LOCK:
572
+ SPEEDY_RUNNING_PROCESSES[:] = [
573
+ p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()
574
+ ]
575
+
576
+
577
+ def _track_processes(processes: list[psutil.Process]) -> None:
578
+ """Add processes to global tracking list."""
579
+ if not processes:
580
+ return
581
+ with _SPEEDY_PROCESSES_LOCK:
582
+ living = [p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()]
583
+ for candidate in processes:
584
+ if not candidate.is_running():
585
+ continue
586
+ if any(existing.pid == candidate.pid for existing in living):
587
+ continue
588
+ living.append(candidate)
589
+ SPEEDY_RUNNING_PROCESSES[:] = living
590
+
591
+
592
+ def _track_multiprocessing_processes() -> None:
593
+ """Track multiprocessing worker processes."""
594
+ try:
595
+ # Find recently created child processes
596
+ current_pid = os.getpid()
597
+ parent = psutil.Process(current_pid)
598
+ new_processes = []
599
+ for child in parent.children(recursive=False):
600
+ try:
601
+ # Created within last 5 seconds
602
+ if time.time() - child.create_time() < 5:
603
+ new_processes.append(child)
604
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
605
+ continue
606
+ _track_processes(new_processes)
607
+ except Exception:
608
+ # Don't fail if process tracking fails
609
+ pass
610
+
611
+
612
+ def _track_ray_processes() -> None:
613
+ """Track Ray worker processes when Ray is initialized."""
614
+ try:
615
+ current_pid = os.getpid()
616
+ parent = psutil.Process(current_pid)
617
+ ray_processes = []
618
+ for child in parent.children(recursive=True):
619
+ try:
620
+ name = child.name().lower()
621
+ if 'ray' in name or 'worker' in name:
622
+ ray_processes.append(child)
623
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
624
+ continue
625
+ _track_processes(ray_processes)
626
+ except Exception:
627
+ pass
628
+
629
+
630
+ # ─── Log gating utilities ───────────────────────────────────────
631
+
632
+ _LOG_GATE_CACHE: dict[str, bool] = {}
633
+
634
+
635
+ def _should_allow_worker_logs(
636
+ mode: Literal['all', 'zero', 'first'],
637
+ gate_path: Path | None,
638
+ ) -> bool:
639
+ """Determine if current worker should emit logs for the given mode."""
640
+ if mode == 'all':
641
+ return True
642
+ if mode == 'zero':
643
+ return False
644
+ if mode == 'first':
645
+ if gate_path is None:
646
+ return True
647
+ key = str(gate_path)
648
+ cached = _LOG_GATE_CACHE.get(key)
649
+ if cached is not None:
650
+ return cached
651
+ gate_path.parent.mkdir(parents=True, exist_ok=True)
652
+ try:
653
+ fd = os.open(key, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
654
+ except FileExistsError:
655
+ allowed = False
656
+ else:
657
+ os.close(fd)
658
+ allowed = True
659
+ _LOG_GATE_CACHE[key] = allowed
660
+ return allowed
661
+ raise ValueError(f'Unsupported log mode: {mode!r}')
662
+
663
+
664
+ def _cleanup_log_gate(gate_path: Path | None) -> None:
665
+ """Remove the log gate file if it exists."""
666
+ if gate_path is None:
667
+ return
668
+ try:
669
+ gate_path.unlink(missing_ok=True)
670
+ except OSError:
671
+ pass
672
+
673
+
674
+ class _PrefixedWriter:
675
+ """Stream wrapper that prefixes each line with worker id."""
676
+
677
+ def __init__(self, stream, prefix: str):
678
+ self._stream = stream
679
+ self._prefix = prefix
680
+ self._at_line_start = True
681
+
682
+ def write(self, s):
683
+ if not s:
684
+ return 0
685
+ total = 0
686
+ for chunk in s.splitlines(True):
687
+ if self._at_line_start:
688
+ self._stream.write(self._prefix)
689
+ total += len(self._prefix)
690
+ self._stream.write(chunk)
691
+ total += len(chunk)
692
+ self._at_line_start = chunk.endswith('\n')
693
+ return total
694
+
695
+ def flush(self):
696
+ self._stream.flush()
697
+
698
+
699
+ def _call_with_log_control(
700
+ func: Callable,
701
+ x: Any,
702
+ func_kwargs: dict[str, Any],
703
+ log_mode: Literal['all', 'zero', 'first'],
704
+ gate_path: Path | None,
705
+ ):
706
+ """Call a function, silencing stdout/stderr based on log mode."""
707
+ allow_logs = _should_allow_worker_logs(log_mode, gate_path)
708
+ if allow_logs:
709
+ prefix = f'[worker-{os.getpid()}] '
710
+ # Route worker logs to stderr to reduce clobbering tqdm on stdout
711
+ out = _PrefixedWriter(sys.stderr, prefix)
712
+ err = out
713
+ with contextlib.redirect_stdout(out), contextlib.redirect_stderr(err):
714
+ return func(x, **func_kwargs)
715
+ with (
716
+ open(os.devnull, 'w') as devnull,
717
+ contextlib.redirect_stdout(devnull),
718
+ contextlib.redirect_stderr(devnull),
719
+ ):
720
+ return func(x, **func_kwargs)
721
+
722
+
723
+ # ─── Cache helpers ──────────────────────────────────────────────
724
+
725
+
726
+ def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
727
+ """Build cache dir with function name + timestamp."""
728
+ import datetime
729
+ func_name = getattr(func, '__name__', 'func')
730
+ now = datetime.datetime.now()
731
+ stamp = now.strftime('%m%d_%Hh%Mm%Ss')
732
+ run_id = f'{func_name}_{stamp}_{uuid.uuid4().hex[:6]}'
733
+ path = Path('.cache') / run_id
734
+ path.mkdir(parents=True, exist_ok=True)
735
+ return path
736
+
737
+
738
+ def wrap_dump(
739
+ func: Callable,
740
+ cache_dir: Path | None,
741
+ dump_in_thread: bool = True,
742
+ ):
743
+ """Wrap a function so results are dumped to .pkl when cache_dir is set."""
744
+ if cache_dir is None:
745
+ return func
746
+
747
+ def wrapped(x, *args, **kwargs):
748
+ res = func(x, *args, **kwargs)
749
+ p = cache_dir / f'{uuid.uuid4().hex}.pkl'
750
+
751
+ def save():
752
+ with open(p, 'wb') as fh:
753
+ pickle.dump(res, fh)
754
+
755
+ if dump_in_thread:
756
+ thread = threading.Thread(target=save)
757
+ while threading.active_count() > 16:
758
+ time.sleep(0.1)
759
+ thread.start()
760
+ else:
761
+ save()
762
+ return str(p)
763
+
764
+ return wrapped
765
+
766
+
767
+ # ─── Log gate path helper ───────────────────────────────────────
768
+
769
+
770
+ def create_log_gate_path(
771
+ log_worker: Literal['zero', 'first', 'all'],
772
+ ) -> Path | None:
773
+ """Create a log gate path for first-worker-only logging."""
774
+ if log_worker == 'first':
775
+ return (
776
+ Path(tempfile.gettempdir())
777
+ / f'speedy_utils_log_gate_{os.getpid()}_{uuid.uuid4().hex}.gate'
778
+ )
779
+ elif log_worker not in ('zero', 'all'):
780
+ raise ValueError(f'Unsupported log_worker: {log_worker!r}')
781
+ return None
782
+
783
+
784
+ # ─── Cleanup utility ────────────────────────────────────────────
785
+
786
+
787
+ def cleanup_phantom_workers() -> None:
788
+ """
789
+ Kill all tracked processes and threads (phantom workers).
790
+
791
+ Also lists non-daemon threads that remain.
792
+ """
793
+ # Clean up tracked processes first
794
+ _prune_dead_processes()
795
+ killed_processes = 0
796
+ with _SPEEDY_PROCESSES_LOCK:
797
+ for process in SPEEDY_RUNNING_PROCESSES[:]:
798
+ try:
799
+ print(
800
+ f'🔪 Killing tracked process {process.pid} '
801
+ f'({process.name()})'
802
+ )
803
+ process.kill()
804
+ killed_processes += 1
805
+ except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
806
+ print(f'⚠️ Could not kill process {process.pid}: {e}')
807
+ SPEEDY_RUNNING_PROCESSES.clear()
808
+
809
+ # Also kill any remaining child processes (fallback)
810
+ parent = psutil.Process(os.getpid())
811
+ for child in parent.children(recursive=True):
812
+ try:
813
+ print(f'🔪 Killing child process {child.pid} ({child.name()})')
814
+ child.kill()
815
+ except psutil.NoSuchProcess:
816
+ pass
817
+
818
+ # Try to clean up threads using thread module functions if available
819
+ try:
820
+ from .thread import (
821
+ SPEEDY_RUNNING_THREADS,
822
+ _prune_dead_threads,
823
+ kill_all_thread,
824
+ )
825
+
826
+ _prune_dead_threads()
827
+ killed_threads = kill_all_thread()
828
+ if killed_threads > 0:
829
+ print(f'🔪 Killed {killed_threads} tracked threads')
830
+ except ImportError:
831
+ # Fallback: just report stray threads
832
+ for t in threading.enumerate():
833
+ if t is threading.current_thread():
834
+ continue
835
+ if not t.daemon:
836
+ print(
837
+ f'⚠️ Thread {t.name} is still running '
838
+ f'(cannot be force-killed).'
839
+ )
840
+
841
+ print(
842
+ f'✅ Cleaned up {killed_processes} tracked processes and '
843
+ f'child processes (kernel untouched).'
844
+ )
845
+
846
+
847
+ __all__ = [
848
+ # Types
849
+ 'ErrorHandlerType',
850
+ 'ErrorStats',
851
+ # Process tracking globals
852
+ 'SPEEDY_RUNNING_PROCESSES',
853
+ '_SPEEDY_PROCESSES_LOCK',
854
+ # Error utilities
855
+ '_should_skip_frame',
856
+ '_format_traceback_lines',
857
+ '_extract_frames_from_traceback',
858
+ '_extract_frames_from_ray_error',
859
+ '_display_formatted_error_and_exit',
860
+ '_exit_on_worker_error',
861
+ '_exit_on_ray_error',
862
+ # Process tracking
863
+ '_prune_dead_processes',
864
+ '_track_processes',
865
+ '_track_multiprocessing_processes',
866
+ '_track_ray_processes',
867
+ # Log gating
868
+ '_LOG_GATE_CACHE',
869
+ '_should_allow_worker_logs',
870
+ '_cleanup_log_gate',
871
+ '_PrefixedWriter',
872
+ '_call_with_log_control',
873
+ 'create_log_gate_path',
874
+ # Cache helpers
875
+ '_build_cache_dir',
876
+ 'wrap_dump',
877
+ # Cleanup
878
+ 'cleanup_phantom_workers',
879
+ ]