torch-rb 0.1.3 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
@@ -0,0 +1,36 @@
1
+ module Torch
2
+ module NN
3
+ class GroupNorm < Module
4
+ def initialize(num_groups, num_channels, eps: 1e-5, affine: true)
5
+ super()
6
+ @num_groups = num_groups
7
+ @num_channels = num_channels
8
+ @eps = eps
9
+ @affine = affine
10
+ if @affine
11
+ @weight = Parameter.new(Torch::Tensor.new(num_channels))
12
+ @bias = Parameter.new(Torch::Tensor.new(num_channels))
13
+ else
14
+ register_parameter("weight", nil)
15
+ register_parameter("bias", nil)
16
+ end
17
+ reset_parameters
18
+ end
19
+
20
+ def reset_parameters
21
+ if @affine
22
+ Init.ones!(@weight)
23
+ Init.zeros!(@bias)
24
+ end
25
+ end
26
+
27
+ def forward(input)
28
+ F.group_norm(input, @num_groups, weight: @weight, bias: @bias, eps: @eps)
29
+ end
30
+
31
+ def extra_inspect
32
+ format("%{num_groups}, %{num_channels}, eps: %{eps}, affine: %{affine}", **dict)
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class GRU < RNNBase
4
+ def initialize(*args, **options)
5
+ super("GRU", *args, **options)
6
+ end
7
+
8
+ def run_impl(input, hx, batch_sizes)
9
+ if batch_sizes.nil?
10
+ Torch.gru(input, hx, _get_flat_weights, @bias, @num_layers,
11
+ @dropout, @training, @bidirectional, @batch_first)
12
+ else
13
+ Torch.gru(input, batch_sizes, hx, _get_flat_weights, @bias,
14
+ @num_layers, @dropout, @training, @bidirectional)
15
+ end
16
+ end
17
+
18
+ def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
19
+ if hx.nil?
20
+ num_directions = @bidirectional ? 2 : 1
21
+ hx = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
22
+ else
23
+ # Each batch of the hidden state should match the input sequence that
24
+ # the user believes he/she is passing in.
25
+ hx = permute_hidden(hx, sorted_indices)
26
+ end
27
+
28
+ check_forward_args(input, hx, batch_sizes)
29
+ result = run_impl(input, hx, batch_sizes)
30
+ output = result[0]
31
+ hidden = result[1]
32
+ [output, hidden]
33
+ end
34
+
35
+ def forward_tensor(input, hx: nil)
36
+ batch_sizes = nil
37
+ max_batch_size = @batch_first ? input.size(0) : input.size(1)
38
+ sorted_indices = nil
39
+ unsorted_indices = nil
40
+ output, hidden = forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
41
+ [output, permute_hidden(hidden, unsorted_indices)]
42
+ end
43
+
44
+ def forward(input, hx: nil)
45
+ forward_tensor(input, hx: hx)
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Hardshrink < Module
4
+ def initialize(lambd: 0.5)
5
+ super()
6
+ @lambd = lambd
7
+ end
8
+
9
+ def forward(input)
10
+ F.hardshrink(input, @lambd)
11
+ end
12
+
13
+ def extra_inspect
14
+ @lambd.to_s
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class HingeEmbeddingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.hinge_embedding_loss(input, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class Identity < Module
4
+ # written this way to support unused arguments
5
+ def initialize(*args, **options)
6
+ super()
7
+ end
8
+
9
+ def forward(input)
10
+ input
11
+ end
12
+ end
13
+ end
14
+ end
@@ -2,7 +2,64 @@ module Torch
2
2
  module NN
3
3
  module Init
4
4
  class << self
5
- def calculate_fan_in_and_fan_out(tensor)
5
+ def calculate_gain(nonlinearity, param: 0.01)
6
+ _calculate_gain(nonlinearity, param)
7
+ end
8
+
9
+ def uniform!(tensor, a: 0.0, b: 1.0)
10
+ _uniform!(tensor, a, b)
11
+ end
12
+
13
+ def normal!(tensor, mean: 0.0, std: 1.0)
14
+ _normal!(tensor, mean, std)
15
+ end
16
+
17
+ def constant!(tensor, val)
18
+ _constant!(tensor, val)
19
+ end
20
+
21
+ def ones!(tensor)
22
+ _ones!(tensor)
23
+ end
24
+
25
+ def zeros!(tensor)
26
+ _zeros!(tensor)
27
+ end
28
+
29
+ def eye!(tensor)
30
+ _eye!(tensor)
31
+ end
32
+
33
+ def dirac!(tensor)
34
+ _dirac!(tensor)
35
+ end
36
+
37
+ def xavier_uniform!(tensor, gain: 1.0)
38
+ _xavier_uniform!(tensor, gain)
39
+ end
40
+
41
+ def xavier_normal!(tensor, gain: 1.0)
42
+ _xavier_normal!(tensor, gain)
43
+ end
44
+
45
+ def kaiming_uniform!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
46
+ _kaiming_uniform!(tensor, a, mode, nonlinearity)
47
+ end
48
+
49
+ def kaiming_normal!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
50
+ _kaiming_normal!(tensor, a, mode, nonlinearity)
51
+ end
52
+
53
+ def orthogonal!(tensor, gain: 1)
54
+ _orthogonal!(tensor, gain)
55
+ end
56
+
57
+ def sparse!(tensor, sparsity, std: 0.01)
58
+ _sparse!(tensor, sparsity, std)
59
+ end
60
+
61
+ # TODO move to C++ when released
62
+ def _calculate_fan_in_and_fan_out(tensor)
6
63
  dimensions = tensor.dim
7
64
  if dimensions < 2
8
65
  raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm < BatchNorm
4
+ def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: false, track_running_stats: false)
5
+ super(num_features, eps: eps, momentum: momentum, affine: affine, track_running_stats: track_running_stats)
6
+ end
7
+
8
+ def forward(input)
9
+ _check_input_dim(input)
10
+
11
+ F.instance_norm(
12
+ input, running_mean: @running_mean, running_var: @running_var,
13
+ weight: @weight, bias: @bias,
14
+ use_input_stats: @training || !@track_running_stats,
15
+ momentum: @momentum, eps: @eps
16
+ )
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm1d < InstanceNorm
4
+ def _check_input_dim(input)
5
+ if input.dim == 2
6
+ raise ArgumentError,
7
+ "InstanceNorm1d returns 0-filled tensor to 2D tensor." +
8
+ "This is because InstanceNorm1d reshapes inputs to" +
9
+ "(1, N * C, ...) from (N, C,...) and this makes" +
10
+ "variances 0."
11
+ end
12
+ if input.dim != 3
13
+ raise "expected 3D input (got #{input.dim}D input)"
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm2d < InstanceNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 4
6
+ raise ArgumentError, "expected 4D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm3d < InstanceNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 5
6
+ raise ArgumentError, "expected 5D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class KLDivLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.kl_div(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class L1Loss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.l1_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,35 @@
1
+ module Torch
2
+ module NN
3
+ class LayerNorm < Module
4
+ def initialize(normalized_shape, eps: 1e-5, elementwise_affine: true)
5
+ super()
6
+ @normalized_shape = Array(normalized_shape)
7
+ @eps = eps
8
+ @elementwise_affine = elementwise_affine
9
+ if @elementwise_affine
10
+ @weight = Parameter.new(Torch::Tensor.new(*normalized_shape))
11
+ @bias = Parameter.new(Torch::Tensor.new(*normalized_shape))
12
+ else
13
+ register_parameter("weight", nil)
14
+ register_parameter("bias", nil)
15
+ end
16
+ reset_parameters
17
+ end
18
+
19
+ def reset_parameters
20
+ if @elementwise_affine
21
+ Init.ones!(@weight)
22
+ Init.zeros!(@bias)
23
+ end
24
+ end
25
+
26
+ def forward(input)
27
+ F.layer_norm(input, @normalized_shape, weight: @weight, bias: @bias, eps: @eps)
28
+ end
29
+
30
+ def extra_inspect
31
+ format("%{normalized_shape}, eps: %{eps}, elementwise_affine: %{elementwise_affine}", **dict)
32
+ end
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class LeakyReLU < Module
4
+ def initialize(negative_slope: 1e-2) #, inplace: false)
5
+ super()
6
+ @negative_slope = negative_slope
7
+ # @inplace = inplace
8
+ end
9
+
10
+ def forward(input)
11
+ F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
12
+ end
13
+
14
+ def extra_inspect
15
+ inplace_str = @inplace ? ", inplace: true" : ""
16
+ format("negative_slope: %s%s", @negative_slope, inplace_str)
17
+ end
18
+ end
19
+ end
20
+ end
@@ -1,35 +1,36 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
- attr_reader :bias, :weight
5
-
6
4
  def initialize(in_features, out_features, bias: true)
5
+ super()
7
6
  @in_features = in_features
8
7
  @out_features = out_features
9
8
 
10
9
  @weight = Parameter.new(Tensor.new(out_features, in_features))
11
10
  if bias
12
11
  @bias = Parameter.new(Tensor.new(out_features))
12
+ else
13
+ register_parameter("bias", nil)
13
14
  end
14
15
 
15
16
  reset_parameters
16
17
  end
17
18
 
18
- def call(input)
19
- F.linear(input, @weight, @bias)
20
- end
21
-
22
19
  def reset_parameters
23
- Init.kaiming_uniform!(@weight, Math.sqrt(5))
20
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
24
21
  if @bias
25
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
22
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
26
23
  bound = 1 / Math.sqrt(fan_in)
27
- Init.uniform!(@bias, -bound, bound)
24
+ Init.uniform!(@bias, a: -bound, b: bound)
28
25
  end
29
26
  end
30
27
 
31
- def inspect
32
- "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
28
+ def forward(input)
29
+ F.linear(input, @weight, @bias)
30
+ end
31
+
32
+ def extra_inspect
33
+ format("in_features: %s, out_features: %s, bias: %s", @in_features, @out_features, !@bias.nil?)
33
34
  end
34
35
  end
35
36
  end
@@ -0,0 +1,21 @@
1
+ module Torch
2
+ module NN
3
+ class LocalResponseNorm < Module
4
+ def initialize(size, alpha: 1e-4, beta: 0.75, k: 1.0)
5
+ super()
6
+ @size = size
7
+ @alpha = alpha
8
+ @beta = beta
9
+ @k = k
10
+ end
11
+
12
+ def forward(input)
13
+ F.local_response_norm(input, @size, alpha: @alpha, beta: @beta, k: @k)
14
+ end
15
+
16
+ def extra_inspect
17
+ format("%{size}, alpha: %{alpha}, beta: %{beta}, k: %{k}", **dict)
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LogSigmoid < Module
4
+ def forward(input)
5
+ F.log_sigmoid(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class LogSoftmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.log_softmax(input, @dim)
11
+ end
12
+ end
13
+ end
14
+ end