torch-rb 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +5 -3
  4. data/ext/torch/ext.cpp +22 -548
  5. data/ext/torch/extconf.rb +6 -0
  6. data/ext/torch/nn_functions.cpp +595 -0
  7. data/ext/torch/nn_functions.hpp +6 -0
  8. data/ext/torch/templates.hpp +250 -0
  9. data/ext/torch/tensor_functions.cpp +1860 -0
  10. data/ext/torch/tensor_functions.hpp +6 -0
  11. data/ext/torch/torch_functions.cpp +2875 -0
  12. data/ext/torch/torch_functions.hpp +6 -0
  13. data/lib/torch.rb +68 -129
  14. data/lib/torch/ext.bundle +0 -0
  15. data/lib/torch/native/dispatcher.rb +48 -0
  16. data/lib/torch/native/function.rb +78 -0
  17. data/lib/torch/native/generator.rb +149 -0
  18. data/lib/torch/native/native_functions.yaml +6837 -0
  19. data/lib/torch/native/parser.rb +97 -0
  20. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  21. data/lib/torch/nn/conv2d.rb +0 -2
  22. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  23. data/lib/torch/nn/functional.rb +55 -16
  24. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  25. data/lib/torch/nn/identity.rb +1 -0
  26. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  27. data/lib/torch/nn/module.rb +59 -12
  28. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  29. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  30. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  31. data/lib/torch/nn/parameter.rb +4 -0
  32. data/lib/torch/nn/rnn.rb +22 -0
  33. data/lib/torch/nn/rnn_base.rb +154 -0
  34. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  35. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  36. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  37. data/lib/torch/tensor.rb +19 -19
  38. data/lib/torch/version.rb +1 -1
  39. metadata +26 -2
@@ -0,0 +1,97 @@
1
+ module Torch
2
+ module Native
3
+ class Parser
4
+ def initialize(functions)
5
+ @functions = functions
6
+ @name = @functions.first.ruby_name
7
+ @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
8
+ @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
+ end
10
+
11
+ def parse(args, options)
12
+ candidates = @functions.dup
13
+
14
+ if args.size < @min_args || args.size > @max_args
15
+ expected = String.new(@min_args.to_s)
16
+ expected += "..#{@max_args}" if @max_args != @min_args
17
+ return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
18
+ end
19
+
20
+ # exclude functions where options don't match
21
+ options.each do |k, v|
22
+ candidates.select! do |func|
23
+ func.args.any? { |a| a[:name] == k.to_s }
24
+ end
25
+ # TODO show all bad keywords at once like Ruby?
26
+ return {error: "unknown keyword: #{k}"} if candidates.empty?
27
+ end
28
+
29
+ # exclude functions missing required options
30
+ candidates.reject! do |func|
31
+ # TODO make more generic
32
+ func.out? && !options[:out]
33
+ end
34
+
35
+ final_values = {}
36
+
37
+ # check args
38
+ candidates.select! do |func|
39
+ good = true
40
+
41
+ values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
42
+ values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
43
+ func.args.each do |fa|
44
+ values[fa[:name]] ||= fa[:default]
45
+ end
46
+
47
+ arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
48
+
49
+ values.each do |k, v|
50
+ t = arg_types[k].split("(").first
51
+ good =
52
+ case t
53
+ when "Tensor"
54
+ v.is_a?(Tensor)
55
+ when "Tensor[]"
56
+ v.all? { |v2| v2.is_a?(Tensor) }
57
+ when "int"
58
+ v.is_a?(Integer)
59
+ when "int[]"
60
+ v.all? { |v2| v2.is_a?(Integer) }
61
+ when "Scalar"
62
+ v.is_a?(Numeric)
63
+ when "bool"
64
+ v == true || v == false
65
+ else
66
+ raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
67
+ end
68
+
69
+ if !good
70
+ if candidates.size == 1
71
+ k = "input" if k == "self"
72
+ return {error: "#{@name}(): argument '#{k}' must be #{t}"}
73
+ end
74
+ break
75
+ end
76
+ end
77
+
78
+ if good
79
+ final_values = values
80
+ end
81
+
82
+ good
83
+ end
84
+
85
+ if candidates.size != 1
86
+ raise Error, "This should never happen. Please report a bug with #{@name}."
87
+ end
88
+
89
+ func = candidates.first
90
+ {
91
+ name: func.cpp_name,
92
+ args: func.args.map { |a| final_values[a[:name]] }
93
+ }
94
+ end
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class BCEWithLogitsLoss < Loss
4
+ def initialize(weight: nil, reduction: "mean", pos_weight: nil)
5
+ super(reduction)
6
+ register_buffer("weight", weight)
7
+ register_buffer("pos_weight", pos_weight)
8
+ end
9
+
10
+ def forward(input, target)
11
+ F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -1,8 +1,6 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Conv2d < ConvNd
4
- attr_reader :bias, :weight
5
-
6
4
  def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
7
5
  kernel_size = pair(kernel_size)
8
6
  stride = pair(stride)
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CosineEmbeddingLoss < Loss
4
+ def initialize(margin: 0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -50,7 +50,8 @@ module Torch
50
50
  raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
51
51
 
52
52
  padding_idx ||= -1
53
- Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
53
+ # weight and indices are swapped from Python interface
54
+ Torch._embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
54
55
  end
55
56
 
56
57
  def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
@@ -76,8 +77,15 @@ module Torch
76
77
  # loss functions
77
78
 
78
79
  def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
79
- raise NotImplementedYet if weight
80
- Torch.binary_cross_entropy(input, target, reduction)
80
+ NN._binary_cross_entropy(input, target, weight, reduction)
81
+ end
82
+
83
+ def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
84
+ Torch._binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
85
+ end
86
+
87
+ def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
88
+ raise NotImplementedYet
81
89
  end
82
90
 
83
91
  def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
@@ -86,28 +94,59 @@ module Torch
86
94
 
87
95
  def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
88
96
  # call to_a on input_lengths and target_lengths for C++
89
- Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
97
+ Torch._ctc_loss_intlist(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
98
+ end
99
+
100
+ def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
101
+ Torch._hinge_embedding_loss(input, target, margin, reduction)
90
102
  end
91
103
 
92
104
  def kl_div(input, target, reduction: "mean")
93
- Torch.kl_div(input, target, reduction)
105
+ Torch._kl_div(input, target, reduction)
94
106
  end
95
107
 
96
108
  def l1_loss(input, target, reduction: "mean")
97
- Torch.l1_loss(input, target, reduction)
109
+ NN._l1_loss(input, target, reduction)
110
+ end
111
+
112
+ def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
113
+ raise NotImplementedYet
98
114
  end
99
115
 
100
116
  def mse_loss(input, target, reduction: "mean")
101
- Torch.mse_loss(input, target, reduction)
117
+ NN._mse_loss(input, target, reduction)
118
+ end
119
+
120
+ def multilabel_margin_loss(input, target, reduction: "mean")
121
+ NN._multilabel_margin_loss(input, target, reduction)
122
+ end
123
+
124
+ def multilabel_soft_margin_loss(input, target, weight: nil)
125
+ raise NotImplementedYet
126
+ end
127
+
128
+ def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
129
+ NN._multi_margin_loss(input, target, p, margin, weight, reduction)
102
130
  end
103
131
 
104
132
  def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
105
- raise NotImplementedYet if weight
106
- Torch.nll_loss(input, target, reduction, ignore_index)
133
+ NN._nll_loss(input, target, weight, reduction, ignore_index)
107
134
  end
108
135
 
109
136
  def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
110
- Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
137
+ Torch._poisson_nll_loss(input, target, log_input, full, eps, reduction)
138
+ end
139
+
140
+ def soft_margin_loss(input, target, reduction: "mean")
141
+ NN._soft_margin_loss(input, target, reduction)
142
+ end
143
+
144
+ def smooth_l1_loss(input, target, reduction: "mean")
145
+ NN._smooth_l1_loss(input, target, reduction)
146
+ end
147
+
148
+ def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
149
+ Torch._triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
111
150
  end
112
151
 
113
152
  # end loss
@@ -123,7 +162,7 @@ module Torch
123
162
  end
124
163
 
125
164
  def softplus(input, beta: 1, threshold: 20)
126
- Torch._softplus(input, beta, threshold)
165
+ NN._softplus(input, beta, threshold)
127
166
  end
128
167
 
129
168
  # TODO make dim keyword argument and update examples
@@ -134,7 +173,7 @@ module Torch
134
173
 
135
174
  def dropout(input, p: 0.5, training: true, inplace: false)
136
175
  if inplace
137
- Torch._dropout!(input, p, training)
176
+ Torch._dropout_(input, p, training)
138
177
  else
139
178
  Torch._dropout(input, p, training)
140
179
  end
@@ -144,7 +183,7 @@ module Torch
144
183
  raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
145
184
 
146
185
  if inplace
147
- Torch._feature_dropout!(input, p, training)
186
+ Torch._feature_dropout_(input, p, training)
148
187
  else
149
188
  Torch._feature_dropout(input, p, training)
150
189
  end
@@ -152,7 +191,7 @@ module Torch
152
191
 
153
192
  def dropout3d(input, p: 0.5, training: true, inplace: false)
154
193
  if inplace
155
- Torch._feature_dropout!(input, p, training)
194
+ Torch._feature_dropout_(input, p, training)
156
195
  else
157
196
  Torch._feature_dropout(input, p, training)
158
197
  end
@@ -160,7 +199,7 @@ module Torch
160
199
 
161
200
  def alpha_dropout(input, p: 0.5, training: true, inplace: false)
162
201
  if inplace
163
- Torch._alpha_dropout!(input, p, training)
202
+ Torch._alpha_dropout_(input, p, training)
164
203
  else
165
204
  Torch._alpha_dropout(input, p, training)
166
205
  end
@@ -168,7 +207,7 @@ module Torch
168
207
 
169
208
  def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
170
209
  if inplace
171
- Torch._feature_alpha_dropout!(input, p, training)
210
+ Torch._feature_alpha_dropout_(input, p, training)
172
211
  else
173
212
  Torch._feature_alpha_dropout(input, p, training)
174
213
  end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class HingeEmbeddingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.hinge_embedding_loss(input, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Identity < Module
4
+ # written this way to support unused arguments
4
5
  def initialize(*args, **options)
5
6
  super()
6
7
  end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class MarginRankingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -85,25 +85,56 @@ module Torch
85
85
  end
86
86
 
87
87
  def parameters
88
- params = []
88
+ named_parameters.values
89
+ end
90
+
91
+ def named_parameters(prefix: "", recurse: true)
92
+ params = {}
93
+ if recurse
94
+ named_children.each do |name, mod|
95
+ params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
96
+ end
97
+ end
89
98
  instance_variables.each do |name|
90
99
  param = instance_variable_get(name)
91
- params << param if param.is_a?(Parameter)
100
+ params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
101
+ end
102
+ @parameters.each do |name, param|
103
+ params[[prefix, name].join] = param
92
104
  end
93
- params + modules.flat_map { |_, mod| mod.parameters }
105
+ params
106
+ end
107
+
108
+ def buffers
109
+ named_buffers.values
110
+ end
111
+
112
+ def named_buffers
113
+ @buffers || {}
94
114
  end
95
115
 
96
116
  def children
97
- @modules.values
117
+ named_children.values
98
118
  end
99
119
 
100
- def modules
120
+ def named_children
101
121
  modules = {}
102
122
  instance_variables.each do |name|
103
123
  mod = instance_variable_get(name)
104
124
  modules[name[1..-1]] = mod if mod.is_a?(Module)
105
125
  end
106
- @modules.merge(modules)
126
+ @modules.each do |name, mod|
127
+ modules[name] = mod
128
+ end
129
+ modules
130
+ end
131
+
132
+ def modules
133
+ named_modules.values
134
+ end
135
+
136
+ def named_modules
137
+ {"" => self}.merge(named_children)
107
138
  end
108
139
 
109
140
  def train(mode = true)
@@ -140,12 +171,12 @@ module Torch
140
171
 
141
172
  def inspect
142
173
  name = self.class.name.split("::").last
143
- if modules.empty?
174
+ if children.empty?
144
175
  "#{name}(#{extra_inspect})"
145
176
  else
146
177
  str = String.new
147
178
  str << "#{name}(\n"
148
- modules.each do |name, mod|
179
+ children.each do |name, mod|
149
180
  str << " (#{name}): #{mod.inspect}\n"
150
181
  end
151
182
  str << ")"
@@ -153,11 +184,21 @@ module Torch
153
184
  end
154
185
 
155
186
  def method_missing(method, *args, &block)
156
- modules[method.to_s] || super
187
+ name = method.to_s
188
+ if named_parameters.key?(name)
189
+ named_parameters[name]
190
+ elsif named_buffers.key?(name)
191
+ named_buffers[name]
192
+ elsif named_modules.key?(name)
193
+ named_modules[name]
194
+ else
195
+ super
196
+ end
157
197
  end
158
198
 
159
199
  def respond_to?(method, include_private = false)
160
- modules.key?(method.to_s) || super
200
+ name = method.to_s
201
+ named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
161
202
  end
162
203
 
163
204
  private
@@ -166,8 +207,14 @@ module Torch
166
207
  nil
167
208
  end
168
209
 
169
- def format(str, *vars)
170
- str % vars.map(&:inspect)
210
+ def format(str, *vars, **options)
211
+ vars =
212
+ if vars.any?
213
+ vars.map(&:inspect)
214
+ else
215
+ options.map { |k, v| [k, v.inspect] }.to_h
216
+ end
217
+ str % vars
171
218
  end
172
219
  end
173
220
  end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelMarginLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_margin_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end