potnn 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- potnn/__init__.py +86 -0
- potnn/codegen/__init__.py +20 -0
- potnn/codegen/bit2.py +263 -0
- potnn/codegen/fp130.py +269 -0
- potnn/codegen/header.py +460 -0
- potnn/codegen/level5.py +393 -0
- potnn/codegen/scale.py +184 -0
- potnn/codegen/ternary.py +354 -0
- potnn/codegen/unroll.py +616 -0
- potnn/config.py +112 -0
- potnn/export.py +2196 -0
- potnn/fuse.py +167 -0
- potnn/modules/__init__.py +11 -0
- potnn/modules/add.py +114 -0
- potnn/modules/avgpool.py +173 -0
- potnn/modules/base.py +225 -0
- potnn/modules/conv.py +203 -0
- potnn/modules/conv1d.py +317 -0
- potnn/modules/depthwise.py +216 -0
- potnn/modules/linear.py +199 -0
- potnn/quantize/__init__.py +35 -0
- potnn/quantize/calibration.py +233 -0
- potnn/quantize/integer_ops.py +207 -0
- potnn/quantize/integer_sim.py +225 -0
- potnn/quantize/pot.py +455 -0
- potnn/quantize/qat.py +356 -0
- potnn/utils/__init__.py +13 -0
- potnn/utils/allocation.py +240 -0
- potnn/utils/memory.py +158 -0
- potnn/wrapper.py +304 -0
- potnn-1.0.0.dist-info/METADATA +260 -0
- potnn-1.0.0.dist-info/RECORD +35 -0
- potnn-1.0.0.dist-info/WHEEL +5 -0
- potnn-1.0.0.dist-info/licenses/LICENSE +72 -0
- potnn-1.0.0.dist-info/top_level.txt +1 -0
potnn/quantize/qat.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
"""Quantization-aware training utilities for potnn."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
import math
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _compute_raw_alpha(target_alpha: float) -> float:
|
|
11
|
+
"""Compute raw_alpha value that produces target alpha after softplus.
|
|
12
|
+
|
|
13
|
+
softplus(raw_alpha) = target_alpha
|
|
14
|
+
raw_alpha = inverse_softplus(target_alpha) = log(exp(target_alpha) - 1)
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
target_alpha: Desired alpha value after softplus
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
raw_alpha value to set
|
|
21
|
+
"""
|
|
22
|
+
if target_alpha > 10:
|
|
23
|
+
# For large values, softplus(x) ≈ x
|
|
24
|
+
return target_alpha
|
|
25
|
+
elif target_alpha > 0.1:
|
|
26
|
+
# Normal case: inverse softplus
|
|
27
|
+
return math.log(math.exp(target_alpha) - 1)
|
|
28
|
+
elif target_alpha > 0.01:
|
|
29
|
+
# Small values: approximation
|
|
30
|
+
return math.log(target_alpha)
|
|
31
|
+
else:
|
|
32
|
+
# Very small: minimum
|
|
33
|
+
return -4.0 # softplus(-4) ≈ 0.018
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def prepare_qat(model: nn.Module, config=None, force_alpha_init: bool = True):
|
|
37
|
+
"""Prepare model for quantization-aware training.
|
|
38
|
+
|
|
39
|
+
This function:
|
|
40
|
+
1. Fuses BatchNorm layers into preceding Conv layers (CRITICAL!)
|
|
41
|
+
2. Sets encoding for each layer based on config.layer_encodings
|
|
42
|
+
3. Initializes alpha values based on weight statistics (if not calibrated)
|
|
43
|
+
4. Calculates combined scale factors for each layer
|
|
44
|
+
5. Enables quantization mode for all PoT layers
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model: Model to prepare (should be wrapped with potnn)
|
|
48
|
+
config: potnn Config with layer_encodings (optional)
|
|
49
|
+
force_alpha_init: If True, always initialize alpha from weight stats.
|
|
50
|
+
If False, only initialize if alpha seems uncalibrated (≈0.974).
|
|
51
|
+
Default True for safety.
|
|
52
|
+
"""
|
|
53
|
+
from ..modules import PoTLinear, PoTConv2d, PoTDepthwiseConv2d
|
|
54
|
+
from ..modules.base import PoTLayerBase
|
|
55
|
+
from ..fuse import fuse_batchnorm, check_bn_fused
|
|
56
|
+
|
|
57
|
+
# ========================================
|
|
58
|
+
# Step 0: Fuse BatchNorm layers BEFORE QAT
|
|
59
|
+
# This is CRITICAL! BN must be fused before PoT weight training.
|
|
60
|
+
# ========================================
|
|
61
|
+
fuse_batchnorm(model)
|
|
62
|
+
|
|
63
|
+
# Replace AdaptiveAvgPool2d with PoTGlobalAvgPool BEFORE collecting layers
|
|
64
|
+
# This ensures it's included in any subsequent processing if needed
|
|
65
|
+
from ..modules.avgpool import PoTGlobalAvgPool
|
|
66
|
+
# Use list(model.named_modules()) to avoid modification during iteration issues
|
|
67
|
+
replacements = []
|
|
68
|
+
for name, module in model.named_modules():
|
|
69
|
+
if isinstance(module, nn.AdaptiveAvgPool2d):
|
|
70
|
+
if module.output_size == 1 or module.output_size == (1, 1):
|
|
71
|
+
replacements.append((name, module))
|
|
72
|
+
|
|
73
|
+
for name, module in replacements:
|
|
74
|
+
parts = name.split('.')
|
|
75
|
+
if len(parts) > 1:
|
|
76
|
+
parent = model.get_submodule('.'.join(parts[:-1]))
|
|
77
|
+
child_name = parts[-1]
|
|
78
|
+
else:
|
|
79
|
+
parent = model
|
|
80
|
+
child_name = name
|
|
81
|
+
|
|
82
|
+
setattr(parent, child_name, PoTGlobalAvgPool())
|
|
83
|
+
print(f" {name}: Replaced AdaptiveAvgPool2d with PoTGlobalAvgPool")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
print("Preparing model for QAT...")
|
|
87
|
+
|
|
88
|
+
# Collect all PoT layers
|
|
89
|
+
pot_layers = []
|
|
90
|
+
for name, module in model.named_modules():
|
|
91
|
+
if isinstance(module, PoTLayerBase):
|
|
92
|
+
pot_layers.append((name, module))
|
|
93
|
+
|
|
94
|
+
# ========================================
|
|
95
|
+
# Step 0.5: Set encoding for each layer
|
|
96
|
+
# ========================================
|
|
97
|
+
if config is not None:
|
|
98
|
+
for name, module in pot_layers:
|
|
99
|
+
encoding = config.get_encoding(name)
|
|
100
|
+
module.encoding = encoding
|
|
101
|
+
print(f" {name}: encoding={encoding}")
|
|
102
|
+
|
|
103
|
+
# Initialize alpha values based on weight statistics
|
|
104
|
+
# This is a safety net if calibration was not called
|
|
105
|
+
for name, module in pot_layers:
|
|
106
|
+
if not hasattr(module, 'weight'):
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
with torch.no_grad():
|
|
110
|
+
current_alpha = module.alpha.item()
|
|
111
|
+
w_std = module.weight.std().item()
|
|
112
|
+
|
|
113
|
+
# Check if alpha needs initialization
|
|
114
|
+
# Default alpha ≈ 0.974 (from softplus(0.5))
|
|
115
|
+
needs_init = force_alpha_init or (0.9 < current_alpha < 1.1)
|
|
116
|
+
|
|
117
|
+
if needs_init and w_std > 0.001:
|
|
118
|
+
# Use weight std as target alpha
|
|
119
|
+
target_alpha = w_std
|
|
120
|
+
target_alpha = max(target_alpha, 0.01) # Minimum
|
|
121
|
+
|
|
122
|
+
# Compute raw_alpha for this target
|
|
123
|
+
raw = _compute_raw_alpha(target_alpha)
|
|
124
|
+
module.raw_alpha.data.fill_(raw)
|
|
125
|
+
|
|
126
|
+
# Update alpha_init for regularization
|
|
127
|
+
actual_alpha = module.alpha.item()
|
|
128
|
+
module.alpha_init.data.fill_(actual_alpha)
|
|
129
|
+
|
|
130
|
+
print(f" {name}: alpha initialized {current_alpha:.4f} → {actual_alpha:.4f} (w_std={w_std:.4f})")
|
|
131
|
+
else:
|
|
132
|
+
print(f" {name}: alpha kept at {current_alpha:.4f} (w_std={w_std:.4f})")
|
|
133
|
+
|
|
134
|
+
# Calculate combined scale factors and set prev_act_scale
|
|
135
|
+
# These are used for both Integer-Only QAT and C export
|
|
136
|
+
prev_act_scale = 1.0
|
|
137
|
+
|
|
138
|
+
# Iterate over all modules in order to propagate prev_act_scale correctly
|
|
139
|
+
# We need to handle both PoTLayerBase and PoTGlobalAvgPool
|
|
140
|
+
from ..modules.avgpool import PoTGlobalAvgPool
|
|
141
|
+
|
|
142
|
+
# Collect relevant layers in order
|
|
143
|
+
ordered_layers = []
|
|
144
|
+
for name, module in model.named_modules():
|
|
145
|
+
if isinstance(module, (PoTLayerBase, PoTGlobalAvgPool)):
|
|
146
|
+
ordered_layers.append((name, module))
|
|
147
|
+
|
|
148
|
+
for i, (name, module) in enumerate(ordered_layers):
|
|
149
|
+
is_first = (i == 0)
|
|
150
|
+
is_last = (i == len(ordered_layers) - 1)
|
|
151
|
+
|
|
152
|
+
if isinstance(module, PoTLayerBase):
|
|
153
|
+
module.set_layer_position(is_first, is_last)
|
|
154
|
+
|
|
155
|
+
# Set previous layer's act_scale (Critical for Integer QAT)
|
|
156
|
+
if isinstance(prev_act_scale, torch.Tensor):
|
|
157
|
+
module.set_prev_act_scale(prev_act_scale.item())
|
|
158
|
+
else:
|
|
159
|
+
module.set_prev_act_scale(prev_act_scale)
|
|
160
|
+
|
|
161
|
+
# Also set combined_scale_factor for export (legacy but useful)
|
|
162
|
+
scale_val = prev_act_scale.item() if isinstance(prev_act_scale, torch.Tensor) else prev_act_scale
|
|
163
|
+
module.combined_scale_factor = 1.0 / scale_val
|
|
164
|
+
|
|
165
|
+
# Update prev_act_scale for next layer
|
|
166
|
+
if module.act_scale is not None:
|
|
167
|
+
prev_act_scale = module.act_scale
|
|
168
|
+
|
|
169
|
+
elif isinstance(module, PoTGlobalAvgPool):
|
|
170
|
+
# Set GAP's input scale
|
|
171
|
+
scale_val = prev_act_scale.item() if isinstance(prev_act_scale, torch.Tensor) else prev_act_scale
|
|
172
|
+
module.prepare_qat(act_scale=scale_val)
|
|
173
|
+
|
|
174
|
+
# GAP doesn't change scale (it's averaging), so prev_act_scale remains valid for next layer
|
|
175
|
+
# Unless GAP has its own act_scale?
|
|
176
|
+
# PoTGlobalAvgPool has act_scale (input scale).
|
|
177
|
+
# Output scale is same as input scale for averaging?
|
|
178
|
+
# Yes, average of scaled integers is scaled integer.
|
|
179
|
+
# So prev_act_scale should NOT change.
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
print(f" {name}: GAP prepared for QAT (act_scale={scale_val:.4f})")
|
|
183
|
+
|
|
184
|
+
# Enable quantization mode
|
|
185
|
+
for name, module in pot_layers:
|
|
186
|
+
module.quantize = True
|
|
187
|
+
|
|
188
|
+
# Also enable quantization for PoTAdd layers
|
|
189
|
+
from ..modules.add import PoTAdd
|
|
190
|
+
add_count = 0
|
|
191
|
+
for name, module in model.named_modules():
|
|
192
|
+
if isinstance(module, PoTAdd):
|
|
193
|
+
module.quantize = True
|
|
194
|
+
add_count += 1
|
|
195
|
+
|
|
196
|
+
print(f"QAT mode enabled for {len(pot_layers)} PoT layers and {add_count} Add layers.")
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def alpha_reg_loss(model: nn.Module, lambda_reg: float = 0.01) -> torch.Tensor:
|
|
200
|
+
"""Calculate alpha regularization loss.
|
|
201
|
+
|
|
202
|
+
This prevents alpha values from drifting too far from their initial values,
|
|
203
|
+
which helps stabilize training and prevent weight collapse.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
model: Model with PoT layers
|
|
207
|
+
lambda_reg: Regularization strength
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Regularization loss term to add to the main loss
|
|
211
|
+
"""
|
|
212
|
+
from ..modules import PoTLinear, PoTConv2d, PoTConv1d
|
|
213
|
+
|
|
214
|
+
reg_loss = torch.tensor(0.0, device=next(model.parameters()).device)
|
|
215
|
+
|
|
216
|
+
for module in model.modules():
|
|
217
|
+
if isinstance(module, (PoTLinear, PoTConv2d, PoTConv1d)):
|
|
218
|
+
# Apply regularization to keep alpha near its initial value
|
|
219
|
+
alpha = F.softplus(module.alpha).clamp(min=0.01)
|
|
220
|
+
reg_loss += (alpha - module.alpha_init) ** 2
|
|
221
|
+
|
|
222
|
+
return lambda_reg * reg_loss
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def enable_integer_sim(model: nn.Module, input_std=1.0, input_mean=0.0, verbose: bool = True):
|
|
226
|
+
"""Enable integer simulation mode for C-compatible inference.
|
|
227
|
+
|
|
228
|
+
This function sets up all PoT layers for integer simulation that matches
|
|
229
|
+
C inference bit-for-bit. Call this after calibration and before QAT training
|
|
230
|
+
for best results.
|
|
231
|
+
|
|
232
|
+
The integer simulation ensures:
|
|
233
|
+
- Python forward pass uses same integer arithmetic as C
|
|
234
|
+
- Eliminates QAT-to-C accuracy gap
|
|
235
|
+
- Proper scale chain propagation between layers
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
model: Model with PoT layers (should already be calibrated)
|
|
239
|
+
input_std: Input standard deviation for first layer.
|
|
240
|
+
Can be float, List[float], or Config object.
|
|
241
|
+
input_mean: Input mean for first layer.
|
|
242
|
+
Can be float, List[float], or Config object.
|
|
243
|
+
verbose: Print debug info (default True)
|
|
244
|
+
|
|
245
|
+
Example:
|
|
246
|
+
# After calibration, before QAT training:
|
|
247
|
+
prepare_qat(model)
|
|
248
|
+
enable_integer_sim(model, input_std=0.3081, input_mean=0.1307)
|
|
249
|
+
|
|
250
|
+
# Or after training, before export:
|
|
251
|
+
enable_integer_sim(model, input_std=config.std, input_mean=config.mean)
|
|
252
|
+
"""
|
|
253
|
+
from ..modules.base import PoTLayerBase
|
|
254
|
+
from ..modules.add import PoTAdd
|
|
255
|
+
from ..config import Config
|
|
256
|
+
|
|
257
|
+
# Handle Config object passed as input_std (common mistake)
|
|
258
|
+
if isinstance(input_std, Config):
|
|
259
|
+
config = input_std
|
|
260
|
+
input_std = config.std
|
|
261
|
+
input_mean = config.mean if hasattr(config, 'mean') else input_mean
|
|
262
|
+
|
|
263
|
+
# Convert list to average (for multi-channel, use avg_std like C code)
|
|
264
|
+
if isinstance(input_std, (list, tuple)):
|
|
265
|
+
input_std = sum(input_std) / len(input_std)
|
|
266
|
+
if isinstance(input_mean, (list, tuple)):
|
|
267
|
+
input_mean = sum(input_mean) / len(input_mean)
|
|
268
|
+
|
|
269
|
+
if verbose:
|
|
270
|
+
print("="*60)
|
|
271
|
+
print("Enabling integer simulation mode...")
|
|
272
|
+
print(f" input_mean={input_mean}, input_std={input_std}")
|
|
273
|
+
print("="*60)
|
|
274
|
+
|
|
275
|
+
# Collect all PoT layers in order
|
|
276
|
+
pot_layers = []
|
|
277
|
+
for name, module in model.named_modules():
|
|
278
|
+
if isinstance(module, PoTLayerBase) and not isinstance(module, PoTAdd):
|
|
279
|
+
pot_layers.append((name, module))
|
|
280
|
+
|
|
281
|
+
if len(pot_layers) == 0:
|
|
282
|
+
print(" Warning: No PoT layers found!")
|
|
283
|
+
return
|
|
284
|
+
|
|
285
|
+
# Set up each layer
|
|
286
|
+
prev_act_scale = 1.0
|
|
287
|
+
|
|
288
|
+
for i, (name, layer) in enumerate(pot_layers):
|
|
289
|
+
is_first = (i == 0)
|
|
290
|
+
is_last = (i == len(pot_layers) - 1)
|
|
291
|
+
|
|
292
|
+
# Set layer position
|
|
293
|
+
layer.set_layer_position(is_first, is_last)
|
|
294
|
+
|
|
295
|
+
# Set previous layer's act_scale
|
|
296
|
+
layer.set_prev_act_scale(prev_act_scale)
|
|
297
|
+
|
|
298
|
+
# Set input stats for first layer
|
|
299
|
+
if is_first:
|
|
300
|
+
layer.set_input_std(input_std, input_mean)
|
|
301
|
+
if verbose:
|
|
302
|
+
print(f" [DEBUG] First layer: mean={input_mean}, std={input_std}")
|
|
303
|
+
|
|
304
|
+
# Compute integer parameters
|
|
305
|
+
scale_int, shift = layer.compute_integer_params()
|
|
306
|
+
|
|
307
|
+
# Enable integer simulation
|
|
308
|
+
layer.use_integer_sim = True
|
|
309
|
+
|
|
310
|
+
if verbose:
|
|
311
|
+
act_scale = layer.act_scale.item() if layer.act_scale is not None else None
|
|
312
|
+
alpha = layer.alpha.item()
|
|
313
|
+
print(f" {name}: first={is_first}, last={is_last}, "
|
|
314
|
+
f"prev_scale={prev_act_scale:.4f}, act_scale={act_scale}, "
|
|
315
|
+
f"alpha={alpha:.4f}, scale_int={scale_int}, shift={shift}")
|
|
316
|
+
|
|
317
|
+
# Update prev_act_scale for next layer
|
|
318
|
+
if layer.act_scale is not None:
|
|
319
|
+
prev_act_scale = layer.act_scale.item()
|
|
320
|
+
|
|
321
|
+
# Enable integer sim for PoTGlobalAvgPool layers
|
|
322
|
+
from ..modules.avgpool import PoTGlobalAvgPool
|
|
323
|
+
for name, module in model.named_modules():
|
|
324
|
+
if isinstance(module, PoTGlobalAvgPool):
|
|
325
|
+
module.integer_sim_enabled = True
|
|
326
|
+
if verbose:
|
|
327
|
+
print(f" {name}: GAP integer sim enabled")
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
if verbose:
|
|
333
|
+
print(f"Integer simulation enabled for {len(pot_layers)} layers.")
|
|
334
|
+
print("="*60)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def disable_integer_sim(model: nn.Module):
|
|
338
|
+
"""Disable integer simulation mode, reverting to float QAT.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
model: Model with PoT layers
|
|
342
|
+
"""
|
|
343
|
+
from ..modules.base import PoTLayerBase
|
|
344
|
+
from ..modules.avgpool import PoTGlobalAvgPool
|
|
345
|
+
|
|
346
|
+
for module in model.modules():
|
|
347
|
+
if isinstance(module, PoTLayerBase):
|
|
348
|
+
module.use_integer_sim = False
|
|
349
|
+
elif isinstance(module, PoTGlobalAvgPool):
|
|
350
|
+
module.integer_sim_enabled = False
|
|
351
|
+
elif isinstance(module, nn.AdaptiveAvgPool2d):
|
|
352
|
+
if hasattr(module, '_original_forward'):
|
|
353
|
+
module.forward = module._original_forward
|
|
354
|
+
del module._original_forward
|
|
355
|
+
|
|
356
|
+
print("Integer simulation disabled.")
|
potnn/utils/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Utility functions for potnn."""
|
|
2
|
+
|
|
3
|
+
from .memory import estimate_memory_usage, validate_memory
|
|
4
|
+
from .allocation import allocate_hybrid, allocate_layers, allocate_from_model, LayerAllocation
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'estimate_memory_usage',
|
|
8
|
+
'validate_memory',
|
|
9
|
+
'allocate_hybrid',
|
|
10
|
+
'allocate_layers',
|
|
11
|
+
'allocate_from_model',
|
|
12
|
+
'LayerAllocation',
|
|
13
|
+
]
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
"""Memory allocation and hybrid unroll/loop decision.
|
|
2
|
+
|
|
3
|
+
Simple heuristic: Top 20% largest layers use loop mode, rest use unroll mode.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from typing import Dict, List, Tuple, Optional
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class LayerAllocation:
|
|
13
|
+
"""Allocation decision for a single layer."""
|
|
14
|
+
name: str
|
|
15
|
+
weight_count: int
|
|
16
|
+
mode: str # 'unroll' or 'loop'
|
|
17
|
+
levels: int # 11 for both (loop uses 4-bit packing)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def estimated_flash(self) -> int:
|
|
21
|
+
"""Estimated Flash usage in bytes."""
|
|
22
|
+
if self.mode == 'unroll':
|
|
23
|
+
# Each weight becomes ~4 bytes of code on average
|
|
24
|
+
return self.weight_count * 4
|
|
25
|
+
else:
|
|
26
|
+
# Loop mode: packed weights + loop code
|
|
27
|
+
# 11 levels = 4-bit = 2 weights per byte
|
|
28
|
+
return (self.weight_count + 1) // 2 + 300 # +300 for loop code + decode table
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def estimated_ram(self) -> int:
|
|
32
|
+
"""Estimated RAM usage in bytes (stack only, weights in Flash)."""
|
|
33
|
+
# Both modes store weights in Flash, RAM is just for accumulators
|
|
34
|
+
return 32 # Fixed stack usage per layer
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def allocate_layers(
|
|
38
|
+
layer_infos: List[Dict],
|
|
39
|
+
loop_ratio: float = 0.2,
|
|
40
|
+
force_loop: Optional[List[str]] = None,
|
|
41
|
+
force_unroll: Optional[List[str]] = None,
|
|
42
|
+
) -> Dict[str, LayerAllocation]:
|
|
43
|
+
"""Allocate layers between unroll and loop modes.
|
|
44
|
+
|
|
45
|
+
Simple heuristic:
|
|
46
|
+
- Sort layers by weight count (descending)
|
|
47
|
+
- Top `loop_ratio` (default 20%) use loop mode (4 levels)
|
|
48
|
+
- Rest use unroll mode (11 levels)
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
layer_infos: List of layer info dicts with 'name' and weight info
|
|
52
|
+
loop_ratio: Fraction of layers (by count) to use loop mode (0.0-1.0)
|
|
53
|
+
force_loop: Layer names to force into loop mode
|
|
54
|
+
force_unroll: Layer names to force into unroll mode
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Dictionary mapping layer names to allocation decisions
|
|
58
|
+
"""
|
|
59
|
+
force_loop = set(force_loop or [])
|
|
60
|
+
force_unroll = set(force_unroll or [])
|
|
61
|
+
|
|
62
|
+
# Collect PoT layers with weight counts
|
|
63
|
+
pot_layers = []
|
|
64
|
+
for info in layer_infos:
|
|
65
|
+
if info.get('layer_type') != 'pot':
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
name = info['name']
|
|
69
|
+
|
|
70
|
+
# Calculate weight count from weight tensor
|
|
71
|
+
weight = info.get('weight')
|
|
72
|
+
if weight is not None:
|
|
73
|
+
weight_count = weight.numel()
|
|
74
|
+
else:
|
|
75
|
+
# Estimate from layer dimensions
|
|
76
|
+
layer_type = info.get('type', '')
|
|
77
|
+
if 'Conv' in layer_type:
|
|
78
|
+
in_ch = info.get('in_channels', 1)
|
|
79
|
+
out_ch = info.get('out_channels', 1)
|
|
80
|
+
ks = info.get('kernel_size', 3)
|
|
81
|
+
weight_count = out_ch * in_ch * ks * ks
|
|
82
|
+
elif 'Linear' in layer_type:
|
|
83
|
+
weight_count = info.get('in_features', 1) * info.get('out_features', 1)
|
|
84
|
+
else:
|
|
85
|
+
weight_count = 0
|
|
86
|
+
|
|
87
|
+
pot_layers.append({
|
|
88
|
+
'name': name,
|
|
89
|
+
'weight_count': weight_count,
|
|
90
|
+
'info': info
|
|
91
|
+
})
|
|
92
|
+
|
|
93
|
+
if not pot_layers:
|
|
94
|
+
return {}
|
|
95
|
+
|
|
96
|
+
# Sort by weight count (largest first)
|
|
97
|
+
pot_layers.sort(key=lambda x: x['weight_count'], reverse=True)
|
|
98
|
+
|
|
99
|
+
# Calculate how many layers should be loop mode
|
|
100
|
+
total_layers = len(pot_layers)
|
|
101
|
+
num_loop = _calculate_loop_count(total_layers, loop_ratio)
|
|
102
|
+
|
|
103
|
+
print(f"\nAllocation strategy: {num_loop}/{total_layers} layers will use loop mode")
|
|
104
|
+
|
|
105
|
+
# Assign modes
|
|
106
|
+
allocations = {}
|
|
107
|
+
for i, layer in enumerate(pot_layers):
|
|
108
|
+
name = layer['name']
|
|
109
|
+
weight_count = layer['weight_count']
|
|
110
|
+
|
|
111
|
+
# Check forced assignments
|
|
112
|
+
if name in force_loop:
|
|
113
|
+
mode = 'loop'
|
|
114
|
+
levels = 11 # 11-level loop (4-bit packing, no accuracy loss)
|
|
115
|
+
elif name in force_unroll:
|
|
116
|
+
mode = 'unroll'
|
|
117
|
+
levels = 11
|
|
118
|
+
elif i < num_loop:
|
|
119
|
+
# Top N largest → loop mode
|
|
120
|
+
mode = 'loop'
|
|
121
|
+
levels = 11 # 11-level loop (4-bit packing, no accuracy loss)
|
|
122
|
+
else:
|
|
123
|
+
# Rest → unroll mode
|
|
124
|
+
mode = 'unroll'
|
|
125
|
+
levels = 11
|
|
126
|
+
|
|
127
|
+
allocations[name] = LayerAllocation(
|
|
128
|
+
name=name,
|
|
129
|
+
weight_count=weight_count,
|
|
130
|
+
mode=mode,
|
|
131
|
+
levels=levels
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Print summary
|
|
135
|
+
_print_allocation_summary(allocations, pot_layers)
|
|
136
|
+
|
|
137
|
+
return allocations
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _calculate_loop_count(total: int, ratio: float) -> int:
|
|
141
|
+
"""Calculate number of layers to use loop mode.
|
|
142
|
+
|
|
143
|
+
Rules:
|
|
144
|
+
- 1-4 layers: 0 loop (all unroll)
|
|
145
|
+
- 5-9 layers: 1 loop
|
|
146
|
+
- 10-14 layers: 2 loop
|
|
147
|
+
- etc.
|
|
148
|
+
"""
|
|
149
|
+
if total < 3:
|
|
150
|
+
return 0
|
|
151
|
+
|
|
152
|
+
# Round to nearest integer, minimum 1 if ratio > 0 and total >= 5
|
|
153
|
+
num_loop = round(total * ratio)
|
|
154
|
+
return max(1, num_loop) if ratio > 0 else 0
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _print_allocation_summary(
|
|
158
|
+
allocations: Dict[str, LayerAllocation],
|
|
159
|
+
pot_layers: List[Dict]
|
|
160
|
+
) -> None:
|
|
161
|
+
"""Print allocation summary."""
|
|
162
|
+
print("\nLayer allocation summary:")
|
|
163
|
+
print("-" * 70)
|
|
164
|
+
print(f"{'Layer':<30} {'Weights':>10} {'Mode':<8} {'Levels':>6} {'Flash':>10}")
|
|
165
|
+
print("-" * 70)
|
|
166
|
+
|
|
167
|
+
total_flash = 0
|
|
168
|
+
for layer in pot_layers:
|
|
169
|
+
name = layer['name']
|
|
170
|
+
alloc = allocations[name]
|
|
171
|
+
flash = alloc.estimated_flash
|
|
172
|
+
total_flash += flash
|
|
173
|
+
|
|
174
|
+
mode_str = f"{alloc.mode}"
|
|
175
|
+
print(f"{name:<30} {alloc.weight_count:>10} {mode_str:<8} {alloc.levels:>6} {flash:>10}")
|
|
176
|
+
|
|
177
|
+
print("-" * 70)
|
|
178
|
+
|
|
179
|
+
loop_count = sum(1 for a in allocations.values() if a.mode == 'loop')
|
|
180
|
+
unroll_count = sum(1 for a in allocations.values() if a.mode == 'unroll')
|
|
181
|
+
|
|
182
|
+
print(f"Total: {len(allocations)} layers ({unroll_count} unroll, {loop_count} loop)")
|
|
183
|
+
print(f"Estimated Flash: {total_flash:,} bytes ({total_flash/1024:.1f} KB)")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def allocate_from_model(
|
|
187
|
+
model: nn.Module,
|
|
188
|
+
loop_ratio: float = 0.2,
|
|
189
|
+
) -> Dict[str, LayerAllocation]:
|
|
190
|
+
"""Convenience function to allocate directly from PyTorch model.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
model: PyTorch model with PoT layers
|
|
194
|
+
loop_ratio: Fraction of layers to use loop mode
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Dictionary mapping layer names to allocation decisions
|
|
198
|
+
"""
|
|
199
|
+
from ..modules.base import PoTLayerBase
|
|
200
|
+
|
|
201
|
+
# Collect layer infos
|
|
202
|
+
layer_infos = []
|
|
203
|
+
for name, module in model.named_modules():
|
|
204
|
+
if isinstance(module, PoTLayerBase):
|
|
205
|
+
weight_count = module.weight.numel()
|
|
206
|
+
layer_infos.append({
|
|
207
|
+
'name': name,
|
|
208
|
+
'layer_type': 'pot',
|
|
209
|
+
'type': type(module).__name__,
|
|
210
|
+
'weight_count': weight_count,
|
|
211
|
+
'weight': module.weight,
|
|
212
|
+
})
|
|
213
|
+
|
|
214
|
+
return allocate_layers(layer_infos, loop_ratio=loop_ratio)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# Backward compatibility alias
|
|
218
|
+
def allocate_hybrid(
|
|
219
|
+
model: nn.Module,
|
|
220
|
+
flash_budget: int,
|
|
221
|
+
ram_budget: int,
|
|
222
|
+
input_shape: Tuple = (1, 16, 16),
|
|
223
|
+
loop_ratio: float = 0.2,
|
|
224
|
+
) -> Dict[str, LayerAllocation]:
|
|
225
|
+
"""Backward-compatible wrapper for allocate_from_model.
|
|
226
|
+
|
|
227
|
+
This function is kept for compatibility with existing code.
|
|
228
|
+
New code should use allocate_layers() or allocate_from_model().
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
model: PyTorch model with PoT layers
|
|
232
|
+
flash_budget: Flash memory budget in bytes (currently ignored)
|
|
233
|
+
ram_budget: RAM budget in bytes (currently ignored)
|
|
234
|
+
input_shape: Input shape (currently ignored)
|
|
235
|
+
loop_ratio: Fraction of layers to use loop mode
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Dictionary mapping layer names to allocation decisions
|
|
239
|
+
"""
|
|
240
|
+
return allocate_from_model(model, loop_ratio=loop_ratio)
|