ruby-dnn 0.5.7 → 0.5.8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/dnn/core/rnn_layers.rb +43 -17
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: '040681c989e47e7c183f46ff921285db8c5fd541112f1922b07262c27383d96a'
|
4
|
+
data.tar.gz: 84a1130bed58297aac0414c3eb5842aec3635e4f8aeae826fd1f7d7fbdecdc97
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 39c4c5b6f2557ae9bb39d67d248c33488f20aba8ec665221d744b9f1ee1f10f310e88e38689a53202536d58cc19ccf1b5cb2579bd50589888d6fd0318b451628
|
7
|
+
data.tar.gz: 7d2f2311a1123b5bc34fac29b0f7ee73b5638121ef1771a385be9a129a1fea49a3b0385ccb98fa2f7f0985e26cc681b167a20ff31b5926e15c16cc8c640b3ce9
|
data/lib/dnn/core/rnn_layers.rb
CHANGED
@@ -1,6 +1,32 @@
|
|
1
1
|
module DNN
|
2
2
|
module Layers
|
3
3
|
|
4
|
+
class SimpleRNN_Dense
|
5
|
+
def initialize(params, grads, activation)
|
6
|
+
@params = params
|
7
|
+
@grads = grads
|
8
|
+
@activation = activation
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(x, h)
|
12
|
+
@x = x
|
13
|
+
@h = h
|
14
|
+
h2 = x.dot(@params[:weight]) + h.dot(@params[:weight2]) + @params[:bias]
|
15
|
+
@activation.forward(h2)
|
16
|
+
end
|
17
|
+
|
18
|
+
def backward(dh2)
|
19
|
+
dh2 = @activation.backward(dh2)
|
20
|
+
@grads[:weight] += @x.transpose.dot(dh2)
|
21
|
+
@grads[:weight2] += @h.transpose.dot(dh2)
|
22
|
+
@grads[:bias] += dh2.sum(0)
|
23
|
+
dx = dh2.dot(@params[:weight].transpose)
|
24
|
+
dh = dh2.dot(@params[:weight2].transpose)
|
25
|
+
[dx, dh]
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
|
4
30
|
class SimpleRNN < HasParamLayer
|
5
31
|
include Initializers
|
6
32
|
include Activations
|
@@ -31,37 +57,34 @@ module DNN
|
|
31
57
|
@weight_initializer = (weight_initializer || RandomNormal.new)
|
32
58
|
@bias_initializer = (bias_initializer || Zeros.new)
|
33
59
|
@weight_decay = weight_decay
|
60
|
+
@layers = []
|
34
61
|
@h = nil
|
35
62
|
end
|
36
63
|
|
37
64
|
def forward(xs)
|
38
|
-
@
|
39
|
-
|
65
|
+
@xs_shape = xs.shape
|
66
|
+
hs = SFloat.zeros(xs.shape[0], *shape)
|
40
67
|
h = (@stateful && @h) ? @h : SFloat.zeros(xs.shape[0], @num_nodes)
|
41
68
|
xs.shape[1].times do |t|
|
42
69
|
x = xs[true, t, false]
|
43
|
-
h =
|
44
|
-
|
45
|
-
@hs[true, t, false] = h
|
70
|
+
h = @layers[t].forward(x, h)
|
71
|
+
hs[true, t, false] = h
|
46
72
|
end
|
47
73
|
@h = h
|
48
|
-
|
74
|
+
hs
|
49
75
|
end
|
50
76
|
|
51
|
-
def backward(
|
77
|
+
def backward(dh2s)
|
52
78
|
@grads[:weight] = SFloat.zeros(*@params[:weight].shape)
|
53
79
|
@grads[:weight2] = SFloat.zeros(*@params[:weight2].shape)
|
54
|
-
|
55
|
-
(
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
@grads[:weight2] += h.transpose.dot(dout)
|
62
|
-
dxs[true, t, false] = dout.dot(@params[:weight].transpose)
|
80
|
+
@grads[:bias] = SFloat.zeros(*@params[:bias].shape)
|
81
|
+
dxs = SFloat.zeros(@xs_shape)
|
82
|
+
dh = 0
|
83
|
+
(0...dh2s.shape[1]).to_a.reverse.each do |t|
|
84
|
+
dh2 = dh2s[true, t, false]
|
85
|
+
dx, dh = @layers[t].backward(dh2 + dh)
|
86
|
+
dxs[true, t, false] = dx
|
63
87
|
end
|
64
|
-
@grads[:bias] = douts.sum(0).sum(0)
|
65
88
|
dxs
|
66
89
|
end
|
67
90
|
|
@@ -97,6 +120,9 @@ module DNN
|
|
97
120
|
@weight_initializer.init_param(self, :weight)
|
98
121
|
@weight_initializer.init_param(self, :weight2)
|
99
122
|
@bias_initializer.init_param(self, :bias)
|
123
|
+
@time_length.times do |t|
|
124
|
+
@layers << SimpleRNN_Dense.new(@params, @grads, @activation.clone)
|
125
|
+
end
|
100
126
|
end
|
101
127
|
end
|
102
128
|
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.5.
|
4
|
+
version: 0.5.8
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-08-
|
11
|
+
date: 2018-08-11 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|