torch-rb 0.1.3 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -0
- data/README.md +5 -2
- data/ext/torch/ext.cpp +130 -555
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +244 -0
- data/lib/torch.rb +209 -171
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +110 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6491 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +10 -20
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +379 -32
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +186 -35
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +51 -44
- data/lib/torch/version.rb +1 -1
- metadata +98 -6
- data/lib/torch/ext.bundle +0 -0
@@ -0,0 +1,134 @@
|
|
1
|
+
module Torch
|
2
|
+
module Native
|
3
|
+
class Parser
|
4
|
+
def initialize(functions)
|
5
|
+
@functions = functions
|
6
|
+
@name = @functions.first.ruby_name
|
7
|
+
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
|
+
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
end
|
10
|
+
|
11
|
+
def parse(args, options)
|
12
|
+
candidates = @functions.dup
|
13
|
+
|
14
|
+
# remove nil
|
15
|
+
while args.any? && args.last.nil?
|
16
|
+
args.pop
|
17
|
+
end
|
18
|
+
|
19
|
+
# TODO account for args passed as options here
|
20
|
+
if args.size < @min_args || args.size > @max_args
|
21
|
+
expected = String.new(@min_args.to_s)
|
22
|
+
expected += "..#{@max_args}" if @max_args != @min_args
|
23
|
+
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
24
|
+
end
|
25
|
+
|
26
|
+
candidates.reject! { |f| args.size > f.args.size }
|
27
|
+
|
28
|
+
# exclude functions missing required options
|
29
|
+
candidates.reject! do |func|
|
30
|
+
# TODO make more generic
|
31
|
+
func.out? && !options[:out]
|
32
|
+
end
|
33
|
+
|
34
|
+
# handle out with multiple
|
35
|
+
# there should only be one match, so safe to modify all
|
36
|
+
out_func = candidates.find { |f| f.out? }
|
37
|
+
if out_func && out_func.out_size > 1 && options[:out]
|
38
|
+
out_args = out_func.args.last(2).map { |a| a[:name] }
|
39
|
+
out_args.zip(options.delete(:out)).each do |k, v|
|
40
|
+
options[k.to_sym] = v
|
41
|
+
end
|
42
|
+
candidates = [out_func]
|
43
|
+
end
|
44
|
+
|
45
|
+
# exclude functions where options don't match
|
46
|
+
options.each do |k, v|
|
47
|
+
candidates.select! do |func|
|
48
|
+
func.args.any? { |a| a[:name] == k.to_s }
|
49
|
+
end
|
50
|
+
# TODO show all bad keywords at once like Ruby?
|
51
|
+
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
52
|
+
end
|
53
|
+
|
54
|
+
final_values = {}
|
55
|
+
|
56
|
+
# check args
|
57
|
+
candidates.select! do |func|
|
58
|
+
good = true
|
59
|
+
|
60
|
+
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
61
|
+
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
62
|
+
func.args.each do |fa|
|
63
|
+
values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
|
64
|
+
end
|
65
|
+
|
66
|
+
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
67
|
+
|
68
|
+
values.each_key do |k|
|
69
|
+
v = values[k]
|
70
|
+
t = arg_types[k].split("(").first
|
71
|
+
|
72
|
+
good =
|
73
|
+
case t
|
74
|
+
when "Tensor"
|
75
|
+
v.is_a?(Tensor)
|
76
|
+
when "Tensor?"
|
77
|
+
v.nil? || v.is_a?(Tensor)
|
78
|
+
when "Tensor[]"
|
79
|
+
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
80
|
+
when "int"
|
81
|
+
if k == "reduction"
|
82
|
+
v.is_a?(String)
|
83
|
+
else
|
84
|
+
v.is_a?(Integer)
|
85
|
+
end
|
86
|
+
when "float"
|
87
|
+
v.is_a?(Numeric)
|
88
|
+
when /int\[.*\]/
|
89
|
+
if v.is_a?(Integer)
|
90
|
+
size = t[4..-2]
|
91
|
+
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
92
|
+
v = [v] * size.to_i
|
93
|
+
values[k] = v
|
94
|
+
end
|
95
|
+
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
96
|
+
when "Scalar"
|
97
|
+
v.is_a?(Numeric)
|
98
|
+
when "ScalarType?"
|
99
|
+
v.nil?
|
100
|
+
when "bool"
|
101
|
+
v == true || v == false
|
102
|
+
else
|
103
|
+
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
104
|
+
end
|
105
|
+
|
106
|
+
if !good
|
107
|
+
if candidates.size == 1
|
108
|
+
k = "input" if k == "self"
|
109
|
+
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
110
|
+
end
|
111
|
+
break
|
112
|
+
end
|
113
|
+
end
|
114
|
+
|
115
|
+
if good
|
116
|
+
final_values = values
|
117
|
+
end
|
118
|
+
|
119
|
+
good
|
120
|
+
end
|
121
|
+
|
122
|
+
if candidates.size != 1
|
123
|
+
raise Error, "This should never happen. Please report a bug with #{@name}."
|
124
|
+
end
|
125
|
+
|
126
|
+
func = candidates.first
|
127
|
+
{
|
128
|
+
name: func.cpp_name,
|
129
|
+
args: func.args.map { |a| final_values[a[:name]] }
|
130
|
+
}
|
131
|
+
end
|
132
|
+
end
|
133
|
+
end
|
134
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool1d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true)
|
5
|
+
super()
|
6
|
+
@kernel_size = _single(kernel_size)
|
7
|
+
@stride = _single(stride || kernel_size)
|
8
|
+
@padding = _single(padding)
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(input)
|
14
|
+
F.avg_pool1d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool2d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@stride = stride || kernel_size
|
8
|
+
@padding = padding
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
@divisor_override = divisor_override
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
F.avg_pool2d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool3d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@stride = stride || kernel_size
|
8
|
+
@padding = padding
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
@divisor_override = divisor_override
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
F.avg_pool3d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,75 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BatchNorm < Module
|
4
|
+
def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: true, track_running_stats: true)
|
5
|
+
super()
|
6
|
+
@num_features = num_features
|
7
|
+
@eps = eps
|
8
|
+
@momentum = momentum
|
9
|
+
@affine = affine
|
10
|
+
@track_running_stats = track_running_stats
|
11
|
+
if @affine
|
12
|
+
@weight = Parameter.new(Torch::Tensor.new(num_features))
|
13
|
+
@bias = Parameter.new(Torch::Tensor.new(num_features))
|
14
|
+
else
|
15
|
+
register_parameter("weight", nil)
|
16
|
+
register_parameter("bias", nil)
|
17
|
+
end
|
18
|
+
if track_running_stats
|
19
|
+
register_buffer("running_mean", Torch.zeros(num_features))
|
20
|
+
register_buffer("running_var", Torch.ones(num_features))
|
21
|
+
register_buffer("num_batches_tracked", Torch.tensor(0, dtype: :long))
|
22
|
+
else
|
23
|
+
register_parameter("running_mean", nil)
|
24
|
+
register_parameter("running_var", nil)
|
25
|
+
register_parameter("num_batches_tracked", nil)
|
26
|
+
end
|
27
|
+
reset_parameters
|
28
|
+
end
|
29
|
+
|
30
|
+
def reset_running_stats
|
31
|
+
if @track_running_stats
|
32
|
+
@running_mean.zero!
|
33
|
+
@running_var.fill!(1)
|
34
|
+
@num_batches_tracked.zero!
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def reset_parameters
|
39
|
+
reset_running_stats
|
40
|
+
if @affine
|
41
|
+
Init.ones!(@weight)
|
42
|
+
Init.zeros!(@bias)
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def forward(input)
|
47
|
+
_check_input_dim(input)
|
48
|
+
|
49
|
+
if @momentum.nil?
|
50
|
+
exponential_average_factor = 0.0
|
51
|
+
else
|
52
|
+
exponential_average_factor = @momentum
|
53
|
+
end
|
54
|
+
|
55
|
+
if @training and @track_running_stats
|
56
|
+
if @num_batches_tracked.nil?
|
57
|
+
@num_batches_tracked += 1
|
58
|
+
if @momentum.nil?
|
59
|
+
exponential_average_factor = 1.0 / @num_batches_tracked.to_f
|
60
|
+
else
|
61
|
+
exponential_average_factor = @momentum
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
F.batch_norm(
|
67
|
+
input, @running_mean, @running_var,
|
68
|
+
weight: @weight, bias: @bias,
|
69
|
+
training: @training || !@track_running_stats,
|
70
|
+
momentum: exponential_average_factor, eps: @eps
|
71
|
+
)
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
@@ -0,0 +1,13 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCELoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
end
|
7
|
+
|
8
|
+
def forward(input, target)
|
9
|
+
F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCEWithLogitsLoss < Loss
|
4
|
+
def initialize(weight: nil, reduction: "mean", pos_weight: nil)
|
5
|
+
super(reduction)
|
6
|
+
register_buffer("weight", weight)
|
7
|
+
register_buffer("pos_weight", pos_weight)
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input, target)
|
11
|
+
F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,38 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Bilinear < Module
|
4
|
+
def initialize(in1_features, in2_features, out_features, bias: true)
|
5
|
+
super()
|
6
|
+
|
7
|
+
@in1_features = in1_features
|
8
|
+
@in2_features = in2_features
|
9
|
+
@out_features = out_features
|
10
|
+
@weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
|
11
|
+
|
12
|
+
if bias
|
13
|
+
@bias = Parameter.new(Tensor.new(out_features))
|
14
|
+
else
|
15
|
+
raise NotImplementedYet
|
16
|
+
end
|
17
|
+
|
18
|
+
reset_parameters
|
19
|
+
end
|
20
|
+
|
21
|
+
def reset_parameters
|
22
|
+
bound = 1 / Math.sqrt(@weight.size(1))
|
23
|
+
Init.uniform!(@weight, a: -bound, b: bound)
|
24
|
+
if @bias
|
25
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input1, input2)
|
30
|
+
F.bilinear(input1, input2, @weight, @bias)
|
31
|
+
end
|
32
|
+
|
33
|
+
def extra_inspect
|
34
|
+
format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ConstantPadNd < Module
|
4
|
+
def initialize(value)
|
5
|
+
super()
|
6
|
+
@value = value
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input)
|
10
|
+
F.pad(input, @padding, mode: "constant", value: @value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
format("padding: %s, value: %s", @padding, @value)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Conv1d < ConvNd
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
+
padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
6
|
+
|
7
|
+
kernel_size = _single(kernel_size)
|
8
|
+
stride = _single(stride)
|
9
|
+
padding = _single(padding)
|
10
|
+
dilation = _single(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _single(0), groups, bias, padding_mode)
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
if @padding_mode == "circular"
|
16
|
+
raise NotImplementedError
|
17
|
+
end
|
18
|
+
F.conv1d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|