ruby-dnn 1.2.3 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -130,6 +130,7 @@ module DNN
130
130
  @loss_weights = nil
131
131
  @callbacks = []
132
132
  @last_log = {}
133
+ @early_stop_requested = false
133
134
  end
134
135
 
135
136
  def call(input_tensors)
@@ -182,21 +183,24 @@ module DNN
182
183
  # @param [Array | NilClass] test If you to test the model for every 1 epoch,
183
184
  # specify [x_test, y_test]. Don't test to the model, specify nil.
184
185
  # @param [Boolean] verbose Set true to display the log. If false is set, the log is not displayed.
185
- # @param [Boolean] accuracy Set true to compute the accuracy.
186
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
187
+ # @param [IO] io Specifies the IO object to use for logging.
186
188
  def train(x, y, epochs,
187
189
  batch_size: 1,
188
190
  initial_epoch: 1,
189
191
  test: nil,
190
192
  verbose: true,
191
- accuracy: true)
192
- check_xy_type(x, y)
193
- train_iterator = Iterator.new(x, y)
194
- train_by_iterator(train_iterator, epochs,
195
- batch_size: batch_size,
196
- initial_epoch: initial_epoch,
197
- test: test,
198
- verbose: verbose,
199
- accuracy: accuracy)
193
+ need_accuracy: true,
194
+ io: $stdout)
195
+ trainer = ModelTrainer.new(self)
196
+ trainer.start_train(x, y, epochs,
197
+ batch_size: batch_size,
198
+ initial_epoch: initial_epoch,
199
+ test: test,
200
+ verbose: verbose,
201
+ need_accuracy: need_accuracy,
202
+ io: io)
203
+ trainer.update while trainer.training?
200
204
  end
201
205
 
202
206
  alias fit train
@@ -210,70 +214,24 @@ module DNN
210
214
  # @param [Array | NilClass] test If you to test the model for every 1 epoch,
211
215
  # specify [x_test, y_test]. Don't test to the model, specify nil.
212
216
  # @param [Boolean] verbose Set true to display the log. If false is set, the log is not displayed.
213
- # @param [Boolean] accuracy Set true to compute the accuracy.
217
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
218
+ # @param [IO] io Specifies the IO object to use for logging.
214
219
  def train_by_iterator(train_iterator, epochs,
215
220
  batch_size: 1,
216
221
  initial_epoch: 1,
217
222
  test: nil,
218
223
  verbose: true,
219
- accuracy: true)
220
- raise DNNError, "The model is not optimizer setup complete." unless @optimizer
221
- raise DNNError, "The model is not loss_func setup complete." unless @loss_func
222
-
223
- num_train_datas = train_iterator.num_datas
224
- num_train_datas = num_train_datas / batch_size * batch_size if train_iterator.last_round_down
225
-
226
- stopped = catch(:stop) do
227
- (initial_epoch..epochs).each do |epoch|
228
- @last_log[:epoch] = epoch
229
- call_callbacks(:before_epoch)
230
- puts "【 epoch #{epoch}/#{epochs} 】" if verbose
231
-
232
- train_iterator.foreach(batch_size) do |x_batch, y_batch, index|
233
- @last_log[:step] = index
234
- train_step_met = train_step(x_batch, y_batch)
235
- num_trained_datas = (index + 1) * batch_size
236
- num_trained_datas = num_trained_datas > num_train_datas ? num_train_datas : num_trained_datas
237
- log = "\r"
238
- 40.times do |i|
239
- if i < num_trained_datas * 40 / num_train_datas
240
- log << "="
241
- elsif i == num_trained_datas * 40 / num_train_datas
242
- log << ">"
243
- else
244
- log << "_"
245
- end
246
- end
247
-
248
- log << " #{num_trained_datas}/#{num_train_datas} "
249
- log << metrics_to_str(train_step_met)
250
- print log if verbose
251
- end
252
-
253
- if test
254
- acc, loss = if test.is_a?(Array)
255
- evaluate(test[0], test[1], batch_size: batch_size, accuracy: accuracy)
256
- else
257
- evaluate_by_iterator(test, batch_size: batch_size, accuracy: accuracy)
258
- end
259
- if verbose
260
- metrics = if accuracy
261
- { accuracy: acc, test_loss: loss }
262
- else
263
- { test_loss: loss }
264
- end
265
- print " " + metrics_to_str(metrics)
266
- end
267
- end
268
- puts "" if verbose
269
- call_callbacks(:after_epoch)
270
- end
271
- nil
272
- end
273
-
274
- if stopped
275
- puts "\n#{stopped}" if verbose
276
- end
224
+ need_accuracy: true,
225
+ io: $stdout)
226
+ trainer = ModelTrainer.new(self)
227
+ trainer.start_train_by_iterator(train_iterator, epochs,
228
+ batch_size: batch_size,
229
+ initial_epoch: initial_epoch,
230
+ test: test,
231
+ verbose: verbose,
232
+ need_accuracy: need_accuracy,
233
+ io: io)
234
+ trainer.update while trainer.training?
277
235
  end
278
236
 
279
237
  alias fit_by_iterator train_by_iterator
@@ -281,128 +239,154 @@ module DNN
281
239
  # Implement the training process to be performed in one step.
282
240
  # @param [Numo::SFloat] x Input training data.
283
241
  # @param [Numo::SFloat] y Output training data.
242
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
284
243
  # @return [Hash] Hash of contents to be output to log.
285
- private def train_step(x, y)
286
- loss_value = train_on_batch(x, y)
287
- { loss: loss_value }
244
+ def train_step(x, y, need_accuracy: false)
245
+ output_data, loss_data = train_on_batch_internal(x, y)
246
+ if loss_data.is_a?(Array)
247
+ loss_value = []
248
+ acc = [] if need_accuracy
249
+ loss_data.each_index do |i|
250
+ loss_value << Utils.to_f(loss_data)
251
+ acc << accuracy(output_data[i], y[i]).to_f / y[i].shape[0] if need_accuracy
252
+ end
253
+ else
254
+ loss_value = Utils.to_f(loss_data)
255
+ acc = accuracy(output_data, y).to_f / y.shape[0] if need_accuracy
256
+ end
257
+ if need_accuracy
258
+ { loss: loss_value, accuracy: acc }
259
+ else
260
+ { loss: loss_value }
261
+ end
288
262
  end
289
263
 
290
264
  # Training once.
291
265
  # Setup the model before use this method.
292
266
  # @param [Numo::SFloat] x Input training data.
293
267
  # @param [Numo::SFloat] y Output training data.
294
- # @return [Float | Numo::SFloat] Return loss value in the form of Float or Numo::SFloat.
268
+ # @return [Float | Array] Return loss value in the form of Float or Array.
295
269
  def train_on_batch(x, y)
296
270
  raise DNNError, "The model is not optimizer setup complete." unless @optimizer
297
271
  raise DNNError, "The model is not loss_func setup complete." unless @loss_func
298
- check_xy_type(x, y)
299
- call_callbacks(:before_train_on_batch)
272
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
273
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
274
+ *, loss_data = train_on_batch_internal(x, y)
275
+ if loss_data.is_a?(Array)
276
+ loss_data.map { |v| Utils.to_f(v) }
277
+ else
278
+ Utils.to_f(loss_data)
279
+ end
280
+ end
281
+
282
+ private def train_on_batch_internal(x, y)
300
283
  DNN.learning_phase = true
301
284
  output_tensors = call(Tensor.convert(x))
302
285
  if output_tensors.is_a?(Array)
286
+ output_data = []
303
287
  loss_data = []
304
288
  output_tensors.each.with_index do |out, i|
289
+ output_data << out.data
305
290
  loss_opt = {}
306
291
  loss_opt[:layers] = layers if i == 0
307
292
  loss_opt[:loss_weight] = @loss_weights[i] if @loss_weights
308
293
  loss = @loss_func[i].loss(out, Tensor.convert(y[i]), **loss_opt)
309
- loss_data << Utils.to_f(loss.data)
294
+ loss_data << loss.data
310
295
  loss.link.backward(Xumo::SFloat.ones(y[i][0...1, false].shape[0], 1))
311
296
  end
312
297
  else
313
298
  out = output_tensors
299
+ output_data = out.data
314
300
  loss = @loss_func.loss(out, Tensor.convert(y), layers: layers)
315
- loss_data = Utils.to_f(loss.data)
301
+ loss_data = loss.data
316
302
  loss.link.backward(Xumo::SFloat.ones(y[0...1, false].shape[0], 1))
317
303
  end
318
304
  @optimizer.update(get_all_trainable_params)
319
- @last_log[:train_loss] = loss_data
320
- call_callbacks(:after_train_on_batch)
321
- loss_data
305
+ [output_data, loss_data]
322
306
  end
323
307
 
324
308
  # Evaluate model and get accuracy and loss of test data.
325
309
  # @param [Numo::SFloat] x Input test data.
326
310
  # @param [Numo::SFloat] y Output test data.
327
311
  # @param [Integer] batch_size Batch size used for one test.
312
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
328
313
  # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
329
314
  # If accuracy is not needed returns in the form [nil, mean_loss].
330
- def evaluate(x, y, batch_size: 100, accuracy: true)
331
- check_xy_type(x, y)
332
- evaluate_by_iterator(Iterator.new(x, y, random: false), batch_size: batch_size, accuracy: accuracy)
315
+ def evaluate(x, y, batch_size: 100, need_accuracy: true)
316
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
317
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
318
+ evaluator = ModelEvaluator.new(self)
319
+ evaluator.start_evaluate(x, y, batch_size: batch_size, need_accuracy: need_accuracy)
320
+ evaluator.update while evaluator.evaluating?
321
+ [@last_log[:test_accuracy], @last_log[:test_loss]]
333
322
  end
334
323
 
335
324
  # Evaluate model by iterator.
336
325
  # @param [DNN::Iterator] test_iterator Iterator used for testing.
337
326
  # @param [Integer] batch_size Batch size used for one test.
327
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
338
328
  # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
339
329
  # If accuracy is not needed returns in the form [nil, mean_loss].
340
- def evaluate_by_iterator(test_iterator, batch_size: 100, accuracy: true)
341
- num_test_datas = test_iterator.num_datas
342
- batch_size = batch_size >= num_test_datas ? num_test_datas : batch_size
343
- if @loss_func.is_a?(Array)
344
- total_correct = Array.new(@loss_func.length, 0)
345
- sum_loss = Array.new(@loss_func.length, 0)
346
- else
347
- total_correct = 0
348
- sum_loss = 0
349
- end
350
- max_steps = (num_test_datas.to_f / batch_size).ceil
351
- test_iterator.foreach(batch_size) do |x_batch, y_batch|
352
- correct, loss_value = test_on_batch(x_batch, y_batch, accuracy: accuracy)
353
- if @loss_func.is_a?(Array)
354
- @loss_func.each_index do |i|
355
- total_correct[i] += correct[i] if accuracy
356
- sum_loss[i] += loss_value[i]
357
- end
358
- else
359
- total_correct += correct if accuracy
360
- sum_loss += loss_value
361
- end
362
- end
363
- acc = nil
364
- if @loss_func.is_a?(Array)
365
- mean_loss = Array.new(@loss_func.length, 0)
366
- acc = Array.new(@loss_func.length, 0) if accuracy
367
- @loss_func.each_index do |i|
368
- mean_loss[i] += sum_loss[i] / max_steps
369
- acc[i] += total_correct[i].to_f / num_test_datas if accuracy
330
+ def evaluate_by_iterator(test_iterator, batch_size: 100, need_accuracy: true)
331
+ evaluator = ModelEvaluator.new(self)
332
+ evaluator.start_evaluate_by_iterator(test_iterator, batch_size: batch_size, need_accuracy: need_accuracy)
333
+ evaluator.update while evaluator.evaluating?
334
+ [@last_log[:test_accuracy], @last_log[:test_loss]]
335
+ end
336
+
337
+ # Testing process to be performed in one step.
338
+ # @param [Numo::SFloat] x Input training data.
339
+ # @param [Numo::SFloat] y Output training data.
340
+ # @return [Hash] Hash of contents to be output to log.
341
+ def test_step(x, y, need_accuracy: false)
342
+ output_data, loss_data = test_on_batch_internal(x, y)
343
+ if loss_data.is_a?(Array)
344
+ loss_value = []
345
+ accuracy = []
346
+ loss_data.each_index do |i|
347
+ loss_value << Utils.to_f(loss_data)
348
+ accuracy << accuracy(output_data[i], y[i]).to_f / y[i].shape[0]
370
349
  end
371
350
  else
372
- mean_loss = sum_loss / max_steps
373
- acc = total_correct.to_f / num_test_datas if accuracy
351
+ loss_value = Utils.to_f(loss_data)
374
352
  end
375
- @last_log[:test_loss] = mean_loss
376
- @last_log[:test_accuracy] = acc
377
- [acc, mean_loss]
353
+ { test_loss: loss_value, test_accuracy: accuracy(output_data, y) }
378
354
  end
379
355
 
380
- # Evaluate once.
356
+ # Test once.
381
357
  # @param [Numo::SFloat | Array] x Input test data.
382
358
  # @param [Numo::SFloat | Array] y Output test data.
383
- # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, loss].
384
- # If accuracy is not needed returns in the form [nil, loss].
385
- def test_on_batch(x, y, accuracy: true)
386
- call_callbacks(:before_test_on_batch)
359
+ # @return [Float | Array] Return loss value in the form of Float or Array.
360
+ def test_on_batch(x, y)
361
+ raise DNNError, "The model is not loss_func setup complete." unless @loss_func
362
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
363
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
364
+ *, loss_data = test_on_batch_internal(x, y)
365
+ if loss_data.is_a?(Array)
366
+ loss_data.map { |v| Utils.to_f(v) }
367
+ else
368
+ Utils.to_f(loss_data)
369
+ end
370
+ end
371
+
372
+ private def test_on_batch_internal(x, y)
387
373
  DNN.learning_phase = false
388
374
  output_tensors = call(Tensor.convert(x))
389
- correct = nil
390
375
  if output_tensors.is_a?(Array)
391
- correct = [] if accuracy
376
+ output_data = []
392
377
  loss_data = []
393
378
  output_tensors.each.with_index do |out, i|
394
- correct << accuracy(out.data, y[i]) if accuracy
379
+ output_data << out.data
395
380
  loss = @loss_func[i].(out, Tensor.convert(y[i]))
396
- loss_data << Utils.to_f(loss.data)
381
+ loss_data << loss.data
397
382
  end
398
383
  else
399
384
  out = output_tensors
400
- correct = accuracy(out.data, y) if accuracy
385
+ output_data = out.data
401
386
  loss = @loss_func.(out, Tensor.convert(y))
402
- loss_data = Utils.to_f(loss.data)
387
+ loss_data = loss.data
403
388
  end
404
- call_callbacks(:after_test_on_batch)
405
- [correct, loss_data]
389
+ [output_data, loss_data]
406
390
  end
407
391
 
408
392
  # Implement the process to accuracy this model.
@@ -429,7 +413,7 @@ module DNN
429
413
  # @param [Numo::SFloat] x Input data.
430
414
  # @param [Boolean] use_loss_activation Use loss activation when loss has an activation.
431
415
  def predict(x, use_loss_activation: true)
432
- check_xy_type(x)
416
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
433
417
  DNN.learning_phase = false
434
418
  output_tensors = call(Tensor.convert(x))
435
419
  if output_tensors.is_a?(Array)
@@ -454,7 +438,7 @@ module DNN
454
438
  # Predict one data.
455
439
  # @param [Numo::SFloat] x Input data. However, x is single data.
456
440
  def predict1(x, use_loss_activation: true)
457
- check_xy_type(x)
441
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
458
442
  input = if x.is_a?(Array)
459
443
  x.map { |v| v.reshape(1, *v.shape) }
460
444
  else
@@ -618,7 +602,18 @@ module DNN
618
602
  self
619
603
  end
620
604
 
621
- private
605
+ # Request training early stop.
606
+ def request_early_stop
607
+ @early_stop_requested = true
608
+ end
609
+
610
+ def check_early_stop_requested
611
+ if @early_stop_requested
612
+ @early_stop_requested = false
613
+ return true
614
+ end
615
+ false
616
+ end
622
617
 
623
618
  def get_all_trainable_params
624
619
  layers.select { |layer| layer.is_a?(Layers::TrainableLayer) && layer.trainable }
@@ -631,6 +626,245 @@ module DNN
631
626
  callback.send(event) if callback.respond_to?(event)
632
627
  end
633
628
  end
629
+ end
630
+
631
+ class ModelTrainer
632
+ def initialize(model)
633
+ @model = model
634
+ @state = :none
635
+ @initial_epoch = 1
636
+ @step = 1
637
+ @max_steps = 1
638
+ @train_iterator = nil
639
+ @max_epochs = 1
640
+ @batch_size = 1
641
+ @epoch = 1
642
+ @test = nil
643
+ @verbose = false
644
+ @need_accuracy = false
645
+ @io = nil
646
+ @num_train_datas = 0
647
+ end
648
+
649
+ # Start training.
650
+ # Setup the model before use this method.
651
+ # @param [Numo::SFloat] x Input training data.
652
+ # @param [Numo::SFloat] y Output training data.
653
+ # @param [Integer] epochs Number of training.
654
+ # @param [Integer] batch_size Batch size used for one training.
655
+ # @param [Integer] initial_epoch Initial epoch.
656
+ # @param [Array | NilClass] test If you to test the model for every 1 epoch,
657
+ # specify [x_test, y_test]. Don't test to the model, specify nil.
658
+ # @param [Boolean] verbose Set true to display the log. If false is set, the log is not displayed.
659
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
660
+ # @param [IO] io Specifies the IO object to use for logging.
661
+ def start_train(x, y, epochs,
662
+ batch_size: 1,
663
+ initial_epoch: 1,
664
+ test: nil,
665
+ verbose: true,
666
+ need_accuracy: true,
667
+ io: $stdout)
668
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
669
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
670
+ train_iterator = Iterator.new(x, y)
671
+ start_train_by_iterator(train_iterator, epochs,
672
+ batch_size: batch_size,
673
+ initial_epoch: initial_epoch,
674
+ test: test,
675
+ verbose: verbose,
676
+ need_accuracy: need_accuracy,
677
+ io: io)
678
+ end
679
+
680
+ # Start training by iterator.
681
+ # Setup the model before use this method.
682
+ # @param [DNN::Iterator] train_iterator Iterator used for training.
683
+ # @param [Integer] epochs Number of training.
684
+ # @param [Integer] batch_size Batch size used for one training.
685
+ # @param [Integer] initial_epoch Initial epoch.
686
+ # @param [Array | NilClass] test If you to test the model for every 1 epoch,
687
+ # specify [x_test, y_test]. Don't test to the model, specify nil.
688
+ # @param [Boolean] verbose Set true to display the log. If false is set, the log is not displayed.
689
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
690
+ # @param [IO] io Specifies the IO object to use for logging.
691
+ def start_train_by_iterator(train_iterator, epochs,
692
+ batch_size: 1,
693
+ initial_epoch: 1,
694
+ test: nil,
695
+ verbose: true,
696
+ need_accuracy: true,
697
+ io: $stdout)
698
+ raise DNNError, "The model is not optimizer setup complete." unless @model.optimizer
699
+ raise DNNError, "The model is not loss_func setup complete." unless @model.loss_func
700
+ @model.check_early_stop_requested # Clear early stop request.
701
+ @train_iterator = train_iterator
702
+ @max_epochs = epochs
703
+ @batch_size = batch_size
704
+ @epoch = initial_epoch
705
+ @test = test
706
+ @verbose = verbose
707
+ @need_accuracy = need_accuracy
708
+ @io = io
709
+ @state = :start_epoch
710
+ @max_steps = train_iterator.max_steps(batch_size)
711
+ @num_train_datas = train_iterator.num_usable_datas(batch_size)
712
+ @line_first_pos = 0
713
+ @model.call_callbacks(:before_train)
714
+ end
715
+
716
+ # Check if it is currently evaluating.
717
+ # @return [Boolean] Returns true if currently training.
718
+ def training?
719
+ @state != :none
720
+ end
721
+
722
+ # Update trainer.
723
+ def update
724
+ case @state
725
+ when :start_epoch
726
+ start_epoch
727
+ when :start_step
728
+ start_step
729
+ when :train_step
730
+ train_step
731
+ when :end_step
732
+ end_step
733
+ when :end_epoch
734
+ end_epoch
735
+ when :start_evaluate
736
+ start_evaluate
737
+ when :evaluating
738
+ evaluating
739
+ when :end_evaluate
740
+ end_evaluate
741
+ when :end_training
742
+ end_training
743
+ end
744
+ end
745
+
746
+ private
747
+
748
+ def start_epoch
749
+ @model.last_log[:epoch] = @epoch
750
+ @model.call_callbacks(:before_epoch)
751
+ @io.puts "【 epoch #{@epoch}/#{@max_epochs} 】" if @verbose
752
+ @step = 1
753
+ @state = :start_step
754
+ end
755
+
756
+ def start_step
757
+ @model.last_log[:step] = @step
758
+ @state = :train_step
759
+ end
760
+
761
+ def train_step
762
+ (x_batch, y_batch) = @train_iterator.next_batch(@batch_size)
763
+ @model.call_callbacks(:before_train_on_batch)
764
+ train_step_met = @model.train_step(x_batch, y_batch, need_accuracy: @need_accuracy)
765
+ @model.last_log.merge!(train_step_met)
766
+ @model.call_callbacks(:after_train_on_batch)
767
+ num_trained_datas = @step * @batch_size
768
+ num_trained_datas = num_trained_datas > @num_train_datas ? @num_train_datas : num_trained_datas
769
+ if @io == $stdout
770
+ log = "\r"
771
+ else
772
+ @line_first_pos = @io.pos
773
+ log = ""
774
+ end
775
+ 40.times do |i|
776
+ if i < num_trained_datas * 40 / @num_train_datas
777
+ log << "="
778
+ elsif i == num_trained_datas * 40 / @num_train_datas
779
+ log << ">"
780
+ else
781
+ log << "_"
782
+ end
783
+ end
784
+ log << " #{num_trained_datas}/#{@num_train_datas} "
785
+ log << metrics_to_str(train_step_met)
786
+ @io.print log if @verbose
787
+ if @model.check_early_stop_requested
788
+ @io.puts("\nEarly stopped.") if @verbose
789
+ @state = :end_training
790
+ else
791
+ @state = :end_step
792
+ end
793
+ end
794
+
795
+ def end_step
796
+ @step += 1
797
+ if @step <= @max_steps
798
+ unless @io == $stdout
799
+ @io.pos = @line_first_pos
800
+ end
801
+ @state = :start_step
802
+ else
803
+ @state = :end_epoch
804
+ end
805
+ end
806
+
807
+ def end_epoch
808
+ @epoch += 1
809
+ if @test
810
+ @state = :start_evaluate
811
+ else
812
+ @io.puts "" if @verbose
813
+ @model.call_callbacks(:after_epoch)
814
+ if @epoch <= @max_epochs
815
+ @train_iterator.reset
816
+ @state = :start_epoch
817
+ else
818
+ @state = :none
819
+ end
820
+ end
821
+ end
822
+
823
+ def start_evaluate
824
+ @evaluator = ModelEvaluator.new(@model)
825
+ if @test.is_a?(Array)
826
+ @evaluator.start_evaluate(@test[0], @test[1], batch_size: @batch_size, need_accuracy: @need_accuracy)
827
+ else
828
+ @evaluator.start_evaluate_by_iterator(@test, batch_size: @batch_size, need_accuracy: @need_accuracy)
829
+ end
830
+ @state = :evaluating
831
+ end
832
+
833
+ def evaluating
834
+ @evaluator.update
835
+ unless @evaluator.evaluating?
836
+ @state = :end_evaluate
837
+ end
838
+ end
839
+
840
+ def end_evaluate
841
+ if @verbose
842
+ metrics = if @need_accuracy
843
+ { test_accuracy: @model.last_log[:test_accuracy], test_loss: @model.last_log[:test_loss] }
844
+ else
845
+ { test_loss: @model.last_log[:test_loss] }
846
+ end
847
+ @io.print " " + metrics_to_str(metrics)
848
+ end
849
+ @io.puts "" if @verbose
850
+ @model.call_callbacks(:after_epoch)
851
+ if @epoch <= @max_epochs
852
+ @train_iterator.reset
853
+ if @model.check_early_stop_requested
854
+ @io.puts("Early stopped.") if @verbose
855
+ @state = :end_training
856
+ else
857
+ @state = :start_epoch
858
+ end
859
+ else
860
+ @state = :end_training
861
+ end
862
+ end
863
+
864
+ def end_training
865
+ @model.call_callbacks(:after_train)
866
+ @state = :none
867
+ end
634
868
 
635
869
  def metrics_to_str(mertics)
636
870
  mertics.map { |key, values|
@@ -643,28 +877,119 @@ module DNN
643
877
  "#{key}: #{str_values}"
644
878
  }.join(", ")
645
879
  end
880
+ end
881
+
882
+ class ModelEvaluator
883
+ def initialize(model)
884
+ @model = model
885
+ @state = :none
886
+ end
646
887
 
647
- def check_xy_type(x, y = nil)
648
- if !x.is_a?(Xumo::SFloat) && !x.is_a?(Array)
649
- raise TypeError, "x:#{x.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
888
+ # Start evaluate model and get accuracy and loss of test data.
889
+ # @param [Numo::SFloat] x Input test data.
890
+ # @param [Numo::SFloat] y Output test data.
891
+ # @param [Integer] batch_size Batch size used for one test.
892
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
893
+ # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
894
+ # If accuracy is not needed returns in the form [nil, mean_loss].
895
+ def start_evaluate(x, y, batch_size: 100, need_accuracy: true)
896
+ Utils.check_input_data_type("x", x, Xumo::SFloat)
897
+ Utils.check_input_data_type("y", y, Xumo::SFloat)
898
+ start_evaluate_by_iterator(Iterator.new(x, y, random: false), batch_size: batch_size, need_accuracy: need_accuracy)
899
+ end
900
+
901
+ # Start Evaluate model by iterator.
902
+ # @param [DNN::Iterator] test_iterator Iterator used for testing.
903
+ # @param [Integer] batch_size Batch size used for one test.
904
+ # @param [Boolean] need_accuracy Set true to compute the accuracy.
905
+ # @return [Array] Returns the test data accuracy and mean loss in the form [accuracy, mean_loss].
906
+ # If accuracy is not needed returns in the form [nil, mean_loss].
907
+ def start_evaluate_by_iterator(test_iterator, batch_size: 100, need_accuracy: true)
908
+ @test_iterator = test_iterator
909
+ @num_test_datas = test_iterator.num_datas
910
+ @batch_size = batch_size >= @num_test_datas ? @num_test_datas : batch_size
911
+ @need_accuracy = need_accuracy
912
+ if @loss_func.is_a?(Array)
913
+ @total_correct = Array.new(@loss_func.length, 0)
914
+ @sum_loss = Array.new(@loss_func.length, 0)
915
+ else
916
+ @total_correct = 0
917
+ @sum_loss = 0
650
918
  end
651
- if x.is_a?(Array)
652
- x.each.with_index do |v, i|
653
- unless v.is_a?(Xumo::SFloat)
654
- raise TypeError, "x[#{i}]:#{v.class.name} is not an instance of #{Xumo::SFloat.name} class."
655
- end
919
+ @step = 1
920
+ @max_steps = (@num_test_datas.to_f / @batch_size).ceil
921
+ @state = :start_step
922
+ end
923
+
924
+ # Check if it is currently evaluating.
925
+ # @return [Boolean] Returns true if currently evaluating.
926
+ def evaluating?
927
+ @state != :none
928
+ end
929
+
930
+ # Update evaluator.
931
+ def update
932
+ case @state
933
+ when :start_step
934
+ start_step
935
+ when :test_step
936
+ test_step
937
+ when :end_step
938
+ end_step
939
+ when :end_evaluate
940
+ end_evaluate
941
+ end
942
+ end
943
+
944
+ private
945
+
946
+ def start_step
947
+ @model.last_log[:step] = @step
948
+ @state = :test_step
949
+ end
950
+
951
+ def test_step
952
+ (x_batch, y_batch) = @test_iterator.next_batch(@batch_size)
953
+ @model.call_callbacks(:before_test_on_batch)
954
+ test_met = @model.test_step(x_batch, y_batch, need_accuracy: @need_accuracy)
955
+ @model.call_callbacks(:after_test_on_batch)
956
+ if @loss_func.is_a?(Array)
957
+ @loss_func.each_index do |i|
958
+ @total_correct[i] += test_met[:test_accuracy][i] if @need_accuracy
959
+ @sum_loss[i] += test_met[:test_loss][i]
656
960
  end
961
+ else
962
+ @total_correct += test_met[:test_accuracy] if @need_accuracy
963
+ @sum_loss += test_met[:test_loss]
657
964
  end
658
- if y && !y.is_a?(Xumo::SFloat) && !y.is_a?(Array)
659
- raise TypeError, "y:#{y.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
965
+ @state = :end_step
966
+ end
967
+
968
+ def end_step
969
+ @step += 1
970
+ if @step <= @max_steps
971
+ @state = :start_step
972
+ else
973
+ @state = :end_evaluate
660
974
  end
661
- if y.is_a?(Array)
662
- y.each.with_index do |v, i|
663
- unless v.is_a?(Xumo::SFloat)
664
- raise TypeError, "x[#{i}]:#{v.class.name} is not an instance of #{Xumo::SFloat.name} class."
665
- end
975
+ end
976
+
977
+ def end_evaluate
978
+ acc = nil
979
+ if @loss_func.is_a?(Array)
980
+ mean_loss = Array.new(@loss_func.length, 0)
981
+ acc = Array.new(@loss_func.length, 0) if @need_accuracy
982
+ @loss_func.each_index do |i|
983
+ mean_loss[i] += @sum_loss[i] / @max_steps
984
+ acc[i] += @total_correct[i].to_f / @num_test_datas if @need_accuracy
666
985
  end
986
+ else
987
+ mean_loss = @sum_loss / @max_steps
988
+ acc = @total_correct.to_f / @num_test_datas if @need_accuracy
667
989
  end
990
+ @model.last_log[:test_loss] = mean_loss
991
+ @model.last_log[:test_accuracy] = acc
992
+ @state = :none
668
993
  end
669
994
  end
670
995