rumale 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/.coveralls.yml +1 -0
  3. data/.gitignore +20 -0
  4. data/.rspec +3 -0
  5. data/.rubocop.yml +47 -0
  6. data/.rubocop_todo.yml +58 -0
  7. data/.travis.yml +13 -0
  8. data/CHANGELOG.md +2 -0
  9. data/CODE_OF_CONDUCT.md +74 -0
  10. data/Gemfile +4 -0
  11. data/LICENSE.txt +23 -0
  12. data/README.md +175 -0
  13. data/Rakefile +6 -0
  14. data/bin/console +14 -0
  15. data/bin/setup +8 -0
  16. data/lib/rumale.rb +70 -0
  17. data/lib/rumale/base/base_estimator.rb +13 -0
  18. data/lib/rumale/base/classifier.rb +36 -0
  19. data/lib/rumale/base/cluster_analyzer.rb +31 -0
  20. data/lib/rumale/base/evaluator.rb +17 -0
  21. data/lib/rumale/base/regressor.rb +36 -0
  22. data/lib/rumale/base/splitter.rb +21 -0
  23. data/lib/rumale/base/transformer.rb +22 -0
  24. data/lib/rumale/clustering/dbscan.rb +125 -0
  25. data/lib/rumale/clustering/k_means.rb +138 -0
  26. data/lib/rumale/dataset.rb +110 -0
  27. data/lib/rumale/decomposition/nmf.rb +141 -0
  28. data/lib/rumale/decomposition/pca.rb +148 -0
  29. data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
  30. data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
  31. data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
  32. data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
  33. data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
  34. data/lib/rumale/evaluation_measure/f_score.rb +50 -0
  35. data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
  36. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
  37. data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
  38. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
  39. data/lib/rumale/evaluation_measure/precision.rb +50 -0
  40. data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
  41. data/lib/rumale/evaluation_measure/purity.rb +40 -0
  42. data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
  43. data/lib/rumale/evaluation_measure/recall.rb +50 -0
  44. data/lib/rumale/kernel_approximation/rbf.rb +121 -0
  45. data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
  46. data/lib/rumale/linear_model/base_linear_model.rb +89 -0
  47. data/lib/rumale/linear_model/lasso.rb +136 -0
  48. data/lib/rumale/linear_model/linear_regression.rb +110 -0
  49. data/lib/rumale/linear_model/logistic_regression.rb +159 -0
  50. data/lib/rumale/linear_model/ridge.rb +110 -0
  51. data/lib/rumale/linear_model/svc.rb +183 -0
  52. data/lib/rumale/linear_model/svr.rb +122 -0
  53. data/lib/rumale/model_selection/cross_validation.rb +123 -0
  54. data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
  55. data/lib/rumale/model_selection/k_fold.rb +76 -0
  56. data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
  57. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
  58. data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
  59. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
  60. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
  61. data/lib/rumale/optimizer/nadam.rb +90 -0
  62. data/lib/rumale/optimizer/rmsprop.rb +69 -0
  63. data/lib/rumale/optimizer/sgd.rb +65 -0
  64. data/lib/rumale/optimizer/yellow_fin.rb +144 -0
  65. data/lib/rumale/pairwise_metric.rb +91 -0
  66. data/lib/rumale/pipeline/pipeline.rb +197 -0
  67. data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
  68. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
  69. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
  70. data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
  71. data/lib/rumale/preprocessing/label_encoder.rb +94 -0
  72. data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
  73. data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
  74. data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
  75. data/lib/rumale/probabilistic_output.rb +112 -0
  76. data/lib/rumale/tree/base_decision_tree.rb +153 -0
  77. data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
  78. data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
  79. data/lib/rumale/tree/node.rb +70 -0
  80. data/lib/rumale/utils.rb +37 -0
  81. data/lib/rumale/validation.rb +79 -0
  82. data/lib/rumale/values.rb +13 -0
  83. data/lib/rumale/version.rb +6 -0
  84. data/rumale.gemspec +41 -0
  85. metadata +204 -0
@@ -0,0 +1,111 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/classifier'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that implement estimators based on nearest neighbors rule.
8
+ module NearestNeighbors
9
+ # KNeighborsClassifier is a class that implements the classifier with the k-nearest neighbors rule.
10
+ # The current implementation uses the Euclidean distance for finding the neighbors.
11
+ #
12
+ # @example
13
+ # estimator =
14
+ # Rumale::NearestNeighbors::KNeighborsClassifier.new(n_neighbors = 5)
15
+ # estimator.fit(training_samples, traininig_labels)
16
+ # results = estimator.predict(testing_samples)
17
+ #
18
+ class KNeighborsClassifier
19
+ include Base::BaseEstimator
20
+ include Base::Classifier
21
+
22
+ # Return the prototypes for the nearest neighbor classifier.
23
+ # @return [Numo::DFloat] (shape: [n_samples, n_features])
24
+ attr_reader :prototypes
25
+
26
+ # Return the labels of the prototypes
27
+ # @return [Numo::Int32] (size: n_samples)
28
+ attr_reader :labels
29
+
30
+ # Return the class labels.
31
+ # @return [Numo::Int32] (size: n_classes)
32
+ attr_reader :classes
33
+
34
+ # Create a new classifier with the nearest neighbor rule.
35
+ #
36
+ # @param n_neighbors [Integer] The number of neighbors.
37
+ def initialize(n_neighbors: 5)
38
+ check_params_integer(n_neighbors: n_neighbors)
39
+ check_params_positive(n_neighbors: n_neighbors)
40
+ @params = {}
41
+ @params[:n_neighbors] = n_neighbors
42
+ @prototypes = nil
43
+ @labels = nil
44
+ @classes = nil
45
+ end
46
+
47
+ # Fit the model with given training data.
48
+ #
49
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
50
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
51
+ # @return [KNeighborsClassifier] The learned classifier itself.
52
+ def fit(x, y)
53
+ check_sample_array(x)
54
+ check_label_array(y)
55
+ check_sample_label_size(x, y)
56
+ @prototypes = Numo::DFloat.asarray(x.to_a)
57
+ @labels = Numo::Int32.asarray(y.to_a)
58
+ @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
59
+ self
60
+ end
61
+
62
+ # Calculate confidence scores for samples.
63
+ #
64
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
65
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
66
+ def decision_function(x)
67
+ check_sample_array(x)
68
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @prototypes)
69
+ n_samples, n_prototypes = distance_matrix.shape
70
+ n_classes = @classes.size
71
+ n_neighbors = [@params[:n_neighbors], n_prototypes].min
72
+ scores = Numo::DFloat.zeros(n_samples, n_classes)
73
+ n_samples.times do |m|
74
+ neighbor_ids = distance_matrix[m, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors]
75
+ neighbor_ids.each { |n| scores[m, @classes.to_a.index(@labels[n])] += 1.0 }
76
+ end
77
+ scores
78
+ end
79
+
80
+ # Predict class labels for samples.
81
+ #
82
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
83
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
84
+ def predict(x)
85
+ check_sample_array(x)
86
+ n_samples = x.shape.first
87
+ decision_values = decision_function(x)
88
+ Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
89
+ end
90
+
91
+ # Dump marshal data.
92
+ # @return [Hash] The marshal data about KNeighborsClassifier.
93
+ def marshal_dump
94
+ { params: @params,
95
+ prototypes: @prototypes,
96
+ labels: @labels,
97
+ classes: @classes }
98
+ end
99
+
100
+ # Load marshal data.
101
+ # @return [nil]
102
+ def marshal_load(obj)
103
+ @params = obj[:params]
104
+ @prototypes = obj[:prototypes]
105
+ @labels = obj[:labels]
106
+ @classes = obj[:classes]
107
+ nil
108
+ end
109
+ end
110
+ end
111
+ end
@@ -0,0 +1,93 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/regressor'
5
+
6
+ module Rumale
7
+ module NearestNeighbors
8
+ # KNeighborsRegressor is a class that implements the regressor with the k-nearest neighbors rule.
9
+ # The current implementation uses the Euclidean distance for finding the neighbors.
10
+ #
11
+ # @example
12
+ # estimator =
13
+ # Rumale::NearestNeighbors::KNeighborsRegressor.new(n_neighbors = 5)
14
+ # estimator.fit(training_samples, traininig_target_values)
15
+ # results = estimator.predict(testing_samples)
16
+ #
17
+ class KNeighborsRegressor
18
+ include Base::BaseEstimator
19
+ include Base::Regressor
20
+
21
+ # Return the prototypes for the nearest neighbor regressor.
22
+ # @return [Numo::DFloat] (shape: [n_samples, n_features])
23
+ attr_reader :prototypes
24
+
25
+ # Return the values of the prototypes
26
+ # @return [Numo::DFloat] (shape: [n_samples, n_outputs])
27
+ attr_reader :values
28
+
29
+ # Create a new regressor with the nearest neighbor rule.
30
+ #
31
+ # @param n_neighbors [Integer] The number of neighbors.
32
+ def initialize(n_neighbors: 5)
33
+ check_params_integer(n_neighbors: n_neighbors)
34
+ check_params_positive(n_neighbors: n_neighbors)
35
+ @params = {}
36
+ @params[:n_neighbors] = n_neighbors
37
+ @prototypes = nil
38
+ @values = nil
39
+ end
40
+
41
+ # Fit the model with given training data.
42
+ #
43
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
44
+ # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
45
+ # @return [KNeighborsRegressor] The learned regressor itself.
46
+ def fit(x, y)
47
+ check_sample_array(x)
48
+ check_tvalue_array(y)
49
+ check_sample_tvalue_size(x, y)
50
+ @prototypes = x.dup
51
+ @values = y.dup
52
+ self
53
+ end
54
+
55
+ # Predict values for samples.
56
+ #
57
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
58
+ # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
59
+ def predict(x)
60
+ check_sample_array(x)
61
+ # Initialize some variables.
62
+ n_samples, = x.shape
63
+ n_prototypes, n_outputs = @values.shape
64
+ n_neighbors = [@params[:n_neighbors], n_prototypes].min
65
+ # Calculate distance matrix.
66
+ distance_matrix = PairwiseMetric.euclidean_distance(x, @prototypes)
67
+ # Predict values for the given samples.
68
+ predicted_values = Array.new(n_samples) do |n|
69
+ neighbor_ids = distance_matrix[n, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors]
70
+ n_outputs.nil? ? @values[neighbor_ids].mean : @values[neighbor_ids, true].mean(0).to_a
71
+ end
72
+ Numo::DFloat[*predicted_values]
73
+ end
74
+
75
+ # Dump marshal data.
76
+ # @return [Hash] The marshal data about KNeighborsRegressor.
77
+ def marshal_dump
78
+ { params: @params,
79
+ prototypes: @prototypes,
80
+ values: @values }
81
+ end
82
+
83
+ # Load marshal data.
84
+ # @return [nil]
85
+ def marshal_load(obj)
86
+ @params = obj[:params]
87
+ @prototypes = obj[:prototypes]
88
+ @values = obj[:values]
89
+ nil
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,90 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/base/base_estimator'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that implement optimizers adaptively tuning hyperparameters.
8
+ module Optimizer
9
+ # Nadam is a class that implements Nadam optimizer.
10
+ #
11
+ # @example
12
+ # optimizer = Rumale::Optimizer::Nadam.new(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999)
13
+ # estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
14
+ # estimator.fit(samples, values)
15
+ #
16
+ # *Reference*
17
+ # - T. Dozat, "Incorporating Nesterov Momentum into Adam," Tech. Repo. Stanford University, 2015.
18
+ class Nadam
19
+ include Base::BaseEstimator
20
+ include Validation
21
+
22
+ # Create a new optimizer with Nadam
23
+ #
24
+ # @param learning_rate [Float] The initial value of learning rate.
25
+ # @param momentum [Float] The initial value of momentum.
26
+ # @param decay1 [Float] The smoothing parameter for the first moment.
27
+ # @param decay2 [Float] The smoothing parameter for the second moment.
28
+ def initialize(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999)
29
+ check_params_float(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2)
30
+ check_params_positive(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2)
31
+ @params = {}
32
+ @params[:learning_rate] = learning_rate
33
+ @params[:momentum] = momentum
34
+ @params[:decay1] = decay1
35
+ @params[:decay2] = decay2
36
+ @fst_moment = nil
37
+ @sec_moment = nil
38
+ @decay1_prod = 1.0
39
+ @iter = 0
40
+ end
41
+
42
+ # Calculate the updated weight with Nadam adaptive learning rate.
43
+ #
44
+ # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
45
+ # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
46
+ # @return [Numo::DFloat] (shape: [n_feautres]) The updated weight.
47
+ def call(weight, gradient)
48
+ @fst_moment ||= Numo::DFloat.zeros(weight.shape[0])
49
+ @sec_moment ||= Numo::DFloat.zeros(weight.shape[0])
50
+
51
+ @iter += 1
52
+
53
+ decay1_curr = @params[:decay1] * (1.0 - 0.5 * 0.96**(@iter * 0.004))
54
+ decay1_next = @params[:decay1] * (1.0 - 0.5 * 0.96**((@iter + 1) * 0.004))
55
+ decay1_prod_curr = @decay1_prod * decay1_curr
56
+ decay1_prod_next = @decay1_prod * decay1_curr * decay1_next
57
+ @decay1_prod = decay1_prod_curr
58
+
59
+ @fst_moment = @params[:decay1] * @fst_moment + (1.0 - @params[:decay1]) * gradient
60
+ @sec_moment = @params[:decay2] * @sec_moment + (1.0 - @params[:decay2]) * gradient**2
61
+ nm_gradient = gradient / (1.0 - decay1_prod_curr)
62
+ nm_fst_moment = @fst_moment / (1.0 - decay1_prod_next)
63
+ nm_sec_moment = @sec_moment / (1.0 - @params[:decay2]**@iter)
64
+
65
+ weight - (@params[:learning_rate] / (nm_sec_moment**0.5 + 1e-8)) * ((1 - decay1_curr) * nm_gradient + decay1_next * nm_fst_moment)
66
+ end
67
+
68
+ # Dump marshal data.
69
+ # @return [Hash] The marshal data.
70
+ def marshal_dump
71
+ { params: @params,
72
+ fst_moment: @fst_moment,
73
+ sec_moment: @sec_moment,
74
+ decay1_prod: @decay1_prod,
75
+ iter: @iter }
76
+ end
77
+
78
+ # Load marshal data.
79
+ # @return [nil]
80
+ def marshal_load(obj)
81
+ @params = obj[:params]
82
+ @fst_moment = obj[:fst_moment]
83
+ @sec_moment = obj[:sec_moment]
84
+ @decay1_prod = obj[:decay1_prod]
85
+ @iter = obj[:iter]
86
+ nil
87
+ end
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,69 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/base/base_estimator'
5
+
6
+ module Rumale
7
+ module Optimizer
8
+ # RMSProp is a class that implements RMSProp optimizer.
9
+ #
10
+ # @example
11
+ # optimizer = Rumale::Optimizer::RMSProp.new(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
12
+ # estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
13
+ # estimator.fit(samples, values)
14
+ #
15
+ # *Reference*
16
+ # - I. Sutskever, J. Martens, G. Dahl, and G. Hinton, "On the importance of initialization and momentum in deep learning," Proc. ICML' 13, pp. 1139--1147, 2013.
17
+ # - G. Hinton, N. Srivastava, and K. Swersky, "Lecture 6e rmsprop," Neural Networks for Machine Learning, 2012.
18
+ class RMSProp
19
+ include Base::BaseEstimator
20
+ include Validation
21
+
22
+ # Create a new optimizer with RMSProp.
23
+ #
24
+ # @param learning_rate [Float] The initial value of learning rate.
25
+ # @param momentum [Float] The initial value of momentum.
26
+ # @param decay [Float] The smooting parameter.
27
+ def initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
28
+ check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay)
29
+ check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay)
30
+ @params = {}
31
+ @params[:learning_rate] = learning_rate
32
+ @params[:momentum] = momentum
33
+ @params[:decay] = decay
34
+ @moment = nil
35
+ @update = nil
36
+ end
37
+
38
+ # Calculate the updated weight with RMSProp adaptive learning rate.
39
+ #
40
+ # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
41
+ # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
42
+ # @return [Numo::DFloat] (shape: [n_feautres]) The updated weight.
43
+ def call(weight, gradient)
44
+ @moment ||= Numo::DFloat.zeros(weight.shape[0])
45
+ @update ||= Numo::DFloat.zeros(weight.shape[0])
46
+ @moment = @params[:decay] * @moment + (1.0 - @params[:decay]) * gradient**2
47
+ @update = @params[:momentum] * @update - (@params[:learning_rate] / (@moment**0.5 + 1.0e-8)) * gradient
48
+ weight + @update
49
+ end
50
+
51
+ # Dump marshal data.
52
+ # @return [Hash] The marshal data.
53
+ def marshal_dump
54
+ { params: @params,
55
+ moment: @moment,
56
+ update: @update }
57
+ end
58
+
59
+ # Load marshal data.
60
+ # @return [nil]
61
+ def marshal_load(obj)
62
+ @params = obj[:params]
63
+ @moment = obj[:moment]
64
+ @update = obj[:update]
65
+ nil
66
+ end
67
+ end
68
+ end
69
+ end
@@ -0,0 +1,65 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/base/base_estimator'
5
+
6
+ module Rumale
7
+ module Optimizer
8
+ # SGD is a class that implements SGD optimizer.
9
+ #
10
+ # @example
11
+ # optimizer = Rumale::Optimizer::SGD.new(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
12
+ # estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
13
+ # estimator.fit(samples, values)
14
+ class SGD
15
+ include Base::BaseEstimator
16
+ include Validation
17
+
18
+ # Create a new optimizer with SGD.
19
+ #
20
+ # @param learning_rate [Float] The initial value of learning rate.
21
+ # @param momentum [Float] The initial value of momentum.
22
+ # @param decay [Float] The smooting parameter.
23
+ def initialize(learning_rate: 0.01, momentum: 0.0, decay: 0.0)
24
+ check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay)
25
+ check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay)
26
+ @params = {}
27
+ @params[:learning_rate] = learning_rate
28
+ @params[:momentum] = momentum
29
+ @params[:decay] = decay
30
+ @iter = 0
31
+ @update = nil
32
+ end
33
+
34
+ # Calculate the updated weight with SGD.
35
+ #
36
+ # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
37
+ # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
38
+ # @return [Numo::DFloat] (shape: [n_feautres]) The updated weight.
39
+ def call(weight, gradient)
40
+ @update ||= Numo::DFloat.zeros(weight.shape[0])
41
+ current_learning_rate = @params[:learning_rate] / (1.0 + @params[:decay] * @iter)
42
+ @iter += 1
43
+ @update = @params[:momentum] * @update - current_learning_rate * gradient
44
+ weight + @update
45
+ end
46
+
47
+ # Dump marshal data.
48
+ # @return [Hash] The marshal data.
49
+ def marshal_dump
50
+ { params: @params,
51
+ iter: @iter,
52
+ update: @update }
53
+ end
54
+
55
+ # Load marshal data.
56
+ # @return [nil]
57
+ def marshal_load(obj)
58
+ @params = obj[:params]
59
+ @iter = obj[:iter]
60
+ @update = obj[:update]
61
+ nil
62
+ end
63
+ end
64
+ end
65
+ end
@@ -0,0 +1,144 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/base/base_estimator'
5
+
6
+ module Rumale
7
+ module Optimizer
8
+ # YellowFin is a class that implements YellowFin optimizer.
9
+ #
10
+ # @example
11
+ # optimizer = Rumale::Optimizer::YellowFin.new(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20)
12
+ # estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
13
+ # estimator.fit(samples, values)
14
+ #
15
+ # *Reference*
16
+ # - J. Zhang and I. Mitliagkas, "YellowFin and the Art of Momentum Tuning," CoRR abs/1706.03471, 2017.
17
+ class YellowFin
18
+ include Base::BaseEstimator
19
+ include Validation
20
+
21
+ # Create a new optimizer with YellowFin.
22
+ #
23
+ # @param learning_rate [Float] The initial value of learning rate.
24
+ # @param momentum [Float] The initial value of momentum.
25
+ # @param decay [Float] The smooting parameter.
26
+ # @param window_width [Integer] The sliding window width for searching curvature range.
27
+ def initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20)
28
+ check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay)
29
+ check_params_integer(window_width: window_width)
30
+ check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay, window_width: window_width)
31
+ @params = {}
32
+ @params[:learning_rate] = learning_rate
33
+ @params[:momentum] = momentum
34
+ @params[:decay] = decay
35
+ @params[:window_width] = window_width
36
+ @smth_learning_rate = learning_rate
37
+ @smth_momentum = momentum
38
+ @grad_norms = nil
39
+ @grad_norm_min = 0.0
40
+ @grad_norm_max = 0.0
41
+ @grad_mean_sqr = 0.0
42
+ @grad_mean = 0.0
43
+ @grad_var = 0.0
44
+ @grad_norm_mean = 0.0
45
+ @curve_mean = 0.0
46
+ @distance_mean = 0.0
47
+ @update = nil
48
+ end
49
+
50
+ # Calculate the updated weight with adaptive momentum coefficient and learning rate.
51
+ #
52
+ # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
53
+ # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
54
+ # @return [Numo::DFloat] (shape: [n_feautres]) The updated weight.
55
+ def call(weight, gradient)
56
+ @update ||= Numo::DFloat.zeros(weight.shape[0])
57
+ curvature_range(gradient)
58
+ gradient_variance(gradient)
59
+ distance_to_optimum(gradient)
60
+ @smth_momentum = @params[:decay] * @smth_momentum + (1 - @params[:decay]) * current_momentum
61
+ @smth_learning_rate = @params[:decay] * @smth_learning_rate + (1 - @params[:decay]) * current_learning_rate
62
+ @update = @smth_momentum * @update - @smth_learning_rate * gradient
63
+ weight + @update
64
+ end
65
+
66
+ private
67
+
68
+ def current_momentum
69
+ dr = Math.sqrt(@grad_norm_max / @grad_norm_min + 1.0e-8)
70
+ [cubic_root**2, ((dr - 1) / (dr + 1))**2].max
71
+ end
72
+
73
+ def current_learning_rate
74
+ (1.0 - Math.sqrt(@params[:momentum]))**2 / (@grad_norm_min + 1.0e-8)
75
+ end
76
+
77
+ def cubic_root
78
+ p = (@distance_mean**2 * @grad_norm_min**2) / (2 * @grad_var + 1.0e-8)
79
+ w3 = (-Math.sqrt(p**2 + 4.fdiv(27) * p**3) - p).fdiv(2)
80
+ w = (w3 >= 0.0 ? 1 : -1) * w3.abs**1.fdiv(3)
81
+ y = w - p / (3 * w + 1.0e-8)
82
+ y + 1
83
+ end
84
+
85
+ def curvature_range(gradient)
86
+ @grad_norms ||= []
87
+ @grad_norms.push((gradient**2).sum)
88
+ @grad_norms.shift(@grad_norms.size - @params[:window_width]) if @grad_norms.size > @params[:window_width]
89
+ @grad_norm_min = @params[:decay] * @grad_norm_min + (1 - @params[:decay]) * @grad_norms.min
90
+ @grad_norm_max = @params[:decay] * @grad_norm_max + (1 - @params[:decay]) * @grad_norms.max
91
+ end
92
+
93
+ def gradient_variance(gradient)
94
+ @grad_mean_sqr = @params[:decay] * @grad_mean_sqr + (1 - @params[:decay]) * gradient**2
95
+ @grad_mean = @params[:decay] * @grad_mean + (1 - @params[:decay]) * gradient
96
+ @grad_var = (@grad_mean_sqr - @grad_mean**2).sum
97
+ end
98
+
99
+ def distance_to_optimum(gradient)
100
+ grad_sqr = (gradient**2).sum
101
+ @grad_norm_mean = @params[:decay] * @grad_norm_mean + (1 - @params[:decay]) * Math.sqrt(grad_sqr + 1.0e-8)
102
+ @curve_mean = @params[:decay] * @curve_mean + (1 - @params[:decay]) * grad_sqr
103
+ @distance_mean = @params[:decay] * @distance_mean + (1 - @params[:decay]) * (@grad_norm_mean / @curve_mean)
104
+ end
105
+
106
+ # Dump marshal data.
107
+ # @return [Hash] The marshal data.
108
+ def marshal_dump
109
+ { params: @params,
110
+ smth_learning_rate: @smth_learning_rate,
111
+ smth_momentum: @smth_momentum,
112
+ grad_norms: @grad_norms,
113
+ grad_norm_min: @grad_norm_min,
114
+ grad_norm_max: @grad_norm_max,
115
+ grad_mean_sqr: @grad_mean_sqr,
116
+ grad_mean: @grad_mean,
117
+ grad_var: @grad_var,
118
+ grad_norm_mean: @grad_norm_mean,
119
+ curve_mean: @curve_mean,
120
+ distance_mean: @distance_mean,
121
+ update: @update }
122
+ end
123
+
124
+ # Load marshal data.
125
+ # @return [nis]
126
+ def marshal_load(obj)
127
+ @params = obj[:params]
128
+ @smth_learning_rate = obj[:smth_learning_rate]
129
+ @smth_momentum = obj[:smth_momentum]
130
+ @grad_norms = obj[:grad_norms]
131
+ @grad_norm_min = obj[:grad_norm_min]
132
+ @grad_norm_max = obj[:grad_norm_max]
133
+ @grad_mean_sqr = obj[:grad_mean_sqr]
134
+ @grad_mean = obj[:grad_mean]
135
+ @grad_var = obj[:grad_var]
136
+ @grad_norm_mean = obj[:grad_norm_mean]
137
+ @curve_mean = obj[:curve_mean]
138
+ @distance_mean = obj[:distance_mean]
139
+ @update = obj[:update]
140
+ nil
141
+ end
142
+ end
143
+ end
144
+ end