safetensors 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/Cargo.lock +414 -0
- data/Cargo.toml +6 -0
- data/LICENSE.txt +201 -0
- data/README.md +67 -0
- data/ext/safetensors/Cargo.toml +17 -0
- data/ext/safetensors/extconf.rb +4 -0
- data/ext/safetensors/src/lib.rs +464 -0
- data/lib/safetensors/numo.rb +93 -0
- data/lib/safetensors/torch.rb +141 -0
- data/lib/safetensors/version.rb +3 -0
- data/lib/safetensors.rb +37 -0
- metadata +69 -0
data/README.md
ADDED
@@ -0,0 +1,67 @@
|
|
1
|
+
# Safetensors Ruby
|
2
|
+
|
3
|
+
:slightly_smiling_face: Simple, [safe way](https://github.com/huggingface/safetensors) to store and distribute tensors
|
4
|
+
|
5
|
+
Supports [Torch.rb](https://github.com/ankane/torch.rb) and [Numo](https://github.com/ruby-numo/numo-narray)
|
6
|
+
|
7
|
+
[](https://github.com/ankane/safetensors-ruby/actions)
|
8
|
+
|
9
|
+
## Installation
|
10
|
+
|
11
|
+
Add this line to your application’s Gemfile:
|
12
|
+
|
13
|
+
```ruby
|
14
|
+
gem "safetensors"
|
15
|
+
```
|
16
|
+
|
17
|
+
## Getting Started
|
18
|
+
|
19
|
+
Save tensors
|
20
|
+
|
21
|
+
```ruby
|
22
|
+
tensors = {
|
23
|
+
"weight1" => Torch.zeros([1024, 1024]),
|
24
|
+
"weight2" => Torch.zeros([1024, 1024])
|
25
|
+
}
|
26
|
+
Safetensors::Torch.save_file(tensors, "model.safetensors")
|
27
|
+
```
|
28
|
+
|
29
|
+
Load tensors
|
30
|
+
|
31
|
+
```ruby
|
32
|
+
tensors = Safetensors::Torch.load_file("model.safetensors")
|
33
|
+
# or
|
34
|
+
tensors = {}
|
35
|
+
Safetensors.safe_open("model.safetensors", framework: "torch", device: "cpu") do |f|
|
36
|
+
f.keys.each do |key|
|
37
|
+
tensors[key] = f.get_tensor(key)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
```
|
41
|
+
|
42
|
+
## API
|
43
|
+
|
44
|
+
This library follows the [Safetensors Python API](https://huggingface.co/docs/safetensors/index). You can follow Python tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
|
45
|
+
|
46
|
+
## History
|
47
|
+
|
48
|
+
View the [changelog](https://github.com/ankane/safetensors-ruby/blob/master/CHANGELOG.md)
|
49
|
+
|
50
|
+
## Contributing
|
51
|
+
|
52
|
+
Everyone is encouraged to help improve this project. Here are a few ways you can help:
|
53
|
+
|
54
|
+
- [Report bugs](https://github.com/ankane/safetensors-ruby/issues)
|
55
|
+
- Fix bugs and [submit pull requests](https://github.com/ankane/safetensors-ruby/pulls)
|
56
|
+
- Write, clarify, or fix documentation
|
57
|
+
- Suggest or add new features
|
58
|
+
|
59
|
+
To get started with development:
|
60
|
+
|
61
|
+
```sh
|
62
|
+
git clone https://github.com/ankane/safetensors-ruby.git
|
63
|
+
cd safetensors-ruby
|
64
|
+
bundle install
|
65
|
+
bundle exec rake compile
|
66
|
+
bundle exec rake test
|
67
|
+
```
|
@@ -0,0 +1,17 @@
|
|
1
|
+
[package]
|
2
|
+
name = "safetensors"
|
3
|
+
version = "0.1.0"
|
4
|
+
license = "Apache-2.0"
|
5
|
+
authors = ["Andrew Kane <andrew@ankane.org>"]
|
6
|
+
edition = "2021"
|
7
|
+
rust-version = "1.62.0"
|
8
|
+
publish = false
|
9
|
+
|
10
|
+
[lib]
|
11
|
+
crate-type = ["cdylib"]
|
12
|
+
|
13
|
+
[dependencies]
|
14
|
+
magnus = "0.6"
|
15
|
+
memmap2 = "0.5"
|
16
|
+
safetensors = "=0.4.2"
|
17
|
+
serde_json = "1.0"
|
@@ -0,0 +1,464 @@
|
|
1
|
+
use magnus::{
|
2
|
+
function, kwargs, method, prelude::*, r_hash::ForEach, Error, IntoValue, RArray, RHash,
|
3
|
+
RModule, RString, Ruby, Symbol, TryConvert, Value,
|
4
|
+
};
|
5
|
+
use memmap2::{Mmap, MmapOptions};
|
6
|
+
use safetensors::tensor::{Dtype, Metadata, SafeTensors, TensorView};
|
7
|
+
use std::collections::HashMap;
|
8
|
+
use std::fs::File;
|
9
|
+
use std::path::PathBuf;
|
10
|
+
use std::sync::Arc;
|
11
|
+
|
12
|
+
type RbResult<T> = Result<T, Error>;
|
13
|
+
|
14
|
+
fn prepare(tensor_dict: &RHash) -> RbResult<HashMap<String, TensorView<'_>>> {
|
15
|
+
let mut tensors = HashMap::with_capacity(tensor_dict.len());
|
16
|
+
tensor_dict.foreach(|tensor_name: String, tensor_desc: RHash| {
|
17
|
+
let mut shape: Option<Vec<usize>> = None;
|
18
|
+
let mut dtype: Option<Dtype> = None;
|
19
|
+
let mut data: Option<(*const u8, usize)> = None;
|
20
|
+
|
21
|
+
tensor_desc.foreach(|key: String, value: Value| {
|
22
|
+
match key.as_str() {
|
23
|
+
"shape" => shape = Some(Vec::try_convert(value)?),
|
24
|
+
"dtype" => {
|
25
|
+
let value = String::try_convert(value)?;
|
26
|
+
dtype = match value.as_str() {
|
27
|
+
"bool" => Some(Dtype::BOOL),
|
28
|
+
"int8" => Some(Dtype::I8),
|
29
|
+
"uint8" => Some(Dtype::U8),
|
30
|
+
"int16" => Some(Dtype::I16),
|
31
|
+
"uint16" => Some(Dtype::U16),
|
32
|
+
"int32" => Some(Dtype::I32),
|
33
|
+
"uint32" => Some(Dtype::U32),
|
34
|
+
"int64" => Some(Dtype::I64),
|
35
|
+
"uint64" => Some(Dtype::U64),
|
36
|
+
"float16" => Some(Dtype::F16),
|
37
|
+
"float32" => Some(Dtype::F32),
|
38
|
+
"float64" => Some(Dtype::F64),
|
39
|
+
"bfloat16" => Some(Dtype::BF16),
|
40
|
+
"float8_e4m3fn" => Some(Dtype::F8_E4M3),
|
41
|
+
"float8_e5m2" => Some(Dtype::F8_E5M2),
|
42
|
+
dtype_str => {
|
43
|
+
return Err(SafetensorError::new_err(format!(
|
44
|
+
"dtype {dtype_str} is not covered",
|
45
|
+
)));
|
46
|
+
}
|
47
|
+
}
|
48
|
+
}
|
49
|
+
"data" => {
|
50
|
+
let rs = RString::try_convert(value)?;
|
51
|
+
// SAFETY: No context switching between threads in native extensions
|
52
|
+
// so the string will not be modified (or garbage collected)
|
53
|
+
// while the reference is held. Also, the string is a private copy.
|
54
|
+
let slice = unsafe { rs.as_slice() };
|
55
|
+
data = Some((slice.as_ptr(), slice.len()))
|
56
|
+
}
|
57
|
+
_ => println!("Ignored unknown kwarg option {key}"),
|
58
|
+
};
|
59
|
+
|
60
|
+
Ok(ForEach::Continue)
|
61
|
+
})?;
|
62
|
+
let shape = shape.ok_or_else(|| {
|
63
|
+
SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}"))
|
64
|
+
})?;
|
65
|
+
let dtype = dtype.ok_or_else(|| {
|
66
|
+
SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}"))
|
67
|
+
})?;
|
68
|
+
let data = data.ok_or_else(|| {
|
69
|
+
SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}"))
|
70
|
+
})?;
|
71
|
+
// SAFETY: See comment above.
|
72
|
+
let data = unsafe { std::slice::from_raw_parts(data.0, data.1) };
|
73
|
+
let tensor = TensorView::new(dtype, shape, data)
|
74
|
+
.map_err(|e| SafetensorError::new_err(format!("Error preparing tensor view: {e:?}")))?;
|
75
|
+
tensors.insert(tensor_name, tensor);
|
76
|
+
|
77
|
+
Ok(ForEach::Continue)
|
78
|
+
})?;
|
79
|
+
Ok(tensors)
|
80
|
+
}
|
81
|
+
|
82
|
+
fn serialize(tensor_dict: RHash, metadata: Option<HashMap<String, String>>) -> RbResult<RString> {
|
83
|
+
let tensors = prepare(&tensor_dict)?;
|
84
|
+
let metadata_map = metadata.map(HashMap::from_iter);
|
85
|
+
let out = safetensors::tensor::serialize(&tensors, &metadata_map)
|
86
|
+
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
|
87
|
+
let rbbytes = RString::from_slice(&out);
|
88
|
+
Ok(rbbytes)
|
89
|
+
}
|
90
|
+
|
91
|
+
fn serialize_file(
|
92
|
+
tensor_dict: RHash,
|
93
|
+
filename: PathBuf,
|
94
|
+
metadata: Option<HashMap<String, String>>,
|
95
|
+
) -> RbResult<()> {
|
96
|
+
let tensors = prepare(&tensor_dict)?;
|
97
|
+
safetensors::tensor::serialize_to_file(&tensors, &metadata, filename.as_path())
|
98
|
+
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {e:?}")))?;
|
99
|
+
Ok(())
|
100
|
+
}
|
101
|
+
|
102
|
+
fn deserialize(bytes: RString) -> RbResult<RArray> {
|
103
|
+
let safetensor = SafeTensors::deserialize(unsafe { bytes.as_slice() })
|
104
|
+
.map_err(|e| SafetensorError::new_err(format!("Error while deserializing: {e:?}")))?;
|
105
|
+
|
106
|
+
let tensors = safetensor.tensors();
|
107
|
+
let items = RArray::with_capacity(tensors.len());
|
108
|
+
|
109
|
+
for (tensor_name, tensor) in tensors {
|
110
|
+
let rbshape = RArray::from_vec(tensor.shape().to_vec());
|
111
|
+
let rbdtype = format!("{:?}", tensor.dtype());
|
112
|
+
|
113
|
+
let rbdata = RString::from_slice(tensor.data());
|
114
|
+
|
115
|
+
let map = RHash::new();
|
116
|
+
map.aset("shape", rbshape)?;
|
117
|
+
map.aset("dtype", rbdtype)?;
|
118
|
+
map.aset("data", rbdata)?;
|
119
|
+
|
120
|
+
items.push((tensor_name, map))?;
|
121
|
+
}
|
122
|
+
Ok(items)
|
123
|
+
}
|
124
|
+
|
125
|
+
#[derive(Debug, Clone, PartialEq, Eq)]
|
126
|
+
enum Framework {
|
127
|
+
Pytorch,
|
128
|
+
Numo,
|
129
|
+
}
|
130
|
+
|
131
|
+
impl TryConvert for Framework {
|
132
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
133
|
+
let name: String = String::try_convert(ob)?;
|
134
|
+
match &name[..] {
|
135
|
+
"pt" => Ok(Framework::Pytorch),
|
136
|
+
"torch" => Ok(Framework::Pytorch),
|
137
|
+
"pytorch" => Ok(Framework::Pytorch),
|
138
|
+
|
139
|
+
"nm" => Ok(Framework::Numo),
|
140
|
+
"numo" => Ok(Framework::Numo),
|
141
|
+
|
142
|
+
name => Err(SafetensorError::new_err(format!(
|
143
|
+
"framework {name} is invalid"
|
144
|
+
))),
|
145
|
+
}
|
146
|
+
}
|
147
|
+
}
|
148
|
+
|
149
|
+
#[derive(Debug, Clone, PartialEq, Eq)]
|
150
|
+
enum Device {
|
151
|
+
Cpu,
|
152
|
+
Cuda(usize),
|
153
|
+
Mps,
|
154
|
+
Npu(usize),
|
155
|
+
Xpu(usize),
|
156
|
+
}
|
157
|
+
|
158
|
+
impl TryConvert for Device {
|
159
|
+
fn try_convert(ob: Value) -> RbResult<Self> {
|
160
|
+
if let Ok(name) = String::try_convert(ob) {
|
161
|
+
match &name[..] {
|
162
|
+
"cpu" => Ok(Device::Cpu),
|
163
|
+
"cuda" => Ok(Device::Cuda(0)),
|
164
|
+
"mps" => Ok(Device::Mps),
|
165
|
+
"npu" => Ok(Device::Npu(0)),
|
166
|
+
"xpu" => Ok(Device::Xpu(0)),
|
167
|
+
name if name.starts_with("cuda:") => {
|
168
|
+
let tokens: Vec<_> = name.split(':').collect();
|
169
|
+
if tokens.len() == 2 {
|
170
|
+
let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
|
171
|
+
Ok(Device::Cuda(device))
|
172
|
+
} else {
|
173
|
+
Err(SafetensorError::new_err(format!(
|
174
|
+
"device {name} is invalid"
|
175
|
+
)))
|
176
|
+
}
|
177
|
+
}
|
178
|
+
name if name.starts_with("npu:") => {
|
179
|
+
let tokens: Vec<_> = name.split(':').collect();
|
180
|
+
if tokens.len() == 2 {
|
181
|
+
let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
|
182
|
+
Ok(Device::Npu(device))
|
183
|
+
} else {
|
184
|
+
Err(SafetensorError::new_err(format!(
|
185
|
+
"device {name} is invalid"
|
186
|
+
)))
|
187
|
+
}
|
188
|
+
}
|
189
|
+
name if name.starts_with("xpu:") => {
|
190
|
+
let tokens: Vec<_> = name.split(':').collect();
|
191
|
+
if tokens.len() == 2 {
|
192
|
+
let device: usize = tokens[1].parse().map_err(SafetensorError::parse)?;
|
193
|
+
Ok(Device::Xpu(device))
|
194
|
+
} else {
|
195
|
+
Err(SafetensorError::new_err(format!(
|
196
|
+
"device {name} is invalid"
|
197
|
+
)))
|
198
|
+
}
|
199
|
+
}
|
200
|
+
name => Err(SafetensorError::new_err(format!(
|
201
|
+
"device {name} is invalid"
|
202
|
+
))),
|
203
|
+
}
|
204
|
+
} else if let Ok(number) = usize::try_convert(ob) {
|
205
|
+
Ok(Device::Cuda(number))
|
206
|
+
} else {
|
207
|
+
Err(SafetensorError::new_err(format!("device {ob} is invalid")))
|
208
|
+
}
|
209
|
+
}
|
210
|
+
}
|
211
|
+
|
212
|
+
impl IntoValue for Device {
|
213
|
+
fn into_value_with(self, ruby: &Ruby) -> Value {
|
214
|
+
match self {
|
215
|
+
Device::Cpu => "cpu".into_value_with(ruby),
|
216
|
+
Device::Cuda(n) => format!("cuda:{n}").into_value_with(ruby),
|
217
|
+
Device::Mps => "mps".into_value_with(ruby),
|
218
|
+
Device::Npu(n) => format!("npu:{n}").into_value_with(ruby),
|
219
|
+
Device::Xpu(n) => format!("xpu:{n}").into_value_with(ruby),
|
220
|
+
}
|
221
|
+
}
|
222
|
+
}
|
223
|
+
|
224
|
+
enum Storage {
|
225
|
+
Mmap(Mmap),
|
226
|
+
}
|
227
|
+
|
228
|
+
struct Open {
|
229
|
+
metadata: Metadata,
|
230
|
+
offset: usize,
|
231
|
+
framework: Framework,
|
232
|
+
device: Device,
|
233
|
+
storage: Arc<Storage>,
|
234
|
+
}
|
235
|
+
|
236
|
+
impl Open {
|
237
|
+
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
|
238
|
+
let file = File::open(&filename).map_err(|_| {
|
239
|
+
SafetensorError::new_err(format!("No such file or directory: {filename:?}"))
|
240
|
+
})?;
|
241
|
+
let device = device.unwrap_or(Device::Cpu);
|
242
|
+
|
243
|
+
if device != Device::Cpu && framework != Framework::Pytorch {
|
244
|
+
return Err(SafetensorError::new_err(format!(
|
245
|
+
"Device {device:?} is not support for framework {framework:?}",
|
246
|
+
)));
|
247
|
+
}
|
248
|
+
|
249
|
+
// SAFETY: Mmap is used to prevent allocating in Rust
|
250
|
+
// before making a copy within Ruby.
|
251
|
+
let buffer = unsafe { MmapOptions::new().map(&file).map_err(SafetensorError::io)? };
|
252
|
+
|
253
|
+
let (n, metadata) = SafeTensors::read_metadata(&buffer).map_err(|e| {
|
254
|
+
SafetensorError::new_err(format!("Error while deserializing header: {e:?}"))
|
255
|
+
})?;
|
256
|
+
|
257
|
+
let offset = n + 8;
|
258
|
+
|
259
|
+
let storage = Storage::Mmap(buffer);
|
260
|
+
|
261
|
+
let storage = Arc::new(storage);
|
262
|
+
|
263
|
+
Ok(Self {
|
264
|
+
metadata,
|
265
|
+
offset,
|
266
|
+
framework,
|
267
|
+
device,
|
268
|
+
storage,
|
269
|
+
})
|
270
|
+
}
|
271
|
+
|
272
|
+
pub fn metadata(&self) -> Option<HashMap<String, String>> {
|
273
|
+
self.metadata.metadata().clone()
|
274
|
+
}
|
275
|
+
|
276
|
+
pub fn keys(&self) -> RbResult<Vec<String>> {
|
277
|
+
let mut keys: Vec<String> = self.metadata.tensors().keys().cloned().collect();
|
278
|
+
keys.sort();
|
279
|
+
Ok(keys)
|
280
|
+
}
|
281
|
+
|
282
|
+
pub fn get_tensor(&self, name: &str) -> RbResult<Value> {
|
283
|
+
let info = self.metadata.info(name).ok_or_else(|| {
|
284
|
+
SafetensorError::new_err(format!("File does not contain tensor {name}",))
|
285
|
+
})?;
|
286
|
+
|
287
|
+
match &self.storage.as_ref() {
|
288
|
+
Storage::Mmap(mmap) => {
|
289
|
+
let data =
|
290
|
+
&mmap[info.data_offsets.0 + self.offset..info.data_offsets.1 + self.offset];
|
291
|
+
|
292
|
+
let array: Value = RString::from_slice(data).into_value();
|
293
|
+
|
294
|
+
create_tensor(
|
295
|
+
&self.framework,
|
296
|
+
info.dtype,
|
297
|
+
&info.shape,
|
298
|
+
array,
|
299
|
+
&self.device,
|
300
|
+
)
|
301
|
+
}
|
302
|
+
}
|
303
|
+
}
|
304
|
+
}
|
305
|
+
|
306
|
+
#[magnus::wrap(class = "Safetensors::SafeOpen")]
|
307
|
+
struct SafeOpen {
|
308
|
+
inner: Option<Open>,
|
309
|
+
}
|
310
|
+
|
311
|
+
impl SafeOpen {
|
312
|
+
fn inner(&self) -> RbResult<&Open> {
|
313
|
+
let inner = self
|
314
|
+
.inner
|
315
|
+
.as_ref()
|
316
|
+
.ok_or_else(|| SafetensorError::new_err("File is closed".to_string()))?;
|
317
|
+
Ok(inner)
|
318
|
+
}
|
319
|
+
}
|
320
|
+
|
321
|
+
impl SafeOpen {
|
322
|
+
pub fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> RbResult<Self> {
|
323
|
+
let inner = Some(Open::new(filename, framework, device)?);
|
324
|
+
Ok(Self { inner })
|
325
|
+
}
|
326
|
+
|
327
|
+
pub fn metadata(&self) -> RbResult<Option<HashMap<String, String>>> {
|
328
|
+
Ok(self.inner()?.metadata())
|
329
|
+
}
|
330
|
+
|
331
|
+
pub fn keys(&self) -> RbResult<Vec<String>> {
|
332
|
+
self.inner()?.keys()
|
333
|
+
}
|
334
|
+
|
335
|
+
pub fn get_tensor(&self, name: String) -> RbResult<Value> {
|
336
|
+
self.inner()?.get_tensor(&name)
|
337
|
+
}
|
338
|
+
}
|
339
|
+
|
340
|
+
fn create_tensor(
|
341
|
+
framework: &Framework,
|
342
|
+
dtype: Dtype,
|
343
|
+
shape: &[usize],
|
344
|
+
array: Value,
|
345
|
+
device: &Device,
|
346
|
+
) -> RbResult<Value> {
|
347
|
+
let ruby = Ruby::get().unwrap();
|
348
|
+
let (module, is_numo): (RModule, bool) = match framework {
|
349
|
+
Framework::Pytorch => (
|
350
|
+
ruby.class_object()
|
351
|
+
.const_get("Torch")
|
352
|
+
.map_err(|_| SafetensorError::new_err("Torch not loaded".into()))?,
|
353
|
+
false,
|
354
|
+
),
|
355
|
+
_ => (
|
356
|
+
ruby.class_object()
|
357
|
+
.const_get("Numo")
|
358
|
+
.map_err(|_| SafetensorError::new_err("Numo not loaded".into()))?,
|
359
|
+
true,
|
360
|
+
),
|
361
|
+
};
|
362
|
+
|
363
|
+
let dtype = get_rbdtype(module, dtype, is_numo)?;
|
364
|
+
let shape = shape.to_vec();
|
365
|
+
let tensor: Value = match framework {
|
366
|
+
Framework::Pytorch => {
|
367
|
+
let options: Value = module.funcall(
|
368
|
+
"tensor_options",
|
369
|
+
(kwargs!("dtype" => dtype, "device" => device.clone()),),
|
370
|
+
)?;
|
371
|
+
module.funcall("_from_blob_ref", (array, shape, options))?
|
372
|
+
}
|
373
|
+
_ => {
|
374
|
+
let class: Value = module.funcall("const_get", (dtype,))?;
|
375
|
+
class.funcall("from_binary", (array, shape))?
|
376
|
+
}
|
377
|
+
};
|
378
|
+
Ok(tensor)
|
379
|
+
}
|
380
|
+
|
381
|
+
fn get_rbdtype(_module: RModule, dtype: Dtype, is_numo: bool) -> RbResult<Value> {
|
382
|
+
let dtype: Value = if is_numo {
|
383
|
+
match dtype {
|
384
|
+
Dtype::F64 => Symbol::new("DFloat").into_value(),
|
385
|
+
Dtype::F32 => Symbol::new("SFloat").into_value(),
|
386
|
+
Dtype::U64 => Symbol::new("UInt64").into_value(),
|
387
|
+
Dtype::I64 => Symbol::new("Int64").into_value(),
|
388
|
+
Dtype::U32 => Symbol::new("UInt32").into_value(),
|
389
|
+
Dtype::I32 => Symbol::new("Int32").into_value(),
|
390
|
+
Dtype::U16 => Symbol::new("UInt16").into_value(),
|
391
|
+
Dtype::I16 => Symbol::new("Int16").into_value(),
|
392
|
+
Dtype::U8 => Symbol::new("UInt8").into_value(),
|
393
|
+
Dtype::I8 => Symbol::new("Int8").into_value(),
|
394
|
+
dtype => {
|
395
|
+
return Err(SafetensorError::new_err(format!(
|
396
|
+
"Dtype not understood: {dtype:?}"
|
397
|
+
)))
|
398
|
+
}
|
399
|
+
}
|
400
|
+
} else {
|
401
|
+
match dtype {
|
402
|
+
Dtype::F64 => Symbol::new("float64").into_value(),
|
403
|
+
Dtype::F32 => Symbol::new("float32").into_value(),
|
404
|
+
Dtype::BF16 => Symbol::new("bfloat16").into_value(),
|
405
|
+
Dtype::F16 => Symbol::new("float16").into_value(),
|
406
|
+
Dtype::U64 => Symbol::new("uint64").into_value(),
|
407
|
+
Dtype::I64 => Symbol::new("int64").into_value(),
|
408
|
+
Dtype::U32 => Symbol::new("uint32").into_value(),
|
409
|
+
Dtype::I32 => Symbol::new("int32").into_value(),
|
410
|
+
Dtype::U16 => Symbol::new("uint16").into_value(),
|
411
|
+
Dtype::I16 => Symbol::new("int16").into_value(),
|
412
|
+
Dtype::U8 => Symbol::new("uint8").into_value(),
|
413
|
+
Dtype::I8 => Symbol::new("int8").into_value(),
|
414
|
+
Dtype::BOOL => Symbol::new("bool").into_value(),
|
415
|
+
Dtype::F8_E4M3 => Symbol::new("float8_e4m3fn").into_value(),
|
416
|
+
Dtype::F8_E5M2 => Symbol::new("float8_e5m2").into_value(),
|
417
|
+
dtype => {
|
418
|
+
return Err(SafetensorError::new_err(format!(
|
419
|
+
"Dtype not understood: {dtype:?}"
|
420
|
+
)))
|
421
|
+
}
|
422
|
+
}
|
423
|
+
};
|
424
|
+
Ok(dtype)
|
425
|
+
}
|
426
|
+
|
427
|
+
struct SafetensorError {}
|
428
|
+
|
429
|
+
impl SafetensorError {
|
430
|
+
fn new_err(message: String) -> Error {
|
431
|
+
let class = Ruby::get()
|
432
|
+
.unwrap()
|
433
|
+
.class_object()
|
434
|
+
.const_get::<_, RModule>("Safetensors")
|
435
|
+
.unwrap()
|
436
|
+
.const_get("Error")
|
437
|
+
.unwrap();
|
438
|
+
Error::new(class, message)
|
439
|
+
}
|
440
|
+
|
441
|
+
fn io(err: std::io::Error) -> Error {
|
442
|
+
Self::new_err(err.to_string())
|
443
|
+
}
|
444
|
+
|
445
|
+
fn parse(err: std::num::ParseIntError) -> Error {
|
446
|
+
Self::new_err(err.to_string())
|
447
|
+
}
|
448
|
+
}
|
449
|
+
|
450
|
+
#[magnus::init]
|
451
|
+
fn init(ruby: &Ruby) -> RbResult<()> {
|
452
|
+
let module = ruby.define_module("Safetensors")?;
|
453
|
+
module.define_singleton_method("_serialize", function!(serialize, 2))?;
|
454
|
+
module.define_singleton_method("_serialize_file", function!(serialize_file, 3))?;
|
455
|
+
module.define_singleton_method("deserialize", function!(deserialize, 1))?;
|
456
|
+
|
457
|
+
let class = module.define_class("SafeOpen", ruby.class_object())?;
|
458
|
+
class.define_singleton_method("new", function!(SafeOpen::new, 3))?;
|
459
|
+
class.define_method("metadata", method!(SafeOpen::metadata, 0))?;
|
460
|
+
class.define_method("keys", method!(SafeOpen::keys, 0))?;
|
461
|
+
class.define_method("get_tensor", method!(SafeOpen::get_tensor, 1))?;
|
462
|
+
|
463
|
+
Ok(())
|
464
|
+
}
|
@@ -0,0 +1,93 @@
|
|
1
|
+
module Safetensors
|
2
|
+
module Numo
|
3
|
+
DTYPES = {
|
4
|
+
"DFloat" => "float64",
|
5
|
+
"SFloat" => "float32"
|
6
|
+
}
|
7
|
+
|
8
|
+
TYPES = {
|
9
|
+
"F64" => :DFloat,
|
10
|
+
"F32" => :SFloat,
|
11
|
+
"I64" => :Int64,
|
12
|
+
"U64" => :UInt64,
|
13
|
+
"I32" => :Int32,
|
14
|
+
"U32" => :UInt32,
|
15
|
+
"I16" => :Int16,
|
16
|
+
"U16" => :UInt16,
|
17
|
+
"I8" => :Int8,
|
18
|
+
"U8" => :UInt8
|
19
|
+
}
|
20
|
+
|
21
|
+
class << self
|
22
|
+
def save(tensor_dict, metadata: nil)
|
23
|
+
Safetensors.serialize(_flatten(tensor_dict), metadata: metadata)
|
24
|
+
end
|
25
|
+
|
26
|
+
def save_file(tensor_dict, filename, metadata: nil)
|
27
|
+
Safetensors.serialize_file(_flatten(tensor_dict), filename, metadata: metadata)
|
28
|
+
end
|
29
|
+
|
30
|
+
def load(data)
|
31
|
+
flat = Safetensors.deserialize(data)
|
32
|
+
_view2numo(flat)
|
33
|
+
end
|
34
|
+
|
35
|
+
def load_file(filename)
|
36
|
+
result = {}
|
37
|
+
Safetensors.safe_open(filename, framework: "numo") do |f|
|
38
|
+
f.keys.each do |k|
|
39
|
+
result[k] = f.get_tensor(k)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
result
|
43
|
+
end
|
44
|
+
|
45
|
+
private
|
46
|
+
|
47
|
+
def _flatten(tensors)
|
48
|
+
if !tensors.is_a?(Hash)
|
49
|
+
raise ArgumentError, "Expected a hash of [String, Numo::NArray] but received #{tensors.class.name}"
|
50
|
+
end
|
51
|
+
|
52
|
+
tensors.each do |k, v|
|
53
|
+
if !v.is_a?(::Numo::NArray)
|
54
|
+
raise ArgumentError, "Key `#{k}` is invalid, expected Numo::NArray but received #{v.class.name}"
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
tensors.to_h do |k, v|
|
59
|
+
[
|
60
|
+
k.is_a?(Symbol) ? k.to_s : k,
|
61
|
+
{
|
62
|
+
"dtype" => DTYPES.fetch(v.class.name.split("::").last),
|
63
|
+
"shape" => v.shape,
|
64
|
+
"data" => _tobytes(v)
|
65
|
+
}
|
66
|
+
]
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
def _tobytes(tensor)
|
71
|
+
if Safetensors.big_endian?
|
72
|
+
raise "Not yet implemented"
|
73
|
+
end
|
74
|
+
|
75
|
+
tensor.to_binary
|
76
|
+
end
|
77
|
+
|
78
|
+
def _getdtype(dtype_str)
|
79
|
+
TYPES.fetch(dtype_str)
|
80
|
+
end
|
81
|
+
|
82
|
+
def _view2numo(safeview)
|
83
|
+
result = {}
|
84
|
+
safeview.each do |k, v|
|
85
|
+
dtype = _getdtype(v["dtype"])
|
86
|
+
arr = ::Numo.const_get(dtype).from_binary(v["data"], v["shape"])
|
87
|
+
result[k] = arr
|
88
|
+
end
|
89
|
+
result
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|