ai4r 1.13 → 2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +174 -0
- data/examples/classifiers/hyperpipes_data.csv +14 -0
- data/examples/classifiers/hyperpipes_example.rb +22 -0
- data/examples/classifiers/ib1_example.rb +12 -0
- data/examples/classifiers/id3_example.rb +15 -10
- data/examples/classifiers/id3_graphviz_example.rb +17 -0
- data/examples/classifiers/logistic_regression_example.rb +11 -0
- data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
- data/examples/classifiers/naive_bayes_example.rb +12 -13
- data/examples/classifiers/one_r_example.rb +27 -0
- data/examples/classifiers/parameter_tutorial.rb +29 -0
- data/examples/classifiers/prism_nominal_example.rb +15 -0
- data/examples/classifiers/prism_numeric_example.rb +21 -0
- data/examples/classifiers/simple_linear_regression_example.rb +14 -11
- data/examples/classifiers/zero_and_one_r_example.rb +34 -0
- data/examples/classifiers/zero_one_r_data.csv +8 -0
- data/examples/clusterers/clusterer_example.rb +40 -34
- data/examples/clusterers/dbscan_example.rb +17 -0
- data/examples/clusterers/dendrogram_example.rb +17 -0
- data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
- data/examples/clusterers/kmeans_custom_example.rb +26 -0
- data/examples/genetic_algorithm/bitstring_example.rb +41 -0
- data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
- data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
- data/examples/neural_network/backpropagation_example.rb +48 -48
- data/examples/neural_network/hopfield_example.rb +45 -0
- data/examples/neural_network/patterns_with_base_noise.rb +39 -39
- data/examples/neural_network/patterns_with_noise.rb +41 -39
- data/examples/neural_network/train_epochs_callback.rb +25 -0
- data/examples/neural_network/training_patterns.rb +39 -39
- data/examples/neural_network/transformer_text_classification.rb +78 -0
- data/examples/neural_network/xor_example.rb +23 -22
- data/examples/reinforcement/q_learning_example.rb +10 -0
- data/examples/som/som_data.rb +155 -152
- data/examples/som/som_multi_node_example.rb +12 -13
- data/examples/som/som_single_example.rb +12 -15
- data/examples/transformer/decode_classifier_example.rb +68 -0
- data/examples/transformer/deterministic_example.rb +10 -0
- data/examples/transformer/seq2seq_example.rb +16 -0
- data/lib/ai4r/classifiers/classifier.rb +24 -16
- data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
- data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
- data/lib/ai4r/classifiers/ib1.rb +122 -32
- data/lib/ai4r/classifiers/id3.rb +524 -145
- data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
- data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
- data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
- data/lib/ai4r/classifiers/one_r.rb +112 -44
- data/lib/ai4r/classifiers/prism.rb +167 -76
- data/lib/ai4r/classifiers/random_forest.rb +72 -0
- data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
- data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
- data/lib/ai4r/classifiers/votes.rb +57 -0
- data/lib/ai4r/classifiers/zero_r.rb +71 -30
- data/lib/ai4r/clusterers/average_linkage.rb +46 -27
- data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
- data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
- data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
- data/lib/ai4r/clusterers/clusterer.rb +29 -14
- data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
- data/lib/ai4r/clusterers/dbscan.rb +134 -0
- data/lib/ai4r/clusterers/diana.rb +75 -49
- data/lib/ai4r/clusterers/k_means.rb +270 -135
- data/lib/ai4r/clusterers/median_linkage.rb +49 -33
- data/lib/ai4r/clusterers/single_linkage.rb +196 -88
- data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
- data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
- data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
- data/lib/ai4r/data/data_set.rb +223 -103
- data/lib/ai4r/data/parameterizable.rb +31 -25
- data/lib/ai4r/data/proximity.rb +62 -62
- data/lib/ai4r/data/statistics.rb +46 -35
- data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
- data/lib/ai4r/experiment/split.rb +39 -0
- data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
- data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
- data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
- data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
- data/lib/ai4r/neural_network/activation_functions.rb +37 -0
- data/lib/ai4r/neural_network/backpropagation.rb +399 -134
- data/lib/ai4r/neural_network/hopfield.rb +175 -58
- data/lib/ai4r/neural_network/transformer.rb +194 -0
- data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
- data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
- data/lib/ai4r/reinforcement/q_learning.rb +51 -0
- data/lib/ai4r/search/a_star.rb +76 -0
- data/lib/ai4r/search/bfs.rb +50 -0
- data/lib/ai4r/search/dfs.rb +50 -0
- data/lib/ai4r/search/mcts.rb +118 -0
- data/lib/ai4r/search.rb +12 -0
- data/lib/ai4r/som/distance_metrics.rb +29 -0
- data/lib/ai4r/som/layer.rb +28 -17
- data/lib/ai4r/som/node.rb +61 -32
- data/lib/ai4r/som/som.rb +158 -41
- data/lib/ai4r/som/two_phase_layer.rb +21 -25
- data/lib/ai4r/version.rb +3 -0
- data/lib/ai4r.rb +57 -28
- metadata +79 -109
- data/README.rdoc +0 -39
- data/test/classifiers/hyperpipes_test.rb +0 -84
- data/test/classifiers/ib1_test.rb +0 -78
- data/test/classifiers/id3_test.rb +0 -220
- data/test/classifiers/multilayer_perceptron_test.rb +0 -79
- data/test/classifiers/naive_bayes_test.rb +0 -43
- data/test/classifiers/one_r_test.rb +0 -62
- data/test/classifiers/prism_test.rb +0 -85
- data/test/classifiers/simple_linear_regression_test.rb +0 -37
- data/test/classifiers/zero_r_test.rb +0 -50
- data/test/clusterers/average_linkage_test.rb +0 -51
- data/test/clusterers/bisecting_k_means_test.rb +0 -66
- data/test/clusterers/centroid_linkage_test.rb +0 -53
- data/test/clusterers/complete_linkage_test.rb +0 -57
- data/test/clusterers/diana_test.rb +0 -69
- data/test/clusterers/k_means_test.rb +0 -167
- data/test/clusterers/median_linkage_test.rb +0 -53
- data/test/clusterers/single_linkage_test.rb +0 -122
- data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
- data/test/clusterers/ward_linkage_test.rb +0 -53
- data/test/clusterers/weighted_average_linkage_test.rb +0 -53
- data/test/data/data_set_test.rb +0 -104
- data/test/data/proximity_test.rb +0 -87
- data/test/data/statistics_test.rb +0 -65
- data/test/experiment/classifier_evaluator_test.rb +0 -76
- data/test/genetic_algorithm/chromosome_test.rb +0 -57
- data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
- data/test/neural_network/backpropagation_test.rb +0 -82
- data/test/neural_network/hopfield_test.rb +0 -72
- data/test/som/som_test.rb +0 -97
@@ -1,110 +1,178 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens (Implementation only)
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
|
-
# Url::
|
6
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
12
|
require 'set'
|
11
|
-
|
12
|
-
|
13
|
+
require_relative '../data/data_set'
|
14
|
+
require_relative '../classifiers/classifier'
|
13
15
|
|
14
16
|
module Ai4r
|
15
17
|
module Classifiers
|
16
|
-
|
17
18
|
# = Introduction
|
18
|
-
#
|
19
|
+
#
|
19
20
|
# The idea of the OneR algorithm is identify the single
|
20
|
-
# attribute to use to classify data that makes
|
21
|
+
# attribute to use to classify data that makes
|
21
22
|
# fewest prediction errors.
|
22
23
|
# It generates rules based on a single attribute.
|
24
|
+
# Numeric attributes are automatically discretized into a fixed
|
25
|
+
# number of bins (default is 10).
|
23
26
|
class OneR < Classifier
|
24
|
-
|
25
27
|
attr_reader :data_set, :rule
|
26
28
|
|
29
|
+
parameters_info selected_attribute: 'Index of the attribute to force.',
|
30
|
+
tie_break: 'Strategy when two attributes yield the same accuracy.',
|
31
|
+
bin_count: 'Number of bins used to discretize numeric attributes.'
|
32
|
+
|
33
|
+
# @return [Object]
|
34
|
+
def initialize
|
35
|
+
super()
|
36
|
+
@selected_attribute = nil
|
37
|
+
@tie_break = :first
|
38
|
+
@bin_count = 10
|
39
|
+
end
|
40
|
+
|
27
41
|
# Build a new OneR classifier. You must provide a DataSet instance
|
28
|
-
# as parameter. The last attribute of each item is considered as
|
42
|
+
# as parameter. The last attribute of each item is considered as
|
29
43
|
# the item class.
|
44
|
+
# @param data_set [Object]
|
45
|
+
# @return [Object]
|
30
46
|
def build(data_set)
|
31
47
|
data_set.check_not_empty
|
32
48
|
@data_set = data_set
|
33
|
-
if
|
49
|
+
if data_set.num_attributes == 1
|
34
50
|
@zero_r = ZeroR.new.build(data_set)
|
35
|
-
return self
|
51
|
+
return self
|
36
52
|
else
|
37
|
-
@zero_r = nil
|
53
|
+
@zero_r = nil
|
38
54
|
end
|
39
55
|
domains = @data_set.build_domains
|
40
56
|
@rule = nil
|
41
|
-
|
42
|
-
rule = build_rule(@data_set.data_items,
|
43
|
-
|
57
|
+
if @selected_attribute
|
58
|
+
@rule = build_rule(@data_set.data_items, @selected_attribute, domains)
|
59
|
+
else
|
60
|
+
domains[1...-1].each_index do |attr_index|
|
61
|
+
rule = build_rule(@data_set.data_items, attr_index, domains)
|
62
|
+
if !@rule || rule[:correct] > @rule[:correct] ||
|
63
|
+
(rule[:correct] == @rule[:correct] && @tie_break == :last)
|
64
|
+
@rule = rule
|
65
|
+
end
|
66
|
+
end
|
44
67
|
end
|
45
|
-
|
68
|
+
self
|
46
69
|
end
|
47
|
-
|
70
|
+
|
48
71
|
# You can evaluate new data, predicting its class.
|
49
72
|
# e.g.
|
50
|
-
# classifier.eval(['New York', '<30', 'F']) # => 'Y'
|
73
|
+
# classifier.eval(['New York', '<30', 'F']) # => 'Y'
|
74
|
+
# @param data [Object]
|
75
|
+
# @return [Object]
|
51
76
|
def eval(data)
|
52
77
|
return @zero_r.eval(data) if @zero_r
|
78
|
+
|
53
79
|
attr_value = data[@rule[:attr_index]]
|
54
|
-
|
80
|
+
if @rule[:bins]
|
81
|
+
bin = @rule[:bins].find { |b| b.include?(attr_value) }
|
82
|
+
attr_value = bin
|
83
|
+
end
|
84
|
+
@rule[:rule][attr_value]
|
55
85
|
end
|
56
|
-
|
86
|
+
|
57
87
|
# This method returns the generated rules in ruby code.
|
58
88
|
# e.g.
|
59
|
-
#
|
89
|
+
#
|
60
90
|
# classifier.get_rules
|
61
91
|
# # => if age_range == '<30' then marketing_target = 'Y'
|
62
92
|
# elsif age_range == '[30-50)' then marketing_target = 'N'
|
63
93
|
# elsif age_range == '[50-80]' then marketing_target = 'N'
|
64
94
|
# end
|
65
95
|
#
|
66
|
-
# It is a nice way to inspect induction results, and also to execute them:
|
96
|
+
# It is a nice way to inspect induction results, and also to execute them:
|
67
97
|
# marketing_target = nil
|
68
|
-
# eval classifier.get_rules
|
98
|
+
# eval classifier.get_rules
|
69
99
|
# puts marketing_target
|
70
100
|
# # => 'Y'
|
101
|
+
# @return [Object]
|
71
102
|
def get_rules
|
72
103
|
return @zero_r.get_rules if @zero_r
|
104
|
+
|
73
105
|
sentences = []
|
74
106
|
attr_label = @data_set.data_labels[@rule[:attr_index]]
|
75
|
-
class_label = @data_set.
|
107
|
+
class_label = @data_set.category_label
|
76
108
|
@rule[:rule].each_pair do |attr_value, class_value|
|
77
|
-
sentences <<
|
109
|
+
sentences << if attr_value.is_a?(Range)
|
110
|
+
"(#{attr_value}).include?(#{attr_label}) then #{class_label} = '#{class_value}'"
|
111
|
+
else
|
112
|
+
"#{attr_label} == '#{attr_value}' then #{class_label} = '#{class_value}'"
|
113
|
+
end
|
78
114
|
end
|
79
|
-
|
115
|
+
"if #{sentences.join("\nelsif ")}\nend"
|
80
116
|
end
|
81
|
-
|
117
|
+
|
82
118
|
protected
|
83
|
-
|
119
|
+
|
120
|
+
# @param data_examples [Object]
|
121
|
+
# @param attr_index [Object]
|
122
|
+
# @param domains [Object]
|
123
|
+
# @return [Object]
|
84
124
|
def build_rule(data_examples, attr_index, domains)
|
85
125
|
domain = domains[attr_index]
|
86
|
-
value_freq =
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
126
|
+
bins, value_freq = build_frequency(domain, data_examples, attr_index)
|
127
|
+
rule, correct_instances = rule_from_frequency(value_freq)
|
128
|
+
{ attr_index: attr_index, rule: rule, correct: correct_instances, bins: bins }
|
129
|
+
end
|
130
|
+
|
131
|
+
def build_frequency(domain, data_examples, attr_index)
|
132
|
+
if domain.is_a?(Array) && domain.length == 2 && domain.all? { |v| v.is_a? Numeric }
|
133
|
+
bins = discretize_range(domain, @bin_count)
|
134
|
+
value_freq = bins.each_with_object({}) { |b, h| h[b] = Hash.new(0) }
|
135
|
+
data_examples.each do |data|
|
136
|
+
bin = bins.find { |b| b.include?(data[attr_index]) }
|
137
|
+
value_freq[bin][data.last] += 1
|
138
|
+
end
|
139
|
+
else
|
140
|
+
bins = nil
|
141
|
+
value_freq = domain.each_with_object({}) { |v, h| h[v] = Hash.new(0) }
|
142
|
+
data_examples.each do |data|
|
143
|
+
value_freq[data[attr_index]][data.last] += 1
|
144
|
+
end
|
92
145
|
end
|
146
|
+
[bins, value_freq]
|
147
|
+
end
|
148
|
+
|
149
|
+
def rule_from_frequency(value_freq)
|
93
150
|
rule = {}
|
94
151
|
correct_instances = 0
|
95
|
-
value_freq.each_pair do |attr, class_freq_hash|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
max_freq = freq
|
101
|
-
end
|
102
|
-
end
|
152
|
+
value_freq.each_pair do |attr, class_freq_hash|
|
153
|
+
pair = class_freq_hash.max_by { |_k, v| v }
|
154
|
+
next unless pair
|
155
|
+
|
156
|
+
rule[attr], max_freq = pair
|
103
157
|
correct_instances += max_freq
|
104
158
|
end
|
105
|
-
|
159
|
+
[rule, correct_instances]
|
106
160
|
end
|
107
161
|
|
162
|
+
# @param range [Object]
|
163
|
+
# @param bins [Object]
|
164
|
+
# @return [Object]
|
165
|
+
def discretize_range(range, bins)
|
166
|
+
min, max = range
|
167
|
+
step = (max - min).to_f / bins
|
168
|
+
ranges = []
|
169
|
+
bins.times do |i|
|
170
|
+
low = min + (i * step)
|
171
|
+
high = i == bins - 1 ? max : min + ((i + 1) * step)
|
172
|
+
ranges << (i == bins - 1 ? (low..high) : (low...high))
|
173
|
+
end
|
174
|
+
ranges
|
175
|
+
end
|
108
176
|
end
|
109
177
|
end
|
110
178
|
end
|
@@ -1,65 +1,99 @@
|
|
1
|
-
#
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: Sergio Fierens (Implementation only, Cendrowska is
|
2
4
|
# the creator of the algorithm)
|
3
5
|
# License:: MPL 1.1
|
4
6
|
# Project:: ai4r
|
5
|
-
# Url::
|
7
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
6
8
|
#
|
7
|
-
# You can redistribute it and/or modify it under the terms of
|
8
|
-
# the Mozilla Public License version 1.1 as published by the
|
9
|
+
# You can redistribute it and/or modify it under the terms of
|
10
|
+
# the Mozilla Public License version 1.1 as published by the
|
9
11
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
10
12
|
#
|
11
|
-
# J. Cendrowska (1987). PRISM: An algorithm for inducing modular rules.
|
13
|
+
# J. Cendrowska (1987). PRISM: An algorithm for inducing modular rules.
|
12
14
|
# International Journal of Man-Machine Studies. 27(4):349-370.
|
13
15
|
|
14
|
-
|
15
|
-
|
16
|
+
require_relative '../data/data_set'
|
17
|
+
require_relative '../classifiers/classifier'
|
16
18
|
|
17
19
|
module Ai4r
|
18
20
|
module Classifiers
|
19
|
-
|
20
21
|
# = Introduction
|
21
|
-
# This is an implementation of the PRISM algorithm (Cendrowska, 1987)
|
22
|
+
# This is an implementation of the PRISM algorithm (Cendrowska, 1987)
|
22
23
|
# Given a set of preclassified examples, it builds a set of rules
|
23
24
|
# to predict the class of other instaces.
|
24
|
-
#
|
25
|
-
# J. Cendrowska (1987). PRISM: An algorithm for inducing modular rules.
|
25
|
+
#
|
26
|
+
# J. Cendrowska (1987). PRISM: An algorithm for inducing modular rules.
|
26
27
|
# International Journal of Man-Machine Studies. 27(4):349-370.
|
27
28
|
class Prism < Classifier
|
28
|
-
|
29
|
-
|
29
|
+
attr_reader :data_set, :rules, :majority_class
|
30
|
+
|
31
|
+
parameters_info(
|
32
|
+
fallback_class: 'Default class returned when no rule matches.',
|
33
|
+
bin_count: 'Number of bins used to discretize numeric attributes.',
|
34
|
+
default_class: 'Return this value when no rule matches.',
|
35
|
+
tie_break: 'Strategy when multiple conditions have equal ratios.'
|
36
|
+
)
|
37
|
+
|
38
|
+
def initialize
|
39
|
+
super()
|
40
|
+
@fallback_class = nil
|
41
|
+
@bin_count = 10
|
42
|
+
@attr_bins = {}
|
43
|
+
|
44
|
+
@default_class = nil
|
45
|
+
@tie_break = :first
|
46
|
+
@bin_count = 10
|
47
|
+
@attr_bins = {}
|
48
|
+
end
|
30
49
|
|
31
50
|
# Build a new Prism classifier. You must provide a DataSet instance
|
32
|
-
# as parameter. The last attribute of each item is considered as
|
51
|
+
# as parameter. The last attribute of each item is considered as
|
33
52
|
# the item class.
|
53
|
+
# @param data_set [Object]
|
54
|
+
# @return [Object]
|
34
55
|
def build(data_set)
|
35
56
|
data_set.check_not_empty
|
36
57
|
@data_set = data_set
|
58
|
+
|
59
|
+
freqs = Hash.new(0)
|
60
|
+
@data_set.data_items.each { |item| freqs[item.last] += 1 }
|
61
|
+
@majority_class = freqs.max_by { |_, v| v }&.first
|
62
|
+
@fallback_class = @default_class if @default_class
|
63
|
+
@fallback_class = @majority_class if @fallback_class.nil?
|
64
|
+
|
37
65
|
domains = @data_set.build_domains
|
38
|
-
|
66
|
+
@attr_bins = {}
|
67
|
+
domains[0...-1].each_with_index do |domain, i|
|
68
|
+
@attr_bins[@data_set.data_labels[i]] = discretize_range(domain, @bin_count) if domain.is_a?(Array) && domain.length == 2 && domain.all? { |v| v.is_a? Numeric }
|
69
|
+
end
|
70
|
+
instances = @data_set.data_items.collect { |data| data }
|
39
71
|
@rules = []
|
40
72
|
domains.last.each do |class_value|
|
41
|
-
while(
|
73
|
+
while class_value?(instances, class_value)
|
42
74
|
rule = build_rule(class_value, instances)
|
43
75
|
@rules << rule
|
44
|
-
instances = instances.
|
76
|
+
instances = instances.reject { |data| matches_conditions(data, rule[:conditions]) }
|
45
77
|
end
|
46
78
|
end
|
47
|
-
|
79
|
+
self
|
48
80
|
end
|
49
81
|
|
50
82
|
# You can evaluate new data, predicting its class.
|
51
83
|
# e.g.
|
52
|
-
# classifier.eval(['New York', '<30', 'F']) # => 'Y'
|
84
|
+
# classifier.eval(['New York', '<30', 'F']) # => 'Y'
|
85
|
+
# @param instace [Object]
|
86
|
+
# @return [Object]
|
53
87
|
def eval(instace)
|
54
88
|
@rules.each do |rule|
|
55
89
|
return rule[:class_value] if matches_conditions(instace, rule[:conditions])
|
56
90
|
end
|
57
|
-
|
91
|
+
@default_class || @fallback_class
|
58
92
|
end
|
59
|
-
|
93
|
+
|
60
94
|
# This method returns the generated rules in ruby code.
|
61
95
|
# e.g.
|
62
|
-
#
|
96
|
+
#
|
63
97
|
# classifier.get_rules
|
64
98
|
# # => if age_range == '<30' then marketing_target = 'Y'
|
65
99
|
# elsif age_range == '>80' then marketing_target = 'Y'
|
@@ -67,131 +101,188 @@ module Ai4r
|
|
67
101
|
# else marketing_target = 'N'
|
68
102
|
# end
|
69
103
|
#
|
70
|
-
# It is a nice way to inspect induction results, and also to execute them:
|
104
|
+
# It is a nice way to inspect induction results, and also to execute them:
|
71
105
|
# age_range = '[30-50)'
|
72
106
|
# city = 'New York'
|
73
|
-
# eval(classifier.get_rules)
|
107
|
+
# eval(classifier.get_rules)
|
74
108
|
# puts marketing_target
|
75
109
|
# 'Y'
|
110
|
+
# @return [Object]
|
76
111
|
def get_rules
|
77
112
|
out = "if #{join_terms(@rules.first)} then #{then_clause(@rules.first)}"
|
78
|
-
@rules[1...-1].each do |rule|
|
113
|
+
@rules[1...-1].each do |rule|
|
79
114
|
out += "\nelsif #{join_terms(rule)} then #{then_clause(rule)}"
|
80
115
|
end
|
81
116
|
out += "\nelse #{then_clause(@rules.last)}" if @rules.size > 1
|
82
117
|
out += "\nend"
|
83
|
-
|
118
|
+
out
|
84
119
|
end
|
85
|
-
|
120
|
+
|
86
121
|
protected
|
87
|
-
|
122
|
+
|
123
|
+
# @param data [Object]
|
124
|
+
# @param attr [Object]
|
125
|
+
# @return [Object]
|
88
126
|
def get_attr_value(data, attr)
|
89
127
|
data[@data_set.get_index(attr)]
|
90
128
|
end
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
129
|
+
|
130
|
+
# @param instances [Object]
|
131
|
+
# @param class_value [Object]
|
132
|
+
# @return [Object]
|
133
|
+
def class_value?(instances, class_value)
|
134
|
+
instances.any? { |data| data.last == class_value }
|
95
135
|
end
|
96
|
-
|
97
|
-
|
136
|
+
|
137
|
+
# @param instances [Object]
|
138
|
+
# @param rule [Object]
|
139
|
+
# @return [Object]
|
140
|
+
def perfect?(instances, rule)
|
98
141
|
class_value = rule[:class_value]
|
99
|
-
instances.each do |data|
|
100
|
-
return false if data.last != class_value
|
142
|
+
instances.each do |data|
|
143
|
+
return false if (data.last != class_value) && matches_conditions(data, rule[:conditions])
|
101
144
|
end
|
102
|
-
|
145
|
+
true
|
103
146
|
end
|
104
|
-
|
147
|
+
|
148
|
+
# @param data [Object]
|
149
|
+
# @param conditions [Object]
|
150
|
+
# @return [Object]
|
105
151
|
def matches_conditions(data, conditions)
|
106
152
|
conditions.each_pair do |attr_label, attr_value|
|
107
|
-
|
153
|
+
value = get_attr_value(data, attr_label)
|
154
|
+
if attr_value.is_a?(Range)
|
155
|
+
return false unless attr_value.include?(value)
|
156
|
+
else
|
157
|
+
return false unless value == attr_value
|
158
|
+
end
|
108
159
|
end
|
109
|
-
|
160
|
+
true
|
110
161
|
end
|
111
|
-
|
162
|
+
|
163
|
+
# @param class_value [Object]
|
164
|
+
# @param instances [Object]
|
165
|
+
# @return [Object]
|
112
166
|
def build_rule(class_value, instances)
|
113
|
-
rule = {:
|
114
|
-
rule_instances = instances.collect {|data| data }
|
115
|
-
attributes = @data_set.data_labels[0...-1].collect {|label| label }
|
116
|
-
until(
|
167
|
+
rule = { class_value: class_value, conditions: {} }
|
168
|
+
rule_instances = instances.collect { |data| data }
|
169
|
+
attributes = @data_set.data_labels[0...-1].collect { |label| label }
|
170
|
+
until perfect?(instances, rule) || attributes.empty?
|
117
171
|
freq_table = build_freq_table(rule_instances, attributes, class_value)
|
118
172
|
condition = get_condition(freq_table)
|
119
173
|
rule[:conditions].merge!(condition)
|
120
|
-
|
121
|
-
|
174
|
+
attributes.delete(condition.keys.first)
|
175
|
+
rule_instances = rule_instances.select do |data|
|
176
|
+
matches_conditions(data, condition)
|
122
177
|
end
|
123
178
|
end
|
124
|
-
|
179
|
+
rule
|
125
180
|
end
|
126
|
-
|
181
|
+
|
127
182
|
# Returns a structure with the folloring format:
|
128
183
|
# => {attr1_label => { :attr1_value1 => [p, t], attr1_value2 => [p, t], ... },
|
129
184
|
# attr2_label => { :attr2_value1 => [p, t], attr2_value2 => [p, t], ... },
|
130
185
|
# ...
|
131
186
|
# }
|
132
187
|
# where p is the number of instances classified as class_value
|
133
|
-
# with that attribute value, and t is the total number of instances with
|
188
|
+
# with that attribute value, and t is the total number of instances with
|
134
189
|
# that attribute value
|
190
|
+
# @param rule_instances [Object]
|
191
|
+
# @param attributes [Object]
|
192
|
+
# @param class_value [Object]
|
193
|
+
# @return [Object]
|
135
194
|
def build_freq_table(rule_instances, attributes, class_value)
|
136
|
-
freq_table =
|
195
|
+
freq_table = {}
|
137
196
|
rule_instances.each do |data|
|
138
197
|
attributes.each do |attr_label|
|
139
198
|
attr_freqs = freq_table[attr_label] || Hash.new([0, 0])
|
140
|
-
|
141
|
-
|
142
|
-
|
199
|
+
value = get_attr_value(data, attr_label)
|
200
|
+
if (bins = @attr_bins[attr_label])
|
201
|
+
value = bins.find { |b| b.include?(value) }
|
202
|
+
end
|
203
|
+
pt = attr_freqs[value]
|
204
|
+
pt = [data.last == class_value ? pt[0] + 1 : pt[0], pt[1] + 1]
|
205
|
+
attr_freqs[value] = pt
|
143
206
|
freq_table[attr_label] = attr_freqs
|
144
207
|
end
|
145
208
|
end
|
146
|
-
|
209
|
+
freq_table
|
147
210
|
end
|
148
|
-
|
211
|
+
|
149
212
|
# returns a single conditional term: {attrN_label => attrN_valueM}
|
150
213
|
# selecting the attribute with higher pt ratio
|
151
|
-
# (occurrences of attribute value classified as class_value /
|
214
|
+
# (occurrences of attribute value classified as class_value /
|
152
215
|
# occurrences of attribute value)
|
216
|
+
# @param freq_table [Object]
|
217
|
+
# @return [Object]
|
153
218
|
def get_condition(freq_table)
|
154
219
|
best_pt = [0, 0]
|
155
220
|
condition = nil
|
156
221
|
freq_table.each do |attr_label, attr_freqs|
|
157
222
|
attr_freqs.each do |attr_value, pt|
|
158
|
-
if
|
223
|
+
if better_pt(pt, best_pt)
|
159
224
|
condition = { attr_label => attr_value }
|
160
225
|
best_pt = pt
|
161
226
|
end
|
162
227
|
end
|
163
228
|
end
|
164
|
-
|
229
|
+
condition
|
165
230
|
end
|
166
|
-
|
231
|
+
|
167
232
|
# pt = [p, t]
|
168
233
|
# p = occurrences of attribute value with instance classified as class_value
|
169
234
|
# t = occurrences of attribute value
|
170
235
|
# a pt is better if:
|
171
236
|
# 1- its ratio is higher
|
172
|
-
# 2- its ratio is equal, and has a higher p
|
237
|
+
# 2- its ratio is equal, and has a higher p
|
238
|
+
# @param pt [Object]
|
239
|
+
# @param best_pt [Object]
|
240
|
+
# @return [Object]
|
173
241
|
def better_pt(pt, best_pt)
|
174
|
-
return false if pt[1]
|
175
|
-
return true if best_pt[1]
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
return
|
242
|
+
return false if pt[1].zero?
|
243
|
+
return true if best_pt[1].zero?
|
244
|
+
|
245
|
+
a = pt[0] * best_pt[1]
|
246
|
+
b = best_pt[0] * pt[1]
|
247
|
+
return true if a > b || (a == b && pt[0] > best_pt[0])
|
248
|
+
return true if a == b && pt[0] == best_pt[0] && @tie_break == :last
|
249
|
+
|
250
|
+
false
|
180
251
|
end
|
181
|
-
|
252
|
+
|
253
|
+
# @param range [Object]
|
254
|
+
# @param bins [Object]
|
255
|
+
# @return [Object]
|
256
|
+
def discretize_range(range, bins)
|
257
|
+
min, max = range
|
258
|
+
step = (max - min).to_f / bins
|
259
|
+
ranges = []
|
260
|
+
bins.times do |i|
|
261
|
+
low = min + (i * step)
|
262
|
+
high = i == bins - 1 ? max : min + ((i + 1) * step)
|
263
|
+
ranges << (i == bins - 1 ? (low..high) : (low...high))
|
264
|
+
end
|
265
|
+
ranges
|
266
|
+
end
|
267
|
+
|
268
|
+
# @param rule [Object]
|
269
|
+
# @return [Object]
|
182
270
|
def join_terms(rule)
|
183
|
-
terms = []
|
184
|
-
|
185
|
-
|
271
|
+
terms = rule[:conditions].map do |attr_label, attr_value|
|
272
|
+
if attr_value.is_a?(Range)
|
273
|
+
"(#{attr_value}).include?(#{attr_label})"
|
274
|
+
else
|
275
|
+
"#{attr_label} == '#{attr_value}'"
|
276
|
+
end
|
186
277
|
end
|
187
|
-
|
278
|
+
terms.join(' and ').to_s
|
188
279
|
end
|
189
|
-
|
280
|
+
|
281
|
+
# @param rule [Object]
|
282
|
+
# @return [Object]
|
190
283
|
def then_clause(rule)
|
191
|
-
"#{@data_set.
|
284
|
+
"#{@data_set.category_label} = '#{rule[:class_value]}'"
|
192
285
|
end
|
193
|
-
|
194
286
|
end
|
195
287
|
end
|
196
288
|
end
|
197
|
-
|
@@ -0,0 +1,72 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: OpenAI ChatGPT
|
4
|
+
# License:: MPL 1.1
|
5
|
+
# Project:: ai4r
|
6
|
+
#
|
7
|
+
# A simple Random Forest implementation using ID3 decision trees.
|
8
|
+
|
9
|
+
require_relative 'id3'
|
10
|
+
require_relative '../data/data_set'
|
11
|
+
require_relative '../classifiers/classifier'
|
12
|
+
require_relative 'votes'
|
13
|
+
|
14
|
+
module Ai4r
|
15
|
+
module Classifiers
|
16
|
+
# RandomForest ensemble classifier built from decision trees.
|
17
|
+
class RandomForest < Classifier
|
18
|
+
parameters_info n_trees: 'Number of trees to build. Default 10.',
|
19
|
+
sample_size: 'Number of data items for each tree (with replacement). Default: data set size.',
|
20
|
+
feature_fraction:
|
21
|
+
'Fraction of attributes sampled for each tree. Default: sqrt(num_attributes)/num_attributes.',
|
22
|
+
random_seed: 'Seed for reproducible randomness.'
|
23
|
+
|
24
|
+
attr_reader :trees, :features
|
25
|
+
|
26
|
+
def initialize
|
27
|
+
super()
|
28
|
+
@n_trees = 10
|
29
|
+
@sample_size = nil
|
30
|
+
@feature_fraction = nil
|
31
|
+
@random_seed = nil
|
32
|
+
end
|
33
|
+
|
34
|
+
def build(data_set)
|
35
|
+
data_set.check_not_empty
|
36
|
+
rng = @random_seed ? Random.new(@random_seed) : Random.new
|
37
|
+
num_attributes = data_set.data_labels.length - 1
|
38
|
+
frac = @feature_fraction || (Math.sqrt(num_attributes) / num_attributes)
|
39
|
+
feature_count = [1, (num_attributes * frac).round].max
|
40
|
+
@sample_size ||= data_set.data_items.length
|
41
|
+
@trees = []
|
42
|
+
@features = []
|
43
|
+
@n_trees.times do
|
44
|
+
sampled = Array.new(@sample_size) { data_set.data_items.sample(random: rng) }
|
45
|
+
feature_idx = (0...num_attributes).to_a.sample(feature_count, random: rng)
|
46
|
+
tree_items = sampled.map do |item|
|
47
|
+
values = feature_idx.map { |i| item[i] }
|
48
|
+
values + [item.last]
|
49
|
+
end
|
50
|
+
labels = feature_idx.map { |i| data_set.data_labels[i] } + [data_set.data_labels.last]
|
51
|
+
ds = Ai4r::Data::DataSet.new(data_items: tree_items, data_labels: labels)
|
52
|
+
@trees << ID3.new.build(ds)
|
53
|
+
@features << feature_idx
|
54
|
+
end
|
55
|
+
self
|
56
|
+
end
|
57
|
+
|
58
|
+
def eval(data)
|
59
|
+
votes = Votes.new
|
60
|
+
@trees.each_with_index do |tree, idx|
|
61
|
+
sub_data = @features[idx].map { |i| data[i] }
|
62
|
+
votes.increment_category(tree.eval(sub_data))
|
63
|
+
end
|
64
|
+
votes.get_winner
|
65
|
+
end
|
66
|
+
|
67
|
+
def get_rules
|
68
|
+
'RandomForest does not support rule extraction.'
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|