ruby-dnn 0.9.4 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,36 +1,106 @@
1
1
  module DNN
2
+ module Regularizers
2
3
 
3
- class Lasso
4
- def initialize(l1_lambda, param)
5
- @l1_lambda = l1_lambda
6
- @param = param
7
- end
4
+ class Regularizer
5
+ attr_accessor :param
8
6
 
9
- def forward(x)
10
- x + @l1_lambda * @param.data.abs.sum
11
- end
7
+ def forward(x)
8
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'forward'")
9
+ end
10
+
11
+ def backward
12
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'backward'")
13
+ end
12
14
 
13
- def backward
14
- dlasso = Xumo::SFloat.ones(*@param.data.shape)
15
- dlasso[@param.data < 0] = -1
16
- @param.grad += @l1_lambda * dlasso
15
+ def to_hash(merge_hash)
16
+ hash = {class: self.class.name}
17
+ hash.merge!(merge_hash)
18
+ hash
19
+ end
17
20
  end
18
- end
19
21
 
22
+ class L1 < Regularizer
23
+ attr_accessor :l1_lambda
24
+
25
+ def self.from_hash(hash)
26
+ L1.new(hash[:l1_lambda])
27
+ end
28
+
29
+ def initialize(l1_lambda = 0.01)
30
+ @l1_lambda = l1_lambda
31
+ end
20
32
 
21
- class Ridge
22
- def initialize(l2_lambda, param)
23
- @l2_lambda = l2_lambda
24
- @param = param
33
+ def forward(x)
34
+ x + @l1_lambda * @param.data.abs.sum
35
+ end
36
+
37
+ def backward
38
+ dparam = Xumo::SFloat.ones(*@param.data.shape)
39
+ dparam[@param.data < 0] = -1
40
+ @param.grad += @l1_lambda * dparam
41
+ end
42
+
43
+ def to_hash
44
+ super(l1_lambda: @l1_lambda)
45
+ end
25
46
  end
26
47
 
27
- def forward(x)
28
- x + 0.5 * @l2_lambda * (@param.data**2).sum
48
+
49
+ class L2 < Regularizer
50
+ attr_accessor :l2_lambda
51
+
52
+ def self.from_hash(hash)
53
+ L2.new(hash[:l2_lambda])
54
+ end
55
+
56
+ def initialize(l2_lambda = 0.01)
57
+ @l2_lambda = l2_lambda
58
+ end
59
+
60
+ def forward(x)
61
+ x + 0.5 * @l2_lambda * (@param.data**2).sum
62
+ end
63
+
64
+ def backward
65
+ @param.grad += @l2_lambda * @param.data
66
+ end
67
+
68
+ def to_hash
69
+ super(l2_lambda: @l2_lambda)
70
+ end
29
71
  end
30
72
 
31
- def backward
32
- @param.grad += @l2_lambda * @param.data
73
+ class L1L2 < Regularizer
74
+ attr_accessor :l1_lambda
75
+ attr_accessor :l2_lambda
76
+
77
+ def self.from_hash(hash)
78
+ L1L2.new(hash[:l1_lambda], hash[:l2_lambda])
79
+ end
80
+
81
+ def initialize(l1_lambda = 0.01, l2_lambda = 0.01)
82
+ @l1_lambda = l1_lambda
83
+ @l2_lambda = l2_lambda
84
+ end
85
+
86
+ def forward(x)
87
+ l1 = @l1_lambda * @param.data.abs.sum
88
+ l2 = 0.5 * @l2_lambda * (@param.data**2).sum
89
+ x + l1 + l2
90
+ end
91
+
92
+ def backward
93
+ dparam = Xumo::SFloat.ones(*@param.data.shape)
94
+ dparam[@param.data < 0] = -1
95
+ @param.grad += @l1_lambda * dparam
96
+ @param.grad += @l2_lambda * @param.data
97
+ end
98
+
99
+ def to_hash
100
+ super(l1_lambda: l1_lambda, l2_lambda: l2_lambda)
101
+ end
102
+
33
103
  end
34
- end
35
104
 
105
+ end
36
106
  end
@@ -11,25 +11,36 @@ module DNN
11
11
  attr_reader :stateful
12
12
  # @return [Bool] Set the false, only the last of each cell of RNN is left.
13
13
  attr_reader :return_sequences
14
+ # @return [DNN::Initializers::Initializer] Recurrent weight initializer.
15
+ attr_reader :recurrent_weight_initializer
16
+ # @return [DNN::Regularizers::Regularizer] Recurrent weight regularization.
17
+ attr_reader :recurrent_weight_regularizer
14
18
 
15
19
  def initialize(num_nodes,
16
20
  stateful: false,
17
21
  return_sequences: true,
18
22
  weight_initializer: RandomNormal.new,
23
+ recurrent_weight_initializer: RandomNormal.new,
19
24
  bias_initializer: Zeros.new,
20
- l1_lambda: 0,
21
- l2_lambda: 0,
25
+ weight_regularizer: nil,
26
+ recurrent_weight_regularizer: nil,
27
+ bias_regularizer: nil,
22
28
  use_bias: true)
23
29
  super(weight_initializer: weight_initializer, bias_initializer: bias_initializer,
24
- l1_lambda: l1_lambda, l2_lambda: l2_lambda, use_bias: use_bias)
30
+ weight_regularizer: weight_regularizer, bias_regularizer: bias_regularizer, use_bias: use_bias)
25
31
  @num_nodes = num_nodes
26
32
  @stateful = stateful
27
33
  @return_sequences = return_sequences
28
34
  @layers = []
29
- @hidden = @params[:h] = Param.new
30
- # TODO
31
- # Change to a good name.
32
- @params[:weight2] = @weight2 = Param.new
35
+ @hidden = @params[:hidden] = Param.new
36
+ @params[:recurrent_weight] = @recurrent_weight = Param.new(nil, 0)
37
+ @recurrent_weight_initializer = recurrent_weight_initializer
38
+ @recurrent_weight_regularizer = recurrent_weight_regularizer
39
+ end
40
+
41
+ def build(input_shape)
42
+ super
43
+ @time_length = @input_shape[0]
33
44
  end
34
45
 
35
46
  def forward(xs)
@@ -38,6 +49,7 @@ module DNN
38
49
  h = (@stateful && @hidden.data) ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
39
50
  xs.shape[1].times do |t|
40
51
  x = xs[true, t, false]
52
+ @layers[t].trainable = @trainable
41
53
  h = @layers[t].forward(x, h)
42
54
  hs[true, t, false] = h
43
55
  end
@@ -46,9 +58,6 @@ module DNN
46
58
  end
47
59
 
48
60
  def backward(dh2s)
49
- @weight.grad = Xumo::SFloat.zeros(*@weight.data.shape)
50
- @weight2.grad = Xumo::SFloat.zeros(*@weight2.data.shape)
51
- @bias.grad = Xumo::SFloat.zeros(*@bias.data.shape) if @bias
52
61
  unless @return_sequences
53
62
  dh = dh2s
54
63
  dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
@@ -72,7 +81,9 @@ module DNN
72
81
  hash = {
73
82
  num_nodes: @num_nodes,
74
83
  stateful: @stateful,
75
- return_sequences: @return_sequences
84
+ return_sequences: @return_sequences,
85
+ recurrent_weight_initializer: @recurrent_weight_initializer.to_hash,
86
+ recurrent_weight_regularizer: @recurrent_weight_regularizer&.to_hash,
76
87
  }
77
88
  hash.merge!(merge_hash) if merge_hash
78
89
  super(hash)
@@ -85,48 +96,48 @@ module DNN
85
96
 
86
97
  def regularizers
87
98
  regularizers = []
88
- if @l1_lambda > 0
89
- regularizers << Lasso.new(@l1_lambda, @weight)
90
- regularizers << Lasso.new(@l1_lambda, @weight2)
91
- end
92
- if @l2_lambda > 0
93
- regularizers << Ridge.new(@l2_lambda, @weight)
94
- regularizers << Ridge.new(@l2_lambda, @weight2)
95
- end
99
+ regularizers << @weight_regularizer if @weight_regularizer
100
+ regularizers << @recurrent_weight_regularizer if @recurrent_weight_regularizer
101
+ regularizers << @bias_regularizer if @bias_regularizer
96
102
  regularizers
97
103
  end
98
104
 
99
- private
100
-
101
- def init_params
102
- @time_length = @input_shape[0]
105
+ private def init_weight_and_bias
106
+ super
107
+ @recurrent_weight_initializer.init_param(self, @recurrent_weight)
108
+ @recurrent_weight_regularizer.param = @recurrent_weight if @recurrent_weight_regularizer
103
109
  end
104
110
  end
105
111
 
106
112
 
107
113
  class SimpleRNN_Dense
108
- def initialize(weight, weight2, bias, activation)
114
+ attr_accessor :trainable
115
+
116
+ def initialize(weight, recurrent_weight, bias, activation)
109
117
  @weight = weight
110
- @weight2 = weight2
118
+ @recurrent_weight = recurrent_weight
111
119
  @bias = bias
112
120
  @activation = activation.clone
121
+ @trainable = true
113
122
  end
114
123
 
115
124
  def forward(x, h)
116
125
  @x = x
117
126
  @h = h
118
- h2 = x.dot(@weight.data) + h.dot(@weight2.data)
127
+ h2 = x.dot(@weight.data) + h.dot(@recurrent_weight.data)
119
128
  h2 += @bias.data if @bias
120
129
  @activation.forward(h2)
121
130
  end
122
131
 
123
132
  def backward(dh2)
124
133
  dh2 = @activation.backward(dh2)
125
- @weight.grad += @x.transpose.dot(dh2)
126
- @weight2.grad += @h.transpose.dot(dh2)
127
- @bias.grad += dh2.sum(0) if @bias
134
+ if @trainable
135
+ @weight.grad += @x.transpose.dot(dh2)
136
+ @recurrent_weight.grad += @h.transpose.dot(dh2)
137
+ @bias.grad += dh2.sum(0) if @bias
138
+ end
128
139
  dx = dh2.dot(@weight.data.transpose)
129
- dh = dh2.dot(@weight2.data.transpose)
140
+ dh = dh2.dot(@recurrent_weight.data.transpose)
130
141
  [dx, dh]
131
142
  end
132
143
  end
@@ -137,15 +148,17 @@ module DNN
137
148
 
138
149
  attr_reader :activation
139
150
 
140
- def self.load_hash(hash)
151
+ def self.from_hash(hash)
141
152
  simple_rnn = self.new(hash[:num_nodes],
142
153
  stateful: hash[:stateful],
143
154
  return_sequences: hash[:return_sequences],
144
- activation: Utils.load_hash(hash[:activation]),
145
- weight_initializer: Utils.load_hash(hash[:weight_initializer]),
146
- bias_initializer: Utils.load_hash(hash[:bias_initializer]),
147
- l1_lambda: hash[:l1_lambda],
148
- l2_lambda: hash[:l2_lambda],
155
+ activation: Utils.from_hash(hash[:activation]),
156
+ weight_initializer: Utils.from_hash(hash[:weight_initializer]),
157
+ recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
158
+ bias_initializer: Utils.from_hash(hash[:bias_initializer]),
159
+ weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
160
+ recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
161
+ bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
149
162
  use_bias: hash[:use_bias])
150
163
  simple_rnn
151
164
  end
@@ -155,53 +168,56 @@ module DNN
155
168
  return_sequences: true,
156
169
  activation: Tanh.new,
157
170
  weight_initializer: RandomNormal.new,
171
+ recurrent_weight_initializer: RandomNormal.new,
158
172
  bias_initializer: Zeros.new,
159
- l1_lambda: 0,
160
- l2_lambda: 0,
173
+ weight_regularizer: nil,
174
+ recurrent_weight_regularizer: nil,
175
+ bias_regularizer: nil,
161
176
  use_bias: true)
162
177
  super(num_nodes,
163
178
  stateful: stateful,
164
179
  return_sequences: return_sequences,
165
180
  weight_initializer: weight_initializer,
181
+ recurrent_weight_initializer: recurrent_weight_initializer,
166
182
  bias_initializer: bias_initializer,
167
- l1_lambda: l1_lambda,
168
- l2_lambda: l2_lambda,
183
+ weight_regularizer: weight_regularizer,
184
+ recurrent_weight_regularizer: recurrent_weight_regularizer,
185
+ bias_regularizer: bias_regularizer,
169
186
  use_bias: use_bias)
170
187
  @activation = activation
171
188
  end
172
189
 
173
- def to_hash
174
- super({activation: @activation.to_hash})
175
- end
176
-
177
- private
178
-
179
- def init_params
180
- super()
181
- num_prev_nodes = @input_shape[1]
190
+ def build(input_shape)
191
+ super
192
+ num_prev_nodes = input_shape[1]
182
193
  @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
183
- @weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
194
+ @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
184
195
  @bias.data = Xumo::SFloat.new(@num_nodes) if @bias
185
- @weight_initializer.init_param(self, @weight)
186
- @weight_initializer.init_param(self, @weight2)
187
- @bias_initializer.init_param(self, @bias) if @bias
196
+ init_weight_and_bias
188
197
  @time_length.times do |t|
189
- @layers << SimpleRNN_Dense.new(@weight, @weight2, @bias, @activation)
198
+ @layers << SimpleRNN_Dense.new(@weight, @recurrent_weight, @bias, @activation)
190
199
  end
191
200
  end
201
+
202
+ def to_hash
203
+ super({activation: @activation.to_hash})
204
+ end
192
205
  end
193
206
 
194
207
 
195
208
  class LSTM_Dense
196
- def initialize(weight, weight2, bias)
209
+ attr_accessor :trainable
210
+
211
+ def initialize(weight, recurrent_weight, bias)
197
212
  @weight = weight
198
- @weight2 = weight2
213
+ @recurrent_weight = recurrent_weight
199
214
  @bias = bias
200
215
  @tanh = Tanh.new
201
216
  @g_tanh = Tanh.new
202
217
  @forget_sigmoid = Sigmoid.new
203
218
  @in_sigmoid = Sigmoid.new
204
219
  @out_sigmoid = Sigmoid.new
220
+ @trainable = true
205
221
  end
206
222
 
207
223
  def forward(x, h, c)
@@ -209,7 +225,7 @@ module DNN
209
225
  @h = h
210
226
  @c = c
211
227
  num_nodes = h.shape[1]
212
- a = x.dot(@weight.data) + h.dot(@weight2.data)
228
+ a = x.dot(@weight.data) + h.dot(@recurrent_weight.data)
213
229
  a += @bias.data if @bias
214
230
 
215
231
  @forget = @forget_sigmoid.forward(a[true, 0...num_nodes])
@@ -234,11 +250,13 @@ module DNN
234
250
 
235
251
  da = Xumo::SFloat.hstack([dforget, dg, din, dout])
236
252
 
237
- @weight.grad += @x.transpose.dot(da)
238
- @weight2.grad += @h.transpose.dot(da)
239
- @bias.grad += da.sum(0) if @bias
253
+ if @trainable
254
+ @weight.grad += @x.transpose.dot(da)
255
+ @recurrent_weight.grad += @h.transpose.dot(da)
256
+ @bias.grad += da.sum(0) if @bias
257
+ end
240
258
  dx = da.dot(@weight.data.transpose)
241
- dh = da.dot(@weight2.data.transpose)
259
+ dh = da.dot(@recurrent_weight.data.transpose)
242
260
  dc = dc2_tmp * @forget
243
261
  [dx, dh, dc]
244
262
  end
@@ -246,14 +264,16 @@ module DNN
246
264
 
247
265
 
248
266
  class LSTM < RNN
249
- def self.load_hash(hash)
267
+ def self.from_hash(hash)
250
268
  lstm = self.new(hash[:num_nodes],
251
269
  stateful: hash[:stateful],
252
270
  return_sequences: hash[:return_sequences],
253
- weight_initializer: Utils.load_hash(hash[:weight_initializer]),
254
- bias_initializer: Utils.load_hash(hash[:bias_initializer]),
255
- l1_lambda: hash[:l1_lambda],
256
- l2_lambda: hash[:l2_lambda],
271
+ weight_initializer: Utils.from_hash(hash[:weight_initializer]),
272
+ recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
273
+ bias_initializer: Utils.from_hash(hash[:bias_initializer]),
274
+ weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
275
+ recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
276
+ bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
257
277
  use_bias: hash[:use_bias])
258
278
  lstm
259
279
  end
@@ -262,12 +282,26 @@ module DNN
262
282
  stateful: false,
263
283
  return_sequences: true,
264
284
  weight_initializer: RandomNormal.new,
285
+ recurrent_weight_initializer: RandomNormal.new,
265
286
  bias_initializer: Zeros.new,
266
- l1_lambda: 0,
267
- l2_lambda: 0,
287
+ weight_regularizer: nil,
288
+ recurrent_weight_regularizer: nil,
289
+ bias_regularizer: nil,
268
290
  use_bias: true)
269
291
  super
270
- @cell = @params[:c] = Param.new
292
+ @cell = @params[:cell] = Param.new
293
+ end
294
+
295
+ def build(input_shape)
296
+ super
297
+ num_prev_nodes = input_shape[1]
298
+ @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
299
+ @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
300
+ @bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
301
+ init_weight_and_bias
302
+ @time_length.times do |t|
303
+ @layers << LSTM_Dense.new(@weight, @recurrent_weight, @bias)
304
+ end
271
305
  end
272
306
 
273
307
  def forward(xs)
@@ -283,6 +317,7 @@ module DNN
283
317
  c ||= Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
284
318
  xs.shape[1].times do |t|
285
319
  x = xs[true, t, false]
320
+ @layers[t].trainable = @trainable
286
321
  h, c = @layers[t].forward(x, h, c)
287
322
  hs[true, t, false] = h
288
323
  end
@@ -292,9 +327,6 @@ module DNN
292
327
  end
293
328
 
294
329
  def backward(dh2s)
295
- @weight.grad = Xumo::SFloat.zeros(*@weight.data.shape)
296
- @weight2.grad = Xumo::SFloat.zeros(*@weight2.data.shape)
297
- @bias.grad = Xumo::SFloat.zeros(*@bias.data.shape) if @bias
298
330
  unless @return_sequences
299
331
  dh = dh2s
300
332
  dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
@@ -315,33 +347,20 @@ module DNN
315
347
  super()
316
348
  @cell.data = @cell.data.fill(0) if @cell.data
317
349
  end
318
-
319
- private
320
-
321
- def init_params
322
- super()
323
- num_prev_nodes = @input_shape[1]
324
- @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
325
- @weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
326
- @bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
327
- @weight_initializer.init_param(self, @weight)
328
- @weight_initializer.init_param(self, @weight2)
329
- @bias_initializer.init_param(self, @bias) if @bias
330
- @time_length.times do |t|
331
- @layers << LSTM_Dense.new(@weight, @weight2, @bias)
332
- end
333
- end
334
350
  end
335
351
 
336
352
 
337
353
  class GRU_Dense
338
- def initialize(weight, weight2, bias)
354
+ attr_accessor :trainable
355
+
356
+ def initialize(weight, recurrent_weight, bias)
339
357
  @weight = weight
340
- @weight2 = weight2
358
+ @recurrent_weight = recurrent_weight
341
359
  @bias = bias
342
360
  @update_sigmoid = Sigmoid.new
343
361
  @reset_sigmoid = Sigmoid.new
344
362
  @tanh = Tanh.new
363
+ @trainable = true
345
364
  end
346
365
 
347
366
  def forward(x, h)
@@ -349,60 +368,68 @@ module DNN
349
368
  @h = h
350
369
  num_nodes = h.shape[1]
351
370
  @weight_a = @weight.data[true, 0...(num_nodes * 2)]
352
- @weight2_a = @weight2.data[true, 0...(num_nodes * 2)]
371
+ @weight2_a = @recurrent_weight.data[true, 0...(num_nodes * 2)]
353
372
  a = x.dot(@weight_a) + h.dot(@weight2_a)
354
373
  a += @bias.data[0...(num_nodes * 2)] if @bias
355
374
  @update = @update_sigmoid.forward(a[true, 0...num_nodes])
356
375
  @reset = @reset_sigmoid.forward(a[true, num_nodes..-1])
357
376
 
358
377
  @weight_h = @weight.data[true, (num_nodes * 2)..-1]
359
- @weight2_h = @weight2.data[true, (num_nodes * 2)..-1]
378
+ @weight2_h = @recurrent_weight.data[true, (num_nodes * 2)..-1]
360
379
  @tanh_h = if @bias
361
380
  bias_h = @bias.data[(num_nodes * 2)..-1]
362
381
  @tanh.forward(x.dot(@weight_h) + (h * @reset).dot(@weight2_h) + bias_h)
363
382
  else
364
383
  @tanh.forward(x.dot(@weight_h) + (h * @reset).dot(@weight2_h))
365
384
  end
366
- h2 = (1 - @update) * h + @update * @tanh_h
385
+ h2 = (1 - @update) * @tanh_h + @update * h
367
386
  h2
368
387
  end
369
388
 
370
389
  def backward(dh2)
371
- dtanh_h = @tanh.backward(dh2 * @update)
372
- dh = dh2 * (1 - @update)
390
+ dtanh_h = @tanh.backward(dh2 * (1 - @update))
391
+ dh = dh2 * @update
373
392
 
374
- dweight_h = @x.transpose.dot(dtanh_h)
393
+ if @trainable
394
+ dweight_h = @x.transpose.dot(dtanh_h)
395
+ dweight2_h = (@h * @reset).transpose.dot(dtanh_h)
396
+ dbias_h = dtanh_h.sum(0) if @bias
397
+ end
375
398
  dx = dtanh_h.dot(@weight_h.transpose)
376
- dweight2_h = (@h * @reset).transpose.dot(dtanh_h)
377
399
  dh += dtanh_h.dot(@weight2_h.transpose) * @reset
378
- dbias_h = dtanh_h.sum(0) if @bias
379
400
 
380
401
  dreset = @reset_sigmoid.backward(dtanh_h.dot(@weight2_h.transpose) * @h)
381
- dupdate = @update_sigmoid.backward(dh2 * @tanh_h - dh2 * @h)
402
+ dupdate = @update_sigmoid.backward(dh2 * @h - dh2 * @tanh_h)
382
403
  da = Xumo::SFloat.hstack([dupdate, dreset])
383
- dweight_a = @x.transpose.dot(da)
404
+ if @trainable
405
+ dweight_a = @x.transpose.dot(da)
406
+ dweight2_a = @h.transpose.dot(da)
407
+ dbias_a = da.sum(0) if @bias
408
+ end
384
409
  dx += da.dot(@weight_a.transpose)
385
- dweight2_a = @h.transpose.dot(da)
386
410
  dh += da.dot(@weight2_a.transpose)
387
- dbias_a = da.sum(0) if @bias
388
411
 
389
- @weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
390
- @weight2.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
391
- @bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h]) if @bias
412
+ if @trainable
413
+ @weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
414
+ @recurrent_weight.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
415
+ @bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h]) if @bias
416
+ end
392
417
  [dx, dh]
393
418
  end
394
419
  end
395
420
 
396
421
 
397
422
  class GRU < RNN
398
- def self.load_hash(hash)
423
+ def self.from_hash(hash)
399
424
  gru = self.new(hash[:num_nodes],
400
425
  stateful: hash[:stateful],
401
426
  return_sequences: hash[:return_sequences],
402
- weight_initializer: Utils.load_hash(hash[:weight_initializer]),
403
- bias_initializer: Utils.load_hash(hash[:bias_initializer]),
404
- l1_lambda: hash[:l1_lambda],
405
- l2_lambda: hash[:l2_lambda],
427
+ weight_initializer: Utils.from_hash(hash[:weight_initializer]),
428
+ recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
429
+ bias_initializer: Utils.from_hash(hash[:bias_initializer]),
430
+ weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
431
+ recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
432
+ bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
406
433
  use_bias: hash[:use_bias])
407
434
  gru
408
435
  end
@@ -411,26 +438,24 @@ module DNN
411
438
  stateful: false,
412
439
  return_sequences: true,
413
440
  weight_initializer: RandomNormal.new,
441
+ recurrent_weight_initializer: RandomNormal.new,
414
442
  bias_initializer: Zeros.new,
415
- l1_lambda: 0,
416
- l2_lambda: 0,
443
+ weight_regularizer: nil,
444
+ recurrent_weight_regularizer: nil,
445
+ bias_regularizer: nil,
417
446
  use_bias: true)
418
447
  super
419
448
  end
420
-
421
- private
422
449
 
423
- def init_params
424
- super()
450
+ def build(input_shape)
451
+ super
425
452
  num_prev_nodes = @input_shape[1]
426
453
  @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 3)
427
- @weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
454
+ @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
428
455
  @bias.data = Xumo::SFloat.new(@num_nodes * 3) if @bias
429
- @weight_initializer.init_param(self, @weight)
430
- @weight_initializer.init_param(self, @weight2)
431
- @bias_initializer.init_param(self, @bias) if @bias
456
+ init_weight_and_bias
432
457
  @time_length.times do |t|
433
- @layers << GRU_Dense.new(@weight, @weight2, @bias)
458
+ @layers << GRU_Dense.new(@weight, @recurrent_weight, @bias)
434
459
  end
435
460
  end
436
461
  end