mmgp 3.5.7__py3-none-any.whl → 3.6.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mmgp/fp8_quanto_bridge.py +645 -0
- mmgp/fp8_quanto_bridge_old.py +498 -0
- mmgp/offload.py +1038 -248
- mmgp/quant_router.py +518 -0
- mmgp/quanto_int8_cuda.py +97 -0
- mmgp/quanto_int8_inject.py +335 -0
- mmgp/safetensors2.py +57 -10
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/METADATA +2 -2
- mmgp-3.6.11.dist-info/RECORD +14 -0
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/licenses/LICENSE.md +1 -1
- mmgp-3.5.7.dist-info/RECORD +0 -9
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/WHEEL +0 -0
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/top_level.txt +0 -0
mmgp/quant_router.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import inspect
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from optimum.quanto import QModuleMixin, register_qmodule
|
|
7
|
+
from optimum.quanto.tensor.qtype import qtype as _quanto_qtype
|
|
8
|
+
|
|
9
|
+
from . import safetensors2
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_QTYPE_QMODULE_CACHE = None
|
|
13
|
+
_QMODULE_BASE_ATTRS = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _extract_qtypes(handler):
|
|
17
|
+
for obj in vars(handler).values():
|
|
18
|
+
if isinstance(obj, _quanto_qtype):
|
|
19
|
+
yield obj
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _extract_qmodule_classes(handler):
|
|
23
|
+
for obj in vars(handler).values():
|
|
24
|
+
if inspect.isclass(obj) and issubclass(obj, QModuleMixin) and issubclass(obj, torch.nn.Linear):
|
|
25
|
+
if obj is QLinearQuantoRouter:
|
|
26
|
+
continue
|
|
27
|
+
yield obj
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _build_qmodule_cache():
|
|
31
|
+
mapping = {}
|
|
32
|
+
for handler in _load_handlers():
|
|
33
|
+
qmodule_classes = list(_extract_qmodule_classes(handler))
|
|
34
|
+
if len(qmodule_classes) != 1:
|
|
35
|
+
continue
|
|
36
|
+
qmodule_cls = qmodule_classes[0]
|
|
37
|
+
for qt in _extract_qtypes(handler):
|
|
38
|
+
mapping.setdefault(qt, qmodule_cls)
|
|
39
|
+
return mapping
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_qmodule_base_attrs():
|
|
43
|
+
global _QMODULE_BASE_ATTRS
|
|
44
|
+
if _QMODULE_BASE_ATTRS is not None:
|
|
45
|
+
return _QMODULE_BASE_ATTRS
|
|
46
|
+
base = torch.nn.Linear(1, 1, bias=True)
|
|
47
|
+
_QMODULE_BASE_ATTRS = set(base.__dict__.keys())
|
|
48
|
+
_QMODULE_BASE_ATTRS.update({
|
|
49
|
+
"_parameters",
|
|
50
|
+
"_buffers",
|
|
51
|
+
"_modules",
|
|
52
|
+
"_non_persistent_buffers_set",
|
|
53
|
+
})
|
|
54
|
+
return _QMODULE_BASE_ATTRS
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _get_qmodule_for_qtype(qtype_obj):
|
|
58
|
+
global _QTYPE_QMODULE_CACHE
|
|
59
|
+
if qtype_obj is None:
|
|
60
|
+
return None
|
|
61
|
+
if _QTYPE_QMODULE_CACHE is None or qtype_obj not in _QTYPE_QMODULE_CACHE:
|
|
62
|
+
_QTYPE_QMODULE_CACHE = _build_qmodule_cache()
|
|
63
|
+
return _QTYPE_QMODULE_CACHE.get(qtype_obj)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _load_with_qmodule(
|
|
67
|
+
module,
|
|
68
|
+
qmodule_cls,
|
|
69
|
+
state_dict,
|
|
70
|
+
prefix,
|
|
71
|
+
local_metadata,
|
|
72
|
+
strict,
|
|
73
|
+
missing_keys,
|
|
74
|
+
unexpected_keys,
|
|
75
|
+
error_msgs,
|
|
76
|
+
):
|
|
77
|
+
device = module.weight.device if torch.is_tensor(module.weight) else None
|
|
78
|
+
if torch.is_tensor(module.weight) and module.weight.dtype.is_floating_point:
|
|
79
|
+
weight_dtype = module.weight.dtype
|
|
80
|
+
elif torch.is_tensor(getattr(module, "bias", None)) and module.bias.dtype.is_floating_point:
|
|
81
|
+
weight_dtype = module.bias.dtype
|
|
82
|
+
else:
|
|
83
|
+
weight_dtype = torch.float16
|
|
84
|
+
tmp = qmodule_cls(
|
|
85
|
+
module.in_features,
|
|
86
|
+
module.out_features,
|
|
87
|
+
bias=module.bias is not None,
|
|
88
|
+
device=device,
|
|
89
|
+
dtype=weight_dtype,
|
|
90
|
+
weights=module.weight_qtype,
|
|
91
|
+
activations=module.activation_qtype,
|
|
92
|
+
optimizer=module.optimizer,
|
|
93
|
+
quantize_input=True,
|
|
94
|
+
)
|
|
95
|
+
setter = getattr(tmp, "set_default_dtype", None)
|
|
96
|
+
if callable(setter):
|
|
97
|
+
setter(getattr(module, "_router_default_dtype", None) or module.weight.dtype)
|
|
98
|
+
tmp._load_from_state_dict(
|
|
99
|
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
module.weight = tmp.weight
|
|
103
|
+
module.bias = tmp.bias
|
|
104
|
+
module.input_scale = tmp.input_scale
|
|
105
|
+
module.output_scale = tmp.output_scale
|
|
106
|
+
|
|
107
|
+
ignore = set(_get_qmodule_base_attrs())
|
|
108
|
+
ignore.update({
|
|
109
|
+
"_quantize_hooks",
|
|
110
|
+
"training",
|
|
111
|
+
"_router_default_dtype",
|
|
112
|
+
})
|
|
113
|
+
for name, value in tmp.__dict__.items():
|
|
114
|
+
if name in ignore:
|
|
115
|
+
continue
|
|
116
|
+
setattr(module, name, value)
|
|
117
|
+
module._router_forward_impl = qmodule_cls.forward
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@register_qmodule(torch.nn.Linear)
|
|
121
|
+
class QLinearQuantoRouter(QModuleMixin, torch.nn.Linear):
|
|
122
|
+
@classmethod
|
|
123
|
+
def qcreate(
|
|
124
|
+
cls,
|
|
125
|
+
module,
|
|
126
|
+
weights,
|
|
127
|
+
activations=None,
|
|
128
|
+
optimizer=None,
|
|
129
|
+
device=None,
|
|
130
|
+
):
|
|
131
|
+
if torch.is_tensor(module.weight) and module.weight.dtype.is_floating_point:
|
|
132
|
+
weight_dtype = module.weight.dtype
|
|
133
|
+
elif torch.is_tensor(getattr(module, "bias", None)) and module.bias.dtype.is_floating_point:
|
|
134
|
+
weight_dtype = module.bias.dtype
|
|
135
|
+
else:
|
|
136
|
+
weight_dtype = torch.float16
|
|
137
|
+
return cls(
|
|
138
|
+
module.in_features,
|
|
139
|
+
module.out_features,
|
|
140
|
+
module.bias is not None,
|
|
141
|
+
device=device,
|
|
142
|
+
dtype=weight_dtype,
|
|
143
|
+
weights=weights,
|
|
144
|
+
activations=activations,
|
|
145
|
+
optimizer=optimizer,
|
|
146
|
+
quantize_input=True,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
in_features,
|
|
152
|
+
out_features,
|
|
153
|
+
bias=True,
|
|
154
|
+
device=None,
|
|
155
|
+
dtype=None,
|
|
156
|
+
weights=None,
|
|
157
|
+
activations=None,
|
|
158
|
+
optimizer=None,
|
|
159
|
+
quantize_input=True,
|
|
160
|
+
):
|
|
161
|
+
super().__init__(
|
|
162
|
+
in_features,
|
|
163
|
+
out_features,
|
|
164
|
+
bias=bias,
|
|
165
|
+
device=device,
|
|
166
|
+
dtype=dtype,
|
|
167
|
+
weights=weights,
|
|
168
|
+
activations=activations,
|
|
169
|
+
optimizer=optimizer,
|
|
170
|
+
quantize_input=quantize_input,
|
|
171
|
+
)
|
|
172
|
+
self._router_default_dtype = dtype
|
|
173
|
+
|
|
174
|
+
def set_default_dtype(self, dtype):
|
|
175
|
+
self._router_default_dtype = dtype
|
|
176
|
+
|
|
177
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
178
|
+
impl = getattr(self, "_router_forward_impl", None)
|
|
179
|
+
if impl is not None:
|
|
180
|
+
return impl(self, input)
|
|
181
|
+
return torch.nn.functional.linear(input, self.qweight, bias=self.bias)
|
|
182
|
+
|
|
183
|
+
def _load_from_state_dict(
|
|
184
|
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
185
|
+
):
|
|
186
|
+
qmodule_cls = _get_qmodule_for_qtype(self.weight_qtype)
|
|
187
|
+
if qmodule_cls is not None and qmodule_cls is not QLinearQuantoRouter:
|
|
188
|
+
return _load_with_qmodule(
|
|
189
|
+
self, qmodule_cls, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
190
|
+
)
|
|
191
|
+
return super()._load_from_state_dict(
|
|
192
|
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
_FP8_QUANTO_BRIDGE_MODULE = ".fp8_quanto_bridge"
|
|
197
|
+
|
|
198
|
+
_HANDLER_MODULES = [
|
|
199
|
+
_FP8_QUANTO_BRIDGE_MODULE,
|
|
200
|
+
]
|
|
201
|
+
_HANDLER_OBJECTS = []
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def register_handler(handler):
|
|
205
|
+
global _QTYPE_QMODULE_CACHE
|
|
206
|
+
if isinstance(handler, str):
|
|
207
|
+
if handler not in _HANDLER_MODULES:
|
|
208
|
+
_HANDLER_MODULES.append(handler)
|
|
209
|
+
_QTYPE_QMODULE_CACHE = None
|
|
210
|
+
return handler
|
|
211
|
+
if handler not in _HANDLER_OBJECTS:
|
|
212
|
+
_HANDLER_OBJECTS.append(handler)
|
|
213
|
+
_QTYPE_QMODULE_CACHE = None
|
|
214
|
+
return handler
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def unregister_handler(handler):
|
|
218
|
+
global _QTYPE_QMODULE_CACHE
|
|
219
|
+
removed = False
|
|
220
|
+
if isinstance(handler, str):
|
|
221
|
+
if handler in _HANDLER_MODULES:
|
|
222
|
+
_HANDLER_MODULES.remove(handler)
|
|
223
|
+
removed = True
|
|
224
|
+
elif handler in _HANDLER_OBJECTS:
|
|
225
|
+
_HANDLER_OBJECTS.remove(handler)
|
|
226
|
+
removed = True
|
|
227
|
+
if removed:
|
|
228
|
+
_QTYPE_QMODULE_CACHE = None
|
|
229
|
+
return removed
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _load_handlers():
|
|
233
|
+
handlers = []
|
|
234
|
+
for mod_path in _HANDLER_MODULES:
|
|
235
|
+
module = importlib.import_module(mod_path, package=__package__)
|
|
236
|
+
if not hasattr(module, "detect") or not hasattr(module, "convert_to_quanto"):
|
|
237
|
+
raise RuntimeError(
|
|
238
|
+
f"Quant handler '{mod_path}' must define detect() and convert_to_quanto() functions."
|
|
239
|
+
)
|
|
240
|
+
handlers.append(module)
|
|
241
|
+
for handler in _HANDLER_OBJECTS:
|
|
242
|
+
if not hasattr(handler, "detect") or not hasattr(handler, "convert_to_quanto"):
|
|
243
|
+
raise RuntimeError(
|
|
244
|
+
"Quant handler object must define detect() and convert_to_quanto() functions."
|
|
245
|
+
)
|
|
246
|
+
handlers.append(handler)
|
|
247
|
+
register_qmodule(torch.nn.Linear)(QLinearQuantoRouter)
|
|
248
|
+
return handlers
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _handler_name(handler):
|
|
252
|
+
return getattr(handler, "HANDLER_NAME", handler.__name__.split(".")[-1])
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def detect_safetensors_format(state_dict, verboseLevel=1):
|
|
256
|
+
matches = []
|
|
257
|
+
details = {}
|
|
258
|
+
for handler in _load_handlers():
|
|
259
|
+
result = handler.detect(state_dict, verboseLevel=verboseLevel)
|
|
260
|
+
name = _handler_name(handler)
|
|
261
|
+
details[name] = result
|
|
262
|
+
if result.get("matched", False):
|
|
263
|
+
matches.append(name)
|
|
264
|
+
if len(matches) > 1:
|
|
265
|
+
return {"kind": "mixed", "found": matches, "details": details}
|
|
266
|
+
if len(matches) == 1:
|
|
267
|
+
return {"kind": matches[0], "found": matches, "details": details}
|
|
268
|
+
return {"kind": "none", "found": [], "details": details}
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def detect_and_convert(state_dict, default_dtype, verboseLevel=1):
|
|
272
|
+
info = detect_safetensors_format(state_dict, verboseLevel=verboseLevel)
|
|
273
|
+
kind = info.get("kind", "none")
|
|
274
|
+
if kind == "mixed":
|
|
275
|
+
found = info.get("found", [])
|
|
276
|
+
details = info.get("details", {})
|
|
277
|
+
raise RuntimeError(f"Mixed quantization formats detected: {found} details={details}")
|
|
278
|
+
if kind in ("none", "quanto"):
|
|
279
|
+
return {"state_dict": state_dict, "quant_map": {}, "kind": kind, "details": info}
|
|
280
|
+
for handler in _load_handlers():
|
|
281
|
+
if _handler_name(handler) == kind:
|
|
282
|
+
detection = info.get("details", {}).get(kind, {})
|
|
283
|
+
conv = handler.convert_to_quanto(
|
|
284
|
+
state_dict,
|
|
285
|
+
default_dtype=default_dtype,
|
|
286
|
+
verboseLevel=verboseLevel,
|
|
287
|
+
detection=detection,
|
|
288
|
+
)
|
|
289
|
+
conv["kind"] = kind
|
|
290
|
+
conv["details"] = info
|
|
291
|
+
return conv
|
|
292
|
+
raise RuntimeError(f"Unsupported quantization format '{kind}'")
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def get_available_qtypes():
|
|
296
|
+
try:
|
|
297
|
+
from optimum.quanto.tensor.qtype import qtypes as _quanto_qtypes
|
|
298
|
+
except Exception:
|
|
299
|
+
return []
|
|
300
|
+
return sorted(_quanto_qtypes.keys())
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def get_available_qtype_aliases():
|
|
304
|
+
aliases = set()
|
|
305
|
+
for name in get_available_qtypes():
|
|
306
|
+
key = str(name).lower()
|
|
307
|
+
aliases.add(key)
|
|
308
|
+
if key.startswith("q") and len(key) > 1:
|
|
309
|
+
aliases.add(key[1:])
|
|
310
|
+
if "float8" in key:
|
|
311
|
+
aliases.add("fp8")
|
|
312
|
+
return aliases
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def get_quantization_tokens(quantization):
|
|
316
|
+
if quantization is None:
|
|
317
|
+
return []
|
|
318
|
+
key = str(quantization).lower()
|
|
319
|
+
if len(key) == 0:
|
|
320
|
+
return []
|
|
321
|
+
aliases = get_available_qtype_aliases()
|
|
322
|
+
if key not in aliases:
|
|
323
|
+
return []
|
|
324
|
+
tokens = {key}
|
|
325
|
+
if key.startswith("q") and len(key) > 1:
|
|
326
|
+
tokens.add(key[1:])
|
|
327
|
+
if "float8" in key or key == "fp8":
|
|
328
|
+
tokens.add("fp8")
|
|
329
|
+
if "int4" in key:
|
|
330
|
+
tokens.add("int4")
|
|
331
|
+
if "int8" in key:
|
|
332
|
+
tokens.add("int8")
|
|
333
|
+
return sorted(tokens, key=len, reverse=True)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def get_quantization_label(quantization):
|
|
337
|
+
if quantization is None:
|
|
338
|
+
return ""
|
|
339
|
+
key = str(quantization).lower()
|
|
340
|
+
if key in ("", "none", "bf16", "fp16", "float16", "bfloat16"):
|
|
341
|
+
return ""
|
|
342
|
+
aliases = get_available_qtype_aliases()
|
|
343
|
+
if key not in aliases:
|
|
344
|
+
return ""
|
|
345
|
+
if "float8" in key or key == "fp8":
|
|
346
|
+
return "FP8"
|
|
347
|
+
if key.startswith("q"):
|
|
348
|
+
key = key[1:]
|
|
349
|
+
return key.replace("_", " ").upper()
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
_quantization_filename_cache = {}
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _normalize_quant_file_key(file_path):
|
|
356
|
+
try:
|
|
357
|
+
return os.path.normcase(os.path.abspath(file_path))
|
|
358
|
+
except Exception:
|
|
359
|
+
return str(file_path).lower()
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def get_cached_quantization_for_file(file_path):
|
|
363
|
+
if not file_path:
|
|
364
|
+
return None
|
|
365
|
+
return _quantization_filename_cache.get(_normalize_quant_file_key(file_path))
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def cache_quantization_for_file(file_path, kind):
|
|
369
|
+
if not file_path or not kind:
|
|
370
|
+
return
|
|
371
|
+
key = _normalize_quant_file_key(file_path)
|
|
372
|
+
if key not in _quantization_filename_cache:
|
|
373
|
+
_quantization_filename_cache[key] = kind
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def _infer_qtype_from_quantization_map(quantization_map):
|
|
377
|
+
if not quantization_map:
|
|
378
|
+
return None
|
|
379
|
+
counts = {}
|
|
380
|
+
for entry in quantization_map.values():
|
|
381
|
+
if not isinstance(entry, dict):
|
|
382
|
+
continue
|
|
383
|
+
weights = entry.get("weights")
|
|
384
|
+
if not weights or weights == "none":
|
|
385
|
+
continue
|
|
386
|
+
counts[weights] = counts.get(weights, 0) + 1
|
|
387
|
+
if not counts:
|
|
388
|
+
return None
|
|
389
|
+
return max(counts, key=counts.get)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def detect_quantization_kind_for_file(file_path, verboseLevel=1):
|
|
393
|
+
cached = get_cached_quantization_for_file(file_path)
|
|
394
|
+
if cached:
|
|
395
|
+
return cached
|
|
396
|
+
if not file_path or not os.path.isfile(file_path):
|
|
397
|
+
return None
|
|
398
|
+
if not (".safetensors" in file_path or ".sft" in file_path):
|
|
399
|
+
return None
|
|
400
|
+
|
|
401
|
+
def _load_full():
|
|
402
|
+
state_dict = {}
|
|
403
|
+
with safetensors2.safe_open(
|
|
404
|
+
file_path,
|
|
405
|
+
framework="pt",
|
|
406
|
+
device="cpu",
|
|
407
|
+
writable_tensors=False,
|
|
408
|
+
) as f:
|
|
409
|
+
for key in f.keys():
|
|
410
|
+
state_dict[key] = f.get_tensor(key)
|
|
411
|
+
metadata = f.metadata()
|
|
412
|
+
return state_dict, metadata
|
|
413
|
+
|
|
414
|
+
def _try_detect(state_dict):
|
|
415
|
+
try:
|
|
416
|
+
info = detect_safetensors_format(state_dict, verboseLevel=verboseLevel)
|
|
417
|
+
return info.get("kind"), True
|
|
418
|
+
except Exception:
|
|
419
|
+
return None, False
|
|
420
|
+
|
|
421
|
+
metadata_only = False
|
|
422
|
+
try:
|
|
423
|
+
state_dict, metadata = safetensors2.load_metadata_state_dict(file_path)
|
|
424
|
+
metadata_only = True
|
|
425
|
+
except Exception:
|
|
426
|
+
try:
|
|
427
|
+
state_dict, metadata = _load_full()
|
|
428
|
+
except Exception:
|
|
429
|
+
return None
|
|
430
|
+
|
|
431
|
+
kind, ok = _try_detect(state_dict)
|
|
432
|
+
if metadata_only and not ok:
|
|
433
|
+
try:
|
|
434
|
+
state_dict, metadata = _load_full()
|
|
435
|
+
kind, ok = _try_detect(state_dict)
|
|
436
|
+
except Exception:
|
|
437
|
+
kind = None
|
|
438
|
+
|
|
439
|
+
if (not kind or kind == "none") and metadata is not None:
|
|
440
|
+
inferred = _infer_qtype_from_quantization_map(metadata.get("quantization_map"))
|
|
441
|
+
if inferred:
|
|
442
|
+
kind = inferred
|
|
443
|
+
|
|
444
|
+
cache_quantization_for_file(file_path, kind or "none")
|
|
445
|
+
return kind
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def detect_quantization_label_from_filename(filename):
|
|
449
|
+
if not filename:
|
|
450
|
+
return ""
|
|
451
|
+
cached = get_cached_quantization_for_file(filename)
|
|
452
|
+
if cached:
|
|
453
|
+
return get_quantization_label(cached)
|
|
454
|
+
kind = detect_quantization_kind_for_file(filename, verboseLevel=0)
|
|
455
|
+
if kind:
|
|
456
|
+
label = get_quantization_label(kind)
|
|
457
|
+
if label:
|
|
458
|
+
return label
|
|
459
|
+
base = os.path.basename(filename).lower()
|
|
460
|
+
for token in sorted(get_available_qtype_aliases(), key=len, reverse=True):
|
|
461
|
+
if token and token in base:
|
|
462
|
+
return get_quantization_label(token)
|
|
463
|
+
if "quanto" in base:
|
|
464
|
+
return "QUANTO"
|
|
465
|
+
return ""
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def apply_pre_quantization(model, state_dict, quantization_map, default_dtype=None, verboseLevel=1):
|
|
469
|
+
remaining = dict(quantization_map or {})
|
|
470
|
+
post_load = []
|
|
471
|
+
for handler in _load_handlers():
|
|
472
|
+
fn = getattr(handler, "apply_pre_quantization", None)
|
|
473
|
+
if fn is None:
|
|
474
|
+
continue
|
|
475
|
+
remaining, hooks = fn(
|
|
476
|
+
model,
|
|
477
|
+
state_dict,
|
|
478
|
+
remaining,
|
|
479
|
+
default_dtype=default_dtype,
|
|
480
|
+
verboseLevel=verboseLevel,
|
|
481
|
+
)
|
|
482
|
+
if hooks:
|
|
483
|
+
post_load.extend(hooks)
|
|
484
|
+
return remaining, post_load
|
|
485
|
+
|
|
486
|
+
def _patch_marlin_fp8_bias():
|
|
487
|
+
"""
|
|
488
|
+
Quanto's Marlin FP8 CUDA kernel currently ignores the bias argument.
|
|
489
|
+
Add it back manually (in-place) so outputs stay correct on CUDA builds.
|
|
490
|
+
"""
|
|
491
|
+
try:
|
|
492
|
+
from optimum.quanto.tensor.weights.marlin.fp8 import qbits as marlin_fp8
|
|
493
|
+
except Exception:
|
|
494
|
+
return
|
|
495
|
+
if getattr(marlin_fp8.MarlinF8QBytesLinearFunction, "_wan2gp_bias_patch", False):
|
|
496
|
+
return
|
|
497
|
+
|
|
498
|
+
orig_forward = marlin_fp8.MarlinF8QBytesLinearFunction.forward
|
|
499
|
+
|
|
500
|
+
def forward_with_bias(ctx, input, other, bias=None):
|
|
501
|
+
out = orig_forward(ctx, input, other, None)
|
|
502
|
+
if bias is None:
|
|
503
|
+
return out
|
|
504
|
+
bias_to_add = bias
|
|
505
|
+
if bias_to_add.device != out.device or bias_to_add.dtype != out.dtype:
|
|
506
|
+
bias_to_add = bias_to_add.to(device=out.device, dtype=out.dtype)
|
|
507
|
+
view_shape = [1] * out.ndim
|
|
508
|
+
view_shape[-1] = bias_to_add.shape[0]
|
|
509
|
+
bias_view = bias_to_add.view(*view_shape)
|
|
510
|
+
out.add_(bias_view)
|
|
511
|
+
return out
|
|
512
|
+
|
|
513
|
+
marlin_fp8.MarlinF8QBytesLinearFunction.forward = staticmethod(forward_with_bias) # type: ignore
|
|
514
|
+
marlin_fp8.MarlinF8QBytesLinearFunction._wan2gp_bias_patch = True # type: ignore
|
|
515
|
+
marlin_fp8.MarlinF8QBytesLinearFunction._wan2gp_bias_orig = orig_forward # type: ignore
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
_patch_marlin_fp8_bias()
|
mmgp/quanto_int8_cuda.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from types import SimpleNamespace
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils import cpp_extension
|
|
9
|
+
|
|
10
|
+
_STATE = SimpleNamespace(module=None)
|
|
11
|
+
|
|
12
|
+
def _maybe_set_msvc_env() -> None:
|
|
13
|
+
if os.name != "nt":
|
|
14
|
+
return
|
|
15
|
+
if os.environ.get("VCToolsInstallDir") and os.environ.get("INCLUDE"):
|
|
16
|
+
return
|
|
17
|
+
vs_root = Path(os.environ.get("VSINSTALLDIR", r"C:\Program Files\Microsoft Visual Studio\2022\Community"))
|
|
18
|
+
msvc_root = vs_root / "VC" / "Tools" / "MSVC"
|
|
19
|
+
sdk_root = Path(r"C:\Program Files (x86)\Windows Kits\10")
|
|
20
|
+
if not msvc_root.exists() or not sdk_root.exists():
|
|
21
|
+
return
|
|
22
|
+
msvc_versions = sorted(msvc_root.iterdir(), reverse=True)
|
|
23
|
+
sdk_includes = sorted((sdk_root / "Include").iterdir(), reverse=True)
|
|
24
|
+
if not msvc_versions or not sdk_includes:
|
|
25
|
+
return
|
|
26
|
+
vc_tools = msvc_versions[0]
|
|
27
|
+
sdk_ver = sdk_includes[0].name
|
|
28
|
+
os.environ.setdefault("VCToolsInstallDir", str(vc_tools) + os.sep)
|
|
29
|
+
os.environ.setdefault("VCINSTALLDIR", str(vs_root / "VC") + os.sep)
|
|
30
|
+
os.environ.setdefault("WindowsSdkDir", str(sdk_root) + os.sep)
|
|
31
|
+
os.environ.setdefault("WindowsSDKVersion", sdk_ver + os.sep)
|
|
32
|
+
os.environ["PATH"] = f"{vc_tools}\\bin\\Hostx64\\x64;{sdk_root}\\bin\\{sdk_ver}\\x64;" + os.environ.get("PATH", "")
|
|
33
|
+
os.environ["INCLUDE"] = (
|
|
34
|
+
f"{vc_tools}\\include;{sdk_root}\\Include\\{sdk_ver}\\ucrt;{sdk_root}\\Include\\{sdk_ver}\\shared;"
|
|
35
|
+
f"{sdk_root}\\Include\\{sdk_ver}\\um;{sdk_root}\\Include\\{sdk_ver}\\winrt;{sdk_root}\\Include\\{sdk_ver}\\cppwinrt"
|
|
36
|
+
)
|
|
37
|
+
os.environ["LIB"] = (
|
|
38
|
+
f"{vc_tools}\\lib\\x64;{sdk_root}\\Lib\\{sdk_ver}\\ucrt\\x64;{sdk_root}\\Lib\\{sdk_ver}\\um\\x64"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _extra_cflags() -> list[str]:
|
|
43
|
+
if os.name == "nt":
|
|
44
|
+
return ["/O2"]
|
|
45
|
+
return ["-O3"]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _extra_cuda_cflags() -> list[str]:
|
|
49
|
+
flags = [
|
|
50
|
+
"-O3",
|
|
51
|
+
"--use_fast_math",
|
|
52
|
+
"--expt-relaxed-constexpr",
|
|
53
|
+
"--expt-extended-lambda",
|
|
54
|
+
]
|
|
55
|
+
return flags
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _sources() -> list[str]:
|
|
59
|
+
src_dir = Path(__file__).parent / "quanto_int8_kernels"
|
|
60
|
+
return [
|
|
61
|
+
str(src_dir / "int8_scaled_mm.cpp"),
|
|
62
|
+
str(src_dir / "int8_scaled_mm.cu"),
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _build_dir() -> str:
|
|
67
|
+
build_dir = Path(__file__).parent / "quanto_int8_kernels" / "build"
|
|
68
|
+
build_dir.mkdir(parents=True, exist_ok=True)
|
|
69
|
+
return str(build_dir)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def load() -> object:
|
|
73
|
+
if _STATE.module is not None:
|
|
74
|
+
return _STATE.module
|
|
75
|
+
_maybe_set_msvc_env()
|
|
76
|
+
name = "mmgp_quanto_int8_cuda"
|
|
77
|
+
_STATE.module = cpp_extension.load(
|
|
78
|
+
name=name,
|
|
79
|
+
sources=_sources(),
|
|
80
|
+
build_directory=_build_dir(),
|
|
81
|
+
extra_cflags=_extra_cflags(),
|
|
82
|
+
extra_cuda_cflags=_extra_cuda_cflags(),
|
|
83
|
+
verbose=True,
|
|
84
|
+
)
|
|
85
|
+
return _STATE.module
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def int8_scaled_mm(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor) -> torch.Tensor:
|
|
89
|
+
return load().int8_scaled_mm(a, b, a_scale, b_scale)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def quantize_per_row_int8(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
93
|
+
return load().quantize_per_row_int8(a)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def scale_int32_to(acc: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor) -> torch.Tensor:
|
|
97
|
+
return load().scale_int32_to(acc, a_scale, b_scale)
|