potnn 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,216 @@
1
+ """PoT-quantized Depthwise Conv2d layer with Integer Simulation.
2
+
3
+ v2: Added integer simulation for C-compatible QAT
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Union, Tuple
10
+
11
+ from .base import PoTLayerBase
12
+ from ..quantize.pot import quantize_to_pot_ste, quantize_to_pot, quantize_activation_ste
13
+ from ..quantize.integer_sim import (
14
+ round_ste, floor_ste, clamp_ste,
15
+ quantize_to_int8_ste, quantize_to_uint8_ste,
16
+ requantize_ste
17
+ )
18
+
19
+
20
+ class PoTDepthwiseConv2d(PoTLayerBase):
21
+ """Power-of-Two quantized Depthwise Conv2d layer.
22
+
23
+ Depthwise convolution applies a single filter per input channel.
24
+ This is commonly used in MobileNet-style architectures as the first
25
+ part of depthwise separable convolution.
26
+
27
+ Key properties:
28
+ - in_channels == out_channels == channels
29
+ - groups = channels (each channel processed independently)
30
+ - weight shape: [channels, 1, kH, kW]
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ channels: int,
36
+ kernel_size: Union[int, Tuple[int, int]] = 3,
37
+ stride: Union[int, Tuple[int, int]] = 1,
38
+ padding: Union[int, Tuple[int, int]] = 1,
39
+ dilation: Union[int, Tuple[int, int]] = 1,
40
+ bias: bool = True,
41
+ encoding: str = 'unroll'
42
+ ):
43
+ """Initialize PoTDepthwiseConv2d layer.
44
+
45
+ Args:
46
+ channels: Number of input/output channels
47
+ kernel_size: Size of the convolution kernel (default: 3)
48
+ stride: Stride of the convolution (default: 1)
49
+ padding: Zero-padding added to both sides (default: 1)
50
+ dilation: Spacing between kernel elements (default: 1)
51
+ bias: If True, adds a learnable bias (default: True)
52
+ encoding: Encoding type ('unroll', 'fp130', '5level', '2bit', 'ternary')
53
+ """
54
+ super().__init__(encoding)
55
+
56
+ self.channels = channels
57
+ self.in_channels = channels # alias for compatibility
58
+ self.out_channels = channels # alias for compatibility
59
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
60
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride)
61
+ self.padding = padding if isinstance(padding, tuple) else (padding, padding)
62
+ self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
63
+ self.groups = channels # depthwise: each channel is its own group
64
+
65
+ # Initialize weight parameter: [channels, 1, kH, kW]
66
+ self.weight = nn.Parameter(torch.empty(
67
+ channels, 1, *self.kernel_size
68
+ ))
69
+
70
+ # Initialize bias parameter
71
+ if bias:
72
+ self.bias = nn.Parameter(torch.zeros(channels))
73
+ else:
74
+ self.register_parameter('bias', None)
75
+
76
+ # Initialize weights using Kaiming normal
77
+ nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ """Forward pass with optional PoT quantization.
81
+
82
+ Args:
83
+ x: Input tensor of shape (N, C, H, W)
84
+
85
+ Returns:
86
+ Output tensor of shape (N, C, H_out, W_out)
87
+ """
88
+ if not self.quantize:
89
+ # Float mode (warmup training)
90
+ return F.conv2d(
91
+ x, self.weight, self.bias,
92
+ self.stride, self.padding, self.dilation, self.groups
93
+ )
94
+
95
+ if self.use_integer_sim and self.scale_int is not None:
96
+ if self.training:
97
+ # Training: use float QAT for gradient flow
98
+ return self._forward_float_qat(x)
99
+ else:
100
+ # Eval: use integer sim for C-exact match
101
+ return self._forward_integer_sim(x)
102
+ else:
103
+ # Standard Float QAT Mode
104
+ return self._forward_float_qat(x)
105
+
106
+ def _forward_float_qat(self, x: torch.Tensor) -> torch.Tensor:
107
+ """Original float QAT forward."""
108
+ w_q = quantize_to_pot_ste(self.weight, self.alpha, encoding=self.encoding)
109
+
110
+ out = F.conv2d(
111
+ x, w_q * self.alpha, self.bias,
112
+ self.stride, self.padding, self.dilation, self.groups
113
+ )
114
+
115
+ if self.act_scale is not None:
116
+ out = quantize_activation_ste(out, self.act_scale)
117
+
118
+ return out
119
+
120
+ def _forward_integer_sim(self, x: torch.Tensor) -> torch.Tensor:
121
+ """Integer simulation forward - matches C inference exactly."""
122
+ DEBUG = False # True로 바꾸면 상세 디버그 출력
123
+
124
+ is_first = self.is_first_layer.item() if self.is_first_layer is not None else False
125
+ is_last = self.is_last_layer.item() if self.is_last_layer is not None else False
126
+
127
+ if DEBUG:
128
+ print(f"\n[DEBUG DW _forward_integer_sim] is_first={is_first}, is_last={is_last}")
129
+ print(f" input: shape={x.shape}, range=[{x.min():.4f}, {x.max():.4f}]")
130
+
131
+ # Step 1: Quantize input
132
+ if is_first:
133
+ # First layer: input is NORMALIZED, denormalize with channel-wise mean
134
+ if self.input_mean is not None and self.input_std is not None:
135
+ avg_std = self.input_std.mean().item()
136
+ mean = self.input_mean.view(1, -1, 1, 1).to(x.device) # [1, C, 1, 1]
137
+ x_raw = x * avg_std + mean # channel-wise mean
138
+ x_raw = torch.clamp(x_raw, 0.0, 1.0)
139
+ else:
140
+ x_raw = x
141
+ # [0,1] → [0,255] (uint8), /256 absorbed in shift (+8)
142
+ x_int = quantize_to_uint8_ste(x_raw, 256.0)
143
+ else:
144
+ prev_scale = self.prev_act_scale if self.prev_act_scale is not None else torch.tensor(1.0)
145
+ x_int = quantize_to_int8_ste(x, prev_scale)
146
+
147
+ # Step 2: PoT Depthwise Convolution with STE for gradient flow
148
+ w_pot = quantize_to_pot_ste(self.weight, self.alpha, encoding=self.encoding)
149
+
150
+ acc = F.conv2d(
151
+ x_int, w_pot, None,
152
+ self.stride, self.padding, self.dilation, self.groups
153
+ )
154
+
155
+ # Step 3: Requantize
156
+ scale_int = self.scale_int.item() if self.scale_int is not None else 1
157
+ shift = self.shift.item() if self.shift is not None else 0
158
+ acc = requantize_ste(acc, scale_int, shift)
159
+
160
+ if DEBUG:
161
+ print(f" acc after requantize: scale_int={scale_int}, shift={shift}, range=[{acc.min():.0f}, {acc.max():.0f}]")
162
+
163
+ # Step 4: Add bias (with mean absorption for first layer)
164
+ if self.bias is not None:
165
+ act_scale = self.act_scale if self.act_scale is not None else torch.tensor(1.0)
166
+
167
+ if is_first:
168
+ # First layer: absorb mean into bias
169
+ # Use avg_std to match QAT and C inference
170
+ if self.input_mean is not None and self.input_std is not None:
171
+ avg_std = self.input_std.mean().item()
172
+ # Depthwise: weight is [channels, 1, kH, kW]
173
+ channels = w_pot.shape[0]
174
+ alpha = self.alpha
175
+ bias_adjusted = self.bias.clone()
176
+ for c in range(channels):
177
+ mean_c = self.input_mean[c].item() if c < len(self.input_mean) else 0.0
178
+ weight_sum_c = w_pot[c].sum() * alpha
179
+ bias_adjusted[c] = bias_adjusted[c] - (mean_c / avg_std) * weight_sum_c
180
+ else:
181
+ bias_adjusted = self.bias
182
+ bias_int = round_ste(bias_adjusted * act_scale)
183
+
184
+ if DEBUG:
185
+ print(f" [First layer DW bias absorption]")
186
+ else:
187
+ bias_int = round_ste(self.bias * act_scale)
188
+
189
+ acc = acc + bias_int.view(1, -1, 1, 1)
190
+
191
+ # Step 5: Clamp
192
+ if not is_last:
193
+ out = clamp_ste(acc, 0.0, 127.0)
194
+ else:
195
+ out = acc
196
+
197
+ # Step 6: Convert back to float
198
+ if self.act_scale is not None and not is_last:
199
+ out = out / self.act_scale
200
+
201
+ return out
202
+
203
+ def extra_repr(self) -> str:
204
+ """String representation of layer configuration."""
205
+ s = f'channels={self.channels}, kernel_size={self.kernel_size}, stride={self.stride}'
206
+ if self.padding != (0, 0):
207
+ s += f', padding={self.padding}'
208
+ if self.dilation != (1, 1):
209
+ s += f', dilation={self.dilation}'
210
+ if self.bias is None:
211
+ s += ', bias=False'
212
+ if self.quantize:
213
+ s += f', quantize=True, encoding={self.encoding}'
214
+ if self.use_integer_sim:
215
+ s += ', integer_sim=True'
216
+ return s
@@ -0,0 +1,199 @@
1
+ """PoT-quantized Linear layer with Integer Simulation.
2
+
3
+ v2: Added integer simulation for C-compatible QAT
4
+ - Forward pass can simulate C integer operations exactly
5
+ - Matches C inference bit-for-bit when use_integer_sim=True
6
+ - Eliminates QAT-C accuracy gap
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Optional
13
+
14
+ from .base import PoTLayerBase
15
+ from ..quantize.pot import quantize_to_pot_ste, quantize_to_pot, quantize_activation_ste, apply_5level_zero_constraint
16
+ from ..quantize.integer_ops import (
17
+ round_half_up_ste, clamp_ste,
18
+ fake_quantize_input, fake_quantize_input_uint8,
19
+ fake_requantize
20
+ )
21
+
22
+
23
+ class PoTLinear(PoTLayerBase):
24
+ """Power-of-Two quantized Linear layer.
25
+
26
+ This layer implements a Linear (fully connected) layer with PoT weight
27
+ quantization.
28
+
29
+ [Integer-Only QAT Mode]
30
+ The forward pass simulates C integer arithmetic EXACTLY:
31
+ 1. Input Quantization: float -> int8 (or uint8 for first layer)
32
+ 2. Integer Linear: int8 * int8 -> int32
33
+ 3. Requantize: (int32 * scale_int + round) >> shift
34
+ 4. Bias Add: + round(bias_adjusted * act_scale)
35
+ 5. Clamp: [0, 127]
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ in_features: int,
41
+ out_features: int,
42
+ bias: bool = True,
43
+ encoding: str = 'unroll'
44
+ ):
45
+ """Initialize PoTLinear layer.
46
+
47
+ Args:
48
+ in_features: Size of each input sample
49
+ out_features: Size of each output sample
50
+ bias: If True, adds a learnable bias (default: True)
51
+ encoding: Encoding type ('unroll', 'fp130', '5level', '2bit', 'ternary')
52
+ """
53
+ super().__init__(encoding)
54
+
55
+ self.in_features = in_features
56
+ self.out_features = out_features
57
+
58
+ # Initialize weight parameter
59
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
60
+
61
+ # Initialize bias parameter
62
+ if bias:
63
+ self.bias = nn.Parameter(torch.zeros(out_features))
64
+ else:
65
+ self.register_parameter('bias', None)
66
+
67
+ # Initialize weights using Kaiming normal
68
+ nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ """Forward pass with three modes:
72
+ 1. Float warmup (quantize=False): Standard linear
73
+ 2. Float QAT (use_integer_sim=False): PoT weight + float activation
74
+ 3. Integer sim (use_integer_sim=True): C-identical integer ops
75
+ """
76
+ if not self.quantize:
77
+ # Float mode (warmup training)
78
+ return F.linear(x, self.weight, self.bias)
79
+
80
+ if not getattr(self, 'use_integer_sim', False):
81
+ # Float QAT: PoT weight + float activation
82
+ # ReLU는 모델에서 외부로 호출
83
+ w_pot = quantize_to_pot_ste(self.weight, self.alpha, encoding=self.encoding)
84
+
85
+ # 5level constraint
86
+ if self.encoding == '5level' and self.enforce_5level_constraint:
87
+ w_pot = apply_5level_zero_constraint(w_pot)
88
+
89
+ out = F.linear(x, w_pot * self.alpha, self.bias)
90
+ return out
91
+
92
+ # === Integer Simulation Mode (C-identical) ===
93
+
94
+ # === 1. Prepare Integer Parameters ===
95
+ # Always compute dynamically to ensure consistency with export
96
+ scale_int, shift, _ = self._compute_scale_and_shift()
97
+
98
+ is_first = self.is_first_layer.item() if self.is_first_layer is not None else False
99
+ is_last = self.is_last_layer.item() if self.is_last_layer is not None else False
100
+
101
+ # === 2. Input Quantization ===
102
+ if is_first:
103
+ # First layer: input is NORMALIZED (x - mean) / std
104
+ # Simulate C behavior: raw uint8 input
105
+ if self.input_mean is not None and self.input_std is not None:
106
+ avg_std = self.input_std.mean().item()
107
+ # Handle flattened input for Linear layer
108
+ if x.dim() == 2: # [batch, features]
109
+ num_ch = len(self.input_mean)
110
+ feat_per_ch = x.shape[1] // num_ch
111
+ if x.shape[1] == num_ch * feat_per_ch:
112
+ # Reshape to apply channel-wise mean
113
+ x_reshaped = x.view(x.shape[0], num_ch, feat_per_ch)
114
+ mean = self.input_mean.view(1, -1, 1).to(x.device)
115
+ x_raw = x_reshaped * avg_std + mean
116
+ x_raw = x_raw.view(x.shape[0], -1)
117
+ else:
118
+ # Fallback to average mean if dimensions don't match
119
+ avg_mean = self.input_mean.mean().item()
120
+ x_raw = x * avg_std + avg_mean
121
+ else:
122
+ avg_mean = self.input_mean.mean().item()
123
+ x_raw = x * avg_std + avg_mean
124
+
125
+ x_raw = clamp_ste(x_raw, 0.0, 1.0)
126
+ else:
127
+ x_raw = x
128
+
129
+ # Quantize to uint8 [0, 255]
130
+ x_int = fake_quantize_input_uint8(x_raw, 256.0)
131
+ else:
132
+ # Other layers: Input is already int8 from previous layer
133
+ x_int = x
134
+
135
+ # === 3. Weight Quantization ===
136
+ w_pot = quantize_to_pot_ste(self.weight, self.alpha, encoding=self.encoding)
137
+
138
+ # 5level constraint (always apply for 5level encoding to match export)
139
+ if self.encoding == '5level':
140
+ w_pot = apply_5level_zero_constraint(w_pot)
141
+
142
+ # === 4. Integer Linear ===
143
+ # F.linear with integer-valued inputs/weights -> integer-valued output
144
+ acc = F.linear(x_int, w_pot, None)
145
+
146
+ # === 5. Requantize ===
147
+ acc_scaled = fake_requantize(acc, scale_int, shift)
148
+
149
+ # === 6. Bias Addition ===
150
+ if self.bias is not None:
151
+ act_scale = self.act_scale if self.act_scale is not None else torch.tensor(1.0)
152
+
153
+ if is_first and self.input_mean is not None and self.input_std is not None:
154
+ # Absorb mean/std into bias
155
+ avg_std = self.input_std.mean().item()
156
+ alpha = self.alpha
157
+ in_features = self.weight.shape[1]
158
+ bias_adjusted = self.bias.clone()
159
+
160
+ num_channels = len(self.input_mean)
161
+ features_per_channel = in_features // num_channels if num_channels > 0 else in_features
162
+
163
+ if num_channels > 0 and in_features == num_channels * features_per_channel:
164
+ for c in range(num_channels):
165
+ mean_c = self.input_mean[c].item()
166
+ start_idx = c * features_per_channel
167
+ end_idx = start_idx + features_per_channel
168
+ weight_sum_c = w_pot[:, start_idx:end_idx].sum(dim=1) * alpha
169
+ bias_adjusted = bias_adjusted - (mean_c / avg_std) * weight_sum_c
170
+ else:
171
+ # Fallback if dimensions don't match
172
+ bias_adjusted = self.bias
173
+ else:
174
+ bias_adjusted = self.bias
175
+
176
+ # Quantize bias: round(bias * act_scale)
177
+ bias_int = round_half_up_ste(bias_adjusted * act_scale)
178
+
179
+ # Add bias
180
+ acc_scaled = acc_scaled + bias_int
181
+
182
+ # === 7. Clamp (ReLU) ===
183
+ if not is_last:
184
+ out = clamp_ste(acc_scaled, 0.0, 127.0)
185
+ else:
186
+ out = acc_scaled
187
+
188
+ # === 8. Output ===
189
+ # Round to ensure exact integer (floating point precision)
190
+ # Use STE to maintain gradient flow during training
191
+ # int8 그대로 반환 (C와 동일)
192
+ out = round_half_up_ste(out)
193
+
194
+ return out
195
+
196
+ def extra_repr(self) -> str:
197
+ s = super().extra_repr()
198
+ s += f', quantize={self.quantize}'
199
+ return s
@@ -0,0 +1,35 @@
1
+ """Quantization module for potnn."""
2
+
3
+ from .pot import (
4
+ quantize_to_pot, quantize_to_pot_ste, quantize_activation_ste,
5
+ quantize_activation_c_aligned_ste, get_pot_values
6
+ )
7
+ from .calibration import calibrate_model
8
+ from .qat import prepare_qat, alpha_reg_loss, enable_integer_sim, disable_integer_sim
9
+ from .integer_sim import (
10
+ round_ste, round_half_up_ste, floor_ste, clamp_ste,
11
+ quantize_to_int8_ste, quantize_to_uint8_ste,
12
+ requantize_ste, compute_scale_params
13
+ )
14
+
15
+ __all__ = [
16
+ 'quantize_to_pot',
17
+ 'quantize_to_pot_ste',
18
+ 'quantize_activation_ste',
19
+ 'quantize_activation_c_aligned_ste',
20
+ 'get_pot_values',
21
+ 'calibrate_model',
22
+ 'prepare_qat',
23
+ 'alpha_reg_loss',
24
+ 'enable_integer_sim',
25
+ 'disable_integer_sim',
26
+ # Integer simulation
27
+ 'round_ste',
28
+ 'round_half_up_ste',
29
+ 'floor_ste',
30
+ 'clamp_ste',
31
+ 'quantize_to_int8_ste',
32
+ 'quantize_to_uint8_ste',
33
+ 'requantize_ste',
34
+ 'compute_scale_params',
35
+ ]
@@ -0,0 +1,233 @@
1
+ """Activation calibration for PoT quantization."""
2
+
3
+ import torch
4
+ import math
5
+ from torch.utils.data import DataLoader
6
+ from typing import Dict, Any, Optional, Union, List
7
+
8
+ from ..modules.base import PoTLayerBase
9
+
10
+
11
+ def calibrate_model(
12
+ model: torch.nn.Module,
13
+ data_loader: DataLoader,
14
+ num_batches: int = 10,
15
+ mean: Union[float, List[float]] = 0.0,
16
+ std: Union[float, List[float]] = 1.0
17
+ ) -> Dict[str, float]:
18
+ """Calibrate activation scales for each PoT layer.
19
+
20
+ This function measures the maximum activation values for each layer
21
+ during inference on calibration data. These values are used to set
22
+ fixed activation scales for quantization-aware training.
23
+
24
+ IMPORTANT: This function should be called ONCE before QAT training,
25
+ after float warmup training.
26
+
27
+ Args:
28
+ model: Model containing PoT layers to calibrate
29
+ data_loader: DataLoader for calibration data
30
+ num_batches: Number of batches to use for calibration (default: 10)
31
+ mean: Dataset mean for normalization (default: 0.0)
32
+ std: Dataset std for normalization (default: 1.0)
33
+
34
+ Returns:
35
+ Dictionary mapping layer names to maximum activation values
36
+ """
37
+ # ========================================
38
+ # Step 0: Fuse BatchNorm layers BEFORE calibration
39
+ # This ensures calibration measures activations with final fused weights.
40
+ # ========================================
41
+ from ..fuse import fuse_batchnorm
42
+ fuse_batchnorm(model)
43
+
44
+ print("Starting calibration...")
45
+
46
+ # Set model to evaluation mode
47
+ model.eval()
48
+
49
+ # Dictionary to store max activation values for each layer
50
+ activation_max = {}
51
+ hooks = []
52
+
53
+ def make_hook(name):
54
+ """Create a forward hook to capture activation values."""
55
+ def hook(module, input, output):
56
+ with torch.no_grad():
57
+ max_val = output.abs().max().item()
58
+ if name not in activation_max:
59
+ activation_max[name] = max_val
60
+ else:
61
+ activation_max[name] = max(activation_max[name], max_val)
62
+ return hook
63
+
64
+ # Register hooks on all PoT layers
65
+ for name, module in model.named_modules():
66
+ if isinstance(module, PoTLayerBase):
67
+ hook = module.register_forward_hook(make_hook(name))
68
+ hooks.append(hook)
69
+
70
+ # Also register hooks on PoTAdd layers to track their output range
71
+ from ..modules.add import PoTAdd
72
+
73
+ # PoTAdd 입력 scale 추적용 딕셔너리
74
+ add_input_max = {} # {name: {'x': max_x, 'y': max_y}}
75
+
76
+ def make_add_hook(name):
77
+ """Create a forward hook for PoTAdd to capture input/output values."""
78
+ def hook(module, input, output):
79
+ with torch.no_grad():
80
+ # output max
81
+ max_val = output.abs().max().item()
82
+ if name not in activation_max:
83
+ activation_max[name] = max_val
84
+ else:
85
+ activation_max[name] = max(activation_max[name], max_val)
86
+
87
+ # input max (x=skip, y=conv)
88
+ x, y = input[0], input[1]
89
+ max_x = x.abs().max().item()
90
+ max_y = y.abs().max().item()
91
+
92
+ if name not in add_input_max:
93
+ add_input_max[name] = {'x': max_x, 'y': max_y}
94
+ else:
95
+ add_input_max[name]['x'] = max(add_input_max[name]['x'], max_x)
96
+ add_input_max[name]['y'] = max(add_input_max[name]['y'], max_y)
97
+ return hook
98
+
99
+ for name, module in model.named_modules():
100
+ if isinstance(module, PoTAdd):
101
+ hook = module.register_forward_hook(make_add_hook(name))
102
+ hooks.append(hook)
103
+
104
+ # Run forward passes to collect statistics
105
+ with torch.no_grad():
106
+ for i, batch in enumerate(data_loader):
107
+ if i >= num_batches:
108
+ break
109
+
110
+ # Handle different data formats (data, target) or just data
111
+ if isinstance(batch, (list, tuple)):
112
+ data = batch[0]
113
+ else:
114
+ data = batch
115
+
116
+ # Move to same device as model
117
+ device = next(model.parameters()).device
118
+ data = data.to(device)
119
+
120
+ # Normalize data (support both scalar and per-channel mean/std)
121
+ # Dynamic reshape based on data dimensions (3D for Conv1d, 4D for Conv2d)
122
+ if isinstance(mean, (list, tuple)):
123
+ mean_t = torch.tensor(mean, dtype=data.dtype, device=device)
124
+ if data.dim() == 3: # Conv1d: (B, C, L)
125
+ mean_t = mean_t.view(1, -1, 1)
126
+ else: # Conv2d: (B, C, H, W)
127
+ mean_t = mean_t.view(1, -1, 1, 1)
128
+ else:
129
+ mean_t = mean
130
+ if isinstance(std, (list, tuple)):
131
+ std_t = torch.tensor(std, dtype=data.dtype, device=device)
132
+ if data.dim() == 3: # Conv1d: (B, C, L)
133
+ std_t = std_t.view(1, -1, 1)
134
+ else: # Conv2d: (B, C, H, W)
135
+ std_t = std_t.view(1, -1, 1, 1)
136
+ else:
137
+ std_t = std
138
+ data = (data - mean_t) / std_t
139
+
140
+ # Forward pass
141
+ _ = model(data)
142
+
143
+ # Progress indicator
144
+ if (i + 1) % 5 == 0:
145
+ print(f" Calibration batch {i + 1}/{min(num_batches, len(data_loader))}")
146
+
147
+ # Remove hooks
148
+ for hook in hooks:
149
+ hook.remove()
150
+
151
+ # Set activation scales based on calibration
152
+ print("\nSetting activation scales:")
153
+ for name, module in model.named_modules():
154
+ if isinstance(module, PoTLayerBase):
155
+ if name in activation_max and activation_max[name] > 0:
156
+ # Set the activation scale for quantization
157
+ module.calibrate(activation_max[name])
158
+ print(f" {name}: act_max={activation_max[name]:.2f}, "
159
+ f"act_scale={module.act_scale.item():.4f}")
160
+ else:
161
+ # Default if layer wasn't activated during calibration
162
+ module.calibrate(1.0)
163
+ print(f" {name}: no activations detected, using default scale")
164
+
165
+ # Set activation scales for PoTAdd layers
166
+ print("\nSetting PoTAdd scales:")
167
+ for name, module in model.named_modules():
168
+ if isinstance(module, PoTAdd):
169
+ if name in activation_max and activation_max[name] > 0:
170
+ # act_scale = 127 / max_activation
171
+ act_scale = 127.0 / activation_max[name]
172
+ module.act_scale = torch.tensor(act_scale)
173
+ module.quantize = True
174
+
175
+ # 입력 scale 설정 (핵심!)
176
+ if name in add_input_max:
177
+ max_x = add_input_max[name]['x']
178
+ max_y = add_input_max[name]['y']
179
+ scale_x = 127.0 / max_x if max_x > 0 else 1.0
180
+ scale_y = 127.0 / max_y if max_y > 0 else 1.0
181
+ module.set_scales(scale_x, scale_y)
182
+ print(f" {name}: scale_x={scale_x:.4f}, scale_y={scale_y:.4f}, "
183
+ f"ratio={scale_y/scale_x:.4f} (PoTAdd)")
184
+ else:
185
+ print(f" {name}: act_max={activation_max[name]:.2f}, "
186
+ f"act_scale={act_scale:.4f} (PoTAdd, no input scales)")
187
+ else:
188
+ # Default scale
189
+ module.act_scale = torch.tensor(1.0)
190
+ module.quantize = True
191
+ print(f" {name}: no activations detected, using default scale (PoTAdd)")
192
+
193
+ # Initialize alpha values based on weight distribution
194
+ print("\nInitializing alpha values based on weight distribution:")
195
+ for name, module in model.named_modules():
196
+ if isinstance(module, PoTLayerBase):
197
+ with torch.no_grad():
198
+ # Calculate weight standard deviation
199
+ w_std = module.weight.std().item()
200
+
201
+ # Calculate appropriate raw_alpha value
202
+ # We want: softplus(raw_alpha) ≈ w_std
203
+ # softplus(x) = log(1 + exp(x))
204
+ # So we need: x = log(exp(w_std) - 1) when w_std > log(2)
205
+ if w_std > 0.01:
206
+ # Inverse softplus: raw = log(exp(target) - 1)
207
+ # But ensure numerical stability
208
+ target = w_std
209
+ if target > 10: # Avoid overflow
210
+ raw = target # For large values, softplus(x) ≈ x
211
+ elif target > 0.1:
212
+ raw = math.log(math.exp(target) - 1)
213
+ else:
214
+ # For small values, use approximation
215
+ raw = math.log(target) if target > 0 else -2.0
216
+ else:
217
+ # Very small weights, use small alpha
218
+ raw = -2.0
219
+
220
+ # Update raw_alpha
221
+ module.raw_alpha.data.fill_(raw)
222
+
223
+ # Update alpha_init to match the newly initialized alpha
224
+ # This ensures regularization pulls towards the calibrated value
225
+ module.alpha_init.data.fill_(module.alpha.item())
226
+
227
+ # Verify the result
228
+ actual_alpha = module.alpha.item()
229
+ print(f" {name}: weight_std={w_std:.4f}, "
230
+ f"raw_alpha={raw:.4f}, alpha={actual_alpha:.4f}")
231
+
232
+ print(f"\nCalibration complete. Processed {len(activation_max)} layers.")
233
+ return activation_max