mpiptop 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mpiptop.py CHANGED
@@ -6,6 +6,7 @@ from __future__ import annotations
6
6
  import argparse
7
7
  import colorsys
8
8
  import dataclasses
9
+ import datetime
9
10
  import hashlib
10
11
  import json
11
12
  import os
@@ -52,7 +53,9 @@ class ProgramSelector:
52
53
 
53
54
  @dataclasses.dataclass(frozen=True)
54
55
  class State:
56
+ launcher: str
55
57
  prte_pid: int
58
+ slurm_job_id: Optional[str]
56
59
  rankfile: str
57
60
  ranks: List[RankInfo]
58
61
  selector: ProgramSelector
@@ -79,10 +82,40 @@ class ParsedPySpy:
79
82
  threads: List[ThreadBlock]
80
83
 
81
84
 
85
+ @dataclasses.dataclass(frozen=True)
86
+ class RankSnapshot:
87
+ output: Optional[str]
88
+ error: Optional[str]
89
+ stack_lines: List[str]
90
+ details: List[str]
91
+
92
+
93
+ @dataclasses.dataclass
94
+ class SessionEvent:
95
+ timestamp: float
96
+ ranks: Dict[int, Dict[str, object]]
97
+
98
+
99
+ @dataclasses.dataclass
100
+ class TimelineLevel:
101
+ start: int
102
+ end: int
103
+ selected: int = 0
104
+ buckets: List[Tuple[int, int]] = dataclasses.field(default_factory=list)
105
+
106
+
82
107
  PUNCT_STYLE = "grey62"
83
108
  BORDER_STYLE = "grey62"
84
109
  KEY_STYLE = "#7ad7ff"
85
110
  HEADER_HEIGHT = 3
111
+ SESSION_VERSION = 1
112
+ SESSION_LOG_FILE = "session.jsonl"
113
+ SESSION_METADATA_FILE = "metadata.json"
114
+ SESSION_EVENTS_FILE = "events.jsonl"
115
+ SPARKLINE_CHARS = "▁▂▃▄▅▆▇█"
116
+ HEARTBEAT_INTERVAL = 60
117
+ DIVERGENCE_THRESHOLD = 0.5
118
+ DIVERGENCE_INTERVAL = 60
86
119
  ENV_KEYS = (
87
120
  "PATH",
88
121
  "LD_LIBRARY_PATH",
@@ -101,6 +134,7 @@ import os
101
134
 
102
135
  TARGET = os.environ.get("MPIPTOP_TARGET", "")
103
136
  MODULE = os.environ.get("MPIPTOP_MODULE", "")
137
+ JOB_ID = os.environ.get("MPIPTOP_SLURM_JOB_ID", "")
104
138
  ENV_KEYS = [
105
139
  "PATH",
106
140
  "LD_LIBRARY_PATH",
@@ -174,6 +208,12 @@ def matches(cmd):
174
208
  return True
175
209
 
176
210
 
211
+ def matches_job(env):
212
+ if not JOB_ID:
213
+ return True
214
+ return env.get("SLURM_JOB_ID") == JOB_ID
215
+
216
+
177
217
  results = []
178
218
  for pid in os.listdir("/proc"):
179
219
  if not pid.isdigit():
@@ -183,7 +223,14 @@ for pid in os.listdir("/proc"):
183
223
  if not matches(cmd):
184
224
  continue
185
225
  env = read_env(pid)
186
- rank = env.get("OMPI_COMM_WORLD_RANK") or env.get("PMIX_RANK") or env.get("PMI_RANK")
226
+ if not matches_job(env):
227
+ continue
228
+ rank = (
229
+ env.get("OMPI_COMM_WORLD_RANK")
230
+ or env.get("PMIX_RANK")
231
+ or env.get("PMI_RANK")
232
+ or env.get("SLURM_PROCID")
233
+ )
187
234
  if rank is None:
188
235
  continue
189
236
  results.append(
@@ -203,6 +250,266 @@ print(json.dumps(results))
203
250
  """
204
251
 
205
252
 
253
+ def iso_timestamp(value: Optional[float] = None) -> str:
254
+ ts = time.time() if value is None else value
255
+ return datetime.datetime.fromtimestamp(ts).isoformat(timespec="seconds")
256
+
257
+
258
+ def default_session_path() -> str:
259
+ stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
260
+ return os.path.abspath(f"mpiptop-session-{stamp}.jsonl")
261
+
262
+
263
+ def normalize_session_path(path: str) -> Tuple[str, str]:
264
+ if path.endswith(".jsonl") or (os.path.exists(path) and os.path.isfile(path)):
265
+ base_dir = os.path.dirname(path) or "."
266
+ return base_dir, path
267
+ return path, os.path.join(path, SESSION_LOG_FILE)
268
+
269
+
270
+ def ensure_session_path(path: str) -> Tuple[str, str]:
271
+ base_dir, log_path = normalize_session_path(path)
272
+ if os.path.exists(path):
273
+ if os.path.isdir(path):
274
+ if os.listdir(path):
275
+ if os.path.exists(log_path) or os.path.exists(os.path.join(path, SESSION_METADATA_FILE)):
276
+ return base_dir, log_path
277
+ raise SystemExit(f"record path exists and is not empty: {path}")
278
+ elif os.path.isfile(path):
279
+ return base_dir, log_path
280
+ else:
281
+ raise SystemExit(f"record path exists and is not a file or directory: {path}")
282
+ else:
283
+ if log_path.endswith(".jsonl"):
284
+ os.makedirs(base_dir, exist_ok=True)
285
+ else:
286
+ os.makedirs(base_dir, exist_ok=True)
287
+ return base_dir, log_path
288
+
289
+
290
+ def write_session_metadata(log_path: str, state: State, refresh: int, pythonpath: str) -> None:
291
+ payload = {
292
+ "version": SESSION_VERSION,
293
+ "created_at": iso_timestamp(),
294
+ "refresh": refresh,
295
+ "rankfile": state.rankfile,
296
+ "prte_pid": state.prte_pid,
297
+ "launcher": state.launcher,
298
+ "slurm_job_id": state.slurm_job_id,
299
+ "selector": dataclasses.asdict(state.selector),
300
+ "ranks": [dataclasses.asdict(rank) for rank in state.ranks],
301
+ "pythonpath": pythonpath,
302
+ "record_on_change": True,
303
+ }
304
+ if os.path.exists(log_path) and os.path.getsize(log_path) > 0:
305
+ return
306
+ with open(log_path, "a", encoding="utf-8") as handle:
307
+ handle.write(json.dumps({"type": "metadata", "data": payload}) + "\n")
308
+
309
+
310
+ def load_session_metadata(path: str) -> Dict[str, object]:
311
+ base_dir, log_path = normalize_session_path(path)
312
+ metadata_path = os.path.join(base_dir, SESSION_METADATA_FILE)
313
+ if os.path.exists(metadata_path):
314
+ with open(metadata_path, "r", encoding="utf-8") as handle:
315
+ return json.load(handle)
316
+ if not os.path.exists(log_path):
317
+ raise SystemExit(f"metadata not found in {path}")
318
+ with open(log_path, "r", encoding="utf-8") as handle:
319
+ for line in handle:
320
+ raw = line.strip()
321
+ if not raw:
322
+ continue
323
+ data = json.loads(raw)
324
+ if isinstance(data, dict) and data.get("type") == "metadata":
325
+ payload = data.get("data")
326
+ if isinstance(payload, dict):
327
+ return payload
328
+ if isinstance(data, dict) and "version" in data and "ranks" in data:
329
+ return data
330
+ raise SystemExit(f"metadata not found in {log_path}")
331
+
332
+
333
+ def read_last_event(path: str) -> Optional[Dict[str, object]]:
334
+ if not os.path.exists(path):
335
+ return None
336
+ with open(path, "rb") as handle:
337
+ handle.seek(0, os.SEEK_END)
338
+ pos = handle.tell()
339
+ if pos == 0:
340
+ return None
341
+ chunk = b""
342
+ while pos > 0:
343
+ step = min(4096, pos)
344
+ pos -= step
345
+ handle.seek(pos)
346
+ chunk = handle.read(step) + chunk
347
+ if b"\n" in chunk:
348
+ break
349
+ lines = [line for line in chunk.splitlines() if line.strip()]
350
+ while lines:
351
+ raw = lines.pop().decode("utf-8", errors="ignore")
352
+ try:
353
+ data = json.loads(raw)
354
+ except json.JSONDecodeError:
355
+ continue
356
+ if isinstance(data, dict) and data.get("type") == "metadata":
357
+ continue
358
+ if isinstance(data, dict) and data.get("type") == "event":
359
+ payload = data.get("data")
360
+ if isinstance(payload, dict):
361
+ return payload
362
+ return data
363
+ return None
364
+
365
+
366
+ def load_session_events(path: str) -> List[SessionEvent]:
367
+ base_dir, log_path = normalize_session_path(path)
368
+ events_path = os.path.join(base_dir, SESSION_EVENTS_FILE)
369
+ if not os.path.exists(events_path) and not os.path.exists(log_path):
370
+ raise SystemExit(f"events not found in {path}")
371
+ path_to_read = events_path if os.path.exists(events_path) else log_path
372
+ events: List[SessionEvent] = []
373
+ with open(path_to_read, "r", encoding="utf-8") as handle:
374
+ for line in handle:
375
+ raw = line.strip()
376
+ if not raw:
377
+ continue
378
+ data = json.loads(raw)
379
+ if isinstance(data, dict) and data.get("type") == "metadata":
380
+ continue
381
+ if isinstance(data, dict) and data.get("type") == "event":
382
+ data = data.get("data", {})
383
+ if not isinstance(data, dict):
384
+ continue
385
+ timestamp = float(data.get("t", 0.0))
386
+ ranks_raw = data.get("ranks", {})
387
+ ranks: Dict[int, Dict[str, object]] = {}
388
+ for key, value in ranks_raw.items():
389
+ try:
390
+ rank_id = int(key)
391
+ except (TypeError, ValueError):
392
+ continue
393
+ ranks[rank_id] = value
394
+ events.append(SessionEvent(timestamp=timestamp, ranks=ranks))
395
+ return events
396
+
397
+
398
+ def signature_from_snapshot(snapshot: Optional[RankSnapshot]) -> str:
399
+ if snapshot is None:
400
+ return "missing"
401
+ if snapshot.error:
402
+ return f"error:{snapshot.error}"
403
+ if snapshot.output is None:
404
+ return "missing"
405
+ digest = hashlib.sha1(snapshot.output.encode("utf-8", errors="ignore")).hexdigest()
406
+ return digest
407
+
408
+
409
+ def snapshot_signature(ranks: List[RankInfo], snapshots: Dict[int, RankSnapshot]) -> Dict[int, str]:
410
+ signature: Dict[int, str] = {}
411
+ for info in ranks:
412
+ signature[info.rank] = signature_from_snapshot(snapshots.get(info.rank))
413
+ return signature
414
+
415
+
416
+ def signature_from_event(event: Dict[str, object]) -> Optional[Dict[int, str]]:
417
+ ranks = event.get("ranks", {})
418
+ if not isinstance(ranks, dict):
419
+ return None
420
+ signature: Dict[int, str] = {}
421
+ for key, payload in ranks.items():
422
+ try:
423
+ rank_id = int(key)
424
+ except (TypeError, ValueError):
425
+ continue
426
+ if not isinstance(payload, dict):
427
+ signature[rank_id] = "missing"
428
+ continue
429
+ if payload.get("error"):
430
+ signature[rank_id] = f"error:{payload.get('error')}"
431
+ elif payload.get("py_spy"):
432
+ digest = hashlib.sha1(
433
+ str(payload.get("py_spy")).encode("utf-8", errors="ignore")
434
+ ).hexdigest()
435
+ signature[rank_id] = digest
436
+ else:
437
+ signature[rank_id] = "missing"
438
+ return signature
439
+
440
+
441
+ class RecordSession:
442
+ def __init__(self, path: str, state: State, refresh: int, pythonpath: str):
443
+ self.base_dir, self.log_path = ensure_session_path(path)
444
+ write_session_metadata(self.log_path, state, refresh, pythonpath)
445
+ self.handle = open(self.log_path, "a", encoding="utf-8")
446
+ self.event_count = 0
447
+ self.last_signature: Optional[Dict[int, str]] = None
448
+ last_event = read_last_event(self.log_path)
449
+ if last_event:
450
+ self.last_signature = signature_from_event(last_event)
451
+ self.event_count = self._count_events()
452
+
453
+ def _count_events(self) -> int:
454
+ if not os.path.exists(self.log_path):
455
+ return 0
456
+ count = 0
457
+ with open(self.log_path, "r", encoding="utf-8") as handle:
458
+ for line in handle:
459
+ raw = line.strip()
460
+ if not raw:
461
+ continue
462
+ try:
463
+ data = json.loads(raw)
464
+ except json.JSONDecodeError:
465
+ continue
466
+ if isinstance(data, dict) and data.get("type") == "metadata":
467
+ continue
468
+ count += 1
469
+ return count
470
+
471
+ def record_if_changed(
472
+ self,
473
+ state: State,
474
+ rank_to_proc: Dict[int, RankProcess],
475
+ snapshots: Dict[int, RankSnapshot],
476
+ ) -> bool:
477
+ signature = snapshot_signature(state.ranks, snapshots)
478
+ if self.last_signature is not None and signature == self.last_signature:
479
+ return False
480
+ payload: Dict[str, object] = {"t": time.time(), "ranks": {}}
481
+ ranks_payload: Dict[str, object] = {}
482
+ for info in state.ranks:
483
+ rank = info.rank
484
+ proc = rank_to_proc.get(rank)
485
+ snapshot = snapshots.get(rank)
486
+ entry: Dict[str, object] = {"host": info.host}
487
+ if proc is not None:
488
+ entry["pid"] = proc.pid
489
+ entry["cmdline"] = proc.cmdline
490
+ entry["rss_kb"] = proc.rss_kb
491
+ if snapshot is None:
492
+ entry["error"] = "No data"
493
+ elif snapshot.error:
494
+ entry["error"] = snapshot.error
495
+ elif snapshot.output is not None:
496
+ entry["py_spy"] = snapshot.output
497
+ else:
498
+ entry["error"] = "No data"
499
+ ranks_payload[str(rank)] = entry
500
+ payload["ranks"] = ranks_payload
501
+ self.handle.write(json.dumps({"type": "event", "data": payload}) + "\n")
502
+ self.handle.flush()
503
+ self.last_signature = signature
504
+ self.event_count += 1
505
+ return True
506
+
507
+ def close(self) -> None:
508
+ try:
509
+ self.handle.close()
510
+ except Exception:
511
+ pass
512
+
206
513
  def read_ps() -> List[Proc]:
207
514
  result = subprocess.run(
208
515
  ["ps", "-eo", "pid=,ppid=,args="],
@@ -410,6 +717,7 @@ def matches_python_cmd(cmd: List[str], selector: ProgramSelector) -> bool:
410
717
 
411
718
  def find_rank_pids_local(
412
719
  selector: ProgramSelector,
720
+ slurm_job_id: Optional[str],
413
721
  ) -> List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]]:
414
722
  results: List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]] = []
415
723
  for pid in os.listdir("/proc"):
@@ -428,7 +736,14 @@ def find_rank_pids_local(
428
736
  continue
429
737
  key, value = item.split(b"=", 1)
430
738
  env[key.decode(errors="ignore")] = value.decode(errors="ignore")
431
- rank = env.get("OMPI_COMM_WORLD_RANK") or env.get("PMIX_RANK") or env.get("PMI_RANK")
739
+ if slurm_job_id and env.get("SLURM_JOB_ID") != slurm_job_id:
740
+ continue
741
+ rank = (
742
+ env.get("OMPI_COMM_WORLD_RANK")
743
+ or env.get("PMIX_RANK")
744
+ or env.get("PMI_RANK")
745
+ or env.get("SLURM_PROCID")
746
+ )
432
747
  if rank is None:
433
748
  continue
434
749
  rss_kb = read_rss_kb(int(pid))
@@ -476,12 +791,15 @@ def run_ssh(host: str, command: str, timeout: int = 8) -> subprocess.CompletedPr
476
791
 
477
792
 
478
793
  def find_rank_pids_remote(
479
- host: str, selector: ProgramSelector
794
+ host: str,
795
+ selector: ProgramSelector,
796
+ slurm_job_id: Optional[str],
480
797
  ) -> Tuple[List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]], Optional[str]]:
481
798
  env_prefix = build_env_prefix(
482
799
  {
483
800
  "MPIPTOP_TARGET": selector.script or "",
484
801
  "MPIPTOP_MODULE": selector.module or "",
802
+ "MPIPTOP_SLURM_JOB_ID": slurm_job_id or "",
485
803
  }
486
804
  )
487
805
  remote_cmd = f"{env_prefix}python3 - <<'PY'\n{REMOTE_FINDER_SCRIPT}\nPY"
@@ -1134,6 +1452,9 @@ def build_header(
1134
1452
  program_lines = wrap_program_lines(state.selector, width)
1135
1453
  if not program_lines:
1136
1454
  program_lines = [Text("python")]
1455
+ for line in program_lines:
1456
+ line.no_wrap = True
1457
+ line.overflow = "crop"
1137
1458
 
1138
1459
  controls_plain = "q quit | space refresh | t threads | d details"
1139
1460
  padding = max(0, width - len(controls_plain))
@@ -1155,6 +1476,8 @@ def build_header(
1155
1476
  text.append_text(line)
1156
1477
  text.append("\n")
1157
1478
  text.append_text(line2)
1479
+ text.no_wrap = True
1480
+ text.overflow = "crop"
1158
1481
  return text, len(program_lines) + 1
1159
1482
 
1160
1483
 
@@ -1246,18 +1569,536 @@ def build_details_text(
1246
1569
  return output
1247
1570
 
1248
1571
 
1572
+ def format_elapsed(start: Optional[float]) -> str:
1573
+ if start is None:
1574
+ return "0:00"
1575
+ elapsed = max(0, int(time.time() - start))
1576
+ return format_duration(elapsed)
1577
+
1578
+
1579
+ def format_duration(elapsed: int) -> str:
1580
+ hours = elapsed // 3600
1581
+ minutes = (elapsed % 3600) // 60
1582
+ seconds = elapsed % 60
1583
+ if hours:
1584
+ return f"{hours}:{minutes:02d}:{seconds:02d}"
1585
+ return f"{minutes}:{seconds:02d}"
1586
+
1587
+
1588
+ def build_live_header(
1589
+ state: State,
1590
+ last_update: str,
1591
+ refresh: int,
1592
+ record_line: Optional[str],
1593
+ width: int,
1594
+ ) -> Tuple[Text, int]:
1595
+ program_lines = wrap_program_lines(state.selector, width)
1596
+ if not program_lines:
1597
+ program_lines = [Text("python")]
1598
+ for line in program_lines:
1599
+ line.no_wrap = True
1600
+ line.overflow = "crop"
1601
+
1602
+ record_text = None
1603
+ if record_line:
1604
+ record_text = Text()
1605
+ record_text.append("REC", style="bold red")
1606
+ record_text.append(" recording: ")
1607
+ record_text.append(record_line)
1608
+ record_text.truncate(width)
1609
+ record_text.no_wrap = True
1610
+ record_text.overflow = "crop"
1611
+
1612
+ controls_plain = "q quit | space refresh | t threads | d details | r record"
1613
+ padding = max(0, width - len(controls_plain))
1614
+ controls_line = Text(" " * padding + controls_plain)
1615
+ for token in ["q", "space", "t", "d", "r"]:
1616
+ start = controls_plain.find(token)
1617
+ if start != -1:
1618
+ controls_line.stylize(KEY_STYLE, padding + start, padding + start + len(token))
1619
+ controls_line.truncate(width)
1620
+ controls_line.no_wrap = True
1621
+ controls_line.overflow = "crop"
1622
+
1623
+ text = Text()
1624
+ for idx, line in enumerate(program_lines):
1625
+ if idx:
1626
+ text.append("\n")
1627
+ text.append_text(line)
1628
+ text.append("\n")
1629
+ if record_text is not None:
1630
+ text.append_text(record_text)
1631
+ text.append("\n")
1632
+ text.append_text(controls_line)
1633
+ text.no_wrap = True
1634
+ text.overflow = "crop"
1635
+ extra_lines = 2 if record_text is not None else 1
1636
+ return text, len(program_lines) + extra_lines
1637
+
1638
+
1639
+ def build_review_header(
1640
+ state: State,
1641
+ event_index: int,
1642
+ event_total: int,
1643
+ event_time: str,
1644
+ timeline_lines: List[Text],
1645
+ width: int,
1646
+ ) -> Tuple[Text, int]:
1647
+ program_lines = wrap_program_lines(state.selector, width)
1648
+ if not program_lines:
1649
+ program_lines = [Text("python")]
1650
+ status_line = Text(
1651
+ f"review {event_index + 1}/{event_total} | {event_time}"
1652
+ )
1653
+ status_line.truncate(width)
1654
+
1655
+ controls_plain = "q quit | left/right move | down zoom | up zoom out | t threads | d details"
1656
+ padding = max(0, width - len(controls_plain))
1657
+ controls_line = Text(" " * padding + controls_plain)
1658
+ for token in ["q", "left/right", "down", "up", "t", "d"]:
1659
+ start = controls_plain.find(token)
1660
+ if start != -1:
1661
+ controls_line.stylize(KEY_STYLE, padding + start, padding + start + len(token))
1662
+ controls_line.truncate(width)
1663
+ controls_line.no_wrap = True
1664
+ controls_line.overflow = "crop"
1665
+
1666
+ text = Text()
1667
+ for idx, line in enumerate(program_lines):
1668
+ if idx:
1669
+ text.append("\n")
1670
+ text.append_text(line)
1671
+ text.append("\n")
1672
+ text.append_text(status_line)
1673
+ for line in timeline_lines:
1674
+ text.append("\n")
1675
+ text.append_text(line)
1676
+ text.append("\n")
1677
+ text.append_text(controls_line)
1678
+ text.no_wrap = True
1679
+ text.overflow = "crop"
1680
+ return text, len(program_lines) + 1 + len(timeline_lines) + 1
1681
+
1682
+
1683
+ def build_buckets(start: int, end: int, width: int) -> List[Tuple[int, int]]:
1684
+ count = max(0, end - start)
1685
+ if count == 0:
1686
+ return []
1687
+ bucket_count = max(1, min(width, count))
1688
+ base = count // bucket_count
1689
+ remainder = count % bucket_count
1690
+ buckets: List[Tuple[int, int]] = []
1691
+ current = start
1692
+ for idx in range(bucket_count):
1693
+ size = base + (1 if idx < remainder else 0)
1694
+ buckets.append((current, current + size))
1695
+ current += size
1696
+ return buckets
1697
+
1698
+
1699
+ def divergence_color(ratio: float) -> str:
1700
+ clamped = min(1.0, max(0.0, ratio))
1701
+ intensity = clamped ** 0.7
1702
+ base = (170, 170, 170)
1703
+ hot = (255, 122, 0)
1704
+ r = int(base[0] + (hot[0] - base[0]) * intensity)
1705
+ g = int(base[1] + (hot[1] - base[1]) * intensity)
1706
+ b = int(base[2] + (hot[2] - base[2]) * intensity)
1707
+ return f"#{r:02x}{g:02x}{b:02x}"
1708
+
1709
+
1710
+ def compute_event_metrics(
1711
+ events: List[SessionEvent],
1712
+ ranks: List[RankInfo],
1713
+ show_threads: bool,
1714
+ ) -> Tuple[List[int], List[float], List[int]]:
1715
+ max_stack_lens: List[int] = []
1716
+ divergence_ratios: List[float] = []
1717
+ common_prefixes: List[int] = []
1718
+ for event in events:
1719
+ stacks_by_rank: Dict[int, List[str]] = {}
1720
+ for info in ranks:
1721
+ payload = event.ranks.get(info.rank, {})
1722
+ if payload.get("error"):
1723
+ stacks_by_rank[info.rank] = []
1724
+ continue
1725
+ output = payload.get("py_spy")
1726
+ if not output:
1727
+ stacks_by_rank[info.rank] = []
1728
+ continue
1729
+ lines, _details = render_pyspy_output(str(output), show_threads)
1730
+ stacks_by_rank[info.rank] = extract_stack_lines(lines)
1731
+ max_len = max((len(stack) for stack in stacks_by_rank.values()), default=0)
1732
+ common_len = common_prefix_length(stacks_by_rank)
1733
+ similarity = float(common_len) / float(max_len) if max_len else 0.0
1734
+ ratio = 1.0 - similarity if max_len else 0.0
1735
+ max_stack_lens.append(max_len)
1736
+ divergence_ratios.append(ratio)
1737
+ common_prefixes.append(common_len)
1738
+ return max_stack_lens, divergence_ratios, common_prefixes
1739
+
1740
+
1741
+ def render_timeline_lines(
1742
+ levels: List[TimelineLevel],
1743
+ max_stack_lens: List[int],
1744
+ divergence_ratios: List[float],
1745
+ width: int,
1746
+ ) -> List[Text]:
1747
+ lines: List[Text] = []
1748
+ for level_index, level in enumerate(levels):
1749
+ level.buckets = build_buckets(level.start, level.end, width)
1750
+ if level.buckets:
1751
+ level.selected = max(0, min(level.selected, len(level.buckets) - 1))
1752
+ stats: List[Tuple[int, float]] = []
1753
+ for start, end in level.buckets:
1754
+ bucket_heights = max_stack_lens[start:end]
1755
+ bucket_ratios = divergence_ratios[start:end]
1756
+ height = max(bucket_heights) if bucket_heights else 0
1757
+ ratio = max(bucket_ratios) if bucket_ratios else 0.0
1758
+ stats.append((height, ratio))
1759
+ max_height = max((height for height, _ in stats), default=1)
1760
+ if max_height <= 0:
1761
+ max_height = 1
1762
+ text = Text()
1763
+ for idx, (height, ratio) in enumerate(stats):
1764
+ normalized = float(height) / float(max_height) if max_height else 0.0
1765
+ level_idx = int(round(normalized * (len(SPARKLINE_CHARS) - 1)))
1766
+ level_idx = max(0, min(level_idx, len(SPARKLINE_CHARS) - 1))
1767
+ char = SPARKLINE_CHARS[level_idx]
1768
+ style = divergence_color(ratio)
1769
+ if idx == level.selected:
1770
+ if level_index == len(levels) - 1:
1771
+ style = f"{style} bold underline"
1772
+ else:
1773
+ style = f"{style} underline"
1774
+ text.append(char, style=style)
1775
+ text.no_wrap = True
1776
+ text.overflow = "crop"
1777
+ lines.append(text)
1778
+ return lines
1779
+
1780
+
1781
+ def event_snapshots_from_event(
1782
+ event: SessionEvent,
1783
+ ranks: List[RankInfo],
1784
+ show_threads: bool,
1785
+ ) -> Dict[int, RankSnapshot]:
1786
+ snapshots: Dict[int, RankSnapshot] = {}
1787
+ for info in ranks:
1788
+ payload = event.ranks.get(info.rank)
1789
+ if not payload:
1790
+ snapshots[info.rank] = RankSnapshot(
1791
+ output=None,
1792
+ error="No data",
1793
+ stack_lines=["No data"],
1794
+ details=[],
1795
+ )
1796
+ continue
1797
+ if payload.get("error"):
1798
+ snapshots[info.rank] = RankSnapshot(
1799
+ output=None,
1800
+ error=str(payload.get("error")),
1801
+ stack_lines=[str(payload.get("error"))],
1802
+ details=[],
1803
+ )
1804
+ continue
1805
+ output = payload.get("py_spy")
1806
+ if not output:
1807
+ snapshots[info.rank] = RankSnapshot(
1808
+ output=None,
1809
+ error="No data",
1810
+ stack_lines=["No data"],
1811
+ details=[],
1812
+ )
1813
+ continue
1814
+ lines, details = render_pyspy_output(str(output), show_threads)
1815
+ snapshots[info.rank] = RankSnapshot(
1816
+ output=str(output),
1817
+ error=None,
1818
+ stack_lines=lines,
1819
+ details=details,
1820
+ )
1821
+ return snapshots
1822
+
1823
+
1824
+ def rank_to_proc_from_event(
1825
+ event: SessionEvent,
1826
+ ranks: List[RankInfo],
1827
+ ) -> Dict[int, RankProcess]:
1828
+ rank_to_proc: Dict[int, RankProcess] = {}
1829
+ for info in ranks:
1830
+ payload = event.ranks.get(info.rank)
1831
+ if not payload:
1832
+ continue
1833
+ pid = payload.get("pid")
1834
+ cmdline = payload.get("cmdline")
1835
+ rss_kb = payload.get("rss_kb")
1836
+ if pid is None or cmdline is None:
1837
+ continue
1838
+ try:
1839
+ pid_value = int(pid)
1840
+ except (TypeError, ValueError):
1841
+ continue
1842
+ rss_value = None
1843
+ if rss_kb is not None:
1844
+ try:
1845
+ rss_value = int(rss_kb)
1846
+ except (TypeError, ValueError):
1847
+ rss_value = None
1848
+ rank_to_proc[info.rank] = RankProcess(
1849
+ pid=pid_value,
1850
+ cmdline=str(cmdline),
1851
+ rss_kb=rss_value,
1852
+ python_exe=None,
1853
+ env={},
1854
+ )
1855
+ return rank_to_proc
1856
+
1857
+
1858
+ def compute_divergence_from_snapshots(
1859
+ ranks: List[RankInfo], snapshots: Dict[int, RankSnapshot]
1860
+ ) -> Tuple[float, int, int]:
1861
+ stack_lines_by_rank = {
1862
+ info.rank: extract_stack_lines(snapshots.get(info.rank, RankSnapshot(None, "No data", [], [])).stack_lines)
1863
+ for info in ranks
1864
+ }
1865
+ max_len = max((len(stack) for stack in stack_lines_by_rank.values()), default=0)
1866
+ common_len = common_prefix_length(stack_lines_by_rank)
1867
+ similarity = float(common_len) / float(max_len) if max_len else 0.0
1868
+ divergence = 1.0 - similarity if max_len else 0.0
1869
+ return divergence, common_len, max_len
1870
+
1871
+
1872
+ def read_key(timeout: float) -> Optional[str]:
1873
+ if sys.stdin not in select_with_timeout(timeout):
1874
+ return None
1875
+ key = sys.stdin.read(1)
1876
+ if key != "\x1b":
1877
+ return key
1878
+ seq = key
1879
+ for _ in range(2):
1880
+ if sys.stdin in select_with_timeout(0.01):
1881
+ seq += sys.stdin.read(1)
1882
+ if seq == "\x1b[A":
1883
+ return "up"
1884
+ if seq == "\x1b[B":
1885
+ return "down"
1886
+ if seq == "\x1b[C":
1887
+ return "right"
1888
+ if seq == "\x1b[D":
1889
+ return "left"
1890
+ return None
1891
+
1892
+
1893
+ def is_pid_alive(pid: int) -> bool:
1894
+ if pid <= 0:
1895
+ return False
1896
+ try:
1897
+ os.kill(pid, 0)
1898
+ except ProcessLookupError:
1899
+ return False
1900
+ except PermissionError:
1901
+ return True
1902
+ return True
1903
+
1904
+
1905
+ def parse_scontrol_kv(line: str) -> Dict[str, str]:
1906
+ fields: Dict[str, str] = {}
1907
+ for token in line.split():
1908
+ if "=" not in token:
1909
+ continue
1910
+ key, value = token.split("=", 1)
1911
+ fields[key] = value
1912
+ return fields
1913
+
1914
+
1915
+ def run_scontrol_show_job(job_id: str) -> Dict[str, str]:
1916
+ result = subprocess.run(
1917
+ ["scontrol", "show", "job", "-o", str(job_id)],
1918
+ capture_output=True,
1919
+ text=True,
1920
+ )
1921
+ if result.returncode != 0:
1922
+ stderr = (result.stderr or result.stdout or "").strip()
1923
+ raise SystemExit(f"scontrol show job failed for {job_id}: {stderr or 'unknown error'}")
1924
+ line = (result.stdout or "").strip()
1925
+ if not line:
1926
+ raise SystemExit(f"scontrol show job returned empty output for {job_id}")
1927
+ return parse_scontrol_kv(line)
1928
+
1929
+
1930
+ def expand_slurm_nodelist(nodelist: str) -> List[str]:
1931
+ result = subprocess.run(
1932
+ ["scontrol", "show", "hostnames", nodelist],
1933
+ capture_output=True,
1934
+ text=True,
1935
+ )
1936
+ if result.returncode != 0:
1937
+ stderr = (result.stderr or result.stdout or "").strip()
1938
+ raise SystemExit(f"scontrol show hostnames failed: {stderr or 'unknown error'}")
1939
+ hosts = [line.strip() for line in result.stdout.splitlines() if line.strip()]
1940
+ if not hosts:
1941
+ raise SystemExit(f"no hosts parsed from nodelist: {nodelist}")
1942
+ return hosts
1943
+
1944
+
1945
+ def parse_tasks_per_node(raw: str) -> List[int]:
1946
+ if not raw:
1947
+ return []
1948
+ counts: List[int] = []
1949
+ for part in raw.split(","):
1950
+ part = part.strip()
1951
+ if not part:
1952
+ continue
1953
+ match = re.match(r"(\d+)\(x(\d+)\)", part)
1954
+ if match:
1955
+ value = int(match.group(1))
1956
+ repeat = int(match.group(2))
1957
+ counts.extend([value] * repeat)
1958
+ continue
1959
+ if part.isdigit():
1960
+ counts.append(int(part))
1961
+ return counts
1962
+
1963
+
1964
+ def distribute_tasks(num_tasks: int, num_nodes: int) -> List[int]:
1965
+ if num_nodes <= 0:
1966
+ return []
1967
+ base = num_tasks // num_nodes
1968
+ remainder = num_tasks % num_nodes
1969
+ counts = [base] * num_nodes
1970
+ for idx in range(remainder):
1971
+ counts[idx] += 1
1972
+ return counts
1973
+
1974
+
1975
+ def slurm_job_to_ranks(job_id: str) -> List[RankInfo]:
1976
+ info = run_scontrol_show_job(job_id)
1977
+ nodelist = info.get("NodeList") or info.get("Nodes")
1978
+ if not nodelist:
1979
+ raise SystemExit(f"no NodeList found for slurm job {job_id}")
1980
+ hosts = expand_slurm_nodelist(nodelist)
1981
+ tasks_per_node = parse_tasks_per_node(info.get("TasksPerNode", ""))
1982
+ num_tasks = 0
1983
+ try:
1984
+ num_tasks = int(info.get("NumTasks", "0") or 0)
1985
+ except ValueError:
1986
+ num_tasks = 0
1987
+
1988
+ if len(tasks_per_node) != len(hosts):
1989
+ if num_tasks > 0:
1990
+ tasks_per_node = distribute_tasks(num_tasks, len(hosts))
1991
+ else:
1992
+ tasks_per_node = [1] * len(hosts)
1993
+
1994
+ ranks: List[RankInfo] = []
1995
+ rank_id = 0
1996
+ for host, count in zip(hosts, tasks_per_node):
1997
+ for _ in range(max(0, count)):
1998
+ ranks.append(RankInfo(rank=rank_id, host=host))
1999
+ rank_id += 1
2000
+ return ranks
2001
+
2002
+
2003
+ def resolve_slurm_job_id(args: argparse.Namespace) -> Optional[str]:
2004
+ if getattr(args, "slurm_job", None):
2005
+ return str(args.slurm_job)
2006
+ env_job = os.environ.get("SLURM_JOB_ID")
2007
+ if env_job:
2008
+ return env_job
2009
+ user = os.environ.get("USER")
2010
+ if not user:
2011
+ return None
2012
+ result = subprocess.run(
2013
+ ["squeue", "-u", user, "-h", "-t", "R", "-o", "%i"],
2014
+ capture_output=True,
2015
+ text=True,
2016
+ )
2017
+ if result.returncode != 0:
2018
+ return None
2019
+ jobs = [line.strip() for line in result.stdout.splitlines() if line.strip()]
2020
+ if len(jobs) == 1:
2021
+ return jobs[0]
2022
+ return None
2023
+
2024
+
2025
+ def describe_slurm_jobs() -> str:
2026
+ user = os.environ.get("USER")
2027
+ if not user:
2028
+ return ""
2029
+ result = subprocess.run(
2030
+ ["squeue", "-u", user, "-h", "-o", "%i %t %j %R"],
2031
+ capture_output=True,
2032
+ text=True,
2033
+ )
2034
+ if result.returncode != 0:
2035
+ return ""
2036
+ lines = [line.strip() for line in result.stdout.splitlines() if line.strip()]
2037
+ if not lines:
2038
+ return ""
2039
+ return "\n".join(lines[:10])
2040
+
2041
+
2042
+ def is_slurm_job_alive(job_id: Optional[str]) -> bool:
2043
+ if not job_id:
2044
+ return False
2045
+ result = subprocess.run(
2046
+ ["squeue", "-j", str(job_id), "-h", "-o", "%t"],
2047
+ capture_output=True,
2048
+ text=True,
2049
+ )
2050
+ if result.returncode != 0:
2051
+ return False
2052
+ state = (result.stdout or "").strip()
2053
+ return bool(state)
2054
+
1249
2055
  def detect_state(args: argparse.Namespace) -> State:
1250
2056
  procs = read_ps()
1251
- prte = find_prterun(procs, args.prterun_pid)
1252
- rankfile = args.rankfile or find_rankfile_path(prte.args)
1253
- if not rankfile:
1254
- raise SystemExit("rankfile not found in prterun/mpirun args")
1255
- ranks = parse_rankfile(rankfile)
1256
- children = build_children_map(procs)
1257
- descendants = find_descendants(children, prte.pid)
1258
- program_proc = select_program(procs, descendants)
1259
- selector = parse_python_selector(program_proc.args if program_proc else "")
1260
- return State(prte_pid=prte.pid, rankfile=rankfile, ranks=ranks, selector=selector)
2057
+ prte_error = None
2058
+ prte = None
2059
+ try:
2060
+ prte = find_prterun(procs, args.prterun_pid)
2061
+ except SystemExit as exc:
2062
+ prte_error = str(exc)
2063
+
2064
+ if prte is not None:
2065
+ rankfile = args.rankfile or find_rankfile_path(prte.args)
2066
+ if not rankfile:
2067
+ raise SystemExit("rankfile not found in prterun/mpirun args")
2068
+ ranks = parse_rankfile(rankfile)
2069
+ children = build_children_map(procs)
2070
+ descendants = find_descendants(children, prte.pid)
2071
+ program_proc = select_program(procs, descendants)
2072
+ selector = parse_python_selector(program_proc.args if program_proc else "")
2073
+ return State(
2074
+ launcher="prte",
2075
+ prte_pid=prte.pid,
2076
+ slurm_job_id=None,
2077
+ rankfile=rankfile,
2078
+ ranks=ranks,
2079
+ selector=selector,
2080
+ )
2081
+
2082
+ slurm_job_id = resolve_slurm_job_id(args)
2083
+ if slurm_job_id:
2084
+ ranks = slurm_job_to_ranks(slurm_job_id)
2085
+ selector = ProgramSelector(module=None, script=None, display="")
2086
+ return State(
2087
+ launcher="slurm",
2088
+ prte_pid=0,
2089
+ slurm_job_id=slurm_job_id,
2090
+ rankfile=f"slurm:{slurm_job_id}",
2091
+ ranks=ranks,
2092
+ selector=selector,
2093
+ )
2094
+
2095
+ hint = describe_slurm_jobs()
2096
+ if hint:
2097
+ hint = "\n" + hint
2098
+ raise SystemExit(
2099
+ prte_error
2100
+ or f"no prterun/mpirun process found and no slurm job detected (try --slurm-job){hint}"
2101
+ )
1261
2102
 
1262
2103
 
1263
2104
  def collect_rank_pids(state: State) -> Tuple[Dict[int, RankProcess], List[str]]:
@@ -1268,10 +2109,10 @@ def collect_rank_pids(state: State) -> Tuple[Dict[int, RankProcess], List[str]]:
1268
2109
 
1269
2110
  for host in hosts:
1270
2111
  if is_local_host(host):
1271
- entries = find_rank_pids_local(state.selector)
2112
+ entries = find_rank_pids_local(state.selector, state.slurm_job_id)
1272
2113
  host_error = None
1273
2114
  else:
1274
- entries, host_error = find_rank_pids_remote(host, state.selector)
2115
+ entries, host_error = find_rank_pids_remote(host, state.selector, state.slurm_job_id)
1275
2116
  if host_error:
1276
2117
  errors.append(host_error)
1277
2118
  for rank, pid, cmd, rss_kb, python_exe, env_subset in entries:
@@ -1297,42 +2138,105 @@ def collect_stacks(
1297
2138
  pythonpath: str,
1298
2139
  show_threads: bool,
1299
2140
  install_attempted: set,
1300
- ) -> Tuple[Dict[int, List[str]], Dict[int, List[str]], List[str]]:
1301
- stacks: Dict[int, List[str]] = {}
1302
- details_by_rank: Dict[int, List[str]] = {}
2141
+ ) -> Tuple[Dict[int, RankSnapshot], List[str]]:
2142
+ snapshots: Dict[int, RankSnapshot] = {}
1303
2143
  errors: List[str] = []
1304
2144
  for entry in state.ranks:
1305
2145
  proc = rank_to_proc.get(entry.rank)
1306
2146
  if proc is None:
1307
- stacks[entry.rank] = ["No process"]
1308
- details_by_rank[entry.rank] = []
2147
+ snapshots[entry.rank] = RankSnapshot(
2148
+ output=None,
2149
+ error="No process",
2150
+ stack_lines=["No process"],
2151
+ details=[],
2152
+ )
1309
2153
  continue
1310
2154
  output, error = run_py_spy(entry.host, proc, pythonpath, install_attempted)
1311
2155
  if error:
1312
2156
  errors.append(error)
1313
- stacks[entry.rank] = [error]
1314
- details_by_rank[entry.rank] = []
2157
+ snapshots[entry.rank] = RankSnapshot(
2158
+ output=None,
2159
+ error=error,
2160
+ stack_lines=[error],
2161
+ details=[],
2162
+ )
1315
2163
  continue
1316
2164
  lines, details = render_pyspy_output(output or "", show_threads)
1317
- stacks[entry.rank] = lines
1318
- details_by_rank[entry.rank] = details
1319
- return stacks, details_by_rank, errors
2165
+ snapshots[entry.rank] = RankSnapshot(
2166
+ output=output,
2167
+ error=None,
2168
+ stack_lines=lines,
2169
+ details=details,
2170
+ )
2171
+ return snapshots, errors
1320
2172
 
1321
2173
 
1322
- def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
2174
+ def parse_live_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
1323
2175
  parser = argparse.ArgumentParser(description="Show MPI Python stacks across hosts.")
1324
2176
  parser.add_argument("--rankfile", help="Override rankfile path")
1325
2177
  parser.add_argument("--prterun-pid", type=int, help="PID of prterun/mpirun")
2178
+ parser.add_argument("--slurm-job", help="Slurm job ID to inspect")
1326
2179
  parser.add_argument("--refresh", type=int, default=10, help="Refresh interval (seconds)")
1327
2180
  parser.add_argument(
1328
2181
  "--pythonpath",
1329
2182
  help="PYTHONPATH to export remotely (defaults to local PYTHONPATH)",
1330
2183
  )
2184
+ parser.add_argument(
2185
+ "--out",
2186
+ help="Output path for recordings (.jsonl file or directory)",
2187
+ )
1331
2188
  return parser.parse_args(argv)
1332
2189
 
1333
2190
 
1334
- def main(argv: Optional[Sequence[str]] = None) -> int:
1335
- args = parse_args(argv)
2191
+ def parse_review_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
2192
+ parser = argparse.ArgumentParser(description="Review a recorded mpiptop session.")
2193
+ parser.add_argument("path", help="Path to a recorded session (.jsonl file or directory)")
2194
+ return parser.parse_args(argv)
2195
+
2196
+
2197
+ def parse_summarize_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
2198
+ parser = argparse.ArgumentParser(description="Summarize a recorded mpiptop session.")
2199
+ parser.add_argument("path", help="Path to a recorded session (.jsonl file or directory)")
2200
+ parser.add_argument(
2201
+ "--format",
2202
+ choices=["text", "json"],
2203
+ default="text",
2204
+ help="Output format",
2205
+ )
2206
+ parser.add_argument(
2207
+ "--top",
2208
+ type=int,
2209
+ default=5,
2210
+ help="Top signatures to report",
2211
+ )
2212
+ return parser.parse_args(argv)
2213
+
2214
+
2215
+ def parse_record_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
2216
+ parser = argparse.ArgumentParser(description="Record an mpiptop session.")
2217
+ parser.add_argument("--rankfile", help="Override rankfile path")
2218
+ parser.add_argument("--prterun-pid", type=int, help="PID of prterun/mpirun")
2219
+ parser.add_argument("--slurm-job", help="Slurm job ID to inspect")
2220
+ parser.add_argument("--refresh", type=int, default=10, help="Refresh interval (seconds)")
2221
+ parser.add_argument(
2222
+ "--pythonpath",
2223
+ help="PYTHONPATH to export remotely (defaults to local PYTHONPATH)",
2224
+ )
2225
+ parser.add_argument(
2226
+ "--out",
2227
+ help="Output path for recordings (.jsonl file or directory)",
2228
+ )
2229
+ parser.add_argument(
2230
+ "--quiet",
2231
+ action="store_true",
2232
+ help="Only print start/stop lines",
2233
+ )
2234
+ args = parser.parse_args(argv)
2235
+ args.record = True
2236
+ return args
2237
+
2238
+
2239
+ def run_live(args: argparse.Namespace) -> int:
1336
2240
  pythonpath = args.pythonpath if args.pythonpath is not None else os.environ.get("PYTHONPATH", "")
1337
2241
 
1338
2242
  state = detect_state(args)
@@ -1341,6 +2245,10 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
1341
2245
  show_threads = False
1342
2246
  show_details = False
1343
2247
  install_attempted: set = set()
2248
+ record_session: Optional[RecordSession] = None
2249
+ recording_enabled = bool(getattr(args, "record", False))
2250
+ record_started_at: Optional[float] = None
2251
+ record_path = args.out
1344
2252
 
1345
2253
  def handle_sigint(_sig, _frame):
1346
2254
  raise KeyboardInterrupt
@@ -1357,32 +2265,60 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
1357
2265
  last_update = "never"
1358
2266
  next_refresh = 0.0
1359
2267
 
2268
+ def start_recording() -> None:
2269
+ nonlocal record_session, recording_enabled, record_started_at, record_path
2270
+ if record_session is None:
2271
+ record_path = record_path or default_session_path()
2272
+ record_session = RecordSession(record_path, state, refresh, pythonpath)
2273
+ recording_enabled = True
2274
+ if record_started_at is None:
2275
+ record_started_at = time.time()
2276
+
2277
+ def stop_recording() -> None:
2278
+ nonlocal recording_enabled, record_started_at
2279
+ recording_enabled = False
2280
+ record_started_at = None
2281
+
2282
+ if recording_enabled:
2283
+ start_recording()
2284
+
1360
2285
  def refresh_view() -> None:
1361
- nonlocal last_update, state
1362
- rank_to_proc, pid_errors = collect_rank_pids(state)
2286
+ nonlocal last_update, state, record_session
2287
+ rank_to_proc, _pid_errors = collect_rank_pids(state)
1363
2288
  candidate = best_selector_from_procs(rank_to_proc.values())
1364
2289
  if candidate and selector_score(candidate) > selector_score(state.selector):
1365
2290
  state = dataclasses.replace(state, selector=candidate)
1366
- stacks, details_by_rank, stack_errors = collect_stacks(
2291
+ snapshots, _stack_errors = collect_stacks(
1367
2292
  state, rank_to_proc, pythonpath, show_threads, install_attempted
1368
2293
  )
2294
+ if recording_enabled and record_session is not None:
2295
+ record_session.record_if_changed(state, rank_to_proc, snapshots)
1369
2296
  stacks_text: Dict[int, Text] = {}
1370
- stack_lines_by_rank = {rank: extract_stack_lines(lines) for rank, lines in stacks.items()}
2297
+ stack_lines_by_rank = {
2298
+ rank: extract_stack_lines(snapshot.stack_lines)
2299
+ for rank, snapshot in snapshots.items()
2300
+ }
1371
2301
  prefix_len = common_prefix_length(stack_lines_by_rank)
1372
2302
  diff_index = None
1373
2303
  if any(stack_lines_by_rank.values()):
1374
- if prefix_len > 0:
1375
- diff_index = prefix_len - 1
1376
- else:
1377
- diff_index = 0
1378
- for rank, lines in stacks.items():
2304
+ diff_index = max(0, prefix_len - 1) if prefix_len > 0 else 0
2305
+ for rank, snapshot in snapshots.items():
2306
+ lines = snapshot.stack_lines
1379
2307
  marked = mark_diff_line(lines, diff_index) if diff_index is not None else lines
1380
2308
  stacks_text[rank] = style_lines(marked)
1381
- errors = pid_errors + stack_errors
2309
+ details_by_rank = {
2310
+ rank: snapshot.details for rank, snapshot in snapshots.items()
2311
+ }
1382
2312
  last_update = time.strftime("%H:%M:%S")
1383
2313
  width, height = shutil.get_terminal_size((120, 40))
1384
2314
  content_width = max(0, width - 4)
1385
- header, header_lines = build_header(state, last_update, errors, refresh, content_width)
2315
+ record_line = None
2316
+ if record_session is not None and recording_enabled:
2317
+ record_line = f"{record_session.log_path} | events {record_session.event_count} | {format_elapsed(record_started_at)}"
2318
+ record_line = shorten(record_line, max(10, content_width - 12))
2319
+ header, header_lines = build_live_header(
2320
+ state, last_update, refresh, record_line, content_width
2321
+ )
1386
2322
  header_height = header_lines + 2
1387
2323
  header_height = max(3, min(header_height, max(3, height - 1)))
1388
2324
  layout["header"].size = header_height
@@ -1412,26 +2348,405 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
1412
2348
  refresh_view()
1413
2349
  next_refresh = now + refresh
1414
2350
 
1415
- if sys.stdin in select_with_timeout(0.1):
1416
- key = sys.stdin.read(1)
1417
- if key == "q":
1418
- return 0
1419
- if key == " ":
1420
- next_refresh = 0.0
1421
- if key == "t":
1422
- show_threads = not show_threads
1423
- next_refresh = 0.0
1424
- if key == "d":
1425
- show_details = not show_details
1426
- next_refresh = 0.0
2351
+ key = read_key(0.1)
2352
+ if key is None:
2353
+ continue
2354
+ if key == "q":
2355
+ return 0
2356
+ if key == " ":
2357
+ next_refresh = 0.0
2358
+ if key == "t":
2359
+ show_threads = not show_threads
2360
+ next_refresh = 0.0
2361
+ if key == "d":
2362
+ show_details = not show_details
2363
+ next_refresh = 0.0
2364
+ if key == "r":
2365
+ if recording_enabled:
2366
+ stop_recording()
2367
+ else:
2368
+ start_recording()
2369
+ next_refresh = 0.0
1427
2370
  except KeyboardInterrupt:
1428
2371
  return 0
1429
2372
  finally:
1430
2373
  termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
2374
+ if record_session is not None:
2375
+ record_session.close()
2376
+ if record_session.event_count > 0:
2377
+ print(f"Recording saved to: {record_session.log_path}")
1431
2378
 
1432
2379
  return 0
1433
2380
 
1434
2381
 
2382
+ def run_record_batch(args: argparse.Namespace) -> int:
2383
+ pythonpath = args.pythonpath if args.pythonpath is not None else os.environ.get("PYTHONPATH", "")
2384
+ state = detect_state(args)
2385
+ refresh = max(1, args.refresh)
2386
+ record_path = args.out or default_session_path()
2387
+ record_session = RecordSession(record_path, state, refresh, pythonpath)
2388
+ quiet = bool(args.quiet)
2389
+ install_attempted: set = set()
2390
+ start_time = time.time()
2391
+ last_change: Optional[float] = None
2392
+ last_heartbeat = start_time
2393
+ last_divergence_time = 0.0
2394
+ stop_reason = "completed"
2395
+
2396
+ target = state.selector.display or "python"
2397
+ target = shorten(target, 120)
2398
+ print(
2399
+ f"recording start | path={record_session.log_path} | ranks={len(state.ranks)} | "
2400
+ f"refresh={refresh}s | target={target}"
2401
+ )
2402
+
2403
+ try:
2404
+ while True:
2405
+ loop_start = time.time()
2406
+ if state.launcher == "prte":
2407
+ if not is_pid_alive(state.prte_pid):
2408
+ stop_reason = "prterun-exited"
2409
+ break
2410
+ else:
2411
+ if not is_slurm_job_alive(state.slurm_job_id):
2412
+ stop_reason = "slurm-job-exited"
2413
+ break
2414
+ rank_to_proc, _pid_errors = collect_rank_pids(state)
2415
+ snapshots, _stack_errors = collect_stacks(
2416
+ state, rank_to_proc, pythonpath, False, install_attempted
2417
+ )
2418
+ if record_session.record_if_changed(state, rank_to_proc, snapshots):
2419
+ last_change = time.time()
2420
+ divergence, common_len, max_len = compute_divergence_from_snapshots(state.ranks, snapshots)
2421
+ now = time.time()
2422
+ if not quiet and now - last_heartbeat >= HEARTBEAT_INTERVAL:
2423
+ last_change_age = "never"
2424
+ if last_change is not None:
2425
+ last_change_age = format_duration(int(now - last_change))
2426
+ elapsed = format_duration(int(now - start_time))
2427
+ print(
2428
+ f"heartbeat | events={record_session.event_count} | "
2429
+ f"last_change={last_change_age} | elapsed={elapsed}"
2430
+ )
2431
+ last_heartbeat = now
2432
+ if (
2433
+ not quiet
2434
+ and divergence >= DIVERGENCE_THRESHOLD
2435
+ and now - last_divergence_time >= DIVERGENCE_INTERVAL
2436
+ ):
2437
+ print(
2438
+ f"divergence | ratio={divergence:.2f} | common={common_len} | max={max_len}"
2439
+ )
2440
+ last_divergence_time = now
2441
+ elapsed = time.time() - loop_start
2442
+ sleep_for = refresh - elapsed
2443
+ if sleep_for > 0:
2444
+ time.sleep(sleep_for)
2445
+ except KeyboardInterrupt:
2446
+ stop_reason = "interrupted"
2447
+ finally:
2448
+ record_session.close()
2449
+ elapsed = format_duration(int(time.time() - start_time))
2450
+ print(
2451
+ f"recording stop | reason={stop_reason} | events={record_session.event_count} | "
2452
+ f"elapsed={elapsed} | path={record_session.log_path}"
2453
+ )
2454
+
2455
+ return 0
2456
+
2457
+
2458
+ def run_review(args: argparse.Namespace) -> int:
2459
+ metadata = load_session_metadata(args.path)
2460
+ ranks = [
2461
+ RankInfo(rank=int(item["rank"]), host=str(item["host"]))
2462
+ for item in metadata.get("ranks", [])
2463
+ if "rank" in item and "host" in item
2464
+ ]
2465
+ if not ranks:
2466
+ raise SystemExit("no ranks found in metadata")
2467
+ selector_payload = metadata.get("selector", {}) if isinstance(metadata.get("selector"), dict) else {}
2468
+ selector = ProgramSelector(
2469
+ module=selector_payload.get("module"),
2470
+ script=selector_payload.get("script"),
2471
+ display=selector_payload.get("display", ""),
2472
+ )
2473
+ state = State(
2474
+ launcher=str(metadata.get("launcher", "prte")),
2475
+ prte_pid=int(metadata.get("prte_pid", 0) or 0),
2476
+ slurm_job_id=metadata.get("slurm_job_id"),
2477
+ rankfile=str(metadata.get("rankfile", "")),
2478
+ ranks=ranks,
2479
+ selector=selector,
2480
+ )
2481
+ events = load_session_events(args.path)
2482
+ if not events:
2483
+ raise SystemExit("no events recorded")
2484
+
2485
+ console = Console()
2486
+ show_threads = False
2487
+ show_details = False
2488
+ levels = [TimelineLevel(0, len(events), selected=0)]
2489
+ max_stack_lens, divergence_ratios, _ = compute_event_metrics(
2490
+ events, ranks, show_threads
2491
+ )
2492
+
2493
+ def handle_sigint(_sig, _frame):
2494
+ raise KeyboardInterrupt
2495
+
2496
+ signal.signal(signal.SIGINT, handle_sigint)
2497
+
2498
+ fd = sys.stdin.fileno()
2499
+ old_settings = termios.tcgetattr(fd)
2500
+ tty.setcbreak(fd)
2501
+
2502
+ layout = Layout()
2503
+ layout.split_column(Layout(name="header", size=HEADER_HEIGHT), Layout(name="body"))
2504
+
2505
+ def refresh_view() -> None:
2506
+ width, height = shutil.get_terminal_size((120, 40))
2507
+ content_width = max(0, width - 4)
2508
+ timeline_lines = render_timeline_lines(levels, max_stack_lens, divergence_ratios, content_width)
2509
+ active_level = levels[-1]
2510
+ if not active_level.buckets:
2511
+ return
2512
+ current_index = active_level.buckets[active_level.selected][0]
2513
+ current_index = max(0, min(current_index, len(events) - 1))
2514
+ event = events[current_index]
2515
+ snapshots = event_snapshots_from_event(event, ranks, show_threads)
2516
+ rank_to_proc = rank_to_proc_from_event(event, ranks)
2517
+ stack_lines_by_rank = {
2518
+ rank: extract_stack_lines(snapshot.stack_lines)
2519
+ for rank, snapshot in snapshots.items()
2520
+ }
2521
+ prefix_len = common_prefix_length(stack_lines_by_rank)
2522
+ diff_index = None
2523
+ if any(stack_lines_by_rank.values()):
2524
+ diff_index = max(0, prefix_len - 1) if prefix_len > 0 else 0
2525
+ stacks_text: Dict[int, Text] = {}
2526
+ for rank, snapshot in snapshots.items():
2527
+ lines = snapshot.stack_lines
2528
+ marked = mark_diff_line(lines, diff_index) if diff_index is not None else lines
2529
+ stacks_text[rank] = style_lines(marked)
2530
+ details_by_rank = {
2531
+ rank: snapshot.details for rank, snapshot in snapshots.items()
2532
+ }
2533
+ event_time = iso_timestamp(event.timestamp)
2534
+ header, header_lines = build_review_header(
2535
+ state,
2536
+ current_index,
2537
+ len(events),
2538
+ event_time,
2539
+ timeline_lines,
2540
+ content_width,
2541
+ )
2542
+ header_height = header_lines + 2
2543
+ header_height = max(3, min(header_height, max(3, height - 1)))
2544
+ layout["header"].size = header_height
2545
+ body_height = max(1, height - header_height)
2546
+ total_columns = len(ranks) + (1 if show_details else 0)
2547
+ column_width = max(1, content_width // max(1, total_columns))
2548
+ inner_width = max(1, column_width - 4)
2549
+ details_text = (
2550
+ build_details_text(ranks, rank_to_proc, details_by_rank, inner_width)
2551
+ if show_details
2552
+ else None
2553
+ )
2554
+ layout["header"].update(
2555
+ Panel(header, padding=(0, 1), border_style=BORDER_STYLE)
2556
+ )
2557
+ layout["body"].update(
2558
+ render_columns(ranks, stacks_text, details_text, body_height, rank_to_proc)
2559
+ )
2560
+
2561
+ try:
2562
+ refresh_view()
2563
+ with Live(layout, console=console, refresh_per_second=10, screen=True):
2564
+ while True:
2565
+ key = read_key(0.1)
2566
+ if key is None:
2567
+ continue
2568
+ if key == "q":
2569
+ return 0
2570
+ if key == "t":
2571
+ show_threads = not show_threads
2572
+ max_stack_lens, divergence_ratios, _ = compute_event_metrics(
2573
+ events, ranks, show_threads
2574
+ )
2575
+ refresh_view()
2576
+ if key == "d":
2577
+ show_details = not show_details
2578
+ refresh_view()
2579
+ if key == "left":
2580
+ level = levels[-1]
2581
+ level.selected = max(0, level.selected - 1)
2582
+ refresh_view()
2583
+ if key == "right":
2584
+ level = levels[-1]
2585
+ level.selected = min(max(0, len(level.buckets) - 1), level.selected + 1)
2586
+ refresh_view()
2587
+ if key == "down":
2588
+ level = levels[-1]
2589
+ if not level.buckets:
2590
+ continue
2591
+ bucket = level.buckets[level.selected]
2592
+ if bucket[1] - bucket[0] <= 1:
2593
+ continue
2594
+ levels.append(TimelineLevel(bucket[0], bucket[1], selected=0))
2595
+ refresh_view()
2596
+ if key == "up":
2597
+ if len(levels) > 1:
2598
+ levels.pop()
2599
+ refresh_view()
2600
+ except KeyboardInterrupt:
2601
+ return 0
2602
+ finally:
2603
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
2604
+
2605
+ return 0
2606
+
2607
+
2608
+ def run_summarize(args: argparse.Namespace) -> int:
2609
+ metadata = load_session_metadata(args.path)
2610
+ events = load_session_events(args.path)
2611
+ ranks = [
2612
+ RankInfo(rank=int(item["rank"]), host=str(item["host"]))
2613
+ for item in metadata.get("ranks", [])
2614
+ if "rank" in item and "host" in item
2615
+ ]
2616
+ if not ranks:
2617
+ raise SystemExit("no ranks found in metadata")
2618
+ if not events:
2619
+ raise SystemExit("no events recorded")
2620
+
2621
+ rank_order = [info.rank for info in ranks]
2622
+ signature_counts: Dict[Tuple[str, ...], int] = {}
2623
+ signature_examples: Dict[Tuple[str, ...], Dict[int, str]] = {}
2624
+ rank_change_counts: Dict[int, int] = {rank: 0 for rank in rank_order}
2625
+ previous_rank_signature: Dict[int, str] = {rank: "" for rank in rank_order}
2626
+ max_stack_lens, divergence_ratios, common_prefixes = compute_event_metrics(
2627
+ events, ranks, show_threads=False
2628
+ )
2629
+
2630
+ for event in events:
2631
+ per_rank_signature: Dict[int, str] = {}
2632
+ per_rank_top_frame: Dict[int, str] = {}
2633
+ for info in ranks:
2634
+ payload = event.ranks.get(info.rank, {})
2635
+ if payload.get("error"):
2636
+ signature = f"error:{payload.get('error')}"
2637
+ top_frame = signature
2638
+ else:
2639
+ output = payload.get("py_spy")
2640
+ if output:
2641
+ lines, _details = render_pyspy_output(str(output), show_threads=False)
2642
+ stack_lines = extract_stack_lines(lines)
2643
+ signature = hashlib.sha1(
2644
+ "\n".join(stack_lines).encode("utf-8", errors="ignore")
2645
+ ).hexdigest()
2646
+ top_frame = stack_lines[0].strip() if stack_lines else "empty"
2647
+ else:
2648
+ signature = "empty"
2649
+ top_frame = "empty"
2650
+ per_rank_signature[info.rank] = signature
2651
+ per_rank_top_frame[info.rank] = top_frame
2652
+
2653
+ for rank, signature in per_rank_signature.items():
2654
+ if previous_rank_signature.get(rank) != signature:
2655
+ rank_change_counts[rank] = rank_change_counts.get(rank, 0) + 1
2656
+ previous_rank_signature[rank] = signature
2657
+
2658
+ signature_key = tuple(per_rank_signature[rank] for rank in rank_order)
2659
+ signature_counts[signature_key] = signature_counts.get(signature_key, 0) + 1
2660
+ if signature_key not in signature_examples:
2661
+ signature_examples[signature_key] = per_rank_top_frame
2662
+
2663
+ sorted_signatures = sorted(
2664
+ signature_counts.items(), key=lambda item: item[1], reverse=True
2665
+ )
2666
+ top_signatures = sorted_signatures[: max(1, args.top)]
2667
+ total_events = len(events)
2668
+ start_time = iso_timestamp(events[0].timestamp)
2669
+ end_time = iso_timestamp(events[-1].timestamp)
2670
+
2671
+ if args.format == "json":
2672
+ payload = {
2673
+ "metadata": metadata,
2674
+ "event_count": total_events,
2675
+ "time_range": {"start": start_time, "end": end_time},
2676
+ "rank_change_counts": rank_change_counts,
2677
+ "top_signatures": [
2678
+ {
2679
+ "count": count,
2680
+ "ratio": count / float(total_events),
2681
+ "example_top_frames": signature_examples.get(signature_key, {}),
2682
+ }
2683
+ for signature_key, count in top_signatures
2684
+ ],
2685
+ "most_divergent": sorted(
2686
+ [
2687
+ {
2688
+ "index": idx,
2689
+ "timestamp": iso_timestamp(events[idx].timestamp),
2690
+ "divergence_ratio": divergence_ratios[idx],
2691
+ "common_prefix_len": common_prefixes[idx],
2692
+ "max_stack_len": max_stack_lens[idx],
2693
+ }
2694
+ for idx in range(total_events)
2695
+ ],
2696
+ key=lambda item: item["divergence_ratio"],
2697
+ reverse=True,
2698
+ )[:5],
2699
+ }
2700
+ print(json.dumps(payload, indent=2, sort_keys=True))
2701
+ return 0
2702
+
2703
+ print(f"Session: {args.path}")
2704
+ print(f"Events: {total_events} ({start_time} -> {end_time})")
2705
+ print(f"Ranks: {', '.join(str(rank) for rank in rank_order)}")
2706
+ print("")
2707
+ print("Top stack signatures:")
2708
+ for idx, (signature_key, count) in enumerate(top_signatures, start=1):
2709
+ ratio = count / float(total_events)
2710
+ print(f"{idx}. {count} events ({ratio:.1%})")
2711
+ example = signature_examples.get(signature_key, {})
2712
+ for rank in rank_order:
2713
+ frame = example.get(rank, "")
2714
+ frame = shorten(frame, 120)
2715
+ print(f" rank {rank}: {frame}")
2716
+ print("")
2717
+ print("Rank change counts:")
2718
+ for rank in rank_order:
2719
+ print(f" rank {rank}: {rank_change_counts.get(rank, 0)}")
2720
+ print("")
2721
+ print("Most divergent events:")
2722
+ divergent = sorted(
2723
+ range(total_events),
2724
+ key=lambda idx: divergence_ratios[idx],
2725
+ reverse=True,
2726
+ )[:5]
2727
+ for idx in divergent:
2728
+ print(
2729
+ f" #{idx + 1} @ {iso_timestamp(events[idx].timestamp)} | "
2730
+ f"ratio {divergence_ratios[idx]:.2f} | "
2731
+ f"common {common_prefixes[idx]} | "
2732
+ f"max {max_stack_lens[idx]}"
2733
+ )
2734
+ return 0
2735
+
2736
+
2737
+ def main(argv: Optional[Sequence[str]] = None) -> int:
2738
+ argv = list(argv) if argv is not None else sys.argv[1:]
2739
+ if argv and argv[0] in {"review", "summarize", "record"}:
2740
+ command = argv[0]
2741
+ sub_args = argv[1:]
2742
+ if command == "review":
2743
+ return run_review(parse_review_args(sub_args))
2744
+ if command == "record":
2745
+ return run_record_batch(parse_record_args(sub_args))
2746
+ return run_summarize(parse_summarize_args(sub_args))
2747
+ return run_live(parse_live_args(argv))
2748
+
2749
+
1435
2750
  def select_with_timeout(timeout: float):
1436
2751
  import select
1437
2752