hybridforest 0.10.0 → 0.14.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +20 -1
- data/Gemfile.lock +1 -1
- data/lib/hybridforest/forests/forest_growers/cart_grower.rb +1 -1
- data/lib/hybridforest/forests/forest_growers/hybrid_grower.rb +19 -15
- data/lib/hybridforest/forests/forest_growers/id3_grower.rb +1 -1
- data/lib/hybridforest/trees/feature_selectors/random_feature_subspace.rb +1 -0
- data/lib/hybridforest/utilities/utils.rb +43 -27
- data/lib/hybridforest/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 4311342fccd332cd6f98b2b6a30f32acc01f2f3090490e452e9bfe8981730f07
|
4
|
+
data.tar.gz: 2a4461a04ac9232d5506271ddbf8b9f9b49b7397a06c00aea894711c3ce0586f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9708316bfa685c814afbffea2d2ea238dd8736ac929bd782a365d6ab65dbe2d1987df7a66750db3d9108071279a99cb1206f179163f1383494c0ebcda704fab6
|
7
|
+
data.tar.gz: efbd6db6210d02830734bb7e3411f8b9553ed8b6172df59f1180e801884057d34edd85a67cd6c63e6b09c8531cf434ccdf265e14f4d8d26e035b333a5805c48d
|
data/CHANGELOG.md
CHANGED
@@ -34,4 +34,23 @@
|
|
34
34
|
|
35
35
|
## [0.10.0] - 2021-12-29
|
36
36
|
|
37
|
-
- Refactor dataframe extensions
|
37
|
+
- Refactor dataframe extensions
|
38
|
+
|
39
|
+
## [0.11.0] - 2021-12-29
|
40
|
+
|
41
|
+
- Randomize Utils.train_test_split
|
42
|
+
- Refactor Utils module
|
43
|
+
|
44
|
+
## [0.12.0] - 2022-01-08
|
45
|
+
|
46
|
+
- Allow Utils.random_sample to be passed a dataframe or a dataframe convertible object
|
47
|
+
- Allow Utils.random_sample's 'size' arg to equal the size of the initial dataframe if the strategy is sampling with replacement
|
48
|
+
|
49
|
+
## [0.13.0] - 2022-01-09
|
50
|
+
|
51
|
+
- Refactor forest growers
|
52
|
+
|
53
|
+
|
54
|
+
## [0.14.0] - 2022-01-09
|
55
|
+
|
56
|
+
- Refactor hybrid forest grower
|
data/Gemfile.lock
CHANGED
@@ -8,7 +8,7 @@ module HybridForest
|
|
8
8
|
def grow_forest(instances, number_of_trees)
|
9
9
|
forest = []
|
10
10
|
number_of_trees.times do
|
11
|
-
sample
|
11
|
+
sample = HybridForest::Utils.random_sample(data: instances, size: instances.size)
|
12
12
|
forest << HybridForest::Trees::CARTTree.new.fit(sample)
|
13
13
|
end
|
14
14
|
forest
|
@@ -11,8 +11,9 @@ module HybridForest
|
|
11
11
|
def grow_forest(instances, number_of_trees)
|
12
12
|
forest = []
|
13
13
|
number_of_trees.times do
|
14
|
-
|
15
|
-
|
14
|
+
iob_data, oob_data, oob_labels = HybridForest::Utils.train_test_bootstrap_split(instances)
|
15
|
+
trees = grow_trees(TREE_TYPES, iob_data)
|
16
|
+
tree_results = predict_evaluate_trees(trees, oob_data, oob_labels)
|
16
17
|
best_tree = select_best_tree(tree_results)
|
17
18
|
forest << best_tree
|
18
19
|
end
|
@@ -21,25 +22,28 @@ module HybridForest
|
|
21
22
|
|
22
23
|
private
|
23
24
|
|
24
|
-
def
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
25
|
+
def grow_trees(tree_types, iob_data)
|
26
|
+
tree_types.collect do |tree_type|
|
27
|
+
tree_type.new.fit(iob_data)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def predict_evaluate_trees(trees, oob_data, oob_labels)
|
32
|
+
trees.collect do |tree|
|
33
|
+
predict_evaluate(tree, oob_data, oob_labels)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def predict_evaluate(tree, data, actual_labels)
|
38
|
+
predicted_labels = tree.predict(data)
|
39
|
+
accuracy = HybridForest::Utils.accuracy(predicted_labels, actual_labels)
|
40
|
+
{tree: tree, oob_accuracy: accuracy}
|
29
41
|
end
|
30
42
|
|
31
43
|
def select_best_tree(tree_results)
|
32
44
|
best_result = tree_results.max_by(1) { |result| result[:oob_accuracy] }.first
|
33
45
|
best_result[:tree]
|
34
46
|
end
|
35
|
-
|
36
|
-
def grow_trees(tree_types, in_of_bag, out_of_bag, out_of_bag_labels)
|
37
|
-
tree_results = []
|
38
|
-
tree_types.each do |tree_type|
|
39
|
-
tree_results << fit_and_predict(tree_type, in_of_bag, out_of_bag, out_of_bag_labels)
|
40
|
-
end
|
41
|
-
tree_results
|
42
|
-
end
|
43
47
|
end
|
44
48
|
end
|
45
49
|
end
|
@@ -8,7 +8,7 @@ module HybridForest
|
|
8
8
|
def grow_forest(instances, number_of_trees)
|
9
9
|
forest = []
|
10
10
|
number_of_trees.times do
|
11
|
-
sample
|
11
|
+
sample = HybridForest::Utils.random_sample(data: instances, size: instances.size)
|
12
12
|
forest << HybridForest::Trees::ID3Tree.new.fit(sample)
|
13
13
|
end
|
14
14
|
forest
|
@@ -12,17 +12,16 @@ module HybridForest
|
|
12
12
|
# of independent features and an array of labels. Returns [+training_set+, +testing_set+, +testing_set_labels+]
|
13
13
|
#
|
14
14
|
def self.train_test_split(dataset, test_set_size = 0.20)
|
15
|
-
# TODO:
|
15
|
+
# TODO: Offer stratify param
|
16
16
|
dataset = to_dataframe(dataset)
|
17
|
+
all_rows = (0...dataset.count).to_a
|
17
18
|
|
18
19
|
test_set_count = (dataset.count * test_set_size).floor
|
19
|
-
|
20
|
-
test_set = dataset[
|
21
|
-
test_set_labels = test_set.
|
22
|
-
test_set.except!(test_set.label)
|
20
|
+
test_set_rows = rand_uniq_nums(test_set_count, 0...dataset.count)
|
21
|
+
test_set = dataset[test_set_rows]
|
22
|
+
test_set, test_set_labels = test_set.disconnect_labels
|
23
23
|
|
24
|
-
|
25
|
-
train_set = dataset[train_set_indices]
|
24
|
+
train_set = dataset[all_rows - test_set_rows]
|
26
25
|
|
27
26
|
[train_set, test_set, test_set_labels]
|
28
27
|
end
|
@@ -37,20 +36,13 @@ module HybridForest
|
|
37
36
|
dataset = to_dataframe(dataset)
|
38
37
|
all_rows = (0...dataset.count).to_a
|
39
38
|
|
40
|
-
|
41
|
-
|
42
|
-
dataset.count.times do
|
43
|
-
row = all_rows.sample
|
44
|
-
train_set_rows << row
|
45
|
-
train_set.concat(dataset[row])
|
46
|
-
end
|
39
|
+
train_set_rows = rand_nums(dataset.count, 0...dataset.count)
|
40
|
+
train_set = dataset[train_set_rows]
|
47
41
|
|
48
42
|
return train_test_split(dataset) if train_set_rows.sort == all_rows
|
49
43
|
|
50
|
-
|
51
|
-
test_set =
|
52
|
-
test_set_labels = test_set.class_labels
|
53
|
-
test_set.except!(test_set.label)
|
44
|
+
test_set = dataset[all_rows - train_set_rows]
|
45
|
+
test_set, test_set_labels = test_set.disconnect_labels
|
54
46
|
|
55
47
|
[train_set, test_set, test_set_labels]
|
56
48
|
end
|
@@ -86,18 +78,18 @@ module HybridForest
|
|
86
78
|
# Draws a random sample of +size+ from +data+.
|
87
79
|
#
|
88
80
|
def self.random_sample(data:, size:, with_replacement: true)
|
89
|
-
|
81
|
+
data = to_dataframe(data)
|
90
82
|
|
91
|
-
if with_replacement
|
92
|
-
|
93
|
-
|
83
|
+
if size < 1 || (!with_replacement && size > data.count)
|
84
|
+
raise ArgumentError, "Invalid sample size"
|
85
|
+
end
|
86
|
+
|
87
|
+
rows = if with_replacement
|
88
|
+
rand_nums(size, 0...data.count)
|
94
89
|
else
|
95
|
-
|
96
|
-
until rows.size == size
|
97
|
-
rows << rand(0...data.count)
|
98
|
-
end
|
99
|
-
data[rows.to_a]
|
90
|
+
rand_uniq_nums(size, 0...data.count)
|
100
91
|
end
|
92
|
+
data[rows]
|
101
93
|
end
|
102
94
|
|
103
95
|
# Outputs a report of common prediction metrics.
|
@@ -168,6 +160,12 @@ module HybridForest
|
|
168
160
|
def class_labels
|
169
161
|
self[label].to_a
|
170
162
|
end
|
163
|
+
|
164
|
+
def disconnect_labels
|
165
|
+
labels = class_labels
|
166
|
+
except!(label)
|
167
|
+
[self, labels]
|
168
|
+
end
|
171
169
|
end
|
172
170
|
end
|
173
171
|
|
@@ -202,5 +200,23 @@ module HybridForest
|
|
202
200
|
def false_label?(label)
|
203
201
|
[false, 0].include? label
|
204
202
|
end
|
203
|
+
|
204
|
+
##
|
205
|
+
# Returns an array of +n+ random numbers in the exclusive +range+.
|
206
|
+
def rand_nums(n, range)
|
207
|
+
n.times.collect { rand(range) }
|
208
|
+
end
|
209
|
+
|
210
|
+
##
|
211
|
+
# Returns an array of +n+ _unique_ random numbers in the exclusive +range+.
|
212
|
+
def rand_uniq_nums(n, range)
|
213
|
+
raise ArgumentError if n > range.size
|
214
|
+
|
215
|
+
nums = Set.new
|
216
|
+
until nums.size == n
|
217
|
+
nums << rand(range)
|
218
|
+
end
|
219
|
+
nums.to_a
|
220
|
+
end
|
205
221
|
end
|
206
222
|
end
|
data/lib/hybridforest/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: hybridforest
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.14.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- hi-tech-jazz
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2022-01-09 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rake
|