mmgp 2.0.4__py3-none-any.whl → 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mmgp might be problematic. Click here for more details.
- mmgp/__init__.py +22 -0
- mmgp/offload.py +1470 -0
- mmgp/safetensors2.py +394 -0
- {mmgp-2.0.4.dist-info → mmgp-3.0.1.dist-info}/LICENSE.md +1 -1
- {mmgp-2.0.4.dist-info → mmgp-3.0.1.dist-info}/METADATA +157 -137
- mmgp-3.0.1.dist-info/RECORD +9 -0
- mmgp-2.0.4.dist-info/RECORD +0 -7
- mmgp.py +0 -951
- {mmgp-2.0.4.dist-info → mmgp-3.0.1.dist-info}/WHEEL +0 -0
- {mmgp-2.0.4.dist-info → mmgp-3.0.1.dist-info}/top_level.txt +0 -0
mmgp/safetensors2.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
# ------------------ Safetensors2 1.0 by DeepBeepMeep (mmgp)------------------
|
|
2
|
+
#
|
|
3
|
+
# This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
|
|
4
|
+
# It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
|
|
5
|
+
# You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from typing import Optional, Dict, List, Iterator, Tuple
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import torch
|
|
11
|
+
import mmap
|
|
12
|
+
import struct
|
|
13
|
+
import json
|
|
14
|
+
import base64
|
|
15
|
+
import safetensors
|
|
16
|
+
import accelerate
|
|
17
|
+
import os
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_old_torch_load_file = None
|
|
22
|
+
_old_safe_open = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
mmm = {}
|
|
27
|
+
verboseLevel = 1
|
|
28
|
+
|
|
29
|
+
import weakref
|
|
30
|
+
|
|
31
|
+
_map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
|
|
32
|
+
'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
|
|
33
|
+
'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MmapTracker:
|
|
37
|
+
def __init__(self, file_path):
|
|
38
|
+
self._maps = {}
|
|
39
|
+
self._already_released = 0
|
|
40
|
+
from pathlib import Path
|
|
41
|
+
s = Path(file_path).parts
|
|
42
|
+
if len(s)>2:
|
|
43
|
+
s = s[-2:]
|
|
44
|
+
file_path = os.path.join(*s)
|
|
45
|
+
self.file_path = file_path # os.path.abspath(file_path)
|
|
46
|
+
self.count = 0
|
|
47
|
+
mmm[file_path] = self
|
|
48
|
+
|
|
49
|
+
def register(self, mmap_obj, map_id, start, size):
|
|
50
|
+
|
|
51
|
+
self.count += 1
|
|
52
|
+
def finalizer(ref):
|
|
53
|
+
self._already_released += 1
|
|
54
|
+
if verboseLevel >=2:
|
|
55
|
+
if self.count == self._already_released:
|
|
56
|
+
text =" (all the mmaps have been released)"
|
|
57
|
+
else:
|
|
58
|
+
text =f" ({self.count-self._already_released:} left)"
|
|
59
|
+
|
|
60
|
+
print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
|
|
61
|
+
if self.count == self._already_released:
|
|
62
|
+
del mmm[self.file_path]
|
|
63
|
+
|
|
64
|
+
self._maps.pop(map_id, None)
|
|
65
|
+
|
|
66
|
+
wr = weakref.ref(mmap_obj, finalizer)
|
|
67
|
+
self._maps[map_id] = {
|
|
68
|
+
'mmap' : wr,
|
|
69
|
+
'start': start,
|
|
70
|
+
'size': size,
|
|
71
|
+
'end': start + size
|
|
72
|
+
}
|
|
73
|
+
return wr
|
|
74
|
+
|
|
75
|
+
def get_active_maps(self):
|
|
76
|
+
return dict(self._maps)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class cached_metadata:
|
|
80
|
+
file_path = None
|
|
81
|
+
file_length = 0
|
|
82
|
+
file_date = None
|
|
83
|
+
catalog = None
|
|
84
|
+
metadata = None
|
|
85
|
+
skip_bytes = 0
|
|
86
|
+
|
|
87
|
+
def __init__(self, file_path, catalog, metadata, skip_bytes):
|
|
88
|
+
self.catalog = catalog
|
|
89
|
+
self.metadata = metadata
|
|
90
|
+
self.skip_bytes = skip_bytes
|
|
91
|
+
file_stats = os.stat(file_path)
|
|
92
|
+
self.file_path = os.path.abspath(file_path)
|
|
93
|
+
self.file_length = file_stats.st_size
|
|
94
|
+
self.file_date = file_stats.st_ctime
|
|
95
|
+
|
|
96
|
+
def get_metadata(self, file_path):
|
|
97
|
+
file_stats = os.stat(file_path)
|
|
98
|
+
file_length = file_stats.st_size
|
|
99
|
+
file_date = file_stats.st_ctime
|
|
100
|
+
file_path = os.path.abspath(file_path)
|
|
101
|
+
if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
|
|
102
|
+
return None, None, None
|
|
103
|
+
return self.catalog, self.metadata, self.skip_bytes
|
|
104
|
+
|
|
105
|
+
_cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
|
|
106
|
+
|
|
107
|
+
def _parse_metadata(metadata):
|
|
108
|
+
if metadata == None:
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
new_metadata= {}
|
|
112
|
+
|
|
113
|
+
for k,v in metadata.items():
|
|
114
|
+
if k.endswith("_base64"):
|
|
115
|
+
v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
|
|
116
|
+
p = k.rfind("_")
|
|
117
|
+
new_k = k[:p]
|
|
118
|
+
new_metadata[new_k]= v_decoded
|
|
119
|
+
else:
|
|
120
|
+
new_metadata[k] = v
|
|
121
|
+
|
|
122
|
+
return new_metadata
|
|
123
|
+
|
|
124
|
+
def _read_safetensors_header(path, file):
|
|
125
|
+
global _cached_entry
|
|
126
|
+
length_of_header_bytes = file.read(8)
|
|
127
|
+
# Interpret the bytes as a little-endian unsigned 64-bit integer
|
|
128
|
+
length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
|
|
129
|
+
|
|
130
|
+
if _cached_entry != None:
|
|
131
|
+
catalog, metadata, _ = _cached_entry.get_metadata(path)
|
|
132
|
+
else:
|
|
133
|
+
catalog = None
|
|
134
|
+
|
|
135
|
+
if catalog == None:
|
|
136
|
+
header_bytes = file.read(length_of_header)
|
|
137
|
+
#catalog = json.loads(header_bytes.decode('utf-8'))
|
|
138
|
+
catalog = json.loads(header_bytes)
|
|
139
|
+
metadata = catalog.pop("__metadata__", None)
|
|
140
|
+
metadata = _parse_metadata(metadata)
|
|
141
|
+
|
|
142
|
+
_cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
|
|
143
|
+
else:
|
|
144
|
+
file.seek(length_of_header, 1)
|
|
145
|
+
|
|
146
|
+
return catalog, metadata, length_of_header + 8
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def torch_write_file(sd, file_path, quantization_map = None, config = None):
|
|
150
|
+
from collections import OrderedDict
|
|
151
|
+
sf_sd = OrderedDict()
|
|
152
|
+
|
|
153
|
+
map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
|
|
154
|
+
torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
|
|
155
|
+
torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
|
|
156
|
+
pos = 0
|
|
157
|
+
i = 0
|
|
158
|
+
mx = 1000000
|
|
159
|
+
for k , t in sd.items():
|
|
160
|
+
entry = {}
|
|
161
|
+
dtypestr= map[t.dtype]
|
|
162
|
+
entry["dtype"] = dtypestr
|
|
163
|
+
entry["shape"] = list(t.shape)
|
|
164
|
+
size = torch.numel(t) * t.element_size()
|
|
165
|
+
entry["data_offsets"] = [pos, pos + size]
|
|
166
|
+
pos += size
|
|
167
|
+
sf_sd[k] = entry
|
|
168
|
+
i+=1
|
|
169
|
+
if i==mx:
|
|
170
|
+
break
|
|
171
|
+
metadata = dict()
|
|
172
|
+
if not quantization_map is None:
|
|
173
|
+
metadata["quantization_format"] = "quanto"
|
|
174
|
+
metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
175
|
+
|
|
176
|
+
if not config is None:
|
|
177
|
+
metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
178
|
+
|
|
179
|
+
if len(metadata) > 0:
|
|
180
|
+
sf_sd["__metadata__"] = metadata
|
|
181
|
+
|
|
182
|
+
header_bytes = json.dumps(sf_sd).encode()
|
|
183
|
+
#header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
|
|
184
|
+
size_header = len(header_bytes)
|
|
185
|
+
import struct
|
|
186
|
+
|
|
187
|
+
length_of_header_bytes = struct.pack('<Q', size_header)
|
|
188
|
+
|
|
189
|
+
empty_tensor = b'\x80\x3f'
|
|
190
|
+
|
|
191
|
+
with open(file_path, "wb") as writer:
|
|
192
|
+
bytes_written = writer.write(length_of_header_bytes)
|
|
193
|
+
bytes_written = writer.write(header_bytes)
|
|
194
|
+
|
|
195
|
+
i = 0
|
|
196
|
+
for k , t in sd.items():
|
|
197
|
+
size = torch.numel(t) * t.element_size()
|
|
198
|
+
if len(t.shape) == 0:
|
|
199
|
+
bytes_written = writer.write(empty_tensor)
|
|
200
|
+
else:
|
|
201
|
+
buffer = t.view(torch.uint8).numpy().tobytes()
|
|
202
|
+
bytes_written = writer.write(buffer)
|
|
203
|
+
assert bytes_written == size
|
|
204
|
+
i+=1
|
|
205
|
+
if i==mx:
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
class SafeTensorFile:
|
|
209
|
+
"""Main class for accessing safetensors files that provides memory-efficient access"""
|
|
210
|
+
|
|
211
|
+
def __init__(self, file_path, metadata, catalog, skip_bytes):
|
|
212
|
+
self._file_path = file_path
|
|
213
|
+
self._metadata = metadata
|
|
214
|
+
self._catalog = catalog
|
|
215
|
+
self._skip_bytes = skip_bytes
|
|
216
|
+
self._keys = None
|
|
217
|
+
self.sd = None
|
|
218
|
+
self.mtracker = None
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def load_metadata(cls, file_path):
|
|
222
|
+
with open(file_path, 'rb') as f:
|
|
223
|
+
catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
|
|
224
|
+
|
|
225
|
+
return cls(file_path, metadata, catalog, skip_bytes)
|
|
226
|
+
|
|
227
|
+
def init_tensors(self):
|
|
228
|
+
if self.sd is None:
|
|
229
|
+
self.sd = self.create_tensors()
|
|
230
|
+
return self.sd
|
|
231
|
+
|
|
232
|
+
def create_tensors(self):
|
|
233
|
+
|
|
234
|
+
self.mtracker = MmapTracker(self._file_path)
|
|
235
|
+
import mmap
|
|
236
|
+
|
|
237
|
+
PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
|
|
238
|
+
MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
|
|
239
|
+
|
|
240
|
+
# First pass: find optimal aligned map boundaries
|
|
241
|
+
skip_bytes = self._skip_bytes
|
|
242
|
+
tensor_map_indexes = []
|
|
243
|
+
maps_info = []
|
|
244
|
+
current_pos = skip_bytes
|
|
245
|
+
current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
|
|
246
|
+
current_map_size = skip_bytes - current_map_start
|
|
247
|
+
idx = 0
|
|
248
|
+
for k,v in self._catalog.items():
|
|
249
|
+
data_offsets = v["data_offsets"]
|
|
250
|
+
length = data_offsets[1]-data_offsets[0]
|
|
251
|
+
if current_map_size + length > MMAP_SIZE:
|
|
252
|
+
maps_info.append((current_map_start, current_map_size))
|
|
253
|
+
current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
|
|
254
|
+
current_map_size = current_pos - current_map_start
|
|
255
|
+
idx += 1
|
|
256
|
+
tensor_map_indexes.append(idx)
|
|
257
|
+
current_map_size += length
|
|
258
|
+
current_pos += length
|
|
259
|
+
|
|
260
|
+
maps_info.append((current_map_start, current_map_size))
|
|
261
|
+
|
|
262
|
+
# Second pass: create maps and tensors
|
|
263
|
+
maps = []
|
|
264
|
+
sd = OrderedDict()
|
|
265
|
+
|
|
266
|
+
current_pos = skip_bytes
|
|
267
|
+
with open(self._file_path, 'rb') as f:
|
|
268
|
+
i = 0
|
|
269
|
+
for map_start, map_size in maps_info:
|
|
270
|
+
mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access=mmap.ACCESS_COPY) #.ACCESS_READ
|
|
271
|
+
maps.append((mm, map_start, map_size))
|
|
272
|
+
self.mtracker.register(mm, i, map_start, map_size)
|
|
273
|
+
i = i+ 1
|
|
274
|
+
|
|
275
|
+
iter_tensor_no = iter(tensor_map_indexes)
|
|
276
|
+
for k,v in self._catalog.items():
|
|
277
|
+
dtypestr = v["dtype"]
|
|
278
|
+
dtype= _map_to_dtype[dtypestr]
|
|
279
|
+
shape = v["shape"]
|
|
280
|
+
data_offsets = v["data_offsets"]
|
|
281
|
+
length = data_offsets[1]-data_offsets[0]
|
|
282
|
+
map_idx = next(iter_tensor_no)
|
|
283
|
+
offset = current_pos - maps[map_idx][1]
|
|
284
|
+
if len(shape) == 0:
|
|
285
|
+
t = torch.ones((), dtype=dtype, device="cpu")
|
|
286
|
+
else:
|
|
287
|
+
mv = memoryview(maps[map_idx][0])[offset:offset + length]
|
|
288
|
+
t = torch.frombuffer(mv, dtype=dtype)
|
|
289
|
+
t = torch.reshape(t, shape)
|
|
290
|
+
# t._mmap = maps[map_idx][0]
|
|
291
|
+
sd[k] = t
|
|
292
|
+
current_pos += length
|
|
293
|
+
|
|
294
|
+
return sd
|
|
295
|
+
|
|
296
|
+
def get_tensor(self, name: str) -> torch.tensor:
|
|
297
|
+
"""Get a tensor by name"""
|
|
298
|
+
self.init_tensors()
|
|
299
|
+
return self.sd[name]
|
|
300
|
+
|
|
301
|
+
def keys(self) -> List[str]:
|
|
302
|
+
"""Get list of tensor names"""
|
|
303
|
+
if self._keys is None:
|
|
304
|
+
self._keys = list(self._catalog)
|
|
305
|
+
return self._keys
|
|
306
|
+
|
|
307
|
+
def names(self) -> List[str]:
|
|
308
|
+
"""Alias for keys()"""
|
|
309
|
+
return self.keys()
|
|
310
|
+
|
|
311
|
+
def tensors(self) -> Dict[str, torch.tensor]:
|
|
312
|
+
"""Get dictionary of all tensors"""
|
|
313
|
+
self.init_tensors()
|
|
314
|
+
return self.sd
|
|
315
|
+
|
|
316
|
+
def metadata(self) -> Optional[Dict[str, str]]:
|
|
317
|
+
"""Get metadata dictionary"""
|
|
318
|
+
return self._metadata
|
|
319
|
+
|
|
320
|
+
def __len__(self) -> int:
|
|
321
|
+
"""Get number of tensors"""
|
|
322
|
+
self.init_tensors()
|
|
323
|
+
return len(self.keys())
|
|
324
|
+
|
|
325
|
+
def __contains__(self, key: str) -> bool:
|
|
326
|
+
"""Check if tensor exists"""
|
|
327
|
+
return key in self.keys()
|
|
328
|
+
|
|
329
|
+
def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
|
|
330
|
+
"""Iterate over (name, tensor) pairs"""
|
|
331
|
+
return ((name, self.get_tensor(name)) for name in self.keys())
|
|
332
|
+
|
|
333
|
+
def _free_resources(self):
|
|
334
|
+
del self.sd
|
|
335
|
+
del self._catalog
|
|
336
|
+
|
|
337
|
+
class _SafeTensorLoader:
|
|
338
|
+
"""Context manager for loading SafeTensorFile"""
|
|
339
|
+
|
|
340
|
+
def __init__(self, filename: str):
|
|
341
|
+
self.filename = Path(filename)
|
|
342
|
+
self.sft = None
|
|
343
|
+
|
|
344
|
+
if not self.filename.exists():
|
|
345
|
+
raise FileNotFoundError(f"File not found: {filename}")
|
|
346
|
+
|
|
347
|
+
def __enter__(self) -> SafeTensorFile:
|
|
348
|
+
"""Open file and return SafeTensorFile instance"""
|
|
349
|
+
|
|
350
|
+
try:
|
|
351
|
+
self.sft = SafeTensorFile.load_metadata(self.filename)
|
|
352
|
+
return self.sft
|
|
353
|
+
|
|
354
|
+
except Exception as e:
|
|
355
|
+
self.close()
|
|
356
|
+
raise Exception(f"Failed to load safetensors file: {e}") from e
|
|
357
|
+
|
|
358
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
359
|
+
"""Clean up resources"""
|
|
360
|
+
self.close()
|
|
361
|
+
|
|
362
|
+
def close(self) -> None:
|
|
363
|
+
if self.sft != None:
|
|
364
|
+
self.sft._free_resources()
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def safe_open(filename: str, framework: str = "pt",device = "cpu") -> _SafeTensorLoader:
|
|
369
|
+
if device != "cpu" or framework !="pt":
|
|
370
|
+
pass
|
|
371
|
+
return _old_safe_open(filename =filename, framework=framework, device=device)
|
|
372
|
+
return _SafeTensorLoader(filename)
|
|
373
|
+
|
|
374
|
+
def torch_load_file( filename, device = 'cpu' ) -> Dict[str, torch.Tensor]:
|
|
375
|
+
sd = {}
|
|
376
|
+
with safe_open(filename, framework="pt", device = device ) as f:
|
|
377
|
+
for k in f.keys():
|
|
378
|
+
sd[k] = f.get_tensor(k)
|
|
379
|
+
return sd
|
|
380
|
+
|
|
381
|
+
_old_torch_load_file = safetensors.torch.load_file
|
|
382
|
+
safetensors.torch.load_file = torch_load_file
|
|
383
|
+
_old_safe_open = safetensors.safe_open
|
|
384
|
+
safetensors.safe_open = safe_open
|
|
385
|
+
accelerate.utils.modeling.safe_open = safe_open
|
|
386
|
+
accelerate.utils.modeling.safe_load_file = torch_load_file
|
|
387
|
+
try:
|
|
388
|
+
import transformers
|
|
389
|
+
transformers.modeling_utils.safe_open = safe_open
|
|
390
|
+
transformers.modeling_utils.safe_load_file = torch_load_file
|
|
391
|
+
except:
|
|
392
|
+
pass
|
|
393
|
+
|
|
394
|
+
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
GNU GENERAL PUBLIC LICENSE
|
|
1
|
+
GNU GENERAL PUBLIC LICENSE
|
|
2
2
|
Version 3, 29 June 2007
|