svmkit 0.7.3 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,87 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/transformer'
6
-
7
- module SVMKit
8
- # This module consists of the classes that perform preprocessings.
9
- module Preprocessing
10
- # Normalize samples by centering and scaling to unit variance.
11
- #
12
- # @example
13
- # normalizer = SVMKit::Preprocessing::StandardScaler.new
14
- # new_training_samples = normalizer.fit_transform(training_samples)
15
- # new_testing_samples = normalizer.transform(testing_samples)
16
- class StandardScaler
17
- include Base::BaseEstimator
18
- include Base::Transformer
19
-
20
- # Return the vector consists of the mean value for each feature.
21
- # @return [Numo::DFloat] (shape: [n_features])
22
- attr_reader :mean_vec
23
-
24
- # Return the vector consists of the standard deviation for each feature.
25
- # @return [Numo::DFloat] (shape: [n_features])
26
- attr_reader :std_vec
27
-
28
- # Create a new normalizer for centering and scaling to unit variance.
29
- def initialize
30
- @params = {}
31
- @mean_vec = nil
32
- @std_vec = nil
33
- end
34
-
35
- # Calculate the mean value and standard deviation of each feature for scaling.
36
- #
37
- # @overload fit(x) -> StandardScaler
38
- #
39
- # @param x [Numo::DFloat] (shape: [n_samples, n_features])
40
- # The samples to calculate the mean values and standard deviations.
41
- # @return [StandardScaler]
42
- def fit(x, _y = nil)
43
- SVMKit::Validation.check_sample_array(x)
44
- @mean_vec = x.mean(0)
45
- @std_vec = x.stddev(0)
46
- self
47
- end
48
-
49
- # Calculate the mean values and standard deviations, and then normalize samples using them.
50
- #
51
- # @overload fit_transform(x) -> Numo::DFloat
52
- #
53
- # @param x [Numo::DFloat] (shape: [n_samples, n_features])
54
- # The samples to calculate the mean values and standard deviations.
55
- # @return [Numo::DFloat] The scaled samples.
56
- def fit_transform(x, _y = nil)
57
- SVMKit::Validation.check_sample_array(x)
58
- fit(x).transform(x)
59
- end
60
-
61
- # Perform standardization the given samples.
62
- #
63
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be scaled.
64
- # @return [Numo::DFloat] The scaled samples.
65
- def transform(x)
66
- SVMKit::Validation.check_sample_array(x)
67
- n_samples, = x.shape
68
- (x - @mean_vec.tile(n_samples, 1)) / @std_vec.tile(n_samples, 1)
69
- end
70
-
71
- # Dump marshal data.
72
- # @return [Hash] The marshal data about StandardScaler.
73
- def marshal_dump
74
- { mean_vec: @mean_vec,
75
- std_vec: @std_vec }
76
- end
77
-
78
- # Load marshal data.
79
- # @return [nil]
80
- def marshal_load(obj)
81
- @mean_vec = obj[:mean_vec]
82
- @std_vec = obj[:std_vec]
83
- nil
84
- end
85
- end
86
- end
87
- end
@@ -1,112 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- # Module for calculating posterior class probabilities with SVM outputs.
5
- # This module is used for internal processes.
6
- #
7
- # @example
8
- # estimator = SVMKit::LinearModel::SVC.new
9
- # estimator.fit(x, bin_y)
10
- # df = estimator.decision_function(x)
11
- # params = SVMKit::ProbabilisticOutput.fit_sigmoid(df, bin_y)
12
- # probs = 1 / (Numo::NMath.exp(params[0] * df + params[1]) + 1)
13
- #
14
- # *Reference*
15
- # 1. J C. Platt, "Probabilistic Outputs for Support Vector Machines and Comparisons to Regularized Likelihood Methods," Adv. Large Margin Classifiers, pp. 61--74, 2000.
16
- # 1. H-T Lin, C-J Lin, and R C.Weng, "A Note on Platt's Probabilistic Outputs for Support Vector Machines," J. Machine Learning, Vol. 63 (3), pp. 267--276, 2007.
17
- module ProbabilisticOutput
18
- class << self
19
- # Fit the probabilistic model for binary SVM outputs.
20
- #
21
- # @param df [Numo::DFloat] (shape: [n_samples]) The outputs of decision function to be used for fitting the model.
22
- # @param bin_y [Numo::Int32] (shape: [n_samples]) The binary labels to be used for fitting the model.
23
- # @param max_iter [Integer] The maximum number of iterations.
24
- # @param min_step [Float] The minimum step of Newton's method.
25
- # @param sigma [Float] The parameter to avoid hessian matrix from becoming singular matrix.
26
- # @return [Numo::DFloat] (shape: 2) The parameters of the model.
27
- def fit_sigmoid(df, bin_y, max_iter = 100, min_step = 1e-10, sigma = 1e-12)
28
- # Initialize some variables.
29
- n_samples = bin_y.size
30
- negative_label = bin_y.to_a.uniq.min
31
- pos = bin_y.ne(negative_label)
32
- neg = bin_y.eq(negative_label)
33
- n_pos_samples = pos.count
34
- n_neg_samples = neg.count
35
- target_probs = Numo::DFloat.zeros(n_samples)
36
- target_probs[pos] = (n_pos_samples + 1) / (n_pos_samples + 2.0)
37
- target_probs[neg] = 1 / (n_neg_samples + 2.0)
38
- alpha = 0.0
39
- beta = Math.log((n_neg_samples + 1) / (n_pos_samples + 1.0))
40
- err = error_function(target_probs, df, alpha, beta)
41
- # Optimize parameters for class porbability calculation.
42
- old_grad_vec = Numo::DFloat.zeros(2)
43
- max_iter.times do
44
- # Calculate gradient and hessian matrix.
45
- probs = predicted_probs(df, alpha, beta)
46
- grad_vec = gradient(target_probs, probs, df)
47
- hess_mat = hessian_matrix(probs, df, sigma)
48
- break if grad_vec.abs.lt(1e-5).count == 2
49
- break if (old_grad_vec - grad_vec).abs.sum < 1e-5
50
- old_grad_vec = grad_vec
51
- # Calculate Newton directions.
52
- dirs_vec = directions(grad_vec, hess_mat)
53
- grad_dir = grad_vec.dot(dirs_vec)
54
- stepsize = 2.0
55
- while stepsize >= min_step
56
- stepsize *= 0.5
57
- new_alpha = alpha + stepsize * dirs_vec[0]
58
- new_beta = beta + stepsize * dirs_vec[1]
59
- new_err = error_function(target_probs, df, new_alpha, new_beta)
60
- next unless new_err < err + 0.0001 * stepsize * grad_dir
61
- alpha = new_alpha
62
- beta = new_beta
63
- err = new_err
64
- break
65
- end
66
- end
67
- Numo::DFloat[alpha, beta]
68
- end
69
-
70
- private
71
-
72
- def error_function(target_probs, df, alpha, beta)
73
- fn = alpha * df + beta
74
- pos = fn.ge(0.0)
75
- neg = fn.lt(0.0)
76
- err = 0.0
77
- err += (target_probs[pos] * fn[pos] + Numo::NMath.log(1 + Numo::NMath.exp(-fn[pos]))).sum if pos.count > 0
78
- err += ((target_probs[neg] - 1) * fn[neg] + Numo::NMath.log(1 + Numo::NMath.exp(fn[neg]))).sum if neg.count > 0
79
- err
80
- end
81
-
82
- def predicted_probs(df, alpha, beta)
83
- fn = alpha * df + beta
84
- pos = fn.ge(0.0)
85
- neg = fn.lt(0.0)
86
- probs = Numo::DFloat.zeros(df.shape[0])
87
- probs[pos] = Numo::NMath.exp(-fn[pos]) / (1 + Numo::NMath.exp(-fn[pos])) if pos.count > 0
88
- probs[neg] = 1 / (1 + Numo::NMath.exp(fn[neg])) if neg.count > 0
89
- probs
90
- end
91
-
92
- def gradient(target_probs, probs, df)
93
- sub = target_probs - probs
94
- Numo::DFloat[(df * sub).sum, sub.sum]
95
- end
96
-
97
- def hessian_matrix(probs, df, sigma)
98
- sub = probs * (1 - probs)
99
- h11 = (df * df * sub).sum + sigma
100
- h22 = sub.sum + sigma
101
- h21 = (df * sub).sum
102
- Numo::DFloat[[h11, h21], [h21, h22]]
103
- end
104
-
105
- def directions(grad_vec, hess_mat)
106
- det = hess_mat[0, 0] * hess_mat[1, 1] - hess_mat[0, 1] * hess_mat[1, 0]
107
- inv_hess_mat = Numo::DFloat[[hess_mat[1, 1], -hess_mat[0, 1]], [-hess_mat[1, 0], hess_mat[0, 0]]] / det
108
- -inv_hess_mat.dot(grad_vec)
109
- end
110
- end
111
- end
112
- end
@@ -1,276 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/classifier'
6
- require 'svmkit/tree/node'
7
-
8
- module SVMKit
9
- # This module consists of the classes that implement tree models.
10
- module Tree
11
- # DecisionTreeClassifier is a class that implements decision tree for classification.
12
- #
13
- # @example
14
- # estimator =
15
- # SVMKit::Tree::DecisionTreeClassifier.new(
16
- # criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
17
- # estimator.fit(training_samples, traininig_labels)
18
- # results = estimator.predict(testing_samples)
19
- #
20
- class DecisionTreeClassifier
21
- include Base::BaseEstimator
22
- include Base::Classifier
23
-
24
- # Return the class labels.
25
- # @return [Numo::Int32] (size: n_classes)
26
- attr_reader :classes
27
-
28
- # Return the importance for each feature.
29
- # @return [Numo::DFloat] (size: n_features)
30
- attr_reader :feature_importances
31
-
32
- # Return the learned tree.
33
- # @return [Node]
34
- attr_reader :tree
35
-
36
- # Return the random generator for random selection of feature index.
37
- # @return [Random]
38
- attr_reader :rng
39
-
40
- # Return the labels assigned each leaf.
41
- # @return [Numo::Int32] (size: n_leafs)
42
- attr_reader :leaf_labels
43
-
44
- # Create a new classifier with decision tree algorithm.
45
- #
46
- # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
47
- # @param max_depth [Integer] The maximum depth of the tree.
48
- # If nil is given, decision tree grows without concern for depth.
49
- # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
50
- # If nil is given, number of leaves is not limited.
51
- # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
52
- # @param max_features [Integer] The number of features to consider when searching optimal split point.
53
- # If nil is given, split process considers all features.
54
- # @param random_seed [Integer] The seed value using to initialize the random generator.
55
- # It is used to randomly determine the order of features when deciding spliting point.
56
- def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
57
- random_seed: nil)
58
- SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
59
- max_features: max_features, random_seed: random_seed)
60
- SVMKit::Validation.check_params_integer(min_samples_leaf: min_samples_leaf)
61
- SVMKit::Validation.check_params_string(criterion: criterion)
62
- SVMKit::Validation.check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
63
- min_samples_leaf: min_samples_leaf, max_features: max_features)
64
- @params = {}
65
- @params[:criterion] = criterion
66
- @params[:max_depth] = max_depth
67
- @params[:max_leaf_nodes] = max_leaf_nodes
68
- @params[:min_samples_leaf] = min_samples_leaf
69
- @params[:max_features] = max_features
70
- @params[:random_seed] = random_seed
71
- @params[:random_seed] ||= srand
72
- @criterion = :gini
73
- @criterion = :entropy if @params[:criterion] == 'entropy'
74
- @tree = nil
75
- @classes = nil
76
- @feature_importances = nil
77
- @n_leaves = nil
78
- @leaf_labels = nil
79
- @rng = Random.new(@params[:random_seed])
80
- end
81
-
82
- # Fit the model with given training data.
83
- #
84
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
85
- # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
86
- # @return [DecisionTreeClassifier] The learned classifier itself.
87
- def fit(x, y)
88
- SVMKit::Validation.check_sample_array(x)
89
- SVMKit::Validation.check_label_array(y)
90
- SVMKit::Validation.check_sample_label_size(x, y)
91
- n_samples, n_features = x.shape
92
- @params[:max_features] = n_features if @params[:max_features].nil?
93
- @params[:max_features] = [@params[:max_features], n_features].min
94
- uniq_y = y.to_a.uniq.sort
95
- @classes = Numo::Int32.asarray(uniq_y)
96
- build_tree(x, y.map { |v| uniq_y.index(v) })
97
- eval_importance(n_samples, n_features)
98
- self
99
- end
100
-
101
- # Predict class labels for samples.
102
- #
103
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
104
- # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
105
- def predict(x)
106
- SVMKit::Validation.check_sample_array(x)
107
- @leaf_labels[apply(x)]
108
- end
109
-
110
- # Predict probability for samples.
111
- #
112
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
113
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
114
- def predict_proba(x)
115
- SVMKit::Validation.check_sample_array(x)
116
- Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })]
117
- end
118
-
119
- # Return the index of the leaf that each sample reached.
120
- #
121
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
122
- # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
123
- def apply(x)
124
- SVMKit::Validation.check_sample_array(x)
125
- Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })]
126
- end
127
-
128
- # Dump marshal data.
129
- # @return [Hash] The marshal data about DecisionTreeClassifier
130
- def marshal_dump
131
- { params: @params,
132
- classes: @classes,
133
- criterion: @criterion,
134
- tree: @tree,
135
- feature_importances: @feature_importances,
136
- leaf_labels: @leaf_labels,
137
- rng: @rng }
138
- end
139
-
140
- # Load marshal data.
141
- # @return [nil]
142
- def marshal_load(obj)
143
- @params = obj[:params]
144
- @classes = obj[:classes]
145
- @criterion = obj[:criterion]
146
- @tree = obj[:tree]
147
- @feature_importances = obj[:feature_importances]
148
- @leaf_labels = obj[:leaf_labels]
149
- @rng = obj[:rng]
150
- nil
151
- end
152
-
153
- private
154
-
155
- def predict_at_node(node, sample)
156
- return node.probs if node.leaf
157
- branch_at_node('predict', node, sample)
158
- end
159
-
160
- def apply_at_node(node, sample)
161
- return node.leaf_id if node.leaf
162
- branch_at_node('apply', node, sample)
163
- end
164
-
165
- def branch_at_node(action, node, sample)
166
- return send("#{action}_at_node", node.left, sample) if node.right.nil?
167
- return send("#{action}_at_node", node.right, sample) if node.left.nil?
168
- if sample[node.feature_id] <= node.threshold
169
- send("#{action}_at_node", node.left, sample)
170
- else
171
- send("#{action}_at_node", node.right, sample)
172
- end
173
- end
174
-
175
- def build_tree(x, y)
176
- @n_leaves = 0
177
- @leaf_labels = []
178
- @tree = grow_node(0, x, y, impurity(y))
179
- @leaf_labels = Numo::Int32[*@leaf_labels]
180
- nil
181
- end
182
-
183
- def grow_node(depth, x, y, whole_impurity)
184
- unless @params[:max_leaf_nodes].nil?
185
- return nil if @n_leaves >= @params[:max_leaf_nodes]
186
- end
187
-
188
- n_samples, n_features = x.shape
189
- return nil if n_samples <= @params[:min_samples_leaf]
190
-
191
- node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples)
192
-
193
- return put_leaf(node, y) if y.to_a.uniq.size == 1
194
-
195
- unless @params[:max_depth].nil?
196
- return put_leaf(node, y) if depth == @params[:max_depth]
197
- end
198
-
199
- feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain =
200
- rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last)
201
-
202
- return put_leaf(node, y) if gain.nil? || gain.zero?
203
-
204
- node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids], left_impurity)
205
- node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids], right_impurity)
206
-
207
- return put_leaf(node, y) if node.left.nil? && node.right.nil?
208
-
209
- node.feature_id = feature_id
210
- node.threshold = threshold
211
- node.leaf = false
212
- node
213
- end
214
-
215
- def put_leaf(node, y)
216
- node.probs = y.bincount(minlength: @classes.size) / node.n_samples.to_f
217
- node.leaf = true
218
- node.leaf_id = @n_leaves
219
- @n_leaves += 1
220
- @leaf_labels.push(@classes[node.probs.max_index])
221
- node
222
- end
223
-
224
- def rand_ids(n)
225
- [*0...n].sample(@params[:max_features], random: @rng)
226
- end
227
-
228
- def best_split(features, labels, whole_impurity)
229
- n_samples = labels.size
230
- features.to_a.uniq.sort.each_cons(2).map do |l, r|
231
- threshold = 0.5 * (l + r)
232
- left_ids = features.le(threshold).where
233
- right_ids = features.gt(threshold).where
234
- left_impurity = impurity(labels[left_ids])
235
- right_impurity = impurity(labels[right_ids])
236
- gain = whole_impurity -
237
- left_impurity * left_ids.size.fdiv(n_samples) -
238
- right_impurity * right_ids.size.fdiv(n_samples)
239
- [threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
240
- end.max_by(&:last)
241
- end
242
-
243
- def impurity(labels)
244
- send(@criterion, labels.bincount / labels.size.to_f)
245
- end
246
-
247
- def gini(posterior_probs)
248
- 1.0 - (posterior_probs * posterior_probs).sum
249
- end
250
-
251
- def entropy(posterior_probs)
252
- -(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
253
- end
254
-
255
- def eval_importance(n_samples, n_features)
256
- @feature_importances = Numo::DFloat.zeros(n_features)
257
- eval_importance_at_node(@tree)
258
- @feature_importances /= n_samples
259
- normalizer = @feature_importances.sum
260
- @feature_importances /= normalizer if normalizer > 0.0
261
- nil
262
- end
263
-
264
- def eval_importance_at_node(node)
265
- return nil if node.leaf
266
- return nil if node.left.nil? || node.right.nil?
267
- gain = node.n_samples * node.impurity -
268
- node.left.n_samples * node.left.impurity -
269
- node.right.n_samples * node.right.impurity
270
- @feature_importances[node.feature_id] += gain
271
- eval_importance_at_node(node.left)
272
- eval_importance_at_node(node.right)
273
- end
274
- end
275
- end
276
- end