ruby-dnn 0.5.12 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/dnn/core/activations.rb +5 -6
- data/lib/dnn/core/initializers.rb +1 -1
- data/lib/dnn/core/layers.rb +1 -1
- data/lib/dnn/core/model.rb +13 -22
- data/lib/dnn/core/optimizers.rb +1 -1
- data/lib/dnn/core/rnn_layers.rb +6 -2
- data/lib/dnn/core/util.rb +1 -1
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: '08a74d28175b0e881a1cb7b93d9d973ac601e60498f1422890214d6888f7671a'
|
4
|
+
data.tar.gz: ebde049f70c4e43b90daae6c3eea3bafbf2bd04d9a8556726da46e6cb75ed313
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 46a59448402d63fc8e594259d8bb084e029595acd58a73c7f9558bfe949029f12a406cf9b996b41822a093a3c130e9b7a1249b3b9db61a4ae11611df5ff91264
|
7
|
+
data.tar.gz: 31fd9ff92eb43999f30162650677ac470bec417845ca2cc613cf282b4840e4b5f3d698724d05392ac631c19b2a705ff0034c303684a88998996be8214231779f
|
data/lib/dnn/core/activations.rb
CHANGED
@@ -1,8 +1,7 @@
|
|
1
1
|
module DNN
|
2
2
|
module Activations
|
3
|
-
Layer = Layers::Layer
|
4
3
|
|
5
|
-
class Sigmoid < Layer
|
4
|
+
class Sigmoid < Layers::Layer
|
6
5
|
def forward(x)
|
7
6
|
@out = 1 / (1 + Xumo::NMath.exp(-x))
|
8
7
|
end
|
@@ -13,7 +12,7 @@ module DNN
|
|
13
12
|
end
|
14
13
|
|
15
14
|
|
16
|
-
class Tanh < Layer
|
15
|
+
class Tanh < Layers::Layer
|
17
16
|
def forward(x)
|
18
17
|
@out = Xumo::NMath.tanh(x)
|
19
18
|
end
|
@@ -24,7 +23,7 @@ module DNN
|
|
24
23
|
end
|
25
24
|
|
26
25
|
|
27
|
-
class ReLU < Layer
|
26
|
+
class ReLU < Layers::Layer
|
28
27
|
def forward(x)
|
29
28
|
@x = x.clone
|
30
29
|
x[x < 0] = 0
|
@@ -39,7 +38,7 @@ module DNN
|
|
39
38
|
end
|
40
39
|
|
41
40
|
|
42
|
-
class LeakyReLU < Layer
|
41
|
+
class LeakyReLU < Layers::Layer
|
43
42
|
attr_reader :alpha
|
44
43
|
|
45
44
|
def initialize(alpha = 0.3)
|
@@ -64,7 +63,7 @@ module DNN
|
|
64
63
|
end
|
65
64
|
|
66
65
|
def to_hash
|
67
|
-
{
|
66
|
+
{class: self.class.name, alpha: alpha}
|
68
67
|
end
|
69
68
|
end
|
70
69
|
|
data/lib/dnn/core/layers.rb
CHANGED
data/lib/dnn/core/model.rb
CHANGED
@@ -22,7 +22,6 @@ module DNN
|
|
22
22
|
def initialize
|
23
23
|
@layers = []
|
24
24
|
@optimizer = nil
|
25
|
-
@batch_size = nil
|
26
25
|
@training = false
|
27
26
|
@compiled = false
|
28
27
|
end
|
@@ -103,13 +102,12 @@ module DNN
|
|
103
102
|
unless compiled?
|
104
103
|
raise DNN_Error.new("The model is not compiled.")
|
105
104
|
end
|
106
|
-
@batch_size = batch_size
|
107
105
|
num_train_data = x.shape[0]
|
108
106
|
(1..epochs).each do |epoch|
|
109
107
|
puts "【 epoch #{epoch}/#{epochs} 】" if verbose
|
110
|
-
(num_train_data.to_f /
|
111
|
-
x_batch, y_batch = Util.get_minibatch(x, y,
|
112
|
-
loss = train_on_batch(x_batch, y_batch,
|
108
|
+
(num_train_data.to_f / batch_size).ceil.times do |index|
|
109
|
+
x_batch, y_batch = Util.get_minibatch(x, y, batch_size)
|
110
|
+
loss = train_on_batch(x_batch, y_batch, &batch_proc)
|
113
111
|
if loss.nan?
|
114
112
|
puts "\nloss is nan" if verbose
|
115
113
|
return
|
@@ -130,7 +128,7 @@ module DNN
|
|
130
128
|
print log if verbose
|
131
129
|
end
|
132
130
|
if verbose && test
|
133
|
-
acc = accurate(test[0], test[1], batch_size
|
131
|
+
acc = accurate(test[0], test[1], batch_size, &batch_proc)
|
134
132
|
print " accurate: #{acc}"
|
135
133
|
end
|
136
134
|
puts "" if verbose
|
@@ -138,8 +136,7 @@ module DNN
|
|
138
136
|
end
|
139
137
|
end
|
140
138
|
|
141
|
-
def train_on_batch(x, y,
|
142
|
-
@batch_size = batch_size
|
139
|
+
def train_on_batch(x, y, &batch_proc)
|
143
140
|
x, y = batch_proc.call(x, y) if batch_proc
|
144
141
|
forward(x, true)
|
145
142
|
loss = @layers[-1].loss(y)
|
@@ -148,27 +145,21 @@ module DNN
|
|
148
145
|
loss
|
149
146
|
end
|
150
147
|
|
151
|
-
def accurate(x, y, batch_size =
|
152
|
-
|
153
|
-
if @batch_size
|
154
|
-
batch_size = @batch_size >= x.shape[0] ? @batch_size : x.shape[0]
|
155
|
-
else
|
156
|
-
batch_size = 1
|
157
|
-
end
|
158
|
-
end
|
148
|
+
def accurate(x, y, batch_size = 1, &batch_proc)
|
149
|
+
batch_size = batch_size >= x.shape[0] ? batch_size : x.shape[0]
|
159
150
|
correct = 0
|
160
|
-
(x.shape[0].to_f /
|
161
|
-
x_batch = Xumo::SFloat.zeros(
|
162
|
-
y_batch = Xumo::SFloat.zeros(
|
163
|
-
|
164
|
-
k = i *
|
151
|
+
(x.shape[0].to_f / batch_size).ceil.times do |i|
|
152
|
+
x_batch = Xumo::SFloat.zeros(batch_size, *x.shape[1..-1])
|
153
|
+
y_batch = Xumo::SFloat.zeros(batch_size, *y.shape[1..-1])
|
154
|
+
batch_size.times do |j|
|
155
|
+
k = i * batch_size + j
|
165
156
|
break if k >= x.shape[0]
|
166
157
|
x_batch[j, false] = x[k, false]
|
167
158
|
y_batch[j, false] = y[k, false]
|
168
159
|
end
|
169
160
|
x_batch, y_batch = batch_proc.call(x_batch, y_batch) if batch_proc
|
170
161
|
out = forward(x_batch, false)
|
171
|
-
|
162
|
+
batch_size.times do |j|
|
172
163
|
if @layers[-1].shape == [1]
|
173
164
|
correct += 1 if out[j, 0].round == y_batch[j, 0].round
|
174
165
|
else
|
data/lib/dnn/core/optimizers.rb
CHANGED
data/lib/dnn/core/rnn_layers.rb
CHANGED
@@ -6,6 +6,7 @@ module DNN
|
|
6
6
|
include Initializers
|
7
7
|
include Activations
|
8
8
|
|
9
|
+
attr_accessor :h
|
9
10
|
attr_reader :num_nodes
|
10
11
|
attr_reader :stateful
|
11
12
|
attr_reader :weight_decay
|
@@ -29,7 +30,7 @@ module DNN
|
|
29
30
|
|
30
31
|
def to_hash(merge_hash = nil)
|
31
32
|
hash = {
|
32
|
-
|
33
|
+
class: self.class.name,
|
33
34
|
num_nodes: @num_nodes,
|
34
35
|
stateful: @stateful,
|
35
36
|
return_sequences: @return_sequences,
|
@@ -48,7 +49,7 @@ module DNN
|
|
48
49
|
|
49
50
|
def ridge
|
50
51
|
if @weight_decay > 0
|
51
|
-
0.5 * (@weight_decay * (@params[:weight]**2).sum +
|
52
|
+
0.5 * (@weight_decay * ((@params[:weight]**2).sum + (@params[:weight2]**2).sum))
|
52
53
|
else
|
53
54
|
0
|
54
55
|
end
|
@@ -219,6 +220,8 @@ module DNN
|
|
219
220
|
|
220
221
|
|
221
222
|
class LSTM < RNN
|
223
|
+
attr_accessor :cell
|
224
|
+
|
222
225
|
def self.load_hash(hash)
|
223
226
|
self.new(hash[:num_nodes],
|
224
227
|
stateful: hash[:stateful],
|
@@ -235,6 +238,7 @@ module DNN
|
|
235
238
|
bias_initializer: nil,
|
236
239
|
weight_decay: 0)
|
237
240
|
super
|
241
|
+
@cell = nil
|
238
242
|
end
|
239
243
|
|
240
244
|
def forward(xs)
|
data/lib/dnn/core/util.rb
CHANGED
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.6.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-08-
|
11
|
+
date: 2018-08-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|