safetensors 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/README.md ADDED
@@ -0,0 +1,67 @@
1
+ # Safetensors Ruby
2
+
3
+ :slightly_smiling_face: Simple, [safe way](https://github.com/huggingface/safetensors) to store and distribute tensors
4
+
5
+ Supports [Torch.rb](https://github.com/ankane/torch.rb) and [Numo](https://github.com/ruby-numo/numo-narray)
6
+
7
+ [![Build Status](https://github.com/ankane/safetensors-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/safetensors-ruby/actions)
8
+
9
+ ## Installation
10
+
11
+ Add this line to your application’s Gemfile:
12
+
13
+ ```ruby
14
+ gem "safetensors"
15
+ ```
16
+
17
+ ## Getting Started
18
+
19
+ Save tensors
20
+
21
+ ```ruby
22
+ tensors = {
23
+ "weight1" => Torch.zeros([1024, 1024]),
24
+ "weight2" => Torch.zeros([1024, 1024])
25
+ }
26
+ Safetensors::Torch.save_file(tensors, "model.safetensors")
27
+ ```
28
+
29
+ Load tensors
30
+
31
+ ```ruby
32
+ tensors = Safetensors::Torch.load_file("model.safetensors")
33
+ # or
34
+ tensors = {}
35
+ Safetensors.safe_open("model.safetensors", framework: "torch", device: "cpu") do |f|
36
+ f.keys.each do |key|
37
+ tensors[key] = f.get_tensor(key)
38
+ end
39
+ end
40
+ ```
41
+
42
+ ## API
43
+
44
+ This library follows the [Safetensors Python API](https://huggingface.co/docs/safetensors/index). You can follow Python tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
45
+
46
+ ## History
47
+
48
+ View the [changelog](https://github.com/ankane/safetensors-ruby/blob/master/CHANGELOG.md)
49
+
50
+ ## Contributing
51
+
52
+ Everyone is encouraged to help improve this project. Here are a few ways you can help:
53
+
54
+ - [Report bugs](https://github.com/ankane/safetensors-ruby/issues)
55
+ - Fix bugs and [submit pull requests](https://github.com/ankane/safetensors-ruby/pulls)
56
+ - Write, clarify, or fix documentation
57
+ - Suggest or add new features
58
+
59
+ To get started with development:
60
+
61
+ ```sh
62
+ git clone https://github.com/ankane/safetensors-ruby.git
63
+ cd safetensors-ruby
64
+ bundle install
65
+ bundle exec rake compile
66
+ bundle exec rake test
67
+ ```
@@ -0,0 +1,17 @@
1
+ [package]
2
+ name = "safetensors"
3
+ version = "0.1.0"
4
+ license = "Apache-2.0"
5
+ authors = ["Andrew Kane <andrew@ankane.org>"]
6
+ edition = "2021"
7
+ rust-version = "1.62.0"
8
+ publish = false
9
+
10
+ [lib]
11
+ crate-type = ["cdylib"]
12
+
13
+ [dependencies]
14
+ magnus = "0.6"
15
+ memmap2 = "0.5"
16
+ safetensors = "=0.4.2"
17
+ serde_json = "1.0"
@@ -0,0 +1,4 @@
1
+ require "mkmf"
2
+ require "rb_sys/mkmf"
3
+
4
+ create_rust_makefile("safetensors/safetensors")
@@ -0,0 +1,464 @@
1
+ use magnus::{
2
+ function, kwargs, method, prelude::*, r_hash::ForEach, Error, IntoValue, RArray, RHash,
3
+ RModule, RString, Ruby, Symbol, TryConvert, Value,
4
+ };
5
+ use memmap2::{Mmap, MmapOptions};
6
+ use safetensors::tensor::{Dtype, Metadata, SafeTensors, TensorView};
7
+ use std::collections::HashMap;
8
+ use std::fs::File;
9
+ use std::path::PathBuf;
10
+ use std::sync::Arc;
11
+
12
+ type RbResult<T> = Result<T, Error>;
13
+
14
+ fn prepare(tensor_dict: &RHash) -> RbResult<HashMap<String, TensorView<'_>>> {
15
+ let mut tensors = HashMap::with_capacity(tensor_dict.len());
16
+ tensor_dict.foreach(|tensor_name: String, tensor_desc: RHash| {
17
+ let mut shape: Option<Vec<usize>> = None;
18
+ let mut dtype: Option<Dtype> = None;
19
+ let mut data: Option<(*const u8, usize)> = None;
20
+
21
+ tensor_desc.foreach(|key: String, value: Value| {
22
+ match key.as_str() {
23
+ "shape" => shape = Some(Vec::try_convert(value)?),
24
+ "dtype" => {
25
+ let value = String::try_convert(value)?;
26
+ dtype = match value.as_str() {
27
+ "bool" => Some(Dtype::BOOL),
28
+ "int8" => Some(Dtype::I8),
29
+ "uint8" => Some(Dtype::U8),
30
+ "int16" => Some(Dtype::I16),
31
+ "uint16" => Some(Dtype::U16),
32
+ "int32" => Some(Dtype::I32),
33
+ "uint32" => Some(Dtype::U32),
34
+ "int64" => Some(Dtype::I64),
35
+ "uint64" => Some(Dtype::U64),
36
+ "float16" => Some(Dtype::F16),
37
+ "float32" => Some(Dtype::F32),
38
+ "float64" => Some(Dtype::F64),
39
+ "bfloat16" => Some(Dtype::BF16),
40
+ "float8_e4m3fn" => Some(Dtype::F8_E4M3),
41
+ "float8_e5m2" => Some(Dtype::F8_E5M2),
42
+ dtype_str => {
43
+ return Err(SafetensorError::new_err(format!(
44
+ "dtype {dtype_str} is not covered",
45
+ )));
46
+ }
47
+ }
48
+ }
49
+ "data" => {
50
+ let rs = RString::try_convert(value)?;
51
+ // SAFETY: No context switching between threads in native extensions
52
+ // so the string will not be modified (or garbage collected)
53
+ // while the reference is held. Also, the string is a private copy.
54
+ let slice = unsafe { rs.as_slice() };
55
+ data = Some((slice.as_ptr(), slice.len()))
56
+ }
57
+ _ => println!("Ignored unknown kwarg option {key}"),
58
+ };
59
+
60
+ Ok(ForEach::Continue)
61
+ })?;
62
+ let shape = shape.ok_or_else(|| {
63
+ SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}"))
64
+ })?;
65
+ let dtype = dtype.ok_or_else(|| {
66
+ SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}"))
67
+ })?;
68
+ let data = data.ok_or_else(|| {
69
+ SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}"))
70
+ })?;
71
+ // SAFETY: See comment above.
72
+ let data = unsafe { std::slice::from_raw_parts(data.0, data.1) };
73
+ let tensor = TensorView::new(dtype, shape, data)
74
+ .map_err(|e| SafetensorError::new_err(format!("Error preparing tensor view: {e:?}")))?;
75
+ tensors.insert(tensor_name, tensor);
76
+
77
+ Ok(ForEach::Continue)
78
+ })?;
79
+ Ok(tensors)
80
+ }
81
+
82
+ fn serialize(tensor_dict: RHash, metadata: Option<HashMap<String, String>>) -> RbResult<RString> {
83
+ let tensors = prepare(&tensor_dict)?;
84
+ let metadata_map = metadata.map(HashMap::from_iter);
85
+ let out = safetensors::tensor::serialize(&tensors, &metadata_map)
86
+ .map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
87
+ let rbbytes = RString::from_slice(&out);
88
+ Ok(rbbytes)
89
+ }
90
+
91
+ fn serialize_file(
92
+ tensor_dict: RHash,
93
+ filename: PathBuf,
94
+ metadata: Option<HashMap<String, String>>,
95
+ ) -> RbResult<()> {
96
+ let tensors = prepare(&tensor_dict)?;
97
+ safetensors::tensor::serialize_to_file(&tensors, &metadata, filename.as_path())
98
+ .map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
99
+ Ok(())
100
+ }
101
+
102
+ fn deserialize(bytes: RString) -> RbResult<RArray> {
103
+ let safetensor = SafeTensors::deserialize(unsafe { bytes.as_slice() })
104
+ .map_err(|e| SafetensorError::new_err(format!("Error while deserializing: {e:?}")))?;
105
+
106
+ let tensors = safetensor.tensors();
107
+ let items = RArray::with_capacity(tensors.len());
108
+
109
+ for (tensor_name, tensor) in tensors {
110
+ let rbshape = RArray::from_vec(tensor.shape().to_vec());
111
+ let rbdtype = format!("{:?}", tensor.dtype());
112
+
113
+ let rbdata = RString::from_slice(tensor.data());
114
+
115
+ let map = RHash::new();
116
+ map.aset("shape", rbshape)?;
117
+ map.aset("dtype", rbdtype)?;
118
+ map.aset("data", rbdata)?;
119
+
120
+ items.push((tensor_name, map))?;
121
+ }
122
+ Ok(items)
123
+ }
124
+
125
+ #[derive(Debug, Clone, PartialEq, Eq)]
126
+ enum Framework {
127
+ Pytorch,
128
+ Numo,
129
+ }
130
+
131
+ impl TryConvert for Framework {
132
+ fn try_convert(ob: Value) -> RbResult<Self> {
133
+ let name: String = String::try_convert(ob)?;
134
+ match &name[..] {
135
+ "pt" => Ok(Framework::Pytorch),
136
+ "torch" => Ok(Framework::Pytorch),
137
+ "pytorch" => Ok(Framework::Pytorch),
138
+
139
+ "nm" => Ok(Framework::Numo),
140
+ "numo" => Ok(Framework::Numo),
141
+
142
+ name => Err(SafetensorError::new_err(format!(
143
+ "framework {name} is invalid"
144
+ ))),
145
+ }
146
+ }
147
+ }
148
+
149
+ #[derive(Debug, Clone, PartialEq, Eq)]
150
+ enum Device {
151
+ Cpu,
152
+ Cuda(usize),
153
+ Mps,
154
+ Npu(usize),
155
+ Xpu(usize),
156
+ }
157
+
158
+ impl TryConvert for Device {
159
+ fn try_convert(ob: Value) -> RbResult<Self> {
160
+ if let Ok(name) = String::try_convert(ob) {
161
+ match &name[..] {
162
+ "cpu" => Ok(Device::Cpu),
163
+ "cuda" => Ok(Device::Cuda(0)),
164
+ "mps" => Ok(Device::Mps),
165
+ "npu" => Ok(Device::Npu(0)),
166
+ "xpu" => Ok(Device::Xpu(0)),
167
+ name if name.starts_with("cuda:") => {
168
+ let tokens: Vec<_> = name.split(':').collect();
169
+ if tokens.len() == 2 {
170
+ let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
171
+ Ok(Device::Cuda(device))
172
+ } else {
173
+ Err(SafetensorError::new_err(format!(
174
+ "device {name} is invalid"
175
+ )))
176
+ }
177
+ }
178
+ name if name.starts_with("npu:") => {
179
+ let tokens: Vec<_> = name.split(':').collect();
180
+ if tokens.len() == 2 {
181
+ let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
182
+ Ok(Device::Npu(device))
183
+ } else {
184
+ Err(SafetensorError::new_err(format!(
185
+ "device {name} is invalid"
186
+ )))
187
+ }
188
+ }
189
+ name if name.starts_with("xpu:") => {
190
+ let tokens: Vec<_> = name.split(':').collect();
191
+ if tokens.len() == 2 {
192
+ let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
193
+ Ok(Device::Xpu(device))
194
+ } else {
195
+ Err(SafetensorError::new_err(format!(
196
+ "device {name} is invalid"
197
+ )))
198
+ }
199
+ }
200
+ name => Err(SafetensorError::new_err(format!(
201
+ "device {name} is invalid"
202
+ ))),
203
+ }
204
+ } else if let Ok(number) = usize::try_convert(ob) {
205
+ Ok(Device::Cuda(number))
206
+ } else {
207
+ Err(SafetensorError::new_err(format!("device {ob} is invalid")))
208
+ }
209
+ }
210
+ }
211
+
212
+ impl IntoValue for Device {
213
+ fn into_value_with(self, ruby: &Ruby) -> Value {
214
+ match self {
215
+ Device::Cpu => "cpu".into_value_with(ruby),
216
+ Device::Cuda(n) => format!("cuda:{n}").into_value_with(ruby),
217
+ Device::Mps => "mps".into_value_with(ruby),
218
+ Device::Npu(n) => format!("npu:{n}").into_value_with(ruby),
219
+ Device::Xpu(n) => format!("xpu:{n}").into_value_with(ruby),
220
+ }
221
+ }
222
+ }
223
+
224
+ enum Storage {
225
+ Mmap(Mmap),
226
+ }
227
+
228
+ struct Open {
229
+ metadata: Metadata,
230
+ offset: usize,
231
+ framework: Framework,
232
+ device: Device,
233
+ storage: Arc<Storage>,
234
+ }
235
+
236
+ impl Open {
237
+ fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
238
+ let file = File::open(&filename).map_err(|_| {
239
+ SafetensorError::new_err(format!("No such file or directory: {filename:?}"))
240
+ })?;
241
+ let device = device.unwrap_or(Device::Cpu);
242
+
243
+ if device != Device::Cpu && framework != Framework::Pytorch {
244
+ return Err(SafetensorError::new_err(format!(
245
+ "Device {device:?} is not support for framework {framework:?}",
246
+ )));
247
+ }
248
+
249
+ // SAFETY: Mmap is used to prevent allocating in Rust
250
+ // before making a copy within Ruby.
251
+ let buffer = unsafe { MmapOptions::new().map(&file).map_err(SafetensorError::io)? };
252
+
253
+ let (n, metadata) = SafeTensors::read_metadata(&buffer).map_err(|e| {
254
+ SafetensorError::new_err(format!("Error while deserializing header: {e:?}"))
255
+ })?;
256
+
257
+ let offset = n + 8;
258
+
259
+ let storage = Storage::Mmap(buffer);
260
+
261
+ let storage = Arc::new(storage);
262
+
263
+ Ok(Self {
264
+ metadata,
265
+ offset,
266
+ framework,
267
+ device,
268
+ storage,
269
+ })
270
+ }
271
+
272
+ pub fn metadata(&self) -> Option<HashMap<String, String>> {
273
+ self.metadata.metadata().clone()
274
+ }
275
+
276
+ pub fn keys(&self) -> RbResult<Vec<String>> {
277
+ let mut keys: Vec<String> = self.metadata.tensors().keys().cloned().collect();
278
+ keys.sort();
279
+ Ok(keys)
280
+ }
281
+
282
+ pub fn get_tensor(&self, name: &str) -> RbResult<Value> {
283
+ let info = self.metadata.info(name).ok_or_else(|| {
284
+ SafetensorError::new_err(format!("File does not contain tensor {name}",))
285
+ })?;
286
+
287
+ match &self.storage.as_ref() {
288
+ Storage::Mmap(mmap) => {
289
+ let data =
290
+ &mmap[info.data_offsets.0 + self.offset..info.data_offsets.1 + self.offset];
291
+
292
+ let array: Value = RString::from_slice(data).into_value();
293
+
294
+ create_tensor(
295
+ &self.framework,
296
+ info.dtype,
297
+ &info.shape,
298
+ array,
299
+ &self.device,
300
+ )
301
+ }
302
+ }
303
+ }
304
+ }
305
+
306
+ #[magnus::wrap(class = "Safetensors::SafeOpen")]
307
+ struct SafeOpen {
308
+ inner: Option<Open>,
309
+ }
310
+
311
+ impl SafeOpen {
312
+ fn inner(&self) -> RbResult<&Open> {
313
+ let inner = self
314
+ .inner
315
+ .as_ref()
316
+ .ok_or_else(|| SafetensorError::new_err("File is closed".to_string()))?;
317
+ Ok(inner)
318
+ }
319
+ }
320
+
321
+ impl SafeOpen {
322
+ pub fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
323
+ let inner = Some(Open::new(filename, framework, device)?);
324
+ Ok(Self { inner })
325
+ }
326
+
327
+ pub fn metadata(&self) -> RbResult<Option<HashMap<String, String>>> {
328
+ Ok(self.inner()?.metadata())
329
+ }
330
+
331
+ pub fn keys(&self) -> RbResult<Vec<String>> {
332
+ self.inner()?.keys()
333
+ }
334
+
335
+ pub fn get_tensor(&self, name: String) -> RbResult<Value> {
336
+ self.inner()?.get_tensor(&name)
337
+ }
338
+ }
339
+
340
+ fn create_tensor(
341
+ framework: &Framework,
342
+ dtype: Dtype,
343
+ shape: &[usize],
344
+ array: Value,
345
+ device: &Device,
346
+ ) -> RbResult<Value> {
347
+ let ruby = Ruby::get().unwrap();
348
+ let (module, is_numo): (RModule, bool) = match framework {
349
+ Framework::Pytorch => (
350
+ ruby.class_object()
351
+ .const_get("Torch")
352
+ .map_err(|_| SafetensorError::new_err("Torch not loaded".into()))?,
353
+ false,
354
+ ),
355
+ _ => (
356
+ ruby.class_object()
357
+ .const_get("Numo")
358
+ .map_err(|_| SafetensorError::new_err("Numo not loaded".into()))?,
359
+ true,
360
+ ),
361
+ };
362
+
363
+ let dtype = get_rbdtype(module, dtype, is_numo)?;
364
+ let shape = shape.to_vec();
365
+ let tensor: Value = match framework {
366
+ Framework::Pytorch => {
367
+ let options: Value = module.funcall(
368
+ "tensor_options",
369
+ (kwargs!("dtype" => dtype, "device" => device.clone()),),
370
+ )?;
371
+ module.funcall("_from_blob_ref", (array, shape, options))?
372
+ }
373
+ _ => {
374
+ let class: Value = module.funcall("const_get", (dtype,))?;
375
+ class.funcall("from_binary", (array, shape))?
376
+ }
377
+ };
378
+ Ok(tensor)
379
+ }
380
+
381
+ fn get_rbdtype(_module: RModule, dtype: Dtype, is_numo: bool) -> RbResult<Value> {
382
+ let dtype: Value = if is_numo {
383
+ match dtype {
384
+ Dtype::F64 => Symbol::new("DFloat").into_value(),
385
+ Dtype::F32 => Symbol::new("SFloat").into_value(),
386
+ Dtype::U64 => Symbol::new("UInt64").into_value(),
387
+ Dtype::I64 => Symbol::new("Int64").into_value(),
388
+ Dtype::U32 => Symbol::new("UInt32").into_value(),
389
+ Dtype::I32 => Symbol::new("Int32").into_value(),
390
+ Dtype::U16 => Symbol::new("UInt16").into_value(),
391
+ Dtype::I16 => Symbol::new("Int16").into_value(),
392
+ Dtype::U8 => Symbol::new("UInt8").into_value(),
393
+ Dtype::I8 => Symbol::new("Int8").into_value(),
394
+ dtype => {
395
+ return Err(SafetensorError::new_err(format!(
396
+ "Dtype not understood: {dtype:?}"
397
+ )))
398
+ }
399
+ }
400
+ } else {
401
+ match dtype {
402
+ Dtype::F64 => Symbol::new("float64").into_value(),
403
+ Dtype::F32 => Symbol::new("float32").into_value(),
404
+ Dtype::BF16 => Symbol::new("bfloat16").into_value(),
405
+ Dtype::F16 => Symbol::new("float16").into_value(),
406
+ Dtype::U64 => Symbol::new("uint64").into_value(),
407
+ Dtype::I64 => Symbol::new("int64").into_value(),
408
+ Dtype::U32 => Symbol::new("uint32").into_value(),
409
+ Dtype::I32 => Symbol::new("int32").into_value(),
410
+ Dtype::U16 => Symbol::new("uint16").into_value(),
411
+ Dtype::I16 => Symbol::new("int16").into_value(),
412
+ Dtype::U8 => Symbol::new("uint8").into_value(),
413
+ Dtype::I8 => Symbol::new("int8").into_value(),
414
+ Dtype::BOOL => Symbol::new("bool").into_value(),
415
+ Dtype::F8_E4M3 => Symbol::new("float8_e4m3fn").into_value(),
416
+ Dtype::F8_E5M2 => Symbol::new("float8_e5m2").into_value(),
417
+ dtype => {
418
+ return Err(SafetensorError::new_err(format!(
419
+ "Dtype not understood: {dtype:?}"
420
+ )))
421
+ }
422
+ }
423
+ };
424
+ Ok(dtype)
425
+ }
426
+
427
+ struct SafetensorError {}
428
+
429
+ impl SafetensorError {
430
+ fn new_err(message: String) -> Error {
431
+ let class = Ruby::get()
432
+ .unwrap()
433
+ .class_object()
434
+ .const_get::<_, RModule>("Safetensors")
435
+ .unwrap()
436
+ .const_get("Error")
437
+ .unwrap();
438
+ Error::new(class, message)
439
+ }
440
+
441
+ fn io(err: std::io::Error) -> Error {
442
+ Self::new_err(err.to_string())
443
+ }
444
+
445
+ fn parse(err: std::num::ParseIntError) -> Error {
446
+ Self::new_err(err.to_string())
447
+ }
448
+ }
449
+
450
+ #[magnus::init]
451
+ fn init(ruby: &Ruby) -> RbResult<()> {
452
+ let module = ruby.define_module("Safetensors")?;
453
+ module.define_singleton_method("_serialize", function!(serialize, 2))?;
454
+ module.define_singleton_method("_serialize_file", function!(serialize_file, 3))?;
455
+ module.define_singleton_method("deserialize", function!(deserialize, 1))?;
456
+
457
+ let class = module.define_class("SafeOpen", ruby.class_object())?;
458
+ class.define_singleton_method("new", function!(SafeOpen::new, 3))?;
459
+ class.define_method("metadata", method!(SafeOpen::metadata, 0))?;
460
+ class.define_method("keys", method!(SafeOpen::keys, 0))?;
461
+ class.define_method("get_tensor", method!(SafeOpen::get_tensor, 1))?;
462
+
463
+ Ok(())
464
+ }
@@ -0,0 +1,93 @@
1
+ module Safetensors
2
+ module Numo
3
+ DTYPES = {
4
+ "DFloat" => "float64",
5
+ "SFloat" => "float32"
6
+ }
7
+
8
+ TYPES = {
9
+ "F64" => :DFloat,
10
+ "F32" => :SFloat,
11
+ "I64" => :Int64,
12
+ "U64" => :UInt64,
13
+ "I32" => :Int32,
14
+ "U32" => :UInt32,
15
+ "I16" => :Int16,
16
+ "U16" => :UInt16,
17
+ "I8" => :Int8,
18
+ "U8" => :UInt8
19
+ }
20
+
21
+ class << self
22
+ def save(tensor_dict, metadata: nil)
23
+ Safetensors.serialize(_flatten(tensor_dict), metadata: metadata)
24
+ end
25
+
26
+ def save_file(tensor_dict, filename, metadata: nil)
27
+ Safetensors.serialize_file(_flatten(tensor_dict), filename, metadata: metadata)
28
+ end
29
+
30
+ def load(data)
31
+ flat = Safetensors.deserialize(data)
32
+ _view2numo(flat)
33
+ end
34
+
35
+ def load_file(filename)
36
+ result = {}
37
+ Safetensors.safe_open(filename, framework: "numo") do |f|
38
+ f.keys.each do |k|
39
+ result[k] = f.get_tensor(k)
40
+ end
41
+ end
42
+ result
43
+ end
44
+
45
+ private
46
+
47
+ def _flatten(tensors)
48
+ if !tensors.is_a?(Hash)
49
+ raise ArgumentError, "Expected a hash of [String, Numo::NArray] but received #{tensors.class.name}"
50
+ end
51
+
52
+ tensors.each do |k, v|
53
+ if !v.is_a?(::Numo::NArray)
54
+ raise ArgumentError, "Key `#{k}` is invalid, expected Numo::NArray but received #{v.class.name}"
55
+ end
56
+ end
57
+
58
+ tensors.to_h do |k, v|
59
+ [
60
+ k.is_a?(Symbol) ? k.to_s : k,
61
+ {
62
+ "dtype" => DTYPES.fetch(v.class.name.split("::").last),
63
+ "shape" => v.shape,
64
+ "data" => _tobytes(v)
65
+ }
66
+ ]
67
+ end
68
+ end
69
+
70
+ def _tobytes(tensor)
71
+ if Safetensors.big_endian?
72
+ raise "Not yet implemented"
73
+ end
74
+
75
+ tensor.to_binary
76
+ end
77
+
78
+ def _getdtype(dtype_str)
79
+ TYPES.fetch(dtype_str)
80
+ end
81
+
82
+ def _view2numo(safeview)
83
+ result = {}
84
+ safeview.each do |k, v|
85
+ dtype = _getdtype(v["dtype"])
86
+ arr = ::Numo.const_get(dtype).from_binary(v["data"], v["shape"])
87
+ result[k] = arr
88
+ end
89
+ result
90
+ end
91
+ end
92
+ end
93
+ end