ruby-dnn 0.14.3 → 0.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +5 -3
- data/Rakefile +4 -2
- data/examples/api-examples/save_example.rb +7 -5
- data/examples/dcgan/imgen.rb +2 -7
- data/examples/dcgan/train.rb +0 -1
- data/lib/dnn.rb +10 -10
- data/lib/dnn/core/callbacks.rb +6 -2
- data/lib/dnn/core/iterator.rb +10 -2
- data/lib/dnn/core/{activations.rb → layers/activations.rb} +0 -0
- data/lib/dnn/core/{layers.rb → layers/basic_layers.rb} +31 -8
- data/lib/dnn/core/{cnn_layers.rb → layers/cnn_layers.rb} +0 -0
- data/lib/dnn/core/{embedding.rb → layers/embedding.rb} +5 -4
- data/lib/dnn/core/{merge_layers.rb → layers/merge_layers.rb} +1 -1
- data/lib/dnn/core/{normalizations.rb → layers/normalizations.rb} +9 -5
- data/lib/dnn/core/{rnn_layers.rb → layers/rnn_layers.rb} +25 -16
- data/lib/dnn/core/losses.rb +8 -0
- data/lib/dnn/core/models.rb +164 -68
- data/lib/dnn/core/optimizers.rb +49 -72
- data/lib/dnn/core/param.rb +0 -2
- data/lib/dnn/core/savers.rb +40 -49
- data/lib/dnn/datasets/stl-10.rb +65 -0
- data/lib/dnn/version.rb +1 -1
- metadata +10 -9
data/lib/dnn/core/losses.rb
CHANGED
data/lib/dnn/core/models.rb
CHANGED
@@ -1,7 +1,97 @@
|
|
1
1
|
module DNN
|
2
2
|
module Models
|
3
|
+
|
4
|
+
class LayersList < Array
|
5
|
+
def self.from_hash_list(hash_list)
|
6
|
+
layers_list = new
|
7
|
+
hash_list.each do |hash|
|
8
|
+
obj_class = DNN.const_get(hash[:class])
|
9
|
+
obj = obj_class.allocate
|
10
|
+
if obj.is_a?(Chain)
|
11
|
+
obj = obj_class.new
|
12
|
+
obj.load_hash(hash)
|
13
|
+
else
|
14
|
+
obj = Layers::Layer.from_hash(hash)
|
15
|
+
end
|
16
|
+
layers_list << obj
|
17
|
+
end
|
18
|
+
layers_list
|
19
|
+
end
|
20
|
+
|
21
|
+
def to_hash_list
|
22
|
+
map { |layer| layer.to_hash }
|
23
|
+
end
|
24
|
+
|
25
|
+
# Get the all layers.
|
26
|
+
# @return [Array] All layers array.
|
27
|
+
def layers
|
28
|
+
layers_array = []
|
29
|
+
each do |layer|
|
30
|
+
if layer.is_a?(Layers::Layer)
|
31
|
+
layers_array << layer
|
32
|
+
elsif layer.is_a?(Chain) || layer.is_a?(LayersList)
|
33
|
+
layers_array.concat(layer.layers)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
layers_array
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
class Chain
|
41
|
+
def call(x)
|
42
|
+
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'call'"
|
43
|
+
end
|
44
|
+
|
45
|
+
# Get the all layers.
|
46
|
+
# @return [Array] All layers array.
|
47
|
+
def layers
|
48
|
+
layers_array = []
|
49
|
+
instance_variables.sort.each do |ivar|
|
50
|
+
obj = instance_variable_get(ivar)
|
51
|
+
if obj.is_a?(Layers::Layer)
|
52
|
+
layers_array << obj
|
53
|
+
elsif obj.is_a?(Chain) || obj.is_a?(LayersList)
|
54
|
+
layers_array.concat(obj.layers)
|
55
|
+
end
|
56
|
+
end
|
57
|
+
layers_array
|
58
|
+
end
|
59
|
+
|
60
|
+
def to_hash
|
61
|
+
layers_hash = { class: self.class.name }
|
62
|
+
instance_variables.sort.each do |ivar|
|
63
|
+
obj = instance_variable_get(ivar)
|
64
|
+
if obj.is_a?(Layers::Layer) || obj.is_a?(Chain)
|
65
|
+
layers_hash[ivar] = obj.to_hash
|
66
|
+
elsif obj.is_a?(LayersList)
|
67
|
+
layers_hash[ivar] = obj.to_hash_list
|
68
|
+
end
|
69
|
+
end
|
70
|
+
layers_hash
|
71
|
+
end
|
72
|
+
|
73
|
+
def load_hash(layers_hash)
|
74
|
+
instance_variables.sort.each do |ivar|
|
75
|
+
hash_or_array = layers_hash[ivar]
|
76
|
+
if hash_or_array.is_a?(Array)
|
77
|
+
instance_variable_set(ivar, LayersList.from_hash_list(hash_or_array))
|
78
|
+
elsif hash_or_array.is_a?(Hash)
|
79
|
+
obj_class = DNN.const_get(hash_or_array[:class])
|
80
|
+
obj = obj_class.allocate
|
81
|
+
if obj.is_a?(Chain)
|
82
|
+
obj = obj_class.new
|
83
|
+
obj.load_hash(hash_or_array)
|
84
|
+
instance_variable_set(ivar, obj)
|
85
|
+
else
|
86
|
+
instance_variable_set(ivar, Layers::Layer.from_hash(hash_or_array))
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
|
3
93
|
# This class deals with the model of the network.
|
4
|
-
class Model
|
94
|
+
class Model < Chain
|
5
95
|
attr_accessor :optimizer
|
6
96
|
attr_accessor :loss_func
|
7
97
|
attr_reader :last_log
|
@@ -10,7 +100,7 @@ module DNN
|
|
10
100
|
# @param [String] file_name File name of marshal model to load.
|
11
101
|
# @return [DNN::Models::Model] Return the loaded model.
|
12
102
|
def self.load(file_name)
|
13
|
-
model =
|
103
|
+
model = self.allocate
|
14
104
|
loader = Loaders::MarshalLoader.new(model)
|
15
105
|
loader.load(file_name)
|
16
106
|
model
|
@@ -113,8 +203,12 @@ module DNN
|
|
113
203
|
end
|
114
204
|
|
115
205
|
if test
|
116
|
-
|
117
|
-
|
206
|
+
acc, loss = if test.is_a?(Array)
|
207
|
+
evaluate(test[0], test[1], batch_size: batch_size)
|
208
|
+
else
|
209
|
+
evaluate_by_iterator(test, batch_size: batch_size)
|
210
|
+
end
|
211
|
+
print " " + metrics_to_str({ accuracy: acc, test_loss: loss }) if verbose
|
118
212
|
end
|
119
213
|
puts "" if verbose
|
120
214
|
call_callbacks(:after_epoch)
|
@@ -138,16 +232,6 @@ module DNN
|
|
138
232
|
{ loss: loss_value }
|
139
233
|
end
|
140
234
|
|
141
|
-
# Implement the test process to be performed.
|
142
|
-
# @param [Numo::SFloat] x Input training data.
|
143
|
-
# @param [Numo::SFloat] y Output training data.
|
144
|
-
# @param [Integer] batch_size Batch size used for one test.
|
145
|
-
# @return [Hash] Hash of contents to be output to log.
|
146
|
-
private def test(x, y, batch_size: 100)
|
147
|
-
acc, test_loss = accuracy(x, y, batch_size: batch_size)
|
148
|
-
{ accuracy: acc, test_loss: test_loss }
|
149
|
-
end
|
150
|
-
|
151
235
|
# Training once.
|
152
236
|
# Setup the model before use this method.
|
153
237
|
# @param [Numo::SFloat] x Input training data.
|
@@ -169,20 +253,24 @@ module DNN
|
|
169
253
|
loss_value
|
170
254
|
end
|
171
255
|
|
172
|
-
# Evaluate model and get accuracy of test data.
|
256
|
+
# Evaluate model and get accuracy and loss of test data.
|
173
257
|
# @param [Numo::SFloat] x Input test data.
|
174
258
|
# @param [Numo::SFloat] y Output test data.
|
175
259
|
# @param [Integer] batch_size Batch size used for one test.
|
176
260
|
# @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
|
177
|
-
def
|
261
|
+
def evaluate(x, y, batch_size: 100)
|
178
262
|
check_xy_type(x, y)
|
179
|
-
|
263
|
+
evaluate_by_iterator(Iterator.new(x, y, random: false))
|
264
|
+
end
|
265
|
+
|
266
|
+
# Evaluate model by iterator
|
267
|
+
def evaluate_by_iterator(test_iterator, batch_size: 100)
|
268
|
+
num_test_datas = test_iterator.num_datas
|
180
269
|
batch_size = batch_size >= num_test_datas[0] ? num_test_datas : batch_size
|
181
|
-
iter = Iterator.new(x, y, random: false)
|
182
270
|
total_correct = 0
|
183
271
|
sum_loss = 0
|
184
272
|
max_steps = (num_test_datas.to_f / batch_size).ceil
|
185
|
-
|
273
|
+
test_iterator.foreach(batch_size) do |x_batch, y_batch|
|
186
274
|
correct, loss_value = test_on_batch(x_batch, y_batch)
|
187
275
|
total_correct += correct
|
188
276
|
sum_loss += loss_value
|
@@ -201,16 +289,16 @@ module DNN
|
|
201
289
|
def test_on_batch(x, y)
|
202
290
|
call_callbacks(:before_test_on_batch)
|
203
291
|
x = forward(x, false)
|
204
|
-
correct =
|
292
|
+
correct = accuracy(x, y)
|
205
293
|
loss_value = @loss_func.loss(x, y)
|
206
294
|
call_callbacks(:after_test_on_batch)
|
207
295
|
[correct, loss_value]
|
208
296
|
end
|
209
297
|
|
210
|
-
# Implement the process to
|
298
|
+
# Implement the process to accuracy this model.
|
211
299
|
# @param [Numo::SFloat] x Input test data.
|
212
300
|
# @param [Numo::SFloat] y Output test data.
|
213
|
-
private def
|
301
|
+
private def accuracy(x, y)
|
214
302
|
if x.shape[1..-1] == [1]
|
215
303
|
correct = 0
|
216
304
|
x.shape[0].times do |i|
|
@@ -257,11 +345,24 @@ module DNN
|
|
257
345
|
@callbacks = []
|
258
346
|
end
|
259
347
|
|
348
|
+
# Load marshal params.
|
349
|
+
# @param [String] file_name File name of marshal model to load.
|
350
|
+
def load_params(file_name)
|
351
|
+
loader = Loaders::MarshalLoader.new(self)
|
352
|
+
loader.load(file_name)
|
353
|
+
end
|
354
|
+
|
260
355
|
# Save the model in marshal format.
|
261
356
|
# @param [String] file_name Name to save model.
|
262
|
-
|
263
|
-
|
264
|
-
saver
|
357
|
+
def save(file_name)
|
358
|
+
saver = Savers::MarshalSaver.new(self, include_model: true)
|
359
|
+
saver.save(file_name)
|
360
|
+
end
|
361
|
+
|
362
|
+
# Save the params in marshal format.
|
363
|
+
# @param [String] file_name Name to save model.
|
364
|
+
def save_params(file_name)
|
365
|
+
saver = Savers::MarshalSaver.new(self, include_model: false)
|
265
366
|
saver.save(file_name)
|
266
367
|
end
|
267
368
|
|
@@ -270,37 +371,21 @@ module DNN
|
|
270
371
|
Marshal.load(Marshal.dump(self))
|
271
372
|
end
|
272
373
|
|
273
|
-
# Get the all layers.
|
274
|
-
# @return [Array] All layers array.
|
275
|
-
def layers
|
276
|
-
raise DNN_Error, "This model is not built. You need build this model using predict or train." unless built?
|
277
|
-
return @layers_cache if @layers_cache
|
278
|
-
layers = []
|
279
|
-
get_layers = -> link do
|
280
|
-
return unless link
|
281
|
-
layers.unshift(link.layer)
|
282
|
-
if link.is_a?(TwoInputLink)
|
283
|
-
get_layers.(link.prev1)
|
284
|
-
get_layers.(link.prev2)
|
285
|
-
else
|
286
|
-
get_layers.(link.prev)
|
287
|
-
end
|
288
|
-
end
|
289
|
-
get_layers.(@last_link)
|
290
|
-
@layers_cache = layers.uniq
|
291
|
-
end
|
292
|
-
|
293
|
-
# Get the all has param layers.
|
374
|
+
# Get the all trainable layers.
|
294
375
|
# @return [Array] All has param layers array.
|
295
|
-
def
|
296
|
-
layers.select { |layer| layer.is_a?(Layers::
|
376
|
+
def trainable_layers
|
377
|
+
layers.select { |layer| layer.is_a?(Layers::TrainableLayer) }
|
297
378
|
end
|
298
379
|
|
299
380
|
# Get the layer that the model has.
|
300
381
|
# @param [Symbol] name The name of the layer to get.
|
301
382
|
# @return [DNN::Layers::Layer] Return the layer.
|
302
383
|
def get_layer(name)
|
303
|
-
|
384
|
+
layer = instance_variable_get("@#{name}")
|
385
|
+
if layer.is_a?(Layers::Layer) || layer.is_a?(Chain) || layer.is_a?(LayersList)
|
386
|
+
return layer
|
387
|
+
end
|
388
|
+
nil
|
304
389
|
end
|
305
390
|
|
306
391
|
# @return [Boolean] If model have already been built then return true.
|
@@ -308,6 +393,31 @@ module DNN
|
|
308
393
|
@built
|
309
394
|
end
|
310
395
|
|
396
|
+
def clean_layers
|
397
|
+
layers.each do |layer|
|
398
|
+
layer.clean
|
399
|
+
end
|
400
|
+
@loss_func.clean
|
401
|
+
@last_link = nil
|
402
|
+
@layers_cache = nil
|
403
|
+
end
|
404
|
+
|
405
|
+
def get_all_params_data
|
406
|
+
trainable_layers.map do |layer|
|
407
|
+
layer.get_params.to_h do |key, param|
|
408
|
+
[key, param.data]
|
409
|
+
end
|
410
|
+
end
|
411
|
+
end
|
412
|
+
|
413
|
+
def set_all_params_data(params_data)
|
414
|
+
trainable_layers.each.with_index do |layer, i|
|
415
|
+
params_data[i].each do |(key, data)|
|
416
|
+
layer.get_params[key].data = data
|
417
|
+
end
|
418
|
+
end
|
419
|
+
end
|
420
|
+
|
311
421
|
private
|
312
422
|
|
313
423
|
def forward(x, learning_phase)
|
@@ -322,7 +432,6 @@ module DNN
|
|
322
432
|
@last_link = output_tensor.link
|
323
433
|
unless @built
|
324
434
|
@built = true
|
325
|
-
naming
|
326
435
|
end
|
327
436
|
output_tensor.data
|
328
437
|
end
|
@@ -337,19 +446,6 @@ module DNN
|
|
337
446
|
end
|
338
447
|
end
|
339
448
|
|
340
|
-
def naming
|
341
|
-
layers.each do |layer|
|
342
|
-
id = layers.select { |l| l.is_a?(layer.class) }.index(layer)
|
343
|
-
class_name = layer.class.name.split("::").last
|
344
|
-
layer.name = "#{class_name}_#{id}".to_sym unless layer.name
|
345
|
-
if layer.is_a?(Layers::HasParamLayer)
|
346
|
-
layer.get_params.each do |param_key, param|
|
347
|
-
param.name = "#{layer.name}__#{param_key}".to_sym unless param.name
|
348
|
-
end
|
349
|
-
end
|
350
|
-
end
|
351
|
-
end
|
352
|
-
|
353
449
|
def metrics_to_str(mertics)
|
354
450
|
mertics.map { |key, num| "#{key}: #{sprintf('%.4f', num)}" }.join(", ")
|
355
451
|
end
|
@@ -370,7 +466,7 @@ module DNN
|
|
370
466
|
# @param [Array] stack All layers possessed by the model.
|
371
467
|
def initialize(stack = [])
|
372
468
|
super()
|
373
|
-
@stack =
|
469
|
+
@stack = LayersList.new
|
374
470
|
stack.each do |layer|
|
375
471
|
add(layer)
|
376
472
|
end
|
@@ -380,8 +476,8 @@ module DNN
|
|
380
476
|
# @param [DNN::Layers::Layer] layer Layer to add to the model.
|
381
477
|
# @return [DNN::Models::Model] Return self.
|
382
478
|
def add(layer)
|
383
|
-
if layer.is_a?(
|
384
|
-
raise TypeError, "layer: #{layer.class.name} should not be a DNN::
|
479
|
+
if layer.is_a?(Layers::MergeLayer)
|
480
|
+
raise TypeError, "layer: #{layer.class.name} should not be a DNN::Layers::MergeLayer class."
|
385
481
|
end
|
386
482
|
unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
|
387
483
|
raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
|
@@ -396,8 +492,8 @@ module DNN
|
|
396
492
|
# @param [DNN::Layers::Layer] layer Layer to add to the model.
|
397
493
|
# @return [DNN::Models::Model] Return self.
|
398
494
|
def insert(index, layer)
|
399
|
-
if layer.is_a?(
|
400
|
-
raise TypeError, "layer: #{layer.class.name} should not be a DNN::
|
495
|
+
if layer.is_a?(Layers::MergeLayer)
|
496
|
+
raise TypeError, "layer: #{layer.class.name} should not be a DNN::Layers::MergeLayer class."
|
401
497
|
end
|
402
498
|
unless layer.is_a?(Layers::Layer) || layer.is_a?(Model)
|
403
499
|
raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Model class."
|
data/lib/dnn/core/optimizers.rb
CHANGED
@@ -3,7 +3,6 @@ module DNN
|
|
3
3
|
|
4
4
|
# Super class of all optimizer classes.
|
5
5
|
class Optimizer
|
6
|
-
attr_reader :status
|
7
6
|
attr_accessor :clip_norm
|
8
7
|
|
9
8
|
def self.from_hash(hash)
|
@@ -15,17 +14,6 @@ module DNN
|
|
15
14
|
optimizer
|
16
15
|
end
|
17
16
|
|
18
|
-
def self.load(dumped)
|
19
|
-
opt = from_hash(dumped[:hash])
|
20
|
-
return opt unless dumped[:status]
|
21
|
-
dumped[:status].each do |key, state|
|
22
|
-
state = state.clone
|
23
|
-
opt.status[key] = state
|
24
|
-
opt.instance_variable_set("@#{key}", state)
|
25
|
-
end
|
26
|
-
opt
|
27
|
-
end
|
28
|
-
|
29
17
|
# @param [Float | NilClass] clip_norm Gradient clip norm.
|
30
18
|
def initialize(clip_norm: nil)
|
31
19
|
@clip_norm = clip_norm
|
@@ -33,7 +21,7 @@ module DNN
|
|
33
21
|
|
34
22
|
# Update layers has params.
|
35
23
|
def update(layers)
|
36
|
-
target_params = layers.select { |layer| layer.is_a?(Layers::
|
24
|
+
target_params = layers.select { |layer| layer.is_a?(Layers::TrainableLayer) && layer.trainable }
|
37
25
|
.map { |layer| layer.get_params.values }.flatten.compact
|
38
26
|
.select(&:grad)
|
39
27
|
clip_grads(target_params) if @clip_norm
|
@@ -43,11 +31,6 @@ module DNN
|
|
43
31
|
end
|
44
32
|
end
|
45
33
|
|
46
|
-
def dump(require_status = true)
|
47
|
-
status = require_status ? @status : nil
|
48
|
-
{ hash: to_hash, status: status }
|
49
|
-
end
|
50
|
-
|
51
34
|
def to_hash(merge_hash = nil)
|
52
35
|
hash = { class: self.class.name, clip_norm: @clip_norm }
|
53
36
|
hash.merge!(merge_hash) if merge_hash
|
@@ -80,12 +63,11 @@ module DNN
|
|
80
63
|
|
81
64
|
# @param [Float] lr Learning rate.
|
82
65
|
# @param [Float] momentum Momentum coefficient.
|
83
|
-
def initialize(lr
|
66
|
+
def initialize(lr: 0.01, momentum: 0, clip_norm: nil)
|
84
67
|
super(clip_norm: clip_norm)
|
85
68
|
@lr = lr
|
86
69
|
@momentum = momentum
|
87
70
|
@v = {}
|
88
|
-
@status = { v: @v }
|
89
71
|
end
|
90
72
|
|
91
73
|
def to_hash
|
@@ -96,30 +78,30 @@ module DNN
|
|
96
78
|
params.each do |param|
|
97
79
|
amount = param.grad * @lr
|
98
80
|
if @momentum > 0
|
99
|
-
@v[param
|
100
|
-
amount += @momentum * @v[param
|
101
|
-
@v[param
|
81
|
+
@v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
82
|
+
amount += @momentum * @v[param]
|
83
|
+
@v[param] = amount
|
102
84
|
end
|
103
85
|
param.data -= amount
|
104
86
|
end
|
105
87
|
end
|
106
88
|
|
107
89
|
def load_hash(hash)
|
108
|
-
initialize(hash[:lr], momentum: hash[:momentum], clip_norm: hash[:clip_norm])
|
90
|
+
initialize(lr: hash[:lr], momentum: hash[:momentum], clip_norm: hash[:clip_norm])
|
109
91
|
end
|
110
92
|
end
|
111
93
|
|
112
94
|
class Nesterov < SGD
|
113
|
-
def initialize(lr
|
114
|
-
super(lr, momentum: momentum, clip_norm: clip_norm)
|
95
|
+
def initialize(lr: 0.01, momentum: 0.9, clip_norm: nil)
|
96
|
+
super(lr: lr, momentum: momentum, clip_norm: clip_norm)
|
115
97
|
end
|
116
98
|
|
117
99
|
private def update_params(params)
|
118
100
|
params.each do |param|
|
119
|
-
@v[param
|
101
|
+
@v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
120
102
|
amount = param.grad * @lr
|
121
|
-
@v[param
|
122
|
-
param.data = (param.data + @momentum**2 * @v[param
|
103
|
+
@v[param] = @v[param] * @momentum - amount
|
104
|
+
param.data = (param.data + @momentum**2 * @v[param]) - (1 + @momentum) * amount
|
123
105
|
end
|
124
106
|
end
|
125
107
|
end
|
@@ -130,19 +112,18 @@ module DNN
|
|
130
112
|
|
131
113
|
# @param [Float] lr Learning rate.
|
132
114
|
# @param [Float] eps Value to avoid division by zero.
|
133
|
-
def initialize(lr
|
115
|
+
def initialize(lr: 0.01, eps: 1e-7, clip_norm: nil)
|
134
116
|
super(clip_norm: clip_norm)
|
135
117
|
@lr = lr
|
136
118
|
@eps = eps
|
137
119
|
@g = {}
|
138
|
-
@status = { g: @g }
|
139
120
|
end
|
140
121
|
|
141
122
|
private def update_params(params)
|
142
123
|
params.each do |param|
|
143
|
-
@g[param
|
144
|
-
@g[param
|
145
|
-
param.data -= (@lr / Xumo::NMath.sqrt(@g[param
|
124
|
+
@g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
125
|
+
@g[param] += param.grad**2
|
126
|
+
param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
|
146
127
|
end
|
147
128
|
end
|
148
129
|
|
@@ -151,7 +132,7 @@ module DNN
|
|
151
132
|
end
|
152
133
|
|
153
134
|
def load_hash(hash)
|
154
|
-
initialize(hash[:lr], eps: hash[:eps], clip_norm: hash[:clip_norm])
|
135
|
+
initialize(lr: hash[:lr], eps: hash[:eps], clip_norm: hash[:clip_norm])
|
155
136
|
end
|
156
137
|
end
|
157
138
|
|
@@ -163,13 +144,12 @@ module DNN
|
|
163
144
|
# @param [Float] lr Learning rate.
|
164
145
|
# @param [Float] alpha Moving average index of past slopes.
|
165
146
|
# @param [Float] eps Value to avoid division by zero.
|
166
|
-
def initialize(lr
|
147
|
+
def initialize(lr: 0.001, alpha: 0.9, eps: 1e-7, clip_norm: nil)
|
167
148
|
super(clip_norm: clip_norm)
|
168
149
|
@lr = lr
|
169
150
|
@alpha = alpha
|
170
151
|
@eps = eps
|
171
152
|
@g = {}
|
172
|
-
@status = { g: @g }
|
173
153
|
end
|
174
154
|
|
175
155
|
def to_hash
|
@@ -178,14 +158,14 @@ module DNN
|
|
178
158
|
|
179
159
|
private def update_params(params)
|
180
160
|
params.each do |param|
|
181
|
-
@g[param
|
182
|
-
@g[param
|
183
|
-
param.data -= (@lr / Xumo::NMath.sqrt(@g[param
|
161
|
+
@g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
162
|
+
@g[param] = @alpha * @g[param] + (1 - @alpha) * param.grad**2
|
163
|
+
param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
|
184
164
|
end
|
185
165
|
end
|
186
166
|
|
187
167
|
def load_hash(hash)
|
188
|
-
initialize(hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
|
168
|
+
initialize(lr: hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
|
189
169
|
end
|
190
170
|
end
|
191
171
|
|
@@ -201,7 +181,6 @@ module DNN
|
|
201
181
|
@eps = eps
|
202
182
|
@h = {}
|
203
183
|
@s = {}
|
204
|
-
@status = { h: @h, s: @s }
|
205
184
|
end
|
206
185
|
|
207
186
|
def to_hash
|
@@ -210,11 +189,11 @@ module DNN
|
|
210
189
|
|
211
190
|
private def update_params(params)
|
212
191
|
params.each do |param|
|
213
|
-
@h[param
|
214
|
-
@s[param
|
215
|
-
@h[param
|
216
|
-
v = (Xumo::NMath.sqrt(@s[param
|
217
|
-
@s[param
|
192
|
+
@h[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
193
|
+
@s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
194
|
+
@h[param] = @rho * @h[param] + (1 - @rho) * param.grad**2
|
195
|
+
v = (Xumo::NMath.sqrt(@s[param] + @eps) / Xumo::NMath.sqrt(@h[param] + @eps)) * param.grad
|
196
|
+
@s[param] = @rho * @s[param] + (1 - @rho) * v**2
|
218
197
|
param.data -= v
|
219
198
|
end
|
220
199
|
end
|
@@ -232,14 +211,13 @@ module DNN
|
|
232
211
|
# @param [Float] lr Learning rate.
|
233
212
|
# @param [Float] alpha Moving average index of past slopes.
|
234
213
|
# @param [Float] eps Value to avoid division by zero.
|
235
|
-
def initialize(lr
|
214
|
+
def initialize(lr: 0.0001, alpha: 0.95, eps: 0.0001, clip_norm: nil)
|
236
215
|
super(clip_norm: clip_norm)
|
237
216
|
@lr = lr
|
238
217
|
@alpha = alpha
|
239
218
|
@eps = eps
|
240
219
|
@m = {}
|
241
220
|
@v = {}
|
242
|
-
@status = { m: @m, v: @v }
|
243
221
|
end
|
244
222
|
|
245
223
|
def to_hash
|
@@ -248,16 +226,16 @@ module DNN
|
|
248
226
|
|
249
227
|
private def update_params(params)
|
250
228
|
params.each do |param|
|
251
|
-
@m[param
|
252
|
-
@v[param
|
253
|
-
@m[param
|
254
|
-
@v[param
|
255
|
-
param.data -= (@lr / Xumo::NMath.sqrt(@v[param
|
229
|
+
@m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
230
|
+
@v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
231
|
+
@m[param] = @alpha * @m[param] + (1 - @alpha) * param.grad
|
232
|
+
@v[param] = @alpha * @v[param] + (1 - @alpha) * param.grad**2
|
233
|
+
param.data -= (@lr / Xumo::NMath.sqrt(@v[param] - @m[param]**2 + @eps)) * param.grad
|
256
234
|
end
|
257
235
|
end
|
258
236
|
|
259
237
|
def load_hash(hash)
|
260
|
-
initialize(hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
|
238
|
+
initialize(lr: hash[:lr], alpha: hash[:alpha], eps: hash[:eps], clip_norm: hash[:clip_norm])
|
261
239
|
end
|
262
240
|
end
|
263
241
|
|
@@ -284,7 +262,6 @@ module DNN
|
|
284
262
|
@m = {}
|
285
263
|
@v = {}
|
286
264
|
@s = amsgrad ? {} : nil
|
287
|
-
@status = { t: @t, m: @m, v: @v, s: @s }
|
288
265
|
end
|
289
266
|
|
290
267
|
def to_hash
|
@@ -298,16 +275,16 @@ module DNN
|
|
298
275
|
@t += 1
|
299
276
|
lr = @alpha * Math.sqrt(1 - @beta2**@t) / (1 - @beta1**@t)
|
300
277
|
params.each do |param|
|
301
|
-
@m[param
|
302
|
-
@v[param
|
303
|
-
@m[param
|
304
|
-
@v[param
|
278
|
+
@m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
279
|
+
@v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
280
|
+
@m[param] += (1 - @beta1) * (param.grad - @m[param])
|
281
|
+
@v[param] += (1 - @beta2) * (param.grad**2 - @v[param])
|
305
282
|
if @amsgrad
|
306
|
-
@s[param
|
307
|
-
@s[param
|
308
|
-
param.data -= lr * @m[param
|
283
|
+
@s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
284
|
+
@s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
|
285
|
+
param.data -= lr * @m[param] / Xumo::NMath.sqrt(@s[param] + @eps)
|
309
286
|
else
|
310
|
-
param.data -= lr * @m[param
|
287
|
+
param.data -= lr * @m[param] / Xumo::NMath.sqrt(@v[param] + @eps)
|
311
288
|
end
|
312
289
|
end
|
313
290
|
end
|
@@ -344,16 +321,16 @@ module DNN
|
|
344
321
|
lower_bound = final_lr * (1 - 1 / (@gamma * @t + 1))
|
345
322
|
upper_bound = final_lr * (1 + 1 / (@gamma * @t))
|
346
323
|
params.each do |param|
|
347
|
-
@m[param
|
348
|
-
@v[param
|
349
|
-
@m[param
|
350
|
-
@v[param
|
324
|
+
@m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
325
|
+
@v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
326
|
+
@m[param] += (1 - @beta1) * (param.grad - @m[param])
|
327
|
+
@v[param] += (1 - @beta2) * (param.grad**2 - @v[param])
|
351
328
|
if @amsgrad
|
352
|
-
@s[param
|
353
|
-
@s[param
|
354
|
-
param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param
|
329
|
+
@s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
330
|
+
@s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
|
331
|
+
param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param]) + @eps), lower_bound, upper_bound) * @m[param]
|
355
332
|
else
|
356
|
-
param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param
|
333
|
+
param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param]) + @eps), lower_bound, upper_bound) * @m[param]
|
357
334
|
end
|
358
335
|
end
|
359
336
|
end
|