safetensors 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/Cargo.lock +414 -0
- data/Cargo.toml +6 -0
- data/LICENSE.txt +201 -0
- data/README.md +67 -0
- data/ext/safetensors/Cargo.toml +17 -0
- data/ext/safetensors/extconf.rb +4 -0
- data/ext/safetensors/src/lib.rs +464 -0
- data/lib/safetensors/numo.rb +93 -0
- data/lib/safetensors/torch.rb +141 -0
- data/lib/safetensors/version.rb +3 -0
- data/lib/safetensors.rb +37 -0
- metadata +69 -0
data/README.md
ADDED
@@ -0,0 +1,67 @@
|
|
1
|
+
# Safetensors Ruby
|
2
|
+
|
3
|
+
:slightly_smiling_face: Simple, [safe way](https://github.com/huggingface/safetensors) to store and distribute tensors
|
4
|
+
|
5
|
+
Supports [Torch.rb](https://github.com/ankane/torch.rb) and [Numo](https://github.com/ruby-numo/numo-narray)
|
6
|
+
|
7
|
+
[![Build Status](https://github.com/ankane/safetensors-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/safetensors-ruby/actions)
|
8
|
+
|
9
|
+
## Installation
|
10
|
+
|
11
|
+
Add this line to your application’s Gemfile:
|
12
|
+
|
13
|
+
```ruby
|
14
|
+
gem "safetensors"
|
15
|
+
```
|
16
|
+
|
17
|
+
## Getting Started
|
18
|
+
|
19
|
+
Save tensors
|
20
|
+
|
21
|
+
```ruby
|
22
|
+
tensors = {
|
23
|
+
"weight1" => Torch.zeros([1024, 1024]),
|
24
|
+
"weight2" => Torch.zeros([1024, 1024])
|
25
|
+
}
|
26
|
+
Safetensors::Torch.save_file(tensors, "model.safetensors")
|
27
|
+
```
|
28
|
+
|
29
|
+
Load tensors
|
30
|
+
|
31
|
+
```ruby
|
32
|
+
tensors = Safetensors::Torch.load_file("model.safetensors")
|
33
|
+
# or
|
34
|
+
tensors = {}
|
35
|
+
Safetensors.safe_open("model.safetensors", framework: "torch", device: "cpu") do |f|
|
36
|
+
f.keys.each do |key|
|
37
|
+
tensors[key] = f.get_tensor(key)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
```
|
41
|
+
|
42
|
+
## API
|
43
|
+
|
44
|
+
This library follows the [Safetensors Python API](https://huggingface.co/docs/safetensors/index). You can follow Python tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
|
45
|
+
|
46
|
+
## History
|
47
|
+
|
48
|
+
View the [changelog](https://github.com/ankane/safetensors-ruby/blob/master/CHANGELOG.md)
|
49
|
+
|
50
|
+
## Contributing
|
51
|
+
|
52
|
+
Everyone is encouraged to help improve this project. Here are a few ways you can help:
|
53
|
+
|
54
|
+
- [Report bugs](https://github.com/ankane/safetensors-ruby/issues)
|
55
|
+
- Fix bugs and [submit pull requests](https://github.com/ankane/safetensors-ruby/pulls)
|
56
|
+
- Write, clarify, or fix documentation
|
57
|
+
- Suggest or add new features
|
58
|
+
|
59
|
+
To get started with development:
|
60
|
+
|
61
|
+
```sh
|
62
|
+
git clone https://github.com/ankane/safetensors-ruby.git
|
63
|
+
cd safetensors-ruby
|
64
|
+
bundle install
|
65
|
+
bundle exec rake compile
|
66
|
+
bundle exec rake test
|
67
|
+
```
|
@@ -0,0 +1,17 @@
|
|
1
|
+
[package]
|
2
|
+
name = "safetensors"
|
3
|
+
version = "0.1.0"
|
4
|
+
license = "Apache-2.0"
|
5
|
+
authors = ["Andrew Kane <andrew@ankane.org>"]
|
6
|
+
edition = "2021"
|
7
|
+
rust-version = "1.62.0"
|
8
|
+
publish = false
|
9
|
+
|
10
|
+
[lib]
|
11
|
+
crate-type = ["cdylib"]
|
12
|
+
|
13
|
+
[dependencies]
|
14
|
+
magnus = "0.6"
|
15
|
+
memmap2 = "0.5"
|
16
|
+
safetensors = "=0.4.2"
|
17
|
+
serde_json = "1.0"
|
@@ -0,0 +1,464 @@
|
|
1
|
+
use magnus::{
|
2
|
+
function, kwargs, method, prelude::*, r_hash::ForEach, Error, IntoValue, RArray, RHash,
|
3
|
+
RModule, RString, Ruby, Symbol, TryConvert, Value,
|
4
|
+
};
|
5
|
+
use memmap2::{Mmap, MmapOptions};
|
6
|
+
use safetensors::tensor::{Dtype, Metadata, SafeTensors, TensorView};
|
7
|
+
use std::collections::HashMap;
|
8
|
+
use std::fs::File;
|
9
|
+
use std::path::PathBuf;
|
10
|
+
use std::sync::Arc;
|
11
|
+
|
12
|
+
type RbResult<T> = Result<T, Error>;
|
13
|
+
|
14
|
+
fn prepare(tensor_dict: &RHash) -> RbResult<HashMap<String, TensorView<'_>>> {
|
15
|
+
let mut tensors = HashMap::with_capacity(tensor_dict.len());
|
16
|
+
tensor_dict.foreach(|tensor_name: String, tensor_desc: RHash| {
|
17
|
+
let mut shape: Option<Vec<usize>> = None;
|
18
|
+
let mut dtype: Option<Dtype> = None;
|
19
|
+
let mut data: Option<(*const u8, usize)> = None;
|
20
|
+
|
21
|
+
tensor_desc.foreach(|key: String, value: Value| {
|
22
|
+
match key.as_str() {
|
23
|
+
"shape" => shape = Some(Vec::try_convert(value)?),
|
24
|
+
"dtype" => {
|
25
|
+
let value = String::try_convert(value)?;
|
26
|
+
dtype = match value.as_str() {
|
27
|
+
"bool" => Some(Dtype::BOOL),
|
28
|
+
"int8" => Some(Dtype::I8),
|
29
|
+
"uint8" => Some(Dtype::U8),
|
30
|
+
"int16" => Some(Dtype::I16),
|
31
|
+
"uint16" => Some(Dtype::U16),
|
32
|
+
"int32" => Some(Dtype::I32),
|
33
|
+
"uint32" => Some(Dtype::U32),
|
34
|
+
"int64" => Some(Dtype::I64),
|
35
|
+
"uint64" => Some(Dtype::U64),
|
36
|
+
"float16" => Some(Dtype::F16),
|
37
|
+
"float32" => Some(Dtype::F32),
|
38
|
+
"float64" => Some(Dtype::F64),
|
39
|
+
"bfloat16" => Some(Dtype::BF16),
|
40
|
+
"float8_e4m3fn" => Some(Dtype::F8_E4M3),
|
41
|
+
"float8_e5m2" => Some(Dtype::F8_E5M2),
|
42
|
+
dtype_str => {
|
43
|
+
return Err(SafetensorError::new_err(format!(
|
44
|
+
"dtype {dtype_str} is not covered",
|
45
|
+
)));
|
46
|
+
}
|
47
|
+
}
|
48
|
+
}
|
49
|
+
"data" => {
|
50
|
+
let rs = RString::try_convert(value)?;
|
51
|
+
// SAFETY: No context switching between threads in native extensions
|
52
|
+
// so the string will not be modified (or garbage collected)
|
53
|
+
// while the reference is held. Also, the string is a private copy.
|
54
|
+
let slice = unsafe { rs.as_slice() };
|
55
|
+
data = Some((slice.as_ptr(), slice.len()))
|
56
|
+
}
|
57
|
+
_ => println!("Ignored unknown kwarg option {key}"),
|
58
|
+
};
|
59
|
+
|
60
|
+
Ok(ForEach::Continue)
|
61
|
+
})?;
|
62
|
+
let shape = shape.ok_or_else(|| {
|
63
|
+
SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}"))
|
64
|
+
})?;
|
65
|
+
let dtype = dtype.ok_or_else(|| {
|
66
|
+
SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}"))
|
67
|
+
})?;
|
68
|
+
let data = data.ok_or_else(|| {
|
69
|
+
SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}"))
|
70
|
+
})?;
|
71
|
+
// SAFETY: See comment above.
|
72
|
+
let data = unsafe { std::slice::from_raw_parts(data.0, data.1) };
|
73
|
+
let tensor = TensorView::new(dtype, shape, data)
|
74
|
+
.map_err(|e| SafetensorError::new_err(format!("Error preparing tensor view: {e:?}")))?;
|
75
|
+
tensors.insert(tensor_name, tensor);
|
76
|
+
|
77
|
+
Ok(ForEach::Continue)
|
78
|
+
})?;
|
79
|
+
Ok(tensors)
|
80
|
+
}
|
81
|
+
|
82
|
+
fn serialize(tensor_dict: RHash, metadata: Option<HashMap<String, String>>) -> RbResult<RString> {
|
83
|
+
let tensors = prepare(&tensor_dict)?;
|
84
|
+
let metadata_map = metadata.map(HashMap::from_iter);
|
85
|
+
let out = safetensors::tensor::serialize(&tensors, &metadata_map)
|
86
|
+
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
|
87
|
+
let rbbytes = RString::from_slice(&out);
|
88
|
+
Ok(rbbytes)
|
89
|
+
}
|
90
|
+
|
91
|
+
fn serialize_file(
|
92
|
+
tensor_dict: RHash,
|
93
|
+
filename: PathBuf,
|
94
|
+
metadata: Option<HashMap<String, String>>,
|
95
|
+
) -> RbResult<()> {
|
96
|
+
let tensors = prepare(&tensor_dict)?;
|
97
|
+
safetensors::tensor::serialize_to_file(&tensors, &metadata, filename.as_path())
|
98
|
+
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
|
99
|
+
Ok(())
|
100
|
+
}
|
101
|
+
|
102
|
+
fn deserialize(bytes: RString) -> RbResult<RArray> {
|
103
|
+
let safetensor = SafeTensors::deserialize(unsafe { bytes.as_slice() })
|
104
|
+
.map_err(|e| SafetensorError::new_err(format!("Error while deserializing: {e:?}")))?;
|
105
|
+
|
106
|
+
let tensors = safetensor.tensors();
|
107
|
+
let items = RArray::with_capacity(tensors.len());
|
108
|
+
|
109
|
+
for (tensor_name, tensor) in tensors {
|
110
|
+
let rbshape = RArray::from_vec(tensor.shape().to_vec());
|
111
|
+
let rbdtype = format!("{:?}", tensor.dtype());
|
112
|
+
|
113
|
+
let rbdata = RString::from_slice(tensor.data());
|
114
|
+
|
115
|
+
let map = RHash::new();
|
116
|
+
map.aset("shape", rbshape)?;
|
117
|
+
map.aset("dtype", rbdtype)?;
|
118
|
+
map.aset("data", rbdata)?;
|
119
|
+
|
120
|
+
items.push((tensor_name, map))?;
|
121
|
+
}
|
122
|
+
Ok(items)
|
123
|
+
}
|
124
|
+
|
125
|
+
#[derive(Debug, Clone, PartialEq, Eq)]
|
126
|
+
enum Framework {
|
127
|
+
Pytorch,
|
128
|
+
Numo,
|
129
|
+
}
|
130
|
+
|
131
|
+
impl TryConvert for Framework {
|
132
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
133
|
+
let name: String = String::try_convert(ob)?;
|
134
|
+
match &name[..] {
|
135
|
+
"pt" => Ok(Framework::Pytorch),
|
136
|
+
"torch" => Ok(Framework::Pytorch),
|
137
|
+
"pytorch" => Ok(Framework::Pytorch),
|
138
|
+
|
139
|
+
"nm" => Ok(Framework::Numo),
|
140
|
+
"numo" => Ok(Framework::Numo),
|
141
|
+
|
142
|
+
name => Err(SafetensorError::new_err(format!(
|
143
|
+
"framework {name} is invalid"
|
144
|
+
))),
|
145
|
+
}
|
146
|
+
}
|
147
|
+
}
|
148
|
+
|
149
|
+
#[derive(Debug, Clone, PartialEq, Eq)]
|
150
|
+
enum Device {
|
151
|
+
Cpu,
|
152
|
+
Cuda(usize),
|
153
|
+
Mps,
|
154
|
+
Npu(usize),
|
155
|
+
Xpu(usize),
|
156
|
+
}
|
157
|
+
|
158
|
+
impl TryConvert for Device {
|
159
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
160
|
+
if let Ok(name) = String::try_convert(ob) {
|
161
|
+
match &name[..] {
|
162
|
+
"cpu" => Ok(Device::Cpu),
|
163
|
+
"cuda" => Ok(Device::Cuda(0)),
|
164
|
+
"mps" => Ok(Device::Mps),
|
165
|
+
"npu" => Ok(Device::Npu(0)),
|
166
|
+
"xpu" => Ok(Device::Xpu(0)),
|
167
|
+
name if name.starts_with("cuda:") => {
|
168
|
+
let tokens: Vec<_> = name.split(':').collect();
|
169
|
+
if tokens.len() == 2 {
|
170
|
+
let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
|
171
|
+
Ok(Device::Cuda(device))
|
172
|
+
} else {
|
173
|
+
Err(SafetensorError::new_err(format!(
|
174
|
+
"device {name} is invalid"
|
175
|
+
)))
|
176
|
+
}
|
177
|
+
}
|
178
|
+
name if name.starts_with("npu:") => {
|
179
|
+
let tokens: Vec<_> = name.split(':').collect();
|
180
|
+
if tokens.len() == 2 {
|
181
|
+
let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
|
182
|
+
Ok(Device::Npu(device))
|
183
|
+
} else {
|
184
|
+
Err(SafetensorError::new_err(format!(
|
185
|
+
"device {name} is invalid"
|
186
|
+
)))
|
187
|
+
}
|
188
|
+
}
|
189
|
+
name if name.starts_with("xpu:") => {
|
190
|
+
let tokens: Vec<_> = name.split(':').collect();
|
191
|
+
if tokens.len() == 2 {
|
192
|
+
let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
|
193
|
+
Ok(Device::Xpu(device))
|
194
|
+
} else {
|
195
|
+
Err(SafetensorError::new_err(format!(
|
196
|
+
"device {name} is invalid"
|
197
|
+
)))
|
198
|
+
}
|
199
|
+
}
|
200
|
+
name => Err(SafetensorError::new_err(format!(
|
201
|
+
"device {name} is invalid"
|
202
|
+
))),
|
203
|
+
}
|
204
|
+
} else if let Ok(number) = usize::try_convert(ob) {
|
205
|
+
Ok(Device::Cuda(number))
|
206
|
+
} else {
|
207
|
+
Err(SafetensorError::new_err(format!("device {ob} is invalid")))
|
208
|
+
}
|
209
|
+
}
|
210
|
+
}
|
211
|
+
|
212
|
+
impl IntoValue for Device {
|
213
|
+
fn into_value_with(self, ruby: &Ruby) -> Value {
|
214
|
+
match self {
|
215
|
+
Device::Cpu => "cpu".into_value_with(ruby),
|
216
|
+
Device::Cuda(n) => format!("cuda:{n}").into_value_with(ruby),
|
217
|
+
Device::Mps => "mps".into_value_with(ruby),
|
218
|
+
Device::Npu(n) => format!("npu:{n}").into_value_with(ruby),
|
219
|
+
Device::Xpu(n) => format!("xpu:{n}").into_value_with(ruby),
|
220
|
+
}
|
221
|
+
}
|
222
|
+
}
|
223
|
+
|
224
|
+
enum Storage {
|
225
|
+
Mmap(Mmap),
|
226
|
+
}
|
227
|
+
|
228
|
+
struct Open {
|
229
|
+
metadata: Metadata,
|
230
|
+
offset: usize,
|
231
|
+
framework: Framework,
|
232
|
+
device: Device,
|
233
|
+
storage: Arc<Storage>,
|
234
|
+
}
|
235
|
+
|
236
|
+
impl Open {
|
237
|
+
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
|
238
|
+
let file = File::open(&filename).map_err(|_| {
|
239
|
+
SafetensorError::new_err(format!("No such file or directory: {filename:?}"))
|
240
|
+
})?;
|
241
|
+
let device = device.unwrap_or(Device::Cpu);
|
242
|
+
|
243
|
+
if device != Device::Cpu && framework != Framework::Pytorch {
|
244
|
+
return Err(SafetensorError::new_err(format!(
|
245
|
+
"Device {device:?} is not support for framework {framework:?}",
|
246
|
+
)));
|
247
|
+
}
|
248
|
+
|
249
|
+
// SAFETY: Mmap is used to prevent allocating in Rust
|
250
|
+
// before making a copy within Ruby.
|
251
|
+
let buffer = unsafe { MmapOptions::new().map(&file).map_err(SafetensorError::io)? };
|
252
|
+
|
253
|
+
let (n, metadata) = SafeTensors::read_metadata(&buffer).map_err(|e| {
|
254
|
+
SafetensorError::new_err(format!("Error while deserializing header: {e:?}"))
|
255
|
+
})?;
|
256
|
+
|
257
|
+
let offset = n + 8;
|
258
|
+
|
259
|
+
let storage = Storage::Mmap(buffer);
|
260
|
+
|
261
|
+
let storage = Arc::new(storage);
|
262
|
+
|
263
|
+
Ok(Self {
|
264
|
+
metadata,
|
265
|
+
offset,
|
266
|
+
framework,
|
267
|
+
device,
|
268
|
+
storage,
|
269
|
+
})
|
270
|
+
}
|
271
|
+
|
272
|
+
pub fn metadata(&self) -> Option<HashMap<String, String>> {
|
273
|
+
self.metadata.metadata().clone()
|
274
|
+
}
|
275
|
+
|
276
|
+
pub fn keys(&self) -> RbResult<Vec<String>> {
|
277
|
+
let mut keys: Vec<String> = self.metadata.tensors().keys().cloned().collect();
|
278
|
+
keys.sort();
|
279
|
+
Ok(keys)
|
280
|
+
}
|
281
|
+
|
282
|
+
pub fn get_tensor(&self, name: &str) -> RbResult<Value> {
|
283
|
+
let info = self.metadata.info(name).ok_or_else(|| {
|
284
|
+
SafetensorError::new_err(format!("File does not contain tensor {name}",))
|
285
|
+
})?;
|
286
|
+
|
287
|
+
match &self.storage.as_ref() {
|
288
|
+
Storage::Mmap(mmap) => {
|
289
|
+
let data =
|
290
|
+
&mmap[info.data_offsets.0 + self.offset..info.data_offsets.1 + self.offset];
|
291
|
+
|
292
|
+
let array: Value = RString::from_slice(data).into_value();
|
293
|
+
|
294
|
+
create_tensor(
|
295
|
+
&self.framework,
|
296
|
+
info.dtype,
|
297
|
+
&info.shape,
|
298
|
+
array,
|
299
|
+
&self.device,
|
300
|
+
)
|
301
|
+
}
|
302
|
+
}
|
303
|
+
}
|
304
|
+
}
|
305
|
+
|
306
|
+
#[magnus::wrap(class = "Safetensors::SafeOpen")]
|
307
|
+
struct SafeOpen {
|
308
|
+
inner: Option<Open>,
|
309
|
+
}
|
310
|
+
|
311
|
+
impl SafeOpen {
|
312
|
+
fn inner(&self) -> RbResult<&Open> {
|
313
|
+
let inner = self
|
314
|
+
.inner
|
315
|
+
.as_ref()
|
316
|
+
.ok_or_else(|| SafetensorError::new_err("File is closed".to_string()))?;
|
317
|
+
Ok(inner)
|
318
|
+
}
|
319
|
+
}
|
320
|
+
|
321
|
+
impl SafeOpen {
|
322
|
+
pub fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
|
323
|
+
let inner = Some(Open::new(filename, framework, device)?);
|
324
|
+
Ok(Self { inner })
|
325
|
+
}
|
326
|
+
|
327
|
+
pub fn metadata(&self) -> RbResult<Option<HashMap<String, String>>> {
|
328
|
+
Ok(self.inner()?.metadata())
|
329
|
+
}
|
330
|
+
|
331
|
+
pub fn keys(&self) -> RbResult<Vec<String>> {
|
332
|
+
self.inner()?.keys()
|
333
|
+
}
|
334
|
+
|
335
|
+
pub fn get_tensor(&self, name: String) -> RbResult<Value> {
|
336
|
+
self.inner()?.get_tensor(&name)
|
337
|
+
}
|
338
|
+
}
|
339
|
+
|
340
|
+
fn create_tensor(
|
341
|
+
framework: &Framework,
|
342
|
+
dtype: Dtype,
|
343
|
+
shape: &[usize],
|
344
|
+
array: Value,
|
345
|
+
device: &Device,
|
346
|
+
) -> RbResult<Value> {
|
347
|
+
let ruby = Ruby::get().unwrap();
|
348
|
+
let (module, is_numo): (RModule, bool) = match framework {
|
349
|
+
Framework::Pytorch => (
|
350
|
+
ruby.class_object()
|
351
|
+
.const_get("Torch")
|
352
|
+
.map_err(|_| SafetensorError::new_err("Torch not loaded".into()))?,
|
353
|
+
false,
|
354
|
+
),
|
355
|
+
_ => (
|
356
|
+
ruby.class_object()
|
357
|
+
.const_get("Numo")
|
358
|
+
.map_err(|_| SafetensorError::new_err("Numo not loaded".into()))?,
|
359
|
+
true,
|
360
|
+
),
|
361
|
+
};
|
362
|
+
|
363
|
+
let dtype = get_rbdtype(module, dtype, is_numo)?;
|
364
|
+
let shape = shape.to_vec();
|
365
|
+
let tensor: Value = match framework {
|
366
|
+
Framework::Pytorch => {
|
367
|
+
let options: Value = module.funcall(
|
368
|
+
"tensor_options",
|
369
|
+
(kwargs!("dtype" => dtype, "device" => device.clone()),),
|
370
|
+
)?;
|
371
|
+
module.funcall("_from_blob_ref", (array, shape, options))?
|
372
|
+
}
|
373
|
+
_ => {
|
374
|
+
let class: Value = module.funcall("const_get", (dtype,))?;
|
375
|
+
class.funcall("from_binary", (array, shape))?
|
376
|
+
}
|
377
|
+
};
|
378
|
+
Ok(tensor)
|
379
|
+
}
|
380
|
+
|
381
|
+
fn get_rbdtype(_module: RModule, dtype: Dtype, is_numo: bool) -> RbResult<Value> {
|
382
|
+
let dtype: Value = if is_numo {
|
383
|
+
match dtype {
|
384
|
+
Dtype::F64 => Symbol::new("DFloat").into_value(),
|
385
|
+
Dtype::F32 => Symbol::new("SFloat").into_value(),
|
386
|
+
Dtype::U64 => Symbol::new("UInt64").into_value(),
|
387
|
+
Dtype::I64 => Symbol::new("Int64").into_value(),
|
388
|
+
Dtype::U32 => Symbol::new("UInt32").into_value(),
|
389
|
+
Dtype::I32 => Symbol::new("Int32").into_value(),
|
390
|
+
Dtype::U16 => Symbol::new("UInt16").into_value(),
|
391
|
+
Dtype::I16 => Symbol::new("Int16").into_value(),
|
392
|
+
Dtype::U8 => Symbol::new("UInt8").into_value(),
|
393
|
+
Dtype::I8 => Symbol::new("Int8").into_value(),
|
394
|
+
dtype => {
|
395
|
+
return Err(SafetensorError::new_err(format!(
|
396
|
+
"Dtype not understood: {dtype:?}"
|
397
|
+
)))
|
398
|
+
}
|
399
|
+
}
|
400
|
+
} else {
|
401
|
+
match dtype {
|
402
|
+
Dtype::F64 => Symbol::new("float64").into_value(),
|
403
|
+
Dtype::F32 => Symbol::new("float32").into_value(),
|
404
|
+
Dtype::BF16 => Symbol::new("bfloat16").into_value(),
|
405
|
+
Dtype::F16 => Symbol::new("float16").into_value(),
|
406
|
+
Dtype::U64 => Symbol::new("uint64").into_value(),
|
407
|
+
Dtype::I64 => Symbol::new("int64").into_value(),
|
408
|
+
Dtype::U32 => Symbol::new("uint32").into_value(),
|
409
|
+
Dtype::I32 => Symbol::new("int32").into_value(),
|
410
|
+
Dtype::U16 => Symbol::new("uint16").into_value(),
|
411
|
+
Dtype::I16 => Symbol::new("int16").into_value(),
|
412
|
+
Dtype::U8 => Symbol::new("uint8").into_value(),
|
413
|
+
Dtype::I8 => Symbol::new("int8").into_value(),
|
414
|
+
Dtype::BOOL => Symbol::new("bool").into_value(),
|
415
|
+
Dtype::F8_E4M3 => Symbol::new("float8_e4m3fn").into_value(),
|
416
|
+
Dtype::F8_E5M2 => Symbol::new("float8_e5m2").into_value(),
|
417
|
+
dtype => {
|
418
|
+
return Err(SafetensorError::new_err(format!(
|
419
|
+
"Dtype not understood: {dtype:?}"
|
420
|
+
)))
|
421
|
+
}
|
422
|
+
}
|
423
|
+
};
|
424
|
+
Ok(dtype)
|
425
|
+
}
|
426
|
+
|
427
|
+
struct SafetensorError {}
|
428
|
+
|
429
|
+
impl SafetensorError {
|
430
|
+
fn new_err(message: String) -> Error {
|
431
|
+
let class = Ruby::get()
|
432
|
+
.unwrap()
|
433
|
+
.class_object()
|
434
|
+
.const_get::<_, RModule>("Safetensors")
|
435
|
+
.unwrap()
|
436
|
+
.const_get("Error")
|
437
|
+
.unwrap();
|
438
|
+
Error::new(class, message)
|
439
|
+
}
|
440
|
+
|
441
|
+
fn io(err: std::io::Error) -> Error {
|
442
|
+
Self::new_err(err.to_string())
|
443
|
+
}
|
444
|
+
|
445
|
+
fn parse(err: std::num::ParseIntError) -> Error {
|
446
|
+
Self::new_err(err.to_string())
|
447
|
+
}
|
448
|
+
}
|
449
|
+
|
450
|
+
#[magnus::init]
|
451
|
+
fn init(ruby: &Ruby) -> RbResult<()> {
|
452
|
+
let module = ruby.define_module("Safetensors")?;
|
453
|
+
module.define_singleton_method("_serialize", function!(serialize, 2))?;
|
454
|
+
module.define_singleton_method("_serialize_file", function!(serialize_file, 3))?;
|
455
|
+
module.define_singleton_method("deserialize", function!(deserialize, 1))?;
|
456
|
+
|
457
|
+
let class = module.define_class("SafeOpen", ruby.class_object())?;
|
458
|
+
class.define_singleton_method("new", function!(SafeOpen::new, 3))?;
|
459
|
+
class.define_method("metadata", method!(SafeOpen::metadata, 0))?;
|
460
|
+
class.define_method("keys", method!(SafeOpen::keys, 0))?;
|
461
|
+
class.define_method("get_tensor", method!(SafeOpen::get_tensor, 1))?;
|
462
|
+
|
463
|
+
Ok(())
|
464
|
+
}
|
@@ -0,0 +1,93 @@
|
|
1
|
+
module Safetensors
|
2
|
+
module Numo
|
3
|
+
DTYPES = {
|
4
|
+
"DFloat" => "float64",
|
5
|
+
"SFloat" => "float32"
|
6
|
+
}
|
7
|
+
|
8
|
+
TYPES = {
|
9
|
+
"F64" => :DFloat,
|
10
|
+
"F32" => :SFloat,
|
11
|
+
"I64" => :Int64,
|
12
|
+
"U64" => :UInt64,
|
13
|
+
"I32" => :Int32,
|
14
|
+
"U32" => :UInt32,
|
15
|
+
"I16" => :Int16,
|
16
|
+
"U16" => :UInt16,
|
17
|
+
"I8" => :Int8,
|
18
|
+
"U8" => :UInt8
|
19
|
+
}
|
20
|
+
|
21
|
+
class << self
|
22
|
+
def save(tensor_dict, metadata: nil)
|
23
|
+
Safetensors.serialize(_flatten(tensor_dict), metadata: metadata)
|
24
|
+
end
|
25
|
+
|
26
|
+
def save_file(tensor_dict, filename, metadata: nil)
|
27
|
+
Safetensors.serialize_file(_flatten(tensor_dict), filename, metadata: metadata)
|
28
|
+
end
|
29
|
+
|
30
|
+
def load(data)
|
31
|
+
flat = Safetensors.deserialize(data)
|
32
|
+
_view2numo(flat)
|
33
|
+
end
|
34
|
+
|
35
|
+
def load_file(filename)
|
36
|
+
result = {}
|
37
|
+
Safetensors.safe_open(filename, framework: "numo") do |f|
|
38
|
+
f.keys.each do |k|
|
39
|
+
result[k] = f.get_tensor(k)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
result
|
43
|
+
end
|
44
|
+
|
45
|
+
private
|
46
|
+
|
47
|
+
def _flatten(tensors)
|
48
|
+
if !tensors.is_a?(Hash)
|
49
|
+
raise ArgumentError, "Expected a hash of [String, Numo::NArray] but received #{tensors.class.name}"
|
50
|
+
end
|
51
|
+
|
52
|
+
tensors.each do |k, v|
|
53
|
+
if !v.is_a?(::Numo::NArray)
|
54
|
+
raise ArgumentError, "Key `#{k}` is invalid, expected Numo::NArray but received #{v.class.name}"
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
tensors.to_h do |k, v|
|
59
|
+
[
|
60
|
+
k.is_a?(Symbol) ? k.to_s : k,
|
61
|
+
{
|
62
|
+
"dtype" => DTYPES.fetch(v.class.name.split("::").last),
|
63
|
+
"shape" => v.shape,
|
64
|
+
"data" => _tobytes(v)
|
65
|
+
}
|
66
|
+
]
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
def _tobytes(tensor)
|
71
|
+
if Safetensors.big_endian?
|
72
|
+
raise "Not yet implemented"
|
73
|
+
end
|
74
|
+
|
75
|
+
tensor.to_binary
|
76
|
+
end
|
77
|
+
|
78
|
+
def _getdtype(dtype_str)
|
79
|
+
TYPES.fetch(dtype_str)
|
80
|
+
end
|
81
|
+
|
82
|
+
def _view2numo(safeview)
|
83
|
+
result = {}
|
84
|
+
safeview.each do |k, v|
|
85
|
+
dtype = _getdtype(v["dtype"])
|
86
|
+
arr = ::Numo.const_get(dtype).from_binary(v["data"], v["shape"])
|
87
|
+
result[k] = arr
|
88
|
+
end
|
89
|
+
result
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|