hybridforest 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +89 -0
  3. data/.rspec +3 -0
  4. data/.rubocop.yml +13 -0
  5. data/CHANGELOG.md +33 -0
  6. data/Gemfile +22 -0
  7. data/Gemfile.lock +92 -0
  8. data/LICENSE.txt +21 -0
  9. data/README.md +96 -0
  10. data/Rakefile +12 -0
  11. data/bin/console +15 -0
  12. data/bin/setup +8 -0
  13. data/hybridforest.gemspec +44 -0
  14. data/lib/hybridforest/errors/invalid_state_error.rb +9 -0
  15. data/lib/hybridforest/forests/forest_growers/cart_grower.rb +19 -0
  16. data/lib/hybridforest/forests/forest_growers/hybrid_grower.rb +46 -0
  17. data/lib/hybridforest/forests/forest_growers/id3_grower.rb +19 -0
  18. data/lib/hybridforest/forests/grower_factory.rb +29 -0
  19. data/lib/hybridforest/forests/random_forest.rb +84 -0
  20. data/lib/hybridforest/trees/cart_tree.rb +18 -0
  21. data/lib/hybridforest/trees/feature_selectors/all_features.rb +14 -0
  22. data/lib/hybridforest/trees/feature_selectors/max_one_split_per_feature.rb +21 -0
  23. data/lib/hybridforest/trees/feature_selectors/random_feature_subspace.rb +27 -0
  24. data/lib/hybridforest/trees/id3_tree.rb +18 -0
  25. data/lib/hybridforest/trees/impurity_metrics/entropy.rb +21 -0
  26. data/lib/hybridforest/trees/impurity_metrics/gini_impurity.rb +17 -0
  27. data/lib/hybridforest/trees/impurity_metrics/impurity.rb +28 -0
  28. data/lib/hybridforest/trees/nodes/binary_node.rb +37 -0
  29. data/lib/hybridforest/trees/nodes/leaf_node.rb +31 -0
  30. data/lib/hybridforest/trees/nodes/multiway_node.rb +45 -0
  31. data/lib/hybridforest/trees/split.rb +30 -0
  32. data/lib/hybridforest/trees/tests/equal.rb +37 -0
  33. data/lib/hybridforest/trees/tests/equal_or_greater.rb +37 -0
  34. data/lib/hybridforest/trees/tests/less.rb +37 -0
  35. data/lib/hybridforest/trees/tests/not_equal.rb +37 -0
  36. data/lib/hybridforest/trees/tests/test.rb +28 -0
  37. data/lib/hybridforest/trees/tree.rb +46 -0
  38. data/lib/hybridforest/trees/tree_growers/cart_grower.rb +76 -0
  39. data/lib/hybridforest/trees/tree_growers/id3_grower.rb +161 -0
  40. data/lib/hybridforest/utilities/utils.rb +206 -0
  41. data/lib/hybridforest/version.rb +5 -0
  42. data/lib/hybridforest.rb +46 -0
  43. metadata +285 -0
@@ -0,0 +1,18 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "tree"
4
+ require_relative "tree_growers/cart_grower"
5
+
6
+ module HybridForest
7
+ module Trees
8
+ class CARTTree < Tree
9
+ def initialize(tree_grower: HybridForest::Trees::TreeGrowers::CARTGrower.new)
10
+ super(tree_grower: tree_grower)
11
+ end
12
+
13
+ def name
14
+ "CART Tree"
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,14 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class AllFeatures
6
+ def select_features(all_features)
7
+ all_features
8
+ end
9
+
10
+ def update(_feature)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class MaxOneSplitPerFeature
6
+ attr_reader :exhausted_features
7
+
8
+ def initialize
9
+ @exhausted_features = []
10
+ end
11
+
12
+ def select_features(all_features)
13
+ all_features - @exhausted_features
14
+ end
15
+
16
+ def update(feature)
17
+ @exhausted_features << feature
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "set"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ class RandomFeatureSubspace
8
+ def select_features(all_features)
9
+ n = default_subspace_size(all_features.count)
10
+ indices = Set.new
11
+ until indices.size == n
12
+ indices << rand(0...all_features.count)
13
+ end
14
+ all_features.values_at(*indices)
15
+ end
16
+
17
+ def update(_feature)
18
+ end
19
+
20
+ private
21
+
22
+ def default_subspace_size(num_of_features)
23
+ (num_of_features.to_f / 2).round
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,18 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "tree"
4
+ require_relative "tree_growers/id3_grower"
5
+
6
+ module HybridForest
7
+ module Trees
8
+ class ID3Tree < Tree
9
+ def initialize(tree_grower: HybridForest::Trees::TreeGrowers::ID3Grower.new)
10
+ super(tree_grower: tree_grower)
11
+ end
12
+
13
+ def name
14
+ "ID3 Tree"
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "impurity"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ class Entropy
8
+ include HybridForest::Trees::Impurity
9
+
10
+ def compute(instances)
11
+ instance_count = instances.count
12
+ label_counts = instances.count_labels
13
+ label_counts.values.sum do |label_count|
14
+ label_probability = label_count.to_f / instance_count
15
+ label_surprise = Math.log2(1 / label_probability)
16
+ label_probability * label_surprise
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,17 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "impurity"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ class GiniImpurity
8
+ include HybridForest::Trees::Impurity
9
+
10
+ def compute(instances)
11
+ total_count = instances.count.to_f
12
+ label_counts = instances.count_labels.each_value
13
+ 1 - label_counts.sum { |count| (count / total_count)**2 }
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,28 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "active_support"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Impurity
8
+ def information_gain(children, parent_impurity)
9
+ return 0.0 if children.blank?
10
+
11
+ parent_count = children.sum(&:count)
12
+ children_impurity = children.sum do |child|
13
+ weighted_impurity(child, parent_count)
14
+ end
15
+ parent_impurity - children_impurity
16
+ end
17
+
18
+ def weighted_impurity(instances_in_child, parent_count)
19
+ weight = instances_in_child.count.to_f / parent_count
20
+ compute(instances_in_child) * weight
21
+ end
22
+
23
+ def compute(instances)
24
+ raise NotImplementedError, "Must be implemented by including classes"
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class BinaryNode
6
+ attr_reader :test, :true_branch, :false_branch
7
+
8
+ def initialize(test, true_branch, false_branch)
9
+ @test = test
10
+ @true_branch = true_branch
11
+ @false_branch = false_branch
12
+ end
13
+
14
+ def branch_for(instance)
15
+ if @test.passed_by? instance
16
+ @true_branch
17
+ else
18
+ @false_branch
19
+ end
20
+ end
21
+
22
+ def classify(instance)
23
+ branch_for(instance).classify(instance)
24
+ end
25
+
26
+ def print_string(spacing = "")
27
+ print spacing
28
+ puts "#{@test} True"
29
+ @true_branch.print_string(spacing + " ")
30
+
31
+ print spacing
32
+ puts "#{@test} False"
33
+ @false_branch.print_string(spacing + " ")
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,31 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class LeafNode
6
+ def initialize(instances)
7
+ @prediction = majority_vote(instances)
8
+ end
9
+
10
+ def classify(_instance)
11
+ @prediction
12
+ end
13
+
14
+ def print_string(spacing = "")
15
+ print spacing
16
+ puts to_s
17
+ end
18
+
19
+ def to_s
20
+ "Predict #{@prediction}"
21
+ end
22
+
23
+ private
24
+
25
+ def majority_vote(instances)
26
+ labels = instances[instances.label].to_enum
27
+ labels.max_by(1) { |label| labels.count(label) }.first
28
+ end
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,45 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class MultiwayNode
6
+ attr_reader :paths, :majority_class
7
+
8
+ def initialize(paths, instances)
9
+ @paths = paths
10
+ @majority_class = majority_vote(instances)
11
+ end
12
+
13
+ def tests
14
+ @paths.keys
15
+ end
16
+
17
+ def branches
18
+ @paths.values
19
+ end
20
+
21
+ def branch_for(instance)
22
+ match = tests.find { |test| test.passed_by? instance }
23
+ paths[match]
24
+ end
25
+
26
+ def classify(instance)
27
+ branch = branch_for instance
28
+ branch&.classify(instance) || @majority_class
29
+ end
30
+
31
+ def print_string(spacing = "")
32
+ paths.each do |test, branch|
33
+ print spacing
34
+ puts "#{test} True"
35
+ branch.print_string(spacing + " ")
36
+ end
37
+ end
38
+
39
+ def majority_vote(instances)
40
+ labels = instances[instances.label].to_enum
41
+ labels.max_by(1) { |label| labels.count(label) }.first
42
+ end
43
+ end
44
+ end
45
+ end
@@ -0,0 +1,30 @@
1
+ module HybridForest
2
+ module Trees
3
+ class Split
4
+ include Comparable
5
+
6
+ attr_reader :feature, :value, :subsets, :info_gain
7
+ alias_method :better_than?, :>
8
+ alias_method :worse_than?, :<
9
+
10
+ def initialize(feature, info_gain: 0, subsets: [Rover::DataFrame.new, Rover::DataFrame.new], value: nil)
11
+ @feature = feature
12
+ @value = value
13
+ @subsets = subsets
14
+ @info_gain = info_gain
15
+ end
16
+
17
+ def <=>(other)
18
+ info_gain <=> other.info_gain
19
+ end
20
+
21
+ def binary?
22
+ value ? true : false
23
+ end
24
+
25
+ def multiway?
26
+ !binary?
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Tests
8
+ class Equal < Test
9
+ def initialize(feature, value)
10
+ super(feature, value)
11
+ end
12
+
13
+ def passed_by?(instance)
14
+ instance[feature] == value
15
+ end
16
+
17
+ def ==(other)
18
+ return false unless other.is_a? Equal
19
+ feature == other.feature && value == other.value
20
+ end
21
+
22
+ def eql?(other)
23
+ return false unless other.is_a? Equal
24
+ feature.eql?(other.feature) && value.eql?(other.value)
25
+ end
26
+
27
+ def to_s
28
+ "#{feature} == #{value}?"
29
+ end
30
+
31
+ def description
32
+ "#{feature} equal to #{value}?"
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Tests
8
+ class EqualOrGreater < Test
9
+ def initialize(feature, value)
10
+ super(feature, value)
11
+ end
12
+
13
+ def passed_by?(instance)
14
+ instance[feature] >= value
15
+ end
16
+
17
+ def ==(other)
18
+ return false unless other.is_a? EqualOrGreater
19
+ feature == other.feature && value == other.value
20
+ end
21
+
22
+ def eql?(other)
23
+ return false unless other.is_a? EqualOrGreater
24
+ feature.eql?(other.feature) && value.eql?(other.value)
25
+ end
26
+
27
+ def to_s
28
+ "#{feature} >= #{value}?"
29
+ end
30
+
31
+ def description
32
+ "#{feature} equal to or greater than #{value}?"
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Tests
8
+ class Less < Test
9
+ def initialize(feature, value)
10
+ super(feature, value)
11
+ end
12
+
13
+ def passed_by?(instance)
14
+ instance[feature] < value
15
+ end
16
+
17
+ def ==(other)
18
+ return false unless other.is_a? Less
19
+ feature == other.feature && value == other.value
20
+ end
21
+
22
+ def eql?(other)
23
+ return false unless other.is_a? Less
24
+ feature.eql?(other.feature) && value.eql?(other.value)
25
+ end
26
+
27
+ def to_s
28
+ "#{feature} < #{value}?"
29
+ end
30
+
31
+ def description
32
+ "#{feature} less than #{value}?"
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Tests
8
+ class NotEqual < Test
9
+ def initialize(feature, value)
10
+ super(feature, value)
11
+ end
12
+
13
+ def passed_by?(instance)
14
+ instance[feature] != value
15
+ end
16
+
17
+ def ==(other)
18
+ return false unless other.is_a? NotEqual
19
+ feature == other.feature && value == other.value
20
+ end
21
+
22
+ def eql?(other)
23
+ return false unless other.is_a? NotEqual
24
+ feature.eql?(other.feature) && value.eql?(other.value)
25
+ end
26
+
27
+ def to_s
28
+ "#{feature} != #{value}?"
29
+ end
30
+
31
+ def description
32
+ "#{feature} not equal to #{value}?"
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,28 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ module Tests
6
+ class Test
7
+ attr_reader :feature, :value
8
+
9
+ def initialize(feature, value)
10
+ @feature = feature
11
+ @value = value
12
+ end
13
+
14
+ def passed_by?(_instance)
15
+ raise NotImplementedError
16
+ end
17
+
18
+ def description
19
+ raise NotImplementedError
20
+ end
21
+
22
+ def hash
23
+ feature.hash ^ value.hash
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,46 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "../utilities/utils"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ class Tree
8
+ # Creates a new Tree using the specified tree growing algorithm.
9
+ def initialize(tree_grower:)
10
+ @tree_grower = tree_grower
11
+ end
12
+
13
+ ##
14
+ # Fits a model to the given dataset +instances+ and returns +self+.
15
+ #
16
+ def fit(instances)
17
+ instances = HybridForest::Utils.to_dataframe(instances)
18
+ @root = @tree_grower.grow_tree(instances)
19
+ self
20
+ end
21
+
22
+ ##
23
+ # Predicts a label for each instance in the dataset +instances+ and returns an array of labels.
24
+ #
25
+ def predict(instances)
26
+ if @root.nil?
27
+ raise Errors::InvalidStateError,
28
+ "You must call #fit before you call #predict"
29
+ end
30
+
31
+ HybridForest::Utils.to_dataframe(instances).each_row.reduce([]) do |predictions, instance|
32
+ predictions << @root.classify(instance)
33
+ end
34
+ end
35
+
36
+ # Prints a string representation of this Tree.
37
+ def inspect
38
+ if @root.nil?
39
+ "Empty tree: #{super}"
40
+ else
41
+ @root.print_string
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,76 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "../../utilities/utils"
4
+ require_relative "../split"
5
+ require_relative "../feature_selectors/random_feature_subspace"
6
+ require_relative "../feature_selectors/all_features"
7
+ require_relative "../feature_selectors/max_one_split_per_feature"
8
+ require_relative "../impurity_metrics/gini_impurity"
9
+ require_relative "../impurity_metrics/entropy"
10
+ require_relative "../nodes/binary_node"
11
+ require_relative "../nodes/leaf_node"
12
+ require_relative "../tests/equal_or_greater"
13
+
14
+ module HybridForest
15
+ module Trees
16
+ module TreeGrowers
17
+ class CARTGrower
18
+ def initialize(feature_selector: RandomFeatureSubspace.new, impurity_metric: GiniImpurity.new)
19
+ @impurity_metric = impurity_metric
20
+ @feature_selector = feature_selector
21
+ end
22
+
23
+ def grow_tree(instances)
24
+ split = find_best_split(instances)
25
+ if split.info_gain == 0
26
+ LeafNode.new(instances)
27
+ else
28
+ branch(split.subsets, split.feature, split.value)
29
+ end
30
+ end
31
+
32
+ private
33
+
34
+ def branch(subsets, feature, value)
35
+ true_instances, false_instances = subsets
36
+ true_branch = grow_tree(true_instances)
37
+ false_branch = grow_tree(false_instances)
38
+ test = Tests::EqualOrGreater.new(feature, value)
39
+ BinaryNode.new(test, true_branch, false_branch)
40
+ end
41
+
42
+ def find_best_split(instances)
43
+ considered_features = @feature_selector.select_features(instances.features)
44
+ current_impurity = @impurity_metric.compute(instances)
45
+ best_split = default_split(instances, considered_features)
46
+
47
+ considered_features.each do |feature|
48
+ instances[feature].uniq.each do |value|
49
+ subsets = if instances[feature].numeric?
50
+ instances.equal_or_greater_split(feature, value)
51
+ else
52
+ instances.equal_split(feature, value)
53
+ end
54
+
55
+ next if subsets.any? { |set| set.count == 0 }
56
+
57
+ info_gain = @impurity_metric.information_gain(subsets, current_impurity)
58
+ split = Split.new(feature, info_gain: info_gain, subsets: subsets, value: value)
59
+ if split.better_than? best_split
60
+ best_split = split
61
+ end
62
+ end
63
+ end
64
+
65
+ best_split
66
+ end
67
+
68
+ def default_split(instances, features)
69
+ first_feature = features.first
70
+ value = instances[first_feature].first
71
+ Split.new(first_feature, value: value)
72
+ end
73
+ end
74
+ end
75
+ end
76
+ end