mmgp 3.6.0__tar.gz → 3.6.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mmgp might be problematic. Click here for more details.
- {mmgp-3.6.0/src/mmgp.egg-info → mmgp-3.6.2}/PKG-INFO +2 -2
- {mmgp-3.6.0 → mmgp-3.6.2}/README.md +1 -1
- {mmgp-3.6.0 → mmgp-3.6.2}/pyproject.toml +1 -1
- mmgp-3.6.2/src/mmgp/fp8_quanto_bridge.py +504 -0
- {mmgp-3.6.0 → mmgp-3.6.2}/src/mmgp/offload.py +29 -12
- {mmgp-3.6.0 → mmgp-3.6.2/src/mmgp.egg-info}/PKG-INFO +2 -2
- {mmgp-3.6.0 → mmgp-3.6.2}/src/mmgp.egg-info/SOURCES.txt +1 -0
- {mmgp-3.6.0 → mmgp-3.6.2}/LICENSE.md +0 -0
- {mmgp-3.6.0 → mmgp-3.6.2}/setup.cfg +0 -0
- {mmgp-3.6.0 → mmgp-3.6.2}/src/__init__.py +0 -0
- {mmgp-3.6.0 → mmgp-3.6.2}/src/mmgp/__init__.py +0 -0
- {mmgp-3.6.0 → mmgp-3.6.2}/src/mmgp/safetensors2.py +0 -0
- {mmgp-3.6.0 → mmgp-3.6.2}/src/mmgp.egg-info/dependency_links.txt +0 -0
- {mmgp-3.6.0 → mmgp-3.6.2}/src/mmgp.egg-info/requires.txt +0 -0
- {mmgp-3.6.0 → mmgp-3.6.2}/src/mmgp.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mmgp
|
|
3
|
-
Version: 3.6.
|
|
3
|
+
Version: 3.6.2
|
|
4
4
|
Summary: Memory Management for the GPU Poor
|
|
5
5
|
Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -15,7 +15,7 @@ Dynamic: license-file
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
<p align="center">
|
|
18
|
-
<H2>Memory Management 3.6.
|
|
18
|
+
<H2>Memory Management 3.6.2 for the GPU Poor by DeepBeepMeep</H2>
|
|
19
19
|
</p>
|
|
20
20
|
|
|
21
21
|
|
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import json, re, inspect
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
from typing import Dict, Optional, Tuple, Union, Iterable, Callable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from safetensors.torch import safe_open, save_file
|
|
8
|
+
|
|
9
|
+
# ---------- Constants ----------
|
|
10
|
+
DATA_SUFFIX = "._data"
|
|
11
|
+
SCALE_SUFFIX = "._scale" # per-channel, shape [out, 1, ...]
|
|
12
|
+
IN_SCALE = ".input_scale" # 1-D placeholder tensor [1]
|
|
13
|
+
OUT_SCALE = ".output_scale" # 1-D placeholder tensor [1]
|
|
14
|
+
|
|
15
|
+
_QTYPE_NAME = {
|
|
16
|
+
"e4m3fn": "qfloat8_e4m3fn",
|
|
17
|
+
"e5m2": "qfloat8_e5m2",
|
|
18
|
+
"auto": "qfloat8",
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
_SCALE_META_KEYS = (
|
|
22
|
+
"fp8_scale_map", "fp8.scale_map", "scale_map",
|
|
23
|
+
"quant_scale_map", "weights_scales", "scales",
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
_DTYPE_ALIASES = {
|
|
27
|
+
"float32": torch.float32, "fp32": torch.float32,
|
|
28
|
+
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
|
|
29
|
+
"float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
def _is_weight_key(k: str) -> bool:
|
|
33
|
+
return k.endswith(".weight")
|
|
34
|
+
|
|
35
|
+
# ---------- Accessors (unify file vs dict) ----------
|
|
36
|
+
class Accessor:
|
|
37
|
+
def keys(self) -> Iterable[str]: ...
|
|
38
|
+
def get_tensor(self, key: str) -> torch.Tensor: ...
|
|
39
|
+
def metadata(self) -> Dict[str, str]: ...
|
|
40
|
+
def has(self, key: str) -> bool: ... # NEW
|
|
41
|
+
def can_delete(self) -> bool: return False
|
|
42
|
+
def delete(self, key: str) -> None: raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
class FileAccessor(Accessor):
|
|
45
|
+
def __init__(self, path: str):
|
|
46
|
+
self._fh = safe_open(path, framework="pt")
|
|
47
|
+
self._keys = list(self._fh.keys())
|
|
48
|
+
self._keys_set = set(self._keys) # O(1) membership
|
|
49
|
+
self._meta = self._fh.metadata() or {}
|
|
50
|
+
def keys(self) -> Iterable[str]: return self._keys
|
|
51
|
+
def has(self, key: str) -> bool: return key in self._keys_set
|
|
52
|
+
def get_tensor(self, key: str) -> torch.Tensor: return self._fh.get_tensor(key)
|
|
53
|
+
def metadata(self) -> Dict[str, str]: return self._meta
|
|
54
|
+
def close(self) -> None: self._fh.close()
|
|
55
|
+
|
|
56
|
+
class DictAccessor(Accessor):
|
|
57
|
+
def __init__(self, sd: Dict[str, torch.Tensor], meta: Optional[Dict[str, str]] = None,
|
|
58
|
+
in_place: bool = False, free_cuda_cache: bool = False, cuda_cache_interval: int = 32):
|
|
59
|
+
self.sd = sd
|
|
60
|
+
self._meta = meta or {}
|
|
61
|
+
self._in_place = in_place
|
|
62
|
+
self._free = free_cuda_cache
|
|
63
|
+
self._interval = int(cuda_cache_interval)
|
|
64
|
+
self._deletions = 0
|
|
65
|
+
def keys(self) -> Iterable[str]: return list(self.sd.keys())
|
|
66
|
+
def has(self, key: str) -> bool: return key in self.sd # dict membership = O(1)
|
|
67
|
+
def get_tensor(self, key: str) -> torch.Tensor: return self.sd[key]
|
|
68
|
+
def metadata(self) -> Dict[str, str]: return self._meta
|
|
69
|
+
def can_delete(self) -> bool: return self._in_place
|
|
70
|
+
def delete(self, key: str) -> None:
|
|
71
|
+
if key in self.sd:
|
|
72
|
+
self.sd.pop(key, None)
|
|
73
|
+
self._deletions += 1
|
|
74
|
+
if self._free and (self._deletions % self._interval == 0) and torch.cuda.is_available():
|
|
75
|
+
torch.cuda.empty_cache()
|
|
76
|
+
def _as_accessor(src: Union[str, Dict[str, torch.Tensor]], **dict_opts) -> Tuple[Accessor, Callable[[], None]]:
|
|
77
|
+
if isinstance(src, str):
|
|
78
|
+
acc = FileAccessor(src)
|
|
79
|
+
return acc, acc.close
|
|
80
|
+
acc = DictAccessor(src, **dict_opts)
|
|
81
|
+
return acc, (lambda: None)
|
|
82
|
+
|
|
83
|
+
# ---------- Shared helpers ----------
|
|
84
|
+
def _normalize_scale_dtype(scale_dtype: Union[str, torch.dtype]) -> torch.dtype:
|
|
85
|
+
if isinstance(scale_dtype, torch.dtype):
|
|
86
|
+
return scale_dtype
|
|
87
|
+
key = str(scale_dtype).lower()
|
|
88
|
+
if key not in _DTYPE_ALIASES:
|
|
89
|
+
raise ValueError(f"scale_dtype must be one of {list(_DTYPE_ALIASES.keys())} or a torch.dtype")
|
|
90
|
+
return _DTYPE_ALIASES[key]
|
|
91
|
+
|
|
92
|
+
def _json_to_dict(s: str) -> Optional[Dict]:
|
|
93
|
+
# Strictly catch JSON decoding only
|
|
94
|
+
try:
|
|
95
|
+
return json.loads(s)
|
|
96
|
+
except json.JSONDecodeError:
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
def _maybe_parse_scale_map(meta: Dict[str, str]) -> Optional[Dict[str, float]]:
|
|
100
|
+
def try_parse(obj) -> Optional[Dict[str, float]]:
|
|
101
|
+
if not isinstance(obj, dict):
|
|
102
|
+
return None
|
|
103
|
+
out: Dict[str, float] = {}
|
|
104
|
+
for wk, v in obj.items():
|
|
105
|
+
if isinstance(v, (int, float)):
|
|
106
|
+
out[wk] = float(v)
|
|
107
|
+
elif isinstance(v, dict) and "scale" in v:
|
|
108
|
+
sc = v["scale"]
|
|
109
|
+
if isinstance(sc, (int, float)):
|
|
110
|
+
out[wk] = float(sc)
|
|
111
|
+
elif isinstance(sc, (list, tuple)) and len(sc) == 1 and isinstance(sc[0], (int, float)):
|
|
112
|
+
out[wk] = float(sc[0])
|
|
113
|
+
if out:
|
|
114
|
+
return out
|
|
115
|
+
for sub in ("weights", "tensors", "params", "map"):
|
|
116
|
+
subobj = obj.get(sub)
|
|
117
|
+
if isinstance(subobj, dict):
|
|
118
|
+
got = try_parse(subobj)
|
|
119
|
+
if got:
|
|
120
|
+
return got
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
# exact keys first
|
|
124
|
+
for k in _SCALE_META_KEYS:
|
|
125
|
+
raw = meta.get(k)
|
|
126
|
+
if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"):
|
|
127
|
+
parsed = _json_to_dict(raw)
|
|
128
|
+
if parsed:
|
|
129
|
+
got = try_parse(parsed)
|
|
130
|
+
if got:
|
|
131
|
+
return got
|
|
132
|
+
|
|
133
|
+
# loose scan of any JSON-looking value
|
|
134
|
+
for v in meta.values():
|
|
135
|
+
if isinstance(v, str) and v.startswith("{") and v.endswith("}"):
|
|
136
|
+
parsed = _json_to_dict(v)
|
|
137
|
+
if parsed:
|
|
138
|
+
got = try_parse(parsed)
|
|
139
|
+
if got:
|
|
140
|
+
return got
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
def _quick_fp8_variant_from_sentinel(acc: Accessor) -> Optional[str]:
|
|
144
|
+
if "scaled_fp8" in set(acc.keys()):
|
|
145
|
+
dt = acc.get_tensor("scaled_fp8").dtype
|
|
146
|
+
if dt == torch.float8_e4m3fn: return "e4m3fn"
|
|
147
|
+
if dt == torch.float8_e5m2: return "e5m2"
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
def _per_channel_reshape(vec: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
|
151
|
+
return vec.view(weight.shape[0], *([1] * (weight.ndim - 1)))
|
|
152
|
+
|
|
153
|
+
# ---------- Unified converter ----------
|
|
154
|
+
class ConvertResult(Dict[str, object]):
|
|
155
|
+
@property
|
|
156
|
+
def state_dict(self) -> Dict[str, torch.Tensor]: return self["state_dict"] # type: ignore
|
|
157
|
+
@property
|
|
158
|
+
def quant_map(self) -> Dict[str, Dict]: return self["quant_map"] # type: ignore
|
|
159
|
+
@property
|
|
160
|
+
def fp8_format(self) -> str: return self["fp8_format"] # type: ignore
|
|
161
|
+
@property
|
|
162
|
+
def patch_needed(self) -> bool: return self["patch_needed"] # type: ignore
|
|
163
|
+
|
|
164
|
+
def convert_scaled_fp8_to_quanto(
|
|
165
|
+
src: Union[str, Dict[str, torch.Tensor]],
|
|
166
|
+
*,
|
|
167
|
+
fp8_format: Optional[str] = None, # 'e4m3fn' | 'e5m2' | None (auto)
|
|
168
|
+
require_scale: bool = False,
|
|
169
|
+
allow_default_scale: bool = True,
|
|
170
|
+
default_missing_scale: float = 1.0,
|
|
171
|
+
scale_dtype: Union[str, torch.dtype] = "float32",
|
|
172
|
+
add_activation_placeholders: bool = True,
|
|
173
|
+
# dict mode options
|
|
174
|
+
sd_metadata: Optional[Dict[str, str]] = None,
|
|
175
|
+
in_place: bool = False,
|
|
176
|
+
free_cuda_cache: bool = False,
|
|
177
|
+
cuda_cache_interval: int = 32,
|
|
178
|
+
) -> ConvertResult:
|
|
179
|
+
sd_scale_dtype = _normalize_scale_dtype(scale_dtype)
|
|
180
|
+
patch_needed = (sd_scale_dtype == torch.float32)
|
|
181
|
+
|
|
182
|
+
acc, closer = _as_accessor(
|
|
183
|
+
src,
|
|
184
|
+
meta=sd_metadata,
|
|
185
|
+
in_place=in_place,
|
|
186
|
+
free_cuda_cache=free_cuda_cache,
|
|
187
|
+
cuda_cache_interval=cuda_cache_interval,
|
|
188
|
+
)
|
|
189
|
+
try:
|
|
190
|
+
meta = acc.metadata() or {}
|
|
191
|
+
meta_scale_map = _maybe_parse_scale_map(meta) or {}
|
|
192
|
+
|
|
193
|
+
keys = list(acc.keys())
|
|
194
|
+
|
|
195
|
+
# FP8 variant: sentinel -> first FP8 weight -> 'auto'
|
|
196
|
+
fmt = fp8_format or _quick_fp8_variant_from_sentinel(acc)
|
|
197
|
+
if fmt is None:
|
|
198
|
+
for wk in keys:
|
|
199
|
+
if not _is_weight_key(wk): continue
|
|
200
|
+
dt = acc.get_tensor(wk).dtype
|
|
201
|
+
if dt == torch.float8_e4m3fn: fmt = "e4m3fn"; break
|
|
202
|
+
if dt == torch.float8_e5m2: fmt = "e5m2"; break
|
|
203
|
+
if fmt is None: fmt = "auto"
|
|
204
|
+
|
|
205
|
+
# Map '<base>.scale_weight' -> '<base>.weight'
|
|
206
|
+
scale_weight_map: Dict[str, str] = {}
|
|
207
|
+
for sk in keys:
|
|
208
|
+
if sk.endswith(".scale_weight"):
|
|
209
|
+
base = sk[: -len(".scale_weight")]
|
|
210
|
+
wk = base + ".weight"
|
|
211
|
+
if wk in keys:
|
|
212
|
+
scale_weight_map[wk] = sk
|
|
213
|
+
|
|
214
|
+
def get_scale_vec_for_weight(wk: str, out_ch: int) -> Optional[torch.Tensor]:
|
|
215
|
+
# 1) explicit tensor
|
|
216
|
+
sk = scale_weight_map.get(wk)
|
|
217
|
+
if sk is not None:
|
|
218
|
+
s_t = acc.get_tensor(sk).to(torch.float32)
|
|
219
|
+
if s_t.numel() == 1:
|
|
220
|
+
return torch.full((out_ch,), float(s_t.item()), dtype=torch.float32)
|
|
221
|
+
if s_t.numel() == out_ch:
|
|
222
|
+
return s_t.reshape(out_ch)
|
|
223
|
+
if torch.numel(s_t.unique()) == 1:
|
|
224
|
+
return torch.full((out_ch,), float(s_t.view(-1)[0].item()), dtype=torch.float32)
|
|
225
|
+
raise ValueError(f"Unexpected scale length for '{wk}': {s_t.numel()} (out_ch={out_ch})")
|
|
226
|
+
# 2) metadata exact / normalized
|
|
227
|
+
if wk in meta_scale_map:
|
|
228
|
+
return torch.full((out_ch,), float(meta_scale_map[wk]), dtype=torch.float32)
|
|
229
|
+
for alt in (wk.replace("model.", ""), re.sub(r"(^|\.)weight$", "", wk)):
|
|
230
|
+
if alt in meta_scale_map:
|
|
231
|
+
return torch.full((out_ch,), float(meta_scale_map[alt]), dtype=torch.float32)
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
# out dict: mutate original dict if in_place, else new dict
|
|
235
|
+
out_sd: Dict[str, torch.Tensor] = acc.sd if isinstance(acc, DictAccessor) and in_place else {}
|
|
236
|
+
qmap: Dict[str, Dict] = {}
|
|
237
|
+
|
|
238
|
+
# Single pass: rewrite FP8 weights, copy-through others
|
|
239
|
+
for k in keys:
|
|
240
|
+
# Drop source-only artifacts
|
|
241
|
+
if k == "scaled_fp8" or k.endswith(".scale_weight"):
|
|
242
|
+
if acc.can_delete(): acc.delete(k)
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
if _is_weight_key(k):
|
|
246
|
+
t = acc.get_tensor(k)
|
|
247
|
+
if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
|
248
|
+
# Quantized: keep original FP8 tensor as _data
|
|
249
|
+
out_sd[k + DATA_SUFFIX] = t
|
|
250
|
+
|
|
251
|
+
out_ch = int(t.shape[0])
|
|
252
|
+
s_vec = get_scale_vec_for_weight(k, out_ch)
|
|
253
|
+
if s_vec is None:
|
|
254
|
+
if require_scale and not allow_default_scale:
|
|
255
|
+
raise KeyError(f"No scale found for '{k}' (looked for '.scale_weight' and metadata).")
|
|
256
|
+
s_vec = torch.full((out_ch,), float(default_missing_scale), dtype=torch.float32)
|
|
257
|
+
|
|
258
|
+
s_grid = _per_channel_reshape(s_vec, t).to(sd_scale_dtype)
|
|
259
|
+
out_sd[k + SCALE_SUFFIX] = s_grid
|
|
260
|
+
|
|
261
|
+
if add_activation_placeholders:
|
|
262
|
+
base = k[:-len(".weight")]
|
|
263
|
+
out_sd[base + IN_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
|
|
264
|
+
out_sd[base + OUT_SCALE] = torch.tensor([1], dtype=sd_scale_dtype)
|
|
265
|
+
|
|
266
|
+
base = k[:-len(".weight")]
|
|
267
|
+
qmap[base] = {"weights": _QTYPE_NAME[fmt], "activations": "none"}
|
|
268
|
+
|
|
269
|
+
if acc.can_delete():
|
|
270
|
+
acc.delete(k)
|
|
271
|
+
continue # don't copy original .weight
|
|
272
|
+
|
|
273
|
+
# Copy-through
|
|
274
|
+
if not (isinstance(acc, DictAccessor) and in_place):
|
|
275
|
+
out_sd[k] = acc.get_tensor(k)
|
|
276
|
+
|
|
277
|
+
return ConvertResult(state_dict=out_sd, quant_map=qmap, fp8_format=fmt, patch_needed=patch_needed)
|
|
278
|
+
finally:
|
|
279
|
+
closer()
|
|
280
|
+
|
|
281
|
+
def detect_safetensors_format(
|
|
282
|
+
src: Union[str, Dict[str, torch.Tensor]],
|
|
283
|
+
*,
|
|
284
|
+
sd_metadata: Optional[Dict[str, str]] = None,
|
|
285
|
+
probe_weights: bool = False, # if True, we may read up to 2 weights total
|
|
286
|
+
with_hints: bool = False,
|
|
287
|
+
) -> Dict[str, str]:
|
|
288
|
+
"""
|
|
289
|
+
Returns:
|
|
290
|
+
{
|
|
291
|
+
'kind': 'quanto' | 'scaled_fp8' | 'fp8' | 'none',
|
|
292
|
+
'quant_format': 'qfloat8_e4m3fn' | 'qfloat8_e5m2' | 'qfloat8' | 'qint8' | 'qint4' | 'unknown' | '',
|
|
293
|
+
'fp8_format': 'e4m3fn' | 'e5m2' | 'unknown' | '',
|
|
294
|
+
'hint': '...' # only when with_hints=True
|
|
295
|
+
}
|
|
296
|
+
"""
|
|
297
|
+
acc, closer = _as_accessor(src, meta=sd_metadata, in_place=False)
|
|
298
|
+
try:
|
|
299
|
+
# --- O(1) sentinel test up-front (no key scan) ---
|
|
300
|
+
if acc.has("scaled_fp8"):
|
|
301
|
+
dt = acc.get_tensor("scaled_fp8").dtype
|
|
302
|
+
fp8_fmt = "e4m3fn" if dt == torch.float8_e4m3fn else ("e5m2" if dt == torch.float8_e5m2 else "unknown")
|
|
303
|
+
out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_fmt}
|
|
304
|
+
if with_hints: out["hint"] = "sentinel"
|
|
305
|
+
return out
|
|
306
|
+
|
|
307
|
+
# --- Single pass over keys (no re-scans) ---
|
|
308
|
+
ks = list(acc.keys())
|
|
309
|
+
has_scale_weight = False
|
|
310
|
+
saw_quanto_data = False
|
|
311
|
+
fp8_variant = None
|
|
312
|
+
fp8_probe_budget = 2 if probe_weights else 1
|
|
313
|
+
|
|
314
|
+
for k in ks:
|
|
315
|
+
# Quanto pack short-circuit
|
|
316
|
+
if not saw_quanto_data and k.endswith(DATA_SUFFIX):
|
|
317
|
+
saw_quanto_data = True
|
|
318
|
+
# we can break here, but keep minimal state setting uniformity
|
|
319
|
+
break
|
|
320
|
+
|
|
321
|
+
if saw_quanto_data:
|
|
322
|
+
out = {"kind": "quanto", "quant_format": "qfloat8", "fp8_format": ""}
|
|
323
|
+
if with_hints: out["hint"] = "keys:*._data"
|
|
324
|
+
return out
|
|
325
|
+
|
|
326
|
+
# continue single pass for the rest (scale keys + bounded dtype probe)
|
|
327
|
+
for k in ks:
|
|
328
|
+
if not has_scale_weight and k.endswith(".scale_weight"):
|
|
329
|
+
has_scale_weight = True
|
|
330
|
+
# don't return yet; we may still probe a dtype to grab variant
|
|
331
|
+
|
|
332
|
+
if fp8_probe_budget > 0 and _is_weight_key(k):
|
|
333
|
+
dt = acc.get_tensor(k).dtype
|
|
334
|
+
if dt == torch.float8_e4m3fn:
|
|
335
|
+
fp8_variant = "e4m3fn"; fp8_probe_budget -= 1
|
|
336
|
+
elif dt == torch.float8_e5m2:
|
|
337
|
+
fp8_variant = "e5m2"; fp8_probe_budget -= 1
|
|
338
|
+
|
|
339
|
+
if has_scale_weight:
|
|
340
|
+
out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fp8_variant or "unknown"}
|
|
341
|
+
if with_hints: out["hint"] = "scale_weight keys"
|
|
342
|
+
return out
|
|
343
|
+
|
|
344
|
+
if fp8_variant is not None:
|
|
345
|
+
out = {"kind": "fp8", "quant_format": "", "fp8_format": fp8_variant}
|
|
346
|
+
if with_hints: out["hint"] = "weight dtype (plain fp8)"
|
|
347
|
+
return out
|
|
348
|
+
|
|
349
|
+
# --- Cheap metadata peek only if keys didn't decide it (no JSON parsing) ---
|
|
350
|
+
meta = acc.metadata() or {}
|
|
351
|
+
blob = " ".join(v for v in meta.values() if isinstance(v, str)).lower()
|
|
352
|
+
|
|
353
|
+
# scaled-fp8 hinted by metadata only
|
|
354
|
+
has_scale_map = (
|
|
355
|
+
any(k in meta for k in _SCALE_META_KEYS) or
|
|
356
|
+
(("scale" in blob) and (("fp8" in blob) or ("float8" in blob)))
|
|
357
|
+
)
|
|
358
|
+
if has_scale_map:
|
|
359
|
+
fmt = "e4m3fn" if "e4m3" in blob else ("e5m2" if "e5m2" in blob else "unknown")
|
|
360
|
+
out = {"kind": "scaled_fp8", "quant_format": "", "fp8_format": fmt}
|
|
361
|
+
if with_hints: out["hint"] = "metadata"
|
|
362
|
+
return out
|
|
363
|
+
|
|
364
|
+
# quanto hinted by metadata only (not decisive without keys)
|
|
365
|
+
qtype_hint = ""
|
|
366
|
+
for tok in ("qfloat8_e4m3fn", "qfloat8_e5m2", "qfloat8", "qint8", "qint4"):
|
|
367
|
+
if tok in blob:
|
|
368
|
+
qtype_hint = tok
|
|
369
|
+
break
|
|
370
|
+
|
|
371
|
+
out = {"kind": "none", "quant_format": qtype_hint, "fp8_format": ""}
|
|
372
|
+
if with_hints: out["hint"] = "no decisive keys"
|
|
373
|
+
return out
|
|
374
|
+
|
|
375
|
+
finally:
|
|
376
|
+
closer()
|
|
377
|
+
|
|
378
|
+
# ---------- Optional Quanto runtime patch (FP32-scale support), enable/disable ----------
|
|
379
|
+
_patch_state = SimpleNamespace(enabled=False, orig=None, scale_index=None)
|
|
380
|
+
|
|
381
|
+
def enable_fp8_fp32_scale_support():
|
|
382
|
+
"""
|
|
383
|
+
Version-robust wrapper for WeightQBytesTensor.create:
|
|
384
|
+
- matches both positional/keyword call styles via *args/**kwargs,
|
|
385
|
+
- for FP8 + FP32 scales, expands scalar/uniform scales with a VIEW to the needed length,
|
|
386
|
+
- leaves bf16/fp16 (classic Quanto) untouched.
|
|
387
|
+
Enable only if you emitted float32 scales.
|
|
388
|
+
"""
|
|
389
|
+
if _patch_state.enabled:
|
|
390
|
+
return True
|
|
391
|
+
|
|
392
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes # late import
|
|
393
|
+
orig = _qbytes.WeightQBytesTensor.create
|
|
394
|
+
sig = inspect.signature(orig)
|
|
395
|
+
params = list(sig.parameters.keys())
|
|
396
|
+
scale_index = params.index("scale") if "scale" in params else 5 # fallback
|
|
397
|
+
|
|
398
|
+
def wrapper(*args, **kwargs):
|
|
399
|
+
# Extract fields irrespective of signature
|
|
400
|
+
qtype = kwargs.get("qtype", args[0] if len(args) > 0 else None)
|
|
401
|
+
axis = kwargs.get("axis", args[1] if len(args) > 1 else None)
|
|
402
|
+
size = kwargs.get("size", args[2] if len(args) > 2 else None)
|
|
403
|
+
|
|
404
|
+
if "scale" in kwargs:
|
|
405
|
+
scale = kwargs["scale"]
|
|
406
|
+
def set_scale(new): kwargs.__setitem__("scale", new)
|
|
407
|
+
else:
|
|
408
|
+
scale = args[scale_index] if len(args) > scale_index else None
|
|
409
|
+
def set_scale(new):
|
|
410
|
+
nonlocal args
|
|
411
|
+
args = list(args)
|
|
412
|
+
if len(args) > scale_index:
|
|
413
|
+
args[scale_index] = new
|
|
414
|
+
else:
|
|
415
|
+
kwargs["scale"] = new
|
|
416
|
+
args = tuple(args)
|
|
417
|
+
|
|
418
|
+
is_fp8 = isinstance(qtype, str) and ("float8" in qtype.lower() or "qfloat8" in qtype.lower()) or \
|
|
419
|
+
(not isinstance(qtype, str) and "float8" in str(qtype).lower())
|
|
420
|
+
|
|
421
|
+
if is_fp8 and isinstance(scale, torch.Tensor) and scale.dtype == torch.float32:
|
|
422
|
+
need = int(size[axis]) if (isinstance(size, (tuple, list)) and axis is not None and axis >= 0) else None
|
|
423
|
+
if need is not None:
|
|
424
|
+
if scale.numel() == 1:
|
|
425
|
+
scale = scale.view(1).expand(need, *scale.shape[1:])
|
|
426
|
+
elif scale.shape[0] != need:
|
|
427
|
+
# Expand if uniform; otherwise raise
|
|
428
|
+
uniform = (scale.numel() == 1) or (torch.numel(scale.unique()) == 1)
|
|
429
|
+
if uniform:
|
|
430
|
+
scale = scale.reshape(1, *scale.shape[1:]).expand(need, *scale.shape[1:])
|
|
431
|
+
else:
|
|
432
|
+
raise ValueError(f"Scale leading dim {scale.shape[0]} != required {need}")
|
|
433
|
+
set_scale(scale)
|
|
434
|
+
|
|
435
|
+
return orig(*args, **kwargs)
|
|
436
|
+
|
|
437
|
+
_qbytes.WeightQBytesTensor.create = wrapper
|
|
438
|
+
_patch_state.enabled = True
|
|
439
|
+
_patch_state.orig = orig
|
|
440
|
+
_patch_state.scale_index = scale_index
|
|
441
|
+
return True
|
|
442
|
+
|
|
443
|
+
def disable_fp8_fp32_scale_support():
|
|
444
|
+
"""Restore Quanto's original factory."""
|
|
445
|
+
if not _patch_state.enabled:
|
|
446
|
+
return False
|
|
447
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes
|
|
448
|
+
_qbytes.WeightQBytesTensor.create = _patch_state.orig
|
|
449
|
+
_patch_state.enabled = False
|
|
450
|
+
_patch_state.orig = None
|
|
451
|
+
_patch_state.scale_index = None
|
|
452
|
+
return True
|
|
453
|
+
|
|
454
|
+
# ---------- Tiny CLI (optional) ----------
|
|
455
|
+
def _cli():
|
|
456
|
+
import argparse, json as _json
|
|
457
|
+
p = argparse.ArgumentParser("fp8_quanto_bridge")
|
|
458
|
+
sub = p.add_subparsers(dest="cmd", required=True)
|
|
459
|
+
|
|
460
|
+
p_conv = sub.add_parser("convert", help="Convert scaled-FP8 (file) to Quanto artifacts.")
|
|
461
|
+
p_conv.add_argument("in_path")
|
|
462
|
+
p_conv.add_argument("out_weights")
|
|
463
|
+
p_conv.add_argument("out_qmap")
|
|
464
|
+
p_conv.add_argument("--fp8-format", choices=("e4m3fn", "e5m2"), default=None)
|
|
465
|
+
p_conv.add_argument("--scale-dtype", default="float32",
|
|
466
|
+
choices=("float32","bfloat16","float16","fp32","bf16","fp16","half"))
|
|
467
|
+
p_conv.add_argument("--no-activation-placeholders", action="store_true")
|
|
468
|
+
p_conv.add_argument("--default-missing-scale", type=float, default=1.0)
|
|
469
|
+
|
|
470
|
+
p_det = sub.add_parser("detect", help="Detect format quickly (path).")
|
|
471
|
+
p_det.add_argument("path")
|
|
472
|
+
p_det.add_argument("--probe", action="store_true")
|
|
473
|
+
p_det.add_argument("--hints", action="store_true")
|
|
474
|
+
|
|
475
|
+
p_patch = sub.add_parser("patch", help="Enable/disable FP32-scale runtime patch.")
|
|
476
|
+
p_patch.add_argument("mode", choices=("enable","disable"))
|
|
477
|
+
|
|
478
|
+
args = p.parse_args()
|
|
479
|
+
|
|
480
|
+
if args.cmd == "convert":
|
|
481
|
+
res = convert_scaled_fp8_to_quanto(
|
|
482
|
+
args.in_path,
|
|
483
|
+
fp8_format=args.fp8_format,
|
|
484
|
+
scale_dtype=args.scale_dtype,
|
|
485
|
+
add_activation_placeholders=not args.no_activation_placeholders,
|
|
486
|
+
default_missing_scale=args.default_missing_scale,
|
|
487
|
+
)
|
|
488
|
+
save_file(res.state_dict, args.out_weights)
|
|
489
|
+
with open(args.out_qmap, "w") as f:
|
|
490
|
+
_json.dump(res.quant_map, f)
|
|
491
|
+
print(f"Wrote: {args.out_weights} and {args.out_qmap}. Patch needed: {res.patch_needed}")
|
|
492
|
+
return 0
|
|
493
|
+
|
|
494
|
+
if args.cmd == "detect":
|
|
495
|
+
info = detect_safetensors_format(args.path, probe_weights=args.probe, with_hints=args.hints)
|
|
496
|
+
print(info); return 0
|
|
497
|
+
|
|
498
|
+
if args.cmd == "patch":
|
|
499
|
+
ok = enable_fp8_fp32_scale_support() if args.mode == "enable" else disable_fp8_fp32_scale_support()
|
|
500
|
+
print(f"patch {args.mode}: {'ok' if ok else 'already in that state'}")
|
|
501
|
+
return 0
|
|
502
|
+
|
|
503
|
+
if __name__ == "__main__":
|
|
504
|
+
raise SystemExit(_cli())
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# ------------------ Memory Management 3.6.
|
|
1
|
+
# ------------------ Memory Management 3.6.2 for the GPU Poor by DeepBeepMeep (mmgp)------------------
|
|
2
2
|
#
|
|
3
3
|
# This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
|
|
4
4
|
# This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
|
|
@@ -71,7 +71,7 @@ import torch
|
|
|
71
71
|
|
|
72
72
|
from mmgp import safetensors2
|
|
73
73
|
from mmgp import profile_type
|
|
74
|
-
|
|
74
|
+
from .fp8_quanto_bridge import convert_scaled_fp8_to_quanto, detect_safetensors_format , enable_fp8_fp32_scale_support
|
|
75
75
|
from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QLinear, QTensor, quantize_module, register_qmodule
|
|
76
76
|
|
|
77
77
|
# support for Embedding module quantization that is not supported by default by quanto
|
|
@@ -688,7 +688,7 @@ def _welcome():
|
|
|
688
688
|
if welcome_displayed:
|
|
689
689
|
return
|
|
690
690
|
welcome_displayed = True
|
|
691
|
-
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.
|
|
691
|
+
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.2) by DeepBeepMeep ************{ENDC}{UNBOLD}")
|
|
692
692
|
|
|
693
693
|
def change_dtype(model, new_dtype, exclude_buffers = False):
|
|
694
694
|
for submodule_name, submodule in model.named_modules():
|
|
@@ -1097,7 +1097,9 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1097
1097
|
|
|
1098
1098
|
invalid_keys = []
|
|
1099
1099
|
unexpected_keys = []
|
|
1100
|
-
|
|
1100
|
+
new_state_dict = {}
|
|
1101
|
+
for k in list(state_dict.keys()):
|
|
1102
|
+
v = state_dict.pop(k)
|
|
1101
1103
|
lora_A = lora_B = diff_b = diff = lora_key = None
|
|
1102
1104
|
if k.endswith(".diff"):
|
|
1103
1105
|
diff = v
|
|
@@ -1141,6 +1143,7 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1141
1143
|
error_msg = append(error_msg, msg)
|
|
1142
1144
|
fail = True
|
|
1143
1145
|
break
|
|
1146
|
+
v = lora_A = lora_A.to(module.weight.dtype)
|
|
1144
1147
|
elif lora_B != None:
|
|
1145
1148
|
rank = lora_B.shape[1]
|
|
1146
1149
|
if module_shape[0] != v.shape[0]:
|
|
@@ -1151,6 +1154,7 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1151
1154
|
error_msg = append(error_msg, msg)
|
|
1152
1155
|
fail = True
|
|
1153
1156
|
break
|
|
1157
|
+
v = lora_B = lora_B.to(module.weight.dtype)
|
|
1154
1158
|
elif diff != None:
|
|
1155
1159
|
lora_B = diff
|
|
1156
1160
|
if module_shape != v.shape:
|
|
@@ -1161,6 +1165,7 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1161
1165
|
error_msg = append(error_msg, msg)
|
|
1162
1166
|
fail = True
|
|
1163
1167
|
break
|
|
1168
|
+
v = lora_B = lora_B.to(module.weight.dtype)
|
|
1164
1169
|
elif diff_b != None:
|
|
1165
1170
|
rank = diff_b.shape[0]
|
|
1166
1171
|
if not hasattr(module, "bias"):
|
|
@@ -1179,8 +1184,11 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1179
1184
|
error_msg = append(error_msg, msg)
|
|
1180
1185
|
fail = True
|
|
1181
1186
|
break
|
|
1187
|
+
v = diff_b = diff_b.to(module.weight.dtype)
|
|
1182
1188
|
|
|
1183
1189
|
if not check_only:
|
|
1190
|
+
new_state_dict[k] = v
|
|
1191
|
+
v = None
|
|
1184
1192
|
loras_module_data = loras_model_data.get(module, None)
|
|
1185
1193
|
assert loras_module_data != None
|
|
1186
1194
|
loras_adapter_data = loras_module_data.get(adapter_name, None)
|
|
@@ -1188,11 +1196,11 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1188
1196
|
loras_adapter_data = [None, None, None, 1.]
|
|
1189
1197
|
loras_module_data[adapter_name] = loras_adapter_data
|
|
1190
1198
|
if lora_A != None:
|
|
1191
|
-
loras_adapter_data[0] = lora_A
|
|
1199
|
+
loras_adapter_data[0] = lora_A
|
|
1192
1200
|
elif lora_B != None:
|
|
1193
|
-
loras_adapter_data[1] = lora_B
|
|
1201
|
+
loras_adapter_data[1] = lora_B
|
|
1194
1202
|
else:
|
|
1195
|
-
loras_adapter_data[2] = diff_b
|
|
1203
|
+
loras_adapter_data[2] = diff_b
|
|
1196
1204
|
if rank != None and lora_key is not None and "lora" in lora_key:
|
|
1197
1205
|
alpha_key = k[:-len(lora_key)] + "alpha"
|
|
1198
1206
|
alpha = lora_alphas.get(alpha_key, None)
|
|
@@ -1220,7 +1228,7 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1220
1228
|
if not check_only:
|
|
1221
1229
|
# model._loras_tied_weights[adapter_name] = tied_weights
|
|
1222
1230
|
if pinnedLora:
|
|
1223
|
-
pinned_sd_list.append(
|
|
1231
|
+
pinned_sd_list.append(new_state_dict)
|
|
1224
1232
|
pinned_names_list.append(path)
|
|
1225
1233
|
# _pin_sd_to_memory(state_dict, path)
|
|
1226
1234
|
|
|
@@ -1311,7 +1319,7 @@ def move_loras_to_device(model, device="cpu" ):
|
|
|
1311
1319
|
if ".lora_" in k:
|
|
1312
1320
|
m.to(device)
|
|
1313
1321
|
|
|
1314
|
-
def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None,
|
|
1322
|
+
def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, configKwargs ={}):
|
|
1315
1323
|
"""
|
|
1316
1324
|
quick version of .LoadfromPretrained of the transformers library
|
|
1317
1325
|
used to build a model and load the corresponding weights (quantized or not)
|
|
@@ -1399,13 +1407,13 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
|
|
|
1399
1407
|
|
|
1400
1408
|
model._config = transformer_config
|
|
1401
1409
|
|
|
1402
|
-
load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, verboseLevel=verboseLevel )
|
|
1410
|
+
load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, default_dtype = default_dtype, ignore_unused_weights = ignore_unused_weights, verboseLevel=verboseLevel )
|
|
1403
1411
|
|
|
1404
1412
|
return model
|
|
1405
1413
|
|
|
1406
1414
|
|
|
1407
1415
|
|
|
1408
|
-
def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, modules = None, return_shared_modules = None, verboseLevel = -1):
|
|
1416
|
+
def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, verboseLevel = -1):
|
|
1409
1417
|
"""
|
|
1410
1418
|
Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
|
|
1411
1419
|
"""
|
|
@@ -1495,6 +1503,15 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1495
1503
|
for tied_weights in tied_weights_list:
|
|
1496
1504
|
state_dict[tied_weights] = mapped_weight
|
|
1497
1505
|
|
|
1506
|
+
if quantization_map is None:
|
|
1507
|
+
detection_type = detect_safetensors_format(state_dict)
|
|
1508
|
+
if detection_type["kind"] in ['scaled_fp8','fp8']:
|
|
1509
|
+
conv_result = convert_scaled_fp8_to_quanto(state_dict, scale_dtype = default_dtype)
|
|
1510
|
+
state_dict = conv_result["state_dict"]
|
|
1511
|
+
quantization_map = conv_result["quant_map"]
|
|
1512
|
+
conv_result = None
|
|
1513
|
+
# enable_fp8_fp32_scale_support()
|
|
1514
|
+
|
|
1498
1515
|
if quantization_map is None:
|
|
1499
1516
|
pos = str.rfind(file, ".")
|
|
1500
1517
|
if pos > 0:
|
|
@@ -1554,7 +1571,7 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1554
1571
|
|
|
1555
1572
|
del state_dict
|
|
1556
1573
|
|
|
1557
|
-
if len(unexpected_keys) > 0 and verboseLevel >=2:
|
|
1574
|
+
if len(unexpected_keys) > 0 and verboseLevel >=2 and not ignore_unused_weights:
|
|
1558
1575
|
print(f"Unexpected keys while loading '{file_path}': {unexpected_keys}")
|
|
1559
1576
|
|
|
1560
1577
|
for k,p in model.named_parameters():
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mmgp
|
|
3
|
-
Version: 3.6.
|
|
3
|
+
Version: 3.6.2
|
|
4
4
|
Summary: Memory Management for the GPU Poor
|
|
5
5
|
Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -15,7 +15,7 @@ Dynamic: license-file
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
<p align="center">
|
|
18
|
-
<H2>Memory Management 3.6.
|
|
18
|
+
<H2>Memory Management 3.6.2 for the GPU Poor by DeepBeepMeep</H2>
|
|
19
19
|
</p>
|
|
20
20
|
|
|
21
21
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|