ruby-dnn 0.10.1 → 0.10.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,9 +1,9 @@
1
- class DNN::Param
2
- attr_accessor :data
3
- attr_accessor :grad
4
-
5
- def initialize(data = nil, grad = nil)
6
- @data = data
7
- @grad = grad
8
- end
9
- end
1
+ class DNN::Param
2
+ attr_accessor :data
3
+ attr_accessor :grad
4
+
5
+ def initialize(data = nil, grad = nil)
6
+ @data = data
7
+ @grad = grad
8
+ end
9
+ end
@@ -1,106 +1,106 @@
1
- module DNN
2
- module Regularizers
3
-
4
- class Regularizer
5
- attr_accessor :param
6
-
7
- def forward(x)
8
- raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'forward'")
9
- end
10
-
11
- def backward
12
- raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'backward'")
13
- end
14
-
15
- def to_hash(merge_hash)
16
- hash = {class: self.class.name}
17
- hash.merge!(merge_hash)
18
- hash
19
- end
20
- end
21
-
22
- class L1 < Regularizer
23
- attr_accessor :l1_lambda
24
-
25
- def self.from_hash(hash)
26
- L1.new(hash[:l1_lambda])
27
- end
28
-
29
- def initialize(l1_lambda = 0.01)
30
- @l1_lambda = l1_lambda
31
- end
32
-
33
- def forward(x)
34
- x + @l1_lambda * @param.data.abs.sum
35
- end
36
-
37
- def backward
38
- dparam = Xumo::SFloat.ones(*@param.data.shape)
39
- dparam[@param.data < 0] = -1
40
- @param.grad += @l1_lambda * dparam
41
- end
42
-
43
- def to_hash
44
- super(l1_lambda: @l1_lambda)
45
- end
46
- end
47
-
48
-
49
- class L2 < Regularizer
50
- attr_accessor :l2_lambda
51
-
52
- def self.from_hash(hash)
53
- L2.new(hash[:l2_lambda])
54
- end
55
-
56
- def initialize(l2_lambda = 0.01)
57
- @l2_lambda = l2_lambda
58
- end
59
-
60
- def forward(x)
61
- x + 0.5 * @l2_lambda * (@param.data**2).sum
62
- end
63
-
64
- def backward
65
- @param.grad += @l2_lambda * @param.data
66
- end
67
-
68
- def to_hash
69
- super(l2_lambda: @l2_lambda)
70
- end
71
- end
72
-
73
- class L1L2 < Regularizer
74
- attr_accessor :l1_lambda
75
- attr_accessor :l2_lambda
76
-
77
- def self.from_hash(hash)
78
- L1L2.new(hash[:l1_lambda], hash[:l2_lambda])
79
- end
80
-
81
- def initialize(l1_lambda = 0.01, l2_lambda = 0.01)
82
- @l1_lambda = l1_lambda
83
- @l2_lambda = l2_lambda
84
- end
85
-
86
- def forward(x)
87
- l1 = @l1_lambda * @param.data.abs.sum
88
- l2 = 0.5 * @l2_lambda * (@param.data**2).sum
89
- x + l1 + l2
90
- end
91
-
92
- def backward
93
- dparam = Xumo::SFloat.ones(*@param.data.shape)
94
- dparam[@param.data < 0] = -1
95
- @param.grad += @l1_lambda * dparam
96
- @param.grad += @l2_lambda * @param.data
97
- end
98
-
99
- def to_hash
100
- super(l1_lambda: l1_lambda, l2_lambda: l2_lambda)
101
- end
102
-
103
- end
104
-
105
- end
106
- end
1
+ module DNN
2
+ module Regularizers
3
+
4
+ class Regularizer
5
+ attr_accessor :param
6
+
7
+ def forward(x)
8
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'forward'")
9
+ end
10
+
11
+ def backward
12
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'backward'")
13
+ end
14
+
15
+ def to_hash(merge_hash)
16
+ hash = {class: self.class.name}
17
+ hash.merge!(merge_hash)
18
+ hash
19
+ end
20
+ end
21
+
22
+ class L1 < Regularizer
23
+ attr_accessor :l1_lambda
24
+
25
+ def self.from_hash(hash)
26
+ L1.new(hash[:l1_lambda])
27
+ end
28
+
29
+ def initialize(l1_lambda = 0.01)
30
+ @l1_lambda = l1_lambda
31
+ end
32
+
33
+ def forward(x)
34
+ x + @l1_lambda * @param.data.abs.sum
35
+ end
36
+
37
+ def backward
38
+ dparam = Xumo::SFloat.ones(*@param.data.shape)
39
+ dparam[@param.data < 0] = -1
40
+ @param.grad += @l1_lambda * dparam
41
+ end
42
+
43
+ def to_hash
44
+ super(l1_lambda: @l1_lambda)
45
+ end
46
+ end
47
+
48
+
49
+ class L2 < Regularizer
50
+ attr_accessor :l2_lambda
51
+
52
+ def self.from_hash(hash)
53
+ L2.new(hash[:l2_lambda])
54
+ end
55
+
56
+ def initialize(l2_lambda = 0.01)
57
+ @l2_lambda = l2_lambda
58
+ end
59
+
60
+ def forward(x)
61
+ x + 0.5 * @l2_lambda * (@param.data**2).sum
62
+ end
63
+
64
+ def backward
65
+ @param.grad += @l2_lambda * @param.data
66
+ end
67
+
68
+ def to_hash
69
+ super(l2_lambda: @l2_lambda)
70
+ end
71
+ end
72
+
73
+ class L1L2 < Regularizer
74
+ attr_accessor :l1_lambda
75
+ attr_accessor :l2_lambda
76
+
77
+ def self.from_hash(hash)
78
+ L1L2.new(hash[:l1_lambda], hash[:l2_lambda])
79
+ end
80
+
81
+ def initialize(l1_lambda = 0.01, l2_lambda = 0.01)
82
+ @l1_lambda = l1_lambda
83
+ @l2_lambda = l2_lambda
84
+ end
85
+
86
+ def forward(x)
87
+ l1 = @l1_lambda * @param.data.abs.sum
88
+ l2 = 0.5 * @l2_lambda * (@param.data**2).sum
89
+ x + l1 + l2
90
+ end
91
+
92
+ def backward
93
+ dparam = Xumo::SFloat.ones(*@param.data.shape)
94
+ dparam[@param.data < 0] = -1
95
+ @param.grad += @l1_lambda * dparam
96
+ @param.grad += @l2_lambda * @param.data
97
+ end
98
+
99
+ def to_hash
100
+ super(l1_lambda: l1_lambda, l2_lambda: l2_lambda)
101
+ end
102
+
103
+ end
104
+
105
+ end
106
+ end
@@ -1,464 +1,464 @@
1
- module DNN
2
- module Layers
3
-
4
- # Super class of all RNN classes.
5
- class RNN < Connection
6
- include Initializers
7
-
8
- # @return [Integer] number of nodes.
9
- attr_reader :num_nodes
10
- # @return [Bool] Maintain state between batches.
11
- attr_reader :stateful
12
- # @return [Bool] Set the false, only the last of each cell of RNN is left.
13
- attr_reader :return_sequences
14
- # @return [DNN::Initializers::Initializer] Recurrent weight initializer.
15
- attr_reader :recurrent_weight_initializer
16
- # @return [DNN::Regularizers::Regularizer] Recurrent weight regularization.
17
- attr_reader :recurrent_weight_regularizer
18
-
19
- def initialize(num_nodes,
20
- stateful: false,
21
- return_sequences: true,
22
- weight_initializer: RandomNormal.new,
23
- recurrent_weight_initializer: RandomNormal.new,
24
- bias_initializer: Zeros.new,
25
- weight_regularizer: nil,
26
- recurrent_weight_regularizer: nil,
27
- bias_regularizer: nil,
28
- use_bias: true)
29
- super(weight_initializer: weight_initializer, bias_initializer: bias_initializer,
30
- weight_regularizer: weight_regularizer, bias_regularizer: bias_regularizer, use_bias: use_bias)
31
- @num_nodes = num_nodes
32
- @stateful = stateful
33
- @return_sequences = return_sequences
34
- @layers = []
35
- @hidden = @params[:hidden] = Param.new
36
- @params[:recurrent_weight] = @recurrent_weight = Param.new(nil, 0)
37
- @recurrent_weight_initializer = recurrent_weight_initializer
38
- @recurrent_weight_regularizer = recurrent_weight_regularizer
39
- end
40
-
41
- def build(input_shape)
42
- super
43
- @time_length = @input_shape[0]
44
- end
45
-
46
- def forward(xs)
47
- @xs_shape = xs.shape
48
- hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes)
49
- h = (@stateful && @hidden.data) ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
50
- xs.shape[1].times do |t|
51
- x = xs[true, t, false]
52
- @layers[t].trainable = @trainable
53
- h = @layers[t].forward(x, h)
54
- hs[true, t, false] = h
55
- end
56
- @hidden.data = h
57
- @return_sequences ? hs : h
58
- end
59
-
60
- def backward(dh2s)
61
- unless @return_sequences
62
- dh = dh2s
63
- dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
64
- dh2s[true, -1, false] = dh
65
- end
66
- dxs = Xumo::SFloat.zeros(@xs_shape)
67
- dh = 0
68
- (0...dh2s.shape[1]).to_a.reverse.each do |t|
69
- dh2 = dh2s[true, t, false]
70
- dx, dh = @layers[t].backward(dh2 + dh)
71
- dxs[true, t, false] = dx
72
- end
73
- dxs
74
- end
75
-
76
- def output_shape
77
- @return_sequences ? [@time_length, @num_nodes] : [@num_nodes]
78
- end
79
-
80
- def to_hash(merge_hash = nil)
81
- hash = {
82
- num_nodes: @num_nodes,
83
- stateful: @stateful,
84
- return_sequences: @return_sequences,
85
- recurrent_weight_initializer: @recurrent_weight_initializer.to_hash,
86
- recurrent_weight_regularizer: @recurrent_weight_regularizer&.to_hash,
87
- }
88
- hash.merge!(merge_hash) if merge_hash
89
- super(hash)
90
- end
91
-
92
- # Reset the state of RNN.
93
- def reset_state
94
- @hidden.data = @hidden.data.fill(0) if @hidden.data
95
- end
96
-
97
- def regularizers
98
- regularizers = []
99
- regularizers << @weight_regularizer if @weight_regularizer
100
- regularizers << @recurrent_weight_regularizer if @recurrent_weight_regularizer
101
- regularizers << @bias_regularizer if @bias_regularizer
102
- regularizers
103
- end
104
-
105
- private def init_weight_and_bias
106
- super
107
- @recurrent_weight_initializer.init_param(self, @recurrent_weight)
108
- @recurrent_weight_regularizer.param = @recurrent_weight if @recurrent_weight_regularizer
109
- end
110
- end
111
-
112
-
113
- class SimpleRNN_Dense
114
- attr_accessor :trainable
115
-
116
- def initialize(weight, recurrent_weight, bias, activation)
117
- @weight = weight
118
- @recurrent_weight = recurrent_weight
119
- @bias = bias
120
- @activation = activation.clone
121
- @trainable = true
122
- end
123
-
124
- def forward(x, h)
125
- @x = x
126
- @h = h
127
- h2 = x.dot(@weight.data) + h.dot(@recurrent_weight.data)
128
- h2 += @bias.data if @bias
129
- @activation.forward(h2)
130
- end
131
-
132
- def backward(dh2)
133
- dh2 = @activation.backward(dh2)
134
- if @trainable
135
- @weight.grad += @x.transpose.dot(dh2)
136
- @recurrent_weight.grad += @h.transpose.dot(dh2)
137
- @bias.grad += dh2.sum(0) if @bias
138
- end
139
- dx = dh2.dot(@weight.data.transpose)
140
- dh = dh2.dot(@recurrent_weight.data.transpose)
141
- [dx, dh]
142
- end
143
- end
144
-
145
-
146
- class SimpleRNN < RNN
147
- include Activations
148
-
149
- attr_reader :activation
150
-
151
- def self.from_hash(hash)
152
- simple_rnn = self.new(hash[:num_nodes],
153
- stateful: hash[:stateful],
154
- return_sequences: hash[:return_sequences],
155
- activation: Utils.from_hash(hash[:activation]),
156
- weight_initializer: Utils.from_hash(hash[:weight_initializer]),
157
- recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
158
- bias_initializer: Utils.from_hash(hash[:bias_initializer]),
159
- weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
160
- recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
161
- bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
162
- use_bias: hash[:use_bias])
163
- simple_rnn
164
- end
165
-
166
- def initialize(num_nodes,
167
- stateful: false,
168
- return_sequences: true,
169
- activation: Tanh.new,
170
- weight_initializer: RandomNormal.new,
171
- recurrent_weight_initializer: RandomNormal.new,
172
- bias_initializer: Zeros.new,
173
- weight_regularizer: nil,
174
- recurrent_weight_regularizer: nil,
175
- bias_regularizer: nil,
176
- use_bias: true)
177
- super(num_nodes,
178
- stateful: stateful,
179
- return_sequences: return_sequences,
180
- weight_initializer: weight_initializer,
181
- recurrent_weight_initializer: recurrent_weight_initializer,
182
- bias_initializer: bias_initializer,
183
- weight_regularizer: weight_regularizer,
184
- recurrent_weight_regularizer: recurrent_weight_regularizer,
185
- bias_regularizer: bias_regularizer,
186
- use_bias: use_bias)
187
- @activation = activation
188
- end
189
-
190
- def build(input_shape)
191
- super
192
- num_prev_nodes = input_shape[1]
193
- @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
194
- @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
195
- @bias.data = Xumo::SFloat.new(@num_nodes) if @bias
196
- init_weight_and_bias
197
- @time_length.times do |t|
198
- @layers << SimpleRNN_Dense.new(@weight, @recurrent_weight, @bias, @activation)
199
- end
200
- end
201
-
202
- def to_hash
203
- super({activation: @activation.to_hash})
204
- end
205
- end
206
-
207
-
208
- class LSTM_Dense
209
- attr_accessor :trainable
210
-
211
- def initialize(weight, recurrent_weight, bias)
212
- @weight = weight
213
- @recurrent_weight = recurrent_weight
214
- @bias = bias
215
- @tanh = Tanh.new
216
- @g_tanh = Tanh.new
217
- @forget_sigmoid = Sigmoid.new
218
- @in_sigmoid = Sigmoid.new
219
- @out_sigmoid = Sigmoid.new
220
- @trainable = true
221
- end
222
-
223
- def forward(x, h, c)
224
- @x = x
225
- @h = h
226
- @c = c
227
- num_nodes = h.shape[1]
228
- a = x.dot(@weight.data) + h.dot(@recurrent_weight.data)
229
- a += @bias.data if @bias
230
-
231
- @forget = @forget_sigmoid.forward(a[true, 0...num_nodes])
232
- @g = @g_tanh.forward(a[true, num_nodes...(num_nodes * 2)])
233
- @in = @in_sigmoid.forward(a[true, (num_nodes * 2)...(num_nodes * 3)])
234
- @out = @out_sigmoid.forward(a[true, (num_nodes * 3)..-1])
235
-
236
- c2 = @forget * c + @g * @in
237
- @tanh_c2 = @tanh.forward(c2)
238
- h2 = @out * @tanh_c2
239
- [h2, c2]
240
- end
241
-
242
- def backward(dh2, dc2)
243
- dh2_tmp = @tanh_c2 * dh2
244
- dc2_tmp = @tanh.backward(@out * dh2) + dc2
245
-
246
- dout = @out_sigmoid.backward(dh2_tmp)
247
- din = @in_sigmoid.backward(dc2_tmp * @g)
248
- dg = @g_tanh.backward(dc2_tmp * @in)
249
- dforget = @forget_sigmoid.backward(dc2_tmp * @c)
250
-
251
- da = Xumo::SFloat.hstack([dforget, dg, din, dout])
252
-
253
- if @trainable
254
- @weight.grad += @x.transpose.dot(da)
255
- @recurrent_weight.grad += @h.transpose.dot(da)
256
- @bias.grad += da.sum(0) if @bias
257
- end
258
- dx = da.dot(@weight.data.transpose)
259
- dh = da.dot(@recurrent_weight.data.transpose)
260
- dc = dc2_tmp * @forget
261
- [dx, dh, dc]
262
- end
263
- end
264
-
265
-
266
- class LSTM < RNN
267
- def self.from_hash(hash)
268
- lstm = self.new(hash[:num_nodes],
269
- stateful: hash[:stateful],
270
- return_sequences: hash[:return_sequences],
271
- weight_initializer: Utils.from_hash(hash[:weight_initializer]),
272
- recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
273
- bias_initializer: Utils.from_hash(hash[:bias_initializer]),
274
- weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
275
- recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
276
- bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
277
- use_bias: hash[:use_bias])
278
- lstm
279
- end
280
-
281
- def initialize(num_nodes,
282
- stateful: false,
283
- return_sequences: true,
284
- weight_initializer: RandomNormal.new,
285
- recurrent_weight_initializer: RandomNormal.new,
286
- bias_initializer: Zeros.new,
287
- weight_regularizer: nil,
288
- recurrent_weight_regularizer: nil,
289
- bias_regularizer: nil,
290
- use_bias: true)
291
- super
292
- @cell = @params[:cell] = Param.new
293
- end
294
-
295
- def build(input_shape)
296
- super
297
- num_prev_nodes = input_shape[1]
298
- @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
299
- @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
300
- @bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
301
- init_weight_and_bias
302
- @time_length.times do |t|
303
- @layers << LSTM_Dense.new(@weight, @recurrent_weight, @bias)
304
- end
305
- end
306
-
307
- def forward(xs)
308
- @xs_shape = xs.shape
309
- hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes)
310
- h = nil
311
- c = nil
312
- if @stateful
313
- h = @hidden.data if @hidden.data
314
- c = @cell.data if @cell.data
315
- end
316
- h ||= Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
317
- c ||= Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
318
- xs.shape[1].times do |t|
319
- x = xs[true, t, false]
320
- @layers[t].trainable = @trainable
321
- h, c = @layers[t].forward(x, h, c)
322
- hs[true, t, false] = h
323
- end
324
- @hidden.data = h
325
- @cell.data = c
326
- @return_sequences ? hs : h
327
- end
328
-
329
- def backward(dh2s)
330
- unless @return_sequences
331
- dh = dh2s
332
- dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
333
- dh2s[true, -1, false] = dh
334
- end
335
- dxs = Xumo::SFloat.zeros(@xs_shape)
336
- dh = 0
337
- dc = 0
338
- (0...dh2s.shape[1]).to_a.reverse.each do |t|
339
- dh2 = dh2s[true, t, false]
340
- dx, dh, dc = @layers[t].backward(dh2 + dh, dc)
341
- dxs[true, t, false] = dx
342
- end
343
- dxs
344
- end
345
-
346
- def reset_state
347
- super()
348
- @cell.data = @cell.data.fill(0) if @cell.data
349
- end
350
- end
351
-
352
-
353
- class GRU_Dense
354
- attr_accessor :trainable
355
-
356
- def initialize(weight, recurrent_weight, bias)
357
- @weight = weight
358
- @recurrent_weight = recurrent_weight
359
- @bias = bias
360
- @update_sigmoid = Sigmoid.new
361
- @reset_sigmoid = Sigmoid.new
362
- @tanh = Tanh.new
363
- @trainable = true
364
- end
365
-
366
- def forward(x, h)
367
- @x = x
368
- @h = h
369
- num_nodes = h.shape[1]
370
- @weight_a = @weight.data[true, 0...(num_nodes * 2)]
371
- @weight2_a = @recurrent_weight.data[true, 0...(num_nodes * 2)]
372
- a = x.dot(@weight_a) + h.dot(@weight2_a)
373
- a += @bias.data[0...(num_nodes * 2)] if @bias
374
- @update = @update_sigmoid.forward(a[true, 0...num_nodes])
375
- @reset = @reset_sigmoid.forward(a[true, num_nodes..-1])
376
-
377
- @weight_h = @weight.data[true, (num_nodes * 2)..-1]
378
- @weight2_h = @recurrent_weight.data[true, (num_nodes * 2)..-1]
379
- @tanh_h = if @bias
380
- bias_h = @bias.data[(num_nodes * 2)..-1]
381
- @tanh.forward(x.dot(@weight_h) + (h * @reset).dot(@weight2_h) + bias_h)
382
- else
383
- @tanh.forward(x.dot(@weight_h) + (h * @reset).dot(@weight2_h))
384
- end
385
- h2 = (1 - @update) * @tanh_h + @update * h
386
- h2
387
- end
388
-
389
- def backward(dh2)
390
- dtanh_h = @tanh.backward(dh2 * (1 - @update))
391
- dh = dh2 * @update
392
-
393
- if @trainable
394
- dweight_h = @x.transpose.dot(dtanh_h)
395
- dweight2_h = (@h * @reset).transpose.dot(dtanh_h)
396
- dbias_h = dtanh_h.sum(0) if @bias
397
- end
398
- dx = dtanh_h.dot(@weight_h.transpose)
399
- dh += dtanh_h.dot(@weight2_h.transpose) * @reset
400
-
401
- dreset = @reset_sigmoid.backward(dtanh_h.dot(@weight2_h.transpose) * @h)
402
- dupdate = @update_sigmoid.backward(dh2 * @h - dh2 * @tanh_h)
403
- da = Xumo::SFloat.hstack([dupdate, dreset])
404
- if @trainable
405
- dweight_a = @x.transpose.dot(da)
406
- dweight2_a = @h.transpose.dot(da)
407
- dbias_a = da.sum(0) if @bias
408
- end
409
- dx += da.dot(@weight_a.transpose)
410
- dh += da.dot(@weight2_a.transpose)
411
-
412
- if @trainable
413
- @weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
414
- @recurrent_weight.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
415
- @bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h]) if @bias
416
- end
417
- [dx, dh]
418
- end
419
- end
420
-
421
-
422
- class GRU < RNN
423
- def self.from_hash(hash)
424
- gru = self.new(hash[:num_nodes],
425
- stateful: hash[:stateful],
426
- return_sequences: hash[:return_sequences],
427
- weight_initializer: Utils.from_hash(hash[:weight_initializer]),
428
- recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
429
- bias_initializer: Utils.from_hash(hash[:bias_initializer]),
430
- weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
431
- recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
432
- bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
433
- use_bias: hash[:use_bias])
434
- gru
435
- end
436
-
437
- def initialize(num_nodes,
438
- stateful: false,
439
- return_sequences: true,
440
- weight_initializer: RandomNormal.new,
441
- recurrent_weight_initializer: RandomNormal.new,
442
- bias_initializer: Zeros.new,
443
- weight_regularizer: nil,
444
- recurrent_weight_regularizer: nil,
445
- bias_regularizer: nil,
446
- use_bias: true)
447
- super
448
- end
449
-
450
- def build(input_shape)
451
- super
452
- num_prev_nodes = @input_shape[1]
453
- @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 3)
454
- @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
455
- @bias.data = Xumo::SFloat.new(@num_nodes * 3) if @bias
456
- init_weight_and_bias
457
- @time_length.times do |t|
458
- @layers << GRU_Dense.new(@weight, @recurrent_weight, @bias)
459
- end
460
- end
461
- end
462
-
463
- end
464
- end
1
+ module DNN
2
+ module Layers
3
+
4
+ # Super class of all RNN classes.
5
+ class RNN < Connection
6
+ include Initializers
7
+
8
+ # @return [Integer] number of nodes.
9
+ attr_reader :num_nodes
10
+ # @return [Bool] Maintain state between batches.
11
+ attr_reader :stateful
12
+ # @return [Bool] Set the false, only the last of each cell of RNN is left.
13
+ attr_reader :return_sequences
14
+ # @return [DNN::Initializers::Initializer] Recurrent weight initializer.
15
+ attr_reader :recurrent_weight_initializer
16
+ # @return [DNN::Regularizers::Regularizer] Recurrent weight regularization.
17
+ attr_reader :recurrent_weight_regularizer
18
+
19
+ def initialize(num_nodes,
20
+ stateful: false,
21
+ return_sequences: true,
22
+ weight_initializer: RandomNormal.new,
23
+ recurrent_weight_initializer: RandomNormal.new,
24
+ bias_initializer: Zeros.new,
25
+ weight_regularizer: nil,
26
+ recurrent_weight_regularizer: nil,
27
+ bias_regularizer: nil,
28
+ use_bias: true)
29
+ super(weight_initializer: weight_initializer, bias_initializer: bias_initializer,
30
+ weight_regularizer: weight_regularizer, bias_regularizer: bias_regularizer, use_bias: use_bias)
31
+ @num_nodes = num_nodes
32
+ @stateful = stateful
33
+ @return_sequences = return_sequences
34
+ @layers = []
35
+ @hidden = @params[:hidden] = Param.new
36
+ @params[:recurrent_weight] = @recurrent_weight = Param.new(nil, 0)
37
+ @recurrent_weight_initializer = recurrent_weight_initializer
38
+ @recurrent_weight_regularizer = recurrent_weight_regularizer
39
+ end
40
+
41
+ def build(input_shape)
42
+ super
43
+ @time_length = @input_shape[0]
44
+ end
45
+
46
+ def forward(xs)
47
+ @xs_shape = xs.shape
48
+ hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes)
49
+ h = (@stateful && @hidden.data) ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
50
+ xs.shape[1].times do |t|
51
+ x = xs[true, t, false]
52
+ @layers[t].trainable = @trainable
53
+ h = @layers[t].forward(x, h)
54
+ hs[true, t, false] = h
55
+ end
56
+ @hidden.data = h
57
+ @return_sequences ? hs : h
58
+ end
59
+
60
+ def backward(dh2s)
61
+ unless @return_sequences
62
+ dh = dh2s
63
+ dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
64
+ dh2s[true, -1, false] = dh
65
+ end
66
+ dxs = Xumo::SFloat.zeros(@xs_shape)
67
+ dh = 0
68
+ (0...dh2s.shape[1]).to_a.reverse.each do |t|
69
+ dh2 = dh2s[true, t, false]
70
+ dx, dh = @layers[t].backward(dh2 + dh)
71
+ dxs[true, t, false] = dx
72
+ end
73
+ dxs
74
+ end
75
+
76
+ def output_shape
77
+ @return_sequences ? [@time_length, @num_nodes] : [@num_nodes]
78
+ end
79
+
80
+ def to_hash(merge_hash = nil)
81
+ hash = {
82
+ num_nodes: @num_nodes,
83
+ stateful: @stateful,
84
+ return_sequences: @return_sequences,
85
+ recurrent_weight_initializer: @recurrent_weight_initializer.to_hash,
86
+ recurrent_weight_regularizer: @recurrent_weight_regularizer&.to_hash,
87
+ }
88
+ hash.merge!(merge_hash) if merge_hash
89
+ super(hash)
90
+ end
91
+
92
+ # Reset the state of RNN.
93
+ def reset_state
94
+ @hidden.data = @hidden.data.fill(0) if @hidden.data
95
+ end
96
+
97
+ def regularizers
98
+ regularizers = []
99
+ regularizers << @weight_regularizer if @weight_regularizer
100
+ regularizers << @recurrent_weight_regularizer if @recurrent_weight_regularizer
101
+ regularizers << @bias_regularizer if @bias_regularizer
102
+ regularizers
103
+ end
104
+
105
+ private def init_weight_and_bias
106
+ super
107
+ @recurrent_weight_initializer.init_param(self, @recurrent_weight)
108
+ @recurrent_weight_regularizer.param = @recurrent_weight if @recurrent_weight_regularizer
109
+ end
110
+ end
111
+
112
+
113
+ class SimpleRNN_Dense
114
+ attr_accessor :trainable
115
+
116
+ def initialize(weight, recurrent_weight, bias, activation)
117
+ @weight = weight
118
+ @recurrent_weight = recurrent_weight
119
+ @bias = bias
120
+ @activation = activation.clone
121
+ @trainable = true
122
+ end
123
+
124
+ def forward(x, h)
125
+ @x = x
126
+ @h = h
127
+ h2 = x.dot(@weight.data) + h.dot(@recurrent_weight.data)
128
+ h2 += @bias.data if @bias
129
+ @activation.forward(h2)
130
+ end
131
+
132
+ def backward(dh2)
133
+ dh2 = @activation.backward(dh2)
134
+ if @trainable
135
+ @weight.grad += @x.transpose.dot(dh2)
136
+ @recurrent_weight.grad += @h.transpose.dot(dh2)
137
+ @bias.grad += dh2.sum(0) if @bias
138
+ end
139
+ dx = dh2.dot(@weight.data.transpose)
140
+ dh = dh2.dot(@recurrent_weight.data.transpose)
141
+ [dx, dh]
142
+ end
143
+ end
144
+
145
+
146
+ class SimpleRNN < RNN
147
+ include Activations
148
+
149
+ attr_reader :activation
150
+
151
+ def self.from_hash(hash)
152
+ simple_rnn = self.new(hash[:num_nodes],
153
+ stateful: hash[:stateful],
154
+ return_sequences: hash[:return_sequences],
155
+ activation: Utils.from_hash(hash[:activation]),
156
+ weight_initializer: Utils.from_hash(hash[:weight_initializer]),
157
+ recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
158
+ bias_initializer: Utils.from_hash(hash[:bias_initializer]),
159
+ weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
160
+ recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
161
+ bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
162
+ use_bias: hash[:use_bias])
163
+ simple_rnn
164
+ end
165
+
166
+ def initialize(num_nodes,
167
+ stateful: false,
168
+ return_sequences: true,
169
+ activation: Tanh.new,
170
+ weight_initializer: RandomNormal.new,
171
+ recurrent_weight_initializer: RandomNormal.new,
172
+ bias_initializer: Zeros.new,
173
+ weight_regularizer: nil,
174
+ recurrent_weight_regularizer: nil,
175
+ bias_regularizer: nil,
176
+ use_bias: true)
177
+ super(num_nodes,
178
+ stateful: stateful,
179
+ return_sequences: return_sequences,
180
+ weight_initializer: weight_initializer,
181
+ recurrent_weight_initializer: recurrent_weight_initializer,
182
+ bias_initializer: bias_initializer,
183
+ weight_regularizer: weight_regularizer,
184
+ recurrent_weight_regularizer: recurrent_weight_regularizer,
185
+ bias_regularizer: bias_regularizer,
186
+ use_bias: use_bias)
187
+ @activation = activation
188
+ end
189
+
190
+ def build(input_shape)
191
+ super
192
+ num_prev_nodes = input_shape[1]
193
+ @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
194
+ @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
195
+ @bias.data = Xumo::SFloat.new(@num_nodes) if @bias
196
+ init_weight_and_bias
197
+ @time_length.times do |t|
198
+ @layers << SimpleRNN_Dense.new(@weight, @recurrent_weight, @bias, @activation)
199
+ end
200
+ end
201
+
202
+ def to_hash
203
+ super({activation: @activation.to_hash})
204
+ end
205
+ end
206
+
207
+
208
+ class LSTM_Dense
209
+ attr_accessor :trainable
210
+
211
+ def initialize(weight, recurrent_weight, bias)
212
+ @weight = weight
213
+ @recurrent_weight = recurrent_weight
214
+ @bias = bias
215
+ @tanh = Tanh.new
216
+ @g_tanh = Tanh.new
217
+ @forget_sigmoid = Sigmoid.new
218
+ @in_sigmoid = Sigmoid.new
219
+ @out_sigmoid = Sigmoid.new
220
+ @trainable = true
221
+ end
222
+
223
+ def forward(x, h, c)
224
+ @x = x
225
+ @h = h
226
+ @c = c
227
+ num_nodes = h.shape[1]
228
+ a = x.dot(@weight.data) + h.dot(@recurrent_weight.data)
229
+ a += @bias.data if @bias
230
+
231
+ @forget = @forget_sigmoid.forward(a[true, 0...num_nodes])
232
+ @g = @g_tanh.forward(a[true, num_nodes...(num_nodes * 2)])
233
+ @in = @in_sigmoid.forward(a[true, (num_nodes * 2)...(num_nodes * 3)])
234
+ @out = @out_sigmoid.forward(a[true, (num_nodes * 3)..-1])
235
+
236
+ c2 = @forget * c + @g * @in
237
+ @tanh_c2 = @tanh.forward(c2)
238
+ h2 = @out * @tanh_c2
239
+ [h2, c2]
240
+ end
241
+
242
+ def backward(dh2, dc2)
243
+ dh2_tmp = @tanh_c2 * dh2
244
+ dc2_tmp = @tanh.backward(@out * dh2) + dc2
245
+
246
+ dout = @out_sigmoid.backward(dh2_tmp)
247
+ din = @in_sigmoid.backward(dc2_tmp * @g)
248
+ dg = @g_tanh.backward(dc2_tmp * @in)
249
+ dforget = @forget_sigmoid.backward(dc2_tmp * @c)
250
+
251
+ da = Xumo::SFloat.hstack([dforget, dg, din, dout])
252
+
253
+ if @trainable
254
+ @weight.grad += @x.transpose.dot(da)
255
+ @recurrent_weight.grad += @h.transpose.dot(da)
256
+ @bias.grad += da.sum(0) if @bias
257
+ end
258
+ dx = da.dot(@weight.data.transpose)
259
+ dh = da.dot(@recurrent_weight.data.transpose)
260
+ dc = dc2_tmp * @forget
261
+ [dx, dh, dc]
262
+ end
263
+ end
264
+
265
+
266
+ class LSTM < RNN
267
+ def self.from_hash(hash)
268
+ lstm = self.new(hash[:num_nodes],
269
+ stateful: hash[:stateful],
270
+ return_sequences: hash[:return_sequences],
271
+ weight_initializer: Utils.from_hash(hash[:weight_initializer]),
272
+ recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
273
+ bias_initializer: Utils.from_hash(hash[:bias_initializer]),
274
+ weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
275
+ recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
276
+ bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
277
+ use_bias: hash[:use_bias])
278
+ lstm
279
+ end
280
+
281
+ def initialize(num_nodes,
282
+ stateful: false,
283
+ return_sequences: true,
284
+ weight_initializer: RandomNormal.new,
285
+ recurrent_weight_initializer: RandomNormal.new,
286
+ bias_initializer: Zeros.new,
287
+ weight_regularizer: nil,
288
+ recurrent_weight_regularizer: nil,
289
+ bias_regularizer: nil,
290
+ use_bias: true)
291
+ super
292
+ @cell = @params[:cell] = Param.new
293
+ end
294
+
295
+ def build(input_shape)
296
+ super
297
+ num_prev_nodes = input_shape[1]
298
+ @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
299
+ @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
300
+ @bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
301
+ init_weight_and_bias
302
+ @time_length.times do |t|
303
+ @layers << LSTM_Dense.new(@weight, @recurrent_weight, @bias)
304
+ end
305
+ end
306
+
307
+ def forward(xs)
308
+ @xs_shape = xs.shape
309
+ hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes)
310
+ h = nil
311
+ c = nil
312
+ if @stateful
313
+ h = @hidden.data if @hidden.data
314
+ c = @cell.data if @cell.data
315
+ end
316
+ h ||= Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
317
+ c ||= Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
318
+ xs.shape[1].times do |t|
319
+ x = xs[true, t, false]
320
+ @layers[t].trainable = @trainable
321
+ h, c = @layers[t].forward(x, h, c)
322
+ hs[true, t, false] = h
323
+ end
324
+ @hidden.data = h
325
+ @cell.data = c
326
+ @return_sequences ? hs : h
327
+ end
328
+
329
+ def backward(dh2s)
330
+ unless @return_sequences
331
+ dh = dh2s
332
+ dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
333
+ dh2s[true, -1, false] = dh
334
+ end
335
+ dxs = Xumo::SFloat.zeros(@xs_shape)
336
+ dh = 0
337
+ dc = 0
338
+ (0...dh2s.shape[1]).to_a.reverse.each do |t|
339
+ dh2 = dh2s[true, t, false]
340
+ dx, dh, dc = @layers[t].backward(dh2 + dh, dc)
341
+ dxs[true, t, false] = dx
342
+ end
343
+ dxs
344
+ end
345
+
346
+ def reset_state
347
+ super()
348
+ @cell.data = @cell.data.fill(0) if @cell.data
349
+ end
350
+ end
351
+
352
+
353
+ class GRU_Dense
354
+ attr_accessor :trainable
355
+
356
+ def initialize(weight, recurrent_weight, bias)
357
+ @weight = weight
358
+ @recurrent_weight = recurrent_weight
359
+ @bias = bias
360
+ @update_sigmoid = Sigmoid.new
361
+ @reset_sigmoid = Sigmoid.new
362
+ @tanh = Tanh.new
363
+ @trainable = true
364
+ end
365
+
366
+ def forward(x, h)
367
+ @x = x
368
+ @h = h
369
+ num_nodes = h.shape[1]
370
+ @weight_a = @weight.data[true, 0...(num_nodes * 2)]
371
+ @weight2_a = @recurrent_weight.data[true, 0...(num_nodes * 2)]
372
+ a = x.dot(@weight_a) + h.dot(@weight2_a)
373
+ a += @bias.data[0...(num_nodes * 2)] if @bias
374
+ @update = @update_sigmoid.forward(a[true, 0...num_nodes])
375
+ @reset = @reset_sigmoid.forward(a[true, num_nodes..-1])
376
+
377
+ @weight_h = @weight.data[true, (num_nodes * 2)..-1]
378
+ @weight2_h = @recurrent_weight.data[true, (num_nodes * 2)..-1]
379
+ @tanh_h = if @bias
380
+ bias_h = @bias.data[(num_nodes * 2)..-1]
381
+ @tanh.forward(x.dot(@weight_h) + (h * @reset).dot(@weight2_h) + bias_h)
382
+ else
383
+ @tanh.forward(x.dot(@weight_h) + (h * @reset).dot(@weight2_h))
384
+ end
385
+ h2 = (1 - @update) * @tanh_h + @update * h
386
+ h2
387
+ end
388
+
389
+ def backward(dh2)
390
+ dtanh_h = @tanh.backward(dh2 * (1 - @update))
391
+ dh = dh2 * @update
392
+
393
+ if @trainable
394
+ dweight_h = @x.transpose.dot(dtanh_h)
395
+ dweight2_h = (@h * @reset).transpose.dot(dtanh_h)
396
+ dbias_h = dtanh_h.sum(0) if @bias
397
+ end
398
+ dx = dtanh_h.dot(@weight_h.transpose)
399
+ dh += dtanh_h.dot(@weight2_h.transpose) * @reset
400
+
401
+ dreset = @reset_sigmoid.backward(dtanh_h.dot(@weight2_h.transpose) * @h)
402
+ dupdate = @update_sigmoid.backward(dh2 * @h - dh2 * @tanh_h)
403
+ da = Xumo::SFloat.hstack([dupdate, dreset])
404
+ if @trainable
405
+ dweight_a = @x.transpose.dot(da)
406
+ dweight2_a = @h.transpose.dot(da)
407
+ dbias_a = da.sum(0) if @bias
408
+ end
409
+ dx += da.dot(@weight_a.transpose)
410
+ dh += da.dot(@weight2_a.transpose)
411
+
412
+ if @trainable
413
+ @weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
414
+ @recurrent_weight.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
415
+ @bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h]) if @bias
416
+ end
417
+ [dx, dh]
418
+ end
419
+ end
420
+
421
+
422
+ class GRU < RNN
423
+ def self.from_hash(hash)
424
+ gru = self.new(hash[:num_nodes],
425
+ stateful: hash[:stateful],
426
+ return_sequences: hash[:return_sequences],
427
+ weight_initializer: Utils.from_hash(hash[:weight_initializer]),
428
+ recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
429
+ bias_initializer: Utils.from_hash(hash[:bias_initializer]),
430
+ weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
431
+ recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
432
+ bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
433
+ use_bias: hash[:use_bias])
434
+ gru
435
+ end
436
+
437
+ def initialize(num_nodes,
438
+ stateful: false,
439
+ return_sequences: true,
440
+ weight_initializer: RandomNormal.new,
441
+ recurrent_weight_initializer: RandomNormal.new,
442
+ bias_initializer: Zeros.new,
443
+ weight_regularizer: nil,
444
+ recurrent_weight_regularizer: nil,
445
+ bias_regularizer: nil,
446
+ use_bias: true)
447
+ super
448
+ end
449
+
450
+ def build(input_shape)
451
+ super
452
+ num_prev_nodes = @input_shape[1]
453
+ @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 3)
454
+ @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
455
+ @bias.data = Xumo::SFloat.new(@num_nodes * 3) if @bias
456
+ init_weight_and_bias
457
+ @time_length.times do |t|
458
+ @layers << GRU_Dense.new(@weight, @recurrent_weight, @bias)
459
+ end
460
+ end
461
+ end
462
+
463
+ end
464
+ end