rumale-tree 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,192 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/regressor'
5
+ require 'rumale/validation'
6
+ require 'rumale/tree/ext'
7
+ require 'rumale/tree/node'
8
+
9
+ module Rumale
10
+ module Tree
11
+ # GradientTreeRegressor is a class that implements decision tree for regression with exact gredy algorithm.
12
+ # This class is used internally for estimators with gradient tree boosting.
13
+ #
14
+ # *Reference*
15
+ # - Friedman, J H., "Greedy Function Approximation: A Gradient Boosting Machine," Annals of Statistics, 29 (5), pp. 1189--1232, 2001.
16
+ # - Friedman, J H., "Stochastic Gradient Boosting," Computational Statistics and Data Analysis, 38 (4), pp. 367--378, 2002.
17
+ # - Chen, T., and Guestrin, C., "XGBoost: A Scalable Tree Boosting System," Proc. KDD'16, pp. 785--794, 2016.
18
+ class GradientTreeRegressor < ::Rumale::Base::Estimator
19
+ include ::Rumale::Base::Regressor
20
+ include ::Rumale::Tree::ExtGradientTreeRegressor
21
+
22
+ # Return the importance for each feature.
23
+ # The feature importances are calculated based on the numbers of times the feature is used for splitting.
24
+ # @return [Numo::DFloat] (shape: [n_features])
25
+ attr_reader :feature_importances
26
+
27
+ # Return the learned tree.
28
+ # @return [Node]
29
+ attr_reader :tree
30
+
31
+ # Return the random generator for random selection of feature index.
32
+ # @return [Random]
33
+ attr_reader :rng
34
+
35
+ # Return the values assigned each leaf.
36
+ # @return [Numo::DFloat] (shape: [n_leaves])
37
+ attr_reader :leaf_weights
38
+
39
+ # Initialize a gradient tree regressor
40
+ #
41
+ # @param reg_lambda [Float] The L2 regularization term on weight.
42
+ # @param shrinkage_rate [Float] The shrinkage rate for weight.
43
+ # @param max_depth [Integer] The maximum depth of the tree.
44
+ # If nil is given, decision tree grows without concern for depth.
45
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
46
+ # If nil is given, number of leaves is not limited.
47
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
48
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
49
+ # If nil is given, split process considers all features.
50
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
51
+ # It is used to randomly determine the order of features when deciding spliting point.
52
+ def initialize(reg_lambda: 0.0, shrinkage_rate: 1.0,
53
+ max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil)
54
+ super()
55
+ @params = {
56
+ reg_lambda: reg_lambda,
57
+ shrinkage_rate: shrinkage_rate,
58
+ max_depth: max_depth,
59
+ max_leaf_nodes: max_leaf_nodes,
60
+ min_samples_leaf: min_samples_leaf,
61
+ max_features: max_features,
62
+ random_seed: random_seed || srand
63
+ }
64
+ @rng = Random.new(@params[:random_seed])
65
+ end
66
+
67
+ # Fit the model with given training data.
68
+ #
69
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
70
+ # @param y [Numo::DFloat] (shape: [n_samples]) The taget values to be used for fitting the model.
71
+ # @param g [Numo::DFloat] (shape: [n_samples]) The gradient of loss function.
72
+ # @param h [Numo::DFloat] (shape: [n_samples]) The hessian of loss function.
73
+ # @return [GradientTreeRegressor] The learned regressor itself.
74
+ def fit(x, y, g, h)
75
+ x = ::Rumale::Validation.check_convert_sample_array(x)
76
+ y = ::Rumale::Validation.check_convert_target_value_array(y)
77
+ ::Rumale::Validation.check_sample_size(x, y)
78
+
79
+ # Initialize some variables.
80
+ n_features = x.shape[1]
81
+ @params[:max_features] ||= n_features
82
+ @n_leaves = 0
83
+ @leaf_weights = []
84
+ @feature_importances = Numo::DFloat.zeros(n_features)
85
+ @sub_rng = @rng.dup
86
+ # Build tree.
87
+ build_tree(x, y, g, h)
88
+ @leaf_weights = Numo::DFloat[*@leaf_weights]
89
+ self
90
+ end
91
+
92
+ # Predict values for samples.
93
+ #
94
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
95
+ # @return [Numo::DFloat] (size: n_samples) Predicted values per sample.
96
+ def predict(x)
97
+ x = ::Rumale::Validation.check_convert_sample_array(x)
98
+
99
+ @leaf_weights[apply(x)].dup
100
+ end
101
+
102
+ # Return the index of the leaf that each sample reached.
103
+ #
104
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
105
+ # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
106
+ def apply(x)
107
+ x = ::Rumale::Validation.check_convert_sample_array(x)
108
+
109
+ Numo::Int32[*(Array.new(x.shape[0]) { |n| partial_apply(@tree, x[n, true]) })]
110
+ end
111
+
112
+ private
113
+
114
+ def partial_apply(tree, sample)
115
+ node = tree
116
+ until node.leaf
117
+ node = if node.right.nil?
118
+ node.left
119
+ elsif node.left.nil?
120
+ node.right
121
+ else
122
+ sample[node.feature_id] <= node.threshold ? node.left : node.right
123
+ end
124
+ end
125
+ node.leaf_id
126
+ end
127
+
128
+ def build_tree(x, y, g, h)
129
+ @feature_ids = Array.new(x.shape[1]) { |v| v }
130
+ @tree = grow_node(0, x, y, g, h)
131
+ @feature_ids = nil
132
+ nil
133
+ end
134
+
135
+ def grow_node(depth, x, y, g, h) # rubocop:disable Metrics/AbcSize
136
+ # intialize some variables.
137
+ sum_g = g.sum
138
+ sum_h = h.sum
139
+ n_samples = x.shape[0]
140
+ node = Node.new(depth: depth, n_samples: n_samples)
141
+
142
+ # terminate growing.
143
+ return nil if !@params[:max_leaf_nodes].nil? && @n_leaves >= @params[:max_leaf_nodes]
144
+ return nil if n_samples < @params[:min_samples_leaf]
145
+ return put_leaf(node, sum_g, sum_h) if n_samples == @params[:min_samples_leaf]
146
+ return put_leaf(node, sum_g, sum_h) if !@params[:max_depth].nil? && depth == @params[:max_depth]
147
+ return put_leaf(node, sum_g, sum_h) if stop_growing?(y)
148
+
149
+ # calculate optimal parameters.
150
+ feature_id, threshold, gain = rand_ids.map { |n| [n, *best_split(x[true, n], g, h, sum_g, sum_h)] }.max_by(&:last)
151
+
152
+ return put_leaf(node, sum_g, sum_h) if gain.nil? || gain.zero?
153
+
154
+ left_ids = x[true, feature_id].le(threshold).where
155
+ right_ids = x[true, feature_id].gt(threshold).where
156
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids], g[left_ids], h[left_ids])
157
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids], g[right_ids], h[right_ids])
158
+
159
+ return put_leaf(node, sum_g, sum_h) if node.left.nil? && node.right.nil?
160
+
161
+ @feature_importances[feature_id] += 1.0
162
+
163
+ node.feature_id = feature_id
164
+ node.threshold = threshold
165
+ node.leaf = false
166
+ node
167
+ end
168
+
169
+ def stop_growing?(y)
170
+ y.to_a.uniq.size == 1
171
+ end
172
+
173
+ def put_leaf(node, sum_g, sum_h)
174
+ node.probs = nil
175
+ node.leaf = true
176
+ node.leaf_id = @n_leaves
177
+ weight = -@params[:shrinkage_rate] * sum_g / (sum_h + @params[:reg_lambda])
178
+ @leaf_weights.push(weight)
179
+ @n_leaves += 1
180
+ node
181
+ end
182
+
183
+ def best_split(f, g, h, sum_g, sum_h)
184
+ find_split_params(f.sort_index, f, g, h, sum_g, sum_h, @params[:reg_lambda])
185
+ end
186
+
187
+ def rand_ids
188
+ @feature_ids.sample(@params[:max_features], random: @sub_rng)
189
+ end
190
+ end
191
+ end
192
+ end
@@ -0,0 +1,39 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Rumale
4
+ module Tree
5
+ # Node is a class that implements node used for construction of decision tree.
6
+ # This class is used for internal data structures.
7
+ class Node
8
+ # @!visibility private
9
+ attr_accessor :depth, :impurity, :n_samples, :probs, :leaf, :leaf_id, :left, :right, :feature_id, :threshold
10
+
11
+ # Create a new node for decision tree.
12
+ #
13
+ # @param depth [Integer] The depth of the node in tree.
14
+ # @param impurity [Float] The impurity of the node.
15
+ # @param n_samples [Integer] The number of the samples in the node.
16
+ # @param probs [Float] The probability of the node.
17
+ # @param leaf [Boolean] The flag indicating whether the node is a leaf.
18
+ # @param leaf_id [Integer] The leaf index of the node.
19
+ # @param left [Node] The left node.
20
+ # @param right [Node] The right node.
21
+ # @param feature_id [Integer] The feature index used for evaluation.
22
+ # @param threshold [Float] The threshold value of the feature for splitting the node.
23
+ def initialize(depth: 0, impurity: 0.0, n_samples: 0, probs: 0.0,
24
+ leaf: false, leaf_id: nil,
25
+ left: nil, right: nil, feature_id: 0, threshold: 0.0)
26
+ @depth = depth
27
+ @impurity = impurity
28
+ @n_samples = n_samples
29
+ @probs = probs
30
+ @leaf = leaf
31
+ @leaf_id = leaf_id
32
+ @left = left
33
+ @right = right
34
+ @feature_id = feature_id
35
+ @threshold = threshold
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,10 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Rumale is a machine learning library in Ruby.
4
+ module Rumale
5
+ # This module consists of the classes that implement tree models.
6
+ module Tree
7
+ # @!visibility private
8
+ VERSION = '0.24.0'
9
+ end
10
+ end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+
5
+ require_relative 'tree/version'
6
+
7
+ require_relative 'tree/decision_tree_classifier'
8
+ require_relative 'tree/decision_tree_regressor'
9
+ require_relative 'tree/extra_tree_classifier'
10
+ require_relative 'tree/extra_tree_regressor'
11
+ require_relative 'tree/gradient_tree_regressor'
metadata ADDED
@@ -0,0 +1,93 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: rumale-tree
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.24.0
5
+ platform: ruby
6
+ authors:
7
+ - yoshoku
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2022-12-31 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-narray
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: 0.9.1
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: 0.9.1
27
+ - !ruby/object:Gem::Dependency
28
+ name: rumale-core
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: 0.24.0
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: 0.24.0
41
+ description: Rumale::Tree provides classifier and regression based on decision tree
42
+ algorithms with Rumale interface.
43
+ email:
44
+ - yoshoku@outlook.com
45
+ executables: []
46
+ extensions:
47
+ - ext/rumale/tree/extconf.rb
48
+ extra_rdoc_files: []
49
+ files:
50
+ - LICENSE.txt
51
+ - README.md
52
+ - ext/rumale/tree/ext.c
53
+ - ext/rumale/tree/ext.h
54
+ - ext/rumale/tree/extconf.rb
55
+ - lib/rumale/tree.rb
56
+ - lib/rumale/tree/base_decision_tree.rb
57
+ - lib/rumale/tree/decision_tree_classifier.rb
58
+ - lib/rumale/tree/decision_tree_regressor.rb
59
+ - lib/rumale/tree/extra_tree_classifier.rb
60
+ - lib/rumale/tree/extra_tree_regressor.rb
61
+ - lib/rumale/tree/gradient_tree_regressor.rb
62
+ - lib/rumale/tree/node.rb
63
+ - lib/rumale/tree/version.rb
64
+ homepage: https://github.com/yoshoku/rumale
65
+ licenses:
66
+ - BSD-3-Clause
67
+ metadata:
68
+ homepage_uri: https://github.com/yoshoku/rumale
69
+ source_code_uri: https://github.com/yoshoku/rumale/tree/main/rumale-tree
70
+ changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
71
+ documentation_uri: https://yoshoku.github.io/rumale/doc/
72
+ rubygems_mfa_required: 'true'
73
+ post_install_message:
74
+ rdoc_options: []
75
+ require_paths:
76
+ - lib
77
+ required_ruby_version: !ruby/object:Gem::Requirement
78
+ requirements:
79
+ - - ">="
80
+ - !ruby/object:Gem::Version
81
+ version: '0'
82
+ required_rubygems_version: !ruby/object:Gem::Requirement
83
+ requirements:
84
+ - - ">="
85
+ - !ruby/object:Gem::Version
86
+ version: '0'
87
+ requirements: []
88
+ rubygems_version: 3.3.26
89
+ signing_key:
90
+ specification_version: 4
91
+ summary: Rumale::Tree provides classifier and regression based on decision tree algorithms
92
+ with Rumale interface.
93
+ test_files: []