hybridforest 0.9.0 → 0.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8d0af8ba6e658989bf55d0b8026822ab0be77020c189ffd056ffac3e7a393700
4
- data.tar.gz: 968c29d52088fe4ba28d93fcd3aedf7288d323f3a35ba000efde5b1673c10e2b
3
+ metadata.gz: aa4457c9f58fde0edfa9fd8af0a4ac5ef23db54892f75a50a8f62ac1f8652631
4
+ data.tar.gz: a8f0afc4b58d067388f3c5986521609a5ede0816643f73235b8517265c6b72c7
5
5
  SHA512:
6
- metadata.gz: 2e7076a8cdea280c900f61c98a6f53715a335d7765a6861f028a6956a8c6599cbe233b5d3e867b0d55f9a8889e41067e314ca151b00e4455f55071ac407059f5
7
- data.tar.gz: dbf3dc47bfba70ff33e9452cc9232d5079651410a3eb319b29172992141c1602f8c06bba50e7ee9a56c1c8530e9099763834b1a964a86cdc4ec4c89d944237f4
6
+ metadata.gz: 9261b0cc8fa086e8b91f910f8b852f2369a1eac0e76ee4b01f1db3dddf9bb7f8f9947cbcaca2f81d6db6a9c734ed516a0848f29b7735661ea5be428b8e720392
7
+ data.tar.gz: 5d8a886b98ca5e735b4173fb51e114d313c28ef480bd8cae32cf339634e49384a44d4fdb53548afb82228bd368960ef002d55bbd1292dd8313706b2458bf543c
data/CHANGELOG.md CHANGED
@@ -30,4 +30,22 @@
30
30
 
31
31
  ## [0.9.0] - 2021-12-28
32
32
 
33
- - Update dependencies
33
+ - Update dependencies
34
+
35
+ ## [0.10.0] - 2021-12-29
36
+
37
+ - Refactor dataframe extensions
38
+
39
+ ## [0.11.0] - 2021-12-29
40
+
41
+ - Randomize Utils.train_test_split
42
+ - Refactor Utils module
43
+
44
+ ## [0.12.0] - 2022-01-08
45
+
46
+ - Allow Utils.random_sample to be passed a dataframe or a dataframe convertible object
47
+ - Allow Utils.random_sample's 'size' arg to equal the size of the initial dataframe if the strategy is sampling with replacement
48
+
49
+ ## [0.13.0] - 2022-01-09
50
+
51
+ - Refactor forest growers
data/Gemfile.lock CHANGED
@@ -1,7 +1,7 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- hybridforest (0.9.0)
4
+ hybridforest (0.12.0)
5
5
  activesupport (~> 6.1)
6
6
  rake (~> 13.0)
7
7
  require_all
@@ -8,7 +8,7 @@ module HybridForest
8
8
  def grow_forest(instances, number_of_trees)
9
9
  forest = []
10
10
  number_of_trees.times do
11
- sample, _, _ = HybridForest::Utils.train_test_bootstrap_split(instances)
11
+ sample = HybridForest::Utils.random_sample(data: instances, size: instances.size)
12
12
  forest << HybridForest::Trees::CARTTree.new.fit(sample)
13
13
  end
14
14
  forest
@@ -8,7 +8,7 @@ module HybridForest
8
8
  def grow_forest(instances, number_of_trees)
9
9
  forest = []
10
10
  number_of_trees.times do
11
- sample, _, _ = HybridForest::Utils.train_test_bootstrap_split(instances)
11
+ sample = HybridForest::Utils.random_sample(data: instances, size: instances.size)
12
12
  forest << HybridForest::Trees::ID3Tree.new.fit(sample)
13
13
  end
14
14
  forest
@@ -6,6 +6,7 @@ module HybridForest
6
6
  module Trees
7
7
  class RandomFeatureSubspace
8
8
  def select_features(all_features)
9
+ # TODO: Allow the subspace size to be configured.
9
10
  n = default_subspace_size(all_features.count)
10
11
  indices = Set.new
11
12
  until indices.size == n
@@ -12,17 +12,16 @@ module HybridForest
12
12
  # of independent features and an array of labels. Returns [+training_set+, +testing_set+, +testing_set_labels+]
13
13
  #
14
14
  def self.train_test_split(dataset, test_set_size = 0.20)
15
- # TODO: Shuffle and stratify samples
15
+ # TODO: Offer stratify param
16
16
  dataset = to_dataframe(dataset)
17
+ all_rows = (0...dataset.count).to_a
17
18
 
18
19
  test_set_count = (dataset.count * test_set_size).floor
19
- test_set_indices = 0..test_set_count
20
- test_set = dataset[test_set_indices]
21
- test_set_labels = test_set.class_labels
22
- test_set.except!(test_set.label)
20
+ test_set_rows = rand_uniq_nums(test_set_count, 0...dataset.count)
21
+ test_set = dataset[test_set_rows]
22
+ test_set, test_set_labels = test_set.disconnect_labels
23
23
 
24
- train_set_indices = test_set_count + 1...dataset.count
25
- train_set = dataset[train_set_indices]
24
+ train_set = dataset[all_rows - test_set_rows]
26
25
 
27
26
  [train_set, test_set, test_set_labels]
28
27
  end
@@ -37,20 +36,13 @@ module HybridForest
37
36
  dataset = to_dataframe(dataset)
38
37
  all_rows = (0...dataset.count).to_a
39
38
 
40
- train_set = Rover::DataFrame.new
41
- train_set_rows = []
42
- dataset.count.times do
43
- row = all_rows.sample
44
- train_set_rows << row
45
- train_set.concat(dataset[row])
46
- end
39
+ train_set_rows = rand_nums(dataset.count, 0...dataset.count)
40
+ train_set = dataset[train_set_rows]
47
41
 
48
42
  return train_test_split(dataset) if train_set_rows.sort == all_rows
49
43
 
50
- test_set_rows = all_rows - train_set_rows
51
- test_set = dataset[test_set_rows]
52
- test_set_labels = test_set.class_labels
53
- test_set.except!(test_set.label)
44
+ test_set = dataset[all_rows - train_set_rows]
45
+ test_set, test_set_labels = test_set.disconnect_labels
54
46
 
55
47
  [train_set, test_set, test_set_labels]
56
48
  end
@@ -86,18 +78,18 @@ module HybridForest
86
78
  # Draws a random sample of +size+ from +data+.
87
79
  #
88
80
  def self.random_sample(data:, size:, with_replacement: true)
89
- raise ArgumentError, "Invalid sample size" if size < 1 || size > data.count
81
+ data = to_dataframe(data)
90
82
 
91
- if with_replacement
92
- rows = size.times.collect { rand(0...data.count) }
93
- data[rows]
83
+ if size < 1 || (!with_replacement && size > data.count)
84
+ raise ArgumentError, "Invalid sample size"
85
+ end
86
+
87
+ rows = if with_replacement
88
+ rand_nums(size, 0...data.count)
94
89
  else
95
- rows = Set.new
96
- until rows.size == size
97
- rows << rand(0...data.count)
98
- end
99
- data[rows.to_a]
90
+ rand_uniq_nums(size, 0...data.count)
100
91
  end
92
+ data[rows]
101
93
  end
102
94
 
103
95
  # Outputs a report of common prediction metrics.
@@ -113,7 +105,7 @@ module HybridForest
113
105
  #
114
106
  def self.accuracy(predicted, actual)
115
107
  accurate = predicted.zip(actual).count { |p, a| equal_labels?(p, a) }
116
- accurate.to_f / predicted.count.to_f
108
+ accurate.to_f / predicted.count
117
109
  end
118
110
 
119
111
  # Extensions to simplify common dataframe operations.
@@ -146,11 +138,11 @@ module HybridForest
146
138
  end
147
139
 
148
140
  def feature_count(without_label: true)
149
- without_label ? names.count - 1 : names.count
141
+ without_label ? features.count : names.count
150
142
  end
151
143
 
152
144
  def pure?
153
- column_by_index(-1).uniq.size == 1
145
+ self[label].uniq.size == 1
154
146
  end
155
147
 
156
148
  def features
@@ -158,7 +150,7 @@ module HybridForest
158
150
  end
159
151
 
160
152
  def count_labels
161
- column_by_index(-1).tally
153
+ self[label].tally
162
154
  end
163
155
 
164
156
  def label
@@ -168,6 +160,12 @@ module HybridForest
168
160
  def class_labels
169
161
  self[label].to_a
170
162
  end
163
+
164
+ def disconnect_labels
165
+ labels = class_labels
166
+ except!(label)
167
+ [self, labels]
168
+ end
171
169
  end
172
170
  end
173
171
 
@@ -202,5 +200,23 @@ module HybridForest
202
200
  def false_label?(label)
203
201
  [false, 0].include? label
204
202
  end
203
+
204
+ ##
205
+ # Returns an array of +n+ random numbers in the exclusive +range+.
206
+ def rand_nums(n, range)
207
+ n.times.collect { rand(range) }
208
+ end
209
+
210
+ ##
211
+ # Returns an array of +n+ _unique_ random numbers in the exclusive +range+.
212
+ def rand_uniq_nums(n, range)
213
+ raise ArgumentError if n > range.size
214
+
215
+ nums = Set.new
216
+ until nums.size == n
217
+ nums << rand(range)
218
+ end
219
+ nums.to_a
220
+ end
205
221
  end
206
222
  end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module HybridForest
4
- VERSION = "0.9.0"
4
+ VERSION = "0.13.0"
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: hybridforest
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.9.0
4
+ version: 0.13.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - hi-tech-jazz
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2021-12-28 00:00:00.000000000 Z
11
+ date: 2022-01-09 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rake