eiporion 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eiporion-0.1.0/PKG-INFO +24 -0
- eiporion-0.1.0/README.md +15 -0
- eiporion-0.1.0/eiporion/__init__.py +11 -0
- eiporion-0.1.0/eiporion/bitLinear.py +108 -0
- eiporion-0.1.0/eiporion/eiporionkernels.py +227 -0
- eiporion-0.1.0/eiporion/eiporionoptim.py +258 -0
- eiporion-0.1.0/eiporion.egg-info/PKG-INFO +24 -0
- eiporion-0.1.0/eiporion.egg-info/SOURCES.txt +11 -0
- eiporion-0.1.0/eiporion.egg-info/dependency_links.txt +1 -0
- eiporion-0.1.0/eiporion.egg-info/requires.txt +2 -0
- eiporion-0.1.0/eiporion.egg-info/top_level.txt +1 -0
- eiporion-0.1.0/pyproject.toml +14 -0
- eiporion-0.1.0/setup.cfg +4 -0
eiporion-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: eiporion
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration.
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: bitsandbytes<0.46.0,>=0.45.0
|
|
8
|
+
Requires-Dist: torch
|
|
9
|
+
|
|
10
|
+
# Eiporion
|
|
11
|
+
|
|
12
|
+
INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration.
|
|
13
|
+
|
|
14
|
+
## Install
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
pip install eiporion
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## Quick use
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
from eiporion import BitLinear
|
|
24
|
+
```
|
eiporion-0.1.0/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .eiporionkernels import quantize_fp_to_int8
|
|
2
|
+
from .bitLinear import BitLinear, collect_bitlinear_modules
|
|
3
|
+
from .eiporionoptim import EiporionOptim, EiporionOptimSR
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"BitLinear",
|
|
7
|
+
"EiporionOptim",
|
|
8
|
+
"EiporionOptimSR",
|
|
9
|
+
"collect_bitlinear_modules",
|
|
10
|
+
"quantize_fp_to_int8",
|
|
11
|
+
]
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
from .eiporionkernels import (
|
|
7
|
+
Int8LinearFn,
|
|
8
|
+
consume_bit_grad,
|
|
9
|
+
next_bit_handle,
|
|
10
|
+
register_bit_handle,
|
|
11
|
+
release_bit_handle,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _init_int8_weight(out_features: int, in_features: int):
|
|
16
|
+
"""Kaiming-uniform init, quantise to int8.
|
|
17
|
+
|
|
18
|
+
Matches bnb ``int8_vectorwise_quant``: per-row max_abs / 127 scale.
|
|
19
|
+
W_int8 = clip(round(W / scale), -128, 127) with scale = max_abs_per_row / 127.
|
|
20
|
+
W_eff = int_weight * weight_scale ≈ original kaiming weight.
|
|
21
|
+
"""
|
|
22
|
+
weight = torch.empty((out_features, in_features), dtype=torch.float32)
|
|
23
|
+
nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
|
|
24
|
+
w = weight.float()
|
|
25
|
+
scale_per_row = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) / 127.0
|
|
26
|
+
q = torch.round(w / scale_per_row).clamp(-127, 127).to(torch.int8)
|
|
27
|
+
return q.contiguous(), scale_per_row.squeeze(1).contiguous()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BitLinear(nn.Module):
|
|
31
|
+
"""INT8 linear layer matching bnb ``Int8Params`` + DQT paper.
|
|
32
|
+
|
|
33
|
+
* ``int_weight`` — int8 buffer ``[O, K]``, the quantised weight (bnb's CB).
|
|
34
|
+
* ``weight_scale`` — float buffer ``[O]``, per-row max_abs/127 (bnb's SCB/127).
|
|
35
|
+
* Forward: ``W_eff = int_weight * weight_scale``, then standard matmul.
|
|
36
|
+
* Gradients for ``int_weight`` are stashed in ``_BIT_GRAD_CACHE`` and consumed
|
|
37
|
+
by :class:`EiporionOptim` for DQT stochastic rounding.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
in_features: int,
|
|
43
|
+
out_features: int,
|
|
44
|
+
bias: bool = False,
|
|
45
|
+
) -> None:
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.in_features = int(in_features)
|
|
48
|
+
self.out_features = int(out_features)
|
|
49
|
+
|
|
50
|
+
int_init, scale_init = _init_int8_weight(
|
|
51
|
+
out_features=self.out_features, in_features=self.in_features
|
|
52
|
+
)
|
|
53
|
+
self.register_buffer("int_weight", int_init, persistent=True)
|
|
54
|
+
self.register_buffer("weight_scale", scale_init, persistent=True)
|
|
55
|
+
self.register_buffer(
|
|
56
|
+
"_bit_handle",
|
|
57
|
+
torch.tensor(next_bit_handle(), dtype=torch.int64),
|
|
58
|
+
persistent=True,
|
|
59
|
+
)
|
|
60
|
+
self._registered_handle = int(self._bit_handle.item())
|
|
61
|
+
self.register_load_state_dict_post_hook(self._post_load_state_dict)
|
|
62
|
+
|
|
63
|
+
if bias:
|
|
64
|
+
self.bias = nn.Parameter(torch.zeros(out_features))
|
|
65
|
+
else:
|
|
66
|
+
self.register_parameter("bias", None)
|
|
67
|
+
|
|
68
|
+
@torch.no_grad()
|
|
69
|
+
def reset_int8_(self) -> None:
|
|
70
|
+
int_init, scale_init = _init_int8_weight(
|
|
71
|
+
out_features=self.out_features, in_features=self.in_features
|
|
72
|
+
)
|
|
73
|
+
self.int_weight.copy_(int_init.to(device=self.int_weight.device))
|
|
74
|
+
self.weight_scale.copy_(scale_init.to(device=self.weight_scale.device))
|
|
75
|
+
|
|
76
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
77
|
+
if x.shape[-1] != self.in_features:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"expected input last dim {self.in_features}, got {x.shape[-1]}"
|
|
80
|
+
)
|
|
81
|
+
if self.int_weight.device != x.device:
|
|
82
|
+
raise RuntimeError(
|
|
83
|
+
"BitLinear input and int_weight must be on same device. "
|
|
84
|
+
"Move module with model.to(device) before forward."
|
|
85
|
+
)
|
|
86
|
+
x2d = x.reshape(-1, self.in_features)
|
|
87
|
+
out2d = Int8LinearFn.apply(
|
|
88
|
+
x2d,
|
|
89
|
+
self.int_weight,
|
|
90
|
+
self.weight_scale,
|
|
91
|
+
self.bias,
|
|
92
|
+
int(self._bit_handle.item()),
|
|
93
|
+
)
|
|
94
|
+
return out2d.view(*x.shape[:-1], self.out_features)
|
|
95
|
+
|
|
96
|
+
def consume_weight_grad(self) -> torch.Tensor | None:
|
|
97
|
+
return consume_bit_grad(int(self._bit_handle.item()))
|
|
98
|
+
|
|
99
|
+
def _post_load_state_dict(self, module, incompatible_keys) -> None:
|
|
100
|
+
del module, incompatible_keys
|
|
101
|
+
new_handle = register_bit_handle(int(self._bit_handle.item()))
|
|
102
|
+
if new_handle != self._registered_handle:
|
|
103
|
+
release_bit_handle(self._registered_handle)
|
|
104
|
+
self._registered_handle = new_handle
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def collect_bitlinear_modules(module: nn.Module) -> list[BitLinear]:
|
|
108
|
+
return [m for m in module.modules() if isinstance(m, BitLinear)]
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# ---------------------------------------------------------------------------
|
|
5
|
+
# Handle registry
|
|
6
|
+
# ---------------------------------------------------------------------------
|
|
7
|
+
_NEXT_HANDLE = 1
|
|
8
|
+
_REGISTERED_HANDLES: set[int] = set()
|
|
9
|
+
_BIT_GRAD_CACHE: dict[int, torch.Tensor] = {}
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def next_bit_handle() -> int:
|
|
13
|
+
global _NEXT_HANDLE
|
|
14
|
+
while _NEXT_HANDLE in _REGISTERED_HANDLES:
|
|
15
|
+
_NEXT_HANDLE += 1
|
|
16
|
+
handle = _NEXT_HANDLE
|
|
17
|
+
_REGISTERED_HANDLES.add(handle)
|
|
18
|
+
_NEXT_HANDLE += 1
|
|
19
|
+
return handle
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def register_bit_handle(handle: int) -> int:
|
|
23
|
+
global _NEXT_HANDLE
|
|
24
|
+
handle = int(handle)
|
|
25
|
+
if handle <= 0:
|
|
26
|
+
raise ValueError(f"handle must be > 0, got {handle}")
|
|
27
|
+
_REGISTERED_HANDLES.add(handle)
|
|
28
|
+
_BIT_GRAD_CACHE.pop(handle, None)
|
|
29
|
+
if handle >= _NEXT_HANDLE:
|
|
30
|
+
_NEXT_HANDLE = handle + 1
|
|
31
|
+
return handle
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def release_bit_handle(handle: int) -> None:
|
|
35
|
+
handle = int(handle)
|
|
36
|
+
_REGISTERED_HANDLES.discard(handle)
|
|
37
|
+
_BIT_GRAD_CACHE.pop(handle, None)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def consume_bit_grad(handle: int) -> torch.Tensor | None:
|
|
41
|
+
return _BIT_GRAD_CACHE.pop(int(handle), None)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
# bitsandbytes backend
|
|
46
|
+
# ---------------------------------------------------------------------------
|
|
47
|
+
_BNB_F = None
|
|
48
|
+
_BNB_FMT = "col_ampere"
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
import bitsandbytes.functional as _BNB_F
|
|
52
|
+
|
|
53
|
+
if torch.cuda.is_available():
|
|
54
|
+
cc = torch.cuda.get_device_capability()
|
|
55
|
+
if cc[0] >= 8:
|
|
56
|
+
_BNB_FMT = "col_ampere"
|
|
57
|
+
elif cc[0] == 7 and cc[1] >= 5:
|
|
58
|
+
_BNB_FMT = "col_turing"
|
|
59
|
+
else:
|
|
60
|
+
_BNB_FMT = "col32"
|
|
61
|
+
except ImportError:
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
# Weight-transform cache (handle → (CxB, SB, version))
|
|
65
|
+
_BNB_WCACHE: dict[int, tuple] = {}
|
|
66
|
+
_BNB_WVERSION: dict[int, int] = {}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _cached_weight_transform(
|
|
70
|
+
handle: int, int_weight: torch.Tensor, weight_scale: torch.Tensor
|
|
71
|
+
):
|
|
72
|
+
"""Return (CxB, SB) for the current weight; re-transform only if stale."""
|
|
73
|
+
cur_ver = _BNB_WVERSION.get(handle, 0)
|
|
74
|
+
entry = _BNB_WCACHE.get(handle)
|
|
75
|
+
if entry is not None:
|
|
76
|
+
CxB, SB, cached_ver = entry
|
|
77
|
+
if cached_ver == cur_ver:
|
|
78
|
+
return CxB, SB
|
|
79
|
+
|
|
80
|
+
# bitsandbytes expects fp16 input for int8_vectorwise_quant
|
|
81
|
+
w_fp16 = int_weight.to(torch.float16) * weight_scale.to(torch.float16).unsqueeze(1)
|
|
82
|
+
w_q, _w_s, _ = _BNB_F.int8_vectorwise_quant(w_fp16)
|
|
83
|
+
CxB, SB = _BNB_F.transform(w_q, _BNB_FMT)
|
|
84
|
+
_BNB_WCACHE[handle] = (CxB, SB, cur_ver)
|
|
85
|
+
return CxB, SB
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _invalidate_weight_cache(handle: int):
|
|
89
|
+
"""Call after int_weight is modified by stochastic rounding."""
|
|
90
|
+
_BNB_WVERSION[handle] = _BNB_WVERSION.get(handle, 0) + 1
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# ---------------------------------------------------------------------------
|
|
94
|
+
# Weight quantisation (static, used at init / reset)
|
|
95
|
+
# ---------------------------------------------------------------------------
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def quantize_fp_to_int8(weight: torch.Tensor, eps: float = 1e-8):
|
|
99
|
+
if weight.ndim != 2:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"weight must be 2D [out_features, in_features], got {tuple(weight.shape)}"
|
|
102
|
+
)
|
|
103
|
+
w = weight.float()
|
|
104
|
+
scale = w.abs().amax(dim=1, keepdim=True).clamp_min(float(eps)) / 127.0
|
|
105
|
+
q = torch.round(w / scale).clamp(-127, 127).to(torch.int8)
|
|
106
|
+
return q.contiguous(), scale.squeeze(1).contiguous()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# ---------------------------------------------------------------------------
|
|
110
|
+
# Int8LinearFn
|
|
111
|
+
# ---------------------------------------------------------------------------
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Int8LinearFn(torch.autograd.Function):
|
|
115
|
+
@staticmethod
|
|
116
|
+
def forward(
|
|
117
|
+
ctx,
|
|
118
|
+
x2d: torch.Tensor, # [N, K] (autocast doesn't cover custom Fns)
|
|
119
|
+
int_weight: torch.Tensor, # [O, K] int8
|
|
120
|
+
weight_scale: torch.Tensor, # [O] float (trainable Parameter)
|
|
121
|
+
bias: torch.Tensor | None, # [O] float
|
|
122
|
+
handle: int,
|
|
123
|
+
):
|
|
124
|
+
return _forward_bf16(ctx, x2d, int_weight, weight_scale, bias, handle)
|
|
125
|
+
|
|
126
|
+
@staticmethod
|
|
127
|
+
def backward(ctx, grad_out: torch.Tensor):
|
|
128
|
+
return _backward_impl(ctx, grad_out)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ---------------------------------------------------------------------------
|
|
132
|
+
# Path A: bitsandbytes cuBLASLt INT8 Tensor Cores
|
|
133
|
+
# ---------------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _forward_bnb(ctx, x2d, int_weight, weight_scale, bias, handle):
|
|
137
|
+
# 1. Quantise activation — bitsandbytes works in fp16
|
|
138
|
+
x_fp16 = x2d.half()
|
|
139
|
+
CA, SCA, _ = _BNB_F.int8_vectorwise_quant(x_fp16)
|
|
140
|
+
|
|
141
|
+
# 2. Transform activation to col32 layout
|
|
142
|
+
C32A, SA = _BNB_F.transform(CA, "col32")
|
|
143
|
+
|
|
144
|
+
# 3. Get cached weight transform
|
|
145
|
+
CxB, SB = _cached_weight_transform(handle, int_weight, weight_scale)
|
|
146
|
+
|
|
147
|
+
# 4. INT8 matmul via cuBLASLt
|
|
148
|
+
out_i32, _ = _BNB_F.igemmlt(C32A, CxB, SA, SB)
|
|
149
|
+
|
|
150
|
+
# 5. Dequantise — output must match activation scale × weight scale
|
|
151
|
+
out = _BNB_F.mm_dequant(out_i32, SCA, weight_scale.half().unsqueeze(1))
|
|
152
|
+
|
|
153
|
+
if bias is not None:
|
|
154
|
+
out.add_(bias.half())
|
|
155
|
+
|
|
156
|
+
ctx.save_for_backward(x2d.to(torch.bfloat16), int_weight, weight_scale)
|
|
157
|
+
ctx.handle = int(handle)
|
|
158
|
+
ctx.has_bias = bias is not None
|
|
159
|
+
ctx.input_dtype = x2d.dtype
|
|
160
|
+
return out.to(dtype=x2d.dtype)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# ---------------------------------------------------------------------------
|
|
164
|
+
# Path B: BF16 fallback (torch.matmul on BF16 Tensor Cores, works everywhere)
|
|
165
|
+
# ---------------------------------------------------------------------------
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _forward_bf16(ctx, x2d, int_weight, weight_scale, bias, handle):
|
|
169
|
+
x2d_bf16 = x2d.to(torch.bfloat16)
|
|
170
|
+
w_bf16 = int_weight.to(torch.bfloat16) * weight_scale.to(torch.bfloat16).unsqueeze(
|
|
171
|
+
1
|
|
172
|
+
)
|
|
173
|
+
out = torch.matmul(x2d_bf16, w_bf16.t())
|
|
174
|
+
if bias is not None:
|
|
175
|
+
out.add_(bias)
|
|
176
|
+
ctx.save_for_backward(x2d_bf16, int_weight, weight_scale)
|
|
177
|
+
ctx.handle = int(handle)
|
|
178
|
+
ctx.has_bias = bias is not None
|
|
179
|
+
ctx.input_dtype = x2d.dtype
|
|
180
|
+
return out
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# ---------------------------------------------------------------------------
|
|
184
|
+
# Common backward (BF16 matmul — correct for both forward paths)
|
|
185
|
+
# ---------------------------------------------------------------------------
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _backward_impl(ctx, grad_out):
|
|
189
|
+
x2d_bf16, int_weight, weight_scale = ctx.saved_tensors
|
|
190
|
+
go_bf16 = grad_out.to(torch.bfloat16)
|
|
191
|
+
ws_bf16 = weight_scale.to(torch.bfloat16)
|
|
192
|
+
|
|
193
|
+
w_bf16 = int_weight.to(torch.bfloat16) * ws_bf16.unsqueeze(1)
|
|
194
|
+
|
|
195
|
+
grad_in = torch.matmul(go_bf16, w_bf16)
|
|
196
|
+
|
|
197
|
+
grad_w = torch.matmul(go_bf16.t(), x2d_bf16)
|
|
198
|
+
cached = _BIT_GRAD_CACHE.get(ctx.handle)
|
|
199
|
+
if cached is None:
|
|
200
|
+
_BIT_GRAD_CACHE[ctx.handle] = grad_w.to(dtype=torch.bfloat16)
|
|
201
|
+
else:
|
|
202
|
+
cached.add_(grad_w.to(dtype=torch.bfloat16))
|
|
203
|
+
|
|
204
|
+
grad_bias = go_bf16.sum(dim=0).to(dtype=grad_out.dtype) if ctx.has_bias else None
|
|
205
|
+
|
|
206
|
+
# int_weight and weight_scale are fixed buffers — no gradient flows to them.
|
|
207
|
+
# int_weight is updated via DQT in EiporionOptim; weight_scale stays fixed.
|
|
208
|
+
return (
|
|
209
|
+
grad_in.to(dtype=grad_out.dtype),
|
|
210
|
+
None, # int_weight
|
|
211
|
+
None, # weight_scale
|
|
212
|
+
grad_bias,
|
|
213
|
+
None, # handle
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# ---------------------------------------------------------------------------
|
|
218
|
+
# INT8 weight update
|
|
219
|
+
# ---------------------------------------------------------------------------
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@torch.no_grad()
|
|
223
|
+
def update_int8_weight_(int_weight: torch.Tensor, delta_q: torch.Tensor) -> None:
|
|
224
|
+
"""In-place int8 weight update: W += delta_q, clamped to [-127, 127]."""
|
|
225
|
+
result = int_weight.to(torch.int16) + delta_q.to(torch.int16)
|
|
226
|
+
result.clamp_(-127, 127)
|
|
227
|
+
int_weight.copy_(result.to(torch.int8))
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from bitsandbytes.optim import AdamW8bit
|
|
3
|
+
|
|
4
|
+
from .bitLinear import BitLinear
|
|
5
|
+
from .eiporionkernels import update_int8_weight_, _invalidate_weight_cache
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# ---- 8-bit blockwise helpers with mu-law companding ----
|
|
9
|
+
_MU = 255.0
|
|
10
|
+
_1_OVER_LN1P_MU = 1.0 / 5.545177444479562 # 1 / ln(1 + 255)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _pad_to_block_multiple(t: torch.Tensor, blocksize: int, num_blocks: int):
|
|
14
|
+
padded_size = num_blocks * blocksize
|
|
15
|
+
if padded_size > t.numel():
|
|
16
|
+
return torch.cat(
|
|
17
|
+
[t, torch.zeros(padded_size - t.numel(), device=t.device, dtype=t.dtype)]
|
|
18
|
+
)
|
|
19
|
+
return t
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _quantize_blockwise_signed(x: torch.Tensor, blocksize: int):
|
|
23
|
+
x_flat = x.float().contiguous().view(-1)
|
|
24
|
+
numel = x_flat.numel()
|
|
25
|
+
num_blocks = (numel + blocksize - 1) // blocksize
|
|
26
|
+
x_padded = _pad_to_block_multiple(x_flat, blocksize, num_blocks)
|
|
27
|
+
x_blocks = x_padded.view(num_blocks, blocksize)
|
|
28
|
+
absmax = x_blocks.abs().amax(dim=1).clamp_min(1e-12)
|
|
29
|
+
|
|
30
|
+
# mu-law companding: normalise → compress → quantise
|
|
31
|
+
x_norm = (x_blocks / absmax.unsqueeze(1)).clamp(-1.0, 1.0)
|
|
32
|
+
x_comp = torch.sign(x_norm) * torch.log1p(_MU * x_norm.abs()) * _1_OVER_LN1P_MU
|
|
33
|
+
|
|
34
|
+
q_blocks = torch.round(x_comp * 127.0).clamp(-127, 127).to(torch.int16)
|
|
35
|
+
q_flat = (q_blocks.view(-1)[:numel] + 128).to(torch.uint8)
|
|
36
|
+
return q_flat.view_as(x), absmax
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _dequantize_blockwise_signed(q, absmax, blocksize, shape):
|
|
40
|
+
q_flat = q.contiguous().view(-1)
|
|
41
|
+
numel = q_flat.numel()
|
|
42
|
+
num_blocks = absmax.numel()
|
|
43
|
+
q_padded = _pad_to_block_multiple(q_flat, blocksize, num_blocks)
|
|
44
|
+
q_blocks = q_padded.view(num_blocks, blocksize).float()
|
|
45
|
+
|
|
46
|
+
# inverse mu-law
|
|
47
|
+
y = ((q_blocks - 128.0) / 127.0).clamp(-1.0, 1.0)
|
|
48
|
+
x_norm = torch.sign(y) * (torch.exp(y.abs() / _1_OVER_LN1P_MU) - 1.0) / _MU
|
|
49
|
+
out_blocks = x_norm * absmax.unsqueeze(1)
|
|
50
|
+
return out_blocks.view(-1)[:numel].view(shape)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _quantize_blockwise_unsigned(x: torch.Tensor, blocksize: int):
|
|
54
|
+
x_flat = x.float().contiguous().view(-1)
|
|
55
|
+
numel = x_flat.numel()
|
|
56
|
+
num_blocks = (numel + blocksize - 1) // blocksize
|
|
57
|
+
x_padded = _pad_to_block_multiple(x_flat, blocksize, num_blocks)
|
|
58
|
+
x_blocks = x_padded.view(num_blocks, blocksize)
|
|
59
|
+
absmax = x_blocks.amax(dim=1).clamp_min(1e-12)
|
|
60
|
+
|
|
61
|
+
# mu-law companding (unsigned: [0, 1] range)
|
|
62
|
+
x_norm = (x_blocks / absmax.unsqueeze(1)).clamp(0.0, 1.0)
|
|
63
|
+
x_comp = torch.log1p(_MU * x_norm) * _1_OVER_LN1P_MU
|
|
64
|
+
|
|
65
|
+
q_blocks = torch.round(x_comp * 255.0).clamp(0, 255)
|
|
66
|
+
q_flat = q_blocks.view(-1)[:numel].to(torch.uint8)
|
|
67
|
+
return q_flat.view_as(x), absmax
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _dequantize_blockwise_unsigned(q, absmax, blocksize, shape):
|
|
71
|
+
q_flat = q.contiguous().view(-1)
|
|
72
|
+
numel = q_flat.numel()
|
|
73
|
+
num_blocks = absmax.numel()
|
|
74
|
+
q_padded = _pad_to_block_multiple(q_flat, blocksize, num_blocks)
|
|
75
|
+
q_blocks = q_padded.view(num_blocks, blocksize).float()
|
|
76
|
+
|
|
77
|
+
y = (q_blocks / 255.0).clamp(0.0, 1.0)
|
|
78
|
+
x_norm = (torch.exp(y / _1_OVER_LN1P_MU) - 1.0) / _MU
|
|
79
|
+
out_blocks = x_norm * absmax.unsqueeze(1)
|
|
80
|
+
return out_blocks.view(-1)[:numel].view(shape)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class EiporionOptim(AdamW8bit):
|
|
84
|
+
"""bnb AdamW8bit for dense params + 8-bit AdamW + DQT-SR for INT8 weights."""
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
params,
|
|
89
|
+
lr=1e-3,
|
|
90
|
+
betas=(0.9, 0.999),
|
|
91
|
+
eps=1e-8,
|
|
92
|
+
weight_decay=0.1,
|
|
93
|
+
bit_modules=None,
|
|
94
|
+
block_size=256,
|
|
95
|
+
sr_bias_scale=0.15,
|
|
96
|
+
):
|
|
97
|
+
super().__init__(
|
|
98
|
+
params,
|
|
99
|
+
lr=lr,
|
|
100
|
+
betas=betas,
|
|
101
|
+
eps=eps,
|
|
102
|
+
weight_decay=weight_decay,
|
|
103
|
+
)
|
|
104
|
+
self.bit_modules = list(bit_modules) if bit_modules is not None else []
|
|
105
|
+
self._bit_handles = (
|
|
106
|
+
[int(m._bit_handle.item()) for m in self.bit_modules]
|
|
107
|
+
if bit_modules is not None
|
|
108
|
+
else []
|
|
109
|
+
)
|
|
110
|
+
self._bit_state: dict[int, dict[str, torch.Tensor | int]] = {}
|
|
111
|
+
self.block_size = int(block_size)
|
|
112
|
+
self.sr_bias_scale = float(sr_bias_scale)
|
|
113
|
+
|
|
114
|
+
def add_bit_modules(self, modules) -> None:
|
|
115
|
+
for module in modules:
|
|
116
|
+
if not isinstance(module, BitLinear):
|
|
117
|
+
raise TypeError(f"expected BitLinear, got {type(module).__name__}")
|
|
118
|
+
if module not in self.bit_modules:
|
|
119
|
+
self.bit_modules.append(module)
|
|
120
|
+
self._bit_handles.append(int(module._bit_handle.item()))
|
|
121
|
+
|
|
122
|
+
@torch.no_grad()
|
|
123
|
+
def step(self, closure=None):
|
|
124
|
+
loss = super().step(closure)
|
|
125
|
+
|
|
126
|
+
if not self.bit_modules:
|
|
127
|
+
return loss
|
|
128
|
+
|
|
129
|
+
group = self.param_groups[0]
|
|
130
|
+
lr = float(group["lr"])
|
|
131
|
+
beta1, beta2 = group["betas"]
|
|
132
|
+
eps = float(group["eps"])
|
|
133
|
+
wd = float(group["weight_decay"])
|
|
134
|
+
bs = self.block_size
|
|
135
|
+
|
|
136
|
+
for module, handle in zip(self.bit_modules, self._bit_handles):
|
|
137
|
+
g = module.consume_weight_grad()
|
|
138
|
+
if g is None:
|
|
139
|
+
continue
|
|
140
|
+
g = g.float().contiguous()
|
|
141
|
+
|
|
142
|
+
state = self._bit_state.setdefault(handle, {})
|
|
143
|
+
if "step" not in state:
|
|
144
|
+
state["step"] = 0
|
|
145
|
+
# 8-bit m, v (matching AdamW8bit scheme)
|
|
146
|
+
m_q, m_absmax = _quantize_blockwise_signed(
|
|
147
|
+
torch.zeros_like(g, dtype=torch.float32), bs
|
|
148
|
+
)
|
|
149
|
+
v_q, v_absmax = _quantize_blockwise_unsigned(
|
|
150
|
+
torch.zeros_like(g, dtype=torch.float32), bs
|
|
151
|
+
)
|
|
152
|
+
state["m_q"] = m_q
|
|
153
|
+
state["m_absmax"] = m_absmax
|
|
154
|
+
state["v_q"] = v_q
|
|
155
|
+
state["v_absmax"] = v_absmax
|
|
156
|
+
# residual also 8-bit blockwise, same scheme as m
|
|
157
|
+
r_q, r_absmax = _quantize_blockwise_signed(
|
|
158
|
+
torch.zeros_like(g, dtype=torch.float32), bs
|
|
159
|
+
)
|
|
160
|
+
state["r_q"] = r_q
|
|
161
|
+
state["r_absmax"] = r_absmax
|
|
162
|
+
|
|
163
|
+
m = _dequantize_blockwise_signed(
|
|
164
|
+
state["m_q"], state["m_absmax"], bs, g.shape
|
|
165
|
+
)
|
|
166
|
+
v = _dequantize_blockwise_unsigned(
|
|
167
|
+
state["v_q"], state["v_absmax"], bs, g.shape
|
|
168
|
+
)
|
|
169
|
+
residual = _dequantize_blockwise_signed(
|
|
170
|
+
state["r_q"], state["r_absmax"], bs, g.shape
|
|
171
|
+
)
|
|
172
|
+
state["step"] += 1
|
|
173
|
+
t = state["step"]
|
|
174
|
+
|
|
175
|
+
m.mul_(beta1).add_(g, alpha=1.0 - beta1)
|
|
176
|
+
v.mul_(beta2).addcmul_(g, g, value=1.0 - beta2)
|
|
177
|
+
m_hat = m / (1.0 - beta1**t)
|
|
178
|
+
v_hat = v / (1.0 - beta2**t)
|
|
179
|
+
|
|
180
|
+
ws = module.weight_scale.float().unsqueeze(1).clamp_min(eps)
|
|
181
|
+
iw = module.int_weight.float()
|
|
182
|
+
adam_term = m_hat / (v_hat.sqrt() + eps)
|
|
183
|
+
delta_w_eff = -lr * (adam_term + wd * iw * ws)
|
|
184
|
+
residual = residual + delta_w_eff / ws
|
|
185
|
+
|
|
186
|
+
abs_res = residual.abs()
|
|
187
|
+
base = torch.floor(abs_res)
|
|
188
|
+
frac = abs_res - base
|
|
189
|
+
# Momentum-biased SR
|
|
190
|
+
bias = torch.tanh(adam_term) * self.sr_bias_scale * torch.sign(residual)
|
|
191
|
+
frac_biased = (frac + bias).clamp(0.0, 1.0)
|
|
192
|
+
extra = (torch.rand_like(frac) < frac_biased).float()
|
|
193
|
+
delta_q = (torch.sign(residual) * (base + extra)).to(torch.int32)
|
|
194
|
+
|
|
195
|
+
if torch.any(delta_q != 0):
|
|
196
|
+
update_int8_weight_(module.int_weight, delta_q)
|
|
197
|
+
residual = residual - delta_q.float()
|
|
198
|
+
_invalidate_weight_cache(handle)
|
|
199
|
+
|
|
200
|
+
r_q, r_absmax = _quantize_blockwise_signed(residual, bs)
|
|
201
|
+
state["r_q"] = r_q
|
|
202
|
+
state["r_absmax"] = r_absmax
|
|
203
|
+
|
|
204
|
+
m_q, m_absmax = _quantize_blockwise_signed(m, bs)
|
|
205
|
+
v_q, v_absmax = _quantize_blockwise_unsigned(v, bs)
|
|
206
|
+
state["m_q"] = m_q
|
|
207
|
+
state["m_absmax"] = m_absmax
|
|
208
|
+
state["v_q"] = v_q
|
|
209
|
+
state["v_absmax"] = v_absmax
|
|
210
|
+
|
|
211
|
+
return loss
|
|
212
|
+
|
|
213
|
+
def state_dict(self):
|
|
214
|
+
state = super().state_dict()
|
|
215
|
+
state["bit_state"] = {
|
|
216
|
+
int(handle): {
|
|
217
|
+
key: value.detach().cpu() if torch.is_tensor(value) else value
|
|
218
|
+
for key, value in per_handle.items()
|
|
219
|
+
}
|
|
220
|
+
for handle, per_handle in self._bit_state.items()
|
|
221
|
+
}
|
|
222
|
+
return state
|
|
223
|
+
|
|
224
|
+
def load_state_dict(self, state_dict):
|
|
225
|
+
state_dict = dict(state_dict)
|
|
226
|
+
bit_state = state_dict.pop("bit_state", {})
|
|
227
|
+
super().load_state_dict(state_dict)
|
|
228
|
+
self._bit_state = {}
|
|
229
|
+
for handle, per_handle in bit_state.items():
|
|
230
|
+
self._bit_state[int(handle)] = {
|
|
231
|
+
key: value.clone() if torch.is_tensor(value) else value
|
|
232
|
+
for key, value in per_handle.items()
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class EiporionOptimSR(EiporionOptim):
|
|
237
|
+
"""EiporionOptim with unbiased stochastic rounding (no momentum bias)."""
|
|
238
|
+
|
|
239
|
+
def __init__(
|
|
240
|
+
self,
|
|
241
|
+
params,
|
|
242
|
+
lr=1e-3,
|
|
243
|
+
betas=(0.9, 0.999),
|
|
244
|
+
eps=1e-8,
|
|
245
|
+
weight_decay=0.1,
|
|
246
|
+
bit_modules=None,
|
|
247
|
+
block_size=256,
|
|
248
|
+
):
|
|
249
|
+
super().__init__(
|
|
250
|
+
params,
|
|
251
|
+
lr=lr,
|
|
252
|
+
betas=betas,
|
|
253
|
+
eps=eps,
|
|
254
|
+
weight_decay=weight_decay,
|
|
255
|
+
bit_modules=bit_modules,
|
|
256
|
+
block_size=block_size,
|
|
257
|
+
sr_bias_scale=0.0,
|
|
258
|
+
)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: eiporion
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration.
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: bitsandbytes<0.46.0,>=0.45.0
|
|
8
|
+
Requires-Dist: torch
|
|
9
|
+
|
|
10
|
+
# Eiporion
|
|
11
|
+
|
|
12
|
+
INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration.
|
|
13
|
+
|
|
14
|
+
## Install
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
pip install eiporion
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## Quick use
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
from eiporion import BitLinear
|
|
24
|
+
```
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
eiporion/__init__.py
|
|
4
|
+
eiporion/bitLinear.py
|
|
5
|
+
eiporion/eiporionkernels.py
|
|
6
|
+
eiporion/eiporionoptim.py
|
|
7
|
+
eiporion.egg-info/PKG-INFO
|
|
8
|
+
eiporion.egg-info/SOURCES.txt
|
|
9
|
+
eiporion.egg-info/dependency_links.txt
|
|
10
|
+
eiporion.egg-info/requires.txt
|
|
11
|
+
eiporion.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
eiporion
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "eiporion"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "INT8 linear layers with DQT SR/MB-SR and bitsandbytes acceleration."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.12"
|
|
7
|
+
dependencies = ["bitsandbytes>=0.45.0,<0.46.0", "torch"]
|
|
8
|
+
|
|
9
|
+
[build-system]
|
|
10
|
+
requires = ["setuptools>=68", "wheel"]
|
|
11
|
+
build-backend = "setuptools.build_meta"
|
|
12
|
+
|
|
13
|
+
[tool.setuptools]
|
|
14
|
+
packages = ["eiporion"]
|
eiporion-0.1.0/setup.cfg
ADDED