torch-rb 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +0 -170
- data/ext/torch/nn_functions.cpp +44 -24
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +48 -0
- data/ext/torch/tensor_functions.cpp +76 -16
- data/ext/torch/torch_functions.cpp +165 -65
- data/lib/torch.rb +51 -42
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +1 -1
- data/lib/torch/native/function.rb +36 -5
- data/lib/torch/native/generator.rb +26 -7
- data/lib/torch/native/parser.rb +51 -14
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +7 -2
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +9 -17
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +320 -100
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +1 -1
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +6 -6
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +7 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn_base.rb +48 -4
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/tensor.rb +14 -25
- data/lib/torch/version.rb +1 -1
- metadata +50 -2
data/lib/torch/nn/avg_poolnd.rb
CHANGED
@@ -0,0 +1,75 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BatchNorm < Module
|
4
|
+
def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: true, track_running_stats: true)
|
5
|
+
super()
|
6
|
+
@num_features = num_features
|
7
|
+
@eps = eps
|
8
|
+
@momentum = momentum
|
9
|
+
@affine = affine
|
10
|
+
@track_running_stats = track_running_stats
|
11
|
+
if @affine
|
12
|
+
@weight = Parameter.new(Torch::Tensor.new(num_features))
|
13
|
+
@bias = Parameter.new(Torch::Tensor.new(num_features))
|
14
|
+
else
|
15
|
+
register_parameter("weight", nil)
|
16
|
+
register_parameter("bias", nil)
|
17
|
+
end
|
18
|
+
if track_running_stats
|
19
|
+
register_buffer("running_mean", Torch.zeros(num_features))
|
20
|
+
register_buffer("running_var", Torch.ones(num_features))
|
21
|
+
register_buffer("num_batches_tracked", Torch.tensor(0, dtype: :long))
|
22
|
+
else
|
23
|
+
register_parameter("running_mean", nil)
|
24
|
+
register_parameter("running_var", nil)
|
25
|
+
register_parameter("num_batches_tracked", nil)
|
26
|
+
end
|
27
|
+
reset_parameters
|
28
|
+
end
|
29
|
+
|
30
|
+
def reset_running_stats
|
31
|
+
if @track_running_stats
|
32
|
+
@running_mean.zero!
|
33
|
+
@running_var.fill!(1)
|
34
|
+
@num_batches_tracked.zero!
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def reset_parameters
|
39
|
+
reset_running_stats
|
40
|
+
if @affine
|
41
|
+
Init.ones!(@weight)
|
42
|
+
Init.zeros!(@bias)
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def forward(input)
|
47
|
+
_check_input_dim(input)
|
48
|
+
|
49
|
+
if @momentum.nil?
|
50
|
+
exponential_average_factor = 0.0
|
51
|
+
else
|
52
|
+
exponential_average_factor = @momentum
|
53
|
+
end
|
54
|
+
|
55
|
+
if @training and @track_running_stats
|
56
|
+
if @num_batches_tracked.nil?
|
57
|
+
@num_batches_tracked += 1
|
58
|
+
if @momentum.nil?
|
59
|
+
exponential_average_factor = 1.0 / @num_batches_tracked.to_f
|
60
|
+
else
|
61
|
+
exponential_average_factor = @momentum
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
F.batch_norm(
|
67
|
+
input, @running_mean, @running_var,
|
68
|
+
weight: @weight, bias: @bias,
|
69
|
+
training: @training || !@track_running_stats,
|
70
|
+
momentum: exponential_average_factor, eps: @eps
|
71
|
+
)
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ConstantPadNd < Module
|
4
|
+
def initialize(value)
|
5
|
+
super()
|
6
|
+
@value = value
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input)
|
10
|
+
F.pad(input, @padding, mode: "constant", value: @value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
format("padding: %s, value: %s", @padding, @value)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Conv1d < ConvNd
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
+
padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
6
|
+
|
7
|
+
kernel_size = _single(kernel_size)
|
8
|
+
stride = _single(stride)
|
9
|
+
padding = _single(padding)
|
10
|
+
dilation = _single(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _single(0), groups, bias, padding_mode)
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
if @padding_mode == "circular"
|
16
|
+
raise NotImplementedError
|
17
|
+
end
|
18
|
+
F.conv1d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -1,35 +1,27 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Conv2d < ConvNd
|
4
|
-
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
+
padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
6
|
+
|
7
|
+
kernel_size = _pair(kernel_size)
|
8
|
+
stride = _pair(stride)
|
9
|
+
padding = _pair(padding)
|
10
|
+
dilation = _pair(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _pair(0), groups, bias, padding_mode)
|
10
12
|
end
|
11
13
|
|
12
14
|
def forward(input)
|
13
15
|
if @padding_mode == "circular"
|
14
16
|
raise NotImplementedError
|
15
17
|
end
|
16
|
-
F.conv2d(input, @weight, @bias,
|
18
|
+
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
17
19
|
end
|
18
20
|
|
19
21
|
# TODO add more parameters
|
20
22
|
def extra_inspect
|
21
23
|
format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
|
22
24
|
end
|
23
|
-
|
24
|
-
private
|
25
|
-
|
26
|
-
def pair(value)
|
27
|
-
if value.is_a?(Array)
|
28
|
-
value
|
29
|
-
else
|
30
|
-
[value] * 2
|
31
|
-
end
|
32
|
-
end
|
33
25
|
end
|
34
26
|
end
|
35
27
|
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Conv3d < ConvNd
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1,
|
5
|
+
padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
6
|
+
|
7
|
+
kernel_size = _triple(kernel_size)
|
8
|
+
stride = _triple(stride)
|
9
|
+
padding = _triple(padding)
|
10
|
+
dilation = _triple(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _triple(0), groups, bias, padding_mode)
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
if @padding_mode == "circular"
|
16
|
+
raise NotImplementedError
|
17
|
+
end
|
18
|
+
F.conv3d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Fold < Module
|
4
|
+
def initialize(output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
|
5
|
+
super()
|
6
|
+
@output_size = output_size
|
7
|
+
@kernel_size = kernel_size
|
8
|
+
@dilation = dilation
|
9
|
+
@padding = padding
|
10
|
+
@stride = stride
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(input)
|
14
|
+
F.fold(input, @output_size, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
|
15
|
+
end
|
16
|
+
|
17
|
+
# TODO add extra_inspect
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -2,6 +2,166 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class Functional
|
4
4
|
class << self
|
5
|
+
include Utils
|
6
|
+
|
7
|
+
# convolution layers
|
8
|
+
|
9
|
+
def conv1d(*args, **options)
|
10
|
+
Torch.conv1d(*args, **options)
|
11
|
+
end
|
12
|
+
|
13
|
+
def conv2d(*args, **options)
|
14
|
+
Torch.conv2d(*args, **options)
|
15
|
+
end
|
16
|
+
|
17
|
+
def conv3d(*args, **options)
|
18
|
+
Torch.conv3d(*args, **options)
|
19
|
+
end
|
20
|
+
|
21
|
+
def unfold(input, kernel_size, dilation: 1, padding: 0, stride: 1)
|
22
|
+
if input.dim == 4
|
23
|
+
NN.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
|
24
|
+
else
|
25
|
+
raise Error, "Input Error: Only 4D input Tensors are supported (got #{input.dim}D)"
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def fold(input, output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
|
30
|
+
if input.dim == 3
|
31
|
+
NN.col2im(input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
|
32
|
+
else
|
33
|
+
raise Error, "Input Error: Only 3D input Tensors are supported (got #{input.dim}D)"
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
# pooling layers
|
38
|
+
|
39
|
+
def max_pool1d(*args, **options)
|
40
|
+
return_indices = args.pop if args.size == 7
|
41
|
+
if return_indices
|
42
|
+
Torch.max_pool1d_with_indices(*args, **options)
|
43
|
+
else
|
44
|
+
Torch.max_pool1d(*args, **options)
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def max_pool2d(*args, **options)
|
49
|
+
return_indices = args.pop if args.size == 7
|
50
|
+
if return_indices
|
51
|
+
NN.max_pool2d_with_indices(*args, **options)
|
52
|
+
else
|
53
|
+
Torch.max_pool2d(*args, **options)
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
def max_pool3d(*args, **options)
|
58
|
+
return_indices = args.pop if args.size == 7
|
59
|
+
if return_indices
|
60
|
+
NN.max_pool3d_with_indices(*args, **options)
|
61
|
+
else
|
62
|
+
Torch.max_pool3d(*args, **options)
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
def max_unpool1d(input, indices, kernel_size, stride: nil, padding: 0, output_size: nil)
|
67
|
+
raise NotImplementedYet
|
68
|
+
kernel_size = _single(kernel_size)
|
69
|
+
if !stride.nil?
|
70
|
+
_stride = _single(stride)
|
71
|
+
else
|
72
|
+
_stride = kernel_size
|
73
|
+
end
|
74
|
+
padding = _single(padding)
|
75
|
+
output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
|
76
|
+
output_size = output_size + [1]
|
77
|
+
NN.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3)
|
78
|
+
end
|
79
|
+
|
80
|
+
def max_unpool2d(*args, **options)
|
81
|
+
raise NotImplementedYet
|
82
|
+
NN.max_unpool2d(*args, **options)
|
83
|
+
end
|
84
|
+
|
85
|
+
def max_unpool3d(*args, **options)
|
86
|
+
raise NotImplementedYet
|
87
|
+
NN.max_unpool3d(*args, **options)
|
88
|
+
end
|
89
|
+
|
90
|
+
def avg_pool1d(*args, **options)
|
91
|
+
Torch.avg_pool1d(*args, **options)
|
92
|
+
end
|
93
|
+
|
94
|
+
def avg_pool2d(*args, **options)
|
95
|
+
NN.avg_pool2d(*args, **options)
|
96
|
+
end
|
97
|
+
|
98
|
+
def avg_pool3d(*args, **options)
|
99
|
+
NN.avg_pool3d(*args, **options)
|
100
|
+
end
|
101
|
+
|
102
|
+
# padding layers
|
103
|
+
|
104
|
+
def pad(input, pad, mode: "constant", value: 0)
|
105
|
+
raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0
|
106
|
+
raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
|
107
|
+
|
108
|
+
if mode == "constant"
|
109
|
+
return Torch.constant_pad_nd(input, pad, value)
|
110
|
+
else
|
111
|
+
raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
|
112
|
+
|
113
|
+
if input.dim == 3
|
114
|
+
raise ArgumentError, "3D tensors expect 2 values for padding" unless pad.size == 2
|
115
|
+
case mode
|
116
|
+
when "reflect"
|
117
|
+
NN.reflection_pad1d(input, pad)
|
118
|
+
when "replicate"
|
119
|
+
NN.replication_pad1d(input, pad)
|
120
|
+
else
|
121
|
+
raise NotImplementedYet
|
122
|
+
end
|
123
|
+
elsif input.dim == 4
|
124
|
+
raise ArgumentError, "4D tensors expect 4 values for padding" unless pad.size == 4
|
125
|
+
case mode
|
126
|
+
when "reflect"
|
127
|
+
NN.reflection_pad2d(input, pad)
|
128
|
+
when "replicate"
|
129
|
+
NN.replication_pad2d(input, pad)
|
130
|
+
else
|
131
|
+
raise NotImplementedYet
|
132
|
+
end
|
133
|
+
elsif input.dim == 5
|
134
|
+
raise ArgumentError, "5D tensors expect 6 values for padding" unless pad.size == 6
|
135
|
+
case mode
|
136
|
+
when "replicate"
|
137
|
+
NN.replication_pad3d(input, pad)
|
138
|
+
else
|
139
|
+
raise NotImplementedYet
|
140
|
+
end
|
141
|
+
else
|
142
|
+
raise ArgumentError, "Only 3D, 4D, 5D padding with non-constant padding are supported for now"
|
143
|
+
end
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
# activation layers
|
148
|
+
|
149
|
+
def hardshrink(input, lambd = 0.5)
|
150
|
+
Torch.hardshrink(input, lambd)
|
151
|
+
end
|
152
|
+
|
153
|
+
def leaky_relu(input, negative_slope = 0.01)
|
154
|
+
NN.leaky_relu(input, negative_slope)
|
155
|
+
end
|
156
|
+
|
157
|
+
def log_sigmoid(input)
|
158
|
+
NN.log_sigmoid(input)
|
159
|
+
end
|
160
|
+
|
161
|
+
def prelu(input, weight)
|
162
|
+
Torch.prelu(input, weight)
|
163
|
+
end
|
164
|
+
|
5
165
|
def relu(input, inplace: false)
|
6
166
|
if inplace
|
7
167
|
input.relu!
|
@@ -10,37 +170,151 @@ module Torch
|
|
10
170
|
end
|
11
171
|
end
|
12
172
|
|
13
|
-
def
|
14
|
-
|
15
|
-
|
173
|
+
def softplus(input, beta: 1, threshold: 20)
|
174
|
+
NN.softplus(input, beta, threshold)
|
175
|
+
end
|
176
|
+
|
177
|
+
def softshrink(*args, **options)
|
178
|
+
NN.softshrink(*args, **options)
|
16
179
|
end
|
17
180
|
|
18
|
-
def
|
19
|
-
|
181
|
+
def softsign(input)
|
182
|
+
input / (input.abs + 1)
|
20
183
|
end
|
21
184
|
|
22
|
-
def
|
23
|
-
|
185
|
+
def tanhshrink(input)
|
186
|
+
input - input.tanh
|
24
187
|
end
|
25
188
|
|
26
|
-
|
27
|
-
|
28
|
-
|
189
|
+
# other activation layers
|
190
|
+
|
191
|
+
def softmin(input, dim: nil)
|
192
|
+
dim ||= softmax_dim(input.dim)
|
193
|
+
(-input).softmax(dim)
|
194
|
+
end
|
195
|
+
|
196
|
+
def softmax(input, dim: nil)
|
197
|
+
dim ||= softmax_dim(input.dim)
|
198
|
+
input.softmax(dim)
|
199
|
+
end
|
200
|
+
|
201
|
+
# TODO make dim keyword argument and update examples
|
202
|
+
def log_softmax(input, dim = nil)
|
203
|
+
dim ||= softmax_dim(input.dim)
|
204
|
+
input.log_softmax(dim)
|
205
|
+
end
|
206
|
+
|
207
|
+
# normalization layers
|
208
|
+
|
209
|
+
def batch_norm(input, running_mean, running_var, weight: nil, bias: nil,
|
210
|
+
training: false, momentum: 0.1, eps: 1e-5)
|
211
|
+
|
212
|
+
if training
|
213
|
+
size = input.size
|
214
|
+
size_prods = size[0]
|
215
|
+
(size.length - 2).times do |i|
|
216
|
+
size_prods *= size[i + 2]
|
217
|
+
end
|
218
|
+
if size_prods == 1
|
219
|
+
raise ArgumentError, "Expected more than 1 value per channel when training, got input size #{size.inspect}"
|
220
|
+
end
|
221
|
+
end
|
222
|
+
|
223
|
+
Torch.batch_norm(
|
224
|
+
input, weight, bias, running_mean, running_var,
|
225
|
+
training, momentum, eps, false
|
226
|
+
)
|
227
|
+
end
|
228
|
+
|
229
|
+
def group_norm(input, num_groups, weight: nil, bias: nil, eps: 1e-5)
|
230
|
+
Torch.group_norm(input, num_groups, weight, bias, eps, false)
|
231
|
+
end
|
232
|
+
|
233
|
+
def instance_norm(input, running_mean: nil, running_var: nil, weight: nil,
|
234
|
+
bias: nil, use_input_stats: true, momentum: 0.1, eps: 1e-5)
|
235
|
+
|
236
|
+
Torch.instance_norm(
|
237
|
+
input, weight, bias, running_mean, running_var,
|
238
|
+
use_input_stats, momentum, eps, false
|
239
|
+
)
|
29
240
|
end
|
30
241
|
|
31
|
-
def
|
32
|
-
|
33
|
-
|
242
|
+
def layer_norm(input, normalized_shape, weight: nil, bias: nil, eps: 1e-5)
|
243
|
+
Torch.layer_norm(input, normalized_shape, weight, bias, eps, false)
|
244
|
+
end
|
245
|
+
|
246
|
+
def local_response_norm(input, size, alpha: 1e-4, beta: 0.75, k: 1.0)
|
247
|
+
dim = input.dim
|
248
|
+
if dim < 3
|
249
|
+
raise ArgumentError, "Expected 3D or higher dimensionality input (got #{dim} dimensions)"
|
250
|
+
end
|
251
|
+
div = input.mul(input).unsqueeze(1)
|
252
|
+
if dim == 3
|
253
|
+
div = pad(div, [0, 0, size / 2, (size - 1) / 2])
|
254
|
+
div = avg_pool2d(div, [size, 1], stride: 1).squeeze(1)
|
255
|
+
else
|
256
|
+
sizes = input.size
|
257
|
+
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
|
258
|
+
div = pad(div, [0, 0, 0, 0, size / 2, (size - 1) / 2])
|
259
|
+
div = avg_pool3d(div, [size, 1, 1], stride: 1).squeeze(1)
|
260
|
+
div = div.view(sizes)
|
261
|
+
end
|
262
|
+
div = div.mul(alpha).add(k).pow(beta)
|
263
|
+
input / div
|
34
264
|
end
|
35
265
|
|
36
266
|
# linear layers
|
37
267
|
|
268
|
+
def linear(input, weight, bias)
|
269
|
+
NN.linear(input, weight, bias)
|
270
|
+
end
|
271
|
+
|
38
272
|
def bilinear(input1, input2, weight, bias)
|
39
273
|
Torch.bilinear(input1, input2, weight, bias)
|
40
274
|
end
|
41
275
|
|
42
|
-
|
43
|
-
|
276
|
+
# dropout layers
|
277
|
+
|
278
|
+
def dropout(input, p: 0.5, training: true, inplace: false)
|
279
|
+
if inplace
|
280
|
+
Torch.dropout!(input, p, training)
|
281
|
+
else
|
282
|
+
Torch.dropout(input, p, training)
|
283
|
+
end
|
284
|
+
end
|
285
|
+
|
286
|
+
def dropout2d(input, p: 0.5, training: true, inplace: false)
|
287
|
+
raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
|
288
|
+
|
289
|
+
if inplace
|
290
|
+
Torch.feature_dropout!(input, p, training)
|
291
|
+
else
|
292
|
+
Torch.feature_dropout(input, p, training)
|
293
|
+
end
|
294
|
+
end
|
295
|
+
|
296
|
+
def dropout3d(input, p: 0.5, training: true, inplace: false)
|
297
|
+
if inplace
|
298
|
+
Torch.feature_dropout!(input, p, training)
|
299
|
+
else
|
300
|
+
Torch.feature_dropout(input, p, training)
|
301
|
+
end
|
302
|
+
end
|
303
|
+
|
304
|
+
def alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
305
|
+
if inplace
|
306
|
+
Torch.alpha_dropout!(input, p, training)
|
307
|
+
else
|
308
|
+
Torch.alpha_dropout(input, p, training)
|
309
|
+
end
|
310
|
+
end
|
311
|
+
|
312
|
+
def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
313
|
+
if inplace
|
314
|
+
Torch.feature_alpha_dropout!(input, p, training)
|
315
|
+
else
|
316
|
+
Torch.feature_alpha_dropout(input, p, training)
|
317
|
+
end
|
44
318
|
end
|
45
319
|
|
46
320
|
# sparse layers
|
@@ -51,37 +325,47 @@ module Torch
|
|
51
325
|
|
52
326
|
padding_idx ||= -1
|
53
327
|
# weight and indices are swapped from Python interface
|
54
|
-
Torch.
|
328
|
+
Torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
55
329
|
end
|
56
330
|
|
57
331
|
def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
|
58
|
-
# need to handle nils
|
59
|
-
raise NotImplementedYet
|
60
|
-
|
61
332
|
# TODO handle max_norm and norm_type
|
62
333
|
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
63
334
|
|
64
|
-
|
335
|
+
mode_enum =
|
336
|
+
case mode
|
337
|
+
when "sum"
|
338
|
+
0
|
339
|
+
when "mean"
|
340
|
+
1
|
341
|
+
when "max"
|
342
|
+
2
|
343
|
+
else
|
344
|
+
raise ArgumentError, "Unknown mode: #{mode}"
|
345
|
+
end
|
346
|
+
|
347
|
+
# weight and input swapped
|
348
|
+
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
65
349
|
end
|
66
350
|
|
67
351
|
# distance functions
|
68
352
|
|
69
353
|
def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
|
70
|
-
Torch.
|
354
|
+
Torch.cosine_similarity(x1, x2, dim, eps)
|
71
355
|
end
|
72
356
|
|
73
357
|
def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
|
74
|
-
Torch.
|
358
|
+
Torch.pairwise_distance(x1, x2, p, eps, keepdim)
|
75
359
|
end
|
76
360
|
|
77
361
|
# loss functions
|
78
362
|
|
79
363
|
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
80
|
-
NN.
|
364
|
+
NN.binary_cross_entropy(input, target, weight, reduction)
|
81
365
|
end
|
82
366
|
|
83
367
|
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
84
|
-
Torch.
|
368
|
+
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
85
369
|
end
|
86
370
|
|
87
371
|
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
@@ -94,19 +378,19 @@ module Torch
|
|
94
378
|
|
95
379
|
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
96
380
|
# call to_a on input_lengths and target_lengths for C++
|
97
|
-
Torch.
|
381
|
+
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
98
382
|
end
|
99
383
|
|
100
384
|
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
101
|
-
Torch.
|
385
|
+
Torch.hinge_embedding_loss(input, target, margin, reduction)
|
102
386
|
end
|
103
387
|
|
104
388
|
def kl_div(input, target, reduction: "mean")
|
105
|
-
Torch.
|
389
|
+
Torch.kl_div(input, target, reduction)
|
106
390
|
end
|
107
391
|
|
108
392
|
def l1_loss(input, target, reduction: "mean")
|
109
|
-
NN.
|
393
|
+
NN.l1_loss(input, target, reduction)
|
110
394
|
end
|
111
395
|
|
112
396
|
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
@@ -114,11 +398,11 @@ module Torch
|
|
114
398
|
end
|
115
399
|
|
116
400
|
def mse_loss(input, target, reduction: "mean")
|
117
|
-
NN.
|
401
|
+
NN.mse_loss(input, target, reduction)
|
118
402
|
end
|
119
403
|
|
120
404
|
def multilabel_margin_loss(input, target, reduction: "mean")
|
121
|
-
NN.
|
405
|
+
NN.multilabel_margin_loss(input, target, reduction)
|
122
406
|
end
|
123
407
|
|
124
408
|
def multilabel_soft_margin_loss(input, target, weight: nil)
|
@@ -126,91 +410,27 @@ module Torch
|
|
126
410
|
end
|
127
411
|
|
128
412
|
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
129
|
-
NN.
|
413
|
+
NN.multi_margin_loss(input, target, p, margin, weight, reduction)
|
130
414
|
end
|
131
415
|
|
132
416
|
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
133
|
-
NN.
|
417
|
+
NN.nll_loss(input, target, weight, reduction, ignore_index)
|
134
418
|
end
|
135
419
|
|
136
420
|
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
137
|
-
Torch.
|
421
|
+
Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
138
422
|
end
|
139
423
|
|
140
424
|
def soft_margin_loss(input, target, reduction: "mean")
|
141
|
-
NN.
|
425
|
+
NN.soft_margin_loss(input, target, reduction)
|
142
426
|
end
|
143
427
|
|
144
428
|
def smooth_l1_loss(input, target, reduction: "mean")
|
145
|
-
NN.
|
429
|
+
NN.smooth_l1_loss(input, target, reduction)
|
146
430
|
end
|
147
431
|
|
148
432
|
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
149
|
-
Torch.
|
150
|
-
end
|
151
|
-
|
152
|
-
# end loss
|
153
|
-
|
154
|
-
def softmax(input, dim: nil)
|
155
|
-
dim ||= softmax_dim(input.dim)
|
156
|
-
input.softmax(dim: dim)
|
157
|
-
end
|
158
|
-
|
159
|
-
def softmin(input, dim: nil)
|
160
|
-
dim ||= softmax_dim(input.dim)
|
161
|
-
(-input).softmax(dim: dim)
|
162
|
-
end
|
163
|
-
|
164
|
-
def softplus(input, beta: 1, threshold: 20)
|
165
|
-
NN._softplus(input, beta, threshold)
|
166
|
-
end
|
167
|
-
|
168
|
-
# TODO make dim keyword argument and update examples
|
169
|
-
def log_softmax(input, dim = nil)
|
170
|
-
dim ||= softmax_dim(input.dim)
|
171
|
-
input.log_softmax(dim)
|
172
|
-
end
|
173
|
-
|
174
|
-
def dropout(input, p: 0.5, training: true, inplace: false)
|
175
|
-
if inplace
|
176
|
-
Torch._dropout_(input, p, training)
|
177
|
-
else
|
178
|
-
Torch._dropout(input, p, training)
|
179
|
-
end
|
180
|
-
end
|
181
|
-
|
182
|
-
def dropout2d(input, p: 0.5, training: true, inplace: false)
|
183
|
-
raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
|
184
|
-
|
185
|
-
if inplace
|
186
|
-
Torch._feature_dropout_(input, p, training)
|
187
|
-
else
|
188
|
-
Torch._feature_dropout(input, p, training)
|
189
|
-
end
|
190
|
-
end
|
191
|
-
|
192
|
-
def dropout3d(input, p: 0.5, training: true, inplace: false)
|
193
|
-
if inplace
|
194
|
-
Torch._feature_dropout_(input, p, training)
|
195
|
-
else
|
196
|
-
Torch._feature_dropout(input, p, training)
|
197
|
-
end
|
198
|
-
end
|
199
|
-
|
200
|
-
def alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
201
|
-
if inplace
|
202
|
-
Torch._alpha_dropout_(input, p, training)
|
203
|
-
else
|
204
|
-
Torch._alpha_dropout(input, p, training)
|
205
|
-
end
|
206
|
-
end
|
207
|
-
|
208
|
-
def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
209
|
-
if inplace
|
210
|
-
Torch._feature_alpha_dropout_(input, p, training)
|
211
|
-
else
|
212
|
-
Torch._feature_alpha_dropout(input, p, training)
|
213
|
-
end
|
433
|
+
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
214
434
|
end
|
215
435
|
|
216
436
|
private
|