hybridforest 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (43) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +89 -0
  3. data/.rspec +3 -0
  4. data/.rubocop.yml +13 -0
  5. data/CHANGELOG.md +33 -0
  6. data/Gemfile +22 -0
  7. data/Gemfile.lock +92 -0
  8. data/LICENSE.txt +21 -0
  9. data/README.md +96 -0
  10. data/Rakefile +12 -0
  11. data/bin/console +15 -0
  12. data/bin/setup +8 -0
  13. data/hybridforest.gemspec +44 -0
  14. data/lib/hybridforest/errors/invalid_state_error.rb +9 -0
  15. data/lib/hybridforest/forests/forest_growers/cart_grower.rb +19 -0
  16. data/lib/hybridforest/forests/forest_growers/hybrid_grower.rb +46 -0
  17. data/lib/hybridforest/forests/forest_growers/id3_grower.rb +19 -0
  18. data/lib/hybridforest/forests/grower_factory.rb +29 -0
  19. data/lib/hybridforest/forests/random_forest.rb +84 -0
  20. data/lib/hybridforest/trees/cart_tree.rb +18 -0
  21. data/lib/hybridforest/trees/feature_selectors/all_features.rb +14 -0
  22. data/lib/hybridforest/trees/feature_selectors/max_one_split_per_feature.rb +21 -0
  23. data/lib/hybridforest/trees/feature_selectors/random_feature_subspace.rb +27 -0
  24. data/lib/hybridforest/trees/id3_tree.rb +18 -0
  25. data/lib/hybridforest/trees/impurity_metrics/entropy.rb +21 -0
  26. data/lib/hybridforest/trees/impurity_metrics/gini_impurity.rb +17 -0
  27. data/lib/hybridforest/trees/impurity_metrics/impurity.rb +28 -0
  28. data/lib/hybridforest/trees/nodes/binary_node.rb +37 -0
  29. data/lib/hybridforest/trees/nodes/leaf_node.rb +31 -0
  30. data/lib/hybridforest/trees/nodes/multiway_node.rb +45 -0
  31. data/lib/hybridforest/trees/split.rb +30 -0
  32. data/lib/hybridforest/trees/tests/equal.rb +37 -0
  33. data/lib/hybridforest/trees/tests/equal_or_greater.rb +37 -0
  34. data/lib/hybridforest/trees/tests/less.rb +37 -0
  35. data/lib/hybridforest/trees/tests/not_equal.rb +37 -0
  36. data/lib/hybridforest/trees/tests/test.rb +28 -0
  37. data/lib/hybridforest/trees/tree.rb +46 -0
  38. data/lib/hybridforest/trees/tree_growers/cart_grower.rb +76 -0
  39. data/lib/hybridforest/trees/tree_growers/id3_grower.rb +161 -0
  40. data/lib/hybridforest/utilities/utils.rb +206 -0
  41. data/lib/hybridforest/version.rb +5 -0
  42. data/lib/hybridforest.rb +46 -0
  43. metadata +285 -0
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 8d0af8ba6e658989bf55d0b8026822ab0be77020c189ffd056ffac3e7a393700
4
+ data.tar.gz: 968c29d52088fe4ba28d93fcd3aedf7288d323f3a35ba000efde5b1673c10e2b
5
+ SHA512:
6
+ metadata.gz: 2e7076a8cdea280c900f61c98a6f53715a335d7765a6861f028a6956a8c6599cbe233b5d3e867b0d55f9a8889e41067e314ca151b00e4455f55071ac407059f5
7
+ data.tar.gz: dbf3dc47bfba70ff33e9452cc9232d5079651410a3eb319b29172992141c1602f8c06bba50e7ee9a56c1c8530e9099763834b1a964a86cdc4ec4c89d944237f4
data/.gitignore ADDED
@@ -0,0 +1,89 @@
1
+ /.bundle/
2
+ /.yardoc
3
+ /_yardoc/
4
+ /coverage/
5
+ /doc/
6
+ /pkg/
7
+ /spec/reports/
8
+ /tmp/
9
+
10
+ # rspec failure tracking
11
+ .rspec_status
12
+
13
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
14
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
15
+
16
+ # User-specific stuff
17
+ .idea/**/workspace.xml
18
+ .idea/**/tasks.xml
19
+ .idea/**/usage.statistics.xml
20
+ .idea/**/dictionaries
21
+ .idea/**/shelf
22
+ .DS_Store
23
+ .idea/
24
+ .rubocop.yml
25
+
26
+ # AWS User-specific
27
+ .idea/**/aws.xml
28
+
29
+ # Generated files
30
+ .idea/**/contentModel.xml
31
+
32
+ # Sensitive or high-churn files
33
+ .idea/**/dataSources/
34
+ .idea/**/dataSources.ids
35
+ .idea/**/dataSources.local.xml
36
+ .idea/**/sqlDataSources.xml
37
+ .idea/**/dynamic.xml
38
+ .idea/**/uiDesigner.xml
39
+ .idea/**/dbnavigator.xml
40
+
41
+ # Gradle
42
+ .idea/**/gradle.xml
43
+ .idea/**/libraries
44
+
45
+ # Gradle and Maven with auto-import
46
+ # When using Gradle or Maven with auto-import, you should exclude module files,
47
+ # since they will be recreated, and may cause churn. Uncomment if using
48
+ # auto-import.
49
+ # .idea/artifacts
50
+ # .idea/compiler.xml
51
+ # .idea/jarRepositories.xml
52
+ # .idea/modules.xml
53
+ # .idea/*.iml
54
+ # .idea/modules
55
+ # *.iml
56
+ # *.ipr
57
+
58
+ # CMake
59
+ cmake-build-*/
60
+
61
+ # Mongo Explorer plugin
62
+ .idea/**/mongoSettings.xml
63
+
64
+ # File-based project format
65
+ *.iws
66
+
67
+ # IntelliJ
68
+ out/
69
+
70
+ # mpeltonen/sbt-idea plugin
71
+ .idea_modules/
72
+
73
+ # JIRA plugin
74
+ atlassian-ide-plugin.xml
75
+
76
+ # Cursive Clojure plugin
77
+ .idea/replstate.xml
78
+
79
+ # Crashlytics plugin (for Android Studio and IntelliJ)
80
+ com_crashlytics_export_strings.xml
81
+ crashlytics.properties
82
+ crashlytics-build.properties
83
+ fabric.properties
84
+
85
+ # Editor-based Rest Client
86
+ .idea/httpRequests
87
+
88
+ # Android studio 3.1+ serialized cache file
89
+ .idea/caches/build_file_checksums.ser
data/.rspec ADDED
@@ -0,0 +1,3 @@
1
+ --format documentation
2
+ --color
3
+ --require spec_helper
data/.rubocop.yml ADDED
@@ -0,0 +1,13 @@
1
+ AllCops:
2
+ TargetRubyVersion: 3.1
3
+
4
+ Style/StringLiterals:
5
+ Enabled: true
6
+ EnforcedStyle: double_quotes
7
+
8
+ Style/StringLiteralsInInterpolation:
9
+ Enabled: true
10
+ EnforcedStyle: double_quotes
11
+
12
+ Layout/LineLength:
13
+ Max: 120
data/CHANGELOG.md ADDED
@@ -0,0 +1,33 @@
1
+ ## [Unreleased]
2
+
3
+ ## [0.3.1] - 2021-12-27
4
+
5
+ - Fix bugs related to bootstrapping and random feature selection.
6
+ - Add more tests.
7
+
8
+ ## [0.4.0] - 2021-12-27
9
+
10
+ - Minor refactoring
11
+
12
+ ## [0.5.0] - 2021-12-27
13
+
14
+ - Make all utility methods module methods
15
+ - Implement random sampling without replacement
16
+
17
+ ## [0.6.0] - 2021-12-28
18
+
19
+ - Add nicer to_s for HybridForest::RandomForest
20
+
21
+ ## [0.7.0] - 2021-12-28
22
+
23
+ - Add specific title to HybridForest::RandomForest#to_s
24
+
25
+ ## [0.8.0] - 2021-12-28
26
+
27
+ - Raise custom error, not string
28
+ - Add supported tree types list
29
+ - Minor refactoring of internal logic
30
+
31
+ ## [0.9.0] - 2021-12-28
32
+
33
+ - Update dependencies
data/Gemfile ADDED
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ source "https://rubygems.org"
4
+
5
+ # Specify your gem's dependencies in hybridforest.gemspec
6
+ gemspec
7
+
8
+ gem "rake", "~> 13.0"
9
+
10
+ gem "rspec", "~> 3.0"
11
+
12
+ gem "rubocop", "~> 1.7"
13
+
14
+ gem "rumale"
15
+
16
+ gem "activesupport", "~> 6.1"
17
+
18
+ gem "rover-df", "~> 0.2.6"
19
+
20
+ gem "require_all"
21
+
22
+ gem "terminal-table", "~> 3.0"
data/Gemfile.lock ADDED
@@ -0,0 +1,92 @@
1
+ PATH
2
+ remote: .
3
+ specs:
4
+ hybridforest (0.9.0)
5
+ activesupport (~> 6.1)
6
+ rake (~> 13.0)
7
+ require_all
8
+ rover-df (~> 0.2.6)
9
+ rspec (~> 3.0)
10
+ rubocop (~> 1.7)
11
+ rumale
12
+ terminal-table (~> 3.0)
13
+
14
+ GEM
15
+ remote: https://rubygems.org/
16
+ specs:
17
+ activesupport (6.1.4.1)
18
+ concurrent-ruby (~> 1.0, >= 1.0.2)
19
+ i18n (>= 1.6, < 2)
20
+ minitest (>= 5.1)
21
+ tzinfo (~> 2.0)
22
+ zeitwerk (~> 2.3)
23
+ ast (2.4.2)
24
+ concurrent-ruby (1.1.9)
25
+ diff-lcs (1.4.4)
26
+ i18n (1.8.11)
27
+ concurrent-ruby (~> 1.0)
28
+ lbfgsb (0.4.1)
29
+ numo-narray (>= 0.9.1)
30
+ minitest (5.14.4)
31
+ numo-narray (0.9.2.0)
32
+ parallel (1.21.0)
33
+ parser (3.0.3.1)
34
+ ast (~> 2.4.1)
35
+ rainbow (3.0.0)
36
+ rake (13.0.6)
37
+ regexp_parser (2.2.0)
38
+ require_all (3.0.0)
39
+ rexml (3.2.5)
40
+ rover-df (0.2.6)
41
+ numo-narray (>= 0.9.1.9)
42
+ rspec (3.10.0)
43
+ rspec-core (~> 3.10.0)
44
+ rspec-expectations (~> 3.10.0)
45
+ rspec-mocks (~> 3.10.0)
46
+ rspec-core (3.10.1)
47
+ rspec-support (~> 3.10.0)
48
+ rspec-expectations (3.10.1)
49
+ diff-lcs (>= 1.2.0, < 2.0)
50
+ rspec-support (~> 3.10.0)
51
+ rspec-mocks (3.10.2)
52
+ diff-lcs (>= 1.2.0, < 2.0)
53
+ rspec-support (~> 3.10.0)
54
+ rspec-support (3.10.3)
55
+ rubocop (1.23.0)
56
+ parallel (~> 1.10)
57
+ parser (>= 3.0.0.0)
58
+ rainbow (>= 2.2.2, < 4.0)
59
+ regexp_parser (>= 1.8, < 3.0)
60
+ rexml
61
+ rubocop-ast (>= 1.12.0, < 2.0)
62
+ ruby-progressbar (~> 1.7)
63
+ unicode-display_width (>= 1.4.0, < 3.0)
64
+ rubocop-ast (1.14.0)
65
+ parser (>= 3.0.1.1)
66
+ ruby-progressbar (1.11.0)
67
+ rumale (0.23.1)
68
+ lbfgsb (>= 0.3.0)
69
+ numo-narray (>= 0.9.1)
70
+ terminal-table (3.0.2)
71
+ unicode-display_width (>= 1.1.1, < 3)
72
+ tzinfo (2.0.4)
73
+ concurrent-ruby (~> 1.0)
74
+ unicode-display_width (2.1.0)
75
+ zeitwerk (2.5.1)
76
+
77
+ PLATFORMS
78
+ x86_64-darwin-20
79
+
80
+ DEPENDENCIES
81
+ activesupport (~> 6.1)
82
+ hybridforest!
83
+ rake (~> 13.0)
84
+ require_all
85
+ rover-df (~> 0.2.6)
86
+ rspec (~> 3.0)
87
+ rubocop (~> 1.7)
88
+ rumale
89
+ terminal-table (~> 3.0)
90
+
91
+ BUNDLED WITH
92
+ 2.3.3
data/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2021 hi-tech-jazz
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in
13
+ all copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ THE SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,96 @@
1
+ # HybridForest
2
+
3
+ <code>HybridForest</code> offers the possibility to build hybrid random forests, i.e., ensembles where the base learners are built from not one but several different decision tree algorithms. As of now, two types of trees are supported:
4
+ * `CARTTree`
5
+ * Performs binary splits at each internal node.
6
+ * Supports categorical and continuous features.
7
+ * Supports binary classification problems.
8
+ * Uses gini impurity to find the most discriminatory feature.
9
+ * Considers a random subset of features in each split.
10
+ * Loosely based on the original CART algorithm (Breiman et al., 1984).
11
+ * `ID3Tree`
12
+ * Performs multiway (>=2) splits at each internal node.
13
+ * Supports categorical and continuous features.
14
+ * Supports binary classification problems.
15
+ * Uses entropy to find the most discriminatory feature.
16
+ * Considers every feature in max one split.
17
+ * Loosely based on the ID3 algorithm (Quinlan, 1986).
18
+
19
+ The random forest itself is represented by the `RandomForest` class.
20
+ A random forest classifier can be created with one of three base learner configurations.
21
+
22
+ 1. Hybrid mode
23
+ ```ruby
24
+ # Equivalent, hybrid is default.
25
+ HybridForest::RandomForest.new(number_of_trees: 100, ensemble_type: :hybrid)
26
+ HybridForest::RandomForest.new(number_of_trees: 100)
27
+ ```
28
+
29
+
30
+ 2. CART mode
31
+ ```ruby
32
+ HybridForest::RandomForest.new(number_of_trees: 100, ensemble_type: :cart)
33
+ ```
34
+ 3. ID3 mode
35
+
36
+ ```ruby
37
+ HybridForest::RandomForest.new(number_of_trees: 100, ensemble_type: :id3)
38
+ ```
39
+
40
+ The implementation is quite naive and there are a bunch of features that might be nice to have but are not supported, including:
41
+ * Pruning
42
+ * Parallelization
43
+ * More decision trees, e.g., C4.5 and CHAID
44
+ * Additional hyperparameters
45
+
46
+ ## Installation
47
+
48
+ Add this line to your application's Gemfile:
49
+
50
+ ```ruby
51
+ gem 'hybridforest'
52
+ ```
53
+
54
+ And then execute:
55
+
56
+ $ bundle install
57
+
58
+ Or install it yourself as:
59
+
60
+ $ gem install hybridforest
61
+
62
+ ## Usage
63
+
64
+ ```ruby
65
+ require "hybridforest"
66
+
67
+ # Prepare data.
68
+ # A dataset can be passed as a CSV file path, an array of hashes, or a hash of arrays.
69
+ training_set, test_set, actual_test_labels = HybridForest::Utils.train_test_split("data.csv")
70
+
71
+ # Create classifier.
72
+ hybrid_forest = HybridForest::RandomForest.new(number_of_trees: 100)
73
+
74
+ # Fit model.
75
+ hybrid_forest.fit(training_set)
76
+
77
+ # Predict.
78
+ predicted_labels = hybrid_forest.predict(test_set)
79
+
80
+ # Report metrics.
81
+ puts HybridForest::Utils.prediction_report(actual_test_labels, predicted_labels)
82
+ ```
83
+
84
+ ## Development
85
+
86
+ After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
87
+
88
+ To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and the created tag, and push the `.gem` file to [rubygems.org](https://rubygems.org).
89
+
90
+ ## Contributing
91
+
92
+ Bug reports and pull requests are welcome on GitHub at https://github.com/hi-tech-jazz/hybridforest.
93
+
94
+ ## License
95
+
96
+ The gem is available as open source under the terms of the [MIT License](https://opensource.org/licenses/MIT).
data/Rakefile ADDED
@@ -0,0 +1,12 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "bundler/gem_tasks"
4
+ require "rspec/core/rake_task"
5
+
6
+ RSpec::Core::RakeTask.new(:spec)
7
+
8
+ require "rubocop/rake_task"
9
+
10
+ RuboCop::RakeTask.new
11
+
12
+ task default: %i[spec rubocop]
data/bin/console ADDED
@@ -0,0 +1,15 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ require "bundler/setup"
5
+ require "hybridforest"
6
+
7
+ # You can add fixtures and/or initialization code here to make experimenting
8
+ # with your gem easier. You can also use a different console, if you like.
9
+
10
+ # (If you use this, don't forget to add pry to your Gemfile!)
11
+ # require "pry"
12
+ # Pry.start
13
+
14
+ require "irb"
15
+ IRB.start(__FILE__)
data/bin/setup ADDED
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ # Do any other automated setup that you need to do here
@@ -0,0 +1,44 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "lib/hybridforest/version"
4
+
5
+ Gem::Specification.new do |spec|
6
+ spec.name = "hybridforest"
7
+ spec.version = HybridForest::VERSION
8
+ spec.authors = ["hi-tech-jazz"]
9
+ spec.email = ["jazztechhi@gmail.com"]
10
+
11
+ spec.summary = "Hybrid random forest for binary classification tasks."
12
+ spec.description = "HybridForest provides random forests built upon combinations of different decision tree algorithms to enable diverse tree ensembles. Until version 1.0.0, please expect breaking changes."
13
+ spec.homepage = "https://github.com/hi-tech-jazz/hybridforest"
14
+ spec.license = "MIT"
15
+ spec.required_ruby_version = ">= 3.0.2"
16
+
17
+ spec.metadata["homepage_uri"] = "https://github.com/hi-tech-jazz/hybridforest"
18
+ spec.metadata["source_code_uri"] = "https://github.com/hi-tech-jazz/hybridforest"
19
+ spec.metadata["changelog_uri"] = "https://github.com/hi-tech-jazz/hybridforest/blob/master/CHANGELOG.md"
20
+
21
+ # Specify which files should be added to the gem when it is released.
22
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
23
+ spec.files = Dir.chdir(File.expand_path(__dir__)) do
24
+ `git ls-files -z`.split("\x0").reject { |f| f.match(%r{\A(?:test|spec|features)/}) }
25
+ end
26
+ spec.bindir = "exe"
27
+ spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
28
+ spec.require_paths = ["lib"]
29
+
30
+ spec.add_dependency "rake", "~> 13.0"
31
+ spec.add_dependency "rspec", "~> 3.0"
32
+ spec.add_dependency "rubocop", "~> 1.7"
33
+ spec.add_dependency "rumale"
34
+ spec.add_dependency "activesupport", "~> 6.1"
35
+ spec.add_dependency "rover-df", "~> 0.2.6"
36
+ spec.add_dependency "require_all"
37
+ spec.add_dependency "terminal-table", "~> 3.0"
38
+ spec.add_development_dependency "rake", "~> 13.0"
39
+ spec.add_development_dependency "rspec", "~> 3.0"
40
+ spec.add_development_dependency "rubocop", "~> 1.7"
41
+ spec.add_development_dependency "rumale"
42
+ spec.add_development_dependency "activesupport", "~> 6.1"
43
+ spec.add_development_dependency "rover-df", "~> 0.2.6"
44
+ end
@@ -0,0 +1,9 @@
1
+ module HybridForest
2
+ module Errors
3
+ class InvalidStateError < StandardError
4
+ def initialize(message = nil)
5
+ super
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ require_relative "../../utilities/utils"
2
+ require_relative "../../trees/cart_tree"
3
+
4
+ module HybridForest
5
+ module Forests
6
+ module ForestGrowers
7
+ class CARTGrower
8
+ def grow_forest(instances, number_of_trees)
9
+ forest = []
10
+ number_of_trees.times do
11
+ sample, _, _ = HybridForest::Utils.train_test_bootstrap_split(instances)
12
+ forest << HybridForest::Trees::CARTTree.new.fit(sample)
13
+ end
14
+ forest
15
+ end
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,46 @@
1
+ require_relative "../../utilities/utils"
2
+ require_relative "../../trees/id3_tree"
3
+ require_relative "../../trees/cart_tree"
4
+
5
+ module HybridForest
6
+ module Forests
7
+ module ForestGrowers
8
+ class HybridGrower
9
+ TREE_TYPES = [HybridForest::Trees::CARTTree, HybridForest::Trees::ID3Tree].freeze
10
+
11
+ def grow_forest(instances, number_of_trees)
12
+ forest = []
13
+ number_of_trees.times do
14
+ in_of_bag, out_of_bag, out_of_bag_labels = HybridForest::Utils.train_test_bootstrap_split(instances)
15
+ tree_results = grow_trees(TREE_TYPES, in_of_bag, out_of_bag, out_of_bag_labels)
16
+ best_tree = select_best_tree(tree_results)
17
+ forest << best_tree
18
+ end
19
+ forest
20
+ end
21
+
22
+ private
23
+
24
+ def fit_and_predict(tree_class, in_of_bag, out_of_bag, out_of_bag_labels)
25
+ tree = tree_class.new.fit(in_of_bag)
26
+ tree_predictions = tree.predict(out_of_bag)
27
+ tree_accuracy = HybridForest::Utils.accuracy(tree_predictions, out_of_bag_labels)
28
+ {tree: tree, oob_accuracy: tree_accuracy}
29
+ end
30
+
31
+ def select_best_tree(tree_results)
32
+ best_result = tree_results.max_by(1) { |result| result[:oob_accuracy] }.first
33
+ best_result[:tree]
34
+ end
35
+
36
+ def grow_trees(tree_types, in_of_bag, out_of_bag, out_of_bag_labels)
37
+ tree_results = []
38
+ tree_types.each do |tree_type|
39
+ tree_results << fit_and_predict(tree_type, in_of_bag, out_of_bag, out_of_bag_labels)
40
+ end
41
+ tree_results
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,19 @@
1
+ require_relative "../../utilities/utils"
2
+ require_relative "../../trees/id3_tree"
3
+
4
+ module HybridForest
5
+ module Forests
6
+ module ForestGrowers
7
+ class ID3Grower
8
+ def grow_forest(instances, number_of_trees)
9
+ forest = []
10
+ number_of_trees.times do
11
+ sample, _, _ = HybridForest::Utils.train_test_bootstrap_split(instances)
12
+ forest << HybridForest::Trees::ID3Tree.new.fit(sample)
13
+ end
14
+ forest
15
+ end
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,29 @@
1
+ require_relative "../forests/forest_growers/hybrid_grower"
2
+ require_relative "../forests/forest_growers/cart_grower"
3
+ require_relative "../forests/forest_growers/id3_grower"
4
+
5
+ module HybridForest
6
+ module Forests
7
+ class GrowerFactory
8
+ TYPES = {
9
+ cart: HybridForest::Forests::ForestGrowers::CARTGrower,
10
+ id3: HybridForest::Forests::ForestGrowers::ID3Grower,
11
+ hybrid: HybridForest::Forests::ForestGrowers::HybridGrower
12
+ }.freeze
13
+
14
+ class << self
15
+ def for(type)
16
+ (TYPES[type] || default).new
17
+ end
18
+
19
+ def default
20
+ HybridForest::Forests::ForestGrowers::HybridGrower
21
+ end
22
+
23
+ def types
24
+ TYPES.keys
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,84 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "../utilities/utils"
4
+
5
+ module HybridForest
6
+ class RandomForest
7
+ ##
8
+ # Creates a new random forest.
9
+ #
10
+ # +number_of_trees+ dictates the size of the tree ensemble.
11
+ #
12
+ # +ensemble_type+ dictates the composition of the tree ensemble.
13
+ # Valid options are +:hybrid+, +:cart+, +:id3+.
14
+ #
15
+ def initialize(number_of_trees:, ensemble_type: :hybrid)
16
+ raise ArgumentError, "Invalid ensemble type" unless Forests::GrowerFactory.types.include? ensemble_type
17
+
18
+ @number_of_trees = number_of_trees
19
+ @ensemble_type = ensemble_type
20
+ end
21
+
22
+ ##
23
+ # Fits a model to the given dataset +instances+ and returns +self+.
24
+ #
25
+ def fit(instances)
26
+ instances = HybridForest::Utils.to_dataframe(instances)
27
+ forest_grower = Forests::GrowerFactory.for(@ensemble_type)
28
+ @forest = forest_grower.grow_forest(instances, @number_of_trees)
29
+ self
30
+ end
31
+
32
+ ##
33
+ # Predicts a label for each instance in the dataset +instances+ and returns an array of labels.
34
+ #
35
+ def predict(instances)
36
+ raise Errors::InvalidStateError, "You must call #fit before you call #predict" if @forest.nil?
37
+
38
+ instances = HybridForest::Utils.to_dataframe(instances)
39
+ predictions = tree_predictions(instances)
40
+ predictions.collect { |votes| majority_vote(votes) }
41
+ end
42
+
43
+ def to_s
44
+ return "Empty random forest: \n#{super()}" if @forest.nil?
45
+
46
+ title = case @ensemble_type
47
+ when :hybrid then "Hybrid random forest"
48
+ else "Uniform random forest"
49
+ end
50
+
51
+ table = Terminal::Table.new do |t|
52
+ t.title = title
53
+ t.headings = %w[Tree Count]
54
+ tally_ensemble.each do |tree_type, count|
55
+ t << [tree_type, count]
56
+ t << :separator
57
+ end
58
+ t << ["Total", @number_of_trees]
59
+ end
60
+ table.to_s
61
+ end
62
+
63
+ private
64
+
65
+ def majority_vote(votes)
66
+ votes.max_by(1) { |label| votes.count(label) }.first
67
+ end
68
+
69
+ def tree_predictions(instances)
70
+ predictions = @forest.collect { |tree| tree.predict(instances) }
71
+ predictions.transpose
72
+ end
73
+
74
+ def tally_ensemble
75
+ tree_type_counts = Hash.new(0)
76
+ @forest.each do |tree|
77
+ key = tree.name
78
+ tree_type_counts[key] += 1
79
+ end
80
+ tree_type_counts.default = nil
81
+ tree_type_counts
82
+ end
83
+ end
84
+ end