safetensors 0.6.0rc0__cp38-abi3-win32.whl → 0.6.1rc0__cp38-abi3-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of safetensors might be problematic. Click here for more details.

safetensors/__init__.py CHANGED
@@ -4,6 +4,7 @@ from ._safetensors_rust import ( # noqa: F401
4
4
  __version__,
5
5
  deserialize,
6
6
  safe_open,
7
+ _safe_open_handle,
7
8
  serialize,
8
9
  serialize_file,
9
10
  )
safetensors/__init__.pyi CHANGED
@@ -49,7 +49,7 @@ def serialize_file(tensor_dict, filename, metadata=None):
49
49
 
50
50
  Returns:
51
51
  (`NoneType`):
52
- On success return None.
52
+ On success return None
53
53
  """
54
54
  pass
55
55
 
@@ -68,19 +68,21 @@ class safe_open:
68
68
  device (`str`, defaults to `"cpu"`):
69
69
  The device on which you want the tensors.
70
70
  """
71
-
72
71
  def __init__(self, filename, framework, device=...):
73
72
  pass
73
+
74
74
  def __enter__(self):
75
75
  """
76
76
  Start the context manager
77
77
  """
78
78
  pass
79
+
79
80
  def __exit__(self, _exc_type, _exc_value, _traceback):
80
81
  """
81
82
  Exits the context manager
82
83
  """
83
84
  pass
85
+
84
86
  def get_slice(self, name):
85
87
  """
86
88
  Returns a full slice view object
@@ -102,6 +104,7 @@ class safe_open:
102
104
  ```
103
105
  """
104
106
  pass
107
+
105
108
  def get_tensor(self, name):
106
109
  """
107
110
  Returns a full tensor
@@ -124,6 +127,7 @@ class safe_open:
124
127
  ```
125
128
  """
126
129
  pass
130
+
127
131
  def keys(self):
128
132
  """
129
133
  Returns the names of the tensors in the file.
@@ -133,6 +137,7 @@ class safe_open:
133
137
  The name of the tensors contained in that file
134
138
  """
135
139
  pass
140
+
136
141
  def metadata(self):
137
142
  """
138
143
  Return the special non tensor information in the header
@@ -143,6 +148,16 @@ class safe_open:
143
148
  """
144
149
  pass
145
150
 
151
+ def offset_keys(self):
152
+ """
153
+ Returns the names of the tensors in the file, ordered by offset.
154
+
155
+ Returns:
156
+ (`List[str]`):
157
+ The name of the tensors contained in that file
158
+ """
159
+ pass
160
+
146
161
  class SafetensorError(Exception):
147
162
  """
148
163
  Custom Python Exception for Safetensor errors.
Binary file
safetensors/mlx.py CHANGED
@@ -7,7 +7,9 @@ import mlx.core as mx
7
7
  from safetensors import numpy, safe_open
8
8
 
9
9
 
10
- def save(tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None) -> bytes:
10
+ def save(
11
+ tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None
12
+ ) -> bytes:
11
13
  """
12
14
  Saves a dictionary of tensors into raw bytes in safetensors format.
13
15
 
safetensors/numpy.py CHANGED
@@ -13,7 +13,9 @@ def _tobytes(tensor: np.ndarray) -> bytes:
13
13
  return tensor.tobytes()
14
14
 
15
15
 
16
- def save(tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None) -> bytes:
16
+ def save(
17
+ tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None
18
+ ) -> bytes:
17
19
  """
18
20
  Saves a dictionary of tensors into raw bytes in safetensors format.
19
21
 
@@ -38,14 +40,19 @@ def save(tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]]
38
40
  byte_data = save(tensors)
39
41
  ```
40
42
  """
41
- flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()}
43
+ flattened = {
44
+ k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)}
45
+ for k, v in tensor_dict.items()
46
+ }
42
47
  serialized = serialize(flattened, metadata=metadata)
43
48
  result = bytes(serialized)
44
49
  return result
45
50
 
46
51
 
47
52
  def save_file(
48
- tensor_dict: Dict[str, np.ndarray], filename: Union[str, os.PathLike], metadata: Optional[Dict[str, str]] = None
53
+ tensor_dict: Dict[str, np.ndarray],
54
+ filename: Union[str, os.PathLike],
55
+ metadata: Optional[Dict[str, str]] = None,
49
56
  ) -> None:
50
57
  """
51
58
  Saves a dictionary of tensors into raw bytes in safetensors format.
@@ -73,7 +80,10 @@ def save_file(
73
80
  save_file(tensors, "model.safetensors")
74
81
  ```
75
82
  """
76
- flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()}
83
+ flattened = {
84
+ k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)}
85
+ for k, v in tensor_dict.items()
86
+ }
77
87
  serialize_file(flattened, filename, metadata=metadata)
78
88
 
79
89
 
safetensors/paddle.py CHANGED
@@ -7,7 +7,9 @@ import paddle
7
7
  from safetensors import numpy
8
8
 
9
9
 
10
- def save(tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
10
+ def save(
11
+ tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None
12
+ ) -> bytes:
11
13
  """
12
14
  Saves a dictionary of tensors into raw bytes in safetensors format.
13
15
 
@@ -98,7 +100,9 @@ def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]:
98
100
  return _np2paddle(flat, device)
99
101
 
100
102
 
101
- def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, paddle.Tensor]:
103
+ def load_file(
104
+ filename: Union[str, os.PathLike], device="cpu"
105
+ ) -> Dict[str, paddle.Tensor]:
102
106
  """
103
107
  Loads a safetensors file into paddle format.
104
108
 
@@ -126,7 +130,9 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, padd
126
130
  return output
127
131
 
128
132
 
129
- def _np2paddle(numpy_dict: Dict[str, np.ndarray], device: str = "cpu") -> Dict[str, paddle.Tensor]:
133
+ def _np2paddle(
134
+ numpy_dict: Dict[str, np.ndarray], device: str = "cpu"
135
+ ) -> Dict[str, paddle.Tensor]:
130
136
  for k, v in numpy_dict.items():
131
137
  numpy_dict[k] = paddle.to_tensor(v, place=device)
132
138
  return numpy_dict
safetensors/tensorflow.py CHANGED
@@ -7,7 +7,9 @@ import tensorflow as tf
7
7
  from safetensors import numpy, safe_open
8
8
 
9
9
 
10
- def save(tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
10
+ def save(
11
+ tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None
12
+ ) -> bytes:
11
13
  """
12
14
  Saves a dictionary of tensors into raw bytes in safetensors format.
13
15
 
safetensors/torch.py CHANGED
@@ -2,6 +2,7 @@ import os
2
2
  import sys
3
3
  from collections import defaultdict
4
4
  from typing import Any, Dict, List, Optional, Set, Tuple, Union
5
+ from packaging.version import Version
5
6
 
6
7
  import torch
7
8
 
@@ -41,7 +42,9 @@ def storage_size(tensor: torch.Tensor) -> int:
41
42
  return tensor.nelement() * _SIZE[tensor.dtype]
42
43
 
43
44
 
44
- def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
45
+ def _filter_shared_not_shared(
46
+ tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
47
+ ) -> List[Set[str]]:
45
48
  filtered_tensors = []
46
49
  for shared in tensors:
47
50
  if len(shared) < 2:
@@ -69,7 +72,11 @@ def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, tor
69
72
  def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
70
73
  tensors = defaultdict(set)
71
74
  for k, v in state_dict.items():
72
- if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
75
+ if (
76
+ v.device != torch.device("meta")
77
+ and storage_ptr(v) != 0
78
+ and storage_size(v) != 0
79
+ ):
73
80
  # Need to add device as key because of multiple GPU.
74
81
  tensors[(v.device, storage_ptr(v), storage_size(v))].add(k)
75
82
  tensors = list(sorted(tensors.values()))
@@ -78,7 +85,9 @@ def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
78
85
 
79
86
 
80
87
  def _is_complete(tensor: torch.Tensor) -> bool:
81
- return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _SIZE[tensor.dtype] == storage_size(tensor)
88
+ return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _SIZE[
89
+ tensor.dtype
90
+ ] == storage_size(tensor)
82
91
 
83
92
 
84
93
  def _remove_duplicate_names(
@@ -97,7 +106,9 @@ def _remove_duplicate_names(
97
106
  shareds = _find_shared_tensors(state_dict)
98
107
  to_remove = defaultdict(list)
99
108
  for shared in shareds:
100
- complete_names = set([name for name in shared if _is_complete(state_dict[name])])
109
+ complete_names = set(
110
+ [name for name in shared if _is_complete(state_dict[name])]
111
+ )
101
112
  if not complete_names:
102
113
  raise RuntimeError(
103
114
  "Error while trying to find names to remove to save state dict, but found no suitable name to keep"
@@ -207,7 +218,9 @@ def load_model(
207
218
  """
208
219
  state_dict = load_file(filename, device=device)
209
220
  model_state_dict = model.state_dict()
210
- to_removes = _remove_duplicate_names(model_state_dict, preferred_names=state_dict.keys())
221
+ to_removes = _remove_duplicate_names(
222
+ model_state_dict, preferred_names=state_dict.keys()
223
+ )
211
224
 
212
225
  reverse_to_remove = {}
213
226
  for key, to_remove_group in to_removes.items():
@@ -273,7 +286,9 @@ def load_model(
273
286
  return missing, unexpected
274
287
 
275
288
 
276
- def save(tensors: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
289
+ def save(
290
+ tensors: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
291
+ ) -> bytes:
277
292
  """
278
293
  Saves a dictionary of tensors into raw bytes in safetensors format.
279
294
 
@@ -337,7 +352,9 @@ def save_file(
337
352
  serialize_file(_flatten(tensors), filename, metadata=metadata)
338
353
 
339
354
 
340
- def load_file(filename: Union[str, os.PathLike], device: Union[str, int] = "cpu") -> Dict[str, torch.Tensor]:
355
+ def load_file(
356
+ filename: Union[str, os.PathLike], device: Union[str, int] = "cpu"
357
+ ) -> Dict[str, torch.Tensor]:
341
358
  """
342
359
  Loads a safetensors file into torch format.
343
360
 
@@ -416,6 +433,14 @@ _SIZE = {
416
433
  _float8_e8m0: 1,
417
434
  _float4_e2m1_x2: 1,
418
435
  }
436
+ if Version(torch.__version__) > Version("2.0.0"):
437
+ _SIZE.update(
438
+ {
439
+ torch.uint64: 8,
440
+ torch.uint32: 4,
441
+ torch.uint16: 2,
442
+ }
443
+ )
419
444
 
420
445
  _TYPES = {
421
446
  "F64": torch.float64,
@@ -423,17 +448,22 @@ _TYPES = {
423
448
  "F16": torch.float16,
424
449
  "BF16": torch.bfloat16,
425
450
  "I64": torch.int64,
426
- # "U64": torch.uint64,
427
451
  "I32": torch.int32,
428
- # "U32": torch.uint32,
429
452
  "I16": torch.int16,
430
- # "U16": torch.uint16,
431
453
  "I8": torch.int8,
432
454
  "U8": torch.uint8,
433
455
  "BOOL": torch.bool,
434
456
  "F8_E4M3": _float8_e4m3fn,
435
457
  "F8_E5M2": _float8_e5m2,
436
458
  }
459
+ if Version(torch.__version__) > Version("2.0.0"):
460
+ _TYPES.update(
461
+ {
462
+ "U64": torch.uint64,
463
+ "U32": torch.uint32,
464
+ "U16": torch.uint16,
465
+ }
466
+ )
437
467
 
438
468
 
439
469
  def _getdtype(dtype_str: str) -> torch.dtype:
@@ -517,12 +547,16 @@ def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
517
547
 
518
548
  def _flatten(tensors: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]:
519
549
  if not isinstance(tensors, dict):
520
- raise ValueError(f"Expected a dict of [str, torch.Tensor] but received {type(tensors)}")
550
+ raise ValueError(
551
+ f"Expected a dict of [str, torch.Tensor] but received {type(tensors)}"
552
+ )
521
553
 
522
554
  invalid_tensors = []
523
555
  for k, v in tensors.items():
524
556
  if not isinstance(v, torch.Tensor):
525
- raise ValueError(f"Key `{k}` is invalid, expected torch.Tensor but received {type(v)}")
557
+ raise ValueError(
558
+ f"Key `{k}` is invalid, expected torch.Tensor but received {type(v)}"
559
+ )
526
560
 
527
561
  if v.layout != torch.strided:
528
562
  invalid_tensors.append(k)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: safetensors
3
- Version: 0.6.0rc0
3
+ Version: 0.6.1rc0
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Developers
6
6
  Classifier: Intended Audience :: Education
@@ -28,10 +28,7 @@ Requires-Dist: jaxlib>=0.3.25 ; extra == 'jax'
28
28
  Requires-Dist: mlx>=0.0.9 ; extra == 'mlx'
29
29
  Requires-Dist: safetensors[numpy] ; extra == 'paddlepaddle'
30
30
  Requires-Dist: paddlepaddle>=2.4.1 ; extra == 'paddlepaddle'
31
- Requires-Dist: black==22.3 ; extra == 'quality'
32
- Requires-Dist: click==8.0.4 ; extra == 'quality'
33
- Requires-Dist: isort>=5.5.4 ; extra == 'quality'
34
- Requires-Dist: flake8>=3.8.3 ; extra == 'quality'
31
+ Requires-Dist: ruff ; extra == 'quality'
35
32
  Requires-Dist: safetensors[numpy] ; extra == 'testing'
36
33
  Requires-Dist: h5py>=3.7.0 ; extra == 'testing'
37
34
  Requires-Dist: huggingface-hub>=0.12.1 ; extra == 'testing'
@@ -39,6 +36,12 @@ Requires-Dist: setuptools-rust>=1.5.2 ; extra == 'testing'
39
36
  Requires-Dist: pytest>=7.2.0 ; extra == 'testing'
40
37
  Requires-Dist: pytest-benchmark>=4.0.0 ; extra == 'testing'
41
38
  Requires-Dist: hypothesis>=6.70.2 ; extra == 'testing'
39
+ Requires-Dist: safetensors[numpy] ; extra == 'testingfree'
40
+ Requires-Dist: huggingface-hub>=0.12.1 ; extra == 'testingfree'
41
+ Requires-Dist: setuptools-rust>=1.5.2 ; extra == 'testingfree'
42
+ Requires-Dist: pytest>=7.2.0 ; extra == 'testingfree'
43
+ Requires-Dist: pytest-benchmark>=4.0.0 ; extra == 'testingfree'
44
+ Requires-Dist: hypothesis>=6.70.2 ; extra == 'testingfree'
42
45
  Requires-Dist: safetensors[torch] ; extra == 'all'
43
46
  Requires-Dist: safetensors[numpy] ; extra == 'all'
44
47
  Requires-Dist: safetensors[pinned-tf] ; extra == 'all'
@@ -56,6 +59,7 @@ Provides-Extra: mlx
56
59
  Provides-Extra: paddlepaddle
57
60
  Provides-Extra: quality
58
61
  Provides-Extra: testing
62
+ Provides-Extra: testingfree
59
63
  Provides-Extra: all
60
64
  Provides-Extra: dev
61
65
  License-File: LICENSE
@@ -0,0 +1,14 @@
1
+ safetensors-0.6.1rc0.dist-info/METADATA,sha256=b9UhFet-V41dU1DjVtpEBwxjqd4KdYo1TXZI1aTCeDI,4144
2
+ safetensors-0.6.1rc0.dist-info/WHEEL,sha256=f_z1-UCPQDLldsOrrwsrPzP3Lu45fZVQJi3w5k9L9Kw,90
3
+ safetensors-0.6.1rc0.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
4
+ safetensors/__init__.py,sha256=HYY5VVsb3b-cxhZBwhNx53ZKqSIB4M14nIXLTOAM1Wc,204
5
+ safetensors/__init__.pyi,sha256=tnVaPqYbh8ggFbOZdYKUC4ArqitiWDfrIQt1BNJ377k,4183
6
+ safetensors/_safetensors_rust.pyd,sha256=Ka0Sw1QC34OT7VG3HnygBtlgL76WcpGmATrh9PHdmxg,611328
7
+ safetensors/flax.py,sha256=SnuiGojmth0eCFIWoKEvAfh95nZP9uCZ9E-S4NndrbU,3991
8
+ safetensors/mlx.py,sha256=KvfTWusLSx1hSPWQgg99iL-z9VoD6zQ8l4-RAsCe7P8,3990
9
+ safetensors/numpy.py,sha256=MaUhU4V3J4nDjLmoy0OdeTc6JC8Dq2PBHmnxjQU2bfQ,5189
10
+ safetensors/paddle.py,sha256=B8TLF5MFeqeipUxynSOB_NqPNypL_dE1C3vihtEWj0A,4337
11
+ safetensors/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ safetensors/tensorflow.py,sha256=DajI3qkz00Zy2h7jublSAvTaD51QOPdaIgKQIeSiCRs,4042
13
+ safetensors/torch.py,sha256=vpMhMmierfW5Os6BLlwYgYbb7akoPTBiWPD8fW6nwSE,20849
14
+ safetensors-0.6.1rc0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: maturin (1.8.7)
2
+ Generator: maturin (1.9.3)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp38-abi3-win32
@@ -1,14 +0,0 @@
1
- safetensors-0.6.0rc0.dist-info/METADATA,sha256=JOcXGydzbxQyLCKSV_L6kQ3DYJdXoKD5R1dJZJRzdCM,3908
2
- safetensors-0.6.0rc0.dist-info/WHEEL,sha256=SDCbBFz5TSmn0QRHfGLgEdqCqxx9FJdkEZPReKjVInM,90
3
- safetensors-0.6.0rc0.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
4
- safetensors/__init__.py,sha256=xxEn5gc4JWHjA0wnafb0Wwwq1m2QXwlbm5nsItGRqO0,180
5
- safetensors/__init__.pyi,sha256=ItiQgMqqoYAuwZH8EWkXh3ApUwGkDTRiLkDoMBIMR3s,3918
6
- safetensors/_safetensors_rust.pyd,sha256=J8rqSSSTCpRV2DOOg3Peya6Wgu-dvvHg4gGg7Zaf5oA,603648
7
- safetensors/flax.py,sha256=SnuiGojmth0eCFIWoKEvAfh95nZP9uCZ9E-S4NndrbU,3991
8
- safetensors/mlx.py,sha256=Bcb3g1LiEbdY_h3r2v_USXgFzdQdwMAI-dlSY4cBGNY,3982
9
- safetensors/numpy.py,sha256=JKxjDTfr3Z0bqri-WO7LxSwstY58DUP4cAFOcGuoDHY,5120
10
- safetensors/paddle.py,sha256=ZKyerzx1bwVb-hB10sKQnpSI0z_U-iDknmBEqxVzFRI,4313
11
- safetensors/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- safetensors/tensorflow.py,sha256=zcg2HyMfSwb3XHtsSQsFm3VgwCoIeRgs73xlPA72Jzc,4034
13
- safetensors/torch.py,sha256=MhfQdczycyNM0AMmuDNkSwdBkvYbCFPQPfkp-5fWVVc,20315
14
- safetensors-0.6.0rc0.dist-info/RECORD,,