mmgp 3.6.1__tar.gz → 3.6.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mmgp
3
- Version: 3.6.1
3
+ Version: 3.6.2
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  Requires-Python: >=3.10
@@ -15,7 +15,7 @@ Dynamic: license-file
15
15
 
16
16
 
17
17
  <p align="center">
18
- <H2>Memory Management 3.6.1 for the GPU Poor by DeepBeepMeep</H2>
18
+ <H2>Memory Management 3.6.2 for the GPU Poor by DeepBeepMeep</H2>
19
19
  </p>
20
20
 
21
21
 
@@ -1,6 +1,6 @@
1
1
 
2
2
  <p align="center">
3
- <H2>Memory Management 3.6.1 for the GPU Poor by DeepBeepMeep</H2>
3
+ <H2>Memory Management 3.6.2 for the GPU Poor by DeepBeepMeep</H2>
4
4
  </p>
5
5
 
6
6
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "mmgp"
3
- version = "3.6.1"
3
+ version = "3.6.2"
4
4
  authors = [
5
5
  { name = "deepbeepmeep", email = "deepbeepmeep@yahoo.com" },
6
6
  ]
@@ -0,0 +1,504 @@
1
+ from __future__ import annotations
2
+ import json, re, inspect
3
+ from types import SimpleNamespace
4
+ from typing import Dict, Optional, Tuple, Union, Iterable, Callable
5
+
6
+ import torch
7
+ from safetensors.torch import safe_open, save_file
8
+
9
+ # ---------- Constants ----------
10
+ DATA_SUFFIX = "._data"
11
+ SCALE_SUFFIX = "._scale" # per-channel, shape [out, 1, ...]
12
+ IN_SCALE = ".input_scale" # 1-D placeholder tensor [1]
13
+ OUT_SCALE = ".output_scale" # 1-D placeholder tensor [1]
14
+
15
+ _QTYPE_NAME = {
16
+ "e4m3fn": "qfloat8_e4m3fn",
17
+ "e5m2": "qfloat8_e5m2",
18
+ "auto": "qfloat8",
19
+ }
20
+
21
+ _SCALE_META_KEYS = (
22
+ "fp8_scale_map", "fp8.scale_map", "scale_map",
23
+ "quant_scale_map", "weights_scales", "scales",
24
+ )
25
+
26
+ _DTYPE_ALIASES = {
27
+ "float32": torch.float32, "fp32": torch.float32,
28
+ "bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
29
+ "float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
30
+ }
31
+
32
+ def _is_weight_key(k: str) -> bool:
33
+ return k.endswith(".weight")
34
+
35
+ # ---------- Accessors (unify file vs dict) ----------
36
+ class Accessor:
37
+ def keys(self) -> Iterable[str]: ...
38
+ def get_tensor(self, key: str) -> torch.Tensor: ...
39
+ def metadata(self) -> Dict[str, str]: ...
40
+ def has(self, key: str) -> bool: ... # NEW
41
+ def can_delete(self) -> bool: return False
42
+ def delete(self, key: str) -> None: raise NotImplementedError
43
+
44
+ class FileAccessor(Accessor):
45
+ def __init__(self, path: str):
46
+ self._fh = safe_open(path, framework="pt")
47
+ self._keys = list(self._fh.keys())
48
+ self._keys_set = set(self._keys) # O(1) membership
49
+ self._meta = self._fh.metadata() or {}
50
+ def keys(self) -> Iterable[str]: return self._keys
51
+ def has(self, key: str) -> bool: return key in self._keys_set
52
+ def get_tensor(self, key: str) -> torch.Tensor: return self._fh.get_tensor(key)
53
+ def metadata(self) -> Dict[str, str]: return self._meta
54
+ def close(self) -> None: self._fh.close()
55
+
56
+ class DictAccessor(Accessor):
57
+ def __init__(self, sd: Dict[str, torch.Tensor], meta: Optional[Dict[str, str]] = None,
58
+ in_place: bool = False, free_cuda_cache: bool = False, cuda_cache_interval: int = 32):
59
+ self.sd = sd
60
+ self._meta = meta or {}
61
+ self._in_place = in_place
62
+ self._free = free_cuda_cache
63
+ self._interval = int(cuda_cache_interval)
64
+ self._deletions = 0
65
+ def keys(self) -> Iterable[str]: return list(self.sd.keys())
66
+ def has(self, key: str) -> bool: return key in self.sd # dict membership = O(1)
67
+ def get_tensor(self, key: str) -> torch.Tensor: return self.sd[key]
68
+ def metadata(self) -> Dict[str, str]: return self._meta
69
+ def can_delete(self) -> bool: return self._in_place
70
+ def delete(self, key: str) -> None:
71
+ if key in self.sd:
72
+ self.sd.pop(key, None)
73
+ self._deletions += 1
74
+ if self._free and (self._deletions % self._interval == 0) and torch.cuda.is_available():
75
+ torch.cuda.empty_cache()
76
+ def _as_accessor(src: Union[str, Dict[str, torch.Tensor]], **dict_opts) -> Tuple[Accessor, Callable[[], None]]:
77
+ if isinstance(src, str):
78
+ acc = FileAccessor(src)
79
+ return acc, acc.close
80
+ acc = DictAccessor(src, **dict_opts)
81
+ return acc, (lambda: None)
82
+
83
+ # ---------- Shared helpers ----------
84
+ def _normalize_scale_dtype(scale_dtype: Union[str, torch.dtype]) -> torch.dtype:
85
+ if isinstance(scale_dtype, torch.dtype):
86
+ return scale_dtype
87
+ key = str(scale_dtype).lower()
88
+ if key not in _DTYPE_ALIASES:
89
+ raise ValueError(f"scale_dtype must be one of {list(_DTYPE_ALIASES.keys())} or a torch.dtype")
90
+ return _DTYPE_ALIASES[key]
91
+
92
+ def _json_to_dict(s: str) -> Optional[Dict]:
93
+ # Strictly catch JSON decoding only
94
+ try:
95
+ return json.loads(s)
96
+ except json.JSONDecodeError:
97
+ return None
98
+
99
+ def _maybe_parse_scale_map(meta: Dict[str, str]) -> Optional[Dict[str, float]]:
100
+ def try_parse(obj) -> Optional[Dict[str, float]]:
101
+ if not isinstance(obj, dict):
102
+ return None
103
+ out: Dict[str, float] = {}
104
+ for wk, v in obj.items():
105
+ if isinstance(v, (int, float)):
106
+ out[wk] = float(v)
107
+ elif isinstance(v, dict) and "scale" in v:
108
+ sc = v["scale"]
109
+ if isinstance(sc, (int, float)):
110
+ out[wk] = float(sc)
111
+ elif isinstance(sc, (list, tuple)) and len(sc) == 1 and isinstance(sc[0], (int, float)):
112
+ out[wk] = float(sc[0])
113
+ if out:
114
+ return out
115
+ for sub in ("weights", "tensors", "params", "map"):
116
+ subobj = obj.get(sub)
117
+ if isinstance(subobj, dict):
118
+ got = try_parse(subobj)
119
+ if got:
120
+ return got
121
+ return None
122
+
123
+ # exact keys first
124
+ for k in _SCALE_META_KEYS:
125
+ raw = meta.get(k)
126
+ if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"):
127
+ parsed = _json_to_dict(raw)
128
+ if parsed:
129
+ got = try_parse(parsed)
130
+ if got:
131
+ return got
132
+
133
+ # loose scan of any JSON-looking value
134
+ for v in meta.values():
135
+ if isinstance(v, str) and v.startswith("{") and v.endswith("}"):
136
+ parsed = _json_to_dict(v)
137
+ if parsed:
138
+ got = try_parse(parsed)
139
+ if got:
140
+ return got
141
+ return None
142
+
143
+ def _quick_fp8_variant_from_sentinel(acc: Accessor) -> Optional[str]:
144
+ if "scaled_fp8" in set(acc.keys()):
145
+ dt = acc.get_tensor("scaled_fp8").dtype
146
+ if dt == torch.float8_e4m3fn: return "e4m3fn"
147
+ if dt == torch.float8_e5m2: return "e5m2"
148
+ return None
149
+
150
+ def _per_channel_reshape(vec: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
151
+ return vec.view(weight.shape[0], *([1] * (weight.ndim - 1)))
152
+
153
+ # ---------- Unified converter ----------
154
+ class ConvertResult(Dict[str, object]):
155
+ @property
156
+ def state_dict(self) -> Dict[str, torch.Tensor]: return self["state_dict"] # type: ignore
157
+ @property
158
+ def quant_map(self) -> Dict[str, Dict]: return self["quant_map"] # type: ignore
159
+ @property
160
+ def fp8_format(self) -> str: return self["fp8_format"] # type: ignore
161
+ @property
162
+ def patch_needed(self) -> bool: return self["patch_needed"] # type: ignore
163
+
164
+ def convert_scaled_fp8_to_quanto(
165
+ src: Union[str, Dict[str, torch.Tensor]],
166
+ *,
167
+ fp8_format: Optional[str] = None, # 'e4m3fn' | 'e5m2' | None (auto)
168
+ require_scale: bool = False,
169
+ allow_default_scale: bool = True,
170
+ default_missing_scale: float = 1.0,
171
+ scale_dtype: Union[str, torch.dtype] = "float32",
172
+ add_activation_placeholders: bool = True,
173
+ # dict mode options
174
+ sd_metadata: Optional[Dict[str, str]] = None,
175
+ in_place: bool = False,
176
+ free_cuda_cache: bool = False,
177
+ cuda_cache_interval: int = 32,
178
+ ) -> ConvertResult:
179
+ sd_scale_dtype = _normalize_scale_dtype(scale_dtype)
180
+ patch_needed = (sd_scale_dtype == torch.float32)
181
+
182
+ acc, closer = _as_accessor(
183
+ src,
184
+ meta=sd_metadata,
185
+ in_place=in_place,
186
+ free_cuda_cache=free_cuda_cache,
187
+ cuda_cache_interval=cuda_cache_interval,
188
+ )
189
+ try:
190
+ meta = acc.metadata() or {}
191
+ meta_scale_map = _maybe_parse_scale_map(meta) or {}
192
+
193
+ keys = list(acc.keys())
194
+
195
+ # FP8 variant: sentinel -> first FP8 weight -> 'auto'
196
+ fmt = fp8_format or _quick_fp8_variant_from_sentinel(acc)
197
+ if fmt is None:
198
+ for wk in keys:
199
+ if not _is_weight_key(wk): continue
200
+ dt = acc.get_tensor(wk).dtype
201
+ if dt == torch.float8_e4m3fn: fmt = "e4m3fn"; break
202
+ if dt == torch.float8_e5m2: fmt = "e5m2"; break
203
+ if fmt is None: fmt = "auto"
204
+
205
+ # Map '<base>.scale_weight' -> '<base>.weight'
206
+ scale_weight_map: Dict[str, str] = {}
207
+ for sk in keys:
208
+ if sk.endswith(".scale_weight"):
209
+ base = sk[: -len(".scale_weight")]
210
+ wk = base + ".weight"
211
+ if wk in keys:
212
+ scale_weight_map[wk] = sk
213
+
214
+ def get_scale_vec_for_weight(wk: str, out_ch: int) -> Optional[torch.Tensor]:
215
+ # 1) explicit tensor
216
+ sk = scale_weight_map.get(wk)
217
+ if sk is not None:
218
+ s_t = acc.get_tensor(sk).to(torch.float32)
219
+ if s_t.numel() == 1:
220
+ return torch.full((out_ch,), float(s_t.item()), dtype=torch.float32)
221
+ if s_t.numel() == out_ch:
222
+ return s_t.reshape(out_ch)
223
+ if torch.numel(s_t.unique()) == 1:
224
+ return torch.full((out_ch,), float(s_t.view(-1)[0].item()), dtype=torch.float32)
225
+ raise ValueError(f"Unexpected scale length for '{wk}': {s_t.numel()} (out_ch={out_ch})")
226
+ # 2) metadata exact / normalized
227
+ if wk in meta_scale_map:
228
+ return torch.full((out_ch,), float(meta_scale_map[wk]), dtype=torch.float32)
229
+ for alt in (wk.replace("model.", ""), re.sub(r"(^|\.)weight$", "", wk)):
230
+ if alt in meta_scale_map:
231
+ return torch.full((out_ch,), float(meta_scale_map[alt]), dtype=torch.float32)
232
+ return None
233
+
234
+ # out dict: mutate original dict if in_place, else new dict
235
+ out_sd: Dict[str, torch.Tensor] = acc.sd if isinstance(acc, DictAccessor) and in_place else {}
236
+ qmap: Dict[str, Dict] = {}
237
+
238
+ # Single pass: rewrite FP8 weights, copy-through others
239
+ for k in keys:
240
+ # Drop source-only artifacts
241
+ if k == "scaled_fp8" or k.endswith(".scale_weight"):
242
+ if acc.can_delete(): acc.delete(k)
243
+ continue
244
+
245
+ if _is_weight_key(k):
246
+ t = acc.get_tensor(k)
247
+ if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
248
+ # Quantized: keep original FP8 tensor as _data
249
+ out_sd[k + DATA_SUFFIX] = t
250
+
251
+ out_ch = int(t.shape[0])
252
+ s_vec = get_scale_vec_for_weight(k, out_ch)
253
+ if s_vec is None:
254
+ if require_scale and not allow_default_scale:
255
+ raise KeyError(f"No scale found for '{k}' (looked for '.scale_weight' and metadata).")
256
+ s_vec = torch.full((out_ch,), float(default_missing_scale), dtype=torch.float32)
257
+
258
+ s_grid = _per_channel_reshape(s_vec, t).to(sd_scale_dtype)
259
+ out_sd[k + SCALE_SUFFIX] = s_grid
260
+
261
+ if add_activation_placeholders:
262
+ base = k[:-len(".weight")]
263
+ out_sd[base + IN_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
264
+ out_sd[base + OUT_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
265
+
266
+ base = k[:-len(".weight")]
267
+ qmap[base] = {"weights": _QTYPE_NAME[fmt], "activations": "none"}
268
+
269
+ if acc.can_delete():
270
+ acc.delete(k)
271
+ continue # don't copy original .weight
272
+
273
+ # Copy-through
274
+ if not (isinstance(acc, DictAccessor) and in_place):
275
+ out_sd[k] = acc.get_tensor(k)
276
+
277
+ return ConvertResult(state_dict=out_sd, quant_map=qmap, fp8_format=fmt, patch_needed=patch_needed)
278
+ finally:
279
+ closer()
280
+
281
+ def detect_safetensors_format(
282
+ src: Union[str, Dict[str, torch.Tensor]],
283
+ *,
284
+ sd_metadata: Optional[Dict[str, str]] = None,
285
+ probe_weights: bool = False, # if True, we may read up to 2 weights total
286
+ with_hints: bool = False,
287
+ ) -> Dict[str, str]:
288
+ """
289
+ Returns:
290
+ {
291
+ 'kind': 'quanto' | 'scaled_fp8' | 'fp8' | 'none',
292
+ 'quant_format': 'qfloat8_e4m3fn' | 'qfloat8_e5m2' | 'qfloat8' | 'qint8' | 'qint4' | 'unknown' | '',
293
+ 'fp8_format': 'e4m3fn' | 'e5m2' | 'unknown' | '',
294
+ 'hint': '...' # only when with_hints=True
295
+ }
296
+ """
297
+ acc, closer = _as_accessor(src, meta=sd_metadata, in_place=False)
298
+ try:
299
+ # --- O(1) sentinel test up-front (no key scan) ---
300
+ if acc.has("scaled_fp8"):
301
+ dt = acc.get_tensor("scaled_fp8").dtype
302
+ fp8_fmt = "e4m3fn" if dt == torch.float8_e4m3fn else ("e5m2" if dt == torch.float8_e5m2 else "unknown")
303
+ out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_fmt}
304
+ if with_hints: out["hint"] = "sentinel"
305
+ return out
306
+
307
+ # --- Single pass over keys (no re-scans) ---
308
+ ks = list(acc.keys())
309
+ has_scale_weight = False
310
+ saw_quanto_data = False
311
+ fp8_variant = None
312
+ fp8_probe_budget = 2 if probe_weights else 1
313
+
314
+ for k in ks:
315
+ # Quanto pack short-circuit
316
+ if not saw_quanto_data and k.endswith(DATA_SUFFIX):
317
+ saw_quanto_data = True
318
+ # we can break here, but keep minimal state setting uniformity
319
+ break
320
+
321
+ if saw_quanto_data:
322
+ out = {"kind": "quanto", "quant_format": "qfloat8", "fp8_format": ""}
323
+ if with_hints: out["hint"] = "keys:*._data"
324
+ return out
325
+
326
+ # continue single pass for the rest (scale keys + bounded dtype probe)
327
+ for k in ks:
328
+ if not has_scale_weight and k.endswith(".scale_weight"):
329
+ has_scale_weight = True
330
+ # don't return yet; we may still probe a dtype to grab variant
331
+
332
+ if fp8_probe_budget > 0 and _is_weight_key(k):
333
+ dt = acc.get_tensor(k).dtype
334
+ if dt == torch.float8_e4m3fn:
335
+ fp8_variant = "e4m3fn"; fp8_probe_budget -= 1
336
+ elif dt == torch.float8_e5m2:
337
+ fp8_variant = "e5m2"; fp8_probe_budget -= 1
338
+
339
+ if has_scale_weight:
340
+ out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_variant or "unknown"}
341
+ if with_hints: out["hint"] = "scale_weight keys"
342
+ return out
343
+
344
+ if fp8_variant is not None:
345
+ out = {"kind": "fp8", "quant_format": "", "fp8_format": fp8_variant}
346
+ if with_hints: out["hint"] = "weight dtype (plain fp8)"
347
+ return out
348
+
349
+ # --- Cheap metadata peek only if keys didn't decide it (no JSON parsing) ---
350
+ meta = acc.metadata() or {}
351
+ blob = " ".join(v for v in meta.values() if isinstance(v, str)).lower()
352
+
353
+ # scaled-fp8 hinted by metadata only
354
+ has_scale_map = (
355
+ any(k in meta for k in _SCALE_META_KEYS) or
356
+ (("scale" in blob) and (("fp8" in blob) or ("float8" in blob)))
357
+ )
358
+ if has_scale_map:
359
+ fmt = "e4m3fn" if "e4m3" in blob else ("e5m2" if "e5m2" in blob else "unknown")
360
+ out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fmt}
361
+ if with_hints: out["hint"] = "metadata"
362
+ return out
363
+
364
+ # quanto hinted by metadata only (not decisive without keys)
365
+ qtype_hint = ""
366
+ for tok in ("qfloat8_e4m3fn", "qfloat8_e5m2", "qfloat8", "qint8", "qint4"):
367
+ if tok in blob:
368
+ qtype_hint = tok
369
+ break
370
+
371
+ out = {"kind": "none", "quant_format": qtype_hint, "fp8_format": ""}
372
+ if with_hints: out["hint"] = "no decisive keys"
373
+ return out
374
+
375
+ finally:
376
+ closer()
377
+
378
+ # ---------- Optional Quanto runtime patch (FP32-scale support), enable/disable ----------
379
+ _patch_state = SimpleNamespace(enabled=False, orig=None, scale_index=None)
380
+
381
+ def enable_fp8_fp32_scale_support():
382
+ """
383
+ Version-robust wrapper for WeightQBytesTensor.create:
384
+ - matches both positional/keyword call styles via *args/**kwargs,
385
+ - for FP8 + FP32 scales, expands scalar/uniform scales with a VIEW to the needed length,
386
+ - leaves bf16/fp16 (classic Quanto) untouched.
387
+ Enable only if you emitted float32 scales.
388
+ """
389
+ if _patch_state.enabled:
390
+ return True
391
+
392
+ from optimum.quanto.tensor.weights import qbytes as _qbytes # late import
393
+ orig = _qbytes.WeightQBytesTensor.create
394
+ sig = inspect.signature(orig)
395
+ params = list(sig.parameters.keys())
396
+ scale_index = params.index("scale") if "scale" in params else 5 # fallback
397
+
398
+ def wrapper(*args, **kwargs):
399
+ # Extract fields irrespective of signature
400
+ qtype = kwargs.get("qtype", args[0] if len(args) > 0 else None)
401
+ axis = kwargs.get("axis", args[1] if len(args) > 1 else None)
402
+ size = kwargs.get("size", args[2] if len(args) > 2 else None)
403
+
404
+ if "scale" in kwargs:
405
+ scale = kwargs["scale"]
406
+ def set_scale(new): kwargs.__setitem__("scale", new)
407
+ else:
408
+ scale = args[scale_index] if len(args) > scale_index else None
409
+ def set_scale(new):
410
+ nonlocal args
411
+ args = list(args)
412
+ if len(args) > scale_index:
413
+ args[scale_index] = new
414
+ else:
415
+ kwargs["scale"] = new
416
+ args = tuple(args)
417
+
418
+ is_fp8 = isinstance(qtype, str) and ("float8" in qtype.lower() or "qfloat8" in qtype.lower()) or \
419
+ (not isinstance(qtype, str) and "float8" in str(qtype).lower())
420
+
421
+ if is_fp8 and isinstance(scale, torch.Tensor) and scale.dtype == torch.float32:
422
+ need = int(size[axis]) if (isinstance(size, (tuple, list)) and axis is not None and axis >= 0) else None
423
+ if need is not None:
424
+ if scale.numel() == 1:
425
+ scale = scale.view(1).expand(need, *scale.shape[1:])
426
+ elif scale.shape[0] != need:
427
+ # Expand if uniform; otherwise raise
428
+ uniform = (scale.numel() == 1) or (torch.numel(scale.unique()) == 1)
429
+ if uniform:
430
+ scale = scale.reshape(1, *scale.shape[1:]).expand(need, *scale.shape[1:])
431
+ else:
432
+ raise ValueError(f"Scale leading dim {scale.shape[0]} != required {need}")
433
+ set_scale(scale)
434
+
435
+ return orig(*args, **kwargs)
436
+
437
+ _qbytes.WeightQBytesTensor.create = wrapper
438
+ _patch_state.enabled = True
439
+ _patch_state.orig = orig
440
+ _patch_state.scale_index = scale_index
441
+ return True
442
+
443
+ def disable_fp8_fp32_scale_support():
444
+ """Restore Quanto's original factory."""
445
+ if not _patch_state.enabled:
446
+ return False
447
+ from optimum.quanto.tensor.weights import qbytes as _qbytes
448
+ _qbytes.WeightQBytesTensor.create = _patch_state.orig
449
+ _patch_state.enabled = False
450
+ _patch_state.orig = None
451
+ _patch_state.scale_index = None
452
+ return True
453
+
454
+ # ---------- Tiny CLI (optional) ----------
455
+ def _cli():
456
+ import argparse, json as _json
457
+ p = argparse.ArgumentParser("fp8_quanto_bridge")
458
+ sub = p.add_subparsers(dest="cmd", required=True)
459
+
460
+ p_conv = sub.add_parser("convert", help="Convert scaled-FP8 (file) to Quanto artifacts.")
461
+ p_conv.add_argument("in_path")
462
+ p_conv.add_argument("out_weights")
463
+ p_conv.add_argument("out_qmap")
464
+ p_conv.add_argument("--fp8-format", choices=("e4m3fn", "e5m2"), default=None)
465
+ p_conv.add_argument("--scale-dtype", default="float32",
466
+ choices=("float32","bfloat16","float16","fp32","bf16","fp16","half"))
467
+ p_conv.add_argument("--no-activation-placeholders", action="store_true")
468
+ p_conv.add_argument("--default-missing-scale", type=float, default=1.0)
469
+
470
+ p_det = sub.add_parser("detect", help="Detect format quickly (path).")
471
+ p_det.add_argument("path")
472
+ p_det.add_argument("--probe", action="store_true")
473
+ p_det.add_argument("--hints", action="store_true")
474
+
475
+ p_patch = sub.add_parser("patch", help="Enable/disable FP32-scale runtime patch.")
476
+ p_patch.add_argument("mode", choices=("enable","disable"))
477
+
478
+ args = p.parse_args()
479
+
480
+ if args.cmd == "convert":
481
+ res = convert_scaled_fp8_to_quanto(
482
+ args.in_path,
483
+ fp8_format=args.fp8_format,
484
+ scale_dtype=args.scale_dtype,
485
+ add_activation_placeholders=not args.no_activation_placeholders,
486
+ default_missing_scale=args.default_missing_scale,
487
+ )
488
+ save_file(res.state_dict, args.out_weights)
489
+ with open(args.out_qmap, "w") as f:
490
+ _json.dump(res.quant_map, f)
491
+ print(f"Wrote: {args.out_weights} and {args.out_qmap}. Patch needed: {res.patch_needed}")
492
+ return 0
493
+
494
+ if args.cmd == "detect":
495
+ info = detect_safetensors_format(args.path, probe_weights=args.probe, with_hints=args.hints)
496
+ print(info); return 0
497
+
498
+ if args.cmd == "patch":
499
+ ok = enable_fp8_fp32_scale_support() if args.mode == "enable" else disable_fp8_fp32_scale_support()
500
+ print(f"patch {args.mode}: {'ok' if ok else 'already in that state'}")
501
+ return 0
502
+
503
+ if __name__ == "__main__":
504
+ raise SystemExit(_cli())
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.6.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.6.2 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -71,7 +71,7 @@ import torch
71
71
 
72
72
  from mmgp import safetensors2
73
73
  from mmgp import profile_type
74
-
74
+ from .fp8_quanto_bridge import convert_scaled_fp8_to_quanto, detect_safetensors_format , enable_fp8_fp32_scale_support
75
75
  from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QLinear, QTensor, quantize_module, register_qmodule
76
76
 
77
77
  # support for Embedding module quantization that is not supported by default by quanto
@@ -688,7 +688,7 @@ def _welcome():
688
688
  if welcome_displayed:
689
689
  return
690
690
  welcome_displayed = True
691
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
691
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.2) by DeepBeepMeep ************{ENDC}{UNBOLD}")
692
692
 
693
693
  def change_dtype(model, new_dtype, exclude_buffers = False):
694
694
  for submodule_name, submodule in model.named_modules():
@@ -1319,7 +1319,7 @@ def move_loras_to_device(model, device="cpu" ):
1319
1319
  if ".lora_" in k:
1320
1320
  m.to(device)
1321
1321
 
1322
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None, configKwargs ={}):
1322
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, configKwargs ={}):
1323
1323
  """
1324
1324
  quick version of .LoadfromPretrained of the transformers library
1325
1325
  used to build a model and load the corresponding weights (quantized or not)
@@ -1407,13 +1407,13 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
1407
1407
 
1408
1408
  model._config = transformer_config
1409
1409
 
1410
- load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, verboseLevel=verboseLevel )
1410
+ load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, default_dtype = default_dtype, ignore_unused_weights = ignore_unused_weights, verboseLevel=verboseLevel )
1411
1411
 
1412
1412
  return model
1413
1413
 
1414
1414
 
1415
1415
 
1416
- def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, modules = None, return_shared_modules = None, verboseLevel = -1):
1416
+ def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, verboseLevel = -1):
1417
1417
  """
1418
1418
  Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
1419
1419
  """
@@ -1503,6 +1503,15 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1503
1503
  for tied_weights in tied_weights_list:
1504
1504
  state_dict[tied_weights] = mapped_weight
1505
1505
 
1506
+ if quantization_map is None:
1507
+ detection_type = detect_safetensors_format(state_dict)
1508
+ if detection_type["kind"] in ['scaled_fp8','fp8']:
1509
+ conv_result = convert_scaled_fp8_to_quanto(state_dict, scale_dtype = default_dtype)
1510
+ state_dict = conv_result["state_dict"]
1511
+ quantization_map = conv_result["quant_map"]
1512
+ conv_result = None
1513
+ # enable_fp8_fp32_scale_support()
1514
+
1506
1515
  if quantization_map is None:
1507
1516
  pos = str.rfind(file, ".")
1508
1517
  if pos > 0:
@@ -1562,7 +1571,7 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1562
1571
 
1563
1572
  del state_dict
1564
1573
 
1565
- if len(unexpected_keys) > 0 and verboseLevel >=2:
1574
+ if len(unexpected_keys) > 0 and verboseLevel >=2 and not ignore_unused_weights:
1566
1575
  print(f"Unexpected keys while loading '{file_path}': {unexpected_keys}")
1567
1576
 
1568
1577
  for k,p in model.named_parameters():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mmgp
3
- Version: 3.6.1
3
+ Version: 3.6.2
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  Requires-Python: >=3.10
@@ -15,7 +15,7 @@ Dynamic: license-file
15
15
 
16
16
 
17
17
  <p align="center">
18
- <H2>Memory Management 3.6.1 for the GPU Poor by DeepBeepMeep</H2>
18
+ <H2>Memory Management 3.6.2 for the GPU Poor by DeepBeepMeep</H2>
19
19
  </p>
20
20
 
21
21
 
@@ -3,6 +3,7 @@ README.md
3
3
  pyproject.toml
4
4
  src/__init__.py
5
5
  src/mmgp/__init__.py
6
+ src/mmgp/fp8_quanto_bridge.py
6
7
  src/mmgp/offload.py
7
8
  src/mmgp/safetensors2.py
8
9
  src/mmgp.egg-info/PKG-INFO
File without changes
File without changes
File without changes
File without changes
File without changes