torch-rb 0.1.2 → 0.1.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +18 -6
  5. data/ext/torch/ext.cpp +148 -369
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +242 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +240 -131
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +27 -22
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -38
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +411 -22
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +201 -20
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +2 -2
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +56 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +48 -16
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +71 -30
  139. data/lib/torch/utils/data/data_loader.rb +10 -4
  140. data/lib/torch/utils/data/tensor_dataset.rb +3 -0
  141. data/lib/torch/version.rb +1 -1
  142. metadata +123 -6
@@ -1,8 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
- class MSELoss < Module
3
+ class MSELoss < Loss
4
4
  def initialize(reduction: "mean")
5
- @reduction = reduction
5
+ super(reduction)
6
6
  end
7
7
 
8
8
  def forward(input, target)
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelMarginLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_margin_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelSoftMarginLoss < WeightedLoss
4
+ def initialize(weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_soft_margin_loss(input, target, weight: @weight, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,17 @@
1
+ module Torch
2
+ module NN
3
+ class MultiMarginLoss < WeightedLoss
4
+ def initialize(p: 1, margin: 1.0, weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ raise ArgumentError, "only p == 1 and p == 2 supported" if p != 1 && p != 2
7
+ raise ArgumentError, "weight must be nil or have one dimension" unless weight.nil? || weight.dim == 1
8
+ @p = p
9
+ @margin = margin
10
+ end
11
+
12
+ def forward(input, target)
13
+ F.multi_margin_loss(input, target, p: @p, margin: @margin, weight: @weight, reduction: @reduction)
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class NLLLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PairwiseDistance < Module
4
+ def initialize(p: 2.0, eps: 1e-6, keepdim: false)
5
+ super()
6
+ @norm = p
7
+ @eps = eps
8
+ @keepdim = keepdim
9
+ end
10
+
11
+ def forward(x1, x2)
12
+ F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -6,8 +6,8 @@ module Torch
6
6
  Tensor._make_subclass(data, requires_grad)
7
7
  end
8
8
 
9
- def grad
10
- _grad if _grad_defined
9
+ def inspect
10
+ "Parameter containing:\n#{super}"
11
11
  end
12
12
  end
13
13
  end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PoissonNLLLoss < Loss
4
+ def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
5
+ super(reduction)
6
+ @log_input = log_input
7
+ @full = full
8
+ @eps = eps
9
+ end
10
+
11
+ def forward(log_input, target)
12
+ F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class PReLU < Module
4
+ def initialize(num_parameters: 1, init: 0.25)
5
+ @num_parameters = num_parameters
6
+ super()
7
+ @weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
8
+ end
9
+
10
+ def forward(input)
11
+ F.prelu(input, @weight)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("num_parameters: %s", @num_parameters)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPad1d < ReflectionPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPad2d < ReflectionPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPadNd < Module
4
+ def forward(input)
5
+ F.pad(input, @padding, mode: "reflect")
6
+ end
7
+
8
+ def extra_inspect
9
+ @padding.inspect
10
+ end
11
+ end
12
+ end
13
+ end
data/lib/torch/nn/relu.rb CHANGED
@@ -1,12 +1,17 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ReLU < Module
4
- def initialize #(inplace: false)
5
- # @inplace = inplace
4
+ def initialize(inplace: false)
5
+ super()
6
+ @inplace = inplace
6
7
  end
7
8
 
8
9
  def forward(input)
9
- F.relu(input) #, inplace: @inplace)
10
+ F.relu(input, inplace: @inplace)
11
+ end
12
+
13
+ def extra_inspect
14
+ @inplace ? "inplace: true" : ""
10
15
  end
11
16
  end
12
17
  end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad1d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad2d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad3d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _ntuple(6, padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPadNd < Module
4
+ def forward(input)
5
+ F.pad(input, @padding, mode: "replicate")
6
+ end
7
+
8
+ def extra_inspect
9
+ @padding.inspect
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class RNN < RNNBase
4
+ def initialize(*args, **options)
5
+ if options.key?(:nonlinearity)
6
+ if options[:nonlinearity] == "tanh"
7
+ mode = "RNN_TANH"
8
+ elsif options[:nonlinearity] == "relu"
9
+ mode = "RNN_RELU"
10
+ else
11
+ raise ArgumentError, "Unknown nonlinearity: #{options[:nonlinearity]}"
12
+ end
13
+ options.delete(:nonlinearity)
14
+ else
15
+ mode = "RNN_TANH"
16
+ end
17
+
18
+ super(mode, *args, **options)
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,198 @@
1
+ module Torch
2
+ module NN
3
+ class RNNBase < Module
4
+ def initialize(mode, input_size, hidden_size, num_layers: 1, bias: true,
5
+ batch_first: false, dropout: 0.0, bidirectional: false)
6
+
7
+ super()
8
+ @mode = mode
9
+ @input_size = input_size
10
+ @hidden_size = hidden_size
11
+ @num_layers = num_layers
12
+ @bias = bias
13
+ @batch_first = batch_first
14
+ @dropout = dropout.to_f
15
+ @bidirectional = bidirectional
16
+ num_directions = bidirectional ? 2 : 1
17
+
18
+ if !dropout.is_a?(Numeric) || !(dropout >= 0 && dropout <= 1)
19
+ raise ArgumentError, "dropout should be a number in range [0, 1] " +
20
+ "representing the probability of an element being " +
21
+ "zeroed"
22
+ end
23
+ if dropout > 0 && num_layers == 1
24
+ warn "dropout option adds dropout after all but last " +
25
+ "recurrent layer, so non-zero dropout expects " +
26
+ "num_layers greater than 1, but got dropout=#{dropout} and " +
27
+ "num_layers=#{num_layers}"
28
+ end
29
+
30
+ gate_size =
31
+ case mode
32
+ when "LSTM"
33
+ 4 * hidden_size
34
+ when "GRU"
35
+ 3 * hidden_size
36
+ when "RNN_TANH"
37
+ hidden_size
38
+ when "RNN_RELU"
39
+ hidden_size
40
+ else
41
+ raise ArgumentError, "Unrecognized RNN mode: #{mode}"
42
+ end
43
+
44
+ @all_weights = []
45
+ num_layers.times do |layer|
46
+ num_directions.times do |direction|
47
+ layer_input_size = layer == 0 ? input_size : hidden_size * num_directions
48
+
49
+ w_ih = Parameter.new(Torch::Tensor.new(gate_size, layer_input_size))
50
+ w_hh = Parameter.new(Torch::Tensor.new(gate_size, hidden_size))
51
+ b_ih = Parameter.new(Torch::Tensor.new(gate_size))
52
+ # Second bias vector included for CuDNN compatibility. Only one
53
+ # bias vector is needed in standard definition.
54
+ b_hh = Parameter.new(Torch::Tensor.new(gate_size))
55
+ layer_params = [w_ih, w_hh, b_ih, b_hh]
56
+
57
+ suffix = direction == 1 ? "_reverse" : ""
58
+ param_names = ["weight_ih_l%s%s", "weight_hh_l%s%s"]
59
+ if bias
60
+ param_names += ["bias_ih_l%s%s", "bias_hh_l%s%s"]
61
+ end
62
+ param_names.map! { |x| x % [layer, suffix] }
63
+
64
+ param_names.zip(layer_params) do |name, param|
65
+ instance_variable_set("@#{name}", param)
66
+ end
67
+ @all_weights << param_names
68
+ end
69
+ end
70
+
71
+ flatten_parameters
72
+ reset_parameters
73
+ end
74
+
75
+ def flatten_parameters
76
+ # no-op unless module is on the GPU and cuDNN is enabled
77
+ end
78
+
79
+ def _apply(fn)
80
+ ret = super
81
+ flatten_parameters
82
+ ret
83
+ end
84
+
85
+ def reset_parameters
86
+ stdv = 1.0 / Math.sqrt(@hidden_size)
87
+ parameters.each do |weight|
88
+ Init.uniform!(weight, a: -stdv, b: stdv)
89
+ end
90
+ end
91
+
92
+ def permute_hidden(hx, permutation)
93
+ if permutation.nil?
94
+ return hx
95
+ end
96
+ raise NotImplementedYet
97
+ end
98
+
99
+ def forward(input, hx: nil)
100
+ is_packed = false # TODO isinstance(input, PackedSequence)
101
+ if is_packed
102
+ input, batch_sizes, sorted_indices, unsorted_indices = input
103
+ max_batch_size = batch_sizes[0]
104
+ max_batch_size = max_batch_size.to_i
105
+ else
106
+ batch_sizes = nil
107
+ max_batch_size = @batch_first ? input.size(0) : input.size(1)
108
+ sorted_indices = nil
109
+ unsorted_indices = nil
110
+ end
111
+
112
+ if hx.nil?
113
+ num_directions = @bidirectional ? 2 : 1
114
+ hx = Torch.zeros(@num_layers * num_directions, max_batch_size,
115
+ @hidden_size, dtype: input.dtype, device: input.device)
116
+ else
117
+ # Each batch of the hidden state should match the input sequence that
118
+ # the user believes he/she is passing in.
119
+ hx = permute_hidden(hx, sorted_indices)
120
+ end
121
+
122
+ check_forward_args(input, hx, batch_sizes)
123
+ _rnn_impls = {
124
+ "RNN_TANH" => Torch.method(:rnn_tanh),
125
+ "RNN_RELU" => Torch.method(:rnn_relu)
126
+ }
127
+ _impl = _rnn_impls[@mode]
128
+ if batch_sizes.nil?
129
+ result = _impl.call(input, hx, _get_flat_weights, @bias, @num_layers,
130
+ @dropout, @training, @bidirectional, @batch_first)
131
+ else
132
+ result = _impl.call(input, batch_sizes, hx, _get_flat_weights, @bias,
133
+ @num_layers, @dropout, @training, @bidirectional)
134
+ end
135
+ output = result[0]
136
+ hidden = result[1]
137
+
138
+ if is_packed
139
+ raise NotImplementedYet
140
+ # output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
141
+ end
142
+ [output, permute_hidden(hidden, unsorted_indices)]
143
+ end
144
+
145
+ # TODO add more parameters
146
+ def extra_inspect
147
+ s = String.new("%{input_size}, %{hidden_size}")
148
+ if @num_layers != 1
149
+ s += ", num_layers: %{num_layers}"
150
+ end
151
+ format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers)
152
+ end
153
+
154
+ private
155
+
156
+ def _flat_weights
157
+ @all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
158
+ end
159
+
160
+ def _get_flat_weights
161
+ _flat_weights
162
+ end
163
+
164
+ def check_input(input, batch_sizes)
165
+ expected_input_dim = !batch_sizes.nil? ? 2 : 3
166
+ if input.dim != expected_input_dim
167
+ raise ArgumentError, "input must have #{expected_input_dim} dimensions, got #{input.dim}"
168
+ end
169
+ if @input_size != input.size(-1)
170
+ raise ArgumentError, "input.size(-1) must be equal to input_size. Expected #{@input_size}, got #{input.size(-1)}"
171
+ end
172
+ end
173
+
174
+ def get_expected_hidden_size(input, batch_sizes)
175
+ if !batch_sizes.nil?
176
+ mini_batch = batch_sizes[0]
177
+ mini_batch = mini_batch.to_i
178
+ else
179
+ mini_batch = @batch_first ? input.size(0) : input.size(1)
180
+ end
181
+ num_directions = @bidirectional ? 2 : 1
182
+ [@num_layers * num_directions, mini_batch, @hidden_size]
183
+ end
184
+
185
+ def check_hidden_size(hx, expected_hidden_size)
186
+ if hx.size != expected_hidden_size
187
+ raise ArgumentError, "Expected hidden size #{expected_hidden_size.inspect}, got #{hx.size.inspect}"
188
+ end
189
+ end
190
+
191
+ def check_forward_args(input, hidden, batch_sizes)
192
+ check_input(input, batch_sizes)
193
+ expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
194
+ check_hidden_size(hidden, expected_hidden_size)
195
+ end
196
+ end
197
+ end
198
+ end