hybridforest 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (43) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +89 -0
  3. data/.rspec +3 -0
  4. data/.rubocop.yml +13 -0
  5. data/CHANGELOG.md +33 -0
  6. data/Gemfile +22 -0
  7. data/Gemfile.lock +92 -0
  8. data/LICENSE.txt +21 -0
  9. data/README.md +96 -0
  10. data/Rakefile +12 -0
  11. data/bin/console +15 -0
  12. data/bin/setup +8 -0
  13. data/hybridforest.gemspec +44 -0
  14. data/lib/hybridforest/errors/invalid_state_error.rb +9 -0
  15. data/lib/hybridforest/forests/forest_growers/cart_grower.rb +19 -0
  16. data/lib/hybridforest/forests/forest_growers/hybrid_grower.rb +46 -0
  17. data/lib/hybridforest/forests/forest_growers/id3_grower.rb +19 -0
  18. data/lib/hybridforest/forests/grower_factory.rb +29 -0
  19. data/lib/hybridforest/forests/random_forest.rb +84 -0
  20. data/lib/hybridforest/trees/cart_tree.rb +18 -0
  21. data/lib/hybridforest/trees/feature_selectors/all_features.rb +14 -0
  22. data/lib/hybridforest/trees/feature_selectors/max_one_split_per_feature.rb +21 -0
  23. data/lib/hybridforest/trees/feature_selectors/random_feature_subspace.rb +27 -0
  24. data/lib/hybridforest/trees/id3_tree.rb +18 -0
  25. data/lib/hybridforest/trees/impurity_metrics/entropy.rb +21 -0
  26. data/lib/hybridforest/trees/impurity_metrics/gini_impurity.rb +17 -0
  27. data/lib/hybridforest/trees/impurity_metrics/impurity.rb +28 -0
  28. data/lib/hybridforest/trees/nodes/binary_node.rb +37 -0
  29. data/lib/hybridforest/trees/nodes/leaf_node.rb +31 -0
  30. data/lib/hybridforest/trees/nodes/multiway_node.rb +45 -0
  31. data/lib/hybridforest/trees/split.rb +30 -0
  32. data/lib/hybridforest/trees/tests/equal.rb +37 -0
  33. data/lib/hybridforest/trees/tests/equal_or_greater.rb +37 -0
  34. data/lib/hybridforest/trees/tests/less.rb +37 -0
  35. data/lib/hybridforest/trees/tests/not_equal.rb +37 -0
  36. data/lib/hybridforest/trees/tests/test.rb +28 -0
  37. data/lib/hybridforest/trees/tree.rb +46 -0
  38. data/lib/hybridforest/trees/tree_growers/cart_grower.rb +76 -0
  39. data/lib/hybridforest/trees/tree_growers/id3_grower.rb +161 -0
  40. data/lib/hybridforest/utilities/utils.rb +206 -0
  41. data/lib/hybridforest/version.rb +5 -0
  42. data/lib/hybridforest.rb +46 -0
  43. metadata +285 -0
@@ -0,0 +1,18 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "tree"
4
+ require_relative "tree_growers/cart_grower"
5
+
6
+ module HybridForest
7
+ module Trees
8
+ class CARTTree < Tree
9
+ def initialize(tree_grower: HybridForest::Trees::TreeGrowers::CARTGrower.new)
10
+ super(tree_grower: tree_grower)
11
+ end
12
+
13
+ def name
14
+ "CART Tree"
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,14 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class AllFeatures
6
+ def select_features(all_features)
7
+ all_features
8
+ end
9
+
10
+ def update(_feature)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class MaxOneSplitPerFeature
6
+ attr_reader :exhausted_features
7
+
8
+ def initialize
9
+ @exhausted_features = []
10
+ end
11
+
12
+ def select_features(all_features)
13
+ all_features - @exhausted_features
14
+ end
15
+
16
+ def update(feature)
17
+ @exhausted_features << feature
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "set"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ class RandomFeatureSubspace
8
+ def select_features(all_features)
9
+ n = default_subspace_size(all_features.count)
10
+ indices = Set.new
11
+ until indices.size == n
12
+ indices << rand(0...all_features.count)
13
+ end
14
+ all_features.values_at(*indices)
15
+ end
16
+
17
+ def update(_feature)
18
+ end
19
+
20
+ private
21
+
22
+ def default_subspace_size(num_of_features)
23
+ (num_of_features.to_f / 2).round
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,18 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "tree"
4
+ require_relative "tree_growers/id3_grower"
5
+
6
+ module HybridForest
7
+ module Trees
8
+ class ID3Tree < Tree
9
+ def initialize(tree_grower: HybridForest::Trees::TreeGrowers::ID3Grower.new)
10
+ super(tree_grower: tree_grower)
11
+ end
12
+
13
+ def name
14
+ "ID3 Tree"
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "impurity"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ class Entropy
8
+ include HybridForest::Trees::Impurity
9
+
10
+ def compute(instances)
11
+ instance_count = instances.count
12
+ label_counts = instances.count_labels
13
+ label_counts.values.sum do |label_count|
14
+ label_probability = label_count.to_f / instance_count
15
+ label_surprise = Math.log2(1 / label_probability)
16
+ label_probability * label_surprise
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,17 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "impurity"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ class GiniImpurity
8
+ include HybridForest::Trees::Impurity
9
+
10
+ def compute(instances)
11
+ total_count = instances.count.to_f
12
+ label_counts = instances.count_labels.each_value
13
+ 1 - label_counts.sum { |count| (count / total_count)**2 }
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,28 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "active_support"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Impurity
8
+ def information_gain(children, parent_impurity)
9
+ return 0.0 if children.blank?
10
+
11
+ parent_count = children.sum(&:count)
12
+ children_impurity = children.sum do |child|
13
+ weighted_impurity(child, parent_count)
14
+ end
15
+ parent_impurity - children_impurity
16
+ end
17
+
18
+ def weighted_impurity(instances_in_child, parent_count)
19
+ weight = instances_in_child.count.to_f / parent_count
20
+ compute(instances_in_child) * weight
21
+ end
22
+
23
+ def compute(instances)
24
+ raise NotImplementedError, "Must be implemented by including classes"
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class BinaryNode
6
+ attr_reader :test, :true_branch, :false_branch
7
+
8
+ def initialize(test, true_branch, false_branch)
9
+ @test = test
10
+ @true_branch = true_branch
11
+ @false_branch = false_branch
12
+ end
13
+
14
+ def branch_for(instance)
15
+ if @test.passed_by? instance
16
+ @true_branch
17
+ else
18
+ @false_branch
19
+ end
20
+ end
21
+
22
+ def classify(instance)
23
+ branch_for(instance).classify(instance)
24
+ end
25
+
26
+ def print_string(spacing = "")
27
+ print spacing
28
+ puts "#{@test} True"
29
+ @true_branch.print_string(spacing + " ")
30
+
31
+ print spacing
32
+ puts "#{@test} False"
33
+ @false_branch.print_string(spacing + " ")
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,31 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class LeafNode
6
+ def initialize(instances)
7
+ @prediction = majority_vote(instances)
8
+ end
9
+
10
+ def classify(_instance)
11
+ @prediction
12
+ end
13
+
14
+ def print_string(spacing = "")
15
+ print spacing
16
+ puts to_s
17
+ end
18
+
19
+ def to_s
20
+ "Predict #{@prediction}"
21
+ end
22
+
23
+ private
24
+
25
+ def majority_vote(instances)
26
+ labels = instances[instances.label].to_enum
27
+ labels.max_by(1) { |label| labels.count(label) }.first
28
+ end
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,45 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ class MultiwayNode
6
+ attr_reader :paths, :majority_class
7
+
8
+ def initialize(paths, instances)
9
+ @paths = paths
10
+ @majority_class = majority_vote(instances)
11
+ end
12
+
13
+ def tests
14
+ @paths.keys
15
+ end
16
+
17
+ def branches
18
+ @paths.values
19
+ end
20
+
21
+ def branch_for(instance)
22
+ match = tests.find { |test| test.passed_by? instance }
23
+ paths[match]
24
+ end
25
+
26
+ def classify(instance)
27
+ branch = branch_for instance
28
+ branch&.classify(instance) || @majority_class
29
+ end
30
+
31
+ def print_string(spacing = "")
32
+ paths.each do |test, branch|
33
+ print spacing
34
+ puts "#{test} True"
35
+ branch.print_string(spacing + " ")
36
+ end
37
+ end
38
+
39
+ def majority_vote(instances)
40
+ labels = instances[instances.label].to_enum
41
+ labels.max_by(1) { |label| labels.count(label) }.first
42
+ end
43
+ end
44
+ end
45
+ end
@@ -0,0 +1,30 @@
1
+ module HybridForest
2
+ module Trees
3
+ class Split
4
+ include Comparable
5
+
6
+ attr_reader :feature, :value, :subsets, :info_gain
7
+ alias_method :better_than?, :>
8
+ alias_method :worse_than?, :<
9
+
10
+ def initialize(feature, info_gain: 0, subsets: [Rover::DataFrame.new, Rover::DataFrame.new], value: nil)
11
+ @feature = feature
12
+ @value = value
13
+ @subsets = subsets
14
+ @info_gain = info_gain
15
+ end
16
+
17
+ def <=>(other)
18
+ info_gain <=> other.info_gain
19
+ end
20
+
21
+ def binary?
22
+ value ? true : false
23
+ end
24
+
25
+ def multiway?
26
+ !binary?
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Tests
8
+ class Equal < Test
9
+ def initialize(feature, value)
10
+ super(feature, value)
11
+ end
12
+
13
+ def passed_by?(instance)
14
+ instance[feature] == value
15
+ end
16
+
17
+ def ==(other)
18
+ return false unless other.is_a? Equal
19
+ feature == other.feature && value == other.value
20
+ end
21
+
22
+ def eql?(other)
23
+ return false unless other.is_a? Equal
24
+ feature.eql?(other.feature) && value.eql?(other.value)
25
+ end
26
+
27
+ def to_s
28
+ "#{feature} == #{value}?"
29
+ end
30
+
31
+ def description
32
+ "#{feature} equal to #{value}?"
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Tests
8
+ class EqualOrGreater < Test
9
+ def initialize(feature, value)
10
+ super(feature, value)
11
+ end
12
+
13
+ def passed_by?(instance)
14
+ instance[feature] >= value
15
+ end
16
+
17
+ def ==(other)
18
+ return false unless other.is_a? EqualOrGreater
19
+ feature == other.feature && value == other.value
20
+ end
21
+
22
+ def eql?(other)
23
+ return false unless other.is_a? EqualOrGreater
24
+ feature.eql?(other.feature) && value.eql?(other.value)
25
+ end
26
+
27
+ def to_s
28
+ "#{feature} >= #{value}?"
29
+ end
30
+
31
+ def description
32
+ "#{feature} equal to or greater than #{value}?"
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Tests
8
+ class Less < Test
9
+ def initialize(feature, value)
10
+ super(feature, value)
11
+ end
12
+
13
+ def passed_by?(instance)
14
+ instance[feature] < value
15
+ end
16
+
17
+ def ==(other)
18
+ return false unless other.is_a? Less
19
+ feature == other.feature && value == other.value
20
+ end
21
+
22
+ def eql?(other)
23
+ return false unless other.is_a? Less
24
+ feature.eql?(other.feature) && value.eql?(other.value)
25
+ end
26
+
27
+ def to_s
28
+ "#{feature} < #{value}?"
29
+ end
30
+
31
+ def description
32
+ "#{feature} less than #{value}?"
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ module Tests
8
+ class NotEqual < Test
9
+ def initialize(feature, value)
10
+ super(feature, value)
11
+ end
12
+
13
+ def passed_by?(instance)
14
+ instance[feature] != value
15
+ end
16
+
17
+ def ==(other)
18
+ return false unless other.is_a? NotEqual
19
+ feature == other.feature && value == other.value
20
+ end
21
+
22
+ def eql?(other)
23
+ return false unless other.is_a? NotEqual
24
+ feature.eql?(other.feature) && value.eql?(other.value)
25
+ end
26
+
27
+ def to_s
28
+ "#{feature} != #{value}?"
29
+ end
30
+
31
+ def description
32
+ "#{feature} not equal to #{value}?"
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,28 @@
1
+ # frozen_string_literal: true
2
+
3
+ module HybridForest
4
+ module Trees
5
+ module Tests
6
+ class Test
7
+ attr_reader :feature, :value
8
+
9
+ def initialize(feature, value)
10
+ @feature = feature
11
+ @value = value
12
+ end
13
+
14
+ def passed_by?(_instance)
15
+ raise NotImplementedError
16
+ end
17
+
18
+ def description
19
+ raise NotImplementedError
20
+ end
21
+
22
+ def hash
23
+ feature.hash ^ value.hash
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,46 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "../utilities/utils"
4
+
5
+ module HybridForest
6
+ module Trees
7
+ class Tree
8
+ # Creates a new Tree using the specified tree growing algorithm.
9
+ def initialize(tree_grower:)
10
+ @tree_grower = tree_grower
11
+ end
12
+
13
+ ##
14
+ # Fits a model to the given dataset +instances+ and returns +self+.
15
+ #
16
+ def fit(instances)
17
+ instances = HybridForest::Utils.to_dataframe(instances)
18
+ @root = @tree_grower.grow_tree(instances)
19
+ self
20
+ end
21
+
22
+ ##
23
+ # Predicts a label for each instance in the dataset +instances+ and returns an array of labels.
24
+ #
25
+ def predict(instances)
26
+ if @root.nil?
27
+ raise Errors::InvalidStateError,
28
+ "You must call #fit before you call #predict"
29
+ end
30
+
31
+ HybridForest::Utils.to_dataframe(instances).each_row.reduce([]) do |predictions, instance|
32
+ predictions << @root.classify(instance)
33
+ end
34
+ end
35
+
36
+ # Prints a string representation of this Tree.
37
+ def inspect
38
+ if @root.nil?
39
+ "Empty tree: #{super}"
40
+ else
41
+ @root.print_string
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,76 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "../../utilities/utils"
4
+ require_relative "../split"
5
+ require_relative "../feature_selectors/random_feature_subspace"
6
+ require_relative "../feature_selectors/all_features"
7
+ require_relative "../feature_selectors/max_one_split_per_feature"
8
+ require_relative "../impurity_metrics/gini_impurity"
9
+ require_relative "../impurity_metrics/entropy"
10
+ require_relative "../nodes/binary_node"
11
+ require_relative "../nodes/leaf_node"
12
+ require_relative "../tests/equal_or_greater"
13
+
14
+ module HybridForest
15
+ module Trees
16
+ module TreeGrowers
17
+ class CARTGrower
18
+ def initialize(feature_selector: RandomFeatureSubspace.new, impurity_metric: GiniImpurity.new)
19
+ @impurity_metric = impurity_metric
20
+ @feature_selector = feature_selector
21
+ end
22
+
23
+ def grow_tree(instances)
24
+ split = find_best_split(instances)
25
+ if split.info_gain == 0
26
+ LeafNode.new(instances)
27
+ else
28
+ branch(split.subsets, split.feature, split.value)
29
+ end
30
+ end
31
+
32
+ private
33
+
34
+ def branch(subsets, feature, value)
35
+ true_instances, false_instances = subsets
36
+ true_branch = grow_tree(true_instances)
37
+ false_branch = grow_tree(false_instances)
38
+ test = Tests::EqualOrGreater.new(feature, value)
39
+ BinaryNode.new(test, true_branch, false_branch)
40
+ end
41
+
42
+ def find_best_split(instances)
43
+ considered_features = @feature_selector.select_features(instances.features)
44
+ current_impurity = @impurity_metric.compute(instances)
45
+ best_split = default_split(instances, considered_features)
46
+
47
+ considered_features.each do |feature|
48
+ instances[feature].uniq.each do |value|
49
+ subsets = if instances[feature].numeric?
50
+ instances.equal_or_greater_split(feature, value)
51
+ else
52
+ instances.equal_split(feature, value)
53
+ end
54
+
55
+ next if subsets.any? { |set| set.count == 0 }
56
+
57
+ info_gain = @impurity_metric.information_gain(subsets, current_impurity)
58
+ split = Split.new(feature, info_gain: info_gain, subsets: subsets, value: value)
59
+ if split.better_than? best_split
60
+ best_split = split
61
+ end
62
+ end
63
+ end
64
+
65
+ best_split
66
+ end
67
+
68
+ def default_split(instances, features)
69
+ first_feature = features.first
70
+ value = instances[first_feature].first
71
+ Split.new(first_feature, value: value)
72
+ end
73
+ end
74
+ end
75
+ end
76
+ end