alloc 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alloc/__init__.py +11 -0
- alloc/artifact_writer.py +67 -0
- alloc/callbacks.py +342 -0
- alloc/catalog/__init__.py +138 -0
- alloc/catalog/default_rate_card.json +18 -0
- alloc/catalog/gpus.v1.json +174 -0
- alloc/cli.py +1341 -0
- alloc/config.py +124 -0
- alloc/context.py +191 -0
- alloc/display.py +580 -0
- alloc/extractor_runner.py +141 -0
- alloc/ghost.py +167 -0
- alloc/model_extractor.py +170 -0
- alloc/model_registry.py +138 -0
- alloc/probe.py +461 -0
- alloc/stability.py +144 -0
- alloc/upload.py +138 -0
- alloc/yaml_config.py +287 -0
- alloc-0.0.1.dist-info/METADATA +256 -0
- alloc-0.0.1.dist-info/RECORD +23 -0
- alloc-0.0.1.dist-info/WHEEL +5 -0
- alloc-0.0.1.dist-info/entry_points.txt +2 -0
- alloc-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Standalone subprocess runner for model parameter extraction.
|
|
2
|
+
|
|
3
|
+
This module is invoked by model_extractor.py as:
|
|
4
|
+
subprocess.run([python, "-m", "alloc.extractor_runner", sidecar_path, script_path])
|
|
5
|
+
|
|
6
|
+
It imports the user's training script in an isolated process, finds
|
|
7
|
+
nn.Module instances/classes, counts parameters, and writes results
|
|
8
|
+
to a JSON sidecar file. Runs with CUDA_VISIBLE_DEVICES="" to prevent
|
|
9
|
+
GPU allocation.
|
|
10
|
+
|
|
11
|
+
Never imported directly by the rest of the CLI — only executed as __main__.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import importlib.util
|
|
17
|
+
import json
|
|
18
|
+
import os
|
|
19
|
+
import sys
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _patch_sys_exit():
|
|
23
|
+
"""Prevent sys.exit() from killing the extractor."""
|
|
24
|
+
class _ExitCatch(SystemExit):
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
def _fake_exit(code=0):
|
|
28
|
+
raise _ExitCatch(code)
|
|
29
|
+
|
|
30
|
+
sys.exit = _fake_exit
|
|
31
|
+
return _ExitCatch
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _count_params(model):
|
|
35
|
+
"""Count parameters in an nn.Module."""
|
|
36
|
+
total = 0
|
|
37
|
+
dtype_str = "float32"
|
|
38
|
+
try:
|
|
39
|
+
for i, (name, param) in enumerate(model.named_parameters()):
|
|
40
|
+
total += param.numel()
|
|
41
|
+
if i == 0:
|
|
42
|
+
dtype_str = str(param.dtype).replace("torch.", "")
|
|
43
|
+
except Exception:
|
|
44
|
+
pass
|
|
45
|
+
return total, dtype_str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def main():
|
|
49
|
+
sidecar_path = sys.argv[1]
|
|
50
|
+
script_path = sys.argv[2]
|
|
51
|
+
|
|
52
|
+
result = {"status": "no_model"}
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
import torch
|
|
56
|
+
import torch.nn as nn
|
|
57
|
+
except ImportError:
|
|
58
|
+
result = {"status": "error", "error": "torch not installed"}
|
|
59
|
+
with open(sidecar_path, "w") as f:
|
|
60
|
+
json.dump(result, f)
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
ExitCatch = _patch_sys_exit()
|
|
64
|
+
|
|
65
|
+
# Add script's directory to sys.path for relative imports
|
|
66
|
+
script_dir = os.path.dirname(os.path.abspath(script_path))
|
|
67
|
+
if script_dir not in sys.path:
|
|
68
|
+
sys.path.insert(0, script_dir)
|
|
69
|
+
|
|
70
|
+
# Import the user module
|
|
71
|
+
try:
|
|
72
|
+
spec = importlib.util.spec_from_file_location("__user_module__", script_path)
|
|
73
|
+
if spec is None or spec.loader is None:
|
|
74
|
+
result = {"status": "error", "error": "cannot load module"}
|
|
75
|
+
with open(sidecar_path, "w") as f:
|
|
76
|
+
json.dump(result, f)
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
module = importlib.util.module_from_spec(spec)
|
|
80
|
+
module.__name__ = "__user_module__" # skip if __name__ == "__main__" guards
|
|
81
|
+
try:
|
|
82
|
+
spec.loader.exec_module(module)
|
|
83
|
+
except ExitCatch:
|
|
84
|
+
pass # script called sys.exit(), that's fine
|
|
85
|
+
except SystemExit:
|
|
86
|
+
pass # catch real SystemExit too
|
|
87
|
+
except Exception as e:
|
|
88
|
+
result = {"status": "error", "error": str(e)[:200]}
|
|
89
|
+
with open(sidecar_path, "w") as f:
|
|
90
|
+
json.dump(result, f)
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
# Search for nn.Module instances in module globals
|
|
94
|
+
models = []
|
|
95
|
+
for attr_name in dir(module):
|
|
96
|
+
try:
|
|
97
|
+
obj = getattr(module, attr_name)
|
|
98
|
+
if isinstance(obj, nn.Module):
|
|
99
|
+
count, dtype_str = _count_params(obj)
|
|
100
|
+
if count > 0:
|
|
101
|
+
models.append((count, dtype_str, attr_name))
|
|
102
|
+
except Exception:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
# If no instances found, search for nn.Module subclasses and try instantiation
|
|
106
|
+
if not models:
|
|
107
|
+
for attr_name in dir(module):
|
|
108
|
+
try:
|
|
109
|
+
obj = getattr(module, attr_name)
|
|
110
|
+
if (isinstance(obj, type)
|
|
111
|
+
and issubclass(obj, nn.Module)
|
|
112
|
+
and obj is not nn.Module):
|
|
113
|
+
try:
|
|
114
|
+
instance = obj()
|
|
115
|
+
count, dtype_str = _count_params(instance)
|
|
116
|
+
if count > 0:
|
|
117
|
+
models.append((count, dtype_str, attr_name))
|
|
118
|
+
except Exception:
|
|
119
|
+
continue
|
|
120
|
+
except Exception:
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
if models:
|
|
124
|
+
# Pick largest model (most params = likely training target)
|
|
125
|
+
models.sort(key=lambda m: m[0], reverse=True)
|
|
126
|
+
best = models[0]
|
|
127
|
+
result = {
|
|
128
|
+
"status": "ok",
|
|
129
|
+
"param_count": best[0],
|
|
130
|
+
"dtype": best[1],
|
|
131
|
+
"model_name": best[2],
|
|
132
|
+
}
|
|
133
|
+
else:
|
|
134
|
+
result = {"status": "no_model"}
|
|
135
|
+
|
|
136
|
+
with open(sidecar_path, "w") as f:
|
|
137
|
+
json.dump(result, f)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == "__main__":
|
|
141
|
+
main()
|
alloc/ghost.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""Ghost Scan — static VRAM analysis without executing the model.
|
|
2
|
+
|
|
3
|
+
Two paths:
|
|
4
|
+
1. With torch: walk model.named_parameters() for exact sizing
|
|
5
|
+
2. Without torch: pure math from param_count (same formula as API engine)
|
|
6
|
+
|
|
7
|
+
Never crashes user code. All exceptions are caught and logged.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass, asdict
|
|
13
|
+
from typing import Any, Optional
|
|
14
|
+
|
|
15
|
+
# Canonical dtype → bytes mapping. Keep in sync with apps/api/src/engine/vram.py.
|
|
16
|
+
BYTES_PER_DTYPE = {
|
|
17
|
+
"float32": 4,
|
|
18
|
+
"float16": 2,
|
|
19
|
+
"bfloat16": 2,
|
|
20
|
+
"int8": 1,
|
|
21
|
+
"int4": 0.5,
|
|
22
|
+
# Shorthand aliases (torch uses long names, API uses short names)
|
|
23
|
+
"fp32": 4,
|
|
24
|
+
"fp16": 2,
|
|
25
|
+
"bf16": 2,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
# Adam: params (fp32) + momentum (fp32) + variance (fp32) = 3 states * 4 bytes
|
|
29
|
+
# Keep in sync with OPTIMIZER_MULTIPLIER_FP16 in apps/api/src/engine/vram.py.
|
|
30
|
+
OPTIMIZER_BYTES_PER_PARAM = 12
|
|
31
|
+
|
|
32
|
+
# Buffer overhead: 10% of (weights + gradients + activations + optimizer)
|
|
33
|
+
# for fragmentation and temp allocations.
|
|
34
|
+
BUFFER_OVERHEAD_FACTOR = 0.1
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class GhostReport:
|
|
39
|
+
"""Result of a Ghost Scan — VRAM breakdown."""
|
|
40
|
+
|
|
41
|
+
param_count: int
|
|
42
|
+
param_count_b: float
|
|
43
|
+
dtype: str
|
|
44
|
+
weights_gb: float
|
|
45
|
+
gradients_gb: float
|
|
46
|
+
optimizer_gb: float
|
|
47
|
+
activations_gb: float
|
|
48
|
+
buffer_gb: float
|
|
49
|
+
total_gb: float
|
|
50
|
+
extraction_method: Optional[str] = None # "execution", "ast", "manual"
|
|
51
|
+
|
|
52
|
+
def to_dict(self) -> dict:
|
|
53
|
+
return asdict(self)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def ghost(
|
|
57
|
+
model: Any = None,
|
|
58
|
+
*,
|
|
59
|
+
param_count: Optional[int] = None,
|
|
60
|
+
param_count_b: Optional[float] = None,
|
|
61
|
+
dtype: str = "float16",
|
|
62
|
+
batch_size: int = 32,
|
|
63
|
+
seq_length: int = 2048,
|
|
64
|
+
hidden_dim: int = 4096,
|
|
65
|
+
) -> GhostReport:
|
|
66
|
+
"""Run a Ghost Scan on a model or param count.
|
|
67
|
+
|
|
68
|
+
Usage:
|
|
69
|
+
# With a PyTorch model
|
|
70
|
+
report = alloc.ghost(model)
|
|
71
|
+
|
|
72
|
+
# Without torch — just param count
|
|
73
|
+
report = alloc.ghost(param_count_b=7.0)
|
|
74
|
+
|
|
75
|
+
Never raises. If something goes wrong, returns a best-effort report.
|
|
76
|
+
"""
|
|
77
|
+
try:
|
|
78
|
+
return _ghost_impl(
|
|
79
|
+
model=model,
|
|
80
|
+
param_count=param_count,
|
|
81
|
+
param_count_b=param_count_b,
|
|
82
|
+
dtype=dtype,
|
|
83
|
+
batch_size=batch_size,
|
|
84
|
+
seq_length=seq_length,
|
|
85
|
+
hidden_dim=hidden_dim,
|
|
86
|
+
)
|
|
87
|
+
except Exception:
|
|
88
|
+
# Never crash user code
|
|
89
|
+
count = param_count or int((param_count_b or 0) * 1e9) or 0
|
|
90
|
+
return GhostReport(
|
|
91
|
+
param_count=count,
|
|
92
|
+
param_count_b=count / 1e9,
|
|
93
|
+
dtype=dtype,
|
|
94
|
+
weights_gb=0.0,
|
|
95
|
+
gradients_gb=0.0,
|
|
96
|
+
optimizer_gb=0.0,
|
|
97
|
+
activations_gb=0.0,
|
|
98
|
+
buffer_gb=0.0,
|
|
99
|
+
total_gb=0.0,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _ghost_impl(
|
|
104
|
+
model: Any,
|
|
105
|
+
param_count: Optional[int],
|
|
106
|
+
param_count_b: Optional[float],
|
|
107
|
+
dtype: str,
|
|
108
|
+
batch_size: int,
|
|
109
|
+
seq_length: int,
|
|
110
|
+
hidden_dim: int,
|
|
111
|
+
) -> GhostReport:
|
|
112
|
+
"""Core Ghost implementation."""
|
|
113
|
+
resolved_count = 0
|
|
114
|
+
resolved_dtype = dtype
|
|
115
|
+
|
|
116
|
+
if model is not None:
|
|
117
|
+
# Path 1: Walk torch model parameters
|
|
118
|
+
resolved_count, resolved_dtype = _count_from_model(model)
|
|
119
|
+
elif param_count is not None:
|
|
120
|
+
resolved_count = param_count
|
|
121
|
+
elif param_count_b is not None:
|
|
122
|
+
resolved_count = int(param_count_b * 1e9)
|
|
123
|
+
|
|
124
|
+
if resolved_count <= 0:
|
|
125
|
+
resolved_count = 0
|
|
126
|
+
|
|
127
|
+
bytes_per_param = BYTES_PER_DTYPE.get(resolved_dtype, 2)
|
|
128
|
+
to_gb = 1.0 / (1024 ** 3)
|
|
129
|
+
|
|
130
|
+
weights_bytes = resolved_count * bytes_per_param
|
|
131
|
+
gradients_bytes = resolved_count * bytes_per_param
|
|
132
|
+
optimizer_bytes = resolved_count * OPTIMIZER_BYTES_PER_PARAM
|
|
133
|
+
activations_bytes = batch_size * seq_length * hidden_dim * bytes_per_param
|
|
134
|
+
|
|
135
|
+
weights_gb = weights_bytes * to_gb
|
|
136
|
+
gradients_gb = gradients_bytes * to_gb
|
|
137
|
+
optimizer_gb = optimizer_bytes * to_gb
|
|
138
|
+
activations_gb = activations_bytes * to_gb
|
|
139
|
+
buffer_gb = BUFFER_OVERHEAD_FACTOR * (weights_gb + gradients_gb + optimizer_gb + activations_gb)
|
|
140
|
+
total_gb = weights_gb + gradients_gb + optimizer_gb + activations_gb + buffer_gb
|
|
141
|
+
|
|
142
|
+
return GhostReport(
|
|
143
|
+
param_count=resolved_count,
|
|
144
|
+
param_count_b=round(resolved_count / 1e9, 3),
|
|
145
|
+
dtype=resolved_dtype,
|
|
146
|
+
weights_gb=round(weights_gb, 2),
|
|
147
|
+
gradients_gb=round(gradients_gb, 2),
|
|
148
|
+
optimizer_gb=round(optimizer_gb, 2),
|
|
149
|
+
activations_gb=round(activations_gb, 2),
|
|
150
|
+
buffer_gb=round(buffer_gb, 2),
|
|
151
|
+
total_gb=round(total_gb, 2),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _count_from_model(model: Any) -> tuple:
|
|
156
|
+
"""Extract param count and dtype from a torch model. Returns (count, dtype_str)."""
|
|
157
|
+
total = 0
|
|
158
|
+
dtype_str = "float16"
|
|
159
|
+
try:
|
|
160
|
+
for name, param in model.named_parameters():
|
|
161
|
+
total += param.numel()
|
|
162
|
+
# Use dtype of the first parameter
|
|
163
|
+
if total == param.numel():
|
|
164
|
+
dtype_str = str(param.dtype).replace("torch.", "")
|
|
165
|
+
except Exception:
|
|
166
|
+
pass
|
|
167
|
+
return total, dtype_str
|
alloc/model_extractor.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""Model extraction — get param count from user scripts.
|
|
2
|
+
|
|
3
|
+
Three extraction methods with fallback chain:
|
|
4
|
+
1. Manual override (--param-count-b flag) → instant, no file needed
|
|
5
|
+
2. Subprocess execution → import module, find nn.Module, count params
|
|
6
|
+
3. AST parsing → find from_pretrained() calls, match against known models
|
|
7
|
+
|
|
8
|
+
Never crashes. Returns None when extraction fails.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import ast
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
import subprocess
|
|
17
|
+
import sys
|
|
18
|
+
import tempfile
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from typing import Optional
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ModelInfo:
|
|
25
|
+
"""Result of model extraction."""
|
|
26
|
+
|
|
27
|
+
param_count: int # raw parameter count
|
|
28
|
+
dtype: str # "float32", "float16", etc.
|
|
29
|
+
model_name: Optional[str] # class name if found
|
|
30
|
+
method: str # "execution" | "ast" | "manual"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def extract_model_info(
|
|
34
|
+
script: str,
|
|
35
|
+
*,
|
|
36
|
+
timeout: int = 60,
|
|
37
|
+
param_count_b: Optional[float] = None,
|
|
38
|
+
) -> Optional[ModelInfo]:
|
|
39
|
+
"""Extract model info from a script with fallback chain.
|
|
40
|
+
|
|
41
|
+
Returns ModelInfo or None if extraction fails.
|
|
42
|
+
"""
|
|
43
|
+
# 1. Manual override — no file needed
|
|
44
|
+
if param_count_b is not None:
|
|
45
|
+
return ModelInfo(
|
|
46
|
+
param_count=int(param_count_b * 1e9),
|
|
47
|
+
dtype="float16",
|
|
48
|
+
model_name=None,
|
|
49
|
+
method="manual",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# 2. File must exist for remaining methods
|
|
53
|
+
if not os.path.isfile(script):
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
# 3. Try subprocess execution
|
|
57
|
+
result = _extract_via_subprocess(script, timeout=timeout)
|
|
58
|
+
if result is not None:
|
|
59
|
+
return result
|
|
60
|
+
|
|
61
|
+
# 4. Try AST parsing
|
|
62
|
+
result = _extract_via_ast(script)
|
|
63
|
+
if result is not None:
|
|
64
|
+
return result
|
|
65
|
+
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# ---------------------------------------------------------------------------
|
|
70
|
+
# Subprocess execution
|
|
71
|
+
# ---------------------------------------------------------------------------
|
|
72
|
+
|
|
73
|
+
def _extract_via_subprocess(
|
|
74
|
+
script: str,
|
|
75
|
+
*,
|
|
76
|
+
timeout: int = 60,
|
|
77
|
+
) -> Optional[ModelInfo]:
|
|
78
|
+
"""Execute user script in isolated subprocess, extract model info.
|
|
79
|
+
|
|
80
|
+
Invokes alloc.extractor_runner as a module in a subprocess with
|
|
81
|
+
CUDA_VISIBLE_DEVICES="" to prevent GPU allocation. Results are
|
|
82
|
+
communicated via a JSON sidecar file.
|
|
83
|
+
"""
|
|
84
|
+
sidecar_fd = None
|
|
85
|
+
sidecar_path = None
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
# Create sidecar file for IPC
|
|
89
|
+
sidecar_fd, sidecar_path = tempfile.mkstemp(suffix=".json", prefix="alloc_extract_")
|
|
90
|
+
os.close(sidecar_fd)
|
|
91
|
+
sidecar_fd = None
|
|
92
|
+
|
|
93
|
+
script_abs = os.path.abspath(script)
|
|
94
|
+
|
|
95
|
+
env = os.environ.copy()
|
|
96
|
+
env["CUDA_VISIBLE_DEVICES"] = "" # prevent GPU allocation
|
|
97
|
+
|
|
98
|
+
subprocess.run(
|
|
99
|
+
[sys.executable, "-m", "alloc.extractor_runner", sidecar_path, script_abs],
|
|
100
|
+
timeout=timeout,
|
|
101
|
+
capture_output=True,
|
|
102
|
+
env=env,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Read sidecar result
|
|
106
|
+
try:
|
|
107
|
+
with open(sidecar_path, "r") as f:
|
|
108
|
+
data = json.load(f)
|
|
109
|
+
except (json.JSONDecodeError, FileNotFoundError, OSError):
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
if data.get("status") == "ok":
|
|
113
|
+
return ModelInfo(
|
|
114
|
+
param_count=data["param_count"],
|
|
115
|
+
dtype=data.get("dtype", "float32"),
|
|
116
|
+
model_name=data.get("model_name"),
|
|
117
|
+
method="execution",
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
except subprocess.TimeoutExpired:
|
|
123
|
+
return None
|
|
124
|
+
except Exception:
|
|
125
|
+
return None
|
|
126
|
+
finally:
|
|
127
|
+
# Clean up sidecar temp file
|
|
128
|
+
if sidecar_path and os.path.exists(sidecar_path):
|
|
129
|
+
try:
|
|
130
|
+
os.unlink(sidecar_path)
|
|
131
|
+
except OSError:
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# ---------------------------------------------------------------------------
|
|
136
|
+
# AST-based extraction (static fallback)
|
|
137
|
+
# ---------------------------------------------------------------------------
|
|
138
|
+
|
|
139
|
+
from alloc.model_registry import KNOWN_MODELS
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _extract_via_ast(script: str) -> Optional[ModelInfo]:
|
|
143
|
+
"""Parse script AST, find from_pretrained() calls, match against known models."""
|
|
144
|
+
try:
|
|
145
|
+
with open(script, "r") as f:
|
|
146
|
+
source = f.read()
|
|
147
|
+
tree = ast.parse(source)
|
|
148
|
+
except (SyntaxError, OSError):
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
# Walk AST looking for .from_pretrained("model_name") calls
|
|
152
|
+
for node in ast.walk(tree):
|
|
153
|
+
if not isinstance(node, ast.Call):
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
# Match *.from_pretrained(...)
|
|
157
|
+
func = node.func
|
|
158
|
+
if isinstance(func, ast.Attribute) and func.attr == "from_pretrained":
|
|
159
|
+
if node.args and isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str):
|
|
160
|
+
model_id = node.args[0].value
|
|
161
|
+
param_count = KNOWN_MODELS.get(model_id)
|
|
162
|
+
if param_count is not None:
|
|
163
|
+
return ModelInfo(
|
|
164
|
+
param_count=param_count,
|
|
165
|
+
dtype="float16",
|
|
166
|
+
model_name=model_id,
|
|
167
|
+
method="ast",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return None
|
alloc/model_registry.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Canonical model registry — single source of truth for known model param counts.
|
|
2
|
+
|
|
3
|
+
Used by:
|
|
4
|
+
- model_extractor.py (AST fallback: HuggingFace model IDs → raw param count)
|
|
5
|
+
- cli.py (alloc scan: short model names → param count in billions)
|
|
6
|
+
|
|
7
|
+
Add new models here. Do NOT maintain separate lookup tables elsewhere.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Dict, Optional
|
|
13
|
+
|
|
14
|
+
# HuggingFace model IDs → raw parameter count (int).
|
|
15
|
+
# Used by AST extractor to match from_pretrained("model_id") calls.
|
|
16
|
+
KNOWN_MODELS: Dict[str, int] = {
|
|
17
|
+
# GPT-2
|
|
18
|
+
"gpt2": int(124e6),
|
|
19
|
+
"gpt2-medium": int(355e6),
|
|
20
|
+
"gpt2-large": int(774e6),
|
|
21
|
+
"gpt2-xl": int(1.5e9),
|
|
22
|
+
# BERT
|
|
23
|
+
"bert-base-uncased": int(110e6),
|
|
24
|
+
"bert-large-uncased": int(340e6),
|
|
25
|
+
# Llama
|
|
26
|
+
"meta-llama/Llama-2-7b-hf": int(7e9),
|
|
27
|
+
"meta-llama/Llama-2-13b-hf": int(13e9),
|
|
28
|
+
"meta-llama/Llama-2-70b-hf": int(70e9),
|
|
29
|
+
# Mistral / Mixtral
|
|
30
|
+
"mistralai/Mistral-7B-v0.1": int(7.24e9),
|
|
31
|
+
"mistralai/Mixtral-8x7B-v0.1": int(46.7e9),
|
|
32
|
+
# Falcon
|
|
33
|
+
"tiiuae/falcon-7b": int(7e9),
|
|
34
|
+
"tiiuae/falcon-40b": int(40e9),
|
|
35
|
+
# Bloom
|
|
36
|
+
"bigscience/bloom-560m": int(560e6),
|
|
37
|
+
"bigscience/bloom-1b7": int(1.7e9),
|
|
38
|
+
"bigscience/bloom-7b1": int(7.1e9),
|
|
39
|
+
# GPT-Neo / GPT-J
|
|
40
|
+
"EleutherAI/gpt-neo-125M": int(125e6),
|
|
41
|
+
"EleutherAI/gpt-neo-1.3B": int(1.3e9),
|
|
42
|
+
"EleutherAI/gpt-neo-2.7B": int(2.7e9),
|
|
43
|
+
"EleutherAI/gpt-j-6B": int(6e9),
|
|
44
|
+
# Phi
|
|
45
|
+
"microsoft/phi-2": int(2.78e9),
|
|
46
|
+
# Gemma
|
|
47
|
+
"google/gemma-2b": int(2.51e9),
|
|
48
|
+
"google/gemma-7b": int(8.54e9),
|
|
49
|
+
# Qwen
|
|
50
|
+
"Qwen/Qwen-7B": int(7.72e9),
|
|
51
|
+
"Qwen/Qwen-14B": int(14.2e9),
|
|
52
|
+
"Qwen/Qwen-72B": int(72.7e9),
|
|
53
|
+
# DeepSeek
|
|
54
|
+
"deepseek-ai/deepseek-llm-7b-base": int(6.9e9),
|
|
55
|
+
"deepseek-ai/deepseek-llm-67b-base": int(67e9),
|
|
56
|
+
# T5
|
|
57
|
+
"t5-small": int(60e6),
|
|
58
|
+
"t5-base": int(220e6),
|
|
59
|
+
"t5-large": int(770e6),
|
|
60
|
+
"google/t5-xl-lm-adapt": int(3e9),
|
|
61
|
+
"google/t5-xxl-lm-adapt": int(11e9),
|
|
62
|
+
# Vision
|
|
63
|
+
"google/vit-base-patch16-224": int(86e6),
|
|
64
|
+
"google/vit-large-patch16-224": int(307e6),
|
|
65
|
+
# Whisper
|
|
66
|
+
"openai/whisper-small": int(244e6),
|
|
67
|
+
"openai/whisper-medium": int(769e6),
|
|
68
|
+
"openai/whisper-large-v3": int(1.55e9),
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
# Short CLI names → param count in billions.
|
|
72
|
+
# Used by `alloc scan --model <name>` for quick lookups without a script.
|
|
73
|
+
# Maps normalized lowercase names to billions (float).
|
|
74
|
+
_CLI_MODEL_PARAMS: Dict[str, float] = {
|
|
75
|
+
# Llama
|
|
76
|
+
"llama-3-70b": 70.0,
|
|
77
|
+
"llama-3-8b": 8.03,
|
|
78
|
+
"llama-2-70b": 70.0,
|
|
79
|
+
"llama-2-13b": 13.0,
|
|
80
|
+
"llama-2-7b": 7.0,
|
|
81
|
+
# Mistral / Mixtral
|
|
82
|
+
"mistral-7b": 7.24,
|
|
83
|
+
"mixtral-8x7b": 46.7,
|
|
84
|
+
# GPT-2
|
|
85
|
+
"gpt2": 0.124,
|
|
86
|
+
"gpt2-medium": 0.355,
|
|
87
|
+
"gpt2-large": 0.774,
|
|
88
|
+
"gpt2-xl": 1.5,
|
|
89
|
+
# BERT
|
|
90
|
+
"bert-base": 0.110,
|
|
91
|
+
"bert-large": 0.340,
|
|
92
|
+
# T5
|
|
93
|
+
"t5-small": 0.060,
|
|
94
|
+
"t5-base": 0.220,
|
|
95
|
+
"t5-large": 0.770,
|
|
96
|
+
"t5-xl": 3.0,
|
|
97
|
+
"t5-xxl": 11.0,
|
|
98
|
+
# Falcon
|
|
99
|
+
"falcon-7b": 7.0,
|
|
100
|
+
"falcon-40b": 40.0,
|
|
101
|
+
# Phi
|
|
102
|
+
"phi-2": 2.78,
|
|
103
|
+
# Gemma
|
|
104
|
+
"gemma-2b": 2.51,
|
|
105
|
+
"gemma-7b": 8.54,
|
|
106
|
+
# Qwen
|
|
107
|
+
"qwen-7b": 7.72,
|
|
108
|
+
"qwen-14b": 14.2,
|
|
109
|
+
"qwen-72b": 72.7,
|
|
110
|
+
# DeepSeek
|
|
111
|
+
"deepseek-7b": 6.9,
|
|
112
|
+
"deepseek-67b": 67.0,
|
|
113
|
+
# Vision
|
|
114
|
+
"vit-base": 0.086,
|
|
115
|
+
"vit-large": 0.307,
|
|
116
|
+
# Whisper
|
|
117
|
+
"whisper-small": 0.244,
|
|
118
|
+
"whisper-medium": 0.769,
|
|
119
|
+
"whisper-large": 1.55,
|
|
120
|
+
# Bloom
|
|
121
|
+
"bloom-560m": 0.560,
|
|
122
|
+
"bloom-1b7": 1.7,
|
|
123
|
+
"bloom-7b1": 7.1,
|
|
124
|
+
# GPT-Neo / GPT-J
|
|
125
|
+
"gpt-neo-125m": 0.125,
|
|
126
|
+
"gpt-neo-1.3b": 1.3,
|
|
127
|
+
"gpt-neo-2.7b": 2.7,
|
|
128
|
+
"gpt-j-6b": 6.0,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def lookup_model_params(model: str) -> Optional[float]:
|
|
133
|
+
"""Look up model param count (in billions) by short CLI name.
|
|
134
|
+
|
|
135
|
+
Returns None if model not found.
|
|
136
|
+
"""
|
|
137
|
+
normalized = model.lower().strip()
|
|
138
|
+
return _CLI_MODEL_PARAMS.get(normalized)
|