potnn 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
potnn/quantize/qat.py ADDED
@@ -0,0 +1,356 @@
1
+ """Quantization-aware training utilities for potnn."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from typing import Optional
8
+
9
+
10
+ def _compute_raw_alpha(target_alpha: float) -> float:
11
+ """Compute raw_alpha value that produces target alpha after softplus.
12
+
13
+ softplus(raw_alpha) = target_alpha
14
+ raw_alpha = inverse_softplus(target_alpha) = log(exp(target_alpha) - 1)
15
+
16
+ Args:
17
+ target_alpha: Desired alpha value after softplus
18
+
19
+ Returns:
20
+ raw_alpha value to set
21
+ """
22
+ if target_alpha > 10:
23
+ # For large values, softplus(x) ≈ x
24
+ return target_alpha
25
+ elif target_alpha > 0.1:
26
+ # Normal case: inverse softplus
27
+ return math.log(math.exp(target_alpha) - 1)
28
+ elif target_alpha > 0.01:
29
+ # Small values: approximation
30
+ return math.log(target_alpha)
31
+ else:
32
+ # Very small: minimum
33
+ return -4.0 # softplus(-4) ≈ 0.018
34
+
35
+
36
+ def prepare_qat(model: nn.Module, config=None, force_alpha_init: bool = True):
37
+ """Prepare model for quantization-aware training.
38
+
39
+ This function:
40
+ 1. Fuses BatchNorm layers into preceding Conv layers (CRITICAL!)
41
+ 2. Sets encoding for each layer based on config.layer_encodings
42
+ 3. Initializes alpha values based on weight statistics (if not calibrated)
43
+ 4. Calculates combined scale factors for each layer
44
+ 5. Enables quantization mode for all PoT layers
45
+
46
+ Args:
47
+ model: Model to prepare (should be wrapped with potnn)
48
+ config: potnn Config with layer_encodings (optional)
49
+ force_alpha_init: If True, always initialize alpha from weight stats.
50
+ If False, only initialize if alpha seems uncalibrated (≈0.974).
51
+ Default True for safety.
52
+ """
53
+ from ..modules import PoTLinear, PoTConv2d, PoTDepthwiseConv2d
54
+ from ..modules.base import PoTLayerBase
55
+ from ..fuse import fuse_batchnorm, check_bn_fused
56
+
57
+ # ========================================
58
+ # Step 0: Fuse BatchNorm layers BEFORE QAT
59
+ # This is CRITICAL! BN must be fused before PoT weight training.
60
+ # ========================================
61
+ fuse_batchnorm(model)
62
+
63
+ # Replace AdaptiveAvgPool2d with PoTGlobalAvgPool BEFORE collecting layers
64
+ # This ensures it's included in any subsequent processing if needed
65
+ from ..modules.avgpool import PoTGlobalAvgPool
66
+ # Use list(model.named_modules()) to avoid modification during iteration issues
67
+ replacements = []
68
+ for name, module in model.named_modules():
69
+ if isinstance(module, nn.AdaptiveAvgPool2d):
70
+ if module.output_size == 1 or module.output_size == (1, 1):
71
+ replacements.append((name, module))
72
+
73
+ for name, module in replacements:
74
+ parts = name.split('.')
75
+ if len(parts) > 1:
76
+ parent = model.get_submodule('.'.join(parts[:-1]))
77
+ child_name = parts[-1]
78
+ else:
79
+ parent = model
80
+ child_name = name
81
+
82
+ setattr(parent, child_name, PoTGlobalAvgPool())
83
+ print(f" {name}: Replaced AdaptiveAvgPool2d with PoTGlobalAvgPool")
84
+
85
+
86
+ print("Preparing model for QAT...")
87
+
88
+ # Collect all PoT layers
89
+ pot_layers = []
90
+ for name, module in model.named_modules():
91
+ if isinstance(module, PoTLayerBase):
92
+ pot_layers.append((name, module))
93
+
94
+ # ========================================
95
+ # Step 0.5: Set encoding for each layer
96
+ # ========================================
97
+ if config is not None:
98
+ for name, module in pot_layers:
99
+ encoding = config.get_encoding(name)
100
+ module.encoding = encoding
101
+ print(f" {name}: encoding={encoding}")
102
+
103
+ # Initialize alpha values based on weight statistics
104
+ # This is a safety net if calibration was not called
105
+ for name, module in pot_layers:
106
+ if not hasattr(module, 'weight'):
107
+ continue
108
+
109
+ with torch.no_grad():
110
+ current_alpha = module.alpha.item()
111
+ w_std = module.weight.std().item()
112
+
113
+ # Check if alpha needs initialization
114
+ # Default alpha ≈ 0.974 (from softplus(0.5))
115
+ needs_init = force_alpha_init or (0.9 < current_alpha < 1.1)
116
+
117
+ if needs_init and w_std > 0.001:
118
+ # Use weight std as target alpha
119
+ target_alpha = w_std
120
+ target_alpha = max(target_alpha, 0.01) # Minimum
121
+
122
+ # Compute raw_alpha for this target
123
+ raw = _compute_raw_alpha(target_alpha)
124
+ module.raw_alpha.data.fill_(raw)
125
+
126
+ # Update alpha_init for regularization
127
+ actual_alpha = module.alpha.item()
128
+ module.alpha_init.data.fill_(actual_alpha)
129
+
130
+ print(f" {name}: alpha initialized {current_alpha:.4f} → {actual_alpha:.4f} (w_std={w_std:.4f})")
131
+ else:
132
+ print(f" {name}: alpha kept at {current_alpha:.4f} (w_std={w_std:.4f})")
133
+
134
+ # Calculate combined scale factors and set prev_act_scale
135
+ # These are used for both Integer-Only QAT and C export
136
+ prev_act_scale = 1.0
137
+
138
+ # Iterate over all modules in order to propagate prev_act_scale correctly
139
+ # We need to handle both PoTLayerBase and PoTGlobalAvgPool
140
+ from ..modules.avgpool import PoTGlobalAvgPool
141
+
142
+ # Collect relevant layers in order
143
+ ordered_layers = []
144
+ for name, module in model.named_modules():
145
+ if isinstance(module, (PoTLayerBase, PoTGlobalAvgPool)):
146
+ ordered_layers.append((name, module))
147
+
148
+ for i, (name, module) in enumerate(ordered_layers):
149
+ is_first = (i == 0)
150
+ is_last = (i == len(ordered_layers) - 1)
151
+
152
+ if isinstance(module, PoTLayerBase):
153
+ module.set_layer_position(is_first, is_last)
154
+
155
+ # Set previous layer's act_scale (Critical for Integer QAT)
156
+ if isinstance(prev_act_scale, torch.Tensor):
157
+ module.set_prev_act_scale(prev_act_scale.item())
158
+ else:
159
+ module.set_prev_act_scale(prev_act_scale)
160
+
161
+ # Also set combined_scale_factor for export (legacy but useful)
162
+ scale_val = prev_act_scale.item() if isinstance(prev_act_scale, torch.Tensor) else prev_act_scale
163
+ module.combined_scale_factor = 1.0 / scale_val
164
+
165
+ # Update prev_act_scale for next layer
166
+ if module.act_scale is not None:
167
+ prev_act_scale = module.act_scale
168
+
169
+ elif isinstance(module, PoTGlobalAvgPool):
170
+ # Set GAP's input scale
171
+ scale_val = prev_act_scale.item() if isinstance(prev_act_scale, torch.Tensor) else prev_act_scale
172
+ module.prepare_qat(act_scale=scale_val)
173
+
174
+ # GAP doesn't change scale (it's averaging), so prev_act_scale remains valid for next layer
175
+ # Unless GAP has its own act_scale?
176
+ # PoTGlobalAvgPool has act_scale (input scale).
177
+ # Output scale is same as input scale for averaging?
178
+ # Yes, average of scaled integers is scaled integer.
179
+ # So prev_act_scale should NOT change.
180
+ pass
181
+
182
+ print(f" {name}: GAP prepared for QAT (act_scale={scale_val:.4f})")
183
+
184
+ # Enable quantization mode
185
+ for name, module in pot_layers:
186
+ module.quantize = True
187
+
188
+ # Also enable quantization for PoTAdd layers
189
+ from ..modules.add import PoTAdd
190
+ add_count = 0
191
+ for name, module in model.named_modules():
192
+ if isinstance(module, PoTAdd):
193
+ module.quantize = True
194
+ add_count += 1
195
+
196
+ print(f"QAT mode enabled for {len(pot_layers)} PoT layers and {add_count} Add layers.")
197
+
198
+
199
+ def alpha_reg_loss(model: nn.Module, lambda_reg: float = 0.01) -> torch.Tensor:
200
+ """Calculate alpha regularization loss.
201
+
202
+ This prevents alpha values from drifting too far from their initial values,
203
+ which helps stabilize training and prevent weight collapse.
204
+
205
+ Args:
206
+ model: Model with PoT layers
207
+ lambda_reg: Regularization strength
208
+
209
+ Returns:
210
+ Regularization loss term to add to the main loss
211
+ """
212
+ from ..modules import PoTLinear, PoTConv2d, PoTConv1d
213
+
214
+ reg_loss = torch.tensor(0.0, device=next(model.parameters()).device)
215
+
216
+ for module in model.modules():
217
+ if isinstance(module, (PoTLinear, PoTConv2d, PoTConv1d)):
218
+ # Apply regularization to keep alpha near its initial value
219
+ alpha = F.softplus(module.alpha).clamp(min=0.01)
220
+ reg_loss += (alpha - module.alpha_init) ** 2
221
+
222
+ return lambda_reg * reg_loss
223
+
224
+
225
+ def enable_integer_sim(model: nn.Module, input_std=1.0, input_mean=0.0, verbose: bool = True):
226
+ """Enable integer simulation mode for C-compatible inference.
227
+
228
+ This function sets up all PoT layers for integer simulation that matches
229
+ C inference bit-for-bit. Call this after calibration and before QAT training
230
+ for best results.
231
+
232
+ The integer simulation ensures:
233
+ - Python forward pass uses same integer arithmetic as C
234
+ - Eliminates QAT-to-C accuracy gap
235
+ - Proper scale chain propagation between layers
236
+
237
+ Args:
238
+ model: Model with PoT layers (should already be calibrated)
239
+ input_std: Input standard deviation for first layer.
240
+ Can be float, List[float], or Config object.
241
+ input_mean: Input mean for first layer.
242
+ Can be float, List[float], or Config object.
243
+ verbose: Print debug info (default True)
244
+
245
+ Example:
246
+ # After calibration, before QAT training:
247
+ prepare_qat(model)
248
+ enable_integer_sim(model, input_std=0.3081, input_mean=0.1307)
249
+
250
+ # Or after training, before export:
251
+ enable_integer_sim(model, input_std=config.std, input_mean=config.mean)
252
+ """
253
+ from ..modules.base import PoTLayerBase
254
+ from ..modules.add import PoTAdd
255
+ from ..config import Config
256
+
257
+ # Handle Config object passed as input_std (common mistake)
258
+ if isinstance(input_std, Config):
259
+ config = input_std
260
+ input_std = config.std
261
+ input_mean = config.mean if hasattr(config, 'mean') else input_mean
262
+
263
+ # Convert list to average (for multi-channel, use avg_std like C code)
264
+ if isinstance(input_std, (list, tuple)):
265
+ input_std = sum(input_std) / len(input_std)
266
+ if isinstance(input_mean, (list, tuple)):
267
+ input_mean = sum(input_mean) / len(input_mean)
268
+
269
+ if verbose:
270
+ print("="*60)
271
+ print("Enabling integer simulation mode...")
272
+ print(f" input_mean={input_mean}, input_std={input_std}")
273
+ print("="*60)
274
+
275
+ # Collect all PoT layers in order
276
+ pot_layers = []
277
+ for name, module in model.named_modules():
278
+ if isinstance(module, PoTLayerBase) and not isinstance(module, PoTAdd):
279
+ pot_layers.append((name, module))
280
+
281
+ if len(pot_layers) == 0:
282
+ print(" Warning: No PoT layers found!")
283
+ return
284
+
285
+ # Set up each layer
286
+ prev_act_scale = 1.0
287
+
288
+ for i, (name, layer) in enumerate(pot_layers):
289
+ is_first = (i == 0)
290
+ is_last = (i == len(pot_layers) - 1)
291
+
292
+ # Set layer position
293
+ layer.set_layer_position(is_first, is_last)
294
+
295
+ # Set previous layer's act_scale
296
+ layer.set_prev_act_scale(prev_act_scale)
297
+
298
+ # Set input stats for first layer
299
+ if is_first:
300
+ layer.set_input_std(input_std, input_mean)
301
+ if verbose:
302
+ print(f" [DEBUG] First layer: mean={input_mean}, std={input_std}")
303
+
304
+ # Compute integer parameters
305
+ scale_int, shift = layer.compute_integer_params()
306
+
307
+ # Enable integer simulation
308
+ layer.use_integer_sim = True
309
+
310
+ if verbose:
311
+ act_scale = layer.act_scale.item() if layer.act_scale is not None else None
312
+ alpha = layer.alpha.item()
313
+ print(f" {name}: first={is_first}, last={is_last}, "
314
+ f"prev_scale={prev_act_scale:.4f}, act_scale={act_scale}, "
315
+ f"alpha={alpha:.4f}, scale_int={scale_int}, shift={shift}")
316
+
317
+ # Update prev_act_scale for next layer
318
+ if layer.act_scale is not None:
319
+ prev_act_scale = layer.act_scale.item()
320
+
321
+ # Enable integer sim for PoTGlobalAvgPool layers
322
+ from ..modules.avgpool import PoTGlobalAvgPool
323
+ for name, module in model.named_modules():
324
+ if isinstance(module, PoTGlobalAvgPool):
325
+ module.integer_sim_enabled = True
326
+ if verbose:
327
+ print(f" {name}: GAP integer sim enabled")
328
+
329
+
330
+
331
+
332
+ if verbose:
333
+ print(f"Integer simulation enabled for {len(pot_layers)} layers.")
334
+ print("="*60)
335
+
336
+
337
+ def disable_integer_sim(model: nn.Module):
338
+ """Disable integer simulation mode, reverting to float QAT.
339
+
340
+ Args:
341
+ model: Model with PoT layers
342
+ """
343
+ from ..modules.base import PoTLayerBase
344
+ from ..modules.avgpool import PoTGlobalAvgPool
345
+
346
+ for module in model.modules():
347
+ if isinstance(module, PoTLayerBase):
348
+ module.use_integer_sim = False
349
+ elif isinstance(module, PoTGlobalAvgPool):
350
+ module.integer_sim_enabled = False
351
+ elif isinstance(module, nn.AdaptiveAvgPool2d):
352
+ if hasattr(module, '_original_forward'):
353
+ module.forward = module._original_forward
354
+ del module._original_forward
355
+
356
+ print("Integer simulation disabled.")
@@ -0,0 +1,13 @@
1
+ """Utility functions for potnn."""
2
+
3
+ from .memory import estimate_memory_usage, validate_memory
4
+ from .allocation import allocate_hybrid, allocate_layers, allocate_from_model, LayerAllocation
5
+
6
+ __all__ = [
7
+ 'estimate_memory_usage',
8
+ 'validate_memory',
9
+ 'allocate_hybrid',
10
+ 'allocate_layers',
11
+ 'allocate_from_model',
12
+ 'LayerAllocation',
13
+ ]
@@ -0,0 +1,240 @@
1
+ """Memory allocation and hybrid unroll/loop decision.
2
+
3
+ Simple heuristic: Top 20% largest layers use loop mode, rest use unroll mode.
4
+ """
5
+
6
+ import torch.nn as nn
7
+ from typing import Dict, List, Tuple, Optional
8
+ from dataclasses import dataclass
9
+
10
+
11
+ @dataclass
12
+ class LayerAllocation:
13
+ """Allocation decision for a single layer."""
14
+ name: str
15
+ weight_count: int
16
+ mode: str # 'unroll' or 'loop'
17
+ levels: int # 11 for both (loop uses 4-bit packing)
18
+
19
+ @property
20
+ def estimated_flash(self) -> int:
21
+ """Estimated Flash usage in bytes."""
22
+ if self.mode == 'unroll':
23
+ # Each weight becomes ~4 bytes of code on average
24
+ return self.weight_count * 4
25
+ else:
26
+ # Loop mode: packed weights + loop code
27
+ # 11 levels = 4-bit = 2 weights per byte
28
+ return (self.weight_count + 1) // 2 + 300 # +300 for loop code + decode table
29
+
30
+ @property
31
+ def estimated_ram(self) -> int:
32
+ """Estimated RAM usage in bytes (stack only, weights in Flash)."""
33
+ # Both modes store weights in Flash, RAM is just for accumulators
34
+ return 32 # Fixed stack usage per layer
35
+
36
+
37
+ def allocate_layers(
38
+ layer_infos: List[Dict],
39
+ loop_ratio: float = 0.2,
40
+ force_loop: Optional[List[str]] = None,
41
+ force_unroll: Optional[List[str]] = None,
42
+ ) -> Dict[str, LayerAllocation]:
43
+ """Allocate layers between unroll and loop modes.
44
+
45
+ Simple heuristic:
46
+ - Sort layers by weight count (descending)
47
+ - Top `loop_ratio` (default 20%) use loop mode (4 levels)
48
+ - Rest use unroll mode (11 levels)
49
+
50
+ Args:
51
+ layer_infos: List of layer info dicts with 'name' and weight info
52
+ loop_ratio: Fraction of layers (by count) to use loop mode (0.0-1.0)
53
+ force_loop: Layer names to force into loop mode
54
+ force_unroll: Layer names to force into unroll mode
55
+
56
+ Returns:
57
+ Dictionary mapping layer names to allocation decisions
58
+ """
59
+ force_loop = set(force_loop or [])
60
+ force_unroll = set(force_unroll or [])
61
+
62
+ # Collect PoT layers with weight counts
63
+ pot_layers = []
64
+ for info in layer_infos:
65
+ if info.get('layer_type') != 'pot':
66
+ continue
67
+
68
+ name = info['name']
69
+
70
+ # Calculate weight count from weight tensor
71
+ weight = info.get('weight')
72
+ if weight is not None:
73
+ weight_count = weight.numel()
74
+ else:
75
+ # Estimate from layer dimensions
76
+ layer_type = info.get('type', '')
77
+ if 'Conv' in layer_type:
78
+ in_ch = info.get('in_channels', 1)
79
+ out_ch = info.get('out_channels', 1)
80
+ ks = info.get('kernel_size', 3)
81
+ weight_count = out_ch * in_ch * ks * ks
82
+ elif 'Linear' in layer_type:
83
+ weight_count = info.get('in_features', 1) * info.get('out_features', 1)
84
+ else:
85
+ weight_count = 0
86
+
87
+ pot_layers.append({
88
+ 'name': name,
89
+ 'weight_count': weight_count,
90
+ 'info': info
91
+ })
92
+
93
+ if not pot_layers:
94
+ return {}
95
+
96
+ # Sort by weight count (largest first)
97
+ pot_layers.sort(key=lambda x: x['weight_count'], reverse=True)
98
+
99
+ # Calculate how many layers should be loop mode
100
+ total_layers = len(pot_layers)
101
+ num_loop = _calculate_loop_count(total_layers, loop_ratio)
102
+
103
+ print(f"\nAllocation strategy: {num_loop}/{total_layers} layers will use loop mode")
104
+
105
+ # Assign modes
106
+ allocations = {}
107
+ for i, layer in enumerate(pot_layers):
108
+ name = layer['name']
109
+ weight_count = layer['weight_count']
110
+
111
+ # Check forced assignments
112
+ if name in force_loop:
113
+ mode = 'loop'
114
+ levels = 11 # 11-level loop (4-bit packing, no accuracy loss)
115
+ elif name in force_unroll:
116
+ mode = 'unroll'
117
+ levels = 11
118
+ elif i < num_loop:
119
+ # Top N largest → loop mode
120
+ mode = 'loop'
121
+ levels = 11 # 11-level loop (4-bit packing, no accuracy loss)
122
+ else:
123
+ # Rest → unroll mode
124
+ mode = 'unroll'
125
+ levels = 11
126
+
127
+ allocations[name] = LayerAllocation(
128
+ name=name,
129
+ weight_count=weight_count,
130
+ mode=mode,
131
+ levels=levels
132
+ )
133
+
134
+ # Print summary
135
+ _print_allocation_summary(allocations, pot_layers)
136
+
137
+ return allocations
138
+
139
+
140
+ def _calculate_loop_count(total: int, ratio: float) -> int:
141
+ """Calculate number of layers to use loop mode.
142
+
143
+ Rules:
144
+ - 1-4 layers: 0 loop (all unroll)
145
+ - 5-9 layers: 1 loop
146
+ - 10-14 layers: 2 loop
147
+ - etc.
148
+ """
149
+ if total < 3:
150
+ return 0
151
+
152
+ # Round to nearest integer, minimum 1 if ratio > 0 and total >= 5
153
+ num_loop = round(total * ratio)
154
+ return max(1, num_loop) if ratio > 0 else 0
155
+
156
+
157
+ def _print_allocation_summary(
158
+ allocations: Dict[str, LayerAllocation],
159
+ pot_layers: List[Dict]
160
+ ) -> None:
161
+ """Print allocation summary."""
162
+ print("\nLayer allocation summary:")
163
+ print("-" * 70)
164
+ print(f"{'Layer':<30} {'Weights':>10} {'Mode':<8} {'Levels':>6} {'Flash':>10}")
165
+ print("-" * 70)
166
+
167
+ total_flash = 0
168
+ for layer in pot_layers:
169
+ name = layer['name']
170
+ alloc = allocations[name]
171
+ flash = alloc.estimated_flash
172
+ total_flash += flash
173
+
174
+ mode_str = f"{alloc.mode}"
175
+ print(f"{name:<30} {alloc.weight_count:>10} {mode_str:<8} {alloc.levels:>6} {flash:>10}")
176
+
177
+ print("-" * 70)
178
+
179
+ loop_count = sum(1 for a in allocations.values() if a.mode == 'loop')
180
+ unroll_count = sum(1 for a in allocations.values() if a.mode == 'unroll')
181
+
182
+ print(f"Total: {len(allocations)} layers ({unroll_count} unroll, {loop_count} loop)")
183
+ print(f"Estimated Flash: {total_flash:,} bytes ({total_flash/1024:.1f} KB)")
184
+
185
+
186
+ def allocate_from_model(
187
+ model: nn.Module,
188
+ loop_ratio: float = 0.2,
189
+ ) -> Dict[str, LayerAllocation]:
190
+ """Convenience function to allocate directly from PyTorch model.
191
+
192
+ Args:
193
+ model: PyTorch model with PoT layers
194
+ loop_ratio: Fraction of layers to use loop mode
195
+
196
+ Returns:
197
+ Dictionary mapping layer names to allocation decisions
198
+ """
199
+ from ..modules.base import PoTLayerBase
200
+
201
+ # Collect layer infos
202
+ layer_infos = []
203
+ for name, module in model.named_modules():
204
+ if isinstance(module, PoTLayerBase):
205
+ weight_count = module.weight.numel()
206
+ layer_infos.append({
207
+ 'name': name,
208
+ 'layer_type': 'pot',
209
+ 'type': type(module).__name__,
210
+ 'weight_count': weight_count,
211
+ 'weight': module.weight,
212
+ })
213
+
214
+ return allocate_layers(layer_infos, loop_ratio=loop_ratio)
215
+
216
+
217
+ # Backward compatibility alias
218
+ def allocate_hybrid(
219
+ model: nn.Module,
220
+ flash_budget: int,
221
+ ram_budget: int,
222
+ input_shape: Tuple = (1, 16, 16),
223
+ loop_ratio: float = 0.2,
224
+ ) -> Dict[str, LayerAllocation]:
225
+ """Backward-compatible wrapper for allocate_from_model.
226
+
227
+ This function is kept for compatibility with existing code.
228
+ New code should use allocate_layers() or allocate_from_model().
229
+
230
+ Args:
231
+ model: PyTorch model with PoT layers
232
+ flash_budget: Flash memory budget in bytes (currently ignored)
233
+ ram_budget: RAM budget in bytes (currently ignored)
234
+ input_shape: Input shape (currently ignored)
235
+ loop_ratio: Fraction of layers to use loop mode
236
+
237
+ Returns:
238
+ Dictionary mapping layer names to allocation decisions
239
+ """
240
+ return allocate_from_model(model, loop_ratio=loop_ratio)