alloc 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,141 @@
1
+ """Standalone subprocess runner for model parameter extraction.
2
+
3
+ This module is invoked by model_extractor.py as:
4
+ subprocess.run([python, "-m", "alloc.extractor_runner", sidecar_path, script_path])
5
+
6
+ It imports the user's training script in an isolated process, finds
7
+ nn.Module instances/classes, counts parameters, and writes results
8
+ to a JSON sidecar file. Runs with CUDA_VISIBLE_DEVICES="" to prevent
9
+ GPU allocation.
10
+
11
+ Never imported directly by the rest of the CLI — only executed as __main__.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import importlib.util
17
+ import json
18
+ import os
19
+ import sys
20
+
21
+
22
+ def _patch_sys_exit():
23
+ """Prevent sys.exit() from killing the extractor."""
24
+ class _ExitCatch(SystemExit):
25
+ pass
26
+
27
+ def _fake_exit(code=0):
28
+ raise _ExitCatch(code)
29
+
30
+ sys.exit = _fake_exit
31
+ return _ExitCatch
32
+
33
+
34
+ def _count_params(model):
35
+ """Count parameters in an nn.Module."""
36
+ total = 0
37
+ dtype_str = "float32"
38
+ try:
39
+ for i, (name, param) in enumerate(model.named_parameters()):
40
+ total += param.numel()
41
+ if i == 0:
42
+ dtype_str = str(param.dtype).replace("torch.", "")
43
+ except Exception:
44
+ pass
45
+ return total, dtype_str
46
+
47
+
48
+ def main():
49
+ sidecar_path = sys.argv[1]
50
+ script_path = sys.argv[2]
51
+
52
+ result = {"status": "no_model"}
53
+
54
+ try:
55
+ import torch
56
+ import torch.nn as nn
57
+ except ImportError:
58
+ result = {"status": "error", "error": "torch not installed"}
59
+ with open(sidecar_path, "w") as f:
60
+ json.dump(result, f)
61
+ return
62
+
63
+ ExitCatch = _patch_sys_exit()
64
+
65
+ # Add script's directory to sys.path for relative imports
66
+ script_dir = os.path.dirname(os.path.abspath(script_path))
67
+ if script_dir not in sys.path:
68
+ sys.path.insert(0, script_dir)
69
+
70
+ # Import the user module
71
+ try:
72
+ spec = importlib.util.spec_from_file_location("__user_module__", script_path)
73
+ if spec is None or spec.loader is None:
74
+ result = {"status": "error", "error": "cannot load module"}
75
+ with open(sidecar_path, "w") as f:
76
+ json.dump(result, f)
77
+ return
78
+
79
+ module = importlib.util.module_from_spec(spec)
80
+ module.__name__ = "__user_module__" # skip if __name__ == "__main__" guards
81
+ try:
82
+ spec.loader.exec_module(module)
83
+ except ExitCatch:
84
+ pass # script called sys.exit(), that's fine
85
+ except SystemExit:
86
+ pass # catch real SystemExit too
87
+ except Exception as e:
88
+ result = {"status": "error", "error": str(e)[:200]}
89
+ with open(sidecar_path, "w") as f:
90
+ json.dump(result, f)
91
+ return
92
+
93
+ # Search for nn.Module instances in module globals
94
+ models = []
95
+ for attr_name in dir(module):
96
+ try:
97
+ obj = getattr(module, attr_name)
98
+ if isinstance(obj, nn.Module):
99
+ count, dtype_str = _count_params(obj)
100
+ if count > 0:
101
+ models.append((count, dtype_str, attr_name))
102
+ except Exception:
103
+ continue
104
+
105
+ # If no instances found, search for nn.Module subclasses and try instantiation
106
+ if not models:
107
+ for attr_name in dir(module):
108
+ try:
109
+ obj = getattr(module, attr_name)
110
+ if (isinstance(obj, type)
111
+ and issubclass(obj, nn.Module)
112
+ and obj is not nn.Module):
113
+ try:
114
+ instance = obj()
115
+ count, dtype_str = _count_params(instance)
116
+ if count > 0:
117
+ models.append((count, dtype_str, attr_name))
118
+ except Exception:
119
+ continue
120
+ except Exception:
121
+ continue
122
+
123
+ if models:
124
+ # Pick largest model (most params = likely training target)
125
+ models.sort(key=lambda m: m[0], reverse=True)
126
+ best = models[0]
127
+ result = {
128
+ "status": "ok",
129
+ "param_count": best[0],
130
+ "dtype": best[1],
131
+ "model_name": best[2],
132
+ }
133
+ else:
134
+ result = {"status": "no_model"}
135
+
136
+ with open(sidecar_path, "w") as f:
137
+ json.dump(result, f)
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main()
alloc/ghost.py ADDED
@@ -0,0 +1,167 @@
1
+ """Ghost Scan — static VRAM analysis without executing the model.
2
+
3
+ Two paths:
4
+ 1. With torch: walk model.named_parameters() for exact sizing
5
+ 2. Without torch: pure math from param_count (same formula as API engine)
6
+
7
+ Never crashes user code. All exceptions are caught and logged.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, asdict
13
+ from typing import Any, Optional
14
+
15
+ # Canonical dtype → bytes mapping. Keep in sync with apps/api/src/engine/vram.py.
16
+ BYTES_PER_DTYPE = {
17
+ "float32": 4,
18
+ "float16": 2,
19
+ "bfloat16": 2,
20
+ "int8": 1,
21
+ "int4": 0.5,
22
+ # Shorthand aliases (torch uses long names, API uses short names)
23
+ "fp32": 4,
24
+ "fp16": 2,
25
+ "bf16": 2,
26
+ }
27
+
28
+ # Adam: params (fp32) + momentum (fp32) + variance (fp32) = 3 states * 4 bytes
29
+ # Keep in sync with OPTIMIZER_MULTIPLIER_FP16 in apps/api/src/engine/vram.py.
30
+ OPTIMIZER_BYTES_PER_PARAM = 12
31
+
32
+ # Buffer overhead: 10% of (weights + gradients + activations + optimizer)
33
+ # for fragmentation and temp allocations.
34
+ BUFFER_OVERHEAD_FACTOR = 0.1
35
+
36
+
37
+ @dataclass
38
+ class GhostReport:
39
+ """Result of a Ghost Scan — VRAM breakdown."""
40
+
41
+ param_count: int
42
+ param_count_b: float
43
+ dtype: str
44
+ weights_gb: float
45
+ gradients_gb: float
46
+ optimizer_gb: float
47
+ activations_gb: float
48
+ buffer_gb: float
49
+ total_gb: float
50
+ extraction_method: Optional[str] = None # "execution", "ast", "manual"
51
+
52
+ def to_dict(self) -> dict:
53
+ return asdict(self)
54
+
55
+
56
+ def ghost(
57
+ model: Any = None,
58
+ *,
59
+ param_count: Optional[int] = None,
60
+ param_count_b: Optional[float] = None,
61
+ dtype: str = "float16",
62
+ batch_size: int = 32,
63
+ seq_length: int = 2048,
64
+ hidden_dim: int = 4096,
65
+ ) -> GhostReport:
66
+ """Run a Ghost Scan on a model or param count.
67
+
68
+ Usage:
69
+ # With a PyTorch model
70
+ report = alloc.ghost(model)
71
+
72
+ # Without torch — just param count
73
+ report = alloc.ghost(param_count_b=7.0)
74
+
75
+ Never raises. If something goes wrong, returns a best-effort report.
76
+ """
77
+ try:
78
+ return _ghost_impl(
79
+ model=model,
80
+ param_count=param_count,
81
+ param_count_b=param_count_b,
82
+ dtype=dtype,
83
+ batch_size=batch_size,
84
+ seq_length=seq_length,
85
+ hidden_dim=hidden_dim,
86
+ )
87
+ except Exception:
88
+ # Never crash user code
89
+ count = param_count or int((param_count_b or 0) * 1e9) or 0
90
+ return GhostReport(
91
+ param_count=count,
92
+ param_count_b=count / 1e9,
93
+ dtype=dtype,
94
+ weights_gb=0.0,
95
+ gradients_gb=0.0,
96
+ optimizer_gb=0.0,
97
+ activations_gb=0.0,
98
+ buffer_gb=0.0,
99
+ total_gb=0.0,
100
+ )
101
+
102
+
103
+ def _ghost_impl(
104
+ model: Any,
105
+ param_count: Optional[int],
106
+ param_count_b: Optional[float],
107
+ dtype: str,
108
+ batch_size: int,
109
+ seq_length: int,
110
+ hidden_dim: int,
111
+ ) -> GhostReport:
112
+ """Core Ghost implementation."""
113
+ resolved_count = 0
114
+ resolved_dtype = dtype
115
+
116
+ if model is not None:
117
+ # Path 1: Walk torch model parameters
118
+ resolved_count, resolved_dtype = _count_from_model(model)
119
+ elif param_count is not None:
120
+ resolved_count = param_count
121
+ elif param_count_b is not None:
122
+ resolved_count = int(param_count_b * 1e9)
123
+
124
+ if resolved_count <= 0:
125
+ resolved_count = 0
126
+
127
+ bytes_per_param = BYTES_PER_DTYPE.get(resolved_dtype, 2)
128
+ to_gb = 1.0 / (1024 ** 3)
129
+
130
+ weights_bytes = resolved_count * bytes_per_param
131
+ gradients_bytes = resolved_count * bytes_per_param
132
+ optimizer_bytes = resolved_count * OPTIMIZER_BYTES_PER_PARAM
133
+ activations_bytes = batch_size * seq_length * hidden_dim * bytes_per_param
134
+
135
+ weights_gb = weights_bytes * to_gb
136
+ gradients_gb = gradients_bytes * to_gb
137
+ optimizer_gb = optimizer_bytes * to_gb
138
+ activations_gb = activations_bytes * to_gb
139
+ buffer_gb = BUFFER_OVERHEAD_FACTOR * (weights_gb + gradients_gb + optimizer_gb + activations_gb)
140
+ total_gb = weights_gb + gradients_gb + optimizer_gb + activations_gb + buffer_gb
141
+
142
+ return GhostReport(
143
+ param_count=resolved_count,
144
+ param_count_b=round(resolved_count / 1e9, 3),
145
+ dtype=resolved_dtype,
146
+ weights_gb=round(weights_gb, 2),
147
+ gradients_gb=round(gradients_gb, 2),
148
+ optimizer_gb=round(optimizer_gb, 2),
149
+ activations_gb=round(activations_gb, 2),
150
+ buffer_gb=round(buffer_gb, 2),
151
+ total_gb=round(total_gb, 2),
152
+ )
153
+
154
+
155
+ def _count_from_model(model: Any) -> tuple:
156
+ """Extract param count and dtype from a torch model. Returns (count, dtype_str)."""
157
+ total = 0
158
+ dtype_str = "float16"
159
+ try:
160
+ for name, param in model.named_parameters():
161
+ total += param.numel()
162
+ # Use dtype of the first parameter
163
+ if total == param.numel():
164
+ dtype_str = str(param.dtype).replace("torch.", "")
165
+ except Exception:
166
+ pass
167
+ return total, dtype_str
@@ -0,0 +1,170 @@
1
+ """Model extraction — get param count from user scripts.
2
+
3
+ Three extraction methods with fallback chain:
4
+ 1. Manual override (--param-count-b flag) → instant, no file needed
5
+ 2. Subprocess execution → import module, find nn.Module, count params
6
+ 3. AST parsing → find from_pretrained() calls, match against known models
7
+
8
+ Never crashes. Returns None when extraction fails.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import ast
14
+ import json
15
+ import os
16
+ import subprocess
17
+ import sys
18
+ import tempfile
19
+ from dataclasses import dataclass
20
+ from typing import Optional
21
+
22
+
23
+ @dataclass
24
+ class ModelInfo:
25
+ """Result of model extraction."""
26
+
27
+ param_count: int # raw parameter count
28
+ dtype: str # "float32", "float16", etc.
29
+ model_name: Optional[str] # class name if found
30
+ method: str # "execution" | "ast" | "manual"
31
+
32
+
33
+ def extract_model_info(
34
+ script: str,
35
+ *,
36
+ timeout: int = 60,
37
+ param_count_b: Optional[float] = None,
38
+ ) -> Optional[ModelInfo]:
39
+ """Extract model info from a script with fallback chain.
40
+
41
+ Returns ModelInfo or None if extraction fails.
42
+ """
43
+ # 1. Manual override — no file needed
44
+ if param_count_b is not None:
45
+ return ModelInfo(
46
+ param_count=int(param_count_b * 1e9),
47
+ dtype="float16",
48
+ model_name=None,
49
+ method="manual",
50
+ )
51
+
52
+ # 2. File must exist for remaining methods
53
+ if not os.path.isfile(script):
54
+ return None
55
+
56
+ # 3. Try subprocess execution
57
+ result = _extract_via_subprocess(script, timeout=timeout)
58
+ if result is not None:
59
+ return result
60
+
61
+ # 4. Try AST parsing
62
+ result = _extract_via_ast(script)
63
+ if result is not None:
64
+ return result
65
+
66
+ return None
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # Subprocess execution
71
+ # ---------------------------------------------------------------------------
72
+
73
+ def _extract_via_subprocess(
74
+ script: str,
75
+ *,
76
+ timeout: int = 60,
77
+ ) -> Optional[ModelInfo]:
78
+ """Execute user script in isolated subprocess, extract model info.
79
+
80
+ Invokes alloc.extractor_runner as a module in a subprocess with
81
+ CUDA_VISIBLE_DEVICES="" to prevent GPU allocation. Results are
82
+ communicated via a JSON sidecar file.
83
+ """
84
+ sidecar_fd = None
85
+ sidecar_path = None
86
+
87
+ try:
88
+ # Create sidecar file for IPC
89
+ sidecar_fd, sidecar_path = tempfile.mkstemp(suffix=".json", prefix="alloc_extract_")
90
+ os.close(sidecar_fd)
91
+ sidecar_fd = None
92
+
93
+ script_abs = os.path.abspath(script)
94
+
95
+ env = os.environ.copy()
96
+ env["CUDA_VISIBLE_DEVICES"] = "" # prevent GPU allocation
97
+
98
+ subprocess.run(
99
+ [sys.executable, "-m", "alloc.extractor_runner", sidecar_path, script_abs],
100
+ timeout=timeout,
101
+ capture_output=True,
102
+ env=env,
103
+ )
104
+
105
+ # Read sidecar result
106
+ try:
107
+ with open(sidecar_path, "r") as f:
108
+ data = json.load(f)
109
+ except (json.JSONDecodeError, FileNotFoundError, OSError):
110
+ return None
111
+
112
+ if data.get("status") == "ok":
113
+ return ModelInfo(
114
+ param_count=data["param_count"],
115
+ dtype=data.get("dtype", "float32"),
116
+ model_name=data.get("model_name"),
117
+ method="execution",
118
+ )
119
+
120
+ return None
121
+
122
+ except subprocess.TimeoutExpired:
123
+ return None
124
+ except Exception:
125
+ return None
126
+ finally:
127
+ # Clean up sidecar temp file
128
+ if sidecar_path and os.path.exists(sidecar_path):
129
+ try:
130
+ os.unlink(sidecar_path)
131
+ except OSError:
132
+ pass
133
+
134
+
135
+ # ---------------------------------------------------------------------------
136
+ # AST-based extraction (static fallback)
137
+ # ---------------------------------------------------------------------------
138
+
139
+ from alloc.model_registry import KNOWN_MODELS
140
+
141
+
142
+ def _extract_via_ast(script: str) -> Optional[ModelInfo]:
143
+ """Parse script AST, find from_pretrained() calls, match against known models."""
144
+ try:
145
+ with open(script, "r") as f:
146
+ source = f.read()
147
+ tree = ast.parse(source)
148
+ except (SyntaxError, OSError):
149
+ return None
150
+
151
+ # Walk AST looking for .from_pretrained("model_name") calls
152
+ for node in ast.walk(tree):
153
+ if not isinstance(node, ast.Call):
154
+ continue
155
+
156
+ # Match *.from_pretrained(...)
157
+ func = node.func
158
+ if isinstance(func, ast.Attribute) and func.attr == "from_pretrained":
159
+ if node.args and isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str):
160
+ model_id = node.args[0].value
161
+ param_count = KNOWN_MODELS.get(model_id)
162
+ if param_count is not None:
163
+ return ModelInfo(
164
+ param_count=param_count,
165
+ dtype="float16",
166
+ model_name=model_id,
167
+ method="ast",
168
+ )
169
+
170
+ return None
@@ -0,0 +1,138 @@
1
+ """Canonical model registry — single source of truth for known model param counts.
2
+
3
+ Used by:
4
+ - model_extractor.py (AST fallback: HuggingFace model IDs → raw param count)
5
+ - cli.py (alloc scan: short model names → param count in billions)
6
+
7
+ Add new models here. Do NOT maintain separate lookup tables elsewhere.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Dict, Optional
13
+
14
+ # HuggingFace model IDs → raw parameter count (int).
15
+ # Used by AST extractor to match from_pretrained("model_id") calls.
16
+ KNOWN_MODELS: Dict[str, int] = {
17
+ # GPT-2
18
+ "gpt2": int(124e6),
19
+ "gpt2-medium": int(355e6),
20
+ "gpt2-large": int(774e6),
21
+ "gpt2-xl": int(1.5e9),
22
+ # BERT
23
+ "bert-base-uncased": int(110e6),
24
+ "bert-large-uncased": int(340e6),
25
+ # Llama
26
+ "meta-llama/Llama-2-7b-hf": int(7e9),
27
+ "meta-llama/Llama-2-13b-hf": int(13e9),
28
+ "meta-llama/Llama-2-70b-hf": int(70e9),
29
+ # Mistral / Mixtral
30
+ "mistralai/Mistral-7B-v0.1": int(7.24e9),
31
+ "mistralai/Mixtral-8x7B-v0.1": int(46.7e9),
32
+ # Falcon
33
+ "tiiuae/falcon-7b": int(7e9),
34
+ "tiiuae/falcon-40b": int(40e9),
35
+ # Bloom
36
+ "bigscience/bloom-560m": int(560e6),
37
+ "bigscience/bloom-1b7": int(1.7e9),
38
+ "bigscience/bloom-7b1": int(7.1e9),
39
+ # GPT-Neo / GPT-J
40
+ "EleutherAI/gpt-neo-125M": int(125e6),
41
+ "EleutherAI/gpt-neo-1.3B": int(1.3e9),
42
+ "EleutherAI/gpt-neo-2.7B": int(2.7e9),
43
+ "EleutherAI/gpt-j-6B": int(6e9),
44
+ # Phi
45
+ "microsoft/phi-2": int(2.78e9),
46
+ # Gemma
47
+ "google/gemma-2b": int(2.51e9),
48
+ "google/gemma-7b": int(8.54e9),
49
+ # Qwen
50
+ "Qwen/Qwen-7B": int(7.72e9),
51
+ "Qwen/Qwen-14B": int(14.2e9),
52
+ "Qwen/Qwen-72B": int(72.7e9),
53
+ # DeepSeek
54
+ "deepseek-ai/deepseek-llm-7b-base": int(6.9e9),
55
+ "deepseek-ai/deepseek-llm-67b-base": int(67e9),
56
+ # T5
57
+ "t5-small": int(60e6),
58
+ "t5-base": int(220e6),
59
+ "t5-large": int(770e6),
60
+ "google/t5-xl-lm-adapt": int(3e9),
61
+ "google/t5-xxl-lm-adapt": int(11e9),
62
+ # Vision
63
+ "google/vit-base-patch16-224": int(86e6),
64
+ "google/vit-large-patch16-224": int(307e6),
65
+ # Whisper
66
+ "openai/whisper-small": int(244e6),
67
+ "openai/whisper-medium": int(769e6),
68
+ "openai/whisper-large-v3": int(1.55e9),
69
+ }
70
+
71
+ # Short CLI names → param count in billions.
72
+ # Used by `alloc scan --model <name>` for quick lookups without a script.
73
+ # Maps normalized lowercase names to billions (float).
74
+ _CLI_MODEL_PARAMS: Dict[str, float] = {
75
+ # Llama
76
+ "llama-3-70b": 70.0,
77
+ "llama-3-8b": 8.03,
78
+ "llama-2-70b": 70.0,
79
+ "llama-2-13b": 13.0,
80
+ "llama-2-7b": 7.0,
81
+ # Mistral / Mixtral
82
+ "mistral-7b": 7.24,
83
+ "mixtral-8x7b": 46.7,
84
+ # GPT-2
85
+ "gpt2": 0.124,
86
+ "gpt2-medium": 0.355,
87
+ "gpt2-large": 0.774,
88
+ "gpt2-xl": 1.5,
89
+ # BERT
90
+ "bert-base": 0.110,
91
+ "bert-large": 0.340,
92
+ # T5
93
+ "t5-small": 0.060,
94
+ "t5-base": 0.220,
95
+ "t5-large": 0.770,
96
+ "t5-xl": 3.0,
97
+ "t5-xxl": 11.0,
98
+ # Falcon
99
+ "falcon-7b": 7.0,
100
+ "falcon-40b": 40.0,
101
+ # Phi
102
+ "phi-2": 2.78,
103
+ # Gemma
104
+ "gemma-2b": 2.51,
105
+ "gemma-7b": 8.54,
106
+ # Qwen
107
+ "qwen-7b": 7.72,
108
+ "qwen-14b": 14.2,
109
+ "qwen-72b": 72.7,
110
+ # DeepSeek
111
+ "deepseek-7b": 6.9,
112
+ "deepseek-67b": 67.0,
113
+ # Vision
114
+ "vit-base": 0.086,
115
+ "vit-large": 0.307,
116
+ # Whisper
117
+ "whisper-small": 0.244,
118
+ "whisper-medium": 0.769,
119
+ "whisper-large": 1.55,
120
+ # Bloom
121
+ "bloom-560m": 0.560,
122
+ "bloom-1b7": 1.7,
123
+ "bloom-7b1": 7.1,
124
+ # GPT-Neo / GPT-J
125
+ "gpt-neo-125m": 0.125,
126
+ "gpt-neo-1.3b": 1.3,
127
+ "gpt-neo-2.7b": 2.7,
128
+ "gpt-j-6b": 6.0,
129
+ }
130
+
131
+
132
+ def lookup_model_params(model: str) -> Optional[float]:
133
+ """Look up model param count (in billions) by short CLI name.
134
+
135
+ Returns None if model not found.
136
+ """
137
+ normalized = model.lower().strip()
138
+ return _CLI_MODEL_PARAMS.get(normalized)