mmgp 3.6.1__tar.gz → 3.6.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mmgp might be problematic. Click here for more details.
- {mmgp-3.6.1/src/mmgp.egg-info → mmgp-3.6.3}/PKG-INFO +2 -2
- {mmgp-3.6.1 → mmgp-3.6.3}/README.md +1 -1
- {mmgp-3.6.1 → mmgp-3.6.3}/pyproject.toml +1 -1
- mmgp-3.6.3/src/mmgp/fp8_quanto_bridge.py +498 -0
- {mmgp-3.6.1 → mmgp-3.6.3}/src/mmgp/offload.py +44 -32
- {mmgp-3.6.1 → mmgp-3.6.3/src/mmgp.egg-info}/PKG-INFO +2 -2
- {mmgp-3.6.1 → mmgp-3.6.3}/src/mmgp.egg-info/SOURCES.txt +1 -0
- {mmgp-3.6.1 → mmgp-3.6.3}/LICENSE.md +0 -0
- {mmgp-3.6.1 → mmgp-3.6.3}/setup.cfg +0 -0
- {mmgp-3.6.1 → mmgp-3.6.3}/src/__init__.py +0 -0
- {mmgp-3.6.1 → mmgp-3.6.3}/src/mmgp/__init__.py +0 -0
- {mmgp-3.6.1 → mmgp-3.6.3}/src/mmgp/safetensors2.py +0 -0
- {mmgp-3.6.1 → mmgp-3.6.3}/src/mmgp.egg-info/dependency_links.txt +0 -0
- {mmgp-3.6.1 → mmgp-3.6.3}/src/mmgp.egg-info/requires.txt +0 -0
- {mmgp-3.6.1 → mmgp-3.6.3}/src/mmgp.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mmgp
|
|
3
|
-
Version: 3.6.
|
|
3
|
+
Version: 3.6.3
|
|
4
4
|
Summary: Memory Management for the GPU Poor
|
|
5
5
|
Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -15,7 +15,7 @@ Dynamic: license-file
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
<p align="center">
|
|
18
|
-
<H2>Memory Management 3.6.
|
|
18
|
+
<H2>Memory Management 3.6.3 for the GPU Poor by DeepBeepMeep</H2>
|
|
19
19
|
</p>
|
|
20
20
|
|
|
21
21
|
|
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import json, re, inspect
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
from typing import Dict, Optional, Tuple, Union, Iterable, Callable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from safetensors.torch import safe_open, save_file
|
|
8
|
+
|
|
9
|
+
# ---------- Constants ----------
|
|
10
|
+
DATA_SUFFIX = "._data"
|
|
11
|
+
SCALE_SUFFIX = "._scale" # per-channel, shape [out, 1, ...]
|
|
12
|
+
IN_SCALE = ".input_scale" # 1-D placeholder tensor [1]
|
|
13
|
+
OUT_SCALE = ".output_scale" # 1-D placeholder tensor [1]
|
|
14
|
+
|
|
15
|
+
_QTYPE_NAME = {
|
|
16
|
+
"e4m3fn": "qfloat8_e4m3fn",
|
|
17
|
+
"e5m2": "qfloat8_e5m2",
|
|
18
|
+
"auto": "qfloat8",
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
_SCALE_META_KEYS = (
|
|
22
|
+
"fp8_scale_map", "fp8.scale_map", "scale_map",
|
|
23
|
+
"quant_scale_map", "weights_scales", "scales",
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
_DTYPE_ALIASES = {
|
|
27
|
+
"float32": torch.float32, "fp32": torch.float32,
|
|
28
|
+
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
|
|
29
|
+
"float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
def _is_weight_key(k: str) -> bool:
|
|
33
|
+
return k.endswith(".weight")
|
|
34
|
+
|
|
35
|
+
# ---------- Accessors (unify file vs dict) ----------
|
|
36
|
+
class Accessor:
|
|
37
|
+
def keys(self) -> Iterable[str]: ...
|
|
38
|
+
def get_tensor(self, key: str) -> torch.Tensor: ...
|
|
39
|
+
def metadata(self) -> Dict[str, str]: ...
|
|
40
|
+
def has(self, key: str) -> bool: ... # NEW
|
|
41
|
+
def can_delete(self) -> bool: return False
|
|
42
|
+
def delete(self, key: str) -> None: raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
class FileAccessor(Accessor):
|
|
45
|
+
def __init__(self, path: str):
|
|
46
|
+
self._fh = safe_open(path, framework="pt")
|
|
47
|
+
self._keys = list(self._fh.keys())
|
|
48
|
+
self._keys_set = set(self._keys) # O(1) membership
|
|
49
|
+
self._meta = self._fh.metadata() or {}
|
|
50
|
+
def keys(self) -> Iterable[str]: return self._keys
|
|
51
|
+
def has(self, key: str) -> bool: return key in self._keys_set
|
|
52
|
+
def get_tensor(self, key: str) -> torch.Tensor: return self._fh.get_tensor(key)
|
|
53
|
+
def metadata(self) -> Dict[str, str]: return self._meta
|
|
54
|
+
def close(self) -> None: self._fh.close()
|
|
55
|
+
|
|
56
|
+
class DictAccessor(Accessor):
|
|
57
|
+
def __init__(self, sd: Dict[str, torch.Tensor], meta: Optional[Dict[str, str]] = None,
|
|
58
|
+
in_place: bool = False, free_cuda_cache: bool = False, cuda_cache_interval: int = 32):
|
|
59
|
+
self.sd = sd
|
|
60
|
+
self._meta = meta or {}
|
|
61
|
+
self._in_place = in_place
|
|
62
|
+
self._free = free_cuda_cache
|
|
63
|
+
self._interval = int(cuda_cache_interval)
|
|
64
|
+
self._deletions = 0
|
|
65
|
+
def keys(self) -> Iterable[str]: return list(self.sd.keys())
|
|
66
|
+
def has(self, key: str) -> bool: return key in self.sd # dict membership = O(1)
|
|
67
|
+
def get_tensor(self, key: str) -> torch.Tensor: return self.sd[key]
|
|
68
|
+
def metadata(self) -> Dict[str, str]: return self._meta
|
|
69
|
+
def can_delete(self) -> bool: return self._in_place
|
|
70
|
+
def delete(self, key: str) -> None:
|
|
71
|
+
if key in self.sd:
|
|
72
|
+
self.sd.pop(key, None)
|
|
73
|
+
self._deletions += 1
|
|
74
|
+
if self._free and (self._deletions % self._interval == 0) and torch.cuda.is_available():
|
|
75
|
+
torch.cuda.empty_cache()
|
|
76
|
+
def _as_accessor(src: Union[str, Dict[str, torch.Tensor]], **dict_opts) -> Tuple[Accessor, Callable[[], None]]:
|
|
77
|
+
if isinstance(src, str):
|
|
78
|
+
acc = FileAccessor(src)
|
|
79
|
+
return acc, acc.close
|
|
80
|
+
acc = DictAccessor(src, **dict_opts)
|
|
81
|
+
return acc, (lambda: None)
|
|
82
|
+
|
|
83
|
+
# ---------- Shared helpers ----------
|
|
84
|
+
def _normalize_scale_dtype(scale_dtype: Union[str, torch.dtype]) -> torch.dtype:
|
|
85
|
+
if isinstance(scale_dtype, torch.dtype):
|
|
86
|
+
return scale_dtype
|
|
87
|
+
key = str(scale_dtype).lower()
|
|
88
|
+
if key not in _DTYPE_ALIASES:
|
|
89
|
+
raise ValueError(f"scale_dtype must be one of {list(_DTYPE_ALIASES.keys())} or a torch.dtype")
|
|
90
|
+
return _DTYPE_ALIASES[key]
|
|
91
|
+
|
|
92
|
+
def _json_to_dict(s: str) -> Optional[Dict]:
|
|
93
|
+
# Strictly catch JSON decoding only
|
|
94
|
+
try:
|
|
95
|
+
return json.loads(s)
|
|
96
|
+
except json.JSONDecodeError:
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
def _maybe_parse_scale_map(meta: Dict[str, str]) -> Optional[Dict[str, float]]:
|
|
100
|
+
def try_parse(obj) -> Optional[Dict[str, float]]:
|
|
101
|
+
if not isinstance(obj, dict):
|
|
102
|
+
return None
|
|
103
|
+
out: Dict[str, float] = {}
|
|
104
|
+
for wk, v in obj.items():
|
|
105
|
+
if isinstance(v, (int, float)):
|
|
106
|
+
out[wk] = float(v)
|
|
107
|
+
elif isinstance(v, dict) and "scale" in v:
|
|
108
|
+
sc = v["scale"]
|
|
109
|
+
if isinstance(sc, (int, float)):
|
|
110
|
+
out[wk] = float(sc)
|
|
111
|
+
elif isinstance(sc, (list, tuple)) and len(sc) == 1 and isinstance(sc[0], (int, float)):
|
|
112
|
+
out[wk] = float(sc[0])
|
|
113
|
+
if out:
|
|
114
|
+
return out
|
|
115
|
+
for sub in ("weights", "tensors", "params", "map"):
|
|
116
|
+
subobj = obj.get(sub)
|
|
117
|
+
if isinstance(subobj, dict):
|
|
118
|
+
got = try_parse(subobj)
|
|
119
|
+
if got:
|
|
120
|
+
return got
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
# exact keys first
|
|
124
|
+
for k in _SCALE_META_KEYS:
|
|
125
|
+
raw = meta.get(k)
|
|
126
|
+
if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"):
|
|
127
|
+
parsed = _json_to_dict(raw)
|
|
128
|
+
if parsed:
|
|
129
|
+
got = try_parse(parsed)
|
|
130
|
+
if got:
|
|
131
|
+
return got
|
|
132
|
+
|
|
133
|
+
# loose scan of any JSON-looking value
|
|
134
|
+
for v in meta.values():
|
|
135
|
+
if isinstance(v, str) and v.startswith("{") and v.endswith("}"):
|
|
136
|
+
parsed = _json_to_dict(v)
|
|
137
|
+
if parsed:
|
|
138
|
+
got = try_parse(parsed)
|
|
139
|
+
if got:
|
|
140
|
+
return got
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
def _quick_fp8_variant_from_sentinel(acc: Accessor) -> Optional[str]:
|
|
144
|
+
if "scaled_fp8" in set(acc.keys()):
|
|
145
|
+
dt = acc.get_tensor("scaled_fp8").dtype
|
|
146
|
+
if dt == torch.float8_e4m3fn: return "e4m3fn"
|
|
147
|
+
if dt == torch.float8_e5m2: return "e5m2"
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
def _per_channel_reshape(vec: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
|
151
|
+
return vec.view(weight.shape[0], *([1] * (weight.ndim - 1)))
|
|
152
|
+
|
|
153
|
+
# ---------- Unified converter ----------
|
|
154
|
+
class ConvertResult(Dict[str, object]):
|
|
155
|
+
@property
|
|
156
|
+
def state_dict(self) -> Dict[str, torch.Tensor]: return self["state_dict"] # type: ignore
|
|
157
|
+
@property
|
|
158
|
+
def quant_map(self) -> Dict[str, Dict]: return self["quant_map"] # type: ignore
|
|
159
|
+
@property
|
|
160
|
+
def fp8_format(self) -> str: return self["fp8_format"] # type: ignore
|
|
161
|
+
@property
|
|
162
|
+
def patch_needed(self) -> bool: return self["patch_needed"] # type: ignore
|
|
163
|
+
|
|
164
|
+
def convert_scaled_fp8_to_quanto(
|
|
165
|
+
src: Union[str, Dict[str, torch.Tensor]],
|
|
166
|
+
*,
|
|
167
|
+
fp8_format: Optional[str] = None, # 'e4m3fn' | 'e5m2' | None (auto)
|
|
168
|
+
require_scale: bool = False,
|
|
169
|
+
allow_default_scale: bool = True,
|
|
170
|
+
default_missing_scale: float = 1.0,
|
|
171
|
+
dtype: Union[str, torch.dtype] = "float32",
|
|
172
|
+
add_activation_placeholders: bool = True,
|
|
173
|
+
# dict mode options
|
|
174
|
+
sd_metadata: Optional[Dict[str, str]] = None,
|
|
175
|
+
in_place: bool = False,
|
|
176
|
+
free_cuda_cache: bool = False,
|
|
177
|
+
cuda_cache_interval: int = 32,
|
|
178
|
+
) -> ConvertResult:
|
|
179
|
+
sd_scale_dtype = _normalize_scale_dtype(dtype)
|
|
180
|
+
patch_needed = (sd_scale_dtype == torch.float32)
|
|
181
|
+
|
|
182
|
+
acc, closer = _as_accessor(
|
|
183
|
+
src,
|
|
184
|
+
meta=sd_metadata,
|
|
185
|
+
in_place=in_place,
|
|
186
|
+
free_cuda_cache=free_cuda_cache,
|
|
187
|
+
cuda_cache_interval=cuda_cache_interval,
|
|
188
|
+
)
|
|
189
|
+
if not acc.can_delete(): in_place = False
|
|
190
|
+
try:
|
|
191
|
+
meta = acc.metadata() or {}
|
|
192
|
+
meta_scale_map = _maybe_parse_scale_map(meta) or {}
|
|
193
|
+
|
|
194
|
+
keys = list(acc.keys())
|
|
195
|
+
|
|
196
|
+
# FP8 variant: sentinel -> first FP8 weight -> 'auto'
|
|
197
|
+
fmt = fp8_format or _quick_fp8_variant_from_sentinel(acc)
|
|
198
|
+
if fmt is None:
|
|
199
|
+
for wk in keys:
|
|
200
|
+
if not _is_weight_key(wk): continue
|
|
201
|
+
dt = acc.get_tensor(wk).dtype
|
|
202
|
+
if dt == torch.float8_e4m3fn: fmt = "e4m3fn"; break
|
|
203
|
+
if dt == torch.float8_e5m2: fmt = "e5m2"; break
|
|
204
|
+
if fmt is None: fmt = "auto"
|
|
205
|
+
|
|
206
|
+
# Map '<base>.scale_weight' -> '<base>.weight'
|
|
207
|
+
scale_weight_map: Dict[str, str] = {}
|
|
208
|
+
for sk in keys:
|
|
209
|
+
if sk.endswith(".scale_weight"):
|
|
210
|
+
base = sk[: -len(".scale_weight")]
|
|
211
|
+
wk = base + ".weight"
|
|
212
|
+
if wk in keys:
|
|
213
|
+
scale_weight_map[wk] = sk
|
|
214
|
+
|
|
215
|
+
def get_scale_vec_for_weight(wk: str, out_ch: int) -> Optional[torch.Tensor]:
|
|
216
|
+
# 1) explicit tensor
|
|
217
|
+
sk = scale_weight_map.get(wk)
|
|
218
|
+
if sk is not None:
|
|
219
|
+
s_t = acc.get_tensor(sk).to(torch.float32)
|
|
220
|
+
if in_place: acc.delete(s_t)
|
|
221
|
+
if s_t.numel() == 1:
|
|
222
|
+
return torch.full((out_ch,), float(s_t.item()), dtype=torch.float32)
|
|
223
|
+
if s_t.numel() == out_ch:
|
|
224
|
+
return s_t.reshape(out_ch)
|
|
225
|
+
if torch.numel(s_t.unique()) == 1:
|
|
226
|
+
return torch.full((out_ch,), float(s_t.view(-1)[0].item()), dtype=torch.float32)
|
|
227
|
+
raise ValueError(f"Unexpected scale length for '{wk}': {s_t.numel()} (out_ch={out_ch})")
|
|
228
|
+
# 2) metadata exact / normalized
|
|
229
|
+
if wk in meta_scale_map:
|
|
230
|
+
return torch.full((out_ch,), float(meta_scale_map[wk]), dtype=torch.float32)
|
|
231
|
+
for alt in (wk.replace("model.", ""), re.sub(r"(^|\.)weight$", "", wk)):
|
|
232
|
+
if alt in meta_scale_map:
|
|
233
|
+
return torch.full((out_ch,), float(meta_scale_map[alt]), dtype=torch.float32)
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
out_sd: Dict[str, torch.Tensor] = {}
|
|
237
|
+
qmap: Dict[str, Dict] = {}
|
|
238
|
+
|
|
239
|
+
# Single pass: rewrite FP8 weights, copy-through others
|
|
240
|
+
for k in keys:
|
|
241
|
+
# Drop source-only artifacts
|
|
242
|
+
if k == "scaled_fp8" or k.endswith(".scale_weight") :
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
t = acc.get_tensor(k)
|
|
246
|
+
if in_place: acc.delete(k)
|
|
247
|
+
if _is_weight_key(k) and t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
|
248
|
+
# Quantized: keep original FP8 tensor as _data
|
|
249
|
+
out_sd[k + DATA_SUFFIX] = t
|
|
250
|
+
|
|
251
|
+
out_ch = int(t.shape[0])
|
|
252
|
+
s_vec = get_scale_vec_for_weight(k, out_ch)
|
|
253
|
+
if s_vec is None:
|
|
254
|
+
if require_scale and not allow_default_scale:
|
|
255
|
+
raise KeyError(f"No scale found for '{k}' (looked for '.scale_weight' and metadata).")
|
|
256
|
+
s_vec = torch.full((out_ch,), float(default_missing_scale), dtype=torch.float32)
|
|
257
|
+
|
|
258
|
+
s_grid = _per_channel_reshape(s_vec, t).to(sd_scale_dtype)
|
|
259
|
+
out_sd[k + SCALE_SUFFIX] = s_grid
|
|
260
|
+
|
|
261
|
+
if add_activation_placeholders:
|
|
262
|
+
base = k[:-len(".weight")]
|
|
263
|
+
out_sd[base + IN_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
|
|
264
|
+
out_sd[base + OUT_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
|
|
265
|
+
|
|
266
|
+
base = k[:-len(".weight")]
|
|
267
|
+
qmap[base] = {"weights": _QTYPE_NAME[fmt], "activations": "none"}
|
|
268
|
+
else:
|
|
269
|
+
out_sd[k] = t if t.dtype == dtype or t.dtype == torch.float32 else t.to(dtype)
|
|
270
|
+
t = None
|
|
271
|
+
return ConvertResult(state_dict=out_sd, quant_map=qmap, fp8_format=fmt, patch_needed=patch_needed)
|
|
272
|
+
finally:
|
|
273
|
+
closer()
|
|
274
|
+
|
|
275
|
+
def detect_safetensors_format(
|
|
276
|
+
src: Union[str, Dict[str, torch.Tensor]],
|
|
277
|
+
*,
|
|
278
|
+
sd_metadata: Optional[Dict[str, str]] = None,
|
|
279
|
+
probe_weights: bool = False, # if True, we may read up to 2 weights total
|
|
280
|
+
with_hints: bool = False,
|
|
281
|
+
) -> Dict[str, str]:
|
|
282
|
+
"""
|
|
283
|
+
Returns:
|
|
284
|
+
{
|
|
285
|
+
'kind': 'quanto' | 'scaled_fp8' | 'fp8' | 'none',
|
|
286
|
+
'quant_format': 'qfloat8_e4m3fn' | 'qfloat8_e5m2' | 'qfloat8' | 'qint8' | 'qint4' | 'unknown' | '',
|
|
287
|
+
'fp8_format': 'e4m3fn' | 'e5m2' | 'unknown' | '',
|
|
288
|
+
'hint': '...' # only when with_hints=True
|
|
289
|
+
}
|
|
290
|
+
"""
|
|
291
|
+
acc, closer = _as_accessor(src, meta=sd_metadata, in_place=False)
|
|
292
|
+
try:
|
|
293
|
+
# --- O(1) sentinel test up-front (no key scan) ---
|
|
294
|
+
if acc.has("scaled_fp8"):
|
|
295
|
+
dt = acc.get_tensor("scaled_fp8").dtype
|
|
296
|
+
fp8_fmt = "e4m3fn" if dt == torch.float8_e4m3fn else ("e5m2" if dt == torch.float8_e5m2 else "unknown")
|
|
297
|
+
out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_fmt}
|
|
298
|
+
if with_hints: out["hint"] = "sentinel"
|
|
299
|
+
return out
|
|
300
|
+
|
|
301
|
+
# --- Single pass over keys (no re-scans) ---
|
|
302
|
+
ks = list(acc.keys())
|
|
303
|
+
has_scale_weight = False
|
|
304
|
+
saw_quanto_data = False
|
|
305
|
+
fp8_variant = None
|
|
306
|
+
fp8_probe_budget = 2 if probe_weights else 1
|
|
307
|
+
|
|
308
|
+
for k in ks:
|
|
309
|
+
# Quanto pack short-circuit
|
|
310
|
+
if not saw_quanto_data and k.endswith(DATA_SUFFIX):
|
|
311
|
+
saw_quanto_data = True
|
|
312
|
+
# we can break here, but keep minimal state setting uniformity
|
|
313
|
+
break
|
|
314
|
+
|
|
315
|
+
if saw_quanto_data:
|
|
316
|
+
out = {"kind": "quanto", "quant_format": "qfloat8", "fp8_format": ""}
|
|
317
|
+
if with_hints: out["hint"] = "keys:*._data"
|
|
318
|
+
return out
|
|
319
|
+
|
|
320
|
+
# continue single pass for the rest (scale keys + bounded dtype probe)
|
|
321
|
+
for k in ks:
|
|
322
|
+
if not has_scale_weight and k.endswith(".scale_weight"):
|
|
323
|
+
has_scale_weight = True
|
|
324
|
+
# don't return yet; we may still probe a dtype to grab variant
|
|
325
|
+
|
|
326
|
+
if fp8_probe_budget > 0 and _is_weight_key(k):
|
|
327
|
+
dt = acc.get_tensor(k).dtype
|
|
328
|
+
if dt == torch.float8_e4m3fn:
|
|
329
|
+
fp8_variant = "e4m3fn"; fp8_probe_budget -= 1
|
|
330
|
+
elif dt == torch.float8_e5m2:
|
|
331
|
+
fp8_variant = "e5m2"; fp8_probe_budget -= 1
|
|
332
|
+
|
|
333
|
+
if has_scale_weight:
|
|
334
|
+
out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_variant or "unknown"}
|
|
335
|
+
if with_hints: out["hint"] = "scale_weight keys"
|
|
336
|
+
return out
|
|
337
|
+
|
|
338
|
+
if fp8_variant is not None:
|
|
339
|
+
out = {"kind": "fp8", "quant_format": "", "fp8_format": fp8_variant}
|
|
340
|
+
if with_hints: out["hint"] = "weight dtype (plain fp8)"
|
|
341
|
+
return out
|
|
342
|
+
|
|
343
|
+
# --- Cheap metadata peek only if keys didn't decide it (no JSON parsing) ---
|
|
344
|
+
meta = acc.metadata() or {}
|
|
345
|
+
blob = " ".join(v for v in meta.values() if isinstance(v, str)).lower()
|
|
346
|
+
|
|
347
|
+
# scaled-fp8 hinted by metadata only
|
|
348
|
+
has_scale_map = (
|
|
349
|
+
any(k in meta for k in _SCALE_META_KEYS) or
|
|
350
|
+
(("scale" in blob) and (("fp8" in blob) or ("float8" in blob)))
|
|
351
|
+
)
|
|
352
|
+
if has_scale_map:
|
|
353
|
+
fmt = "e4m3fn" if "e4m3" in blob else ("e5m2" if "e5m2" in blob else "unknown")
|
|
354
|
+
out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fmt}
|
|
355
|
+
if with_hints: out["hint"] = "metadata"
|
|
356
|
+
return out
|
|
357
|
+
|
|
358
|
+
# quanto hinted by metadata only (not decisive without keys)
|
|
359
|
+
qtype_hint = ""
|
|
360
|
+
for tok in ("qfloat8_e4m3fn", "qfloat8_e5m2", "qfloat8", "qint8", "qint4"):
|
|
361
|
+
if tok in blob:
|
|
362
|
+
qtype_hint = tok
|
|
363
|
+
break
|
|
364
|
+
|
|
365
|
+
out = {"kind": "none", "quant_format": qtype_hint, "fp8_format": ""}
|
|
366
|
+
if with_hints: out["hint"] = "no decisive keys"
|
|
367
|
+
return out
|
|
368
|
+
|
|
369
|
+
finally:
|
|
370
|
+
closer()
|
|
371
|
+
|
|
372
|
+
# ---------- Optional Quanto runtime patch (FP32-scale support), enable/disable ----------
|
|
373
|
+
_patch_state = SimpleNamespace(enabled=False, orig=None, scale_index=None)
|
|
374
|
+
|
|
375
|
+
def enable_fp8_fp32_scale_support():
|
|
376
|
+
"""
|
|
377
|
+
Version-robust wrapper for WeightQBytesTensor.create:
|
|
378
|
+
- matches both positional/keyword call styles via *args/**kwargs,
|
|
379
|
+
- for FP8 + FP32 scales, expands scalar/uniform scales with a VIEW to the needed length,
|
|
380
|
+
- leaves bf16/fp16 (classic Quanto) untouched.
|
|
381
|
+
Enable only if you emitted float32 scales.
|
|
382
|
+
"""
|
|
383
|
+
if _patch_state.enabled:
|
|
384
|
+
return True
|
|
385
|
+
|
|
386
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes # late import
|
|
387
|
+
orig = _qbytes.WeightQBytesTensor.create
|
|
388
|
+
sig = inspect.signature(orig)
|
|
389
|
+
params = list(sig.parameters.keys())
|
|
390
|
+
scale_index = params.index("scale") if "scale" in params else 5 # fallback
|
|
391
|
+
|
|
392
|
+
def wrapper(*args, **kwargs):
|
|
393
|
+
# Extract fields irrespective of signature
|
|
394
|
+
qtype = kwargs.get("qtype", args[0] if len(args) > 0 else None)
|
|
395
|
+
axis = kwargs.get("axis", args[1] if len(args) > 1 else None)
|
|
396
|
+
size = kwargs.get("size", args[2] if len(args) > 2 else None)
|
|
397
|
+
|
|
398
|
+
if "scale" in kwargs:
|
|
399
|
+
scale = kwargs["scale"]
|
|
400
|
+
def set_scale(new): kwargs.__setitem__("scale", new)
|
|
401
|
+
else:
|
|
402
|
+
scale = args[scale_index] if len(args) > scale_index else None
|
|
403
|
+
def set_scale(new):
|
|
404
|
+
nonlocal args
|
|
405
|
+
args = list(args)
|
|
406
|
+
if len(args) > scale_index:
|
|
407
|
+
args[scale_index] = new
|
|
408
|
+
else:
|
|
409
|
+
kwargs["scale"] = new
|
|
410
|
+
args = tuple(args)
|
|
411
|
+
|
|
412
|
+
is_fp8 = isinstance(qtype, str) and ("float8" in qtype.lower() or "qfloat8" in qtype.lower()) or \
|
|
413
|
+
(not isinstance(qtype, str) and "float8" in str(qtype).lower())
|
|
414
|
+
|
|
415
|
+
if is_fp8 and isinstance(scale, torch.Tensor) and scale.dtype == torch.float32:
|
|
416
|
+
need = int(size[axis]) if (isinstance(size, (tuple, list)) and axis is not None and axis >= 0) else None
|
|
417
|
+
if need is not None:
|
|
418
|
+
if scale.numel() == 1:
|
|
419
|
+
scale = scale.view(1).expand(need, *scale.shape[1:])
|
|
420
|
+
elif scale.shape[0] != need:
|
|
421
|
+
# Expand if uniform; otherwise raise
|
|
422
|
+
uniform = (scale.numel() == 1) or (torch.numel(scale.unique()) == 1)
|
|
423
|
+
if uniform:
|
|
424
|
+
scale = scale.reshape(1, *scale.shape[1:]).expand(need, *scale.shape[1:])
|
|
425
|
+
else:
|
|
426
|
+
raise ValueError(f"Scale leading dim {scale.shape[0]} != required {need}")
|
|
427
|
+
set_scale(scale)
|
|
428
|
+
|
|
429
|
+
return orig(*args, **kwargs)
|
|
430
|
+
|
|
431
|
+
_qbytes.WeightQBytesTensor.create = wrapper
|
|
432
|
+
_patch_state.enabled = True
|
|
433
|
+
_patch_state.orig = orig
|
|
434
|
+
_patch_state.scale_index = scale_index
|
|
435
|
+
return True
|
|
436
|
+
|
|
437
|
+
def disable_fp8_fp32_scale_support():
|
|
438
|
+
"""Restore Quanto's original factory."""
|
|
439
|
+
if not _patch_state.enabled:
|
|
440
|
+
return False
|
|
441
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes
|
|
442
|
+
_qbytes.WeightQBytesTensor.create = _patch_state.orig
|
|
443
|
+
_patch_state.enabled = False
|
|
444
|
+
_patch_state.orig = None
|
|
445
|
+
_patch_state.scale_index = None
|
|
446
|
+
return True
|
|
447
|
+
|
|
448
|
+
# ---------- Tiny CLI (optional) ----------
|
|
449
|
+
def _cli():
|
|
450
|
+
import argparse, json as _json
|
|
451
|
+
p = argparse.ArgumentParser("fp8_quanto_bridge")
|
|
452
|
+
sub = p.add_subparsers(dest="cmd", required=True)
|
|
453
|
+
|
|
454
|
+
p_conv = sub.add_parser("convert", help="Convert scaled-FP8 (file) to Quanto artifacts.")
|
|
455
|
+
p_conv.add_argument("in_path")
|
|
456
|
+
p_conv.add_argument("out_weights")
|
|
457
|
+
p_conv.add_argument("out_qmap")
|
|
458
|
+
p_conv.add_argument("--fp8-format", choices=("e4m3fn", "e5m2"), default=None)
|
|
459
|
+
p_conv.add_argument("--scale-dtype", default="float32",
|
|
460
|
+
choices=("float32","bfloat16","float16","fp32","bf16","fp16","half"))
|
|
461
|
+
p_conv.add_argument("--no-activation-placeholders", action="store_true")
|
|
462
|
+
p_conv.add_argument("--default-missing-scale", type=float, default=1.0)
|
|
463
|
+
|
|
464
|
+
p_det = sub.add_parser("detect", help="Detect format quickly (path).")
|
|
465
|
+
p_det.add_argument("path")
|
|
466
|
+
p_det.add_argument("--probe", action="store_true")
|
|
467
|
+
p_det.add_argument("--hints", action="store_true")
|
|
468
|
+
|
|
469
|
+
p_patch = sub.add_parser("patch", help="Enable/disable FP32-scale runtime patch.")
|
|
470
|
+
p_patch.add_argument("mode", choices=("enable","disable"))
|
|
471
|
+
|
|
472
|
+
args = p.parse_args()
|
|
473
|
+
|
|
474
|
+
if args.cmd == "convert":
|
|
475
|
+
res = convert_scaled_fp8_to_quanto(
|
|
476
|
+
args.in_path,
|
|
477
|
+
fp8_format=args.fp8_format,
|
|
478
|
+
dtype=args.scale_dtype,
|
|
479
|
+
add_activation_placeholders=not args.no_activation_placeholders,
|
|
480
|
+
default_missing_scale=args.default_missing_scale,
|
|
481
|
+
)
|
|
482
|
+
save_file(res.state_dict, args.out_weights)
|
|
483
|
+
with open(args.out_qmap, "w") as f:
|
|
484
|
+
_json.dump(res.quant_map, f)
|
|
485
|
+
print(f"Wrote: {args.out_weights} and {args.out_qmap}. Patch needed: {res.patch_needed}")
|
|
486
|
+
return 0
|
|
487
|
+
|
|
488
|
+
if args.cmd == "detect":
|
|
489
|
+
info = detect_safetensors_format(args.path, probe_weights=args.probe, with_hints=args.hints)
|
|
490
|
+
print(info); return 0
|
|
491
|
+
|
|
492
|
+
if args.cmd == "patch":
|
|
493
|
+
ok = enable_fp8_fp32_scale_support() if args.mode == "enable" else disable_fp8_fp32_scale_support()
|
|
494
|
+
print(f"patch {args.mode}: {'ok' if ok else 'already in that state'}")
|
|
495
|
+
return 0
|
|
496
|
+
|
|
497
|
+
if __name__ == "__main__":
|
|
498
|
+
raise SystemExit(_cli())
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# ------------------ Memory Management 3.6.
|
|
1
|
+
# ------------------ Memory Management 3.6.3 for the GPU Poor by DeepBeepMeep (mmgp)------------------
|
|
2
2
|
#
|
|
3
3
|
# This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
|
|
4
4
|
# This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
|
|
@@ -71,7 +71,7 @@ import torch
|
|
|
71
71
|
|
|
72
72
|
from mmgp import safetensors2
|
|
73
73
|
from mmgp import profile_type
|
|
74
|
-
|
|
74
|
+
from .fp8_quanto_bridge import convert_scaled_fp8_to_quanto, detect_safetensors_format , enable_fp8_fp32_scale_support
|
|
75
75
|
from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QLinear, QTensor, quantize_module, register_qmodule
|
|
76
76
|
|
|
77
77
|
# support for Embedding module quantization that is not supported by default by quanto
|
|
@@ -688,7 +688,7 @@ def _welcome():
|
|
|
688
688
|
if welcome_displayed:
|
|
689
689
|
return
|
|
690
690
|
welcome_displayed = True
|
|
691
|
-
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.
|
|
691
|
+
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.3) by DeepBeepMeep ************{ENDC}{UNBOLD}")
|
|
692
692
|
|
|
693
693
|
def change_dtype(model, new_dtype, exclude_buffers = False):
|
|
694
694
|
for submodule_name, submodule in model.named_modules():
|
|
@@ -1319,7 +1319,7 @@ def move_loras_to_device(model, device="cpu" ):
|
|
|
1319
1319
|
if ".lora_" in k:
|
|
1320
1320
|
m.to(device)
|
|
1321
1321
|
|
|
1322
|
-
def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None,
|
|
1322
|
+
def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, configKwargs ={}):
|
|
1323
1323
|
"""
|
|
1324
1324
|
quick version of .LoadfromPretrained of the transformers library
|
|
1325
1325
|
used to build a model and load the corresponding weights (quantized or not)
|
|
@@ -1407,13 +1407,13 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
|
|
|
1407
1407
|
|
|
1408
1408
|
model._config = transformer_config
|
|
1409
1409
|
|
|
1410
|
-
load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, verboseLevel=verboseLevel )
|
|
1410
|
+
load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, default_dtype = default_dtype, ignore_unused_weights = ignore_unused_weights, verboseLevel=verboseLevel )
|
|
1411
1411
|
|
|
1412
1412
|
return model
|
|
1413
1413
|
|
|
1414
1414
|
|
|
1415
1415
|
|
|
1416
|
-
def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, modules = None, return_shared_modules = None, verboseLevel = -1):
|
|
1416
|
+
def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, postprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, verboseLevel = -1):
|
|
1417
1417
|
"""
|
|
1418
1418
|
Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
|
|
1419
1419
|
"""
|
|
@@ -1489,29 +1489,41 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1489
1489
|
state_dict.update(sd)
|
|
1490
1490
|
else:
|
|
1491
1491
|
state_dict, metadata = _safetensors_load_file(file, writable_tensors =writable_tensors)
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1492
|
+
|
|
1493
|
+
if preprocess_sd != None:
|
|
1494
|
+
state_dict = preprocess_sd(state_dict)
|
|
1495
|
+
|
|
1496
|
+
if metadata != None:
|
|
1497
|
+
quantization_map = metadata.get("quantization_map", None)
|
|
1498
|
+
config = metadata.get("config", None)
|
|
1499
|
+
if config is not None:
|
|
1500
|
+
model._config = config
|
|
1501
|
+
|
|
1502
|
+
tied_weights_map = metadata.get("tied_weights_map", None)
|
|
1503
|
+
if tied_weights_map != None:
|
|
1504
|
+
for name, tied_weights_list in tied_weights_map.items():
|
|
1505
|
+
mapped_weight = state_dict[name]
|
|
1506
|
+
for tied_weights in tied_weights_list:
|
|
1507
|
+
state_dict[tied_weights] = mapped_weight
|
|
1508
|
+
|
|
1509
|
+
if quantization_map is None:
|
|
1510
|
+
detection_type = detect_safetensors_format(state_dict)
|
|
1511
|
+
if detection_type["kind"] in ['scaled_fp8','fp8']:
|
|
1512
|
+
conv_result = convert_scaled_fp8_to_quanto(state_dict, dtype = default_dtype, in_place= True)
|
|
1513
|
+
state_dict = conv_result["state_dict"]
|
|
1514
|
+
quantization_map = conv_result["quant_map"]
|
|
1515
|
+
conv_result = None
|
|
1516
|
+
# enable_fp8_fp32_scale_support()
|
|
1517
|
+
|
|
1518
|
+
if quantization_map is None:
|
|
1519
|
+
pos = str.rfind(file, ".")
|
|
1520
|
+
if pos > 0:
|
|
1521
|
+
quantization_map_path = file[:pos]
|
|
1522
|
+
quantization_map_path += "_map.json"
|
|
1523
|
+
|
|
1524
|
+
if os.path.isfile(quantization_map_path):
|
|
1525
|
+
with open(quantization_map_path, 'r') as f:
|
|
1526
|
+
quantization_map = json.load(f)
|
|
1515
1527
|
|
|
1516
1528
|
full_state_dict.update(state_dict)
|
|
1517
1529
|
if quantization_map != None:
|
|
@@ -1530,8 +1542,8 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1530
1542
|
full_state_dict, full_quantization_map, full_tied_weights_map = None, None, None
|
|
1531
1543
|
|
|
1532
1544
|
# deal if we are trying to load just a sub part of a larger model
|
|
1533
|
-
if
|
|
1534
|
-
state_dict, quantization_map =
|
|
1545
|
+
if postprocess_sd != None:
|
|
1546
|
+
state_dict, quantization_map = postprocess_sd(state_dict, quantization_map)
|
|
1535
1547
|
|
|
1536
1548
|
if modelPrefix != None:
|
|
1537
1549
|
base_model_prefix = modelPrefix + "."
|
|
@@ -1562,7 +1574,7 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1562
1574
|
|
|
1563
1575
|
del state_dict
|
|
1564
1576
|
|
|
1565
|
-
if len(unexpected_keys) > 0 and verboseLevel >=2:
|
|
1577
|
+
if len(unexpected_keys) > 0 and verboseLevel >=2 and not ignore_unused_weights:
|
|
1566
1578
|
print(f"Unexpected keys while loading '{file_path}': {unexpected_keys}")
|
|
1567
1579
|
|
|
1568
1580
|
for k,p in model.named_parameters():
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mmgp
|
|
3
|
-
Version: 3.6.
|
|
3
|
+
Version: 3.6.3
|
|
4
4
|
Summary: Memory Management for the GPU Poor
|
|
5
5
|
Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -15,7 +15,7 @@ Dynamic: license-file
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
<p align="center">
|
|
18
|
-
<H2>Memory Management 3.6.
|
|
18
|
+
<H2>Memory Management 3.6.3 for the GPU Poor by DeepBeepMeep</H2>
|
|
19
19
|
</p>
|
|
20
20
|
|
|
21
21
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|