potnn 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- potnn/__init__.py +86 -0
- potnn/codegen/__init__.py +20 -0
- potnn/codegen/bit2.py +263 -0
- potnn/codegen/fp130.py +269 -0
- potnn/codegen/header.py +460 -0
- potnn/codegen/level5.py +393 -0
- potnn/codegen/scale.py +184 -0
- potnn/codegen/ternary.py +354 -0
- potnn/codegen/unroll.py +616 -0
- potnn/config.py +112 -0
- potnn/export.py +2196 -0
- potnn/fuse.py +167 -0
- potnn/modules/__init__.py +11 -0
- potnn/modules/add.py +114 -0
- potnn/modules/avgpool.py +173 -0
- potnn/modules/base.py +225 -0
- potnn/modules/conv.py +203 -0
- potnn/modules/conv1d.py +317 -0
- potnn/modules/depthwise.py +216 -0
- potnn/modules/linear.py +199 -0
- potnn/quantize/__init__.py +35 -0
- potnn/quantize/calibration.py +233 -0
- potnn/quantize/integer_ops.py +207 -0
- potnn/quantize/integer_sim.py +225 -0
- potnn/quantize/pot.py +455 -0
- potnn/quantize/qat.py +356 -0
- potnn/utils/__init__.py +13 -0
- potnn/utils/allocation.py +240 -0
- potnn/utils/memory.py +158 -0
- potnn/wrapper.py +304 -0
- potnn-1.0.0.dist-info/METADATA +260 -0
- potnn-1.0.0.dist-info/RECORD +35 -0
- potnn-1.0.0.dist-info/WHEEL +5 -0
- potnn-1.0.0.dist-info/licenses/LICENSE +72 -0
- potnn-1.0.0.dist-info/top_level.txt +1 -0
potnn/codegen/level5.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""Generate 5-level encoded C code for PoT layers.
|
|
2
|
+
|
|
3
|
+
5-level encoding: [skip(2)][sign(1)][mag(1)] = 4bit
|
|
4
|
+
- 5 levels: -8, -1, 0, +1, +8
|
|
5
|
+
- Skip field handles consecutive zeros (0-3 positions)
|
|
6
|
+
- Decoding: skip positions, then val = (mag ? 8 : 1) * (sign ? -1 : 1)
|
|
7
|
+
- Balanced accuracy and memory efficiency
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from typing import Dict, Any, Tuple, List
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def pack_weights_5level(w_q: np.ndarray) -> Tuple[np.ndarray, int, int]:
|
|
15
|
+
"""Pack quantized weights to 5-level format with skip encoding.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
w_q: Quantized weights (values in -8, -1, 0, +1, +8)
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
packed: uint8 array (2 codes per byte)
|
|
22
|
+
n_codes: number of 4-bit codes
|
|
23
|
+
n_weights: original number of weights
|
|
24
|
+
"""
|
|
25
|
+
flat = np.round(w_q.flatten()).astype(np.int16)
|
|
26
|
+
n = len(flat)
|
|
27
|
+
|
|
28
|
+
# Generate 4-bit codes with skip encoding
|
|
29
|
+
codes: List[int] = []
|
|
30
|
+
i = 0
|
|
31
|
+
while i < n:
|
|
32
|
+
# Count consecutive zeros (max 3)
|
|
33
|
+
skip = 0
|
|
34
|
+
while i + skip < n and flat[i + skip] == 0 and skip < 3:
|
|
35
|
+
skip += 1
|
|
36
|
+
|
|
37
|
+
i += skip
|
|
38
|
+
|
|
39
|
+
if i >= n:
|
|
40
|
+
# End of weights - emit dummy code if needed
|
|
41
|
+
if skip > 0:
|
|
42
|
+
codes.append((skip << 2) | 0b00) # skip + val=+1
|
|
43
|
+
break
|
|
44
|
+
|
|
45
|
+
w = flat[i]
|
|
46
|
+
sign = 1 if w < 0 else 0
|
|
47
|
+
mag = 1 if abs(w) == 8 else 0 # 0=1, 1=8
|
|
48
|
+
|
|
49
|
+
code = (skip << 2) | (sign << 1) | mag
|
|
50
|
+
codes.append(code)
|
|
51
|
+
i += 1
|
|
52
|
+
|
|
53
|
+
# Pack 2 codes per byte
|
|
54
|
+
n_codes = len(codes)
|
|
55
|
+
packed_len = (n_codes + 1) // 2
|
|
56
|
+
packed = np.zeros(packed_len, dtype=np.uint8)
|
|
57
|
+
for i in range(0, n_codes, 2):
|
|
58
|
+
high = codes[i]
|
|
59
|
+
low = codes[i + 1] if i + 1 < n_codes else 0
|
|
60
|
+
packed[i // 2] = (high << 4) | low
|
|
61
|
+
|
|
62
|
+
return packed, n_codes, n
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def pack_weights_5level_linear(w_q: np.ndarray) -> Tuple[np.ndarray, List[int]]:
|
|
66
|
+
"""Pack Linear weights row by row (each output filter separately).
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
w_q: 2D weight array (out_features, in_features)
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
packed: uint8 array (all rows concatenated, byte-aligned per row)
|
|
73
|
+
row_bytes: list of byte counts per row
|
|
74
|
+
|
|
75
|
+
Note:
|
|
76
|
+
5-level encoding can only skip up to 3 consecutive zeros.
|
|
77
|
+
If there are 4+ consecutive zeros, the 4th zero onwards are replaced
|
|
78
|
+
with +1 (smallest non-zero value). This is a spec limitation.
|
|
79
|
+
"""
|
|
80
|
+
out_features, in_features = w_q.shape
|
|
81
|
+
all_packed = []
|
|
82
|
+
row_bytes = []
|
|
83
|
+
|
|
84
|
+
for o in range(out_features):
|
|
85
|
+
row = np.round(w_q[o]).astype(np.int16).copy()
|
|
86
|
+
|
|
87
|
+
# WORKAROUND: Replace 4th+ consecutive zeros with +1
|
|
88
|
+
# 5-level skip field is 2 bits (max 3), so 4+ zeros can't be encoded
|
|
89
|
+
zero_run = 0
|
|
90
|
+
for j in range(len(row)):
|
|
91
|
+
if row[j] == 0:
|
|
92
|
+
zero_run += 1
|
|
93
|
+
if zero_run > 3:
|
|
94
|
+
row[j] = 1 # Replace with smallest non-zero
|
|
95
|
+
zero_run = 0 # Reset counter
|
|
96
|
+
else:
|
|
97
|
+
zero_run = 0
|
|
98
|
+
|
|
99
|
+
# Generate codes for this row
|
|
100
|
+
codes: List[int] = []
|
|
101
|
+
i = 0
|
|
102
|
+
while i < in_features:
|
|
103
|
+
skip = 0
|
|
104
|
+
while i + skip < in_features and row[i + skip] == 0 and skip < 3:
|
|
105
|
+
skip += 1
|
|
106
|
+
|
|
107
|
+
i += skip
|
|
108
|
+
|
|
109
|
+
if i >= in_features:
|
|
110
|
+
# Trailing zeros: emit skip code so decoder advances i
|
|
111
|
+
if skip > 0:
|
|
112
|
+
codes.append((skip << 2) | 0b00) # dummy: will be skipped by decoder
|
|
113
|
+
break
|
|
114
|
+
|
|
115
|
+
w = row[i]
|
|
116
|
+
sign = 1 if w < 0 else 0
|
|
117
|
+
mag = 1 if abs(w) == 8 else 0
|
|
118
|
+
|
|
119
|
+
code = (skip << 2) | (sign << 1) | mag
|
|
120
|
+
codes.append(code)
|
|
121
|
+
i += 1
|
|
122
|
+
|
|
123
|
+
# Pack to bytes (byte-aligned per row)
|
|
124
|
+
n_codes = len(codes)
|
|
125
|
+
packed_len = (n_codes + 1) // 2
|
|
126
|
+
for j in range(0, n_codes, 2):
|
|
127
|
+
high = codes[j]
|
|
128
|
+
low = codes[j + 1] if j + 1 < n_codes else 0
|
|
129
|
+
all_packed.append((high << 4) | low)
|
|
130
|
+
|
|
131
|
+
row_bytes.append(packed_len)
|
|
132
|
+
|
|
133
|
+
return np.array(all_packed, dtype=np.uint8), row_bytes
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def generate_5level_layer(layer_info: Dict[str, Any]) -> str:
|
|
137
|
+
"""Generate 5-level encoded C code for a layer.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
layer_info: Dictionary with layer information
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
C code for the layer
|
|
144
|
+
"""
|
|
145
|
+
name = layer_info['name']
|
|
146
|
+
layer_type = layer_info['type']
|
|
147
|
+
weight = layer_info['weight']
|
|
148
|
+
bias = layer_info.get('bias', None)
|
|
149
|
+
use_relu = layer_info.get('has_relu', False)
|
|
150
|
+
act_scale = layer_info.get('act_scale', 1.0) or 1.0
|
|
151
|
+
|
|
152
|
+
# Get weight as numpy
|
|
153
|
+
if hasattr(weight, 'numpy'):
|
|
154
|
+
w_q = weight.numpy()
|
|
155
|
+
else:
|
|
156
|
+
w_q = np.array(weight)
|
|
157
|
+
|
|
158
|
+
if 'Linear' in layer_type:
|
|
159
|
+
# Linear: row-aligned packing
|
|
160
|
+
packed, row_bytes = pack_weights_5level_linear(w_q)
|
|
161
|
+
|
|
162
|
+
code = f"// {name} - 5-level encoding (-8, -1, 0, +1, +8)\n"
|
|
163
|
+
code += f"// Packed: {len(packed)} bytes (row-aligned)\n\n"
|
|
164
|
+
|
|
165
|
+
# Weight data
|
|
166
|
+
code += f"static const uint8_t {name}_weights[] = {{\n "
|
|
167
|
+
for i, b in enumerate(packed):
|
|
168
|
+
code += f"0x{b:02x}, "
|
|
169
|
+
if (i + 1) % 16 == 0:
|
|
170
|
+
code += "\n "
|
|
171
|
+
code += "\n};\n\n"
|
|
172
|
+
|
|
173
|
+
# Bias data
|
|
174
|
+
if bias is not None:
|
|
175
|
+
code += f"static const int32_t {name}_bias[] = {{\n "
|
|
176
|
+
for i, b in enumerate(bias):
|
|
177
|
+
bias_val = int(round(b.item() * act_scale))
|
|
178
|
+
# No clipping for int32
|
|
179
|
+
code += f"{bias_val}, "
|
|
180
|
+
code += "\n};\n\n"
|
|
181
|
+
|
|
182
|
+
code += _generate_linear_5level(name, w_q.shape, bias, use_relu, act_scale)
|
|
183
|
+
return code
|
|
184
|
+
|
|
185
|
+
elif 'Conv2d' in layer_type:
|
|
186
|
+
# Conv2d: flat packing (pre-decode at runtime)
|
|
187
|
+
packed, n_codes, n_weights = pack_weights_5level(w_q)
|
|
188
|
+
|
|
189
|
+
code = f"// {name} - 5-level encoding (-8, -1, 0, +1, +8)\n"
|
|
190
|
+
code += f"// Packed: {len(packed)} bytes ({n_codes} codes for {n_weights} weights)\n\n"
|
|
191
|
+
|
|
192
|
+
# Weight data
|
|
193
|
+
code += f"static const uint8_t {name}_weights[] = {{\n "
|
|
194
|
+
for i, b in enumerate(packed):
|
|
195
|
+
code += f"0x{b:02x}, "
|
|
196
|
+
if (i + 1) % 16 == 0:
|
|
197
|
+
code += "\n "
|
|
198
|
+
code += "\n};\n\n"
|
|
199
|
+
|
|
200
|
+
# Bias data
|
|
201
|
+
if bias is not None:
|
|
202
|
+
code += f"static const int32_t {name}_bias[] = {{\n "
|
|
203
|
+
for i, b in enumerate(bias):
|
|
204
|
+
bias_val = int(round(b.item() * act_scale))
|
|
205
|
+
# No clipping for int32
|
|
206
|
+
code += f"{bias_val}, "
|
|
207
|
+
code += "\n};\n\n"
|
|
208
|
+
|
|
209
|
+
code += _generate_conv2d_5level(name, layer_info, bias, use_relu, act_scale)
|
|
210
|
+
return code
|
|
211
|
+
|
|
212
|
+
return ""
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _generate_linear_5level(name: str, shape: tuple, bias, use_relu: bool, act_scale: float) -> str:
|
|
216
|
+
"""Generate 5-level Linear layer code."""
|
|
217
|
+
out_features, in_features = shape
|
|
218
|
+
|
|
219
|
+
code = f"static void {name}_forward(const int8_t* input, int8_t* output) {{\n"
|
|
220
|
+
code += f" const uint8_t* wp = {name}_weights;\n"
|
|
221
|
+
code += f" int32_t acc;\n"
|
|
222
|
+
code += f" uint8_t packed, code, skip, sign, mag;\n"
|
|
223
|
+
code += f" int shift, shifted, mask;\n"
|
|
224
|
+
code += f" int i, nibble;\n\n"
|
|
225
|
+
|
|
226
|
+
code += f" for (int o = 0; o < {out_features}; o++) {{\n"
|
|
227
|
+
code += f" acc = 0;\n"
|
|
228
|
+
code += f" i = 0;\n"
|
|
229
|
+
code += f" nibble = 0; // reset per row (row-aligned packing)\n"
|
|
230
|
+
code += f" while (i < {in_features}) {{\n"
|
|
231
|
+
code += f" if (nibble == 0) {{\n"
|
|
232
|
+
code += f" packed = *wp++;\n"
|
|
233
|
+
code += f" code = packed >> 4;\n"
|
|
234
|
+
code += f" }} else {{\n"
|
|
235
|
+
code += f" code = packed & 0xf;\n"
|
|
236
|
+
code += f" }}\n"
|
|
237
|
+
code += f" nibble = 1 - nibble;\n"
|
|
238
|
+
code += f" \n"
|
|
239
|
+
code += f" skip = (code >> 2) & 0x3;\n"
|
|
240
|
+
code += f" i += skip; // skip zeros\n"
|
|
241
|
+
code += f" if (i >= {in_features}) break;\n"
|
|
242
|
+
code += f" \n"
|
|
243
|
+
code += f" sign = (code >> 1) & 1;\n"
|
|
244
|
+
code += f" mag = code & 1;\n"
|
|
245
|
+
code += f" shift = (mag << 1) + mag; // 0 or 3\n"
|
|
246
|
+
code += f" shifted = (int)input[i] << shift;\n"
|
|
247
|
+
code += f" mask = -(int)sign;\n"
|
|
248
|
+
code += f" acc += (shifted ^ mask) - mask;\n"
|
|
249
|
+
code += f" i++;\n"
|
|
250
|
+
code += f" }}\n"
|
|
251
|
+
|
|
252
|
+
# Scale
|
|
253
|
+
code += f" acc = scale_{name}(acc);\n"
|
|
254
|
+
|
|
255
|
+
# Bias
|
|
256
|
+
if bias is not None:
|
|
257
|
+
code += f" acc += {name}_bias[o];\n"
|
|
258
|
+
|
|
259
|
+
# ReLU
|
|
260
|
+
if use_relu:
|
|
261
|
+
code += f" if (acc < 0) acc = 0;\n"
|
|
262
|
+
|
|
263
|
+
# Clamp and store
|
|
264
|
+
code += f" output[o] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));\n"
|
|
265
|
+
code += f" }}\n"
|
|
266
|
+
code += f"}}\n\n"
|
|
267
|
+
|
|
268
|
+
return code
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _generate_conv2d_5level(name: str, layer_info: Dict, bias, use_relu: bool, act_scale: float) -> str:
|
|
272
|
+
"""Generate 5-level Conv2d layer code."""
|
|
273
|
+
weight = layer_info['weight']
|
|
274
|
+
if hasattr(weight, 'shape'):
|
|
275
|
+
w_shape = weight.shape
|
|
276
|
+
else:
|
|
277
|
+
w_shape = np.array(weight).shape
|
|
278
|
+
|
|
279
|
+
if len(w_shape) == 4:
|
|
280
|
+
out_ch, in_ch, kh, kw = w_shape
|
|
281
|
+
elif len(w_shape) == 3:
|
|
282
|
+
out_ch, in_ch, kw = w_shape
|
|
283
|
+
kh = 1
|
|
284
|
+
|
|
285
|
+
out_h = layer_info.get('out_h', 1)
|
|
286
|
+
out_w = layer_info.get('out_w', 1)
|
|
287
|
+
in_h = layer_info.get('in_h', 1)
|
|
288
|
+
in_w = layer_info.get('in_w', 1)
|
|
289
|
+
stride = layer_info.get('stride', 1)
|
|
290
|
+
padding = layer_info.get('padding', 0)
|
|
291
|
+
groups = layer_info.get('groups', 1)
|
|
292
|
+
|
|
293
|
+
# Handle tuple parameters
|
|
294
|
+
if isinstance(stride, tuple): stride_h, stride_w = stride
|
|
295
|
+
else: stride_h = stride_w = stride
|
|
296
|
+
|
|
297
|
+
if isinstance(padding, tuple): pad_h, pad_w = padding
|
|
298
|
+
else: pad_h = pad_w = padding
|
|
299
|
+
|
|
300
|
+
kernel_size = in_ch * kh * kw
|
|
301
|
+
|
|
302
|
+
code = f"static void {name}_forward(const int8_t* input, int8_t* output) {{\n"
|
|
303
|
+
code += f" // Conv2d: {in_ch}x{in_h}x{in_w} -> {out_ch}x{out_h}x{out_w}\n"
|
|
304
|
+
code += f" // 5-level with skip encoding\n"
|
|
305
|
+
code += f" int32_t acc;\n"
|
|
306
|
+
code += f" uint8_t packed, code, skip, sign, mag;\n"
|
|
307
|
+
code += f" int base, mask, val;\n\n"
|
|
308
|
+
|
|
309
|
+
# For conv2d, we need to decode per-filter
|
|
310
|
+
# Pre-decode weights to array for random access during convolution
|
|
311
|
+
code += f" // Pre-decoded weights for random access\n"
|
|
312
|
+
code += f" static int8_t w_decoded[{out_ch * kernel_size}];\n"
|
|
313
|
+
code += f" static int decoded = 0;\n"
|
|
314
|
+
code += f" if (!decoded) {{\n"
|
|
315
|
+
code += f" const uint8_t* wp = {name}_weights;\n"
|
|
316
|
+
code += f" int idx = 0, nibble = 0;\n"
|
|
317
|
+
code += f" uint8_t p;\n"
|
|
318
|
+
code += f" while (idx < {out_ch * kernel_size}) {{\n"
|
|
319
|
+
code += f" if (nibble == 0) {{ p = *wp++; code = p >> 4; }}\n"
|
|
320
|
+
code += f" else {{ code = p & 0xf; }}\n"
|
|
321
|
+
code += f" nibble = 1 - nibble;\n"
|
|
322
|
+
code += f" skip = (code >> 2) & 0x3;\n"
|
|
323
|
+
code += f" for (int s = 0; s < skip && idx < {out_ch * kernel_size}; s++)\n"
|
|
324
|
+
code += f" w_decoded[idx++] = 0;\n"
|
|
325
|
+
code += f" if (idx >= {out_ch * kernel_size}) break;\n"
|
|
326
|
+
code += f" sign = (code >> 1) & 1;\n"
|
|
327
|
+
code += f" mag = code & 1;\n"
|
|
328
|
+
code += f" base = 1 << ((mag << 1) + mag); // 1 or 8\n"
|
|
329
|
+
code += f" mask = -(int)sign;\n"
|
|
330
|
+
code += f" w_decoded[idx++] = (base ^ mask) - mask;\n"
|
|
331
|
+
code += f" }}\n"
|
|
332
|
+
code += f" decoded = 1;\n"
|
|
333
|
+
code += f" }}\n\n"
|
|
334
|
+
|
|
335
|
+
code += f" for (int oc = 0; oc < {out_ch}; oc++) {{\n"
|
|
336
|
+
code += f" const int8_t* wf = w_decoded + oc * {kernel_size};\n"
|
|
337
|
+
code += f" for (int oy = 0; oy < {out_h}; oy++) {{\n"
|
|
338
|
+
code += f" for (int ox = 0; ox < {out_w}; ox++) {{\n"
|
|
339
|
+
code += f" acc = 0;\n"
|
|
340
|
+
code += f" int w_idx = 0;\n"
|
|
341
|
+
|
|
342
|
+
# Group offset calculation
|
|
343
|
+
channels_per_group = in_ch
|
|
344
|
+
if groups == out_ch:
|
|
345
|
+
group_stride_str = f"oc * {channels_per_group}"
|
|
346
|
+
elif groups > 1:
|
|
347
|
+
out_per_group = out_ch // groups
|
|
348
|
+
group_stride_str = f"(oc / {out_per_group}) * {channels_per_group}"
|
|
349
|
+
else:
|
|
350
|
+
group_stride_str = "0"
|
|
351
|
+
|
|
352
|
+
code += f" for (int ic = 0; ic < {in_ch}; ic++) {{\n"
|
|
353
|
+
code += f" for (int ky = 0; ky < {kh}; ky++) {{\n"
|
|
354
|
+
code += f" int iy = oy * {stride_h} + ky - {pad_h};\n"
|
|
355
|
+
code += f" if (iy < 0 || iy >= {in_h}) {{ w_idx += {kw}; continue; }}\n"
|
|
356
|
+
code += f" for (int kx = 0; kx < {kw}; kx++) {{\n"
|
|
357
|
+
code += f" int ix = ox * {stride_w} + kx - {pad_w};\n"
|
|
358
|
+
code += f" if (ix >= 0 && ix < {in_w}) {{\n"
|
|
359
|
+
code += f" val = wf[w_idx];\n"
|
|
360
|
+
code += f" if (val) {{\n"
|
|
361
|
+
|
|
362
|
+
if group_stride_str == "0":
|
|
363
|
+
input_idx = f"ic * {in_h * in_w} + iy * {in_w} + ix"
|
|
364
|
+
else:
|
|
365
|
+
input_idx = f"({group_stride_str} + ic) * {in_h * in_w} + iy * {in_w} + ix"
|
|
366
|
+
|
|
367
|
+
code += f" int32_t inp = input[{input_idx}];\n"
|
|
368
|
+
code += f" // 5-level: val is -8, -1, +1, or +8\n"
|
|
369
|
+
code += f" if (val == 1) acc += inp;\n"
|
|
370
|
+
code += f" else if (val == -1) acc -= inp;\n"
|
|
371
|
+
code += f" else if (val == 8) acc += inp << 3;\n"
|
|
372
|
+
code += f" else acc -= inp << 3; // val == -8\n"
|
|
373
|
+
code += f" }}\n"
|
|
374
|
+
code += f" }}\n"
|
|
375
|
+
code += f" w_idx++;\n"
|
|
376
|
+
code += f" }}\n"
|
|
377
|
+
code += f" }}\n"
|
|
378
|
+
code += f" }}\n"
|
|
379
|
+
code += f" acc = scale_{name}(acc);\n"
|
|
380
|
+
|
|
381
|
+
if bias is not None:
|
|
382
|
+
code += f" acc += {name}_bias[oc];\n"
|
|
383
|
+
|
|
384
|
+
if use_relu:
|
|
385
|
+
code += f" if (acc < 0) acc = 0;\n"
|
|
386
|
+
|
|
387
|
+
code += f" output[oc * {out_h * out_w} + oy * {out_w} + ox] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));\n"
|
|
388
|
+
code += f" }}\n"
|
|
389
|
+
code += f" }}\n"
|
|
390
|
+
code += f" }}\n"
|
|
391
|
+
code += f"}}\n\n"
|
|
392
|
+
|
|
393
|
+
return code
|
potnn/codegen/scale.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Scale decomposition and combined scale calculation for C code generation."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import math
|
|
5
|
+
from typing import List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def decompose_scale_to_shifts(scale: int) -> List[int]:
|
|
9
|
+
"""Decompose an integer scale into shift positions.
|
|
10
|
+
|
|
11
|
+
This finds which bits are set in the scale value,
|
|
12
|
+
allowing multiplication to be replaced with shifts and adds.
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
scale = 21 = 0b10101 = (1<<0) + (1<<2) + (1<<4)
|
|
16
|
+
Returns: [0, 2, 4]
|
|
17
|
+
Meaning: x * 21 = x + (x<<2) + (x<<4)
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
scale: Integer scale value
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
List of bit positions where scale has 1s
|
|
24
|
+
"""
|
|
25
|
+
shifts = []
|
|
26
|
+
for i in range(20): # Check up to 20 bits (supports scales up to ~1M)
|
|
27
|
+
if scale & (1 << i):
|
|
28
|
+
shifts.append(i)
|
|
29
|
+
return shifts
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def generate_scale_func(layer_name: str, scale: float, shift: int) -> str:
|
|
33
|
+
"""Generate C function for scale multiplication using shifts and adds.
|
|
34
|
+
|
|
35
|
+
This function converts floating-point scale to fixed-point integer,
|
|
36
|
+
then decomposes it into shift+add operations to avoid multiplication.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
layer_name: Name of the layer (for function naming)
|
|
40
|
+
scale: Floating-point scale value
|
|
41
|
+
shift: Total right shift to apply after multiplication
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
C function code as string
|
|
45
|
+
"""
|
|
46
|
+
# Convert to fixed-point integer
|
|
47
|
+
scale_int = int(scale * (1 << 16)) # 16-bit fixed point
|
|
48
|
+
total_shift = shift + 16
|
|
49
|
+
|
|
50
|
+
# Decompose scale into shifts
|
|
51
|
+
shifts = decompose_scale_to_shifts(scale_int)
|
|
52
|
+
|
|
53
|
+
if not shifts:
|
|
54
|
+
# Scale is 0
|
|
55
|
+
return f"""static inline int32_t scale_{layer_name}(int32_t x) {{
|
|
56
|
+
return 0;
|
|
57
|
+
}}
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
# Generate C code
|
|
61
|
+
code = f"static inline int32_t scale_{layer_name}(int32_t x) {{\n"
|
|
62
|
+
|
|
63
|
+
if len(shifts) == 1 and shifts[0] == 0:
|
|
64
|
+
# Scale is 1, only need shift
|
|
65
|
+
if total_shift == 0:
|
|
66
|
+
code += f" return x;\n"
|
|
67
|
+
elif total_shift == 1:
|
|
68
|
+
code += f" return (x + 1) >> 1; // Rounding\n"
|
|
69
|
+
else:
|
|
70
|
+
code += f" return (x + (1 << {total_shift - 1})) >> {total_shift}; // Rounding\n"
|
|
71
|
+
else:
|
|
72
|
+
# Multiple shifts - need to add them
|
|
73
|
+
terms = []
|
|
74
|
+
for s in shifts:
|
|
75
|
+
if s == 0:
|
|
76
|
+
terms.append("x")
|
|
77
|
+
else:
|
|
78
|
+
terms.append(f"(x << {s})")
|
|
79
|
+
|
|
80
|
+
# Add rounding term for proper rounding (handle edge cases)
|
|
81
|
+
if total_shift == 0:
|
|
82
|
+
code += f" return ({' + '.join(terms)});\n"
|
|
83
|
+
elif total_shift == 1:
|
|
84
|
+
code += f" return (({' + '.join(terms)}) + 1) >> 1;\n"
|
|
85
|
+
else:
|
|
86
|
+
code += f" return (({' + '.join(terms)}) + (1 << {total_shift - 1})) >> {total_shift};\n"
|
|
87
|
+
|
|
88
|
+
code += "}\n"
|
|
89
|
+
return code
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def calculate_combined_scale(
|
|
93
|
+
alpha: float,
|
|
94
|
+
act_scale: Optional[float],
|
|
95
|
+
prev_act_scale: float = 1.0,
|
|
96
|
+
std: Optional[float] = None,
|
|
97
|
+
is_first: bool = False
|
|
98
|
+
) -> Tuple[float, int]:
|
|
99
|
+
"""Calculate combined scale for a layer.
|
|
100
|
+
|
|
101
|
+
Formula:
|
|
102
|
+
- General layer: combined_scale = alpha * act_scale / prev_act_scale
|
|
103
|
+
- First layer: combined_scale = alpha * act_scale / std (absorb /std)
|
|
104
|
+
|
|
105
|
+
The /256 normalization is handled separately via shift adjustment.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
alpha: Alpha scaling parameter from the layer
|
|
109
|
+
act_scale: Activation scale from calibration (None for last layer)
|
|
110
|
+
prev_act_scale: Previous layer's activation scale
|
|
111
|
+
std: Standard deviation for input normalization (first layer only)
|
|
112
|
+
Can be float or List[float] (uses average for multi-channel)
|
|
113
|
+
is_first: Whether this is the first layer
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Tuple of (scale_float, base_shift)
|
|
117
|
+
"""
|
|
118
|
+
# Start with alpha
|
|
119
|
+
combined_scale = alpha
|
|
120
|
+
|
|
121
|
+
# Apply activation scale (if not None)
|
|
122
|
+
if act_scale is not None:
|
|
123
|
+
combined_scale *= act_scale
|
|
124
|
+
|
|
125
|
+
# Handle first layer standardization
|
|
126
|
+
if is_first and std is not None:
|
|
127
|
+
# Use average std for multi-channel
|
|
128
|
+
if isinstance(std, (list, tuple)):
|
|
129
|
+
avg_std = sum(std) / len(std)
|
|
130
|
+
else:
|
|
131
|
+
avg_std = std
|
|
132
|
+
combined_scale /= avg_std
|
|
133
|
+
else:
|
|
134
|
+
# Compensate for previous layer's scale
|
|
135
|
+
combined_scale /= prev_act_scale
|
|
136
|
+
|
|
137
|
+
# Base shift (will be adjusted for /256 in first layer)
|
|
138
|
+
base_shift = 0
|
|
139
|
+
|
|
140
|
+
return combined_scale, base_shift
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def absorb_standardization(first_layer, mean, std):
|
|
144
|
+
"""Absorb input standardization into the first layer.
|
|
145
|
+
|
|
146
|
+
This modifies the first layer's bias to absorb the mean subtraction:
|
|
147
|
+
b' = b - Σ_c (mean[c]/std[c]) × ΣW[:,c,:,:] × α
|
|
148
|
+
|
|
149
|
+
The std division is handled in combined_scale calculation.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
first_layer: The first PoT layer in the model
|
|
153
|
+
mean: Input mean for standardization (float or List[float])
|
|
154
|
+
std: Input standard deviation for standardization (float or List[float])
|
|
155
|
+
"""
|
|
156
|
+
if first_layer.bias is None:
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
# Normalize to list
|
|
160
|
+
if isinstance(mean, (int, float)):
|
|
161
|
+
mean = [mean]
|
|
162
|
+
if isinstance(std, (int, float)):
|
|
163
|
+
std = [std]
|
|
164
|
+
|
|
165
|
+
with torch.no_grad():
|
|
166
|
+
weight = first_layer.weight
|
|
167
|
+
in_channels = weight.shape[1] if len(weight.shape) > 1 else 1
|
|
168
|
+
|
|
169
|
+
# Channel-wise bias adjustment
|
|
170
|
+
for c in range(in_channels):
|
|
171
|
+
if hasattr(first_layer, 'kernel_size'):
|
|
172
|
+
# Conv2d: sum over kernel for this channel
|
|
173
|
+
weight_sum_c = weight[:, c, :, :].sum(dim=(1, 2))
|
|
174
|
+
else:
|
|
175
|
+
# Linear: handle flattened input
|
|
176
|
+
features_per_ch = weight.shape[1] // len(mean)
|
|
177
|
+
start_idx = c * features_per_ch
|
|
178
|
+
end_idx = start_idx + features_per_ch
|
|
179
|
+
weight_sum_c = weight[:, start_idx:end_idx].sum(dim=1)
|
|
180
|
+
|
|
181
|
+
first_layer.bias.data -= (mean[c] / std[c]) * weight_sum_c
|
|
182
|
+
|
|
183
|
+
print(f"Absorbed standardization into first layer:")
|
|
184
|
+
print(f" mean={mean}, std={std}")
|