rumale-tree 0.29.0 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rumale/tree/version.rb +1 -1
- data/lib/rumale/tree/vr_tree_classifier.rb +98 -0
- data/lib/rumale/tree/vr_tree_regressor.rb +89 -0
- data/lib/rumale/tree.rb +2 -0
- metadata +7 -8
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: fe371a348b02a2b18cb90639fb64af902f76bfa0c2f57bee235cdad12e183584
|
4
|
+
data.tar.gz: a2ffee131bc50ea0f850ac1fcbc5ca29be81d101f6a175a94e4756db4d19dede
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 2b73e0d16c2e63b23ea42ed6235aa22b49bf014e9187d1faba3d73642b1c2e1a203c6cfa2791c6e31be0d94f432b68764f00350dbf77e54d2ec1efd04930518e
|
7
|
+
data.tar.gz: d025cbffcadcc38b904f3afe9c17d0dd8e33347345d9aac02a10d052653abb42ee362819f4a50eb267d40ed5610242139fdc5e4e0c2318ef276cbbe3d44dc4da
|
data/lib/rumale/tree/version.rb
CHANGED
@@ -0,0 +1,98 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/tree/decision_tree_classifier'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Tree
|
7
|
+
# VRTreeClassifier is a class that implements Variable-Random (VR) tree for classification.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# require 'rumale/tree/vr_tree_classifier'
|
11
|
+
#
|
12
|
+
# estimator =
|
13
|
+
# Rumale::Tree::VRTreeClassifier.new(
|
14
|
+
# criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
|
15
|
+
# estimator.fit(training_samples, traininig_labels)
|
16
|
+
# results = estimator.predict(testing_samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Liu, F. T., Ting, K. M., Yu, Y., and Zhou, Z. H., "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
|
20
|
+
class VRTreeClassifier < DecisionTreeClassifier
|
21
|
+
# Return the class labels.
|
22
|
+
# @return [Numo::Int32] (size: n_classes)
|
23
|
+
attr_reader :classes
|
24
|
+
|
25
|
+
# Return the importance for each feature.
|
26
|
+
# @return [Numo::DFloat] (size: n_features)
|
27
|
+
attr_reader :feature_importances
|
28
|
+
|
29
|
+
# Return the learned tree.
|
30
|
+
# @return [Node]
|
31
|
+
attr_reader :tree
|
32
|
+
|
33
|
+
# Return the random generator for random selection of feature index.
|
34
|
+
# @return [Random]
|
35
|
+
attr_reader :rng
|
36
|
+
|
37
|
+
# Return the labels assigned each leaf.
|
38
|
+
# @return [Numo::Int32] (size: n_leafs)
|
39
|
+
attr_reader :leaf_labels
|
40
|
+
|
41
|
+
# Create a new classifier with variable-random tree algorithm.
|
42
|
+
#
|
43
|
+
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
44
|
+
# @param alpha [Float] The probability of choosing a deterministic or random spliting point.
|
45
|
+
# If 1.0 is given, the tree is the same as the normal decision tree.
|
46
|
+
# @param max_depth [Integer] The maximum depth of the tree.
|
47
|
+
# If nil is given, variable-random tree grows without concern for depth.
|
48
|
+
# @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
|
49
|
+
# If nil is given, number of leaves is not limited.
|
50
|
+
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
|
51
|
+
# @param max_features [Integer] The number of features to consider when searching optimal split point.
|
52
|
+
# If nil is given, split process considers all features.
|
53
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
54
|
+
# It is used to randomly determine the order of features when deciding spliting point.
|
55
|
+
def initialize(criterion: 'gini', alpha: 0.5, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
|
56
|
+
random_seed: nil)
|
57
|
+
super(criterion: criterion, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
|
58
|
+
max_features: max_features, random_seed: random_seed)
|
59
|
+
@params[:alpha] = alpha.clamp(0.0, 1.0)
|
60
|
+
end
|
61
|
+
|
62
|
+
# Fit the model with given training data.
|
63
|
+
#
|
64
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
65
|
+
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
66
|
+
# @return [VRTreeClassifier] The learned classifier itself.
|
67
|
+
|
68
|
+
# Predict class labels for samples.
|
69
|
+
#
|
70
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
71
|
+
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
72
|
+
|
73
|
+
# Predict probability for samples.
|
74
|
+
#
|
75
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
76
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
77
|
+
|
78
|
+
private
|
79
|
+
|
80
|
+
def best_split(features, y, whole_impurity)
|
81
|
+
r = -@sub_rng.rand(-1.0...0.0) # generate random number with (0, 1]
|
82
|
+
return super if r <= @params[:alpha]
|
83
|
+
|
84
|
+
fa, fb = features.to_a.uniq.sample(2, random: @sub_rng)
|
85
|
+
fb = fa if fb.nil?
|
86
|
+
threshold = 0.5 * (fa + fb)
|
87
|
+
l_ids = features.le(threshold).where
|
88
|
+
r_ids = features.gt(threshold).where
|
89
|
+
l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids])
|
90
|
+
r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids])
|
91
|
+
gain = whole_impurity -
|
92
|
+
l_impurity * l_ids.size.fdiv(y.size) -
|
93
|
+
r_impurity * r_ids.size.fdiv(y.size)
|
94
|
+
[l_impurity, r_impurity, threshold, gain]
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
98
|
+
end
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/tree/decision_tree_regressor'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module Tree
|
7
|
+
# VRTreeRegressor is a class that implements Variable-Random (VR) tree for regression.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# require 'rumale/tree/vr_tree_regressor'
|
11
|
+
#
|
12
|
+
# estimator =
|
13
|
+
# Rumale::Tree::VRTreeRegressor.new(
|
14
|
+
# max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
|
15
|
+
# estimator.fit(training_samples, traininig_values)
|
16
|
+
# results = estimator.predict(testing_samples)
|
17
|
+
#
|
18
|
+
# *Reference*
|
19
|
+
# - Liu, F. T., Ting, K. M., Yu, Y., and Zhou, Z. H., "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
|
20
|
+
class VRTreeRegressor < DecisionTreeRegressor
|
21
|
+
# Return the importance for each feature.
|
22
|
+
# @return [Numo::DFloat] (size: n_features)
|
23
|
+
attr_reader :feature_importances
|
24
|
+
|
25
|
+
# Return the learned tree.
|
26
|
+
# @return [Node]
|
27
|
+
attr_reader :tree
|
28
|
+
|
29
|
+
# Return the random generator for random selection of feature index.
|
30
|
+
# @return [Random]
|
31
|
+
attr_reader :rng
|
32
|
+
|
33
|
+
# Return the values assigned each leaf.
|
34
|
+
# @return [Numo::DFloat] (shape: [n_leafs, n_outputs])
|
35
|
+
attr_reader :leaf_values
|
36
|
+
|
37
|
+
# Create a new regressor with variable-random tree algorithm.
|
38
|
+
#
|
39
|
+
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
|
40
|
+
# @param alpha [Float] The probability of choosing a deterministic or random spliting point.
|
41
|
+
# If 1.0 is given, the tree is the same as the normal decision tree.
|
42
|
+
# @param max_depth [Integer] The maximum depth of the tree.
|
43
|
+
# If nil is given, variable-random tree grows without concern for depth.
|
44
|
+
# @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
|
45
|
+
# If nil is given, number of leaves is not limited.
|
46
|
+
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
|
47
|
+
# @param max_features [Integer] The number of features to consider when searching optimal split point.
|
48
|
+
# If nil is given, split process considers all features.
|
49
|
+
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
50
|
+
# It is used to randomly determine the order of features when deciding spliting point.
|
51
|
+
def initialize(criterion: 'mse', alpha: 0.5, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
|
52
|
+
random_seed: nil)
|
53
|
+
super(criterion: criterion, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
|
54
|
+
max_features: max_features, random_seed: random_seed)
|
55
|
+
@params[:alpha] = alpha.clamp(0.0, 1.0)
|
56
|
+
end
|
57
|
+
|
58
|
+
# Fit the model with given training data.
|
59
|
+
#
|
60
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
61
|
+
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values to be used for fitting the model.
|
62
|
+
# @return [VRTreeRegressor] The learned regressor itself.
|
63
|
+
|
64
|
+
# Predict values for samples.
|
65
|
+
#
|
66
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
|
67
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
|
68
|
+
|
69
|
+
private
|
70
|
+
|
71
|
+
def best_split(features, y, whole_impurity)
|
72
|
+
r = -@sub_rng.rand(-1.0...0.0) # generate random number with (0, 1]
|
73
|
+
return super if r <= @params[:alpha]
|
74
|
+
|
75
|
+
fa, fb = features.to_a.uniq.sample(2, random: @sub_rng)
|
76
|
+
fb = fa if fb.nil?
|
77
|
+
threshold = 0.5 * (fa + fb)
|
78
|
+
l_ids = features.le(threshold).where
|
79
|
+
r_ids = features.gt(threshold).where
|
80
|
+
l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids, true])
|
81
|
+
r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids, true])
|
82
|
+
gain = whole_impurity -
|
83
|
+
l_impurity * l_ids.size.fdiv(y.shape[0]) -
|
84
|
+
r_impurity * r_ids.size.fdiv(y.shape[0])
|
85
|
+
[l_impurity, r_impurity, threshold, gain]
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
data/lib/rumale/tree.rb
CHANGED
@@ -9,3 +9,5 @@ require_relative 'tree/decision_tree_regressor'
|
|
9
9
|
require_relative 'tree/extra_tree_classifier'
|
10
10
|
require_relative 'tree/extra_tree_regressor'
|
11
11
|
require_relative 'tree/gradient_tree_regressor'
|
12
|
+
require_relative 'tree/vr_tree_classifier'
|
13
|
+
require_relative 'tree/vr_tree_regressor'
|
metadata
CHANGED
@@ -1,14 +1,13 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-tree
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 1.0.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
|
-
autorequire:
|
9
8
|
bindir: exe
|
10
9
|
cert_chain: []
|
11
|
-
date:
|
10
|
+
date: 2025-01-02 00:00:00.000000000 Z
|
12
11
|
dependencies:
|
13
12
|
- !ruby/object:Gem::Dependency
|
14
13
|
name: numo-narray
|
@@ -30,14 +29,14 @@ dependencies:
|
|
30
29
|
requirements:
|
31
30
|
- - "~>"
|
32
31
|
- !ruby/object:Gem::Version
|
33
|
-
version: 0.
|
32
|
+
version: 1.0.0
|
34
33
|
type: :runtime
|
35
34
|
prerelease: false
|
36
35
|
version_requirements: !ruby/object:Gem::Requirement
|
37
36
|
requirements:
|
38
37
|
- - "~>"
|
39
38
|
- !ruby/object:Gem::Version
|
40
|
-
version: 0.
|
39
|
+
version: 1.0.0
|
41
40
|
description: Rumale::Tree provides classifier and regression based on decision tree
|
42
41
|
algorithms with Rumale interface.
|
43
42
|
email:
|
@@ -61,6 +60,8 @@ files:
|
|
61
60
|
- lib/rumale/tree/gradient_tree_regressor.rb
|
62
61
|
- lib/rumale/tree/node.rb
|
63
62
|
- lib/rumale/tree/version.rb
|
63
|
+
- lib/rumale/tree/vr_tree_classifier.rb
|
64
|
+
- lib/rumale/tree/vr_tree_regressor.rb
|
64
65
|
homepage: https://github.com/yoshoku/rumale
|
65
66
|
licenses:
|
66
67
|
- BSD-3-Clause
|
@@ -70,7 +71,6 @@ metadata:
|
|
70
71
|
changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
|
71
72
|
documentation_uri: https://yoshoku.github.io/rumale/doc/
|
72
73
|
rubygems_mfa_required: 'true'
|
73
|
-
post_install_message:
|
74
74
|
rdoc_options: []
|
75
75
|
require_paths:
|
76
76
|
- lib
|
@@ -85,8 +85,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
85
85
|
- !ruby/object:Gem::Version
|
86
86
|
version: '0'
|
87
87
|
requirements: []
|
88
|
-
rubygems_version: 3.
|
89
|
-
signing_key:
|
88
|
+
rubygems_version: 3.6.2
|
90
89
|
specification_version: 4
|
91
90
|
summary: Rumale::Tree provides classifier and regression based on decision tree algorithms
|
92
91
|
with Rumale interface.
|