torch-rb 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +1 -1
  4. data/ext/torch/ext.cpp +0 -170
  5. data/ext/torch/nn_functions.cpp +44 -24
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +48 -0
  8. data/ext/torch/tensor_functions.cpp +76 -16
  9. data/ext/torch/torch_functions.cpp +165 -65
  10. data/lib/torch.rb +51 -42
  11. data/lib/torch/ext.bundle +0 -0
  12. data/lib/torch/native/dispatcher.rb +1 -1
  13. data/lib/torch/native/function.rb +36 -5
  14. data/lib/torch/native/generator.rb +26 -7
  15. data/lib/torch/native/parser.rb +51 -14
  16. data/lib/torch/nn/avg_pool1d.rb +18 -0
  17. data/lib/torch/nn/avg_pool2d.rb +7 -2
  18. data/lib/torch/nn/avg_pool3d.rb +19 -0
  19. data/lib/torch/nn/avg_poolnd.rb +1 -1
  20. data/lib/torch/nn/batch_norm.rb +75 -0
  21. data/lib/torch/nn/batch_norm1d.rb +11 -0
  22. data/lib/torch/nn/batch_norm2d.rb +11 -0
  23. data/lib/torch/nn/batch_norm3d.rb +11 -0
  24. data/lib/torch/nn/constant_pad1d.rb +10 -0
  25. data/lib/torch/nn/constant_pad2d.rb +10 -0
  26. data/lib/torch/nn/constant_pad3d.rb +10 -0
  27. data/lib/torch/nn/constant_padnd.rb +18 -0
  28. data/lib/torch/nn/conv1d.rb +22 -0
  29. data/lib/torch/nn/conv2d.rb +9 -17
  30. data/lib/torch/nn/conv3d.rb +22 -0
  31. data/lib/torch/nn/fold.rb +20 -0
  32. data/lib/torch/nn/functional.rb +320 -100
  33. data/lib/torch/nn/group_norm.rb +36 -0
  34. data/lib/torch/nn/gru.rb +49 -0
  35. data/lib/torch/nn/hardshrink.rb +18 -0
  36. data/lib/torch/nn/instance_norm.rb +20 -0
  37. data/lib/torch/nn/instance_norm1d.rb +18 -0
  38. data/lib/torch/nn/instance_norm2d.rb +11 -0
  39. data/lib/torch/nn/instance_norm3d.rb +11 -0
  40. data/lib/torch/nn/layer_norm.rb +35 -0
  41. data/lib/torch/nn/local_response_norm.rb +21 -0
  42. data/lib/torch/nn/log_sigmoid.rb +9 -0
  43. data/lib/torch/nn/lp_pool1d.rb +9 -0
  44. data/lib/torch/nn/lp_pool2d.rb +9 -0
  45. data/lib/torch/nn/lp_poolnd.rb +22 -0
  46. data/lib/torch/nn/lstm.rb +66 -0
  47. data/lib/torch/nn/max_pool1d.rb +9 -0
  48. data/lib/torch/nn/max_pool2d.rb +1 -1
  49. data/lib/torch/nn/max_pool3d.rb +9 -0
  50. data/lib/torch/nn/max_poolnd.rb +6 -6
  51. data/lib/torch/nn/max_unpool1d.rb +16 -0
  52. data/lib/torch/nn/max_unpool2d.rb +16 -0
  53. data/lib/torch/nn/max_unpool3d.rb +16 -0
  54. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  55. data/lib/torch/nn/module.rb +7 -0
  56. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  57. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  58. data/lib/torch/nn/reflection_padnd.rb +13 -0
  59. data/lib/torch/nn/replication_pad1d.rb +10 -0
  60. data/lib/torch/nn/replication_pad2d.rb +10 -0
  61. data/lib/torch/nn/replication_pad3d.rb +10 -0
  62. data/lib/torch/nn/replication_padnd.rb +13 -0
  63. data/lib/torch/nn/rnn_base.rb +48 -4
  64. data/lib/torch/nn/softshrink.rb +18 -0
  65. data/lib/torch/nn/softsign.rb +9 -0
  66. data/lib/torch/nn/tanh.rb +9 -0
  67. data/lib/torch/nn/tanhshrink.rb +9 -0
  68. data/lib/torch/nn/unfold.rb +19 -0
  69. data/lib/torch/nn/utils.rb +25 -0
  70. data/lib/torch/nn/zero_pad2d.rb +9 -0
  71. data/lib/torch/tensor.rb +14 -25
  72. data/lib/torch/version.rb +1 -1
  73. metadata +50 -2
@@ -0,0 +1,36 @@
1
+ module Torch
2
+ module NN
3
+ class GroupNorm < Module
4
+ def initialize(num_groups, num_channels, eps: 1e-5, affine: true)
5
+ super()
6
+ @num_groups = num_groups
7
+ @num_channels = num_channels
8
+ @eps = eps
9
+ @affine = affine
10
+ if @affine
11
+ @weight = Parameter.new(Torch::Tensor.new(num_channels))
12
+ @bias = Parameter.new(Torch::Tensor.new(num_channels))
13
+ else
14
+ register_parameter("weight", nil)
15
+ register_parameter("bias", nil)
16
+ end
17
+ reset_parameters
18
+ end
19
+
20
+ def reset_parameters
21
+ if @affine
22
+ Init.ones!(@weight)
23
+ Init.zeros!(@bias)
24
+ end
25
+ end
26
+
27
+ def forward(input)
28
+ F.group_norm(input, @num_groups, weight: @weight, bias: @bias, eps: @eps)
29
+ end
30
+
31
+ def extra_inspect
32
+ format("%{num_groups}, %{num_channels}, eps: %{eps}, affine: %{affine}", **dict)
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class GRU < RNNBase
4
+ def initialize(*args, **options)
5
+ super("GRU", *args, **options)
6
+ end
7
+
8
+ def run_impl(input, hx, batch_sizes)
9
+ if batch_sizes.nil?
10
+ Torch.gru(input, hx, _get_flat_weights, @bias, @num_layers,
11
+ @dropout, @training, @bidirectional, @batch_first)
12
+ else
13
+ Torch.gru(input, batch_sizes, hx, _get_flat_weights, @bias,
14
+ @num_layers, @dropout, @training, @bidirectional)
15
+ end
16
+ end
17
+
18
+ def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
19
+ if hx.nil?
20
+ num_directions = @bidirectional ? 2 : 1
21
+ hx = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
22
+ else
23
+ # Each batch of the hidden state should match the input sequence that
24
+ # the user believes he/she is passing in.
25
+ hx = permute_hidden(hx, sorted_indices)
26
+ end
27
+
28
+ check_forward_args(input, hx, batch_sizes)
29
+ result = run_impl(input, hx, batch_sizes)
30
+ output = result[0]
31
+ hidden = result[1]
32
+ [output, hidden]
33
+ end
34
+
35
+ def forward_tensor(input, hx: nil)
36
+ batch_sizes = nil
37
+ max_batch_size = @batch_first ? input.size(0) : input.size(1)
38
+ sorted_indices = nil
39
+ unsorted_indices = nil
40
+ output, hidden = forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
41
+ [output, permute_hidden(hidden, unsorted_indices)]
42
+ end
43
+
44
+ def forward(input, hx: nil)
45
+ forward_tensor(input, hx: hx)
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Hardshrink < Module
4
+ def initialize(lambd: 0.5)
5
+ super()
6
+ @lambd = lambd
7
+ end
8
+
9
+ def forward(input)
10
+ F.hardshrink(input, @lambd)
11
+ end
12
+
13
+ def extra_inspect
14
+ @lambd.to_s
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm < BatchNorm
4
+ def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: false, track_running_stats: false)
5
+ super(num_features, eps: eps, momentum: momentum, affine: affine, track_running_stats: track_running_stats)
6
+ end
7
+
8
+ def forward(input)
9
+ _check_input_dim(input)
10
+
11
+ F.instance_norm(
12
+ input, running_mean: @running_mean, running_var: @running_var,
13
+ weight: @weight, bias: @bias,
14
+ use_input_stats: @training || !@track_running_stats,
15
+ momentum: @momentum, eps: @eps
16
+ )
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm1d < InstanceNorm
4
+ def _check_input_dim(input)
5
+ if input.dim == 2
6
+ raise ArgumentError,
7
+ "InstanceNorm1d returns 0-filled tensor to 2D tensor." +
8
+ "This is because InstanceNorm1d reshapes inputs to" +
9
+ "(1, N * C, ...) from (N, C,...) and this makes" +
10
+ "variances 0."
11
+ end
12
+ if input.dim != 3
13
+ raise "expected 3D input (got #{input.dim}D input)"
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm2d < InstanceNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 4
6
+ raise ArgumentError, "expected 4D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm3d < InstanceNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 5
6
+ raise ArgumentError, "expected 5D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,35 @@
1
+ module Torch
2
+ module NN
3
+ class LayerNorm < Module
4
+ def initialize(normalized_shape, eps: 1e-5, elementwise_affine: true)
5
+ super()
6
+ @normalized_shape = Array(normalized_shape)
7
+ @eps = eps
8
+ @elementwise_affine = elementwise_affine
9
+ if @elementwise_affine
10
+ @weight = Parameter.new(Torch::Tensor.new(*normalized_shape))
11
+ @bias = Parameter.new(Torch::Tensor.new(*normalized_shape))
12
+ else
13
+ register_parameter("weight", nil)
14
+ register_parameter("bias", nil)
15
+ end
16
+ reset_parameters
17
+ end
18
+
19
+ def reset_parameters
20
+ if @elementwise_affine
21
+ Init.ones!(@weight)
22
+ Init.zeros!(@bias)
23
+ end
24
+ end
25
+
26
+ def forward(input)
27
+ F.layer_norm(input, @normalized_shape, weight: @weight, bias: @bias, eps: @eps)
28
+ end
29
+
30
+ def extra_inspect
31
+ format("%{normalized_shape}, eps: %{eps}, elementwise_affine: %{elementwise_affine}", **dict)
32
+ end
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,21 @@
1
+ module Torch
2
+ module NN
3
+ class LocalResponseNorm < Module
4
+ def initialize(size, alpha: 1e-4, beta: 0.75, k: 1.0)
5
+ super()
6
+ @size = size
7
+ @alpha = alpha
8
+ @beta = beta
9
+ @k = k
10
+ end
11
+
12
+ def forward(input)
13
+ F.local_response_norm(input, @size, alpha: @alpha, beta: @beta, k: @k)
14
+ end
15
+
16
+ def extra_inspect
17
+ format("%{size}, alpha: %{alpha}, beta: %{beta}, k: %{k}", **dict)
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LogSigmoid < Module
4
+ def forward(input)
5
+ F.log_sigmoid(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LPPool1d < LPPoolNd
4
+ def forward(input)
5
+ F.lp_pool1d(input, @norm_type.to_f, @kernel_size, @stride, @ceil_mode)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LPPool2d < LPPoolNd
4
+ def forward(input)
5
+ F.lp_pool2d(input, @norm_type.to_f, @kernel_size, @stride, @ceil_mode)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class LPPoolNd < Module
4
+ def initialize(norm_type, kernel_size, stride: nil, ceil_mode: false)
5
+ super()
6
+ @norm_type = norm_type
7
+ @kernel_size = kernel_size
8
+ @stride = stride
9
+ @ceil_mode = ceil_mode
10
+ end
11
+
12
+ def extra_inspect
13
+ format("norm_type: %{norm_type}, kernel_size: %{kernel_size}, stride: %{stride}, ceil_mode: %{ceil_mode}",
14
+ norm_type: @norm_type,
15
+ kernel_size: @kernel_size,
16
+ stride: @stride,
17
+ ceil_mode: @ceil_mode
18
+ )
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,66 @@
1
+ module Torch
2
+ module NN
3
+ class LSTM < RNNBase
4
+ def initialize(*args, **options)
5
+ super("LSTM", *args, **options)
6
+ end
7
+
8
+ def check_forward_args(input, hidden, batch_sizes)
9
+ check_input(input, batch_sizes)
10
+ expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
11
+
12
+ # TODO pass message
13
+ check_hidden_size(hidden[0], expected_hidden_size)
14
+ check_hidden_size(hidden[1], expected_hidden_size)
15
+ end
16
+
17
+ def permute_hidden(hx, permutation)
18
+ if permutation.nil?
19
+ return hx
20
+ end
21
+ raise NotImplementedYet
22
+ end
23
+
24
+ def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
25
+ if hx.nil?
26
+ num_directions = @bidirectional ? 2 : 1
27
+ zeros = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
28
+ hx = [zeros, zeros]
29
+ else
30
+ # Each batch of the hidden state should match the input sequence that
31
+ # the user believes he/she is passing in.
32
+ hx = permute_hidden(hx, sorted_indices)
33
+ end
34
+
35
+ check_forward_args(input, hx, batch_sizes)
36
+ if batch_sizes.nil?
37
+ result = Torch.lstm(input, hx, _get_flat_weights, @bias, @num_layers,
38
+ @dropout, @training, @bidirectional, @batch_first)
39
+ else
40
+ result = Torch.lstm(input, batch_sizes, hx, _get_flat_weights, @bias,
41
+ @num_layers, @dropout, @training, @bidirectional)
42
+ end
43
+ output = result[0]
44
+ hidden = result[1..-1]
45
+
46
+ [output, hidden]
47
+ end
48
+
49
+ def forward_tensor(input, hx: nil)
50
+ batch_sizes = nil
51
+ max_batch_size = @batch_first ? input.size(0) : input.size(1)
52
+ sorted_indices = nil
53
+ unsorted_indices = nil
54
+
55
+ output, hidden = forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
56
+
57
+ [output, permute_hidden(hidden, unsorted_indices)]
58
+ end
59
+
60
+ def forward(input, hx: nil)
61
+ # TODO PackedSequence
62
+ forward_tensor(input, hx: hx)
63
+ end
64
+ end
65
+ end
66
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool1d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool1d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -2,7 +2,7 @@ module Torch
2
2
  module NN
3
3
  class MaxPool2d < MaxPoolNd
4
4
  def forward(input)
5
- F.max_pool2d(input, @kernel_size) # TODO other parameters
5
+ F.max_pool2d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
6
  end
7
7
  end
8
8
  end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool3d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool3d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,14 +1,14 @@
1
1
  module Torch
2
2
  module NN
3
3
  class MaxPoolNd < Module
4
- def initialize(kernel_size) #, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
4
+ def initialize(kernel_size, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
5
5
  super()
6
6
  @kernel_size = kernel_size
7
- # @stride = stride || kernel_size
8
- # @padding = padding
9
- # @dilation = dilation
10
- # @return_indices = return_indices
11
- # @ceil_mode = ceil_mode
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @dilation = dilation
10
+ @return_indices = return_indices
11
+ @ceil_mode = ceil_mode
12
12
  end
13
13
 
14
14
  def extra_inspect
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool1d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _single(kernel_size)
7
+ @stride = _single(stride || kernel_size)
8
+ @padding = _single(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool1d(input, indices, @kernel_size, stride: @stride, padding: @padding, output_size: output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool2d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _pair(kernel_size)
7
+ @stride = _pair(stride || kernel_size)
8
+ @padding = _pair(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool2d(input, indices, @kernel_size, @stride, @padding, output_size)
13
+ end
14
+ end
15
+ end
16
+ end