safetensors 0.6.2__tar.gz → 0.7.0rc0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of safetensors might be problematic. Click here for more details.

Files changed (60) hide show
  1. {safetensors-0.6.2 → safetensors-0.7.0rc0}/PKG-INFO +2 -1
  2. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/Cargo.lock +77 -26
  3. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/Cargo.toml +1 -1
  4. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/py_src/safetensors/numpy.py +1 -0
  5. safetensors-0.7.0rc0/bindings/python/py_src/safetensors/paddle.py +290 -0
  6. {safetensors-0.6.2 → safetensors-0.7.0rc0/bindings/python}/py_src/safetensors/torch.py +11 -53
  7. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/src/lib.rs +245 -17
  8. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/src/view.rs +1 -0
  9. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/tests/test_flax_comparison.py +1 -0
  10. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/tests/test_mlx_comparison.py +3 -0
  11. safetensors-0.7.0rc0/bindings/python/tests/test_paddle_comparison.py +243 -0
  12. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/tests/test_pt_comparison.py +26 -0
  13. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/tests/test_tf_comparison.py +1 -0
  14. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/uv.lock +1171 -1167
  15. {safetensors-0.6.2 → safetensors-0.7.0rc0}/py_src/safetensors/numpy.py +1 -0
  16. safetensors-0.7.0rc0/py_src/safetensors/paddle.py +290 -0
  17. {safetensors-0.6.2/bindings/python → safetensors-0.7.0rc0}/py_src/safetensors/torch.py +11 -53
  18. {safetensors-0.6.2 → safetensors-0.7.0rc0}/pyproject.toml +1 -0
  19. {safetensors-0.6.2 → safetensors-0.7.0rc0}/safetensors/Cargo.toml +10 -6
  20. {safetensors-0.6.2 → safetensors-0.7.0rc0}/safetensors/src/lib.rs +1 -6
  21. {safetensors-0.6.2 → safetensors-0.7.0rc0}/safetensors/src/tensor.rs +29 -0
  22. safetensors-0.6.2/bindings/python/py_src/safetensors/paddle.py +0 -144
  23. safetensors-0.6.2/bindings/python/tests/test_paddle_comparison.py +0 -47
  24. safetensors-0.6.2/py_src/safetensors/paddle.py +0 -144
  25. {safetensors-0.6.2 → safetensors-0.7.0rc0}/LICENSE +0 -0
  26. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/.gitignore +0 -0
  27. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/LICENSE +0 -0
  28. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/MANIFEST.in +0 -0
  29. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/Makefile +0 -0
  30. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/README.md +0 -0
  31. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/benches/test_flax.py +0 -0
  32. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/benches/test_mlx.py +0 -0
  33. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/benches/test_paddle.py +0 -0
  34. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/benches/test_pt.py +0 -0
  35. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/benches/test_tf.py +0 -0
  36. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/convert.py +0 -0
  37. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/convert_all.py +0 -0
  38. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/fuzz.py +0 -0
  39. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/py_src/safetensors/__init__.py +0 -0
  40. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/py_src/safetensors/__init__.pyi +0 -0
  41. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/py_src/safetensors/flax.py +0 -0
  42. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/py_src/safetensors/mlx.py +0 -0
  43. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/py_src/safetensors/py.typed +0 -0
  44. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/py_src/safetensors/tensorflow.py +0 -0
  45. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/setup.cfg +0 -0
  46. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/stub.py +0 -0
  47. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/tests/data/__init__.py +0 -0
  48. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/tests/test_handle.py +0 -0
  49. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/tests/test_pt_model.py +0 -0
  50. {safetensors-0.6.2 → safetensors-0.7.0rc0}/bindings/python/tests/test_simple.py +0 -0
  51. {safetensors-0.6.2 → safetensors-0.7.0rc0}/py_src/safetensors/__init__.py +0 -0
  52. {safetensors-0.6.2 → safetensors-0.7.0rc0}/py_src/safetensors/__init__.pyi +0 -0
  53. {safetensors-0.6.2 → safetensors-0.7.0rc0}/py_src/safetensors/flax.py +0 -0
  54. {safetensors-0.6.2 → safetensors-0.7.0rc0}/py_src/safetensors/mlx.py +0 -0
  55. {safetensors-0.6.2 → safetensors-0.7.0rc0}/py_src/safetensors/py.typed +0 -0
  56. {safetensors-0.6.2 → safetensors-0.7.0rc0}/py_src/safetensors/tensorflow.py +0 -0
  57. {safetensors-0.6.2 → safetensors-0.7.0rc0}/safetensors/LICENSE +0 -0
  58. {safetensors-0.6.2 → safetensors-0.7.0rc0}/safetensors/README.md +0 -0
  59. {safetensors-0.6.2 → safetensors-0.7.0rc0}/safetensors/benches/benchmark.rs +0 -0
  60. {safetensors-0.6.2 → safetensors-0.7.0rc0}/safetensors/src/slice.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: safetensors
3
- Version: 0.6.2
3
+ Version: 0.7.0rc0
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Developers
6
6
  Classifier: Intended Audience :: Education
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
16
  Classifier: Typing :: Typed
17
17
  Requires-Dist: numpy>=1.21.6 ; extra == 'numpy'
18
+ Requires-Dist: packaging ; extra == 'torch'
18
19
  Requires-Dist: safetensors[numpy] ; extra == 'torch'
19
20
  Requires-Dist: torch>=1.10 ; extra == 'torch'
20
21
  Requires-Dist: safetensors[numpy] ; extra == 'tensorflow'
@@ -2,12 +2,42 @@
2
2
  # It is not intended for manual editing.
3
3
  version = 3
4
4
 
5
+ [[package]]
6
+ name = "allocator-api2"
7
+ version = "0.2.21"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
10
+
5
11
  [[package]]
6
12
  name = "autocfg"
7
13
  version = "1.5.0"
8
14
  source = "registry+https://github.com/rust-lang/crates.io-index"
9
15
  checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
10
16
 
17
+ [[package]]
18
+ name = "equivalent"
19
+ version = "1.0.2"
20
+ source = "registry+https://github.com/rust-lang/crates.io-index"
21
+ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
22
+
23
+ [[package]]
24
+ name = "foldhash"
25
+ version = "0.2.0"
26
+ source = "registry+https://github.com/rust-lang/crates.io-index"
27
+ checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
28
+
29
+ [[package]]
30
+ name = "hashbrown"
31
+ version = "0.16.0"
32
+ source = "registry+https://github.com/rust-lang/crates.io-index"
33
+ checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d"
34
+ dependencies = [
35
+ "allocator-api2",
36
+ "equivalent",
37
+ "foldhash",
38
+ "serde",
39
+ ]
40
+
11
41
  [[package]]
12
42
  name = "heck"
13
43
  version = "0.5.0"
@@ -16,9 +46,12 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
16
46
 
17
47
  [[package]]
18
48
  name = "indoc"
19
- version = "2.0.6"
49
+ version = "2.0.7"
20
50
  source = "registry+https://github.com/rust-lang/crates.io-index"
21
- checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
51
+ checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
52
+ dependencies = [
53
+ "rustversion",
54
+ ]
22
55
 
23
56
  [[package]]
24
57
  name = "itoa"
@@ -28,21 +61,21 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
28
61
 
29
62
  [[package]]
30
63
  name = "libc"
31
- version = "0.2.174"
64
+ version = "0.2.177"
32
65
  source = "registry+https://github.com/rust-lang/crates.io-index"
33
- checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776"
66
+ checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976"
34
67
 
35
68
  [[package]]
36
69
  name = "memchr"
37
- version = "2.7.5"
70
+ version = "2.7.6"
38
71
  source = "registry+https://github.com/rust-lang/crates.io-index"
39
- checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
72
+ checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
40
73
 
41
74
  [[package]]
42
75
  name = "memmap2"
43
- version = "0.9.5"
76
+ version = "0.9.9"
44
77
  source = "registry+https://github.com/rust-lang/crates.io-index"
45
- checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f"
78
+ checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490"
46
79
  dependencies = [
47
80
  "libc",
48
81
  ]
@@ -70,9 +103,9 @@ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
70
103
 
71
104
  [[package]]
72
105
  name = "proc-macro2"
73
- version = "1.0.95"
106
+ version = "1.0.103"
74
107
  source = "registry+https://github.com/rust-lang/crates.io-index"
75
- checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
108
+ checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
76
109
  dependencies = [
77
110
  "unicode-ident",
78
111
  ]
@@ -141,13 +174,19 @@ dependencies = [
141
174
 
142
175
  [[package]]
143
176
  name = "quote"
144
- version = "1.0.40"
177
+ version = "1.0.41"
145
178
  source = "registry+https://github.com/rust-lang/crates.io-index"
146
- checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
179
+ checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1"
147
180
  dependencies = [
148
181
  "proc-macro2",
149
182
  ]
150
183
 
184
+ [[package]]
185
+ name = "rustversion"
186
+ version = "1.0.22"
187
+ source = "registry+https://github.com/rust-lang/crates.io-index"
188
+ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
189
+
151
190
  [[package]]
152
191
  name = "ryu"
153
192
  version = "1.0.20"
@@ -156,15 +195,16 @@ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
156
195
 
157
196
  [[package]]
158
197
  name = "safetensors"
159
- version = "0.6.2"
198
+ version = "0.7.0-rc.0"
160
199
  dependencies = [
200
+ "hashbrown",
161
201
  "serde",
162
202
  "serde_json",
163
203
  ]
164
204
 
165
205
  [[package]]
166
206
  name = "safetensors-python"
167
- version = "0.6.2"
207
+ version = "0.7.0-rc.0"
168
208
  dependencies = [
169
209
  "memmap2",
170
210
  "pyo3",
@@ -174,18 +214,28 @@ dependencies = [
174
214
 
175
215
  [[package]]
176
216
  name = "serde"
177
- version = "1.0.219"
217
+ version = "1.0.228"
178
218
  source = "registry+https://github.com/rust-lang/crates.io-index"
179
- checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
219
+ checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
220
+ dependencies = [
221
+ "serde_core",
222
+ "serde_derive",
223
+ ]
224
+
225
+ [[package]]
226
+ name = "serde_core"
227
+ version = "1.0.228"
228
+ source = "registry+https://github.com/rust-lang/crates.io-index"
229
+ checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
180
230
  dependencies = [
181
231
  "serde_derive",
182
232
  ]
183
233
 
184
234
  [[package]]
185
235
  name = "serde_derive"
186
- version = "1.0.219"
236
+ version = "1.0.228"
187
237
  source = "registry+https://github.com/rust-lang/crates.io-index"
188
- checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
238
+ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
189
239
  dependencies = [
190
240
  "proc-macro2",
191
241
  "quote",
@@ -194,21 +244,22 @@ dependencies = [
194
244
 
195
245
  [[package]]
196
246
  name = "serde_json"
197
- version = "1.0.140"
247
+ version = "1.0.145"
198
248
  source = "registry+https://github.com/rust-lang/crates.io-index"
199
- checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
249
+ checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c"
200
250
  dependencies = [
201
251
  "itoa",
202
252
  "memchr",
203
253
  "ryu",
204
254
  "serde",
255
+ "serde_core",
205
256
  ]
206
257
 
207
258
  [[package]]
208
259
  name = "syn"
209
- version = "2.0.104"
260
+ version = "2.0.108"
210
261
  source = "registry+https://github.com/rust-lang/crates.io-index"
211
- checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40"
262
+ checksum = "da58917d35242480a05c2897064da0a80589a2a0476c9a3f2fdc83b53502e917"
212
263
  dependencies = [
213
264
  "proc-macro2",
214
265
  "quote",
@@ -217,15 +268,15 @@ dependencies = [
217
268
 
218
269
  [[package]]
219
270
  name = "target-lexicon"
220
- version = "0.13.2"
271
+ version = "0.13.3"
221
272
  source = "registry+https://github.com/rust-lang/crates.io-index"
222
- checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a"
273
+ checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c"
223
274
 
224
275
  [[package]]
225
276
  name = "unicode-ident"
226
- version = "1.0.18"
277
+ version = "1.0.20"
227
278
  source = "registry+https://github.com/rust-lang/crates.io-index"
228
- checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
279
+ checksum = "462eeb75aeb73aea900253ce739c8e18a67423fadf006037cd3ff27e82748a06"
229
280
 
230
281
  [[package]]
231
282
  name = "unindent"
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "safetensors-python"
3
- version = "0.6.2"
3
+ version = "0.7.0-rc.0"
4
4
  edition = "2021"
5
5
  rust-version = "1.74"
6
6
  readme = "README.md"
@@ -154,6 +154,7 @@ _TYPES = {
154
154
  "I8": np.int8,
155
155
  "U8": np.uint8,
156
156
  "BOOL": bool,
157
+ "C64": np.complex64,
157
158
  }
158
159
 
159
160
 
@@ -0,0 +1,290 @@
1
+ import os
2
+ import sys
3
+ from typing import Any, Dict, Optional, Union
4
+
5
+ import numpy as np
6
+ import paddle
7
+
8
+ from safetensors import numpy, deserialize, safe_open, serialize, serialize_file
9
+
10
+
11
+ def save(
12
+ tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None
13
+ ) -> bytes:
14
+ """
15
+ Saves a dictionary of tensors into raw bytes in safetensors format.
16
+
17
+ Args:
18
+ tensors (`Dict[str, paddle.Tensor]`):
19
+ The incoming tensors. Tensors need to be contiguous and dense.
20
+ metadata (`Dict[str, str]`, *optional*, defaults to `None`):
21
+ Optional text only metadata you might want to save in your header.
22
+ For instance it can be useful to specify more about the underlying
23
+ tensors. This is purely informative and does not affect tensor loading.
24
+
25
+ Returns:
26
+ `bytes`: The raw bytes representing the format
27
+
28
+ Example:
29
+
30
+ ```python
31
+ from safetensors.paddle import save
32
+ import paddle
33
+
34
+ tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))}
35
+ byte_data = save(tensors)
36
+ ```
37
+ """
38
+ serialized = serialize(_flatten(tensors), metadata=metadata)
39
+ result = bytes(serialized)
40
+ return result
41
+
42
+
43
+ def save_file(
44
+ tensors: Dict[str, paddle.Tensor],
45
+ filename: Union[str, os.PathLike],
46
+ metadata: Optional[Dict[str, str]] = None,
47
+ ) -> None:
48
+ """
49
+ Saves a dictionary of tensors into raw bytes in safetensors format.
50
+
51
+ Args:
52
+ tensors (`Dict[str, paddle.Tensor]`):
53
+ The incoming tensors. Tensors need to be contiguous and dense.
54
+ filename (`str`, or `os.PathLike`)):
55
+ The filename we're saving into.
56
+ metadata (`Dict[str, str]`, *optional*, defaults to `None`):
57
+ Optional text only metadata you might want to save in your header.
58
+ For instance it can be useful to specify more about the underlying
59
+ tensors. This is purely informative and does not affect tensor loading.
60
+
61
+ Returns:
62
+ `None`
63
+
64
+ Example:
65
+
66
+ ```python
67
+ from safetensors.paddle import save_file
68
+ import paddle
69
+
70
+ tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))}
71
+ save_file(tensors, "model.safetensors")
72
+ ```
73
+ """
74
+ serialize_file(_flatten(tensors), filename, metadata=metadata)
75
+
76
+
77
+ def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]:
78
+ """
79
+ Loads a safetensors file into paddle format from pure bytes.
80
+
81
+ Args:
82
+ data (`bytes`):
83
+ The content of a safetensors file
84
+
85
+ Returns:
86
+ `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor` on cpu
87
+
88
+ Example:
89
+
90
+ ```python
91
+ from safetensors.paddle import load
92
+
93
+ file_path = "./my_folder/bert.safetensors"
94
+ with open(file_path, "rb") as f:
95
+ data = f.read()
96
+
97
+ loaded = load(data)
98
+ ```
99
+ """
100
+ if paddle.__version__ >= "3.2.0":
101
+ flat = deserialize(data)
102
+ return _view2paddle(flat, device)
103
+ else:
104
+ flat = numpy.load(data)
105
+ return _np2paddle(flat, device)
106
+
107
+
108
+ def load_file(
109
+ filename: Union[str, os.PathLike], device="cpu"
110
+ ) -> Dict[str, paddle.Tensor]:
111
+ """
112
+ Loads a safetensors file into paddle format.
113
+
114
+ Args:
115
+ filename (`str`, or `os.PathLike`)):
116
+ The name of the file which contains the tensors
117
+ device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`):
118
+ The device where the tensors need to be located after load.
119
+ available options are all regular paddle device locations
120
+
121
+ Returns:
122
+ `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor`
123
+
124
+ Example:
125
+
126
+ ```python
127
+ from safetensors.paddle import load_file
128
+
129
+ file_path = "./my_folder/bert.safetensors"
130
+ loaded = load_file(file_path)
131
+ ```
132
+ """
133
+ result = {}
134
+ if paddle.__version__ >= "3.2.0":
135
+ with safe_open(filename, framework="paddle", device=device) as f:
136
+ for k in f.offset_keys():
137
+ result[k] = f.get_tensor(k)
138
+ else:
139
+ flat = numpy.load_file(filename)
140
+ result = _np2paddle(flat, device)
141
+ return result
142
+
143
+
144
+ def _np2paddle(
145
+ numpy_dict: Dict[str, np.ndarray], device: str = "cpu"
146
+ ) -> Dict[str, paddle.Tensor]:
147
+ for k, v in numpy_dict.items():
148
+ numpy_dict[k] = paddle.to_tensor(v, place=device)
149
+ return numpy_dict
150
+
151
+
152
+ def _paddle2np(paddle_dict: Dict[str, paddle.Tensor]) -> Dict[str, np.array]:
153
+ for k, v in paddle_dict.items():
154
+ paddle_dict[k] = v.detach().cpu().numpy()
155
+ return paddle_dict
156
+
157
+
158
+ _SIZE = {
159
+ paddle.int64: 8,
160
+ paddle.float32: 4,
161
+ paddle.int32: 4,
162
+ paddle.bfloat16: 2,
163
+ paddle.float16: 2,
164
+ paddle.int16: 2,
165
+ paddle.uint8: 1,
166
+ paddle.int8: 1,
167
+ paddle.bool: 1,
168
+ paddle.float64: 8,
169
+ paddle.float8_e4m3fn: 1,
170
+ paddle.float8_e5m2: 1,
171
+ paddle.complex64: 8,
172
+ # XXX: These are not supported yet in paddle
173
+ # paddle.uint64: 8,
174
+ # paddle.uint32: 4,
175
+ # paddle.uint16: 2,
176
+ # paddle.float8_e8m0: 1,
177
+ # paddle.float4_e2m1_x2: 1,
178
+ }
179
+
180
+ _TYPES = {
181
+ "F64": paddle.float64,
182
+ "F32": paddle.float32,
183
+ "F16": paddle.float16,
184
+ "BF16": paddle.bfloat16,
185
+ "I64": paddle.int64,
186
+ "I32": paddle.int32,
187
+ "I16": paddle.int16,
188
+ "I8": paddle.int8,
189
+ "U8": paddle.uint8,
190
+ "BOOL": paddle.bool,
191
+ "F8_E4M3": paddle.float8_e4m3fn,
192
+ "F8_E5M2": paddle.float8_e5m2,
193
+ }
194
+
195
+ NPDTYPES = {
196
+ paddle.int64: np.int64,
197
+ paddle.float32: np.float32,
198
+ paddle.int32: np.int32,
199
+ # XXX: This is ok because both have the same width
200
+ paddle.bfloat16: np.float16,
201
+ paddle.float16: np.float16,
202
+ paddle.int16: np.int16,
203
+ paddle.uint8: np.uint8,
204
+ paddle.int8: np.int8,
205
+ paddle.bool: bool,
206
+ paddle.float64: np.float64,
207
+ # XXX: This is ok because both have the same width and byteswap is a no-op anyway
208
+ paddle.float8_e4m3fn: np.uint8,
209
+ paddle.float8_e5m2: np.uint8,
210
+ }
211
+
212
+
213
+ def _getdtype(dtype_str: str) -> paddle.dtype:
214
+ return _TYPES[dtype_str]
215
+
216
+
217
+ def _view2paddle(safeview, device) -> Dict[str, paddle.Tensor]:
218
+ result = {}
219
+ for k, v in safeview:
220
+ dtype = _getdtype(v["dtype"])
221
+ if len(v["data"]) == 0:
222
+ # Workaround because frombuffer doesn't accept zero-size tensors
223
+ assert any(x == 0 for x in v["shape"])
224
+ arr = paddle.empty(v["shape"], dtype=dtype)
225
+ else:
226
+ arr = paddle.base.core.frombuffer(v["data"], dtype).reshape(v["shape"])
227
+ if device != "cpu":
228
+ arr = arr.to(device)
229
+ if sys.byteorder == "big":
230
+ arr = paddle.to_tensor(arr.numpy().byteswap(inplace=False), place=device)
231
+ result[k] = arr
232
+
233
+ return result
234
+
235
+
236
+ def _tobytes(tensor: paddle.Tensor, name: str) -> bytes:
237
+ if not tensor.is_contiguous():
238
+ raise ValueError(
239
+ f"You are trying to save a non contiguous tensor: `{name}` which is not allowed. It either means you"
240
+ " are trying to save tensors which are reference of each other in which case it's recommended to save"
241
+ " only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to"
242
+ " pack it before saving."
243
+ )
244
+ if not tensor.place.is_cpu_place():
245
+ # Moving tensor to cpu before saving
246
+ tensor = tensor.cpu()
247
+
248
+ import ctypes
249
+
250
+ import numpy as np
251
+
252
+ # When shape is empty (scalar), np.prod returns a float
253
+ # we need a int for the following calculations
254
+ length = int(np.prod(tensor.shape).item())
255
+ bytes_per_item = _SIZE[tensor.dtype]
256
+
257
+ total_bytes = length * bytes_per_item
258
+
259
+ ptr = tensor.data_ptr()
260
+ if ptr == 0:
261
+ return b""
262
+ newptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_ubyte))
263
+ data = np.ctypeslib.as_array(newptr, (total_bytes,)) # no internal copy
264
+ if sys.byteorder == "big":
265
+ npdtype = NPDTYPES[tensor.dtype]
266
+ # Not in place as that would potentially modify a live running model
267
+ data = data.view(npdtype).byteswap(inplace=False)
268
+ return data.tobytes()
269
+
270
+
271
+ def _flatten(tensors: Dict[str, paddle.Tensor]) -> Dict[str, Dict[str, Any]]:
272
+ if not isinstance(tensors, dict):
273
+ raise ValueError(
274
+ f"Expected a dict of [str, paddle.Tensor] but received {type(tensors)}"
275
+ )
276
+
277
+ for k, v in tensors.items():
278
+ if not isinstance(v, paddle.Tensor):
279
+ raise ValueError(
280
+ f"Key `{k}` is invalid, expected paddle.Tensor but received {type(v)}"
281
+ )
282
+
283
+ return {
284
+ k: {
285
+ "dtype": str(v.dtype).split(".")[-1],
286
+ "shape": v.shape,
287
+ "data": _tobytes(v, k),
288
+ }
289
+ for k, v in tensors.items()
290
+ }
@@ -221,68 +221,23 @@ def load_model(
221
221
  to_removes = _remove_duplicate_names(
222
222
  model_state_dict, preferred_names=state_dict.keys()
223
223
  )
224
-
225
- reverse_to_remove = {}
226
- for key, to_remove_group in to_removes.items():
227
- for to_remove in to_remove_group:
228
- reverse_to_remove[to_remove] = key
229
-
230
- # We iterate on the model, so we'll add keys we find missing
231
- # here
232
- missing = set()
233
- # We start with all keys on disk declared as unexpected, we'll
234
- # slowly remove them when we find them
235
- unexpected = set(state_dict.keys())
236
- # Some keys can be invalid too.
237
- invalid = set()
238
-
239
- for k, mv in model_state_dict.items():
240
- actual_k = reverse_to_remove.get(k, None)
241
- if actual_k is not None:
242
- look_k = actual_k
243
- else:
244
- look_k = k
245
- v = state_dict.get(look_k, None)
246
- if v is None:
247
- missing.add(k)
248
- else:
249
- # We can actually check for the shapes while we're at it.
250
- # For the device, it's trickier given torch's internals
251
- # There might be some Meta device for faster initiation
252
- if v.dtype != mv.dtype or v.shape != mv.shape:
253
- invalid.add(k)
254
- if actual_k is None:
255
- unexpected.remove(k)
256
-
224
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
257
225
  missing = set(missing)
258
- unexpected = set(unexpected)
259
- if strict and (missing or unexpected or invalid):
226
+ for to_remove_group in to_removes.values():
227
+ for to_remove in to_remove_group:
228
+ if to_remove not in missing:
229
+ unexpected.append(to_remove)
230
+ else:
231
+ missing.remove(to_remove)
232
+ if strict and (missing or unexpected):
260
233
  missing_keys = ", ".join([f'"{k}"' for k in sorted(missing)])
261
234
  unexpected_keys = ", ".join([f'"{k}"' for k in sorted(unexpected)])
262
- invalid_keys = ", ".join([f'"{k}"' for k in sorted(invalid)])
263
235
  error = f"Error(s) in loading state_dict for {model.__class__.__name__}:"
264
236
  if missing:
265
237
  error += f"\n Missing key(s) in state_dict: {missing_keys}"
266
238
  if unexpected:
267
239
  error += f"\n Unexpected key(s) in state_dict: {unexpected_keys}"
268
- if invalid:
269
- error += f"\n Invalid key(s) in state_dict: {invalid_keys}, mismatched dtypes or shape."
270
- del state_dict
271
240
  raise RuntimeError(error)
272
-
273
- torch_missing, torch_unexpected = model.load_state_dict(state_dict, strict=False)
274
- # Sanity check that the work we've done matches
275
- # Pytorch internal loading.
276
- torch_missing = set(torch_missing)
277
- torch_unexpected = set(torch_unexpected)
278
- for to_remove_group in to_removes.values():
279
- for to_remove in to_remove_group:
280
- if to_remove not in torch_missing:
281
- torch_unexpected.add(to_remove)
282
- else:
283
- torch_missing.remove(to_remove)
284
- assert torch_missing == missing, f"{torch_missing} != {missing}"
285
- assert torch_unexpected == unexpected, f"{torch_unexpected} != {unexpected}"
286
241
  return missing, unexpected
287
242
 
288
243
 
@@ -428,6 +383,7 @@ _SIZE = {
428
383
  torch.int8: 1,
429
384
  torch.bool: 1,
430
385
  torch.float64: 8,
386
+ torch.complex64: 8,
431
387
  _float8_e4m3fn: 1,
432
388
  _float8_e5m2: 1,
433
389
  _float8_e8m0: 1,
@@ -455,6 +411,7 @@ _TYPES = {
455
411
  "BOOL": torch.bool,
456
412
  "F8_E4M3": _float8_e4m3fn,
457
413
  "F8_E5M2": _float8_e5m2,
414
+ "C64": torch.complex64,
458
415
  }
459
416
  if Version(torch.__version__) >= Version("2.3.0"):
460
417
  _TYPES.update(
@@ -538,6 +495,7 @@ def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
538
495
  # XXX: This is ok because both have the same width and byteswap is a no-op anyway
539
496
  _float8_e4m3fn: np.uint8,
540
497
  _float8_e5m2: np.uint8,
498
+ torch.complex64: np.complex64,
541
499
  }
542
500
  npdtype = NPDTYPES[tensor.dtype]
543
501
  # Not in place as that would potentially modify a live running model