ruby-dnn 0.14.3 → 0.15.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -53,6 +53,14 @@ module DNN
53
53
  def load_hash(hash)
54
54
  initialize
55
55
  end
56
+
57
+ def clean
58
+ hash = to_hash
59
+ instance_variables.each do |ivar|
60
+ instance_variable_set(ivar, nil)
61
+ end
62
+ load_hash(hash)
63
+ end
56
64
  end
57
65
 
58
66
  class MeanSquaredError < Loss
@@ -1,7 +1,97 @@
1
1
  module DNN
2
2
  module Models
3
+
4
+ class LayersList < Array
5
+ def self.from_hash_list(hash_list)
6
+ layers_list = new
7
+ hash_list.each do |hash|
8
+ obj_class = DNN.const_get(hash[:class])
9
+ obj = obj_class.allocate
10
+ if obj.is_a?(Chain)
11
+ obj = obj_class.new
12
+ obj.load_hash(hash)
13
+ else
14
+ obj = Layers::Layer.from_hash(hash)
15
+ end
16
+ layers_list << obj
17
+ end
18
+ layers_list
19
+ end
20
+
21
+ def to_hash_list
22
+ map { |layer| layer.to_hash }
23
+ end
24
+
25
+ # Get the all layers.
26
+ # @return [Array] All layers array.
27
+ def layers
28
+ layers_array = []
29
+ each do |layer|
30
+ if layer.is_a?(Layers::Layer)
31
+ layers_array << layer
32
+ elsif layer.is_a?(Chain) || layer.is_a?(LayersList)
33
+ layers_array.concat(layer.layers)
34
+ end
35
+ end
36
+ layers_array
37
+ end
38
+ end
39
+
40
+ class Chain
41
+ def call(x)
42
+ raise NotImplementedError, "Class '#{self.class.name}' has implement method 'call'"
43
+ end
44
+
45
+ # Get the all layers.
46
+ # @return [Array] All layers array.
47
+ def layers
48
+ layers_array = []
49
+ instance_variables.sort.each do |ivar|
50
+ obj = instance_variable_get(ivar)
51
+ if obj.is_a?(Layers::Layer)
52
+ layers_array << obj
53
+ elsif obj.is_a?(Chain) || obj.is_a?(LayersList)
54
+ layers_array.concat(obj.layers)
55
+ end
56
+ end
57
+ layers_array
58
+ end
59
+
60
+ def to_hash
61
+ layers_hash = { class: self.class.name }
62
+ instance_variables.sort.each do |ivar|
63
+ obj = instance_variable_get(ivar)
64
+ if obj.is_a?(Layers::Layer) || obj.is_a?(Chain)
65
+ layers_hash[ivar] = obj.to_hash
66
+ elsif obj.is_a?(LayersList)
67
+ layers_hash[ivar] = obj.to_hash_list
68
+ end
69
+ end
70
+ layers_hash
71
+ end
72
+
73
+ def load_hash(layers_hash)
74
+ instance_variables.sort.each do |ivar|
75
+ hash_or_array = layers_hash[ivar]
76
+ if hash_or_array.is_a?(Array)
77
+ instance_variable_set(ivar, LayersList.from_hash_list(hash_or_array))
78
+ elsif hash_or_array.is_a?(Hash)
79
+ obj_class = DNN.const_get(hash_or_array[:class])
80
+ obj = obj_class.allocate
81
+ if obj.is_a?(Chain)
82
+ obj = obj_class.new
83
+ obj.load_hash(hash_or_array)
84
+ instance_variable_set(ivar, obj)
85
+ else
86
+ instance_variable_set(ivar, Layers::Layer.from_hash(hash_or_array))
87
+ end
88
+ end
89
+ end
90
+ end
91
+ end
92
+
3
93
  # This class deals with the model of the network.
4
- class Model
94
+ class Model < Chain
5
95
  attr_accessor :optimizer
6
96
  attr_accessor :loss_func
7
97
  attr_reader :last_log
@@ -10,7 +100,7 @@ module DNN
10
100
  # @param [String] file_name File name of marshal model to load.
11
101
  # @return [DNN::Models::Model] Return the loaded model.
12
102
  def self.load(file_name)
13
- model = new
103
+ model = self.allocate
14
104
  loader = Loaders::MarshalLoader.new(model)
15
105
  loader.load(file_name)
16
106
  model
@@ -113,8 +203,12 @@ module DNN
113
203
  end
114
204
 
115
205
  if test
116
- test_met = test(test[0], test[1], batch_size: batch_size)
117
- print " " + metrics_to_str(test_met) if verbose
206
+ acc, loss = if test.is_a?(Array)
207
+ evaluate(test[0], test[1], batch_size: batch_size)
208
+ else
209
+ evaluate_by_iterator(test, batch_size: batch_size)
210
+ end
211
+ print " " + metrics_to_str({ accuracy: acc, test_loss: loss }) if verbose
118
212
  end
119
213
  puts "" if verbose
120
214
  call_callbacks(:after_epoch)
@@ -138,16 +232,6 @@ module DNN
138
232
  { loss: loss_value }
139
233
  end
140
234
 
141
- # Implement the test process to be performed.
142
- # @param [Numo::SFloat] x Input training data.
143
- # @param [Numo::SFloat] y Output training data.
144
- # @param [Integer] batch_size Batch size used for one test.
145
- # @return [Hash] Hash of contents to be output to log.
146
- private def test(x, y, batch_size: 100)
147
- acc, test_loss = accuracy(x, y, batch_size: batch_size)
148
- { accuracy: acc, test_loss: test_loss }
149
- end
150
-
151
235
  # Training once.
152
236
  # Setup the model before use this method.
153
237
  # @param [Numo::SFloat] x Input training data.
@@ -169,20 +253,24 @@ module DNN
169
253
  loss_value
170
254
  end
171
255
 
172
- # Evaluate model and get accuracy of test data.
256
+ # Evaluate model and get accuracy and loss of test data.
173
257
  # @param [Numo::SFloat] x Input test data.
174
258
  # @param [Numo::SFloat] y Output test data.
175
259
  # @param [Integer] batch_size Batch size used for one test.
176
260
  # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
177
- def accuracy(x, y, batch_size: 100)
261
+ def evaluate(x, y, batch_size: 100)
178
262
  check_xy_type(x, y)
179
- num_test_datas = x.is_a?(Array) ? x[0].shape[0] : x.shape[0]
263
+ evaluate_by_iterator(Iterator.new(x, y, random: false))
264
+ end
265
+
266
+ # Evaluate model by iterator
267
+ def evaluate_by_iterator(test_iterator, batch_size: 100)
268
+ num_test_datas = test_iterator.num_datas
180
269
  batch_size = batch_size >= num_test_datas[0] ? num_test_datas : batch_size
181
- iter = Iterator.new(x, y, random: false)
182
270
  total_correct = 0
183
271
  sum_loss = 0
184
272
  max_steps = (num_test_datas.to_f / batch_size).ceil
185
- iter.foreach(batch_size) do |x_batch, y_batch|
273
+ test_iterator.foreach(batch_size) do |x_batch, y_batch|
186
274
  correct, loss_value = test_on_batch(x_batch, y_batch)
187
275
  total_correct += correct
188
276
  sum_loss += loss_value
@@ -201,16 +289,16 @@ module DNN
201
289
  def test_on_batch(x, y)
202
290
  call_callbacks(:before_test_on_batch)
203
291
  x = forward(x, false)
204
- correct = evaluate(x, y)
292
+ correct = accuracy(x, y)
205
293
  loss_value = @loss_func.loss(x, y)
206
294
  call_callbacks(:after_test_on_batch)
207
295
  [correct, loss_value]
208
296
  end
209
297
 
210
- # Implement the process to evaluate this model.
298
+ # Implement the process to accuracy this model.
211
299
  # @param [Numo::SFloat] x Input test data.
212
300
  # @param [Numo::SFloat] y Output test data.
213
- private def evaluate(x, y)
301
+ private def accuracy(x, y)
214
302
  if x.shape[1..-1] == [1]
215
303
  correct = 0
216
304
  x.shape[0].times do |i|
@@ -257,11 +345,24 @@ module DNN
257
345
  @callbacks = []
258
346
  end
259
347
 
348
+ # Load marshal params.
349
+ # @param [String] file_name File name of marshal model to load.
350
+ def load_params(file_name)
351
+ loader = Loaders::MarshalLoader.new(self)
352
+ loader.load(file_name)
353
+ end
354
+
260
355
  # Save the model in marshal format.
261
356
  # @param [String] file_name Name to save model.
262
- # @param [Boolean] include_optimizer Set true to save data included optimizer status.
263
- def save(file_name, include_optimizer: true)
264
- saver = Savers::MarshalSaver.new(self, include_optimizer: include_optimizer)
357
+ def save(file_name)
358
+ saver = Savers::MarshalSaver.new(self, include_model: true)
359
+ saver.save(file_name)
360
+ end
361
+
362
+ # Save the params in marshal format.
363
+ # @param [String] file_name Name to save model.
364
+ def save_params(file_name)
365
+ saver = Savers::MarshalSaver.new(self, include_model: false)
265
366
  saver.save(file_name)
266
367
  end
267
368
 
@@ -270,37 +371,21 @@ module DNN
270
371
  Marshal.load(Marshal.dump(self))
271
372
  end
272
373
 
273
- # Get the all layers.
274
- # @return [Array] All layers array.
275
- def layers
276
- raise DNN_Error, "This model is not built. You need build this model using predict or train." unless built?
277
- return @layers_cache if @layers_cache
278
- layers = []
279
- get_layers = -> link do
280
- return unless link
281
- layers.unshift(link.layer)
282
- if link.is_a?(TwoInputLink)
283
- get_layers.(link.prev1)
284
- get_layers.(link.prev2)
285
- else
286
- get_layers.(link.prev)
287
- end
288
- end
289
- get_layers.(@last_link)
290
- @layers_cache = layers.uniq
291
- end
292
-
293
- # Get the all has param layers.
374
+ # Get the all trainable layers.
294
375
  # @return [Array] All has param layers array.
295
- def has_param_layers
296
- layers.select { |layer| layer.is_a?(Layers::HasParamLayer) }
376
+ def trainable_layers
377
+ layers.select { |layer| layer.is_a?(Layers::TrainableLayer) }
297
378
  end
298
379
 
299
380
  # Get the layer that the model has.
300
381
  # @param [Symbol] name The name of the layer to get.
301
382
  # @return [DNN::Layers::Layer] Return the layer.
302
383
  def get_layer(name)
303
- layers.find { |layer| layer.name == name }
384
+ layer = instance_variable_get("@#{name}")
385
+ if layer.is_a?(Layers::Layer) || layer.is_a?(Chain) || layer.is_a?(LayersList)
386
+ return layer
387
+ end
388
+ nil
304
389
  end
305
390
 
306
391
  # @return [Boolean] If model have already been built then return true.
@@ -308,6 +393,31 @@ module DNN
308
393
  @built
309
394
  end
310
395
 
396
+ def clean_layers
397
+ layers.each do |layer|
398
+ layer.clean
399
+ end
400
+ @loss_func.clean
401
+ @last_link = nil
402
+ @layers_cache = nil
403
+ end
404
+
405
+ def get_all_params_data
406
+ trainable_layers.map do |layer|
407
+ layer.get_params.to_h do |key, param|
408
+ [key, param.data]
409
+ end
410
+ end
411
+ end
412
+
413
+ def set_all_params_data(params_data)
414
+ trainable_layers.each.with_index do |layer, i|
415
+ params_data[i].each do |(key, data)|
416
+ layer.get_params[key].data = data
417
+ end
418
+ end
419
+ end
420
+
311
421
  private
312
422
 
313
423
  def forward(x, learning_phase)
@@ -322,7 +432,6 @@ module DNN
322
432
  @last_link = output_tensor.link
323
433
  unless @built
324
434
  @built = true
325
- naming
326
435
  end
327
436
  output_tensor.data
328
437
  end
@@ -337,19 +446,6 @@ module DNN
337
446
  end
338
447
  end
339
448
 
340
- def naming
341
- layers.each do |layer|
342
- id = layers.select { |l| l.is_a?(layer.class) }.index(layer)
343
- class_name = layer.class.name.split("::").last
344
- layer.name = "#{class_name}_#{id}".to_sym unless layer.name
345
- if layer.is_a?(Layers::HasParamLayer)
346
- layer.get_params.each do |param_key, param|
347
- param.name = "#{layer.name}__#{param_key}".to_sym unless param.name
348
- end
349
- end
350
- end
351
- end
352
-
353
449
  def metrics_to_str(mertics)
354
450
  mertics.map { |key, num| "#{key}: #{sprintf('%.4f', num)}" }.join(", ")
355
451
  end
@@ -370,7 +466,7 @@ module DNN
370
466
  # @param [Array] stack All layers possessed by the model.
371
467
  def initialize(stack = [])
372
468
  super()
373
- @stack = []
469
+ @stack = LayersList.new
374
470
  stack.each do |layer|
375
471
  add(layer)
376
472
  end
@@ -380,8 +476,8 @@ module DNN
380
476
  # @param [DNN::Layers::Layer] layer Layer to add to the model.
381
477
  # @return [DNN::Models::Model] Return self.
382
478
  def add(layer)
383
- if layer.is_a?(MergeLayers::MergeLayer)
384
- raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
479
+ if layer.is_a?(Layers::MergeLayer)
480
+ raise TypeError, "layer: #{layer.class.name} should not be a DNN::Layers::MergeLayer class."
385
481
  end
386
482
  unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
387
483
  raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
@@ -396,8 +492,8 @@ module DNN
396
492
  # @param [DNN::Layers::Layer] layer Layer to add to the model.
397
493
  # @return [DNN::Models::Model] Return self.
398
494
  def insert(index, layer)
399
- if layer.is_a?(MergeLayers::MergeLayer)
400
- raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
495
+ if layer.is_a?(Layers::MergeLayer)
496
+ raise TypeError, "layer: #{layer.class.name} should not be a DNN::Layers::MergeLayer class."
401
497
  end
402
498
  unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
403
499
  raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
@@ -3,7 +3,6 @@ module DNN
3
3
 
4
4
  # Super class of all optimizer classes.
5
5
  class Optimizer
6
- attr_reader :status
7
6
  attr_accessor :clip_norm
8
7
 
9
8
  def self.from_hash(hash)
@@ -15,17 +14,6 @@ module DNN
15
14
  optimizer
16
15
  end
17
16
 
18
- def self.load(dumped)
19
- opt = from_hash(dumped[:hash])
20
- return opt unless dumped[:status]
21
- dumped[:status].each do |key, state|
22
- state = state.clone
23
- opt.status[key] = state
24
- opt.instance_variable_set("@#{key}", state)
25
- end
26
- opt
27
- end
28
-
29
17
  # @param [Float | NilClass] clip_norm Gradient clip norm.
30
18
  def initialize(clip_norm: nil)
31
19
  @clip_norm = clip_norm
@@ -33,7 +21,7 @@ module DNN
33
21
 
34
22
  # Update layers has params.
35
23
  def update(layers)
36
- target_params = layers.select { |layer| layer.is_a?(Layers::HasParamLayer) && layer.trainable }
24
+ target_params = layers.select { |layer| layer.is_a?(Layers::TrainableLayer) && layer.trainable }
37
25
  .map { |layer| layer.get_params.values }.flatten.compact
38
26
  .select(&:grad)
39
27
  clip_grads(target_params) if @clip_norm
@@ -43,11 +31,6 @@ module DNN
43
31
  end
44
32
  end
45
33
 
46
- def dump(require_status = true)
47
- status = require_status ? @status : nil
48
- { hash: to_hash, status: status }
49
- end
50
-
51
34
  def to_hash(merge_hash = nil)
52
35
  hash = { class: self.class.name, clip_norm: @clip_norm }
53
36
  hash.merge!(merge_hash) if merge_hash
@@ -80,12 +63,11 @@ module DNN
80
63
 
81
64
  # @param [Float] lr Learning rate.
82
65
  # @param [Float] momentum Momentum coefficient.
83
- def initialize(lr = 0.01, momentum: 0, clip_norm: nil)
66
+ def initialize(lr: 0.01, momentum: 0, clip_norm: nil)
84
67
  super(clip_norm: clip_norm)
85
68
  @lr = lr
86
69
  @momentum = momentum
87
70
  @v = {}
88
- @status = { v: @v }
89
71
  end
90
72
 
91
73
  def to_hash
@@ -96,30 +78,30 @@ module DNN
96
78
  params.each do |param|
97
79
  amount = param.grad * @lr
98
80
  if @momentum > 0
99
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
100
- amount += @momentum * @v[param.name]
101
- @v[param.name] = amount
81
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
82
+ amount += @momentum * @v[param]
83
+ @v[param] = amount
102
84
  end
103
85
  param.data -= amount
104
86
  end
105
87
  end
106
88
 
107
89
  def load_hash(hash)
108
- initialize(hash[:lr], momentum: hash[:momentum], clip_norm: hash[:clip_norm])
90
+ initialize(lr: hash[:lr], momentum: hash[:momentum], clip_norm: hash[:clip_norm])
109
91
  end
110
92
  end
111
93
 
112
94
  class Nesterov < SGD
113
- def initialize(lr = 0.01, momentum: 0.9, clip_norm: nil)
114
- super(lr, momentum: momentum, clip_norm: clip_norm)
95
+ def initialize(lr: 0.01, momentum: 0.9, clip_norm: nil)
96
+ super(lr: lr, momentum: momentum, clip_norm: clip_norm)
115
97
  end
116
98
 
117
99
  private def update_params(params)
118
100
  params.each do |param|
119
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
101
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
120
102
  amount = param.grad * @lr
121
- @v[param.name] = @v[param.name] * @momentum - amount
122
- param.data = (param.data + @momentum**2 * @v[param.name]) - (1 + @momentum) * amount
103
+ @v[param] = @v[param] * @momentum - amount
104
+ param.data = (param.data + @momentum**2 * @v[param]) - (1 + @momentum) * amount
123
105
  end
124
106
  end
125
107
  end
@@ -130,19 +112,18 @@ module DNN
130
112
 
131
113
  # @param [Float] lr Learning rate.
132
114
  # @param [Float] eps Value to avoid division by zero.
133
- def initialize(lr = 0.01, eps: 1e-7, clip_norm: nil)
115
+ def initialize(lr: 0.01, eps: 1e-7, clip_norm: nil)
134
116
  super(clip_norm: clip_norm)
135
117
  @lr = lr
136
118
  @eps = eps
137
119
  @g = {}
138
- @status = { g: @g }
139
120
  end
140
121
 
141
122
  private def update_params(params)
142
123
  params.each do |param|
143
- @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
144
- @g[param.name] += param.grad**2
145
- param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad
124
+ @g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
125
+ @g[param] += param.grad**2
126
+ param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
146
127
  end
147
128
  end
148
129
 
@@ -151,7 +132,7 @@ module DNN
151
132
  end
152
133
 
153
134
  def load_hash(hash)
154
- initialize(hash[:lr], eps: hash[:eps], clip_norm: hash[:clip_norm])
135
+ initialize(lr: hash[:lr], eps: hash[:eps], clip_norm: hash[:clip_norm])
155
136
  end
156
137
  end
157
138
 
@@ -163,13 +144,12 @@ module DNN
163
144
  # @param [Float] lr Learning rate.
164
145
  # @param [Float] alpha Moving average index of past slopes.
165
146
  # @param [Float] eps Value to avoid division by zero.
166
- def initialize(lr = 0.001, alpha: 0.9, eps: 1e-7, clip_norm: nil)
147
+ def initialize(lr: 0.001, alpha: 0.9, eps: 1e-7, clip_norm: nil)
167
148
  super(clip_norm: clip_norm)
168
149
  @lr = lr
169
150
  @alpha = alpha
170
151
  @eps = eps
171
152
  @g = {}
172
- @status = { g: @g }
173
153
  end
174
154
 
175
155
  def to_hash
@@ -178,14 +158,14 @@ module DNN
178
158
 
179
159
  private def update_params(params)
180
160
  params.each do |param|
181
- @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
182
- @g[param.name] = @alpha * @g[param.name] + (1 - @alpha) * param.grad**2
183
- param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad
161
+ @g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
162
+ @g[param] = @alpha * @g[param] + (1 - @alpha) * param.grad**2
163
+ param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
184
164
  end
185
165
  end
186
166
 
187
167
  def load_hash(hash)
188
- initialize(hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
168
+ initialize(lr: hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
189
169
  end
190
170
  end
191
171
 
@@ -201,7 +181,6 @@ module DNN
201
181
  @eps = eps
202
182
  @h = {}
203
183
  @s = {}
204
- @status = { h: @h, s: @s }
205
184
  end
206
185
 
207
186
  def to_hash
@@ -210,11 +189,11 @@ module DNN
210
189
 
211
190
  private def update_params(params)
212
191
  params.each do |param|
213
- @h[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
214
- @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
215
- @h[param.name] = @rho * @h[param.name] + (1 - @rho) * param.grad**2
216
- v = (Xumo::NMath.sqrt(@s[param.name] + @eps) / Xumo::NMath.sqrt(@h[param.name] + @eps)) * param.grad
217
- @s[param.name] = @rho * @s[param.name] + (1 - @rho) * v**2
192
+ @h[param] ||= Xumo::SFloat.zeros(*param.data.shape)
193
+ @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
194
+ @h[param] = @rho * @h[param] + (1 - @rho) * param.grad**2
195
+ v = (Xumo::NMath.sqrt(@s[param] + @eps) / Xumo::NMath.sqrt(@h[param] + @eps)) * param.grad
196
+ @s[param] = @rho * @s[param] + (1 - @rho) * v**2
218
197
  param.data -= v
219
198
  end
220
199
  end
@@ -232,14 +211,13 @@ module DNN
232
211
  # @param [Float] lr Learning rate.
233
212
  # @param [Float] alpha Moving average index of past slopes.
234
213
  # @param [Float] eps Value to avoid division by zero.
235
- def initialize(lr = 0.0001, alpha: 0.95, eps: 0.0001, clip_norm: nil)
214
+ def initialize(lr: 0.0001, alpha: 0.95, eps: 0.0001, clip_norm: nil)
236
215
  super(clip_norm: clip_norm)
237
216
  @lr = lr
238
217
  @alpha = alpha
239
218
  @eps = eps
240
219
  @m = {}
241
220
  @v = {}
242
- @status = { m: @m, v: @v }
243
221
  end
244
222
 
245
223
  def to_hash
@@ -248,16 +226,16 @@ module DNN
248
226
 
249
227
  private def update_params(params)
250
228
  params.each do |param|
251
- @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
252
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
253
- @m[param.name] = @alpha * @m[param.name] + (1 - @alpha) * param.grad
254
- @v[param.name] = @alpha * @v[param.name] + (1 - @alpha) * param.grad**2
255
- param.data -= (@lr / Xumo::NMath.sqrt(@v[param.name] - @m[param.name]**2 + @eps)) * param.grad
229
+ @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
230
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
231
+ @m[param] = @alpha * @m[param] + (1 - @alpha) * param.grad
232
+ @v[param] = @alpha * @v[param] + (1 - @alpha) * param.grad**2
233
+ param.data -= (@lr / Xumo::NMath.sqrt(@v[param] - @m[param]**2 + @eps)) * param.grad
256
234
  end
257
235
  end
258
236
 
259
237
  def load_hash(hash)
260
- initialize(hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
238
+ initialize(lr: hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
261
239
  end
262
240
  end
263
241
 
@@ -284,7 +262,6 @@ module DNN
284
262
  @m = {}
285
263
  @v = {}
286
264
  @s = amsgrad ? {} : nil
287
- @status = { t: @t, m: @m, v: @v, s: @s }
288
265
  end
289
266
 
290
267
  def to_hash
@@ -298,16 +275,16 @@ module DNN
298
275
  @t += 1
299
276
  lr = @alpha * Math.sqrt(1 - @beta2**@t) / (1 - @beta1**@t)
300
277
  params.each do |param|
301
- @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
302
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
303
- @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name])
304
- @v[param.name] += (1 - @beta2) * (param.grad**2 - @v[param.name])
278
+ @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
279
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
280
+ @m[param] += (1 - @beta1) * (param.grad - @m[param])
281
+ @v[param] += (1 - @beta2) * (param.grad**2 - @v[param])
305
282
  if @amsgrad
306
- @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
307
- @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name])
308
- param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@s[param.name] + @eps)
283
+ @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
284
+ @s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
285
+ param.data -= lr * @m[param] / Xumo::NMath.sqrt(@s[param] + @eps)
309
286
  else
310
- param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@v[param.name] + @eps)
287
+ param.data -= lr * @m[param] / Xumo::NMath.sqrt(@v[param] + @eps)
311
288
  end
312
289
  end
313
290
  end
@@ -344,16 +321,16 @@ module DNN
344
321
  lower_bound = final_lr * (1 - 1 / (@gamma * @t + 1))
345
322
  upper_bound = final_lr * (1 + 1 / (@gamma * @t))
346
323
  params.each do |param|
347
- @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
348
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
349
- @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name])
350
- @v[param.name] += (1 - @beta2) * (param.grad**2 - @v[param.name])
324
+ @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
325
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
326
+ @m[param] += (1 - @beta1) * (param.grad - @m[param])
327
+ @v[param] += (1 - @beta2) * (param.grad**2 - @v[param])
351
328
  if @amsgrad
352
- @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
353
- @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name])
354
- param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name]
329
+ @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
330
+ @s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
331
+ param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param]) + @eps), lower_bound, upper_bound) * @m[param]
355
332
  else
356
- param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name]
333
+ param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param]) + @eps), lower_bound, upper_bound) * @m[param]
357
334
  end
358
335
  end
359
336
  end