torch-rb 0.1.1 → 0.1.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +73 -9
  5. data/ext/torch/ext.cpp +148 -315
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +298 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +236 -112
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +52 -25
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -39
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +419 -16
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +191 -19
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +4 -0
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +62 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +60 -0
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +90 -30
  139. data/lib/torch/utils/data/data_loader.rb +15 -0
  140. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  141. data/lib/torch/version.rb +1 -1
  142. metadata +122 -3
@@ -0,0 +1,36 @@
1
+ module Torch
2
+ module NN
3
+ class GroupNorm < Module
4
+ def initialize(num_groups, num_channels, eps: 1e-5, affine: true)
5
+ super()
6
+ @num_groups = num_groups
7
+ @num_channels = num_channels
8
+ @eps = eps
9
+ @affine = affine
10
+ if @affine
11
+ @weight = Parameter.new(Torch::Tensor.new(num_channels))
12
+ @bias = Parameter.new(Torch::Tensor.new(num_channels))
13
+ else
14
+ register_parameter("weight", nil)
15
+ register_parameter("bias", nil)
16
+ end
17
+ reset_parameters
18
+ end
19
+
20
+ def reset_parameters
21
+ if @affine
22
+ Init.ones!(@weight)
23
+ Init.zeros!(@bias)
24
+ end
25
+ end
26
+
27
+ def forward(input)
28
+ F.group_norm(input, @num_groups, weight: @weight, bias: @bias, eps: @eps)
29
+ end
30
+
31
+ def extra_inspect
32
+ format("%{num_groups}, %{num_channels}, eps: %{eps}, affine: %{affine}", **dict)
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class GRU < RNNBase
4
+ def initialize(*args, **options)
5
+ super("GRU", *args, **options)
6
+ end
7
+
8
+ def run_impl(input, hx, batch_sizes)
9
+ if batch_sizes.nil?
10
+ Torch.gru(input, hx, _get_flat_weights, @bias, @num_layers,
11
+ @dropout, @training, @bidirectional, @batch_first)
12
+ else
13
+ Torch.gru(input, batch_sizes, hx, _get_flat_weights, @bias,
14
+ @num_layers, @dropout, @training, @bidirectional)
15
+ end
16
+ end
17
+
18
+ def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
19
+ if hx.nil?
20
+ num_directions = @bidirectional ? 2 : 1
21
+ hx = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
22
+ else
23
+ # Each batch of the hidden state should match the input sequence that
24
+ # the user believes he/she is passing in.
25
+ hx = permute_hidden(hx, sorted_indices)
26
+ end
27
+
28
+ check_forward_args(input, hx, batch_sizes)
29
+ result = run_impl(input, hx, batch_sizes)
30
+ output = result[0]
31
+ hidden = result[1]
32
+ [output, hidden]
33
+ end
34
+
35
+ def forward_tensor(input, hx: nil)
36
+ batch_sizes = nil
37
+ max_batch_size = @batch_first ? input.size(0) : input.size(1)
38
+ sorted_indices = nil
39
+ unsorted_indices = nil
40
+ output, hidden = forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
41
+ [output, permute_hidden(hidden, unsorted_indices)]
42
+ end
43
+
44
+ def forward(input, hx: nil)
45
+ forward_tensor(input, hx: hx)
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Hardshrink < Module
4
+ def initialize(lambd: 0.5)
5
+ super()
6
+ @lambd = lambd
7
+ end
8
+
9
+ def forward(input)
10
+ F.hardshrink(input, @lambd)
11
+ end
12
+
13
+ def extra_inspect
14
+ @lambd.to_s
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class HingeEmbeddingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.hinge_embedding_loss(input, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class Identity < Module
4
+ # written this way to support unused arguments
5
+ def initialize(*args, **options)
6
+ super()
7
+ end
8
+
9
+ def forward(input)
10
+ input
11
+ end
12
+ end
13
+ end
14
+ end
@@ -2,7 +2,64 @@ module Torch
2
2
  module NN
3
3
  module Init
4
4
  class << self
5
- def calculate_fan_in_and_fan_out(tensor)
5
+ def calculate_gain(nonlinearity, param: 0.01)
6
+ _calculate_gain(nonlinearity, param)
7
+ end
8
+
9
+ def uniform!(tensor, a: 0.0, b: 1.0)
10
+ _uniform!(tensor, a, b)
11
+ end
12
+
13
+ def normal!(tensor, mean: 0.0, std: 1.0)
14
+ _normal!(tensor, mean, std)
15
+ end
16
+
17
+ def constant!(tensor, val)
18
+ _constant!(tensor, val)
19
+ end
20
+
21
+ def ones!(tensor)
22
+ _ones!(tensor)
23
+ end
24
+
25
+ def zeros!(tensor)
26
+ _zeros!(tensor)
27
+ end
28
+
29
+ def eye!(tensor)
30
+ _eye!(tensor)
31
+ end
32
+
33
+ def dirac!(tensor)
34
+ _dirac!(tensor)
35
+ end
36
+
37
+ def xavier_uniform!(tensor, gain: 1.0)
38
+ _xavier_uniform!(tensor, gain)
39
+ end
40
+
41
+ def xavier_normal!(tensor, gain: 1.0)
42
+ _xavier_normal!(tensor, gain)
43
+ end
44
+
45
+ def kaiming_uniform!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
46
+ _kaiming_uniform!(tensor, a, mode, nonlinearity)
47
+ end
48
+
49
+ def kaiming_normal!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
50
+ _kaiming_normal!(tensor, a, mode, nonlinearity)
51
+ end
52
+
53
+ def orthogonal!(tensor, gain: 1)
54
+ _orthogonal!(tensor, gain)
55
+ end
56
+
57
+ def sparse!(tensor, sparsity, std: 0.01)
58
+ _sparse!(tensor, sparsity, std)
59
+ end
60
+
61
+ # TODO move to C++ when released
62
+ def _calculate_fan_in_and_fan_out(tensor)
6
63
  dimensions = tensor.dim
7
64
  if dimensions < 2
8
65
  raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm < BatchNorm
4
+ def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: false, track_running_stats: false)
5
+ super(num_features, eps: eps, momentum: momentum, affine: affine, track_running_stats: track_running_stats)
6
+ end
7
+
8
+ def forward(input)
9
+ _check_input_dim(input)
10
+
11
+ F.instance_norm(
12
+ input, running_mean: @running_mean, running_var: @running_var,
13
+ weight: @weight, bias: @bias,
14
+ use_input_stats: @training || !@track_running_stats,
15
+ momentum: @momentum, eps: @eps
16
+ )
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm1d < InstanceNorm
4
+ def _check_input_dim(input)
5
+ if input.dim == 2
6
+ raise ArgumentError,
7
+ "InstanceNorm1d returns 0-filled tensor to 2D tensor." +
8
+ "This is because InstanceNorm1d reshapes inputs to" +
9
+ "(1, N * C, ...) from (N, C,...) and this makes" +
10
+ "variances 0."
11
+ end
12
+ if input.dim != 3
13
+ raise "expected 3D input (got #{input.dim}D input)"
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm2d < InstanceNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 4
6
+ raise ArgumentError, "expected 4D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class InstanceNorm3d < InstanceNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 5
6
+ raise ArgumentError, "expected 5D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class KLDivLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.kl_div(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class L1Loss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.l1_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,35 @@
1
+ module Torch
2
+ module NN
3
+ class LayerNorm < Module
4
+ def initialize(normalized_shape, eps: 1e-5, elementwise_affine: true)
5
+ super()
6
+ @normalized_shape = Array(normalized_shape)
7
+ @eps = eps
8
+ @elementwise_affine = elementwise_affine
9
+ if @elementwise_affine
10
+ @weight = Parameter.new(Torch::Tensor.new(*normalized_shape))
11
+ @bias = Parameter.new(Torch::Tensor.new(*normalized_shape))
12
+ else
13
+ register_parameter("weight", nil)
14
+ register_parameter("bias", nil)
15
+ end
16
+ reset_parameters
17
+ end
18
+
19
+ def reset_parameters
20
+ if @elementwise_affine
21
+ Init.ones!(@weight)
22
+ Init.zeros!(@bias)
23
+ end
24
+ end
25
+
26
+ def forward(input)
27
+ F.layer_norm(input, @normalized_shape, weight: @weight, bias: @bias, eps: @eps)
28
+ end
29
+
30
+ def extra_inspect
31
+ format("%{normalized_shape}, eps: %{eps}, elementwise_affine: %{elementwise_affine}", **dict)
32
+ end
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class LeakyReLU < Module
4
+ def initialize(negative_slope: 1e-2) #, inplace: false)
5
+ super()
6
+ @negative_slope = negative_slope
7
+ # @inplace = inplace
8
+ end
9
+
10
+ def forward(input)
11
+ F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
12
+ end
13
+
14
+ def extra_inspect
15
+ inplace_str = @inplace ? ", inplace: true" : ""
16
+ format("negative_slope: %s%s", @negative_slope, inplace_str)
17
+ end
18
+ end
19
+ end
20
+ end
@@ -1,35 +1,36 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
- attr_reader :bias, :weight
5
-
6
4
  def initialize(in_features, out_features, bias: true)
5
+ super()
7
6
  @in_features = in_features
8
7
  @out_features = out_features
9
8
 
10
9
  @weight = Parameter.new(Tensor.new(out_features, in_features))
11
10
  if bias
12
11
  @bias = Parameter.new(Tensor.new(out_features))
12
+ else
13
+ register_parameter("bias", nil)
13
14
  end
14
15
 
15
16
  reset_parameters
16
17
  end
17
18
 
18
- def call(input)
19
- F.linear(input, @weight, @bias)
20
- end
21
-
22
19
  def reset_parameters
23
- Init.kaiming_uniform_(@weight, Math.sqrt(5))
20
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
24
21
  if @bias
25
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
22
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
26
23
  bound = 1 / Math.sqrt(fan_in)
27
- Init.uniform_(@bias, -bound, bound)
24
+ Init.uniform!(@bias, a: -bound, b: bound)
28
25
  end
29
26
  end
30
27
 
31
- def inspect
32
- "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
28
+ def forward(input)
29
+ F.linear(input, @weight, @bias)
30
+ end
31
+
32
+ def extra_inspect
33
+ format("in_features: %s, out_features: %s, bias: %s", @in_features, @out_features, !@bias.nil?)
33
34
  end
34
35
  end
35
36
  end
@@ -0,0 +1,21 @@
1
+ module Torch
2
+ module NN
3
+ class LocalResponseNorm < Module
4
+ def initialize(size, alpha: 1e-4, beta: 0.75, k: 1.0)
5
+ super()
6
+ @size = size
7
+ @alpha = alpha
8
+ @beta = beta
9
+ @k = k
10
+ end
11
+
12
+ def forward(input)
13
+ F.local_response_norm(input, @size, alpha: @alpha, beta: @beta, k: @k)
14
+ end
15
+
16
+ def extra_inspect
17
+ format("%{size}, alpha: %{alpha}, beta: %{beta}, k: %{k}", **dict)
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LogSigmoid < Module
4
+ def forward(input)
5
+ F.log_sigmoid(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class LogSoftmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.log_softmax(input, @dim)
11
+ end
12
+ end
13
+ end
14
+ end