torch-rb 0.1.0 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +40 -0
- data/LICENSE.txt +46 -22
- data/README.md +85 -19
- data/ext/torch/ext.cpp +274 -256
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +199 -84
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +52 -25
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +14 -29
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +194 -11
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +184 -19
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +92 -21
- data/lib/torch/utils/data/data_loader.rb +15 -0
- data/lib/torch/utils/data/tensor_dataset.rb +8 -1
- data/lib/torch/version.rb +1 -1
- metadata +74 -3
@@ -0,0 +1,97 @@
|
|
1
|
+
module Torch
|
2
|
+
module Native
|
3
|
+
class Parser
|
4
|
+
def initialize(functions)
|
5
|
+
@functions = functions
|
6
|
+
@name = @functions.first.ruby_name
|
7
|
+
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
|
8
|
+
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
end
|
10
|
+
|
11
|
+
def parse(args, options)
|
12
|
+
candidates = @functions.dup
|
13
|
+
|
14
|
+
if args.size < @min_args || args.size > @max_args
|
15
|
+
expected = String.new(@min_args.to_s)
|
16
|
+
expected += "..#{@max_args}" if @max_args != @min_args
|
17
|
+
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
18
|
+
end
|
19
|
+
|
20
|
+
# exclude functions where options don't match
|
21
|
+
options.each do |k, v|
|
22
|
+
candidates.select! do |func|
|
23
|
+
func.args.any? { |a| a[:name] == k.to_s }
|
24
|
+
end
|
25
|
+
# TODO show all bad keywords at once like Ruby?
|
26
|
+
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
27
|
+
end
|
28
|
+
|
29
|
+
# exclude functions missing required options
|
30
|
+
candidates.reject! do |func|
|
31
|
+
# TODO make more generic
|
32
|
+
func.out? && !options[:out]
|
33
|
+
end
|
34
|
+
|
35
|
+
final_values = {}
|
36
|
+
|
37
|
+
# check args
|
38
|
+
candidates.select! do |func|
|
39
|
+
good = true
|
40
|
+
|
41
|
+
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
42
|
+
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
43
|
+
func.args.each do |fa|
|
44
|
+
values[fa[:name]] ||= fa[:default]
|
45
|
+
end
|
46
|
+
|
47
|
+
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
48
|
+
|
49
|
+
values.each do |k, v|
|
50
|
+
t = arg_types[k].split("(").first
|
51
|
+
good =
|
52
|
+
case t
|
53
|
+
when "Tensor"
|
54
|
+
v.is_a?(Tensor)
|
55
|
+
when "Tensor[]"
|
56
|
+
v.all? { |v2| v2.is_a?(Tensor) }
|
57
|
+
when "int"
|
58
|
+
v.is_a?(Integer)
|
59
|
+
when "int[]"
|
60
|
+
v.all? { |v2| v2.is_a?(Integer) }
|
61
|
+
when "Scalar"
|
62
|
+
v.is_a?(Numeric)
|
63
|
+
when "bool"
|
64
|
+
v == true || v == false
|
65
|
+
else
|
66
|
+
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
|
67
|
+
end
|
68
|
+
|
69
|
+
if !good
|
70
|
+
if candidates.size == 1
|
71
|
+
k = "input" if k == "self"
|
72
|
+
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
73
|
+
end
|
74
|
+
break
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
if good
|
79
|
+
final_values = values
|
80
|
+
end
|
81
|
+
|
82
|
+
good
|
83
|
+
end
|
84
|
+
|
85
|
+
if candidates.size != 1
|
86
|
+
raise Error, "This should never happen. Please report a bug with #{@name}."
|
87
|
+
end
|
88
|
+
|
89
|
+
func = candidates.first
|
90
|
+
{
|
91
|
+
name: func.cpp_name,
|
92
|
+
args: func.args.map { |a| final_values[a[:name]] }
|
93
|
+
}
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
@@ -0,0 +1,13 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCELoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
end
|
7
|
+
|
8
|
+
def forward(input, target)
|
9
|
+
F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCEWithLogitsLoss < Loss
|
4
|
+
def initialize(weight: nil, reduction: "mean", pos_weight: nil)
|
5
|
+
super(reduction)
|
6
|
+
register_buffer("weight", weight)
|
7
|
+
register_buffer("pos_weight", pos_weight)
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input, target)
|
11
|
+
F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,38 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Bilinear < Module
|
4
|
+
def initialize(in1_features, in2_features, out_features, bias: true)
|
5
|
+
super()
|
6
|
+
|
7
|
+
@in1_features = in1_features
|
8
|
+
@in2_features = in2_features
|
9
|
+
@out_features = out_features
|
10
|
+
@weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
|
11
|
+
|
12
|
+
if bias
|
13
|
+
@bias = Parameter.new(Tensor.new(out_features))
|
14
|
+
else
|
15
|
+
raise NotImplementedYet
|
16
|
+
end
|
17
|
+
|
18
|
+
reset_parameters
|
19
|
+
end
|
20
|
+
|
21
|
+
def reset_parameters
|
22
|
+
bound = 1 / Math.sqrt(@weight.size(1))
|
23
|
+
Init.uniform!(@weight, a: -bound, b: bound)
|
24
|
+
if @bias
|
25
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input1, input2)
|
30
|
+
F.bilinear(input1, input2, @weight, @bias)
|
31
|
+
end
|
32
|
+
|
33
|
+
def extra_inspect
|
34
|
+
format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -1,39 +1,24 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
|
-
class Conv2d <
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
@stride = pair(1)
|
11
|
-
# @stride = pair(stride)
|
12
|
-
# @padding = pair(padding)
|
13
|
-
# @dilation = pair(dilation)
|
14
|
-
|
15
|
-
# TODO divide by groups
|
16
|
-
@weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
|
17
|
-
@bias = Parameter.new(Tensor.new(out_channels))
|
18
|
-
|
19
|
-
reset_parameters
|
3
|
+
class Conv2d < ConvNd
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
5
|
+
kernel_size = pair(kernel_size)
|
6
|
+
stride = pair(stride)
|
7
|
+
padding = pair(padding)
|
8
|
+
dilation = pair(dilation)
|
9
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode)
|
20
10
|
end
|
21
11
|
|
22
|
-
def
|
23
|
-
|
24
|
-
|
25
|
-
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
26
|
-
bound = 1 / Math.sqrt(fan_in)
|
27
|
-
Init.uniform_(@bias, -bound, bound)
|
12
|
+
def forward(input)
|
13
|
+
if @padding_mode == "circular"
|
14
|
+
raise NotImplementedError
|
28
15
|
end
|
16
|
+
F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups)
|
29
17
|
end
|
30
18
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
def inspect
|
36
|
-
"Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
|
19
|
+
# TODO add more parameters
|
20
|
+
def extra_inspect
|
21
|
+
format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
|
37
22
|
end
|
38
23
|
|
39
24
|
private
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ConvNd < Module
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
5
|
+
super()
|
6
|
+
raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
|
7
|
+
raise ArgumentError, "out_channels must be divisible by groups" if out_channels % groups != 0
|
8
|
+
@in_channels = in_channels
|
9
|
+
@out_channels = out_channels
|
10
|
+
@kernel_size = kernel_size
|
11
|
+
@stride = stride
|
12
|
+
@padding = padding
|
13
|
+
@dilation = dilation
|
14
|
+
@transposed = transposed
|
15
|
+
@output_padding = output_padding
|
16
|
+
@groups = groups
|
17
|
+
@padding_mode = padding_mode
|
18
|
+
if transposed
|
19
|
+
@weight = Parameter.new(Tensor.new(in_channels, out_channels / groups, *kernel_size))
|
20
|
+
else
|
21
|
+
@weight = Parameter.new(Tensor.new(out_channels, in_channels / groups, *kernel_size))
|
22
|
+
end
|
23
|
+
if bias
|
24
|
+
@bias = Parameter.new(Tensor.new(out_channels))
|
25
|
+
else
|
26
|
+
raise NotImplementedError
|
27
|
+
end
|
28
|
+
reset_parameters
|
29
|
+
end
|
30
|
+
|
31
|
+
def reset_parameters
|
32
|
+
Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
|
33
|
+
if @bias
|
34
|
+
fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
|
35
|
+
bound = 1 / Math.sqrt(fan_in)
|
36
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CosineEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CrossEntropyLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, ignore_index: -100, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
@ignore_index = ignore_index
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CTCLoss < Loss
|
4
|
+
def initialize(blank: 0, reduction: "mean", zero_infinity: false)
|
5
|
+
super(reduction)
|
6
|
+
@blank = blank
|
7
|
+
@zero_infinity = zero_infinity
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(log_probs, targets, input_lengths, target_lengths)
|
11
|
+
F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
|
2
|
+
module Torch
|
3
|
+
module NN
|
4
|
+
class Embedding < Module
|
5
|
+
def initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil,
|
6
|
+
norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil)
|
7
|
+
|
8
|
+
super()
|
9
|
+
@num_embeddings = num_embeddings
|
10
|
+
@embedding_dim = embedding_dim
|
11
|
+
|
12
|
+
if padding_idx
|
13
|
+
if padding_idx > 0
|
14
|
+
raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx < @num_embeddings
|
15
|
+
elsif padding_idx < 0
|
16
|
+
raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx >= -@num_embeddings
|
17
|
+
padding_idx = @num_embeddings + padding_idx
|
18
|
+
end
|
19
|
+
end
|
20
|
+
@padding_idx = padding_idx
|
21
|
+
@max_norm = max_norm
|
22
|
+
@norm_type = norm_type
|
23
|
+
@scale_grad_by_freq = scale_grad_by_freq
|
24
|
+
if _weight.nil?
|
25
|
+
@weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
|
26
|
+
reset_parameters
|
27
|
+
else
|
28
|
+
raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
|
29
|
+
@weight = Parameter.new(_weight)
|
30
|
+
end
|
31
|
+
@sparse = sparse
|
32
|
+
end
|
33
|
+
|
34
|
+
def reset_parameters
|
35
|
+
Init.normal!(@weight)
|
36
|
+
if @padding_idx
|
37
|
+
Torch.no_grad do
|
38
|
+
@weight[@padding_idx].fill!(0)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
def forward(input)
|
44
|
+
F.embedding(input, @weight, padding_idx: @padding_idx, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, sparse: @sparse)
|
45
|
+
end
|
46
|
+
|
47
|
+
def inspect
|
48
|
+
"Embedding(#{@num_embeddings}, #{@embedding_dim})"
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|