mmgp 3.0.0__py3-none-any.whl → 3.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/safetensors2.py CHANGED
@@ -1,387 +1,394 @@
1
- from typing import Optional, Dict, List, Iterator, Tuple
2
- from pathlib import Path
3
- import torch
4
- import mmap
5
- import struct
6
- import json
7
- import base64
8
- import safetensors
9
- import accelerate
10
- import os
11
- from collections import OrderedDict
12
-
13
-
14
- _old_torch_load_file = None
15
- _old_safe_open = None
16
-
17
-
18
-
19
- mmm = {}
20
- verboseLevel = 1
21
-
22
- import weakref
23
-
24
- _map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
25
- 'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
26
- 'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
27
-
28
-
29
- class MmapTracker:
30
- def __init__(self, file_path):
31
- self._maps = {}
32
- self._already_released = 0
33
- from pathlib import Path
34
- s = Path(file_path).parts
35
- if len(s)>2:
36
- s = s[-2:]
37
- file_path = os.path.join(*s)
38
- self.file_path = file_path # os.path.abspath(file_path)
39
- self.count = 0
40
- mmm[file_path] = self
41
-
42
- def register(self, mmap_obj, map_id, start, size):
43
-
44
- self.count += 1
45
- def finalizer(ref):
46
- self._already_released += 1
47
- if verboseLevel >=2:
48
- if self.count == self._already_released:
49
- text =" (all the mmaps have been released)"
50
- else:
51
- text =f" ({self.count-self._already_released:} left)"
52
-
53
- print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
54
- if self.count == self._already_released:
55
- del mmm[self.file_path]
56
-
57
- self._maps.pop(map_id, None)
58
-
59
- wr = weakref.ref(mmap_obj, finalizer)
60
- self._maps[map_id] = {
61
- 'mmap' : wr,
62
- 'start': start,
63
- 'size': size,
64
- 'end': start + size
65
- }
66
- return wr
67
-
68
- def get_active_maps(self):
69
- return dict(self._maps)
70
-
71
-
72
- class cached_metadata:
73
- file_path = None
74
- file_length = 0
75
- file_date = None
76
- catalog = None
77
- metadata = None
78
- skip_bytes = 0
79
-
80
- def __init__(self, file_path, catalog, metadata, skip_bytes):
81
- self.catalog = catalog
82
- self.metadata = metadata
83
- self.skip_bytes = skip_bytes
84
- file_stats = os.stat(file_path)
85
- self.file_path = os.path.abspath(file_path)
86
- self.file_length = file_stats.st_size
87
- self.file_date = file_stats.st_ctime
88
-
89
- def get_metadata(self, file_path):
90
- file_stats = os.stat(file_path)
91
- file_length = file_stats.st_size
92
- file_date = file_stats.st_ctime
93
- file_path = os.path.abspath(file_path)
94
- if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
95
- return None, None, None
96
- return self.catalog, self.metadata, self.skip_bytes
97
-
98
- _cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
99
-
100
- def _parse_metadata(metadata):
101
- if metadata == None:
102
- return None
103
-
104
- new_metadata= {}
105
-
106
- for k,v in metadata.items():
107
- if k.endswith("_base64"):
108
- v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
109
- p = k.rfind("_")
110
- new_k = k[:p]
111
- new_metadata[new_k]= v_decoded
112
- else:
113
- new_metadata[k] = v
114
-
115
- return new_metadata
116
-
117
- def _read_safetensors_header(path, file):
118
- global _cached_entry
119
- length_of_header_bytes = file.read(8)
120
- # Interpret the bytes as a little-endian unsigned 64-bit integer
121
- length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
122
-
123
- if _cached_entry != None:
124
- catalog, metadata, _ = _cached_entry.get_metadata(path)
125
- else:
126
- catalog = None
127
-
128
- if catalog == None:
129
- header_bytes = file.read(length_of_header)
130
- #catalog = json.loads(header_bytes.decode('utf-8'))
131
- catalog = json.loads(header_bytes)
132
- metadata = catalog.pop("__metadata__", None)
133
- metadata = _parse_metadata(metadata)
134
-
135
- _cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
136
- else:
137
- file.seek(length_of_header, 1)
138
-
139
- return catalog, metadata, length_of_header + 8
140
-
141
-
142
- def torch_write_file(sd, file_path, quantization_map = None, config = None):
143
- from collections import OrderedDict
144
- sf_sd = OrderedDict()
145
-
146
- map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
147
- torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
148
- torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
149
- pos = 0
150
- i = 0
151
- mx = 1000000
152
- for k , t in sd.items():
153
- entry = {}
154
- dtypestr= map[t.dtype]
155
- entry["dtype"] = dtypestr
156
- entry["shape"] = list(t.shape)
157
- size = torch.numel(t) * t.element_size()
158
- entry["data_offsets"] = [pos, pos + size]
159
- pos += size
160
- sf_sd[k] = entry
161
- i+=1
162
- if i==mx:
163
- break
164
- metadata = dict()
165
- if not quantization_map is None:
166
- metadata["quantization_format"] = "quanto"
167
- metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
168
-
169
- if not config is None:
170
- metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
171
-
172
- if len(metadata) > 0:
173
- sf_sd["__metadata__"] = metadata
174
-
175
- header_bytes = json.dumps(sf_sd).encode()
176
- #header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
177
- size_header = len(header_bytes)
178
- import struct
179
-
180
- length_of_header_bytes = struct.pack('<Q', size_header)
181
-
182
- empty_tensor = b'\x80\x3f'
183
-
184
- with open(file_path, "wb") as writer:
185
- bytes_written = writer.write(length_of_header_bytes)
186
- bytes_written = writer.write(header_bytes)
187
-
188
- i = 0
189
- for k , t in sd.items():
190
- size = torch.numel(t) * t.element_size()
191
- if len(t.shape) == 0:
192
- bytes_written = writer.write(empty_tensor)
193
- else:
194
- buffer = t.view(torch.uint8).numpy().tobytes()
195
- bytes_written = writer.write(buffer)
196
- assert bytes_written == size
197
- i+=1
198
- if i==mx:
199
- break
200
-
201
- class SafeTensorFile:
202
- """Main class for accessing safetensors files that provides memory-efficient access"""
203
-
204
- def __init__(self, file_path, metadata, catalog, skip_bytes):
205
- self._file_path = file_path
206
- self._metadata = metadata
207
- self._catalog = catalog
208
- self._skip_bytes = skip_bytes
209
- self._keys = None
210
- self.sd = None
211
- self.mtracker = None
212
-
213
- @classmethod
214
- def load_metadata(cls, file_path):
215
- with open(file_path, 'rb') as f:
216
- catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
217
-
218
- return cls(file_path, metadata, catalog, skip_bytes)
219
-
220
- def init_tensors(self):
221
- if self.sd is None:
222
- self.sd = self.create_tensors()
223
- return self.sd
224
-
225
- def create_tensors(self):
226
-
227
- self.mtracker = MmapTracker(self._file_path)
228
- import mmap
229
-
230
- PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
231
- MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
232
-
233
- # First pass: find optimal aligned map boundaries
234
- skip_bytes = self._skip_bytes
235
- tensor_map_indexes = []
236
- maps_info = []
237
- current_pos = skip_bytes
238
- current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
239
- current_map_size = skip_bytes - current_map_start
240
- idx = 0
241
- for k,v in self._catalog.items():
242
- data_offsets = v["data_offsets"]
243
- length = data_offsets[1]-data_offsets[0]
244
- if current_map_size + length > MMAP_SIZE:
245
- maps_info.append((current_map_start, current_map_size))
246
- current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
247
- current_map_size = current_pos - current_map_start
248
- idx += 1
249
- tensor_map_indexes.append(idx)
250
- current_map_size += length
251
- current_pos += length
252
-
253
- maps_info.append((current_map_start, current_map_size))
254
-
255
- # Second pass: create maps and tensors
256
- maps = []
257
- sd = OrderedDict()
258
-
259
- current_pos = skip_bytes
260
- with open(self._file_path, 'rb') as f:
261
- i = 0
262
- for map_start, map_size in maps_info:
263
- mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access=mmap.ACCESS_COPY) #.ACCESS_READ
264
- maps.append((mm, map_start, map_size))
265
- self.mtracker.register(mm, i, map_start, map_size)
266
- i = i+ 1
267
-
268
- iter_tensor_no = iter(tensor_map_indexes)
269
- for k,v in self._catalog.items():
270
- dtypestr = v["dtype"]
271
- dtype= _map_to_dtype[dtypestr]
272
- shape = v["shape"]
273
- data_offsets = v["data_offsets"]
274
- length = data_offsets[1]-data_offsets[0]
275
- map_idx = next(iter_tensor_no)
276
- offset = current_pos - maps[map_idx][1]
277
- if len(shape) == 0:
278
- t = torch.ones((), dtype=dtype, device="cpu")
279
- else:
280
- mv = memoryview(maps[map_idx][0])[offset:offset + length]
281
- t = torch.frombuffer(mv, dtype=dtype)
282
- t = torch.reshape(t, shape)
283
- # t._mmap = maps[map_idx][0]
284
- sd[k] = t
285
- current_pos += length
286
-
287
- return sd
288
-
289
- def get_tensor(self, name: str) -> torch.tensor:
290
- """Get a tensor by name"""
291
- self.init_tensors()
292
- return self.sd[name]
293
-
294
- def keys(self) -> List[str]:
295
- """Get list of tensor names"""
296
- if self._keys is None:
297
- self._keys = list(self._catalog)
298
- return self._keys
299
-
300
- def names(self) -> List[str]:
301
- """Alias for keys()"""
302
- return self.keys()
303
-
304
- def tensors(self) -> Dict[str, torch.tensor]:
305
- """Get dictionary of all tensors"""
306
- self.init_tensors()
307
- return self.sd
308
-
309
- def metadata(self) -> Optional[Dict[str, str]]:
310
- """Get metadata dictionary"""
311
- return self._metadata
312
-
313
- def __len__(self) -> int:
314
- """Get number of tensors"""
315
- self.init_tensors()
316
- return len(self.keys())
317
-
318
- def __contains__(self, key: str) -> bool:
319
- """Check if tensor exists"""
320
- return key in self.keys()
321
-
322
- def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
323
- """Iterate over (name, tensor) pairs"""
324
- return ((name, self.get_tensor(name)) for name in self.keys())
325
-
326
- def _free_resources(self):
327
- del self.sd
328
- del self._catalog
329
-
330
- class _SafeTensorLoader:
331
- """Context manager for loading SafeTensorFile"""
332
-
333
- def __init__(self, filename: str):
334
- self.filename = Path(filename)
335
- self.sft = None
336
-
337
- if not self.filename.exists():
338
- raise FileNotFoundError(f"File not found: {filename}")
339
-
340
- def __enter__(self) -> SafeTensorFile:
341
- """Open file and return SafeTensorFile instance"""
342
-
343
- try:
344
- self.sft = SafeTensorFile.load_metadata(self.filename)
345
- return self.sft
346
-
347
- except Exception as e:
348
- self.close()
349
- raise Exception(f"Failed to load safetensors file: {e}") from e
350
-
351
- def __exit__(self, exc_type, exc_val, exc_tb) -> None:
352
- """Clean up resources"""
353
- self.close()
354
-
355
- def close(self) -> None:
356
- if self.sft != None:
357
- self.sft._free_resources()
358
- pass
359
-
360
-
361
- def safe_open(filename: str, framework: str = "pt",device = "cpu") -> _SafeTensorLoader:
362
- if device != "cpu" or framework !="pt":
363
- pass
364
- return _old_safe_open(filename =filename, framework=framework, device=device)
365
- return _SafeTensorLoader(filename)
366
-
367
- def torch_load_file( filename, device = 'cpu' ) -> Dict[str, torch.Tensor]:
368
- sd = {}
369
- with safe_open(filename, framework="pt", device = device ) as f:
370
- for k in f.keys():
371
- sd[k] = f.get_tensor(k)
372
- return sd
373
-
374
- _old_torch_load_file = safetensors.torch.load_file
375
- safetensors.torch.load_file = torch_load_file
376
- _old_safe_open = safetensors.safe_open
377
- safetensors.safe_open = safe_open
378
- accelerate.utils.modeling.safe_open = safe_open
379
- accelerate.utils.modeling.safe_load_file = torch_load_file
380
- try:
381
- import transformers
382
- transformers.modeling_utils.safe_open = safe_open
383
- transformers.modeling_utils.safe_load_file = torch_load_file
384
- except:
385
- pass
386
-
387
-
1
+ # ------------------ Safetensors2 1.0 by DeepBeepMeep (mmgp)------------------
2
+ #
3
+ # This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
4
+ # It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
5
+ # You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
6
+
7
+
8
+ from typing import Optional, Dict, List, Iterator, Tuple
9
+ from pathlib import Path
10
+ import torch
11
+ import mmap
12
+ import struct
13
+ import json
14
+ import base64
15
+ import safetensors
16
+ import accelerate
17
+ import os
18
+ from collections import OrderedDict
19
+
20
+
21
+ _old_torch_load_file = None
22
+ _old_safe_open = None
23
+
24
+
25
+
26
+ mmm = {}
27
+ verboseLevel = 1
28
+
29
+ import weakref
30
+
31
+ _map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
32
+ 'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
33
+ 'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
34
+
35
+
36
+ class MmapTracker:
37
+ def __init__(self, file_path):
38
+ self._maps = {}
39
+ self._already_released = 0
40
+ from pathlib import Path
41
+ s = Path(file_path).parts
42
+ if len(s)>2:
43
+ s = s[-2:]
44
+ file_path = os.path.join(*s)
45
+ self.file_path = file_path # os.path.abspath(file_path)
46
+ self.count = 0
47
+ mmm[file_path] = self
48
+
49
+ def register(self, mmap_obj, map_id, start, size):
50
+
51
+ self.count += 1
52
+ def finalizer(ref):
53
+ self._already_released += 1
54
+ if verboseLevel >=2:
55
+ if self.count == self._already_released:
56
+ text =" (all the mmaps have been released)"
57
+ else:
58
+ text =f" ({self.count-self._already_released:} left)"
59
+
60
+ print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
61
+ if self.count == self._already_released:
62
+ del mmm[self.file_path]
63
+
64
+ self._maps.pop(map_id, None)
65
+
66
+ wr = weakref.ref(mmap_obj, finalizer)
67
+ self._maps[map_id] = {
68
+ 'mmap' : wr,
69
+ 'start': start,
70
+ 'size': size,
71
+ 'end': start + size
72
+ }
73
+ return wr
74
+
75
+ def get_active_maps(self):
76
+ return dict(self._maps)
77
+
78
+
79
+ class cached_metadata:
80
+ file_path = None
81
+ file_length = 0
82
+ file_date = None
83
+ catalog = None
84
+ metadata = None
85
+ skip_bytes = 0
86
+
87
+ def __init__(self, file_path, catalog, metadata, skip_bytes):
88
+ self.catalog = catalog
89
+ self.metadata = metadata
90
+ self.skip_bytes = skip_bytes
91
+ file_stats = os.stat(file_path)
92
+ self.file_path = os.path.abspath(file_path)
93
+ self.file_length = file_stats.st_size
94
+ self.file_date = file_stats.st_ctime
95
+
96
+ def get_metadata(self, file_path):
97
+ file_stats = os.stat(file_path)
98
+ file_length = file_stats.st_size
99
+ file_date = file_stats.st_ctime
100
+ file_path = os.path.abspath(file_path)
101
+ if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
102
+ return None, None, None
103
+ return self.catalog, self.metadata, self.skip_bytes
104
+
105
+ _cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
106
+
107
+ def _parse_metadata(metadata):
108
+ if metadata == None:
109
+ return None
110
+
111
+ new_metadata= {}
112
+
113
+ for k,v in metadata.items():
114
+ if k.endswith("_base64"):
115
+ v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
116
+ p = k.rfind("_")
117
+ new_k = k[:p]
118
+ new_metadata[new_k]= v_decoded
119
+ else:
120
+ new_metadata[k] = v
121
+
122
+ return new_metadata
123
+
124
+ def _read_safetensors_header(path, file):
125
+ global _cached_entry
126
+ length_of_header_bytes = file.read(8)
127
+ # Interpret the bytes as a little-endian unsigned 64-bit integer
128
+ length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
129
+
130
+ if _cached_entry != None:
131
+ catalog, metadata, _ = _cached_entry.get_metadata(path)
132
+ else:
133
+ catalog = None
134
+
135
+ if catalog == None:
136
+ header_bytes = file.read(length_of_header)
137
+ #catalog = json.loads(header_bytes.decode('utf-8'))
138
+ catalog = json.loads(header_bytes)
139
+ metadata = catalog.pop("__metadata__", None)
140
+ metadata = _parse_metadata(metadata)
141
+
142
+ _cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
143
+ else:
144
+ file.seek(length_of_header, 1)
145
+
146
+ return catalog, metadata, length_of_header + 8
147
+
148
+
149
+ def torch_write_file(sd, file_path, quantization_map = None, config = None):
150
+ from collections import OrderedDict
151
+ sf_sd = OrderedDict()
152
+
153
+ map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
154
+ torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
155
+ torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
156
+ pos = 0
157
+ i = 0
158
+ mx = 1000000
159
+ for k , t in sd.items():
160
+ entry = {}
161
+ dtypestr= map[t.dtype]
162
+ entry["dtype"] = dtypestr
163
+ entry["shape"] = list(t.shape)
164
+ size = torch.numel(t) * t.element_size()
165
+ entry["data_offsets"] = [pos, pos + size]
166
+ pos += size
167
+ sf_sd[k] = entry
168
+ i+=1
169
+ if i==mx:
170
+ break
171
+ metadata = dict()
172
+ if not quantization_map is None:
173
+ metadata["quantization_format"] = "quanto"
174
+ metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
175
+
176
+ if not config is None:
177
+ metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
178
+
179
+ if len(metadata) > 0:
180
+ sf_sd["__metadata__"] = metadata
181
+
182
+ header_bytes = json.dumps(sf_sd).encode()
183
+ #header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
184
+ size_header = len(header_bytes)
185
+ import struct
186
+
187
+ length_of_header_bytes = struct.pack('<Q', size_header)
188
+
189
+ empty_tensor = b'\x80\x3f'
190
+
191
+ with open(file_path, "wb") as writer:
192
+ bytes_written = writer.write(length_of_header_bytes)
193
+ bytes_written = writer.write(header_bytes)
194
+
195
+ i = 0
196
+ for k , t in sd.items():
197
+ size = torch.numel(t) * t.element_size()
198
+ if len(t.shape) == 0:
199
+ bytes_written = writer.write(empty_tensor)
200
+ else:
201
+ buffer = t.view(torch.uint8).numpy().tobytes()
202
+ bytes_written = writer.write(buffer)
203
+ assert bytes_written == size
204
+ i+=1
205
+ if i==mx:
206
+ break
207
+
208
+ class SafeTensorFile:
209
+ """Main class for accessing safetensors files that provides memory-efficient access"""
210
+
211
+ def __init__(self, file_path, metadata, catalog, skip_bytes):
212
+ self._file_path = file_path
213
+ self._metadata = metadata
214
+ self._catalog = catalog
215
+ self._skip_bytes = skip_bytes
216
+ self._keys = None
217
+ self.sd = None
218
+ self.mtracker = None
219
+
220
+ @classmethod
221
+ def load_metadata(cls, file_path):
222
+ with open(file_path, 'rb') as f:
223
+ catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
224
+
225
+ return cls(file_path, metadata, catalog, skip_bytes)
226
+
227
+ def init_tensors(self):
228
+ if self.sd is None:
229
+ self.sd = self.create_tensors()
230
+ return self.sd
231
+
232
+ def create_tensors(self):
233
+
234
+ self.mtracker = MmapTracker(self._file_path)
235
+ import mmap
236
+
237
+ PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
238
+ MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
239
+
240
+ # First pass: find optimal aligned map boundaries
241
+ skip_bytes = self._skip_bytes
242
+ tensor_map_indexes = []
243
+ maps_info = []
244
+ current_pos = skip_bytes
245
+ current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
246
+ current_map_size = skip_bytes - current_map_start
247
+ idx = 0
248
+ for k,v in self._catalog.items():
249
+ data_offsets = v["data_offsets"]
250
+ length = data_offsets[1]-data_offsets[0]
251
+ if current_map_size + length > MMAP_SIZE:
252
+ maps_info.append((current_map_start, current_map_size))
253
+ current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
254
+ current_map_size = current_pos - current_map_start
255
+ idx += 1
256
+ tensor_map_indexes.append(idx)
257
+ current_map_size += length
258
+ current_pos += length
259
+
260
+ maps_info.append((current_map_start, current_map_size))
261
+
262
+ # Second pass: create maps and tensors
263
+ maps = []
264
+ sd = OrderedDict()
265
+
266
+ current_pos = skip_bytes
267
+ with open(self._file_path, 'rb') as f:
268
+ i = 0
269
+ for map_start, map_size in maps_info:
270
+ mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access=mmap.ACCESS_COPY) #.ACCESS_READ
271
+ maps.append((mm, map_start, map_size))
272
+ self.mtracker.register(mm, i, map_start, map_size)
273
+ i = i+ 1
274
+
275
+ iter_tensor_no = iter(tensor_map_indexes)
276
+ for k,v in self._catalog.items():
277
+ dtypestr = v["dtype"]
278
+ dtype= _map_to_dtype[dtypestr]
279
+ shape = v["shape"]
280
+ data_offsets = v["data_offsets"]
281
+ length = data_offsets[1]-data_offsets[0]
282
+ map_idx = next(iter_tensor_no)
283
+ offset = current_pos - maps[map_idx][1]
284
+ if len(shape) == 0:
285
+ t = torch.ones((), dtype=dtype, device="cpu")
286
+ else:
287
+ mv = memoryview(maps[map_idx][0])[offset:offset + length]
288
+ t = torch.frombuffer(mv, dtype=dtype)
289
+ t = torch.reshape(t, shape)
290
+ # t._mmap = maps[map_idx][0]
291
+ sd[k] = t
292
+ current_pos += length
293
+
294
+ return sd
295
+
296
+ def get_tensor(self, name: str) -> torch.tensor:
297
+ """Get a tensor by name"""
298
+ self.init_tensors()
299
+ return self.sd[name]
300
+
301
+ def keys(self) -> List[str]:
302
+ """Get list of tensor names"""
303
+ if self._keys is None:
304
+ self._keys = list(self._catalog)
305
+ return self._keys
306
+
307
+ def names(self) -> List[str]:
308
+ """Alias for keys()"""
309
+ return self.keys()
310
+
311
+ def tensors(self) -> Dict[str, torch.tensor]:
312
+ """Get dictionary of all tensors"""
313
+ self.init_tensors()
314
+ return self.sd
315
+
316
+ def metadata(self) -> Optional[Dict[str, str]]:
317
+ """Get metadata dictionary"""
318
+ return self._metadata
319
+
320
+ def __len__(self) -> int:
321
+ """Get number of tensors"""
322
+ self.init_tensors()
323
+ return len(self.keys())
324
+
325
+ def __contains__(self, key: str) -> bool:
326
+ """Check if tensor exists"""
327
+ return key in self.keys()
328
+
329
+ def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
330
+ """Iterate over (name, tensor) pairs"""
331
+ return ((name, self.get_tensor(name)) for name in self.keys())
332
+
333
+ def _free_resources(self):
334
+ del self.sd
335
+ del self._catalog
336
+
337
+ class _SafeTensorLoader:
338
+ """Context manager for loading SafeTensorFile"""
339
+
340
+ def __init__(self, filename: str):
341
+ self.filename = Path(filename)
342
+ self.sft = None
343
+
344
+ if not self.filename.exists():
345
+ raise FileNotFoundError(f"File not found: {filename}")
346
+
347
+ def __enter__(self) -> SafeTensorFile:
348
+ """Open file and return SafeTensorFile instance"""
349
+
350
+ try:
351
+ self.sft = SafeTensorFile.load_metadata(self.filename)
352
+ return self.sft
353
+
354
+ except Exception as e:
355
+ self.close()
356
+ raise Exception(f"Failed to load safetensors file: {e}") from e
357
+
358
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
359
+ """Clean up resources"""
360
+ self.close()
361
+
362
+ def close(self) -> None:
363
+ if self.sft != None:
364
+ self.sft._free_resources()
365
+ pass
366
+
367
+
368
+ def safe_open(filename: str, framework: str = "pt",device = "cpu") -> _SafeTensorLoader:
369
+ if device != "cpu" or framework !="pt":
370
+ pass
371
+ return _old_safe_open(filename =filename, framework=framework, device=device)
372
+ return _SafeTensorLoader(filename)
373
+
374
+ def torch_load_file( filename, device = 'cpu' ) -> Dict[str, torch.Tensor]:
375
+ sd = {}
376
+ with safe_open(filename, framework="pt", device = device ) as f:
377
+ for k in f.keys():
378
+ sd[k] = f.get_tensor(k)
379
+ return sd
380
+
381
+ _old_torch_load_file = safetensors.torch.load_file
382
+ safetensors.torch.load_file = torch_load_file
383
+ _old_safe_open = safetensors.safe_open
384
+ safetensors.safe_open = safe_open
385
+ accelerate.utils.modeling.safe_open = safe_open
386
+ accelerate.utils.modeling.safe_load_file = torch_load_file
387
+ try:
388
+ import transformers
389
+ transformers.modeling_utils.safe_open = safe_open
390
+ transformers.modeling_utils.safe_load_file = torch_load_file
391
+ except:
392
+ pass
393
+
394
+