mmgp 3.4.1__py3-none-any.whl → 3.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/safetensors2.py CHANGED
@@ -1,462 +1,483 @@
1
- # ------------------ Safetensors2 1.1 by DeepBeepMeep (mmgp)------------------
2
- #
3
- # This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
4
- # It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
5
- # You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
6
-
7
-
8
- from typing import Optional, Dict, List, Iterator, Tuple
9
- from pathlib import Path
10
- import torch
11
- import mmap
12
- import struct
13
- import json
14
- import base64
15
- import safetensors
16
- import accelerate
17
- import os
18
- from collections import OrderedDict
19
- import warnings
20
-
21
- warnings.filterwarnings("ignore", ".*The given buffer is not writable, and PyTorch does not support non-writable tensors*")
22
-
23
- _old_torch_load_file = None
24
- _old_safe_open = None
25
-
26
- all_tensors_are_read_only = False
27
-
28
- mmm = {}
29
- verboseLevel = 1
30
-
31
- import weakref
32
-
33
- _map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
34
- 'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
35
- 'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
36
-
37
-
38
- class MmapTracker:
39
- def __init__(self, file_path):
40
- self._maps = {}
41
- self._already_released = 0
42
- from pathlib import Path
43
- s = Path(file_path).parts
44
- if len(s)>2:
45
- s = s[-2:]
46
- file_path = os.path.join(*s)
47
- self.file_path = file_path # os.path.abspath(file_path)
48
- self.count = 0
49
- mmm[file_path] = self
50
-
51
- def register(self, mmap_obj, map_id, start, size):
52
-
53
- self.count += 1
54
- def finalizer(ref):
55
- self._already_released += 1
56
- if verboseLevel >=2:
57
- if self.count == self._already_released:
58
- text =" (all the mmaps have been released)"
59
- else:
60
- text =f" ({self.count-self._already_released:} left)"
61
-
62
- print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
63
- if self.count == self._already_released:
64
- del mmm[self.file_path]
65
-
66
- self._maps.pop(map_id, None)
67
-
68
- wr = weakref.ref(mmap_obj, finalizer)
69
- self._maps[map_id] = {
70
- 'mmap' : wr,
71
- 'start': start,
72
- 'size': size,
73
- 'end': start + size
74
- }
75
- return wr
76
-
77
- def get_active_maps(self):
78
- return dict(self._maps)
79
-
80
-
81
- class cached_metadata:
82
- file_path = None
83
- file_length = 0
84
- file_date = None
85
- catalog = None
86
- metadata = None
87
- skip_bytes = 0
88
-
89
- def __init__(self, file_path, catalog, metadata, skip_bytes):
90
- self.catalog = catalog
91
- self.metadata = metadata
92
- self.skip_bytes = skip_bytes
93
- file_stats = os.stat(file_path)
94
- self.file_path = os.path.abspath(file_path)
95
- self.file_length = file_stats.st_size
96
- self.file_date = file_stats.st_ctime
97
-
98
- def get_metadata(self, file_path):
99
- file_stats = os.stat(file_path)
100
- file_length = file_stats.st_size
101
- file_date = file_stats.st_ctime
102
- file_path = os.path.abspath(file_path)
103
- if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
104
- return None, None, None
105
- return self.catalog, self.metadata, self.skip_bytes
106
-
107
- _cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
108
-
109
- def _parse_metadata(metadata):
110
- if metadata == None:
111
- return None
112
-
113
- new_metadata= {}
114
-
115
- for k,v in metadata.items():
116
- if k.endswith("_base64"):
117
- v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
118
- p = k.rfind("_")
119
- new_k = k[:p]
120
- new_metadata[new_k]= v_decoded
121
- else:
122
- new_metadata[k] = v
123
-
124
- return new_metadata
125
-
126
- def _read_safetensors_header(path, file):
127
- global _cached_entry
128
- length_of_header_bytes = file.read(8)
129
- # Interpret the bytes as a little-endian unsigned 64-bit integer
130
- length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
131
-
132
- if _cached_entry != None:
133
- catalog, metadata, _ = _cached_entry.get_metadata(path)
134
- else:
135
- catalog = None
136
-
137
- if catalog == None:
138
- header_bytes = file.read(length_of_header)
139
- #catalog = json.loads(header_bytes.decode('utf-8'))
140
- catalog = json.loads(header_bytes)
141
- metadata = catalog.pop("__metadata__", None)
142
- metadata = _parse_metadata(metadata)
143
-
144
- _cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
145
- else:
146
- file.seek(length_of_header, 1)
147
-
148
- return catalog, metadata, length_of_header + 8
149
-
150
-
151
- def torch_write_file(sd, file_path, quantization_map = None, config = None, extra_meta = None):
152
- from collections import OrderedDict
153
- sf_sd = OrderedDict()
154
-
155
- map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
156
- torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
157
- torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
158
- pos = 0
159
- i = 0
160
- mx = 100000
161
- metadata = dict()
162
- for k , t in sd.items():
163
- if torch.is_tensor(t):
164
- entry = {}
165
- dtypestr= map[t.dtype]
166
- entry["dtype"] = dtypestr
167
- entry["shape"] = list(t.shape)
168
- size = torch.numel(t) * t.element_size()
169
- if size == 0:
170
- pass
171
- entry["data_offsets"] = [pos, pos + size]
172
- pos += size
173
- sf_sd[k] = entry
174
- else:
175
- if isinstance(t, str):
176
- metadata[k] = t
177
- else:
178
- try:
179
- b64 = base64.b64encode(json.dumps(t, ensure_ascii=False).encode('utf8')).decode('utf8')
180
- metadata[k + "_base64"] = b64
181
- except:
182
- pass
183
-
184
- i+=1
185
- if i==mx:
186
- break
187
- if not quantization_map is None:
188
- metadata["quantization_format"] = "quanto"
189
- metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
190
-
191
- if not config is None:
192
- metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
193
-
194
- if not extra_meta is None:
195
- for n , m in extra_meta.items():
196
- if isinstance(m, str):
197
- metadata[n] = m
198
- else:
199
- metadata[n + "_base64"] = base64.b64encode(json.dumps(m, ensure_ascii=False).encode('utf8')).decode('utf8')
200
-
201
-
202
- if len(metadata) > 0:
203
- sf_sd["__metadata__"] = metadata
204
-
205
- header_bytes = json.dumps(sf_sd).encode()
206
- #header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
207
- size_header = len(header_bytes)
208
- import struct
209
-
210
- length_of_header_bytes = struct.pack('<Q', size_header)
211
-
212
- with open(file_path, "wb") as writer:
213
- bytes_written = writer.write(length_of_header_bytes)
214
- bytes_written = writer.write(header_bytes)
215
-
216
- i = 0
217
- for k , t in sd.items():
218
- if torch.is_tensor(t):
219
- size = torch.numel(t) * t.element_size()
220
- if size != 0:
221
- dtype = t.dtype
222
- # convert in a friendly format, scalars types not supported by numpy
223
- if dtype == torch.bfloat16:
224
- t = t.view(torch.uint16)
225
- elif dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn:
226
- t = t.view(torch.uint8)
227
- buffer = t.numpy().tobytes()
228
- bytes_written = writer.write(buffer)
229
- assert bytes_written == size
230
- i+=1
231
- if i==mx:
232
- break
233
-
234
- class SafeTensorFile:
235
- """Main class for accessing safetensors files that provides memory-efficient access"""
236
-
237
- def __init__(self, file_path, metadata, catalog, skip_bytes, lazy_loading = True, writable_tensors = True):
238
- self._file_path = file_path
239
- self._metadata = metadata
240
- self._catalog = catalog
241
- self._skip_bytes = skip_bytes
242
- self._keys = None
243
- self.sd = None
244
- self.mtracker = None
245
- self.lazy_loading = lazy_loading
246
- self.writable_tensors = writable_tensors
247
-
248
- @classmethod
249
- def load_metadata(cls, file_path, lazy_loading = True, writable_tensors = True):
250
- with open(file_path, 'rb') as f:
251
- catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
252
-
253
- return cls(file_path, metadata, catalog, skip_bytes, lazy_loading, writable_tensors )
254
-
255
- def init_tensors(self, lazyTensors = True, writable_tensors = True):
256
- if self.sd is None:
257
- self.lazy_loading = lazyTensors
258
- if lazyTensors:
259
- self.sd = self.create_tensors_with_mmap(writable_tensors)
260
- else:
261
- self.sd = self.create_tensors_without_mmap()
262
- # else:
263
- # if not self.lazy_loading and lazyTensors:
264
- # raise Exception("Every tensor should be either lazy loaded or not lazy loaded")
265
-
266
- return self.sd
267
-
268
-
269
- def create_tensors_with_mmap(self, writable_tensors = True):
270
-
271
- self.mtracker = MmapTracker(self._file_path)
272
- import mmap
273
-
274
- PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
275
- MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
276
- # MMAP_SIZE = 256 * 1024 * 1024 # 1GB
277
-
278
- # First pass: find optimal aligned map boundaries
279
- skip_bytes = self._skip_bytes
280
- tensor_map_indexes = []
281
- maps_info = []
282
- current_pos = skip_bytes
283
- current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
284
- current_map_size = skip_bytes - current_map_start
285
- idx = 0
286
- for k,v in self._catalog.items():
287
- data_offsets = v["data_offsets"]
288
- length = data_offsets[1]-data_offsets[0]
289
- if current_map_size + length > MMAP_SIZE:
290
- maps_info.append((current_map_start, current_map_size))
291
- current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
292
- current_map_size = current_pos - current_map_start
293
- idx += 1
294
- tensor_map_indexes.append(idx)
295
- current_map_size += length
296
- current_pos += length
297
-
298
- maps_info.append((current_map_start, current_map_size))
299
-
300
- # Second pass: create maps and tensors
301
- maps = []
302
- sd = OrderedDict()
303
-
304
- current_pos = skip_bytes
305
- with open(self._file_path, 'rb') as f:
306
- i = 0
307
- for map_start, map_size in maps_info:
308
- mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access= mmap.ACCESS_COPY if writable_tensors else mmap.ACCESS_READ)
309
- maps.append((mm, map_start, map_size))
310
- self.mtracker.register(mm, i, map_start, map_size)
311
- i = i+ 1
312
-
313
- iter_tensor_no = iter(tensor_map_indexes)
314
- for k,v in self._catalog.items():
315
- dtypestr = v["dtype"]
316
- dtype= _map_to_dtype[dtypestr]
317
- shape = v["shape"]
318
- data_offsets = v["data_offsets"]
319
- length = data_offsets[1]-data_offsets[0]
320
- map_idx = next(iter_tensor_no)
321
- offset = current_pos - maps[map_idx][1]
322
- if length == 0:
323
- t = torch.empty(shape, dtype=dtype)
324
- elif len(shape) == 0:
325
- # don't waste a memory view for a scalar
326
- t = torch.frombuffer(bytearray(maps[map_idx][0][offset:offset + length]), dtype=torch.uint8)
327
- t = t.view(dtype)
328
- else:
329
- mv = memoryview(maps[map_idx][0])[offset:offset + length]
330
- t = torch.frombuffer(mv, dtype=dtype)
331
- t = torch.reshape(t, shape)
332
- # t._mmap = maps[map_idx][0]
333
- sd[k] = t
334
- current_pos += length
335
-
336
- return sd
337
-
338
-
339
- def create_tensors_without_mmap(self):
340
- sd = OrderedDict()
341
-
342
- with open(self._file_path, 'rb') as f:
343
- f.seek(self._skip_bytes, 0)
344
- for k,v in self._catalog.items():
345
- dtypestr = v["dtype"]
346
- dtype= _map_to_dtype[dtypestr]
347
- shape = v["shape"]
348
- data_offsets = v["data_offsets"]
349
- length = data_offsets[1]-data_offsets[0]
350
- buffer = f.read(length)
351
- if length == 0:
352
- t = torch.empty(0, dtype=dtype)
353
- elif len(shape) == 0:
354
- t = torch.frombuffer(bytearray(buffer), dtype=torch.uint8)
355
- t = t.view(dtype)
356
- else:
357
- t = torch.frombuffer(bytearray(buffer), dtype=dtype)
358
- t = torch.reshape(t, shape)
359
- sd[k] = t
360
- return sd
361
-
362
- def get_tensor(self, name: str) -> torch.tensor:
363
- """Get a tensor by name"""
364
- # To do : switch to a JIT tensor creation per tensor
365
- self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
366
- return self.sd[name]
367
-
368
- def keys(self) -> List[str]:
369
- """Get list of tensor names"""
370
- if self._keys is None:
371
- self._keys = list(self._catalog)
372
- return self._keys
373
-
374
- def names(self) -> List[str]:
375
- """Alias for keys()"""
376
- return self.keys()
377
-
378
- def tensors(self) -> Dict[str, torch.tensor]:
379
- """Get dictionary of all tensors"""
380
- self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
381
- return self.sd
382
-
383
- def metadata(self) -> Optional[Dict[str, str]]:
384
- """Get metadata dictionary"""
385
- return self._metadata
386
-
387
- def __len__(self) -> int:
388
- """Get number of tensors"""
389
- self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
390
- return len(self.keys())
391
-
392
- def __contains__(self, key: str) -> bool:
393
- """Check if tensor exists"""
394
- return key in self.keys()
395
-
396
- def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
397
- """Iterate over (name, tensor) pairs"""
398
- return ((name, self.get_tensor(name)) for name in self.keys())
399
-
400
- def _free_resources(self):
401
- del self.sd
402
- del self._catalog
403
-
404
- class _SafeTensorLoader:
405
- """Context manager for loading SafeTensorFile"""
406
-
407
- def __init__(self, filename: str, writable_tensors = True ):
408
- self.filename = Path(filename)
409
- self.writable_tensors = writable_tensors
410
- self.sft = None
411
- if not self.filename.exists():
412
- raise FileNotFoundError(f"File not found: {filename}")
413
-
414
- def __enter__(self) -> SafeTensorFile:
415
- """Open file and return SafeTensorFile instance"""
416
- writable_tensors = self.writable_tensors
417
-
418
- if all_tensors_are_read_only:
419
- writable_tensors = False
420
-
421
- try:
422
- self.sft = SafeTensorFile.load_metadata(self.filename, writable_tensors= writable_tensors)
423
- return self.sft
424
-
425
- except Exception as e:
426
- self.close()
427
- raise Exception(f"Failed to load safetensors file: {e}") from e
428
-
429
- def __exit__(self, exc_type, exc_val, exc_tb) -> None:
430
- """Clean up resources"""
431
- self.close()
432
-
433
- def close(self) -> None:
434
- if self.sft != None:
435
- self.sft._free_resources()
436
- pass
437
-
438
-
439
- def safe_open(filename: str, framework: str = "pt",device = "cpu", writable_tensors = True) -> _SafeTensorLoader:
440
- if device != "cpu" or framework !="pt":
441
- return _old_safe_open(filename =filename, framework=framework, device=device)
442
- return _SafeTensorLoader(filename, writable_tensors = writable_tensors)
443
-
444
- def torch_load_file( filename, device = 'cpu', writable_tensors = True) -> Dict[str, torch.Tensor]:
445
- sd = {}
446
- with safe_open(filename, framework="pt", device = device, writable_tensors =writable_tensors ) as f:
447
- for k in f.keys():
448
- sd[k] = f.get_tensor(k)
449
- return sd
450
-
451
- _old_torch_load_file = safetensors.torch.load_file
452
- safetensors.torch.load_file = torch_load_file
453
- _old_safe_open = safetensors.safe_open
454
- safetensors.safe_open = safe_open
455
- accelerate.utils.modeling.safe_open = safe_open
456
- accelerate.utils.modeling.safe_load_file = torch_load_file
457
- try:
458
- import transformers
459
- transformers.modeling_utils.safe_open = safe_open
460
- transformers.modeling_utils.safe_load_file = torch_load_file
461
- except:
1
+ # ------------------ Safetensors2 1.1 by DeepBeepMeep (mmgp)------------------
2
+ #
3
+ # This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
4
+ # It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
5
+ # You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
6
+
7
+
8
+ from typing import Optional, Dict, List, Iterator, Tuple
9
+ from pathlib import Path
10
+ import torch
11
+ import mmap
12
+ import struct
13
+ import json
14
+ import base64
15
+ import safetensors
16
+ import accelerate
17
+ import os
18
+ from collections import OrderedDict
19
+ import warnings
20
+
21
+ warnings.filterwarnings("ignore", ".*The given buffer is not writable, and PyTorch does not support non-writable tensors*")
22
+
23
+ _old_torch_load_file = None
24
+ _old_safe_open = None
25
+
26
+ all_tensors_are_read_only = False
27
+
28
+ mmm = {}
29
+ verboseLevel = 1
30
+
31
+ import weakref
32
+
33
+ _map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
34
+ 'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
35
+ 'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
36
+
37
+
38
+ class MmapTracker:
39
+ def __init__(self, file_path):
40
+ self._maps = {}
41
+ self._already_released = 0
42
+ from pathlib import Path
43
+ s = Path(file_path).parts
44
+ if len(s)>2:
45
+ s = s[-2:]
46
+ file_path = os.path.join(*s)
47
+ self.file_path = file_path # os.path.abspath(file_path)
48
+ self.count = 0
49
+ mmm[file_path] = self
50
+
51
+ def register(self, mmap_obj, map_id, start, size):
52
+
53
+ self.count += 1
54
+ def finalizer(ref):
55
+ self._already_released += 1
56
+ if verboseLevel >=2:
57
+ if self.count == self._already_released:
58
+ text =" (all the mmaps have been released)"
59
+ else:
60
+ text =f" ({self.count-self._already_released:} left)"
61
+
62
+ print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
63
+ if self.count == self._already_released:
64
+ del mmm[self.file_path]
65
+
66
+ self._maps.pop(map_id, None)
67
+
68
+ wr = weakref.ref(mmap_obj, finalizer)
69
+ self._maps[map_id] = {
70
+ 'mmap' : wr,
71
+ 'start': start,
72
+ 'size': size,
73
+ 'end': start + size
74
+ }
75
+ return wr
76
+
77
+ def get_active_maps(self):
78
+ return dict(self._maps)
79
+
80
+ class tensor_slice:
81
+ catalog = None
82
+ value = None
83
+ name = None
84
+
85
+ def __init__(self, catalog, name, value):
86
+ self.catalog = catalog
87
+ self.value = value
88
+ self.name = name
89
+
90
+ def __getitem__(self, s):
91
+ return self.value[s]
92
+
93
+ def get_dtype(self):
94
+ return self.catalog[self.name]["dtype"]
95
+
96
+ def get_shape(self):
97
+ return self.catalog[self.name]["shape"]
98
+
99
+ class cached_metadata:
100
+ file_path = None
101
+ file_length = 0
102
+ file_date = None
103
+ catalog = None
104
+ metadata = None
105
+ skip_bytes = 0
106
+
107
+ def __init__(self, file_path, catalog, metadata, skip_bytes):
108
+ self.catalog = catalog
109
+ self.metadata = metadata
110
+ self.skip_bytes = skip_bytes
111
+ file_stats = os.stat(file_path)
112
+ self.file_path = os.path.abspath(file_path)
113
+ self.file_length = file_stats.st_size
114
+ self.file_date = file_stats.st_ctime
115
+
116
+ def get_metadata(self, file_path):
117
+ file_stats = os.stat(file_path)
118
+ file_length = file_stats.st_size
119
+ file_date = file_stats.st_ctime
120
+ file_path = os.path.abspath(file_path)
121
+ if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
122
+ return None, None, None
123
+ return self.catalog, self.metadata, self.skip_bytes
124
+
125
+ _cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
126
+
127
+ def _parse_metadata(metadata):
128
+ if metadata == None:
129
+ return None
130
+
131
+ new_metadata= {}
132
+
133
+ for k,v in metadata.items():
134
+ if k.endswith("_base64"):
135
+ v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
136
+ p = k.rfind("_")
137
+ new_k = k[:p]
138
+ new_metadata[new_k]= v_decoded
139
+ else:
140
+ new_metadata[k] = v
141
+
142
+ return new_metadata
143
+
144
+ def _read_safetensors_header(path, file):
145
+ global _cached_entry
146
+ length_of_header_bytes = file.read(8)
147
+ # Interpret the bytes as a little-endian unsigned 64-bit integer
148
+ length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
149
+
150
+ if _cached_entry != None:
151
+ catalog, metadata, _ = _cached_entry.get_metadata(path)
152
+ else:
153
+ catalog = None
154
+
155
+ if catalog == None:
156
+ header_bytes = file.read(length_of_header)
157
+ #catalog = json.loads(header_bytes.decode('utf-8'))
158
+ catalog = json.loads(header_bytes)
159
+ metadata = catalog.pop("__metadata__", None)
160
+ metadata = _parse_metadata(metadata)
161
+
162
+ _cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
163
+ else:
164
+ file.seek(length_of_header, 1)
165
+
166
+ return catalog, metadata, length_of_header + 8
167
+
168
+
169
+ def torch_write_file(sd, file_path, quantization_map = None, config = None, extra_meta = None):
170
+ from collections import OrderedDict
171
+ sf_sd = OrderedDict()
172
+
173
+ map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
174
+ torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
175
+ torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
176
+ pos = 0
177
+ i = 0
178
+ mx = 100000
179
+ metadata = dict()
180
+ for k , t in sd.items():
181
+ if torch.is_tensor(t):
182
+ entry = {}
183
+ dtypestr= map[t.dtype]
184
+ entry["dtype"] = dtypestr
185
+ entry["shape"] = list(t.shape)
186
+ size = torch.numel(t) * t.element_size()
187
+ if size == 0:
188
+ pass
189
+ entry["data_offsets"] = [pos, pos + size]
190
+ pos += size
191
+ sf_sd[k] = entry
192
+ else:
193
+ if isinstance(t, str):
194
+ metadata[k] = t
195
+ else:
196
+ try:
197
+ b64 = base64.b64encode(json.dumps(t, ensure_ascii=False).encode('utf8')).decode('utf8')
198
+ metadata[k + "_base64"] = b64
199
+ except:
200
+ pass
201
+
202
+ i+=1
203
+ if i==mx:
204
+ break
205
+ if not quantization_map is None:
206
+ metadata["quantization_format"] = "quanto"
207
+ metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
208
+
209
+ if not config is None:
210
+ metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
211
+
212
+ if not extra_meta is None:
213
+ for n , m in extra_meta.items():
214
+ if isinstance(m, str):
215
+ metadata[n] = m
216
+ else:
217
+ metadata[n + "_base64"] = base64.b64encode(json.dumps(m, ensure_ascii=False).encode('utf8')).decode('utf8')
218
+
219
+
220
+ if len(metadata) > 0:
221
+ sf_sd["__metadata__"] = metadata
222
+
223
+ header_bytes = json.dumps(sf_sd).encode()
224
+ #header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
225
+ size_header = len(header_bytes)
226
+ import struct
227
+
228
+ length_of_header_bytes = struct.pack('<Q', size_header)
229
+
230
+ with open(file_path, "wb") as writer:
231
+ bytes_written = writer.write(length_of_header_bytes)
232
+ bytes_written = writer.write(header_bytes)
233
+
234
+ i = 0
235
+ for k , t in sd.items():
236
+ if torch.is_tensor(t):
237
+ size = torch.numel(t) * t.element_size()
238
+ if size != 0:
239
+ dtype = t.dtype
240
+ # convert in a friendly format, scalars types not supported by numpy
241
+ if dtype == torch.bfloat16:
242
+ t = t.view(torch.uint16)
243
+ elif dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn:
244
+ t = t.view(torch.uint8)
245
+ buffer = t.numpy().tobytes()
246
+ bytes_written = writer.write(buffer)
247
+ assert bytes_written == size
248
+ i+=1
249
+ if i==mx:
250
+ break
251
+
252
+ class SafeTensorFile:
253
+ """Main class for accessing safetensors files that provides memory-efficient access"""
254
+
255
+ def __init__(self, file_path, metadata, catalog, skip_bytes, lazy_loading = True, writable_tensors = True):
256
+ self._file_path = file_path
257
+ self._metadata = metadata
258
+ self._catalog = catalog
259
+ self._skip_bytes = skip_bytes
260
+ self._keys = None
261
+ self.sd = None
262
+ self.mtracker = None
263
+ self.lazy_loading = lazy_loading
264
+ self.writable_tensors = writable_tensors
265
+
266
+ @classmethod
267
+ def load_metadata(cls, file_path, lazy_loading = True, writable_tensors = True):
268
+ with open(file_path, 'rb') as f:
269
+ catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
270
+
271
+ return cls(file_path, metadata, catalog, skip_bytes, lazy_loading, writable_tensors )
272
+
273
+ def init_tensors(self, lazyTensors = True, writable_tensors = True):
274
+ if self.sd is None:
275
+ self.lazy_loading = lazyTensors
276
+ if lazyTensors:
277
+ self.sd = self.create_tensors_with_mmap(writable_tensors)
278
+ else:
279
+ self.sd = self.create_tensors_without_mmap()
280
+ # else:
281
+ # if not self.lazy_loading and lazyTensors:
282
+ # raise Exception("Every tensor should be either lazy loaded or not lazy loaded")
283
+
284
+ return self.sd
285
+
286
+
287
+ def create_tensors_with_mmap(self, writable_tensors = True):
288
+
289
+ self.mtracker = MmapTracker(self._file_path)
290
+ import mmap
291
+
292
+ PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
293
+ MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
294
+ # MMAP_SIZE = 256 * 1024 * 1024 # 1GB
295
+
296
+ # First pass: find optimal aligned map boundaries
297
+ skip_bytes = self._skip_bytes
298
+ tensor_map_indexes = []
299
+ maps_info = []
300
+ current_pos = skip_bytes
301
+ current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
302
+ current_map_size = skip_bytes - current_map_start
303
+ idx = 0
304
+ for k,v in self._catalog.items():
305
+ data_offsets = v["data_offsets"]
306
+ length = data_offsets[1]-data_offsets[0]
307
+ if current_map_size + length > MMAP_SIZE:
308
+ maps_info.append((current_map_start, current_map_size))
309
+ current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
310
+ current_map_size = current_pos - current_map_start
311
+ idx += 1
312
+ tensor_map_indexes.append(idx)
313
+ current_map_size += length
314
+ current_pos += length
315
+
316
+ maps_info.append((current_map_start, current_map_size))
317
+
318
+ # Second pass: create maps and tensors
319
+ maps = []
320
+ sd = OrderedDict()
321
+
322
+ current_pos = skip_bytes
323
+ with open(self._file_path, 'rb') as f:
324
+ i = 0
325
+ for map_start, map_size in maps_info:
326
+ mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access= mmap.ACCESS_COPY if writable_tensors else mmap.ACCESS_READ)
327
+ maps.append((mm, map_start, map_size))
328
+ self.mtracker.register(mm, i, map_start, map_size)
329
+ i = i+ 1
330
+
331
+ iter_tensor_no = iter(tensor_map_indexes)
332
+ for k,v in self._catalog.items():
333
+ dtypestr = v["dtype"]
334
+ dtype= _map_to_dtype[dtypestr]
335
+ shape = v["shape"]
336
+ data_offsets = v["data_offsets"]
337
+ length = data_offsets[1]-data_offsets[0]
338
+ map_idx = next(iter_tensor_no)
339
+ offset = current_pos - maps[map_idx][1]
340
+ if length == 0:
341
+ t = torch.empty(shape, dtype=dtype)
342
+ elif len(shape) == 0:
343
+ # don't waste a memory view for a scalar
344
+ t = torch.frombuffer(bytearray(maps[map_idx][0][offset:offset + length]), dtype=torch.uint8)
345
+ t = t.view(dtype)
346
+ else:
347
+ mv = memoryview(maps[map_idx][0])[offset:offset + length]
348
+ t = torch.frombuffer(mv, dtype=dtype)
349
+ t = torch.reshape(t, shape)
350
+ # t._mmap = maps[map_idx][0]
351
+ sd[k] = t
352
+ current_pos += length
353
+
354
+ return sd
355
+
356
+
357
+ def create_tensors_without_mmap(self):
358
+ sd = OrderedDict()
359
+
360
+ with open(self._file_path, 'rb') as f:
361
+ f.seek(self._skip_bytes, 0)
362
+ for k,v in self._catalog.items():
363
+ dtypestr = v["dtype"]
364
+ dtype= _map_to_dtype[dtypestr]
365
+ shape = v["shape"]
366
+ data_offsets = v["data_offsets"]
367
+ length = data_offsets[1]-data_offsets[0]
368
+ buffer = f.read(length)
369
+ if length == 0:
370
+ t = torch.empty(0, dtype=dtype)
371
+ elif len(shape) == 0:
372
+ t = torch.frombuffer(bytearray(buffer), dtype=torch.uint8)
373
+ t = t.view(dtype)
374
+ else:
375
+ t = torch.frombuffer(bytearray(buffer), dtype=dtype)
376
+ t = torch.reshape(t, shape)
377
+ sd[k] = t
378
+ return sd
379
+
380
+ def get_slice(self, name: str) -> torch.tensor:
381
+ return tensor_slice(self._catalog, name, self.get_tensor(name))
382
+
383
+ def get_tensor(self, name: str) -> torch.tensor:
384
+ """Get a tensor by name"""
385
+ # To do : switch to a JIT tensor creation per tensor
386
+ self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
387
+ return self.sd[name]
388
+
389
+ def keys(self) -> List[str]:
390
+ """Get list of tensor names"""
391
+ if self._keys is None:
392
+ self._keys = list(self._catalog)
393
+ return self._keys
394
+
395
+ def names(self) -> List[str]:
396
+ """Alias for keys()"""
397
+ return self.keys()
398
+
399
+ def tensors(self) -> Dict[str, torch.tensor]:
400
+ """Get dictionary of all tensors"""
401
+ self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
402
+ return self.sd
403
+
404
+ def metadata(self) -> Optional[Dict[str, str]]:
405
+ """Get metadata dictionary"""
406
+ return self._metadata
407
+
408
+ def __len__(self) -> int:
409
+ """Get number of tensors"""
410
+ self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
411
+ return len(self.keys())
412
+
413
+ def __contains__(self, key: str) -> bool:
414
+ """Check if tensor exists"""
415
+ return key in self.keys()
416
+
417
+ def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
418
+ """Iterate over (name, tensor) pairs"""
419
+ return ((name, self.get_tensor(name)) for name in self.keys())
420
+
421
+ def _free_resources(self):
422
+ del self.sd
423
+ del self._catalog
424
+
425
+ class _SafeTensorLoader:
426
+ """Context manager for loading SafeTensorFile"""
427
+
428
+ def __init__(self, filename: str, writable_tensors = True ):
429
+ self.filename = Path(filename)
430
+ self.writable_tensors = writable_tensors
431
+ self.sft = None
432
+ if not self.filename.exists():
433
+ raise FileNotFoundError(f"File not found: {filename}")
434
+
435
+ def __enter__(self) -> SafeTensorFile:
436
+ """Open file and return SafeTensorFile instance"""
437
+ writable_tensors = self.writable_tensors
438
+
439
+ if all_tensors_are_read_only:
440
+ writable_tensors = False
441
+
442
+ try:
443
+ self.sft = SafeTensorFile.load_metadata(self.filename, writable_tensors= writable_tensors)
444
+ return self.sft
445
+
446
+ except Exception as e:
447
+ self.close()
448
+ raise Exception(f"Failed to load safetensors file: {e}") from e
449
+
450
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
451
+ """Clean up resources"""
452
+ self.close()
453
+
454
+ def close(self) -> None:
455
+ if self.sft != None:
456
+ self.sft._free_resources()
457
+ pass
458
+
459
+
460
+ def safe_open(filename: str, framework: str = "pt",device = "cpu", writable_tensors = True) -> _SafeTensorLoader:
461
+ if device != "cpu" or framework !="pt":
462
+ return _old_safe_open(filename =filename, framework=framework, device=device)
463
+ return _SafeTensorLoader(filename, writable_tensors = writable_tensors)
464
+
465
+ def torch_load_file( filename, device = 'cpu', writable_tensors = True) -> Dict[str, torch.Tensor]:
466
+ sd = {}
467
+ with safe_open(filename, framework="pt", device = device, writable_tensors =writable_tensors ) as f:
468
+ for k in f.keys():
469
+ sd[k] = f.get_tensor(k)
470
+ return sd
471
+
472
+ _old_torch_load_file = safetensors.torch.load_file
473
+ safetensors.torch.load_file = torch_load_file
474
+ _old_safe_open = safetensors.safe_open
475
+ safetensors.safe_open = safe_open
476
+ accelerate.utils.modeling.safe_open = safe_open
477
+ accelerate.utils.modeling.safe_load_file = torch_load_file
478
+ try:
479
+ import transformers
480
+ transformers.modeling_utils.safe_open = safe_open
481
+ transformers.modeling_utils.safe_load_file = torch_load_file
482
+ except:
462
483
  pass