safetensors 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a7e6ec72940990e896207a173be53e68952eb12c8ad8b4f4c3bd707fabee4944
4
- data.tar.gz: aa06b8480be369bb75db42d3eab5b3d139194d74aa8bad6580f5f71dd4351514
3
+ metadata.gz: cc628696f449fb899f85e95c1c824c00b832aad7e3f92253bdc48c9c1e30c3af
4
+ data.tar.gz: 433d331c6cfaadcc05a844d7fb8a1e31681b626bcedadc6fdc653960a0cfc578
5
5
  SHA512:
6
- metadata.gz: 4086269deb609530821fe5b3257be0c1deedfcc9be19469c8552afb9b86cbfec32ad40160ab761ab5bea082dca7fc05cc22fe3a04a0de604ae37b916e309c21f
7
- data.tar.gz: 7edf607e6a2ff87b4844d50a2ef5a5a629db6aa87a522d9a771559a8f6ed35ed5094f59bfde6d6491f39c914b295606e2a97a09d113ae621ee0e8a78ff7e0329
6
+ metadata.gz: d05099bbce87c9cfe6634ed567378a8e135b3e1c72d6aa541d6ba2c8452ff3a65e249bd460d8490d8b525c01ae6269aafd75d07d07c3a40fedea9418eb7f79a3
7
+ data.tar.gz: 688f2f91a2f340a82dc8e11d73428e2bb6bfb66e2f15f67e24b7e31117582617b82d106f973cb152395cd4de07172bb48a304d0febd196daefa234cd45ac4980
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.3.0 (2026-06-09)
2
+
3
+ - Updated Safetensors to 0.8.0
4
+ - Added support for releasing GVL
5
+ - Dropped support for Ruby < 3.3
6
+
1
7
  ## 0.2.2 (2026-01-05)
2
8
 
3
9
  - Added support for Ruby 4.0
data/Cargo.lock CHANGED
@@ -81,12 +81,40 @@ version = "1.0.2"
81
81
  source = "registry+https://github.com/rust-lang/crates.io-index"
82
82
  checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
83
83
 
84
+ [[package]]
85
+ name = "errno"
86
+ version = "0.3.14"
87
+ source = "registry+https://github.com/rust-lang/crates.io-index"
88
+ checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
89
+ dependencies = [
90
+ "libc",
91
+ "windows-sys",
92
+ ]
93
+
94
+ [[package]]
95
+ name = "fastrand"
96
+ version = "2.4.1"
97
+ source = "registry+https://github.com/rust-lang/crates.io-index"
98
+ checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6"
99
+
84
100
  [[package]]
85
101
  name = "foldhash"
86
102
  version = "0.2.0"
87
103
  source = "registry+https://github.com/rust-lang/crates.io-index"
88
104
  checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
89
105
 
106
+ [[package]]
107
+ name = "getrandom"
108
+ version = "0.3.4"
109
+ source = "registry+https://github.com/rust-lang/crates.io-index"
110
+ checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
111
+ dependencies = [
112
+ "cfg-if",
113
+ "libc",
114
+ "r-efi",
115
+ "wasip2",
116
+ ]
117
+
90
118
  [[package]]
91
119
  name = "glob"
92
120
  version = "0.3.2"
@@ -148,6 +176,12 @@ dependencies = [
148
176
  "windows-targets",
149
177
  ]
150
178
 
179
+ [[package]]
180
+ name = "linux-raw-sys"
181
+ version = "0.11.0"
182
+ source = "registry+https://github.com/rust-lang/crates.io-index"
183
+ checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
184
+
151
185
  [[package]]
152
186
  name = "magnus"
153
187
  version = "0.8.1"
@@ -179,9 +213,9 @@ checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
179
213
 
180
214
  [[package]]
181
215
  name = "memmap2"
182
- version = "0.5.10"
216
+ version = "0.9.10"
183
217
  source = "registry+https://github.com/rust-lang/crates.io-index"
184
- checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327"
218
+ checksum = "714098028fe011992e1c3962653c96b2d578c4b4bce9036e15ff220319b1e0e3"
185
219
  dependencies = [
186
220
  "libc",
187
221
  ]
@@ -202,6 +236,12 @@ dependencies = [
202
236
  "minimal-lexical",
203
237
  ]
204
238
 
239
+ [[package]]
240
+ name = "once_cell"
241
+ version = "1.21.4"
242
+ source = "registry+https://github.com/rust-lang/crates.io-index"
243
+ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
244
+
205
245
  [[package]]
206
246
  name = "proc-macro2"
207
247
  version = "1.0.95"
@@ -220,6 +260,12 @@ dependencies = [
220
260
  "proc-macro2",
221
261
  ]
222
262
 
263
+ [[package]]
264
+ name = "r-efi"
265
+ version = "5.3.0"
266
+ source = "registry+https://github.com/rust-lang/crates.io-index"
267
+ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
268
+
223
269
  [[package]]
224
270
  name = "rb-sys"
225
271
  version = "0.9.124"
@@ -285,6 +331,19 @@ version = "1.1.0"
285
331
  source = "registry+https://github.com/rust-lang/crates.io-index"
286
332
  checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
287
333
 
334
+ [[package]]
335
+ name = "rustix"
336
+ version = "1.1.2"
337
+ source = "registry+https://github.com/rust-lang/crates.io-index"
338
+ checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
339
+ dependencies = [
340
+ "bitflags",
341
+ "errno",
342
+ "libc",
343
+ "linux-raw-sys",
344
+ "windows-sys",
345
+ ]
346
+
288
347
  [[package]]
289
348
  name = "ryu"
290
349
  version = "1.0.20"
@@ -293,21 +352,24 @@ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
293
352
 
294
353
  [[package]]
295
354
  name = "safetensors"
296
- version = "0.7.0"
355
+ version = "0.8.0"
297
356
  source = "registry+https://github.com/rust-lang/crates.io-index"
298
- checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5"
357
+ checksum = "79b079b829cb27a1c3c374341345ed2e8b2c0c839034522cee576c140bd7f846"
299
358
  dependencies = [
300
359
  "hashbrown",
360
+ "libc",
301
361
  "serde",
302
362
  "serde_json",
363
+ "tempfile",
303
364
  ]
304
365
 
305
366
  [[package]]
306
367
  name = "safetensors-ruby"
307
- version = "0.2.2"
368
+ version = "0.3.0"
308
369
  dependencies = [
309
370
  "magnus",
310
371
  "memmap2",
372
+ "rb-sys",
311
373
  "safetensors",
312
374
  "serde_json",
313
375
  ]
@@ -373,12 +435,49 @@ dependencies = [
373
435
  "unicode-ident",
374
436
  ]
375
437
 
438
+ [[package]]
439
+ name = "tempfile"
440
+ version = "3.23.0"
441
+ source = "registry+https://github.com/rust-lang/crates.io-index"
442
+ checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16"
443
+ dependencies = [
444
+ "fastrand",
445
+ "getrandom",
446
+ "once_cell",
447
+ "rustix",
448
+ "windows-sys",
449
+ ]
450
+
376
451
  [[package]]
377
452
  name = "unicode-ident"
378
453
  version = "1.0.18"
379
454
  source = "registry+https://github.com/rust-lang/crates.io-index"
380
455
  checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
381
456
 
457
+ [[package]]
458
+ name = "wasip2"
459
+ version = "1.0.3+wasi-0.2.9"
460
+ source = "registry+https://github.com/rust-lang/crates.io-index"
461
+ checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6"
462
+ dependencies = [
463
+ "wit-bindgen",
464
+ ]
465
+
466
+ [[package]]
467
+ name = "windows-link"
468
+ version = "0.2.1"
469
+ source = "registry+https://github.com/rust-lang/crates.io-index"
470
+ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
471
+
472
+ [[package]]
473
+ name = "windows-sys"
474
+ version = "0.61.2"
475
+ source = "registry+https://github.com/rust-lang/crates.io-index"
476
+ checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc"
477
+ dependencies = [
478
+ "windows-link",
479
+ ]
480
+
382
481
  [[package]]
383
482
  name = "windows-targets"
384
483
  version = "0.53.2"
@@ -442,3 +541,9 @@ name = "windows_x86_64_msvc"
442
541
  version = "0.53.0"
443
542
  source = "registry+https://github.com/rust-lang/crates.io-index"
444
543
  checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
544
+
545
+ [[package]]
546
+ name = "wit-bindgen"
547
+ version = "0.57.1"
548
+ source = "registry+https://github.com/rust-lang/crates.io-index"
549
+ checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e"
data/README.md CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  :slightly_smiling_face: Simple, [safe way](https://github.com/huggingface/safetensors) to store and distribute tensors
4
4
 
5
- Supports [Torch.rb](https://github.com/ankane/torch.rb) and [Numo](https://github.com/ruby-numo/numo-narray)
5
+ Supports [Torch.rb](https://github.com/ankane/torch.rb) and [numo-narray-alt](https://github.com/yoshoku/numo-narray-alt)
6
6
 
7
7
  [![Build Status](https://github.com/ankane/safetensors-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/safetensors-ruby/actions)
8
8
 
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "safetensors-ruby"
3
- version = "0.2.2"
3
+ version = "0.3.0"
4
4
  license = "Apache-2.0"
5
5
  authors = ["Andrew Kane <andrew@ankane.org>"]
6
6
  edition = "2021"
@@ -13,6 +13,7 @@ crate-type = ["cdylib"]
13
13
 
14
14
  [dependencies]
15
15
  magnus = "0.8"
16
- memmap2 = "0.5"
17
- safetensors = "=0.7.0"
16
+ memmap2 = "0.9"
17
+ rb-sys = "0.9"
18
+ safetensors = "=0.8.0"
18
19
  serde_json = "1"
@@ -1,76 +1,125 @@
1
+ mod ruby;
2
+
3
+ use core::slice;
1
4
  use magnus::{
2
5
  function, kwargs, method, prelude::*, r_hash::ForEach, Error, IntoValue, RArray, RHash,
3
6
  RModule, RString, Ruby, Symbol, TryConvert, Value,
4
7
  };
5
8
  use memmap2::{Mmap, MmapOptions};
6
- use safetensors::tensor::{Dtype, Metadata, SafeTensors, TensorView};
9
+ use safetensors::tensor::{Dtype, Metadata, SafeTensors};
10
+ use safetensors::View;
11
+ use std::borrow::Cow;
7
12
  use std::collections::HashMap;
13
+ use std::fmt;
8
14
  use std::fs::File;
9
15
  use std::path::PathBuf;
10
16
  use std::sync::Arc;
11
17
 
18
+ use crate::ruby::GvlExt;
19
+
12
20
  type RbResult<T> = Result<T, Error>;
13
21
 
14
- fn prepare(tensor_dict: &RHash) -> RbResult<HashMap<String, TensorView<'_>>> {
22
+ #[derive(Clone, Debug)]
23
+ struct TensorSpec {
24
+ dtype: Dtype,
25
+ shape: Vec<usize>,
26
+ data_ptr: u64,
27
+ data_len: usize,
28
+ }
29
+
30
+ impl TensorSpec {
31
+ fn new(dtype: &str, shape: Vec<usize>, data_ptr: u64, data_len: usize) -> RbResult<Self> {
32
+ let dtype = parse_dtype_str(dtype)?;
33
+ let mut shape = shape;
34
+ // F4 packs two elements per byte; the safetensors header records the
35
+ // logical element count, so double the last dim.
36
+ if dtype == Dtype::F4 && !shape.is_empty() {
37
+ let n = shape.len();
38
+ shape[n - 1] = shape[n - 1].checked_mul(2).ok_or_else(|| {
39
+ SafetensorError::new_err(format!(
40
+ "F4 last-dim {} doubled to logical shape overflows usize",
41
+ shape[n - 1]
42
+ ))
43
+ })?;
44
+ }
45
+ Ok(Self {
46
+ dtype,
47
+ shape,
48
+ data_ptr,
49
+ data_len,
50
+ })
51
+ }
52
+ }
53
+
54
+ impl View for &TensorSpec {
55
+ fn dtype(&self) -> Dtype {
56
+ self.dtype
57
+ }
58
+
59
+ fn shape(&self) -> &[usize] {
60
+ &self.shape
61
+ }
62
+
63
+ fn data(&self) -> Cow<'_, [u8]> {
64
+ let p = self.data_ptr as *const u8;
65
+ // SAFETY: validated by the caller; see the struct-level safety note.
66
+ unsafe {
67
+ let slice = slice::from_raw_parts(p, self.data_len);
68
+ Cow::Borrowed(slice)
69
+ }
70
+ }
71
+
72
+ fn data_len(&self) -> usize {
73
+ self.data_len
74
+ }
75
+ }
76
+
77
+ fn parse_dtype_str(dtype: &str) -> RbResult<Dtype> {
78
+ Ok(match dtype {
79
+ "bool" => Dtype::BOOL,
80
+ "int8" => Dtype::I8,
81
+ "uint8" => Dtype::U8,
82
+ "int16" => Dtype::I16,
83
+ "uint16" => Dtype::U16,
84
+ "int32" => Dtype::I32,
85
+ "uint32" => Dtype::U32,
86
+ "int64" => Dtype::I64,
87
+ "uint64" => Dtype::U64,
88
+ "float16" => Dtype::F16,
89
+ "float32" => Dtype::F32,
90
+ "float64" => Dtype::F64,
91
+ "bfloat16" => Dtype::BF16,
92
+ "float8_e4m3fn" => Dtype::F8_E4M3,
93
+ "float8_e4m3fnuz" => Dtype::F8_E4M3FNUZ,
94
+ "float8_e5m2" => Dtype::F8_E5M2,
95
+ "float8_e5m2fnuz" => Dtype::F8_E5M2FNUZ,
96
+ "float8_e8m0fnu" => Dtype::F8_E8M0,
97
+ "float4_e2m1fn_x2" => Dtype::F4,
98
+ "complex64" => Dtype::C64,
99
+ other => {
100
+ return Err(SafetensorError::new_err(format!(
101
+ "Unknown dtype {other:?}. Supported dtypes: bool, int8, uint8, int16, uint16, \
102
+ int32, uint32, int64, uint64, float16, float32, float64, bfloat16, \
103
+ float8_e4m3fn, float8_e4m3fnuz, float8_e5m2, float8_e5m2fnuz, float8_e8m0fnu, \
104
+ float4_e2m1fn_x2, complex64",
105
+ )));
106
+ }
107
+ })
108
+ }
109
+
110
+ fn prepare(tensor_dict: &RHash) -> RbResult<HashMap<String, TensorSpec>> {
15
111
  let mut tensors = HashMap::with_capacity(tensor_dict.len());
16
112
  tensor_dict.foreach(|tensor_name: String, tensor_desc: RHash| {
17
- let mut shape: Option<Vec<usize>> = None;
18
- let mut dtype: Option<Dtype> = None;
19
- let mut data: Option<(*const u8, usize)> = None;
20
-
21
- tensor_desc.foreach(|key: String, value: Value| {
22
- match key.as_str() {
23
- "shape" => shape = Some(Vec::try_convert(value)?),
24
- "dtype" => {
25
- let value = String::try_convert(value)?;
26
- dtype = match value.as_str() {
27
- "bool" => Some(Dtype::BOOL),
28
- "int8" => Some(Dtype::I8),
29
- "uint8" => Some(Dtype::U8),
30
- "int16" => Some(Dtype::I16),
31
- "uint16" => Some(Dtype::U16),
32
- "int32" => Some(Dtype::I32),
33
- "uint32" => Some(Dtype::U32),
34
- "int64" => Some(Dtype::I64),
35
- "uint64" => Some(Dtype::U64),
36
- "float16" => Some(Dtype::F16),
37
- "float32" => Some(Dtype::F32),
38
- "float64" => Some(Dtype::F64),
39
- "bfloat16" => Some(Dtype::BF16),
40
- "float8_e4m3fn" => Some(Dtype::F8_E4M3),
41
- "float8_e5m2" => Some(Dtype::F8_E5M2),
42
- dtype_str => {
43
- return Err(SafetensorError::new_err(format!(
44
- "dtype {dtype_str} is not covered",
45
- )));
46
- }
47
- }
48
- }
49
- "data" => {
50
- let rs = RString::try_convert(value)?;
51
- // SAFETY: No context switching between threads in native extensions
52
- // so the string will not be modified (or garbage collected)
53
- // while the reference is held. Also, the string is a private copy.
54
- let slice = unsafe { rs.as_slice() };
55
- data = Some((slice.as_ptr(), slice.len()));
56
- }
57
- _ => println!("Ignored unknown kwarg option {key}"),
58
- };
59
-
60
- Ok(ForEach::Continue)
61
- })?;
62
- let shape = shape.ok_or_else(|| {
63
- SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}"))
64
- })?;
65
- let dtype = dtype.ok_or_else(|| {
66
- SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}"))
67
- })?;
68
- let data = data.ok_or_else(|| {
69
- SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}"))
70
- })?;
71
- // SAFETY: See comment above.
72
- let data = unsafe { std::slice::from_raw_parts(data.0, data.1) };
73
- let tensor = TensorView::new(dtype, shape, data)
113
+ let dtype: String = tensor_desc.aref("dtype")?;
114
+ let shape: Vec<usize> = tensor_desc.aref("shape")?;
115
+ let data: RString = tensor_desc.aref("data")?;
116
+
117
+ // SAFETY: No context switching between threads in native extensions
118
+ // so the string will not be modified (or garbage collected)
119
+ // while the reference is held. Also, the string is a private copy.
120
+ let slice = unsafe { data.as_slice() };
121
+
122
+ let tensor = TensorSpec::new(dtype.as_ref(), shape, slice.as_ptr() as u64, slice.len())
74
123
  .map_err(|e| SafetensorError::new_err(format!("Error preparing tensor view: {e:?}")))?;
75
124
  tensors.insert(tensor_name, tensor);
76
125
 
@@ -86,19 +135,21 @@ fn serialize(
86
135
  ) -> RbResult<RString> {
87
136
  let tensors = prepare(&tensor_dict)?;
88
137
  let metadata_map = metadata.map(HashMap::from_iter);
89
- let out = safetensors::tensor::serialize(&tensors, metadata_map)
138
+ let out = ruby
139
+ .detach(|| safetensors::tensor::serialize(&tensors, metadata_map))
90
140
  .map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
91
141
  let rbbytes = ruby.str_from_slice(&out);
92
142
  Ok(rbbytes)
93
143
  }
94
144
 
95
145
  fn serialize_file(
146
+ ruby: &Ruby,
96
147
  tensor_dict: RHash,
97
148
  filename: PathBuf,
98
149
  metadata: Option<HashMap<String, String>>,
99
150
  ) -> RbResult<()> {
100
151
  let tensors = prepare(&tensor_dict)?;
101
- safetensors::tensor::serialize_to_file(&tensors, metadata, filename.as_path())
152
+ ruby.detach(|| safetensors::tensor::serialize_to_file(&tensors, metadata, filename.as_path()))
102
153
  .map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
103
154
  Ok(())
104
155
  }
@@ -132,6 +183,15 @@ enum Framework {
132
183
  Numo,
133
184
  }
134
185
 
186
+ impl fmt::Display for Framework {
187
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188
+ f.write_str(match *self {
189
+ Framework::Pytorch => "torch",
190
+ Framework::Numo => "numo",
191
+ })
192
+ }
193
+ }
194
+
135
195
  impl TryConvert for Framework {
136
196
  fn try_convert(ob: Value) -> RbResult<Self> {
137
197
  let name: String = String::try_convert(ob)?;
@@ -157,56 +217,64 @@ enum Device {
157
217
  Mps,
158
218
  Npu(usize),
159
219
  Xpu(usize),
220
+ Xla(usize),
221
+ Mlu(usize),
222
+ Hpu(usize),
223
+ Anonymous(usize),
224
+ }
225
+
226
+ impl fmt::Display for Device {
227
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228
+ match *self {
229
+ Device::Cpu => write!(f, "cpu"),
230
+ Device::Mps => write!(f, "mps"),
231
+ Device::Cuda(index) => write!(f, "cuda:{index}"),
232
+ Device::Npu(index) => write!(f, "npu:{index}"),
233
+ Device::Xpu(index) => write!(f, "xpu:{index}"),
234
+ Device::Xla(index) => write!(f, "xla:{index}"),
235
+ Device::Mlu(index) => write!(f, "mlu:{index}"),
236
+ Device::Hpu(index) => write!(f, "hpu:{index}"),
237
+ Device::Anonymous(index) => write!(f, "{index}"),
238
+ }
239
+ }
240
+ }
241
+
242
+ /// Parsing the device index.
243
+ fn parse_device(name: &str) -> RbResult<usize> {
244
+ let tokens: Vec<_> = name.split(':').collect();
245
+ if tokens.len() == 2 {
246
+ Ok(tokens[1].parse().map_err(SafetensorError::parse)?)
247
+ } else {
248
+ Err(SafetensorError::new_err(format!(
249
+ "device {name} is invalid"
250
+ )))
251
+ }
160
252
  }
161
253
 
162
254
  impl TryConvert for Device {
163
255
  fn try_convert(ob: Value) -> RbResult<Self> {
164
256
  if let Ok(name) = String::try_convert(ob) {
165
- match &name[..] {
257
+ match name.as_str() {
166
258
  "cpu" => Ok(Device::Cpu),
167
259
  "cuda" => Ok(Device::Cuda(0)),
168
260
  "mps" => Ok(Device::Mps),
169
261
  "npu" => Ok(Device::Npu(0)),
170
262
  "xpu" => Ok(Device::Xpu(0)),
171
- name if name.starts_with("cuda:") => {
172
- let tokens: Vec<_> = name.split(':').collect();
173
- if tokens.len() == 2 {
174
- let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
175
- Ok(Device::Cuda(device))
176
- } else {
177
- Err(SafetensorError::new_err(format!(
178
- "device {name} is invalid"
179
- )))
180
- }
181
- }
182
- name if name.starts_with("npu:") => {
183
- let tokens: Vec<_> = name.split(':').collect();
184
- if tokens.len() == 2 {
185
- let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
186
- Ok(Device::Npu(device))
187
- } else {
188
- Err(SafetensorError::new_err(format!(
189
- "device {name} is invalid"
190
- )))
191
- }
192
- }
193
- name if name.starts_with("xpu:") => {
194
- let tokens: Vec<_> = name.split(':').collect();
195
- if tokens.len() == 2 {
196
- let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
197
- Ok(Device::Xpu(device))
198
- } else {
199
- Err(SafetensorError::new_err(format!(
200
- "device {name} is invalid"
201
- )))
202
- }
203
- }
263
+ "xla" => Ok(Device::Xla(0)),
264
+ "mlu" => Ok(Device::Mlu(0)),
265
+ "hpu" => Ok(Device::Hpu(0)),
266
+ name if name.starts_with("cuda:") => parse_device(name).map(Device::Cuda),
267
+ name if name.starts_with("npu:") => parse_device(name).map(Device::Npu),
268
+ name if name.starts_with("xpu:") => parse_device(name).map(Device::Xpu),
269
+ name if name.starts_with("xla:") => parse_device(name).map(Device::Xla),
270
+ name if name.starts_with("mlu:") => parse_device(name).map(Device::Mlu),
271
+ name if name.starts_with("hpu:") => parse_device(name).map(Device::Hpu),
204
272
  name => Err(SafetensorError::new_err(format!(
205
273
  "device {name} is invalid"
206
274
  ))),
207
275
  }
208
276
  } else if let Ok(number) = usize::try_convert(ob) {
209
- Ok(Device::Cuda(number))
277
+ Ok(Device::Anonymous(number))
210
278
  } else {
211
279
  Err(SafetensorError::new_err(format!("device {ob} is invalid")))
212
280
  }
@@ -221,6 +289,10 @@ impl IntoValue for Device {
221
289
  Device::Mps => "mps".into_value_with(ruby),
222
290
  Device::Npu(n) => format!("npu:{n}").into_value_with(ruby),
223
291
  Device::Xpu(n) => format!("xpu:{n}").into_value_with(ruby),
292
+ Device::Xla(n) => format!("xla:{n}").into_value_with(ruby),
293
+ Device::Mlu(n) => format!("mlu:{n}").into_value_with(ruby),
294
+ Device::Hpu(n) => format!("hpu:{n}").into_value_with(ruby),
295
+ Device::Anonymous(n) => n.into_value_with(ruby),
224
296
  }
225
297
  }
226
298
  }
@@ -240,13 +312,13 @@ struct Open {
240
312
  impl Open {
241
313
  fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
242
314
  let file = File::open(&filename).map_err(|_| {
243
- SafetensorError::new_err(format!("No such file or directory: {filename:?}"))
315
+ SafetensorError::new_err(format!("No such file or directory: {}", filename.display()))
244
316
  })?;
245
317
  let device = device.unwrap_or(Device::Cpu);
246
318
 
247
319
  if device != Device::Cpu && framework != Framework::Pytorch {
248
320
  return Err(SafetensorError::new_err(format!(
249
- "Device {device:?} is not support for framework {framework:?}",
321
+ "Device {device} is not support for framework {framework}",
250
322
  )));
251
323
  }
252
324
 
@@ -255,7 +327,7 @@ impl Open {
255
327
  let buffer = unsafe { MmapOptions::new().map(&file).map_err(SafetensorError::io)? };
256
328
 
257
329
  let (n, metadata) = SafeTensors::read_metadata(&buffer).map_err(|e| {
258
- SafetensorError::new_err(format!("Error while deserializing header: {e:?}"))
330
+ SafetensorError::new_err(format!("Error while deserializing header: {e}"))
259
331
  })?;
260
332
 
261
333
  let offset = n + 8;
@@ -285,7 +357,7 @@ impl Open {
285
357
 
286
358
  pub fn get_tensor(&self, ruby: &Ruby, name: &str) -> RbResult<Value> {
287
359
  let info = self.metadata.info(name).ok_or_else(|| {
288
- SafetensorError::new_err(format!("File does not contain tensor {name}",))
360
+ SafetensorError::new_err(format!("File does not contain tensor {name}"))
289
361
  })?;
290
362
 
291
363
  match &self.storage.as_ref() {
@@ -357,7 +429,7 @@ fn create_tensor(
357
429
  .map_err(|_| SafetensorError::new_err("Torch not loaded".into()))?,
358
430
  false,
359
431
  ),
360
- _ => (
432
+ Framework::Numo => (
361
433
  ruby.class_object()
362
434
  .const_get("Numo")
363
435
  .map_err(|_| SafetensorError::new_err("Numo not loaded".into()))?,
@@ -375,7 +447,7 @@ fn create_tensor(
375
447
  )?;
376
448
  module.funcall("_from_blob_ref", (array, shape, options))?
377
449
  }
378
- _ => {
450
+ Framework::Numo => {
379
451
  let class: Value = module.funcall("const_get", (dtype,))?;
380
452
  class.funcall("from_binary", (array, shape))?
381
453
  }
@@ -419,6 +491,10 @@ fn get_rbdtype(ruby: &Ruby, _module: RModule, dtype: Dtype, is_numo: bool) -> Rb
419
491
  Dtype::BOOL => ruby.to_symbol("bool"),
420
492
  Dtype::F8_E4M3 => ruby.to_symbol("float8_e4m3fn"),
421
493
  Dtype::F8_E5M2 => ruby.to_symbol("float8_e5m2"),
494
+ Dtype::F8_E5M2FNUZ => ruby.to_symbol("float8_e5m2fnuz"),
495
+ Dtype::F8_E8M0 => ruby.to_symbol("float8_e8m0fnu"),
496
+ Dtype::F4 => ruby.to_symbol("float4_e2m1fn_x2"),
497
+ Dtype::C64 => ruby.to_symbol("complex64"),
422
498
  dtype => {
423
499
  return Err(SafetensorError::new_err(format!(
424
500
  "Dtype not understood: {dtype:?}"
@@ -0,0 +1,51 @@
1
+ use std::ffi::c_void;
2
+ use std::ptr::null_mut;
3
+
4
+ use magnus::Ruby;
5
+ use rb_sys::rb_thread_call_without_gvl;
6
+
7
+ pub trait GvlExt {
8
+ fn detach<T, F>(&self, func: F) -> T
9
+ where
10
+ F: Send + FnOnce() -> T,
11
+ T: Send;
12
+ }
13
+
14
+ impl GvlExt for Ruby {
15
+ fn detach<T, F>(&self, func: F) -> T
16
+ where
17
+ F: Send + FnOnce() -> T,
18
+ T: Send,
19
+ {
20
+ let mut data = CallbackData {
21
+ func: Some(func),
22
+ result: None,
23
+ };
24
+
25
+ unsafe {
26
+ rb_thread_call_without_gvl(
27
+ Some(call_without_gvl::<F, T>),
28
+ &mut data as *mut _ as *mut c_void,
29
+ None,
30
+ null_mut(),
31
+ );
32
+ }
33
+
34
+ data.result.unwrap()
35
+ }
36
+ }
37
+
38
+ struct CallbackData<F, T> {
39
+ func: Option<F>,
40
+ result: Option<T>,
41
+ }
42
+
43
+ extern "C" fn call_without_gvl<F, T>(data: *mut c_void) -> *mut c_void
44
+ where
45
+ F: FnOnce() -> T,
46
+ {
47
+ let data = unsafe { &mut *(data as *mut CallbackData<F, T>) };
48
+ let func = data.func.take().unwrap();
49
+ data.result = Some(func());
50
+ null_mut()
51
+ }
@@ -1,3 +1,3 @@
1
1
  module Safetensors
2
- VERSION = "0.2.2"
2
+ VERSION = "0.3.0"
3
3
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: safetensors
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.2
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
@@ -37,6 +37,7 @@ files:
37
37
  - ext/safetensors/Cargo.toml
38
38
  - ext/safetensors/extconf.rb
39
39
  - ext/safetensors/src/lib.rs
40
+ - ext/safetensors/src/ruby.rs
40
41
  - lib/safetensors.rb
41
42
  - lib/safetensors/numo.rb
42
43
  - lib/safetensors/torch.rb
@@ -52,14 +53,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
52
53
  requirements:
53
54
  - - ">="
54
55
  - !ruby/object:Gem::Version
55
- version: '3.2'
56
+ version: '3.3'
56
57
  required_rubygems_version: !ruby/object:Gem::Requirement
57
58
  requirements:
58
59
  - - ">="
59
60
  - !ruby/object:Gem::Version
60
61
  version: '0'
61
62
  requirements: []
62
- rubygems_version: 4.0.3
63
+ rubygems_version: 4.0.6
63
64
  specification_version: 4
64
65
  summary: Simple, safe way to store and distribute tensors
65
66
  test_files: []