torch-rb 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +1 -1
  4. data/ext/torch/ext.cpp +0 -170
  5. data/ext/torch/nn_functions.cpp +44 -24
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +48 -0
  8. data/ext/torch/tensor_functions.cpp +76 -16
  9. data/ext/torch/torch_functions.cpp +165 -65
  10. data/lib/torch.rb +51 -42
  11. data/lib/torch/ext.bundle +0 -0
  12. data/lib/torch/native/dispatcher.rb +1 -1
  13. data/lib/torch/native/function.rb +36 -5
  14. data/lib/torch/native/generator.rb +26 -7
  15. data/lib/torch/native/parser.rb +51 -14
  16. data/lib/torch/nn/avg_pool1d.rb +18 -0
  17. data/lib/torch/nn/avg_pool2d.rb +7 -2
  18. data/lib/torch/nn/avg_pool3d.rb +19 -0
  19. data/lib/torch/nn/avg_poolnd.rb +1 -1
  20. data/lib/torch/nn/batch_norm.rb +75 -0
  21. data/lib/torch/nn/batch_norm1d.rb +11 -0
  22. data/lib/torch/nn/batch_norm2d.rb +11 -0
  23. data/lib/torch/nn/batch_norm3d.rb +11 -0
  24. data/lib/torch/nn/constant_pad1d.rb +10 -0
  25. data/lib/torch/nn/constant_pad2d.rb +10 -0
  26. data/lib/torch/nn/constant_pad3d.rb +10 -0
  27. data/lib/torch/nn/constant_padnd.rb +18 -0
  28. data/lib/torch/nn/conv1d.rb +22 -0
  29. data/lib/torch/nn/conv2d.rb +9 -17
  30. data/lib/torch/nn/conv3d.rb +22 -0
  31. data/lib/torch/nn/fold.rb +20 -0
  32. data/lib/torch/nn/functional.rb +320 -100
  33. data/lib/torch/nn/group_norm.rb +36 -0
  34. data/lib/torch/nn/gru.rb +49 -0
  35. data/lib/torch/nn/hardshrink.rb +18 -0
  36. data/lib/torch/nn/instance_norm.rb +20 -0
  37. data/lib/torch/nn/instance_norm1d.rb +18 -0
  38. data/lib/torch/nn/instance_norm2d.rb +11 -0
  39. data/lib/torch/nn/instance_norm3d.rb +11 -0
  40. data/lib/torch/nn/layer_norm.rb +35 -0
  41. data/lib/torch/nn/local_response_norm.rb +21 -0
  42. data/lib/torch/nn/log_sigmoid.rb +9 -0
  43. data/lib/torch/nn/lp_pool1d.rb +9 -0
  44. data/lib/torch/nn/lp_pool2d.rb +9 -0
  45. data/lib/torch/nn/lp_poolnd.rb +22 -0
  46. data/lib/torch/nn/lstm.rb +66 -0
  47. data/lib/torch/nn/max_pool1d.rb +9 -0
  48. data/lib/torch/nn/max_pool2d.rb +1 -1
  49. data/lib/torch/nn/max_pool3d.rb +9 -0
  50. data/lib/torch/nn/max_poolnd.rb +6 -6
  51. data/lib/torch/nn/max_unpool1d.rb +16 -0
  52. data/lib/torch/nn/max_unpool2d.rb +16 -0
  53. data/lib/torch/nn/max_unpool3d.rb +16 -0
  54. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  55. data/lib/torch/nn/module.rb +7 -0
  56. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  57. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  58. data/lib/torch/nn/reflection_padnd.rb +13 -0
  59. data/lib/torch/nn/replication_pad1d.rb +10 -0
  60. data/lib/torch/nn/replication_pad2d.rb +10 -0
  61. data/lib/torch/nn/replication_pad3d.rb +10 -0
  62. data/lib/torch/nn/replication_padnd.rb +13 -0
  63. data/lib/torch/nn/rnn_base.rb +48 -4
  64. data/lib/torch/nn/softshrink.rb +18 -0
  65. data/lib/torch/nn/softsign.rb +9 -0
  66. data/lib/torch/nn/tanh.rb +9 -0
  67. data/lib/torch/nn/tanhshrink.rb +9 -0
  68. data/lib/torch/nn/unfold.rb +19 -0
  69. data/lib/torch/nn/utils.rb +25 -0
  70. data/lib/torch/nn/zero_pad2d.rb +9 -0
  71. data/lib/torch/tensor.rb +14 -25
  72. data/lib/torch/version.rb +1 -1
  73. metadata +50 -2
@@ -2,7 +2,7 @@ module Torch
2
2
  module NN
3
3
  class AvgPoolNd < Module
4
4
  def extra_inspect
5
- format("kernel_size: %s", @kernel_size)
5
+ format("kernel_size: %s, stride: %s, padding: %s", @kernel_size, @stride, @padding)
6
6
  end
7
7
  end
8
8
  end
@@ -0,0 +1,75 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm < Module
4
+ def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: true, track_running_stats: true)
5
+ super()
6
+ @num_features = num_features
7
+ @eps = eps
8
+ @momentum = momentum
9
+ @affine = affine
10
+ @track_running_stats = track_running_stats
11
+ if @affine
12
+ @weight = Parameter.new(Torch::Tensor.new(num_features))
13
+ @bias = Parameter.new(Torch::Tensor.new(num_features))
14
+ else
15
+ register_parameter("weight", nil)
16
+ register_parameter("bias", nil)
17
+ end
18
+ if track_running_stats
19
+ register_buffer("running_mean", Torch.zeros(num_features))
20
+ register_buffer("running_var", Torch.ones(num_features))
21
+ register_buffer("num_batches_tracked", Torch.tensor(0, dtype: :long))
22
+ else
23
+ register_parameter("running_mean", nil)
24
+ register_parameter("running_var", nil)
25
+ register_parameter("num_batches_tracked", nil)
26
+ end
27
+ reset_parameters
28
+ end
29
+
30
+ def reset_running_stats
31
+ if @track_running_stats
32
+ @running_mean.zero!
33
+ @running_var.fill!(1)
34
+ @num_batches_tracked.zero!
35
+ end
36
+ end
37
+
38
+ def reset_parameters
39
+ reset_running_stats
40
+ if @affine
41
+ Init.ones!(@weight)
42
+ Init.zeros!(@bias)
43
+ end
44
+ end
45
+
46
+ def forward(input)
47
+ _check_input_dim(input)
48
+
49
+ if @momentum.nil?
50
+ exponential_average_factor = 0.0
51
+ else
52
+ exponential_average_factor = @momentum
53
+ end
54
+
55
+ if @training and @track_running_stats
56
+ if @num_batches_tracked.nil?
57
+ @num_batches_tracked += 1
58
+ if @momentum.nil?
59
+ exponential_average_factor = 1.0 / @num_batches_tracked.to_f
60
+ else
61
+ exponential_average_factor = @momentum
62
+ end
63
+ end
64
+ end
65
+
66
+ F.batch_norm(
67
+ input, @running_mean, @running_var,
68
+ weight: @weight, bias: @bias,
69
+ training: @training || !@track_running_stats,
70
+ momentum: exponential_average_factor, eps: @eps
71
+ )
72
+ end
73
+ end
74
+ end
75
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm1d < BatchNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 2 && input.dim != 3
6
+ raise ArgumentError, "expected 2D or 3D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm2d < BatchNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 4
6
+ raise ArgumentError, "expected 4D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm3d < BatchNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 5
6
+ raise ArgumentError, "expected 5D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPad1d < ConstantPadNd
4
+ def initialize(padding, value)
5
+ super(value)
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPad2d < ConstantPadNd
4
+ def initialize(padding, value)
5
+ super(value)
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPad3d < ConstantPadNd
4
+ def initialize(padding, value)
5
+ super(value)
6
+ @padding = _ntuple(6, padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPadNd < Module
4
+ def initialize(value)
5
+ super()
6
+ @value = value
7
+ end
8
+
9
+ def forward(input)
10
+ F.pad(input, @padding, mode: "constant", value: @value)
11
+ end
12
+
13
+ def extra_inspect
14
+ format("padding: %s, value: %s", @padding, @value)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class Conv1d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _single(kernel_size)
8
+ stride = _single(stride)
9
+ padding = _single(padding)
10
+ dilation = _single(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _single(0), groups, bias, padding_mode)
12
+ end
13
+
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
17
+ end
18
+ F.conv1d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
+ end
20
+ end
21
+ end
22
+ end
@@ -1,35 +1,27 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Conv2d < ConvNd
4
- def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
5
- kernel_size = pair(kernel_size)
6
- stride = pair(stride)
7
- padding = pair(padding)
8
- dilation = pair(dilation)
9
- super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode)
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _pair(kernel_size)
8
+ stride = _pair(stride)
9
+ padding = _pair(padding)
10
+ dilation = _pair(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _pair(0), groups, bias, padding_mode)
10
12
  end
11
13
 
12
14
  def forward(input)
13
15
  if @padding_mode == "circular"
14
16
  raise NotImplementedError
15
17
  end
16
- F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups)
18
+ F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
17
19
  end
18
20
 
19
21
  # TODO add more parameters
20
22
  def extra_inspect
21
23
  format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
22
24
  end
23
-
24
- private
25
-
26
- def pair(value)
27
- if value.is_a?(Array)
28
- value
29
- else
30
- [value] * 2
31
- end
32
- end
33
25
  end
34
26
  end
35
27
  end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class Conv3d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _triple(kernel_size)
8
+ stride = _triple(stride)
9
+ padding = _triple(padding)
10
+ dilation = _triple(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _triple(0), groups, bias, padding_mode)
12
+ end
13
+
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
17
+ end
18
+ F.conv3d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class Fold < Module
4
+ def initialize(output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
5
+ super()
6
+ @output_size = output_size
7
+ @kernel_size = kernel_size
8
+ @dilation = dilation
9
+ @padding = padding
10
+ @stride = stride
11
+ end
12
+
13
+ def forward(input)
14
+ F.fold(input, @output_size, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
15
+ end
16
+
17
+ # TODO add extra_inspect
18
+ end
19
+ end
20
+ end
@@ -2,6 +2,166 @@ module Torch
2
2
  module NN
3
3
  class Functional
4
4
  class << self
5
+ include Utils
6
+
7
+ # convolution layers
8
+
9
+ def conv1d(*args, **options)
10
+ Torch.conv1d(*args, **options)
11
+ end
12
+
13
+ def conv2d(*args, **options)
14
+ Torch.conv2d(*args, **options)
15
+ end
16
+
17
+ def conv3d(*args, **options)
18
+ Torch.conv3d(*args, **options)
19
+ end
20
+
21
+ def unfold(input, kernel_size, dilation: 1, padding: 0, stride: 1)
22
+ if input.dim == 4
23
+ NN.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
24
+ else
25
+ raise Error, "Input Error: Only 4D input Tensors are supported (got #{input.dim}D)"
26
+ end
27
+ end
28
+
29
+ def fold(input, output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
30
+ if input.dim == 3
31
+ NN.col2im(input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
32
+ else
33
+ raise Error, "Input Error: Only 3D input Tensors are supported (got #{input.dim}D)"
34
+ end
35
+ end
36
+
37
+ # pooling layers
38
+
39
+ def max_pool1d(*args, **options)
40
+ return_indices = args.pop if args.size == 7
41
+ if return_indices
42
+ Torch.max_pool1d_with_indices(*args, **options)
43
+ else
44
+ Torch.max_pool1d(*args, **options)
45
+ end
46
+ end
47
+
48
+ def max_pool2d(*args, **options)
49
+ return_indices = args.pop if args.size == 7
50
+ if return_indices
51
+ NN.max_pool2d_with_indices(*args, **options)
52
+ else
53
+ Torch.max_pool2d(*args, **options)
54
+ end
55
+ end
56
+
57
+ def max_pool3d(*args, **options)
58
+ return_indices = args.pop if args.size == 7
59
+ if return_indices
60
+ NN.max_pool3d_with_indices(*args, **options)
61
+ else
62
+ Torch.max_pool3d(*args, **options)
63
+ end
64
+ end
65
+
66
+ def max_unpool1d(input, indices, kernel_size, stride: nil, padding: 0, output_size: nil)
67
+ raise NotImplementedYet
68
+ kernel_size = _single(kernel_size)
69
+ if !stride.nil?
70
+ _stride = _single(stride)
71
+ else
72
+ _stride = kernel_size
73
+ end
74
+ padding = _single(padding)
75
+ output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
76
+ output_size = output_size + [1]
77
+ NN.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3)
78
+ end
79
+
80
+ def max_unpool2d(*args, **options)
81
+ raise NotImplementedYet
82
+ NN.max_unpool2d(*args, **options)
83
+ end
84
+
85
+ def max_unpool3d(*args, **options)
86
+ raise NotImplementedYet
87
+ NN.max_unpool3d(*args, **options)
88
+ end
89
+
90
+ def avg_pool1d(*args, **options)
91
+ Torch.avg_pool1d(*args, **options)
92
+ end
93
+
94
+ def avg_pool2d(*args, **options)
95
+ NN.avg_pool2d(*args, **options)
96
+ end
97
+
98
+ def avg_pool3d(*args, **options)
99
+ NN.avg_pool3d(*args, **options)
100
+ end
101
+
102
+ # padding layers
103
+
104
+ def pad(input, pad, mode: "constant", value: 0)
105
+ raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0
106
+ raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
107
+
108
+ if mode == "constant"
109
+ return Torch.constant_pad_nd(input, pad, value)
110
+ else
111
+ raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
112
+
113
+ if input.dim == 3
114
+ raise ArgumentError, "3D tensors expect 2 values for padding" unless pad.size == 2
115
+ case mode
116
+ when "reflect"
117
+ NN.reflection_pad1d(input, pad)
118
+ when "replicate"
119
+ NN.replication_pad1d(input, pad)
120
+ else
121
+ raise NotImplementedYet
122
+ end
123
+ elsif input.dim == 4
124
+ raise ArgumentError, "4D tensors expect 4 values for padding" unless pad.size == 4
125
+ case mode
126
+ when "reflect"
127
+ NN.reflection_pad2d(input, pad)
128
+ when "replicate"
129
+ NN.replication_pad2d(input, pad)
130
+ else
131
+ raise NotImplementedYet
132
+ end
133
+ elsif input.dim == 5
134
+ raise ArgumentError, "5D tensors expect 6 values for padding" unless pad.size == 6
135
+ case mode
136
+ when "replicate"
137
+ NN.replication_pad3d(input, pad)
138
+ else
139
+ raise NotImplementedYet
140
+ end
141
+ else
142
+ raise ArgumentError, "Only 3D, 4D, 5D padding with non-constant padding are supported for now"
143
+ end
144
+ end
145
+ end
146
+
147
+ # activation layers
148
+
149
+ def hardshrink(input, lambd = 0.5)
150
+ Torch.hardshrink(input, lambd)
151
+ end
152
+
153
+ def leaky_relu(input, negative_slope = 0.01)
154
+ NN.leaky_relu(input, negative_slope)
155
+ end
156
+
157
+ def log_sigmoid(input)
158
+ NN.log_sigmoid(input)
159
+ end
160
+
161
+ def prelu(input, weight)
162
+ Torch.prelu(input, weight)
163
+ end
164
+
5
165
  def relu(input, inplace: false)
6
166
  if inplace
7
167
  input.relu!
@@ -10,37 +170,151 @@ module Torch
10
170
  end
11
171
  end
12
172
 
13
- def conv2d(input, weight, bias, stride: 1, padding: 0, dilation: 1, groups: 1)
14
- # TODO pair stride and padding when needed
15
- Torch.conv2d(input, weight, bias, stride, padding, dilation, groups)
173
+ def softplus(input, beta: 1, threshold: 20)
174
+ NN.softplus(input, beta, threshold)
175
+ end
176
+
177
+ def softshrink(*args, **options)
178
+ NN.softshrink(*args, **options)
16
179
  end
17
180
 
18
- def prelu(input, weight)
19
- Torch.prelu(input, weight)
181
+ def softsign(input)
182
+ input / (input.abs + 1)
20
183
  end
21
184
 
22
- def leaky_relu(input, negative_slope = 0.01)
23
- Torch.leaky_relu(input, negative_slope)
185
+ def tanhshrink(input)
186
+ input - input.tanh
24
187
  end
25
188
 
26
- def max_pool2d(input, kernel_size)
27
- kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
28
- Torch.max_pool2d(input, kernel_size)
189
+ # other activation layers
190
+
191
+ def softmin(input, dim: nil)
192
+ dim ||= softmax_dim(input.dim)
193
+ (-input).softmax(dim)
194
+ end
195
+
196
+ def softmax(input, dim: nil)
197
+ dim ||= softmax_dim(input.dim)
198
+ input.softmax(dim)
199
+ end
200
+
201
+ # TODO make dim keyword argument and update examples
202
+ def log_softmax(input, dim = nil)
203
+ dim ||= softmax_dim(input.dim)
204
+ input.log_softmax(dim)
205
+ end
206
+
207
+ # normalization layers
208
+
209
+ def batch_norm(input, running_mean, running_var, weight: nil, bias: nil,
210
+ training: false, momentum: 0.1, eps: 1e-5)
211
+
212
+ if training
213
+ size = input.size
214
+ size_prods = size[0]
215
+ (size.length - 2).times do |i|
216
+ size_prods *= size[i + 2]
217
+ end
218
+ if size_prods == 1
219
+ raise ArgumentError, "Expected more than 1 value per channel when training, got input size #{size.inspect}"
220
+ end
221
+ end
222
+
223
+ Torch.batch_norm(
224
+ input, weight, bias, running_mean, running_var,
225
+ training, momentum, eps, false
226
+ )
227
+ end
228
+
229
+ def group_norm(input, num_groups, weight: nil, bias: nil, eps: 1e-5)
230
+ Torch.group_norm(input, num_groups, weight, bias, eps, false)
231
+ end
232
+
233
+ def instance_norm(input, running_mean: nil, running_var: nil, weight: nil,
234
+ bias: nil, use_input_stats: true, momentum: 0.1, eps: 1e-5)
235
+
236
+ Torch.instance_norm(
237
+ input, weight, bias, running_mean, running_var,
238
+ use_input_stats, momentum, eps, false
239
+ )
29
240
  end
30
241
 
31
- def avg_pool2d(input, kernel_size)
32
- kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
33
- Torch.avg_pool2d(input, kernel_size)
242
+ def layer_norm(input, normalized_shape, weight: nil, bias: nil, eps: 1e-5)
243
+ Torch.layer_norm(input, normalized_shape, weight, bias, eps, false)
244
+ end
245
+
246
+ def local_response_norm(input, size, alpha: 1e-4, beta: 0.75, k: 1.0)
247
+ dim = input.dim
248
+ if dim < 3
249
+ raise ArgumentError, "Expected 3D or higher dimensionality input (got #{dim} dimensions)"
250
+ end
251
+ div = input.mul(input).unsqueeze(1)
252
+ if dim == 3
253
+ div = pad(div, [0, 0, size / 2, (size - 1) / 2])
254
+ div = avg_pool2d(div, [size, 1], stride: 1).squeeze(1)
255
+ else
256
+ sizes = input.size
257
+ div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
258
+ div = pad(div, [0, 0, 0, 0, size / 2, (size - 1) / 2])
259
+ div = avg_pool3d(div, [size, 1, 1], stride: 1).squeeze(1)
260
+ div = div.view(sizes)
261
+ end
262
+ div = div.mul(alpha).add(k).pow(beta)
263
+ input / div
34
264
  end
35
265
 
36
266
  # linear layers
37
267
 
268
+ def linear(input, weight, bias)
269
+ NN.linear(input, weight, bias)
270
+ end
271
+
38
272
  def bilinear(input1, input2, weight, bias)
39
273
  Torch.bilinear(input1, input2, weight, bias)
40
274
  end
41
275
 
42
- def linear(input, weight, bias)
43
- Torch.linear(input, weight, bias)
276
+ # dropout layers
277
+
278
+ def dropout(input, p: 0.5, training: true, inplace: false)
279
+ if inplace
280
+ Torch.dropout!(input, p, training)
281
+ else
282
+ Torch.dropout(input, p, training)
283
+ end
284
+ end
285
+
286
+ def dropout2d(input, p: 0.5, training: true, inplace: false)
287
+ raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
288
+
289
+ if inplace
290
+ Torch.feature_dropout!(input, p, training)
291
+ else
292
+ Torch.feature_dropout(input, p, training)
293
+ end
294
+ end
295
+
296
+ def dropout3d(input, p: 0.5, training: true, inplace: false)
297
+ if inplace
298
+ Torch.feature_dropout!(input, p, training)
299
+ else
300
+ Torch.feature_dropout(input, p, training)
301
+ end
302
+ end
303
+
304
+ def alpha_dropout(input, p: 0.5, training: true, inplace: false)
305
+ if inplace
306
+ Torch.alpha_dropout!(input, p, training)
307
+ else
308
+ Torch.alpha_dropout(input, p, training)
309
+ end
310
+ end
311
+
312
+ def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
313
+ if inplace
314
+ Torch.feature_alpha_dropout!(input, p, training)
315
+ else
316
+ Torch.feature_alpha_dropout(input, p, training)
317
+ end
44
318
  end
45
319
 
46
320
  # sparse layers
@@ -51,37 +325,47 @@ module Torch
51
325
 
52
326
  padding_idx ||= -1
53
327
  # weight and indices are swapped from Python interface
54
- Torch._embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
328
+ Torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
55
329
  end
56
330
 
57
331
  def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
58
- # need to handle nils
59
- raise NotImplementedYet
60
-
61
332
  # TODO handle max_norm and norm_type
62
333
  raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
63
334
 
64
- Torch._embedding_bag(input, weight, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights)
335
+ mode_enum =
336
+ case mode
337
+ when "sum"
338
+ 0
339
+ when "mean"
340
+ 1
341
+ when "max"
342
+ 2
343
+ else
344
+ raise ArgumentError, "Unknown mode: #{mode}"
345
+ end
346
+
347
+ # weight and input swapped
348
+ Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
65
349
  end
66
350
 
67
351
  # distance functions
68
352
 
69
353
  def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
70
- Torch._cosine_similarity(x1, x2, dim, eps)
354
+ Torch.cosine_similarity(x1, x2, dim, eps)
71
355
  end
72
356
 
73
357
  def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
74
- Torch._pairwise_distance(x1, x2, p, eps, keepdim)
358
+ Torch.pairwise_distance(x1, x2, p, eps, keepdim)
75
359
  end
76
360
 
77
361
  # loss functions
78
362
 
79
363
  def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
80
- NN._binary_cross_entropy(input, target, weight, reduction)
364
+ NN.binary_cross_entropy(input, target, weight, reduction)
81
365
  end
82
366
 
83
367
  def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
84
- Torch._binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
368
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
85
369
  end
86
370
 
87
371
  def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
@@ -94,19 +378,19 @@ module Torch
94
378
 
95
379
  def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
96
380
  # call to_a on input_lengths and target_lengths for C++
97
- Torch._ctc_loss_intlist(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
381
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
98
382
  end
99
383
 
100
384
  def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
101
- Torch._hinge_embedding_loss(input, target, margin, reduction)
385
+ Torch.hinge_embedding_loss(input, target, margin, reduction)
102
386
  end
103
387
 
104
388
  def kl_div(input, target, reduction: "mean")
105
- Torch._kl_div(input, target, reduction)
389
+ Torch.kl_div(input, target, reduction)
106
390
  end
107
391
 
108
392
  def l1_loss(input, target, reduction: "mean")
109
- NN._l1_loss(input, target, reduction)
393
+ NN.l1_loss(input, target, reduction)
110
394
  end
111
395
 
112
396
  def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
@@ -114,11 +398,11 @@ module Torch
114
398
  end
115
399
 
116
400
  def mse_loss(input, target, reduction: "mean")
117
- NN._mse_loss(input, target, reduction)
401
+ NN.mse_loss(input, target, reduction)
118
402
  end
119
403
 
120
404
  def multilabel_margin_loss(input, target, reduction: "mean")
121
- NN._multilabel_margin_loss(input, target, reduction)
405
+ NN.multilabel_margin_loss(input, target, reduction)
122
406
  end
123
407
 
124
408
  def multilabel_soft_margin_loss(input, target, weight: nil)
@@ -126,91 +410,27 @@ module Torch
126
410
  end
127
411
 
128
412
  def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
129
- NN._multi_margin_loss(input, target, p, margin, weight, reduction)
413
+ NN.multi_margin_loss(input, target, p, margin, weight, reduction)
130
414
  end
131
415
 
132
416
  def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
133
- NN._nll_loss(input, target, weight, reduction, ignore_index)
417
+ NN.nll_loss(input, target, weight, reduction, ignore_index)
134
418
  end
135
419
 
136
420
  def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
137
- Torch._poisson_nll_loss(input, target, log_input, full, eps, reduction)
421
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
138
422
  end
139
423
 
140
424
  def soft_margin_loss(input, target, reduction: "mean")
141
- NN._soft_margin_loss(input, target, reduction)
425
+ NN.soft_margin_loss(input, target, reduction)
142
426
  end
143
427
 
144
428
  def smooth_l1_loss(input, target, reduction: "mean")
145
- NN._smooth_l1_loss(input, target, reduction)
429
+ NN.smooth_l1_loss(input, target, reduction)
146
430
  end
147
431
 
148
432
  def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
149
- Torch._triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
150
- end
151
-
152
- # end loss
153
-
154
- def softmax(input, dim: nil)
155
- dim ||= softmax_dim(input.dim)
156
- input.softmax(dim: dim)
157
- end
158
-
159
- def softmin(input, dim: nil)
160
- dim ||= softmax_dim(input.dim)
161
- (-input).softmax(dim: dim)
162
- end
163
-
164
- def softplus(input, beta: 1, threshold: 20)
165
- NN._softplus(input, beta, threshold)
166
- end
167
-
168
- # TODO make dim keyword argument and update examples
169
- def log_softmax(input, dim = nil)
170
- dim ||= softmax_dim(input.dim)
171
- input.log_softmax(dim)
172
- end
173
-
174
- def dropout(input, p: 0.5, training: true, inplace: false)
175
- if inplace
176
- Torch._dropout_(input, p, training)
177
- else
178
- Torch._dropout(input, p, training)
179
- end
180
- end
181
-
182
- def dropout2d(input, p: 0.5, training: true, inplace: false)
183
- raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
184
-
185
- if inplace
186
- Torch._feature_dropout_(input, p, training)
187
- else
188
- Torch._feature_dropout(input, p, training)
189
- end
190
- end
191
-
192
- def dropout3d(input, p: 0.5, training: true, inplace: false)
193
- if inplace
194
- Torch._feature_dropout_(input, p, training)
195
- else
196
- Torch._feature_dropout(input, p, training)
197
- end
198
- end
199
-
200
- def alpha_dropout(input, p: 0.5, training: true, inplace: false)
201
- if inplace
202
- Torch._alpha_dropout_(input, p, training)
203
- else
204
- Torch._alpha_dropout(input, p, training)
205
- end
206
- end
207
-
208
- def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
209
- if inplace
210
- Torch._feature_alpha_dropout_(input, p, training)
211
- else
212
- Torch._feature_alpha_dropout(input, p, training)
213
- end
433
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
214
434
  end
215
435
 
216
436
  private