hybridforest 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +89 -0
- data/.rspec +3 -0
- data/.rubocop.yml +13 -0
- data/CHANGELOG.md +33 -0
- data/Gemfile +22 -0
- data/Gemfile.lock +92 -0
- data/LICENSE.txt +21 -0
- data/README.md +96 -0
- data/Rakefile +12 -0
- data/bin/console +15 -0
- data/bin/setup +8 -0
- data/hybridforest.gemspec +44 -0
- data/lib/hybridforest/errors/invalid_state_error.rb +9 -0
- data/lib/hybridforest/forests/forest_growers/cart_grower.rb +19 -0
- data/lib/hybridforest/forests/forest_growers/hybrid_grower.rb +46 -0
- data/lib/hybridforest/forests/forest_growers/id3_grower.rb +19 -0
- data/lib/hybridforest/forests/grower_factory.rb +29 -0
- data/lib/hybridforest/forests/random_forest.rb +84 -0
- data/lib/hybridforest/trees/cart_tree.rb +18 -0
- data/lib/hybridforest/trees/feature_selectors/all_features.rb +14 -0
- data/lib/hybridforest/trees/feature_selectors/max_one_split_per_feature.rb +21 -0
- data/lib/hybridforest/trees/feature_selectors/random_feature_subspace.rb +27 -0
- data/lib/hybridforest/trees/id3_tree.rb +18 -0
- data/lib/hybridforest/trees/impurity_metrics/entropy.rb +21 -0
- data/lib/hybridforest/trees/impurity_metrics/gini_impurity.rb +17 -0
- data/lib/hybridforest/trees/impurity_metrics/impurity.rb +28 -0
- data/lib/hybridforest/trees/nodes/binary_node.rb +37 -0
- data/lib/hybridforest/trees/nodes/leaf_node.rb +31 -0
- data/lib/hybridforest/trees/nodes/multiway_node.rb +45 -0
- data/lib/hybridforest/trees/split.rb +30 -0
- data/lib/hybridforest/trees/tests/equal.rb +37 -0
- data/lib/hybridforest/trees/tests/equal_or_greater.rb +37 -0
- data/lib/hybridforest/trees/tests/less.rb +37 -0
- data/lib/hybridforest/trees/tests/not_equal.rb +37 -0
- data/lib/hybridforest/trees/tests/test.rb +28 -0
- data/lib/hybridforest/trees/tree.rb +46 -0
- data/lib/hybridforest/trees/tree_growers/cart_grower.rb +76 -0
- data/lib/hybridforest/trees/tree_growers/id3_grower.rb +161 -0
- data/lib/hybridforest/utilities/utils.rb +206 -0
- data/lib/hybridforest/version.rb +5 -0
- data/lib/hybridforest.rb +46 -0
- metadata +285 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 8d0af8ba6e658989bf55d0b8026822ab0be77020c189ffd056ffac3e7a393700
|
4
|
+
data.tar.gz: 968c29d52088fe4ba28d93fcd3aedf7288d323f3a35ba000efde5b1673c10e2b
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 2e7076a8cdea280c900f61c98a6f53715a335d7765a6861f028a6956a8c6599cbe233b5d3e867b0d55f9a8889e41067e314ca151b00e4455f55071ac407059f5
|
7
|
+
data.tar.gz: dbf3dc47bfba70ff33e9452cc9232d5079651410a3eb319b29172992141c1602f8c06bba50e7ee9a56c1c8530e9099763834b1a964a86cdc4ec4c89d944237f4
|
data/.gitignore
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
/.bundle/
|
2
|
+
/.yardoc
|
3
|
+
/_yardoc/
|
4
|
+
/coverage/
|
5
|
+
/doc/
|
6
|
+
/pkg/
|
7
|
+
/spec/reports/
|
8
|
+
/tmp/
|
9
|
+
|
10
|
+
# rspec failure tracking
|
11
|
+
.rspec_status
|
12
|
+
|
13
|
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
14
|
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
15
|
+
|
16
|
+
# User-specific stuff
|
17
|
+
.idea/**/workspace.xml
|
18
|
+
.idea/**/tasks.xml
|
19
|
+
.idea/**/usage.statistics.xml
|
20
|
+
.idea/**/dictionaries
|
21
|
+
.idea/**/shelf
|
22
|
+
.DS_Store
|
23
|
+
.idea/
|
24
|
+
.rubocop.yml
|
25
|
+
|
26
|
+
# AWS User-specific
|
27
|
+
.idea/**/aws.xml
|
28
|
+
|
29
|
+
# Generated files
|
30
|
+
.idea/**/contentModel.xml
|
31
|
+
|
32
|
+
# Sensitive or high-churn files
|
33
|
+
.idea/**/dataSources/
|
34
|
+
.idea/**/dataSources.ids
|
35
|
+
.idea/**/dataSources.local.xml
|
36
|
+
.idea/**/sqlDataSources.xml
|
37
|
+
.idea/**/dynamic.xml
|
38
|
+
.idea/**/uiDesigner.xml
|
39
|
+
.idea/**/dbnavigator.xml
|
40
|
+
|
41
|
+
# Gradle
|
42
|
+
.idea/**/gradle.xml
|
43
|
+
.idea/**/libraries
|
44
|
+
|
45
|
+
# Gradle and Maven with auto-import
|
46
|
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
47
|
+
# since they will be recreated, and may cause churn. Uncomment if using
|
48
|
+
# auto-import.
|
49
|
+
# .idea/artifacts
|
50
|
+
# .idea/compiler.xml
|
51
|
+
# .idea/jarRepositories.xml
|
52
|
+
# .idea/modules.xml
|
53
|
+
# .idea/*.iml
|
54
|
+
# .idea/modules
|
55
|
+
# *.iml
|
56
|
+
# *.ipr
|
57
|
+
|
58
|
+
# CMake
|
59
|
+
cmake-build-*/
|
60
|
+
|
61
|
+
# Mongo Explorer plugin
|
62
|
+
.idea/**/mongoSettings.xml
|
63
|
+
|
64
|
+
# File-based project format
|
65
|
+
*.iws
|
66
|
+
|
67
|
+
# IntelliJ
|
68
|
+
out/
|
69
|
+
|
70
|
+
# mpeltonen/sbt-idea plugin
|
71
|
+
.idea_modules/
|
72
|
+
|
73
|
+
# JIRA plugin
|
74
|
+
atlassian-ide-plugin.xml
|
75
|
+
|
76
|
+
# Cursive Clojure plugin
|
77
|
+
.idea/replstate.xml
|
78
|
+
|
79
|
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
80
|
+
com_crashlytics_export_strings.xml
|
81
|
+
crashlytics.properties
|
82
|
+
crashlytics-build.properties
|
83
|
+
fabric.properties
|
84
|
+
|
85
|
+
# Editor-based Rest Client
|
86
|
+
.idea/httpRequests
|
87
|
+
|
88
|
+
# Android studio 3.1+ serialized cache file
|
89
|
+
.idea/caches/build_file_checksums.ser
|
data/.rspec
ADDED
data/.rubocop.yml
ADDED
data/CHANGELOG.md
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
## [Unreleased]
|
2
|
+
|
3
|
+
## [0.3.1] - 2021-12-27
|
4
|
+
|
5
|
+
- Fix bugs related to bootstrapping and random feature selection.
|
6
|
+
- Add more tests.
|
7
|
+
|
8
|
+
## [0.4.0] - 2021-12-27
|
9
|
+
|
10
|
+
- Minor refactoring
|
11
|
+
|
12
|
+
## [0.5.0] - 2021-12-27
|
13
|
+
|
14
|
+
- Make all utility methods module methods
|
15
|
+
- Implement random sampling without replacement
|
16
|
+
|
17
|
+
## [0.6.0] - 2021-12-28
|
18
|
+
|
19
|
+
- Add nicer to_s for HybridForest::RandomForest
|
20
|
+
|
21
|
+
## [0.7.0] - 2021-12-28
|
22
|
+
|
23
|
+
- Add specific title to HybridForest::RandomForest#to_s
|
24
|
+
|
25
|
+
## [0.8.0] - 2021-12-28
|
26
|
+
|
27
|
+
- Raise custom error, not string
|
28
|
+
- Add supported tree types list
|
29
|
+
- Minor refactoring of internal logic
|
30
|
+
|
31
|
+
## [0.9.0] - 2021-12-28
|
32
|
+
|
33
|
+
- Update dependencies
|
data/Gemfile
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
source "https://rubygems.org"
|
4
|
+
|
5
|
+
# Specify your gem's dependencies in hybridforest.gemspec
|
6
|
+
gemspec
|
7
|
+
|
8
|
+
gem "rake", "~> 13.0"
|
9
|
+
|
10
|
+
gem "rspec", "~> 3.0"
|
11
|
+
|
12
|
+
gem "rubocop", "~> 1.7"
|
13
|
+
|
14
|
+
gem "rumale"
|
15
|
+
|
16
|
+
gem "activesupport", "~> 6.1"
|
17
|
+
|
18
|
+
gem "rover-df", "~> 0.2.6"
|
19
|
+
|
20
|
+
gem "require_all"
|
21
|
+
|
22
|
+
gem "terminal-table", "~> 3.0"
|
data/Gemfile.lock
ADDED
@@ -0,0 +1,92 @@
|
|
1
|
+
PATH
|
2
|
+
remote: .
|
3
|
+
specs:
|
4
|
+
hybridforest (0.9.0)
|
5
|
+
activesupport (~> 6.1)
|
6
|
+
rake (~> 13.0)
|
7
|
+
require_all
|
8
|
+
rover-df (~> 0.2.6)
|
9
|
+
rspec (~> 3.0)
|
10
|
+
rubocop (~> 1.7)
|
11
|
+
rumale
|
12
|
+
terminal-table (~> 3.0)
|
13
|
+
|
14
|
+
GEM
|
15
|
+
remote: https://rubygems.org/
|
16
|
+
specs:
|
17
|
+
activesupport (6.1.4.1)
|
18
|
+
concurrent-ruby (~> 1.0, >= 1.0.2)
|
19
|
+
i18n (>= 1.6, < 2)
|
20
|
+
minitest (>= 5.1)
|
21
|
+
tzinfo (~> 2.0)
|
22
|
+
zeitwerk (~> 2.3)
|
23
|
+
ast (2.4.2)
|
24
|
+
concurrent-ruby (1.1.9)
|
25
|
+
diff-lcs (1.4.4)
|
26
|
+
i18n (1.8.11)
|
27
|
+
concurrent-ruby (~> 1.0)
|
28
|
+
lbfgsb (0.4.1)
|
29
|
+
numo-narray (>= 0.9.1)
|
30
|
+
minitest (5.14.4)
|
31
|
+
numo-narray (0.9.2.0)
|
32
|
+
parallel (1.21.0)
|
33
|
+
parser (3.0.3.1)
|
34
|
+
ast (~> 2.4.1)
|
35
|
+
rainbow (3.0.0)
|
36
|
+
rake (13.0.6)
|
37
|
+
regexp_parser (2.2.0)
|
38
|
+
require_all (3.0.0)
|
39
|
+
rexml (3.2.5)
|
40
|
+
rover-df (0.2.6)
|
41
|
+
numo-narray (>= 0.9.1.9)
|
42
|
+
rspec (3.10.0)
|
43
|
+
rspec-core (~> 3.10.0)
|
44
|
+
rspec-expectations (~> 3.10.0)
|
45
|
+
rspec-mocks (~> 3.10.0)
|
46
|
+
rspec-core (3.10.1)
|
47
|
+
rspec-support (~> 3.10.0)
|
48
|
+
rspec-expectations (3.10.1)
|
49
|
+
diff-lcs (>= 1.2.0, < 2.0)
|
50
|
+
rspec-support (~> 3.10.0)
|
51
|
+
rspec-mocks (3.10.2)
|
52
|
+
diff-lcs (>= 1.2.0, < 2.0)
|
53
|
+
rspec-support (~> 3.10.0)
|
54
|
+
rspec-support (3.10.3)
|
55
|
+
rubocop (1.23.0)
|
56
|
+
parallel (~> 1.10)
|
57
|
+
parser (>= 3.0.0.0)
|
58
|
+
rainbow (>= 2.2.2, < 4.0)
|
59
|
+
regexp_parser (>= 1.8, < 3.0)
|
60
|
+
rexml
|
61
|
+
rubocop-ast (>= 1.12.0, < 2.0)
|
62
|
+
ruby-progressbar (~> 1.7)
|
63
|
+
unicode-display_width (>= 1.4.0, < 3.0)
|
64
|
+
rubocop-ast (1.14.0)
|
65
|
+
parser (>= 3.0.1.1)
|
66
|
+
ruby-progressbar (1.11.0)
|
67
|
+
rumale (0.23.1)
|
68
|
+
lbfgsb (>= 0.3.0)
|
69
|
+
numo-narray (>= 0.9.1)
|
70
|
+
terminal-table (3.0.2)
|
71
|
+
unicode-display_width (>= 1.1.1, < 3)
|
72
|
+
tzinfo (2.0.4)
|
73
|
+
concurrent-ruby (~> 1.0)
|
74
|
+
unicode-display_width (2.1.0)
|
75
|
+
zeitwerk (2.5.1)
|
76
|
+
|
77
|
+
PLATFORMS
|
78
|
+
x86_64-darwin-20
|
79
|
+
|
80
|
+
DEPENDENCIES
|
81
|
+
activesupport (~> 6.1)
|
82
|
+
hybridforest!
|
83
|
+
rake (~> 13.0)
|
84
|
+
require_all
|
85
|
+
rover-df (~> 0.2.6)
|
86
|
+
rspec (~> 3.0)
|
87
|
+
rubocop (~> 1.7)
|
88
|
+
rumale
|
89
|
+
terminal-table (~> 3.0)
|
90
|
+
|
91
|
+
BUNDLED WITH
|
92
|
+
2.3.3
|
data/LICENSE.txt
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
The MIT License (MIT)
|
2
|
+
|
3
|
+
Copyright (c) 2021 hi-tech-jazz
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in
|
13
|
+
all copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
21
|
+
THE SOFTWARE.
|
data/README.md
ADDED
@@ -0,0 +1,96 @@
|
|
1
|
+
# HybridForest
|
2
|
+
|
3
|
+
<code>HybridForest</code> offers the possibility to build hybrid random forests, i.e., ensembles where the base learners are built from not one but several different decision tree algorithms. As of now, two types of trees are supported:
|
4
|
+
* `CARTTree`
|
5
|
+
* Performs binary splits at each internal node.
|
6
|
+
* Supports categorical and continuous features.
|
7
|
+
* Supports binary classification problems.
|
8
|
+
* Uses gini impurity to find the most discriminatory feature.
|
9
|
+
* Considers a random subset of features in each split.
|
10
|
+
* Loosely based on the original CART algorithm (Breiman et al., 1984).
|
11
|
+
* `ID3Tree`
|
12
|
+
* Performs multiway (>=2) splits at each internal node.
|
13
|
+
* Supports categorical and continuous features.
|
14
|
+
* Supports binary classification problems.
|
15
|
+
* Uses entropy to find the most discriminatory feature.
|
16
|
+
* Considers every feature in max one split.
|
17
|
+
* Loosely based on the ID3 algorithm (Quinlan, 1986).
|
18
|
+
|
19
|
+
The random forest itself is represented by the `RandomForest` class.
|
20
|
+
A random forest classifier can be created with one of three base learner configurations.
|
21
|
+
|
22
|
+
1. Hybrid mode
|
23
|
+
```ruby
|
24
|
+
# Equivalent, hybrid is default.
|
25
|
+
HybridForest::RandomForest.new(number_of_trees: 100, ensemble_type: :hybrid)
|
26
|
+
HybridForest::RandomForest.new(number_of_trees: 100)
|
27
|
+
```
|
28
|
+
|
29
|
+
|
30
|
+
2. CART mode
|
31
|
+
```ruby
|
32
|
+
HybridForest::RandomForest.new(number_of_trees: 100, ensemble_type: :cart)
|
33
|
+
```
|
34
|
+
3. ID3 mode
|
35
|
+
|
36
|
+
```ruby
|
37
|
+
HybridForest::RandomForest.new(number_of_trees: 100, ensemble_type: :id3)
|
38
|
+
```
|
39
|
+
|
40
|
+
The implementation is quite naive and there are a bunch of features that might be nice to have but are not supported, including:
|
41
|
+
* Pruning
|
42
|
+
* Parallelization
|
43
|
+
* More decision trees, e.g., C4.5 and CHAID
|
44
|
+
* Additional hyperparameters
|
45
|
+
|
46
|
+
## Installation
|
47
|
+
|
48
|
+
Add this line to your application's Gemfile:
|
49
|
+
|
50
|
+
```ruby
|
51
|
+
gem 'hybridforest'
|
52
|
+
```
|
53
|
+
|
54
|
+
And then execute:
|
55
|
+
|
56
|
+
$ bundle install
|
57
|
+
|
58
|
+
Or install it yourself as:
|
59
|
+
|
60
|
+
$ gem install hybridforest
|
61
|
+
|
62
|
+
## Usage
|
63
|
+
|
64
|
+
```ruby
|
65
|
+
require "hybridforest"
|
66
|
+
|
67
|
+
# Prepare data.
|
68
|
+
# A dataset can be passed as a CSV file path, an array of hashes, or a hash of arrays.
|
69
|
+
training_set, test_set, actual_test_labels = HybridForest::Utils.train_test_split("data.csv")
|
70
|
+
|
71
|
+
# Create classifier.
|
72
|
+
hybrid_forest = HybridForest::RandomForest.new(number_of_trees: 100)
|
73
|
+
|
74
|
+
# Fit model.
|
75
|
+
hybrid_forest.fit(training_set)
|
76
|
+
|
77
|
+
# Predict.
|
78
|
+
predicted_labels = hybrid_forest.predict(test_set)
|
79
|
+
|
80
|
+
# Report metrics.
|
81
|
+
puts HybridForest::Utils.prediction_report(actual_test_labels, predicted_labels)
|
82
|
+
```
|
83
|
+
|
84
|
+
## Development
|
85
|
+
|
86
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
87
|
+
|
88
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and the created tag, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
89
|
+
|
90
|
+
## Contributing
|
91
|
+
|
92
|
+
Bug reports and pull requests are welcome on GitHub at https://github.com/hi-tech-jazz/hybridforest.
|
93
|
+
|
94
|
+
## License
|
95
|
+
|
96
|
+
The gem is available as open source under the terms of the [MIT License](https://opensource.org/licenses/MIT).
|
data/Rakefile
ADDED
data/bin/console
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "bundler/setup"
|
5
|
+
require "hybridforest"
|
6
|
+
|
7
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
8
|
+
# with your gem easier. You can also use a different console, if you like.
|
9
|
+
|
10
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
11
|
+
# require "pry"
|
12
|
+
# Pry.start
|
13
|
+
|
14
|
+
require "irb"
|
15
|
+
IRB.start(__FILE__)
|
data/bin/setup
ADDED
@@ -0,0 +1,44 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "lib/hybridforest/version"
|
4
|
+
|
5
|
+
Gem::Specification.new do |spec|
|
6
|
+
spec.name = "hybridforest"
|
7
|
+
spec.version = HybridForest::VERSION
|
8
|
+
spec.authors = ["hi-tech-jazz"]
|
9
|
+
spec.email = ["jazztechhi@gmail.com"]
|
10
|
+
|
11
|
+
spec.summary = "Hybrid random forest for binary classification tasks."
|
12
|
+
spec.description = "HybridForest provides random forests built upon combinations of different decision tree algorithms to enable diverse tree ensembles. Until version 1.0.0, please expect breaking changes."
|
13
|
+
spec.homepage = "https://github.com/hi-tech-jazz/hybridforest"
|
14
|
+
spec.license = "MIT"
|
15
|
+
spec.required_ruby_version = ">= 3.0.2"
|
16
|
+
|
17
|
+
spec.metadata["homepage_uri"] = "https://github.com/hi-tech-jazz/hybridforest"
|
18
|
+
spec.metadata["source_code_uri"] = "https://github.com/hi-tech-jazz/hybridforest"
|
19
|
+
spec.metadata["changelog_uri"] = "https://github.com/hi-tech-jazz/hybridforest/blob/master/CHANGELOG.md"
|
20
|
+
|
21
|
+
# Specify which files should be added to the gem when it is released.
|
22
|
+
# The `git ls-files -z` loads the files in the RubyGem that have been added into git.
|
23
|
+
spec.files = Dir.chdir(File.expand_path(__dir__)) do
|
24
|
+
`git ls-files -z`.split("\x0").reject { |f| f.match(%r{\A(?:test|spec|features)/}) }
|
25
|
+
end
|
26
|
+
spec.bindir = "exe"
|
27
|
+
spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
|
28
|
+
spec.require_paths = ["lib"]
|
29
|
+
|
30
|
+
spec.add_dependency "rake", "~> 13.0"
|
31
|
+
spec.add_dependency "rspec", "~> 3.0"
|
32
|
+
spec.add_dependency "rubocop", "~> 1.7"
|
33
|
+
spec.add_dependency "rumale"
|
34
|
+
spec.add_dependency "activesupport", "~> 6.1"
|
35
|
+
spec.add_dependency "rover-df", "~> 0.2.6"
|
36
|
+
spec.add_dependency "require_all"
|
37
|
+
spec.add_dependency "terminal-table", "~> 3.0"
|
38
|
+
spec.add_development_dependency "rake", "~> 13.0"
|
39
|
+
spec.add_development_dependency "rspec", "~> 3.0"
|
40
|
+
spec.add_development_dependency "rubocop", "~> 1.7"
|
41
|
+
spec.add_development_dependency "rumale"
|
42
|
+
spec.add_development_dependency "activesupport", "~> 6.1"
|
43
|
+
spec.add_development_dependency "rover-df", "~> 0.2.6"
|
44
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
require_relative "../../utilities/utils"
|
2
|
+
require_relative "../../trees/cart_tree"
|
3
|
+
|
4
|
+
module HybridForest
|
5
|
+
module Forests
|
6
|
+
module ForestGrowers
|
7
|
+
class CARTGrower
|
8
|
+
def grow_forest(instances, number_of_trees)
|
9
|
+
forest = []
|
10
|
+
number_of_trees.times do
|
11
|
+
sample, _, _ = HybridForest::Utils.train_test_bootstrap_split(instances)
|
12
|
+
forest << HybridForest::Trees::CARTTree.new.fit(sample)
|
13
|
+
end
|
14
|
+
forest
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,46 @@
|
|
1
|
+
require_relative "../../utilities/utils"
|
2
|
+
require_relative "../../trees/id3_tree"
|
3
|
+
require_relative "../../trees/cart_tree"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Forests
|
7
|
+
module ForestGrowers
|
8
|
+
class HybridGrower
|
9
|
+
TREE_TYPES = [HybridForest::Trees::CARTTree, HybridForest::Trees::ID3Tree].freeze
|
10
|
+
|
11
|
+
def grow_forest(instances, number_of_trees)
|
12
|
+
forest = []
|
13
|
+
number_of_trees.times do
|
14
|
+
in_of_bag, out_of_bag, out_of_bag_labels = HybridForest::Utils.train_test_bootstrap_split(instances)
|
15
|
+
tree_results = grow_trees(TREE_TYPES, in_of_bag, out_of_bag, out_of_bag_labels)
|
16
|
+
best_tree = select_best_tree(tree_results)
|
17
|
+
forest << best_tree
|
18
|
+
end
|
19
|
+
forest
|
20
|
+
end
|
21
|
+
|
22
|
+
private
|
23
|
+
|
24
|
+
def fit_and_predict(tree_class, in_of_bag, out_of_bag, out_of_bag_labels)
|
25
|
+
tree = tree_class.new.fit(in_of_bag)
|
26
|
+
tree_predictions = tree.predict(out_of_bag)
|
27
|
+
tree_accuracy = HybridForest::Utils.accuracy(tree_predictions, out_of_bag_labels)
|
28
|
+
{tree: tree, oob_accuracy: tree_accuracy}
|
29
|
+
end
|
30
|
+
|
31
|
+
def select_best_tree(tree_results)
|
32
|
+
best_result = tree_results.max_by(1) { |result| result[:oob_accuracy] }.first
|
33
|
+
best_result[:tree]
|
34
|
+
end
|
35
|
+
|
36
|
+
def grow_trees(tree_types, in_of_bag, out_of_bag, out_of_bag_labels)
|
37
|
+
tree_results = []
|
38
|
+
tree_types.each do |tree_type|
|
39
|
+
tree_results << fit_and_predict(tree_type, in_of_bag, out_of_bag, out_of_bag_labels)
|
40
|
+
end
|
41
|
+
tree_results
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
require_relative "../../utilities/utils"
|
2
|
+
require_relative "../../trees/id3_tree"
|
3
|
+
|
4
|
+
module HybridForest
|
5
|
+
module Forests
|
6
|
+
module ForestGrowers
|
7
|
+
class ID3Grower
|
8
|
+
def grow_forest(instances, number_of_trees)
|
9
|
+
forest = []
|
10
|
+
number_of_trees.times do
|
11
|
+
sample, _, _ = HybridForest::Utils.train_test_bootstrap_split(instances)
|
12
|
+
forest << HybridForest::Trees::ID3Tree.new.fit(sample)
|
13
|
+
end
|
14
|
+
forest
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
require_relative "../forests/forest_growers/hybrid_grower"
|
2
|
+
require_relative "../forests/forest_growers/cart_grower"
|
3
|
+
require_relative "../forests/forest_growers/id3_grower"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Forests
|
7
|
+
class GrowerFactory
|
8
|
+
TYPES = {
|
9
|
+
cart: HybridForest::Forests::ForestGrowers::CARTGrower,
|
10
|
+
id3: HybridForest::Forests::ForestGrowers::ID3Grower,
|
11
|
+
hybrid: HybridForest::Forests::ForestGrowers::HybridGrower
|
12
|
+
}.freeze
|
13
|
+
|
14
|
+
class << self
|
15
|
+
def for(type)
|
16
|
+
(TYPES[type] || default).new
|
17
|
+
end
|
18
|
+
|
19
|
+
def default
|
20
|
+
HybridForest::Forests::ForestGrowers::HybridGrower
|
21
|
+
end
|
22
|
+
|
23
|
+
def types
|
24
|
+
TYPES.keys
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "../utilities/utils"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
class RandomForest
|
7
|
+
##
|
8
|
+
# Creates a new random forest.
|
9
|
+
#
|
10
|
+
# +number_of_trees+ dictates the size of the tree ensemble.
|
11
|
+
#
|
12
|
+
# +ensemble_type+ dictates the composition of the tree ensemble.
|
13
|
+
# Valid options are +:hybrid+, +:cart+, +:id3+.
|
14
|
+
#
|
15
|
+
def initialize(number_of_trees:, ensemble_type: :hybrid)
|
16
|
+
raise ArgumentError, "Invalid ensemble type" unless Forests::GrowerFactory.types.include? ensemble_type
|
17
|
+
|
18
|
+
@number_of_trees = number_of_trees
|
19
|
+
@ensemble_type = ensemble_type
|
20
|
+
end
|
21
|
+
|
22
|
+
##
|
23
|
+
# Fits a model to the given dataset +instances+ and returns +self+.
|
24
|
+
#
|
25
|
+
def fit(instances)
|
26
|
+
instances = HybridForest::Utils.to_dataframe(instances)
|
27
|
+
forest_grower = Forests::GrowerFactory.for(@ensemble_type)
|
28
|
+
@forest = forest_grower.grow_forest(instances, @number_of_trees)
|
29
|
+
self
|
30
|
+
end
|
31
|
+
|
32
|
+
##
|
33
|
+
# Predicts a label for each instance in the dataset +instances+ and returns an array of labels.
|
34
|
+
#
|
35
|
+
def predict(instances)
|
36
|
+
raise Errors::InvalidStateError, "You must call #fit before you call #predict" if @forest.nil?
|
37
|
+
|
38
|
+
instances = HybridForest::Utils.to_dataframe(instances)
|
39
|
+
predictions = tree_predictions(instances)
|
40
|
+
predictions.collect { |votes| majority_vote(votes) }
|
41
|
+
end
|
42
|
+
|
43
|
+
def to_s
|
44
|
+
return "Empty random forest: \n#{super()}" if @forest.nil?
|
45
|
+
|
46
|
+
title = case @ensemble_type
|
47
|
+
when :hybrid then "Hybrid random forest"
|
48
|
+
else "Uniform random forest"
|
49
|
+
end
|
50
|
+
|
51
|
+
table = Terminal::Table.new do |t|
|
52
|
+
t.title = title
|
53
|
+
t.headings = %w[Tree Count]
|
54
|
+
tally_ensemble.each do |tree_type, count|
|
55
|
+
t << [tree_type, count]
|
56
|
+
t << :separator
|
57
|
+
end
|
58
|
+
t << ["Total", @number_of_trees]
|
59
|
+
end
|
60
|
+
table.to_s
|
61
|
+
end
|
62
|
+
|
63
|
+
private
|
64
|
+
|
65
|
+
def majority_vote(votes)
|
66
|
+
votes.max_by(1) { |label| votes.count(label) }.first
|
67
|
+
end
|
68
|
+
|
69
|
+
def tree_predictions(instances)
|
70
|
+
predictions = @forest.collect { |tree| tree.predict(instances) }
|
71
|
+
predictions.transpose
|
72
|
+
end
|
73
|
+
|
74
|
+
def tally_ensemble
|
75
|
+
tree_type_counts = Hash.new(0)
|
76
|
+
@forest.each do |tree|
|
77
|
+
key = tree.name
|
78
|
+
tree_type_counts[key] += 1
|
79
|
+
end
|
80
|
+
tree_type_counts.default = nil
|
81
|
+
tree_type_counts
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|