ai4r 1.12 → 2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +174 -0
  3. data/examples/classifiers/hyperpipes_data.csv +14 -0
  4. data/examples/classifiers/hyperpipes_example.rb +22 -0
  5. data/examples/classifiers/ib1_example.rb +12 -0
  6. data/examples/classifiers/id3_example.rb +15 -10
  7. data/examples/classifiers/id3_graphviz_example.rb +17 -0
  8. data/examples/classifiers/logistic_regression_example.rb +11 -0
  9. data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
  10. data/examples/classifiers/naive_bayes_example.rb +12 -13
  11. data/examples/classifiers/one_r_example.rb +27 -0
  12. data/examples/classifiers/parameter_tutorial.rb +29 -0
  13. data/examples/classifiers/prism_nominal_example.rb +15 -0
  14. data/examples/classifiers/prism_numeric_example.rb +21 -0
  15. data/examples/classifiers/simple_linear_regression_example.csv +159 -0
  16. data/examples/classifiers/simple_linear_regression_example.rb +18 -0
  17. data/examples/classifiers/zero_and_one_r_example.rb +34 -0
  18. data/examples/classifiers/zero_one_r_data.csv +8 -0
  19. data/examples/clusterers/clusterer_example.rb +62 -0
  20. data/examples/clusterers/dbscan_example.rb +17 -0
  21. data/examples/clusterers/dendrogram_example.rb +17 -0
  22. data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
  23. data/examples/clusterers/kmeans_custom_example.rb +26 -0
  24. data/examples/genetic_algorithm/bitstring_example.rb +41 -0
  25. data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
  26. data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
  27. data/examples/neural_network/backpropagation_example.rb +49 -48
  28. data/examples/neural_network/hopfield_example.rb +45 -0
  29. data/examples/neural_network/patterns_with_base_noise.rb +39 -39
  30. data/examples/neural_network/patterns_with_noise.rb +41 -39
  31. data/examples/neural_network/train_epochs_callback.rb +25 -0
  32. data/examples/neural_network/training_patterns.rb +39 -39
  33. data/examples/neural_network/transformer_text_classification.rb +78 -0
  34. data/examples/neural_network/xor_example.rb +23 -22
  35. data/examples/reinforcement/q_learning_example.rb +10 -0
  36. data/examples/som/som_data.rb +155 -152
  37. data/examples/som/som_multi_node_example.rb +12 -13
  38. data/examples/som/som_single_example.rb +12 -15
  39. data/examples/transformer/decode_classifier_example.rb +68 -0
  40. data/examples/transformer/deterministic_example.rb +10 -0
  41. data/examples/transformer/seq2seq_example.rb +16 -0
  42. data/lib/ai4r/classifiers/classifier.rb +24 -16
  43. data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
  44. data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
  45. data/lib/ai4r/classifiers/ib1.rb +122 -32
  46. data/lib/ai4r/classifiers/id3.rb +527 -144
  47. data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
  48. data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
  49. data/lib/ai4r/classifiers/naive_bayes.rb +112 -48
  50. data/lib/ai4r/classifiers/one_r.rb +112 -44
  51. data/lib/ai4r/classifiers/prism.rb +167 -76
  52. data/lib/ai4r/classifiers/random_forest.rb +72 -0
  53. data/lib/ai4r/classifiers/simple_linear_regression.rb +143 -0
  54. data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
  55. data/lib/ai4r/classifiers/votes.rb +57 -0
  56. data/lib/ai4r/classifiers/zero_r.rb +71 -30
  57. data/lib/ai4r/clusterers/average_linkage.rb +46 -27
  58. data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
  59. data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
  60. data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
  61. data/lib/ai4r/clusterers/clusterer.rb +28 -24
  62. data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
  63. data/lib/ai4r/clusterers/dbscan.rb +134 -0
  64. data/lib/ai4r/clusterers/diana.rb +75 -49
  65. data/lib/ai4r/clusterers/k_means.rb +309 -72
  66. data/lib/ai4r/clusterers/median_linkage.rb +49 -33
  67. data/lib/ai4r/clusterers/single_linkage.rb +196 -88
  68. data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
  69. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +63 -0
  70. data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
  71. data/lib/ai4r/data/data_set.rb +229 -100
  72. data/lib/ai4r/data/parameterizable.rb +31 -25
  73. data/lib/ai4r/data/proximity.rb +72 -50
  74. data/lib/ai4r/data/statistics.rb +46 -35
  75. data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
  76. data/lib/ai4r/experiment/split.rb +39 -0
  77. data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
  78. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
  79. data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
  80. data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
  81. data/lib/ai4r/neural_network/activation_functions.rb +37 -0
  82. data/lib/ai4r/neural_network/backpropagation.rb +419 -143
  83. data/lib/ai4r/neural_network/hopfield.rb +175 -58
  84. data/lib/ai4r/neural_network/transformer.rb +194 -0
  85. data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
  86. data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
  87. data/lib/ai4r/reinforcement/q_learning.rb +51 -0
  88. data/lib/ai4r/search/a_star.rb +76 -0
  89. data/lib/ai4r/search/bfs.rb +50 -0
  90. data/lib/ai4r/search/dfs.rb +50 -0
  91. data/lib/ai4r/search/mcts.rb +118 -0
  92. data/lib/ai4r/search.rb +12 -0
  93. data/lib/ai4r/som/distance_metrics.rb +29 -0
  94. data/lib/ai4r/som/layer.rb +28 -17
  95. data/lib/ai4r/som/node.rb +61 -32
  96. data/lib/ai4r/som/som.rb +158 -41
  97. data/lib/ai4r/som/two_phase_layer.rb +21 -25
  98. data/lib/ai4r/version.rb +3 -0
  99. data/lib/ai4r.rb +58 -27
  100. metadata +117 -106
  101. data/README.rdoc +0 -44
  102. data/test/classifiers/hyperpipes_test.rb +0 -84
  103. data/test/classifiers/ib1_test.rb +0 -78
  104. data/test/classifiers/id3_test.rb +0 -208
  105. data/test/classifiers/multilayer_perceptron_test.rb +0 -79
  106. data/test/classifiers/naive_bayes_test.rb +0 -43
  107. data/test/classifiers/one_r_test.rb +0 -62
  108. data/test/classifiers/prism_test.rb +0 -85
  109. data/test/classifiers/zero_r_test.rb +0 -50
  110. data/test/clusterers/average_linkage_test.rb +0 -51
  111. data/test/clusterers/bisecting_k_means_test.rb +0 -66
  112. data/test/clusterers/centroid_linkage_test.rb +0 -53
  113. data/test/clusterers/complete_linkage_test.rb +0 -57
  114. data/test/clusterers/diana_test.rb +0 -69
  115. data/test/clusterers/k_means_test.rb +0 -100
  116. data/test/clusterers/median_linkage_test.rb +0 -53
  117. data/test/clusterers/single_linkage_test.rb +0 -122
  118. data/test/clusterers/ward_linkage_test.rb +0 -53
  119. data/test/clusterers/weighted_average_linkage_test.rb +0 -53
  120. data/test/data/data_set_test.rb +0 -96
  121. data/test/data/proximity_test.rb +0 -81
  122. data/test/data/statistics_test.rb +0 -65
  123. data/test/experiment/classifier_evaluator_test.rb +0 -76
  124. data/test/genetic_algorithm/chromosome_test.rb +0 -57
  125. data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
  126. data/test/neural_network/backpropagation_test.rb +0 -82
  127. data/test/neural_network/hopfield_test.rb +0 -72
  128. data/test/som/som_test.rb +0 -97
@@ -1,118 +1,194 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens (Implementation only)
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
- # Url:: http://www.ai4r.org/
6
+ # Url:: https://github.com/SergioFierens/ai4r
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
12
  require 'set'
11
- require File.dirname(__FILE__) + '/../data/data_set'
12
- require File.dirname(__FILE__) + '/../classifiers/classifier'
13
+ require_relative '../data/data_set'
14
+ require_relative '../classifiers/classifier'
15
+ require_relative '../classifiers/votes'
13
16
 
14
17
  module Ai4r
18
+ # Collection of classifier algorithms.
15
19
  module Classifiers
16
-
17
20
  include Ai4r::Data
18
-
21
+
19
22
  # = Introduction
20
- #
21
- # A fast classifier algorithm, created by Lucio de Souza Coelho
23
+ #
24
+ # A fast classifier algorithm, created by Lucio de Souza Coelho
22
25
  # and Len Trigg.
23
26
  class Hyperpipes < Classifier
24
-
25
27
  attr_reader :data_set, :pipes
26
28
 
29
+ parameters_info tie_break:
30
+ 'Strategy used when more than one class has the same maximal vote. ' \
31
+ 'Valid values are :last (default) and :random.',
32
+ margin: 'Numeric margin added to the bounds of numeric attributes.',
33
+ random_seed: 'Seed for random tie-breaking when tie_break is :random.'
34
+
35
+ # @return [Object]
36
+ def initialize
37
+ super()
38
+ @tie_break = :last
39
+ @margin = 0
40
+ @random_seed = nil
41
+ @rng = nil
42
+ end
43
+
27
44
  # Build a new Hyperpipes classifier. You must provide a DataSet instance
28
- # as parameter. The last attribute of each item is considered as
45
+ # as parameter. The last attribute of each item is considered as
29
46
  # the item class.
47
+ # @param data_set [Object]
48
+ # @return [Object]
30
49
  def build(data_set)
31
50
  data_set.check_not_empty
32
51
  @data_set = data_set
33
52
  @domains = data_set.build_domains
34
-
53
+
35
54
  @pipes = {}
36
- @domains.last.each {|cat| @pipes[cat] = build_pipe(@data_set)}
37
- @data_set.data_items.each {|item| update_pipe(@pipes[item.last], item) }
38
-
39
- return self
55
+ @domains.last.each { |cat| @pipes[cat] = build_pipe(@data_set) }
56
+ @data_set.data_items.each { |item| update_pipe(@pipes[item.last], item) }
57
+
58
+ self
40
59
  end
41
-
60
+
42
61
  # You can evaluate new data, predicting its class.
43
62
  # e.g.
44
- # classifier.eval(['New York', '<30', 'F']) # => 'Y'
63
+ # classifier.eval(['New York', '<30', 'F']) # => 'Y'
64
+ # Tie resolution is controlled by +tie_break+ parameter.
65
+ # @param data [Object]
66
+ # @return [Object]
45
67
  def eval(data)
46
- votes = Hash.new {0}
68
+ votes = Votes.new
47
69
  @pipes.each do |category, pipe|
48
70
  pipe.each_with_index do |bounds, i|
49
71
  if data[i].is_a? Numeric
50
- votes[category]+=1 if data[i]>=bounds[:min] && data[i]<=bounds[:max]
51
- else
52
- votes[category]+=1 if bounds[data[i]]
72
+ votes.increment_category(category) if data[i].between?(bounds[:min], bounds[:max])
73
+ elsif bounds[data[i]]
74
+ votes.increment_category(category)
53
75
  end
54
76
  end
55
77
  end
56
- return votes.to_a.max {|x, y| x.last <=> y.last}.first
78
+ rng = @rng || (@random_seed.nil? ? Random.new : Random.new(@random_seed))
79
+ votes.get_winner(@tie_break, rng: rng)
57
80
  end
58
-
81
+ # rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity
82
+
59
83
  # This method returns the generated rules in ruby code.
60
84
  # e.g.
61
- #
85
+ #
62
86
  # classifier.get_rules
63
87
  # # => if age_range == '<30' then marketing_target = 'Y'
64
88
  # elsif age_range == '[30-50)' then marketing_target = 'N'
65
89
  # elsif age_range == '[50-80]' then marketing_target = 'N'
66
90
  # end
67
91
  #
68
- # It is a nice way to inspect induction results, and also to execute them:
92
+ # It is a nice way to inspect induction results, and also to execute them:
69
93
  # marketing_target = nil
70
- # eval classifier.get_rules
94
+ # eval classifier.get_rules
71
95
  # puts marketing_target
72
96
  # # => 'Y'
97
+ # @return [Object]
98
+ # rubocop:disable Metrics/AbcSize
73
99
  def get_rules
74
100
  rules = []
75
- rules << "votes = Hash.new {0}"
101
+ rules << 'votes = Votes.new'
76
102
  data = @data_set.data_items.first
77
- labels = @data_set.data_labels.collect {|l| l.to_s}
103
+ labels = @data_set.data_labels.collect(&:to_s)
78
104
  @pipes.each do |category, pipe|
79
105
  pipe.each_with_index do |bounds, i|
80
- rule = "votes['#{category}'] += 1 "
81
- if data[i].is_a? Numeric
82
- rule += "if #{labels[i]} >= #{bounds[:min]} && #{labels[i]} <= #{bounds[:max]}"
106
+ rule = "votes.increment_category('#{category}') "
107
+ rule += if data[i].is_a? Numeric
108
+ "if #{labels[i]} >= #{bounds[:min]} && #{labels[i]} <= #{bounds[:max]}"
109
+ else
110
+ "if #{bounds.inspect}[#{labels[i]}]"
111
+ end
112
+ rules << rule
113
+ end
114
+ end
115
+ rules << "#{labels.last} = votes.get_winner(:#{@tie_break})"
116
+ rules.join("\n")
117
+ end
118
+ # rubocop:enable Metrics/AbcSize
119
+ # rubocop:enable Naming/AccessorMethodName
120
+
121
+ # Return a summary representation of all pipes.
122
+ #
123
+ # The returned hash maps each category to another hash where the keys are
124
+ # attribute labels and the values are either numeric ranges
125
+ # `[min, max]` (including the optional margin) or a Set of nominal values.
126
+ #
127
+ # classifier.pipes_summary
128
+ # # => { "Y" => { "city" => #{Set['New York', 'Chicago']},
129
+ # "age" => [18, 85],
130
+ # "gender" => #{Set['M', 'F']} },
131
+ # "N" => { ... } }
132
+ #
133
+ # The optional +margin+ parameter expands numeric bounds by the given
134
+ # fraction. A value of 0.1 would enlarge each range by 10%.
135
+ # @param margin [Object]
136
+ # @return [Object]
137
+ # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity
138
+ def pipes_summary(margin: 0)
139
+ raise 'Model not built yet' unless @data_set && @pipes
140
+
141
+ labels = @data_set.data_labels[0...-1]
142
+ summary = {}
143
+ @pipes.each do |category, pipe|
144
+ attr_summary = {}
145
+ pipe.each_with_index do |bounds, i|
146
+ if bounds.is_a?(Hash) && bounds.key?(:min) && bounds.key?(:max)
147
+ min = bounds[:min]
148
+ max = bounds[:max]
149
+ range_margin = (max - min) * margin
150
+ attr_summary[labels[i]] = [min - range_margin, max + range_margin]
83
151
  else
84
- rule += "if #{bounds.inspect}[#{labels[i]}]"
152
+ attr_summary[labels[i]] = bounds.select { |_k, v| v }.keys.to_set
85
153
  end
86
- rules << rule
87
154
  end
155
+ summary[category] = attr_summary
88
156
  end
89
- rules << "#{labels.last} = votes.to_a.max {|x, y| x.last <=> y.last}.first"
90
- return rules.join("\n")
157
+ summary
91
158
  end
92
-
159
+ # rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity
160
+
93
161
  protected
94
162
 
163
+ # @param data_set [Object]
164
+ # @return [Object]
95
165
  def build_pipe(data_set)
96
166
  data_set.data_items.first[0...-1].collect do |att|
97
167
  if att.is_a? Numeric
98
- {:min=>1.0/0, :max=>-1.0/0}
168
+ { min: Float::INFINITY, max: -Float::INFINITY }
99
169
  else
100
170
  Hash.new(false)
101
171
  end
102
172
  end
103
173
  end
104
-
174
+
175
+ # @param pipe [Object]
176
+ # @param data_item [Object]
177
+ # @return [Object]
178
+ # rubocop:disable Metrics/AbcSize
105
179
  def update_pipe(pipe, data_item)
106
180
  data_item[0...-1].each_with_index do |att, i|
107
181
  if att.is_a? Numeric
108
- pipe[i][:min] = att if att < pipe[i][:min]
109
- pipe[i][:max] = att if att > pipe[i][:max]
182
+ min_val = att - @margin
183
+ max_val = att + @margin
184
+ pipe[i][:min] = min_val if min_val < pipe[i][:min]
185
+ pipe[i][:max] = max_val if max_val > pipe[i][:max]
110
186
  else
111
187
  pipe[i][att] = true
112
- end
188
+ end
113
189
  end
114
190
  end
115
-
191
+ # rubocop:enable Metrics/AbcSize
116
192
  end
117
193
  end
118
194
  end
@@ -1,21 +1,22 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens (Implementation only)
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
- # Url:: http://ai4r.org/
6
+ # Url:: https://github.com/SergioFierens/ai4r
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
12
  require 'set'
11
- require File.dirname(__FILE__) + '/../data/data_set'
12
- require File.dirname(__FILE__) + '/../classifiers/classifier'
13
+ require_relative '../data/data_set'
14
+ require_relative '../classifiers/classifier'
13
15
 
14
16
  module Ai4r
15
17
  module Classifiers
16
-
17
18
  # = Introduction
18
- #
19
+ #
19
20
  # IB1 algorithm implementation.
20
21
  # IB1 is the simplest instance-based learning (IBL) algorithm.
21
22
  #
@@ -26,45 +27,126 @@ module Ai4r
26
27
  # it normalizes its attributes' ranges, processes instances
27
28
  # incrementally, and has a simple policy for tolerating missing values
28
29
  class IB1 < Classifier
29
-
30
- attr_reader :data_set
30
+ attr_reader :data_set, :min_values, :max_values
31
+
32
+ parameters_info k: 'Number of nearest neighbors to consider. Default is 1.',
33
+ distance_function:
34
+ 'Optional custom distance metric taking two instances.',
35
+ tie_break:
36
+ 'Strategy used when neighbors vote tie. ' \
37
+ 'Valid values are :first (default) and :random.',
38
+ random_seed:
39
+ 'Seed for random tie-breaking when :tie_break is :random.'
40
+
41
+ # @return [Object]
42
+ def initialize
43
+ super()
44
+ @k = 1
45
+ @distance_function = nil
46
+ @tie_break = :first
47
+ @random_seed = nil
48
+ @rng = nil
49
+ end
31
50
 
32
51
  # Build a new IB1 classifier. You must provide a DataSet instance
33
- # as parameter. The last attribute of each item is considered as
52
+ # as parameter. The last attribute of each item is considered as
34
53
  # the item class.
54
+ # @param data_set [Object]
55
+ # @return [Object]
35
56
  def build(data_set)
36
57
  data_set.check_not_empty
37
58
  @data_set = data_set
38
59
  @min_values = Array.new(data_set.data_labels.length)
39
60
  @max_values = Array.new(data_set.data_labels.length)
40
61
  data_set.data_items.each { |data_item| update_min_max(data_item[0...-1]) }
41
- return self
62
+ self
63
+ end
64
+
65
+ # Append a new instance to the internal dataset. The last element is
66
+ # considered the class label. Minimum and maximum values for numeric
67
+ # attributes are updated so that future distance calculations remain
68
+ # normalized.
69
+ # @param data_item [Object]
70
+ # @return [Object]
71
+ def add_instance(data_item)
72
+ @data_set << data_item
73
+ update_min_max(data_item[0...-1])
74
+ self
42
75
  end
43
-
76
+
44
77
  # You can evaluate new data, predicting its class.
45
78
  # e.g.
46
- # classifier.eval(['New York', '<30', 'F']) # => 'Y'
79
+ # classifier.eval(['New York', '<30', 'F']) # => 'Y'
80
+ #
81
+ # Evaluation does not update internal statistics, keeping the
82
+ # classifier state unchanged. Use +update_with_instance+ to
83
+ # incorporate new samples.
47
84
  def eval(data)
48
- update_min_max(data)
49
- min_distance = 1.0/0
50
- klass = nil
51
- @data_set.data_items.each do |train_item|
52
- d = distance(data, train_item)
53
- if d < min_distance
54
- min_distance = d
55
- klass = train_item.last
56
- end
85
+ neighbors = @data_set.data_items.map do |train_item|
86
+ [distance(data, train_item), train_item.last]
87
+ end
88
+ neighbors.sort_by! { |d, _| d }
89
+ k_limit = [@k, @data_set.data_items.length].min
90
+ k_neighbors = neighbors.first(k_limit)
91
+
92
+ # Include any other neighbors tied with the last selected distance
93
+ last_distance = k_neighbors.last[0]
94
+ neighbors[k_limit..].to_a.each do |dist, klass|
95
+ break if dist > last_distance
96
+
97
+ k_neighbors << [dist, klass]
57
98
  end
58
- return klass
99
+
100
+ counts = Hash.new(0)
101
+ k_neighbors.each { |(_dist, klass)| counts[klass] += 1 }
102
+ max_votes = counts.values.max
103
+ tied = counts.select { |_, v| v == max_votes }.keys
104
+
105
+ return tied.first if tied.length == 1
106
+
107
+ rng = @rng || (@random_seed.nil? ? Random.new : Random.new(@random_seed))
108
+
109
+ case @tie_break
110
+ when :random
111
+ tied.sample(random: rng)
112
+ else
113
+ k_neighbors.each { |(_dist, klass)| return klass if tied.include?(klass) }
114
+ end
115
+ end
116
+
117
+ # Returns an array with the +k+ nearest instances from the training set
118
+ # for the given +data+ item. The returned elements are the training data
119
+ # rows themselves, ordered from the closest to the furthest.
120
+ # @param data [Object]
121
+ # @param k [Object]
122
+ # @return [Object]
123
+ def neighbors_for(data, k_neighbors)
124
+ update_min_max(data)
125
+ @data_set.data_items
126
+ .map { |train_item| [train_item, distance(data, train_item)] }
127
+ .sort_by(&:last)
128
+ .first(k_neighbors)
129
+ .map(&:first)
130
+ end
131
+
132
+ # Update min/max values with the provided instance attributes. If
133
+ # +learn+ is true, also append the instance to the training set so the
134
+ # classifier learns incrementally.
135
+ def update_with_instance(data_item, learn: false)
136
+ update_min_max(data_item[0...-1])
137
+ @data_set << data_item if learn
138
+ self
59
139
  end
60
-
140
+
61
141
  protected
62
142
 
63
143
  # We keep in the state the min and max value of each attribute,
64
144
  # to provide normalized distances between to values of a numeric attribute
145
+ # @param atts [Object]
146
+ # @return [Object]
65
147
  def update_min_max(atts)
66
148
  atts.each_with_index do |att, i|
67
- if att && att.is_a?(Numeric)
149
+ if att.is_a?(Numeric)
68
150
  @min_values[i] = att if @min_values[i].nil? || @min_values[i] > att
69
151
  @max_values[i] = att if @max_values[i].nil? || @max_values[i] < att
70
152
  end
@@ -80,10 +162,15 @@ module Ai4r
80
162
  # * 1 if both atts are missing
81
163
  # * normalized numeric att value if other att value is missing and > 0.5
82
164
  # * 1.0-normalized numeric att value if other att value is missing and < 0.5
83
- def distance(a, b)
165
+ # @param a [Object]
166
+ # @param b [Object]
167
+ # @return [Object]
168
+ def distance(data_a, data_b)
169
+ return @distance_function.call(data_a, data_b) if @distance_function
170
+
84
171
  d = 0
85
- a.each_with_index do |att_a, i|
86
- att_b = b[i]
172
+ data_a.each_with_index do |att_a, i|
173
+ att_b = data_b[i]
87
174
  if att_a.nil?
88
175
  if att_b.is_a? Numeric
89
176
  diff = norm(att_b, i)
@@ -93,7 +180,7 @@ module Ai4r
93
180
  end
94
181
  elsif att_a.is_a? Numeric
95
182
  if att_b.is_a? Numeric
96
- diff = norm(att_a, i) - norm(att_b, i);
183
+ diff = norm(att_a, i) - norm(att_b, i)
97
184
  else
98
185
  diff = norm(att_a, i)
99
186
  diff = 1.0 - diff if diff < 0.5
@@ -105,17 +192,20 @@ module Ai4r
105
192
  end
106
193
  d += diff * diff
107
194
  end
108
- return d
195
+ d
109
196
  end
110
197
 
111
198
  # Returns normalized value att
112
199
  #
113
200
  # index is the index of the attribute in the instance.
201
+ # @param att [Object]
202
+ # @param index [Object]
203
+ # @return [Object]
114
204
  def norm(att, index)
115
205
  return 0 if @min_values[index].nil?
116
- return 1.0*(att - @min_values[index]) / (@max_values[index] -@min_values[index]);
206
+
207
+ 1.0 * (att - @min_values[index]) / (@max_values[index] - @min_values[index])
117
208
  end
118
-
119
209
  end
120
210
  end
121
211
  end