torch-rb 0.15.0 → 0.17.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 2,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 4,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class ELU < Module
4
+ def initialize(alpha: 1, inplace: false)
5
+ super()
6
+ @alpha = alpha
7
+ @inplace = inplace
8
+ end
9
+
10
+ def forward(input)
11
+ F.elu(input, alpha: @alpha, inplace: @inplace)
12
+ end
13
+
14
+ def extra_inspect
15
+ inplace_str = @inplace ? ", inplace: true" : ""
16
+ format("alpha: %s", @alpha) + inplace_str
17
+ end
18
+ end
19
+ end
20
+ end
@@ -134,7 +134,7 @@ module Torch
134
134
  raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
135
135
 
136
136
  if mode == "constant"
137
- return Torch.constant_pad_nd(input, pad, value)
137
+ Torch.constant_pad_nd(input, pad, value)
138
138
  else
139
139
  raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
140
140
 
@@ -174,6 +174,18 @@ module Torch
174
174
 
175
175
  # activation layers
176
176
 
177
+ def elu(input, alpha: 1, inplace: false)
178
+ if inplace
179
+ NN.elu!(input, alpha)
180
+ else
181
+ NN.elu(input, alpha)
182
+ end
183
+ end
184
+
185
+ def gelu(input, approximate: 'none')
186
+ NN.gelu(input, approximate: approximate)
187
+ end
188
+
177
189
  def hardshrink(input, lambd = 0.5)
178
190
  Torch.hardshrink(input, lambd)
179
191
  end
@@ -469,6 +481,16 @@ module Torch
469
481
  Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
470
482
  end
471
483
 
484
+ def normalize(input, p: 2.0, dim: 1, eps: 1e-12, out: nil)
485
+ if out.nil?
486
+ denom = input.norm(p, dim, keepdim: true).clamp_min(eps).expand_as(input)
487
+ input / denom
488
+ else
489
+ denom = input.norm(p, dim, keepdim: true).clamp_min!(eps).expand_as(input)
490
+ Torch.div(input, denom, out: out)
491
+ end
492
+ end
493
+
472
494
  # vision
473
495
 
474
496
  def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
@@ -5,10 +5,10 @@ module Torch
5
5
  def in_projection_packed(q, k, v, w, b: nil)
6
6
  e = q.size(-1)
7
7
 
8
- if k.eql? v
9
- if q.eql? k
8
+ if k.eql?(v)
9
+ if q.eql?(k)
10
10
  # self-attention
11
- return linear(q, w, b).chunk(3, dim: -1)
11
+ linear(q, w, b).chunk(3, dim: -1)
12
12
  else
13
13
  # encoder-decoder attention
14
14
  w_q, w_kv = w.split_with_sizes([e, e * 2])
@@ -18,7 +18,7 @@ module Torch
18
18
  b_q, b_kv = b.split_with_sizes([e, e * 2])
19
19
  end
20
20
 
21
- return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
21
+ [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
22
  end
23
23
  else
24
24
  w_q, w_k, w_v = w.chunk(3)
@@ -28,7 +28,7 @@ module Torch
28
28
  b_q, b_k, b_v = b.chunk(3)
29
29
  end
30
30
 
31
- return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
31
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
32
  end
33
33
  end
34
34
 
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class GELU < Module
4
+ def initialize(approximate: 'none')
5
+ super()
6
+ @approximate = approximate
7
+ end
8
+
9
+ def forward(input)
10
+ F.gelu(input, approximate: @approximate)
11
+ end
12
+
13
+ def extra_inspect
14
+ "approximate: #{@approximate.inspect}"
15
+ end
16
+ end
17
+ end
18
+ end
@@ -13,7 +13,7 @@ module Torch
13
13
 
14
14
  def extra_inspect
15
15
  inplace_str = @inplace ? ", inplace: true" : ""
16
- format("negative_slope: %s%s", @negative_slope, inplace_str)
16
+ format("negative_slope: %s", @negative_slope) + inplace_str
17
17
  end
18
18
  end
19
19
  end
data/lib/torch/tensor.rb CHANGED
@@ -57,7 +57,7 @@ module Torch
57
57
  if shape.empty?
58
58
  arr
59
59
  else
60
- shape[1..-1].reverse.each do |dim|
60
+ shape[1..-1].reverse_each do |dim|
61
61
  arr = arr.each_slice(dim)
62
62
  end
63
63
  arr.to_a
@@ -160,6 +160,7 @@ module Torch
160
160
  # based on python_variable_indexing.cpp and
161
161
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
162
162
  def [](*indexes)
163
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
163
164
  _index(indexes)
164
165
  end
165
166
 
@@ -167,6 +168,7 @@ module Torch
167
168
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
168
169
  def []=(*indexes, value)
169
170
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
171
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
170
172
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
171
173
  _index_put_custom(indexes, value)
172
174
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.15.0"
2
+ VERSION = "0.17.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -123,6 +123,8 @@ require_relative "torch/nn/dropout3d"
123
123
  require_relative "torch/nn/feature_alpha_dropout"
124
124
 
125
125
  # nn activations
126
+ require_relative "torch/nn/elu"
127
+ require_relative "torch/nn/gelu"
126
128
  require_relative "torch/nn/hardshrink"
127
129
  require_relative "torch/nn/leaky_relu"
128
130
  require_relative "torch/nn/log_sigmoid"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.15.0
4
+ version: 0.17.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-02-29 00:00:00.000000000 Z
11
+ date: 2024-07-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -103,12 +103,14 @@ files:
103
103
  - lib/torch/nn/dropout2d.rb
104
104
  - lib/torch/nn/dropout3d.rb
105
105
  - lib/torch/nn/dropoutnd.rb
106
+ - lib/torch/nn/elu.rb
106
107
  - lib/torch/nn/embedding.rb
107
108
  - lib/torch/nn/embedding_bag.rb
108
109
  - lib/torch/nn/feature_alpha_dropout.rb
109
110
  - lib/torch/nn/fold.rb
110
111
  - lib/torch/nn/functional.rb
111
112
  - lib/torch/nn/functional_attention.rb
113
+ - lib/torch/nn/gelu.rb
112
114
  - lib/torch/nn/group_norm.rb
113
115
  - lib/torch/nn/gru.rb
114
116
  - lib/torch/nn/hardshrink.rb
@@ -230,14 +232,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
230
232
  requirements:
231
233
  - - ">="
232
234
  - !ruby/object:Gem::Version
233
- version: '3'
235
+ version: '3.1'
234
236
  required_rubygems_version: !ruby/object:Gem::Requirement
235
237
  requirements:
236
238
  - - ">="
237
239
  - !ruby/object:Gem::Version
238
240
  version: '0'
239
241
  requirements: []
240
- rubygems_version: 3.5.3
242
+ rubygems_version: 3.5.11
241
243
  signing_key:
242
244
  specification_version: 4
243
245
  summary: Deep learning for Ruby, powered by LibTorch