torch-rb 0.17.0 → 0.17.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/lib/torch/nn/module.rb +24 -4
- data/lib/torch/tensor.rb +7 -3
- data/lib/torch/version.rb +1 -1
- metadata +4 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f04aeddd9c66d668f78ed38d831ec66907743fd5be2cd076e9e547d00f1f2aa3
|
4
|
+
data.tar.gz: e76ca5199c4801c0eaa760ed3e41ebb6e521774bba084809656b1ed9e4b96483
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 683246aa91d19daa624a9d3a89fa5e4898f19fbb45a72ff2bcac0ee3025859f0f8ad26ac153e58150aedf10547cf670e2fe131cffda246546ded1f1344637040
|
7
|
+
data.tar.gz: bd0b955b997f09a274d8b59a21eafa8bf45972b6a20883b9d457a29417aa2c6d413b324abadfb6772c5bdc3aa0e633d3e25d6f23dd986fa603c55c69ae4a5a85
|
data/CHANGELOG.md
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -10,16 +10,23 @@ module Torch
|
|
10
10
|
@parameters = {}
|
11
11
|
@buffers = {}
|
12
12
|
@modules = {}
|
13
|
+
@non_persistent_buffers_set = Set.new
|
13
14
|
end
|
14
15
|
|
15
16
|
def forward
|
16
17
|
raise NotImplementedError
|
17
18
|
end
|
18
19
|
|
19
|
-
def register_buffer(name, tensor)
|
20
|
+
def register_buffer(name, tensor, persistent: true)
|
20
21
|
# TODO add checks
|
21
22
|
@buffers[name] = tensor
|
22
23
|
instance_variable_set("@#{name}", tensor)
|
24
|
+
|
25
|
+
if persistent
|
26
|
+
@non_persistent_buffers_set.delete(name)
|
27
|
+
else
|
28
|
+
@non_persistent_buffers_set << name
|
29
|
+
end
|
23
30
|
end
|
24
31
|
|
25
32
|
def register_parameter(name, param)
|
@@ -190,8 +197,18 @@ module Torch
|
|
190
197
|
named_buffers.values
|
191
198
|
end
|
192
199
|
|
193
|
-
|
194
|
-
|
200
|
+
# TODO set recurse: true in 0.18.0
|
201
|
+
def named_buffers(prefix: "", recurse: false)
|
202
|
+
buffers = {}
|
203
|
+
if recurse
|
204
|
+
named_children.each do |name, mod|
|
205
|
+
buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
|
206
|
+
end
|
207
|
+
end
|
208
|
+
(@buffers || {}).each do |k, v|
|
209
|
+
buffers[[prefix, k].join] = v
|
210
|
+
end
|
211
|
+
buffers
|
195
212
|
end
|
196
213
|
|
197
214
|
def children
|
@@ -390,7 +407,10 @@ module Torch
|
|
390
407
|
destination[prefix + k] = v
|
391
408
|
end
|
392
409
|
named_buffers.each do |k, v|
|
393
|
-
|
410
|
+
# TODO exclude v.nil?
|
411
|
+
if !@non_persistent_buffers_set.include?(k)
|
412
|
+
destination[prefix + k] = v
|
413
|
+
end
|
394
414
|
end
|
395
415
|
end
|
396
416
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -132,9 +132,13 @@ module Torch
|
|
132
132
|
|
133
133
|
# TODO read directly from memory
|
134
134
|
def numo
|
135
|
-
|
136
|
-
|
137
|
-
|
135
|
+
if dtype == :bool
|
136
|
+
Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
|
137
|
+
else
|
138
|
+
cls = Torch._dtype_to_numo[dtype]
|
139
|
+
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
140
|
+
cls.from_string(_data_str).reshape(*shape)
|
141
|
+
end
|
138
142
|
end
|
139
143
|
|
140
144
|
def requires_grad=(requires_grad)
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.17.
|
4
|
+
version: 0.17.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-08-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: 4.1
|
19
|
+
version: '4.1'
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: 4.1
|
26
|
+
version: '4.1'
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|