safetensors 0.6.0rc0__tar.gz → 0.6.1rc0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of safetensors might be problematic. Click here for more details.

Files changed (57) hide show
  1. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/PKG-INFO +9 -5
  2. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/Cargo.lock +13 -13
  3. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/Cargo.toml +2 -1
  4. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/benches/test_pt.py +8 -2
  5. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/convert.py +111 -28
  6. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/convert_all.py +7 -2
  7. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/py_src/safetensors/__init__.py +1 -0
  8. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/py_src/safetensors/__init__.pyi +17 -2
  9. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/py_src/safetensors/mlx.py +3 -1
  10. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/py_src/safetensors/numpy.py +14 -4
  11. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0/bindings/python}/py_src/safetensors/paddle.py +9 -3
  12. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0/bindings/python}/py_src/safetensors/tensorflow.py +3 -1
  13. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0/bindings/python}/py_src/safetensors/torch.py +46 -12
  14. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/src/lib.rs +126 -5
  15. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/stub.py +32 -24
  16. safetensors-0.6.1rc0/bindings/python/tests/test_handle.py +65 -0
  17. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/tests/test_mlx_comparison.py +4 -4
  18. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/tests/test_pt_comparison.py +6 -1
  19. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/tests/test_pt_model.py +27 -9
  20. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/tests/test_simple.py +10 -6
  21. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/uv.lock +78 -149
  22. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/py_src/safetensors/__init__.py +1 -0
  23. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/py_src/safetensors/__init__.pyi +17 -2
  24. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/py_src/safetensors/mlx.py +3 -1
  25. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/py_src/safetensors/numpy.py +14 -4
  26. {safetensors-0.6.0rc0/bindings/python → safetensors-0.6.1rc0}/py_src/safetensors/paddle.py +9 -3
  27. {safetensors-0.6.0rc0/bindings/python → safetensors-0.6.1rc0}/py_src/safetensors/tensorflow.py +3 -1
  28. {safetensors-0.6.0rc0/bindings/python → safetensors-0.6.1rc0}/py_src/safetensors/torch.py +46 -12
  29. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/pyproject.toml +10 -4
  30. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/safetensors/Cargo.toml +1 -1
  31. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/LICENSE +0 -0
  32. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/.gitignore +0 -0
  33. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/LICENSE +0 -0
  34. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/MANIFEST.in +0 -0
  35. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/Makefile +0 -0
  36. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/README.md +0 -0
  37. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/benches/test_flax.py +0 -0
  38. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/benches/test_mlx.py +0 -0
  39. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/benches/test_paddle.py +0 -0
  40. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/benches/test_tf.py +0 -0
  41. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/fuzz.py +0 -0
  42. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/py_src/safetensors/flax.py +0 -0
  43. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/py_src/safetensors/py.typed +0 -0
  44. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/setup.cfg +0 -0
  45. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/src/view.rs +0 -0
  46. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/tests/data/__init__.py +0 -0
  47. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/tests/test_flax_comparison.py +0 -0
  48. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/tests/test_paddle_comparison.py +0 -0
  49. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/bindings/python/tests/test_tf_comparison.py +0 -0
  50. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/py_src/safetensors/flax.py +0 -0
  51. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/py_src/safetensors/py.typed +0 -0
  52. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/safetensors/LICENSE +0 -0
  53. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/safetensors/README.md +0 -0
  54. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/safetensors/benches/benchmark.rs +0 -0
  55. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/safetensors/src/lib.rs +0 -0
  56. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/safetensors/src/slice.rs +0 -0
  57. {safetensors-0.6.0rc0 → safetensors-0.6.1rc0}/safetensors/src/tensor.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: safetensors
3
- Version: 0.6.0rc0
3
+ Version: 0.6.1rc0
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Developers
6
6
  Classifier: Intended Audience :: Education
@@ -28,10 +28,7 @@ Requires-Dist: jaxlib>=0.3.25 ; extra == 'jax'
28
28
  Requires-Dist: mlx>=0.0.9 ; extra == 'mlx'
29
29
  Requires-Dist: safetensors[numpy] ; extra == 'paddlepaddle'
30
30
  Requires-Dist: paddlepaddle>=2.4.1 ; extra == 'paddlepaddle'
31
- Requires-Dist: black==22.3 ; extra == 'quality'
32
- Requires-Dist: click==8.0.4 ; extra == 'quality'
33
- Requires-Dist: isort>=5.5.4 ; extra == 'quality'
34
- Requires-Dist: flake8>=3.8.3 ; extra == 'quality'
31
+ Requires-Dist: ruff ; extra == 'quality'
35
32
  Requires-Dist: safetensors[numpy] ; extra == 'testing'
36
33
  Requires-Dist: h5py>=3.7.0 ; extra == 'testing'
37
34
  Requires-Dist: huggingface-hub>=0.12.1 ; extra == 'testing'
@@ -39,6 +36,12 @@ Requires-Dist: setuptools-rust>=1.5.2 ; extra == 'testing'
39
36
  Requires-Dist: pytest>=7.2.0 ; extra == 'testing'
40
37
  Requires-Dist: pytest-benchmark>=4.0.0 ; extra == 'testing'
41
38
  Requires-Dist: hypothesis>=6.70.2 ; extra == 'testing'
39
+ Requires-Dist: safetensors[numpy] ; extra == 'testingfree'
40
+ Requires-Dist: huggingface-hub>=0.12.1 ; extra == 'testingfree'
41
+ Requires-Dist: setuptools-rust>=1.5.2 ; extra == 'testingfree'
42
+ Requires-Dist: pytest>=7.2.0 ; extra == 'testingfree'
43
+ Requires-Dist: pytest-benchmark>=4.0.0 ; extra == 'testingfree'
44
+ Requires-Dist: hypothesis>=6.70.2 ; extra == 'testingfree'
42
45
  Requires-Dist: safetensors[torch] ; extra == 'all'
43
46
  Requires-Dist: safetensors[numpy] ; extra == 'all'
44
47
  Requires-Dist: safetensors[pinned-tf] ; extra == 'all'
@@ -56,6 +59,7 @@ Provides-Extra: mlx
56
59
  Provides-Extra: paddlepaddle
57
60
  Provides-Extra: quality
58
61
  Provides-Extra: testing
62
+ Provides-Extra: testingfree
59
63
  Provides-Extra: all
60
64
  Provides-Extra: dev
61
65
  License-File: LICENSE
@@ -1,12 +1,12 @@
1
1
  # This file is automatically @generated by Cargo.
2
2
  # It is not intended for manual editing.
3
- version = 3
3
+ version = 4
4
4
 
5
5
  [[package]]
6
6
  name = "autocfg"
7
- version = "1.4.0"
7
+ version = "1.5.0"
8
8
  source = "registry+https://github.com/rust-lang/crates.io-index"
9
- checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
9
+ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
10
10
 
11
11
  [[package]]
12
12
  name = "heck"
@@ -28,9 +28,9 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
28
28
 
29
29
  [[package]]
30
30
  name = "libc"
31
- version = "0.2.173"
31
+ version = "0.2.174"
32
32
  source = "registry+https://github.com/rust-lang/crates.io-index"
33
- checksum = "d8cfeafaffdbc32176b64fb251369d52ea9f0a8fbc6f8759edffef7b525d64bb"
33
+ checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776"
34
34
 
35
35
  [[package]]
36
36
  name = "memchr"
@@ -40,9 +40,9 @@ checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
40
40
 
41
41
  [[package]]
42
42
  name = "memmap2"
43
- version = "0.9.5"
43
+ version = "0.9.7"
44
44
  source = "registry+https://github.com/rust-lang/crates.io-index"
45
- checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f"
45
+ checksum = "483758ad303d734cec05e5c12b41d7e93e6a6390c5e9dae6bdeb7c1259012d28"
46
46
  dependencies = [
47
47
  "libc",
48
48
  ]
@@ -156,7 +156,7 @@ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
156
156
 
157
157
  [[package]]
158
158
  name = "safetensors"
159
- version = "0.6.0-rc.0"
159
+ version = "0.6.1-rc.0"
160
160
  dependencies = [
161
161
  "serde",
162
162
  "serde_json",
@@ -164,7 +164,7 @@ dependencies = [
164
164
 
165
165
  [[package]]
166
166
  name = "safetensors-python"
167
- version = "0.6.0-rc.0"
167
+ version = "0.6.1-rc.0"
168
168
  dependencies = [
169
169
  "memmap2",
170
170
  "pyo3",
@@ -194,9 +194,9 @@ dependencies = [
194
194
 
195
195
  [[package]]
196
196
  name = "serde_json"
197
- version = "1.0.140"
197
+ version = "1.0.142"
198
198
  source = "registry+https://github.com/rust-lang/crates.io-index"
199
- checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
199
+ checksum = "030fedb782600dcbd6f02d479bf0d817ac3bb40d644745b769d6a96bc3afc5a7"
200
200
  dependencies = [
201
201
  "itoa",
202
202
  "memchr",
@@ -206,9 +206,9 @@ dependencies = [
206
206
 
207
207
  [[package]]
208
208
  name = "syn"
209
- version = "2.0.103"
209
+ version = "2.0.104"
210
210
  source = "registry+https://github.com/rust-lang/crates.io-index"
211
- checksum = "e4307e30089d6fd6aff212f2da3a1f9e32f3223b1f010fb09b7c95f90f3ca1e8"
211
+ checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40"
212
212
  dependencies = [
213
213
  "proc-macro2",
214
214
  "quote",
@@ -1,8 +1,9 @@
1
1
  [package]
2
2
  name = "safetensors-python"
3
- version = "0.6.0-rc.0"
3
+ version = "0.6.1-rc.0"
4
4
  edition = "2021"
5
5
  rust-version = "1.74"
6
+ readme = "README.md"
6
7
 
7
8
  # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8
9
  [lib]
@@ -118,7 +118,10 @@ def test_pt_sf_load_gpu(benchmark):
118
118
  assert torch.allclose(v, tv)
119
119
 
120
120
 
121
- @pytest.mark.skipif(not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), reason="requires mps")
121
+ @pytest.mark.skipif(
122
+ not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(),
123
+ reason="requires mps",
124
+ )
122
125
  def test_pt_pt_load_mps(benchmark):
123
126
  # benchmark something
124
127
  weights = create_gpt2(12)
@@ -133,7 +136,10 @@ def test_pt_pt_load_mps(benchmark):
133
136
  assert torch.allclose(v, tv)
134
137
 
135
138
 
136
- @pytest.mark.skipif(not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), reason="requires mps")
139
+ @pytest.mark.skipif(
140
+ not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(),
141
+ reason="requires mps",
142
+ )
137
143
  def test_pt_sf_load_mps(benchmark):
138
144
  # benchmark something
139
145
  weights = create_gpt2(12)
@@ -8,7 +8,13 @@ from typing import Dict, List, Optional, Set, Tuple
8
8
 
9
9
  import torch
10
10
 
11
- from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
11
+ from huggingface_hub import (
12
+ CommitInfo,
13
+ CommitOperationAdd,
14
+ Discussion,
15
+ HfApi,
16
+ hf_hub_download,
17
+ )
12
18
  from huggingface_hub.file_download import repo_folder_name
13
19
  from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
14
20
 
@@ -49,7 +55,9 @@ def _remove_duplicate_names(
49
55
  shareds = _find_shared_tensors(state_dict)
50
56
  to_remove = defaultdict(list)
51
57
  for shared in shareds:
52
- complete_names = set([name for name in shared if _is_complete(state_dict[name])])
58
+ complete_names = set(
59
+ [name for name in shared if _is_complete(state_dict[name])]
60
+ )
53
61
  if not complete_names:
54
62
  if len(shared) == 1:
55
63
  # Force contiguous
@@ -81,14 +89,20 @@ def _remove_duplicate_names(
81
89
  return to_remove
82
90
 
83
91
 
84
- def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
92
+ def get_discard_names(
93
+ model_id: str, revision: Optional[str], folder: str, token: Optional[str]
94
+ ) -> List[str]:
85
95
  try:
86
96
  import json
87
97
 
88
98
  import transformers
89
99
 
90
100
  config_filename = hf_hub_download(
91
- model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
101
+ model_id,
102
+ revision=revision,
103
+ filename="config.json",
104
+ token=token,
105
+ cache_dir=folder,
92
106
  )
93
107
  with open(config_filename, "r") as f:
94
108
  config = json.load(f)
@@ -129,10 +143,19 @@ def rename(pt_filename: str) -> str:
129
143
 
130
144
 
131
145
  def convert_multi(
132
- model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]
146
+ model_id: str,
147
+ *,
148
+ revision=Optional[str],
149
+ folder: str,
150
+ token: Optional[str],
151
+ discard_names: List[str],
133
152
  ) -> ConversionResult:
134
153
  filename = hf_hub_download(
135
- repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder
154
+ repo_id=model_id,
155
+ revision=revision,
156
+ filename="pytorch_model.bin.index.json",
157
+ token=token,
158
+ cache_dir=folder,
136
159
  )
137
160
  with open(filename, "r") as f:
138
161
  data = json.load(f)
@@ -140,7 +163,9 @@ def convert_multi(
140
163
  filenames = set(data["weight_map"].values())
141
164
  local_filenames = []
142
165
  for filename in filenames:
143
- pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token, cache_dir=folder)
166
+ pt_filename = hf_hub_download(
167
+ repo_id=model_id, filename=filename, token=token, cache_dir=folder
168
+ )
144
169
 
145
170
  sf_filename = rename(pt_filename)
146
171
  sf_filename = os.path.join(folder, sf_filename)
@@ -156,7 +181,8 @@ def convert_multi(
156
181
  local_filenames.append(index)
157
182
 
158
183
  operations = [
159
- CommitOperationAdd(path_in_repo=os.path.basename(local), path_or_fileobj=local) for local in local_filenames
184
+ CommitOperationAdd(path_in_repo=os.path.basename(local), path_or_fileobj=local)
185
+ for local in local_filenames
160
186
  ]
161
187
  errors: List[Tuple[str, "Exception"]] = []
162
188
 
@@ -164,10 +190,19 @@ def convert_multi(
164
190
 
165
191
 
166
192
  def convert_single(
167
- model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
193
+ model_id: str,
194
+ *,
195
+ revision: Optional[str],
196
+ folder: str,
197
+ token: Optional[str],
198
+ discard_names: List[str],
168
199
  ) -> ConversionResult:
169
200
  pt_filename = hf_hub_download(
170
- repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder
201
+ repo_id=model_id,
202
+ revision=revision,
203
+ filename="pytorch_model.bin",
204
+ token=token,
205
+ cache_dir=folder,
171
206
  )
172
207
 
173
208
  sf_name = "model.safetensors"
@@ -219,20 +254,30 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
219
254
  sf_only = sf_set - pt_set
220
255
 
221
256
  if pt_only:
222
- errors.append(f"{key} : PT warnings contain {pt_only} which are not present in SF warnings")
257
+ errors.append(
258
+ f"{key} : PT warnings contain {pt_only} which are not present in SF warnings"
259
+ )
223
260
  if sf_only:
224
- errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
261
+ errors.append(
262
+ f"{key} : SF warnings contain {sf_only} which are not present in PT warnings"
263
+ )
225
264
  return "\n".join(errors)
226
265
 
227
266
 
228
- def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
267
+ def previous_pr(
268
+ api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]
269
+ ) -> Optional["Discussion"]:
229
270
  try:
230
271
  revision_commit = api.model_info(model_id, revision=revision).sha
231
272
  discussions = api.get_repo_discussions(repo_id=model_id)
232
273
  except Exception:
233
274
  return None
234
275
  for discussion in discussions:
235
- if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
276
+ if (
277
+ discussion.status in {"open", "closed"}
278
+ and discussion.is_pull_request
279
+ and discussion.title == pr_title
280
+ ):
236
281
  commits = api.list_repo_commits(model_id, revision=discussion.git_reference)
237
282
 
238
283
  if revision_commit == commits[1].commit_id:
@@ -241,7 +286,12 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st
241
286
 
242
287
 
243
288
  def convert_generic(
244
- model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
289
+ model_id: str,
290
+ *,
291
+ revision=Optional[str],
292
+ folder: str,
293
+ filenames: Set[str],
294
+ token: Optional[str],
245
295
  ) -> ConversionResult:
246
296
  operations = []
247
297
  errors = []
@@ -251,7 +301,11 @@ def convert_generic(
251
301
  prefix, ext = os.path.splitext(filename)
252
302
  if ext in extensions:
253
303
  pt_filename = hf_hub_download(
254
- model_id, revision=revision, filename=filename, token=token, cache_dir=folder
304
+ model_id,
305
+ revision=revision,
306
+ filename=filename,
307
+ token=token,
308
+ cache_dir=folder,
255
309
  )
256
310
  dirname, raw_filename = os.path.split(filename)
257
311
  if raw_filename == "pytorch_model.bin":
@@ -263,7 +317,11 @@ def convert_generic(
263
317
  sf_filename = os.path.join(folder, sf_in_repo)
264
318
  try:
265
319
  convert_file(pt_filename, sf_filename, discard_names=[])
266
- operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
320
+ operations.append(
321
+ CommitOperationAdd(
322
+ path_in_repo=sf_in_repo, path_or_fileobj=sf_filename
323
+ )
324
+ )
267
325
  except Exception as e:
268
326
  errors.append((pt_filename, e))
269
327
  return operations, errors
@@ -285,28 +343,50 @@ def convert(
285
343
  pr = previous_pr(api, model_id, pr_title, revision=revision)
286
344
 
287
345
  library_name = getattr(info, "library_name", None)
288
- if any(filename.endswith(".safetensors") for filename in filenames) and not force:
289
- raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
346
+ if (
347
+ any(filename.endswith(".safetensors") for filename in filenames)
348
+ and not force
349
+ ):
350
+ raise AlreadyExists(
351
+ f"Model {model_id} is already converted, skipping.."
352
+ )
290
353
  elif pr is not None and not force:
291
354
  url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
292
355
  new_pr = pr
293
- raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
356
+ raise AlreadyExists(
357
+ f"Model {model_id} already has an open PR check out {url}"
358
+ )
294
359
  elif library_name == "transformers":
295
-
296
- discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
360
+ discard_names = get_discard_names(
361
+ model_id, revision=revision, folder=folder, token=api.token
362
+ )
297
363
  if "pytorch_model.bin" in filenames:
298
364
  operations, errors = convert_single(
299
- model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
365
+ model_id,
366
+ revision=revision,
367
+ folder=folder,
368
+ token=api.token,
369
+ discard_names=discard_names,
300
370
  )
301
371
  elif "pytorch_model.bin.index.json" in filenames:
302
372
  operations, errors = convert_multi(
303
- model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
373
+ model_id,
374
+ revision=revision,
375
+ folder=folder,
376
+ token=api.token,
377
+ discard_names=discard_names,
304
378
  )
305
379
  else:
306
- raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
380
+ raise RuntimeError(
381
+ f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert"
382
+ )
307
383
  else:
308
384
  operations, errors = convert_generic(
309
- model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
385
+ model_id,
386
+ revision=revision,
387
+ folder=folder,
388
+ filenames=filenames,
389
+ token=api.token,
310
390
  )
311
391
 
312
392
  if operations:
@@ -366,7 +446,9 @@ if __name__ == "__main__":
366
446
  " Continue [Y/n] ?"
367
447
  )
368
448
  if txt.lower() in {"", "y"}:
369
- commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force)
449
+ commit_info, errors = convert(
450
+ api, model_id, revision=args.revision, force=args.force
451
+ )
370
452
  string = f"""
371
453
  ### Success 🔥
372
454
  Yay! This model was successfully converted and a PR was open using your token, here:
@@ -375,7 +457,8 @@ Yay! This model was successfully converted and a PR was open using your token, h
375
457
  if errors:
376
458
  string += "\nErrors during conversion:\n"
377
459
  string += "\n".join(
378
- f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
460
+ f"Error while converting {filename}: {e}, skipped conversion"
461
+ for filename, e in errors
379
462
  )
380
463
  print(string)
381
464
  else:
@@ -1,4 +1,5 @@
1
1
  """Simple utility tool to convert automatically most downloaded models"""
2
+
2
3
  from convert import AlreadyExists, convert
3
4
  from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
4
5
  from transformers import AutoConfig
@@ -10,7 +11,11 @@ if __name__ == "__main__":
10
11
 
11
12
  total = 50
12
13
  models = list(
13
- api.list_models(filter=ModelFilter(library=args.library.Transformers), sort="downloads", direction=-1)
14
+ api.list_models(
15
+ filter=ModelFilter(library=args.library.Transformers),
16
+ sort="downloads",
17
+ direction=-1,
18
+ )
14
19
  )[:total]
15
20
 
16
21
  correct = 0
@@ -40,4 +45,4 @@ if __name__ == "__main__":
40
45
 
41
46
  print(f"Errors: {errors}")
42
47
  print(f"File size is difference {len(errors)}")
43
- print(f"Correct rate {correct}/{total} ({correct/total * 100:.2f}%)")
48
+ print(f"Correct rate {correct}/{total} ({correct / total * 100:.2f}%)")
@@ -4,6 +4,7 @@ from ._safetensors_rust import ( # noqa: F401
4
4
  __version__,
5
5
  deserialize,
6
6
  safe_open,
7
+ _safe_open_handle,
7
8
  serialize,
8
9
  serialize_file,
9
10
  )
@@ -49,7 +49,7 @@ def serialize_file(tensor_dict, filename, metadata=None):
49
49
 
50
50
  Returns:
51
51
  (`NoneType`):
52
- On success return None.
52
+ On success return None
53
53
  """
54
54
  pass
55
55
 
@@ -68,19 +68,21 @@ class safe_open:
68
68
  device (`str`, defaults to `"cpu"`):
69
69
  The device on which you want the tensors.
70
70
  """
71
-
72
71
  def __init__(self, filename, framework, device=...):
73
72
  pass
73
+
74
74
  def __enter__(self):
75
75
  """
76
76
  Start the context manager
77
77
  """
78
78
  pass
79
+
79
80
  def __exit__(self, _exc_type, _exc_value, _traceback):
80
81
  """
81
82
  Exits the context manager
82
83
  """
83
84
  pass
85
+
84
86
  def get_slice(self, name):
85
87
  """
86
88
  Returns a full slice view object
@@ -102,6 +104,7 @@ class safe_open:
102
104
  ```
103
105
  """
104
106
  pass
107
+
105
108
  def get_tensor(self, name):
106
109
  """
107
110
  Returns a full tensor
@@ -124,6 +127,7 @@ class safe_open:
124
127
  ```
125
128
  """
126
129
  pass
130
+
127
131
  def keys(self):
128
132
  """
129
133
  Returns the names of the tensors in the file.
@@ -133,6 +137,7 @@ class safe_open:
133
137
  The name of the tensors contained in that file
134
138
  """
135
139
  pass
140
+
136
141
  def metadata(self):
137
142
  """
138
143
  Return the special non tensor information in the header
@@ -143,6 +148,16 @@ class safe_open:
143
148
  """
144
149
  pass
145
150
 
151
+ def offset_keys(self):
152
+ """
153
+ Returns the names of the tensors in the file, ordered by offset.
154
+
155
+ Returns:
156
+ (`List[str]`):
157
+ The name of the tensors contained in that file
158
+ """
159
+ pass
160
+
146
161
  class SafetensorError(Exception):
147
162
  """
148
163
  Custom Python Exception for Safetensor errors.
@@ -7,7 +7,9 @@ import mlx.core as mx
7
7
  from safetensors import numpy, safe_open
8
8
 
9
9
 
10
- def save(tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None) -> bytes:
10
+ def save(
11
+ tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None
12
+ ) -> bytes:
11
13
  """
12
14
  Saves a dictionary of tensors into raw bytes in safetensors format.
13
15
 
@@ -13,7 +13,9 @@ def _tobytes(tensor: np.ndarray) -> bytes:
13
13
  return tensor.tobytes()
14
14
 
15
15
 
16
- def save(tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None) -> bytes:
16
+ def save(
17
+ tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None
18
+ ) -> bytes:
17
19
  """
18
20
  Saves a dictionary of tensors into raw bytes in safetensors format.
19
21
 
@@ -38,14 +40,19 @@ def save(tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]]
38
40
  byte_data = save(tensors)
39
41
  ```
40
42
  """
41
- flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()}
43
+ flattened = {
44
+ k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)}
45
+ for k, v in tensor_dict.items()
46
+ }
42
47
  serialized = serialize(flattened, metadata=metadata)
43
48
  result = bytes(serialized)
44
49
  return result
45
50
 
46
51
 
47
52
  def save_file(
48
- tensor_dict: Dict[str, np.ndarray], filename: Union[str, os.PathLike], metadata: Optional[Dict[str, str]] = None
53
+ tensor_dict: Dict[str, np.ndarray],
54
+ filename: Union[str, os.PathLike],
55
+ metadata: Optional[Dict[str, str]] = None,
49
56
  ) -> None:
50
57
  """
51
58
  Saves a dictionary of tensors into raw bytes in safetensors format.
@@ -73,7 +80,10 @@ def save_file(
73
80
  save_file(tensors, "model.safetensors")
74
81
  ```
75
82
  """
76
- flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()}
83
+ flattened = {
84
+ k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)}
85
+ for k, v in tensor_dict.items()
86
+ }
77
87
  serialize_file(flattened, filename, metadata=metadata)
78
88
 
79
89
 
@@ -7,7 +7,9 @@ import paddle
7
7
  from safetensors import numpy
8
8
 
9
9
 
10
- def save(tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
10
+ def save(
11
+ tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None
12
+ ) -> bytes:
11
13
  """
12
14
  Saves a dictionary of tensors into raw bytes in safetensors format.
13
15
 
@@ -98,7 +100,9 @@ def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]:
98
100
  return _np2paddle(flat, device)
99
101
 
100
102
 
101
- def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, paddle.Tensor]:
103
+ def load_file(
104
+ filename: Union[str, os.PathLike], device="cpu"
105
+ ) -> Dict[str, paddle.Tensor]:
102
106
  """
103
107
  Loads a safetensors file into paddle format.
104
108
 
@@ -126,7 +130,9 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, padd
126
130
  return output
127
131
 
128
132
 
129
- def _np2paddle(numpy_dict: Dict[str, np.ndarray], device: str = "cpu") -> Dict[str, paddle.Tensor]:
133
+ def _np2paddle(
134
+ numpy_dict: Dict[str, np.ndarray], device: str = "cpu"
135
+ ) -> Dict[str, paddle.Tensor]:
130
136
  for k, v in numpy_dict.items():
131
137
  numpy_dict[k] = paddle.to_tensor(v, place=device)
132
138
  return numpy_dict
@@ -7,7 +7,9 @@ import tensorflow as tf
7
7
  from safetensors import numpy, safe_open
8
8
 
9
9
 
10
- def save(tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
10
+ def save(
11
+ tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None
12
+ ) -> bytes:
11
13
  """
12
14
  Saves a dictionary of tensors into raw bytes in safetensors format.
13
15