ruby-dnn 0.6.1 → 0.6.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/API-Reference.ja.md +13 -1
- data/lib/dnn/core/activations.rb +24 -0
- data/lib/dnn/core/rnn_layers.rb +128 -4
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 530a8532898c2b8f934c257ff0d986407763fb3df432df8f6d061d31147b97d8
|
4
|
+
data.tar.gz: 333fb2f4e0cfd6c6aa78d3bea8f9159433acec613d8017218d24688ab1384d24
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9979622493c8056c75ea3c9a5561ffa6cb2a63147d7c230a43870125ae8a023d44ed9615b0fa48f169c29d19a3979d109570f46baf9122046be01080d10baf01
|
7
|
+
data.tar.gz: f9730237fcc631dab2109909508d73559dcbbcdad996578be7f1252776ab048cab0510b17e58a6662358dde6e8bd81356348c2576347b73aa8aaa0c0e4cd689b
|
data/API-Reference.ja.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
ruby-dnnのAPIリファレンスです。このリファレンスでは、APIを利用するうえで必要となるクラスとメソッドしか記載していません。
|
3
3
|
そのため、プログラムの詳細が必要な場合は、ソースコードを参照してください。
|
4
4
|
|
5
|
-
最終更新バージョン:0.6.
|
5
|
+
最終更新バージョン:0.6.2
|
6
6
|
|
7
7
|
# module DNN
|
8
8
|
ruby-dnnの名前空間をなすモジュールです。
|
@@ -484,6 +484,10 @@ Numo::SFloat
|
|
484
484
|
nilを設定することで、中間層のセルステートをリセットすることができます。
|
485
485
|
|
486
486
|
|
487
|
+
# class GRU < RNN
|
488
|
+
GRUレイヤーを扱うクラスです。
|
489
|
+
|
490
|
+
|
487
491
|
# class Flatten
|
488
492
|
N次元のデータを平坦化します。
|
489
493
|
|
@@ -572,6 +576,14 @@ Numo::SFloat y
|
|
572
576
|
tanh関数のレイヤーです。
|
573
577
|
|
574
578
|
|
579
|
+
# class Softsign < Layer
|
580
|
+
softsign関数のレイヤーです。
|
581
|
+
|
582
|
+
|
583
|
+
# class Softplus < Layer
|
584
|
+
softplus関数のレイヤーです。
|
585
|
+
|
586
|
+
|
575
587
|
# class ReLU < Layer
|
576
588
|
ランプ関数のレイヤーです。
|
577
589
|
|
data/lib/dnn/core/activations.rb
CHANGED
@@ -68,6 +68,30 @@ module DNN
|
|
68
68
|
end
|
69
69
|
|
70
70
|
|
71
|
+
class Softsign < Layers::Layer
|
72
|
+
def forward(x)
|
73
|
+
@x = x
|
74
|
+
x / (1 + x.abs)
|
75
|
+
end
|
76
|
+
|
77
|
+
def backward(dout)
|
78
|
+
dout * (1 / (1 + @x.abs)**2)
|
79
|
+
end
|
80
|
+
end
|
81
|
+
|
82
|
+
|
83
|
+
class Softplus < Layers::Layer
|
84
|
+
def forward(x)
|
85
|
+
@x = x
|
86
|
+
Xumo::NMath.log(1 + Xumo::NMath.exp(x))
|
87
|
+
end
|
88
|
+
|
89
|
+
def backward(dout)
|
90
|
+
dout * (1 / (1 + Xumo::NMath.exp(-@x)))
|
91
|
+
end
|
92
|
+
end
|
93
|
+
|
94
|
+
|
71
95
|
class IdentityMSE < Layers::OutputLayer
|
72
96
|
def forward(x)
|
73
97
|
@out = x
|
data/lib/dnn/core/rnn_layers.rb
CHANGED
@@ -191,10 +191,10 @@ module DNN
|
|
191
191
|
@in = @in_sigmoid.forward(a[true, (num_nodes * 2)...(num_nodes * 3)])
|
192
192
|
@out = @out_sigmoid.forward(a[true, (num_nodes * 3)..-1])
|
193
193
|
|
194
|
-
|
195
|
-
@tanh_cell2 = @tanh.forward(
|
196
|
-
|
197
|
-
[
|
194
|
+
cell2 = @forget * cell + @g * @in
|
195
|
+
@tanh_cell2 = @tanh.forward(cell2)
|
196
|
+
h2 = @out * @tanh_cell2
|
197
|
+
[h2, cell2]
|
198
198
|
end
|
199
199
|
|
200
200
|
def backward(dh2, dcell2)
|
@@ -299,5 +299,129 @@ module DNN
|
|
299
299
|
end
|
300
300
|
end
|
301
301
|
|
302
|
+
|
303
|
+
class GRU_Dense
|
304
|
+
def initialize(params, grads)
|
305
|
+
@params = params
|
306
|
+
@grads = grads
|
307
|
+
@update_sigmoid = Sigmoid.new
|
308
|
+
@reset_sigmoid = Sigmoid.new
|
309
|
+
@tanh = Tanh.new
|
310
|
+
end
|
311
|
+
|
312
|
+
def forward(x, h)
|
313
|
+
@x = x
|
314
|
+
@h = h
|
315
|
+
num_nodes = h.shape[1]
|
316
|
+
@weight_a = @params[:weight][true, 0...(num_nodes * 2)]
|
317
|
+
@weight2_a = @params[:weight2][true, 0...(num_nodes * 2)]
|
318
|
+
bias_a = @params[:bias][0...(num_nodes * 2)]
|
319
|
+
a = x.dot(@weight_a) + h.dot(@weight2_a) + bias_a
|
320
|
+
@update = @update_sigmoid.forward(a[true, 0...num_nodes])
|
321
|
+
@reset = @reset_sigmoid.forward(a[true, num_nodes..-1])
|
322
|
+
|
323
|
+
@weight_h = @params[:weight][true, (num_nodes * 2)..-1]
|
324
|
+
@weight2_h = @params[:weight2][true, (num_nodes * 2)..-1]
|
325
|
+
bias_h = @params[:bias][(num_nodes * 2)..-1]
|
326
|
+
@tanh_h = @tanh.forward(x.dot(@weight_h) + (h * @reset).dot(@weight2_h) + bias_h)
|
327
|
+
h2 = (1 - @update) * h + @update * @tanh_h
|
328
|
+
h2
|
329
|
+
end
|
330
|
+
|
331
|
+
def backward(dh2)
|
332
|
+
dtanh_h = @tanh.backward(dh2 * @update)
|
333
|
+
dh = dh2 * (1 - @update)
|
334
|
+
|
335
|
+
dweight_h = @x.transpose.dot(dtanh_h)
|
336
|
+
dx = dtanh_h.dot(@weight_h.transpose)
|
337
|
+
dweight2_h = (@h * @reset).transpose.dot(dtanh_h)
|
338
|
+
dh += dtanh_h.dot(@weight2_h.transpose) * @reset
|
339
|
+
dbias_h = dtanh_h.sum(0)
|
340
|
+
|
341
|
+
dreset = @reset_sigmoid.backward(dtanh_h.dot(@weight2_h.transpose) * @h)
|
342
|
+
dupdate = @update_sigmoid.backward(dh2 * @tanh_h - dh2 * @h)
|
343
|
+
da = Xumo::SFloat.hstack([dupdate, dreset])
|
344
|
+
dweight_a = @x.transpose.dot(da)
|
345
|
+
dx += da.dot(@weight_a.transpose)
|
346
|
+
dweight2_a = @h.transpose.dot(da)
|
347
|
+
dh += da.dot(@weight2_a.transpose)
|
348
|
+
dbias_a = da.sum(0)
|
349
|
+
|
350
|
+
@grads[:weight] += Xumo::SFloat.hstack([dweight_a, dweight_h])
|
351
|
+
@grads[:weight2] += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
|
352
|
+
@grads[:bias] += Xumo::SFloat.hstack([dbias_a, dbias_h])
|
353
|
+
[dx, dh]
|
354
|
+
end
|
355
|
+
end
|
356
|
+
|
357
|
+
|
358
|
+
class GRU < RNN
|
359
|
+
def self.load_hash(hash)
|
360
|
+
self.new(hash[:num_nodes],
|
361
|
+
stateful: hash[:stateful],
|
362
|
+
return_sequences: hash[:return_sequences],
|
363
|
+
weight_initializer: Util.load_hash(hash[:weight_initializer]),
|
364
|
+
bias_initializer: Util.load_hash(hash[:bias_initializer]),
|
365
|
+
weight_decay: hash[:weight_decay])
|
366
|
+
end
|
367
|
+
|
368
|
+
def initialize(num_nodes,
|
369
|
+
stateful: false,
|
370
|
+
return_sequences: true,
|
371
|
+
weight_initializer: nil,
|
372
|
+
bias_initializer: nil,
|
373
|
+
weight_decay: 0)
|
374
|
+
super
|
375
|
+
end
|
376
|
+
|
377
|
+
def forward(xs)
|
378
|
+
@xs_shape = xs.shape
|
379
|
+
hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes)
|
380
|
+
h = (@stateful && @h) ? @h : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
|
381
|
+
xs.shape[1].times do |t|
|
382
|
+
x = xs[true, t, false]
|
383
|
+
h = @layers[t].forward(x, h)
|
384
|
+
hs[true, t, false] = h
|
385
|
+
end
|
386
|
+
@h = h
|
387
|
+
@return_sequences ? hs : h
|
388
|
+
end
|
389
|
+
|
390
|
+
def backward(dh2s)
|
391
|
+
@grads[:weight] = Xumo::SFloat.zeros(*@params[:weight].shape)
|
392
|
+
@grads[:weight2] = Xumo::SFloat.zeros(*@params[:weight2].shape)
|
393
|
+
@grads[:bias] = Xumo::SFloat.zeros(*@params[:bias].shape)
|
394
|
+
unless @return_sequences
|
395
|
+
dh = dh2s
|
396
|
+
dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
|
397
|
+
dh2s[true, -1, false] = dh
|
398
|
+
end
|
399
|
+
dxs = Xumo::SFloat.zeros(@xs_shape)
|
400
|
+
dh = 0
|
401
|
+
(0...dh2s.shape[1]).to_a.reverse.each do |t|
|
402
|
+
dh2 = dh2s[true, t, false]
|
403
|
+
dx, dh = @layers[t].backward(dh2 + dh)
|
404
|
+
dxs[true, t, false] = dx
|
405
|
+
end
|
406
|
+
dxs
|
407
|
+
end
|
408
|
+
|
409
|
+
private
|
410
|
+
|
411
|
+
def init_params
|
412
|
+
super()
|
413
|
+
num_prev_nodes = prev_layer.shape[1]
|
414
|
+
@params[:weight] = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 3)
|
415
|
+
@params[:weight2] = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
|
416
|
+
@params[:bias] = Xumo::SFloat.new(@num_nodes * 3)
|
417
|
+
@weight_initializer.init_param(self, :weight)
|
418
|
+
@weight_initializer.init_param(self, :weight2)
|
419
|
+
@bias_initializer.init_param(self, :bias)
|
420
|
+
@time_length.times do |t|
|
421
|
+
@layers << GRU_Dense.new(@params, @grads)
|
422
|
+
end
|
423
|
+
end
|
424
|
+
end
|
425
|
+
|
302
426
|
end
|
303
427
|
end
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.6.
|
4
|
+
version: 0.6.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-08-
|
11
|
+
date: 2018-08-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|