torch-rb 0.17.0 → 0.17.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: cd28e938ac0d61829d538f4a36e9f3377f06257ab08fd3ec8d8e3912194a101a
4
- data.tar.gz: 7bbb0c98e783c4a02e4ca46a1835c591fff3630c84d1a44abf1ef8df75c45de7
3
+ metadata.gz: f04aeddd9c66d668f78ed38d831ec66907743fd5be2cd076e9e547d00f1f2aa3
4
+ data.tar.gz: e76ca5199c4801c0eaa760ed3e41ebb6e521774bba084809656b1ed9e4b96483
5
5
  SHA512:
6
- metadata.gz: 38415fe461d01d11d620d2072012ea6a86ad000e5d25920af3d2d4cd05c3b088479cb60a2bd2b6dc013da51e3f744913b3d734fc949ac2708b825c3c041f8c52
7
- data.tar.gz: 2c54b30ccfa91f9b7576f7927377c03a9d82b7ff8aaf561e92ad89b2b00a2e7e3e08c268d606e6609cbaa13d2dc2e5e11c8329bf1b3475af7025a21dc8a1921f
6
+ metadata.gz: 683246aa91d19daa624a9d3a89fa5e4898f19fbb45a72ff2bcac0ee3025859f0f8ad26ac153e58150aedf10547cf670e2fe131cffda246546ded1f1344637040
7
+ data.tar.gz: bd0b955b997f09a274d8b59a21eafa8bf45972b6a20883b9d457a29417aa2c6d413b324abadfb6772c5bdc3aa0e633d3e25d6f23dd986fa603c55c69ae4a5a85
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.17.1 (2024-08-19)
2
+
3
+ - Added `persistent` option to `register_buffer` method
4
+ - Added `prefix` and `recurse` options to `named_buffers` method
5
+
1
6
  ## 0.17.0 (2024-07-26)
2
7
 
3
8
  - Updated LibTorch to 2.4.0
@@ -10,16 +10,23 @@ module Torch
10
10
  @parameters = {}
11
11
  @buffers = {}
12
12
  @modules = {}
13
+ @non_persistent_buffers_set = Set.new
13
14
  end
14
15
 
15
16
  def forward
16
17
  raise NotImplementedError
17
18
  end
18
19
 
19
- def register_buffer(name, tensor)
20
+ def register_buffer(name, tensor, persistent: true)
20
21
  # TODO add checks
21
22
  @buffers[name] = tensor
22
23
  instance_variable_set("@#{name}", tensor)
24
+
25
+ if persistent
26
+ @non_persistent_buffers_set.delete(name)
27
+ else
28
+ @non_persistent_buffers_set << name
29
+ end
23
30
  end
24
31
 
25
32
  def register_parameter(name, param)
@@ -190,8 +197,18 @@ module Torch
190
197
  named_buffers.values
191
198
  end
192
199
 
193
- def named_buffers
194
- @buffers || {}
200
+ # TODO set recurse: true in 0.18.0
201
+ def named_buffers(prefix: "", recurse: false)
202
+ buffers = {}
203
+ if recurse
204
+ named_children.each do |name, mod|
205
+ buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
206
+ end
207
+ end
208
+ (@buffers || {}).each do |k, v|
209
+ buffers[[prefix, k].join] = v
210
+ end
211
+ buffers
195
212
  end
196
213
 
197
214
  def children
@@ -390,7 +407,10 @@ module Torch
390
407
  destination[prefix + k] = v
391
408
  end
392
409
  named_buffers.each do |k, v|
393
- destination[prefix + k] = v
410
+ # TODO exclude v.nil?
411
+ if !@non_persistent_buffers_set.include?(k)
412
+ destination[prefix + k] = v
413
+ end
394
414
  end
395
415
  end
396
416
 
data/lib/torch/tensor.rb CHANGED
@@ -132,9 +132,13 @@ module Torch
132
132
 
133
133
  # TODO read directly from memory
134
134
  def numo
135
- cls = Torch._dtype_to_numo[dtype]
136
- raise Error, "Cannot convert #{dtype} to Numo" unless cls
137
- cls.from_string(_data_str).reshape(*shape)
135
+ if dtype == :bool
136
+ Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
137
+ else
138
+ cls = Torch._dtype_to_numo[dtype]
139
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
140
+ cls.from_string(_data_str).reshape(*shape)
141
+ end
138
142
  end
139
143
 
140
144
  def requires_grad=(requires_grad)
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.17.0"
2
+ VERSION = "0.17.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.17.0
4
+ version: 0.17.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-07-26 00:00:00.000000000 Z
11
+ date: 2024-08-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 4.1.0
19
+ version: '4.1'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 4.1.0
26
+ version: '4.1'
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []