mmgp 3.3.1__py3-none-any.whl → 3.6.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mmgp/fp8_quanto_bridge.py +645 -0
- mmgp/fp8_quanto_bridge_old.py +498 -0
- mmgp/offload.py +3613 -2461
- mmgp/quant_router.py +518 -0
- mmgp/quanto_int8_cuda.py +97 -0
- mmgp/quanto_int8_inject.py +335 -0
- mmgp/safetensors2.py +534 -450
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/METADATA +195 -197
- mmgp-3.6.11.dist-info/RECORD +14 -0
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/WHEEL +1 -1
- mmgp-3.3.1.dist-info/RECORD +0 -9
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/licenses/LICENSE.md +0 -0
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import json, re, inspect
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
from typing import Dict, Optional, Tuple, Union, Iterable, Callable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from safetensors.torch import safe_open, save_file
|
|
8
|
+
|
|
9
|
+
# ---------- Constants ----------
|
|
10
|
+
DATA_SUFFIX = "._data"
|
|
11
|
+
SCALE_SUFFIX = "._scale" # per-channel, shape [out, 1, ...]
|
|
12
|
+
IN_SCALE = ".input_scale" # 1-D placeholder tensor [1]
|
|
13
|
+
OUT_SCALE = ".output_scale" # 1-D placeholder tensor [1]
|
|
14
|
+
|
|
15
|
+
_QTYPE_NAME = {
|
|
16
|
+
"e4m3fn": "qfloat8_e4m3fn",
|
|
17
|
+
"e5m2": "qfloat8_e5m2",
|
|
18
|
+
"auto": "qfloat8",
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
_SCALE_META_KEYS = (
|
|
22
|
+
"fp8_scale_map", "fp8.scale_map", "scale_map",
|
|
23
|
+
"quant_scale_map", "weights_scales", "scales",
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
_DTYPE_ALIASES = {
|
|
27
|
+
"float32": torch.float32, "fp32": torch.float32,
|
|
28
|
+
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
|
|
29
|
+
"float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
def _is_weight_key(k: str) -> bool:
|
|
33
|
+
return k.endswith(".weight")
|
|
34
|
+
|
|
35
|
+
# ---------- Accessors (unify file vs dict) ----------
|
|
36
|
+
class Accessor:
|
|
37
|
+
def keys(self) -> Iterable[str]: ...
|
|
38
|
+
def get_tensor(self, key: str) -> torch.Tensor: ...
|
|
39
|
+
def metadata(self) -> Dict[str, str]: ...
|
|
40
|
+
def has(self, key: str) -> bool: ... # NEW
|
|
41
|
+
def can_delete(self) -> bool: return False
|
|
42
|
+
def delete(self, key: str) -> None: raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
class FileAccessor(Accessor):
|
|
45
|
+
def __init__(self, path: str):
|
|
46
|
+
self._fh = safe_open(path, framework="pt")
|
|
47
|
+
self._keys = list(self._fh.keys())
|
|
48
|
+
self._keys_set = set(self._keys) # O(1) membership
|
|
49
|
+
self._meta = self._fh.metadata() or {}
|
|
50
|
+
def keys(self) -> Iterable[str]: return self._keys
|
|
51
|
+
def has(self, key: str) -> bool: return key in self._keys_set
|
|
52
|
+
def get_tensor(self, key: str) -> torch.Tensor: return self._fh.get_tensor(key)
|
|
53
|
+
def metadata(self) -> Dict[str, str]: return self._meta
|
|
54
|
+
def close(self) -> None: self._fh.close()
|
|
55
|
+
|
|
56
|
+
class DictAccessor(Accessor):
|
|
57
|
+
def __init__(self, sd: Dict[str, torch.Tensor], meta: Optional[Dict[str, str]] = None,
|
|
58
|
+
in_place: bool = False, free_cuda_cache: bool = False, cuda_cache_interval: int = 32):
|
|
59
|
+
self.sd = sd
|
|
60
|
+
self._meta = meta or {}
|
|
61
|
+
self._in_place = in_place
|
|
62
|
+
self._free = free_cuda_cache
|
|
63
|
+
self._interval = int(cuda_cache_interval)
|
|
64
|
+
self._deletions = 0
|
|
65
|
+
def keys(self) -> Iterable[str]: return list(self.sd.keys())
|
|
66
|
+
def has(self, key: str) -> bool: return key in self.sd # dict membership = O(1)
|
|
67
|
+
def get_tensor(self, key: str) -> torch.Tensor: return self.sd[key]
|
|
68
|
+
def metadata(self) -> Dict[str, str]: return self._meta
|
|
69
|
+
def can_delete(self) -> bool: return self._in_place
|
|
70
|
+
def delete(self, key: str) -> None:
|
|
71
|
+
if key in self.sd:
|
|
72
|
+
self.sd.pop(key, None)
|
|
73
|
+
self._deletions += 1
|
|
74
|
+
if self._free and (self._deletions % self._interval == 0) and torch.cuda.is_available():
|
|
75
|
+
torch.cuda.empty_cache()
|
|
76
|
+
def _as_accessor(src: Union[str, Dict[str, torch.Tensor]], **dict_opts) -> Tuple[Accessor, Callable[[], None]]:
|
|
77
|
+
if isinstance(src, str):
|
|
78
|
+
acc = FileAccessor(src)
|
|
79
|
+
return acc, acc.close
|
|
80
|
+
acc = DictAccessor(src, **dict_opts)
|
|
81
|
+
return acc, (lambda: None)
|
|
82
|
+
|
|
83
|
+
# ---------- Shared helpers ----------
|
|
84
|
+
def _normalize_scale_dtype(scale_dtype: Union[str, torch.dtype]) -> torch.dtype:
|
|
85
|
+
if isinstance(scale_dtype, torch.dtype):
|
|
86
|
+
return scale_dtype
|
|
87
|
+
key = str(scale_dtype).lower()
|
|
88
|
+
if key not in _DTYPE_ALIASES:
|
|
89
|
+
raise ValueError(f"scale_dtype must be one of {list(_DTYPE_ALIASES.keys())} or a torch.dtype")
|
|
90
|
+
return _DTYPE_ALIASES[key]
|
|
91
|
+
|
|
92
|
+
def _json_to_dict(s: str) -> Optional[Dict]:
|
|
93
|
+
# Strictly catch JSON decoding only
|
|
94
|
+
try:
|
|
95
|
+
return json.loads(s)
|
|
96
|
+
except json.JSONDecodeError:
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
def _maybe_parse_scale_map(meta: Dict[str, str]) -> Optional[Dict[str, float]]:
|
|
100
|
+
def try_parse(obj) -> Optional[Dict[str, float]]:
|
|
101
|
+
if not isinstance(obj, dict):
|
|
102
|
+
return None
|
|
103
|
+
out: Dict[str, float] = {}
|
|
104
|
+
for wk, v in obj.items():
|
|
105
|
+
if isinstance(v, (int, float)):
|
|
106
|
+
out[wk] = float(v)
|
|
107
|
+
elif isinstance(v, dict) and "scale" in v:
|
|
108
|
+
sc = v["scale"]
|
|
109
|
+
if isinstance(sc, (int, float)):
|
|
110
|
+
out[wk] = float(sc)
|
|
111
|
+
elif isinstance(sc, (list, tuple)) and len(sc) == 1 and isinstance(sc[0], (int, float)):
|
|
112
|
+
out[wk] = float(sc[0])
|
|
113
|
+
if out:
|
|
114
|
+
return out
|
|
115
|
+
for sub in ("weights", "tensors", "params", "map"):
|
|
116
|
+
subobj = obj.get(sub)
|
|
117
|
+
if isinstance(subobj, dict):
|
|
118
|
+
got = try_parse(subobj)
|
|
119
|
+
if got:
|
|
120
|
+
return got
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
# exact keys first
|
|
124
|
+
for k in _SCALE_META_KEYS:
|
|
125
|
+
raw = meta.get(k)
|
|
126
|
+
if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"):
|
|
127
|
+
parsed = _json_to_dict(raw)
|
|
128
|
+
if parsed:
|
|
129
|
+
got = try_parse(parsed)
|
|
130
|
+
if got:
|
|
131
|
+
return got
|
|
132
|
+
|
|
133
|
+
# loose scan of any JSON-looking value
|
|
134
|
+
for v in meta.values():
|
|
135
|
+
if isinstance(v, str) and v.startswith("{") and v.endswith("}"):
|
|
136
|
+
parsed = _json_to_dict(v)
|
|
137
|
+
if parsed:
|
|
138
|
+
got = try_parse(parsed)
|
|
139
|
+
if got:
|
|
140
|
+
return got
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
def _quick_fp8_variant_from_sentinel(acc: Accessor) -> Optional[str]:
|
|
144
|
+
if "scaled_fp8" in set(acc.keys()):
|
|
145
|
+
dt = acc.get_tensor("scaled_fp8").dtype
|
|
146
|
+
if dt == torch.float8_e4m3fn: return "e4m3fn"
|
|
147
|
+
if dt == torch.float8_e5m2: return "e5m2"
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
def _per_channel_reshape(vec: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
|
151
|
+
return vec.view(weight.shape[0], *([1] * (weight.ndim - 1)))
|
|
152
|
+
|
|
153
|
+
# ---------- Unified converter ----------
|
|
154
|
+
class ConvertResult(Dict[str, object]):
|
|
155
|
+
@property
|
|
156
|
+
def state_dict(self) -> Dict[str, torch.Tensor]: return self["state_dict"] # type: ignore
|
|
157
|
+
@property
|
|
158
|
+
def quant_map(self) -> Dict[str, Dict]: return self["quant_map"] # type: ignore
|
|
159
|
+
@property
|
|
160
|
+
def fp8_format(self) -> str: return self["fp8_format"] # type: ignore
|
|
161
|
+
@property
|
|
162
|
+
def patch_needed(self) -> bool: return self["patch_needed"] # type: ignore
|
|
163
|
+
|
|
164
|
+
def convert_scaled_fp8_to_quanto(
|
|
165
|
+
src: Union[str, Dict[str, torch.Tensor]],
|
|
166
|
+
*,
|
|
167
|
+
fp8_format: Optional[str] = None, # 'e4m3fn' | 'e5m2' | None (auto)
|
|
168
|
+
require_scale: bool = False,
|
|
169
|
+
allow_default_scale: bool = True,
|
|
170
|
+
default_missing_scale: float = 1.0,
|
|
171
|
+
dtype: Union[str, torch.dtype] = "float32",
|
|
172
|
+
add_activation_placeholders: bool = True,
|
|
173
|
+
# dict mode options
|
|
174
|
+
sd_metadata: Optional[Dict[str, str]] = None,
|
|
175
|
+
in_place: bool = False,
|
|
176
|
+
free_cuda_cache: bool = False,
|
|
177
|
+
cuda_cache_interval: int = 32,
|
|
178
|
+
) -> ConvertResult:
|
|
179
|
+
sd_scale_dtype = _normalize_scale_dtype(dtype)
|
|
180
|
+
patch_needed = (sd_scale_dtype == torch.float32)
|
|
181
|
+
|
|
182
|
+
acc, closer = _as_accessor(
|
|
183
|
+
src,
|
|
184
|
+
meta=sd_metadata,
|
|
185
|
+
in_place=in_place,
|
|
186
|
+
free_cuda_cache=free_cuda_cache,
|
|
187
|
+
cuda_cache_interval=cuda_cache_interval,
|
|
188
|
+
)
|
|
189
|
+
if not acc.can_delete(): in_place = False
|
|
190
|
+
try:
|
|
191
|
+
meta = acc.metadata() or {}
|
|
192
|
+
meta_scale_map = _maybe_parse_scale_map(meta) or {}
|
|
193
|
+
|
|
194
|
+
keys = list(acc.keys())
|
|
195
|
+
|
|
196
|
+
# FP8 variant: sentinel -> first FP8 weight -> 'auto'
|
|
197
|
+
fmt = fp8_format or _quick_fp8_variant_from_sentinel(acc)
|
|
198
|
+
if fmt is None:
|
|
199
|
+
for wk in keys:
|
|
200
|
+
if not _is_weight_key(wk): continue
|
|
201
|
+
dt = acc.get_tensor(wk).dtype
|
|
202
|
+
if dt == torch.float8_e4m3fn: fmt = "e4m3fn"; break
|
|
203
|
+
if dt == torch.float8_e5m2: fmt = "e5m2"; break
|
|
204
|
+
if fmt is None: fmt = "auto"
|
|
205
|
+
|
|
206
|
+
# Map '<base>.scale_weight' -> '<base>.weight'
|
|
207
|
+
scale_weight_map: Dict[str, str] = {}
|
|
208
|
+
for sk in keys:
|
|
209
|
+
if sk.endswith(".scale_weight"):
|
|
210
|
+
base = sk[: -len(".scale_weight")]
|
|
211
|
+
wk = base + ".weight"
|
|
212
|
+
if wk in keys:
|
|
213
|
+
scale_weight_map[wk] = sk
|
|
214
|
+
|
|
215
|
+
def get_scale_vec_for_weight(wk: str, out_ch: int) -> Optional[torch.Tensor]:
|
|
216
|
+
# 1) explicit tensor
|
|
217
|
+
sk = scale_weight_map.get(wk)
|
|
218
|
+
if sk is not None:
|
|
219
|
+
s_t = acc.get_tensor(sk).to(torch.float32)
|
|
220
|
+
if in_place: acc.delete(s_t)
|
|
221
|
+
if s_t.numel() == 1:
|
|
222
|
+
return torch.full((out_ch,), float(s_t.item()), dtype=torch.float32)
|
|
223
|
+
if s_t.numel() == out_ch:
|
|
224
|
+
return s_t.reshape(out_ch)
|
|
225
|
+
if torch.numel(s_t.unique()) == 1:
|
|
226
|
+
return torch.full((out_ch,), float(s_t.view(-1)[0].item()), dtype=torch.float32)
|
|
227
|
+
raise ValueError(f"Unexpected scale length for '{wk}': {s_t.numel()} (out_ch={out_ch})")
|
|
228
|
+
# 2) metadata exact / normalized
|
|
229
|
+
if wk in meta_scale_map:
|
|
230
|
+
return torch.full((out_ch,), float(meta_scale_map[wk]), dtype=torch.float32)
|
|
231
|
+
for alt in (wk.replace("model.", ""), re.sub(r"(^|\.)weight$", "", wk)):
|
|
232
|
+
if alt in meta_scale_map:
|
|
233
|
+
return torch.full((out_ch,), float(meta_scale_map[alt]), dtype=torch.float32)
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
out_sd: Dict[str, torch.Tensor] = {}
|
|
237
|
+
qmap: Dict[str, Dict] = {}
|
|
238
|
+
|
|
239
|
+
# Single pass: rewrite FP8 weights, copy-through others
|
|
240
|
+
for k in keys:
|
|
241
|
+
# Drop source-only artifacts
|
|
242
|
+
if k == "scaled_fp8" or k.endswith(".scale_weight") :
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
t = acc.get_tensor(k)
|
|
246
|
+
if in_place: acc.delete(k)
|
|
247
|
+
if _is_weight_key(k) and t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
|
248
|
+
# Quantized: keep original FP8 tensor as _data
|
|
249
|
+
out_sd[k + DATA_SUFFIX] = t
|
|
250
|
+
|
|
251
|
+
out_ch = int(t.shape[0])
|
|
252
|
+
s_vec = get_scale_vec_for_weight(k, out_ch)
|
|
253
|
+
if s_vec is None:
|
|
254
|
+
if require_scale and not allow_default_scale:
|
|
255
|
+
raise KeyError(f"No scale found for '{k}' (looked for '.scale_weight' and metadata).")
|
|
256
|
+
s_vec = torch.full((out_ch,), float(default_missing_scale), dtype=torch.float32)
|
|
257
|
+
|
|
258
|
+
s_grid = _per_channel_reshape(s_vec, t).to(sd_scale_dtype)
|
|
259
|
+
out_sd[k + SCALE_SUFFIX] = s_grid
|
|
260
|
+
|
|
261
|
+
if add_activation_placeholders:
|
|
262
|
+
base = k[:-len(".weight")]
|
|
263
|
+
out_sd[base + IN_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
|
|
264
|
+
out_sd[base + OUT_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
|
|
265
|
+
|
|
266
|
+
base = k[:-len(".weight")]
|
|
267
|
+
qmap[base] = {"weights": _QTYPE_NAME[fmt], "activations": "none"}
|
|
268
|
+
else:
|
|
269
|
+
out_sd[k] = t if t.dtype == dtype or t.dtype == torch.float32 else t.to(dtype)
|
|
270
|
+
t = None
|
|
271
|
+
return ConvertResult(state_dict=out_sd, quant_map=qmap, fp8_format=fmt, patch_needed=patch_needed)
|
|
272
|
+
finally:
|
|
273
|
+
closer()
|
|
274
|
+
|
|
275
|
+
def detect_safetensors_format(
|
|
276
|
+
src: Union[str, Dict[str, torch.Tensor]],
|
|
277
|
+
*,
|
|
278
|
+
sd_metadata: Optional[Dict[str, str]] = None,
|
|
279
|
+
probe_weights: bool = False, # if True, we may read up to 2 weights total
|
|
280
|
+
with_hints: bool = False,
|
|
281
|
+
) -> Dict[str, str]:
|
|
282
|
+
"""
|
|
283
|
+
Returns:
|
|
284
|
+
{
|
|
285
|
+
'kind': 'quanto' | 'scaled_fp8' | 'fp8' | 'none',
|
|
286
|
+
'quant_format': 'qfloat8_e4m3fn' | 'qfloat8_e5m2' | 'qfloat8' | 'qint8' | 'qint4' | 'unknown' | '',
|
|
287
|
+
'fp8_format': 'e4m3fn' | 'e5m2' | 'unknown' | '',
|
|
288
|
+
'hint': '...' # only when with_hints=True
|
|
289
|
+
}
|
|
290
|
+
"""
|
|
291
|
+
acc, closer = _as_accessor(src, meta=sd_metadata, in_place=False)
|
|
292
|
+
try:
|
|
293
|
+
# --- O(1) sentinel test up-front (no key scan) ---
|
|
294
|
+
if acc.has("scaled_fp8"):
|
|
295
|
+
dt = acc.get_tensor("scaled_fp8").dtype
|
|
296
|
+
fp8_fmt = "e4m3fn" if dt == torch.float8_e4m3fn else ("e5m2" if dt == torch.float8_e5m2 else "unknown")
|
|
297
|
+
out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_fmt}
|
|
298
|
+
if with_hints: out["hint"] = "sentinel"
|
|
299
|
+
return out
|
|
300
|
+
|
|
301
|
+
# --- Single pass over keys (no re-scans) ---
|
|
302
|
+
ks = list(acc.keys())
|
|
303
|
+
has_scale_weight = False
|
|
304
|
+
saw_quanto_data = False
|
|
305
|
+
fp8_variant = None
|
|
306
|
+
fp8_probe_budget = 2 if probe_weights else 1
|
|
307
|
+
|
|
308
|
+
for k in ks:
|
|
309
|
+
# Quanto pack short-circuit
|
|
310
|
+
if not saw_quanto_data and k.endswith(DATA_SUFFIX):
|
|
311
|
+
saw_quanto_data = True
|
|
312
|
+
# we can break here, but keep minimal state setting uniformity
|
|
313
|
+
break
|
|
314
|
+
|
|
315
|
+
if saw_quanto_data:
|
|
316
|
+
out = {"kind": "quanto", "quant_format": "qfloat8", "fp8_format": ""}
|
|
317
|
+
if with_hints: out["hint"] = "keys:*._data"
|
|
318
|
+
return out
|
|
319
|
+
|
|
320
|
+
# continue single pass for the rest (scale keys + bounded dtype probe)
|
|
321
|
+
for k in ks:
|
|
322
|
+
if not has_scale_weight and k.endswith(".scale_weight"):
|
|
323
|
+
has_scale_weight = True
|
|
324
|
+
# don't return yet; we may still probe a dtype to grab variant
|
|
325
|
+
|
|
326
|
+
if fp8_probe_budget > 0 and _is_weight_key(k):
|
|
327
|
+
dt = acc.get_tensor(k).dtype
|
|
328
|
+
if dt == torch.float8_e4m3fn:
|
|
329
|
+
fp8_variant = "e4m3fn"; fp8_probe_budget -= 1
|
|
330
|
+
elif dt == torch.float8_e5m2:
|
|
331
|
+
fp8_variant = "e5m2"; fp8_probe_budget -= 1
|
|
332
|
+
|
|
333
|
+
if has_scale_weight:
|
|
334
|
+
out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_variant or "unknown"}
|
|
335
|
+
if with_hints: out["hint"] = "scale_weight keys"
|
|
336
|
+
return out
|
|
337
|
+
|
|
338
|
+
if fp8_variant is not None:
|
|
339
|
+
out = {"kind": "fp8", "quant_format": "", "fp8_format": fp8_variant}
|
|
340
|
+
if with_hints: out["hint"] = "weight dtype (plain fp8)"
|
|
341
|
+
return out
|
|
342
|
+
|
|
343
|
+
# --- Cheap metadata peek only if keys didn't decide it (no JSON parsing) ---
|
|
344
|
+
meta = acc.metadata() or {}
|
|
345
|
+
blob = " ".join(v for v in meta.values() if isinstance(v, str)).lower()
|
|
346
|
+
|
|
347
|
+
# scaled-fp8 hinted by metadata only
|
|
348
|
+
has_scale_map = (
|
|
349
|
+
any(k in meta for k in _SCALE_META_KEYS) or
|
|
350
|
+
(("scale" in blob) and (("fp8" in blob) or ("float8" in blob)))
|
|
351
|
+
)
|
|
352
|
+
if has_scale_map:
|
|
353
|
+
fmt = "e4m3fn" if "e4m3" in blob else ("e5m2" if "e5m2" in blob else "unknown")
|
|
354
|
+
out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fmt}
|
|
355
|
+
if with_hints: out["hint"] = "metadata"
|
|
356
|
+
return out
|
|
357
|
+
|
|
358
|
+
# quanto hinted by metadata only (not decisive without keys)
|
|
359
|
+
qtype_hint = ""
|
|
360
|
+
for tok in ("qfloat8_e4m3fn", "qfloat8_e5m2", "qfloat8", "qint8", "qint4"):
|
|
361
|
+
if tok in blob:
|
|
362
|
+
qtype_hint = tok
|
|
363
|
+
break
|
|
364
|
+
|
|
365
|
+
out = {"kind": "none", "quant_format": qtype_hint, "fp8_format": ""}
|
|
366
|
+
if with_hints: out["hint"] = "no decisive keys"
|
|
367
|
+
return out
|
|
368
|
+
|
|
369
|
+
finally:
|
|
370
|
+
closer()
|
|
371
|
+
|
|
372
|
+
# ---------- Optional Quanto runtime patch (FP32-scale support), enable/disable ----------
|
|
373
|
+
_patch_state = SimpleNamespace(enabled=False, orig=None, scale_index=None)
|
|
374
|
+
|
|
375
|
+
def enable_fp8_fp32_scale_support():
|
|
376
|
+
"""
|
|
377
|
+
Version-robust wrapper for WeightQBytesTensor.create:
|
|
378
|
+
- matches both positional/keyword call styles via *args/**kwargs,
|
|
379
|
+
- for FP8 + FP32 scales, expands scalar/uniform scales with a VIEW to the needed length,
|
|
380
|
+
- leaves bf16/fp16 (classic Quanto) untouched.
|
|
381
|
+
Enable only if you emitted float32 scales.
|
|
382
|
+
"""
|
|
383
|
+
if _patch_state.enabled:
|
|
384
|
+
return True
|
|
385
|
+
|
|
386
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes # late import
|
|
387
|
+
orig = _qbytes.WeightQBytesTensor.create
|
|
388
|
+
sig = inspect.signature(orig)
|
|
389
|
+
params = list(sig.parameters.keys())
|
|
390
|
+
scale_index = params.index("scale") if "scale" in params else 5 # fallback
|
|
391
|
+
|
|
392
|
+
def wrapper(*args, **kwargs):
|
|
393
|
+
# Extract fields irrespective of signature
|
|
394
|
+
qtype = kwargs.get("qtype", args[0] if len(args) > 0 else None)
|
|
395
|
+
axis = kwargs.get("axis", args[1] if len(args) > 1 else None)
|
|
396
|
+
size = kwargs.get("size", args[2] if len(args) > 2 else None)
|
|
397
|
+
|
|
398
|
+
if "scale" in kwargs:
|
|
399
|
+
scale = kwargs["scale"]
|
|
400
|
+
def set_scale(new): kwargs.__setitem__("scale", new)
|
|
401
|
+
else:
|
|
402
|
+
scale = args[scale_index] if len(args) > scale_index else None
|
|
403
|
+
def set_scale(new):
|
|
404
|
+
nonlocal args
|
|
405
|
+
args = list(args)
|
|
406
|
+
if len(args) > scale_index:
|
|
407
|
+
args[scale_index] = new
|
|
408
|
+
else:
|
|
409
|
+
kwargs["scale"] = new
|
|
410
|
+
args = tuple(args)
|
|
411
|
+
|
|
412
|
+
is_fp8 = isinstance(qtype, str) and ("float8" in qtype.lower() or "qfloat8" in qtype.lower()) or \
|
|
413
|
+
(not isinstance(qtype, str) and "float8" in str(qtype).lower())
|
|
414
|
+
|
|
415
|
+
if is_fp8 and isinstance(scale, torch.Tensor) and scale.dtype == torch.float32:
|
|
416
|
+
need = int(size[axis]) if (isinstance(size, (tuple, list)) and axis is not None and axis >= 0) else None
|
|
417
|
+
if need is not None:
|
|
418
|
+
if scale.numel() == 1:
|
|
419
|
+
scale = scale.view(1).expand(need, *scale.shape[1:])
|
|
420
|
+
elif scale.shape[0] != need:
|
|
421
|
+
# Expand if uniform; otherwise raise
|
|
422
|
+
uniform = (scale.numel() == 1) or (torch.numel(scale.unique()) == 1)
|
|
423
|
+
if uniform:
|
|
424
|
+
scale = scale.reshape(1, *scale.shape[1:]).expand(need, *scale.shape[1:])
|
|
425
|
+
else:
|
|
426
|
+
raise ValueError(f"Scale leading dim {scale.shape[0]} != required {need}")
|
|
427
|
+
set_scale(scale)
|
|
428
|
+
|
|
429
|
+
return orig(*args, **kwargs)
|
|
430
|
+
|
|
431
|
+
_qbytes.WeightQBytesTensor.create = wrapper
|
|
432
|
+
_patch_state.enabled = True
|
|
433
|
+
_patch_state.orig = orig
|
|
434
|
+
_patch_state.scale_index = scale_index
|
|
435
|
+
return True
|
|
436
|
+
|
|
437
|
+
def disable_fp8_fp32_scale_support():
|
|
438
|
+
"""Restore Quanto's original factory."""
|
|
439
|
+
if not _patch_state.enabled:
|
|
440
|
+
return False
|
|
441
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes
|
|
442
|
+
_qbytes.WeightQBytesTensor.create = _patch_state.orig
|
|
443
|
+
_patch_state.enabled = False
|
|
444
|
+
_patch_state.orig = None
|
|
445
|
+
_patch_state.scale_index = None
|
|
446
|
+
return True
|
|
447
|
+
|
|
448
|
+
# ---------- Tiny CLI (optional) ----------
|
|
449
|
+
def _cli():
|
|
450
|
+
import argparse, json as _json
|
|
451
|
+
p = argparse.ArgumentParser("fp8_quanto_bridge")
|
|
452
|
+
sub = p.add_subparsers(dest="cmd", required=True)
|
|
453
|
+
|
|
454
|
+
p_conv = sub.add_parser("convert", help="Convert scaled-FP8 (file) to Quanto artifacts.")
|
|
455
|
+
p_conv.add_argument("in_path")
|
|
456
|
+
p_conv.add_argument("out_weights")
|
|
457
|
+
p_conv.add_argument("out_qmap")
|
|
458
|
+
p_conv.add_argument("--fp8-format", choices=("e4m3fn", "e5m2"), default=None)
|
|
459
|
+
p_conv.add_argument("--scale-dtype", default="float32",
|
|
460
|
+
choices=("float32","bfloat16","float16","fp32","bf16","fp16","half"))
|
|
461
|
+
p_conv.add_argument("--no-activation-placeholders", action="store_true")
|
|
462
|
+
p_conv.add_argument("--default-missing-scale", type=float, default=1.0)
|
|
463
|
+
|
|
464
|
+
p_det = sub.add_parser("detect", help="Detect format quickly (path).")
|
|
465
|
+
p_det.add_argument("path")
|
|
466
|
+
p_det.add_argument("--probe", action="store_true")
|
|
467
|
+
p_det.add_argument("--hints", action="store_true")
|
|
468
|
+
|
|
469
|
+
p_patch = sub.add_parser("patch", help="Enable/disable FP32-scale runtime patch.")
|
|
470
|
+
p_patch.add_argument("mode", choices=("enable","disable"))
|
|
471
|
+
|
|
472
|
+
args = p.parse_args()
|
|
473
|
+
|
|
474
|
+
if args.cmd == "convert":
|
|
475
|
+
res = convert_scaled_fp8_to_quanto(
|
|
476
|
+
args.in_path,
|
|
477
|
+
fp8_format=args.fp8_format,
|
|
478
|
+
dtype=args.scale_dtype,
|
|
479
|
+
add_activation_placeholders=not args.no_activation_placeholders,
|
|
480
|
+
default_missing_scale=args.default_missing_scale,
|
|
481
|
+
)
|
|
482
|
+
save_file(res.state_dict, args.out_weights)
|
|
483
|
+
with open(args.out_qmap, "w") as f:
|
|
484
|
+
_json.dump(res.quant_map, f)
|
|
485
|
+
print(f"Wrote: {args.out_weights} and {args.out_qmap}. Patch needed: {res.patch_needed}")
|
|
486
|
+
return 0
|
|
487
|
+
|
|
488
|
+
if args.cmd == "detect":
|
|
489
|
+
info = detect_safetensors_format(args.path, probe_weights=args.probe, with_hints=args.hints)
|
|
490
|
+
print(info); return 0
|
|
491
|
+
|
|
492
|
+
if args.cmd == "patch":
|
|
493
|
+
ok = enable_fp8_fp32_scale_support() if args.mode == "enable" else disable_fp8_fp32_scale_support()
|
|
494
|
+
print(f"patch {args.mode}: {'ok' if ok else 'already in that state'}")
|
|
495
|
+
return 0
|
|
496
|
+
|
|
497
|
+
if __name__ == "__main__":
|
|
498
|
+
raise SystemExit(_cli())
|