potnn 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- potnn/__init__.py +86 -0
- potnn/codegen/__init__.py +20 -0
- potnn/codegen/bit2.py +263 -0
- potnn/codegen/fp130.py +269 -0
- potnn/codegen/header.py +460 -0
- potnn/codegen/level5.py +393 -0
- potnn/codegen/scale.py +184 -0
- potnn/codegen/ternary.py +354 -0
- potnn/codegen/unroll.py +616 -0
- potnn/config.py +112 -0
- potnn/export.py +2196 -0
- potnn/fuse.py +167 -0
- potnn/modules/__init__.py +11 -0
- potnn/modules/add.py +114 -0
- potnn/modules/avgpool.py +173 -0
- potnn/modules/base.py +225 -0
- potnn/modules/conv.py +203 -0
- potnn/modules/conv1d.py +317 -0
- potnn/modules/depthwise.py +216 -0
- potnn/modules/linear.py +199 -0
- potnn/quantize/__init__.py +35 -0
- potnn/quantize/calibration.py +233 -0
- potnn/quantize/integer_ops.py +207 -0
- potnn/quantize/integer_sim.py +225 -0
- potnn/quantize/pot.py +455 -0
- potnn/quantize/qat.py +356 -0
- potnn/utils/__init__.py +13 -0
- potnn/utils/allocation.py +240 -0
- potnn/utils/memory.py +158 -0
- potnn/wrapper.py +304 -0
- potnn-1.0.0.dist-info/METADATA +260 -0
- potnn-1.0.0.dist-info/RECORD +35 -0
- potnn-1.0.0.dist-info/WHEEL +5 -0
- potnn-1.0.0.dist-info/licenses/LICENSE +72 -0
- potnn-1.0.0.dist-info/top_level.txt +1 -0
potnn/export.py
ADDED
|
@@ -0,0 +1,2196 @@
|
|
|
1
|
+
"""Export functionality for PoT quantized models via ONNX.
|
|
2
|
+
|
|
3
|
+
Follows the technical specification exactly:
|
|
4
|
+
- Input: uint8 [0,255], no runtime normalization
|
|
5
|
+
- /256: absorbed into combined_shift (+8)
|
|
6
|
+
- mean: absorbed into bias (b' = b - mean × ΣW)
|
|
7
|
+
- /std: absorbed into combined_scale (scale × 1/std)
|
|
8
|
+
- bias: scaled by act_scale (bias_int = round(bias × act_scale))
|
|
9
|
+
- Output: MUL-free C code (shift+add only, except combined_scale MUL per layer)
|
|
10
|
+
|
|
11
|
+
v8: Fixed all bugs from systematic verification
|
|
12
|
+
- Fixed: bias scaling with act_scale
|
|
13
|
+
- Fixed: standardization absorption with act_scale
|
|
14
|
+
- Fixed: last layer act_scale=None handling
|
|
15
|
+
- Fixed: weight to shift conversion (w is already PoT*alpha)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
import tempfile
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
import numpy as np
|
|
23
|
+
from typing import Dict, List, Optional, Any, Tuple
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
import onnx
|
|
27
|
+
from onnx import numpy_helper
|
|
28
|
+
ONNX_AVAILABLE = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
ONNX_AVAILABLE = False
|
|
31
|
+
|
|
32
|
+
from .modules.base import PoTLayerBase
|
|
33
|
+
from .modules.conv import PoTConv2d
|
|
34
|
+
from .modules.conv1d import PoTConv1d
|
|
35
|
+
from .modules.depthwise import PoTDepthwiseConv2d
|
|
36
|
+
from .modules.linear import PoTLinear
|
|
37
|
+
from .modules.add import PoTAdd
|
|
38
|
+
from .config import Config
|
|
39
|
+
from .quantize.pot import quantize_to_pot
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def export(model: nn.Module, output_path: str, config: Config, dummy_input: torch.Tensor = None, optimized: bool = True):
|
|
43
|
+
"""
|
|
44
|
+
Export PyTorch model to C code.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model: PyTorch model (PoT layers)
|
|
48
|
+
output_path: Path to save .c file
|
|
49
|
+
config: Hardware configuration (Config class)
|
|
50
|
+
dummy_input: Dummy input for ONNX export (optional, auto-generated if None)
|
|
51
|
+
optimized: Enable optimized C kernels (default: True)
|
|
52
|
+
- Loop layers: Full Pipeline (Zero-Padding + im2col + Column-wise +
|
|
53
|
+
Position Blocking + Shift Grouping)
|
|
54
|
+
- Unroll layers: Zero-Padding (eliminates boundary checks)
|
|
55
|
+
"""
|
|
56
|
+
if not ONNX_AVAILABLE:
|
|
57
|
+
raise RuntimeError("onnx package required. Install with: pip install onnx")
|
|
58
|
+
|
|
59
|
+
print(f"\nStarting export to {output_path}...")
|
|
60
|
+
print(f"Optimized mode: {'ENABLED' if optimized else 'DISABLED'} (Full Pipeline for loop, Zero-Padding for unroll)")
|
|
61
|
+
print("Using ONNX for graph extraction (v10 - hybrid unroll/loop)...")
|
|
62
|
+
|
|
63
|
+
model.eval()
|
|
64
|
+
|
|
65
|
+
# Step 0a: Disable Integer Simulation for clean ONNX export
|
|
66
|
+
# Integer Sim adds round_ste/clamp_ste which become spurious Add nodes
|
|
67
|
+
from .quantize.qat import disable_integer_sim
|
|
68
|
+
disable_integer_sim(model)
|
|
69
|
+
|
|
70
|
+
# Step 0a-2: Disable 5level constraint for torch.export compatibility
|
|
71
|
+
# The constraint uses Python for-loops which torch.export doesn't support
|
|
72
|
+
from .modules.base import PoTLayerBase
|
|
73
|
+
for module in model.modules():
|
|
74
|
+
if isinstance(module, PoTLayerBase):
|
|
75
|
+
module.enforce_5level_constraint = False
|
|
76
|
+
|
|
77
|
+
# Step 0b: Fuse BatchNorm layers into Conv/Linear (CRITICAL!)
|
|
78
|
+
from .fuse import fuse_batchnorm, check_bn_fused
|
|
79
|
+
model = fuse_batchnorm(model)
|
|
80
|
+
if not check_bn_fused(model):
|
|
81
|
+
print("Warning: Some BatchNorm layers could not be fused!")
|
|
82
|
+
|
|
83
|
+
if dummy_input is None:
|
|
84
|
+
if config.input_w == 1:
|
|
85
|
+
# Conv1d: (B, C, L) - input_h is sequence length
|
|
86
|
+
dummy_input = torch.randn(1, config.input_channels, config.input_h)
|
|
87
|
+
else:
|
|
88
|
+
# Conv2d: (B, C, H, W)
|
|
89
|
+
dummy_input = torch.randn(1, config.input_channels, config.input_h, config.input_w)
|
|
90
|
+
|
|
91
|
+
device = next(model.parameters()).device
|
|
92
|
+
dummy_input = dummy_input.to(device)
|
|
93
|
+
|
|
94
|
+
# Auto-detect input shape from dummy_input
|
|
95
|
+
if dummy_input is not None:
|
|
96
|
+
if dummy_input.dim() == 4:
|
|
97
|
+
# (B, C, H, W)
|
|
98
|
+
print(f"[Export] Auto-detected input shape: {dummy_input.shape}")
|
|
99
|
+
config.input_channels = dummy_input.shape[1]
|
|
100
|
+
config.input_h = dummy_input.shape[2]
|
|
101
|
+
config.input_w = dummy_input.shape[3]
|
|
102
|
+
elif dummy_input.dim() == 3:
|
|
103
|
+
# (B, C, L)
|
|
104
|
+
print(f"[Export] Auto-detected input shape (1D): {dummy_input.shape}")
|
|
105
|
+
config.input_channels = dummy_input.shape[1]
|
|
106
|
+
config.input_h = dummy_input.shape[2] # Length
|
|
107
|
+
config.input_w = 1
|
|
108
|
+
|
|
109
|
+
# Step 1: Collect PoT layer info from PyTorch model
|
|
110
|
+
pot_layer_infos = collect_pot_layer_info(model)
|
|
111
|
+
print(f"Found {len(pot_layer_infos)} PoT layers in PyTorch model")
|
|
112
|
+
|
|
113
|
+
# Collect PoTAdd layer info (for rescale calculation)
|
|
114
|
+
add_layer_infos = collect_add_layer_info(model)
|
|
115
|
+
|
|
116
|
+
# Debug: print layer info
|
|
117
|
+
for info in pot_layer_infos:
|
|
118
|
+
print(f" {info['name']}: alpha={info['alpha']:.4f}, act_scale={info['act_scale']}")
|
|
119
|
+
|
|
120
|
+
# Step 2: Export to ONNX
|
|
121
|
+
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:
|
|
122
|
+
onnx_path = f.name
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
# Use legacy ONNX exporter (dynamo=False) to avoid torch.export compatibility issues
|
|
126
|
+
torch.onnx.export(
|
|
127
|
+
model, dummy_input, onnx_path,
|
|
128
|
+
input_names=['input'], output_names=['output'],
|
|
129
|
+
opset_version=11, do_constant_folding=True,
|
|
130
|
+
dynamo=False, # Legacy exporter
|
|
131
|
+
)
|
|
132
|
+
print(f"ONNX export successful")
|
|
133
|
+
|
|
134
|
+
# Step 3: Load ONNX graph
|
|
135
|
+
onnx_model = onnx.load(onnx_path)
|
|
136
|
+
onnx.checker.check_model(onnx_model)
|
|
137
|
+
|
|
138
|
+
layer_infos = parse_onnx_graph(onnx_model, pot_layer_infos, config, add_layer_infos)
|
|
139
|
+
print(f"Parsed {len(layer_infos)} layers from ONNX graph")
|
|
140
|
+
|
|
141
|
+
finally:
|
|
142
|
+
if os.path.exists(onnx_path):
|
|
143
|
+
os.remove(onnx_path)
|
|
144
|
+
|
|
145
|
+
# Step 4: Absorb standardization into first layer (per specification 6.4-6.6)
|
|
146
|
+
if config.mean is not None and config.std is not None:
|
|
147
|
+
absorb_standardization(layer_infos, config.mean, config.std)
|
|
148
|
+
|
|
149
|
+
# Step 5: Calculate combined scales (per specification 9.x)
|
|
150
|
+
calculate_combined_scales(layer_infos, config)
|
|
151
|
+
|
|
152
|
+
# Step 6: Scale biases by act_scale (CRITICAL FIX)
|
|
153
|
+
scale_biases(layer_infos)
|
|
154
|
+
|
|
155
|
+
# Step 6.5: Decide loop/unroll mode (All unroll now)
|
|
156
|
+
allocations = {}
|
|
157
|
+
|
|
158
|
+
# Add allocation info to layer_infos
|
|
159
|
+
for info in layer_infos:
|
|
160
|
+
if info['layer_type'] == 'pot':
|
|
161
|
+
name = info['name']
|
|
162
|
+
if name in allocations:
|
|
163
|
+
alloc = allocations[name]
|
|
164
|
+
info['code_mode'] = alloc.mode # 'unroll' or 'loop'
|
|
165
|
+
info['levels'] = alloc.levels # 11 or 4
|
|
166
|
+
else:
|
|
167
|
+
# Default to unroll
|
|
168
|
+
info['code_mode'] = 'unroll'
|
|
169
|
+
info['levels'] = 11
|
|
170
|
+
|
|
171
|
+
# Step 8: Generate C header
|
|
172
|
+
code = generate_header(layer_infos, config, optimized=optimized)
|
|
173
|
+
|
|
174
|
+
# Step 9: Write to file
|
|
175
|
+
with open(output_path, 'w') as f:
|
|
176
|
+
f.write(code)
|
|
177
|
+
|
|
178
|
+
# Print summary
|
|
179
|
+
pot_count = sum(1 for info in layer_infos if info['layer_type'] == 'pot')
|
|
180
|
+
maxpool_count = sum(1 for info in layer_infos if info['layer_type'] == 'maxpool')
|
|
181
|
+
unroll_count = sum(1 for info in layer_infos
|
|
182
|
+
if info['layer_type'] == 'pot')
|
|
183
|
+
|
|
184
|
+
print(f"\nExport complete!")
|
|
185
|
+
print(f" Output file: {output_path}")
|
|
186
|
+
print(f" Target MCU: {config.ram}B RAM, {config.flash}B Flash")
|
|
187
|
+
print(f" PoT layers: {pot_count}")
|
|
188
|
+
print(f" MaxPool layers: {maxpool_count}")
|
|
189
|
+
print(f" Total layers: {len(layer_infos)}")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def collect_pot_layer_info(model: nn.Module) -> List[Dict]:
|
|
193
|
+
"""Collect alpha, act_scale, weights from PoT layers in order."""
|
|
194
|
+
infos = []
|
|
195
|
+
|
|
196
|
+
for name, module in model.named_modules():
|
|
197
|
+
if isinstance(module, PoTLayerBase):
|
|
198
|
+
with torch.no_grad():
|
|
199
|
+
alpha = module.alpha.item()
|
|
200
|
+
# act_scale can be None for last layer
|
|
201
|
+
act_scale = module.act_scale.item() if module.act_scale is not None else None
|
|
202
|
+
weight = module.weight.cpu().clone()
|
|
203
|
+
bias = module.bias.cpu().clone() if module.bias is not None else None
|
|
204
|
+
|
|
205
|
+
# Get encoding from module (default: 'unroll')
|
|
206
|
+
encoding = getattr(module, 'encoding', 'unroll')
|
|
207
|
+
|
|
208
|
+
# Get PoT quantized weight with encoding-specific levels
|
|
209
|
+
weight_q = quantize_to_pot(weight, alpha, encoding=encoding)
|
|
210
|
+
|
|
211
|
+
# Apply 5level constraint (max 3 consecutive zeros)
|
|
212
|
+
if encoding == '5level':
|
|
213
|
+
from .quantize.pot import apply_5level_zero_constraint
|
|
214
|
+
weight_q = apply_5level_zero_constraint(weight_q)
|
|
215
|
+
|
|
216
|
+
layer_info = {
|
|
217
|
+
'name': name,
|
|
218
|
+
'pot_index': len(infos),
|
|
219
|
+
'alpha': alpha,
|
|
220
|
+
'act_scale': act_scale, # Keep None as None!
|
|
221
|
+
'weight': weight_q,
|
|
222
|
+
'bias': bias, # Original float bias
|
|
223
|
+
'encoding': encoding,
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
if isinstance(module, PoTConv2d):
|
|
227
|
+
layer_info['type'] = 'conv'
|
|
228
|
+
layer_info['in_channels'] = module.in_channels
|
|
229
|
+
layer_info['out_channels'] = module.out_channels
|
|
230
|
+
layer_info['kernel_size'] = module.kernel_size
|
|
231
|
+
layer_info['stride'] = module.stride
|
|
232
|
+
layer_info['padding'] = module.padding
|
|
233
|
+
layer_info['groups'] = module.groups
|
|
234
|
+
elif isinstance(module, PoTConv1d):
|
|
235
|
+
layer_info['type'] = 'conv1d'
|
|
236
|
+
layer_info['in_channels'] = module.in_channels
|
|
237
|
+
layer_info['out_channels'] = module.out_channels
|
|
238
|
+
layer_info['kernel_size'] = module.kernel_size
|
|
239
|
+
layer_info['stride'] = module.stride
|
|
240
|
+
layer_info['padding'] = module.padding
|
|
241
|
+
layer_info['groups'] = getattr(module, 'groups', 1)
|
|
242
|
+
elif isinstance(module, PoTDepthwiseConv2d):
|
|
243
|
+
layer_info['type'] = 'depthwise'
|
|
244
|
+
layer_info['channels'] = module.channels
|
|
245
|
+
layer_info['in_channels'] = module.channels
|
|
246
|
+
layer_info['out_channels'] = module.channels
|
|
247
|
+
layer_info['kernel_size'] = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size
|
|
248
|
+
layer_info['stride'] = module.stride[0] if isinstance(module.stride, tuple) else module.stride
|
|
249
|
+
layer_info['padding'] = module.padding[0] if isinstance(module.padding, tuple) else module.padding
|
|
250
|
+
layer_info['groups'] = module.channels
|
|
251
|
+
elif isinstance(module, PoTLinear):
|
|
252
|
+
layer_info['type'] = 'linear'
|
|
253
|
+
layer_info['in_features'] = module.in_features
|
|
254
|
+
layer_info['out_features'] = module.out_features
|
|
255
|
+
|
|
256
|
+
infos.append(layer_info)
|
|
257
|
+
|
|
258
|
+
return infos
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def collect_add_layer_info(model: nn.Module) -> List[Dict]:
|
|
262
|
+
"""Collect PoTAdd layers info in order."""
|
|
263
|
+
infos = []
|
|
264
|
+
|
|
265
|
+
for name, module in model.named_modules():
|
|
266
|
+
if isinstance(module, PoTAdd):
|
|
267
|
+
layer_info = {
|
|
268
|
+
'name': name,
|
|
269
|
+
'add_index': len(infos),
|
|
270
|
+
'type': 'add',
|
|
271
|
+
'rescale_mult': module.rescale_mult.item() if module.rescale_mult is not None else 128,
|
|
272
|
+
'rescale_shift': module.rescale_shift.item() if module.rescale_shift is not None else 7,
|
|
273
|
+
'scale_x': module.scale_x.item() if module.scale_x is not None else None,
|
|
274
|
+
'scale_y': module.scale_y.item() if module.scale_y is not None else None,
|
|
275
|
+
'act_scale': module.act_scale.item() if module.act_scale is not None else None,
|
|
276
|
+
}
|
|
277
|
+
infos.append(layer_info)
|
|
278
|
+
print(f" Found PoTAdd: {name}, mult={layer_info['rescale_mult']}, shift={layer_info['rescale_shift']}")
|
|
279
|
+
|
|
280
|
+
return infos
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def parse_onnx_graph(onnx_model, pot_layer_infos: List[Dict], config: Config, add_layer_infos: List[Dict] = None) -> List[Dict]:
|
|
284
|
+
"""Parse ONNX graph and build layer_infos list with graph analysis for skip connections."""
|
|
285
|
+
if add_layer_infos is None:
|
|
286
|
+
add_layer_infos = []
|
|
287
|
+
|
|
288
|
+
graph = onnx_model.graph
|
|
289
|
+
|
|
290
|
+
initializers = {}
|
|
291
|
+
for init in graph.initializer:
|
|
292
|
+
initializers[init.name] = numpy_helper.to_array(init)
|
|
293
|
+
|
|
294
|
+
# ========================================
|
|
295
|
+
# Step 1: Build tensor→node mapping for graph analysis
|
|
296
|
+
# ========================================
|
|
297
|
+
tensor_to_producer = {} # tensor_name → (node_idx, node)
|
|
298
|
+
tensor_to_shape = {} # tensor_name → (channels, h, w) or (channels, L) or (features,)
|
|
299
|
+
tensor_to_scale = {} # tensor_name → act_scale (CRITICAL for Add rescale!)
|
|
300
|
+
|
|
301
|
+
# Detect Conv1d mode
|
|
302
|
+
is_conv1d_model = (config.input_w == 1)
|
|
303
|
+
|
|
304
|
+
# Input tensor
|
|
305
|
+
input_name = graph.input[0].name
|
|
306
|
+
if is_conv1d_model:
|
|
307
|
+
# Conv1d: (C, L)
|
|
308
|
+
tensor_to_shape[input_name] = (config.input_channels, config.input_h)
|
|
309
|
+
else:
|
|
310
|
+
# Conv2d: (C, H, W)
|
|
311
|
+
tensor_to_shape[input_name] = (config.input_channels, config.input_h, config.input_w)
|
|
312
|
+
tensor_to_producer[input_name] = (-1, None) # -1 = input
|
|
313
|
+
tensor_to_scale[input_name] = 1.0 # Input scale
|
|
314
|
+
|
|
315
|
+
# First pass: collect all node outputs and their producers
|
|
316
|
+
for node_idx, node in enumerate(graph.node):
|
|
317
|
+
for out_name in node.output:
|
|
318
|
+
tensor_to_producer[out_name] = (node_idx, node)
|
|
319
|
+
|
|
320
|
+
# Build tensor→consumers mapping for forward tracing
|
|
321
|
+
tensor_to_consumers = {} # tensor_name → [(node_idx, node), ...]
|
|
322
|
+
for node_idx, node in enumerate(graph.node):
|
|
323
|
+
for in_name in node.input:
|
|
324
|
+
if in_name not in tensor_to_consumers:
|
|
325
|
+
tensor_to_consumers[in_name] = []
|
|
326
|
+
tensor_to_consumers[in_name].append((node_idx, node))
|
|
327
|
+
|
|
328
|
+
def trace_forward_to_relu(tensor_name: str) -> bool:
|
|
329
|
+
"""순방향 추적: 텐서에서 시작해서 Relu가 나오는지 확인.
|
|
330
|
+
|
|
331
|
+
Conv → Mul → Round → Clip → Div → Relu 체인에서
|
|
332
|
+
Conv 출력 텐서에서 시작해서 Relu까지 따라감.
|
|
333
|
+
|
|
334
|
+
중간 노드(Mul, Round, Clip, Div)는 pass through.
|
|
335
|
+
"""
|
|
336
|
+
visited = set()
|
|
337
|
+
current = tensor_name
|
|
338
|
+
|
|
339
|
+
while current and current not in visited:
|
|
340
|
+
visited.add(current)
|
|
341
|
+
|
|
342
|
+
if current not in tensor_to_consumers:
|
|
343
|
+
return False
|
|
344
|
+
|
|
345
|
+
consumers = tensor_to_consumers[current]
|
|
346
|
+
if not consumers:
|
|
347
|
+
return False
|
|
348
|
+
|
|
349
|
+
# 첫 번째 consumer만 따라감 (linear chain 가정)
|
|
350
|
+
_, node = consumers[0]
|
|
351
|
+
op_type = node.op_type
|
|
352
|
+
|
|
353
|
+
if op_type == 'Relu':
|
|
354
|
+
return True
|
|
355
|
+
|
|
356
|
+
# Pass through 노드: 출력을 따라감
|
|
357
|
+
if op_type in ('Mul', 'Round', 'Clip', 'Div', 'Cast', 'Add', 'Floor'):
|
|
358
|
+
if node.output:
|
|
359
|
+
current = node.output[0]
|
|
360
|
+
else:
|
|
361
|
+
return False
|
|
362
|
+
else:
|
|
363
|
+
# 다른 노드 만나면 중단
|
|
364
|
+
return False
|
|
365
|
+
|
|
366
|
+
return False
|
|
367
|
+
|
|
368
|
+
def trace_to_layer(tensor_name: str, tensor_to_layer_idx: dict) -> int:
|
|
369
|
+
"""역추적: 텐서에서 시작해서 원래 PoT 레이어까지 추적.
|
|
370
|
+
|
|
371
|
+
ONNX 그래프에서 Conv → Mul → Round → Clip → Div → Relu 같은
|
|
372
|
+
체인이 있을 때, 어떤 텐서에서든 원래 Conv 레이어를 찾아냄.
|
|
373
|
+
|
|
374
|
+
범용성: 어떤 중간 노드가 있든 첫 번째 입력을 따라감.
|
|
375
|
+
"""
|
|
376
|
+
visited = set()
|
|
377
|
+
current = tensor_name
|
|
378
|
+
|
|
379
|
+
while current and current not in visited:
|
|
380
|
+
# 이미 매핑된 텐서면 바로 반환
|
|
381
|
+
if current in tensor_to_layer_idx:
|
|
382
|
+
return tensor_to_layer_idx[current]
|
|
383
|
+
|
|
384
|
+
visited.add(current)
|
|
385
|
+
|
|
386
|
+
# 이 텐서를 생성한 노드 찾기
|
|
387
|
+
if current in tensor_to_producer:
|
|
388
|
+
node_idx, node = tensor_to_producer[current]
|
|
389
|
+
if node is None:
|
|
390
|
+
# 입력 텐서
|
|
391
|
+
return -1
|
|
392
|
+
if len(node.input) > 0:
|
|
393
|
+
# 첫 번째 입력으로 역추적
|
|
394
|
+
current = node.input[0]
|
|
395
|
+
else:
|
|
396
|
+
break
|
|
397
|
+
else:
|
|
398
|
+
break
|
|
399
|
+
|
|
400
|
+
return -1 # 못 찾음
|
|
401
|
+
|
|
402
|
+
def trace_to_scale(tensor_name: str) -> float:
|
|
403
|
+
"""텐서의 activation scale 추적.
|
|
404
|
+
|
|
405
|
+
텐서 체인을 따라가며 가장 최근 scale 값을 찾음.
|
|
406
|
+
MaxPool, Flatten 등은 입력 scale을 그대로 전파.
|
|
407
|
+
"""
|
|
408
|
+
visited = set()
|
|
409
|
+
current = tensor_name
|
|
410
|
+
|
|
411
|
+
while current and current not in visited:
|
|
412
|
+
if current in tensor_to_scale:
|
|
413
|
+
return tensor_to_scale[current]
|
|
414
|
+
|
|
415
|
+
visited.add(current)
|
|
416
|
+
|
|
417
|
+
if current in tensor_to_producer:
|
|
418
|
+
node_idx, node = tensor_to_producer[current]
|
|
419
|
+
if node is None:
|
|
420
|
+
return 1.0 # Input
|
|
421
|
+
if len(node.input) > 0:
|
|
422
|
+
current = node.input[0]
|
|
423
|
+
else:
|
|
424
|
+
break
|
|
425
|
+
else:
|
|
426
|
+
break
|
|
427
|
+
|
|
428
|
+
return 1.0 # Default
|
|
429
|
+
|
|
430
|
+
# ========================================
|
|
431
|
+
# Step 2: Process nodes in order
|
|
432
|
+
# ========================================
|
|
433
|
+
current_h = config.input_h
|
|
434
|
+
current_w = config.input_w
|
|
435
|
+
current_ch = config.input_channels
|
|
436
|
+
|
|
437
|
+
layer_infos = []
|
|
438
|
+
pot_index = 0
|
|
439
|
+
num_pot_layers = len(pot_layer_infos)
|
|
440
|
+
|
|
441
|
+
# Track which layer produces which tensor (for skip connections)
|
|
442
|
+
tensor_to_layer_idx = {input_name: -1} # -1 = input
|
|
443
|
+
|
|
444
|
+
for node_idx, node in enumerate(graph.node):
|
|
445
|
+
op_type = node.op_type
|
|
446
|
+
output_name = node.output[0] if node.output else None
|
|
447
|
+
|
|
448
|
+
if op_type == 'Conv':
|
|
449
|
+
pot_info = pot_layer_infos[pot_index] if pot_index < num_pot_layers else None
|
|
450
|
+
|
|
451
|
+
if pot_info and pot_info['type'] == 'conv1d':
|
|
452
|
+
# ===== Conv1D Processing =====
|
|
453
|
+
weight = pot_info['weight']
|
|
454
|
+
bias = pot_info['bias']
|
|
455
|
+
out_channels = pot_info['out_channels']
|
|
456
|
+
in_channels = pot_info['in_channels']
|
|
457
|
+
alpha = pot_info['alpha']
|
|
458
|
+
act_scale = pot_info['act_scale']
|
|
459
|
+
|
|
460
|
+
kL = pot_info['kernel_size']
|
|
461
|
+
stride = pot_info['stride']
|
|
462
|
+
padding = pot_info['padding']
|
|
463
|
+
|
|
464
|
+
# For 1D: current_h represents length, current_w is 1
|
|
465
|
+
in_L = current_h # Use current_h as length
|
|
466
|
+
out_L = (in_L + 2 * padding - kL) // stride + 1
|
|
467
|
+
|
|
468
|
+
layer_idx = len(layer_infos)
|
|
469
|
+
|
|
470
|
+
has_relu = trace_forward_to_relu(output_name) if output_name else False
|
|
471
|
+
|
|
472
|
+
info = {
|
|
473
|
+
'name': f'layer_{layer_idx}',
|
|
474
|
+
'type': 'PoTConv1d',
|
|
475
|
+
'layer_type': 'pot',
|
|
476
|
+
'weight': weight,
|
|
477
|
+
'bias': bias,
|
|
478
|
+
'alpha': alpha,
|
|
479
|
+
'act_scale': act_scale,
|
|
480
|
+
'in_channels': in_channels,
|
|
481
|
+
'out_channels': out_channels,
|
|
482
|
+
'kernel_size': kL,
|
|
483
|
+
'stride': stride,
|
|
484
|
+
'padding': padding,
|
|
485
|
+
'in_L': in_L,
|
|
486
|
+
'out_L': out_L,
|
|
487
|
+
# Compatibility with 2D: for buffer size calculation
|
|
488
|
+
'in_h': in_L,
|
|
489
|
+
'in_w': 1,
|
|
490
|
+
'out_h': out_L,
|
|
491
|
+
'out_w': 1,
|
|
492
|
+
'is_first': (pot_index == 0),
|
|
493
|
+
'is_last': (pot_index == num_pot_layers - 1),
|
|
494
|
+
'has_relu': has_relu,
|
|
495
|
+
'encoding': pot_info.get('encoding', 'unroll'),
|
|
496
|
+
}
|
|
497
|
+
layer_infos.append(info)
|
|
498
|
+
|
|
499
|
+
if output_name:
|
|
500
|
+
tensor_to_layer_idx[output_name] = layer_idx
|
|
501
|
+
tensor_to_shape[output_name] = (out_channels, out_L)
|
|
502
|
+
if act_scale is not None:
|
|
503
|
+
tensor_to_scale[output_name] = act_scale
|
|
504
|
+
|
|
505
|
+
current_h, current_w, current_ch = out_L, 1, out_channels
|
|
506
|
+
pot_index += 1
|
|
507
|
+
|
|
508
|
+
print(f" Conv1d: {in_channels}x{in_L} -> {out_channels}x{out_L}")
|
|
509
|
+
|
|
510
|
+
elif pot_info and pot_info['type'] in ('conv', 'depthwise'):
|
|
511
|
+
# ===== Conv2D Processing =====
|
|
512
|
+
weight = pot_info['weight']
|
|
513
|
+
bias = pot_info['bias']
|
|
514
|
+
out_channels = pot_info['out_channels']
|
|
515
|
+
in_channels = pot_info['in_channels']
|
|
516
|
+
alpha = pot_info['alpha']
|
|
517
|
+
act_scale = pot_info['act_scale'] # Can be None
|
|
518
|
+
is_depthwise = (pot_info['type'] == 'depthwise')
|
|
519
|
+
|
|
520
|
+
# Use kernel_size, stride, padding from PyTorch model
|
|
521
|
+
kh_tuple = pot_info['kernel_size']
|
|
522
|
+
stride_tuple = pot_info['stride']
|
|
523
|
+
padding_tuple = pot_info['padding']
|
|
524
|
+
|
|
525
|
+
# Normalize to (h, w) tuples
|
|
526
|
+
if isinstance(kh_tuple, int): kh_tuple = (kh_tuple, kh_tuple)
|
|
527
|
+
if isinstance(stride_tuple, int): stride_tuple = (stride_tuple, stride_tuple)
|
|
528
|
+
if isinstance(padding_tuple, int): padding_tuple = (padding_tuple, padding_tuple)
|
|
529
|
+
|
|
530
|
+
kh, kw = kh_tuple
|
|
531
|
+
sh, sw = stride_tuple
|
|
532
|
+
ph, pw = padding_tuple
|
|
533
|
+
|
|
534
|
+
# [Support Conv1d]
|
|
535
|
+
# If weight is 3D (Out, In, Length), reshape to 4D (Out, In, 1, Length)
|
|
536
|
+
# to satisfy generic Conv2d generators.
|
|
537
|
+
if weight.ndim == 3:
|
|
538
|
+
weight = weight.unsqueeze(2) # (N, C, L) -> (N, C, 1, L)
|
|
539
|
+
# For Conv1d, kh is usually tuple (k,) or int k
|
|
540
|
+
# If it was 1D kernel, now it's effectively 1xW.
|
|
541
|
+
# We might need to ensure kh/kw are correct?
|
|
542
|
+
# The generator reads shape from w_q.shape usually.
|
|
543
|
+
pass
|
|
544
|
+
else:
|
|
545
|
+
raise RuntimeError(f"PoT layer mismatch at index {pot_index}")
|
|
546
|
+
input_tensor_name = node.input[0]
|
|
547
|
+
if input_tensor_name in tensor_to_shape:
|
|
548
|
+
input_shape = tensor_to_shape[input_tensor_name]
|
|
549
|
+
in_h = input_shape[1]
|
|
550
|
+
in_w = input_shape[2] if len(input_shape) > 2 else 1
|
|
551
|
+
else:
|
|
552
|
+
in_h, in_w = current_h, current_w # fallback
|
|
553
|
+
# Continue only for Conv2D (Conv1D handled above and continues)
|
|
554
|
+
if pot_info and pot_info['type'] in ('conv', 'depthwise'):
|
|
555
|
+
out_h = (in_h + 2 * ph - kh) // sh + 1
|
|
556
|
+
out_w = (in_w + 2 * pw - kw) // sw + 1
|
|
557
|
+
|
|
558
|
+
# Determine layer type
|
|
559
|
+
if is_depthwise:
|
|
560
|
+
layer_type_name = 'PoTDepthwiseConv2d'
|
|
561
|
+
else:
|
|
562
|
+
layer_type_name = 'PoTConv2d'
|
|
563
|
+
|
|
564
|
+
layer_idx = len(layer_infos)
|
|
565
|
+
|
|
566
|
+
# Check if ReLU follows this Conv (via ONNX chain)
|
|
567
|
+
has_relu = trace_forward_to_relu(output_name) if output_name else False
|
|
568
|
+
|
|
569
|
+
info = {
|
|
570
|
+
'name': f'layer_{layer_idx}',
|
|
571
|
+
'type': layer_type_name,
|
|
572
|
+
'layer_type': 'pot',
|
|
573
|
+
'weight': weight,
|
|
574
|
+
'bias': bias,
|
|
575
|
+
'alpha': alpha,
|
|
576
|
+
'act_scale': act_scale, # Keep None as None
|
|
577
|
+
'in_channels': in_channels,
|
|
578
|
+
'out_channels': out_channels,
|
|
579
|
+
'kernel_size': kh,
|
|
580
|
+
'stride': sh if sh == sw else (sh, sw),
|
|
581
|
+
'padding': ph if ph == pw else (ph, pw),
|
|
582
|
+
'in_h': in_h,
|
|
583
|
+
'in_w': in_w,
|
|
584
|
+
'out_h': out_h,
|
|
585
|
+
'out_w': out_w,
|
|
586
|
+
'is_first': (pot_index == 0),
|
|
587
|
+
'is_last': (pot_index == num_pot_layers - 1),
|
|
588
|
+
'has_relu': has_relu,
|
|
589
|
+
'encoding': pot_info.get('encoding', 'unroll'),
|
|
590
|
+
'groups': out_channels if is_depthwise else pot_info.get('groups', 1),
|
|
591
|
+
}
|
|
592
|
+
layer_infos.append(info)
|
|
593
|
+
|
|
594
|
+
# Track output tensor
|
|
595
|
+
if output_name:
|
|
596
|
+
tensor_to_layer_idx[output_name] = layer_idx
|
|
597
|
+
tensor_to_shape[output_name] = (out_channels, out_h, out_w)
|
|
598
|
+
# Track scale for Add rescale calculation
|
|
599
|
+
if act_scale is not None:
|
|
600
|
+
tensor_to_scale[output_name] = act_scale
|
|
601
|
+
|
|
602
|
+
current_h, current_w, current_ch = out_h, out_w, out_channels
|
|
603
|
+
pot_index += 1
|
|
604
|
+
|
|
605
|
+
layer_desc = 'DepthwiseConv' if is_depthwise else 'Conv'
|
|
606
|
+
print(f" {layer_desc}: {in_channels}x{info['in_h']}x{info['in_w']} -> {out_channels}x{out_h}x{out_w}")
|
|
607
|
+
|
|
608
|
+
elif op_type == 'MaxPool':
|
|
609
|
+
kernel_shape = get_attribute(node, 'kernel_shape', [2, 2])
|
|
610
|
+
strides = get_attribute(node, 'strides', [2, 2])
|
|
611
|
+
|
|
612
|
+
# Detect 1D vs 2D based on kernel_shape length or is_conv1d_model
|
|
613
|
+
is_maxpool_1d = (len(kernel_shape) == 1) or is_conv1d_model
|
|
614
|
+
|
|
615
|
+
k = kernel_shape[0]
|
|
616
|
+
s = strides[0]
|
|
617
|
+
|
|
618
|
+
if is_maxpool_1d:
|
|
619
|
+
# MaxPool1d: only update h (length)
|
|
620
|
+
out_h = current_h // s
|
|
621
|
+
out_w = 1 # Keep as 1 for 1D
|
|
622
|
+
pool_type = 'MaxPool1d'
|
|
623
|
+
else:
|
|
624
|
+
# MaxPool2d
|
|
625
|
+
out_h = current_h // s
|
|
626
|
+
out_w = current_w // s
|
|
627
|
+
pool_type = 'MaxPool2d'
|
|
628
|
+
|
|
629
|
+
layer_idx = len(layer_infos)
|
|
630
|
+
info = {
|
|
631
|
+
'name': f'layer_{layer_idx}',
|
|
632
|
+
'type': pool_type,
|
|
633
|
+
'layer_type': 'maxpool',
|
|
634
|
+
'kernel_size': k,
|
|
635
|
+
'stride': s,
|
|
636
|
+
'in_h': current_h,
|
|
637
|
+
'in_w': current_w,
|
|
638
|
+
'in_channels': current_ch,
|
|
639
|
+
'out_h': out_h,
|
|
640
|
+
'out_w': out_w,
|
|
641
|
+
'is_1d': is_maxpool_1d,
|
|
642
|
+
}
|
|
643
|
+
layer_infos.append(info)
|
|
644
|
+
|
|
645
|
+
# Track output tensor
|
|
646
|
+
if output_name:
|
|
647
|
+
tensor_to_layer_idx[output_name] = layer_idx
|
|
648
|
+
if is_maxpool_1d:
|
|
649
|
+
tensor_to_shape[output_name] = (current_ch, out_h)
|
|
650
|
+
else:
|
|
651
|
+
tensor_to_shape[output_name] = (current_ch, out_h, out_w)
|
|
652
|
+
# Propagate scale from input (MaxPool doesn't change scale)
|
|
653
|
+
input_name = node.input[0] if len(node.input) > 0 else None
|
|
654
|
+
if input_name:
|
|
655
|
+
input_scale = trace_to_scale(input_name)
|
|
656
|
+
tensor_to_scale[output_name] = input_scale
|
|
657
|
+
|
|
658
|
+
current_h, current_w = out_h, out_w
|
|
659
|
+
|
|
660
|
+
if is_maxpool_1d:
|
|
661
|
+
print(f" MaxPool1d: {current_ch}x{info['in_h']} -> {current_ch}x{out_h}")
|
|
662
|
+
else:
|
|
663
|
+
print(f" MaxPool2d: {current_ch}x{info['in_h']}x{info['in_w']} -> {current_ch}x{out_h}x{out_w}")
|
|
664
|
+
|
|
665
|
+
elif op_type in ('GlobalAveragePool', 'ReduceMean'):
|
|
666
|
+
# GlobalAveragePool: C×H×W → C×1×1
|
|
667
|
+
# ReduceMean with axes=[2,3]: same effect
|
|
668
|
+
|
|
669
|
+
# Check if it's ReduceMean over spatial dims
|
|
670
|
+
if op_type == 'ReduceMean':
|
|
671
|
+
axes = get_attribute(node, 'axes', None)
|
|
672
|
+
|
|
673
|
+
# axes might be in input[1] instead of attribute (newer ONNX)
|
|
674
|
+
if axes is None and len(node.input) > 1:
|
|
675
|
+
axes_name = node.input[1]
|
|
676
|
+
if axes_name in initializers:
|
|
677
|
+
axes = initializers[axes_name].tolist()
|
|
678
|
+
|
|
679
|
+
# axes=[2,3] or axes=[-2,-1] means spatial reduction
|
|
680
|
+
if axes is None:
|
|
681
|
+
continue
|
|
682
|
+
if set(axes) not in ({2, 3}, {-2, -1}):
|
|
683
|
+
continue
|
|
684
|
+
|
|
685
|
+
layer_idx = len(layer_infos)
|
|
686
|
+
pool_size = current_h * current_w
|
|
687
|
+
|
|
688
|
+
# Compute div_mult and div_shift for integer division
|
|
689
|
+
if pool_size > 0 and (pool_size & (pool_size - 1)) == 0:
|
|
690
|
+
# Power of 2
|
|
691
|
+
import math
|
|
692
|
+
div_mult = 1
|
|
693
|
+
div_shift = int(math.log2(pool_size))
|
|
694
|
+
else:
|
|
695
|
+
# Not power of 2: avg ≈ (sum * div_mult) >> div_shift
|
|
696
|
+
import math
|
|
697
|
+
base_shift = 15
|
|
698
|
+
div_mult = round((1 << base_shift) / pool_size)
|
|
699
|
+
while div_mult > 255 and base_shift > 8:
|
|
700
|
+
base_shift -= 1
|
|
701
|
+
div_mult = round((1 << base_shift) / pool_size)
|
|
702
|
+
div_shift = base_shift
|
|
703
|
+
div_mult = max(1, min(65535, div_mult))
|
|
704
|
+
|
|
705
|
+
info = {
|
|
706
|
+
'name': f'layer_{layer_idx}',
|
|
707
|
+
'type': 'GlobalAvgPool',
|
|
708
|
+
'layer_type': 'global_avg_pool',
|
|
709
|
+
'in_h': current_h,
|
|
710
|
+
'in_w': current_w,
|
|
711
|
+
'in_channels': current_ch,
|
|
712
|
+
'out_channels': current_ch,
|
|
713
|
+
'pool_size': pool_size,
|
|
714
|
+
'div_mult': div_mult,
|
|
715
|
+
'div_shift': div_shift,
|
|
716
|
+
}
|
|
717
|
+
layer_infos.append(info)
|
|
718
|
+
|
|
719
|
+
# Track output tensor
|
|
720
|
+
if output_name:
|
|
721
|
+
tensor_to_layer_idx[output_name] = layer_idx
|
|
722
|
+
tensor_to_shape[output_name] = (current_ch, 1, 1)
|
|
723
|
+
|
|
724
|
+
# After global avg pool: H=1, W=1
|
|
725
|
+
print(f" GlobalAvgPool: {current_ch}x{current_h}x{current_w} -> {current_ch} (pool_size={pool_size}, mult={div_mult}, shift={div_shift})")
|
|
726
|
+
current_h, current_w = 1, 1
|
|
727
|
+
|
|
728
|
+
elif op_type in ('Flatten', 'Reshape'):
|
|
729
|
+
shape = None
|
|
730
|
+
if op_type == 'Reshape' and len(node.input) > 1:
|
|
731
|
+
shape_name = node.input[1]
|
|
732
|
+
if shape_name in initializers:
|
|
733
|
+
shape = initializers[shape_name].tolist()
|
|
734
|
+
|
|
735
|
+
if op_type == 'Flatten' or (shape and len(shape) == 2):
|
|
736
|
+
out_features = current_h * current_w * current_ch
|
|
737
|
+
|
|
738
|
+
layer_idx = len(layer_infos)
|
|
739
|
+
info = {
|
|
740
|
+
'name': f'layer_{layer_idx}',
|
|
741
|
+
'type': 'Flatten',
|
|
742
|
+
'layer_type': 'flatten',
|
|
743
|
+
'in_h': current_h,
|
|
744
|
+
'in_w': current_w,
|
|
745
|
+
'in_channels': current_ch,
|
|
746
|
+
'out_features': out_features,
|
|
747
|
+
}
|
|
748
|
+
layer_infos.append(info)
|
|
749
|
+
|
|
750
|
+
# Track output tensor
|
|
751
|
+
if output_name:
|
|
752
|
+
tensor_to_layer_idx[output_name] = layer_idx
|
|
753
|
+
|
|
754
|
+
print(f" Flatten: {current_ch}x{current_h}x{current_w} -> {out_features}")
|
|
755
|
+
current_h, current_w = 1, 1
|
|
756
|
+
|
|
757
|
+
elif op_type == 'Gemm':
|
|
758
|
+
pot_info = pot_layer_infos[pot_index] if pot_index < num_pot_layers else None
|
|
759
|
+
|
|
760
|
+
if pot_info and pot_info['type'] == 'linear':
|
|
761
|
+
weight = pot_info['weight']
|
|
762
|
+
bias = pot_info['bias']
|
|
763
|
+
in_features = pot_info['in_features']
|
|
764
|
+
out_features = pot_info['out_features']
|
|
765
|
+
alpha = pot_info['alpha']
|
|
766
|
+
act_scale = pot_info['act_scale'] # Can be None
|
|
767
|
+
else:
|
|
768
|
+
raise RuntimeError(f"PoT layer mismatch at index {pot_index}")
|
|
769
|
+
|
|
770
|
+
layer_idx = len(layer_infos)
|
|
771
|
+
|
|
772
|
+
# Check if ReLU follows this Linear (via ONNX chain)
|
|
773
|
+
has_relu = trace_forward_to_relu(output_name) if output_name else False
|
|
774
|
+
|
|
775
|
+
info = {
|
|
776
|
+
'name': f'layer_{layer_idx}',
|
|
777
|
+
'type': 'PoTLinear',
|
|
778
|
+
'layer_type': 'pot',
|
|
779
|
+
'weight': weight,
|
|
780
|
+
'bias': bias,
|
|
781
|
+
'alpha': alpha,
|
|
782
|
+
'act_scale': act_scale, # Keep None as None
|
|
783
|
+
'in_features': in_features,
|
|
784
|
+
'out_features': out_features,
|
|
785
|
+
'is_first': (pot_index == 0),
|
|
786
|
+
'is_last': (pot_index == num_pot_layers - 1),
|
|
787
|
+
'has_relu': has_relu,
|
|
788
|
+
'encoding': pot_info.get('encoding', 'unroll'),
|
|
789
|
+
}
|
|
790
|
+
layer_infos.append(info)
|
|
791
|
+
|
|
792
|
+
# Track output tensor
|
|
793
|
+
if output_name:
|
|
794
|
+
tensor_to_layer_idx[output_name] = layer_idx
|
|
795
|
+
# Track scale (Linear is usually last, but just in case)
|
|
796
|
+
if act_scale is not None:
|
|
797
|
+
tensor_to_scale[output_name] = act_scale
|
|
798
|
+
|
|
799
|
+
current_ch = out_features
|
|
800
|
+
pot_index += 1
|
|
801
|
+
|
|
802
|
+
print(f" Linear: {in_features} -> {out_features}")
|
|
803
|
+
|
|
804
|
+
elif op_type == 'Relu':
|
|
805
|
+
# ReLU: pass through tensor mapping and scale
|
|
806
|
+
if len(node.input) > 0 and len(node.output) > 0:
|
|
807
|
+
relu_input = node.input[0]
|
|
808
|
+
relu_output = node.output[0]
|
|
809
|
+
if relu_input in tensor_to_layer_idx:
|
|
810
|
+
tensor_to_layer_idx[relu_output] = tensor_to_layer_idx[relu_input]
|
|
811
|
+
if relu_input in tensor_to_shape:
|
|
812
|
+
tensor_to_shape[relu_output] = tensor_to_shape[relu_input]
|
|
813
|
+
# Propagate scale
|
|
814
|
+
input_scale = trace_to_scale(relu_input)
|
|
815
|
+
tensor_to_scale[relu_output] = input_scale
|
|
816
|
+
|
|
817
|
+
elif op_type in ('Mul', 'Round', 'Clip', 'Div'):
|
|
818
|
+
# 양자화 관련 중간 노드: 텐서 매핑 및 scale 전파
|
|
819
|
+
if len(node.input) > 0 and len(node.output) > 0:
|
|
820
|
+
node_input = node.input[0]
|
|
821
|
+
node_output = node.output[0]
|
|
822
|
+
if node_input in tensor_to_layer_idx:
|
|
823
|
+
tensor_to_layer_idx[node_output] = tensor_to_layer_idx[node_input]
|
|
824
|
+
if node_input in tensor_to_shape:
|
|
825
|
+
tensor_to_shape[node_output] = tensor_to_shape[node_input]
|
|
826
|
+
# Propagate scale
|
|
827
|
+
input_scale = trace_to_scale(node_input)
|
|
828
|
+
tensor_to_scale[node_output] = input_scale
|
|
829
|
+
|
|
830
|
+
elif op_type == 'Add':
|
|
831
|
+
# ========================================
|
|
832
|
+
# ONNX Add 노드 분류 (robust version)
|
|
833
|
+
#
|
|
834
|
+
# ONNX에서 Add 노드가 생기는 경우:
|
|
835
|
+
# 1. Bias addition: Conv/Linear의 bias (한 입력이 initializer)
|
|
836
|
+
# 2. Skip connection: x + conv(x) (두 입력 모두 레이어 출력)
|
|
837
|
+
# 3. 기타 산술 연산
|
|
838
|
+
#
|
|
839
|
+
# 판단 기준:
|
|
840
|
+
# - 한 입력이 initializer → bias add → 무시
|
|
841
|
+
# - 두 입력의 소스 레이어가 같음 → 무시
|
|
842
|
+
# - 두 입력의 소스 레이어가 다름 → skip connection
|
|
843
|
+
# ========================================
|
|
844
|
+
|
|
845
|
+
input_a = node.input[0] if len(node.input) > 0 else None
|
|
846
|
+
input_b = node.input[1] if len(node.input) > 1 else None
|
|
847
|
+
|
|
848
|
+
# ------------------------------------------
|
|
849
|
+
# Case 1: 한 입력이 initializer (상수/bias)
|
|
850
|
+
# ------------------------------------------
|
|
851
|
+
a_is_initializer = input_a in initializers
|
|
852
|
+
b_is_initializer = input_b in initializers
|
|
853
|
+
|
|
854
|
+
if a_is_initializer or b_is_initializer:
|
|
855
|
+
# Bias addition - NOT a skip connection
|
|
856
|
+
# Just propagate scale, shape, and layer index
|
|
857
|
+
print(f" [DEBUG] Bias Add detected (input is initializer), skipping")
|
|
858
|
+
if output_name:
|
|
859
|
+
non_const_input = input_b if a_is_initializer else input_a
|
|
860
|
+
|
|
861
|
+
# Propagate scale
|
|
862
|
+
input_scale = trace_to_scale(non_const_input)
|
|
863
|
+
tensor_to_scale[output_name] = input_scale
|
|
864
|
+
|
|
865
|
+
# Propagate shape by tracing back
|
|
866
|
+
traced_shape = None
|
|
867
|
+
trace_name = non_const_input
|
|
868
|
+
trace_visited = set()
|
|
869
|
+
while trace_name and trace_name not in trace_visited:
|
|
870
|
+
trace_visited.add(trace_name)
|
|
871
|
+
if trace_name in tensor_to_shape:
|
|
872
|
+
traced_shape = tensor_to_shape[trace_name]
|
|
873
|
+
break
|
|
874
|
+
if trace_name in tensor_to_producer:
|
|
875
|
+
_, prod_node = tensor_to_producer[trace_name]
|
|
876
|
+
if prod_node and prod_node.input:
|
|
877
|
+
trace_name = prod_node.input[0]
|
|
878
|
+
else:
|
|
879
|
+
break
|
|
880
|
+
else:
|
|
881
|
+
break
|
|
882
|
+
if traced_shape:
|
|
883
|
+
tensor_to_shape[output_name] = traced_shape
|
|
884
|
+
|
|
885
|
+
# Propagate layer index (this Add is pass-through)
|
|
886
|
+
traced_layer = trace_to_layer(non_const_input, tensor_to_layer_idx)
|
|
887
|
+
if traced_layer >= 0:
|
|
888
|
+
tensor_to_layer_idx[output_name] = traced_layer
|
|
889
|
+
continue
|
|
890
|
+
|
|
891
|
+
# ------------------------------------------
|
|
892
|
+
# Case 2: 둘 다 텐서 - 소스 레이어 분석
|
|
893
|
+
# ------------------------------------------
|
|
894
|
+
source_a = trace_to_layer(input_a, tensor_to_layer_idx)
|
|
895
|
+
source_b = trace_to_layer(input_b, tensor_to_layer_idx)
|
|
896
|
+
|
|
897
|
+
print(f" [DEBUG] Add node: source_a={source_a}, source_b={source_b}")
|
|
898
|
+
|
|
899
|
+
# 같은 소스에서 오면 skip connection 아님
|
|
900
|
+
if source_a == source_b:
|
|
901
|
+
print(f" [DEBUG] Same source ({source_a}), not a skip connection - skipping")
|
|
902
|
+
if output_name:
|
|
903
|
+
input_scale = trace_to_scale(input_a)
|
|
904
|
+
tensor_to_scale[output_name] = input_scale
|
|
905
|
+
# Propagate shape
|
|
906
|
+
if input_a in tensor_to_shape:
|
|
907
|
+
tensor_to_shape[output_name] = tensor_to_shape[input_a]
|
|
908
|
+
# Propagate layer index
|
|
909
|
+
if source_a >= 0:
|
|
910
|
+
tensor_to_layer_idx[output_name] = source_a
|
|
911
|
+
continue
|
|
912
|
+
|
|
913
|
+
# Case 2.5: 한쪽이 -1이면 (입력 텐서에서 직접) skip connection 아님
|
|
914
|
+
if source_a < 0 or source_b < 0:
|
|
915
|
+
print(f" [DEBUG] One source is input tensor ({source_a}, {source_b}), not a skip connection - skipping")
|
|
916
|
+
if output_name:
|
|
917
|
+
# 유효한 소스의 정보를 전파
|
|
918
|
+
valid_source = max(source_a, source_b)
|
|
919
|
+
valid_tensor = input_a if source_a >= 0 else input_b
|
|
920
|
+
input_scale = trace_to_scale(valid_tensor)
|
|
921
|
+
tensor_to_scale[output_name] = input_scale
|
|
922
|
+
if valid_tensor in tensor_to_shape:
|
|
923
|
+
tensor_to_shape[output_name] = tensor_to_shape[valid_tensor]
|
|
924
|
+
if valid_source >= 0:
|
|
925
|
+
tensor_to_layer_idx[output_name] = valid_source
|
|
926
|
+
continue
|
|
927
|
+
|
|
928
|
+
# ------------------------------------------
|
|
929
|
+
# Case 3: 다른 소스 = 실제 skip connection!
|
|
930
|
+
# ------------------------------------------
|
|
931
|
+
print(f" [DEBUG] Skip connection detected: {source_a} vs {source_b}")
|
|
932
|
+
|
|
933
|
+
# skip은 더 오래된(더 작은 인덱스) 레이어에서 온 것
|
|
934
|
+
if source_a < source_b:
|
|
935
|
+
skip_source = source_a
|
|
936
|
+
conv_source = source_b
|
|
937
|
+
skip_tensor = input_a
|
|
938
|
+
conv_tensor = input_b
|
|
939
|
+
else:
|
|
940
|
+
skip_source = source_b
|
|
941
|
+
conv_source = source_a
|
|
942
|
+
skip_tensor = input_b
|
|
943
|
+
conv_tensor = input_a
|
|
944
|
+
|
|
945
|
+
# Count how many PoTAdd layers we've already processed
|
|
946
|
+
add_count = sum(1 for l in layer_infos if l.get('type') == 'PoTAdd')
|
|
947
|
+
|
|
948
|
+
# ========================================
|
|
949
|
+
# Rescale 파라미터 결정
|
|
950
|
+
# ========================================
|
|
951
|
+
|
|
952
|
+
# Option 1: PyTorch PoTAdd 모듈에서 calibration 값 사용
|
|
953
|
+
if add_count < len(add_layer_infos) and add_layer_infos[add_count]['scale_x'] is not None:
|
|
954
|
+
skip_act_scale = add_layer_infos[add_count]['scale_x']
|
|
955
|
+
conv_act_scale = add_layer_infos[add_count]['scale_y']
|
|
956
|
+
rescale_mult = add_layer_infos[add_count]['rescale_mult']
|
|
957
|
+
rescale_shift = add_layer_infos[add_count]['rescale_shift']
|
|
958
|
+
print(f" Add rescale (from PoTAdd calibration): skip_scale={skip_act_scale:.4f}, "
|
|
959
|
+
f"conv_scale={conv_act_scale:.4f}, mult={rescale_mult}, shift={rescale_shift}")
|
|
960
|
+
|
|
961
|
+
# Option 2: PoTAdd 없이 skip connection 구현된 경우 - ONNX scale 추적
|
|
962
|
+
else:
|
|
963
|
+
skip_act_scale = trace_to_scale(skip_tensor)
|
|
964
|
+
conv_act_scale = trace_to_scale(conv_tensor)
|
|
965
|
+
|
|
966
|
+
# Calculate rescale ratio: convert skip scale to conv scale
|
|
967
|
+
# skip_aligned = skip * ratio 해서 conv와 같은 scale로 맞춤
|
|
968
|
+
if skip_act_scale != 0 and skip_act_scale != conv_act_scale:
|
|
969
|
+
ratio = conv_act_scale / skip_act_scale
|
|
970
|
+
else:
|
|
971
|
+
ratio = 1.0
|
|
972
|
+
|
|
973
|
+
# Convert to integer arithmetic: mult * x >> shift ≈ ratio * x
|
|
974
|
+
rescale_shift = 7
|
|
975
|
+
rescale_mult = round(ratio * (1 << rescale_shift))
|
|
976
|
+
rescale_mult = max(1, min(rescale_mult, 512)) # Clamp to valid range
|
|
977
|
+
|
|
978
|
+
if add_count >= len(add_layer_infos):
|
|
979
|
+
print(f" Add rescale (no PoTAdd module - direct skip): "
|
|
980
|
+
f"skip_scale={skip_act_scale:.4f}, conv_scale={conv_act_scale:.4f}, "
|
|
981
|
+
f"ratio={ratio:.4f}, mult={rescale_mult}, shift={rescale_shift}")
|
|
982
|
+
else:
|
|
983
|
+
print(f" Add rescale (fallback - PoTAdd uncalibrated): "
|
|
984
|
+
f"skip_scale={skip_act_scale:.4f}, conv_scale={conv_act_scale:.4f}, "
|
|
985
|
+
f"ratio={ratio:.4f}, mult={rescale_mult}, shift={rescale_shift}")
|
|
986
|
+
|
|
987
|
+
layer_idx = len(layer_infos)
|
|
988
|
+
|
|
989
|
+
# Check if ReLU follows this Add (via ONNX chain)
|
|
990
|
+
has_relu = trace_forward_to_relu(output_name) if output_name else False
|
|
991
|
+
|
|
992
|
+
info = {
|
|
993
|
+
'name': f'layer_{layer_idx}',
|
|
994
|
+
'type': 'PoTAdd',
|
|
995
|
+
'layer_type': 'add',
|
|
996
|
+
'in_h': current_h,
|
|
997
|
+
'in_w': current_w,
|
|
998
|
+
'in_channels': current_ch,
|
|
999
|
+
'out_h': current_h,
|
|
1000
|
+
'out_w': current_w,
|
|
1001
|
+
'out_channels': current_ch,
|
|
1002
|
+
'rescale_mult': rescale_mult,
|
|
1003
|
+
'rescale_shift': rescale_shift,
|
|
1004
|
+
'skip_source_layer': skip_source, # skip이 시작된 레이어
|
|
1005
|
+
'conv_source_layer': conv_source, # conv 경로의 마지막 레이어
|
|
1006
|
+
'act_scale': conv_act_scale, # Add output uses conv's scale
|
|
1007
|
+
'has_relu': has_relu,
|
|
1008
|
+
}
|
|
1009
|
+
layer_infos.append(info)
|
|
1010
|
+
|
|
1011
|
+
# Track output tensor and its scale
|
|
1012
|
+
if output_name:
|
|
1013
|
+
tensor_to_layer_idx[output_name] = layer_idx
|
|
1014
|
+
tensor_to_shape[output_name] = (current_ch, current_h, current_w)
|
|
1015
|
+
tensor_to_scale[output_name] = conv_act_scale
|
|
1016
|
+
|
|
1017
|
+
print(f" Add: {current_ch}x{current_h}x{current_w} (skip from layer_{skip_source}, conv from layer_{conv_source})")
|
|
1018
|
+
|
|
1019
|
+
return layer_infos
|
|
1020
|
+
|
|
1021
|
+
|
|
1022
|
+
def get_attribute(node, name: str, default=None):
|
|
1023
|
+
"""Get attribute value from ONNX node."""
|
|
1024
|
+
for attr in node.attribute:
|
|
1025
|
+
if attr.name == name:
|
|
1026
|
+
if attr.type == onnx.AttributeProto.INTS:
|
|
1027
|
+
return list(attr.ints)
|
|
1028
|
+
elif attr.type == onnx.AttributeProto.INT:
|
|
1029
|
+
return attr.i
|
|
1030
|
+
elif attr.type == onnx.AttributeProto.FLOATS:
|
|
1031
|
+
return list(attr.floats)
|
|
1032
|
+
elif attr.type == onnx.AttributeProto.FLOAT:
|
|
1033
|
+
return attr.f
|
|
1034
|
+
return default
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def absorb_standardization(layer_infos: List[Dict], mean: List[float], std: List[float]):
|
|
1038
|
+
"""Absorb input standardization into first PoT layer.
|
|
1039
|
+
|
|
1040
|
+
CRITICAL: Uses avg_std everywhere to match QAT exactly.
|
|
1041
|
+
- mean → bias: b' = b - Σ_c (mean[c]/avg_std) × ΣW[:,c,:,:] × α
|
|
1042
|
+
- /std → combined_scale: uses average std
|
|
1043
|
+
|
|
1044
|
+
Args:
|
|
1045
|
+
layer_infos: List of layer info dicts
|
|
1046
|
+
mean: Per-channel mean values as list (e.g., [0.4914, 0.4822, 0.4465] for CIFAR-10)
|
|
1047
|
+
std: Per-channel std values as list (e.g., [0.2470, 0.2435, 0.2616] for CIFAR-10)
|
|
1048
|
+
|
|
1049
|
+
Note: bias scaling by act_scale is done separately in scale_biases()
|
|
1050
|
+
"""
|
|
1051
|
+
# Calculate avg_std upfront
|
|
1052
|
+
avg_std = sum(std) / len(std) if std else 1.0
|
|
1053
|
+
|
|
1054
|
+
for info in layer_infos:
|
|
1055
|
+
if info['layer_type'] == 'pot' and info.get('is_first', False):
|
|
1056
|
+
weight = info['weight']
|
|
1057
|
+
bias = info['bias']
|
|
1058
|
+
alpha = info['alpha']
|
|
1059
|
+
layer_type = info['type']
|
|
1060
|
+
|
|
1061
|
+
if weight is None:
|
|
1062
|
+
return
|
|
1063
|
+
|
|
1064
|
+
in_channels = weight.shape[1]
|
|
1065
|
+
|
|
1066
|
+
# Validate channel count
|
|
1067
|
+
if len(mean) != in_channels or len(std) != in_channels:
|
|
1068
|
+
raise ValueError(
|
|
1069
|
+
f"mean/std length ({len(mean)}/{len(std)}) must match "
|
|
1070
|
+
f"first layer's in_channels ({in_channels})"
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
if bias is None:
|
|
1074
|
+
bias = torch.zeros(weight.shape[0])
|
|
1075
|
+
|
|
1076
|
+
# Use quantized weights for bias absorption (info['weight'] is already quantized)
|
|
1077
|
+
w_q = weight
|
|
1078
|
+
|
|
1079
|
+
# Channel-wise bias adjustment using avg_std
|
|
1080
|
+
# b' = b - Σ_c (mean[c]/avg_std) × ΣW_q[:,c,...] × α
|
|
1081
|
+
for c in range(in_channels):
|
|
1082
|
+
if layer_type == 'PoTConv1d':
|
|
1083
|
+
# Conv1D: weight shape is [out_ch, in_ch, kernel_size]
|
|
1084
|
+
weight_sum_c = w_q[:, c, :].sum(dim=1) # [out_ch]
|
|
1085
|
+
else:
|
|
1086
|
+
# Conv2D: weight shape is [out_ch, in_ch, kh, kw]
|
|
1087
|
+
weight_sum_c = w_q[:, c, :, :].sum(dim=(1, 2)) # [out_ch]
|
|
1088
|
+
|
|
1089
|
+
bias = bias - (mean[c] / avg_std) * weight_sum_c * alpha
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
info['bias'] = bias
|
|
1093
|
+
info['input_std'] = avg_std
|
|
1094
|
+
|
|
1095
|
+
print(f" Standardization absorbed into {info['name']}: mean={mean}, avg_std={avg_std:.4f}")
|
|
1096
|
+
break
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
def calculate_combined_scales(layer_infos: List[Dict], config: Config):
|
|
1100
|
+
"""Calculate combined scale for each PoT layer.
|
|
1101
|
+
|
|
1102
|
+
Per specification section 6.3 and 9.x:
|
|
1103
|
+
- First layer: combined_scale = α × act_scale / std
|
|
1104
|
+
combined_shift = base_shift + 8 (/256 absorbed)
|
|
1105
|
+
- Other layers: combined_scale = α × act_scale / prev_act_scale
|
|
1106
|
+
combined_shift = base_shift
|
|
1107
|
+
- Last layer (act_scale=None): only α / prev_act_scale (no output quantization)
|
|
1108
|
+
"""
|
|
1109
|
+
prev_act_scale = 1.0
|
|
1110
|
+
|
|
1111
|
+
for info in layer_infos:
|
|
1112
|
+
# CRITICAL: Update prev_act_scale for Add layers!
|
|
1113
|
+
# Add layer changes the scale, so subsequent layers need to use Add's output scale
|
|
1114
|
+
if info['layer_type'] == 'add':
|
|
1115
|
+
add_act_scale = info.get('act_scale')
|
|
1116
|
+
if add_act_scale is not None:
|
|
1117
|
+
print(f" {info['name']}: Add layer, updating prev_act_scale to {add_act_scale:.4f}")
|
|
1118
|
+
prev_act_scale = add_act_scale
|
|
1119
|
+
continue
|
|
1120
|
+
|
|
1121
|
+
if info['layer_type'] != 'pot':
|
|
1122
|
+
continue
|
|
1123
|
+
|
|
1124
|
+
is_first = info.get('is_first', False)
|
|
1125
|
+
is_last = info.get('is_last', False)
|
|
1126
|
+
alpha = info['alpha']
|
|
1127
|
+
act_scale = info.get('act_scale') # Can be None
|
|
1128
|
+
|
|
1129
|
+
# Calculate combined scale
|
|
1130
|
+
if is_first:
|
|
1131
|
+
input_std = info.get('input_std', 1.0)
|
|
1132
|
+
if act_scale is not None:
|
|
1133
|
+
# combined_scale = α × act_scale / std
|
|
1134
|
+
combined_scale = alpha * act_scale / input_std
|
|
1135
|
+
else:
|
|
1136
|
+
# Last layer is also first layer (single layer model)
|
|
1137
|
+
combined_scale = alpha / input_std
|
|
1138
|
+
else:
|
|
1139
|
+
if act_scale is not None:
|
|
1140
|
+
# combined_scale = α × act_scale / prev_act_scale
|
|
1141
|
+
combined_scale = alpha * act_scale / prev_act_scale
|
|
1142
|
+
else:
|
|
1143
|
+
# Last layer: no act_scale, output is raw logits
|
|
1144
|
+
combined_scale = alpha / prev_act_scale
|
|
1145
|
+
|
|
1146
|
+
# Determine shift amount for integer approximation
|
|
1147
|
+
base_shift = 0
|
|
1148
|
+
scale_magnitude = abs(combined_scale)
|
|
1149
|
+
|
|
1150
|
+
# Target: scale_int around 64-256 for precision
|
|
1151
|
+
while scale_magnitude < 64 and base_shift < 24:
|
|
1152
|
+
scale_magnitude *= 2
|
|
1153
|
+
base_shift += 1
|
|
1154
|
+
while scale_magnitude > 512 and base_shift > 0:
|
|
1155
|
+
scale_magnitude /= 2
|
|
1156
|
+
base_shift -= 1
|
|
1157
|
+
|
|
1158
|
+
# For first layer, add +8 for /256 absorption
|
|
1159
|
+
if is_first:
|
|
1160
|
+
combined_shift = base_shift + 8
|
|
1161
|
+
else:
|
|
1162
|
+
combined_shift = base_shift
|
|
1163
|
+
|
|
1164
|
+
# Calculate integer scale
|
|
1165
|
+
scale_int = round(combined_scale * (1 << base_shift))
|
|
1166
|
+
|
|
1167
|
+
info['combined_scale'] = combined_scale
|
|
1168
|
+
info['scale_int'] = scale_int
|
|
1169
|
+
info['combined_shift'] = combined_shift
|
|
1170
|
+
info['base_shift'] = base_shift
|
|
1171
|
+
|
|
1172
|
+
print(f" {info['name']}: combined_scale={combined_scale:.6f}, scale_int={scale_int}, shift={combined_shift}, act_scale={act_scale}")
|
|
1173
|
+
|
|
1174
|
+
|
|
1175
|
+
# Update prev_act_scale for next layer
|
|
1176
|
+
if act_scale is not None:
|
|
1177
|
+
prev_act_scale = act_scale
|
|
1178
|
+
|
|
1179
|
+
|
|
1180
|
+
def scale_biases(layer_infos: List[Dict]):
|
|
1181
|
+
"""Scale all biases by act_scale.
|
|
1182
|
+
|
|
1183
|
+
CRITICAL: In PyTorch training:
|
|
1184
|
+
output = conv(x) + bias
|
|
1185
|
+
output_quantized = output * act_scale
|
|
1186
|
+
|
|
1187
|
+
So bias is also multiplied by act_scale!
|
|
1188
|
+
|
|
1189
|
+
For last layer (act_scale=None), bias is used as-is (no quantization).
|
|
1190
|
+
"""
|
|
1191
|
+
for info in layer_infos:
|
|
1192
|
+
if info['layer_type'] != 'pot':
|
|
1193
|
+
continue
|
|
1194
|
+
|
|
1195
|
+
bias = info.get('bias')
|
|
1196
|
+
if bias is None:
|
|
1197
|
+
info['bias_scaled'] = None
|
|
1198
|
+
continue
|
|
1199
|
+
|
|
1200
|
+
act_scale = info.get('act_scale')
|
|
1201
|
+
|
|
1202
|
+
if act_scale is not None:
|
|
1203
|
+
# Scale bias by act_scale
|
|
1204
|
+
# Use floor(x + 0.5) to match round_half_up_ste in PyTorch modules
|
|
1205
|
+
# (torch.round uses bankers rounding which causes ±1 mismatch)
|
|
1206
|
+
bias_scaled = torch.floor(bias * act_scale + 0.5).to(torch.int32)
|
|
1207
|
+
info['bias_scaled'] = bias_scaled
|
|
1208
|
+
print(f" {info['name']}: bias scaled by act_scale={act_scale:.2f}, range=[{bias_scaled.min()}, {bias_scaled.max()}]")
|
|
1209
|
+
else:
|
|
1210
|
+
# Last layer: no act_scale
|
|
1211
|
+
# Use floor(x + 0.5) to match round_half_up_ste
|
|
1212
|
+
bias_scaled = torch.floor(bias + 0.5).to(torch.int32)
|
|
1213
|
+
info['bias_scaled'] = bias_scaled
|
|
1214
|
+
print(f" {info['name']}: last layer, bias not scaled, range=[{bias_scaled.min()}, {bias_scaled.max()}]")
|
|
1215
|
+
|
|
1216
|
+
|
|
1217
|
+
def generate_header(layer_infos: List[Dict[str, Any]], config: Config, optimized: bool = True) -> str:
|
|
1218
|
+
"""Generate complete C header file.
|
|
1219
|
+
|
|
1220
|
+
Args:
|
|
1221
|
+
layer_infos: List of layer information dictionaries
|
|
1222
|
+
config: Configuration with MCU specs
|
|
1223
|
+
optimized: If True, use optimized code generation
|
|
1224
|
+
"""
|
|
1225
|
+
code = []
|
|
1226
|
+
|
|
1227
|
+
code.append("#ifndef POTNN_MODEL_H")
|
|
1228
|
+
code.append("#define POTNN_MODEL_H")
|
|
1229
|
+
code.append("")
|
|
1230
|
+
code.append("#include <stdint.h>")
|
|
1231
|
+
code.append("#include <string.h>") # for memset
|
|
1232
|
+
code.append("")
|
|
1233
|
+
|
|
1234
|
+
code.append("/*")
|
|
1235
|
+
code.append(f" * PoT Quantized Neural Network")
|
|
1236
|
+
code.append(f" * Target: {config.ram}B RAM, {config.flash}B Flash")
|
|
1237
|
+
if config.mean is not None:
|
|
1238
|
+
# Format list values nicely
|
|
1239
|
+
mean_str = ', '.join(f'{m:.4f}' for m in config.mean)
|
|
1240
|
+
std_str = ', '.join(f'{s:.4f}' for s in config.std)
|
|
1241
|
+
code.append(f" * Input normalization: mean=[{mean_str}], std=[{std_str}]")
|
|
1242
|
+
code.append(f" * (absorbed into first layer: mean→bias, /std→scale, /256→shift)")
|
|
1243
|
+
code.append(f" * Input: uint8 [0,255] - no runtime normalization needed")
|
|
1244
|
+
if config.input_w == 1:
|
|
1245
|
+
# Conv1d: CxL format
|
|
1246
|
+
code.append(f" * Input size: {config.input_channels}x{config.input_h}")
|
|
1247
|
+
else:
|
|
1248
|
+
# Conv2d: CxHxW format
|
|
1249
|
+
code.append(f" * Input size: {config.input_channels}x{config.input_h}x{config.input_w}")
|
|
1250
|
+
if optimized:
|
|
1251
|
+
code.append(f" * Optimized: Full Pipeline (loop), Zero-Padding (unroll)")
|
|
1252
|
+
code.append(f" * Generated by potnn v8 (all bugs fixed)")
|
|
1253
|
+
code.append(" */")
|
|
1254
|
+
code.append("")
|
|
1255
|
+
|
|
1256
|
+
# Generate scale functions
|
|
1257
|
+
code.append("// Scale functions (combined_scale MUL, 1회/layer)")
|
|
1258
|
+
for info in layer_infos:
|
|
1259
|
+
if info['layer_type'] == 'pot':
|
|
1260
|
+
scale_func = generate_scale_function(info)
|
|
1261
|
+
code.append(scale_func)
|
|
1262
|
+
code.append("")
|
|
1263
|
+
|
|
1264
|
+
# Generate layer functions
|
|
1265
|
+
code.append("// Layer forward functions")
|
|
1266
|
+
for i, info in enumerate(layer_infos):
|
|
1267
|
+
if info['layer_type'] == 'pot':
|
|
1268
|
+
is_first = info.get('is_first', False)
|
|
1269
|
+
# CRITICAL FIX: Disable optimization for first layer due to bias generation bug (missing bias in Ch0)
|
|
1270
|
+
use_opt = optimized and not is_first
|
|
1271
|
+
layer_code = generate_pot_layer(info, is_first_layer=is_first, optimized=use_opt)
|
|
1272
|
+
code.append(layer_code)
|
|
1273
|
+
elif info['layer_type'] == 'maxpool':
|
|
1274
|
+
layer_code = generate_maxpool_layer(info)
|
|
1275
|
+
code.append(layer_code)
|
|
1276
|
+
elif info['layer_type'] == 'add':
|
|
1277
|
+
layer_code = generate_add_layer(info)
|
|
1278
|
+
code.append(layer_code)
|
|
1279
|
+
elif info['layer_type'] == 'global_avg_pool':
|
|
1280
|
+
layer_code = generate_global_avg_pool_layer(info)
|
|
1281
|
+
code.append(layer_code)
|
|
1282
|
+
|
|
1283
|
+
# Generate main predict function
|
|
1284
|
+
code.append("// Main prediction function")
|
|
1285
|
+
code.append(generate_predict_function(layer_infos, config))
|
|
1286
|
+
|
|
1287
|
+
code.append("#endif // POTNN_MODEL_H")
|
|
1288
|
+
|
|
1289
|
+
return '\n'.join(code)
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
def generate_scale_function(info: Dict) -> str:
|
|
1293
|
+
"""Generate scale function using MUL + shift."""
|
|
1294
|
+
name = info['name']
|
|
1295
|
+
scale_int = info['scale_int']
|
|
1296
|
+
shift = info['combined_shift']
|
|
1297
|
+
|
|
1298
|
+
code = []
|
|
1299
|
+
code.append(f"// {name}: combined_scale={info['combined_scale']:.6f}, scale_int={scale_int}, shift={shift}")
|
|
1300
|
+
code.append(f"static inline int32_t scale_{name}(int32_t x) {{")
|
|
1301
|
+
|
|
1302
|
+
# Handle edge cases to avoid undefined behavior (1 << -1)
|
|
1303
|
+
if shift == 0:
|
|
1304
|
+
code.append(f" return (int32_t)((int64_t)x * {scale_int});")
|
|
1305
|
+
elif shift == 1:
|
|
1306
|
+
code.append(f" return (int32_t)(((int64_t)x * {scale_int} + 1) >> 1);")
|
|
1307
|
+
else:
|
|
1308
|
+
code.append(f" return (int32_t)(((int64_t)x * {scale_int} + (1 << {shift-1})) >> {shift});")
|
|
1309
|
+
|
|
1310
|
+
code.append("}")
|
|
1311
|
+
code.append("")
|
|
1312
|
+
|
|
1313
|
+
return '\n'.join(code)
|
|
1314
|
+
|
|
1315
|
+
|
|
1316
|
+
def generate_pot_layer(info: Dict, is_first_layer: bool = False, optimized: bool = True) -> str:
|
|
1317
|
+
"""Generate PoT Conv2d, DepthwiseConv2d, Linear, Add, or GlobalAvgPool layer code.
|
|
1318
|
+
|
|
1319
|
+
Supports multiple encoding modes:
|
|
1320
|
+
- 'unroll': weights embedded as shift-add operations (11 levels, fastest)
|
|
1321
|
+
- 'fp130': FP1.3.0 packed weights (16 levels, no zero)
|
|
1322
|
+
- '5level': skip-encoded weights (5 levels: -8,-1,0,+1,+8)
|
|
1323
|
+
- '2bit': minimal 2-bit encoding (4 levels: ±1,±2)
|
|
1324
|
+
- 'ternary': ternary encoding (3 levels: -1,0,+1)
|
|
1325
|
+
- 'loop': weights in packed table with loop (slower, less Flash) [legacy]
|
|
1326
|
+
|
|
1327
|
+
When optimized=True:
|
|
1328
|
+
- Loop mode: Full Pipeline (Zero-Padding + im2col + Column-wise + Position Blocking + Shift Grouping)
|
|
1329
|
+
- Unroll mode: Zero-Padding (eliminates boundary checks)
|
|
1330
|
+
"""
|
|
1331
|
+
layer_type = info['type']
|
|
1332
|
+
code_mode = info.get('code_mode', 'unroll')
|
|
1333
|
+
levels = info.get('levels', 11)
|
|
1334
|
+
encoding = info.get('encoding', 'unroll')
|
|
1335
|
+
|
|
1336
|
+
# Debug: print encoding for each layer
|
|
1337
|
+
print(f" [CODEGEN] {info['name']}: encoding={encoding}, code_mode={code_mode}")
|
|
1338
|
+
|
|
1339
|
+
# Non-PoT layers always use fixed implementations
|
|
1340
|
+
if layer_type == 'PoTAdd':
|
|
1341
|
+
return generate_add_layer(info)
|
|
1342
|
+
elif layer_type == 'GlobalAvgPool':
|
|
1343
|
+
return generate_global_avg_pool_layer(info)
|
|
1344
|
+
|
|
1345
|
+
# Encoding-specific code generation (overrides code_mode)
|
|
1346
|
+
if encoding == 'fp130':
|
|
1347
|
+
from .codegen.fp130 import generate_fp130_layer
|
|
1348
|
+
print(f" → Using FP1.3.0 encoder")
|
|
1349
|
+
return generate_fp130_layer(info)
|
|
1350
|
+
elif encoding == '2bit':
|
|
1351
|
+
from .codegen.bit2 import generate_2bit_layer
|
|
1352
|
+
print(f" → Using 2-bit encoder")
|
|
1353
|
+
return generate_2bit_layer(info)
|
|
1354
|
+
elif encoding == '5level':
|
|
1355
|
+
from .codegen.level5 import generate_5level_layer
|
|
1356
|
+
print(f" → Using 5-level encoder")
|
|
1357
|
+
return generate_5level_layer(info)
|
|
1358
|
+
elif encoding == 'ternary':
|
|
1359
|
+
from .codegen.ternary import generate_ternary_layer
|
|
1360
|
+
print(f" → Using Ternary encoder")
|
|
1361
|
+
return generate_ternary_layer(info)
|
|
1362
|
+
|
|
1363
|
+
# Default: unroll or loop mode (legacy)
|
|
1364
|
+
|
|
1365
|
+
# Default: unroll mode
|
|
1366
|
+
if code_mode == 'loop':
|
|
1367
|
+
print(f" [WARNING] 'loop' code_mode is deprecated and loop.py is removed. Falling back to 'unroll' mode.")
|
|
1368
|
+
if optimized:
|
|
1369
|
+
from .codegen.unroll import generate_unrolled_layer_optimized
|
|
1370
|
+
return generate_unrolled_layer_optimized(info, is_first_layer=is_first_layer)
|
|
1371
|
+
else:
|
|
1372
|
+
# Legacy unroll logic dispatch
|
|
1373
|
+
pass # Fallthrough to below
|
|
1374
|
+
|
|
1375
|
+
# Standard unroll dispatch logic continues below...
|
|
1376
|
+
if True: # Indent anchor
|
|
1377
|
+
# Unroll mode
|
|
1378
|
+
if optimized:
|
|
1379
|
+
from .codegen.unroll import generate_unrolled_layer_optimized
|
|
1380
|
+
return generate_unrolled_layer_optimized(info, is_first_layer=is_first_layer)
|
|
1381
|
+
else:
|
|
1382
|
+
if layer_type == 'PoTConv2d':
|
|
1383
|
+
return generate_conv_layer(info, is_first_layer)
|
|
1384
|
+
elif layer_type == 'PoTConv1d':
|
|
1385
|
+
# Conv1d: use optimized version even in non-optimized mode
|
|
1386
|
+
from .codegen.unroll import generate_unrolled_layer_optimized
|
|
1387
|
+
return generate_unrolled_layer_optimized(info, is_first_layer=is_first_layer)
|
|
1388
|
+
elif layer_type == 'PoTDepthwiseConv2d':
|
|
1389
|
+
return generate_depthwise_conv_layer(info, is_first_layer)
|
|
1390
|
+
elif layer_type == 'PoTLinear':
|
|
1391
|
+
return generate_linear_layer(info, is_first_layer)
|
|
1392
|
+
|
|
1393
|
+
return ""
|
|
1394
|
+
|
|
1395
|
+
|
|
1396
|
+
def generate_global_avg_pool_layer(info: Dict) -> str:
|
|
1397
|
+
"""Generate Global Average Pooling layer.
|
|
1398
|
+
|
|
1399
|
+
C×H×W → C (채널당 평균)
|
|
1400
|
+
나눗셈을 정수 연산으로: avg = (sum * div_mult) >> div_shift
|
|
1401
|
+
"""
|
|
1402
|
+
name = info['name']
|
|
1403
|
+
channels = info['in_channels']
|
|
1404
|
+
h = info['in_h']
|
|
1405
|
+
w = info['in_w']
|
|
1406
|
+
pool_size = info.get('pool_size', h * w)
|
|
1407
|
+
div_mult = info.get('div_mult', 1)
|
|
1408
|
+
div_shift = info.get('div_shift', 0)
|
|
1409
|
+
|
|
1410
|
+
code = []
|
|
1411
|
+
code.append(f"// {name} - Global Average Pooling: {channels}x{h}x{w} -> {channels}")
|
|
1412
|
+
|
|
1413
|
+
if div_mult == 1:
|
|
1414
|
+
# Power of 2: shift only
|
|
1415
|
+
code.append(f"// pool_size={pool_size} (2^{div_shift}), using shift only")
|
|
1416
|
+
code.append(f"static void {name}_forward(const int8_t* input, int8_t* output) {{")
|
|
1417
|
+
code.append(f" for (int c = 0; c < {channels}; c++) {{")
|
|
1418
|
+
code.append(f" int32_t sum = 0;")
|
|
1419
|
+
code.append(f" for (int i = 0; i < {pool_size}; i++) {{")
|
|
1420
|
+
code.append(f" sum += input[c * {pool_size} + i];")
|
|
1421
|
+
code.append(f" }}")
|
|
1422
|
+
code.append(f" output[c] = (int8_t)((sum + {1 << (div_shift - 1)}) >> {div_shift});")
|
|
1423
|
+
code.append(f" }}")
|
|
1424
|
+
code.append(f"}}")
|
|
1425
|
+
else:
|
|
1426
|
+
# Not power of 2: mult + shift
|
|
1427
|
+
code.append(f"// pool_size={pool_size}, div = (sum * {div_mult}) >> {div_shift}")
|
|
1428
|
+
code.append(f"static void {name}_forward(const int8_t* input, int8_t* output) {{")
|
|
1429
|
+
code.append(f" for (int c = 0; c < {channels}; c++) {{")
|
|
1430
|
+
code.append(f" int32_t sum = 0;")
|
|
1431
|
+
code.append(f" for (int i = 0; i < {pool_size}; i++) {{")
|
|
1432
|
+
code.append(f" sum += input[c * {pool_size} + i];")
|
|
1433
|
+
code.append(f" }}")
|
|
1434
|
+
code.append(f" int32_t avg = (sum * {div_mult}) >> {div_shift};")
|
|
1435
|
+
code.append(f" output[c] = (int8_t)(avg > 127 ? 127 : (avg < -128 ? -128 : avg));")
|
|
1436
|
+
code.append(f" }}")
|
|
1437
|
+
code.append(f"}}")
|
|
1438
|
+
|
|
1439
|
+
code.append("")
|
|
1440
|
+
return '\n'.join(code)
|
|
1441
|
+
|
|
1442
|
+
|
|
1443
|
+
def generate_add_layer(info: Dict) -> str:
|
|
1444
|
+
"""Generate Add layer for skip/residual connections.
|
|
1445
|
+
|
|
1446
|
+
두 입력의 scale 정합을 정수 MUL + shift로 처리:
|
|
1447
|
+
x_aligned = (x * rescale_mult) >> rescale_shift
|
|
1448
|
+
output = x_aligned + y
|
|
1449
|
+
|
|
1450
|
+
컴파일 타임에 mult/shift 계산, 런타임에 float 없음.
|
|
1451
|
+
"""
|
|
1452
|
+
name = info['name']
|
|
1453
|
+
channels = info['in_channels']
|
|
1454
|
+
h = info['in_h']
|
|
1455
|
+
w = info['in_w']
|
|
1456
|
+
rescale_mult = info.get('rescale_mult', 128)
|
|
1457
|
+
rescale_shift = info.get('rescale_shift', 7)
|
|
1458
|
+
has_relu = info.get('has_relu', False)
|
|
1459
|
+
|
|
1460
|
+
size = channels * h * w
|
|
1461
|
+
|
|
1462
|
+
code = []
|
|
1463
|
+
code.append(f"// {name} - Add (skip connection): {channels}x{h}x{w}")
|
|
1464
|
+
code.append(f"// rescale: x_aligned = (x * {rescale_mult}) >> {rescale_shift}")
|
|
1465
|
+
if has_relu:
|
|
1466
|
+
code.append(f"// ReLU applied after add")
|
|
1467
|
+
code.append(f"static void {name}_forward(const int8_t* input_skip, const int8_t* input_conv, int8_t* output) {{")
|
|
1468
|
+
code.append(f" for (int i = 0; i < {size}; i++) {{")
|
|
1469
|
+
code.append(f" int32_t x = (int32_t)input_skip[i];")
|
|
1470
|
+
code.append(f" int32_t y = (int32_t)input_conv[i];")
|
|
1471
|
+
|
|
1472
|
+
# Rescale x to match y's scale: x * mult >> shift
|
|
1473
|
+
if rescale_mult == 128 and rescale_shift == 7:
|
|
1474
|
+
# 기본값 (ratio ≈ 1.0): 최적화 가능
|
|
1475
|
+
code.append(f" // ratio ≈ 1.0, skip rescale")
|
|
1476
|
+
else:
|
|
1477
|
+
code.append(f" x = (x * {rescale_mult}) >> {rescale_shift};")
|
|
1478
|
+
|
|
1479
|
+
code.append(f" int32_t sum = x + y;")
|
|
1480
|
+
|
|
1481
|
+
# ReLU if needed
|
|
1482
|
+
if has_relu:
|
|
1483
|
+
code.append(f" if (sum < 0) sum = 0;")
|
|
1484
|
+
|
|
1485
|
+
code.append(f" // Clamp to int8 range")
|
|
1486
|
+
code.append(f" output[i] = (int8_t)(sum > 127 ? 127 : (sum < -128 ? -128 : sum));")
|
|
1487
|
+
code.append(f" }}")
|
|
1488
|
+
code.append(f"}}")
|
|
1489
|
+
code.append("")
|
|
1490
|
+
|
|
1491
|
+
return '\n'.join(code)
|
|
1492
|
+
|
|
1493
|
+
|
|
1494
|
+
def generate_depthwise_conv_layer(info: Dict, is_first_layer: bool = False) -> str:
|
|
1495
|
+
"""Generate Depthwise Conv2d layer with PoT weights.
|
|
1496
|
+
|
|
1497
|
+
Depthwise Conv: 각 채널이 독립적으로 처리됨.
|
|
1498
|
+
weight shape: [channels, 1, kH, kW]
|
|
1499
|
+
입출력 채널이 동일 (channels = in_channels = out_channels)
|
|
1500
|
+
"""
|
|
1501
|
+
name = info['name']
|
|
1502
|
+
weight = info['weight']
|
|
1503
|
+
bias_scaled = info.get('bias_scaled')
|
|
1504
|
+
alpha = info['alpha']
|
|
1505
|
+
is_last = info.get('is_last', False)
|
|
1506
|
+
|
|
1507
|
+
channels = info['in_channels'] # in_channels == out_channels for depthwise
|
|
1508
|
+
kh = info['kernel_size']
|
|
1509
|
+
kw = kh
|
|
1510
|
+
stride = info['stride']
|
|
1511
|
+
padding = info['padding']
|
|
1512
|
+
in_h = info['in_h']
|
|
1513
|
+
in_w = info['in_w']
|
|
1514
|
+
out_h = info['out_h']
|
|
1515
|
+
out_w = info['out_w']
|
|
1516
|
+
|
|
1517
|
+
input_type = "uint8_t" if is_first_layer else "int8_t"
|
|
1518
|
+
|
|
1519
|
+
code = []
|
|
1520
|
+
code.append(f"// {name} - DepthwiseConv2d: {channels}x{in_h}x{in_w} -> {channels}x{out_h}x{out_w}")
|
|
1521
|
+
code.append(f"// Kernel: {kh}x{kw}, Stride: {stride}, Padding: {padding}, alpha={alpha:.4f}")
|
|
1522
|
+
if is_first_layer:
|
|
1523
|
+
code.append(f"// First layer: input is uint8 [0,255], /256 absorbed in shift")
|
|
1524
|
+
if is_last:
|
|
1525
|
+
code.append(f"// Last layer: no ReLU")
|
|
1526
|
+
code.append(f"static void {name}_forward(const {input_type}* input, int8_t* output) {{")
|
|
1527
|
+
code.append(f" int32_t acc;")
|
|
1528
|
+
code.append("")
|
|
1529
|
+
|
|
1530
|
+
# Depthwise: 각 채널을 독립적으로 처리
|
|
1531
|
+
code.append(f" for (int c = 0; c < {channels}; c++) {{")
|
|
1532
|
+
code.append(f" for (int oy = 0; oy < {out_h}; oy++) {{")
|
|
1533
|
+
code.append(f" for (int ox = 0; ox < {out_w}; ox++) {{")
|
|
1534
|
+
code.append(f" acc = 0;")
|
|
1535
|
+
code.append("")
|
|
1536
|
+
|
|
1537
|
+
# 커널 루프 (언롤링)
|
|
1538
|
+
for ky in range(kh):
|
|
1539
|
+
for kx in range(kw):
|
|
1540
|
+
# weight index: [c, 0, ky, kx] → c dimension handled by switch
|
|
1541
|
+
iy_base = ky - padding
|
|
1542
|
+
ix_base = kx - padding
|
|
1543
|
+
|
|
1544
|
+
# Input coordinate: iy = oy * stride + ky - padding
|
|
1545
|
+
if stride == 1:
|
|
1546
|
+
if iy_base == 0:
|
|
1547
|
+
iy_expr = "oy"
|
|
1548
|
+
elif iy_base > 0:
|
|
1549
|
+
iy_expr = f"oy + {iy_base}"
|
|
1550
|
+
else:
|
|
1551
|
+
iy_expr = f"oy - {-iy_base}"
|
|
1552
|
+
else:
|
|
1553
|
+
if iy_base == 0:
|
|
1554
|
+
iy_expr = f"oy * {stride}"
|
|
1555
|
+
elif iy_base > 0:
|
|
1556
|
+
iy_expr = f"oy * {stride} + {iy_base}"
|
|
1557
|
+
else:
|
|
1558
|
+
iy_expr = f"oy * {stride} - {-iy_base}"
|
|
1559
|
+
|
|
1560
|
+
if stride == 1:
|
|
1561
|
+
if ix_base == 0:
|
|
1562
|
+
ix_expr = "ox"
|
|
1563
|
+
elif ix_base > 0:
|
|
1564
|
+
ix_expr = f"ox + {ix_base}"
|
|
1565
|
+
else:
|
|
1566
|
+
ix_expr = f"ox - {-ix_base}"
|
|
1567
|
+
else:
|
|
1568
|
+
if ix_base == 0:
|
|
1569
|
+
ix_expr = f"ox * {stride}"
|
|
1570
|
+
elif ix_base > 0:
|
|
1571
|
+
ix_expr = f"ox * {stride} + {ix_base}"
|
|
1572
|
+
else:
|
|
1573
|
+
ix_expr = f"ox * {stride} - {-ix_base}"
|
|
1574
|
+
|
|
1575
|
+
# Boundary check
|
|
1576
|
+
checks = []
|
|
1577
|
+
if stride == 1:
|
|
1578
|
+
if iy_base < 0:
|
|
1579
|
+
checks.append(f"oy >= {-iy_base}")
|
|
1580
|
+
if iy_base > 0:
|
|
1581
|
+
checks.append(f"oy + {iy_base} < {in_h}")
|
|
1582
|
+
if ix_base < 0:
|
|
1583
|
+
checks.append(f"ox >= {-ix_base}")
|
|
1584
|
+
if ix_base > 0:
|
|
1585
|
+
checks.append(f"ox + {ix_base} < {in_w}")
|
|
1586
|
+
else:
|
|
1587
|
+
if iy_base < 0:
|
|
1588
|
+
checks.append(f"oy * {stride} >= {-iy_base}")
|
|
1589
|
+
if iy_base > 0:
|
|
1590
|
+
checks.append(f"oy * {stride} + {iy_base} < {in_h}")
|
|
1591
|
+
if ix_base < 0:
|
|
1592
|
+
checks.append(f"ox * {stride} >= {-ix_base}")
|
|
1593
|
+
if ix_base > 0:
|
|
1594
|
+
checks.append(f"ox * {stride} + {ix_base} < {in_w}")
|
|
1595
|
+
|
|
1596
|
+
# Input index: c * in_h * in_w + iy * in_w + ix
|
|
1597
|
+
idx_expr = f"c * {in_h} * {in_w} + ({iy_expr}) * {in_w} + ({ix_expr})"
|
|
1598
|
+
|
|
1599
|
+
# Generate per-channel weight switch
|
|
1600
|
+
code.append(f" // Kernel position ({ky}, {kx})")
|
|
1601
|
+
|
|
1602
|
+
# Check which channels have non-zero weights at this kernel position
|
|
1603
|
+
non_zero_channels = []
|
|
1604
|
+
for c in range(channels):
|
|
1605
|
+
w = weight[c, 0, ky, kx].item()
|
|
1606
|
+
if abs(w) >= 0.5: # Not zero
|
|
1607
|
+
w_abs = abs(w)
|
|
1608
|
+
k = round(np.log2(w_abs + 1e-10))
|
|
1609
|
+
k = max(0, min(4, k))
|
|
1610
|
+
sign = '+' if w > 0 else '-'
|
|
1611
|
+
non_zero_channels.append((c, k, sign))
|
|
1612
|
+
|
|
1613
|
+
if len(non_zero_channels) == 0:
|
|
1614
|
+
code.append(f" // All weights zero at this position")
|
|
1615
|
+
continue
|
|
1616
|
+
|
|
1617
|
+
# Check if all non-zero channels have same k and sign
|
|
1618
|
+
all_same = len(set((k, sign) for _, k, sign in non_zero_channels)) == 1
|
|
1619
|
+
|
|
1620
|
+
if all_same and len(non_zero_channels) == channels:
|
|
1621
|
+
# All channels same: no switch needed
|
|
1622
|
+
k, sign = non_zero_channels[0][1], non_zero_channels[0][2]
|
|
1623
|
+
if k == 0:
|
|
1624
|
+
shift_expr = f"(int32_t)input[{idx_expr}]"
|
|
1625
|
+
else:
|
|
1626
|
+
shift_expr = f"(int32_t)input[{idx_expr}] << {k}"
|
|
1627
|
+
|
|
1628
|
+
if checks:
|
|
1629
|
+
cond = " && ".join(checks)
|
|
1630
|
+
code.append(f" if ({cond}) acc {sign}= {shift_expr};")
|
|
1631
|
+
else:
|
|
1632
|
+
code.append(f" acc {sign}= {shift_expr};")
|
|
1633
|
+
else:
|
|
1634
|
+
# Different weights per channel: use switch
|
|
1635
|
+
cond_prefix = ""
|
|
1636
|
+
if checks:
|
|
1637
|
+
cond = " && ".join(checks)
|
|
1638
|
+
cond_prefix = f"if ({cond}) "
|
|
1639
|
+
|
|
1640
|
+
code.append(f" {cond_prefix}switch (c) {{")
|
|
1641
|
+
for c, k, sign in non_zero_channels:
|
|
1642
|
+
if k == 0:
|
|
1643
|
+
shift_expr = f"(int32_t)input[{idx_expr}]"
|
|
1644
|
+
else:
|
|
1645
|
+
shift_expr = f"(int32_t)input[{idx_expr}] << {k}"
|
|
1646
|
+
code.append(f" case {c}: acc {sign}= {shift_expr}; break;")
|
|
1647
|
+
code.append(f" }}")
|
|
1648
|
+
|
|
1649
|
+
code.append("")
|
|
1650
|
+
|
|
1651
|
+
# Apply scale
|
|
1652
|
+
code.append(f" acc = scale_{name}(acc);")
|
|
1653
|
+
|
|
1654
|
+
# Add scaled bias (per-channel)
|
|
1655
|
+
if bias_scaled is not None:
|
|
1656
|
+
code.append(f" // Add per-channel bias")
|
|
1657
|
+
# DEBUG: Print bias values during generation
|
|
1658
|
+
if name == 'layer_0':
|
|
1659
|
+
print(f" [GEN DEBUG] {name} bias_scaled: {bias_scaled[:5].tolist()}...")
|
|
1660
|
+
code.append(f" switch (c) {{")
|
|
1661
|
+
|
|
1662
|
+
for c in range(channels):
|
|
1663
|
+
b = int(bias_scaled[c].item())
|
|
1664
|
+
if b != 0:
|
|
1665
|
+
code.append(f" case {c}: acc += {b}; break;")
|
|
1666
|
+
code.append(f" }}")
|
|
1667
|
+
|
|
1668
|
+
# ReLU (based on ONNX graph analysis)
|
|
1669
|
+
if info.get('has_relu', False):
|
|
1670
|
+
code.append(f" if (acc < 0) acc = 0;")
|
|
1671
|
+
|
|
1672
|
+
# Clamp and store
|
|
1673
|
+
code.append(f" int out_idx = c * {out_h} * {out_w} + oy * {out_w} + ox;")
|
|
1674
|
+
code.append(f" output[out_idx] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));")
|
|
1675
|
+
|
|
1676
|
+
code.append(f" }}")
|
|
1677
|
+
code.append(f" }}")
|
|
1678
|
+
code.append(f" }}")
|
|
1679
|
+
code.append(f"}}")
|
|
1680
|
+
code.append("")
|
|
1681
|
+
|
|
1682
|
+
return '\n'.join(code)
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
def generate_conv1x1_layer(info: Dict, is_first_layer: bool = False) -> str:
|
|
1686
|
+
"""Generate optimized 1x1 Conv2d layer with PoT weights.
|
|
1687
|
+
|
|
1688
|
+
1x1 Conv는 spatial 연산 없이 채널 간 내적만 수행.
|
|
1689
|
+
패딩/커널 루프 불필요 → 단순화된 코드 생성.
|
|
1690
|
+
"""
|
|
1691
|
+
name = info['name']
|
|
1692
|
+
weight = info['weight']
|
|
1693
|
+
bias_scaled = info.get('bias_scaled')
|
|
1694
|
+
alpha = info['alpha']
|
|
1695
|
+
is_last = info.get('is_last', False)
|
|
1696
|
+
|
|
1697
|
+
in_ch = info['in_channels']
|
|
1698
|
+
out_ch = info['out_channels']
|
|
1699
|
+
stride = info['stride']
|
|
1700
|
+
in_h = info['in_h']
|
|
1701
|
+
in_w = info['in_w']
|
|
1702
|
+
out_h = info['out_h']
|
|
1703
|
+
out_w = info['out_w']
|
|
1704
|
+
|
|
1705
|
+
input_type = "uint8_t" if is_first_layer else "int8_t"
|
|
1706
|
+
|
|
1707
|
+
code = []
|
|
1708
|
+
code.append(f"// {name} - Conv2d 1x1: {in_ch}x{in_h}x{in_w} -> {out_ch}x{out_h}x{out_w}")
|
|
1709
|
+
code.append(f"// Stride: {stride}, alpha={alpha:.4f}")
|
|
1710
|
+
if is_first_layer:
|
|
1711
|
+
code.append(f"// First layer: input is uint8 [0,255], /256 absorbed in shift")
|
|
1712
|
+
if is_last:
|
|
1713
|
+
code.append(f"// Last layer: no ReLU")
|
|
1714
|
+
code.append(f"static void {name}_forward(const {input_type}* input, int8_t* output) {{")
|
|
1715
|
+
code.append(f" int32_t acc;")
|
|
1716
|
+
code.append("")
|
|
1717
|
+
|
|
1718
|
+
# 1x1 Conv: 각 출력 위치에서 채널 내적만 수행
|
|
1719
|
+
code.append(f" for (int oy = 0; oy < {out_h}; oy++) {{")
|
|
1720
|
+
code.append(f" for (int ox = 0; ox < {out_w}; ox++) {{")
|
|
1721
|
+
|
|
1722
|
+
# stride 적용: 입력 좌표 계산
|
|
1723
|
+
if stride == 1:
|
|
1724
|
+
code.append(f" int iy = oy;")
|
|
1725
|
+
code.append(f" int ix = ox;")
|
|
1726
|
+
else:
|
|
1727
|
+
code.append(f" int iy = oy * {stride};")
|
|
1728
|
+
code.append(f" int ix = ox * {stride};")
|
|
1729
|
+
code.append("")
|
|
1730
|
+
|
|
1731
|
+
for oc in range(out_ch):
|
|
1732
|
+
code.append(f" // Output channel {oc}")
|
|
1733
|
+
code.append(f" acc = 0;")
|
|
1734
|
+
|
|
1735
|
+
# 채널 내적 (1x1이므로 ky, kx 루프 없음)
|
|
1736
|
+
for ic in range(in_ch):
|
|
1737
|
+
w = weight[oc, ic, 0, 0].item()
|
|
1738
|
+
if abs(w) < 1e-9: # Skip zero weights
|
|
1739
|
+
continue
|
|
1740
|
+
|
|
1741
|
+
w_abs = abs(w)
|
|
1742
|
+
if w_abs < 0.5: # Too small, skip
|
|
1743
|
+
continue
|
|
1744
|
+
|
|
1745
|
+
# Find k where 2^k ≈ pot_value
|
|
1746
|
+
k = round(np.log2(w_abs + 1e-10))
|
|
1747
|
+
k = max(0, min(4, k)) # k in [0, 4] for {1, 2, 4, 8, 16}
|
|
1748
|
+
|
|
1749
|
+
sign = '+' if w > 0 else '-'
|
|
1750
|
+
|
|
1751
|
+
# 입력 인덱스: [ic, iy, ix] in CHW layout
|
|
1752
|
+
if in_ch > 1:
|
|
1753
|
+
idx = f"{ic} * {in_h} * {in_w} + iy * {in_w} + ix"
|
|
1754
|
+
else:
|
|
1755
|
+
idx = f"iy * {in_w} + ix"
|
|
1756
|
+
|
|
1757
|
+
if k == 0:
|
|
1758
|
+
shift_expr = f"(int32_t)input[{idx}]"
|
|
1759
|
+
else:
|
|
1760
|
+
shift_expr = f"(int32_t)input[{idx}] << {k}"
|
|
1761
|
+
|
|
1762
|
+
code.append(f" acc {sign}= {shift_expr};")
|
|
1763
|
+
|
|
1764
|
+
# Apply scale
|
|
1765
|
+
code.append(f" acc = scale_{name}(acc);")
|
|
1766
|
+
|
|
1767
|
+
# Add scaled bias
|
|
1768
|
+
if bias_scaled is not None:
|
|
1769
|
+
b = int(bias_scaled[oc].item())
|
|
1770
|
+
if b != 0:
|
|
1771
|
+
code.append(f" acc += {b};")
|
|
1772
|
+
|
|
1773
|
+
# ReLU (based on ONNX graph analysis)
|
|
1774
|
+
if info.get('has_relu', False):
|
|
1775
|
+
code.append(f" if (acc < 0) acc = 0;")
|
|
1776
|
+
|
|
1777
|
+
# Clamp and store
|
|
1778
|
+
out_idx = f"{oc} * {out_h} * {out_w} + oy * {out_w} + ox"
|
|
1779
|
+
code.append(f" output[{out_idx}] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));")
|
|
1780
|
+
code.append("")
|
|
1781
|
+
|
|
1782
|
+
code.append(f" }}")
|
|
1783
|
+
code.append(f" }}")
|
|
1784
|
+
code.append(f"}}")
|
|
1785
|
+
code.append("")
|
|
1786
|
+
|
|
1787
|
+
return '\n'.join(code)
|
|
1788
|
+
|
|
1789
|
+
|
|
1790
|
+
def generate_conv_layer(info: Dict, is_first_layer: bool = False) -> str:
|
|
1791
|
+
"""Generate Conv2d layer with PoT weights."""
|
|
1792
|
+
name = info['name']
|
|
1793
|
+
weight = info['weight']
|
|
1794
|
+
bias_scaled = info.get('bias_scaled')
|
|
1795
|
+
alpha = info['alpha']
|
|
1796
|
+
is_last = info.get('is_last', False)
|
|
1797
|
+
|
|
1798
|
+
in_ch = info['in_channels']
|
|
1799
|
+
out_ch = info['out_channels']
|
|
1800
|
+
kh = info['kernel_size']
|
|
1801
|
+
kw = kh
|
|
1802
|
+
stride = info['stride']
|
|
1803
|
+
padding = info['padding']
|
|
1804
|
+
in_h = info['in_h']
|
|
1805
|
+
in_w = info['in_w']
|
|
1806
|
+
out_h = info['out_h']
|
|
1807
|
+
out_w = info['out_w']
|
|
1808
|
+
|
|
1809
|
+
# 1x1 Conv 최적화: 별도 함수로 처리
|
|
1810
|
+
if kh == 1 and kw == 1:
|
|
1811
|
+
return generate_conv1x1_layer(info, is_first_layer)
|
|
1812
|
+
|
|
1813
|
+
# Input type: uint8 for first layer (raw input), int8 for others
|
|
1814
|
+
input_type = "uint8_t" if is_first_layer else "int8_t"
|
|
1815
|
+
|
|
1816
|
+
code = []
|
|
1817
|
+
code.append(f"// {name} - Conv2d: {in_ch}x{in_h}x{in_w} -> {out_ch}x{out_h}x{out_w}")
|
|
1818
|
+
code.append(f"// Kernel: {kh}x{kw}, Stride: {stride}, Padding: {padding}, alpha={alpha:.4f}")
|
|
1819
|
+
if is_first_layer:
|
|
1820
|
+
code.append(f"// First layer: input is uint8 [0,255], /256 absorbed in shift")
|
|
1821
|
+
if is_last:
|
|
1822
|
+
code.append(f"// Last layer: no ReLU")
|
|
1823
|
+
code.append(f"static void {name}_forward(const {input_type}* input, int8_t* output) {{")
|
|
1824
|
+
code.append(f" int32_t acc;")
|
|
1825
|
+
code.append("")
|
|
1826
|
+
|
|
1827
|
+
# Generate unrolled convolution
|
|
1828
|
+
code.append(f" for (int oy = 0; oy < {out_h}; oy++) {{")
|
|
1829
|
+
code.append(f" for (int ox = 0; ox < {out_w}; ox++) {{")
|
|
1830
|
+
|
|
1831
|
+
for oc in range(out_ch):
|
|
1832
|
+
code.append(f" // Output channel {oc}")
|
|
1833
|
+
code.append(f" acc = 0;")
|
|
1834
|
+
|
|
1835
|
+
# Generate weight operations
|
|
1836
|
+
for ic in range(in_ch):
|
|
1837
|
+
for ky in range(kh):
|
|
1838
|
+
for kx in range(kw):
|
|
1839
|
+
w = weight[oc, ic, ky, kx].item()
|
|
1840
|
+
if abs(w) < 1e-9: # Skip zero weights
|
|
1841
|
+
continue
|
|
1842
|
+
|
|
1843
|
+
# weight_q = pot_value * alpha
|
|
1844
|
+
# pot_value = weight_q / alpha ∈ {0, 1, 2, 4, 8, 16}
|
|
1845
|
+
w_abs = abs(w)
|
|
1846
|
+
pot_value = w_abs
|
|
1847
|
+
|
|
1848
|
+
# pot_value should be close to 1, 2, 4, 8, or 16
|
|
1849
|
+
if pot_value < 0.5: # Too small, skip
|
|
1850
|
+
continue
|
|
1851
|
+
|
|
1852
|
+
# Find k where 2^k ≈ pot_value
|
|
1853
|
+
k = round(np.log2(pot_value + 1e-10))
|
|
1854
|
+
k = max(0, min(4, k)) # k in [0, 4] for {1, 2, 4, 8, 16}
|
|
1855
|
+
|
|
1856
|
+
sign = '+' if w > 0 else '-'
|
|
1857
|
+
|
|
1858
|
+
# Calculate input coordinates with stride (BUG FIX!)
|
|
1859
|
+
# iy = oy * stride + ky - padding
|
|
1860
|
+
# ix = ox * stride + kx - padding
|
|
1861
|
+
iy_base = ky - padding # offset from oy * stride
|
|
1862
|
+
ix_base = kx - padding # offset from ox * stride
|
|
1863
|
+
|
|
1864
|
+
# Build index expression with stride
|
|
1865
|
+
idx_parts = []
|
|
1866
|
+
if in_ch > 1:
|
|
1867
|
+
idx_parts.append(f"{ic} * {in_h} * {in_w}")
|
|
1868
|
+
|
|
1869
|
+
# Y coordinate: oy * stride + iy_base
|
|
1870
|
+
if stride == 1:
|
|
1871
|
+
if iy_base == 0:
|
|
1872
|
+
idx_parts.append(f"oy * {in_w}")
|
|
1873
|
+
elif iy_base > 0:
|
|
1874
|
+
idx_parts.append(f"(oy + {iy_base}) * {in_w}")
|
|
1875
|
+
else:
|
|
1876
|
+
idx_parts.append(f"(oy - {-iy_base}) * {in_w}")
|
|
1877
|
+
else:
|
|
1878
|
+
if iy_base == 0:
|
|
1879
|
+
idx_parts.append(f"oy * {stride} * {in_w}")
|
|
1880
|
+
elif iy_base > 0:
|
|
1881
|
+
idx_parts.append(f"(oy * {stride} + {iy_base}) * {in_w}")
|
|
1882
|
+
else:
|
|
1883
|
+
idx_parts.append(f"(oy * {stride} - {-iy_base}) * {in_w}")
|
|
1884
|
+
|
|
1885
|
+
# X coordinate: ox * stride + ix_base
|
|
1886
|
+
if stride == 1:
|
|
1887
|
+
if ix_base == 0:
|
|
1888
|
+
idx_parts.append("ox")
|
|
1889
|
+
elif ix_base > 0:
|
|
1890
|
+
idx_parts.append(f"(ox + {ix_base})")
|
|
1891
|
+
else:
|
|
1892
|
+
idx_parts.append(f"(ox - {-ix_base})")
|
|
1893
|
+
else:
|
|
1894
|
+
if ix_base == 0:
|
|
1895
|
+
idx_parts.append(f"ox * {stride}")
|
|
1896
|
+
elif ix_base > 0:
|
|
1897
|
+
idx_parts.append(f"(ox * {stride} + {ix_base})")
|
|
1898
|
+
else:
|
|
1899
|
+
idx_parts.append(f"(ox * {stride} - {-ix_base})")
|
|
1900
|
+
|
|
1901
|
+
idx = " + ".join(idx_parts)
|
|
1902
|
+
|
|
1903
|
+
# Boundary conditions for padding (with stride)
|
|
1904
|
+
checks = []
|
|
1905
|
+
if stride == 1:
|
|
1906
|
+
if iy_base < 0:
|
|
1907
|
+
checks.append(f"oy >= {-iy_base}")
|
|
1908
|
+
if iy_base > 0:
|
|
1909
|
+
checks.append(f"oy + {iy_base} < {in_h}")
|
|
1910
|
+
if ix_base < 0:
|
|
1911
|
+
checks.append(f"ox >= {-ix_base}")
|
|
1912
|
+
if ix_base > 0:
|
|
1913
|
+
checks.append(f"ox + {ix_base} < {in_w}")
|
|
1914
|
+
else:
|
|
1915
|
+
if iy_base < 0:
|
|
1916
|
+
checks.append(f"oy * {stride} >= {-iy_base}")
|
|
1917
|
+
if iy_base > 0:
|
|
1918
|
+
checks.append(f"oy * {stride} + {iy_base} < {in_h}")
|
|
1919
|
+
if ix_base < 0:
|
|
1920
|
+
checks.append(f"ox * {stride} >= {-ix_base}")
|
|
1921
|
+
if ix_base > 0:
|
|
1922
|
+
checks.append(f"ox * {stride} + {ix_base} < {in_w}")
|
|
1923
|
+
|
|
1924
|
+
# Generate code
|
|
1925
|
+
if k == 0:
|
|
1926
|
+
shift_expr = f"(int32_t)input[{idx}]"
|
|
1927
|
+
else:
|
|
1928
|
+
shift_expr = f"(int32_t)input[{idx}] << {k}"
|
|
1929
|
+
|
|
1930
|
+
if checks:
|
|
1931
|
+
cond = " && ".join(checks)
|
|
1932
|
+
code.append(f" if ({cond}) acc {sign}= {shift_expr};")
|
|
1933
|
+
else:
|
|
1934
|
+
code.append(f" acc {sign}= {shift_expr};")
|
|
1935
|
+
|
|
1936
|
+
# Apply scale
|
|
1937
|
+
code.append(f" acc = scale_{name}(acc);")
|
|
1938
|
+
|
|
1939
|
+
# Add scaled bias (CRITICAL FIX: use bias_scaled, not raw bias)
|
|
1940
|
+
if bias_scaled is not None:
|
|
1941
|
+
b = int(bias_scaled[oc].item())
|
|
1942
|
+
if b != 0:
|
|
1943
|
+
code.append(f" acc += {b};")
|
|
1944
|
+
|
|
1945
|
+
# ReLU (based on ONNX graph analysis)
|
|
1946
|
+
if info.get('has_relu', False):
|
|
1947
|
+
code.append(f" if (acc < 0) acc = 0;")
|
|
1948
|
+
|
|
1949
|
+
# Clamp and store
|
|
1950
|
+
out_idx = f"{oc * out_h * out_w} + oy * {out_w} + ox"
|
|
1951
|
+
code.append(f" output[{out_idx}] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));")
|
|
1952
|
+
code.append("")
|
|
1953
|
+
|
|
1954
|
+
code.append(f" }}")
|
|
1955
|
+
code.append(f" }}")
|
|
1956
|
+
code.append(f"}}")
|
|
1957
|
+
code.append("")
|
|
1958
|
+
|
|
1959
|
+
return '\n'.join(code)
|
|
1960
|
+
|
|
1961
|
+
|
|
1962
|
+
def generate_linear_layer(info: Dict, is_first_layer: bool = False) -> str:
|
|
1963
|
+
"""Generate Linear layer with PoT weights."""
|
|
1964
|
+
name = info['name']
|
|
1965
|
+
weight = info['weight']
|
|
1966
|
+
bias_scaled = info.get('bias_scaled')
|
|
1967
|
+
alpha = info['alpha']
|
|
1968
|
+
is_last = info.get('is_last', False)
|
|
1969
|
+
|
|
1970
|
+
in_features = info['in_features']
|
|
1971
|
+
out_features = info['out_features']
|
|
1972
|
+
|
|
1973
|
+
input_type = "uint8_t" if is_first_layer else "int8_t"
|
|
1974
|
+
|
|
1975
|
+
code = []
|
|
1976
|
+
code.append(f"// {name} - Linear: {in_features} -> {out_features}, alpha={alpha:.4f}")
|
|
1977
|
+
if is_last:
|
|
1978
|
+
code.append(f"// Last layer: no ReLU")
|
|
1979
|
+
code.append(f"static void {name}_forward(const {input_type}* input, int8_t* output) {{")
|
|
1980
|
+
code.append(f" int32_t acc;")
|
|
1981
|
+
code.append("")
|
|
1982
|
+
|
|
1983
|
+
for o in range(out_features):
|
|
1984
|
+
code.append(f" // Output {o}")
|
|
1985
|
+
code.append(f" acc = 0;")
|
|
1986
|
+
|
|
1987
|
+
for i in range(in_features):
|
|
1988
|
+
w = weight[o, i].item()
|
|
1989
|
+
if abs(w) < 1e-9:
|
|
1990
|
+
continue
|
|
1991
|
+
|
|
1992
|
+
w_abs = abs(w)
|
|
1993
|
+
pot_value = w_abs
|
|
1994
|
+
|
|
1995
|
+
if pot_value < 0.5:
|
|
1996
|
+
continue
|
|
1997
|
+
|
|
1998
|
+
k = round(np.log2(pot_value + 1e-10))
|
|
1999
|
+
k = max(0, min(4, k))
|
|
2000
|
+
|
|
2001
|
+
sign = '+' if w > 0 else '-'
|
|
2002
|
+
|
|
2003
|
+
if k == 0:
|
|
2004
|
+
code.append(f" acc {sign}= (int32_t)input[{i}];")
|
|
2005
|
+
else:
|
|
2006
|
+
code.append(f" acc {sign}= (int32_t)input[{i}] << {k};")
|
|
2007
|
+
|
|
2008
|
+
code.append(f" acc = scale_{name}(acc);")
|
|
2009
|
+
|
|
2010
|
+
# Add scaled bias (CRITICAL FIX)
|
|
2011
|
+
if bias_scaled is not None:
|
|
2012
|
+
b = int(bias_scaled[o].item())
|
|
2013
|
+
if b != 0:
|
|
2014
|
+
code.append(f" acc += {b};")
|
|
2015
|
+
|
|
2016
|
+
# ReLU (based on ONNX graph analysis)
|
|
2017
|
+
if info.get('has_relu', False):
|
|
2018
|
+
code.append(f" if (acc < 0) acc = 0;")
|
|
2019
|
+
|
|
2020
|
+
code.append(f" output[{o}] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));")
|
|
2021
|
+
code.append("")
|
|
2022
|
+
|
|
2023
|
+
code.append(f"}}")
|
|
2024
|
+
code.append("")
|
|
2025
|
+
|
|
2026
|
+
return '\n'.join(code)
|
|
2027
|
+
|
|
2028
|
+
|
|
2029
|
+
def generate_maxpool_layer(info: Dict) -> str:
|
|
2030
|
+
"""Generate MaxPool1d or MaxPool2d layer."""
|
|
2031
|
+
name = info['name']
|
|
2032
|
+
in_h = info['in_h']
|
|
2033
|
+
in_w = info['in_w']
|
|
2034
|
+
in_ch = info['in_channels']
|
|
2035
|
+
out_h = info['out_h']
|
|
2036
|
+
out_w = info['out_w']
|
|
2037
|
+
k = info['kernel_size']
|
|
2038
|
+
s = info['stride']
|
|
2039
|
+
is_1d = info.get('is_1d', False)
|
|
2040
|
+
|
|
2041
|
+
code = []
|
|
2042
|
+
|
|
2043
|
+
if is_1d:
|
|
2044
|
+
# MaxPool1d
|
|
2045
|
+
code.append(f"// {name} - MaxPool1d k={k}, stride={s}")
|
|
2046
|
+
code.append(f"static void {name}_forward(const int8_t* input, int8_t* output) {{")
|
|
2047
|
+
code.append(f" for (int c = 0; c < {in_ch}; c++) {{")
|
|
2048
|
+
code.append(f" for (int o = 0; o < {out_h}; o++) {{")
|
|
2049
|
+
code.append(f" int8_t max_val = -128;")
|
|
2050
|
+
code.append(f" for (int ki = 0; ki < {k}; ki++) {{")
|
|
2051
|
+
code.append(f" int idx = c * {in_h} + (o * {s} + ki);")
|
|
2052
|
+
code.append(f" if (input[idx] > max_val) max_val = input[idx];")
|
|
2053
|
+
code.append(f" }}")
|
|
2054
|
+
code.append(f" output[c * {out_h} + o] = max_val;")
|
|
2055
|
+
code.append(f" }}")
|
|
2056
|
+
code.append(f" }}")
|
|
2057
|
+
code.append(f"}}")
|
|
2058
|
+
else:
|
|
2059
|
+
# MaxPool2d
|
|
2060
|
+
code.append(f"// {name} - MaxPool2d {k}x{k}, stride {s}")
|
|
2061
|
+
code.append(f"static void {name}_forward(const int8_t* input, int8_t* output) {{")
|
|
2062
|
+
code.append(f" for (int c = 0; c < {in_ch}; c++) {{")
|
|
2063
|
+
code.append(f" for (int oy = 0; oy < {out_h}; oy++) {{")
|
|
2064
|
+
code.append(f" for (int ox = 0; ox < {out_w}; ox++) {{")
|
|
2065
|
+
code.append(f" int8_t max_val = -128;")
|
|
2066
|
+
code.append(f" for (int ky = 0; ky < {k}; ky++) {{")
|
|
2067
|
+
code.append(f" for (int kx = 0; kx < {k}; kx++) {{")
|
|
2068
|
+
code.append(f" int idx = c * {in_h} * {in_w} + (oy * {s} + ky) * {in_w} + (ox * {s} + kx);")
|
|
2069
|
+
code.append(f" if (input[idx] > max_val) max_val = input[idx];")
|
|
2070
|
+
code.append(f" }}")
|
|
2071
|
+
code.append(f" }}")
|
|
2072
|
+
code.append(f" output[c * {out_h} * {out_w} + oy * {out_w} + ox] = max_val;")
|
|
2073
|
+
code.append(f" }}")
|
|
2074
|
+
code.append(f" }}")
|
|
2075
|
+
code.append(f" }}")
|
|
2076
|
+
code.append(f"}}")
|
|
2077
|
+
|
|
2078
|
+
code.append("")
|
|
2079
|
+
|
|
2080
|
+
return '\n'.join(code)
|
|
2081
|
+
|
|
2082
|
+
|
|
2083
|
+
def generate_predict_function(layer_infos: List[Dict], config: Config) -> str:
|
|
2084
|
+
"""Generate main prediction function with proper skip connection handling."""
|
|
2085
|
+
code = []
|
|
2086
|
+
|
|
2087
|
+
input_size = config.input_h * config.input_w
|
|
2088
|
+
|
|
2089
|
+
# Calculate max buffer size
|
|
2090
|
+
max_buffer_size = input_size
|
|
2091
|
+
for info in layer_infos:
|
|
2092
|
+
if 'out_h' in info and 'out_w' in info:
|
|
2093
|
+
if 'out_channels' in info:
|
|
2094
|
+
size = info['out_channels'] * info['out_h'] * info['out_w']
|
|
2095
|
+
elif 'in_channels' in info:
|
|
2096
|
+
size = info['in_channels'] * info['out_h'] * info['out_w']
|
|
2097
|
+
else:
|
|
2098
|
+
size = info['out_h'] * info['out_w']
|
|
2099
|
+
max_buffer_size = max(max_buffer_size, size)
|
|
2100
|
+
elif 'out_features' in info:
|
|
2101
|
+
max_buffer_size = max(max_buffer_size, info['out_features'])
|
|
2102
|
+
|
|
2103
|
+
# ========================================
|
|
2104
|
+
# Analyze skip connections
|
|
2105
|
+
# ========================================
|
|
2106
|
+
# Collect all skip source layers (where skip branches start)
|
|
2107
|
+
skip_sources = set()
|
|
2108
|
+
for info in layer_infos:
|
|
2109
|
+
if info['layer_type'] == 'add' and 'skip_source_layer' in info:
|
|
2110
|
+
skip_sources.add(info['skip_source_layer'])
|
|
2111
|
+
|
|
2112
|
+
has_skip = len(skip_sources) > 0
|
|
2113
|
+
|
|
2114
|
+
# Find number of output classes
|
|
2115
|
+
num_classes = 10
|
|
2116
|
+
for info in reversed(layer_infos):
|
|
2117
|
+
if 'out_features' in info and info['layer_type'] == 'pot':
|
|
2118
|
+
num_classes = info['out_features']
|
|
2119
|
+
break
|
|
2120
|
+
|
|
2121
|
+
code.append("// Input: uint8 [0,255] - raw pixel values, no normalization needed")
|
|
2122
|
+
code.append("// Output: int8 [0,255] - raw logits")
|
|
2123
|
+
code.append("void potnn_predict(const uint8_t* input, int8_t* output) {")
|
|
2124
|
+
code.append(f" // Intermediate buffers (max size: {max_buffer_size})")
|
|
2125
|
+
code.append(f" static int8_t buffer1[{max_buffer_size}];")
|
|
2126
|
+
code.append(f" static int8_t buffer2[{max_buffer_size}];")
|
|
2127
|
+
|
|
2128
|
+
if has_skip:
|
|
2129
|
+
code.append(f" static int8_t skip_buffer[{max_buffer_size}]; // for skip connections")
|
|
2130
|
+
code.append(f" // Skip sources: layers {sorted(skip_sources)}")
|
|
2131
|
+
|
|
2132
|
+
code.append(f" int8_t *current = buffer1;")
|
|
2133
|
+
code.append(f" int8_t *next = buffer2;")
|
|
2134
|
+
code.append("")
|
|
2135
|
+
|
|
2136
|
+
first_pot_done = False
|
|
2137
|
+
|
|
2138
|
+
for i, info in enumerate(layer_infos):
|
|
2139
|
+
layer_type = info['layer_type']
|
|
2140
|
+
|
|
2141
|
+
if layer_type == 'flatten':
|
|
2142
|
+
code.append(f" // Layer {i}: Flatten (no-op)")
|
|
2143
|
+
continue
|
|
2144
|
+
|
|
2145
|
+
code.append(f" // Layer {i}: {info['type']}")
|
|
2146
|
+
|
|
2147
|
+
if layer_type == 'pot' and not first_pot_done:
|
|
2148
|
+
# First PoT layer: input is uint8
|
|
2149
|
+
code.append(f" {info['name']}_forward(input, current);")
|
|
2150
|
+
first_pot_done = True
|
|
2151
|
+
|
|
2152
|
+
elif layer_type == 'add':
|
|
2153
|
+
# Add layer: skip_buffer + current -> next
|
|
2154
|
+
skip_src = info.get('skip_source_layer', -1)
|
|
2155
|
+
conv_src = info.get('conv_source_layer', -1)
|
|
2156
|
+
code.append(f" // skip from layer_{skip_src}, conv from layer_{conv_src}")
|
|
2157
|
+
code.append(f" {info['name']}_forward(skip_buffer, current, next);")
|
|
2158
|
+
code.append(f" {{ int8_t *tmp = current; current = next; next = tmp; }}")
|
|
2159
|
+
|
|
2160
|
+
elif layer_type == 'global_avg_pool':
|
|
2161
|
+
# Global Average Pooling: C×H×W → C
|
|
2162
|
+
code.append(f" {info['name']}_forward(current, next);")
|
|
2163
|
+
code.append(f" {{ int8_t *tmp = current; current = next; next = tmp; }}")
|
|
2164
|
+
|
|
2165
|
+
else:
|
|
2166
|
+
code.append(f" {info['name']}_forward(current, next);")
|
|
2167
|
+
code.append(f" {{ int8_t *tmp = current; current = next; next = tmp; }}")
|
|
2168
|
+
|
|
2169
|
+
# ========================================
|
|
2170
|
+
# Check if this layer is a skip source → save to skip_buffer
|
|
2171
|
+
# ========================================
|
|
2172
|
+
if i in skip_sources:
|
|
2173
|
+
# Find the size of this layer's output
|
|
2174
|
+
if 'out_channels' in info and 'out_h' in info and 'out_w' in info:
|
|
2175
|
+
skip_size = info['out_channels'] * info['out_h'] * info['out_w']
|
|
2176
|
+
elif 'out_features' in info:
|
|
2177
|
+
skip_size = info['out_features']
|
|
2178
|
+
elif 'in_channels' in info and 'out_h' in info:
|
|
2179
|
+
skip_size = info['in_channels'] * info['out_h'] * info['out_w']
|
|
2180
|
+
else:
|
|
2181
|
+
skip_size = max_buffer_size
|
|
2182
|
+
|
|
2183
|
+
code.append(f" // Save skip (layer {i} is skip source)")
|
|
2184
|
+
code.append(f" for (int _i = 0; _i < {skip_size}; _i++) skip_buffer[_i] = current[_i];")
|
|
2185
|
+
|
|
2186
|
+
code.append("")
|
|
2187
|
+
|
|
2188
|
+
code.append(f" // Copy result to output buffer (num_classes = {num_classes})")
|
|
2189
|
+
code.append(f" for (int i = 0; i < {num_classes}; i++) {{")
|
|
2190
|
+
code.append(f" output[i] = current[i];")
|
|
2191
|
+
code.append(f" }}")
|
|
2192
|
+
code.append(f"}}")
|
|
2193
|
+
|
|
2194
|
+
return '\n'.join(code)
|
|
2195
|
+
|
|
2196
|
+
return '\n'.join(code)
|