torch-rb 0.1.0 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +40 -0
- data/LICENSE.txt +46 -22
- data/README.md +85 -19
- data/ext/torch/ext.cpp +274 -256
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +199 -84
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +52 -25
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +14 -29
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +194 -11
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +184 -19
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +92 -21
- data/lib/torch/utils/data/data_loader.rb +15 -0
- data/lib/torch/utils/data/tensor_dataset.rb +8 -1
- data/lib/torch/version.rb +1 -1
- metadata +74 -3
@@ -0,0 +1,97 @@
|
|
1
|
+
module Torch
|
2
|
+
module Native
|
3
|
+
class Parser
|
4
|
+
def initialize(functions)
|
5
|
+
@functions = functions
|
6
|
+
@name = @functions.first.ruby_name
|
7
|
+
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
|
8
|
+
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
end
|
10
|
+
|
11
|
+
def parse(args, options)
|
12
|
+
candidates = @functions.dup
|
13
|
+
|
14
|
+
if args.size < @min_args || args.size > @max_args
|
15
|
+
expected = String.new(@min_args.to_s)
|
16
|
+
expected += "..#{@max_args}" if @max_args != @min_args
|
17
|
+
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
18
|
+
end
|
19
|
+
|
20
|
+
# exclude functions where options don't match
|
21
|
+
options.each do |k, v|
|
22
|
+
candidates.select! do |func|
|
23
|
+
func.args.any? { |a| a[:name] == k.to_s }
|
24
|
+
end
|
25
|
+
# TODO show all bad keywords at once like Ruby?
|
26
|
+
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
27
|
+
end
|
28
|
+
|
29
|
+
# exclude functions missing required options
|
30
|
+
candidates.reject! do |func|
|
31
|
+
# TODO make more generic
|
32
|
+
func.out? && !options[:out]
|
33
|
+
end
|
34
|
+
|
35
|
+
final_values = {}
|
36
|
+
|
37
|
+
# check args
|
38
|
+
candidates.select! do |func|
|
39
|
+
good = true
|
40
|
+
|
41
|
+
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
42
|
+
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
43
|
+
func.args.each do |fa|
|
44
|
+
values[fa[:name]] ||= fa[:default]
|
45
|
+
end
|
46
|
+
|
47
|
+
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
48
|
+
|
49
|
+
values.each do |k, v|
|
50
|
+
t = arg_types[k].split("(").first
|
51
|
+
good =
|
52
|
+
case t
|
53
|
+
when "Tensor"
|
54
|
+
v.is_a?(Tensor)
|
55
|
+
when "Tensor[]"
|
56
|
+
v.all? { |v2| v2.is_a?(Tensor) }
|
57
|
+
when "int"
|
58
|
+
v.is_a?(Integer)
|
59
|
+
when "int[]"
|
60
|
+
v.all? { |v2| v2.is_a?(Integer) }
|
61
|
+
when "Scalar"
|
62
|
+
v.is_a?(Numeric)
|
63
|
+
when "bool"
|
64
|
+
v == true || v == false
|
65
|
+
else
|
66
|
+
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
|
67
|
+
end
|
68
|
+
|
69
|
+
if !good
|
70
|
+
if candidates.size == 1
|
71
|
+
k = "input" if k == "self"
|
72
|
+
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
73
|
+
end
|
74
|
+
break
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
if good
|
79
|
+
final_values = values
|
80
|
+
end
|
81
|
+
|
82
|
+
good
|
83
|
+
end
|
84
|
+
|
85
|
+
if candidates.size != 1
|
86
|
+
raise Error, "This should never happen. Please report a bug with #{@name}."
|
87
|
+
end
|
88
|
+
|
89
|
+
func = candidates.first
|
90
|
+
{
|
91
|
+
name: func.cpp_name,
|
92
|
+
args: func.args.map { |a| final_values[a[:name]] }
|
93
|
+
}
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
@@ -0,0 +1,13 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCELoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
end
|
7
|
+
|
8
|
+
def forward(input, target)
|
9
|
+
F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCEWithLogitsLoss < Loss
|
4
|
+
def initialize(weight: nil, reduction: "mean", pos_weight: nil)
|
5
|
+
super(reduction)
|
6
|
+
register_buffer("weight", weight)
|
7
|
+
register_buffer("pos_weight", pos_weight)
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input, target)
|
11
|
+
F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,38 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Bilinear < Module
|
4
|
+
def initialize(in1_features, in2_features, out_features, bias: true)
|
5
|
+
super()
|
6
|
+
|
7
|
+
@in1_features = in1_features
|
8
|
+
@in2_features = in2_features
|
9
|
+
@out_features = out_features
|
10
|
+
@weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
|
11
|
+
|
12
|
+
if bias
|
13
|
+
@bias = Parameter.new(Tensor.new(out_features))
|
14
|
+
else
|
15
|
+
raise NotImplementedYet
|
16
|
+
end
|
17
|
+
|
18
|
+
reset_parameters
|
19
|
+
end
|
20
|
+
|
21
|
+
def reset_parameters
|
22
|
+
bound = 1 / Math.sqrt(@weight.size(1))
|
23
|
+
Init.uniform!(@weight, a: -bound, b: bound)
|
24
|
+
if @bias
|
25
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input1, input2)
|
30
|
+
F.bilinear(input1, input2, @weight, @bias)
|
31
|
+
end
|
32
|
+
|
33
|
+
def extra_inspect
|
34
|
+
format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -1,39 +1,24 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
|
-
class Conv2d <
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
@stride = pair(1)
|
11
|
-
# @stride = pair(stride)
|
12
|
-
# @padding = pair(padding)
|
13
|
-
# @dilation = pair(dilation)
|
14
|
-
|
15
|
-
# TODO divide by groups
|
16
|
-
@weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
|
17
|
-
@bias = Parameter.new(Tensor.new(out_channels))
|
18
|
-
|
19
|
-
reset_parameters
|
3
|
+
class Conv2d < ConvNd
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
5
|
+
kernel_size = pair(kernel_size)
|
6
|
+
stride = pair(stride)
|
7
|
+
padding = pair(padding)
|
8
|
+
dilation = pair(dilation)
|
9
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode)
|
20
10
|
end
|
21
11
|
|
22
|
-
def
|
23
|
-
|
24
|
-
|
25
|
-
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
26
|
-
bound = 1 / Math.sqrt(fan_in)
|
27
|
-
Init.uniform_(@bias, -bound, bound)
|
12
|
+
def forward(input)
|
13
|
+
if @padding_mode == "circular"
|
14
|
+
raise NotImplementedError
|
28
15
|
end
|
16
|
+
F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups)
|
29
17
|
end
|
30
18
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
def inspect
|
36
|
-
"Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
|
19
|
+
# TODO add more parameters
|
20
|
+
def extra_inspect
|
21
|
+
format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
|
37
22
|
end
|
38
23
|
|
39
24
|
private
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ConvNd < Module
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
5
|
+
super()
|
6
|
+
raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
|
7
|
+
raise ArgumentError, "out_channels must be divisible by groups" if out_channels % groups != 0
|
8
|
+
@in_channels = in_channels
|
9
|
+
@out_channels = out_channels
|
10
|
+
@kernel_size = kernel_size
|
11
|
+
@stride = stride
|
12
|
+
@padding = padding
|
13
|
+
@dilation = dilation
|
14
|
+
@transposed = transposed
|
15
|
+
@output_padding = output_padding
|
16
|
+
@groups = groups
|
17
|
+
@padding_mode = padding_mode
|
18
|
+
if transposed
|
19
|
+
@weight = Parameter.new(Tensor.new(in_channels, out_channels / groups, *kernel_size))
|
20
|
+
else
|
21
|
+
@weight = Parameter.new(Tensor.new(out_channels, in_channels / groups, *kernel_size))
|
22
|
+
end
|
23
|
+
if bias
|
24
|
+
@bias = Parameter.new(Tensor.new(out_channels))
|
25
|
+
else
|
26
|
+
raise NotImplementedError
|
27
|
+
end
|
28
|
+
reset_parameters
|
29
|
+
end
|
30
|
+
|
31
|
+
def reset_parameters
|
32
|
+
Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
|
33
|
+
if @bias
|
34
|
+
fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
|
35
|
+
bound = 1 / Math.sqrt(fan_in)
|
36
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CosineEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CrossEntropyLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, ignore_index: -100, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
@ignore_index = ignore_index
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CTCLoss < Loss
|
4
|
+
def initialize(blank: 0, reduction: "mean", zero_infinity: false)
|
5
|
+
super(reduction)
|
6
|
+
@blank = blank
|
7
|
+
@zero_infinity = zero_infinity
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(log_probs, targets, input_lengths, target_lengths)
|
11
|
+
F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
|
2
|
+
module Torch
|
3
|
+
module NN
|
4
|
+
class Embedding < Module
|
5
|
+
def initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil,
|
6
|
+
norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil)
|
7
|
+
|
8
|
+
super()
|
9
|
+
@num_embeddings = num_embeddings
|
10
|
+
@embedding_dim = embedding_dim
|
11
|
+
|
12
|
+
if padding_idx
|
13
|
+
if padding_idx > 0
|
14
|
+
raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx < @num_embeddings
|
15
|
+
elsif padding_idx < 0
|
16
|
+
raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx >= -@num_embeddings
|
17
|
+
padding_idx = @num_embeddings + padding_idx
|
18
|
+
end
|
19
|
+
end
|
20
|
+
@padding_idx = padding_idx
|
21
|
+
@max_norm = max_norm
|
22
|
+
@norm_type = norm_type
|
23
|
+
@scale_grad_by_freq = scale_grad_by_freq
|
24
|
+
if _weight.nil?
|
25
|
+
@weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
|
26
|
+
reset_parameters
|
27
|
+
else
|
28
|
+
raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
|
29
|
+
@weight = Parameter.new(_weight)
|
30
|
+
end
|
31
|
+
@sparse = sparse
|
32
|
+
end
|
33
|
+
|
34
|
+
def reset_parameters
|
35
|
+
Init.normal!(@weight)
|
36
|
+
if @padding_idx
|
37
|
+
Torch.no_grad do
|
38
|
+
@weight[@padding_idx].fill!(0)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
def forward(input)
|
44
|
+
F.embedding(input, @weight, padding_idx: @padding_idx, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, sparse: @sparse)
|
45
|
+
end
|
46
|
+
|
47
|
+
def inspect
|
48
|
+
"Embedding(#{@num_embeddings}, #{@embedding_dim})"
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|