rumale 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/.coveralls.yml +1 -0
  3. data/.gitignore +20 -0
  4. data/.rspec +3 -0
  5. data/.rubocop.yml +47 -0
  6. data/.rubocop_todo.yml +58 -0
  7. data/.travis.yml +13 -0
  8. data/CHANGELOG.md +2 -0
  9. data/CODE_OF_CONDUCT.md +74 -0
  10. data/Gemfile +4 -0
  11. data/LICENSE.txt +23 -0
  12. data/README.md +175 -0
  13. data/Rakefile +6 -0
  14. data/bin/console +14 -0
  15. data/bin/setup +8 -0
  16. data/lib/rumale.rb +70 -0
  17. data/lib/rumale/base/base_estimator.rb +13 -0
  18. data/lib/rumale/base/classifier.rb +36 -0
  19. data/lib/rumale/base/cluster_analyzer.rb +31 -0
  20. data/lib/rumale/base/evaluator.rb +17 -0
  21. data/lib/rumale/base/regressor.rb +36 -0
  22. data/lib/rumale/base/splitter.rb +21 -0
  23. data/lib/rumale/base/transformer.rb +22 -0
  24. data/lib/rumale/clustering/dbscan.rb +125 -0
  25. data/lib/rumale/clustering/k_means.rb +138 -0
  26. data/lib/rumale/dataset.rb +110 -0
  27. data/lib/rumale/decomposition/nmf.rb +141 -0
  28. data/lib/rumale/decomposition/pca.rb +148 -0
  29. data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
  30. data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
  31. data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
  32. data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
  33. data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
  34. data/lib/rumale/evaluation_measure/f_score.rb +50 -0
  35. data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
  36. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
  37. data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
  38. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
  39. data/lib/rumale/evaluation_measure/precision.rb +50 -0
  40. data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
  41. data/lib/rumale/evaluation_measure/purity.rb +40 -0
  42. data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
  43. data/lib/rumale/evaluation_measure/recall.rb +50 -0
  44. data/lib/rumale/kernel_approximation/rbf.rb +121 -0
  45. data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
  46. data/lib/rumale/linear_model/base_linear_model.rb +89 -0
  47. data/lib/rumale/linear_model/lasso.rb +136 -0
  48. data/lib/rumale/linear_model/linear_regression.rb +110 -0
  49. data/lib/rumale/linear_model/logistic_regression.rb +159 -0
  50. data/lib/rumale/linear_model/ridge.rb +110 -0
  51. data/lib/rumale/linear_model/svc.rb +183 -0
  52. data/lib/rumale/linear_model/svr.rb +122 -0
  53. data/lib/rumale/model_selection/cross_validation.rb +123 -0
  54. data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
  55. data/lib/rumale/model_selection/k_fold.rb +76 -0
  56. data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
  57. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
  58. data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
  59. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
  60. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
  61. data/lib/rumale/optimizer/nadam.rb +90 -0
  62. data/lib/rumale/optimizer/rmsprop.rb +69 -0
  63. data/lib/rumale/optimizer/sgd.rb +65 -0
  64. data/lib/rumale/optimizer/yellow_fin.rb +144 -0
  65. data/lib/rumale/pairwise_metric.rb +91 -0
  66. data/lib/rumale/pipeline/pipeline.rb +197 -0
  67. data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
  68. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
  69. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
  70. data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
  71. data/lib/rumale/preprocessing/label_encoder.rb +94 -0
  72. data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
  73. data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
  74. data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
  75. data/lib/rumale/probabilistic_output.rb +112 -0
  76. data/lib/rumale/tree/base_decision_tree.rb +153 -0
  77. data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
  78. data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
  79. data/lib/rumale/tree/node.rb +70 -0
  80. data/lib/rumale/utils.rb +37 -0
  81. data/lib/rumale/validation.rb +79 -0
  82. data/lib/rumale/values.rb +13 -0
  83. data/lib/rumale/version.rb +6 -0
  84. data/rumale.gemspec +41 -0
  85. metadata +204 -0
@@ -0,0 +1,86 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/transformer'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that perform preprocessings.
8
+ module Preprocessing
9
+ # Normalize samples by centering and scaling to unit variance.
10
+ #
11
+ # @example
12
+ # normalizer = Rumale::Preprocessing::StandardScaler.new
13
+ # new_training_samples = normalizer.fit_transform(training_samples)
14
+ # new_testing_samples = normalizer.transform(testing_samples)
15
+ class StandardScaler
16
+ include Base::BaseEstimator
17
+ include Base::Transformer
18
+
19
+ # Return the vector consists of the mean value for each feature.
20
+ # @return [Numo::DFloat] (shape: [n_features])
21
+ attr_reader :mean_vec
22
+
23
+ # Return the vector consists of the standard deviation for each feature.
24
+ # @return [Numo::DFloat] (shape: [n_features])
25
+ attr_reader :std_vec
26
+
27
+ # Create a new normalizer for centering and scaling to unit variance.
28
+ def initialize
29
+ @params = {}
30
+ @mean_vec = nil
31
+ @std_vec = nil
32
+ end
33
+
34
+ # Calculate the mean value and standard deviation of each feature for scaling.
35
+ #
36
+ # @overload fit(x) -> StandardScaler
37
+ #
38
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
39
+ # The samples to calculate the mean values and standard deviations.
40
+ # @return [StandardScaler]
41
+ def fit(x, _y = nil)
42
+ check_sample_array(x)
43
+ @mean_vec = x.mean(0)
44
+ @std_vec = x.stddev(0)
45
+ self
46
+ end
47
+
48
+ # Calculate the mean values and standard deviations, and then normalize samples using them.
49
+ #
50
+ # @overload fit_transform(x) -> Numo::DFloat
51
+ #
52
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
53
+ # The samples to calculate the mean values and standard deviations.
54
+ # @return [Numo::DFloat] The scaled samples.
55
+ def fit_transform(x, _y = nil)
56
+ check_sample_array(x)
57
+ fit(x).transform(x)
58
+ end
59
+
60
+ # Perform standardization the given samples.
61
+ #
62
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be scaled.
63
+ # @return [Numo::DFloat] The scaled samples.
64
+ def transform(x)
65
+ check_sample_array(x)
66
+ n_samples, = x.shape
67
+ (x - @mean_vec.tile(n_samples, 1)) / @std_vec.tile(n_samples, 1)
68
+ end
69
+
70
+ # Dump marshal data.
71
+ # @return [Hash] The marshal data about StandardScaler.
72
+ def marshal_dump
73
+ { mean_vec: @mean_vec,
74
+ std_vec: @std_vec }
75
+ end
76
+
77
+ # Load marshal data.
78
+ # @return [nil]
79
+ def marshal_load(obj)
80
+ @mean_vec = obj[:mean_vec]
81
+ @std_vec = obj[:std_vec]
82
+ nil
83
+ end
84
+ end
85
+ end
86
+ end
@@ -0,0 +1,112 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Rumale
4
+ # Module for calculating posterior class probabilities with SVM outputs.
5
+ # This module is used for internal processes.
6
+ #
7
+ # @example
8
+ # estimator = Rumale::LinearModel::SVC.new
9
+ # estimator.fit(x, bin_y)
10
+ # df = estimator.decision_function(x)
11
+ # params = Rumale::ProbabilisticOutput.fit_sigmoid(df, bin_y)
12
+ # probs = 1 / (Numo::NMath.exp(params[0] * df + params[1]) + 1)
13
+ #
14
+ # *Reference*
15
+ # 1. J C. Platt, "Probabilistic Outputs for Support Vector Machines and Comparisons to Regularized Likelihood Methods," Adv. Large Margin Classifiers, pp. 61--74, 2000.
16
+ # 1. H-T Lin, C-J Lin, and R C.Weng, "A Note on Platt's Probabilistic Outputs for Support Vector Machines," J. Machine Learning, Vol. 63 (3), pp. 267--276, 2007.
17
+ module ProbabilisticOutput
18
+ class << self
19
+ # Fit the probabilistic model for binary SVM outputs.
20
+ #
21
+ # @param df [Numo::DFloat] (shape: [n_samples]) The outputs of decision function to be used for fitting the model.
22
+ # @param bin_y [Numo::Int32] (shape: [n_samples]) The binary labels to be used for fitting the model.
23
+ # @param max_iter [Integer] The maximum number of iterations.
24
+ # @param min_step [Float] The minimum step of Newton's method.
25
+ # @param sigma [Float] The parameter to avoid hessian matrix from becoming singular matrix.
26
+ # @return [Numo::DFloat] (shape: 2) The parameters of the model.
27
+ def fit_sigmoid(df, bin_y, max_iter = 100, min_step = 1e-10, sigma = 1e-12)
28
+ # Initialize some variables.
29
+ n_samples = bin_y.size
30
+ negative_label = bin_y.to_a.uniq.min
31
+ pos = bin_y.ne(negative_label)
32
+ neg = bin_y.eq(negative_label)
33
+ n_pos_samples = pos.count
34
+ n_neg_samples = neg.count
35
+ target_probs = Numo::DFloat.zeros(n_samples)
36
+ target_probs[pos] = (n_pos_samples + 1) / (n_pos_samples + 2.0)
37
+ target_probs[neg] = 1 / (n_neg_samples + 2.0)
38
+ alpha = 0.0
39
+ beta = Math.log((n_neg_samples + 1) / (n_pos_samples + 1.0))
40
+ err = error_function(target_probs, df, alpha, beta)
41
+ # Optimize parameters for class porbability calculation.
42
+ old_grad_vec = Numo::DFloat.zeros(2)
43
+ max_iter.times do
44
+ # Calculate gradient and hessian matrix.
45
+ probs = predicted_probs(df, alpha, beta)
46
+ grad_vec = gradient(target_probs, probs, df)
47
+ hess_mat = hessian_matrix(probs, df, sigma)
48
+ break if grad_vec.abs.lt(1e-5).count == 2
49
+ break if (old_grad_vec - grad_vec).abs.sum < 1e-5
50
+ old_grad_vec = grad_vec
51
+ # Calculate Newton directions.
52
+ dirs_vec = directions(grad_vec, hess_mat)
53
+ grad_dir = grad_vec.dot(dirs_vec)
54
+ stepsize = 2.0
55
+ while stepsize >= min_step
56
+ stepsize *= 0.5
57
+ new_alpha = alpha + stepsize * dirs_vec[0]
58
+ new_beta = beta + stepsize * dirs_vec[1]
59
+ new_err = error_function(target_probs, df, new_alpha, new_beta)
60
+ next unless new_err < err + 0.0001 * stepsize * grad_dir
61
+ alpha = new_alpha
62
+ beta = new_beta
63
+ err = new_err
64
+ break
65
+ end
66
+ end
67
+ Numo::DFloat[alpha, beta]
68
+ end
69
+
70
+ private
71
+
72
+ def error_function(target_probs, df, alpha, beta)
73
+ fn = alpha * df + beta
74
+ pos = fn.ge(0.0)
75
+ neg = fn.lt(0.0)
76
+ err = 0.0
77
+ err += (target_probs[pos] * fn[pos] + Numo::NMath.log(1 + Numo::NMath.exp(-fn[pos]))).sum if pos.count.positive?
78
+ err += ((target_probs[neg] - 1) * fn[neg] + Numo::NMath.log(1 + Numo::NMath.exp(fn[neg]))).sum if neg.count.positive?
79
+ err
80
+ end
81
+
82
+ def predicted_probs(df, alpha, beta)
83
+ fn = alpha * df + beta
84
+ pos = fn.ge(0.0)
85
+ neg = fn.lt(0.0)
86
+ probs = Numo::DFloat.zeros(df.shape[0])
87
+ probs[pos] = Numo::NMath.exp(-fn[pos]) / (1 + Numo::NMath.exp(-fn[pos])) if pos.count.positive?
88
+ probs[neg] = 1 / (1 + Numo::NMath.exp(fn[neg])) if neg.count.positive?
89
+ probs
90
+ end
91
+
92
+ def gradient(target_probs, probs, df)
93
+ sub = target_probs - probs
94
+ Numo::DFloat[(df * sub).sum, sub.sum]
95
+ end
96
+
97
+ def hessian_matrix(probs, df, sigma)
98
+ sub = probs * (1 - probs)
99
+ h11 = (df * df * sub).sum + sigma
100
+ h22 = sub.sum + sigma
101
+ h21 = (df * sub).sum
102
+ Numo::DFloat[[h11, h21], [h21, h22]]
103
+ end
104
+
105
+ def directions(grad_vec, hess_mat)
106
+ det = hess_mat[0, 0] * hess_mat[1, 1] - hess_mat[0, 1] * hess_mat[1, 0]
107
+ inv_hess_mat = Numo::DFloat[[hess_mat[1, 1], -hess_mat[0, 1]], [-hess_mat[1, 0], hess_mat[0, 0]]] / det
108
+ -inv_hess_mat.dot(grad_vec)
109
+ end
110
+ end
111
+ end
112
+ end
@@ -0,0 +1,153 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/tree/node'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that implement tree models.
8
+ module Tree
9
+ # BaseDecisionTree is an abstract class for implementation of decision tree-based estimator.
10
+ # This class is used internally.
11
+ class BaseDecisionTree
12
+ include Base::BaseEstimator
13
+
14
+ # Initialize a decision tree-based estimator.
15
+ #
16
+ # @param criterion [String] The function to evalue spliting point.
17
+ # @param max_depth [Integer] The maximum depth of the tree.
18
+ # If nil is given, decision tree grows without concern for depth.
19
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
20
+ # If nil is given, number of leaves is not limited.
21
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
22
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
23
+ # If nil is given, split process considers all features.
24
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
25
+ # It is used to randomly determine the order of features when deciding spliting point.
26
+ def initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil)
27
+ @params = {}
28
+ @params[:criterion] = criterion
29
+ @params[:max_depth] = max_depth
30
+ @params[:max_leaf_nodes] = max_leaf_nodes
31
+ @params[:min_samples_leaf] = min_samples_leaf
32
+ @params[:max_features] = max_features
33
+ @params[:random_seed] = random_seed
34
+ @params[:random_seed] ||= srand
35
+ @tree = nil
36
+ @feature_importances = nil
37
+ @n_leaves = nil
38
+ @rng = Random.new(@params[:random_seed])
39
+ end
40
+
41
+ # Return the index of the leaf that each sample reached.
42
+ #
43
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
44
+ # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
45
+ def apply(x)
46
+ check_sample_array(x)
47
+ Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })]
48
+ end
49
+
50
+ private
51
+
52
+ def apply_at_node(node, sample)
53
+ return node.leaf_id if node.leaf
54
+ return apply_at_node(node.left, sample) if node.right.nil?
55
+ return apply_at_node(node.right, sample) if node.left.nil?
56
+ if sample[node.feature_id] <= node.threshold
57
+ apply_at_node(node.left, sample)
58
+ else
59
+ apply_at_node(node.right, sample)
60
+ end
61
+ end
62
+
63
+ def build_tree(x, y)
64
+ y = y.expand_dims(1).dup if y.shape[1].nil?
65
+ @tree = grow_node(0, x, y, impurity(y))
66
+ nil
67
+ end
68
+
69
+ def grow_node(depth, x, y, whole_impurity)
70
+ unless @params[:max_leaf_nodes].nil?
71
+ return nil if @n_leaves >= @params[:max_leaf_nodes]
72
+ end
73
+
74
+ n_samples, n_features = x.shape
75
+ return nil if n_samples <= @params[:min_samples_leaf]
76
+
77
+ node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples)
78
+
79
+ return put_leaf(node, y) if stop_growing?(y)
80
+
81
+ unless @params[:max_depth].nil?
82
+ return put_leaf(node, y) if depth == @params[:max_depth]
83
+ end
84
+
85
+ feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain =
86
+ rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last)
87
+
88
+ return put_leaf(node, y) if gain.nil? || gain.zero?
89
+
90
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_impurity)
91
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_impurity)
92
+
93
+ return put_leaf(node, y) if node.left.nil? && node.right.nil?
94
+
95
+ node.feature_id = feature_id
96
+ node.threshold = threshold
97
+ node.leaf = false
98
+ node
99
+ end
100
+
101
+ def stop_growing?(_y)
102
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
103
+ end
104
+
105
+ def put_leaf(_node, _y)
106
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
107
+ end
108
+
109
+ def rand_ids(n)
110
+ [*0...n].sample(@params[:max_features], random: @rng)
111
+ end
112
+
113
+ def best_split(features, targets, whole_impurity)
114
+ n_samples = targets.shape[0]
115
+ features.to_a.uniq.sort.each_cons(2).map do |l, r|
116
+ threshold = 0.5 * (l + r)
117
+ left_ids = features.le(threshold).where
118
+ right_ids = features.gt(threshold).where
119
+ left_impurity = impurity(targets[left_ids, true])
120
+ right_impurity = impurity(targets[right_ids, true])
121
+ gain = whole_impurity -
122
+ left_impurity * left_ids.size.fdiv(n_samples) -
123
+ right_impurity * right_ids.size.fdiv(n_samples)
124
+ [threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
125
+ end.max_by(&:last)
126
+ end
127
+
128
+ def impurity(_targets)
129
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
130
+ end
131
+
132
+ def eval_importance(n_samples, n_features)
133
+ @feature_importances = Numo::DFloat.zeros(n_features)
134
+ eval_importance_at_node(@tree)
135
+ @feature_importances /= n_samples
136
+ normalizer = @feature_importances.sum
137
+ @feature_importances /= normalizer if normalizer > 0.0
138
+ nil
139
+ end
140
+
141
+ def eval_importance_at_node(node)
142
+ return nil if node.leaf
143
+ return nil if node.left.nil? || node.right.nil?
144
+ gain = node.n_samples * node.impurity -
145
+ node.left.n_samples * node.left.impurity -
146
+ node.right.n_samples * node.right.impurity
147
+ @feature_importances[node.feature_id] += gain
148
+ eval_importance_at_node(node.left)
149
+ eval_importance_at_node(node.right)
150
+ end
151
+ end
152
+ end
153
+ end
@@ -0,0 +1,163 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/base_decision_tree'
4
+ require 'rumale/base/classifier'
5
+
6
+ module Rumale
7
+ module Tree
8
+ # DecisionTreeClassifier is a class that implements decision tree for classification.
9
+ #
10
+ # @example
11
+ # estimator =
12
+ # Rumale::Tree::DecisionTreeClassifier.new(
13
+ # criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
14
+ # estimator.fit(training_samples, traininig_labels)
15
+ # results = estimator.predict(testing_samples)
16
+ #
17
+ class DecisionTreeClassifier < BaseDecisionTree
18
+ include Base::Classifier
19
+
20
+ # Return the class labels.
21
+ # @return [Numo::Int32] (size: n_classes)
22
+ attr_reader :classes
23
+
24
+ # Return the importance for each feature.
25
+ # @return [Numo::DFloat] (size: n_features)
26
+ attr_reader :feature_importances
27
+
28
+ # Return the learned tree.
29
+ # @return [Node]
30
+ attr_reader :tree
31
+
32
+ # Return the random generator for random selection of feature index.
33
+ # @return [Random]
34
+ attr_reader :rng
35
+
36
+ # Return the labels assigned each leaf.
37
+ # @return [Numo::Int32] (size: n_leafs)
38
+ attr_reader :leaf_labels
39
+
40
+ # Create a new classifier with decision tree algorithm.
41
+ #
42
+ # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
43
+ # @param max_depth [Integer] The maximum depth of the tree.
44
+ # If nil is given, decision tree grows without concern for depth.
45
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
46
+ # If nil is given, number of leaves is not limited.
47
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
48
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
49
+ # If nil is given, split process considers all features.
50
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
51
+ # It is used to randomly determine the order of features when deciding spliting point.
52
+ def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
53
+ random_seed: nil)
54
+ check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
55
+ max_features: max_features, random_seed: random_seed)
56
+ check_params_integer(min_samples_leaf: min_samples_leaf)
57
+ check_params_string(criterion: criterion)
58
+ check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
59
+ min_samples_leaf: min_samples_leaf, max_features: max_features)
60
+ super
61
+ @leaf_labels = nil
62
+ end
63
+
64
+ # Fit the model with given training data.
65
+ #
66
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
67
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
68
+ # @return [DecisionTreeClassifier] The learned classifier itself.
69
+ def fit(x, y)
70
+ check_sample_array(x)
71
+ check_label_array(y)
72
+ check_sample_label_size(x, y)
73
+ n_samples, n_features = x.shape
74
+ @params[:max_features] = n_features if @params[:max_features].nil?
75
+ @params[:max_features] = [@params[:max_features], n_features].min
76
+ uniq_y = y.to_a.uniq.sort
77
+ @classes = Numo::Int32.asarray(uniq_y)
78
+ @n_leaves = 0
79
+ @leaf_labels = []
80
+ build_tree(x, y.map { |v| uniq_y.index(v) })
81
+ eval_importance(n_samples, n_features)
82
+ @leaf_labels = Numo::Int32[*@leaf_labels]
83
+ self
84
+ end
85
+
86
+ # Predict class labels for samples.
87
+ #
88
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
89
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
90
+ def predict(x)
91
+ check_sample_array(x)
92
+ @leaf_labels[apply(x)]
93
+ end
94
+
95
+ # Predict probability for samples.
96
+ #
97
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
98
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
99
+ def predict_proba(x)
100
+ check_sample_array(x)
101
+ Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_proba_at_node(@tree, x[n, true]) })]
102
+ end
103
+
104
+ # Dump marshal data.
105
+ # @return [Hash] The marshal data about DecisionTreeClassifier
106
+ def marshal_dump
107
+ { params: @params,
108
+ classes: @classes,
109
+ tree: @tree,
110
+ feature_importances: @feature_importances,
111
+ leaf_labels: @leaf_labels,
112
+ rng: @rng }
113
+ end
114
+
115
+ # Load marshal data.
116
+ # @return [nil]
117
+ def marshal_load(obj)
118
+ @params = obj[:params]
119
+ @classes = obj[:classes]
120
+ @tree = obj[:tree]
121
+ @feature_importances = obj[:feature_importances]
122
+ @leaf_labels = obj[:leaf_labels]
123
+ @rng = obj[:rng]
124
+ nil
125
+ end
126
+
127
+ private
128
+
129
+ def predict_proba_at_node(node, sample)
130
+ return node.probs if node.leaf
131
+ return predict_proba_at_node(node.left, sample) if node.right.nil?
132
+ return predict_proba_at_node(node.right, sample) if node.left.nil?
133
+ if sample[node.feature_id] <= node.threshold
134
+ predict_proba_at_node(node.left, sample)
135
+ else
136
+ predict_proba_at_node(node.right, sample)
137
+ end
138
+ end
139
+
140
+ def stop_growing?(y)
141
+ y.flatten.to_a.uniq.size == 1
142
+ end
143
+
144
+ def put_leaf(node, y)
145
+ node.probs = y.flatten.bincount(minlength: @classes.size) / node.n_samples.to_f
146
+ node.leaf = true
147
+ node.leaf_id = @n_leaves
148
+ @n_leaves += 1
149
+ @leaf_labels.push(@classes[node.probs.max_index])
150
+ node
151
+ end
152
+
153
+ def impurity(y)
154
+ posterior_probs = y.flatten.bincount / y.size.to_f
155
+ if @params[:criterion] == 'entropy'
156
+ -(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
157
+ else
158
+ 1.0 - (posterior_probs * posterior_probs).sum
159
+ end
160
+ end
161
+ end
162
+ end
163
+ end