eiporion 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ Metadata-Version: 2.4
2
+ Name: eiporion
3
+ Version: 0.1.0
4
+ Summary: INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration.
5
+ Requires-Python: >=3.12
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: bitsandbytes<0.46.0,>=0.45.0
8
+ Requires-Dist: torch
9
+
10
+ # Eiporion
11
+
12
+ INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration.
13
+
14
+ ## Install
15
+
16
+ ```bash
17
+ pip install eiporion
18
+ ```
19
+
20
+ ## Quick use
21
+
22
+ ```python
23
+ from eiporion import BitLinear
24
+ ```
@@ -0,0 +1,15 @@
1
+ # Eiporion
2
+
3
+ INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install eiporion
9
+ ```
10
+
11
+ ## Quick use
12
+
13
+ ```python
14
+ from eiporion import BitLinear
15
+ ```
@@ -0,0 +1,11 @@
1
+ from .eiporionkernels import quantize_fp_to_int8
2
+ from .bitLinear import BitLinear, collect_bitlinear_modules
3
+ from .eiporionoptim import EiporionOptim, EiporionOptimSR
4
+
5
+ __all__ = [
6
+ "BitLinear",
7
+ "EiporionOptim",
8
+ "EiporionOptimSR",
9
+ "collect_bitlinear_modules",
10
+ "quantize_fp_to_int8",
11
+ ]
@@ -0,0 +1,108 @@
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .eiporionkernels import (
7
+ Int8LinearFn,
8
+ consume_bit_grad,
9
+ next_bit_handle,
10
+ register_bit_handle,
11
+ release_bit_handle,
12
+ )
13
+
14
+
15
+ def _init_int8_weight(out_features: int, in_features: int):
16
+ """Kaiming-uniform init, quantise to int8.
17
+
18
+ Matches bnb ``int8_vectorwise_quant``: per-row max_abs / 127 scale.
19
+ W_int8 = clip(round(W / scale), -128, 127) with scale = max_abs_per_row / 127.
20
+ W_eff = int_weight * weight_scale ≈ original kaiming weight.
21
+ """
22
+ weight = torch.empty((out_features, in_features), dtype=torch.float32)
23
+ nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
24
+ w = weight.float()
25
+ scale_per_row = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) / 127.0
26
+ q = torch.round(w / scale_per_row).clamp(-127, 127).to(torch.int8)
27
+ return q.contiguous(), scale_per_row.squeeze(1).contiguous()
28
+
29
+
30
+ class BitLinear(nn.Module):
31
+ """INT8 linear layer matching bnb ``Int8Params`` + DQT paper.
32
+
33
+ * ``int_weight`` — int8 buffer ``[O, K]``, the quantised weight (bnb's CB).
34
+ * ``weight_scale`` — float buffer ``[O]``, per-row max_abs/127 (bnb's SCB/127).
35
+ * Forward: ``W_eff = int_weight * weight_scale``, then standard matmul.
36
+ * Gradients for ``int_weight`` are stashed in ``_BIT_GRAD_CACHE`` and consumed
37
+ by :class:`EiporionOptim` for DQT stochastic rounding.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ in_features: int,
43
+ out_features: int,
44
+ bias: bool = False,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.in_features = int(in_features)
48
+ self.out_features = int(out_features)
49
+
50
+ int_init, scale_init = _init_int8_weight(
51
+ out_features=self.out_features, in_features=self.in_features
52
+ )
53
+ self.register_buffer("int_weight", int_init, persistent=True)
54
+ self.register_buffer("weight_scale", scale_init, persistent=True)
55
+ self.register_buffer(
56
+ "_bit_handle",
57
+ torch.tensor(next_bit_handle(), dtype=torch.int64),
58
+ persistent=True,
59
+ )
60
+ self._registered_handle = int(self._bit_handle.item())
61
+ self.register_load_state_dict_post_hook(self._post_load_state_dict)
62
+
63
+ if bias:
64
+ self.bias = nn.Parameter(torch.zeros(out_features))
65
+ else:
66
+ self.register_parameter("bias", None)
67
+
68
+ @torch.no_grad()
69
+ def reset_int8_(self) -> None:
70
+ int_init, scale_init = _init_int8_weight(
71
+ out_features=self.out_features, in_features=self.in_features
72
+ )
73
+ self.int_weight.copy_(int_init.to(device=self.int_weight.device))
74
+ self.weight_scale.copy_(scale_init.to(device=self.weight_scale.device))
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ if x.shape[-1] != self.in_features:
78
+ raise ValueError(
79
+ f"expected input last dim {self.in_features}, got {x.shape[-1]}"
80
+ )
81
+ if self.int_weight.device != x.device:
82
+ raise RuntimeError(
83
+ "BitLinear input and int_weight must be on same device. "
84
+ "Move module with model.to(device) before forward."
85
+ )
86
+ x2d = x.reshape(-1, self.in_features)
87
+ out2d = Int8LinearFn.apply(
88
+ x2d,
89
+ self.int_weight,
90
+ self.weight_scale,
91
+ self.bias,
92
+ int(self._bit_handle.item()),
93
+ )
94
+ return out2d.view(*x.shape[:-1], self.out_features)
95
+
96
+ def consume_weight_grad(self) -> torch.Tensor | None:
97
+ return consume_bit_grad(int(self._bit_handle.item()))
98
+
99
+ def _post_load_state_dict(self, module, incompatible_keys) -> None:
100
+ del module, incompatible_keys
101
+ new_handle = register_bit_handle(int(self._bit_handle.item()))
102
+ if new_handle != self._registered_handle:
103
+ release_bit_handle(self._registered_handle)
104
+ self._registered_handle = new_handle
105
+
106
+
107
+ def collect_bitlinear_modules(module: nn.Module) -> list[BitLinear]:
108
+ return [m for m in module.modules() if isinstance(m, BitLinear)]
@@ -0,0 +1,227 @@
1
+ import torch
2
+
3
+
4
+ # ---------------------------------------------------------------------------
5
+ # Handle registry
6
+ # ---------------------------------------------------------------------------
7
+ _NEXT_HANDLE = 1
8
+ _REGISTERED_HANDLES: set[int] = set()
9
+ _BIT_GRAD_CACHE: dict[int, torch.Tensor] = {}
10
+
11
+
12
+ def next_bit_handle() -> int:
13
+ global _NEXT_HANDLE
14
+ while _NEXT_HANDLE in _REGISTERED_HANDLES:
15
+ _NEXT_HANDLE += 1
16
+ handle = _NEXT_HANDLE
17
+ _REGISTERED_HANDLES.add(handle)
18
+ _NEXT_HANDLE += 1
19
+ return handle
20
+
21
+
22
+ def register_bit_handle(handle: int) -> int:
23
+ global _NEXT_HANDLE
24
+ handle = int(handle)
25
+ if handle <= 0:
26
+ raise ValueError(f"handle must be > 0, got {handle}")
27
+ _REGISTERED_HANDLES.add(handle)
28
+ _BIT_GRAD_CACHE.pop(handle, None)
29
+ if handle >= _NEXT_HANDLE:
30
+ _NEXT_HANDLE = handle + 1
31
+ return handle
32
+
33
+
34
+ def release_bit_handle(handle: int) -> None:
35
+ handle = int(handle)
36
+ _REGISTERED_HANDLES.discard(handle)
37
+ _BIT_GRAD_CACHE.pop(handle, None)
38
+
39
+
40
+ def consume_bit_grad(handle: int) -> torch.Tensor | None:
41
+ return _BIT_GRAD_CACHE.pop(int(handle), None)
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # bitsandbytes backend
46
+ # ---------------------------------------------------------------------------
47
+ _BNB_F = None
48
+ _BNB_FMT = "col_ampere"
49
+
50
+ try:
51
+ import bitsandbytes.functional as _BNB_F
52
+
53
+ if torch.cuda.is_available():
54
+ cc = torch.cuda.get_device_capability()
55
+ if cc[0] >= 8:
56
+ _BNB_FMT = "col_ampere"
57
+ elif cc[0] == 7 and cc[1] >= 5:
58
+ _BNB_FMT = "col_turing"
59
+ else:
60
+ _BNB_FMT = "col32"
61
+ except ImportError:
62
+ pass
63
+
64
+ # Weight-transform cache (handle → (CxB, SB, version))
65
+ _BNB_WCACHE: dict[int, tuple] = {}
66
+ _BNB_WVERSION: dict[int, int] = {}
67
+
68
+
69
+ def _cached_weight_transform(
70
+ handle: int, int_weight: torch.Tensor, weight_scale: torch.Tensor
71
+ ):
72
+ """Return (CxB, SB) for the current weight; re-transform only if stale."""
73
+ cur_ver = _BNB_WVERSION.get(handle, 0)
74
+ entry = _BNB_WCACHE.get(handle)
75
+ if entry is not None:
76
+ CxB, SB, cached_ver = entry
77
+ if cached_ver == cur_ver:
78
+ return CxB, SB
79
+
80
+ # bitsandbytes expects fp16 input for int8_vectorwise_quant
81
+ w_fp16 = int_weight.to(torch.float16) * weight_scale.to(torch.float16).unsqueeze(1)
82
+ w_q, _w_s, _ = _BNB_F.int8_vectorwise_quant(w_fp16)
83
+ CxB, SB = _BNB_F.transform(w_q, _BNB_FMT)
84
+ _BNB_WCACHE[handle] = (CxB, SB, cur_ver)
85
+ return CxB, SB
86
+
87
+
88
+ def _invalidate_weight_cache(handle: int):
89
+ """Call after int_weight is modified by stochastic rounding."""
90
+ _BNB_WVERSION[handle] = _BNB_WVERSION.get(handle, 0) + 1
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Weight quantisation (static, used at init / reset)
95
+ # ---------------------------------------------------------------------------
96
+
97
+
98
+ def quantize_fp_to_int8(weight: torch.Tensor, eps: float = 1e-8):
99
+ if weight.ndim != 2:
100
+ raise ValueError(
101
+ f"weight must be 2D [out_features, in_features], got {tuple(weight.shape)}"
102
+ )
103
+ w = weight.float()
104
+ scale = w.abs().amax(dim=1, keepdim=True).clamp_min(float(eps)) / 127.0
105
+ q = torch.round(w / scale).clamp(-127, 127).to(torch.int8)
106
+ return q.contiguous(), scale.squeeze(1).contiguous()
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Int8LinearFn
111
+ # ---------------------------------------------------------------------------
112
+
113
+
114
+ class Int8LinearFn(torch.autograd.Function):
115
+ @staticmethod
116
+ def forward(
117
+ ctx,
118
+ x2d: torch.Tensor, # [N, K] (autocast doesn't cover custom Fns)
119
+ int_weight: torch.Tensor, # [O, K] int8
120
+ weight_scale: torch.Tensor, # [O] float (trainable Parameter)
121
+ bias: torch.Tensor | None, # [O] float
122
+ handle: int,
123
+ ):
124
+ return _forward_bf16(ctx, x2d, int_weight, weight_scale, bias, handle)
125
+
126
+ @staticmethod
127
+ def backward(ctx, grad_out: torch.Tensor):
128
+ return _backward_impl(ctx, grad_out)
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # Path A: bitsandbytes cuBLASLt INT8 Tensor Cores
133
+ # ---------------------------------------------------------------------------
134
+
135
+
136
+ def _forward_bnb(ctx, x2d, int_weight, weight_scale, bias, handle):
137
+ # 1. Quantise activation — bitsandbytes works in fp16
138
+ x_fp16 = x2d.half()
139
+ CA, SCA, _ = _BNB_F.int8_vectorwise_quant(x_fp16)
140
+
141
+ # 2. Transform activation to col32 layout
142
+ C32A, SA = _BNB_F.transform(CA, "col32")
143
+
144
+ # 3. Get cached weight transform
145
+ CxB, SB = _cached_weight_transform(handle, int_weight, weight_scale)
146
+
147
+ # 4. INT8 matmul via cuBLASLt
148
+ out_i32, _ = _BNB_F.igemmlt(C32A, CxB, SA, SB)
149
+
150
+ # 5. Dequantise — output must match activation scale × weight scale
151
+ out = _BNB_F.mm_dequant(out_i32, SCA, weight_scale.half().unsqueeze(1))
152
+
153
+ if bias is not None:
154
+ out.add_(bias.half())
155
+
156
+ ctx.save_for_backward(x2d.to(torch.bfloat16), int_weight, weight_scale)
157
+ ctx.handle = int(handle)
158
+ ctx.has_bias = bias is not None
159
+ ctx.input_dtype = x2d.dtype
160
+ return out.to(dtype=x2d.dtype)
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Path B: BF16 fallback (torch.matmul on BF16 Tensor Cores, works everywhere)
165
+ # ---------------------------------------------------------------------------
166
+
167
+
168
+ def _forward_bf16(ctx, x2d, int_weight, weight_scale, bias, handle):
169
+ x2d_bf16 = x2d.to(torch.bfloat16)
170
+ w_bf16 = int_weight.to(torch.bfloat16) * weight_scale.to(torch.bfloat16).unsqueeze(
171
+ 1
172
+ )
173
+ out = torch.matmul(x2d_bf16, w_bf16.t())
174
+ if bias is not None:
175
+ out.add_(bias)
176
+ ctx.save_for_backward(x2d_bf16, int_weight, weight_scale)
177
+ ctx.handle = int(handle)
178
+ ctx.has_bias = bias is not None
179
+ ctx.input_dtype = x2d.dtype
180
+ return out
181
+
182
+
183
+ # ---------------------------------------------------------------------------
184
+ # Common backward (BF16 matmul — correct for both forward paths)
185
+ # ---------------------------------------------------------------------------
186
+
187
+
188
+ def _backward_impl(ctx, grad_out):
189
+ x2d_bf16, int_weight, weight_scale = ctx.saved_tensors
190
+ go_bf16 = grad_out.to(torch.bfloat16)
191
+ ws_bf16 = weight_scale.to(torch.bfloat16)
192
+
193
+ w_bf16 = int_weight.to(torch.bfloat16) * ws_bf16.unsqueeze(1)
194
+
195
+ grad_in = torch.matmul(go_bf16, w_bf16)
196
+
197
+ grad_w = torch.matmul(go_bf16.t(), x2d_bf16)
198
+ cached = _BIT_GRAD_CACHE.get(ctx.handle)
199
+ if cached is None:
200
+ _BIT_GRAD_CACHE[ctx.handle] = grad_w.to(dtype=torch.bfloat16)
201
+ else:
202
+ cached.add_(grad_w.to(dtype=torch.bfloat16))
203
+
204
+ grad_bias = go_bf16.sum(dim=0).to(dtype=grad_out.dtype) if ctx.has_bias else None
205
+
206
+ # int_weight and weight_scale are fixed buffers — no gradient flows to them.
207
+ # int_weight is updated via DQT in EiporionOptim; weight_scale stays fixed.
208
+ return (
209
+ grad_in.to(dtype=grad_out.dtype),
210
+ None, # int_weight
211
+ None, # weight_scale
212
+ grad_bias,
213
+ None, # handle
214
+ )
215
+
216
+
217
+ # ---------------------------------------------------------------------------
218
+ # INT8 weight update
219
+ # ---------------------------------------------------------------------------
220
+
221
+
222
+ @torch.no_grad()
223
+ def update_int8_weight_(int_weight: torch.Tensor, delta_q: torch.Tensor) -> None:
224
+ """In-place int8 weight update: W += delta_q, clamped to [-127, 127]."""
225
+ result = int_weight.to(torch.int16) + delta_q.to(torch.int16)
226
+ result.clamp_(-127, 127)
227
+ int_weight.copy_(result.to(torch.int8))
@@ -0,0 +1,258 @@
1
+ import torch
2
+ from bitsandbytes.optim import AdamW8bit
3
+
4
+ from .bitLinear import BitLinear
5
+ from .eiporionkernels import update_int8_weight_, _invalidate_weight_cache
6
+
7
+
8
+ # ---- 8-bit blockwise helpers with mu-law companding ----
9
+ _MU = 255.0
10
+ _1_OVER_LN1P_MU = 1.0 / 5.545177444479562 # 1 / ln(1 + 255)
11
+
12
+
13
+ def _pad_to_block_multiple(t: torch.Tensor, blocksize: int, num_blocks: int):
14
+ padded_size = num_blocks * blocksize
15
+ if padded_size > t.numel():
16
+ return torch.cat(
17
+ [t, torch.zeros(padded_size - t.numel(), device=t.device, dtype=t.dtype)]
18
+ )
19
+ return t
20
+
21
+
22
+ def _quantize_blockwise_signed(x: torch.Tensor, blocksize: int):
23
+ x_flat = x.float().contiguous().view(-1)
24
+ numel = x_flat.numel()
25
+ num_blocks = (numel + blocksize - 1) // blocksize
26
+ x_padded = _pad_to_block_multiple(x_flat, blocksize, num_blocks)
27
+ x_blocks = x_padded.view(num_blocks, blocksize)
28
+ absmax = x_blocks.abs().amax(dim=1).clamp_min(1e-12)
29
+
30
+ # mu-law companding: normalise → compress → quantise
31
+ x_norm = (x_blocks / absmax.unsqueeze(1)).clamp(-1.0, 1.0)
32
+ x_comp = torch.sign(x_norm) * torch.log1p(_MU * x_norm.abs()) * _1_OVER_LN1P_MU
33
+
34
+ q_blocks = torch.round(x_comp * 127.0).clamp(-127, 127).to(torch.int16)
35
+ q_flat = (q_blocks.view(-1)[:numel] + 128).to(torch.uint8)
36
+ return q_flat.view_as(x), absmax
37
+
38
+
39
+ def _dequantize_blockwise_signed(q, absmax, blocksize, shape):
40
+ q_flat = q.contiguous().view(-1)
41
+ numel = q_flat.numel()
42
+ num_blocks = absmax.numel()
43
+ q_padded = _pad_to_block_multiple(q_flat, blocksize, num_blocks)
44
+ q_blocks = q_padded.view(num_blocks, blocksize).float()
45
+
46
+ # inverse mu-law
47
+ y = ((q_blocks - 128.0) / 127.0).clamp(-1.0, 1.0)
48
+ x_norm = torch.sign(y) * (torch.exp(y.abs() / _1_OVER_LN1P_MU) - 1.0) / _MU
49
+ out_blocks = x_norm * absmax.unsqueeze(1)
50
+ return out_blocks.view(-1)[:numel].view(shape)
51
+
52
+
53
+ def _quantize_blockwise_unsigned(x: torch.Tensor, blocksize: int):
54
+ x_flat = x.float().contiguous().view(-1)
55
+ numel = x_flat.numel()
56
+ num_blocks = (numel + blocksize - 1) // blocksize
57
+ x_padded = _pad_to_block_multiple(x_flat, blocksize, num_blocks)
58
+ x_blocks = x_padded.view(num_blocks, blocksize)
59
+ absmax = x_blocks.amax(dim=1).clamp_min(1e-12)
60
+
61
+ # mu-law companding (unsigned: [0, 1] range)
62
+ x_norm = (x_blocks / absmax.unsqueeze(1)).clamp(0.0, 1.0)
63
+ x_comp = torch.log1p(_MU * x_norm) * _1_OVER_LN1P_MU
64
+
65
+ q_blocks = torch.round(x_comp * 255.0).clamp(0, 255)
66
+ q_flat = q_blocks.view(-1)[:numel].to(torch.uint8)
67
+ return q_flat.view_as(x), absmax
68
+
69
+
70
+ def _dequantize_blockwise_unsigned(q, absmax, blocksize, shape):
71
+ q_flat = q.contiguous().view(-1)
72
+ numel = q_flat.numel()
73
+ num_blocks = absmax.numel()
74
+ q_padded = _pad_to_block_multiple(q_flat, blocksize, num_blocks)
75
+ q_blocks = q_padded.view(num_blocks, blocksize).float()
76
+
77
+ y = (q_blocks / 255.0).clamp(0.0, 1.0)
78
+ x_norm = (torch.exp(y / _1_OVER_LN1P_MU) - 1.0) / _MU
79
+ out_blocks = x_norm * absmax.unsqueeze(1)
80
+ return out_blocks.view(-1)[:numel].view(shape)
81
+
82
+
83
+ class EiporionOptim(AdamW8bit):
84
+ """bnb AdamW8bit for dense params + 8-bit AdamW + DQT-SR for INT8 weights."""
85
+
86
+ def __init__(
87
+ self,
88
+ params,
89
+ lr=1e-3,
90
+ betas=(0.9, 0.999),
91
+ eps=1e-8,
92
+ weight_decay=0.1,
93
+ bit_modules=None,
94
+ block_size=256,
95
+ sr_bias_scale=0.15,
96
+ ):
97
+ super().__init__(
98
+ params,
99
+ lr=lr,
100
+ betas=betas,
101
+ eps=eps,
102
+ weight_decay=weight_decay,
103
+ )
104
+ self.bit_modules = list(bit_modules) if bit_modules is not None else []
105
+ self._bit_handles = (
106
+ [int(m._bit_handle.item()) for m in self.bit_modules]
107
+ if bit_modules is not None
108
+ else []
109
+ )
110
+ self._bit_state: dict[int, dict[str, torch.Tensor | int]] = {}
111
+ self.block_size = int(block_size)
112
+ self.sr_bias_scale = float(sr_bias_scale)
113
+
114
+ def add_bit_modules(self, modules) -> None:
115
+ for module in modules:
116
+ if not isinstance(module, BitLinear):
117
+ raise TypeError(f"expected BitLinear, got {type(module).__name__}")
118
+ if module not in self.bit_modules:
119
+ self.bit_modules.append(module)
120
+ self._bit_handles.append(int(module._bit_handle.item()))
121
+
122
+ @torch.no_grad()
123
+ def step(self, closure=None):
124
+ loss = super().step(closure)
125
+
126
+ if not self.bit_modules:
127
+ return loss
128
+
129
+ group = self.param_groups[0]
130
+ lr = float(group["lr"])
131
+ beta1, beta2 = group["betas"]
132
+ eps = float(group["eps"])
133
+ wd = float(group["weight_decay"])
134
+ bs = self.block_size
135
+
136
+ for module, handle in zip(self.bit_modules, self._bit_handles):
137
+ g = module.consume_weight_grad()
138
+ if g is None:
139
+ continue
140
+ g = g.float().contiguous()
141
+
142
+ state = self._bit_state.setdefault(handle, {})
143
+ if "step" not in state:
144
+ state["step"] = 0
145
+ # 8-bit m, v (matching AdamW8bit scheme)
146
+ m_q, m_absmax = _quantize_blockwise_signed(
147
+ torch.zeros_like(g, dtype=torch.float32), bs
148
+ )
149
+ v_q, v_absmax = _quantize_blockwise_unsigned(
150
+ torch.zeros_like(g, dtype=torch.float32), bs
151
+ )
152
+ state["m_q"] = m_q
153
+ state["m_absmax"] = m_absmax
154
+ state["v_q"] = v_q
155
+ state["v_absmax"] = v_absmax
156
+ # residual also 8-bit blockwise, same scheme as m
157
+ r_q, r_absmax = _quantize_blockwise_signed(
158
+ torch.zeros_like(g, dtype=torch.float32), bs
159
+ )
160
+ state["r_q"] = r_q
161
+ state["r_absmax"] = r_absmax
162
+
163
+ m = _dequantize_blockwise_signed(
164
+ state["m_q"], state["m_absmax"], bs, g.shape
165
+ )
166
+ v = _dequantize_blockwise_unsigned(
167
+ state["v_q"], state["v_absmax"], bs, g.shape
168
+ )
169
+ residual = _dequantize_blockwise_signed(
170
+ state["r_q"], state["r_absmax"], bs, g.shape
171
+ )
172
+ state["step"] += 1
173
+ t = state["step"]
174
+
175
+ m.mul_(beta1).add_(g, alpha=1.0 - beta1)
176
+ v.mul_(beta2).addcmul_(g, g, value=1.0 - beta2)
177
+ m_hat = m / (1.0 - beta1**t)
178
+ v_hat = v / (1.0 - beta2**t)
179
+
180
+ ws = module.weight_scale.float().unsqueeze(1).clamp_min(eps)
181
+ iw = module.int_weight.float()
182
+ adam_term = m_hat / (v_hat.sqrt() + eps)
183
+ delta_w_eff = -lr * (adam_term + wd * iw * ws)
184
+ residual = residual + delta_w_eff / ws
185
+
186
+ abs_res = residual.abs()
187
+ base = torch.floor(abs_res)
188
+ frac = abs_res - base
189
+ # Momentum-biased SR
190
+ bias = torch.tanh(adam_term) * self.sr_bias_scale * torch.sign(residual)
191
+ frac_biased = (frac + bias).clamp(0.0, 1.0)
192
+ extra = (torch.rand_like(frac) < frac_biased).float()
193
+ delta_q = (torch.sign(residual) * (base + extra)).to(torch.int32)
194
+
195
+ if torch.any(delta_q != 0):
196
+ update_int8_weight_(module.int_weight, delta_q)
197
+ residual = residual - delta_q.float()
198
+ _invalidate_weight_cache(handle)
199
+
200
+ r_q, r_absmax = _quantize_blockwise_signed(residual, bs)
201
+ state["r_q"] = r_q
202
+ state["r_absmax"] = r_absmax
203
+
204
+ m_q, m_absmax = _quantize_blockwise_signed(m, bs)
205
+ v_q, v_absmax = _quantize_blockwise_unsigned(v, bs)
206
+ state["m_q"] = m_q
207
+ state["m_absmax"] = m_absmax
208
+ state["v_q"] = v_q
209
+ state["v_absmax"] = v_absmax
210
+
211
+ return loss
212
+
213
+ def state_dict(self):
214
+ state = super().state_dict()
215
+ state["bit_state"] = {
216
+ int(handle): {
217
+ key: value.detach().cpu() if torch.is_tensor(value) else value
218
+ for key, value in per_handle.items()
219
+ }
220
+ for handle, per_handle in self._bit_state.items()
221
+ }
222
+ return state
223
+
224
+ def load_state_dict(self, state_dict):
225
+ state_dict = dict(state_dict)
226
+ bit_state = state_dict.pop("bit_state", {})
227
+ super().load_state_dict(state_dict)
228
+ self._bit_state = {}
229
+ for handle, per_handle in bit_state.items():
230
+ self._bit_state[int(handle)] = {
231
+ key: value.clone() if torch.is_tensor(value) else value
232
+ for key, value in per_handle.items()
233
+ }
234
+
235
+
236
+ class EiporionOptimSR(EiporionOptim):
237
+ """EiporionOptim with unbiased stochastic rounding (no momentum bias)."""
238
+
239
+ def __init__(
240
+ self,
241
+ params,
242
+ lr=1e-3,
243
+ betas=(0.9, 0.999),
244
+ eps=1e-8,
245
+ weight_decay=0.1,
246
+ bit_modules=None,
247
+ block_size=256,
248
+ ):
249
+ super().__init__(
250
+ params,
251
+ lr=lr,
252
+ betas=betas,
253
+ eps=eps,
254
+ weight_decay=weight_decay,
255
+ bit_modules=bit_modules,
256
+ block_size=block_size,
257
+ sr_bias_scale=0.0,
258
+ )
@@ -0,0 +1,24 @@
1
+ Metadata-Version: 2.4
2
+ Name: eiporion
3
+ Version: 0.1.0
4
+ Summary: INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration.
5
+ Requires-Python: >=3.12
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: bitsandbytes<0.46.0,>=0.45.0
8
+ Requires-Dist: torch
9
+
10
+ # Eiporion
11
+
12
+ INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration.
13
+
14
+ ## Install
15
+
16
+ ```bash
17
+ pip install eiporion
18
+ ```
19
+
20
+ ## Quick use
21
+
22
+ ```python
23
+ from eiporion import BitLinear
24
+ ```
@@ -0,0 +1,11 @@
1
+ README.md
2
+ pyproject.toml
3
+ eiporion/__init__.py
4
+ eiporion/bitLinear.py
5
+ eiporion/eiporionkernels.py
6
+ eiporion/eiporionoptim.py
7
+ eiporion.egg-info/PKG-INFO
8
+ eiporion.egg-info/SOURCES.txt
9
+ eiporion.egg-info/dependency_links.txt
10
+ eiporion.egg-info/requires.txt
11
+ eiporion.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ bitsandbytes<0.46.0,>=0.45.0
2
+ torch
@@ -0,0 +1 @@
1
+ eiporion
@@ -0,0 +1,14 @@
1
+ [project]
2
+ name = "eiporion"
3
+ version = "0.1.0"
4
+ description = "INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration."
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = ["bitsandbytes>=0.45.0,<0.46.0", "torch"]
8
+
9
+ [build-system]
10
+ requires = ["setuptools>=68", "wheel"]
11
+ build-backend = "setuptools.build_meta"
12
+
13
+ [tool.setuptools]
14
+ packages = ["eiporion"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+