rumale-tree 0.24.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,192 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+ require 'rumale/base/regressor'
5
+ require 'rumale/validation'
6
+ require 'rumale/tree/ext'
7
+ require 'rumale/tree/node'
8
+
9
+ module Rumale
10
+ module Tree
11
+ # GradientTreeRegressor is a class that implements decision tree for regression with exact gredy algorithm.
12
+ # This class is used internally for estimators with gradient tree boosting.
13
+ #
14
+ # *Reference*
15
+ # - Friedman, J H., "Greedy Function Approximation: A Gradient Boosting Machine," Annals of Statistics, 29 (5), pp. 1189--1232, 2001.
16
+ # - Friedman, J H., "Stochastic Gradient Boosting," Computational Statistics and Data Analysis, 38 (4), pp. 367--378, 2002.
17
+ # - Chen, T., and Guestrin, C., "XGBoost: A Scalable Tree Boosting System," Proc. KDD'16, pp. 785--794, 2016.
18
+ class GradientTreeRegressor < ::Rumale::Base::Estimator
19
+ include ::Rumale::Base::Regressor
20
+ include ::Rumale::Tree::ExtGradientTreeRegressor
21
+
22
+ # Return the importance for each feature.
23
+ # The feature importances are calculated based on the numbers of times the feature is used for splitting.
24
+ # @return [Numo::DFloat] (shape: [n_features])
25
+ attr_reader :feature_importances
26
+
27
+ # Return the learned tree.
28
+ # @return [Node]
29
+ attr_reader :tree
30
+
31
+ # Return the random generator for random selection of feature index.
32
+ # @return [Random]
33
+ attr_reader :rng
34
+
35
+ # Return the values assigned each leaf.
36
+ # @return [Numo::DFloat] (shape: [n_leaves])
37
+ attr_reader :leaf_weights
38
+
39
+ # Initialize a gradient tree regressor
40
+ #
41
+ # @param reg_lambda [Float] The L2 regularization term on weight.
42
+ # @param shrinkage_rate [Float] The shrinkage rate for weight.
43
+ # @param max_depth [Integer] The maximum depth of the tree.
44
+ # If nil is given, decision tree grows without concern for depth.
45
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
46
+ # If nil is given, number of leaves is not limited.
47
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
48
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
49
+ # If nil is given, split process considers all features.
50
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
51
+ # It is used to randomly determine the order of features when deciding spliting point.
52
+ def initialize(reg_lambda: 0.0, shrinkage_rate: 1.0,
53
+ max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil)
54
+ super()
55
+ @params = {
56
+ reg_lambda: reg_lambda,
57
+ shrinkage_rate: shrinkage_rate,
58
+ max_depth: max_depth,
59
+ max_leaf_nodes: max_leaf_nodes,
60
+ min_samples_leaf: min_samples_leaf,
61
+ max_features: max_features,
62
+ random_seed: random_seed || srand
63
+ }
64
+ @rng = Random.new(@params[:random_seed])
65
+ end
66
+
67
+ # Fit the model with given training data.
68
+ #
69
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
70
+ # @param y [Numo::DFloat] (shape: [n_samples]) The taget values to be used for fitting the model.
71
+ # @param g [Numo::DFloat] (shape: [n_samples]) The gradient of loss function.
72
+ # @param h [Numo::DFloat] (shape: [n_samples]) The hessian of loss function.
73
+ # @return [GradientTreeRegressor] The learned regressor itself.
74
+ def fit(x, y, g, h)
75
+ x = ::Rumale::Validation.check_convert_sample_array(x)
76
+ y = ::Rumale::Validation.check_convert_target_value_array(y)
77
+ ::Rumale::Validation.check_sample_size(x, y)
78
+
79
+ # Initialize some variables.
80
+ n_features = x.shape[1]
81
+ @params[:max_features] ||= n_features
82
+ @n_leaves = 0
83
+ @leaf_weights = []
84
+ @feature_importances = Numo::DFloat.zeros(n_features)
85
+ @sub_rng = @rng.dup
86
+ # Build tree.
87
+ build_tree(x, y, g, h)
88
+ @leaf_weights = Numo::DFloat[*@leaf_weights]
89
+ self
90
+ end
91
+
92
+ # Predict values for samples.
93
+ #
94
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
95
+ # @return [Numo::DFloat] (size: n_samples) Predicted values per sample.
96
+ def predict(x)
97
+ x = ::Rumale::Validation.check_convert_sample_array(x)
98
+
99
+ @leaf_weights[apply(x)].dup
100
+ end
101
+
102
+ # Return the index of the leaf that each sample reached.
103
+ #
104
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
105
+ # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
106
+ def apply(x)
107
+ x = ::Rumale::Validation.check_convert_sample_array(x)
108
+
109
+ Numo::Int32[*(Array.new(x.shape[0]) { |n| partial_apply(@tree, x[n, true]) })]
110
+ end
111
+
112
+ private
113
+
114
+ def partial_apply(tree, sample)
115
+ node = tree
116
+ until node.leaf
117
+ node = if node.right.nil?
118
+ node.left
119
+ elsif node.left.nil?
120
+ node.right
121
+ else
122
+ sample[node.feature_id] <= node.threshold ? node.left : node.right
123
+ end
124
+ end
125
+ node.leaf_id
126
+ end
127
+
128
+ def build_tree(x, y, g, h)
129
+ @feature_ids = Array.new(x.shape[1]) { |v| v }
130
+ @tree = grow_node(0, x, y, g, h)
131
+ @feature_ids = nil
132
+ nil
133
+ end
134
+
135
+ def grow_node(depth, x, y, g, h) # rubocop:disable Metrics/AbcSize
136
+ # intialize some variables.
137
+ sum_g = g.sum
138
+ sum_h = h.sum
139
+ n_samples = x.shape[0]
140
+ node = Node.new(depth: depth, n_samples: n_samples)
141
+
142
+ # terminate growing.
143
+ return nil if !@params[:max_leaf_nodes].nil? && @n_leaves >= @params[:max_leaf_nodes]
144
+ return nil if n_samples < @params[:min_samples_leaf]
145
+ return put_leaf(node, sum_g, sum_h) if n_samples == @params[:min_samples_leaf]
146
+ return put_leaf(node, sum_g, sum_h) if !@params[:max_depth].nil? && depth == @params[:max_depth]
147
+ return put_leaf(node, sum_g, sum_h) if stop_growing?(y)
148
+
149
+ # calculate optimal parameters.
150
+ feature_id, threshold, gain = rand_ids.map { |n| [n, *best_split(x[true, n], g, h, sum_g, sum_h)] }.max_by(&:last)
151
+
152
+ return put_leaf(node, sum_g, sum_h) if gain.nil? || gain.zero?
153
+
154
+ left_ids = x[true, feature_id].le(threshold).where
155
+ right_ids = x[true, feature_id].gt(threshold).where
156
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids], g[left_ids], h[left_ids])
157
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids], g[right_ids], h[right_ids])
158
+
159
+ return put_leaf(node, sum_g, sum_h) if node.left.nil? && node.right.nil?
160
+
161
+ @feature_importances[feature_id] += 1.0
162
+
163
+ node.feature_id = feature_id
164
+ node.threshold = threshold
165
+ node.leaf = false
166
+ node
167
+ end
168
+
169
+ def stop_growing?(y)
170
+ y.to_a.uniq.size == 1
171
+ end
172
+
173
+ def put_leaf(node, sum_g, sum_h)
174
+ node.probs = nil
175
+ node.leaf = true
176
+ node.leaf_id = @n_leaves
177
+ weight = -@params[:shrinkage_rate] * sum_g / (sum_h + @params[:reg_lambda])
178
+ @leaf_weights.push(weight)
179
+ @n_leaves += 1
180
+ node
181
+ end
182
+
183
+ def best_split(f, g, h, sum_g, sum_h)
184
+ find_split_params(f.sort_index, f, g, h, sum_g, sum_h, @params[:reg_lambda])
185
+ end
186
+
187
+ def rand_ids
188
+ @feature_ids.sample(@params[:max_features], random: @sub_rng)
189
+ end
190
+ end
191
+ end
192
+ end
@@ -0,0 +1,39 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Rumale
4
+ module Tree
5
+ # Node is a class that implements node used for construction of decision tree.
6
+ # This class is used for internal data structures.
7
+ class Node
8
+ # @!visibility private
9
+ attr_accessor :depth, :impurity, :n_samples, :probs, :leaf, :leaf_id, :left, :right, :feature_id, :threshold
10
+
11
+ # Create a new node for decision tree.
12
+ #
13
+ # @param depth [Integer] The depth of the node in tree.
14
+ # @param impurity [Float] The impurity of the node.
15
+ # @param n_samples [Integer] The number of the samples in the node.
16
+ # @param probs [Float] The probability of the node.
17
+ # @param leaf [Boolean] The flag indicating whether the node is a leaf.
18
+ # @param leaf_id [Integer] The leaf index of the node.
19
+ # @param left [Node] The left node.
20
+ # @param right [Node] The right node.
21
+ # @param feature_id [Integer] The feature index used for evaluation.
22
+ # @param threshold [Float] The threshold value of the feature for splitting the node.
23
+ def initialize(depth: 0, impurity: 0.0, n_samples: 0, probs: 0.0,
24
+ leaf: false, leaf_id: nil,
25
+ left: nil, right: nil, feature_id: 0, threshold: 0.0)
26
+ @depth = depth
27
+ @impurity = impurity
28
+ @n_samples = n_samples
29
+ @probs = probs
30
+ @leaf = leaf
31
+ @leaf_id = leaf_id
32
+ @left = left
33
+ @right = right
34
+ @feature_id = feature_id
35
+ @threshold = threshold
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,10 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Rumale is a machine learning library in Ruby.
4
+ module Rumale
5
+ # This module consists of the classes that implement tree models.
6
+ module Tree
7
+ # @!visibility private
8
+ VERSION = '0.24.0'
9
+ end
10
+ end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+
5
+ require_relative 'tree/version'
6
+
7
+ require_relative 'tree/decision_tree_classifier'
8
+ require_relative 'tree/decision_tree_regressor'
9
+ require_relative 'tree/extra_tree_classifier'
10
+ require_relative 'tree/extra_tree_regressor'
11
+ require_relative 'tree/gradient_tree_regressor'
metadata ADDED
@@ -0,0 +1,93 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: rumale-tree
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.24.0
5
+ platform: ruby
6
+ authors:
7
+ - yoshoku
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2022-12-31 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-narray
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: 0.9.1
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: 0.9.1
27
+ - !ruby/object:Gem::Dependency
28
+ name: rumale-core
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: 0.24.0
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: 0.24.0
41
+ description: Rumale::Tree provides classifier and regression based on decision tree
42
+ algorithms with Rumale interface.
43
+ email:
44
+ - yoshoku@outlook.com
45
+ executables: []
46
+ extensions:
47
+ - ext/rumale/tree/extconf.rb
48
+ extra_rdoc_files: []
49
+ files:
50
+ - LICENSE.txt
51
+ - README.md
52
+ - ext/rumale/tree/ext.c
53
+ - ext/rumale/tree/ext.h
54
+ - ext/rumale/tree/extconf.rb
55
+ - lib/rumale/tree.rb
56
+ - lib/rumale/tree/base_decision_tree.rb
57
+ - lib/rumale/tree/decision_tree_classifier.rb
58
+ - lib/rumale/tree/decision_tree_regressor.rb
59
+ - lib/rumale/tree/extra_tree_classifier.rb
60
+ - lib/rumale/tree/extra_tree_regressor.rb
61
+ - lib/rumale/tree/gradient_tree_regressor.rb
62
+ - lib/rumale/tree/node.rb
63
+ - lib/rumale/tree/version.rb
64
+ homepage: https://github.com/yoshoku/rumale
65
+ licenses:
66
+ - BSD-3-Clause
67
+ metadata:
68
+ homepage_uri: https://github.com/yoshoku/rumale
69
+ source_code_uri: https://github.com/yoshoku/rumale/tree/main/rumale-tree
70
+ changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
71
+ documentation_uri: https://yoshoku.github.io/rumale/doc/
72
+ rubygems_mfa_required: 'true'
73
+ post_install_message:
74
+ rdoc_options: []
75
+ require_paths:
76
+ - lib
77
+ required_ruby_version: !ruby/object:Gem::Requirement
78
+ requirements:
79
+ - - ">="
80
+ - !ruby/object:Gem::Version
81
+ version: '0'
82
+ required_rubygems_version: !ruby/object:Gem::Requirement
83
+ requirements:
84
+ - - ">="
85
+ - !ruby/object:Gem::Version
86
+ version: '0'
87
+ requirements: []
88
+ rubygems_version: 3.3.26
89
+ signing_key:
90
+ specification_version: 4
91
+ summary: Rumale::Tree provides classifier and regression based on decision tree algorithms
92
+ with Rumale interface.
93
+ test_files: []