torch-rb 0.1.2 → 0.1.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +46 -22
- data/README.md +18 -6
- data/ext/torch/ext.cpp +148 -369
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +615 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +242 -0
- data/ext/torch/tensor_functions.cpp +1920 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2975 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +240 -131
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +27 -22
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +109 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +16 -38
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +411 -22
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +201 -20
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +56 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +48 -16
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +71 -30
- data/lib/torch/utils/data/data_loader.rb +10 -4
- data/lib/torch/utils/data/tensor_dataset.rb +3 -0
- data/lib/torch/version.rb +1 -1
- metadata +123 -6
@@ -0,0 +1,134 @@
|
|
1
|
+
module Torch
|
2
|
+
module Native
|
3
|
+
class Parser
|
4
|
+
def initialize(functions)
|
5
|
+
@functions = functions
|
6
|
+
@name = @functions.first.ruby_name
|
7
|
+
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
|
+
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
end
|
10
|
+
|
11
|
+
def parse(args, options)
|
12
|
+
candidates = @functions.dup
|
13
|
+
|
14
|
+
# remove nil
|
15
|
+
while args.any? && args.last.nil?
|
16
|
+
args.pop
|
17
|
+
end
|
18
|
+
|
19
|
+
# TODO account for args passed as options here
|
20
|
+
if args.size < @min_args || args.size > @max_args
|
21
|
+
expected = String.new(@min_args.to_s)
|
22
|
+
expected += "..#{@max_args}" if @max_args != @min_args
|
23
|
+
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
24
|
+
end
|
25
|
+
|
26
|
+
candidates.reject! { |f| args.size > f.args.size }
|
27
|
+
|
28
|
+
# exclude functions missing required options
|
29
|
+
candidates.reject! do |func|
|
30
|
+
# TODO make more generic
|
31
|
+
func.out? && !options[:out]
|
32
|
+
end
|
33
|
+
|
34
|
+
# handle out with multiple
|
35
|
+
# there should only be one match, so safe to modify all
|
36
|
+
out_func = candidates.find { |f| f.out? }
|
37
|
+
if out_func && out_func.out_size > 1 && options[:out]
|
38
|
+
out_args = out_func.args.last(2).map { |a| a[:name] }
|
39
|
+
out_args.zip(options.delete(:out)).each do |k, v|
|
40
|
+
options[k.to_sym] = v
|
41
|
+
end
|
42
|
+
candidates = [out_func]
|
43
|
+
end
|
44
|
+
|
45
|
+
# exclude functions where options don't match
|
46
|
+
options.each do |k, v|
|
47
|
+
candidates.select! do |func|
|
48
|
+
func.args.any? { |a| a[:name] == k.to_s }
|
49
|
+
end
|
50
|
+
# TODO show all bad keywords at once like Ruby?
|
51
|
+
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
52
|
+
end
|
53
|
+
|
54
|
+
final_values = {}
|
55
|
+
|
56
|
+
# check args
|
57
|
+
candidates.select! do |func|
|
58
|
+
good = true
|
59
|
+
|
60
|
+
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
61
|
+
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
62
|
+
func.args.each do |fa|
|
63
|
+
values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
|
64
|
+
end
|
65
|
+
|
66
|
+
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
67
|
+
|
68
|
+
values.each_key do |k|
|
69
|
+
v = values[k]
|
70
|
+
t = arg_types[k].split("(").first
|
71
|
+
|
72
|
+
good =
|
73
|
+
case t
|
74
|
+
when "Tensor"
|
75
|
+
v.is_a?(Tensor)
|
76
|
+
when "Tensor?"
|
77
|
+
v.nil? || v.is_a?(Tensor)
|
78
|
+
when "Tensor[]"
|
79
|
+
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
80
|
+
when "int"
|
81
|
+
if k == "reduction"
|
82
|
+
v.is_a?(String)
|
83
|
+
else
|
84
|
+
v.is_a?(Integer)
|
85
|
+
end
|
86
|
+
when "float"
|
87
|
+
v.is_a?(Numeric)
|
88
|
+
when /int\[.*\]/
|
89
|
+
if v.is_a?(Integer)
|
90
|
+
size = t[4..-2]
|
91
|
+
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
92
|
+
v = [v] * size.to_i
|
93
|
+
values[k] = v
|
94
|
+
end
|
95
|
+
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
96
|
+
when "Scalar"
|
97
|
+
v.is_a?(Numeric)
|
98
|
+
when "ScalarType?"
|
99
|
+
v.nil?
|
100
|
+
when "bool"
|
101
|
+
v == true || v == false
|
102
|
+
else
|
103
|
+
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
104
|
+
end
|
105
|
+
|
106
|
+
if !good
|
107
|
+
if candidates.size == 1
|
108
|
+
k = "input" if k == "self"
|
109
|
+
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
110
|
+
end
|
111
|
+
break
|
112
|
+
end
|
113
|
+
end
|
114
|
+
|
115
|
+
if good
|
116
|
+
final_values = values
|
117
|
+
end
|
118
|
+
|
119
|
+
good
|
120
|
+
end
|
121
|
+
|
122
|
+
if candidates.size != 1
|
123
|
+
raise Error, "This should never happen. Please report a bug with #{@name}."
|
124
|
+
end
|
125
|
+
|
126
|
+
func = candidates.first
|
127
|
+
{
|
128
|
+
name: func.cpp_name,
|
129
|
+
args: func.args.map { |a| final_values[a[:name]] }
|
130
|
+
}
|
131
|
+
end
|
132
|
+
end
|
133
|
+
end
|
134
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool1d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true)
|
5
|
+
super()
|
6
|
+
@kernel_size = _single(kernel_size)
|
7
|
+
@stride = _single(stride || kernel_size)
|
8
|
+
@padding = _single(padding)
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(input)
|
14
|
+
F.avg_pool1d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool2d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@stride = stride || kernel_size
|
8
|
+
@padding = padding
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
@divisor_override = divisor_override
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
F.avg_pool2d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool3d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@stride = stride || kernel_size
|
8
|
+
@padding = padding
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
@divisor_override = divisor_override
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
F.avg_pool3d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,75 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BatchNorm < Module
|
4
|
+
def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: true, track_running_stats: true)
|
5
|
+
super()
|
6
|
+
@num_features = num_features
|
7
|
+
@eps = eps
|
8
|
+
@momentum = momentum
|
9
|
+
@affine = affine
|
10
|
+
@track_running_stats = track_running_stats
|
11
|
+
if @affine
|
12
|
+
@weight = Parameter.new(Torch::Tensor.new(num_features))
|
13
|
+
@bias = Parameter.new(Torch::Tensor.new(num_features))
|
14
|
+
else
|
15
|
+
register_parameter("weight", nil)
|
16
|
+
register_parameter("bias", nil)
|
17
|
+
end
|
18
|
+
if track_running_stats
|
19
|
+
register_buffer("running_mean", Torch.zeros(num_features))
|
20
|
+
register_buffer("running_var", Torch.ones(num_features))
|
21
|
+
register_buffer("num_batches_tracked", Torch.tensor(0, dtype: :long))
|
22
|
+
else
|
23
|
+
register_parameter("running_mean", nil)
|
24
|
+
register_parameter("running_var", nil)
|
25
|
+
register_parameter("num_batches_tracked", nil)
|
26
|
+
end
|
27
|
+
reset_parameters
|
28
|
+
end
|
29
|
+
|
30
|
+
def reset_running_stats
|
31
|
+
if @track_running_stats
|
32
|
+
@running_mean.zero!
|
33
|
+
@running_var.fill!(1)
|
34
|
+
@num_batches_tracked.zero!
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def reset_parameters
|
39
|
+
reset_running_stats
|
40
|
+
if @affine
|
41
|
+
Init.ones!(@weight)
|
42
|
+
Init.zeros!(@bias)
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def forward(input)
|
47
|
+
_check_input_dim(input)
|
48
|
+
|
49
|
+
if @momentum.nil?
|
50
|
+
exponential_average_factor = 0.0
|
51
|
+
else
|
52
|
+
exponential_average_factor = @momentum
|
53
|
+
end
|
54
|
+
|
55
|
+
if @training and @track_running_stats
|
56
|
+
if @num_batches_tracked.nil?
|
57
|
+
@num_batches_tracked += 1
|
58
|
+
if @momentum.nil?
|
59
|
+
exponential_average_factor = 1.0 / @num_batches_tracked.to_f
|
60
|
+
else
|
61
|
+
exponential_average_factor = @momentum
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
F.batch_norm(
|
67
|
+
input, @running_mean, @running_var,
|
68
|
+
weight: @weight, bias: @bias,
|
69
|
+
training: @training || !@track_running_stats,
|
70
|
+
momentum: exponential_average_factor, eps: @eps
|
71
|
+
)
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
@@ -0,0 +1,13 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCELoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
end
|
7
|
+
|
8
|
+
def forward(input, target)
|
9
|
+
F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCEWithLogitsLoss < Loss
|
4
|
+
def initialize(weight: nil, reduction: "mean", pos_weight: nil)
|
5
|
+
super(reduction)
|
6
|
+
register_buffer("weight", weight)
|
7
|
+
register_buffer("pos_weight", pos_weight)
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input, target)
|
11
|
+
F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,38 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Bilinear < Module
|
4
|
+
def initialize(in1_features, in2_features, out_features, bias: true)
|
5
|
+
super()
|
6
|
+
|
7
|
+
@in1_features = in1_features
|
8
|
+
@in2_features = in2_features
|
9
|
+
@out_features = out_features
|
10
|
+
@weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
|
11
|
+
|
12
|
+
if bias
|
13
|
+
@bias = Parameter.new(Tensor.new(out_features))
|
14
|
+
else
|
15
|
+
raise NotImplementedYet
|
16
|
+
end
|
17
|
+
|
18
|
+
reset_parameters
|
19
|
+
end
|
20
|
+
|
21
|
+
def reset_parameters
|
22
|
+
bound = 1 / Math.sqrt(@weight.size(1))
|
23
|
+
Init.uniform!(@weight, a: -bound, b: bound)
|
24
|
+
if @bias
|
25
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input1, input2)
|
30
|
+
F.bilinear(input1, input2, @weight, @bias)
|
31
|
+
end
|
32
|
+
|
33
|
+
def extra_inspect
|
34
|
+
format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ConstantPadNd < Module
|
4
|
+
def initialize(value)
|
5
|
+
super()
|
6
|
+
@value = value
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input)
|
10
|
+
F.pad(input, @padding, mode: "constant", value: @value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
format("padding: %s, value: %s", @padding, @value)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|