mmgp 3.4.0__py3-none-any.whl → 3.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mmgp might be problematic. Click here for more details.
- mmgp/offload.py +2558 -2494
- mmgp/safetensors2.py +482 -461
- {mmgp-3.4.0.dist-info → mmgp-3.4.2.dist-info}/METADATA +2 -2
- mmgp-3.4.2.dist-info/RECORD +9 -0
- {mmgp-3.4.0.dist-info → mmgp-3.4.2.dist-info}/WHEEL +1 -1
- {mmgp-3.4.0.dist-info → mmgp-3.4.2.dist-info}/licenses/LICENSE.md +1 -1
- mmgp-3.4.0.dist-info/RECORD +0 -9
- {mmgp-3.4.0.dist-info → mmgp-3.4.2.dist-info}/top_level.txt +0 -0
mmgp/safetensors2.py
CHANGED
|
@@ -1,462 +1,483 @@
|
|
|
1
|
-
# ------------------ Safetensors2 1.1 by DeepBeepMeep (mmgp)------------------
|
|
2
|
-
#
|
|
3
|
-
# This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
|
|
4
|
-
# It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
|
|
5
|
-
# You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from typing import Optional, Dict, List, Iterator, Tuple
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
import torch
|
|
11
|
-
import mmap
|
|
12
|
-
import struct
|
|
13
|
-
import json
|
|
14
|
-
import base64
|
|
15
|
-
import safetensors
|
|
16
|
-
import accelerate
|
|
17
|
-
import os
|
|
18
|
-
from collections import OrderedDict
|
|
19
|
-
import warnings
|
|
20
|
-
|
|
21
|
-
warnings.filterwarnings("ignore", ".*The given buffer is not writable, and PyTorch does not support non-writable tensors*")
|
|
22
|
-
|
|
23
|
-
_old_torch_load_file = None
|
|
24
|
-
_old_safe_open = None
|
|
25
|
-
|
|
26
|
-
all_tensors_are_read_only = False
|
|
27
|
-
|
|
28
|
-
mmm = {}
|
|
29
|
-
verboseLevel = 1
|
|
30
|
-
|
|
31
|
-
import weakref
|
|
32
|
-
|
|
33
|
-
_map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
|
|
34
|
-
'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
|
|
35
|
-
'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class MmapTracker:
|
|
39
|
-
def __init__(self, file_path):
|
|
40
|
-
self._maps = {}
|
|
41
|
-
self._already_released = 0
|
|
42
|
-
from pathlib import Path
|
|
43
|
-
s = Path(file_path).parts
|
|
44
|
-
if len(s)>2:
|
|
45
|
-
s = s[-2:]
|
|
46
|
-
file_path = os.path.join(*s)
|
|
47
|
-
self.file_path = file_path # os.path.abspath(file_path)
|
|
48
|
-
self.count = 0
|
|
49
|
-
mmm[file_path] = self
|
|
50
|
-
|
|
51
|
-
def register(self, mmap_obj, map_id, start, size):
|
|
52
|
-
|
|
53
|
-
self.count += 1
|
|
54
|
-
def finalizer(ref):
|
|
55
|
-
self._already_released += 1
|
|
56
|
-
if verboseLevel >=2:
|
|
57
|
-
if self.count == self._already_released:
|
|
58
|
-
text =" (all the mmaps have been released)"
|
|
59
|
-
else:
|
|
60
|
-
text =f" ({self.count-self._already_released:} left)"
|
|
61
|
-
|
|
62
|
-
print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
|
|
63
|
-
if self.count == self._already_released:
|
|
64
|
-
del mmm[self.file_path]
|
|
65
|
-
|
|
66
|
-
self._maps.pop(map_id, None)
|
|
67
|
-
|
|
68
|
-
wr = weakref.ref(mmap_obj, finalizer)
|
|
69
|
-
self._maps[map_id] = {
|
|
70
|
-
'mmap' : wr,
|
|
71
|
-
'start': start,
|
|
72
|
-
'size': size,
|
|
73
|
-
'end': start + size
|
|
74
|
-
}
|
|
75
|
-
return wr
|
|
76
|
-
|
|
77
|
-
def get_active_maps(self):
|
|
78
|
-
return dict(self._maps)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
catalog
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
self.
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
self.
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
def
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
return self.
|
|
382
|
-
|
|
383
|
-
def
|
|
384
|
-
"""Get
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
self.
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
self.
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
1
|
+
# ------------------ Safetensors2 1.1 by DeepBeepMeep (mmgp)------------------
|
|
2
|
+
#
|
|
3
|
+
# This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
|
|
4
|
+
# It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
|
|
5
|
+
# You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from typing import Optional, Dict, List, Iterator, Tuple
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import torch
|
|
11
|
+
import mmap
|
|
12
|
+
import struct
|
|
13
|
+
import json
|
|
14
|
+
import base64
|
|
15
|
+
import safetensors
|
|
16
|
+
import accelerate
|
|
17
|
+
import os
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
warnings.filterwarnings("ignore", ".*The given buffer is not writable, and PyTorch does not support non-writable tensors*")
|
|
22
|
+
|
|
23
|
+
_old_torch_load_file = None
|
|
24
|
+
_old_safe_open = None
|
|
25
|
+
|
|
26
|
+
all_tensors_are_read_only = False
|
|
27
|
+
|
|
28
|
+
mmm = {}
|
|
29
|
+
verboseLevel = 1
|
|
30
|
+
|
|
31
|
+
import weakref
|
|
32
|
+
|
|
33
|
+
_map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
|
|
34
|
+
'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
|
|
35
|
+
'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MmapTracker:
|
|
39
|
+
def __init__(self, file_path):
|
|
40
|
+
self._maps = {}
|
|
41
|
+
self._already_released = 0
|
|
42
|
+
from pathlib import Path
|
|
43
|
+
s = Path(file_path).parts
|
|
44
|
+
if len(s)>2:
|
|
45
|
+
s = s[-2:]
|
|
46
|
+
file_path = os.path.join(*s)
|
|
47
|
+
self.file_path = file_path # os.path.abspath(file_path)
|
|
48
|
+
self.count = 0
|
|
49
|
+
mmm[file_path] = self
|
|
50
|
+
|
|
51
|
+
def register(self, mmap_obj, map_id, start, size):
|
|
52
|
+
|
|
53
|
+
self.count += 1
|
|
54
|
+
def finalizer(ref):
|
|
55
|
+
self._already_released += 1
|
|
56
|
+
if verboseLevel >=2:
|
|
57
|
+
if self.count == self._already_released:
|
|
58
|
+
text =" (all the mmaps have been released)"
|
|
59
|
+
else:
|
|
60
|
+
text =f" ({self.count-self._already_released:} left)"
|
|
61
|
+
|
|
62
|
+
print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
|
|
63
|
+
if self.count == self._already_released:
|
|
64
|
+
del mmm[self.file_path]
|
|
65
|
+
|
|
66
|
+
self._maps.pop(map_id, None)
|
|
67
|
+
|
|
68
|
+
wr = weakref.ref(mmap_obj, finalizer)
|
|
69
|
+
self._maps[map_id] = {
|
|
70
|
+
'mmap' : wr,
|
|
71
|
+
'start': start,
|
|
72
|
+
'size': size,
|
|
73
|
+
'end': start + size
|
|
74
|
+
}
|
|
75
|
+
return wr
|
|
76
|
+
|
|
77
|
+
def get_active_maps(self):
|
|
78
|
+
return dict(self._maps)
|
|
79
|
+
|
|
80
|
+
class tensor_slice:
|
|
81
|
+
catalog = None
|
|
82
|
+
value = None
|
|
83
|
+
name = None
|
|
84
|
+
|
|
85
|
+
def __init__(self, catalog, name, value):
|
|
86
|
+
self.catalog = catalog
|
|
87
|
+
self.value = value
|
|
88
|
+
self.name = name
|
|
89
|
+
|
|
90
|
+
def __getitem__(self, s):
|
|
91
|
+
return self.value[s]
|
|
92
|
+
|
|
93
|
+
def get_dtype(self):
|
|
94
|
+
return self.catalog[self.name]["dtype"]
|
|
95
|
+
|
|
96
|
+
def get_shape(self):
|
|
97
|
+
return self.catalog[self.name]["shape"]
|
|
98
|
+
|
|
99
|
+
class cached_metadata:
|
|
100
|
+
file_path = None
|
|
101
|
+
file_length = 0
|
|
102
|
+
file_date = None
|
|
103
|
+
catalog = None
|
|
104
|
+
metadata = None
|
|
105
|
+
skip_bytes = 0
|
|
106
|
+
|
|
107
|
+
def __init__(self, file_path, catalog, metadata, skip_bytes):
|
|
108
|
+
self.catalog = catalog
|
|
109
|
+
self.metadata = metadata
|
|
110
|
+
self.skip_bytes = skip_bytes
|
|
111
|
+
file_stats = os.stat(file_path)
|
|
112
|
+
self.file_path = os.path.abspath(file_path)
|
|
113
|
+
self.file_length = file_stats.st_size
|
|
114
|
+
self.file_date = file_stats.st_ctime
|
|
115
|
+
|
|
116
|
+
def get_metadata(self, file_path):
|
|
117
|
+
file_stats = os.stat(file_path)
|
|
118
|
+
file_length = file_stats.st_size
|
|
119
|
+
file_date = file_stats.st_ctime
|
|
120
|
+
file_path = os.path.abspath(file_path)
|
|
121
|
+
if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
|
|
122
|
+
return None, None, None
|
|
123
|
+
return self.catalog, self.metadata, self.skip_bytes
|
|
124
|
+
|
|
125
|
+
_cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
|
|
126
|
+
|
|
127
|
+
def _parse_metadata(metadata):
|
|
128
|
+
if metadata == None:
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
new_metadata= {}
|
|
132
|
+
|
|
133
|
+
for k,v in metadata.items():
|
|
134
|
+
if k.endswith("_base64"):
|
|
135
|
+
v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
|
|
136
|
+
p = k.rfind("_")
|
|
137
|
+
new_k = k[:p]
|
|
138
|
+
new_metadata[new_k]= v_decoded
|
|
139
|
+
else:
|
|
140
|
+
new_metadata[k] = v
|
|
141
|
+
|
|
142
|
+
return new_metadata
|
|
143
|
+
|
|
144
|
+
def _read_safetensors_header(path, file):
|
|
145
|
+
global _cached_entry
|
|
146
|
+
length_of_header_bytes = file.read(8)
|
|
147
|
+
# Interpret the bytes as a little-endian unsigned 64-bit integer
|
|
148
|
+
length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
|
|
149
|
+
|
|
150
|
+
if _cached_entry != None:
|
|
151
|
+
catalog, metadata, _ = _cached_entry.get_metadata(path)
|
|
152
|
+
else:
|
|
153
|
+
catalog = None
|
|
154
|
+
|
|
155
|
+
if catalog == None:
|
|
156
|
+
header_bytes = file.read(length_of_header)
|
|
157
|
+
#catalog = json.loads(header_bytes.decode('utf-8'))
|
|
158
|
+
catalog = json.loads(header_bytes)
|
|
159
|
+
metadata = catalog.pop("__metadata__", None)
|
|
160
|
+
metadata = _parse_metadata(metadata)
|
|
161
|
+
|
|
162
|
+
_cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
|
|
163
|
+
else:
|
|
164
|
+
file.seek(length_of_header, 1)
|
|
165
|
+
|
|
166
|
+
return catalog, metadata, length_of_header + 8
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def torch_write_file(sd, file_path, quantization_map = None, config = None, extra_meta = None):
|
|
170
|
+
from collections import OrderedDict
|
|
171
|
+
sf_sd = OrderedDict()
|
|
172
|
+
|
|
173
|
+
map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
|
|
174
|
+
torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
|
|
175
|
+
torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
|
|
176
|
+
pos = 0
|
|
177
|
+
i = 0
|
|
178
|
+
mx = 100000
|
|
179
|
+
metadata = dict()
|
|
180
|
+
for k , t in sd.items():
|
|
181
|
+
if torch.is_tensor(t):
|
|
182
|
+
entry = {}
|
|
183
|
+
dtypestr= map[t.dtype]
|
|
184
|
+
entry["dtype"] = dtypestr
|
|
185
|
+
entry["shape"] = list(t.shape)
|
|
186
|
+
size = torch.numel(t) * t.element_size()
|
|
187
|
+
if size == 0:
|
|
188
|
+
pass
|
|
189
|
+
entry["data_offsets"] = [pos, pos + size]
|
|
190
|
+
pos += size
|
|
191
|
+
sf_sd[k] = entry
|
|
192
|
+
else:
|
|
193
|
+
if isinstance(t, str):
|
|
194
|
+
metadata[k] = t
|
|
195
|
+
else:
|
|
196
|
+
try:
|
|
197
|
+
b64 = base64.b64encode(json.dumps(t, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
198
|
+
metadata[k + "_base64"] = b64
|
|
199
|
+
except:
|
|
200
|
+
pass
|
|
201
|
+
|
|
202
|
+
i+=1
|
|
203
|
+
if i==mx:
|
|
204
|
+
break
|
|
205
|
+
if not quantization_map is None:
|
|
206
|
+
metadata["quantization_format"] = "quanto"
|
|
207
|
+
metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
208
|
+
|
|
209
|
+
if not config is None:
|
|
210
|
+
metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
211
|
+
|
|
212
|
+
if not extra_meta is None:
|
|
213
|
+
for n , m in extra_meta.items():
|
|
214
|
+
if isinstance(m, str):
|
|
215
|
+
metadata[n] = m
|
|
216
|
+
else:
|
|
217
|
+
metadata[n + "_base64"] = base64.b64encode(json.dumps(m, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
if len(metadata) > 0:
|
|
221
|
+
sf_sd["__metadata__"] = metadata
|
|
222
|
+
|
|
223
|
+
header_bytes = json.dumps(sf_sd).encode()
|
|
224
|
+
#header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
|
|
225
|
+
size_header = len(header_bytes)
|
|
226
|
+
import struct
|
|
227
|
+
|
|
228
|
+
length_of_header_bytes = struct.pack('<Q', size_header)
|
|
229
|
+
|
|
230
|
+
with open(file_path, "wb") as writer:
|
|
231
|
+
bytes_written = writer.write(length_of_header_bytes)
|
|
232
|
+
bytes_written = writer.write(header_bytes)
|
|
233
|
+
|
|
234
|
+
i = 0
|
|
235
|
+
for k , t in sd.items():
|
|
236
|
+
if torch.is_tensor(t):
|
|
237
|
+
size = torch.numel(t) * t.element_size()
|
|
238
|
+
if size != 0:
|
|
239
|
+
dtype = t.dtype
|
|
240
|
+
# convert in a friendly format, scalars types not supported by numpy
|
|
241
|
+
if dtype == torch.bfloat16:
|
|
242
|
+
t = t.view(torch.uint16)
|
|
243
|
+
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn:
|
|
244
|
+
t = t.view(torch.uint8)
|
|
245
|
+
buffer = t.numpy().tobytes()
|
|
246
|
+
bytes_written = writer.write(buffer)
|
|
247
|
+
assert bytes_written == size
|
|
248
|
+
i+=1
|
|
249
|
+
if i==mx:
|
|
250
|
+
break
|
|
251
|
+
|
|
252
|
+
class SafeTensorFile:
|
|
253
|
+
"""Main class for accessing safetensors files that provides memory-efficient access"""
|
|
254
|
+
|
|
255
|
+
def __init__(self, file_path, metadata, catalog, skip_bytes, lazy_loading = True, writable_tensors = True):
|
|
256
|
+
self._file_path = file_path
|
|
257
|
+
self._metadata = metadata
|
|
258
|
+
self._catalog = catalog
|
|
259
|
+
self._skip_bytes = skip_bytes
|
|
260
|
+
self._keys = None
|
|
261
|
+
self.sd = None
|
|
262
|
+
self.mtracker = None
|
|
263
|
+
self.lazy_loading = lazy_loading
|
|
264
|
+
self.writable_tensors = writable_tensors
|
|
265
|
+
|
|
266
|
+
@classmethod
|
|
267
|
+
def load_metadata(cls, file_path, lazy_loading = True, writable_tensors = True):
|
|
268
|
+
with open(file_path, 'rb') as f:
|
|
269
|
+
catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
|
|
270
|
+
|
|
271
|
+
return cls(file_path, metadata, catalog, skip_bytes, lazy_loading, writable_tensors )
|
|
272
|
+
|
|
273
|
+
def init_tensors(self, lazyTensors = True, writable_tensors = True):
|
|
274
|
+
if self.sd is None:
|
|
275
|
+
self.lazy_loading = lazyTensors
|
|
276
|
+
if lazyTensors:
|
|
277
|
+
self.sd = self.create_tensors_with_mmap(writable_tensors)
|
|
278
|
+
else:
|
|
279
|
+
self.sd = self.create_tensors_without_mmap()
|
|
280
|
+
# else:
|
|
281
|
+
# if not self.lazy_loading and lazyTensors:
|
|
282
|
+
# raise Exception("Every tensor should be either lazy loaded or not lazy loaded")
|
|
283
|
+
|
|
284
|
+
return self.sd
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def create_tensors_with_mmap(self, writable_tensors = True):
|
|
288
|
+
|
|
289
|
+
self.mtracker = MmapTracker(self._file_path)
|
|
290
|
+
import mmap
|
|
291
|
+
|
|
292
|
+
PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
|
|
293
|
+
MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
|
|
294
|
+
# MMAP_SIZE = 256 * 1024 * 1024 # 1GB
|
|
295
|
+
|
|
296
|
+
# First pass: find optimal aligned map boundaries
|
|
297
|
+
skip_bytes = self._skip_bytes
|
|
298
|
+
tensor_map_indexes = []
|
|
299
|
+
maps_info = []
|
|
300
|
+
current_pos = skip_bytes
|
|
301
|
+
current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
|
|
302
|
+
current_map_size = skip_bytes - current_map_start
|
|
303
|
+
idx = 0
|
|
304
|
+
for k,v in self._catalog.items():
|
|
305
|
+
data_offsets = v["data_offsets"]
|
|
306
|
+
length = data_offsets[1]-data_offsets[0]
|
|
307
|
+
if current_map_size + length > MMAP_SIZE:
|
|
308
|
+
maps_info.append((current_map_start, current_map_size))
|
|
309
|
+
current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
|
|
310
|
+
current_map_size = current_pos - current_map_start
|
|
311
|
+
idx += 1
|
|
312
|
+
tensor_map_indexes.append(idx)
|
|
313
|
+
current_map_size += length
|
|
314
|
+
current_pos += length
|
|
315
|
+
|
|
316
|
+
maps_info.append((current_map_start, current_map_size))
|
|
317
|
+
|
|
318
|
+
# Second pass: create maps and tensors
|
|
319
|
+
maps = []
|
|
320
|
+
sd = OrderedDict()
|
|
321
|
+
|
|
322
|
+
current_pos = skip_bytes
|
|
323
|
+
with open(self._file_path, 'rb') as f:
|
|
324
|
+
i = 0
|
|
325
|
+
for map_start, map_size in maps_info:
|
|
326
|
+
mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access= mmap.ACCESS_COPY if writable_tensors else mmap.ACCESS_READ)
|
|
327
|
+
maps.append((mm, map_start, map_size))
|
|
328
|
+
self.mtracker.register(mm, i, map_start, map_size)
|
|
329
|
+
i = i+ 1
|
|
330
|
+
|
|
331
|
+
iter_tensor_no = iter(tensor_map_indexes)
|
|
332
|
+
for k,v in self._catalog.items():
|
|
333
|
+
dtypestr = v["dtype"]
|
|
334
|
+
dtype= _map_to_dtype[dtypestr]
|
|
335
|
+
shape = v["shape"]
|
|
336
|
+
data_offsets = v["data_offsets"]
|
|
337
|
+
length = data_offsets[1]-data_offsets[0]
|
|
338
|
+
map_idx = next(iter_tensor_no)
|
|
339
|
+
offset = current_pos - maps[map_idx][1]
|
|
340
|
+
if length == 0:
|
|
341
|
+
t = torch.empty(shape, dtype=dtype)
|
|
342
|
+
elif len(shape) == 0:
|
|
343
|
+
# don't waste a memory view for a scalar
|
|
344
|
+
t = torch.frombuffer(bytearray(maps[map_idx][0][offset:offset + length]), dtype=torch.uint8)
|
|
345
|
+
t = t.view(dtype)
|
|
346
|
+
else:
|
|
347
|
+
mv = memoryview(maps[map_idx][0])[offset:offset + length]
|
|
348
|
+
t = torch.frombuffer(mv, dtype=dtype)
|
|
349
|
+
t = torch.reshape(t, shape)
|
|
350
|
+
# t._mmap = maps[map_idx][0]
|
|
351
|
+
sd[k] = t
|
|
352
|
+
current_pos += length
|
|
353
|
+
|
|
354
|
+
return sd
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def create_tensors_without_mmap(self):
|
|
358
|
+
sd = OrderedDict()
|
|
359
|
+
|
|
360
|
+
with open(self._file_path, 'rb') as f:
|
|
361
|
+
f.seek(self._skip_bytes, 0)
|
|
362
|
+
for k,v in self._catalog.items():
|
|
363
|
+
dtypestr = v["dtype"]
|
|
364
|
+
dtype= _map_to_dtype[dtypestr]
|
|
365
|
+
shape = v["shape"]
|
|
366
|
+
data_offsets = v["data_offsets"]
|
|
367
|
+
length = data_offsets[1]-data_offsets[0]
|
|
368
|
+
buffer = f.read(length)
|
|
369
|
+
if length == 0:
|
|
370
|
+
t = torch.empty(0, dtype=dtype)
|
|
371
|
+
elif len(shape) == 0:
|
|
372
|
+
t = torch.frombuffer(bytearray(buffer), dtype=torch.uint8)
|
|
373
|
+
t = t.view(dtype)
|
|
374
|
+
else:
|
|
375
|
+
t = torch.frombuffer(bytearray(buffer), dtype=dtype)
|
|
376
|
+
t = torch.reshape(t, shape)
|
|
377
|
+
sd[k] = t
|
|
378
|
+
return sd
|
|
379
|
+
|
|
380
|
+
def get_slice(self, name: str) -> torch.tensor:
|
|
381
|
+
return tensor_slice(self._catalog, name, self.get_tensor(name))
|
|
382
|
+
|
|
383
|
+
def get_tensor(self, name: str) -> torch.tensor:
|
|
384
|
+
"""Get a tensor by name"""
|
|
385
|
+
# To do : switch to a JIT tensor creation per tensor
|
|
386
|
+
self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
|
|
387
|
+
return self.sd[name]
|
|
388
|
+
|
|
389
|
+
def keys(self) -> List[str]:
|
|
390
|
+
"""Get list of tensor names"""
|
|
391
|
+
if self._keys is None:
|
|
392
|
+
self._keys = list(self._catalog)
|
|
393
|
+
return self._keys
|
|
394
|
+
|
|
395
|
+
def names(self) -> List[str]:
|
|
396
|
+
"""Alias for keys()"""
|
|
397
|
+
return self.keys()
|
|
398
|
+
|
|
399
|
+
def tensors(self) -> Dict[str, torch.tensor]:
|
|
400
|
+
"""Get dictionary of all tensors"""
|
|
401
|
+
self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
|
|
402
|
+
return self.sd
|
|
403
|
+
|
|
404
|
+
def metadata(self) -> Optional[Dict[str, str]]:
|
|
405
|
+
"""Get metadata dictionary"""
|
|
406
|
+
return self._metadata
|
|
407
|
+
|
|
408
|
+
def __len__(self) -> int:
|
|
409
|
+
"""Get number of tensors"""
|
|
410
|
+
self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
|
|
411
|
+
return len(self.keys())
|
|
412
|
+
|
|
413
|
+
def __contains__(self, key: str) -> bool:
|
|
414
|
+
"""Check if tensor exists"""
|
|
415
|
+
return key in self.keys()
|
|
416
|
+
|
|
417
|
+
def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
|
|
418
|
+
"""Iterate over (name, tensor) pairs"""
|
|
419
|
+
return ((name, self.get_tensor(name)) for name in self.keys())
|
|
420
|
+
|
|
421
|
+
def _free_resources(self):
|
|
422
|
+
del self.sd
|
|
423
|
+
del self._catalog
|
|
424
|
+
|
|
425
|
+
class _SafeTensorLoader:
|
|
426
|
+
"""Context manager for loading SafeTensorFile"""
|
|
427
|
+
|
|
428
|
+
def __init__(self, filename: str, writable_tensors = True ):
|
|
429
|
+
self.filename = Path(filename)
|
|
430
|
+
self.writable_tensors = writable_tensors
|
|
431
|
+
self.sft = None
|
|
432
|
+
if not self.filename.exists():
|
|
433
|
+
raise FileNotFoundError(f"File not found: {filename}")
|
|
434
|
+
|
|
435
|
+
def __enter__(self) -> SafeTensorFile:
|
|
436
|
+
"""Open file and return SafeTensorFile instance"""
|
|
437
|
+
writable_tensors = self.writable_tensors
|
|
438
|
+
|
|
439
|
+
if all_tensors_are_read_only:
|
|
440
|
+
writable_tensors = False
|
|
441
|
+
|
|
442
|
+
try:
|
|
443
|
+
self.sft = SafeTensorFile.load_metadata(self.filename, writable_tensors= writable_tensors)
|
|
444
|
+
return self.sft
|
|
445
|
+
|
|
446
|
+
except Exception as e:
|
|
447
|
+
self.close()
|
|
448
|
+
raise Exception(f"Failed to load safetensors file: {e}") from e
|
|
449
|
+
|
|
450
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
451
|
+
"""Clean up resources"""
|
|
452
|
+
self.close()
|
|
453
|
+
|
|
454
|
+
def close(self) -> None:
|
|
455
|
+
if self.sft != None:
|
|
456
|
+
self.sft._free_resources()
|
|
457
|
+
pass
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def safe_open(filename: str, framework: str = "pt",device = "cpu", writable_tensors = True) -> _SafeTensorLoader:
|
|
461
|
+
if device != "cpu" or framework !="pt":
|
|
462
|
+
return _old_safe_open(filename =filename, framework=framework, device=device)
|
|
463
|
+
return _SafeTensorLoader(filename, writable_tensors = writable_tensors)
|
|
464
|
+
|
|
465
|
+
def torch_load_file( filename, device = 'cpu', writable_tensors = True) -> Dict[str, torch.Tensor]:
|
|
466
|
+
sd = {}
|
|
467
|
+
with safe_open(filename, framework="pt", device = device, writable_tensors =writable_tensors ) as f:
|
|
468
|
+
for k in f.keys():
|
|
469
|
+
sd[k] = f.get_tensor(k)
|
|
470
|
+
return sd
|
|
471
|
+
|
|
472
|
+
_old_torch_load_file = safetensors.torch.load_file
|
|
473
|
+
safetensors.torch.load_file = torch_load_file
|
|
474
|
+
_old_safe_open = safetensors.safe_open
|
|
475
|
+
safetensors.safe_open = safe_open
|
|
476
|
+
accelerate.utils.modeling.safe_open = safe_open
|
|
477
|
+
accelerate.utils.modeling.safe_load_file = torch_load_file
|
|
478
|
+
try:
|
|
479
|
+
import transformers
|
|
480
|
+
transformers.modeling_utils.safe_open = safe_open
|
|
481
|
+
transformers.modeling_utils.safe_load_file = torch_load_file
|
|
482
|
+
except:
|
|
462
483
|
pass
|