alloc 0.0.7__tar.gz → 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {alloc-0.0.7 → alloc-0.0.9}/PKG-INFO +2 -2
  2. {alloc-0.0.7 → alloc-0.0.9}/README.md +1 -1
  3. {alloc-0.0.7 → alloc-0.0.9}/pyproject.toml +1 -1
  4. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/__init__.py +1 -1
  5. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/cli.py +44 -7
  6. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/extractor_runner.py +29 -2
  7. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/model_extractor.py +14 -5
  8. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/probe.py +1 -1
  9. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/upload.py +1 -0
  10. {alloc-0.0.7 → alloc-0.0.9}/src/alloc.egg-info/PKG-INFO +2 -2
  11. {alloc-0.0.7 → alloc-0.0.9}/src/alloc.egg-info/SOURCES.txt +1 -0
  12. {alloc-0.0.7 → alloc-0.0.9}/tests/test_auth.py +29 -0
  13. {alloc-0.0.7 → alloc-0.0.9}/tests/test_cli.py +1 -1
  14. alloc-0.0.9/tests/test_topology_strategy.py +93 -0
  15. {alloc-0.0.7 → alloc-0.0.9}/setup.cfg +0 -0
  16. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/artifact_loader.py +0 -0
  17. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/artifact_writer.py +0 -0
  18. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/browser_auth.py +0 -0
  19. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/callbacks.py +0 -0
  20. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/catalog/__init__.py +0 -0
  21. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/catalog/default_rate_card.json +0 -0
  22. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/catalog/gpus.v1.json +0 -0
  23. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/code_analyzer.py +0 -0
  24. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/config.py +0 -0
  25. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/context.py +0 -0
  26. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/diagnosis_display.py +0 -0
  27. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/diagnosis_engine.py +0 -0
  28. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/diagnosis_rules.py +0 -0
  29. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/display.py +0 -0
  30. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/ghost.py +0 -0
  31. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/model_registry.py +0 -0
  32. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/stability.py +0 -0
  33. {alloc-0.0.7 → alloc-0.0.9}/src/alloc/yaml_config.py +0 -0
  34. {alloc-0.0.7 → alloc-0.0.9}/src/alloc.egg-info/dependency_links.txt +0 -0
  35. {alloc-0.0.7 → alloc-0.0.9}/src/alloc.egg-info/entry_points.txt +0 -0
  36. {alloc-0.0.7 → alloc-0.0.9}/src/alloc.egg-info/requires.txt +0 -0
  37. {alloc-0.0.7 → alloc-0.0.9}/src/alloc.egg-info/top_level.txt +0 -0
  38. {alloc-0.0.7 → alloc-0.0.9}/tests/test_artifact.py +0 -0
  39. {alloc-0.0.7 → alloc-0.0.9}/tests/test_artifact_loader.py +0 -0
  40. {alloc-0.0.7 → alloc-0.0.9}/tests/test_callbacks.py +0 -0
  41. {alloc-0.0.7 → alloc-0.0.9}/tests/test_catalog.py +0 -0
  42. {alloc-0.0.7 → alloc-0.0.9}/tests/test_code_analyzer.py +0 -0
  43. {alloc-0.0.7 → alloc-0.0.9}/tests/test_context.py +0 -0
  44. {alloc-0.0.7 → alloc-0.0.9}/tests/test_diagnose_cli.py +0 -0
  45. {alloc-0.0.7 → alloc-0.0.9}/tests/test_diagnosis_engine.py +0 -0
  46. {alloc-0.0.7 → alloc-0.0.9}/tests/test_diagnosis_rules.py +0 -0
  47. {alloc-0.0.7 → alloc-0.0.9}/tests/test_extractor_activation.py +0 -0
  48. {alloc-0.0.7 → alloc-0.0.9}/tests/test_ghost.py +0 -0
  49. {alloc-0.0.7 → alloc-0.0.9}/tests/test_ghost_degradation.py +0 -0
  50. {alloc-0.0.7 → alloc-0.0.9}/tests/test_init_from_org.py +0 -0
  51. {alloc-0.0.7 → alloc-0.0.9}/tests/test_interconnect.py +0 -0
  52. {alloc-0.0.7 → alloc-0.0.9}/tests/test_model_extractor.py +0 -0
  53. {alloc-0.0.7 → alloc-0.0.9}/tests/test_probe_hw.py +0 -0
  54. {alloc-0.0.7 → alloc-0.0.9}/tests/test_probe_multi.py +0 -0
  55. {alloc-0.0.7 → alloc-0.0.9}/tests/test_scan_auth.py +0 -0
  56. {alloc-0.0.7 → alloc-0.0.9}/tests/test_stability.py +0 -0
  57. {alloc-0.0.7 → alloc-0.0.9}/tests/test_upload.py +0 -0
  58. {alloc-0.0.7 → alloc-0.0.9}/tests/test_verdict.py +0 -0
  59. {alloc-0.0.7 → alloc-0.0.9}/tests/test_yaml_config.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -40,7 +40,7 @@ alloc run python train.py
40
40
  ```
41
41
 
42
42
  ```
43
- alloc v0.0.2 — Calibrate
43
+ alloc v0.0.8 — Calibrate
44
44
 
45
45
  Run Summary
46
46
  Peak VRAM 31.2 GB / 40.0 GB (A100)
@@ -12,7 +12,7 @@ alloc run python train.py
12
12
  ```
13
13
 
14
14
  ```
15
- alloc v0.0.2 — Calibrate
15
+ alloc v0.0.8 — Calibrate
16
16
 
17
17
  Run Summary
18
18
  Peak VRAM 31.2 GB / 40.0 GB (A100)
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "alloc"
7
- version = "0.0.7"
7
+ version = "0.0.9"
8
8
  description = "Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints."
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -9,7 +9,7 @@ _warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda"
9
9
  _warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
10
10
  del _warnings
11
11
 
12
- __version__ = "0.0.7"
12
+ __version__ = "0.0.9"
13
13
 
14
14
  from alloc.ghost import ghost, GhostReport
15
15
  from alloc.callbacks import AllocCallback as HuggingFaceCallback
@@ -455,7 +455,9 @@ def run(
455
455
  "tp_degree": topology.get("tp_degree"),
456
456
  "pp_degree": topology.get("pp_degree"),
457
457
  "dp_degree": topology.get("dp_degree"),
458
+ "strategy": topology.get("strategy"),
458
459
  "interconnect_type": topology.get("interconnect_type"),
460
+ "process_map": result.process_map,
459
461
  "objective": objective,
460
462
  "max_budget_hourly": max_budget_hourly,
461
463
  "command": " ".join(command),
@@ -2368,7 +2370,7 @@ def whoami(
2368
2370
 
2369
2371
  out = {
2370
2372
  "api_url": api_url,
2371
- "logged_in": bool(token),
2373
+ "logged_in": False,
2372
2374
  "token_source": token_source if token else None,
2373
2375
  }
2374
2376
 
@@ -2398,20 +2400,33 @@ def whoami(
2398
2400
  profile = _get("/profile")
2399
2401
  fleet = _get("/gpu-fleet")
2400
2402
  else:
2401
- if json_output:
2403
+ # whoami is a status command — report structured result, exit 0
2404
+ if e.response.status_code == 401:
2405
+ out["token_status"] = "expired"
2406
+ else:
2407
+ out["token_status"] = "error"
2402
2408
  out["error"] = f"API error {e.response.status_code}"
2409
+ if json_output:
2403
2410
  _print_json(out)
2404
2411
  else:
2405
- console.print(f"[red]API error {e.response.status_code}[/red]")
2412
+ if e.response.status_code == 401:
2413
+ console.print("[yellow]Token expired.[/yellow]")
2414
+ else:
2415
+ console.print(f"[red]API error {e.response.status_code}[/red]")
2406
2416
  console.print("[dim]Run: alloc login[/dim]")
2407
- raise typer.Exit(1)
2417
+ return
2408
2418
  except httpx.ConnectError:
2419
+ out["token_status"] = "unreachable"
2420
+ out["error"] = f"Cannot connect to {api_url}"
2409
2421
  if json_output:
2410
- out["error"] = f"Cannot connect to {api_url}"
2411
2422
  _print_json(out)
2412
2423
  else:
2413
2424
  console.print(f"[red]Cannot connect to {api_url}[/red]")
2414
- raise typer.Exit(1)
2425
+ return
2426
+
2427
+ # API validated the token — now we know login is real
2428
+ out["logged_in"] = True
2429
+ out["token_status"] = "valid"
2415
2430
 
2416
2431
  gpus = fleet.get("gpus") or []
2417
2432
  fleet_count = len([g for g in gpus if g.get("fleet_status") == "in_fleet"])
@@ -3087,7 +3102,7 @@ def status(
3087
3102
 
3088
3103
  out = {
3089
3104
  "version": __version__,
3090
- "logged_in": bool(token),
3105
+ "has_token": bool(token),
3091
3106
  "api_url": api_url,
3092
3107
  "artifact": None,
3093
3108
  "dashboard_url": None,
@@ -3548,6 +3563,27 @@ def _infer_parallel_topology_from_env(*, num_gpus_detected: int, config_intercon
3548
3563
  if interconnect not in ("pcie", "nvlink", "nvlink_switch", "nvlink_p2p", "infiniband", "unknown"):
3549
3564
  interconnect = "unknown"
3550
3565
 
3566
+ # Infer strategy from degrees — only when evidence exists
3567
+ strategy = None
3568
+ has_tp = tp is not None and tp > 1
3569
+ has_pp = pp is not None and pp > 1
3570
+ if has_tp and has_pp:
3571
+ strategy = "tp+pp+dp"
3572
+ elif has_tp:
3573
+ strategy = "tp+dp" if (dp is not None and dp > 1) else "tp"
3574
+ elif has_pp:
3575
+ strategy = "pp+dp" if (dp is not None and dp > 1) else "pp"
3576
+ elif dp is not None and dp > 1:
3577
+ strategy = "ddp"
3578
+ elif strategy is None and num_gpus_detected > 1 and not has_tp and not has_pp:
3579
+ # Multiple GPUs detected via NVML with no TP/PP env vars →
3580
+ # DDP is PyTorch's default and the only realistic inference.
3581
+ # This is NOT the old `or "ddp"` — it only fires when probe
3582
+ # actually observed multiple GPU processes.
3583
+ strategy = "ddp"
3584
+ if dp is None:
3585
+ dp = num_gpus_detected
3586
+
3551
3587
  return {
3552
3588
  "num_nodes": nnodes or 1,
3553
3589
  "gpus_per_node": gpn,
@@ -3555,6 +3591,7 @@ def _infer_parallel_topology_from_env(*, num_gpus_detected: int, config_intercon
3555
3591
  "pp_degree": pp,
3556
3592
  "dp_degree": dp,
3557
3593
  "interconnect_type": interconnect,
3594
+ "strategy": strategy,
3558
3595
  }
3559
3596
 
3560
3597
 
@@ -206,7 +206,11 @@ def main():
206
206
  except SystemExit:
207
207
  pass # catch real SystemExit too
208
208
  except Exception as e:
209
- result = {"status": "error", "error": str(e)[:200]}
209
+ error_msg = str(e)[:500]
210
+ _dist_keywords = ("init_process_group", "nccl", "gloo", "distributed",
211
+ "master_addr", "master_port", "rendezvouserror")
212
+ status = "error_distributed" if any(kw in error_msg.lower() for kw in _dist_keywords) else "error"
213
+ result = {"status": status, "error": error_msg}
210
214
  with open(sidecar_path, "w") as f:
211
215
  json.dump(result, f)
212
216
  return
@@ -277,7 +281,30 @@ def main():
277
281
  "activation_method": activation_result.get("activation_method"),
278
282
  }
279
283
  else:
280
- result = {"status": "no_model"}
284
+ # No model found — check if this is a distributed training script
285
+ # that hides the model inside __main__ guard or main()
286
+ _is_dist = False
287
+ try:
288
+ import torch.distributed as _dist_mod
289
+ if _dist_mod.is_initialized():
290
+ _is_dist = True
291
+ except Exception:
292
+ pass
293
+ if not _is_dist:
294
+ # Check if module imported distributed primitives
295
+ for attr_name in dir(module):
296
+ try:
297
+ obj = getattr(module, attr_name)
298
+ mod_name = getattr(obj, "__module__", "") or ""
299
+ if "torch.distributed" in mod_name or "torch.nn.parallel" in mod_name:
300
+ _is_dist = True
301
+ break
302
+ except Exception:
303
+ continue
304
+ if _is_dist:
305
+ result = {"status": "error_distributed", "error": "no model found — script uses distributed training"}
306
+ else:
307
+ result = {"status": "no_model"}
281
308
 
282
309
  with open(sidecar_path, "w") as f:
283
310
  json.dump(result, f)
@@ -108,6 +108,10 @@ def _extract_via_subprocess(
108
108
  env.setdefault("WORLD_SIZE", "1")
109
109
  env.setdefault("MASTER_ADDR", "127.0.0.1")
110
110
  env.setdefault("MASTER_PORT", "29500")
111
+ # Suppress pynvml/torch.cuda deprecation warnings in subprocess
112
+ existing = env.get("PYTHONWARNINGS", "")
113
+ filters = "ignore::FutureWarning,ignore::DeprecationWarning"
114
+ env["PYTHONWARNINGS"] = f"{existing},{filters}" if existing else filters
111
115
 
112
116
  subprocess.run(
113
117
  [sys.executable, "-m", "alloc.extractor_runner", sidecar_path, script_abs],
@@ -137,11 +141,16 @@ def _extract_via_subprocess(
137
141
  )
138
142
 
139
143
  # Structured degradation for distributed scripts
140
- if data.get("status") == "error":
141
- error_msg = data.get("error", "")
142
- _dist_keywords = ("init_process_group", "NCCL", "gloo", "distributed",
143
- "MASTER_ADDR", "MASTER_PORT", "RendezvousError")
144
- if any(kw.lower() in error_msg.lower() for kw in _dist_keywords):
144
+ status = data.get("status", "")
145
+ if status in ("error", "error_distributed"):
146
+ is_distributed = status == "error_distributed"
147
+ if not is_distributed:
148
+ # Fallback keyword match for older sidecar format
149
+ error_msg = data.get("error", "")
150
+ _dist_keywords = ("init_process_group", "NCCL", "gloo", "distributed",
151
+ "MASTER_ADDR", "MASTER_PORT", "RendezvousError")
152
+ is_distributed = any(kw.lower() in error_msg.lower() for kw in _dist_keywords)
153
+ if is_distributed:
145
154
  return ModelInfo(
146
155
  param_count=0,
147
156
  dtype="float16",
@@ -363,7 +363,7 @@ def probe_command(
363
363
  """
364
364
  pynvml = _try_import_pynvml()
365
365
 
366
- # Launch the subprocess
366
+ # Launch the user's training subprocess — do NOT modify env (their warnings matter)
367
367
  try:
368
368
  proc = subprocess.Popen(
369
369
  command,
@@ -123,6 +123,7 @@ def upload_artifact(artifact_path: str, api_url: str, token: str) -> dict:
123
123
  "dataloader_wait_pct": probe.get("dataloader_wait_pct"),
124
124
  "comm_overhead_pct": probe.get("comm_overhead_pct"),
125
125
  "per_rank_peak_vram_mb": probe.get("per_rank_peak_vram_mb"),
126
+ "process_map": probe.get("process_map"),
126
127
  # Architecture fields: probe (callbacks) takes priority over ghost defaults
127
128
  "batch_size": probe.get("batch_size") or (ghost.get("batch_size") if ghost else None),
128
129
  "seq_length": ghost.get("seq_length") if ghost else None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -40,7 +40,7 @@ alloc run python train.py
40
40
  ```
41
41
 
42
42
  ```
43
- alloc v0.0.2 — Calibrate
43
+ alloc v0.0.8 — Calibrate
44
44
 
45
45
  Run Summary
46
46
  Peak VRAM 31.2 GB / 40.0 GB (A100)
@@ -51,6 +51,7 @@ tests/test_probe_hw.py
51
51
  tests/test_probe_multi.py
52
52
  tests/test_scan_auth.py
53
53
  tests/test_stability.py
54
+ tests/test_topology_strategy.py
54
55
  tests/test_upload.py
55
56
  tests/test_verdict.py
56
57
  tests/test_yaml_config.py
@@ -68,6 +68,34 @@ def test_whoami_not_logged_in_json(tmp_path: Path):
68
68
  assert data["api_url"] == "https://api.example.com"
69
69
 
70
70
 
71
+ def test_whoami_stale_token_json(tmp_path: Path):
72
+ """Stale token should exit 0 with token_status: expired."""
73
+ mock_resp = MagicMock()
74
+ mock_resp.status_code = 401
75
+ mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
76
+ "Unauthorized", request=MagicMock(), response=mock_resp,
77
+ )
78
+ mock_client = MagicMock()
79
+ mock_client.__enter__.return_value = mock_client
80
+ mock_client.__exit__.return_value = False
81
+ mock_client.get.return_value = mock_resp
82
+
83
+ env = {
84
+ "HOME": str(tmp_path),
85
+ "ALLOC_API_URL": "https://api.example.com",
86
+ "ALLOC_TOKEN": "stale-token",
87
+ }
88
+
89
+ with patch("httpx.Client", return_value=mock_client), \
90
+ patch("alloc.cli.try_refresh_access_token", return_value=None):
91
+ result = runner.invoke(app, ["whoami", "--json"], env=env)
92
+
93
+ assert result.exit_code == 0
94
+ data = json.loads(result.output)
95
+ assert data["logged_in"] is False
96
+ assert data["token_status"] == "expired"
97
+
98
+
71
99
  def test_whoami_logged_in_json(tmp_path: Path):
72
100
  profile_resp = MagicMock()
73
101
  profile_resp.raise_for_status.return_value = None
@@ -110,6 +138,7 @@ def test_whoami_logged_in_json(tmp_path: Path):
110
138
  assert result.exit_code == 0
111
139
  data = json.loads(result.output)
112
140
  assert data["logged_in"] is True
141
+ assert data["token_status"] == "valid"
113
142
  assert data["token_source"] == "env"
114
143
  assert data["email"] == "user@example.com"
115
144
  assert data["fleet_count"] == 1
@@ -239,7 +239,7 @@ def test_status_json_no_artifact(tmp_path, monkeypatch):
239
239
  assert result.exit_code == 0
240
240
  data = json.loads(result.output.strip())
241
241
  assert data["artifact"] is None
242
- assert data["logged_in"] is False
242
+ assert data["has_token"] is False
243
243
  assert "version" in data
244
244
 
245
245
 
@@ -0,0 +1,93 @@
1
+ """Tests for strategy inference from topology degrees (P0-B)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from unittest.mock import patch
7
+
8
+ from alloc.cli import _infer_parallel_topology_from_env
9
+
10
+
11
+ class TestStrategyInference:
12
+ """Strategy should be inferred from TP/PP/DP degrees when present."""
13
+
14
+ def _topo(self, env=None, num_gpus=4):
15
+ env = env or {}
16
+ with patch.dict(os.environ, env, clear=False):
17
+ return _infer_parallel_topology_from_env(
18
+ num_gpus_detected=num_gpus,
19
+ )
20
+
21
+ def test_no_degrees_multi_gpu_infers_ddp(self):
22
+ """When no degree env vars but multiple GPUs detected, infer DDP."""
23
+ result = self._topo({}, num_gpus=4)
24
+ assert result["strategy"] == "ddp"
25
+ assert result["dp_degree"] == 4
26
+
27
+ def test_single_gpu_no_degrees_strategy_none(self):
28
+ """Single GPU with no degrees → strategy stays None."""
29
+ result = self._topo({}, num_gpus=1)
30
+ assert result["strategy"] is None
31
+
32
+ def test_dp_only_is_ddp(self):
33
+ """WORLD_SIZE=4 with no TP/PP → dp inferred → strategy=ddp."""
34
+ result = self._topo({"WORLD_SIZE": "4"})
35
+ assert result["strategy"] == "ddp"
36
+ assert result["dp_degree"] == 4
37
+
38
+ def test_tp_only(self):
39
+ """TP_SIZE=4 alone → strategy=tp."""
40
+ result = self._topo({"TP_SIZE": "4"})
41
+ assert result["strategy"] == "tp"
42
+
43
+ def test_pp_only(self):
44
+ """PP_SIZE=4 alone → strategy=pp."""
45
+ result = self._topo({"PP_SIZE": "4"})
46
+ assert result["strategy"] == "pp"
47
+
48
+ def test_tp_dp(self):
49
+ """TP_SIZE=2 with DP_SIZE=2 → strategy=tp+dp."""
50
+ result = self._topo({"TP_SIZE": "2", "DP_SIZE": "2"})
51
+ assert result["strategy"] == "tp+dp"
52
+
53
+ def test_pp_dp(self):
54
+ """PP_SIZE=2 with DP_SIZE=2 → strategy=pp+dp."""
55
+ result = self._topo({"PP_SIZE": "2", "DP_SIZE": "2"})
56
+ assert result["strategy"] == "pp+dp"
57
+
58
+ def test_tp_pp_dp(self):
59
+ """All three degrees → strategy=tp+pp+dp."""
60
+ result = self._topo({"TP_SIZE": "2", "PP_SIZE": "2", "DP_SIZE": "2"})
61
+ assert result["strategy"] == "tp+pp+dp"
62
+
63
+ def test_tp_pp_no_dp(self):
64
+ """TP+PP without explicit DP → strategy=tp+pp+dp."""
65
+ result = self._topo({"TP_SIZE": "2", "PP_SIZE": "2"})
66
+ assert result["strategy"] == "tp+pp+dp"
67
+
68
+ def test_tp_size_1_not_counted(self):
69
+ """TP_SIZE=1 should not count as tensor parallelism."""
70
+ result = self._topo({"TP_SIZE": "1", "DP_SIZE": "4"})
71
+ assert result["strategy"] == "ddp"
72
+
73
+ def test_pp_size_1_not_counted(self):
74
+ """PP_SIZE=1 should not count as pipeline parallelism."""
75
+ result = self._topo({"PP_SIZE": "1", "DP_SIZE": "4"})
76
+ assert result["strategy"] == "ddp"
77
+
78
+ def test_dp_inferred_from_world_size(self):
79
+ """DP inferred from WORLD_SIZE / (TP * PP) → strategy includes dp."""
80
+ result = self._topo({"WORLD_SIZE": "8", "TP_SIZE": "2"})
81
+ assert result["dp_degree"] == 4
82
+ assert result["strategy"] == "tp+dp"
83
+
84
+
85
+ class TestProcessMapInProbeDictAssembly:
86
+ """process_map should reach probe_dict from ProbeResult."""
87
+
88
+ def test_process_map_present_in_topology_return(self):
89
+ """Topology dict now includes strategy field."""
90
+ with patch.dict(os.environ, {"WORLD_SIZE": "4"}, clear=False):
91
+ topo = _infer_parallel_topology_from_env(num_gpus_detected=4)
92
+ assert "strategy" in topo
93
+ assert topo["strategy"] == "ddp"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes