safetensors 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/Cargo.lock +414 -0
- data/Cargo.toml +6 -0
- data/LICENSE.txt +201 -0
- data/README.md +67 -0
- data/ext/safetensors/Cargo.toml +17 -0
- data/ext/safetensors/extconf.rb +4 -0
- data/ext/safetensors/src/lib.rs +464 -0
- data/lib/safetensors/numo.rb +93 -0
- data/lib/safetensors/torch.rb +141 -0
- data/lib/safetensors/version.rb +3 -0
- data/lib/safetensors.rb +37 -0
- metadata +69 -0
@@ -0,0 +1,141 @@
|
|
1
|
+
module Safetensors
|
2
|
+
module Torch
|
3
|
+
TYPES = {
|
4
|
+
"F64" => :float64,
|
5
|
+
"F32" => :float32,
|
6
|
+
"F16" => :float16,
|
7
|
+
"BF16" => :bfloat16,
|
8
|
+
"I64" => :int64,
|
9
|
+
"U64" => :uint64,
|
10
|
+
"I32" => :int32,
|
11
|
+
"U32" => :uint32,
|
12
|
+
"I16" => :int16,
|
13
|
+
"U16" => :uint16,
|
14
|
+
"I8" => :int8,
|
15
|
+
"U8" => :uint8,
|
16
|
+
"BOOL" => :bool,
|
17
|
+
"F8_E4M3" => :float8_e4m3fn,
|
18
|
+
"F8_E5M2" => :float8_e5m2
|
19
|
+
}
|
20
|
+
|
21
|
+
class << self
|
22
|
+
def save(tensors, metadata: nil)
|
23
|
+
Safetensors.serialize(_flatten(tensors), metadata: metadata)
|
24
|
+
end
|
25
|
+
|
26
|
+
def save_file(tensors, filename, metadata: nil)
|
27
|
+
Safetensors.serialize_file(_flatten(tensors), filename, metadata: metadata)
|
28
|
+
end
|
29
|
+
|
30
|
+
def load_file(filename, device: "cpu")
|
31
|
+
result = {}
|
32
|
+
Safetensors.safe_open(filename, framework: "torch", device: device) do |f|
|
33
|
+
f.keys.each do |k|
|
34
|
+
result[k] = f.get_tensor(k)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
result
|
38
|
+
end
|
39
|
+
|
40
|
+
def load(data)
|
41
|
+
flat = Safetensors.deserialize(data)
|
42
|
+
_view2torch(flat)
|
43
|
+
end
|
44
|
+
|
45
|
+
private
|
46
|
+
|
47
|
+
def _find_shared_tensors(state_dict)
|
48
|
+
# TODO
|
49
|
+
[]
|
50
|
+
end
|
51
|
+
|
52
|
+
def _getdtype(dtype_str)
|
53
|
+
TYPES.fetch(dtype_str)
|
54
|
+
end
|
55
|
+
|
56
|
+
def _view2torch(safeview)
|
57
|
+
result = {}
|
58
|
+
safeview.each do |k, v|
|
59
|
+
dtype = _getdtype(v["dtype"])
|
60
|
+
options = ::Torch.send(:tensor_options, dtype: dtype)
|
61
|
+
arr = ::Torch._from_blob_ref(v["data"], v["shape"], options)
|
62
|
+
if Safetensors.big_endian?
|
63
|
+
# TODO
|
64
|
+
raise "not yet implemented"
|
65
|
+
end
|
66
|
+
result[k] = arr
|
67
|
+
end
|
68
|
+
result
|
69
|
+
end
|
70
|
+
|
71
|
+
def _tobytes(tensor, name)
|
72
|
+
if tensor.layout != :strided
|
73
|
+
raise ArgumentError, "You are trying to save a sparse tensor: `#{name}` which this library does not support. You can make it a dense tensor before saving with `.to_dense()` but be aware this might make a much larger file than needed."
|
74
|
+
end
|
75
|
+
|
76
|
+
if !tensor.contiguous?
|
77
|
+
raise ArgumentError, "You are trying to save a non contiguous tensor: `#{name}` which is not allowed. It either means you are trying to save tensors which are reference of each other in which case it's recommended to save only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to pack it before saving."
|
78
|
+
end
|
79
|
+
|
80
|
+
if tensor.device != "cpu"
|
81
|
+
# Moving tensor to cpu before saving
|
82
|
+
tensor = tensor.to("cpu")
|
83
|
+
end
|
84
|
+
|
85
|
+
if Safetensors.big_endian?
|
86
|
+
# TODO
|
87
|
+
raise "not yet implemented"
|
88
|
+
end
|
89
|
+
|
90
|
+
tensor._data_str
|
91
|
+
end
|
92
|
+
|
93
|
+
def _flatten(tensors)
|
94
|
+
if !tensors.is_a?(Hash)
|
95
|
+
raise ArgumentError, "Expected a hash of [String, Torch::Tensor] but received #{tensors.class.name}"
|
96
|
+
end
|
97
|
+
|
98
|
+
invalid_tensors = []
|
99
|
+
tensors.each do |k, v|
|
100
|
+
if !v.is_a?(::Torch::Tensor)
|
101
|
+
raise ArgumentError, "Key `#{k}` is invalid, expected Torch::Tensor but received #{v.class.name}"
|
102
|
+
end
|
103
|
+
|
104
|
+
if v.layout != :strided
|
105
|
+
invalid_tensors << k
|
106
|
+
end
|
107
|
+
end
|
108
|
+
if invalid_tensors.any?
|
109
|
+
raise ArgumentError, "You are trying to save a sparse tensors: `#{invalid_tensors}` which this library does not support. You can make it a dense tensor before saving with `.to_dense()` but be aware this might make a much larger file than needed."
|
110
|
+
end
|
111
|
+
|
112
|
+
shared_pointers = _find_shared_tensors(tensors)
|
113
|
+
failing = []
|
114
|
+
shared_pointers.each do |names|
|
115
|
+
if names.length > 1
|
116
|
+
failing << names
|
117
|
+
end
|
118
|
+
end
|
119
|
+
|
120
|
+
if failing.any?
|
121
|
+
raise <<~MSG
|
122
|
+
Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: #{failing}.
|
123
|
+
A potential way to correctly save your model is to use `save_model`.
|
124
|
+
More information at https://huggingface.co/docs/safetensors/torch_shared_tensors
|
125
|
+
MSG
|
126
|
+
end
|
127
|
+
|
128
|
+
tensors.to_h do |k, v|
|
129
|
+
[
|
130
|
+
k.is_a?(Symbol) ? k.to_s : k,
|
131
|
+
{
|
132
|
+
"dtype" => v.dtype.to_s,
|
133
|
+
"shape" => v.shape,
|
134
|
+
"data" => _tobytes(v, k)
|
135
|
+
}
|
136
|
+
]
|
137
|
+
end
|
138
|
+
end
|
139
|
+
end
|
140
|
+
end
|
141
|
+
end
|
data/lib/safetensors.rb
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
# ext
|
2
|
+
begin
|
3
|
+
require "safetensors/#{RUBY_VERSION.to_f}/safetensors"
|
4
|
+
rescue LoadError
|
5
|
+
require "safetensors/safetensors"
|
6
|
+
end
|
7
|
+
|
8
|
+
# modules
|
9
|
+
require_relative "safetensors/numo"
|
10
|
+
require_relative "safetensors/torch"
|
11
|
+
require_relative "safetensors/version"
|
12
|
+
|
13
|
+
module Safetensors
|
14
|
+
class Error < StandardError; end
|
15
|
+
|
16
|
+
def self.serialize(tensor_dict, metadata: nil)
|
17
|
+
_serialize(tensor_dict, metadata)
|
18
|
+
end
|
19
|
+
|
20
|
+
def self.serialize_file(tensor_dict, filename, metadata: nil)
|
21
|
+
_serialize_file(tensor_dict, filename, metadata)
|
22
|
+
end
|
23
|
+
|
24
|
+
def self.safe_open(filename, framework:, device: "cpu")
|
25
|
+
f = SafeOpen.new(filename, framework, device)
|
26
|
+
if block_given?
|
27
|
+
yield f
|
28
|
+
else
|
29
|
+
f
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
# private
|
34
|
+
def self.big_endian?
|
35
|
+
[1].pack("i") == [1].pack("i!>")
|
36
|
+
end
|
37
|
+
end
|
metadata
ADDED
@@ -0,0 +1,69 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: safetensors
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Andrew Kane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2024-02-25 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: rb_sys
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
description:
|
28
|
+
email: andrew@ankane.org
|
29
|
+
executables: []
|
30
|
+
extensions:
|
31
|
+
- ext/safetensors/extconf.rb
|
32
|
+
extra_rdoc_files: []
|
33
|
+
files:
|
34
|
+
- CHANGELOG.md
|
35
|
+
- Cargo.lock
|
36
|
+
- Cargo.toml
|
37
|
+
- LICENSE.txt
|
38
|
+
- README.md
|
39
|
+
- ext/safetensors/Cargo.toml
|
40
|
+
- ext/safetensors/extconf.rb
|
41
|
+
- ext/safetensors/src/lib.rs
|
42
|
+
- lib/safetensors.rb
|
43
|
+
- lib/safetensors/numo.rb
|
44
|
+
- lib/safetensors/torch.rb
|
45
|
+
- lib/safetensors/version.rb
|
46
|
+
homepage: https://github.com/ankane/safetensors-ruby
|
47
|
+
licenses:
|
48
|
+
- Apache-2.0
|
49
|
+
metadata: {}
|
50
|
+
post_install_message:
|
51
|
+
rdoc_options: []
|
52
|
+
require_paths:
|
53
|
+
- lib
|
54
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
55
|
+
requirements:
|
56
|
+
- - ">="
|
57
|
+
- !ruby/object:Gem::Version
|
58
|
+
version: '3.1'
|
59
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
60
|
+
requirements:
|
61
|
+
- - ">="
|
62
|
+
- !ruby/object:Gem::Version
|
63
|
+
version: '0'
|
64
|
+
requirements: []
|
65
|
+
rubygems_version: 3.5.3
|
66
|
+
signing_key:
|
67
|
+
specification_version: 4
|
68
|
+
summary: Simple, safe way to store and distribute tensors
|
69
|
+
test_files: []
|