rumale-tree 0.29.0 → 1.0.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rumale/tree/version.rb +1 -1
- data/lib/rumale/tree/vr_tree_classifier.rb +98 -0
- data/lib/rumale/tree/vr_tree_regressor.rb +89 -0
- data/lib/rumale/tree.rb +2 -0
- metadata +7 -8
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: fe371a348b02a2b18cb90639fb64af902f76bfa0c2f57bee235cdad12e183584
|
4
|
+
data.tar.gz: a2ffee131bc50ea0f850ac1fcbc5ca29be81d101f6a175a94e4756db4d19dede
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 2b73e0d16c2e63b23ea42ed6235aa22b49bf014e9187d1faba3d73642b1c2e1a203c6cfa2791c6e31be0d94f432b68764f00350dbf77e54d2ec1efd04930518e
|
7
|
+
data.tar.gz: d025cbffcadcc38b904f3afe9c17d0dd8e33347345d9aac02a10d052653abb42ee362819f4a50eb267d40ed5610242139fdc5e4e0c2318ef276cbbe3d44dc4da
|
data/lib/rumale/tree/version.rb
CHANGED
@@ -0,0 +1,98 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/tree/decision_tree_classifier'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Tree
|
7
|
+
# VRTreeClassifier is a class that implements Variable-Random (VR) tree for classification.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# require 'rumale/tree/vr_tree_classifier'
|
11
|
+
#
|
12
|
+
# estimator =
|
13
|
+
# Rumale::Tree::VRTreeClassifier.new(
|
14
|
+
# criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
|
15
|
+
# estimator.fit(training_samples, traininig_labels)
|
16
|
+
# results = estimator.predict(testing_samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Liu, F. T., Ting, K. M., Yu, Y., and Zhou, Z. H., "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
|
20
|
+
class VRTreeClassifier < DecisionTreeClassifier
|
21
|
+
# Return the class labels.
|
22
|
+
# @return [Numo::Int32] (size: n_classes)
|
23
|
+
attr_reader :classes
|
24
|
+
|
25
|
+
# Return the importance for each feature.
|
26
|
+
# @return [Numo::DFloat] (size: n_features)
|
27
|
+
attr_reader :feature_importances
|
28
|
+
|
29
|
+
# Return the learned tree.
|
30
|
+
# @return [Node]
|
31
|
+
attr_reader :tree
|
32
|
+
|
33
|
+
# Return the random generator for random selection of feature index.
|
34
|
+
# @return [Random]
|
35
|
+
attr_reader :rng
|
36
|
+
|
37
|
+
# Return the labels assigned each leaf.
|
38
|
+
# @return [Numo::Int32] (size: n_leafs)
|
39
|
+
attr_reader :leaf_labels
|
40
|
+
|
41
|
+
# Create a new classifier with variable-random tree algorithm.
|
42
|
+
#
|
43
|
+
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
44
|
+
# @param alpha [Float] The probability of choosing a deterministic or random spliting point.
|
45
|
+
# If 1.0 is given, the tree is the same as the normal decision tree.
|
46
|
+
# @param max_depth [Integer] The maximum depth of the tree.
|
47
|
+
# If nil is given, variable-random tree grows without concern for depth.
|
48
|
+
# @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
|
49
|
+
# If nil is given, number of leaves is not limited.
|
50
|
+
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
|
51
|
+
# @param max_features [Integer] The number of features to consider when searching optimal split point.
|
52
|
+
# If nil is given, split process considers all features.
|
53
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
54
|
+
# It is used to randomly determine the order of features when deciding spliting point.
|
55
|
+
def initialize(criterion: 'gini', alpha: 0.5, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
|
56
|
+
random_seed: nil)
|
57
|
+
super(criterion: criterion, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
|
58
|
+
max_features: max_features, random_seed: random_seed)
|
59
|
+
@params[:alpha] = alpha.clamp(0.0, 1.0)
|
60
|
+
end
|
61
|
+
|
62
|
+
# Fit the model with given training data.
|
63
|
+
#
|
64
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
65
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
66
|
+
# @return [VRTreeClassifier] The learned classifier itself.
|
67
|
+
|
68
|
+
# Predict class labels for samples.
|
69
|
+
#
|
70
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
71
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
72
|
+
|
73
|
+
# Predict probability for samples.
|
74
|
+
#
|
75
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
76
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
77
|
+
|
78
|
+
private
|
79
|
+
|
80
|
+
def best_split(features, y, whole_impurity)
|
81
|
+
r = -@sub_rng.rand(-1.0...0.0) # generate random number with (0, 1]
|
82
|
+
return super if r <= @params[:alpha]
|
83
|
+
|
84
|
+
fa, fb = features.to_a.uniq.sample(2, random: @sub_rng)
|
85
|
+
fb = fa if fb.nil?
|
86
|
+
threshold = 0.5 * (fa + fb)
|
87
|
+
l_ids = features.le(threshold).where
|
88
|
+
r_ids = features.gt(threshold).where
|
89
|
+
l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids])
|
90
|
+
r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids])
|
91
|
+
gain = whole_impurity -
|
92
|
+
l_impurity * l_ids.size.fdiv(y.size) -
|
93
|
+
r_impurity * r_ids.size.fdiv(y.size)
|
94
|
+
[l_impurity, r_impurity, threshold, gain]
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
98
|
+
end
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/tree/decision_tree_regressor'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Tree
|
7
|
+
# VRTreeRegressor is a class that implements Variable-Random (VR) tree for regression.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# require 'rumale/tree/vr_tree_regressor'
|
11
|
+
#
|
12
|
+
# estimator =
|
13
|
+
# Rumale::Tree::VRTreeRegressor.new(
|
14
|
+
# max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
|
15
|
+
# estimator.fit(training_samples, traininig_values)
|
16
|
+
# results = estimator.predict(testing_samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Liu, F. T., Ting, K. M., Yu, Y., and Zhou, Z. H., "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
|
20
|
+
class VRTreeRegressor < DecisionTreeRegressor
|
21
|
+
# Return the importance for each feature.
|
22
|
+
# @return [Numo::DFloat] (size: n_features)
|
23
|
+
attr_reader :feature_importances
|
24
|
+
|
25
|
+
# Return the learned tree.
|
26
|
+
# @return [Node]
|
27
|
+
attr_reader :tree
|
28
|
+
|
29
|
+
# Return the random generator for random selection of feature index.
|
30
|
+
# @return [Random]
|
31
|
+
attr_reader :rng
|
32
|
+
|
33
|
+
# Return the values assigned each leaf.
|
34
|
+
# @return [Numo::DFloat] (shape: [n_leafs, n_outputs])
|
35
|
+
attr_reader :leaf_values
|
36
|
+
|
37
|
+
# Create a new regressor with variable-random tree algorithm.
|
38
|
+
#
|
39
|
+
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
|
40
|
+
# @param alpha [Float] The probability of choosing a deterministic or random spliting point.
|
41
|
+
# If 1.0 is given, the tree is the same as the normal decision tree.
|
42
|
+
# @param max_depth [Integer] The maximum depth of the tree.
|
43
|
+
# If nil is given, variable-random tree grows without concern for depth.
|
44
|
+
# @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
|
45
|
+
# If nil is given, number of leaves is not limited.
|
46
|
+
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
|
47
|
+
# @param max_features [Integer] The number of features to consider when searching optimal split point.
|
48
|
+
# If nil is given, split process considers all features.
|
49
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
50
|
+
# It is used to randomly determine the order of features when deciding spliting point.
|
51
|
+
def initialize(criterion: 'mse', alpha: 0.5, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
|
52
|
+
random_seed: nil)
|
53
|
+
super(criterion: criterion, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
|
54
|
+
max_features: max_features, random_seed: random_seed)
|
55
|
+
@params[:alpha] = alpha.clamp(0.0, 1.0)
|
56
|
+
end
|
57
|
+
|
58
|
+
# Fit the model with given training data.
|
59
|
+
#
|
60
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
61
|
+
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values to be used for fitting the model.
|
62
|
+
# @return [VRTreeRegressor] The learned regressor itself.
|
63
|
+
|
64
|
+
# Predict values for samples.
|
65
|
+
#
|
66
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
|
67
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
|
68
|
+
|
69
|
+
private
|
70
|
+
|
71
|
+
def best_split(features, y, whole_impurity)
|
72
|
+
r = -@sub_rng.rand(-1.0...0.0) # generate random number with (0, 1]
|
73
|
+
return super if r <= @params[:alpha]
|
74
|
+
|
75
|
+
fa, fb = features.to_a.uniq.sample(2, random: @sub_rng)
|
76
|
+
fb = fa if fb.nil?
|
77
|
+
threshold = 0.5 * (fa + fb)
|
78
|
+
l_ids = features.le(threshold).where
|
79
|
+
r_ids = features.gt(threshold).where
|
80
|
+
l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids, true])
|
81
|
+
r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids, true])
|
82
|
+
gain = whole_impurity -
|
83
|
+
l_impurity * l_ids.size.fdiv(y.shape[0]) -
|
84
|
+
r_impurity * r_ids.size.fdiv(y.shape[0])
|
85
|
+
[l_impurity, r_impurity, threshold, gain]
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
data/lib/rumale/tree.rb
CHANGED
@@ -9,3 +9,5 @@ require_relative 'tree/decision_tree_regressor'
|
|
9
9
|
require_relative 'tree/extra_tree_classifier'
|
10
10
|
require_relative 'tree/extra_tree_regressor'
|
11
11
|
require_relative 'tree/gradient_tree_regressor'
|
12
|
+
require_relative 'tree/vr_tree_classifier'
|
13
|
+
require_relative 'tree/vr_tree_regressor'
|
metadata
CHANGED
@@ -1,14 +1,13 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-tree
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 1.0.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
|
-
autorequire:
|
9
8
|
bindir: exe
|
10
9
|
cert_chain: []
|
11
|
-
date:
|
10
|
+
date: 2025-01-02 00:00:00.000000000 Z
|
12
11
|
dependencies:
|
13
12
|
- !ruby/object:Gem::Dependency
|
14
13
|
name: numo-narray
|
@@ -30,14 +29,14 @@ dependencies:
|
|
30
29
|
requirements:
|
31
30
|
- - "~>"
|
32
31
|
- !ruby/object:Gem::Version
|
33
|
-
version: 0.
|
32
|
+
version: 1.0.0
|
34
33
|
type: :runtime
|
35
34
|
prerelease: false
|
36
35
|
version_requirements: !ruby/object:Gem::Requirement
|
37
36
|
requirements:
|
38
37
|
- - "~>"
|
39
38
|
- !ruby/object:Gem::Version
|
40
|
-
version: 0.
|
39
|
+
version: 1.0.0
|
41
40
|
description: Rumale::Tree provides classifier and regression based on decision tree
|
42
41
|
algorithms with Rumale interface.
|
43
42
|
email:
|
@@ -61,6 +60,8 @@ files:
|
|
61
60
|
- lib/rumale/tree/gradient_tree_regressor.rb
|
62
61
|
- lib/rumale/tree/node.rb
|
63
62
|
- lib/rumale/tree/version.rb
|
63
|
+
- lib/rumale/tree/vr_tree_classifier.rb
|
64
|
+
- lib/rumale/tree/vr_tree_regressor.rb
|
64
65
|
homepage: https://github.com/yoshoku/rumale
|
65
66
|
licenses:
|
66
67
|
- BSD-3-Clause
|
@@ -70,7 +71,6 @@ metadata:
|
|
70
71
|
changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
|
71
72
|
documentation_uri: https://yoshoku.github.io/rumale/doc/
|
72
73
|
rubygems_mfa_required: 'true'
|
73
|
-
post_install_message:
|
74
74
|
rdoc_options: []
|
75
75
|
require_paths:
|
76
76
|
- lib
|
@@ -85,8 +85,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
85
85
|
- !ruby/object:Gem::Version
|
86
86
|
version: '0'
|
87
87
|
requirements: []
|
88
|
-
rubygems_version: 3.
|
89
|
-
signing_key:
|
88
|
+
rubygems_version: 3.6.2
|
90
89
|
specification_version: 4
|
91
90
|
summary: Rumale::Tree provides classifier and regression based on decision tree algorithms
|
92
91
|
with Rumale interface.
|