ruby-dnn 0.9.4 → 0.10.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,36 +1,106 @@
1
1
  module DNN
2
+ module Regularizers
2
3
 
3
- class Lasso
4
- def initialize(l1_lambda, param)
5
- @l1_lambda = l1_lambda
6
- @param = param
7
- end
4
+ class Regularizer
5
+ attr_accessor :param
8
6
 
9
- def forward(x)
10
- x + @l1_lambda * @param.data.abs.sum
11
- end
7
+ def forward(x)
8
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'forward'")
9
+ end
10
+
11
+ def backward
12
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'backward'")
13
+ end
12
14
 
13
- def backward
14
- dlasso = Xumo::SFloat.ones(*@param.data.shape)
15
- dlasso[@param.data < 0] = -1
16
- @param.grad += @l1_lambda * dlasso
15
+ def to_hash(merge_hash)
16
+ hash = {class: self.class.name}
17
+ hash.merge!(merge_hash)
18
+ hash
19
+ end
17
20
  end
18
- end
19
21
 
22
+ class L1 < Regularizer
23
+ attr_accessor :l1_lambda
24
+
25
+ def self.from_hash(hash)
26
+ L1.new(hash[:l1_lambda])
27
+ end
28
+
29
+ def initialize(l1_lambda = 0.01)
30
+ @l1_lambda = l1_lambda
31
+ end
20
32
 
21
- class Ridge
22
- def initialize(l2_lambda, param)
23
- @l2_lambda = l2_lambda
24
- @param = param
33
+ def forward(x)
34
+ x + @l1_lambda * @param.data.abs.sum
35
+ end
36
+
37
+ def backward
38
+ dparam = Xumo::SFloat.ones(*@param.data.shape)
39
+ dparam[@param.data < 0] = -1
40
+ @param.grad += @l1_lambda * dparam
41
+ end
42
+
43
+ def to_hash
44
+ super(l1_lambda: @l1_lambda)
45
+ end
25
46
  end
26
47
 
27
- def forward(x)
28
- x + 0.5 * @l2_lambda * (@param.data**2).sum
48
+
49
+ class L2 < Regularizer
50
+ attr_accessor :l2_lambda
51
+
52
+ def self.from_hash(hash)
53
+ L2.new(hash[:l2_lambda])
54
+ end
55
+
56
+ def initialize(l2_lambda = 0.01)
57
+ @l2_lambda = l2_lambda
58
+ end
59
+
60
+ def forward(x)
61
+ x + 0.5 * @l2_lambda * (@param.data**2).sum
62
+ end
63
+
64
+ def backward
65
+ @param.grad += @l2_lambda * @param.data
66
+ end
67
+
68
+ def to_hash
69
+ super(l2_lambda: @l2_lambda)
70
+ end
29
71
  end
30
72
 
31
- def backward
32
- @param.grad += @l2_lambda * @param.data
73
+ class L1L2 < Regularizer
74
+ attr_accessor :l1_lambda
75
+ attr_accessor :l2_lambda
76
+
77
+ def self.from_hash(hash)
78
+ L1L2.new(hash[:l1_lambda], hash[:l2_lambda])
79
+ end
80
+
81
+ def initialize(l1_lambda = 0.01, l2_lambda = 0.01)
82
+ @l1_lambda = l1_lambda
83
+ @l2_lambda = l2_lambda
84
+ end
85
+
86
+ def forward(x)
87
+ l1 = @l1_lambda * @param.data.abs.sum
88
+ l2 = 0.5 * @l2_lambda * (@param.data**2).sum
89
+ x + l1 + l2
90
+ end
91
+
92
+ def backward
93
+ dparam = Xumo::SFloat.ones(*@param.data.shape)
94
+ dparam[@param.data < 0] = -1
95
+ @param.grad += @l1_lambda * dparam
96
+ @param.grad += @l2_lambda * @param.data
97
+ end
98
+
99
+ def to_hash
100
+ super(l1_lambda: l1_lambda, l2_lambda: l2_lambda)
101
+ end
102
+
33
103
  end
34
- end
35
104
 
105
+ end
36
106
  end
@@ -11,25 +11,36 @@ module DNN
11
11
  attr_reader :stateful
12
12
  # @return [Bool] Set the false, only the last of each cell of RNN is left.
13
13
  attr_reader :return_sequences
14
+ # @return [DNN::Initializers::Initializer] Recurrent weight initializer.
15
+ attr_reader :recurrent_weight_initializer
16
+ # @return [DNN::Regularizers::Regularizer] Recurrent weight regularization.
17
+ attr_reader :recurrent_weight_regularizer
14
18
 
15
19
  def initialize(num_nodes,
16
20
  stateful: false,
17
21
  return_sequences: true,
18
22
  weight_initializer: RandomNormal.new,
23
+ recurrent_weight_initializer: RandomNormal.new,
19
24
  bias_initializer: Zeros.new,
20
- l1_lambda: 0,
21
- l2_lambda: 0,
25
+ weight_regularizer: nil,
26
+ recurrent_weight_regularizer: nil,
27
+ bias_regularizer: nil,
22
28
  use_bias: true)
23
29
  super(weight_initializer: weight_initializer, bias_initializer: bias_initializer,
24
- l1_lambda: l1_lambda, l2_lambda: l2_lambda, use_bias: use_bias)
30
+ weight_regularizer: weight_regularizer, bias_regularizer: bias_regularizer, use_bias: use_bias)
25
31
  @num_nodes = num_nodes
26
32
  @stateful = stateful
27
33
  @return_sequences = return_sequences
28
34
  @layers = []
29
- @hidden = @params[:h] = Param.new
30
- # TODO
31
- # Change to a good name.
32
- @params[:weight2] = @weight2 = Param.new
35
+ @hidden = @params[:hidden] = Param.new
36
+ @params[:recurrent_weight] = @recurrent_weight = Param.new(nil, 0)
37
+ @recurrent_weight_initializer = recurrent_weight_initializer
38
+ @recurrent_weight_regularizer = recurrent_weight_regularizer
39
+ end
40
+
41
+ def build(input_shape)
42
+ super
43
+ @time_length = @input_shape[0]
33
44
  end
34
45
 
35
46
  def forward(xs)
@@ -38,6 +49,7 @@ module DNN
38
49
  h = (@stateful && @hidden.data) ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
39
50
  xs.shape[1].times do |t|
40
51
  x = xs[true, t, false]
52
+ @layers[t].trainable = @trainable
41
53
  h = @layers[t].forward(x, h)
42
54
  hs[true, t, false] = h
43
55
  end
@@ -46,9 +58,6 @@ module DNN
46
58
  end
47
59
 
48
60
  def backward(dh2s)
49
- @weight.grad = Xumo::SFloat.zeros(*@weight.data.shape)
50
- @weight2.grad = Xumo::SFloat.zeros(*@weight2.data.shape)
51
- @bias.grad = Xumo::SFloat.zeros(*@bias.data.shape) if @bias
52
61
  unless @return_sequences
53
62
  dh = dh2s
54
63
  dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
@@ -72,7 +81,9 @@ module DNN
72
81
  hash = {
73
82
  num_nodes: @num_nodes,
74
83
  stateful: @stateful,
75
- return_sequences: @return_sequences
84
+ return_sequences: @return_sequences,
85
+ recurrent_weight_initializer: @recurrent_weight_initializer.to_hash,
86
+ recurrent_weight_regularizer: @recurrent_weight_regularizer&.to_hash,
76
87
  }
77
88
  hash.merge!(merge_hash) if merge_hash
78
89
  super(hash)
@@ -85,48 +96,48 @@ module DNN
85
96
 
86
97
  def regularizers
87
98
  regularizers = []
88
- if @l1_lambda > 0
89
- regularizers << Lasso.new(@l1_lambda, @weight)
90
- regularizers << Lasso.new(@l1_lambda, @weight2)
91
- end
92
- if @l2_lambda > 0
93
- regularizers << Ridge.new(@l2_lambda, @weight)
94
- regularizers << Ridge.new(@l2_lambda, @weight2)
95
- end
99
+ regularizers << @weight_regularizer if @weight_regularizer
100
+ regularizers << @recurrent_weight_regularizer if @recurrent_weight_regularizer
101
+ regularizers << @bias_regularizer if @bias_regularizer
96
102
  regularizers
97
103
  end
98
104
 
99
- private
100
-
101
- def init_params
102
- @time_length = @input_shape[0]
105
+ private def init_weight_and_bias
106
+ super
107
+ @recurrent_weight_initializer.init_param(self, @recurrent_weight)
108
+ @recurrent_weight_regularizer.param = @recurrent_weight if @recurrent_weight_regularizer
103
109
  end
104
110
  end
105
111
 
106
112
 
107
113
  class SimpleRNN_Dense
108
- def initialize(weight, weight2, bias, activation)
114
+ attr_accessor :trainable
115
+
116
+ def initialize(weight, recurrent_weight, bias, activation)
109
117
  @weight = weight
110
- @weight2 = weight2
118
+ @recurrent_weight = recurrent_weight
111
119
  @bias = bias
112
120
  @activation = activation.clone
121
+ @trainable = true
113
122
  end
114
123
 
115
124
  def forward(x, h)
116
125
  @x = x
117
126
  @h = h
118
- h2 = x.dot(@weight.data) + h.dot(@weight2.data)
127
+ h2 = x.dot(@weight.data) + h.dot(@recurrent_weight.data)
119
128
  h2 += @bias.data if @bias
120
129
  @activation.forward(h2)
121
130
  end
122
131
 
123
132
  def backward(dh2)
124
133
  dh2 = @activation.backward(dh2)
125
- @weight.grad += @x.transpose.dot(dh2)
126
- @weight2.grad += @h.transpose.dot(dh2)
127
- @bias.grad += dh2.sum(0) if @bias
134
+ if @trainable
135
+ @weight.grad += @x.transpose.dot(dh2)
136
+ @recurrent_weight.grad += @h.transpose.dot(dh2)
137
+ @bias.grad += dh2.sum(0) if @bias
138
+ end
128
139
  dx = dh2.dot(@weight.data.transpose)
129
- dh = dh2.dot(@weight2.data.transpose)
140
+ dh = dh2.dot(@recurrent_weight.data.transpose)
130
141
  [dx, dh]
131
142
  end
132
143
  end
@@ -137,15 +148,17 @@ module DNN
137
148
 
138
149
  attr_reader :activation
139
150
 
140
- def self.load_hash(hash)
151
+ def self.from_hash(hash)
141
152
  simple_rnn = self.new(hash[:num_nodes],
142
153
  stateful: hash[:stateful],
143
154
  return_sequences: hash[:return_sequences],
144
- activation: Utils.load_hash(hash[:activation]),
145
- weight_initializer: Utils.load_hash(hash[:weight_initializer]),
146
- bias_initializer: Utils.load_hash(hash[:bias_initializer]),
147
- l1_lambda: hash[:l1_lambda],
148
- l2_lambda: hash[:l2_lambda],
155
+ activation: Utils.from_hash(hash[:activation]),
156
+ weight_initializer: Utils.from_hash(hash[:weight_initializer]),
157
+ recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
158
+ bias_initializer: Utils.from_hash(hash[:bias_initializer]),
159
+ weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
160
+ recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
161
+ bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
149
162
  use_bias: hash[:use_bias])
150
163
  simple_rnn
151
164
  end
@@ -155,53 +168,56 @@ module DNN
155
168
  return_sequences: true,
156
169
  activation: Tanh.new,
157
170
  weight_initializer: RandomNormal.new,
171
+ recurrent_weight_initializer: RandomNormal.new,
158
172
  bias_initializer: Zeros.new,
159
- l1_lambda: 0,
160
- l2_lambda: 0,
173
+ weight_regularizer: nil,
174
+ recurrent_weight_regularizer: nil,
175
+ bias_regularizer: nil,
161
176
  use_bias: true)
162
177
  super(num_nodes,
163
178
  stateful: stateful,
164
179
  return_sequences: return_sequences,
165
180
  weight_initializer: weight_initializer,
181
+ recurrent_weight_initializer: recurrent_weight_initializer,
166
182
  bias_initializer: bias_initializer,
167
- l1_lambda: l1_lambda,
168
- l2_lambda: l2_lambda,
183
+ weight_regularizer: weight_regularizer,
184
+ recurrent_weight_regularizer: recurrent_weight_regularizer,
185
+ bias_regularizer: bias_regularizer,
169
186
  use_bias: use_bias)
170
187
  @activation = activation
171
188
  end
172
189
 
173
- def to_hash
174
- super({activation: @activation.to_hash})
175
- end
176
-
177
- private
178
-
179
- def init_params
180
- super()
181
- num_prev_nodes = @input_shape[1]
190
+ def build(input_shape)
191
+ super
192
+ num_prev_nodes = input_shape[1]
182
193
  @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
183
- @weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
194
+ @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
184
195
  @bias.data = Xumo::SFloat.new(@num_nodes) if @bias
185
- @weight_initializer.init_param(self, @weight)
186
- @weight_initializer.init_param(self, @weight2)
187
- @bias_initializer.init_param(self, @bias) if @bias
196
+ init_weight_and_bias
188
197
  @time_length.times do |t|
189
- @layers << SimpleRNN_Dense.new(@weight, @weight2, @bias, @activation)
198
+ @layers << SimpleRNN_Dense.new(@weight, @recurrent_weight, @bias, @activation)
190
199
  end
191
200
  end
201
+
202
+ def to_hash
203
+ super({activation: @activation.to_hash})
204
+ end
192
205
  end
193
206
 
194
207
 
195
208
  class LSTM_Dense
196
- def initialize(weight, weight2, bias)
209
+ attr_accessor :trainable
210
+
211
+ def initialize(weight, recurrent_weight, bias)
197
212
  @weight = weight
198
- @weight2 = weight2
213
+ @recurrent_weight = recurrent_weight
199
214
  @bias = bias
200
215
  @tanh = Tanh.new
201
216
  @g_tanh = Tanh.new
202
217
  @forget_sigmoid = Sigmoid.new
203
218
  @in_sigmoid = Sigmoid.new
204
219
  @out_sigmoid = Sigmoid.new
220
+ @trainable = true
205
221
  end
206
222
 
207
223
  def forward(x, h, c)
@@ -209,7 +225,7 @@ module DNN
209
225
  @h = h
210
226
  @c = c
211
227
  num_nodes = h.shape[1]
212
- a = x.dot(@weight.data) + h.dot(@weight2.data)
228
+ a = x.dot(@weight.data) + h.dot(@recurrent_weight.data)
213
229
  a += @bias.data if @bias
214
230
 
215
231
  @forget = @forget_sigmoid.forward(a[true, 0...num_nodes])
@@ -234,11 +250,13 @@ module DNN
234
250
 
235
251
  da = Xumo::SFloat.hstack([dforget, dg, din, dout])
236
252
 
237
- @weight.grad += @x.transpose.dot(da)
238
- @weight2.grad += @h.transpose.dot(da)
239
- @bias.grad += da.sum(0) if @bias
253
+ if @trainable
254
+ @weight.grad += @x.transpose.dot(da)
255
+ @recurrent_weight.grad += @h.transpose.dot(da)
256
+ @bias.grad += da.sum(0) if @bias
257
+ end
240
258
  dx = da.dot(@weight.data.transpose)
241
- dh = da.dot(@weight2.data.transpose)
259
+ dh = da.dot(@recurrent_weight.data.transpose)
242
260
  dc = dc2_tmp * @forget
243
261
  [dx, dh, dc]
244
262
  end
@@ -246,14 +264,16 @@ module DNN
246
264
 
247
265
 
248
266
  class LSTM < RNN
249
- def self.load_hash(hash)
267
+ def self.from_hash(hash)
250
268
  lstm = self.new(hash[:num_nodes],
251
269
  stateful: hash[:stateful],
252
270
  return_sequences: hash[:return_sequences],
253
- weight_initializer: Utils.load_hash(hash[:weight_initializer]),
254
- bias_initializer: Utils.load_hash(hash[:bias_initializer]),
255
- l1_lambda: hash[:l1_lambda],
256
- l2_lambda: hash[:l2_lambda],
271
+ weight_initializer: Utils.from_hash(hash[:weight_initializer]),
272
+ recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
273
+ bias_initializer: Utils.from_hash(hash[:bias_initializer]),
274
+ weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
275
+ recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
276
+ bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
257
277
  use_bias: hash[:use_bias])
258
278
  lstm
259
279
  end
@@ -262,12 +282,26 @@ module DNN
262
282
  stateful: false,
263
283
  return_sequences: true,
264
284
  weight_initializer: RandomNormal.new,
285
+ recurrent_weight_initializer: RandomNormal.new,
265
286
  bias_initializer: Zeros.new,
266
- l1_lambda: 0,
267
- l2_lambda: 0,
287
+ weight_regularizer: nil,
288
+ recurrent_weight_regularizer: nil,
289
+ bias_regularizer: nil,
268
290
  use_bias: true)
269
291
  super
270
- @cell = @params[:c] = Param.new
292
+ @cell = @params[:cell] = Param.new
293
+ end
294
+
295
+ def build(input_shape)
296
+ super
297
+ num_prev_nodes = input_shape[1]
298
+ @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
299
+ @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
300
+ @bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
301
+ init_weight_and_bias
302
+ @time_length.times do |t|
303
+ @layers << LSTM_Dense.new(@weight, @recurrent_weight, @bias)
304
+ end
271
305
  end
272
306
 
273
307
  def forward(xs)
@@ -283,6 +317,7 @@ module DNN
283
317
  c ||= Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
284
318
  xs.shape[1].times do |t|
285
319
  x = xs[true, t, false]
320
+ @layers[t].trainable = @trainable
286
321
  h, c = @layers[t].forward(x, h, c)
287
322
  hs[true, t, false] = h
288
323
  end
@@ -292,9 +327,6 @@ module DNN
292
327
  end
293
328
 
294
329
  def backward(dh2s)
295
- @weight.grad = Xumo::SFloat.zeros(*@weight.data.shape)
296
- @weight2.grad = Xumo::SFloat.zeros(*@weight2.data.shape)
297
- @bias.grad = Xumo::SFloat.zeros(*@bias.data.shape) if @bias
298
330
  unless @return_sequences
299
331
  dh = dh2s
300
332
  dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
@@ -315,33 +347,20 @@ module DNN
315
347
  super()
316
348
  @cell.data = @cell.data.fill(0) if @cell.data
317
349
  end
318
-
319
- private
320
-
321
- def init_params
322
- super()
323
- num_prev_nodes = @input_shape[1]
324
- @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
325
- @weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
326
- @bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
327
- @weight_initializer.init_param(self, @weight)
328
- @weight_initializer.init_param(self, @weight2)
329
- @bias_initializer.init_param(self, @bias) if @bias
330
- @time_length.times do |t|
331
- @layers << LSTM_Dense.new(@weight, @weight2, @bias)
332
- end
333
- end
334
350
  end
335
351
 
336
352
 
337
353
  class GRU_Dense
338
- def initialize(weight, weight2, bias)
354
+ attr_accessor :trainable
355
+
356
+ def initialize(weight, recurrent_weight, bias)
339
357
  @weight = weight
340
- @weight2 = weight2
358
+ @recurrent_weight = recurrent_weight
341
359
  @bias = bias
342
360
  @update_sigmoid = Sigmoid.new
343
361
  @reset_sigmoid = Sigmoid.new
344
362
  @tanh = Tanh.new
363
+ @trainable = true
345
364
  end
346
365
 
347
366
  def forward(x, h)
@@ -349,60 +368,68 @@ module DNN
349
368
  @h = h
350
369
  num_nodes = h.shape[1]
351
370
  @weight_a = @weight.data[true, 0...(num_nodes * 2)]
352
- @weight2_a = @weight2.data[true, 0...(num_nodes * 2)]
371
+ @weight2_a = @recurrent_weight.data[true, 0...(num_nodes * 2)]
353
372
  a = x.dot(@weight_a) + h.dot(@weight2_a)
354
373
  a += @bias.data[0...(num_nodes * 2)] if @bias
355
374
  @update = @update_sigmoid.forward(a[true, 0...num_nodes])
356
375
  @reset = @reset_sigmoid.forward(a[true, num_nodes..-1])
357
376
 
358
377
  @weight_h = @weight.data[true, (num_nodes * 2)..-1]
359
- @weight2_h = @weight2.data[true, (num_nodes * 2)..-1]
378
+ @weight2_h = @recurrent_weight.data[true, (num_nodes * 2)..-1]
360
379
  @tanh_h = if @bias
361
380
  bias_h = @bias.data[(num_nodes * 2)..-1]
362
381
  @tanh.forward(x.dot(@weight_h) + (h * @reset).dot(@weight2_h) + bias_h)
363
382
  else
364
383
  @tanh.forward(x.dot(@weight_h) + (h * @reset).dot(@weight2_h))
365
384
  end
366
- h2 = (1 - @update) * h + @update * @tanh_h
385
+ h2 = (1 - @update) * @tanh_h + @update * h
367
386
  h2
368
387
  end
369
388
 
370
389
  def backward(dh2)
371
- dtanh_h = @tanh.backward(dh2 * @update)
372
- dh = dh2 * (1 - @update)
390
+ dtanh_h = @tanh.backward(dh2 * (1 - @update))
391
+ dh = dh2 * @update
373
392
 
374
- dweight_h = @x.transpose.dot(dtanh_h)
393
+ if @trainable
394
+ dweight_h = @x.transpose.dot(dtanh_h)
395
+ dweight2_h = (@h * @reset).transpose.dot(dtanh_h)
396
+ dbias_h = dtanh_h.sum(0) if @bias
397
+ end
375
398
  dx = dtanh_h.dot(@weight_h.transpose)
376
- dweight2_h = (@h * @reset).transpose.dot(dtanh_h)
377
399
  dh += dtanh_h.dot(@weight2_h.transpose) * @reset
378
- dbias_h = dtanh_h.sum(0) if @bias
379
400
 
380
401
  dreset = @reset_sigmoid.backward(dtanh_h.dot(@weight2_h.transpose) * @h)
381
- dupdate = @update_sigmoid.backward(dh2 * @tanh_h - dh2 * @h)
402
+ dupdate = @update_sigmoid.backward(dh2 * @h - dh2 * @tanh_h)
382
403
  da = Xumo::SFloat.hstack([dupdate, dreset])
383
- dweight_a = @x.transpose.dot(da)
404
+ if @trainable
405
+ dweight_a = @x.transpose.dot(da)
406
+ dweight2_a = @h.transpose.dot(da)
407
+ dbias_a = da.sum(0) if @bias
408
+ end
384
409
  dx += da.dot(@weight_a.transpose)
385
- dweight2_a = @h.transpose.dot(da)
386
410
  dh += da.dot(@weight2_a.transpose)
387
- dbias_a = da.sum(0) if @bias
388
411
 
389
- @weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
390
- @weight2.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
391
- @bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h]) if @bias
412
+ if @trainable
413
+ @weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
414
+ @recurrent_weight.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
415
+ @bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h]) if @bias
416
+ end
392
417
  [dx, dh]
393
418
  end
394
419
  end
395
420
 
396
421
 
397
422
  class GRU < RNN
398
- def self.load_hash(hash)
423
+ def self.from_hash(hash)
399
424
  gru = self.new(hash[:num_nodes],
400
425
  stateful: hash[:stateful],
401
426
  return_sequences: hash[:return_sequences],
402
- weight_initializer: Utils.load_hash(hash[:weight_initializer]),
403
- bias_initializer: Utils.load_hash(hash[:bias_initializer]),
404
- l1_lambda: hash[:l1_lambda],
405
- l2_lambda: hash[:l2_lambda],
427
+ weight_initializer: Utils.from_hash(hash[:weight_initializer]),
428
+ recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
429
+ bias_initializer: Utils.from_hash(hash[:bias_initializer]),
430
+ weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
431
+ recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
432
+ bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
406
433
  use_bias: hash[:use_bias])
407
434
  gru
408
435
  end
@@ -411,26 +438,24 @@ module DNN
411
438
  stateful: false,
412
439
  return_sequences: true,
413
440
  weight_initializer: RandomNormal.new,
441
+ recurrent_weight_initializer: RandomNormal.new,
414
442
  bias_initializer: Zeros.new,
415
- l1_lambda: 0,
416
- l2_lambda: 0,
443
+ weight_regularizer: nil,
444
+ recurrent_weight_regularizer: nil,
445
+ bias_regularizer: nil,
417
446
  use_bias: true)
418
447
  super
419
448
  end
420
-
421
- private
422
449
 
423
- def init_params
424
- super()
450
+ def build(input_shape)
451
+ super
425
452
  num_prev_nodes = @input_shape[1]
426
453
  @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 3)
427
- @weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
454
+ @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
428
455
  @bias.data = Xumo::SFloat.new(@num_nodes * 3) if @bias
429
- @weight_initializer.init_param(self, @weight)
430
- @weight_initializer.init_param(self, @weight2)
431
- @bias_initializer.init_param(self, @bias) if @bias
456
+ init_weight_and_bias
432
457
  @time_length.times do |t|
433
- @layers << GRU_Dense.new(@weight, @weight2, @bias)
458
+ @layers << GRU_Dense.new(@weight, @recurrent_weight, @bias)
434
459
  end
435
460
  end
436
461
  end