torch-rb 0.1.3 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -0
- data/README.md +5 -2
- data/ext/torch/ext.cpp +130 -555
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +244 -0
- data/lib/torch.rb +209 -171
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +110 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6491 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +10 -20
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +379 -32
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +186 -35
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +51 -44
- data/lib/torch/version.rb +1 -1
- metadata +98 -6
- data/lib/torch/ext.bundle +0 -0
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -1,36 +1,26 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Conv2d < ConvNd
|
4
|
-
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
+
padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
5
6
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
dilation
|
11
|
-
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode)
|
7
|
+
kernel_size = _pair(kernel_size)
|
8
|
+
stride = _pair(stride)
|
9
|
+
padding = _pair(padding)
|
10
|
+
dilation = _pair(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _pair(0), groups, bias, padding_mode)
|
12
12
|
end
|
13
13
|
|
14
14
|
def forward(input)
|
15
15
|
if @padding_mode == "circular"
|
16
16
|
raise NotImplementedError
|
17
17
|
end
|
18
|
-
F.conv2d(input, @weight, @bias,
|
18
|
+
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
19
|
end
|
20
20
|
|
21
21
|
# TODO add more parameters
|
22
|
-
def
|
23
|
-
"
|
24
|
-
end
|
25
|
-
|
26
|
-
private
|
27
|
-
|
28
|
-
def pair(value)
|
29
|
-
if value.is_a?(Array)
|
30
|
-
value
|
31
|
-
else
|
32
|
-
[value] * 2
|
33
|
-
end
|
22
|
+
def extra_inspect
|
23
|
+
format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
|
34
24
|
end
|
35
25
|
end
|
36
26
|
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Conv3d < ConvNd
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
+
padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
6
|
+
|
7
|
+
kernel_size = _triple(kernel_size)
|
8
|
+
stride = _triple(stride)
|
9
|
+
padding = _triple(padding)
|
10
|
+
dilation = _triple(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _triple(0), groups, bias, padding_mode)
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
if @padding_mode == "circular"
|
16
|
+
raise NotImplementedError
|
17
|
+
end
|
18
|
+
F.conv3d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
data/lib/torch/nn/convnd.rb
CHANGED
@@ -29,11 +29,11 @@ module Torch
|
|
29
29
|
end
|
30
30
|
|
31
31
|
def reset_parameters
|
32
|
-
Init.kaiming_uniform!(@weight, Math.sqrt(5))
|
32
|
+
Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
|
33
33
|
if @bias
|
34
|
-
fan_in, _ = Init.
|
34
|
+
fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
|
35
35
|
bound = 1 / Math.sqrt(fan_in)
|
36
|
-
Init.uniform!(@bias, -bound, bound)
|
36
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
37
37
|
end
|
38
38
|
end
|
39
39
|
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CosineEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CrossEntropyLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, ignore_index: -100, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
@ignore_index = ignore_index
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CTCLoss < Loss
|
4
|
+
def initialize(blank: 0, reduction: "mean", zero_infinity: false)
|
5
|
+
super(reduction)
|
6
|
+
@blank = blank
|
7
|
+
@zero_infinity = zero_infinity
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(log_probs, targets, input_lengths, target_lengths)
|
11
|
+
F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
data/lib/torch/nn/dropoutnd.rb
CHANGED
@@ -0,0 +1,34 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
|
2
|
+
module Torch
|
3
|
+
module NN
|
4
|
+
class EmbeddingBag < Module
|
5
|
+
def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
|
6
|
+
scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
|
7
|
+
|
8
|
+
super()
|
9
|
+
@num_embeddings = num_embeddings
|
10
|
+
@embedding_dim = embedding_dim
|
11
|
+
@max_norm = max_norm
|
12
|
+
@norm_type = norm_type
|
13
|
+
@scale_grad_by_freq = scale_grad_by_freq
|
14
|
+
if _weight.nil?
|
15
|
+
@weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
|
16
|
+
reset_parameters
|
17
|
+
else
|
18
|
+
raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
|
19
|
+
@weight = Parameter.new(_weight)
|
20
|
+
end
|
21
|
+
@mode = mode
|
22
|
+
@sparse = sparse
|
23
|
+
end
|
24
|
+
|
25
|
+
def reset_parameters
|
26
|
+
Init.normal!(@weight)
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input, offsets: nil, per_sample_weights: nil)
|
30
|
+
F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Fold < Module
|
4
|
+
def initialize(output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
|
5
|
+
super()
|
6
|
+
@output_size = output_size
|
7
|
+
@kernel_size = kernel_size
|
8
|
+
@dilation = dilation
|
9
|
+
@padding = padding
|
10
|
+
@stride = stride
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(input)
|
14
|
+
F.fold(input, @output_size, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
|
15
|
+
end
|
16
|
+
|
17
|
+
# TODO add extra_inspect
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -2,51 +2,284 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class Functional
|
4
4
|
class << self
|
5
|
-
|
6
|
-
|
5
|
+
include Utils
|
6
|
+
|
7
|
+
# convolution layers
|
8
|
+
|
9
|
+
def conv1d(*args, **options)
|
10
|
+
Torch.conv1d(*args, **options)
|
7
11
|
end
|
8
12
|
|
9
|
-
def conv2d(
|
10
|
-
|
11
|
-
Torch.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
13
|
+
def conv2d(*args, **options)
|
14
|
+
Torch.conv2d(*args, **options)
|
12
15
|
end
|
13
16
|
|
14
|
-
def
|
15
|
-
|
16
|
-
Torch.max_pool2d(input, kernel_size)
|
17
|
+
def conv3d(*args, **options)
|
18
|
+
Torch.conv3d(*args, **options)
|
17
19
|
end
|
18
20
|
|
19
|
-
def
|
20
|
-
|
21
|
-
|
21
|
+
def unfold(input, kernel_size, dilation: 1, padding: 0, stride: 1)
|
22
|
+
if input.dim == 4
|
23
|
+
NN.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
|
24
|
+
else
|
25
|
+
raise Error, "Input Error: Only 4D input Tensors are supported (got #{input.dim}D)"
|
26
|
+
end
|
22
27
|
end
|
23
28
|
|
24
|
-
def
|
25
|
-
|
29
|
+
def fold(input, output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
|
30
|
+
if input.dim == 3
|
31
|
+
NN.col2im(input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
|
32
|
+
else
|
33
|
+
raise Error, "Input Error: Only 3D input Tensors are supported (got #{input.dim}D)"
|
34
|
+
end
|
26
35
|
end
|
27
36
|
|
28
|
-
|
29
|
-
|
37
|
+
# pooling layers
|
38
|
+
|
39
|
+
def max_pool1d(*args, **options)
|
40
|
+
return_indices = args.pop if args.size == 7
|
41
|
+
if return_indices
|
42
|
+
Torch.max_pool1d_with_indices(*args, **options)
|
43
|
+
else
|
44
|
+
Torch.max_pool1d(*args, **options)
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def max_pool2d(*args, **options)
|
49
|
+
return_indices = args.pop if args.size == 7
|
50
|
+
if return_indices
|
51
|
+
NN.max_pool2d_with_indices(*args, **options)
|
52
|
+
else
|
53
|
+
Torch.max_pool2d(*args, **options)
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
def max_pool3d(*args, **options)
|
58
|
+
return_indices = args.pop if args.size == 7
|
59
|
+
if return_indices
|
60
|
+
NN.max_pool3d_with_indices(*args, **options)
|
61
|
+
else
|
62
|
+
Torch.max_pool3d(*args, **options)
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
def max_unpool1d(input, indices, kernel_size, stride: nil, padding: 0, output_size: nil)
|
67
|
+
raise NotImplementedYet
|
68
|
+
kernel_size = _single(kernel_size)
|
69
|
+
if !stride.nil?
|
70
|
+
_stride = _single(stride)
|
71
|
+
else
|
72
|
+
_stride = kernel_size
|
73
|
+
end
|
74
|
+
padding = _single(padding)
|
75
|
+
output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
|
76
|
+
output_size = output_size + [1]
|
77
|
+
NN.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3)
|
78
|
+
end
|
79
|
+
|
80
|
+
def max_unpool2d(*args, **options)
|
81
|
+
raise NotImplementedYet
|
82
|
+
NN.max_unpool2d(*args, **options)
|
83
|
+
end
|
84
|
+
|
85
|
+
def max_unpool3d(*args, **options)
|
86
|
+
raise NotImplementedYet
|
87
|
+
NN.max_unpool3d(*args, **options)
|
30
88
|
end
|
31
89
|
|
32
|
-
def
|
33
|
-
|
90
|
+
def avg_pool1d(*args, **options)
|
91
|
+
Torch.avg_pool1d(*args, **options)
|
34
92
|
end
|
35
93
|
|
36
|
-
def
|
37
|
-
|
38
|
-
|
94
|
+
def avg_pool2d(*args, **options)
|
95
|
+
NN.avg_pool2d(*args, **options)
|
96
|
+
end
|
97
|
+
|
98
|
+
def avg_pool3d(*args, **options)
|
99
|
+
NN.avg_pool3d(*args, **options)
|
100
|
+
end
|
101
|
+
|
102
|
+
# padding layers
|
103
|
+
|
104
|
+
def pad(input, pad, mode: "constant", value: 0)
|
105
|
+
raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0
|
106
|
+
raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
|
107
|
+
|
108
|
+
if mode == "constant"
|
109
|
+
return Torch.constant_pad_nd(input, pad, value)
|
110
|
+
else
|
111
|
+
raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
|
112
|
+
|
113
|
+
if input.dim == 3
|
114
|
+
raise ArgumentError, "3D tensors expect 2 values for padding" unless pad.size == 2
|
115
|
+
case mode
|
116
|
+
when "reflect"
|
117
|
+
NN.reflection_pad1d(input, pad)
|
118
|
+
when "replicate"
|
119
|
+
NN.replication_pad1d(input, pad)
|
120
|
+
else
|
121
|
+
raise NotImplementedYet
|
122
|
+
end
|
123
|
+
elsif input.dim == 4
|
124
|
+
raise ArgumentError, "4D tensors expect 4 values for padding" unless pad.size == 4
|
125
|
+
case mode
|
126
|
+
when "reflect"
|
127
|
+
NN.reflection_pad2d(input, pad)
|
128
|
+
when "replicate"
|
129
|
+
NN.replication_pad2d(input, pad)
|
130
|
+
else
|
131
|
+
raise NotImplementedYet
|
132
|
+
end
|
133
|
+
elsif input.dim == 5
|
134
|
+
raise ArgumentError, "5D tensors expect 6 values for padding" unless pad.size == 6
|
135
|
+
case mode
|
136
|
+
when "replicate"
|
137
|
+
NN.replication_pad3d(input, pad)
|
138
|
+
else
|
139
|
+
raise NotImplementedYet
|
140
|
+
end
|
141
|
+
else
|
142
|
+
raise ArgumentError, "Only 3D, 4D, 5D padding with non-constant padding are supported for now"
|
143
|
+
end
|
144
|
+
end
|
39
145
|
end
|
40
146
|
|
41
|
-
|
147
|
+
# activation layers
|
148
|
+
|
149
|
+
def hardshrink(input, lambd = 0.5)
|
150
|
+
Torch.hardshrink(input, lambd)
|
151
|
+
end
|
152
|
+
|
153
|
+
def leaky_relu(input, negative_slope = 0.01)
|
154
|
+
NN.leaky_relu(input, negative_slope)
|
155
|
+
end
|
156
|
+
|
157
|
+
def log_sigmoid(input)
|
158
|
+
NN.log_sigmoid(input)
|
159
|
+
end
|
160
|
+
|
161
|
+
def prelu(input, weight)
|
162
|
+
Torch.prelu(input, weight)
|
163
|
+
end
|
164
|
+
|
165
|
+
def relu(input, inplace: false)
|
166
|
+
if inplace
|
167
|
+
input.relu!
|
168
|
+
else
|
169
|
+
input.relu
|
170
|
+
end
|
171
|
+
end
|
172
|
+
|
173
|
+
def softplus(input, beta: 1, threshold: 20)
|
174
|
+
NN.softplus(input, beta, threshold)
|
175
|
+
end
|
176
|
+
|
177
|
+
def softshrink(*args, **options)
|
178
|
+
NN.softshrink(*args, **options)
|
179
|
+
end
|
180
|
+
|
181
|
+
def softsign(input)
|
182
|
+
input / (input.abs + 1)
|
183
|
+
end
|
184
|
+
|
185
|
+
def tanhshrink(input)
|
186
|
+
input - input.tanh
|
187
|
+
end
|
188
|
+
|
189
|
+
# other activation layers
|
190
|
+
|
191
|
+
def softmin(input, dim: nil)
|
192
|
+
dim ||= softmax_dim(input.dim)
|
193
|
+
(-input).softmax(dim)
|
194
|
+
end
|
195
|
+
|
196
|
+
def softmax(input, dim: nil)
|
197
|
+
dim ||= softmax_dim(input.dim)
|
198
|
+
input.softmax(dim)
|
199
|
+
end
|
200
|
+
|
201
|
+
# TODO make dim keyword argument and update examples
|
202
|
+
def log_softmax(input, dim = nil)
|
203
|
+
dim ||= softmax_dim(input.dim)
|
42
204
|
input.log_softmax(dim)
|
43
205
|
end
|
44
206
|
|
207
|
+
# normalization layers
|
208
|
+
|
209
|
+
def batch_norm(input, running_mean, running_var, weight: nil, bias: nil,
|
210
|
+
training: false, momentum: 0.1, eps: 1e-5)
|
211
|
+
|
212
|
+
if training
|
213
|
+
size = input.size
|
214
|
+
size_prods = size[0]
|
215
|
+
(size.length - 2).times do |i|
|
216
|
+
size_prods *= size[i + 2]
|
217
|
+
end
|
218
|
+
if size_prods == 1
|
219
|
+
raise ArgumentError, "Expected more than 1 value per channel when training, got input size #{size.inspect}"
|
220
|
+
end
|
221
|
+
end
|
222
|
+
|
223
|
+
Torch.batch_norm(
|
224
|
+
input, weight, bias, running_mean, running_var,
|
225
|
+
training, momentum, eps, false
|
226
|
+
)
|
227
|
+
end
|
228
|
+
|
229
|
+
def group_norm(input, num_groups, weight: nil, bias: nil, eps: 1e-5)
|
230
|
+
Torch.group_norm(input, num_groups, weight, bias, eps, false)
|
231
|
+
end
|
232
|
+
|
233
|
+
def instance_norm(input, running_mean: nil, running_var: nil, weight: nil,
|
234
|
+
bias: nil, use_input_stats: true, momentum: 0.1, eps: 1e-5)
|
235
|
+
|
236
|
+
Torch.instance_norm(
|
237
|
+
input, weight, bias, running_mean, running_var,
|
238
|
+
use_input_stats, momentum, eps, false
|
239
|
+
)
|
240
|
+
end
|
241
|
+
|
242
|
+
def layer_norm(input, normalized_shape, weight: nil, bias: nil, eps: 1e-5)
|
243
|
+
Torch.layer_norm(input, normalized_shape, weight, bias, eps, false)
|
244
|
+
end
|
245
|
+
|
246
|
+
def local_response_norm(input, size, alpha: 1e-4, beta: 0.75, k: 1.0)
|
247
|
+
dim = input.dim
|
248
|
+
if dim < 3
|
249
|
+
raise ArgumentError, "Expected 3D or higher dimensionality input (got #{dim} dimensions)"
|
250
|
+
end
|
251
|
+
div = input.mul(input).unsqueeze(1)
|
252
|
+
if dim == 3
|
253
|
+
div = pad(div, [0, 0, size / 2, (size - 1) / 2])
|
254
|
+
div = avg_pool2d(div, [size, 1], stride: 1).squeeze(1)
|
255
|
+
else
|
256
|
+
sizes = input.size
|
257
|
+
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
|
258
|
+
div = pad(div, [0, 0, 0, 0, size / 2, (size - 1) / 2])
|
259
|
+
div = avg_pool3d(div, [size, 1, 1], stride: 1).squeeze(1)
|
260
|
+
div = div.view(sizes)
|
261
|
+
end
|
262
|
+
div = div.mul(alpha).add(k).pow(beta)
|
263
|
+
input / div
|
264
|
+
end
|
265
|
+
|
266
|
+
# linear layers
|
267
|
+
|
268
|
+
def linear(input, weight, bias)
|
269
|
+
NN.linear(input, weight, bias)
|
270
|
+
end
|
271
|
+
|
272
|
+
def bilinear(input1, input2, weight, bias)
|
273
|
+
Torch.bilinear(input1, input2, weight, bias)
|
274
|
+
end
|
275
|
+
|
276
|
+
# dropout layers
|
277
|
+
|
45
278
|
def dropout(input, p: 0.5, training: true, inplace: false)
|
46
279
|
if inplace
|
47
|
-
Torch.
|
280
|
+
Torch.dropout!(input, p, training)
|
48
281
|
else
|
49
|
-
Torch.
|
282
|
+
Torch.dropout(input, p, training)
|
50
283
|
end
|
51
284
|
end
|
52
285
|
|
@@ -54,42 +287,156 @@ module Torch
|
|
54
287
|
raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
|
55
288
|
|
56
289
|
if inplace
|
57
|
-
Torch.
|
290
|
+
Torch.feature_dropout!(input, p, training)
|
58
291
|
else
|
59
|
-
Torch.
|
292
|
+
Torch.feature_dropout(input, p, training)
|
60
293
|
end
|
61
294
|
end
|
62
295
|
|
63
296
|
def dropout3d(input, p: 0.5, training: true, inplace: false)
|
64
297
|
if inplace
|
65
|
-
Torch.
|
298
|
+
Torch.feature_dropout!(input, p, training)
|
66
299
|
else
|
67
|
-
Torch.
|
300
|
+
Torch.feature_dropout(input, p, training)
|
68
301
|
end
|
69
302
|
end
|
70
303
|
|
71
304
|
def alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
72
305
|
if inplace
|
73
|
-
Torch.
|
306
|
+
Torch.alpha_dropout!(input, p, training)
|
74
307
|
else
|
75
|
-
Torch.
|
308
|
+
Torch.alpha_dropout(input, p, training)
|
76
309
|
end
|
77
310
|
end
|
78
311
|
|
79
312
|
def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
80
313
|
if inplace
|
81
|
-
Torch.
|
314
|
+
Torch.feature_alpha_dropout!(input, p, training)
|
82
315
|
else
|
83
|
-
Torch.
|
316
|
+
Torch.feature_alpha_dropout(input, p, training)
|
84
317
|
end
|
85
318
|
end
|
86
319
|
|
320
|
+
# sparse layers
|
321
|
+
|
87
322
|
def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
|
88
323
|
# TODO handle max_norm and norm_type
|
89
324
|
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
90
325
|
|
91
326
|
padding_idx ||= -1
|
92
|
-
|
327
|
+
# weight and indices are swapped from Python interface
|
328
|
+
Torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
329
|
+
end
|
330
|
+
|
331
|
+
def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
|
332
|
+
# TODO handle max_norm and norm_type
|
333
|
+
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
334
|
+
|
335
|
+
mode_enum =
|
336
|
+
case mode
|
337
|
+
when "sum"
|
338
|
+
0
|
339
|
+
when "mean"
|
340
|
+
1
|
341
|
+
when "max"
|
342
|
+
2
|
343
|
+
else
|
344
|
+
raise ArgumentError, "Unknown mode: #{mode}"
|
345
|
+
end
|
346
|
+
|
347
|
+
# weight and input swapped
|
348
|
+
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
349
|
+
end
|
350
|
+
|
351
|
+
# distance functions
|
352
|
+
|
353
|
+
def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
|
354
|
+
Torch.cosine_similarity(x1, x2, dim, eps)
|
355
|
+
end
|
356
|
+
|
357
|
+
def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
|
358
|
+
Torch.pairwise_distance(x1, x2, p, eps, keepdim)
|
359
|
+
end
|
360
|
+
|
361
|
+
# loss functions
|
362
|
+
|
363
|
+
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
364
|
+
NN.binary_cross_entropy(input, target, weight, reduction)
|
365
|
+
end
|
366
|
+
|
367
|
+
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
368
|
+
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
369
|
+
end
|
370
|
+
|
371
|
+
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
372
|
+
raise NotImplementedYet
|
373
|
+
end
|
374
|
+
|
375
|
+
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
376
|
+
nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
|
377
|
+
end
|
378
|
+
|
379
|
+
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
380
|
+
# call to_a on input_lengths and target_lengths for C++
|
381
|
+
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
382
|
+
end
|
383
|
+
|
384
|
+
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
385
|
+
Torch.hinge_embedding_loss(input, target, margin, reduction)
|
386
|
+
end
|
387
|
+
|
388
|
+
def kl_div(input, target, reduction: "mean")
|
389
|
+
Torch.kl_div(input, target, reduction)
|
390
|
+
end
|
391
|
+
|
392
|
+
def l1_loss(input, target, reduction: "mean")
|
393
|
+
NN.l1_loss(input, target, reduction)
|
394
|
+
end
|
395
|
+
|
396
|
+
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
397
|
+
raise NotImplementedYet
|
398
|
+
end
|
399
|
+
|
400
|
+
def mse_loss(input, target, reduction: "mean")
|
401
|
+
NN.mse_loss(input, target, reduction)
|
402
|
+
end
|
403
|
+
|
404
|
+
def multilabel_margin_loss(input, target, reduction: "mean")
|
405
|
+
NN.multilabel_margin_loss(input, target, reduction)
|
406
|
+
end
|
407
|
+
|
408
|
+
def multilabel_soft_margin_loss(input, target, weight: nil)
|
409
|
+
raise NotImplementedYet
|
410
|
+
end
|
411
|
+
|
412
|
+
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
413
|
+
NN.multi_margin_loss(input, target, p, margin, weight, reduction)
|
414
|
+
end
|
415
|
+
|
416
|
+
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
417
|
+
NN.nll_loss(input, target, weight, reduction, ignore_index)
|
418
|
+
end
|
419
|
+
|
420
|
+
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
421
|
+
Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
422
|
+
end
|
423
|
+
|
424
|
+
def soft_margin_loss(input, target, reduction: "mean")
|
425
|
+
NN.soft_margin_loss(input, target, reduction)
|
426
|
+
end
|
427
|
+
|
428
|
+
def smooth_l1_loss(input, target, reduction: "mean")
|
429
|
+
NN.smooth_l1_loss(input, target, reduction)
|
430
|
+
end
|
431
|
+
|
432
|
+
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
433
|
+
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
434
|
+
end
|
435
|
+
|
436
|
+
private
|
437
|
+
|
438
|
+
def softmax_dim(ndim)
|
439
|
+
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
|
93
440
|
end
|
94
441
|
end
|
95
442
|
end
|