potnn 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,354 @@
1
+ """Generate Ternary RLE encoded C code for PoT layers.
2
+
3
+ Ternary encoding: Run-Length Encoded
4
+ - 3 levels: -1, 0, +1
5
+ - RLE format: [type(2)][count/value]
6
+ - 00: +1 single
7
+ - 01: -1 single
8
+ - 10: zero run (next 6 bits = count 1-64)
9
+ - 11: non-zero run (next 2 bits = count 1-4, then values)
10
+ - Highest compression ratio for sparse ternary weights
11
+ """
12
+
13
+ import numpy as np
14
+ from typing import Dict, Any, Tuple, List
15
+
16
+
17
+ def pack_weights_ternary(w_q: np.ndarray) -> Tuple[bytes, int]:
18
+ """Pack quantized weights to Ternary RLE format.
19
+
20
+ Args:
21
+ w_q: Quantized weights (values in -1, 0, +1)
22
+
23
+ Returns:
24
+ packed: bytes of RLE encoded data
25
+ n_weights: original number of weights
26
+ """
27
+ flat = np.round(w_q.flatten()).astype(np.int8)
28
+ n = len(flat)
29
+
30
+ # Simple RLE: encode runs of zeros and individual non-zeros
31
+ # Format: [code(2 bits)][payload]
32
+ # 00 = +1, 01 = -1, 10 = zero run (6 bits count), 11 = reserved
33
+
34
+ bits: List[int] = [] # List of bits
35
+
36
+ i = 0
37
+ while i < n:
38
+ if flat[i] == 0:
39
+ # Count consecutive zeros
40
+ count = 0
41
+ while i + count < n and flat[i + count] == 0 and count < 63:
42
+ count += 1
43
+
44
+ # Encode zero run: 10 + 6-bit count (1-64 encoded as 0-63)
45
+ bits.extend([1, 0]) # type = 10
46
+ for b in range(5, -1, -1):
47
+ bits.append((count - 1) >> b & 1)
48
+ i += count
49
+ else:
50
+ # Single non-zero value
51
+ if flat[i] == 1:
52
+ bits.extend([0, 0]) # +1
53
+ else: # -1
54
+ bits.extend([0, 1]) # -1
55
+ i += 1
56
+
57
+ # Pack bits into bytes
58
+ packed = bytearray()
59
+ for i in range(0, len(bits), 8):
60
+ byte = 0
61
+ for j in range(8):
62
+ if i + j < len(bits):
63
+ byte |= bits[i + j] << (7 - j)
64
+ packed.append(byte)
65
+
66
+ return bytes(packed), n
67
+
68
+
69
+ def pack_weights_triple_run(w_q: np.ndarray) -> Tuple[np.ndarray, int, int]:
70
+ """Pack weights using 'Triple-Run' custom ternary encoding into uint32.
71
+
72
+ Encoding:
73
+ - 00: 0
74
+ - 01: +1
75
+ - 10: -1
76
+ - 11: Repeat Previous x2 (Total 3 of same value)
77
+
78
+ CRITICAL: Triple Runs MUST NOT cross block boundaries (rows for Linear, filters for Conv),
79
+ because the C decoder processes blocks independently and resets state.
80
+
81
+ Args:
82
+ w_q: Quantized weights
83
+
84
+ Returns:
85
+ packed: uint32 array
86
+ n_codes: number of 2-bit codes
87
+ n: number of weights
88
+ """
89
+ shape = w_q.shape
90
+ flat = np.round(w_q.flatten()).astype(np.int8)
91
+ n = len(flat)
92
+
93
+ # Determine Block Size
94
+ if len(shape) == 2:
95
+ # Linear: [Out, In] -> Block size is In
96
+ block_size = shape[1]
97
+ elif len(shape) == 4:
98
+ # Conv2d: [Out, In, KH, KW] -> Block size is In*KH*KW
99
+ block_size = shape[1] * shape[2] * shape[3]
100
+ else:
101
+ # Fallback or 1D
102
+ block_size = n
103
+
104
+ codes = []
105
+ i = 0
106
+ while i < n:
107
+ val = flat[i]
108
+
109
+ # Check boundary integrity
110
+ # We can only do a Triple Run if i, i+1, i+2 are in the SAME block
111
+ current_block = i // block_size
112
+ next_block = (i + 2) // block_size
113
+
114
+ can_run = (i + 2 < n) and \
115
+ (flat[i+1] == val and flat[i+2] == val) and \
116
+ (current_block == next_block)
117
+
118
+ if can_run:
119
+ # Emit code for val first
120
+ if val == 0: codes.append(0b00)
121
+ elif val == 1: codes.append(0b01)
122
+ else: codes.append(0b10) # -1
123
+
124
+ # Emit repeat code
125
+ codes.append(0b11) # Repeat x2
126
+ i += 3
127
+ else:
128
+ # Single emit
129
+ if val == 0: codes.append(0b00)
130
+ elif val == 1: codes.append(0b01)
131
+ else: codes.append(0b10) # -1
132
+ i += 1
133
+
134
+ # Pack 16 codes per uint32
135
+ n_codes = len(codes)
136
+ packed_len = (n_codes + 15) // 16
137
+ packed = np.zeros(packed_len, dtype=np.uint32)
138
+
139
+ for i in range(0, n_codes, 16):
140
+ b = 0
141
+ for j in range(16):
142
+ if i + j < n_codes:
143
+ b |= (int(codes[i + j]) << (2 * j))
144
+ packed[i // 16] = b
145
+
146
+ return packed, n_codes, n
147
+
148
+
149
+ def generate_ternary_layer(layer_info: Dict[str, Any]) -> str:
150
+ """Generate Triple-Run Ternary encoded C code."""
151
+ name = layer_info['name']
152
+ layer_type = layer_info['type']
153
+ weight = layer_info['weight']
154
+ bias = layer_info.get('bias', None)
155
+ use_relu = layer_info.get('has_relu', False)
156
+ act_scale = layer_info.get('act_scale', 1.0) or 1.0
157
+
158
+ if hasattr(weight, 'numpy'):
159
+ w_q = weight.numpy()
160
+ else:
161
+ w_q = np.array(weight)
162
+
163
+ packed, n_codes, n_weights = pack_weights_triple_run(w_q)
164
+
165
+ code = f"// {name} - Ternary Triple-Run (00:0, 01:+1, 10:-1, 11:Rep2)\n"
166
+ code += f"// Packed: {len(packed)*4} bytes ({n_codes} codes for {n_weights} weights)\n\n"
167
+
168
+ code += f"static const uint32_t {name}_weights[] = {{\n "
169
+ for i, w in enumerate(packed):
170
+ code += f"0x{w:08x}, "
171
+ if (i + 1) % 8 == 0:
172
+ code += "\n "
173
+ code += "\n};\n\n"
174
+
175
+ if bias is not None:
176
+ code += f"static const int32_t {name}_bias[] = {{\n "
177
+ for i, b in enumerate(bias):
178
+ bias_val = int(round(b.item() * act_scale))
179
+ # No clipping for int32
180
+ code += f"{bias_val}, "
181
+ if (i + 1) % 16 == 0:
182
+ code += "\n "
183
+ code += "\n};\n\n"
184
+
185
+ if 'Linear' in layer_type:
186
+ code += _generate_linear_ternary(name, w_q.shape, bias, use_relu, act_scale)
187
+ elif 'Conv2d' in layer_type:
188
+ code += _generate_conv2d_ternary(name, layer_info, bias, use_relu, act_scale)
189
+
190
+ return code
191
+
192
+
193
+ def _generate_linear_ternary(name: str, shape: tuple, bias, use_relu: bool, act_scale: float) -> str:
194
+ """Stream decoding for Linear layer."""
195
+ out_features, in_features = shape
196
+
197
+ code = f"static void {name}_forward(const int8_t* input, int8_t* output) {{\n"
198
+ code += f" const uint32_t* wp = {name}_weights;\n"
199
+ code += f" int32_t acc;\n"
200
+ code += f" uint32_t weight_chunk = *wp++;\n"
201
+ code += f" uint8_t code;\n"
202
+ code += f" int8_t prev_val = 0;\n"
203
+ code += f" int code_idx = 0;\n\n"
204
+
205
+ code += f" for (int o = 0; o < {out_features}; o++) {{\n"
206
+ code += f" acc = 0;\n"
207
+ code += f" int i = 0;\n"
208
+ code += f" while (i < {in_features}) {{\n"
209
+ code += f" code = (weight_chunk >> (code_idx << 1)) & 0x3;\n"
210
+ code += f" code_idx++;\n"
211
+ code += f" if (code_idx == 16) {{\n"
212
+ code += f" code_idx = 0;\n"
213
+ code += f" weight_chunk = *wp++;\n"
214
+ code += f" }}\n"
215
+ code += f" \n"
216
+ code += f" if (code == 3) {{ // Repeat x2 (Total 3)\n"
217
+ code += f" // prev_val applied 2 more times\n"
218
+ code += f" for (int k=0; k<2 && i < {in_features}; k++) {{\n"
219
+ code += f" acc += (int32_t)input[i++] * prev_val;\n"
220
+ code += f" }}\n"
221
+ code += f" }} else {{\n"
222
+ code += f" // Decode new value\n"
223
+ code += f" if (code == 0) prev_val = 0;\n"
224
+ code += f" else if (code == 1) {{ prev_val = 1; acc += input[i]; }}\n"
225
+ code += f" else {{ prev_val = -1; acc -= input[i]; }}\n"
226
+ code += f" i++;\n"
227
+ code += f" }}\n"
228
+ code += f" }}\n"
229
+
230
+ code += f" acc = scale_{name}(acc);\n"
231
+ if bias is not None: code += f" acc += {name}_bias[o];\n"
232
+ if use_relu: code += f" if (acc < 0) acc = 0;\n"
233
+ code += f" output[o] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));\n"
234
+ code += f" }}\n"
235
+ code += f"}}\n\n"
236
+
237
+ return code
238
+
239
+
240
+ def _generate_conv2d_ternary(name: str, layer_info: Dict, bias, use_relu: bool, act_scale: float) -> str:
241
+ """Block decoding for Conv2d layer."""
242
+ weight = layer_info['weight']
243
+ if hasattr(weight, 'shape'):
244
+ w_shape = weight.shape
245
+ else:
246
+ w_shape = np.array(weight).shape
247
+
248
+ if len(w_shape) == 4:
249
+ out_ch, in_ch, kh, kw = w_shape
250
+ elif len(w_shape) == 3:
251
+ out_ch, in_ch, kw = w_shape
252
+ kh = 1
253
+
254
+ out_h = layer_info.get('out_h', 1)
255
+ out_w = layer_info.get('out_w', 1)
256
+ in_h = layer_info.get('in_h', 1)
257
+ in_w = layer_info.get('in_w', 1)
258
+ stride = layer_info.get('stride', 1)
259
+ padding = layer_info.get('padding', 0)
260
+ groups = layer_info.get('groups', 1)
261
+
262
+ # Handle tuple parameters
263
+ if isinstance(stride, tuple): stride_h, stride_w = stride
264
+ else: stride_h = stride_w = stride
265
+
266
+ if isinstance(padding, tuple): pad_h, pad_w = padding
267
+ else: pad_h = pad_w = padding
268
+
269
+ kernel_size = in_ch * kh * kw
270
+
271
+ code = f"static void {name}_forward(const int8_t* input, int8_t* output) {{\n"
272
+ code += f" const uint32_t* wp = {name}_weights;\n"
273
+ code += f" int32_t acc;\n"
274
+ code += f" int8_t w_buf[{kernel_size}]; // Filter Weights Cache\n"
275
+ code += f" uint32_t weight_chunk;\n"
276
+ code += f" uint8_t code;\n"
277
+ code += f" int8_t prev_val;\n"
278
+ code += f" int code_idx;\n\n"
279
+
280
+ code += f" // wp is persistent across filters since we decode sequentially\n"
281
+ code += f" weight_chunk = *wp++;\n"
282
+ code += f" code_idx = 0;\n"
283
+ code += f" \n"
284
+
285
+ code += f" for (int oc = 0; oc < {out_ch}; oc++) {{\n"
286
+ code += f" // 1. Decode Filter Weights to Stack Buffer\n"
287
+ code += f" int w_ptr = 0;\n"
288
+ code += f" prev_val = 0;\n"
289
+ code += f" while (w_ptr < {kernel_size}) {{\n"
290
+ code += f" code = (weight_chunk >> (code_idx << 1)) & 0x3;\n"
291
+ code += f" code_idx++;\n"
292
+ code += f" if (code_idx == 16) {{\n"
293
+ code += f" code_idx = 0;\n"
294
+ code += f" weight_chunk = *wp++;\n"
295
+ code += f" }}\n"
296
+ code += f" \n"
297
+ code += f" if (code == 3) {{ // Repeat x2\n"
298
+ code += f" w_buf[w_ptr++] = prev_val;\n"
299
+ code += f" if (w_ptr < {kernel_size}) w_buf[w_ptr++] = prev_val;\n"
300
+ code += f" }} else {{\n"
301
+ code += f" if (code == 0) prev_val = 0;\n"
302
+ code += f" else if (code == 1) prev_val = 1;\n"
303
+ code += f" else prev_val = -1;\n"
304
+ code += f" w_buf[w_ptr++] = prev_val;\n"
305
+ code += f" }}\n"
306
+ code += f" }}\n"
307
+ code += f" \n"
308
+ code += f" // 2. Compute Convolution using Cache\n"
309
+ code += f" for (int oy = 0; oy < {out_h}; oy++) {{\n"
310
+ code += f" for (int ox = 0; ox < {out_w}; ox++) {{\n"
311
+ code += f" acc = 0;\n"
312
+ code += f" int buf_idx = 0;\n"
313
+
314
+ # Group offset calculation
315
+ channels_per_group = in_ch
316
+ if groups == out_ch:
317
+ group_stride_str = f"oc * {channels_per_group}"
318
+ elif groups > 1:
319
+ out_per_group = out_ch // groups
320
+ group_stride_str = f"(oc / {out_per_group}) * {channels_per_group}"
321
+ else:
322
+ group_stride_str = "0"
323
+
324
+ code += f" for (int ic = 0; ic < {in_ch}; ic++) {{\n"
325
+ code += f" for (int ky = 0; ky < {kh}; ky++) {{\n"
326
+ code += f" int iy = oy * {stride_h} + ky - {pad_h};\n"
327
+ code += f" if (iy < 0 || iy >= {in_h}) {{ buf_idx += {kw}; continue; }}\n"
328
+ code += f" for (int kx = 0; kx < {kw}; kx++) {{\n"
329
+ code += f" int ix = ox * {stride_w} + kx - {pad_w};\n"
330
+ code += f" if (ix >= 0 && ix < {in_w}) {{\n"
331
+
332
+ if group_stride_str == "0":
333
+ input_idx = f"ic * {in_h * in_w} + iy * {in_w} + ix"
334
+ else:
335
+ input_idx = f"({group_stride_str} + ic) * {in_h * in_w} + iy * {in_w} + ix"
336
+
337
+ code += f" int8_t w = w_buf[buf_idx];\n"
338
+ code += f" if (w) acc += (w == 1) ? input[{input_idx}] : -input[{input_idx}];\n"
339
+ code += f" }}\n"
340
+ code += f" buf_idx++;\n"
341
+ code += f" }}\n"
342
+ code += f" }}\n"
343
+ code += f" }}\n"
344
+ code += f" \n"
345
+ code += f" acc = scale_{name}(acc);\n"
346
+ if bias is not None: code += f" acc += {name}_bias[oc];\n"
347
+ if use_relu: code += f" if (acc < 0) acc = 0;\n"
348
+ code += f" output[oc * {out_h * out_w} + oy * {out_w} + ox] = (int8_t)(acc > 127 ? 127 : (acc < -128 ? -128 : acc));\n"
349
+ code += f" }}\n"
350
+ code += f" }}\n"
351
+ code += f" }}\n"
352
+ code += f"}}\n\n"
353
+
354
+ return code