hybridforest 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +89 -0
- data/.rspec +3 -0
- data/.rubocop.yml +13 -0
- data/CHANGELOG.md +33 -0
- data/Gemfile +22 -0
- data/Gemfile.lock +92 -0
- data/LICENSE.txt +21 -0
- data/README.md +96 -0
- data/Rakefile +12 -0
- data/bin/console +15 -0
- data/bin/setup +8 -0
- data/hybridforest.gemspec +44 -0
- data/lib/hybridforest/errors/invalid_state_error.rb +9 -0
- data/lib/hybridforest/forests/forest_growers/cart_grower.rb +19 -0
- data/lib/hybridforest/forests/forest_growers/hybrid_grower.rb +46 -0
- data/lib/hybridforest/forests/forest_growers/id3_grower.rb +19 -0
- data/lib/hybridforest/forests/grower_factory.rb +29 -0
- data/lib/hybridforest/forests/random_forest.rb +84 -0
- data/lib/hybridforest/trees/cart_tree.rb +18 -0
- data/lib/hybridforest/trees/feature_selectors/all_features.rb +14 -0
- data/lib/hybridforest/trees/feature_selectors/max_one_split_per_feature.rb +21 -0
- data/lib/hybridforest/trees/feature_selectors/random_feature_subspace.rb +27 -0
- data/lib/hybridforest/trees/id3_tree.rb +18 -0
- data/lib/hybridforest/trees/impurity_metrics/entropy.rb +21 -0
- data/lib/hybridforest/trees/impurity_metrics/gini_impurity.rb +17 -0
- data/lib/hybridforest/trees/impurity_metrics/impurity.rb +28 -0
- data/lib/hybridforest/trees/nodes/binary_node.rb +37 -0
- data/lib/hybridforest/trees/nodes/leaf_node.rb +31 -0
- data/lib/hybridforest/trees/nodes/multiway_node.rb +45 -0
- data/lib/hybridforest/trees/split.rb +30 -0
- data/lib/hybridforest/trees/tests/equal.rb +37 -0
- data/lib/hybridforest/trees/tests/equal_or_greater.rb +37 -0
- data/lib/hybridforest/trees/tests/less.rb +37 -0
- data/lib/hybridforest/trees/tests/not_equal.rb +37 -0
- data/lib/hybridforest/trees/tests/test.rb +28 -0
- data/lib/hybridforest/trees/tree.rb +46 -0
- data/lib/hybridforest/trees/tree_growers/cart_grower.rb +76 -0
- data/lib/hybridforest/trees/tree_growers/id3_grower.rb +161 -0
- data/lib/hybridforest/utilities/utils.rb +206 -0
- data/lib/hybridforest/version.rb +5 -0
- data/lib/hybridforest.rb +46 -0
- metadata +285 -0
@@ -0,0 +1,18 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "tree"
|
4
|
+
require_relative "tree_growers/cart_grower"
|
5
|
+
|
6
|
+
module HybridForest
|
7
|
+
module Trees
|
8
|
+
class CARTTree < Tree
|
9
|
+
def initialize(tree_grower: HybridForest::Trees::TreeGrowers::CARTGrower.new)
|
10
|
+
super(tree_grower: tree_grower)
|
11
|
+
end
|
12
|
+
|
13
|
+
def name
|
14
|
+
"CART Tree"
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
class MaxOneSplitPerFeature
|
6
|
+
attr_reader :exhausted_features
|
7
|
+
|
8
|
+
def initialize
|
9
|
+
@exhausted_features = []
|
10
|
+
end
|
11
|
+
|
12
|
+
def select_features(all_features)
|
13
|
+
all_features - @exhausted_features
|
14
|
+
end
|
15
|
+
|
16
|
+
def update(feature)
|
17
|
+
@exhausted_features << feature
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "set"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
class RandomFeatureSubspace
|
8
|
+
def select_features(all_features)
|
9
|
+
n = default_subspace_size(all_features.count)
|
10
|
+
indices = Set.new
|
11
|
+
until indices.size == n
|
12
|
+
indices << rand(0...all_features.count)
|
13
|
+
end
|
14
|
+
all_features.values_at(*indices)
|
15
|
+
end
|
16
|
+
|
17
|
+
def update(_feature)
|
18
|
+
end
|
19
|
+
|
20
|
+
private
|
21
|
+
|
22
|
+
def default_subspace_size(num_of_features)
|
23
|
+
(num_of_features.to_f / 2).round
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "tree"
|
4
|
+
require_relative "tree_growers/id3_grower"
|
5
|
+
|
6
|
+
module HybridForest
|
7
|
+
module Trees
|
8
|
+
class ID3Tree < Tree
|
9
|
+
def initialize(tree_grower: HybridForest::Trees::TreeGrowers::ID3Grower.new)
|
10
|
+
super(tree_grower: tree_grower)
|
11
|
+
end
|
12
|
+
|
13
|
+
def name
|
14
|
+
"ID3 Tree"
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "impurity"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
class Entropy
|
8
|
+
include HybridForest::Trees::Impurity
|
9
|
+
|
10
|
+
def compute(instances)
|
11
|
+
instance_count = instances.count
|
12
|
+
label_counts = instances.count_labels
|
13
|
+
label_counts.values.sum do |label_count|
|
14
|
+
label_probability = label_count.to_f / instance_count
|
15
|
+
label_surprise = Math.log2(1 / label_probability)
|
16
|
+
label_probability * label_surprise
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "impurity"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
class GiniImpurity
|
8
|
+
include HybridForest::Trees::Impurity
|
9
|
+
|
10
|
+
def compute(instances)
|
11
|
+
total_count = instances.count.to_f
|
12
|
+
label_counts = instances.count_labels.each_value
|
13
|
+
1 - label_counts.sum { |count| (count / total_count)**2 }
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "active_support"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Impurity
|
8
|
+
def information_gain(children, parent_impurity)
|
9
|
+
return 0.0 if children.blank?
|
10
|
+
|
11
|
+
parent_count = children.sum(&:count)
|
12
|
+
children_impurity = children.sum do |child|
|
13
|
+
weighted_impurity(child, parent_count)
|
14
|
+
end
|
15
|
+
parent_impurity - children_impurity
|
16
|
+
end
|
17
|
+
|
18
|
+
def weighted_impurity(instances_in_child, parent_count)
|
19
|
+
weight = instances_in_child.count.to_f / parent_count
|
20
|
+
compute(instances_in_child) * weight
|
21
|
+
end
|
22
|
+
|
23
|
+
def compute(instances)
|
24
|
+
raise NotImplementedError, "Must be implemented by including classes"
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
class BinaryNode
|
6
|
+
attr_reader :test, :true_branch, :false_branch
|
7
|
+
|
8
|
+
def initialize(test, true_branch, false_branch)
|
9
|
+
@test = test
|
10
|
+
@true_branch = true_branch
|
11
|
+
@false_branch = false_branch
|
12
|
+
end
|
13
|
+
|
14
|
+
def branch_for(instance)
|
15
|
+
if @test.passed_by? instance
|
16
|
+
@true_branch
|
17
|
+
else
|
18
|
+
@false_branch
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
def classify(instance)
|
23
|
+
branch_for(instance).classify(instance)
|
24
|
+
end
|
25
|
+
|
26
|
+
def print_string(spacing = "")
|
27
|
+
print spacing
|
28
|
+
puts "#{@test} True"
|
29
|
+
@true_branch.print_string(spacing + " ")
|
30
|
+
|
31
|
+
print spacing
|
32
|
+
puts "#{@test} False"
|
33
|
+
@false_branch.print_string(spacing + " ")
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
class LeafNode
|
6
|
+
def initialize(instances)
|
7
|
+
@prediction = majority_vote(instances)
|
8
|
+
end
|
9
|
+
|
10
|
+
def classify(_instance)
|
11
|
+
@prediction
|
12
|
+
end
|
13
|
+
|
14
|
+
def print_string(spacing = "")
|
15
|
+
print spacing
|
16
|
+
puts to_s
|
17
|
+
end
|
18
|
+
|
19
|
+
def to_s
|
20
|
+
"Predict #{@prediction}"
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
|
25
|
+
def majority_vote(instances)
|
26
|
+
labels = instances[instances.label].to_enum
|
27
|
+
labels.max_by(1) { |label| labels.count(label) }.first
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
class MultiwayNode
|
6
|
+
attr_reader :paths, :majority_class
|
7
|
+
|
8
|
+
def initialize(paths, instances)
|
9
|
+
@paths = paths
|
10
|
+
@majority_class = majority_vote(instances)
|
11
|
+
end
|
12
|
+
|
13
|
+
def tests
|
14
|
+
@paths.keys
|
15
|
+
end
|
16
|
+
|
17
|
+
def branches
|
18
|
+
@paths.values
|
19
|
+
end
|
20
|
+
|
21
|
+
def branch_for(instance)
|
22
|
+
match = tests.find { |test| test.passed_by? instance }
|
23
|
+
paths[match]
|
24
|
+
end
|
25
|
+
|
26
|
+
def classify(instance)
|
27
|
+
branch = branch_for instance
|
28
|
+
branch&.classify(instance) || @majority_class
|
29
|
+
end
|
30
|
+
|
31
|
+
def print_string(spacing = "")
|
32
|
+
paths.each do |test, branch|
|
33
|
+
print spacing
|
34
|
+
puts "#{test} True"
|
35
|
+
branch.print_string(spacing + " ")
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def majority_vote(instances)
|
40
|
+
labels = instances[instances.label].to_enum
|
41
|
+
labels.max_by(1) { |label| labels.count(label) }.first
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module HybridForest
|
2
|
+
module Trees
|
3
|
+
class Split
|
4
|
+
include Comparable
|
5
|
+
|
6
|
+
attr_reader :feature, :value, :subsets, :info_gain
|
7
|
+
alias_method :better_than?, :>
|
8
|
+
alias_method :worse_than?, :<
|
9
|
+
|
10
|
+
def initialize(feature, info_gain: 0, subsets: [Rover::DataFrame.new, Rover::DataFrame.new], value: nil)
|
11
|
+
@feature = feature
|
12
|
+
@value = value
|
13
|
+
@subsets = subsets
|
14
|
+
@info_gain = info_gain
|
15
|
+
end
|
16
|
+
|
17
|
+
def <=>(other)
|
18
|
+
info_gain <=> other.info_gain
|
19
|
+
end
|
20
|
+
|
21
|
+
def binary?
|
22
|
+
value ? true : false
|
23
|
+
end
|
24
|
+
|
25
|
+
def multiway?
|
26
|
+
!binary?
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "test"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Tests
|
8
|
+
class Equal < Test
|
9
|
+
def initialize(feature, value)
|
10
|
+
super(feature, value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def passed_by?(instance)
|
14
|
+
instance[feature] == value
|
15
|
+
end
|
16
|
+
|
17
|
+
def ==(other)
|
18
|
+
return false unless other.is_a? Equal
|
19
|
+
feature == other.feature && value == other.value
|
20
|
+
end
|
21
|
+
|
22
|
+
def eql?(other)
|
23
|
+
return false unless other.is_a? Equal
|
24
|
+
feature.eql?(other.feature) && value.eql?(other.value)
|
25
|
+
end
|
26
|
+
|
27
|
+
def to_s
|
28
|
+
"#{feature} == #{value}?"
|
29
|
+
end
|
30
|
+
|
31
|
+
def description
|
32
|
+
"#{feature} equal to #{value}?"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "test"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Tests
|
8
|
+
class EqualOrGreater < Test
|
9
|
+
def initialize(feature, value)
|
10
|
+
super(feature, value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def passed_by?(instance)
|
14
|
+
instance[feature] >= value
|
15
|
+
end
|
16
|
+
|
17
|
+
def ==(other)
|
18
|
+
return false unless other.is_a? EqualOrGreater
|
19
|
+
feature == other.feature && value == other.value
|
20
|
+
end
|
21
|
+
|
22
|
+
def eql?(other)
|
23
|
+
return false unless other.is_a? EqualOrGreater
|
24
|
+
feature.eql?(other.feature) && value.eql?(other.value)
|
25
|
+
end
|
26
|
+
|
27
|
+
def to_s
|
28
|
+
"#{feature} >= #{value}?"
|
29
|
+
end
|
30
|
+
|
31
|
+
def description
|
32
|
+
"#{feature} equal to or greater than #{value}?"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "test"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Tests
|
8
|
+
class Less < Test
|
9
|
+
def initialize(feature, value)
|
10
|
+
super(feature, value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def passed_by?(instance)
|
14
|
+
instance[feature] < value
|
15
|
+
end
|
16
|
+
|
17
|
+
def ==(other)
|
18
|
+
return false unless other.is_a? Less
|
19
|
+
feature == other.feature && value == other.value
|
20
|
+
end
|
21
|
+
|
22
|
+
def eql?(other)
|
23
|
+
return false unless other.is_a? Less
|
24
|
+
feature.eql?(other.feature) && value.eql?(other.value)
|
25
|
+
end
|
26
|
+
|
27
|
+
def to_s
|
28
|
+
"#{feature} < #{value}?"
|
29
|
+
end
|
30
|
+
|
31
|
+
def description
|
32
|
+
"#{feature} less than #{value}?"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "test"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Tests
|
8
|
+
class NotEqual < Test
|
9
|
+
def initialize(feature, value)
|
10
|
+
super(feature, value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def passed_by?(instance)
|
14
|
+
instance[feature] != value
|
15
|
+
end
|
16
|
+
|
17
|
+
def ==(other)
|
18
|
+
return false unless other.is_a? NotEqual
|
19
|
+
feature == other.feature && value == other.value
|
20
|
+
end
|
21
|
+
|
22
|
+
def eql?(other)
|
23
|
+
return false unless other.is_a? NotEqual
|
24
|
+
feature.eql?(other.feature) && value.eql?(other.value)
|
25
|
+
end
|
26
|
+
|
27
|
+
def to_s
|
28
|
+
"#{feature} != #{value}?"
|
29
|
+
end
|
30
|
+
|
31
|
+
def description
|
32
|
+
"#{feature} not equal to #{value}?"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
module Tests
|
6
|
+
class Test
|
7
|
+
attr_reader :feature, :value
|
8
|
+
|
9
|
+
def initialize(feature, value)
|
10
|
+
@feature = feature
|
11
|
+
@value = value
|
12
|
+
end
|
13
|
+
|
14
|
+
def passed_by?(_instance)
|
15
|
+
raise NotImplementedError
|
16
|
+
end
|
17
|
+
|
18
|
+
def description
|
19
|
+
raise NotImplementedError
|
20
|
+
end
|
21
|
+
|
22
|
+
def hash
|
23
|
+
feature.hash ^ value.hash
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "../utilities/utils"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
class Tree
|
8
|
+
# Creates a new Tree using the specified tree growing algorithm.
|
9
|
+
def initialize(tree_grower:)
|
10
|
+
@tree_grower = tree_grower
|
11
|
+
end
|
12
|
+
|
13
|
+
##
|
14
|
+
# Fits a model to the given dataset +instances+ and returns +self+.
|
15
|
+
#
|
16
|
+
def fit(instances)
|
17
|
+
instances = HybridForest::Utils.to_dataframe(instances)
|
18
|
+
@root = @tree_grower.grow_tree(instances)
|
19
|
+
self
|
20
|
+
end
|
21
|
+
|
22
|
+
##
|
23
|
+
# Predicts a label for each instance in the dataset +instances+ and returns an array of labels.
|
24
|
+
#
|
25
|
+
def predict(instances)
|
26
|
+
if @root.nil?
|
27
|
+
raise Errors::InvalidStateError,
|
28
|
+
"You must call #fit before you call #predict"
|
29
|
+
end
|
30
|
+
|
31
|
+
HybridForest::Utils.to_dataframe(instances).each_row.reduce([]) do |predictions, instance|
|
32
|
+
predictions << @root.classify(instance)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
# Prints a string representation of this Tree.
|
37
|
+
def inspect
|
38
|
+
if @root.nil?
|
39
|
+
"Empty tree: #{super}"
|
40
|
+
else
|
41
|
+
@root.print_string
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "../../utilities/utils"
|
4
|
+
require_relative "../split"
|
5
|
+
require_relative "../feature_selectors/random_feature_subspace"
|
6
|
+
require_relative "../feature_selectors/all_features"
|
7
|
+
require_relative "../feature_selectors/max_one_split_per_feature"
|
8
|
+
require_relative "../impurity_metrics/gini_impurity"
|
9
|
+
require_relative "../impurity_metrics/entropy"
|
10
|
+
require_relative "../nodes/binary_node"
|
11
|
+
require_relative "../nodes/leaf_node"
|
12
|
+
require_relative "../tests/equal_or_greater"
|
13
|
+
|
14
|
+
module HybridForest
|
15
|
+
module Trees
|
16
|
+
module TreeGrowers
|
17
|
+
class CARTGrower
|
18
|
+
def initialize(feature_selector: RandomFeatureSubspace.new, impurity_metric: GiniImpurity.new)
|
19
|
+
@impurity_metric = impurity_metric
|
20
|
+
@feature_selector = feature_selector
|
21
|
+
end
|
22
|
+
|
23
|
+
def grow_tree(instances)
|
24
|
+
split = find_best_split(instances)
|
25
|
+
if split.info_gain == 0
|
26
|
+
LeafNode.new(instances)
|
27
|
+
else
|
28
|
+
branch(split.subsets, split.feature, split.value)
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
private
|
33
|
+
|
34
|
+
def branch(subsets, feature, value)
|
35
|
+
true_instances, false_instances = subsets
|
36
|
+
true_branch = grow_tree(true_instances)
|
37
|
+
false_branch = grow_tree(false_instances)
|
38
|
+
test = Tests::EqualOrGreater.new(feature, value)
|
39
|
+
BinaryNode.new(test, true_branch, false_branch)
|
40
|
+
end
|
41
|
+
|
42
|
+
def find_best_split(instances)
|
43
|
+
considered_features = @feature_selector.select_features(instances.features)
|
44
|
+
current_impurity = @impurity_metric.compute(instances)
|
45
|
+
best_split = default_split(instances, considered_features)
|
46
|
+
|
47
|
+
considered_features.each do |feature|
|
48
|
+
instances[feature].uniq.each do |value|
|
49
|
+
subsets = if instances[feature].numeric?
|
50
|
+
instances.equal_or_greater_split(feature, value)
|
51
|
+
else
|
52
|
+
instances.equal_split(feature, value)
|
53
|
+
end
|
54
|
+
|
55
|
+
next if subsets.any? { |set| set.count == 0 }
|
56
|
+
|
57
|
+
info_gain = @impurity_metric.information_gain(subsets, current_impurity)
|
58
|
+
split = Split.new(feature, info_gain: info_gain, subsets: subsets, value: value)
|
59
|
+
if split.better_than? best_split
|
60
|
+
best_split = split
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
best_split
|
66
|
+
end
|
67
|
+
|
68
|
+
def default_split(instances, features)
|
69
|
+
first_feature = features.first
|
70
|
+
value = instances[first_feature].first
|
71
|
+
Split.new(first_feature, value: value)
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|