mpiptop 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mpiptop
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: TUI for viewing MPI Python stacks across hosts
5
5
  Author: yieldthought
6
6
  License-Expression: MIT
@@ -46,11 +46,14 @@ Common options:
46
46
  ```bash
47
47
  mpiptop --rankfile /etc/mpirun/rankfile_01_02
48
48
  mpiptop --prterun-pid 12345
49
+ mpiptop --slurm-job 123456
49
50
  mpiptop --refresh 5
50
51
  mpiptop --pythonpath /path/to/your/code
51
52
  mpiptop record --out ./mpiptop-session-20260123-120000.jsonl
52
53
  ```
53
54
 
55
+ Slurm notes: if `SLURM_JOB_ID` is set or you have exactly one running job, mpiptop will auto-detect it.
56
+
54
57
  Record/review (record is batch mode; use plain `mpiptop` for the TUI):
55
58
  ```bash
56
59
  mpiptop record
@@ -0,0 +1,7 @@
1
+ mpiptop.py,sha256=-obFdgSQcVsdVE38b8NaxtS1yRFQCWR2xGE1NJDKozU,93237
2
+ mpiptop-0.2.1.dist-info/licenses/LICENSE,sha256=ChKmQ8qCXxdXRR_HIJECjIA5NLWlUTEJWh7Xkhm2wAA,1069
3
+ mpiptop-0.2.1.dist-info/METADATA,sha256=pIp11zp75Fs2Y0PRvcb1s4bq7tACnI42EE_1fgxIjtY,2042
4
+ mpiptop-0.2.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
5
+ mpiptop-0.2.1.dist-info/entry_points.txt,sha256=RsGsr8GBLfUNpb432YWS5gz4MWfWdK9xJRr1SmdnLo8,41
6
+ mpiptop-0.2.1.dist-info/top_level.txt,sha256=c2Vdu6tTg0DEPUWD8Odyods7fXsPWMQ2kSvjdKiTClc,8
7
+ mpiptop-0.2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.10.1)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
mpiptop.py CHANGED
@@ -53,7 +53,9 @@ class ProgramSelector:
53
53
 
54
54
  @dataclasses.dataclass(frozen=True)
55
55
  class State:
56
+ launcher: str
56
57
  prte_pid: int
58
+ slurm_job_id: Optional[str]
57
59
  rankfile: str
58
60
  ranks: List[RankInfo]
59
61
  selector: ProgramSelector
@@ -132,6 +134,7 @@ import os
132
134
 
133
135
  TARGET = os.environ.get("MPIPTOP_TARGET", "")
134
136
  MODULE = os.environ.get("MPIPTOP_MODULE", "")
137
+ JOB_ID = os.environ.get("MPIPTOP_SLURM_JOB_ID", "")
135
138
  ENV_KEYS = [
136
139
  "PATH",
137
140
  "LD_LIBRARY_PATH",
@@ -205,6 +208,12 @@ def matches(cmd):
205
208
  return True
206
209
 
207
210
 
211
+ def matches_job(env):
212
+ if not JOB_ID:
213
+ return True
214
+ return env.get("SLURM_JOB_ID") == JOB_ID
215
+
216
+
208
217
  results = []
209
218
  for pid in os.listdir("/proc"):
210
219
  if not pid.isdigit():
@@ -214,7 +223,14 @@ for pid in os.listdir("/proc"):
214
223
  if not matches(cmd):
215
224
  continue
216
225
  env = read_env(pid)
217
- rank = env.get("OMPI_COMM_WORLD_RANK") or env.get("PMIX_RANK") or env.get("PMI_RANK")
226
+ if not matches_job(env):
227
+ continue
228
+ rank = (
229
+ env.get("OMPI_COMM_WORLD_RANK")
230
+ or env.get("PMIX_RANK")
231
+ or env.get("PMI_RANK")
232
+ or env.get("SLURM_PROCID")
233
+ )
218
234
  if rank is None:
219
235
  continue
220
236
  results.append(
@@ -278,6 +294,8 @@ def write_session_metadata(log_path: str, state: State, refresh: int, pythonpath
278
294
  "refresh": refresh,
279
295
  "rankfile": state.rankfile,
280
296
  "prte_pid": state.prte_pid,
297
+ "launcher": state.launcher,
298
+ "slurm_job_id": state.slurm_job_id,
281
299
  "selector": dataclasses.asdict(state.selector),
282
300
  "ranks": [dataclasses.asdict(rank) for rank in state.ranks],
283
301
  "pythonpath": pythonpath,
@@ -699,6 +717,7 @@ def matches_python_cmd(cmd: List[str], selector: ProgramSelector) -> bool:
699
717
 
700
718
  def find_rank_pids_local(
701
719
  selector: ProgramSelector,
720
+ slurm_job_id: Optional[str],
702
721
  ) -> List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]]:
703
722
  results: List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]] = []
704
723
  for pid in os.listdir("/proc"):
@@ -717,7 +736,14 @@ def find_rank_pids_local(
717
736
  continue
718
737
  key, value = item.split(b"=", 1)
719
738
  env[key.decode(errors="ignore")] = value.decode(errors="ignore")
720
- rank = env.get("OMPI_COMM_WORLD_RANK") or env.get("PMIX_RANK") or env.get("PMI_RANK")
739
+ if slurm_job_id and env.get("SLURM_JOB_ID") != slurm_job_id:
740
+ continue
741
+ rank = (
742
+ env.get("OMPI_COMM_WORLD_RANK")
743
+ or env.get("PMIX_RANK")
744
+ or env.get("PMI_RANK")
745
+ or env.get("SLURM_PROCID")
746
+ )
721
747
  if rank is None:
722
748
  continue
723
749
  rss_kb = read_rss_kb(int(pid))
@@ -765,12 +791,15 @@ def run_ssh(host: str, command: str, timeout: int = 8) -> subprocess.CompletedPr
765
791
 
766
792
 
767
793
  def find_rank_pids_remote(
768
- host: str, selector: ProgramSelector
794
+ host: str,
795
+ selector: ProgramSelector,
796
+ slurm_job_id: Optional[str],
769
797
  ) -> Tuple[List[Tuple[int, int, str, Optional[int], Optional[str], Dict[str, str]]], Optional[str]]:
770
798
  env_prefix = build_env_prefix(
771
799
  {
772
800
  "MPIPTOP_TARGET": selector.script or "",
773
801
  "MPIPTOP_MODULE": selector.module or "",
802
+ "MPIPTOP_SLURM_JOB_ID": slurm_job_id or "",
774
803
  }
775
804
  )
776
805
  remote_cmd = f"{env_prefix}python3 - <<'PY'\n{REMOTE_FINDER_SCRIPT}\nPY"
@@ -1872,18 +1901,204 @@ def is_pid_alive(pid: int) -> bool:
1872
1901
  return True
1873
1902
  return True
1874
1903
 
1904
+
1905
+ def parse_scontrol_kv(line: str) -> Dict[str, str]:
1906
+ fields: Dict[str, str] = {}
1907
+ for token in line.split():
1908
+ if "=" not in token:
1909
+ continue
1910
+ key, value = token.split("=", 1)
1911
+ fields[key] = value
1912
+ return fields
1913
+
1914
+
1915
+ def run_scontrol_show_job(job_id: str) -> Dict[str, str]:
1916
+ result = subprocess.run(
1917
+ ["scontrol", "show", "job", "-o", str(job_id)],
1918
+ capture_output=True,
1919
+ text=True,
1920
+ )
1921
+ if result.returncode != 0:
1922
+ stderr = (result.stderr or result.stdout or "").strip()
1923
+ raise SystemExit(f"scontrol show job failed for {job_id}: {stderr or 'unknown error'}")
1924
+ line = (result.stdout or "").strip()
1925
+ if not line:
1926
+ raise SystemExit(f"scontrol show job returned empty output for {job_id}")
1927
+ return parse_scontrol_kv(line)
1928
+
1929
+
1930
+ def expand_slurm_nodelist(nodelist: str) -> List[str]:
1931
+ result = subprocess.run(
1932
+ ["scontrol", "show", "hostnames", nodelist],
1933
+ capture_output=True,
1934
+ text=True,
1935
+ )
1936
+ if result.returncode != 0:
1937
+ stderr = (result.stderr or result.stdout or "").strip()
1938
+ raise SystemExit(f"scontrol show hostnames failed: {stderr or 'unknown error'}")
1939
+ hosts = [line.strip() for line in result.stdout.splitlines() if line.strip()]
1940
+ if not hosts:
1941
+ raise SystemExit(f"no hosts parsed from nodelist: {nodelist}")
1942
+ return hosts
1943
+
1944
+
1945
+ def parse_tasks_per_node(raw: str) -> List[int]:
1946
+ if not raw:
1947
+ return []
1948
+ counts: List[int] = []
1949
+ for part in raw.split(","):
1950
+ part = part.strip()
1951
+ if not part:
1952
+ continue
1953
+ match = re.match(r"(\d+)\(x(\d+)\)", part)
1954
+ if match:
1955
+ value = int(match.group(1))
1956
+ repeat = int(match.group(2))
1957
+ counts.extend([value] * repeat)
1958
+ continue
1959
+ if part.isdigit():
1960
+ counts.append(int(part))
1961
+ return counts
1962
+
1963
+
1964
+ def distribute_tasks(num_tasks: int, num_nodes: int) -> List[int]:
1965
+ if num_nodes <= 0:
1966
+ return []
1967
+ base = num_tasks // num_nodes
1968
+ remainder = num_tasks % num_nodes
1969
+ counts = [base] * num_nodes
1970
+ for idx in range(remainder):
1971
+ counts[idx] += 1
1972
+ return counts
1973
+
1974
+
1975
+ def slurm_job_to_ranks(job_id: str) -> List[RankInfo]:
1976
+ info = run_scontrol_show_job(job_id)
1977
+ nodelist = info.get("NodeList") or info.get("Nodes")
1978
+ if not nodelist:
1979
+ raise SystemExit(f"no NodeList found for slurm job {job_id}")
1980
+ hosts = expand_slurm_nodelist(nodelist)
1981
+ tasks_per_node = parse_tasks_per_node(info.get("TasksPerNode", ""))
1982
+ num_tasks = 0
1983
+ try:
1984
+ num_tasks = int(info.get("NumTasks", "0") or 0)
1985
+ except ValueError:
1986
+ num_tasks = 0
1987
+
1988
+ if len(tasks_per_node) != len(hosts):
1989
+ if num_tasks > 0:
1990
+ tasks_per_node = distribute_tasks(num_tasks, len(hosts))
1991
+ else:
1992
+ tasks_per_node = [1] * len(hosts)
1993
+
1994
+ ranks: List[RankInfo] = []
1995
+ rank_id = 0
1996
+ for host, count in zip(hosts, tasks_per_node):
1997
+ for _ in range(max(0, count)):
1998
+ ranks.append(RankInfo(rank=rank_id, host=host))
1999
+ rank_id += 1
2000
+ return ranks
2001
+
2002
+
2003
+ def resolve_slurm_job_id(args: argparse.Namespace) -> Optional[str]:
2004
+ if getattr(args, "slurm_job", None):
2005
+ return str(args.slurm_job)
2006
+ env_job = os.environ.get("SLURM_JOB_ID")
2007
+ if env_job:
2008
+ return env_job
2009
+ user = os.environ.get("USER")
2010
+ if not user:
2011
+ return None
2012
+ result = subprocess.run(
2013
+ ["squeue", "-u", user, "-h", "-t", "R", "-o", "%i"],
2014
+ capture_output=True,
2015
+ text=True,
2016
+ )
2017
+ if result.returncode != 0:
2018
+ return None
2019
+ jobs = [line.strip() for line in result.stdout.splitlines() if line.strip()]
2020
+ if len(jobs) == 1:
2021
+ return jobs[0]
2022
+ return None
2023
+
2024
+
2025
+ def describe_slurm_jobs() -> str:
2026
+ user = os.environ.get("USER")
2027
+ if not user:
2028
+ return ""
2029
+ result = subprocess.run(
2030
+ ["squeue", "-u", user, "-h", "-o", "%i %t %j %R"],
2031
+ capture_output=True,
2032
+ text=True,
2033
+ )
2034
+ if result.returncode != 0:
2035
+ return ""
2036
+ lines = [line.strip() for line in result.stdout.splitlines() if line.strip()]
2037
+ if not lines:
2038
+ return ""
2039
+ return "\n".join(lines[:10])
2040
+
2041
+
2042
+ def is_slurm_job_alive(job_id: Optional[str]) -> bool:
2043
+ if not job_id:
2044
+ return False
2045
+ result = subprocess.run(
2046
+ ["squeue", "-j", str(job_id), "-h", "-o", "%t"],
2047
+ capture_output=True,
2048
+ text=True,
2049
+ )
2050
+ if result.returncode != 0:
2051
+ return False
2052
+ state = (result.stdout or "").strip()
2053
+ return bool(state)
2054
+
1875
2055
  def detect_state(args: argparse.Namespace) -> State:
1876
2056
  procs = read_ps()
1877
- prte = find_prterun(procs, args.prterun_pid)
1878
- rankfile = args.rankfile or find_rankfile_path(prte.args)
1879
- if not rankfile:
1880
- raise SystemExit("rankfile not found in prterun/mpirun args")
1881
- ranks = parse_rankfile(rankfile)
1882
- children = build_children_map(procs)
1883
- descendants = find_descendants(children, prte.pid)
1884
- program_proc = select_program(procs, descendants)
1885
- selector = parse_python_selector(program_proc.args if program_proc else "")
1886
- return State(prte_pid=prte.pid, rankfile=rankfile, ranks=ranks, selector=selector)
2057
+ prte_error = None
2058
+ prte = None
2059
+ try:
2060
+ prte = find_prterun(procs, args.prterun_pid)
2061
+ except SystemExit as exc:
2062
+ prte_error = str(exc)
2063
+
2064
+ if prte is not None:
2065
+ rankfile = args.rankfile or find_rankfile_path(prte.args)
2066
+ if not rankfile:
2067
+ raise SystemExit("rankfile not found in prterun/mpirun args")
2068
+ ranks = parse_rankfile(rankfile)
2069
+ children = build_children_map(procs)
2070
+ descendants = find_descendants(children, prte.pid)
2071
+ program_proc = select_program(procs, descendants)
2072
+ selector = parse_python_selector(program_proc.args if program_proc else "")
2073
+ return State(
2074
+ launcher="prte",
2075
+ prte_pid=prte.pid,
2076
+ slurm_job_id=None,
2077
+ rankfile=rankfile,
2078
+ ranks=ranks,
2079
+ selector=selector,
2080
+ )
2081
+
2082
+ slurm_job_id = resolve_slurm_job_id(args)
2083
+ if slurm_job_id:
2084
+ ranks = slurm_job_to_ranks(slurm_job_id)
2085
+ selector = ProgramSelector(module=None, script=None, display="")
2086
+ return State(
2087
+ launcher="slurm",
2088
+ prte_pid=0,
2089
+ slurm_job_id=slurm_job_id,
2090
+ rankfile=f"slurm:{slurm_job_id}",
2091
+ ranks=ranks,
2092
+ selector=selector,
2093
+ )
2094
+
2095
+ hint = describe_slurm_jobs()
2096
+ if hint:
2097
+ hint = "\n" + hint
2098
+ raise SystemExit(
2099
+ prte_error
2100
+ or f"no prterun/mpirun process found and no slurm job detected (try --slurm-job){hint}"
2101
+ )
1887
2102
 
1888
2103
 
1889
2104
  def collect_rank_pids(state: State) -> Tuple[Dict[int, RankProcess], List[str]]:
@@ -1894,10 +2109,10 @@ def collect_rank_pids(state: State) -> Tuple[Dict[int, RankProcess], List[str]]:
1894
2109
 
1895
2110
  for host in hosts:
1896
2111
  if is_local_host(host):
1897
- entries = find_rank_pids_local(state.selector)
2112
+ entries = find_rank_pids_local(state.selector, state.slurm_job_id)
1898
2113
  host_error = None
1899
2114
  else:
1900
- entries, host_error = find_rank_pids_remote(host, state.selector)
2115
+ entries, host_error = find_rank_pids_remote(host, state.selector, state.slurm_job_id)
1901
2116
  if host_error:
1902
2117
  errors.append(host_error)
1903
2118
  for rank, pid, cmd, rss_kb, python_exe, env_subset in entries:
@@ -1960,6 +2175,7 @@ def parse_live_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
1960
2175
  parser = argparse.ArgumentParser(description="Show MPI Python stacks across hosts.")
1961
2176
  parser.add_argument("--rankfile", help="Override rankfile path")
1962
2177
  parser.add_argument("--prterun-pid", type=int, help="PID of prterun/mpirun")
2178
+ parser.add_argument("--slurm-job", help="Slurm job ID to inspect")
1963
2179
  parser.add_argument("--refresh", type=int, default=10, help="Refresh interval (seconds)")
1964
2180
  parser.add_argument(
1965
2181
  "--pythonpath",
@@ -2000,6 +2216,7 @@ def parse_record_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespac
2000
2216
  parser = argparse.ArgumentParser(description="Record an mpiptop session.")
2001
2217
  parser.add_argument("--rankfile", help="Override rankfile path")
2002
2218
  parser.add_argument("--prterun-pid", type=int, help="PID of prterun/mpirun")
2219
+ parser.add_argument("--slurm-job", help="Slurm job ID to inspect")
2003
2220
  parser.add_argument("--refresh", type=int, default=10, help="Refresh interval (seconds)")
2004
2221
  parser.add_argument(
2005
2222
  "--pythonpath",
@@ -2186,9 +2403,14 @@ def run_record_batch(args: argparse.Namespace) -> int:
2186
2403
  try:
2187
2404
  while True:
2188
2405
  loop_start = time.time()
2189
- if not is_pid_alive(state.prte_pid):
2190
- stop_reason = "prterun-exited"
2191
- break
2406
+ if state.launcher == "prte":
2407
+ if not is_pid_alive(state.prte_pid):
2408
+ stop_reason = "prterun-exited"
2409
+ break
2410
+ else:
2411
+ if not is_slurm_job_alive(state.slurm_job_id):
2412
+ stop_reason = "slurm-job-exited"
2413
+ break
2192
2414
  rank_to_proc, _pid_errors = collect_rank_pids(state)
2193
2415
  snapshots, _stack_errors = collect_stacks(
2194
2416
  state, rank_to_proc, pythonpath, False, install_attempted
@@ -2249,7 +2471,9 @@ def run_review(args: argparse.Namespace) -> int:
2249
2471
  display=selector_payload.get("display", ""),
2250
2472
  )
2251
2473
  state = State(
2474
+ launcher=str(metadata.get("launcher", "prte")),
2252
2475
  prte_pid=int(metadata.get("prte_pid", 0) or 0),
2476
+ slurm_job_id=metadata.get("slurm_job_id"),
2253
2477
  rankfile=str(metadata.get("rankfile", "")),
2254
2478
  ranks=ranks,
2255
2479
  selector=selector,
@@ -1,7 +0,0 @@
1
- mpiptop.py,sha256=D4h-jOyhYU4M0FzQ6JMX5gDBE8UVd4detm5JDFdTm4c,86492
2
- mpiptop-0.2.0.dist-info/licenses/LICENSE,sha256=ChKmQ8qCXxdXRR_HIJECjIA5NLWlUTEJWh7Xkhm2wAA,1069
3
- mpiptop-0.2.0.dist-info/METADATA,sha256=3vT5lrkqfuh6O2DL6xCm412w7opwQAAzJFdD6IGHs7g,1910
4
- mpiptop-0.2.0.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
5
- mpiptop-0.2.0.dist-info/entry_points.txt,sha256=RsGsr8GBLfUNpb432YWS5gz4MWfWdK9xJRr1SmdnLo8,41
6
- mpiptop-0.2.0.dist-info/top_level.txt,sha256=c2Vdu6tTg0DEPUWD8Odyods7fXsPWMQ2kSvjdKiTClc,8
7
- mpiptop-0.2.0.dist-info/RECORD,,