rumale 0.23.3 → 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE.txt +5 -1
  3. data/README.md +3 -288
  4. data/lib/rumale/version.rb +1 -1
  5. data/lib/rumale.rb +20 -131
  6. metadata +252 -150
  7. data/CHANGELOG.md +0 -643
  8. data/CODE_OF_CONDUCT.md +0 -74
  9. data/ext/rumale/extconf.rb +0 -37
  10. data/ext/rumale/rumaleext.c +0 -545
  11. data/ext/rumale/rumaleext.h +0 -12
  12. data/lib/rumale/base/base_estimator.rb +0 -49
  13. data/lib/rumale/base/classifier.rb +0 -36
  14. data/lib/rumale/base/cluster_analyzer.rb +0 -31
  15. data/lib/rumale/base/evaluator.rb +0 -17
  16. data/lib/rumale/base/regressor.rb +0 -36
  17. data/lib/rumale/base/splitter.rb +0 -21
  18. data/lib/rumale/base/transformer.rb +0 -22
  19. data/lib/rumale/clustering/dbscan.rb +0 -123
  20. data/lib/rumale/clustering/gaussian_mixture.rb +0 -218
  21. data/lib/rumale/clustering/hdbscan.rb +0 -291
  22. data/lib/rumale/clustering/k_means.rb +0 -122
  23. data/lib/rumale/clustering/k_medoids.rb +0 -141
  24. data/lib/rumale/clustering/mini_batch_k_means.rb +0 -139
  25. data/lib/rumale/clustering/power_iteration.rb +0 -127
  26. data/lib/rumale/clustering/single_linkage.rb +0 -203
  27. data/lib/rumale/clustering/snn.rb +0 -76
  28. data/lib/rumale/clustering/spectral_clustering.rb +0 -115
  29. data/lib/rumale/dataset.rb +0 -246
  30. data/lib/rumale/decomposition/factor_analysis.rb +0 -150
  31. data/lib/rumale/decomposition/fast_ica.rb +0 -188
  32. data/lib/rumale/decomposition/nmf.rb +0 -124
  33. data/lib/rumale/decomposition/pca.rb +0 -159
  34. data/lib/rumale/ensemble/ada_boost_classifier.rb +0 -179
  35. data/lib/rumale/ensemble/ada_boost_regressor.rb +0 -160
  36. data/lib/rumale/ensemble/extra_trees_classifier.rb +0 -139
  37. data/lib/rumale/ensemble/extra_trees_regressor.rb +0 -125
  38. data/lib/rumale/ensemble/gradient_boosting_classifier.rb +0 -306
  39. data/lib/rumale/ensemble/gradient_boosting_regressor.rb +0 -237
  40. data/lib/rumale/ensemble/random_forest_classifier.rb +0 -189
  41. data/lib/rumale/ensemble/random_forest_regressor.rb +0 -153
  42. data/lib/rumale/ensemble/stacking_classifier.rb +0 -215
  43. data/lib/rumale/ensemble/stacking_regressor.rb +0 -163
  44. data/lib/rumale/ensemble/voting_classifier.rb +0 -126
  45. data/lib/rumale/ensemble/voting_regressor.rb +0 -82
  46. data/lib/rumale/evaluation_measure/accuracy.rb +0 -29
  47. data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +0 -74
  48. data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +0 -56
  49. data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +0 -53
  50. data/lib/rumale/evaluation_measure/explained_variance_score.rb +0 -39
  51. data/lib/rumale/evaluation_measure/f_score.rb +0 -50
  52. data/lib/rumale/evaluation_measure/function.rb +0 -147
  53. data/lib/rumale/evaluation_measure/log_loss.rb +0 -45
  54. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +0 -29
  55. data/lib/rumale/evaluation_measure/mean_squared_error.rb +0 -29
  56. data/lib/rumale/evaluation_measure/mean_squared_log_error.rb +0 -29
  57. data/lib/rumale/evaluation_measure/median_absolute_error.rb +0 -30
  58. data/lib/rumale/evaluation_measure/mutual_information.rb +0 -49
  59. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +0 -53
  60. data/lib/rumale/evaluation_measure/precision.rb +0 -50
  61. data/lib/rumale/evaluation_measure/precision_recall.rb +0 -96
  62. data/lib/rumale/evaluation_measure/purity.rb +0 -40
  63. data/lib/rumale/evaluation_measure/r2_score.rb +0 -43
  64. data/lib/rumale/evaluation_measure/recall.rb +0 -50
  65. data/lib/rumale/evaluation_measure/roc_auc.rb +0 -130
  66. data/lib/rumale/evaluation_measure/silhouette_score.rb +0 -82
  67. data/lib/rumale/feature_extraction/feature_hasher.rb +0 -110
  68. data/lib/rumale/feature_extraction/hash_vectorizer.rb +0 -155
  69. data/lib/rumale/feature_extraction/tfidf_transformer.rb +0 -113
  70. data/lib/rumale/kernel_approximation/nystroem.rb +0 -126
  71. data/lib/rumale/kernel_approximation/rbf.rb +0 -102
  72. data/lib/rumale/kernel_machine/kernel_fda.rb +0 -120
  73. data/lib/rumale/kernel_machine/kernel_pca.rb +0 -97
  74. data/lib/rumale/kernel_machine/kernel_ridge.rb +0 -82
  75. data/lib/rumale/kernel_machine/kernel_ridge_classifier.rb +0 -92
  76. data/lib/rumale/kernel_machine/kernel_svc.rb +0 -193
  77. data/lib/rumale/linear_model/base_sgd.rb +0 -285
  78. data/lib/rumale/linear_model/elastic_net.rb +0 -119
  79. data/lib/rumale/linear_model/lasso.rb +0 -115
  80. data/lib/rumale/linear_model/linear_regression.rb +0 -201
  81. data/lib/rumale/linear_model/logistic_regression.rb +0 -275
  82. data/lib/rumale/linear_model/nnls.rb +0 -137
  83. data/lib/rumale/linear_model/ridge.rb +0 -209
  84. data/lib/rumale/linear_model/svc.rb +0 -213
  85. data/lib/rumale/linear_model/svr.rb +0 -132
  86. data/lib/rumale/manifold/mds.rb +0 -155
  87. data/lib/rumale/manifold/tsne.rb +0 -222
  88. data/lib/rumale/metric_learning/fisher_discriminant_analysis.rb +0 -113
  89. data/lib/rumale/metric_learning/mlkr.rb +0 -161
  90. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +0 -167
  91. data/lib/rumale/model_selection/cross_validation.rb +0 -125
  92. data/lib/rumale/model_selection/function.rb +0 -42
  93. data/lib/rumale/model_selection/grid_search_cv.rb +0 -225
  94. data/lib/rumale/model_selection/group_k_fold.rb +0 -93
  95. data/lib/rumale/model_selection/group_shuffle_split.rb +0 -115
  96. data/lib/rumale/model_selection/k_fold.rb +0 -81
  97. data/lib/rumale/model_selection/shuffle_split.rb +0 -90
  98. data/lib/rumale/model_selection/stratified_k_fold.rb +0 -99
  99. data/lib/rumale/model_selection/stratified_shuffle_split.rb +0 -118
  100. data/lib/rumale/model_selection/time_series_split.rb +0 -91
  101. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +0 -83
  102. data/lib/rumale/naive_bayes/base_naive_bayes.rb +0 -47
  103. data/lib/rumale/naive_bayes/bernoulli_nb.rb +0 -82
  104. data/lib/rumale/naive_bayes/complement_nb.rb +0 -85
  105. data/lib/rumale/naive_bayes/gaussian_nb.rb +0 -69
  106. data/lib/rumale/naive_bayes/multinomial_nb.rb +0 -74
  107. data/lib/rumale/naive_bayes/negation_nb.rb +0 -71
  108. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +0 -133
  109. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +0 -108
  110. data/lib/rumale/nearest_neighbors/vp_tree.rb +0 -132
  111. data/lib/rumale/neural_network/adam.rb +0 -56
  112. data/lib/rumale/neural_network/base_mlp.rb +0 -248
  113. data/lib/rumale/neural_network/mlp_classifier.rb +0 -120
  114. data/lib/rumale/neural_network/mlp_regressor.rb +0 -90
  115. data/lib/rumale/pairwise_metric.rb +0 -152
  116. data/lib/rumale/pipeline/feature_union.rb +0 -69
  117. data/lib/rumale/pipeline/pipeline.rb +0 -175
  118. data/lib/rumale/preprocessing/bin_discretizer.rb +0 -93
  119. data/lib/rumale/preprocessing/binarizer.rb +0 -60
  120. data/lib/rumale/preprocessing/kernel_calculator.rb +0 -92
  121. data/lib/rumale/preprocessing/l1_normalizer.rb +0 -62
  122. data/lib/rumale/preprocessing/l2_normalizer.rb +0 -63
  123. data/lib/rumale/preprocessing/label_binarizer.rb +0 -89
  124. data/lib/rumale/preprocessing/label_encoder.rb +0 -79
  125. data/lib/rumale/preprocessing/max_abs_scaler.rb +0 -61
  126. data/lib/rumale/preprocessing/max_normalizer.rb +0 -62
  127. data/lib/rumale/preprocessing/min_max_scaler.rb +0 -76
  128. data/lib/rumale/preprocessing/one_hot_encoder.rb +0 -100
  129. data/lib/rumale/preprocessing/ordinal_encoder.rb +0 -109
  130. data/lib/rumale/preprocessing/polynomial_features.rb +0 -109
  131. data/lib/rumale/preprocessing/standard_scaler.rb +0 -71
  132. data/lib/rumale/probabilistic_output.rb +0 -114
  133. data/lib/rumale/tree/base_decision_tree.rb +0 -150
  134. data/lib/rumale/tree/decision_tree_classifier.rb +0 -150
  135. data/lib/rumale/tree/decision_tree_regressor.rb +0 -116
  136. data/lib/rumale/tree/extra_tree_classifier.rb +0 -107
  137. data/lib/rumale/tree/extra_tree_regressor.rb +0 -94
  138. data/lib/rumale/tree/gradient_tree_regressor.rb +0 -202
  139. data/lib/rumale/tree/node.rb +0 -39
  140. data/lib/rumale/utils.rb +0 -42
  141. data/lib/rumale/validation.rb +0 -128
  142. data/lib/rumale/values.rb +0 -13
@@ -1,167 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/base/base_estimator'
4
- require 'rumale/base/transformer'
5
- require 'rumale/utils'
6
- require 'rumale/pairwise_metric'
7
- require 'lbfgsb'
8
-
9
- module Rumale
10
- module MetricLearning
11
- # NeighbourhoodComponentAnalysis is a class that implements Neighbourhood Component Analysis.
12
- #
13
- # @example
14
- # require 'rumale'
15
- #
16
- # transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
17
- # transformer.fit(training_samples, traininig_labels)
18
- # low_samples = transformer.transform(testing_samples)
19
- #
20
- # *Reference*
21
- # - Goldberger, J., Roweis, S., Hinton, G., and Salakhutdinov, R., "Neighbourhood Component Analysis," Advances in NIPS'17, pp. 513--520, 2005.
22
- class NeighbourhoodComponentAnalysis
23
- include Base::BaseEstimator
24
- include Base::Transformer
25
-
26
- # Returns the neighbourhood components.
27
- # @return [Numo::DFloat] (shape: [n_components, n_features])
28
- attr_reader :components
29
-
30
- # Return the number of iterations run for optimization
31
- # @return [Integer]
32
- attr_reader :n_iter
33
-
34
- # Return the random generator.
35
- # @return [Random]
36
- attr_reader :rng
37
-
38
- # Create a new transformer with NeighbourhoodComponentAnalysis.
39
- #
40
- # @param n_components [Integer] The number of components.
41
- # @param init [String] The initialization method for components ('random' or 'pca').
42
- # @param max_iter [Integer] The maximum number of iterations.
43
- # @param tol [Float] The tolerance of termination criterion.
44
- # This value is given as tol / Lbfgsb::DBL_EPSILON to the factr argument of Lbfgsb.minimize method.
45
- # @param verbose [Boolean] The flag indicating whether to output loss during iteration.
46
- # If true is given, 'iterate.dat' file is generated by lbfgsb.rb.
47
- # @param random_seed [Integer] The seed value using to initialize the random generator.
48
- def initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verbose: false, random_seed: nil)
49
- check_params_numeric_or_nil(n_components: n_components, random_seed: random_seed)
50
- check_params_numeric(max_iter: max_iter, tol: tol)
51
- check_params_string(init: init)
52
- check_params_boolean(verbose: verbose)
53
- @params = {}
54
- @params[:n_components] = n_components
55
- @params[:init] = init
56
- @params[:max_iter] = max_iter
57
- @params[:tol] = tol
58
- @params[:verbose] = verbose
59
- @params[:random_seed] = random_seed
60
- @params[:random_seed] ||= srand
61
- @components = nil
62
- @n_iter = nil
63
- @rng = Random.new(@params[:random_seed])
64
- end
65
-
66
- # Fit the model with given training data.
67
- #
68
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
69
- # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
70
- # @return [NeighbourhoodComponentAnalysis] The learned classifier itself.
71
- def fit(x, y)
72
- x = check_convert_sample_array(x)
73
- y = check_convert_label_array(y)
74
- check_sample_label_size(x, y)
75
- n_features = x.shape[1]
76
- n_components = if @params[:n_components].nil?
77
- n_features
78
- else
79
- [n_features, @params[:n_components]].min
80
- end
81
- @components, @n_iter = optimize_components(x, y, n_features, n_components)
82
- self
83
- end
84
-
85
- # Fit the model with training data, and then transform them with the learned model.
86
- #
87
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
88
- # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
89
- # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
90
- def fit_transform(x, y)
91
- x = check_convert_sample_array(x)
92
- y = check_convert_label_array(y)
93
- fit(x, y).transform(x)
94
- end
95
-
96
- # Transform the given data with the learned model.
97
- #
98
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
99
- # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
100
- def transform(x)
101
- x = check_convert_sample_array(x)
102
- x.dot(@components.transpose)
103
- end
104
-
105
- private
106
-
107
- def init_components(x, n_features, n_components)
108
- if @params[:init] == 'pca'
109
- pca = Rumale::Decomposition::PCA.new(n_components: n_components)
110
- pca.fit(x).components.flatten.dup
111
- else
112
- Rumale::Utils.rand_normal([n_features, n_components], @rng.dup).flatten.dup
113
- end
114
- end
115
-
116
- def optimize_components(x, y, n_features, n_components)
117
- # initialize components.
118
- comp_init = init_components(x, n_features, n_components)
119
- # initialize optimization results.
120
- res = {}
121
- res[:x] = comp_init
122
- res[:n_iter] = 0
123
- # perform optimization.
124
- verbose = @params[:verbose] ? 1 : -1
125
- res = Lbfgsb.minimize(
126
- fnc: method(:nca_fnc), jcb: true, x_init: comp_init, args: [x, y],
127
- maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON, verbose: verbose
128
- )
129
- # return the results.
130
- n_iter = res[:n_iter]
131
- comps = n_components == 1 ? res[:x].dup : res[:x].reshape(n_components, n_features)
132
- [comps, n_iter]
133
- end
134
-
135
- def nca_fnc(w, x, y)
136
- # initialize some variables.
137
- n_samples, n_features = x.shape
138
- n_components = w.size / n_features
139
- # projection.
140
- w = w.reshape(n_components, n_features)
141
- z = x.dot(w.transpose)
142
- # calculate probability matrix.
143
- prob_mat = probability_matrix(z)
144
- # calculate loss and gradient.
145
- # NOTE:
146
- # NCA attempts to maximize its objective function.
147
- # For the minization algorithm, the objective function value is subtracted from the maixmum value (n_samples).
148
- mask_mat = y.expand_dims(1).eq(y)
149
- masked_prob_mat = prob_mat * mask_mat
150
- loss = n_samples - masked_prob_mat.sum
151
- sum_probs = masked_prob_mat.sum(1)
152
- weight_mat = (sum_probs.expand_dims(1) * prob_mat - masked_prob_mat)
153
- weight_mat += weight_mat.transpose
154
- weight_mat = weight_mat.sum(0).diag - weight_mat
155
- gradient = -2 * z.transpose.dot(weight_mat).dot(x)
156
- [loss, gradient.flatten.dup]
157
- end
158
-
159
- def probability_matrix(z)
160
- prob_mat = Numo::NMath.exp(-Rumale::PairwiseMetric.squared_error(z))
161
- prob_mat[prob_mat.diag_indices] = 0.0
162
- prob_mat /= prob_mat.sum(1).expand_dims(1)
163
- prob_mat
164
- end
165
- end
166
- end
167
- end
@@ -1,125 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
- require 'rumale/base/base_estimator'
5
- require 'rumale/base/classifier'
6
- require 'rumale/base/regressor'
7
- require 'rumale/base/splitter'
8
- require 'rumale/base/evaluator'
9
- require 'rumale/evaluation_measure/log_loss'
10
-
11
- module Rumale
12
- # This module consists of the classes for model validation techniques.
13
- module ModelSelection
14
- # CrossValidation is a class that evaluates a given classifier with cross-validation method.
15
- #
16
- # @example
17
- # svc = Rumale::LinearModel::SVC.new
18
- # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
19
- # cv = Rumale::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
20
- # report = cv.perform(samples, labels)
21
- # mean_test_score = report[:test_score].inject(:+) / kf.n_splits
22
- #
23
- class CrossValidation
24
- include Validation
25
-
26
- # Return the classifier of which performance is evaluated.
27
- # @return [Classifier]
28
- attr_reader :estimator
29
-
30
- # Return the splitter that divides dataset.
31
- # @return [Splitter]
32
- attr_reader :splitter
33
-
34
- # Return the evaluator that calculates score.
35
- # @return [Evaluator]
36
- attr_reader :evaluator
37
-
38
- # Return the flag indicating whether to caculate the score of training dataset.
39
- # @return [Boolean]
40
- attr_reader :return_train_score
41
-
42
- # Create a new evaluator with cross-validation method.
43
- #
44
- # @param estimator [Classifier] The classifier of which performance is evaluated.
45
- # @param splitter [Splitter] The splitter that divides dataset to training and testing dataset.
46
- # @param evaluator [Evaluator] The evaluator that calculates score of estimator results.
47
- # @param return_train_score [Boolean] The flag indicating whether to calculate the score of training dataset.
48
- def initialize(estimator: nil, splitter: nil, evaluator: nil, return_train_score: false)
49
- check_params_type(Rumale::Base::BaseEstimator, estimator: estimator)
50
- check_params_type(Rumale::Base::Splitter, splitter: splitter)
51
- check_params_type_or_nil(Rumale::Base::Evaluator, evaluator: evaluator)
52
- check_params_boolean(return_train_score: return_train_score)
53
- @estimator = estimator
54
- @splitter = splitter
55
- @evaluator = evaluator
56
- @return_train_score = return_train_score
57
- end
58
-
59
- # Perform the evalution of given classifier with cross-validation method.
60
- #
61
- # @param x [Numo::DFloat] (shape: [n_samples, n_features])
62
- # The dataset to be used to evaluate the estimator.
63
- # @param y [Numo::Int32 / Numo::DFloat] (shape: [n_samples] / [n_samples, n_outputs])
64
- # The labels to be used to evaluate the classifier / The target values to be used to evaluate the regressor.
65
- # @return [Hash] The report summarizing the results of cross-validation.
66
- # * :fit_time (Array<Float>) The calculation times of fitting the estimator for each split.
67
- # * :test_score (Array<Float>) The scores of testing dataset for each split.
68
- # * :train_score (Array<Float>) The scores of training dataset for each split. This option is nil if
69
- # the return_train_score is false.
70
- def perform(x, y)
71
- x = check_convert_sample_array(x)
72
- case @estimator
73
- when Rumale::Base::Classifier
74
- y = check_convert_label_array(y)
75
- check_sample_label_size(x, y)
76
- when Rumale::Base::Regressor
77
- y = check_convert_tvalue_array(y)
78
- check_sample_tvalue_size(x, y)
79
- else
80
- y = Numo::NArray.asarray(y)
81
- end
82
- # Initialize the report of cross validation.
83
- report = { test_score: [], train_score: nil, fit_time: [] }
84
- report[:train_score] = [] if @return_train_score
85
- # Evaluate the estimator on each split.
86
- @splitter.split(x, y).each do |train_ids, test_ids|
87
- # Split dataset into training and testing dataset.
88
- feature_ids = !kernel_machine? || train_ids
89
- train_x = x[train_ids, feature_ids]
90
- train_y = y.shape[1].nil? ? y[train_ids] : y[train_ids, true]
91
- test_x = x[test_ids, feature_ids]
92
- test_y = y.shape[1].nil? ? y[test_ids] : y[test_ids, true]
93
- # Fit the estimator.
94
- start_time = Time.now.to_i
95
- @estimator.fit(train_x, train_y)
96
- # Calculate scores and prepare the report.
97
- report[:fit_time].push(Time.now.to_i - start_time)
98
- if @evaluator.nil?
99
- report[:test_score].push(@estimator.score(test_x, test_y))
100
- report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
101
- elsif log_loss?
102
- report[:test_score].push(@evaluator.score(test_y, @estimator.predict_proba(test_x)))
103
- report[:train_score].push(@evaluator.score(train_y, @estimator.predict_proba(train_x))) if @return_train_score
104
- else
105
- report[:test_score].push(@evaluator.score(test_y, @estimator.predict(test_x)))
106
- report[:train_score].push(@evaluator.score(train_y, @estimator.predict(train_x))) if @return_train_score
107
- end
108
- end
109
- report
110
- end
111
-
112
- private
113
-
114
- def kernel_machine?
115
- class_name = @estimator.class.to_s
116
- class_name = @estimator.params[:estimator].class.to_s if class_name.include?('Multiclass')
117
- class_name.include?('KernelMachine')
118
- end
119
-
120
- def log_loss?
121
- @evaluator.is_a?(Rumale::EvaluationMeasure::LogLoss)
122
- end
123
- end
124
- end
125
- end
@@ -1,42 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/model_selection/shuffle_split'
4
- require 'rumale/model_selection/stratified_shuffle_split'
5
-
6
- module Rumale
7
- module ModelSelection
8
- module_function
9
-
10
- # Split randomly data set into test and train data.
11
- #
12
- # @example
13
- # x_train, x_test, y_train, y_test = Rumale::ModelSelection.train_test_split(x, y, test_size: 0.2, stratify: true, random_seed: 1)
14
- #
15
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The dataset to be used to generate data indices.
16
- # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used to generate data indices for stratified random permutation.
17
- # If stratify = false, this parameter is ignored.
18
- # @param test_size [Float] The ratio of number of samples for test data.
19
- # @param train_size [Float] The ratio of number of samples for train data.
20
- # If nil is given, it sets to 1 - test_size.
21
- # @param stratify [Boolean] The flag indicating whether to perform stratify split.
22
- # @param random_seed [Integer] The seed value using to initialize the random generator.
23
- # @return [Array<Numo::NArray>] The set of training and testing data.
24
- def train_test_split(x, y = nil, test_size: 0.1, train_size: nil, stratify: false, random_seed: nil)
25
- splitter = if stratify
26
- Rumale::ModelSelection::StratifiedShuffleSplit.new(
27
- n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
28
- )
29
- else
30
- Rumale::ModelSelection::ShuffleSplit.new(
31
- n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
32
- )
33
- end
34
- train_ids, test_ids = splitter.split(x, y).first
35
- x_train = x[train_ids, true].dup
36
- y_train = y[train_ids].dup
37
- x_test = x[test_ids, true].dup
38
- y_test = y[test_ids].dup
39
- [x_train, x_test, y_train, y_test]
40
- end
41
- end
42
- end
@@ -1,225 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/validation'
4
- require 'rumale/base/base_estimator'
5
- require 'rumale/base/evaluator'
6
- require 'rumale/base/splitter'
7
- require 'rumale/pipeline/pipeline'
8
-
9
- module Rumale
10
- module ModelSelection
11
- # GridSearchCV is a class that performs hyperparameter optimization with grid search method.
12
- #
13
- # @example
14
- # rfc = Rumale::Ensemble::RandomForestClassifier.new(random_seed: 1)
15
- # pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
16
- # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
17
- # gs = Rumale::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
18
- # gs.fit(samples, labels)
19
- # p gs.cv_results
20
- # p gs.best_params
21
- #
22
- # @example
23
- # rbf = Rumale::KernelApproximation::RBF.new(random_seed: 1)
24
- # svc = Rumale::LinearModel::SVC.new(random_seed: 1)
25
- # pipe = Rumale::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
26
- # pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
27
- # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
28
- # gs = Rumale::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
29
- # gs.fit(samples, labels)
30
- # p gs.cv_results
31
- # p gs.best_params
32
- #
33
- class GridSearchCV
34
- include Base::BaseEstimator
35
- include Validation
36
-
37
- # Return the result of cross validation for each parameter.
38
- # @return [Hash]
39
- attr_reader :cv_results
40
-
41
- # Return the score of the estimator learned with the best parameter.
42
- # @return [Float]
43
- attr_reader :best_score
44
-
45
- # Return the best parameter set.
46
- # @return [Hash]
47
- attr_reader :best_params
48
-
49
- # Return the index of the best parameter.
50
- # @return [Integer]
51
- attr_reader :best_index
52
-
53
- # Return the estimator learned with the best parameter.
54
- # @return [Estimator]
55
- attr_reader :best_estimator
56
-
57
- # Create a new grid search method.
58
- #
59
- # @param estimator [Classifier/Regresor] The estimator to be searched for optimal parameters with grid search method.
60
- # @param param_grid [Array<Hash>] The parameter sets is represented with array of hash that
61
- # consists of parameter names as keys and array of parameter values as values.
62
- # @param splitter [Splitter] The splitter that divides dataset to training and testing dataset on cross validation.
63
- # @param evaluator [Evaluator] The evaluator that calculates score of estimator results on cross validation.
64
- # If nil is given, the score method of estimator is used to evaluation.
65
- # @param greater_is_better [Boolean] The flag that indicates whether the estimator is better as
66
- # evaluation score is larger.
67
- def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
68
- check_params_type(Rumale::Base::BaseEstimator, estimator: estimator)
69
- check_params_type(Rumale::Base::Splitter, splitter: splitter)
70
- check_params_type_or_nil(Rumale::Base::Evaluator, evaluator: evaluator)
71
- check_params_boolean(greater_is_better: greater_is_better)
72
- @params = {}
73
- @params[:param_grid] = valid_param_grid(param_grid)
74
- @params[:estimator] = Marshal.load(Marshal.dump(estimator))
75
- @params[:splitter] = Marshal.load(Marshal.dump(splitter))
76
- @params[:evaluator] = Marshal.load(Marshal.dump(evaluator))
77
- @params[:greater_is_better] = greater_is_better
78
- @cv_results = nil
79
- @best_score = nil
80
- @best_params = nil
81
- @best_index = nil
82
- @best_estimator = nil
83
- end
84
-
85
- # Fit the model with given training data and all sets of parameters.
86
- #
87
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
88
- # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.
89
- # @return [GridSearchCV] The learned estimator with grid search.
90
- def fit(x, y)
91
- x = check_convert_sample_array(x)
92
-
93
- init_attrs
94
-
95
- param_combinations.each do |prm_set|
96
- prm_set.each do |prms|
97
- report = perform_cross_validation(x, y, prms)
98
- store_cv_result(prms, report)
99
- end
100
- end
101
-
102
- find_best_params
103
-
104
- @best_estimator = configurated_estimator(@best_params)
105
- @best_estimator.fit(x, y)
106
- self
107
- end
108
-
109
- # Call the decision_function method of learned estimator with the best parameter.
110
- #
111
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
112
- # @return [Numo::DFloat] (shape: [n_samples]) Confidence score per sample.
113
- def decision_function(x)
114
- x = check_convert_sample_array(x)
115
- @best_estimator.decision_function(x)
116
- end
117
-
118
- # Call the predict method of learned estimator with the best parameter.
119
- #
120
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to obtain prediction result.
121
- # @return [Numo::NArray] Predicted results.
122
- def predict(x)
123
- x = check_convert_sample_array(x)
124
- @best_estimator.predict(x)
125
- end
126
-
127
- # Call the predict_log_proba method of learned estimator with the best parameter.
128
- #
129
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
130
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
131
- def predict_log_proba(x)
132
- x = check_convert_sample_array(x)
133
- @best_estimator.predict_log_proba(x)
134
- end
135
-
136
- # Call the predict_proba method of learned estimator with the best parameter.
137
- #
138
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
139
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
140
- def predict_proba(x)
141
- x = check_convert_sample_array(x)
142
- @best_estimator.predict_proba(x)
143
- end
144
-
145
- # Call the score method of learned estimator with the best parameter.
146
- #
147
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
148
- # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) True target values or labels for testing data.
149
- # @return [Float] The score of estimator.
150
- def score(x, y)
151
- x = check_convert_sample_array(x)
152
- @best_estimator.score(x, y)
153
- end
154
-
155
- private
156
-
157
- def valid_param_grid(grid)
158
- raise TypeError, 'Expect class of param_grid to be Hash or Array' unless grid.is_a?(Hash) || grid.is_a?(Array)
159
-
160
- grid = [grid] if grid.is_a?(Hash)
161
- grid.each do |h|
162
- raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
163
- raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all?(Array)
164
- end
165
- grid
166
- end
167
-
168
- def param_combinations
169
- @param_combinations ||= @params[:param_grid].map do |prm|
170
- x = prm.sort.to_h.map { |k, v| [k].product(v) }
171
- x[0].product(*x[1...x.size]).map(&:to_h)
172
- end
173
- end
174
-
175
- def perform_cross_validation(x, y, prms)
176
- est = configurated_estimator(prms)
177
- cv = CrossValidation.new(estimator: est, splitter: @params[:splitter],
178
- evaluator: @params[:evaluator], return_train_score: true)
179
- cv.perform(x, y)
180
- end
181
-
182
- def configurated_estimator(prms)
183
- estimator = Marshal.load(Marshal.dump(@params[:estimator]))
184
- if @params[:estimator].is_a?(Rumale::Pipeline::Pipeline)
185
- prms.each do |k, v|
186
- est_name, prm_name = k.to_s.split('__')
187
- estimator.steps[est_name.to_sym].params[prm_name.to_sym] = v
188
- end
189
- else
190
- prms.each { |k, v| estimator.params[k] = v }
191
- end
192
- estimator
193
- end
194
-
195
- def init_attrs
196
- @cv_results = %i[mean_test_score std_test_score
197
- mean_train_score std_train_score
198
- mean_fit_time std_fit_time params].map { |v| [v, []] }.to_h
199
- @best_score = nil
200
- @best_params = nil
201
- @best_index = nil
202
- @best_estimator = nil
203
- end
204
-
205
- def store_cv_result(prms, report)
206
- test_scores = Numo::DFloat[*report[:test_score]]
207
- train_scores = Numo::DFloat[*report[:train_score]]
208
- fit_times = Numo::DFloat[*report[:fit_time]]
209
- @cv_results[:mean_test_score].push(test_scores.mean)
210
- @cv_results[:std_test_score].push(test_scores.stddev)
211
- @cv_results[:mean_train_score].push(train_scores.mean)
212
- @cv_results[:std_train_score].push(train_scores.stddev)
213
- @cv_results[:mean_fit_time].push(fit_times.mean)
214
- @cv_results[:std_fit_time].push(fit_times.stddev)
215
- @cv_results[:params].push(prms)
216
- end
217
-
218
- def find_best_params
219
- @best_score = @params[:greater_is_better] ? @cv_results[:mean_test_score].max : @cv_results[:mean_test_score].min
220
- @best_index = @cv_results[:mean_test_score].index(@best_score)
221
- @best_params = @cv_results[:params][@best_index]
222
- end
223
- end
224
- end
225
- end
@@ -1,93 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'rumale/base/splitter'
4
- require 'rumale/preprocessing/label_encoder'
5
-
6
- module Rumale
7
- module ModelSelection
8
- # GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
9
- # The data points belonging to the same group do not be split into different folds.
10
- # The number of groups should be greater than or equal to the number of splits.
11
- #
12
- # @example
13
- # cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
14
- # x = Numo::DFloat.new(8, 2).rand
15
- # groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
16
- # cv.split(x, nil, groups).each do |train_ids, test_ids|
17
- # puts '---'
18
- # pp train_ids
19
- # pp test_ids
20
- # end
21
- #
22
- # # ---
23
- # # [0, 1, 2, 3, 4]
24
- # # [5, 6, 7]
25
- # # ---
26
- # # [3, 4, 5, 6, 7]
27
- # # [0, 1, 2]
28
- # # ---
29
- # # [0, 1, 2, 5, 6, 7]
30
- # # [3, 4]
31
- #
32
- class GroupKFold
33
- include Base::Splitter
34
-
35
- # Return the number of folds.
36
- # @return [Integer]
37
- attr_reader :n_splits
38
-
39
- # Create a new data splitter for grouped K-fold cross validation.
40
- #
41
- # @param n_splits [Integer] The number of folds.
42
- def initialize(n_splits: 5)
43
- check_params_numeric(n_splits: n_splits)
44
- @n_splits = n_splits
45
- end
46
-
47
- # Generate data indices for grouped K-fold cross validation.
48
- #
49
- # @overload split(x, y, groups) -> Array
50
- # @param x [Numo::DFloat] (shape: [n_samples, n_features])
51
- # The dataset to be used to generate data indices for grouped K-fold cross validation.
52
- # @param y [Numo::Int32] (shape: [n_samples])
53
- # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
54
- # @param groups [Numo::Int32] (shape: [n_samples])
55
- # The group labels to be used to generate data indices for grouped K-fold cross validation.
56
- # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
57
- def split(x, _y, groups)
58
- x = check_convert_sample_array(x)
59
- groups = check_convert_label_array(groups)
60
- check_sample_label_size(x, groups)
61
-
62
- encoder = Rumale::Preprocessing::LabelEncoder.new
63
- groups = encoder.fit_transform(groups)
64
- n_groups = encoder.classes.size
65
-
66
- raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' if n_groups < @n_splits
67
-
68
- n_samples_per_group = groups.bincount
69
- group_ids = n_samples_per_group.sort_index.reverse
70
- n_samples_per_group = n_samples_per_group[group_ids]
71
-
72
- n_samples_per_fold = Numo::Int32.zeros(@n_splits)
73
- group_to_fold = Numo::Int32.zeros(n_groups)
74
-
75
- n_samples_per_group.each_with_index do |weight, id|
76
- min_sample_fold_id = n_samples_per_fold.min_index
77
- n_samples_per_fold[min_sample_fold_id] += weight
78
- group_to_fold[group_ids[id]] = min_sample_fold_id
79
- end
80
-
81
- n_samples = x.shape[0]
82
- sample_ids = Array(0...n_samples)
83
- fold_ids = group_to_fold[groups]
84
-
85
- Array.new(@n_splits) do |fid|
86
- test_ids = fold_ids.eq(fid).where.to_a
87
- train_ids = sample_ids - test_ids
88
- [train_ids, test_ids]
89
- end
90
- end
91
- end
92
- end
93
- end