potnn 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- potnn/__init__.py +86 -0
- potnn/codegen/__init__.py +20 -0
- potnn/codegen/bit2.py +263 -0
- potnn/codegen/fp130.py +269 -0
- potnn/codegen/header.py +460 -0
- potnn/codegen/level5.py +393 -0
- potnn/codegen/scale.py +184 -0
- potnn/codegen/ternary.py +354 -0
- potnn/codegen/unroll.py +616 -0
- potnn/config.py +112 -0
- potnn/export.py +2196 -0
- potnn/fuse.py +167 -0
- potnn/modules/__init__.py +11 -0
- potnn/modules/add.py +114 -0
- potnn/modules/avgpool.py +173 -0
- potnn/modules/base.py +225 -0
- potnn/modules/conv.py +203 -0
- potnn/modules/conv1d.py +317 -0
- potnn/modules/depthwise.py +216 -0
- potnn/modules/linear.py +199 -0
- potnn/quantize/__init__.py +35 -0
- potnn/quantize/calibration.py +233 -0
- potnn/quantize/integer_ops.py +207 -0
- potnn/quantize/integer_sim.py +225 -0
- potnn/quantize/pot.py +455 -0
- potnn/quantize/qat.py +356 -0
- potnn/utils/__init__.py +13 -0
- potnn/utils/allocation.py +240 -0
- potnn/utils/memory.py +158 -0
- potnn/wrapper.py +304 -0
- potnn-1.0.0.dist-info/METADATA +260 -0
- potnn-1.0.0.dist-info/RECORD +35 -0
- potnn-1.0.0.dist-info/WHEEL +5 -0
- potnn-1.0.0.dist-info/licenses/LICENSE +72 -0
- potnn-1.0.0.dist-info/top_level.txt +1 -0
potnn/quantize/pot.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
"""Power-of-Two (PoT) quantization functions."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import Union, List
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# =============================================================================
|
|
9
|
+
# 인코딩별 레벨셋 정의
|
|
10
|
+
# =============================================================================
|
|
11
|
+
|
|
12
|
+
ENCODING_LEVELS = {
|
|
13
|
+
# unroll: fp130 + Zero (17레벨) - 민감층용, 최고 정확도
|
|
14
|
+
'unroll': [0, 1, -1, 2, -2, 4, -4, 8, -8, 16, -16, 32, -32, 64, -64, 128, -128],
|
|
15
|
+
|
|
16
|
+
# fp130: FP1.3.0 형식 (16레벨, Zero 없음) - Dense 레이어용
|
|
17
|
+
'fp130': [1, -1, 2, -2, 4, -4, 8, -8, 16, -16, 32, -32, 64, -64, 128, -128],
|
|
18
|
+
|
|
19
|
+
# 5level: 희소 레이어용 (5레벨, Zero 있음)
|
|
20
|
+
'5level': [-8, -1, 0, 1, 8],
|
|
21
|
+
|
|
22
|
+
# 2bit: 최소 메모리용 (4레벨, Zero 없음)
|
|
23
|
+
'2bit': [-2, -1, 1, 2],
|
|
24
|
+
|
|
25
|
+
# ternary: 최소 메모리 + Zero (3레벨)
|
|
26
|
+
'ternary': [-1, 0, 1],
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
# 인코딩별 positive 값 (0 제외, 양자화용)
|
|
30
|
+
ENCODING_POS_VALUES = {
|
|
31
|
+
'unroll': [0, 1, 2, 4, 8, 16, 32, 64, 128],
|
|
32
|
+
'fp130': [1, 2, 4, 8, 16, 32, 64, 128],
|
|
33
|
+
'5level': [0, 1, 8],
|
|
34
|
+
'2bit': [1, 2],
|
|
35
|
+
'ternary': [0, 1],
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
# 인코딩별 Zero 포함 여부
|
|
39
|
+
ENCODING_HAS_ZERO = {
|
|
40
|
+
'unroll': True,
|
|
41
|
+
'fp130': False,
|
|
42
|
+
'5level': True,
|
|
43
|
+
'2bit': False,
|
|
44
|
+
'ternary': True,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_pot_values(encoding: str = 'unroll') -> torch.Tensor:
|
|
49
|
+
"""Get PoT values for given encoding.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
encoding: Encoding type ('unroll', 'fp130', '5level', '2bit', 'ternary')
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Tensor of PoT values
|
|
56
|
+
"""
|
|
57
|
+
if encoding not in ENCODING_LEVELS:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Unsupported encoding: {encoding}. "
|
|
60
|
+
f"Must be one of {list(ENCODING_LEVELS.keys())}"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
return torch.tensor(ENCODING_LEVELS[encoding], dtype=torch.float32)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_pot_pos_values(encoding: str = 'unroll') -> torch.Tensor:
|
|
67
|
+
"""Get positive PoT values (including 0 if applicable) for given encoding.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
encoding: Encoding type
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Tensor of positive PoT values for quantization
|
|
74
|
+
"""
|
|
75
|
+
if encoding not in ENCODING_POS_VALUES:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Unsupported encoding: {encoding}. "
|
|
78
|
+
f"Must be one of {list(ENCODING_POS_VALUES.keys())}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return torch.tensor(ENCODING_POS_VALUES[encoding], dtype=torch.float32)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# Backward compatibility: levels -> encoding 매핑
|
|
85
|
+
def _levels_to_encoding(levels: int) -> str:
|
|
86
|
+
"""Convert legacy levels parameter to encoding string."""
|
|
87
|
+
mapping = {
|
|
88
|
+
3: 'ternary',
|
|
89
|
+
5: '5level',
|
|
90
|
+
11: 'unroll',
|
|
91
|
+
17: 'unroll',
|
|
92
|
+
16: 'fp130',
|
|
93
|
+
}
|
|
94
|
+
if levels not in mapping:
|
|
95
|
+
raise ValueError(f"Unsupported levels: {levels}")
|
|
96
|
+
return mapping[levels]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def quantize_to_pot(
|
|
100
|
+
weight: torch.Tensor,
|
|
101
|
+
alpha: float,
|
|
102
|
+
encoding: str = 'unroll',
|
|
103
|
+
levels: int = None # Backward compatibility
|
|
104
|
+
) -> torch.Tensor:
|
|
105
|
+
"""Quantize weight tensor to Power-of-Two values.
|
|
106
|
+
|
|
107
|
+
This function quantizes float weights to the nearest PoT value.
|
|
108
|
+
The quantization is done by finding the nearest PoT value for each weight.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
weight: Weight tensor to quantize (float)
|
|
112
|
+
alpha: Scaling factor to normalize weights before quantization
|
|
113
|
+
encoding: Encoding type ('unroll', 'fp130', '5level', '2bit', 'ternary')
|
|
114
|
+
levels: (Deprecated) Legacy parameter, use encoding instead
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Quantized weight tensor with PoT values
|
|
118
|
+
"""
|
|
119
|
+
# Backward compatibility
|
|
120
|
+
if levels is not None:
|
|
121
|
+
encoding = _levels_to_encoding(levels)
|
|
122
|
+
|
|
123
|
+
# Get positive PoT values for this encoding
|
|
124
|
+
pot_pos = get_pot_pos_values(encoding).to(weight.device)
|
|
125
|
+
has_zero = ENCODING_HAS_ZERO[encoding]
|
|
126
|
+
|
|
127
|
+
# Separate sign and magnitude
|
|
128
|
+
sign = torch.sign(weight)
|
|
129
|
+
abs_weight = torch.abs(weight / alpha)
|
|
130
|
+
|
|
131
|
+
# Find nearest PoT value for absolute weights
|
|
132
|
+
abs_weight_flat = abs_weight.reshape(-1, 1)
|
|
133
|
+
pot_pos_flat = pot_pos.reshape(1, -1)
|
|
134
|
+
distances = torch.abs(abs_weight_flat - pot_pos_flat)
|
|
135
|
+
indices = torch.argmin(distances, dim=1)
|
|
136
|
+
|
|
137
|
+
# Get quantized absolute values
|
|
138
|
+
quantized_abs = pot_pos[indices].reshape(weight.shape)
|
|
139
|
+
|
|
140
|
+
# Restore sign
|
|
141
|
+
w_q = sign * quantized_abs
|
|
142
|
+
|
|
143
|
+
# Zero-free 인코딩의 경우, 0에 가장 가까운 값으로 대체
|
|
144
|
+
if not has_zero:
|
|
145
|
+
min_val = pot_pos[pot_pos > 0].min()
|
|
146
|
+
zero_mask = (w_q == 0)
|
|
147
|
+
|
|
148
|
+
if encoding == 'fp130':
|
|
149
|
+
# fp130: Use alternating +1/-1 for zeros to reduce bias
|
|
150
|
+
# This matches C export packing logic if zeros were preserved
|
|
151
|
+
flat_indices = torch.arange(w_q.numel(), device=w_q.device)
|
|
152
|
+
toggle_vals = torch.where(flat_indices % 2 == 0, min_val, -min_val)
|
|
153
|
+
toggle_vals = toggle_vals.reshape(w_q.shape)
|
|
154
|
+
w_q = torch.where(zero_mask, toggle_vals, w_q)
|
|
155
|
+
else:
|
|
156
|
+
# 기본: 최소 양수값으로 대체 (+1/-1은 sign 유지)
|
|
157
|
+
# sign이 0인 경우 (원래 weight가 0) → +min_val로
|
|
158
|
+
w_q = torch.where(zero_mask & (sign >= 0), min_val, w_q)
|
|
159
|
+
w_q = torch.where(zero_mask & (sign < 0), -min_val, w_q)
|
|
160
|
+
|
|
161
|
+
# 5level 인코딩 constraint는 forward에서 enforce_5level_constraint 플래그로 처리
|
|
162
|
+
# (torch.export 호환성을 위해 여기서는 적용 안함)
|
|
163
|
+
|
|
164
|
+
return w_q
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _apply_5level_constraint_numpy(w_q: torch.Tensor) -> torch.Tensor:
|
|
168
|
+
"""Apply 5-level constraint (non-differentiable, for export only)."""
|
|
169
|
+
w_np = w_q.clone()
|
|
170
|
+
|
|
171
|
+
if w_q.dim() == 2: # Linear
|
|
172
|
+
out_features, in_features = w_q.shape
|
|
173
|
+
for o in range(out_features):
|
|
174
|
+
zero_run = 0
|
|
175
|
+
for i in range(in_features):
|
|
176
|
+
if w_np[o, i] == 0:
|
|
177
|
+
zero_run += 1
|
|
178
|
+
if zero_run > 3:
|
|
179
|
+
w_np[o, i] = 1.0
|
|
180
|
+
zero_run = 0
|
|
181
|
+
else:
|
|
182
|
+
zero_run = 0
|
|
183
|
+
elif w_q.dim() == 4: # Conv2d
|
|
184
|
+
out_ch = w_q.shape[0]
|
|
185
|
+
for oc in range(out_ch):
|
|
186
|
+
w_flat = w_np[oc].flatten()
|
|
187
|
+
zero_run = 0
|
|
188
|
+
for i in range(len(w_flat)):
|
|
189
|
+
if w_flat[i] == 0:
|
|
190
|
+
zero_run += 1
|
|
191
|
+
if zero_run > 3:
|
|
192
|
+
w_flat[i] = 1.0
|
|
193
|
+
zero_run = 0
|
|
194
|
+
else:
|
|
195
|
+
zero_run = 0
|
|
196
|
+
w_np[oc] = w_flat.view(w_q[oc].shape)
|
|
197
|
+
|
|
198
|
+
return w_np
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def apply_5level_zero_constraint(w_q: torch.Tensor) -> torch.Tensor:
|
|
202
|
+
"""Apply 5-level encoding constraint: max 3 consecutive zeros.
|
|
203
|
+
|
|
204
|
+
5-level encoding uses 2-bit skip field (0~3), so 4+ consecutive zeros
|
|
205
|
+
cannot be encoded. This function replaces the 4th+ consecutive zero
|
|
206
|
+
with +1 (smallest non-zero value).
|
|
207
|
+
|
|
208
|
+
Works with 2D (Linear) and 4D (Conv2d) tensors.
|
|
209
|
+
For Conv2d, applies constraint per output filter.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
w_q: Quantized weight tensor
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Constrained weight tensor (differentiable via STE)
|
|
216
|
+
"""
|
|
217
|
+
# Detach for modification, but keep gradient flow via STE
|
|
218
|
+
w_constrained = w_q.clone()
|
|
219
|
+
|
|
220
|
+
if w_q.dim() == 2: # Linear: (out_features, in_features)
|
|
221
|
+
out_features, in_features = w_q.shape
|
|
222
|
+
for o in range(out_features):
|
|
223
|
+
zero_run = 0
|
|
224
|
+
for i in range(in_features):
|
|
225
|
+
if w_constrained[o, i] == 0:
|
|
226
|
+
zero_run += 1
|
|
227
|
+
if zero_run > 3:
|
|
228
|
+
w_constrained[o, i] = 1.0 # Replace with +1
|
|
229
|
+
zero_run = 0
|
|
230
|
+
else:
|
|
231
|
+
zero_run = 0
|
|
232
|
+
|
|
233
|
+
elif w_q.dim() == 4: # Conv2d: (out_ch, in_ch, kh, kw)
|
|
234
|
+
out_ch = w_q.shape[0]
|
|
235
|
+
flat_size = w_q[0].numel()
|
|
236
|
+
for oc in range(out_ch):
|
|
237
|
+
w_flat = w_constrained[oc].flatten()
|
|
238
|
+
zero_run = 0
|
|
239
|
+
for i in range(flat_size):
|
|
240
|
+
if w_flat[i] == 0:
|
|
241
|
+
zero_run += 1
|
|
242
|
+
if zero_run > 3:
|
|
243
|
+
w_flat[i] = 1.0
|
|
244
|
+
zero_run = 0
|
|
245
|
+
else:
|
|
246
|
+
zero_run = 0
|
|
247
|
+
w_constrained[oc] = w_flat.view(w_q[oc].shape)
|
|
248
|
+
|
|
249
|
+
# STE: use constrained in forward, but gradient flows to original
|
|
250
|
+
return w_q + (w_constrained - w_q).detach()
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class PoTQuantizeSTE(torch.autograd.Function):
|
|
254
|
+
"""Straight-Through Estimator for PoT quantization.
|
|
255
|
+
|
|
256
|
+
Forward pass: quantize to PoT values
|
|
257
|
+
Backward pass: pass gradient through unchanged
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
@staticmethod
|
|
261
|
+
def forward(ctx, weight, alpha, encoding):
|
|
262
|
+
"""Forward pass: quantize to PoT values."""
|
|
263
|
+
return quantize_to_pot(weight, alpha, encoding=encoding)
|
|
264
|
+
|
|
265
|
+
@staticmethod
|
|
266
|
+
def backward(ctx, grad_output):
|
|
267
|
+
"""Backward pass: gradient passes through unchanged (STE)."""
|
|
268
|
+
return grad_output, None, None
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def quantize_to_pot_ste(
|
|
272
|
+
weight: torch.Tensor,
|
|
273
|
+
alpha: torch.Tensor,
|
|
274
|
+
encoding: str = 'unroll',
|
|
275
|
+
levels: int = None # Backward compatibility
|
|
276
|
+
) -> torch.Tensor:
|
|
277
|
+
"""Quantize weight tensor to PoT values with Straight-Through Estimator.
|
|
278
|
+
|
|
279
|
+
This function applies PoT quantization in the forward pass while allowing
|
|
280
|
+
gradients to flow through unchanged in the backward pass (STE).
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
weight: Weight tensor to quantize
|
|
284
|
+
alpha: Scaling factor (learnable parameter)
|
|
285
|
+
encoding: Encoding type ('unroll', 'fp130', '5level', '2bit', 'ternary')
|
|
286
|
+
levels: (Deprecated) Legacy parameter, use encoding instead
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Quantized weight tensor with STE for gradients
|
|
290
|
+
"""
|
|
291
|
+
# Backward compatibility
|
|
292
|
+
if levels is not None:
|
|
293
|
+
encoding = _levels_to_encoding(levels)
|
|
294
|
+
|
|
295
|
+
return PoTQuantizeSTE.apply(weight, alpha, encoding)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class QuantizeActivationSTE(torch.autograd.Function):
|
|
299
|
+
"""Straight-Through Estimator for activation quantization.
|
|
300
|
+
|
|
301
|
+
Forward pass: quantize to int8 range using half-up rounding (C style)
|
|
302
|
+
Backward pass: pass gradient through unchanged (STE)
|
|
303
|
+
|
|
304
|
+
CRITICAL: Without STE, rounding has zero gradient,
|
|
305
|
+
which blocks gradient flow to earlier layers and causes QAT to fail.
|
|
306
|
+
|
|
307
|
+
NOTE: Uses floor(x + 0.5) for half-up rounding to match C behavior:
|
|
308
|
+
C: (int)(x + 0.5) or (x * scale + round_const) >> shift
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
@staticmethod
|
|
312
|
+
def forward(ctx, x, scale):
|
|
313
|
+
"""Forward pass: quantize activation to int8 range with half-up rounding."""
|
|
314
|
+
# Half-up rounding to match C: floor(x + 0.5)
|
|
315
|
+
return torch.floor(x * scale + 0.5).clamp(-128, 127) / scale
|
|
316
|
+
|
|
317
|
+
@staticmethod
|
|
318
|
+
def backward(ctx, grad_output):
|
|
319
|
+
"""Backward pass: gradient passes through unchanged (STE)."""
|
|
320
|
+
return grad_output, None
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def quantize_activation_ste(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
324
|
+
"""Quantize activation tensor with Straight-Through Estimator.
|
|
325
|
+
|
|
326
|
+
This function applies int8 quantization in the forward pass while allowing
|
|
327
|
+
gradients to flow through unchanged in the backward pass.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
x: Activation tensor to quantize
|
|
331
|
+
scale: Scale factor (127.0 / max_activation)
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Quantized activation tensor with STE for gradients
|
|
335
|
+
"""
|
|
336
|
+
return QuantizeActivationSTE.apply(x, scale)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class QuantizeActivationCAlignedSTE(torch.autograd.Function):
|
|
340
|
+
"""C-Aligned activation quantization with STE.
|
|
341
|
+
|
|
342
|
+
Uses integer scale_int and shift to match C code exactly:
|
|
343
|
+
C: out = (acc * scale_int + (1 << (shift-1))) >> shift
|
|
344
|
+
|
|
345
|
+
This eliminates the floating-point precision gap between QAT and C.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
@staticmethod
|
|
349
|
+
def forward(ctx, x, scale_int, shift, act_scale):
|
|
350
|
+
"""Forward: C-style requantize."""
|
|
351
|
+
# C: (x * scale_int + round_const) >> shift
|
|
352
|
+
if shift > 0:
|
|
353
|
+
round_const = 1 << (shift - 1)
|
|
354
|
+
else:
|
|
355
|
+
round_const = 0
|
|
356
|
+
|
|
357
|
+
divisor = float(1 << shift)
|
|
358
|
+
|
|
359
|
+
# Scale x to integer range first (matching C's integer accumulator)
|
|
360
|
+
x_scaled = x * act_scale
|
|
361
|
+
|
|
362
|
+
# Apply C-style requantize
|
|
363
|
+
numerator = x_scaled * float(scale_int) + float(round_const)
|
|
364
|
+
result = torch.floor(numerator / divisor)
|
|
365
|
+
result = result.clamp(-128, 127)
|
|
366
|
+
|
|
367
|
+
# Convert back to float range for next layer
|
|
368
|
+
return result / act_scale
|
|
369
|
+
|
|
370
|
+
@staticmethod
|
|
371
|
+
def backward(ctx, grad_output):
|
|
372
|
+
"""Backward: STE - gradient passes through."""
|
|
373
|
+
return grad_output, None, None, None
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def quantize_activation_c_aligned_ste(
|
|
377
|
+
x: torch.Tensor,
|
|
378
|
+
scale_int: int,
|
|
379
|
+
shift: int,
|
|
380
|
+
act_scale: torch.Tensor
|
|
381
|
+
) -> torch.Tensor:
|
|
382
|
+
"""C-aligned activation quantization with STE.
|
|
383
|
+
|
|
384
|
+
Matches C code: out = (acc * scale_int + round) >> shift
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
x: Activation tensor
|
|
388
|
+
scale_int: Integer scale (from compute_scale_params)
|
|
389
|
+
shift: Shift amount (from compute_scale_params)
|
|
390
|
+
act_scale: Original float act_scale (for scaling back)
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
Quantized activation matching C behavior
|
|
394
|
+
"""
|
|
395
|
+
return QuantizeActivationCAlignedSTE.apply(x, scale_int, shift, act_scale)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def pack_2bit(weights: np.ndarray) -> np.ndarray:
|
|
399
|
+
"""Pack 4-level weights into 2-bit representation.
|
|
400
|
+
|
|
401
|
+
Encoding:
|
|
402
|
+
-1 -> 0b00
|
|
403
|
+
0 -> 0b01
|
|
404
|
+
1 -> 0b10
|
|
405
|
+
2 -> 0b11
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
weights: Array of weights with values in {-1, 0, 1, 2}
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
Packed uint8 array (4 weights per byte)
|
|
412
|
+
"""
|
|
413
|
+
# Map weights to 2-bit codes
|
|
414
|
+
mapping = {-1: 0b00, 0: 0b01, 1: 0b10, 2: 0b11}
|
|
415
|
+
|
|
416
|
+
# Flatten weights
|
|
417
|
+
w_flat = weights.flatten()
|
|
418
|
+
|
|
419
|
+
# Pad to multiple of 4
|
|
420
|
+
pad_len = (4 - len(w_flat) % 4) % 4
|
|
421
|
+
if pad_len > 0:
|
|
422
|
+
w_flat = np.concatenate([w_flat, np.zeros(pad_len)])
|
|
423
|
+
|
|
424
|
+
# Pack 4 weights per byte
|
|
425
|
+
packed = []
|
|
426
|
+
for i in range(0, len(w_flat), 4):
|
|
427
|
+
byte = 0
|
|
428
|
+
for j in range(4):
|
|
429
|
+
code = mapping.get(int(w_flat[i+j]), 0b01) # Default to 0 if not found
|
|
430
|
+
byte |= (code << (j * 2))
|
|
431
|
+
packed.append(byte)
|
|
432
|
+
|
|
433
|
+
return np.array(packed, dtype=np.uint8)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def unpack_2bit(packed: np.ndarray, num_weights: int) -> np.ndarray:
|
|
437
|
+
"""Unpack 2-bit representation to 4-level weights.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
packed: Packed uint8 array
|
|
441
|
+
num_weights: Number of weights to unpack
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Array of weights with values in {-1, 0, 1, 2}
|
|
445
|
+
"""
|
|
446
|
+
# Inverse mapping
|
|
447
|
+
mapping = {0b00: -1, 0b01: 0, 0b10: 1, 0b11: 2}
|
|
448
|
+
|
|
449
|
+
weights = []
|
|
450
|
+
for byte in packed:
|
|
451
|
+
for j in range(4):
|
|
452
|
+
code = (byte >> (j * 2)) & 0b11
|
|
453
|
+
weights.append(mapping[code])
|
|
454
|
+
|
|
455
|
+
return np.array(weights[:num_weights])
|