mmgp 3.3.1__py3-none-any.whl → 3.6.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mmgp/fp8_quanto_bridge.py +645 -0
- mmgp/fp8_quanto_bridge_old.py +498 -0
- mmgp/offload.py +3613 -2461
- mmgp/quant_router.py +518 -0
- mmgp/quanto_int8_cuda.py +97 -0
- mmgp/quanto_int8_inject.py +335 -0
- mmgp/safetensors2.py +534 -450
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/METADATA +195 -197
- mmgp-3.6.11.dist-info/RECORD +14 -0
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/WHEEL +1 -1
- mmgp-3.3.1.dist-info/RECORD +0 -9
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/licenses/LICENSE.md +0 -0
- {mmgp-3.3.1.dist-info → mmgp-3.6.11.dist-info}/top_level.txt +0 -0
mmgp/safetensors2.py
CHANGED
|
@@ -1,454 +1,538 @@
|
|
|
1
|
-
# ------------------ Safetensors2 1.
|
|
2
|
-
#
|
|
3
|
-
# This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
|
|
4
|
-
# It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
|
|
5
|
-
# You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from typing import Optional, Dict, List, Iterator, Tuple
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
import torch
|
|
11
|
-
import mmap
|
|
12
|
-
import struct
|
|
13
|
-
import json
|
|
14
|
-
import base64
|
|
15
|
-
import safetensors
|
|
16
|
-
import accelerate
|
|
17
|
-
import os
|
|
18
|
-
from collections import OrderedDict
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
self.
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def
|
|
88
|
-
self.
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
1
|
+
# ------------------ Safetensors2 1.3 by DeepBeepMeep (mmgp)------------------
|
|
2
|
+
#
|
|
3
|
+
# This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
|
|
4
|
+
# It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
|
|
5
|
+
# You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from typing import Optional, Dict, List, Iterator, Tuple
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import torch
|
|
11
|
+
import mmap
|
|
12
|
+
import struct
|
|
13
|
+
import json
|
|
14
|
+
import base64
|
|
15
|
+
import safetensors
|
|
16
|
+
import accelerate
|
|
17
|
+
import os
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
warnings.filterwarnings("ignore", ".*The given buffer is not writable, and PyTorch does not support non-writable tensors*")
|
|
22
|
+
|
|
23
|
+
_old_torch_load_file = None
|
|
24
|
+
_old_safe_open = None
|
|
25
|
+
|
|
26
|
+
all_tensors_are_read_only = False
|
|
27
|
+
|
|
28
|
+
mmm = {}
|
|
29
|
+
verboseLevel = 1
|
|
30
|
+
|
|
31
|
+
import weakref
|
|
32
|
+
|
|
33
|
+
_map_to_dtype = { 'BF16': torch.bfloat16, 'U8': torch.uint8 , 'U16': torch.uint16, 'U32' : torch.uint32 , 'U64' : torch.uint64,
|
|
34
|
+
'I8': torch.int8, 'I16': torch.int16, 'I32' : torch.int32 , 'I64' : torch.int64,
|
|
35
|
+
'F64' : torch.float64, 'F32': torch.float32, 'F16': torch.float16, 'BOOL' : torch.bool, "F8_E5M2" : torch.float8_e5m2, "F8_E4M3" : torch.float8_e4m3fn }
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MmapTracker:
|
|
39
|
+
def __init__(self, file_path):
|
|
40
|
+
self._maps = {}
|
|
41
|
+
self._already_released = 0
|
|
42
|
+
from pathlib import Path
|
|
43
|
+
s = Path(file_path).parts
|
|
44
|
+
if len(s)>2:
|
|
45
|
+
s = s[-2:]
|
|
46
|
+
file_path = os.path.join(*s)
|
|
47
|
+
self.file_path = file_path # os.path.abspath(file_path)
|
|
48
|
+
self.count = 0
|
|
49
|
+
key = file_path
|
|
50
|
+
i = 1
|
|
51
|
+
while True:
|
|
52
|
+
if key not in mmm:
|
|
53
|
+
mmm[key] = self
|
|
54
|
+
break
|
|
55
|
+
i +=1
|
|
56
|
+
key = key + "#" + str(i)
|
|
57
|
+
self.mmm_key = key
|
|
58
|
+
# print(f"MMAP Add: {file_path}: {mmm.keys()}")
|
|
59
|
+
|
|
60
|
+
def register(self, mmap_obj, map_id, start, size):
|
|
61
|
+
|
|
62
|
+
self.count += 1
|
|
63
|
+
def finalizer(ref):
|
|
64
|
+
self._already_released += 1
|
|
65
|
+
if verboseLevel >=2:
|
|
66
|
+
if self.count == self._already_released:
|
|
67
|
+
text =" (all the mmaps have been released)"
|
|
68
|
+
else:
|
|
69
|
+
text =f" ({self.count-self._already_released:} left)"
|
|
70
|
+
|
|
71
|
+
print(f"MMap Manager of file '{self.file_path}' : MMap no {map_id} has been released" + text)
|
|
72
|
+
if self.count == self._already_released:
|
|
73
|
+
# print(f"MMAP Del: {self.file_path}: {mmm.keys()}")
|
|
74
|
+
del mmm[self.mmm_key ]
|
|
75
|
+
|
|
76
|
+
self._maps.pop(map_id, None)
|
|
77
|
+
|
|
78
|
+
wr = weakref.ref(mmap_obj, finalizer)
|
|
79
|
+
self._maps[map_id] = {
|
|
80
|
+
'mmap' : wr,
|
|
81
|
+
'start': start,
|
|
82
|
+
'size': size,
|
|
83
|
+
'end': start + size
|
|
84
|
+
}
|
|
85
|
+
return wr
|
|
86
|
+
|
|
87
|
+
def get_active_maps(self):
|
|
88
|
+
return dict(self._maps)
|
|
89
|
+
|
|
90
|
+
class tensor_slice:
|
|
91
|
+
catalog = None
|
|
92
|
+
value = None
|
|
93
|
+
name = None
|
|
94
|
+
|
|
95
|
+
def __init__(self, catalog, name, value):
|
|
96
|
+
self.catalog = catalog
|
|
97
|
+
self.value = value
|
|
98
|
+
self.name = name
|
|
99
|
+
|
|
100
|
+
def __getitem__(self, s):
|
|
101
|
+
return self.value[s]
|
|
102
|
+
|
|
103
|
+
def get_dtype(self):
|
|
104
|
+
return self.catalog[self.name]["dtype"]
|
|
105
|
+
|
|
106
|
+
def get_shape(self):
|
|
107
|
+
return self.catalog[self.name]["shape"]
|
|
108
|
+
|
|
109
|
+
class tensor_stub:
|
|
110
|
+
dtype = None
|
|
111
|
+
shape = None
|
|
112
|
+
|
|
113
|
+
def __init__(self, dtype, shape):
|
|
114
|
+
self.dtype = dtype
|
|
115
|
+
self.shape = tuple(shape)
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def ndim(self):
|
|
119
|
+
return len(self.shape)
|
|
120
|
+
|
|
121
|
+
def numel(self):
|
|
122
|
+
if not self.shape:
|
|
123
|
+
return 1
|
|
124
|
+
n = 1
|
|
125
|
+
for dim in self.shape:
|
|
126
|
+
n *= int(dim)
|
|
127
|
+
return n
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def device(self):
|
|
131
|
+
return torch.device("cpu")
|
|
132
|
+
|
|
133
|
+
class cached_metadata:
|
|
134
|
+
file_path = None
|
|
135
|
+
file_length = 0
|
|
136
|
+
file_date = None
|
|
137
|
+
catalog = None
|
|
138
|
+
metadata = None
|
|
139
|
+
skip_bytes = 0
|
|
140
|
+
|
|
141
|
+
def __init__(self, file_path, catalog, metadata, skip_bytes):
|
|
142
|
+
self.catalog = catalog
|
|
143
|
+
self.metadata = metadata
|
|
144
|
+
self.skip_bytes = skip_bytes
|
|
145
|
+
file_stats = os.stat(file_path)
|
|
146
|
+
self.file_path = os.path.abspath(file_path)
|
|
147
|
+
self.file_length = file_stats.st_size
|
|
148
|
+
self.file_date = file_stats.st_ctime
|
|
149
|
+
|
|
150
|
+
def get_metadata(self, file_path):
|
|
151
|
+
file_stats = os.stat(file_path)
|
|
152
|
+
file_length = file_stats.st_size
|
|
153
|
+
file_date = file_stats.st_ctime
|
|
154
|
+
file_path = os.path.abspath(file_path)
|
|
155
|
+
if self.file_path != file_path or self.file_length != file_length or self.file_date != file_date:
|
|
156
|
+
return None, None, None
|
|
157
|
+
return self.catalog, self.metadata, self.skip_bytes
|
|
158
|
+
|
|
159
|
+
_cached_entry = None # ideally we should create a dict of the last n entries but one entry covers most cases
|
|
160
|
+
|
|
161
|
+
def _parse_metadata(metadata):
|
|
162
|
+
new_metadata= {}
|
|
163
|
+
if metadata != None:
|
|
164
|
+
for k,v in metadata.items():
|
|
165
|
+
if k.endswith("_base64"):
|
|
166
|
+
v_decoded = json.loads(base64.b64decode(v.encode('utf8')).decode('utf8'))
|
|
167
|
+
p = k.rfind("_")
|
|
168
|
+
new_k = k[:p]
|
|
169
|
+
new_metadata[new_k]= v_decoded
|
|
170
|
+
else:
|
|
171
|
+
new_metadata[k] = v
|
|
172
|
+
if "format" not in new_metadata:
|
|
173
|
+
new_metadata["format"] = "pt"
|
|
174
|
+
return new_metadata
|
|
175
|
+
|
|
124
176
|
def _read_safetensors_header(path, file):
|
|
125
|
-
global _cached_entry
|
|
126
|
-
length_of_header_bytes = file.read(8)
|
|
127
|
-
# Interpret the bytes as a little-endian unsigned 64-bit integer
|
|
128
|
-
length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
|
|
129
|
-
|
|
130
|
-
if _cached_entry != None:
|
|
131
|
-
catalog, metadata, _ = _cached_entry.get_metadata(path)
|
|
132
|
-
else:
|
|
133
|
-
catalog = None
|
|
134
|
-
|
|
135
|
-
if catalog == None:
|
|
136
|
-
header_bytes = file.read(length_of_header)
|
|
137
|
-
#catalog = json.loads(header_bytes.decode('utf-8'))
|
|
138
|
-
catalog = json.loads(header_bytes)
|
|
139
|
-
metadata = catalog.pop("__metadata__", None)
|
|
140
|
-
metadata = _parse_metadata(metadata)
|
|
141
|
-
|
|
142
|
-
_cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
|
|
143
|
-
else:
|
|
144
|
-
file.seek(length_of_header, 1)
|
|
145
|
-
|
|
177
|
+
global _cached_entry
|
|
178
|
+
length_of_header_bytes = file.read(8)
|
|
179
|
+
# Interpret the bytes as a little-endian unsigned 64-bit integer
|
|
180
|
+
length_of_header = struct.unpack('<Q', length_of_header_bytes)[0]
|
|
181
|
+
|
|
182
|
+
if _cached_entry != None:
|
|
183
|
+
catalog, metadata, _ = _cached_entry.get_metadata(path)
|
|
184
|
+
else:
|
|
185
|
+
catalog = None
|
|
186
|
+
|
|
187
|
+
if catalog == None:
|
|
188
|
+
header_bytes = file.read(length_of_header)
|
|
189
|
+
#catalog = json.loads(header_bytes.decode('utf-8'))
|
|
190
|
+
catalog = json.loads(header_bytes)
|
|
191
|
+
metadata = catalog.pop("__metadata__", None)
|
|
192
|
+
metadata = _parse_metadata(metadata)
|
|
193
|
+
|
|
194
|
+
_cached_entry = cached_metadata(path, catalog, metadata,length_of_header )
|
|
195
|
+
else:
|
|
196
|
+
file.seek(length_of_header, 1)
|
|
197
|
+
|
|
146
198
|
return catalog, metadata, length_of_header + 8
|
|
147
199
|
|
|
148
|
-
|
|
149
|
-
def torch_write_file(sd, file_path, quantization_map = None, config = None, extra_meta = None):
|
|
150
|
-
from collections import OrderedDict
|
|
151
|
-
sf_sd = OrderedDict()
|
|
152
|
-
|
|
153
|
-
map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
|
|
154
|
-
torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
|
|
155
|
-
torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
|
|
156
|
-
pos = 0
|
|
157
|
-
i = 0
|
|
158
|
-
mx = 100000
|
|
159
|
-
metadata = dict()
|
|
160
|
-
for k , t in sd.items():
|
|
161
|
-
if torch.is_tensor(t):
|
|
162
|
-
entry = {}
|
|
163
|
-
dtypestr= map[t.dtype]
|
|
164
|
-
entry["dtype"] = dtypestr
|
|
165
|
-
entry["shape"] = list(t.shape)
|
|
166
|
-
size = torch.numel(t) * t.element_size()
|
|
167
|
-
if size == 0:
|
|
168
|
-
pass
|
|
169
|
-
entry["data_offsets"] = [pos, pos + size]
|
|
170
|
-
pos += size
|
|
171
|
-
sf_sd[k] = entry
|
|
172
|
-
else:
|
|
173
|
-
if isinstance(t, str):
|
|
174
|
-
metadata[k] = t
|
|
175
|
-
else:
|
|
176
|
-
try:
|
|
177
|
-
b64 = base64.b64encode(json.dumps(t, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
178
|
-
metadata[k + "_base64"] = b64
|
|
179
|
-
except:
|
|
180
|
-
pass
|
|
181
|
-
|
|
182
|
-
i+=1
|
|
183
|
-
if i==mx:
|
|
184
|
-
break
|
|
185
|
-
if not quantization_map is None:
|
|
186
|
-
metadata["quantization_format"] = "quanto"
|
|
187
|
-
metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
188
|
-
|
|
189
|
-
if not config is None:
|
|
190
|
-
metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
191
|
-
|
|
192
|
-
if not extra_meta is None:
|
|
193
|
-
for n , m in extra_meta.items():
|
|
194
|
-
if isinstance(m, str):
|
|
195
|
-
metadata[n] = m
|
|
196
|
-
else:
|
|
197
|
-
metadata[n + "_base64"] = base64.b64encode(json.dumps(m, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
if len(metadata) > 0:
|
|
201
|
-
sf_sd["__metadata__"] = metadata
|
|
202
|
-
|
|
203
|
-
header_bytes = json.dumps(sf_sd).encode()
|
|
204
|
-
#header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
|
|
205
|
-
size_header = len(header_bytes)
|
|
206
|
-
import struct
|
|
207
|
-
|
|
208
|
-
length_of_header_bytes = struct.pack('<Q', size_header)
|
|
209
|
-
|
|
210
|
-
with open(file_path, "wb") as writer:
|
|
211
|
-
bytes_written = writer.write(length_of_header_bytes)
|
|
212
|
-
bytes_written = writer.write(header_bytes)
|
|
213
|
-
|
|
214
|
-
i = 0
|
|
215
|
-
for k , t in sd.items():
|
|
216
|
-
if torch.is_tensor(t):
|
|
217
|
-
size = torch.numel(t) * t.element_size()
|
|
218
|
-
if size != 0:
|
|
219
|
-
dtype = t.dtype
|
|
220
|
-
# convert in a friendly format, scalars types not supported by numpy
|
|
221
|
-
if dtype == torch.bfloat16:
|
|
222
|
-
t = t.view(torch.uint16)
|
|
223
|
-
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn:
|
|
224
|
-
t = t.view(torch.uint8)
|
|
225
|
-
buffer = t.numpy().tobytes()
|
|
226
|
-
bytes_written = writer.write(buffer)
|
|
227
|
-
assert bytes_written == size
|
|
228
|
-
i+=1
|
|
229
|
-
if i==mx:
|
|
230
|
-
break
|
|
231
|
-
|
|
232
|
-
class SafeTensorFile:
|
|
233
|
-
"""Main class for accessing safetensors files that provides memory-efficient access"""
|
|
234
|
-
|
|
235
|
-
def __init__(self, file_path, metadata, catalog, skip_bytes, lazy_loading = True):
|
|
236
|
-
self._file_path = file_path
|
|
237
|
-
self._metadata = metadata
|
|
238
|
-
self._catalog = catalog
|
|
239
|
-
self._skip_bytes = skip_bytes
|
|
240
|
-
self._keys = None
|
|
241
|
-
self.sd = None
|
|
242
|
-
self.mtracker = None
|
|
243
|
-
self.lazy_loading = lazy_loading
|
|
244
|
-
|
|
245
|
-
@classmethod
|
|
246
|
-
def load_metadata(cls, file_path, lazy_loading = True):
|
|
247
|
-
with open(file_path, 'rb') as f:
|
|
248
|
-
catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
|
|
249
|
-
|
|
250
|
-
return cls(file_path, metadata, catalog, skip_bytes, lazy_loading)
|
|
251
|
-
|
|
252
|
-
def init_tensors(self, lazyTensors = True):
|
|
253
|
-
if self.sd is None:
|
|
254
|
-
self.lazy_loading = lazyTensors
|
|
255
|
-
if lazyTensors:
|
|
256
|
-
self.sd = self.create_tensors_with_mmap()
|
|
257
|
-
else:
|
|
258
|
-
self.sd = self.create_tensors_without_mmap()
|
|
259
|
-
# else:
|
|
260
|
-
# if not self.lazy_loading and lazyTensors:
|
|
261
|
-
# raise Exception("Every tensor should be either lazy loaded or not lazy loaded")
|
|
262
|
-
|
|
263
|
-
return self.sd
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
def create_tensors_with_mmap(self):
|
|
267
|
-
|
|
268
|
-
self.mtracker = MmapTracker(self._file_path)
|
|
269
|
-
import mmap
|
|
270
|
-
|
|
271
|
-
PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
|
|
272
|
-
MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
|
|
273
|
-
# MMAP_SIZE = 256 * 1024 * 1024 # 1GB
|
|
274
|
-
|
|
275
|
-
# First pass: find optimal aligned map boundaries
|
|
276
|
-
skip_bytes = self._skip_bytes
|
|
277
|
-
tensor_map_indexes = []
|
|
278
|
-
maps_info = []
|
|
279
|
-
current_pos = skip_bytes
|
|
280
|
-
current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
|
|
281
|
-
current_map_size = skip_bytes - current_map_start
|
|
282
|
-
idx = 0
|
|
283
|
-
for k,v in self._catalog.items():
|
|
284
|
-
data_offsets = v["data_offsets"]
|
|
285
|
-
length = data_offsets[1]-data_offsets[0]
|
|
286
|
-
if current_map_size + length > MMAP_SIZE:
|
|
287
|
-
maps_info.append((current_map_start, current_map_size))
|
|
288
|
-
current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
|
|
289
|
-
current_map_size = current_pos - current_map_start
|
|
290
|
-
idx += 1
|
|
291
|
-
tensor_map_indexes.append(idx)
|
|
292
|
-
current_map_size += length
|
|
293
|
-
current_pos += length
|
|
294
|
-
|
|
295
|
-
maps_info.append((current_map_start, current_map_size))
|
|
296
|
-
|
|
297
|
-
# Second pass: create maps and tensors
|
|
298
|
-
maps = []
|
|
299
|
-
sd = OrderedDict()
|
|
300
|
-
|
|
301
|
-
current_pos = skip_bytes
|
|
302
|
-
with open(self._file_path, 'rb') as f:
|
|
303
|
-
i = 0
|
|
304
|
-
for map_start, map_size in maps_info:
|
|
305
|
-
mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access=mmap.ACCESS_COPY) #.ACCESS_READ
|
|
306
|
-
maps.append((mm, map_start, map_size))
|
|
307
|
-
self.mtracker.register(mm, i, map_start, map_size)
|
|
308
|
-
i = i+ 1
|
|
309
|
-
|
|
310
|
-
iter_tensor_no = iter(tensor_map_indexes)
|
|
311
|
-
for k,v in self._catalog.items():
|
|
312
|
-
dtypestr = v["dtype"]
|
|
313
|
-
dtype= _map_to_dtype[dtypestr]
|
|
314
|
-
shape = v["shape"]
|
|
315
|
-
data_offsets = v["data_offsets"]
|
|
316
|
-
length = data_offsets[1]-data_offsets[0]
|
|
317
|
-
map_idx = next(iter_tensor_no)
|
|
318
|
-
offset = current_pos - maps[map_idx][1]
|
|
319
|
-
if length == 0:
|
|
320
|
-
t = torch.empty(shape, dtype=dtype)
|
|
321
|
-
elif len(shape) == 0:
|
|
322
|
-
# don't waste a memory view for a scalar
|
|
323
|
-
t = torch.frombuffer(bytearray(maps[map_idx][0][offset:offset + length]), dtype=torch.uint8)
|
|
324
|
-
t = t.view(dtype)
|
|
325
|
-
else:
|
|
326
|
-
mv = memoryview(maps[map_idx][0])[offset:offset + length]
|
|
327
|
-
t = torch.frombuffer(mv, dtype=dtype)
|
|
328
|
-
t = torch.reshape(t, shape)
|
|
329
|
-
# t._mmap = maps[map_idx][0]
|
|
330
|
-
sd[k] = t
|
|
331
|
-
current_pos += length
|
|
332
|
-
|
|
333
|
-
return sd
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
def create_tensors_without_mmap(self):
|
|
337
|
-
sd = OrderedDict()
|
|
338
|
-
|
|
339
|
-
with open(self._file_path, 'rb') as f:
|
|
340
|
-
f.seek(self._skip_bytes, 0)
|
|
341
|
-
for k,v in self._catalog.items():
|
|
342
|
-
dtypestr = v["dtype"]
|
|
343
|
-
dtype= _map_to_dtype[dtypestr]
|
|
344
|
-
shape = v["shape"]
|
|
345
|
-
data_offsets = v["data_offsets"]
|
|
346
|
-
length = data_offsets[1]-data_offsets[0]
|
|
347
|
-
buffer = f.read(length)
|
|
348
|
-
if length == 0:
|
|
349
|
-
t = torch.empty(0, dtype=dtype)
|
|
350
|
-
elif len(shape) == 0:
|
|
351
|
-
t = torch.frombuffer(bytearray(buffer), dtype=torch.uint8)
|
|
352
|
-
t = t.view(dtype)
|
|
353
|
-
else:
|
|
354
|
-
t = torch.frombuffer(bytearray(buffer), dtype=dtype)
|
|
355
|
-
t = torch.reshape(t, shape)
|
|
356
|
-
sd[k] = t
|
|
357
|
-
return sd
|
|
358
|
-
|
|
359
|
-
def get_tensor(self, name: str) -> torch.tensor:
|
|
360
|
-
"""Get a tensor by name"""
|
|
361
|
-
# To do : switch to a JIT tensor creation per tensor
|
|
362
|
-
self.init_tensors()
|
|
363
|
-
return self.sd[name]
|
|
364
|
-
|
|
365
|
-
def keys(self) -> List[str]:
|
|
366
|
-
"""Get list of tensor names"""
|
|
367
|
-
if self._keys is None:
|
|
368
|
-
self._keys = list(self._catalog)
|
|
369
|
-
return self._keys
|
|
370
|
-
|
|
371
|
-
def names(self) -> List[str]:
|
|
372
|
-
"""Alias for keys()"""
|
|
373
|
-
return self.keys()
|
|
374
|
-
|
|
375
|
-
def tensors(self) -> Dict[str, torch.tensor]:
|
|
376
|
-
"""Get dictionary of all tensors"""
|
|
377
|
-
self.init_tensors(self.lazy_loading)
|
|
378
|
-
return self.sd
|
|
379
|
-
|
|
380
|
-
def metadata(self) -> Optional[Dict[str, str]]:
|
|
381
|
-
"""Get metadata dictionary"""
|
|
382
|
-
return self._metadata
|
|
383
|
-
|
|
384
|
-
def __len__(self) -> int:
|
|
385
|
-
"""Get number of tensors"""
|
|
386
|
-
self.init_tensors(self.lazy_loading)
|
|
387
|
-
return len(self.keys())
|
|
388
|
-
|
|
389
|
-
def __contains__(self, key: str) -> bool:
|
|
390
|
-
"""Check if tensor exists"""
|
|
391
|
-
return key in self.keys()
|
|
392
|
-
|
|
393
|
-
def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
|
|
394
|
-
"""Iterate over (name, tensor) pairs"""
|
|
395
|
-
return ((name, self.get_tensor(name)) for name in self.keys())
|
|
396
|
-
|
|
397
|
-
def _free_resources(self):
|
|
398
|
-
del self.sd
|
|
399
|
-
del self._catalog
|
|
400
|
-
|
|
401
|
-
class _SafeTensorLoader:
|
|
402
|
-
"""Context manager for loading SafeTensorFile"""
|
|
403
|
-
|
|
404
|
-
def __init__(self, filename: str ):
|
|
405
|
-
self.filename = Path(filename)
|
|
406
|
-
self.sft = None
|
|
407
|
-
if not self.filename.exists():
|
|
408
|
-
raise FileNotFoundError(f"File not found: {filename}")
|
|
409
|
-
|
|
410
|
-
def __enter__(self) -> SafeTensorFile:
|
|
411
|
-
"""Open file and return SafeTensorFile instance"""
|
|
412
|
-
|
|
413
|
-
try:
|
|
414
|
-
self.sft = SafeTensorFile.load_metadata(self.filename)
|
|
415
|
-
return self.sft
|
|
416
|
-
|
|
417
|
-
except Exception as e:
|
|
418
|
-
self.close()
|
|
419
|
-
raise Exception(f"Failed to load safetensors file: {e}") from e
|
|
420
|
-
|
|
421
|
-
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
422
|
-
"""Clean up resources"""
|
|
423
|
-
self.close()
|
|
424
|
-
|
|
425
|
-
def close(self) -> None:
|
|
426
|
-
if self.sft != None:
|
|
427
|
-
self.sft._free_resources()
|
|
428
|
-
pass
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
def safe_open(filename: str, framework: str = "pt",device = "cpu") -> _SafeTensorLoader:
|
|
432
|
-
if device != "cpu" or framework !="pt":
|
|
433
|
-
return _old_safe_open(filename =filename, framework=framework, device=device)
|
|
434
|
-
return _SafeTensorLoader(filename)
|
|
435
|
-
|
|
436
|
-
def torch_load_file( filename, device = 'cpu' ) -> Dict[str, torch.Tensor]:
|
|
437
|
-
sd = {}
|
|
438
|
-
with safe_open(filename, framework="pt", device = device ) as f:
|
|
439
|
-
for k in f.keys():
|
|
440
|
-
sd[k] = f.get_tensor(k)
|
|
441
|
-
return sd
|
|
442
200
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
201
|
+
def load_metadata_state_dict(file_path):
|
|
202
|
+
with open(file_path, 'rb') as f:
|
|
203
|
+
catalog, metadata, _ = _read_safetensors_header(file_path, f)
|
|
204
|
+
sd = OrderedDict()
|
|
205
|
+
for k, v in catalog.items():
|
|
206
|
+
dtypestr = v["dtype"]
|
|
207
|
+
dtype = _map_to_dtype.get(dtypestr)
|
|
208
|
+
if dtype is None:
|
|
209
|
+
raise KeyError(f"Unknown safetensors dtype '{dtypestr}' in {file_path}")
|
|
210
|
+
sd[k] = tensor_stub(dtype, v["shape"])
|
|
211
|
+
return sd, metadata
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def torch_write_file(sd, file_path, quantization_map = None, config = None, extra_meta = None):
|
|
215
|
+
from collections import OrderedDict
|
|
216
|
+
sf_sd = OrderedDict()
|
|
217
|
+
|
|
218
|
+
map = { torch.bfloat16 : 'BF16' , torch.int64 : 'I64' , torch.int32 : 'I32' , torch.int16 : 'I16' , torch.int8 : 'I8' ,
|
|
219
|
+
torch.uint64 : 'U64' , torch.uint32 : 'U32' , torch.uint16 : 'U16' , torch.uint8 : 'U8' ,
|
|
220
|
+
torch.bool : 'BOOL' , torch.float64 : 'F64' , torch.float32 : 'F32' , torch.float16 : 'F16', torch.float8_e5m2 : "F8_E5M2", torch.float8_e4m3fn: "F8_E4M3" }
|
|
221
|
+
pos = 0
|
|
222
|
+
i = 0
|
|
223
|
+
mx = 100000
|
|
224
|
+
metadata = dict()
|
|
225
|
+
for k , t in sd.items():
|
|
226
|
+
if torch.is_tensor(t):
|
|
227
|
+
entry = {}
|
|
228
|
+
dtypestr= map[t.dtype]
|
|
229
|
+
entry["dtype"] = dtypestr
|
|
230
|
+
entry["shape"] = list(t.shape)
|
|
231
|
+
size = torch.numel(t) * t.element_size()
|
|
232
|
+
if size == 0:
|
|
233
|
+
pass
|
|
234
|
+
entry["data_offsets"] = [pos, pos + size]
|
|
235
|
+
pos += size
|
|
236
|
+
sf_sd[k] = entry
|
|
237
|
+
else:
|
|
238
|
+
if isinstance(t, str):
|
|
239
|
+
metadata[k] = t
|
|
240
|
+
else:
|
|
241
|
+
try:
|
|
242
|
+
b64 = base64.b64encode(json.dumps(t, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
243
|
+
metadata[k + "_base64"] = b64
|
|
244
|
+
except:
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
i+=1
|
|
248
|
+
if i==mx:
|
|
249
|
+
break
|
|
250
|
+
if not quantization_map is None:
|
|
251
|
+
metadata["quantization_format"] = "quanto"
|
|
252
|
+
metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
253
|
+
|
|
254
|
+
if not config is None:
|
|
255
|
+
metadata["config_base64"] = base64.b64encode(json.dumps(config, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
256
|
+
|
|
257
|
+
if not extra_meta is None:
|
|
258
|
+
for n , m in extra_meta.items():
|
|
259
|
+
if isinstance(m, str):
|
|
260
|
+
metadata[n] = m
|
|
261
|
+
else:
|
|
262
|
+
metadata[n + "_base64"] = base64.b64encode(json.dumps(m, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
if len(metadata) > 0:
|
|
266
|
+
sf_sd["__metadata__"] = metadata
|
|
267
|
+
|
|
268
|
+
header_bytes = json.dumps(sf_sd).encode()
|
|
269
|
+
#header_bytes =json.dumps(config, ensure_ascii=False).encode('utf8')
|
|
270
|
+
size_header = len(header_bytes)
|
|
271
|
+
import struct
|
|
272
|
+
|
|
273
|
+
length_of_header_bytes = struct.pack('<Q', size_header)
|
|
274
|
+
|
|
275
|
+
with open(file_path, "wb") as writer:
|
|
276
|
+
bytes_written = writer.write(length_of_header_bytes)
|
|
277
|
+
bytes_written = writer.write(header_bytes)
|
|
278
|
+
|
|
279
|
+
i = 0
|
|
280
|
+
for k , t in sd.items():
|
|
281
|
+
if torch.is_tensor(t):
|
|
282
|
+
size = torch.numel(t) * t.element_size()
|
|
283
|
+
if size != 0:
|
|
284
|
+
dtype = t.dtype
|
|
285
|
+
# convert in a friendly format, scalars types not supported by numpy
|
|
286
|
+
if dtype == torch.bfloat16:
|
|
287
|
+
t = t.view(torch.uint16)
|
|
288
|
+
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn:
|
|
289
|
+
t = t.view(torch.uint8)
|
|
290
|
+
buffer = t.cpu().numpy().tobytes()
|
|
291
|
+
bytes_written = writer.write(buffer)
|
|
292
|
+
assert bytes_written == size
|
|
293
|
+
i+=1
|
|
294
|
+
if i==mx:
|
|
295
|
+
break
|
|
296
|
+
|
|
297
|
+
class SafeTensorFile:
|
|
298
|
+
"""Main class for accessing safetensors files that provides memory-efficient access"""
|
|
299
|
+
|
|
300
|
+
def __init__(self, file_path, metadata, catalog, skip_bytes, lazy_loading = True, writable_tensors = True):
|
|
301
|
+
self._file_path = file_path
|
|
302
|
+
self._metadata = metadata
|
|
303
|
+
self._catalog = catalog
|
|
304
|
+
self._skip_bytes = skip_bytes
|
|
305
|
+
self._keys = None
|
|
306
|
+
self.sd = None
|
|
307
|
+
self.mtracker = None
|
|
308
|
+
self.lazy_loading = lazy_loading
|
|
309
|
+
self.writable_tensors = writable_tensors
|
|
310
|
+
|
|
311
|
+
@classmethod
|
|
312
|
+
def load_metadata(cls, file_path, lazy_loading = True, writable_tensors = True):
|
|
313
|
+
with open(file_path, 'rb') as f:
|
|
314
|
+
catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
|
|
315
|
+
|
|
316
|
+
return cls(file_path, metadata, catalog, skip_bytes, lazy_loading, writable_tensors )
|
|
317
|
+
|
|
318
|
+
def init_tensors(self, lazyTensors = True, writable_tensors = True):
|
|
319
|
+
if self.sd is None:
|
|
320
|
+
self.lazy_loading = lazyTensors
|
|
321
|
+
if lazyTensors:
|
|
322
|
+
self.sd = self.create_tensors_with_mmap(writable_tensors)
|
|
323
|
+
else:
|
|
324
|
+
self.sd = self.create_tensors_without_mmap()
|
|
325
|
+
# else:
|
|
326
|
+
# if not self.lazy_loading and lazyTensors:
|
|
327
|
+
# raise Exception("Every tensor should be either lazy loaded or not lazy loaded")
|
|
328
|
+
|
|
329
|
+
return self.sd
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def create_tensors_with_mmap(self, writable_tensors = True):
|
|
333
|
+
|
|
334
|
+
self.mtracker = MmapTracker(self._file_path)
|
|
335
|
+
import mmap
|
|
336
|
+
|
|
337
|
+
PAGE_SIZE = mmap.ALLOCATIONGRANULARITY
|
|
338
|
+
MMAP_SIZE = 1024 * 1024 * 1024 # 1GB
|
|
339
|
+
# MMAP_SIZE = 256 * 1024 * 1024 # 1GB
|
|
340
|
+
|
|
341
|
+
# First pass: find optimal aligned map boundaries
|
|
342
|
+
skip_bytes = self._skip_bytes
|
|
343
|
+
tensor_map_indexes = []
|
|
344
|
+
maps_info = []
|
|
345
|
+
current_pos = skip_bytes
|
|
346
|
+
current_map_start = (skip_bytes // PAGE_SIZE) * PAGE_SIZE
|
|
347
|
+
current_map_size = skip_bytes - current_map_start
|
|
348
|
+
idx = 0
|
|
349
|
+
for k,v in self._catalog.items():
|
|
350
|
+
data_offsets = v["data_offsets"]
|
|
351
|
+
length = data_offsets[1]-data_offsets[0]
|
|
352
|
+
if current_map_size + length > MMAP_SIZE:
|
|
353
|
+
maps_info.append((current_map_start, current_map_size))
|
|
354
|
+
current_map_start = (current_pos // PAGE_SIZE) * PAGE_SIZE
|
|
355
|
+
current_map_size = current_pos - current_map_start
|
|
356
|
+
idx += 1
|
|
357
|
+
tensor_map_indexes.append(idx)
|
|
358
|
+
current_map_size += length
|
|
359
|
+
current_pos += length
|
|
360
|
+
|
|
361
|
+
maps_info.append((current_map_start, current_map_size))
|
|
362
|
+
|
|
363
|
+
# Second pass: create maps and tensors
|
|
364
|
+
maps = []
|
|
365
|
+
sd = OrderedDict()
|
|
366
|
+
|
|
367
|
+
current_pos = skip_bytes
|
|
368
|
+
with open(self._file_path, 'rb') as f:
|
|
369
|
+
i = 0
|
|
370
|
+
for map_start, map_size in maps_info:
|
|
371
|
+
mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access= mmap.ACCESS_COPY if writable_tensors else mmap.ACCESS_READ)
|
|
372
|
+
maps.append((mm, map_start, map_size))
|
|
373
|
+
self.mtracker.register(mm, i, map_start, map_size)
|
|
374
|
+
i = i+ 1
|
|
375
|
+
|
|
376
|
+
iter_tensor_no = iter(tensor_map_indexes)
|
|
377
|
+
for k,v in self._catalog.items():
|
|
378
|
+
dtypestr = v["dtype"]
|
|
379
|
+
dtype= _map_to_dtype[dtypestr]
|
|
380
|
+
shape = v["shape"]
|
|
381
|
+
data_offsets = v["data_offsets"]
|
|
382
|
+
length = data_offsets[1]-data_offsets[0]
|
|
383
|
+
map_idx = next(iter_tensor_no)
|
|
384
|
+
offset = current_pos - maps[map_idx][1]
|
|
385
|
+
if length == 0:
|
|
386
|
+
t = torch.empty(shape, dtype=dtype)
|
|
387
|
+
elif len(shape) == 0:
|
|
388
|
+
# don't waste a memory view for a scalar
|
|
389
|
+
t = torch.frombuffer(bytearray(maps[map_idx][0][offset:offset + length]), dtype=torch.uint8)
|
|
390
|
+
t = t.view(dtype)
|
|
391
|
+
else:
|
|
392
|
+
mv = memoryview(maps[map_idx][0])[offset:offset + length]
|
|
393
|
+
t = torch.frombuffer(mv, dtype=dtype)
|
|
394
|
+
t = torch.reshape(t, shape)
|
|
395
|
+
# t._mmap = maps[map_idx][0]
|
|
396
|
+
sd[k] = t
|
|
397
|
+
current_pos += length
|
|
398
|
+
|
|
399
|
+
return sd
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def create_tensors_without_mmap(self):
|
|
403
|
+
sd = OrderedDict()
|
|
404
|
+
|
|
405
|
+
with open(self._file_path, 'rb') as f:
|
|
406
|
+
f.seek(self._skip_bytes, 0)
|
|
407
|
+
for k,v in self._catalog.items():
|
|
408
|
+
dtypestr = v["dtype"]
|
|
409
|
+
dtype= _map_to_dtype[dtypestr]
|
|
410
|
+
shape = v["shape"]
|
|
411
|
+
data_offsets = v["data_offsets"]
|
|
412
|
+
length = data_offsets[1]-data_offsets[0]
|
|
413
|
+
buffer = f.read(length)
|
|
414
|
+
if length == 0:
|
|
415
|
+
t = torch.empty(0, dtype=dtype)
|
|
416
|
+
elif len(shape) == 0:
|
|
417
|
+
t = torch.frombuffer(bytearray(buffer), dtype=torch.uint8)
|
|
418
|
+
t = t.view(dtype)
|
|
419
|
+
else:
|
|
420
|
+
t = torch.frombuffer(bytearray(buffer), dtype=dtype)
|
|
421
|
+
t = torch.reshape(t, shape)
|
|
422
|
+
sd[k] = t
|
|
423
|
+
return sd
|
|
424
|
+
|
|
425
|
+
def get_slice(self, name: str) -> torch.tensor:
|
|
426
|
+
return tensor_slice(self._catalog, name, self.get_tensor(name))
|
|
427
|
+
|
|
428
|
+
def get_tensor(self, name: str) -> torch.tensor:
|
|
429
|
+
"""Get a tensor by name"""
|
|
430
|
+
# To do : switch to a JIT tensor creation per tensor
|
|
431
|
+
self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
|
|
432
|
+
return self.sd[name]
|
|
433
|
+
|
|
434
|
+
def keys(self) -> List[str]:
|
|
435
|
+
"""Get list of tensor names"""
|
|
436
|
+
if self._keys is None:
|
|
437
|
+
self._keys = list(self._catalog)
|
|
438
|
+
return self._keys
|
|
439
|
+
|
|
440
|
+
def names(self) -> List[str]:
|
|
441
|
+
"""Alias for keys()"""
|
|
442
|
+
return self.keys()
|
|
443
|
+
|
|
444
|
+
def tensors(self) -> Dict[str, torch.tensor]:
|
|
445
|
+
"""Get dictionary of all tensors"""
|
|
446
|
+
self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
|
|
447
|
+
return self.sd
|
|
448
|
+
|
|
449
|
+
def metadata(self) -> Optional[Dict[str, str]]:
|
|
450
|
+
"""Get metadata dictionary"""
|
|
451
|
+
return self._metadata
|
|
452
|
+
|
|
453
|
+
def __len__(self) -> int:
|
|
454
|
+
"""Get number of tensors"""
|
|
455
|
+
self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
|
|
456
|
+
return len(self.keys())
|
|
457
|
+
|
|
458
|
+
def __contains__(self, key: str) -> bool:
|
|
459
|
+
"""Check if tensor exists"""
|
|
460
|
+
return key in self.keys()
|
|
461
|
+
|
|
462
|
+
def __iter__(self) -> Iterator[Tuple[str, torch.tensor ]]:
|
|
463
|
+
"""Iterate over (name, tensor) pairs"""
|
|
464
|
+
return ((name, self.get_tensor(name)) for name in self.keys())
|
|
465
|
+
|
|
466
|
+
def _free_resources(self):
|
|
467
|
+
del self.sd
|
|
468
|
+
del self._catalog
|
|
469
|
+
|
|
470
|
+
class _SafeTensorLoader:
|
|
471
|
+
"""Context manager for loading SafeTensorFile"""
|
|
472
|
+
|
|
473
|
+
def __init__(self, filename: str, writable_tensors = True ):
|
|
474
|
+
self.filename = Path(filename)
|
|
475
|
+
self.writable_tensors = writable_tensors
|
|
476
|
+
self.sft = None
|
|
477
|
+
if not self.filename.exists():
|
|
478
|
+
raise FileNotFoundError(f"File not found: {filename}")
|
|
479
|
+
|
|
480
|
+
def __enter__(self) -> SafeTensorFile:
|
|
481
|
+
"""Open file and return SafeTensorFile instance"""
|
|
482
|
+
writable_tensors = self.writable_tensors
|
|
483
|
+
|
|
484
|
+
if all_tensors_are_read_only:
|
|
485
|
+
writable_tensors = False
|
|
486
|
+
|
|
487
|
+
try:
|
|
488
|
+
self.sft = SafeTensorFile.load_metadata(self.filename, writable_tensors= writable_tensors)
|
|
489
|
+
return self.sft
|
|
490
|
+
|
|
491
|
+
except Exception as e:
|
|
492
|
+
self.close()
|
|
493
|
+
raise Exception(f"Failed to load safetensors file: {e}") from e
|
|
494
|
+
|
|
495
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
496
|
+
"""Clean up resources"""
|
|
497
|
+
self.close()
|
|
498
|
+
|
|
499
|
+
def get_tensor(self, name):
|
|
500
|
+
if self.sft == None:
|
|
501
|
+
self.__enter__()
|
|
502
|
+
return self.sft.get_tensor(name)
|
|
503
|
+
|
|
504
|
+
def get_slice(self, name):
|
|
505
|
+
if self.sft == None:
|
|
506
|
+
self.__enter__()
|
|
507
|
+
return self.sft.get_slice(name)
|
|
508
|
+
|
|
509
|
+
def close(self) -> None:
|
|
510
|
+
if self.sft != None:
|
|
511
|
+
self.sft._free_resources()
|
|
512
|
+
pass
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def safe_open(filename: str, framework: str = "pt",device = "cpu", writable_tensors = True) -> _SafeTensorLoader:
|
|
516
|
+
if device != "cpu" or framework !="pt":
|
|
517
|
+
return _old_safe_open(filename =filename, framework=framework, device=device)
|
|
518
|
+
return _SafeTensorLoader(filename, writable_tensors = writable_tensors)
|
|
519
|
+
|
|
520
|
+
def torch_load_file( filename, device = 'cpu', writable_tensors = True) -> Dict[str, torch.Tensor]:
|
|
521
|
+
sd = {}
|
|
522
|
+
with safe_open(filename, framework="pt", device = device, writable_tensors =writable_tensors ) as f:
|
|
523
|
+
for k in f.keys():
|
|
524
|
+
sd[k] = f.get_tensor(k)
|
|
525
|
+
return sd
|
|
526
|
+
|
|
527
|
+
_old_torch_load_file = safetensors.torch.load_file
|
|
528
|
+
safetensors.torch.load_file = torch_load_file
|
|
529
|
+
_old_safe_open = safetensors.safe_open
|
|
530
|
+
safetensors.safe_open = safe_open
|
|
531
|
+
accelerate.utils.modeling.safe_open = safe_open
|
|
532
|
+
accelerate.utils.modeling.safe_load_file = torch_load_file
|
|
533
|
+
try:
|
|
534
|
+
import transformers
|
|
535
|
+
transformers.modeling_utils.safe_open = safe_open
|
|
536
|
+
transformers.modeling_utils.safe_load_file = torch_load_file
|
|
537
|
+
except:
|
|
538
|
+
pass
|