safetensors 0.1.4 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/Cargo.lock +75 -44
- data/ext/safetensors/Cargo.toml +4 -4
- data/ext/safetensors/src/lib.rs +47 -42
- data/lib/safetensors/version.rb +1 -1
- metadata +3 -3
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 547c0781cd2c693db2a6605bca61004833aa668876f1512673ed7e5a44b6d30d
|
|
4
|
+
data.tar.gz: be088a6e2c725dfbc32a16c3864f8bc23ff4f58ee86fcf6c3388f64d98fb0223
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 270ea7030272173f04939528d34ae49d8d8a15f7ac51adb62fe6c93990e551257bce124c9bc983bbda70851322e66c4149f4483815f26b504a9528067b61dc1e
|
|
7
|
+
data.tar.gz: 7e4c3681d5443bb85d5226590d4dfe1b42a1c8ecf597ace1f136f89c6456070878cefc4f81ddf8247d8243f5c4fc20897af800c55e16aa4746d271d4cccdda53
|
data/CHANGELOG.md
CHANGED
data/Cargo.lock
CHANGED
|
@@ -11,6 +11,12 @@ dependencies = [
|
|
|
11
11
|
"memchr",
|
|
12
12
|
]
|
|
13
13
|
|
|
14
|
+
[[package]]
|
|
15
|
+
name = "allocator-api2"
|
|
16
|
+
version = "0.2.21"
|
|
17
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
18
|
+
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
|
19
|
+
|
|
14
20
|
[[package]]
|
|
15
21
|
name = "bindgen"
|
|
16
22
|
version = "0.69.5"
|
|
@@ -33,9 +39,9 @@ dependencies = [
|
|
|
33
39
|
|
|
34
40
|
[[package]]
|
|
35
41
|
name = "bitflags"
|
|
36
|
-
version = "2.9.
|
|
42
|
+
version = "2.9.1"
|
|
37
43
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
38
|
-
checksum = "
|
|
44
|
+
checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967"
|
|
39
45
|
|
|
40
46
|
[[package]]
|
|
41
47
|
name = "cexpr"
|
|
@@ -48,9 +54,9 @@ dependencies = [
|
|
|
48
54
|
|
|
49
55
|
[[package]]
|
|
50
56
|
name = "cfg-if"
|
|
51
|
-
version = "1.0.
|
|
57
|
+
version = "1.0.1"
|
|
52
58
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
53
|
-
checksum = "
|
|
59
|
+
checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268"
|
|
54
60
|
|
|
55
61
|
[[package]]
|
|
56
62
|
name = "clang-sys"
|
|
@@ -69,12 +75,36 @@ version = "1.15.0"
|
|
|
69
75
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
70
76
|
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
|
71
77
|
|
|
78
|
+
[[package]]
|
|
79
|
+
name = "equivalent"
|
|
80
|
+
version = "1.0.2"
|
|
81
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
82
|
+
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
|
83
|
+
|
|
84
|
+
[[package]]
|
|
85
|
+
name = "foldhash"
|
|
86
|
+
version = "0.2.0"
|
|
87
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
88
|
+
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
|
|
89
|
+
|
|
72
90
|
[[package]]
|
|
73
91
|
name = "glob"
|
|
74
92
|
version = "0.3.2"
|
|
75
93
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
76
94
|
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
|
77
95
|
|
|
96
|
+
[[package]]
|
|
97
|
+
name = "hashbrown"
|
|
98
|
+
version = "0.16.0"
|
|
99
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
100
|
+
checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d"
|
|
101
|
+
dependencies = [
|
|
102
|
+
"allocator-api2",
|
|
103
|
+
"equivalent",
|
|
104
|
+
"foldhash",
|
|
105
|
+
"serde",
|
|
106
|
+
]
|
|
107
|
+
|
|
78
108
|
[[package]]
|
|
79
109
|
name = "itertools"
|
|
80
110
|
version = "0.12.1"
|
|
@@ -104,15 +134,15 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
|
|
104
134
|
|
|
105
135
|
[[package]]
|
|
106
136
|
name = "libc"
|
|
107
|
-
version = "0.2.
|
|
137
|
+
version = "0.2.174"
|
|
108
138
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
109
|
-
checksum = "
|
|
139
|
+
checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776"
|
|
110
140
|
|
|
111
141
|
[[package]]
|
|
112
142
|
name = "libloading"
|
|
113
|
-
version = "0.8.
|
|
143
|
+
version = "0.8.8"
|
|
114
144
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
115
|
-
checksum = "
|
|
145
|
+
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
|
|
116
146
|
dependencies = [
|
|
117
147
|
"cfg-if",
|
|
118
148
|
"windows-targets",
|
|
@@ -120,9 +150,9 @@ dependencies = [
|
|
|
120
150
|
|
|
121
151
|
[[package]]
|
|
122
152
|
name = "magnus"
|
|
123
|
-
version = "0.
|
|
153
|
+
version = "0.8.1"
|
|
124
154
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
125
|
-
checksum = "
|
|
155
|
+
checksum = "bd2ac6e71886be00ac34db92aa732c793c5107c95191805b9a1c7e70e6d342e0"
|
|
126
156
|
dependencies = [
|
|
127
157
|
"magnus-macros",
|
|
128
158
|
"rb-sys",
|
|
@@ -132,9 +162,9 @@ dependencies = [
|
|
|
132
162
|
|
|
133
163
|
[[package]]
|
|
134
164
|
name = "magnus-macros"
|
|
135
|
-
version = "0.
|
|
165
|
+
version = "0.8.0"
|
|
136
166
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
137
|
-
checksum = "
|
|
167
|
+
checksum = "47607461fd8e1513cb4f2076c197d8092d921a1ea75bd08af97398f593751892"
|
|
138
168
|
dependencies = [
|
|
139
169
|
"proc-macro2",
|
|
140
170
|
"quote",
|
|
@@ -143,9 +173,9 @@ dependencies = [
|
|
|
143
173
|
|
|
144
174
|
[[package]]
|
|
145
175
|
name = "memchr"
|
|
146
|
-
version = "2.7.
|
|
176
|
+
version = "2.7.5"
|
|
147
177
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
148
|
-
checksum = "
|
|
178
|
+
checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
|
|
149
179
|
|
|
150
180
|
[[package]]
|
|
151
181
|
name = "memmap2"
|
|
@@ -192,18 +222,18 @@ dependencies = [
|
|
|
192
222
|
|
|
193
223
|
[[package]]
|
|
194
224
|
name = "rb-sys"
|
|
195
|
-
version = "0.9.
|
|
225
|
+
version = "0.9.117"
|
|
196
226
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
197
|
-
checksum = "
|
|
227
|
+
checksum = "f900d1ce4629a2ebffaf5de74bd8f9c1188d4c5ed406df02f97e22f77a006f44"
|
|
198
228
|
dependencies = [
|
|
199
229
|
"rb-sys-build",
|
|
200
230
|
]
|
|
201
231
|
|
|
202
232
|
[[package]]
|
|
203
233
|
name = "rb-sys-build"
|
|
204
|
-
version = "0.9.
|
|
234
|
+
version = "0.9.117"
|
|
205
235
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
206
|
-
checksum = "
|
|
236
|
+
checksum = "ef1e9c857028f631056bcd6d88cec390c751e343ce2223ddb26d23eb4a151d59"
|
|
207
237
|
dependencies = [
|
|
208
238
|
"bindgen",
|
|
209
239
|
"lazy_static",
|
|
@@ -216,9 +246,9 @@ dependencies = [
|
|
|
216
246
|
|
|
217
247
|
[[package]]
|
|
218
248
|
name = "rb-sys-env"
|
|
219
|
-
version = "0.
|
|
249
|
+
version = "0.2.2"
|
|
220
250
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
221
|
-
checksum = "
|
|
251
|
+
checksum = "08f8d2924cf136a1315e2b4c7460a39f62ef11ee5d522df9b2750fab55b868b6"
|
|
222
252
|
|
|
223
253
|
[[package]]
|
|
224
254
|
name = "regex"
|
|
@@ -263,20 +293,21 @@ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
|
|
263
293
|
|
|
264
294
|
[[package]]
|
|
265
295
|
name = "safetensors"
|
|
266
|
-
version = "0.1
|
|
296
|
+
version = "0.2.1"
|
|
267
297
|
dependencies = [
|
|
268
298
|
"magnus",
|
|
269
299
|
"memmap2",
|
|
270
|
-
"safetensors 0.
|
|
300
|
+
"safetensors 0.7.0",
|
|
271
301
|
"serde_json",
|
|
272
302
|
]
|
|
273
303
|
|
|
274
304
|
[[package]]
|
|
275
305
|
name = "safetensors"
|
|
276
|
-
version = "0.
|
|
306
|
+
version = "0.7.0"
|
|
277
307
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
278
|
-
checksum = "
|
|
308
|
+
checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5"
|
|
279
309
|
dependencies = [
|
|
310
|
+
"hashbrown",
|
|
280
311
|
"serde",
|
|
281
312
|
"serde_json",
|
|
282
313
|
]
|
|
@@ -333,9 +364,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
|
|
333
364
|
|
|
334
365
|
[[package]]
|
|
335
366
|
name = "syn"
|
|
336
|
-
version = "2.0.
|
|
367
|
+
version = "2.0.104"
|
|
337
368
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
338
|
-
checksum = "
|
|
369
|
+
checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40"
|
|
339
370
|
dependencies = [
|
|
340
371
|
"proc-macro2",
|
|
341
372
|
"quote",
|
|
@@ -350,9 +381,9 @@ checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
|
|
350
381
|
|
|
351
382
|
[[package]]
|
|
352
383
|
name = "windows-targets"
|
|
353
|
-
version = "0.
|
|
384
|
+
version = "0.53.2"
|
|
354
385
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
355
|
-
checksum = "
|
|
386
|
+
checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef"
|
|
356
387
|
dependencies = [
|
|
357
388
|
"windows_aarch64_gnullvm",
|
|
358
389
|
"windows_aarch64_msvc",
|
|
@@ -366,48 +397,48 @@ dependencies = [
|
|
|
366
397
|
|
|
367
398
|
[[package]]
|
|
368
399
|
name = "windows_aarch64_gnullvm"
|
|
369
|
-
version = "0.
|
|
400
|
+
version = "0.53.0"
|
|
370
401
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
371
|
-
checksum = "
|
|
402
|
+
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
|
|
372
403
|
|
|
373
404
|
[[package]]
|
|
374
405
|
name = "windows_aarch64_msvc"
|
|
375
|
-
version = "0.
|
|
406
|
+
version = "0.53.0"
|
|
376
407
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
377
|
-
checksum = "
|
|
408
|
+
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
|
|
378
409
|
|
|
379
410
|
[[package]]
|
|
380
411
|
name = "windows_i686_gnu"
|
|
381
|
-
version = "0.
|
|
412
|
+
version = "0.53.0"
|
|
382
413
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
383
|
-
checksum = "
|
|
414
|
+
checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3"
|
|
384
415
|
|
|
385
416
|
[[package]]
|
|
386
417
|
name = "windows_i686_gnullvm"
|
|
387
|
-
version = "0.
|
|
418
|
+
version = "0.53.0"
|
|
388
419
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
389
|
-
checksum = "
|
|
420
|
+
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
|
|
390
421
|
|
|
391
422
|
[[package]]
|
|
392
423
|
name = "windows_i686_msvc"
|
|
393
|
-
version = "0.
|
|
424
|
+
version = "0.53.0"
|
|
394
425
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
395
|
-
checksum = "
|
|
426
|
+
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
|
|
396
427
|
|
|
397
428
|
[[package]]
|
|
398
429
|
name = "windows_x86_64_gnu"
|
|
399
|
-
version = "0.
|
|
430
|
+
version = "0.53.0"
|
|
400
431
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
401
|
-
checksum = "
|
|
432
|
+
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
|
|
402
433
|
|
|
403
434
|
[[package]]
|
|
404
435
|
name = "windows_x86_64_gnullvm"
|
|
405
|
-
version = "0.
|
|
436
|
+
version = "0.53.0"
|
|
406
437
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
407
|
-
checksum = "
|
|
438
|
+
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
|
|
408
439
|
|
|
409
440
|
[[package]]
|
|
410
441
|
name = "windows_x86_64_msvc"
|
|
411
|
-
version = "0.
|
|
442
|
+
version = "0.53.0"
|
|
412
443
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
413
|
-
checksum = "
|
|
444
|
+
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
data/ext/safetensors/Cargo.toml
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
[package]
|
|
2
2
|
name = "safetensors"
|
|
3
|
-
version = "0.1
|
|
3
|
+
version = "0.2.1"
|
|
4
4
|
license = "Apache-2.0"
|
|
5
5
|
authors = ["Andrew Kane <andrew@ankane.org>"]
|
|
6
6
|
edition = "2021"
|
|
7
|
-
rust-version = "1.
|
|
7
|
+
rust-version = "1.80"
|
|
8
8
|
publish = false
|
|
9
9
|
|
|
10
10
|
[lib]
|
|
11
11
|
crate-type = ["cdylib"]
|
|
12
12
|
|
|
13
13
|
[dependencies]
|
|
14
|
-
magnus = "0.
|
|
14
|
+
magnus = "0.8"
|
|
15
15
|
memmap2 = "0.5"
|
|
16
|
-
safetensors = "=0.
|
|
16
|
+
safetensors = "=0.7.0"
|
|
17
17
|
serde_json = "1"
|
data/ext/safetensors/src/lib.rs
CHANGED
|
@@ -79,12 +79,16 @@ fn prepare(tensor_dict: &RHash) -> RbResult<HashMap<String, TensorView<'_>>> {
|
|
|
79
79
|
Ok(tensors)
|
|
80
80
|
}
|
|
81
81
|
|
|
82
|
-
fn serialize(
|
|
82
|
+
fn serialize(
|
|
83
|
+
ruby: &Ruby,
|
|
84
|
+
tensor_dict: RHash,
|
|
85
|
+
metadata: Option<HashMap<String, String>>,
|
|
86
|
+
) -> RbResult<RString> {
|
|
83
87
|
let tensors = prepare(&tensor_dict)?;
|
|
84
88
|
let metadata_map = metadata.map(HashMap::from_iter);
|
|
85
|
-
let out = safetensors::tensor::serialize(&tensors,
|
|
89
|
+
let out = safetensors::tensor::serialize(&tensors, metadata_map)
|
|
86
90
|
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
|
|
87
|
-
let rbbytes =
|
|
91
|
+
let rbbytes = ruby.str_from_slice(&out);
|
|
88
92
|
Ok(rbbytes)
|
|
89
93
|
}
|
|
90
94
|
|
|
@@ -94,25 +98,25 @@ fn serialize_file(
|
|
|
94
98
|
metadata: Option<HashMap<String, String>>,
|
|
95
99
|
) -> RbResult<()> {
|
|
96
100
|
let tensors = prepare(&tensor_dict)?;
|
|
97
|
-
safetensors::tensor::serialize_to_file(&tensors,
|
|
101
|
+
safetensors::tensor::serialize_to_file(&tensors, metadata, filename.as_path())
|
|
98
102
|
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
|
|
99
103
|
Ok(())
|
|
100
104
|
}
|
|
101
105
|
|
|
102
|
-
fn deserialize(bytes: RString) -> RbResult<RArray> {
|
|
106
|
+
fn deserialize(ruby: &Ruby, bytes: RString) -> RbResult<RArray> {
|
|
103
107
|
let safetensor = SafeTensors::deserialize(unsafe { bytes.as_slice() })
|
|
104
108
|
.map_err(|e| SafetensorError::new_err(format!("Error while deserializing: {e:?}")))?;
|
|
105
109
|
|
|
106
110
|
let tensors = safetensor.tensors();
|
|
107
|
-
let items =
|
|
111
|
+
let items = ruby.ary_new_capa(tensors.len());
|
|
108
112
|
|
|
109
113
|
for (tensor_name, tensor) in tensors {
|
|
110
|
-
let rbshape =
|
|
114
|
+
let rbshape = ruby.ary_from_vec(tensor.shape().to_vec());
|
|
111
115
|
let rbdtype = format!("{:?}", tensor.dtype());
|
|
112
116
|
|
|
113
|
-
let rbdata =
|
|
117
|
+
let rbdata = ruby.str_from_slice(tensor.data());
|
|
114
118
|
|
|
115
|
-
let map =
|
|
119
|
+
let map = ruby.hash_new();
|
|
116
120
|
map.aset("shape", rbshape)?;
|
|
117
121
|
map.aset("dtype", rbdtype)?;
|
|
118
122
|
map.aset("data", rbdata)?;
|
|
@@ -279,7 +283,7 @@ impl Open {
|
|
|
279
283
|
Ok(keys)
|
|
280
284
|
}
|
|
281
285
|
|
|
282
|
-
pub fn get_tensor(&self, name: &str) -> RbResult<Value> {
|
|
286
|
+
pub fn get_tensor(&self, ruby: &Ruby, name: &str) -> RbResult<Value> {
|
|
283
287
|
let info = self.metadata.info(name).ok_or_else(|| {
|
|
284
288
|
SafetensorError::new_err(format!("File does not contain tensor {name}",))
|
|
285
289
|
})?;
|
|
@@ -289,9 +293,10 @@ impl Open {
|
|
|
289
293
|
let data =
|
|
290
294
|
&mmap[info.data_offsets.0 + self.offset..info.data_offsets.1 + self.offset];
|
|
291
295
|
|
|
292
|
-
let array: Value =
|
|
296
|
+
let array: Value = ruby.str_from_slice(data).as_value();
|
|
293
297
|
|
|
294
298
|
create_tensor(
|
|
299
|
+
ruby,
|
|
295
300
|
&self.framework,
|
|
296
301
|
info.dtype,
|
|
297
302
|
&info.shape,
|
|
@@ -332,19 +337,19 @@ impl SafeOpen {
|
|
|
332
337
|
self.inner()?.keys()
|
|
333
338
|
}
|
|
334
339
|
|
|
335
|
-
pub fn get_tensor(&
|
|
336
|
-
|
|
340
|
+
pub fn get_tensor(ruby: &Ruby, rb_self: &Self, name: String) -> RbResult<Value> {
|
|
341
|
+
rb_self.inner()?.get_tensor(ruby, &name)
|
|
337
342
|
}
|
|
338
343
|
}
|
|
339
344
|
|
|
340
345
|
fn create_tensor(
|
|
346
|
+
ruby: &Ruby,
|
|
341
347
|
framework: &Framework,
|
|
342
348
|
dtype: Dtype,
|
|
343
349
|
shape: &[usize],
|
|
344
350
|
array: Value,
|
|
345
351
|
device: &Device,
|
|
346
352
|
) -> RbResult<Value> {
|
|
347
|
-
let ruby = Ruby::get().unwrap();
|
|
348
353
|
let (module, is_numo): (RModule, bool) = match framework {
|
|
349
354
|
Framework::Pytorch => (
|
|
350
355
|
ruby.class_object()
|
|
@@ -360,7 +365,7 @@ fn create_tensor(
|
|
|
360
365
|
),
|
|
361
366
|
};
|
|
362
367
|
|
|
363
|
-
let dtype = get_rbdtype(module, dtype, is_numo)?;
|
|
368
|
+
let dtype = get_rbdtype(ruby, module, dtype, is_numo)?;
|
|
364
369
|
let shape = shape.to_vec();
|
|
365
370
|
let tensor: Value = match framework {
|
|
366
371
|
Framework::Pytorch => {
|
|
@@ -378,19 +383,19 @@ fn create_tensor(
|
|
|
378
383
|
Ok(tensor)
|
|
379
384
|
}
|
|
380
385
|
|
|
381
|
-
fn get_rbdtype(_module: RModule, dtype: Dtype, is_numo: bool) -> RbResult<
|
|
382
|
-
let dtype:
|
|
386
|
+
fn get_rbdtype(ruby: &Ruby, _module: RModule, dtype: Dtype, is_numo: bool) -> RbResult<Symbol> {
|
|
387
|
+
let dtype: Symbol = if is_numo {
|
|
383
388
|
match dtype {
|
|
384
|
-
Dtype::F64 =>
|
|
385
|
-
Dtype::F32 =>
|
|
386
|
-
Dtype::U64 =>
|
|
387
|
-
Dtype::I64 =>
|
|
388
|
-
Dtype::U32 =>
|
|
389
|
-
Dtype::I32 =>
|
|
390
|
-
Dtype::U16 =>
|
|
391
|
-
Dtype::I16 =>
|
|
392
|
-
Dtype::U8 =>
|
|
393
|
-
Dtype::I8 =>
|
|
389
|
+
Dtype::F64 => ruby.to_symbol("DFloat"),
|
|
390
|
+
Dtype::F32 => ruby.to_symbol("SFloat"),
|
|
391
|
+
Dtype::U64 => ruby.to_symbol("UInt64"),
|
|
392
|
+
Dtype::I64 => ruby.to_symbol("Int64"),
|
|
393
|
+
Dtype::U32 => ruby.to_symbol("UInt32"),
|
|
394
|
+
Dtype::I32 => ruby.to_symbol("Int32"),
|
|
395
|
+
Dtype::U16 => ruby.to_symbol("UInt16"),
|
|
396
|
+
Dtype::I16 => ruby.to_symbol("Int16"),
|
|
397
|
+
Dtype::U8 => ruby.to_symbol("UInt8"),
|
|
398
|
+
Dtype::I8 => ruby.to_symbol("Int8"),
|
|
394
399
|
dtype => {
|
|
395
400
|
return Err(SafetensorError::new_err(format!(
|
|
396
401
|
"Dtype not understood: {dtype:?}"
|
|
@@ -399,21 +404,21 @@ fn get_rbdtype(_module: RModule, dtype: Dtype, is_numo: bool) -> RbResult<Value>
|
|
|
399
404
|
}
|
|
400
405
|
} else {
|
|
401
406
|
match dtype {
|
|
402
|
-
Dtype::F64 =>
|
|
403
|
-
Dtype::F32 =>
|
|
404
|
-
Dtype::BF16 =>
|
|
405
|
-
Dtype::F16 =>
|
|
406
|
-
Dtype::U64 =>
|
|
407
|
-
Dtype::I64 =>
|
|
408
|
-
Dtype::U32 =>
|
|
409
|
-
Dtype::I32 =>
|
|
410
|
-
Dtype::U16 =>
|
|
411
|
-
Dtype::I16 =>
|
|
412
|
-
Dtype::U8 =>
|
|
413
|
-
Dtype::I8 =>
|
|
414
|
-
Dtype::BOOL =>
|
|
415
|
-
Dtype::F8_E4M3 =>
|
|
416
|
-
Dtype::F8_E5M2 =>
|
|
407
|
+
Dtype::F64 => ruby.to_symbol("float64"),
|
|
408
|
+
Dtype::F32 => ruby.to_symbol("float32"),
|
|
409
|
+
Dtype::BF16 => ruby.to_symbol("bfloat16"),
|
|
410
|
+
Dtype::F16 => ruby.to_symbol("float16"),
|
|
411
|
+
Dtype::U64 => ruby.to_symbol("uint64"),
|
|
412
|
+
Dtype::I64 => ruby.to_symbol("int64"),
|
|
413
|
+
Dtype::U32 => ruby.to_symbol("uint32"),
|
|
414
|
+
Dtype::I32 => ruby.to_symbol("int32"),
|
|
415
|
+
Dtype::U16 => ruby.to_symbol("uint16"),
|
|
416
|
+
Dtype::I16 => ruby.to_symbol("int16"),
|
|
417
|
+
Dtype::U8 => ruby.to_symbol("uint8"),
|
|
418
|
+
Dtype::I8 => ruby.to_symbol("int8"),
|
|
419
|
+
Dtype::BOOL => ruby.to_symbol("bool"),
|
|
420
|
+
Dtype::F8_E4M3 => ruby.to_symbol("float8_e4m3fn"),
|
|
421
|
+
Dtype::F8_E5M2 => ruby.to_symbol("float8_e5m2"),
|
|
417
422
|
dtype => {
|
|
418
423
|
return Err(SafetensorError::new_err(format!(
|
|
419
424
|
"Dtype not understood: {dtype:?}"
|
data/lib/safetensors/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: safetensors
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.1
|
|
4
|
+
version: 0.2.1
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Andrew Kane
|
|
@@ -52,14 +52,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
|
52
52
|
requirements:
|
|
53
53
|
- - ">="
|
|
54
54
|
- !ruby/object:Gem::Version
|
|
55
|
-
version: '3.
|
|
55
|
+
version: '3.2'
|
|
56
56
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
57
57
|
requirements:
|
|
58
58
|
- - ">="
|
|
59
59
|
- !ruby/object:Gem::Version
|
|
60
60
|
version: '0'
|
|
61
61
|
requirements: []
|
|
62
|
-
rubygems_version: 3.6.
|
|
62
|
+
rubygems_version: 3.6.9
|
|
63
63
|
specification_version: 4
|
|
64
64
|
summary: Simple, safe way to store and distribute tensors
|
|
65
65
|
test_files: []
|