potnn 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
potnn/utils/memory.py ADDED
@@ -0,0 +1,158 @@
1
+ """Memory estimation and validation utilities."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Dict, Tuple
6
+
7
+
8
+ def estimate_layer_size(module: nn.Module) -> int:
9
+ """Estimate the size of a layer in bytes.
10
+
11
+ Args:
12
+ module: Neural network module
13
+
14
+ Returns:
15
+ Estimated size in bytes
16
+ """
17
+ param_count = 0
18
+
19
+ # Count weight parameters
20
+ if hasattr(module, 'weight') and module.weight is not None:
21
+ param_count += module.weight.numel()
22
+
23
+ # Count bias parameters
24
+ if hasattr(module, 'bias') and module.bias is not None:
25
+ param_count += module.bias.numel()
26
+
27
+ return param_count
28
+
29
+
30
+ def estimate_activation_size(model: nn.Module, input_shape: Tuple) -> int:
31
+ """Estimate maximum activation buffer size needed.
32
+
33
+ Args:
34
+ model: Neural network model
35
+ input_shape: Shape of input tensor (without batch dimension)
36
+
37
+ Returns:
38
+ Maximum activation size in bytes
39
+ """
40
+ # Create dummy input
41
+ dummy_input = torch.zeros(1, *input_shape)
42
+ device = next(model.parameters()).device
43
+ dummy_input = dummy_input.to(device)
44
+
45
+ max_size = 0
46
+ hooks = []
47
+
48
+ def hook_fn(module, input, output):
49
+ nonlocal max_size
50
+ if isinstance(output, torch.Tensor):
51
+ size = output.numel() # int8 = 1 byte per element
52
+ max_size = max(max_size, size)
53
+
54
+ # Register hooks
55
+ for module in model.modules():
56
+ if isinstance(module, (nn.Conv2d, nn.Linear, nn.MaxPool2d)):
57
+ hooks.append(module.register_forward_hook(hook_fn))
58
+
59
+ # Run forward pass
60
+ model.eval()
61
+ with torch.no_grad():
62
+ model(dummy_input)
63
+
64
+ # Clean up hooks
65
+ for hook in hooks:
66
+ hook.remove()
67
+
68
+ return max_size
69
+
70
+
71
+ def estimate_memory_usage(model: nn.Module, input_shape: Tuple, mode: str = 'all') -> Dict[str, int]:
72
+ """Estimate memory usage of the model.
73
+
74
+ Args:
75
+ model: Neural network model
76
+ input_shape: Shape of input tensor (without batch dimension)
77
+ mode: 'all', 'weights', or 'activations'
78
+
79
+ Returns:
80
+ Dictionary with memory estimates in bytes
81
+ """
82
+ result = {}
83
+
84
+ if mode in ['all', 'weights']:
85
+ # Estimate weight memory
86
+ total_weights = 0
87
+ layer_weights = {}
88
+
89
+ for name, module in model.named_modules():
90
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
91
+ size = estimate_layer_size(module)
92
+ layer_weights[name] = size
93
+ total_weights += size
94
+
95
+ result['total_weights'] = total_weights
96
+ result['layer_weights'] = layer_weights
97
+
98
+ if mode in ['all', 'activations']:
99
+ # Estimate activation memory
100
+ result['max_activation'] = estimate_activation_size(model, input_shape)
101
+
102
+ if mode == 'all':
103
+ # Input buffer size
104
+ input_size = torch.zeros(1, *input_shape).numel()
105
+ result['input_buffer'] = input_size
106
+
107
+ # Total RAM needed (input + activations + some weights in loop mode)
108
+ result['estimated_ram'] = result['input_buffer'] + result['max_activation']
109
+
110
+ # Total Flash needed (mainly for unrolled weights as code)
111
+ # This is a rough estimate - actual size depends on unroll/loop decisions
112
+ result['estimated_flash'] = total_weights * 4 # Rough estimate for unrolled code
113
+
114
+ return result
115
+
116
+
117
+ def validate_memory(model: nn.Module, flash_budget: int, ram_budget: int,
118
+ input_shape: Tuple = (1, 16, 16)) -> bool:
119
+ """Validate if model fits within memory budgets.
120
+
121
+ Args:
122
+ model: Neural network model
123
+ flash_budget: Flash memory budget in bytes
124
+ ram_budget: RAM budget in bytes
125
+ input_shape: Input tensor shape (default for 16x16 grayscale)
126
+
127
+ Returns:
128
+ True if model fits, False otherwise
129
+
130
+ Raises:
131
+ ValueError: If model doesn't fit with error details
132
+ """
133
+ estimates = estimate_memory_usage(model, input_shape)
134
+
135
+ # Check RAM budget
136
+ min_ram_needed = estimates['input_buffer'] + estimates['max_activation']
137
+ if min_ram_needed > ram_budget:
138
+ raise ValueError(
139
+ f"Model requires at least {min_ram_needed} bytes of RAM "
140
+ f"(input: {estimates['input_buffer']}, activation: {estimates['max_activation']}), "
141
+ f"but only {ram_budget} bytes available."
142
+ )
143
+
144
+ # Check if we can fit weights either in Flash (unrolled) or RAM (loop)
145
+ # This is a simplified check - actual allocation is done by allocate_hybrid
146
+ # For unrolled code, estimate ~4 bytes per weight
147
+ # For loop with packing, estimate ~0.25 bytes per weight (2-bit packing)
148
+ unrolled_size = estimates['total_weights'] * 4
149
+ packed_size = estimates['total_weights'] // 4 # 2-bit packing
150
+
151
+ if unrolled_size > flash_budget and packed_size > (ram_budget - min_ram_needed):
152
+ raise ValueError(
153
+ f"Model weights ({estimates['total_weights']} parameters) too large. "
154
+ f"Unrolled: {unrolled_size} bytes > Flash {flash_budget} bytes. "
155
+ f"Packed: {packed_size} bytes > Available RAM {ram_budget - min_ram_needed} bytes."
156
+ )
157
+
158
+ return True
potnn/wrapper.py ADDED
@@ -0,0 +1,304 @@
1
+ """Model wrapper for potnn conversion."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Optional, Callable
7
+ from torch.utils.data import DataLoader
8
+
9
+ from .config import Config
10
+ from .modules import PoTLinear, PoTConv2d
11
+ from .quantize.calibration import calibrate_model
12
+ from .quantize.qat import prepare_qat, alpha_reg_loss, enable_integer_sim
13
+ from .utils import validate_memory, allocate_hybrid
14
+ from .fuse import fuse_batchnorm
15
+
16
+
17
+ def _normalize_data(data: torch.Tensor, mean: float, std: float) -> torch.Tensor:
18
+ """Normalize input data to match C inference exactly.
19
+
20
+ C inference uses:
21
+ - /256 (via shift+8) instead of /255
22
+ - avg_std (single value) instead of per-channel std
23
+
24
+ QAT must match this for training = deployment consistency.
25
+
26
+ Args:
27
+ data: Input tensor [N, C, H, W] or [N, C, L], expected range [0, 1] from torchvision
28
+ mean: Dataset mean - float or List[float]
29
+ std: Dataset std - float or List[float]
30
+
31
+ Returns:
32
+ Normalized tensor matching C inference behavior
33
+ """
34
+ import torch
35
+
36
+ # Calculate average std (C uses single scale value)
37
+ if isinstance(std, (list, tuple)):
38
+ avg_std = sum(std) / len(std)
39
+ else:
40
+ avg_std = std
41
+
42
+ # Convert mean to tensor with proper shape for broadcasting
43
+ # Dynamic: [1, C, 1] for 3D (Conv1d), [1, C, 1, 1] for 4D (Conv2d)
44
+ if isinstance(mean, (list, tuple)):
45
+ mean = torch.tensor(mean, dtype=data.dtype, device=data.device)
46
+ if data.dim() == 3: # Conv1d: (B, C, L)
47
+ mean = mean.view(1, -1, 1)
48
+ else: # Conv2d: (B, C, H, W)
49
+ mean = mean.view(1, -1, 1, 1)
50
+
51
+ # Match C inference:
52
+ # - data comes as [0,1] from torchvision (raw/255)
53
+ # - C uses raw/256, so multiply by 256/255 to compensate
54
+ # - C uses avg_std in scale, so divide by avg_std
55
+ return (data * (256.0 / 255.0) - mean) / avg_std
56
+
57
+
58
+ def _validate_model(model: nn.Module) -> None:
59
+ """모델에 일반 nn.Conv2d, nn.Linear가 있는지 검사.
60
+
61
+ potnn은 PoTConv2d, PoTLinear만 지원한다.
62
+ 일반 레이어가 섞여 있으면 export 시 실패하므로 미리 경고.
63
+ """
64
+ from .modules.conv import PoTConv2d
65
+ from .modules.depthwise import PoTDepthwiseConv2d
66
+ from .modules.linear import PoTLinear
67
+
68
+ errors = []
69
+
70
+ for name, module in model.named_modules():
71
+ # nn.Conv2d지만 PoTConv2d/PoTDepthwiseConv2d가 아닌 경우
72
+ if isinstance(module, nn.Conv2d) and not isinstance(module, (PoTConv2d, PoTDepthwiseConv2d)):
73
+ errors.append(f" - {name}: nn.Conv2d → potnn.PoTConv2d로 교체 필요")
74
+
75
+ # nn.Linear지만 PoTLinear가 아닌 경우
76
+ if isinstance(module, nn.Linear) and not isinstance(module, PoTLinear):
77
+ errors.append(f" - {name}: nn.Linear → potnn.PoTLinear로 교체 필요")
78
+
79
+ if errors:
80
+ error_msg = "\n".join(errors)
81
+ raise ValueError(
82
+ f"potnn은 PoT 레이어만 지원합니다. 다음 레이어를 교체하세요:\n{error_msg}\n\n"
83
+ f"예시:\n"
84
+ f" nn.Conv2d(1, 16, 3) → potnn.PoTConv2d(1, 16, 3)\n"
85
+ f" nn.Linear(256, 10) → potnn.PoTLinear(256, 10)"
86
+ )
87
+
88
+
89
+ def train(model: nn.Module,
90
+ train_loader: DataLoader,
91
+ test_loader: DataLoader,
92
+ config: Config,
93
+ float_epochs: int = 15,
94
+ qat_epochs: int = 50,
95
+ float_lr: float = 1e-3,
96
+ qat_lr: float = 1e-4,
97
+ device: str = 'cuda',
98
+ fuse_bn: bool = True,
99
+ verbose: bool = True) -> nn.Module:
100
+ """Complete training pipeline: Float → (BN Fusion) → Calibration → QAT → Integer Sim
101
+
102
+ Args:
103
+ model: PoT model (must use PoTConv2d, PoTLinear)
104
+ train_loader: Training data loader (raw [0,1] input, NO Normalize transform needed)
105
+ test_loader: Test data loader (raw [0,1] input, NO Normalize transform needed)
106
+ config: potnn Config (mean/std used for automatic normalization)
107
+ float_epochs: Float training epochs (default: 15)
108
+ qat_epochs: QAT training epochs (default: 50)
109
+ float_lr: Float training learning rate (default: 1e-3)
110
+ qat_lr: QAT training learning rate (default: 1e-4)
111
+ device: 'cuda' or 'cpu' (default: 'cuda')
112
+ fuse_bn: Fuse BatchNorm layers after float training (default: True)
113
+ verbose: Print progress (default: True)
114
+
115
+ Returns:
116
+ Trained model with Integer Simulation enabled.
117
+ - model.train(): uses Float QAT for fine-tuning
118
+ - model.eval(): uses Integer Simulation (matches C exactly)
119
+
120
+ Note:
121
+ Input normalization is handled automatically using config.mean/std.
122
+ Do NOT add transforms.Normalize() to your DataLoader.
123
+ """
124
+ # 모델 검증: 일반 nn.Conv2d, nn.Linear 사용 시 경고
125
+ _validate_model(model)
126
+
127
+ model = model.to(device)
128
+
129
+ # Get normalization params from config
130
+ mean = config.mean if config.mean is not None else 0.0
131
+ std = config.std if config.std is not None else 1.0
132
+
133
+ # Phase 1: Float Training
134
+ if verbose:
135
+ print(f"\n[Phase 1] Float Training ({float_epochs} epochs)...")
136
+
137
+ optimizer = torch.optim.AdamW(model.parameters(), lr=float_lr)
138
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float_epochs)
139
+
140
+ best_float_acc = 0
141
+ for epoch in range(float_epochs):
142
+ model.train()
143
+ for data, target in train_loader:
144
+ data, target = data.to(device), target.to(device)
145
+ data = _normalize_data(data, mean, std) # Auto normalize
146
+ optimizer.zero_grad()
147
+ output = model(data)
148
+ loss = F.cross_entropy(output, target)
149
+ loss.backward()
150
+ optimizer.step()
151
+
152
+ scheduler.step()
153
+ acc = _evaluate(model, test_loader, device, mean, std)
154
+ best_float_acc = max(best_float_acc, acc)
155
+
156
+ if verbose and (epoch % 5 == 0 or epoch == float_epochs - 1):
157
+ print(f" Epoch {epoch+1}/{float_epochs}: {acc:.2f}%")
158
+
159
+ if verbose:
160
+ print(f" Best Float: {best_float_acc:.2f}%")
161
+
162
+ # Phase 1.5: BatchNorm Fusion (optional)
163
+ if fuse_bn:
164
+ if verbose:
165
+ print(f"\n[Phase 1.5] BatchNorm Fusion...")
166
+ model = fuse_batchnorm(model)
167
+
168
+ # Phase 2: Calibration
169
+ if verbose:
170
+ print(f"\n[Phase 2] Calibration...")
171
+
172
+ calibrate_model(model, train_loader, mean=mean, std=std)
173
+
174
+ # Phase 3: QAT Preparation
175
+ if verbose:
176
+ print(f"\n[Phase 3] Preparing QAT...")
177
+
178
+ prepare_qat(model, config)
179
+
180
+ # Set up first layer info for mean absorption during QAT
181
+ # This ensures QAT uses the same bias as Integer Sim
182
+ from .modules.base import PoTLayerBase
183
+ pot_layers = [(name, m) for name, m in model.named_modules() if isinstance(m, PoTLayerBase)]
184
+ for i, (name, layer) in enumerate(pot_layers):
185
+ is_first = (i == 0)
186
+ is_last = (i == len(pot_layers) - 1)
187
+ layer.set_layer_position(is_first, is_last)
188
+ if is_first:
189
+ layer.set_input_std(config.std, config.mean)
190
+
191
+ # Phase 4: QAT Training (Hybrid: float first, integer sim last 20%)
192
+ if verbose:
193
+ print(f"\n[Phase 4] QAT Training ({qat_epochs} epochs)...")
194
+
195
+ optimizer = torch.optim.Adam(model.parameters(), lr=qat_lr)
196
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, qat_epochs)
197
+
198
+ # Activation epochs (last 20%)
199
+ constraint_start_epoch = int(qat_epochs * 0.8)
200
+ integer_sim_start_epoch = int(qat_epochs * 0.6) # Float QAT 60% → Integer sim fine-tune 40%
201
+
202
+ best_qat_acc = 0
203
+ best_state = None
204
+
205
+ for epoch in range(qat_epochs):
206
+ # Enable 5level constraint for last 20% epochs
207
+ if epoch == constraint_start_epoch:
208
+ for name, module in model.named_modules():
209
+ if hasattr(module, 'enforce_5level_constraint'):
210
+ if hasattr(module, 'encoding') and module.encoding == '5level':
211
+ module.enforce_5level_constraint = True
212
+ if verbose:
213
+ print(f" [{name}] 5level constraint enabled (epoch {epoch+1})")
214
+
215
+ # Enable integer sim for last 20% epochs (fine-tune phase)
216
+ if epoch == integer_sim_start_epoch:
217
+ if verbose:
218
+ print(f" [Integer Sim] Enabled for fine-tuning (epoch {epoch+1})")
219
+ enable_integer_sim(model, input_std=config.std, input_mean=config.mean, verbose=False)
220
+ # Lower learning rate for fine-tuning
221
+ for param_group in optimizer.param_groups:
222
+ param_group['lr'] = param_group['lr'] * 0.1
223
+
224
+ # Update integer params each epoch if using integer sim
225
+ if epoch >= integer_sim_start_epoch:
226
+ for name, module in model.named_modules():
227
+ if isinstance(module, PoTLayerBase) and module.use_integer_sim:
228
+ module.compute_integer_params()
229
+
230
+ model.train()
231
+ for data, target in train_loader:
232
+ data, target = data.to(device), target.to(device)
233
+ data = _normalize_data(data, mean, std) # Auto normalize
234
+ optimizer.zero_grad()
235
+ output = model(data)
236
+ loss = F.cross_entropy(output, target) + alpha_reg_loss(model, 0.01)
237
+ loss.backward()
238
+ optimizer.step()
239
+
240
+ scheduler.step()
241
+ acc = _evaluate(model, test_loader, device, mean, std)
242
+
243
+ # Only update best after integer sim starts (to ensure C-compatible weights)
244
+ if epoch >= integer_sim_start_epoch and acc > best_qat_acc:
245
+ best_qat_acc = acc
246
+ best_state = {k: v.clone() for k, v in model.state_dict().items()}
247
+
248
+ if verbose and (epoch % 10 == 0 or epoch == qat_epochs - 1):
249
+ print(f" Epoch {epoch+1}/{qat_epochs}: {acc:.2f}%")
250
+
251
+ # Restore best model
252
+ if best_state is not None:
253
+ model.load_state_dict(best_state, strict=False)
254
+
255
+ if verbose:
256
+ print(f" Best QAT: {best_qat_acc:.2f}%")
257
+
258
+ # Ensure integer sim is enabled for final model
259
+ enable_integer_sim(model, input_std=config.std, input_mean=config.mean, verbose=verbose)
260
+
261
+ # Final integer params update
262
+ for name, module in model.named_modules():
263
+ if isinstance(module, PoTLayerBase) and module.use_integer_sim:
264
+ module.compute_integer_params()
265
+
266
+ # Final accuracy (with integer sim)
267
+ final_acc = _evaluate(model, test_loader, device, mean, std)
268
+
269
+ if verbose:
270
+ print(f"\n[Summary] Float: {best_float_acc:.2f}% → QAT: {best_qat_acc:.2f}% → C-Ready: {final_acc:.2f}%")
271
+
272
+ # Attach stats to model for reporting
273
+ model.train_stats = {
274
+ 'float_acc': best_float_acc,
275
+ 'qat_acc': best_qat_acc,
276
+ 'final_acc': final_acc
277
+ }
278
+
279
+ return model
280
+
281
+
282
+ def _evaluate(model: nn.Module, test_loader: DataLoader, device: str,
283
+ mean: float = 0.0, std: float = 1.0) -> float:
284
+ """Evaluate model accuracy with automatic normalization."""
285
+ model.eval()
286
+ correct = 0
287
+ total = 0
288
+
289
+ with torch.no_grad():
290
+ for data, target in test_loader:
291
+ data, target = data.to(device), target.to(device)
292
+ data = _normalize_data(data, mean, std) # Auto normalize
293
+ output = model(data)
294
+ _, predicted = output.max(1)
295
+ total += target.size(0)
296
+ correct += predicted.eq(target).sum().item()
297
+
298
+ return 100. * correct / total
299
+
300
+
301
+ # Note: wrap() function has been removed.
302
+ # Users must define models using PoT layers directly:
303
+ # potnn.PoTConv2d, potnn.PoTConv1d, potnn.PoTLinear, etc.
304
+ # This ensures proper initialization of alpha, QAT parameters, and encoding.