svmkit 0.7.3 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,87 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/transformer'
6
-
7
- module SVMKit
8
- # This module consists of the classes that perform preprocessings.
9
- module Preprocessing
10
- # Normalize samples by centering and scaling to unit variance.
11
- #
12
- # @example
13
- # normalizer = SVMKit::Preprocessing::StandardScaler.new
14
- # new_training_samples = normalizer.fit_transform(training_samples)
15
- # new_testing_samples = normalizer.transform(testing_samples)
16
- class StandardScaler
17
- include Base::BaseEstimator
18
- include Base::Transformer
19
-
20
- # Return the vector consists of the mean value for each feature.
21
- # @return [Numo::DFloat] (shape: [n_features])
22
- attr_reader :mean_vec
23
-
24
- # Return the vector consists of the standard deviation for each feature.
25
- # @return [Numo::DFloat] (shape: [n_features])
26
- attr_reader :std_vec
27
-
28
- # Create a new normalizer for centering and scaling to unit variance.
29
- def initialize
30
- @params = {}
31
- @mean_vec = nil
32
- @std_vec = nil
33
- end
34
-
35
- # Calculate the mean value and standard deviation of each feature for scaling.
36
- #
37
- # @overload fit(x) -> StandardScaler
38
- #
39
- # @param x [Numo::DFloat] (shape: [n_samples, n_features])
40
- # The samples to calculate the mean values and standard deviations.
41
- # @return [StandardScaler]
42
- def fit(x, _y = nil)
43
- SVMKit::Validation.check_sample_array(x)
44
- @mean_vec = x.mean(0)
45
- @std_vec = x.stddev(0)
46
- self
47
- end
48
-
49
- # Calculate the mean values and standard deviations, and then normalize samples using them.
50
- #
51
- # @overload fit_transform(x) -> Numo::DFloat
52
- #
53
- # @param x [Numo::DFloat] (shape: [n_samples, n_features])
54
- # The samples to calculate the mean values and standard deviations.
55
- # @return [Numo::DFloat] The scaled samples.
56
- def fit_transform(x, _y = nil)
57
- SVMKit::Validation.check_sample_array(x)
58
- fit(x).transform(x)
59
- end
60
-
61
- # Perform standardization the given samples.
62
- #
63
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be scaled.
64
- # @return [Numo::DFloat] The scaled samples.
65
- def transform(x)
66
- SVMKit::Validation.check_sample_array(x)
67
- n_samples, = x.shape
68
- (x - @mean_vec.tile(n_samples, 1)) / @std_vec.tile(n_samples, 1)
69
- end
70
-
71
- # Dump marshal data.
72
- # @return [Hash] The marshal data about StandardScaler.
73
- def marshal_dump
74
- { mean_vec: @mean_vec,
75
- std_vec: @std_vec }
76
- end
77
-
78
- # Load marshal data.
79
- # @return [nil]
80
- def marshal_load(obj)
81
- @mean_vec = obj[:mean_vec]
82
- @std_vec = obj[:std_vec]
83
- nil
84
- end
85
- end
86
- end
87
- end
@@ -1,112 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- # Module for calculating posterior class probabilities with SVM outputs.
5
- # This module is used for internal processes.
6
- #
7
- # @example
8
- # estimator = SVMKit::LinearModel::SVC.new
9
- # estimator.fit(x, bin_y)
10
- # df = estimator.decision_function(x)
11
- # params = SVMKit::ProbabilisticOutput.fit_sigmoid(df, bin_y)
12
- # probs = 1 / (Numo::NMath.exp(params[0] * df + params[1]) + 1)
13
- #
14
- # *Reference*
15
- # 1. J C. Platt, "Probabilistic Outputs for Support Vector Machines and Comparisons to Regularized Likelihood Methods," Adv. Large Margin Classifiers, pp. 61--74, 2000.
16
- # 1. H-T Lin, C-J Lin, and R C.Weng, "A Note on Platt's Probabilistic Outputs for Support Vector Machines," J. Machine Learning, Vol. 63 (3), pp. 267--276, 2007.
17
- module ProbabilisticOutput
18
- class << self
19
- # Fit the probabilistic model for binary SVM outputs.
20
- #
21
- # @param df [Numo::DFloat] (shape: [n_samples]) The outputs of decision function to be used for fitting the model.
22
- # @param bin_y [Numo::Int32] (shape: [n_samples]) The binary labels to be used for fitting the model.
23
- # @param max_iter [Integer] The maximum number of iterations.
24
- # @param min_step [Float] The minimum step of Newton's method.
25
- # @param sigma [Float] The parameter to avoid hessian matrix from becoming singular matrix.
26
- # @return [Numo::DFloat] (shape: 2) The parameters of the model.
27
- def fit_sigmoid(df, bin_y, max_iter = 100, min_step = 1e-10, sigma = 1e-12)
28
- # Initialize some variables.
29
- n_samples = bin_y.size
30
- negative_label = bin_y.to_a.uniq.min
31
- pos = bin_y.ne(negative_label)
32
- neg = bin_y.eq(negative_label)
33
- n_pos_samples = pos.count
34
- n_neg_samples = neg.count
35
- target_probs = Numo::DFloat.zeros(n_samples)
36
- target_probs[pos] = (n_pos_samples + 1) / (n_pos_samples + 2.0)
37
- target_probs[neg] = 1 / (n_neg_samples + 2.0)
38
- alpha = 0.0
39
- beta = Math.log((n_neg_samples + 1) / (n_pos_samples + 1.0))
40
- err = error_function(target_probs, df, alpha, beta)
41
- # Optimize parameters for class porbability calculation.
42
- old_grad_vec = Numo::DFloat.zeros(2)
43
- max_iter.times do
44
- # Calculate gradient and hessian matrix.
45
- probs = predicted_probs(df, alpha, beta)
46
- grad_vec = gradient(target_probs, probs, df)
47
- hess_mat = hessian_matrix(probs, df, sigma)
48
- break if grad_vec.abs.lt(1e-5).count == 2
49
- break if (old_grad_vec - grad_vec).abs.sum < 1e-5
50
- old_grad_vec = grad_vec
51
- # Calculate Newton directions.
52
- dirs_vec = directions(grad_vec, hess_mat)
53
- grad_dir = grad_vec.dot(dirs_vec)
54
- stepsize = 2.0
55
- while stepsize >= min_step
56
- stepsize *= 0.5
57
- new_alpha = alpha + stepsize * dirs_vec[0]
58
- new_beta = beta + stepsize * dirs_vec[1]
59
- new_err = error_function(target_probs, df, new_alpha, new_beta)
60
- next unless new_err < err + 0.0001 * stepsize * grad_dir
61
- alpha = new_alpha
62
- beta = new_beta
63
- err = new_err
64
- break
65
- end
66
- end
67
- Numo::DFloat[alpha, beta]
68
- end
69
-
70
- private
71
-
72
- def error_function(target_probs, df, alpha, beta)
73
- fn = alpha * df + beta
74
- pos = fn.ge(0.0)
75
- neg = fn.lt(0.0)
76
- err = 0.0
77
- err += (target_probs[pos] * fn[pos] + Numo::NMath.log(1 + Numo::NMath.exp(-fn[pos]))).sum if pos.count > 0
78
- err += ((target_probs[neg] - 1) * fn[neg] + Numo::NMath.log(1 + Numo::NMath.exp(fn[neg]))).sum if neg.count > 0
79
- err
80
- end
81
-
82
- def predicted_probs(df, alpha, beta)
83
- fn = alpha * df + beta
84
- pos = fn.ge(0.0)
85
- neg = fn.lt(0.0)
86
- probs = Numo::DFloat.zeros(df.shape[0])
87
- probs[pos] = Numo::NMath.exp(-fn[pos]) / (1 + Numo::NMath.exp(-fn[pos])) if pos.count > 0
88
- probs[neg] = 1 / (1 + Numo::NMath.exp(fn[neg])) if neg.count > 0
89
- probs
90
- end
91
-
92
- def gradient(target_probs, probs, df)
93
- sub = target_probs - probs
94
- Numo::DFloat[(df * sub).sum, sub.sum]
95
- end
96
-
97
- def hessian_matrix(probs, df, sigma)
98
- sub = probs * (1 - probs)
99
- h11 = (df * df * sub).sum + sigma
100
- h22 = sub.sum + sigma
101
- h21 = (df * sub).sum
102
- Numo::DFloat[[h11, h21], [h21, h22]]
103
- end
104
-
105
- def directions(grad_vec, hess_mat)
106
- det = hess_mat[0, 0] * hess_mat[1, 1] - hess_mat[0, 1] * hess_mat[1, 0]
107
- inv_hess_mat = Numo::DFloat[[hess_mat[1, 1], -hess_mat[0, 1]], [-hess_mat[1, 0], hess_mat[0, 0]]] / det
108
- -inv_hess_mat.dot(grad_vec)
109
- end
110
- end
111
- end
112
- end
@@ -1,276 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/classifier'
6
- require 'svmkit/tree/node'
7
-
8
- module SVMKit
9
- # This module consists of the classes that implement tree models.
10
- module Tree
11
- # DecisionTreeClassifier is a class that implements decision tree for classification.
12
- #
13
- # @example
14
- # estimator =
15
- # SVMKit::Tree::DecisionTreeClassifier.new(
16
- # criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
17
- # estimator.fit(training_samples, traininig_labels)
18
- # results = estimator.predict(testing_samples)
19
- #
20
- class DecisionTreeClassifier
21
- include Base::BaseEstimator
22
- include Base::Classifier
23
-
24
- # Return the class labels.
25
- # @return [Numo::Int32] (size: n_classes)
26
- attr_reader :classes
27
-
28
- # Return the importance for each feature.
29
- # @return [Numo::DFloat] (size: n_features)
30
- attr_reader :feature_importances
31
-
32
- # Return the learned tree.
33
- # @return [Node]
34
- attr_reader :tree
35
-
36
- # Return the random generator for random selection of feature index.
37
- # @return [Random]
38
- attr_reader :rng
39
-
40
- # Return the labels assigned each leaf.
41
- # @return [Numo::Int32] (size: n_leafs)
42
- attr_reader :leaf_labels
43
-
44
- # Create a new classifier with decision tree algorithm.
45
- #
46
- # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
47
- # @param max_depth [Integer] The maximum depth of the tree.
48
- # If nil is given, decision tree grows without concern for depth.
49
- # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
50
- # If nil is given, number of leaves is not limited.
51
- # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
52
- # @param max_features [Integer] The number of features to consider when searching optimal split point.
53
- # If nil is given, split process considers all features.
54
- # @param random_seed [Integer] The seed value using to initialize the random generator.
55
- # It is used to randomly determine the order of features when deciding spliting point.
56
- def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
57
- random_seed: nil)
58
- SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
59
- max_features: max_features, random_seed: random_seed)
60
- SVMKit::Validation.check_params_integer(min_samples_leaf: min_samples_leaf)
61
- SVMKit::Validation.check_params_string(criterion: criterion)
62
- SVMKit::Validation.check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
63
- min_samples_leaf: min_samples_leaf, max_features: max_features)
64
- @params = {}
65
- @params[:criterion] = criterion
66
- @params[:max_depth] = max_depth
67
- @params[:max_leaf_nodes] = max_leaf_nodes
68
- @params[:min_samples_leaf] = min_samples_leaf
69
- @params[:max_features] = max_features
70
- @params[:random_seed] = random_seed
71
- @params[:random_seed] ||= srand
72
- @criterion = :gini
73
- @criterion = :entropy if @params[:criterion] == 'entropy'
74
- @tree = nil
75
- @classes = nil
76
- @feature_importances = nil
77
- @n_leaves = nil
78
- @leaf_labels = nil
79
- @rng = Random.new(@params[:random_seed])
80
- end
81
-
82
- # Fit the model with given training data.
83
- #
84
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
85
- # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
86
- # @return [DecisionTreeClassifier] The learned classifier itself.
87
- def fit(x, y)
88
- SVMKit::Validation.check_sample_array(x)
89
- SVMKit::Validation.check_label_array(y)
90
- SVMKit::Validation.check_sample_label_size(x, y)
91
- n_samples, n_features = x.shape
92
- @params[:max_features] = n_features if @params[:max_features].nil?
93
- @params[:max_features] = [@params[:max_features], n_features].min
94
- uniq_y = y.to_a.uniq.sort
95
- @classes = Numo::Int32.asarray(uniq_y)
96
- build_tree(x, y.map { |v| uniq_y.index(v) })
97
- eval_importance(n_samples, n_features)
98
- self
99
- end
100
-
101
- # Predict class labels for samples.
102
- #
103
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
104
- # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
105
- def predict(x)
106
- SVMKit::Validation.check_sample_array(x)
107
- @leaf_labels[apply(x)]
108
- end
109
-
110
- # Predict probability for samples.
111
- #
112
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
113
- # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
114
- def predict_proba(x)
115
- SVMKit::Validation.check_sample_array(x)
116
- Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })]
117
- end
118
-
119
- # Return the index of the leaf that each sample reached.
120
- #
121
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
122
- # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
123
- def apply(x)
124
- SVMKit::Validation.check_sample_array(x)
125
- Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })]
126
- end
127
-
128
- # Dump marshal data.
129
- # @return [Hash] The marshal data about DecisionTreeClassifier
130
- def marshal_dump
131
- { params: @params,
132
- classes: @classes,
133
- criterion: @criterion,
134
- tree: @tree,
135
- feature_importances: @feature_importances,
136
- leaf_labels: @leaf_labels,
137
- rng: @rng }
138
- end
139
-
140
- # Load marshal data.
141
- # @return [nil]
142
- def marshal_load(obj)
143
- @params = obj[:params]
144
- @classes = obj[:classes]
145
- @criterion = obj[:criterion]
146
- @tree = obj[:tree]
147
- @feature_importances = obj[:feature_importances]
148
- @leaf_labels = obj[:leaf_labels]
149
- @rng = obj[:rng]
150
- nil
151
- end
152
-
153
- private
154
-
155
- def predict_at_node(node, sample)
156
- return node.probs if node.leaf
157
- branch_at_node('predict', node, sample)
158
- end
159
-
160
- def apply_at_node(node, sample)
161
- return node.leaf_id if node.leaf
162
- branch_at_node('apply', node, sample)
163
- end
164
-
165
- def branch_at_node(action, node, sample)
166
- return send("#{action}_at_node", node.left, sample) if node.right.nil?
167
- return send("#{action}_at_node", node.right, sample) if node.left.nil?
168
- if sample[node.feature_id] <= node.threshold
169
- send("#{action}_at_node", node.left, sample)
170
- else
171
- send("#{action}_at_node", node.right, sample)
172
- end
173
- end
174
-
175
- def build_tree(x, y)
176
- @n_leaves = 0
177
- @leaf_labels = []
178
- @tree = grow_node(0, x, y, impurity(y))
179
- @leaf_labels = Numo::Int32[*@leaf_labels]
180
- nil
181
- end
182
-
183
- def grow_node(depth, x, y, whole_impurity)
184
- unless @params[:max_leaf_nodes].nil?
185
- return nil if @n_leaves >= @params[:max_leaf_nodes]
186
- end
187
-
188
- n_samples, n_features = x.shape
189
- return nil if n_samples <= @params[:min_samples_leaf]
190
-
191
- node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples)
192
-
193
- return put_leaf(node, y) if y.to_a.uniq.size == 1
194
-
195
- unless @params[:max_depth].nil?
196
- return put_leaf(node, y) if depth == @params[:max_depth]
197
- end
198
-
199
- feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain =
200
- rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last)
201
-
202
- return put_leaf(node, y) if gain.nil? || gain.zero?
203
-
204
- node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids], left_impurity)
205
- node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids], right_impurity)
206
-
207
- return put_leaf(node, y) if node.left.nil? && node.right.nil?
208
-
209
- node.feature_id = feature_id
210
- node.threshold = threshold
211
- node.leaf = false
212
- node
213
- end
214
-
215
- def put_leaf(node, y)
216
- node.probs = y.bincount(minlength: @classes.size) / node.n_samples.to_f
217
- node.leaf = true
218
- node.leaf_id = @n_leaves
219
- @n_leaves += 1
220
- @leaf_labels.push(@classes[node.probs.max_index])
221
- node
222
- end
223
-
224
- def rand_ids(n)
225
- [*0...n].sample(@params[:max_features], random: @rng)
226
- end
227
-
228
- def best_split(features, labels, whole_impurity)
229
- n_samples = labels.size
230
- features.to_a.uniq.sort.each_cons(2).map do |l, r|
231
- threshold = 0.5 * (l + r)
232
- left_ids = features.le(threshold).where
233
- right_ids = features.gt(threshold).where
234
- left_impurity = impurity(labels[left_ids])
235
- right_impurity = impurity(labels[right_ids])
236
- gain = whole_impurity -
237
- left_impurity * left_ids.size.fdiv(n_samples) -
238
- right_impurity * right_ids.size.fdiv(n_samples)
239
- [threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
240
- end.max_by(&:last)
241
- end
242
-
243
- def impurity(labels)
244
- send(@criterion, labels.bincount / labels.size.to_f)
245
- end
246
-
247
- def gini(posterior_probs)
248
- 1.0 - (posterior_probs * posterior_probs).sum
249
- end
250
-
251
- def entropy(posterior_probs)
252
- -(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
253
- end
254
-
255
- def eval_importance(n_samples, n_features)
256
- @feature_importances = Numo::DFloat.zeros(n_features)
257
- eval_importance_at_node(@tree)
258
- @feature_importances /= n_samples
259
- normalizer = @feature_importances.sum
260
- @feature_importances /= normalizer if normalizer > 0.0
261
- nil
262
- end
263
-
264
- def eval_importance_at_node(node)
265
- return nil if node.leaf
266
- return nil if node.left.nil? || node.right.nil?
267
- gain = node.n_samples * node.impurity -
268
- node.left.n_samples * node.left.impurity -
269
- node.right.n_samples * node.right.impurity
270
- @feature_importances[node.feature_id] += gain
271
- eval_importance_at_node(node.left)
272
- eval_importance_at_node(node.right)
273
- end
274
- end
275
- end
276
- end