potnn 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
potnn/fuse.py ADDED
@@ -0,0 +1,167 @@
1
+ """BatchNorm fusion into Conv/Linear layers."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Dict, List, Tuple, Optional
6
+
7
+
8
+ def fuse_batchnorm(model: nn.Module) -> nn.Module:
9
+ """Fuse BatchNorm layers into preceding Conv/Linear layers.
10
+
11
+ This absorbs BatchNorm parameters (γ, β, μ, σ) into the weight and bias
12
+ of the preceding convolution or linear layer, eliminating the need for
13
+ separate BatchNorm computation at inference time.
14
+
15
+ Formula:
16
+ y = γ * (x - μ) / √(σ² + ε) + β
17
+
18
+ For Conv/Linear followed by BatchNorm:
19
+ out = BN(W*x + b)
20
+ = scale * (W*x + b) + bias'
21
+ = (scale * W) * x + (scale * b + bias')
22
+
23
+ where:
24
+ scale = γ / √(σ² + ε)
25
+ bias' = β - γ * μ / √(σ² + ε)
26
+
27
+ Therefore:
28
+ W_fused = W * scale
29
+ b_fused = b * scale + bias' = (b - μ) * scale + β
30
+
31
+ Args:
32
+ model: Model with Conv/Linear + BatchNorm sequences
33
+
34
+ Returns:
35
+ Model with BatchNorm fused (BatchNorm layers become identity)
36
+ """
37
+ print("Fusing BatchNorm layers...")
38
+
39
+ # Find Conv/Linear -> BatchNorm pairs
40
+ pairs = _find_bn_pairs(model)
41
+
42
+ if not pairs:
43
+ print(" No BatchNorm layers found to fuse.")
44
+ return model
45
+
46
+ # Fuse each pair
47
+ for conv_name, bn_name, conv_module, bn_module in pairs:
48
+ _fuse_single_bn(conv_module, bn_module)
49
+ print(f" Fused: {conv_name} <- {bn_name}")
50
+
51
+ # Replace BatchNorm layers with Identity
52
+ _replace_bn_with_identity(model, [bn_name for _, bn_name, _, _ in pairs])
53
+
54
+ print(f" Total {len(pairs)} BatchNorm layers fused.")
55
+
56
+ return model
57
+
58
+
59
+ def _find_bn_pairs(model: nn.Module) -> List[Tuple[str, str, nn.Module, nn.Module]]:
60
+ """Find Conv/Linear -> BatchNorm pairs in the model.
61
+
62
+ Returns:
63
+ List of (conv_name, bn_name, conv_module, bn_module) tuples
64
+ """
65
+ pairs = []
66
+ prev_name = None
67
+ prev_module = None
68
+
69
+ for name, module in model.named_modules():
70
+ if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
71
+ # Check if previous layer is Conv or Linear
72
+ if prev_module is not None:
73
+ if isinstance(prev_module, (nn.Conv2d, nn.Linear)):
74
+ pairs.append((prev_name, name, prev_module, module))
75
+ elif hasattr(prev_module, 'weight'):
76
+ # PoTConv2d or PoTLinear
77
+ pairs.append((prev_name, name, prev_module, module))
78
+
79
+ # Track previous layer (skip non-compute layers)
80
+ if isinstance(module, (nn.Conv2d, nn.Linear)) or hasattr(module, 'weight'):
81
+ if not isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
82
+ prev_name = name
83
+ prev_module = module
84
+
85
+ return pairs
86
+
87
+
88
+ def _fuse_single_bn(conv: nn.Module, bn: nn.Module):
89
+ """Fuse a single BatchNorm into its preceding Conv/Linear.
90
+
91
+ Modifies conv.weight and conv.bias in-place.
92
+ """
93
+ with torch.no_grad():
94
+ # Get BatchNorm parameters
95
+ gamma = bn.weight # γ (scale)
96
+ beta = bn.bias # β (shift)
97
+ mean = bn.running_mean # μ
98
+ var = bn.running_var # σ²
99
+ eps = bn.eps
100
+
101
+ # Compute scale factor: γ / √(σ² + ε)
102
+ std = torch.sqrt(var + eps)
103
+ scale = gamma / std
104
+
105
+ # Compute bias adjustment: β - γ * μ / √(σ² + ε)
106
+ bias_adjust = beta - gamma * mean / std
107
+
108
+ # Get conv weight shape
109
+ weight = conv.weight
110
+
111
+ if isinstance(conv, nn.Conv2d) or (hasattr(conv, 'kernel_size')):
112
+ # Conv2d: weight shape is [out_ch, in_ch, kH, kW]
113
+ # Scale each output channel
114
+ scale_shape = scale.view(-1, 1, 1, 1)
115
+ conv.weight.data = weight * scale_shape
116
+ else:
117
+ # Linear: weight shape is [out_features, in_features]
118
+ scale_shape = scale.view(-1, 1)
119
+ conv.weight.data = weight * scale_shape
120
+
121
+ # Handle bias
122
+ if conv.bias is not None:
123
+ # Existing bias: b_fused = b * scale + bias_adjust
124
+ conv.bias.data = conv.bias * scale + bias_adjust
125
+ else:
126
+ # No existing bias: create one with just bias_adjust
127
+ conv.bias = nn.Parameter(bias_adjust.clone())
128
+
129
+
130
+ def _replace_bn_with_identity(model: nn.Module, bn_names: List[str]):
131
+ """Replace BatchNorm layers with Identity.
132
+
133
+ This ensures the fused BatchNorm layers don't affect forward pass.
134
+ """
135
+ for bn_name in bn_names:
136
+ # Navigate to parent and replace
137
+ parts = bn_name.split('.')
138
+
139
+ if len(parts) == 1:
140
+ # Top-level module
141
+ setattr(model, bn_name, nn.Identity())
142
+ else:
143
+ # Nested module
144
+ parent = model
145
+ for part in parts[:-1]:
146
+ if part.isdigit():
147
+ parent = parent[int(part)]
148
+ else:
149
+ parent = getattr(parent, part)
150
+
151
+ child_name = parts[-1]
152
+ if child_name.isdigit():
153
+ parent[int(child_name)] = nn.Identity()
154
+ else:
155
+ setattr(parent, child_name, nn.Identity())
156
+
157
+
158
+ def check_bn_fused(model: nn.Module) -> bool:
159
+ """Check if all BatchNorm layers have been fused.
160
+
161
+ Returns:
162
+ True if no BatchNorm layers remain (or all are Identity)
163
+ """
164
+ for module in model.modules():
165
+ if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
166
+ return False
167
+ return True
@@ -0,0 +1,11 @@
1
+ """Neural network modules for potnn."""
2
+
3
+ from .base import PoTLayerBase
4
+ from .linear import PoTLinear
5
+ from .conv import PoTConv2d
6
+ from .conv1d import PoTConv1d
7
+ from .depthwise import PoTDepthwiseConv2d
8
+ from .add import PoTAdd
9
+ from .avgpool import PoTGlobalAvgPool
10
+
11
+ __all__ = ['PoTLayerBase', 'PoTLinear', 'PoTConv2d', 'PoTConv1d', 'PoTDepthwiseConv2d', 'PoTAdd', 'PoTGlobalAvgPool']
potnn/modules/add.py ADDED
@@ -0,0 +1,114 @@
1
+ """PoT Add layer for skip/residual connections."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+
7
+
8
+ class PoTAdd(nn.Module):
9
+ """Add layer for residual/skip connections with scale alignment.
10
+
11
+ Skip connection에서 두 branch의 scale이 다를 수 있음:
12
+ - x: 원래 입력 (scale_x)
13
+ - y: conv 거친 출력 (scale_y)
14
+
15
+ 이 레이어는 scale 정합 후 더하기를 수행:
16
+ output = rescale(x) + y
17
+
18
+ rescale은 정수 MUL + shift로 구현:
19
+ x_aligned = (x * rescale_mult) >> rescale_shift
20
+
21
+ 컴파일 타임에 rescale_mult, rescale_shift 계산.
22
+ 런타임에 float 연산 없음.
23
+
24
+ 사용 예:
25
+ # ResNet block
26
+ identity = x
27
+ out = conv2(relu(conv1(x)))
28
+ out = add_layer(identity, out) # identity + out with scale alignment
29
+ out = relu(out)
30
+ """
31
+
32
+ def __init__(self):
33
+ """Initialize PoTAdd layer."""
34
+ super().__init__()
35
+
36
+ # Scale alignment: x_aligned = (x * rescale_mult) >> rescale_shift
37
+ self.register_buffer('rescale_mult', torch.tensor(128)) # 기본값: 1.0 * 128
38
+ self.register_buffer('rescale_shift', torch.tensor(7)) # 기본값: >>7
39
+
40
+ # Activation scale for output (set during calibration)
41
+ self.register_buffer('act_scale', None)
42
+
43
+ # Scale info for the two inputs (set during calibration)
44
+ self.register_buffer('scale_x', None) # scale of first input (skip)
45
+ self.register_buffer('scale_y', None) # scale of second input (conv output)
46
+
47
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
48
+ """Forward pass: aligned add.
49
+
50
+ Args:
51
+ x: First input (typically skip/identity branch)
52
+ y: Second input (typically conv output)
53
+
54
+ Returns:
55
+ x + y with scale alignment applied to x
56
+ """
57
+ # QAT mode: simulate integer rescale
58
+ if getattr(self, 'quantize', False) and self.scale_x is not None:
59
+ # Simulate: x_aligned = (x * rescale_mult) >> rescale_shift
60
+ # In float: x_aligned = x * (rescale_mult / 2^rescale_shift)
61
+ ratio = self.rescale_mult.float() / (1 << self.rescale_shift.item())
62
+ x = x * ratio
63
+
64
+ return x + y
65
+
66
+ def set_scales(self, scale_x: float, scale_y: float):
67
+ """Set input scales and compute rescale_mult/rescale_shift.
68
+
69
+ C 코드: skip_rescaled = (skip_int * mult) >> shift
70
+ skip을 conv scale 기준으로 맞추려면:
71
+ ratio = scale_y / scale_x (conv/skip)
72
+
73
+ Args:
74
+ scale_x: Activation scale of first input (skip branch)
75
+ scale_y: Activation scale of second input (conv branch)
76
+ """
77
+ self.scale_x = torch.tensor(scale_x)
78
+ self.scale_y = torch.tensor(scale_y)
79
+
80
+ # C 코드와 일치: skip을 conv scale 기준으로 변환
81
+ ratio = scale_y / scale_x
82
+
83
+ # 정수 양자화: ratio ≈ rescale_mult / 2^rescale_shift
84
+ # mult = ratio * 2^shift, shift를 조정하여 mult를 1~255 범위로
85
+ base_shift = 7
86
+ mult = round(ratio * (1 << base_shift))
87
+
88
+ # mult가 너무 크면 shift 감소 (mult = ratio * 2^shift)
89
+ while mult > 255 and base_shift > 0:
90
+ base_shift -= 1
91
+ mult = round(ratio * (1 << base_shift))
92
+
93
+ # mult가 너무 작으면 shift 증가
94
+ while mult < 32 and base_shift < 15:
95
+ base_shift += 1
96
+ mult = round(ratio * (1 << base_shift))
97
+
98
+ # clamp mult to safe range
99
+ mult = max(1, min(255, mult))
100
+
101
+ self.rescale_mult = torch.tensor(mult)
102
+ self.rescale_shift = torch.tensor(base_shift)
103
+
104
+ # Output scale is same as y's scale (after alignment)
105
+ self.act_scale = torch.tensor(scale_y)
106
+
107
+ def extra_repr(self) -> str:
108
+ """String representation."""
109
+ s = f"rescale_mult={self.rescale_mult.item()}, rescale_shift={self.rescale_shift.item()}"
110
+ if self.scale_x is not None:
111
+ ratio = self.scale_x.item() / self.scale_y.item()
112
+ approx = self.rescale_mult.item() / (1 << self.rescale_shift.item())
113
+ s += f", ratio={ratio:.3f}, approx={approx:.3f}"
114
+ return s
@@ -0,0 +1,173 @@
1
+ """PoT Global Average Pooling layer."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+
7
+
8
+ from ..quantize.integer_ops import round_half_up_ste, floor_ste
9
+
10
+ class PoTGlobalAvgPool(nn.Module):
11
+ """Global Average Pooling with PoT-compatible quantization.
12
+
13
+ [Integer-Only QAT Mode]
14
+ Forward pass simulates C integer arithmetic:
15
+ - Power of 2 size: (sum + (size//2)) >> log2(size)
16
+ - Generic size: (sum * div_mult + round_const) >> div_shift
17
+ """
18
+
19
+ def __init__(self):
20
+ """Initialize PoTGlobalAvgPool."""
21
+ super().__init__()
22
+
23
+ # Division parameters
24
+ self.register_buffer('div_mult', torch.tensor(1))
25
+ self.register_buffer('div_shift', torch.tensor(0))
26
+ self.register_buffer('pool_size', torch.tensor(0))
27
+
28
+ # Activation scale (passed from previous layer)
29
+ self.register_buffer('act_scale', None)
30
+
31
+ self.quantize = False
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ """Forward pass: global average pooling.
35
+
36
+ Args:
37
+ x: Input tensor of shape (N, C, H, W)
38
+
39
+ Returns:
40
+ Output tensor of shape (N, C)
41
+ """
42
+ if not self.quantize:
43
+ # Float mode
44
+ return x.mean(dim=(2, 3))
45
+
46
+ if not getattr(self, 'integer_sim_enabled', False):
47
+ # Float QAT mode
48
+ return x.mean(dim=(2, 3))
49
+
50
+ # Calculate pool size dynamically if not set
51
+ current_pool_size = x.shape[2] * x.shape[3]
52
+
53
+ if self.pool_size.item() != current_pool_size:
54
+ self.set_pool_size(x.shape[2], x.shape[3])
55
+
56
+ # Integer mode
57
+ # x is float but represents integer values (from previous layer)
58
+ # We assume input is already scaled by act_scale of previous layer?
59
+ # No, in our new design, previous layer output is "dequantized" float.
60
+ # So x is float.
61
+ # But GAP in C operates on the accumulated integer values?
62
+ # Wait, C GAP input is the output of the previous layer *before* requantization?
63
+ # No, usually GAP follows a Conv/ReLU layer.
64
+ # The previous layer output is int8 (requantized).
65
+ # So x here is int8 values (represented as float).
66
+ # But wait, our Conv layer returns `out / act_scale`.
67
+ # So x is float.
68
+ # We need to recover the int8 values: `x_int = round(x * prev_scale)`
69
+ # But `prev_scale` is `act_scale` of previous layer.
70
+ # If we assume `act_scale` is passed to this layer, we can use it.
71
+
72
+ # However, for GAP, usually we just average the values.
73
+ # mean(x) = sum(x) / N
74
+ # If x = x_int / scale, then mean(x) = sum(x_int) / N / scale
75
+ # = (sum(x_int) / N) / scale
76
+ # So we can just compute mean(x) in float?
77
+ # NO! The rounding behavior of `sum(x_int) / N` in integer arithmetic is different from float mean.
78
+ # C: `(sum(x_int) + N//2) >> log2(N)`
79
+ # Python float: `mean(x_int)` (exact)
80
+ # We must simulate the integer division on `x_int`.
81
+
82
+ # So:
83
+ # 1. Recover x_int: x_int = round(x * act_scale)
84
+ # 2. Compute sum(x_int)
85
+ # 3. Integer division
86
+ # 4. Convert back to float: result / act_scale
87
+
88
+ # We need `act_scale` of the input.
89
+ # Usually this is passed or stored.
90
+ # Let's assume `act_scale` is available (set by `set_prev_act_scale` or similar mechanism).
91
+ # But `PoTGlobalAvgPool` doesn't inherit `PoTLayerBase` currently.
92
+ # Let's assume for now we just operate on `x` assuming it's `x_int` if `act_scale` is 1.0.
93
+ # But wait, if we don't know `act_scale`, we can't recover `x_int`.
94
+
95
+ # In the user's specific case (SimpleNet), GAP follows Conv2.
96
+ # Conv2 output is `out / act_scale`.
97
+ # So GAP input is float.
98
+ # If we want to match C, we need to know `act_scale`.
99
+
100
+ # Let's check how `PoTGlobalAvgPool` is used.
101
+ # It seems it's used in `SimpleNet`.
102
+ # We should probably add `act_scale` management to `PoTGlobalAvgPool`.
103
+
104
+ # For now, let's implement the integer logic assuming `x` is `x_int`?
105
+ # No, `PoTConv2d` divides by scale.
106
+
107
+ # Solution: `PoTGlobalAvgPool` needs `act_scale`.
108
+ # We'll add `set_act_scale` method.
109
+
110
+ scale = self.act_scale if self.act_scale is not None else torch.tensor(1.0)
111
+
112
+ # Input should be integer values from previous layer
113
+ # Round to ensure exact integer (may have floating point precision errors)
114
+ # Use STE to maintain gradient flow during training
115
+ x_int = round_half_up_ste(x)
116
+
117
+
118
+
119
+ # 2. Sum over H, W
120
+ sum_val = x_int.sum(dim=(2, 3))
121
+
122
+
123
+ # 3. Integer Division
124
+ pool_size = int(self.pool_size.item())
125
+ if (pool_size & (pool_size - 1)) == 0:
126
+ # Power of 2
127
+ shift = int(math.log2(pool_size))
128
+ round_const = 1 << (shift - 1)
129
+ # Power of 2
130
+ shift = int(math.log2(pool_size))
131
+ round_const = 1 << (shift - 1)
132
+ # (sum + round) >> shift
133
+ out_int = floor_ste((sum_val + round_const) / (1 << shift))
134
+
135
+ else:
136
+ # Generic
137
+ mult = self.div_mult.item()
138
+ shift = self.div_shift.item()
139
+ # (sum * mult + round) >> shift
140
+ # round_const for shift is 1<<(shift-1)
141
+ # But wait, C generic implementation:
142
+ # avg = (sum * div_mult + (1<<(div_shift-1))) >> div_shift
143
+ round_const = 1 << (shift - 1) if shift > 0 else 0
144
+ val = sum_val * mult + round_const
145
+ out_int = floor_ste(val / (1 << shift))
146
+
147
+ # Output is int8 (no conversion back to float)
148
+ return out_int
149
+
150
+ def set_pool_size(self, h: int, w: int):
151
+ """Set pool size and compute div_mult/div_shift."""
152
+ pool_size = h * w
153
+ self.pool_size = torch.tensor(pool_size)
154
+
155
+ if pool_size > 0 and (pool_size & (pool_size - 1)) == 0:
156
+ self.div_mult = torch.tensor(1)
157
+ self.div_shift = torch.tensor(int(math.log2(pool_size)))
158
+ else:
159
+ base_shift = 15
160
+ mult = round((1 << base_shift) / pool_size)
161
+ while mult > 255 and base_shift > 8:
162
+ base_shift -= 1
163
+ mult = round((1 << base_shift) / pool_size)
164
+ self.div_mult = torch.tensor(max(1, min(65535, mult)))
165
+ self.div_shift = torch.tensor(base_shift)
166
+
167
+ def prepare_qat(self, act_scale=None):
168
+ self.quantize = True
169
+ if act_scale is not None:
170
+ self.act_scale = torch.tensor(act_scale)
171
+
172
+ def extra_repr(self) -> str:
173
+ return f"pool_size={self.pool_size.item()}, quantize={self.quantize}"
potnn/modules/base.py ADDED
@@ -0,0 +1,225 @@
1
+ """Base class for all PoT (Power-of-Two) quantized layers."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class PoTLayerBase(nn.Module):
9
+ """Base class for all PoT layers with alpha scaling and activation quantization.
10
+
11
+ This class provides:
12
+ - Alpha scaling parameter (learnable)
13
+ - Activation scale (fixed after calibration)
14
+ - QAT (Quantization-Aware Training) mode management
15
+ - Alpha regularization loss
16
+ - Integer simulation mode for C-compatible inference
17
+ """
18
+
19
+ def __init__(self, encoding='unroll'):
20
+ """Initialize PoT layer base.
21
+
22
+ Args:
23
+ encoding: Encoding type for weight quantization
24
+ - 'unroll': 17레벨 {0, ±1, ±2, ..., ±128} (default)
25
+ - 'fp130': 16레벨 {±1, ±2, ..., ±128} (Zero 없음)
26
+ - '5level': 5레벨 {-8, -1, 0, 1, 8}
27
+ - '2bit': 4레벨 {-2, -1, 1, 2} (Zero 없음)
28
+ - 'ternary': 3레벨 {-1, 0, 1}
29
+ """
30
+ super().__init__()
31
+ self.encoding = encoding
32
+
33
+ # Alpha scaling parameter (learnable)
34
+ # raw_alpha → softplus → clamp(0.01) → alpha
35
+ self.raw_alpha = nn.Parameter(torch.tensor(0.5))
36
+
37
+ # Alpha initial value for regularization
38
+ # This will be updated during calibration to match the initialized alpha
39
+ self.register_buffer('alpha_init', torch.tensor(0.5))
40
+
41
+ # Activation scale (fixed after calibration)
42
+ self.register_buffer('act_scale', None)
43
+
44
+ # QAT mode flag
45
+ self.quantize = False
46
+
47
+ # === Integer Simulation Parameters ===
48
+ # These enable C-compatible integer arithmetic simulation
49
+
50
+ # Layer position flags
51
+ self.register_buffer('is_first_layer', torch.tensor(False))
52
+ self.register_buffer('is_last_layer', torch.tensor(False))
53
+
54
+ # Previous layer's act_scale (for scale chain)
55
+ self.register_buffer('prev_act_scale', None)
56
+
57
+ # Input std (for first layer standardization absorption)
58
+ # Per-channel tensor [in_ch] or None
59
+ self.register_buffer('input_std', None)
60
+
61
+ # Input mean (for first layer bias adjustment)
62
+ # Per-channel tensor [in_ch] or None
63
+ self.register_buffer('input_mean', None)
64
+
65
+ # Pre-computed integer scale parameters
66
+ self.register_buffer('scale_int', None)
67
+ self.register_buffer('shift', None)
68
+
69
+ # Integer simulation mode flag
70
+ self.use_integer_sim = False
71
+
72
+ # 5level encoding constraint flag
73
+ # When True, enforces max 3 consecutive zeros (skip field is 2 bits)
74
+ self.enforce_5level_constraint = False
75
+
76
+ @property
77
+ def alpha(self):
78
+ """Get positive alpha value using softplus + clamp.
79
+
80
+ Returns:
81
+ Positive alpha value for scaling PoT weights.
82
+ """
83
+ return F.softplus(self.raw_alpha).clamp(min=0.01)
84
+
85
+ def calibrate(self, act_max):
86
+ """Set activation scale based on calibration.
87
+
88
+ Args:
89
+ act_max: Maximum activation value from calibration.
90
+ """
91
+ if act_max > 0:
92
+ self.act_scale = torch.tensor(127.0 / act_max)
93
+ else:
94
+ self.act_scale = torch.tensor(1.0)
95
+
96
+ def prepare_qat(self):
97
+ """Enable QAT (Quantization-Aware Training) mode."""
98
+ self.quantize = True
99
+
100
+ def alpha_reg_loss(self, lambda_reg=0.01):
101
+ """Calculate alpha regularization loss.
102
+
103
+ This loss encourages alpha to stay close to its initial value,
104
+ preventing it from drifting too far during training.
105
+
106
+ Args:
107
+ lambda_reg: Regularization strength (default: 0.01)
108
+
109
+ Returns:
110
+ Alpha regularization loss value.
111
+ """
112
+ # Use the stored alpha_init which is set during calibration
113
+ return lambda_reg * (self.alpha - self.alpha_init) ** 2
114
+
115
+ # === Integer Simulation Methods ===
116
+
117
+ def set_layer_position(self, is_first: bool, is_last: bool):
118
+ """Set layer position in the network.
119
+
120
+ Args:
121
+ is_first: True if this is the first PoT layer (input is uint8)
122
+ is_last: True if this is the last PoT layer (no ReLU)
123
+ """
124
+ self.is_first_layer = torch.tensor(is_first)
125
+ self.is_last_layer = torch.tensor(is_last)
126
+
127
+ def set_prev_act_scale(self, prev_scale: float):
128
+ """Set previous layer's activation scale.
129
+
130
+ Args:
131
+ prev_scale: Previous layer's act_scale value
132
+ """
133
+ if prev_scale is not None:
134
+ self.prev_act_scale = torch.tensor(prev_scale)
135
+ else:
136
+ self.prev_act_scale = None
137
+
138
+ def set_input_std(self, std, mean=None):
139
+ """Set input statistics for first layer.
140
+
141
+ Args:
142
+ std: Standard deviation - float (single channel) or List[float] (multi-channel)
143
+ mean: Mean values - float (single channel) or List[float] (multi-channel)
144
+ """
145
+ # Convert to per-channel tensor
146
+ if isinstance(std, (int, float)):
147
+ self.input_std = torch.tensor([float(std)])
148
+ else:
149
+ self.input_std = torch.tensor([float(s) for s in std])
150
+
151
+ if mean is not None:
152
+ if isinstance(mean, (int, float)):
153
+ self.input_mean = torch.tensor([float(mean)])
154
+ else:
155
+ self.input_mean = torch.tensor([float(m) for m in mean])
156
+ else:
157
+ self.input_mean = None
158
+
159
+ def compute_integer_params(self):
160
+ """Compute integer scale parameters for C-compatible inference.
161
+
162
+ MUST match export.py calculate_combined_scales() exactly!
163
+
164
+ Returns:
165
+ (scale_int, shift) tuple
166
+ """
167
+ scale_int, shift, _ = self._compute_scale_and_shift()
168
+
169
+ self.scale_int = torch.tensor(scale_int, device=self.raw_alpha.device)
170
+ self.shift = torch.tensor(shift, device=self.raw_alpha.device)
171
+
172
+ return scale_int, shift
173
+
174
+ def _compute_scale_and_shift(self):
175
+ """Internal method to compute scale_int and shift dynamically.
176
+
177
+ Returns:
178
+ (scale_int, shift, combined_scale)
179
+ """
180
+ # self.alpha is already softplus(raw_alpha).clamp(0.01) via property
181
+ alpha = self.alpha.item()
182
+ act_scale = self.act_scale.item() if self.act_scale is not None else None
183
+
184
+ is_first = self.is_first_layer.item()
185
+
186
+ # Calculate combined_scale - EXACTLY like export.py
187
+ if is_first:
188
+ # Use average std for combined_scale (matches export.py)
189
+ if self.input_std is not None:
190
+ input_std = self.input_std.mean().item()
191
+ else:
192
+ input_std = 1.0
193
+ if act_scale is not None:
194
+ combined_scale = alpha * act_scale / input_std
195
+ else:
196
+ combined_scale = alpha / input_std
197
+ else:
198
+ prev_scale = self.prev_act_scale.item() if self.prev_act_scale is not None else 1.0
199
+ if act_scale is not None:
200
+ combined_scale = alpha * act_scale / prev_scale
201
+ else:
202
+ combined_scale = alpha / prev_scale
203
+
204
+ # Determine shift - EXACTLY like export.py
205
+ base_shift = 0
206
+ scale_magnitude = abs(combined_scale)
207
+
208
+ # Target: scale_int around 64-512 for precision (export.py uses 64-512)
209
+ while scale_magnitude < 64 and base_shift < 24:
210
+ scale_magnitude *= 2
211
+ base_shift += 1
212
+ while scale_magnitude > 512 and base_shift > 0:
213
+ scale_magnitude /= 2
214
+ base_shift -= 1
215
+
216
+ # For first layer, add +8 for /256 absorption
217
+ if is_first:
218
+ combined_shift = base_shift + 8
219
+ else:
220
+ combined_shift = base_shift
221
+
222
+ # Calculate integer scale
223
+ scale_int = round(combined_scale * (1 << base_shift))
224
+
225
+ return scale_int, combined_shift, combined_scale