torch-rb 0.17.0 → 0.18.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 4,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 5,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -10,16 +10,23 @@ module Torch
10
10
  @parameters = {}
11
11
  @buffers = {}
12
12
  @modules = {}
13
+ @non_persistent_buffers_set = Set.new
13
14
  end
14
15
 
15
16
  def forward
16
17
  raise NotImplementedError
17
18
  end
18
19
 
19
- def register_buffer(name, tensor)
20
+ def register_buffer(name, tensor, persistent: true)
20
21
  # TODO add checks
21
22
  @buffers[name] = tensor
22
23
  instance_variable_set("@#{name}", tensor)
24
+
25
+ if persistent
26
+ @non_persistent_buffers_set.delete(name)
27
+ else
28
+ @non_persistent_buffers_set << name
29
+ end
23
30
  end
24
31
 
25
32
  def register_parameter(name, param)
@@ -190,8 +197,18 @@ module Torch
190
197
  named_buffers.values
191
198
  end
192
199
 
193
- def named_buffers
194
- @buffers || {}
200
+ # TODO set recurse: true in 0.18.0
201
+ def named_buffers(prefix: "", recurse: false)
202
+ buffers = {}
203
+ if recurse
204
+ named_children.each do |name, mod|
205
+ buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
206
+ end
207
+ end
208
+ (@buffers || {}).each do |k, v|
209
+ buffers[[prefix, k].join] = v
210
+ end
211
+ buffers
195
212
  end
196
213
 
197
214
  def children
@@ -390,7 +407,10 @@ module Torch
390
407
  destination[prefix + k] = v
391
408
  end
392
409
  named_buffers.each do |k, v|
393
- destination[prefix + k] = v
410
+ # TODO exclude v.nil?
411
+ if !@non_persistent_buffers_set.include?(k)
412
+ destination[prefix + k] = v
413
+ end
394
414
  end
395
415
  end
396
416
 
data/lib/torch/tensor.rb CHANGED
@@ -24,6 +24,7 @@ module Torch
24
24
  alias_method :^, :logical_xor
25
25
  alias_method :<<, :__lshift__
26
26
  alias_method :>>, :__rshift__
27
+ alias_method :~, :bitwise_not
27
28
 
28
29
  def self.new(*args)
29
30
  FloatTensor.new(*args)
@@ -132,9 +133,13 @@ module Torch
132
133
 
133
134
  # TODO read directly from memory
134
135
  def numo
135
- cls = Torch._dtype_to_numo[dtype]
136
- raise Error, "Cannot convert #{dtype} to Numo" unless cls
137
- cls.from_string(_data_str).reshape(*shape)
136
+ if dtype == :bool
137
+ Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
138
+ else
139
+ cls = Torch._dtype_to_numo[dtype]
140
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
141
+ cls.from_string(_data_str).reshape(*shape)
142
+ end
138
143
  end
139
144
 
140
145
  def requires_grad=(requires_grad)
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.17.0"
2
+ VERSION = "0.18.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.17.0
4
+ version: 0.18.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-07-26 00:00:00.000000000 Z
11
+ date: 2024-10-22 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 4.1.0
19
+ version: 4.3.3
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 4.1.0
26
+ version: 4.3.3
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []
@@ -239,7 +239,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
239
239
  - !ruby/object:Gem::Version
240
240
  version: '0'
241
241
  requirements: []
242
- rubygems_version: 3.5.11
242
+ rubygems_version: 3.5.16
243
243
  signing_key:
244
244
  specification_version: 4
245
245
  summary: Deep learning for Ruby, powered by LibTorch