mmgp 3.5.7__py3-none-any.whl → 3.6.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mmgp/fp8_quanto_bridge.py +645 -0
- mmgp/fp8_quanto_bridge_old.py +498 -0
- mmgp/offload.py +1038 -248
- mmgp/quant_router.py +518 -0
- mmgp/quanto_int8_cuda.py +97 -0
- mmgp/quanto_int8_inject.py +335 -0
- mmgp/safetensors2.py +57 -10
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/METADATA +2 -2
- mmgp-3.6.11.dist-info/RECORD +14 -0
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/licenses/LICENSE.md +1 -1
- mmgp-3.5.7.dist-info/RECORD +0 -9
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/WHEEL +0 -0
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import os
|
|
5
|
+
from types import SimpleNamespace
|
|
6
|
+
from typing import Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
# Env toggles
|
|
12
|
+
_ENV_ENABLE = "WAN2GP_QUANTO_INT8_KERNEL"
|
|
13
|
+
_ENV_DEBUG = "WAN2GP_QUANTO_INT8_DEBUG"
|
|
14
|
+
_ENV_ALIGN_M = "WAN2GP_QUANTO_INT8_ALIGN_M"
|
|
15
|
+
_ENV_ALIGN_N = "WAN2GP_QUANTO_INT8_ALIGN_N"
|
|
16
|
+
_ENV_ALIGN_K = "WAN2GP_QUANTO_INT8_ALIGN_K"
|
|
17
|
+
_ENV_USE_TC = "WAN2GP_QUANTO_INT8_TC"
|
|
18
|
+
|
|
19
|
+
# Kernel namespace/entrypoints (resolved lazily)
|
|
20
|
+
_KERNEL_OP = None
|
|
21
|
+
_KERNEL_MODULES = (
|
|
22
|
+
"mmgp.quanto_int8_cuda",
|
|
23
|
+
"mmgp_quanto_int8_cuda",
|
|
24
|
+
"mmgp_quanto_int8",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _env_flag(name: str, default: str = "1") -> bool:
|
|
29
|
+
val = os.environ.get(name, default)
|
|
30
|
+
return str(val).strip().lower() in ("1", "true", "yes", "on")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _env_int(name: str, default: int) -> int:
|
|
34
|
+
raw = os.environ.get(name)
|
|
35
|
+
if raw is None:
|
|
36
|
+
return default
|
|
37
|
+
try:
|
|
38
|
+
return int(raw)
|
|
39
|
+
except ValueError:
|
|
40
|
+
return default
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _debug(msg: str) -> None:
|
|
44
|
+
if _env_flag(_ENV_DEBUG, "0"):
|
|
45
|
+
print(f"[WAN2GP][INT8][quanto] {msg}")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _get_alignments() -> Tuple[int, int, int]:
|
|
49
|
+
# Conservative defaults; keep in sync with kernel requirements.
|
|
50
|
+
align_m = _env_int(_ENV_ALIGN_M, 16)
|
|
51
|
+
align_n = _env_int(_ENV_ALIGN_N, 16)
|
|
52
|
+
align_k = _env_int(_ENV_ALIGN_K, 16)
|
|
53
|
+
return align_m, align_n, align_k
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _is_qbytes_tensor(t: torch.Tensor) -> bool:
|
|
57
|
+
try:
|
|
58
|
+
from optimum.quanto.tensor.qbytes import QBytesTensor
|
|
59
|
+
except Exception:
|
|
60
|
+
return False
|
|
61
|
+
return isinstance(t, QBytesTensor)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _is_weight_qbytes(t: torch.Tensor) -> bool:
|
|
65
|
+
try:
|
|
66
|
+
from optimum.quanto.tensor.weights.qbytes import WeightQBytesTensor
|
|
67
|
+
except Exception:
|
|
68
|
+
return False
|
|
69
|
+
return isinstance(t, WeightQBytesTensor)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _flatten_scale(scale: torch.Tensor) -> torch.Tensor:
|
|
73
|
+
if scale.ndim == 2 and scale.shape[1] == 1:
|
|
74
|
+
return scale.view(-1)
|
|
75
|
+
if scale.ndim == 1:
|
|
76
|
+
return scale
|
|
77
|
+
return scale.reshape(-1)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _quantize_activation_per_row(x_2d: torch.Tensor, scale_dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
81
|
+
# Symmetric int8 quantization per-row: scale = max(abs(row)) / 127
|
|
82
|
+
x_fp32 = x_2d.to(torch.float32)
|
|
83
|
+
amax = x_fp32.abs().amax(dim=1)
|
|
84
|
+
qmax = 127.0
|
|
85
|
+
scale = amax / qmax
|
|
86
|
+
scale = torch.where(scale == 0, torch.ones_like(scale), scale)
|
|
87
|
+
q = torch.round(x_fp32 / scale[:, None]).clamp(-qmax, qmax).to(torch.int8)
|
|
88
|
+
return q, scale.to(scale_dtype)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _pad_to_multiple(x: torch.Tensor, m_pad: int, k_pad: int) -> torch.Tensor:
|
|
92
|
+
if m_pad == 0 and k_pad == 0:
|
|
93
|
+
return x
|
|
94
|
+
# Pad last dim (K) and then rows (M)
|
|
95
|
+
if k_pad:
|
|
96
|
+
x = F.pad(x, (0, k_pad))
|
|
97
|
+
if m_pad:
|
|
98
|
+
x = F.pad(x, (0, 0, 0, m_pad))
|
|
99
|
+
return x
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _pad_weights_to_multiple(w: torch.Tensor, n_pad: int, k_pad: int) -> torch.Tensor:
|
|
103
|
+
if n_pad == 0 and k_pad == 0:
|
|
104
|
+
return w
|
|
105
|
+
if k_pad:
|
|
106
|
+
w = F.pad(w, (0, k_pad))
|
|
107
|
+
if n_pad:
|
|
108
|
+
w = F.pad(w, (0, 0, 0, n_pad))
|
|
109
|
+
return w
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _pad_scale_to_multiple(scale: torch.Tensor, pad: int, pad_value: float = 1.0) -> torch.Tensor:
|
|
113
|
+
if pad == 0:
|
|
114
|
+
return scale
|
|
115
|
+
return F.pad(scale, (0, pad), value=pad_value)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _resolve_kernel_op():
|
|
119
|
+
global _KERNEL_OP
|
|
120
|
+
if _KERNEL_OP is not None:
|
|
121
|
+
return _KERNEL_OP
|
|
122
|
+
# torch.ops path (preferred)
|
|
123
|
+
ops_ns = getattr(torch.ops, "mmgp_quanto_int8", None)
|
|
124
|
+
if ops_ns is not None and hasattr(ops_ns, "int8_scaled_mm"):
|
|
125
|
+
_KERNEL_OP = ops_ns.int8_scaled_mm
|
|
126
|
+
return _KERNEL_OP
|
|
127
|
+
# import module path
|
|
128
|
+
for mod_name in _KERNEL_MODULES:
|
|
129
|
+
try:
|
|
130
|
+
mod = importlib.import_module(mod_name)
|
|
131
|
+
except Exception:
|
|
132
|
+
continue
|
|
133
|
+
if hasattr(mod, "int8_scaled_mm"):
|
|
134
|
+
_KERNEL_OP = mod.int8_scaled_mm
|
|
135
|
+
return _KERNEL_OP
|
|
136
|
+
ops_ns = getattr(torch.ops, "mmgp_quanto_int8", None)
|
|
137
|
+
if ops_ns is not None and hasattr(ops_ns, "int8_scaled_mm"):
|
|
138
|
+
_KERNEL_OP = ops_ns.int8_scaled_mm
|
|
139
|
+
return _KERNEL_OP
|
|
140
|
+
raise RuntimeError(
|
|
141
|
+
"mmgp int8 kernel extension not loaded. Expected torch.ops.mmgp_quanto_int8.int8_scaled_mm "
|
|
142
|
+
"or a module exposing int8_scaled_mm (mmgp.quanto_int8_cuda)."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _quantize_with_kernel(x_2d: torch.Tensor, scale_dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
147
|
+
try:
|
|
148
|
+
mod = importlib.import_module("mmgp.quanto_int8_cuda")
|
|
149
|
+
q, s = mod.quantize_per_row_int8(x_2d)
|
|
150
|
+
if s.dtype != scale_dtype:
|
|
151
|
+
s = s.to(scale_dtype)
|
|
152
|
+
return q, s
|
|
153
|
+
except Exception:
|
|
154
|
+
return _quantize_activation_per_row(x_2d, scale_dtype)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _maybe_get_transposed_weight(other: torch.Tensor) -> torch.Tensor:
|
|
158
|
+
cached = getattr(other, "_mmgp_int8_t", None)
|
|
159
|
+
if isinstance(cached, torch.Tensor) and cached.device == other._data.device:
|
|
160
|
+
if cached.shape == (other._data.shape[1], other._data.shape[0]):
|
|
161
|
+
return cached
|
|
162
|
+
w_t = other._data.t().contiguous()
|
|
163
|
+
try:
|
|
164
|
+
setattr(other, "_mmgp_int8_t", w_t)
|
|
165
|
+
except Exception:
|
|
166
|
+
pass
|
|
167
|
+
return w_t
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _int8_tc_mm(
|
|
171
|
+
a_int8: torch.Tensor,
|
|
172
|
+
b_int8: torch.Tensor,
|
|
173
|
+
a_scale: torch.Tensor,
|
|
174
|
+
b_scale: torch.Tensor,
|
|
175
|
+
b_int8_t: Optional[torch.Tensor] = None,
|
|
176
|
+
) -> torch.Tensor:
|
|
177
|
+
if a_int8.dtype != torch.int8 or b_int8.dtype != torch.int8:
|
|
178
|
+
raise RuntimeError("int8 TC path requires int8 tensors")
|
|
179
|
+
a_int8 = a_int8.contiguous()
|
|
180
|
+
if b_int8_t is None:
|
|
181
|
+
b_int8_t = b_int8.t().contiguous()
|
|
182
|
+
|
|
183
|
+
# torch._int_mm expects [M, K] @ [K, N]
|
|
184
|
+
acc = torch._int_mm(a_int8, b_int8_t)
|
|
185
|
+
try:
|
|
186
|
+
mod = importlib.import_module("mmgp.quanto_int8_cuda")
|
|
187
|
+
return mod.scale_int32_to(acc, a_scale, b_scale)
|
|
188
|
+
except Exception:
|
|
189
|
+
# Fallback to torch ops if the scaling kernel isn't available
|
|
190
|
+
return (acc.float() * a_scale[:, None] * b_scale[None, :]).to(a_scale.dtype)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _int8_scaled_mm(
|
|
194
|
+
a_int8: torch.Tensor,
|
|
195
|
+
b_int8: torch.Tensor,
|
|
196
|
+
a_scale: torch.Tensor,
|
|
197
|
+
b_scale: torch.Tensor,
|
|
198
|
+
b_int8_t: Optional[torch.Tensor] = None,
|
|
199
|
+
) -> torch.Tensor:
|
|
200
|
+
# a_int8: [M, K], b_int8: [N, K], scales: a=[M], b=[N]
|
|
201
|
+
if not a_int8.is_cuda or not b_int8.is_cuda:
|
|
202
|
+
raise RuntimeError("int8 kernel requires CUDA tensors")
|
|
203
|
+
if a_int8.dtype != torch.int8 or b_int8.dtype != torch.int8:
|
|
204
|
+
raise RuntimeError("int8 kernel requires int8 activations and weights")
|
|
205
|
+
|
|
206
|
+
a_int8 = a_int8.contiguous()
|
|
207
|
+
b_int8 = b_int8.contiguous()
|
|
208
|
+
a_scale = _flatten_scale(a_scale).contiguous()
|
|
209
|
+
b_scale = _flatten_scale(b_scale).contiguous()
|
|
210
|
+
|
|
211
|
+
m, k = a_int8.shape
|
|
212
|
+
n = b_int8.shape[0]
|
|
213
|
+
use_tc = _env_flag(_ENV_USE_TC, "1")
|
|
214
|
+
if use_tc:
|
|
215
|
+
# torch._int_mm requires M > 16 and M/N/K multiples of 8
|
|
216
|
+
if m <= 16:
|
|
217
|
+
m_pad = 24 - m
|
|
218
|
+
else:
|
|
219
|
+
m_pad = (8 - (m % 8)) % 8
|
|
220
|
+
n_pad = (8 - (n % 8)) % 8
|
|
221
|
+
k_pad = (8 - (k % 8)) % 8
|
|
222
|
+
else:
|
|
223
|
+
align_m, align_n, align_k = _get_alignments()
|
|
224
|
+
m_pad = (align_m - (m % align_m)) % align_m
|
|
225
|
+
n_pad = (align_n - (n % align_n)) % align_n
|
|
226
|
+
k_pad = (align_k - (k % align_k)) % align_k
|
|
227
|
+
|
|
228
|
+
if m_pad or n_pad or k_pad:
|
|
229
|
+
a_int8 = _pad_to_multiple(a_int8, m_pad=m_pad, k_pad=k_pad)
|
|
230
|
+
b_int8 = _pad_weights_to_multiple(b_int8, n_pad=n_pad, k_pad=k_pad)
|
|
231
|
+
a_scale = _pad_scale_to_multiple(a_scale, pad=m_pad, pad_value=1.0)
|
|
232
|
+
b_scale = _pad_scale_to_multiple(b_scale, pad=n_pad, pad_value=1.0)
|
|
233
|
+
if b_int8_t is not None and b_int8_t.shape != (b_int8.shape[1], b_int8.shape[0]):
|
|
234
|
+
b_int8_t = None
|
|
235
|
+
|
|
236
|
+
if use_tc:
|
|
237
|
+
out = _int8_tc_mm(a_int8, b_int8, a_scale, b_scale, b_int8_t=b_int8_t)
|
|
238
|
+
else:
|
|
239
|
+
op = _resolve_kernel_op()
|
|
240
|
+
out = op(a_int8, b_int8, a_scale, b_scale)
|
|
241
|
+
if m_pad or n_pad:
|
|
242
|
+
out = out[:m, :n]
|
|
243
|
+
return out
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _use_int8_kernel(input: torch.Tensor, other: torch.Tensor) -> bool:
|
|
247
|
+
if not _is_weight_qbytes(other):
|
|
248
|
+
return False
|
|
249
|
+
if other._data.dtype != torch.int8:
|
|
250
|
+
return False
|
|
251
|
+
if not other._data.is_cuda:
|
|
252
|
+
return False
|
|
253
|
+
if _is_qbytes_tensor(input):
|
|
254
|
+
return input._data.dtype == torch.int8 and input._data.is_cuda
|
|
255
|
+
return input.is_cuda and input.dtype in (torch.bfloat16, torch.float16, torch.float32)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _int8_linear_forward(ctx, input: torch.Tensor, other: torch.Tensor, bias: Optional[torch.Tensor]):
|
|
259
|
+
ctx.save_for_backward(input, other)
|
|
260
|
+
|
|
261
|
+
input_shape = input.shape
|
|
262
|
+
in_features = input_shape[-1]
|
|
263
|
+
out_features = other.shape[0]
|
|
264
|
+
|
|
265
|
+
# Prepare activations
|
|
266
|
+
if _is_qbytes_tensor(input):
|
|
267
|
+
a_2d = input._data.reshape(-1, in_features)
|
|
268
|
+
a_scale = input._scale
|
|
269
|
+
a_scale = _flatten_scale(a_scale).to(other._scale.dtype)
|
|
270
|
+
if a_scale.numel() == 1:
|
|
271
|
+
a_scale = a_scale.reshape(1).expand(a_2d.shape[0]).contiguous()
|
|
272
|
+
elif a_scale.numel() != a_2d.shape[0]:
|
|
273
|
+
raise RuntimeError("Activation scale length does not match token count")
|
|
274
|
+
a_int8 = a_2d
|
|
275
|
+
else:
|
|
276
|
+
a_2d = input.reshape(-1, in_features)
|
|
277
|
+
a_int8, a_scale = _quantize_with_kernel(a_2d, other._scale.dtype)
|
|
278
|
+
|
|
279
|
+
# Per-output scale is handled inside the kernel: out = sum(a*b) * a_scale[row] * b_scale[col]
|
|
280
|
+
b_scale = _flatten_scale(other._scale).to(other._scale.dtype)
|
|
281
|
+
if b_scale.numel() != out_features:
|
|
282
|
+
raise RuntimeError("Weight scale length does not match output features")
|
|
283
|
+
|
|
284
|
+
b_int8_t = _maybe_get_transposed_weight(other)
|
|
285
|
+
out_2d = _int8_scaled_mm(a_int8, other._data, a_scale, b_scale, b_int8_t=b_int8_t)
|
|
286
|
+
out = out_2d.reshape(input_shape[:-1] + (out_features,))
|
|
287
|
+
|
|
288
|
+
if bias is not None:
|
|
289
|
+
out = out + bias
|
|
290
|
+
return out
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
_PATCH_STATE = SimpleNamespace(enabled=False, orig_forward=None)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def enable_quanto_int8_kernel() -> bool:
|
|
297
|
+
if _PATCH_STATE.enabled:
|
|
298
|
+
return True
|
|
299
|
+
try:
|
|
300
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes
|
|
301
|
+
except Exception:
|
|
302
|
+
return False
|
|
303
|
+
|
|
304
|
+
orig_forward = _qbytes.WeightQBytesLinearFunction.forward
|
|
305
|
+
|
|
306
|
+
def forward(ctx, input, other, bias=None):
|
|
307
|
+
if _use_int8_kernel(input, other):
|
|
308
|
+
_debug("using mmgp int8 kernel")
|
|
309
|
+
return _int8_linear_forward(ctx, input, other, bias)
|
|
310
|
+
return orig_forward(ctx, input, other, bias)
|
|
311
|
+
|
|
312
|
+
_qbytes.WeightQBytesLinearFunction.forward = staticmethod(forward)
|
|
313
|
+
_PATCH_STATE.enabled = True
|
|
314
|
+
_PATCH_STATE.orig_forward = orig_forward
|
|
315
|
+
return True
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def disable_quanto_int8_kernel() -> bool:
|
|
319
|
+
if not _PATCH_STATE.enabled:
|
|
320
|
+
return False
|
|
321
|
+
from optimum.quanto.tensor.weights import qbytes as _qbytes
|
|
322
|
+
_qbytes.WeightQBytesLinearFunction.forward = staticmethod(_PATCH_STATE.orig_forward)
|
|
323
|
+
_PATCH_STATE.enabled = False
|
|
324
|
+
_PATCH_STATE.orig_forward = None
|
|
325
|
+
return True
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def maybe_enable_quanto_int8_kernel() -> bool:
|
|
329
|
+
if not _env_flag(_ENV_ENABLE, "1"):
|
|
330
|
+
return False
|
|
331
|
+
return enable_quanto_int8_kernel()
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
# Auto-enable on import (default on, can be disabled via env)
|
|
335
|
+
maybe_enable_quanto_int8_kernel()
|
mmgp/safetensors2.py
CHANGED
|
@@ -46,7 +46,16 @@ class MmapTracker:
|
|
|
46
46
|
file_path = os.path.join(*s)
|
|
47
47
|
self.file_path = file_path # os.path.abspath(file_path)
|
|
48
48
|
self.count = 0
|
|
49
|
-
|
|
49
|
+
key = file_path
|
|
50
|
+
i = 1
|
|
51
|
+
while True:
|
|
52
|
+
if key not in mmm:
|
|
53
|
+
mmm[key] = self
|
|
54
|
+
break
|
|
55
|
+
i +=1
|
|
56
|
+
key = key + "#" + str(i)
|
|
57
|
+
self.mmm_key = key
|
|
58
|
+
# print(f"MMAP Add: {file_path}: {mmm.keys()}")
|
|
50
59
|
|
|
51
60
|
def register(self, mmap_obj, map_id, start, size):
|
|
52
61
|
|
|
@@ -61,7 +70,8 @@ class MmapTracker:
|
|
|
61
70
|
|
|
62
71
|
print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
|
|
63
72
|
if self.count == self._already_released:
|
|
64
|
-
|
|
73
|
+
# print(f"MMAP Del: {self.file_path}: {mmm.keys()}")
|
|
74
|
+
del mmm[self.mmm_key ]
|
|
65
75
|
|
|
66
76
|
self._maps.pop(map_id, None)
|
|
67
77
|
|
|
@@ -77,7 +87,7 @@ class MmapTracker:
|
|
|
77
87
|
def get_active_maps(self):
|
|
78
88
|
return dict(self._maps)
|
|
79
89
|
|
|
80
|
-
class tensor_slice:
|
|
90
|
+
class tensor_slice:
|
|
81
91
|
catalog = None
|
|
82
92
|
value = None
|
|
83
93
|
name = None
|
|
@@ -93,9 +103,33 @@ class tensor_slice:
|
|
|
93
103
|
def get_dtype(self):
|
|
94
104
|
return self.catalog[self.name]["dtype"]
|
|
95
105
|
|
|
96
|
-
def get_shape(self):
|
|
97
|
-
return self.catalog[self.name]["shape"]
|
|
98
|
-
|
|
106
|
+
def get_shape(self):
|
|
107
|
+
return self.catalog[self.name]["shape"]
|
|
108
|
+
|
|
109
|
+
class tensor_stub:
|
|
110
|
+
dtype = None
|
|
111
|
+
shape = None
|
|
112
|
+
|
|
113
|
+
def __init__(self, dtype, shape):
|
|
114
|
+
self.dtype = dtype
|
|
115
|
+
self.shape = tuple(shape)
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def ndim(self):
|
|
119
|
+
return len(self.shape)
|
|
120
|
+
|
|
121
|
+
def numel(self):
|
|
122
|
+
if not self.shape:
|
|
123
|
+
return 1
|
|
124
|
+
n = 1
|
|
125
|
+
for dim in self.shape:
|
|
126
|
+
n *= int(dim)
|
|
127
|
+
return n
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def device(self):
|
|
131
|
+
return torch.device("cpu")
|
|
132
|
+
|
|
99
133
|
class cached_metadata:
|
|
100
134
|
file_path = None
|
|
101
135
|
file_length = 0
|
|
@@ -139,7 +173,7 @@ def _parse_metadata(metadata):
|
|
|
139
173
|
new_metadata["format"] = "pt"
|
|
140
174
|
return new_metadata
|
|
141
175
|
|
|
142
|
-
def _read_safetensors_header(path, file):
|
|
176
|
+
def _read_safetensors_header(path, file):
|
|
143
177
|
global _cached_entry
|
|
144
178
|
length_of_header_bytes = file.read(8)
|
|
145
179
|
# Interpret the bytes as a little-endian unsigned 64-bit integer
|
|
@@ -161,7 +195,20 @@ def _read_safetensors_header(path, file):
|
|
|
161
195
|
else:
|
|
162
196
|
file.seek(length_of_header, 1)
|
|
163
197
|
|
|
164
|
-
return catalog, metadata, length_of_header + 8
|
|
198
|
+
return catalog, metadata, length_of_header + 8
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def load_metadata_state_dict(file_path):
|
|
202
|
+
with open(file_path, 'rb') as f:
|
|
203
|
+
catalog, metadata, _ = _read_safetensors_header(file_path, f)
|
|
204
|
+
sd = OrderedDict()
|
|
205
|
+
for k, v in catalog.items():
|
|
206
|
+
dtypestr = v["dtype"]
|
|
207
|
+
dtype = _map_to_dtype.get(dtypestr)
|
|
208
|
+
if dtype is None:
|
|
209
|
+
raise KeyError(f"Unknown safetensors dtype '{dtypestr}' in {file_path}")
|
|
210
|
+
sd[k] = tensor_stub(dtype, v["shape"])
|
|
211
|
+
return sd, metadata
|
|
165
212
|
|
|
166
213
|
|
|
167
214
|
def torch_write_file(sd, file_path, quantization_map = None, config = None, extra_meta = None):
|
|
@@ -240,7 +287,7 @@ def torch_write_file(sd, file_path, quantization_map = None, config = None, extr
|
|
|
240
287
|
t = t.view(torch.uint16)
|
|
241
288
|
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn:
|
|
242
289
|
t = t.view(torch.uint8)
|
|
243
|
-
buffer = t.numpy().tobytes()
|
|
290
|
+
buffer = t.cpu().numpy().tobytes()
|
|
244
291
|
bytes_written = writer.write(buffer)
|
|
245
292
|
assert bytes_written == size
|
|
246
293
|
i+=1
|
|
@@ -488,4 +535,4 @@ try:
|
|
|
488
535
|
transformers.modeling_utils.safe_open = safe_open
|
|
489
536
|
transformers.modeling_utils.safe_load_file = torch_load_file
|
|
490
537
|
except:
|
|
491
|
-
pass
|
|
538
|
+
pass
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mmgp
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.6.11
|
|
4
4
|
Summary: Memory Management for the GPU Poor
|
|
5
5
|
Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -15,7 +15,7 @@ Dynamic: license-file
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
<p align="center">
|
|
18
|
-
<H2>Memory Management 3.
|
|
18
|
+
<H2>Memory Management 3.6.11 for the GPU Poor by DeepBeepMeep</H2>
|
|
19
19
|
</p>
|
|
20
20
|
|
|
21
21
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
|
|
3
|
+
mmgp/fp8_quanto_bridge.py,sha256=vTKnzWKe88MgMl2z0gzcu7EGTVHM7c2Wg95dmqMU3vM,26356
|
|
4
|
+
mmgp/fp8_quanto_bridge_old.py,sha256=VtUaD6wzo7Yn9vGY0LMtbhwt6KMWRpSWLc65bU_sfZU,21155
|
|
5
|
+
mmgp/offload.py,sha256=emZtKU6Fq_3d39EIplkJ6c4pd5PNIPwLqsD74qOixkk,160548
|
|
6
|
+
mmgp/quant_router.py,sha256=WGh9C0eRV4EPC1eAEbUjrfxwfieqC6lgMnGNYdeUEjg,16515
|
|
7
|
+
mmgp/quanto_int8_cuda.py,sha256=LD5pTtM-bNgKseI1wnHp8JxeUl2Q4uSL8TbFYT1Jg5s,3258
|
|
8
|
+
mmgp/quanto_int8_inject.py,sha256=MyegEMZvUmc2iOmaRC7_zTlbvqlfKaN0xDoOGTBkAzY,11038
|
|
9
|
+
mmgp/safetensors2.py,sha256=vPaM8rGjrJosj_5WAYe9Xgr2_oGKmeB-8bLOANhA2aQ,19935
|
|
10
|
+
mmgp-3.6.11.dist-info/licenses/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
|
|
11
|
+
mmgp-3.6.11.dist-info/METADATA,sha256=wwOcRwAyxkjiXvhWvgpdVGBWp7QcLJQQ7Zl_eMzCyig,16311
|
|
12
|
+
mmgp-3.6.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
+
mmgp-3.6.11.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
|
|
14
|
+
mmgp-3.6.11.dist-info/RECORD,,
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
GNU GENERAL PUBLIC LICENSE
|
|
1
|
+
GNU GENERAL PUBLIC LICENSE
|
|
2
2
|
Version 3, 29 June 2007
|
mmgp-3.5.7.dist-info/RECORD
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
|
|
3
|
-
mmgp/offload.py,sha256=SKt-EunQrH6omBFI7aNLe82GIoXBKW9y1i0HMPFrKLY,127089
|
|
4
|
-
mmgp/safetensors2.py,sha256=4nKV13qCMabnNEB1TA_ueFbfGYYmiQ9racR_C6SsGug,18693
|
|
5
|
-
mmgp-3.5.7.dist-info/licenses/LICENSE.md,sha256=DD-WIS0BkPoWJ_8hQO3J8hMP9K_1-dyrYv1YCbkxcDU,94
|
|
6
|
-
mmgp-3.5.7.dist-info/METADATA,sha256=s420bK-WQuSZM2RpVwYjzXY-QmtIHkRbIiL9hAyV7sA,16309
|
|
7
|
-
mmgp-3.5.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
8
|
-
mmgp-3.5.7.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
|
|
9
|
-
mmgp-3.5.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|