safetensors 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,141 @@
1
+ module Safetensors
2
+ module Torch
3
+ TYPES = {
4
+ "F64" => :float64,
5
+ "F32" => :float32,
6
+ "F16" => :float16,
7
+ "BF16" => :bfloat16,
8
+ "I64" => :int64,
9
+ "U64" => :uint64,
10
+ "I32" => :int32,
11
+ "U32" => :uint32,
12
+ "I16" => :int16,
13
+ "U16" => :uint16,
14
+ "I8" => :int8,
15
+ "U8" => :uint8,
16
+ "BOOL" => :bool,
17
+ "F8_E4M3" => :float8_e4m3fn,
18
+ "F8_E5M2" => :float8_e5m2
19
+ }
20
+
21
+ class << self
22
+ def save(tensors, metadata: nil)
23
+ Safetensors.serialize(_flatten(tensors), metadata: metadata)
24
+ end
25
+
26
+ def save_file(tensors, filename, metadata: nil)
27
+ Safetensors.serialize_file(_flatten(tensors), filename, metadata: metadata)
28
+ end
29
+
30
+ def load_file(filename, device: "cpu")
31
+ result = {}
32
+ Safetensors.safe_open(filename, framework: "torch", device: device) do |f|
33
+ f.keys.each do |k|
34
+ result[k] = f.get_tensor(k)
35
+ end
36
+ end
37
+ result
38
+ end
39
+
40
+ def load(data)
41
+ flat = Safetensors.deserialize(data)
42
+ _view2torch(flat)
43
+ end
44
+
45
+ private
46
+
47
+ def _find_shared_tensors(state_dict)
48
+ # TODO
49
+ []
50
+ end
51
+
52
+ def _getdtype(dtype_str)
53
+ TYPES.fetch(dtype_str)
54
+ end
55
+
56
+ def _view2torch(safeview)
57
+ result = {}
58
+ safeview.each do |k, v|
59
+ dtype = _getdtype(v["dtype"])
60
+ options = ::Torch.send(:tensor_options, dtype: dtype)
61
+ arr = ::Torch._from_blob_ref(v["data"], v["shape"], options)
62
+ if Safetensors.big_endian?
63
+ # TODO
64
+ raise "not yet implemented"
65
+ end
66
+ result[k] = arr
67
+ end
68
+ result
69
+ end
70
+
71
+ def _tobytes(tensor, name)
72
+ if tensor.layout != :strided
73
+ raise ArgumentError, "You are trying to save a sparse tensor: `#{name}` which this library does not support. You can make it a dense tensor before saving with `.to_dense()` but be aware this might make a much larger file than needed."
74
+ end
75
+
76
+ if !tensor.contiguous?
77
+ raise ArgumentError, "You are trying to save a non contiguous tensor: `#{name}` which is not allowed. It either means you are trying to save tensors which are reference of each other in which case it's recommended to save only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to pack it before saving."
78
+ end
79
+
80
+ if tensor.device != "cpu"
81
+ # Moving tensor to cpu before saving
82
+ tensor = tensor.to("cpu")
83
+ end
84
+
85
+ if Safetensors.big_endian?
86
+ # TODO
87
+ raise "not yet implemented"
88
+ end
89
+
90
+ tensor._data_str
91
+ end
92
+
93
+ def _flatten(tensors)
94
+ if !tensors.is_a?(Hash)
95
+ raise ArgumentError, "Expected a hash of [String, Torch::Tensor] but received #{tensors.class.name}"
96
+ end
97
+
98
+ invalid_tensors = []
99
+ tensors.each do |k, v|
100
+ if !v.is_a?(::Torch::Tensor)
101
+ raise ArgumentError, "Key `#{k}` is invalid, expected Torch::Tensor but received #{v.class.name}"
102
+ end
103
+
104
+ if v.layout != :strided
105
+ invalid_tensors << k
106
+ end
107
+ end
108
+ if invalid_tensors.any?
109
+ raise ArgumentError, "You are trying to save a sparse tensors: `#{invalid_tensors}` which this library does not support. You can make it a dense tensor before saving with `.to_dense()` but be aware this might make a much larger file than needed."
110
+ end
111
+
112
+ shared_pointers = _find_shared_tensors(tensors)
113
+ failing = []
114
+ shared_pointers.each do |names|
115
+ if names.length > 1
116
+ failing << names
117
+ end
118
+ end
119
+
120
+ if failing.any?
121
+ raise <<~MSG
122
+ Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: #{failing}.
123
+ A potential way to correctly save your model is to use `save_model`.
124
+ More information at https://huggingface.co/docs/safetensors/torch_shared_tensors
125
+ MSG
126
+ end
127
+
128
+ tensors.to_h do |k, v|
129
+ [
130
+ k.is_a?(Symbol) ? k.to_s : k,
131
+ {
132
+ "dtype" => v.dtype.to_s,
133
+ "shape" => v.shape,
134
+ "data" => _tobytes(v, k)
135
+ }
136
+ ]
137
+ end
138
+ end
139
+ end
140
+ end
141
+ end
@@ -0,0 +1,3 @@
1
+ module Safetensors
2
+ VERSION = "0.1.0"
3
+ end
@@ -0,0 +1,37 @@
1
+ # ext
2
+ begin
3
+ require "safetensors/#{RUBY_VERSION.to_f}/safetensors"
4
+ rescue LoadError
5
+ require "safetensors/safetensors"
6
+ end
7
+
8
+ # modules
9
+ require_relative "safetensors/numo"
10
+ require_relative "safetensors/torch"
11
+ require_relative "safetensors/version"
12
+
13
+ module Safetensors
14
+ class Error < StandardError; end
15
+
16
+ def self.serialize(tensor_dict, metadata: nil)
17
+ _serialize(tensor_dict, metadata)
18
+ end
19
+
20
+ def self.serialize_file(tensor_dict, filename, metadata: nil)
21
+ _serialize_file(tensor_dict, filename, metadata)
22
+ end
23
+
24
+ def self.safe_open(filename, framework:, device: "cpu")
25
+ f = SafeOpen.new(filename, framework, device)
26
+ if block_given?
27
+ yield f
28
+ else
29
+ f
30
+ end
31
+ end
32
+
33
+ # private
34
+ def self.big_endian?
35
+ [1].pack("i") == [1].pack("i!>")
36
+ end
37
+ end
metadata ADDED
@@ -0,0 +1,69 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: safetensors
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Andrew Kane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2024-02-25 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rb_sys
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ description:
28
+ email: andrew@ankane.org
29
+ executables: []
30
+ extensions:
31
+ - ext/safetensors/extconf.rb
32
+ extra_rdoc_files: []
33
+ files:
34
+ - CHANGELOG.md
35
+ - Cargo.lock
36
+ - Cargo.toml
37
+ - LICENSE.txt
38
+ - README.md
39
+ - ext/safetensors/Cargo.toml
40
+ - ext/safetensors/extconf.rb
41
+ - ext/safetensors/src/lib.rs
42
+ - lib/safetensors.rb
43
+ - lib/safetensors/numo.rb
44
+ - lib/safetensors/torch.rb
45
+ - lib/safetensors/version.rb
46
+ homepage: https://github.com/ankane/safetensors-ruby
47
+ licenses:
48
+ - Apache-2.0
49
+ metadata: {}
50
+ post_install_message:
51
+ rdoc_options: []
52
+ require_paths:
53
+ - lib
54
+ required_ruby_version: !ruby/object:Gem::Requirement
55
+ requirements:
56
+ - - ">="
57
+ - !ruby/object:Gem::Version
58
+ version: '3.1'
59
+ required_rubygems_version: !ruby/object:Gem::Requirement
60
+ requirements:
61
+ - - ">="
62
+ - !ruby/object:Gem::Version
63
+ version: '0'
64
+ requirements: []
65
+ rubygems_version: 3.5.3
66
+ signing_key:
67
+ specification_version: 4
68
+ summary: Simple, safe way to store and distribute tensors
69
+ test_files: []