potnn 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,207 @@
1
+ """Integer simulation operations with Straight-Through Estimator (STE).
2
+
3
+ This module provides the core building blocks for "Integer-Only QAT".
4
+ All operations in the forward pass simulate C integer arithmetic exactly,
5
+ while the backward pass allows gradients to flow for training.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ # =============================================================================
12
+ # Core Rounding Functions
13
+ # =============================================================================
14
+
15
+ class RoundHalfUpSTE(torch.autograd.Function):
16
+ """Half-up rounding with STE (C style).
17
+
18
+ Forward: floor(x + 0.5)
19
+ Backward: identity (gradient passes through unchanged)
20
+ """
21
+ @staticmethod
22
+ def forward(ctx, x):
23
+ return torch.floor(x + 0.5)
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ return grad_output
28
+
29
+ def round_half_up_ste(x: torch.Tensor) -> torch.Tensor:
30
+ """Round half up with STE. Matches C behavior: (int)(x + 0.5)."""
31
+ return RoundHalfUpSTE.apply(x)
32
+
33
+ class FloorSTE(torch.autograd.Function):
34
+ """Floor with STE."""
35
+ @staticmethod
36
+ def forward(ctx, x):
37
+ return torch.floor(x)
38
+
39
+ @staticmethod
40
+ def backward(ctx, grad_output):
41
+ return grad_output
42
+
43
+ def floor_ste(x: torch.Tensor) -> torch.Tensor:
44
+ return FloorSTE.apply(x)
45
+
46
+ class ClampSTE(torch.autograd.Function):
47
+ """Clamp with STE."""
48
+ @staticmethod
49
+ def forward(ctx, x, min_val, max_val):
50
+ return x.clamp(min_val, max_val)
51
+
52
+ @staticmethod
53
+ def backward(ctx, grad_output):
54
+ return grad_output, None, None
55
+
56
+ def clamp_ste(x: torch.Tensor, min_val: float, max_val: float) -> torch.Tensor:
57
+ return ClampSTE.apply(x, min_val, max_val)
58
+
59
+
60
+ # =============================================================================
61
+ # Integer Simulation Functions
62
+ # =============================================================================
63
+
64
+ def fake_quantize_input(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
65
+ """Quantize float input to int8 range (simulated as float).
66
+
67
+ Args:
68
+ x: Input tensor (float)
69
+ scale: Input scale factor (127.0 / max_val) or similar
70
+
71
+ Returns:
72
+ Quantized tensor (float dtype, but integer values)
73
+ """
74
+ # x_int = round(x * scale)
75
+ # clamp to [-128, 127] (or [0, 255] for uint8 if handled externally)
76
+ # Here we assume signed int8 for general case, but first layer might be uint8.
77
+ # We'll use round_half_up_ste for consistency with C.
78
+ return clamp_ste(round_half_up_ste(x * scale), -128.0, 127.0)
79
+
80
+ def fake_quantize_input_uint8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
81
+ """Quantize float input to uint8 range [0, 255]."""
82
+ return clamp_ste(round_half_up_ste(x * scale), 0.0, 255.0)
83
+
84
+
85
+ class FakeRequantizeSTE(torch.autograd.Function):
86
+ """Simulate C-style requantization: (acc * scale_int + round) >> shift.
87
+
88
+ This is the core of Integer-Only QAT.
89
+ """
90
+ @staticmethod
91
+ def forward(ctx, acc, scale_int, shift):
92
+ # acc: int32 accumulator (simulated as float)
93
+ # scale_int: integer scale
94
+ # shift: integer shift
95
+
96
+ # C logic:
97
+ # int64_t temp = (int64_t)acc * scale_int;
98
+ # temp += (1 << (shift - 1)); // round
99
+ # output = temp >> shift;
100
+
101
+ # Python simulation (using float for large range, but logic is integer)
102
+ # Note: We use float arithmetic but ensure integer results
103
+
104
+ if shift > 0:
105
+ round_const = 1 << (shift - 1)
106
+ else:
107
+ round_const = 0
108
+
109
+ # 1. Multiply (use double precision to avoid float32 rounding errors for large acc)
110
+ # acc is float32 but represents integer values.
111
+ # acc * scale_int can exceed 2^24 (16M), causing precision loss in float32.
112
+ # double (float64) has 53 bits significand, sufficient for > 10^15.
113
+ val = acc.double() * scale_int
114
+
115
+ # 2. Add round constant
116
+ val = val + round_const
117
+
118
+ # 3. Shift (floor division by 2^shift)
119
+ # Use integer division simulation in double
120
+ divisor = float(1 << shift)
121
+ val = torch.floor(val / divisor)
122
+
123
+ return val.float()
124
+
125
+ @staticmethod
126
+ def backward(ctx, grad_output):
127
+ # STE: Gradient flows through as if it was just multiplication by (scale_int / 2^shift)
128
+ # out ≈ acc * (scale_int / 2^shift)
129
+ # grad_acc = grad_out * (scale_int / 2^shift)
130
+
131
+ scale_int = ctx.saved_tensors[0] if hasattr(ctx, 'saved_tensors') else 1.0 # Context saving not implemented in staticmethod forward
132
+ # Actually, we need to save context. Let's redo this properly.
133
+ return grad_output, None, None
134
+
135
+ # Redefine properly with context
136
+ class FakeRequantizeSTE(torch.autograd.Function):
137
+ @staticmethod
138
+ def forward(ctx, acc, scale_int, shift):
139
+ ctx.save_for_backward(torch.tensor(scale_int, dtype=torch.float32, device=acc.device),
140
+ torch.tensor(shift, dtype=torch.float32, device=acc.device))
141
+
142
+ scale_int_val = int(scale_int)
143
+ shift_val = int(shift)
144
+
145
+ if shift_val > 0:
146
+ round_const = 1 << (shift_val - 1)
147
+ else:
148
+ round_const = 0
149
+
150
+ val = acc * scale_int_val + round_const
151
+ divisor = float(1 << shift_val)
152
+ val = torch.floor(val / divisor)
153
+
154
+ return val
155
+
156
+ @staticmethod
157
+ def backward(ctx, grad_output):
158
+ scale_int, shift = ctx.saved_tensors
159
+ # Effective scale = scale_int / 2^shift
160
+ effective_scale = scale_int / (2.0 ** shift)
161
+ return grad_output * effective_scale, None, None
162
+
163
+ def fake_requantize(acc: torch.Tensor, scale_int: int, shift: int) -> torch.Tensor:
164
+ """Simulate C-style requantization with STE."""
165
+ return FakeRequantizeSTE.apply(acc, float(scale_int), float(shift))
166
+
167
+
168
+ def fake_integer_gap(x: torch.Tensor) -> torch.Tensor:
169
+ """Simulate C-style Global Average Pooling: (sum + 32) >> 6.
170
+
171
+ Assumes 8x8 input (64 elements).
172
+ For generic size HxW: (sum + (HW//2)) >> log2(HW)
173
+ """
174
+ # We assume the input x is already int8 (or output of previous layer)
175
+ # Shape: [N, C, H, W]
176
+
177
+ # 1. Sum over H, W
178
+ sum_val = x.sum(dim=(2, 3)) # [N, C]
179
+
180
+ # 2. Add round constant and shift
181
+ # We need to know the pool size.
182
+ # For now, let's assume 8x8=64 (shift 6) as in the specific issue.
183
+ # In general, this should be parameterized.
184
+ # But for this function, let's implement the generic logic if possible,
185
+ # or just the specific logic for the user's case.
186
+ # The user's case was (sum + 32) >> 6.
187
+
188
+ pool_size = x.shape[2] * x.shape[3]
189
+ import math
190
+
191
+ # Check if power of 2
192
+ if (pool_size & (pool_size - 1)) == 0:
193
+ shift = int(math.log2(pool_size))
194
+ round_const = 1 << (shift - 1)
195
+
196
+ val = sum_val + round_const
197
+ val = torch.floor(val / (1 << shift))
198
+ else:
199
+ # Generic division: sum / pool_size
200
+ # C: (sum * div_mult + round) >> div_shift
201
+ # For simulation, we can just do floor(sum / pool_size + 0.5) = round_half_up(sum / pool_size)
202
+ # But to be bit-exact with C generic implementation, we might need the mult/shift logic.
203
+ # For now, let's use round_half_up(mean) as a close approximation if exact params aren't available,
204
+ # but ideally we should use the exact logic.
205
+ val = round_half_up_ste(sum_val / pool_size)
206
+
207
+ return val
@@ -0,0 +1,225 @@
1
+ """Integer simulation functions for QAT.
2
+
3
+ These functions simulate C integer operations in PyTorch while allowing
4
+ gradient flow through Straight-Through Estimator (STE).
5
+
6
+ C operations:
7
+ - round: (x + 0.5) truncation
8
+ - clamp: min/max saturation
9
+ - requantize: (acc * scale_int + round) >> shift
10
+
11
+ Python simulation must match C bit-for-bit for QAT to be accurate.
12
+
13
+ Usage:
14
+ from potnn.quantize.integer_sim import (
15
+ round_ste, floor_ste, clamp_ste,
16
+ quantize_to_int8_ste, quantize_to_uint8_ste,
17
+ requantize_ste, compute_scale_params
18
+ )
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+
25
+ class RoundSTE(torch.autograd.Function):
26
+ """Round with Straight-Through Estimator.
27
+
28
+ Forward: torch.round(x)
29
+ Backward: gradient passes through unchanged
30
+ """
31
+
32
+ @staticmethod
33
+ def forward(ctx, x):
34
+ return torch.round(x)
35
+
36
+ @staticmethod
37
+ def backward(ctx, grad_output):
38
+ return grad_output
39
+
40
+
41
+ class RoundHalfUpSTE(torch.autograd.Function):
42
+ """Half-up rounding with STE (C style).
43
+
44
+ Forward: floor(x + 0.5) - matches C's (x + 0.5) truncation
45
+ Backward: gradient passes through unchanged
46
+
47
+ This matches C integer rounding:
48
+ (int)(x + 0.5) for positive x
49
+ (x * scale + (1 << (shift-1))) >> shift
50
+ """
51
+
52
+ @staticmethod
53
+ def forward(ctx, x):
54
+ return torch.floor(x + 0.5)
55
+
56
+ @staticmethod
57
+ def backward(ctx, grad_output):
58
+ return grad_output
59
+
60
+
61
+ class FloorSTE(torch.autograd.Function):
62
+ """Floor with Straight-Through Estimator.
63
+
64
+ Forward: torch.floor(x)
65
+ Backward: gradient passes through unchanged
66
+
67
+ Used for integer division: a // b = floor(a / b)
68
+ """
69
+
70
+ @staticmethod
71
+ def forward(ctx, x):
72
+ return torch.floor(x)
73
+
74
+ @staticmethod
75
+ def backward(ctx, grad_output):
76
+ return grad_output
77
+
78
+
79
+ class ClampSTE(torch.autograd.Function):
80
+ """Clamp with Straight-Through Estimator.
81
+
82
+ Forward: torch.clamp(x, min_val, max_val)
83
+ Backward: gradient passes through unchanged
84
+
85
+ Note: Standard clamp has zero gradient outside [min, max].
86
+ STE version allows gradient to flow for training stability.
87
+ """
88
+
89
+ @staticmethod
90
+ def forward(ctx, x, min_val, max_val):
91
+ return torch.clamp(x, min_val, max_val)
92
+
93
+ @staticmethod
94
+ def backward(ctx, grad_output):
95
+ return grad_output, None, None
96
+
97
+
98
+ def round_ste(x: torch.Tensor) -> torch.Tensor:
99
+ """Round with STE for gradient flow (C style half-up)."""
100
+ return RoundHalfUpSTE.apply(x) # floor(x + 0.5) - matches C
101
+
102
+
103
+ def round_half_up_ste(x: torch.Tensor) -> torch.Tensor:
104
+ """Half-up rounding with STE (C style).
105
+
106
+ This matches C integer rounding behavior.
107
+ Example: 2.5 -> 3, -2.5 -> -2
108
+ """
109
+ return RoundHalfUpSTE.apply(x)
110
+
111
+
112
+ def floor_ste(x: torch.Tensor) -> torch.Tensor:
113
+ """Floor with STE for gradient flow."""
114
+ return FloorSTE.apply(x)
115
+
116
+
117
+ def clamp_ste(x: torch.Tensor, min_val: float, max_val: float) -> torch.Tensor:
118
+ """Clamp with STE for gradient flow."""
119
+ return ClampSTE.apply(x, min_val, max_val)
120
+
121
+
122
+ def quantize_to_int8_ste(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
123
+ """Quantize tensor to int8 range with STE.
124
+
125
+ Forward:
126
+ x_int = round(x * scale)
127
+ x_int = clamp(x_int, -128, 127)
128
+
129
+ Backward: gradient passes through unchanged
130
+
131
+ Args:
132
+ x: Input tensor (float)
133
+ scale: Quantization scale (127.0 / max_activation)
134
+
135
+ Returns:
136
+ Tensor with int8 values (but float dtype for gradient)
137
+ """
138
+ x_scaled = x * scale
139
+ x_rounded = round_ste(x_scaled)
140
+ x_clamped = clamp_ste(x_rounded, -128.0, 127.0)
141
+ return x_clamped
142
+
143
+
144
+ def quantize_to_uint8_ste(x: torch.Tensor, scale: float = 256.0) -> torch.Tensor:
145
+ """Quantize tensor to uint8 range with STE.
146
+
147
+ For first layer: input [0, 1] -> [0, 255]
148
+
149
+ Args:
150
+ x: Input tensor (float, assumed [0, 1] normalized)
151
+ scale: Quantization scale (default 256 for /256 normalization)
152
+
153
+ Returns:
154
+ Tensor with uint8 values (but float dtype for gradient)
155
+ """
156
+ x_scaled = x * scale
157
+ x_rounded = round_ste(x_scaled)
158
+ x_clamped = clamp_ste(x_rounded, 0.0, 255.0)
159
+ return x_clamped
160
+
161
+
162
+ def requantize_ste(acc: torch.Tensor, scale_int: int, shift: int) -> torch.Tensor:
163
+ """Simulate C requantization with STE.
164
+
165
+ C code:
166
+ out = ((int64_t)acc * scale_int + (1 << (shift-1))) >> shift
167
+
168
+ This is equivalent to:
169
+ out = floor((acc * scale_int + round_const) / divisor)
170
+
171
+ where round_const = 1 << (shift-1), divisor = 1 << shift
172
+
173
+ Args:
174
+ acc: Accumulator tensor (int32 range values in float tensor)
175
+ scale_int: Integer scale factor
176
+ shift: Right shift amount
177
+
178
+ Returns:
179
+ Requantized tensor (int32 range values in float tensor)
180
+ """
181
+ if shift > 0:
182
+ round_const = 1 << (shift - 1)
183
+ else:
184
+ round_const = 0
185
+
186
+ divisor = float(1 << shift)
187
+
188
+ numerator = acc * float(scale_int) + float(round_const)
189
+ result = floor_ste(numerator / divisor)
190
+
191
+ return result
192
+
193
+
194
+ def compute_scale_params(combined_scale: float, target_range: tuple = (64, 512)) -> tuple:
195
+ """Compute integer scale and shift from float scale.
196
+
197
+ Find (scale_int, shift) such that:
198
+ scale_int / (1 << shift) ≈ combined_scale
199
+ target_range[0] <= scale_int <= target_range[1]
200
+
201
+ Args:
202
+ combined_scale: Float scale value (alpha * act_scale / prev_act_scale)
203
+ target_range: Target range for scale_int (default 64-512 to match export.py)
204
+
205
+ Returns:
206
+ (scale_int, shift) tuple
207
+ """
208
+ if combined_scale == 0:
209
+ return 0, 0
210
+
211
+ min_scale, max_scale = target_range
212
+ shift = 0
213
+ scale_magnitude = abs(combined_scale)
214
+
215
+ while scale_magnitude < min_scale and shift < 24:
216
+ scale_magnitude *= 2
217
+ shift += 1
218
+
219
+ while scale_magnitude > max_scale and shift > 0:
220
+ scale_magnitude /= 2
221
+ shift -= 1
222
+
223
+ scale_int = round(combined_scale * (1 << shift))
224
+
225
+ return scale_int, shift