torch-rb 0.1.2 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +46 -22
- data/README.md +18 -6
- data/ext/torch/ext.cpp +148 -369
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +615 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +242 -0
- data/ext/torch/tensor_functions.cpp +1920 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2975 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +240 -131
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +27 -22
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +109 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +16 -38
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +411 -22
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +201 -20
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +56 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +48 -16
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +71 -30
- data/lib/torch/utils/data/data_loader.rb +10 -4
- data/lib/torch/utils/data/tensor_dataset.rb +3 -0
- data/lib/torch/version.rb +1 -1
- metadata +123 -6
data/lib/torch/nn/mse_loss.rb
CHANGED
@@ -0,0 +1,13 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MultiLabelSoftMarginLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
end
|
7
|
+
|
8
|
+
def forward(input, target)
|
9
|
+
F.multilabel_soft_margin_loss(input, target, weight: @weight, reduction: @reduction)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MultiMarginLoss < WeightedLoss
|
4
|
+
def initialize(p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
raise ArgumentError, "only p == 1 and p == 2 supported" if p != 1 && p != 2
|
7
|
+
raise ArgumentError, "weight must be nil or have one dimension" unless weight.nil? || weight.dim == 1
|
8
|
+
@p = p
|
9
|
+
@margin = margin
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(input, target)
|
13
|
+
F.multi_margin_loss(input, target, p: @p, margin: @margin, weight: @weight, reduction: @reduction)
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class NLLLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, ignore_index: -100, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
@ignore_index = ignore_index
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PairwiseDistance < Module
|
4
|
+
def initialize(p: 2.0, eps: 1e-6, keepdim: false)
|
5
|
+
super()
|
6
|
+
@norm = p
|
7
|
+
@eps = eps
|
8
|
+
@keepdim = keepdim
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(x1, x2)
|
12
|
+
F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
data/lib/torch/nn/parameter.rb
CHANGED
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PoissonNLLLoss < Loss
|
4
|
+
def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@log_input = log_input
|
7
|
+
@full = full
|
8
|
+
@eps = eps
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(log_input, target)
|
12
|
+
F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PReLU < Module
|
4
|
+
def initialize(num_parameters: 1, init: 0.25)
|
5
|
+
@num_parameters = num_parameters
|
6
|
+
super()
|
7
|
+
@weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.prelu(input, @weight)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("num_parameters: %s", @num_parameters)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
data/lib/torch/nn/relu.rb
CHANGED
@@ -1,12 +1,17 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class ReLU < Module
|
4
|
-
def initialize
|
5
|
-
|
4
|
+
def initialize(inplace: false)
|
5
|
+
super()
|
6
|
+
@inplace = inplace
|
6
7
|
end
|
7
8
|
|
8
9
|
def forward(input)
|
9
|
-
F.relu(input
|
10
|
+
F.relu(input, inplace: @inplace)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
@inplace ? "inplace: true" : ""
|
10
15
|
end
|
11
16
|
end
|
12
17
|
end
|
data/lib/torch/nn/rnn.rb
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class RNN < RNNBase
|
4
|
+
def initialize(*args, **options)
|
5
|
+
if options.key?(:nonlinearity)
|
6
|
+
if options[:nonlinearity] == "tanh"
|
7
|
+
mode = "RNN_TANH"
|
8
|
+
elsif options[:nonlinearity] == "relu"
|
9
|
+
mode = "RNN_RELU"
|
10
|
+
else
|
11
|
+
raise ArgumentError, "Unknown nonlinearity: #{options[:nonlinearity]}"
|
12
|
+
end
|
13
|
+
options.delete(:nonlinearity)
|
14
|
+
else
|
15
|
+
mode = "RNN_TANH"
|
16
|
+
end
|
17
|
+
|
18
|
+
super(mode, *args, **options)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,198 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class RNNBase < Module
|
4
|
+
def initialize(mode, input_size, hidden_size, num_layers: 1, bias: true,
|
5
|
+
batch_first: false, dropout: 0.0, bidirectional: false)
|
6
|
+
|
7
|
+
super()
|
8
|
+
@mode = mode
|
9
|
+
@input_size = input_size
|
10
|
+
@hidden_size = hidden_size
|
11
|
+
@num_layers = num_layers
|
12
|
+
@bias = bias
|
13
|
+
@batch_first = batch_first
|
14
|
+
@dropout = dropout.to_f
|
15
|
+
@bidirectional = bidirectional
|
16
|
+
num_directions = bidirectional ? 2 : 1
|
17
|
+
|
18
|
+
if !dropout.is_a?(Numeric) || !(dropout >= 0 && dropout <= 1)
|
19
|
+
raise ArgumentError, "dropout should be a number in range [0, 1] " +
|
20
|
+
"representing the probability of an element being " +
|
21
|
+
"zeroed"
|
22
|
+
end
|
23
|
+
if dropout > 0 && num_layers == 1
|
24
|
+
warn "dropout option adds dropout after all but last " +
|
25
|
+
"recurrent layer, so non-zero dropout expects " +
|
26
|
+
"num_layers greater than 1, but got dropout=#{dropout} and " +
|
27
|
+
"num_layers=#{num_layers}"
|
28
|
+
end
|
29
|
+
|
30
|
+
gate_size =
|
31
|
+
case mode
|
32
|
+
when "LSTM"
|
33
|
+
4 * hidden_size
|
34
|
+
when "GRU"
|
35
|
+
3 * hidden_size
|
36
|
+
when "RNN_TANH"
|
37
|
+
hidden_size
|
38
|
+
when "RNN_RELU"
|
39
|
+
hidden_size
|
40
|
+
else
|
41
|
+
raise ArgumentError, "Unrecognized RNN mode: #{mode}"
|
42
|
+
end
|
43
|
+
|
44
|
+
@all_weights = []
|
45
|
+
num_layers.times do |layer|
|
46
|
+
num_directions.times do |direction|
|
47
|
+
layer_input_size = layer == 0 ? input_size : hidden_size * num_directions
|
48
|
+
|
49
|
+
w_ih = Parameter.new(Torch::Tensor.new(gate_size, layer_input_size))
|
50
|
+
w_hh = Parameter.new(Torch::Tensor.new(gate_size, hidden_size))
|
51
|
+
b_ih = Parameter.new(Torch::Tensor.new(gate_size))
|
52
|
+
# Second bias vector included for CuDNN compatibility. Only one
|
53
|
+
# bias vector is needed in standard definition.
|
54
|
+
b_hh = Parameter.new(Torch::Tensor.new(gate_size))
|
55
|
+
layer_params = [w_ih, w_hh, b_ih, b_hh]
|
56
|
+
|
57
|
+
suffix = direction == 1 ? "_reverse" : ""
|
58
|
+
param_names = ["weight_ih_l%s%s", "weight_hh_l%s%s"]
|
59
|
+
if bias
|
60
|
+
param_names += ["bias_ih_l%s%s", "bias_hh_l%s%s"]
|
61
|
+
end
|
62
|
+
param_names.map! { |x| x % [layer, suffix] }
|
63
|
+
|
64
|
+
param_names.zip(layer_params) do |name, param|
|
65
|
+
instance_variable_set("@#{name}", param)
|
66
|
+
end
|
67
|
+
@all_weights << param_names
|
68
|
+
end
|
69
|
+
end
|
70
|
+
|
71
|
+
flatten_parameters
|
72
|
+
reset_parameters
|
73
|
+
end
|
74
|
+
|
75
|
+
def flatten_parameters
|
76
|
+
# no-op unless module is on the GPU and cuDNN is enabled
|
77
|
+
end
|
78
|
+
|
79
|
+
def _apply(fn)
|
80
|
+
ret = super
|
81
|
+
flatten_parameters
|
82
|
+
ret
|
83
|
+
end
|
84
|
+
|
85
|
+
def reset_parameters
|
86
|
+
stdv = 1.0 / Math.sqrt(@hidden_size)
|
87
|
+
parameters.each do |weight|
|
88
|
+
Init.uniform!(weight, a: -stdv, b: stdv)
|
89
|
+
end
|
90
|
+
end
|
91
|
+
|
92
|
+
def permute_hidden(hx, permutation)
|
93
|
+
if permutation.nil?
|
94
|
+
return hx
|
95
|
+
end
|
96
|
+
raise NotImplementedYet
|
97
|
+
end
|
98
|
+
|
99
|
+
def forward(input, hx: nil)
|
100
|
+
is_packed = false # TODO isinstance(input, PackedSequence)
|
101
|
+
if is_packed
|
102
|
+
input, batch_sizes, sorted_indices, unsorted_indices = input
|
103
|
+
max_batch_size = batch_sizes[0]
|
104
|
+
max_batch_size = max_batch_size.to_i
|
105
|
+
else
|
106
|
+
batch_sizes = nil
|
107
|
+
max_batch_size = @batch_first ? input.size(0) : input.size(1)
|
108
|
+
sorted_indices = nil
|
109
|
+
unsorted_indices = nil
|
110
|
+
end
|
111
|
+
|
112
|
+
if hx.nil?
|
113
|
+
num_directions = @bidirectional ? 2 : 1
|
114
|
+
hx = Torch.zeros(@num_layers * num_directions, max_batch_size,
|
115
|
+
@hidden_size, dtype: input.dtype, device: input.device)
|
116
|
+
else
|
117
|
+
# Each batch of the hidden state should match the input sequence that
|
118
|
+
# the user believes he/she is passing in.
|
119
|
+
hx = permute_hidden(hx, sorted_indices)
|
120
|
+
end
|
121
|
+
|
122
|
+
check_forward_args(input, hx, batch_sizes)
|
123
|
+
_rnn_impls = {
|
124
|
+
"RNN_TANH" => Torch.method(:rnn_tanh),
|
125
|
+
"RNN_RELU" => Torch.method(:rnn_relu)
|
126
|
+
}
|
127
|
+
_impl = _rnn_impls[@mode]
|
128
|
+
if batch_sizes.nil?
|
129
|
+
result = _impl.call(input, hx, _get_flat_weights, @bias, @num_layers,
|
130
|
+
@dropout, @training, @bidirectional, @batch_first)
|
131
|
+
else
|
132
|
+
result = _impl.call(input, batch_sizes, hx, _get_flat_weights, @bias,
|
133
|
+
@num_layers, @dropout, @training, @bidirectional)
|
134
|
+
end
|
135
|
+
output = result[0]
|
136
|
+
hidden = result[1]
|
137
|
+
|
138
|
+
if is_packed
|
139
|
+
raise NotImplementedYet
|
140
|
+
# output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
141
|
+
end
|
142
|
+
[output, permute_hidden(hidden, unsorted_indices)]
|
143
|
+
end
|
144
|
+
|
145
|
+
# TODO add more parameters
|
146
|
+
def extra_inspect
|
147
|
+
s = String.new("%{input_size}, %{hidden_size}")
|
148
|
+
if @num_layers != 1
|
149
|
+
s += ", num_layers: %{num_layers}"
|
150
|
+
end
|
151
|
+
format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers)
|
152
|
+
end
|
153
|
+
|
154
|
+
private
|
155
|
+
|
156
|
+
def _flat_weights
|
157
|
+
@all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
|
158
|
+
end
|
159
|
+
|
160
|
+
def _get_flat_weights
|
161
|
+
_flat_weights
|
162
|
+
end
|
163
|
+
|
164
|
+
def check_input(input, batch_sizes)
|
165
|
+
expected_input_dim = !batch_sizes.nil? ? 2 : 3
|
166
|
+
if input.dim != expected_input_dim
|
167
|
+
raise ArgumentError, "input must have #{expected_input_dim} dimensions, got #{input.dim}"
|
168
|
+
end
|
169
|
+
if @input_size != input.size(-1)
|
170
|
+
raise ArgumentError, "input.size(-1) must be equal to input_size. Expected #{@input_size}, got #{input.size(-1)}"
|
171
|
+
end
|
172
|
+
end
|
173
|
+
|
174
|
+
def get_expected_hidden_size(input, batch_sizes)
|
175
|
+
if !batch_sizes.nil?
|
176
|
+
mini_batch = batch_sizes[0]
|
177
|
+
mini_batch = mini_batch.to_i
|
178
|
+
else
|
179
|
+
mini_batch = @batch_first ? input.size(0) : input.size(1)
|
180
|
+
end
|
181
|
+
num_directions = @bidirectional ? 2 : 1
|
182
|
+
[@num_layers * num_directions, mini_batch, @hidden_size]
|
183
|
+
end
|
184
|
+
|
185
|
+
def check_hidden_size(hx, expected_hidden_size)
|
186
|
+
if hx.size != expected_hidden_size
|
187
|
+
raise ArgumentError, "Expected hidden size #{expected_hidden_size.inspect}, got #{hx.size.inspect}"
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
def check_forward_args(input, hidden, batch_sizes)
|
192
|
+
check_input(input, batch_sizes)
|
193
|
+
expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
|
194
|
+
check_hidden_size(hidden, expected_hidden_size)
|
195
|
+
end
|
196
|
+
end
|
197
|
+
end
|
198
|
+
end
|