ai4r 1.13 → 2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +174 -0
  3. data/examples/classifiers/hyperpipes_data.csv +14 -0
  4. data/examples/classifiers/hyperpipes_example.rb +22 -0
  5. data/examples/classifiers/ib1_example.rb +12 -0
  6. data/examples/classifiers/id3_example.rb +15 -10
  7. data/examples/classifiers/id3_graphviz_example.rb +17 -0
  8. data/examples/classifiers/logistic_regression_example.rb +11 -0
  9. data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
  10. data/examples/classifiers/naive_bayes_example.rb +12 -13
  11. data/examples/classifiers/one_r_example.rb +27 -0
  12. data/examples/classifiers/parameter_tutorial.rb +29 -0
  13. data/examples/classifiers/prism_nominal_example.rb +15 -0
  14. data/examples/classifiers/prism_numeric_example.rb +21 -0
  15. data/examples/classifiers/simple_linear_regression_example.rb +14 -11
  16. data/examples/classifiers/zero_and_one_r_example.rb +34 -0
  17. data/examples/classifiers/zero_one_r_data.csv +8 -0
  18. data/examples/clusterers/clusterer_example.rb +40 -34
  19. data/examples/clusterers/dbscan_example.rb +17 -0
  20. data/examples/clusterers/dendrogram_example.rb +17 -0
  21. data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
  22. data/examples/clusterers/kmeans_custom_example.rb +26 -0
  23. data/examples/genetic_algorithm/bitstring_example.rb +41 -0
  24. data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
  25. data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
  26. data/examples/neural_network/backpropagation_example.rb +48 -48
  27. data/examples/neural_network/hopfield_example.rb +45 -0
  28. data/examples/neural_network/patterns_with_base_noise.rb +39 -39
  29. data/examples/neural_network/patterns_with_noise.rb +41 -39
  30. data/examples/neural_network/train_epochs_callback.rb +25 -0
  31. data/examples/neural_network/training_patterns.rb +39 -39
  32. data/examples/neural_network/transformer_text_classification.rb +78 -0
  33. data/examples/neural_network/xor_example.rb +23 -22
  34. data/examples/reinforcement/q_learning_example.rb +10 -0
  35. data/examples/som/som_data.rb +155 -152
  36. data/examples/som/som_multi_node_example.rb +12 -13
  37. data/examples/som/som_single_example.rb +12 -15
  38. data/examples/transformer/decode_classifier_example.rb +68 -0
  39. data/examples/transformer/deterministic_example.rb +10 -0
  40. data/examples/transformer/seq2seq_example.rb +16 -0
  41. data/lib/ai4r/classifiers/classifier.rb +24 -16
  42. data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
  43. data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
  44. data/lib/ai4r/classifiers/ib1.rb +122 -32
  45. data/lib/ai4r/classifiers/id3.rb +524 -145
  46. data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
  47. data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
  48. data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
  49. data/lib/ai4r/classifiers/one_r.rb +112 -44
  50. data/lib/ai4r/classifiers/prism.rb +167 -76
  51. data/lib/ai4r/classifiers/random_forest.rb +72 -0
  52. data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
  53. data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
  54. data/lib/ai4r/classifiers/votes.rb +57 -0
  55. data/lib/ai4r/classifiers/zero_r.rb +71 -30
  56. data/lib/ai4r/clusterers/average_linkage.rb +46 -27
  57. data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
  58. data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
  59. data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
  60. data/lib/ai4r/clusterers/clusterer.rb +29 -14
  61. data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
  62. data/lib/ai4r/clusterers/dbscan.rb +134 -0
  63. data/lib/ai4r/clusterers/diana.rb +75 -49
  64. data/lib/ai4r/clusterers/k_means.rb +270 -135
  65. data/lib/ai4r/clusterers/median_linkage.rb +49 -33
  66. data/lib/ai4r/clusterers/single_linkage.rb +196 -88
  67. data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
  68. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
  69. data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
  70. data/lib/ai4r/data/data_set.rb +223 -103
  71. data/lib/ai4r/data/parameterizable.rb +31 -25
  72. data/lib/ai4r/data/proximity.rb +62 -62
  73. data/lib/ai4r/data/statistics.rb +46 -35
  74. data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
  75. data/lib/ai4r/experiment/split.rb +39 -0
  76. data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
  77. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
  78. data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
  79. data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
  80. data/lib/ai4r/neural_network/activation_functions.rb +37 -0
  81. data/lib/ai4r/neural_network/backpropagation.rb +399 -134
  82. data/lib/ai4r/neural_network/hopfield.rb +175 -58
  83. data/lib/ai4r/neural_network/transformer.rb +194 -0
  84. data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
  85. data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
  86. data/lib/ai4r/reinforcement/q_learning.rb +51 -0
  87. data/lib/ai4r/search/a_star.rb +76 -0
  88. data/lib/ai4r/search/bfs.rb +50 -0
  89. data/lib/ai4r/search/dfs.rb +50 -0
  90. data/lib/ai4r/search/mcts.rb +118 -0
  91. data/lib/ai4r/search.rb +12 -0
  92. data/lib/ai4r/som/distance_metrics.rb +29 -0
  93. data/lib/ai4r/som/layer.rb +28 -17
  94. data/lib/ai4r/som/node.rb +61 -32
  95. data/lib/ai4r/som/som.rb +158 -41
  96. data/lib/ai4r/som/two_phase_layer.rb +21 -25
  97. data/lib/ai4r/version.rb +3 -0
  98. data/lib/ai4r.rb +57 -28
  99. metadata +79 -109
  100. data/README.rdoc +0 -39
  101. data/test/classifiers/hyperpipes_test.rb +0 -84
  102. data/test/classifiers/ib1_test.rb +0 -78
  103. data/test/classifiers/id3_test.rb +0 -220
  104. data/test/classifiers/multilayer_perceptron_test.rb +0 -79
  105. data/test/classifiers/naive_bayes_test.rb +0 -43
  106. data/test/classifiers/one_r_test.rb +0 -62
  107. data/test/classifiers/prism_test.rb +0 -85
  108. data/test/classifiers/simple_linear_regression_test.rb +0 -37
  109. data/test/classifiers/zero_r_test.rb +0 -50
  110. data/test/clusterers/average_linkage_test.rb +0 -51
  111. data/test/clusterers/bisecting_k_means_test.rb +0 -66
  112. data/test/clusterers/centroid_linkage_test.rb +0 -53
  113. data/test/clusterers/complete_linkage_test.rb +0 -57
  114. data/test/clusterers/diana_test.rb +0 -69
  115. data/test/clusterers/k_means_test.rb +0 -167
  116. data/test/clusterers/median_linkage_test.rb +0 -53
  117. data/test/clusterers/single_linkage_test.rb +0 -122
  118. data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
  119. data/test/clusterers/ward_linkage_test.rb +0 -53
  120. data/test/clusterers/weighted_average_linkage_test.rb +0 -53
  121. data/test/data/data_set_test.rb +0 -104
  122. data/test/data/proximity_test.rb +0 -87
  123. data/test/data/statistics_test.rb +0 -65
  124. data/test/experiment/classifier_evaluator_test.rb +0 -76
  125. data/test/genetic_algorithm/chromosome_test.rb +0 -57
  126. data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
  127. data/test/neural_network/backpropagation_test.rb +0 -82
  128. data/test/neural_network/hopfield_test.rb +0 -72
  129. data/test/som/som_test.rb +0 -97
@@ -1,34 +1,34 @@
1
- # Author:: Sergio Fierens (Implementation, Quinlan is
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: Sergio Fierens (Implementation, Quinlan is
2
4
  # the creator of the algorithm)
3
5
  # License:: MPL 1.1
4
6
  # Project:: ai4r
5
- # Url:: http://ai4r.org/
7
+ # Url:: https://github.com/SergioFierens/ai4r
6
8
  #
7
- # You can redistribute it and/or modify it under the terms of
8
- # the Mozilla Public License version 1.1 as published by the
9
+ # You can redistribute it and/or modify it under the terms of
10
+ # the Mozilla Public License version 1.1 as published by the
9
11
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
10
12
 
11
- require File.dirname(__FILE__) + '/../data/data_set'
12
- require File.dirname(__FILE__) + '/../classifiers/classifier'
13
+ require_relative '../data/data_set'
14
+ require_relative '../classifiers/classifier'
13
15
 
14
16
  module Ai4r
15
-
16
17
  module Classifiers
17
-
18
18
  # = Introduction
19
- # This is an implementation of the ID3 algorithm (Quinlan)
20
- # Given a set of preclassified examples, it builds a top-down
21
- # induction of decision tree, biased by the information gain and
19
+ # This is an implementation of the ID3 algorithm (Quinlan)
20
+ # Given a set of preclassified examples, it builds a top-down
21
+ # induction of decision tree, biased by the information gain and
22
22
  # entropy measure.
23
23
  #
24
24
  # * http://en.wikipedia.org/wiki/Decision_tree
25
25
  # * http://en.wikipedia.org/wiki/ID3_algorithm
26
26
  #
27
27
  # = How to use it
28
- #
28
+ #
29
29
  # DATA_LABELS = [ 'city', 'age_range', 'gender', 'marketing_target' ]
30
30
  #
31
- # DATA_ITEMS = [
31
+ # DATA_ITEMS = [
32
32
  # ['New York', '<30', 'M', 'Y'],
33
33
  # ['Chicago', '<30', 'M', 'Y'],
34
34
  # ['Chicago', '<30', 'F', 'Y'],
@@ -45,286 +45,665 @@ module Ai4r
45
45
  # ['New York', '[50-80]', 'F', 'N'],
46
46
  # ['Chicago', '>80', 'F', 'Y']
47
47
  # ]
48
- #
48
+ #
49
49
  # data_set = DataSet.new(:data_items=>DATA_SET, :data_labels=>DATA_LABELS)
50
50
  # id3 = Ai4r::Classifiers::ID3.new.build(data_set)
51
- #
51
+ #
52
52
  # id3.get_rules
53
53
  # # => if age_range=='<30' then marketing_target='Y'
54
54
  # elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'
55
55
  # elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'
56
56
  # elsif age_range=='[50-80]' then marketing_target='N'
57
57
  # elsif age_range=='>80' then marketing_target='Y'
58
- # else raise 'There was not enough information during training to do a proper induction for this data element' end
59
- #
58
+ # else
59
+ # raise 'There was not enough information during training to do '
60
+ # 'a proper induction for this data element'
61
+ # end
62
+ #
60
63
  # id3.eval(['New York', '<30', 'M'])
61
64
  # # => 'Y'
62
- #
63
- # = A better way to load the data
64
- #
65
+ #
66
+ # = A better way to load the data
67
+ #
65
68
  # In the real life you will use lot more data training examples, with more
66
- # attributes. Consider moving your data to an external CSV (comma separate
69
+ # attributes. Consider moving your data to an external CSV (comma separate
67
70
  # values) file.
68
- #
71
+ #
69
72
  # data_file = "#{File.dirname(__FILE__)}/data_set.csv"
70
73
  # data_set = DataSet.load_csv_with_labels data_file
71
- # id3 = Ai4r::Classifiers::ID3.new.build(data_set)
72
- #
74
+ # id3 = Ai4r::Classifiers::ID3.new.build(data_set)
75
+ #
73
76
  # = A nice tip for data evaluation
74
- #
77
+ #
75
78
  # id3 = Ai4r::Classifiers::ID3.new.build(data_set)
76
79
  #
77
80
  # age_range = '<30'
78
81
  # marketing_target = nil
79
- # eval id3.get_rules
82
+ # eval id3.get_rules
80
83
  # puts marketing_target
81
- # # => 'Y'
84
+ # # => 'Y'
82
85
  #
83
86
  # = More about ID3 and decision trees
84
- #
87
+ #
85
88
  # * http://en.wikipedia.org/wiki/Decision_tree
86
89
  # * http://en.wikipedia.org/wiki/ID3_algorithm
87
- #
90
+ #
88
91
  # = About the project
89
92
  # Author:: Sergio Fierens
90
93
  # License:: MPL 1.1
91
- # Url:: http://ai4r.org/
94
+ # Url:: https://github.com/SergioFierens/ai4r
92
95
  class ID3 < Classifier
93
-
94
- attr_reader :data_set
95
-
96
+ attr_reader :data_set, :majority_class, :validation_set
97
+
98
+ parameters_info max_depth: 'Maximum recursion depth. Default is nil (no limit).',
99
+ min_gain: 'Minimum information gain required to split. Default is 0.',
100
+ on_unknown: 'Behaviour when evaluating unseen attribute values: '
101
+
102
+ # @return [Object]
103
+ def initialize
104
+ super()
105
+ @max_depth = nil
106
+ @min_gain = 0
107
+ @on_unknown = :raise
108
+ end
109
+
96
110
  # Create a new ID3 classifier. You must provide a DataSet instance
97
111
  # as parameter. The last attribute of each item is considered as the
98
112
  # item class.
99
- def build(data_set)
113
+ # @param data_set [Object]
114
+ # @param options [Object]
115
+ # @return [Object]
116
+ def build(data_set, options = {})
100
117
  data_set.check_not_empty
101
118
  @data_set = data_set
119
+ @validation_set = options[:validation_set]
102
120
  preprocess_data(@data_set.data_items)
103
- return self
121
+ prune! if @validation_set
122
+ self
104
123
  end
105
124
 
106
125
  # You can evaluate new data, predicting its category.
107
126
  # e.g.
108
127
  # id3.eval(['New York', '<30', 'F']) # => 'Y'
128
+ # @param data [Object]
129
+ # @return [Object]
109
130
  def eval(data)
110
- @tree.value(data) if @tree
131
+ @tree&.value(data, self)
111
132
  end
112
133
 
113
134
  # This method returns the generated rules in ruby code.
114
135
  # e.g.
115
- #
136
+ #
116
137
  # id3.get_rules
117
138
  # # => if age_range=='<30' then marketing_target='Y'
118
139
  # elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'
119
140
  # elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'
120
141
  # elsif age_range=='[50-80]' then marketing_target='N'
121
142
  # elsif age_range=='>80' then marketing_target='Y'
122
- # else raise 'There was not enough information during training to do a proper induction for this data element' end
143
+ # else
144
+ # raise 'There was not enough information during training to do '
145
+ # 'a proper induction for this data element'
146
+ # end
123
147
  #
124
- # It is a nice way to inspect induction results, and also to execute them:
148
+ # It is a nice way to inspect induction results, and also to execute them:
125
149
  # age_range = '<30'
126
150
  # marketing_target = nil
127
- # eval id3.get_rules
151
+ # eval id3.get_rules
128
152
  # puts marketing_target
129
153
  # # => 'Y'
154
+ # @return [Object]
130
155
  def get_rules
131
- #return "Empty ID3 tree" if !@tree
156
+ # return "Empty ID3 tree" if !@tree
132
157
  rules = @tree.get_rules
133
158
  rules = rules.collect do |rule|
134
- "#{rule[0..-2].join(' and ')} then #{rule.last}"
159
+ "#{rule[0..-2].join(' and ')} then #{rule.last}"
135
160
  end
136
- return "if #{rules.join("\nelsif ")}\nelse raise 'There was not enough information during training to do a proper induction for this data element' end"
161
+ error_msg = 'There was not enough information during training to do a proper induction for this data element'
162
+ "if #{rules.join("\nelsif ")}\nelse raise '#{error_msg}' end"
137
163
  end
138
164
 
139
- private
165
+ # Return a nested Hash representation of the decision tree. This
166
+ # structure can easily be converted to JSON or other formats.
167
+ # Leaf nodes are represented by their category value, while internal
168
+ # nodes are hashes keyed by attribute value.
169
+ # @return [Object]
170
+ def to_h
171
+ @tree&.to_h
172
+ end
173
+
174
+ # Generate GraphViz DOT syntax describing the decision tree. Nodes are
175
+ # labeled with attribute names or category values and edges are labeled
176
+ # with attribute values.
177
+ # @return [Object]
178
+ def to_graphviz
179
+ return 'digraph G {}' unless @tree
180
+
181
+ lines = ['digraph G {']
182
+ @tree.to_graphviz(0, lines)
183
+ lines << '}'
184
+ lines.join("\n")
185
+ end
186
+
187
+ # Prune the decision tree using the validation set provided during build.
188
+ # Subtrees are replaced by a single leaf when this increases the
189
+ # classification accuracy on the validation data.
190
+ # @return [Object]
191
+ def prune!
192
+ return self unless @validation_set
193
+
194
+ @tree = prune_node(@tree, @validation_set.data_items)
195
+ self
196
+ end
197
+
198
+ # @param data_examples [Object]
199
+ # @return [Object]
140
200
  def preprocess_data(data_examples)
141
- @tree = build_node(data_examples)
201
+ @majority_class = most_freq(data_examples, domain(data_examples))
202
+ @tree = build_node(data_examples, [], 0)
142
203
  end
143
204
 
144
- private
145
- def build_node(data_examples, flag_att = [])
146
- return ErrorNode.new if data_examples.length == 0
147
- domain = domain(data_examples)
148
- return CategoryNode.new(@data_set.data_labels.last, domain.last[0]) if domain.last.length == 1
149
- min_entropy_index = min_entropy_index(data_examples, domain, flag_att)
150
- flag_att << min_entropy_index
151
- split_data_examples = split_data_examples(data_examples, domain, min_entropy_index)
152
- return CategoryNode.new(@data_set.data_labels.last, most_freq(data_examples, domain)) if split_data_examples.length == 1
153
- nodes = split_data_examples.collect do |partial_data_examples|
154
- build_node(partial_data_examples, flag_att)
205
+ # @param data_examples [Object]
206
+ # @param flag_att [Object]
207
+ # @param depth [Object]
208
+ # @return [Object]
209
+ def build_node(data_examples, flag_att = [], depth = 0)
210
+ return ErrorNode.new if data_examples.empty?
211
+
212
+ domain = domain(data_examples)
213
+ return CategoryNode.new(@data_set.category_label, domain.last[0]) if domain.last.length == 1
214
+
215
+ if flag_att.length >= domain.length - 1
216
+ return CategoryNode.new(@data_set.category_label,
217
+ most_freq(data_examples,
218
+ domain))
219
+ end
220
+
221
+ return CategoryNode.new(@data_set.category_label, most_freq(data_examples, domain)) if @max_depth && depth >= @max_depth
222
+
223
+ best_index = nil
224
+ best_entropy = nil
225
+ best_split = nil
226
+ best_threshold = nil
227
+ numeric = false
228
+
229
+ domain[0..-2].each_index do |index|
230
+ next if flag_att.include?(index)
231
+
232
+ if domain[index].all? { |v| v.is_a? Numeric }
233
+ threshold, split, entropy = best_numeric_split(data_examples, index, domain)
234
+ if best_entropy.nil? || entropy < best_entropy
235
+ best_entropy = entropy
236
+ best_index = index
237
+ best_split = split
238
+ best_threshold = threshold
239
+ numeric = true
240
+ end
241
+ else
242
+ freq_grid = freq_grid(index, data_examples, domain)
243
+ entropy = entropy(freq_grid, data_examples.length)
244
+ if best_entropy.nil? || entropy < best_entropy
245
+ best_entropy = entropy
246
+ best_index = index
247
+ best_split = split_data_examples(data_examples, domain, index)
248
+ numeric = false
249
+ end
250
+ end
251
+ end
252
+
253
+ gain = information_gain(data_examples, domain, best_index)
254
+ if gain < @min_gain
255
+ return CategoryNode.new(@data_set.category_label,
256
+ most_freq(data_examples, domain))
257
+ end
258
+ if best_split.length == 1
259
+ return CategoryNode.new(@data_set.category_label,
260
+ most_freq(data_examples, domain))
261
+ end
262
+
263
+ nodes = best_split.collect do |partial_data_examples|
264
+ build_node(partial_data_examples, numeric ? flag_att : [*flag_att, best_index], depth + 1)
265
+ end
266
+ majority = most_freq(data_examples, domain)
267
+
268
+ if numeric
269
+ EvaluationNode.new(@data_set.data_labels, best_index, best_threshold, nodes, true,
270
+ majority)
271
+ else
272
+ EvaluationNode.new(@data_set.data_labels, best_index, domain[best_index], nodes, false,
273
+ majority)
155
274
  end
156
- return EvaluationNode.new(@data_set.data_labels, min_entropy_index, domain[min_entropy_index], nodes)
157
275
  end
158
276
 
159
- private
277
+ # @param values [Object]
278
+ # @return [Object]
160
279
  def self.sum(values)
161
- values.inject( 0 ) { |sum,x| sum+x }
280
+ values.sum
162
281
  end
163
282
 
164
- private
283
+ # @param z [Object]
284
+ # @return [Object]
165
285
  def self.log2(z)
166
- return 0.0 if z == 0
167
- Math.log(z)/LOG2
168
- end
169
-
170
- private
171
- def most_freq(examples, domain)
172
- freqs = []
173
- domain.last.length.times { freqs << 0}
174
- examples.each do |example|
175
- cat_index = domain.last.index(example.last)
176
- freq = freqs[cat_index] + 1
177
- freqs[cat_index] = freq
178
- end
179
- max_freq = freqs.max
180
- max_freq_index = freqs.index(max_freq)
181
- domain.last[max_freq_index]
286
+ return 0.0 if z.zero?
287
+
288
+ Math.log(z) / LOG2
182
289
  end
183
290
 
184
291
  private
292
+
293
+ # @param examples [Object]
294
+ # @param domain [Object]
295
+ # @return [Object]
296
+ def most_freq(examples, _domain)
297
+ examples.map(&:last).tally.max_by { _2 }&.first
298
+ end
299
+
300
+ # @param data_examples [Object]
301
+ # @param att_index [Object]
302
+ # @return [Object]
303
+ def split_data_examples_by_value(data_examples, att_index)
304
+ att_value_examples = Hash.new { |hsh, key| hsh[key] = [] }
305
+ data_examples.each do |example|
306
+ att_value = example[att_index]
307
+ att_value_examples[att_value] << example
308
+ end
309
+ att_value_examples
310
+ end
311
+
312
+ # @param data_examples [Object]
313
+ # @param domain [Object]
314
+ # @param att_index [Object]
315
+ # @return [Object]
185
316
  def split_data_examples(data_examples, domain, att_index)
317
+ att_value_examples = split_data_examples_by_value(data_examples, att_index)
318
+ attribute_domain = domain[att_index]
186
319
  data_examples_array = []
187
- att_value_examples = {}
320
+ att_value_examples.each do |att_value, example_set|
321
+ att_value_index = attribute_domain.index(att_value)
322
+ data_examples_array[att_value_index] = example_set
323
+ end
324
+ data_examples_array
325
+ end
326
+
327
+ # @param data_examples [Object]
328
+ # @param att_index [Object]
329
+ # @param threshold [Object]
330
+ # @return [Object]
331
+ def split_data_examples_numeric(data_examples, att_index, threshold)
332
+ lower = []
333
+ higher = []
188
334
  data_examples.each do |example|
189
- example_set = att_value_examples[example[att_index]]
190
- example_set = [] if !example_set
191
- example_set << example
192
- att_value_examples.store(example[att_index], example_set)
335
+ if example[att_index] <= threshold
336
+ lower << example
337
+ else
338
+ higher << example
339
+ end
193
340
  end
194
- att_value_examples.each_pair do |att_value, example_set|
195
- att_value_index = domain[att_index].index(att_value)
196
- data_examples_array[att_value_index] = example_set
341
+ [lower, higher]
342
+ end
343
+
344
+ # @param data_examples [Object]
345
+ # @param att_index [Object]
346
+ # @return [Object]
347
+ def candidate_thresholds(data_examples, att_index)
348
+ values = data_examples.collect { |d| d[att_index] }.uniq.sort
349
+ thresholds = []
350
+ values.each_cons(2) { |a, b| thresholds << ((a + b) / 2.0) }
351
+ thresholds
352
+ end
353
+
354
+ # @param split_data [Object]
355
+ # @param domain [Object]
356
+ # @return [Object]
357
+ def entropy_for_numeric_split(split_data, domain)
358
+ category_domain = domain.last
359
+ grid = split_data.collect do |subset|
360
+ counts = Array.new(category_domain.length, 0)
361
+ subset.each do |example|
362
+ cat_idx = category_domain.index(example.last)
363
+ counts[cat_idx] += 1
364
+ end
365
+ counts
366
+ end
367
+ entropy(grid, split_data[0].length + split_data[1].length)
368
+ end
369
+
370
+ # @param data_examples [Object]
371
+ # @param att_index [Object]
372
+ # @param domain [Object]
373
+ # @return [Object]
374
+ def best_numeric_split(data_examples, att_index, domain)
375
+ best_threshold = nil
376
+ best_entropy = nil
377
+ best_split = nil
378
+ candidate_thresholds(data_examples, att_index).each do |threshold|
379
+ split = split_data_examples_numeric(data_examples, att_index, threshold)
380
+ e = entropy_for_numeric_split(split, domain)
381
+ next unless best_entropy.nil? || e < best_entropy
382
+
383
+ best_entropy = e
384
+ best_threshold = threshold
385
+ best_split = split
197
386
  end
198
- return data_examples_array
387
+ [best_threshold, best_split, best_entropy]
199
388
  end
200
389
 
201
- private
202
- def min_entropy_index(data_examples, domain, flag_att=[])
390
+ # @param data_examples [Object]
391
+ # @param domain [Object]
392
+ # @param flag_att [Object]
393
+ # @return [Object]
394
+ def min_entropy_index(data_examples, domain, flag_att = [])
203
395
  min_entropy = nil
204
396
  min_index = 0
205
397
  domain[0..-2].each_index do |index|
398
+ next if flag_att.include?(index)
399
+
206
400
  freq_grid = freq_grid(index, data_examples, domain)
207
401
  entropy = entropy(freq_grid, data_examples.length)
208
- if (!min_entropy || entropy < min_entropy) && !flag_att.include?(index)
209
- min_entropy = entropy
210
- min_index = index
402
+ if !min_entropy || entropy < min_entropy
403
+ min_entropy = entropy
404
+ min_index = index
211
405
  end
212
406
  end
213
- return min_index
407
+ min_index
214
408
  end
215
409
 
216
- private
410
+ # @param data_examples [Object]
411
+ # @param domain [Object]
412
+ # @param att_index [Object]
413
+ # @return [Object]
414
+ def information_gain(data_examples, domain, att_index)
415
+ total_entropy = class_entropy(data_examples, domain)
416
+ freq_grid_att = freq_grid(att_index, data_examples, domain)
417
+ att_entropy = entropy(freq_grid_att, data_examples.length)
418
+ total_entropy - att_entropy
419
+ end
420
+
421
+ # @param data_examples [Object]
422
+ # @param domain [Object]
423
+ # @return [Object]
424
+ def class_entropy(data_examples, domain)
425
+ category_domain = domain.last
426
+ freqs = Array.new(category_domain.length, 0)
427
+ data_examples.each do |ex|
428
+ cat = ex.last
429
+ idx = category_domain.index(cat)
430
+ freqs[idx] += 1
431
+ end
432
+ entropy([freqs], data_examples.length)
433
+ end
434
+
435
+ # @param data_examples [Object]
436
+ # @return [Object]
217
437
  def domain(data_examples)
218
- #return build_domains(data_examples)
219
- domain = []
220
- @data_set.data_labels.length.times { domain << [] }
438
+ # return build_domains(data_examples)
439
+ domain = Array.new(@data_set.data_labels.length) { [] }
221
440
  data_examples.each do |data|
222
- data.each_index do |i|
223
- domain[i] << data[i] if i<domain.length && !domain[i].include?(data[i])
441
+ data.each_with_index do |att_value, i|
442
+ domain[i] << att_value if i < domain.length && !domain[i].include?(att_value)
224
443
  end
225
444
  end
226
- return domain
445
+ domain
227
446
  end
228
-
229
- private
447
+
448
+ # @param att_index [Object]
449
+ # @param data_examples [Object]
450
+ # @param domain [Object]
451
+ # @return [Object]
230
452
  def freq_grid(att_index, data_examples, domain)
231
- #Initialize empty grid
232
- grid_element = []
233
- domain.last.length.times { grid_element << 0}
234
- grid = []
235
- domain[att_index].length.times { grid << grid_element.clone }
236
- #Fill frecuency with grid
453
+ # Initialize empty grid
454
+ feature_domain = domain[att_index]
455
+ category_domain = domain.last
456
+ grid = Array.new(feature_domain.length) { Array.new(category_domain.length, 0) }
457
+ # Fill frecuency with grid
237
458
  data_examples.each do |example|
238
459
  att_val = example[att_index]
239
- att_val_index = domain[att_index].index(att_val)
460
+ att_val_index = feature_domain.index(att_val)
240
461
  category = example.last
241
- category_index = domain.last.index(category)
242
- freq = grid[att_val_index][category_index] + 1
243
- grid[att_val_index][category_index] = freq
462
+ category_index = category_domain.index(category)
463
+ grid[att_val_index][category_index] += 1
244
464
  end
245
- return grid
465
+ grid
246
466
  end
247
467
 
248
- private
468
+ # @param freq_grid [Object]
469
+ # @param total_examples [Object]
470
+ # @return [Object]
249
471
  def entropy(freq_grid, total_examples)
250
- #Calc entropy of each element
472
+ # Calc entropy of each element
251
473
  entropy = 0
252
474
  freq_grid.each do |att_freq|
253
475
  att_total_freq = ID3.sum(att_freq)
254
476
  partial_entropy = 0
255
- if att_total_freq != 0
477
+ unless att_total_freq.zero?
256
478
  att_freq.each do |freq|
257
- prop = freq.to_f/att_total_freq
258
- partial_entropy += (-1*prop*ID3.log2(prop))
479
+ prop = freq.to_f / att_total_freq
480
+ partial_entropy += (-1 * prop * ID3.log2(prop))
259
481
  end
260
482
  end
261
- entropy += (att_total_freq.to_f/total_examples) * partial_entropy
483
+ entropy += (att_total_freq.to_f / total_examples) * partial_entropy
262
484
  end
263
- return entropy
485
+ entropy
486
+ end
487
+
488
+ # @param node [Object]
489
+ # @param examples [Object]
490
+ # @return [Object]
491
+ def prune_node(node, examples)
492
+ return node if node.is_a?(CategoryNode) || node.is_a?(ErrorNode)
493
+
494
+ subsets = split_examples(node, examples)
495
+
496
+ node.nodes.each_with_index do |child, i|
497
+ node.nodes[i] = prune_node(child, subsets[i])
498
+ end
499
+
500
+ leaf = CategoryNode.new(@data_set.category_label, node.majority)
501
+ replace_with_leaf?(leaf, node, examples) ? leaf : node
502
+ end
503
+
504
+ def split_examples(node, examples)
505
+ if node.numeric
506
+ Array.new(2) { [] }.tap do |subsets|
507
+ examples.each do |ex|
508
+ idx = ex[node.index] <= node.threshold ? 0 : 1
509
+ subsets[idx] << ex
510
+ end
511
+ end
512
+ else
513
+ Array.new(node.values.length) { [] }.tap do |subsets|
514
+ examples.each do |ex|
515
+ idx = node.values.index(ex[node.index])
516
+ subsets[idx] << ex if idx
517
+ end
518
+ end
519
+ end
520
+ end
521
+
522
+ def replace_with_leaf?(leaf, node, examples)
523
+ before = accuracy_for_node(node, examples)
524
+ after = accuracy_for_node(leaf, examples)
525
+ after && before && after >= before
526
+ end
527
+
528
+ # @param node [Object]
529
+ # @param examples [Object]
530
+ # @return [Object]
531
+ def accuracy_for_node(node, examples)
532
+ return nil if examples.empty?
533
+
534
+ correct = examples.count do |ex|
535
+ node.value(ex[0..-2], self) == ex.last
536
+ end
537
+ correct.to_f / examples.length
264
538
  end
265
539
 
266
- private
267
540
  LOG2 = Math.log(2)
268
541
  end
269
542
 
270
- class EvaluationNode #:nodoc: all
271
-
272
- attr_reader :index, :values, :nodes
273
-
274
- def initialize(data_labels, index, values, nodes)
543
+ class EvaluationNode # :nodoc: all
544
+ attr_reader :index, :values, :nodes, :numeric, :threshold, :majority
545
+
546
+ # @param data_labels [Object]
547
+ # @param index [Object]
548
+ # @param values_or_threshold [Object]
549
+ # @param nodes [Object]
550
+ # @param numeric [Object]
551
+ # @param majority [Object]
552
+ # @return [Object]
553
+ def initialize(data_labels, index, values_or_threshold, nodes, numeric = false,
554
+ majority = nil)
275
555
  @index = index
276
- @values = values
556
+ @numeric = numeric
557
+ if numeric
558
+ @threshold = values_or_threshold
559
+ @values = nil
560
+ else
561
+ @values = values_or_threshold
562
+ end
277
563
  @nodes = nodes
564
+ @majority = majority
278
565
  @data_labels = data_labels
279
566
  end
280
-
281
- def value(data)
567
+
568
+ # @param data [Object]
569
+ # @param classifier [Object]
570
+ # @return [Object]
571
+ def value(data, classifier = nil)
282
572
  value = data[@index]
283
- return ErrorNode.new.value(data) if !@values.include?(value)
284
- return nodes[@values.index(value)].value(data)
573
+ if @numeric
574
+ node = value <= @threshold ? @nodes[0] : @nodes[1]
575
+ node.value(data, classifier)
576
+ else
577
+ unless @values.include?(value)
578
+ return nil if classifier&.on_unknown == :nil
579
+ return @majority if classifier&.on_unknown == :most_frequent
580
+
581
+ return ErrorNode.new.value(data, classifier)
582
+ end
583
+ @nodes[@values.index(value)].value(data, classifier)
584
+ end
285
585
  end
286
-
586
+
587
+ # @return [Object]
287
588
  def get_rules
288
589
  rule_set = []
289
- @nodes.each_index do |child_node_index|
290
- my_rule = "#{@data_labels[@index]}=='#{@values[child_node_index]}'"
291
- child_node = @nodes[child_node_index]
590
+ @nodes.each_with_index do |child_node, child_node_index|
591
+ if @numeric
592
+ op = child_node_index.zero? ? '<=' : '>'
593
+ my_rule = "#{@data_labels[@index]} #{op} #{@threshold}"
594
+ else
595
+ my_rule = "#{@data_labels[@index]}=='#{@values[child_node_index]}'"
596
+ end
292
597
  child_node_rules = child_node.get_rules
293
598
  child_node_rules.each do |child_rule|
294
599
  child_rule.unshift(my_rule)
295
600
  end
296
601
  rule_set += child_node_rules
297
602
  end
298
- return rule_set
603
+ rule_set
604
+ end
605
+
606
+ # @return [Object]
607
+ def to_h
608
+ hash = {}
609
+ @nodes.each_with_index do |child, i|
610
+ hash[@values[i]] = child.to_h
611
+ end
612
+ { @data_labels[@index] => hash }
613
+ end
614
+
615
+ # @param id [Object]
616
+ # @param lines [Object]
617
+ # @param parent [Object]
618
+ # @param edge_label [Object]
619
+ # @return [Object]
620
+ def to_graphviz(id, lines, parent = nil, edge_label = nil)
621
+ my_id = id
622
+ lines << " node#{my_id} [label=\"#{@data_labels[@index]}\"]"
623
+ lines << " node#{parent} -> node#{my_id} [label=\"#{edge_label}\"]" if parent
624
+ next_id = my_id
625
+ @nodes.each_with_index do |child, idx|
626
+ next_id += 1
627
+ next_id = child.to_graphviz(next_id, lines, my_id, @values[idx])
628
+ end
629
+ next_id
299
630
  end
300
-
301
631
  end
302
632
 
303
- class CategoryNode #:nodoc: all
633
+ class CategoryNode # :nodoc: all
634
+ # @param label [Object]
635
+ # @param value [Object]
636
+ # @return [Object]
304
637
  def initialize(label, value)
305
638
  @label = label
306
639
  @value = value
307
640
  end
308
- def value(data)
309
- return @value
641
+
642
+ # @param data [Object]
643
+ # @param classifier [Object]
644
+ # @return [Object]
645
+ def value(_data, _classifier = nil)
646
+ @value
310
647
  end
648
+
649
+ # @return [Object]
311
650
  def get_rules
312
- return [["#{@label}='#{@value}'"]]
651
+ [["#{@label}='#{@value}'"]]
652
+ end
653
+
654
+ # @return [Object]
655
+ def to_h
656
+ @value
657
+ end
658
+
659
+ # @param id [Object]
660
+ # @param lines [Object]
661
+ # @param parent [Object]
662
+ # @param edge_label [Object]
663
+ # @return [Object]
664
+ def to_graphviz(id, lines, parent = nil, edge_label = nil)
665
+ my_id = id
666
+ lines << " node#{my_id} [label=\"#{@value}\", shape=box]"
667
+ lines << " node#{parent} -> node#{my_id} [label=\"#{edge_label}\"]" if parent
668
+ my_id
313
669
  end
314
670
  end
315
671
 
672
+ # Raised when the training data is insufficient to build a model.
316
673
  class ModelFailureError < StandardError
317
- default_message = "There was not enough information during training to do a proper induction for this data element."
674
+ MSG = 'There was not enough information during training to do a proper ' \
675
+ 'induction for this data element.'
318
676
  end
319
677
 
320
- class ErrorNode #:nodoc: all
321
- def value(data)
322
- raise ModelFailureError, "There was not enough information during training to do a proper induction for the data element #{data}."
678
+ class ErrorNode # :nodoc: all
679
+ # @param data [Object]
680
+ # @param classifier [Object]
681
+ # @return [Object]
682
+ def value(data, _classifier = nil)
683
+ raise ModelFailureError, "#{ModelFailureError::MSG} for the data element #{data}."
323
684
  end
685
+
686
+ # @return [Object]
324
687
  def get_rules
325
- return []
688
+ []
326
689
  end
327
- end
328
690
 
691
+ # @return [Object]
692
+ def to_h
693
+ nil
694
+ end
695
+
696
+ # @param id [Object]
697
+ # @param lines [Object]
698
+ # @param parent [Object]
699
+ # @param edge_label [Object]
700
+ # @return [Object]
701
+ def to_graphviz(id, lines, parent = nil, edge_label = nil)
702
+ my_id = id
703
+ lines << " node#{my_id} [label=\"?\", shape=box]"
704
+ lines << " node#{parent} -> node#{my_id} [label=\"#{edge_label}\"]" if parent
705
+ my_id
706
+ end
707
+ end
329
708
  end
330
709
  end