ai4r 1.13 → 2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +174 -0
- data/examples/classifiers/hyperpipes_data.csv +14 -0
- data/examples/classifiers/hyperpipes_example.rb +22 -0
- data/examples/classifiers/ib1_example.rb +12 -0
- data/examples/classifiers/id3_example.rb +15 -10
- data/examples/classifiers/id3_graphviz_example.rb +17 -0
- data/examples/classifiers/logistic_regression_example.rb +11 -0
- data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
- data/examples/classifiers/naive_bayes_example.rb +12 -13
- data/examples/classifiers/one_r_example.rb +27 -0
- data/examples/classifiers/parameter_tutorial.rb +29 -0
- data/examples/classifiers/prism_nominal_example.rb +15 -0
- data/examples/classifiers/prism_numeric_example.rb +21 -0
- data/examples/classifiers/simple_linear_regression_example.rb +14 -11
- data/examples/classifiers/zero_and_one_r_example.rb +34 -0
- data/examples/classifiers/zero_one_r_data.csv +8 -0
- data/examples/clusterers/clusterer_example.rb +40 -34
- data/examples/clusterers/dbscan_example.rb +17 -0
- data/examples/clusterers/dendrogram_example.rb +17 -0
- data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
- data/examples/clusterers/kmeans_custom_example.rb +26 -0
- data/examples/genetic_algorithm/bitstring_example.rb +41 -0
- data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
- data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
- data/examples/neural_network/backpropagation_example.rb +48 -48
- data/examples/neural_network/hopfield_example.rb +45 -0
- data/examples/neural_network/patterns_with_base_noise.rb +39 -39
- data/examples/neural_network/patterns_with_noise.rb +41 -39
- data/examples/neural_network/train_epochs_callback.rb +25 -0
- data/examples/neural_network/training_patterns.rb +39 -39
- data/examples/neural_network/transformer_text_classification.rb +78 -0
- data/examples/neural_network/xor_example.rb +23 -22
- data/examples/reinforcement/q_learning_example.rb +10 -0
- data/examples/som/som_data.rb +155 -152
- data/examples/som/som_multi_node_example.rb +12 -13
- data/examples/som/som_single_example.rb +12 -15
- data/examples/transformer/decode_classifier_example.rb +68 -0
- data/examples/transformer/deterministic_example.rb +10 -0
- data/examples/transformer/seq2seq_example.rb +16 -0
- data/lib/ai4r/classifiers/classifier.rb +24 -16
- data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
- data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
- data/lib/ai4r/classifiers/ib1.rb +122 -32
- data/lib/ai4r/classifiers/id3.rb +524 -145
- data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
- data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
- data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
- data/lib/ai4r/classifiers/one_r.rb +112 -44
- data/lib/ai4r/classifiers/prism.rb +167 -76
- data/lib/ai4r/classifiers/random_forest.rb +72 -0
- data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
- data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
- data/lib/ai4r/classifiers/votes.rb +57 -0
- data/lib/ai4r/classifiers/zero_r.rb +71 -30
- data/lib/ai4r/clusterers/average_linkage.rb +46 -27
- data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
- data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
- data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
- data/lib/ai4r/clusterers/clusterer.rb +29 -14
- data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
- data/lib/ai4r/clusterers/dbscan.rb +134 -0
- data/lib/ai4r/clusterers/diana.rb +75 -49
- data/lib/ai4r/clusterers/k_means.rb +270 -135
- data/lib/ai4r/clusterers/median_linkage.rb +49 -33
- data/lib/ai4r/clusterers/single_linkage.rb +196 -88
- data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
- data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
- data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
- data/lib/ai4r/data/data_set.rb +223 -103
- data/lib/ai4r/data/parameterizable.rb +31 -25
- data/lib/ai4r/data/proximity.rb +62 -62
- data/lib/ai4r/data/statistics.rb +46 -35
- data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
- data/lib/ai4r/experiment/split.rb +39 -0
- data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
- data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
- data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
- data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
- data/lib/ai4r/neural_network/activation_functions.rb +37 -0
- data/lib/ai4r/neural_network/backpropagation.rb +399 -134
- data/lib/ai4r/neural_network/hopfield.rb +175 -58
- data/lib/ai4r/neural_network/transformer.rb +194 -0
- data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
- data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
- data/lib/ai4r/reinforcement/q_learning.rb +51 -0
- data/lib/ai4r/search/a_star.rb +76 -0
- data/lib/ai4r/search/bfs.rb +50 -0
- data/lib/ai4r/search/dfs.rb +50 -0
- data/lib/ai4r/search/mcts.rb +118 -0
- data/lib/ai4r/search.rb +12 -0
- data/lib/ai4r/som/distance_metrics.rb +29 -0
- data/lib/ai4r/som/layer.rb +28 -17
- data/lib/ai4r/som/node.rb +61 -32
- data/lib/ai4r/som/som.rb +158 -41
- data/lib/ai4r/som/two_phase_layer.rb +21 -25
- data/lib/ai4r/version.rb +3 -0
- data/lib/ai4r.rb +57 -28
- metadata +79 -109
- data/README.rdoc +0 -39
- data/test/classifiers/hyperpipes_test.rb +0 -84
- data/test/classifiers/ib1_test.rb +0 -78
- data/test/classifiers/id3_test.rb +0 -220
- data/test/classifiers/multilayer_perceptron_test.rb +0 -79
- data/test/classifiers/naive_bayes_test.rb +0 -43
- data/test/classifiers/one_r_test.rb +0 -62
- data/test/classifiers/prism_test.rb +0 -85
- data/test/classifiers/simple_linear_regression_test.rb +0 -37
- data/test/classifiers/zero_r_test.rb +0 -50
- data/test/clusterers/average_linkage_test.rb +0 -51
- data/test/clusterers/bisecting_k_means_test.rb +0 -66
- data/test/clusterers/centroid_linkage_test.rb +0 -53
- data/test/clusterers/complete_linkage_test.rb +0 -57
- data/test/clusterers/diana_test.rb +0 -69
- data/test/clusterers/k_means_test.rb +0 -167
- data/test/clusterers/median_linkage_test.rb +0 -53
- data/test/clusterers/single_linkage_test.rb +0 -122
- data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
- data/test/clusterers/ward_linkage_test.rb +0 -53
- data/test/clusterers/weighted_average_linkage_test.rb +0 -53
- data/test/data/data_set_test.rb +0 -104
- data/test/data/proximity_test.rb +0 -87
- data/test/data/statistics_test.rb +0 -65
- data/test/experiment/classifier_evaluator_test.rb +0 -76
- data/test/genetic_algorithm/chromosome_test.rb +0 -57
- data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
- data/test/neural_network/backpropagation_test.rb +0 -82
- data/test/neural_network/hopfield_test.rb +0 -72
- data/test/som/som_test.rb +0 -97
data/lib/ai4r/classifiers/id3.rb
CHANGED
@@ -1,34 +1,34 @@
|
|
1
|
-
#
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: Sergio Fierens (Implementation, Quinlan is
|
2
4
|
# the creator of the algorithm)
|
3
5
|
# License:: MPL 1.1
|
4
6
|
# Project:: ai4r
|
5
|
-
# Url::
|
7
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
6
8
|
#
|
7
|
-
# You can redistribute it and/or modify it under the terms of
|
8
|
-
# the Mozilla Public License version 1.1 as published by the
|
9
|
+
# You can redistribute it and/or modify it under the terms of
|
10
|
+
# the Mozilla Public License version 1.1 as published by the
|
9
11
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
10
12
|
|
11
|
-
|
12
|
-
|
13
|
+
require_relative '../data/data_set'
|
14
|
+
require_relative '../classifiers/classifier'
|
13
15
|
|
14
16
|
module Ai4r
|
15
|
-
|
16
17
|
module Classifiers
|
17
|
-
|
18
18
|
# = Introduction
|
19
|
-
# This is an implementation of the ID3 algorithm (Quinlan)
|
20
|
-
# Given a set of preclassified examples, it builds a top-down
|
21
|
-
# induction of decision tree, biased by the information gain and
|
19
|
+
# This is an implementation of the ID3 algorithm (Quinlan)
|
20
|
+
# Given a set of preclassified examples, it builds a top-down
|
21
|
+
# induction of decision tree, biased by the information gain and
|
22
22
|
# entropy measure.
|
23
23
|
#
|
24
24
|
# * http://en.wikipedia.org/wiki/Decision_tree
|
25
25
|
# * http://en.wikipedia.org/wiki/ID3_algorithm
|
26
26
|
#
|
27
27
|
# = How to use it
|
28
|
-
#
|
28
|
+
#
|
29
29
|
# DATA_LABELS = [ 'city', 'age_range', 'gender', 'marketing_target' ]
|
30
30
|
#
|
31
|
-
# DATA_ITEMS = [
|
31
|
+
# DATA_ITEMS = [
|
32
32
|
# ['New York', '<30', 'M', 'Y'],
|
33
33
|
# ['Chicago', '<30', 'M', 'Y'],
|
34
34
|
# ['Chicago', '<30', 'F', 'Y'],
|
@@ -45,286 +45,665 @@ module Ai4r
|
|
45
45
|
# ['New York', '[50-80]', 'F', 'N'],
|
46
46
|
# ['Chicago', '>80', 'F', 'Y']
|
47
47
|
# ]
|
48
|
-
#
|
48
|
+
#
|
49
49
|
# data_set = DataSet.new(:data_items=>DATA_SET, :data_labels=>DATA_LABELS)
|
50
50
|
# id3 = Ai4r::Classifiers::ID3.new.build(data_set)
|
51
|
-
#
|
51
|
+
#
|
52
52
|
# id3.get_rules
|
53
53
|
# # => if age_range=='<30' then marketing_target='Y'
|
54
54
|
# elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'
|
55
55
|
# elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'
|
56
56
|
# elsif age_range=='[50-80]' then marketing_target='N'
|
57
57
|
# elsif age_range=='>80' then marketing_target='Y'
|
58
|
-
# else
|
59
|
-
#
|
58
|
+
# else
|
59
|
+
# raise 'There was not enough information during training to do '
|
60
|
+
# 'a proper induction for this data element'
|
61
|
+
# end
|
62
|
+
#
|
60
63
|
# id3.eval(['New York', '<30', 'M'])
|
61
64
|
# # => 'Y'
|
62
|
-
#
|
63
|
-
# = A better way to load the data
|
64
|
-
#
|
65
|
+
#
|
66
|
+
# = A better way to load the data
|
67
|
+
#
|
65
68
|
# In the real life you will use lot more data training examples, with more
|
66
|
-
# attributes. Consider moving your data to an external CSV (comma separate
|
69
|
+
# attributes. Consider moving your data to an external CSV (comma separate
|
67
70
|
# values) file.
|
68
|
-
#
|
71
|
+
#
|
69
72
|
# data_file = "#{File.dirname(__FILE__)}/data_set.csv"
|
70
73
|
# data_set = DataSet.load_csv_with_labels data_file
|
71
|
-
# id3 = Ai4r::Classifiers::ID3.new.build(data_set)
|
72
|
-
#
|
74
|
+
# id3 = Ai4r::Classifiers::ID3.new.build(data_set)
|
75
|
+
#
|
73
76
|
# = A nice tip for data evaluation
|
74
|
-
#
|
77
|
+
#
|
75
78
|
# id3 = Ai4r::Classifiers::ID3.new.build(data_set)
|
76
79
|
#
|
77
80
|
# age_range = '<30'
|
78
81
|
# marketing_target = nil
|
79
|
-
# eval id3.get_rules
|
82
|
+
# eval id3.get_rules
|
80
83
|
# puts marketing_target
|
81
|
-
# # => 'Y'
|
84
|
+
# # => 'Y'
|
82
85
|
#
|
83
86
|
# = More about ID3 and decision trees
|
84
|
-
#
|
87
|
+
#
|
85
88
|
# * http://en.wikipedia.org/wiki/Decision_tree
|
86
89
|
# * http://en.wikipedia.org/wiki/ID3_algorithm
|
87
|
-
#
|
90
|
+
#
|
88
91
|
# = About the project
|
89
92
|
# Author:: Sergio Fierens
|
90
93
|
# License:: MPL 1.1
|
91
|
-
# Url::
|
94
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
92
95
|
class ID3 < Classifier
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
+
attr_reader :data_set, :majority_class, :validation_set
|
97
|
+
|
98
|
+
parameters_info max_depth: 'Maximum recursion depth. Default is nil (no limit).',
|
99
|
+
min_gain: 'Minimum information gain required to split. Default is 0.',
|
100
|
+
on_unknown: 'Behaviour when evaluating unseen attribute values: '
|
101
|
+
|
102
|
+
# @return [Object]
|
103
|
+
def initialize
|
104
|
+
super()
|
105
|
+
@max_depth = nil
|
106
|
+
@min_gain = 0
|
107
|
+
@on_unknown = :raise
|
108
|
+
end
|
109
|
+
|
96
110
|
# Create a new ID3 classifier. You must provide a DataSet instance
|
97
111
|
# as parameter. The last attribute of each item is considered as the
|
98
112
|
# item class.
|
99
|
-
|
113
|
+
# @param data_set [Object]
|
114
|
+
# @param options [Object]
|
115
|
+
# @return [Object]
|
116
|
+
def build(data_set, options = {})
|
100
117
|
data_set.check_not_empty
|
101
118
|
@data_set = data_set
|
119
|
+
@validation_set = options[:validation_set]
|
102
120
|
preprocess_data(@data_set.data_items)
|
103
|
-
|
121
|
+
prune! if @validation_set
|
122
|
+
self
|
104
123
|
end
|
105
124
|
|
106
125
|
# You can evaluate new data, predicting its category.
|
107
126
|
# e.g.
|
108
127
|
# id3.eval(['New York', '<30', 'F']) # => 'Y'
|
128
|
+
# @param data [Object]
|
129
|
+
# @return [Object]
|
109
130
|
def eval(data)
|
110
|
-
@tree
|
131
|
+
@tree&.value(data, self)
|
111
132
|
end
|
112
133
|
|
113
134
|
# This method returns the generated rules in ruby code.
|
114
135
|
# e.g.
|
115
|
-
#
|
136
|
+
#
|
116
137
|
# id3.get_rules
|
117
138
|
# # => if age_range=='<30' then marketing_target='Y'
|
118
139
|
# elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'
|
119
140
|
# elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'
|
120
141
|
# elsif age_range=='[50-80]' then marketing_target='N'
|
121
142
|
# elsif age_range=='>80' then marketing_target='Y'
|
122
|
-
# else
|
143
|
+
# else
|
144
|
+
# raise 'There was not enough information during training to do '
|
145
|
+
# 'a proper induction for this data element'
|
146
|
+
# end
|
123
147
|
#
|
124
|
-
# It is a nice way to inspect induction results, and also to execute them:
|
148
|
+
# It is a nice way to inspect induction results, and also to execute them:
|
125
149
|
# age_range = '<30'
|
126
150
|
# marketing_target = nil
|
127
|
-
# eval id3.get_rules
|
151
|
+
# eval id3.get_rules
|
128
152
|
# puts marketing_target
|
129
153
|
# # => 'Y'
|
154
|
+
# @return [Object]
|
130
155
|
def get_rules
|
131
|
-
#return "Empty ID3 tree" if !@tree
|
156
|
+
# return "Empty ID3 tree" if !@tree
|
132
157
|
rules = @tree.get_rules
|
133
158
|
rules = rules.collect do |rule|
|
134
|
-
|
159
|
+
"#{rule[0..-2].join(' and ')} then #{rule.last}"
|
135
160
|
end
|
136
|
-
|
161
|
+
error_msg = 'There was not enough information during training to do a proper induction for this data element'
|
162
|
+
"if #{rules.join("\nelsif ")}\nelse raise '#{error_msg}' end"
|
137
163
|
end
|
138
164
|
|
139
|
-
|
165
|
+
# Return a nested Hash representation of the decision tree. This
|
166
|
+
# structure can easily be converted to JSON or other formats.
|
167
|
+
# Leaf nodes are represented by their category value, while internal
|
168
|
+
# nodes are hashes keyed by attribute value.
|
169
|
+
# @return [Object]
|
170
|
+
def to_h
|
171
|
+
@tree&.to_h
|
172
|
+
end
|
173
|
+
|
174
|
+
# Generate GraphViz DOT syntax describing the decision tree. Nodes are
|
175
|
+
# labeled with attribute names or category values and edges are labeled
|
176
|
+
# with attribute values.
|
177
|
+
# @return [Object]
|
178
|
+
def to_graphviz
|
179
|
+
return 'digraph G {}' unless @tree
|
180
|
+
|
181
|
+
lines = ['digraph G {']
|
182
|
+
@tree.to_graphviz(0, lines)
|
183
|
+
lines << '}'
|
184
|
+
lines.join("\n")
|
185
|
+
end
|
186
|
+
|
187
|
+
# Prune the decision tree using the validation set provided during build.
|
188
|
+
# Subtrees are replaced by a single leaf when this increases the
|
189
|
+
# classification accuracy on the validation data.
|
190
|
+
# @return [Object]
|
191
|
+
def prune!
|
192
|
+
return self unless @validation_set
|
193
|
+
|
194
|
+
@tree = prune_node(@tree, @validation_set.data_items)
|
195
|
+
self
|
196
|
+
end
|
197
|
+
|
198
|
+
# @param data_examples [Object]
|
199
|
+
# @return [Object]
|
140
200
|
def preprocess_data(data_examples)
|
141
|
-
@
|
201
|
+
@majority_class = most_freq(data_examples, domain(data_examples))
|
202
|
+
@tree = build_node(data_examples, [], 0)
|
142
203
|
end
|
143
204
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
return CategoryNode.new(@data_set.
|
153
|
-
|
154
|
-
|
205
|
+
# @param data_examples [Object]
|
206
|
+
# @param flag_att [Object]
|
207
|
+
# @param depth [Object]
|
208
|
+
# @return [Object]
|
209
|
+
def build_node(data_examples, flag_att = [], depth = 0)
|
210
|
+
return ErrorNode.new if data_examples.empty?
|
211
|
+
|
212
|
+
domain = domain(data_examples)
|
213
|
+
return CategoryNode.new(@data_set.category_label, domain.last[0]) if domain.last.length == 1
|
214
|
+
|
215
|
+
if flag_att.length >= domain.length - 1
|
216
|
+
return CategoryNode.new(@data_set.category_label,
|
217
|
+
most_freq(data_examples,
|
218
|
+
domain))
|
219
|
+
end
|
220
|
+
|
221
|
+
return CategoryNode.new(@data_set.category_label, most_freq(data_examples, domain)) if @max_depth && depth >= @max_depth
|
222
|
+
|
223
|
+
best_index = nil
|
224
|
+
best_entropy = nil
|
225
|
+
best_split = nil
|
226
|
+
best_threshold = nil
|
227
|
+
numeric = false
|
228
|
+
|
229
|
+
domain[0..-2].each_index do |index|
|
230
|
+
next if flag_att.include?(index)
|
231
|
+
|
232
|
+
if domain[index].all? { |v| v.is_a? Numeric }
|
233
|
+
threshold, split, entropy = best_numeric_split(data_examples, index, domain)
|
234
|
+
if best_entropy.nil? || entropy < best_entropy
|
235
|
+
best_entropy = entropy
|
236
|
+
best_index = index
|
237
|
+
best_split = split
|
238
|
+
best_threshold = threshold
|
239
|
+
numeric = true
|
240
|
+
end
|
241
|
+
else
|
242
|
+
freq_grid = freq_grid(index, data_examples, domain)
|
243
|
+
entropy = entropy(freq_grid, data_examples.length)
|
244
|
+
if best_entropy.nil? || entropy < best_entropy
|
245
|
+
best_entropy = entropy
|
246
|
+
best_index = index
|
247
|
+
best_split = split_data_examples(data_examples, domain, index)
|
248
|
+
numeric = false
|
249
|
+
end
|
250
|
+
end
|
251
|
+
end
|
252
|
+
|
253
|
+
gain = information_gain(data_examples, domain, best_index)
|
254
|
+
if gain < @min_gain
|
255
|
+
return CategoryNode.new(@data_set.category_label,
|
256
|
+
most_freq(data_examples, domain))
|
257
|
+
end
|
258
|
+
if best_split.length == 1
|
259
|
+
return CategoryNode.new(@data_set.category_label,
|
260
|
+
most_freq(data_examples, domain))
|
261
|
+
end
|
262
|
+
|
263
|
+
nodes = best_split.collect do |partial_data_examples|
|
264
|
+
build_node(partial_data_examples, numeric ? flag_att : [*flag_att, best_index], depth + 1)
|
265
|
+
end
|
266
|
+
majority = most_freq(data_examples, domain)
|
267
|
+
|
268
|
+
if numeric
|
269
|
+
EvaluationNode.new(@data_set.data_labels, best_index, best_threshold, nodes, true,
|
270
|
+
majority)
|
271
|
+
else
|
272
|
+
EvaluationNode.new(@data_set.data_labels, best_index, domain[best_index], nodes, false,
|
273
|
+
majority)
|
155
274
|
end
|
156
|
-
return EvaluationNode.new(@data_set.data_labels, min_entropy_index, domain[min_entropy_index], nodes)
|
157
275
|
end
|
158
276
|
|
159
|
-
|
277
|
+
# @param values [Object]
|
278
|
+
# @return [Object]
|
160
279
|
def self.sum(values)
|
161
|
-
values.
|
280
|
+
values.sum
|
162
281
|
end
|
163
282
|
|
164
|
-
|
283
|
+
# @param z [Object]
|
284
|
+
# @return [Object]
|
165
285
|
def self.log2(z)
|
166
|
-
return 0.0 if z
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
private
|
171
|
-
def most_freq(examples, domain)
|
172
|
-
freqs = []
|
173
|
-
domain.last.length.times { freqs << 0}
|
174
|
-
examples.each do |example|
|
175
|
-
cat_index = domain.last.index(example.last)
|
176
|
-
freq = freqs[cat_index] + 1
|
177
|
-
freqs[cat_index] = freq
|
178
|
-
end
|
179
|
-
max_freq = freqs.max
|
180
|
-
max_freq_index = freqs.index(max_freq)
|
181
|
-
domain.last[max_freq_index]
|
286
|
+
return 0.0 if z.zero?
|
287
|
+
|
288
|
+
Math.log(z) / LOG2
|
182
289
|
end
|
183
290
|
|
184
291
|
private
|
292
|
+
|
293
|
+
# @param examples [Object]
|
294
|
+
# @param domain [Object]
|
295
|
+
# @return [Object]
|
296
|
+
def most_freq(examples, _domain)
|
297
|
+
examples.map(&:last).tally.max_by { _2 }&.first
|
298
|
+
end
|
299
|
+
|
300
|
+
# @param data_examples [Object]
|
301
|
+
# @param att_index [Object]
|
302
|
+
# @return [Object]
|
303
|
+
def split_data_examples_by_value(data_examples, att_index)
|
304
|
+
att_value_examples = Hash.new { |hsh, key| hsh[key] = [] }
|
305
|
+
data_examples.each do |example|
|
306
|
+
att_value = example[att_index]
|
307
|
+
att_value_examples[att_value] << example
|
308
|
+
end
|
309
|
+
att_value_examples
|
310
|
+
end
|
311
|
+
|
312
|
+
# @param data_examples [Object]
|
313
|
+
# @param domain [Object]
|
314
|
+
# @param att_index [Object]
|
315
|
+
# @return [Object]
|
185
316
|
def split_data_examples(data_examples, domain, att_index)
|
317
|
+
att_value_examples = split_data_examples_by_value(data_examples, att_index)
|
318
|
+
attribute_domain = domain[att_index]
|
186
319
|
data_examples_array = []
|
187
|
-
att_value_examples
|
320
|
+
att_value_examples.each do |att_value, example_set|
|
321
|
+
att_value_index = attribute_domain.index(att_value)
|
322
|
+
data_examples_array[att_value_index] = example_set
|
323
|
+
end
|
324
|
+
data_examples_array
|
325
|
+
end
|
326
|
+
|
327
|
+
# @param data_examples [Object]
|
328
|
+
# @param att_index [Object]
|
329
|
+
# @param threshold [Object]
|
330
|
+
# @return [Object]
|
331
|
+
def split_data_examples_numeric(data_examples, att_index, threshold)
|
332
|
+
lower = []
|
333
|
+
higher = []
|
188
334
|
data_examples.each do |example|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
335
|
+
if example[att_index] <= threshold
|
336
|
+
lower << example
|
337
|
+
else
|
338
|
+
higher << example
|
339
|
+
end
|
193
340
|
end
|
194
|
-
|
195
|
-
|
196
|
-
|
341
|
+
[lower, higher]
|
342
|
+
end
|
343
|
+
|
344
|
+
# @param data_examples [Object]
|
345
|
+
# @param att_index [Object]
|
346
|
+
# @return [Object]
|
347
|
+
def candidate_thresholds(data_examples, att_index)
|
348
|
+
values = data_examples.collect { |d| d[att_index] }.uniq.sort
|
349
|
+
thresholds = []
|
350
|
+
values.each_cons(2) { |a, b| thresholds << ((a + b) / 2.0) }
|
351
|
+
thresholds
|
352
|
+
end
|
353
|
+
|
354
|
+
# @param split_data [Object]
|
355
|
+
# @param domain [Object]
|
356
|
+
# @return [Object]
|
357
|
+
def entropy_for_numeric_split(split_data, domain)
|
358
|
+
category_domain = domain.last
|
359
|
+
grid = split_data.collect do |subset|
|
360
|
+
counts = Array.new(category_domain.length, 0)
|
361
|
+
subset.each do |example|
|
362
|
+
cat_idx = category_domain.index(example.last)
|
363
|
+
counts[cat_idx] += 1
|
364
|
+
end
|
365
|
+
counts
|
366
|
+
end
|
367
|
+
entropy(grid, split_data[0].length + split_data[1].length)
|
368
|
+
end
|
369
|
+
|
370
|
+
# @param data_examples [Object]
|
371
|
+
# @param att_index [Object]
|
372
|
+
# @param domain [Object]
|
373
|
+
# @return [Object]
|
374
|
+
def best_numeric_split(data_examples, att_index, domain)
|
375
|
+
best_threshold = nil
|
376
|
+
best_entropy = nil
|
377
|
+
best_split = nil
|
378
|
+
candidate_thresholds(data_examples, att_index).each do |threshold|
|
379
|
+
split = split_data_examples_numeric(data_examples, att_index, threshold)
|
380
|
+
e = entropy_for_numeric_split(split, domain)
|
381
|
+
next unless best_entropy.nil? || e < best_entropy
|
382
|
+
|
383
|
+
best_entropy = e
|
384
|
+
best_threshold = threshold
|
385
|
+
best_split = split
|
197
386
|
end
|
198
|
-
|
387
|
+
[best_threshold, best_split, best_entropy]
|
199
388
|
end
|
200
389
|
|
201
|
-
|
202
|
-
|
390
|
+
# @param data_examples [Object]
|
391
|
+
# @param domain [Object]
|
392
|
+
# @param flag_att [Object]
|
393
|
+
# @return [Object]
|
394
|
+
def min_entropy_index(data_examples, domain, flag_att = [])
|
203
395
|
min_entropy = nil
|
204
396
|
min_index = 0
|
205
397
|
domain[0..-2].each_index do |index|
|
398
|
+
next if flag_att.include?(index)
|
399
|
+
|
206
400
|
freq_grid = freq_grid(index, data_examples, domain)
|
207
401
|
entropy = entropy(freq_grid, data_examples.length)
|
208
|
-
if
|
209
|
-
min_entropy = entropy
|
210
|
-
min_index = index
|
402
|
+
if !min_entropy || entropy < min_entropy
|
403
|
+
min_entropy = entropy
|
404
|
+
min_index = index
|
211
405
|
end
|
212
406
|
end
|
213
|
-
|
407
|
+
min_index
|
214
408
|
end
|
215
409
|
|
216
|
-
|
410
|
+
# @param data_examples [Object]
|
411
|
+
# @param domain [Object]
|
412
|
+
# @param att_index [Object]
|
413
|
+
# @return [Object]
|
414
|
+
def information_gain(data_examples, domain, att_index)
|
415
|
+
total_entropy = class_entropy(data_examples, domain)
|
416
|
+
freq_grid_att = freq_grid(att_index, data_examples, domain)
|
417
|
+
att_entropy = entropy(freq_grid_att, data_examples.length)
|
418
|
+
total_entropy - att_entropy
|
419
|
+
end
|
420
|
+
|
421
|
+
# @param data_examples [Object]
|
422
|
+
# @param domain [Object]
|
423
|
+
# @return [Object]
|
424
|
+
def class_entropy(data_examples, domain)
|
425
|
+
category_domain = domain.last
|
426
|
+
freqs = Array.new(category_domain.length, 0)
|
427
|
+
data_examples.each do |ex|
|
428
|
+
cat = ex.last
|
429
|
+
idx = category_domain.index(cat)
|
430
|
+
freqs[idx] += 1
|
431
|
+
end
|
432
|
+
entropy([freqs], data_examples.length)
|
433
|
+
end
|
434
|
+
|
435
|
+
# @param data_examples [Object]
|
436
|
+
# @return [Object]
|
217
437
|
def domain(data_examples)
|
218
|
-
#return build_domains(data_examples)
|
219
|
-
domain = []
|
220
|
-
@data_set.data_labels.length.times { domain << [] }
|
438
|
+
# return build_domains(data_examples)
|
439
|
+
domain = Array.new(@data_set.data_labels.length) { [] }
|
221
440
|
data_examples.each do |data|
|
222
|
-
data.
|
223
|
-
domain[i] <<
|
441
|
+
data.each_with_index do |att_value, i|
|
442
|
+
domain[i] << att_value if i < domain.length && !domain[i].include?(att_value)
|
224
443
|
end
|
225
444
|
end
|
226
|
-
|
445
|
+
domain
|
227
446
|
end
|
228
|
-
|
229
|
-
|
447
|
+
|
448
|
+
# @param att_index [Object]
|
449
|
+
# @param data_examples [Object]
|
450
|
+
# @param domain [Object]
|
451
|
+
# @return [Object]
|
230
452
|
def freq_grid(att_index, data_examples, domain)
|
231
|
-
#Initialize empty grid
|
232
|
-
|
233
|
-
domain.last
|
234
|
-
grid =
|
235
|
-
|
236
|
-
#Fill frecuency with grid
|
453
|
+
# Initialize empty grid
|
454
|
+
feature_domain = domain[att_index]
|
455
|
+
category_domain = domain.last
|
456
|
+
grid = Array.new(feature_domain.length) { Array.new(category_domain.length, 0) }
|
457
|
+
# Fill frecuency with grid
|
237
458
|
data_examples.each do |example|
|
238
459
|
att_val = example[att_index]
|
239
|
-
att_val_index =
|
460
|
+
att_val_index = feature_domain.index(att_val)
|
240
461
|
category = example.last
|
241
|
-
category_index =
|
242
|
-
|
243
|
-
grid[att_val_index][category_index] = freq
|
462
|
+
category_index = category_domain.index(category)
|
463
|
+
grid[att_val_index][category_index] += 1
|
244
464
|
end
|
245
|
-
|
465
|
+
grid
|
246
466
|
end
|
247
467
|
|
248
|
-
|
468
|
+
# @param freq_grid [Object]
|
469
|
+
# @param total_examples [Object]
|
470
|
+
# @return [Object]
|
249
471
|
def entropy(freq_grid, total_examples)
|
250
|
-
#Calc entropy of each element
|
472
|
+
# Calc entropy of each element
|
251
473
|
entropy = 0
|
252
474
|
freq_grid.each do |att_freq|
|
253
475
|
att_total_freq = ID3.sum(att_freq)
|
254
476
|
partial_entropy = 0
|
255
|
-
|
477
|
+
unless att_total_freq.zero?
|
256
478
|
att_freq.each do |freq|
|
257
|
-
prop = freq.to_f/att_total_freq
|
258
|
-
partial_entropy += (-1*prop*ID3.log2(prop))
|
479
|
+
prop = freq.to_f / att_total_freq
|
480
|
+
partial_entropy += (-1 * prop * ID3.log2(prop))
|
259
481
|
end
|
260
482
|
end
|
261
|
-
entropy += (att_total_freq.to_f/total_examples) * partial_entropy
|
483
|
+
entropy += (att_total_freq.to_f / total_examples) * partial_entropy
|
262
484
|
end
|
263
|
-
|
485
|
+
entropy
|
486
|
+
end
|
487
|
+
|
488
|
+
# @param node [Object]
|
489
|
+
# @param examples [Object]
|
490
|
+
# @return [Object]
|
491
|
+
def prune_node(node, examples)
|
492
|
+
return node if node.is_a?(CategoryNode) || node.is_a?(ErrorNode)
|
493
|
+
|
494
|
+
subsets = split_examples(node, examples)
|
495
|
+
|
496
|
+
node.nodes.each_with_index do |child, i|
|
497
|
+
node.nodes[i] = prune_node(child, subsets[i])
|
498
|
+
end
|
499
|
+
|
500
|
+
leaf = CategoryNode.new(@data_set.category_label, node.majority)
|
501
|
+
replace_with_leaf?(leaf, node, examples) ? leaf : node
|
502
|
+
end
|
503
|
+
|
504
|
+
def split_examples(node, examples)
|
505
|
+
if node.numeric
|
506
|
+
Array.new(2) { [] }.tap do |subsets|
|
507
|
+
examples.each do |ex|
|
508
|
+
idx = ex[node.index] <= node.threshold ? 0 : 1
|
509
|
+
subsets[idx] << ex
|
510
|
+
end
|
511
|
+
end
|
512
|
+
else
|
513
|
+
Array.new(node.values.length) { [] }.tap do |subsets|
|
514
|
+
examples.each do |ex|
|
515
|
+
idx = node.values.index(ex[node.index])
|
516
|
+
subsets[idx] << ex if idx
|
517
|
+
end
|
518
|
+
end
|
519
|
+
end
|
520
|
+
end
|
521
|
+
|
522
|
+
def replace_with_leaf?(leaf, node, examples)
|
523
|
+
before = accuracy_for_node(node, examples)
|
524
|
+
after = accuracy_for_node(leaf, examples)
|
525
|
+
after && before && after >= before
|
526
|
+
end
|
527
|
+
|
528
|
+
# @param node [Object]
|
529
|
+
# @param examples [Object]
|
530
|
+
# @return [Object]
|
531
|
+
def accuracy_for_node(node, examples)
|
532
|
+
return nil if examples.empty?
|
533
|
+
|
534
|
+
correct = examples.count do |ex|
|
535
|
+
node.value(ex[0..-2], self) == ex.last
|
536
|
+
end
|
537
|
+
correct.to_f / examples.length
|
264
538
|
end
|
265
539
|
|
266
|
-
private
|
267
540
|
LOG2 = Math.log(2)
|
268
541
|
end
|
269
542
|
|
270
|
-
class EvaluationNode
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
543
|
+
class EvaluationNode # :nodoc: all
|
544
|
+
attr_reader :index, :values, :nodes, :numeric, :threshold, :majority
|
545
|
+
|
546
|
+
# @param data_labels [Object]
|
547
|
+
# @param index [Object]
|
548
|
+
# @param values_or_threshold [Object]
|
549
|
+
# @param nodes [Object]
|
550
|
+
# @param numeric [Object]
|
551
|
+
# @param majority [Object]
|
552
|
+
# @return [Object]
|
553
|
+
def initialize(data_labels, index, values_or_threshold, nodes, numeric = false,
|
554
|
+
majority = nil)
|
275
555
|
@index = index
|
276
|
-
@
|
556
|
+
@numeric = numeric
|
557
|
+
if numeric
|
558
|
+
@threshold = values_or_threshold
|
559
|
+
@values = nil
|
560
|
+
else
|
561
|
+
@values = values_or_threshold
|
562
|
+
end
|
277
563
|
@nodes = nodes
|
564
|
+
@majority = majority
|
278
565
|
@data_labels = data_labels
|
279
566
|
end
|
280
|
-
|
281
|
-
|
567
|
+
|
568
|
+
# @param data [Object]
|
569
|
+
# @param classifier [Object]
|
570
|
+
# @return [Object]
|
571
|
+
def value(data, classifier = nil)
|
282
572
|
value = data[@index]
|
283
|
-
|
284
|
-
|
573
|
+
if @numeric
|
574
|
+
node = value <= @threshold ? @nodes[0] : @nodes[1]
|
575
|
+
node.value(data, classifier)
|
576
|
+
else
|
577
|
+
unless @values.include?(value)
|
578
|
+
return nil if classifier&.on_unknown == :nil
|
579
|
+
return @majority if classifier&.on_unknown == :most_frequent
|
580
|
+
|
581
|
+
return ErrorNode.new.value(data, classifier)
|
582
|
+
end
|
583
|
+
@nodes[@values.index(value)].value(data, classifier)
|
584
|
+
end
|
285
585
|
end
|
286
|
-
|
586
|
+
|
587
|
+
# @return [Object]
|
287
588
|
def get_rules
|
288
589
|
rule_set = []
|
289
|
-
@nodes.
|
290
|
-
|
291
|
-
|
590
|
+
@nodes.each_with_index do |child_node, child_node_index|
|
591
|
+
if @numeric
|
592
|
+
op = child_node_index.zero? ? '<=' : '>'
|
593
|
+
my_rule = "#{@data_labels[@index]} #{op} #{@threshold}"
|
594
|
+
else
|
595
|
+
my_rule = "#{@data_labels[@index]}=='#{@values[child_node_index]}'"
|
596
|
+
end
|
292
597
|
child_node_rules = child_node.get_rules
|
293
598
|
child_node_rules.each do |child_rule|
|
294
599
|
child_rule.unshift(my_rule)
|
295
600
|
end
|
296
601
|
rule_set += child_node_rules
|
297
602
|
end
|
298
|
-
|
603
|
+
rule_set
|
604
|
+
end
|
605
|
+
|
606
|
+
# @return [Object]
|
607
|
+
def to_h
|
608
|
+
hash = {}
|
609
|
+
@nodes.each_with_index do |child, i|
|
610
|
+
hash[@values[i]] = child.to_h
|
611
|
+
end
|
612
|
+
{ @data_labels[@index] => hash }
|
613
|
+
end
|
614
|
+
|
615
|
+
# @param id [Object]
|
616
|
+
# @param lines [Object]
|
617
|
+
# @param parent [Object]
|
618
|
+
# @param edge_label [Object]
|
619
|
+
# @return [Object]
|
620
|
+
def to_graphviz(id, lines, parent = nil, edge_label = nil)
|
621
|
+
my_id = id
|
622
|
+
lines << " node#{my_id} [label=\"#{@data_labels[@index]}\"]"
|
623
|
+
lines << " node#{parent} -> node#{my_id} [label=\"#{edge_label}\"]" if parent
|
624
|
+
next_id = my_id
|
625
|
+
@nodes.each_with_index do |child, idx|
|
626
|
+
next_id += 1
|
627
|
+
next_id = child.to_graphviz(next_id, lines, my_id, @values[idx])
|
628
|
+
end
|
629
|
+
next_id
|
299
630
|
end
|
300
|
-
|
301
631
|
end
|
302
632
|
|
303
|
-
class CategoryNode
|
633
|
+
class CategoryNode # :nodoc: all
|
634
|
+
# @param label [Object]
|
635
|
+
# @param value [Object]
|
636
|
+
# @return [Object]
|
304
637
|
def initialize(label, value)
|
305
638
|
@label = label
|
306
639
|
@value = value
|
307
640
|
end
|
308
|
-
|
309
|
-
|
641
|
+
|
642
|
+
# @param data [Object]
|
643
|
+
# @param classifier [Object]
|
644
|
+
# @return [Object]
|
645
|
+
def value(_data, _classifier = nil)
|
646
|
+
@value
|
310
647
|
end
|
648
|
+
|
649
|
+
# @return [Object]
|
311
650
|
def get_rules
|
312
|
-
|
651
|
+
[["#{@label}='#{@value}'"]]
|
652
|
+
end
|
653
|
+
|
654
|
+
# @return [Object]
|
655
|
+
def to_h
|
656
|
+
@value
|
657
|
+
end
|
658
|
+
|
659
|
+
# @param id [Object]
|
660
|
+
# @param lines [Object]
|
661
|
+
# @param parent [Object]
|
662
|
+
# @param edge_label [Object]
|
663
|
+
# @return [Object]
|
664
|
+
def to_graphviz(id, lines, parent = nil, edge_label = nil)
|
665
|
+
my_id = id
|
666
|
+
lines << " node#{my_id} [label=\"#{@value}\", shape=box]"
|
667
|
+
lines << " node#{parent} -> node#{my_id} [label=\"#{edge_label}\"]" if parent
|
668
|
+
my_id
|
313
669
|
end
|
314
670
|
end
|
315
671
|
|
672
|
+
# Raised when the training data is insufficient to build a model.
|
316
673
|
class ModelFailureError < StandardError
|
317
|
-
|
674
|
+
MSG = 'There was not enough information during training to do a proper ' \
|
675
|
+
'induction for this data element.'
|
318
676
|
end
|
319
677
|
|
320
|
-
class ErrorNode
|
321
|
-
|
322
|
-
|
678
|
+
class ErrorNode # :nodoc: all
|
679
|
+
# @param data [Object]
|
680
|
+
# @param classifier [Object]
|
681
|
+
# @return [Object]
|
682
|
+
def value(data, _classifier = nil)
|
683
|
+
raise ModelFailureError, "#{ModelFailureError::MSG} for the data element #{data}."
|
323
684
|
end
|
685
|
+
|
686
|
+
# @return [Object]
|
324
687
|
def get_rules
|
325
|
-
|
688
|
+
[]
|
326
689
|
end
|
327
|
-
end
|
328
690
|
|
691
|
+
# @return [Object]
|
692
|
+
def to_h
|
693
|
+
nil
|
694
|
+
end
|
695
|
+
|
696
|
+
# @param id [Object]
|
697
|
+
# @param lines [Object]
|
698
|
+
# @param parent [Object]
|
699
|
+
# @param edge_label [Object]
|
700
|
+
# @return [Object]
|
701
|
+
def to_graphviz(id, lines, parent = nil, edge_label = nil)
|
702
|
+
my_id = id
|
703
|
+
lines << " node#{my_id} [label=\"?\", shape=box]"
|
704
|
+
lines << " node#{parent} -> node#{my_id} [label=\"#{edge_label}\"]" if parent
|
705
|
+
my_id
|
706
|
+
end
|
707
|
+
end
|
329
708
|
end
|
330
709
|
end
|