torch-rb 0.1.3 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
@@ -1,36 +1,26 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Conv2d < ConvNd
4
- attr_reader :bias, :weight
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
5
6
 
6
- def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
7
- kernel_size = pair(kernel_size)
8
- stride = pair(stride)
9
- padding = pair(padding)
10
- dilation = pair(dilation)
11
- super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode)
7
+ kernel_size = _pair(kernel_size)
8
+ stride = _pair(stride)
9
+ padding = _pair(padding)
10
+ dilation = _pair(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _pair(0), groups, bias, padding_mode)
12
12
  end
13
13
 
14
14
  def forward(input)
15
15
  if @padding_mode == "circular"
16
16
  raise NotImplementedError
17
17
  end
18
- F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups)
18
+ F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
19
  end
20
20
 
21
21
  # TODO add more parameters
22
- def inspect
23
- "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
24
- end
25
-
26
- private
27
-
28
- def pair(value)
29
- if value.is_a?(Array)
30
- value
31
- else
32
- [value] * 2
33
- end
22
+ def extra_inspect
23
+ format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
34
24
  end
35
25
  end
36
26
  end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class Conv3d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _triple(kernel_size)
8
+ stride = _triple(stride)
9
+ padding = _triple(padding)
10
+ dilation = _triple(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _triple(0), groups, bias, padding_mode)
12
+ end
13
+
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
17
+ end
18
+ F.conv3d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
+ end
20
+ end
21
+ end
22
+ end
@@ -29,11 +29,11 @@ module Torch
29
29
  end
30
30
 
31
31
  def reset_parameters
32
- Init.kaiming_uniform!(@weight, Math.sqrt(5))
32
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
33
33
  if @bias
34
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
34
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
35
35
  bound = 1 / Math.sqrt(fan_in)
36
- Init.uniform!(@bias, -bound, bound)
36
+ Init.uniform!(@bias, a: -bound, b: bound)
37
37
  end
38
38
  end
39
39
  end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CosineEmbeddingLoss < Loss
4
+ def initialize(margin: 0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CosineSimilarity < Module
4
+ def initialize(dim: 1, eps: 1e-8)
5
+ super()
6
+ @dim = dim
7
+ @eps = eps
8
+ end
9
+
10
+ def forward(x1, x2)
11
+ F.cosine_similarity(x1, x2, dim: @dim, eps: @eps)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CrossEntropyLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CTCLoss < Loss
4
+ def initialize(blank: 0, reduction: "mean", zero_infinity: false)
5
+ super(reduction)
6
+ @blank = blank
7
+ @zero_infinity = zero_infinity
8
+ end
9
+
10
+ def forward(log_probs, targets, input_lengths, target_lengths)
11
+ F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -7,8 +7,8 @@ module Torch
7
7
  @inplace = inplace
8
8
  end
9
9
 
10
- def inspect
11
- "#{self.class.name.split("::").last}(p: #{@p.inspect}, inplace: #{@inplace.inspect})"
10
+ def extra_inspect
11
+ format("p: %s, inplace: %s", @p, @inplace)
12
12
  end
13
13
  end
14
14
  end
@@ -0,0 +1,34 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
2
+ module Torch
3
+ module NN
4
+ class EmbeddingBag < Module
5
+ def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
6
+ scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
7
+
8
+ super()
9
+ @num_embeddings = num_embeddings
10
+ @embedding_dim = embedding_dim
11
+ @max_norm = max_norm
12
+ @norm_type = norm_type
13
+ @scale_grad_by_freq = scale_grad_by_freq
14
+ if _weight.nil?
15
+ @weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
16
+ reset_parameters
17
+ else
18
+ raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
19
+ @weight = Parameter.new(_weight)
20
+ end
21
+ @mode = mode
22
+ @sparse = sparse
23
+ end
24
+
25
+ def reset_parameters
26
+ Init.normal!(@weight)
27
+ end
28
+
29
+ def forward(input, offsets: nil, per_sample_weights: nil)
30
+ F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
31
+ end
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class Fold < Module
4
+ def initialize(output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
5
+ super()
6
+ @output_size = output_size
7
+ @kernel_size = kernel_size
8
+ @dilation = dilation
9
+ @padding = padding
10
+ @stride = stride
11
+ end
12
+
13
+ def forward(input)
14
+ F.fold(input, @output_size, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
15
+ end
16
+
17
+ # TODO add extra_inspect
18
+ end
19
+ end
20
+ end
@@ -2,51 +2,284 @@ module Torch
2
2
  module NN
3
3
  class Functional
4
4
  class << self
5
- def relu(input)
6
- Torch.relu(input)
5
+ include Utils
6
+
7
+ # convolution layers
8
+
9
+ def conv1d(*args, **options)
10
+ Torch.conv1d(*args, **options)
7
11
  end
8
12
 
9
- def conv2d(input, weight, bias, stride: 1, padding: 0, dilation: 1, groups: 1)
10
- # TODO pair stride and padding when needed
11
- Torch.conv2d(input, weight, bias, stride, padding, dilation, groups)
13
+ def conv2d(*args, **options)
14
+ Torch.conv2d(*args, **options)
12
15
  end
13
16
 
14
- def max_pool2d(input, kernel_size)
15
- kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
16
- Torch.max_pool2d(input, kernel_size)
17
+ def conv3d(*args, **options)
18
+ Torch.conv3d(*args, **options)
17
19
  end
18
20
 
19
- def avg_pool2d(input, kernel_size)
20
- kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
21
- Torch.avg_pool2d(input, kernel_size)
21
+ def unfold(input, kernel_size, dilation: 1, padding: 0, stride: 1)
22
+ if input.dim == 4
23
+ NN.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
24
+ else
25
+ raise Error, "Input Error: Only 4D input Tensors are supported (got #{input.dim}D)"
26
+ end
22
27
  end
23
28
 
24
- def linear(input, weight, bias)
25
- Torch.linear(input, weight, bias)
29
+ def fold(input, output_size, kernel_size, dilation: 1, padding: 0, stride: 1)
30
+ if input.dim == 3
31
+ NN.col2im(input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
32
+ else
33
+ raise Error, "Input Error: Only 3D input Tensors are supported (got #{input.dim}D)"
34
+ end
26
35
  end
27
36
 
28
- def mse_loss(input, target, reduction: "mean")
29
- Torch.mse_loss(input, target, reduction)
37
+ # pooling layers
38
+
39
+ def max_pool1d(*args, **options)
40
+ return_indices = args.pop if args.size == 7
41
+ if return_indices
42
+ Torch.max_pool1d_with_indices(*args, **options)
43
+ else
44
+ Torch.max_pool1d(*args, **options)
45
+ end
46
+ end
47
+
48
+ def max_pool2d(*args, **options)
49
+ return_indices = args.pop if args.size == 7
50
+ if return_indices
51
+ NN.max_pool2d_with_indices(*args, **options)
52
+ else
53
+ Torch.max_pool2d(*args, **options)
54
+ end
55
+ end
56
+
57
+ def max_pool3d(*args, **options)
58
+ return_indices = args.pop if args.size == 7
59
+ if return_indices
60
+ NN.max_pool3d_with_indices(*args, **options)
61
+ else
62
+ Torch.max_pool3d(*args, **options)
63
+ end
64
+ end
65
+
66
+ def max_unpool1d(input, indices, kernel_size, stride: nil, padding: 0, output_size: nil)
67
+ raise NotImplementedYet
68
+ kernel_size = _single(kernel_size)
69
+ if !stride.nil?
70
+ _stride = _single(stride)
71
+ else
72
+ _stride = kernel_size
73
+ end
74
+ padding = _single(padding)
75
+ output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
76
+ output_size = output_size + [1]
77
+ NN.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3)
78
+ end
79
+
80
+ def max_unpool2d(*args, **options)
81
+ raise NotImplementedYet
82
+ NN.max_unpool2d(*args, **options)
83
+ end
84
+
85
+ def max_unpool3d(*args, **options)
86
+ raise NotImplementedYet
87
+ NN.max_unpool3d(*args, **options)
30
88
  end
31
89
 
32
- def cross_entropy(input, target)
33
- nll_loss(log_softmax(input, 1), target)
90
+ def avg_pool1d(*args, **options)
91
+ Torch.avg_pool1d(*args, **options)
34
92
  end
35
93
 
36
- def nll_loss(input, target, reduction: "mean")
37
- # TODO fix for non-1d
38
- Torch.nll_loss(input, target, reduction)
94
+ def avg_pool2d(*args, **options)
95
+ NN.avg_pool2d(*args, **options)
96
+ end
97
+
98
+ def avg_pool3d(*args, **options)
99
+ NN.avg_pool3d(*args, **options)
100
+ end
101
+
102
+ # padding layers
103
+
104
+ def pad(input, pad, mode: "constant", value: 0)
105
+ raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0
106
+ raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
107
+
108
+ if mode == "constant"
109
+ return Torch.constant_pad_nd(input, pad, value)
110
+ else
111
+ raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
112
+
113
+ if input.dim == 3
114
+ raise ArgumentError, "3D tensors expect 2 values for padding" unless pad.size == 2
115
+ case mode
116
+ when "reflect"
117
+ NN.reflection_pad1d(input, pad)
118
+ when "replicate"
119
+ NN.replication_pad1d(input, pad)
120
+ else
121
+ raise NotImplementedYet
122
+ end
123
+ elsif input.dim == 4
124
+ raise ArgumentError, "4D tensors expect 4 values for padding" unless pad.size == 4
125
+ case mode
126
+ when "reflect"
127
+ NN.reflection_pad2d(input, pad)
128
+ when "replicate"
129
+ NN.replication_pad2d(input, pad)
130
+ else
131
+ raise NotImplementedYet
132
+ end
133
+ elsif input.dim == 5
134
+ raise ArgumentError, "5D tensors expect 6 values for padding" unless pad.size == 6
135
+ case mode
136
+ when "replicate"
137
+ NN.replication_pad3d(input, pad)
138
+ else
139
+ raise NotImplementedYet
140
+ end
141
+ else
142
+ raise ArgumentError, "Only 3D, 4D, 5D padding with non-constant padding are supported for now"
143
+ end
144
+ end
39
145
  end
40
146
 
41
- def log_softmax(input, dim)
147
+ # activation layers
148
+
149
+ def hardshrink(input, lambd = 0.5)
150
+ Torch.hardshrink(input, lambd)
151
+ end
152
+
153
+ def leaky_relu(input, negative_slope = 0.01)
154
+ NN.leaky_relu(input, negative_slope)
155
+ end
156
+
157
+ def log_sigmoid(input)
158
+ NN.log_sigmoid(input)
159
+ end
160
+
161
+ def prelu(input, weight)
162
+ Torch.prelu(input, weight)
163
+ end
164
+
165
+ def relu(input, inplace: false)
166
+ if inplace
167
+ input.relu!
168
+ else
169
+ input.relu
170
+ end
171
+ end
172
+
173
+ def softplus(input, beta: 1, threshold: 20)
174
+ NN.softplus(input, beta, threshold)
175
+ end
176
+
177
+ def softshrink(*args, **options)
178
+ NN.softshrink(*args, **options)
179
+ end
180
+
181
+ def softsign(input)
182
+ input / (input.abs + 1)
183
+ end
184
+
185
+ def tanhshrink(input)
186
+ input - input.tanh
187
+ end
188
+
189
+ # other activation layers
190
+
191
+ def softmin(input, dim: nil)
192
+ dim ||= softmax_dim(input.dim)
193
+ (-input).softmax(dim)
194
+ end
195
+
196
+ def softmax(input, dim: nil)
197
+ dim ||= softmax_dim(input.dim)
198
+ input.softmax(dim)
199
+ end
200
+
201
+ # TODO make dim keyword argument and update examples
202
+ def log_softmax(input, dim = nil)
203
+ dim ||= softmax_dim(input.dim)
42
204
  input.log_softmax(dim)
43
205
  end
44
206
 
207
+ # normalization layers
208
+
209
+ def batch_norm(input, running_mean, running_var, weight: nil, bias: nil,
210
+ training: false, momentum: 0.1, eps: 1e-5)
211
+
212
+ if training
213
+ size = input.size
214
+ size_prods = size[0]
215
+ (size.length - 2).times do |i|
216
+ size_prods *= size[i + 2]
217
+ end
218
+ if size_prods == 1
219
+ raise ArgumentError, "Expected more than 1 value per channel when training, got input size #{size.inspect}"
220
+ end
221
+ end
222
+
223
+ Torch.batch_norm(
224
+ input, weight, bias, running_mean, running_var,
225
+ training, momentum, eps, false
226
+ )
227
+ end
228
+
229
+ def group_norm(input, num_groups, weight: nil, bias: nil, eps: 1e-5)
230
+ Torch.group_norm(input, num_groups, weight, bias, eps, false)
231
+ end
232
+
233
+ def instance_norm(input, running_mean: nil, running_var: nil, weight: nil,
234
+ bias: nil, use_input_stats: true, momentum: 0.1, eps: 1e-5)
235
+
236
+ Torch.instance_norm(
237
+ input, weight, bias, running_mean, running_var,
238
+ use_input_stats, momentum, eps, false
239
+ )
240
+ end
241
+
242
+ def layer_norm(input, normalized_shape, weight: nil, bias: nil, eps: 1e-5)
243
+ Torch.layer_norm(input, normalized_shape, weight, bias, eps, false)
244
+ end
245
+
246
+ def local_response_norm(input, size, alpha: 1e-4, beta: 0.75, k: 1.0)
247
+ dim = input.dim
248
+ if dim < 3
249
+ raise ArgumentError, "Expected 3D or higher dimensionality input (got #{dim} dimensions)"
250
+ end
251
+ div = input.mul(input).unsqueeze(1)
252
+ if dim == 3
253
+ div = pad(div, [0, 0, size / 2, (size - 1) / 2])
254
+ div = avg_pool2d(div, [size, 1], stride: 1).squeeze(1)
255
+ else
256
+ sizes = input.size
257
+ div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
258
+ div = pad(div, [0, 0, 0, 0, size / 2, (size - 1) / 2])
259
+ div = avg_pool3d(div, [size, 1, 1], stride: 1).squeeze(1)
260
+ div = div.view(sizes)
261
+ end
262
+ div = div.mul(alpha).add(k).pow(beta)
263
+ input / div
264
+ end
265
+
266
+ # linear layers
267
+
268
+ def linear(input, weight, bias)
269
+ NN.linear(input, weight, bias)
270
+ end
271
+
272
+ def bilinear(input1, input2, weight, bias)
273
+ Torch.bilinear(input1, input2, weight, bias)
274
+ end
275
+
276
+ # dropout layers
277
+
45
278
  def dropout(input, p: 0.5, training: true, inplace: false)
46
279
  if inplace
47
- Torch._dropout!(input, p, training)
280
+ Torch.dropout!(input, p, training)
48
281
  else
49
- Torch._dropout(input, p, training)
282
+ Torch.dropout(input, p, training)
50
283
  end
51
284
  end
52
285
 
@@ -54,42 +287,156 @@ module Torch
54
287
  raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
55
288
 
56
289
  if inplace
57
- Torch._feature_dropout!(input, p, training)
290
+ Torch.feature_dropout!(input, p, training)
58
291
  else
59
- Torch._feature_dropout(input, p, training)
292
+ Torch.feature_dropout(input, p, training)
60
293
  end
61
294
  end
62
295
 
63
296
  def dropout3d(input, p: 0.5, training: true, inplace: false)
64
297
  if inplace
65
- Torch._feature_dropout!(input, p, training)
298
+ Torch.feature_dropout!(input, p, training)
66
299
  else
67
- Torch._feature_dropout(input, p, training)
300
+ Torch.feature_dropout(input, p, training)
68
301
  end
69
302
  end
70
303
 
71
304
  def alpha_dropout(input, p: 0.5, training: true, inplace: false)
72
305
  if inplace
73
- Torch._alpha_dropout!(input, p, training)
306
+ Torch.alpha_dropout!(input, p, training)
74
307
  else
75
- Torch._alpha_dropout(input, p, training)
308
+ Torch.alpha_dropout(input, p, training)
76
309
  end
77
310
  end
78
311
 
79
312
  def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
80
313
  if inplace
81
- Torch._feature_alpha_dropout!(input, p, training)
314
+ Torch.feature_alpha_dropout!(input, p, training)
82
315
  else
83
- Torch._feature_alpha_dropout(input, p, training)
316
+ Torch.feature_alpha_dropout(input, p, training)
84
317
  end
85
318
  end
86
319
 
320
+ # sparse layers
321
+
87
322
  def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
88
323
  # TODO handle max_norm and norm_type
89
324
  raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
90
325
 
91
326
  padding_idx ||= -1
92
- Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
327
+ # weight and indices are swapped from Python interface
328
+ Torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
329
+ end
330
+
331
+ def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
332
+ # TODO handle max_norm and norm_type
333
+ raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
334
+
335
+ mode_enum =
336
+ case mode
337
+ when "sum"
338
+ 0
339
+ when "mean"
340
+ 1
341
+ when "max"
342
+ 2
343
+ else
344
+ raise ArgumentError, "Unknown mode: #{mode}"
345
+ end
346
+
347
+ # weight and input swapped
348
+ Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
349
+ end
350
+
351
+ # distance functions
352
+
353
+ def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
354
+ Torch.cosine_similarity(x1, x2, dim, eps)
355
+ end
356
+
357
+ def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
358
+ Torch.pairwise_distance(x1, x2, p, eps, keepdim)
359
+ end
360
+
361
+ # loss functions
362
+
363
+ def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
364
+ NN.binary_cross_entropy(input, target, weight, reduction)
365
+ end
366
+
367
+ def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
368
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
369
+ end
370
+
371
+ def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
372
+ raise NotImplementedYet
373
+ end
374
+
375
+ def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
376
+ nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
377
+ end
378
+
379
+ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
380
+ # call to_a on input_lengths and target_lengths for C++
381
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
382
+ end
383
+
384
+ def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
385
+ Torch.hinge_embedding_loss(input, target, margin, reduction)
386
+ end
387
+
388
+ def kl_div(input, target, reduction: "mean")
389
+ Torch.kl_div(input, target, reduction)
390
+ end
391
+
392
+ def l1_loss(input, target, reduction: "mean")
393
+ NN.l1_loss(input, target, reduction)
394
+ end
395
+
396
+ def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
397
+ raise NotImplementedYet
398
+ end
399
+
400
+ def mse_loss(input, target, reduction: "mean")
401
+ NN.mse_loss(input, target, reduction)
402
+ end
403
+
404
+ def multilabel_margin_loss(input, target, reduction: "mean")
405
+ NN.multilabel_margin_loss(input, target, reduction)
406
+ end
407
+
408
+ def multilabel_soft_margin_loss(input, target, weight: nil)
409
+ raise NotImplementedYet
410
+ end
411
+
412
+ def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
413
+ NN.multi_margin_loss(input, target, p, margin, weight, reduction)
414
+ end
415
+
416
+ def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
417
+ NN.nll_loss(input, target, weight, reduction, ignore_index)
418
+ end
419
+
420
+ def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
421
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
422
+ end
423
+
424
+ def soft_margin_loss(input, target, reduction: "mean")
425
+ NN.soft_margin_loss(input, target, reduction)
426
+ end
427
+
428
+ def smooth_l1_loss(input, target, reduction: "mean")
429
+ NN.smooth_l1_loss(input, target, reduction)
430
+ end
431
+
432
+ def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
433
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
434
+ end
435
+
436
+ private
437
+
438
+ def softmax_dim(ndim)
439
+ ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
93
440
  end
94
441
  end
95
442
  end