ruby-dnn 0.14.3 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -53,6 +53,14 @@ module DNN
53
53
  def load_hash(hash)
54
54
  initialize
55
55
  end
56
+
57
+ def clean
58
+ hash = to_hash
59
+ instance_variables.each do |ivar|
60
+ instance_variable_set(ivar, nil)
61
+ end
62
+ load_hash(hash)
63
+ end
56
64
  end
57
65
 
58
66
  class MeanSquaredError < Loss
@@ -1,7 +1,97 @@
1
1
  module DNN
2
2
  module Models
3
+
4
+ class LayersList < Array
5
+ def self.from_hash_list(hash_list)
6
+ layers_list = new
7
+ hash_list.each do |hash|
8
+ obj_class = DNN.const_get(hash[:class])
9
+ obj = obj_class.allocate
10
+ if obj.is_a?(Chain)
11
+ obj = obj_class.new
12
+ obj.load_hash(hash)
13
+ else
14
+ obj = Layers::Layer.from_hash(hash)
15
+ end
16
+ layers_list << obj
17
+ end
18
+ layers_list
19
+ end
20
+
21
+ def to_hash_list
22
+ map { |layer| layer.to_hash }
23
+ end
24
+
25
+ # Get the all layers.
26
+ # @return [Array] All layers array.
27
+ def layers
28
+ layers_array = []
29
+ each do |layer|
30
+ if layer.is_a?(Layers::Layer)
31
+ layers_array << layer
32
+ elsif layer.is_a?(Chain) || layer.is_a?(LayersList)
33
+ layers_array.concat(layer.layers)
34
+ end
35
+ end
36
+ layers_array
37
+ end
38
+ end
39
+
40
+ class Chain
41
+ def call(x)
42
+ raise NotImplementedError, "Class '#{self.class.name}' has implement method 'call'"
43
+ end
44
+
45
+ # Get the all layers.
46
+ # @return [Array] All layers array.
47
+ def layers
48
+ layers_array = []
49
+ instance_variables.sort.each do |ivar|
50
+ obj = instance_variable_get(ivar)
51
+ if obj.is_a?(Layers::Layer)
52
+ layers_array << obj
53
+ elsif obj.is_a?(Chain) || obj.is_a?(LayersList)
54
+ layers_array.concat(obj.layers)
55
+ end
56
+ end
57
+ layers_array
58
+ end
59
+
60
+ def to_hash
61
+ layers_hash = { class: self.class.name }
62
+ instance_variables.sort.each do |ivar|
63
+ obj = instance_variable_get(ivar)
64
+ if obj.is_a?(Layers::Layer) || obj.is_a?(Chain)
65
+ layers_hash[ivar] = obj.to_hash
66
+ elsif obj.is_a?(LayersList)
67
+ layers_hash[ivar] = obj.to_hash_list
68
+ end
69
+ end
70
+ layers_hash
71
+ end
72
+
73
+ def load_hash(layers_hash)
74
+ instance_variables.sort.each do |ivar|
75
+ hash_or_array = layers_hash[ivar]
76
+ if hash_or_array.is_a?(Array)
77
+ instance_variable_set(ivar, LayersList.from_hash_list(hash_or_array))
78
+ elsif hash_or_array.is_a?(Hash)
79
+ obj_class = DNN.const_get(hash_or_array[:class])
80
+ obj = obj_class.allocate
81
+ if obj.is_a?(Chain)
82
+ obj = obj_class.new
83
+ obj.load_hash(hash_or_array)
84
+ instance_variable_set(ivar, obj)
85
+ else
86
+ instance_variable_set(ivar, Layers::Layer.from_hash(hash_or_array))
87
+ end
88
+ end
89
+ end
90
+ end
91
+ end
92
+
3
93
  # This class deals with the model of the network.
4
- class Model
94
+ class Model < Chain
5
95
  attr_accessor :optimizer
6
96
  attr_accessor :loss_func
7
97
  attr_reader :last_log
@@ -10,7 +100,7 @@ module DNN
10
100
  # @param [String] file_name File name of marshal model to load.
11
101
  # @return [DNN::Models::Model] Return the loaded model.
12
102
  def self.load(file_name)
13
- model = new
103
+ model = self.allocate
14
104
  loader = Loaders::MarshalLoader.new(model)
15
105
  loader.load(file_name)
16
106
  model
@@ -113,8 +203,12 @@ module DNN
113
203
  end
114
204
 
115
205
  if test
116
- test_met = test(test[0], test[1], batch_size: batch_size)
117
- print " " + metrics_to_str(test_met) if verbose
206
+ acc, loss = if test.is_a?(Array)
207
+ evaluate(test[0], test[1], batch_size: batch_size)
208
+ else
209
+ evaluate_by_iterator(test, batch_size: batch_size)
210
+ end
211
+ print " " + metrics_to_str({ accuracy: acc, test_loss: loss }) if verbose
118
212
  end
119
213
  puts "" if verbose
120
214
  call_callbacks(:after_epoch)
@@ -138,16 +232,6 @@ module DNN
138
232
  { loss: loss_value }
139
233
  end
140
234
 
141
- # Implement the test process to be performed.
142
- # @param [Numo::SFloat] x Input training data.
143
- # @param [Numo::SFloat] y Output training data.
144
- # @param [Integer] batch_size Batch size used for one test.
145
- # @return [Hash] Hash of contents to be output to log.
146
- private def test(x, y, batch_size: 100)
147
- acc, test_loss = accuracy(x, y, batch_size: batch_size)
148
- { accuracy: acc, test_loss: test_loss }
149
- end
150
-
151
235
  # Training once.
152
236
  # Setup the model before use this method.
153
237
  # @param [Numo::SFloat] x Input training data.
@@ -169,20 +253,24 @@ module DNN
169
253
  loss_value
170
254
  end
171
255
 
172
- # Evaluate model and get accuracy of test data.
256
+ # Evaluate model and get accuracy and loss of test data.
173
257
  # @param [Numo::SFloat] x Input test data.
174
258
  # @param [Numo::SFloat] y Output test data.
175
259
  # @param [Integer] batch_size Batch size used for one test.
176
260
  # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
177
- def accuracy(x, y, batch_size: 100)
261
+ def evaluate(x, y, batch_size: 100)
178
262
  check_xy_type(x, y)
179
- num_test_datas = x.is_a?(Array) ? x[0].shape[0] : x.shape[0]
263
+ evaluate_by_iterator(Iterator.new(x, y, random: false))
264
+ end
265
+
266
+ # Evaluate model by iterator
267
+ def evaluate_by_iterator(test_iterator, batch_size: 100)
268
+ num_test_datas = test_iterator.num_datas
180
269
  batch_size = batch_size >= num_test_datas[0] ? num_test_datas : batch_size
181
- iter = Iterator.new(x, y, random: false)
182
270
  total_correct = 0
183
271
  sum_loss = 0
184
272
  max_steps = (num_test_datas.to_f / batch_size).ceil
185
- iter.foreach(batch_size) do |x_batch, y_batch|
273
+ test_iterator.foreach(batch_size) do |x_batch, y_batch|
186
274
  correct, loss_value = test_on_batch(x_batch, y_batch)
187
275
  total_correct += correct
188
276
  sum_loss += loss_value
@@ -201,16 +289,16 @@ module DNN
201
289
  def test_on_batch(x, y)
202
290
  call_callbacks(:before_test_on_batch)
203
291
  x = forward(x, false)
204
- correct = evaluate(x, y)
292
+ correct = accuracy(x, y)
205
293
  loss_value = @loss_func.loss(x, y)
206
294
  call_callbacks(:after_test_on_batch)
207
295
  [correct, loss_value]
208
296
  end
209
297
 
210
- # Implement the process to evaluate this model.
298
+ # Implement the process to accuracy this model.
211
299
  # @param [Numo::SFloat] x Input test data.
212
300
  # @param [Numo::SFloat] y Output test data.
213
- private def evaluate(x, y)
301
+ private def accuracy(x, y)
214
302
  if x.shape[1..-1] == [1]
215
303
  correct = 0
216
304
  x.shape[0].times do |i|
@@ -257,11 +345,24 @@ module DNN
257
345
  @callbacks = []
258
346
  end
259
347
 
348
+ # Load marshal params.
349
+ # @param [String] file_name File name of marshal model to load.
350
+ def load_params(file_name)
351
+ loader = Loaders::MarshalLoader.new(self)
352
+ loader.load(file_name)
353
+ end
354
+
260
355
  # Save the model in marshal format.
261
356
  # @param [String] file_name Name to save model.
262
- # @param [Boolean] include_optimizer Set true to save data included optimizer status.
263
- def save(file_name, include_optimizer: true)
264
- saver = Savers::MarshalSaver.new(self, include_optimizer: include_optimizer)
357
+ def save(file_name)
358
+ saver = Savers::MarshalSaver.new(self, include_model: true)
359
+ saver.save(file_name)
360
+ end
361
+
362
+ # Save the params in marshal format.
363
+ # @param [String] file_name Name to save model.
364
+ def save_params(file_name)
365
+ saver = Savers::MarshalSaver.new(self, include_model: false)
265
366
  saver.save(file_name)
266
367
  end
267
368
 
@@ -270,37 +371,21 @@ module DNN
270
371
  Marshal.load(Marshal.dump(self))
271
372
  end
272
373
 
273
- # Get the all layers.
274
- # @return [Array] All layers array.
275
- def layers
276
- raise DNN_Error, "This model is not built. You need build this model using predict or train." unless built?
277
- return @layers_cache if @layers_cache
278
- layers = []
279
- get_layers = -> link do
280
- return unless link
281
- layers.unshift(link.layer)
282
- if link.is_a?(TwoInputLink)
283
- get_layers.(link.prev1)
284
- get_layers.(link.prev2)
285
- else
286
- get_layers.(link.prev)
287
- end
288
- end
289
- get_layers.(@last_link)
290
- @layers_cache = layers.uniq
291
- end
292
-
293
- # Get the all has param layers.
374
+ # Get the all trainable layers.
294
375
  # @return [Array] All has param layers array.
295
- def has_param_layers
296
- layers.select { |layer| layer.is_a?(Layers::HasParamLayer) }
376
+ def trainable_layers
377
+ layers.select { |layer| layer.is_a?(Layers::TrainableLayer) }
297
378
  end
298
379
 
299
380
  # Get the layer that the model has.
300
381
  # @param [Symbol] name The name of the layer to get.
301
382
  # @return [DNN::Layers::Layer] Return the layer.
302
383
  def get_layer(name)
303
- layers.find { |layer| layer.name == name }
384
+ layer = instance_variable_get("@#{name}")
385
+ if layer.is_a?(Layers::Layer) || layer.is_a?(Chain) || layer.is_a?(LayersList)
386
+ return layer
387
+ end
388
+ nil
304
389
  end
305
390
 
306
391
  # @return [Boolean] If model have already been built then return true.
@@ -308,6 +393,31 @@ module DNN
308
393
  @built
309
394
  end
310
395
 
396
+ def clean_layers
397
+ layers.each do |layer|
398
+ layer.clean
399
+ end
400
+ @loss_func.clean
401
+ @last_link = nil
402
+ @layers_cache = nil
403
+ end
404
+
405
+ def get_all_params_data
406
+ trainable_layers.map do |layer|
407
+ layer.get_params.to_h do |key, param|
408
+ [key, param.data]
409
+ end
410
+ end
411
+ end
412
+
413
+ def set_all_params_data(params_data)
414
+ trainable_layers.each.with_index do |layer, i|
415
+ params_data[i].each do |(key, data)|
416
+ layer.get_params[key].data = data
417
+ end
418
+ end
419
+ end
420
+
311
421
  private
312
422
 
313
423
  def forward(x, learning_phase)
@@ -322,7 +432,6 @@ module DNN
322
432
  @last_link = output_tensor.link
323
433
  unless @built
324
434
  @built = true
325
- naming
326
435
  end
327
436
  output_tensor.data
328
437
  end
@@ -337,19 +446,6 @@ module DNN
337
446
  end
338
447
  end
339
448
 
340
- def naming
341
- layers.each do |layer|
342
- id = layers.select { |l| l.is_a?(layer.class) }.index(layer)
343
- class_name = layer.class.name.split("::").last
344
- layer.name = "#{class_name}_#{id}".to_sym unless layer.name
345
- if layer.is_a?(Layers::HasParamLayer)
346
- layer.get_params.each do |param_key, param|
347
- param.name = "#{layer.name}__#{param_key}".to_sym unless param.name
348
- end
349
- end
350
- end
351
- end
352
-
353
449
  def metrics_to_str(mertics)
354
450
  mertics.map { |key, num| "#{key}: #{sprintf('%.4f', num)}" }.join(", ")
355
451
  end
@@ -370,7 +466,7 @@ module DNN
370
466
  # @param [Array] stack All layers possessed by the model.
371
467
  def initialize(stack = [])
372
468
  super()
373
- @stack = []
469
+ @stack = LayersList.new
374
470
  stack.each do |layer|
375
471
  add(layer)
376
472
  end
@@ -380,8 +476,8 @@ module DNN
380
476
  # @param [DNN::Layers::Layer] layer Layer to add to the model.
381
477
  # @return [DNN::Models::Model] Return self.
382
478
  def add(layer)
383
- if layer.is_a?(MergeLayers::MergeLayer)
384
- raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
479
+ if layer.is_a?(Layers::MergeLayer)
480
+ raise TypeError, "layer: #{layer.class.name} should not be a DNN::Layers::MergeLayer class."
385
481
  end
386
482
  unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
387
483
  raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
@@ -396,8 +492,8 @@ module DNN
396
492
  # @param [DNN::Layers::Layer] layer Layer to add to the model.
397
493
  # @return [DNN::Models::Model] Return self.
398
494
  def insert(index, layer)
399
- if layer.is_a?(MergeLayers::MergeLayer)
400
- raise TypeError, "layer: #{layer.class.name} should not be a DNN::MergeLayers::MergeLayer class."
495
+ if layer.is_a?(Layers::MergeLayer)
496
+ raise TypeError, "layer: #{layer.class.name} should not be a DNN::Layers::MergeLayer class."
401
497
  end
402
498
  unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
403
499
  raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
@@ -3,7 +3,6 @@ module DNN
3
3
 
4
4
  # Super class of all optimizer classes.
5
5
  class Optimizer
6
- attr_reader :status
7
6
  attr_accessor :clip_norm
8
7
 
9
8
  def self.from_hash(hash)
@@ -15,17 +14,6 @@ module DNN
15
14
  optimizer
16
15
  end
17
16
 
18
- def self.load(dumped)
19
- opt = from_hash(dumped[:hash])
20
- return opt unless dumped[:status]
21
- dumped[:status].each do |key, state|
22
- state = state.clone
23
- opt.status[key] = state
24
- opt.instance_variable_set("@#{key}", state)
25
- end
26
- opt
27
- end
28
-
29
17
  # @param [Float | NilClass] clip_norm Gradient clip norm.
30
18
  def initialize(clip_norm: nil)
31
19
  @clip_norm = clip_norm
@@ -33,7 +21,7 @@ module DNN
33
21
 
34
22
  # Update layers has params.
35
23
  def update(layers)
36
- target_params = layers.select { |layer| layer.is_a?(Layers::HasParamLayer) && layer.trainable }
24
+ target_params = layers.select { |layer| layer.is_a?(Layers::TrainableLayer) && layer.trainable }
37
25
  .map { |layer| layer.get_params.values }.flatten.compact
38
26
  .select(&:grad)
39
27
  clip_grads(target_params) if @clip_norm
@@ -43,11 +31,6 @@ module DNN
43
31
  end
44
32
  end
45
33
 
46
- def dump(require_status = true)
47
- status = require_status ? @status : nil
48
- { hash: to_hash, status: status }
49
- end
50
-
51
34
  def to_hash(merge_hash = nil)
52
35
  hash = { class: self.class.name, clip_norm: @clip_norm }
53
36
  hash.merge!(merge_hash) if merge_hash
@@ -80,12 +63,11 @@ module DNN
80
63
 
81
64
  # @param [Float] lr Learning rate.
82
65
  # @param [Float] momentum Momentum coefficient.
83
- def initialize(lr = 0.01, momentum: 0, clip_norm: nil)
66
+ def initialize(lr: 0.01, momentum: 0, clip_norm: nil)
84
67
  super(clip_norm: clip_norm)
85
68
  @lr = lr
86
69
  @momentum = momentum
87
70
  @v = {}
88
- @status = { v: @v }
89
71
  end
90
72
 
91
73
  def to_hash
@@ -96,30 +78,30 @@ module DNN
96
78
  params.each do |param|
97
79
  amount = param.grad * @lr
98
80
  if @momentum > 0
99
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
100
- amount += @momentum * @v[param.name]
101
- @v[param.name] = amount
81
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
82
+ amount += @momentum * @v[param]
83
+ @v[param] = amount
102
84
  end
103
85
  param.data -= amount
104
86
  end
105
87
  end
106
88
 
107
89
  def load_hash(hash)
108
- initialize(hash[:lr], momentum: hash[:momentum], clip_norm: hash[:clip_norm])
90
+ initialize(lr: hash[:lr], momentum: hash[:momentum], clip_norm: hash[:clip_norm])
109
91
  end
110
92
  end
111
93
 
112
94
  class Nesterov < SGD
113
- def initialize(lr = 0.01, momentum: 0.9, clip_norm: nil)
114
- super(lr, momentum: momentum, clip_norm: clip_norm)
95
+ def initialize(lr: 0.01, momentum: 0.9, clip_norm: nil)
96
+ super(lr: lr, momentum: momentum, clip_norm: clip_norm)
115
97
  end
116
98
 
117
99
  private def update_params(params)
118
100
  params.each do |param|
119
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
101
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
120
102
  amount = param.grad * @lr
121
- @v[param.name] = @v[param.name] * @momentum - amount
122
- param.data = (param.data + @momentum**2 * @v[param.name]) - (1 + @momentum) * amount
103
+ @v[param] = @v[param] * @momentum - amount
104
+ param.data = (param.data + @momentum**2 * @v[param]) - (1 + @momentum) * amount
123
105
  end
124
106
  end
125
107
  end
@@ -130,19 +112,18 @@ module DNN
130
112
 
131
113
  # @param [Float] lr Learning rate.
132
114
  # @param [Float] eps Value to avoid division by zero.
133
- def initialize(lr = 0.01, eps: 1e-7, clip_norm: nil)
115
+ def initialize(lr: 0.01, eps: 1e-7, clip_norm: nil)
134
116
  super(clip_norm: clip_norm)
135
117
  @lr = lr
136
118
  @eps = eps
137
119
  @g = {}
138
- @status = { g: @g }
139
120
  end
140
121
 
141
122
  private def update_params(params)
142
123
  params.each do |param|
143
- @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
144
- @g[param.name] += param.grad**2
145
- param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad
124
+ @g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
125
+ @g[param] += param.grad**2
126
+ param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
146
127
  end
147
128
  end
148
129
 
@@ -151,7 +132,7 @@ module DNN
151
132
  end
152
133
 
153
134
  def load_hash(hash)
154
- initialize(hash[:lr], eps: hash[:eps], clip_norm: hash[:clip_norm])
135
+ initialize(lr: hash[:lr], eps: hash[:eps], clip_norm: hash[:clip_norm])
155
136
  end
156
137
  end
157
138
 
@@ -163,13 +144,12 @@ module DNN
163
144
  # @param [Float] lr Learning rate.
164
145
  # @param [Float] alpha Moving average index of past slopes.
165
146
  # @param [Float] eps Value to avoid division by zero.
166
- def initialize(lr = 0.001, alpha: 0.9, eps: 1e-7, clip_norm: nil)
147
+ def initialize(lr: 0.001, alpha: 0.9, eps: 1e-7, clip_norm: nil)
167
148
  super(clip_norm: clip_norm)
168
149
  @lr = lr
169
150
  @alpha = alpha
170
151
  @eps = eps
171
152
  @g = {}
172
- @status = { g: @g }
173
153
  end
174
154
 
175
155
  def to_hash
@@ -178,14 +158,14 @@ module DNN
178
158
 
179
159
  private def update_params(params)
180
160
  params.each do |param|
181
- @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
182
- @g[param.name] = @alpha * @g[param.name] + (1 - @alpha) * param.grad**2
183
- param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad
161
+ @g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
162
+ @g[param] = @alpha * @g[param] + (1 - @alpha) * param.grad**2
163
+ param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
184
164
  end
185
165
  end
186
166
 
187
167
  def load_hash(hash)
188
- initialize(hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
168
+ initialize(lr: hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
189
169
  end
190
170
  end
191
171
 
@@ -201,7 +181,6 @@ module DNN
201
181
  @eps = eps
202
182
  @h = {}
203
183
  @s = {}
204
- @status = { h: @h, s: @s }
205
184
  end
206
185
 
207
186
  def to_hash
@@ -210,11 +189,11 @@ module DNN
210
189
 
211
190
  private def update_params(params)
212
191
  params.each do |param|
213
- @h[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
214
- @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
215
- @h[param.name] = @rho * @h[param.name] + (1 - @rho) * param.grad**2
216
- v = (Xumo::NMath.sqrt(@s[param.name] + @eps) / Xumo::NMath.sqrt(@h[param.name] + @eps)) * param.grad
217
- @s[param.name] = @rho * @s[param.name] + (1 - @rho) * v**2
192
+ @h[param] ||= Xumo::SFloat.zeros(*param.data.shape)
193
+ @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
194
+ @h[param] = @rho * @h[param] + (1 - @rho) * param.grad**2
195
+ v = (Xumo::NMath.sqrt(@s[param] + @eps) / Xumo::NMath.sqrt(@h[param] + @eps)) * param.grad
196
+ @s[param] = @rho * @s[param] + (1 - @rho) * v**2
218
197
  param.data -= v
219
198
  end
220
199
  end
@@ -232,14 +211,13 @@ module DNN
232
211
  # @param [Float] lr Learning rate.
233
212
  # @param [Float] alpha Moving average index of past slopes.
234
213
  # @param [Float] eps Value to avoid division by zero.
235
- def initialize(lr = 0.0001, alpha: 0.95, eps: 0.0001, clip_norm: nil)
214
+ def initialize(lr: 0.0001, alpha: 0.95, eps: 0.0001, clip_norm: nil)
236
215
  super(clip_norm: clip_norm)
237
216
  @lr = lr
238
217
  @alpha = alpha
239
218
  @eps = eps
240
219
  @m = {}
241
220
  @v = {}
242
- @status = { m: @m, v: @v }
243
221
  end
244
222
 
245
223
  def to_hash
@@ -248,16 +226,16 @@ module DNN
248
226
 
249
227
  private def update_params(params)
250
228
  params.each do |param|
251
- @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
252
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
253
- @m[param.name] = @alpha * @m[param.name] + (1 - @alpha) * param.grad
254
- @v[param.name] = @alpha * @v[param.name] + (1 - @alpha) * param.grad**2
255
- param.data -= (@lr / Xumo::NMath.sqrt(@v[param.name] - @m[param.name]**2 + @eps)) * param.grad
229
+ @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
230
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
231
+ @m[param] = @alpha * @m[param] + (1 - @alpha) * param.grad
232
+ @v[param] = @alpha * @v[param] + (1 - @alpha) * param.grad**2
233
+ param.data -= (@lr / Xumo::NMath.sqrt(@v[param] - @m[param]**2 + @eps)) * param.grad
256
234
  end
257
235
  end
258
236
 
259
237
  def load_hash(hash)
260
- initialize(hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
238
+ initialize(lr: hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
261
239
  end
262
240
  end
263
241
 
@@ -284,7 +262,6 @@ module DNN
284
262
  @m = {}
285
263
  @v = {}
286
264
  @s = amsgrad ? {} : nil
287
- @status = { t: @t, m: @m, v: @v, s: @s }
288
265
  end
289
266
 
290
267
  def to_hash
@@ -298,16 +275,16 @@ module DNN
298
275
  @t += 1
299
276
  lr = @alpha * Math.sqrt(1 - @beta2**@t) / (1 - @beta1**@t)
300
277
  params.each do |param|
301
- @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
302
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
303
- @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name])
304
- @v[param.name] += (1 - @beta2) * (param.grad**2 - @v[param.name])
278
+ @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
279
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
280
+ @m[param] += (1 - @beta1) * (param.grad - @m[param])
281
+ @v[param] += (1 - @beta2) * (param.grad**2 - @v[param])
305
282
  if @amsgrad
306
- @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
307
- @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name])
308
- param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@s[param.name] + @eps)
283
+ @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
284
+ @s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
285
+ param.data -= lr * @m[param] / Xumo::NMath.sqrt(@s[param] + @eps)
309
286
  else
310
- param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@v[param.name] + @eps)
287
+ param.data -= lr * @m[param] / Xumo::NMath.sqrt(@v[param] + @eps)
311
288
  end
312
289
  end
313
290
  end
@@ -344,16 +321,16 @@ module DNN
344
321
  lower_bound = final_lr * (1 - 1 / (@gamma * @t + 1))
345
322
  upper_bound = final_lr * (1 + 1 / (@gamma * @t))
346
323
  params.each do |param|
347
- @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
348
- @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
349
- @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name])
350
- @v[param.name] += (1 - @beta2) * (param.grad**2 - @v[param.name])
324
+ @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
325
+ @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
326
+ @m[param] += (1 - @beta1) * (param.grad - @m[param])
327
+ @v[param] += (1 - @beta2) * (param.grad**2 - @v[param])
351
328
  if @amsgrad
352
- @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
353
- @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name])
354
- param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name]
329
+ @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
330
+ @s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
331
+ param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param]) + @eps), lower_bound, upper_bound) * @m[param]
355
332
  else
356
- param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name]
333
+ param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param]) + @eps), lower_bound, upper_bound) * @m[param]
357
334
  end
358
335
  end
359
336
  end