torch-rb 0.1.2 → 0.1.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +46 -22
- data/README.md +18 -6
- data/ext/torch/ext.cpp +148 -369
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +615 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +242 -0
- data/ext/torch/tensor_functions.cpp +1920 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2975 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +240 -131
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +27 -22
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +109 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +16 -38
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +411 -22
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +201 -20
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +56 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +48 -16
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +71 -30
- data/lib/torch/utils/data/data_loader.rb +10 -4
- data/lib/torch/utils/data/tensor_dataset.rb +3 -0
- data/lib/torch/version.rb +1 -1
- metadata +123 -6
@@ -0,0 +1,22 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Conv1d < ConvNd
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
+
padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
6
|
+
|
7
|
+
kernel_size = _single(kernel_size)
|
8
|
+
stride = _single(stride)
|
9
|
+
padding = _single(padding)
|
10
|
+
dilation = _single(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _single(0), groups, bias, padding_mode)
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
if @padding_mode == "circular"
|
16
|
+
raise NotImplementedError
|
17
|
+
end
|
18
|
+
F.conv1d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -1,48 +1,26 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
|
-
class Conv2d <
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
# @dilation = pair(dilation)
|
13
|
-
|
14
|
-
# TODO divide by groups
|
15
|
-
@weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
|
16
|
-
@bias = Parameter.new(Tensor.new(out_channels))
|
17
|
-
|
18
|
-
reset_parameters
|
3
|
+
class Conv2d < ConvNd
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
+
padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
6
|
+
|
7
|
+
kernel_size = _pair(kernel_size)
|
8
|
+
stride = _pair(stride)
|
9
|
+
padding = _pair(padding)
|
10
|
+
dilation = _pair(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _pair(0), groups, bias, padding_mode)
|
19
12
|
end
|
20
13
|
|
21
|
-
def
|
22
|
-
|
23
|
-
|
24
|
-
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
25
|
-
bound = 1 / Math.sqrt(fan_in)
|
26
|
-
Init.uniform_(@bias, -bound, bound)
|
14
|
+
def forward(input)
|
15
|
+
if @padding_mode == "circular"
|
16
|
+
raise NotImplementedError
|
27
17
|
end
|
18
|
+
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
28
19
|
end
|
29
20
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
def inspect
|
35
|
-
"Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
|
36
|
-
end
|
37
|
-
|
38
|
-
private
|
39
|
-
|
40
|
-
def pair(value)
|
41
|
-
if value.is_a?(Array)
|
42
|
-
value
|
43
|
-
else
|
44
|
-
[value] * 2
|
45
|
-
end
|
21
|
+
# TODO add more parameters
|
22
|
+
def extra_inspect
|
23
|
+
format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
|
46
24
|
end
|
47
25
|
end
|
48
26
|
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Conv3d < ConvNd
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
+
padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
6
|
+
|
7
|
+
kernel_size = _triple(kernel_size)
|
8
|
+
stride = _triple(stride)
|
9
|
+
padding = _triple(padding)
|
10
|
+
dilation = _triple(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _triple(0), groups, bias, padding_mode)
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
if @padding_mode == "circular"
|
16
|
+
raise NotImplementedError
|
17
|
+
end
|
18
|
+
F.conv3d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ConvNd < Module
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
5
|
+
super()
|
6
|
+
raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
|
7
|
+
raise ArgumentError, "out_channels must be divisible by groups" if out_channels % groups != 0
|
8
|
+
@in_channels = in_channels
|
9
|
+
@out_channels = out_channels
|
10
|
+
@kernel_size = kernel_size
|
11
|
+
@stride = stride
|
12
|
+
@padding = padding
|
13
|
+
@dilation = dilation
|
14
|
+
@transposed = transposed
|
15
|
+
@output_padding = output_padding
|
16
|
+
@groups = groups
|
17
|
+
@padding_mode = padding_mode
|
18
|
+
if transposed
|
19
|
+
@weight = Parameter.new(Tensor.new(in_channels, out_channels / groups, *kernel_size))
|
20
|
+
else
|
21
|
+
@weight = Parameter.new(Tensor.new(out_channels, in_channels / groups, *kernel_size))
|
22
|
+
end
|
23
|
+
if bias
|
24
|
+
@bias = Parameter.new(Tensor.new(out_channels))
|
25
|
+
else
|
26
|
+
raise NotImplementedError
|
27
|
+
end
|
28
|
+
reset_parameters
|
29
|
+
end
|
30
|
+
|
31
|
+
def reset_parameters
|
32
|
+
Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
|
33
|
+
if @bias
|
34
|
+
fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
|
35
|
+
bound = 1 / Math.sqrt(fan_in)
|
36
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CosineEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CrossEntropyLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, ignore_index: -100, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
@ignore_index = ignore_index
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CTCLoss < Loss
|
4
|
+
def initialize(blank: 0, reduction: "mean", zero_infinity: false)
|
5
|
+
super(reduction)
|
6
|
+
@blank = blank
|
7
|
+
@zero_infinity = zero_infinity
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(log_probs, targets, input_lengths, target_lengths)
|
11
|
+
F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
|
2
|
+
module Torch
|
3
|
+
module NN
|
4
|
+
class Embedding < Module
|
5
|
+
def initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil,
|
6
|
+
norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil)
|
7
|
+
|
8
|
+
super()
|
9
|
+
@num_embeddings = num_embeddings
|
10
|
+
@embedding_dim = embedding_dim
|
11
|
+
|
12
|
+
if padding_idx
|
13
|
+
if padding_idx > 0
|
14
|
+
raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx < @num_embeddings
|
15
|
+
elsif padding_idx < 0
|
16
|
+
raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx >= -@num_embeddings
|
17
|
+
padding_idx = @num_embeddings + padding_idx
|
18
|
+
end
|
19
|
+
end
|
20
|
+
@padding_idx = padding_idx
|
21
|
+
@max_norm = max_norm
|
22
|
+
@norm_type = norm_type
|
23
|
+
@scale_grad_by_freq = scale_grad_by_freq
|
24
|
+
if _weight.nil?
|
25
|
+
@weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
|
26
|
+
reset_parameters
|
27
|
+
else
|
28
|
+
raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
|
29
|
+
@weight = Parameter.new(_weight)
|
30
|
+
end
|
31
|
+
@sparse = sparse
|
32
|
+
end
|
33
|
+
|
34
|
+
def reset_parameters
|
35
|
+
Init.normal!(@weight)
|
36
|
+
if @padding_idx
|
37
|
+
Torch.no_grad do
|
38
|
+
@weight[@padding_idx].fill!(0)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
def forward(input)
|
44
|
+
F.embedding(input, @weight, padding_idx: @padding_idx, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, sparse: @sparse)
|
45
|
+
end
|
46
|
+
|
47
|
+
def inspect
|
48
|
+
"Embedding(#{@num_embeddings}, #{@embedding_dim})"
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
|
2
|
+
module Torch
|
3
|
+
module NN
|
4
|
+
class EmbeddingBag < Module
|
5
|
+
def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
|
6
|
+
scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
|
7
|
+
|
8
|
+
super()
|
9
|
+
@num_embeddings = num_embeddings
|
10
|
+
@embedding_dim = embedding_dim
|
11
|
+
@max_norm = max_norm
|
12
|
+
@norm_type = norm_type
|
13
|
+
@scale_grad_by_freq = scale_grad_by_freq
|
14
|
+
if _weight.nil?
|
15
|
+
@weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
|
16
|
+
reset_parameters
|
17
|
+
else
|
18
|
+
raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
|
19
|
+
@weight = Parameter.new(_weight)
|
20
|
+
end
|
21
|
+
@mode = mode
|
22
|
+
@sparse = sparse
|
23
|
+
end
|
24
|
+
|
25
|
+
def reset_parameters
|
26
|
+
Init.normal!(@weight)
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input, offsets: nil, per_sample_weights: nil)
|
30
|
+
F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Fold < Module
|
4
|
+
def initialize(output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
|
5
|
+
super()
|
6
|
+
@output_size = output_size
|
7
|
+
@kernel_size = kernel_size
|
8
|
+
@dilation = dilation
|
9
|
+
@padding = padding
|
10
|
+
@stride = stride
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(input)
|
14
|
+
F.fold(input, @output_size, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
|
15
|
+
end
|
16
|
+
|
17
|
+
# TODO add extra_inspect
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -2,52 +2,441 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class Functional
|
4
4
|
class << self
|
5
|
-
|
6
|
-
|
5
|
+
include Utils
|
6
|
+
|
7
|
+
# convolution layers
|
8
|
+
|
9
|
+
def conv1d(*args, **options)
|
10
|
+
Torch.conv1d(*args, **options)
|
11
|
+
end
|
12
|
+
|
13
|
+
def conv2d(*args, **options)
|
14
|
+
Torch.conv2d(*args, **options)
|
15
|
+
end
|
16
|
+
|
17
|
+
def conv3d(*args, **options)
|
18
|
+
Torch.conv3d(*args, **options)
|
19
|
+
end
|
20
|
+
|
21
|
+
def unfold(input, kernel_size, dilation: 1, padding: 0, stride: 1)
|
22
|
+
if input.dim == 4
|
23
|
+
NN.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
|
24
|
+
else
|
25
|
+
raise Error, "Input Error: Only 4D input Tensors are supported (got #{input.dim}D)"
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def fold(input, output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
|
30
|
+
if input.dim == 3
|
31
|
+
NN.col2im(input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
|
32
|
+
else
|
33
|
+
raise Error, "Input Error: Only 3D input Tensors are supported (got #{input.dim}D)"
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
# pooling layers
|
38
|
+
|
39
|
+
def max_pool1d(*args, **options)
|
40
|
+
return_indices = args.pop if args.size == 7
|
41
|
+
if return_indices
|
42
|
+
Torch.max_pool1d_with_indices(*args, **options)
|
43
|
+
else
|
44
|
+
Torch.max_pool1d(*args, **options)
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def max_pool2d(*args, **options)
|
49
|
+
return_indices = args.pop if args.size == 7
|
50
|
+
if return_indices
|
51
|
+
NN.max_pool2d_with_indices(*args, **options)
|
52
|
+
else
|
53
|
+
Torch.max_pool2d(*args, **options)
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
def max_pool3d(*args, **options)
|
58
|
+
return_indices = args.pop if args.size == 7
|
59
|
+
if return_indices
|
60
|
+
NN.max_pool3d_with_indices(*args, **options)
|
61
|
+
else
|
62
|
+
Torch.max_pool3d(*args, **options)
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
def max_unpool1d(input, indices, kernel_size, stride: nil, padding: 0, output_size: nil)
|
67
|
+
raise NotImplementedYet
|
68
|
+
kernel_size = _single(kernel_size)
|
69
|
+
if !stride.nil?
|
70
|
+
_stride = _single(stride)
|
71
|
+
else
|
72
|
+
_stride = kernel_size
|
73
|
+
end
|
74
|
+
padding = _single(padding)
|
75
|
+
output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
|
76
|
+
output_size = output_size + [1]
|
77
|
+
NN.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3)
|
78
|
+
end
|
79
|
+
|
80
|
+
def max_unpool2d(*args, **options)
|
81
|
+
raise NotImplementedYet
|
82
|
+
NN.max_unpool2d(*args, **options)
|
83
|
+
end
|
84
|
+
|
85
|
+
def max_unpool3d(*args, **options)
|
86
|
+
raise NotImplementedYet
|
87
|
+
NN.max_unpool3d(*args, **options)
|
88
|
+
end
|
89
|
+
|
90
|
+
def avg_pool1d(*args, **options)
|
91
|
+
Torch.avg_pool1d(*args, **options)
|
92
|
+
end
|
93
|
+
|
94
|
+
def avg_pool2d(*args, **options)
|
95
|
+
NN.avg_pool2d(*args, **options)
|
96
|
+
end
|
97
|
+
|
98
|
+
def avg_pool3d(*args, **options)
|
99
|
+
NN.avg_pool3d(*args, **options)
|
100
|
+
end
|
101
|
+
|
102
|
+
# padding layers
|
103
|
+
|
104
|
+
def pad(input, pad, mode: "constant", value: 0)
|
105
|
+
raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0
|
106
|
+
raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
|
107
|
+
|
108
|
+
if mode == "constant"
|
109
|
+
return Torch.constant_pad_nd(input, pad, value)
|
110
|
+
else
|
111
|
+
raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
|
112
|
+
|
113
|
+
if input.dim == 3
|
114
|
+
raise ArgumentError, "3D tensors expect 2 values for padding" unless pad.size == 2
|
115
|
+
case mode
|
116
|
+
when "reflect"
|
117
|
+
NN.reflection_pad1d(input, pad)
|
118
|
+
when "replicate"
|
119
|
+
NN.replication_pad1d(input, pad)
|
120
|
+
else
|
121
|
+
raise NotImplementedYet
|
122
|
+
end
|
123
|
+
elsif input.dim == 4
|
124
|
+
raise ArgumentError, "4D tensors expect 4 values for padding" unless pad.size == 4
|
125
|
+
case mode
|
126
|
+
when "reflect"
|
127
|
+
NN.reflection_pad2d(input, pad)
|
128
|
+
when "replicate"
|
129
|
+
NN.replication_pad2d(input, pad)
|
130
|
+
else
|
131
|
+
raise NotImplementedYet
|
132
|
+
end
|
133
|
+
elsif input.dim == 5
|
134
|
+
raise ArgumentError, "5D tensors expect 6 values for padding" unless pad.size == 6
|
135
|
+
case mode
|
136
|
+
when "replicate"
|
137
|
+
NN.replication_pad3d(input, pad)
|
138
|
+
else
|
139
|
+
raise NotImplementedYet
|
140
|
+
end
|
141
|
+
else
|
142
|
+
raise ArgumentError, "Only 3D, 4D, 5D padding with non-constant padding are supported for now"
|
143
|
+
end
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
# activation layers
|
148
|
+
|
149
|
+
def hardshrink(input, lambd = 0.5)
|
150
|
+
Torch.hardshrink(input, lambd)
|
151
|
+
end
|
152
|
+
|
153
|
+
def leaky_relu(input, negative_slope = 0.01)
|
154
|
+
NN.leaky_relu(input, negative_slope)
|
7
155
|
end
|
8
156
|
|
9
|
-
def
|
10
|
-
|
11
|
-
Torch.conv2d(input, weight, bias, stride, padding)
|
157
|
+
def log_sigmoid(input)
|
158
|
+
NN.log_sigmoid(input)
|
12
159
|
end
|
13
160
|
|
14
161
|
def prelu(input, weight)
|
15
162
|
Torch.prelu(input, weight)
|
16
163
|
end
|
17
164
|
|
18
|
-
def
|
19
|
-
|
165
|
+
def relu(input, inplace: false)
|
166
|
+
if inplace
|
167
|
+
input.relu!
|
168
|
+
else
|
169
|
+
input.relu
|
170
|
+
end
|
171
|
+
end
|
172
|
+
|
173
|
+
def softplus(input, beta: 1, threshold: 20)
|
174
|
+
NN.softplus(input, beta, threshold)
|
175
|
+
end
|
176
|
+
|
177
|
+
def softshrink(*args, **options)
|
178
|
+
NN.softshrink(*args, **options)
|
179
|
+
end
|
180
|
+
|
181
|
+
def softsign(input)
|
182
|
+
input / (input.abs + 1)
|
183
|
+
end
|
184
|
+
|
185
|
+
def tanhshrink(input)
|
186
|
+
input - input.tanh
|
187
|
+
end
|
188
|
+
|
189
|
+
# other activation layers
|
190
|
+
|
191
|
+
def softmin(input, dim: nil)
|
192
|
+
dim ||= softmax_dim(input.dim)
|
193
|
+
(-input).softmax(dim)
|
194
|
+
end
|
195
|
+
|
196
|
+
def softmax(input, dim: nil)
|
197
|
+
dim ||= softmax_dim(input.dim)
|
198
|
+
input.softmax(dim)
|
199
|
+
end
|
200
|
+
|
201
|
+
# TODO make dim keyword argument and update examples
|
202
|
+
def log_softmax(input, dim = nil)
|
203
|
+
dim ||= softmax_dim(input.dim)
|
204
|
+
input.log_softmax(dim)
|
205
|
+
end
|
206
|
+
|
207
|
+
# normalization layers
|
208
|
+
|
209
|
+
def batch_norm(input, running_mean, running_var, weight: nil, bias: nil,
|
210
|
+
training: false, momentum: 0.1, eps: 1e-5)
|
211
|
+
|
212
|
+
if training
|
213
|
+
size = input.size
|
214
|
+
size_prods = size[0]
|
215
|
+
(size.length - 2).times do |i|
|
216
|
+
size_prods *= size[i + 2]
|
217
|
+
end
|
218
|
+
if size_prods == 1
|
219
|
+
raise ArgumentError, "Expected more than 1 value per channel when training, got input size #{size.inspect}"
|
220
|
+
end
|
221
|
+
end
|
222
|
+
|
223
|
+
Torch.batch_norm(
|
224
|
+
input, weight, bias, running_mean, running_var,
|
225
|
+
training, momentum, eps, false
|
226
|
+
)
|
227
|
+
end
|
228
|
+
|
229
|
+
def group_norm(input, num_groups, weight: nil, bias: nil, eps: 1e-5)
|
230
|
+
Torch.group_norm(input, num_groups, weight, bias, eps, false)
|
20
231
|
end
|
21
232
|
|
22
|
-
def
|
23
|
-
|
24
|
-
|
233
|
+
def instance_norm(input, running_mean: nil, running_var: nil, weight: nil,
|
234
|
+
bias: nil, use_input_stats: true, momentum: 0.1, eps: 1e-5)
|
235
|
+
|
236
|
+
Torch.instance_norm(
|
237
|
+
input, weight, bias, running_mean, running_var,
|
238
|
+
use_input_stats, momentum, eps, false
|
239
|
+
)
|
240
|
+
end
|
241
|
+
|
242
|
+
def layer_norm(input, normalized_shape, weight: nil, bias: nil, eps: 1e-5)
|
243
|
+
Torch.layer_norm(input, normalized_shape, weight, bias, eps, false)
|
25
244
|
end
|
26
245
|
|
27
|
-
def
|
28
|
-
|
29
|
-
|
246
|
+
def local_response_norm(input, size, alpha: 1e-4, beta: 0.75, k: 1.0)
|
247
|
+
dim = input.dim
|
248
|
+
if dim < 3
|
249
|
+
raise ArgumentError, "Expected 3D or higher dimensionality input (got #{dim} dimensions)"
|
250
|
+
end
|
251
|
+
div = input.mul(input).unsqueeze(1)
|
252
|
+
if dim == 3
|
253
|
+
div = pad(div, [0, 0, size / 2, (size - 1) / 2])
|
254
|
+
div = avg_pool2d(div, [size, 1], stride: 1).squeeze(1)
|
255
|
+
else
|
256
|
+
sizes = input.size
|
257
|
+
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
|
258
|
+
div = pad(div, [0, 0, 0, 0, size / 2, (size - 1) / 2])
|
259
|
+
div = avg_pool3d(div, [size, 1, 1], stride: 1).squeeze(1)
|
260
|
+
div = div.view(sizes)
|
261
|
+
end
|
262
|
+
div = div.mul(alpha).add(k).pow(beta)
|
263
|
+
input / div
|
30
264
|
end
|
31
265
|
|
266
|
+
# linear layers
|
267
|
+
|
32
268
|
def linear(input, weight, bias)
|
33
|
-
|
269
|
+
NN.linear(input, weight, bias)
|
270
|
+
end
|
271
|
+
|
272
|
+
def bilinear(input1, input2, weight, bias)
|
273
|
+
Torch.bilinear(input1, input2, weight, bias)
|
274
|
+
end
|
275
|
+
|
276
|
+
# dropout layers
|
277
|
+
|
278
|
+
def dropout(input, p: 0.5, training: true, inplace: false)
|
279
|
+
if inplace
|
280
|
+
Torch.dropout!(input, p, training)
|
281
|
+
else
|
282
|
+
Torch.dropout(input, p, training)
|
283
|
+
end
|
284
|
+
end
|
285
|
+
|
286
|
+
def dropout2d(input, p: 0.5, training: true, inplace: false)
|
287
|
+
raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
|
288
|
+
|
289
|
+
if inplace
|
290
|
+
Torch.feature_dropout!(input, p, training)
|
291
|
+
else
|
292
|
+
Torch.feature_dropout(input, p, training)
|
293
|
+
end
|
294
|
+
end
|
295
|
+
|
296
|
+
def dropout3d(input, p: 0.5, training: true, inplace: false)
|
297
|
+
if inplace
|
298
|
+
Torch.feature_dropout!(input, p, training)
|
299
|
+
else
|
300
|
+
Torch.feature_dropout(input, p, training)
|
301
|
+
end
|
302
|
+
end
|
303
|
+
|
304
|
+
def alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
305
|
+
if inplace
|
306
|
+
Torch.alpha_dropout!(input, p, training)
|
307
|
+
else
|
308
|
+
Torch.alpha_dropout(input, p, training)
|
309
|
+
end
|
310
|
+
end
|
311
|
+
|
312
|
+
def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
313
|
+
if inplace
|
314
|
+
Torch.feature_alpha_dropout!(input, p, training)
|
315
|
+
else
|
316
|
+
Torch.feature_alpha_dropout(input, p, training)
|
317
|
+
end
|
318
|
+
end
|
319
|
+
|
320
|
+
# sparse layers
|
321
|
+
|
322
|
+
def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
|
323
|
+
# TODO handle max_norm and norm_type
|
324
|
+
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
325
|
+
|
326
|
+
padding_idx ||= -1
|
327
|
+
# weight and indices are swapped from Python interface
|
328
|
+
Torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
329
|
+
end
|
330
|
+
|
331
|
+
def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
|
332
|
+
# TODO handle max_norm and norm_type
|
333
|
+
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
334
|
+
|
335
|
+
mode_enum =
|
336
|
+
case mode
|
337
|
+
when "sum"
|
338
|
+
0
|
339
|
+
when "mean"
|
340
|
+
1
|
341
|
+
when "max"
|
342
|
+
2
|
343
|
+
else
|
344
|
+
raise ArgumentError, "Unknown mode: #{mode}"
|
345
|
+
end
|
346
|
+
|
347
|
+
# weight and input swapped
|
348
|
+
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
349
|
+
end
|
350
|
+
|
351
|
+
# distance functions
|
352
|
+
|
353
|
+
def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
|
354
|
+
Torch.cosine_similarity(x1, x2, dim, eps)
|
355
|
+
end
|
356
|
+
|
357
|
+
def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
|
358
|
+
Torch.pairwise_distance(x1, x2, p, eps, keepdim)
|
359
|
+
end
|
360
|
+
|
361
|
+
# loss functions
|
362
|
+
|
363
|
+
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
364
|
+
NN.binary_cross_entropy(input, target, weight, reduction)
|
365
|
+
end
|
366
|
+
|
367
|
+
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
368
|
+
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
369
|
+
end
|
370
|
+
|
371
|
+
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
372
|
+
raise NotImplementedYet
|
373
|
+
end
|
374
|
+
|
375
|
+
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
376
|
+
nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
|
377
|
+
end
|
378
|
+
|
379
|
+
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
380
|
+
# call to_a on input_lengths and target_lengths for C++
|
381
|
+
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
382
|
+
end
|
383
|
+
|
384
|
+
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
385
|
+
Torch.hinge_embedding_loss(input, target, margin, reduction)
|
386
|
+
end
|
387
|
+
|
388
|
+
def kl_div(input, target, reduction: "mean")
|
389
|
+
Torch.kl_div(input, target, reduction)
|
390
|
+
end
|
391
|
+
|
392
|
+
def l1_loss(input, target, reduction: "mean")
|
393
|
+
NN.l1_loss(input, target, reduction)
|
394
|
+
end
|
395
|
+
|
396
|
+
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
397
|
+
raise NotImplementedYet
|
34
398
|
end
|
35
399
|
|
36
400
|
def mse_loss(input, target, reduction: "mean")
|
37
|
-
|
401
|
+
NN.mse_loss(input, target, reduction)
|
38
402
|
end
|
39
403
|
|
40
|
-
def
|
41
|
-
|
404
|
+
def multilabel_margin_loss(input, target, reduction: "mean")
|
405
|
+
NN.multilabel_margin_loss(input, target, reduction)
|
42
406
|
end
|
43
407
|
|
44
|
-
def
|
45
|
-
|
46
|
-
Torch.nll_loss(input, target)
|
408
|
+
def multilabel_soft_margin_loss(input, target, weight: nil)
|
409
|
+
raise NotImplementedYet
|
47
410
|
end
|
48
411
|
|
49
|
-
def
|
50
|
-
|
412
|
+
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
413
|
+
NN.multi_margin_loss(input, target, p, margin, weight, reduction)
|
414
|
+
end
|
415
|
+
|
416
|
+
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
417
|
+
NN.nll_loss(input, target, weight, reduction, ignore_index)
|
418
|
+
end
|
419
|
+
|
420
|
+
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
421
|
+
Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
422
|
+
end
|
423
|
+
|
424
|
+
def soft_margin_loss(input, target, reduction: "mean")
|
425
|
+
NN.soft_margin_loss(input, target, reduction)
|
426
|
+
end
|
427
|
+
|
428
|
+
def smooth_l1_loss(input, target, reduction: "mean")
|
429
|
+
NN.smooth_l1_loss(input, target, reduction)
|
430
|
+
end
|
431
|
+
|
432
|
+
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
433
|
+
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
434
|
+
end
|
435
|
+
|
436
|
+
private
|
437
|
+
|
438
|
+
def softmax_dim(ndim)
|
439
|
+
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
|
51
440
|
end
|
52
441
|
end
|
53
442
|
end
|