potnn 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
potnn/export.py ADDED
@@ -0,0 +1,2196 @@
1
+ """Export functionality for PoT quantized models via ONNX.
2
+
3
+ Follows the technical specification exactly:
4
+ - Input: uint8 [0,255], no runtime normalization
5
+ - /256: absorbed into combined_shift (+8)
6
+ - mean: absorbed into bias (b' = b - mean × ΣW)
7
+ - /std: absorbed into combined_scale (scale × 1/std)
8
+ - bias: scaled by act_scale (bias_int = round(bias × act_scale))
9
+ - Output: MUL-free C code (shift+add only, except combined_scale MUL per layer)
10
+
11
+ v8: Fixed all bugs from systematic verification
12
+ - Fixed: bias scaling with act_scale
13
+ - Fixed: standardization absorption with act_scale
14
+ - Fixed: last layer act_scale=None handling
15
+ - Fixed: weight to shift conversion (w is already PoT*alpha)
16
+ """
17
+
18
+ import os
19
+ import tempfile
20
+ import torch
21
+ import torch.nn as nn
22
+ import numpy as np
23
+ from typing import Dict, List, Optional, Any, Tuple
24
+
25
+ try:
26
+ import onnx
27
+ from onnx import numpy_helper
28
+ ONNX_AVAILABLE = True
29
+ except ImportError:
30
+ ONNX_AVAILABLE = False
31
+
32
+ from .modules.base import PoTLayerBase
33
+ from .modules.conv import PoTConv2d
34
+ from .modules.conv1d import PoTConv1d
35
+ from .modules.depthwise import PoTDepthwiseConv2d
36
+ from .modules.linear import PoTLinear
37
+ from .modules.add import PoTAdd
38
+ from .config import Config
39
+ from .quantize.pot import quantize_to_pot
40
+
41
+
42
+ def export(model: nn.Module, output_path: str, config: Config, dummy_input: torch.Tensor = None, optimized: bool = True):
43
+ """
44
+ Export PyTorch model to C code.
45
+
46
+ Args:
47
+ model: PyTorch model (PoT layers)
48
+ output_path: Path to save .c file
49
+ config: Hardware configuration (Config class)
50
+ dummy_input: Dummy input for ONNX export (optional, auto-generated if None)
51
+ optimized: Enable optimized C kernels (default: True)
52
+ - Loop layers: Full Pipeline (Zero-Padding + im2col + Column-wise +
53
+ Position Blocking + Shift Grouping)
54
+ - Unroll layers: Zero-Padding (eliminates boundary checks)
55
+ """
56
+ if not ONNX_AVAILABLE:
57
+ raise RuntimeError("onnx package required. Install with: pip install onnx")
58
+
59
+ print(f"\nStarting export to {output_path}...")
60
+ print(f"Optimized mode: {'ENABLED' if optimized else 'DISABLED'} (Full Pipeline for loop, Zero-Padding for unroll)")
61
+ print("Using ONNX for graph extraction (v10 - hybrid unroll/loop)...")
62
+
63
+ model.eval()
64
+
65
+ # Step 0a: Disable Integer Simulation for clean ONNX export
66
+ # Integer Sim adds round_ste/clamp_ste which become spurious Add nodes
67
+ from .quantize.qat import disable_integer_sim
68
+ disable_integer_sim(model)
69
+
70
+ # Step 0a-2: Disable 5level constraint for torch.export compatibility
71
+ # The constraint uses Python for-loops which torch.export doesn't support
72
+ from .modules.base import PoTLayerBase
73
+ for module in model.modules():
74
+ if isinstance(module, PoTLayerBase):
75
+ module.enforce_5level_constraint = False
76
+
77
+ # Step 0b: Fuse BatchNorm layers into Conv/Linear (CRITICAL!)
78
+ from .fuse import fuse_batchnorm, check_bn_fused
79
+ model = fuse_batchnorm(model)
80
+ if not check_bn_fused(model):
81
+ print("Warning: Some BatchNorm layers could not be fused!")
82
+
83
+ if dummy_input is None:
84
+ if config.input_w == 1:
85
+ # Conv1d: (B, C, L) - input_h is sequence length
86
+ dummy_input = torch.randn(1, config.input_channels, config.input_h)
87
+ else:
88
+ # Conv2d: (B, C, H, W)
89
+ dummy_input = torch.randn(1, config.input_channels, config.input_h, config.input_w)
90
+
91
+ device = next(model.parameters()).device
92
+ dummy_input = dummy_input.to(device)
93
+
94
+ # Auto-detect input shape from dummy_input
95
+ if dummy_input is not None:
96
+ if dummy_input.dim() == 4:
97
+ # (B, C, H, W)
98
+ print(f"[Export] Auto-detected input shape: {dummy_input.shape}")
99
+ config.input_channels = dummy_input.shape[1]
100
+ config.input_h = dummy_input.shape[2]
101
+ config.input_w = dummy_input.shape[3]
102
+ elif dummy_input.dim() == 3:
103
+ # (B, C, L)
104
+ print(f"[Export] Auto-detected input shape (1D): {dummy_input.shape}")
105
+ config.input_channels = dummy_input.shape[1]
106
+ config.input_h = dummy_input.shape[2] # Length
107
+ config.input_w = 1
108
+
109
+ # Step 1: Collect PoT layer info from PyTorch model
110
+ pot_layer_infos = collect_pot_layer_info(model)
111
+ print(f"Found {len(pot_layer_infos)} PoT layers in PyTorch model")
112
+
113
+ # Collect PoTAdd layer info (for rescale calculation)
114
+ add_layer_infos = collect_add_layer_info(model)
115
+
116
+ # Debug: print layer info
117
+ for info in pot_layer_infos:
118
+ print(f" {info['name']}: alpha={info['alpha']:.4f}, act_scale={info['act_scale']}")
119
+
120
+ # Step 2: Export to ONNX
121
+ with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:
122
+ onnx_path = f.name
123
+
124
+ try:
125
+ # Use legacy ONNX exporter (dynamo=False) to avoid torch.export compatibility issues
126
+ torch.onnx.export(
127
+ model, dummy_input, onnx_path,
128
+ input_names=['input'], output_names=['output'],
129
+ opset_version=11, do_constant_folding=True,
130
+ dynamo=False, # Legacy exporter
131
+ )
132
+ print(f"ONNX export successful")
133
+
134
+ # Step 3: Load ONNX graph
135
+ onnx_model = onnx.load(onnx_path)
136
+ onnx.checker.check_model(onnx_model)
137
+
138
+ layer_infos = parse_onnx_graph(onnx_model, pot_layer_infos, config, add_layer_infos)
139
+ print(f"Parsed {len(layer_infos)} layers from ONNX graph")
140
+
141
+ finally:
142
+ if os.path.exists(onnx_path):
143
+ os.remove(onnx_path)
144
+
145
+ # Step 4: Absorb standardization into first layer (per specification 6.4-6.6)
146
+ if config.mean is not None and config.std is not None:
147
+ absorb_standardization(layer_infos, config.mean, config.std)
148
+
149
+ # Step 5: Calculate combined scales (per specification 9.x)
150
+ calculate_combined_scales(layer_infos, config)
151
+
152
+ # Step 6: Scale biases by act_scale (CRITICAL FIX)
153
+ scale_biases(layer_infos)
154
+
155
+ # Step 6.5: Decide loop/unroll mode (All unroll now)
156
+ allocations = {}
157
+
158
+ # Add allocation info to layer_infos
159
+ for info in layer_infos:
160
+ if info['layer_type'] == 'pot':
161
+ name = info['name']
162
+ if name in allocations:
163
+ alloc = allocations[name]
164
+ info['code_mode'] = alloc.mode # 'unroll' or 'loop'
165
+ info['levels'] = alloc.levels # 11 or 4
166
+ else:
167
+ # Default to unroll
168
+ info['code_mode'] = 'unroll'
169
+ info['levels'] = 11
170
+
171
+ # Step 8: Generate C header
172
+ code = generate_header(layer_infos, config, optimized=optimized)
173
+
174
+ # Step 9: Write to file
175
+ with open(output_path, 'w') as f:
176
+ f.write(code)
177
+
178
+ # Print summary
179
+ pot_count = sum(1 for info in layer_infos if info['layer_type'] == 'pot')
180
+ maxpool_count = sum(1 for info in layer_infos if info['layer_type'] == 'maxpool')
181
+ unroll_count = sum(1 for info in layer_infos
182
+ if info['layer_type'] == 'pot')
183
+
184
+ print(f"\nExport complete!")
185
+ print(f" Output file: {output_path}")
186
+ print(f" Target MCU: {config.ram}B RAM, {config.flash}B Flash")
187
+ print(f" PoT layers: {pot_count}")
188
+ print(f" MaxPool layers: {maxpool_count}")
189
+ print(f" Total layers: {len(layer_infos)}")
190
+
191
+
192
+ def collect_pot_layer_info(model: nn.Module) -> List[Dict]:
193
+ """Collect alpha, act_scale, weights from PoT layers in order."""
194
+ infos = []
195
+
196
+ for name, module in model.named_modules():
197
+ if isinstance(module, PoTLayerBase):
198
+ with torch.no_grad():
199
+ alpha = module.alpha.item()
200
+ # act_scale can be None for last layer
201
+ act_scale = module.act_scale.item() if module.act_scale is not None else None
202
+ weight = module.weight.cpu().clone()
203
+ bias = module.bias.cpu().clone() if module.bias is not None else None
204
+
205
+ # Get encoding from module (default: 'unroll')
206
+ encoding = getattr(module, 'encoding', 'unroll')
207
+
208
+ # Get PoT quantized weight with encoding-specific levels
209
+ weight_q = quantize_to_pot(weight, alpha, encoding=encoding)
210
+
211
+ # Apply 5level constraint (max 3 consecutive zeros)
212
+ if encoding == '5level':
213
+ from .quantize.pot import apply_5level_zero_constraint
214
+ weight_q = apply_5level_zero_constraint(weight_q)
215
+
216
+ layer_info = {
217
+ 'name': name,
218
+ 'pot_index': len(infos),
219
+ 'alpha': alpha,
220
+ 'act_scale': act_scale, # Keep None as None!
221
+ 'weight': weight_q,
222
+ 'bias': bias, # Original float bias
223
+ 'encoding': encoding,
224
+ }
225
+
226
+ if isinstance(module, PoTConv2d):
227
+ layer_info['type'] = 'conv'
228
+ layer_info['in_channels'] = module.in_channels
229
+ layer_info['out_channels'] = module.out_channels
230
+ layer_info['kernel_size'] = module.kernel_size
231
+ layer_info['stride'] = module.stride
232
+ layer_info['padding'] = module.padding
233
+ layer_info['groups'] = module.groups
234
+ elif isinstance(module, PoTConv1d):
235
+ layer_info['type'] = 'conv1d'
236
+ layer_info['in_channels'] = module.in_channels
237
+ layer_info['out_channels'] = module.out_channels
238
+ layer_info['kernel_size'] = module.kernel_size
239
+ layer_info['stride'] = module.stride
240
+ layer_info['padding'] = module.padding
241
+ layer_info['groups'] = getattr(module, 'groups', 1)
242
+ elif isinstance(module, PoTDepthwiseConv2d):
243
+ layer_info['type'] = 'depthwise'
244
+ layer_info['channels'] = module.channels
245
+ layer_info['in_channels'] = module.channels
246
+ layer_info['out_channels'] = module.channels
247
+ layer_info['kernel_size'] = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size
248
+ layer_info['stride'] = module.stride[0] if isinstance(module.stride, tuple) else module.stride
249
+ layer_info['padding'] = module.padding[0] if isinstance(module.padding, tuple) else module.padding
250
+ layer_info['groups'] = module.channels
251
+ elif isinstance(module, PoTLinear):
252
+ layer_info['type'] = 'linear'
253
+ layer_info['in_features'] = module.in_features
254
+ layer_info['out_features'] = module.out_features
255
+
256
+ infos.append(layer_info)
257
+
258
+ return infos
259
+
260
+
261
+ def collect_add_layer_info(model: nn.Module) -> List[Dict]:
262
+ """Collect PoTAdd layers info in order."""
263
+ infos = []
264
+
265
+ for name, module in model.named_modules():
266
+ if isinstance(module, PoTAdd):
267
+ layer_info = {
268
+ 'name': name,
269
+ 'add_index': len(infos),
270
+ 'type': 'add',
271
+ 'rescale_mult': module.rescale_mult.item() if module.rescale_mult is not None else 128,
272
+ 'rescale_shift': module.rescale_shift.item() if module.rescale_shift is not None else 7,
273
+ 'scale_x': module.scale_x.item() if module.scale_x is not None else None,
274
+ 'scale_y': module.scale_y.item() if module.scale_y is not None else None,
275
+ 'act_scale': module.act_scale.item() if module.act_scale is not None else None,
276
+ }
277
+ infos.append(layer_info)
278
+ print(f" Found PoTAdd: {name}, mult={layer_info['rescale_mult']}, shift={layer_info['rescale_shift']}")
279
+
280
+ return infos
281
+
282
+
283
+ def parse_onnx_graph(onnx_model, pot_layer_infos: List[Dict], config: Config, add_layer_infos: List[Dict] = None) -> List[Dict]:
284
+ """Parse ONNX graph and build layer_infos list with graph analysis for skip connections."""
285
+ if add_layer_infos is None:
286
+ add_layer_infos = []
287
+
288
+ graph = onnx_model.graph
289
+
290
+ initializers = {}
291
+ for init in graph.initializer:
292
+ initializers[init.name] = numpy_helper.to_array(init)
293
+
294
+ # ========================================
295
+ # Step 1: Build tensor→node mapping for graph analysis
296
+ # ========================================
297
+ tensor_to_producer = {} # tensor_name → (node_idx, node)
298
+ tensor_to_shape = {} # tensor_name → (channels, h, w) or (channels, L) or (features,)
299
+ tensor_to_scale = {} # tensor_name → act_scale (CRITICAL for Add rescale!)
300
+
301
+ # Detect Conv1d mode
302
+ is_conv1d_model = (config.input_w == 1)
303
+
304
+ # Input tensor
305
+ input_name = graph.input[0].name
306
+ if is_conv1d_model:
307
+ # Conv1d: (C, L)
308
+ tensor_to_shape[input_name] = (config.input_channels, config.input_h)
309
+ else:
310
+ # Conv2d: (C, H, W)
311
+ tensor_to_shape[input_name] = (config.input_channels, config.input_h, config.input_w)
312
+ tensor_to_producer[input_name] = (-1, None) # -1 = input
313
+ tensor_to_scale[input_name] = 1.0 # Input scale
314
+
315
+ # First pass: collect all node outputs and their producers
316
+ for node_idx, node in enumerate(graph.node):
317
+ for out_name in node.output:
318
+ tensor_to_producer[out_name] = (node_idx, node)
319
+
320
+ # Build tensor→consumers mapping for forward tracing
321
+ tensor_to_consumers = {} # tensor_name → [(node_idx, node), ...]
322
+ for node_idx, node in enumerate(graph.node):
323
+ for in_name in node.input:
324
+ if in_name not in tensor_to_consumers:
325
+ tensor_to_consumers[in_name] = []
326
+ tensor_to_consumers[in_name].append((node_idx, node))
327
+
328
+ def trace_forward_to_relu(tensor_name: str) -> bool:
329
+ """순방향 추적: 텐서에서 시작해서 Relu가 나오는지 확인.
330
+
331
+ Conv → Mul → Round → Clip → Div → Relu 체인에서
332
+ Conv 출력 텐서에서 시작해서 Relu까지 따라감.
333
+
334
+ 중간 노드(Mul, Round, Clip, Div)는 pass through.
335
+ """
336
+ visited = set()
337
+ current = tensor_name
338
+
339
+ while current and current not in visited:
340
+ visited.add(current)
341
+
342
+ if current not in tensor_to_consumers:
343
+ return False
344
+
345
+ consumers = tensor_to_consumers[current]
346
+ if not consumers:
347
+ return False
348
+
349
+ # 첫 번째 consumer만 따라감 (linear chain 가정)
350
+ _, node = consumers[0]
351
+ op_type = node.op_type
352
+
353
+ if op_type == 'Relu':
354
+ return True
355
+
356
+ # Pass through 노드: 출력을 따라감
357
+ if op_type in ('Mul', 'Round', 'Clip', 'Div', 'Cast', 'Add', 'Floor'):
358
+ if node.output:
359
+ current = node.output[0]
360
+ else:
361
+ return False
362
+ else:
363
+ # 다른 노드 만나면 중단
364
+ return False
365
+
366
+ return False
367
+
368
+ def trace_to_layer(tensor_name: str, tensor_to_layer_idx: dict) -> int:
369
+ """역추적: 텐서에서 시작해서 원래 PoT 레이어까지 추적.
370
+
371
+ ONNX 그래프에서 Conv → Mul → Round → Clip → Div → Relu 같은
372
+ 체인이 있을 때, 어떤 텐서에서든 원래 Conv 레이어를 찾아냄.
373
+
374
+ 범용성: 어떤 중간 노드가 있든 첫 번째 입력을 따라감.
375
+ """
376
+ visited = set()
377
+ current = tensor_name
378
+
379
+ while current and current not in visited:
380
+ # 이미 매핑된 텐서면 바로 반환
381
+ if current in tensor_to_layer_idx:
382
+ return tensor_to_layer_idx[current]
383
+
384
+ visited.add(current)
385
+
386
+ # 이 텐서를 생성한 노드 찾기
387
+ if current in tensor_to_producer:
388
+ node_idx, node = tensor_to_producer[current]
389
+ if node is None:
390
+ # 입력 텐서
391
+ return -1
392
+ if len(node.input) > 0:
393
+ # 첫 번째 입력으로 역추적
394
+ current = node.input[0]
395
+ else:
396
+ break
397
+ else:
398
+ break
399
+
400
+ return -1 # 못 찾음
401
+
402
+ def trace_to_scale(tensor_name: str) -> float:
403
+ """텐서의 activation scale 추적.
404
+
405
+ 텐서 체인을 따라가며 가장 최근 scale 값을 찾음.
406
+ MaxPool, Flatten 등은 입력 scale을 그대로 전파.
407
+ """
408
+ visited = set()
409
+ current = tensor_name
410
+
411
+ while current and current not in visited:
412
+ if current in tensor_to_scale:
413
+ return tensor_to_scale[current]
414
+
415
+ visited.add(current)
416
+
417
+ if current in tensor_to_producer:
418
+ node_idx, node = tensor_to_producer[current]
419
+ if node is None:
420
+ return 1.0 # Input
421
+ if len(node.input) > 0:
422
+ current = node.input[0]
423
+ else:
424
+ break
425
+ else:
426
+ break
427
+
428
+ return 1.0 # Default
429
+
430
+ # ========================================
431
+ # Step 2: Process nodes in order
432
+ # ========================================
433
+ current_h = config.input_h
434
+ current_w = config.input_w
435
+ current_ch = config.input_channels
436
+
437
+ layer_infos = []
438
+ pot_index = 0
439
+ num_pot_layers = len(pot_layer_infos)
440
+
441
+ # Track which layer produces which tensor (for skip connections)
442
+ tensor_to_layer_idx = {input_name: -1} # -1 = input
443
+
444
+ for node_idx, node in enumerate(graph.node):
445
+ op_type = node.op_type
446
+ output_name = node.output[0] if node.output else None
447
+
448
+ if op_type == 'Conv':
449
+ pot_info = pot_layer_infos[pot_index] if pot_index < num_pot_layers else None
450
+
451
+ if pot_info and pot_info['type'] == 'conv1d':
452
+ # ===== Conv1D Processing =====
453
+ weight = pot_info['weight']
454
+ bias = pot_info['bias']
455
+ out_channels = pot_info['out_channels']
456
+ in_channels = pot_info['in_channels']
457
+ alpha = pot_info['alpha']
458
+ act_scale = pot_info['act_scale']
459
+
460
+ kL = pot_info['kernel_size']
461
+ stride = pot_info['stride']
462
+ padding = pot_info['padding']
463
+
464
+ # For 1D: current_h represents length, current_w is 1
465
+ in_L = current_h # Use current_h as length
466
+ out_L = (in_L + 2 * padding - kL) // stride + 1
467
+
468
+ layer_idx = len(layer_infos)
469
+
470
+ has_relu = trace_forward_to_relu(output_name) if output_name else False
471
+
472
+ info = {
473
+ 'name': f'layer_{layer_idx}',
474
+ 'type': 'PoTConv1d',
475
+ 'layer_type': 'pot',
476
+ 'weight': weight,
477
+ 'bias': bias,
478
+ 'alpha': alpha,
479
+ 'act_scale': act_scale,
480
+ 'in_channels': in_channels,
481
+ 'out_channels': out_channels,
482
+ 'kernel_size': kL,
483
+ 'stride': stride,
484
+ 'padding': padding,
485
+ 'in_L': in_L,
486
+ 'out_L': out_L,
487
+ # Compatibility with 2D: for buffer size calculation
488
+ 'in_h': in_L,
489
+ 'in_w': 1,
490
+ 'out_h': out_L,
491
+ 'out_w': 1,
492
+ 'is_first': (pot_index == 0),
493
+ 'is_last': (pot_index == num_pot_layers - 1),
494
+ 'has_relu': has_relu,
495
+ 'encoding': pot_info.get('encoding', 'unroll'),
496
+ }
497
+ layer_infos.append(info)
498
+
499
+ if output_name:
500
+ tensor_to_layer_idx[output_name] = layer_idx
501
+ tensor_to_shape[output_name] = (out_channels, out_L)
502
+ if act_scale is not None:
503
+ tensor_to_scale[output_name] = act_scale
504
+
505
+ current_h, current_w, current_ch = out_L, 1, out_channels
506
+ pot_index += 1
507
+
508
+ print(f" Conv1d: {in_channels}x{in_L} -> {out_channels}x{out_L}")
509
+
510
+ elif pot_info and pot_info['type'] in ('conv', 'depthwise'):
511
+ # ===== Conv2D Processing =====
512
+ weight = pot_info['weight']
513
+ bias = pot_info['bias']
514
+ out_channels = pot_info['out_channels']
515
+ in_channels = pot_info['in_channels']
516
+ alpha = pot_info['alpha']
517
+ act_scale = pot_info['act_scale'] # Can be None
518
+ is_depthwise = (pot_info['type'] == 'depthwise')
519
+
520
+ # Use kernel_size, stride, padding from PyTorch model
521
+ kh_tuple = pot_info['kernel_size']
522
+ stride_tuple = pot_info['stride']
523
+ padding_tuple = pot_info['padding']
524
+
525
+ # Normalize to (h, w) tuples
526
+ if isinstance(kh_tuple, int): kh_tuple = (kh_tuple, kh_tuple)
527
+ if isinstance(stride_tuple, int): stride_tuple = (stride_tuple, stride_tuple)
528
+ if isinstance(padding_tuple, int): padding_tuple = (padding_tuple, padding_tuple)
529
+
530
+ kh, kw = kh_tuple
531
+ sh, sw = stride_tuple
532
+ ph, pw = padding_tuple
533
+
534
+ # [Support Conv1d]
535
+ # If weight is 3D (Out, In, Length), reshape to 4D (Out, In, 1, Length)
536
+ # to satisfy generic Conv2d generators.
537
+ if weight.ndim == 3:
538
+ weight = weight.unsqueeze(2) # (N, C, L) -> (N, C, 1, L)
539
+ # For Conv1d, kh is usually tuple (k,) or int k
540
+ # If it was 1D kernel, now it's effectively 1xW.
541
+ # We might need to ensure kh/kw are correct?
542
+ # The generator reads shape from w_q.shape usually.
543
+ pass
544
+ else:
545
+ raise RuntimeError(f"PoT layer mismatch at index {pot_index}")
546
+ input_tensor_name = node.input[0]
547
+ if input_tensor_name in tensor_to_shape:
548
+ input_shape = tensor_to_shape[input_tensor_name]
549
+ in_h = input_shape[1]
550
+ in_w = input_shape[2] if len(input_shape) > 2 else 1
551
+ else:
552
+ in_h, in_w = current_h, current_w # fallback
553
+ # Continue only for Conv2D (Conv1D handled above and continues)
554
+ if pot_info and pot_info['type'] in ('conv', 'depthwise'):
555
+ out_h = (in_h + 2 * ph - kh) // sh + 1
556
+ out_w = (in_w + 2 * pw - kw) // sw + 1
557
+
558
+ # Determine layer type
559
+ if is_depthwise:
560
+ layer_type_name = 'PoTDepthwiseConv2d'
561
+ else:
562
+ layer_type_name = 'PoTConv2d'
563
+
564
+ layer_idx = len(layer_infos)
565
+
566
+ # Check if ReLU follows this Conv (via ONNX chain)
567
+ has_relu = trace_forward_to_relu(output_name) if output_name else False
568
+
569
+ info = {
570
+ 'name': f'layer_{layer_idx}',
571
+ 'type': layer_type_name,
572
+ 'layer_type': 'pot',
573
+ 'weight': weight,
574
+ 'bias': bias,
575
+ 'alpha': alpha,
576
+ 'act_scale': act_scale, # Keep None as None
577
+ 'in_channels': in_channels,
578
+ 'out_channels': out_channels,
579
+ 'kernel_size': kh,
580
+ 'stride': sh if sh == sw else (sh, sw),
581
+ 'padding': ph if ph == pw else (ph, pw),
582
+ 'in_h': in_h,
583
+ 'in_w': in_w,
584
+ 'out_h': out_h,
585
+ 'out_w': out_w,
586
+ 'is_first': (pot_index == 0),
587
+ 'is_last': (pot_index == num_pot_layers - 1),
588
+ 'has_relu': has_relu,
589
+ 'encoding': pot_info.get('encoding', 'unroll'),
590
+ 'groups': out_channels if is_depthwise else pot_info.get('groups', 1),
591
+ }
592
+ layer_infos.append(info)
593
+
594
+ # Track output tensor
595
+ if output_name:
596
+ tensor_to_layer_idx[output_name] = layer_idx
597
+ tensor_to_shape[output_name] = (out_channels, out_h, out_w)
598
+ # Track scale for Add rescale calculation
599
+ if act_scale is not None:
600
+ tensor_to_scale[output_name] = act_scale
601
+
602
+ current_h, current_w, current_ch = out_h, out_w, out_channels
603
+ pot_index += 1
604
+
605
+ layer_desc = 'DepthwiseConv' if is_depthwise else 'Conv'
606
+ print(f" {layer_desc}: {in_channels}x{info['in_h']}x{info['in_w']} -> {out_channels}x{out_h}x{out_w}")
607
+
608
+ elif op_type == 'MaxPool':
609
+ kernel_shape = get_attribute(node, 'kernel_shape', [2, 2])
610
+ strides = get_attribute(node, 'strides', [2, 2])
611
+
612
+ # Detect 1D vs 2D based on kernel_shape length or is_conv1d_model
613
+ is_maxpool_1d = (len(kernel_shape) == 1) or is_conv1d_model
614
+
615
+ k = kernel_shape[0]
616
+ s = strides[0]
617
+
618
+ if is_maxpool_1d:
619
+ # MaxPool1d: only update h (length)
620
+ out_h = current_h // s
621
+ out_w = 1 # Keep as 1 for 1D
622
+ pool_type = 'MaxPool1d'
623
+ else:
624
+ # MaxPool2d
625
+ out_h = current_h // s
626
+ out_w = current_w // s
627
+ pool_type = 'MaxPool2d'
628
+
629
+ layer_idx = len(layer_infos)
630
+ info = {
631
+ 'name': f'layer_{layer_idx}',
632
+ 'type': pool_type,
633
+ 'layer_type': 'maxpool',
634
+ 'kernel_size': k,
635
+ 'stride': s,
636
+ 'in_h': current_h,
637
+ 'in_w': current_w,
638
+ 'in_channels': current_ch,
639
+ 'out_h': out_h,
640
+ 'out_w': out_w,
641
+ 'is_1d': is_maxpool_1d,
642
+ }
643
+ layer_infos.append(info)
644
+
645
+ # Track output tensor
646
+ if output_name:
647
+ tensor_to_layer_idx[output_name] = layer_idx
648
+ if is_maxpool_1d:
649
+ tensor_to_shape[output_name] = (current_ch, out_h)
650
+ else:
651
+ tensor_to_shape[output_name] = (current_ch, out_h, out_w)
652
+ # Propagate scale from input (MaxPool doesn't change scale)
653
+ input_name = node.input[0] if len(node.input) > 0 else None
654
+ if input_name:
655
+ input_scale = trace_to_scale(input_name)
656
+ tensor_to_scale[output_name] = input_scale
657
+
658
+ current_h, current_w = out_h, out_w
659
+
660
+ if is_maxpool_1d:
661
+ print(f" MaxPool1d: {current_ch}x{info['in_h']} -> {current_ch}x{out_h}")
662
+ else:
663
+ print(f" MaxPool2d: {current_ch}x{info['in_h']}x{info['in_w']} -> {current_ch}x{out_h}x{out_w}")
664
+
665
+ elif op_type in ('GlobalAveragePool', 'ReduceMean'):
666
+ # GlobalAveragePool: C×H×W → C×1×1
667
+ # ReduceMean with axes=[2,3]: same effect
668
+
669
+ # Check if it's ReduceMean over spatial dims
670
+ if op_type == 'ReduceMean':
671
+ axes = get_attribute(node, 'axes', None)
672
+
673
+ # axes might be in input[1] instead of attribute (newer ONNX)
674
+ if axes is None and len(node.input) > 1:
675
+ axes_name = node.input[1]
676
+ if axes_name in initializers:
677
+ axes = initializers[axes_name].tolist()
678
+
679
+ # axes=[2,3] or axes=[-2,-1] means spatial reduction
680
+ if axes is None:
681
+ continue
682
+ if set(axes) not in ({2, 3}, {-2, -1}):
683
+ continue
684
+
685
+ layer_idx = len(layer_infos)
686
+ pool_size = current_h * current_w
687
+
688
+ # Compute div_mult and div_shift for integer division
689
+ if pool_size > 0 and (pool_size & (pool_size - 1)) == 0:
690
+ # Power of 2
691
+ import math
692
+ div_mult = 1
693
+ div_shift = int(math.log2(pool_size))
694
+ else:
695
+ # Not power of 2: avg ≈ (sum * div_mult) >> div_shift
696
+ import math
697
+ base_shift = 15
698
+ div_mult = round((1 << base_shift) / pool_size)
699
+ while div_mult > 255 and base_shift > 8:
700
+ base_shift -= 1
701
+ div_mult = round((1 << base_shift) / pool_size)
702
+ div_shift = base_shift
703
+ div_mult = max(1, min(65535, div_mult))
704
+
705
+ info = {
706
+ 'name': f'layer_{layer_idx}',
707
+ 'type': 'GlobalAvgPool',
708
+ 'layer_type': 'global_avg_pool',
709
+ 'in_h': current_h,
710
+ 'in_w': current_w,
711
+ 'in_channels': current_ch,
712
+ 'out_channels': current_ch,
713
+ 'pool_size': pool_size,
714
+ 'div_mult': div_mult,
715
+ 'div_shift': div_shift,
716
+ }
717
+ layer_infos.append(info)
718
+
719
+ # Track output tensor
720
+ if output_name:
721
+ tensor_to_layer_idx[output_name] = layer_idx
722
+ tensor_to_shape[output_name] = (current_ch, 1, 1)
723
+
724
+ # After global avg pool: H=1, W=1
725
+ print(f" GlobalAvgPool: {current_ch}x{current_h}x{current_w} -> {current_ch} (pool_size={pool_size}, mult={div_mult}, shift={div_shift})")
726
+ current_h, current_w = 1, 1
727
+
728
+ elif op_type in ('Flatten', 'Reshape'):
729
+ shape = None
730
+ if op_type == 'Reshape' and len(node.input) > 1:
731
+ shape_name = node.input[1]
732
+ if shape_name in initializers:
733
+ shape = initializers[shape_name].tolist()
734
+
735
+ if op_type == 'Flatten' or (shape and len(shape) == 2):
736
+ out_features = current_h * current_w * current_ch
737
+
738
+ layer_idx = len(layer_infos)
739
+ info = {
740
+ 'name': f'layer_{layer_idx}',
741
+ 'type': 'Flatten',
742
+ 'layer_type': 'flatten',
743
+ 'in_h': current_h,
744
+ 'in_w': current_w,
745
+ 'in_channels': current_ch,
746
+ 'out_features': out_features,
747
+ }
748
+ layer_infos.append(info)
749
+
750
+ # Track output tensor
751
+ if output_name:
752
+ tensor_to_layer_idx[output_name] = layer_idx
753
+
754
+ print(f" Flatten: {current_ch}x{current_h}x{current_w} -> {out_features}")
755
+ current_h, current_w = 1, 1
756
+
757
+ elif op_type == 'Gemm':
758
+ pot_info = pot_layer_infos[pot_index] if pot_index < num_pot_layers else None
759
+
760
+ if pot_info and pot_info['type'] == 'linear':
761
+ weight = pot_info['weight']
762
+ bias = pot_info['bias']
763
+ in_features = pot_info['in_features']
764
+ out_features = pot_info['out_features']
765
+ alpha = pot_info['alpha']
766
+ act_scale = pot_info['act_scale'] # Can be None
767
+ else:
768
+ raise RuntimeError(f"PoT layer mismatch at index {pot_index}")
769
+
770
+ layer_idx = len(layer_infos)
771
+
772
+ # Check if ReLU follows this Linear (via ONNX chain)
773
+ has_relu = trace_forward_to_relu(output_name) if output_name else False
774
+
775
+ info = {
776
+ 'name': f'layer_{layer_idx}',
777
+ 'type': 'PoTLinear',
778
+ 'layer_type': 'pot',
779
+ 'weight': weight,
780
+ 'bias': bias,
781
+ 'alpha': alpha,
782
+ 'act_scale': act_scale, # Keep None as None
783
+ 'in_features': in_features,
784
+ 'out_features': out_features,
785
+ 'is_first': (pot_index == 0),
786
+ 'is_last': (pot_index == num_pot_layers - 1),
787
+ 'has_relu': has_relu,
788
+ 'encoding': pot_info.get('encoding', 'unroll'),
789
+ }
790
+ layer_infos.append(info)
791
+
792
+ # Track output tensor
793
+ if output_name:
794
+ tensor_to_layer_idx[output_name] = layer_idx
795
+ # Track scale (Linear is usually last, but just in case)
796
+ if act_scale is not None:
797
+ tensor_to_scale[output_name] = act_scale
798
+
799
+ current_ch = out_features
800
+ pot_index += 1
801
+
802
+ print(f" Linear: {in_features} -> {out_features}")
803
+
804
+ elif op_type == 'Relu':
805
+ # ReLU: pass through tensor mapping and scale
806
+ if len(node.input) > 0 and len(node.output) > 0:
807
+ relu_input = node.input[0]
808
+ relu_output = node.output[0]
809
+ if relu_input in tensor_to_layer_idx:
810
+ tensor_to_layer_idx[relu_output] = tensor_to_layer_idx[relu_input]
811
+ if relu_input in tensor_to_shape:
812
+ tensor_to_shape[relu_output] = tensor_to_shape[relu_input]
813
+ # Propagate scale
814
+ input_scale = trace_to_scale(relu_input)
815
+ tensor_to_scale[relu_output] = input_scale
816
+
817
+ elif op_type in ('Mul', 'Round', 'Clip', 'Div'):
818
+ # 양자화 관련 중간 노드: 텐서 매핑 및 scale 전파
819
+ if len(node.input) > 0 and len(node.output) > 0:
820
+ node_input = node.input[0]
821
+ node_output = node.output[0]
822
+ if node_input in tensor_to_layer_idx:
823
+ tensor_to_layer_idx[node_output] = tensor_to_layer_idx[node_input]
824
+ if node_input in tensor_to_shape:
825
+ tensor_to_shape[node_output] = tensor_to_shape[node_input]
826
+ # Propagate scale
827
+ input_scale = trace_to_scale(node_input)
828
+ tensor_to_scale[node_output] = input_scale
829
+
830
+ elif op_type == 'Add':
831
+ # ========================================
832
+ # ONNX Add 노드 분류 (robust version)
833
+ #
834
+ # ONNX에서 Add 노드가 생기는 경우:
835
+ # 1. Bias addition: Conv/Linear의 bias (한 입력이 initializer)
836
+ # 2. Skip connection: x + conv(x) (두 입력 모두 레이어 출력)
837
+ # 3. 기타 산술 연산
838
+ #
839
+ # 판단 기준:
840
+ # - 한 입력이 initializer → bias add → 무시
841
+ # - 두 입력의 소스 레이어가 같음 → 무시
842
+ # - 두 입력의 소스 레이어가 다름 → skip connection
843
+ # ========================================
844
+
845
+ input_a = node.input[0] if len(node.input) > 0 else None
846
+ input_b = node.input[1] if len(node.input) > 1 else None
847
+
848
+ # ------------------------------------------
849
+ # Case 1: 한 입력이 initializer (상수/bias)
850
+ # ------------------------------------------
851
+ a_is_initializer = input_a in initializers
852
+ b_is_initializer = input_b in initializers
853
+
854
+ if a_is_initializer or b_is_initializer:
855
+ # Bias addition - NOT a skip connection
856
+ # Just propagate scale, shape, and layer index
857
+ print(f" [DEBUG] Bias Add detected (input is initializer), skipping")
858
+ if output_name:
859
+ non_const_input = input_b if a_is_initializer else input_a
860
+
861
+ # Propagate scale
862
+ input_scale = trace_to_scale(non_const_input)
863
+ tensor_to_scale[output_name] = input_scale
864
+
865
+ # Propagate shape by tracing back
866
+ traced_shape = None
867
+ trace_name = non_const_input
868
+ trace_visited = set()
869
+ while trace_name and trace_name not in trace_visited:
870
+ trace_visited.add(trace_name)
871
+ if trace_name in tensor_to_shape:
872
+ traced_shape = tensor_to_shape[trace_name]
873
+ break
874
+ if trace_name in tensor_to_producer:
875
+ _, prod_node = tensor_to_producer[trace_name]
876
+ if prod_node and prod_node.input:
877
+ trace_name = prod_node.input[0]
878
+ else:
879
+ break
880
+ else:
881
+ break
882
+ if traced_shape:
883
+ tensor_to_shape[output_name] = traced_shape
884
+
885
+ # Propagate layer index (this Add is pass-through)
886
+ traced_layer = trace_to_layer(non_const_input, tensor_to_layer_idx)
887
+ if traced_layer >= 0:
888
+ tensor_to_layer_idx[output_name] = traced_layer
889
+ continue
890
+
891
+ # ------------------------------------------
892
+ # Case 2: 둘 다 텐서 - 소스 레이어 분석
893
+ # ------------------------------------------
894
+ source_a = trace_to_layer(input_a, tensor_to_layer_idx)
895
+ source_b = trace_to_layer(input_b, tensor_to_layer_idx)
896
+
897
+ print(f" [DEBUG] Add node: source_a={source_a}, source_b={source_b}")
898
+
899
+ # 같은 소스에서 오면 skip connection 아님
900
+ if source_a == source_b:
901
+ print(f" [DEBUG] Same source ({source_a}), not a skip connection - skipping")
902
+ if output_name:
903
+ input_scale = trace_to_scale(input_a)
904
+ tensor_to_scale[output_name] = input_scale
905
+ # Propagate shape
906
+ if input_a in tensor_to_shape:
907
+ tensor_to_shape[output_name] = tensor_to_shape[input_a]
908
+ # Propagate layer index
909
+ if source_a >= 0:
910
+ tensor_to_layer_idx[output_name] = source_a
911
+ continue
912
+
913
+ # Case 2.5: 한쪽이 -1이면 (입력 텐서에서 직접) skip connection 아님
914
+ if source_a < 0 or source_b < 0:
915
+ print(f" [DEBUG] One source is input tensor ({source_a}, {source_b}), not a skip connection - skipping")
916
+ if output_name:
917
+ # 유효한 소스의 정보를 전파
918
+ valid_source = max(source_a, source_b)
919
+ valid_tensor = input_a if source_a >= 0 else input_b
920
+ input_scale = trace_to_scale(valid_tensor)
921
+ tensor_to_scale[output_name] = input_scale
922
+ if valid_tensor in tensor_to_shape:
923
+ tensor_to_shape[output_name] = tensor_to_shape[valid_tensor]
924
+ if valid_source >= 0:
925
+ tensor_to_layer_idx[output_name] = valid_source
926
+ continue
927
+
928
+ # ------------------------------------------
929
+ # Case 3: 다른 소스 = 실제 skip connection!
930
+ # ------------------------------------------
931
+ print(f" [DEBUG] Skip connection detected: {source_a} vs {source_b}")
932
+
933
+ # skip은 더 오래된(더 작은 인덱스) 레이어에서 온 것
934
+ if source_a < source_b:
935
+ skip_source = source_a
936
+ conv_source = source_b
937
+ skip_tensor = input_a
938
+ conv_tensor = input_b
939
+ else:
940
+ skip_source = source_b
941
+ conv_source = source_a
942
+ skip_tensor = input_b
943
+ conv_tensor = input_a
944
+
945
+ # Count how many PoTAdd layers we've already processed
946
+ add_count = sum(1 for l in layer_infos if l.get('type') == 'PoTAdd')
947
+
948
+ # ========================================
949
+ # Rescale 파라미터 결정
950
+ # ========================================
951
+
952
+ # Option 1: PyTorch PoTAdd 모듈에서 calibration 값 사용
953
+ if add_count < len(add_layer_infos) and add_layer_infos[add_count]['scale_x'] is not None:
954
+ skip_act_scale = add_layer_infos[add_count]['scale_x']
955
+ conv_act_scale = add_layer_infos[add_count]['scale_y']
956
+ rescale_mult = add_layer_infos[add_count]['rescale_mult']
957
+ rescale_shift = add_layer_infos[add_count]['rescale_shift']
958
+ print(f" Add rescale (from PoTAdd calibration): skip_scale={skip_act_scale:.4f}, "
959
+ f"conv_scale={conv_act_scale:.4f}, mult={rescale_mult}, shift={rescale_shift}")
960
+
961
+ # Option 2: PoTAdd 없이 skip connection 구현된 경우 - ONNX scale 추적
962
+ else:
963
+ skip_act_scale = trace_to_scale(skip_tensor)
964
+ conv_act_scale = trace_to_scale(conv_tensor)
965
+
966
+ # Calculate rescale ratio: convert skip scale to conv scale
967
+ # skip_aligned = skip * ratio 해서 conv와 같은 scale로 맞춤
968
+ if skip_act_scale != 0 and skip_act_scale != conv_act_scale:
969
+ ratio = conv_act_scale / skip_act_scale
970
+ else:
971
+ ratio = 1.0
972
+
973
+ # Convert to integer arithmetic: mult * x >> shift ≈ ratio * x
974
+ rescale_shift = 7
975
+ rescale_mult = round(ratio * (1 << rescale_shift))
976
+ rescale_mult = max(1, min(rescale_mult, 512)) # Clamp to valid range
977
+
978
+ if add_count >= len(add_layer_infos):
979
+ print(f" Add rescale (no PoTAdd module - direct skip): "
980
+ f"skip_scale={skip_act_scale:.4f}, conv_scale={conv_act_scale:.4f}, "
981
+ f"ratio={ratio:.4f}, mult={rescale_mult}, shift={rescale_shift}")
982
+ else:
983
+ print(f" Add rescale (fallback - PoTAdd uncalibrated): "
984
+ f"skip_scale={skip_act_scale:.4f}, conv_scale={conv_act_scale:.4f}, "
985
+ f"ratio={ratio:.4f}, mult={rescale_mult}, shift={rescale_shift}")
986
+
987
+ layer_idx = len(layer_infos)
988
+
989
+ # Check if ReLU follows this Add (via ONNX chain)
990
+ has_relu = trace_forward_to_relu(output_name) if output_name else False
991
+
992
+ info = {
993
+ 'name': f'layer_{layer_idx}',
994
+ 'type': 'PoTAdd',
995
+ 'layer_type': 'add',
996
+ 'in_h': current_h,
997
+ 'in_w': current_w,
998
+ 'in_channels': current_ch,
999
+ 'out_h': current_h,
1000
+ 'out_w': current_w,
1001
+ 'out_channels': current_ch,
1002
+ 'rescale_mult': rescale_mult,
1003
+ 'rescale_shift': rescale_shift,
1004
+ 'skip_source_layer': skip_source, # skip이 시작된 레이어
1005
+ 'conv_source_layer': conv_source, # conv 경로의 마지막 레이어
1006
+ 'act_scale': conv_act_scale, # Add output uses conv's scale
1007
+ 'has_relu': has_relu,
1008
+ }
1009
+ layer_infos.append(info)
1010
+
1011
+ # Track output tensor and its scale
1012
+ if output_name:
1013
+ tensor_to_layer_idx[output_name] = layer_idx
1014
+ tensor_to_shape[output_name] = (current_ch, current_h, current_w)
1015
+ tensor_to_scale[output_name] = conv_act_scale
1016
+
1017
+ print(f" Add: {current_ch}x{current_h}x{current_w} (skip from layer_{skip_source}, conv from layer_{conv_source})")
1018
+
1019
+ return layer_infos
1020
+
1021
+
1022
+ def get_attribute(node, name: str, default=None):
1023
+ """Get attribute value from ONNX node."""
1024
+ for attr in node.attribute:
1025
+ if attr.name == name:
1026
+ if attr.type == onnx.AttributeProto.INTS:
1027
+ return list(attr.ints)
1028
+ elif attr.type == onnx.AttributeProto.INT:
1029
+ return attr.i
1030
+ elif attr.type == onnx.AttributeProto.FLOATS:
1031
+ return list(attr.floats)
1032
+ elif attr.type == onnx.AttributeProto.FLOAT:
1033
+ return attr.f
1034
+ return default
1035
+
1036
+
1037
+ def absorb_standardization(layer_infos: List[Dict], mean: List[float], std: List[float]):
1038
+ """Absorb input standardization into first PoT layer.
1039
+
1040
+ CRITICAL: Uses avg_std everywhere to match QAT exactly.
1041
+ - mean → bias: b' = b - Σ_c (mean[c]/avg_std) × ΣW[:,c,:,:] × α
1042
+ - /std → combined_scale: uses average std
1043
+
1044
+ Args:
1045
+ layer_infos: List of layer info dicts
1046
+ mean: Per-channel mean values as list (e.g., [0.4914, 0.4822, 0.4465] for CIFAR-10)
1047
+ std: Per-channel std values as list (e.g., [0.2470, 0.2435, 0.2616] for CIFAR-10)
1048
+
1049
+ Note: bias scaling by act_scale is done separately in scale_biases()
1050
+ """
1051
+ # Calculate avg_std upfront
1052
+ avg_std = sum(std) / len(std) if std else 1.0
1053
+
1054
+ for info in layer_infos:
1055
+ if info['layer_type'] == 'pot' and info.get('is_first', False):
1056
+ weight = info['weight']
1057
+ bias = info['bias']
1058
+ alpha = info['alpha']
1059
+ layer_type = info['type']
1060
+
1061
+ if weight is None:
1062
+ return
1063
+
1064
+ in_channels = weight.shape[1]
1065
+
1066
+ # Validate channel count
1067
+ if len(mean) != in_channels or len(std) != in_channels:
1068
+ raise ValueError(
1069
+ f"mean/std length ({len(mean)}/{len(std)}) must match "
1070
+ f"first layer's in_channels ({in_channels})"
1071
+ )
1072
+
1073
+ if bias is None:
1074
+ bias = torch.zeros(weight.shape[0])
1075
+
1076
+ # Use quantized weights for bias absorption (info['weight'] is already quantized)
1077
+ w_q = weight
1078
+
1079
+ # Channel-wise bias adjustment using avg_std
1080
+ # b' = b - Σ_c (mean[c]/avg_std) × ΣW_q[:,c,...] × α
1081
+ for c in range(in_channels):
1082
+ if layer_type == 'PoTConv1d':
1083
+ # Conv1D: weight shape is [out_ch, in_ch, kernel_size]
1084
+ weight_sum_c = w_q[:, c, :].sum(dim=1) # [out_ch]
1085
+ else:
1086
+ # Conv2D: weight shape is [out_ch, in_ch, kh, kw]
1087
+ weight_sum_c = w_q[:, c, :, :].sum(dim=(1, 2)) # [out_ch]
1088
+
1089
+ bias = bias - (mean[c] / avg_std) * weight_sum_c * alpha
1090
+
1091
+
1092
+ info['bias'] = bias
1093
+ info['input_std'] = avg_std
1094
+
1095
+ print(f" Standardization absorbed into {info['name']}: mean={mean}, avg_std={avg_std:.4f}")
1096
+ break
1097
+
1098
+
1099
+ def calculate_combined_scales(layer_infos: List[Dict], config: Config):
1100
+ """Calculate combined scale for each PoT layer.
1101
+
1102
+ Per specification section 6.3 and 9.x:
1103
+ - First layer: combined_scale = α × act_scale / std
1104
+ combined_shift = base_shift + 8 (/256 absorbed)
1105
+ - Other layers: combined_scale = α × act_scale / prev_act_scale
1106
+ combined_shift = base_shift
1107
+ - Last layer (act_scale=None): only α / prev_act_scale (no output quantization)
1108
+ """
1109
+ prev_act_scale = 1.0
1110
+
1111
+ for info in layer_infos:
1112
+ # CRITICAL: Update prev_act_scale for Add layers!
1113
+ # Add layer changes the scale, so subsequent layers need to use Add's output scale
1114
+ if info['layer_type'] == 'add':
1115
+ add_act_scale = info.get('act_scale')
1116
+ if add_act_scale is not None:
1117
+ print(f" {info['name']}: Add layer, updating prev_act_scale to {add_act_scale:.4f}")
1118
+ prev_act_scale = add_act_scale
1119
+ continue
1120
+
1121
+ if info['layer_type'] != 'pot':
1122
+ continue
1123
+
1124
+ is_first = info.get('is_first', False)
1125
+ is_last = info.get('is_last', False)
1126
+ alpha = info['alpha']
1127
+ act_scale = info.get('act_scale') # Can be None
1128
+
1129
+ # Calculate combined scale
1130
+ if is_first:
1131
+ input_std = info.get('input_std', 1.0)
1132
+ if act_scale is not None:
1133
+ # combined_scale = α × act_scale / std
1134
+ combined_scale = alpha * act_scale / input_std
1135
+ else:
1136
+ # Last layer is also first layer (single layer model)
1137
+ combined_scale = alpha / input_std
1138
+ else:
1139
+ if act_scale is not None:
1140
+ # combined_scale = α × act_scale / prev_act_scale
1141
+ combined_scale = alpha * act_scale / prev_act_scale
1142
+ else:
1143
+ # Last layer: no act_scale, output is raw logits
1144
+ combined_scale = alpha / prev_act_scale
1145
+
1146
+ # Determine shift amount for integer approximation
1147
+ base_shift = 0
1148
+ scale_magnitude = abs(combined_scale)
1149
+
1150
+ # Target: scale_int around 64-256 for precision
1151
+ while scale_magnitude < 64 and base_shift < 24:
1152
+ scale_magnitude *= 2
1153
+ base_shift += 1
1154
+ while scale_magnitude > 512 and base_shift > 0:
1155
+ scale_magnitude /= 2
1156
+ base_shift -= 1
1157
+
1158
+ # For first layer, add +8 for /256 absorption
1159
+ if is_first:
1160
+ combined_shift = base_shift + 8
1161
+ else:
1162
+ combined_shift = base_shift
1163
+
1164
+ # Calculate integer scale
1165
+ scale_int = round(combined_scale * (1 << base_shift))
1166
+
1167
+ info['combined_scale'] = combined_scale
1168
+ info['scale_int'] = scale_int
1169
+ info['combined_shift'] = combined_shift
1170
+ info['base_shift'] = base_shift
1171
+
1172
+ print(f" {info['name']}: combined_scale={combined_scale:.6f}, scale_int={scale_int}, shift={combined_shift}, act_scale={act_scale}")
1173
+
1174
+
1175
+ # Update prev_act_scale for next layer
1176
+ if act_scale is not None:
1177
+ prev_act_scale = act_scale
1178
+
1179
+
1180
+ def scale_biases(layer_infos: List[Dict]):
1181
+ """Scale all biases by act_scale.
1182
+
1183
+ CRITICAL: In PyTorch training:
1184
+ output = conv(x) + bias
1185
+ output_quantized = output * act_scale
1186
+
1187
+ So bias is also multiplied by act_scale!
1188
+
1189
+ For last layer (act_scale=None), bias is used as-is (no quantization).
1190
+ """
1191
+ for info in layer_infos:
1192
+ if info['layer_type'] != 'pot':
1193
+ continue
1194
+
1195
+ bias = info.get('bias')
1196
+ if bias is None:
1197
+ info['bias_scaled'] = None
1198
+ continue
1199
+
1200
+ act_scale = info.get('act_scale')
1201
+
1202
+ if act_scale is not None:
1203
+ # Scale bias by act_scale
1204
+ # Use floor(x + 0.5) to match round_half_up_ste in PyTorch modules
1205
+ # (torch.round uses bankers rounding which causes ±1 mismatch)
1206
+ bias_scaled = torch.floor(bias * act_scale + 0.5).to(torch.int32)
1207
+ info['bias_scaled'] = bias_scaled
1208
+ print(f" {info['name']}: bias scaled by act_scale={act_scale:.2f}, range=[{bias_scaled.min()}, {bias_scaled.max()}]")
1209
+ else:
1210
+ # Last layer: no act_scale
1211
+ # Use floor(x + 0.5) to match round_half_up_ste
1212
+ bias_scaled = torch.floor(bias + 0.5).to(torch.int32)
1213
+ info['bias_scaled'] = bias_scaled
1214
+ print(f" {info['name']}: last layer, bias not scaled, range=[{bias_scaled.min()}, {bias_scaled.max()}]")
1215
+
1216
+
1217
+ def generate_header(layer_infos: List[Dict[str, Any]], config: Config, optimized: bool = True) -> str:
1218
+ """Generate complete C header file.
1219
+
1220
+ Args:
1221
+ layer_infos: List of layer information dictionaries
1222
+ config: Configuration with MCU specs
1223
+ optimized: If True, use optimized code generation
1224
+ """
1225
+ code = []
1226
+
1227
+ code.append("#ifndef POTNN_MODEL_H")
1228
+ code.append("#define POTNN_MODEL_H")
1229
+ code.append("")
1230
+ code.append("#include <stdint.h>")
1231
+ code.append("#include <string.h>") # for memset
1232
+ code.append("")
1233
+
1234
+ code.append("/*")
1235
+ code.append(f" * PoT Quantized Neural Network")
1236
+ code.append(f" * Target: {config.ram}B RAM, {config.flash}B Flash")
1237
+ if config.mean is not None:
1238
+ # Format list values nicely
1239
+ mean_str = ', '.join(f'{m:.4f}' for m in config.mean)
1240
+ std_str = ', '.join(f'{s:.4f}' for s in config.std)
1241
+ code.append(f" * Input normalization: mean=[{mean_str}], std=[{std_str}]")
1242
+ code.append(f" * (absorbed into first layer: mean→bias, /std→scale, /256→shift)")
1243
+ code.append(f" * Input: uint8 [0,255] - no runtime normalization needed")
1244
+ if config.input_w == 1:
1245
+ # Conv1d: CxL format
1246
+ code.append(f" * Input size: {config.input_channels}x{config.input_h}")
1247
+ else:
1248
+ # Conv2d: CxHxW format
1249
+ code.append(f" * Input size: {config.input_channels}x{config.input_h}x{config.input_w}")
1250
+ if optimized:
1251
+ code.append(f" * Optimized: Full Pipeline (loop), Zero-Padding (unroll)")
1252
+ code.append(f" * Generated by potnn v8 (all bugs fixed)")
1253
+ code.append(" */")
1254
+ code.append("")
1255
+
1256
+ # Generate scale functions
1257
+ code.append("// Scale functions (combined_scale MUL, 1회/layer)")
1258
+ for info in layer_infos:
1259
+ if info['layer_type'] == 'pot':
1260
+ scale_func = generate_scale_function(info)
1261
+ code.append(scale_func)
1262
+ code.append("")
1263
+
1264
+ # Generate layer functions
1265
+ code.append("// Layer forward functions")
1266
+ for i, info in enumerate(layer_infos):
1267
+ if info['layer_type'] == 'pot':
1268
+ is_first = info.get('is_first', False)
1269
+ # CRITICAL FIX: Disable optimization for first layer due to bias generation bug (missing bias in Ch0)
1270
+ use_opt = optimized and not is_first
1271
+ layer_code = generate_pot_layer(info, is_first_layer=is_first, optimized=use_opt)
1272
+ code.append(layer_code)
1273
+ elif info['layer_type'] == 'maxpool':
1274
+ layer_code = generate_maxpool_layer(info)
1275
+ code.append(layer_code)
1276
+ elif info['layer_type'] == 'add':
1277
+ layer_code = generate_add_layer(info)
1278
+ code.append(layer_code)
1279
+ elif info['layer_type'] == 'global_avg_pool':
1280
+ layer_code = generate_global_avg_pool_layer(info)
1281
+ code.append(layer_code)
1282
+
1283
+ # Generate main predict function
1284
+ code.append("// Main prediction function")
1285
+ code.append(generate_predict_function(layer_infos, config))
1286
+
1287
+ code.append("#endif // POTNN_MODEL_H")
1288
+
1289
+ return '\n'.join(code)
1290
+
1291
+
1292
+ def generate_scale_function(info: Dict) -> str:
1293
+ """Generate scale function using MUL + shift."""
1294
+ name = info['name']
1295
+ scale_int = info['scale_int']
1296
+ shift = info['combined_shift']
1297
+
1298
+ code = []
1299
+ code.append(f"// {name}: combined_scale={info['combined_scale']:.6f}, scale_int={scale_int}, shift={shift}")
1300
+ code.append(f"static inline int32_t scale_{name}(int32_t x) {{")
1301
+
1302
+ # Handle edge cases to avoid undefined behavior (1 << -1)
1303
+ if shift == 0:
1304
+ code.append(f" return (int32_t)((int64_t)x * {scale_int});")
1305
+ elif shift == 1:
1306
+ code.append(f" return (int32_t)(((int64_t)x * {scale_int} + 1) >> 1);")
1307
+ else:
1308
+ code.append(f" return (int32_t)(((int64_t)x * {scale_int} + (1 << {shift-1})) >> {shift});")
1309
+
1310
+ code.append("}")
1311
+ code.append("")
1312
+
1313
+ return '\n'.join(code)
1314
+
1315
+
1316
+ def generate_pot_layer(info: Dict, is_first_layer: bool = False, optimized: bool = True) -> str:
1317
+ """Generate PoT Conv2d, DepthwiseConv2d, Linear, Add, or GlobalAvgPool layer code.
1318
+
1319
+ Supports multiple encoding modes:
1320
+ - 'unroll': weights embedded as shift-add operations (11 levels, fastest)
1321
+ - 'fp130': FP1.3.0 packed weights (16 levels, no zero)
1322
+ - '5level': skip-encoded weights (5 levels: -8,-1,0,+1,+8)
1323
+ - '2bit': minimal 2-bit encoding (4 levels: ±1,±2)
1324
+ - 'ternary': ternary encoding (3 levels: -1,0,+1)
1325
+ - 'loop': weights in packed table with loop (slower, less Flash) [legacy]
1326
+
1327
+ When optimized=True:
1328
+ - Loop mode: Full Pipeline (Zero-Padding + im2col + Column-wise + Position Blocking + Shift Grouping)
1329
+ - Unroll mode: Zero-Padding (eliminates boundary checks)
1330
+ """
1331
+ layer_type = info['type']
1332
+ code_mode = info.get('code_mode', 'unroll')
1333
+ levels = info.get('levels', 11)
1334
+ encoding = info.get('encoding', 'unroll')
1335
+
1336
+ # Debug: print encoding for each layer
1337
+ print(f" [CODEGEN] {info['name']}: encoding={encoding}, code_mode={code_mode}")
1338
+
1339
+ # Non-PoT layers always use fixed implementations
1340
+ if layer_type == 'PoTAdd':
1341
+ return generate_add_layer(info)
1342
+ elif layer_type == 'GlobalAvgPool':
1343
+ return generate_global_avg_pool_layer(info)
1344
+
1345
+ # Encoding-specific code generation (overrides code_mode)
1346
+ if encoding == 'fp130':
1347
+ from .codegen.fp130 import generate_fp130_layer
1348
+ print(f" → Using FP1.3.0 encoder")
1349
+ return generate_fp130_layer(info)
1350
+ elif encoding == '2bit':
1351
+ from .codegen.bit2 import generate_2bit_layer
1352
+ print(f" → Using 2-bit encoder")
1353
+ return generate_2bit_layer(info)
1354
+ elif encoding == '5level':
1355
+ from .codegen.level5 import generate_5level_layer
1356
+ print(f" → Using 5-level encoder")
1357
+ return generate_5level_layer(info)
1358
+ elif encoding == 'ternary':
1359
+ from .codegen.ternary import generate_ternary_layer
1360
+ print(f" → Using Ternary encoder")
1361
+ return generate_ternary_layer(info)
1362
+
1363
+ # Default: unroll or loop mode (legacy)
1364
+
1365
+ # Default: unroll mode
1366
+ if code_mode == 'loop':
1367
+ print(f" [WARNING] 'loop' code_mode is deprecated and loop.py is removed. Falling back to 'unroll' mode.")
1368
+ if optimized:
1369
+ from .codegen.unroll import generate_unrolled_layer_optimized
1370
+ return generate_unrolled_layer_optimized(info, is_first_layer=is_first_layer)
1371
+ else:
1372
+ # Legacy unroll logic dispatch
1373
+ pass # Fallthrough to below
1374
+
1375
+ # Standard unroll dispatch logic continues below...
1376
+ if True: # Indent anchor
1377
+ # Unroll mode
1378
+ if optimized:
1379
+ from .codegen.unroll import generate_unrolled_layer_optimized
1380
+ return generate_unrolled_layer_optimized(info, is_first_layer=is_first_layer)
1381
+ else:
1382
+ if layer_type == 'PoTConv2d':
1383
+ return generate_conv_layer(info, is_first_layer)
1384
+ elif layer_type == 'PoTConv1d':
1385
+ # Conv1d: use optimized version even in non-optimized mode
1386
+ from .codegen.unroll import generate_unrolled_layer_optimized
1387
+ return generate_unrolled_layer_optimized(info, is_first_layer=is_first_layer)
1388
+ elif layer_type == 'PoTDepthwiseConv2d':
1389
+ return generate_depthwise_conv_layer(info, is_first_layer)
1390
+ elif layer_type == 'PoTLinear':
1391
+ return generate_linear_layer(info, is_first_layer)
1392
+
1393
+ return ""
1394
+
1395
+
1396
+ def generate_global_avg_pool_layer(info: Dict) -> str:
1397
+ """Generate Global Average Pooling layer.
1398
+
1399
+ C×H×W → C (채널당 평균)
1400
+ 나눗셈을 정수 연산으로: avg = (sum * div_mult) >> div_shift
1401
+ """
1402
+ name = info['name']
1403
+ channels = info['in_channels']
1404
+ h = info['in_h']
1405
+ w = info['in_w']
1406
+ pool_size = info.get('pool_size', h * w)
1407
+ div_mult = info.get('div_mult', 1)
1408
+ div_shift = info.get('div_shift', 0)
1409
+
1410
+ code = []
1411
+ code.append(f"// {name} - Global Average Pooling: {channels}x{h}x{w} -> {channels}")
1412
+
1413
+ if div_mult == 1:
1414
+ # Power of 2: shift only
1415
+ code.append(f"// pool_size={pool_size} (2^{div_shift}), using shift only")
1416
+ code.append(f"static void {name}_forward(const int8_t* input, int8_t* output) {{")
1417
+ code.append(f" for (int c = 0; c < {channels}; c++) {{")
1418
+ code.append(f" int32_t sum = 0;")
1419
+ code.append(f" for (int i = 0; i < {pool_size}; i++) {{")
1420
+ code.append(f" sum += input[c * {pool_size} + i];")
1421
+ code.append(f" }}")
1422
+ code.append(f" output[c] = (int8_t)((sum + {1 << (div_shift - 1)}) >> {div_shift});")
1423
+ code.append(f" }}")
1424
+ code.append(f"}}")
1425
+ else:
1426
+ # Not power of 2: mult + shift
1427
+ code.append(f"// pool_size={pool_size}, div = (sum * {div_mult}) >> {div_shift}")
1428
+ code.append(f"static void {name}_forward(const int8_t* input, int8_t* output) {{")
1429
+ code.append(f" for (int c = 0; c < {channels}; c++) {{")
1430
+ code.append(f" int32_t sum = 0;")
1431
+ code.append(f" for (int i = 0; i < {pool_size}; i++) {{")
1432
+ code.append(f" sum += input[c * {pool_size} + i];")
1433
+ code.append(f" }}")
1434
+ code.append(f" int32_t avg = (sum * {div_mult}) >> {div_shift};")
1435
+ code.append(f" output[c] = (int8_t)(avg > 127 ? 127 : (avg < -128 ? -128 : avg));")
1436
+ code.append(f" }}")
1437
+ code.append(f"}}")
1438
+
1439
+ code.append("")
1440
+ return '\n'.join(code)
1441
+
1442
+
1443
+ def generate_add_layer(info: Dict) -> str:
1444
+ """Generate Add layer for skip/residual connections.
1445
+
1446
+ 두 입력의 scale 정합을 정수 MUL + shift로 처리:
1447
+ x_aligned = (x * rescale_mult) >> rescale_shift
1448
+ output = x_aligned + y
1449
+
1450
+ 컴파일 타임에 mult/shift 계산, 런타임에 float 없음.
1451
+ """
1452
+ name = info['name']
1453
+ channels = info['in_channels']
1454
+ h = info['in_h']
1455
+ w = info['in_w']
1456
+ rescale_mult = info.get('rescale_mult', 128)
1457
+ rescale_shift = info.get('rescale_shift', 7)
1458
+ has_relu = info.get('has_relu', False)
1459
+
1460
+ size = channels * h * w
1461
+
1462
+ code = []
1463
+ code.append(f"// {name} - Add (skip connection): {channels}x{h}x{w}")
1464
+ code.append(f"// rescale: x_aligned = (x * {rescale_mult}) >> {rescale_shift}")
1465
+ if has_relu:
1466
+ code.append(f"// ReLU applied after add")
1467
+ code.append(f"static void {name}_forward(const int8_t* input_skip, const int8_t* input_conv, int8_t* output) {{")
1468
+ code.append(f" for (int i = 0; i < {size}; i++) {{")
1469
+ code.append(f" int32_t x = (int32_t)input_skip[i];")
1470
+ code.append(f" int32_t y = (int32_t)input_conv[i];")
1471
+
1472
+ # Rescale x to match y's scale: x * mult >> shift
1473
+ if rescale_mult == 128 and rescale_shift == 7:
1474
+ # 기본값 (ratio ≈ 1.0): 최적화 가능
1475
+ code.append(f" // ratio ≈ 1.0, skip rescale")
1476
+ else:
1477
+ code.append(f" x = (x * {rescale_mult}) >> {rescale_shift};")
1478
+
1479
+ code.append(f" int32_t sum = x + y;")
1480
+
1481
+ # ReLU if needed
1482
+ if has_relu:
1483
+ code.append(f" if (sum < 0) sum = 0;")
1484
+
1485
+ code.append(f" // Clamp to int8 range")
1486
+ code.append(f" output[i] = (int8_t)(sum > 127 ? 127 : (sum < -128 ? -128 : sum));")
1487
+ code.append(f" }}")
1488
+ code.append(f"}}")
1489
+ code.append("")
1490
+
1491
+ return '\n'.join(code)
1492
+
1493
+
1494
+ def generate_depthwise_conv_layer(info: Dict, is_first_layer: bool = False) -> str:
1495
+ """Generate Depthwise Conv2d layer with PoT weights.
1496
+
1497
+ Depthwise Conv: 각 채널이 독립적으로 처리됨.
1498
+ weight shape: [channels, 1, kH, kW]
1499
+ 입출력 채널이 동일 (channels = in_channels = out_channels)
1500
+ """
1501
+ name = info['name']
1502
+ weight = info['weight']
1503
+ bias_scaled = info.get('bias_scaled')
1504
+ alpha = info['alpha']
1505
+ is_last = info.get('is_last', False)
1506
+
1507
+ channels = info['in_channels'] # in_channels == out_channels for depthwise
1508
+ kh = info['kernel_size']
1509
+ kw = kh
1510
+ stride = info['stride']
1511
+ padding = info['padding']
1512
+ in_h = info['in_h']
1513
+ in_w = info['in_w']
1514
+ out_h = info['out_h']
1515
+ out_w = info['out_w']
1516
+
1517
+ input_type = "uint8_t" if is_first_layer else "int8_t"
1518
+
1519
+ code = []
1520
+ code.append(f"// {name} - DepthwiseConv2d: {channels}x{in_h}x{in_w} -> {channels}x{out_h}x{out_w}")
1521
+ code.append(f"// Kernel: {kh}x{kw}, Stride: {stride}, Padding: {padding}, alpha={alpha:.4f}")
1522
+ if is_first_layer:
1523
+ code.append(f"// First layer: input is uint8 [0,255], /256 absorbed in shift")
1524
+ if is_last:
1525
+ code.append(f"// Last layer: no ReLU")
1526
+ code.append(f"static void {name}_forward(const {input_type}* input, int8_t* output) {{")
1527
+ code.append(f" int32_t acc;")
1528
+ code.append("")
1529
+
1530
+ # Depthwise: 각 채널을 독립적으로 처리
1531
+ code.append(f" for (int c = 0; c < {channels}; c++) {{")
1532
+ code.append(f" for (int oy = 0; oy < {out_h}; oy++) {{")
1533
+ code.append(f" for (int ox = 0; ox < {out_w}; ox++) {{")
1534
+ code.append(f" acc = 0;")
1535
+ code.append("")
1536
+
1537
+ # 커널 루프 (언롤링)
1538
+ for ky in range(kh):
1539
+ for kx in range(kw):
1540
+ # weight index: [c, 0, ky, kx] → c dimension handled by switch
1541
+ iy_base = ky - padding
1542
+ ix_base = kx - padding
1543
+
1544
+ # Input coordinate: iy = oy * stride + ky - padding
1545
+ if stride == 1:
1546
+ if iy_base == 0:
1547
+ iy_expr = "oy"
1548
+ elif iy_base > 0:
1549
+ iy_expr = f"oy + {iy_base}"
1550
+ else:
1551
+ iy_expr = f"oy - {-iy_base}"
1552
+ else:
1553
+ if iy_base == 0:
1554
+ iy_expr = f"oy * {stride}"
1555
+ elif iy_base > 0:
1556
+ iy_expr = f"oy * {stride} + {iy_base}"
1557
+ else:
1558
+ iy_expr = f"oy * {stride} - {-iy_base}"
1559
+
1560
+ if stride == 1:
1561
+ if ix_base == 0:
1562
+ ix_expr = "ox"
1563
+ elif ix_base > 0:
1564
+ ix_expr = f"ox + {ix_base}"
1565
+ else:
1566
+ ix_expr = f"ox - {-ix_base}"
1567
+ else:
1568
+ if ix_base == 0:
1569
+ ix_expr = f"ox * {stride}"
1570
+ elif ix_base > 0:
1571
+ ix_expr = f"ox * {stride} + {ix_base}"
1572
+ else:
1573
+ ix_expr = f"ox * {stride} - {-ix_base}"
1574
+
1575
+ # Boundary check
1576
+ checks = []
1577
+ if stride == 1:
1578
+ if iy_base < 0:
1579
+ checks.append(f"oy >= {-iy_base}")
1580
+ if iy_base > 0:
1581
+ checks.append(f"oy + {iy_base} < {in_h}")
1582
+ if ix_base < 0:
1583
+ checks.append(f"ox >= {-ix_base}")
1584
+ if ix_base > 0:
1585
+ checks.append(f"ox + {ix_base} < {in_w}")
1586
+ else:
1587
+ if iy_base < 0:
1588
+ checks.append(f"oy * {stride} >= {-iy_base}")
1589
+ if iy_base > 0:
1590
+ checks.append(f"oy * {stride} + {iy_base} < {in_h}")
1591
+ if ix_base < 0:
1592
+ checks.append(f"ox * {stride} >= {-ix_base}")
1593
+ if ix_base > 0:
1594
+ checks.append(f"ox * {stride} + {ix_base} < {in_w}")
1595
+
1596
+ # Input index: c * in_h * in_w + iy * in_w + ix
1597
+ idx_expr = f"c * {in_h} * {in_w} + ({iy_expr}) * {in_w} + ({ix_expr})"
1598
+
1599
+ # Generate per-channel weight switch
1600
+ code.append(f" // Kernel position ({ky}, {kx})")
1601
+
1602
+ # Check which channels have non-zero weights at this kernel position
1603
+ non_zero_channels = []
1604
+ for c in range(channels):
1605
+ w = weight[c, 0, ky, kx].item()
1606
+ if abs(w) >= 0.5: # Not zero
1607
+ w_abs = abs(w)
1608
+ k = round(np.log2(w_abs + 1e-10))
1609
+ k = max(0, min(4, k))
1610
+ sign = '+' if w > 0 else '-'
1611
+ non_zero_channels.append((c, k, sign))
1612
+
1613
+ if len(non_zero_channels) == 0:
1614
+ code.append(f" // All weights zero at this position")
1615
+ continue
1616
+
1617
+ # Check if all non-zero channels have same k and sign
1618
+ all_same = len(set((k, sign) for _, k, sign in non_zero_channels)) == 1
1619
+
1620
+ if all_same and len(non_zero_channels) == channels:
1621
+ # All channels same: no switch needed
1622
+ k, sign = non_zero_channels[0][1], non_zero_channels[0][2]
1623
+ if k == 0:
1624
+ shift_expr = f"(int32_t)input[{idx_expr}]"
1625
+ else:
1626
+ shift_expr = f"(int32_t)input[{idx_expr}] << {k}"
1627
+
1628
+ if checks:
1629
+ cond = " && ".join(checks)
1630
+ code.append(f" if ({cond}) acc {sign}= {shift_expr};")
1631
+ else:
1632
+ code.append(f" acc {sign}= {shift_expr};")
1633
+ else:
1634
+ # Different weights per channel: use switch
1635
+ cond_prefix = ""
1636
+ if checks:
1637
+ cond = " && ".join(checks)
1638
+ cond_prefix = f"if ({cond}) "
1639
+
1640
+ code.append(f" {cond_prefix}switch (c) {{")
1641
+ for c, k, sign in non_zero_channels:
1642
+ if k == 0:
1643
+ shift_expr = f"(int32_t)input[{idx_expr}]"
1644
+ else:
1645
+ shift_expr = f"(int32_t)input[{idx_expr}] << {k}"
1646
+ code.append(f" case {c}: acc {sign}= {shift_expr}; break;")
1647
+ code.append(f" }}")
1648
+
1649
+ code.append("")
1650
+
1651
+ # Apply scale
1652
+ code.append(f" acc = scale_{name}(acc);")
1653
+
1654
+ # Add scaled bias (per-channel)
1655
+ if bias_scaled is not None:
1656
+ code.append(f" // Add per-channel bias")
1657
+ # DEBUG: Print bias values during generation
1658
+ if name == 'layer_0':
1659
+ print(f" [GEN DEBUG] {name} bias_scaled: {bias_scaled[:5].tolist()}...")
1660
+ code.append(f" switch (c) {{")
1661
+
1662
+ for c in range(channels):
1663
+ b = int(bias_scaled[c].item())
1664
+ if b != 0:
1665
+ code.append(f" case {c}: acc += {b}; break;")
1666
+ code.append(f" }}")
1667
+
1668
+ # ReLU (based on ONNX graph analysis)
1669
+ if info.get('has_relu', False):
1670
+ code.append(f" if (acc < 0) acc = 0;")
1671
+
1672
+ # Clamp and store
1673
+ code.append(f" int out_idx = c * {out_h} * {out_w} + oy * {out_w} + ox;")
1674
+ code.append(f" output[out_idx] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));")
1675
+
1676
+ code.append(f" }}")
1677
+ code.append(f" }}")
1678
+ code.append(f" }}")
1679
+ code.append(f"}}")
1680
+ code.append("")
1681
+
1682
+ return '\n'.join(code)
1683
+
1684
+
1685
+ def generate_conv1x1_layer(info: Dict, is_first_layer: bool = False) -> str:
1686
+ """Generate optimized 1x1 Conv2d layer with PoT weights.
1687
+
1688
+ 1x1 Conv는 spatial 연산 없이 채널 간 내적만 수행.
1689
+ 패딩/커널 루프 불필요 → 단순화된 코드 생성.
1690
+ """
1691
+ name = info['name']
1692
+ weight = info['weight']
1693
+ bias_scaled = info.get('bias_scaled')
1694
+ alpha = info['alpha']
1695
+ is_last = info.get('is_last', False)
1696
+
1697
+ in_ch = info['in_channels']
1698
+ out_ch = info['out_channels']
1699
+ stride = info['stride']
1700
+ in_h = info['in_h']
1701
+ in_w = info['in_w']
1702
+ out_h = info['out_h']
1703
+ out_w = info['out_w']
1704
+
1705
+ input_type = "uint8_t" if is_first_layer else "int8_t"
1706
+
1707
+ code = []
1708
+ code.append(f"// {name} - Conv2d 1x1: {in_ch}x{in_h}x{in_w} -> {out_ch}x{out_h}x{out_w}")
1709
+ code.append(f"// Stride: {stride}, alpha={alpha:.4f}")
1710
+ if is_first_layer:
1711
+ code.append(f"// First layer: input is uint8 [0,255], /256 absorbed in shift")
1712
+ if is_last:
1713
+ code.append(f"// Last layer: no ReLU")
1714
+ code.append(f"static void {name}_forward(const {input_type}* input, int8_t* output) {{")
1715
+ code.append(f" int32_t acc;")
1716
+ code.append("")
1717
+
1718
+ # 1x1 Conv: 각 출력 위치에서 채널 내적만 수행
1719
+ code.append(f" for (int oy = 0; oy < {out_h}; oy++) {{")
1720
+ code.append(f" for (int ox = 0; ox < {out_w}; ox++) {{")
1721
+
1722
+ # stride 적용: 입력 좌표 계산
1723
+ if stride == 1:
1724
+ code.append(f" int iy = oy;")
1725
+ code.append(f" int ix = ox;")
1726
+ else:
1727
+ code.append(f" int iy = oy * {stride};")
1728
+ code.append(f" int ix = ox * {stride};")
1729
+ code.append("")
1730
+
1731
+ for oc in range(out_ch):
1732
+ code.append(f" // Output channel {oc}")
1733
+ code.append(f" acc = 0;")
1734
+
1735
+ # 채널 내적 (1x1이므로 ky, kx 루프 없음)
1736
+ for ic in range(in_ch):
1737
+ w = weight[oc, ic, 0, 0].item()
1738
+ if abs(w) < 1e-9: # Skip zero weights
1739
+ continue
1740
+
1741
+ w_abs = abs(w)
1742
+ if w_abs < 0.5: # Too small, skip
1743
+ continue
1744
+
1745
+ # Find k where 2^k ≈ pot_value
1746
+ k = round(np.log2(w_abs + 1e-10))
1747
+ k = max(0, min(4, k)) # k in [0, 4] for {1, 2, 4, 8, 16}
1748
+
1749
+ sign = '+' if w > 0 else '-'
1750
+
1751
+ # 입력 인덱스: [ic, iy, ix] in CHW layout
1752
+ if in_ch > 1:
1753
+ idx = f"{ic} * {in_h} * {in_w} + iy * {in_w} + ix"
1754
+ else:
1755
+ idx = f"iy * {in_w} + ix"
1756
+
1757
+ if k == 0:
1758
+ shift_expr = f"(int32_t)input[{idx}]"
1759
+ else:
1760
+ shift_expr = f"(int32_t)input[{idx}] << {k}"
1761
+
1762
+ code.append(f" acc {sign}= {shift_expr};")
1763
+
1764
+ # Apply scale
1765
+ code.append(f" acc = scale_{name}(acc);")
1766
+
1767
+ # Add scaled bias
1768
+ if bias_scaled is not None:
1769
+ b = int(bias_scaled[oc].item())
1770
+ if b != 0:
1771
+ code.append(f" acc += {b};")
1772
+
1773
+ # ReLU (based on ONNX graph analysis)
1774
+ if info.get('has_relu', False):
1775
+ code.append(f" if (acc < 0) acc = 0;")
1776
+
1777
+ # Clamp and store
1778
+ out_idx = f"{oc} * {out_h} * {out_w} + oy * {out_w} + ox"
1779
+ code.append(f" output[{out_idx}] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));")
1780
+ code.append("")
1781
+
1782
+ code.append(f" }}")
1783
+ code.append(f" }}")
1784
+ code.append(f"}}")
1785
+ code.append("")
1786
+
1787
+ return '\n'.join(code)
1788
+
1789
+
1790
+ def generate_conv_layer(info: Dict, is_first_layer: bool = False) -> str:
1791
+ """Generate Conv2d layer with PoT weights."""
1792
+ name = info['name']
1793
+ weight = info['weight']
1794
+ bias_scaled = info.get('bias_scaled')
1795
+ alpha = info['alpha']
1796
+ is_last = info.get('is_last', False)
1797
+
1798
+ in_ch = info['in_channels']
1799
+ out_ch = info['out_channels']
1800
+ kh = info['kernel_size']
1801
+ kw = kh
1802
+ stride = info['stride']
1803
+ padding = info['padding']
1804
+ in_h = info['in_h']
1805
+ in_w = info['in_w']
1806
+ out_h = info['out_h']
1807
+ out_w = info['out_w']
1808
+
1809
+ # 1x1 Conv 최적화: 별도 함수로 처리
1810
+ if kh == 1 and kw == 1:
1811
+ return generate_conv1x1_layer(info, is_first_layer)
1812
+
1813
+ # Input type: uint8 for first layer (raw input), int8 for others
1814
+ input_type = "uint8_t" if is_first_layer else "int8_t"
1815
+
1816
+ code = []
1817
+ code.append(f"// {name} - Conv2d: {in_ch}x{in_h}x{in_w} -> {out_ch}x{out_h}x{out_w}")
1818
+ code.append(f"// Kernel: {kh}x{kw}, Stride: {stride}, Padding: {padding}, alpha={alpha:.4f}")
1819
+ if is_first_layer:
1820
+ code.append(f"// First layer: input is uint8 [0,255], /256 absorbed in shift")
1821
+ if is_last:
1822
+ code.append(f"// Last layer: no ReLU")
1823
+ code.append(f"static void {name}_forward(const {input_type}* input, int8_t* output) {{")
1824
+ code.append(f" int32_t acc;")
1825
+ code.append("")
1826
+
1827
+ # Generate unrolled convolution
1828
+ code.append(f" for (int oy = 0; oy < {out_h}; oy++) {{")
1829
+ code.append(f" for (int ox = 0; ox < {out_w}; ox++) {{")
1830
+
1831
+ for oc in range(out_ch):
1832
+ code.append(f" // Output channel {oc}")
1833
+ code.append(f" acc = 0;")
1834
+
1835
+ # Generate weight operations
1836
+ for ic in range(in_ch):
1837
+ for ky in range(kh):
1838
+ for kx in range(kw):
1839
+ w = weight[oc, ic, ky, kx].item()
1840
+ if abs(w) < 1e-9: # Skip zero weights
1841
+ continue
1842
+
1843
+ # weight_q = pot_value * alpha
1844
+ # pot_value = weight_q / alpha ∈ {0, 1, 2, 4, 8, 16}
1845
+ w_abs = abs(w)
1846
+ pot_value = w_abs
1847
+
1848
+ # pot_value should be close to 1, 2, 4, 8, or 16
1849
+ if pot_value < 0.5: # Too small, skip
1850
+ continue
1851
+
1852
+ # Find k where 2^k ≈ pot_value
1853
+ k = round(np.log2(pot_value + 1e-10))
1854
+ k = max(0, min(4, k)) # k in [0, 4] for {1, 2, 4, 8, 16}
1855
+
1856
+ sign = '+' if w > 0 else '-'
1857
+
1858
+ # Calculate input coordinates with stride (BUG FIX!)
1859
+ # iy = oy * stride + ky - padding
1860
+ # ix = ox * stride + kx - padding
1861
+ iy_base = ky - padding # offset from oy * stride
1862
+ ix_base = kx - padding # offset from ox * stride
1863
+
1864
+ # Build index expression with stride
1865
+ idx_parts = []
1866
+ if in_ch > 1:
1867
+ idx_parts.append(f"{ic} * {in_h} * {in_w}")
1868
+
1869
+ # Y coordinate: oy * stride + iy_base
1870
+ if stride == 1:
1871
+ if iy_base == 0:
1872
+ idx_parts.append(f"oy * {in_w}")
1873
+ elif iy_base > 0:
1874
+ idx_parts.append(f"(oy + {iy_base}) * {in_w}")
1875
+ else:
1876
+ idx_parts.append(f"(oy - {-iy_base}) * {in_w}")
1877
+ else:
1878
+ if iy_base == 0:
1879
+ idx_parts.append(f"oy * {stride} * {in_w}")
1880
+ elif iy_base > 0:
1881
+ idx_parts.append(f"(oy * {stride} + {iy_base}) * {in_w}")
1882
+ else:
1883
+ idx_parts.append(f"(oy * {stride} - {-iy_base}) * {in_w}")
1884
+
1885
+ # X coordinate: ox * stride + ix_base
1886
+ if stride == 1:
1887
+ if ix_base == 0:
1888
+ idx_parts.append("ox")
1889
+ elif ix_base > 0:
1890
+ idx_parts.append(f"(ox + {ix_base})")
1891
+ else:
1892
+ idx_parts.append(f"(ox - {-ix_base})")
1893
+ else:
1894
+ if ix_base == 0:
1895
+ idx_parts.append(f"ox * {stride}")
1896
+ elif ix_base > 0:
1897
+ idx_parts.append(f"(ox * {stride} + {ix_base})")
1898
+ else:
1899
+ idx_parts.append(f"(ox * {stride} - {-ix_base})")
1900
+
1901
+ idx = " + ".join(idx_parts)
1902
+
1903
+ # Boundary conditions for padding (with stride)
1904
+ checks = []
1905
+ if stride == 1:
1906
+ if iy_base < 0:
1907
+ checks.append(f"oy >= {-iy_base}")
1908
+ if iy_base > 0:
1909
+ checks.append(f"oy + {iy_base} < {in_h}")
1910
+ if ix_base < 0:
1911
+ checks.append(f"ox >= {-ix_base}")
1912
+ if ix_base > 0:
1913
+ checks.append(f"ox + {ix_base} < {in_w}")
1914
+ else:
1915
+ if iy_base < 0:
1916
+ checks.append(f"oy * {stride} >= {-iy_base}")
1917
+ if iy_base > 0:
1918
+ checks.append(f"oy * {stride} + {iy_base} < {in_h}")
1919
+ if ix_base < 0:
1920
+ checks.append(f"ox * {stride} >= {-ix_base}")
1921
+ if ix_base > 0:
1922
+ checks.append(f"ox * {stride} + {ix_base} < {in_w}")
1923
+
1924
+ # Generate code
1925
+ if k == 0:
1926
+ shift_expr = f"(int32_t)input[{idx}]"
1927
+ else:
1928
+ shift_expr = f"(int32_t)input[{idx}] << {k}"
1929
+
1930
+ if checks:
1931
+ cond = " && ".join(checks)
1932
+ code.append(f" if ({cond}) acc {sign}= {shift_expr};")
1933
+ else:
1934
+ code.append(f" acc {sign}= {shift_expr};")
1935
+
1936
+ # Apply scale
1937
+ code.append(f" acc = scale_{name}(acc);")
1938
+
1939
+ # Add scaled bias (CRITICAL FIX: use bias_scaled, not raw bias)
1940
+ if bias_scaled is not None:
1941
+ b = int(bias_scaled[oc].item())
1942
+ if b != 0:
1943
+ code.append(f" acc += {b};")
1944
+
1945
+ # ReLU (based on ONNX graph analysis)
1946
+ if info.get('has_relu', False):
1947
+ code.append(f" if (acc < 0) acc = 0;")
1948
+
1949
+ # Clamp and store
1950
+ out_idx = f"{oc * out_h * out_w} + oy * {out_w} + ox"
1951
+ code.append(f" output[{out_idx}] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));")
1952
+ code.append("")
1953
+
1954
+ code.append(f" }}")
1955
+ code.append(f" }}")
1956
+ code.append(f"}}")
1957
+ code.append("")
1958
+
1959
+ return '\n'.join(code)
1960
+
1961
+
1962
+ def generate_linear_layer(info: Dict, is_first_layer: bool = False) -> str:
1963
+ """Generate Linear layer with PoT weights."""
1964
+ name = info['name']
1965
+ weight = info['weight']
1966
+ bias_scaled = info.get('bias_scaled')
1967
+ alpha = info['alpha']
1968
+ is_last = info.get('is_last', False)
1969
+
1970
+ in_features = info['in_features']
1971
+ out_features = info['out_features']
1972
+
1973
+ input_type = "uint8_t" if is_first_layer else "int8_t"
1974
+
1975
+ code = []
1976
+ code.append(f"// {name} - Linear: {in_features} -> {out_features}, alpha={alpha:.4f}")
1977
+ if is_last:
1978
+ code.append(f"// Last layer: no ReLU")
1979
+ code.append(f"static void {name}_forward(const {input_type}* input, int8_t* output) {{")
1980
+ code.append(f" int32_t acc;")
1981
+ code.append("")
1982
+
1983
+ for o in range(out_features):
1984
+ code.append(f" // Output {o}")
1985
+ code.append(f" acc = 0;")
1986
+
1987
+ for i in range(in_features):
1988
+ w = weight[o, i].item()
1989
+ if abs(w) < 1e-9:
1990
+ continue
1991
+
1992
+ w_abs = abs(w)
1993
+ pot_value = w_abs
1994
+
1995
+ if pot_value < 0.5:
1996
+ continue
1997
+
1998
+ k = round(np.log2(pot_value + 1e-10))
1999
+ k = max(0, min(4, k))
2000
+
2001
+ sign = '+' if w > 0 else '-'
2002
+
2003
+ if k == 0:
2004
+ code.append(f" acc {sign}= (int32_t)input[{i}];")
2005
+ else:
2006
+ code.append(f" acc {sign}= (int32_t)input[{i}] << {k};")
2007
+
2008
+ code.append(f" acc = scale_{name}(acc);")
2009
+
2010
+ # Add scaled bias (CRITICAL FIX)
2011
+ if bias_scaled is not None:
2012
+ b = int(bias_scaled[o].item())
2013
+ if b != 0:
2014
+ code.append(f" acc += {b};")
2015
+
2016
+ # ReLU (based on ONNX graph analysis)
2017
+ if info.get('has_relu', False):
2018
+ code.append(f" if (acc < 0) acc = 0;")
2019
+
2020
+ code.append(f" output[{o}] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));")
2021
+ code.append("")
2022
+
2023
+ code.append(f"}}")
2024
+ code.append("")
2025
+
2026
+ return '\n'.join(code)
2027
+
2028
+
2029
+ def generate_maxpool_layer(info: Dict) -> str:
2030
+ """Generate MaxPool1d or MaxPool2d layer."""
2031
+ name = info['name']
2032
+ in_h = info['in_h']
2033
+ in_w = info['in_w']
2034
+ in_ch = info['in_channels']
2035
+ out_h = info['out_h']
2036
+ out_w = info['out_w']
2037
+ k = info['kernel_size']
2038
+ s = info['stride']
2039
+ is_1d = info.get('is_1d', False)
2040
+
2041
+ code = []
2042
+
2043
+ if is_1d:
2044
+ # MaxPool1d
2045
+ code.append(f"// {name} - MaxPool1d k={k}, stride={s}")
2046
+ code.append(f"static void {name}_forward(const int8_t* input, int8_t* output) {{")
2047
+ code.append(f" for (int c = 0; c < {in_ch}; c++) {{")
2048
+ code.append(f" for (int o = 0; o < {out_h}; o++) {{")
2049
+ code.append(f" int8_t max_val = -128;")
2050
+ code.append(f" for (int ki = 0; ki < {k}; ki++) {{")
2051
+ code.append(f" int idx = c * {in_h} + (o * {s} + ki);")
2052
+ code.append(f" if (input[idx] > max_val) max_val = input[idx];")
2053
+ code.append(f" }}")
2054
+ code.append(f" output[c * {out_h} + o] = max_val;")
2055
+ code.append(f" }}")
2056
+ code.append(f" }}")
2057
+ code.append(f"}}")
2058
+ else:
2059
+ # MaxPool2d
2060
+ code.append(f"// {name} - MaxPool2d {k}x{k}, stride {s}")
2061
+ code.append(f"static void {name}_forward(const int8_t* input, int8_t* output) {{")
2062
+ code.append(f" for (int c = 0; c < {in_ch}; c++) {{")
2063
+ code.append(f" for (int oy = 0; oy < {out_h}; oy++) {{")
2064
+ code.append(f" for (int ox = 0; ox < {out_w}; ox++) {{")
2065
+ code.append(f" int8_t max_val = -128;")
2066
+ code.append(f" for (int ky = 0; ky < {k}; ky++) {{")
2067
+ code.append(f" for (int kx = 0; kx < {k}; kx++) {{")
2068
+ code.append(f" int idx = c * {in_h} * {in_w} + (oy * {s} + ky) * {in_w} + (ox * {s} + kx);")
2069
+ code.append(f" if (input[idx] > max_val) max_val = input[idx];")
2070
+ code.append(f" }}")
2071
+ code.append(f" }}")
2072
+ code.append(f" output[c * {out_h} * {out_w} + oy * {out_w} + ox] = max_val;")
2073
+ code.append(f" }}")
2074
+ code.append(f" }}")
2075
+ code.append(f" }}")
2076
+ code.append(f"}}")
2077
+
2078
+ code.append("")
2079
+
2080
+ return '\n'.join(code)
2081
+
2082
+
2083
+ def generate_predict_function(layer_infos: List[Dict], config: Config) -> str:
2084
+ """Generate main prediction function with proper skip connection handling."""
2085
+ code = []
2086
+
2087
+ input_size = config.input_h * config.input_w
2088
+
2089
+ # Calculate max buffer size
2090
+ max_buffer_size = input_size
2091
+ for info in layer_infos:
2092
+ if 'out_h' in info and 'out_w' in info:
2093
+ if 'out_channels' in info:
2094
+ size = info['out_channels'] * info['out_h'] * info['out_w']
2095
+ elif 'in_channels' in info:
2096
+ size = info['in_channels'] * info['out_h'] * info['out_w']
2097
+ else:
2098
+ size = info['out_h'] * info['out_w']
2099
+ max_buffer_size = max(max_buffer_size, size)
2100
+ elif 'out_features' in info:
2101
+ max_buffer_size = max(max_buffer_size, info['out_features'])
2102
+
2103
+ # ========================================
2104
+ # Analyze skip connections
2105
+ # ========================================
2106
+ # Collect all skip source layers (where skip branches start)
2107
+ skip_sources = set()
2108
+ for info in layer_infos:
2109
+ if info['layer_type'] == 'add' and 'skip_source_layer' in info:
2110
+ skip_sources.add(info['skip_source_layer'])
2111
+
2112
+ has_skip = len(skip_sources) > 0
2113
+
2114
+ # Find number of output classes
2115
+ num_classes = 10
2116
+ for info in reversed(layer_infos):
2117
+ if 'out_features' in info and info['layer_type'] == 'pot':
2118
+ num_classes = info['out_features']
2119
+ break
2120
+
2121
+ code.append("// Input: uint8 [0,255] - raw pixel values, no normalization needed")
2122
+ code.append("// Output: int8 [0,255] - raw logits")
2123
+ code.append("void potnn_predict(const uint8_t* input, int8_t* output) {")
2124
+ code.append(f" // Intermediate buffers (max size: {max_buffer_size})")
2125
+ code.append(f" static int8_t buffer1[{max_buffer_size}];")
2126
+ code.append(f" static int8_t buffer2[{max_buffer_size}];")
2127
+
2128
+ if has_skip:
2129
+ code.append(f" static int8_t skip_buffer[{max_buffer_size}]; // for skip connections")
2130
+ code.append(f" // Skip sources: layers {sorted(skip_sources)}")
2131
+
2132
+ code.append(f" int8_t *current = buffer1;")
2133
+ code.append(f" int8_t *next = buffer2;")
2134
+ code.append("")
2135
+
2136
+ first_pot_done = False
2137
+
2138
+ for i, info in enumerate(layer_infos):
2139
+ layer_type = info['layer_type']
2140
+
2141
+ if layer_type == 'flatten':
2142
+ code.append(f" // Layer {i}: Flatten (no-op)")
2143
+ continue
2144
+
2145
+ code.append(f" // Layer {i}: {info['type']}")
2146
+
2147
+ if layer_type == 'pot' and not first_pot_done:
2148
+ # First PoT layer: input is uint8
2149
+ code.append(f" {info['name']}_forward(input, current);")
2150
+ first_pot_done = True
2151
+
2152
+ elif layer_type == 'add':
2153
+ # Add layer: skip_buffer + current -> next
2154
+ skip_src = info.get('skip_source_layer', -1)
2155
+ conv_src = info.get('conv_source_layer', -1)
2156
+ code.append(f" // skip from layer_{skip_src}, conv from layer_{conv_src}")
2157
+ code.append(f" {info['name']}_forward(skip_buffer, current, next);")
2158
+ code.append(f" {{ int8_t *tmp = current; current = next; next = tmp; }}")
2159
+
2160
+ elif layer_type == 'global_avg_pool':
2161
+ # Global Average Pooling: C×H×W → C
2162
+ code.append(f" {info['name']}_forward(current, next);")
2163
+ code.append(f" {{ int8_t *tmp = current; current = next; next = tmp; }}")
2164
+
2165
+ else:
2166
+ code.append(f" {info['name']}_forward(current, next);")
2167
+ code.append(f" {{ int8_t *tmp = current; current = next; next = tmp; }}")
2168
+
2169
+ # ========================================
2170
+ # Check if this layer is a skip source → save to skip_buffer
2171
+ # ========================================
2172
+ if i in skip_sources:
2173
+ # Find the size of this layer's output
2174
+ if 'out_channels' in info and 'out_h' in info and 'out_w' in info:
2175
+ skip_size = info['out_channels'] * info['out_h'] * info['out_w']
2176
+ elif 'out_features' in info:
2177
+ skip_size = info['out_features']
2178
+ elif 'in_channels' in info and 'out_h' in info:
2179
+ skip_size = info['in_channels'] * info['out_h'] * info['out_w']
2180
+ else:
2181
+ skip_size = max_buffer_size
2182
+
2183
+ code.append(f" // Save skip (layer {i} is skip source)")
2184
+ code.append(f" for (int _i = 0; _i < {skip_size}; _i++) skip_buffer[_i] = current[_i];")
2185
+
2186
+ code.append("")
2187
+
2188
+ code.append(f" // Copy result to output buffer (num_classes = {num_classes})")
2189
+ code.append(f" for (int i = 0; i < {num_classes}; i++) {{")
2190
+ code.append(f" output[i] = current[i];")
2191
+ code.append(f" }}")
2192
+ code.append(f"}}")
2193
+
2194
+ return '\n'.join(code)
2195
+
2196
+ return '\n'.join(code)