mmgp 3.3.1__py3-none-any.whl → 3.6.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,335 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import os
5
+ from types import SimpleNamespace
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ # Env toggles
12
+ _ENV_ENABLE = "WAN2GP_QUANTO_INT8_KERNEL"
13
+ _ENV_DEBUG = "WAN2GP_QUANTO_INT8_DEBUG"
14
+ _ENV_ALIGN_M = "WAN2GP_QUANTO_INT8_ALIGN_M"
15
+ _ENV_ALIGN_N = "WAN2GP_QUANTO_INT8_ALIGN_N"
16
+ _ENV_ALIGN_K = "WAN2GP_QUANTO_INT8_ALIGN_K"
17
+ _ENV_USE_TC = "WAN2GP_QUANTO_INT8_TC"
18
+
19
+ # Kernel namespace/entrypoints (resolved lazily)
20
+ _KERNEL_OP = None
21
+ _KERNEL_MODULES = (
22
+ "mmgp.quanto_int8_cuda",
23
+ "mmgp_quanto_int8_cuda",
24
+ "mmgp_quanto_int8",
25
+ )
26
+
27
+
28
+ def _env_flag(name: str, default: str = "1") -> bool:
29
+ val = os.environ.get(name, default)
30
+ return str(val).strip().lower() in ("1", "true", "yes", "on")
31
+
32
+
33
+ def _env_int(name: str, default: int) -> int:
34
+ raw = os.environ.get(name)
35
+ if raw is None:
36
+ return default
37
+ try:
38
+ return int(raw)
39
+ except ValueError:
40
+ return default
41
+
42
+
43
+ def _debug(msg: str) -> None:
44
+ if _env_flag(_ENV_DEBUG, "0"):
45
+ print(f"[WAN2GP][INT8][quanto] {msg}")
46
+
47
+
48
+ def _get_alignments() -> Tuple[int, int, int]:
49
+ # Conservative defaults; keep in sync with kernel requirements.
50
+ align_m = _env_int(_ENV_ALIGN_M, 16)
51
+ align_n = _env_int(_ENV_ALIGN_N, 16)
52
+ align_k = _env_int(_ENV_ALIGN_K, 16)
53
+ return align_m, align_n, align_k
54
+
55
+
56
+ def _is_qbytes_tensor(t: torch.Tensor) -> bool:
57
+ try:
58
+ from optimum.quanto.tensor.qbytes import QBytesTensor
59
+ except Exception:
60
+ return False
61
+ return isinstance(t, QBytesTensor)
62
+
63
+
64
+ def _is_weight_qbytes(t: torch.Tensor) -> bool:
65
+ try:
66
+ from optimum.quanto.tensor.weights.qbytes import WeightQBytesTensor
67
+ except Exception:
68
+ return False
69
+ return isinstance(t, WeightQBytesTensor)
70
+
71
+
72
+ def _flatten_scale(scale: torch.Tensor) -> torch.Tensor:
73
+ if scale.ndim == 2 and scale.shape[1] == 1:
74
+ return scale.view(-1)
75
+ if scale.ndim == 1:
76
+ return scale
77
+ return scale.reshape(-1)
78
+
79
+
80
+ def _quantize_activation_per_row(x_2d: torch.Tensor, scale_dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
81
+ # Symmetric int8 quantization per-row: scale = max(abs(row)) / 127
82
+ x_fp32 = x_2d.to(torch.float32)
83
+ amax = x_fp32.abs().amax(dim=1)
84
+ qmax = 127.0
85
+ scale = amax / qmax
86
+ scale = torch.where(scale == 0, torch.ones_like(scale), scale)
87
+ q = torch.round(x_fp32 / scale[:, None]).clamp(-qmax, qmax).to(torch.int8)
88
+ return q, scale.to(scale_dtype)
89
+
90
+
91
+ def _pad_to_multiple(x: torch.Tensor, m_pad: int, k_pad: int) -> torch.Tensor:
92
+ if m_pad == 0 and k_pad == 0:
93
+ return x
94
+ # Pad last dim (K) and then rows (M)
95
+ if k_pad:
96
+ x = F.pad(x, (0, k_pad))
97
+ if m_pad:
98
+ x = F.pad(x, (0, 0, 0, m_pad))
99
+ return x
100
+
101
+
102
+ def _pad_weights_to_multiple(w: torch.Tensor, n_pad: int, k_pad: int) -> torch.Tensor:
103
+ if n_pad == 0 and k_pad == 0:
104
+ return w
105
+ if k_pad:
106
+ w = F.pad(w, (0, k_pad))
107
+ if n_pad:
108
+ w = F.pad(w, (0, 0, 0, n_pad))
109
+ return w
110
+
111
+
112
+ def _pad_scale_to_multiple(scale: torch.Tensor, pad: int, pad_value: float = 1.0) -> torch.Tensor:
113
+ if pad == 0:
114
+ return scale
115
+ return F.pad(scale, (0, pad), value=pad_value)
116
+
117
+
118
+ def _resolve_kernel_op():
119
+ global _KERNEL_OP
120
+ if _KERNEL_OP is not None:
121
+ return _KERNEL_OP
122
+ # torch.ops path (preferred)
123
+ ops_ns = getattr(torch.ops, "mmgp_quanto_int8", None)
124
+ if ops_ns is not None and hasattr(ops_ns, "int8_scaled_mm"):
125
+ _KERNEL_OP = ops_ns.int8_scaled_mm
126
+ return _KERNEL_OP
127
+ # import module path
128
+ for mod_name in _KERNEL_MODULES:
129
+ try:
130
+ mod = importlib.import_module(mod_name)
131
+ except Exception:
132
+ continue
133
+ if hasattr(mod, "int8_scaled_mm"):
134
+ _KERNEL_OP = mod.int8_scaled_mm
135
+ return _KERNEL_OP
136
+ ops_ns = getattr(torch.ops, "mmgp_quanto_int8", None)
137
+ if ops_ns is not None and hasattr(ops_ns, "int8_scaled_mm"):
138
+ _KERNEL_OP = ops_ns.int8_scaled_mm
139
+ return _KERNEL_OP
140
+ raise RuntimeError(
141
+ "mmgp int8 kernel extension not loaded. Expected torch.ops.mmgp_quanto_int8.int8_scaled_mm "
142
+ "or a module exposing int8_scaled_mm (mmgp.quanto_int8_cuda)."
143
+ )
144
+
145
+
146
+ def _quantize_with_kernel(x_2d: torch.Tensor, scale_dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
147
+ try:
148
+ mod = importlib.import_module("mmgp.quanto_int8_cuda")
149
+ q, s = mod.quantize_per_row_int8(x_2d)
150
+ if s.dtype != scale_dtype:
151
+ s = s.to(scale_dtype)
152
+ return q, s
153
+ except Exception:
154
+ return _quantize_activation_per_row(x_2d, scale_dtype)
155
+
156
+
157
+ def _maybe_get_transposed_weight(other: torch.Tensor) -> torch.Tensor:
158
+ cached = getattr(other, "_mmgp_int8_t", None)
159
+ if isinstance(cached, torch.Tensor) and cached.device == other._data.device:
160
+ if cached.shape == (other._data.shape[1], other._data.shape[0]):
161
+ return cached
162
+ w_t = other._data.t().contiguous()
163
+ try:
164
+ setattr(other, "_mmgp_int8_t", w_t)
165
+ except Exception:
166
+ pass
167
+ return w_t
168
+
169
+
170
+ def _int8_tc_mm(
171
+ a_int8: torch.Tensor,
172
+ b_int8: torch.Tensor,
173
+ a_scale: torch.Tensor,
174
+ b_scale: torch.Tensor,
175
+ b_int8_t: Optional[torch.Tensor] = None,
176
+ ) -> torch.Tensor:
177
+ if a_int8.dtype != torch.int8 or b_int8.dtype != torch.int8:
178
+ raise RuntimeError("int8 TC path requires int8 tensors")
179
+ a_int8 = a_int8.contiguous()
180
+ if b_int8_t is None:
181
+ b_int8_t = b_int8.t().contiguous()
182
+
183
+ # torch._int_mm expects [M, K] @ [K, N]
184
+ acc = torch._int_mm(a_int8, b_int8_t)
185
+ try:
186
+ mod = importlib.import_module("mmgp.quanto_int8_cuda")
187
+ return mod.scale_int32_to(acc, a_scale, b_scale)
188
+ except Exception:
189
+ # Fallback to torch ops if the scaling kernel isn't available
190
+ return (acc.float() * a_scale[:, None] * b_scale[None, :]).to(a_scale.dtype)
191
+
192
+
193
+ def _int8_scaled_mm(
194
+ a_int8: torch.Tensor,
195
+ b_int8: torch.Tensor,
196
+ a_scale: torch.Tensor,
197
+ b_scale: torch.Tensor,
198
+ b_int8_t: Optional[torch.Tensor] = None,
199
+ ) -> torch.Tensor:
200
+ # a_int8: [M, K], b_int8: [N, K], scales: a=[M], b=[N]
201
+ if not a_int8.is_cuda or not b_int8.is_cuda:
202
+ raise RuntimeError("int8 kernel requires CUDA tensors")
203
+ if a_int8.dtype != torch.int8 or b_int8.dtype != torch.int8:
204
+ raise RuntimeError("int8 kernel requires int8 activations and weights")
205
+
206
+ a_int8 = a_int8.contiguous()
207
+ b_int8 = b_int8.contiguous()
208
+ a_scale = _flatten_scale(a_scale).contiguous()
209
+ b_scale = _flatten_scale(b_scale).contiguous()
210
+
211
+ m, k = a_int8.shape
212
+ n = b_int8.shape[0]
213
+ use_tc = _env_flag(_ENV_USE_TC, "1")
214
+ if use_tc:
215
+ # torch._int_mm requires M > 16 and M/N/K multiples of 8
216
+ if m <= 16:
217
+ m_pad = 24 - m
218
+ else:
219
+ m_pad = (8 - (m % 8)) % 8
220
+ n_pad = (8 - (n % 8)) % 8
221
+ k_pad = (8 - (k % 8)) % 8
222
+ else:
223
+ align_m, align_n, align_k = _get_alignments()
224
+ m_pad = (align_m - (m % align_m)) % align_m
225
+ n_pad = (align_n - (n % align_n)) % align_n
226
+ k_pad = (align_k - (k % align_k)) % align_k
227
+
228
+ if m_pad or n_pad or k_pad:
229
+ a_int8 = _pad_to_multiple(a_int8, m_pad=m_pad, k_pad=k_pad)
230
+ b_int8 = _pad_weights_to_multiple(b_int8, n_pad=n_pad, k_pad=k_pad)
231
+ a_scale = _pad_scale_to_multiple(a_scale, pad=m_pad, pad_value=1.0)
232
+ b_scale = _pad_scale_to_multiple(b_scale, pad=n_pad, pad_value=1.0)
233
+ if b_int8_t is not None and b_int8_t.shape != (b_int8.shape[1], b_int8.shape[0]):
234
+ b_int8_t = None
235
+
236
+ if use_tc:
237
+ out = _int8_tc_mm(a_int8, b_int8, a_scale, b_scale, b_int8_t=b_int8_t)
238
+ else:
239
+ op = _resolve_kernel_op()
240
+ out = op(a_int8, b_int8, a_scale, b_scale)
241
+ if m_pad or n_pad:
242
+ out = out[:m, :n]
243
+ return out
244
+
245
+
246
+ def _use_int8_kernel(input: torch.Tensor, other: torch.Tensor) -> bool:
247
+ if not _is_weight_qbytes(other):
248
+ return False
249
+ if other._data.dtype != torch.int8:
250
+ return False
251
+ if not other._data.is_cuda:
252
+ return False
253
+ if _is_qbytes_tensor(input):
254
+ return input._data.dtype == torch.int8 and input._data.is_cuda
255
+ return input.is_cuda and input.dtype in (torch.bfloat16, torch.float16, torch.float32)
256
+
257
+
258
+ def _int8_linear_forward(ctx, input: torch.Tensor, other: torch.Tensor, bias: Optional[torch.Tensor]):
259
+ ctx.save_for_backward(input, other)
260
+
261
+ input_shape = input.shape
262
+ in_features = input_shape[-1]
263
+ out_features = other.shape[0]
264
+
265
+ # Prepare activations
266
+ if _is_qbytes_tensor(input):
267
+ a_2d = input._data.reshape(-1, in_features)
268
+ a_scale = input._scale
269
+ a_scale = _flatten_scale(a_scale).to(other._scale.dtype)
270
+ if a_scale.numel() == 1:
271
+ a_scale = a_scale.reshape(1).expand(a_2d.shape[0]).contiguous()
272
+ elif a_scale.numel() != a_2d.shape[0]:
273
+ raise RuntimeError("Activation scale length does not match token count")
274
+ a_int8 = a_2d
275
+ else:
276
+ a_2d = input.reshape(-1, in_features)
277
+ a_int8, a_scale = _quantize_with_kernel(a_2d, other._scale.dtype)
278
+
279
+ # Per-output scale is handled inside the kernel: out = sum(a*b) * a_scale[row] * b_scale[col]
280
+ b_scale = _flatten_scale(other._scale).to(other._scale.dtype)
281
+ if b_scale.numel() != out_features:
282
+ raise RuntimeError("Weight scale length does not match output features")
283
+
284
+ b_int8_t = _maybe_get_transposed_weight(other)
285
+ out_2d = _int8_scaled_mm(a_int8, other._data, a_scale, b_scale, b_int8_t=b_int8_t)
286
+ out = out_2d.reshape(input_shape[:-1] + (out_features,))
287
+
288
+ if bias is not None:
289
+ out = out + bias
290
+ return out
291
+
292
+
293
+ _PATCH_STATE = SimpleNamespace(enabled=False, orig_forward=None)
294
+
295
+
296
+ def enable_quanto_int8_kernel() -> bool:
297
+ if _PATCH_STATE.enabled:
298
+ return True
299
+ try:
300
+ from optimum.quanto.tensor.weights import qbytes as _qbytes
301
+ except Exception:
302
+ return False
303
+
304
+ orig_forward = _qbytes.WeightQBytesLinearFunction.forward
305
+
306
+ def forward(ctx, input, other, bias=None):
307
+ if _use_int8_kernel(input, other):
308
+ _debug("using mmgp int8 kernel")
309
+ return _int8_linear_forward(ctx, input, other, bias)
310
+ return orig_forward(ctx, input, other, bias)
311
+
312
+ _qbytes.WeightQBytesLinearFunction.forward = staticmethod(forward)
313
+ _PATCH_STATE.enabled = True
314
+ _PATCH_STATE.orig_forward = orig_forward
315
+ return True
316
+
317
+
318
+ def disable_quanto_int8_kernel() -> bool:
319
+ if not _PATCH_STATE.enabled:
320
+ return False
321
+ from optimum.quanto.tensor.weights import qbytes as _qbytes
322
+ _qbytes.WeightQBytesLinearFunction.forward = staticmethod(_PATCH_STATE.orig_forward)
323
+ _PATCH_STATE.enabled = False
324
+ _PATCH_STATE.orig_forward = None
325
+ return True
326
+
327
+
328
+ def maybe_enable_quanto_int8_kernel() -> bool:
329
+ if not _env_flag(_ENV_ENABLE, "1"):
330
+ return False
331
+ return enable_quanto_int8_kernel()
332
+
333
+
334
+ # Auto-enable on import (default on, can be disabled via env)
335
+ maybe_enable_quanto_int8_kernel()