rumale-tree 0.24.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/LICENSE.txt +27 -0
- data/README.md +33 -0
- data/ext/rumale/tree/ext.c +575 -0
- data/ext/rumale/tree/ext.h +12 -0
- data/ext/rumale/tree/extconf.rb +32 -0
- data/lib/rumale/tree/base_decision_tree.rb +154 -0
- data/lib/rumale/tree/decision_tree_classifier.rb +148 -0
- data/lib/rumale/tree/decision_tree_regressor.rb +113 -0
- data/lib/rumale/tree/extra_tree_classifier.rb +89 -0
- data/lib/rumale/tree/extra_tree_regressor.rb +80 -0
- data/lib/rumale/tree/gradient_tree_regressor.rb +192 -0
- data/lib/rumale/tree/node.rb +39 -0
- data/lib/rumale/tree/version.rb +10 -0
- data/lib/rumale/tree.rb +11 -0
- metadata +93 -0
@@ -0,0 +1,192 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/estimator'
|
4
|
+
require 'rumale/base/regressor'
|
5
|
+
require 'rumale/validation'
|
6
|
+
require 'rumale/tree/ext'
|
7
|
+
require 'rumale/tree/node'
|
8
|
+
|
9
|
+
module Rumale
|
10
|
+
module Tree
|
11
|
+
# GradientTreeRegressor is a class that implements decision tree for regression with exact gredy algorithm.
|
12
|
+
# This class is used internally for estimators with gradient tree boosting.
|
13
|
+
#
|
14
|
+
# *Reference*
|
15
|
+
# - Friedman, J H., "Greedy Function Approximation: A Gradient Boosting Machine," Annals of Statistics, 29 (5), pp. 1189--1232, 2001.
|
16
|
+
# - Friedman, J H., "Stochastic Gradient Boosting," Computational Statistics and Data Analysis, 38 (4), pp. 367--378, 2002.
|
17
|
+
# - Chen, T., and Guestrin, C., "XGBoost: A Scalable Tree Boosting System," Proc. KDD'16, pp. 785--794, 2016.
|
18
|
+
class GradientTreeRegressor < ::Rumale::Base::Estimator
|
19
|
+
include ::Rumale::Base::Regressor
|
20
|
+
include ::Rumale::Tree::ExtGradientTreeRegressor
|
21
|
+
|
22
|
+
# Return the importance for each feature.
|
23
|
+
# The feature importances are calculated based on the numbers of times the feature is used for splitting.
|
24
|
+
# @return [Numo::DFloat] (shape: [n_features])
|
25
|
+
attr_reader :feature_importances
|
26
|
+
|
27
|
+
# Return the learned tree.
|
28
|
+
# @return [Node]
|
29
|
+
attr_reader :tree
|
30
|
+
|
31
|
+
# Return the random generator for random selection of feature index.
|
32
|
+
# @return [Random]
|
33
|
+
attr_reader :rng
|
34
|
+
|
35
|
+
# Return the values assigned each leaf.
|
36
|
+
# @return [Numo::DFloat] (shape: [n_leaves])
|
37
|
+
attr_reader :leaf_weights
|
38
|
+
|
39
|
+
# Initialize a gradient tree regressor
|
40
|
+
#
|
41
|
+
# @param reg_lambda [Float] The L2 regularization term on weight.
|
42
|
+
# @param shrinkage_rate [Float] The shrinkage rate for weight.
|
43
|
+
# @param max_depth [Integer] The maximum depth of the tree.
|
44
|
+
# If nil is given, decision tree grows without concern for depth.
|
45
|
+
# @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
|
46
|
+
# If nil is given, number of leaves is not limited.
|
47
|
+
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
|
48
|
+
# @param max_features [Integer] The number of features to consider when searching optimal split point.
|
49
|
+
# If nil is given, split process considers all features.
|
50
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
51
|
+
# It is used to randomly determine the order of features when deciding spliting point.
|
52
|
+
def initialize(reg_lambda: 0.0, shrinkage_rate: 1.0,
|
53
|
+
max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil)
|
54
|
+
super()
|
55
|
+
@params = {
|
56
|
+
reg_lambda: reg_lambda,
|
57
|
+
shrinkage_rate: shrinkage_rate,
|
58
|
+
max_depth: max_depth,
|
59
|
+
max_leaf_nodes: max_leaf_nodes,
|
60
|
+
min_samples_leaf: min_samples_leaf,
|
61
|
+
max_features: max_features,
|
62
|
+
random_seed: random_seed || srand
|
63
|
+
}
|
64
|
+
@rng = Random.new(@params[:random_seed])
|
65
|
+
end
|
66
|
+
|
67
|
+
# Fit the model with given training data.
|
68
|
+
#
|
69
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
70
|
+
# @param y [Numo::DFloat] (shape: [n_samples]) The taget values to be used for fitting the model.
|
71
|
+
# @param g [Numo::DFloat] (shape: [n_samples]) The gradient of loss function.
|
72
|
+
# @param h [Numo::DFloat] (shape: [n_samples]) The hessian of loss function.
|
73
|
+
# @return [GradientTreeRegressor] The learned regressor itself.
|
74
|
+
def fit(x, y, g, h)
|
75
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
76
|
+
y = ::Rumale::Validation.check_convert_target_value_array(y)
|
77
|
+
::Rumale::Validation.check_sample_size(x, y)
|
78
|
+
|
79
|
+
# Initialize some variables.
|
80
|
+
n_features = x.shape[1]
|
81
|
+
@params[:max_features] ||= n_features
|
82
|
+
@n_leaves = 0
|
83
|
+
@leaf_weights = []
|
84
|
+
@feature_importances = Numo::DFloat.zeros(n_features)
|
85
|
+
@sub_rng = @rng.dup
|
86
|
+
# Build tree.
|
87
|
+
build_tree(x, y, g, h)
|
88
|
+
@leaf_weights = Numo::DFloat[*@leaf_weights]
|
89
|
+
self
|
90
|
+
end
|
91
|
+
|
92
|
+
# Predict values for samples.
|
93
|
+
#
|
94
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
|
95
|
+
# @return [Numo::DFloat] (size: n_samples) Predicted values per sample.
|
96
|
+
def predict(x)
|
97
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
98
|
+
|
99
|
+
@leaf_weights[apply(x)].dup
|
100
|
+
end
|
101
|
+
|
102
|
+
# Return the index of the leaf that each sample reached.
|
103
|
+
#
|
104
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
105
|
+
# @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
|
106
|
+
def apply(x)
|
107
|
+
x = ::Rumale::Validation.check_convert_sample_array(x)
|
108
|
+
|
109
|
+
Numo::Int32[*(Array.new(x.shape[0]) { |n| partial_apply(@tree, x[n, true]) })]
|
110
|
+
end
|
111
|
+
|
112
|
+
private
|
113
|
+
|
114
|
+
def partial_apply(tree, sample)
|
115
|
+
node = tree
|
116
|
+
until node.leaf
|
117
|
+
node = if node.right.nil?
|
118
|
+
node.left
|
119
|
+
elsif node.left.nil?
|
120
|
+
node.right
|
121
|
+
else
|
122
|
+
sample[node.feature_id] <= node.threshold ? node.left : node.right
|
123
|
+
end
|
124
|
+
end
|
125
|
+
node.leaf_id
|
126
|
+
end
|
127
|
+
|
128
|
+
def build_tree(x, y, g, h)
|
129
|
+
@feature_ids = Array.new(x.shape[1]) { |v| v }
|
130
|
+
@tree = grow_node(0, x, y, g, h)
|
131
|
+
@feature_ids = nil
|
132
|
+
nil
|
133
|
+
end
|
134
|
+
|
135
|
+
def grow_node(depth, x, y, g, h) # rubocop:disable Metrics/AbcSize
|
136
|
+
# intialize some variables.
|
137
|
+
sum_g = g.sum
|
138
|
+
sum_h = h.sum
|
139
|
+
n_samples = x.shape[0]
|
140
|
+
node = Node.new(depth: depth, n_samples: n_samples)
|
141
|
+
|
142
|
+
# terminate growing.
|
143
|
+
return nil if !@params[:max_leaf_nodes].nil? && @n_leaves >= @params[:max_leaf_nodes]
|
144
|
+
return nil if n_samples < @params[:min_samples_leaf]
|
145
|
+
return put_leaf(node, sum_g, sum_h) if n_samples == @params[:min_samples_leaf]
|
146
|
+
return put_leaf(node, sum_g, sum_h) if !@params[:max_depth].nil? && depth == @params[:max_depth]
|
147
|
+
return put_leaf(node, sum_g, sum_h) if stop_growing?(y)
|
148
|
+
|
149
|
+
# calculate optimal parameters.
|
150
|
+
feature_id, threshold, gain = rand_ids.map { |n| [n, *best_split(x[true, n], g, h, sum_g, sum_h)] }.max_by(&:last)
|
151
|
+
|
152
|
+
return put_leaf(node, sum_g, sum_h) if gain.nil? || gain.zero?
|
153
|
+
|
154
|
+
left_ids = x[true, feature_id].le(threshold).where
|
155
|
+
right_ids = x[true, feature_id].gt(threshold).where
|
156
|
+
node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids], g[left_ids], h[left_ids])
|
157
|
+
node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids], g[right_ids], h[right_ids])
|
158
|
+
|
159
|
+
return put_leaf(node, sum_g, sum_h) if node.left.nil? && node.right.nil?
|
160
|
+
|
161
|
+
@feature_importances[feature_id] += 1.0
|
162
|
+
|
163
|
+
node.feature_id = feature_id
|
164
|
+
node.threshold = threshold
|
165
|
+
node.leaf = false
|
166
|
+
node
|
167
|
+
end
|
168
|
+
|
169
|
+
def stop_growing?(y)
|
170
|
+
y.to_a.uniq.size == 1
|
171
|
+
end
|
172
|
+
|
173
|
+
def put_leaf(node, sum_g, sum_h)
|
174
|
+
node.probs = nil
|
175
|
+
node.leaf = true
|
176
|
+
node.leaf_id = @n_leaves
|
177
|
+
weight = -@params[:shrinkage_rate] * sum_g / (sum_h + @params[:reg_lambda])
|
178
|
+
@leaf_weights.push(weight)
|
179
|
+
@n_leaves += 1
|
180
|
+
node
|
181
|
+
end
|
182
|
+
|
183
|
+
def best_split(f, g, h, sum_g, sum_h)
|
184
|
+
find_split_params(f.sort_index, f, g, h, sum_g, sum_h, @params[:reg_lambda])
|
185
|
+
end
|
186
|
+
|
187
|
+
def rand_ids
|
188
|
+
@feature_ids.sample(@params[:max_features], random: @sub_rng)
|
189
|
+
end
|
190
|
+
end
|
191
|
+
end
|
192
|
+
end
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Rumale
|
4
|
+
module Tree
|
5
|
+
# Node is a class that implements node used for construction of decision tree.
|
6
|
+
# This class is used for internal data structures.
|
7
|
+
class Node
|
8
|
+
# @!visibility private
|
9
|
+
attr_accessor :depth, :impurity, :n_samples, :probs, :leaf, :leaf_id, :left, :right, :feature_id, :threshold
|
10
|
+
|
11
|
+
# Create a new node for decision tree.
|
12
|
+
#
|
13
|
+
# @param depth [Integer] The depth of the node in tree.
|
14
|
+
# @param impurity [Float] The impurity of the node.
|
15
|
+
# @param n_samples [Integer] The number of the samples in the node.
|
16
|
+
# @param probs [Float] The probability of the node.
|
17
|
+
# @param leaf [Boolean] The flag indicating whether the node is a leaf.
|
18
|
+
# @param leaf_id [Integer] The leaf index of the node.
|
19
|
+
# @param left [Node] The left node.
|
20
|
+
# @param right [Node] The right node.
|
21
|
+
# @param feature_id [Integer] The feature index used for evaluation.
|
22
|
+
# @param threshold [Float] The threshold value of the feature for splitting the node.
|
23
|
+
def initialize(depth: 0, impurity: 0.0, n_samples: 0, probs: 0.0,
|
24
|
+
leaf: false, leaf_id: nil,
|
25
|
+
left: nil, right: nil, feature_id: 0, threshold: 0.0)
|
26
|
+
@depth = depth
|
27
|
+
@impurity = impurity
|
28
|
+
@n_samples = n_samples
|
29
|
+
@probs = probs
|
30
|
+
@leaf = leaf
|
31
|
+
@leaf_id = leaf_id
|
32
|
+
@left = left
|
33
|
+
@right = right
|
34
|
+
@feature_id = feature_id
|
35
|
+
@threshold = threshold
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
data/lib/rumale/tree.rb
ADDED
@@ -0,0 +1,11 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/narray'
|
4
|
+
|
5
|
+
require_relative 'tree/version'
|
6
|
+
|
7
|
+
require_relative 'tree/decision_tree_classifier'
|
8
|
+
require_relative 'tree/decision_tree_regressor'
|
9
|
+
require_relative 'tree/extra_tree_classifier'
|
10
|
+
require_relative 'tree/extra_tree_regressor'
|
11
|
+
require_relative 'tree/gradient_tree_regressor'
|
metadata
ADDED
@@ -0,0 +1,93 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: rumale-tree
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.24.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- yoshoku
|
8
|
+
autorequire:
|
9
|
+
bindir: exe
|
10
|
+
cert_chain: []
|
11
|
+
date: 2022-12-31 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: numo-narray
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: 0.9.1
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: 0.9.1
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rumale-core
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - "~>"
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: 0.24.0
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - "~>"
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: 0.24.0
|
41
|
+
description: Rumale::Tree provides classifier and regression based on decision tree
|
42
|
+
algorithms with Rumale interface.
|
43
|
+
email:
|
44
|
+
- yoshoku@outlook.com
|
45
|
+
executables: []
|
46
|
+
extensions:
|
47
|
+
- ext/rumale/tree/extconf.rb
|
48
|
+
extra_rdoc_files: []
|
49
|
+
files:
|
50
|
+
- LICENSE.txt
|
51
|
+
- README.md
|
52
|
+
- ext/rumale/tree/ext.c
|
53
|
+
- ext/rumale/tree/ext.h
|
54
|
+
- ext/rumale/tree/extconf.rb
|
55
|
+
- lib/rumale/tree.rb
|
56
|
+
- lib/rumale/tree/base_decision_tree.rb
|
57
|
+
- lib/rumale/tree/decision_tree_classifier.rb
|
58
|
+
- lib/rumale/tree/decision_tree_regressor.rb
|
59
|
+
- lib/rumale/tree/extra_tree_classifier.rb
|
60
|
+
- lib/rumale/tree/extra_tree_regressor.rb
|
61
|
+
- lib/rumale/tree/gradient_tree_regressor.rb
|
62
|
+
- lib/rumale/tree/node.rb
|
63
|
+
- lib/rumale/tree/version.rb
|
64
|
+
homepage: https://github.com/yoshoku/rumale
|
65
|
+
licenses:
|
66
|
+
- BSD-3-Clause
|
67
|
+
metadata:
|
68
|
+
homepage_uri: https://github.com/yoshoku/rumale
|
69
|
+
source_code_uri: https://github.com/yoshoku/rumale/tree/main/rumale-tree
|
70
|
+
changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
|
71
|
+
documentation_uri: https://yoshoku.github.io/rumale/doc/
|
72
|
+
rubygems_mfa_required: 'true'
|
73
|
+
post_install_message:
|
74
|
+
rdoc_options: []
|
75
|
+
require_paths:
|
76
|
+
- lib
|
77
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
78
|
+
requirements:
|
79
|
+
- - ">="
|
80
|
+
- !ruby/object:Gem::Version
|
81
|
+
version: '0'
|
82
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
83
|
+
requirements:
|
84
|
+
- - ">="
|
85
|
+
- !ruby/object:Gem::Version
|
86
|
+
version: '0'
|
87
|
+
requirements: []
|
88
|
+
rubygems_version: 3.3.26
|
89
|
+
signing_key:
|
90
|
+
specification_version: 4
|
91
|
+
summary: Rumale::Tree provides classifier and regression based on decision tree algorithms
|
92
|
+
with Rumale interface.
|
93
|
+
test_files: []
|