safetensors 0.5.1__tar.gz → 0.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of safetensors might be problematic. Click here for more details.

Files changed (53) hide show
  1. {safetensors-0.5.1 → safetensors-0.5.3}/PKG-INFO +20 -20
  2. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/Cargo.lock +33 -33
  3. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/Cargo.toml +2 -1
  4. {safetensors-0.5.1 → safetensors-0.5.3/bindings/python}/py_src/safetensors/__init__.pyi +1 -1
  5. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/src/lib.rs +11 -3
  6. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/stub.py +7 -19
  7. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/tests/test_pt_comparison.py +18 -0
  8. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/tests/test_simple.py +15 -1
  9. {safetensors-0.5.1/bindings/python → safetensors-0.5.3}/py_src/safetensors/__init__.pyi +1 -1
  10. {safetensors-0.5.1 → safetensors-0.5.3}/safetensors/Cargo.toml +10 -3
  11. safetensors-0.5.3/safetensors/src/lib.rs +43 -0
  12. {safetensors-0.5.1 → safetensors-0.5.3}/safetensors/src/slice.rs +64 -7
  13. {safetensors-0.5.1 → safetensors-0.5.3}/safetensors/src/tensor.rs +29 -18
  14. safetensors-0.5.1/safetensors/src/lib.rs +0 -5
  15. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/.gitignore +0 -0
  16. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/MANIFEST.in +0 -0
  17. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/Makefile +0 -0
  18. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/README.md +0 -0
  19. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/benches/test_flax.py +0 -0
  20. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/benches/test_mlx.py +0 -0
  21. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/benches/test_paddle.py +0 -0
  22. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/benches/test_pt.py +0 -0
  23. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/benches/test_tf.py +0 -0
  24. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/convert.py +0 -0
  25. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/convert_all.py +0 -0
  26. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/fuzz.py +0 -0
  27. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/py_src/safetensors/__init__.py +0 -0
  28. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/py_src/safetensors/flax.py +0 -0
  29. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/py_src/safetensors/mlx.py +0 -0
  30. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/py_src/safetensors/numpy.py +0 -0
  31. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/py_src/safetensors/paddle.py +0 -0
  32. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/py_src/safetensors/py.typed +0 -0
  33. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/py_src/safetensors/tensorflow.py +0 -0
  34. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/py_src/safetensors/torch.py +0 -0
  35. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/setup.cfg +0 -0
  36. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/tests/data/__init__.py +0 -0
  37. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/tests/test_flax_comparison.py +0 -0
  38. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/tests/test_mlx_comparison.py +0 -0
  39. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/tests/test_paddle_comparison.py +0 -0
  40. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/tests/test_pt_model.py +0 -0
  41. {safetensors-0.5.1 → safetensors-0.5.3}/bindings/python/tests/test_tf_comparison.py +0 -0
  42. {safetensors-0.5.1 → safetensors-0.5.3}/py_src/safetensors/__init__.py +0 -0
  43. {safetensors-0.5.1 → safetensors-0.5.3}/py_src/safetensors/flax.py +0 -0
  44. {safetensors-0.5.1 → safetensors-0.5.3}/py_src/safetensors/mlx.py +0 -0
  45. {safetensors-0.5.1 → safetensors-0.5.3}/py_src/safetensors/numpy.py +0 -0
  46. {safetensors-0.5.1 → safetensors-0.5.3}/py_src/safetensors/paddle.py +0 -0
  47. {safetensors-0.5.1 → safetensors-0.5.3}/py_src/safetensors/py.typed +0 -0
  48. {safetensors-0.5.1 → safetensors-0.5.3}/py_src/safetensors/tensorflow.py +0 -0
  49. {safetensors-0.5.1 → safetensors-0.5.3}/py_src/safetensors/torch.py +0 -0
  50. {safetensors-0.5.1 → safetensors-0.5.3}/pyproject.toml +0 -0
  51. {safetensors-0.5.1 → safetensors-0.5.3}/safetensors/LICENSE +0 -0
  52. {safetensors-0.5.1 → safetensors-0.5.3}/safetensors/README.md +0 -0
  53. {safetensors-0.5.1 → safetensors-0.5.3}/safetensors/benches/benchmark.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: safetensors
3
- Version: 0.5.1
3
+ Version: 0.5.3
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Developers
6
6
  Classifier: Intended Audience :: Education
@@ -14,31 +14,31 @@ Classifier: Programming Language :: Python :: 3.9
14
14
  Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
16
  Classifier: Typing :: Typed
17
- Requires-Dist: numpy >=1.21.6 ; extra == 'numpy'
17
+ Requires-Dist: numpy>=1.21.6 ; extra == 'numpy'
18
18
  Requires-Dist: safetensors[numpy] ; extra == 'torch'
19
- Requires-Dist: torch >=1.10 ; extra == 'torch'
19
+ Requires-Dist: torch>=1.10 ; extra == 'torch'
20
20
  Requires-Dist: safetensors[numpy] ; extra == 'tensorflow'
21
- Requires-Dist: tensorflow >=2.11.0 ; extra == 'tensorflow'
21
+ Requires-Dist: tensorflow>=2.11.0 ; extra == 'tensorflow'
22
22
  Requires-Dist: safetensors[numpy] ; extra == 'pinned-tf'
23
- Requires-Dist: tensorflow ==2.18.0 ; extra == 'pinned-tf'
23
+ Requires-Dist: tensorflow==2.18.0 ; extra == 'pinned-tf'
24
24
  Requires-Dist: safetensors[numpy] ; extra == 'jax'
25
- Requires-Dist: flax >=0.6.3 ; extra == 'jax'
26
- Requires-Dist: jax >=0.3.25 ; extra == 'jax'
27
- Requires-Dist: jaxlib >=0.3.25 ; extra == 'jax'
28
- Requires-Dist: mlx >=0.0.9 ; extra == 'mlx'
25
+ Requires-Dist: flax>=0.6.3 ; extra == 'jax'
26
+ Requires-Dist: jax>=0.3.25 ; extra == 'jax'
27
+ Requires-Dist: jaxlib>=0.3.25 ; extra == 'jax'
28
+ Requires-Dist: mlx>=0.0.9 ; extra == 'mlx'
29
29
  Requires-Dist: safetensors[numpy] ; extra == 'paddlepaddle'
30
- Requires-Dist: paddlepaddle >=2.4.1 ; extra == 'paddlepaddle'
31
- Requires-Dist: black ==22.3 ; extra == 'quality'
32
- Requires-Dist: click ==8.0.4 ; extra == 'quality'
33
- Requires-Dist: isort >=5.5.4 ; extra == 'quality'
34
- Requires-Dist: flake8 >=3.8.3 ; extra == 'quality'
30
+ Requires-Dist: paddlepaddle>=2.4.1 ; extra == 'paddlepaddle'
31
+ Requires-Dist: black==22.3 ; extra == 'quality'
32
+ Requires-Dist: click==8.0.4 ; extra == 'quality'
33
+ Requires-Dist: isort>=5.5.4 ; extra == 'quality'
34
+ Requires-Dist: flake8>=3.8.3 ; extra == 'quality'
35
35
  Requires-Dist: safetensors[numpy] ; extra == 'testing'
36
- Requires-Dist: h5py >=3.7.0 ; extra == 'testing'
37
- Requires-Dist: huggingface-hub >=0.12.1 ; extra == 'testing'
38
- Requires-Dist: setuptools-rust >=1.5.2 ; extra == 'testing'
39
- Requires-Dist: pytest >=7.2.0 ; extra == 'testing'
40
- Requires-Dist: pytest-benchmark >=4.0.0 ; extra == 'testing'
41
- Requires-Dist: hypothesis >=6.70.2 ; extra == 'testing'
36
+ Requires-Dist: h5py>=3.7.0 ; extra == 'testing'
37
+ Requires-Dist: huggingface-hub>=0.12.1 ; extra == 'testing'
38
+ Requires-Dist: setuptools-rust>=1.5.2 ; extra == 'testing'
39
+ Requires-Dist: pytest>=7.2.0 ; extra == 'testing'
40
+ Requires-Dist: pytest-benchmark>=4.0.0 ; extra == 'testing'
41
+ Requires-Dist: hypothesis>=6.70.2 ; extra == 'testing'
42
42
  Requires-Dist: safetensors[torch] ; extra == 'all'
43
43
  Requires-Dist: safetensors[numpy] ; extra == 'all'
44
44
  Requires-Dist: safetensors[pinned-tf] ; extra == 'all'
@@ -1,6 +1,6 @@
1
1
  # This file is automatically @generated by Cargo.
2
2
  # It is not intended for manual editing.
3
- version = 4
3
+ version = 3
4
4
 
5
5
  [[package]]
6
6
  name = "autocfg"
@@ -34,9 +34,9 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
34
34
 
35
35
  [[package]]
36
36
  name = "libc"
37
- version = "0.2.169"
37
+ version = "0.2.170"
38
38
  source = "registry+https://github.com/rust-lang/crates.io-index"
39
- checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
39
+ checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828"
40
40
 
41
41
  [[package]]
42
42
  name = "memchr"
@@ -64,30 +64,30 @@ dependencies = [
64
64
 
65
65
  [[package]]
66
66
  name = "once_cell"
67
- version = "1.20.2"
67
+ version = "1.20.3"
68
68
  source = "registry+https://github.com/rust-lang/crates.io-index"
69
- checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
69
+ checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
70
70
 
71
71
  [[package]]
72
72
  name = "portable-atomic"
73
- version = "1.10.0"
73
+ version = "1.11.0"
74
74
  source = "registry+https://github.com/rust-lang/crates.io-index"
75
- checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6"
75
+ checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
76
76
 
77
77
  [[package]]
78
78
  name = "proc-macro2"
79
- version = "1.0.92"
79
+ version = "1.0.93"
80
80
  source = "registry+https://github.com/rust-lang/crates.io-index"
81
- checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0"
81
+ checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99"
82
82
  dependencies = [
83
83
  "unicode-ident",
84
84
  ]
85
85
 
86
86
  [[package]]
87
87
  name = "pyo3"
88
- version = "0.23.3"
88
+ version = "0.23.5"
89
89
  source = "registry+https://github.com/rust-lang/crates.io-index"
90
- checksum = "e484fd2c8b4cb67ab05a318f1fd6fa8f199fcc30819f08f07d200809dba26c15"
90
+ checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
91
91
  dependencies = [
92
92
  "cfg-if",
93
93
  "indoc",
@@ -103,9 +103,9 @@ dependencies = [
103
103
 
104
104
  [[package]]
105
105
  name = "pyo3-build-config"
106
- version = "0.23.3"
106
+ version = "0.23.5"
107
107
  source = "registry+https://github.com/rust-lang/crates.io-index"
108
- checksum = "dc0e0469a84f208e20044b98965e1561028180219e35352a2afaf2b942beff3b"
108
+ checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
109
109
  dependencies = [
110
110
  "once_cell",
111
111
  "target-lexicon",
@@ -113,9 +113,9 @@ dependencies = [
113
113
 
114
114
  [[package]]
115
115
  name = "pyo3-ffi"
116
- version = "0.23.3"
116
+ version = "0.23.5"
117
117
  source = "registry+https://github.com/rust-lang/crates.io-index"
118
- checksum = "eb1547a7f9966f6f1a0f0227564a9945fe36b90da5a93b3933fc3dc03fae372d"
118
+ checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
119
119
  dependencies = [
120
120
  "libc",
121
121
  "pyo3-build-config",
@@ -123,9 +123,9 @@ dependencies = [
123
123
 
124
124
  [[package]]
125
125
  name = "pyo3-macros"
126
- version = "0.23.3"
126
+ version = "0.23.5"
127
127
  source = "registry+https://github.com/rust-lang/crates.io-index"
128
- checksum = "fdb6da8ec6fa5cedd1626c886fc8749bdcbb09424a86461eb8cdf096b7c33257"
128
+ checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
129
129
  dependencies = [
130
130
  "proc-macro2",
131
131
  "pyo3-macros-backend",
@@ -135,9 +135,9 @@ dependencies = [
135
135
 
136
136
  [[package]]
137
137
  name = "pyo3-macros-backend"
138
- version = "0.23.3"
138
+ version = "0.23.5"
139
139
  source = "registry+https://github.com/rust-lang/crates.io-index"
140
- checksum = "38a385202ff5a92791168b1136afae5059d3ac118457bb7bc304c197c2d33e7d"
140
+ checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
141
141
  dependencies = [
142
142
  "heck",
143
143
  "proc-macro2",
@@ -157,13 +157,13 @@ dependencies = [
157
157
 
158
158
  [[package]]
159
159
  name = "ryu"
160
- version = "1.0.18"
160
+ version = "1.0.19"
161
161
  source = "registry+https://github.com/rust-lang/crates.io-index"
162
- checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
162
+ checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
163
163
 
164
164
  [[package]]
165
165
  name = "safetensors"
166
- version = "0.5.1"
166
+ version = "0.5.3"
167
167
  dependencies = [
168
168
  "serde",
169
169
  "serde_json",
@@ -171,7 +171,7 @@ dependencies = [
171
171
 
172
172
  [[package]]
173
173
  name = "safetensors-python"
174
- version = "0.5.1"
174
+ version = "0.5.3"
175
175
  dependencies = [
176
176
  "memmap2",
177
177
  "pyo3",
@@ -181,18 +181,18 @@ dependencies = [
181
181
 
182
182
  [[package]]
183
183
  name = "serde"
184
- version = "1.0.217"
184
+ version = "1.0.218"
185
185
  source = "registry+https://github.com/rust-lang/crates.io-index"
186
- checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70"
186
+ checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60"
187
187
  dependencies = [
188
188
  "serde_derive",
189
189
  ]
190
190
 
191
191
  [[package]]
192
192
  name = "serde_derive"
193
- version = "1.0.217"
193
+ version = "1.0.218"
194
194
  source = "registry+https://github.com/rust-lang/crates.io-index"
195
- checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
195
+ checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b"
196
196
  dependencies = [
197
197
  "proc-macro2",
198
198
  "quote",
@@ -201,9 +201,9 @@ dependencies = [
201
201
 
202
202
  [[package]]
203
203
  name = "serde_json"
204
- version = "1.0.134"
204
+ version = "1.0.139"
205
205
  source = "registry+https://github.com/rust-lang/crates.io-index"
206
- checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d"
206
+ checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
207
207
  dependencies = [
208
208
  "itoa",
209
209
  "memchr",
@@ -213,9 +213,9 @@ dependencies = [
213
213
 
214
214
  [[package]]
215
215
  name = "syn"
216
- version = "2.0.94"
216
+ version = "2.0.98"
217
217
  source = "registry+https://github.com/rust-lang/crates.io-index"
218
- checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3"
218
+ checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1"
219
219
  dependencies = [
220
220
  "proc-macro2",
221
221
  "quote",
@@ -230,9 +230,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
230
230
 
231
231
  [[package]]
232
232
  name = "unicode-ident"
233
- version = "1.0.14"
233
+ version = "1.0.17"
234
234
  source = "registry+https://github.com/rust-lang/crates.io-index"
235
- checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
235
+ checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe"
236
236
 
237
237
  [[package]]
238
238
  name = "unindent"
@@ -1,7 +1,8 @@
1
1
  [package]
2
2
  name = "safetensors-python"
3
- version = "0.5.1"
3
+ version = "0.5.3"
4
4
  edition = "2021"
5
+ rust-version = "1.74"
5
6
 
6
7
  # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7
8
  [lib]
@@ -69,7 +69,7 @@ class safe_open:
69
69
  The device on which you want the tensors.
70
70
  """
71
71
 
72
- def __init__(filename, framework, device=...):
72
+ def __init__(self, filename, framework, device=...):
73
73
  pass
74
74
  def __enter__(self):
75
75
  """
@@ -267,7 +267,8 @@ enum Device {
267
267
  Xpu(usize),
268
268
  Xla(usize),
269
269
  Mlu(usize),
270
- /// User didn't specify acceletor, torch
270
+ Hpu(usize),
271
+ /// User didn't specify accelerator, torch
271
272
  /// is responsible for choosing.
272
273
  Anonymous(usize),
273
274
  }
@@ -296,11 +297,13 @@ impl<'source> FromPyObject<'source> for Device {
296
297
  "xpu" => Ok(Device::Xpu(0)),
297
298
  "xla" => Ok(Device::Xla(0)),
298
299
  "mlu" => Ok(Device::Mlu(0)),
300
+ "hpu" => Ok(Device::Hpu(0)),
299
301
  name if name.starts_with("cuda:") => parse_device(name).map(Device::Cuda),
300
302
  name if name.starts_with("npu:") => parse_device(name).map(Device::Npu),
301
303
  name if name.starts_with("xpu:") => parse_device(name).map(Device::Xpu),
302
304
  name if name.starts_with("xla:") => parse_device(name).map(Device::Xla),
303
305
  name if name.starts_with("mlu:") => parse_device(name).map(Device::Mlu),
306
+ name if name.starts_with("hpu:") => parse_device(name).map(Device::Hpu),
304
307
  name => Err(SafetensorError::new_err(format!(
305
308
  "device {name} is invalid"
306
309
  ))),
@@ -327,6 +330,7 @@ impl<'py> IntoPyObject<'py> for Device {
327
330
  Device::Xpu(n) => format!("xpu:{n}").into_pyobject(py).map(|x| x.into_any()),
328
331
  Device::Xla(n) => format!("xla:{n}").into_pyobject(py).map(|x| x.into_any()),
329
332
  Device::Mlu(n) => format!("mlu:{n}").into_pyobject(py).map(|x| x.into_any()),
333
+ Device::Hpu(n) => format!("hpu:{n}").into_pyobject(py).map(|x| x.into_any()),
330
334
  Device::Anonymous(n) => n.into_pyobject(py).map(|x| x.into_any()),
331
335
  }
332
336
  }
@@ -795,8 +799,12 @@ struct Disp(Vec<TensorIndexer>);
795
799
  impl fmt::Display for Disp {
796
800
  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
797
801
  write!(f, "[")?;
798
- for item in &self.0 {
799
- write!(f, "{item}")?;
802
+ for (i, item) in self.0.iter().enumerate() {
803
+ if i != self.0.len() - 1 {
804
+ write!(f, "{item}, ")?;
805
+ } else {
806
+ write!(f, "{item}")?;
807
+ }
800
808
  }
801
809
  write!(f, "]")
802
810
  }
@@ -42,10 +42,7 @@ def fn_predicate(obj):
42
42
  return (
43
43
  obj.__doc__
44
44
  and obj.__text_signature__
45
- and (
46
- not obj.__name__.startswith("_")
47
- or obj.__name__ in {"__enter__", "__exit__"}
48
- )
45
+ and (not obj.__name__.startswith("_") or obj.__name__ in {"__enter__", "__exit__"})
49
46
  )
50
47
  if inspect.isgetsetdescriptor(obj):
51
48
  return obj.__doc__ and not obj.__name__.startswith("_")
@@ -81,15 +78,14 @@ def pyi_file(obj, indent=""):
81
78
 
82
79
  body = ""
83
80
  if obj.__doc__:
84
- body += (
85
- f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
86
- )
81
+ body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
87
82
 
88
83
  fns = inspect.getmembers(obj, fn_predicate)
89
84
 
90
85
  # Init
91
86
  if obj.__text_signature__:
92
- body += f"{indent}def __init__{obj.__text_signature__}:\n"
87
+ signature = obj.__text_signature__.replace("(", "(self, ")
88
+ body += f"{indent}def __init__{signature}:\n"
93
89
  body += f"{indent+INDENT}pass\n"
94
90
  body += "\n"
95
91
 
@@ -146,11 +142,7 @@ def do_black(content, is_pyi):
146
142
 
147
143
 
148
144
  def write(module, directory, origin, check=False):
149
- submodules = [
150
- (name, member)
151
- for name, member in inspect.getmembers(module)
152
- if inspect.ismodule(member)
153
- ]
145
+ submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
154
146
 
155
147
  filename = os.path.join(directory, "__init__.pyi")
156
148
  pyi_content = pyi_file(module)
@@ -159,9 +151,7 @@ def write(module, directory, origin, check=False):
159
151
  if check:
160
152
  with open(filename, "r") as f:
161
153
  data = f.read()
162
- assert (
163
- data == pyi_content
164
- ), f"The content of {filename} seems outdated, please run `python stub.py`"
154
+ assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
165
155
  else:
166
156
  with open(filename, "w") as f:
167
157
  f.write(pyi_content)
@@ -184,9 +174,7 @@ def write(module, directory, origin, check=False):
184
174
  if check:
185
175
  with open(filename, "r") as f:
186
176
  data = f.read()
187
- assert (
188
- data == py_content
189
- ), f"The content of {filename} seems outdated, please run `python stub.py`"
177
+ assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
190
178
  else:
191
179
  with open(filename, "w") as f:
192
180
  f.write(py_content)
@@ -170,6 +170,24 @@ class TorchTestCase(unittest.TestCase):
170
170
  for k, v in reloaded.items():
171
171
  self.assertTrue(torch.allclose(data[k], reloaded[k]))
172
172
 
173
+ def test_hpu(self):
174
+ # must be run to load torch with Intel Gaudi bindings
175
+ try:
176
+ import habana_frameworks.torch.core as htcore
177
+ except ImportError:
178
+ self.skipTest("HPU is not available")
179
+
180
+ data = {
181
+ "test1": torch.zeros((2, 2), dtype=torch.float32).to("hpu"),
182
+ "test2": torch.zeros((2, 2), dtype=torch.float16).to("hpu"),
183
+ }
184
+ local = "./tests/data/out_safe_pt_mmap_small_hpu.safetensors"
185
+ save_file(data, local)
186
+
187
+ reloaded = load_file(local, device="hpu")
188
+ for k, v in reloaded.items():
189
+ self.assertTrue(torch.allclose(data[k], reloaded[k]))
190
+
173
191
  @unittest.skipIf(not torch.cuda.is_available(), "Cuda is not available")
174
192
  def test_anonymous_accelerator(self):
175
193
  data = {
@@ -340,5 +340,19 @@ class ReadmeTestCase(unittest.TestCase):
340
340
  tensor = slice_[2:, 20]
341
341
  self.assertEqual(
342
342
  str(cm.exception),
343
- "Error during slicing [2:20] with shape [10, 5]: SliceOutOfRange { dim_index: 1, asked: 20, dim_size: 5 }",
343
+ "Error during slicing [2:, 20] with shape [10, 5]: SliceOutOfRange { dim_index: 1, asked: 20, dim_size: 5 }",
344
+ )
345
+
346
+ with self.assertRaises(SafetensorError) as cm:
347
+ tensor = slice_[:20]
348
+ self.assertEqual(
349
+ str(cm.exception),
350
+ "Error during slicing [:20] with shape [10, 5]: SliceOutOfRange { dim_index: 0, asked: 19, dim_size: 10 }",
351
+ )
352
+
353
+ with self.assertRaises(SafetensorError) as cm:
354
+ tensor = slice_[:, :20]
355
+ self.assertEqual(
356
+ str(cm.exception),
357
+ "Error during slicing [:, :20] with shape [10, 5]: SliceOutOfRange { dim_index: 1, asked: 19, dim_size: 5 }",
344
358
  )
@@ -69,7 +69,7 @@ class safe_open:
69
69
  The device on which you want the tensors.
70
70
  """
71
71
 
72
- def __init__(filename, framework, device=...):
72
+ def __init__(self, filename, framework, device=...):
73
73
  pass
74
74
  def __enter__(self):
75
75
  """
@@ -1,7 +1,8 @@
1
1
  [package]
2
2
  name = "safetensors"
3
- version = "0.5.1"
3
+ version = "0.5.3"
4
4
  edition = "2021"
5
+ rust-version = "1.74"
5
6
  homepage = "https://github.com/huggingface/safetensors"
6
7
  repository = "https://github.com/huggingface/safetensors"
7
8
  documentation = "https://docs.rs/safetensors/"
@@ -21,14 +22,20 @@ exclude = [ "rust-toolchain", "target/*", "Cargo.lock"]
21
22
  # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
22
23
 
23
24
  [dependencies]
24
- serde = {version = "1.0", features = ["derive"]}
25
- serde_json = "1.0"
25
+ hashbrown = { version = "0.15.2", features = ["serde"], optional = true }
26
+ serde = { version = "1.0", default-features = false, features = ["derive"] }
27
+ serde_json = { version = "1.0", default-features = false }
26
28
 
27
29
  [dev-dependencies]
28
30
  criterion = "0.5"
29
31
  memmap2 = "0.9"
30
32
  proptest = "1.4"
31
33
 
34
+ [features]
35
+ default = ["std"]
36
+ std = ["serde/default", "serde_json/default"]
37
+ alloc = ["serde/alloc", "serde_json/alloc", "hashbrown"]
38
+
32
39
  [[bench]]
33
40
  name = "benchmark"
34
41
  harness = false
@@ -0,0 +1,43 @@
1
+ #![deny(missing_docs)]
2
+ #![doc = include_str!("../README.md")]
3
+ #![cfg_attr(not(feature = "std"), no_std)]
4
+ pub mod slice;
5
+ pub mod tensor;
6
+ /// serialize_to_file only valid in std
7
+ #[cfg(feature = "std")]
8
+ pub use tensor::serialize_to_file;
9
+ pub use tensor::{serialize, Dtype, SafeTensorError, SafeTensors, View};
10
+
11
+ #[cfg(feature = "alloc")]
12
+ #[macro_use]
13
+ extern crate alloc;
14
+
15
+ #[cfg(all(feature = "std", feature = "alloc"))]
16
+ compile_error!("must choose either the `std` or `alloc` feature, but not both.");
17
+ #[cfg(all(not(feature = "std"), not(feature = "alloc")))]
18
+ compile_error!("must choose either the `std` or `alloc` feature");
19
+
20
+ /// A facade around all the types we need from the `std`, `core`, and `alloc`
21
+ /// crates. This avoids elaborate import wrangling having to happen in every
22
+ /// module.
23
+ mod lib {
24
+ #[cfg(not(feature = "std"))]
25
+ mod no_stds {
26
+ pub use alloc::borrow::Cow;
27
+ pub use alloc::string::{String, ToString};
28
+ pub use alloc::vec::Vec;
29
+ pub use hashbrown::HashMap;
30
+ }
31
+ #[cfg(feature = "std")]
32
+ mod stds {
33
+ pub use std::borrow::Cow;
34
+ pub use std::collections::HashMap;
35
+ pub use std::string::{String, ToString};
36
+ pub use std::vec::Vec;
37
+ }
38
+ /// choose std or no_std to export by feature flag
39
+ #[cfg(not(feature = "std"))]
40
+ pub use no_stds::*;
41
+ #[cfg(feature = "std")]
42
+ pub use stds::*;
43
+ }
@@ -1,12 +1,13 @@
1
1
  //! Module handling lazy loading via iterating on slices on the original buffer.
2
+ use crate::lib::{String, ToString, Vec};
2
3
  use crate::tensor::TensorView;
3
- use std::fmt;
4
- use std::ops::{
4
+ use core::ops::{
5
5
  Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
6
6
  };
7
7
 
8
8
  /// Error representing invalid slicing attempt
9
9
  #[derive(Debug)]
10
+ #[cfg_attr(test, derive(Eq, PartialEq))]
10
11
  pub enum InvalidSlice {
11
12
  /// When the client asked for more slices than the tensors has dimensions
12
13
  TooManySlices,
@@ -40,8 +41,8 @@ fn display_bound(bound: &Bound<usize>) -> String {
40
41
  }
41
42
 
42
43
  /// Intended for Python users mostly or at least for its conventions
43
- impl fmt::Display for TensorIndexer {
44
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44
+ impl core::fmt::Display for TensorIndexer {
45
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
45
46
  match self {
46
47
  TensorIndexer::Select(n) => {
47
48
  write!(f, "{n}")
@@ -77,7 +78,7 @@ macro_rules! impl_from_range {
77
78
  ($range_type:ty) => {
78
79
  impl From<$range_type> for TensorIndexer {
79
80
  fn from(range: $range_type) -> Self {
80
- use std::ops::Bound::*;
81
+ use core::ops::Bound::*;
81
82
 
82
83
  let start = match range.start_bound() {
83
84
  Included(idx) => Included(*idx),
@@ -235,6 +236,7 @@ where
235
236
 
236
237
  /// Iterator used to return the bits of the overall tensor buffer
237
238
  /// when client asks for a slice of the original tensor.
239
+ #[cfg_attr(test, derive(Debug, Eq, PartialEq))]
238
240
  pub struct SliceIterator<'data> {
239
241
  view: &'data TensorView<'data>,
240
242
  indices: Vec<(usize, usize)>,
@@ -284,10 +286,15 @@ impl<'data> SliceIterator<'data> {
284
286
  }
285
287
  TensorIndexer::Select(s) => (*s, *s + 1),
286
288
  };
287
- if start >= shape && stop > shape {
289
+ if start >= shape || stop > shape {
290
+ let asked = if start >= shape {
291
+ start
292
+ } else {
293
+ stop.saturating_sub(1)
294
+ };
288
295
  return Err(InvalidSlice::SliceOutOfRange {
289
296
  dim_index: i,
290
- asked: stop.saturating_sub(1),
297
+ asked,
291
298
  dim_size: shape,
292
299
  });
293
300
  }
@@ -573,4 +580,54 @@ mod tests {
573
580
  assert_eq!(iterator.next(), Some(&data[12..16]));
574
581
  assert_eq!(iterator.next(), None);
575
582
  }
583
+
584
+ #[test]
585
+ fn test_invalid_range() {
586
+ let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
587
+ .into_iter()
588
+ .flat_map(|f| f.to_le_bytes())
589
+ .collect();
590
+
591
+ let attn_0 = TensorView::new(Dtype::F32, vec![2, 3], &data).unwrap();
592
+
593
+ assert_eq!(
594
+ SliceIterator::new(
595
+ &attn_0,
596
+ &[
597
+ TensorIndexer::Select(1),
598
+ TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(4)),
599
+ ],
600
+ ),
601
+ Err(InvalidSlice::SliceOutOfRange {
602
+ asked: 3,
603
+ dim_index: 1,
604
+ dim_size: 3,
605
+ })
606
+ );
607
+ assert_eq!(
608
+ SliceIterator::new(
609
+ &attn_0,
610
+ &[
611
+ TensorIndexer::Select(1),
612
+ TensorIndexer::Narrow(Bound::Included(3), Bound::Excluded(2)),
613
+ ],
614
+ ),
615
+ Err(InvalidSlice::SliceOutOfRange {
616
+ asked: 3,
617
+ dim_index: 1,
618
+ dim_size: 3,
619
+ })
620
+ );
621
+ assert_eq!(
622
+ SliceIterator::new(
623
+ &attn_0,
624
+ &[
625
+ TensorIndexer::Select(1),
626
+ TensorIndexer::Select(1),
627
+ TensorIndexer::Select(1),
628
+ ],
629
+ ),
630
+ Err(InvalidSlice::TooManySlices)
631
+ );
632
+ }
576
633
  }
@@ -1,11 +1,9 @@
1
1
  //! Module Containing the most important structures
2
+ use crate::lib::{Cow, HashMap, String, ToString, Vec};
2
3
  use crate::slice::{InvalidSlice, SliceIterator, TensorIndexer};
3
4
  use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
4
- use std::borrow::Cow;
5
- use std::collections::HashMap;
6
- use std::fs::File;
7
- use std::io::{BufWriter, Write};
8
- use std::path::Path;
5
+ #[cfg(feature = "std")]
6
+ use std::io::Write;
9
7
 
10
8
  const MAX_HEADER_SIZE: usize = 100_000_000;
11
9
 
@@ -32,6 +30,7 @@ pub enum SafeTensorError {
32
30
  /// The offsets declared for tensor with name `String` in the header are invalid
33
31
  InvalidOffset(String),
34
32
  /// IoError
33
+ #[cfg(feature = "std")]
35
34
  IoError(std::io::Error),
36
35
  /// JSON error
37
36
  JsonError(serde_json::Error),
@@ -46,6 +45,7 @@ pub enum SafeTensorError {
46
45
  ValidationOverflow,
47
46
  }
48
47
 
48
+ #[cfg(feature = "std")]
49
49
  impl From<std::io::Error> for SafeTensorError {
50
50
  fn from(error: std::io::Error) -> SafeTensorError {
51
51
  SafeTensorError::IoError(error)
@@ -58,12 +58,16 @@ impl From<serde_json::Error> for SafeTensorError {
58
58
  }
59
59
  }
60
60
 
61
- impl std::fmt::Display for SafeTensorError {
62
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61
+ impl core::fmt::Display for SafeTensorError {
62
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
63
63
  write!(f, "{self:?}")
64
64
  }
65
65
  }
66
66
 
67
+ #[cfg(not(feature = "std"))]
68
+ impl core::error::Error for SafeTensorError {}
69
+
70
+ #[cfg(feature = "std")]
67
71
  impl std::error::Error for SafeTensorError {}
68
72
 
69
73
  struct PreparedData {
@@ -164,7 +168,7 @@ pub trait View {
164
168
  fn data_len(&self) -> usize;
165
169
  }
166
170
 
167
- fn prepare<S: AsRef<str> + Ord + std::fmt::Display, V: View, I: IntoIterator<Item = (S, V)>>(
171
+ fn prepare<S: AsRef<str> + Ord + core::fmt::Display, V: View, I: IntoIterator<Item = (S, V)>>(
168
172
  data: I,
169
173
  data_info: &Option<HashMap<String, String>>,
170
174
  // ) -> Result<(Metadata, Vec<&'hash TensorView<'data>>, usize), SafeTensorError> {
@@ -212,7 +216,7 @@ fn prepare<S: AsRef<str> + Ord + std::fmt::Display, V: View, I: IntoIterator<Ite
212
216
 
213
217
  /// Serialize to an owned byte buffer the dictionnary of tensors.
214
218
  pub fn serialize<
215
- S: AsRef<str> + Ord + std::fmt::Display,
219
+ S: AsRef<str> + Ord + core::fmt::Display,
216
220
  V: View,
217
221
  I: IntoIterator<Item = (S, V)>,
218
222
  >(
@@ -240,14 +244,15 @@ pub fn serialize<
240
244
  /// Serialize to a regular file the dictionnary of tensors.
241
245
  /// Writing directly to file reduces the need to allocate the whole amount to
242
246
  /// memory.
247
+ #[cfg(feature = "std")]
243
248
  pub fn serialize_to_file<
244
- S: AsRef<str> + Ord + std::fmt::Display,
249
+ S: AsRef<str> + Ord + core::fmt::Display,
245
250
  V: View,
246
251
  I: IntoIterator<Item = (S, V)>,
247
252
  >(
248
253
  data: I,
249
254
  data_info: &Option<HashMap<String, String>>,
250
- filename: &Path,
255
+ filename: &std::path::Path,
251
256
  ) -> Result<(), SafeTensorError> {
252
257
  let (
253
258
  PreparedData {
@@ -255,7 +260,7 @@ pub fn serialize_to_file<
255
260
  },
256
261
  tensors,
257
262
  ) = prepare(data, data_info)?;
258
- let mut f = BufWriter::new(File::create(filename)?);
263
+ let mut f = std::io::BufWriter::new(std::fs::File::create(filename)?);
259
264
  f.write_all(n.to_le_bytes().as_ref())?;
260
265
  f.write_all(&header_bytes)?;
261
266
  for tensor in tensors {
@@ -303,7 +308,7 @@ impl<'data> SafeTensors<'data> {
303
308
  return Err(SafeTensorError::InvalidHeaderLength);
304
309
  }
305
310
  let string =
306
- std::str::from_utf8(&buffer[8..stop]).map_err(|_| SafeTensorError::InvalidHeader)?;
311
+ core::str::from_utf8(&buffer[8..stop]).map_err(|_| SafeTensorError::InvalidHeader)?;
307
312
  // Assert the string starts with {
308
313
  // NOTE: Add when we move to 0.4.0
309
314
  // if !string.starts_with('{') {
@@ -719,6 +724,9 @@ mod tests {
719
724
  use super::*;
720
725
  use crate::slice::IndexOp;
721
726
  use proptest::prelude::*;
727
+ #[cfg(not(feature = "std"))]
728
+ extern crate std;
729
+ use std::io::Write;
722
730
 
723
731
  const MAX_DIMENSION: usize = 8;
724
732
  const MAX_SIZE: usize = 8;
@@ -1021,10 +1029,13 @@ mod tests {
1021
1029
  std::fs::remove_file(&filename).unwrap();
1022
1030
 
1023
1031
  // File api
1024
- serialize_to_file(&metadata, &None, Path::new(&filename)).unwrap();
1025
- let raw = std::fs::read(&filename).unwrap();
1026
- let _deserialized = SafeTensors::deserialize(&raw).unwrap();
1027
- std::fs::remove_file(&filename).unwrap();
1032
+ #[cfg(feature = "std")]
1033
+ {
1034
+ serialize_to_file(&metadata, &None, std::path::Path::new(&filename)).unwrap();
1035
+ let raw = std::fs::read(&filename).unwrap();
1036
+ let _deserialized = SafeTensors::deserialize(&raw).unwrap();
1037
+ std::fs::remove_file(&filename).unwrap();
1038
+ }
1028
1039
  }
1029
1040
 
1030
1041
  #[test]
@@ -1097,7 +1108,7 @@ mod tests {
1097
1108
  let n = serialized.len();
1098
1109
 
1099
1110
  let filename = "out.safetensors";
1100
- let mut f = BufWriter::new(File::create(filename).unwrap());
1111
+ let mut f = std::io::BufWriter::new(std::fs::File::create(filename).unwrap());
1101
1112
  f.write_all(n.to_le_bytes().as_ref()).unwrap();
1102
1113
  f.write_all(serialized).unwrap();
1103
1114
  f.write_all(b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0").unwrap();
@@ -1,5 +0,0 @@
1
- #![deny(missing_docs)]
2
- #![doc = include_str!("../README.md")]
3
- pub mod slice;
4
- pub mod tensor;
5
- pub use tensor::{serialize, serialize_to_file, Dtype, SafeTensorError, SafeTensors, View};
File without changes