potnn 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- potnn/__init__.py +86 -0
- potnn/codegen/__init__.py +20 -0
- potnn/codegen/bit2.py +263 -0
- potnn/codegen/fp130.py +269 -0
- potnn/codegen/header.py +460 -0
- potnn/codegen/level5.py +393 -0
- potnn/codegen/scale.py +184 -0
- potnn/codegen/ternary.py +354 -0
- potnn/codegen/unroll.py +616 -0
- potnn/config.py +112 -0
- potnn/export.py +2196 -0
- potnn/fuse.py +167 -0
- potnn/modules/__init__.py +11 -0
- potnn/modules/add.py +114 -0
- potnn/modules/avgpool.py +173 -0
- potnn/modules/base.py +225 -0
- potnn/modules/conv.py +203 -0
- potnn/modules/conv1d.py +317 -0
- potnn/modules/depthwise.py +216 -0
- potnn/modules/linear.py +199 -0
- potnn/quantize/__init__.py +35 -0
- potnn/quantize/calibration.py +233 -0
- potnn/quantize/integer_ops.py +207 -0
- potnn/quantize/integer_sim.py +225 -0
- potnn/quantize/pot.py +455 -0
- potnn/quantize/qat.py +356 -0
- potnn/utils/__init__.py +13 -0
- potnn/utils/allocation.py +240 -0
- potnn/utils/memory.py +158 -0
- potnn/wrapper.py +304 -0
- potnn-1.0.0.dist-info/METADATA +260 -0
- potnn-1.0.0.dist-info/RECORD +35 -0
- potnn-1.0.0.dist-info/WHEEL +5 -0
- potnn-1.0.0.dist-info/licenses/LICENSE +72 -0
- potnn-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""PoT-quantized Depthwise Conv2d layer with Integer Simulation.
|
|
2
|
+
|
|
3
|
+
v2: Added integer simulation for C-compatible QAT
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from typing import Union, Tuple
|
|
10
|
+
|
|
11
|
+
from .base import PoTLayerBase
|
|
12
|
+
from ..quantize.pot import quantize_to_pot_ste, quantize_to_pot, quantize_activation_ste
|
|
13
|
+
from ..quantize.integer_sim import (
|
|
14
|
+
round_ste, floor_ste, clamp_ste,
|
|
15
|
+
quantize_to_int8_ste, quantize_to_uint8_ste,
|
|
16
|
+
requantize_ste
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PoTDepthwiseConv2d(PoTLayerBase):
|
|
21
|
+
"""Power-of-Two quantized Depthwise Conv2d layer.
|
|
22
|
+
|
|
23
|
+
Depthwise convolution applies a single filter per input channel.
|
|
24
|
+
This is commonly used in MobileNet-style architectures as the first
|
|
25
|
+
part of depthwise separable convolution.
|
|
26
|
+
|
|
27
|
+
Key properties:
|
|
28
|
+
- in_channels == out_channels == channels
|
|
29
|
+
- groups = channels (each channel processed independently)
|
|
30
|
+
- weight shape: [channels, 1, kH, kW]
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
channels: int,
|
|
36
|
+
kernel_size: Union[int, Tuple[int, int]] = 3,
|
|
37
|
+
stride: Union[int, Tuple[int, int]] = 1,
|
|
38
|
+
padding: Union[int, Tuple[int, int]] = 1,
|
|
39
|
+
dilation: Union[int, Tuple[int, int]] = 1,
|
|
40
|
+
bias: bool = True,
|
|
41
|
+
encoding: str = 'unroll'
|
|
42
|
+
):
|
|
43
|
+
"""Initialize PoTDepthwiseConv2d layer.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
channels: Number of input/output channels
|
|
47
|
+
kernel_size: Size of the convolution kernel (default: 3)
|
|
48
|
+
stride: Stride of the convolution (default: 1)
|
|
49
|
+
padding: Zero-padding added to both sides (default: 1)
|
|
50
|
+
dilation: Spacing between kernel elements (default: 1)
|
|
51
|
+
bias: If True, adds a learnable bias (default: True)
|
|
52
|
+
encoding: Encoding type ('unroll', 'fp130', '5level', '2bit', 'ternary')
|
|
53
|
+
"""
|
|
54
|
+
super().__init__(encoding)
|
|
55
|
+
|
|
56
|
+
self.channels = channels
|
|
57
|
+
self.in_channels = channels # alias for compatibility
|
|
58
|
+
self.out_channels = channels # alias for compatibility
|
|
59
|
+
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
|
60
|
+
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
|
61
|
+
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
|
62
|
+
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
|
63
|
+
self.groups = channels # depthwise: each channel is its own group
|
|
64
|
+
|
|
65
|
+
# Initialize weight parameter: [channels, 1, kH, kW]
|
|
66
|
+
self.weight = nn.Parameter(torch.empty(
|
|
67
|
+
channels, 1, *self.kernel_size
|
|
68
|
+
))
|
|
69
|
+
|
|
70
|
+
# Initialize bias parameter
|
|
71
|
+
if bias:
|
|
72
|
+
self.bias = nn.Parameter(torch.zeros(channels))
|
|
73
|
+
else:
|
|
74
|
+
self.register_parameter('bias', None)
|
|
75
|
+
|
|
76
|
+
# Initialize weights using Kaiming normal
|
|
77
|
+
nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')
|
|
78
|
+
|
|
79
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
"""Forward pass with optional PoT quantization.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
x: Input tensor of shape (N, C, H, W)
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Output tensor of shape (N, C, H_out, W_out)
|
|
87
|
+
"""
|
|
88
|
+
if not self.quantize:
|
|
89
|
+
# Float mode (warmup training)
|
|
90
|
+
return F.conv2d(
|
|
91
|
+
x, self.weight, self.bias,
|
|
92
|
+
self.stride, self.padding, self.dilation, self.groups
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if self.use_integer_sim and self.scale_int is not None:
|
|
96
|
+
if self.training:
|
|
97
|
+
# Training: use float QAT for gradient flow
|
|
98
|
+
return self._forward_float_qat(x)
|
|
99
|
+
else:
|
|
100
|
+
# Eval: use integer sim for C-exact match
|
|
101
|
+
return self._forward_integer_sim(x)
|
|
102
|
+
else:
|
|
103
|
+
# Standard Float QAT Mode
|
|
104
|
+
return self._forward_float_qat(x)
|
|
105
|
+
|
|
106
|
+
def _forward_float_qat(self, x: torch.Tensor) -> torch.Tensor:
|
|
107
|
+
"""Original float QAT forward."""
|
|
108
|
+
w_q = quantize_to_pot_ste(self.weight, self.alpha, encoding=self.encoding)
|
|
109
|
+
|
|
110
|
+
out = F.conv2d(
|
|
111
|
+
x, w_q * self.alpha, self.bias,
|
|
112
|
+
self.stride, self.padding, self.dilation, self.groups
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if self.act_scale is not None:
|
|
116
|
+
out = quantize_activation_ste(out, self.act_scale)
|
|
117
|
+
|
|
118
|
+
return out
|
|
119
|
+
|
|
120
|
+
def _forward_integer_sim(self, x: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
"""Integer simulation forward - matches C inference exactly."""
|
|
122
|
+
DEBUG = False # True로 바꾸면 상세 디버그 출력
|
|
123
|
+
|
|
124
|
+
is_first = self.is_first_layer.item() if self.is_first_layer is not None else False
|
|
125
|
+
is_last = self.is_last_layer.item() if self.is_last_layer is not None else False
|
|
126
|
+
|
|
127
|
+
if DEBUG:
|
|
128
|
+
print(f"\n[DEBUG DW _forward_integer_sim] is_first={is_first}, is_last={is_last}")
|
|
129
|
+
print(f" input: shape={x.shape}, range=[{x.min():.4f}, {x.max():.4f}]")
|
|
130
|
+
|
|
131
|
+
# Step 1: Quantize input
|
|
132
|
+
if is_first:
|
|
133
|
+
# First layer: input is NORMALIZED, denormalize with channel-wise mean
|
|
134
|
+
if self.input_mean is not None and self.input_std is not None:
|
|
135
|
+
avg_std = self.input_std.mean().item()
|
|
136
|
+
mean = self.input_mean.view(1, -1, 1, 1).to(x.device) # [1, C, 1, 1]
|
|
137
|
+
x_raw = x * avg_std + mean # channel-wise mean
|
|
138
|
+
x_raw = torch.clamp(x_raw, 0.0, 1.0)
|
|
139
|
+
else:
|
|
140
|
+
x_raw = x
|
|
141
|
+
# [0,1] → [0,255] (uint8), /256 absorbed in shift (+8)
|
|
142
|
+
x_int = quantize_to_uint8_ste(x_raw, 256.0)
|
|
143
|
+
else:
|
|
144
|
+
prev_scale = self.prev_act_scale if self.prev_act_scale is not None else torch.tensor(1.0)
|
|
145
|
+
x_int = quantize_to_int8_ste(x, prev_scale)
|
|
146
|
+
|
|
147
|
+
# Step 2: PoT Depthwise Convolution with STE for gradient flow
|
|
148
|
+
w_pot = quantize_to_pot_ste(self.weight, self.alpha, encoding=self.encoding)
|
|
149
|
+
|
|
150
|
+
acc = F.conv2d(
|
|
151
|
+
x_int, w_pot, None,
|
|
152
|
+
self.stride, self.padding, self.dilation, self.groups
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Step 3: Requantize
|
|
156
|
+
scale_int = self.scale_int.item() if self.scale_int is not None else 1
|
|
157
|
+
shift = self.shift.item() if self.shift is not None else 0
|
|
158
|
+
acc = requantize_ste(acc, scale_int, shift)
|
|
159
|
+
|
|
160
|
+
if DEBUG:
|
|
161
|
+
print(f" acc after requantize: scale_int={scale_int}, shift={shift}, range=[{acc.min():.0f}, {acc.max():.0f}]")
|
|
162
|
+
|
|
163
|
+
# Step 4: Add bias (with mean absorption for first layer)
|
|
164
|
+
if self.bias is not None:
|
|
165
|
+
act_scale = self.act_scale if self.act_scale is not None else torch.tensor(1.0)
|
|
166
|
+
|
|
167
|
+
if is_first:
|
|
168
|
+
# First layer: absorb mean into bias
|
|
169
|
+
# Use avg_std to match QAT and C inference
|
|
170
|
+
if self.input_mean is not None and self.input_std is not None:
|
|
171
|
+
avg_std = self.input_std.mean().item()
|
|
172
|
+
# Depthwise: weight is [channels, 1, kH, kW]
|
|
173
|
+
channels = w_pot.shape[0]
|
|
174
|
+
alpha = self.alpha
|
|
175
|
+
bias_adjusted = self.bias.clone()
|
|
176
|
+
for c in range(channels):
|
|
177
|
+
mean_c = self.input_mean[c].item() if c < len(self.input_mean) else 0.0
|
|
178
|
+
weight_sum_c = w_pot[c].sum() * alpha
|
|
179
|
+
bias_adjusted[c] = bias_adjusted[c] - (mean_c / avg_std) * weight_sum_c
|
|
180
|
+
else:
|
|
181
|
+
bias_adjusted = self.bias
|
|
182
|
+
bias_int = round_ste(bias_adjusted * act_scale)
|
|
183
|
+
|
|
184
|
+
if DEBUG:
|
|
185
|
+
print(f" [First layer DW bias absorption]")
|
|
186
|
+
else:
|
|
187
|
+
bias_int = round_ste(self.bias * act_scale)
|
|
188
|
+
|
|
189
|
+
acc = acc + bias_int.view(1, -1, 1, 1)
|
|
190
|
+
|
|
191
|
+
# Step 5: Clamp
|
|
192
|
+
if not is_last:
|
|
193
|
+
out = clamp_ste(acc, 0.0, 127.0)
|
|
194
|
+
else:
|
|
195
|
+
out = acc
|
|
196
|
+
|
|
197
|
+
# Step 6: Convert back to float
|
|
198
|
+
if self.act_scale is not None and not is_last:
|
|
199
|
+
out = out / self.act_scale
|
|
200
|
+
|
|
201
|
+
return out
|
|
202
|
+
|
|
203
|
+
def extra_repr(self) -> str:
|
|
204
|
+
"""String representation of layer configuration."""
|
|
205
|
+
s = f'channels={self.channels}, kernel_size={self.kernel_size}, stride={self.stride}'
|
|
206
|
+
if self.padding != (0, 0):
|
|
207
|
+
s += f', padding={self.padding}'
|
|
208
|
+
if self.dilation != (1, 1):
|
|
209
|
+
s += f', dilation={self.dilation}'
|
|
210
|
+
if self.bias is None:
|
|
211
|
+
s += ', bias=False'
|
|
212
|
+
if self.quantize:
|
|
213
|
+
s += f', quantize=True, encoding={self.encoding}'
|
|
214
|
+
if self.use_integer_sim:
|
|
215
|
+
s += ', integer_sim=True'
|
|
216
|
+
return s
|
potnn/modules/linear.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""PoT-quantized Linear layer with Integer Simulation.
|
|
2
|
+
|
|
3
|
+
v2: Added integer simulation for C-compatible QAT
|
|
4
|
+
- Forward pass can simulate C integer operations exactly
|
|
5
|
+
- Matches C inference bit-for-bit when use_integer_sim=True
|
|
6
|
+
- Eliminates QAT-C accuracy gap
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
from .base import PoTLayerBase
|
|
15
|
+
from ..quantize.pot import quantize_to_pot_ste, quantize_to_pot, quantize_activation_ste, apply_5level_zero_constraint
|
|
16
|
+
from ..quantize.integer_ops import (
|
|
17
|
+
round_half_up_ste, clamp_ste,
|
|
18
|
+
fake_quantize_input, fake_quantize_input_uint8,
|
|
19
|
+
fake_requantize
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PoTLinear(PoTLayerBase):
|
|
24
|
+
"""Power-of-Two quantized Linear layer.
|
|
25
|
+
|
|
26
|
+
This layer implements a Linear (fully connected) layer with PoT weight
|
|
27
|
+
quantization.
|
|
28
|
+
|
|
29
|
+
[Integer-Only QAT Mode]
|
|
30
|
+
The forward pass simulates C integer arithmetic EXACTLY:
|
|
31
|
+
1. Input Quantization: float -> int8 (or uint8 for first layer)
|
|
32
|
+
2. Integer Linear: int8 * int8 -> int32
|
|
33
|
+
3. Requantize: (int32 * scale_int + round) >> shift
|
|
34
|
+
4. Bias Add: + round(bias_adjusted * act_scale)
|
|
35
|
+
5. Clamp: [0, 127]
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
in_features: int,
|
|
41
|
+
out_features: int,
|
|
42
|
+
bias: bool = True,
|
|
43
|
+
encoding: str = 'unroll'
|
|
44
|
+
):
|
|
45
|
+
"""Initialize PoTLinear layer.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
in_features: Size of each input sample
|
|
49
|
+
out_features: Size of each output sample
|
|
50
|
+
bias: If True, adds a learnable bias (default: True)
|
|
51
|
+
encoding: Encoding type ('unroll', 'fp130', '5level', '2bit', 'ternary')
|
|
52
|
+
"""
|
|
53
|
+
super().__init__(encoding)
|
|
54
|
+
|
|
55
|
+
self.in_features = in_features
|
|
56
|
+
self.out_features = out_features
|
|
57
|
+
|
|
58
|
+
# Initialize weight parameter
|
|
59
|
+
self.weight = nn.Parameter(torch.empty(out_features, in_features))
|
|
60
|
+
|
|
61
|
+
# Initialize bias parameter
|
|
62
|
+
if bias:
|
|
63
|
+
self.bias = nn.Parameter(torch.zeros(out_features))
|
|
64
|
+
else:
|
|
65
|
+
self.register_parameter('bias', None)
|
|
66
|
+
|
|
67
|
+
# Initialize weights using Kaiming normal
|
|
68
|
+
nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')
|
|
69
|
+
|
|
70
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
71
|
+
"""Forward pass with three modes:
|
|
72
|
+
1. Float warmup (quantize=False): Standard linear
|
|
73
|
+
2. Float QAT (use_integer_sim=False): PoT weight + float activation
|
|
74
|
+
3. Integer sim (use_integer_sim=True): C-identical integer ops
|
|
75
|
+
"""
|
|
76
|
+
if not self.quantize:
|
|
77
|
+
# Float mode (warmup training)
|
|
78
|
+
return F.linear(x, self.weight, self.bias)
|
|
79
|
+
|
|
80
|
+
if not getattr(self, 'use_integer_sim', False):
|
|
81
|
+
# Float QAT: PoT weight + float activation
|
|
82
|
+
# ReLU는 모델에서 외부로 호출
|
|
83
|
+
w_pot = quantize_to_pot_ste(self.weight, self.alpha, encoding=self.encoding)
|
|
84
|
+
|
|
85
|
+
# 5level constraint
|
|
86
|
+
if self.encoding == '5level' and self.enforce_5level_constraint:
|
|
87
|
+
w_pot = apply_5level_zero_constraint(w_pot)
|
|
88
|
+
|
|
89
|
+
out = F.linear(x, w_pot * self.alpha, self.bias)
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
# === Integer Simulation Mode (C-identical) ===
|
|
93
|
+
|
|
94
|
+
# === 1. Prepare Integer Parameters ===
|
|
95
|
+
# Always compute dynamically to ensure consistency with export
|
|
96
|
+
scale_int, shift, _ = self._compute_scale_and_shift()
|
|
97
|
+
|
|
98
|
+
is_first = self.is_first_layer.item() if self.is_first_layer is not None else False
|
|
99
|
+
is_last = self.is_last_layer.item() if self.is_last_layer is not None else False
|
|
100
|
+
|
|
101
|
+
# === 2. Input Quantization ===
|
|
102
|
+
if is_first:
|
|
103
|
+
# First layer: input is NORMALIZED (x - mean) / std
|
|
104
|
+
# Simulate C behavior: raw uint8 input
|
|
105
|
+
if self.input_mean is not None and self.input_std is not None:
|
|
106
|
+
avg_std = self.input_std.mean().item()
|
|
107
|
+
# Handle flattened input for Linear layer
|
|
108
|
+
if x.dim() == 2: # [batch, features]
|
|
109
|
+
num_ch = len(self.input_mean)
|
|
110
|
+
feat_per_ch = x.shape[1] // num_ch
|
|
111
|
+
if x.shape[1] == num_ch * feat_per_ch:
|
|
112
|
+
# Reshape to apply channel-wise mean
|
|
113
|
+
x_reshaped = x.view(x.shape[0], num_ch, feat_per_ch)
|
|
114
|
+
mean = self.input_mean.view(1, -1, 1).to(x.device)
|
|
115
|
+
x_raw = x_reshaped * avg_std + mean
|
|
116
|
+
x_raw = x_raw.view(x.shape[0], -1)
|
|
117
|
+
else:
|
|
118
|
+
# Fallback to average mean if dimensions don't match
|
|
119
|
+
avg_mean = self.input_mean.mean().item()
|
|
120
|
+
x_raw = x * avg_std + avg_mean
|
|
121
|
+
else:
|
|
122
|
+
avg_mean = self.input_mean.mean().item()
|
|
123
|
+
x_raw = x * avg_std + avg_mean
|
|
124
|
+
|
|
125
|
+
x_raw = clamp_ste(x_raw, 0.0, 1.0)
|
|
126
|
+
else:
|
|
127
|
+
x_raw = x
|
|
128
|
+
|
|
129
|
+
# Quantize to uint8 [0, 255]
|
|
130
|
+
x_int = fake_quantize_input_uint8(x_raw, 256.0)
|
|
131
|
+
else:
|
|
132
|
+
# Other layers: Input is already int8 from previous layer
|
|
133
|
+
x_int = x
|
|
134
|
+
|
|
135
|
+
# === 3. Weight Quantization ===
|
|
136
|
+
w_pot = quantize_to_pot_ste(self.weight, self.alpha, encoding=self.encoding)
|
|
137
|
+
|
|
138
|
+
# 5level constraint (always apply for 5level encoding to match export)
|
|
139
|
+
if self.encoding == '5level':
|
|
140
|
+
w_pot = apply_5level_zero_constraint(w_pot)
|
|
141
|
+
|
|
142
|
+
# === 4. Integer Linear ===
|
|
143
|
+
# F.linear with integer-valued inputs/weights -> integer-valued output
|
|
144
|
+
acc = F.linear(x_int, w_pot, None)
|
|
145
|
+
|
|
146
|
+
# === 5. Requantize ===
|
|
147
|
+
acc_scaled = fake_requantize(acc, scale_int, shift)
|
|
148
|
+
|
|
149
|
+
# === 6. Bias Addition ===
|
|
150
|
+
if self.bias is not None:
|
|
151
|
+
act_scale = self.act_scale if self.act_scale is not None else torch.tensor(1.0)
|
|
152
|
+
|
|
153
|
+
if is_first and self.input_mean is not None and self.input_std is not None:
|
|
154
|
+
# Absorb mean/std into bias
|
|
155
|
+
avg_std = self.input_std.mean().item()
|
|
156
|
+
alpha = self.alpha
|
|
157
|
+
in_features = self.weight.shape[1]
|
|
158
|
+
bias_adjusted = self.bias.clone()
|
|
159
|
+
|
|
160
|
+
num_channels = len(self.input_mean)
|
|
161
|
+
features_per_channel = in_features // num_channels if num_channels > 0 else in_features
|
|
162
|
+
|
|
163
|
+
if num_channels > 0 and in_features == num_channels * features_per_channel:
|
|
164
|
+
for c in range(num_channels):
|
|
165
|
+
mean_c = self.input_mean[c].item()
|
|
166
|
+
start_idx = c * features_per_channel
|
|
167
|
+
end_idx = start_idx + features_per_channel
|
|
168
|
+
weight_sum_c = w_pot[:, start_idx:end_idx].sum(dim=1) * alpha
|
|
169
|
+
bias_adjusted = bias_adjusted - (mean_c / avg_std) * weight_sum_c
|
|
170
|
+
else:
|
|
171
|
+
# Fallback if dimensions don't match
|
|
172
|
+
bias_adjusted = self.bias
|
|
173
|
+
else:
|
|
174
|
+
bias_adjusted = self.bias
|
|
175
|
+
|
|
176
|
+
# Quantize bias: round(bias * act_scale)
|
|
177
|
+
bias_int = round_half_up_ste(bias_adjusted * act_scale)
|
|
178
|
+
|
|
179
|
+
# Add bias
|
|
180
|
+
acc_scaled = acc_scaled + bias_int
|
|
181
|
+
|
|
182
|
+
# === 7. Clamp (ReLU) ===
|
|
183
|
+
if not is_last:
|
|
184
|
+
out = clamp_ste(acc_scaled, 0.0, 127.0)
|
|
185
|
+
else:
|
|
186
|
+
out = acc_scaled
|
|
187
|
+
|
|
188
|
+
# === 8. Output ===
|
|
189
|
+
# Round to ensure exact integer (floating point precision)
|
|
190
|
+
# Use STE to maintain gradient flow during training
|
|
191
|
+
# int8 그대로 반환 (C와 동일)
|
|
192
|
+
out = round_half_up_ste(out)
|
|
193
|
+
|
|
194
|
+
return out
|
|
195
|
+
|
|
196
|
+
def extra_repr(self) -> str:
|
|
197
|
+
s = super().extra_repr()
|
|
198
|
+
s += f', quantize={self.quantize}'
|
|
199
|
+
return s
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Quantization module for potnn."""
|
|
2
|
+
|
|
3
|
+
from .pot import (
|
|
4
|
+
quantize_to_pot, quantize_to_pot_ste, quantize_activation_ste,
|
|
5
|
+
quantize_activation_c_aligned_ste, get_pot_values
|
|
6
|
+
)
|
|
7
|
+
from .calibration import calibrate_model
|
|
8
|
+
from .qat import prepare_qat, alpha_reg_loss, enable_integer_sim, disable_integer_sim
|
|
9
|
+
from .integer_sim import (
|
|
10
|
+
round_ste, round_half_up_ste, floor_ste, clamp_ste,
|
|
11
|
+
quantize_to_int8_ste, quantize_to_uint8_ste,
|
|
12
|
+
requantize_ste, compute_scale_params
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
'quantize_to_pot',
|
|
17
|
+
'quantize_to_pot_ste',
|
|
18
|
+
'quantize_activation_ste',
|
|
19
|
+
'quantize_activation_c_aligned_ste',
|
|
20
|
+
'get_pot_values',
|
|
21
|
+
'calibrate_model',
|
|
22
|
+
'prepare_qat',
|
|
23
|
+
'alpha_reg_loss',
|
|
24
|
+
'enable_integer_sim',
|
|
25
|
+
'disable_integer_sim',
|
|
26
|
+
# Integer simulation
|
|
27
|
+
'round_ste',
|
|
28
|
+
'round_half_up_ste',
|
|
29
|
+
'floor_ste',
|
|
30
|
+
'clamp_ste',
|
|
31
|
+
'quantize_to_int8_ste',
|
|
32
|
+
'quantize_to_uint8_ste',
|
|
33
|
+
'requantize_ste',
|
|
34
|
+
'compute_scale_params',
|
|
35
|
+
]
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""Activation calibration for PoT quantization."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import math
|
|
5
|
+
from torch.utils.data import DataLoader
|
|
6
|
+
from typing import Dict, Any, Optional, Union, List
|
|
7
|
+
|
|
8
|
+
from ..modules.base import PoTLayerBase
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def calibrate_model(
|
|
12
|
+
model: torch.nn.Module,
|
|
13
|
+
data_loader: DataLoader,
|
|
14
|
+
num_batches: int = 10,
|
|
15
|
+
mean: Union[float, List[float]] = 0.0,
|
|
16
|
+
std: Union[float, List[float]] = 1.0
|
|
17
|
+
) -> Dict[str, float]:
|
|
18
|
+
"""Calibrate activation scales for each PoT layer.
|
|
19
|
+
|
|
20
|
+
This function measures the maximum activation values for each layer
|
|
21
|
+
during inference on calibration data. These values are used to set
|
|
22
|
+
fixed activation scales for quantization-aware training.
|
|
23
|
+
|
|
24
|
+
IMPORTANT: This function should be called ONCE before QAT training,
|
|
25
|
+
after float warmup training.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model: Model containing PoT layers to calibrate
|
|
29
|
+
data_loader: DataLoader for calibration data
|
|
30
|
+
num_batches: Number of batches to use for calibration (default: 10)
|
|
31
|
+
mean: Dataset mean for normalization (default: 0.0)
|
|
32
|
+
std: Dataset std for normalization (default: 1.0)
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Dictionary mapping layer names to maximum activation values
|
|
36
|
+
"""
|
|
37
|
+
# ========================================
|
|
38
|
+
# Step 0: Fuse BatchNorm layers BEFORE calibration
|
|
39
|
+
# This ensures calibration measures activations with final fused weights.
|
|
40
|
+
# ========================================
|
|
41
|
+
from ..fuse import fuse_batchnorm
|
|
42
|
+
fuse_batchnorm(model)
|
|
43
|
+
|
|
44
|
+
print("Starting calibration...")
|
|
45
|
+
|
|
46
|
+
# Set model to evaluation mode
|
|
47
|
+
model.eval()
|
|
48
|
+
|
|
49
|
+
# Dictionary to store max activation values for each layer
|
|
50
|
+
activation_max = {}
|
|
51
|
+
hooks = []
|
|
52
|
+
|
|
53
|
+
def make_hook(name):
|
|
54
|
+
"""Create a forward hook to capture activation values."""
|
|
55
|
+
def hook(module, input, output):
|
|
56
|
+
with torch.no_grad():
|
|
57
|
+
max_val = output.abs().max().item()
|
|
58
|
+
if name not in activation_max:
|
|
59
|
+
activation_max[name] = max_val
|
|
60
|
+
else:
|
|
61
|
+
activation_max[name] = max(activation_max[name], max_val)
|
|
62
|
+
return hook
|
|
63
|
+
|
|
64
|
+
# Register hooks on all PoT layers
|
|
65
|
+
for name, module in model.named_modules():
|
|
66
|
+
if isinstance(module, PoTLayerBase):
|
|
67
|
+
hook = module.register_forward_hook(make_hook(name))
|
|
68
|
+
hooks.append(hook)
|
|
69
|
+
|
|
70
|
+
# Also register hooks on PoTAdd layers to track their output range
|
|
71
|
+
from ..modules.add import PoTAdd
|
|
72
|
+
|
|
73
|
+
# PoTAdd 입력 scale 추적용 딕셔너리
|
|
74
|
+
add_input_max = {} # {name: {'x': max_x, 'y': max_y}}
|
|
75
|
+
|
|
76
|
+
def make_add_hook(name):
|
|
77
|
+
"""Create a forward hook for PoTAdd to capture input/output values."""
|
|
78
|
+
def hook(module, input, output):
|
|
79
|
+
with torch.no_grad():
|
|
80
|
+
# output max
|
|
81
|
+
max_val = output.abs().max().item()
|
|
82
|
+
if name not in activation_max:
|
|
83
|
+
activation_max[name] = max_val
|
|
84
|
+
else:
|
|
85
|
+
activation_max[name] = max(activation_max[name], max_val)
|
|
86
|
+
|
|
87
|
+
# input max (x=skip, y=conv)
|
|
88
|
+
x, y = input[0], input[1]
|
|
89
|
+
max_x = x.abs().max().item()
|
|
90
|
+
max_y = y.abs().max().item()
|
|
91
|
+
|
|
92
|
+
if name not in add_input_max:
|
|
93
|
+
add_input_max[name] = {'x': max_x, 'y': max_y}
|
|
94
|
+
else:
|
|
95
|
+
add_input_max[name]['x'] = max(add_input_max[name]['x'], max_x)
|
|
96
|
+
add_input_max[name]['y'] = max(add_input_max[name]['y'], max_y)
|
|
97
|
+
return hook
|
|
98
|
+
|
|
99
|
+
for name, module in model.named_modules():
|
|
100
|
+
if isinstance(module, PoTAdd):
|
|
101
|
+
hook = module.register_forward_hook(make_add_hook(name))
|
|
102
|
+
hooks.append(hook)
|
|
103
|
+
|
|
104
|
+
# Run forward passes to collect statistics
|
|
105
|
+
with torch.no_grad():
|
|
106
|
+
for i, batch in enumerate(data_loader):
|
|
107
|
+
if i >= num_batches:
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
# Handle different data formats (data, target) or just data
|
|
111
|
+
if isinstance(batch, (list, tuple)):
|
|
112
|
+
data = batch[0]
|
|
113
|
+
else:
|
|
114
|
+
data = batch
|
|
115
|
+
|
|
116
|
+
# Move to same device as model
|
|
117
|
+
device = next(model.parameters()).device
|
|
118
|
+
data = data.to(device)
|
|
119
|
+
|
|
120
|
+
# Normalize data (support both scalar and per-channel mean/std)
|
|
121
|
+
# Dynamic reshape based on data dimensions (3D for Conv1d, 4D for Conv2d)
|
|
122
|
+
if isinstance(mean, (list, tuple)):
|
|
123
|
+
mean_t = torch.tensor(mean, dtype=data.dtype, device=device)
|
|
124
|
+
if data.dim() == 3: # Conv1d: (B, C, L)
|
|
125
|
+
mean_t = mean_t.view(1, -1, 1)
|
|
126
|
+
else: # Conv2d: (B, C, H, W)
|
|
127
|
+
mean_t = mean_t.view(1, -1, 1, 1)
|
|
128
|
+
else:
|
|
129
|
+
mean_t = mean
|
|
130
|
+
if isinstance(std, (list, tuple)):
|
|
131
|
+
std_t = torch.tensor(std, dtype=data.dtype, device=device)
|
|
132
|
+
if data.dim() == 3: # Conv1d: (B, C, L)
|
|
133
|
+
std_t = std_t.view(1, -1, 1)
|
|
134
|
+
else: # Conv2d: (B, C, H, W)
|
|
135
|
+
std_t = std_t.view(1, -1, 1, 1)
|
|
136
|
+
else:
|
|
137
|
+
std_t = std
|
|
138
|
+
data = (data - mean_t) / std_t
|
|
139
|
+
|
|
140
|
+
# Forward pass
|
|
141
|
+
_ = model(data)
|
|
142
|
+
|
|
143
|
+
# Progress indicator
|
|
144
|
+
if (i + 1) % 5 == 0:
|
|
145
|
+
print(f" Calibration batch {i + 1}/{min(num_batches, len(data_loader))}")
|
|
146
|
+
|
|
147
|
+
# Remove hooks
|
|
148
|
+
for hook in hooks:
|
|
149
|
+
hook.remove()
|
|
150
|
+
|
|
151
|
+
# Set activation scales based on calibration
|
|
152
|
+
print("\nSetting activation scales:")
|
|
153
|
+
for name, module in model.named_modules():
|
|
154
|
+
if isinstance(module, PoTLayerBase):
|
|
155
|
+
if name in activation_max and activation_max[name] > 0:
|
|
156
|
+
# Set the activation scale for quantization
|
|
157
|
+
module.calibrate(activation_max[name])
|
|
158
|
+
print(f" {name}: act_max={activation_max[name]:.2f}, "
|
|
159
|
+
f"act_scale={module.act_scale.item():.4f}")
|
|
160
|
+
else:
|
|
161
|
+
# Default if layer wasn't activated during calibration
|
|
162
|
+
module.calibrate(1.0)
|
|
163
|
+
print(f" {name}: no activations detected, using default scale")
|
|
164
|
+
|
|
165
|
+
# Set activation scales for PoTAdd layers
|
|
166
|
+
print("\nSetting PoTAdd scales:")
|
|
167
|
+
for name, module in model.named_modules():
|
|
168
|
+
if isinstance(module, PoTAdd):
|
|
169
|
+
if name in activation_max and activation_max[name] > 0:
|
|
170
|
+
# act_scale = 127 / max_activation
|
|
171
|
+
act_scale = 127.0 / activation_max[name]
|
|
172
|
+
module.act_scale = torch.tensor(act_scale)
|
|
173
|
+
module.quantize = True
|
|
174
|
+
|
|
175
|
+
# 입력 scale 설정 (핵심!)
|
|
176
|
+
if name in add_input_max:
|
|
177
|
+
max_x = add_input_max[name]['x']
|
|
178
|
+
max_y = add_input_max[name]['y']
|
|
179
|
+
scale_x = 127.0 / max_x if max_x > 0 else 1.0
|
|
180
|
+
scale_y = 127.0 / max_y if max_y > 0 else 1.0
|
|
181
|
+
module.set_scales(scale_x, scale_y)
|
|
182
|
+
print(f" {name}: scale_x={scale_x:.4f}, scale_y={scale_y:.4f}, "
|
|
183
|
+
f"ratio={scale_y/scale_x:.4f} (PoTAdd)")
|
|
184
|
+
else:
|
|
185
|
+
print(f" {name}: act_max={activation_max[name]:.2f}, "
|
|
186
|
+
f"act_scale={act_scale:.4f} (PoTAdd, no input scales)")
|
|
187
|
+
else:
|
|
188
|
+
# Default scale
|
|
189
|
+
module.act_scale = torch.tensor(1.0)
|
|
190
|
+
module.quantize = True
|
|
191
|
+
print(f" {name}: no activations detected, using default scale (PoTAdd)")
|
|
192
|
+
|
|
193
|
+
# Initialize alpha values based on weight distribution
|
|
194
|
+
print("\nInitializing alpha values based on weight distribution:")
|
|
195
|
+
for name, module in model.named_modules():
|
|
196
|
+
if isinstance(module, PoTLayerBase):
|
|
197
|
+
with torch.no_grad():
|
|
198
|
+
# Calculate weight standard deviation
|
|
199
|
+
w_std = module.weight.std().item()
|
|
200
|
+
|
|
201
|
+
# Calculate appropriate raw_alpha value
|
|
202
|
+
# We want: softplus(raw_alpha) ≈ w_std
|
|
203
|
+
# softplus(x) = log(1 + exp(x))
|
|
204
|
+
# So we need: x = log(exp(w_std) - 1) when w_std > log(2)
|
|
205
|
+
if w_std > 0.01:
|
|
206
|
+
# Inverse softplus: raw = log(exp(target) - 1)
|
|
207
|
+
# But ensure numerical stability
|
|
208
|
+
target = w_std
|
|
209
|
+
if target > 10: # Avoid overflow
|
|
210
|
+
raw = target # For large values, softplus(x) ≈ x
|
|
211
|
+
elif target > 0.1:
|
|
212
|
+
raw = math.log(math.exp(target) - 1)
|
|
213
|
+
else:
|
|
214
|
+
# For small values, use approximation
|
|
215
|
+
raw = math.log(target) if target > 0 else -2.0
|
|
216
|
+
else:
|
|
217
|
+
# Very small weights, use small alpha
|
|
218
|
+
raw = -2.0
|
|
219
|
+
|
|
220
|
+
# Update raw_alpha
|
|
221
|
+
module.raw_alpha.data.fill_(raw)
|
|
222
|
+
|
|
223
|
+
# Update alpha_init to match the newly initialized alpha
|
|
224
|
+
# This ensures regularization pulls towards the calibrated value
|
|
225
|
+
module.alpha_init.data.fill_(module.alpha.item())
|
|
226
|
+
|
|
227
|
+
# Verify the result
|
|
228
|
+
actual_alpha = module.alpha.item()
|
|
229
|
+
print(f" {name}: weight_std={w_std:.4f}, "
|
|
230
|
+
f"raw_alpha={raw:.4f}, alpha={actual_alpha:.4f}")
|
|
231
|
+
|
|
232
|
+
print(f"\nCalibration complete. Processed {len(activation_max)} layers.")
|
|
233
|
+
return activation_max
|