safetensors 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
data/README.md ADDED
@@ -0,0 +1,67 @@
1
+ # Safetensors Ruby
2
+
3
+ :slightly_smiling_face: Simple, [safe way](https://github.com/huggingface/safetensors) to store and distribute tensors
4
+
5
+ Supports [Torch.rb](https://github.com/ankane/torch.rb) and [Numo](https://github.com/ruby-numo/numo-narray)
6
+
7
+ [![Build Status](https://github.com/ankane/safetensors-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/safetensors-ruby/actions)
8
+
9
+ ## Installation
10
+
11
+ Add this line to your application’s Gemfile:
12
+
13
+ ```ruby
14
+ gem "safetensors"
15
+ ```
16
+
17
+ ## Getting Started
18
+
19
+ Save tensors
20
+
21
+ ```ruby
22
+ tensors = {
23
+ "weight1" => Torch.zeros([1024, 1024]),
24
+ "weight2" => Torch.zeros([1024, 1024])
25
+ }
26
+ Safetensors::Torch.save_file(tensors, "model.safetensors")
27
+ ```
28
+
29
+ Load tensors
30
+
31
+ ```ruby
32
+ tensors = Safetensors::Torch.load_file("model.safetensors")
33
+ # or
34
+ tensors = {}
35
+ Safetensors.safe_open("model.safetensors", framework: "torch", device: "cpu") do |f|
36
+ f.keys.each do |key|
37
+ tensors[key] = f.get_tensor(key)
38
+ end
39
+ end
40
+ ```
41
+
42
+ ## API
43
+
44
+ This library follows the [Safetensors Python API](https://huggingface.co/docs/safetensors/index). You can follow Python tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
45
+
46
+ ## History
47
+
48
+ View the [changelog](https://github.com/ankane/safetensors-ruby/blob/master/CHANGELOG.md)
49
+
50
+ ## Contributing
51
+
52
+ Everyone is encouraged to help improve this project. Here are a few ways you can help:
53
+
54
+ - [Report bugs](https://github.com/ankane/safetensors-ruby/issues)
55
+ - Fix bugs and [submit pull requests](https://github.com/ankane/safetensors-ruby/pulls)
56
+ - Write, clarify, or fix documentation
57
+ - Suggest or add new features
58
+
59
+ To get started with development:
60
+
61
+ ```sh
62
+ git clone https://github.com/ankane/safetensors-ruby.git
63
+ cd safetensors-ruby
64
+ bundle install
65
+ bundle exec rake compile
66
+ bundle exec rake test
67
+ ```
@@ -0,0 +1,17 @@
1
+ [package]
2
+ name = "safetensors"
3
+ version = "0.1.0"
4
+ license = "Apache-2.0"
5
+ authors = ["Andrew Kane <andrew@ankane.org>"]
6
+ edition = "2021"
7
+ rust-version = "1.62.0"
8
+ publish = false
9
+
10
+ [lib]
11
+ crate-type = ["cdylib"]
12
+
13
+ [dependencies]
14
+ magnus = "0.6"
15
+ memmap2 = "0.5"
16
+ safetensors = "=0.4.2"
17
+ serde_json = "1.0"
@@ -0,0 +1,4 @@
1
+ require "mkmf"
2
+ require "rb_sys/mkmf"
3
+
4
+ create_rust_makefile("safetensors/safetensors")
@@ -0,0 +1,464 @@
1
+ use magnus::{
2
+ function, kwargs, method, prelude::*, r_hash::ForEach, Error, IntoValue, RArray, RHash,
3
+ RModule, RString, Ruby, Symbol, TryConvert, Value,
4
+ };
5
+ use memmap2::{Mmap, MmapOptions};
6
+ use safetensors::tensor::{Dtype, Metadata, SafeTensors, TensorView};
7
+ use std::collections::HashMap;
8
+ use std::fs::File;
9
+ use std::path::PathBuf;
10
+ use std::sync::Arc;
11
+
12
+ type RbResult<T> = Result<T, Error>;
13
+
14
+ fn prepare(tensor_dict: &RHash) -> RbResult<HashMap<String, TensorView<'_>>> {
15
+ let mut tensors = HashMap::with_capacity(tensor_dict.len());
16
+ tensor_dict.foreach(|tensor_name: String, tensor_desc: RHash| {
17
+ let mut shape: Option<Vec<usize>> = None;
18
+ let mut dtype: Option<Dtype> = None;
19
+ let mut data: Option<(*const u8, usize)> = None;
20
+
21
+ tensor_desc.foreach(|key: String, value: Value| {
22
+ match key.as_str() {
23
+ "shape" => shape = Some(Vec::try_convert(value)?),
24
+ "dtype" => {
25
+ let value = String::try_convert(value)?;
26
+ dtype = match value.as_str() {
27
+ "bool" => Some(Dtype::BOOL),
28
+ "int8" => Some(Dtype::I8),
29
+ "uint8" => Some(Dtype::U8),
30
+ "int16" => Some(Dtype::I16),
31
+ "uint16" => Some(Dtype::U16),
32
+ "int32" => Some(Dtype::I32),
33
+ "uint32" => Some(Dtype::U32),
34
+ "int64" => Some(Dtype::I64),
35
+ "uint64" => Some(Dtype::U64),
36
+ "float16" => Some(Dtype::F16),
37
+ "float32" => Some(Dtype::F32),
38
+ "float64" => Some(Dtype::F64),
39
+ "bfloat16" => Some(Dtype::BF16),
40
+ "float8_e4m3fn" => Some(Dtype::F8_E4M3),
41
+ "float8_e5m2" => Some(Dtype::F8_E5M2),
42
+ dtype_str => {
43
+ return Err(SafetensorError::new_err(format!(
44
+ "dtype {dtype_str} is not covered",
45
+ )));
46
+ }
47
+ }
48
+ }
49
+ "data" => {
50
+ let rs = RString::try_convert(value)?;
51
+ // SAFETY: No context switching between threads in native extensions
52
+ // so the string will not be modified (or garbage collected)
53
+ // while the reference is held. Also, the string is a private copy.
54
+ let slice = unsafe { rs.as_slice() };
55
+ data = Some((slice.as_ptr(), slice.len()))
56
+ }
57
+ _ => println!("Ignored unknown kwarg option {key}"),
58
+ };
59
+
60
+ Ok(ForEach::Continue)
61
+ })?;
62
+ let shape = shape.ok_or_else(|| {
63
+ SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}"))
64
+ })?;
65
+ let dtype = dtype.ok_or_else(|| {
66
+ SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}"))
67
+ })?;
68
+ let data = data.ok_or_else(|| {
69
+ SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}"))
70
+ })?;
71
+ // SAFETY: See comment above.
72
+ let data = unsafe { std::slice::from_raw_parts(data.0, data.1) };
73
+ let tensor = TensorView::new(dtype, shape, data)
74
+ .map_err(|e| SafetensorError::new_err(format!("Error preparing tensor view: {e:?}")))?;
75
+ tensors.insert(tensor_name, tensor);
76
+
77
+ Ok(ForEach::Continue)
78
+ })?;
79
+ Ok(tensors)
80
+ }
81
+
82
+ fn serialize(tensor_dict: RHash, metadata: Option<HashMap<String, String>>) -> RbResult<RString> {
83
+ let tensors = prepare(&tensor_dict)?;
84
+ let metadata_map = metadata.map(HashMap::from_iter);
85
+ let out = safetensors::tensor::serialize(&tensors, &metadata_map)
86
+ .map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
87
+ let rbbytes = RString::from_slice(&out);
88
+ Ok(rbbytes)
89
+ }
90
+
91
+ fn serialize_file(
92
+ tensor_dict: RHash,
93
+ filename: PathBuf,
94
+ metadata: Option<HashMap<String, String>>,
95
+ ) -> RbResult<()> {
96
+ let tensors = prepare(&tensor_dict)?;
97
+ safetensors::tensor::serialize_to_file(&tensors, &metadata, filename.as_path())
98
+ .map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
99
+ Ok(())
100
+ }
101
+
102
+ fn deserialize(bytes: RString) -> RbResult<RArray> {
103
+ let safetensor = SafeTensors::deserialize(unsafe { bytes.as_slice() })
104
+ .map_err(|e| SafetensorError::new_err(format!("Error while deserializing: {e:?}")))?;
105
+
106
+ let tensors = safetensor.tensors();
107
+ let items = RArray::with_capacity(tensors.len());
108
+
109
+ for (tensor_name, tensor) in tensors {
110
+ let rbshape = RArray::from_vec(tensor.shape().to_vec());
111
+ let rbdtype = format!("{:?}", tensor.dtype());
112
+
113
+ let rbdata = RString::from_slice(tensor.data());
114
+
115
+ let map = RHash::new();
116
+ map.aset("shape", rbshape)?;
117
+ map.aset("dtype", rbdtype)?;
118
+ map.aset("data", rbdata)?;
119
+
120
+ items.push((tensor_name, map))?;
121
+ }
122
+ Ok(items)
123
+ }
124
+
125
+ #[derive(Debug, Clone, PartialEq, Eq)]
126
+ enum Framework {
127
+ Pytorch,
128
+ Numo,
129
+ }
130
+
131
+ impl TryConvert for Framework {
132
+ fn try_convert(ob: Value) -> RbResult<Self> {
133
+ let name: String = String::try_convert(ob)?;
134
+ match &name[..] {
135
+ "pt" => Ok(Framework::Pytorch),
136
+ "torch" => Ok(Framework::Pytorch),
137
+ "pytorch" => Ok(Framework::Pytorch),
138
+
139
+ "nm" => Ok(Framework::Numo),
140
+ "numo" => Ok(Framework::Numo),
141
+
142
+ name => Err(SafetensorError::new_err(format!(
143
+ "framework {name} is invalid"
144
+ ))),
145
+ }
146
+ }
147
+ }
148
+
149
+ #[derive(Debug, Clone, PartialEq, Eq)]
150
+ enum Device {
151
+ Cpu,
152
+ Cuda(usize),
153
+ Mps,
154
+ Npu(usize),
155
+ Xpu(usize),
156
+ }
157
+
158
+ impl TryConvert for Device {
159
+ fn try_convert(ob: Value) -> RbResult<Self> {
160
+ if let Ok(name) = String::try_convert(ob) {
161
+ match &name[..] {
162
+ "cpu" => Ok(Device::Cpu),
163
+ "cuda" => Ok(Device::Cuda(0)),
164
+ "mps" => Ok(Device::Mps),
165
+ "npu" => Ok(Device::Npu(0)),
166
+ "xpu" => Ok(Device::Xpu(0)),
167
+ name if name.starts_with("cuda:") => {
168
+ let tokens: Vec<_> = name.split(':').collect();
169
+ if tokens.len() == 2 {
170
+ let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
171
+ Ok(Device::Cuda(device))
172
+ } else {
173
+ Err(SafetensorError::new_err(format!(
174
+ "device {name} is invalid"
175
+ )))
176
+ }
177
+ }
178
+ name if name.starts_with("npu:") => {
179
+ let tokens: Vec<_> = name.split(':').collect();
180
+ if tokens.len() == 2 {
181
+ let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
182
+ Ok(Device::Npu(device))
183
+ } else {
184
+ Err(SafetensorError::new_err(format!(
185
+ "device {name} is invalid"
186
+ )))
187
+ }
188
+ }
189
+ name if name.starts_with("xpu:") => {
190
+ let tokens: Vec<_> = name.split(':').collect();
191
+ if tokens.len() == 2 {
192
+ let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
193
+ Ok(Device::Xpu(device))
194
+ } else {
195
+ Err(SafetensorError::new_err(format!(
196
+ "device {name} is invalid"
197
+ )))
198
+ }
199
+ }
200
+ name => Err(SafetensorError::new_err(format!(
201
+ "device {name} is invalid"
202
+ ))),
203
+ }
204
+ } else if let Ok(number) = usize::try_convert(ob) {
205
+ Ok(Device::Cuda(number))
206
+ } else {
207
+ Err(SafetensorError::new_err(format!("device {ob} is invalid")))
208
+ }
209
+ }
210
+ }
211
+
212
+ impl IntoValue for Device {
213
+ fn into_value_with(self, ruby: &Ruby) -> Value {
214
+ match self {
215
+ Device::Cpu => "cpu".into_value_with(ruby),
216
+ Device::Cuda(n) => format!("cuda:{n}").into_value_with(ruby),
217
+ Device::Mps => "mps".into_value_with(ruby),
218
+ Device::Npu(n) => format!("npu:{n}").into_value_with(ruby),
219
+ Device::Xpu(n) => format!("xpu:{n}").into_value_with(ruby),
220
+ }
221
+ }
222
+ }
223
+
224
+ enum Storage {
225
+ Mmap(Mmap),
226
+ }
227
+
228
+ struct Open {
229
+ metadata: Metadata,
230
+ offset: usize,
231
+ framework: Framework,
232
+ device: Device,
233
+ storage: Arc<Storage>,
234
+ }
235
+
236
+ impl Open {
237
+ fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
238
+ let file = File::open(&filename).map_err(|_| {
239
+ SafetensorError::new_err(format!("No such file or directory: {filename:?}"))
240
+ })?;
241
+ let device = device.unwrap_or(Device::Cpu);
242
+
243
+ if device != Device::Cpu && framework != Framework::Pytorch {
244
+ return Err(SafetensorError::new_err(format!(
245
+ "Device {device:?} is not support for framework {framework:?}",
246
+ )));
247
+ }
248
+
249
+ // SAFETY: Mmap is used to prevent allocating in Rust
250
+ // before making a copy within Ruby.
251
+ let buffer = unsafe { MmapOptions::new().map(&file).map_err(SafetensorError::io)? };
252
+
253
+ let (n, metadata) = SafeTensors::read_metadata(&buffer).map_err(|e| {
254
+ SafetensorError::new_err(format!("Error while deserializing header: {e:?}"))
255
+ })?;
256
+
257
+ let offset = n + 8;
258
+
259
+ let storage = Storage::Mmap(buffer);
260
+
261
+ let storage = Arc::new(storage);
262
+
263
+ Ok(Self {
264
+ metadata,
265
+ offset,
266
+ framework,
267
+ device,
268
+ storage,
269
+ })
270
+ }
271
+
272
+ pub fn metadata(&self) -> Option<HashMap<String, String>> {
273
+ self.metadata.metadata().clone()
274
+ }
275
+
276
+ pub fn keys(&self) -> RbResult<Vec<String>> {
277
+ let mut keys: Vec<String> = self.metadata.tensors().keys().cloned().collect();
278
+ keys.sort();
279
+ Ok(keys)
280
+ }
281
+
282
+ pub fn get_tensor(&self, name: &str) -> RbResult<Value> {
283
+ let info = self.metadata.info(name).ok_or_else(|| {
284
+ SafetensorError::new_err(format!("File does not contain tensor {name}",))
285
+ })?;
286
+
287
+ match &self.storage.as_ref() {
288
+ Storage::Mmap(mmap) => {
289
+ let data =
290
+ &mmap[info.data_offsets.0 + self.offset..info.data_offsets.1 + self.offset];
291
+
292
+ let array: Value = RString::from_slice(data).into_value();
293
+
294
+ create_tensor(
295
+ &self.framework,
296
+ info.dtype,
297
+ &info.shape,
298
+ array,
299
+ &self.device,
300
+ )
301
+ }
302
+ }
303
+ }
304
+ }
305
+
306
+ #[magnus::wrap(class = "Safetensors::SafeOpen")]
307
+ struct SafeOpen {
308
+ inner: Option<Open>,
309
+ }
310
+
311
+ impl SafeOpen {
312
+ fn inner(&self) -> RbResult<&Open> {
313
+ let inner = self
314
+ .inner
315
+ .as_ref()
316
+ .ok_or_else(|| SafetensorError::new_err("File is closed".to_string()))?;
317
+ Ok(inner)
318
+ }
319
+ }
320
+
321
+ impl SafeOpen {
322
+ pub fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
323
+ let inner = Some(Open::new(filename, framework, device)?);
324
+ Ok(Self { inner })
325
+ }
326
+
327
+ pub fn metadata(&self) -> RbResult<Option<HashMap<String, String>>> {
328
+ Ok(self.inner()?.metadata())
329
+ }
330
+
331
+ pub fn keys(&self) -> RbResult<Vec<String>> {
332
+ self.inner()?.keys()
333
+ }
334
+
335
+ pub fn get_tensor(&self, name: String) -> RbResult<Value> {
336
+ self.inner()?.get_tensor(&name)
337
+ }
338
+ }
339
+
340
+ fn create_tensor(
341
+ framework: &Framework,
342
+ dtype: Dtype,
343
+ shape: &[usize],
344
+ array: Value,
345
+ device: &Device,
346
+ ) -> RbResult<Value> {
347
+ let ruby = Ruby::get().unwrap();
348
+ let (module, is_numo): (RModule, bool) = match framework {
349
+ Framework::Pytorch => (
350
+ ruby.class_object()
351
+ .const_get("Torch")
352
+ .map_err(|_| SafetensorError::new_err("Torch not loaded".into()))?,
353
+ false,
354
+ ),
355
+ _ => (
356
+ ruby.class_object()
357
+ .const_get("Numo")
358
+ .map_err(|_| SafetensorError::new_err("Numo not loaded".into()))?,
359
+ true,
360
+ ),
361
+ };
362
+
363
+ let dtype = get_rbdtype(module, dtype, is_numo)?;
364
+ let shape = shape.to_vec();
365
+ let tensor: Value = match framework {
366
+ Framework::Pytorch => {
367
+ let options: Value = module.funcall(
368
+ "tensor_options",
369
+ (kwargs!("dtype" => dtype, "device" => device.clone()),),
370
+ )?;
371
+ module.funcall("_from_blob_ref", (array, shape, options))?
372
+ }
373
+ _ => {
374
+ let class: Value = module.funcall("const_get", (dtype,))?;
375
+ class.funcall("from_binary", (array, shape))?
376
+ }
377
+ };
378
+ Ok(tensor)
379
+ }
380
+
381
+ fn get_rbdtype(_module: RModule, dtype: Dtype, is_numo: bool) -> RbResult<Value> {
382
+ let dtype: Value = if is_numo {
383
+ match dtype {
384
+ Dtype::F64 => Symbol::new("DFloat").into_value(),
385
+ Dtype::F32 => Symbol::new("SFloat").into_value(),
386
+ Dtype::U64 => Symbol::new("UInt64").into_value(),
387
+ Dtype::I64 => Symbol::new("Int64").into_value(),
388
+ Dtype::U32 => Symbol::new("UInt32").into_value(),
389
+ Dtype::I32 => Symbol::new("Int32").into_value(),
390
+ Dtype::U16 => Symbol::new("UInt16").into_value(),
391
+ Dtype::I16 => Symbol::new("Int16").into_value(),
392
+ Dtype::U8 => Symbol::new("UInt8").into_value(),
393
+ Dtype::I8 => Symbol::new("Int8").into_value(),
394
+ dtype => {
395
+ return Err(SafetensorError::new_err(format!(
396
+ "Dtype not understood: {dtype:?}"
397
+ )))
398
+ }
399
+ }
400
+ } else {
401
+ match dtype {
402
+ Dtype::F64 => Symbol::new("float64").into_value(),
403
+ Dtype::F32 => Symbol::new("float32").into_value(),
404
+ Dtype::BF16 => Symbol::new("bfloat16").into_value(),
405
+ Dtype::F16 => Symbol::new("float16").into_value(),
406
+ Dtype::U64 => Symbol::new("uint64").into_value(),
407
+ Dtype::I64 => Symbol::new("int64").into_value(),
408
+ Dtype::U32 => Symbol::new("uint32").into_value(),
409
+ Dtype::I32 => Symbol::new("int32").into_value(),
410
+ Dtype::U16 => Symbol::new("uint16").into_value(),
411
+ Dtype::I16 => Symbol::new("int16").into_value(),
412
+ Dtype::U8 => Symbol::new("uint8").into_value(),
413
+ Dtype::I8 => Symbol::new("int8").into_value(),
414
+ Dtype::BOOL => Symbol::new("bool").into_value(),
415
+ Dtype::F8_E4M3 => Symbol::new("float8_e4m3fn").into_value(),
416
+ Dtype::F8_E5M2 => Symbol::new("float8_e5m2").into_value(),
417
+ dtype => {
418
+ return Err(SafetensorError::new_err(format!(
419
+ "Dtype not understood: {dtype:?}"
420
+ )))
421
+ }
422
+ }
423
+ };
424
+ Ok(dtype)
425
+ }
426
+
427
+ struct SafetensorError {}
428
+
429
+ impl SafetensorError {
430
+ fn new_err(message: String) -> Error {
431
+ let class = Ruby::get()
432
+ .unwrap()
433
+ .class_object()
434
+ .const_get::<_, RModule>("Safetensors")
435
+ .unwrap()
436
+ .const_get("Error")
437
+ .unwrap();
438
+ Error::new(class, message)
439
+ }
440
+
441
+ fn io(err: std::io::Error) -> Error {
442
+ Self::new_err(err.to_string())
443
+ }
444
+
445
+ fn parse(err: std::num::ParseIntError) -> Error {
446
+ Self::new_err(err.to_string())
447
+ }
448
+ }
449
+
450
+ #[magnus::init]
451
+ fn init(ruby: &Ruby) -> RbResult<()> {
452
+ let module = ruby.define_module("Safetensors")?;
453
+ module.define_singleton_method("_serialize", function!(serialize, 2))?;
454
+ module.define_singleton_method("_serialize_file", function!(serialize_file, 3))?;
455
+ module.define_singleton_method("deserialize", function!(deserialize, 1))?;
456
+
457
+ let class = module.define_class("SafeOpen", ruby.class_object())?;
458
+ class.define_singleton_method("new", function!(SafeOpen::new, 3))?;
459
+ class.define_method("metadata", method!(SafeOpen::metadata, 0))?;
460
+ class.define_method("keys", method!(SafeOpen::keys, 0))?;
461
+ class.define_method("get_tensor", method!(SafeOpen::get_tensor, 1))?;
462
+
463
+ Ok(())
464
+ }
@@ -0,0 +1,93 @@
1
+ module Safetensors
2
+ module Numo
3
+ DTYPES = {
4
+ "DFloat" => "float64",
5
+ "SFloat" => "float32"
6
+ }
7
+
8
+ TYPES = {
9
+ "F64" => :DFloat,
10
+ "F32" => :SFloat,
11
+ "I64" => :Int64,
12
+ "U64" => :UInt64,
13
+ "I32" => :Int32,
14
+ "U32" => :UInt32,
15
+ "I16" => :Int16,
16
+ "U16" => :UInt16,
17
+ "I8" => :Int8,
18
+ "U8" => :UInt8
19
+ }
20
+
21
+ class << self
22
+ def save(tensor_dict, metadata: nil)
23
+ Safetensors.serialize(_flatten(tensor_dict), metadata: metadata)
24
+ end
25
+
26
+ def save_file(tensor_dict, filename, metadata: nil)
27
+ Safetensors.serialize_file(_flatten(tensor_dict), filename, metadata: metadata)
28
+ end
29
+
30
+ def load(data)
31
+ flat = Safetensors.deserialize(data)
32
+ _view2numo(flat)
33
+ end
34
+
35
+ def load_file(filename)
36
+ result = {}
37
+ Safetensors.safe_open(filename, framework: "numo") do |f|
38
+ f.keys.each do |k|
39
+ result[k] = f.get_tensor(k)
40
+ end
41
+ end
42
+ result
43
+ end
44
+
45
+ private
46
+
47
+ def _flatten(tensors)
48
+ if !tensors.is_a?(Hash)
49
+ raise ArgumentError, "Expected a hash of [String, Numo::NArray] but received #{tensors.class.name}"
50
+ end
51
+
52
+ tensors.each do |k, v|
53
+ if !v.is_a?(::Numo::NArray)
54
+ raise ArgumentError, "Key `#{k}` is invalid, expected Numo::NArray but received #{v.class.name}"
55
+ end
56
+ end
57
+
58
+ tensors.to_h do |k, v|
59
+ [
60
+ k.is_a?(Symbol) ? k.to_s : k,
61
+ {
62
+ "dtype" => DTYPES.fetch(v.class.name.split("::").last),
63
+ "shape" => v.shape,
64
+ "data" => _tobytes(v)
65
+ }
66
+ ]
67
+ end
68
+ end
69
+
70
+ def _tobytes(tensor)
71
+ if Safetensors.big_endian?
72
+ raise "Not yet implemented"
73
+ end
74
+
75
+ tensor.to_binary
76
+ end
77
+
78
+ def _getdtype(dtype_str)
79
+ TYPES.fetch(dtype_str)
80
+ end
81
+
82
+ def _view2numo(safeview)
83
+ result = {}
84
+ safeview.each do |k, v|
85
+ dtype = _getdtype(v["dtype"])
86
+ arr = ::Numo.const_get(dtype).from_binary(v["data"], v["shape"])
87
+ result[k] = arr
88
+ end
89
+ result
90
+ end
91
+ end
92
+ end
93
+ end