hybridforest 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +89 -0
- data/.rspec +3 -0
- data/.rubocop.yml +13 -0
- data/CHANGELOG.md +33 -0
- data/Gemfile +22 -0
- data/Gemfile.lock +92 -0
- data/LICENSE.txt +21 -0
- data/README.md +96 -0
- data/Rakefile +12 -0
- data/bin/console +15 -0
- data/bin/setup +8 -0
- data/hybridforest.gemspec +44 -0
- data/lib/hybridforest/errors/invalid_state_error.rb +9 -0
- data/lib/hybridforest/forests/forest_growers/cart_grower.rb +19 -0
- data/lib/hybridforest/forests/forest_growers/hybrid_grower.rb +46 -0
- data/lib/hybridforest/forests/forest_growers/id3_grower.rb +19 -0
- data/lib/hybridforest/forests/grower_factory.rb +29 -0
- data/lib/hybridforest/forests/random_forest.rb +84 -0
- data/lib/hybridforest/trees/cart_tree.rb +18 -0
- data/lib/hybridforest/trees/feature_selectors/all_features.rb +14 -0
- data/lib/hybridforest/trees/feature_selectors/max_one_split_per_feature.rb +21 -0
- data/lib/hybridforest/trees/feature_selectors/random_feature_subspace.rb +27 -0
- data/lib/hybridforest/trees/id3_tree.rb +18 -0
- data/lib/hybridforest/trees/impurity_metrics/entropy.rb +21 -0
- data/lib/hybridforest/trees/impurity_metrics/gini_impurity.rb +17 -0
- data/lib/hybridforest/trees/impurity_metrics/impurity.rb +28 -0
- data/lib/hybridforest/trees/nodes/binary_node.rb +37 -0
- data/lib/hybridforest/trees/nodes/leaf_node.rb +31 -0
- data/lib/hybridforest/trees/nodes/multiway_node.rb +45 -0
- data/lib/hybridforest/trees/split.rb +30 -0
- data/lib/hybridforest/trees/tests/equal.rb +37 -0
- data/lib/hybridforest/trees/tests/equal_or_greater.rb +37 -0
- data/lib/hybridforest/trees/tests/less.rb +37 -0
- data/lib/hybridforest/trees/tests/not_equal.rb +37 -0
- data/lib/hybridforest/trees/tests/test.rb +28 -0
- data/lib/hybridforest/trees/tree.rb +46 -0
- data/lib/hybridforest/trees/tree_growers/cart_grower.rb +76 -0
- data/lib/hybridforest/trees/tree_growers/id3_grower.rb +161 -0
- data/lib/hybridforest/utilities/utils.rb +206 -0
- data/lib/hybridforest/version.rb +5 -0
- data/lib/hybridforest.rb +46 -0
- metadata +285 -0
@@ -0,0 +1,18 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "tree"
|
4
|
+
require_relative "tree_growers/cart_grower"
|
5
|
+
|
6
|
+
module HybridForest
|
7
|
+
module Trees
|
8
|
+
class CARTTree < Tree
|
9
|
+
def initialize(tree_grower: HybridForest::Trees::TreeGrowers::CARTGrower.new)
|
10
|
+
super(tree_grower: tree_grower)
|
11
|
+
end
|
12
|
+
|
13
|
+
def name
|
14
|
+
"CART Tree"
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
class MaxOneSplitPerFeature
|
6
|
+
attr_reader :exhausted_features
|
7
|
+
|
8
|
+
def initialize
|
9
|
+
@exhausted_features = []
|
10
|
+
end
|
11
|
+
|
12
|
+
def select_features(all_features)
|
13
|
+
all_features - @exhausted_features
|
14
|
+
end
|
15
|
+
|
16
|
+
def update(feature)
|
17
|
+
@exhausted_features << feature
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "set"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
class RandomFeatureSubspace
|
8
|
+
def select_features(all_features)
|
9
|
+
n = default_subspace_size(all_features.count)
|
10
|
+
indices = Set.new
|
11
|
+
until indices.size == n
|
12
|
+
indices << rand(0...all_features.count)
|
13
|
+
end
|
14
|
+
all_features.values_at(*indices)
|
15
|
+
end
|
16
|
+
|
17
|
+
def update(_feature)
|
18
|
+
end
|
19
|
+
|
20
|
+
private
|
21
|
+
|
22
|
+
def default_subspace_size(num_of_features)
|
23
|
+
(num_of_features.to_f / 2).round
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "tree"
|
4
|
+
require_relative "tree_growers/id3_grower"
|
5
|
+
|
6
|
+
module HybridForest
|
7
|
+
module Trees
|
8
|
+
class ID3Tree < Tree
|
9
|
+
def initialize(tree_grower: HybridForest::Trees::TreeGrowers::ID3Grower.new)
|
10
|
+
super(tree_grower: tree_grower)
|
11
|
+
end
|
12
|
+
|
13
|
+
def name
|
14
|
+
"ID3 Tree"
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "impurity"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
class Entropy
|
8
|
+
include HybridForest::Trees::Impurity
|
9
|
+
|
10
|
+
def compute(instances)
|
11
|
+
instance_count = instances.count
|
12
|
+
label_counts = instances.count_labels
|
13
|
+
label_counts.values.sum do |label_count|
|
14
|
+
label_probability = label_count.to_f / instance_count
|
15
|
+
label_surprise = Math.log2(1 / label_probability)
|
16
|
+
label_probability * label_surprise
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "impurity"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
class GiniImpurity
|
8
|
+
include HybridForest::Trees::Impurity
|
9
|
+
|
10
|
+
def compute(instances)
|
11
|
+
total_count = instances.count.to_f
|
12
|
+
label_counts = instances.count_labels.each_value
|
13
|
+
1 - label_counts.sum { |count| (count / total_count)**2 }
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "active_support"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Impurity
|
8
|
+
def information_gain(children, parent_impurity)
|
9
|
+
return 0.0 if children.blank?
|
10
|
+
|
11
|
+
parent_count = children.sum(&:count)
|
12
|
+
children_impurity = children.sum do |child|
|
13
|
+
weighted_impurity(child, parent_count)
|
14
|
+
end
|
15
|
+
parent_impurity - children_impurity
|
16
|
+
end
|
17
|
+
|
18
|
+
def weighted_impurity(instances_in_child, parent_count)
|
19
|
+
weight = instances_in_child.count.to_f / parent_count
|
20
|
+
compute(instances_in_child) * weight
|
21
|
+
end
|
22
|
+
|
23
|
+
def compute(instances)
|
24
|
+
raise NotImplementedError, "Must be implemented by including classes"
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
class BinaryNode
|
6
|
+
attr_reader :test, :true_branch, :false_branch
|
7
|
+
|
8
|
+
def initialize(test, true_branch, false_branch)
|
9
|
+
@test = test
|
10
|
+
@true_branch = true_branch
|
11
|
+
@false_branch = false_branch
|
12
|
+
end
|
13
|
+
|
14
|
+
def branch_for(instance)
|
15
|
+
if @test.passed_by? instance
|
16
|
+
@true_branch
|
17
|
+
else
|
18
|
+
@false_branch
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
def classify(instance)
|
23
|
+
branch_for(instance).classify(instance)
|
24
|
+
end
|
25
|
+
|
26
|
+
def print_string(spacing = "")
|
27
|
+
print spacing
|
28
|
+
puts "#{@test} True"
|
29
|
+
@true_branch.print_string(spacing + " ")
|
30
|
+
|
31
|
+
print spacing
|
32
|
+
puts "#{@test} False"
|
33
|
+
@false_branch.print_string(spacing + " ")
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
class LeafNode
|
6
|
+
def initialize(instances)
|
7
|
+
@prediction = majority_vote(instances)
|
8
|
+
end
|
9
|
+
|
10
|
+
def classify(_instance)
|
11
|
+
@prediction
|
12
|
+
end
|
13
|
+
|
14
|
+
def print_string(spacing = "")
|
15
|
+
print spacing
|
16
|
+
puts to_s
|
17
|
+
end
|
18
|
+
|
19
|
+
def to_s
|
20
|
+
"Predict #{@prediction}"
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
|
25
|
+
def majority_vote(instances)
|
26
|
+
labels = instances[instances.label].to_enum
|
27
|
+
labels.max_by(1) { |label| labels.count(label) }.first
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
class MultiwayNode
|
6
|
+
attr_reader :paths, :majority_class
|
7
|
+
|
8
|
+
def initialize(paths, instances)
|
9
|
+
@paths = paths
|
10
|
+
@majority_class = majority_vote(instances)
|
11
|
+
end
|
12
|
+
|
13
|
+
def tests
|
14
|
+
@paths.keys
|
15
|
+
end
|
16
|
+
|
17
|
+
def branches
|
18
|
+
@paths.values
|
19
|
+
end
|
20
|
+
|
21
|
+
def branch_for(instance)
|
22
|
+
match = tests.find { |test| test.passed_by? instance }
|
23
|
+
paths[match]
|
24
|
+
end
|
25
|
+
|
26
|
+
def classify(instance)
|
27
|
+
branch = branch_for instance
|
28
|
+
branch&.classify(instance) || @majority_class
|
29
|
+
end
|
30
|
+
|
31
|
+
def print_string(spacing = "")
|
32
|
+
paths.each do |test, branch|
|
33
|
+
print spacing
|
34
|
+
puts "#{test} True"
|
35
|
+
branch.print_string(spacing + " ")
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def majority_vote(instances)
|
40
|
+
labels = instances[instances.label].to_enum
|
41
|
+
labels.max_by(1) { |label| labels.count(label) }.first
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module HybridForest
|
2
|
+
module Trees
|
3
|
+
class Split
|
4
|
+
include Comparable
|
5
|
+
|
6
|
+
attr_reader :feature, :value, :subsets, :info_gain
|
7
|
+
alias_method :better_than?, :>
|
8
|
+
alias_method :worse_than?, :<
|
9
|
+
|
10
|
+
def initialize(feature, info_gain: 0, subsets: [Rover::DataFrame.new, Rover::DataFrame.new], value: nil)
|
11
|
+
@feature = feature
|
12
|
+
@value = value
|
13
|
+
@subsets = subsets
|
14
|
+
@info_gain = info_gain
|
15
|
+
end
|
16
|
+
|
17
|
+
def <=>(other)
|
18
|
+
info_gain <=> other.info_gain
|
19
|
+
end
|
20
|
+
|
21
|
+
def binary?
|
22
|
+
value ? true : false
|
23
|
+
end
|
24
|
+
|
25
|
+
def multiway?
|
26
|
+
!binary?
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "test"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Tests
|
8
|
+
class Equal < Test
|
9
|
+
def initialize(feature, value)
|
10
|
+
super(feature, value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def passed_by?(instance)
|
14
|
+
instance[feature] == value
|
15
|
+
end
|
16
|
+
|
17
|
+
def ==(other)
|
18
|
+
return false unless other.is_a? Equal
|
19
|
+
feature == other.feature && value == other.value
|
20
|
+
end
|
21
|
+
|
22
|
+
def eql?(other)
|
23
|
+
return false unless other.is_a? Equal
|
24
|
+
feature.eql?(other.feature) && value.eql?(other.value)
|
25
|
+
end
|
26
|
+
|
27
|
+
def to_s
|
28
|
+
"#{feature} == #{value}?"
|
29
|
+
end
|
30
|
+
|
31
|
+
def description
|
32
|
+
"#{feature} equal to #{value}?"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "test"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Tests
|
8
|
+
class EqualOrGreater < Test
|
9
|
+
def initialize(feature, value)
|
10
|
+
super(feature, value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def passed_by?(instance)
|
14
|
+
instance[feature] >= value
|
15
|
+
end
|
16
|
+
|
17
|
+
def ==(other)
|
18
|
+
return false unless other.is_a? EqualOrGreater
|
19
|
+
feature == other.feature && value == other.value
|
20
|
+
end
|
21
|
+
|
22
|
+
def eql?(other)
|
23
|
+
return false unless other.is_a? EqualOrGreater
|
24
|
+
feature.eql?(other.feature) && value.eql?(other.value)
|
25
|
+
end
|
26
|
+
|
27
|
+
def to_s
|
28
|
+
"#{feature} >= #{value}?"
|
29
|
+
end
|
30
|
+
|
31
|
+
def description
|
32
|
+
"#{feature} equal to or greater than #{value}?"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "test"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Tests
|
8
|
+
class Less < Test
|
9
|
+
def initialize(feature, value)
|
10
|
+
super(feature, value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def passed_by?(instance)
|
14
|
+
instance[feature] < value
|
15
|
+
end
|
16
|
+
|
17
|
+
def ==(other)
|
18
|
+
return false unless other.is_a? Less
|
19
|
+
feature == other.feature && value == other.value
|
20
|
+
end
|
21
|
+
|
22
|
+
def eql?(other)
|
23
|
+
return false unless other.is_a? Less
|
24
|
+
feature.eql?(other.feature) && value.eql?(other.value)
|
25
|
+
end
|
26
|
+
|
27
|
+
def to_s
|
28
|
+
"#{feature} < #{value}?"
|
29
|
+
end
|
30
|
+
|
31
|
+
def description
|
32
|
+
"#{feature} less than #{value}?"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "test"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
module Tests
|
8
|
+
class NotEqual < Test
|
9
|
+
def initialize(feature, value)
|
10
|
+
super(feature, value)
|
11
|
+
end
|
12
|
+
|
13
|
+
def passed_by?(instance)
|
14
|
+
instance[feature] != value
|
15
|
+
end
|
16
|
+
|
17
|
+
def ==(other)
|
18
|
+
return false unless other.is_a? NotEqual
|
19
|
+
feature == other.feature && value == other.value
|
20
|
+
end
|
21
|
+
|
22
|
+
def eql?(other)
|
23
|
+
return false unless other.is_a? NotEqual
|
24
|
+
feature.eql?(other.feature) && value.eql?(other.value)
|
25
|
+
end
|
26
|
+
|
27
|
+
def to_s
|
28
|
+
"#{feature} != #{value}?"
|
29
|
+
end
|
30
|
+
|
31
|
+
def description
|
32
|
+
"#{feature} not equal to #{value}?"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module HybridForest
|
4
|
+
module Trees
|
5
|
+
module Tests
|
6
|
+
class Test
|
7
|
+
attr_reader :feature, :value
|
8
|
+
|
9
|
+
def initialize(feature, value)
|
10
|
+
@feature = feature
|
11
|
+
@value = value
|
12
|
+
end
|
13
|
+
|
14
|
+
def passed_by?(_instance)
|
15
|
+
raise NotImplementedError
|
16
|
+
end
|
17
|
+
|
18
|
+
def description
|
19
|
+
raise NotImplementedError
|
20
|
+
end
|
21
|
+
|
22
|
+
def hash
|
23
|
+
feature.hash ^ value.hash
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "../utilities/utils"
|
4
|
+
|
5
|
+
module HybridForest
|
6
|
+
module Trees
|
7
|
+
class Tree
|
8
|
+
# Creates a new Tree using the specified tree growing algorithm.
|
9
|
+
def initialize(tree_grower:)
|
10
|
+
@tree_grower = tree_grower
|
11
|
+
end
|
12
|
+
|
13
|
+
##
|
14
|
+
# Fits a model to the given dataset +instances+ and returns +self+.
|
15
|
+
#
|
16
|
+
def fit(instances)
|
17
|
+
instances = HybridForest::Utils.to_dataframe(instances)
|
18
|
+
@root = @tree_grower.grow_tree(instances)
|
19
|
+
self
|
20
|
+
end
|
21
|
+
|
22
|
+
##
|
23
|
+
# Predicts a label for each instance in the dataset +instances+ and returns an array of labels.
|
24
|
+
#
|
25
|
+
def predict(instances)
|
26
|
+
if @root.nil?
|
27
|
+
raise Errors::InvalidStateError,
|
28
|
+
"You must call #fit before you call #predict"
|
29
|
+
end
|
30
|
+
|
31
|
+
HybridForest::Utils.to_dataframe(instances).each_row.reduce([]) do |predictions, instance|
|
32
|
+
predictions << @root.classify(instance)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
# Prints a string representation of this Tree.
|
37
|
+
def inspect
|
38
|
+
if @root.nil?
|
39
|
+
"Empty tree: #{super}"
|
40
|
+
else
|
41
|
+
@root.print_string
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "../../utilities/utils"
|
4
|
+
require_relative "../split"
|
5
|
+
require_relative "../feature_selectors/random_feature_subspace"
|
6
|
+
require_relative "../feature_selectors/all_features"
|
7
|
+
require_relative "../feature_selectors/max_one_split_per_feature"
|
8
|
+
require_relative "../impurity_metrics/gini_impurity"
|
9
|
+
require_relative "../impurity_metrics/entropy"
|
10
|
+
require_relative "../nodes/binary_node"
|
11
|
+
require_relative "../nodes/leaf_node"
|
12
|
+
require_relative "../tests/equal_or_greater"
|
13
|
+
|
14
|
+
module HybridForest
|
15
|
+
module Trees
|
16
|
+
module TreeGrowers
|
17
|
+
class CARTGrower
|
18
|
+
def initialize(feature_selector: RandomFeatureSubspace.new, impurity_metric: GiniImpurity.new)
|
19
|
+
@impurity_metric = impurity_metric
|
20
|
+
@feature_selector = feature_selector
|
21
|
+
end
|
22
|
+
|
23
|
+
def grow_tree(instances)
|
24
|
+
split = find_best_split(instances)
|
25
|
+
if split.info_gain == 0
|
26
|
+
LeafNode.new(instances)
|
27
|
+
else
|
28
|
+
branch(split.subsets, split.feature, split.value)
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
private
|
33
|
+
|
34
|
+
def branch(subsets, feature, value)
|
35
|
+
true_instances, false_instances = subsets
|
36
|
+
true_branch = grow_tree(true_instances)
|
37
|
+
false_branch = grow_tree(false_instances)
|
38
|
+
test = Tests::EqualOrGreater.new(feature, value)
|
39
|
+
BinaryNode.new(test, true_branch, false_branch)
|
40
|
+
end
|
41
|
+
|
42
|
+
def find_best_split(instances)
|
43
|
+
considered_features = @feature_selector.select_features(instances.features)
|
44
|
+
current_impurity = @impurity_metric.compute(instances)
|
45
|
+
best_split = default_split(instances, considered_features)
|
46
|
+
|
47
|
+
considered_features.each do |feature|
|
48
|
+
instances[feature].uniq.each do |value|
|
49
|
+
subsets = if instances[feature].numeric?
|
50
|
+
instances.equal_or_greater_split(feature, value)
|
51
|
+
else
|
52
|
+
instances.equal_split(feature, value)
|
53
|
+
end
|
54
|
+
|
55
|
+
next if subsets.any? { |set| set.count == 0 }
|
56
|
+
|
57
|
+
info_gain = @impurity_metric.information_gain(subsets, current_impurity)
|
58
|
+
split = Split.new(feature, info_gain: info_gain, subsets: subsets, value: value)
|
59
|
+
if split.better_than? best_split
|
60
|
+
best_split = split
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
best_split
|
66
|
+
end
|
67
|
+
|
68
|
+
def default_split(instances, features)
|
69
|
+
first_feature = features.first
|
70
|
+
value = instances[first_feature].first
|
71
|
+
Split.new(first_feature, value: value)
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|