torch-rb 0.17.0 → 0.17.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: cd28e938ac0d61829d538f4a36e9f3377f06257ab08fd3ec8d8e3912194a101a
4
- data.tar.gz: 7bbb0c98e783c4a02e4ca46a1835c591fff3630c84d1a44abf1ef8df75c45de7
3
+ metadata.gz: f04aeddd9c66d668f78ed38d831ec66907743fd5be2cd076e9e547d00f1f2aa3
4
+ data.tar.gz: e76ca5199c4801c0eaa760ed3e41ebb6e521774bba084809656b1ed9e4b96483
5
5
  SHA512:
6
- metadata.gz: 38415fe461d01d11d620d2072012ea6a86ad000e5d25920af3d2d4cd05c3b088479cb60a2bd2b6dc013da51e3f744913b3d734fc949ac2708b825c3c041f8c52
7
- data.tar.gz: 2c54b30ccfa91f9b7576f7927377c03a9d82b7ff8aaf561e92ad89b2b00a2e7e3e08c268d606e6609cbaa13d2dc2e5e11c8329bf1b3475af7025a21dc8a1921f
6
+ metadata.gz: 683246aa91d19daa624a9d3a89fa5e4898f19fbb45a72ff2bcac0ee3025859f0f8ad26ac153e58150aedf10547cf670e2fe131cffda246546ded1f1344637040
7
+ data.tar.gz: bd0b955b997f09a274d8b59a21eafa8bf45972b6a20883b9d457a29417aa2c6d413b324abadfb6772c5bdc3aa0e633d3e25d6f23dd986fa603c55c69ae4a5a85
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.17.1 (2024-08-19)
2
+
3
+ - Added `persistent` option to `register_buffer` method
4
+ - Added `prefix` and `recurse` options to `named_buffers` method
5
+
1
6
  ## 0.17.0 (2024-07-26)
2
7
 
3
8
  - Updated LibTorch to 2.4.0
@@ -10,16 +10,23 @@ module Torch
10
10
  @parameters = {}
11
11
  @buffers = {}
12
12
  @modules = {}
13
+ @non_persistent_buffers_set = Set.new
13
14
  end
14
15
 
15
16
  def forward
16
17
  raise NotImplementedError
17
18
  end
18
19
 
19
- def register_buffer(name, tensor)
20
+ def register_buffer(name, tensor, persistent: true)
20
21
  # TODO add checks
21
22
  @buffers[name] = tensor
22
23
  instance_variable_set("@#{name}", tensor)
24
+
25
+ if persistent
26
+ @non_persistent_buffers_set.delete(name)
27
+ else
28
+ @non_persistent_buffers_set << name
29
+ end
23
30
  end
24
31
 
25
32
  def register_parameter(name, param)
@@ -190,8 +197,18 @@ module Torch
190
197
  named_buffers.values
191
198
  end
192
199
 
193
- def named_buffers
194
- @buffers || {}
200
+ # TODO set recurse: true in 0.18.0
201
+ def named_buffers(prefix: "", recurse: false)
202
+ buffers = {}
203
+ if recurse
204
+ named_children.each do |name, mod|
205
+ buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
206
+ end
207
+ end
208
+ (@buffers || {}).each do |k, v|
209
+ buffers[[prefix, k].join] = v
210
+ end
211
+ buffers
195
212
  end
196
213
 
197
214
  def children
@@ -390,7 +407,10 @@ module Torch
390
407
  destination[prefix + k] = v
391
408
  end
392
409
  named_buffers.each do |k, v|
393
- destination[prefix + k] = v
410
+ # TODO exclude v.nil?
411
+ if !@non_persistent_buffers_set.include?(k)
412
+ destination[prefix + k] = v
413
+ end
394
414
  end
395
415
  end
396
416
 
data/lib/torch/tensor.rb CHANGED
@@ -132,9 +132,13 @@ module Torch
132
132
 
133
133
  # TODO read directly from memory
134
134
  def numo
135
- cls = Torch._dtype_to_numo[dtype]
136
- raise Error, "Cannot convert #{dtype} to Numo" unless cls
137
- cls.from_string(_data_str).reshape(*shape)
135
+ if dtype == :bool
136
+ Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
137
+ else
138
+ cls = Torch._dtype_to_numo[dtype]
139
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
140
+ cls.from_string(_data_str).reshape(*shape)
141
+ end
138
142
  end
139
143
 
140
144
  def requires_grad=(requires_grad)
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.17.0"
2
+ VERSION = "0.17.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.17.0
4
+ version: 0.17.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-07-26 00:00:00.000000000 Z
11
+ date: 2024-08-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 4.1.0
19
+ version: '4.1'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 4.1.0
26
+ version: '4.1'
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []