ai4ruby 1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. data/README.rdoc +47 -0
  2. data/examples/classifiers/id3_data.csv +121 -0
  3. data/examples/classifiers/id3_example.rb +29 -0
  4. data/examples/classifiers/naive_bayes_data.csv +11 -0
  5. data/examples/classifiers/naive_bayes_example.rb +16 -0
  6. data/examples/classifiers/results.txt +31 -0
  7. data/examples/genetic_algorithm/genetic_algorithm_example.rb +37 -0
  8. data/examples/genetic_algorithm/travel_cost.csv +16 -0
  9. data/examples/neural_network/backpropagation_example.rb +67 -0
  10. data/examples/neural_network/patterns_with_base_noise.rb +68 -0
  11. data/examples/neural_network/patterns_with_noise.rb +66 -0
  12. data/examples/neural_network/training_patterns.rb +68 -0
  13. data/examples/neural_network/xor_example.rb +35 -0
  14. data/examples/som/som_data.rb +156 -0
  15. data/examples/som/som_multi_node_example.rb +22 -0
  16. data/examples/som/som_single_example.rb +24 -0
  17. data/lib/ai4r.rb +33 -0
  18. data/lib/ai4r/classifiers/classifier.rb +62 -0
  19. data/lib/ai4r/classifiers/hyperpipes.rb +118 -0
  20. data/lib/ai4r/classifiers/ib1.rb +121 -0
  21. data/lib/ai4r/classifiers/id3.rb +326 -0
  22. data/lib/ai4r/classifiers/multilayer_perceptron.rb +135 -0
  23. data/lib/ai4r/classifiers/naive_bayes.rb +259 -0
  24. data/lib/ai4r/classifiers/one_r.rb +110 -0
  25. data/lib/ai4r/classifiers/prism.rb +197 -0
  26. data/lib/ai4r/classifiers/zero_r.rb +73 -0
  27. data/lib/ai4r/clusterers/average_linkage.rb +59 -0
  28. data/lib/ai4r/clusterers/bisecting_k_means.rb +93 -0
  29. data/lib/ai4r/clusterers/centroid_linkage.rb +66 -0
  30. data/lib/ai4r/clusterers/clusterer.rb +61 -0
  31. data/lib/ai4r/clusterers/complete_linkage.rb +67 -0
  32. data/lib/ai4r/clusterers/diana.rb +139 -0
  33. data/lib/ai4r/clusterers/k_means.rb +126 -0
  34. data/lib/ai4r/clusterers/median_linkage.rb +61 -0
  35. data/lib/ai4r/clusterers/single_linkage.rb +194 -0
  36. data/lib/ai4r/clusterers/ward_linkage.rb +64 -0
  37. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +31 -0
  38. data/lib/ai4r/clusterers/weighted_average_linkage.rb +61 -0
  39. data/lib/ai4r/data/data_set.rb +266 -0
  40. data/lib/ai4r/data/parameterizable.rb +64 -0
  41. data/lib/ai4r/data/proximity.rb +100 -0
  42. data/lib/ai4r/data/statistics.rb +77 -0
  43. data/lib/ai4r/experiment/classifier_evaluator.rb +95 -0
  44. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +270 -0
  45. data/lib/ai4r/neural_network/backpropagation.rb +326 -0
  46. data/lib/ai4r/neural_network/hopfield.rb +149 -0
  47. data/lib/ai4r/som/layer.rb +68 -0
  48. data/lib/ai4r/som/node.rb +96 -0
  49. data/lib/ai4r/som/som.rb +155 -0
  50. data/lib/ai4r/som/two_phase_layer.rb +90 -0
  51. data/test/classifiers/hyperpipes_test.rb +84 -0
  52. data/test/classifiers/ib1_test.rb +78 -0
  53. data/test/classifiers/id3_test.rb +208 -0
  54. data/test/classifiers/multilayer_perceptron_test.rb +79 -0
  55. data/test/classifiers/naive_bayes_test.rb +43 -0
  56. data/test/classifiers/one_r_test.rb +62 -0
  57. data/test/classifiers/prism_test.rb +85 -0
  58. data/test/classifiers/zero_r_test.rb +49 -0
  59. data/test/clusterers/average_linkage_test.rb +51 -0
  60. data/test/clusterers/bisecting_k_means_test.rb +66 -0
  61. data/test/clusterers/centroid_linkage_test.rb +53 -0
  62. data/test/clusterers/complete_linkage_test.rb +57 -0
  63. data/test/clusterers/diana_test.rb +69 -0
  64. data/test/clusterers/k_means_test.rb +100 -0
  65. data/test/clusterers/median_linkage_test.rb +53 -0
  66. data/test/clusterers/single_linkage_test.rb +122 -0
  67. data/test/clusterers/ward_linkage_hierarchical_test.rb +61 -0
  68. data/test/clusterers/ward_linkage_test.rb +53 -0
  69. data/test/clusterers/weighted_average_linkage_test.rb +53 -0
  70. data/test/data/data_set_test.rb +96 -0
  71. data/test/data/proximity_test.rb +81 -0
  72. data/test/data/statistics_test.rb +65 -0
  73. data/test/experiment/classifier_evaluator_test.rb +76 -0
  74. data/test/genetic_algorithm/chromosome_test.rb +58 -0
  75. data/test/genetic_algorithm/genetic_algorithm_test.rb +81 -0
  76. data/test/neural_network/backpropagation_test.rb +82 -0
  77. data/test/neural_network/hopfield_test.rb +72 -0
  78. data/test/som/som_test.rb +97 -0
  79. metadata +168 -0
@@ -0,0 +1,78 @@
1
+ # Author:: Sergio Fierens
2
+ # License:: MPL 1.1
3
+ # Project:: ai4r
4
+ # Url:: http://ai4r.rubyforge.org/
5
+ #
6
+ # You can redistribute it and/or modify it under the terms of
7
+ # the Mozilla Public License version 1.1 as published by the
8
+ # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
+
10
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/ib1'
11
+ require 'test/unit'
12
+
13
+ class Ai4r::Classifiers::IB1
14
+ attr_accessor :data_set, :min_values, :max_values
15
+ end
16
+
17
+ include Ai4r::Classifiers
18
+ include Ai4r::Data
19
+
20
+ class IB1Test < Test::Unit::TestCase
21
+
22
+ @@data_labels = [ 'city', 'age', 'gender', 'marketing_target' ]
23
+
24
+ @@data_items = [['New York', 25, 'M', 'Y'],
25
+ ['New York', 23, 'M', 'Y'],
26
+ ['New York', 18, 'M', 'Y'],
27
+ ['Chicago', 43, 'M', 'Y'],
28
+ ['New York', 34, 'F', 'N'],
29
+ ['Chicago', 33, 'F', 'Y'],
30
+ ['New York', 31, 'F', 'N'],
31
+ ['Chicago', 55, 'M', 'N'],
32
+ ['New York', 58, 'F', 'N'],
33
+ ['New York', 59, 'M', 'N'],
34
+ ['Chicago', 71, 'M', 'N'],
35
+ ['New York', 60, 'F', 'N'],
36
+ ['Chicago', 85, 'F', 'Y']
37
+ ]
38
+
39
+
40
+ def setup
41
+ IB1.send(:public, *IB1.protected_instance_methods)
42
+ @data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
43
+ @classifier = IB1.new.build(@data_set)
44
+ end
45
+
46
+ def test_build
47
+ assert_raise(ArgumentError) { IB1.new.build(DataSet.new) }
48
+ assert @classifier.data_set
49
+ assert_equal [nil, 18, nil, nil], @classifier.min_values
50
+ assert_equal [nil, 85, nil, nil], @classifier.max_values
51
+ end
52
+
53
+ def test_norm
54
+ assert_equal(0,@classifier.norm('Chicago', 0))
55
+ assert_in_delta(0.5522,@classifier.norm(55, 1),0.0001)
56
+ assert_equal(0,@classifier.norm('F', 0))
57
+ end
58
+
59
+ def test_distance
60
+ item = ['Chicago', 55, 'M', 'N']
61
+ assert_equal(0, @classifier.distance(['Chicago', 55, 'M'], item))
62
+ assert_equal(1, @classifier.distance([nil, 55, 'M'], item))
63
+ assert_equal(1, @classifier.distance(['New York', 55, 'M'], item))
64
+ assert_in_delta(0.2728, @classifier.distance(['Chicago', 20, 'M'], item), 0.0001)
65
+ end
66
+
67
+ def test_eval
68
+ classifier = IB1.new.build(@data_set)
69
+ assert classifier
70
+ assert_equal('N', classifier.eval(['Chicago', 55, 'M']))
71
+ assert_equal('N', classifier.eval(['New York', 35, 'F']))
72
+ assert_equal('Y', classifier.eval(['New York', 25, 'M']))
73
+ assert_equal('Y', classifier.eval(['Chicago', 85, 'F']))
74
+ end
75
+
76
+ end
77
+
78
+
@@ -0,0 +1,208 @@
1
+ # id3_test.rb
2
+ #
3
+ # This is a unit test file for the ID3 algorithm (Quinlan) implemented
4
+ # in ai4r
5
+ #
6
+ # Author:: Sergio Fierens
7
+ # License:: MPL 1.1
8
+ # Project:: ai4r
9
+ # Url:: http://ai4r.rubyforge.org/
10
+ #
11
+ # You can redistribute it and/or modify it under the terms of
12
+ # the Mozilla Public License version 1.1 as published by the
13
+ # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
14
+
15
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/id3'
16
+ require 'test/unit'
17
+
18
+ DATA_LABELS = [ 'city', 'age_range', 'gender', 'marketing_target' ]
19
+
20
+ DATA_ITEMS = [ ['New York', '<30', 'M', 'Y'],
21
+ ['Chicago', '<30', 'M', 'Y'],
22
+ ['Chicago', '<30', 'F', 'Y'],
23
+ ['New York', '<30', 'M', 'Y'],
24
+ ['New York', '<30', 'M', 'Y'],
25
+ ['Chicago', '[30-50)', 'M', 'Y'],
26
+ ['New York', '[30-50)', 'F', 'N'],
27
+ ['Chicago', '[30-50)', 'F', 'Y'],
28
+ ['New York', '[30-50)', 'F', 'N'],
29
+ ['Chicago', '[50-80]', 'M', 'N'],
30
+ ['New York', '[50-80]', 'F', 'N'],
31
+ ['New York', '[50-80]', 'M', 'N'],
32
+ ['Chicago', '[50-80]', 'M', 'N'],
33
+ ['New York', '[50-80]', 'F', 'N'],
34
+ ['Chicago', '>80', 'F', 'Y']
35
+ ]
36
+
37
+ SPLIT_DATA_ITEMS_BY_CITY = [ [
38
+ ["New York", "<30", "M", "Y"],
39
+ ["New York", "<30", "M", "Y"],
40
+ ["New York", "<30", "M", "Y"],
41
+ ["New York", "[30-50)", "F", "N"],
42
+ ["New York", "[30-50)", "F", "N"],
43
+ ["New York", "[50-80]", "F", "N"],
44
+ ["New York", "[50-80]", "M", "N"],
45
+ ["New York", "[50-80]", "F", "N"]],
46
+ [
47
+ ["Chicago", "<30", "M", "Y"],
48
+ ["Chicago", "<30", "F", "Y"],
49
+ ["Chicago", "[30-50)", "M", "Y"],
50
+ ["Chicago", "[30-50)", "F", "Y"],
51
+ ["Chicago", "[50-80]", "M", "N"],
52
+ ["Chicago", "[50-80]", "M", "N"],
53
+ ["Chicago", ">80", "F", "Y"]]
54
+ ]
55
+
56
+ SPLIT_DATA_ITEMS_BY_AGE = [ [
57
+ ["New York", "<30", "M", "Y"],
58
+ ["Chicago", "<30", "M", "Y"],
59
+ ["Chicago", "<30", "F", "Y"],
60
+ ["New York", "<30", "M", "Y"],
61
+ ["New York", "<30", "M", "Y"]],
62
+ [
63
+ ["Chicago", "[30-50)", "M", "Y"],
64
+ ["New York", "[30-50)", "F", "N"],
65
+ ["Chicago", "[30-50)", "F", "Y"],
66
+ ["New York", "[30-50)", "F", "N"]],
67
+ [
68
+ ["Chicago", "[50-80]", "M", "N"],
69
+ ["New York", "[50-80]", "F", "N"],
70
+ ["New York", "[50-80]", "M", "N"],
71
+ ["Chicago", "[50-80]", "M", "N"],
72
+ ["New York", "[50-80]", "F", "N"]],
73
+ [
74
+ ["Chicago", ">80", "F", "Y"]]
75
+ ]
76
+
77
+ EXPECTED_RULES_STRING =
78
+ "if age_range=='<30' then marketing_target='Y'\n"+
79
+ "elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'\n"+
80
+ "elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'\n"+
81
+ "elsif age_range=='[50-80]' then marketing_target='N'\n"+
82
+ "elsif age_range=='>80' then marketing_target='Y'\n"+
83
+ "else raise 'There was not enough information during training to do a proper induction for this data element' end"
84
+
85
+ include Ai4r::Classifiers
86
+ include Ai4r::Data
87
+
88
+ class ID3Test < Test::Unit::TestCase
89
+
90
+ def test_build
91
+ Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.protected_instance_methods)
92
+ Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.private_instance_methods)
93
+ end
94
+
95
+ def test_log2
96
+ assert_equal 1.0, ID3.log2(2)
97
+ assert_equal 0.0, ID3.log2(0)
98
+ assert 1.585 - ID3.log2(3) < 0.001
99
+ end
100
+
101
+ def test_sum
102
+ assert_equal 28, ID3.sum([5, 0, 22, 1])
103
+ assert_equal 0, ID3.sum([])
104
+ end
105
+
106
+ def test_data_labels
107
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS))
108
+ expected_default = [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
109
+ assert_equal(expected_default, id3.data_set.data_labels)
110
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
111
+ assert_equal(DATA_LABELS, id3.data_set.data_labels)
112
+ end
113
+
114
+ def test_domain
115
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
116
+ expected_domain = [["New York", "Chicago"], ["<30", "[30-50)", "[50-80]", ">80"], ["M", "F"], ["Y", "N"]]
117
+ assert_equal expected_domain, id3.domain(DATA_ITEMS)
118
+ end
119
+
120
+ def test_grid
121
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
122
+ expected_grid = [[3, 5], [5, 2]]
123
+ domain = id3.domain(DATA_ITEMS)
124
+ assert_equal expected_grid, id3.freq_grid(0, DATA_ITEMS, domain)
125
+ expected_grid = [[5, 0], [2, 2], [0, 5], [1, 0]]
126
+ assert_equal expected_grid, id3.freq_grid(1, DATA_ITEMS, domain)
127
+ end
128
+
129
+ def test_entropy
130
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
131
+ expected_entropy = 0.9118
132
+ domain = id3.domain(DATA_ITEMS)
133
+ freq_grid = id3.freq_grid(0, DATA_ITEMS, domain)
134
+ assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
135
+ expected_entropy = 0.2667
136
+ freq_grid = id3.freq_grid(1, DATA_ITEMS, domain)
137
+ assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
138
+ expected_entropy = 0.9688
139
+ freq_grid = id3.freq_grid(2, DATA_ITEMS, domain)
140
+ assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
141
+ end
142
+
143
+ def test_min_entropy_index
144
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
145
+ domain = id3.domain(DATA_ITEMS)
146
+ assert_equal 1, id3.min_entropy_index(DATA_ITEMS, domain)
147
+ assert_equal 0, id3.min_entropy_index(DATA_ITEMS, domain, [1])
148
+ assert_equal 2, id3.min_entropy_index(DATA_ITEMS, domain, [1, 0])
149
+ end
150
+
151
+ def test_split_data_examples
152
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
153
+ domain = id3.domain(DATA_ITEMS)
154
+ res = id3.split_data_examples(DATA_ITEMS, domain, 0)
155
+ assert_equal(SPLIT_DATA_ITEMS_BY_CITY, res)
156
+ res = id3.split_data_examples(DATA_ITEMS, domain, 1)
157
+ assert_equal(SPLIT_DATA_ITEMS_BY_AGE, res)
158
+ end
159
+
160
+ def test_most_freq
161
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
162
+ domain = id3.domain(DATA_ITEMS)
163
+ assert_equal 'Y', id3.most_freq(DATA_ITEMS, domain)
164
+ assert_equal 'Y', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[3], domain)
165
+ assert_equal 'N', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[2], domain)
166
+ end
167
+
168
+ def test_get_rules
169
+ assert_equal [["marketing_target='N'"]], CategoryNode.new('marketing_target', 'N').get_rules
170
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
171
+ assert_equal EXPECTED_RULES_STRING, id3.get_rules
172
+ end
173
+
174
+ def test_eval
175
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
176
+ #if age_range='<30' then marketing_target='Y'
177
+ assert_equal 'Y', id3.eval(['New York', '<30', 'F'])
178
+ assert_equal 'Y', id3.eval(['Chicago', '<30', 'M'])
179
+ #if age_range='[30-50)' and city='Chicago' then marketing_target='Y'
180
+ assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'F'])
181
+ assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'M'])
182
+ #if age_range='[30-50)' and city='New York' then marketing_target='N'
183
+ assert_equal 'N', id3.eval(['New York', '[30-50)', 'F'])
184
+ assert_equal 'N', id3.eval(['New York', '[30-50)', 'M'])
185
+ #if age_range='[50-80]' then marketing_target='N'
186
+ assert_equal 'N', id3.eval(['New York', '[50-80]', 'F'])
187
+ assert_equal 'N', id3.eval(['Chicago', '[50-80]', 'M'])
188
+ #if age_range='>80' then marketing_target='Y'
189
+ assert_equal 'Y', id3.eval(['New York', '>80', 'M'])
190
+ assert_equal 'Y', id3.eval(['Chicago', '>80', 'F'])
191
+ end
192
+
193
+ def test_rules_eval
194
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
195
+ #if age_range='<30' then marketing_target='Y'
196
+ age_range = '<30'
197
+ marketing_target = nil
198
+ eval id3.get_rules
199
+ assert_equal 'Y', marketing_target
200
+ #if age_range='[30-50)' and city='New York' then marketing_target='N'
201
+ age_range='[30-50)'
202
+ city='New York'
203
+ eval id3.get_rules
204
+ assert_equal 'N', marketing_target
205
+ end
206
+ end
207
+
208
+
@@ -0,0 +1,79 @@
1
+ require 'test/unit'
2
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/multilayer_perceptron'
3
+ require File.dirname(__FILE__) + '/../../lib/ai4r/data/data_set'
4
+
5
+ # Make all accessors and methods public
6
+ class Ai4r::Classifiers::MultilayerPerceptron
7
+ attr_accessor :data_set, :class_value, :network, :domains, :outputs
8
+ public :get_max_index
9
+ public :data_to_output
10
+ end
11
+
12
+ class MultilayerPerceptronTest < Test::Unit::TestCase
13
+
14
+ include Ai4r::Classifiers
15
+ include Ai4r::Data
16
+
17
+ @@data_set = DataSet.new(:data_items =>[ ['New York', '<30', 'M', 'Y'],
18
+ ['Chicago', '<30', 'M', 'Y'],
19
+ ['New York', '<30', 'M', 'Y'],
20
+ ['New York', '[30-50)', 'F', 'N'],
21
+ ['Chicago', '[30-50)', 'F', 'Y'],
22
+ ['New York', '[30-50)', 'F', 'N'],
23
+ ['Chicago', '[50-80]', 'M', 'N'],
24
+ ])
25
+
26
+ def test_initialize
27
+ classifier = MultilayerPerceptron.new
28
+ assert_equal 1, classifier.active_node_value
29
+ assert_equal 0, classifier.inactive_node_value
30
+ assert_equal Ai4r::NeuralNetwork::Backpropagation, classifier.network_class
31
+ assert_equal [], classifier.hidden_layers
32
+ assert classifier.network_parameters
33
+ assert classifier.network_parameters.empty?
34
+ assert classifier.training_iterations > 1
35
+ end
36
+
37
+ def test_build
38
+ assert_raise(ArgumentError) { MultilayerPerceptron.new.build(DataSet.new) }
39
+ classifier = MultilayerPerceptron.new
40
+ classifier.training_iterations = 1
41
+ classifier.build(@@data_set)
42
+ assert_equal [7,2], classifier.network.structure
43
+ classifier.hidden_layers = [6, 4]
44
+ classifier.build(@@data_set)
45
+ assert_equal [7,6,4,2], classifier.network.structure
46
+ end
47
+
48
+ def test_eval
49
+ classifier = MultilayerPerceptron.new.build(@@data_set)
50
+ assert classifier
51
+ assert_equal('N', classifier.eval(['Chicago', '[50-80]', 'M']))
52
+ assert_equal('N', classifier.eval(['New York', '[30-50)', 'F']))
53
+ assert_equal('Y', classifier.eval(['New York', '<30', 'M']))
54
+ assert_equal('Y', classifier.eval(['Chicago', '[30-50)', 'F']))
55
+ end
56
+
57
+ def test_get_rules
58
+ assert_match(/raise/, MultilayerPerceptron.new.get_rules)
59
+ end
60
+
61
+ def test_get_max_index
62
+ classifier = MultilayerPerceptron.new
63
+ assert_equal(0, classifier.get_max_index([3, 1, 0.2, -9, 0, 2.99]))
64
+ assert_equal(2, classifier.get_max_index([3, 1, 5, -9, 0, 2.99]))
65
+ assert_equal(5, classifier.get_max_index([3, 1, 5, -9, 0, 6]))
66
+ end
67
+
68
+ def test_data_to_output
69
+ classifier = MultilayerPerceptron.new
70
+ classifier.outputs = 4
71
+ classifier.outputs = 4
72
+ classifier.domains = [nil, nil, nil, ["A", "B", "C", "D"]]
73
+ assert_equal([1,0,0,0], classifier.data_to_output("A"))
74
+ assert_equal([0,0,1,0], classifier.data_to_output("C"))
75
+ assert_equal([0,0,0,1], classifier.data_to_output("D"))
76
+ end
77
+
78
+ end
79
+
@@ -0,0 +1,43 @@
1
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/naive_bayes'
2
+ require File.dirname(__FILE__) + '/../../lib/ai4r/data/data_set'
3
+ require 'test/unit'
4
+
5
+ include Ai4r::Classifiers
6
+ include Ai4r::Data
7
+
8
+ class NaiveBayesTest < Test::Unit::TestCase
9
+
10
+ @@data_labels = [ "Color","Type","Origin","Stolen?" ]
11
+
12
+ @@data_items = [
13
+ ["Red", "Sports", "Domestic", "Yes"],
14
+ ["Red", "Sports", "Domestic", "No"],
15
+ ["Red", "Sports", "Domestic", "Yes"],
16
+ ["Yellow","Sports", "Domestic", "No"],
17
+ ["Yellow","Sports", "Imported", "Yes"],
18
+ ["Yellow","SUV", "Imported", "No"],
19
+ ["Yellow","SUV", "Imported", "Yes"],
20
+ ["Yellow","Sports", "Domestic", "No"],
21
+ ["Red", "SUV", "Imported", "No"],
22
+ ["Red", "Sports", "Imported", "Yes"]
23
+ ]
24
+
25
+ def setup
26
+ @data_set = DataSet.new
27
+ @data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
28
+ @b = NaiveBayes.new.set_parameters({:m=>3}).build @data_set
29
+ end
30
+
31
+ def test_eval
32
+ result = @b.eval(["Red", "SUV", "Domestic"])
33
+ assert_equal "No", result
34
+ end
35
+
36
+ def test_get_probability_map
37
+ map = @b.get_probability_map(["Red", "SUV", "Domestic"])
38
+ assert_equal 2, map.keys.length
39
+ assert_in_delta 0.42, map["Yes"], 0.1
40
+ assert_in_delta 0.58, map["No"], 0.1
41
+ end
42
+
43
+ end
@@ -0,0 +1,62 @@
1
+ require 'test/unit'
2
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/one_r'
3
+
4
+ class OneRTest < Test::Unit::TestCase
5
+
6
+ include Ai4r::Classifiers
7
+ include Ai4r::Data
8
+
9
+ @@data_examples = [ ['New York', '<30', 'M', 'Y'],
10
+ ['Chicago', '<30', 'M', 'Y'],
11
+ ['New York', '<30', 'M', 'Y'],
12
+ ['New York', '[30-50)', 'F', 'N'],
13
+ ['Chicago', '[30-50)', 'F', 'Y'],
14
+ ['New York', '[30-50)', 'F', 'N'],
15
+ ['Chicago', '[50-80]', 'M', 'N']
16
+ ]
17
+
18
+ @@data_labels = [ 'city', 'age_range', 'gender', 'marketing_target' ]
19
+
20
+ def test_build
21
+ assert_raise(ArgumentError) { OneR.new.build(DataSet.new) }
22
+ classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples))
23
+ assert_not_nil(classifier.data_set.data_labels)
24
+ assert_not_nil(classifier.rule)
25
+ assert_equal("attribute_1", classifier.data_set.data_labels.first)
26
+ assert_equal("class_value", classifier.data_set.data_labels.last)
27
+ classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples,
28
+ :data_labels => @@data_labels))
29
+ assert_not_nil(classifier.data_set.data_labels)
30
+ assert_not_nil(classifier.rule)
31
+ assert_equal("city", classifier.data_set.data_labels.first)
32
+ assert_equal("marketing_target", classifier.data_set.data_labels.last)
33
+ assert_equal(1, classifier.rule[:attr_index])
34
+ end
35
+
36
+ def test_eval
37
+ classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples))
38
+ assert_equal("Y", classifier.eval(['New York', '<30', 'M']))
39
+ assert_equal("N", classifier.eval(['New York', '[30-50)', 'M']))
40
+ assert_equal("N", classifier.eval(['Chicago', '[50-80]', 'M']))
41
+ end
42
+
43
+ def test_get_rules
44
+ classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples,
45
+ :data_labels => @@data_labels))
46
+ marketing_target = nil
47
+ age_range = nil
48
+ eval(classifier.get_rules)
49
+ assert_nil(marketing_target)
50
+ age_range = '<30'
51
+ eval(classifier.get_rules)
52
+ assert_equal("Y", marketing_target)
53
+ age_range = '[30-50)'
54
+ eval(classifier.get_rules)
55
+ assert_equal("N", marketing_target)
56
+ age_range = '[50-80]'
57
+ eval(classifier.get_rules)
58
+ assert_equal("N", marketing_target)
59
+ end
60
+
61
+ end
62
+