rumale-ensemble 0.28.1 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 3ee0234ca988c60ca6ca08b57a985c5b135f371bd1150daf48b5549bfcda2e4a
4
- data.tar.gz: cecfc6af3713696bc299f3835de2a9e84a6c21a59dbb0f0b9a56244fb38b6018
3
+ metadata.gz: 52f48f439afecd4e75af580c46392ae4c2975b808c91a49f3781c78d48e8a43c
4
+ data.tar.gz: b5deb1e9736674d6db4ee733679e90a5b71cc45ffcb9d4e13b2d5956d66a82e7
5
5
  SHA512:
6
- metadata.gz: dab665e72aa587b613fa4dc551f015022648d69f2161bccef32cde6b27ba5891ef7f6c2d4a5ca2ba1f84bc6ea5940f097e8f50a90b8e4ddc6010bbca56555b55
7
- data.tar.gz: 7c4cfd1950e0dd1fb7bf615a59eba6493a7e4b6b4d3ccb381ee49d6c72a362b6d9499479c267f0c3e0b5fb027d925eee1f920a9656e041715be5606e4da8c27d
6
+ metadata.gz: e21818a828be87993169c1eefded133a355a49b29c3a6d39ce6ce1c5e7d3b54af36f70650d8502fa2eb716436f44fa904f6c653fd46069ca4af5c7edb750d890
7
+ data.tar.gz: 5cd5ee453ef7f86a71b097c4f755a7e4faf90ef89787ad0e15b727e745d9e0d4c2eeebcfdd23616b189a65ee4a569c3eb16be39b4ccae956a9616e396a5531a1
data/LICENSE.txt CHANGED
@@ -1,4 +1,4 @@
1
- Copyright (c) 2022-2023 Atsushi Tatsuma
1
+ Copyright (c) 2022-2024 Atsushi Tatsuma
2
2
  All rights reserved.
3
3
 
4
4
  Redistribution and use in source and binary forms, with or without
@@ -5,6 +5,6 @@ module Rumale
5
5
  # This module consists of the classes that implement ensemble-based methods.
6
6
  module Ensemble
7
7
  # @!visibility private
8
- VERSION = '0.28.1'
8
+ VERSION = '1.0.0'
9
9
  end
10
10
  end
@@ -0,0 +1,139 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/tree/vr_tree_classifier'
5
+ require 'rumale/ensemble/random_forest_classifier'
6
+ require 'rumale/ensemble/value'
7
+
8
+ module Rumale
9
+ module Ensemble
10
+ # VRTreesClassifier is a class that implements variable-random (VR) trees for classification.
11
+ #
12
+ # @example
13
+ # require 'rumale/ensemble/vr_trees_classifier'
14
+ #
15
+ # estimator =
16
+ # Rumale::Ensemble::VRTreesClassifier.new(
17
+ # n_estimators: 10, criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
18
+ # estimator.fit(training_samples, traininig_labels)
19
+ # results = estimator.predict(testing_samples)
20
+ #
21
+ # *Reference*
22
+ # - Liu, F. T., Ting, K. M., Yu, Y., and Zhou, Z. H., "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
23
+ class VRTreesClassifier < RandomForestClassifier
24
+ # Return the set of estimators.
25
+ # @return [Array<VRTreeClassifier>]
26
+ attr_reader :estimators
27
+
28
+ # Return the class labels.
29
+ # @return [Numo::Int32] (size: n_classes)
30
+ attr_reader :classes
31
+
32
+ # Return the importance for each feature.
33
+ # @return [Numo::DFloat] (size: n_features)
34
+ attr_reader :feature_importances
35
+
36
+ # Return the random generator for random selection of feature index.
37
+ # @return [Random]
38
+ attr_reader :rng
39
+
40
+ # Create a new classifier with variable-random trees.
41
+ #
42
+ # @param n_estimators [Integer] The numeber of trees for contructing variable-random trees.
43
+ # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
44
+ # @param max_depth [Integer] The maximum depth of the tree.
45
+ # If nil is given, variable-random tree grows without concern for depth.
46
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
47
+ # If nil is given, number of leaves is not limited.
48
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
49
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
50
+ # If nil is given, split process considers 'n_features' features.
51
+ # @param n_jobs [Integer] The number of jobs for running the fit method in parallel.
52
+ # If nil is given, the method does not execute in parallel.
53
+ # If zero or less is given, it becomes equal to the number of processors.
54
+ # This parameter is ignored if the Parallel gem is not loaded.
55
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
56
+ # It is used to randomly determine the order of features when deciding spliting point.
57
+ def initialize(n_estimators: 10,
58
+ criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
59
+ max_features: nil, n_jobs: nil, random_seed: nil)
60
+ super
61
+ end
62
+
63
+ # Fit the model with given training data.
64
+ #
65
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
66
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
67
+ # @return [VRTreesClassifier] The learned classifier itself.
68
+ def fit(x, y)
69
+ x = ::Rumale::Validation.check_convert_sample_array(x)
70
+ y = ::Rumale::Validation.check_convert_label_array(y)
71
+ ::Rumale::Validation.check_sample_size(x, y)
72
+
73
+ # Initialize some variables.
74
+ n_features = x.shape[1]
75
+ @params[:max_features] = n_features if @params[:max_features].nil?
76
+ @params[:max_features] = @params[:max_features].clamp(1, n_features)
77
+ @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
78
+ sub_rng = @rng.dup
79
+ # Construct trees.
80
+ rng_seeds = Array.new(@params[:n_estimators]) { sub_rng.rand(::Rumale::Ensemble::Value::SEED_BASE) }
81
+ alpha_ratio = 0.5 / @params[:n_estimators]
82
+ alphas = Array.new(@params[:n_estimators]) { |v| v * alpha_ratio }
83
+ @estimators = if enable_parallel?
84
+ parallel_map(@params[:n_estimators]) { |n| plant_tree(alphas[n], rng_seeds[n]).fit(x, y) }
85
+ else
86
+ Array.new(@params[:n_estimators]) { |n| plant_tree(alphas[n], rng_seeds[n]).fit(x, y) }
87
+ end
88
+ @feature_importances =
89
+ if enable_parallel?
90
+ parallel_map(@params[:n_estimators]) { |n| @estimators[n].feature_importances }.sum
91
+ else
92
+ @estimators.sum(&:feature_importances)
93
+ end
94
+ @feature_importances /= @feature_importances.sum
95
+ self
96
+ end
97
+
98
+ # Predict class labels for samples.
99
+ #
100
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
101
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
102
+ def predict(x)
103
+ x = ::Rumale::Validation.check_convert_sample_array(x)
104
+
105
+ super
106
+ end
107
+
108
+ # Predict probability for samples.
109
+ #
110
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
111
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
112
+ def predict_proba(x)
113
+ x = ::Rumale::Validation.check_convert_sample_array(x)
114
+
115
+ super
116
+ end
117
+
118
+ # Return the index of the leaf that each sample reached.
119
+ #
120
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
121
+ # @return [Numo::Int32] (shape: [n_samples, n_estimators]) Leaf index for sample.
122
+ def apply(x)
123
+ x = ::Rumale::Validation.check_convert_sample_array(x)
124
+
125
+ super
126
+ end
127
+
128
+ private
129
+
130
+ def plant_tree(alpha, rnd_seed)
131
+ ::Rumale::Tree::VRTreeClassifier.new(
132
+ criterion: @params[:criterion], alpha: alpha, max_depth: @params[:max_depth],
133
+ max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
134
+ max_features: @params[:max_features], random_seed: rnd_seed
135
+ )
136
+ end
137
+ end
138
+ end
139
+ end
@@ -0,0 +1,124 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/validation'
4
+ require 'rumale/tree/vr_tree_regressor'
5
+ require 'rumale/ensemble/random_forest_regressor'
6
+ require 'rumale/ensemble/value'
7
+
8
+ module Rumale
9
+ module Ensemble
10
+ # VRTreesRegressor is a class that implements variable-random (VR) trees for regression
11
+ #
12
+ # @example
13
+ # @require 'rumale/ensemble/vr_trees_regressor'
14
+ #
15
+ # estimator =
16
+ # Rumale::Ensemble::VRTreesRegressor.new(
17
+ # n_estimators: 10, criterion: 'mse', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
18
+ # estimator.fit(training_samples, traininig_values)
19
+ # results = estimator.predict(testing_samples)
20
+ #
21
+ # *Reference*
22
+ # - Liu, F. T., Ting, K. M., Yu, Y., and Zhou, Z. H., "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
23
+ class VRTreesRegressor < RandomForestRegressor
24
+ # Return the set of estimators.
25
+ # @return [Array<VRTreeRegressor>]
26
+ attr_reader :estimators
27
+
28
+ # Return the importance for each feature.
29
+ # @return [Numo::DFloat] (size: n_features)
30
+ attr_reader :feature_importances
31
+
32
+ # Return the random generator for random selection of feature index.
33
+ # @return [Random]
34
+ attr_reader :rng
35
+
36
+ # Create a new regressor with variable-random trees.
37
+ #
38
+ # @param n_estimators [Integer] The numeber of trees for contructing variable-random trees.
39
+ # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
40
+ # @param max_depth [Integer] The maximum depth of the tree.
41
+ # If nil is given, variable-random tree grows without concern for depth.
42
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
43
+ # If nil is given, number of leaves is not limited.
44
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
45
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
46
+ # If nil is given, split process considers 'n_features' features.
47
+ # @param n_jobs [Integer] The number of jobs for running the fit and predict methods in parallel.
48
+ # If nil is given, the methods do not execute in parallel.
49
+ # If zero or less is given, it becomes equal to the number of processors.
50
+ # This parameter is ignored if the Parallel gem is not loaded.
51
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
52
+ # It is used to randomly determine the order of features when deciding spliting point.
53
+ def initialize(n_estimators: 10,
54
+ criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
55
+ max_features: nil, n_jobs: nil, random_seed: nil)
56
+ super
57
+ end
58
+
59
+ # Fit the model with given training data.
60
+ #
61
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
62
+ # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
63
+ # @return [VRTreesRegressor] The learned regressor itself.
64
+ def fit(x, y)
65
+ x = ::Rumale::Validation.check_convert_sample_array(x)
66
+ y = ::Rumale::Validation.check_convert_target_value_array(y)
67
+ ::Rumale::Validation.check_sample_size(x, y)
68
+
69
+ # Initialize some variables.
70
+ n_features = x.shape[1]
71
+ @params[:max_features] = n_features if @params[:max_features].nil?
72
+ @params[:max_features] = @params[:max_features].clamp(1, n_features)
73
+ sub_rng = @rng.dup
74
+ # Construct forest.
75
+ rng_seeds = Array.new(@params[:n_estimators]) { sub_rng.rand(::Rumale::Ensemble::Value::SEED_BASE) }
76
+ alpha_ratio = 0.5 / @params[:n_estimators]
77
+ alphas = Array.new(@params[:n_estimators]) { |v| v * alpha_ratio }
78
+ @estimators = if enable_parallel?
79
+ parallel_map(@params[:n_estimators]) { |n| plant_tree(alphas[n], rng_seeds[n]).fit(x, y) }
80
+ else
81
+ Array.new(@params[:n_estimators]) { |n| plant_tree(alphas[n], rng_seeds[n]).fit(x, y) }
82
+ end
83
+ @feature_importances =
84
+ if enable_parallel?
85
+ parallel_map(@params[:n_estimators]) { |n| @estimators[n].feature_importances }.sum
86
+ else
87
+ @estimators.sum(&:feature_importances)
88
+ end
89
+ @feature_importances /= @feature_importances.sum
90
+ self
91
+ end
92
+
93
+ # Predict values for samples.
94
+ #
95
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
96
+ # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted value per sample.
97
+ def predict(x)
98
+ x = ::Rumale::Validation.check_convert_sample_array(x)
99
+
100
+ super
101
+ end
102
+
103
+ # Return the index of the leaf that each sample reached.
104
+ #
105
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to assign each leaf.
106
+ # @return [Numo::Int32] (shape: [n_samples, n_estimators]) Leaf index for sample.
107
+ def apply(x)
108
+ x = ::Rumale::Validation.check_convert_sample_array(x)
109
+
110
+ super
111
+ end
112
+
113
+ private
114
+
115
+ def plant_tree(alpha, rnd_seed)
116
+ ::Rumale::Tree::VRTreeRegressor.new(
117
+ criterion: @params[:criterion], alpha: alpha, max_depth: @params[:max_depth],
118
+ max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
119
+ max_features: @params[:max_features], random_seed: rnd_seed
120
+ )
121
+ end
122
+ end
123
+ end
124
+ end
@@ -18,3 +18,5 @@ require_relative 'ensemble/stacking_classifier'
18
18
  require_relative 'ensemble/stacking_regressor'
19
19
  require_relative 'ensemble/voting_classifier'
20
20
  require_relative 'ensemble/voting_regressor'
21
+ require_relative 'ensemble/vr_trees_classifier'
22
+ require_relative 'ensemble/vr_trees_regressor'
metadata CHANGED
@@ -1,14 +1,13 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-ensemble
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.28.1
4
+ version: 1.0.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
- autorequire:
9
8
  bindir: exe
10
9
  cert_chain: []
11
- date: 2023-12-24 00:00:00.000000000 Z
10
+ date: 2025-01-02 00:00:00.000000000 Z
12
11
  dependencies:
13
12
  - !ruby/object:Gem::Dependency
14
13
  name: numo-narray
@@ -30,70 +29,70 @@ dependencies:
30
29
  requirements:
31
30
  - - "~>"
32
31
  - !ruby/object:Gem::Version
33
- version: 0.28.1
32
+ version: 1.0.0
34
33
  type: :runtime
35
34
  prerelease: false
36
35
  version_requirements: !ruby/object:Gem::Requirement
37
36
  requirements:
38
37
  - - "~>"
39
38
  - !ruby/object:Gem::Version
40
- version: 0.28.1
39
+ version: 1.0.0
41
40
  - !ruby/object:Gem::Dependency
42
41
  name: rumale-linear_model
43
42
  requirement: !ruby/object:Gem::Requirement
44
43
  requirements:
45
44
  - - "~>"
46
45
  - !ruby/object:Gem::Version
47
- version: 0.28.1
46
+ version: 1.0.0
48
47
  type: :runtime
49
48
  prerelease: false
50
49
  version_requirements: !ruby/object:Gem::Requirement
51
50
  requirements:
52
51
  - - "~>"
53
52
  - !ruby/object:Gem::Version
54
- version: 0.28.1
53
+ version: 1.0.0
55
54
  - !ruby/object:Gem::Dependency
56
55
  name: rumale-model_selection
57
56
  requirement: !ruby/object:Gem::Requirement
58
57
  requirements:
59
58
  - - "~>"
60
59
  - !ruby/object:Gem::Version
61
- version: 0.28.1
60
+ version: 1.0.0
62
61
  type: :runtime
63
62
  prerelease: false
64
63
  version_requirements: !ruby/object:Gem::Requirement
65
64
  requirements:
66
65
  - - "~>"
67
66
  - !ruby/object:Gem::Version
68
- version: 0.28.1
67
+ version: 1.0.0
69
68
  - !ruby/object:Gem::Dependency
70
69
  name: rumale-preprocessing
71
70
  requirement: !ruby/object:Gem::Requirement
72
71
  requirements:
73
72
  - - "~>"
74
73
  - !ruby/object:Gem::Version
75
- version: 0.28.1
74
+ version: 1.0.0
76
75
  type: :runtime
77
76
  prerelease: false
78
77
  version_requirements: !ruby/object:Gem::Requirement
79
78
  requirements:
80
79
  - - "~>"
81
80
  - !ruby/object:Gem::Version
82
- version: 0.28.1
81
+ version: 1.0.0
83
82
  - !ruby/object:Gem::Dependency
84
83
  name: rumale-tree
85
84
  requirement: !ruby/object:Gem::Requirement
86
85
  requirements:
87
86
  - - "~>"
88
87
  - !ruby/object:Gem::Version
89
- version: 0.28.1
88
+ version: 1.0.0
90
89
  type: :runtime
91
90
  prerelease: false
92
91
  version_requirements: !ruby/object:Gem::Requirement
93
92
  requirements:
94
93
  - - "~>"
95
94
  - !ruby/object:Gem::Version
96
- version: 0.28.1
95
+ version: 1.0.0
97
96
  description: |
98
97
  Rumale::Ensemble provides ensemble learning algorithms,
99
98
  such as AdaBoost, Gradient Tree Boosting, and Random Forest,
@@ -121,6 +120,8 @@ files:
121
120
  - lib/rumale/ensemble/version.rb
122
121
  - lib/rumale/ensemble/voting_classifier.rb
123
122
  - lib/rumale/ensemble/voting_regressor.rb
123
+ - lib/rumale/ensemble/vr_trees_classifier.rb
124
+ - lib/rumale/ensemble/vr_trees_regressor.rb
124
125
  homepage: https://github.com/yoshoku/rumale
125
126
  licenses:
126
127
  - BSD-3-Clause
@@ -130,7 +131,6 @@ metadata:
130
131
  changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
131
132
  documentation_uri: https://yoshoku.github.io/rumale/doc/
132
133
  rubygems_mfa_required: 'true'
133
- post_install_message:
134
134
  rdoc_options: []
135
135
  require_paths:
136
136
  - lib
@@ -145,8 +145,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
145
145
  - !ruby/object:Gem::Version
146
146
  version: '0'
147
147
  requirements: []
148
- rubygems_version: 3.4.22
149
- signing_key:
148
+ rubygems_version: 3.6.2
150
149
  specification_version: 4
151
150
  summary: Rumale::Ensemble provides ensemble learning algorithms with Rumale interface.
152
151
  test_files: []