mmgp 3.3.1__py3-none-any.whl → 3.6.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mmgp/safetensors2.py CHANGED
@@ -1,454 +1,538 @@
1
- # ------------------ Safetensors2 1.0 by DeepBeepMeep (mmgp)------------------
2
- #
3
- # This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
4
- # It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
5
- # You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
6
-
7
-
8
- from typing import Optional, Dict, List, Iterator, Tuple
9
- from pathlib import Path
10
- import torch
11
- import mmap
12
- import struct
13
- import json
14
- import base64
15
- import safetensors
16
- import accelerate
17
- import os
18
- from collections import OrderedDict
19
-
20
-
21
- _old_torch_load_file = None
22
- _old_safe_open = None
23
-
24
-
25
-
26
- mmm = {}
27
- verboseLevel = 1
28
-
29
- import weakref
30
-
31
- _map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
32
- 'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
33
- 'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
34
-
35
-
36
- class MmapTracker:
37
- def __init__(self, file_path):
38
- self._maps = {}
39
- self._already_released = 0
40
- from pathlib import Path
41
- s = Path(file_path).parts
42
- if len(s)>2:
43
- s = s[-2:]
44
- file_path = os.path.join(*s)
45
- self.file_path = file_path # os.path.abspath(file_path)
46
- self.count = 0
47
- mmm[file_path] = self
48
-
49
- def register(self, mmap_obj, map_id, start, size):
50
-
51
- self.count += 1
52
- def finalizer(ref):
53
- self._already_released += 1
54
- if verboseLevel >=2:
55
- if self.count == self._already_released:
56
- text =" (all the mmaps have been released)"
57
- else:
58
- text =f" ({self.count-self._already_released:} left)"
59
-
60
- print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
61
- if self.count == self._already_released:
62
- del mmm[self.file_path]
63
-
64
- self._maps.pop(map_id, None)
65
-
66
- wr = weakref.ref(mmap_obj, finalizer)
67
- self._maps[map_id] = {
68
- 'mmap' : wr,
69
- 'start': start,
70
- 'size': size,
71
- 'end': start + size
72
- }
73
- return wr
74
-
75
- def get_active_maps(self):
76
- return dict(self._maps)
77
-
78
-
79
- class cached_metadata:
80
- file_path = None
81
- file_length = 0
82
- file_date = None
83
- catalog = None
84
- metadata = None
85
- skip_bytes = 0
86
-
87
- def __init__(self, file_path, catalog, metadata, skip_bytes):
88
- self.catalog = catalog
89
- self.metadata = metadata
90
- self.skip_bytes = skip_bytes
91
- file_stats = os.stat(file_path)
92
- self.file_path = os.path.abspath(file_path)
93
- self.file_length = file_stats.st_size
94
- self.file_date = file_stats.st_ctime
95
-
96
- def get_metadata(self, file_path):
97
- file_stats = os.stat(file_path)
98
- file_length = file_stats.st_size
99
- file_date = file_stats.st_ctime
100
- file_path = os.path.abspath(file_path)
101
- if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
102
- return None, None, None
103
- return self.catalog, self.metadata, self.skip_bytes
104
-
105
- _cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
106
-
107
- def _parse_metadata(metadata):
108
- if metadata == None:
109
- return None
110
-
111
- new_metadata= {}
112
-
113
- for k,v in metadata.items():
114
- if k.endswith("_base64"):
115
- v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
116
- p = k.rfind("_")
117
- new_k = k[:p]
118
- new_metadata[new_k]= v_decoded
119
- else:
120
- new_metadata[k] = v
121
-
122
- return new_metadata
123
-
1
+ # ------------------ Safetensors2 1.3 by DeepBeepMeep (mmgp)------------------
2
+ #
3
+ # This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
4
+ # It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
5
+ # You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
6
+
7
+
8
+ from typing import Optional, Dict, List, Iterator, Tuple
9
+ from pathlib import Path
10
+ import torch
11
+ import mmap
12
+ import struct
13
+ import json
14
+ import base64
15
+ import safetensors
16
+ import accelerate
17
+ import os
18
+ from collections import OrderedDict
19
+ import warnings
20
+
21
+ warnings.filterwarnings("ignore", ".*The given buffer is not writable, and PyTorch does not support non-writable tensors*")
22
+
23
+ _old_torch_load_file = None
24
+ _old_safe_open = None
25
+
26
+ all_tensors_are_read_only = False
27
+
28
+ mmm = {}
29
+ verboseLevel = 1
30
+
31
+ import weakref
32
+
33
+ _map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
34
+ 'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
35
+ 'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
36
+
37
+
38
+ class MmapTracker:
39
+ def __init__(self, file_path):
40
+ self._maps = {}
41
+ self._already_released = 0
42
+ from pathlib import Path
43
+ s = Path(file_path).parts
44
+ if len(s)>2:
45
+ s = s[-2:]
46
+ file_path = os.path.join(*s)
47
+ self.file_path = file_path # os.path.abspath(file_path)
48
+ self.count = 0
49
+ key = file_path
50
+ i = 1
51
+ while True:
52
+ if key not in mmm:
53
+ mmm[key] = self
54
+ break
55
+ i +=1
56
+ key = key + "#" + str(i)
57
+ self.mmm_key = key
58
+ # print(f"MMAP Add: {file_path}: {mmm.keys()}")
59
+
60
+ def register(self, mmap_obj, map_id, start, size):
61
+
62
+ self.count += 1
63
+ def finalizer(ref):
64
+ self._already_released += 1
65
+ if verboseLevel >=2:
66
+ if self.count == self._already_released:
67
+ text =" (all the mmaps have been released)"
68
+ else:
69
+ text =f" ({self.count-self._already_released:} left)"
70
+
71
+ print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
72
+ if self.count == self._already_released:
73
+ # print(f"MMAP Del: {self.file_path}: {mmm.keys()}")
74
+ del mmm[self.mmm_key ]
75
+
76
+ self._maps.pop(map_id, None)
77
+
78
+ wr = weakref.ref(mmap_obj, finalizer)
79
+ self._maps[map_id] = {
80
+ 'mmap' : wr,
81
+ 'start': start,
82
+ 'size': size,
83
+ 'end': start + size
84
+ }
85
+ return wr
86
+
87
+ def get_active_maps(self):
88
+ return dict(self._maps)
89
+
90
+ class tensor_slice:
91
+ catalog = None
92
+ value = None
93
+ name = None
94
+
95
+ def __init__(self, catalog, name, value):
96
+ self.catalog = catalog
97
+ self.value = value
98
+ self.name = name
99
+
100
+ def __getitem__(self, s):
101
+ return self.value[s]
102
+
103
+ def get_dtype(self):
104
+ return self.catalog[self.name]["dtype"]
105
+
106
+ def get_shape(self):
107
+ return self.catalog[self.name]["shape"]
108
+
109
+ class tensor_stub:
110
+ dtype = None
111
+ shape = None
112
+
113
+ def __init__(self, dtype, shape):
114
+ self.dtype = dtype
115
+ self.shape = tuple(shape)
116
+
117
+ @property
118
+ def ndim(self):
119
+ return len(self.shape)
120
+
121
+ def numel(self):
122
+ if not self.shape:
123
+ return 1
124
+ n = 1
125
+ for dim in self.shape:
126
+ n *= int(dim)
127
+ return n
128
+
129
+ @property
130
+ def device(self):
131
+ return torch.device("cpu")
132
+
133
+ class cached_metadata:
134
+ file_path = None
135
+ file_length = 0
136
+ file_date = None
137
+ catalog = None
138
+ metadata = None
139
+ skip_bytes = 0
140
+
141
+ def __init__(self, file_path, catalog, metadata, skip_bytes):
142
+ self.catalog = catalog
143
+ self.metadata = metadata
144
+ self.skip_bytes = skip_bytes
145
+ file_stats = os.stat(file_path)
146
+ self.file_path = os.path.abspath(file_path)
147
+ self.file_length = file_stats.st_size
148
+ self.file_date = file_stats.st_ctime
149
+
150
+ def get_metadata(self, file_path):
151
+ file_stats = os.stat(file_path)
152
+ file_length = file_stats.st_size
153
+ file_date = file_stats.st_ctime
154
+ file_path = os.path.abspath(file_path)
155
+ if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
156
+ return None, None, None
157
+ return self.catalog, self.metadata, self.skip_bytes
158
+
159
+ _cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
160
+
161
+ def _parse_metadata(metadata):
162
+ new_metadata= {}
163
+ if metadata != None:
164
+ for k,v in metadata.items():
165
+ if k.endswith("_base64"):
166
+ v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
167
+ p = k.rfind("_")
168
+ new_k = k[:p]
169
+ new_metadata[new_k]= v_decoded
170
+ else:
171
+ new_metadata[k] = v
172
+ if "format" not in new_metadata:
173
+ new_metadata["format"] = "pt"
174
+ return new_metadata
175
+
124
176
  def _read_safetensors_header(path, file):
125
- global _cached_entry
126
- length_of_header_bytes = file.read(8)
127
- # Interpret the bytes as a little-endian unsigned 64-bit integer
128
- length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
129
-
130
- if _cached_entry != None:
131
- catalog, metadata, _ = _cached_entry.get_metadata(path)
132
- else:
133
- catalog = None
134
-
135
- if catalog == None:
136
- header_bytes = file.read(length_of_header)
137
- #catalog = json.loads(header_bytes.decode('utf-8'))
138
- catalog = json.loads(header_bytes)
139
- metadata = catalog.pop("__metadata__", None)
140
- metadata = _parse_metadata(metadata)
141
-
142
- _cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
143
- else:
144
- file.seek(length_of_header, 1)
145
-
177
+ global _cached_entry
178
+ length_of_header_bytes = file.read(8)
179
+ # Interpret the bytes as a little-endian unsigned 64-bit integer
180
+ length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
181
+
182
+ if _cached_entry != None:
183
+ catalog, metadata, _ = _cached_entry.get_metadata(path)
184
+ else:
185
+ catalog = None
186
+
187
+ if catalog == None:
188
+ header_bytes = file.read(length_of_header)
189
+ #catalog = json.loads(header_bytes.decode('utf-8'))
190
+ catalog = json.loads(header_bytes)
191
+ metadata = catalog.pop("__metadata__", None)
192
+ metadata = _parse_metadata(metadata)
193
+
194
+ _cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
195
+ else:
196
+ file.seek(length_of_header, 1)
197
+
146
198
  return catalog, metadata, length_of_header + 8
147
199
 
148
-
149
- def torch_write_file(sd, file_path, quantization_map = None, config = None, extra_meta = None):
150
- from collections import OrderedDict
151
- sf_sd = OrderedDict()
152
-
153
- map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
154
- torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
155
- torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
156
- pos = 0
157
- i = 0
158
- mx = 100000
159
- metadata = dict()
160
- for k , t in sd.items():
161
- if torch.is_tensor(t):
162
- entry = {}
163
- dtypestr= map[t.dtype]
164
- entry["dtype"] = dtypestr
165
- entry["shape"] = list(t.shape)
166
- size = torch.numel(t) * t.element_size()
167
- if size == 0:
168
- pass
169
- entry["data_offsets"] = [pos, pos + size]
170
- pos += size
171
- sf_sd[k] = entry
172
- else:
173
- if isinstance(t, str):
174
- metadata[k] = t
175
- else:
176
- try:
177
- b64 = base64.b64encode(json.dumps(t, ensure_ascii=False).encode('utf8')).decode('utf8')
178
- metadata[k + "_base64"] = b64
179
- except:
180
- pass
181
-
182
- i+=1
183
- if i==mx:
184
- break
185
- if not quantization_map is None:
186
- metadata["quantization_format"] = "quanto"
187
- metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
188
-
189
- if not config is None:
190
- metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
191
-
192
- if not extra_meta is None:
193
- for n , m in extra_meta.items():
194
- if isinstance(m, str):
195
- metadata[n] = m
196
- else:
197
- metadata[n + "_base64"] = base64.b64encode(json.dumps(m, ensure_ascii=False).encode('utf8')).decode('utf8')
198
-
199
-
200
- if len(metadata) > 0:
201
- sf_sd["__metadata__"] = metadata
202
-
203
- header_bytes = json.dumps(sf_sd).encode()
204
- #header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
205
- size_header = len(header_bytes)
206
- import struct
207
-
208
- length_of_header_bytes = struct.pack('<Q', size_header)
209
-
210
- with open(file_path, "wb") as writer:
211
- bytes_written = writer.write(length_of_header_bytes)
212
- bytes_written = writer.write(header_bytes)
213
-
214
- i = 0
215
- for k , t in sd.items():
216
- if torch.is_tensor(t):
217
- size = torch.numel(t) * t.element_size()
218
- if size != 0:
219
- dtype = t.dtype
220
- # convert in a friendly format, scalars types not supported by numpy
221
- if dtype == torch.bfloat16:
222
- t = t.view(torch.uint16)
223
- elif dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn:
224
- t = t.view(torch.uint8)
225
- buffer = t.numpy().tobytes()
226
- bytes_written = writer.write(buffer)
227
- assert bytes_written == size
228
- i+=1
229
- if i==mx:
230
- break
231
-
232
- class SafeTensorFile:
233
- """Main class for accessing safetensors files that provides memory-efficient access"""
234
-
235
- def __init__(self, file_path, metadata, catalog, skip_bytes, lazy_loading = True):
236
- self._file_path = file_path
237
- self._metadata = metadata
238
- self._catalog = catalog
239
- self._skip_bytes = skip_bytes
240
- self._keys = None
241
- self.sd = None
242
- self.mtracker = None
243
- self.lazy_loading = lazy_loading
244
-
245
- @classmethod
246
- def load_metadata(cls, file_path, lazy_loading = True):
247
- with open(file_path, 'rb') as f:
248
- catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
249
-
250
- return cls(file_path, metadata, catalog, skip_bytes, lazy_loading)
251
-
252
- def init_tensors(self, lazyTensors = True):
253
- if self.sd is None:
254
- self.lazy_loading = lazyTensors
255
- if lazyTensors:
256
- self.sd = self.create_tensors_with_mmap()
257
- else:
258
- self.sd = self.create_tensors_without_mmap()
259
- # else:
260
- # if not self.lazy_loading and lazyTensors:
261
- # raise Exception("Every tensor should be either lazy loaded or not lazy loaded")
262
-
263
- return self.sd
264
-
265
-
266
- def create_tensors_with_mmap(self):
267
-
268
- self.mtracker = MmapTracker(self._file_path)
269
- import mmap
270
-
271
- PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
272
- MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
273
- # MMAP_SIZE = 256 * 1024 * 1024 # 1GB
274
-
275
- # First pass: find optimal aligned map boundaries
276
- skip_bytes = self._skip_bytes
277
- tensor_map_indexes = []
278
- maps_info = []
279
- current_pos = skip_bytes
280
- current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
281
- current_map_size = skip_bytes - current_map_start
282
- idx = 0
283
- for k,v in self._catalog.items():
284
- data_offsets = v["data_offsets"]
285
- length = data_offsets[1]-data_offsets[0]
286
- if current_map_size + length > MMAP_SIZE:
287
- maps_info.append((current_map_start, current_map_size))
288
- current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
289
- current_map_size = current_pos - current_map_start
290
- idx += 1
291
- tensor_map_indexes.append(idx)
292
- current_map_size += length
293
- current_pos += length
294
-
295
- maps_info.append((current_map_start, current_map_size))
296
-
297
- # Second pass: create maps and tensors
298
- maps = []
299
- sd = OrderedDict()
300
-
301
- current_pos = skip_bytes
302
- with open(self._file_path, 'rb') as f:
303
- i = 0
304
- for map_start, map_size in maps_info:
305
- mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access=mmap.ACCESS_COPY) #.ACCESS_READ
306
- maps.append((mm, map_start, map_size))
307
- self.mtracker.register(mm, i, map_start, map_size)
308
- i = i+ 1
309
-
310
- iter_tensor_no = iter(tensor_map_indexes)
311
- for k,v in self._catalog.items():
312
- dtypestr = v["dtype"]
313
- dtype= _map_to_dtype[dtypestr]
314
- shape = v["shape"]
315
- data_offsets = v["data_offsets"]
316
- length = data_offsets[1]-data_offsets[0]
317
- map_idx = next(iter_tensor_no)
318
- offset = current_pos - maps[map_idx][1]
319
- if length == 0:
320
- t = torch.empty(shape, dtype=dtype)
321
- elif len(shape) == 0:
322
- # don't waste a memory view for a scalar
323
- t = torch.frombuffer(bytearray(maps[map_idx][0][offset:offset + length]), dtype=torch.uint8)
324
- t = t.view(dtype)
325
- else:
326
- mv = memoryview(maps[map_idx][0])[offset:offset + length]
327
- t = torch.frombuffer(mv, dtype=dtype)
328
- t = torch.reshape(t, shape)
329
- # t._mmap = maps[map_idx][0]
330
- sd[k] = t
331
- current_pos += length
332
-
333
- return sd
334
-
335
-
336
- def create_tensors_without_mmap(self):
337
- sd = OrderedDict()
338
-
339
- with open(self._file_path, 'rb') as f:
340
- f.seek(self._skip_bytes, 0)
341
- for k,v in self._catalog.items():
342
- dtypestr = v["dtype"]
343
- dtype= _map_to_dtype[dtypestr]
344
- shape = v["shape"]
345
- data_offsets = v["data_offsets"]
346
- length = data_offsets[1]-data_offsets[0]
347
- buffer = f.read(length)
348
- if length == 0:
349
- t = torch.empty(0, dtype=dtype)
350
- elif len(shape) == 0:
351
- t = torch.frombuffer(bytearray(buffer), dtype=torch.uint8)
352
- t = t.view(dtype)
353
- else:
354
- t = torch.frombuffer(bytearray(buffer), dtype=dtype)
355
- t = torch.reshape(t, shape)
356
- sd[k] = t
357
- return sd
358
-
359
- def get_tensor(self, name: str) -> torch.tensor:
360
- """Get a tensor by name"""
361
- # To do : switch to a JIT tensor creation per tensor
362
- self.init_tensors()
363
- return self.sd[name]
364
-
365
- def keys(self) -> List[str]:
366
- """Get list of tensor names"""
367
- if self._keys is None:
368
- self._keys = list(self._catalog)
369
- return self._keys
370
-
371
- def names(self) -> List[str]:
372
- """Alias for keys()"""
373
- return self.keys()
374
-
375
- def tensors(self) -> Dict[str, torch.tensor]:
376
- """Get dictionary of all tensors"""
377
- self.init_tensors(self.lazy_loading)
378
- return self.sd
379
-
380
- def metadata(self) -> Optional[Dict[str, str]]:
381
- """Get metadata dictionary"""
382
- return self._metadata
383
-
384
- def __len__(self) -> int:
385
- """Get number of tensors"""
386
- self.init_tensors(self.lazy_loading)
387
- return len(self.keys())
388
-
389
- def __contains__(self, key: str) -> bool:
390
- """Check if tensor exists"""
391
- return key in self.keys()
392
-
393
- def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
394
- """Iterate over (name, tensor) pairs"""
395
- return ((name, self.get_tensor(name)) for name in self.keys())
396
-
397
- def _free_resources(self):
398
- del self.sd
399
- del self._catalog
400
-
401
- class _SafeTensorLoader:
402
- """Context manager for loading SafeTensorFile"""
403
-
404
- def __init__(self, filename: str ):
405
- self.filename = Path(filename)
406
- self.sft = None
407
- if not self.filename.exists():
408
- raise FileNotFoundError(f"File not found: {filename}")
409
-
410
- def __enter__(self) -> SafeTensorFile:
411
- """Open file and return SafeTensorFile instance"""
412
-
413
- try:
414
- self.sft = SafeTensorFile.load_metadata(self.filename)
415
- return self.sft
416
-
417
- except Exception as e:
418
- self.close()
419
- raise Exception(f"Failed to load safetensors file: {e}") from e
420
-
421
- def __exit__(self, exc_type, exc_val, exc_tb) -> None:
422
- """Clean up resources"""
423
- self.close()
424
-
425
- def close(self) -> None:
426
- if self.sft != None:
427
- self.sft._free_resources()
428
- pass
429
-
430
-
431
- def safe_open(filename: str, framework: str = "pt",device = "cpu") -> _SafeTensorLoader:
432
- if device != "cpu" or framework !="pt":
433
- return _old_safe_open(filename =filename, framework=framework, device=device)
434
- return _SafeTensorLoader(filename)
435
-
436
- def torch_load_file( filename, device = 'cpu' ) -> Dict[str, torch.Tensor]:
437
- sd = {}
438
- with safe_open(filename, framework="pt", device = device ) as f:
439
- for k in f.keys():
440
- sd[k] = f.get_tensor(k)
441
- return sd
442
200
 
443
- _old_torch_load_file = safetensors.torch.load_file
444
- safetensors.torch.load_file = torch_load_file
445
- _old_safe_open = safetensors.safe_open
446
- safetensors.safe_open = safe_open
447
- accelerate.utils.modeling.safe_open = safe_open
448
- accelerate.utils.modeling.safe_load_file = torch_load_file
449
- try:
450
- import transformers
451
- transformers.modeling_utils.safe_open = safe_open
452
- transformers.modeling_utils.safe_load_file = torch_load_file
453
- except:
454
- pass
201
+ def load_metadata_state_dict(file_path):
202
+ with open(file_path, 'rb') as f:
203
+ catalog, metadata, _ = _read_safetensors_header(file_path, f)
204
+ sd = OrderedDict()
205
+ for k, v in catalog.items():
206
+ dtypestr = v["dtype"]
207
+ dtype = _map_to_dtype.get(dtypestr)
208
+ if dtype is None:
209
+ raise KeyError(f"Unknown safetensors dtype '{dtypestr}' in {file_path}")
210
+ sd[k] = tensor_stub(dtype, v["shape"])
211
+ return sd, metadata
212
+
213
+
214
+ def torch_write_file(sd, file_path, quantization_map = None, config = None, extra_meta = None):
215
+ from collections import OrderedDict
216
+ sf_sd = OrderedDict()
217
+
218
+ map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
219
+ torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
220
+ torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
221
+ pos = 0
222
+ i = 0
223
+ mx = 100000
224
+ metadata = dict()
225
+ for k , t in sd.items():
226
+ if torch.is_tensor(t):
227
+ entry = {}
228
+ dtypestr= map[t.dtype]
229
+ entry["dtype"] = dtypestr
230
+ entry["shape"] = list(t.shape)
231
+ size = torch.numel(t) * t.element_size()
232
+ if size == 0:
233
+ pass
234
+ entry["data_offsets"] = [pos, pos + size]
235
+ pos += size
236
+ sf_sd[k] = entry
237
+ else:
238
+ if isinstance(t, str):
239
+ metadata[k] = t
240
+ else:
241
+ try:
242
+ b64 = base64.b64encode(json.dumps(t, ensure_ascii=False).encode('utf8')).decode('utf8')
243
+ metadata[k + "_base64"] = b64
244
+ except:
245
+ pass
246
+
247
+ i+=1
248
+ if i==mx:
249
+ break
250
+ if not quantization_map is None:
251
+ metadata["quantization_format"] = "quanto"
252
+ metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
253
+
254
+ if not config is None:
255
+ metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
256
+
257
+ if not extra_meta is None:
258
+ for n , m in extra_meta.items():
259
+ if isinstance(m, str):
260
+ metadata[n] = m
261
+ else:
262
+ metadata[n + "_base64"] = base64.b64encode(json.dumps(m, ensure_ascii=False).encode('utf8')).decode('utf8')
263
+
264
+
265
+ if len(metadata) > 0:
266
+ sf_sd["__metadata__"] = metadata
267
+
268
+ header_bytes = json.dumps(sf_sd).encode()
269
+ #header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
270
+ size_header = len(header_bytes)
271
+ import struct
272
+
273
+ length_of_header_bytes = struct.pack('<Q', size_header)
274
+
275
+ with open(file_path, "wb") as writer:
276
+ bytes_written = writer.write(length_of_header_bytes)
277
+ bytes_written = writer.write(header_bytes)
278
+
279
+ i = 0
280
+ for k , t in sd.items():
281
+ if torch.is_tensor(t):
282
+ size = torch.numel(t) * t.element_size()
283
+ if size != 0:
284
+ dtype = t.dtype
285
+ # convert in a friendly format, scalars types not supported by numpy
286
+ if dtype == torch.bfloat16:
287
+ t = t.view(torch.uint16)
288
+ elif dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn:
289
+ t = t.view(torch.uint8)
290
+ buffer = t.cpu().numpy().tobytes()
291
+ bytes_written = writer.write(buffer)
292
+ assert bytes_written == size
293
+ i+=1
294
+ if i==mx:
295
+ break
296
+
297
+ class SafeTensorFile:
298
+ """Main class for accessing safetensors files that provides memory-efficient access"""
299
+
300
+ def __init__(self, file_path, metadata, catalog, skip_bytes, lazy_loading = True, writable_tensors = True):
301
+ self._file_path = file_path
302
+ self._metadata = metadata
303
+ self._catalog = catalog
304
+ self._skip_bytes = skip_bytes
305
+ self._keys = None
306
+ self.sd = None
307
+ self.mtracker = None
308
+ self.lazy_loading = lazy_loading
309
+ self.writable_tensors = writable_tensors
310
+
311
+ @classmethod
312
+ def load_metadata(cls, file_path, lazy_loading = True, writable_tensors = True):
313
+ with open(file_path, 'rb') as f:
314
+ catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
315
+
316
+ return cls(file_path, metadata, catalog, skip_bytes, lazy_loading, writable_tensors )
317
+
318
+ def init_tensors(self, lazyTensors = True, writable_tensors = True):
319
+ if self.sd is None:
320
+ self.lazy_loading = lazyTensors
321
+ if lazyTensors:
322
+ self.sd = self.create_tensors_with_mmap(writable_tensors)
323
+ else:
324
+ self.sd = self.create_tensors_without_mmap()
325
+ # else:
326
+ # if not self.lazy_loading and lazyTensors:
327
+ # raise Exception("Every tensor should be either lazy loaded or not lazy loaded")
328
+
329
+ return self.sd
330
+
331
+
332
+ def create_tensors_with_mmap(self, writable_tensors = True):
333
+
334
+ self.mtracker = MmapTracker(self._file_path)
335
+ import mmap
336
+
337
+ PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
338
+ MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
339
+ # MMAP_SIZE = 256 * 1024 * 1024 # 1GB
340
+
341
+ # First pass: find optimal aligned map boundaries
342
+ skip_bytes = self._skip_bytes
343
+ tensor_map_indexes = []
344
+ maps_info = []
345
+ current_pos = skip_bytes
346
+ current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
347
+ current_map_size = skip_bytes - current_map_start
348
+ idx = 0
349
+ for k,v in self._catalog.items():
350
+ data_offsets = v["data_offsets"]
351
+ length = data_offsets[1]-data_offsets[0]
352
+ if current_map_size + length > MMAP_SIZE:
353
+ maps_info.append((current_map_start, current_map_size))
354
+ current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
355
+ current_map_size = current_pos - current_map_start
356
+ idx += 1
357
+ tensor_map_indexes.append(idx)
358
+ current_map_size += length
359
+ current_pos += length
360
+
361
+ maps_info.append((current_map_start, current_map_size))
362
+
363
+ # Second pass: create maps and tensors
364
+ maps = []
365
+ sd = OrderedDict()
366
+
367
+ current_pos = skip_bytes
368
+ with open(self._file_path, 'rb') as f:
369
+ i = 0
370
+ for map_start, map_size in maps_info:
371
+ mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access= mmap.ACCESS_COPY if writable_tensors else mmap.ACCESS_READ)
372
+ maps.append((mm, map_start, map_size))
373
+ self.mtracker.register(mm, i, map_start, map_size)
374
+ i = i+ 1
375
+
376
+ iter_tensor_no = iter(tensor_map_indexes)
377
+ for k,v in self._catalog.items():
378
+ dtypestr = v["dtype"]
379
+ dtype= _map_to_dtype[dtypestr]
380
+ shape = v["shape"]
381
+ data_offsets = v["data_offsets"]
382
+ length = data_offsets[1]-data_offsets[0]
383
+ map_idx = next(iter_tensor_no)
384
+ offset = current_pos - maps[map_idx][1]
385
+ if length == 0:
386
+ t = torch.empty(shape, dtype=dtype)
387
+ elif len(shape) == 0:
388
+ # don't waste a memory view for a scalar
389
+ t = torch.frombuffer(bytearray(maps[map_idx][0][offset:offset + length]), dtype=torch.uint8)
390
+ t = t.view(dtype)
391
+ else:
392
+ mv = memoryview(maps[map_idx][0])[offset:offset + length]
393
+ t = torch.frombuffer(mv, dtype=dtype)
394
+ t = torch.reshape(t, shape)
395
+ # t._mmap = maps[map_idx][0]
396
+ sd[k] = t
397
+ current_pos += length
398
+
399
+ return sd
400
+
401
+
402
+ def create_tensors_without_mmap(self):
403
+ sd = OrderedDict()
404
+
405
+ with open(self._file_path, 'rb') as f:
406
+ f.seek(self._skip_bytes, 0)
407
+ for k,v in self._catalog.items():
408
+ dtypestr = v["dtype"]
409
+ dtype= _map_to_dtype[dtypestr]
410
+ shape = v["shape"]
411
+ data_offsets = v["data_offsets"]
412
+ length = data_offsets[1]-data_offsets[0]
413
+ buffer = f.read(length)
414
+ if length == 0:
415
+ t = torch.empty(0, dtype=dtype)
416
+ elif len(shape) == 0:
417
+ t = torch.frombuffer(bytearray(buffer), dtype=torch.uint8)
418
+ t = t.view(dtype)
419
+ else:
420
+ t = torch.frombuffer(bytearray(buffer), dtype=dtype)
421
+ t = torch.reshape(t, shape)
422
+ sd[k] = t
423
+ return sd
424
+
425
+ def get_slice(self, name: str) -> torch.tensor:
426
+ return tensor_slice(self._catalog, name, self.get_tensor(name))
427
+
428
+ def get_tensor(self, name: str) -> torch.tensor:
429
+ """Get a tensor by name"""
430
+ # To do : switch to a JIT tensor creation per tensor
431
+ self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
432
+ return self.sd[name]
433
+
434
+ def keys(self) -> List[str]:
435
+ """Get list of tensor names"""
436
+ if self._keys is None:
437
+ self._keys = list(self._catalog)
438
+ return self._keys
439
+
440
+ def names(self) -> List[str]:
441
+ """Alias for keys()"""
442
+ return self.keys()
443
+
444
+ def tensors(self) -> Dict[str, torch.tensor]:
445
+ """Get dictionary of all tensors"""
446
+ self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
447
+ return self.sd
448
+
449
+ def metadata(self) -> Optional[Dict[str, str]]:
450
+ """Get metadata dictionary"""
451
+ return self._metadata
452
+
453
+ def __len__(self) -> int:
454
+ """Get number of tensors"""
455
+ self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
456
+ return len(self.keys())
457
+
458
+ def __contains__(self, key: str) -> bool:
459
+ """Check if tensor exists"""
460
+ return key in self.keys()
461
+
462
+ def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
463
+ """Iterate over (name, tensor) pairs"""
464
+ return ((name, self.get_tensor(name)) for name in self.keys())
465
+
466
+ def _free_resources(self):
467
+ del self.sd
468
+ del self._catalog
469
+
470
+ class _SafeTensorLoader:
471
+ """Context manager for loading SafeTensorFile"""
472
+
473
+ def __init__(self, filename: str, writable_tensors = True ):
474
+ self.filename = Path(filename)
475
+ self.writable_tensors = writable_tensors
476
+ self.sft = None
477
+ if not self.filename.exists():
478
+ raise FileNotFoundError(f"File not found: {filename}")
479
+
480
+ def __enter__(self) -> SafeTensorFile:
481
+ """Open file and return SafeTensorFile instance"""
482
+ writable_tensors = self.writable_tensors
483
+
484
+ if all_tensors_are_read_only:
485
+ writable_tensors = False
486
+
487
+ try:
488
+ self.sft = SafeTensorFile.load_metadata(self.filename, writable_tensors= writable_tensors)
489
+ return self.sft
490
+
491
+ except Exception as e:
492
+ self.close()
493
+ raise Exception(f"Failed to load safetensors file: {e}") from e
494
+
495
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
496
+ """Clean up resources"""
497
+ self.close()
498
+
499
+ def get_tensor(self, name):
500
+ if self.sft == None:
501
+ self.__enter__()
502
+ return self.sft.get_tensor(name)
503
+
504
+ def get_slice(self, name):
505
+ if self.sft == None:
506
+ self.__enter__()
507
+ return self.sft.get_slice(name)
508
+
509
+ def close(self) -> None:
510
+ if self.sft != None:
511
+ self.sft._free_resources()
512
+ pass
513
+
514
+
515
+ def safe_open(filename: str, framework: str = "pt",device = "cpu", writable_tensors = True) -> _SafeTensorLoader:
516
+ if device != "cpu" or framework !="pt":
517
+ return _old_safe_open(filename =filename, framework=framework, device=device)
518
+ return _SafeTensorLoader(filename, writable_tensors = writable_tensors)
519
+
520
+ def torch_load_file( filename, device = 'cpu', writable_tensors = True) -> Dict[str, torch.Tensor]:
521
+ sd = {}
522
+ with safe_open(filename, framework="pt", device = device, writable_tensors =writable_tensors ) as f:
523
+ for k in f.keys():
524
+ sd[k] = f.get_tensor(k)
525
+ return sd
526
+
527
+ _old_torch_load_file = safetensors.torch.load_file
528
+ safetensors.torch.load_file = torch_load_file
529
+ _old_safe_open = safetensors.safe_open
530
+ safetensors.safe_open = safe_open
531
+ accelerate.utils.modeling.safe_open = safe_open
532
+ accelerate.utils.modeling.safe_load_file = torch_load_file
533
+ try:
534
+ import transformers
535
+ transformers.modeling_utils.safe_open = safe_open
536
+ transformers.modeling_utils.safe_load_file = torch_load_file
537
+ except:
538
+ pass