rumale-tree 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,154 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/validation'
5
+ require 'rumale/tree/ext'
6
+ require 'rumale/tree/node'
7
+
8
+ module Rumale
9
+ module Tree
10
+ # BaseDecisionTree is an abstract class for implementation of decision tree-based estimator.
11
+ # This class is used internally.
12
+ class BaseDecisionTree < ::Rumale::Base::Estimator
13
+ # Initialize a decision tree-based estimator.
14
+ #
15
+ # @param criterion [String] The function to evalue spliting point.
16
+ # @param max_depth [Integer] The maximum depth of the tree.
17
+ # If nil is given, decision tree grows without concern for depth.
18
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
19
+ # If nil is given, number of leaves is not limited.
20
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
21
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
22
+ # If nil is given, split process considers all features.
23
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
24
+ # It is used to randomly determine the order of features when deciding spliting point.
25
+ def initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
26
+ random_seed: nil)
27
+ super()
28
+ @params = {
29
+ criterion: criterion,
30
+ max_depth: max_depth,
31
+ max_leaf_nodes: max_leaf_nodes,
32
+ min_samples_leaf: min_samples_leaf,
33
+ max_features: max_features,
34
+ random_seed: random_seed || srand
35
+ }
36
+ @rng = Random.new(@params[:random_seed])
37
+ end
38
+
39
+ # Return the index of the leaf that each sample reached.
40
+ #
41
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
42
+ # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
43
+ def apply(x)
44
+ x = ::Rumale::Validation.check_convert_sample_array(x)
45
+
46
+ Numo::Int32[*(Array.new(x.shape[0]) { |n| partial_apply(@tree, x[n, true]) })]
47
+ end
48
+
49
+ private
50
+
51
+ def partial_apply(tree, sample)
52
+ node = tree
53
+ until node.leaf
54
+ node = if node.right.nil?
55
+ node.left
56
+ elsif node.left.nil?
57
+ node.right
58
+ else
59
+ sample[node.feature_id] <= node.threshold ? node.left : node.right
60
+ end
61
+ end
62
+ node.leaf_id
63
+ end
64
+
65
+ def build_tree(x, y)
66
+ y = y.expand_dims(1).dup if y.shape[1].nil?
67
+ @feature_ids = Array.new(x.shape[1]) { |v| v }
68
+ @tree = grow_node(0, x, y, impurity(y))
69
+ @feature_ids = nil
70
+ nil
71
+ end
72
+
73
+ def grow_node(depth, x, y, impurity) # rubocop:disable Metrics/AbcSize, Metrics/PerceivedComplexity
74
+ # intialize node.
75
+ n_samples = x.shape[0]
76
+ node = Node.new(depth: depth, impurity: impurity, n_samples: n_samples)
77
+
78
+ # terminate growing.
79
+ return nil if !@params[:max_leaf_nodes].nil? && @n_leaves >= @params[:max_leaf_nodes]
80
+ return nil if n_samples < @params[:min_samples_leaf]
81
+ return put_leaf(node, y) if n_samples == @params[:min_samples_leaf]
82
+ return put_leaf(node, y) if !@params[:max_depth].nil? && depth == @params[:max_depth]
83
+ return put_leaf(node, y) if stop_growing?(y)
84
+
85
+ # calculate optimal parameters.
86
+ feature_id, left_imp, right_imp, threshold, gain =
87
+ rand_ids.map { |n| [n, *best_split(x[true, n], y, impurity)] }.max_by(&:last)
88
+
89
+ return put_leaf(node, y) if gain.nil? || gain.zero?
90
+
91
+ left_ids = x[true, feature_id].le(threshold).where
92
+ right_ids = x[true, feature_id].gt(threshold).where
93
+ node.left = if y.ndim == 1
94
+ grow_node(depth + 1, x[left_ids, true], y[left_ids], left_imp)
95
+ else
96
+ grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_imp)
97
+ end
98
+ node.right = if y.ndim == 1
99
+ grow_node(depth + 1, x[right_ids, true], y[right_ids], right_imp)
100
+ else
101
+ grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_imp)
102
+ end
103
+
104
+ return put_leaf(node, y) if node.left.nil? && node.right.nil?
105
+
106
+ node.feature_id = feature_id
107
+ node.threshold = threshold
108
+ node.leaf = false
109
+ node
110
+ end
111
+
112
+ def stop_growing?(_y)
113
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
114
+ end
115
+
116
+ def put_leaf(_node, _y)
117
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
118
+ end
119
+
120
+ def rand_ids
121
+ @feature_ids.sample(@params[:max_features], random: @sub_rng)
122
+ end
123
+
124
+ def best_split(_features, _y, _impurity)
125
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
126
+ end
127
+
128
+ def impurity(_y)
129
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
130
+ end
131
+
132
+ def eval_importance(n_samples, n_features)
133
+ @feature_importances = Numo::DFloat.zeros(n_features)
134
+ eval_importance_at_node(@tree)
135
+ @feature_importances /= n_samples
136
+ normalizer = @feature_importances.sum
137
+ @feature_importances /= normalizer if normalizer > 0.0
138
+ nil
139
+ end
140
+
141
+ def eval_importance_at_node(node)
142
+ return nil if node.leaf
143
+ return nil if node.left.nil? || node.right.nil?
144
+
145
+ gain = node.n_samples * node.impurity -
146
+ node.left.n_samples * node.left.impurity -
147
+ node.right.n_samples * node.right.impurity
148
+ @feature_importances[node.feature_id] += gain
149
+ eval_importance_at_node(node.left)
150
+ eval_importance_at_node(node.right)
151
+ end
152
+ end
153
+ end
154
+ end
@@ -0,0 +1,148 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/base_decision_tree'
4
+ require 'rumale/base/classifier'
5
+
6
+ module Rumale
7
+ module Tree
8
+ # DecisionTreeClassifier is a class that implements decision tree for classification.
9
+ #
10
+ # @example
11
+ # require 'rumale/tree/decision_tree_classifier'
12
+ #
13
+ # estimator =
14
+ # Rumale::Tree::DecisionTreeClassifier.new(
15
+ # criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
16
+ # estimator.fit(training_samples, traininig_labels)
17
+ # results = estimator.predict(testing_samples)
18
+ #
19
+ class DecisionTreeClassifier < BaseDecisionTree
20
+ include ::Rumale::Base::Classifier
21
+ include ::Rumale::Tree::ExtDecisionTreeClassifier
22
+
23
+ # Return the class labels.
24
+ # @return [Numo::Int32] (size: n_classes)
25
+ attr_reader :classes
26
+
27
+ # Return the importance for each feature.
28
+ # @return [Numo::DFloat] (size: n_features)
29
+ attr_reader :feature_importances
30
+
31
+ # Return the learned tree.
32
+ # @return [Node]
33
+ attr_reader :tree
34
+
35
+ # Return the random generator for random selection of feature index.
36
+ # @return [Random]
37
+ attr_reader :rng
38
+
39
+ # Return the labels assigned each leaf.
40
+ # @return [Numo::Int32] (size: n_leafs)
41
+ attr_reader :leaf_labels
42
+
43
+ # Create a new classifier with decision tree algorithm.
44
+ #
45
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
46
+ # @param max_depth [Integer] The maximum depth of the tree.
47
+ # If nil is given, decision tree grows without concern for depth.
48
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
49
+ # If nil is given, number of leaves is not limited.
50
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
51
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
52
+ # If nil is given, split process considers all features.
53
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
54
+ # It is used to randomly determine the order of features when deciding spliting point.
55
+ def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
56
+ random_seed: nil)
57
+ super
58
+ end
59
+
60
+ # Fit the model with given training data.
61
+ #
62
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
63
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
64
+ # @return [DecisionTreeClassifier] The learned classifier itself.
65
+ def fit(x, y)
66
+ x = ::Rumale::Validation.check_convert_sample_array(x)
67
+ y = ::Rumale::Validation.check_convert_label_array(y)
68
+ ::Rumale::Validation.check_sample_size(x, y)
69
+
70
+ n_samples, n_features = x.shape
71
+ @params[:max_features] = n_features if @params[:max_features].nil?
72
+ @params[:max_features] = [@params[:max_features], n_features].min
73
+ y = Numo::Int32.cast(y) unless y.is_a?(Numo::Int32)
74
+ uniq_y = y.to_a.uniq.sort
75
+ @classes = Numo::Int32.asarray(uniq_y)
76
+ @n_leaves = 0
77
+ @leaf_labels = []
78
+ @feature_ids = Array.new(n_features) { |v| v }
79
+ @sub_rng = @rng.dup
80
+ build_tree(x, y.map { |v| uniq_y.index(v) })
81
+ eval_importance(n_samples, n_features)
82
+ @leaf_labels = Numo::Int32[*@leaf_labels]
83
+ self
84
+ end
85
+
86
+ # Predict class labels for samples.
87
+ #
88
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
89
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
90
+ def predict(x)
91
+ x = ::Rumale::Validation.check_convert_sample_array(x)
92
+
93
+ @leaf_labels[apply(x)].dup
94
+ end
95
+
96
+ # Predict probability for samples.
97
+ #
98
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
99
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
100
+ def predict_proba(x)
101
+ x = ::Rumale::Validation.check_convert_sample_array(x)
102
+
103
+ Numo::DFloat[*(Array.new(x.shape[0]) { |n| partial_predict_proba(@tree, x[n, true]) })]
104
+ end
105
+
106
+ private
107
+
108
+ def partial_predict_proba(tree, sample)
109
+ node = tree
110
+ until node.leaf
111
+ node = if node.right.nil?
112
+ node.left
113
+ elsif node.left.nil?
114
+ node.right
115
+ else
116
+ sample[node.feature_id] <= node.threshold ? node.left : node.right
117
+ end
118
+ end
119
+ node.probs
120
+ end
121
+
122
+ def build_tree(x, y)
123
+ @tree = grow_node(0, x, y, impurity(y))
124
+ nil
125
+ end
126
+
127
+ def put_leaf(node, y)
128
+ node.probs = y.bincount(minlength: @classes.size) / node.n_samples.to_f
129
+ node.leaf = true
130
+ node.leaf_id = @n_leaves
131
+ @n_leaves += 1
132
+ @leaf_labels.push(@classes[node.probs.max_index])
133
+ node
134
+ end
135
+
136
+ def best_split(features, y, whole_impurity)
137
+ order = features.sort_index
138
+ n_classes = @classes.size
139
+ find_split_params(@params[:criterion], whole_impurity, order, features, y, n_classes)
140
+ end
141
+
142
+ def impurity(y)
143
+ n_classes = @classes.size
144
+ node_impurity(@params[:criterion], y, n_classes)
145
+ end
146
+ end
147
+ end
148
+ end
@@ -0,0 +1,113 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/base_decision_tree'
4
+ require 'rumale/base/regressor'
5
+
6
+ module Rumale
7
+ module Tree
8
+ # DecisionTreeRegressor is a class that implements decision tree for regression.
9
+ #
10
+ # @example
11
+ # require 'rumale/tree/decision_tree_regressor'
12
+ #
13
+ # estimator =
14
+ # Rumale::Tree::DecisionTreeRegressor.new(
15
+ # max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
16
+ # estimator.fit(training_samples, traininig_values)
17
+ # results = estimator.predict(testing_samples)
18
+ #
19
+ class DecisionTreeRegressor < BaseDecisionTree
20
+ include ::Rumale::Base::Regressor
21
+ include ::Rumale::Tree::ExtDecisionTreeRegressor
22
+
23
+ # Return the importance for each feature.
24
+ # @return [Numo::DFloat] (size: n_features)
25
+ attr_reader :feature_importances
26
+
27
+ # Return the learned tree.
28
+ # @return [Node]
29
+ attr_reader :tree
30
+
31
+ # Return the random generator for random selection of feature index.
32
+ # @return [Random]
33
+ attr_reader :rng
34
+
35
+ # Return the values assigned each leaf.
36
+ # @return [Numo::DFloat] (shape: [n_leafs, n_outputs])
37
+ attr_reader :leaf_values
38
+
39
+ # Create a new regressor with decision tree algorithm.
40
+ #
41
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
42
+ # @param max_depth [Integer] The maximum depth of the tree.
43
+ # If nil is given, decision tree grows without concern for depth.
44
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
45
+ # If nil is given, number of leaves is not limited.
46
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
47
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
48
+ # If nil is given, split process considers all features.
49
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
50
+ # It is used to randomly determine the order of features when deciding spliting point.
51
+ def initialize(criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
52
+ random_seed: nil)
53
+ super
54
+ end
55
+
56
+ # Fit the model with given training data.
57
+ #
58
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
59
+ # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values to be used for fitting the model.
60
+ # @return [DecisionTreeRegressor] The learned regressor itself.
61
+ def fit(x, y)
62
+ x = ::Rumale::Validation.check_convert_sample_array(x)
63
+ y = ::Rumale::Validation.check_convert_target_value_array(y)
64
+ ::Rumale::Validation.check_sample_size(x, y)
65
+
66
+ n_samples, n_features = x.shape
67
+ @params[:max_features] = n_features if @params[:max_features].nil?
68
+ @params[:max_features] = [@params[:max_features], n_features].min
69
+ @n_leaves = 0
70
+ @leaf_values = []
71
+ @sub_rng = @rng.dup
72
+ build_tree(x, y)
73
+ eval_importance(n_samples, n_features)
74
+ @leaf_values = Numo::DFloat.cast(@leaf_values)
75
+ @leaf_values = @leaf_values.flatten.dup if @leaf_values.shape[1] == 1
76
+ self
77
+ end
78
+
79
+ # Predict values for samples.
80
+ #
81
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
82
+ # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
83
+ def predict(x)
84
+ x = ::Rumale::Validation.check_convert_sample_array(x)
85
+
86
+ @leaf_values.shape[1].nil? ? @leaf_values[apply(x)].dup : @leaf_values[apply(x), true].dup
87
+ end
88
+
89
+ private
90
+
91
+ def stop_growing?(y)
92
+ y.to_a.uniq.size == 1
93
+ end
94
+
95
+ def put_leaf(node, y)
96
+ node.probs = nil
97
+ node.leaf = true
98
+ node.leaf_id = @n_leaves
99
+ @n_leaves += 1
100
+ @leaf_values.push(y.mean(0))
101
+ node
102
+ end
103
+
104
+ def best_split(f, y, impurity)
105
+ find_split_params(@params[:criterion], impurity, f.sort_index, f, y)
106
+ end
107
+
108
+ def impurity(y)
109
+ node_impurity(@params[:criterion], y.to_a)
110
+ end
111
+ end
112
+ end
113
+ end
@@ -0,0 +1,89 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/decision_tree_classifier'
4
+
5
+ module Rumale
6
+ module Tree
7
+ # ExtraTreeClassifier is a class that implements extra randomized tree for classification.
8
+ #
9
+ # @example
10
+ # require 'rumale/tree/extra_tree_classifier'
11
+ #
12
+ # estimator =
13
+ # Rumale::Tree::ExtraTreeClassifier.new(
14
+ # criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
15
+ # estimator.fit(training_samples, traininig_labels)
16
+ # results = estimator.predict(testing_samples)
17
+ #
18
+ # *Reference*
19
+ # - Geurts, P., Ernst, D., and Wehenkel, L., "Extremely randomized trees," Machine Learning, vol. 63 (1), pp. 3--42, 2006.
20
+ class ExtraTreeClassifier < DecisionTreeClassifier
21
+ # Return the class labels.
22
+ # @return [Numo::Int32] (size: n_classes)
23
+ attr_reader :classes
24
+
25
+ # Return the importance for each feature.
26
+ # @return [Numo::DFloat] (size: n_features)
27
+ attr_reader :feature_importances
28
+
29
+ # Return the learned tree.
30
+ # @return [Node]
31
+ attr_reader :tree
32
+
33
+ # Return the random generator for random selection of feature index.
34
+ # @return [Random]
35
+ attr_reader :rng
36
+
37
+ # Return the labels assigned each leaf.
38
+ # @return [Numo::Int32] (size: n_leafs)
39
+ attr_reader :leaf_labels
40
+
41
+ # Create a new classifier with extra randomized tree algorithm.
42
+ #
43
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
44
+ # @param max_depth [Integer] The maximum depth of the tree.
45
+ # If nil is given, extra tree grows without concern for depth.
46
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on extra tree.
47
+ # If nil is given, number of leaves is not limited.
48
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
49
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
50
+ # If nil is given, split process considers all features.
51
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
52
+ # It is used to randomly determine the order of features when deciding spliting point.
53
+ def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
54
+ random_seed: nil)
55
+ super
56
+ end
57
+
58
+ # Fit the model with given training data.
59
+ #
60
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
61
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
62
+ # @return [ExtraTreeClassifier] The learned classifier itself.
63
+
64
+ # Predict class labels for samples.
65
+ #
66
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
67
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
68
+
69
+ # Predict probability for samples.
70
+ #
71
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
72
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
73
+
74
+ private
75
+
76
+ def best_split(features, y, whole_impurity)
77
+ threshold = @sub_rng.rand(features.min..features.max)
78
+ l_ids = features.le(threshold).where
79
+ r_ids = features.gt(threshold).where
80
+ l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids])
81
+ r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids])
82
+ gain = whole_impurity -
83
+ l_impurity * l_ids.size.fdiv(y.size) -
84
+ r_impurity * r_ids.size.fdiv(y.size)
85
+ [l_impurity, r_impurity, threshold, gain]
86
+ end
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,80 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/tree/decision_tree_regressor'
4
+
5
+ module Rumale
6
+ module Tree
7
+ # ExtraTreeRegressor is a class that implements extra randomized tree for regression.
8
+ #
9
+ # @example
10
+ # require 'rumale/tree/extra_tree_regressor'
11
+ #
12
+ # estimator =
13
+ # Rumale::Tree::ExtraTreeRegressor.new(
14
+ # max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
15
+ # estimator.fit(training_samples, traininig_values)
16
+ # results = estimator.predict(testing_samples)
17
+ #
18
+ # *Reference*
19
+ # - Geurts, P., Ernst, D., and Wehenkel, L., "Extremely randomized trees," Machine Learning, vol. 63 (1), pp. 3--42, 2006.
20
+ class ExtraTreeRegressor < DecisionTreeRegressor
21
+ # Return the importance for each feature.
22
+ # @return [Numo::DFloat] (size: n_features)
23
+ attr_reader :feature_importances
24
+
25
+ # Return the learned tree.
26
+ # @return [Node]
27
+ attr_reader :tree
28
+
29
+ # Return the random generator for random selection of feature index.
30
+ # @return [Random]
31
+ attr_reader :rng
32
+
33
+ # Return the values assigned each leaf.
34
+ # @return [Numo::DFloat] (shape: [n_leafs, n_outputs])
35
+ attr_reader :leaf_values
36
+
37
+ # Create a new regressor with extra randomized tree algorithm.
38
+ #
39
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
40
+ # @param max_depth [Integer] The maximum depth of the tree.
41
+ # If nil is given, extra tree grows without concern for depth.
42
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on extra tree.
43
+ # If nil is given, number of leaves is not limited.
44
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
45
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
46
+ # If nil is given, split process considers all features.
47
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
48
+ # It is used to randomly determine the order of features when deciding spliting point.
49
+ def initialize(criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
50
+ random_seed: nil)
51
+ super
52
+ end
53
+
54
+ # Fit the model with given training data.
55
+ #
56
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
57
+ # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values to be used for fitting the model.
58
+ # @return [ExtraTreeRegressor] The learned regressor itself.
59
+
60
+ # Predict values for samples.
61
+ #
62
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
63
+ # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
64
+
65
+ private
66
+
67
+ def best_split(features, y, whole_impurity)
68
+ threshold = @sub_rng.rand(features.min..features.max)
69
+ l_ids = features.le(threshold).where
70
+ r_ids = features.gt(threshold).where
71
+ l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids, true])
72
+ r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids, true])
73
+ gain = whole_impurity -
74
+ l_impurity * l_ids.size.fdiv(y.shape[0]) -
75
+ r_impurity * r_ids.size.fdiv(y.shape[0])
76
+ [l_impurity, r_impurity, threshold, gain]
77
+ end
78
+ end
79
+ end
80
+ end