ruby-dnn 1.2.2 → 1.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (86) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -0
  3. data/.travis.yml +0 -0
  4. data/CODE_OF_CONDUCT.md +0 -0
  5. data/Gemfile +0 -0
  6. data/LICENSE.txt +0 -0
  7. data/README.md +0 -0
  8. data/Rakefile +5 -0
  9. data/examples/api-examples/early_stopping_example.rb +0 -0
  10. data/examples/api-examples/initializer_example.rb +0 -0
  11. data/examples/api-examples/regularizer_example.rb +0 -0
  12. data/examples/api-examples/save_example.rb +0 -0
  13. data/examples/cifar100_example.rb +0 -0
  14. data/examples/cifar10_example.rb +0 -0
  15. data/examples/dcgan/dcgan.rb +1 -1
  16. data/examples/dcgan/imgen.rb +0 -0
  17. data/examples/dcgan/train.rb +0 -0
  18. data/examples/iris_example.rb +17 -41
  19. data/examples/iris_example_unused_model.rb +57 -0
  20. data/examples/judge-number/README.md +0 -0
  21. data/examples/judge-number/capture.PNG +0 -0
  22. data/examples/judge-number/convnet8.rb +0 -0
  23. data/examples/judge-number/make_weights.rb +0 -0
  24. data/examples/judge-number/mnist_predict.rb +0 -0
  25. data/examples/judge-number/mnist_train.rb +0 -0
  26. data/examples/judge-number/public/httpRequest.js +0 -0
  27. data/examples/judge-number/public/judgeNumber.js +0 -0
  28. data/examples/judge-number/server.rb +0 -0
  29. data/examples/judge-number/trained_mnist_params.marshal +0 -0
  30. data/examples/judge-number/views/index.erb +0 -0
  31. data/examples/mnist_conv2d_example.rb +0 -0
  32. data/examples/mnist_define_by_run.rb +0 -0
  33. data/examples/mnist_example.rb +0 -0
  34. data/examples/mnist_gpu.rb +0 -0
  35. data/examples/mnist_lstm_example.rb +0 -0
  36. data/examples/pix2pix/dcgan.rb +0 -0
  37. data/examples/pix2pix/imgen.rb +0 -0
  38. data/examples/pix2pix/train.rb +0 -0
  39. data/examples/vae.rb +1 -1
  40. data/examples/xor_example.rb +0 -0
  41. data/ext/rb_stb_image/extconf.rb +0 -0
  42. data/ext/rb_stb_image/rb_stb_image.c +0 -0
  43. data/img/cart-pole.gif +0 -0
  44. data/img/cycle-gan.PNG +0 -0
  45. data/img/facade-pix2pix.png +0 -0
  46. data/lib/dnn/core/callbacks.rb +18 -8
  47. data/lib/dnn/core/error.rb +0 -0
  48. data/lib/dnn/core/global.rb +0 -0
  49. data/lib/dnn/core/initializers.rb +0 -0
  50. data/lib/dnn/core/iterator.rb +20 -4
  51. data/lib/dnn/core/layers/activations.rb +0 -0
  52. data/lib/dnn/core/layers/basic_layers.rb +2 -2
  53. data/lib/dnn/core/layers/cnn_layers.rb +0 -0
  54. data/lib/dnn/core/layers/embedding.rb +0 -0
  55. data/lib/dnn/core/layers/math_layers.rb +0 -0
  56. data/lib/dnn/core/layers/merge_layers.rb +2 -2
  57. data/lib/dnn/core/layers/normalizations.rb +0 -0
  58. data/lib/dnn/core/layers/rnn_layers.rb +20 -24
  59. data/lib/dnn/core/layers/split_layers.rb +0 -0
  60. data/lib/dnn/core/link.rb +0 -0
  61. data/lib/dnn/core/losses.rb +2 -2
  62. data/lib/dnn/core/models.rb +474 -149
  63. data/lib/dnn/core/monkey_patch.rb +0 -0
  64. data/lib/dnn/core/optimizers.rb +0 -0
  65. data/lib/dnn/core/param.rb +0 -0
  66. data/lib/dnn/core/regularizers.rb +0 -0
  67. data/lib/dnn/core/savers.rb +4 -12
  68. data/lib/dnn/core/tensor.rb +0 -0
  69. data/lib/dnn/core/utils.rb +14 -0
  70. data/lib/dnn/datasets/cifar10.rb +0 -0
  71. data/lib/dnn/datasets/cifar100.rb +0 -0
  72. data/lib/dnn/datasets/downloader.rb +12 -3
  73. data/lib/dnn/datasets/fashion-mnist.rb +0 -0
  74. data/lib/dnn/datasets/iris.rb +5 -1
  75. data/lib/dnn/datasets/mnist.rb +0 -0
  76. data/lib/dnn/datasets/stl-10.rb +0 -0
  77. data/lib/dnn/image.rb +1 -1
  78. data/lib/dnn/keras-model-convertor.rb +0 -0
  79. data/lib/dnn/numo2numpy.rb +0 -0
  80. data/lib/dnn/version.rb +1 -1
  81. data/lib/dnn.rb +32 -26
  82. data/ruby-dnn.gemspec +1 -0
  83. data/third_party/stb_image.h +0 -0
  84. data/third_party/stb_image_resize.h +0 -0
  85. data/third_party/stb_image_write.h +0 -0
  86. metadata +21 -6
@@ -130,6 +130,7 @@ module DNN
130
130
  @loss_weights = nil
131
131
  @callbacks = []
132
132
  @last_log = {}
133
+ @early_stop_requested = false
133
134
  end
134
135
 
135
136
  def call(input_tensors)
@@ -182,21 +183,24 @@ module DNN
182
183
  # @param [Array | NilClass] test If you to test the model for every 1 epoch,
183
184
  # specify [x_test, y_test]. Don't test to the model, specify nil.
184
185
  # @param [Boolean] verbose Set true to display the log. If false is set, the log is not displayed.
185
- # @param [Boolean] accuracy Set true to compute the accuracy.
186
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
187
+ # @param [IO] io Specifies the IO object to use for logging.
186
188
  def train(x, y, epochs,
187
189
  batch_size: 1,
188
190
  initial_epoch: 1,
189
191
  test: nil,
190
192
  verbose: true,
191
- accuracy: true)
192
- check_xy_type(x, y)
193
- train_iterator = Iterator.new(x, y)
194
- train_by_iterator(train_iterator, epochs,
195
- batch_size: batch_size,
196
- initial_epoch: initial_epoch,
197
- test: test,
198
- verbose: verbose,
199
- accuracy: accuracy)
193
+ need_accuracy: true,
194
+ io: $stdout)
195
+ trainer = ModelTrainer.new(self)
196
+ trainer.start_train(x, y, epochs,
197
+ batch_size: batch_size,
198
+ initial_epoch: initial_epoch,
199
+ test: test,
200
+ verbose: verbose,
201
+ need_accuracy: need_accuracy,
202
+ io: io)
203
+ trainer.update while trainer.training?
200
204
  end
201
205
 
202
206
  alias fit train
@@ -210,70 +214,24 @@ module DNN
210
214
  # @param [Array | NilClass] test If you to test the model for every 1 epoch,
211
215
  # specify [x_test, y_test]. Don't test to the model, specify nil.
212
216
  # @param [Boolean] verbose Set true to display the log. If false is set, the log is not displayed.
213
- # @param [Boolean] accuracy Set true to compute the accuracy.
217
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
218
+ # @param [IO] io Specifies the IO object to use for logging.
214
219
  def train_by_iterator(train_iterator, epochs,
215
220
  batch_size: 1,
216
221
  initial_epoch: 1,
217
222
  test: nil,
218
223
  verbose: true,
219
- accuracy: true)
220
- raise DNNError, "The model is not optimizer setup complete." unless @optimizer
221
- raise DNNError, "The model is not loss_func setup complete." unless @loss_func
222
-
223
- num_train_datas = train_iterator.num_datas
224
- num_train_datas = num_train_datas / batch_size * batch_size if train_iterator.last_round_down
225
-
226
- stopped = catch(:stop) do
227
- (initial_epoch..epochs).each do |epoch|
228
- @last_log[:epoch] = epoch
229
- call_callbacks(:before_epoch)
230
- puts "【 epoch #{epoch}/#{epochs} 】" if verbose
231
-
232
- train_iterator.foreach(batch_size) do |x_batch, y_batch, index|
233
- @last_log[:step] = index
234
- train_step_met = train_step(x_batch, y_batch)
235
- num_trained_datas = (index + 1) * batch_size
236
- num_trained_datas = num_trained_datas > num_train_datas ? num_train_datas : num_trained_datas
237
- log = "\r"
238
- 40.times do |i|
239
- if i < num_trained_datas * 40 / num_train_datas
240
- log << "="
241
- elsif i == num_trained_datas * 40 / num_train_datas
242
- log << ">"
243
- else
244
- log << "_"
245
- end
246
- end
247
-
248
- log << " #{num_trained_datas}/#{num_train_datas} "
249
- log << metrics_to_str(train_step_met)
250
- print log if verbose
251
- end
252
-
253
- if test
254
- acc, loss = if test.is_a?(Array)
255
- evaluate(test[0], test[1], batch_size: batch_size, accuracy: accuracy)
256
- else
257
- evaluate_by_iterator(test, batch_size: batch_size, accuracy: accuracy)
258
- end
259
- if verbose
260
- metrics = if accuracy
261
- { accuracy: acc, test_loss: loss }
262
- else
263
- { test_loss: loss }
264
- end
265
- print " " + metrics_to_str(metrics)
266
- end
267
- end
268
- puts "" if verbose
269
- call_callbacks(:after_epoch)
270
- end
271
- nil
272
- end
273
-
274
- if stopped
275
- puts "\n#{stopped}" if verbose
276
- end
224
+ need_accuracy: true,
225
+ io: $stdout)
226
+ trainer = ModelTrainer.new(self)
227
+ trainer.start_train_by_iterator(train_iterator, epochs,
228
+ batch_size: batch_size,
229
+ initial_epoch: initial_epoch,
230
+ test: test,
231
+ verbose: verbose,
232
+ need_accuracy: need_accuracy,
233
+ io: io)
234
+ trainer.update while trainer.training?
277
235
  end
278
236
 
279
237
  alias fit_by_iterator train_by_iterator
@@ -281,128 +239,154 @@ module DNN
281
239
  # Implement the training process to be performed in one step.
282
240
  # @param [Numo::SFloat] x Input training data.
283
241
  # @param [Numo::SFloat] y Output training data.
242
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
284
243
  # @return [Hash] Hash of contents to be output to log.
285
- private def train_step(x, y)
286
- loss_value = train_on_batch(x, y)
287
- { loss: loss_value }
244
+ def train_step(x, y, need_accuracy: false)
245
+ output_data, loss_data = train_on_batch_internal(x, y)
246
+ if loss_data.is_a?(Array)
247
+ loss_value = []
248
+ acc = [] if need_accuracy
249
+ loss_data.each_index do |i|
250
+ loss_value << Utils.to_f(loss_data)
251
+ acc << accuracy(output_data[i], y[i]).to_f / y[i].shape[0] if need_accuracy
252
+ end
253
+ else
254
+ loss_value = Utils.to_f(loss_data)
255
+ acc = accuracy(output_data, y).to_f / y.shape[0] if need_accuracy
256
+ end
257
+ if need_accuracy
258
+ { loss: loss_value, accuracy: acc }
259
+ else
260
+ { loss: loss_value }
261
+ end
288
262
  end
289
263
 
290
264
  # Training once.
291
265
  # Setup the model before use this method.
292
266
  # @param [Numo::SFloat] x Input training data.
293
267
  # @param [Numo::SFloat] y Output training data.
294
- # @return [Float | Numo::SFloat] Return loss value in the form of Float or Numo::SFloat.
268
+ # @return [Float | Array] Return loss value in the form of Float or Array.
295
269
  def train_on_batch(x, y)
296
270
  raise DNNError, "The model is not optimizer setup complete." unless @optimizer
297
271
  raise DNNError, "The model is not loss_func setup complete." unless @loss_func
298
- check_xy_type(x, y)
299
- call_callbacks(:before_train_on_batch)
272
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
273
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
274
+ *, loss_data = train_on_batch_internal(x, y)
275
+ if loss_data.is_a?(Array)
276
+ loss_data.map { |v| Utils.to_f(v) }
277
+ else
278
+ Utils.to_f(loss_data)
279
+ end
280
+ end
281
+
282
+ private def train_on_batch_internal(x, y)
300
283
  DNN.learning_phase = true
301
284
  output_tensors = call(Tensor.convert(x))
302
285
  if output_tensors.is_a?(Array)
286
+ output_data = []
303
287
  loss_data = []
304
288
  output_tensors.each.with_index do |out, i|
289
+ output_data << out.data
305
290
  loss_opt = {}
306
291
  loss_opt[:layers] = layers if i == 0
307
292
  loss_opt[:loss_weight] = @loss_weights[i] if @loss_weights
308
293
  loss = @loss_func[i].loss(out, Tensor.convert(y[i]), **loss_opt)
309
- loss_data << Utils.to_f(loss.data)
294
+ loss_data << loss.data
310
295
  loss.link.backward(Xumo::SFloat.ones(y[i][0...1, false].shape[0], 1))
311
296
  end
312
297
  else
313
298
  out = output_tensors
299
+ output_data = out.data
314
300
  loss = @loss_func.loss(out, Tensor.convert(y), layers: layers)
315
- loss_data = Utils.to_f(loss.data)
301
+ loss_data = loss.data
316
302
  loss.link.backward(Xumo::SFloat.ones(y[0...1, false].shape[0], 1))
317
303
  end
318
304
  @optimizer.update(get_all_trainable_params)
319
- @last_log[:train_loss] = loss_data
320
- call_callbacks(:after_train_on_batch)
321
- loss_data
305
+ [output_data, loss_data]
322
306
  end
323
307
 
324
308
  # Evaluate model and get accuracy and loss of test data.
325
309
  # @param [Numo::SFloat] x Input test data.
326
310
  # @param [Numo::SFloat] y Output test data.
327
311
  # @param [Integer] batch_size Batch size used for one test.
312
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
328
313
  # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
329
314
  # If accuracy is not needed returns in the form [nil, mean_loss].
330
- def evaluate(x, y, batch_size: 100, accuracy: true)
331
- check_xy_type(x, y)
332
- evaluate_by_iterator(Iterator.new(x, y, random: false), batch_size: batch_size, accuracy: accuracy)
315
+ def evaluate(x, y, batch_size: 100, need_accuracy: true)
316
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
317
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
318
+ evaluator = ModelEvaluator.new(self)
319
+ evaluator.start_evaluate(x, y, batch_size: batch_size, need_accuracy: need_accuracy)
320
+ evaluator.update while evaluator.evaluating?
321
+ [@last_log[:test_accuracy], @last_log[:test_loss]]
333
322
  end
334
323
 
335
324
  # Evaluate model by iterator.
336
325
  # @param [DNN::Iterator] test_iterator Iterator used for testing.
337
326
  # @param [Integer] batch_size Batch size used for one test.
327
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
338
328
  # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
339
329
  # If accuracy is not needed returns in the form [nil, mean_loss].
340
- def evaluate_by_iterator(test_iterator, batch_size: 100, accuracy: true)
341
- num_test_datas = test_iterator.num_datas
342
- batch_size = batch_size >= num_test_datas ? num_test_datas : batch_size
343
- if @loss_func.is_a?(Array)
344
- total_correct = Array.new(@loss_func.length, 0)
345
- sum_loss = Array.new(@loss_func.length, 0)
346
- else
347
- total_correct = 0
348
- sum_loss = 0
349
- end
350
- max_steps = (num_test_datas.to_f / batch_size).ceil
351
- test_iterator.foreach(batch_size) do |x_batch, y_batch|
352
- correct, loss_value = test_on_batch(x_batch, y_batch, accuracy: accuracy)
353
- if @loss_func.is_a?(Array)
354
- @loss_func.each_index do |i|
355
- total_correct[i] += correct[i] if accuracy
356
- sum_loss[i] += loss_value[i]
357
- end
358
- else
359
- total_correct += correct if accuracy
360
- sum_loss += loss_value
361
- end
362
- end
363
- acc = nil
364
- if @loss_func.is_a?(Array)
365
- mean_loss = Array.new(@loss_func.length, 0)
366
- acc = Array.new(@loss_func.length, 0) if accuracy
367
- @loss_func.each_index do |i|
368
- mean_loss[i] += sum_loss[i] / max_steps
369
- acc[i] += total_correct[i].to_f / num_test_datas if accuracy
330
+ def evaluate_by_iterator(test_iterator, batch_size: 100, need_accuracy: true)
331
+ evaluator = ModelEvaluator.new(self)
332
+ evaluator.start_evaluate_by_iterator(test_iterator, batch_size: batch_size, need_accuracy: need_accuracy)
333
+ evaluator.update while evaluator.evaluating?
334
+ [@last_log[:test_accuracy], @last_log[:test_loss]]
335
+ end
336
+
337
+ # Testing process to be performed in one step.
338
+ # @param [Numo::SFloat] x Input training data.
339
+ # @param [Numo::SFloat] y Output training data.
340
+ # @return [Hash] Hash of contents to be output to log.
341
+ def test_step(x, y, need_accuracy: false)
342
+ output_data, loss_data = test_on_batch_internal(x, y)
343
+ if loss_data.is_a?(Array)
344
+ loss_value = []
345
+ accuracy = []
346
+ loss_data.each_index do |i|
347
+ loss_value << Utils.to_f(loss_data)
348
+ accuracy << accuracy(output_data[i], y[i]).to_f / y[i].shape[0]
370
349
  end
371
350
  else
372
- mean_loss = sum_loss / max_steps
373
- acc = total_correct.to_f / num_test_datas if accuracy
351
+ loss_value = Utils.to_f(loss_data)
374
352
  end
375
- @last_log[:test_loss] = mean_loss
376
- @last_log[:test_accuracy] = acc
377
- [acc, mean_loss]
353
+ { test_loss: loss_value, test_accuracy: accuracy(output_data, y) }
378
354
  end
379
355
 
380
- # Evaluate once.
356
+ # Test once.
381
357
  # @param [Numo::SFloat | Array] x Input test data.
382
358
  # @param [Numo::SFloat | Array] y Output test data.
383
- # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, loss].
384
- # If accuracy is not needed returns in the form [nil, loss].
385
- def test_on_batch(x, y, accuracy: true)
386
- call_callbacks(:before_test_on_batch)
359
+ # @return [Float | Array] Return loss value in the form of Float or Array.
360
+ def test_on_batch(x, y)
361
+ raise DNNError, "The model is not loss_func setup complete." unless @loss_func
362
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
363
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
364
+ *, loss_data = test_on_batch_internal(x, y)
365
+ if loss_data.is_a?(Array)
366
+ loss_data.map { |v| Utils.to_f(v) }
367
+ else
368
+ Utils.to_f(loss_data)
369
+ end
370
+ end
371
+
372
+ private def test_on_batch_internal(x, y)
387
373
  DNN.learning_phase = false
388
374
  output_tensors = call(Tensor.convert(x))
389
- correct = nil
390
375
  if output_tensors.is_a?(Array)
391
- correct = [] if accuracy
376
+ output_data = []
392
377
  loss_data = []
393
378
  output_tensors.each.with_index do |out, i|
394
- correct << accuracy(out.data, y[i]) if accuracy
379
+ output_data << out.data
395
380
  loss = @loss_func[i].(out, Tensor.convert(y[i]))
396
- loss_data << Utils.to_f(loss.data)
381
+ loss_data << loss.data
397
382
  end
398
383
  else
399
384
  out = output_tensors
400
- correct = accuracy(out.data, y) if accuracy
385
+ output_data = out.data
401
386
  loss = @loss_func.(out, Tensor.convert(y))
402
- loss_data = Utils.to_f(loss.data)
387
+ loss_data = loss.data
403
388
  end
404
- call_callbacks(:after_test_on_batch)
405
- [correct, loss_data]
389
+ [output_data, loss_data]
406
390
  end
407
391
 
408
392
  # Implement the process to accuracy this model.
@@ -429,7 +413,7 @@ module DNN
429
413
  # @param [Numo::SFloat] x Input data.
430
414
  # @param [Boolean] use_loss_activation Use loss activation when loss has an activation.
431
415
  def predict(x, use_loss_activation: true)
432
- check_xy_type(x)
416
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
433
417
  DNN.learning_phase = false
434
418
  output_tensors = call(Tensor.convert(x))
435
419
  if output_tensors.is_a?(Array)
@@ -454,7 +438,7 @@ module DNN
454
438
  # Predict one data.
455
439
  # @param [Numo::SFloat] x Input data. However, x is single data.
456
440
  def predict1(x, use_loss_activation: true)
457
- check_xy_type(x)
441
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
458
442
  input = if x.is_a?(Array)
459
443
  x.map { |v| v.reshape(1, *v.shape) }
460
444
  else
@@ -618,7 +602,18 @@ module DNN
618
602
  self
619
603
  end
620
604
 
621
- private
605
+ # Request training early stop.
606
+ def request_early_stop
607
+ @early_stop_requested = true
608
+ end
609
+
610
+ def check_early_stop_requested
611
+ if @early_stop_requested
612
+ @early_stop_requested = false
613
+ return true
614
+ end
615
+ false
616
+ end
622
617
 
623
618
  def get_all_trainable_params
624
619
  layers.select { |layer| layer.is_a?(Layers::TrainableLayer) && layer.trainable }
@@ -631,6 +626,245 @@ module DNN
631
626
  callback.send(event) if callback.respond_to?(event)
632
627
  end
633
628
  end
629
+ end
630
+
631
+ class ModelTrainer
632
+ def initialize(model)
633
+ @model = model
634
+ @state = :none
635
+ @initial_epoch = 1
636
+ @step = 1
637
+ @max_steps = 1
638
+ @train_iterator = nil
639
+ @max_epochs = 1
640
+ @batch_size = 1
641
+ @epoch = 1
642
+ @test = nil
643
+ @verbose = false
644
+ @need_accuracy = false
645
+ @io = nil
646
+ @num_train_datas = 0
647
+ end
648
+
649
+ # Start training.
650
+ # Setup the model before use this method.
651
+ # @param [Numo::SFloat] x Input training data.
652
+ # @param [Numo::SFloat] y Output training data.
653
+ # @param [Integer] epochs Number of training.
654
+ # @param [Integer] batch_size Batch size used for one training.
655
+ # @param [Integer] initial_epoch Initial epoch.
656
+ # @param [Array | NilClass] test If you to test the model for every 1 epoch,
657
+ # specify [x_test, y_test]. Don't test to the model, specify nil.
658
+ # @param [Boolean] verbose Set true to display the log. If false is set, the log is not displayed.
659
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
660
+ # @param [IO] io Specifies the IO object to use for logging.
661
+ def start_train(x, y, epochs,
662
+ batch_size: 1,
663
+ initial_epoch: 1,
664
+ test: nil,
665
+ verbose: true,
666
+ need_accuracy: true,
667
+ io: $stdout)
668
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
669
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
670
+ train_iterator = Iterator.new(x, y)
671
+ start_train_by_iterator(train_iterator, epochs,
672
+ batch_size: batch_size,
673
+ initial_epoch: initial_epoch,
674
+ test: test,
675
+ verbose: verbose,
676
+ need_accuracy: need_accuracy,
677
+ io: io)
678
+ end
679
+
680
+ # Start training by iterator.
681
+ # Setup the model before use this method.
682
+ # @param [DNN::Iterator] train_iterator Iterator used for training.
683
+ # @param [Integer] epochs Number of training.
684
+ # @param [Integer] batch_size Batch size used for one training.
685
+ # @param [Integer] initial_epoch Initial epoch.
686
+ # @param [Array | NilClass] test If you to test the model for every 1 epoch,
687
+ # specify [x_test, y_test]. Don't test to the model, specify nil.
688
+ # @param [Boolean] verbose Set true to display the log. If false is set, the log is not displayed.
689
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
690
+ # @param [IO] io Specifies the IO object to use for logging.
691
+ def start_train_by_iterator(train_iterator, epochs,
692
+ batch_size: 1,
693
+ initial_epoch: 1,
694
+ test: nil,
695
+ verbose: true,
696
+ need_accuracy: true,
697
+ io: $stdout)
698
+ raise DNNError, "The model is not optimizer setup complete." unless @model.optimizer
699
+ raise DNNError, "The model is not loss_func setup complete." unless @model.loss_func
700
+ @model.check_early_stop_requested # Clear early stop request.
701
+ @train_iterator = train_iterator
702
+ @max_epochs = epochs
703
+ @batch_size = batch_size
704
+ @epoch = initial_epoch
705
+ @test = test
706
+ @verbose = verbose
707
+ @need_accuracy = need_accuracy
708
+ @io = io
709
+ @state = :start_epoch
710
+ @max_steps = train_iterator.max_steps(batch_size)
711
+ @num_train_datas = train_iterator.num_usable_datas(batch_size)
712
+ @line_first_pos = 0
713
+ @model.call_callbacks(:before_train)
714
+ end
715
+
716
+ # Check if it is currently evaluating.
717
+ # @return [Boolean] Returns true if currently training.
718
+ def training?
719
+ @state != :none
720
+ end
721
+
722
+ # Update trainer.
723
+ def update
724
+ case @state
725
+ when :start_epoch
726
+ start_epoch
727
+ when :start_step
728
+ start_step
729
+ when :train_step
730
+ train_step
731
+ when :end_step
732
+ end_step
733
+ when :end_epoch
734
+ end_epoch
735
+ when :start_evaluate
736
+ start_evaluate
737
+ when :evaluating
738
+ evaluating
739
+ when :end_evaluate
740
+ end_evaluate
741
+ when :end_training
742
+ end_training
743
+ end
744
+ end
745
+
746
+ private
747
+
748
+ def start_epoch
749
+ @model.last_log[:epoch] = @epoch
750
+ @model.call_callbacks(:before_epoch)
751
+ @io.puts "【 epoch #{@epoch}/#{@max_epochs} 】" if @verbose
752
+ @step = 1
753
+ @state = :start_step
754
+ end
755
+
756
+ def start_step
757
+ @model.last_log[:step] = @step
758
+ @state = :train_step
759
+ end
760
+
761
+ def train_step
762
+ (x_batch, y_batch) = @train_iterator.next_batch(@batch_size)
763
+ @model.call_callbacks(:before_train_on_batch)
764
+ train_step_met = @model.train_step(x_batch, y_batch, need_accuracy: @need_accuracy)
765
+ @model.last_log.merge!(train_step_met)
766
+ @model.call_callbacks(:after_train_on_batch)
767
+ num_trained_datas = @step * @batch_size
768
+ num_trained_datas = num_trained_datas > @num_train_datas ? @num_train_datas : num_trained_datas
769
+ if @io == $stdout
770
+ log = "\r"
771
+ else
772
+ @line_first_pos = @io.pos
773
+ log = ""
774
+ end
775
+ 40.times do |i|
776
+ if i < num_trained_datas * 40 / @num_train_datas
777
+ log << "="
778
+ elsif i == num_trained_datas * 40 / @num_train_datas
779
+ log << ">"
780
+ else
781
+ log << "_"
782
+ end
783
+ end
784
+ log << " #{num_trained_datas}/#{@num_train_datas} "
785
+ log << metrics_to_str(train_step_met)
786
+ @io.print log if @verbose
787
+ if @model.check_early_stop_requested
788
+ @io.puts("\nEarly stopped.") if @verbose
789
+ @state = :end_training
790
+ else
791
+ @state = :end_step
792
+ end
793
+ end
794
+
795
+ def end_step
796
+ @step += 1
797
+ if @step <= @max_steps
798
+ unless @io == $stdout
799
+ @io.pos = @line_first_pos
800
+ end
801
+ @state = :start_step
802
+ else
803
+ @state = :end_epoch
804
+ end
805
+ end
806
+
807
+ def end_epoch
808
+ @epoch += 1
809
+ if @test
810
+ @state = :start_evaluate
811
+ else
812
+ @io.puts "" if @verbose
813
+ @model.call_callbacks(:after_epoch)
814
+ if @epoch <= @max_epochs
815
+ @train_iterator.reset
816
+ @state = :start_epoch
817
+ else
818
+ @state = :none
819
+ end
820
+ end
821
+ end
822
+
823
+ def start_evaluate
824
+ @evaluator = ModelEvaluator.new(@model)
825
+ if @test.is_a?(Array)
826
+ @evaluator.start_evaluate(@test[0], @test[1], batch_size: @batch_size, need_accuracy: @need_accuracy)
827
+ else
828
+ @evaluator.start_evaluate_by_iterator(@test, batch_size: @batch_size, need_accuracy: @need_accuracy)
829
+ end
830
+ @state = :evaluating
831
+ end
832
+
833
+ def evaluating
834
+ @evaluator.update
835
+ unless @evaluator.evaluating?
836
+ @state = :end_evaluate
837
+ end
838
+ end
839
+
840
+ def end_evaluate
841
+ if @verbose
842
+ metrics = if @need_accuracy
843
+ { test_accuracy: @model.last_log[:test_accuracy], test_loss: @model.last_log[:test_loss] }
844
+ else
845
+ { test_loss: @model.last_log[:test_loss] }
846
+ end
847
+ @io.print " " + metrics_to_str(metrics)
848
+ end
849
+ @io.puts "" if @verbose
850
+ @model.call_callbacks(:after_epoch)
851
+ if @epoch <= @max_epochs
852
+ @train_iterator.reset
853
+ if @model.check_early_stop_requested
854
+ @io.puts("Early stopped.") if @verbose
855
+ @state = :end_training
856
+ else
857
+ @state = :start_epoch
858
+ end
859
+ else
860
+ @state = :end_training
861
+ end
862
+ end
863
+
864
+ def end_training
865
+ @model.call_callbacks(:after_train)
866
+ @state = :none
867
+ end
634
868
 
635
869
  def metrics_to_str(mertics)
636
870
  mertics.map { |key, values|
@@ -643,28 +877,119 @@ module DNN
643
877
  "#{key}: #{str_values}"
644
878
  }.join(", ")
645
879
  end
880
+ end
881
+
882
+ class ModelEvaluator
883
+ def initialize(model)
884
+ @model = model
885
+ @state = :none
886
+ end
646
887
 
647
- def check_xy_type(x, y = nil)
648
- if !x.is_a?(Xumo::SFloat) && !x.is_a?(Array)
649
- raise TypeError, "x:#{x.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
888
+ # Start evaluate model and get accuracy and loss of test data.
889
+ # @param [Numo::SFloat] x Input test data.
890
+ # @param [Numo::SFloat] y Output test data.
891
+ # @param [Integer] batch_size Batch size used for one test.
892
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
893
+ # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
894
+ # If accuracy is not needed returns in the form [nil, mean_loss].
895
+ def start_evaluate(x, y, batch_size: 100, need_accuracy: true)
896
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
897
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
898
+ start_evaluate_by_iterator(Iterator.new(x, y, random: false), batch_size: batch_size, need_accuracy: need_accuracy)
899
+ end
900
+
901
+ # Start Evaluate model by iterator.
902
+ # @param [DNN::Iterator] test_iterator Iterator used for testing.
903
+ # @param [Integer] batch_size Batch size used for one test.
904
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
905
+ # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
906
+ # If accuracy is not needed returns in the form [nil, mean_loss].
907
+ def start_evaluate_by_iterator(test_iterator, batch_size: 100, need_accuracy: true)
908
+ @test_iterator = test_iterator
909
+ @num_test_datas = test_iterator.num_datas
910
+ @batch_size = batch_size >= @num_test_datas ? @num_test_datas : batch_size
911
+ @need_accuracy = need_accuracy
912
+ if @loss_func.is_a?(Array)
913
+ @total_correct = Array.new(@loss_func.length, 0)
914
+ @sum_loss = Array.new(@loss_func.length, 0)
915
+ else
916
+ @total_correct = 0
917
+ @sum_loss = 0
650
918
  end
651
- if x.is_a?(Array)
652
- x.each.with_index do |v, i|
653
- unless v.is_a?(Xumo::SFloat)
654
- raise TypeError, "x[#{i}]:#{v.class.name} is not an instance of #{Xumo::SFloat.name} class."
655
- end
919
+ @step = 1
920
+ @max_steps = (@num_test_datas.to_f / @batch_size).ceil
921
+ @state = :start_step
922
+ end
923
+
924
+ # Check if it is currently evaluating.
925
+ # @return [Boolean] Returns true if currently evaluating.
926
+ def evaluating?
927
+ @state != :none
928
+ end
929
+
930
+ # Update evaluator.
931
+ def update
932
+ case @state
933
+ when :start_step
934
+ start_step
935
+ when :test_step
936
+ test_step
937
+ when :end_step
938
+ end_step
939
+ when :end_evaluate
940
+ end_evaluate
941
+ end
942
+ end
943
+
944
+ private
945
+
946
+ def start_step
947
+ @model.last_log[:step] = @step
948
+ @state = :test_step
949
+ end
950
+
951
+ def test_step
952
+ (x_batch, y_batch) = @test_iterator.next_batch(@batch_size)
953
+ @model.call_callbacks(:before_test_on_batch)
954
+ test_met = @model.test_step(x_batch, y_batch, need_accuracy: @need_accuracy)
955
+ @model.call_callbacks(:after_test_on_batch)
956
+ if @loss_func.is_a?(Array)
957
+ @loss_func.each_index do |i|
958
+ @total_correct[i] += test_met[:test_accuracy][i] if @need_accuracy
959
+ @sum_loss[i] += test_met[:test_loss][i]
656
960
  end
961
+ else
962
+ @total_correct += test_met[:test_accuracy] if @need_accuracy
963
+ @sum_loss += test_met[:test_loss]
657
964
  end
658
- if y && !y.is_a?(Xumo::SFloat) && !y.is_a?(Array)
659
- raise TypeError, "y:#{y.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
965
+ @state = :end_step
966
+ end
967
+
968
+ def end_step
969
+ @step += 1
970
+ if @step <= @max_steps
971
+ @state = :start_step
972
+ else
973
+ @state = :end_evaluate
660
974
  end
661
- if y.is_a?(Array)
662
- y.each.with_index do |v, i|
663
- unless v.is_a?(Xumo::SFloat)
664
- raise TypeError, "x[#{i}]:#{v.class.name} is not an instance of #{Xumo::SFloat.name} class."
665
- end
975
+ end
976
+
977
+ def end_evaluate
978
+ acc = nil
979
+ if @loss_func.is_a?(Array)
980
+ mean_loss = Array.new(@loss_func.length, 0)
981
+ acc = Array.new(@loss_func.length, 0) if @need_accuracy
982
+ @loss_func.each_index do |i|
983
+ mean_loss[i] += @sum_loss[i] / @max_steps
984
+ acc[i] += @total_correct[i].to_f / @num_test_datas if @need_accuracy
666
985
  end
986
+ else
987
+ mean_loss = @sum_loss / @max_steps
988
+ acc = @total_correct.to_f / @num_test_datas if @need_accuracy
667
989
  end
990
+ @model.last_log[:test_loss] = mean_loss
991
+ @model.last_log[:test_accuracy] = acc
992
+ @state = :none
668
993
  end
669
994
  end
670
995