torch-rb 0.1.4 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +5 -3
- data/ext/torch/ext.cpp +22 -548
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +68 -129
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/conv2d.rb +0 -2
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/functional.rb +55 -16
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +1 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/module.rb +59 -12
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/tensor.rb +19 -19
- data/lib/torch/version.rb +1 -1
- metadata +26 -2
@@ -0,0 +1,97 @@
|
|
1
|
+
module Torch
|
2
|
+
module Native
|
3
|
+
class Parser
|
4
|
+
def initialize(functions)
|
5
|
+
@functions = functions
|
6
|
+
@name = @functions.first.ruby_name
|
7
|
+
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
|
8
|
+
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
end
|
10
|
+
|
11
|
+
def parse(args, options)
|
12
|
+
candidates = @functions.dup
|
13
|
+
|
14
|
+
if args.size < @min_args || args.size > @max_args
|
15
|
+
expected = String.new(@min_args.to_s)
|
16
|
+
expected += "..#{@max_args}" if @max_args != @min_args
|
17
|
+
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
18
|
+
end
|
19
|
+
|
20
|
+
# exclude functions where options don't match
|
21
|
+
options.each do |k, v|
|
22
|
+
candidates.select! do |func|
|
23
|
+
func.args.any? { |a| a[:name] == k.to_s }
|
24
|
+
end
|
25
|
+
# TODO show all bad keywords at once like Ruby?
|
26
|
+
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
27
|
+
end
|
28
|
+
|
29
|
+
# exclude functions missing required options
|
30
|
+
candidates.reject! do |func|
|
31
|
+
# TODO make more generic
|
32
|
+
func.out? && !options[:out]
|
33
|
+
end
|
34
|
+
|
35
|
+
final_values = {}
|
36
|
+
|
37
|
+
# check args
|
38
|
+
candidates.select! do |func|
|
39
|
+
good = true
|
40
|
+
|
41
|
+
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
42
|
+
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
43
|
+
func.args.each do |fa|
|
44
|
+
values[fa[:name]] ||= fa[:default]
|
45
|
+
end
|
46
|
+
|
47
|
+
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
48
|
+
|
49
|
+
values.each do |k, v|
|
50
|
+
t = arg_types[k].split("(").first
|
51
|
+
good =
|
52
|
+
case t
|
53
|
+
when "Tensor"
|
54
|
+
v.is_a?(Tensor)
|
55
|
+
when "Tensor[]"
|
56
|
+
v.all? { |v2| v2.is_a?(Tensor) }
|
57
|
+
when "int"
|
58
|
+
v.is_a?(Integer)
|
59
|
+
when "int[]"
|
60
|
+
v.all? { |v2| v2.is_a?(Integer) }
|
61
|
+
when "Scalar"
|
62
|
+
v.is_a?(Numeric)
|
63
|
+
when "bool"
|
64
|
+
v == true || v == false
|
65
|
+
else
|
66
|
+
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
|
67
|
+
end
|
68
|
+
|
69
|
+
if !good
|
70
|
+
if candidates.size == 1
|
71
|
+
k = "input" if k == "self"
|
72
|
+
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
73
|
+
end
|
74
|
+
break
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
if good
|
79
|
+
final_values = values
|
80
|
+
end
|
81
|
+
|
82
|
+
good
|
83
|
+
end
|
84
|
+
|
85
|
+
if candidates.size != 1
|
86
|
+
raise Error, "This should never happen. Please report a bug with #{@name}."
|
87
|
+
end
|
88
|
+
|
89
|
+
func = candidates.first
|
90
|
+
{
|
91
|
+
name: func.cpp_name,
|
92
|
+
args: func.args.map { |a| final_values[a[:name]] }
|
93
|
+
}
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCEWithLogitsLoss < Loss
|
4
|
+
def initialize(weight: nil, reduction: "mean", pos_weight: nil)
|
5
|
+
super(reduction)
|
6
|
+
register_buffer("weight", weight)
|
7
|
+
register_buffer("pos_weight", pos_weight)
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input, target)
|
11
|
+
F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -1,8 +1,6 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Conv2d < ConvNd
|
4
|
-
attr_reader :bias, :weight
|
5
|
-
|
6
4
|
def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
7
5
|
kernel_size = pair(kernel_size)
|
8
6
|
stride = pair(stride)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CosineEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -50,7 +50,8 @@ module Torch
|
|
50
50
|
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
51
51
|
|
52
52
|
padding_idx ||= -1
|
53
|
-
|
53
|
+
# weight and indices are swapped from Python interface
|
54
|
+
Torch._embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
54
55
|
end
|
55
56
|
|
56
57
|
def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
|
@@ -76,8 +77,15 @@ module Torch
|
|
76
77
|
# loss functions
|
77
78
|
|
78
79
|
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
79
|
-
|
80
|
-
|
80
|
+
NN._binary_cross_entropy(input, target, weight, reduction)
|
81
|
+
end
|
82
|
+
|
83
|
+
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
84
|
+
Torch._binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
85
|
+
end
|
86
|
+
|
87
|
+
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
88
|
+
raise NotImplementedYet
|
81
89
|
end
|
82
90
|
|
83
91
|
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
@@ -86,28 +94,59 @@ module Torch
|
|
86
94
|
|
87
95
|
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
88
96
|
# call to_a on input_lengths and target_lengths for C++
|
89
|
-
Torch.
|
97
|
+
Torch._ctc_loss_intlist(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
98
|
+
end
|
99
|
+
|
100
|
+
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
101
|
+
Torch._hinge_embedding_loss(input, target, margin, reduction)
|
90
102
|
end
|
91
103
|
|
92
104
|
def kl_div(input, target, reduction: "mean")
|
93
|
-
Torch.
|
105
|
+
Torch._kl_div(input, target, reduction)
|
94
106
|
end
|
95
107
|
|
96
108
|
def l1_loss(input, target, reduction: "mean")
|
97
|
-
|
109
|
+
NN._l1_loss(input, target, reduction)
|
110
|
+
end
|
111
|
+
|
112
|
+
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
113
|
+
raise NotImplementedYet
|
98
114
|
end
|
99
115
|
|
100
116
|
def mse_loss(input, target, reduction: "mean")
|
101
|
-
|
117
|
+
NN._mse_loss(input, target, reduction)
|
118
|
+
end
|
119
|
+
|
120
|
+
def multilabel_margin_loss(input, target, reduction: "mean")
|
121
|
+
NN._multilabel_margin_loss(input, target, reduction)
|
122
|
+
end
|
123
|
+
|
124
|
+
def multilabel_soft_margin_loss(input, target, weight: nil)
|
125
|
+
raise NotImplementedYet
|
126
|
+
end
|
127
|
+
|
128
|
+
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
129
|
+
NN._multi_margin_loss(input, target, p, margin, weight, reduction)
|
102
130
|
end
|
103
131
|
|
104
132
|
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
105
|
-
|
106
|
-
Torch.nll_loss(input, target, reduction, ignore_index)
|
133
|
+
NN._nll_loss(input, target, weight, reduction, ignore_index)
|
107
134
|
end
|
108
135
|
|
109
136
|
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
110
|
-
Torch.
|
137
|
+
Torch._poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
138
|
+
end
|
139
|
+
|
140
|
+
def soft_margin_loss(input, target, reduction: "mean")
|
141
|
+
NN._soft_margin_loss(input, target, reduction)
|
142
|
+
end
|
143
|
+
|
144
|
+
def smooth_l1_loss(input, target, reduction: "mean")
|
145
|
+
NN._smooth_l1_loss(input, target, reduction)
|
146
|
+
end
|
147
|
+
|
148
|
+
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
149
|
+
Torch._triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
111
150
|
end
|
112
151
|
|
113
152
|
# end loss
|
@@ -123,7 +162,7 @@ module Torch
|
|
123
162
|
end
|
124
163
|
|
125
164
|
def softplus(input, beta: 1, threshold: 20)
|
126
|
-
|
165
|
+
NN._softplus(input, beta, threshold)
|
127
166
|
end
|
128
167
|
|
129
168
|
# TODO make dim keyword argument and update examples
|
@@ -134,7 +173,7 @@ module Torch
|
|
134
173
|
|
135
174
|
def dropout(input, p: 0.5, training: true, inplace: false)
|
136
175
|
if inplace
|
137
|
-
Torch.
|
176
|
+
Torch._dropout_(input, p, training)
|
138
177
|
else
|
139
178
|
Torch._dropout(input, p, training)
|
140
179
|
end
|
@@ -144,7 +183,7 @@ module Torch
|
|
144
183
|
raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
|
145
184
|
|
146
185
|
if inplace
|
147
|
-
Torch.
|
186
|
+
Torch._feature_dropout_(input, p, training)
|
148
187
|
else
|
149
188
|
Torch._feature_dropout(input, p, training)
|
150
189
|
end
|
@@ -152,7 +191,7 @@ module Torch
|
|
152
191
|
|
153
192
|
def dropout3d(input, p: 0.5, training: true, inplace: false)
|
154
193
|
if inplace
|
155
|
-
Torch.
|
194
|
+
Torch._feature_dropout_(input, p, training)
|
156
195
|
else
|
157
196
|
Torch._feature_dropout(input, p, training)
|
158
197
|
end
|
@@ -160,7 +199,7 @@ module Torch
|
|
160
199
|
|
161
200
|
def alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
162
201
|
if inplace
|
163
|
-
Torch.
|
202
|
+
Torch._alpha_dropout_(input, p, training)
|
164
203
|
else
|
165
204
|
Torch._alpha_dropout(input, p, training)
|
166
205
|
end
|
@@ -168,7 +207,7 @@ module Torch
|
|
168
207
|
|
169
208
|
def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
170
209
|
if inplace
|
171
|
-
Torch.
|
210
|
+
Torch._feature_alpha_dropout_(input, p, training)
|
172
211
|
else
|
173
212
|
Torch._feature_alpha_dropout(input, p, training)
|
174
213
|
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class HingeEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 1.0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.hinge_embedding_loss(input, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
data/lib/torch/nn/identity.rb
CHANGED
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MarginRankingLoss < Loss
|
4
|
+
def initialize(margin: 1.0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -85,25 +85,56 @@ module Torch
|
|
85
85
|
end
|
86
86
|
|
87
87
|
def parameters
|
88
|
-
|
88
|
+
named_parameters.values
|
89
|
+
end
|
90
|
+
|
91
|
+
def named_parameters(prefix: "", recurse: true)
|
92
|
+
params = {}
|
93
|
+
if recurse
|
94
|
+
named_children.each do |name, mod|
|
95
|
+
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
96
|
+
end
|
97
|
+
end
|
89
98
|
instance_variables.each do |name|
|
90
99
|
param = instance_variable_get(name)
|
91
|
-
params
|
100
|
+
params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
|
101
|
+
end
|
102
|
+
@parameters.each do |name, param|
|
103
|
+
params[[prefix, name].join] = param
|
92
104
|
end
|
93
|
-
params
|
105
|
+
params
|
106
|
+
end
|
107
|
+
|
108
|
+
def buffers
|
109
|
+
named_buffers.values
|
110
|
+
end
|
111
|
+
|
112
|
+
def named_buffers
|
113
|
+
@buffers || {}
|
94
114
|
end
|
95
115
|
|
96
116
|
def children
|
97
|
-
|
117
|
+
named_children.values
|
98
118
|
end
|
99
119
|
|
100
|
-
def
|
120
|
+
def named_children
|
101
121
|
modules = {}
|
102
122
|
instance_variables.each do |name|
|
103
123
|
mod = instance_variable_get(name)
|
104
124
|
modules[name[1..-1]] = mod if mod.is_a?(Module)
|
105
125
|
end
|
106
|
-
@modules.
|
126
|
+
@modules.each do |name, mod|
|
127
|
+
modules[name] = mod
|
128
|
+
end
|
129
|
+
modules
|
130
|
+
end
|
131
|
+
|
132
|
+
def modules
|
133
|
+
named_modules.values
|
134
|
+
end
|
135
|
+
|
136
|
+
def named_modules
|
137
|
+
{"" => self}.merge(named_children)
|
107
138
|
end
|
108
139
|
|
109
140
|
def train(mode = true)
|
@@ -140,12 +171,12 @@ module Torch
|
|
140
171
|
|
141
172
|
def inspect
|
142
173
|
name = self.class.name.split("::").last
|
143
|
-
if
|
174
|
+
if children.empty?
|
144
175
|
"#{name}(#{extra_inspect})"
|
145
176
|
else
|
146
177
|
str = String.new
|
147
178
|
str << "#{name}(\n"
|
148
|
-
|
179
|
+
children.each do |name, mod|
|
149
180
|
str << " (#{name}): #{mod.inspect}\n"
|
150
181
|
end
|
151
182
|
str << ")"
|
@@ -153,11 +184,21 @@ module Torch
|
|
153
184
|
end
|
154
185
|
|
155
186
|
def method_missing(method, *args, &block)
|
156
|
-
|
187
|
+
name = method.to_s
|
188
|
+
if named_parameters.key?(name)
|
189
|
+
named_parameters[name]
|
190
|
+
elsif named_buffers.key?(name)
|
191
|
+
named_buffers[name]
|
192
|
+
elsif named_modules.key?(name)
|
193
|
+
named_modules[name]
|
194
|
+
else
|
195
|
+
super
|
196
|
+
end
|
157
197
|
end
|
158
198
|
|
159
199
|
def respond_to?(method, include_private = false)
|
160
|
-
|
200
|
+
name = method.to_s
|
201
|
+
named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
|
161
202
|
end
|
162
203
|
|
163
204
|
private
|
@@ -166,8 +207,14 @@ module Torch
|
|
166
207
|
nil
|
167
208
|
end
|
168
209
|
|
169
|
-
def format(str, *vars)
|
170
|
-
|
210
|
+
def format(str, *vars, **options)
|
211
|
+
vars =
|
212
|
+
if vars.any?
|
213
|
+
vars.map(&:inspect)
|
214
|
+
else
|
215
|
+
options.map { |k, v| [k, v.inspect] }.to_h
|
216
|
+
end
|
217
|
+
str % vars
|
171
218
|
end
|
172
219
|
end
|
173
220
|
end
|