mpiptop 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mpiptop.py ADDED
@@ -0,0 +1,1424 @@
1
+ #!/usr/bin/env python3
2
+ """mpiptop: visualize MPI python stacks across hosts using py-spy."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import colorsys
8
+ import dataclasses
9
+ import hashlib
10
+ import json
11
+ import os
12
+ import re
13
+ import shlex
14
+ import shutil
15
+ import signal
16
+ import socket
17
+ import subprocess
18
+ import sys
19
+ import termios
20
+ import time
21
+ import textwrap
22
+ import tty
23
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
24
+
25
+ from rich.console import Console
26
+ from rich.layout import Layout
27
+ from rich.live import Live
28
+ from rich.panel import Panel
29
+ from rich.table import Table
30
+ from rich.text import Text
31
+
32
+
33
+ @dataclasses.dataclass(frozen=True)
34
+ class Proc:
35
+ pid: int
36
+ ppid: int
37
+ args: str
38
+
39
+
40
+ @dataclasses.dataclass(frozen=True)
41
+ class RankInfo:
42
+ rank: int
43
+ host: str
44
+
45
+
46
+ @dataclasses.dataclass(frozen=True)
47
+ class ProgramSelector:
48
+ module: Optional[str]
49
+ script: Optional[str]
50
+ display: str
51
+
52
+
53
+ @dataclasses.dataclass(frozen=True)
54
+ class State:
55
+ prte_pid: int
56
+ rankfile: str
57
+ ranks: List[RankInfo]
58
+ selector: ProgramSelector
59
+
60
+
61
+ @dataclasses.dataclass(frozen=True)
62
+ class RankProcess:
63
+ pid: int
64
+ cmdline: str
65
+ rss_kb: Optional[int]
66
+ python_exe: Optional[str]
67
+ env: Dict[str, str]
68
+
69
+
70
+ @dataclasses.dataclass(frozen=True)
71
+ class ThreadBlock:
72
+ header: str
73
+ stack: List[str]
74
+
75
+
76
+ @dataclasses.dataclass(frozen=True)
77
+ class ParsedPySpy:
78
+ details: List[str]
79
+ threads: List[ThreadBlock]
80
+
81
+
82
+ PUNCT_STYLE = "grey62"
83
+ BORDER_STYLE = "grey62"
84
+ KEY_STYLE = "#7ad7ff"
85
+ HEADER_HEIGHT = 3
86
+ ENV_KEYS = (
87
+ "PATH",
88
+ "LD_LIBRARY_PATH",
89
+ "PYTHONPATH",
90
+ "VIRTUAL_ENV",
91
+ "CONDA_PREFIX",
92
+ "CONDA_DEFAULT_ENV",
93
+ "PYTHONHOME",
94
+ "HOME",
95
+ )
96
+
97
+
98
+ REMOTE_FINDER_SCRIPT = r"""
99
+ import json
100
+ import os
101
+
102
+ TARGET = os.environ.get("MPIPTOP_TARGET", "")
103
+ MODULE = os.environ.get("MPIPTOP_MODULE", "")
104
+ ENV_KEYS = [
105
+ "PATH",
106
+ "LD_LIBRARY_PATH",
107
+ "PYTHONPATH",
108
+ "VIRTUAL_ENV",
109
+ "CONDA_PREFIX",
110
+ "CONDA_DEFAULT_ENV",
111
+ "PYTHONHOME",
112
+ "HOME",
113
+ ]
114
+
115
+
116
+ def read_cmdline(pid):
117
+ with open(f"/proc/{pid}/cmdline", "rb") as f:
118
+ data = f.read().split(b"\0")
119
+ return [x.decode(errors="ignore") for x in data if x]
120
+
121
+
122
+ def read_env(pid):
123
+ with open(f"/proc/{pid}/environ", "rb") as f:
124
+ data = f.read().split(b"\0")
125
+ env = {}
126
+ for item in data:
127
+ if b"=" in item:
128
+ k, v = item.split(b"=", 1)
129
+ env[k.decode(errors="ignore")] = v.decode(errors="ignore")
130
+ return env
131
+
132
+
133
+ def select_env_subset(env):
134
+ return {key: env[key] for key in ENV_KEYS if key in env}
135
+
136
+
137
+ def read_rss_kb(pid):
138
+ try:
139
+ with open(f"/proc/{pid}/status", "r", encoding="utf-8") as f:
140
+ for line in f:
141
+ if line.startswith("VmRSS:"):
142
+ parts = line.split()
143
+ if len(parts) >= 2 and parts[1].isdigit():
144
+ return int(parts[1])
145
+ except Exception:
146
+ return None
147
+ return None
148
+
149
+
150
+ def read_exe(pid):
151
+ try:
152
+ return os.readlink(f"/proc/{pid}/exe")
153
+ except Exception:
154
+ return ""
155
+
156
+
157
+ def matches(cmd):
158
+ if not cmd:
159
+ return False
160
+ exe = os.path.basename(cmd[0])
161
+ if "python" not in exe:
162
+ return False
163
+ if MODULE:
164
+ try:
165
+ idx = cmd.index("-m")
166
+ except ValueError:
167
+ return False
168
+ return idx + 1 < len(cmd) and cmd[idx + 1] == MODULE
169
+ if TARGET:
170
+ for arg in cmd:
171
+ if arg == TARGET or arg.endswith("/" + TARGET) or os.path.basename(arg) == os.path.basename(TARGET):
172
+ return True
173
+ return False
174
+ return True
175
+
176
+
177
+ results = []
178
+ for pid in os.listdir("/proc"):
179
+ if not pid.isdigit():
180
+ continue
181
+ try:
182
+ cmd = read_cmdline(pid)
183
+ if not matches(cmd):
184
+ continue
185
+ env = read_env(pid)
186
+ rank = env.get("OMPI_COMM_WORLD_RANK") or env.get("PMIX_RANK") or env.get("PMI_RANK")
187
+ if rank is None:
188
+ continue
189
+ results.append(
190
+ [
191
+ int(rank),
192
+ int(pid),
193
+ " ".join(cmd),
194
+ read_rss_kb(pid),
195
+ read_exe(pid),
196
+ select_env_subset(env),
197
+ ]
198
+ )
199
+ except Exception:
200
+ continue
201
+
202
+ print(json.dumps(results))
203
+ """
204
+
205
+
206
+ def read_ps() -> List[Proc]:
207
+ result = subprocess.run(
208
+ ["ps", "-eo", "pid=,ppid=,args="],
209
+ check=True,
210
+ capture_output=True,
211
+ text=True,
212
+ )
213
+ procs: List[Proc] = []
214
+ for line in result.stdout.splitlines():
215
+ line = line.strip()
216
+ if not line:
217
+ continue
218
+ parts = line.split(maxsplit=2)
219
+ if len(parts) < 3:
220
+ continue
221
+ try:
222
+ pid = int(parts[0])
223
+ ppid = int(parts[1])
224
+ except ValueError:
225
+ continue
226
+ args = parts[2]
227
+ procs.append(Proc(pid=pid, ppid=ppid, args=args))
228
+ return procs
229
+
230
+
231
+ def select_env_subset(env: Dict[str, str]) -> Dict[str, str]:
232
+ return {key: env[key] for key in ENV_KEYS if key in env and env[key]}
233
+
234
+
235
+ def find_prterun(procs: Sequence[Proc], prterun_pid: Optional[int]) -> Proc:
236
+ if prterun_pid is not None:
237
+ for proc in procs:
238
+ if proc.pid == prterun_pid:
239
+ return proc
240
+ raise SystemExit(f"prterun/mpirun pid {prterun_pid} not found")
241
+
242
+ matcher = re.compile(r"(?:^|/)(prterun|mpirun|orterun)\b")
243
+ candidates = [proc for proc in procs if matcher.search(proc.args)]
244
+ if not candidates:
245
+ raise SystemExit("no prterun/mpirun process found")
246
+
247
+ with_rankfile = [proc for proc in candidates if find_rankfile_path(proc.args)]
248
+ if with_rankfile:
249
+ candidates = with_rankfile
250
+
251
+ candidates.sort(key=lambda p: p.pid, reverse=True)
252
+ return candidates[0]
253
+
254
+
255
+ def find_rankfile_path(args: str) -> Optional[str]:
256
+ match = re.search(r"rankfile:file=([^\s]+)", args)
257
+ if match:
258
+ return match.group(1)
259
+ match = re.search(r"--rankfile\s+([^\s]+)", args)
260
+ if match:
261
+ return match.group(1)
262
+ match = re.search(r"\s-rf\s+([^\s]+)", args)
263
+ if match:
264
+ return match.group(1)
265
+ return None
266
+
267
+
268
+ def parse_rankfile(path: str) -> List[RankInfo]:
269
+ if not os.path.exists(path):
270
+ raise SystemExit(f"rankfile not found: {path}")
271
+ ranks: List[RankInfo] = []
272
+ with open(path, "r", encoding="utf-8") as handle:
273
+ for raw in handle:
274
+ line = raw.strip()
275
+ if not line or line.startswith("#"):
276
+ continue
277
+ match = re.match(r"rank\s+(\d+)\s*=\s*([^\s]+)", line)
278
+ if not match:
279
+ continue
280
+ rank = int(match.group(1))
281
+ host = match.group(2)
282
+ ranks.append(RankInfo(rank=rank, host=host))
283
+ if not ranks:
284
+ raise SystemExit(f"no ranks parsed from {path}")
285
+ ranks.sort(key=lambda r: r.rank)
286
+ return ranks
287
+
288
+
289
+ def build_children_map(procs: Sequence[Proc]) -> Dict[int, List[int]]:
290
+ children: Dict[int, List[int]] = {}
291
+ for proc in procs:
292
+ children.setdefault(proc.ppid, []).append(proc.pid)
293
+ return children
294
+
295
+
296
+ def find_descendants(children: Dict[int, List[int]], root_pid: int) -> List[int]:
297
+ stack = [root_pid]
298
+ seen = set()
299
+ descendants: List[int] = []
300
+ while stack:
301
+ pid = stack.pop()
302
+ if pid in seen:
303
+ continue
304
+ seen.add(pid)
305
+ for child in children.get(pid, []):
306
+ descendants.append(child)
307
+ stack.append(child)
308
+ return descendants
309
+
310
+
311
+ def is_python_process(args: str) -> bool:
312
+ first = args.split(maxsplit=1)[0] if args else ""
313
+ base = os.path.basename(first)
314
+ return "python" in base
315
+
316
+
317
+ def select_program(procs: Sequence[Proc], descendants: Iterable[int]) -> Optional[Proc]:
318
+ descendant_set = set(descendants)
319
+ candidates = [proc for proc in procs if proc.pid in descendant_set and is_python_process(proc.args)]
320
+ candidates = [proc for proc in candidates if "py-spy" not in proc.args]
321
+ if not candidates:
322
+ return None
323
+ def score(proc: Proc) -> Tuple[int, int]:
324
+ has_py = 1 if ".py" in proc.args else 0
325
+ return (has_py, len(proc.args))
326
+ candidates.sort(key=score, reverse=True)
327
+ return candidates[0]
328
+
329
+
330
+ def parse_python_selector(args: str) -> ProgramSelector:
331
+ if not args:
332
+ return ProgramSelector(module=None, script=None, display="")
333
+ try:
334
+ parts = shlex.split(args)
335
+ except ValueError:
336
+ parts = args.split()
337
+ module = None
338
+ script = None
339
+ if "-m" in parts:
340
+ idx = parts.index("-m")
341
+ if idx + 1 < len(parts):
342
+ module = parts[idx + 1]
343
+ for token in parts[1:]:
344
+ if token.startswith("-"):
345
+ continue
346
+ script = token
347
+ break
348
+ display = " ".join(parts)
349
+ return ProgramSelector(module=module, script=script, display=display)
350
+
351
+
352
+ def extract_python_exe(cmdline: str) -> Optional[str]:
353
+ if not cmdline:
354
+ return None
355
+ try:
356
+ parts = shlex.split(cmdline)
357
+ except ValueError:
358
+ parts = cmdline.split()
359
+ if not parts:
360
+ return None
361
+ exe = parts[0]
362
+ if "python" in os.path.basename(exe):
363
+ return exe
364
+ return None
365
+
366
+
367
+ def matches_python_cmd(cmd: List[str], selector: ProgramSelector) -> bool:
368
+ if not cmd:
369
+ return False
370
+ exe = os.path.basename(cmd[0])
371
+ if "python" not in exe:
372
+ return False
373
+ if selector.module:
374
+ try:
375
+ idx = cmd.index("-m")
376
+ except ValueError:
377
+ return False
378
+ return idx + 1 < len(cmd) and cmd[idx + 1] == selector.module
379
+ if selector.script:
380
+ target = selector.script
381
+ base_target = os.path.basename(target)
382
+ for arg in cmd:
383
+ if arg == target or arg.endswith("/" + target) or os.path.basename(arg) == base_target:
384
+ return True
385
+ return False
386
+ return True
387
+
388
+
389
+ def find_rank_pids_local(
390
+ selector: ProgramSelector,
391
+ ) -> List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]]:
392
+ results: List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]] = []
393
+ for pid in os.listdir("/proc"):
394
+ if not pid.isdigit():
395
+ continue
396
+ try:
397
+ with open(f"/proc/{pid}/cmdline", "rb") as handle:
398
+ cmd = [x.decode(errors="ignore") for x in handle.read().split(b"\0") if x]
399
+ if not matches_python_cmd(cmd, selector):
400
+ continue
401
+ with open(f"/proc/{pid}/environ", "rb") as handle:
402
+ env_items = [x for x in handle.read().split(b"\0") if x]
403
+ env: Dict[str, str] = {}
404
+ for item in env_items:
405
+ if b"=" not in item:
406
+ continue
407
+ key, value = item.split(b"=", 1)
408
+ env[key.decode(errors="ignore")] = value.decode(errors="ignore")
409
+ rank = env.get("OMPI_COMM_WORLD_RANK") or env.get("PMIX_RANK") or env.get("PMI_RANK")
410
+ if rank is None:
411
+ continue
412
+ rss_kb = read_rss_kb(int(pid))
413
+ cmdline = " ".join(cmd)
414
+ try:
415
+ exe_path = os.readlink(f"/proc/{pid}/exe")
416
+ except Exception:
417
+ exe_path = ""
418
+ python_exe = exe_path or extract_python_exe(cmdline)
419
+ env_subset = select_env_subset(env)
420
+ results.append((int(rank), int(pid), cmdline, rss_kb, python_exe, env_subset))
421
+ except Exception:
422
+ continue
423
+ return results
424
+
425
+
426
+ def read_rss_kb(pid: int) -> Optional[int]:
427
+ try:
428
+ with open(f"/proc/{pid}/status", "r", encoding="utf-8") as handle:
429
+ for line in handle:
430
+ if line.startswith("VmRSS:"):
431
+ parts = line.split()
432
+ if len(parts) >= 2 and parts[1].isdigit():
433
+ return int(parts[1])
434
+ except Exception:
435
+ return None
436
+ return None
437
+
438
+
439
+ def run_ssh(host: str, command: str, timeout: int = 8) -> subprocess.CompletedProcess:
440
+ return subprocess.run(
441
+ [
442
+ "ssh",
443
+ "-o",
444
+ "BatchMode=yes",
445
+ "-o",
446
+ "ConnectTimeout=5",
447
+ host,
448
+ command,
449
+ ],
450
+ capture_output=True,
451
+ text=True,
452
+ timeout=timeout,
453
+ )
454
+
455
+
456
+ def find_rank_pids_remote(
457
+ host: str, selector: ProgramSelector
458
+ ) -> Tuple[List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]], Optional[str]]:
459
+ env_prefix = build_env_prefix(
460
+ {
461
+ "MPIPTOP_TARGET": selector.script or "",
462
+ "MPIPTOP_MODULE": selector.module or "",
463
+ }
464
+ )
465
+ remote_cmd = f"{env_prefix}python3 - <<'PY'\n{REMOTE_FINDER_SCRIPT}\nPY"
466
+ try:
467
+ result = run_ssh(host, remote_cmd)
468
+ except subprocess.TimeoutExpired:
469
+ return [], f"ssh timeout to {host}"
470
+ if result.returncode != 0:
471
+ stderr = (result.stderr or result.stdout).strip()
472
+ msg = stderr or f"ssh failed ({result.returncode})"
473
+ return [], f"{host}: {msg}"
474
+ try:
475
+ data = json.loads(result.stdout.strip() or "[]")
476
+ except json.JSONDecodeError:
477
+ return [], f"{host}: invalid json from remote"
478
+ parsed: List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]] = []
479
+ for entry in data:
480
+ env_subset: Dict[str, str] = {}
481
+ python_exe: Optional[str] = None
482
+ if len(entry) >= 6:
483
+ r, p, cmd, rss_kb, python_exe, env_subset = entry[:6]
484
+ elif len(entry) >= 5:
485
+ r, p, cmd, rss_kb, python_exe = entry[:5]
486
+ env_subset = {}
487
+ elif len(entry) >= 4:
488
+ r, p, cmd, rss_kb = entry[:4]
489
+ python_exe = None
490
+ env_subset = {}
491
+ else:
492
+ r, p, cmd = entry
493
+ rss_kb = None
494
+ python_exe = None
495
+ env_subset = {}
496
+ rss_value = int(rss_kb) if rss_kb is not None else None
497
+ parsed.append((int(r), int(p), str(cmd), rss_value, python_exe, env_subset or {}))
498
+ return parsed, None
499
+
500
+
501
+ def run_py_spy(
502
+ host: str,
503
+ proc: RankProcess,
504
+ pythonpath: str,
505
+ install_attempted: set,
506
+ timeout: int = 8,
507
+ ) -> Tuple[Optional[str], Optional[str]]:
508
+ env_vars = merge_env(proc, pythonpath, os.environ.copy() if is_local_host(host) else None)
509
+ env_prefix = build_env_prefix(env_vars)
510
+ py_spy_path = None
511
+ if proc.python_exe:
512
+ py_spy_path = os.path.join(os.path.dirname(proc.python_exe), "py-spy")
513
+ if is_local_host(host) and not os.access(py_spy_path, os.X_OK):
514
+ py_spy_path = None
515
+ if not py_spy_path:
516
+ venv = env_vars.get("VIRTUAL_ENV")
517
+ if venv:
518
+ py_spy_path = os.path.join(venv, "bin", "py-spy")
519
+ if is_local_host(host) and not os.access(py_spy_path, os.X_OK):
520
+ py_spy_path = None
521
+
522
+ def run_dump() -> subprocess.CompletedProcess:
523
+ if is_local_host(host):
524
+ env = env_vars
525
+ cmd = [py_spy_path or "py-spy", "dump", "-p", str(proc.pid)]
526
+ return subprocess.run(
527
+ cmd,
528
+ capture_output=True,
529
+ text=True,
530
+ timeout=timeout,
531
+ env=env,
532
+ )
533
+ if py_spy_path:
534
+ spy_cmd = shlex.quote(py_spy_path)
535
+ script = (
536
+ f"if [ -x {spy_cmd} ]; then {spy_cmd} dump -p {proc.pid}; "
537
+ f"else py-spy dump -p {proc.pid}; fi"
538
+ )
539
+ remote_cmd = f"{env_prefix}sh -lc {shlex.quote(script)}"
540
+ else:
541
+ remote_cmd = f"{env_prefix}py-spy dump -p {proc.pid}"
542
+ return run_ssh(host, remote_cmd, timeout=timeout)
543
+
544
+ def missing_py_spy(exc: Optional[BaseException], result: Optional[subprocess.CompletedProcess]) -> bool:
545
+ if isinstance(exc, FileNotFoundError):
546
+ return True
547
+ if result is None:
548
+ return False
549
+ stderr = (result.stderr or result.stdout or "").lower()
550
+ return result.returncode == 127 or "py-spy: command not found" in stderr
551
+
552
+ def ensure_installed() -> Optional[str]:
553
+ venv = env_vars.get("VIRTUAL_ENV")
554
+ venv_python = os.path.join(venv, "bin", "python") if venv else None
555
+ python_exe = proc.python_exe or venv_python or (sys.executable if is_local_host(host) else "python3")
556
+ env_prefix_install = build_env_prefix(env_vars if not is_local_host(host) else {})
557
+ if is_local_host(host):
558
+ env = env_vars
559
+ try:
560
+ result = subprocess.run(
561
+ [python_exe, "-m", "pip", "install", "py-spy"],
562
+ capture_output=True,
563
+ text=True,
564
+ timeout=120,
565
+ env=env,
566
+ )
567
+ except Exception as exc:
568
+ return str(exc)
569
+ if result.returncode == 0:
570
+ return None
571
+ retry = subprocess.run(
572
+ [python_exe, "-m", "pip", "install", "--user", "py-spy"],
573
+ capture_output=True,
574
+ text=True,
575
+ timeout=120,
576
+ env=env,
577
+ )
578
+ if retry.returncode == 0:
579
+ return None
580
+ pip_error = (retry.stderr or retry.stdout or "").strip() or "pip install py-spy failed"
581
+ if should_try_uv(pip_error):
582
+ uv_error = uv_install_local(python_exe, env, pip_error)
583
+ if uv_error is None:
584
+ return None
585
+ fallback_error = pip_user_install_local(env)
586
+ if fallback_error is None:
587
+ return None
588
+ return f"{uv_error}\n{fallback_error}"
589
+ return pip_error
590
+
591
+ cmd = f"{env_prefix_install}{shlex.quote(python_exe)} -m pip install py-spy"
592
+ try:
593
+ result = run_ssh(host, cmd, timeout=120)
594
+ except subprocess.TimeoutExpired:
595
+ return "pip install py-spy timeout"
596
+ if result.returncode == 0:
597
+ return None
598
+ retry = run_ssh(
599
+ host,
600
+ f"{env_prefix_install}{shlex.quote(python_exe)} -m pip install --user py-spy",
601
+ timeout=120,
602
+ )
603
+ if retry.returncode == 0:
604
+ return None
605
+ pip_error = (retry.stderr or retry.stdout or "").strip() or "pip install py-spy failed"
606
+ if should_try_uv(pip_error):
607
+ uv_error = uv_install_remote(host, python_exe, env_prefix_install)
608
+ if uv_error is None:
609
+ return None
610
+ fallback_error = pip_user_install_remote(host, env_prefix_install)
611
+ if fallback_error is None:
612
+ return None
613
+ return f"{uv_error}\n{fallback_error}"
614
+ return pip_error
615
+
616
+ try:
617
+ result = run_dump()
618
+ except FileNotFoundError as exc:
619
+ result = None
620
+ error_exc = exc
621
+ except subprocess.TimeoutExpired:
622
+ return None, f"py-spy timeout on {host}:{proc.pid}"
623
+ else:
624
+ error_exc = None
625
+
626
+ if result is not None and result.returncode == 0:
627
+ return result.stdout, None
628
+
629
+ if missing_py_spy(error_exc, result):
630
+ key = f"{host}|{proc.python_exe or sys.executable}"
631
+ if key not in install_attempted:
632
+ install_attempted.add(key)
633
+ install_error = ensure_installed()
634
+ if install_error:
635
+ return None, f"py-spy install failed on {host}: {install_error}"
636
+ try:
637
+ retry = run_dump()
638
+ except FileNotFoundError:
639
+ return None, f"py-spy still missing on {host}:{proc.pid}"
640
+ except subprocess.TimeoutExpired:
641
+ return None, f"py-spy timeout on {host}:{proc.pid}"
642
+ if retry.returncode == 0:
643
+ return retry.stdout, None
644
+ stderr = (retry.stderr or retry.stdout or "").strip()
645
+ return None, stderr or f"py-spy failed on {host}:{proc.pid}"
646
+
647
+ stderr = (result.stderr or result.stdout or "").strip() if result is not None else str(error_exc)
648
+ return None, stderr or f"py-spy failed on {host}:{proc.pid}"
649
+
650
+
651
+ def build_env_prefix(env: Dict[str, str]) -> str:
652
+ if not env:
653
+ return ""
654
+ parts = []
655
+ for key, value in env.items():
656
+ if not value:
657
+ continue
658
+ parts.append(f"{key}={shlex.quote(value)}")
659
+ return " ".join(parts) + " " if parts else ""
660
+
661
+
662
+ def merge_env(
663
+ proc: RankProcess,
664
+ pythonpath: str,
665
+ base: Optional[Dict[str, str]] = None,
666
+ ) -> Dict[str, str]:
667
+ env = dict(base or {})
668
+ for key, value in (proc.env or {}).items():
669
+ if value:
670
+ env[key] = value
671
+ venv = env.get("VIRTUAL_ENV")
672
+ if venv:
673
+ venv_bin = os.path.join(venv, "bin")
674
+ path = env.get("PATH", "")
675
+ path_parts = path.split(":") if path else []
676
+ if not path_parts or path_parts[0] != venv_bin:
677
+ env["PATH"] = ":".join([venv_bin] + [p for p in path_parts if p != venv_bin])
678
+ if pythonpath:
679
+ env["PYTHONPATH"] = pythonpath
680
+ return env
681
+
682
+
683
+ def should_try_uv(error_text: str) -> bool:
684
+ lowered = error_text.lower()
685
+ return "externally-managed-environment" in lowered or "managed by uv" in lowered
686
+
687
+
688
+ def find_uv_binary(env: Dict[str, str]) -> Optional[str]:
689
+ uv_path = shutil.which("uv", path=env.get("PATH")) if env else shutil.which("uv")
690
+ if uv_path:
691
+ return uv_path
692
+ home = env.get("HOME") if env else None
693
+ candidates = []
694
+ if home:
695
+ candidates.extend(
696
+ [
697
+ os.path.join(home, ".local", "bin", "uv"),
698
+ os.path.join(home, ".cargo", "bin", "uv"),
699
+ ]
700
+ )
701
+ candidates.extend(["/usr/local/bin/uv", "/opt/homebrew/bin/uv"])
702
+ for candidate in candidates:
703
+ if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
704
+ return candidate
705
+ return None
706
+
707
+
708
+ def uv_install_local(python_exe: str, env: Dict[str, str], pip_error: str) -> Optional[str]:
709
+ uv_path = find_uv_binary(env)
710
+ if not uv_path:
711
+ return f"{pip_error}\nuv not found on PATH"
712
+ try:
713
+ result = subprocess.run(
714
+ [uv_path, "pip", "install", "--python", python_exe, "py-spy"],
715
+ capture_output=True,
716
+ text=True,
717
+ timeout=120,
718
+ env=env,
719
+ )
720
+ except Exception as exc:
721
+ return f"{pip_error}\nuv install failed: {exc}"
722
+ if result.returncode == 0:
723
+ return None
724
+ uv_error = (result.stderr or result.stdout or "").strip()
725
+ return f"{pip_error}\nuv install failed: {uv_error or 'unknown error'}"
726
+
727
+
728
+ def uv_install_remote(host: str, python_exe: str, env_prefix: str) -> Optional[str]:
729
+ script = textwrap.dedent(
730
+ f"""
731
+ uv_cmd=""
732
+ if command -v uv >/dev/null 2>&1; then
733
+ uv_cmd="uv"
734
+ else
735
+ for cand in "$HOME/.local/bin/uv" "$HOME/.cargo/bin/uv" "/usr/local/bin/uv" "/opt/homebrew/bin/uv"; do
736
+ if [ -x "$cand" ]; then
737
+ uv_cmd="$cand"
738
+ break
739
+ fi
740
+ done
741
+ fi
742
+ if [ -z "$uv_cmd" ]; then
743
+ echo "uv not found on PATH" 1>&2
744
+ exit 127
745
+ fi
746
+ "$uv_cmd" pip install --python {shlex.quote(python_exe)} py-spy
747
+ """
748
+ ).strip()
749
+ cmd = f"{env_prefix}sh -lc {shlex.quote(script)}"
750
+ try:
751
+ result = run_ssh(host, cmd, timeout=120)
752
+ except subprocess.TimeoutExpired:
753
+ return "uv install timeout"
754
+ if result.returncode == 0:
755
+ return None
756
+ uv_error = (result.stderr or result.stdout or "").strip() or "uv install failed"
757
+ return f"uv install failed: {uv_error}"
758
+
759
+
760
+ def pip_user_install_local(env: Dict[str, str]) -> Optional[str]:
761
+ try:
762
+ result = subprocess.run(
763
+ ["pip", "install", "--user", "py-spy"],
764
+ capture_output=True,
765
+ text=True,
766
+ timeout=120,
767
+ env=env,
768
+ )
769
+ except FileNotFoundError:
770
+ return "pip not found on PATH"
771
+ except Exception as exc:
772
+ return f"pip install failed: {exc}"
773
+ if result.returncode == 0:
774
+ return None
775
+ return (result.stderr or result.stdout or "").strip() or "pip install --user failed"
776
+
777
+
778
+ def pip_user_install_remote(host: str, env_prefix: str) -> Optional[str]:
779
+ cmd = f"{env_prefix}pip install --user py-spy"
780
+ try:
781
+ result = run_ssh(host, cmd, timeout=120)
782
+ except subprocess.TimeoutExpired:
783
+ return "pip install timeout"
784
+ if result.returncode == 0:
785
+ return None
786
+ return (result.stderr or result.stdout or "").strip() or "pip install --user failed"
787
+
788
+
789
+ STACK_LINE_RE = re.compile(r"^(\s*)(.*?)\s+\((.*):(\d+)\)\s*$")
790
+
791
+
792
+ def pastel_color(key: str) -> str:
793
+ digest = hashlib.sha1(key.encode("utf-8", errors="ignore")).hexdigest()
794
+ hue = (int(digest[:8], 16) % 360) / 360.0
795
+ saturation = 0.35
796
+ value = 0.92
797
+ r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)
798
+ return f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}"
799
+
800
+
801
+ def highlight_substring(text: str, substring: str, color: str) -> Text:
802
+ idx = text.find(substring)
803
+ if idx == -1:
804
+ return Text(text)
805
+ output = Text()
806
+ output.append(text[:idx])
807
+ output.append(substring, style=color)
808
+ output.append(text[idx + len(substring):])
809
+ return output
810
+
811
+
812
+ def highlight_py_path(text: str) -> Text:
813
+ match = re.search(r"\S+\.py\b", text)
814
+ if not match:
815
+ return Text(text)
816
+ path = match.group(0)
817
+ color = pastel_color(path)
818
+ output = Text()
819
+ output.append(text[:match.start()])
820
+ output.append(path, style=color)
821
+ output.append(text[match.end():])
822
+ return output
823
+
824
+
825
+ def style_program_display(selector: ProgramSelector) -> Text:
826
+ if not selector.display:
827
+ return Text("python")
828
+ if selector.script and selector.script in selector.display:
829
+ return highlight_substring(selector.display, selector.script, pastel_color(selector.script))
830
+ if selector.module and selector.module in selector.display:
831
+ return highlight_substring(selector.display, selector.module, pastel_color(selector.module))
832
+ return highlight_py_path(selector.display)
833
+
834
+
835
+ def style_program_line(line: str, selector: ProgramSelector) -> Text:
836
+ if selector.script and selector.script in line:
837
+ return highlight_substring(line, selector.script, pastel_color(selector.script))
838
+ if selector.module and selector.module in line:
839
+ return highlight_substring(line, selector.module, pastel_color(selector.module))
840
+ return highlight_py_path(line)
841
+
842
+
843
+ def wrap_program_lines(selector: ProgramSelector, width: int) -> List[Text]:
844
+ display = selector.display or "python"
845
+ if width <= 0:
846
+ return [style_program_line(display, selector)]
847
+ try:
848
+ parts = shlex.split(display)
849
+ except ValueError:
850
+ parts = display.split()
851
+ if not parts:
852
+ return [style_program_line(display, selector)]
853
+
854
+ prefix_tokens = [parts[0]]
855
+ arg_tokens = parts[1:]
856
+
857
+ if selector.module and "-m" in parts:
858
+ idx = parts.index("-m")
859
+ if idx + 1 < len(parts):
860
+ prefix_tokens = parts[: idx + 2]
861
+ arg_tokens = parts[idx + 2 :]
862
+ elif selector.script:
863
+ base = os.path.basename(selector.script)
864
+ script_idx = None
865
+ for i, tok in enumerate(parts[1:], start=1):
866
+ if tok == selector.script or tok.endswith("/" + selector.script) or os.path.basename(tok) == base:
867
+ script_idx = i
868
+ break
869
+ if script_idx is not None:
870
+ prefix_tokens = parts[: script_idx + 1]
871
+ arg_tokens = parts[script_idx + 1 :]
872
+
873
+ prefix = " ".join(prefix_tokens)
874
+ args_str = " ".join(arg_tokens)
875
+ indent = len(prefix) + 1
876
+
877
+ if not args_str:
878
+ return [style_program_line(prefix, selector)]
879
+
880
+ if indent >= width:
881
+ wrapped = textwrap.wrap(display, width=width, break_long_words=False, break_on_hyphens=False)
882
+ if not wrapped:
883
+ wrapped = [display]
884
+ return [style_program_line(line, selector) for line in wrapped]
885
+
886
+ wrapper = textwrap.TextWrapper(
887
+ width=width,
888
+ initial_indent=" " * indent,
889
+ subsequent_indent=" " * indent,
890
+ break_long_words=False,
891
+ break_on_hyphens=False,
892
+ )
893
+ wrapped_args = wrapper.wrap(args_str) or [args_str]
894
+ first = wrapped_args[0]
895
+ if first.startswith(" " * indent):
896
+ first = first[indent:]
897
+ lines = [f"{prefix} {first}"]
898
+ lines.extend(wrapped_args[1:])
899
+ return [style_program_line(line, selector) for line in lines]
900
+
901
+
902
+ def style_detail_line(line: str) -> Text:
903
+ lower = line.lower()
904
+ if lower.startswith("program:"):
905
+ return highlight_py_path(line)
906
+ return Text(line)
907
+
908
+
909
+ def style_stack_line(line: str) -> Text:
910
+ marker = ""
911
+ if line.startswith("➤ "):
912
+ marker = "➤ "
913
+ line = line[2:]
914
+ match = STACK_LINE_RE.match(line)
915
+ if not match:
916
+ output = Text()
917
+ if marker:
918
+ output.append(marker, style=KEY_STYLE)
919
+ output.append(line)
920
+ return output
921
+ indent, func, file_path, line_no = match.groups()
922
+ color = pastel_color(file_path)
923
+ output = Text()
924
+ if marker:
925
+ output.append(marker, style=KEY_STYLE)
926
+ output.append(indent)
927
+ output.append(func, style=color)
928
+ output.append(" ")
929
+ output.append("(", style=PUNCT_STYLE)
930
+ output.append(file_path, style=color)
931
+ output.append(":", style=PUNCT_STYLE)
932
+ output.append(line_no, style=PUNCT_STYLE)
933
+ output.append(")", style=PUNCT_STYLE)
934
+ return output
935
+
936
+
937
+ def style_lines(lines: List[str]) -> Text:
938
+ output = Text()
939
+ for idx, line in enumerate(lines):
940
+ if idx:
941
+ output.append("\n")
942
+ output.append_text(style_stack_line(line))
943
+ output.no_wrap = True
944
+ output.overflow = "crop"
945
+ return output
946
+
947
+
948
+ def format_rss(rss_kb: Optional[int]) -> str:
949
+ if rss_kb is None:
950
+ return "unknown"
951
+ value = float(rss_kb) * 1024.0
952
+ units = ["B", "KB", "MB", "GB", "TB"]
953
+ idx = 0
954
+ while value >= 1000.0 and idx < len(units) - 1:
955
+ value /= 1000.0
956
+ idx += 1
957
+ if units[idx] == "B":
958
+ return f"{int(value)} {units[idx]}"
959
+ return f"{value:.1f} {units[idx]}"
960
+
961
+
962
+ def parse_pyspy_output(output: str) -> ParsedPySpy:
963
+ details: List[str] = []
964
+ threads: List[ThreadBlock] = []
965
+ current_header: Optional[str] = None
966
+ current_stack: List[str] = []
967
+ in_threads = False
968
+ for line in output.splitlines():
969
+ if line.startswith("Thread "):
970
+ if current_header is not None:
971
+ threads.append(ThreadBlock(header=current_header, stack=current_stack))
972
+ current_header = line
973
+ current_stack = []
974
+ in_threads = True
975
+ continue
976
+ if not in_threads:
977
+ details.append(line)
978
+ else:
979
+ current_stack.append(line)
980
+ if current_header is not None:
981
+ threads.append(ThreadBlock(header=current_header, stack=current_stack))
982
+ return ParsedPySpy(details=details, threads=threads)
983
+
984
+
985
+ def invert_stack_lines(lines: List[str]) -> List[str]:
986
+ output: List[str] = []
987
+ stack_block: List[str] = []
988
+ for line in lines:
989
+ if line.startswith(" "):
990
+ stack_block.append(line)
991
+ continue
992
+ if stack_block:
993
+ output.extend(reversed(stack_block))
994
+ stack_block = []
995
+ output.append(line)
996
+ if stack_block:
997
+ output.extend(reversed(stack_block))
998
+ return output
999
+
1000
+
1001
+ def filter_detail_lines(lines: List[str]) -> List[str]:
1002
+ kept: List[str] = []
1003
+ for line in lines:
1004
+ lower = line.lower()
1005
+ if lower.startswith("program:") or lower.startswith("python version:"):
1006
+ kept.append(line)
1007
+ return kept
1008
+
1009
+
1010
+ def select_threads(threads: List[ThreadBlock], show_threads: bool) -> Tuple[List[ThreadBlock], int]:
1011
+ if show_threads:
1012
+ return threads, 0
1013
+ if not threads:
1014
+ return [], 0
1015
+ main_thread = None
1016
+ for thread in threads:
1017
+ if "MainThread" in thread.header:
1018
+ main_thread = thread
1019
+ break
1020
+ if main_thread is None:
1021
+ main_thread = threads[0]
1022
+ return [main_thread], len(threads) - 1
1023
+
1024
+
1025
+ def render_pyspy_output(output: str, show_threads: bool) -> Tuple[List[str], List[str]]:
1026
+ parsed = parse_pyspy_output(output)
1027
+ details = filter_detail_lines(parsed.details)
1028
+ inverted_threads = [
1029
+ ThreadBlock(header=thread.header, stack=invert_stack_lines(thread.stack))
1030
+ for thread in parsed.threads
1031
+ ]
1032
+ display_threads, other_count = select_threads(inverted_threads, show_threads)
1033
+ lines: List[str] = []
1034
+ if not display_threads:
1035
+ lines.append("no thread data")
1036
+ else:
1037
+ for idx, thread in enumerate(display_threads):
1038
+ if thread.header:
1039
+ lines.append(thread.header)
1040
+ lines.extend(thread.stack)
1041
+ if show_threads and idx < len(display_threads) - 1:
1042
+ lines.append("")
1043
+ if not show_threads and other_count > 0:
1044
+ lines.append(f"(+{other_count} other threads)")
1045
+ return lines, details
1046
+
1047
+
1048
+ def extract_stack_lines(lines: List[str]) -> List[str]:
1049
+ if not lines:
1050
+ return []
1051
+ start = 1 if lines[0].startswith("Thread ") else 0
1052
+ stack: List[str] = []
1053
+ for line in lines[start:]:
1054
+ if line.startswith(" "):
1055
+ stack.append(line)
1056
+ else:
1057
+ break
1058
+ return stack
1059
+
1060
+
1061
+ def common_prefix_length(stacks_by_rank: Dict[int, List[str]]) -> int:
1062
+ if not stacks_by_rank:
1063
+ return 0
1064
+ stack_lists = list(stacks_by_rank.values())
1065
+ min_len = min(len(stack) for stack in stack_lists)
1066
+ prefix_len = 0
1067
+ for idx in range(min_len):
1068
+ values = [stack[idx] for stack in stack_lists]
1069
+ if all(value == values[0] for value in values):
1070
+ prefix_len += 1
1071
+ else:
1072
+ break
1073
+ return prefix_len
1074
+
1075
+
1076
+ def mark_diff_line(lines: List[str], diff_index: int) -> List[str]:
1077
+ if diff_index is None:
1078
+ return lines
1079
+ marked = list(lines)
1080
+ stack_pos = 0
1081
+ for idx, line in enumerate(marked):
1082
+ if line.startswith(" "):
1083
+ if stack_pos == diff_index:
1084
+ if line.startswith(" "):
1085
+ marked[idx] = "➤ " + line[2:]
1086
+ else:
1087
+ marked[idx] = "➤ " + line
1088
+ break
1089
+ stack_pos += 1
1090
+ return marked
1091
+
1092
+
1093
+ def is_local_host(host: str) -> bool:
1094
+ host = host.split(".")[0]
1095
+ local = socket.gethostname().split(".")[0]
1096
+ return host == local or host in {"localhost", "127.0.0.1"}
1097
+
1098
+
1099
+ def shorten(text: str, width: int) -> str:
1100
+ if width <= 0:
1101
+ return ""
1102
+ if len(text) <= width:
1103
+ return text
1104
+ if width <= 3:
1105
+ return text[:width]
1106
+ return text[: width - 3] + "..."
1107
+
1108
+
1109
+ def build_header(
1110
+ state: State, last_update: str, errors: List[str], refresh: int, width: int
1111
+ ) -> Tuple[Text, int]:
1112
+ program_lines = wrap_program_lines(state.selector, width)
1113
+ if program_lines:
1114
+ last_line = program_lines[-1]
1115
+ last_line.append(f" | ranks: {len(state.ranks)} | rankfile: {state.rankfile}")
1116
+ else:
1117
+ program_lines = [Text(f"python | ranks: {len(state.ranks)} | rankfile: {state.rankfile}")]
1118
+
1119
+ controls_plain = "q quit | space refresh | t threads | d details"
1120
+ padding = max(0, width - len(controls_plain))
1121
+ line2 = Text(" " * padding)
1122
+ line2.append("q", style=KEY_STYLE)
1123
+ line2.append(" quit | ")
1124
+ line2.append("space", style=KEY_STYLE)
1125
+ line2.append(" refresh | ")
1126
+ line2.append("t", style=KEY_STYLE)
1127
+ line2.append(" threads | ")
1128
+ line2.append("d", style=KEY_STYLE)
1129
+ line2.append(" details")
1130
+ line2.truncate(width)
1131
+
1132
+ text = Text()
1133
+ for idx, line in enumerate(program_lines):
1134
+ if idx:
1135
+ text.append("\n")
1136
+ text.append_text(line)
1137
+ text.append("\n")
1138
+ text.append_text(line2)
1139
+ return text, len(program_lines) + 1
1140
+
1141
+
1142
+ def render_columns(
1143
+ ranks: List[RankInfo],
1144
+ stacks: Dict[int, Text],
1145
+ details: Optional[Text],
1146
+ body_height: int,
1147
+ rank_to_proc: Dict[int, RankProcess],
1148
+ ) -> Table:
1149
+ panels = []
1150
+ for entry in ranks:
1151
+ title = f"rank {entry.rank} @ {entry.host}"
1152
+ proc = rank_to_proc.get(entry.rank)
1153
+ if proc and proc.rss_kb is not None:
1154
+ title = f"{title} | {format_rss(proc.rss_kb)}"
1155
+ stack_text = stacks.get(entry.rank, Text("No process"))
1156
+ stack_text.no_wrap = True
1157
+ stack_text.overflow = "crop"
1158
+ panels.append(
1159
+ Panel(
1160
+ stack_text,
1161
+ title=title,
1162
+ height=body_height,
1163
+ padding=(0, 1),
1164
+ border_style=BORDER_STYLE,
1165
+ )
1166
+ )
1167
+ if details is not None:
1168
+ details.no_wrap = True
1169
+ details.overflow = "crop"
1170
+ panels.append(
1171
+ Panel(
1172
+ details,
1173
+ title="details",
1174
+ height=body_height,
1175
+ padding=(0, 1),
1176
+ border_style=BORDER_STYLE,
1177
+ )
1178
+ )
1179
+ grid = Table.grid(expand=True)
1180
+ for _ in panels:
1181
+ grid.add_column(ratio=1)
1182
+ grid.add_row(*panels)
1183
+ return grid
1184
+
1185
+
1186
+ def wrap_cmdline(cmdline: str, width: int) -> List[str]:
1187
+ prefix = "cmd: "
1188
+ if width <= len(prefix) + 2:
1189
+ return [f"{prefix}{cmdline}"]
1190
+ wrapper = textwrap.TextWrapper(
1191
+ width=width,
1192
+ initial_indent=prefix,
1193
+ subsequent_indent=" " * len(prefix),
1194
+ break_long_words=False,
1195
+ break_on_hyphens=False,
1196
+ )
1197
+ return wrapper.wrap(cmdline) or [f"{prefix}{cmdline}"]
1198
+
1199
+
1200
+ def build_details_text(
1201
+ ranks: List[RankInfo],
1202
+ rank_to_proc: Dict[int, RankProcess],
1203
+ details_by_rank: Dict[int, List[str]],
1204
+ cmd_width: int,
1205
+ ) -> Text:
1206
+ output = Text()
1207
+ for idx, entry in enumerate(ranks):
1208
+ if idx:
1209
+ output.append("\n\n")
1210
+ output.append(f"rank {entry.rank} @ {entry.host}", style="bold")
1211
+ proc = rank_to_proc.get(entry.rank)
1212
+ lines = details_by_rank.get(entry.rank, [])
1213
+ if proc is None:
1214
+ output.append("\n")
1215
+ output.append("No process")
1216
+ continue
1217
+ output.append("\n")
1218
+ output.append(f"pid: {proc.pid}")
1219
+ output.append("\n")
1220
+ output.append(f"rss: {format_rss(proc.rss_kb)}")
1221
+ for cmd_line in wrap_cmdline(proc.cmdline, cmd_width):
1222
+ output.append("\n")
1223
+ output.append_text(highlight_py_path(cmd_line))
1224
+ for line in lines:
1225
+ output.append("\n")
1226
+ output.append_text(style_detail_line(line))
1227
+ return output
1228
+
1229
+
1230
+ def detect_state(args: argparse.Namespace) -> State:
1231
+ procs = read_ps()
1232
+ prte = find_prterun(procs, args.prterun_pid)
1233
+ rankfile = args.rankfile or find_rankfile_path(prte.args)
1234
+ if not rankfile:
1235
+ raise SystemExit("rankfile not found in prterun/mpirun args")
1236
+ ranks = parse_rankfile(rankfile)
1237
+ children = build_children_map(procs)
1238
+ descendants = find_descendants(children, prte.pid)
1239
+ program_proc = select_program(procs, descendants)
1240
+ selector = parse_python_selector(program_proc.args if program_proc else "")
1241
+ return State(prte_pid=prte.pid, rankfile=rankfile, ranks=ranks, selector=selector)
1242
+
1243
+
1244
+ def collect_rank_pids(state: State) -> Tuple[Dict[int, RankProcess], List[str]]:
1245
+ errors: List[str] = []
1246
+ rank_to_proc: Dict[int, RankProcess] = {}
1247
+ hosts = sorted({entry.host for entry in state.ranks})
1248
+ rank_set = {entry.rank for entry in state.ranks}
1249
+
1250
+ for host in hosts:
1251
+ if is_local_host(host):
1252
+ entries = find_rank_pids_local(state.selector)
1253
+ host_error = None
1254
+ else:
1255
+ entries, host_error = find_rank_pids_remote(host, state.selector)
1256
+ if host_error:
1257
+ errors.append(host_error)
1258
+ for rank, pid, cmd, rss_kb, python_exe, env_subset in entries:
1259
+ if rank not in rank_set:
1260
+ continue
1261
+ existing = rank_to_proc.get(rank)
1262
+ if existing is None or pid > existing.pid:
1263
+ venv = env_subset.get("VIRTUAL_ENV") if env_subset else None
1264
+ venv_python = os.path.join(venv, "bin", "python") if venv else None
1265
+ rank_to_proc[rank] = RankProcess(
1266
+ pid=pid,
1267
+ cmdline=cmd,
1268
+ rss_kb=rss_kb,
1269
+ python_exe=python_exe or venv_python or extract_python_exe(cmd),
1270
+ env=env_subset or {},
1271
+ )
1272
+ return rank_to_proc, errors
1273
+
1274
+
1275
+ def collect_stacks(
1276
+ state: State,
1277
+ rank_to_proc: Dict[int, RankProcess],
1278
+ pythonpath: str,
1279
+ show_threads: bool,
1280
+ install_attempted: set,
1281
+ ) -> Tuple[Dict[int, List[str]], Dict[int, List[str]], List[str]]:
1282
+ stacks: Dict[int, List[str]] = {}
1283
+ details_by_rank: Dict[int, List[str]] = {}
1284
+ errors: List[str] = []
1285
+ for entry in state.ranks:
1286
+ proc = rank_to_proc.get(entry.rank)
1287
+ if proc is None:
1288
+ stacks[entry.rank] = ["No process"]
1289
+ details_by_rank[entry.rank] = []
1290
+ continue
1291
+ output, error = run_py_spy(entry.host, proc, pythonpath, install_attempted)
1292
+ if error:
1293
+ errors.append(error)
1294
+ stacks[entry.rank] = [error]
1295
+ details_by_rank[entry.rank] = []
1296
+ continue
1297
+ lines, details = render_pyspy_output(output or "", show_threads)
1298
+ stacks[entry.rank] = lines
1299
+ details_by_rank[entry.rank] = details
1300
+ return stacks, details_by_rank, errors
1301
+
1302
+
1303
+ def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
1304
+ parser = argparse.ArgumentParser(description="Show MPI Python stacks across hosts.")
1305
+ parser.add_argument("--rankfile", help="Override rankfile path")
1306
+ parser.add_argument("--prterun-pid", type=int, help="PID of prterun/mpirun")
1307
+ parser.add_argument("--refresh", type=int, default=10, help="Refresh interval (seconds)")
1308
+ parser.add_argument(
1309
+ "--pythonpath",
1310
+ help="PYTHONPATH to export remotely (defaults to local PYTHONPATH)",
1311
+ )
1312
+ return parser.parse_args(argv)
1313
+
1314
+
1315
+ def main(argv: Optional[Sequence[str]] = None) -> int:
1316
+ args = parse_args(argv)
1317
+ pythonpath = args.pythonpath if args.pythonpath is not None else os.environ.get("PYTHONPATH", "")
1318
+
1319
+ state = detect_state(args)
1320
+ console = Console()
1321
+ refresh = max(1, args.refresh)
1322
+ show_threads = False
1323
+ show_details = False
1324
+ install_attempted: set = set()
1325
+
1326
+ def handle_sigint(_sig, _frame):
1327
+ raise KeyboardInterrupt
1328
+
1329
+ signal.signal(signal.SIGINT, handle_sigint)
1330
+
1331
+ fd = sys.stdin.fileno()
1332
+ old_settings = termios.tcgetattr(fd)
1333
+ tty.setcbreak(fd)
1334
+
1335
+ layout = Layout()
1336
+ layout.split_column(Layout(name="header", size=HEADER_HEIGHT), Layout(name="body"))
1337
+
1338
+ last_update = "never"
1339
+ next_refresh = 0.0
1340
+
1341
+ def refresh_view() -> None:
1342
+ nonlocal last_update
1343
+ rank_to_proc, pid_errors = collect_rank_pids(state)
1344
+ stacks, details_by_rank, stack_errors = collect_stacks(
1345
+ state, rank_to_proc, pythonpath, show_threads, install_attempted
1346
+ )
1347
+ stacks_text: Dict[int, Text] = {}
1348
+ stack_lines_by_rank = {rank: extract_stack_lines(lines) for rank, lines in stacks.items()}
1349
+ prefix_len = common_prefix_length(stack_lines_by_rank)
1350
+ diff_index = None
1351
+ if any(stack_lines_by_rank.values()):
1352
+ if prefix_len > 0:
1353
+ diff_index = prefix_len - 1
1354
+ else:
1355
+ diff_index = 0
1356
+ for rank, lines in stacks.items():
1357
+ marked = mark_diff_line(lines, diff_index) if diff_index is not None else lines
1358
+ stacks_text[rank] = style_lines(marked)
1359
+ errors = pid_errors + stack_errors
1360
+ last_update = time.strftime("%H:%M:%S")
1361
+ width, height = shutil.get_terminal_size((120, 40))
1362
+ content_width = max(0, width - 4)
1363
+ header, header_lines = build_header(state, last_update, errors, refresh, content_width)
1364
+ header_height = header_lines + 2
1365
+ header_height = max(3, min(header_height, max(3, height - 1)))
1366
+ layout["header"].size = header_height
1367
+ body_height = max(1, height - header_height)
1368
+ total_columns = len(state.ranks) + (1 if show_details else 0)
1369
+ column_width = max(1, content_width // max(1, total_columns))
1370
+ inner_width = max(1, column_width - 4)
1371
+ details_text = (
1372
+ build_details_text(state.ranks, rank_to_proc, details_by_rank, inner_width)
1373
+ if show_details
1374
+ else None
1375
+ )
1376
+ layout["header"].update(
1377
+ Panel(header, padding=(0, 1), border_style=BORDER_STYLE)
1378
+ )
1379
+ layout["body"].update(
1380
+ render_columns(state.ranks, stacks_text, details_text, body_height, rank_to_proc)
1381
+ )
1382
+
1383
+ try:
1384
+ refresh_view()
1385
+ next_refresh = time.time() + refresh
1386
+ with Live(layout, console=console, refresh_per_second=10, screen=True):
1387
+ while True:
1388
+ now = time.time()
1389
+ if now >= next_refresh:
1390
+ refresh_view()
1391
+ next_refresh = now + refresh
1392
+
1393
+ if sys.stdin in select_with_timeout(0.1):
1394
+ key = sys.stdin.read(1)
1395
+ if key == "q":
1396
+ return 0
1397
+ if key == " ":
1398
+ next_refresh = 0.0
1399
+ if key == "t":
1400
+ show_threads = not show_threads
1401
+ next_refresh = 0.0
1402
+ if key == "d":
1403
+ show_details = not show_details
1404
+ next_refresh = 0.0
1405
+ except KeyboardInterrupt:
1406
+ return 0
1407
+ finally:
1408
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
1409
+
1410
+ return 0
1411
+
1412
+
1413
+ def select_with_timeout(timeout: float):
1414
+ import select
1415
+
1416
+ try:
1417
+ readable, _, _ = select.select([sys.stdin], [], [], timeout)
1418
+ except ValueError:
1419
+ return []
1420
+ return readable
1421
+
1422
+
1423
+ if __name__ == "__main__":
1424
+ raise SystemExit(main())