rumale-tree 0.24.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,154 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/validation'
5
+ require 'rumale/tree/ext'
6
+ require 'rumale/tree/node'
7
+
8
+ module Rumale
9
+ module Tree
10
+ # BaseDecisionTree is an abstract class for implementation of decision tree-based estimator.
11
+ # This class is used internally.
12
+ class BaseDecisionTree < ::Rumale::Base::Estimator
13
+ # Initialize a decision tree-based estimator.
14
+ #
15
+ # @param criterion [String] The function to evalue spliting point.
16
+ # @param max_depth [Integer] The maximum depth of the tree.
17
+ # If nil is given, decision tree grows without concern for depth.
18
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
19
+ # If nil is given, number of leaves is not limited.
20
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
21
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
22
+ # If nil is given, split process considers all features.
23
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
24
+ # It is used to randomly determine the order of features when deciding spliting point.
25
+ def initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
26
+ random_seed: nil)
27
+ super()
28
+ @params = {
29
+ criterion: criterion,
30
+ max_depth: max_depth,
31
+ max_leaf_nodes: max_leaf_nodes,
32
+ min_samples_leaf: min_samples_leaf,
33
+ max_features: max_features,
34
+ random_seed: random_seed || srand
35
+ }
36
+ @rng = Random.new(@params[:random_seed])
37
+ end
38
+
39
+ # Return the index of the leaf that each sample reached.
40
+ #
41
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
42
+ # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
43
+ def apply(x)
44
+ x = ::Rumale::Validation.check_convert_sample_array(x)
45
+
46
+ Numo::Int32[*(Array.new(x.shape[0]) { |n| partial_apply(@tree, x[n, true]) })]
47
+ end
48
+
49
+ private
50
+
51
+ def partial_apply(tree, sample)
52
+ node = tree
53
+ until node.leaf
54
+ node = if node.right.nil?
55
+ node.left
56
+ elsif node.left.nil?
57
+ node.right
58
+ else
59
+ sample[node.feature_id] <= node.threshold ? node.left : node.right
60
+ end
61
+ end
62
+ node.leaf_id
63
+ end
64
+
65
+ def build_tree(x, y)
66
+ y = y.expand_dims(1).dup if y.shape[1].nil?
67
+ @feature_ids = Array.new(x.shape[1]) { |v| v }
68
+ @tree = grow_node(0, x, y, impurity(y))
69
+ @feature_ids = nil
70
+ nil
71
+ end
72
+
73
+ def grow_node(depth, x, y, impurity) # rubocop:disable Metrics/AbcSize, Metrics/PerceivedComplexity
74
+ # intialize node.
75
+ n_samples = x.shape[0]
76
+ node = Node.new(depth: depth, impurity: impurity, n_samples: n_samples)
77
+
78
+ # terminate growing.
79
+ return nil if !@params[:max_leaf_nodes].nil? && @n_leaves >= @params[:max_leaf_nodes]
80
+ return nil if n_samples < @params[:min_samples_leaf]
81
+ return put_leaf(node, y) if n_samples == @params[:min_samples_leaf]
82
+ return put_leaf(node, y) if !@params[:max_depth].nil? && depth == @params[:max_depth]
83
+ return put_leaf(node, y) if stop_growing?(y)
84
+
85
+ # calculate optimal parameters.
86
+ feature_id, left_imp, right_imp, threshold, gain =
87
+ rand_ids.map { |n| [n, *best_split(x[true, n], y, impurity)] }.max_by(&:last)
88
+
89
+ return put_leaf(node, y) if gain.nil? || gain.zero?
90
+
91
+ left_ids = x[true, feature_id].le(threshold).where
92
+ right_ids = x[true, feature_id].gt(threshold).where
93
+ node.left = if y.ndim == 1
94
+ grow_node(depth + 1, x[left_ids, true], y[left_ids], left_imp)
95
+ else
96
+ grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_imp)
97
+ end
98
+ node.right = if y.ndim == 1
99
+ grow_node(depth + 1, x[right_ids, true], y[right_ids], right_imp)
100
+ else
101
+ grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_imp)
102
+ end
103
+
104
+ return put_leaf(node, y) if node.left.nil? && node.right.nil?
105
+
106
+ node.feature_id = feature_id
107
+ node.threshold = threshold
108
+ node.leaf = false
109
+ node
110
+ end
111
+
112
+ def stop_growing?(_y)
113
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
114
+ end
115
+
116
+ def put_leaf(_node, _y)
117
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
118
+ end
119
+
120
+ def rand_ids
121
+ @feature_ids.sample(@params[:max_features], random: @sub_rng)
122
+ end
123
+
124
+ def best_split(_features, _y, _impurity)
125
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
126
+ end
127
+
128
+ def impurity(_y)
129
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
130
+ end
131
+
132
+ def eval_importance(n_samples, n_features)
133
+ @feature_importances = Numo::DFloat.zeros(n_features)
134
+ eval_importance_at_node(@tree)
135
+ @feature_importances /= n_samples
136
+ normalizer = @feature_importances.sum
137
+ @feature_importances /= normalizer if normalizer > 0.0
138
+ nil
139
+ end
140
+
141
+ def eval_importance_at_node(node)
142
+ return nil if node.leaf
143
+ return nil if node.left.nil? || node.right.nil?
144
+
145
+ gain = node.n_samples * node.impurity -
146
+ node.left.n_samples * node.left.impurity -
147
+ node.right.n_samples * node.right.impurity
148
+ @feature_importances[node.feature_id] += gain
149
+ eval_importance_at_node(node.left)
150
+ eval_importance_at_node(node.right)
151
+ end
152
+ end
153
+ end
154
+ end
@@ -0,0 +1,148 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/base_decision_tree'
4
+ require 'rumale/base/classifier'
5
+
6
+ module Rumale
7
+ module Tree
8
+ # DecisionTreeClassifier is a class that implements decision tree for classification.
9
+ #
10
+ # @example
11
+ # require 'rumale/tree/decision_tree_classifier'
12
+ #
13
+ # estimator =
14
+ # Rumale::Tree::DecisionTreeClassifier.new(
15
+ # criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
16
+ # estimator.fit(training_samples, traininig_labels)
17
+ # results = estimator.predict(testing_samples)
18
+ #
19
+ class DecisionTreeClassifier < BaseDecisionTree
20
+ include ::Rumale::Base::Classifier
21
+ include ::Rumale::Tree::ExtDecisionTreeClassifier
22
+
23
+ # Return the class labels.
24
+ # @return [Numo::Int32] (size: n_classes)
25
+ attr_reader :classes
26
+
27
+ # Return the importance for each feature.
28
+ # @return [Numo::DFloat] (size: n_features)
29
+ attr_reader :feature_importances
30
+
31
+ # Return the learned tree.
32
+ # @return [Node]
33
+ attr_reader :tree
34
+
35
+ # Return the random generator for random selection of feature index.
36
+ # @return [Random]
37
+ attr_reader :rng
38
+
39
+ # Return the labels assigned each leaf.
40
+ # @return [Numo::Int32] (size: n_leafs)
41
+ attr_reader :leaf_labels
42
+
43
+ # Create a new classifier with decision tree algorithm.
44
+ #
45
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
46
+ # @param max_depth [Integer] The maximum depth of the tree.
47
+ # If nil is given, decision tree grows without concern for depth.
48
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
49
+ # If nil is given, number of leaves is not limited.
50
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
51
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
52
+ # If nil is given, split process considers all features.
53
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
54
+ # It is used to randomly determine the order of features when deciding spliting point.
55
+ def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
56
+ random_seed: nil)
57
+ super
58
+ end
59
+
60
+ # Fit the model with given training data.
61
+ #
62
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
63
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
64
+ # @return [DecisionTreeClassifier] The learned classifier itself.
65
+ def fit(x, y)
66
+ x = ::Rumale::Validation.check_convert_sample_array(x)
67
+ y = ::Rumale::Validation.check_convert_label_array(y)
68
+ ::Rumale::Validation.check_sample_size(x, y)
69
+
70
+ n_samples, n_features = x.shape
71
+ @params[:max_features] = n_features if @params[:max_features].nil?
72
+ @params[:max_features] = [@params[:max_features], n_features].min
73
+ y = Numo::Int32.cast(y) unless y.is_a?(Numo::Int32)
74
+ uniq_y = y.to_a.uniq.sort
75
+ @classes = Numo::Int32.asarray(uniq_y)
76
+ @n_leaves = 0
77
+ @leaf_labels = []
78
+ @feature_ids = Array.new(n_features) { |v| v }
79
+ @sub_rng = @rng.dup
80
+ build_tree(x, y.map { |v| uniq_y.index(v) })
81
+ eval_importance(n_samples, n_features)
82
+ @leaf_labels = Numo::Int32[*@leaf_labels]
83
+ self
84
+ end
85
+
86
+ # Predict class labels for samples.
87
+ #
88
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
89
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
90
+ def predict(x)
91
+ x = ::Rumale::Validation.check_convert_sample_array(x)
92
+
93
+ @leaf_labels[apply(x)].dup
94
+ end
95
+
96
+ # Predict probability for samples.
97
+ #
98
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
99
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
100
+ def predict_proba(x)
101
+ x = ::Rumale::Validation.check_convert_sample_array(x)
102
+
103
+ Numo::DFloat[*(Array.new(x.shape[0]) { |n| partial_predict_proba(@tree, x[n, true]) })]
104
+ end
105
+
106
+ private
107
+
108
+ def partial_predict_proba(tree, sample)
109
+ node = tree
110
+ until node.leaf
111
+ node = if node.right.nil?
112
+ node.left
113
+ elsif node.left.nil?
114
+ node.right
115
+ else
116
+ sample[node.feature_id] <= node.threshold ? node.left : node.right
117
+ end
118
+ end
119
+ node.probs
120
+ end
121
+
122
+ def build_tree(x, y)
123
+ @tree = grow_node(0, x, y, impurity(y))
124
+ nil
125
+ end
126
+
127
+ def put_leaf(node, y)
128
+ node.probs = y.bincount(minlength: @classes.size) / node.n_samples.to_f
129
+ node.leaf = true
130
+ node.leaf_id = @n_leaves
131
+ @n_leaves += 1
132
+ @leaf_labels.push(@classes[node.probs.max_index])
133
+ node
134
+ end
135
+
136
+ def best_split(features, y, whole_impurity)
137
+ order = features.sort_index
138
+ n_classes = @classes.size
139
+ find_split_params(@params[:criterion], whole_impurity, order, features, y, n_classes)
140
+ end
141
+
142
+ def impurity(y)
143
+ n_classes = @classes.size
144
+ node_impurity(@params[:criterion], y, n_classes)
145
+ end
146
+ end
147
+ end
148
+ end
@@ -0,0 +1,113 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/base_decision_tree'
4
+ require 'rumale/base/regressor'
5
+
6
+ module Rumale
7
+ module Tree
8
+ # DecisionTreeRegressor is a class that implements decision tree for regression.
9
+ #
10
+ # @example
11
+ # require 'rumale/tree/decision_tree_regressor'
12
+ #
13
+ # estimator =
14
+ # Rumale::Tree::DecisionTreeRegressor.new(
15
+ # max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
16
+ # estimator.fit(training_samples, traininig_values)
17
+ # results = estimator.predict(testing_samples)
18
+ #
19
+ class DecisionTreeRegressor < BaseDecisionTree
20
+ include ::Rumale::Base::Regressor
21
+ include ::Rumale::Tree::ExtDecisionTreeRegressor
22
+
23
+ # Return the importance for each feature.
24
+ # @return [Numo::DFloat] (size: n_features)
25
+ attr_reader :feature_importances
26
+
27
+ # Return the learned tree.
28
+ # @return [Node]
29
+ attr_reader :tree
30
+
31
+ # Return the random generator for random selection of feature index.
32
+ # @return [Random]
33
+ attr_reader :rng
34
+
35
+ # Return the values assigned each leaf.
36
+ # @return [Numo::DFloat] (shape: [n_leafs, n_outputs])
37
+ attr_reader :leaf_values
38
+
39
+ # Create a new regressor with decision tree algorithm.
40
+ #
41
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
42
+ # @param max_depth [Integer] The maximum depth of the tree.
43
+ # If nil is given, decision tree grows without concern for depth.
44
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
45
+ # If nil is given, number of leaves is not limited.
46
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
47
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
48
+ # If nil is given, split process considers all features.
49
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
50
+ # It is used to randomly determine the order of features when deciding spliting point.
51
+ def initialize(criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
52
+ random_seed: nil)
53
+ super
54
+ end
55
+
56
+ # Fit the model with given training data.
57
+ #
58
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
59
+ # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values to be used for fitting the model.
60
+ # @return [DecisionTreeRegressor] The learned regressor itself.
61
+ def fit(x, y)
62
+ x = ::Rumale::Validation.check_convert_sample_array(x)
63
+ y = ::Rumale::Validation.check_convert_target_value_array(y)
64
+ ::Rumale::Validation.check_sample_size(x, y)
65
+
66
+ n_samples, n_features = x.shape
67
+ @params[:max_features] = n_features if @params[:max_features].nil?
68
+ @params[:max_features] = [@params[:max_features], n_features].min
69
+ @n_leaves = 0
70
+ @leaf_values = []
71
+ @sub_rng = @rng.dup
72
+ build_tree(x, y)
73
+ eval_importance(n_samples, n_features)
74
+ @leaf_values = Numo::DFloat.cast(@leaf_values)
75
+ @leaf_values = @leaf_values.flatten.dup if @leaf_values.shape[1] == 1
76
+ self
77
+ end
78
+
79
+ # Predict values for samples.
80
+ #
81
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
82
+ # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
83
+ def predict(x)
84
+ x = ::Rumale::Validation.check_convert_sample_array(x)
85
+
86
+ @leaf_values.shape[1].nil? ? @leaf_values[apply(x)].dup : @leaf_values[apply(x), true].dup
87
+ end
88
+
89
+ private
90
+
91
+ def stop_growing?(y)
92
+ y.to_a.uniq.size == 1
93
+ end
94
+
95
+ def put_leaf(node, y)
96
+ node.probs = nil
97
+ node.leaf = true
98
+ node.leaf_id = @n_leaves
99
+ @n_leaves += 1
100
+ @leaf_values.push(y.mean(0))
101
+ node
102
+ end
103
+
104
+ def best_split(f, y, impurity)
105
+ find_split_params(@params[:criterion], impurity, f.sort_index, f, y)
106
+ end
107
+
108
+ def impurity(y)
109
+ node_impurity(@params[:criterion], y.to_a)
110
+ end
111
+ end
112
+ end
113
+ end
@@ -0,0 +1,89 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/decision_tree_classifier'
4
+
5
+ module Rumale
6
+ module Tree
7
+ # ExtraTreeClassifier is a class that implements extra randomized tree for classification.
8
+ #
9
+ # @example
10
+ # require 'rumale/tree/extra_tree_classifier'
11
+ #
12
+ # estimator =
13
+ # Rumale::Tree::ExtraTreeClassifier.new(
14
+ # criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
15
+ # estimator.fit(training_samples, traininig_labels)
16
+ # results = estimator.predict(testing_samples)
17
+ #
18
+ # *Reference*
19
+ # - Geurts, P., Ernst, D., and Wehenkel, L., "Extremely randomized trees," Machine Learning, vol. 63 (1), pp. 3--42, 2006.
20
+ class ExtraTreeClassifier < DecisionTreeClassifier
21
+ # Return the class labels.
22
+ # @return [Numo::Int32] (size: n_classes)
23
+ attr_reader :classes
24
+
25
+ # Return the importance for each feature.
26
+ # @return [Numo::DFloat] (size: n_features)
27
+ attr_reader :feature_importances
28
+
29
+ # Return the learned tree.
30
+ # @return [Node]
31
+ attr_reader :tree
32
+
33
+ # Return the random generator for random selection of feature index.
34
+ # @return [Random]
35
+ attr_reader :rng
36
+
37
+ # Return the labels assigned each leaf.
38
+ # @return [Numo::Int32] (size: n_leafs)
39
+ attr_reader :leaf_labels
40
+
41
+ # Create a new classifier with extra randomized tree algorithm.
42
+ #
43
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
44
+ # @param max_depth [Integer] The maximum depth of the tree.
45
+ # If nil is given, extra tree grows without concern for depth.
46
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on extra tree.
47
+ # If nil is given, number of leaves is not limited.
48
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
49
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
50
+ # If nil is given, split process considers all features.
51
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
52
+ # It is used to randomly determine the order of features when deciding spliting point.
53
+ def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
54
+ random_seed: nil)
55
+ super
56
+ end
57
+
58
+ # Fit the model with given training data.
59
+ #
60
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
61
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
62
+ # @return [ExtraTreeClassifier] The learned classifier itself.
63
+
64
+ # Predict class labels for samples.
65
+ #
66
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
67
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
68
+
69
+ # Predict probability for samples.
70
+ #
71
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
72
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
73
+
74
+ private
75
+
76
+ def best_split(features, y, whole_impurity)
77
+ threshold = @sub_rng.rand(features.min..features.max)
78
+ l_ids = features.le(threshold).where
79
+ r_ids = features.gt(threshold).where
80
+ l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids])
81
+ r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids])
82
+ gain = whole_impurity -
83
+ l_impurity * l_ids.size.fdiv(y.size) -
84
+ r_impurity * r_ids.size.fdiv(y.size)
85
+ [l_impurity, r_impurity, threshold, gain]
86
+ end
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,80 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/decision_tree_regressor'
4
+
5
+ module Rumale
6
+ module Tree
7
+ # ExtraTreeRegressor is a class that implements extra randomized tree for regression.
8
+ #
9
+ # @example
10
+ # require 'rumale/tree/extra_tree_regressor'
11
+ #
12
+ # estimator =
13
+ # Rumale::Tree::ExtraTreeRegressor.new(
14
+ # max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
15
+ # estimator.fit(training_samples, traininig_values)
16
+ # results = estimator.predict(testing_samples)
17
+ #
18
+ # *Reference*
19
+ # - Geurts, P., Ernst, D., and Wehenkel, L., "Extremely randomized trees," Machine Learning, vol. 63 (1), pp. 3--42, 2006.
20
+ class ExtraTreeRegressor < DecisionTreeRegressor
21
+ # Return the importance for each feature.
22
+ # @return [Numo::DFloat] (size: n_features)
23
+ attr_reader :feature_importances
24
+
25
+ # Return the learned tree.
26
+ # @return [Node]
27
+ attr_reader :tree
28
+
29
+ # Return the random generator for random selection of feature index.
30
+ # @return [Random]
31
+ attr_reader :rng
32
+
33
+ # Return the values assigned each leaf.
34
+ # @return [Numo::DFloat] (shape: [n_leafs, n_outputs])
35
+ attr_reader :leaf_values
36
+
37
+ # Create a new regressor with extra randomized tree algorithm.
38
+ #
39
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
40
+ # @param max_depth [Integer] The maximum depth of the tree.
41
+ # If nil is given, extra tree grows without concern for depth.
42
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on extra tree.
43
+ # If nil is given, number of leaves is not limited.
44
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
45
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
46
+ # If nil is given, split process considers all features.
47
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
48
+ # It is used to randomly determine the order of features when deciding spliting point.
49
+ def initialize(criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
50
+ random_seed: nil)
51
+ super
52
+ end
53
+
54
+ # Fit the model with given training data.
55
+ #
56
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
57
+ # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values to be used for fitting the model.
58
+ # @return [ExtraTreeRegressor] The learned regressor itself.
59
+
60
+ # Predict values for samples.
61
+ #
62
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
63
+ # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
64
+
65
+ private
66
+
67
+ def best_split(features, y, whole_impurity)
68
+ threshold = @sub_rng.rand(features.min..features.max)
69
+ l_ids = features.le(threshold).where
70
+ r_ids = features.gt(threshold).where
71
+ l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids, true])
72
+ r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids, true])
73
+ gain = whole_impurity -
74
+ l_impurity * l_ids.size.fdiv(y.shape[0]) -
75
+ r_impurity * r_ids.size.fdiv(y.shape[0])
76
+ [l_impurity, r_impurity, threshold, gain]
77
+ end
78
+ end
79
+ end
80
+ end