ruby-dnn 0.8.0 → 0.8.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/dnn/core/initializers.rb +10 -8
- data/lib/dnn/core/layers.rb +15 -9
- data/lib/dnn/core/learning_param.rb +0 -5
- data/lib/dnn/core/optimizers.rb +3 -1
- data/lib/dnn/core/rnn_layers.rb +10 -10
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 447ad62d50c89eb90c20e5d8cbec0bc849696a197fe21183c9f101b85b6da926
|
4
|
+
data.tar.gz: bc5a5d1f5b96991a38045cf4696b157692c5a82fc95480fe853ba2a691435e72
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6aff825a4f3e20053cfdc3459ad7186809664604c5fb9df0e8261f5197dc40e45652487f9d0ef7d01a39757506e086d75e2bb835a5e41fef7796321932111f04
|
7
|
+
data.tar.gz: 66bd57b376429617cf37012a2816cafb60e428d90a756b97460f9daa6d01120533b930c403fccab6ffce2a732be1684b6f090af30141b54dc2899144bb98fca1
|
@@ -3,7 +3,9 @@ module DNN
|
|
3
3
|
|
4
4
|
class Initializer
|
5
5
|
# Classes that inherit from this class must implement this method.
|
6
|
-
|
6
|
+
def init_param(layer, param)
|
7
|
+
raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'init_params'")
|
8
|
+
end
|
7
9
|
|
8
10
|
def to_hash(merge_hash = nil)
|
9
11
|
hash = {class: self.class.name}
|
@@ -14,7 +16,7 @@ module DNN
|
|
14
16
|
|
15
17
|
|
16
18
|
class Zeros < Initializer
|
17
|
-
def init_param(param)
|
19
|
+
def init_param(layer, param)
|
18
20
|
param.data = param.data.fill(0)
|
19
21
|
end
|
20
22
|
end
|
@@ -33,7 +35,7 @@ module DNN
|
|
33
35
|
@std = std
|
34
36
|
end
|
35
37
|
|
36
|
-
def init_param(param)
|
38
|
+
def init_param(layer, param)
|
37
39
|
param.data = param.data.rand_norm(@mean, @std)
|
38
40
|
end
|
39
41
|
|
@@ -56,7 +58,7 @@ module DNN
|
|
56
58
|
@max = max
|
57
59
|
end
|
58
60
|
|
59
|
-
def init_param(param)
|
61
|
+
def init_param(layer, param)
|
60
62
|
param.data = param.data.rand(@min, @max)
|
61
63
|
end
|
62
64
|
|
@@ -67,16 +69,16 @@ module DNN
|
|
67
69
|
|
68
70
|
|
69
71
|
class Xavier < Initializer
|
70
|
-
def init_param(param)
|
71
|
-
num_prev_nodes =
|
72
|
+
def init_param(layer, param)
|
73
|
+
num_prev_nodes = layer.prev_layer.shape.reduce(:*)
|
72
74
|
param.data = param.data.rand_norm / Math.sqrt(num_prev_nodes)
|
73
75
|
end
|
74
76
|
end
|
75
77
|
|
76
78
|
|
77
79
|
class He < Initializer
|
78
|
-
def init_param(param)
|
79
|
-
num_prev_nodes =
|
80
|
+
def init_param(layer, param)
|
81
|
+
num_prev_nodes = layer.prev_layer.shape.reduce(:*)
|
80
82
|
param.data = param.data.rand_norm / Math.sqrt(num_prev_nodes) * Math.sqrt(2)
|
81
83
|
end
|
82
84
|
end
|
data/lib/dnn/core/layers.rb
CHANGED
@@ -20,11 +20,15 @@ module DNN
|
|
20
20
|
|
21
21
|
# Forward propagation.
|
22
22
|
# Classes that inherit from this class must implement this method.
|
23
|
-
|
23
|
+
def forward
|
24
|
+
raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'forward'")
|
25
|
+
end
|
24
26
|
|
25
27
|
# Backward propagation.
|
26
28
|
# Classes that inherit from this class must implement this method.
|
27
|
-
|
29
|
+
def backward
|
30
|
+
raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'update'")
|
31
|
+
end
|
28
32
|
|
29
33
|
# Get the shape of the layer.
|
30
34
|
def shape
|
@@ -73,7 +77,9 @@ module DNN
|
|
73
77
|
|
74
78
|
# Initialize of the parameters.
|
75
79
|
# Classes that inherit from this class must implement this method.
|
76
|
-
def init_params
|
80
|
+
def init_params
|
81
|
+
raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'init_params'")
|
82
|
+
end
|
77
83
|
end
|
78
84
|
|
79
85
|
|
@@ -119,8 +125,8 @@ module DNN
|
|
119
125
|
@bias_initializer = (bias_initializer || Zeros.new)
|
120
126
|
@l1_lambda = l1_lambda
|
121
127
|
@l2_lambda = l2_lambda
|
122
|
-
@params[:weight] = @weight = LearningParam.new
|
123
|
-
@params[:bias] = @bias = LearningParam.new
|
128
|
+
@params[:weight] = @weight = LearningParam.new
|
129
|
+
@params[:bias] = @bias = LearningParam.new
|
124
130
|
end
|
125
131
|
|
126
132
|
def lasso
|
@@ -159,8 +165,8 @@ module DNN
|
|
159
165
|
private
|
160
166
|
|
161
167
|
def init_params
|
162
|
-
@weight_initializer.init_param(@weight)
|
163
|
-
@bias_initializer.init_param(@bias)
|
168
|
+
@weight_initializer.init_param(self, @weight)
|
169
|
+
@bias_initializer.init_param(self, @bias)
|
164
170
|
end
|
165
171
|
end
|
166
172
|
|
@@ -324,8 +330,8 @@ module DNN
|
|
324
330
|
def initialize(momentum: 0.9)
|
325
331
|
super()
|
326
332
|
@momentum = momentum
|
327
|
-
@params[:gamma] = @gamma = LearningParam.new
|
328
|
-
@params[:beta] = @beta = LearningParam.new
|
333
|
+
@params[:gamma] = @gamma = LearningParam.new
|
334
|
+
@params[:beta] = @beta = LearningParam.new
|
329
335
|
@params[:running_mean] = nil
|
330
336
|
@params[:running_var] = nil
|
331
337
|
end
|
data/lib/dnn/core/optimizers.rb
CHANGED
@@ -11,7 +11,9 @@ module DNN
|
|
11
11
|
|
12
12
|
# Update params.
|
13
13
|
# Classes that inherit from this class must implement this method.
|
14
|
-
|
14
|
+
def update(params)
|
15
|
+
raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'update'")
|
16
|
+
end
|
15
17
|
|
16
18
|
def to_hash(merge_hash = nil)
|
17
19
|
hash = {class: self.class.name, learning_rate: @learning_rate}
|
data/lib/dnn/core/rnn_layers.rb
CHANGED
@@ -25,7 +25,7 @@ module DNN
|
|
25
25
|
@return_sequences = return_sequences
|
26
26
|
@layers = []
|
27
27
|
@params[:h] = nil
|
28
|
-
@params[:weight2] = @weight2 = LearningParam.new
|
28
|
+
@params[:weight2] = @weight2 = LearningParam.new
|
29
29
|
end
|
30
30
|
|
31
31
|
def forward(xs)
|
@@ -199,9 +199,9 @@ module DNN
|
|
199
199
|
@weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
|
200
200
|
@weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
|
201
201
|
@bias.data = Xumo::SFloat.new(@num_nodes)
|
202
|
-
@weight_initializer.init_param(@weight)
|
203
|
-
@weight_initializer.init_param(@weight2)
|
204
|
-
@bias_initializer.init_param(@bias)
|
202
|
+
@weight_initializer.init_param(self, @weight)
|
203
|
+
@weight_initializer.init_param(self, @weight2)
|
204
|
+
@bias_initializer.init_param(self, @bias)
|
205
205
|
@time_length.times do |t|
|
206
206
|
@layers << SimpleRNN_Dense.new(self)
|
207
207
|
end
|
@@ -343,9 +343,9 @@ module DNN
|
|
343
343
|
@weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
|
344
344
|
@weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
|
345
345
|
@bias.data = Xumo::SFloat.new(@num_nodes * 4)
|
346
|
-
@weight_initializer.init_param(@weight)
|
347
|
-
@weight_initializer.init_param(@weight2)
|
348
|
-
@bias_initializer.init_param(@bias)
|
346
|
+
@weight_initializer.init_param(self, @weight)
|
347
|
+
@weight_initializer.init_param(self, @weight2)
|
348
|
+
@bias_initializer.init_param(self, @bias)
|
349
349
|
@time_length.times do |t|
|
350
350
|
@layers << LSTM_Dense.new(self)
|
351
351
|
end
|
@@ -444,9 +444,9 @@ module DNN
|
|
444
444
|
@weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 3)
|
445
445
|
@weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
|
446
446
|
@bias.data = Xumo::SFloat.new(@num_nodes * 3)
|
447
|
-
@weight_initializer.init_param(@weight)
|
448
|
-
@weight_initializer.init_param(@weight2)
|
449
|
-
@bias_initializer.init_param(@bias)
|
447
|
+
@weight_initializer.init_param(self, @weight)
|
448
|
+
@weight_initializer.init_param(self, @weight2)
|
449
|
+
@bias_initializer.init_param(self, @bias)
|
450
450
|
@time_length.times do |t|
|
451
451
|
@layers << GRU_Dense.new(self)
|
452
452
|
end
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.8.
|
4
|
+
version: 0.8.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-01-
|
11
|
+
date: 2019-01-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|