rumale 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/.coveralls.yml +1 -0
  3. data/.gitignore +20 -0
  4. data/.rspec +3 -0
  5. data/.rubocop.yml +47 -0
  6. data/.rubocop_todo.yml +58 -0
  7. data/.travis.yml +13 -0
  8. data/CHANGELOG.md +2 -0
  9. data/CODE_OF_CONDUCT.md +74 -0
  10. data/Gemfile +4 -0
  11. data/LICENSE.txt +23 -0
  12. data/README.md +175 -0
  13. data/Rakefile +6 -0
  14. data/bin/console +14 -0
  15. data/bin/setup +8 -0
  16. data/lib/rumale.rb +70 -0
  17. data/lib/rumale/base/base_estimator.rb +13 -0
  18. data/lib/rumale/base/classifier.rb +36 -0
  19. data/lib/rumale/base/cluster_analyzer.rb +31 -0
  20. data/lib/rumale/base/evaluator.rb +17 -0
  21. data/lib/rumale/base/regressor.rb +36 -0
  22. data/lib/rumale/base/splitter.rb +21 -0
  23. data/lib/rumale/base/transformer.rb +22 -0
  24. data/lib/rumale/clustering/dbscan.rb +125 -0
  25. data/lib/rumale/clustering/k_means.rb +138 -0
  26. data/lib/rumale/dataset.rb +110 -0
  27. data/lib/rumale/decomposition/nmf.rb +141 -0
  28. data/lib/rumale/decomposition/pca.rb +148 -0
  29. data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
  30. data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
  31. data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
  32. data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
  33. data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
  34. data/lib/rumale/evaluation_measure/f_score.rb +50 -0
  35. data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
  36. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
  37. data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
  38. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
  39. data/lib/rumale/evaluation_measure/precision.rb +50 -0
  40. data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
  41. data/lib/rumale/evaluation_measure/purity.rb +40 -0
  42. data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
  43. data/lib/rumale/evaluation_measure/recall.rb +50 -0
  44. data/lib/rumale/kernel_approximation/rbf.rb +121 -0
  45. data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
  46. data/lib/rumale/linear_model/base_linear_model.rb +89 -0
  47. data/lib/rumale/linear_model/lasso.rb +136 -0
  48. data/lib/rumale/linear_model/linear_regression.rb +110 -0
  49. data/lib/rumale/linear_model/logistic_regression.rb +159 -0
  50. data/lib/rumale/linear_model/ridge.rb +110 -0
  51. data/lib/rumale/linear_model/svc.rb +183 -0
  52. data/lib/rumale/linear_model/svr.rb +122 -0
  53. data/lib/rumale/model_selection/cross_validation.rb +123 -0
  54. data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
  55. data/lib/rumale/model_selection/k_fold.rb +76 -0
  56. data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
  57. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
  58. data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
  59. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
  60. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
  61. data/lib/rumale/optimizer/nadam.rb +90 -0
  62. data/lib/rumale/optimizer/rmsprop.rb +69 -0
  63. data/lib/rumale/optimizer/sgd.rb +65 -0
  64. data/lib/rumale/optimizer/yellow_fin.rb +144 -0
  65. data/lib/rumale/pairwise_metric.rb +91 -0
  66. data/lib/rumale/pipeline/pipeline.rb +197 -0
  67. data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
  68. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
  69. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
  70. data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
  71. data/lib/rumale/preprocessing/label_encoder.rb +94 -0
  72. data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
  73. data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
  74. data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
  75. data/lib/rumale/probabilistic_output.rb +112 -0
  76. data/lib/rumale/tree/base_decision_tree.rb +153 -0
  77. data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
  78. data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
  79. data/lib/rumale/tree/node.rb +70 -0
  80. data/lib/rumale/utils.rb +37 -0
  81. data/lib/rumale/validation.rb +79 -0
  82. data/lib/rumale/values.rb +13 -0
  83. data/lib/rumale/version.rb +6 -0
  84. data/rumale.gemspec +41 -0
  85. metadata +204 -0
@@ -0,0 +1,111 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/classifier'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that implement estimators based on nearest neighbors rule.
8
+ module NearestNeighbors
9
+ # KNeighborsClassifier is a class that implements the classifier with the k-nearest neighbors rule.
10
+ # The current implementation uses the Euclidean distance for finding the neighbors.
11
+ #
12
+ # @example
13
+ # estimator =
14
+ # Rumale::NearestNeighbors::KNeighborsClassifier.new(n_neighbors = 5)
15
+ # estimator.fit(training_samples, traininig_labels)
16
+ # results = estimator.predict(testing_samples)
17
+ #
18
+ class KNeighborsClassifier
19
+ include Base::BaseEstimator
20
+ include Base::Classifier
21
+
22
+ # Return the prototypes for the nearest neighbor classifier.
23
+ # @return [Numo::DFloat] (shape: [n_samples, n_features])
24
+ attr_reader :prototypes
25
+
26
+ # Return the labels of the prototypes
27
+ # @return [Numo::Int32] (size: n_samples)
28
+ attr_reader :labels
29
+
30
+ # Return the class labels.
31
+ # @return [Numo::Int32] (size: n_classes)
32
+ attr_reader :classes
33
+
34
+ # Create a new classifier with the nearest neighbor rule.
35
+ #
36
+ # @param n_neighbors [Integer] The number of neighbors.
37
+ def initialize(n_neighbors: 5)
38
+ check_params_integer(n_neighbors: n_neighbors)
39
+ check_params_positive(n_neighbors: n_neighbors)
40
+ @params = {}
41
+ @params[:n_neighbors] = n_neighbors
42
+ @prototypes = nil
43
+ @labels = nil
44
+ @classes = nil
45
+ end
46
+
47
+ # Fit the model with given training data.
48
+ #
49
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
50
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
51
+ # @return [KNeighborsClassifier] The learned classifier itself.
52
+ def fit(x, y)
53
+ check_sample_array(x)
54
+ check_label_array(y)
55
+ check_sample_label_size(x, y)
56
+ @prototypes = Numo::DFloat.asarray(x.to_a)
57
+ @labels = Numo::Int32.asarray(y.to_a)
58
+ @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
59
+ self
60
+ end
61
+
62
+ # Calculate confidence scores for samples.
63
+ #
64
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
65
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
66
+ def decision_function(x)
67
+ check_sample_array(x)
68
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @prototypes)
69
+ n_samples, n_prototypes = distance_matrix.shape
70
+ n_classes = @classes.size
71
+ n_neighbors = [@params[:n_neighbors], n_prototypes].min
72
+ scores = Numo::DFloat.zeros(n_samples, n_classes)
73
+ n_samples.times do |m|
74
+ neighbor_ids = distance_matrix[m, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors]
75
+ neighbor_ids.each { |n| scores[m, @classes.to_a.index(@labels[n])] += 1.0 }
76
+ end
77
+ scores
78
+ end
79
+
80
+ # Predict class labels for samples.
81
+ #
82
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
83
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
84
+ def predict(x)
85
+ check_sample_array(x)
86
+ n_samples = x.shape.first
87
+ decision_values = decision_function(x)
88
+ Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
89
+ end
90
+
91
+ # Dump marshal data.
92
+ # @return [Hash] The marshal data about KNeighborsClassifier.
93
+ def marshal_dump
94
+ { params: @params,
95
+ prototypes: @prototypes,
96
+ labels: @labels,
97
+ classes: @classes }
98
+ end
99
+
100
+ # Load marshal data.
101
+ # @return [nil]
102
+ def marshal_load(obj)
103
+ @params = obj[:params]
104
+ @prototypes = obj[:prototypes]
105
+ @labels = obj[:labels]
106
+ @classes = obj[:classes]
107
+ nil
108
+ end
109
+ end
110
+ end
111
+ end
@@ -0,0 +1,93 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/regressor'
5
+
6
+ module Rumale
7
+ module NearestNeighbors
8
+ # KNeighborsRegressor is a class that implements the regressor with the k-nearest neighbors rule.
9
+ # The current implementation uses the Euclidean distance for finding the neighbors.
10
+ #
11
+ # @example
12
+ # estimator =
13
+ # Rumale::NearestNeighbors::KNeighborsRegressor.new(n_neighbors = 5)
14
+ # estimator.fit(training_samples, traininig_target_values)
15
+ # results = estimator.predict(testing_samples)
16
+ #
17
+ class KNeighborsRegressor
18
+ include Base::BaseEstimator
19
+ include Base::Regressor
20
+
21
+ # Return the prototypes for the nearest neighbor regressor.
22
+ # @return [Numo::DFloat] (shape: [n_samples, n_features])
23
+ attr_reader :prototypes
24
+
25
+ # Return the values of the prototypes
26
+ # @return [Numo::DFloat] (shape: [n_samples, n_outputs])
27
+ attr_reader :values
28
+
29
+ # Create a new regressor with the nearest neighbor rule.
30
+ #
31
+ # @param n_neighbors [Integer] The number of neighbors.
32
+ def initialize(n_neighbors: 5)
33
+ check_params_integer(n_neighbors: n_neighbors)
34
+ check_params_positive(n_neighbors: n_neighbors)
35
+ @params = {}
36
+ @params[:n_neighbors] = n_neighbors
37
+ @prototypes = nil
38
+ @values = nil
39
+ end
40
+
41
+ # Fit the model with given training data.
42
+ #
43
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
44
+ # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
45
+ # @return [KNeighborsRegressor] The learned regressor itself.
46
+ def fit(x, y)
47
+ check_sample_array(x)
48
+ check_tvalue_array(y)
49
+ check_sample_tvalue_size(x, y)
50
+ @prototypes = x.dup
51
+ @values = y.dup
52
+ self
53
+ end
54
+
55
+ # Predict values for samples.
56
+ #
57
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
58
+ # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
59
+ def predict(x)
60
+ check_sample_array(x)
61
+ # Initialize some variables.
62
+ n_samples, = x.shape
63
+ n_prototypes, n_outputs = @values.shape
64
+ n_neighbors = [@params[:n_neighbors], n_prototypes].min
65
+ # Calculate distance matrix.
66
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @prototypes)
67
+ # Predict values for the given samples.
68
+ predicted_values = Array.new(n_samples) do |n|
69
+ neighbor_ids = distance_matrix[n, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors]
70
+ n_outputs.nil? ? @values[neighbor_ids].mean : @values[neighbor_ids, true].mean(0).to_a
71
+ end
72
+ Numo::DFloat[*predicted_values]
73
+ end
74
+
75
+ # Dump marshal data.
76
+ # @return [Hash] The marshal data about KNeighborsRegressor.
77
+ def marshal_dump
78
+ { params: @params,
79
+ prototypes: @prototypes,
80
+ values: @values }
81
+ end
82
+
83
+ # Load marshal data.
84
+ # @return [nil]
85
+ def marshal_load(obj)
86
+ @params = obj[:params]
87
+ @prototypes = obj[:prototypes]
88
+ @values = obj[:values]
89
+ nil
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,90 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/base/base_estimator'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that implement optimizers adaptively tuning hyperparameters.
8
+ module Optimizer
9
+ # Nadam is a class that implements Nadam optimizer.
10
+ #
11
+ # @example
12
+ # optimizer = Rumale::Optimizer::Nadam.new(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999)
13
+ # estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
14
+ # estimator.fit(samples, values)
15
+ #
16
+ # *Reference*
17
+ # - T. Dozat, "Incorporating Nesterov Momentum into Adam," Tech. Repo. Stanford University, 2015.
18
+ class Nadam
19
+ include Base::BaseEstimator
20
+ include Validation
21
+
22
+ # Create a new optimizer with Nadam
23
+ #
24
+ # @param learning_rate [Float] The initial value of learning rate.
25
+ # @param momentum [Float] The initial value of momentum.
26
+ # @param decay1 [Float] The smoothing parameter for the first moment.
27
+ # @param decay2 [Float] The smoothing parameter for the second moment.
28
+ def initialize(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999)
29
+ check_params_float(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2)
30
+ check_params_positive(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2)
31
+ @params = {}
32
+ @params[:learning_rate] = learning_rate
33
+ @params[:momentum] = momentum
34
+ @params[:decay1] = decay1
35
+ @params[:decay2] = decay2
36
+ @fst_moment = nil
37
+ @sec_moment = nil
38
+ @decay1_prod = 1.0
39
+ @iter = 0
40
+ end
41
+
42
+ # Calculate the updated weight with Nadam adaptive learning rate.
43
+ #
44
+ # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
45
+ # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
46
+ # @return [Numo::DFloat] (shape: [n_feautres]) The updated weight.
47
+ def call(weight, gradient)
48
+ @fst_moment ||= Numo::DFloat.zeros(weight.shape[0])
49
+ @sec_moment ||= Numo::DFloat.zeros(weight.shape[0])
50
+
51
+ @iter += 1
52
+
53
+ decay1_curr = @params[:decay1] * (1.0 - 0.5 * 0.96**(@iter * 0.004))
54
+ decay1_next = @params[:decay1] * (1.0 - 0.5 * 0.96**((@iter + 1) * 0.004))
55
+ decay1_prod_curr = @decay1_prod * decay1_curr
56
+ decay1_prod_next = @decay1_prod * decay1_curr * decay1_next
57
+ @decay1_prod = decay1_prod_curr
58
+
59
+ @fst_moment = @params[:decay1] * @fst_moment + (1.0 - @params[:decay1]) * gradient
60
+ @sec_moment = @params[:decay2] * @sec_moment + (1.0 - @params[:decay2]) * gradient**2
61
+ nm_gradient = gradient / (1.0 - decay1_prod_curr)
62
+ nm_fst_moment = @fst_moment / (1.0 - decay1_prod_next)
63
+ nm_sec_moment = @sec_moment / (1.0 - @params[:decay2]**@iter)
64
+
65
+ weight - (@params[:learning_rate] / (nm_sec_moment**0.5 + 1e-8)) * ((1 - decay1_curr) * nm_gradient + decay1_next * nm_fst_moment)
66
+ end
67
+
68
+ # Dump marshal data.
69
+ # @return [Hash] The marshal data.
70
+ def marshal_dump
71
+ { params: @params,
72
+ fst_moment: @fst_moment,
73
+ sec_moment: @sec_moment,
74
+ decay1_prod: @decay1_prod,
75
+ iter: @iter }
76
+ end
77
+
78
+ # Load marshal data.
79
+ # @return [nil]
80
+ def marshal_load(obj)
81
+ @params = obj[:params]
82
+ @fst_moment = obj[:fst_moment]
83
+ @sec_moment = obj[:sec_moment]
84
+ @decay1_prod = obj[:decay1_prod]
85
+ @iter = obj[:iter]
86
+ nil
87
+ end
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,69 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/base/base_estimator'
5
+
6
+ module Rumale
7
+ module Optimizer
8
+ # RMSProp is a class that implements RMSProp optimizer.
9
+ #
10
+ # @example
11
+ # optimizer = Rumale::Optimizer::RMSProp.new(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
12
+ # estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
13
+ # estimator.fit(samples, values)
14
+ #
15
+ # *Reference*
16
+ # - I. Sutskever, J. Martens, G. Dahl, and G. Hinton, "On the importance of initialization and momentum in deep learning," Proc. ICML' 13, pp. 1139--1147, 2013.
17
+ # - G. Hinton, N. Srivastava, and K. Swersky, "Lecture 6e rmsprop," Neural Networks for Machine Learning, 2012.
18
+ class RMSProp
19
+ include Base::BaseEstimator
20
+ include Validation
21
+
22
+ # Create a new optimizer with RMSProp.
23
+ #
24
+ # @param learning_rate [Float] The initial value of learning rate.
25
+ # @param momentum [Float] The initial value of momentum.
26
+ # @param decay [Float] The smooting parameter.
27
+ def initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
28
+ check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay)
29
+ check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay)
30
+ @params = {}
31
+ @params[:learning_rate] = learning_rate
32
+ @params[:momentum] = momentum
33
+ @params[:decay] = decay
34
+ @moment = nil
35
+ @update = nil
36
+ end
37
+
38
+ # Calculate the updated weight with RMSProp adaptive learning rate.
39
+ #
40
+ # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
41
+ # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
42
+ # @return [Numo::DFloat] (shape: [n_feautres]) The updated weight.
43
+ def call(weight, gradient)
44
+ @moment ||= Numo::DFloat.zeros(weight.shape[0])
45
+ @update ||= Numo::DFloat.zeros(weight.shape[0])
46
+ @moment = @params[:decay] * @moment + (1.0 - @params[:decay]) * gradient**2
47
+ @update = @params[:momentum] * @update - (@params[:learning_rate] / (@moment**0.5 + 1.0e-8)) * gradient
48
+ weight + @update
49
+ end
50
+
51
+ # Dump marshal data.
52
+ # @return [Hash] The marshal data.
53
+ def marshal_dump
54
+ { params: @params,
55
+ moment: @moment,
56
+ update: @update }
57
+ end
58
+
59
+ # Load marshal data.
60
+ # @return [nil]
61
+ def marshal_load(obj)
62
+ @params = obj[:params]
63
+ @moment = obj[:moment]
64
+ @update = obj[:update]
65
+ nil
66
+ end
67
+ end
68
+ end
69
+ end
@@ -0,0 +1,65 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/base/base_estimator'
5
+
6
+ module Rumale
7
+ module Optimizer
8
+ # SGD is a class that implements SGD optimizer.
9
+ #
10
+ # @example
11
+ # optimizer = Rumale::Optimizer::SGD.new(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
12
+ # estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
13
+ # estimator.fit(samples, values)
14
+ class SGD
15
+ include Base::BaseEstimator
16
+ include Validation
17
+
18
+ # Create a new optimizer with SGD.
19
+ #
20
+ # @param learning_rate [Float] The initial value of learning rate.
21
+ # @param momentum [Float] The initial value of momentum.
22
+ # @param decay [Float] The smooting parameter.
23
+ def initialize(learning_rate: 0.01, momentum: 0.0, decay: 0.0)
24
+ check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay)
25
+ check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay)
26
+ @params = {}
27
+ @params[:learning_rate] = learning_rate
28
+ @params[:momentum] = momentum
29
+ @params[:decay] = decay
30
+ @iter = 0
31
+ @update = nil
32
+ end
33
+
34
+ # Calculate the updated weight with SGD.
35
+ #
36
+ # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
37
+ # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
38
+ # @return [Numo::DFloat] (shape: [n_feautres]) The updated weight.
39
+ def call(weight, gradient)
40
+ @update ||= Numo::DFloat.zeros(weight.shape[0])
41
+ current_learning_rate = @params[:learning_rate] / (1.0 + @params[:decay] * @iter)
42
+ @iter += 1
43
+ @update = @params[:momentum] * @update - current_learning_rate * gradient
44
+ weight + @update
45
+ end
46
+
47
+ # Dump marshal data.
48
+ # @return [Hash] The marshal data.
49
+ def marshal_dump
50
+ { params: @params,
51
+ iter: @iter,
52
+ update: @update }
53
+ end
54
+
55
+ # Load marshal data.
56
+ # @return [nil]
57
+ def marshal_load(obj)
58
+ @params = obj[:params]
59
+ @iter = obj[:iter]
60
+ @update = obj[:update]
61
+ nil
62
+ end
63
+ end
64
+ end
65
+ end
@@ -0,0 +1,144 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/base/base_estimator'
5
+
6
+ module Rumale
7
+ module Optimizer
8
+ # YellowFin is a class that implements YellowFin optimizer.
9
+ #
10
+ # @example
11
+ # optimizer = Rumale::Optimizer::YellowFin.new(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20)
12
+ # estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
13
+ # estimator.fit(samples, values)
14
+ #
15
+ # *Reference*
16
+ # - J. Zhang and I. Mitliagkas, "YellowFin and the Art of Momentum Tuning," CoRR abs/1706.03471, 2017.
17
+ class YellowFin
18
+ include Base::BaseEstimator
19
+ include Validation
20
+
21
+ # Create a new optimizer with YellowFin.
22
+ #
23
+ # @param learning_rate [Float] The initial value of learning rate.
24
+ # @param momentum [Float] The initial value of momentum.
25
+ # @param decay [Float] The smooting parameter.
26
+ # @param window_width [Integer] The sliding window width for searching curvature range.
27
+ def initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20)
28
+ check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay)
29
+ check_params_integer(window_width: window_width)
30
+ check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay, window_width: window_width)
31
+ @params = {}
32
+ @params[:learning_rate] = learning_rate
33
+ @params[:momentum] = momentum
34
+ @params[:decay] = decay
35
+ @params[:window_width] = window_width
36
+ @smth_learning_rate = learning_rate
37
+ @smth_momentum = momentum
38
+ @grad_norms = nil
39
+ @grad_norm_min = 0.0
40
+ @grad_norm_max = 0.0
41
+ @grad_mean_sqr = 0.0
42
+ @grad_mean = 0.0
43
+ @grad_var = 0.0
44
+ @grad_norm_mean = 0.0
45
+ @curve_mean = 0.0
46
+ @distance_mean = 0.0
47
+ @update = nil
48
+ end
49
+
50
+ # Calculate the updated weight with adaptive momentum coefficient and learning rate.
51
+ #
52
+ # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
53
+ # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
54
+ # @return [Numo::DFloat] (shape: [n_feautres]) The updated weight.
55
+ def call(weight, gradient)
56
+ @update ||= Numo::DFloat.zeros(weight.shape[0])
57
+ curvature_range(gradient)
58
+ gradient_variance(gradient)
59
+ distance_to_optimum(gradient)
60
+ @smth_momentum = @params[:decay] * @smth_momentum + (1 - @params[:decay]) * current_momentum
61
+ @smth_learning_rate = @params[:decay] * @smth_learning_rate + (1 - @params[:decay]) * current_learning_rate
62
+ @update = @smth_momentum * @update - @smth_learning_rate * gradient
63
+ weight + @update
64
+ end
65
+
66
+ private
67
+
68
+ def current_momentum
69
+ dr = Math.sqrt(@grad_norm_max / @grad_norm_min + 1.0e-8)
70
+ [cubic_root**2, ((dr - 1) / (dr + 1))**2].max
71
+ end
72
+
73
+ def current_learning_rate
74
+ (1.0 - Math.sqrt(@params[:momentum]))**2 / (@grad_norm_min + 1.0e-8)
75
+ end
76
+
77
+ def cubic_root
78
+ p = (@distance_mean**2 * @grad_norm_min**2) / (2 * @grad_var + 1.0e-8)
79
+ w3 = (-Math.sqrt(p**2 + 4.fdiv(27) * p**3) - p).fdiv(2)
80
+ w = (w3 >= 0.0 ? 1 : -1) * w3.abs**1.fdiv(3)
81
+ y = w - p / (3 * w + 1.0e-8)
82
+ y + 1
83
+ end
84
+
85
+ def curvature_range(gradient)
86
+ @grad_norms ||= []
87
+ @grad_norms.push((gradient**2).sum)
88
+ @grad_norms.shift(@grad_norms.size - @params[:window_width]) if @grad_norms.size > @params[:window_width]
89
+ @grad_norm_min = @params[:decay] * @grad_norm_min + (1 - @params[:decay]) * @grad_norms.min
90
+ @grad_norm_max = @params[:decay] * @grad_norm_max + (1 - @params[:decay]) * @grad_norms.max
91
+ end
92
+
93
+ def gradient_variance(gradient)
94
+ @grad_mean_sqr = @params[:decay] * @grad_mean_sqr + (1 - @params[:decay]) * gradient**2
95
+ @grad_mean = @params[:decay] * @grad_mean + (1 - @params[:decay]) * gradient
96
+ @grad_var = (@grad_mean_sqr - @grad_mean**2).sum
97
+ end
98
+
99
+ def distance_to_optimum(gradient)
100
+ grad_sqr = (gradient**2).sum
101
+ @grad_norm_mean = @params[:decay] * @grad_norm_mean + (1 - @params[:decay]) * Math.sqrt(grad_sqr + 1.0e-8)
102
+ @curve_mean = @params[:decay] * @curve_mean + (1 - @params[:decay]) * grad_sqr
103
+ @distance_mean = @params[:decay] * @distance_mean + (1 - @params[:decay]) * (@grad_norm_mean / @curve_mean)
104
+ end
105
+
106
+ # Dump marshal data.
107
+ # @return [Hash] The marshal data.
108
+ def marshal_dump
109
+ { params: @params,
110
+ smth_learning_rate: @smth_learning_rate,
111
+ smth_momentum: @smth_momentum,
112
+ grad_norms: @grad_norms,
113
+ grad_norm_min: @grad_norm_min,
114
+ grad_norm_max: @grad_norm_max,
115
+ grad_mean_sqr: @grad_mean_sqr,
116
+ grad_mean: @grad_mean,
117
+ grad_var: @grad_var,
118
+ grad_norm_mean: @grad_norm_mean,
119
+ curve_mean: @curve_mean,
120
+ distance_mean: @distance_mean,
121
+ update: @update }
122
+ end
123
+
124
+ # Load marshal data.
125
+ # @return [nis]
126
+ def marshal_load(obj)
127
+ @params = obj[:params]
128
+ @smth_learning_rate = obj[:smth_learning_rate]
129
+ @smth_momentum = obj[:smth_momentum]
130
+ @grad_norms = obj[:grad_norms]
131
+ @grad_norm_min = obj[:grad_norm_min]
132
+ @grad_norm_max = obj[:grad_norm_max]
133
+ @grad_mean_sqr = obj[:grad_mean_sqr]
134
+ @grad_mean = obj[:grad_mean]
135
+ @grad_var = obj[:grad_var]
136
+ @grad_norm_mean = obj[:grad_norm_mean]
137
+ @curve_mean = obj[:curve_mean]
138
+ @distance_mean = obj[:distance_mean]
139
+ @update = obj[:update]
140
+ nil
141
+ end
142
+ end
143
+ end
144
+ end