alloc 0.0.6__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alloc-0.0.6 → alloc-0.0.7}/PKG-INFO +1 -1
- {alloc-0.0.6 → alloc-0.0.7}/pyproject.toml +1 -1
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/__init__.py +3 -1
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/cli.py +38 -3
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/model_extractor.py +20 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/probe.py +7 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc.egg-info/PKG-INFO +1 -1
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc.egg-info/SOURCES.txt +2 -0
- alloc-0.0.7/tests/test_ghost_degradation.py +145 -0
- alloc-0.0.7/tests/test_scan_auth.py +142 -0
- {alloc-0.0.6 → alloc-0.0.7}/README.md +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/setup.cfg +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/artifact_loader.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/artifact_writer.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/browser_auth.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/callbacks.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/catalog/__init__.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/catalog/default_rate_card.json +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/catalog/gpus.v1.json +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/code_analyzer.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/config.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/context.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/diagnosis_display.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/diagnosis_engine.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/diagnosis_rules.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/display.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/extractor_runner.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/ghost.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/model_registry.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/stability.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/upload.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc/yaml_config.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc.egg-info/dependency_links.txt +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc.egg-info/entry_points.txt +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc.egg-info/requires.txt +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/src/alloc.egg-info/top_level.txt +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_artifact.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_artifact_loader.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_auth.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_callbacks.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_catalog.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_cli.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_code_analyzer.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_context.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_diagnose_cli.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_diagnosis_engine.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_diagnosis_rules.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_extractor_activation.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_ghost.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_init_from_org.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_interconnect.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_model_extractor.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_probe_hw.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_probe_multi.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_stability.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_upload.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_verdict.py +0 -0
- {alloc-0.0.6 → alloc-0.0.7}/tests/test_yaml_config.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alloc
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
|
|
5
5
|
Author-email: Alloc Labs <hello@alloclabs.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "alloc"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.7"
|
|
8
8
|
description = "Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -5,9 +5,11 @@ from __future__ import annotations
|
|
|
5
5
|
import warnings as _warnings
|
|
6
6
|
_warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
|
|
7
7
|
_warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
8
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
|
|
9
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
|
|
8
10
|
del _warnings
|
|
9
11
|
|
|
10
|
-
__version__ = "0.0.
|
|
12
|
+
__version__ = "0.0.7"
|
|
11
13
|
|
|
12
14
|
from alloc.ghost import ghost, GhostReport
|
|
13
15
|
from alloc.callbacks import AllocCallback as HuggingFaceCallback
|
|
@@ -19,10 +19,12 @@ import sys
|
|
|
19
19
|
import warnings
|
|
20
20
|
from typing import List, Optional
|
|
21
21
|
|
|
22
|
-
# Suppress noisy third-party warnings globally — pynvml deprecation
|
|
23
|
-
# urllib3 LibreSSL warnings clutter
|
|
22
|
+
# Suppress noisy third-party warnings globally — pynvml deprecation (emitted
|
|
23
|
+
# from torch.cuda.__init__) and urllib3 LibreSSL warnings clutter CLI output.
|
|
24
24
|
warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
|
|
25
25
|
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
26
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
|
|
27
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
|
|
26
28
|
warnings.filterwarnings("ignore", message=".*LibreSSL.*", module="urllib3")
|
|
27
29
|
|
|
28
30
|
import typer
|
|
@@ -75,6 +77,19 @@ def ghost(
|
|
|
75
77
|
console.print(f"[dim]Tip: alloc ghost {script} --param-count-b 7.0[/dim]")
|
|
76
78
|
raise typer.Exit(1)
|
|
77
79
|
|
|
80
|
+
if info.extraction_error:
|
|
81
|
+
if json_output:
|
|
82
|
+
_print_json({
|
|
83
|
+
"error": info.extraction_error,
|
|
84
|
+
"detail": info.extraction_detail,
|
|
85
|
+
"supported": False,
|
|
86
|
+
})
|
|
87
|
+
else:
|
|
88
|
+
console.print(f"[yellow]{info.extraction_detail}[/yellow]")
|
|
89
|
+
if info.extraction_error == "distributed_entrypoint":
|
|
90
|
+
console.print("[dim]Tip: alloc ghost model.py (point to the file that defines your model)[/dim]")
|
|
91
|
+
raise typer.Exit(1)
|
|
92
|
+
|
|
78
93
|
# Use dtype from execution if available, otherwise CLI flag
|
|
79
94
|
resolved_dtype = info.dtype if info.method == "execution" else dtype
|
|
80
95
|
|
|
@@ -2099,12 +2114,32 @@ def scan(
|
|
|
2099
2114
|
|
|
2100
2115
|
try:
|
|
2101
2116
|
headers = {"Content-Type": "application/json"}
|
|
2117
|
+
used_auth = bool(token)
|
|
2118
|
+
|
|
2102
2119
|
if token:
|
|
2103
2120
|
headers["Authorization"] = f"Bearer {token}"
|
|
2121
|
+
endpoint = "/scans"
|
|
2122
|
+
else:
|
|
2123
|
+
endpoint = "/scans/cli"
|
|
2104
2124
|
|
|
2105
|
-
endpoint = "/scans" if token else "/scans/cli"
|
|
2106
2125
|
with httpx.Client(timeout=30) as client:
|
|
2107
2126
|
resp = client.post(f"{api_url}{endpoint}", json=payload, headers=headers)
|
|
2127
|
+
|
|
2128
|
+
# On 401 with a saved token: try refresh, then fall back to public endpoint
|
|
2129
|
+
if resp.status_code == 401 and used_auth:
|
|
2130
|
+
new_token = try_refresh_access_token()
|
|
2131
|
+
if new_token:
|
|
2132
|
+
headers["Authorization"] = f"Bearer {new_token}"
|
|
2133
|
+
resp = client.post(f"{api_url}/scans", json=payload, headers=headers)
|
|
2134
|
+
else:
|
|
2135
|
+
# Token refresh failed — fall back to unauthenticated scan
|
|
2136
|
+
console.print(
|
|
2137
|
+
"[yellow]Session expired — falling back to public scan "
|
|
2138
|
+
"(org fleet context unavailable). Run `alloc login` to restore.[/yellow]",
|
|
2139
|
+
)
|
|
2140
|
+
del headers["Authorization"]
|
|
2141
|
+
resp = client.post(f"{api_url}/scans/cli", json=payload, headers=headers)
|
|
2142
|
+
|
|
2108
2143
|
resp.raise_for_status()
|
|
2109
2144
|
result = resp.json()
|
|
2110
2145
|
|
|
@@ -33,6 +33,8 @@ class ModelInfo:
|
|
|
33
33
|
seq_length: Optional[int] = None
|
|
34
34
|
activation_memory_bytes: Optional[int] = None
|
|
35
35
|
activation_method: Optional[str] = None # "traced" | None
|
|
36
|
+
extraction_error: Optional[str] = None # "distributed_entrypoint" | None
|
|
37
|
+
extraction_detail: Optional[str] = None # human-readable explanation
|
|
36
38
|
|
|
37
39
|
|
|
38
40
|
def extract_model_info(
|
|
@@ -134,6 +136,24 @@ def _extract_via_subprocess(
|
|
|
134
136
|
activation_method=data.get("activation_method"),
|
|
135
137
|
)
|
|
136
138
|
|
|
139
|
+
# Structured degradation for distributed scripts
|
|
140
|
+
if data.get("status") == "error":
|
|
141
|
+
error_msg = data.get("error", "")
|
|
142
|
+
_dist_keywords = ("init_process_group", "NCCL", "gloo", "distributed",
|
|
143
|
+
"MASTER_ADDR", "MASTER_PORT", "RendezvousError")
|
|
144
|
+
if any(kw.lower() in error_msg.lower() for kw in _dist_keywords):
|
|
145
|
+
return ModelInfo(
|
|
146
|
+
param_count=0,
|
|
147
|
+
dtype="float16",
|
|
148
|
+
model_name=None,
|
|
149
|
+
method="execution",
|
|
150
|
+
extraction_error="distributed_entrypoint",
|
|
151
|
+
extraction_detail=(
|
|
152
|
+
"Script requires a distributed runtime (e.g. torchrun). "
|
|
153
|
+
"Run ghost on the model definition file instead of the launcher script."
|
|
154
|
+
),
|
|
155
|
+
)
|
|
156
|
+
|
|
137
157
|
return None
|
|
138
158
|
|
|
139
159
|
except subprocess.TimeoutExpired:
|
|
@@ -18,6 +18,13 @@ from dataclasses import dataclass, field
|
|
|
18
18
|
from enum import Enum
|
|
19
19
|
from typing import List, Optional
|
|
20
20
|
|
|
21
|
+
import warnings as _warnings
|
|
22
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
|
|
23
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
24
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
|
|
25
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
|
|
26
|
+
del _warnings
|
|
27
|
+
|
|
21
28
|
|
|
22
29
|
class StopReason(str, Enum):
|
|
23
30
|
STABLE = "stable"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alloc
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
|
|
5
5
|
Author-email: Alloc Labs <hello@alloclabs.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -43,11 +43,13 @@ tests/test_diagnosis_engine.py
|
|
|
43
43
|
tests/test_diagnosis_rules.py
|
|
44
44
|
tests/test_extractor_activation.py
|
|
45
45
|
tests/test_ghost.py
|
|
46
|
+
tests/test_ghost_degradation.py
|
|
46
47
|
tests/test_init_from_org.py
|
|
47
48
|
tests/test_interconnect.py
|
|
48
49
|
tests/test_model_extractor.py
|
|
49
50
|
tests/test_probe_hw.py
|
|
50
51
|
tests/test_probe_multi.py
|
|
52
|
+
tests/test_scan_auth.py
|
|
51
53
|
tests/test_stability.py
|
|
52
54
|
tests/test_upload.py
|
|
53
55
|
tests/test_verdict.py
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Tests for ghost structured degradation on distributed scripts."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import tempfile
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from unittest.mock import patch
|
|
10
|
+
|
|
11
|
+
from typer.testing import CliRunner
|
|
12
|
+
|
|
13
|
+
from alloc.cli import app
|
|
14
|
+
from alloc.model_extractor import ModelInfo, extract_model_info
|
|
15
|
+
|
|
16
|
+
runner = CliRunner()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_distributed_error_returns_structured_modelinfo():
|
|
20
|
+
"""When extractor subprocess fails with a distributed keyword, return structured ModelInfo."""
|
|
21
|
+
sidecar_data = json.dumps({
|
|
22
|
+
"status": "error",
|
|
23
|
+
"error": "RuntimeError: torch.distributed.init_process_group requires MASTER_ADDR",
|
|
24
|
+
})
|
|
25
|
+
|
|
26
|
+
def _fake_subprocess_run(*args, **kwargs):
|
|
27
|
+
# Write the sidecar file
|
|
28
|
+
sidecar_path = args[0][3] # [python, -m, alloc.extractor_runner, sidecar_path, script_path]
|
|
29
|
+
with open(sidecar_path, "w") as f:
|
|
30
|
+
f.write(sidecar_data)
|
|
31
|
+
|
|
32
|
+
# Create a dummy script
|
|
33
|
+
fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_dist_")
|
|
34
|
+
os.write(fd, b"import torch\ntorch.distributed.init_process_group('nccl')\n")
|
|
35
|
+
os.close(fd)
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
with patch("subprocess.run", side_effect=_fake_subprocess_run):
|
|
39
|
+
info = extract_model_info(script_path)
|
|
40
|
+
|
|
41
|
+
assert info is not None
|
|
42
|
+
assert info.extraction_error == "distributed_entrypoint"
|
|
43
|
+
assert "distributed runtime" in info.extraction_detail
|
|
44
|
+
assert info.param_count == 0
|
|
45
|
+
finally:
|
|
46
|
+
os.unlink(script_path)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_distributed_error_nccl_keyword():
|
|
50
|
+
"""NCCL errors should be caught as distributed failures."""
|
|
51
|
+
sidecar_data = json.dumps({
|
|
52
|
+
"status": "error",
|
|
53
|
+
"error": "NCCL error: unhandled system error",
|
|
54
|
+
})
|
|
55
|
+
|
|
56
|
+
fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_nccl_")
|
|
57
|
+
os.write(fd, b"pass\n")
|
|
58
|
+
os.close(fd)
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
def _fake_run(*args, **kwargs):
|
|
62
|
+
sidecar_path = args[0][3]
|
|
63
|
+
with open(sidecar_path, "w") as f:
|
|
64
|
+
f.write(sidecar_data)
|
|
65
|
+
|
|
66
|
+
with patch("subprocess.run", side_effect=_fake_run):
|
|
67
|
+
info = extract_model_info(script_path)
|
|
68
|
+
|
|
69
|
+
assert info is not None
|
|
70
|
+
assert info.extraction_error == "distributed_entrypoint"
|
|
71
|
+
finally:
|
|
72
|
+
os.unlink(script_path)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_non_distributed_error_returns_none():
|
|
76
|
+
"""Non-distributed errors should still return None (fall through to AST)."""
|
|
77
|
+
sidecar_data = json.dumps({
|
|
78
|
+
"status": "error",
|
|
79
|
+
"error": "ImportError: No module named 'custom_lib'",
|
|
80
|
+
})
|
|
81
|
+
|
|
82
|
+
fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_other_")
|
|
83
|
+
# Script with no from_pretrained so AST also returns None
|
|
84
|
+
os.write(fd, b"import custom_lib\n")
|
|
85
|
+
os.close(fd)
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
def _fake_run(*args, **kwargs):
|
|
89
|
+
sidecar_path = args[0][3]
|
|
90
|
+
with open(sidecar_path, "w") as f:
|
|
91
|
+
f.write(sidecar_data)
|
|
92
|
+
|
|
93
|
+
with patch("subprocess.run", side_effect=_fake_run):
|
|
94
|
+
info = extract_model_info(script_path)
|
|
95
|
+
|
|
96
|
+
# Should be None because error is not distributed and AST won't find a model either
|
|
97
|
+
assert info is None
|
|
98
|
+
finally:
|
|
99
|
+
os.unlink(script_path)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_ghost_cli_distributed_error_json(tmp_path: Path):
|
|
103
|
+
"""ghost --json shows structured error for distributed scripts."""
|
|
104
|
+
script_path = tmp_path / "train_ddp.py"
|
|
105
|
+
script_path.write_text("import torch\ntorch.distributed.init_process_group('nccl')\n")
|
|
106
|
+
|
|
107
|
+
dist_info = ModelInfo(
|
|
108
|
+
param_count=0,
|
|
109
|
+
dtype="float16",
|
|
110
|
+
model_name=None,
|
|
111
|
+
method="execution",
|
|
112
|
+
extraction_error="distributed_entrypoint",
|
|
113
|
+
extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
|
|
117
|
+
result = runner.invoke(app, ["ghost", str(script_path), "--json"])
|
|
118
|
+
|
|
119
|
+
assert result.exit_code != 0
|
|
120
|
+
data = json.loads(result.output)
|
|
121
|
+
assert data["error"] == "distributed_entrypoint"
|
|
122
|
+
assert data["supported"] is False
|
|
123
|
+
assert "distributed runtime" in data["detail"]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_ghost_cli_distributed_error_human(tmp_path: Path):
|
|
127
|
+
"""ghost shows human-readable message with tip for distributed scripts."""
|
|
128
|
+
script_path = tmp_path / "train_ddp.py"
|
|
129
|
+
script_path.write_text("import torch\n")
|
|
130
|
+
|
|
131
|
+
dist_info = ModelInfo(
|
|
132
|
+
param_count=0,
|
|
133
|
+
dtype="float16",
|
|
134
|
+
model_name=None,
|
|
135
|
+
method="execution",
|
|
136
|
+
extraction_error="distributed_entrypoint",
|
|
137
|
+
extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
|
|
141
|
+
result = runner.invoke(app, ["ghost", str(script_path)])
|
|
142
|
+
|
|
143
|
+
assert result.exit_code != 0
|
|
144
|
+
assert "distributed runtime" in result.output
|
|
145
|
+
assert "model.py" in result.output # tip about pointing to model file
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Tests for scan command 401 retry + /scans/cli fallback."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from unittest.mock import MagicMock, patch
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
from typer.testing import CliRunner
|
|
11
|
+
|
|
12
|
+
from alloc.cli import app
|
|
13
|
+
|
|
14
|
+
runner = CliRunner()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _make_resp(status_code: int, body: dict, url: str = "https://api.example.com/scans"):
|
|
18
|
+
req = httpx.Request("POST", url)
|
|
19
|
+
return httpx.Response(
|
|
20
|
+
status_code,
|
|
21
|
+
request=req,
|
|
22
|
+
content=json.dumps(body).encode(),
|
|
23
|
+
headers={"content-type": "application/json"},
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_scan_401_refresh_retry(tmp_path: Path):
|
|
28
|
+
"""On 401, refresh token and retry on /scans."""
|
|
29
|
+
resp_401 = _make_resp(401, {"detail": "unauthorized"})
|
|
30
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
|
|
31
|
+
|
|
32
|
+
mock_client = MagicMock()
|
|
33
|
+
mock_client.__enter__.return_value = mock_client
|
|
34
|
+
mock_client.__exit__.return_value = False
|
|
35
|
+
mock_client.post.side_effect = [resp_401, resp_ok]
|
|
36
|
+
|
|
37
|
+
cfg_file = tmp_path / ".alloc" / "config.json"
|
|
38
|
+
cfg_file.parent.mkdir(parents=True)
|
|
39
|
+
cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
|
|
40
|
+
|
|
41
|
+
env = {
|
|
42
|
+
"HOME": str(tmp_path),
|
|
43
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
with (
|
|
47
|
+
patch("httpx.Client", return_value=mock_client),
|
|
48
|
+
patch("alloc.cli.try_refresh_access_token", return_value="new-tok"),
|
|
49
|
+
):
|
|
50
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
|
|
51
|
+
|
|
52
|
+
assert result.exit_code == 0
|
|
53
|
+
assert mock_client.post.call_count == 2
|
|
54
|
+
# Second call should use refreshed token
|
|
55
|
+
second_call = mock_client.post.call_args_list[1]
|
|
56
|
+
assert "Bearer new-tok" in str(second_call)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_scan_401_refresh_fails_fallback_public(tmp_path: Path):
|
|
60
|
+
"""On 401 + refresh failure, fall back to /scans/cli with warning."""
|
|
61
|
+
resp_401 = _make_resp(401, {"detail": "unauthorized"})
|
|
62
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []},
|
|
63
|
+
url="https://api.example.com/scans/cli")
|
|
64
|
+
|
|
65
|
+
mock_client = MagicMock()
|
|
66
|
+
mock_client.__enter__.return_value = mock_client
|
|
67
|
+
mock_client.__exit__.return_value = False
|
|
68
|
+
mock_client.post.side_effect = [resp_401, resp_ok]
|
|
69
|
+
|
|
70
|
+
cfg_file = tmp_path / ".alloc" / "config.json"
|
|
71
|
+
cfg_file.parent.mkdir(parents=True)
|
|
72
|
+
cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
|
|
73
|
+
|
|
74
|
+
env = {
|
|
75
|
+
"HOME": str(tmp_path),
|
|
76
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
with (
|
|
80
|
+
patch("httpx.Client", return_value=mock_client),
|
|
81
|
+
patch("alloc.cli.try_refresh_access_token", return_value=None),
|
|
82
|
+
):
|
|
83
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
|
|
84
|
+
|
|
85
|
+
assert result.exit_code == 0
|
|
86
|
+
assert mock_client.post.call_count == 2
|
|
87
|
+
# Second call should hit /scans/cli
|
|
88
|
+
second_url = str(mock_client.post.call_args_list[1])
|
|
89
|
+
assert "/scans/cli" in second_url
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_scan_401_fallback_warns_about_dropped_features(tmp_path: Path):
|
|
93
|
+
"""Fallback to public scan warns user about lost org context."""
|
|
94
|
+
resp_401 = _make_resp(401, {"detail": "unauthorized"})
|
|
95
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
|
|
96
|
+
|
|
97
|
+
mock_client = MagicMock()
|
|
98
|
+
mock_client.__enter__.return_value = mock_client
|
|
99
|
+
mock_client.__exit__.return_value = False
|
|
100
|
+
mock_client.post.side_effect = [resp_401, resp_ok]
|
|
101
|
+
|
|
102
|
+
cfg_file = tmp_path / ".alloc" / "config.json"
|
|
103
|
+
cfg_file.parent.mkdir(parents=True)
|
|
104
|
+
cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
|
|
105
|
+
|
|
106
|
+
env = {
|
|
107
|
+
"HOME": str(tmp_path),
|
|
108
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
with (
|
|
112
|
+
patch("httpx.Client", return_value=mock_client),
|
|
113
|
+
patch("alloc.cli.try_refresh_access_token", return_value=None),
|
|
114
|
+
):
|
|
115
|
+
# Non-JSON mode to see the warning message
|
|
116
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b"], env=env)
|
|
117
|
+
|
|
118
|
+
assert result.exit_code == 0
|
|
119
|
+
assert "expired" in result.output.lower() or "falling back" in result.output.lower()
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_scan_no_token_uses_public_directly(tmp_path: Path):
|
|
123
|
+
"""Without a token, scan goes directly to /scans/cli."""
|
|
124
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
|
|
125
|
+
|
|
126
|
+
mock_client = MagicMock()
|
|
127
|
+
mock_client.__enter__.return_value = mock_client
|
|
128
|
+
mock_client.__exit__.return_value = False
|
|
129
|
+
mock_client.post.return_value = resp_ok
|
|
130
|
+
|
|
131
|
+
env = {
|
|
132
|
+
"HOME": str(tmp_path),
|
|
133
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
with patch("httpx.Client", return_value=mock_client):
|
|
137
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
|
|
138
|
+
|
|
139
|
+
assert result.exit_code == 0
|
|
140
|
+
assert mock_client.post.call_count == 1
|
|
141
|
+
call_url = str(mock_client.post.call_args_list[0])
|
|
142
|
+
assert "/scans/cli" in call_url
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|