torch-rb 0.1.2 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +18 -6
  5. data/ext/torch/ext.cpp +148 -369
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +242 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +240 -131
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +27 -22
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -38
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +411 -22
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +201 -20
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +2 -2
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +56 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +48 -16
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +71 -30
  139. data/lib/torch/utils/data/data_loader.rb +10 -4
  140. data/lib/torch/utils/data/tensor_dataset.rb +3 -0
  141. data/lib/torch/version.rb +1 -1
  142. metadata +123 -6
@@ -1,8 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
- class MSELoss < Module
3
+ class MSELoss < Loss
4
4
  def initialize(reduction: "mean")
5
- @reduction = reduction
5
+ super(reduction)
6
6
  end
7
7
 
8
8
  def forward(input, target)
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelMarginLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_margin_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelSoftMarginLoss < WeightedLoss
4
+ def initialize(weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_soft_margin_loss(input, target, weight: @weight, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,17 @@
1
+ module Torch
2
+ module NN
3
+ class MultiMarginLoss < WeightedLoss
4
+ def initialize(p: 1, margin: 1.0, weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ raise ArgumentError, "only p == 1 and p == 2 supported" if p != 1 && p != 2
7
+ raise ArgumentError, "weight must be nil or have one dimension" unless weight.nil? || weight.dim == 1
8
+ @p = p
9
+ @margin = margin
10
+ end
11
+
12
+ def forward(input, target)
13
+ F.multi_margin_loss(input, target, p: @p, margin: @margin, weight: @weight, reduction: @reduction)
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class NLLLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PairwiseDistance < Module
4
+ def initialize(p: 2.0, eps: 1e-6, keepdim: false)
5
+ super()
6
+ @norm = p
7
+ @eps = eps
8
+ @keepdim = keepdim
9
+ end
10
+
11
+ def forward(x1, x2)
12
+ F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -6,8 +6,8 @@ module Torch
6
6
  Tensor._make_subclass(data, requires_grad)
7
7
  end
8
8
 
9
- def grad
10
- _grad if _grad_defined
9
+ def inspect
10
+ "Parameter containing:\n#{super}"
11
11
  end
12
12
  end
13
13
  end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PoissonNLLLoss < Loss
4
+ def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
5
+ super(reduction)
6
+ @log_input = log_input
7
+ @full = full
8
+ @eps = eps
9
+ end
10
+
11
+ def forward(log_input, target)
12
+ F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class PReLU < Module
4
+ def initialize(num_parameters: 1, init: 0.25)
5
+ @num_parameters = num_parameters
6
+ super()
7
+ @weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
8
+ end
9
+
10
+ def forward(input)
11
+ F.prelu(input, @weight)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("num_parameters: %s", @num_parameters)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPad1d < ReflectionPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPad2d < ReflectionPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPadNd < Module
4
+ def forward(input)
5
+ F.pad(input, @padding, mode: "reflect")
6
+ end
7
+
8
+ def extra_inspect
9
+ @padding.inspect
10
+ end
11
+ end
12
+ end
13
+ end
data/lib/torch/nn/relu.rb CHANGED
@@ -1,12 +1,17 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ReLU < Module
4
- def initialize #(inplace: false)
5
- # @inplace = inplace
4
+ def initialize(inplace: false)
5
+ super()
6
+ @inplace = inplace
6
7
  end
7
8
 
8
9
  def forward(input)
9
- F.relu(input) #, inplace: @inplace)
10
+ F.relu(input, inplace: @inplace)
11
+ end
12
+
13
+ def extra_inspect
14
+ @inplace ? "inplace: true" : ""
10
15
  end
11
16
  end
12
17
  end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad1d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad2d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad3d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _ntuple(6, padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPadNd < Module
4
+ def forward(input)
5
+ F.pad(input, @padding, mode: "replicate")
6
+ end
7
+
8
+ def extra_inspect
9
+ @padding.inspect
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class RNN < RNNBase
4
+ def initialize(*args, **options)
5
+ if options.key?(:nonlinearity)
6
+ if options[:nonlinearity] == "tanh"
7
+ mode = "RNN_TANH"
8
+ elsif options[:nonlinearity] == "relu"
9
+ mode = "RNN_RELU"
10
+ else
11
+ raise ArgumentError, "Unknown nonlinearity: #{options[:nonlinearity]}"
12
+ end
13
+ options.delete(:nonlinearity)
14
+ else
15
+ mode = "RNN_TANH"
16
+ end
17
+
18
+ super(mode, *args, **options)
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,198 @@
1
+ module Torch
2
+ module NN
3
+ class RNNBase < Module
4
+ def initialize(mode, input_size, hidden_size, num_layers: 1, bias: true,
5
+ batch_first: false, dropout: 0.0, bidirectional: false)
6
+
7
+ super()
8
+ @mode = mode
9
+ @input_size = input_size
10
+ @hidden_size = hidden_size
11
+ @num_layers = num_layers
12
+ @bias = bias
13
+ @batch_first = batch_first
14
+ @dropout = dropout.to_f
15
+ @bidirectional = bidirectional
16
+ num_directions = bidirectional ? 2 : 1
17
+
18
+ if !dropout.is_a?(Numeric) || !(dropout >= 0 && dropout <= 1)
19
+ raise ArgumentError, "dropout should be a number in range [0, 1] " +
20
+ "representing the probability of an element being " +
21
+ "zeroed"
22
+ end
23
+ if dropout > 0 && num_layers == 1
24
+ warn "dropout option adds dropout after all but last " +
25
+ "recurrent layer, so non-zero dropout expects " +
26
+ "num_layers greater than 1, but got dropout=#{dropout} and " +
27
+ "num_layers=#{num_layers}"
28
+ end
29
+
30
+ gate_size =
31
+ case mode
32
+ when "LSTM"
33
+ 4 * hidden_size
34
+ when "GRU"
35
+ 3 * hidden_size
36
+ when "RNN_TANH"
37
+ hidden_size
38
+ when "RNN_RELU"
39
+ hidden_size
40
+ else
41
+ raise ArgumentError, "Unrecognized RNN mode: #{mode}"
42
+ end
43
+
44
+ @all_weights = []
45
+ num_layers.times do |layer|
46
+ num_directions.times do |direction|
47
+ layer_input_size = layer == 0 ? input_size : hidden_size * num_directions
48
+
49
+ w_ih = Parameter.new(Torch::Tensor.new(gate_size, layer_input_size))
50
+ w_hh = Parameter.new(Torch::Tensor.new(gate_size, hidden_size))
51
+ b_ih = Parameter.new(Torch::Tensor.new(gate_size))
52
+ # Second bias vector included for CuDNN compatibility. Only one
53
+ # bias vector is needed in standard definition.
54
+ b_hh = Parameter.new(Torch::Tensor.new(gate_size))
55
+ layer_params = [w_ih, w_hh, b_ih, b_hh]
56
+
57
+ suffix = direction == 1 ? "_reverse" : ""
58
+ param_names = ["weight_ih_l%s%s", "weight_hh_l%s%s"]
59
+ if bias
60
+ param_names += ["bias_ih_l%s%s", "bias_hh_l%s%s"]
61
+ end
62
+ param_names.map! { |x| x % [layer, suffix] }
63
+
64
+ param_names.zip(layer_params) do |name, param|
65
+ instance_variable_set("@#{name}", param)
66
+ end
67
+ @all_weights << param_names
68
+ end
69
+ end
70
+
71
+ flatten_parameters
72
+ reset_parameters
73
+ end
74
+
75
+ def flatten_parameters
76
+ # no-op unless module is on the GPU and cuDNN is enabled
77
+ end
78
+
79
+ def _apply(fn)
80
+ ret = super
81
+ flatten_parameters
82
+ ret
83
+ end
84
+
85
+ def reset_parameters
86
+ stdv = 1.0 / Math.sqrt(@hidden_size)
87
+ parameters.each do |weight|
88
+ Init.uniform!(weight, a: -stdv, b: stdv)
89
+ end
90
+ end
91
+
92
+ def permute_hidden(hx, permutation)
93
+ if permutation.nil?
94
+ return hx
95
+ end
96
+ raise NotImplementedYet
97
+ end
98
+
99
+ def forward(input, hx: nil)
100
+ is_packed = false # TODO isinstance(input, PackedSequence)
101
+ if is_packed
102
+ input, batch_sizes, sorted_indices, unsorted_indices = input
103
+ max_batch_size = batch_sizes[0]
104
+ max_batch_size = max_batch_size.to_i
105
+ else
106
+ batch_sizes = nil
107
+ max_batch_size = @batch_first ? input.size(0) : input.size(1)
108
+ sorted_indices = nil
109
+ unsorted_indices = nil
110
+ end
111
+
112
+ if hx.nil?
113
+ num_directions = @bidirectional ? 2 : 1
114
+ hx = Torch.zeros(@num_layers * num_directions, max_batch_size,
115
+ @hidden_size, dtype: input.dtype, device: input.device)
116
+ else
117
+ # Each batch of the hidden state should match the input sequence that
118
+ # the user believes he/she is passing in.
119
+ hx = permute_hidden(hx, sorted_indices)
120
+ end
121
+
122
+ check_forward_args(input, hx, batch_sizes)
123
+ _rnn_impls = {
124
+ "RNN_TANH" => Torch.method(:rnn_tanh),
125
+ "RNN_RELU" => Torch.method(:rnn_relu)
126
+ }
127
+ _impl = _rnn_impls[@mode]
128
+ if batch_sizes.nil?
129
+ result = _impl.call(input, hx, _get_flat_weights, @bias, @num_layers,
130
+ @dropout, @training, @bidirectional, @batch_first)
131
+ else
132
+ result = _impl.call(input, batch_sizes, hx, _get_flat_weights, @bias,
133
+ @num_layers, @dropout, @training, @bidirectional)
134
+ end
135
+ output = result[0]
136
+ hidden = result[1]
137
+
138
+ if is_packed
139
+ raise NotImplementedYet
140
+ # output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
141
+ end
142
+ [output, permute_hidden(hidden, unsorted_indices)]
143
+ end
144
+
145
+ # TODO add more parameters
146
+ def extra_inspect
147
+ s = String.new("%{input_size}, %{hidden_size}")
148
+ if @num_layers != 1
149
+ s += ", num_layers: %{num_layers}"
150
+ end
151
+ format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers)
152
+ end
153
+
154
+ private
155
+
156
+ def _flat_weights
157
+ @all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
158
+ end
159
+
160
+ def _get_flat_weights
161
+ _flat_weights
162
+ end
163
+
164
+ def check_input(input, batch_sizes)
165
+ expected_input_dim = !batch_sizes.nil? ? 2 : 3
166
+ if input.dim != expected_input_dim
167
+ raise ArgumentError, "input must have #{expected_input_dim} dimensions, got #{input.dim}"
168
+ end
169
+ if @input_size != input.size(-1)
170
+ raise ArgumentError, "input.size(-1) must be equal to input_size. Expected #{@input_size}, got #{input.size(-1)}"
171
+ end
172
+ end
173
+
174
+ def get_expected_hidden_size(input, batch_sizes)
175
+ if !batch_sizes.nil?
176
+ mini_batch = batch_sizes[0]
177
+ mini_batch = mini_batch.to_i
178
+ else
179
+ mini_batch = @batch_first ? input.size(0) : input.size(1)
180
+ end
181
+ num_directions = @bidirectional ? 2 : 1
182
+ [@num_layers * num_directions, mini_batch, @hidden_size]
183
+ end
184
+
185
+ def check_hidden_size(hx, expected_hidden_size)
186
+ if hx.size != expected_hidden_size
187
+ raise ArgumentError, "Expected hidden size #{expected_hidden_size.inspect}, got #{hx.size.inspect}"
188
+ end
189
+ end
190
+
191
+ def check_forward_args(input, hidden, batch_sizes)
192
+ check_input(input, batch_sizes)
193
+ expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
194
+ check_hidden_size(hidden, expected_hidden_size)
195
+ end
196
+ end
197
+ end
198
+ end