memscale 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memscale/__init__.py +29 -0
- memscale/api.py +167 -0
- memscale/core/__init__.py +28 -0
- memscale/core/config.py +112 -0
- memscale/core/decision_engine.py +277 -0
- memscale/core/executor.py +140 -0
- memscale/core/memory_graph.py +100 -0
- memscale/core/profiler.py +271 -0
- memscale/integrations/__init__.py +5 -0
- memscale/integrations/huggingface.py +66 -0
- memscale/observability/__init__.py +5 -0
- memscale/observability/logger.py +89 -0
- memscale/techniques/__init__.py +11 -0
- memscale/techniques/checkpointing.py +47 -0
- memscale/techniques/offloading.py +125 -0
- memscale/techniques/tiling.py +77 -0
- memscale-0.1.0.dist-info/METADATA +246 -0
- memscale-0.1.0.dist-info/RECORD +21 -0
- memscale-0.1.0.dist-info/WHEEL +5 -0
- memscale-0.1.0.dist-info/licenses/LICENSE +201 -0
- memscale-0.1.0.dist-info/top_level.txt +1 -0
memscale/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MemScale — Drop-in memory optimizer for PyTorch training.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
import memscale
|
|
6
|
+
|
|
7
|
+
trainer = memscale.wrap(your_trainer)
|
|
8
|
+
# That's it. VRAM optimization is now active.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from memscale.core.profiler import MemoryProfiler
|
|
12
|
+
from memscale.core.decision_engine import DecisionEngine
|
|
13
|
+
from memscale.core.executor import Executor
|
|
14
|
+
from memscale.core.config import Config, OptimizationMode
|
|
15
|
+
from memscale.api import wrap, optimize, detach
|
|
16
|
+
|
|
17
|
+
__version__ = "0.1.0"
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"wrap",
|
|
21
|
+
"optimize",
|
|
22
|
+
"detach",
|
|
23
|
+
"Config",
|
|
24
|
+
"OptimizationMode",
|
|
25
|
+
"MemoryProfiler",
|
|
26
|
+
"DecisionEngine",
|
|
27
|
+
"Executor",
|
|
28
|
+
"__version__",
|
|
29
|
+
]
|
memscale/api.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""Public API for MemScale.
|
|
2
|
+
|
|
3
|
+
This is the user-facing interface. Most users only need:
|
|
4
|
+
|
|
5
|
+
import memscale
|
|
6
|
+
trainer = memscale.wrap(your_trainer)
|
|
7
|
+
|
|
8
|
+
Or for custom training loops:
|
|
9
|
+
|
|
10
|
+
with memscale.optimize(model, optimizer):
|
|
11
|
+
for batch in dataloader:
|
|
12
|
+
...
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from contextlib import contextmanager
|
|
18
|
+
from typing import Any, Optional
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
|
|
23
|
+
from memscale.core.config import Config
|
|
24
|
+
from memscale.core.decision_engine import DecisionEngine
|
|
25
|
+
from memscale.core.executor import Executor
|
|
26
|
+
from memscale.core.profiler import MemoryProfiler
|
|
27
|
+
from memscale.observability.logger import get_logger
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def wrap(
|
|
33
|
+
target: Any,
|
|
34
|
+
config: Optional[Config] = None,
|
|
35
|
+
sample_input: Optional[torch.Tensor] = None,
|
|
36
|
+
) -> Any:
|
|
37
|
+
"""Wrap a trainer or model with MemScale optimizations.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
target: Either a HuggingFace Trainer, Lightning Trainer, or PyTorch Module
|
|
41
|
+
config: Optional configuration. Defaults to balanced mode.
|
|
42
|
+
sample_input: Sample input tensor (improves profiling accuracy)
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The same trainer/model with optimization hooks attached.
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> from transformers import Trainer
|
|
49
|
+
>>> import memscale
|
|
50
|
+
>>>
|
|
51
|
+
>>> trainer = Trainer(model=model, args=args, ...)
|
|
52
|
+
>>> trainer = memscale.wrap(trainer)
|
|
53
|
+
>>> trainer.train() # VRAM optimized automatically
|
|
54
|
+
"""
|
|
55
|
+
if config is None:
|
|
56
|
+
config = Config()
|
|
57
|
+
|
|
58
|
+
# Initialize observability
|
|
59
|
+
obs_logger = get_logger(config)
|
|
60
|
+
obs_logger.info("MemScale initialized", version="0.1.0", mode=config.mode.value)
|
|
61
|
+
|
|
62
|
+
# Detect what we're wrapping
|
|
63
|
+
if hasattr(target, "model") and hasattr(target, "train"):
|
|
64
|
+
# HuggingFace Trainer
|
|
65
|
+
from memscale.integrations.huggingface import wrap_hf_trainer
|
|
66
|
+
|
|
67
|
+
return wrap_hf_trainer(target, config, sample_input)
|
|
68
|
+
|
|
69
|
+
if isinstance(target, nn.Module):
|
|
70
|
+
# Raw PyTorch module — return wrapped model
|
|
71
|
+
return _wrap_module(target, config, sample_input)
|
|
72
|
+
|
|
73
|
+
raise TypeError(
|
|
74
|
+
f"Cannot wrap object of type {type(target).__name__}. "
|
|
75
|
+
f"Supported: HuggingFace Trainer, PyTorch Module, Lightning Trainer."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _wrap_module(
|
|
80
|
+
model: nn.Module,
|
|
81
|
+
config: Config,
|
|
82
|
+
sample_input: Optional[torch.Tensor],
|
|
83
|
+
) -> nn.Module:
|
|
84
|
+
"""Wrap a raw nn.Module with optimization hooks."""
|
|
85
|
+
profiler = MemoryProfiler(
|
|
86
|
+
prefer_static=config.use_static_profiling,
|
|
87
|
+
fallback_empirical=config.use_empirical_fallback,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Profile
|
|
91
|
+
hw = profiler.detect_hardware()
|
|
92
|
+
logger.info(f"Detected hardware: {hw}")
|
|
93
|
+
|
|
94
|
+
graph = profiler.profile(model, sample_input)
|
|
95
|
+
logger.info(graph.summary())
|
|
96
|
+
|
|
97
|
+
# Decide
|
|
98
|
+
engine = DecisionEngine(config)
|
|
99
|
+
plan = engine.decide(graph, hw)
|
|
100
|
+
logger.info(plan.summary())
|
|
101
|
+
|
|
102
|
+
# Execute
|
|
103
|
+
executor = Executor(model, plan, config)
|
|
104
|
+
executor.attach()
|
|
105
|
+
|
|
106
|
+
# Attach executor reference so user can detach later
|
|
107
|
+
model._memscale_executor = executor
|
|
108
|
+
return model
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@contextmanager
|
|
112
|
+
def optimize(
|
|
113
|
+
model: nn.Module,
|
|
114
|
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
115
|
+
config: Optional[Config] = None,
|
|
116
|
+
sample_input: Optional[torch.Tensor] = None,
|
|
117
|
+
):
|
|
118
|
+
"""Context manager for MemScale optimization.
|
|
119
|
+
|
|
120
|
+
Use this for custom training loops where you don't have a Trainer object.
|
|
121
|
+
|
|
122
|
+
Example:
|
|
123
|
+
>>> import memscale
|
|
124
|
+
>>>
|
|
125
|
+
>>> with memscale.optimize(model, optimizer) as ms:
|
|
126
|
+
... for batch in dataloader:
|
|
127
|
+
... loss = model(batch).loss
|
|
128
|
+
... loss.backward()
|
|
129
|
+
... optimizer.step()
|
|
130
|
+
... optimizer.zero_grad()
|
|
131
|
+
"""
|
|
132
|
+
if config is None:
|
|
133
|
+
config = Config()
|
|
134
|
+
|
|
135
|
+
profiler = MemoryProfiler(
|
|
136
|
+
prefer_static=config.use_static_profiling,
|
|
137
|
+
fallback_empirical=config.use_empirical_fallback,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
hw = profiler.detect_hardware()
|
|
141
|
+
graph = profiler.profile(model, sample_input)
|
|
142
|
+
engine = DecisionEngine(config)
|
|
143
|
+
plan = engine.decide(graph, hw)
|
|
144
|
+
|
|
145
|
+
logger.info(plan.summary())
|
|
146
|
+
|
|
147
|
+
executor = Executor(model, plan, config)
|
|
148
|
+
executor.attach()
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
yield executor
|
|
152
|
+
finally:
|
|
153
|
+
executor.detach()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def detach(model: nn.Module) -> None:
|
|
157
|
+
"""Remove MemScale optimization from a model.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
model: Module previously wrapped with memscale.wrap()
|
|
161
|
+
"""
|
|
162
|
+
if hasattr(model, "_memscale_executor"):
|
|
163
|
+
model._memscale_executor.detach()
|
|
164
|
+
delattr(model, "_memscale_executor")
|
|
165
|
+
logger.info("MemScale detached from model")
|
|
166
|
+
else:
|
|
167
|
+
logger.warning("Model was not wrapped with MemScale — nothing to detach")
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MemScale — Drop-in memory optimizer for PyTorch training.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
import memscale
|
|
6
|
+
|
|
7
|
+
trainer = memscale.wrap(your_trainer)
|
|
8
|
+
# That's it. VRAM optimization is now active.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from memscale.core.profiler import MemoryProfiler
|
|
12
|
+
from memscale.core.decision_engine import DecisionEngine
|
|
13
|
+
from memscale.core.executor import Executor
|
|
14
|
+
from memscale.core.config import Config, OptimizationMode
|
|
15
|
+
from memscale.api import wrap, optimize
|
|
16
|
+
|
|
17
|
+
__version__ = "0.1.0"
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"wrap",
|
|
21
|
+
"optimize",
|
|
22
|
+
"Config",
|
|
23
|
+
"OptimizationMode",
|
|
24
|
+
"MemoryProfiler",
|
|
25
|
+
"DecisionEngine",
|
|
26
|
+
"Executor",
|
|
27
|
+
"__version__",
|
|
28
|
+
]
|
memscale/core/config.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Configuration classes for MemScale.
|
|
2
|
+
|
|
3
|
+
This module defines the user-facing configuration API and internal
|
|
4
|
+
hardware budget representation.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OptimizationMode(str, Enum):
|
|
14
|
+
"""Optimization aggressiveness mode.
|
|
15
|
+
|
|
16
|
+
- CONSERVATIVE: Only optimize when VRAM > 90% capacity. Safest, least overhead.
|
|
17
|
+
- BALANCED: Default. Optimize when VRAM > 75% capacity. Good speed-memory trade-off.
|
|
18
|
+
- AGGRESSIVE: Always optimize. Lowest VRAM usage, may have small overhead.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
CONSERVATIVE = "conservative"
|
|
22
|
+
BALANCED = "balanced"
|
|
23
|
+
AGGRESSIVE = "aggressive"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class Config:
|
|
28
|
+
"""User-facing configuration for MemScale.
|
|
29
|
+
|
|
30
|
+
Most users won't need to set anything — defaults work for 90% of cases.
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
config = Config(
|
|
34
|
+
mode=OptimizationMode.BALANCED,
|
|
35
|
+
enable_offloading=True,
|
|
36
|
+
max_cpu_offload_gb=64,
|
|
37
|
+
)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
# Optimization aggressiveness
|
|
41
|
+
mode: OptimizationMode = OptimizationMode.BALANCED
|
|
42
|
+
|
|
43
|
+
# Which techniques to enable
|
|
44
|
+
enable_checkpointing: bool = True
|
|
45
|
+
enable_offloading: bool = True
|
|
46
|
+
enable_tiling: bool = False # Disabled by default (more experimental)
|
|
47
|
+
|
|
48
|
+
# Resource limits
|
|
49
|
+
max_cpu_offload_gb: Optional[float] = None # None = use available CPU RAM
|
|
50
|
+
target_gpu_utilization: float = 0.85 # Try to use this fraction of GPU memory
|
|
51
|
+
|
|
52
|
+
# Profiling
|
|
53
|
+
use_static_profiling: bool = True # Try torch.fx first
|
|
54
|
+
use_empirical_fallback: bool = True # Fallback to runtime profiling
|
|
55
|
+
warmup_steps: int = 2 # For empirical profiling
|
|
56
|
+
|
|
57
|
+
# Observability
|
|
58
|
+
enable_logging: bool = True
|
|
59
|
+
log_file: Optional[str] = None # None = stdout only
|
|
60
|
+
observability_port: Optional[int] = None # If set, expose Prometheus metrics
|
|
61
|
+
|
|
62
|
+
# Cost tracking
|
|
63
|
+
gpu_cost_per_hour: float = 2.50 # Default H100 spot rate (USD)
|
|
64
|
+
|
|
65
|
+
# Safety
|
|
66
|
+
verify_correctness: bool = False # If True, runs baseline comparison (slow!)
|
|
67
|
+
|
|
68
|
+
def __post_init__(self) -> None:
|
|
69
|
+
if self.target_gpu_utilization <= 0 or self.target_gpu_utilization > 1.0:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"target_gpu_utilization must be in (0, 1], got {self.target_gpu_utilization}"
|
|
72
|
+
)
|
|
73
|
+
if self.warmup_steps < 1:
|
|
74
|
+
raise ValueError(f"warmup_steps must be >= 1, got {self.warmup_steps}")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class HardwareBudget:
|
|
79
|
+
"""Available hardware resources detected at runtime.
|
|
80
|
+
|
|
81
|
+
Populated by Profiler.detect_hardware() — users shouldn't construct this directly.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
gpu_total_bytes: int
|
|
85
|
+
gpu_available_bytes: int
|
|
86
|
+
cpu_total_bytes: int
|
|
87
|
+
cpu_available_bytes: int
|
|
88
|
+
num_gpus: int
|
|
89
|
+
pcie_bandwidth_gbps: float = 16.0 # Reasonable default for PCIe 4.0 x16
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def gpu_total_gb(self) -> float:
|
|
93
|
+
return self.gpu_total_bytes / (1024**3)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def gpu_available_gb(self) -> float:
|
|
97
|
+
return self.gpu_available_bytes / (1024**3)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def cpu_total_gb(self) -> float:
|
|
101
|
+
return self.cpu_total_bytes / (1024**3)
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def cpu_available_gb(self) -> float:
|
|
105
|
+
return self.cpu_available_bytes / (1024**3)
|
|
106
|
+
|
|
107
|
+
def __str__(self) -> str:
|
|
108
|
+
return (
|
|
109
|
+
f"HardwareBudget(GPU: {self.gpu_available_gb:.1f}/{self.gpu_total_gb:.1f} GB, "
|
|
110
|
+
f"CPU: {self.cpu_available_gb:.1f}/{self.cpu_total_gb:.1f} GB, "
|
|
111
|
+
f"GPUs: {self.num_gpus})"
|
|
112
|
+
)
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""Decision engine: decides which optimization technique to apply per layer.
|
|
2
|
+
|
|
3
|
+
Given a memory graph and hardware budget, produce an optimization plan.
|
|
4
|
+
|
|
5
|
+
Current implementation is rule-based heuristic. Future versions will use
|
|
6
|
+
learned policies trained on customer data.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Dict, List
|
|
14
|
+
|
|
15
|
+
from memscale.core.config import Config, HardwareBudget, OptimizationMode
|
|
16
|
+
from memscale.core.memory_graph import LayerMemoryInfo, MemoryGraph
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Technique(str, Enum):
|
|
22
|
+
"""Optimization techniques that can be applied per layer."""
|
|
23
|
+
|
|
24
|
+
KEEP = "keep" # No optimization, keep on GPU
|
|
25
|
+
CHECKPOINT = "checkpoint" # Activation checkpointing (recompute on backward)
|
|
26
|
+
OFFLOAD = "offload" # CPU offloading
|
|
27
|
+
TILE = "tile" # Activation tiling
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class OptimizationPlan:
|
|
32
|
+
"""Plan for which technique to apply to each layer.
|
|
33
|
+
|
|
34
|
+
Produced by DecisionEngine.decide() and consumed by Executor.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
layer_techniques: Dict[str, Technique] = field(default_factory=dict)
|
|
38
|
+
estimated_memory_savings_bytes: int = 0
|
|
39
|
+
estimated_throughput_overhead_pct: float = 0.0
|
|
40
|
+
rationale: List[str] = field(default_factory=list)
|
|
41
|
+
|
|
42
|
+
def technique_for(self, layer_name: str) -> Technique:
|
|
43
|
+
return self.layer_techniques.get(layer_name, Technique.KEEP)
|
|
44
|
+
|
|
45
|
+
def layers_with_technique(self, technique: Technique) -> List[str]:
|
|
46
|
+
return [name for name, t in self.layer_techniques.items() if t == technique]
|
|
47
|
+
|
|
48
|
+
def summary(self) -> str:
|
|
49
|
+
gb = 1024**3
|
|
50
|
+
counts = {t: 0 for t in Technique}
|
|
51
|
+
for technique in self.layer_techniques.values():
|
|
52
|
+
counts[technique] += 1
|
|
53
|
+
|
|
54
|
+
lines = [
|
|
55
|
+
"OptimizationPlan:",
|
|
56
|
+
f" Estimated VRAM savings: {self.estimated_memory_savings_bytes / gb:.2f} GB",
|
|
57
|
+
f" Estimated throughput overhead: {self.estimated_throughput_overhead_pct:.1f}%",
|
|
58
|
+
f" Layers per technique:",
|
|
59
|
+
f" KEEP: {counts[Technique.KEEP]}",
|
|
60
|
+
f" CHECKPOINT: {counts[Technique.CHECKPOINT]}",
|
|
61
|
+
f" OFFLOAD: {counts[Technique.OFFLOAD]}",
|
|
62
|
+
f" TILE: {counts[Technique.TILE]}",
|
|
63
|
+
]
|
|
64
|
+
if self.rationale:
|
|
65
|
+
lines.append(" Rationale:")
|
|
66
|
+
for note in self.rationale[:5]:
|
|
67
|
+
lines.append(f" - {note}")
|
|
68
|
+
return "\n".join(lines)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class DecisionEngine:
|
|
72
|
+
"""Decide optimization techniques for each layer based on rules.
|
|
73
|
+
|
|
74
|
+
Algorithm:
|
|
75
|
+
1. Compute total memory needed for entire model
|
|
76
|
+
2. If it fits in GPU comfortably (< target_utilization × GPU memory), no-op
|
|
77
|
+
3. Otherwise, sort layers by memory footprint, apply techniques to biggest first:
|
|
78
|
+
- High compute/memory ratio: CHECKPOINT (recompute is cheap)
|
|
79
|
+
- Low compute/memory ratio with CPU room: OFFLOAD
|
|
80
|
+
- Otherwise: TILE
|
|
81
|
+
4. Stop when projected memory fits target budget
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
# Heuristic thresholds (tunable based on customer data)
|
|
85
|
+
RECOMPUTE_THRESHOLD = 1e6 # FLOPs/byte ratio above which checkpoint is preferred
|
|
86
|
+
MIN_LAYER_SIZE_FOR_OPTIMIZATION = 4 * 1024 * 1024 # 4 MB minimum
|
|
87
|
+
|
|
88
|
+
def __init__(self, config: Config):
|
|
89
|
+
self.config = config
|
|
90
|
+
|
|
91
|
+
def decide(self, graph: MemoryGraph, hw: HardwareBudget) -> OptimizationPlan:
|
|
92
|
+
"""Produce an optimization plan for the given memory graph and hardware."""
|
|
93
|
+
plan = OptimizationPlan()
|
|
94
|
+
|
|
95
|
+
# Determine optimization mode threshold (used both for early-exit and loop)
|
|
96
|
+
mode_thresholds = {
|
|
97
|
+
OptimizationMode.CONSERVATIVE: 0.90,
|
|
98
|
+
OptimizationMode.BALANCED: 0.75,
|
|
99
|
+
OptimizationMode.AGGRESSIVE: 0.50,
|
|
100
|
+
}
|
|
101
|
+
threshold_pct = mode_thresholds[self.config.mode]
|
|
102
|
+
|
|
103
|
+
# Calculate target VRAM budget
|
|
104
|
+
# If no GPU detected, treat target as a fraction of model memory itself
|
|
105
|
+
# (this lets users force optimization for testing/CPU scenarios)
|
|
106
|
+
total_memory_needed = graph.total_memory_bytes()
|
|
107
|
+
|
|
108
|
+
if hw.gpu_total_bytes > 0:
|
|
109
|
+
target_gpu_bytes = int(hw.gpu_total_bytes * self.config.target_gpu_utilization)
|
|
110
|
+
else:
|
|
111
|
+
# CPU-only or no GPU: use total memory needed × user's target as the budget
|
|
112
|
+
# This makes target_gpu_utilization a "how much of model memory should fit" hint
|
|
113
|
+
target_gpu_bytes = int(total_memory_needed * self.config.target_gpu_utilization)
|
|
114
|
+
|
|
115
|
+
logger.info(
|
|
116
|
+
f"Memory analysis: need {total_memory_needed / (1024**3):.3f} GB, "
|
|
117
|
+
f"target budget {target_gpu_bytes / (1024**3):.3f} GB, "
|
|
118
|
+
f"mode={self.config.mode.value}"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Quick exit: if model fits comfortably AND user hasn't asked for aggressive opt
|
|
122
|
+
# We only skip if BOTH conditions are met:
|
|
123
|
+
# 1. Model fits well within target
|
|
124
|
+
# 2. User is in conservative or balanced mode (not explicitly requesting aggressive)
|
|
125
|
+
comfortable_fit = total_memory_needed < target_gpu_bytes * 0.7
|
|
126
|
+
is_aggressive_request = self.config.mode == OptimizationMode.AGGRESSIVE
|
|
127
|
+
|
|
128
|
+
if comfortable_fit and not is_aggressive_request:
|
|
129
|
+
logger.info("Model fits comfortably. No optimization needed.")
|
|
130
|
+
for name in graph.layers:
|
|
131
|
+
plan.layer_techniques[name] = Technique.KEEP
|
|
132
|
+
plan.rationale.append("Model fits in GPU; no optimization applied.")
|
|
133
|
+
return plan
|
|
134
|
+
|
|
135
|
+
# Initial plan: keep all
|
|
136
|
+
for name in graph.layers:
|
|
137
|
+
plan.layer_techniques[name] = Technique.KEEP
|
|
138
|
+
|
|
139
|
+
# Sort layers by total memory footprint (largest first)
|
|
140
|
+
sorted_layers = graph.layers_sorted_by_memory()
|
|
141
|
+
|
|
142
|
+
# Track running memory budget
|
|
143
|
+
projected_memory = total_memory_needed
|
|
144
|
+
cpu_offload_budget = self._compute_cpu_offload_budget(hw)
|
|
145
|
+
cpu_used = 0
|
|
146
|
+
|
|
147
|
+
# Mode-aware minimum layer size: aggressive mode optimizes smaller layers
|
|
148
|
+
if self.config.mode == OptimizationMode.AGGRESSIVE:
|
|
149
|
+
min_layer_size = self.MIN_LAYER_SIZE_FOR_OPTIMIZATION // 16 # 256 KB
|
|
150
|
+
elif self.config.mode == OptimizationMode.BALANCED:
|
|
151
|
+
min_layer_size = self.MIN_LAYER_SIZE_FOR_OPTIMIZATION // 4 # 1 MB
|
|
152
|
+
else:
|
|
153
|
+
min_layer_size = self.MIN_LAYER_SIZE_FOR_OPTIMIZATION # 4 MB
|
|
154
|
+
|
|
155
|
+
for layer in sorted_layers:
|
|
156
|
+
# Stop optimizing once we're under target
|
|
157
|
+
if projected_memory < target_gpu_bytes * threshold_pct:
|
|
158
|
+
break
|
|
159
|
+
|
|
160
|
+
# Skip tiny layers — overhead not worth it (mode-aware threshold)
|
|
161
|
+
if layer.total_memory_bytes < min_layer_size:
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
technique = self._choose_technique_for_layer(
|
|
165
|
+
layer, cpu_offload_budget - cpu_used
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if technique == Technique.KEEP:
|
|
169
|
+
continue
|
|
170
|
+
|
|
171
|
+
plan.layer_techniques[layer.name] = technique
|
|
172
|
+
savings = self._estimate_savings(layer, technique)
|
|
173
|
+
projected_memory -= savings
|
|
174
|
+
plan.estimated_memory_savings_bytes += savings
|
|
175
|
+
|
|
176
|
+
if technique == Technique.OFFLOAD:
|
|
177
|
+
cpu_used += layer.activation_bytes
|
|
178
|
+
|
|
179
|
+
plan.rationale.append(
|
|
180
|
+
f"{layer.name}: {technique.value} "
|
|
181
|
+
f"(saves {savings / (1024**2):.1f} MB)"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Estimate throughput overhead
|
|
185
|
+
plan.estimated_throughput_overhead_pct = self._estimate_overhead(plan, graph)
|
|
186
|
+
|
|
187
|
+
return plan
|
|
188
|
+
|
|
189
|
+
def _choose_technique_for_layer(
|
|
190
|
+
self, layer: LayerMemoryInfo, cpu_remaining: int
|
|
191
|
+
) -> Technique:
|
|
192
|
+
"""Pick the best technique for a single layer based on its characteristics."""
|
|
193
|
+
|
|
194
|
+
# Layers without activations don't benefit from checkpointing/offloading
|
|
195
|
+
# Mode-aware minimum activation size threshold
|
|
196
|
+
if self.config.mode == OptimizationMode.AGGRESSIVE:
|
|
197
|
+
min_activation = 64 * 1024 # 64 KB
|
|
198
|
+
elif self.config.mode == OptimizationMode.BALANCED:
|
|
199
|
+
min_activation = 512 * 1024 # 512 KB
|
|
200
|
+
else:
|
|
201
|
+
min_activation = 1024 * 1024 # 1 MB
|
|
202
|
+
|
|
203
|
+
if layer.activation_bytes < min_activation:
|
|
204
|
+
return Technique.KEEP
|
|
205
|
+
|
|
206
|
+
# If checkpointing enabled and layer is compute-bound, prefer checkpoint
|
|
207
|
+
# (cheap to recompute relative to activation memory saved)
|
|
208
|
+
if (
|
|
209
|
+
self.config.enable_checkpointing
|
|
210
|
+
and layer.compute_to_memory_ratio > self.RECOMPUTE_THRESHOLD
|
|
211
|
+
):
|
|
212
|
+
return Technique.CHECKPOINT
|
|
213
|
+
|
|
214
|
+
# If offloading enabled and we have CPU room, offload
|
|
215
|
+
if (
|
|
216
|
+
self.config.enable_offloading
|
|
217
|
+
and cpu_remaining > layer.activation_bytes * 1.2 # 20% safety margin
|
|
218
|
+
):
|
|
219
|
+
return Technique.OFFLOAD
|
|
220
|
+
|
|
221
|
+
# If tiling enabled, use it as fallback
|
|
222
|
+
if self.config.enable_tiling and layer.activation_bytes > 16 * 1024 * 1024:
|
|
223
|
+
return Technique.TILE
|
|
224
|
+
|
|
225
|
+
# If checkpointing is enabled but we hit here, still use it as last resort
|
|
226
|
+
if self.config.enable_checkpointing:
|
|
227
|
+
return Technique.CHECKPOINT
|
|
228
|
+
|
|
229
|
+
return Technique.KEEP
|
|
230
|
+
|
|
231
|
+
def _compute_cpu_offload_budget(self, hw: HardwareBudget) -> int:
|
|
232
|
+
"""How much CPU RAM can we use for offloading?"""
|
|
233
|
+
if self.config.max_cpu_offload_gb is not None:
|
|
234
|
+
return int(self.config.max_cpu_offload_gb * 1024**3)
|
|
235
|
+
# Use 50% of available CPU RAM by default (leave room for OS, data loader)
|
|
236
|
+
return int(hw.cpu_available_bytes * 0.5)
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
def _estimate_savings(layer: LayerMemoryInfo, technique: Technique) -> int:
|
|
240
|
+
"""How much GPU memory does this technique save for this layer?"""
|
|
241
|
+
if technique == Technique.CHECKPOINT:
|
|
242
|
+
# Save activation memory (will recompute on backward)
|
|
243
|
+
return layer.activation_bytes
|
|
244
|
+
elif technique == Technique.OFFLOAD:
|
|
245
|
+
# Save activation + can also offload params during inactive periods
|
|
246
|
+
return layer.activation_bytes
|
|
247
|
+
elif technique == Technique.TILE:
|
|
248
|
+
# Tiling saves ~50-75% of activation peak
|
|
249
|
+
return int(layer.activation_bytes * 0.6)
|
|
250
|
+
else:
|
|
251
|
+
return 0
|
|
252
|
+
|
|
253
|
+
@staticmethod
|
|
254
|
+
def _estimate_overhead(plan: OptimizationPlan, graph: MemoryGraph) -> float:
|
|
255
|
+
"""Estimate throughput overhead in percent."""
|
|
256
|
+
# Rule of thumb (will be calibrated with real benchmarks):
|
|
257
|
+
# - CHECKPOINT: ~25% extra compute on the layer (1 extra forward)
|
|
258
|
+
# - OFFLOAD: ~5% if PCIe transfer overlaps; ~30% if it doesn't
|
|
259
|
+
# - TILE: ~10% from sequential processing
|
|
260
|
+
weighted_overhead = 0.0
|
|
261
|
+
total_flops = sum(layer.flops for layer in graph.layers.values()) or 1
|
|
262
|
+
|
|
263
|
+
for layer_name, technique in plan.layer_techniques.items():
|
|
264
|
+
layer = graph.layers.get(layer_name)
|
|
265
|
+
if not layer or layer.flops == 0:
|
|
266
|
+
continue
|
|
267
|
+
|
|
268
|
+
layer_weight = layer.flops / total_flops
|
|
269
|
+
|
|
270
|
+
if technique == Technique.CHECKPOINT:
|
|
271
|
+
weighted_overhead += 25.0 * layer_weight
|
|
272
|
+
elif technique == Technique.OFFLOAD:
|
|
273
|
+
weighted_overhead += 5.0 * layer_weight
|
|
274
|
+
elif technique == Technique.TILE:
|
|
275
|
+
weighted_overhead += 10.0 * layer_weight
|
|
276
|
+
|
|
277
|
+
return min(weighted_overhead, 100.0)
|