torch-rb 0.1.1 → 0.1.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +73 -9
  5. data/ext/torch/ext.cpp +148 -315
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +298 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +236 -112
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +52 -25
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -39
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +419 -16
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +191 -19
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +4 -0
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +62 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +60 -0
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +90 -30
  139. data/lib/torch/utils/data/data_loader.rb +15 -0
  140. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  141. data/lib/torch/version.rb +1 -1
  142. metadata +122 -3
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class Conv1d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _single(kernel_size)
8
+ stride = _single(stride)
9
+ padding = _single(padding)
10
+ dilation = _single(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _single(0), groups, bias, padding_mode)
12
+ end
13
+
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
17
+ end
18
+ F.conv1d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
+ end
20
+ end
21
+ end
22
+ end
@@ -1,49 +1,26 @@
1
1
  module Torch
2
2
  module NN
3
- class Conv2d < Module
4
- attr_reader :bias, :weight
5
-
6
- def initialize(in_channels, out_channels, kernel_size) #, stride: 1, padding: 0, dilation: 1, groups: 1)
7
- @in_channels = in_channels
8
- @out_channels = out_channels
9
- @kernel_size = pair(kernel_size)
10
- @stride = pair(1)
11
- # @stride = pair(stride)
12
- # @padding = pair(padding)
13
- # @dilation = pair(dilation)
14
-
15
- # TODO divide by groups
16
- @weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
17
- @bias = Parameter.new(Tensor.new(out_channels))
18
-
19
- reset_parameters
3
+ class Conv2d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _pair(kernel_size)
8
+ stride = _pair(stride)
9
+ padding = _pair(padding)
10
+ dilation = _pair(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _pair(0), groups, bias, padding_mode)
20
12
  end
21
13
 
22
- def reset_parameters
23
- Init.kaiming_uniform_(@weight, Math.sqrt(5))
24
- if @bias
25
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
26
- bound = 1 / Math.sqrt(fan_in)
27
- Init.uniform_(@bias, -bound, bound)
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
28
17
  end
18
+ F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
29
19
  end
30
20
 
31
- def call(input)
32
- F.conv2d(input, @weight, @bias) # @stride, self.padding, self.dilation, self.groups)
33
- end
34
-
35
- def inspect
36
- "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
37
- end
38
-
39
- private
40
-
41
- def pair(value)
42
- if value.is_a?(Array)
43
- value
44
- else
45
- [value] * 2
46
- end
21
+ # TODO add more parameters
22
+ def extra_inspect
23
+ format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
47
24
  end
48
25
  end
49
26
  end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class Conv3d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _triple(kernel_size)
8
+ stride = _triple(stride)
9
+ padding = _triple(padding)
10
+ dilation = _triple(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _triple(0), groups, bias, padding_mode)
12
+ end
13
+
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
17
+ end
18
+ F.conv3d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,41 @@
1
+ module Torch
2
+ module NN
3
+ class ConvNd < Module
4
+ def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
+ super()
6
+ raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
7
+ raise ArgumentError, "out_channels must be divisible by groups" if out_channels % groups != 0
8
+ @in_channels = in_channels
9
+ @out_channels = out_channels
10
+ @kernel_size = kernel_size
11
+ @stride = stride
12
+ @padding = padding
13
+ @dilation = dilation
14
+ @transposed = transposed
15
+ @output_padding = output_padding
16
+ @groups = groups
17
+ @padding_mode = padding_mode
18
+ if transposed
19
+ @weight = Parameter.new(Tensor.new(in_channels, out_channels / groups, *kernel_size))
20
+ else
21
+ @weight = Parameter.new(Tensor.new(out_channels, in_channels / groups, *kernel_size))
22
+ end
23
+ if bias
24
+ @bias = Parameter.new(Tensor.new(out_channels))
25
+ else
26
+ raise NotImplementedError
27
+ end
28
+ reset_parameters
29
+ end
30
+
31
+ def reset_parameters
32
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
33
+ if @bias
34
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
35
+ bound = 1 / Math.sqrt(fan_in)
36
+ Init.uniform!(@bias, a: -bound, b: bound)
37
+ end
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CosineEmbeddingLoss < Loss
4
+ def initialize(margin: 0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CosineSimilarity < Module
4
+ def initialize(dim: 1, eps: 1e-8)
5
+ super()
6
+ @dim = dim
7
+ @eps = eps
8
+ end
9
+
10
+ def forward(x1, x2)
11
+ F.cosine_similarity(x1, x2, dim: @dim, eps: @eps)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CrossEntropyLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CTCLoss < Loss
4
+ def initialize(blank: 0, reduction: "mean", zero_infinity: false)
5
+ super(reduction)
6
+ @blank = blank
7
+ @zero_infinity = zero_infinity
8
+ end
9
+
10
+ def forward(log_probs, targets, input_lengths, target_lengths)
11
+ F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Dropout < DropoutNd
4
+ def forward(input)
5
+ F.dropout(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Dropout2d < DropoutNd
4
+ def forward(input)
5
+ F.dropout2d(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Dropout3d < DropoutNd
4
+ def forward(input)
5
+ F.dropout3d(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class DropoutNd < Module
4
+ def initialize(p: 0.5, inplace: false)
5
+ super()
6
+ @p = p
7
+ @inplace = inplace
8
+ end
9
+
10
+ def extra_inspect
11
+ format("p: %s, inplace: %s", @p, @inplace)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,52 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
2
+ module Torch
3
+ module NN
4
+ class Embedding < Module
5
+ def initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil,
6
+ norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil)
7
+
8
+ super()
9
+ @num_embeddings = num_embeddings
10
+ @embedding_dim = embedding_dim
11
+
12
+ if padding_idx
13
+ if padding_idx > 0
14
+ raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx < @num_embeddings
15
+ elsif padding_idx < 0
16
+ raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx >= -@num_embeddings
17
+ padding_idx = @num_embeddings + padding_idx
18
+ end
19
+ end
20
+ @padding_idx = padding_idx
21
+ @max_norm = max_norm
22
+ @norm_type = norm_type
23
+ @scale_grad_by_freq = scale_grad_by_freq
24
+ if _weight.nil?
25
+ @weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
26
+ reset_parameters
27
+ else
28
+ raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
29
+ @weight = Parameter.new(_weight)
30
+ end
31
+ @sparse = sparse
32
+ end
33
+
34
+ def reset_parameters
35
+ Init.normal!(@weight)
36
+ if @padding_idx
37
+ Torch.no_grad do
38
+ @weight[@padding_idx].fill!(0)
39
+ end
40
+ end
41
+ end
42
+
43
+ def forward(input)
44
+ F.embedding(input, @weight, padding_idx: @padding_idx, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, sparse: @sparse)
45
+ end
46
+
47
+ def inspect
48
+ "Embedding(#{@num_embeddings}, #{@embedding_dim})"
49
+ end
50
+ end
51
+ end
52
+ end
@@ -0,0 +1,34 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
2
+ module Torch
3
+ module NN
4
+ class EmbeddingBag < Module
5
+ def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
6
+ scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
7
+
8
+ super()
9
+ @num_embeddings = num_embeddings
10
+ @embedding_dim = embedding_dim
11
+ @max_norm = max_norm
12
+ @norm_type = norm_type
13
+ @scale_grad_by_freq = scale_grad_by_freq
14
+ if _weight.nil?
15
+ @weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
16
+ reset_parameters
17
+ else
18
+ raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
19
+ @weight = Parameter.new(_weight)
20
+ end
21
+ @mode = mode
22
+ @sparse = sparse
23
+ end
24
+
25
+ def reset_parameters
26
+ Init.normal!(@weight)
27
+ end
28
+
29
+ def forward(input, offsets: nil, per_sample_weights: nil)
30
+ F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
31
+ end
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class FeatureAlphaDropout < DropoutNd
4
+ def forward(input)
5
+ F.feature_alpha_dropout(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class Fold < Module
4
+ def initialize(output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
5
+ super()
6
+ @output_size = output_size
7
+ @kernel_size = kernel_size
8
+ @dilation = dilation
9
+ @padding = padding
10
+ @stride = stride
11
+ end
12
+
13
+ def forward(input)
14
+ F.fold(input, @output_size, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
15
+ end
16
+
17
+ # TODO add extra_inspect
18
+ end
19
+ end
20
+ end
@@ -2,38 +2,441 @@ module Torch
2
2
  module NN
3
3
  class Functional
4
4
  class << self
5
- def relu(input)
6
- Torch.relu(input)
5
+ include Utils
6
+
7
+ # convolution layers
8
+
9
+ def conv1d(*args, **options)
10
+ Torch.conv1d(*args, **options)
11
+ end
12
+
13
+ def conv2d(*args, **options)
14
+ Torch.conv2d(*args, **options)
15
+ end
16
+
17
+ def conv3d(*args, **options)
18
+ Torch.conv3d(*args, **options)
19
+ end
20
+
21
+ def unfold(input, kernel_size, dilation: 1, padding: 0, stride: 1)
22
+ if input.dim == 4
23
+ NN.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
24
+ else
25
+ raise Error, "Input Error: Only 4D input Tensors are supported (got #{input.dim}D)"
26
+ end
27
+ end
28
+
29
+ def fold(input, output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
30
+ if input.dim == 3
31
+ NN.col2im(input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
32
+ else
33
+ raise Error, "Input Error: Only 3D input Tensors are supported (got #{input.dim}D)"
34
+ end
35
+ end
36
+
37
+ # pooling layers
38
+
39
+ def max_pool1d(*args, **options)
40
+ return_indices = args.pop if args.size == 7
41
+ if return_indices
42
+ Torch.max_pool1d_with_indices(*args, **options)
43
+ else
44
+ Torch.max_pool1d(*args, **options)
45
+ end
46
+ end
47
+
48
+ def max_pool2d(*args, **options)
49
+ return_indices = args.pop if args.size == 7
50
+ if return_indices
51
+ NN.max_pool2d_with_indices(*args, **options)
52
+ else
53
+ Torch.max_pool2d(*args, **options)
54
+ end
55
+ end
56
+
57
+ def max_pool3d(*args, **options)
58
+ return_indices = args.pop if args.size == 7
59
+ if return_indices
60
+ NN.max_pool3d_with_indices(*args, **options)
61
+ else
62
+ Torch.max_pool3d(*args, **options)
63
+ end
64
+ end
65
+
66
+ def max_unpool1d(input, indices, kernel_size, stride: nil, padding: 0, output_size: nil)
67
+ raise NotImplementedYet
68
+ kernel_size = _single(kernel_size)
69
+ if !stride.nil?
70
+ _stride = _single(stride)
71
+ else
72
+ _stride = kernel_size
73
+ end
74
+ padding = _single(padding)
75
+ output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
76
+ output_size = output_size + [1]
77
+ NN.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3)
78
+ end
79
+
80
+ def max_unpool2d(*args, **options)
81
+ raise NotImplementedYet
82
+ NN.max_unpool2d(*args, **options)
83
+ end
84
+
85
+ def max_unpool3d(*args, **options)
86
+ raise NotImplementedYet
87
+ NN.max_unpool3d(*args, **options)
88
+ end
89
+
90
+ def avg_pool1d(*args, **options)
91
+ Torch.avg_pool1d(*args, **options)
92
+ end
93
+
94
+ def avg_pool2d(*args, **options)
95
+ NN.avg_pool2d(*args, **options)
96
+ end
97
+
98
+ def avg_pool3d(*args, **options)
99
+ NN.avg_pool3d(*args, **options)
100
+ end
101
+
102
+ # padding layers
103
+
104
+ def pad(input, pad, mode: "constant", value: 0)
105
+ raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0
106
+ raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
107
+
108
+ if mode == "constant"
109
+ return Torch.constant_pad_nd(input, pad, value)
110
+ else
111
+ raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
112
+
113
+ if input.dim == 3
114
+ raise ArgumentError, "3D tensors expect 2 values for padding" unless pad.size == 2
115
+ case mode
116
+ when "reflect"
117
+ NN.reflection_pad1d(input, pad)
118
+ when "replicate"
119
+ NN.replication_pad1d(input, pad)
120
+ else
121
+ raise NotImplementedYet
122
+ end
123
+ elsif input.dim == 4
124
+ raise ArgumentError, "4D tensors expect 4 values for padding" unless pad.size == 4
125
+ case mode
126
+ when "reflect"
127
+ NN.reflection_pad2d(input, pad)
128
+ when "replicate"
129
+ NN.replication_pad2d(input, pad)
130
+ else
131
+ raise NotImplementedYet
132
+ end
133
+ elsif input.dim == 5
134
+ raise ArgumentError, "5D tensors expect 6 values for padding" unless pad.size == 6
135
+ case mode
136
+ when "replicate"
137
+ NN.replication_pad3d(input, pad)
138
+ else
139
+ raise NotImplementedYet
140
+ end
141
+ else
142
+ raise ArgumentError, "Only 3D, 4D, 5D padding with non-constant padding are supported for now"
143
+ end
144
+ end
145
+ end
146
+
147
+ # activation layers
148
+
149
+ def hardshrink(input, lambd = 0.5)
150
+ Torch.hardshrink(input, lambd)
151
+ end
152
+
153
+ def leaky_relu(input, negative_slope = 0.01)
154
+ NN.leaky_relu(input, negative_slope)
155
+ end
156
+
157
+ def log_sigmoid(input)
158
+ NN.log_sigmoid(input)
159
+ end
160
+
161
+ def prelu(input, weight)
162
+ Torch.prelu(input, weight)
163
+ end
164
+
165
+ def relu(input, inplace: false)
166
+ if inplace
167
+ input.relu!
168
+ else
169
+ input.relu
170
+ end
171
+ end
172
+
173
+ def softplus(input, beta: 1, threshold: 20)
174
+ NN.softplus(input, beta, threshold)
175
+ end
176
+
177
+ def softshrink(*args, **options)
178
+ NN.softshrink(*args, **options)
179
+ end
180
+
181
+ def softsign(input)
182
+ input / (input.abs + 1)
183
+ end
184
+
185
+ def tanhshrink(input)
186
+ input - input.tanh
187
+ end
188
+
189
+ # other activation layers
190
+
191
+ def softmin(input, dim: nil)
192
+ dim ||= softmax_dim(input.dim)
193
+ (-input).softmax(dim)
194
+ end
195
+
196
+ def softmax(input, dim: nil)
197
+ dim ||= softmax_dim(input.dim)
198
+ input.softmax(dim)
199
+ end
200
+
201
+ # TODO make dim keyword argument and update examples
202
+ def log_softmax(input, dim = nil)
203
+ dim ||= softmax_dim(input.dim)
204
+ input.log_softmax(dim)
205
+ end
206
+
207
+ # normalization layers
208
+
209
+ def batch_norm(input, running_mean, running_var, weight: nil, bias: nil,
210
+ training: false, momentum: 0.1, eps: 1e-5)
211
+
212
+ if training
213
+ size = input.size
214
+ size_prods = size[0]
215
+ (size.length - 2).times do |i|
216
+ size_prods *= size[i + 2]
217
+ end
218
+ if size_prods == 1
219
+ raise ArgumentError, "Expected more than 1 value per channel when training, got input size #{size.inspect}"
220
+ end
221
+ end
222
+
223
+ Torch.batch_norm(
224
+ input, weight, bias, running_mean, running_var,
225
+ training, momentum, eps, false
226
+ )
7
227
  end
8
228
 
9
- def conv2d(input, weight, bias)
10
- Torch.conv2d(input, weight, bias)
229
+ def group_norm(input, num_groups, weight: nil, bias: nil, eps: 1e-5)
230
+ Torch.group_norm(input, num_groups, weight, bias, eps, false)
11
231
  end
12
232
 
13
- def max_pool2d(input, kernel_size)
14
- kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
15
- Torch.max_pool2d(input, kernel_size)
233
+ def instance_norm(input, running_mean: nil, running_var: nil, weight: nil,
234
+ bias: nil, use_input_stats: true, momentum: 0.1, eps: 1e-5)
235
+
236
+ Torch.instance_norm(
237
+ input, weight, bias, running_mean, running_var,
238
+ use_input_stats, momentum, eps, false
239
+ )
240
+ end
241
+
242
+ def layer_norm(input, normalized_shape, weight: nil, bias: nil, eps: 1e-5)
243
+ Torch.layer_norm(input, normalized_shape, weight, bias, eps, false)
244
+ end
245
+
246
+ def local_response_norm(input, size, alpha: 1e-4, beta: 0.75, k: 1.0)
247
+ dim = input.dim
248
+ if dim < 3
249
+ raise ArgumentError, "Expected 3D or higher dimensionality input (got #{dim} dimensions)"
250
+ end
251
+ div = input.mul(input).unsqueeze(1)
252
+ if dim == 3
253
+ div = pad(div, [0, 0, size / 2, (size - 1) / 2])
254
+ div = avg_pool2d(div, [size, 1], stride: 1).squeeze(1)
255
+ else
256
+ sizes = input.size
257
+ div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
258
+ div = pad(div, [0, 0, 0, 0, size / 2, (size - 1) / 2])
259
+ div = avg_pool3d(div, [size, 1, 1], stride: 1).squeeze(1)
260
+ div = div.view(sizes)
261
+ end
262
+ div = div.mul(alpha).add(k).pow(beta)
263
+ input / div
16
264
  end
17
265
 
266
+ # linear layers
267
+
18
268
  def linear(input, weight, bias)
19
- Torch.linear(input, weight, bias)
269
+ NN.linear(input, weight, bias)
270
+ end
271
+
272
+ def bilinear(input1, input2, weight, bias)
273
+ Torch.bilinear(input1, input2, weight, bias)
274
+ end
275
+
276
+ # dropout layers
277
+
278
+ def dropout(input, p: 0.5, training: true, inplace: false)
279
+ if inplace
280
+ Torch.dropout!(input, p, training)
281
+ else
282
+ Torch.dropout(input, p, training)
283
+ end
284
+ end
285
+
286
+ def dropout2d(input, p: 0.5, training: true, inplace: false)
287
+ raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
288
+
289
+ if inplace
290
+ Torch.feature_dropout!(input, p, training)
291
+ else
292
+ Torch.feature_dropout(input, p, training)
293
+ end
294
+ end
295
+
296
+ def dropout3d(input, p: 0.5, training: true, inplace: false)
297
+ if inplace
298
+ Torch.feature_dropout!(input, p, training)
299
+ else
300
+ Torch.feature_dropout(input, p, training)
301
+ end
302
+ end
303
+
304
+ def alpha_dropout(input, p: 0.5, training: true, inplace: false)
305
+ if inplace
306
+ Torch.alpha_dropout!(input, p, training)
307
+ else
308
+ Torch.alpha_dropout(input, p, training)
309
+ end
310
+ end
311
+
312
+ def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
313
+ if inplace
314
+ Torch.feature_alpha_dropout!(input, p, training)
315
+ else
316
+ Torch.feature_alpha_dropout(input, p, training)
317
+ end
318
+ end
319
+
320
+ # sparse layers
321
+
322
+ def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
323
+ # TODO handle max_norm and norm_type
324
+ raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
325
+
326
+ padding_idx ||= -1
327
+ # weight and indices are swapped from Python interface
328
+ Torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
329
+ end
330
+
331
+ def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
332
+ # TODO handle max_norm and norm_type
333
+ raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
334
+
335
+ mode_enum =
336
+ case mode
337
+ when "sum"
338
+ 0
339
+ when "mean"
340
+ 1
341
+ when "max"
342
+ 2
343
+ else
344
+ raise ArgumentError, "Unknown mode: #{mode}"
345
+ end
346
+
347
+ # weight and input swapped
348
+ Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
349
+ end
350
+
351
+ # distance functions
352
+
353
+ def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
354
+ Torch.cosine_similarity(x1, x2, dim, eps)
355
+ end
356
+
357
+ def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
358
+ Torch.pairwise_distance(x1, x2, p, eps, keepdim)
359
+ end
360
+
361
+ # loss functions
362
+
363
+ def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
364
+ NN.binary_cross_entropy(input, target, weight, reduction)
365
+ end
366
+
367
+ def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
368
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
369
+ end
370
+
371
+ def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
372
+ raise NotImplementedYet
373
+ end
374
+
375
+ def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
376
+ nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
377
+ end
378
+
379
+ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
380
+ # call to_a on input_lengths and target_lengths for C++
381
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
382
+ end
383
+
384
+ def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
385
+ Torch.hinge_embedding_loss(input, target, margin, reduction)
386
+ end
387
+
388
+ def kl_div(input, target, reduction: "mean")
389
+ Torch.kl_div(input, target, reduction)
390
+ end
391
+
392
+ def l1_loss(input, target, reduction: "mean")
393
+ NN.l1_loss(input, target, reduction)
394
+ end
395
+
396
+ def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
397
+ raise NotImplementedYet
20
398
  end
21
399
 
22
400
  def mse_loss(input, target, reduction: "mean")
23
- Torch.mse_loss(input, target, reduction)
401
+ NN.mse_loss(input, target, reduction)
24
402
  end
25
403
 
26
- def cross_entropy(input, target)
27
- nll_loss(log_softmax(input, 1), target)
404
+ def multilabel_margin_loss(input, target, reduction: "mean")
405
+ NN.multilabel_margin_loss(input, target, reduction)
28
406
  end
29
407
 
30
- def nll_loss(input, target)
31
- # TODO fix for non-1d
32
- Torch.nll_loss(input, target)
408
+ def multilabel_soft_margin_loss(input, target, weight: nil)
409
+ raise NotImplementedYet
33
410
  end
34
411
 
35
- def log_softmax(input, dim)
36
- input.log_softmax(dim)
412
+ def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
413
+ NN.multi_margin_loss(input, target, p, margin, weight, reduction)
414
+ end
415
+
416
+ def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
417
+ NN.nll_loss(input, target, weight, reduction, ignore_index)
418
+ end
419
+
420
+ def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
421
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
422
+ end
423
+
424
+ def soft_margin_loss(input, target, reduction: "mean")
425
+ NN.soft_margin_loss(input, target, reduction)
426
+ end
427
+
428
+ def smooth_l1_loss(input, target, reduction: "mean")
429
+ NN.smooth_l1_loss(input, target, reduction)
430
+ end
431
+
432
+ def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
433
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
434
+ end
435
+
436
+ private
437
+
438
+ def softmax_dim(ndim)
439
+ ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
37
440
  end
38
441
  end
39
442
  end