mmgp 2.0.4__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/safetensors2.py ADDED
@@ -0,0 +1,394 @@
1
+ # ------------------ Safetensors2 1.0 by DeepBeepMeep (mmgp)------------------
2
+ #
3
+ # This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
4
+ # It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
5
+ # You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
6
+
7
+
8
+ from typing import Optional, Dict, List, Iterator, Tuple
9
+ from pathlib import Path
10
+ import torch
11
+ import mmap
12
+ import struct
13
+ import json
14
+ import base64
15
+ import safetensors
16
+ import accelerate
17
+ import os
18
+ from collections import OrderedDict
19
+
20
+
21
+ _old_torch_load_file = None
22
+ _old_safe_open = None
23
+
24
+
25
+
26
+ mmm = {}
27
+ verboseLevel = 1
28
+
29
+ import weakref
30
+
31
+ _map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
32
+ 'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
33
+ 'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
34
+
35
+
36
+ class MmapTracker:
37
+ def __init__(self, file_path):
38
+ self._maps = {}
39
+ self._already_released = 0
40
+ from pathlib import Path
41
+ s = Path(file_path).parts
42
+ if len(s)>2:
43
+ s = s[-2:]
44
+ file_path = os.path.join(*s)
45
+ self.file_path = file_path # os.path.abspath(file_path)
46
+ self.count = 0
47
+ mmm[file_path] = self
48
+
49
+ def register(self, mmap_obj, map_id, start, size):
50
+
51
+ self.count += 1
52
+ def finalizer(ref):
53
+ self._already_released += 1
54
+ if verboseLevel >=2:
55
+ if self.count == self._already_released:
56
+ text =" (all the mmaps have been released)"
57
+ else:
58
+ text =f" ({self.count-self._already_released:} left)"
59
+
60
+ print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
61
+ if self.count == self._already_released:
62
+ del mmm[self.file_path]
63
+
64
+ self._maps.pop(map_id, None)
65
+
66
+ wr = weakref.ref(mmap_obj, finalizer)
67
+ self._maps[map_id] = {
68
+ 'mmap' : wr,
69
+ 'start': start,
70
+ 'size': size,
71
+ 'end': start + size
72
+ }
73
+ return wr
74
+
75
+ def get_active_maps(self):
76
+ return dict(self._maps)
77
+
78
+
79
+ class cached_metadata:
80
+ file_path = None
81
+ file_length = 0
82
+ file_date = None
83
+ catalog = None
84
+ metadata = None
85
+ skip_bytes = 0
86
+
87
+ def __init__(self, file_path, catalog, metadata, skip_bytes):
88
+ self.catalog = catalog
89
+ self.metadata = metadata
90
+ self.skip_bytes = skip_bytes
91
+ file_stats = os.stat(file_path)
92
+ self.file_path = os.path.abspath(file_path)
93
+ self.file_length = file_stats.st_size
94
+ self.file_date = file_stats.st_ctime
95
+
96
+ def get_metadata(self, file_path):
97
+ file_stats = os.stat(file_path)
98
+ file_length = file_stats.st_size
99
+ file_date = file_stats.st_ctime
100
+ file_path = os.path.abspath(file_path)
101
+ if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
102
+ return None, None, None
103
+ return self.catalog, self.metadata, self.skip_bytes
104
+
105
+ _cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
106
+
107
+ def _parse_metadata(metadata):
108
+ if metadata == None:
109
+ return None
110
+
111
+ new_metadata= {}
112
+
113
+ for k,v in metadata.items():
114
+ if k.endswith("_base64"):
115
+ v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
116
+ p = k.rfind("_")
117
+ new_k = k[:p]
118
+ new_metadata[new_k]= v_decoded
119
+ else:
120
+ new_metadata[k] = v
121
+
122
+ return new_metadata
123
+
124
+ def _read_safetensors_header(path, file):
125
+ global _cached_entry
126
+ length_of_header_bytes = file.read(8)
127
+ # Interpret the bytes as a little-endian unsigned 64-bit integer
128
+ length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
129
+
130
+ if _cached_entry != None:
131
+ catalog, metadata, _ = _cached_entry.get_metadata(path)
132
+ else:
133
+ catalog = None
134
+
135
+ if catalog == None:
136
+ header_bytes = file.read(length_of_header)
137
+ #catalog = json.loads(header_bytes.decode('utf-8'))
138
+ catalog = json.loads(header_bytes)
139
+ metadata = catalog.pop("__metadata__", None)
140
+ metadata = _parse_metadata(metadata)
141
+
142
+ _cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
143
+ else:
144
+ file.seek(length_of_header, 1)
145
+
146
+ return catalog, metadata, length_of_header + 8
147
+
148
+
149
+ def torch_write_file(sd, file_path, quantization_map = None, config = None):
150
+ from collections import OrderedDict
151
+ sf_sd = OrderedDict()
152
+
153
+ map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
154
+ torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
155
+ torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
156
+ pos = 0
157
+ i = 0
158
+ mx = 1000000
159
+ for k , t in sd.items():
160
+ entry = {}
161
+ dtypestr= map[t.dtype]
162
+ entry["dtype"] = dtypestr
163
+ entry["shape"] = list(t.shape)
164
+ size = torch.numel(t) * t.element_size()
165
+ entry["data_offsets"] = [pos, pos + size]
166
+ pos += size
167
+ sf_sd[k] = entry
168
+ i+=1
169
+ if i==mx:
170
+ break
171
+ metadata = dict()
172
+ if not quantization_map is None:
173
+ metadata["quantization_format"] = "quanto"
174
+ metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
175
+
176
+ if not config is None:
177
+ metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
178
+
179
+ if len(metadata) > 0:
180
+ sf_sd["__metadata__"] = metadata
181
+
182
+ header_bytes = json.dumps(sf_sd).encode()
183
+ #header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
184
+ size_header = len(header_bytes)
185
+ import struct
186
+
187
+ length_of_header_bytes = struct.pack('<Q', size_header)
188
+
189
+ empty_tensor = b'\x80\x3f'
190
+
191
+ with open(file_path, "wb") as writer:
192
+ bytes_written = writer.write(length_of_header_bytes)
193
+ bytes_written = writer.write(header_bytes)
194
+
195
+ i = 0
196
+ for k , t in sd.items():
197
+ size = torch.numel(t) * t.element_size()
198
+ if len(t.shape) == 0:
199
+ bytes_written = writer.write(empty_tensor)
200
+ else:
201
+ buffer = t.view(torch.uint8).numpy().tobytes()
202
+ bytes_written = writer.write(buffer)
203
+ assert bytes_written == size
204
+ i+=1
205
+ if i==mx:
206
+ break
207
+
208
+ class SafeTensorFile:
209
+ """Main class for accessing safetensors files that provides memory-efficient access"""
210
+
211
+ def __init__(self, file_path, metadata, catalog, skip_bytes):
212
+ self._file_path = file_path
213
+ self._metadata = metadata
214
+ self._catalog = catalog
215
+ self._skip_bytes = skip_bytes
216
+ self._keys = None
217
+ self.sd = None
218
+ self.mtracker = None
219
+
220
+ @classmethod
221
+ def load_metadata(cls, file_path):
222
+ with open(file_path, 'rb') as f:
223
+ catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
224
+
225
+ return cls(file_path, metadata, catalog, skip_bytes)
226
+
227
+ def init_tensors(self):
228
+ if self.sd is None:
229
+ self.sd = self.create_tensors()
230
+ return self.sd
231
+
232
+ def create_tensors(self):
233
+
234
+ self.mtracker = MmapTracker(self._file_path)
235
+ import mmap
236
+
237
+ PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
238
+ MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
239
+
240
+ # First pass: find optimal aligned map boundaries
241
+ skip_bytes = self._skip_bytes
242
+ tensor_map_indexes = []
243
+ maps_info = []
244
+ current_pos = skip_bytes
245
+ current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
246
+ current_map_size = skip_bytes - current_map_start
247
+ idx = 0
248
+ for k,v in self._catalog.items():
249
+ data_offsets = v["data_offsets"]
250
+ length = data_offsets[1]-data_offsets[0]
251
+ if current_map_size + length > MMAP_SIZE:
252
+ maps_info.append((current_map_start, current_map_size))
253
+ current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
254
+ current_map_size = current_pos - current_map_start
255
+ idx += 1
256
+ tensor_map_indexes.append(idx)
257
+ current_map_size += length
258
+ current_pos += length
259
+
260
+ maps_info.append((current_map_start, current_map_size))
261
+
262
+ # Second pass: create maps and tensors
263
+ maps = []
264
+ sd = OrderedDict()
265
+
266
+ current_pos = skip_bytes
267
+ with open(self._file_path, 'rb') as f:
268
+ i = 0
269
+ for map_start, map_size in maps_info:
270
+ mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access=mmap.ACCESS_COPY) #.ACCESS_READ
271
+ maps.append((mm, map_start, map_size))
272
+ self.mtracker.register(mm, i, map_start, map_size)
273
+ i = i+ 1
274
+
275
+ iter_tensor_no = iter(tensor_map_indexes)
276
+ for k,v in self._catalog.items():
277
+ dtypestr = v["dtype"]
278
+ dtype= _map_to_dtype[dtypestr]
279
+ shape = v["shape"]
280
+ data_offsets = v["data_offsets"]
281
+ length = data_offsets[1]-data_offsets[0]
282
+ map_idx = next(iter_tensor_no)
283
+ offset = current_pos - maps[map_idx][1]
284
+ if len(shape) == 0:
285
+ t = torch.ones((), dtype=dtype, device="cpu")
286
+ else:
287
+ mv = memoryview(maps[map_idx][0])[offset:offset + length]
288
+ t = torch.frombuffer(mv, dtype=dtype)
289
+ t = torch.reshape(t, shape)
290
+ # t._mmap = maps[map_idx][0]
291
+ sd[k] = t
292
+ current_pos += length
293
+
294
+ return sd
295
+
296
+ def get_tensor(self, name: str) -> torch.tensor:
297
+ """Get a tensor by name"""
298
+ self.init_tensors()
299
+ return self.sd[name]
300
+
301
+ def keys(self) -> List[str]:
302
+ """Get list of tensor names"""
303
+ if self._keys is None:
304
+ self._keys = list(self._catalog)
305
+ return self._keys
306
+
307
+ def names(self) -> List[str]:
308
+ """Alias for keys()"""
309
+ return self.keys()
310
+
311
+ def tensors(self) -> Dict[str, torch.tensor]:
312
+ """Get dictionary of all tensors"""
313
+ self.init_tensors()
314
+ return self.sd
315
+
316
+ def metadata(self) -> Optional[Dict[str, str]]:
317
+ """Get metadata dictionary"""
318
+ return self._metadata
319
+
320
+ def __len__(self) -> int:
321
+ """Get number of tensors"""
322
+ self.init_tensors()
323
+ return len(self.keys())
324
+
325
+ def __contains__(self, key: str) -> bool:
326
+ """Check if tensor exists"""
327
+ return key in self.keys()
328
+
329
+ def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
330
+ """Iterate over (name, tensor) pairs"""
331
+ return ((name, self.get_tensor(name)) for name in self.keys())
332
+
333
+ def _free_resources(self):
334
+ del self.sd
335
+ del self._catalog
336
+
337
+ class _SafeTensorLoader:
338
+ """Context manager for loading SafeTensorFile"""
339
+
340
+ def __init__(self, filename: str):
341
+ self.filename = Path(filename)
342
+ self.sft = None
343
+
344
+ if not self.filename.exists():
345
+ raise FileNotFoundError(f"File not found: {filename}")
346
+
347
+ def __enter__(self) -> SafeTensorFile:
348
+ """Open file and return SafeTensorFile instance"""
349
+
350
+ try:
351
+ self.sft = SafeTensorFile.load_metadata(self.filename)
352
+ return self.sft
353
+
354
+ except Exception as e:
355
+ self.close()
356
+ raise Exception(f"Failed to load safetensors file: {e}") from e
357
+
358
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
359
+ """Clean up resources"""
360
+ self.close()
361
+
362
+ def close(self) -> None:
363
+ if self.sft != None:
364
+ self.sft._free_resources()
365
+ pass
366
+
367
+
368
+ def safe_open(filename: str, framework: str = "pt",device = "cpu") -> _SafeTensorLoader:
369
+ if device != "cpu" or framework !="pt":
370
+ pass
371
+ return _old_safe_open(filename =filename, framework=framework, device=device)
372
+ return _SafeTensorLoader(filename)
373
+
374
+ def torch_load_file( filename, device = 'cpu' ) -> Dict[str, torch.Tensor]:
375
+ sd = {}
376
+ with safe_open(filename, framework="pt", device = device ) as f:
377
+ for k in f.keys():
378
+ sd[k] = f.get_tensor(k)
379
+ return sd
380
+
381
+ _old_torch_load_file = safetensors.torch.load_file
382
+ safetensors.torch.load_file = torch_load_file
383
+ _old_safe_open = safetensors.safe_open
384
+ safetensors.safe_open = safe_open
385
+ accelerate.utils.modeling.safe_open = safe_open
386
+ accelerate.utils.modeling.safe_load_file = torch_load_file
387
+ try:
388
+ import transformers
389
+ transformers.modeling_utils.safe_open = safe_open
390
+ transformers.modeling_utils.safe_load_file = torch_load_file
391
+ except:
392
+ pass
393
+
394
+
@@ -1,2 +1,2 @@
1
- GNU GENERAL PUBLIC LICENSE
1
+ GNU GENERAL PUBLIC LICENSE
2
2
  Version 3, 29 June 2007