ai4r 1.12 → 2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +174 -0
- data/examples/classifiers/hyperpipes_data.csv +14 -0
- data/examples/classifiers/hyperpipes_example.rb +22 -0
- data/examples/classifiers/ib1_example.rb +12 -0
- data/examples/classifiers/id3_example.rb +15 -10
- data/examples/classifiers/id3_graphviz_example.rb +17 -0
- data/examples/classifiers/logistic_regression_example.rb +11 -0
- data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
- data/examples/classifiers/naive_bayes_example.rb +12 -13
- data/examples/classifiers/one_r_example.rb +27 -0
- data/examples/classifiers/parameter_tutorial.rb +29 -0
- data/examples/classifiers/prism_nominal_example.rb +15 -0
- data/examples/classifiers/prism_numeric_example.rb +21 -0
- data/examples/classifiers/simple_linear_regression_example.csv +159 -0
- data/examples/classifiers/simple_linear_regression_example.rb +18 -0
- data/examples/classifiers/zero_and_one_r_example.rb +34 -0
- data/examples/classifiers/zero_one_r_data.csv +8 -0
- data/examples/clusterers/clusterer_example.rb +62 -0
- data/examples/clusterers/dbscan_example.rb +17 -0
- data/examples/clusterers/dendrogram_example.rb +17 -0
- data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
- data/examples/clusterers/kmeans_custom_example.rb +26 -0
- data/examples/genetic_algorithm/bitstring_example.rb +41 -0
- data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
- data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
- data/examples/neural_network/backpropagation_example.rb +49 -48
- data/examples/neural_network/hopfield_example.rb +45 -0
- data/examples/neural_network/patterns_with_base_noise.rb +39 -39
- data/examples/neural_network/patterns_with_noise.rb +41 -39
- data/examples/neural_network/train_epochs_callback.rb +25 -0
- data/examples/neural_network/training_patterns.rb +39 -39
- data/examples/neural_network/transformer_text_classification.rb +78 -0
- data/examples/neural_network/xor_example.rb +23 -22
- data/examples/reinforcement/q_learning_example.rb +10 -0
- data/examples/som/som_data.rb +155 -152
- data/examples/som/som_multi_node_example.rb +12 -13
- data/examples/som/som_single_example.rb +12 -15
- data/examples/transformer/decode_classifier_example.rb +68 -0
- data/examples/transformer/deterministic_example.rb +10 -0
- data/examples/transformer/seq2seq_example.rb +16 -0
- data/lib/ai4r/classifiers/classifier.rb +24 -16
- data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
- data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
- data/lib/ai4r/classifiers/ib1.rb +122 -32
- data/lib/ai4r/classifiers/id3.rb +527 -144
- data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
- data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
- data/lib/ai4r/classifiers/naive_bayes.rb +112 -48
- data/lib/ai4r/classifiers/one_r.rb +112 -44
- data/lib/ai4r/classifiers/prism.rb +167 -76
- data/lib/ai4r/classifiers/random_forest.rb +72 -0
- data/lib/ai4r/classifiers/simple_linear_regression.rb +143 -0
- data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
- data/lib/ai4r/classifiers/votes.rb +57 -0
- data/lib/ai4r/classifiers/zero_r.rb +71 -30
- data/lib/ai4r/clusterers/average_linkage.rb +46 -27
- data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
- data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
- data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
- data/lib/ai4r/clusterers/clusterer.rb +28 -24
- data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
- data/lib/ai4r/clusterers/dbscan.rb +134 -0
- data/lib/ai4r/clusterers/diana.rb +75 -49
- data/lib/ai4r/clusterers/k_means.rb +309 -72
- data/lib/ai4r/clusterers/median_linkage.rb +49 -33
- data/lib/ai4r/clusterers/single_linkage.rb +196 -88
- data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
- data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +63 -0
- data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
- data/lib/ai4r/data/data_set.rb +229 -100
- data/lib/ai4r/data/parameterizable.rb +31 -25
- data/lib/ai4r/data/proximity.rb +72 -50
- data/lib/ai4r/data/statistics.rb +46 -35
- data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
- data/lib/ai4r/experiment/split.rb +39 -0
- data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
- data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
- data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
- data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
- data/lib/ai4r/neural_network/activation_functions.rb +37 -0
- data/lib/ai4r/neural_network/backpropagation.rb +419 -143
- data/lib/ai4r/neural_network/hopfield.rb +175 -58
- data/lib/ai4r/neural_network/transformer.rb +194 -0
- data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
- data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
- data/lib/ai4r/reinforcement/q_learning.rb +51 -0
- data/lib/ai4r/search/a_star.rb +76 -0
- data/lib/ai4r/search/bfs.rb +50 -0
- data/lib/ai4r/search/dfs.rb +50 -0
- data/lib/ai4r/search/mcts.rb +118 -0
- data/lib/ai4r/search.rb +12 -0
- data/lib/ai4r/som/distance_metrics.rb +29 -0
- data/lib/ai4r/som/layer.rb +28 -17
- data/lib/ai4r/som/node.rb +61 -32
- data/lib/ai4r/som/som.rb +158 -41
- data/lib/ai4r/som/two_phase_layer.rb +21 -25
- data/lib/ai4r/version.rb +3 -0
- data/lib/ai4r.rb +58 -27
- metadata +117 -106
- data/README.rdoc +0 -44
- data/test/classifiers/hyperpipes_test.rb +0 -84
- data/test/classifiers/ib1_test.rb +0 -78
- data/test/classifiers/id3_test.rb +0 -208
- data/test/classifiers/multilayer_perceptron_test.rb +0 -79
- data/test/classifiers/naive_bayes_test.rb +0 -43
- data/test/classifiers/one_r_test.rb +0 -62
- data/test/classifiers/prism_test.rb +0 -85
- data/test/classifiers/zero_r_test.rb +0 -50
- data/test/clusterers/average_linkage_test.rb +0 -51
- data/test/clusterers/bisecting_k_means_test.rb +0 -66
- data/test/clusterers/centroid_linkage_test.rb +0 -53
- data/test/clusterers/complete_linkage_test.rb +0 -57
- data/test/clusterers/diana_test.rb +0 -69
- data/test/clusterers/k_means_test.rb +0 -100
- data/test/clusterers/median_linkage_test.rb +0 -53
- data/test/clusterers/single_linkage_test.rb +0 -122
- data/test/clusterers/ward_linkage_test.rb +0 -53
- data/test/clusterers/weighted_average_linkage_test.rb +0 -53
- data/test/data/data_set_test.rb +0 -96
- data/test/data/proximity_test.rb +0 -81
- data/test/data/statistics_test.rb +0 -65
- data/test/experiment/classifier_evaluator_test.rb +0 -76
- data/test/genetic_algorithm/chromosome_test.rb +0 -57
- data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
- data/test/neural_network/backpropagation_test.rb +0 -82
- data/test/neural_network/hopfield_test.rb +0 -72
- data/test/som/som_test.rb +0 -97
@@ -1,118 +1,194 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens (Implementation only)
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
|
-
# Url::
|
6
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
12
|
require 'set'
|
11
|
-
|
12
|
-
|
13
|
+
require_relative '../data/data_set'
|
14
|
+
require_relative '../classifiers/classifier'
|
15
|
+
require_relative '../classifiers/votes'
|
13
16
|
|
14
17
|
module Ai4r
|
18
|
+
# Collection of classifier algorithms.
|
15
19
|
module Classifiers
|
16
|
-
|
17
20
|
include Ai4r::Data
|
18
|
-
|
21
|
+
|
19
22
|
# = Introduction
|
20
|
-
#
|
21
|
-
# A fast classifier algorithm, created by Lucio de Souza Coelho
|
23
|
+
#
|
24
|
+
# A fast classifier algorithm, created by Lucio de Souza Coelho
|
22
25
|
# and Len Trigg.
|
23
26
|
class Hyperpipes < Classifier
|
24
|
-
|
25
27
|
attr_reader :data_set, :pipes
|
26
28
|
|
29
|
+
parameters_info tie_break:
|
30
|
+
'Strategy used when more than one class has the same maximal vote. ' \
|
31
|
+
'Valid values are :last (default) and :random.',
|
32
|
+
margin: 'Numeric margin added to the bounds of numeric attributes.',
|
33
|
+
random_seed: 'Seed for random tie-breaking when tie_break is :random.'
|
34
|
+
|
35
|
+
# @return [Object]
|
36
|
+
def initialize
|
37
|
+
super()
|
38
|
+
@tie_break = :last
|
39
|
+
@margin = 0
|
40
|
+
@random_seed = nil
|
41
|
+
@rng = nil
|
42
|
+
end
|
43
|
+
|
27
44
|
# Build a new Hyperpipes classifier. You must provide a DataSet instance
|
28
|
-
# as parameter. The last attribute of each item is considered as
|
45
|
+
# as parameter. The last attribute of each item is considered as
|
29
46
|
# the item class.
|
47
|
+
# @param data_set [Object]
|
48
|
+
# @return [Object]
|
30
49
|
def build(data_set)
|
31
50
|
data_set.check_not_empty
|
32
51
|
@data_set = data_set
|
33
52
|
@domains = data_set.build_domains
|
34
|
-
|
53
|
+
|
35
54
|
@pipes = {}
|
36
|
-
@domains.last.each {|cat| @pipes[cat] = build_pipe(@data_set)}
|
37
|
-
@data_set.data_items.each {|item| update_pipe(@pipes[item.last], item) }
|
38
|
-
|
39
|
-
|
55
|
+
@domains.last.each { |cat| @pipes[cat] = build_pipe(@data_set) }
|
56
|
+
@data_set.data_items.each { |item| update_pipe(@pipes[item.last], item) }
|
57
|
+
|
58
|
+
self
|
40
59
|
end
|
41
|
-
|
60
|
+
|
42
61
|
# You can evaluate new data, predicting its class.
|
43
62
|
# e.g.
|
44
|
-
# classifier.eval(['New York', '<30', 'F']) # => 'Y'
|
63
|
+
# classifier.eval(['New York', '<30', 'F']) # => 'Y'
|
64
|
+
# Tie resolution is controlled by +tie_break+ parameter.
|
65
|
+
# @param data [Object]
|
66
|
+
# @return [Object]
|
45
67
|
def eval(data)
|
46
|
-
votes =
|
68
|
+
votes = Votes.new
|
47
69
|
@pipes.each do |category, pipe|
|
48
70
|
pipe.each_with_index do |bounds, i|
|
49
71
|
if data[i].is_a? Numeric
|
50
|
-
votes
|
51
|
-
|
52
|
-
votes
|
72
|
+
votes.increment_category(category) if data[i].between?(bounds[:min], bounds[:max])
|
73
|
+
elsif bounds[data[i]]
|
74
|
+
votes.increment_category(category)
|
53
75
|
end
|
54
76
|
end
|
55
77
|
end
|
56
|
-
|
78
|
+
rng = @rng || (@random_seed.nil? ? Random.new : Random.new(@random_seed))
|
79
|
+
votes.get_winner(@tie_break, rng: rng)
|
57
80
|
end
|
58
|
-
|
81
|
+
# rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity
|
82
|
+
|
59
83
|
# This method returns the generated rules in ruby code.
|
60
84
|
# e.g.
|
61
|
-
#
|
85
|
+
#
|
62
86
|
# classifier.get_rules
|
63
87
|
# # => if age_range == '<30' then marketing_target = 'Y'
|
64
88
|
# elsif age_range == '[30-50)' then marketing_target = 'N'
|
65
89
|
# elsif age_range == '[50-80]' then marketing_target = 'N'
|
66
90
|
# end
|
67
91
|
#
|
68
|
-
# It is a nice way to inspect induction results, and also to execute them:
|
92
|
+
# It is a nice way to inspect induction results, and also to execute them:
|
69
93
|
# marketing_target = nil
|
70
|
-
# eval classifier.get_rules
|
94
|
+
# eval classifier.get_rules
|
71
95
|
# puts marketing_target
|
72
96
|
# # => 'Y'
|
97
|
+
# @return [Object]
|
98
|
+
# rubocop:disable Metrics/AbcSize
|
73
99
|
def get_rules
|
74
100
|
rules = []
|
75
|
-
rules <<
|
101
|
+
rules << 'votes = Votes.new'
|
76
102
|
data = @data_set.data_items.first
|
77
|
-
labels = @data_set.data_labels.collect
|
103
|
+
labels = @data_set.data_labels.collect(&:to_s)
|
78
104
|
@pipes.each do |category, pipe|
|
79
105
|
pipe.each_with_index do |bounds, i|
|
80
|
-
rule = "votes
|
81
|
-
if data[i].is_a? Numeric
|
82
|
-
|
106
|
+
rule = "votes.increment_category('#{category}') "
|
107
|
+
rule += if data[i].is_a? Numeric
|
108
|
+
"if #{labels[i]} >= #{bounds[:min]} && #{labels[i]} <= #{bounds[:max]}"
|
109
|
+
else
|
110
|
+
"if #{bounds.inspect}[#{labels[i]}]"
|
111
|
+
end
|
112
|
+
rules << rule
|
113
|
+
end
|
114
|
+
end
|
115
|
+
rules << "#{labels.last} = votes.get_winner(:#{@tie_break})"
|
116
|
+
rules.join("\n")
|
117
|
+
end
|
118
|
+
# rubocop:enable Metrics/AbcSize
|
119
|
+
# rubocop:enable Naming/AccessorMethodName
|
120
|
+
|
121
|
+
# Return a summary representation of all pipes.
|
122
|
+
#
|
123
|
+
# The returned hash maps each category to another hash where the keys are
|
124
|
+
# attribute labels and the values are either numeric ranges
|
125
|
+
# `[min, max]` (including the optional margin) or a Set of nominal values.
|
126
|
+
#
|
127
|
+
# classifier.pipes_summary
|
128
|
+
# # => { "Y" => { "city" => #{Set['New York', 'Chicago']},
|
129
|
+
# "age" => [18, 85],
|
130
|
+
# "gender" => #{Set['M', 'F']} },
|
131
|
+
# "N" => { ... } }
|
132
|
+
#
|
133
|
+
# The optional +margin+ parameter expands numeric bounds by the given
|
134
|
+
# fraction. A value of 0.1 would enlarge each range by 10%.
|
135
|
+
# @param margin [Object]
|
136
|
+
# @return [Object]
|
137
|
+
# rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity
|
138
|
+
def pipes_summary(margin: 0)
|
139
|
+
raise 'Model not built yet' unless @data_set && @pipes
|
140
|
+
|
141
|
+
labels = @data_set.data_labels[0...-1]
|
142
|
+
summary = {}
|
143
|
+
@pipes.each do |category, pipe|
|
144
|
+
attr_summary = {}
|
145
|
+
pipe.each_with_index do |bounds, i|
|
146
|
+
if bounds.is_a?(Hash) && bounds.key?(:min) && bounds.key?(:max)
|
147
|
+
min = bounds[:min]
|
148
|
+
max = bounds[:max]
|
149
|
+
range_margin = (max - min) * margin
|
150
|
+
attr_summary[labels[i]] = [min - range_margin, max + range_margin]
|
83
151
|
else
|
84
|
-
|
152
|
+
attr_summary[labels[i]] = bounds.select { |_k, v| v }.keys.to_set
|
85
153
|
end
|
86
|
-
rules << rule
|
87
154
|
end
|
155
|
+
summary[category] = attr_summary
|
88
156
|
end
|
89
|
-
|
90
|
-
return rules.join("\n")
|
157
|
+
summary
|
91
158
|
end
|
92
|
-
|
159
|
+
# rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity
|
160
|
+
|
93
161
|
protected
|
94
162
|
|
163
|
+
# @param data_set [Object]
|
164
|
+
# @return [Object]
|
95
165
|
def build_pipe(data_set)
|
96
166
|
data_set.data_items.first[0...-1].collect do |att|
|
97
167
|
if att.is_a? Numeric
|
98
|
-
{:
|
168
|
+
{ min: Float::INFINITY, max: -Float::INFINITY }
|
99
169
|
else
|
100
170
|
Hash.new(false)
|
101
171
|
end
|
102
172
|
end
|
103
173
|
end
|
104
|
-
|
174
|
+
|
175
|
+
# @param pipe [Object]
|
176
|
+
# @param data_item [Object]
|
177
|
+
# @return [Object]
|
178
|
+
# rubocop:disable Metrics/AbcSize
|
105
179
|
def update_pipe(pipe, data_item)
|
106
180
|
data_item[0...-1].each_with_index do |att, i|
|
107
181
|
if att.is_a? Numeric
|
108
|
-
|
109
|
-
|
182
|
+
min_val = att - @margin
|
183
|
+
max_val = att + @margin
|
184
|
+
pipe[i][:min] = min_val if min_val < pipe[i][:min]
|
185
|
+
pipe[i][:max] = max_val if max_val > pipe[i][:max]
|
110
186
|
else
|
111
187
|
pipe[i][att] = true
|
112
|
-
end
|
188
|
+
end
|
113
189
|
end
|
114
190
|
end
|
115
|
-
|
191
|
+
# rubocop:enable Metrics/AbcSize
|
116
192
|
end
|
117
193
|
end
|
118
194
|
end
|
data/lib/ai4r/classifiers/ib1.rb
CHANGED
@@ -1,21 +1,22 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens (Implementation only)
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
|
-
# Url::
|
6
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
12
|
require 'set'
|
11
|
-
|
12
|
-
|
13
|
+
require_relative '../data/data_set'
|
14
|
+
require_relative '../classifiers/classifier'
|
13
15
|
|
14
16
|
module Ai4r
|
15
17
|
module Classifiers
|
16
|
-
|
17
18
|
# = Introduction
|
18
|
-
#
|
19
|
+
#
|
19
20
|
# IB1 algorithm implementation.
|
20
21
|
# IB1 is the simplest instance-based learning (IBL) algorithm.
|
21
22
|
#
|
@@ -26,45 +27,126 @@ module Ai4r
|
|
26
27
|
# it normalizes its attributes' ranges, processes instances
|
27
28
|
# incrementally, and has a simple policy for tolerating missing values
|
28
29
|
class IB1 < Classifier
|
29
|
-
|
30
|
-
|
30
|
+
attr_reader :data_set, :min_values, :max_values
|
31
|
+
|
32
|
+
parameters_info k: 'Number of nearest neighbors to consider. Default is 1.',
|
33
|
+
distance_function:
|
34
|
+
'Optional custom distance metric taking two instances.',
|
35
|
+
tie_break:
|
36
|
+
'Strategy used when neighbors vote tie. ' \
|
37
|
+
'Valid values are :first (default) and :random.',
|
38
|
+
random_seed:
|
39
|
+
'Seed for random tie-breaking when :tie_break is :random.'
|
40
|
+
|
41
|
+
# @return [Object]
|
42
|
+
def initialize
|
43
|
+
super()
|
44
|
+
@k = 1
|
45
|
+
@distance_function = nil
|
46
|
+
@tie_break = :first
|
47
|
+
@random_seed = nil
|
48
|
+
@rng = nil
|
49
|
+
end
|
31
50
|
|
32
51
|
# Build a new IB1 classifier. You must provide a DataSet instance
|
33
|
-
# as parameter. The last attribute of each item is considered as
|
52
|
+
# as parameter. The last attribute of each item is considered as
|
34
53
|
# the item class.
|
54
|
+
# @param data_set [Object]
|
55
|
+
# @return [Object]
|
35
56
|
def build(data_set)
|
36
57
|
data_set.check_not_empty
|
37
58
|
@data_set = data_set
|
38
59
|
@min_values = Array.new(data_set.data_labels.length)
|
39
60
|
@max_values = Array.new(data_set.data_labels.length)
|
40
61
|
data_set.data_items.each { |data_item| update_min_max(data_item[0...-1]) }
|
41
|
-
|
62
|
+
self
|
63
|
+
end
|
64
|
+
|
65
|
+
# Append a new instance to the internal dataset. The last element is
|
66
|
+
# considered the class label. Minimum and maximum values for numeric
|
67
|
+
# attributes are updated so that future distance calculations remain
|
68
|
+
# normalized.
|
69
|
+
# @param data_item [Object]
|
70
|
+
# @return [Object]
|
71
|
+
def add_instance(data_item)
|
72
|
+
@data_set << data_item
|
73
|
+
update_min_max(data_item[0...-1])
|
74
|
+
self
|
42
75
|
end
|
43
|
-
|
76
|
+
|
44
77
|
# You can evaluate new data, predicting its class.
|
45
78
|
# e.g.
|
46
|
-
# classifier.eval(['New York', '<30', 'F']) # => 'Y'
|
79
|
+
# classifier.eval(['New York', '<30', 'F']) # => 'Y'
|
80
|
+
#
|
81
|
+
# Evaluation does not update internal statistics, keeping the
|
82
|
+
# classifier state unchanged. Use +update_with_instance+ to
|
83
|
+
# incorporate new samples.
|
47
84
|
def eval(data)
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
85
|
+
neighbors = @data_set.data_items.map do |train_item|
|
86
|
+
[distance(data, train_item), train_item.last]
|
87
|
+
end
|
88
|
+
neighbors.sort_by! { |d, _| d }
|
89
|
+
k_limit = [@k, @data_set.data_items.length].min
|
90
|
+
k_neighbors = neighbors.first(k_limit)
|
91
|
+
|
92
|
+
# Include any other neighbors tied with the last selected distance
|
93
|
+
last_distance = k_neighbors.last[0]
|
94
|
+
neighbors[k_limit..].to_a.each do |dist, klass|
|
95
|
+
break if dist > last_distance
|
96
|
+
|
97
|
+
k_neighbors << [dist, klass]
|
57
98
|
end
|
58
|
-
|
99
|
+
|
100
|
+
counts = Hash.new(0)
|
101
|
+
k_neighbors.each { |(_dist, klass)| counts[klass] += 1 }
|
102
|
+
max_votes = counts.values.max
|
103
|
+
tied = counts.select { |_, v| v == max_votes }.keys
|
104
|
+
|
105
|
+
return tied.first if tied.length == 1
|
106
|
+
|
107
|
+
rng = @rng || (@random_seed.nil? ? Random.new : Random.new(@random_seed))
|
108
|
+
|
109
|
+
case @tie_break
|
110
|
+
when :random
|
111
|
+
tied.sample(random: rng)
|
112
|
+
else
|
113
|
+
k_neighbors.each { |(_dist, klass)| return klass if tied.include?(klass) }
|
114
|
+
end
|
115
|
+
end
|
116
|
+
|
117
|
+
# Returns an array with the +k+ nearest instances from the training set
|
118
|
+
# for the given +data+ item. The returned elements are the training data
|
119
|
+
# rows themselves, ordered from the closest to the furthest.
|
120
|
+
# @param data [Object]
|
121
|
+
# @param k [Object]
|
122
|
+
# @return [Object]
|
123
|
+
def neighbors_for(data, k_neighbors)
|
124
|
+
update_min_max(data)
|
125
|
+
@data_set.data_items
|
126
|
+
.map { |train_item| [train_item, distance(data, train_item)] }
|
127
|
+
.sort_by(&:last)
|
128
|
+
.first(k_neighbors)
|
129
|
+
.map(&:first)
|
130
|
+
end
|
131
|
+
|
132
|
+
# Update min/max values with the provided instance attributes. If
|
133
|
+
# +learn+ is true, also append the instance to the training set so the
|
134
|
+
# classifier learns incrementally.
|
135
|
+
def update_with_instance(data_item, learn: false)
|
136
|
+
update_min_max(data_item[0...-1])
|
137
|
+
@data_set << data_item if learn
|
138
|
+
self
|
59
139
|
end
|
60
|
-
|
140
|
+
|
61
141
|
protected
|
62
142
|
|
63
143
|
# We keep in the state the min and max value of each attribute,
|
64
144
|
# to provide normalized distances between to values of a numeric attribute
|
145
|
+
# @param atts [Object]
|
146
|
+
# @return [Object]
|
65
147
|
def update_min_max(atts)
|
66
148
|
atts.each_with_index do |att, i|
|
67
|
-
if att
|
149
|
+
if att.is_a?(Numeric)
|
68
150
|
@min_values[i] = att if @min_values[i].nil? || @min_values[i] > att
|
69
151
|
@max_values[i] = att if @max_values[i].nil? || @max_values[i] < att
|
70
152
|
end
|
@@ -80,10 +162,15 @@ module Ai4r
|
|
80
162
|
# * 1 if both atts are missing
|
81
163
|
# * normalized numeric att value if other att value is missing and > 0.5
|
82
164
|
# * 1.0-normalized numeric att value if other att value is missing and < 0.5
|
83
|
-
|
165
|
+
# @param a [Object]
|
166
|
+
# @param b [Object]
|
167
|
+
# @return [Object]
|
168
|
+
def distance(data_a, data_b)
|
169
|
+
return @distance_function.call(data_a, data_b) if @distance_function
|
170
|
+
|
84
171
|
d = 0
|
85
|
-
|
86
|
-
att_b =
|
172
|
+
data_a.each_with_index do |att_a, i|
|
173
|
+
att_b = data_b[i]
|
87
174
|
if att_a.nil?
|
88
175
|
if att_b.is_a? Numeric
|
89
176
|
diff = norm(att_b, i)
|
@@ -93,7 +180,7 @@ module Ai4r
|
|
93
180
|
end
|
94
181
|
elsif att_a.is_a? Numeric
|
95
182
|
if att_b.is_a? Numeric
|
96
|
-
diff = norm(att_a, i) - norm(att_b, i)
|
183
|
+
diff = norm(att_a, i) - norm(att_b, i)
|
97
184
|
else
|
98
185
|
diff = norm(att_a, i)
|
99
186
|
diff = 1.0 - diff if diff < 0.5
|
@@ -105,17 +192,20 @@ module Ai4r
|
|
105
192
|
end
|
106
193
|
d += diff * diff
|
107
194
|
end
|
108
|
-
|
195
|
+
d
|
109
196
|
end
|
110
197
|
|
111
198
|
# Returns normalized value att
|
112
199
|
#
|
113
200
|
# index is the index of the attribute in the instance.
|
201
|
+
# @param att [Object]
|
202
|
+
# @param index [Object]
|
203
|
+
# @return [Object]
|
114
204
|
def norm(att, index)
|
115
205
|
return 0 if @min_values[index].nil?
|
116
|
-
|
206
|
+
|
207
|
+
1.0 * (att - @min_values[index]) / (@max_values[index] - @min_values[index])
|
117
208
|
end
|
118
|
-
|
119
209
|
end
|
120
210
|
end
|
121
211
|
end
|