potnn 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,460 @@
1
+ """Main C header generation for potnn."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Dict, Any, List, Tuple
7
+
8
+ from ..modules import PoTLinear, PoTConv2d
9
+ from .scale import generate_scale_func, calculate_combined_scale
10
+ from .unroll import generate_unrolled_layer
11
+
12
+
13
+
14
+ def generate_c_header(path: str, model: nn.Module):
15
+ """Generate complete C header file for the model.
16
+
17
+ Args:
18
+ path: Output file path
19
+ model: Trained potnn model
20
+ """
21
+ config = model._potnn_config
22
+ allocations = model._potnn_allocations
23
+ input_shape = model._potnn_input_shape
24
+
25
+ # Analyze model structure and compute layer dimensions
26
+ layer_dims = _compute_layer_dimensions(model, input_shape)
27
+
28
+ # Debug: print computed dimensions
29
+ print("\nComputed layer dimensions:")
30
+ for name, dims in layer_dims.items():
31
+ print(f" {name}: {dims}")
32
+
33
+ # Collect layer information
34
+ layers = []
35
+ other_layers = [] # MaxPool, Flatten 등
36
+ prev_act_scale = 1.0
37
+ pot_layer_idx = 0
38
+ num_pot_layers = len([m for m in model.modules() if isinstance(m, (PoTLinear, PoTConv2d))])
39
+
40
+ for name, module in model.named_modules():
41
+ if isinstance(module, (PoTLinear, PoTConv2d)):
42
+ alloc = allocations.get(name)
43
+ if alloc is None:
44
+ continue
45
+
46
+ with torch.no_grad():
47
+ alpha = F.softplus(module.alpha).clamp(min=0.01).item()
48
+
49
+ is_first = (pot_layer_idx == 0)
50
+ is_last = (pot_layer_idx == num_pot_layers - 1)
51
+
52
+ print(f"\n{'='*60}")
53
+ print(f"[DEBUG] Processing layer: {name}")
54
+ print(f" Type: {type(module).__name__}")
55
+ print(f" is_first: {is_first}, is_last: {is_last}")
56
+ print(f" alpha (raw): {module.alpha.item():.6f}")
57
+ print(f" alpha (softplus+clamp): {alpha:.6f}")
58
+ print(f" act_scale: {module.act_scale}")
59
+ print(f" prev_act_scale: {prev_act_scale:.6f}")
60
+
61
+ # Calculate adjusted bias (absorb mean for first layer)
62
+ adjusted_bias = calculate_adjusted_bias(
63
+ bias=module.bias.detach().cpu() if module.bias is not None else None,
64
+ weight=module.weight.detach().cpu(),
65
+ mean=config.mean if is_first else None,
66
+ is_first_layer=is_first
67
+ )
68
+
69
+ scale_int, shift = calculate_combined_scale(
70
+ alpha=alpha,
71
+ act_scale=module.act_scale if module.act_scale else 1.0,
72
+ prev_act_scale=prev_act_scale,
73
+ input_norm=config.input_norm if config.input_norm else 256,
74
+ is_first_layer=is_first,
75
+ is_last_layer=is_last,
76
+ mean=config.mean if is_first else None,
77
+ std=config.std if is_first else None
78
+ )
79
+
80
+ # Debug: Show bias scaling preview
81
+ act_scale_for_bias = module.act_scale if module.act_scale else 1.0
82
+ if adjusted_bias is not None:
83
+ scaled_bias = [int(round(b.item() * act_scale_for_bias)) for b in adjusted_bias[:4]]
84
+ print(f" [DEBUG] Bias scaling preview:")
85
+ print(f" act_scale for bias: {act_scale_for_bias:.4f}")
86
+ print(f" adjusted_bias (first 4): {[round(b.item(), 4) for b in adjusted_bias[:4]]}")
87
+ print(f" scaled_bias (first 4): {scaled_bias}")
88
+
89
+ # Make C-compatible name
90
+ c_name = name.replace('.', '_')
91
+ if c_name and c_name[0].isdigit():
92
+ c_name = 'layer_' + c_name
93
+
94
+ # Get dimensions from layer_dims
95
+ dims = layer_dims.get(name, {})
96
+
97
+ # Determine if ReLU follows this layer
98
+ use_relu = _check_relu_follows(model, name)
99
+
100
+ layer_info = {
101
+ 'name': c_name,
102
+ 'original_name': name,
103
+ 'type': type(module).__name__,
104
+ 'weight': module.weight.detach().cpu(),
105
+ 'bias': adjusted_bias, # Use adjusted bias (mean absorbed for first layer)
106
+ 'alpha': alpha,
107
+ 'act_scale': module.act_scale if module.act_scale else 1.0, # For bias scaling
108
+ 'levels': alloc.levels,
109
+ 'mode': alloc.mode,
110
+ 'scale_int': scale_int,
111
+ 'shift': shift,
112
+ 'use_relu': use_relu,
113
+ # Dimensions
114
+ 'in_h': dims.get('in_h', 0),
115
+ 'in_w': dims.get('in_w', 0),
116
+ 'out_h': dims.get('out_h', 0),
117
+ 'out_w': dims.get('out_w', 0),
118
+ 'in_size': dims.get('in_size', 0),
119
+ 'out_size': dims.get('out_size', 0),
120
+ 'stride': dims.get('stride', 1),
121
+ 'padding': dims.get('padding', 0),
122
+ }
123
+
124
+ layers.append(layer_info)
125
+
126
+ if module.act_scale is not None:
127
+ prev_act_scale = module.act_scale
128
+
129
+ pot_layer_idx += 1
130
+
131
+ elif isinstance(module, nn.MaxPool2d):
132
+ c_name = name.replace('.', '_')
133
+ if c_name and c_name[0].isdigit():
134
+ c_name = 'layer_' + c_name
135
+
136
+ dims = layer_dims.get(name, {})
137
+ other_layers.append({
138
+ 'name': c_name,
139
+ 'original_name': name,
140
+ 'type': 'MaxPool2d',
141
+ 'kernel_size': module.kernel_size if isinstance(module.kernel_size, int) else module.kernel_size[0],
142
+ 'stride': module.stride if isinstance(module.stride, int) else module.stride[0],
143
+ 'in_h': dims.get('in_h', 0),
144
+ 'in_w': dims.get('in_w', 0),
145
+ 'in_ch': dims.get('in_ch', 0),
146
+ 'out_h': dims.get('out_h', 0),
147
+ 'out_w': dims.get('out_w', 0),
148
+ 'in_size': dims.get('in_size', 0),
149
+ 'out_size': dims.get('out_size', 0),
150
+ })
151
+
152
+ elif isinstance(module, nn.Flatten):
153
+ c_name = name.replace('.', '_')
154
+ if c_name and c_name[0].isdigit():
155
+ c_name = 'layer_' + c_name
156
+
157
+ dims = layer_dims.get(name, {})
158
+ other_layers.append({
159
+ 'name': c_name,
160
+ 'original_name': name,
161
+ 'type': 'Flatten',
162
+ 'in_size': dims.get('in_size', 0),
163
+ 'out_size': dims.get('out_size', 0),
164
+ })
165
+
166
+ # Compute buffer sizes
167
+ max_buffer_size = _compute_max_buffer_size(layer_dims)
168
+
169
+ # Get output classes from last layer
170
+ last_layer = layers[-1] if layers else None
171
+ num_classes = last_layer['weight'].shape[0] if last_layer else 10
172
+
173
+ # Build ordered layer sequence
174
+ layer_sequence = _build_layer_sequence(model, layers, other_layers)
175
+
176
+ # Generate header file
177
+ with open(path, 'w') as f:
178
+ _write_header_preamble(f, config)
179
+
180
+ # Add input scale info if available
181
+ if hasattr(model, 'input_scale'):
182
+ f.write(f"/* Input scale: {model.input_scale:.3f} (for test data: int8 = normalized_float * input_scale) */\n")
183
+ f.write(f"/* Input max: {model.input_max:.3f} */\n\n")
184
+
185
+ # Scale functions
186
+ f.write("/* Scale functions using only shifts and adds */\n")
187
+ for layer in layers:
188
+ f.write(generate_scale_func(layer['name'], layer['scale_int'], layer['shift']))
189
+
190
+ # MaxPool functions
191
+ for other in other_layers:
192
+ if other['type'] == 'MaxPool2d':
193
+ f.write(_generate_maxpool_func(other))
194
+
195
+ # Layer forward functions
196
+ f.write("/* Layer forward functions */\n")
197
+ for layer in layers:
198
+ if layer['mode'] == 'unroll':
199
+ f.write(generate_unrolled_layer(layer))
200
+ else:
201
+ # Fallback purely to unroll or error
202
+ f.write(generate_unrolled_layer(layer))
203
+
204
+ # Main prediction function
205
+ f.write(_generate_main_predict(layer_sequence, max_buffer_size, num_classes))
206
+
207
+ f.write("\n#endif /* POTNN_MODEL_H */\n")
208
+
209
+ print(f"C header generated: {path}")
210
+ print(f" - {len(layers)} PoT layers")
211
+ print(f" - {len(other_layers)} other layers (MaxPool, Flatten)")
212
+ print(f" - Buffer size: {max_buffer_size} bytes")
213
+ print(f" - Output classes: {num_classes}")
214
+
215
+
216
+ def _write_header_preamble(f, config):
217
+ """Write header file preamble."""
218
+ f.write("/* POTNN Generated Model - MUL-FREE Neural Network */\n")
219
+ f.write("/* Target: Ultra-low-cost MCUs without multiplication */\n")
220
+ f.write(f"/* Flash budget: {config.flash} bytes */\n")
221
+ f.write(f"/* RAM budget: {config.ram} bytes */\n\n")
222
+ f.write("#ifndef POTNN_MODEL_H\n")
223
+ f.write("#define POTNN_MODEL_H\n\n")
224
+ f.write("#include <stdint.h>\n\n")
225
+
226
+
227
+ def _generate_maxpool_func(info: Dict) -> str:
228
+ """Generate MaxPool2d function."""
229
+ name = info['name']
230
+ kernel = info['kernel_size']
231
+ stride = info['stride']
232
+ in_h = info['in_h']
233
+ in_w = info['in_w']
234
+ in_ch = info['in_ch']
235
+ out_h = info['out_h']
236
+ out_w = info['out_w']
237
+
238
+ code = f"// {name} - MaxPool2d {kernel}x{kernel}, stride {stride}\n"
239
+ code += f"static void {name}_forward(const int8_t* input, int8_t* output) {{\n"
240
+ code += f" const int IN_H = {in_h}, IN_W = {in_w}, IN_CH = {in_ch};\n"
241
+ code += f" const int OUT_H = {out_h}, OUT_W = {out_w};\n"
242
+ code += f" const int K = {kernel}, S = {stride};\n"
243
+ code += f" \n"
244
+ code += f" for (int c = 0; c < IN_CH; c++) {{\n"
245
+ code += f" for (int oy = 0; oy < OUT_H; oy++) {{\n"
246
+ code += f" for (int ox = 0; ox < OUT_W; ox++) {{\n"
247
+ code += f" int8_t max_val = -128;\n"
248
+ code += f" for (int ky = 0; ky < K; ky++) {{\n"
249
+ code += f" for (int kx = 0; kx < K; kx++) {{\n"
250
+ code += f" int in_y = oy * S + ky;\n"
251
+ code += f" int in_x = ox * S + kx;\n"
252
+ code += f" if (in_y < IN_H && in_x < IN_W) {{\n"
253
+ code += f" int idx = c * IN_H * IN_W + in_y * IN_W + in_x;\n"
254
+ code += f" if (input[idx] > max_val) max_val = input[idx];\n"
255
+ code += f" }}\n"
256
+ code += f" }}\n"
257
+ code += f" }}\n"
258
+ code += f" int out_idx = c * OUT_H * OUT_W + oy * OUT_W + ox;\n"
259
+ code += f" output[out_idx] = max_val;\n"
260
+ code += f" }}\n"
261
+ code += f" }}\n"
262
+ code += f" }}\n"
263
+ code += "}\n\n"
264
+ return code
265
+
266
+
267
+ def _generate_main_predict(layer_sequence: List[Dict], buffer_size: int, num_classes: int) -> str:
268
+ """Generate the main prediction function."""
269
+ code = "/* Main prediction function */\n"
270
+ code += "int8_t potnn_predict(const int8_t* input) {\n"
271
+ code += f" static int8_t buffer1[{buffer_size}];\n"
272
+ code += f" static int8_t buffer2[{buffer_size}];\n"
273
+ code += " const int8_t* current_input = input;\n"
274
+ code += " int8_t* current_output = buffer1;\n\n"
275
+
276
+ for i, layer in enumerate(layer_sequence):
277
+ name = layer['name']
278
+ layer_type = layer['type']
279
+ is_last = (i == len(layer_sequence) - 1)
280
+
281
+ if layer_type == 'Flatten':
282
+ code += f" // {name}: Flatten (no-op, just continue)\n"
283
+ code += f" // Input and output share same memory layout\n\n"
284
+ continue
285
+
286
+ code += f" // {name}: {layer_type}\n"
287
+ code += f" {name}_forward(current_input, current_output);\n"
288
+
289
+ if not is_last:
290
+ code += " current_input = current_output;\n"
291
+ code += " current_output = (current_output == buffer1) ? buffer2 : buffer1;\n\n"
292
+
293
+ code += f"\n // Find argmax for classification ({num_classes} classes)\n"
294
+ code += " int8_t max_val = current_output[0];\n"
295
+ code += " int8_t max_idx = 0;\n"
296
+ code += f" for (int i = 1; i < {num_classes}; i++) {{\n"
297
+ code += " if (current_output[i] > max_val) {\n"
298
+ code += " max_val = current_output[i];\n"
299
+ code += " max_idx = i;\n"
300
+ code += " }\n"
301
+ code += " }\n"
302
+ code += " return max_idx;\n"
303
+ code += "}\n"
304
+
305
+ return code
306
+
307
+
308
+ def _compute_layer_dimensions(model: nn.Module, input_shape: Tuple) -> Dict[str, Dict]:
309
+ """Compute input/output dimensions for each layer.
310
+
311
+ Handles both original nn.Conv2d/nn.Linear and PoTConv2d/PoTLinear.
312
+ """
313
+ dims = {}
314
+
315
+ # Current shape tracking
316
+ if len(input_shape) == 3:
317
+ c, h, w = input_shape
318
+ elif len(input_shape) == 1:
319
+ c, h, w = 1, 1, input_shape[0]
320
+ else:
321
+ c, h, w = 1, input_shape[0], input_shape[1]
322
+
323
+ print(f"\nTracking dimensions from input shape: c={c}, h={h}, w={w}")
324
+
325
+ for name, module in model.named_modules():
326
+ # Skip container modules
327
+ if name == '':
328
+ continue
329
+
330
+ # Handle PoTConv2d and nn.Conv2d
331
+ if isinstance(module, PoTConv2d) or isinstance(module, nn.Conv2d):
332
+ in_h, in_w = h, w
333
+ in_ch = c
334
+
335
+ # Get kernel size, stride, padding
336
+ if isinstance(module, PoTConv2d):
337
+ kh = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size
338
+ kw = module.kernel_size[1] if isinstance(module.kernel_size, tuple) else module.kernel_size
339
+ sh = module.stride[0] if isinstance(module.stride, tuple) else module.stride
340
+ sw = module.stride[1] if isinstance(module.stride, tuple) else module.stride
341
+ ph = module.padding[0] if isinstance(module.padding, tuple) else module.padding
342
+ pw = module.padding[1] if isinstance(module.padding, tuple) else module.padding
343
+ out_ch = module.out_channels
344
+ else:
345
+ kh, kw = module.kernel_size if isinstance(module.kernel_size, tuple) else (module.kernel_size, module.kernel_size)
346
+ sh, sw = module.stride if isinstance(module.stride, tuple) else (module.stride, module.stride)
347
+ ph, pw = module.padding if isinstance(module.padding, tuple) else (module.padding, module.padding)
348
+ out_ch = module.out_channels
349
+
350
+ out_h = (in_h + 2*ph - kh) // sh + 1
351
+ out_w = (in_w + 2*pw - kw) // sw + 1
352
+
353
+ dims[name] = {
354
+ 'in_h': in_h, 'in_w': in_w, 'in_ch': in_ch,
355
+ 'out_h': out_h, 'out_w': out_w, 'out_ch': out_ch,
356
+ 'in_size': in_ch * in_h * in_w,
357
+ 'out_size': out_ch * out_h * out_w,
358
+ 'stride': sh,
359
+ 'padding': ph if ph == pw else (ph, pw),
360
+ }
361
+
362
+ # Update current shape
363
+ c, h, w = out_ch, out_h, out_w
364
+ print(f" {name} (Conv): {in_ch}x{in_h}x{in_w} -> {out_ch}x{out_h}x{out_w}")
365
+
366
+ elif isinstance(module, nn.MaxPool2d):
367
+ in_h, in_w = h, w
368
+ in_ch = c
369
+
370
+ k = module.kernel_size if isinstance(module.kernel_size, int) else module.kernel_size[0]
371
+ s = module.stride if isinstance(module.stride, int) else module.stride[0]
372
+
373
+ out_h = in_h // s
374
+ out_w = in_w // s
375
+
376
+ dims[name] = {
377
+ 'in_h': in_h, 'in_w': in_w, 'in_ch': in_ch,
378
+ 'out_h': out_h, 'out_w': out_w,
379
+ 'in_size': in_ch * in_h * in_w,
380
+ 'out_size': in_ch * out_h * out_w,
381
+ }
382
+
383
+ # Update current shape
384
+ h, w = out_h, out_w
385
+ print(f" {name} (MaxPool): {in_ch}x{in_h}x{in_w} -> {in_ch}x{out_h}x{out_w}")
386
+
387
+ elif isinstance(module, nn.Flatten):
388
+ flat_size = c * h * w
389
+ dims[name] = {
390
+ 'in_size': flat_size,
391
+ 'out_size': flat_size,
392
+ }
393
+ print(f" {name} (Flatten): {c}x{h}x{w} -> {flat_size}")
394
+
395
+ elif isinstance(module, PoTLinear) or isinstance(module, nn.Linear):
396
+ in_features = module.in_features
397
+ out_features = module.out_features
398
+
399
+ dims[name] = {
400
+ 'in_size': in_features,
401
+ 'out_size': out_features,
402
+ }
403
+
404
+ # Update current shape for potential next linear
405
+ c, h, w = 1, 1, out_features
406
+ print(f" {name} (Linear): {in_features} -> {out_features}")
407
+
408
+ # Skip BatchNorm, ReLU, Identity, Dropout - they don't change dimensions
409
+
410
+ return dims
411
+
412
+
413
+ def _compute_max_buffer_size(layer_dims: Dict) -> int:
414
+ """Compute maximum buffer size needed."""
415
+ max_size = 256 # Minimum
416
+ for name, dims in layer_dims.items():
417
+ in_size = dims.get('in_size', 0)
418
+ out_size = dims.get('out_size', 0)
419
+ max_size = max(max_size, in_size, out_size)
420
+ return max_size
421
+
422
+
423
+ def _check_relu_follows(model: nn.Module, layer_name: str) -> bool:
424
+ """Check if ReLU follows the given layer.
425
+
426
+ Skips Identity layers (fused BatchNorm) and BatchNorm layers.
427
+ """
428
+ found_layer = False
429
+ for name, module in model.named_modules():
430
+ if name == layer_name:
431
+ found_layer = True
432
+ continue
433
+ if found_layer:
434
+ # Skip Identity (fused BatchNorm) and BatchNorm
435
+ if isinstance(module, (nn.Identity, nn.BatchNorm1d, nn.BatchNorm2d)):
436
+ continue
437
+ if isinstance(module, (nn.ReLU, nn.ReLU6)):
438
+ return True
439
+ elif isinstance(module, (nn.Conv2d, nn.Linear, nn.MaxPool2d)):
440
+ return False # Another compute layer before ReLU
441
+ # Also check for PoT modules
442
+ if isinstance(module, (PoTLinear, PoTConv2d)):
443
+ return False # PoTLinear or PoTConv2d
444
+ return False
445
+
446
+
447
+ def _build_layer_sequence(model: nn.Module, pot_layers: List[Dict], other_layers: List[Dict]) -> List[Dict]:
448
+ """Build ordered sequence of all layers."""
449
+ # Create lookup by original name
450
+ pot_lookup = {l['original_name']: l for l in pot_layers}
451
+ other_lookup = {l['original_name']: l for l in other_layers}
452
+
453
+ sequence = []
454
+ for name, module in model.named_modules():
455
+ if name in pot_lookup:
456
+ sequence.append(pot_lookup[name])
457
+ elif name in other_lookup:
458
+ sequence.append(other_lookup[name])
459
+
460
+ return sequence