mmgp 3.3.1__py3-none-any.whl → 3.6.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mmgp/fp8_quanto_bridge.py +645 -0
- mmgp/fp8_quanto_bridge_old.py +498 -0
- mmgp/offload.py +3613 -2461
- mmgp/quant_router.py +518 -0
- mmgp/quanto_int8_cuda.py +97 -0
- mmgp/quanto_int8_inject.py +335 -0
- mmgp/safetensors2.py +534 -450
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/METADATA +195 -197
- mmgp-3.6.11.dist-info/RECORD +14 -0
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/WHEEL +1 -1
- mmgp-3.3.1.dist-info/RECORD +0 -9
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/licenses/LICENSE.md +0 -0
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import os
|
|
5
|
+
from types import SimpleNamespace
|
|
6
|
+
from typing import Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
# Env toggles
|
|
12
|
+
_ENV_ENABLE = "WAN2GP_QUANTO_INT8_KERNEL"
|
|
13
|
+
_ENV_DEBUG = "WAN2GP_QUANTO_INT8_DEBUG"
|
|
14
|
+
_ENV_ALIGN_M = "WAN2GP_QUANTO_INT8_ALIGN_M"
|
|
15
|
+
_ENV_ALIGN_N = "WAN2GP_QUANTO_INT8_ALIGN_N"
|
|
16
|
+
_ENV_ALIGN_K = "WAN2GP_QUANTO_INT8_ALIGN_K"
|
|
17
|
+
_ENV_USE_TC = "WAN2GP_QUANTO_INT8_TC"
|
|
18
|
+
|
|
19
|
+
# Kernel namespace/entrypoints (resolved lazily)
|
|
20
|
+
_KERNEL_OP = None
|
|
21
|
+
_KERNEL_MODULES = (
|
|
22
|
+
"mmgp.quanto_int8_cuda",
|
|
23
|
+
"mmgp_quanto_int8_cuda",
|
|
24
|
+
"mmgp_quanto_int8",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _env_flag(name: str, default: str = "1") -> bool:
|
|
29
|
+
val = os.environ.get(name, default)
|
|
30
|
+
return str(val).strip().lower() in ("1", "true", "yes", "on")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _env_int(name: str, default: int) -> int:
|
|
34
|
+
raw = os.environ.get(name)
|
|
35
|
+
if raw is None:
|
|
36
|
+
return default
|
|
37
|
+
try:
|
|
38
|
+
return int(raw)
|
|
39
|
+
except ValueError:
|
|
40
|
+
return default
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _debug(msg: str) -> None:
|
|
44
|
+
if _env_flag(_ENV_DEBUG, "0"):
|
|
45
|
+
print(f"[WAN2GP][INT8][quanto] {msg}")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _get_alignments() -> Tuple[int, int, int]:
|
|
49
|
+
# Conservative defaults; keep in sync with kernel requirements.
|
|
50
|
+
align_m = _env_int(_ENV_ALIGN_M, 16)
|
|
51
|
+
align_n = _env_int(_ENV_ALIGN_N, 16)
|
|
52
|
+
align_k = _env_int(_ENV_ALIGN_K, 16)
|
|
53
|
+
return align_m, align_n, align_k
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _is_qbytes_tensor(t: torch.Tensor) -> bool:
|
|
57
|
+
try:
|
|
58
|
+
from optimum.quanto.tensor.qbytes import QBytesTensor
|
|
59
|
+
except Exception:
|
|
60
|
+
return False
|
|
61
|
+
return isinstance(t, QBytesTensor)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _is_weight_qbytes(t: torch.Tensor) -> bool:
|
|
65
|
+
try:
|
|
66
|
+
from optimum.quanto.tensor.weights.qbytes import WeightQBytesTensor
|
|
67
|
+
except Exception:
|
|
68
|
+
return False
|
|
69
|
+
return isinstance(t, WeightQBytesTensor)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _flatten_scale(scale: torch.Tensor) -> torch.Tensor:
|
|
73
|
+
if scale.ndim == 2 and scale.shape[1] == 1:
|
|
74
|
+
return scale.view(-1)
|
|
75
|
+
if scale.ndim == 1:
|
|
76
|
+
return scale
|
|
77
|
+
return scale.reshape(-1)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _quantize_activation_per_row(x_2d: torch.Tensor, scale_dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
81
|
+
# Symmetric int8 quantization per-row: scale = max(abs(row)) / 127
|
|
82
|
+
x_fp32 = x_2d.to(torch.float32)
|
|
83
|
+
amax = x_fp32.abs().amax(dim=1)
|
|
84
|
+
qmax = 127.0
|
|
85
|
+
scale = amax / qmax
|
|
86
|
+
scale = torch.where(scale == 0, torch.ones_like(scale), scale)
|
|
87
|
+
q = torch.round(x_fp32 / scale[:, None]).clamp(-qmax, qmax).to(torch.int8)
|
|
88
|
+
return q, scale.to(scale_dtype)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _pad_to_multiple(x: torch.Tensor, m_pad: int, k_pad: int) -> torch.Tensor:
|
|
92
|
+
if m_pad == 0 and k_pad == 0:
|
|
93
|
+
return x
|
|
94
|
+
# Pad last dim (K) and then rows (M)
|
|
95
|
+
if k_pad:
|
|
96
|
+
x = F.pad(x, (0, k_pad))
|
|
97
|
+
if m_pad:
|
|
98
|
+
x = F.pad(x, (0, 0, 0, m_pad))
|
|
99
|
+
return x
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _pad_weights_to_multiple(w: torch.Tensor, n_pad: int, k_pad: int) -> torch.Tensor:
|
|
103
|
+
if n_pad == 0 and k_pad == 0:
|
|
104
|
+
return w
|
|
105
|
+
if k_pad:
|
|
106
|
+
w = F.pad(w, (0, k_pad))
|
|
107
|
+
if n_pad:
|
|
108
|
+
w = F.pad(w, (0, 0, 0, n_pad))
|
|
109
|
+
return w
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _pad_scale_to_multiple(scale: torch.Tensor, pad: int, pad_value: float = 1.0) -> torch.Tensor:
|
|
113
|
+
if pad == 0:
|
|
114
|
+
return scale
|
|
115
|
+
return F.pad(scale, (0, pad), value=pad_value)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _resolve_kernel_op():
|
|
119
|
+
global _KERNEL_OP
|
|
120
|
+
if _KERNEL_OP is not None:
|
|
121
|
+
return _KERNEL_OP
|
|
122
|
+
# torch.ops path (preferred)
|
|
123
|
+
ops_ns = getattr(torch.ops, "mmgp_quanto_int8", None)
|
|
124
|
+
if ops_ns is not None and hasattr(ops_ns, "int8_scaled_mm"):
|
|
125
|
+
_KERNEL_OP = ops_ns.int8_scaled_mm
|
|
126
|
+
return _KERNEL_OP
|
|
127
|
+
# import module path
|
|
128
|
+
for mod_name in _KERNEL_MODULES:
|
|
129
|
+
try:
|
|
130
|
+
mod = importlib.import_module(mod_name)
|
|
131
|
+
except Exception:
|
|
132
|
+
continue
|
|
133
|
+
if hasattr(mod, "int8_scaled_mm"):
|
|
134
|
+
_KERNEL_OP = mod.int8_scaled_mm
|
|
135
|
+
return _KERNEL_OP
|
|
136
|
+
ops_ns = getattr(torch.ops, "mmgp_quanto_int8", None)
|
|
137
|
+
if ops_ns is not None and hasattr(ops_ns, "int8_scaled_mm"):
|
|
138
|
+
_KERNEL_OP = ops_ns.int8_scaled_mm
|
|
139
|
+
return _KERNEL_OP
|
|
140
|
+
raise RuntimeError(
|
|
141
|
+
"mmgp int8 kernel extension not loaded. Expected torch.ops.mmgp_quanto_int8.int8_scaled_mm "
|
|
142
|
+
"or a module exposing int8_scaled_mm (mmgp.quanto_int8_cuda)."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _quantize_with_kernel(x_2d: torch.Tensor, scale_dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
147
|
+
try:
|
|
148
|
+
mod = importlib.import_module("mmgp.quanto_int8_cuda")
|
|
149
|
+
q, s = mod.quantize_per_row_int8(x_2d)
|
|
150
|
+
if s.dtype != scale_dtype:
|
|
151
|
+
s = s.to(scale_dtype)
|
|
152
|
+
return q, s
|
|
153
|
+
except Exception:
|
|
154
|
+
return _quantize_activation_per_row(x_2d, scale_dtype)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _maybe_get_transposed_weight(other: torch.Tensor) -> torch.Tensor:
|
|
158
|
+
cached = getattr(other, "_mmgp_int8_t", None)
|
|
159
|
+
if isinstance(cached, torch.Tensor) and cached.device == other._data.device:
|
|
160
|
+
if cached.shape == (other._data.shape[1], other._data.shape[0]):
|
|
161
|
+
return cached
|
|
162
|
+
w_t = other._data.t().contiguous()
|
|
163
|
+
try:
|
|
164
|
+
setattr(other, "_mmgp_int8_t", w_t)
|
|
165
|
+
except Exception:
|
|
166
|
+
pass
|
|
167
|
+
return w_t
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _int8_tc_mm(
|
|
171
|
+
a_int8: torch.Tensor,
|
|
172
|
+
b_int8: torch.Tensor,
|
|
173
|
+
a_scale: torch.Tensor,
|
|
174
|
+
b_scale: torch.Tensor,
|
|
175
|
+
b_int8_t: Optional[torch.Tensor] = None,
|
|
176
|
+
) -> torch.Tensor:
|
|
177
|
+
if a_int8.dtype != torch.int8 or b_int8.dtype != torch.int8:
|
|
178
|
+
raise RuntimeError("int8 TC path requires int8 tensors")
|
|
179
|
+
a_int8 = a_int8.contiguous()
|
|
180
|
+
if b_int8_t is None:
|
|
181
|
+
b_int8_t = b_int8.t().contiguous()
|
|
182
|
+
|
|
183
|
+
# torch._int_mm expects [M, K] @ [K, N]
|
|
184
|
+
acc = torch._int_mm(a_int8, b_int8_t)
|
|
185
|
+
try:
|
|
186
|
+
mod = importlib.import_module("mmgp.quanto_int8_cuda")
|
|
187
|
+
return mod.scale_int32_to(acc, a_scale, b_scale)
|
|
188
|
+
except Exception:
|
|
189
|
+
# Fallback to torch ops if the scaling kernel isn't available
|
|
190
|
+
return (acc.float() * a_scale[:, None] * b_scale[None, :]).to(a_scale.dtype)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _int8_scaled_mm(
|
|
194
|
+
a_int8: torch.Tensor,
|
|
195
|
+
b_int8: torch.Tensor,
|
|
196
|
+
a_scale: torch.Tensor,
|
|
197
|
+
b_scale: torch.Tensor,
|
|
198
|
+
b_int8_t: Optional[torch.Tensor] = None,
|
|
199
|
+
) -> torch.Tensor:
|
|
200
|
+
# a_int8: [M, K], b_int8: [N, K], scales: a=[M], b=[N]
|
|
201
|
+
if not a_int8.is_cuda or not b_int8.is_cuda:
|
|
202
|
+
raise RuntimeError("int8 kernel requires CUDA tensors")
|
|
203
|
+
if a_int8.dtype != torch.int8 or b_int8.dtype != torch.int8:
|
|
204
|
+
raise RuntimeError("int8 kernel requires int8 activations and weights")
|
|
205
|
+
|
|
206
|
+
a_int8 = a_int8.contiguous()
|
|
207
|
+
b_int8 = b_int8.contiguous()
|
|
208
|
+
a_scale = _flatten_scale(a_scale).contiguous()
|
|
209
|
+
b_scale = _flatten_scale(b_scale).contiguous()
|
|
210
|
+
|
|
211
|
+
m, k = a_int8.shape
|
|
212
|
+
n = b_int8.shape[0]
|
|
213
|
+
use_tc = _env_flag(_ENV_USE_TC, "1")
|
|
214
|
+
if use_tc:
|
|
215
|
+
# torch._int_mm requires M > 16 and M/N/K multiples of 8
|
|
216
|
+
if m <= 16:
|
|
217
|
+
m_pad = 24 - m
|
|
218
|
+
else:
|
|
219
|
+
m_pad = (8 - (m % 8)) % 8
|
|
220
|
+
n_pad = (8 - (n % 8)) % 8
|
|
221
|
+
k_pad = (8 - (k % 8)) % 8
|
|
222
|
+
else:
|
|
223
|
+
align_m, align_n, align_k = _get_alignments()
|
|
224
|
+
m_pad = (align_m - (m % align_m)) % align_m
|
|
225
|
+
n_pad = (align_n - (n % align_n)) % align_n
|
|
226
|
+
k_pad = (align_k - (k % align_k)) % align_k
|
|
227
|
+
|
|
228
|
+
if m_pad or n_pad or k_pad:
|
|
229
|
+
a_int8 = _pad_to_multiple(a_int8, m_pad=m_pad, k_pad=k_pad)
|
|
230
|
+
b_int8 = _pad_weights_to_multiple(b_int8, n_pad=n_pad, k_pad=k_pad)
|
|
231
|
+
a_scale = _pad_scale_to_multiple(a_scale, pad=m_pad, pad_value=1.0)
|
|
232
|
+
b_scale = _pad_scale_to_multiple(b_scale, pad=n_pad, pad_value=1.0)
|
|
233
|
+
if b_int8_t is not None and b_int8_t.shape != (b_int8.shape[1], b_int8.shape[0]):
|
|
234
|
+
b_int8_t = None
|
|
235
|
+
|
|
236
|
+
if use_tc:
|
|
237
|
+
out = _int8_tc_mm(a_int8, b_int8, a_scale, b_scale, b_int8_t=b_int8_t)
|
|
238
|
+
else:
|
|
239
|
+
op = _resolve_kernel_op()
|
|
240
|
+
out = op(a_int8, b_int8, a_scale, b_scale)
|
|
241
|
+
if m_pad or n_pad:
|
|
242
|
+
out = out[:m, :n]
|
|
243
|
+
return out
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _use_int8_kernel(input: torch.Tensor, other: torch.Tensor) -> bool:
|
|
247
|
+
if not _is_weight_qbytes(other):
|
|
248
|
+
return False
|
|
249
|
+
if other._data.dtype != torch.int8:
|
|
250
|
+
return False
|
|
251
|
+
if not other._data.is_cuda:
|
|
252
|
+
return False
|
|
253
|
+
if _is_qbytes_tensor(input):
|
|
254
|
+
return input._data.dtype == torch.int8 and input._data.is_cuda
|
|
255
|
+
return input.is_cuda and input.dtype in (torch.bfloat16, torch.float16, torch.float32)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _int8_linear_forward(ctx, input: torch.Tensor, other: torch.Tensor, bias: Optional[torch.Tensor]):
|
|
259
|
+
ctx.save_for_backward(input, other)
|
|
260
|
+
|
|
261
|
+
input_shape = input.shape
|
|
262
|
+
in_features = input_shape[-1]
|
|
263
|
+
out_features = other.shape[0]
|
|
264
|
+
|
|
265
|
+
# Prepare activations
|
|
266
|
+
if _is_qbytes_tensor(input):
|
|
267
|
+
a_2d = input._data.reshape(-1, in_features)
|
|
268
|
+
a_scale = input._scale
|
|
269
|
+
a_scale = _flatten_scale(a_scale).to(other._scale.dtype)
|
|
270
|
+
if a_scale.numel() == 1:
|
|
271
|
+
a_scale = a_scale.reshape(1).expand(a_2d.shape[0]).contiguous()
|
|
272
|
+
elif a_scale.numel() != a_2d.shape[0]:
|
|
273
|
+
raise RuntimeError("Activation scale length does not match token count")
|
|
274
|
+
a_int8 = a_2d
|
|
275
|
+
else:
|
|
276
|
+
a_2d = input.reshape(-1, in_features)
|
|
277
|
+
a_int8, a_scale = _quantize_with_kernel(a_2d, other._scale.dtype)
|
|
278
|
+
|
|
279
|
+
# Per-output scale is handled inside the kernel: out = sum(a*b) * a_scale[row] * b_scale[col]
|
|
280
|
+
b_scale = _flatten_scale(other._scale).to(other._scale.dtype)
|
|
281
|
+
if b_scale.numel() != out_features:
|
|
282
|
+
raise RuntimeError("Weight scale length does not match output features")
|
|
283
|
+
|
|
284
|
+
b_int8_t = _maybe_get_transposed_weight(other)
|
|
285
|
+
out_2d = _int8_scaled_mm(a_int8, other._data, a_scale, b_scale, b_int8_t=b_int8_t)
|
|
286
|
+
out = out_2d.reshape(input_shape[:-1] + (out_features,))
|
|
287
|
+
|
|
288
|
+
if bias is not None:
|
|
289
|
+
out = out + bias
|
|
290
|
+
return out
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
_PATCH_STATE = SimpleNamespace(enabled=False, orig_forward=None)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def enable_quanto_int8_kernel() -> bool:
|
|
297
|
+
if _PATCH_STATE.enabled:
|
|
298
|
+
return True
|
|
299
|
+
try:
|
|
300
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes
|
|
301
|
+
except Exception:
|
|
302
|
+
return False
|
|
303
|
+
|
|
304
|
+
orig_forward = _qbytes.WeightQBytesLinearFunction.forward
|
|
305
|
+
|
|
306
|
+
def forward(ctx, input, other, bias=None):
|
|
307
|
+
if _use_int8_kernel(input, other):
|
|
308
|
+
_debug("using mmgp int8 kernel")
|
|
309
|
+
return _int8_linear_forward(ctx, input, other, bias)
|
|
310
|
+
return orig_forward(ctx, input, other, bias)
|
|
311
|
+
|
|
312
|
+
_qbytes.WeightQBytesLinearFunction.forward = staticmethod(forward)
|
|
313
|
+
_PATCH_STATE.enabled = True
|
|
314
|
+
_PATCH_STATE.orig_forward = orig_forward
|
|
315
|
+
return True
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def disable_quanto_int8_kernel() -> bool:
|
|
319
|
+
if not _PATCH_STATE.enabled:
|
|
320
|
+
return False
|
|
321
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes
|
|
322
|
+
_qbytes.WeightQBytesLinearFunction.forward = staticmethod(_PATCH_STATE.orig_forward)
|
|
323
|
+
_PATCH_STATE.enabled = False
|
|
324
|
+
_PATCH_STATE.orig_forward = None
|
|
325
|
+
return True
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def maybe_enable_quanto_int8_kernel() -> bool:
|
|
329
|
+
if not _env_flag(_ENV_ENABLE, "1"):
|
|
330
|
+
return False
|
|
331
|
+
return enable_quanto_int8_kernel()
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
# Auto-enable on import (default on, can be disabled via env)
|
|
335
|
+
maybe_enable_quanto_int8_kernel()
|