ai4r 1.13 → 2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +174 -0
  3. data/examples/classifiers/hyperpipes_data.csv +14 -0
  4. data/examples/classifiers/hyperpipes_example.rb +22 -0
  5. data/examples/classifiers/ib1_example.rb +12 -0
  6. data/examples/classifiers/id3_example.rb +15 -10
  7. data/examples/classifiers/id3_graphviz_example.rb +17 -0
  8. data/examples/classifiers/logistic_regression_example.rb +11 -0
  9. data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
  10. data/examples/classifiers/naive_bayes_example.rb +12 -13
  11. data/examples/classifiers/one_r_example.rb +27 -0
  12. data/examples/classifiers/parameter_tutorial.rb +29 -0
  13. data/examples/classifiers/prism_nominal_example.rb +15 -0
  14. data/examples/classifiers/prism_numeric_example.rb +21 -0
  15. data/examples/classifiers/simple_linear_regression_example.rb +14 -11
  16. data/examples/classifiers/zero_and_one_r_example.rb +34 -0
  17. data/examples/classifiers/zero_one_r_data.csv +8 -0
  18. data/examples/clusterers/clusterer_example.rb +40 -34
  19. data/examples/clusterers/dbscan_example.rb +17 -0
  20. data/examples/clusterers/dendrogram_example.rb +17 -0
  21. data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
  22. data/examples/clusterers/kmeans_custom_example.rb +26 -0
  23. data/examples/genetic_algorithm/bitstring_example.rb +41 -0
  24. data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
  25. data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
  26. data/examples/neural_network/backpropagation_example.rb +48 -48
  27. data/examples/neural_network/hopfield_example.rb +45 -0
  28. data/examples/neural_network/patterns_with_base_noise.rb +39 -39
  29. data/examples/neural_network/patterns_with_noise.rb +41 -39
  30. data/examples/neural_network/train_epochs_callback.rb +25 -0
  31. data/examples/neural_network/training_patterns.rb +39 -39
  32. data/examples/neural_network/transformer_text_classification.rb +78 -0
  33. data/examples/neural_network/xor_example.rb +23 -22
  34. data/examples/reinforcement/q_learning_example.rb +10 -0
  35. data/examples/som/som_data.rb +155 -152
  36. data/examples/som/som_multi_node_example.rb +12 -13
  37. data/examples/som/som_single_example.rb +12 -15
  38. data/examples/transformer/decode_classifier_example.rb +68 -0
  39. data/examples/transformer/deterministic_example.rb +10 -0
  40. data/examples/transformer/seq2seq_example.rb +16 -0
  41. data/lib/ai4r/classifiers/classifier.rb +24 -16
  42. data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
  43. data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
  44. data/lib/ai4r/classifiers/ib1.rb +122 -32
  45. data/lib/ai4r/classifiers/id3.rb +524 -145
  46. data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
  47. data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
  48. data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
  49. data/lib/ai4r/classifiers/one_r.rb +112 -44
  50. data/lib/ai4r/classifiers/prism.rb +167 -76
  51. data/lib/ai4r/classifiers/random_forest.rb +72 -0
  52. data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
  53. data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
  54. data/lib/ai4r/classifiers/votes.rb +57 -0
  55. data/lib/ai4r/classifiers/zero_r.rb +71 -30
  56. data/lib/ai4r/clusterers/average_linkage.rb +46 -27
  57. data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
  58. data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
  59. data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
  60. data/lib/ai4r/clusterers/clusterer.rb +29 -14
  61. data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
  62. data/lib/ai4r/clusterers/dbscan.rb +134 -0
  63. data/lib/ai4r/clusterers/diana.rb +75 -49
  64. data/lib/ai4r/clusterers/k_means.rb +270 -135
  65. data/lib/ai4r/clusterers/median_linkage.rb +49 -33
  66. data/lib/ai4r/clusterers/single_linkage.rb +196 -88
  67. data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
  68. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
  69. data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
  70. data/lib/ai4r/data/data_set.rb +223 -103
  71. data/lib/ai4r/data/parameterizable.rb +31 -25
  72. data/lib/ai4r/data/proximity.rb +62 -62
  73. data/lib/ai4r/data/statistics.rb +46 -35
  74. data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
  75. data/lib/ai4r/experiment/split.rb +39 -0
  76. data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
  77. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
  78. data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
  79. data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
  80. data/lib/ai4r/neural_network/activation_functions.rb +37 -0
  81. data/lib/ai4r/neural_network/backpropagation.rb +399 -134
  82. data/lib/ai4r/neural_network/hopfield.rb +175 -58
  83. data/lib/ai4r/neural_network/transformer.rb +194 -0
  84. data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
  85. data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
  86. data/lib/ai4r/reinforcement/q_learning.rb +51 -0
  87. data/lib/ai4r/search/a_star.rb +76 -0
  88. data/lib/ai4r/search/bfs.rb +50 -0
  89. data/lib/ai4r/search/dfs.rb +50 -0
  90. data/lib/ai4r/search/mcts.rb +118 -0
  91. data/lib/ai4r/search.rb +12 -0
  92. data/lib/ai4r/som/distance_metrics.rb +29 -0
  93. data/lib/ai4r/som/layer.rb +28 -17
  94. data/lib/ai4r/som/node.rb +61 -32
  95. data/lib/ai4r/som/som.rb +158 -41
  96. data/lib/ai4r/som/two_phase_layer.rb +21 -25
  97. data/lib/ai4r/version.rb +3 -0
  98. data/lib/ai4r.rb +57 -28
  99. metadata +79 -109
  100. data/README.rdoc +0 -39
  101. data/test/classifiers/hyperpipes_test.rb +0 -84
  102. data/test/classifiers/ib1_test.rb +0 -78
  103. data/test/classifiers/id3_test.rb +0 -220
  104. data/test/classifiers/multilayer_perceptron_test.rb +0 -79
  105. data/test/classifiers/naive_bayes_test.rb +0 -43
  106. data/test/classifiers/one_r_test.rb +0 -62
  107. data/test/classifiers/prism_test.rb +0 -85
  108. data/test/classifiers/simple_linear_regression_test.rb +0 -37
  109. data/test/classifiers/zero_r_test.rb +0 -50
  110. data/test/clusterers/average_linkage_test.rb +0 -51
  111. data/test/clusterers/bisecting_k_means_test.rb +0 -66
  112. data/test/clusterers/centroid_linkage_test.rb +0 -53
  113. data/test/clusterers/complete_linkage_test.rb +0 -57
  114. data/test/clusterers/diana_test.rb +0 -69
  115. data/test/clusterers/k_means_test.rb +0 -167
  116. data/test/clusterers/median_linkage_test.rb +0 -53
  117. data/test/clusterers/single_linkage_test.rb +0 -122
  118. data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
  119. data/test/clusterers/ward_linkage_test.rb +0 -53
  120. data/test/clusterers/weighted_average_linkage_test.rb +0 -53
  121. data/test/data/data_set_test.rb +0 -104
  122. data/test/data/proximity_test.rb +0 -87
  123. data/test/data/statistics_test.rb +0 -65
  124. data/test/experiment/classifier_evaluator_test.rb +0 -76
  125. data/test/genetic_algorithm/chromosome_test.rb +0 -57
  126. data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
  127. data/test/neural_network/backpropagation_test.rb +0 -82
  128. data/test/neural_network/hopfield_test.rb +0 -72
  129. data/test/som/som_test.rb +0 -97
@@ -1,110 +1,178 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens (Implementation only)
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
- # Url:: http://www.ai4r.org/
6
+ # Url:: https://github.com/SergioFierens/ai4r
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
12
  require 'set'
11
- require File.dirname(__FILE__) + '/../data/data_set'
12
- require File.dirname(__FILE__) + '/../classifiers/classifier'
13
+ require_relative '../data/data_set'
14
+ require_relative '../classifiers/classifier'
13
15
 
14
16
  module Ai4r
15
17
  module Classifiers
16
-
17
18
  # = Introduction
18
- #
19
+ #
19
20
  # The idea of the OneR algorithm is identify the single
20
- # attribute to use to classify data that makes
21
+ # attribute to use to classify data that makes
21
22
  # fewest prediction errors.
22
23
  # It generates rules based on a single attribute.
24
+ # Numeric attributes are automatically discretized into a fixed
25
+ # number of bins (default is 10).
23
26
  class OneR < Classifier
24
-
25
27
  attr_reader :data_set, :rule
26
28
 
29
+ parameters_info selected_attribute: 'Index of the attribute to force.',
30
+ tie_break: 'Strategy when two attributes yield the same accuracy.',
31
+ bin_count: 'Number of bins used to discretize numeric attributes.'
32
+
33
+ # @return [Object]
34
+ def initialize
35
+ super()
36
+ @selected_attribute = nil
37
+ @tie_break = :first
38
+ @bin_count = 10
39
+ end
40
+
27
41
  # Build a new OneR classifier. You must provide a DataSet instance
28
- # as parameter. The last attribute of each item is considered as
42
+ # as parameter. The last attribute of each item is considered as
29
43
  # the item class.
44
+ # @param data_set [Object]
45
+ # @return [Object]
30
46
  def build(data_set)
31
47
  data_set.check_not_empty
32
48
  @data_set = data_set
33
- if (data_set.num_attributes == 1)
49
+ if data_set.num_attributes == 1
34
50
  @zero_r = ZeroR.new.build(data_set)
35
- return self;
51
+ return self
36
52
  else
37
- @zero_r = nil;
53
+ @zero_r = nil
38
54
  end
39
55
  domains = @data_set.build_domains
40
56
  @rule = nil
41
- domains[1...-1].each_index do |attr_index|
42
- rule = build_rule(@data_set.data_items, attr_index, domains)
43
- @rule = rule if !@rule || rule[:correct] > @rule[:correct]
57
+ if @selected_attribute
58
+ @rule = build_rule(@data_set.data_items, @selected_attribute, domains)
59
+ else
60
+ domains[1...-1].each_index do |attr_index|
61
+ rule = build_rule(@data_set.data_items, attr_index, domains)
62
+ if !@rule || rule[:correct] > @rule[:correct] ||
63
+ (rule[:correct] == @rule[:correct] && @tie_break == :last)
64
+ @rule = rule
65
+ end
66
+ end
44
67
  end
45
- return self
68
+ self
46
69
  end
47
-
70
+
48
71
  # You can evaluate new data, predicting its class.
49
72
  # e.g.
50
- # classifier.eval(['New York', '<30', 'F']) # => 'Y'
73
+ # classifier.eval(['New York', '<30', 'F']) # => 'Y'
74
+ # @param data [Object]
75
+ # @return [Object]
51
76
  def eval(data)
52
77
  return @zero_r.eval(data) if @zero_r
78
+
53
79
  attr_value = data[@rule[:attr_index]]
54
- return @rule[:rule][attr_value]
80
+ if @rule[:bins]
81
+ bin = @rule[:bins].find { |b| b.include?(attr_value) }
82
+ attr_value = bin
83
+ end
84
+ @rule[:rule][attr_value]
55
85
  end
56
-
86
+
57
87
  # This method returns the generated rules in ruby code.
58
88
  # e.g.
59
- #
89
+ #
60
90
  # classifier.get_rules
61
91
  # # => if age_range == '<30' then marketing_target = 'Y'
62
92
  # elsif age_range == '[30-50)' then marketing_target = 'N'
63
93
  # elsif age_range == '[50-80]' then marketing_target = 'N'
64
94
  # end
65
95
  #
66
- # It is a nice way to inspect induction results, and also to execute them:
96
+ # It is a nice way to inspect induction results, and also to execute them:
67
97
  # marketing_target = nil
68
- # eval classifier.get_rules
98
+ # eval classifier.get_rules
69
99
  # puts marketing_target
70
100
  # # => 'Y'
101
+ # @return [Object]
71
102
  def get_rules
72
103
  return @zero_r.get_rules if @zero_r
104
+
73
105
  sentences = []
74
106
  attr_label = @data_set.data_labels[@rule[:attr_index]]
75
- class_label = @data_set.data_labels.last
107
+ class_label = @data_set.category_label
76
108
  @rule[:rule].each_pair do |attr_value, class_value|
77
- sentences << "#{attr_label} == '#{attr_value}' then #{class_label} = '#{class_value}'"
109
+ sentences << if attr_value.is_a?(Range)
110
+ "(#{attr_value}).include?(#{attr_label}) then #{class_label} = '#{class_value}'"
111
+ else
112
+ "#{attr_label} == '#{attr_value}' then #{class_label} = '#{class_value}'"
113
+ end
78
114
  end
79
- return "if " + sentences.join("\nelsif ") + "\nend"
115
+ "if #{sentences.join("\nelsif ")}\nend"
80
116
  end
81
-
117
+
82
118
  protected
83
-
119
+
120
+ # @param data_examples [Object]
121
+ # @param attr_index [Object]
122
+ # @param domains [Object]
123
+ # @return [Object]
84
124
  def build_rule(data_examples, attr_index, domains)
85
125
  domain = domains[attr_index]
86
- value_freq = Hash.new
87
- domain.each do |attr_value|
88
- value_freq[attr_value] = Hash.new { |hash, key| hash[key] = 0 }
89
- end
90
- data_examples.each do |data|
91
- value_freq[data[attr_index]][data.last] = value_freq[data[attr_index]][data.last] + 1
126
+ bins, value_freq = build_frequency(domain, data_examples, attr_index)
127
+ rule, correct_instances = rule_from_frequency(value_freq)
128
+ { attr_index: attr_index, rule: rule, correct: correct_instances, bins: bins }
129
+ end
130
+
131
+ def build_frequency(domain, data_examples, attr_index)
132
+ if domain.is_a?(Array) && domain.length == 2 && domain.all? { |v| v.is_a? Numeric }
133
+ bins = discretize_range(domain, @bin_count)
134
+ value_freq = bins.each_with_object({}) { |b, h| h[b] = Hash.new(0) }
135
+ data_examples.each do |data|
136
+ bin = bins.find { |b| b.include?(data[attr_index]) }
137
+ value_freq[bin][data.last] += 1
138
+ end
139
+ else
140
+ bins = nil
141
+ value_freq = domain.each_with_object({}) { |v, h| h[v] = Hash.new(0) }
142
+ data_examples.each do |data|
143
+ value_freq[data[attr_index]][data.last] += 1
144
+ end
92
145
  end
146
+ [bins, value_freq]
147
+ end
148
+
149
+ def rule_from_frequency(value_freq)
93
150
  rule = {}
94
151
  correct_instances = 0
95
- value_freq.each_pair do |attr, class_freq_hash|
96
- max_freq = 0
97
- class_freq_hash.each_pair do |class_value, freq|
98
- if max_freq < freq
99
- rule[attr] = class_value
100
- max_freq = freq
101
- end
102
- end
152
+ value_freq.each_pair do |attr, class_freq_hash|
153
+ pair = class_freq_hash.max_by { |_k, v| v }
154
+ next unless pair
155
+
156
+ rule[attr], max_freq = pair
103
157
  correct_instances += max_freq
104
158
  end
105
- return {:attr_index => attr_index, :rule => rule, :correct => correct_instances}
159
+ [rule, correct_instances]
106
160
  end
107
161
 
162
+ # @param range [Object]
163
+ # @param bins [Object]
164
+ # @return [Object]
165
+ def discretize_range(range, bins)
166
+ min, max = range
167
+ step = (max - min).to_f / bins
168
+ ranges = []
169
+ bins.times do |i|
170
+ low = min + (i * step)
171
+ high = i == bins - 1 ? max : min + ((i + 1) * step)
172
+ ranges << (i == bins - 1 ? (low..high) : (low...high))
173
+ end
174
+ ranges
175
+ end
108
176
  end
109
177
  end
110
178
  end
@@ -1,65 +1,99 @@
1
- # Author:: Sergio Fierens (Implementation only, Cendrowska is
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: Sergio Fierens (Implementation only, Cendrowska is
2
4
  # the creator of the algorithm)
3
5
  # License:: MPL 1.1
4
6
  # Project:: ai4r
5
- # Url:: http://www.ai4r.org/
7
+ # Url:: https://github.com/SergioFierens/ai4r
6
8
  #
7
- # You can redistribute it and/or modify it under the terms of
8
- # the Mozilla Public License version 1.1 as published by the
9
+ # You can redistribute it and/or modify it under the terms of
10
+ # the Mozilla Public License version 1.1 as published by the
9
11
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
10
12
  #
11
- # J. Cendrowska (1987). PRISM: An algorithm for inducing modular rules.
13
+ # J. Cendrowska (1987). PRISM: An algorithm for inducing modular rules.
12
14
  # International Journal of Man-Machine Studies. 27(4):349-370.
13
15
 
14
- require File.dirname(__FILE__) + '/../data/data_set'
15
- require File.dirname(__FILE__) + '/../classifiers/classifier'
16
+ require_relative '../data/data_set'
17
+ require_relative '../classifiers/classifier'
16
18
 
17
19
  module Ai4r
18
20
  module Classifiers
19
-
20
21
  # = Introduction
21
- # This is an implementation of the PRISM algorithm (Cendrowska, 1987)
22
+ # This is an implementation of the PRISM algorithm (Cendrowska, 1987)
22
23
  # Given a set of preclassified examples, it builds a set of rules
23
24
  # to predict the class of other instaces.
24
- #
25
- # J. Cendrowska (1987). PRISM: An algorithm for inducing modular rules.
25
+ #
26
+ # J. Cendrowska (1987). PRISM: An algorithm for inducing modular rules.
26
27
  # International Journal of Man-Machine Studies. 27(4):349-370.
27
28
  class Prism < Classifier
28
-
29
- attr_reader :data_set, :rules
29
+ attr_reader :data_set, :rules, :majority_class
30
+
31
+ parameters_info(
32
+ fallback_class: 'Default class returned when no rule matches.',
33
+ bin_count: 'Number of bins used to discretize numeric attributes.',
34
+ default_class: 'Return this value when no rule matches.',
35
+ tie_break: 'Strategy when multiple conditions have equal ratios.'
36
+ )
37
+
38
+ def initialize
39
+ super()
40
+ @fallback_class = nil
41
+ @bin_count = 10
42
+ @attr_bins = {}
43
+
44
+ @default_class = nil
45
+ @tie_break = :first
46
+ @bin_count = 10
47
+ @attr_bins = {}
48
+ end
30
49
 
31
50
  # Build a new Prism classifier. You must provide a DataSet instance
32
- # as parameter. The last attribute of each item is considered as
51
+ # as parameter. The last attribute of each item is considered as
33
52
  # the item class.
53
+ # @param data_set [Object]
54
+ # @return [Object]
34
55
  def build(data_set)
35
56
  data_set.check_not_empty
36
57
  @data_set = data_set
58
+
59
+ freqs = Hash.new(0)
60
+ @data_set.data_items.each { |item| freqs[item.last] += 1 }
61
+ @majority_class = freqs.max_by { |_, v| v }&.first
62
+ @fallback_class = @default_class if @default_class
63
+ @fallback_class = @majority_class if @fallback_class.nil?
64
+
37
65
  domains = @data_set.build_domains
38
- instances = @data_set.data_items.collect {|data| data }
66
+ @attr_bins = {}
67
+ domains[0...-1].each_with_index do |domain, i|
68
+ @attr_bins[@data_set.data_labels[i]] = discretize_range(domain, @bin_count) if domain.is_a?(Array) && domain.length == 2 && domain.all? { |v| v.is_a? Numeric }
69
+ end
70
+ instances = @data_set.data_items.collect { |data| data }
39
71
  @rules = []
40
72
  domains.last.each do |class_value|
41
- while(has_class_value(instances, class_value))
73
+ while class_value?(instances, class_value)
42
74
  rule = build_rule(class_value, instances)
43
75
  @rules << rule
44
- instances = instances.select {|data| !matches_conditions(data, rule[:conditions])}
76
+ instances = instances.reject { |data| matches_conditions(data, rule[:conditions]) }
45
77
  end
46
78
  end
47
- return self
79
+ self
48
80
  end
49
81
 
50
82
  # You can evaluate new data, predicting its class.
51
83
  # e.g.
52
- # classifier.eval(['New York', '<30', 'F']) # => 'Y'
84
+ # classifier.eval(['New York', '<30', 'F']) # => 'Y'
85
+ # @param instace [Object]
86
+ # @return [Object]
53
87
  def eval(instace)
54
88
  @rules.each do |rule|
55
89
  return rule[:class_value] if matches_conditions(instace, rule[:conditions])
56
90
  end
57
- return nil
91
+ @default_class || @fallback_class
58
92
  end
59
-
93
+
60
94
  # This method returns the generated rules in ruby code.
61
95
  # e.g.
62
- #
96
+ #
63
97
  # classifier.get_rules
64
98
  # # => if age_range == '<30' then marketing_target = 'Y'
65
99
  # elsif age_range == '>80' then marketing_target = 'Y'
@@ -67,131 +101,188 @@ module Ai4r
67
101
  # else marketing_target = 'N'
68
102
  # end
69
103
  #
70
- # It is a nice way to inspect induction results, and also to execute them:
104
+ # It is a nice way to inspect induction results, and also to execute them:
71
105
  # age_range = '[30-50)'
72
106
  # city = 'New York'
73
- # eval(classifier.get_rules)
107
+ # eval(classifier.get_rules)
74
108
  # puts marketing_target
75
109
  # 'Y'
110
+ # @return [Object]
76
111
  def get_rules
77
112
  out = "if #{join_terms(@rules.first)} then #{then_clause(@rules.first)}"
78
- @rules[1...-1].each do |rule|
113
+ @rules[1...-1].each do |rule|
79
114
  out += "\nelsif #{join_terms(rule)} then #{then_clause(rule)}"
80
115
  end
81
116
  out += "\nelse #{then_clause(@rules.last)}" if @rules.size > 1
82
117
  out += "\nend"
83
- return out
118
+ out
84
119
  end
85
-
120
+
86
121
  protected
87
-
122
+
123
+ # @param data [Object]
124
+ # @param attr [Object]
125
+ # @return [Object]
88
126
  def get_attr_value(data, attr)
89
127
  data[@data_set.get_index(attr)]
90
128
  end
91
-
92
- def has_class_value(instances, class_value)
93
- instances.each { |data| return true if data.last == class_value}
94
- return false
129
+
130
+ # @param instances [Object]
131
+ # @param class_value [Object]
132
+ # @return [Object]
133
+ def class_value?(instances, class_value)
134
+ instances.any? { |data| data.last == class_value }
95
135
  end
96
-
97
- def is_perfect(instances, rule)
136
+
137
+ # @param instances [Object]
138
+ # @param rule [Object]
139
+ # @return [Object]
140
+ def perfect?(instances, rule)
98
141
  class_value = rule[:class_value]
99
- instances.each do |data|
100
- return false if data.last != class_value and matches_conditions(data, rule[:conditions])
142
+ instances.each do |data|
143
+ return false if (data.last != class_value) && matches_conditions(data, rule[:conditions])
101
144
  end
102
- return true
145
+ true
103
146
  end
104
-
147
+
148
+ # @param data [Object]
149
+ # @param conditions [Object]
150
+ # @return [Object]
105
151
  def matches_conditions(data, conditions)
106
152
  conditions.each_pair do |attr_label, attr_value|
107
- return false if get_attr_value(data, attr_label) != attr_value
153
+ value = get_attr_value(data, attr_label)
154
+ if attr_value.is_a?(Range)
155
+ return false unless attr_value.include?(value)
156
+ else
157
+ return false unless value == attr_value
158
+ end
108
159
  end
109
- return true
160
+ true
110
161
  end
111
-
162
+
163
+ # @param class_value [Object]
164
+ # @param instances [Object]
165
+ # @return [Object]
112
166
  def build_rule(class_value, instances)
113
- rule = {:class_value => class_value, :conditions => {}}
114
- rule_instances = instances.collect {|data| data }
115
- attributes = @data_set.data_labels[0...-1].collect {|label| label }
116
- until(is_perfect(instances, rule) || attributes.empty?)
167
+ rule = { class_value: class_value, conditions: {} }
168
+ rule_instances = instances.collect { |data| data }
169
+ attributes = @data_set.data_labels[0...-1].collect { |label| label }
170
+ until perfect?(instances, rule) || attributes.empty?
117
171
  freq_table = build_freq_table(rule_instances, attributes, class_value)
118
172
  condition = get_condition(freq_table)
119
173
  rule[:conditions].merge!(condition)
120
- rule_instances = rule_instances.select do |data|
121
- matches_conditions(data, condition)
174
+ attributes.delete(condition.keys.first)
175
+ rule_instances = rule_instances.select do |data|
176
+ matches_conditions(data, condition)
122
177
  end
123
178
  end
124
- return rule
179
+ rule
125
180
  end
126
-
181
+
127
182
  # Returns a structure with the folloring format:
128
183
  # => {attr1_label => { :attr1_value1 => [p, t], attr1_value2 => [p, t], ... },
129
184
  # attr2_label => { :attr2_value1 => [p, t], attr2_value2 => [p, t], ... },
130
185
  # ...
131
186
  # }
132
187
  # where p is the number of instances classified as class_value
133
- # with that attribute value, and t is the total number of instances with
188
+ # with that attribute value, and t is the total number of instances with
134
189
  # that attribute value
190
+ # @param rule_instances [Object]
191
+ # @param attributes [Object]
192
+ # @param class_value [Object]
193
+ # @return [Object]
135
194
  def build_freq_table(rule_instances, attributes, class_value)
136
- freq_table = Hash.new()
195
+ freq_table = {}
137
196
  rule_instances.each do |data|
138
197
  attributes.each do |attr_label|
139
198
  attr_freqs = freq_table[attr_label] || Hash.new([0, 0])
140
- pt = attr_freqs[get_attr_value(data, attr_label)]
141
- pt = [(data.last == class_value) ? pt[0]+1 : pt[0], pt[1]+1]
142
- attr_freqs[get_attr_value(data, attr_label)] = pt
199
+ value = get_attr_value(data, attr_label)
200
+ if (bins = @attr_bins[attr_label])
201
+ value = bins.find { |b| b.include?(value) }
202
+ end
203
+ pt = attr_freqs[value]
204
+ pt = [data.last == class_value ? pt[0] + 1 : pt[0], pt[1] + 1]
205
+ attr_freqs[value] = pt
143
206
  freq_table[attr_label] = attr_freqs
144
207
  end
145
208
  end
146
- return freq_table
209
+ freq_table
147
210
  end
148
-
211
+
149
212
  # returns a single conditional term: {attrN_label => attrN_valueM}
150
213
  # selecting the attribute with higher pt ratio
151
- # (occurrences of attribute value classified as class_value /
214
+ # (occurrences of attribute value classified as class_value /
152
215
  # occurrences of attribute value)
216
+ # @param freq_table [Object]
217
+ # @return [Object]
153
218
  def get_condition(freq_table)
154
219
  best_pt = [0, 0]
155
220
  condition = nil
156
221
  freq_table.each do |attr_label, attr_freqs|
157
222
  attr_freqs.each do |attr_value, pt|
158
- if(better_pt(pt, best_pt))
223
+ if better_pt(pt, best_pt)
159
224
  condition = { attr_label => attr_value }
160
225
  best_pt = pt
161
226
  end
162
227
  end
163
228
  end
164
- return condition
229
+ condition
165
230
  end
166
-
231
+
167
232
  # pt = [p, t]
168
233
  # p = occurrences of attribute value with instance classified as class_value
169
234
  # t = occurrences of attribute value
170
235
  # a pt is better if:
171
236
  # 1- its ratio is higher
172
- # 2- its ratio is equal, and has a higher p
237
+ # 2- its ratio is equal, and has a higher p
238
+ # @param pt [Object]
239
+ # @param best_pt [Object]
240
+ # @return [Object]
173
241
  def better_pt(pt, best_pt)
174
- return false if pt[1] == 0
175
- return true if best_pt[1] == 0
176
- a = pt[0]*best_pt[1]
177
- b = best_pt[0]*pt[1]
178
- return true if a>b || (a==b && pt[0]>best_pt[0])
179
- return false
242
+ return false if pt[1].zero?
243
+ return true if best_pt[1].zero?
244
+
245
+ a = pt[0] * best_pt[1]
246
+ b = best_pt[0] * pt[1]
247
+ return true if a > b || (a == b && pt[0] > best_pt[0])
248
+ return true if a == b && pt[0] == best_pt[0] && @tie_break == :last
249
+
250
+ false
180
251
  end
181
-
252
+
253
+ # @param range [Object]
254
+ # @param bins [Object]
255
+ # @return [Object]
256
+ def discretize_range(range, bins)
257
+ min, max = range
258
+ step = (max - min).to_f / bins
259
+ ranges = []
260
+ bins.times do |i|
261
+ low = min + (i * step)
262
+ high = i == bins - 1 ? max : min + ((i + 1) * step)
263
+ ranges << (i == bins - 1 ? (low..high) : (low...high))
264
+ end
265
+ ranges
266
+ end
267
+
268
+ # @param rule [Object]
269
+ # @return [Object]
182
270
  def join_terms(rule)
183
- terms = []
184
- rule[:conditions].each do |attr_label, attr_value|
185
- terms << "#{attr_label} == '#{attr_value}'"
271
+ terms = rule[:conditions].map do |attr_label, attr_value|
272
+ if attr_value.is_a?(Range)
273
+ "(#{attr_value}).include?(#{attr_label})"
274
+ else
275
+ "#{attr_label} == '#{attr_value}'"
276
+ end
186
277
  end
187
- "#{terms.join(" and ")}"
278
+ terms.join(' and ').to_s
188
279
  end
189
-
280
+
281
+ # @param rule [Object]
282
+ # @return [Object]
190
283
  def then_clause(rule)
191
- "#{@data_set.data_labels.last} = '#{rule[:class_value]}'"
284
+ "#{@data_set.category_label} = '#{rule[:class_value]}'"
192
285
  end
193
-
194
286
  end
195
287
  end
196
288
  end
197
-
@@ -0,0 +1,72 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: OpenAI ChatGPT
4
+ # License:: MPL 1.1
5
+ # Project:: ai4r
6
+ #
7
+ # A simple Random Forest implementation using ID3 decision trees.
8
+
9
+ require_relative 'id3'
10
+ require_relative '../data/data_set'
11
+ require_relative '../classifiers/classifier'
12
+ require_relative 'votes'
13
+
14
+ module Ai4r
15
+ module Classifiers
16
+ # RandomForest ensemble classifier built from decision trees.
17
+ class RandomForest < Classifier
18
+ parameters_info n_trees: 'Number of trees to build. Default 10.',
19
+ sample_size: 'Number of data items for each tree (with replacement). Default: data set size.',
20
+ feature_fraction:
21
+ 'Fraction of attributes sampled for each tree. Default: sqrt(num_attributes)/num_attributes.',
22
+ random_seed: 'Seed for reproducible randomness.'
23
+
24
+ attr_reader :trees, :features
25
+
26
+ def initialize
27
+ super()
28
+ @n_trees = 10
29
+ @sample_size = nil
30
+ @feature_fraction = nil
31
+ @random_seed = nil
32
+ end
33
+
34
+ def build(data_set)
35
+ data_set.check_not_empty
36
+ rng = @random_seed ? Random.new(@random_seed) : Random.new
37
+ num_attributes = data_set.data_labels.length - 1
38
+ frac = @feature_fraction || (Math.sqrt(num_attributes) / num_attributes)
39
+ feature_count = [1, (num_attributes * frac).round].max
40
+ @sample_size ||= data_set.data_items.length
41
+ @trees = []
42
+ @features = []
43
+ @n_trees.times do
44
+ sampled = Array.new(@sample_size) { data_set.data_items.sample(random: rng) }
45
+ feature_idx = (0...num_attributes).to_a.sample(feature_count, random: rng)
46
+ tree_items = sampled.map do |item|
47
+ values = feature_idx.map { |i| item[i] }
48
+ values + [item.last]
49
+ end
50
+ labels = feature_idx.map { |i| data_set.data_labels[i] } + [data_set.data_labels.last]
51
+ ds = Ai4r::Data::DataSet.new(data_items: tree_items, data_labels: labels)
52
+ @trees << ID3.new.build(ds)
53
+ @features << feature_idx
54
+ end
55
+ self
56
+ end
57
+
58
+ def eval(data)
59
+ votes = Votes.new
60
+ @trees.each_with_index do |tree, idx|
61
+ sub_data = @features[idx].map { |i| data[i] }
62
+ votes.increment_category(tree.eval(sub_data))
63
+ end
64
+ votes.get_winner
65
+ end
66
+
67
+ def get_rules
68
+ 'RandomForest does not support rule extraction.'
69
+ end
70
+ end
71
+ end
72
+ end