ruby-dnn 0.5.7 → 0.5.8

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 0a3a2ea891bd75d6d2d3b97bc8e850c220c4118d87a1e7fcd47e1f9dd6943a08
4
- data.tar.gz: 002d15961430e42b06b1b180718170b4e7e8f10b99b6569c312a550dc486a0e9
3
+ metadata.gz: '040681c989e47e7c183f46ff921285db8c5fd541112f1922b07262c27383d96a'
4
+ data.tar.gz: 84a1130bed58297aac0414c3eb5842aec3635e4f8aeae826fd1f7d7fbdecdc97
5
5
  SHA512:
6
- metadata.gz: ed91236978caff9d0def15edae2d743d03bb5e11b2546a8fec4f5d7cd3653af4b3f98d8fdf5cc2d93b10273dbebe84cca0f3b3fc7347afb846e489646dae51c2
7
- data.tar.gz: e44ec8438481761ca2ba28ca6b6839abe4c97025a1985eb1be717bb85b9c44d11e71355a68370db1305412e67a3bcf5094c00fa05303f2ba2f33d5df5c324d87
6
+ metadata.gz: 39c4c5b6f2557ae9bb39d67d248c33488f20aba8ec665221d744b9f1ee1f10f310e88e38689a53202536d58cc19ccf1b5cb2579bd50589888d6fd0318b451628
7
+ data.tar.gz: 7d2f2311a1123b5bc34fac29b0f7ee73b5638121ef1771a385be9a129a1fea49a3b0385ccb98fa2f7f0985e26cc681b167a20ff31b5926e15c16cc8c640b3ce9
@@ -1,6 +1,32 @@
1
1
  module DNN
2
2
  module Layers
3
3
 
4
+ class SimpleRNN_Dense
5
+ def initialize(params, grads, activation)
6
+ @params = params
7
+ @grads = grads
8
+ @activation = activation
9
+ end
10
+
11
+ def forward(x, h)
12
+ @x = x
13
+ @h = h
14
+ h2 = x.dot(@params[:weight]) + h.dot(@params[:weight2]) + @params[:bias]
15
+ @activation.forward(h2)
16
+ end
17
+
18
+ def backward(dh2)
19
+ dh2 = @activation.backward(dh2)
20
+ @grads[:weight] += @x.transpose.dot(dh2)
21
+ @grads[:weight2] += @h.transpose.dot(dh2)
22
+ @grads[:bias] += dh2.sum(0)
23
+ dx = dh2.dot(@params[:weight].transpose)
24
+ dh = dh2.dot(@params[:weight2].transpose)
25
+ [dx, dh]
26
+ end
27
+ end
28
+
29
+
4
30
  class SimpleRNN < HasParamLayer
5
31
  include Initializers
6
32
  include Activations
@@ -31,37 +57,34 @@ module DNN
31
57
  @weight_initializer = (weight_initializer || RandomNormal.new)
32
58
  @bias_initializer = (bias_initializer || Zeros.new)
33
59
  @weight_decay = weight_decay
60
+ @layers = []
34
61
  @h = nil
35
62
  end
36
63
 
37
64
  def forward(xs)
38
- @xs = xs
39
- @hs = SFloat.zeros(xs.shape[0], *shape)
65
+ @xs_shape = xs.shape
66
+ hs = SFloat.zeros(xs.shape[0], *shape)
40
67
  h = (@stateful && @h) ? @h : SFloat.zeros(xs.shape[0], @num_nodes)
41
68
  xs.shape[1].times do |t|
42
69
  x = xs[true, t, false]
43
- h = x.dot(@params[:weight]) + h.dot(@params[:weight2]) + @params[:bias]
44
- h = @activation.forward(h)
45
- @hs[true, t, false] = h
70
+ h = @layers[t].forward(x, h)
71
+ hs[true, t, false] = h
46
72
  end
47
73
  @h = h
48
- @hs
74
+ hs
49
75
  end
50
76
 
51
- def backward(douts)
77
+ def backward(dh2s)
52
78
  @grads[:weight] = SFloat.zeros(*@params[:weight].shape)
53
79
  @grads[:weight2] = SFloat.zeros(*@params[:weight2].shape)
54
- dxs = SFloat.zeros(@xs.shape)
55
- (0...douts.shape[1]).to_a.reverse.each do |t|
56
- dout = douts[true, t, false]
57
- x = @xs[true, t, false]
58
- h = @hs[true, t, false]
59
- dout = @activation.backward(dout)
60
- @grads[:weight] += x.transpose.dot(dout)
61
- @grads[:weight2] += h.transpose.dot(dout)
62
- dxs[true, t, false] = dout.dot(@params[:weight].transpose)
80
+ @grads[:bias] = SFloat.zeros(*@params[:bias].shape)
81
+ dxs = SFloat.zeros(@xs_shape)
82
+ dh = 0
83
+ (0...dh2s.shape[1]).to_a.reverse.each do |t|
84
+ dh2 = dh2s[true, t, false]
85
+ dx, dh = @layers[t].backward(dh2 + dh)
86
+ dxs[true, t, false] = dx
63
87
  end
64
- @grads[:bias] = douts.sum(0).sum(0)
65
88
  dxs
66
89
  end
67
90
 
@@ -97,6 +120,9 @@ module DNN
97
120
  @weight_initializer.init_param(self, :weight)
98
121
  @weight_initializer.init_param(self, :weight2)
99
122
  @bias_initializer.init_param(self, :bias)
123
+ @time_length.times do |t|
124
+ @layers << SimpleRNN_Dense.new(@params, @grads, @activation.clone)
125
+ end
100
126
  end
101
127
  end
102
128
 
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.5.7"
2
+ VERSION = "0.5.8"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.7
4
+ version: 0.5.8
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-08-08 00:00:00.000000000 Z
11
+ date: 2018-08-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray