ruby-dnn 0.5.7 → 0.5.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 0a3a2ea891bd75d6d2d3b97bc8e850c220c4118d87a1e7fcd47e1f9dd6943a08
4
- data.tar.gz: 002d15961430e42b06b1b180718170b4e7e8f10b99b6569c312a550dc486a0e9
3
+ metadata.gz: '040681c989e47e7c183f46ff921285db8c5fd541112f1922b07262c27383d96a'
4
+ data.tar.gz: 84a1130bed58297aac0414c3eb5842aec3635e4f8aeae826fd1f7d7fbdecdc97
5
5
  SHA512:
6
- metadata.gz: ed91236978caff9d0def15edae2d743d03bb5e11b2546a8fec4f5d7cd3653af4b3f98d8fdf5cc2d93b10273dbebe84cca0f3b3fc7347afb846e489646dae51c2
7
- data.tar.gz: e44ec8438481761ca2ba28ca6b6839abe4c97025a1985eb1be717bb85b9c44d11e71355a68370db1305412e67a3bcf5094c00fa05303f2ba2f33d5df5c324d87
6
+ metadata.gz: 39c4c5b6f2557ae9bb39d67d248c33488f20aba8ec665221d744b9f1ee1f10f310e88e38689a53202536d58cc19ccf1b5cb2579bd50589888d6fd0318b451628
7
+ data.tar.gz: 7d2f2311a1123b5bc34fac29b0f7ee73b5638121ef1771a385be9a129a1fea49a3b0385ccb98fa2f7f0985e26cc681b167a20ff31b5926e15c16cc8c640b3ce9
@@ -1,6 +1,32 @@
1
1
  module DNN
2
2
  module Layers
3
3
 
4
+ class SimpleRNN_Dense
5
+ def initialize(params, grads, activation)
6
+ @params = params
7
+ @grads = grads
8
+ @activation = activation
9
+ end
10
+
11
+ def forward(x, h)
12
+ @x = x
13
+ @h = h
14
+ h2 = x.dot(@params[:weight]) + h.dot(@params[:weight2]) + @params[:bias]
15
+ @activation.forward(h2)
16
+ end
17
+
18
+ def backward(dh2)
19
+ dh2 = @activation.backward(dh2)
20
+ @grads[:weight] += @x.transpose.dot(dh2)
21
+ @grads[:weight2] += @h.transpose.dot(dh2)
22
+ @grads[:bias] += dh2.sum(0)
23
+ dx = dh2.dot(@params[:weight].transpose)
24
+ dh = dh2.dot(@params[:weight2].transpose)
25
+ [dx, dh]
26
+ end
27
+ end
28
+
29
+
4
30
  class SimpleRNN < HasParamLayer
5
31
  include Initializers
6
32
  include Activations
@@ -31,37 +57,34 @@ module DNN
31
57
  @weight_initializer = (weight_initializer || RandomNormal.new)
32
58
  @bias_initializer = (bias_initializer || Zeros.new)
33
59
  @weight_decay = weight_decay
60
+ @layers = []
34
61
  @h = nil
35
62
  end
36
63
 
37
64
  def forward(xs)
38
- @xs = xs
39
- @hs = SFloat.zeros(xs.shape[0], *shape)
65
+ @xs_shape = xs.shape
66
+ hs = SFloat.zeros(xs.shape[0], *shape)
40
67
  h = (@stateful && @h) ? @h : SFloat.zeros(xs.shape[0], @num_nodes)
41
68
  xs.shape[1].times do |t|
42
69
  x = xs[true, t, false]
43
- h = x.dot(@params[:weight]) + h.dot(@params[:weight2]) + @params[:bias]
44
- h = @activation.forward(h)
45
- @hs[true, t, false] = h
70
+ h = @layers[t].forward(x, h)
71
+ hs[true, t, false] = h
46
72
  end
47
73
  @h = h
48
- @hs
74
+ hs
49
75
  end
50
76
 
51
- def backward(douts)
77
+ def backward(dh2s)
52
78
  @grads[:weight] = SFloat.zeros(*@params[:weight].shape)
53
79
  @grads[:weight2] = SFloat.zeros(*@params[:weight2].shape)
54
- dxs = SFloat.zeros(@xs.shape)
55
- (0...douts.shape[1]).to_a.reverse.each do |t|
56
- dout = douts[true, t, false]
57
- x = @xs[true, t, false]
58
- h = @hs[true, t, false]
59
- dout = @activation.backward(dout)
60
- @grads[:weight] += x.transpose.dot(dout)
61
- @grads[:weight2] += h.transpose.dot(dout)
62
- dxs[true, t, false] = dout.dot(@params[:weight].transpose)
80
+ @grads[:bias] = SFloat.zeros(*@params[:bias].shape)
81
+ dxs = SFloat.zeros(@xs_shape)
82
+ dh = 0
83
+ (0...dh2s.shape[1]).to_a.reverse.each do |t|
84
+ dh2 = dh2s[true, t, false]
85
+ dx, dh = @layers[t].backward(dh2 + dh)
86
+ dxs[true, t, false] = dx
63
87
  end
64
- @grads[:bias] = douts.sum(0).sum(0)
65
88
  dxs
66
89
  end
67
90
 
@@ -97,6 +120,9 @@ module DNN
97
120
  @weight_initializer.init_param(self, :weight)
98
121
  @weight_initializer.init_param(self, :weight2)
99
122
  @bias_initializer.init_param(self, :bias)
123
+ @time_length.times do |t|
124
+ @layers << SimpleRNN_Dense.new(@params, @grads, @activation.clone)
125
+ end
100
126
  end
101
127
  end
102
128
 
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.5.7"
2
+ VERSION = "0.5.8"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.7
4
+ version: 0.5.8
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-08-08 00:00:00.000000000 Z
11
+ date: 2018-08-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray