ruby-dnn 0.8.0 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 308a84624d71cb5b27d4b72f1ea69880dbfebf226ac9a5b44cf5775cc7e22703
4
- data.tar.gz: baa0b39dcca002f79eb660129cf0042dfda7d4768e9b2067d7135ce877047770
3
+ metadata.gz: 447ad62d50c89eb90c20e5d8cbec0bc849696a197fe21183c9f101b85b6da926
4
+ data.tar.gz: bc5a5d1f5b96991a38045cf4696b157692c5a82fc95480fe853ba2a691435e72
5
5
  SHA512:
6
- metadata.gz: a98278bb5cfd211bcf1231a4f5b0f13fa9d7ee4684d0334a6d2052d77b405d1873a65bd11d1f2df2f1f29926de484d4d0a7cbd8ef8d88cb4735923904eeb91fd
7
- data.tar.gz: 06d52f9698bf600e0bb63c7cc703697d3bb1ec9827871e6dd595a7fdd17cf3e63980d8abe376c9418db9412309b5aed39187c8b988388fd2bbaf09afc4702aa7
6
+ metadata.gz: 6aff825a4f3e20053cfdc3459ad7186809664604c5fb9df0e8261f5197dc40e45652487f9d0ef7d01a39757506e086d75e2bb835a5e41fef7796321932111f04
7
+ data.tar.gz: 66bd57b376429617cf37012a2816cafb60e428d90a756b97460f9daa6d01120533b930c403fccab6ffce2a732be1684b6f090af30141b54dc2899144bb98fca1
@@ -3,7 +3,9 @@ module DNN
3
3
 
4
4
  class Initializer
5
5
  # Classes that inherit from this class must implement this method.
6
- # def init_param(param) end
6
+ def init_param(layer, param)
7
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'init_params'")
8
+ end
7
9
 
8
10
  def to_hash(merge_hash = nil)
9
11
  hash = {class: self.class.name}
@@ -14,7 +16,7 @@ module DNN
14
16
 
15
17
 
16
18
  class Zeros < Initializer
17
- def init_param(param)
19
+ def init_param(layer, param)
18
20
  param.data = param.data.fill(0)
19
21
  end
20
22
  end
@@ -33,7 +35,7 @@ module DNN
33
35
  @std = std
34
36
  end
35
37
 
36
- def init_param(param)
38
+ def init_param(layer, param)
37
39
  param.data = param.data.rand_norm(@mean, @std)
38
40
  end
39
41
 
@@ -56,7 +58,7 @@ module DNN
56
58
  @max = max
57
59
  end
58
60
 
59
- def init_param(param)
61
+ def init_param(layer, param)
60
62
  param.data = param.data.rand(@min, @max)
61
63
  end
62
64
 
@@ -67,16 +69,16 @@ module DNN
67
69
 
68
70
 
69
71
  class Xavier < Initializer
70
- def init_param(param)
71
- num_prev_nodes = param.layer.prev_layer.shape.reduce(:*)
72
+ def init_param(layer, param)
73
+ num_prev_nodes = layer.prev_layer.shape.reduce(:*)
72
74
  param.data = param.data.rand_norm / Math.sqrt(num_prev_nodes)
73
75
  end
74
76
  end
75
77
 
76
78
 
77
79
  class He < Initializer
78
- def init_param(param)
79
- num_prev_nodes = param.layer.prev_layer.shape.reduce(:*)
80
+ def init_param(layer, param)
81
+ num_prev_nodes = layer.prev_layer.shape.reduce(:*)
80
82
  param.data = param.data.rand_norm / Math.sqrt(num_prev_nodes) * Math.sqrt(2)
81
83
  end
82
84
  end
@@ -20,11 +20,15 @@ module DNN
20
20
 
21
21
  # Forward propagation.
22
22
  # Classes that inherit from this class must implement this method.
23
- # def forward() end
23
+ def forward
24
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'forward'")
25
+ end
24
26
 
25
27
  # Backward propagation.
26
28
  # Classes that inherit from this class must implement this method.
27
- # def backward() end
29
+ def backward
30
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'update'")
31
+ end
28
32
 
29
33
  # Get the shape of the layer.
30
34
  def shape
@@ -73,7 +77,9 @@ module DNN
73
77
 
74
78
  # Initialize of the parameters.
75
79
  # Classes that inherit from this class must implement this method.
76
- def init_params() end
80
+ def init_params
81
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'init_params'")
82
+ end
77
83
  end
78
84
 
79
85
 
@@ -119,8 +125,8 @@ module DNN
119
125
  @bias_initializer = (bias_initializer || Zeros.new)
120
126
  @l1_lambda = l1_lambda
121
127
  @l2_lambda = l2_lambda
122
- @params[:weight] = @weight = LearningParam.new(self)
123
- @params[:bias] = @bias = LearningParam.new(self)
128
+ @params[:weight] = @weight = LearningParam.new
129
+ @params[:bias] = @bias = LearningParam.new
124
130
  end
125
131
 
126
132
  def lasso
@@ -159,8 +165,8 @@ module DNN
159
165
  private
160
166
 
161
167
  def init_params
162
- @weight_initializer.init_param(@weight)
163
- @bias_initializer.init_param(@bias)
168
+ @weight_initializer.init_param(self, @weight)
169
+ @bias_initializer.init_param(self, @bias)
164
170
  end
165
171
  end
166
172
 
@@ -324,8 +330,8 @@ module DNN
324
330
  def initialize(momentum: 0.9)
325
331
  super()
326
332
  @momentum = momentum
327
- @params[:gamma] = @gamma = LearningParam.new(self)
328
- @params[:beta] = @beta = LearningParam.new(self)
333
+ @params[:gamma] = @gamma = LearningParam.new
334
+ @params[:beta] = @beta = LearningParam.new
329
335
  @params[:running_mean] = nil
330
336
  @params[:running_var] = nil
331
337
  end
@@ -1,9 +1,4 @@
1
1
  class DNN::LearningParam
2
2
  attr_accessor :data
3
3
  attr_accessor :grad
4
- attr_reader :layer
5
-
6
- def initialize(layer)
7
- @layer = layer
8
- end
9
4
  end
@@ -11,7 +11,9 @@ module DNN
11
11
 
12
12
  # Update params.
13
13
  # Classes that inherit from this class must implement this method.
14
- # def update(params) end
14
+ def update(params)
15
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'update'")
16
+ end
15
17
 
16
18
  def to_hash(merge_hash = nil)
17
19
  hash = {class: self.class.name, learning_rate: @learning_rate}
@@ -25,7 +25,7 @@ module DNN
25
25
  @return_sequences = return_sequences
26
26
  @layers = []
27
27
  @params[:h] = nil
28
- @params[:weight2] = @weight2 = LearningParam.new(self)
28
+ @params[:weight2] = @weight2 = LearningParam.new
29
29
  end
30
30
 
31
31
  def forward(xs)
@@ -199,9 +199,9 @@ module DNN
199
199
  @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
200
200
  @weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
201
201
  @bias.data = Xumo::SFloat.new(@num_nodes)
202
- @weight_initializer.init_param(@weight)
203
- @weight_initializer.init_param(@weight2)
204
- @bias_initializer.init_param(@bias)
202
+ @weight_initializer.init_param(self, @weight)
203
+ @weight_initializer.init_param(self, @weight2)
204
+ @bias_initializer.init_param(self, @bias)
205
205
  @time_length.times do |t|
206
206
  @layers << SimpleRNN_Dense.new(self)
207
207
  end
@@ -343,9 +343,9 @@ module DNN
343
343
  @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
344
344
  @weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
345
345
  @bias.data = Xumo::SFloat.new(@num_nodes * 4)
346
- @weight_initializer.init_param(@weight)
347
- @weight_initializer.init_param(@weight2)
348
- @bias_initializer.init_param(@bias)
346
+ @weight_initializer.init_param(self, @weight)
347
+ @weight_initializer.init_param(self, @weight2)
348
+ @bias_initializer.init_param(self, @bias)
349
349
  @time_length.times do |t|
350
350
  @layers << LSTM_Dense.new(self)
351
351
  end
@@ -444,9 +444,9 @@ module DNN
444
444
  @weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 3)
445
445
  @weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
446
446
  @bias.data = Xumo::SFloat.new(@num_nodes * 3)
447
- @weight_initializer.init_param(@weight)
448
- @weight_initializer.init_param(@weight2)
449
- @bias_initializer.init_param(@bias)
447
+ @weight_initializer.init_param(self, @weight)
448
+ @weight_initializer.init_param(self, @weight2)
449
+ @bias_initializer.init_param(self, @bias)
450
450
  @time_length.times do |t|
451
451
  @layers << GRU_Dense.new(self)
452
452
  end
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.8.0"
2
+ VERSION = "0.8.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.0
4
+ version: 0.8.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-01-06 00:00:00.000000000 Z
11
+ date: 2019-01-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray