torch-rb 0.17.0 → 0.18.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +14 -18
- data/codegen/native_functions.yaml +230 -151
- data/ext/torch/utils.h +1 -1
- data/lib/torch/nn/module.rb +24 -4
- data/lib/torch/tensor.rb +8 -3
- data/lib/torch/version.rb +1 -1
- metadata +5 -5
data/ext/torch/utils.h
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -10,16 +10,23 @@ module Torch
|
|
10
10
|
@parameters = {}
|
11
11
|
@buffers = {}
|
12
12
|
@modules = {}
|
13
|
+
@non_persistent_buffers_set = Set.new
|
13
14
|
end
|
14
15
|
|
15
16
|
def forward
|
16
17
|
raise NotImplementedError
|
17
18
|
end
|
18
19
|
|
19
|
-
def register_buffer(name, tensor)
|
20
|
+
def register_buffer(name, tensor, persistent: true)
|
20
21
|
# TODO add checks
|
21
22
|
@buffers[name] = tensor
|
22
23
|
instance_variable_set("@#{name}", tensor)
|
24
|
+
|
25
|
+
if persistent
|
26
|
+
@non_persistent_buffers_set.delete(name)
|
27
|
+
else
|
28
|
+
@non_persistent_buffers_set << name
|
29
|
+
end
|
23
30
|
end
|
24
31
|
|
25
32
|
def register_parameter(name, param)
|
@@ -190,8 +197,18 @@ module Torch
|
|
190
197
|
named_buffers.values
|
191
198
|
end
|
192
199
|
|
193
|
-
|
194
|
-
|
200
|
+
# TODO set recurse: true in 0.18.0
|
201
|
+
def named_buffers(prefix: "", recurse: false)
|
202
|
+
buffers = {}
|
203
|
+
if recurse
|
204
|
+
named_children.each do |name, mod|
|
205
|
+
buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
|
206
|
+
end
|
207
|
+
end
|
208
|
+
(@buffers || {}).each do |k, v|
|
209
|
+
buffers[[prefix, k].join] = v
|
210
|
+
end
|
211
|
+
buffers
|
195
212
|
end
|
196
213
|
|
197
214
|
def children
|
@@ -390,7 +407,10 @@ module Torch
|
|
390
407
|
destination[prefix + k] = v
|
391
408
|
end
|
392
409
|
named_buffers.each do |k, v|
|
393
|
-
|
410
|
+
# TODO exclude v.nil?
|
411
|
+
if !@non_persistent_buffers_set.include?(k)
|
412
|
+
destination[prefix + k] = v
|
413
|
+
end
|
394
414
|
end
|
395
415
|
end
|
396
416
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -24,6 +24,7 @@ module Torch
|
|
24
24
|
alias_method :^, :logical_xor
|
25
25
|
alias_method :<<, :__lshift__
|
26
26
|
alias_method :>>, :__rshift__
|
27
|
+
alias_method :~, :bitwise_not
|
27
28
|
|
28
29
|
def self.new(*args)
|
29
30
|
FloatTensor.new(*args)
|
@@ -132,9 +133,13 @@ module Torch
|
|
132
133
|
|
133
134
|
# TODO read directly from memory
|
134
135
|
def numo
|
135
|
-
|
136
|
-
|
137
|
-
|
136
|
+
if dtype == :bool
|
137
|
+
Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
|
138
|
+
else
|
139
|
+
cls = Torch._dtype_to_numo[dtype]
|
140
|
+
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
141
|
+
cls.from_string(_data_str).reshape(*shape)
|
142
|
+
end
|
138
143
|
end
|
139
144
|
|
140
145
|
def requires_grad=(requires_grad)
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.18.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-10-22 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: 4.
|
19
|
+
version: 4.3.3
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: 4.
|
26
|
+
version: 4.3.3
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|
@@ -239,7 +239,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
239
239
|
- !ruby/object:Gem::Version
|
240
240
|
version: '0'
|
241
241
|
requirements: []
|
242
|
-
rubygems_version: 3.5.
|
242
|
+
rubygems_version: 3.5.16
|
243
243
|
signing_key:
|
244
244
|
specification_version: 4
|
245
245
|
summary: Deep learning for Ruby, powered by LibTorch
|