mmgp 3.6.1__tar.gz → 3.6.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mmgp
3
- Version: 3.6.1
3
+ Version: 3.6.3
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  Requires-Python: >=3.10
@@ -15,7 +15,7 @@ Dynamic: license-file
15
15
 
16
16
 
17
17
  <p align="center">
18
- <H2>Memory Management 3.6.1 for the GPU Poor by DeepBeepMeep</H2>
18
+ <H2>Memory Management 3.6.3 for the GPU Poor by DeepBeepMeep</H2>
19
19
  </p>
20
20
 
21
21
 
@@ -1,6 +1,6 @@
1
1
 
2
2
  <p align="center">
3
- <H2>Memory Management 3.6.1 for the GPU Poor by DeepBeepMeep</H2>
3
+ <H2>Memory Management 3.6.3 for the GPU Poor by DeepBeepMeep</H2>
4
4
  </p>
5
5
 
6
6
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "mmgp"
3
- version = "3.6.1"
3
+ version = "3.6.3"
4
4
  authors = [
5
5
  { name = "deepbeepmeep", email = "deepbeepmeep@yahoo.com" },
6
6
  ]
@@ -0,0 +1,498 @@
1
+ from __future__ import annotations
2
+ import json, re, inspect
3
+ from types import SimpleNamespace
4
+ from typing import Dict, Optional, Tuple, Union, Iterable, Callable
5
+
6
+ import torch
7
+ from safetensors.torch import safe_open, save_file
8
+
9
+ # ---------- Constants ----------
10
+ DATA_SUFFIX = "._data"
11
+ SCALE_SUFFIX = "._scale" # per-channel, shape [out, 1, ...]
12
+ IN_SCALE = ".input_scale" # 1-D placeholder tensor [1]
13
+ OUT_SCALE = ".output_scale" # 1-D placeholder tensor [1]
14
+
15
+ _QTYPE_NAME = {
16
+ "e4m3fn": "qfloat8_e4m3fn",
17
+ "e5m2": "qfloat8_e5m2",
18
+ "auto": "qfloat8",
19
+ }
20
+
21
+ _SCALE_META_KEYS = (
22
+ "fp8_scale_map", "fp8.scale_map", "scale_map",
23
+ "quant_scale_map", "weights_scales", "scales",
24
+ )
25
+
26
+ _DTYPE_ALIASES = {
27
+ "float32": torch.float32, "fp32": torch.float32,
28
+ "bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
29
+ "float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
30
+ }
31
+
32
+ def _is_weight_key(k: str) -> bool:
33
+ return k.endswith(".weight")
34
+
35
+ # ---------- Accessors (unify file vs dict) ----------
36
+ class Accessor:
37
+ def keys(self) -> Iterable[str]: ...
38
+ def get_tensor(self, key: str) -> torch.Tensor: ...
39
+ def metadata(self) -> Dict[str, str]: ...
40
+ def has(self, key: str) -> bool: ... # NEW
41
+ def can_delete(self) -> bool: return False
42
+ def delete(self, key: str) -> None: raise NotImplementedError
43
+
44
+ class FileAccessor(Accessor):
45
+ def __init__(self, path: str):
46
+ self._fh = safe_open(path, framework="pt")
47
+ self._keys = list(self._fh.keys())
48
+ self._keys_set = set(self._keys) # O(1) membership
49
+ self._meta = self._fh.metadata() or {}
50
+ def keys(self) -> Iterable[str]: return self._keys
51
+ def has(self, key: str) -> bool: return key in self._keys_set
52
+ def get_tensor(self, key: str) -> torch.Tensor: return self._fh.get_tensor(key)
53
+ def metadata(self) -> Dict[str, str]: return self._meta
54
+ def close(self) -> None: self._fh.close()
55
+
56
+ class DictAccessor(Accessor):
57
+ def __init__(self, sd: Dict[str, torch.Tensor], meta: Optional[Dict[str, str]] = None,
58
+ in_place: bool = False, free_cuda_cache: bool = False, cuda_cache_interval: int = 32):
59
+ self.sd = sd
60
+ self._meta = meta or {}
61
+ self._in_place = in_place
62
+ self._free = free_cuda_cache
63
+ self._interval = int(cuda_cache_interval)
64
+ self._deletions = 0
65
+ def keys(self) -> Iterable[str]: return list(self.sd.keys())
66
+ def has(self, key: str) -> bool: return key in self.sd # dict membership = O(1)
67
+ def get_tensor(self, key: str) -> torch.Tensor: return self.sd[key]
68
+ def metadata(self) -> Dict[str, str]: return self._meta
69
+ def can_delete(self) -> bool: return self._in_place
70
+ def delete(self, key: str) -> None:
71
+ if key in self.sd:
72
+ self.sd.pop(key, None)
73
+ self._deletions += 1
74
+ if self._free and (self._deletions % self._interval == 0) and torch.cuda.is_available():
75
+ torch.cuda.empty_cache()
76
+ def _as_accessor(src: Union[str, Dict[str, torch.Tensor]], **dict_opts) -> Tuple[Accessor, Callable[[], None]]:
77
+ if isinstance(src, str):
78
+ acc = FileAccessor(src)
79
+ return acc, acc.close
80
+ acc = DictAccessor(src, **dict_opts)
81
+ return acc, (lambda: None)
82
+
83
+ # ---------- Shared helpers ----------
84
+ def _normalize_scale_dtype(scale_dtype: Union[str, torch.dtype]) -> torch.dtype:
85
+ if isinstance(scale_dtype, torch.dtype):
86
+ return scale_dtype
87
+ key = str(scale_dtype).lower()
88
+ if key not in _DTYPE_ALIASES:
89
+ raise ValueError(f"scale_dtype must be one of {list(_DTYPE_ALIASES.keys())} or a torch.dtype")
90
+ return _DTYPE_ALIASES[key]
91
+
92
+ def _json_to_dict(s: str) -> Optional[Dict]:
93
+ # Strictly catch JSON decoding only
94
+ try:
95
+ return json.loads(s)
96
+ except json.JSONDecodeError:
97
+ return None
98
+
99
+ def _maybe_parse_scale_map(meta: Dict[str, str]) -> Optional[Dict[str, float]]:
100
+ def try_parse(obj) -> Optional[Dict[str, float]]:
101
+ if not isinstance(obj, dict):
102
+ return None
103
+ out: Dict[str, float] = {}
104
+ for wk, v in obj.items():
105
+ if isinstance(v, (int, float)):
106
+ out[wk] = float(v)
107
+ elif isinstance(v, dict) and "scale" in v:
108
+ sc = v["scale"]
109
+ if isinstance(sc, (int, float)):
110
+ out[wk] = float(sc)
111
+ elif isinstance(sc, (list, tuple)) and len(sc) == 1 and isinstance(sc[0], (int, float)):
112
+ out[wk] = float(sc[0])
113
+ if out:
114
+ return out
115
+ for sub in ("weights", "tensors", "params", "map"):
116
+ subobj = obj.get(sub)
117
+ if isinstance(subobj, dict):
118
+ got = try_parse(subobj)
119
+ if got:
120
+ return got
121
+ return None
122
+
123
+ # exact keys first
124
+ for k in _SCALE_META_KEYS:
125
+ raw = meta.get(k)
126
+ if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"):
127
+ parsed = _json_to_dict(raw)
128
+ if parsed:
129
+ got = try_parse(parsed)
130
+ if got:
131
+ return got
132
+
133
+ # loose scan of any JSON-looking value
134
+ for v in meta.values():
135
+ if isinstance(v, str) and v.startswith("{") and v.endswith("}"):
136
+ parsed = _json_to_dict(v)
137
+ if parsed:
138
+ got = try_parse(parsed)
139
+ if got:
140
+ return got
141
+ return None
142
+
143
+ def _quick_fp8_variant_from_sentinel(acc: Accessor) -> Optional[str]:
144
+ if "scaled_fp8" in set(acc.keys()):
145
+ dt = acc.get_tensor("scaled_fp8").dtype
146
+ if dt == torch.float8_e4m3fn: return "e4m3fn"
147
+ if dt == torch.float8_e5m2: return "e5m2"
148
+ return None
149
+
150
+ def _per_channel_reshape(vec: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
151
+ return vec.view(weight.shape[0], *([1] * (weight.ndim - 1)))
152
+
153
+ # ---------- Unified converter ----------
154
+ class ConvertResult(Dict[str, object]):
155
+ @property
156
+ def state_dict(self) -> Dict[str, torch.Tensor]: return self["state_dict"] # type: ignore
157
+ @property
158
+ def quant_map(self) -> Dict[str, Dict]: return self["quant_map"] # type: ignore
159
+ @property
160
+ def fp8_format(self) -> str: return self["fp8_format"] # type: ignore
161
+ @property
162
+ def patch_needed(self) -> bool: return self["patch_needed"] # type: ignore
163
+
164
+ def convert_scaled_fp8_to_quanto(
165
+ src: Union[str, Dict[str, torch.Tensor]],
166
+ *,
167
+ fp8_format: Optional[str] = None, # 'e4m3fn' | 'e5m2' | None (auto)
168
+ require_scale: bool = False,
169
+ allow_default_scale: bool = True,
170
+ default_missing_scale: float = 1.0,
171
+ dtype: Union[str, torch.dtype] = "float32",
172
+ add_activation_placeholders: bool = True,
173
+ # dict mode options
174
+ sd_metadata: Optional[Dict[str, str]] = None,
175
+ in_place: bool = False,
176
+ free_cuda_cache: bool = False,
177
+ cuda_cache_interval: int = 32,
178
+ ) -> ConvertResult:
179
+ sd_scale_dtype = _normalize_scale_dtype(dtype)
180
+ patch_needed = (sd_scale_dtype == torch.float32)
181
+
182
+ acc, closer = _as_accessor(
183
+ src,
184
+ meta=sd_metadata,
185
+ in_place=in_place,
186
+ free_cuda_cache=free_cuda_cache,
187
+ cuda_cache_interval=cuda_cache_interval,
188
+ )
189
+ if not acc.can_delete(): in_place = False
190
+ try:
191
+ meta = acc.metadata() or {}
192
+ meta_scale_map = _maybe_parse_scale_map(meta) or {}
193
+
194
+ keys = list(acc.keys())
195
+
196
+ # FP8 variant: sentinel -> first FP8 weight -> 'auto'
197
+ fmt = fp8_format or _quick_fp8_variant_from_sentinel(acc)
198
+ if fmt is None:
199
+ for wk in keys:
200
+ if not _is_weight_key(wk): continue
201
+ dt = acc.get_tensor(wk).dtype
202
+ if dt == torch.float8_e4m3fn: fmt = "e4m3fn"; break
203
+ if dt == torch.float8_e5m2: fmt = "e5m2"; break
204
+ if fmt is None: fmt = "auto"
205
+
206
+ # Map '<base>.scale_weight' -> '<base>.weight'
207
+ scale_weight_map: Dict[str, str] = {}
208
+ for sk in keys:
209
+ if sk.endswith(".scale_weight"):
210
+ base = sk[: -len(".scale_weight")]
211
+ wk = base + ".weight"
212
+ if wk in keys:
213
+ scale_weight_map[wk] = sk
214
+
215
+ def get_scale_vec_for_weight(wk: str, out_ch: int) -> Optional[torch.Tensor]:
216
+ # 1) explicit tensor
217
+ sk = scale_weight_map.get(wk)
218
+ if sk is not None:
219
+ s_t = acc.get_tensor(sk).to(torch.float32)
220
+ if in_place: acc.delete(s_t)
221
+ if s_t.numel() == 1:
222
+ return torch.full((out_ch,), float(s_t.item()), dtype=torch.float32)
223
+ if s_t.numel() == out_ch:
224
+ return s_t.reshape(out_ch)
225
+ if torch.numel(s_t.unique()) == 1:
226
+ return torch.full((out_ch,), float(s_t.view(-1)[0].item()), dtype=torch.float32)
227
+ raise ValueError(f"Unexpected scale length for '{wk}': {s_t.numel()} (out_ch={out_ch})")
228
+ # 2) metadata exact / normalized
229
+ if wk in meta_scale_map:
230
+ return torch.full((out_ch,), float(meta_scale_map[wk]), dtype=torch.float32)
231
+ for alt in (wk.replace("model.", ""), re.sub(r"(^|\.)weight$", "", wk)):
232
+ if alt in meta_scale_map:
233
+ return torch.full((out_ch,), float(meta_scale_map[alt]), dtype=torch.float32)
234
+ return None
235
+
236
+ out_sd: Dict[str, torch.Tensor] = {}
237
+ qmap: Dict[str, Dict] = {}
238
+
239
+ # Single pass: rewrite FP8 weights, copy-through others
240
+ for k in keys:
241
+ # Drop source-only artifacts
242
+ if k == "scaled_fp8" or k.endswith(".scale_weight") :
243
+ continue
244
+
245
+ t = acc.get_tensor(k)
246
+ if in_place: acc.delete(k)
247
+ if _is_weight_key(k) and t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
248
+ # Quantized: keep original FP8 tensor as _data
249
+ out_sd[k + DATA_SUFFIX] = t
250
+
251
+ out_ch = int(t.shape[0])
252
+ s_vec = get_scale_vec_for_weight(k, out_ch)
253
+ if s_vec is None:
254
+ if require_scale and not allow_default_scale:
255
+ raise KeyError(f"No scale found for '{k}' (looked for '.scale_weight' and metadata).")
256
+ s_vec = torch.full((out_ch,), float(default_missing_scale), dtype=torch.float32)
257
+
258
+ s_grid = _per_channel_reshape(s_vec, t).to(sd_scale_dtype)
259
+ out_sd[k + SCALE_SUFFIX] = s_grid
260
+
261
+ if add_activation_placeholders:
262
+ base = k[:-len(".weight")]
263
+ out_sd[base + IN_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
264
+ out_sd[base + OUT_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
265
+
266
+ base = k[:-len(".weight")]
267
+ qmap[base] = {"weights": _QTYPE_NAME[fmt], "activations": "none"}
268
+ else:
269
+ out_sd[k] = t if t.dtype == dtype or t.dtype == torch.float32 else t.to(dtype)
270
+ t = None
271
+ return ConvertResult(state_dict=out_sd, quant_map=qmap, fp8_format=fmt, patch_needed=patch_needed)
272
+ finally:
273
+ closer()
274
+
275
+ def detect_safetensors_format(
276
+ src: Union[str, Dict[str, torch.Tensor]],
277
+ *,
278
+ sd_metadata: Optional[Dict[str, str]] = None,
279
+ probe_weights: bool = False, # if True, we may read up to 2 weights total
280
+ with_hints: bool = False,
281
+ ) -> Dict[str, str]:
282
+ """
283
+ Returns:
284
+ {
285
+ 'kind': 'quanto' | 'scaled_fp8' | 'fp8' | 'none',
286
+ 'quant_format': 'qfloat8_e4m3fn' | 'qfloat8_e5m2' | 'qfloat8' | 'qint8' | 'qint4' | 'unknown' | '',
287
+ 'fp8_format': 'e4m3fn' | 'e5m2' | 'unknown' | '',
288
+ 'hint': '...' # only when with_hints=True
289
+ }
290
+ """
291
+ acc, closer = _as_accessor(src, meta=sd_metadata, in_place=False)
292
+ try:
293
+ # --- O(1) sentinel test up-front (no key scan) ---
294
+ if acc.has("scaled_fp8"):
295
+ dt = acc.get_tensor("scaled_fp8").dtype
296
+ fp8_fmt = "e4m3fn" if dt == torch.float8_e4m3fn else ("e5m2" if dt == torch.float8_e5m2 else "unknown")
297
+ out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_fmt}
298
+ if with_hints: out["hint"] = "sentinel"
299
+ return out
300
+
301
+ # --- Single pass over keys (no re-scans) ---
302
+ ks = list(acc.keys())
303
+ has_scale_weight = False
304
+ saw_quanto_data = False
305
+ fp8_variant = None
306
+ fp8_probe_budget = 2 if probe_weights else 1
307
+
308
+ for k in ks:
309
+ # Quanto pack short-circuit
310
+ if not saw_quanto_data and k.endswith(DATA_SUFFIX):
311
+ saw_quanto_data = True
312
+ # we can break here, but keep minimal state setting uniformity
313
+ break
314
+
315
+ if saw_quanto_data:
316
+ out = {"kind": "quanto", "quant_format": "qfloat8", "fp8_format": ""}
317
+ if with_hints: out["hint"] = "keys:*._data"
318
+ return out
319
+
320
+ # continue single pass for the rest (scale keys + bounded dtype probe)
321
+ for k in ks:
322
+ if not has_scale_weight and k.endswith(".scale_weight"):
323
+ has_scale_weight = True
324
+ # don't return yet; we may still probe a dtype to grab variant
325
+
326
+ if fp8_probe_budget > 0 and _is_weight_key(k):
327
+ dt = acc.get_tensor(k).dtype
328
+ if dt == torch.float8_e4m3fn:
329
+ fp8_variant = "e4m3fn"; fp8_probe_budget -= 1
330
+ elif dt == torch.float8_e5m2:
331
+ fp8_variant = "e5m2"; fp8_probe_budget -= 1
332
+
333
+ if has_scale_weight:
334
+ out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_variant or "unknown"}
335
+ if with_hints: out["hint"] = "scale_weight keys"
336
+ return out
337
+
338
+ if fp8_variant is not None:
339
+ out = {"kind": "fp8", "quant_format": "", "fp8_format": fp8_variant}
340
+ if with_hints: out["hint"] = "weight dtype (plain fp8)"
341
+ return out
342
+
343
+ # --- Cheap metadata peek only if keys didn't decide it (no JSON parsing) ---
344
+ meta = acc.metadata() or {}
345
+ blob = " ".join(v for v in meta.values() if isinstance(v, str)).lower()
346
+
347
+ # scaled-fp8 hinted by metadata only
348
+ has_scale_map = (
349
+ any(k in meta for k in _SCALE_META_KEYS) or
350
+ (("scale" in blob) and (("fp8" in blob) or ("float8" in blob)))
351
+ )
352
+ if has_scale_map:
353
+ fmt = "e4m3fn" if "e4m3" in blob else ("e5m2" if "e5m2" in blob else "unknown")
354
+ out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fmt}
355
+ if with_hints: out["hint"] = "metadata"
356
+ return out
357
+
358
+ # quanto hinted by metadata only (not decisive without keys)
359
+ qtype_hint = ""
360
+ for tok in ("qfloat8_e4m3fn", "qfloat8_e5m2", "qfloat8", "qint8", "qint4"):
361
+ if tok in blob:
362
+ qtype_hint = tok
363
+ break
364
+
365
+ out = {"kind": "none", "quant_format": qtype_hint, "fp8_format": ""}
366
+ if with_hints: out["hint"] = "no decisive keys"
367
+ return out
368
+
369
+ finally:
370
+ closer()
371
+
372
+ # ---------- Optional Quanto runtime patch (FP32-scale support), enable/disable ----------
373
+ _patch_state = SimpleNamespace(enabled=False, orig=None, scale_index=None)
374
+
375
+ def enable_fp8_fp32_scale_support():
376
+ """
377
+ Version-robust wrapper for WeightQBytesTensor.create:
378
+ - matches both positional/keyword call styles via *args/**kwargs,
379
+ - for FP8 + FP32 scales, expands scalar/uniform scales with a VIEW to the needed length,
380
+ - leaves bf16/fp16 (classic Quanto) untouched.
381
+ Enable only if you emitted float32 scales.
382
+ """
383
+ if _patch_state.enabled:
384
+ return True
385
+
386
+ from optimum.quanto.tensor.weights import qbytes as _qbytes # late import
387
+ orig = _qbytes.WeightQBytesTensor.create
388
+ sig = inspect.signature(orig)
389
+ params = list(sig.parameters.keys())
390
+ scale_index = params.index("scale") if "scale" in params else 5 # fallback
391
+
392
+ def wrapper(*args, **kwargs):
393
+ # Extract fields irrespective of signature
394
+ qtype = kwargs.get("qtype", args[0] if len(args) > 0 else None)
395
+ axis = kwargs.get("axis", args[1] if len(args) > 1 else None)
396
+ size = kwargs.get("size", args[2] if len(args) > 2 else None)
397
+
398
+ if "scale" in kwargs:
399
+ scale = kwargs["scale"]
400
+ def set_scale(new): kwargs.__setitem__("scale", new)
401
+ else:
402
+ scale = args[scale_index] if len(args) > scale_index else None
403
+ def set_scale(new):
404
+ nonlocal args
405
+ args = list(args)
406
+ if len(args) > scale_index:
407
+ args[scale_index] = new
408
+ else:
409
+ kwargs["scale"] = new
410
+ args = tuple(args)
411
+
412
+ is_fp8 = isinstance(qtype, str) and ("float8" in qtype.lower() or "qfloat8" in qtype.lower()) or \
413
+ (not isinstance(qtype, str) and "float8" in str(qtype).lower())
414
+
415
+ if is_fp8 and isinstance(scale, torch.Tensor) and scale.dtype == torch.float32:
416
+ need = int(size[axis]) if (isinstance(size, (tuple, list)) and axis is not None and axis >= 0) else None
417
+ if need is not None:
418
+ if scale.numel() == 1:
419
+ scale = scale.view(1).expand(need, *scale.shape[1:])
420
+ elif scale.shape[0] != need:
421
+ # Expand if uniform; otherwise raise
422
+ uniform = (scale.numel() == 1) or (torch.numel(scale.unique()) == 1)
423
+ if uniform:
424
+ scale = scale.reshape(1, *scale.shape[1:]).expand(need, *scale.shape[1:])
425
+ else:
426
+ raise ValueError(f"Scale leading dim {scale.shape[0]} != required {need}")
427
+ set_scale(scale)
428
+
429
+ return orig(*args, **kwargs)
430
+
431
+ _qbytes.WeightQBytesTensor.create = wrapper
432
+ _patch_state.enabled = True
433
+ _patch_state.orig = orig
434
+ _patch_state.scale_index = scale_index
435
+ return True
436
+
437
+ def disable_fp8_fp32_scale_support():
438
+ """Restore Quanto's original factory."""
439
+ if not _patch_state.enabled:
440
+ return False
441
+ from optimum.quanto.tensor.weights import qbytes as _qbytes
442
+ _qbytes.WeightQBytesTensor.create = _patch_state.orig
443
+ _patch_state.enabled = False
444
+ _patch_state.orig = None
445
+ _patch_state.scale_index = None
446
+ return True
447
+
448
+ # ---------- Tiny CLI (optional) ----------
449
+ def _cli():
450
+ import argparse, json as _json
451
+ p = argparse.ArgumentParser("fp8_quanto_bridge")
452
+ sub = p.add_subparsers(dest="cmd", required=True)
453
+
454
+ p_conv = sub.add_parser("convert", help="Convert scaled-FP8 (file) to Quanto artifacts.")
455
+ p_conv.add_argument("in_path")
456
+ p_conv.add_argument("out_weights")
457
+ p_conv.add_argument("out_qmap")
458
+ p_conv.add_argument("--fp8-format", choices=("e4m3fn", "e5m2"), default=None)
459
+ p_conv.add_argument("--scale-dtype", default="float32",
460
+ choices=("float32","bfloat16","float16","fp32","bf16","fp16","half"))
461
+ p_conv.add_argument("--no-activation-placeholders", action="store_true")
462
+ p_conv.add_argument("--default-missing-scale", type=float, default=1.0)
463
+
464
+ p_det = sub.add_parser("detect", help="Detect format quickly (path).")
465
+ p_det.add_argument("path")
466
+ p_det.add_argument("--probe", action="store_true")
467
+ p_det.add_argument("--hints", action="store_true")
468
+
469
+ p_patch = sub.add_parser("patch", help="Enable/disable FP32-scale runtime patch.")
470
+ p_patch.add_argument("mode", choices=("enable","disable"))
471
+
472
+ args = p.parse_args()
473
+
474
+ if args.cmd == "convert":
475
+ res = convert_scaled_fp8_to_quanto(
476
+ args.in_path,
477
+ fp8_format=args.fp8_format,
478
+ dtype=args.scale_dtype,
479
+ add_activation_placeholders=not args.no_activation_placeholders,
480
+ default_missing_scale=args.default_missing_scale,
481
+ )
482
+ save_file(res.state_dict, args.out_weights)
483
+ with open(args.out_qmap, "w") as f:
484
+ _json.dump(res.quant_map, f)
485
+ print(f"Wrote: {args.out_weights} and {args.out_qmap}. Patch needed: {res.patch_needed}")
486
+ return 0
487
+
488
+ if args.cmd == "detect":
489
+ info = detect_safetensors_format(args.path, probe_weights=args.probe, with_hints=args.hints)
490
+ print(info); return 0
491
+
492
+ if args.cmd == "patch":
493
+ ok = enable_fp8_fp32_scale_support() if args.mode == "enable" else disable_fp8_fp32_scale_support()
494
+ print(f"patch {args.mode}: {'ok' if ok else 'already in that state'}")
495
+ return 0
496
+
497
+ if __name__ == "__main__":
498
+ raise SystemExit(_cli())
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.6.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.6.3 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -71,7 +71,7 @@ import torch
71
71
 
72
72
  from mmgp import safetensors2
73
73
  from mmgp import profile_type
74
-
74
+ from .fp8_quanto_bridge import convert_scaled_fp8_to_quanto, detect_safetensors_format , enable_fp8_fp32_scale_support
75
75
  from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QLinear, QTensor, quantize_module, register_qmodule
76
76
 
77
77
  # support for Embedding module quantization that is not supported by default by quanto
@@ -688,7 +688,7 @@ def _welcome():
688
688
  if welcome_displayed:
689
689
  return
690
690
  welcome_displayed = True
691
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
691
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.3) by DeepBeepMeep ************{ENDC}{UNBOLD}")
692
692
 
693
693
  def change_dtype(model, new_dtype, exclude_buffers = False):
694
694
  for submodule_name, submodule in model.named_modules():
@@ -1319,7 +1319,7 @@ def move_loras_to_device(model, device="cpu" ):
1319
1319
  if ".lora_" in k:
1320
1320
  m.to(device)
1321
1321
 
1322
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None, configKwargs ={}):
1322
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, configKwargs ={}):
1323
1323
  """
1324
1324
  quick version of .LoadfromPretrained of the transformers library
1325
1325
  used to build a model and load the corresponding weights (quantized or not)
@@ -1407,13 +1407,13 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
1407
1407
 
1408
1408
  model._config = transformer_config
1409
1409
 
1410
- load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, verboseLevel=verboseLevel )
1410
+ load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, default_dtype = default_dtype, ignore_unused_weights = ignore_unused_weights, verboseLevel=verboseLevel )
1411
1411
 
1412
1412
  return model
1413
1413
 
1414
1414
 
1415
1415
 
1416
- def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, modules = None, return_shared_modules = None, verboseLevel = -1):
1416
+ def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, postprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, verboseLevel = -1):
1417
1417
  """
1418
1418
  Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
1419
1419
  """
@@ -1489,29 +1489,41 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1489
1489
  state_dict.update(sd)
1490
1490
  else:
1491
1491
  state_dict, metadata = _safetensors_load_file(file, writable_tensors =writable_tensors)
1492
-
1493
- if metadata != None:
1494
- quantization_map = metadata.get("quantization_map", None)
1495
- config = metadata.get("config", None)
1496
- if config is not None:
1497
- model._config = config
1498
-
1499
- tied_weights_map = metadata.get("tied_weights_map", None)
1500
- if tied_weights_map != None:
1501
- for name, tied_weights_list in tied_weights_map.items():
1502
- mapped_weight = state_dict[name]
1503
- for tied_weights in tied_weights_list:
1504
- state_dict[tied_weights] = mapped_weight
1505
-
1506
- if quantization_map is None:
1507
- pos = str.rfind(file, ".")
1508
- if pos > 0:
1509
- quantization_map_path = file[:pos]
1510
- quantization_map_path += "_map.json"
1511
-
1512
- if os.path.isfile(quantization_map_path):
1513
- with open(quantization_map_path, 'r') as f:
1514
- quantization_map = json.load(f)
1492
+
1493
+ if preprocess_sd != None:
1494
+ state_dict = preprocess_sd(state_dict)
1495
+
1496
+ if metadata != None:
1497
+ quantization_map = metadata.get("quantization_map", None)
1498
+ config = metadata.get("config", None)
1499
+ if config is not None:
1500
+ model._config = config
1501
+
1502
+ tied_weights_map = metadata.get("tied_weights_map", None)
1503
+ if tied_weights_map != None:
1504
+ for name, tied_weights_list in tied_weights_map.items():
1505
+ mapped_weight = state_dict[name]
1506
+ for tied_weights in tied_weights_list:
1507
+ state_dict[tied_weights] = mapped_weight
1508
+
1509
+ if quantization_map is None:
1510
+ detection_type = detect_safetensors_format(state_dict)
1511
+ if detection_type["kind"] in ['scaled_fp8','fp8']:
1512
+ conv_result = convert_scaled_fp8_to_quanto(state_dict, dtype = default_dtype, in_place= True)
1513
+ state_dict = conv_result["state_dict"]
1514
+ quantization_map = conv_result["quant_map"]
1515
+ conv_result = None
1516
+ # enable_fp8_fp32_scale_support()
1517
+
1518
+ if quantization_map is None:
1519
+ pos = str.rfind(file, ".")
1520
+ if pos > 0:
1521
+ quantization_map_path = file[:pos]
1522
+ quantization_map_path += "_map.json"
1523
+
1524
+ if os.path.isfile(quantization_map_path):
1525
+ with open(quantization_map_path, 'r') as f:
1526
+ quantization_map = json.load(f)
1515
1527
 
1516
1528
  full_state_dict.update(state_dict)
1517
1529
  if quantization_map != None:
@@ -1530,8 +1542,8 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1530
1542
  full_state_dict, full_quantization_map, full_tied_weights_map = None, None, None
1531
1543
 
1532
1544
  # deal if we are trying to load just a sub part of a larger model
1533
- if preprocess_sd != None:
1534
- state_dict, quantization_map = preprocess_sd(state_dict, quantization_map)
1545
+ if postprocess_sd != None:
1546
+ state_dict, quantization_map = postprocess_sd(state_dict, quantization_map)
1535
1547
 
1536
1548
  if modelPrefix != None:
1537
1549
  base_model_prefix = modelPrefix + "."
@@ -1562,7 +1574,7 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1562
1574
 
1563
1575
  del state_dict
1564
1576
 
1565
- if len(unexpected_keys) > 0 and verboseLevel >=2:
1577
+ if len(unexpected_keys) > 0 and verboseLevel >=2 and not ignore_unused_weights:
1566
1578
  print(f"Unexpected keys while loading '{file_path}': {unexpected_keys}")
1567
1579
 
1568
1580
  for k,p in model.named_parameters():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mmgp
3
- Version: 3.6.1
3
+ Version: 3.6.3
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  Requires-Python: >=3.10
@@ -15,7 +15,7 @@ Dynamic: license-file
15
15
 
16
16
 
17
17
  <p align="center">
18
- <H2>Memory Management 3.6.1 for the GPU Poor by DeepBeepMeep</H2>
18
+ <H2>Memory Management 3.6.3 for the GPU Poor by DeepBeepMeep</H2>
19
19
  </p>
20
20
 
21
21
 
@@ -3,6 +3,7 @@ README.md
3
3
  pyproject.toml
4
4
  src/__init__.py
5
5
  src/mmgp/__init__.py
6
+ src/mmgp/fp8_quanto_bridge.py
6
7
  src/mmgp/offload.py
7
8
  src/mmgp/safetensors2.py
8
9
  src/mmgp.egg-info/PKG-INFO
File without changes
File without changes
File without changes
File without changes
File without changes