safetensors 0.2.2 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/Cargo.lock +110 -5
- data/README.md +1 -1
- data/ext/safetensors/Cargo.toml +4 -3
- data/ext/safetensors/src/lib.rs +178 -102
- data/ext/safetensors/src/ruby.rs +51 -0
- data/lib/safetensors/version.rb +1 -1
- metadata +4 -3
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: cc628696f449fb899f85e95c1c824c00b832aad7e3f92253bdc48c9c1e30c3af
|
|
4
|
+
data.tar.gz: 433d331c6cfaadcc05a844d7fb8a1e31681b626bcedadc6fdc653960a0cfc578
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: d05099bbce87c9cfe6634ed567378a8e135b3e1c72d6aa541d6ba2c8452ff3a65e249bd460d8490d8b525c01ae6269aafd75d07d07c3a40fedea9418eb7f79a3
|
|
7
|
+
data.tar.gz: 688f2f91a2f340a82dc8e11d73428e2bb6bfb66e2f15f67e24b7e31117582617b82d106f973cb152395cd4de07172bb48a304d0febd196daefa234cd45ac4980
|
data/CHANGELOG.md
CHANGED
data/Cargo.lock
CHANGED
|
@@ -81,12 +81,40 @@ version = "1.0.2"
|
|
|
81
81
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
82
82
|
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
|
83
83
|
|
|
84
|
+
[[package]]
|
|
85
|
+
name = "errno"
|
|
86
|
+
version = "0.3.14"
|
|
87
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
88
|
+
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
|
89
|
+
dependencies = [
|
|
90
|
+
"libc",
|
|
91
|
+
"windows-sys",
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
[[package]]
|
|
95
|
+
name = "fastrand"
|
|
96
|
+
version = "2.4.1"
|
|
97
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
98
|
+
checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6"
|
|
99
|
+
|
|
84
100
|
[[package]]
|
|
85
101
|
name = "foldhash"
|
|
86
102
|
version = "0.2.0"
|
|
87
103
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
88
104
|
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
|
|
89
105
|
|
|
106
|
+
[[package]]
|
|
107
|
+
name = "getrandom"
|
|
108
|
+
version = "0.3.4"
|
|
109
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
110
|
+
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
|
|
111
|
+
dependencies = [
|
|
112
|
+
"cfg-if",
|
|
113
|
+
"libc",
|
|
114
|
+
"r-efi",
|
|
115
|
+
"wasip2",
|
|
116
|
+
]
|
|
117
|
+
|
|
90
118
|
[[package]]
|
|
91
119
|
name = "glob"
|
|
92
120
|
version = "0.3.2"
|
|
@@ -148,6 +176,12 @@ dependencies = [
|
|
|
148
176
|
"windows-targets",
|
|
149
177
|
]
|
|
150
178
|
|
|
179
|
+
[[package]]
|
|
180
|
+
name = "linux-raw-sys"
|
|
181
|
+
version = "0.11.0"
|
|
182
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
183
|
+
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
|
|
184
|
+
|
|
151
185
|
[[package]]
|
|
152
186
|
name = "magnus"
|
|
153
187
|
version = "0.8.1"
|
|
@@ -179,9 +213,9 @@ checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
|
|
|
179
213
|
|
|
180
214
|
[[package]]
|
|
181
215
|
name = "memmap2"
|
|
182
|
-
version = "0.
|
|
216
|
+
version = "0.9.10"
|
|
183
217
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
184
|
-
checksum = "
|
|
218
|
+
checksum = "714098028fe011992e1c3962653c96b2d578c4b4bce9036e15ff220319b1e0e3"
|
|
185
219
|
dependencies = [
|
|
186
220
|
"libc",
|
|
187
221
|
]
|
|
@@ -202,6 +236,12 @@ dependencies = [
|
|
|
202
236
|
"minimal-lexical",
|
|
203
237
|
]
|
|
204
238
|
|
|
239
|
+
[[package]]
|
|
240
|
+
name = "once_cell"
|
|
241
|
+
version = "1.21.4"
|
|
242
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
243
|
+
checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
|
|
244
|
+
|
|
205
245
|
[[package]]
|
|
206
246
|
name = "proc-macro2"
|
|
207
247
|
version = "1.0.95"
|
|
@@ -220,6 +260,12 @@ dependencies = [
|
|
|
220
260
|
"proc-macro2",
|
|
221
261
|
]
|
|
222
262
|
|
|
263
|
+
[[package]]
|
|
264
|
+
name = "r-efi"
|
|
265
|
+
version = "5.3.0"
|
|
266
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
267
|
+
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
|
268
|
+
|
|
223
269
|
[[package]]
|
|
224
270
|
name = "rb-sys"
|
|
225
271
|
version = "0.9.124"
|
|
@@ -285,6 +331,19 @@ version = "1.1.0"
|
|
|
285
331
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
286
332
|
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
|
287
333
|
|
|
334
|
+
[[package]]
|
|
335
|
+
name = "rustix"
|
|
336
|
+
version = "1.1.2"
|
|
337
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
338
|
+
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
|
|
339
|
+
dependencies = [
|
|
340
|
+
"bitflags",
|
|
341
|
+
"errno",
|
|
342
|
+
"libc",
|
|
343
|
+
"linux-raw-sys",
|
|
344
|
+
"windows-sys",
|
|
345
|
+
]
|
|
346
|
+
|
|
288
347
|
[[package]]
|
|
289
348
|
name = "ryu"
|
|
290
349
|
version = "1.0.20"
|
|
@@ -293,21 +352,24 @@ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
|
|
293
352
|
|
|
294
353
|
[[package]]
|
|
295
354
|
name = "safetensors"
|
|
296
|
-
version = "0.
|
|
355
|
+
version = "0.8.0"
|
|
297
356
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
298
|
-
checksum = "
|
|
357
|
+
checksum = "79b079b829cb27a1c3c374341345ed2e8b2c0c839034522cee576c140bd7f846"
|
|
299
358
|
dependencies = [
|
|
300
359
|
"hashbrown",
|
|
360
|
+
"libc",
|
|
301
361
|
"serde",
|
|
302
362
|
"serde_json",
|
|
363
|
+
"tempfile",
|
|
303
364
|
]
|
|
304
365
|
|
|
305
366
|
[[package]]
|
|
306
367
|
name = "safetensors-ruby"
|
|
307
|
-
version = "0.
|
|
368
|
+
version = "0.3.0"
|
|
308
369
|
dependencies = [
|
|
309
370
|
"magnus",
|
|
310
371
|
"memmap2",
|
|
372
|
+
"rb-sys",
|
|
311
373
|
"safetensors",
|
|
312
374
|
"serde_json",
|
|
313
375
|
]
|
|
@@ -373,12 +435,49 @@ dependencies = [
|
|
|
373
435
|
"unicode-ident",
|
|
374
436
|
]
|
|
375
437
|
|
|
438
|
+
[[package]]
|
|
439
|
+
name = "tempfile"
|
|
440
|
+
version = "3.23.0"
|
|
441
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
442
|
+
checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16"
|
|
443
|
+
dependencies = [
|
|
444
|
+
"fastrand",
|
|
445
|
+
"getrandom",
|
|
446
|
+
"once_cell",
|
|
447
|
+
"rustix",
|
|
448
|
+
"windows-sys",
|
|
449
|
+
]
|
|
450
|
+
|
|
376
451
|
[[package]]
|
|
377
452
|
name = "unicode-ident"
|
|
378
453
|
version = "1.0.18"
|
|
379
454
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
380
455
|
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
|
381
456
|
|
|
457
|
+
[[package]]
|
|
458
|
+
name = "wasip2"
|
|
459
|
+
version = "1.0.3+wasi-0.2.9"
|
|
460
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
461
|
+
checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6"
|
|
462
|
+
dependencies = [
|
|
463
|
+
"wit-bindgen",
|
|
464
|
+
]
|
|
465
|
+
|
|
466
|
+
[[package]]
|
|
467
|
+
name = "windows-link"
|
|
468
|
+
version = "0.2.1"
|
|
469
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
470
|
+
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
|
471
|
+
|
|
472
|
+
[[package]]
|
|
473
|
+
name = "windows-sys"
|
|
474
|
+
version = "0.61.2"
|
|
475
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
476
|
+
checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc"
|
|
477
|
+
dependencies = [
|
|
478
|
+
"windows-link",
|
|
479
|
+
]
|
|
480
|
+
|
|
382
481
|
[[package]]
|
|
383
482
|
name = "windows-targets"
|
|
384
483
|
version = "0.53.2"
|
|
@@ -442,3 +541,9 @@ name = "windows_x86_64_msvc"
|
|
|
442
541
|
version = "0.53.0"
|
|
443
542
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
444
543
|
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
|
544
|
+
|
|
545
|
+
[[package]]
|
|
546
|
+
name = "wit-bindgen"
|
|
547
|
+
version = "0.57.1"
|
|
548
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
549
|
+
checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e"
|
data/README.md
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
:slightly_smiling_face: Simple, [safe way](https://github.com/huggingface/safetensors) to store and distribute tensors
|
|
4
4
|
|
|
5
|
-
Supports [Torch.rb](https://github.com/ankane/torch.rb) and [
|
|
5
|
+
Supports [Torch.rb](https://github.com/ankane/torch.rb) and [numo-narray-alt](https://github.com/yoshoku/numo-narray-alt)
|
|
6
6
|
|
|
7
7
|
[](https://github.com/ankane/safetensors-ruby/actions)
|
|
8
8
|
|
data/ext/safetensors/Cargo.toml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[package]
|
|
2
2
|
name = "safetensors-ruby"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.3.0"
|
|
4
4
|
license = "Apache-2.0"
|
|
5
5
|
authors = ["Andrew Kane <andrew@ankane.org>"]
|
|
6
6
|
edition = "2021"
|
|
@@ -13,6 +13,7 @@ crate-type = ["cdylib"]
|
|
|
13
13
|
|
|
14
14
|
[dependencies]
|
|
15
15
|
magnus = "0.8"
|
|
16
|
-
memmap2 = "0.
|
|
17
|
-
|
|
16
|
+
memmap2 = "0.9"
|
|
17
|
+
rb-sys = "0.9"
|
|
18
|
+
safetensors = "=0.8.0"
|
|
18
19
|
serde_json = "1"
|
data/ext/safetensors/src/lib.rs
CHANGED
|
@@ -1,76 +1,125 @@
|
|
|
1
|
+
mod ruby;
|
|
2
|
+
|
|
3
|
+
use core::slice;
|
|
1
4
|
use magnus::{
|
|
2
5
|
function, kwargs, method, prelude::*, r_hash::ForEach, Error, IntoValue, RArray, RHash,
|
|
3
6
|
RModule, RString, Ruby, Symbol, TryConvert, Value,
|
|
4
7
|
};
|
|
5
8
|
use memmap2::{Mmap, MmapOptions};
|
|
6
|
-
use safetensors::tensor::{Dtype, Metadata, SafeTensors
|
|
9
|
+
use safetensors::tensor::{Dtype, Metadata, SafeTensors};
|
|
10
|
+
use safetensors::View;
|
|
11
|
+
use std::borrow::Cow;
|
|
7
12
|
use std::collections::HashMap;
|
|
13
|
+
use std::fmt;
|
|
8
14
|
use std::fs::File;
|
|
9
15
|
use std::path::PathBuf;
|
|
10
16
|
use std::sync::Arc;
|
|
11
17
|
|
|
18
|
+
use crate::ruby::GvlExt;
|
|
19
|
+
|
|
12
20
|
type RbResult<T> = Result<T, Error>;
|
|
13
21
|
|
|
14
|
-
|
|
22
|
+
#[derive(Clone, Debug)]
|
|
23
|
+
struct TensorSpec {
|
|
24
|
+
dtype: Dtype,
|
|
25
|
+
shape: Vec<usize>,
|
|
26
|
+
data_ptr: u64,
|
|
27
|
+
data_len: usize,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
impl TensorSpec {
|
|
31
|
+
fn new(dtype: &str, shape: Vec<usize>, data_ptr: u64, data_len: usize) -> RbResult<Self> {
|
|
32
|
+
let dtype = parse_dtype_str(dtype)?;
|
|
33
|
+
let mut shape = shape;
|
|
34
|
+
// F4 packs two elements per byte; the safetensors header records the
|
|
35
|
+
// logical element count, so double the last dim.
|
|
36
|
+
if dtype == Dtype::F4 && !shape.is_empty() {
|
|
37
|
+
let n = shape.len();
|
|
38
|
+
shape[n - 1] = shape[n - 1].checked_mul(2).ok_or_else(|| {
|
|
39
|
+
SafetensorError::new_err(format!(
|
|
40
|
+
"F4 last-dim {} doubled to logical shape overflows usize",
|
|
41
|
+
shape[n - 1]
|
|
42
|
+
))
|
|
43
|
+
})?;
|
|
44
|
+
}
|
|
45
|
+
Ok(Self {
|
|
46
|
+
dtype,
|
|
47
|
+
shape,
|
|
48
|
+
data_ptr,
|
|
49
|
+
data_len,
|
|
50
|
+
})
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
impl View for &TensorSpec {
|
|
55
|
+
fn dtype(&self) -> Dtype {
|
|
56
|
+
self.dtype
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
fn shape(&self) -> &[usize] {
|
|
60
|
+
&self.shape
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
fn data(&self) -> Cow<'_, [u8]> {
|
|
64
|
+
let p = self.data_ptr as *const u8;
|
|
65
|
+
// SAFETY: validated by the caller; see the struct-level safety note.
|
|
66
|
+
unsafe {
|
|
67
|
+
let slice = slice::from_raw_parts(p, self.data_len);
|
|
68
|
+
Cow::Borrowed(slice)
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
fn data_len(&self) -> usize {
|
|
73
|
+
self.data_len
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
fn parse_dtype_str(dtype: &str) -> RbResult<Dtype> {
|
|
78
|
+
Ok(match dtype {
|
|
79
|
+
"bool" => Dtype::BOOL,
|
|
80
|
+
"int8" => Dtype::I8,
|
|
81
|
+
"uint8" => Dtype::U8,
|
|
82
|
+
"int16" => Dtype::I16,
|
|
83
|
+
"uint16" => Dtype::U16,
|
|
84
|
+
"int32" => Dtype::I32,
|
|
85
|
+
"uint32" => Dtype::U32,
|
|
86
|
+
"int64" => Dtype::I64,
|
|
87
|
+
"uint64" => Dtype::U64,
|
|
88
|
+
"float16" => Dtype::F16,
|
|
89
|
+
"float32" => Dtype::F32,
|
|
90
|
+
"float64" => Dtype::F64,
|
|
91
|
+
"bfloat16" => Dtype::BF16,
|
|
92
|
+
"float8_e4m3fn" => Dtype::F8_E4M3,
|
|
93
|
+
"float8_e4m3fnuz" => Dtype::F8_E4M3FNUZ,
|
|
94
|
+
"float8_e5m2" => Dtype::F8_E5M2,
|
|
95
|
+
"float8_e5m2fnuz" => Dtype::F8_E5M2FNUZ,
|
|
96
|
+
"float8_e8m0fnu" => Dtype::F8_E8M0,
|
|
97
|
+
"float4_e2m1fn_x2" => Dtype::F4,
|
|
98
|
+
"complex64" => Dtype::C64,
|
|
99
|
+
other => {
|
|
100
|
+
return Err(SafetensorError::new_err(format!(
|
|
101
|
+
"Unknown dtype {other:?}. Supported dtypes: bool, int8, uint8, int16, uint16, \
|
|
102
|
+
int32, uint32, int64, uint64, float16, float32, float64, bfloat16, \
|
|
103
|
+
float8_e4m3fn, float8_e4m3fnuz, float8_e5m2, float8_e5m2fnuz, float8_e8m0fnu, \
|
|
104
|
+
float4_e2m1fn_x2, complex64",
|
|
105
|
+
)));
|
|
106
|
+
}
|
|
107
|
+
})
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
fn prepare(tensor_dict: &RHash) -> RbResult<HashMap<String, TensorSpec>> {
|
|
15
111
|
let mut tensors = HashMap::with_capacity(tensor_dict.len());
|
|
16
112
|
tensor_dict.foreach(|tensor_name: String, tensor_desc: RHash| {
|
|
17
|
-
let
|
|
18
|
-
let
|
|
19
|
-
let
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
"bool" => Some(Dtype::BOOL),
|
|
28
|
-
"int8" => Some(Dtype::I8),
|
|
29
|
-
"uint8" => Some(Dtype::U8),
|
|
30
|
-
"int16" => Some(Dtype::I16),
|
|
31
|
-
"uint16" => Some(Dtype::U16),
|
|
32
|
-
"int32" => Some(Dtype::I32),
|
|
33
|
-
"uint32" => Some(Dtype::U32),
|
|
34
|
-
"int64" => Some(Dtype::I64),
|
|
35
|
-
"uint64" => Some(Dtype::U64),
|
|
36
|
-
"float16" => Some(Dtype::F16),
|
|
37
|
-
"float32" => Some(Dtype::F32),
|
|
38
|
-
"float64" => Some(Dtype::F64),
|
|
39
|
-
"bfloat16" => Some(Dtype::BF16),
|
|
40
|
-
"float8_e4m3fn" => Some(Dtype::F8_E4M3),
|
|
41
|
-
"float8_e5m2" => Some(Dtype::F8_E5M2),
|
|
42
|
-
dtype_str => {
|
|
43
|
-
return Err(SafetensorError::new_err(format!(
|
|
44
|
-
"dtype {dtype_str} is not covered",
|
|
45
|
-
)));
|
|
46
|
-
}
|
|
47
|
-
}
|
|
48
|
-
}
|
|
49
|
-
"data" => {
|
|
50
|
-
let rs = RString::try_convert(value)?;
|
|
51
|
-
// SAFETY: No context switching between threads in native extensions
|
|
52
|
-
// so the string will not be modified (or garbage collected)
|
|
53
|
-
// while the reference is held. Also, the string is a private copy.
|
|
54
|
-
let slice = unsafe { rs.as_slice() };
|
|
55
|
-
data = Some((slice.as_ptr(), slice.len()));
|
|
56
|
-
}
|
|
57
|
-
_ => println!("Ignored unknown kwarg option {key}"),
|
|
58
|
-
};
|
|
59
|
-
|
|
60
|
-
Ok(ForEach::Continue)
|
|
61
|
-
})?;
|
|
62
|
-
let shape = shape.ok_or_else(|| {
|
|
63
|
-
SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}"))
|
|
64
|
-
})?;
|
|
65
|
-
let dtype = dtype.ok_or_else(|| {
|
|
66
|
-
SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}"))
|
|
67
|
-
})?;
|
|
68
|
-
let data = data.ok_or_else(|| {
|
|
69
|
-
SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}"))
|
|
70
|
-
})?;
|
|
71
|
-
// SAFETY: See comment above.
|
|
72
|
-
let data = unsafe { std::slice::from_raw_parts(data.0, data.1) };
|
|
73
|
-
let tensor = TensorView::new(dtype, shape, data)
|
|
113
|
+
let dtype: String = tensor_desc.aref("dtype")?;
|
|
114
|
+
let shape: Vec<usize> = tensor_desc.aref("shape")?;
|
|
115
|
+
let data: RString = tensor_desc.aref("data")?;
|
|
116
|
+
|
|
117
|
+
// SAFETY: No context switching between threads in native extensions
|
|
118
|
+
// so the string will not be modified (or garbage collected)
|
|
119
|
+
// while the reference is held. Also, the string is a private copy.
|
|
120
|
+
let slice = unsafe { data.as_slice() };
|
|
121
|
+
|
|
122
|
+
let tensor = TensorSpec::new(dtype.as_ref(), shape, slice.as_ptr() as u64, slice.len())
|
|
74
123
|
.map_err(|e| SafetensorError::new_err(format!("Error preparing tensor view: {e:?}")))?;
|
|
75
124
|
tensors.insert(tensor_name, tensor);
|
|
76
125
|
|
|
@@ -86,19 +135,21 @@ fn serialize(
|
|
|
86
135
|
) -> RbResult<RString> {
|
|
87
136
|
let tensors = prepare(&tensor_dict)?;
|
|
88
137
|
let metadata_map = metadata.map(HashMap::from_iter);
|
|
89
|
-
let out =
|
|
138
|
+
let out = ruby
|
|
139
|
+
.detach(|| safetensors::tensor::serialize(&tensors, metadata_map))
|
|
90
140
|
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
|
|
91
141
|
let rbbytes = ruby.str_from_slice(&out);
|
|
92
142
|
Ok(rbbytes)
|
|
93
143
|
}
|
|
94
144
|
|
|
95
145
|
fn serialize_file(
|
|
146
|
+
ruby: &Ruby,
|
|
96
147
|
tensor_dict: RHash,
|
|
97
148
|
filename: PathBuf,
|
|
98
149
|
metadata: Option<HashMap<String, String>>,
|
|
99
150
|
) -> RbResult<()> {
|
|
100
151
|
let tensors = prepare(&tensor_dict)?;
|
|
101
|
-
safetensors::tensor::serialize_to_file(&tensors, metadata, filename.as_path())
|
|
152
|
+
ruby.detach(|| safetensors::tensor::serialize_to_file(&tensors, metadata, filename.as_path()))
|
|
102
153
|
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
|
|
103
154
|
Ok(())
|
|
104
155
|
}
|
|
@@ -132,6 +183,15 @@ enum Framework {
|
|
|
132
183
|
Numo,
|
|
133
184
|
}
|
|
134
185
|
|
|
186
|
+
impl fmt::Display for Framework {
|
|
187
|
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
188
|
+
f.write_str(match *self {
|
|
189
|
+
Framework::Pytorch => "torch",
|
|
190
|
+
Framework::Numo => "numo",
|
|
191
|
+
})
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
135
195
|
impl TryConvert for Framework {
|
|
136
196
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
|
137
197
|
let name: String = String::try_convert(ob)?;
|
|
@@ -157,56 +217,64 @@ enum Device {
|
|
|
157
217
|
Mps,
|
|
158
218
|
Npu(usize),
|
|
159
219
|
Xpu(usize),
|
|
220
|
+
Xla(usize),
|
|
221
|
+
Mlu(usize),
|
|
222
|
+
Hpu(usize),
|
|
223
|
+
Anonymous(usize),
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
impl fmt::Display for Device {
|
|
227
|
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
228
|
+
match *self {
|
|
229
|
+
Device::Cpu => write!(f, "cpu"),
|
|
230
|
+
Device::Mps => write!(f, "mps"),
|
|
231
|
+
Device::Cuda(index) => write!(f, "cuda:{index}"),
|
|
232
|
+
Device::Npu(index) => write!(f, "npu:{index}"),
|
|
233
|
+
Device::Xpu(index) => write!(f, "xpu:{index}"),
|
|
234
|
+
Device::Xla(index) => write!(f, "xla:{index}"),
|
|
235
|
+
Device::Mlu(index) => write!(f, "mlu:{index}"),
|
|
236
|
+
Device::Hpu(index) => write!(f, "hpu:{index}"),
|
|
237
|
+
Device::Anonymous(index) => write!(f, "{index}"),
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
/// Parsing the device index.
|
|
243
|
+
fn parse_device(name: &str) -> RbResult<usize> {
|
|
244
|
+
let tokens: Vec<_> = name.split(':').collect();
|
|
245
|
+
if tokens.len() == 2 {
|
|
246
|
+
Ok(tokens[1].parse().map_err(SafetensorError::parse)?)
|
|
247
|
+
} else {
|
|
248
|
+
Err(SafetensorError::new_err(format!(
|
|
249
|
+
"device {name} is invalid"
|
|
250
|
+
)))
|
|
251
|
+
}
|
|
160
252
|
}
|
|
161
253
|
|
|
162
254
|
impl TryConvert for Device {
|
|
163
255
|
fn try_convert(ob: Value) -> RbResult<Self> {
|
|
164
256
|
if let Ok(name) = String::try_convert(ob) {
|
|
165
|
-
match
|
|
257
|
+
match name.as_str() {
|
|
166
258
|
"cpu" => Ok(Device::Cpu),
|
|
167
259
|
"cuda" => Ok(Device::Cuda(0)),
|
|
168
260
|
"mps" => Ok(Device::Mps),
|
|
169
261
|
"npu" => Ok(Device::Npu(0)),
|
|
170
262
|
"xpu" => Ok(Device::Xpu(0)),
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
}
|
|
181
|
-
}
|
|
182
|
-
name if name.starts_with("npu:") => {
|
|
183
|
-
let tokens: Vec<_> = name.split(':').collect();
|
|
184
|
-
if tokens.len() == 2 {
|
|
185
|
-
let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
|
|
186
|
-
Ok(Device::Npu(device))
|
|
187
|
-
} else {
|
|
188
|
-
Err(SafetensorError::new_err(format!(
|
|
189
|
-
"device {name} is invalid"
|
|
190
|
-
)))
|
|
191
|
-
}
|
|
192
|
-
}
|
|
193
|
-
name if name.starts_with("xpu:") => {
|
|
194
|
-
let tokens: Vec<_> = name.split(':').collect();
|
|
195
|
-
if tokens.len() == 2 {
|
|
196
|
-
let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
|
|
197
|
-
Ok(Device::Xpu(device))
|
|
198
|
-
} else {
|
|
199
|
-
Err(SafetensorError::new_err(format!(
|
|
200
|
-
"device {name} is invalid"
|
|
201
|
-
)))
|
|
202
|
-
}
|
|
203
|
-
}
|
|
263
|
+
"xla" => Ok(Device::Xla(0)),
|
|
264
|
+
"mlu" => Ok(Device::Mlu(0)),
|
|
265
|
+
"hpu" => Ok(Device::Hpu(0)),
|
|
266
|
+
name if name.starts_with("cuda:") => parse_device(name).map(Device::Cuda),
|
|
267
|
+
name if name.starts_with("npu:") => parse_device(name).map(Device::Npu),
|
|
268
|
+
name if name.starts_with("xpu:") => parse_device(name).map(Device::Xpu),
|
|
269
|
+
name if name.starts_with("xla:") => parse_device(name).map(Device::Xla),
|
|
270
|
+
name if name.starts_with("mlu:") => parse_device(name).map(Device::Mlu),
|
|
271
|
+
name if name.starts_with("hpu:") => parse_device(name).map(Device::Hpu),
|
|
204
272
|
name => Err(SafetensorError::new_err(format!(
|
|
205
273
|
"device {name} is invalid"
|
|
206
274
|
))),
|
|
207
275
|
}
|
|
208
276
|
} else if let Ok(number) = usize::try_convert(ob) {
|
|
209
|
-
Ok(Device::
|
|
277
|
+
Ok(Device::Anonymous(number))
|
|
210
278
|
} else {
|
|
211
279
|
Err(SafetensorError::new_err(format!("device {ob} is invalid")))
|
|
212
280
|
}
|
|
@@ -221,6 +289,10 @@ impl IntoValue for Device {
|
|
|
221
289
|
Device::Mps => "mps".into_value_with(ruby),
|
|
222
290
|
Device::Npu(n) => format!("npu:{n}").into_value_with(ruby),
|
|
223
291
|
Device::Xpu(n) => format!("xpu:{n}").into_value_with(ruby),
|
|
292
|
+
Device::Xla(n) => format!("xla:{n}").into_value_with(ruby),
|
|
293
|
+
Device::Mlu(n) => format!("mlu:{n}").into_value_with(ruby),
|
|
294
|
+
Device::Hpu(n) => format!("hpu:{n}").into_value_with(ruby),
|
|
295
|
+
Device::Anonymous(n) => n.into_value_with(ruby),
|
|
224
296
|
}
|
|
225
297
|
}
|
|
226
298
|
}
|
|
@@ -240,13 +312,13 @@ struct Open {
|
|
|
240
312
|
impl Open {
|
|
241
313
|
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
|
|
242
314
|
let file = File::open(&filename).map_err(|_| {
|
|
243
|
-
SafetensorError::new_err(format!("No such file or directory: {
|
|
315
|
+
SafetensorError::new_err(format!("No such file or directory: {}", filename.display()))
|
|
244
316
|
})?;
|
|
245
317
|
let device = device.unwrap_or(Device::Cpu);
|
|
246
318
|
|
|
247
319
|
if device != Device::Cpu && framework != Framework::Pytorch {
|
|
248
320
|
return Err(SafetensorError::new_err(format!(
|
|
249
|
-
"Device {device
|
|
321
|
+
"Device {device} is not support for framework {framework}",
|
|
250
322
|
)));
|
|
251
323
|
}
|
|
252
324
|
|
|
@@ -255,7 +327,7 @@ impl Open {
|
|
|
255
327
|
let buffer = unsafe { MmapOptions::new().map(&file).map_err(SafetensorError::io)? };
|
|
256
328
|
|
|
257
329
|
let (n, metadata) = SafeTensors::read_metadata(&buffer).map_err(|e| {
|
|
258
|
-
SafetensorError::new_err(format!("Error while deserializing header: {e
|
|
330
|
+
SafetensorError::new_err(format!("Error while deserializing header: {e}"))
|
|
259
331
|
})?;
|
|
260
332
|
|
|
261
333
|
let offset = n + 8;
|
|
@@ -285,7 +357,7 @@ impl Open {
|
|
|
285
357
|
|
|
286
358
|
pub fn get_tensor(&self, ruby: &Ruby, name: &str) -> RbResult<Value> {
|
|
287
359
|
let info = self.metadata.info(name).ok_or_else(|| {
|
|
288
|
-
SafetensorError::new_err(format!("File does not contain tensor {name}"
|
|
360
|
+
SafetensorError::new_err(format!("File does not contain tensor {name}"))
|
|
289
361
|
})?;
|
|
290
362
|
|
|
291
363
|
match &self.storage.as_ref() {
|
|
@@ -357,7 +429,7 @@ fn create_tensor(
|
|
|
357
429
|
.map_err(|_| SafetensorError::new_err("Torch not loaded".into()))?,
|
|
358
430
|
false,
|
|
359
431
|
),
|
|
360
|
-
|
|
432
|
+
Framework::Numo => (
|
|
361
433
|
ruby.class_object()
|
|
362
434
|
.const_get("Numo")
|
|
363
435
|
.map_err(|_| SafetensorError::new_err("Numo not loaded".into()))?,
|
|
@@ -375,7 +447,7 @@ fn create_tensor(
|
|
|
375
447
|
)?;
|
|
376
448
|
module.funcall("_from_blob_ref", (array, shape, options))?
|
|
377
449
|
}
|
|
378
|
-
|
|
450
|
+
Framework::Numo => {
|
|
379
451
|
let class: Value = module.funcall("const_get", (dtype,))?;
|
|
380
452
|
class.funcall("from_binary", (array, shape))?
|
|
381
453
|
}
|
|
@@ -419,6 +491,10 @@ fn get_rbdtype(ruby: &Ruby, _module: RModule, dtype: Dtype, is_numo: bool) -> Rb
|
|
|
419
491
|
Dtype::BOOL => ruby.to_symbol("bool"),
|
|
420
492
|
Dtype::F8_E4M3 => ruby.to_symbol("float8_e4m3fn"),
|
|
421
493
|
Dtype::F8_E5M2 => ruby.to_symbol("float8_e5m2"),
|
|
494
|
+
Dtype::F8_E5M2FNUZ => ruby.to_symbol("float8_e5m2fnuz"),
|
|
495
|
+
Dtype::F8_E8M0 => ruby.to_symbol("float8_e8m0fnu"),
|
|
496
|
+
Dtype::F4 => ruby.to_symbol("float4_e2m1fn_x2"),
|
|
497
|
+
Dtype::C64 => ruby.to_symbol("complex64"),
|
|
422
498
|
dtype => {
|
|
423
499
|
return Err(SafetensorError::new_err(format!(
|
|
424
500
|
"Dtype not understood: {dtype:?}"
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
use std::ffi::c_void;
|
|
2
|
+
use std::ptr::null_mut;
|
|
3
|
+
|
|
4
|
+
use magnus::Ruby;
|
|
5
|
+
use rb_sys::rb_thread_call_without_gvl;
|
|
6
|
+
|
|
7
|
+
pub trait GvlExt {
|
|
8
|
+
fn detach<T, F>(&self, func: F) -> T
|
|
9
|
+
where
|
|
10
|
+
F: Send + FnOnce() -> T,
|
|
11
|
+
T: Send;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
impl GvlExt for Ruby {
|
|
15
|
+
fn detach<T, F>(&self, func: F) -> T
|
|
16
|
+
where
|
|
17
|
+
F: Send + FnOnce() -> T,
|
|
18
|
+
T: Send,
|
|
19
|
+
{
|
|
20
|
+
let mut data = CallbackData {
|
|
21
|
+
func: Some(func),
|
|
22
|
+
result: None,
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
unsafe {
|
|
26
|
+
rb_thread_call_without_gvl(
|
|
27
|
+
Some(call_without_gvl::<F, T>),
|
|
28
|
+
&mut data as *mut _ as *mut c_void,
|
|
29
|
+
None,
|
|
30
|
+
null_mut(),
|
|
31
|
+
);
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
data.result.unwrap()
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
struct CallbackData<F, T> {
|
|
39
|
+
func: Option<F>,
|
|
40
|
+
result: Option<T>,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
extern "C" fn call_without_gvl<F, T>(data: *mut c_void) -> *mut c_void
|
|
44
|
+
where
|
|
45
|
+
F: FnOnce() -> T,
|
|
46
|
+
{
|
|
47
|
+
let data = unsafe { &mut *(data as *mut CallbackData<F, T>) };
|
|
48
|
+
let func = data.func.take().unwrap();
|
|
49
|
+
data.result = Some(func());
|
|
50
|
+
null_mut()
|
|
51
|
+
}
|
data/lib/safetensors/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: safetensors
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.
|
|
4
|
+
version: 0.3.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Andrew Kane
|
|
@@ -37,6 +37,7 @@ files:
|
|
|
37
37
|
- ext/safetensors/Cargo.toml
|
|
38
38
|
- ext/safetensors/extconf.rb
|
|
39
39
|
- ext/safetensors/src/lib.rs
|
|
40
|
+
- ext/safetensors/src/ruby.rs
|
|
40
41
|
- lib/safetensors.rb
|
|
41
42
|
- lib/safetensors/numo.rb
|
|
42
43
|
- lib/safetensors/torch.rb
|
|
@@ -52,14 +53,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
|
52
53
|
requirements:
|
|
53
54
|
- - ">="
|
|
54
55
|
- !ruby/object:Gem::Version
|
|
55
|
-
version: '3.
|
|
56
|
+
version: '3.3'
|
|
56
57
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
57
58
|
requirements:
|
|
58
59
|
- - ">="
|
|
59
60
|
- !ruby/object:Gem::Version
|
|
60
61
|
version: '0'
|
|
61
62
|
requirements: []
|
|
62
|
-
rubygems_version: 4.0.
|
|
63
|
+
rubygems_version: 4.0.6
|
|
63
64
|
specification_version: 4
|
|
64
65
|
summary: Simple, safe way to store and distribute tensors
|
|
65
66
|
test_files: []
|