torch-rb 0.1.4 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +5 -3
  4. data/ext/torch/ext.cpp +22 -548
  5. data/ext/torch/extconf.rb +6 -0
  6. data/ext/torch/nn_functions.cpp +595 -0
  7. data/ext/torch/nn_functions.hpp +6 -0
  8. data/ext/torch/templates.hpp +250 -0
  9. data/ext/torch/tensor_functions.cpp +1860 -0
  10. data/ext/torch/tensor_functions.hpp +6 -0
  11. data/ext/torch/torch_functions.cpp +2875 -0
  12. data/ext/torch/torch_functions.hpp +6 -0
  13. data/lib/torch.rb +68 -129
  14. data/lib/torch/ext.bundle +0 -0
  15. data/lib/torch/native/dispatcher.rb +48 -0
  16. data/lib/torch/native/function.rb +78 -0
  17. data/lib/torch/native/generator.rb +149 -0
  18. data/lib/torch/native/native_functions.yaml +6837 -0
  19. data/lib/torch/native/parser.rb +97 -0
  20. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  21. data/lib/torch/nn/conv2d.rb +0 -2
  22. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  23. data/lib/torch/nn/functional.rb +55 -16
  24. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  25. data/lib/torch/nn/identity.rb +1 -0
  26. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  27. data/lib/torch/nn/module.rb +59 -12
  28. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  29. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  30. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  31. data/lib/torch/nn/parameter.rb +4 -0
  32. data/lib/torch/nn/rnn.rb +22 -0
  33. data/lib/torch/nn/rnn_base.rb +154 -0
  34. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  35. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  36. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  37. data/lib/torch/tensor.rb +19 -19
  38. data/lib/torch/version.rb +1 -1
  39. metadata +26 -2
@@ -0,0 +1,97 @@
1
+ module Torch
2
+ module Native
3
+ class Parser
4
+ def initialize(functions)
5
+ @functions = functions
6
+ @name = @functions.first.ruby_name
7
+ @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
8
+ @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
+ end
10
+
11
+ def parse(args, options)
12
+ candidates = @functions.dup
13
+
14
+ if args.size < @min_args || args.size > @max_args
15
+ expected = String.new(@min_args.to_s)
16
+ expected += "..#{@max_args}" if @max_args != @min_args
17
+ return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
18
+ end
19
+
20
+ # exclude functions where options don't match
21
+ options.each do |k, v|
22
+ candidates.select! do |func|
23
+ func.args.any? { |a| a[:name] == k.to_s }
24
+ end
25
+ # TODO show all bad keywords at once like Ruby?
26
+ return {error: "unknown keyword: #{k}"} if candidates.empty?
27
+ end
28
+
29
+ # exclude functions missing required options
30
+ candidates.reject! do |func|
31
+ # TODO make more generic
32
+ func.out? && !options[:out]
33
+ end
34
+
35
+ final_values = {}
36
+
37
+ # check args
38
+ candidates.select! do |func|
39
+ good = true
40
+
41
+ values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
42
+ values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
43
+ func.args.each do |fa|
44
+ values[fa[:name]] ||= fa[:default]
45
+ end
46
+
47
+ arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
48
+
49
+ values.each do |k, v|
50
+ t = arg_types[k].split("(").first
51
+ good =
52
+ case t
53
+ when "Tensor"
54
+ v.is_a?(Tensor)
55
+ when "Tensor[]"
56
+ v.all? { |v2| v2.is_a?(Tensor) }
57
+ when "int"
58
+ v.is_a?(Integer)
59
+ when "int[]"
60
+ v.all? { |v2| v2.is_a?(Integer) }
61
+ when "Scalar"
62
+ v.is_a?(Numeric)
63
+ when "bool"
64
+ v == true || v == false
65
+ else
66
+ raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
67
+ end
68
+
69
+ if !good
70
+ if candidates.size == 1
71
+ k = "input" if k == "self"
72
+ return {error: "#{@name}(): argument '#{k}' must be #{t}"}
73
+ end
74
+ break
75
+ end
76
+ end
77
+
78
+ if good
79
+ final_values = values
80
+ end
81
+
82
+ good
83
+ end
84
+
85
+ if candidates.size != 1
86
+ raise Error, "This should never happen. Please report a bug with #{@name}."
87
+ end
88
+
89
+ func = candidates.first
90
+ {
91
+ name: func.cpp_name,
92
+ args: func.args.map { |a| final_values[a[:name]] }
93
+ }
94
+ end
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class BCEWithLogitsLoss < Loss
4
+ def initialize(weight: nil, reduction: "mean", pos_weight: nil)
5
+ super(reduction)
6
+ register_buffer("weight", weight)
7
+ register_buffer("pos_weight", pos_weight)
8
+ end
9
+
10
+ def forward(input, target)
11
+ F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -1,8 +1,6 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Conv2d < ConvNd
4
- attr_reader :bias, :weight
5
-
6
4
  def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
7
5
  kernel_size = pair(kernel_size)
8
6
  stride = pair(stride)
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CosineEmbeddingLoss < Loss
4
+ def initialize(margin: 0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -50,7 +50,8 @@ module Torch
50
50
  raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
51
51
 
52
52
  padding_idx ||= -1
53
- Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
53
+ # weight and indices are swapped from Python interface
54
+ Torch._embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
54
55
  end
55
56
 
56
57
  def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
@@ -76,8 +77,15 @@ module Torch
76
77
  # loss functions
77
78
 
78
79
  def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
79
- raise NotImplementedYet if weight
80
- Torch.binary_cross_entropy(input, target, reduction)
80
+ NN._binary_cross_entropy(input, target, weight, reduction)
81
+ end
82
+
83
+ def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
84
+ Torch._binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
85
+ end
86
+
87
+ def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
88
+ raise NotImplementedYet
81
89
  end
82
90
 
83
91
  def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
@@ -86,28 +94,59 @@ module Torch
86
94
 
87
95
  def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
88
96
  # call to_a on input_lengths and target_lengths for C++
89
- Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
97
+ Torch._ctc_loss_intlist(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
98
+ end
99
+
100
+ def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
101
+ Torch._hinge_embedding_loss(input, target, margin, reduction)
90
102
  end
91
103
 
92
104
  def kl_div(input, target, reduction: "mean")
93
- Torch.kl_div(input, target, reduction)
105
+ Torch._kl_div(input, target, reduction)
94
106
  end
95
107
 
96
108
  def l1_loss(input, target, reduction: "mean")
97
- Torch.l1_loss(input, target, reduction)
109
+ NN._l1_loss(input, target, reduction)
110
+ end
111
+
112
+ def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
113
+ raise NotImplementedYet
98
114
  end
99
115
 
100
116
  def mse_loss(input, target, reduction: "mean")
101
- Torch.mse_loss(input, target, reduction)
117
+ NN._mse_loss(input, target, reduction)
118
+ end
119
+
120
+ def multilabel_margin_loss(input, target, reduction: "mean")
121
+ NN._multilabel_margin_loss(input, target, reduction)
122
+ end
123
+
124
+ def multilabel_soft_margin_loss(input, target, weight: nil)
125
+ raise NotImplementedYet
126
+ end
127
+
128
+ def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
129
+ NN._multi_margin_loss(input, target, p, margin, weight, reduction)
102
130
  end
103
131
 
104
132
  def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
105
- raise NotImplementedYet if weight
106
- Torch.nll_loss(input, target, reduction, ignore_index)
133
+ NN._nll_loss(input, target, weight, reduction, ignore_index)
107
134
  end
108
135
 
109
136
  def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
110
- Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
137
+ Torch._poisson_nll_loss(input, target, log_input, full, eps, reduction)
138
+ end
139
+
140
+ def soft_margin_loss(input, target, reduction: "mean")
141
+ NN._soft_margin_loss(input, target, reduction)
142
+ end
143
+
144
+ def smooth_l1_loss(input, target, reduction: "mean")
145
+ NN._smooth_l1_loss(input, target, reduction)
146
+ end
147
+
148
+ def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
149
+ Torch._triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
111
150
  end
112
151
 
113
152
  # end loss
@@ -123,7 +162,7 @@ module Torch
123
162
  end
124
163
 
125
164
  def softplus(input, beta: 1, threshold: 20)
126
- Torch._softplus(input, beta, threshold)
165
+ NN._softplus(input, beta, threshold)
127
166
  end
128
167
 
129
168
  # TODO make dim keyword argument and update examples
@@ -134,7 +173,7 @@ module Torch
134
173
 
135
174
  def dropout(input, p: 0.5, training: true, inplace: false)
136
175
  if inplace
137
- Torch._dropout!(input, p, training)
176
+ Torch._dropout_(input, p, training)
138
177
  else
139
178
  Torch._dropout(input, p, training)
140
179
  end
@@ -144,7 +183,7 @@ module Torch
144
183
  raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
145
184
 
146
185
  if inplace
147
- Torch._feature_dropout!(input, p, training)
186
+ Torch._feature_dropout_(input, p, training)
148
187
  else
149
188
  Torch._feature_dropout(input, p, training)
150
189
  end
@@ -152,7 +191,7 @@ module Torch
152
191
 
153
192
  def dropout3d(input, p: 0.5, training: true, inplace: false)
154
193
  if inplace
155
- Torch._feature_dropout!(input, p, training)
194
+ Torch._feature_dropout_(input, p, training)
156
195
  else
157
196
  Torch._feature_dropout(input, p, training)
158
197
  end
@@ -160,7 +199,7 @@ module Torch
160
199
 
161
200
  def alpha_dropout(input, p: 0.5, training: true, inplace: false)
162
201
  if inplace
163
- Torch._alpha_dropout!(input, p, training)
202
+ Torch._alpha_dropout_(input, p, training)
164
203
  else
165
204
  Torch._alpha_dropout(input, p, training)
166
205
  end
@@ -168,7 +207,7 @@ module Torch
168
207
 
169
208
  def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
170
209
  if inplace
171
- Torch._feature_alpha_dropout!(input, p, training)
210
+ Torch._feature_alpha_dropout_(input, p, training)
172
211
  else
173
212
  Torch._feature_alpha_dropout(input, p, training)
174
213
  end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class HingeEmbeddingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.hinge_embedding_loss(input, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Identity < Module
4
+ # written this way to support unused arguments
4
5
  def initialize(*args, **options)
5
6
  super()
6
7
  end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class MarginRankingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -85,25 +85,56 @@ module Torch
85
85
  end
86
86
 
87
87
  def parameters
88
- params = []
88
+ named_parameters.values
89
+ end
90
+
91
+ def named_parameters(prefix: "", recurse: true)
92
+ params = {}
93
+ if recurse
94
+ named_children.each do |name, mod|
95
+ params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
96
+ end
97
+ end
89
98
  instance_variables.each do |name|
90
99
  param = instance_variable_get(name)
91
- params << param if param.is_a?(Parameter)
100
+ params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
101
+ end
102
+ @parameters.each do |name, param|
103
+ params[[prefix, name].join] = param
92
104
  end
93
- params + modules.flat_map { |_, mod| mod.parameters }
105
+ params
106
+ end
107
+
108
+ def buffers
109
+ named_buffers.values
110
+ end
111
+
112
+ def named_buffers
113
+ @buffers || {}
94
114
  end
95
115
 
96
116
  def children
97
- @modules.values
117
+ named_children.values
98
118
  end
99
119
 
100
- def modules
120
+ def named_children
101
121
  modules = {}
102
122
  instance_variables.each do |name|
103
123
  mod = instance_variable_get(name)
104
124
  modules[name[1..-1]] = mod if mod.is_a?(Module)
105
125
  end
106
- @modules.merge(modules)
126
+ @modules.each do |name, mod|
127
+ modules[name] = mod
128
+ end
129
+ modules
130
+ end
131
+
132
+ def modules
133
+ named_modules.values
134
+ end
135
+
136
+ def named_modules
137
+ {"" => self}.merge(named_children)
107
138
  end
108
139
 
109
140
  def train(mode = true)
@@ -140,12 +171,12 @@ module Torch
140
171
 
141
172
  def inspect
142
173
  name = self.class.name.split("::").last
143
- if modules.empty?
174
+ if children.empty?
144
175
  "#{name}(#{extra_inspect})"
145
176
  else
146
177
  str = String.new
147
178
  str << "#{name}(\n"
148
- modules.each do |name, mod|
179
+ children.each do |name, mod|
149
180
  str << " (#{name}): #{mod.inspect}\n"
150
181
  end
151
182
  str << ")"
@@ -153,11 +184,21 @@ module Torch
153
184
  end
154
185
 
155
186
  def method_missing(method, *args, &block)
156
- modules[method.to_s] || super
187
+ name = method.to_s
188
+ if named_parameters.key?(name)
189
+ named_parameters[name]
190
+ elsif named_buffers.key?(name)
191
+ named_buffers[name]
192
+ elsif named_modules.key?(name)
193
+ named_modules[name]
194
+ else
195
+ super
196
+ end
157
197
  end
158
198
 
159
199
  def respond_to?(method, include_private = false)
160
- modules.key?(method.to_s) || super
200
+ name = method.to_s
201
+ named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
161
202
  end
162
203
 
163
204
  private
@@ -166,8 +207,14 @@ module Torch
166
207
  nil
167
208
  end
168
209
 
169
- def format(str, *vars)
170
- str % vars.map(&:inspect)
210
+ def format(str, *vars, **options)
211
+ vars =
212
+ if vars.any?
213
+ vars.map(&:inspect)
214
+ else
215
+ options.map { |k, v| [k, v.inspect] }.to_h
216
+ end
217
+ str % vars
171
218
  end
172
219
  end
173
220
  end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelMarginLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_margin_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end