torch-rb 0.16.0 → 0.17.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_nn_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_sparse_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_tensor_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 3,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 4,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -134,7 +134,7 @@ module Torch
134
134
  raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
135
135
 
136
136
  if mode == "constant"
137
- return Torch.constant_pad_nd(input, pad, value)
137
+ Torch.constant_pad_nd(input, pad, value)
138
138
  else
139
139
  raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
140
140
 
@@ -481,6 +481,16 @@ module Torch
481
481
  Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
482
482
  end
483
483
 
484
+ def normalize(input, p: 2.0, dim: 1, eps: 1e-12, out: nil)
485
+ if out.nil?
486
+ denom = input.norm(p, dim, keepdim: true).clamp_min(eps).expand_as(input)
487
+ input / denom
488
+ else
489
+ denom = input.norm(p, dim, keepdim: true).clamp_min!(eps).expand_as(input)
490
+ Torch.div(input, denom, out: out)
491
+ end
492
+ end
493
+
484
494
  # vision
485
495
 
486
496
  def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
@@ -5,10 +5,10 @@ module Torch
5
5
  def in_projection_packed(q, k, v, w, b: nil)
6
6
  e = q.size(-1)
7
7
 
8
- if k.eql? v
9
- if q.eql? k
8
+ if k.eql?(v)
9
+ if q.eql?(k)
10
10
  # self-attention
11
- return linear(q, w, b).chunk(3, dim: -1)
11
+ linear(q, w, b).chunk(3, dim: -1)
12
12
  else
13
13
  # encoder-decoder attention
14
14
  w_q, w_kv = w.split_with_sizes([e, e * 2])
@@ -18,7 +18,7 @@ module Torch
18
18
  b_q, b_kv = b.split_with_sizes([e, e * 2])
19
19
  end
20
20
 
21
- return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
21
+ [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
22
  end
23
23
  else
24
24
  w_q, w_k, w_v = w.chunk(3)
@@ -28,7 +28,7 @@ module Torch
28
28
  b_q, b_k, b_v = b.chunk(3)
29
29
  end
30
30
 
31
- return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
31
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
32
  end
33
33
  end
34
34
 
@@ -10,16 +10,23 @@ module Torch
10
10
  @parameters = {}
11
11
  @buffers = {}
12
12
  @modules = {}
13
+ @non_persistent_buffers_set = Set.new
13
14
  end
14
15
 
15
16
  def forward
16
17
  raise NotImplementedError
17
18
  end
18
19
 
19
- def register_buffer(name, tensor)
20
+ def register_buffer(name, tensor, persistent: true)
20
21
  # TODO add checks
21
22
  @buffers[name] = tensor
22
23
  instance_variable_set("@#{name}", tensor)
24
+
25
+ if persistent
26
+ @non_persistent_buffers_set.delete(name)
27
+ else
28
+ @non_persistent_buffers_set << name
29
+ end
23
30
  end
24
31
 
25
32
  def register_parameter(name, param)
@@ -190,8 +197,18 @@ module Torch
190
197
  named_buffers.values
191
198
  end
192
199
 
193
- def named_buffers
194
- @buffers || {}
200
+ # TODO set recurse: true in 0.18.0
201
+ def named_buffers(prefix: "", recurse: false)
202
+ buffers = {}
203
+ if recurse
204
+ named_children.each do |name, mod|
205
+ buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
206
+ end
207
+ end
208
+ (@buffers || {}).each do |k, v|
209
+ buffers[[prefix, k].join] = v
210
+ end
211
+ buffers
195
212
  end
196
213
 
197
214
  def children
@@ -390,7 +407,10 @@ module Torch
390
407
  destination[prefix + k] = v
391
408
  end
392
409
  named_buffers.each do |k, v|
393
- destination[prefix + k] = v
410
+ # TODO exclude v.nil?
411
+ if !@non_persistent_buffers_set.include?(k)
412
+ destination[prefix + k] = v
413
+ end
394
414
  end
395
415
  end
396
416
 
data/lib/torch/tensor.rb CHANGED
@@ -57,7 +57,7 @@ module Torch
57
57
  if shape.empty?
58
58
  arr
59
59
  else
60
- shape[1..-1].reverse.each do |dim|
60
+ shape[1..-1].reverse_each do |dim|
61
61
  arr = arr.each_slice(dim)
62
62
  end
63
63
  arr.to_a
@@ -132,9 +132,13 @@ module Torch
132
132
 
133
133
  # TODO read directly from memory
134
134
  def numo
135
- cls = Torch._dtype_to_numo[dtype]
136
- raise Error, "Cannot convert #{dtype} to Numo" unless cls
137
- cls.from_string(_data_str).reshape(*shape)
135
+ if dtype == :bool
136
+ Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
137
+ else
138
+ cls = Torch._dtype_to_numo[dtype]
139
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
140
+ cls.from_string(_data_str).reshape(*shape)
141
+ end
138
142
  end
139
143
 
140
144
  def requires_grad=(requires_grad)
@@ -160,6 +164,7 @@ module Torch
160
164
  # based on python_variable_indexing.cpp and
161
165
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
162
166
  def [](*indexes)
167
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
163
168
  _index(indexes)
164
169
  end
165
170
 
@@ -167,6 +172,7 @@ module Torch
167
172
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
168
173
  def []=(*indexes, value)
169
174
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
175
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
170
176
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
171
177
  _index_put_custom(indexes, value)
172
178
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.16.0"
2
+ VERSION = "0.17.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.16.0
4
+ version: 0.17.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-06-13 00:00:00.000000000 Z
11
+ date: 2024-08-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 4.1.0
19
+ version: '4.1'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 4.1.0
26
+ version: '4.1'
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []
@@ -43,17 +43,24 @@ files:
43
43
  - ext/torch/ext.cpp
44
44
  - ext/torch/extconf.rb
45
45
  - ext/torch/fft.cpp
46
+ - ext/torch/fft_functions.h
46
47
  - ext/torch/generator.cpp
47
48
  - ext/torch/ivalue.cpp
48
49
  - ext/torch/linalg.cpp
50
+ - ext/torch/linalg_functions.h
49
51
  - ext/torch/nn.cpp
52
+ - ext/torch/nn_functions.h
50
53
  - ext/torch/random.cpp
51
54
  - ext/torch/ruby_arg_parser.cpp
52
55
  - ext/torch/ruby_arg_parser.h
56
+ - ext/torch/sparse_functions.h
53
57
  - ext/torch/special.cpp
58
+ - ext/torch/special_functions.h
54
59
  - ext/torch/templates.h
55
60
  - ext/torch/tensor.cpp
61
+ - ext/torch/tensor_functions.h
56
62
  - ext/torch/torch.cpp
63
+ - ext/torch/torch_functions.h
57
64
  - ext/torch/utils.h
58
65
  - ext/torch/wrap_outputs.h
59
66
  - lib/torch-rb.rb