potnn 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- potnn/__init__.py +86 -0
- potnn/codegen/__init__.py +20 -0
- potnn/codegen/bit2.py +263 -0
- potnn/codegen/fp130.py +269 -0
- potnn/codegen/header.py +460 -0
- potnn/codegen/level5.py +393 -0
- potnn/codegen/scale.py +184 -0
- potnn/codegen/ternary.py +354 -0
- potnn/codegen/unroll.py +616 -0
- potnn/config.py +112 -0
- potnn/export.py +2196 -0
- potnn/fuse.py +167 -0
- potnn/modules/__init__.py +11 -0
- potnn/modules/add.py +114 -0
- potnn/modules/avgpool.py +173 -0
- potnn/modules/base.py +225 -0
- potnn/modules/conv.py +203 -0
- potnn/modules/conv1d.py +317 -0
- potnn/modules/depthwise.py +216 -0
- potnn/modules/linear.py +199 -0
- potnn/quantize/__init__.py +35 -0
- potnn/quantize/calibration.py +233 -0
- potnn/quantize/integer_ops.py +207 -0
- potnn/quantize/integer_sim.py +225 -0
- potnn/quantize/pot.py +455 -0
- potnn/quantize/qat.py +356 -0
- potnn/utils/__init__.py +13 -0
- potnn/utils/allocation.py +240 -0
- potnn/utils/memory.py +158 -0
- potnn/wrapper.py +304 -0
- potnn-1.0.0.dist-info/METADATA +260 -0
- potnn-1.0.0.dist-info/RECORD +35 -0
- potnn-1.0.0.dist-info/WHEEL +5 -0
- potnn-1.0.0.dist-info/licenses/LICENSE +72 -0
- potnn-1.0.0.dist-info/top_level.txt +1 -0
potnn/utils/memory.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""Memory estimation and validation utilities."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from typing import Dict, Tuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def estimate_layer_size(module: nn.Module) -> int:
|
|
9
|
+
"""Estimate the size of a layer in bytes.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
module: Neural network module
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
Estimated size in bytes
|
|
16
|
+
"""
|
|
17
|
+
param_count = 0
|
|
18
|
+
|
|
19
|
+
# Count weight parameters
|
|
20
|
+
if hasattr(module, 'weight') and module.weight is not None:
|
|
21
|
+
param_count += module.weight.numel()
|
|
22
|
+
|
|
23
|
+
# Count bias parameters
|
|
24
|
+
if hasattr(module, 'bias') and module.bias is not None:
|
|
25
|
+
param_count += module.bias.numel()
|
|
26
|
+
|
|
27
|
+
return param_count
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def estimate_activation_size(model: nn.Module, input_shape: Tuple) -> int:
|
|
31
|
+
"""Estimate maximum activation buffer size needed.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
model: Neural network model
|
|
35
|
+
input_shape: Shape of input tensor (without batch dimension)
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Maximum activation size in bytes
|
|
39
|
+
"""
|
|
40
|
+
# Create dummy input
|
|
41
|
+
dummy_input = torch.zeros(1, *input_shape)
|
|
42
|
+
device = next(model.parameters()).device
|
|
43
|
+
dummy_input = dummy_input.to(device)
|
|
44
|
+
|
|
45
|
+
max_size = 0
|
|
46
|
+
hooks = []
|
|
47
|
+
|
|
48
|
+
def hook_fn(module, input, output):
|
|
49
|
+
nonlocal max_size
|
|
50
|
+
if isinstance(output, torch.Tensor):
|
|
51
|
+
size = output.numel() # int8 = 1 byte per element
|
|
52
|
+
max_size = max(max_size, size)
|
|
53
|
+
|
|
54
|
+
# Register hooks
|
|
55
|
+
for module in model.modules():
|
|
56
|
+
if isinstance(module, (nn.Conv2d, nn.Linear, nn.MaxPool2d)):
|
|
57
|
+
hooks.append(module.register_forward_hook(hook_fn))
|
|
58
|
+
|
|
59
|
+
# Run forward pass
|
|
60
|
+
model.eval()
|
|
61
|
+
with torch.no_grad():
|
|
62
|
+
model(dummy_input)
|
|
63
|
+
|
|
64
|
+
# Clean up hooks
|
|
65
|
+
for hook in hooks:
|
|
66
|
+
hook.remove()
|
|
67
|
+
|
|
68
|
+
return max_size
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def estimate_memory_usage(model: nn.Module, input_shape: Tuple, mode: str = 'all') -> Dict[str, int]:
|
|
72
|
+
"""Estimate memory usage of the model.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
model: Neural network model
|
|
76
|
+
input_shape: Shape of input tensor (without batch dimension)
|
|
77
|
+
mode: 'all', 'weights', or 'activations'
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Dictionary with memory estimates in bytes
|
|
81
|
+
"""
|
|
82
|
+
result = {}
|
|
83
|
+
|
|
84
|
+
if mode in ['all', 'weights']:
|
|
85
|
+
# Estimate weight memory
|
|
86
|
+
total_weights = 0
|
|
87
|
+
layer_weights = {}
|
|
88
|
+
|
|
89
|
+
for name, module in model.named_modules():
|
|
90
|
+
if isinstance(module, (nn.Conv2d, nn.Linear)):
|
|
91
|
+
size = estimate_layer_size(module)
|
|
92
|
+
layer_weights[name] = size
|
|
93
|
+
total_weights += size
|
|
94
|
+
|
|
95
|
+
result['total_weights'] = total_weights
|
|
96
|
+
result['layer_weights'] = layer_weights
|
|
97
|
+
|
|
98
|
+
if mode in ['all', 'activations']:
|
|
99
|
+
# Estimate activation memory
|
|
100
|
+
result['max_activation'] = estimate_activation_size(model, input_shape)
|
|
101
|
+
|
|
102
|
+
if mode == 'all':
|
|
103
|
+
# Input buffer size
|
|
104
|
+
input_size = torch.zeros(1, *input_shape).numel()
|
|
105
|
+
result['input_buffer'] = input_size
|
|
106
|
+
|
|
107
|
+
# Total RAM needed (input + activations + some weights in loop mode)
|
|
108
|
+
result['estimated_ram'] = result['input_buffer'] + result['max_activation']
|
|
109
|
+
|
|
110
|
+
# Total Flash needed (mainly for unrolled weights as code)
|
|
111
|
+
# This is a rough estimate - actual size depends on unroll/loop decisions
|
|
112
|
+
result['estimated_flash'] = total_weights * 4 # Rough estimate for unrolled code
|
|
113
|
+
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def validate_memory(model: nn.Module, flash_budget: int, ram_budget: int,
|
|
118
|
+
input_shape: Tuple = (1, 16, 16)) -> bool:
|
|
119
|
+
"""Validate if model fits within memory budgets.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
model: Neural network model
|
|
123
|
+
flash_budget: Flash memory budget in bytes
|
|
124
|
+
ram_budget: RAM budget in bytes
|
|
125
|
+
input_shape: Input tensor shape (default for 16x16 grayscale)
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
True if model fits, False otherwise
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
ValueError: If model doesn't fit with error details
|
|
132
|
+
"""
|
|
133
|
+
estimates = estimate_memory_usage(model, input_shape)
|
|
134
|
+
|
|
135
|
+
# Check RAM budget
|
|
136
|
+
min_ram_needed = estimates['input_buffer'] + estimates['max_activation']
|
|
137
|
+
if min_ram_needed > ram_budget:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Model requires at least {min_ram_needed} bytes of RAM "
|
|
140
|
+
f"(input: {estimates['input_buffer']}, activation: {estimates['max_activation']}), "
|
|
141
|
+
f"but only {ram_budget} bytes available."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Check if we can fit weights either in Flash (unrolled) or RAM (loop)
|
|
145
|
+
# This is a simplified check - actual allocation is done by allocate_hybrid
|
|
146
|
+
# For unrolled code, estimate ~4 bytes per weight
|
|
147
|
+
# For loop with packing, estimate ~0.25 bytes per weight (2-bit packing)
|
|
148
|
+
unrolled_size = estimates['total_weights'] * 4
|
|
149
|
+
packed_size = estimates['total_weights'] // 4 # 2-bit packing
|
|
150
|
+
|
|
151
|
+
if unrolled_size > flash_budget and packed_size > (ram_budget - min_ram_needed):
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Model weights ({estimates['total_weights']} parameters) too large. "
|
|
154
|
+
f"Unrolled: {unrolled_size} bytes > Flash {flash_budget} bytes. "
|
|
155
|
+
f"Packed: {packed_size} bytes > Available RAM {ram_budget - min_ram_needed} bytes."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return True
|
potnn/wrapper.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""Model wrapper for potnn conversion."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from typing import Optional, Callable
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
|
|
9
|
+
from .config import Config
|
|
10
|
+
from .modules import PoTLinear, PoTConv2d
|
|
11
|
+
from .quantize.calibration import calibrate_model
|
|
12
|
+
from .quantize.qat import prepare_qat, alpha_reg_loss, enable_integer_sim
|
|
13
|
+
from .utils import validate_memory, allocate_hybrid
|
|
14
|
+
from .fuse import fuse_batchnorm
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _normalize_data(data: torch.Tensor, mean: float, std: float) -> torch.Tensor:
|
|
18
|
+
"""Normalize input data to match C inference exactly.
|
|
19
|
+
|
|
20
|
+
C inference uses:
|
|
21
|
+
- /256 (via shift+8) instead of /255
|
|
22
|
+
- avg_std (single value) instead of per-channel std
|
|
23
|
+
|
|
24
|
+
QAT must match this for training = deployment consistency.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
data: Input tensor [N, C, H, W] or [N, C, L], expected range [0, 1] from torchvision
|
|
28
|
+
mean: Dataset mean - float or List[float]
|
|
29
|
+
std: Dataset std - float or List[float]
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Normalized tensor matching C inference behavior
|
|
33
|
+
"""
|
|
34
|
+
import torch
|
|
35
|
+
|
|
36
|
+
# Calculate average std (C uses single scale value)
|
|
37
|
+
if isinstance(std, (list, tuple)):
|
|
38
|
+
avg_std = sum(std) / len(std)
|
|
39
|
+
else:
|
|
40
|
+
avg_std = std
|
|
41
|
+
|
|
42
|
+
# Convert mean to tensor with proper shape for broadcasting
|
|
43
|
+
# Dynamic: [1, C, 1] for 3D (Conv1d), [1, C, 1, 1] for 4D (Conv2d)
|
|
44
|
+
if isinstance(mean, (list, tuple)):
|
|
45
|
+
mean = torch.tensor(mean, dtype=data.dtype, device=data.device)
|
|
46
|
+
if data.dim() == 3: # Conv1d: (B, C, L)
|
|
47
|
+
mean = mean.view(1, -1, 1)
|
|
48
|
+
else: # Conv2d: (B, C, H, W)
|
|
49
|
+
mean = mean.view(1, -1, 1, 1)
|
|
50
|
+
|
|
51
|
+
# Match C inference:
|
|
52
|
+
# - data comes as [0,1] from torchvision (raw/255)
|
|
53
|
+
# - C uses raw/256, so multiply by 256/255 to compensate
|
|
54
|
+
# - C uses avg_std in scale, so divide by avg_std
|
|
55
|
+
return (data * (256.0 / 255.0) - mean) / avg_std
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _validate_model(model: nn.Module) -> None:
|
|
59
|
+
"""모델에 일반 nn.Conv2d, nn.Linear가 있는지 검사.
|
|
60
|
+
|
|
61
|
+
potnn은 PoTConv2d, PoTLinear만 지원한다.
|
|
62
|
+
일반 레이어가 섞여 있으면 export 시 실패하므로 미리 경고.
|
|
63
|
+
"""
|
|
64
|
+
from .modules.conv import PoTConv2d
|
|
65
|
+
from .modules.depthwise import PoTDepthwiseConv2d
|
|
66
|
+
from .modules.linear import PoTLinear
|
|
67
|
+
|
|
68
|
+
errors = []
|
|
69
|
+
|
|
70
|
+
for name, module in model.named_modules():
|
|
71
|
+
# nn.Conv2d지만 PoTConv2d/PoTDepthwiseConv2d가 아닌 경우
|
|
72
|
+
if isinstance(module, nn.Conv2d) and not isinstance(module, (PoTConv2d, PoTDepthwiseConv2d)):
|
|
73
|
+
errors.append(f" - {name}: nn.Conv2d → potnn.PoTConv2d로 교체 필요")
|
|
74
|
+
|
|
75
|
+
# nn.Linear지만 PoTLinear가 아닌 경우
|
|
76
|
+
if isinstance(module, nn.Linear) and not isinstance(module, PoTLinear):
|
|
77
|
+
errors.append(f" - {name}: nn.Linear → potnn.PoTLinear로 교체 필요")
|
|
78
|
+
|
|
79
|
+
if errors:
|
|
80
|
+
error_msg = "\n".join(errors)
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"potnn은 PoT 레이어만 지원합니다. 다음 레이어를 교체하세요:\n{error_msg}\n\n"
|
|
83
|
+
f"예시:\n"
|
|
84
|
+
f" nn.Conv2d(1, 16, 3) → potnn.PoTConv2d(1, 16, 3)\n"
|
|
85
|
+
f" nn.Linear(256, 10) → potnn.PoTLinear(256, 10)"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def train(model: nn.Module,
|
|
90
|
+
train_loader: DataLoader,
|
|
91
|
+
test_loader: DataLoader,
|
|
92
|
+
config: Config,
|
|
93
|
+
float_epochs: int = 15,
|
|
94
|
+
qat_epochs: int = 50,
|
|
95
|
+
float_lr: float = 1e-3,
|
|
96
|
+
qat_lr: float = 1e-4,
|
|
97
|
+
device: str = 'cuda',
|
|
98
|
+
fuse_bn: bool = True,
|
|
99
|
+
verbose: bool = True) -> nn.Module:
|
|
100
|
+
"""Complete training pipeline: Float → (BN Fusion) → Calibration → QAT → Integer Sim
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
model: PoT model (must use PoTConv2d, PoTLinear)
|
|
104
|
+
train_loader: Training data loader (raw [0,1] input, NO Normalize transform needed)
|
|
105
|
+
test_loader: Test data loader (raw [0,1] input, NO Normalize transform needed)
|
|
106
|
+
config: potnn Config (mean/std used for automatic normalization)
|
|
107
|
+
float_epochs: Float training epochs (default: 15)
|
|
108
|
+
qat_epochs: QAT training epochs (default: 50)
|
|
109
|
+
float_lr: Float training learning rate (default: 1e-3)
|
|
110
|
+
qat_lr: QAT training learning rate (default: 1e-4)
|
|
111
|
+
device: 'cuda' or 'cpu' (default: 'cuda')
|
|
112
|
+
fuse_bn: Fuse BatchNorm layers after float training (default: True)
|
|
113
|
+
verbose: Print progress (default: True)
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Trained model with Integer Simulation enabled.
|
|
117
|
+
- model.train(): uses Float QAT for fine-tuning
|
|
118
|
+
- model.eval(): uses Integer Simulation (matches C exactly)
|
|
119
|
+
|
|
120
|
+
Note:
|
|
121
|
+
Input normalization is handled automatically using config.mean/std.
|
|
122
|
+
Do NOT add transforms.Normalize() to your DataLoader.
|
|
123
|
+
"""
|
|
124
|
+
# 모델 검증: 일반 nn.Conv2d, nn.Linear 사용 시 경고
|
|
125
|
+
_validate_model(model)
|
|
126
|
+
|
|
127
|
+
model = model.to(device)
|
|
128
|
+
|
|
129
|
+
# Get normalization params from config
|
|
130
|
+
mean = config.mean if config.mean is not None else 0.0
|
|
131
|
+
std = config.std if config.std is not None else 1.0
|
|
132
|
+
|
|
133
|
+
# Phase 1: Float Training
|
|
134
|
+
if verbose:
|
|
135
|
+
print(f"\n[Phase 1] Float Training ({float_epochs} epochs)...")
|
|
136
|
+
|
|
137
|
+
optimizer = torch.optim.AdamW(model.parameters(), lr=float_lr)
|
|
138
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float_epochs)
|
|
139
|
+
|
|
140
|
+
best_float_acc = 0
|
|
141
|
+
for epoch in range(float_epochs):
|
|
142
|
+
model.train()
|
|
143
|
+
for data, target in train_loader:
|
|
144
|
+
data, target = data.to(device), target.to(device)
|
|
145
|
+
data = _normalize_data(data, mean, std) # Auto normalize
|
|
146
|
+
optimizer.zero_grad()
|
|
147
|
+
output = model(data)
|
|
148
|
+
loss = F.cross_entropy(output, target)
|
|
149
|
+
loss.backward()
|
|
150
|
+
optimizer.step()
|
|
151
|
+
|
|
152
|
+
scheduler.step()
|
|
153
|
+
acc = _evaluate(model, test_loader, device, mean, std)
|
|
154
|
+
best_float_acc = max(best_float_acc, acc)
|
|
155
|
+
|
|
156
|
+
if verbose and (epoch % 5 == 0 or epoch == float_epochs - 1):
|
|
157
|
+
print(f" Epoch {epoch+1}/{float_epochs}: {acc:.2f}%")
|
|
158
|
+
|
|
159
|
+
if verbose:
|
|
160
|
+
print(f" Best Float: {best_float_acc:.2f}%")
|
|
161
|
+
|
|
162
|
+
# Phase 1.5: BatchNorm Fusion (optional)
|
|
163
|
+
if fuse_bn:
|
|
164
|
+
if verbose:
|
|
165
|
+
print(f"\n[Phase 1.5] BatchNorm Fusion...")
|
|
166
|
+
model = fuse_batchnorm(model)
|
|
167
|
+
|
|
168
|
+
# Phase 2: Calibration
|
|
169
|
+
if verbose:
|
|
170
|
+
print(f"\n[Phase 2] Calibration...")
|
|
171
|
+
|
|
172
|
+
calibrate_model(model, train_loader, mean=mean, std=std)
|
|
173
|
+
|
|
174
|
+
# Phase 3: QAT Preparation
|
|
175
|
+
if verbose:
|
|
176
|
+
print(f"\n[Phase 3] Preparing QAT...")
|
|
177
|
+
|
|
178
|
+
prepare_qat(model, config)
|
|
179
|
+
|
|
180
|
+
# Set up first layer info for mean absorption during QAT
|
|
181
|
+
# This ensures QAT uses the same bias as Integer Sim
|
|
182
|
+
from .modules.base import PoTLayerBase
|
|
183
|
+
pot_layers = [(name, m) for name, m in model.named_modules() if isinstance(m, PoTLayerBase)]
|
|
184
|
+
for i, (name, layer) in enumerate(pot_layers):
|
|
185
|
+
is_first = (i == 0)
|
|
186
|
+
is_last = (i == len(pot_layers) - 1)
|
|
187
|
+
layer.set_layer_position(is_first, is_last)
|
|
188
|
+
if is_first:
|
|
189
|
+
layer.set_input_std(config.std, config.mean)
|
|
190
|
+
|
|
191
|
+
# Phase 4: QAT Training (Hybrid: float first, integer sim last 20%)
|
|
192
|
+
if verbose:
|
|
193
|
+
print(f"\n[Phase 4] QAT Training ({qat_epochs} epochs)...")
|
|
194
|
+
|
|
195
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=qat_lr)
|
|
196
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, qat_epochs)
|
|
197
|
+
|
|
198
|
+
# Activation epochs (last 20%)
|
|
199
|
+
constraint_start_epoch = int(qat_epochs * 0.8)
|
|
200
|
+
integer_sim_start_epoch = int(qat_epochs * 0.6) # Float QAT 60% → Integer sim fine-tune 40%
|
|
201
|
+
|
|
202
|
+
best_qat_acc = 0
|
|
203
|
+
best_state = None
|
|
204
|
+
|
|
205
|
+
for epoch in range(qat_epochs):
|
|
206
|
+
# Enable 5level constraint for last 20% epochs
|
|
207
|
+
if epoch == constraint_start_epoch:
|
|
208
|
+
for name, module in model.named_modules():
|
|
209
|
+
if hasattr(module, 'enforce_5level_constraint'):
|
|
210
|
+
if hasattr(module, 'encoding') and module.encoding == '5level':
|
|
211
|
+
module.enforce_5level_constraint = True
|
|
212
|
+
if verbose:
|
|
213
|
+
print(f" [{name}] 5level constraint enabled (epoch {epoch+1})")
|
|
214
|
+
|
|
215
|
+
# Enable integer sim for last 20% epochs (fine-tune phase)
|
|
216
|
+
if epoch == integer_sim_start_epoch:
|
|
217
|
+
if verbose:
|
|
218
|
+
print(f" [Integer Sim] Enabled for fine-tuning (epoch {epoch+1})")
|
|
219
|
+
enable_integer_sim(model, input_std=config.std, input_mean=config.mean, verbose=False)
|
|
220
|
+
# Lower learning rate for fine-tuning
|
|
221
|
+
for param_group in optimizer.param_groups:
|
|
222
|
+
param_group['lr'] = param_group['lr'] * 0.1
|
|
223
|
+
|
|
224
|
+
# Update integer params each epoch if using integer sim
|
|
225
|
+
if epoch >= integer_sim_start_epoch:
|
|
226
|
+
for name, module in model.named_modules():
|
|
227
|
+
if isinstance(module, PoTLayerBase) and module.use_integer_sim:
|
|
228
|
+
module.compute_integer_params()
|
|
229
|
+
|
|
230
|
+
model.train()
|
|
231
|
+
for data, target in train_loader:
|
|
232
|
+
data, target = data.to(device), target.to(device)
|
|
233
|
+
data = _normalize_data(data, mean, std) # Auto normalize
|
|
234
|
+
optimizer.zero_grad()
|
|
235
|
+
output = model(data)
|
|
236
|
+
loss = F.cross_entropy(output, target) + alpha_reg_loss(model, 0.01)
|
|
237
|
+
loss.backward()
|
|
238
|
+
optimizer.step()
|
|
239
|
+
|
|
240
|
+
scheduler.step()
|
|
241
|
+
acc = _evaluate(model, test_loader, device, mean, std)
|
|
242
|
+
|
|
243
|
+
# Only update best after integer sim starts (to ensure C-compatible weights)
|
|
244
|
+
if epoch >= integer_sim_start_epoch and acc > best_qat_acc:
|
|
245
|
+
best_qat_acc = acc
|
|
246
|
+
best_state = {k: v.clone() for k, v in model.state_dict().items()}
|
|
247
|
+
|
|
248
|
+
if verbose and (epoch % 10 == 0 or epoch == qat_epochs - 1):
|
|
249
|
+
print(f" Epoch {epoch+1}/{qat_epochs}: {acc:.2f}%")
|
|
250
|
+
|
|
251
|
+
# Restore best model
|
|
252
|
+
if best_state is not None:
|
|
253
|
+
model.load_state_dict(best_state, strict=False)
|
|
254
|
+
|
|
255
|
+
if verbose:
|
|
256
|
+
print(f" Best QAT: {best_qat_acc:.2f}%")
|
|
257
|
+
|
|
258
|
+
# Ensure integer sim is enabled for final model
|
|
259
|
+
enable_integer_sim(model, input_std=config.std, input_mean=config.mean, verbose=verbose)
|
|
260
|
+
|
|
261
|
+
# Final integer params update
|
|
262
|
+
for name, module in model.named_modules():
|
|
263
|
+
if isinstance(module, PoTLayerBase) and module.use_integer_sim:
|
|
264
|
+
module.compute_integer_params()
|
|
265
|
+
|
|
266
|
+
# Final accuracy (with integer sim)
|
|
267
|
+
final_acc = _evaluate(model, test_loader, device, mean, std)
|
|
268
|
+
|
|
269
|
+
if verbose:
|
|
270
|
+
print(f"\n[Summary] Float: {best_float_acc:.2f}% → QAT: {best_qat_acc:.2f}% → C-Ready: {final_acc:.2f}%")
|
|
271
|
+
|
|
272
|
+
# Attach stats to model for reporting
|
|
273
|
+
model.train_stats = {
|
|
274
|
+
'float_acc': best_float_acc,
|
|
275
|
+
'qat_acc': best_qat_acc,
|
|
276
|
+
'final_acc': final_acc
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
return model
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _evaluate(model: nn.Module, test_loader: DataLoader, device: str,
|
|
283
|
+
mean: float = 0.0, std: float = 1.0) -> float:
|
|
284
|
+
"""Evaluate model accuracy with automatic normalization."""
|
|
285
|
+
model.eval()
|
|
286
|
+
correct = 0
|
|
287
|
+
total = 0
|
|
288
|
+
|
|
289
|
+
with torch.no_grad():
|
|
290
|
+
for data, target in test_loader:
|
|
291
|
+
data, target = data.to(device), target.to(device)
|
|
292
|
+
data = _normalize_data(data, mean, std) # Auto normalize
|
|
293
|
+
output = model(data)
|
|
294
|
+
_, predicted = output.max(1)
|
|
295
|
+
total += target.size(0)
|
|
296
|
+
correct += predicted.eq(target).sum().item()
|
|
297
|
+
|
|
298
|
+
return 100. * correct / total
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# Note: wrap() function has been removed.
|
|
302
|
+
# Users must define models using PoT layers directly:
|
|
303
|
+
# potnn.PoTConv2d, potnn.PoTConv1d, potnn.PoTLinear, etc.
|
|
304
|
+
# This ensures proper initialization of alpha, QAT parameters, and encoding.
|