torch-rb 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +5 -3
- data/ext/torch/ext.cpp +22 -548
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +68 -129
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/conv2d.rb +0 -2
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/functional.rb +55 -16
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +1 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/module.rb +59 -12
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/tensor.rb +19 -19
- data/lib/torch/version.rb +1 -1
- metadata +26 -2
@@ -0,0 +1,97 @@
|
|
1
|
+
module Torch
|
2
|
+
module Native
|
3
|
+
class Parser
|
4
|
+
def initialize(functions)
|
5
|
+
@functions = functions
|
6
|
+
@name = @functions.first.ruby_name
|
7
|
+
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
|
8
|
+
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
end
|
10
|
+
|
11
|
+
def parse(args, options)
|
12
|
+
candidates = @functions.dup
|
13
|
+
|
14
|
+
if args.size < @min_args || args.size > @max_args
|
15
|
+
expected = String.new(@min_args.to_s)
|
16
|
+
expected += "..#{@max_args}" if @max_args != @min_args
|
17
|
+
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
18
|
+
end
|
19
|
+
|
20
|
+
# exclude functions where options don't match
|
21
|
+
options.each do |k, v|
|
22
|
+
candidates.select! do |func|
|
23
|
+
func.args.any? { |a| a[:name] == k.to_s }
|
24
|
+
end
|
25
|
+
# TODO show all bad keywords at once like Ruby?
|
26
|
+
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
27
|
+
end
|
28
|
+
|
29
|
+
# exclude functions missing required options
|
30
|
+
candidates.reject! do |func|
|
31
|
+
# TODO make more generic
|
32
|
+
func.out? && !options[:out]
|
33
|
+
end
|
34
|
+
|
35
|
+
final_values = {}
|
36
|
+
|
37
|
+
# check args
|
38
|
+
candidates.select! do |func|
|
39
|
+
good = true
|
40
|
+
|
41
|
+
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
42
|
+
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
43
|
+
func.args.each do |fa|
|
44
|
+
values[fa[:name]] ||= fa[:default]
|
45
|
+
end
|
46
|
+
|
47
|
+
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
48
|
+
|
49
|
+
values.each do |k, v|
|
50
|
+
t = arg_types[k].split("(").first
|
51
|
+
good =
|
52
|
+
case t
|
53
|
+
when "Tensor"
|
54
|
+
v.is_a?(Tensor)
|
55
|
+
when "Tensor[]"
|
56
|
+
v.all? { |v2| v2.is_a?(Tensor) }
|
57
|
+
when "int"
|
58
|
+
v.is_a?(Integer)
|
59
|
+
when "int[]"
|
60
|
+
v.all? { |v2| v2.is_a?(Integer) }
|
61
|
+
when "Scalar"
|
62
|
+
v.is_a?(Numeric)
|
63
|
+
when "bool"
|
64
|
+
v == true || v == false
|
65
|
+
else
|
66
|
+
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
|
67
|
+
end
|
68
|
+
|
69
|
+
if !good
|
70
|
+
if candidates.size == 1
|
71
|
+
k = "input" if k == "self"
|
72
|
+
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
73
|
+
end
|
74
|
+
break
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
if good
|
79
|
+
final_values = values
|
80
|
+
end
|
81
|
+
|
82
|
+
good
|
83
|
+
end
|
84
|
+
|
85
|
+
if candidates.size != 1
|
86
|
+
raise Error, "This should never happen. Please report a bug with #{@name}."
|
87
|
+
end
|
88
|
+
|
89
|
+
func = candidates.first
|
90
|
+
{
|
91
|
+
name: func.cpp_name,
|
92
|
+
args: func.args.map { |a| final_values[a[:name]] }
|
93
|
+
}
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCEWithLogitsLoss < Loss
|
4
|
+
def initialize(weight: nil, reduction: "mean", pos_weight: nil)
|
5
|
+
super(reduction)
|
6
|
+
register_buffer("weight", weight)
|
7
|
+
register_buffer("pos_weight", pos_weight)
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input, target)
|
11
|
+
F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -1,8 +1,6 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Conv2d < ConvNd
|
4
|
-
attr_reader :bias, :weight
|
5
|
-
|
6
4
|
def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
7
5
|
kernel_size = pair(kernel_size)
|
8
6
|
stride = pair(stride)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CosineEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -50,7 +50,8 @@ module Torch
|
|
50
50
|
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
51
51
|
|
52
52
|
padding_idx ||= -1
|
53
|
-
|
53
|
+
# weight and indices are swapped from Python interface
|
54
|
+
Torch._embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
54
55
|
end
|
55
56
|
|
56
57
|
def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
|
@@ -76,8 +77,15 @@ module Torch
|
|
76
77
|
# loss functions
|
77
78
|
|
78
79
|
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
79
|
-
|
80
|
-
|
80
|
+
NN._binary_cross_entropy(input, target, weight, reduction)
|
81
|
+
end
|
82
|
+
|
83
|
+
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
84
|
+
Torch._binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
85
|
+
end
|
86
|
+
|
87
|
+
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
88
|
+
raise NotImplementedYet
|
81
89
|
end
|
82
90
|
|
83
91
|
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
@@ -86,28 +94,59 @@ module Torch
|
|
86
94
|
|
87
95
|
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
88
96
|
# call to_a on input_lengths and target_lengths for C++
|
89
|
-
Torch.
|
97
|
+
Torch._ctc_loss_intlist(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
98
|
+
end
|
99
|
+
|
100
|
+
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
101
|
+
Torch._hinge_embedding_loss(input, target, margin, reduction)
|
90
102
|
end
|
91
103
|
|
92
104
|
def kl_div(input, target, reduction: "mean")
|
93
|
-
Torch.
|
105
|
+
Torch._kl_div(input, target, reduction)
|
94
106
|
end
|
95
107
|
|
96
108
|
def l1_loss(input, target, reduction: "mean")
|
97
|
-
|
109
|
+
NN._l1_loss(input, target, reduction)
|
110
|
+
end
|
111
|
+
|
112
|
+
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
113
|
+
raise NotImplementedYet
|
98
114
|
end
|
99
115
|
|
100
116
|
def mse_loss(input, target, reduction: "mean")
|
101
|
-
|
117
|
+
NN._mse_loss(input, target, reduction)
|
118
|
+
end
|
119
|
+
|
120
|
+
def multilabel_margin_loss(input, target, reduction: "mean")
|
121
|
+
NN._multilabel_margin_loss(input, target, reduction)
|
122
|
+
end
|
123
|
+
|
124
|
+
def multilabel_soft_margin_loss(input, target, weight: nil)
|
125
|
+
raise NotImplementedYet
|
126
|
+
end
|
127
|
+
|
128
|
+
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
129
|
+
NN._multi_margin_loss(input, target, p, margin, weight, reduction)
|
102
130
|
end
|
103
131
|
|
104
132
|
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
105
|
-
|
106
|
-
Torch.nll_loss(input, target, reduction, ignore_index)
|
133
|
+
NN._nll_loss(input, target, weight, reduction, ignore_index)
|
107
134
|
end
|
108
135
|
|
109
136
|
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
110
|
-
Torch.
|
137
|
+
Torch._poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
138
|
+
end
|
139
|
+
|
140
|
+
def soft_margin_loss(input, target, reduction: "mean")
|
141
|
+
NN._soft_margin_loss(input, target, reduction)
|
142
|
+
end
|
143
|
+
|
144
|
+
def smooth_l1_loss(input, target, reduction: "mean")
|
145
|
+
NN._smooth_l1_loss(input, target, reduction)
|
146
|
+
end
|
147
|
+
|
148
|
+
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
149
|
+
Torch._triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
111
150
|
end
|
112
151
|
|
113
152
|
# end loss
|
@@ -123,7 +162,7 @@ module Torch
|
|
123
162
|
end
|
124
163
|
|
125
164
|
def softplus(input, beta: 1, threshold: 20)
|
126
|
-
|
165
|
+
NN._softplus(input, beta, threshold)
|
127
166
|
end
|
128
167
|
|
129
168
|
# TODO make dim keyword argument and update examples
|
@@ -134,7 +173,7 @@ module Torch
|
|
134
173
|
|
135
174
|
def dropout(input, p: 0.5, training: true, inplace: false)
|
136
175
|
if inplace
|
137
|
-
Torch.
|
176
|
+
Torch._dropout_(input, p, training)
|
138
177
|
else
|
139
178
|
Torch._dropout(input, p, training)
|
140
179
|
end
|
@@ -144,7 +183,7 @@ module Torch
|
|
144
183
|
raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
|
145
184
|
|
146
185
|
if inplace
|
147
|
-
Torch.
|
186
|
+
Torch._feature_dropout_(input, p, training)
|
148
187
|
else
|
149
188
|
Torch._feature_dropout(input, p, training)
|
150
189
|
end
|
@@ -152,7 +191,7 @@ module Torch
|
|
152
191
|
|
153
192
|
def dropout3d(input, p: 0.5, training: true, inplace: false)
|
154
193
|
if inplace
|
155
|
-
Torch.
|
194
|
+
Torch._feature_dropout_(input, p, training)
|
156
195
|
else
|
157
196
|
Torch._feature_dropout(input, p, training)
|
158
197
|
end
|
@@ -160,7 +199,7 @@ module Torch
|
|
160
199
|
|
161
200
|
def alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
162
201
|
if inplace
|
163
|
-
Torch.
|
202
|
+
Torch._alpha_dropout_(input, p, training)
|
164
203
|
else
|
165
204
|
Torch._alpha_dropout(input, p, training)
|
166
205
|
end
|
@@ -168,7 +207,7 @@ module Torch
|
|
168
207
|
|
169
208
|
def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
170
209
|
if inplace
|
171
|
-
Torch.
|
210
|
+
Torch._feature_alpha_dropout_(input, p, training)
|
172
211
|
else
|
173
212
|
Torch._feature_alpha_dropout(input, p, training)
|
174
213
|
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class HingeEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 1.0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.hinge_embedding_loss(input, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
data/lib/torch/nn/identity.rb
CHANGED
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MarginRankingLoss < Loss
|
4
|
+
def initialize(margin: 1.0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -85,25 +85,56 @@ module Torch
|
|
85
85
|
end
|
86
86
|
|
87
87
|
def parameters
|
88
|
-
|
88
|
+
named_parameters.values
|
89
|
+
end
|
90
|
+
|
91
|
+
def named_parameters(prefix: "", recurse: true)
|
92
|
+
params = {}
|
93
|
+
if recurse
|
94
|
+
named_children.each do |name, mod|
|
95
|
+
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
96
|
+
end
|
97
|
+
end
|
89
98
|
instance_variables.each do |name|
|
90
99
|
param = instance_variable_get(name)
|
91
|
-
params
|
100
|
+
params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
|
101
|
+
end
|
102
|
+
@parameters.each do |name, param|
|
103
|
+
params[[prefix, name].join] = param
|
92
104
|
end
|
93
|
-
params
|
105
|
+
params
|
106
|
+
end
|
107
|
+
|
108
|
+
def buffers
|
109
|
+
named_buffers.values
|
110
|
+
end
|
111
|
+
|
112
|
+
def named_buffers
|
113
|
+
@buffers || {}
|
94
114
|
end
|
95
115
|
|
96
116
|
def children
|
97
|
-
|
117
|
+
named_children.values
|
98
118
|
end
|
99
119
|
|
100
|
-
def
|
120
|
+
def named_children
|
101
121
|
modules = {}
|
102
122
|
instance_variables.each do |name|
|
103
123
|
mod = instance_variable_get(name)
|
104
124
|
modules[name[1..-1]] = mod if mod.is_a?(Module)
|
105
125
|
end
|
106
|
-
@modules.
|
126
|
+
@modules.each do |name, mod|
|
127
|
+
modules[name] = mod
|
128
|
+
end
|
129
|
+
modules
|
130
|
+
end
|
131
|
+
|
132
|
+
def modules
|
133
|
+
named_modules.values
|
134
|
+
end
|
135
|
+
|
136
|
+
def named_modules
|
137
|
+
{"" => self}.merge(named_children)
|
107
138
|
end
|
108
139
|
|
109
140
|
def train(mode = true)
|
@@ -140,12 +171,12 @@ module Torch
|
|
140
171
|
|
141
172
|
def inspect
|
142
173
|
name = self.class.name.split("::").last
|
143
|
-
if
|
174
|
+
if children.empty?
|
144
175
|
"#{name}(#{extra_inspect})"
|
145
176
|
else
|
146
177
|
str = String.new
|
147
178
|
str << "#{name}(\n"
|
148
|
-
|
179
|
+
children.each do |name, mod|
|
149
180
|
str << " (#{name}): #{mod.inspect}\n"
|
150
181
|
end
|
151
182
|
str << ")"
|
@@ -153,11 +184,21 @@ module Torch
|
|
153
184
|
end
|
154
185
|
|
155
186
|
def method_missing(method, *args, &block)
|
156
|
-
|
187
|
+
name = method.to_s
|
188
|
+
if named_parameters.key?(name)
|
189
|
+
named_parameters[name]
|
190
|
+
elsif named_buffers.key?(name)
|
191
|
+
named_buffers[name]
|
192
|
+
elsif named_modules.key?(name)
|
193
|
+
named_modules[name]
|
194
|
+
else
|
195
|
+
super
|
196
|
+
end
|
157
197
|
end
|
158
198
|
|
159
199
|
def respond_to?(method, include_private = false)
|
160
|
-
|
200
|
+
name = method.to_s
|
201
|
+
named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
|
161
202
|
end
|
162
203
|
|
163
204
|
private
|
@@ -166,8 +207,14 @@ module Torch
|
|
166
207
|
nil
|
167
208
|
end
|
168
209
|
|
169
|
-
def format(str, *vars)
|
170
|
-
|
210
|
+
def format(str, *vars, **options)
|
211
|
+
vars =
|
212
|
+
if vars.any?
|
213
|
+
vars.map(&:inspect)
|
214
|
+
else
|
215
|
+
options.map { |k, v| [k, v.inspect] }.to_h
|
216
|
+
end
|
217
|
+
str % vars
|
171
218
|
end
|
172
219
|
end
|
173
220
|
end
|