safetensors 0.6.2__cp38-abi3-win_amd64.whl → 0.7.0rc0__cp38-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of safetensors might be problematic. Click here for more details.

Binary file
safetensors/numpy.py CHANGED
@@ -154,6 +154,7 @@ _TYPES = {
154
154
  "I8": np.int8,
155
155
  "U8": np.uint8,
156
156
  "BOOL": bool,
157
+ "C64": np.complex64,
157
158
  }
158
159
 
159
160
 
safetensors/paddle.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import os
2
- from typing import Dict, Optional, Union
2
+ import sys
3
+ from typing import Any, Dict, Optional, Union
3
4
 
4
5
  import numpy as np
5
-
6
6
  import paddle
7
- from safetensors import numpy
7
+
8
+ from safetensors import numpy, deserialize, safe_open, serialize, serialize_file
8
9
 
9
10
 
10
11
  def save(
@@ -34,8 +35,9 @@ def save(
34
35
  byte_data = save(tensors)
35
36
  ```
36
37
  """
37
- np_tensors = _paddle2np(tensors)
38
- return numpy.save(np_tensors, metadata=metadata)
38
+ serialized = serialize(_flatten(tensors), metadata=metadata)
39
+ result = bytes(serialized)
40
+ return result
39
41
 
40
42
 
41
43
  def save_file(
@@ -69,8 +71,7 @@ def save_file(
69
71
  save_file(tensors, "model.safetensors")
70
72
  ```
71
73
  """
72
- np_tensors = _paddle2np(tensors)
73
- return numpy.save_file(np_tensors, filename, metadata=metadata)
74
+ serialize_file(_flatten(tensors), filename, metadata=metadata)
74
75
 
75
76
 
76
77
  def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]:
@@ -96,8 +97,12 @@ def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]:
96
97
  loaded = load(data)
97
98
  ```
98
99
  """
99
- flat = numpy.load(data)
100
- return _np2paddle(flat, device)
100
+ if paddle.__version__ >= "3.2.0":
101
+ flat = deserialize(data)
102
+ return _view2paddle(flat, device)
103
+ else:
104
+ flat = numpy.load(data)
105
+ return _np2paddle(flat, device)
101
106
 
102
107
 
103
108
  def load_file(
@@ -125,9 +130,15 @@ def load_file(
125
130
  loaded = load_file(file_path)
126
131
  ```
127
132
  """
128
- flat = numpy.load_file(filename)
129
- output = _np2paddle(flat, device)
130
- return output
133
+ result = {}
134
+ if paddle.__version__ >= "3.2.0":
135
+ with safe_open(filename, framework="paddle", device=device) as f:
136
+ for k in f.offset_keys():
137
+ result[k] = f.get_tensor(k)
138
+ else:
139
+ flat = numpy.load_file(filename)
140
+ result = _np2paddle(flat, device)
141
+ return result
131
142
 
132
143
 
133
144
  def _np2paddle(
@@ -142,3 +153,138 @@ def _paddle2np(paddle_dict: Dict[str, paddle.Tensor]) -> Dict[str, np.array]:
142
153
  for k, v in paddle_dict.items():
143
154
  paddle_dict[k] = v.detach().cpu().numpy()
144
155
  return paddle_dict
156
+
157
+
158
+ _SIZE = {
159
+ paddle.int64: 8,
160
+ paddle.float32: 4,
161
+ paddle.int32: 4,
162
+ paddle.bfloat16: 2,
163
+ paddle.float16: 2,
164
+ paddle.int16: 2,
165
+ paddle.uint8: 1,
166
+ paddle.int8: 1,
167
+ paddle.bool: 1,
168
+ paddle.float64: 8,
169
+ paddle.float8_e4m3fn: 1,
170
+ paddle.float8_e5m2: 1,
171
+ paddle.complex64: 8,
172
+ # XXX: These are not supported yet in paddle
173
+ # paddle.uint64: 8,
174
+ # paddle.uint32: 4,
175
+ # paddle.uint16: 2,
176
+ # paddle.float8_e8m0: 1,
177
+ # paddle.float4_e2m1_x2: 1,
178
+ }
179
+
180
+ _TYPES = {
181
+ "F64": paddle.float64,
182
+ "F32": paddle.float32,
183
+ "F16": paddle.float16,
184
+ "BF16": paddle.bfloat16,
185
+ "I64": paddle.int64,
186
+ "I32": paddle.int32,
187
+ "I16": paddle.int16,
188
+ "I8": paddle.int8,
189
+ "U8": paddle.uint8,
190
+ "BOOL": paddle.bool,
191
+ "F8_E4M3": paddle.float8_e4m3fn,
192
+ "F8_E5M2": paddle.float8_e5m2,
193
+ }
194
+
195
+ NPDTYPES = {
196
+ paddle.int64: np.int64,
197
+ paddle.float32: np.float32,
198
+ paddle.int32: np.int32,
199
+ # XXX: This is ok because both have the same width
200
+ paddle.bfloat16: np.float16,
201
+ paddle.float16: np.float16,
202
+ paddle.int16: np.int16,
203
+ paddle.uint8: np.uint8,
204
+ paddle.int8: np.int8,
205
+ paddle.bool: bool,
206
+ paddle.float64: np.float64,
207
+ # XXX: This is ok because both have the same width and byteswap is a no-op anyway
208
+ paddle.float8_e4m3fn: np.uint8,
209
+ paddle.float8_e5m2: np.uint8,
210
+ }
211
+
212
+
213
+ def _getdtype(dtype_str: str) -> paddle.dtype:
214
+ return _TYPES[dtype_str]
215
+
216
+
217
+ def _view2paddle(safeview, device) -> Dict[str, paddle.Tensor]:
218
+ result = {}
219
+ for k, v in safeview:
220
+ dtype = _getdtype(v["dtype"])
221
+ if len(v["data"]) == 0:
222
+ # Workaround because frombuffer doesn't accept zero-size tensors
223
+ assert any(x == 0 for x in v["shape"])
224
+ arr = paddle.empty(v["shape"], dtype=dtype)
225
+ else:
226
+ arr = paddle.base.core.frombuffer(v["data"], dtype).reshape(v["shape"])
227
+ if device != "cpu":
228
+ arr = arr.to(device)
229
+ if sys.byteorder == "big":
230
+ arr = paddle.to_tensor(arr.numpy().byteswap(inplace=False), place=device)
231
+ result[k] = arr
232
+
233
+ return result
234
+
235
+
236
+ def _tobytes(tensor: paddle.Tensor, name: str) -> bytes:
237
+ if not tensor.is_contiguous():
238
+ raise ValueError(
239
+ f"You are trying to save a non contiguous tensor: `{name}` which is not allowed. It either means you"
240
+ " are trying to save tensors which are reference of each other in which case it's recommended to save"
241
+ " only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to"
242
+ " pack it before saving."
243
+ )
244
+ if not tensor.place.is_cpu_place():
245
+ # Moving tensor to cpu before saving
246
+ tensor = tensor.cpu()
247
+
248
+ import ctypes
249
+
250
+ import numpy as np
251
+
252
+ # When shape is empty (scalar), np.prod returns a float
253
+ # we need a int for the following calculations
254
+ length = int(np.prod(tensor.shape).item())
255
+ bytes_per_item = _SIZE[tensor.dtype]
256
+
257
+ total_bytes = length * bytes_per_item
258
+
259
+ ptr = tensor.data_ptr()
260
+ if ptr == 0:
261
+ return b""
262
+ newptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_ubyte))
263
+ data = np.ctypeslib.as_array(newptr, (total_bytes,)) # no internal copy
264
+ if sys.byteorder == "big":
265
+ npdtype = NPDTYPES[tensor.dtype]
266
+ # Not in place as that would potentially modify a live running model
267
+ data = data.view(npdtype).byteswap(inplace=False)
268
+ return data.tobytes()
269
+
270
+
271
+ def _flatten(tensors: Dict[str, paddle.Tensor]) -> Dict[str, Dict[str, Any]]:
272
+ if not isinstance(tensors, dict):
273
+ raise ValueError(
274
+ f"Expected a dict of [str, paddle.Tensor] but received {type(tensors)}"
275
+ )
276
+
277
+ for k, v in tensors.items():
278
+ if not isinstance(v, paddle.Tensor):
279
+ raise ValueError(
280
+ f"Key `{k}` is invalid, expected paddle.Tensor but received {type(v)}"
281
+ )
282
+
283
+ return {
284
+ k: {
285
+ "dtype": str(v.dtype).split(".")[-1],
286
+ "shape": v.shape,
287
+ "data": _tobytes(v, k),
288
+ }
289
+ for k, v in tensors.items()
290
+ }
safetensors/torch.py CHANGED
@@ -221,68 +221,23 @@ def load_model(
221
221
  to_removes = _remove_duplicate_names(
222
222
  model_state_dict, preferred_names=state_dict.keys()
223
223
  )
224
-
225
- reverse_to_remove = {}
226
- for key, to_remove_group in to_removes.items():
227
- for to_remove in to_remove_group:
228
- reverse_to_remove[to_remove] = key
229
-
230
- # We iterate on the model, so we'll add keys we find missing
231
- # here
232
- missing = set()
233
- # We start with all keys on disk declared as unexpected, we'll
234
- # slowly remove them when we find them
235
- unexpected = set(state_dict.keys())
236
- # Some keys can be invalid too.
237
- invalid = set()
238
-
239
- for k, mv in model_state_dict.items():
240
- actual_k = reverse_to_remove.get(k, None)
241
- if actual_k is not None:
242
- look_k = actual_k
243
- else:
244
- look_k = k
245
- v = state_dict.get(look_k, None)
246
- if v is None:
247
- missing.add(k)
248
- else:
249
- # We can actually check for the shapes while we're at it.
250
- # For the device, it's trickier given torch's internals
251
- # There might be some Meta device for faster initiation
252
- if v.dtype != mv.dtype or v.shape != mv.shape:
253
- invalid.add(k)
254
- if actual_k is None:
255
- unexpected.remove(k)
256
-
224
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
257
225
  missing = set(missing)
258
- unexpected = set(unexpected)
259
- if strict and (missing or unexpected or invalid):
226
+ for to_remove_group in to_removes.values():
227
+ for to_remove in to_remove_group:
228
+ if to_remove not in missing:
229
+ unexpected.append(to_remove)
230
+ else:
231
+ missing.remove(to_remove)
232
+ if strict and (missing or unexpected):
260
233
  missing_keys = ", ".join([f'"{k}"' for k in sorted(missing)])
261
234
  unexpected_keys = ", ".join([f'"{k}"' for k in sorted(unexpected)])
262
- invalid_keys = ", ".join([f'"{k}"' for k in sorted(invalid)])
263
235
  error = f"Error(s) in loading state_dict for {model.__class__.__name__}:"
264
236
  if missing:
265
237
  error += f"\n Missing key(s) in state_dict: {missing_keys}"
266
238
  if unexpected:
267
239
  error += f"\n Unexpected key(s) in state_dict: {unexpected_keys}"
268
- if invalid:
269
- error += f"\n Invalid key(s) in state_dict: {invalid_keys}, mismatched dtypes or shape."
270
- del state_dict
271
240
  raise RuntimeError(error)
272
-
273
- torch_missing, torch_unexpected = model.load_state_dict(state_dict, strict=False)
274
- # Sanity check that the work we've done matches
275
- # Pytorch internal loading.
276
- torch_missing = set(torch_missing)
277
- torch_unexpected = set(torch_unexpected)
278
- for to_remove_group in to_removes.values():
279
- for to_remove in to_remove_group:
280
- if to_remove not in torch_missing:
281
- torch_unexpected.add(to_remove)
282
- else:
283
- torch_missing.remove(to_remove)
284
- assert torch_missing == missing, f"{torch_missing} != {missing}"
285
- assert torch_unexpected == unexpected, f"{torch_unexpected} != {unexpected}"
286
241
  return missing, unexpected
287
242
 
288
243
 
@@ -428,6 +383,7 @@ _SIZE = {
428
383
  torch.int8: 1,
429
384
  torch.bool: 1,
430
385
  torch.float64: 8,
386
+ torch.complex64: 8,
431
387
  _float8_e4m3fn: 1,
432
388
  _float8_e5m2: 1,
433
389
  _float8_e8m0: 1,
@@ -455,6 +411,7 @@ _TYPES = {
455
411
  "BOOL": torch.bool,
456
412
  "F8_E4M3": _float8_e4m3fn,
457
413
  "F8_E5M2": _float8_e5m2,
414
+ "C64": torch.complex64,
458
415
  }
459
416
  if Version(torch.__version__) >= Version("2.3.0"):
460
417
  _TYPES.update(
@@ -538,6 +495,7 @@ def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
538
495
  # XXX: This is ok because both have the same width and byteswap is a no-op anyway
539
496
  _float8_e4m3fn: np.uint8,
540
497
  _float8_e5m2: np.uint8,
498
+ torch.complex64: np.complex64,
541
499
  }
542
500
  npdtype = NPDTYPES[tensor.dtype]
543
501
  # Not in place as that would potentially modify a live running model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: safetensors
3
- Version: 0.6.2
3
+ Version: 0.7.0rc0
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Developers
6
6
  Classifier: Intended Audience :: Education
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
16
  Classifier: Typing :: Typed
17
17
  Requires-Dist: numpy>=1.21.6 ; extra == 'numpy'
18
+ Requires-Dist: packaging ; extra == 'torch'
18
19
  Requires-Dist: safetensors[numpy] ; extra == 'torch'
19
20
  Requires-Dist: torch>=1.10 ; extra == 'torch'
20
21
  Requires-Dist: safetensors[numpy] ; extra == 'tensorflow'
@@ -0,0 +1,14 @@
1
+ safetensors-0.7.0rc0.dist-info/METADATA,sha256=Q7QxCj4WdwYSXi6cJQ2xyMMj1Mtiz0NBsR4n7As7WWs,4188
2
+ safetensors-0.7.0rc0.dist-info/WHEEL,sha256=CG8OzNtm0LMpJ2zhrjswlO8N-965OeMLklsQAG-nMvQ,94
3
+ safetensors-0.7.0rc0.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
4
+ safetensors/__init__.py,sha256=HYY5VVsb3b-cxhZBwhNx53ZKqSIB4M14nIXLTOAM1Wc,204
5
+ safetensors/__init__.pyi,sha256=tnVaPqYbh8ggFbOZdYKUC4ArqitiWDfrIQt1BNJ377k,4183
6
+ safetensors/_safetensors_rust.pyd,sha256=SGfYnZXx8FAAnnTG6tfkE5Icp9Fvzo_nABy-BV55SZ8,736768
7
+ safetensors/flax.py,sha256=SnuiGojmth0eCFIWoKEvAfh95nZP9uCZ9E-S4NndrbU,3991
8
+ safetensors/mlx.py,sha256=KvfTWusLSx1hSPWQgg99iL-z9VoD6zQ8l4-RAsCe7P8,3990
9
+ safetensors/numpy.py,sha256=8ci56gDXetlYHH1-Nru83auiUVi-Q1P9bKvfsdkLKPw,5215
10
+ safetensors/paddle.py,sha256=EhXpflqrhKr_NFh4jxV9SUnW0B1vcX_KdPdTqcytrDs,9011
11
+ safetensors/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ safetensors/tensorflow.py,sha256=DajI3qkz00Zy2h7jublSAvTaD51QOPdaIgKQIeSiCRs,4042
13
+ safetensors/torch.py,sha256=CLVWgWQdLm_tVzhRPaeihBHt-4iGAtUW5fY2ys3TyMc,19160
14
+ safetensors-0.7.0rc0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: maturin (1.9.3)
2
+ Generator: maturin (1.9.6)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp38-abi3-win_amd64
@@ -1,14 +0,0 @@
1
- safetensors-0.6.2.dist-info/METADATA,sha256=kkR-LvpyTCvbdGfZ-pH_IiPF4JNP81N9GPrYjY41BDo,4141
2
- safetensors-0.6.2.dist-info/WHEEL,sha256=lvaVdaNOIbpDjZxhxQcXMmDSpIrmQUI6MiaH-nloUu8,94
3
- safetensors-0.6.2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
4
- safetensors/__init__.py,sha256=HYY5VVsb3b-cxhZBwhNx53ZKqSIB4M14nIXLTOAM1Wc,204
5
- safetensors/__init__.pyi,sha256=tnVaPqYbh8ggFbOZdYKUC4ArqitiWDfrIQt1BNJ377k,4183
6
- safetensors/_safetensors_rust.pyd,sha256=05luNkOSehiTNRXXrg4CjBd644FvderG0Xw6TZJPrMc,704000
7
- safetensors/flax.py,sha256=SnuiGojmth0eCFIWoKEvAfh95nZP9uCZ9E-S4NndrbU,3991
8
- safetensors/mlx.py,sha256=KvfTWusLSx1hSPWQgg99iL-z9VoD6zQ8l4-RAsCe7P8,3990
9
- safetensors/numpy.py,sha256=MaUhU4V3J4nDjLmoy0OdeTc6JC8Dq2PBHmnxjQU2bfQ,5189
10
- safetensors/paddle.py,sha256=B8TLF5MFeqeipUxynSOB_NqPNypL_dE1C3vihtEWj0A,4337
11
- safetensors/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- safetensors/tensorflow.py,sha256=DajI3qkz00Zy2h7jublSAvTaD51QOPdaIgKQIeSiCRs,4042
13
- safetensors/torch.py,sha256=wxR0dOwmZBHxv5R5WTmSH4NLOXmvH9JGeVD_uImDlew,20851
14
- safetensors-0.6.2.dist-info/RECORD,,