mmgp 3.3.1__py3-none-any.whl → 3.6.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,645 @@
1
+ from __future__ import annotations
2
+ import json, re, inspect, os
3
+ from types import SimpleNamespace
4
+ from typing import Dict, Optional, Tuple, Union, Iterable, Callable
5
+
6
+ import torch
7
+ from safetensors.torch import safe_open, save_file
8
+
9
+ # ---------- Constants ----------
10
+ DATA_SUFFIX = "._data"
11
+ SCALE_SUFFIX = "._scale" # per-channel, shape [out, 1, ...]
12
+ IN_SCALE = ".input_scale" # 1-D placeholder tensor [1]
13
+ OUT_SCALE = ".output_scale" # 1-D placeholder tensor [1]
14
+
15
+ _QTYPE_NAME = {
16
+ "e4m3fn": "qfloat8_e4m3fn",
17
+ "e5m2": "qfloat8_e5m2",
18
+ "auto": "qfloat8",
19
+ }
20
+
21
+ _SCALE_META_KEYS = (
22
+ "fp8_scale_map", "fp8.scale_map", "scale_map",
23
+ "quant_scale_map", "weights_scales", "scales",
24
+ )
25
+
26
+ _DTYPE_ALIASES = {
27
+ "float32": torch.float32, "fp32": torch.float32,
28
+ "bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
29
+ "float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
30
+ }
31
+
32
+ _fp8_weight_debug_map: Dict[int, str] = {}
33
+
34
+
35
+ def register_fp8_weight_debug_name(weight: Optional[torch.Tensor], name: str) -> None:
36
+ if weight is None:
37
+ return
38
+ if isinstance(name, str) and name:
39
+ try:
40
+ _fp8_weight_debug_map[id(weight)] = name
41
+ except Exception:
42
+ pass
43
+ try:
44
+ setattr(weight, "_debug_name", name)
45
+ except Exception:
46
+ pass
47
+
48
+
49
+ def get_fp8_weight_debug_name(weight: Optional[torch.Tensor]) -> str:
50
+ if weight is None:
51
+ return "unknown"
52
+ debug = getattr(weight, "_debug_name", None)
53
+ if isinstance(debug, str) and debug:
54
+ return debug
55
+ try:
56
+ mapped = _fp8_weight_debug_map.get(id(weight))
57
+ if isinstance(mapped, str) and mapped:
58
+ try:
59
+ setattr(weight, "_debug_name", mapped)
60
+ except Exception:
61
+ pass
62
+ return mapped
63
+ except Exception:
64
+ pass
65
+ if os.environ.get("WAN2GP_FP8_TAG", "").lower() in ("1", "true", "yes", "on"):
66
+ try:
67
+ print(f"[WAN2GP][FP8 tag][miss] type={type(weight).__name__} id={id(weight)} device={weight.device}")
68
+ except Exception:
69
+ print(f"[WAN2GP][FP8 tag][miss] type={type(weight).__name__} id={id(weight)}")
70
+ return "unknown"
71
+
72
+
73
+ def _is_weight_key(k: str) -> bool:
74
+ return k.endswith(".weight")
75
+
76
+ # ---------- Accessors (unify file vs dict) ----------
77
+ class Accessor:
78
+ def keys(self) -> Iterable[str]: ...
79
+ def get_tensor(self, key: str) -> torch.Tensor: ...
80
+ def metadata(self) -> Dict[str, str]: ...
81
+ def has(self, key: str) -> bool: ... # NEW
82
+ def can_delete(self) -> bool: return False
83
+ def delete(self, key: str) -> None: raise NotImplementedError
84
+
85
+ class FileAccessor(Accessor):
86
+ def __init__(self, path: str):
87
+ self._fh = safe_open(path, framework="pt")
88
+ self._keys = list(self._fh.keys())
89
+ self._keys_set = set(self._keys) # O(1) membership
90
+ self._meta = self._fh.metadata() or {}
91
+ def keys(self) -> Iterable[str]: return self._keys
92
+ def has(self, key: str) -> bool: return key in self._keys_set
93
+ def get_tensor(self, key: str) -> torch.Tensor: return self._fh.get_tensor(key)
94
+ def metadata(self) -> Dict[str, str]: return self._meta
95
+ def close(self) -> None: self._fh.close()
96
+
97
+ class DictAccessor(Accessor):
98
+ def __init__(self, sd: Dict[str, torch.Tensor], meta: Optional[Dict[str, str]] = None,
99
+ in_place: bool = False, free_cuda_cache: bool = False, cuda_cache_interval: int = 32):
100
+ self.sd = sd
101
+ self._meta = meta or {}
102
+ self._in_place = in_place
103
+ self._free = free_cuda_cache
104
+ self._interval = int(cuda_cache_interval)
105
+ self._deletions = 0
106
+ def keys(self) -> Iterable[str]: return list(self.sd.keys())
107
+ def has(self, key: str) -> bool: return key in self.sd # dict membership = O(1)
108
+ def get_tensor(self, key: str) -> torch.Tensor: return self.sd[key]
109
+ def metadata(self) -> Dict[str, str]: return self._meta
110
+ def can_delete(self) -> bool: return self._in_place
111
+ def delete(self, key: str) -> None:
112
+ if key in self.sd:
113
+ self.sd.pop(key, None)
114
+ self._deletions += 1
115
+ if self._free and (self._deletions % self._interval == 0) and torch.cuda.is_available():
116
+ torch.cuda.empty_cache()
117
+ def _as_accessor(src: Union[str, Dict[str, torch.Tensor]], **dict_opts) -> Tuple[Accessor, Callable[[], None]]:
118
+ if isinstance(src, str):
119
+ acc = FileAccessor(src)
120
+ return acc, acc.close
121
+ acc = DictAccessor(src, **dict_opts)
122
+ return acc, (lambda: None)
123
+
124
+ # ---------- Shared helpers ----------
125
+ def _normalize_scale_dtype(scale_dtype: Union[str, torch.dtype]) -> torch.dtype:
126
+ if isinstance(scale_dtype, torch.dtype):
127
+ return scale_dtype
128
+ key = str(scale_dtype).lower()
129
+ if key not in _DTYPE_ALIASES:
130
+ raise ValueError(f"scale_dtype must be one of {list(_DTYPE_ALIASES.keys())} or a torch.dtype")
131
+ return _DTYPE_ALIASES[key]
132
+
133
+ def _json_to_dict(s: str) -> Optional[Dict]:
134
+ # Strictly catch JSON decoding only
135
+ try:
136
+ return json.loads(s)
137
+ except json.JSONDecodeError:
138
+ return None
139
+
140
+ def _maybe_parse_scale_map(meta: Dict[str, str]) -> Optional[Dict[str, float]]:
141
+ def try_parse(obj) -> Optional[Dict[str, float]]:
142
+ if not isinstance(obj, dict):
143
+ return None
144
+ out: Dict[str, float] = {}
145
+ for wk, v in obj.items():
146
+ if isinstance(v, (int, float)):
147
+ out[wk] = float(v)
148
+ elif isinstance(v, dict) and "scale" in v:
149
+ sc = v["scale"]
150
+ if isinstance(sc, (int, float)):
151
+ out[wk] = float(sc)
152
+ elif isinstance(sc, (list, tuple)) and len(sc) == 1 and isinstance(sc[0], (int, float)):
153
+ out[wk] = float(sc[0])
154
+ if out:
155
+ return out
156
+ for sub in ("weights", "tensors", "params", "map"):
157
+ subobj = obj.get(sub)
158
+ if isinstance(subobj, dict):
159
+ got = try_parse(subobj)
160
+ if got:
161
+ return got
162
+ return None
163
+
164
+ # exact keys first
165
+ for k in _SCALE_META_KEYS:
166
+ raw = meta.get(k)
167
+ if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"):
168
+ parsed = _json_to_dict(raw)
169
+ if parsed:
170
+ got = try_parse(parsed)
171
+ if got:
172
+ return got
173
+
174
+ # loose scan of any JSON-looking value
175
+ for v in meta.values():
176
+ if isinstance(v, str) and v.startswith("{") and v.endswith("}"):
177
+ parsed = _json_to_dict(v)
178
+ if parsed:
179
+ got = try_parse(parsed)
180
+ if got:
181
+ return got
182
+ return None
183
+
184
+ def _quick_fp8_variant_from_sentinel(acc: Accessor) -> Optional[str]:
185
+ if "scaled_fp8" in set(acc.keys()):
186
+ dt = acc.get_tensor("scaled_fp8").dtype
187
+ if dt == torch.float8_e4m3fn: return "e4m3fn"
188
+ if dt == torch.float8_e5m2: return "e5m2"
189
+ return None
190
+
191
+ def _per_channel_reshape(vec: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
192
+ return vec.view(weight.shape[0], *([1] * (weight.ndim - 1)))
193
+
194
+ # ---------- Unified converter ----------
195
+ class ConvertResult(Dict[str, object]):
196
+ @property
197
+ def state_dict(self) -> Dict[str, torch.Tensor]: return self["state_dict"] # type: ignore
198
+ @property
199
+ def quant_map(self) -> Dict[str, Dict]: return self["quant_map"] # type: ignore
200
+ @property
201
+ def fp8_format(self) -> str: return self["fp8_format"] # type: ignore
202
+ @property
203
+ def patch_needed(self) -> bool: return self["patch_needed"] # type: ignore
204
+
205
+ def _preferred_fp8_scale_dtype() -> torch.dtype:
206
+ """
207
+ Marlin's FP8 kernels expect fp16 scale tensors. Using bf16 led to
208
+ noticeable drift on Linux, so force fp16 for stability.
209
+ """
210
+ return torch.float16
211
+
212
+
213
+ def convert_scaled_fp8_to_quanto(
214
+ src: Union[str, Dict[str, torch.Tensor]],
215
+ *,
216
+ fp8_format: Optional[str] = None, # 'e4m3fn' | 'e5m2' | None (auto)
217
+ require_scale: bool = False,
218
+ allow_default_scale: bool = True,
219
+ default_missing_scale: float = 1.0,
220
+ dtype: Union[str, torch.dtype] = "float32",
221
+ add_activation_placeholders: bool = True,
222
+ # dict mode options
223
+ sd_metadata: Optional[Dict[str, str]] = None,
224
+ in_place: bool = False,
225
+ free_cuda_cache: bool = False,
226
+ cuda_cache_interval: int = 32,
227
+ ) -> ConvertResult:
228
+ sd_scale_dtype = _normalize_scale_dtype(dtype)
229
+ patch_needed = False
230
+ if sd_scale_dtype == torch.float32:
231
+ # Quanto's FP8 kernels expect half-precision scales. Coerce float32
232
+ # requests to a supported dtype so we don't need to enable the slow
233
+ # fp32 patch/fallback path.
234
+ sd_scale_dtype = _preferred_fp8_scale_dtype()
235
+
236
+ acc, closer = _as_accessor(
237
+ src,
238
+ meta=sd_metadata,
239
+ in_place=in_place,
240
+ free_cuda_cache=free_cuda_cache,
241
+ cuda_cache_interval=cuda_cache_interval,
242
+ )
243
+ if not acc.can_delete(): in_place = False
244
+ try:
245
+ meta = acc.metadata() or {}
246
+ meta_scale_map = _maybe_parse_scale_map(meta) or {}
247
+
248
+ keys = list(acc.keys())
249
+
250
+ # FP8 variant: sentinel -> first FP8 weight -> 'auto'
251
+ fmt = fp8_format or _quick_fp8_variant_from_sentinel(acc)
252
+ if fmt is None:
253
+ for wk in keys:
254
+ if not _is_weight_key(wk): continue
255
+ dt = acc.get_tensor(wk).dtype
256
+ if dt == torch.float8_e4m3fn: fmt = "e4m3fn"; break
257
+ if dt == torch.float8_e5m2: fmt = "e5m2"; break
258
+ if fmt is None: fmt = "auto"
259
+
260
+ # Map '<base>.scale_weight' -> '<base>.weight'
261
+ scale_weight_map: Dict[str, str] = {}
262
+ for sk in keys:
263
+ if sk.endswith(".scale_weight"):
264
+ base = sk[: -len(".scale_weight")]
265
+ wk = base + ".weight"
266
+ if wk in keys:
267
+ scale_weight_map[wk] = sk
268
+
269
+ def get_scale_vec_for_weight(wk: str, out_ch: int) -> Optional[torch.Tensor]:
270
+ # 1) explicit tensor
271
+ sk = scale_weight_map.get(wk)
272
+ if sk is not None:
273
+ s_t = acc.get_tensor(sk).to(torch.float32)
274
+ if in_place: acc.delete(sk)
275
+ if s_t.numel() == 1:
276
+ return torch.full((out_ch,), float(s_t.item()), dtype=torch.float32)
277
+ if s_t.numel() == out_ch:
278
+ return s_t.reshape(out_ch)
279
+ if torch.numel(s_t.unique()) == 1:
280
+ return torch.full((out_ch,), float(s_t.view(-1)[0].item()), dtype=torch.float32)
281
+ raise ValueError(f"Unexpected scale length for '{wk}': {s_t.numel()} (out_ch={out_ch})")
282
+ # 2) metadata exact / normalized
283
+ if wk in meta_scale_map:
284
+ return torch.full((out_ch,), float(meta_scale_map[wk]), dtype=torch.float32)
285
+ for alt in (wk.replace("model.", ""), re.sub(r"(^|\.)weight$", "", wk)):
286
+ if alt in meta_scale_map:
287
+ return torch.full((out_ch,), float(meta_scale_map[alt]), dtype=torch.float32)
288
+ return None
289
+
290
+ out_sd: Dict[str, torch.Tensor] = {}
291
+ qmap: Dict[str, Dict] = {}
292
+
293
+ # Single pass: rewrite FP8 weights, copy-through others
294
+ for k in keys:
295
+ # Drop source-only artifacts
296
+ if k == "scaled_fp8" or k.endswith(".scale_weight") :
297
+ continue
298
+
299
+ t = acc.get_tensor(k)
300
+ if in_place: acc.delete(k)
301
+ if _is_weight_key(k) and t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
302
+ # Quantized: keep original FP8 tensor as _data
303
+ out_sd[k + DATA_SUFFIX] = t
304
+
305
+ out_ch = int(t.shape[0])
306
+ s_vec = get_scale_vec_for_weight(k, out_ch)
307
+ if s_vec is None:
308
+ if require_scale and not allow_default_scale:
309
+ raise KeyError(f"No scale found for '{k}' (looked for '.scale_weight' and metadata).")
310
+ s_vec = torch.full((out_ch,), float(default_missing_scale), dtype=torch.float32)
311
+
312
+ s_grid = _per_channel_reshape(s_vec, t).to(sd_scale_dtype)
313
+ out_sd[k + SCALE_SUFFIX] = s_grid
314
+
315
+ if add_activation_placeholders:
316
+ base = k[:-len(".weight")]
317
+ out_sd[base + IN_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
318
+ out_sd[base + OUT_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
319
+
320
+ base = k[:-len(".weight")]
321
+ qmap[base] = {"weights": _QTYPE_NAME[fmt], "activations": "none"}
322
+ else:
323
+ out_sd[k] = t if t.dtype == dtype or t.dtype == torch.float32 else t.to(dtype)
324
+ t = None
325
+ return ConvertResult(state_dict=out_sd, quant_map=qmap, fp8_format=fmt, patch_needed=patch_needed)
326
+ finally:
327
+ closer()
328
+
329
+
330
+ def detect(state_dict, verboseLevel=1):
331
+ info = detect_safetensors_format(state_dict)
332
+ kind = info.get("kind", "none")
333
+ matched = kind in ("scaled_fp8", "fp8")
334
+ return {"matched": matched, "kind": "fp8" if matched else "none", "details": info}
335
+
336
+
337
+ def convert_to_quanto(state_dict, default_dtype, verboseLevel=1, detection=None):
338
+ if detection is not None and not detection.get("matched", False):
339
+ return {"state_dict": state_dict, "quant_map": {}}
340
+ conv_result = convert_scaled_fp8_to_quanto(state_dict, dtype=default_dtype, in_place=True)
341
+ return {"state_dict": conv_result["state_dict"], "quant_map": conv_result["quant_map"]}
342
+
343
+
344
+ def apply_pre_quantization(model, state_dict, quantization_map, default_dtype=None, verboseLevel=1):
345
+ return quantization_map, []
346
+
347
+ def detect_safetensors_format(
348
+ src: Union[str, Dict[str, torch.Tensor]],
349
+ *,
350
+ sd_metadata: Optional[Dict[str, str]] = None,
351
+ probe_weights: bool = False, # if True, we may read up to 2 weights total
352
+ with_hints: bool = False,
353
+ ) -> Dict[str, str]:
354
+ """
355
+ Returns:
356
+ {
357
+ 'kind': 'quanto' | 'scaled_fp8' | 'fp8' | 'none',
358
+ 'quant_format': 'qfloat8_e4m3fn' | 'qfloat8_e5m2' | 'qfloat8' | 'qint8' | 'qint4' | 'unknown' | '',
359
+ 'fp8_format': 'e4m3fn' | 'e5m2' | 'unknown' | '',
360
+ 'hint': '...' # only when with_hints=True
361
+ }
362
+ """
363
+ acc, closer = _as_accessor(src, meta=sd_metadata, in_place=False)
364
+ try:
365
+ # --- O(1) sentinel test up-front (no key scan) ---
366
+ if acc.has("scaled_fp8"):
367
+ dt = acc.get_tensor("scaled_fp8").dtype
368
+ fp8_fmt = "e4m3fn" if dt == torch.float8_e4m3fn else ("e5m2" if dt == torch.float8_e5m2 else "unknown")
369
+ out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_fmt}
370
+ if with_hints: out["hint"] = "sentinel"
371
+ return out
372
+
373
+ # --- Single pass over keys (no re-scans) ---
374
+ ks = list(acc.keys())
375
+ has_scale_weight = False
376
+ saw_quanto_data = False
377
+ fp8_variant = None
378
+ fp8_probe_budget = 2 if probe_weights else 1
379
+
380
+ for k in ks:
381
+ # Quanto pack short-circuit
382
+ if not saw_quanto_data and k.endswith(DATA_SUFFIX):
383
+ saw_quanto_data = True
384
+ # we can break here, but keep minimal state setting uniformity
385
+ break
386
+
387
+ if saw_quanto_data:
388
+ out = {"kind": "quanto", "quant_format": "qfloat8", "fp8_format": ""}
389
+ if with_hints: out["hint"] = "keys:*._data"
390
+ return out
391
+
392
+ # continue single pass for the rest (scale keys + bounded dtype probe)
393
+ for k in ks:
394
+ if not has_scale_weight and k.endswith(".scale_weight"):
395
+ has_scale_weight = True
396
+ # don't return yet; we may still probe a dtype to grab variant
397
+
398
+ if fp8_probe_budget > 0 and _is_weight_key(k):
399
+ dt = acc.get_tensor(k).dtype
400
+ if dt == torch.float8_e4m3fn:
401
+ fp8_variant = "e4m3fn"; fp8_probe_budget -= 1
402
+ elif dt == torch.float8_e5m2:
403
+ fp8_variant = "e5m2"; fp8_probe_budget -= 1
404
+
405
+ if has_scale_weight:
406
+ out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_variant or "unknown"}
407
+ if with_hints: out["hint"] = "scale_weight keys"
408
+ return out
409
+
410
+ if fp8_variant is not None:
411
+ out = {"kind": "fp8", "quant_format": "", "fp8_format": fp8_variant}
412
+ if with_hints: out["hint"] = "weight dtype (plain fp8)"
413
+ return out
414
+
415
+ # --- Cheap metadata peek only if keys didn't decide it (no JSON parsing) ---
416
+ meta = acc.metadata() or {}
417
+ blob = " ".join(v for v in meta.values() if isinstance(v, str)).lower()
418
+
419
+ # scaled-fp8 hinted by metadata only
420
+ has_scale_map = (
421
+ any(k in meta for k in _SCALE_META_KEYS) or
422
+ (("scale" in blob) and (("fp8" in blob) or ("float8" in blob)))
423
+ )
424
+ if has_scale_map:
425
+ fmt = "e4m3fn" if "e4m3" in blob else ("e5m2" if "e5m2" in blob else "unknown")
426
+ out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fmt}
427
+ if with_hints: out["hint"] = "metadata"
428
+ return out
429
+
430
+ # quanto hinted by metadata only (not decisive without keys)
431
+ qtype_hint = ""
432
+ for tok in ("qfloat8_e4m3fn", "qfloat8_e5m2", "qfloat8", "qint8", "qint4"):
433
+ if tok in blob:
434
+ qtype_hint = tok
435
+ break
436
+
437
+ out = {"kind": "none", "quant_format": qtype_hint, "fp8_format": ""}
438
+ if with_hints: out["hint"] = "no decisive keys"
439
+ return out
440
+
441
+ finally:
442
+ closer()
443
+
444
+ # ---------- Optional Quanto runtime patch (FP32-scale support), enable/disable ----------
445
+ _patch_state = SimpleNamespace(enabled=False, orig=None, scale_index=None)
446
+ _fp8_kernel_patch_state = SimpleNamespace(enabled=False, orig_forward=None)
447
+ def enable_fp8_fp32_scale_support():
448
+ """
449
+ Version-robust wrapper for WeightQBytesTensor.create:
450
+ - matches both positional/keyword call styles via *args/**kwargs,
451
+ - for FP8 + FP32 scales, expands scalar/uniform scales with a VIEW to the needed length,
452
+ - leaves bf16/fp16 (classic Quanto) untouched.
453
+ Enable only if you emitted float32 scales.
454
+ """
455
+ if _patch_state.enabled:
456
+ return True
457
+
458
+ from optimum.quanto.tensor.weights import qbytes as _qbytes # late import
459
+ orig = _qbytes.WeightQBytesTensor.create
460
+ sig = inspect.signature(orig)
461
+ params = list(sig.parameters.keys())
462
+ scale_index = params.index("scale") if "scale" in params else 5 # fallback
463
+
464
+ def wrapper(*args, **kwargs):
465
+ # Extract fields irrespective of signature
466
+ qtype = kwargs.get("qtype", args[0] if len(args) > 0 else None)
467
+ axis = kwargs.get("axis", args[1] if len(args) > 1 else None)
468
+ size = kwargs.get("size", args[2] if len(args) > 2 else None)
469
+
470
+ if "scale" in kwargs:
471
+ scale = kwargs["scale"]
472
+ def set_scale(new): kwargs.__setitem__("scale", new)
473
+ else:
474
+ scale = args[scale_index] if len(args) > scale_index else None
475
+ def set_scale(new):
476
+ nonlocal args
477
+ args = list(args)
478
+ if len(args) > scale_index:
479
+ args[scale_index] = new
480
+ else:
481
+ kwargs["scale"] = new
482
+ args = tuple(args)
483
+
484
+ is_fp8 = isinstance(qtype, str) and ("float8" in qtype.lower() or "qfloat8" in qtype.lower()) or \
485
+ (not isinstance(qtype, str) and "float8" in str(qtype).lower())
486
+
487
+ if is_fp8 and isinstance(scale, torch.Tensor) and scale.dtype == torch.float32:
488
+ need = int(size[axis]) if (isinstance(size, (tuple, list)) and axis is not None and axis >= 0) else None
489
+ if need is not None:
490
+ if scale.numel() == 1:
491
+ scale = scale.view(1).expand(need, *scale.shape[1:])
492
+ elif scale.shape[0] != need:
493
+ # Expand if uniform; otherwise raise
494
+ uniform = (scale.numel() == 1) or (torch.numel(scale.unique()) == 1)
495
+ if uniform:
496
+ scale = scale.reshape(1, *scale.shape[1:]).expand(need, *scale.shape[1:])
497
+ else:
498
+ raise ValueError(f"Scale leading dim {scale.shape[0]} != required {need}")
499
+ set_scale(scale)
500
+
501
+ return orig(*args, **kwargs)
502
+
503
+ _qbytes.WeightQBytesTensor.create = wrapper
504
+ _patch_state.enabled = True
505
+ _patch_state.orig = orig
506
+ _patch_state.scale_index = scale_index
507
+ return True
508
+
509
+ def disable_fp8_fp32_scale_support():
510
+ """Restore Quanto's original factory."""
511
+ if not _patch_state.enabled:
512
+ return False
513
+ from optimum.quanto.tensor.weights import qbytes as _qbytes
514
+ _qbytes.WeightQBytesTensor.create = _patch_state.orig
515
+ _patch_state.enabled = False
516
+ _patch_state.orig = None
517
+ _patch_state.scale_index = None
518
+ return True
519
+
520
+
521
+ def _quant_map_has_fp8(quant_map) -> bool:
522
+ if not quant_map:
523
+ return False
524
+ for cfg in quant_map.values():
525
+ if not isinstance(cfg, dict):
526
+ continue
527
+ weights = cfg.get("weights")
528
+ if isinstance(weights, str) and "qfloat8" in weights.lower():
529
+ return True
530
+ return False
531
+
532
+
533
+ def enable_fp8_marlin_fallback():
534
+ """
535
+ Replace Quanto's Marlin FP8 linear kernel with a plain matmul fallback.
536
+
537
+ When enabled, FP8 weights are dequantized on-the-fly and multiplied with
538
+ standard PyTorch ops, side-stepping quanto's custom CUDA kernels.
539
+ """
540
+ if _fp8_kernel_patch_state.enabled:
541
+ return True
542
+ try:
543
+ from optimum.quanto.tensor.function import QuantizedLinearFunction as _QLF
544
+ from optimum.quanto.tensor.weights.marlin.fp8 import qbits as marlin_fp8
545
+ except Exception:
546
+ return False
547
+
548
+ orig_forward = marlin_fp8.MarlinF8QBytesLinearFunction.forward
549
+
550
+ def fallback_forward(ctx, input, other, bias=None):
551
+ weight = other.dequantize()
552
+ if weight.dtype != input.dtype:
553
+ weight = weight.to(input.dtype)
554
+ weight = weight.contiguous()
555
+ return _QLF.forward(ctx, input, weight, bias)
556
+
557
+ marlin_fp8.MarlinF8QBytesLinearFunction.forward = staticmethod(fallback_forward)
558
+ _fp8_kernel_patch_state.enabled = True
559
+ _fp8_kernel_patch_state.orig_forward = orig_forward
560
+ return True
561
+
562
+
563
+ def disable_fp8_marlin_fallback():
564
+ """Restore Quanto's original Marlin FP8 kernel."""
565
+ if not _fp8_kernel_patch_state.enabled:
566
+ return False
567
+ from optimum.quanto.tensor.weights.marlin.fp8 import qbits as marlin_fp8
568
+ marlin_fp8.MarlinF8QBytesLinearFunction.forward = staticmethod(_fp8_kernel_patch_state.orig_forward)
569
+ _fp8_kernel_patch_state.enabled = False
570
+ _fp8_kernel_patch_state.orig_forward = None
571
+ return True
572
+
573
+
574
+ def maybe_enable_fp8_marlin_fallback(quantization_map=None):
575
+ """
576
+ Enable the FP8 fallback only when the provided quantization map contains FP8 weights.
577
+ """
578
+ if quantization_map is not None and not _quant_map_has_fp8(quantization_map):
579
+ return False
580
+ return enable_fp8_marlin_fallback()
581
+
582
+
583
+ def enable_fp8_marlin_workspace_fix():
584
+ """No-op placeholder (workspace patch removed)."""
585
+ return False
586
+
587
+
588
+ def disable_fp8_marlin_workspace_fix():
589
+ """No-op placeholder (workspace patch removed)."""
590
+ return False
591
+
592
+
593
+
594
+
595
+ # ---------- Tiny CLI (optional) ----------
596
+ def _cli():
597
+ import argparse, json as _json
598
+ p = argparse.ArgumentParser("fp8_quanto_bridge")
599
+ sub = p.add_subparsers(dest="cmd", required=True)
600
+
601
+ p_conv = sub.add_parser("convert", help="Convert scaled-FP8 (file) to Quanto artifacts.")
602
+ p_conv.add_argument("in_path")
603
+ p_conv.add_argument("out_weights")
604
+ p_conv.add_argument("out_qmap")
605
+ p_conv.add_argument("--fp8-format", choices=("e4m3fn", "e5m2"), default=None)
606
+ p_conv.add_argument("--scale-dtype", default="float32",
607
+ choices=("float32","bfloat16","float16","fp32","bf16","fp16","half"))
608
+ p_conv.add_argument("--no-activation-placeholders", action="store_true")
609
+ p_conv.add_argument("--default-missing-scale", type=float, default=1.0)
610
+
611
+ p_det = sub.add_parser("detect", help="Detect format quickly (path).")
612
+ p_det.add_argument("path")
613
+ p_det.add_argument("--probe", action="store_true")
614
+ p_det.add_argument("--hints", action="store_true")
615
+
616
+ p_patch = sub.add_parser("patch", help="Enable/disable FP32-scale runtime patch.")
617
+ p_patch.add_argument("mode", choices=("enable","disable"))
618
+
619
+ args = p.parse_args()
620
+
621
+ if args.cmd == "convert":
622
+ res = convert_scaled_fp8_to_quanto(
623
+ args.in_path,
624
+ fp8_format=args.fp8_format,
625
+ dtype=args.scale_dtype,
626
+ add_activation_placeholders=not args.no_activation_placeholders,
627
+ default_missing_scale=args.default_missing_scale,
628
+ )
629
+ save_file(res.state_dict, args.out_weights)
630
+ with open(args.out_qmap, "w") as f:
631
+ _json.dump(res.quant_map, f)
632
+ print(f"Wrote: {args.out_weights} and {args.out_qmap}. Patch needed: {res.patch_needed}")
633
+ return 0
634
+
635
+ if args.cmd == "detect":
636
+ info = detect_safetensors_format(args.path, probe_weights=args.probe, with_hints=args.hints)
637
+ print(info); return 0
638
+
639
+ if args.cmd == "patch":
640
+ ok = enable_fp8_fp32_scale_support() if args.mode == "enable" else disable_fp8_fp32_scale_support()
641
+ print(f"patch {args.mode}: {'ok' if ok else 'already in that state'}")
642
+ return 0
643
+
644
+ if __name__ == "__main__":
645
+ raise SystemExit(_cli())