rumale 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/.coveralls.yml +1 -0
  3. data/.gitignore +20 -0
  4. data/.rspec +3 -0
  5. data/.rubocop.yml +47 -0
  6. data/.rubocop_todo.yml +58 -0
  7. data/.travis.yml +13 -0
  8. data/CHANGELOG.md +2 -0
  9. data/CODE_OF_CONDUCT.md +74 -0
  10. data/Gemfile +4 -0
  11. data/LICENSE.txt +23 -0
  12. data/README.md +175 -0
  13. data/Rakefile +6 -0
  14. data/bin/console +14 -0
  15. data/bin/setup +8 -0
  16. data/lib/rumale.rb +70 -0
  17. data/lib/rumale/base/base_estimator.rb +13 -0
  18. data/lib/rumale/base/classifier.rb +36 -0
  19. data/lib/rumale/base/cluster_analyzer.rb +31 -0
  20. data/lib/rumale/base/evaluator.rb +17 -0
  21. data/lib/rumale/base/regressor.rb +36 -0
  22. data/lib/rumale/base/splitter.rb +21 -0
  23. data/lib/rumale/base/transformer.rb +22 -0
  24. data/lib/rumale/clustering/dbscan.rb +125 -0
  25. data/lib/rumale/clustering/k_means.rb +138 -0
  26. data/lib/rumale/dataset.rb +110 -0
  27. data/lib/rumale/decomposition/nmf.rb +141 -0
  28. data/lib/rumale/decomposition/pca.rb +148 -0
  29. data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
  30. data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
  31. data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
  32. data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
  33. data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
  34. data/lib/rumale/evaluation_measure/f_score.rb +50 -0
  35. data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
  36. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
  37. data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
  38. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
  39. data/lib/rumale/evaluation_measure/precision.rb +50 -0
  40. data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
  41. data/lib/rumale/evaluation_measure/purity.rb +40 -0
  42. data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
  43. data/lib/rumale/evaluation_measure/recall.rb +50 -0
  44. data/lib/rumale/kernel_approximation/rbf.rb +121 -0
  45. data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
  46. data/lib/rumale/linear_model/base_linear_model.rb +89 -0
  47. data/lib/rumale/linear_model/lasso.rb +136 -0
  48. data/lib/rumale/linear_model/linear_regression.rb +110 -0
  49. data/lib/rumale/linear_model/logistic_regression.rb +159 -0
  50. data/lib/rumale/linear_model/ridge.rb +110 -0
  51. data/lib/rumale/linear_model/svc.rb +183 -0
  52. data/lib/rumale/linear_model/svr.rb +122 -0
  53. data/lib/rumale/model_selection/cross_validation.rb +123 -0
  54. data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
  55. data/lib/rumale/model_selection/k_fold.rb +76 -0
  56. data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
  57. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
  58. data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
  59. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
  60. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
  61. data/lib/rumale/optimizer/nadam.rb +90 -0
  62. data/lib/rumale/optimizer/rmsprop.rb +69 -0
  63. data/lib/rumale/optimizer/sgd.rb +65 -0
  64. data/lib/rumale/optimizer/yellow_fin.rb +144 -0
  65. data/lib/rumale/pairwise_metric.rb +91 -0
  66. data/lib/rumale/pipeline/pipeline.rb +197 -0
  67. data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
  68. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
  69. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
  70. data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
  71. data/lib/rumale/preprocessing/label_encoder.rb +94 -0
  72. data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
  73. data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
  74. data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
  75. data/lib/rumale/probabilistic_output.rb +112 -0
  76. data/lib/rumale/tree/base_decision_tree.rb +153 -0
  77. data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
  78. data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
  79. data/lib/rumale/tree/node.rb +70 -0
  80. data/lib/rumale/utils.rb +37 -0
  81. data/lib/rumale/validation.rb +79 -0
  82. data/lib/rumale/values.rb +13 -0
  83. data/lib/rumale/version.rb +6 -0
  84. data/rumale.gemspec +41 -0
  85. metadata +204 -0
@@ -0,0 +1,86 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/transformer'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that perform preprocessings.
8
+ module Preprocessing
9
+ # Normalize samples by centering and scaling to unit variance.
10
+ #
11
+ # @example
12
+ # normalizer = Rumale::Preprocessing::StandardScaler.new
13
+ # new_training_samples = normalizer.fit_transform(training_samples)
14
+ # new_testing_samples = normalizer.transform(testing_samples)
15
+ class StandardScaler
16
+ include Base::BaseEstimator
17
+ include Base::Transformer
18
+
19
+ # Return the vector consists of the mean value for each feature.
20
+ # @return [Numo::DFloat] (shape: [n_features])
21
+ attr_reader :mean_vec
22
+
23
+ # Return the vector consists of the standard deviation for each feature.
24
+ # @return [Numo::DFloat] (shape: [n_features])
25
+ attr_reader :std_vec
26
+
27
+ # Create a new normalizer for centering and scaling to unit variance.
28
+ def initialize
29
+ @params = {}
30
+ @mean_vec = nil
31
+ @std_vec = nil
32
+ end
33
+
34
+ # Calculate the mean value and standard deviation of each feature for scaling.
35
+ #
36
+ # @overload fit(x) -> StandardScaler
37
+ #
38
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
39
+ # The samples to calculate the mean values and standard deviations.
40
+ # @return [StandardScaler]
41
+ def fit(x, _y = nil)
42
+ check_sample_array(x)
43
+ @mean_vec = x.mean(0)
44
+ @std_vec = x.stddev(0)
45
+ self
46
+ end
47
+
48
+ # Calculate the mean values and standard deviations, and then normalize samples using them.
49
+ #
50
+ # @overload fit_transform(x) -> Numo::DFloat
51
+ #
52
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
53
+ # The samples to calculate the mean values and standard deviations.
54
+ # @return [Numo::DFloat] The scaled samples.
55
+ def fit_transform(x, _y = nil)
56
+ check_sample_array(x)
57
+ fit(x).transform(x)
58
+ end
59
+
60
+ # Perform standardization the given samples.
61
+ #
62
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be scaled.
63
+ # @return [Numo::DFloat] The scaled samples.
64
+ def transform(x)
65
+ check_sample_array(x)
66
+ n_samples, = x.shape
67
+ (x - @mean_vec.tile(n_samples, 1)) / @std_vec.tile(n_samples, 1)
68
+ end
69
+
70
+ # Dump marshal data.
71
+ # @return [Hash] The marshal data about StandardScaler.
72
+ def marshal_dump
73
+ { mean_vec: @mean_vec,
74
+ std_vec: @std_vec }
75
+ end
76
+
77
+ # Load marshal data.
78
+ # @return [nil]
79
+ def marshal_load(obj)
80
+ @mean_vec = obj[:mean_vec]
81
+ @std_vec = obj[:std_vec]
82
+ nil
83
+ end
84
+ end
85
+ end
86
+ end
@@ -0,0 +1,112 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Rumale
4
+ # Module for calculating posterior class probabilities with SVM outputs.
5
+ # This module is used for internal processes.
6
+ #
7
+ # @example
8
+ # estimator = Rumale::LinearModel::SVC.new
9
+ # estimator.fit(x, bin_y)
10
+ # df = estimator.decision_function(x)
11
+ # params = Rumale::ProbabilisticOutput.fit_sigmoid(df, bin_y)
12
+ # probs = 1 / (Numo::NMath.exp(params[0] * df + params[1]) + 1)
13
+ #
14
+ # *Reference*
15
+ # 1. J C. Platt, "Probabilistic Outputs for Support Vector Machines and Comparisons to Regularized Likelihood Methods," Adv. Large Margin Classifiers, pp. 61--74, 2000.
16
+ # 1. H-T Lin, C-J Lin, and R C.Weng, "A Note on Platt's Probabilistic Outputs for Support Vector Machines," J. Machine Learning, Vol. 63 (3), pp. 267--276, 2007.
17
+ module ProbabilisticOutput
18
+ class << self
19
+ # Fit the probabilistic model for binary SVM outputs.
20
+ #
21
+ # @param df [Numo::DFloat] (shape: [n_samples]) The outputs of decision function to be used for fitting the model.
22
+ # @param bin_y [Numo::Int32] (shape: [n_samples]) The binary labels to be used for fitting the model.
23
+ # @param max_iter [Integer] The maximum number of iterations.
24
+ # @param min_step [Float] The minimum step of Newton's method.
25
+ # @param sigma [Float] The parameter to avoid hessian matrix from becoming singular matrix.
26
+ # @return [Numo::DFloat] (shape: 2) The parameters of the model.
27
+ def fit_sigmoid(df, bin_y, max_iter = 100, min_step = 1e-10, sigma = 1e-12)
28
+ # Initialize some variables.
29
+ n_samples = bin_y.size
30
+ negative_label = bin_y.to_a.uniq.min
31
+ pos = bin_y.ne(negative_label)
32
+ neg = bin_y.eq(negative_label)
33
+ n_pos_samples = pos.count
34
+ n_neg_samples = neg.count
35
+ target_probs = Numo::DFloat.zeros(n_samples)
36
+ target_probs[pos] = (n_pos_samples + 1) / (n_pos_samples + 2.0)
37
+ target_probs[neg] = 1 / (n_neg_samples + 2.0)
38
+ alpha = 0.0
39
+ beta = Math.log((n_neg_samples + 1) / (n_pos_samples + 1.0))
40
+ err = error_function(target_probs, df, alpha, beta)
41
+ # Optimize parameters for class porbability calculation.
42
+ old_grad_vec = Numo::DFloat.zeros(2)
43
+ max_iter.times do
44
+ # Calculate gradient and hessian matrix.
45
+ probs = predicted_probs(df, alpha, beta)
46
+ grad_vec = gradient(target_probs, probs, df)
47
+ hess_mat = hessian_matrix(probs, df, sigma)
48
+ break if grad_vec.abs.lt(1e-5).count == 2
49
+ break if (old_grad_vec - grad_vec).abs.sum < 1e-5
50
+ old_grad_vec = grad_vec
51
+ # Calculate Newton directions.
52
+ dirs_vec = directions(grad_vec, hess_mat)
53
+ grad_dir = grad_vec.dot(dirs_vec)
54
+ stepsize = 2.0
55
+ while stepsize >= min_step
56
+ stepsize *= 0.5
57
+ new_alpha = alpha + stepsize * dirs_vec[0]
58
+ new_beta = beta + stepsize * dirs_vec[1]
59
+ new_err = error_function(target_probs, df, new_alpha, new_beta)
60
+ next unless new_err < err + 0.0001 * stepsize * grad_dir
61
+ alpha = new_alpha
62
+ beta = new_beta
63
+ err = new_err
64
+ break
65
+ end
66
+ end
67
+ Numo::DFloat[alpha, beta]
68
+ end
69
+
70
+ private
71
+
72
+ def error_function(target_probs, df, alpha, beta)
73
+ fn = alpha * df + beta
74
+ pos = fn.ge(0.0)
75
+ neg = fn.lt(0.0)
76
+ err = 0.0
77
+ err += (target_probs[pos] * fn[pos] + Numo::NMath.log(1 + Numo::NMath.exp(-fn[pos]))).sum if pos.count.positive?
78
+ err += ((target_probs[neg] - 1) * fn[neg] + Numo::NMath.log(1 + Numo::NMath.exp(fn[neg]))).sum if neg.count.positive?
79
+ err
80
+ end
81
+
82
+ def predicted_probs(df, alpha, beta)
83
+ fn = alpha * df + beta
84
+ pos = fn.ge(0.0)
85
+ neg = fn.lt(0.0)
86
+ probs = Numo::DFloat.zeros(df.shape[0])
87
+ probs[pos] = Numo::NMath.exp(-fn[pos]) / (1 + Numo::NMath.exp(-fn[pos])) if pos.count.positive?
88
+ probs[neg] = 1 / (1 + Numo::NMath.exp(fn[neg])) if neg.count.positive?
89
+ probs
90
+ end
91
+
92
+ def gradient(target_probs, probs, df)
93
+ sub = target_probs - probs
94
+ Numo::DFloat[(df * sub).sum, sub.sum]
95
+ end
96
+
97
+ def hessian_matrix(probs, df, sigma)
98
+ sub = probs * (1 - probs)
99
+ h11 = (df * df * sub).sum + sigma
100
+ h22 = sub.sum + sigma
101
+ h21 = (df * sub).sum
102
+ Numo::DFloat[[h11, h21], [h21, h22]]
103
+ end
104
+
105
+ def directions(grad_vec, hess_mat)
106
+ det = hess_mat[0, 0] * hess_mat[1, 1] - hess_mat[0, 1] * hess_mat[1, 0]
107
+ inv_hess_mat = Numo::DFloat[[hess_mat[1, 1], -hess_mat[0, 1]], [-hess_mat[1, 0], hess_mat[0, 0]]] / det
108
+ -inv_hess_mat.dot(grad_vec)
109
+ end
110
+ end
111
+ end
112
+ end
@@ -0,0 +1,153 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/tree/node'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that implement tree models.
8
+ module Tree
9
+ # BaseDecisionTree is an abstract class for implementation of decision tree-based estimator.
10
+ # This class is used internally.
11
+ class BaseDecisionTree
12
+ include Base::BaseEstimator
13
+
14
+ # Initialize a decision tree-based estimator.
15
+ #
16
+ # @param criterion [String] The function to evalue spliting point.
17
+ # @param max_depth [Integer] The maximum depth of the tree.
18
+ # If nil is given, decision tree grows without concern for depth.
19
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
20
+ # If nil is given, number of leaves is not limited.
21
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
22
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
23
+ # If nil is given, split process considers all features.
24
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
25
+ # It is used to randomly determine the order of features when deciding spliting point.
26
+ def initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil)
27
+ @params = {}
28
+ @params[:criterion] = criterion
29
+ @params[:max_depth] = max_depth
30
+ @params[:max_leaf_nodes] = max_leaf_nodes
31
+ @params[:min_samples_leaf] = min_samples_leaf
32
+ @params[:max_features] = max_features
33
+ @params[:random_seed] = random_seed
34
+ @params[:random_seed] ||= srand
35
+ @tree = nil
36
+ @feature_importances = nil
37
+ @n_leaves = nil
38
+ @rng = Random.new(@params[:random_seed])
39
+ end
40
+
41
+ # Return the index of the leaf that each sample reached.
42
+ #
43
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
44
+ # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
45
+ def apply(x)
46
+ check_sample_array(x)
47
+ Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })]
48
+ end
49
+
50
+ private
51
+
52
+ def apply_at_node(node, sample)
53
+ return node.leaf_id if node.leaf
54
+ return apply_at_node(node.left, sample) if node.right.nil?
55
+ return apply_at_node(node.right, sample) if node.left.nil?
56
+ if sample[node.feature_id] <= node.threshold
57
+ apply_at_node(node.left, sample)
58
+ else
59
+ apply_at_node(node.right, sample)
60
+ end
61
+ end
62
+
63
+ def build_tree(x, y)
64
+ y = y.expand_dims(1).dup if y.shape[1].nil?
65
+ @tree = grow_node(0, x, y, impurity(y))
66
+ nil
67
+ end
68
+
69
+ def grow_node(depth, x, y, whole_impurity)
70
+ unless @params[:max_leaf_nodes].nil?
71
+ return nil if @n_leaves >= @params[:max_leaf_nodes]
72
+ end
73
+
74
+ n_samples, n_features = x.shape
75
+ return nil if n_samples <= @params[:min_samples_leaf]
76
+
77
+ node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples)
78
+
79
+ return put_leaf(node, y) if stop_growing?(y)
80
+
81
+ unless @params[:max_depth].nil?
82
+ return put_leaf(node, y) if depth == @params[:max_depth]
83
+ end
84
+
85
+ feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain =
86
+ rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last)
87
+
88
+ return put_leaf(node, y) if gain.nil? || gain.zero?
89
+
90
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_impurity)
91
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_impurity)
92
+
93
+ return put_leaf(node, y) if node.left.nil? && node.right.nil?
94
+
95
+ node.feature_id = feature_id
96
+ node.threshold = threshold
97
+ node.leaf = false
98
+ node
99
+ end
100
+
101
+ def stop_growing?(_y)
102
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
103
+ end
104
+
105
+ def put_leaf(_node, _y)
106
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
107
+ end
108
+
109
+ def rand_ids(n)
110
+ [*0...n].sample(@params[:max_features], random: @rng)
111
+ end
112
+
113
+ def best_split(features, targets, whole_impurity)
114
+ n_samples = targets.shape[0]
115
+ features.to_a.uniq.sort.each_cons(2).map do |l, r|
116
+ threshold = 0.5 * (l + r)
117
+ left_ids = features.le(threshold).where
118
+ right_ids = features.gt(threshold).where
119
+ left_impurity = impurity(targets[left_ids, true])
120
+ right_impurity = impurity(targets[right_ids, true])
121
+ gain = whole_impurity -
122
+ left_impurity * left_ids.size.fdiv(n_samples) -
123
+ right_impurity * right_ids.size.fdiv(n_samples)
124
+ [threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
125
+ end.max_by(&:last)
126
+ end
127
+
128
+ def impurity(_targets)
129
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
130
+ end
131
+
132
+ def eval_importance(n_samples, n_features)
133
+ @feature_importances = Numo::DFloat.zeros(n_features)
134
+ eval_importance_at_node(@tree)
135
+ @feature_importances /= n_samples
136
+ normalizer = @feature_importances.sum
137
+ @feature_importances /= normalizer if normalizer > 0.0
138
+ nil
139
+ end
140
+
141
+ def eval_importance_at_node(node)
142
+ return nil if node.leaf
143
+ return nil if node.left.nil? || node.right.nil?
144
+ gain = node.n_samples * node.impurity -
145
+ node.left.n_samples * node.left.impurity -
146
+ node.right.n_samples * node.right.impurity
147
+ @feature_importances[node.feature_id] += gain
148
+ eval_importance_at_node(node.left)
149
+ eval_importance_at_node(node.right)
150
+ end
151
+ end
152
+ end
153
+ end
@@ -0,0 +1,163 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/base_decision_tree'
4
+ require 'rumale/base/classifier'
5
+
6
+ module Rumale
7
+ module Tree
8
+ # DecisionTreeClassifier is a class that implements decision tree for classification.
9
+ #
10
+ # @example
11
+ # estimator =
12
+ # Rumale::Tree::DecisionTreeClassifier.new(
13
+ # criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
14
+ # estimator.fit(training_samples, traininig_labels)
15
+ # results = estimator.predict(testing_samples)
16
+ #
17
+ class DecisionTreeClassifier < BaseDecisionTree
18
+ include Base::Classifier
19
+
20
+ # Return the class labels.
21
+ # @return [Numo::Int32] (size: n_classes)
22
+ attr_reader :classes
23
+
24
+ # Return the importance for each feature.
25
+ # @return [Numo::DFloat] (size: n_features)
26
+ attr_reader :feature_importances
27
+
28
+ # Return the learned tree.
29
+ # @return [Node]
30
+ attr_reader :tree
31
+
32
+ # Return the random generator for random selection of feature index.
33
+ # @return [Random]
34
+ attr_reader :rng
35
+
36
+ # Return the labels assigned each leaf.
37
+ # @return [Numo::Int32] (size: n_leafs)
38
+ attr_reader :leaf_labels
39
+
40
+ # Create a new classifier with decision tree algorithm.
41
+ #
42
+ # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
43
+ # @param max_depth [Integer] The maximum depth of the tree.
44
+ # If nil is given, decision tree grows without concern for depth.
45
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
46
+ # If nil is given, number of leaves is not limited.
47
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
48
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
49
+ # If nil is given, split process considers all features.
50
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
51
+ # It is used to randomly determine the order of features when deciding spliting point.
52
+ def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
53
+ random_seed: nil)
54
+ check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
55
+ max_features: max_features, random_seed: random_seed)
56
+ check_params_integer(min_samples_leaf: min_samples_leaf)
57
+ check_params_string(criterion: criterion)
58
+ check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
59
+ min_samples_leaf: min_samples_leaf, max_features: max_features)
60
+ super
61
+ @leaf_labels = nil
62
+ end
63
+
64
+ # Fit the model with given training data.
65
+ #
66
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
67
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
68
+ # @return [DecisionTreeClassifier] The learned classifier itself.
69
+ def fit(x, y)
70
+ check_sample_array(x)
71
+ check_label_array(y)
72
+ check_sample_label_size(x, y)
73
+ n_samples, n_features = x.shape
74
+ @params[:max_features] = n_features if @params[:max_features].nil?
75
+ @params[:max_features] = [@params[:max_features], n_features].min
76
+ uniq_y = y.to_a.uniq.sort
77
+ @classes = Numo::Int32.asarray(uniq_y)
78
+ @n_leaves = 0
79
+ @leaf_labels = []
80
+ build_tree(x, y.map { |v| uniq_y.index(v) })
81
+ eval_importance(n_samples, n_features)
82
+ @leaf_labels = Numo::Int32[*@leaf_labels]
83
+ self
84
+ end
85
+
86
+ # Predict class labels for samples.
87
+ #
88
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
89
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
90
+ def predict(x)
91
+ check_sample_array(x)
92
+ @leaf_labels[apply(x)]
93
+ end
94
+
95
+ # Predict probability for samples.
96
+ #
97
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
98
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
99
+ def predict_proba(x)
100
+ check_sample_array(x)
101
+ Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_proba_at_node(@tree, x[n, true]) })]
102
+ end
103
+
104
+ # Dump marshal data.
105
+ # @return [Hash] The marshal data about DecisionTreeClassifier
106
+ def marshal_dump
107
+ { params: @params,
108
+ classes: @classes,
109
+ tree: @tree,
110
+ feature_importances: @feature_importances,
111
+ leaf_labels: @leaf_labels,
112
+ rng: @rng }
113
+ end
114
+
115
+ # Load marshal data.
116
+ # @return [nil]
117
+ def marshal_load(obj)
118
+ @params = obj[:params]
119
+ @classes = obj[:classes]
120
+ @tree = obj[:tree]
121
+ @feature_importances = obj[:feature_importances]
122
+ @leaf_labels = obj[:leaf_labels]
123
+ @rng = obj[:rng]
124
+ nil
125
+ end
126
+
127
+ private
128
+
129
+ def predict_proba_at_node(node, sample)
130
+ return node.probs if node.leaf
131
+ return predict_proba_at_node(node.left, sample) if node.right.nil?
132
+ return predict_proba_at_node(node.right, sample) if node.left.nil?
133
+ if sample[node.feature_id] <= node.threshold
134
+ predict_proba_at_node(node.left, sample)
135
+ else
136
+ predict_proba_at_node(node.right, sample)
137
+ end
138
+ end
139
+
140
+ def stop_growing?(y)
141
+ y.flatten.to_a.uniq.size == 1
142
+ end
143
+
144
+ def put_leaf(node, y)
145
+ node.probs = y.flatten.bincount(minlength: @classes.size) / node.n_samples.to_f
146
+ node.leaf = true
147
+ node.leaf_id = @n_leaves
148
+ @n_leaves += 1
149
+ @leaf_labels.push(@classes[node.probs.max_index])
150
+ node
151
+ end
152
+
153
+ def impurity(y)
154
+ posterior_probs = y.flatten.bincount / y.size.to_f
155
+ if @params[:criterion] == 'entropy'
156
+ -(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
157
+ else
158
+ 1.0 - (posterior_probs * posterior_probs).sum
159
+ end
160
+ end
161
+ end
162
+ end
163
+ end