potnn 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- potnn/__init__.py +86 -0
- potnn/codegen/__init__.py +20 -0
- potnn/codegen/bit2.py +263 -0
- potnn/codegen/fp130.py +269 -0
- potnn/codegen/header.py +460 -0
- potnn/codegen/level5.py +393 -0
- potnn/codegen/scale.py +184 -0
- potnn/codegen/ternary.py +354 -0
- potnn/codegen/unroll.py +616 -0
- potnn/config.py +112 -0
- potnn/export.py +2196 -0
- potnn/fuse.py +167 -0
- potnn/modules/__init__.py +11 -0
- potnn/modules/add.py +114 -0
- potnn/modules/avgpool.py +173 -0
- potnn/modules/base.py +225 -0
- potnn/modules/conv.py +203 -0
- potnn/modules/conv1d.py +317 -0
- potnn/modules/depthwise.py +216 -0
- potnn/modules/linear.py +199 -0
- potnn/quantize/__init__.py +35 -0
- potnn/quantize/calibration.py +233 -0
- potnn/quantize/integer_ops.py +207 -0
- potnn/quantize/integer_sim.py +225 -0
- potnn/quantize/pot.py +455 -0
- potnn/quantize/qat.py +356 -0
- potnn/utils/__init__.py +13 -0
- potnn/utils/allocation.py +240 -0
- potnn/utils/memory.py +158 -0
- potnn/wrapper.py +304 -0
- potnn-1.0.0.dist-info/METADATA +260 -0
- potnn-1.0.0.dist-info/RECORD +35 -0
- potnn-1.0.0.dist-info/WHEEL +5 -0
- potnn-1.0.0.dist-info/licenses/LICENSE +72 -0
- potnn-1.0.0.dist-info/top_level.txt +1 -0
potnn/codegen/header.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
"""Main C header generation for potnn."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from typing import Dict, Any, List, Tuple
|
|
7
|
+
|
|
8
|
+
from ..modules import PoTLinear, PoTConv2d
|
|
9
|
+
from .scale import generate_scale_func, calculate_combined_scale
|
|
10
|
+
from .unroll import generate_unrolled_layer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def generate_c_header(path: str, model: nn.Module):
|
|
15
|
+
"""Generate complete C header file for the model.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
path: Output file path
|
|
19
|
+
model: Trained potnn model
|
|
20
|
+
"""
|
|
21
|
+
config = model._potnn_config
|
|
22
|
+
allocations = model._potnn_allocations
|
|
23
|
+
input_shape = model._potnn_input_shape
|
|
24
|
+
|
|
25
|
+
# Analyze model structure and compute layer dimensions
|
|
26
|
+
layer_dims = _compute_layer_dimensions(model, input_shape)
|
|
27
|
+
|
|
28
|
+
# Debug: print computed dimensions
|
|
29
|
+
print("\nComputed layer dimensions:")
|
|
30
|
+
for name, dims in layer_dims.items():
|
|
31
|
+
print(f" {name}: {dims}")
|
|
32
|
+
|
|
33
|
+
# Collect layer information
|
|
34
|
+
layers = []
|
|
35
|
+
other_layers = [] # MaxPool, Flatten 등
|
|
36
|
+
prev_act_scale = 1.0
|
|
37
|
+
pot_layer_idx = 0
|
|
38
|
+
num_pot_layers = len([m for m in model.modules() if isinstance(m, (PoTLinear, PoTConv2d))])
|
|
39
|
+
|
|
40
|
+
for name, module in model.named_modules():
|
|
41
|
+
if isinstance(module, (PoTLinear, PoTConv2d)):
|
|
42
|
+
alloc = allocations.get(name)
|
|
43
|
+
if alloc is None:
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
with torch.no_grad():
|
|
47
|
+
alpha = F.softplus(module.alpha).clamp(min=0.01).item()
|
|
48
|
+
|
|
49
|
+
is_first = (pot_layer_idx == 0)
|
|
50
|
+
is_last = (pot_layer_idx == num_pot_layers - 1)
|
|
51
|
+
|
|
52
|
+
print(f"\n{'='*60}")
|
|
53
|
+
print(f"[DEBUG] Processing layer: {name}")
|
|
54
|
+
print(f" Type: {type(module).__name__}")
|
|
55
|
+
print(f" is_first: {is_first}, is_last: {is_last}")
|
|
56
|
+
print(f" alpha (raw): {module.alpha.item():.6f}")
|
|
57
|
+
print(f" alpha (softplus+clamp): {alpha:.6f}")
|
|
58
|
+
print(f" act_scale: {module.act_scale}")
|
|
59
|
+
print(f" prev_act_scale: {prev_act_scale:.6f}")
|
|
60
|
+
|
|
61
|
+
# Calculate adjusted bias (absorb mean for first layer)
|
|
62
|
+
adjusted_bias = calculate_adjusted_bias(
|
|
63
|
+
bias=module.bias.detach().cpu() if module.bias is not None else None,
|
|
64
|
+
weight=module.weight.detach().cpu(),
|
|
65
|
+
mean=config.mean if is_first else None,
|
|
66
|
+
is_first_layer=is_first
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
scale_int, shift = calculate_combined_scale(
|
|
70
|
+
alpha=alpha,
|
|
71
|
+
act_scale=module.act_scale if module.act_scale else 1.0,
|
|
72
|
+
prev_act_scale=prev_act_scale,
|
|
73
|
+
input_norm=config.input_norm if config.input_norm else 256,
|
|
74
|
+
is_first_layer=is_first,
|
|
75
|
+
is_last_layer=is_last,
|
|
76
|
+
mean=config.mean if is_first else None,
|
|
77
|
+
std=config.std if is_first else None
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Debug: Show bias scaling preview
|
|
81
|
+
act_scale_for_bias = module.act_scale if module.act_scale else 1.0
|
|
82
|
+
if adjusted_bias is not None:
|
|
83
|
+
scaled_bias = [int(round(b.item() * act_scale_for_bias)) for b in adjusted_bias[:4]]
|
|
84
|
+
print(f" [DEBUG] Bias scaling preview:")
|
|
85
|
+
print(f" act_scale for bias: {act_scale_for_bias:.4f}")
|
|
86
|
+
print(f" adjusted_bias (first 4): {[round(b.item(), 4) for b in adjusted_bias[:4]]}")
|
|
87
|
+
print(f" scaled_bias (first 4): {scaled_bias}")
|
|
88
|
+
|
|
89
|
+
# Make C-compatible name
|
|
90
|
+
c_name = name.replace('.', '_')
|
|
91
|
+
if c_name and c_name[0].isdigit():
|
|
92
|
+
c_name = 'layer_' + c_name
|
|
93
|
+
|
|
94
|
+
# Get dimensions from layer_dims
|
|
95
|
+
dims = layer_dims.get(name, {})
|
|
96
|
+
|
|
97
|
+
# Determine if ReLU follows this layer
|
|
98
|
+
use_relu = _check_relu_follows(model, name)
|
|
99
|
+
|
|
100
|
+
layer_info = {
|
|
101
|
+
'name': c_name,
|
|
102
|
+
'original_name': name,
|
|
103
|
+
'type': type(module).__name__,
|
|
104
|
+
'weight': module.weight.detach().cpu(),
|
|
105
|
+
'bias': adjusted_bias, # Use adjusted bias (mean absorbed for first layer)
|
|
106
|
+
'alpha': alpha,
|
|
107
|
+
'act_scale': module.act_scale if module.act_scale else 1.0, # For bias scaling
|
|
108
|
+
'levels': alloc.levels,
|
|
109
|
+
'mode': alloc.mode,
|
|
110
|
+
'scale_int': scale_int,
|
|
111
|
+
'shift': shift,
|
|
112
|
+
'use_relu': use_relu,
|
|
113
|
+
# Dimensions
|
|
114
|
+
'in_h': dims.get('in_h', 0),
|
|
115
|
+
'in_w': dims.get('in_w', 0),
|
|
116
|
+
'out_h': dims.get('out_h', 0),
|
|
117
|
+
'out_w': dims.get('out_w', 0),
|
|
118
|
+
'in_size': dims.get('in_size', 0),
|
|
119
|
+
'out_size': dims.get('out_size', 0),
|
|
120
|
+
'stride': dims.get('stride', 1),
|
|
121
|
+
'padding': dims.get('padding', 0),
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
layers.append(layer_info)
|
|
125
|
+
|
|
126
|
+
if module.act_scale is not None:
|
|
127
|
+
prev_act_scale = module.act_scale
|
|
128
|
+
|
|
129
|
+
pot_layer_idx += 1
|
|
130
|
+
|
|
131
|
+
elif isinstance(module, nn.MaxPool2d):
|
|
132
|
+
c_name = name.replace('.', '_')
|
|
133
|
+
if c_name and c_name[0].isdigit():
|
|
134
|
+
c_name = 'layer_' + c_name
|
|
135
|
+
|
|
136
|
+
dims = layer_dims.get(name, {})
|
|
137
|
+
other_layers.append({
|
|
138
|
+
'name': c_name,
|
|
139
|
+
'original_name': name,
|
|
140
|
+
'type': 'MaxPool2d',
|
|
141
|
+
'kernel_size': module.kernel_size if isinstance(module.kernel_size, int) else module.kernel_size[0],
|
|
142
|
+
'stride': module.stride if isinstance(module.stride, int) else module.stride[0],
|
|
143
|
+
'in_h': dims.get('in_h', 0),
|
|
144
|
+
'in_w': dims.get('in_w', 0),
|
|
145
|
+
'in_ch': dims.get('in_ch', 0),
|
|
146
|
+
'out_h': dims.get('out_h', 0),
|
|
147
|
+
'out_w': dims.get('out_w', 0),
|
|
148
|
+
'in_size': dims.get('in_size', 0),
|
|
149
|
+
'out_size': dims.get('out_size', 0),
|
|
150
|
+
})
|
|
151
|
+
|
|
152
|
+
elif isinstance(module, nn.Flatten):
|
|
153
|
+
c_name = name.replace('.', '_')
|
|
154
|
+
if c_name and c_name[0].isdigit():
|
|
155
|
+
c_name = 'layer_' + c_name
|
|
156
|
+
|
|
157
|
+
dims = layer_dims.get(name, {})
|
|
158
|
+
other_layers.append({
|
|
159
|
+
'name': c_name,
|
|
160
|
+
'original_name': name,
|
|
161
|
+
'type': 'Flatten',
|
|
162
|
+
'in_size': dims.get('in_size', 0),
|
|
163
|
+
'out_size': dims.get('out_size', 0),
|
|
164
|
+
})
|
|
165
|
+
|
|
166
|
+
# Compute buffer sizes
|
|
167
|
+
max_buffer_size = _compute_max_buffer_size(layer_dims)
|
|
168
|
+
|
|
169
|
+
# Get output classes from last layer
|
|
170
|
+
last_layer = layers[-1] if layers else None
|
|
171
|
+
num_classes = last_layer['weight'].shape[0] if last_layer else 10
|
|
172
|
+
|
|
173
|
+
# Build ordered layer sequence
|
|
174
|
+
layer_sequence = _build_layer_sequence(model, layers, other_layers)
|
|
175
|
+
|
|
176
|
+
# Generate header file
|
|
177
|
+
with open(path, 'w') as f:
|
|
178
|
+
_write_header_preamble(f, config)
|
|
179
|
+
|
|
180
|
+
# Add input scale info if available
|
|
181
|
+
if hasattr(model, 'input_scale'):
|
|
182
|
+
f.write(f"/* Input scale: {model.input_scale:.3f} (for test data: int8 = normalized_float * input_scale) */\n")
|
|
183
|
+
f.write(f"/* Input max: {model.input_max:.3f} */\n\n")
|
|
184
|
+
|
|
185
|
+
# Scale functions
|
|
186
|
+
f.write("/* Scale functions using only shifts and adds */\n")
|
|
187
|
+
for layer in layers:
|
|
188
|
+
f.write(generate_scale_func(layer['name'], layer['scale_int'], layer['shift']))
|
|
189
|
+
|
|
190
|
+
# MaxPool functions
|
|
191
|
+
for other in other_layers:
|
|
192
|
+
if other['type'] == 'MaxPool2d':
|
|
193
|
+
f.write(_generate_maxpool_func(other))
|
|
194
|
+
|
|
195
|
+
# Layer forward functions
|
|
196
|
+
f.write("/* Layer forward functions */\n")
|
|
197
|
+
for layer in layers:
|
|
198
|
+
if layer['mode'] == 'unroll':
|
|
199
|
+
f.write(generate_unrolled_layer(layer))
|
|
200
|
+
else:
|
|
201
|
+
# Fallback purely to unroll or error
|
|
202
|
+
f.write(generate_unrolled_layer(layer))
|
|
203
|
+
|
|
204
|
+
# Main prediction function
|
|
205
|
+
f.write(_generate_main_predict(layer_sequence, max_buffer_size, num_classes))
|
|
206
|
+
|
|
207
|
+
f.write("\n#endif /* POTNN_MODEL_H */\n")
|
|
208
|
+
|
|
209
|
+
print(f"C header generated: {path}")
|
|
210
|
+
print(f" - {len(layers)} PoT layers")
|
|
211
|
+
print(f" - {len(other_layers)} other layers (MaxPool, Flatten)")
|
|
212
|
+
print(f" - Buffer size: {max_buffer_size} bytes")
|
|
213
|
+
print(f" - Output classes: {num_classes}")
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _write_header_preamble(f, config):
|
|
217
|
+
"""Write header file preamble."""
|
|
218
|
+
f.write("/* POTNN Generated Model - MUL-FREE Neural Network */\n")
|
|
219
|
+
f.write("/* Target: Ultra-low-cost MCUs without multiplication */\n")
|
|
220
|
+
f.write(f"/* Flash budget: {config.flash} bytes */\n")
|
|
221
|
+
f.write(f"/* RAM budget: {config.ram} bytes */\n\n")
|
|
222
|
+
f.write("#ifndef POTNN_MODEL_H\n")
|
|
223
|
+
f.write("#define POTNN_MODEL_H\n\n")
|
|
224
|
+
f.write("#include <stdint.h>\n\n")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _generate_maxpool_func(info: Dict) -> str:
|
|
228
|
+
"""Generate MaxPool2d function."""
|
|
229
|
+
name = info['name']
|
|
230
|
+
kernel = info['kernel_size']
|
|
231
|
+
stride = info['stride']
|
|
232
|
+
in_h = info['in_h']
|
|
233
|
+
in_w = info['in_w']
|
|
234
|
+
in_ch = info['in_ch']
|
|
235
|
+
out_h = info['out_h']
|
|
236
|
+
out_w = info['out_w']
|
|
237
|
+
|
|
238
|
+
code = f"// {name} - MaxPool2d {kernel}x{kernel}, stride {stride}\n"
|
|
239
|
+
code += f"static void {name}_forward(const int8_t* input, int8_t* output) {{\n"
|
|
240
|
+
code += f" const int IN_H = {in_h}, IN_W = {in_w}, IN_CH = {in_ch};\n"
|
|
241
|
+
code += f" const int OUT_H = {out_h}, OUT_W = {out_w};\n"
|
|
242
|
+
code += f" const int K = {kernel}, S = {stride};\n"
|
|
243
|
+
code += f" \n"
|
|
244
|
+
code += f" for (int c = 0; c < IN_CH; c++) {{\n"
|
|
245
|
+
code += f" for (int oy = 0; oy < OUT_H; oy++) {{\n"
|
|
246
|
+
code += f" for (int ox = 0; ox < OUT_W; ox++) {{\n"
|
|
247
|
+
code += f" int8_t max_val = -128;\n"
|
|
248
|
+
code += f" for (int ky = 0; ky < K; ky++) {{\n"
|
|
249
|
+
code += f" for (int kx = 0; kx < K; kx++) {{\n"
|
|
250
|
+
code += f" int in_y = oy * S + ky;\n"
|
|
251
|
+
code += f" int in_x = ox * S + kx;\n"
|
|
252
|
+
code += f" if (in_y < IN_H && in_x < IN_W) {{\n"
|
|
253
|
+
code += f" int idx = c * IN_H * IN_W + in_y * IN_W + in_x;\n"
|
|
254
|
+
code += f" if (input[idx] > max_val) max_val = input[idx];\n"
|
|
255
|
+
code += f" }}\n"
|
|
256
|
+
code += f" }}\n"
|
|
257
|
+
code += f" }}\n"
|
|
258
|
+
code += f" int out_idx = c * OUT_H * OUT_W + oy * OUT_W + ox;\n"
|
|
259
|
+
code += f" output[out_idx] = max_val;\n"
|
|
260
|
+
code += f" }}\n"
|
|
261
|
+
code += f" }}\n"
|
|
262
|
+
code += f" }}\n"
|
|
263
|
+
code += "}\n\n"
|
|
264
|
+
return code
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _generate_main_predict(layer_sequence: List[Dict], buffer_size: int, num_classes: int) -> str:
|
|
268
|
+
"""Generate the main prediction function."""
|
|
269
|
+
code = "/* Main prediction function */\n"
|
|
270
|
+
code += "int8_t potnn_predict(const int8_t* input) {\n"
|
|
271
|
+
code += f" static int8_t buffer1[{buffer_size}];\n"
|
|
272
|
+
code += f" static int8_t buffer2[{buffer_size}];\n"
|
|
273
|
+
code += " const int8_t* current_input = input;\n"
|
|
274
|
+
code += " int8_t* current_output = buffer1;\n\n"
|
|
275
|
+
|
|
276
|
+
for i, layer in enumerate(layer_sequence):
|
|
277
|
+
name = layer['name']
|
|
278
|
+
layer_type = layer['type']
|
|
279
|
+
is_last = (i == len(layer_sequence) - 1)
|
|
280
|
+
|
|
281
|
+
if layer_type == 'Flatten':
|
|
282
|
+
code += f" // {name}: Flatten (no-op, just continue)\n"
|
|
283
|
+
code += f" // Input and output share same memory layout\n\n"
|
|
284
|
+
continue
|
|
285
|
+
|
|
286
|
+
code += f" // {name}: {layer_type}\n"
|
|
287
|
+
code += f" {name}_forward(current_input, current_output);\n"
|
|
288
|
+
|
|
289
|
+
if not is_last:
|
|
290
|
+
code += " current_input = current_output;\n"
|
|
291
|
+
code += " current_output = (current_output == buffer1) ? buffer2 : buffer1;\n\n"
|
|
292
|
+
|
|
293
|
+
code += f"\n // Find argmax for classification ({num_classes} classes)\n"
|
|
294
|
+
code += " int8_t max_val = current_output[0];\n"
|
|
295
|
+
code += " int8_t max_idx = 0;\n"
|
|
296
|
+
code += f" for (int i = 1; i < {num_classes}; i++) {{\n"
|
|
297
|
+
code += " if (current_output[i] > max_val) {\n"
|
|
298
|
+
code += " max_val = current_output[i];\n"
|
|
299
|
+
code += " max_idx = i;\n"
|
|
300
|
+
code += " }\n"
|
|
301
|
+
code += " }\n"
|
|
302
|
+
code += " return max_idx;\n"
|
|
303
|
+
code += "}\n"
|
|
304
|
+
|
|
305
|
+
return code
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _compute_layer_dimensions(model: nn.Module, input_shape: Tuple) -> Dict[str, Dict]:
|
|
309
|
+
"""Compute input/output dimensions for each layer.
|
|
310
|
+
|
|
311
|
+
Handles both original nn.Conv2d/nn.Linear and PoTConv2d/PoTLinear.
|
|
312
|
+
"""
|
|
313
|
+
dims = {}
|
|
314
|
+
|
|
315
|
+
# Current shape tracking
|
|
316
|
+
if len(input_shape) == 3:
|
|
317
|
+
c, h, w = input_shape
|
|
318
|
+
elif len(input_shape) == 1:
|
|
319
|
+
c, h, w = 1, 1, input_shape[0]
|
|
320
|
+
else:
|
|
321
|
+
c, h, w = 1, input_shape[0], input_shape[1]
|
|
322
|
+
|
|
323
|
+
print(f"\nTracking dimensions from input shape: c={c}, h={h}, w={w}")
|
|
324
|
+
|
|
325
|
+
for name, module in model.named_modules():
|
|
326
|
+
# Skip container modules
|
|
327
|
+
if name == '':
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
# Handle PoTConv2d and nn.Conv2d
|
|
331
|
+
if isinstance(module, PoTConv2d) or isinstance(module, nn.Conv2d):
|
|
332
|
+
in_h, in_w = h, w
|
|
333
|
+
in_ch = c
|
|
334
|
+
|
|
335
|
+
# Get kernel size, stride, padding
|
|
336
|
+
if isinstance(module, PoTConv2d):
|
|
337
|
+
kh = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size
|
|
338
|
+
kw = module.kernel_size[1] if isinstance(module.kernel_size, tuple) else module.kernel_size
|
|
339
|
+
sh = module.stride[0] if isinstance(module.stride, tuple) else module.stride
|
|
340
|
+
sw = module.stride[1] if isinstance(module.stride, tuple) else module.stride
|
|
341
|
+
ph = module.padding[0] if isinstance(module.padding, tuple) else module.padding
|
|
342
|
+
pw = module.padding[1] if isinstance(module.padding, tuple) else module.padding
|
|
343
|
+
out_ch = module.out_channels
|
|
344
|
+
else:
|
|
345
|
+
kh, kw = module.kernel_size if isinstance(module.kernel_size, tuple) else (module.kernel_size, module.kernel_size)
|
|
346
|
+
sh, sw = module.stride if isinstance(module.stride, tuple) else (module.stride, module.stride)
|
|
347
|
+
ph, pw = module.padding if isinstance(module.padding, tuple) else (module.padding, module.padding)
|
|
348
|
+
out_ch = module.out_channels
|
|
349
|
+
|
|
350
|
+
out_h = (in_h + 2*ph - kh) // sh + 1
|
|
351
|
+
out_w = (in_w + 2*pw - kw) // sw + 1
|
|
352
|
+
|
|
353
|
+
dims[name] = {
|
|
354
|
+
'in_h': in_h, 'in_w': in_w, 'in_ch': in_ch,
|
|
355
|
+
'out_h': out_h, 'out_w': out_w, 'out_ch': out_ch,
|
|
356
|
+
'in_size': in_ch * in_h * in_w,
|
|
357
|
+
'out_size': out_ch * out_h * out_w,
|
|
358
|
+
'stride': sh,
|
|
359
|
+
'padding': ph if ph == pw else (ph, pw),
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
# Update current shape
|
|
363
|
+
c, h, w = out_ch, out_h, out_w
|
|
364
|
+
print(f" {name} (Conv): {in_ch}x{in_h}x{in_w} -> {out_ch}x{out_h}x{out_w}")
|
|
365
|
+
|
|
366
|
+
elif isinstance(module, nn.MaxPool2d):
|
|
367
|
+
in_h, in_w = h, w
|
|
368
|
+
in_ch = c
|
|
369
|
+
|
|
370
|
+
k = module.kernel_size if isinstance(module.kernel_size, int) else module.kernel_size[0]
|
|
371
|
+
s = module.stride if isinstance(module.stride, int) else module.stride[0]
|
|
372
|
+
|
|
373
|
+
out_h = in_h // s
|
|
374
|
+
out_w = in_w // s
|
|
375
|
+
|
|
376
|
+
dims[name] = {
|
|
377
|
+
'in_h': in_h, 'in_w': in_w, 'in_ch': in_ch,
|
|
378
|
+
'out_h': out_h, 'out_w': out_w,
|
|
379
|
+
'in_size': in_ch * in_h * in_w,
|
|
380
|
+
'out_size': in_ch * out_h * out_w,
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
# Update current shape
|
|
384
|
+
h, w = out_h, out_w
|
|
385
|
+
print(f" {name} (MaxPool): {in_ch}x{in_h}x{in_w} -> {in_ch}x{out_h}x{out_w}")
|
|
386
|
+
|
|
387
|
+
elif isinstance(module, nn.Flatten):
|
|
388
|
+
flat_size = c * h * w
|
|
389
|
+
dims[name] = {
|
|
390
|
+
'in_size': flat_size,
|
|
391
|
+
'out_size': flat_size,
|
|
392
|
+
}
|
|
393
|
+
print(f" {name} (Flatten): {c}x{h}x{w} -> {flat_size}")
|
|
394
|
+
|
|
395
|
+
elif isinstance(module, PoTLinear) or isinstance(module, nn.Linear):
|
|
396
|
+
in_features = module.in_features
|
|
397
|
+
out_features = module.out_features
|
|
398
|
+
|
|
399
|
+
dims[name] = {
|
|
400
|
+
'in_size': in_features,
|
|
401
|
+
'out_size': out_features,
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
# Update current shape for potential next linear
|
|
405
|
+
c, h, w = 1, 1, out_features
|
|
406
|
+
print(f" {name} (Linear): {in_features} -> {out_features}")
|
|
407
|
+
|
|
408
|
+
# Skip BatchNorm, ReLU, Identity, Dropout - they don't change dimensions
|
|
409
|
+
|
|
410
|
+
return dims
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def _compute_max_buffer_size(layer_dims: Dict) -> int:
|
|
414
|
+
"""Compute maximum buffer size needed."""
|
|
415
|
+
max_size = 256 # Minimum
|
|
416
|
+
for name, dims in layer_dims.items():
|
|
417
|
+
in_size = dims.get('in_size', 0)
|
|
418
|
+
out_size = dims.get('out_size', 0)
|
|
419
|
+
max_size = max(max_size, in_size, out_size)
|
|
420
|
+
return max_size
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def _check_relu_follows(model: nn.Module, layer_name: str) -> bool:
|
|
424
|
+
"""Check if ReLU follows the given layer.
|
|
425
|
+
|
|
426
|
+
Skips Identity layers (fused BatchNorm) and BatchNorm layers.
|
|
427
|
+
"""
|
|
428
|
+
found_layer = False
|
|
429
|
+
for name, module in model.named_modules():
|
|
430
|
+
if name == layer_name:
|
|
431
|
+
found_layer = True
|
|
432
|
+
continue
|
|
433
|
+
if found_layer:
|
|
434
|
+
# Skip Identity (fused BatchNorm) and BatchNorm
|
|
435
|
+
if isinstance(module, (nn.Identity, nn.BatchNorm1d, nn.BatchNorm2d)):
|
|
436
|
+
continue
|
|
437
|
+
if isinstance(module, (nn.ReLU, nn.ReLU6)):
|
|
438
|
+
return True
|
|
439
|
+
elif isinstance(module, (nn.Conv2d, nn.Linear, nn.MaxPool2d)):
|
|
440
|
+
return False # Another compute layer before ReLU
|
|
441
|
+
# Also check for PoT modules
|
|
442
|
+
if isinstance(module, (PoTLinear, PoTConv2d)):
|
|
443
|
+
return False # PoTLinear or PoTConv2d
|
|
444
|
+
return False
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def _build_layer_sequence(model: nn.Module, pot_layers: List[Dict], other_layers: List[Dict]) -> List[Dict]:
|
|
448
|
+
"""Build ordered sequence of all layers."""
|
|
449
|
+
# Create lookup by original name
|
|
450
|
+
pot_lookup = {l['original_name']: l for l in pot_layers}
|
|
451
|
+
other_lookup = {l['original_name']: l for l in other_layers}
|
|
452
|
+
|
|
453
|
+
sequence = []
|
|
454
|
+
for name, module in model.named_modules():
|
|
455
|
+
if name in pot_lookup:
|
|
456
|
+
sequence.append(pot_lookup[name])
|
|
457
|
+
elif name in other_lookup:
|
|
458
|
+
sequence.append(other_lookup[name])
|
|
459
|
+
|
|
460
|
+
return sequence
|