mmgp 3.3.1__py3-none-any.whl → 3.6.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mmgp/quant_router.py ADDED
@@ -0,0 +1,518 @@
1
+ import importlib
2
+ import inspect
3
+ import os
4
+
5
+ import torch
6
+ from optimum.quanto import QModuleMixin, register_qmodule
7
+ from optimum.quanto.tensor.qtype import qtype as _quanto_qtype
8
+
9
+ from . import safetensors2
10
+
11
+
12
+ _QTYPE_QMODULE_CACHE = None
13
+ _QMODULE_BASE_ATTRS = None
14
+
15
+
16
+ def _extract_qtypes(handler):
17
+ for obj in vars(handler).values():
18
+ if isinstance(obj, _quanto_qtype):
19
+ yield obj
20
+
21
+
22
+ def _extract_qmodule_classes(handler):
23
+ for obj in vars(handler).values():
24
+ if inspect.isclass(obj) and issubclass(obj, QModuleMixin) and issubclass(obj, torch.nn.Linear):
25
+ if obj is QLinearQuantoRouter:
26
+ continue
27
+ yield obj
28
+
29
+
30
+ def _build_qmodule_cache():
31
+ mapping = {}
32
+ for handler in _load_handlers():
33
+ qmodule_classes = list(_extract_qmodule_classes(handler))
34
+ if len(qmodule_classes) != 1:
35
+ continue
36
+ qmodule_cls = qmodule_classes[0]
37
+ for qt in _extract_qtypes(handler):
38
+ mapping.setdefault(qt, qmodule_cls)
39
+ return mapping
40
+
41
+
42
+ def _get_qmodule_base_attrs():
43
+ global _QMODULE_BASE_ATTRS
44
+ if _QMODULE_BASE_ATTRS is not None:
45
+ return _QMODULE_BASE_ATTRS
46
+ base = torch.nn.Linear(1, 1, bias=True)
47
+ _QMODULE_BASE_ATTRS = set(base.__dict__.keys())
48
+ _QMODULE_BASE_ATTRS.update({
49
+ "_parameters",
50
+ "_buffers",
51
+ "_modules",
52
+ "_non_persistent_buffers_set",
53
+ })
54
+ return _QMODULE_BASE_ATTRS
55
+
56
+
57
+ def _get_qmodule_for_qtype(qtype_obj):
58
+ global _QTYPE_QMODULE_CACHE
59
+ if qtype_obj is None:
60
+ return None
61
+ if _QTYPE_QMODULE_CACHE is None or qtype_obj not in _QTYPE_QMODULE_CACHE:
62
+ _QTYPE_QMODULE_CACHE = _build_qmodule_cache()
63
+ return _QTYPE_QMODULE_CACHE.get(qtype_obj)
64
+
65
+
66
+ def _load_with_qmodule(
67
+ module,
68
+ qmodule_cls,
69
+ state_dict,
70
+ prefix,
71
+ local_metadata,
72
+ strict,
73
+ missing_keys,
74
+ unexpected_keys,
75
+ error_msgs,
76
+ ):
77
+ device = module.weight.device if torch.is_tensor(module.weight) else None
78
+ if torch.is_tensor(module.weight) and module.weight.dtype.is_floating_point:
79
+ weight_dtype = module.weight.dtype
80
+ elif torch.is_tensor(getattr(module, "bias", None)) and module.bias.dtype.is_floating_point:
81
+ weight_dtype = module.bias.dtype
82
+ else:
83
+ weight_dtype = torch.float16
84
+ tmp = qmodule_cls(
85
+ module.in_features,
86
+ module.out_features,
87
+ bias=module.bias is not None,
88
+ device=device,
89
+ dtype=weight_dtype,
90
+ weights=module.weight_qtype,
91
+ activations=module.activation_qtype,
92
+ optimizer=module.optimizer,
93
+ quantize_input=True,
94
+ )
95
+ setter = getattr(tmp, "set_default_dtype", None)
96
+ if callable(setter):
97
+ setter(getattr(module, "_router_default_dtype", None) or module.weight.dtype)
98
+ tmp._load_from_state_dict(
99
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
100
+ )
101
+
102
+ module.weight = tmp.weight
103
+ module.bias = tmp.bias
104
+ module.input_scale = tmp.input_scale
105
+ module.output_scale = tmp.output_scale
106
+
107
+ ignore = set(_get_qmodule_base_attrs())
108
+ ignore.update({
109
+ "_quantize_hooks",
110
+ "training",
111
+ "_router_default_dtype",
112
+ })
113
+ for name, value in tmp.__dict__.items():
114
+ if name in ignore:
115
+ continue
116
+ setattr(module, name, value)
117
+ module._router_forward_impl = qmodule_cls.forward
118
+
119
+
120
+ @register_qmodule(torch.nn.Linear)
121
+ class QLinearQuantoRouter(QModuleMixin, torch.nn.Linear):
122
+ @classmethod
123
+ def qcreate(
124
+ cls,
125
+ module,
126
+ weights,
127
+ activations=None,
128
+ optimizer=None,
129
+ device=None,
130
+ ):
131
+ if torch.is_tensor(module.weight) and module.weight.dtype.is_floating_point:
132
+ weight_dtype = module.weight.dtype
133
+ elif torch.is_tensor(getattr(module, "bias", None)) and module.bias.dtype.is_floating_point:
134
+ weight_dtype = module.bias.dtype
135
+ else:
136
+ weight_dtype = torch.float16
137
+ return cls(
138
+ module.in_features,
139
+ module.out_features,
140
+ module.bias is not None,
141
+ device=device,
142
+ dtype=weight_dtype,
143
+ weights=weights,
144
+ activations=activations,
145
+ optimizer=optimizer,
146
+ quantize_input=True,
147
+ )
148
+
149
+ def __init__(
150
+ self,
151
+ in_features,
152
+ out_features,
153
+ bias=True,
154
+ device=None,
155
+ dtype=None,
156
+ weights=None,
157
+ activations=None,
158
+ optimizer=None,
159
+ quantize_input=True,
160
+ ):
161
+ super().__init__(
162
+ in_features,
163
+ out_features,
164
+ bias=bias,
165
+ device=device,
166
+ dtype=dtype,
167
+ weights=weights,
168
+ activations=activations,
169
+ optimizer=optimizer,
170
+ quantize_input=quantize_input,
171
+ )
172
+ self._router_default_dtype = dtype
173
+
174
+ def set_default_dtype(self, dtype):
175
+ self._router_default_dtype = dtype
176
+
177
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
178
+ impl = getattr(self, "_router_forward_impl", None)
179
+ if impl is not None:
180
+ return impl(self, input)
181
+ return torch.nn.functional.linear(input, self.qweight, bias=self.bias)
182
+
183
+ def _load_from_state_dict(
184
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
185
+ ):
186
+ qmodule_cls = _get_qmodule_for_qtype(self.weight_qtype)
187
+ if qmodule_cls is not None and qmodule_cls is not QLinearQuantoRouter:
188
+ return _load_with_qmodule(
189
+ self, qmodule_cls, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
190
+ )
191
+ return super()._load_from_state_dict(
192
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
193
+ )
194
+
195
+
196
+ _FP8_QUANTO_BRIDGE_MODULE = ".fp8_quanto_bridge"
197
+
198
+ _HANDLER_MODULES = [
199
+ _FP8_QUANTO_BRIDGE_MODULE,
200
+ ]
201
+ _HANDLER_OBJECTS = []
202
+
203
+
204
+ def register_handler(handler):
205
+ global _QTYPE_QMODULE_CACHE
206
+ if isinstance(handler, str):
207
+ if handler not in _HANDLER_MODULES:
208
+ _HANDLER_MODULES.append(handler)
209
+ _QTYPE_QMODULE_CACHE = None
210
+ return handler
211
+ if handler not in _HANDLER_OBJECTS:
212
+ _HANDLER_OBJECTS.append(handler)
213
+ _QTYPE_QMODULE_CACHE = None
214
+ return handler
215
+
216
+
217
+ def unregister_handler(handler):
218
+ global _QTYPE_QMODULE_CACHE
219
+ removed = False
220
+ if isinstance(handler, str):
221
+ if handler in _HANDLER_MODULES:
222
+ _HANDLER_MODULES.remove(handler)
223
+ removed = True
224
+ elif handler in _HANDLER_OBJECTS:
225
+ _HANDLER_OBJECTS.remove(handler)
226
+ removed = True
227
+ if removed:
228
+ _QTYPE_QMODULE_CACHE = None
229
+ return removed
230
+
231
+
232
+ def _load_handlers():
233
+ handlers = []
234
+ for mod_path in _HANDLER_MODULES:
235
+ module = importlib.import_module(mod_path, package=__package__)
236
+ if not hasattr(module, "detect") or not hasattr(module, "convert_to_quanto"):
237
+ raise RuntimeError(
238
+ f"Quant handler '{mod_path}' must define detect() and convert_to_quanto() functions."
239
+ )
240
+ handlers.append(module)
241
+ for handler in _HANDLER_OBJECTS:
242
+ if not hasattr(handler, "detect") or not hasattr(handler, "convert_to_quanto"):
243
+ raise RuntimeError(
244
+ "Quant handler object must define detect() and convert_to_quanto() functions."
245
+ )
246
+ handlers.append(handler)
247
+ register_qmodule(torch.nn.Linear)(QLinearQuantoRouter)
248
+ return handlers
249
+
250
+
251
+ def _handler_name(handler):
252
+ return getattr(handler, "HANDLER_NAME", handler.__name__.split(".")[-1])
253
+
254
+
255
+ def detect_safetensors_format(state_dict, verboseLevel=1):
256
+ matches = []
257
+ details = {}
258
+ for handler in _load_handlers():
259
+ result = handler.detect(state_dict, verboseLevel=verboseLevel)
260
+ name = _handler_name(handler)
261
+ details[name] = result
262
+ if result.get("matched", False):
263
+ matches.append(name)
264
+ if len(matches) > 1:
265
+ return {"kind": "mixed", "found": matches, "details": details}
266
+ if len(matches) == 1:
267
+ return {"kind": matches[0], "found": matches, "details": details}
268
+ return {"kind": "none", "found": [], "details": details}
269
+
270
+
271
+ def detect_and_convert(state_dict, default_dtype, verboseLevel=1):
272
+ info = detect_safetensors_format(state_dict, verboseLevel=verboseLevel)
273
+ kind = info.get("kind", "none")
274
+ if kind == "mixed":
275
+ found = info.get("found", [])
276
+ details = info.get("details", {})
277
+ raise RuntimeError(f"Mixed quantization formats detected: {found} details={details}")
278
+ if kind in ("none", "quanto"):
279
+ return {"state_dict": state_dict, "quant_map": {}, "kind": kind, "details": info}
280
+ for handler in _load_handlers():
281
+ if _handler_name(handler) == kind:
282
+ detection = info.get("details", {}).get(kind, {})
283
+ conv = handler.convert_to_quanto(
284
+ state_dict,
285
+ default_dtype=default_dtype,
286
+ verboseLevel=verboseLevel,
287
+ detection=detection,
288
+ )
289
+ conv["kind"] = kind
290
+ conv["details"] = info
291
+ return conv
292
+ raise RuntimeError(f"Unsupported quantization format '{kind}'")
293
+
294
+
295
+ def get_available_qtypes():
296
+ try:
297
+ from optimum.quanto.tensor.qtype import qtypes as _quanto_qtypes
298
+ except Exception:
299
+ return []
300
+ return sorted(_quanto_qtypes.keys())
301
+
302
+
303
+ def get_available_qtype_aliases():
304
+ aliases = set()
305
+ for name in get_available_qtypes():
306
+ key = str(name).lower()
307
+ aliases.add(key)
308
+ if key.startswith("q") and len(key) > 1:
309
+ aliases.add(key[1:])
310
+ if "float8" in key:
311
+ aliases.add("fp8")
312
+ return aliases
313
+
314
+
315
+ def get_quantization_tokens(quantization):
316
+ if quantization is None:
317
+ return []
318
+ key = str(quantization).lower()
319
+ if len(key) == 0:
320
+ return []
321
+ aliases = get_available_qtype_aliases()
322
+ if key not in aliases:
323
+ return []
324
+ tokens = {key}
325
+ if key.startswith("q") and len(key) > 1:
326
+ tokens.add(key[1:])
327
+ if "float8" in key or key == "fp8":
328
+ tokens.add("fp8")
329
+ if "int4" in key:
330
+ tokens.add("int4")
331
+ if "int8" in key:
332
+ tokens.add("int8")
333
+ return sorted(tokens, key=len, reverse=True)
334
+
335
+
336
+ def get_quantization_label(quantization):
337
+ if quantization is None:
338
+ return ""
339
+ key = str(quantization).lower()
340
+ if key in ("", "none", "bf16", "fp16", "float16", "bfloat16"):
341
+ return ""
342
+ aliases = get_available_qtype_aliases()
343
+ if key not in aliases:
344
+ return ""
345
+ if "float8" in key or key == "fp8":
346
+ return "FP8"
347
+ if key.startswith("q"):
348
+ key = key[1:]
349
+ return key.replace("_", " ").upper()
350
+
351
+
352
+ _quantization_filename_cache = {}
353
+
354
+
355
+ def _normalize_quant_file_key(file_path):
356
+ try:
357
+ return os.path.normcase(os.path.abspath(file_path))
358
+ except Exception:
359
+ return str(file_path).lower()
360
+
361
+
362
+ def get_cached_quantization_for_file(file_path):
363
+ if not file_path:
364
+ return None
365
+ return _quantization_filename_cache.get(_normalize_quant_file_key(file_path))
366
+
367
+
368
+ def cache_quantization_for_file(file_path, kind):
369
+ if not file_path or not kind:
370
+ return
371
+ key = _normalize_quant_file_key(file_path)
372
+ if key not in _quantization_filename_cache:
373
+ _quantization_filename_cache[key] = kind
374
+
375
+
376
+ def _infer_qtype_from_quantization_map(quantization_map):
377
+ if not quantization_map:
378
+ return None
379
+ counts = {}
380
+ for entry in quantization_map.values():
381
+ if not isinstance(entry, dict):
382
+ continue
383
+ weights = entry.get("weights")
384
+ if not weights or weights == "none":
385
+ continue
386
+ counts[weights] = counts.get(weights, 0) + 1
387
+ if not counts:
388
+ return None
389
+ return max(counts, key=counts.get)
390
+
391
+
392
+ def detect_quantization_kind_for_file(file_path, verboseLevel=1):
393
+ cached = get_cached_quantization_for_file(file_path)
394
+ if cached:
395
+ return cached
396
+ if not file_path or not os.path.isfile(file_path):
397
+ return None
398
+ if not (".safetensors" in file_path or ".sft" in file_path):
399
+ return None
400
+
401
+ def _load_full():
402
+ state_dict = {}
403
+ with safetensors2.safe_open(
404
+ file_path,
405
+ framework="pt",
406
+ device="cpu",
407
+ writable_tensors=False,
408
+ ) as f:
409
+ for key in f.keys():
410
+ state_dict[key] = f.get_tensor(key)
411
+ metadata = f.metadata()
412
+ return state_dict, metadata
413
+
414
+ def _try_detect(state_dict):
415
+ try:
416
+ info = detect_safetensors_format(state_dict, verboseLevel=verboseLevel)
417
+ return info.get("kind"), True
418
+ except Exception:
419
+ return None, False
420
+
421
+ metadata_only = False
422
+ try:
423
+ state_dict, metadata = safetensors2.load_metadata_state_dict(file_path)
424
+ metadata_only = True
425
+ except Exception:
426
+ try:
427
+ state_dict, metadata = _load_full()
428
+ except Exception:
429
+ return None
430
+
431
+ kind, ok = _try_detect(state_dict)
432
+ if metadata_only and not ok:
433
+ try:
434
+ state_dict, metadata = _load_full()
435
+ kind, ok = _try_detect(state_dict)
436
+ except Exception:
437
+ kind = None
438
+
439
+ if (not kind or kind == "none") and metadata is not None:
440
+ inferred = _infer_qtype_from_quantization_map(metadata.get("quantization_map"))
441
+ if inferred:
442
+ kind = inferred
443
+
444
+ cache_quantization_for_file(file_path, kind or "none")
445
+ return kind
446
+
447
+
448
+ def detect_quantization_label_from_filename(filename):
449
+ if not filename:
450
+ return ""
451
+ cached = get_cached_quantization_for_file(filename)
452
+ if cached:
453
+ return get_quantization_label(cached)
454
+ kind = detect_quantization_kind_for_file(filename, verboseLevel=0)
455
+ if kind:
456
+ label = get_quantization_label(kind)
457
+ if label:
458
+ return label
459
+ base = os.path.basename(filename).lower()
460
+ for token in sorted(get_available_qtype_aliases(), key=len, reverse=True):
461
+ if token and token in base:
462
+ return get_quantization_label(token)
463
+ if "quanto" in base:
464
+ return "QUANTO"
465
+ return ""
466
+
467
+
468
+ def apply_pre_quantization(model, state_dict, quantization_map, default_dtype=None, verboseLevel=1):
469
+ remaining = dict(quantization_map or {})
470
+ post_load = []
471
+ for handler in _load_handlers():
472
+ fn = getattr(handler, "apply_pre_quantization", None)
473
+ if fn is None:
474
+ continue
475
+ remaining, hooks = fn(
476
+ model,
477
+ state_dict,
478
+ remaining,
479
+ default_dtype=default_dtype,
480
+ verboseLevel=verboseLevel,
481
+ )
482
+ if hooks:
483
+ post_load.extend(hooks)
484
+ return remaining, post_load
485
+
486
+ def _patch_marlin_fp8_bias():
487
+ """
488
+ Quanto's Marlin FP8 CUDA kernel currently ignores the bias argument.
489
+ Add it back manually (in-place) so outputs stay correct on CUDA builds.
490
+ """
491
+ try:
492
+ from optimum.quanto.tensor.weights.marlin.fp8 import qbits as marlin_fp8
493
+ except Exception:
494
+ return
495
+ if getattr(marlin_fp8.MarlinF8QBytesLinearFunction, "_wan2gp_bias_patch", False):
496
+ return
497
+
498
+ orig_forward = marlin_fp8.MarlinF8QBytesLinearFunction.forward
499
+
500
+ def forward_with_bias(ctx, input, other, bias=None):
501
+ out = orig_forward(ctx, input, other, None)
502
+ if bias is None:
503
+ return out
504
+ bias_to_add = bias
505
+ if bias_to_add.device != out.device or bias_to_add.dtype != out.dtype:
506
+ bias_to_add = bias_to_add.to(device=out.device, dtype=out.dtype)
507
+ view_shape = [1] * out.ndim
508
+ view_shape[-1] = bias_to_add.shape[0]
509
+ bias_view = bias_to_add.view(*view_shape)
510
+ out.add_(bias_view)
511
+ return out
512
+
513
+ marlin_fp8.MarlinF8QBytesLinearFunction.forward = staticmethod(forward_with_bias) # type: ignore
514
+ marlin_fp8.MarlinF8QBytesLinearFunction._wan2gp_bias_patch = True # type: ignore
515
+ marlin_fp8.MarlinF8QBytesLinearFunction._wan2gp_bias_orig = orig_forward # type: ignore
516
+
517
+
518
+ _patch_marlin_fp8_bias()
@@ -0,0 +1,97 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from types import SimpleNamespace
6
+
7
+ import torch
8
+ from torch.utils import cpp_extension
9
+
10
+ _STATE = SimpleNamespace(module=None)
11
+
12
+ def _maybe_set_msvc_env() -> None:
13
+ if os.name != "nt":
14
+ return
15
+ if os.environ.get("VCToolsInstallDir") and os.environ.get("INCLUDE"):
16
+ return
17
+ vs_root = Path(os.environ.get("VSINSTALLDIR", r"C:\Program Files\Microsoft Visual Studio\2022\Community"))
18
+ msvc_root = vs_root / "VC" / "Tools" / "MSVC"
19
+ sdk_root = Path(r"C:\Program Files (x86)\Windows Kits\10")
20
+ if not msvc_root.exists() or not sdk_root.exists():
21
+ return
22
+ msvc_versions = sorted(msvc_root.iterdir(), reverse=True)
23
+ sdk_includes = sorted((sdk_root / "Include").iterdir(), reverse=True)
24
+ if not msvc_versions or not sdk_includes:
25
+ return
26
+ vc_tools = msvc_versions[0]
27
+ sdk_ver = sdk_includes[0].name
28
+ os.environ.setdefault("VCToolsInstallDir", str(vc_tools) + os.sep)
29
+ os.environ.setdefault("VCINSTALLDIR", str(vs_root / "VC") + os.sep)
30
+ os.environ.setdefault("WindowsSdkDir", str(sdk_root) + os.sep)
31
+ os.environ.setdefault("WindowsSDKVersion", sdk_ver + os.sep)
32
+ os.environ["PATH"] = f"{vc_tools}\\bin\\Hostx64\\x64;{sdk_root}\\bin\\{sdk_ver}\\x64;" + os.environ.get("PATH", "")
33
+ os.environ["INCLUDE"] = (
34
+ f"{vc_tools}\\include;{sdk_root}\\Include\\{sdk_ver}\\ucrt;{sdk_root}\\Include\\{sdk_ver}\\shared;"
35
+ f"{sdk_root}\\Include\\{sdk_ver}\\um;{sdk_root}\\Include\\{sdk_ver}\\winrt;{sdk_root}\\Include\\{sdk_ver}\\cppwinrt"
36
+ )
37
+ os.environ["LIB"] = (
38
+ f"{vc_tools}\\lib\\x64;{sdk_root}\\Lib\\{sdk_ver}\\ucrt\\x64;{sdk_root}\\Lib\\{sdk_ver}\\um\\x64"
39
+ )
40
+
41
+
42
+ def _extra_cflags() -> list[str]:
43
+ if os.name == "nt":
44
+ return ["/O2"]
45
+ return ["-O3"]
46
+
47
+
48
+ def _extra_cuda_cflags() -> list[str]:
49
+ flags = [
50
+ "-O3",
51
+ "--use_fast_math",
52
+ "--expt-relaxed-constexpr",
53
+ "--expt-extended-lambda",
54
+ ]
55
+ return flags
56
+
57
+
58
+ def _sources() -> list[str]:
59
+ src_dir = Path(__file__).parent / "quanto_int8_kernels"
60
+ return [
61
+ str(src_dir / "int8_scaled_mm.cpp"),
62
+ str(src_dir / "int8_scaled_mm.cu"),
63
+ ]
64
+
65
+
66
+ def _build_dir() -> str:
67
+ build_dir = Path(__file__).parent / "quanto_int8_kernels" / "build"
68
+ build_dir.mkdir(parents=True, exist_ok=True)
69
+ return str(build_dir)
70
+
71
+
72
+ def load() -> object:
73
+ if _STATE.module is not None:
74
+ return _STATE.module
75
+ _maybe_set_msvc_env()
76
+ name = "mmgp_quanto_int8_cuda"
77
+ _STATE.module = cpp_extension.load(
78
+ name=name,
79
+ sources=_sources(),
80
+ build_directory=_build_dir(),
81
+ extra_cflags=_extra_cflags(),
82
+ extra_cuda_cflags=_extra_cuda_cflags(),
83
+ verbose=True,
84
+ )
85
+ return _STATE.module
86
+
87
+
88
+ def int8_scaled_mm(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor) -> torch.Tensor:
89
+ return load().int8_scaled_mm(a, b, a_scale, b_scale)
90
+
91
+
92
+ def quantize_per_row_int8(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
93
+ return load().quantize_per_row_int8(a)
94
+
95
+
96
+ def scale_int32_to(acc: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor) -> torch.Tensor:
97
+ return load().scale_int32_to(acc, a_scale, b_scale)