alloc 0.0.8__tar.gz → 0.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {alloc-0.0.8 → alloc-0.0.10}/PKG-INFO +2 -2
  2. {alloc-0.0.8 → alloc-0.0.10}/README.md +1 -1
  3. {alloc-0.0.8 → alloc-0.0.10}/pyproject.toml +1 -1
  4. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/__init__.py +1 -1
  5. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/browser_auth.py +3 -2
  6. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/callbacks.py +4 -1
  7. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/cli.py +23 -6
  8. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/diagnosis_engine.py +1 -1
  9. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/extractor_runner.py +24 -1
  10. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/probe.py +13 -2
  11. {alloc-0.0.8 → alloc-0.0.10}/src/alloc.egg-info/PKG-INFO +2 -2
  12. {alloc-0.0.8 → alloc-0.0.10}/tests/test_auth.py +29 -0
  13. {alloc-0.0.8 → alloc-0.0.10}/tests/test_callbacks.py +68 -0
  14. {alloc-0.0.8 → alloc-0.0.10}/tests/test_diagnosis_engine.py +12 -0
  15. {alloc-0.0.8 → alloc-0.0.10}/tests/test_probe_multi.py +55 -0
  16. {alloc-0.0.8 → alloc-0.0.10}/tests/test_topology_strategy.py +9 -3
  17. {alloc-0.0.8 → alloc-0.0.10}/setup.cfg +0 -0
  18. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/artifact_loader.py +0 -0
  19. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/artifact_writer.py +0 -0
  20. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/catalog/__init__.py +0 -0
  21. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/catalog/default_rate_card.json +0 -0
  22. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/catalog/gpus.v1.json +0 -0
  23. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/code_analyzer.py +0 -0
  24. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/config.py +0 -0
  25. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/context.py +0 -0
  26. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/diagnosis_display.py +0 -0
  27. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/diagnosis_rules.py +0 -0
  28. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/display.py +0 -0
  29. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/ghost.py +0 -0
  30. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/model_extractor.py +0 -0
  31. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/model_registry.py +0 -0
  32. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/stability.py +0 -0
  33. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/upload.py +0 -0
  34. {alloc-0.0.8 → alloc-0.0.10}/src/alloc/yaml_config.py +0 -0
  35. {alloc-0.0.8 → alloc-0.0.10}/src/alloc.egg-info/SOURCES.txt +0 -0
  36. {alloc-0.0.8 → alloc-0.0.10}/src/alloc.egg-info/dependency_links.txt +0 -0
  37. {alloc-0.0.8 → alloc-0.0.10}/src/alloc.egg-info/entry_points.txt +0 -0
  38. {alloc-0.0.8 → alloc-0.0.10}/src/alloc.egg-info/requires.txt +0 -0
  39. {alloc-0.0.8 → alloc-0.0.10}/src/alloc.egg-info/top_level.txt +0 -0
  40. {alloc-0.0.8 → alloc-0.0.10}/tests/test_artifact.py +0 -0
  41. {alloc-0.0.8 → alloc-0.0.10}/tests/test_artifact_loader.py +0 -0
  42. {alloc-0.0.8 → alloc-0.0.10}/tests/test_catalog.py +0 -0
  43. {alloc-0.0.8 → alloc-0.0.10}/tests/test_cli.py +0 -0
  44. {alloc-0.0.8 → alloc-0.0.10}/tests/test_code_analyzer.py +0 -0
  45. {alloc-0.0.8 → alloc-0.0.10}/tests/test_context.py +0 -0
  46. {alloc-0.0.8 → alloc-0.0.10}/tests/test_diagnose_cli.py +0 -0
  47. {alloc-0.0.8 → alloc-0.0.10}/tests/test_diagnosis_rules.py +0 -0
  48. {alloc-0.0.8 → alloc-0.0.10}/tests/test_extractor_activation.py +0 -0
  49. {alloc-0.0.8 → alloc-0.0.10}/tests/test_ghost.py +0 -0
  50. {alloc-0.0.8 → alloc-0.0.10}/tests/test_ghost_degradation.py +0 -0
  51. {alloc-0.0.8 → alloc-0.0.10}/tests/test_init_from_org.py +0 -0
  52. {alloc-0.0.8 → alloc-0.0.10}/tests/test_interconnect.py +0 -0
  53. {alloc-0.0.8 → alloc-0.0.10}/tests/test_model_extractor.py +0 -0
  54. {alloc-0.0.8 → alloc-0.0.10}/tests/test_probe_hw.py +0 -0
  55. {alloc-0.0.8 → alloc-0.0.10}/tests/test_scan_auth.py +0 -0
  56. {alloc-0.0.8 → alloc-0.0.10}/tests/test_stability.py +0 -0
  57. {alloc-0.0.8 → alloc-0.0.10}/tests/test_upload.py +0 -0
  58. {alloc-0.0.8 → alloc-0.0.10}/tests/test_verdict.py +0 -0
  59. {alloc-0.0.8 → alloc-0.0.10}/tests/test_yaml_config.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.8
3
+ Version: 0.0.10
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -40,7 +40,7 @@ alloc run python train.py
40
40
  ```
41
41
 
42
42
  ```
43
- alloc v0.0.2 — Calibrate
43
+ alloc v0.0.9 — Calibrate
44
44
 
45
45
  Run Summary
46
46
  Peak VRAM 31.2 GB / 40.0 GB (A100)
@@ -12,7 +12,7 @@ alloc run python train.py
12
12
  ```
13
13
 
14
14
  ```
15
- alloc v0.0.2 — Calibrate
15
+ alloc v0.0.9 — Calibrate
16
16
 
17
17
  Run Summary
18
18
  Peak VRAM 31.2 GB / 40.0 GB (A100)
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "alloc"
7
- version = "0.0.8"
7
+ version = "0.0.10"
8
8
  description = "Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints."
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -9,7 +9,7 @@ _warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda"
9
9
  _warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
10
10
  del _warnings
11
11
 
12
- __version__ = "0.0.8"
12
+ __version__ = "0.0.10"
13
13
 
14
14
  from alloc.ghost import ghost, GhostReport
15
15
  from alloc.callbacks import AllocCallback as HuggingFaceCallback
@@ -121,8 +121,9 @@ def browser_login(
121
121
  })
122
122
  authorize_url = f"{supabase_url}/auth/v1/authorize?{authorize_params}"
123
123
 
124
- # Bind to 0.0.0.0 so both localhost and 127.0.0.1 reach the server.
125
- server = HTTPServer(("0.0.0.0", port), _CallbackHandler)
124
+ # Bind to 127.0.0.1 only the auth callback server should never be
125
+ # reachable from the network.
126
+ server = HTTPServer(("127.0.0.1", port), _CallbackHandler)
126
127
  server.auth_code = None # type: ignore[attr-defined]
127
128
  server.auth_error = None # type: ignore[attr-defined]
128
129
  server.timeout = 1 # poll interval for handle_request()
@@ -501,7 +501,10 @@ class _NvmlMonitor:
501
501
 
502
502
  self._hw_context["nvlink_active_links"] = active_links
503
503
  except Exception:
504
- pass
504
+ # NVLink detection code failed after entering the try block.
505
+ # We know NVML is functional (handles exist), so fall back to
506
+ # generic "nvlink" rather than leaving interconnect_type unset.
507
+ self._hw_context["interconnect_type"] = "nvlink"
505
508
 
506
509
  self._thread = threading.Thread(target=self._sample_loop, daemon=True)
507
510
  self._thread.start()
@@ -2400,23 +2400,33 @@ def whoami(
2400
2400
  profile = _get("/profile")
2401
2401
  fleet = _get("/gpu-fleet")
2402
2402
  else:
2403
- if json_output:
2403
+ # whoami is a status command — report structured result, exit 0
2404
+ if e.response.status_code == 401:
2405
+ out["token_status"] = "expired"
2406
+ else:
2407
+ out["token_status"] = "error"
2404
2408
  out["error"] = f"API error {e.response.status_code}"
2409
+ if json_output:
2405
2410
  _print_json(out)
2406
2411
  else:
2407
- console.print(f"[red]API error {e.response.status_code}[/red]")
2412
+ if e.response.status_code == 401:
2413
+ console.print("[yellow]Token expired.[/yellow]")
2414
+ else:
2415
+ console.print(f"[red]API error {e.response.status_code}[/red]")
2408
2416
  console.print("[dim]Run: alloc login[/dim]")
2409
- raise typer.Exit(1)
2417
+ return
2410
2418
  except httpx.ConnectError:
2419
+ out["token_status"] = "unreachable"
2420
+ out["error"] = f"Cannot connect to {api_url}"
2411
2421
  if json_output:
2412
- out["error"] = f"Cannot connect to {api_url}"
2413
2422
  _print_json(out)
2414
2423
  else:
2415
2424
  console.print(f"[red]Cannot connect to {api_url}[/red]")
2416
- raise typer.Exit(1)
2425
+ return
2417
2426
 
2418
2427
  # API validated the token — now we know login is real
2419
2428
  out["logged_in"] = True
2429
+ out["token_status"] = "valid"
2420
2430
 
2421
2431
  gpus = fleet.get("gpus") or []
2422
2432
  fleet_count = len([g for g in gpus if g.get("fleet_status") == "in_fleet"])
@@ -3565,7 +3575,14 @@ def _infer_parallel_topology_from_env(*, num_gpus_detected: int, config_intercon
3565
3575
  strategy = "pp+dp" if (dp is not None and dp > 1) else "pp"
3566
3576
  elif dp is not None and dp > 1:
3567
3577
  strategy = "ddp"
3568
- # If none of the above matched, strategy stays None (unknown)
3578
+ elif strategy is None and num_gpus_detected > 1 and not has_tp and not has_pp:
3579
+ # Multiple GPUs detected via NVML with no TP/PP env vars →
3580
+ # DDP is PyTorch's default and the only realistic inference.
3581
+ # This is NOT the old `or "ddp"` — it only fires when probe
3582
+ # actually observed multiple GPU processes.
3583
+ strategy = "ddp"
3584
+ if dp is None:
3585
+ dp = num_gpus_detected
3569
3586
 
3570
3587
  return {
3571
3588
  "num_nodes": nnodes or 1,
@@ -403,7 +403,7 @@ def _estimate_model_params(model_name: str) -> Optional[float]:
403
403
  "whisper-large": 1.55,
404
404
  }
405
405
 
406
- for key, params in estimates.items():
406
+ for key, params in sorted(estimates.items(), key=lambda x: len(x[0]), reverse=True):
407
407
  if key in name:
408
408
  return params
409
409
 
@@ -281,7 +281,30 @@ def main():
281
281
  "activation_method": activation_result.get("activation_method"),
282
282
  }
283
283
  else:
284
- result = {"status": "no_model"}
284
+ # No model found — check if this is a distributed training script
285
+ # that hides the model inside __main__ guard or main()
286
+ _is_dist = False
287
+ try:
288
+ import torch.distributed as _dist_mod
289
+ if _dist_mod.is_initialized():
290
+ _is_dist = True
291
+ except Exception:
292
+ pass
293
+ if not _is_dist:
294
+ # Check if module imported distributed primitives
295
+ for attr_name in dir(module):
296
+ try:
297
+ obj = getattr(module, attr_name)
298
+ mod_name = getattr(obj, "__module__", "") or ""
299
+ if "torch.distributed" in mod_name or "torch.nn.parallel" in mod_name:
300
+ _is_dist = True
301
+ break
302
+ except Exception:
303
+ continue
304
+ if _is_dist:
305
+ result = {"status": "error_distributed", "error": "no model found — script uses distributed training"}
306
+ else:
307
+ result = {"status": "no_model"}
285
308
 
286
309
  with open(sidecar_path, "w") as f:
287
310
  json.dump(result, f)
@@ -215,8 +215,19 @@ def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0, expected_gpus=None
215
215
  if 0 <= idx < device_count:
216
216
  visible_physical.append(idx)
217
217
  except ValueError:
218
- visible_physical = list(range(device_count))
219
- break
218
+ # UUID-style device identifiers — try NVML UUID matching
219
+ try:
220
+ for phys_idx in range(device_count):
221
+ handle = pynvml.nvmlDeviceGetHandleByIndex(phys_idx)
222
+ uuid = pynvml.nvmlDeviceGetUUID(handle)
223
+ if isinstance(uuid, bytes):
224
+ uuid = uuid.decode("utf-8", errors="replace")
225
+ if d in uuid:
226
+ visible_physical.append(phys_idx)
227
+ break
228
+ except Exception:
229
+ visible_physical = list(range(device_count))
230
+ break
220
231
  search_indices = visible_physical if visible_physical else list(range(device_count))
221
232
  else:
222
233
  search_indices = list(range(device_count))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.8
3
+ Version: 0.0.10
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -40,7 +40,7 @@ alloc run python train.py
40
40
  ```
41
41
 
42
42
  ```
43
- alloc v0.0.2 — Calibrate
43
+ alloc v0.0.9 — Calibrate
44
44
 
45
45
  Run Summary
46
46
  Peak VRAM 31.2 GB / 40.0 GB (A100)
@@ -68,6 +68,34 @@ def test_whoami_not_logged_in_json(tmp_path: Path):
68
68
  assert data["api_url"] == "https://api.example.com"
69
69
 
70
70
 
71
+ def test_whoami_stale_token_json(tmp_path: Path):
72
+ """Stale token should exit 0 with token_status: expired."""
73
+ mock_resp = MagicMock()
74
+ mock_resp.status_code = 401
75
+ mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
76
+ "Unauthorized", request=MagicMock(), response=mock_resp,
77
+ )
78
+ mock_client = MagicMock()
79
+ mock_client.__enter__.return_value = mock_client
80
+ mock_client.__exit__.return_value = False
81
+ mock_client.get.return_value = mock_resp
82
+
83
+ env = {
84
+ "HOME": str(tmp_path),
85
+ "ALLOC_API_URL": "https://api.example.com",
86
+ "ALLOC_TOKEN": "stale-token",
87
+ }
88
+
89
+ with patch("httpx.Client", return_value=mock_client), \
90
+ patch("alloc.cli.try_refresh_access_token", return_value=None):
91
+ result = runner.invoke(app, ["whoami", "--json"], env=env)
92
+
93
+ assert result.exit_code == 0
94
+ data = json.loads(result.output)
95
+ assert data["logged_in"] is False
96
+ assert data["token_status"] == "expired"
97
+
98
+
71
99
  def test_whoami_logged_in_json(tmp_path: Path):
72
100
  profile_resp = MagicMock()
73
101
  profile_resp.raise_for_status.return_value = None
@@ -110,6 +138,7 @@ def test_whoami_logged_in_json(tmp_path: Path):
110
138
  assert result.exit_code == 0
111
139
  data = json.loads(result.output)
112
140
  assert data["logged_in"] is True
141
+ assert data["token_status"] == "valid"
113
142
  assert data["token_source"] == "env"
114
143
  assert data["email"] == "user@example.com"
115
144
  assert data["fleet_count"] == 1
@@ -1189,3 +1189,71 @@ class TestNvmlMonitorThreadSafety:
1189
1189
  assert len(probe["per_rank_peak_vram_mb"]) == 2
1190
1190
  for peak in probe["per_rank_peak_vram_mb"]:
1191
1191
  assert peak > 0
1192
+
1193
+
1194
+ class TestNvmlMonitorNvlinkFallback:
1195
+ def test_nvlink_detection_failure_sets_nvlink_fallback(self):
1196
+ """When the outer NVLink detection block raises, fall back to 'nvlink'.
1197
+
1198
+ We trigger this by making nvmlDeviceGetNvLinkState raise on the first
1199
+ call (inner except breaks the loop → active_links=0 → 'pcie'), and
1200
+ then making the active_links comparison itself blow up. The simplest
1201
+ trigger is having _gpu_handles[0] raise IndexError (empty list after
1202
+ the early-return guard).
1203
+ """
1204
+ mock_pynvml = MagicMock()
1205
+ mock_pynvml.nvmlInit.return_value = None
1206
+ mock_pynvml.nvmlShutdown.return_value = None
1207
+ mock_pynvml.nvmlDeviceGetCount.return_value = 2
1208
+ mock_pynvml.nvmlDeviceGetName.return_value = "NVIDIA A100-SXM4-80GB"
1209
+ mem = SimpleNamespace(total=80 * 1024**3, used=1 * 1024**3)
1210
+ mock_pynvml.nvmlDeviceGetMemoryInfo.return_value = mem
1211
+ mock_pynvml.nvmlSystemGetDriverVersion.return_value = "535"
1212
+ mock_pynvml.nvmlSystemGetCudaDriverVersion.return_value = 12000
1213
+ mock_pynvml.nvmlDeviceGetCudaComputeCapability.return_value = (8, 0)
1214
+ util = SimpleNamespace(gpu=75, memory=60)
1215
+ mock_pynvml.nvmlDeviceGetUtilizationRates.return_value = util
1216
+ mock_pynvml.nvmlDeviceGetPowerUsage.return_value = 300000
1217
+
1218
+ # Use a handle list that passes the `if not self._gpu_handles` guard
1219
+ # (it's truthy) but raises IndexError on `self._gpu_handles[0]`.
1220
+ class BadHandleList:
1221
+ """Truthy but raises on index access."""
1222
+ def __bool__(self):
1223
+ return True
1224
+ def __len__(self):
1225
+ return 2
1226
+ def __iter__(self):
1227
+ return iter([])
1228
+ def __getitem__(self, idx):
1229
+ raise IndexError("corrupted handle list")
1230
+
1231
+ with patch("alloc.callbacks._try_import_pynvml", return_value=mock_pynvml):
1232
+ monitor = _NvmlMonitor()
1233
+
1234
+ # Replace handles after __init__ but before start().
1235
+ # start() will re-populate from nvmlDeviceGetCount, so we also need to
1236
+ # make the handle-building loop produce our bad list. We do this by
1237
+ # patching nvmlDeviceGetHandleByIndex to raise, so _gpu_handles stays
1238
+ # empty after the try/except in handle building. But that triggers
1239
+ # the early return. Instead, we patch _gpu_handles AFTER start()
1240
+ # builds them but BEFORE NVLink detection runs. We achieve this by
1241
+ # having nvmlDeviceGetCudaComputeCapability (the last hw-context call
1242
+ # before NVLink detection) swap in the bad handles as a side effect.
1243
+ original_sm = mock_pynvml.nvmlDeviceGetCudaComputeCapability
1244
+
1245
+ def swap_handles_then_return_sm(handle):
1246
+ monitor._gpu_handles = BadHandleList()
1247
+ return (8, 0)
1248
+
1249
+ mock_pynvml.nvmlDeviceGetCudaComputeCapability = MagicMock(
1250
+ side_effect=swap_handles_then_return_sm
1251
+ )
1252
+
1253
+ monitor.start()
1254
+ import time
1255
+ time.sleep(0.02)
1256
+ monitor.stop()
1257
+
1258
+ hw, _ = monitor.get_results()
1259
+ assert hw.get("interconnect_type") == "nvlink"
@@ -359,3 +359,15 @@ def test_estimate_model_params_known_vision_model():
359
359
 
360
360
  result = _estimate_model_params("stable-diffusion")
361
361
  assert result == 0.865
362
+
363
+
364
+ def test_estimate_model_params_gpt2_medium_prefix_match():
365
+ """gpt2-medium-finetuned should match gpt2-medium (0.355), not gpt2 (0.124)."""
366
+ result = _estimate_model_params("gpt2-medium-finetuned")
367
+ assert result == 0.355
368
+
369
+
370
+ def test_estimate_model_params_gpt2_alone():
371
+ """Plain gpt2 should still match 0.124."""
372
+ result = _estimate_model_params("gpt2")
373
+ assert result == 0.124
@@ -158,6 +158,61 @@ def test_parse_plain_python():
158
158
  assert _parse_launcher_gpu_count(["python", "train.py"]) is None
159
159
 
160
160
 
161
+ # ── CVD UUID resolution ──
162
+
163
+
164
+ def test_cvd_uuid_resolves_to_correct_index():
165
+ """UUID-style CUDA_VISIBLE_DEVICES should resolve to the matching physical GPU index."""
166
+ mock = _mock_pynvml_multi_gpu(
167
+ proc_pid=1000,
168
+ gpu_process_map={0: [1000], 1: [], 2: []},
169
+ )
170
+ mock.nvmlDeviceGetCount.return_value = 3
171
+
172
+ # Set up UUID resolution: GPU 0 → UUID-A, GPU 1 → UUID-B, GPU 2 → UUID-C
173
+ uuid_map = {0: "GPU-aaaa-1111", 1: "GPU-bbbb-2222", 2: "GPU-cccc-3333"}
174
+ handles = {}
175
+ for idx in range(3):
176
+ handles[idx] = MagicMock(name=f"handle_{idx}")
177
+
178
+ def get_handle(idx):
179
+ return handles[idx]
180
+
181
+ def get_uuid(handle):
182
+ for idx, h in handles.items():
183
+ if handle == h:
184
+ return uuid_map[idx]
185
+ return "GPU-unknown"
186
+
187
+ mock.nvmlDeviceGetHandleByIndex = MagicMock(side_effect=get_handle)
188
+ mock.nvmlDeviceGetUUID = MagicMock(side_effect=get_uuid)
189
+
190
+ # CVD set to GPU 2's UUID
191
+ with patch("alloc.probe._get_child_pids", return_value=[]):
192
+ with patch.dict("os.environ", {"CUDA_VISIBLE_DEVICES": "GPU-cccc-3333"}):
193
+ result = _discover_gpu_indices(1000, mock, fallback_index=0)
194
+ # Should only search GPU index 2
195
+ assert 2 in result or result == [0] # either found on idx 2, or fallback if no PID match
196
+
197
+
198
+ def test_cvd_invalid_uuid_falls_back_to_all_gpus():
199
+ """Invalid UUID that doesn't match any device should fall back to all GPUs."""
200
+ mock = _mock_pynvml_multi_gpu(
201
+ proc_pid=1000,
202
+ gpu_process_map={0: [1000], 1: []},
203
+ )
204
+ mock.nvmlDeviceGetCount.return_value = 2
205
+
206
+ # UUID lookup raises for all devices
207
+ mock.nvmlDeviceGetUUID = MagicMock(side_effect=RuntimeError("no UUID support"))
208
+
209
+ with patch("alloc.probe._get_child_pids", return_value=[]):
210
+ with patch.dict("os.environ", {"CUDA_VISIBLE_DEVICES": "GPU-nonexistent"}):
211
+ result = _discover_gpu_indices(1000, mock, fallback_index=0)
212
+ # Should fall back to searching all GPUs and find PID 1000 on GPU 0
213
+ assert 0 in result
214
+
215
+
161
216
  def test_parse_torch_distributed_launch():
162
217
  assert _parse_launcher_gpu_count([
163
218
  "python", "-m", "torch.distributed.launch", "--nproc_per_node=2", "train.py"
@@ -18,9 +18,15 @@ class TestStrategyInference:
18
18
  num_gpus_detected=num_gpus,
19
19
  )
20
20
 
21
- def test_no_degrees_strategy_none(self):
22
- """When no degree env vars set, strategy should be None."""
23
- result = self._topo({})
21
+ def test_no_degrees_multi_gpu_infers_ddp(self):
22
+ """When no degree env vars but multiple GPUs detected, infer DDP."""
23
+ result = self._topo({}, num_gpus=4)
24
+ assert result["strategy"] == "ddp"
25
+ assert result["dp_degree"] == 4
26
+
27
+ def test_single_gpu_no_degrees_strategy_none(self):
28
+ """Single GPU with no degrees → strategy stays None."""
29
+ result = self._topo({}, num_gpus=1)
24
30
  assert result["strategy"] is None
25
31
 
26
32
  def test_dp_only_is_ddp(self):
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes