torch-rb 0.16.0 → 0.17.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_nn_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_sparse_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_tensor_functions(Rice::Module& m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 3,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 4,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -134,7 +134,7 @@ module Torch
134
134
  raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
135
135
 
136
136
  if mode == "constant"
137
- return Torch.constant_pad_nd(input, pad, value)
137
+ Torch.constant_pad_nd(input, pad, value)
138
138
  else
139
139
  raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
140
140
 
@@ -481,6 +481,16 @@ module Torch
481
481
  Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
482
482
  end
483
483
 
484
+ def normalize(input, p: 2.0, dim: 1, eps: 1e-12, out: nil)
485
+ if out.nil?
486
+ denom = input.norm(p, dim, keepdim: true).clamp_min(eps).expand_as(input)
487
+ input / denom
488
+ else
489
+ denom = input.norm(p, dim, keepdim: true).clamp_min!(eps).expand_as(input)
490
+ Torch.div(input, denom, out: out)
491
+ end
492
+ end
493
+
484
494
  # vision
485
495
 
486
496
  def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
@@ -5,10 +5,10 @@ module Torch
5
5
  def in_projection_packed(q, k, v, w, b: nil)
6
6
  e = q.size(-1)
7
7
 
8
- if k.eql? v
9
- if q.eql? k
8
+ if k.eql?(v)
9
+ if q.eql?(k)
10
10
  # self-attention
11
- return linear(q, w, b).chunk(3, dim: -1)
11
+ linear(q, w, b).chunk(3, dim: -1)
12
12
  else
13
13
  # encoder-decoder attention
14
14
  w_q, w_kv = w.split_with_sizes([e, e * 2])
@@ -18,7 +18,7 @@ module Torch
18
18
  b_q, b_kv = b.split_with_sizes([e, e * 2])
19
19
  end
20
20
 
21
- return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
21
+ [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
22
  end
23
23
  else
24
24
  w_q, w_k, w_v = w.chunk(3)
@@ -28,7 +28,7 @@ module Torch
28
28
  b_q, b_k, b_v = b.chunk(3)
29
29
  end
30
30
 
31
- return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
31
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
32
  end
33
33
  end
34
34
 
data/lib/torch/tensor.rb CHANGED
@@ -57,7 +57,7 @@ module Torch
57
57
  if shape.empty?
58
58
  arr
59
59
  else
60
- shape[1..-1].reverse.each do |dim|
60
+ shape[1..-1].reverse_each do |dim|
61
61
  arr = arr.each_slice(dim)
62
62
  end
63
63
  arr.to_a
@@ -160,6 +160,7 @@ module Torch
160
160
  # based on python_variable_indexing.cpp and
161
161
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
162
162
  def [](*indexes)
163
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
163
164
  _index(indexes)
164
165
  end
165
166
 
@@ -167,6 +168,7 @@ module Torch
167
168
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
168
169
  def []=(*indexes, value)
169
170
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
171
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
170
172
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
171
173
  _index_put_custom(indexes, value)
172
174
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.16.0"
2
+ VERSION = "0.17.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.16.0
4
+ version: 0.17.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-06-13 00:00:00.000000000 Z
11
+ date: 2024-07-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -43,17 +43,24 @@ files:
43
43
  - ext/torch/ext.cpp
44
44
  - ext/torch/extconf.rb
45
45
  - ext/torch/fft.cpp
46
+ - ext/torch/fft_functions.h
46
47
  - ext/torch/generator.cpp
47
48
  - ext/torch/ivalue.cpp
48
49
  - ext/torch/linalg.cpp
50
+ - ext/torch/linalg_functions.h
49
51
  - ext/torch/nn.cpp
52
+ - ext/torch/nn_functions.h
50
53
  - ext/torch/random.cpp
51
54
  - ext/torch/ruby_arg_parser.cpp
52
55
  - ext/torch/ruby_arg_parser.h
56
+ - ext/torch/sparse_functions.h
53
57
  - ext/torch/special.cpp
58
+ - ext/torch/special_functions.h
54
59
  - ext/torch/templates.h
55
60
  - ext/torch/tensor.cpp
61
+ - ext/torch/tensor_functions.h
56
62
  - ext/torch/torch.cpp
63
+ - ext/torch/torch_functions.h
57
64
  - ext/torch/utils.h
58
65
  - ext/torch/wrap_outputs.h
59
66
  - lib/torch-rb.rb