torch-rb 0.1.2 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +18 -6
  5. data/ext/torch/ext.cpp +148 -369
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +242 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +240 -131
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +27 -22
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -38
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +411 -22
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +201 -20
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +2 -2
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +56 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +48 -16
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +71 -30
  139. data/lib/torch/utils/data/data_loader.rb +10 -4
  140. data/lib/torch/utils/data/tensor_dataset.rb +3 -0
  141. data/lib/torch/version.rb +1 -1
  142. metadata +123 -6
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class Conv1d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _single(kernel_size)
8
+ stride = _single(stride)
9
+ padding = _single(padding)
10
+ dilation = _single(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _single(0), groups, bias, padding_mode)
12
+ end
13
+
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
17
+ end
18
+ F.conv1d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
+ end
20
+ end
21
+ end
22
+ end
@@ -1,48 +1,26 @@
1
1
  module Torch
2
2
  module NN
3
- class Conv2d < Module
4
- attr_reader :bias, :weight
5
-
6
- def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0) #, dilation: 1, groups: 1)
7
- @in_channels = in_channels
8
- @out_channels = out_channels
9
- @kernel_size = pair(kernel_size)
10
- @stride = pair(stride)
11
- @padding = pair(padding)
12
- # @dilation = pair(dilation)
13
-
14
- # TODO divide by groups
15
- @weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
16
- @bias = Parameter.new(Tensor.new(out_channels))
17
-
18
- reset_parameters
3
+ class Conv2d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _pair(kernel_size)
8
+ stride = _pair(stride)
9
+ padding = _pair(padding)
10
+ dilation = _pair(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _pair(0), groups, bias, padding_mode)
19
12
  end
20
13
 
21
- def reset_parameters
22
- Init.kaiming_uniform_(@weight, Math.sqrt(5))
23
- if @bias
24
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
25
- bound = 1 / Math.sqrt(fan_in)
26
- Init.uniform_(@bias, -bound, bound)
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
27
17
  end
18
+ F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
28
19
  end
29
20
 
30
- def call(input)
31
- F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding) #, @dilation, @groups)
32
- end
33
-
34
- def inspect
35
- "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
36
- end
37
-
38
- private
39
-
40
- def pair(value)
41
- if value.is_a?(Array)
42
- value
43
- else
44
- [value] * 2
45
- end
21
+ # TODO add more parameters
22
+ def extra_inspect
23
+ format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
46
24
  end
47
25
  end
48
26
  end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class Conv3d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _triple(kernel_size)
8
+ stride = _triple(stride)
9
+ padding = _triple(padding)
10
+ dilation = _triple(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _triple(0), groups, bias, padding_mode)
12
+ end
13
+
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
17
+ end
18
+ F.conv3d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,41 @@
1
+ module Torch
2
+ module NN
3
+ class ConvNd < Module
4
+ def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
+ super()
6
+ raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
7
+ raise ArgumentError, "out_channels must be divisible by groups" if out_channels % groups != 0
8
+ @in_channels = in_channels
9
+ @out_channels = out_channels
10
+ @kernel_size = kernel_size
11
+ @stride = stride
12
+ @padding = padding
13
+ @dilation = dilation
14
+ @transposed = transposed
15
+ @output_padding = output_padding
16
+ @groups = groups
17
+ @padding_mode = padding_mode
18
+ if transposed
19
+ @weight = Parameter.new(Tensor.new(in_channels, out_channels / groups, *kernel_size))
20
+ else
21
+ @weight = Parameter.new(Tensor.new(out_channels, in_channels / groups, *kernel_size))
22
+ end
23
+ if bias
24
+ @bias = Parameter.new(Tensor.new(out_channels))
25
+ else
26
+ raise NotImplementedError
27
+ end
28
+ reset_parameters
29
+ end
30
+
31
+ def reset_parameters
32
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
33
+ if @bias
34
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
35
+ bound = 1 / Math.sqrt(fan_in)
36
+ Init.uniform!(@bias, a: -bound, b: bound)
37
+ end
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CosineEmbeddingLoss < Loss
4
+ def initialize(margin: 0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CosineSimilarity < Module
4
+ def initialize(dim: 1, eps: 1e-8)
5
+ super()
6
+ @dim = dim
7
+ @eps = eps
8
+ end
9
+
10
+ def forward(x1, x2)
11
+ F.cosine_similarity(x1, x2, dim: @dim, eps: @eps)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CrossEntropyLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CTCLoss < Loss
4
+ def initialize(blank: 0, reduction: "mean", zero_infinity: false)
5
+ super(reduction)
6
+ @blank = blank
7
+ @zero_infinity = zero_infinity
8
+ end
9
+
10
+ def forward(log_probs, targets, input_lengths, target_lengths)
11
+ F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Dropout < DropoutNd
4
+ def forward(input)
5
+ F.dropout(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Dropout2d < DropoutNd
4
+ def forward(input)
5
+ F.dropout2d(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Dropout3d < DropoutNd
4
+ def forward(input)
5
+ F.dropout3d(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class DropoutNd < Module
4
+ def initialize(p: 0.5, inplace: false)
5
+ super()
6
+ @p = p
7
+ @inplace = inplace
8
+ end
9
+
10
+ def extra_inspect
11
+ format("p: %s, inplace: %s", @p, @inplace)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,52 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
2
+ module Torch
3
+ module NN
4
+ class Embedding < Module
5
+ def initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil,
6
+ norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil)
7
+
8
+ super()
9
+ @num_embeddings = num_embeddings
10
+ @embedding_dim = embedding_dim
11
+
12
+ if padding_idx
13
+ if padding_idx > 0
14
+ raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx < @num_embeddings
15
+ elsif padding_idx < 0
16
+ raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx >= -@num_embeddings
17
+ padding_idx = @num_embeddings + padding_idx
18
+ end
19
+ end
20
+ @padding_idx = padding_idx
21
+ @max_norm = max_norm
22
+ @norm_type = norm_type
23
+ @scale_grad_by_freq = scale_grad_by_freq
24
+ if _weight.nil?
25
+ @weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
26
+ reset_parameters
27
+ else
28
+ raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
29
+ @weight = Parameter.new(_weight)
30
+ end
31
+ @sparse = sparse
32
+ end
33
+
34
+ def reset_parameters
35
+ Init.normal!(@weight)
36
+ if @padding_idx
37
+ Torch.no_grad do
38
+ @weight[@padding_idx].fill!(0)
39
+ end
40
+ end
41
+ end
42
+
43
+ def forward(input)
44
+ F.embedding(input, @weight, padding_idx: @padding_idx, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, sparse: @sparse)
45
+ end
46
+
47
+ def inspect
48
+ "Embedding(#{@num_embeddings}, #{@embedding_dim})"
49
+ end
50
+ end
51
+ end
52
+ end
@@ -0,0 +1,34 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
2
+ module Torch
3
+ module NN
4
+ class EmbeddingBag < Module
5
+ def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
6
+ scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
7
+
8
+ super()
9
+ @num_embeddings = num_embeddings
10
+ @embedding_dim = embedding_dim
11
+ @max_norm = max_norm
12
+ @norm_type = norm_type
13
+ @scale_grad_by_freq = scale_grad_by_freq
14
+ if _weight.nil?
15
+ @weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
16
+ reset_parameters
17
+ else
18
+ raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
19
+ @weight = Parameter.new(_weight)
20
+ end
21
+ @mode = mode
22
+ @sparse = sparse
23
+ end
24
+
25
+ def reset_parameters
26
+ Init.normal!(@weight)
27
+ end
28
+
29
+ def forward(input, offsets: nil, per_sample_weights: nil)
30
+ F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
31
+ end
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class FeatureAlphaDropout < DropoutNd
4
+ def forward(input)
5
+ F.feature_alpha_dropout(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class Fold < Module
4
+ def initialize(output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
5
+ super()
6
+ @output_size = output_size
7
+ @kernel_size = kernel_size
8
+ @dilation = dilation
9
+ @padding = padding
10
+ @stride = stride
11
+ end
12
+
13
+ def forward(input)
14
+ F.fold(input, @output_size, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
15
+ end
16
+
17
+ # TODO add extra_inspect
18
+ end
19
+ end
20
+ end
@@ -2,52 +2,441 @@ module Torch
2
2
  module NN
3
3
  class Functional
4
4
  class << self
5
- def relu(input)
6
- Torch.relu(input)
5
+ include Utils
6
+
7
+ # convolution layers
8
+
9
+ def conv1d(*args, **options)
10
+ Torch.conv1d(*args, **options)
11
+ end
12
+
13
+ def conv2d(*args, **options)
14
+ Torch.conv2d(*args, **options)
15
+ end
16
+
17
+ def conv3d(*args, **options)
18
+ Torch.conv3d(*args, **options)
19
+ end
20
+
21
+ def unfold(input, kernel_size, dilation: 1, padding: 0, stride: 1)
22
+ if input.dim == 4
23
+ NN.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
24
+ else
25
+ raise Error, "Input Error: Only 4D input Tensors are supported (got #{input.dim}D)"
26
+ end
27
+ end
28
+
29
+ def fold(input, output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
30
+ if input.dim == 3
31
+ NN.col2im(input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
32
+ else
33
+ raise Error, "Input Error: Only 3D input Tensors are supported (got #{input.dim}D)"
34
+ end
35
+ end
36
+
37
+ # pooling layers
38
+
39
+ def max_pool1d(*args, **options)
40
+ return_indices = args.pop if args.size == 7
41
+ if return_indices
42
+ Torch.max_pool1d_with_indices(*args, **options)
43
+ else
44
+ Torch.max_pool1d(*args, **options)
45
+ end
46
+ end
47
+
48
+ def max_pool2d(*args, **options)
49
+ return_indices = args.pop if args.size == 7
50
+ if return_indices
51
+ NN.max_pool2d_with_indices(*args, **options)
52
+ else
53
+ Torch.max_pool2d(*args, **options)
54
+ end
55
+ end
56
+
57
+ def max_pool3d(*args, **options)
58
+ return_indices = args.pop if args.size == 7
59
+ if return_indices
60
+ NN.max_pool3d_with_indices(*args, **options)
61
+ else
62
+ Torch.max_pool3d(*args, **options)
63
+ end
64
+ end
65
+
66
+ def max_unpool1d(input, indices, kernel_size, stride: nil, padding: 0, output_size: nil)
67
+ raise NotImplementedYet
68
+ kernel_size = _single(kernel_size)
69
+ if !stride.nil?
70
+ _stride = _single(stride)
71
+ else
72
+ _stride = kernel_size
73
+ end
74
+ padding = _single(padding)
75
+ output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
76
+ output_size = output_size + [1]
77
+ NN.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3)
78
+ end
79
+
80
+ def max_unpool2d(*args, **options)
81
+ raise NotImplementedYet
82
+ NN.max_unpool2d(*args, **options)
83
+ end
84
+
85
+ def max_unpool3d(*args, **options)
86
+ raise NotImplementedYet
87
+ NN.max_unpool3d(*args, **options)
88
+ end
89
+
90
+ def avg_pool1d(*args, **options)
91
+ Torch.avg_pool1d(*args, **options)
92
+ end
93
+
94
+ def avg_pool2d(*args, **options)
95
+ NN.avg_pool2d(*args, **options)
96
+ end
97
+
98
+ def avg_pool3d(*args, **options)
99
+ NN.avg_pool3d(*args, **options)
100
+ end
101
+
102
+ # padding layers
103
+
104
+ def pad(input, pad, mode: "constant", value: 0)
105
+ raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0
106
+ raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
107
+
108
+ if mode == "constant"
109
+ return Torch.constant_pad_nd(input, pad, value)
110
+ else
111
+ raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
112
+
113
+ if input.dim == 3
114
+ raise ArgumentError, "3D tensors expect 2 values for padding" unless pad.size == 2
115
+ case mode
116
+ when "reflect"
117
+ NN.reflection_pad1d(input, pad)
118
+ when "replicate"
119
+ NN.replication_pad1d(input, pad)
120
+ else
121
+ raise NotImplementedYet
122
+ end
123
+ elsif input.dim == 4
124
+ raise ArgumentError, "4D tensors expect 4 values for padding" unless pad.size == 4
125
+ case mode
126
+ when "reflect"
127
+ NN.reflection_pad2d(input, pad)
128
+ when "replicate"
129
+ NN.replication_pad2d(input, pad)
130
+ else
131
+ raise NotImplementedYet
132
+ end
133
+ elsif input.dim == 5
134
+ raise ArgumentError, "5D tensors expect 6 values for padding" unless pad.size == 6
135
+ case mode
136
+ when "replicate"
137
+ NN.replication_pad3d(input, pad)
138
+ else
139
+ raise NotImplementedYet
140
+ end
141
+ else
142
+ raise ArgumentError, "Only 3D, 4D, 5D padding with non-constant padding are supported for now"
143
+ end
144
+ end
145
+ end
146
+
147
+ # activation layers
148
+
149
+ def hardshrink(input, lambd = 0.5)
150
+ Torch.hardshrink(input, lambd)
151
+ end
152
+
153
+ def leaky_relu(input, negative_slope = 0.01)
154
+ NN.leaky_relu(input, negative_slope)
7
155
  end
8
156
 
9
- def conv2d(input, weight, bias, stride: 1, padding: 0)
10
- # TODO pair stride and padding when needed
11
- Torch.conv2d(input, weight, bias, stride, padding)
157
+ def log_sigmoid(input)
158
+ NN.log_sigmoid(input)
12
159
  end
13
160
 
14
161
  def prelu(input, weight)
15
162
  Torch.prelu(input, weight)
16
163
  end
17
164
 
18
- def leaky_relu(input, negative_slope = 0.01)
19
- Torch.leaky_relu(input, negative_slope)
165
+ def relu(input, inplace: false)
166
+ if inplace
167
+ input.relu!
168
+ else
169
+ input.relu
170
+ end
171
+ end
172
+
173
+ def softplus(input, beta: 1, threshold: 20)
174
+ NN.softplus(input, beta, threshold)
175
+ end
176
+
177
+ def softshrink(*args, **options)
178
+ NN.softshrink(*args, **options)
179
+ end
180
+
181
+ def softsign(input)
182
+ input / (input.abs + 1)
183
+ end
184
+
185
+ def tanhshrink(input)
186
+ input - input.tanh
187
+ end
188
+
189
+ # other activation layers
190
+
191
+ def softmin(input, dim: nil)
192
+ dim ||= softmax_dim(input.dim)
193
+ (-input).softmax(dim)
194
+ end
195
+
196
+ def softmax(input, dim: nil)
197
+ dim ||= softmax_dim(input.dim)
198
+ input.softmax(dim)
199
+ end
200
+
201
+ # TODO make dim keyword argument and update examples
202
+ def log_softmax(input, dim = nil)
203
+ dim ||= softmax_dim(input.dim)
204
+ input.log_softmax(dim)
205
+ end
206
+
207
+ # normalization layers
208
+
209
+ def batch_norm(input, running_mean, running_var, weight: nil, bias: nil,
210
+ training: false, momentum: 0.1, eps: 1e-5)
211
+
212
+ if training
213
+ size = input.size
214
+ size_prods = size[0]
215
+ (size.length - 2).times do |i|
216
+ size_prods *= size[i + 2]
217
+ end
218
+ if size_prods == 1
219
+ raise ArgumentError, "Expected more than 1 value per channel when training, got input size #{size.inspect}"
220
+ end
221
+ end
222
+
223
+ Torch.batch_norm(
224
+ input, weight, bias, running_mean, running_var,
225
+ training, momentum, eps, false
226
+ )
227
+ end
228
+
229
+ def group_norm(input, num_groups, weight: nil, bias: nil, eps: 1e-5)
230
+ Torch.group_norm(input, num_groups, weight, bias, eps, false)
20
231
  end
21
232
 
22
- def max_pool2d(input, kernel_size)
23
- kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
24
- Torch.max_pool2d(input, kernel_size)
233
+ def instance_norm(input, running_mean: nil, running_var: nil, weight: nil,
234
+ bias: nil, use_input_stats: true, momentum: 0.1, eps: 1e-5)
235
+
236
+ Torch.instance_norm(
237
+ input, weight, bias, running_mean, running_var,
238
+ use_input_stats, momentum, eps, false
239
+ )
240
+ end
241
+
242
+ def layer_norm(input, normalized_shape, weight: nil, bias: nil, eps: 1e-5)
243
+ Torch.layer_norm(input, normalized_shape, weight, bias, eps, false)
25
244
  end
26
245
 
27
- def avg_pool2d(input, kernel_size)
28
- kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
29
- Torch.avg_pool2d(input, kernel_size)
246
+ def local_response_norm(input, size, alpha: 1e-4, beta: 0.75, k: 1.0)
247
+ dim = input.dim
248
+ if dim < 3
249
+ raise ArgumentError, "Expected 3D or higher dimensionality input (got #{dim} dimensions)"
250
+ end
251
+ div = input.mul(input).unsqueeze(1)
252
+ if dim == 3
253
+ div = pad(div, [0, 0, size / 2, (size - 1) / 2])
254
+ div = avg_pool2d(div, [size, 1], stride: 1).squeeze(1)
255
+ else
256
+ sizes = input.size
257
+ div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
258
+ div = pad(div, [0, 0, 0, 0, size / 2, (size - 1) / 2])
259
+ div = avg_pool3d(div, [size, 1, 1], stride: 1).squeeze(1)
260
+ div = div.view(sizes)
261
+ end
262
+ div = div.mul(alpha).add(k).pow(beta)
263
+ input / div
30
264
  end
31
265
 
266
+ # linear layers
267
+
32
268
  def linear(input, weight, bias)
33
- Torch.linear(input, weight, bias)
269
+ NN.linear(input, weight, bias)
270
+ end
271
+
272
+ def bilinear(input1, input2, weight, bias)
273
+ Torch.bilinear(input1, input2, weight, bias)
274
+ end
275
+
276
+ # dropout layers
277
+
278
+ def dropout(input, p: 0.5, training: true, inplace: false)
279
+ if inplace
280
+ Torch.dropout!(input, p, training)
281
+ else
282
+ Torch.dropout(input, p, training)
283
+ end
284
+ end
285
+
286
+ def dropout2d(input, p: 0.5, training: true, inplace: false)
287
+ raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
288
+
289
+ if inplace
290
+ Torch.feature_dropout!(input, p, training)
291
+ else
292
+ Torch.feature_dropout(input, p, training)
293
+ end
294
+ end
295
+
296
+ def dropout3d(input, p: 0.5, training: true, inplace: false)
297
+ if inplace
298
+ Torch.feature_dropout!(input, p, training)
299
+ else
300
+ Torch.feature_dropout(input, p, training)
301
+ end
302
+ end
303
+
304
+ def alpha_dropout(input, p: 0.5, training: true, inplace: false)
305
+ if inplace
306
+ Torch.alpha_dropout!(input, p, training)
307
+ else
308
+ Torch.alpha_dropout(input, p, training)
309
+ end
310
+ end
311
+
312
+ def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
313
+ if inplace
314
+ Torch.feature_alpha_dropout!(input, p, training)
315
+ else
316
+ Torch.feature_alpha_dropout(input, p, training)
317
+ end
318
+ end
319
+
320
+ # sparse layers
321
+
322
+ def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
323
+ # TODO handle max_norm and norm_type
324
+ raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
325
+
326
+ padding_idx ||= -1
327
+ # weight and indices are swapped from Python interface
328
+ Torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
329
+ end
330
+
331
+ def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
332
+ # TODO handle max_norm and norm_type
333
+ raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
334
+
335
+ mode_enum =
336
+ case mode
337
+ when "sum"
338
+ 0
339
+ when "mean"
340
+ 1
341
+ when "max"
342
+ 2
343
+ else
344
+ raise ArgumentError, "Unknown mode: #{mode}"
345
+ end
346
+
347
+ # weight and input swapped
348
+ Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
349
+ end
350
+
351
+ # distance functions
352
+
353
+ def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
354
+ Torch.cosine_similarity(x1, x2, dim, eps)
355
+ end
356
+
357
+ def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
358
+ Torch.pairwise_distance(x1, x2, p, eps, keepdim)
359
+ end
360
+
361
+ # loss functions
362
+
363
+ def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
364
+ NN.binary_cross_entropy(input, target, weight, reduction)
365
+ end
366
+
367
+ def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
368
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
369
+ end
370
+
371
+ def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
372
+ raise NotImplementedYet
373
+ end
374
+
375
+ def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
376
+ nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
377
+ end
378
+
379
+ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
380
+ # call to_a on input_lengths and target_lengths for C++
381
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
382
+ end
383
+
384
+ def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
385
+ Torch.hinge_embedding_loss(input, target, margin, reduction)
386
+ end
387
+
388
+ def kl_div(input, target, reduction: "mean")
389
+ Torch.kl_div(input, target, reduction)
390
+ end
391
+
392
+ def l1_loss(input, target, reduction: "mean")
393
+ NN.l1_loss(input, target, reduction)
394
+ end
395
+
396
+ def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
397
+ raise NotImplementedYet
34
398
  end
35
399
 
36
400
  def mse_loss(input, target, reduction: "mean")
37
- Torch.mse_loss(input, target, reduction)
401
+ NN.mse_loss(input, target, reduction)
38
402
  end
39
403
 
40
- def cross_entropy(input, target)
41
- nll_loss(log_softmax(input, 1), target)
404
+ def multilabel_margin_loss(input, target, reduction: "mean")
405
+ NN.multilabel_margin_loss(input, target, reduction)
42
406
  end
43
407
 
44
- def nll_loss(input, target)
45
- # TODO fix for non-1d
46
- Torch.nll_loss(input, target)
408
+ def multilabel_soft_margin_loss(input, target, weight: nil)
409
+ raise NotImplementedYet
47
410
  end
48
411
 
49
- def log_softmax(input, dim)
50
- input.log_softmax(dim)
412
+ def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
413
+ NN.multi_margin_loss(input, target, p, margin, weight, reduction)
414
+ end
415
+
416
+ def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
417
+ NN.nll_loss(input, target, weight, reduction, ignore_index)
418
+ end
419
+
420
+ def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
421
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
422
+ end
423
+
424
+ def soft_margin_loss(input, target, reduction: "mean")
425
+ NN.soft_margin_loss(input, target, reduction)
426
+ end
427
+
428
+ def smooth_l1_loss(input, target, reduction: "mean")
429
+ NN.smooth_l1_loss(input, target, reduction)
430
+ end
431
+
432
+ def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
433
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
434
+ end
435
+
436
+ private
437
+
438
+ def softmax_dim(ndim)
439
+ ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
51
440
  end
52
441
  end
53
442
  end