ai4ruby 1.11

Sign up to get free protection for your applications and to get access to all the features.
Files changed (79) hide show
  1. data/README.rdoc +47 -0
  2. data/examples/classifiers/id3_data.csv +121 -0
  3. data/examples/classifiers/id3_example.rb +29 -0
  4. data/examples/classifiers/naive_bayes_data.csv +11 -0
  5. data/examples/classifiers/naive_bayes_example.rb +16 -0
  6. data/examples/classifiers/results.txt +31 -0
  7. data/examples/genetic_algorithm/genetic_algorithm_example.rb +37 -0
  8. data/examples/genetic_algorithm/travel_cost.csv +16 -0
  9. data/examples/neural_network/backpropagation_example.rb +67 -0
  10. data/examples/neural_network/patterns_with_base_noise.rb +68 -0
  11. data/examples/neural_network/patterns_with_noise.rb +66 -0
  12. data/examples/neural_network/training_patterns.rb +68 -0
  13. data/examples/neural_network/xor_example.rb +35 -0
  14. data/examples/som/som_data.rb +156 -0
  15. data/examples/som/som_multi_node_example.rb +22 -0
  16. data/examples/som/som_single_example.rb +24 -0
  17. data/lib/ai4r.rb +33 -0
  18. data/lib/ai4r/classifiers/classifier.rb +62 -0
  19. data/lib/ai4r/classifiers/hyperpipes.rb +118 -0
  20. data/lib/ai4r/classifiers/ib1.rb +121 -0
  21. data/lib/ai4r/classifiers/id3.rb +326 -0
  22. data/lib/ai4r/classifiers/multilayer_perceptron.rb +135 -0
  23. data/lib/ai4r/classifiers/naive_bayes.rb +259 -0
  24. data/lib/ai4r/classifiers/one_r.rb +110 -0
  25. data/lib/ai4r/classifiers/prism.rb +197 -0
  26. data/lib/ai4r/classifiers/zero_r.rb +73 -0
  27. data/lib/ai4r/clusterers/average_linkage.rb +59 -0
  28. data/lib/ai4r/clusterers/bisecting_k_means.rb +93 -0
  29. data/lib/ai4r/clusterers/centroid_linkage.rb +66 -0
  30. data/lib/ai4r/clusterers/clusterer.rb +61 -0
  31. data/lib/ai4r/clusterers/complete_linkage.rb +67 -0
  32. data/lib/ai4r/clusterers/diana.rb +139 -0
  33. data/lib/ai4r/clusterers/k_means.rb +126 -0
  34. data/lib/ai4r/clusterers/median_linkage.rb +61 -0
  35. data/lib/ai4r/clusterers/single_linkage.rb +194 -0
  36. data/lib/ai4r/clusterers/ward_linkage.rb +64 -0
  37. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +31 -0
  38. data/lib/ai4r/clusterers/weighted_average_linkage.rb +61 -0
  39. data/lib/ai4r/data/data_set.rb +266 -0
  40. data/lib/ai4r/data/parameterizable.rb +64 -0
  41. data/lib/ai4r/data/proximity.rb +100 -0
  42. data/lib/ai4r/data/statistics.rb +77 -0
  43. data/lib/ai4r/experiment/classifier_evaluator.rb +95 -0
  44. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +270 -0
  45. data/lib/ai4r/neural_network/backpropagation.rb +326 -0
  46. data/lib/ai4r/neural_network/hopfield.rb +149 -0
  47. data/lib/ai4r/som/layer.rb +68 -0
  48. data/lib/ai4r/som/node.rb +96 -0
  49. data/lib/ai4r/som/som.rb +155 -0
  50. data/lib/ai4r/som/two_phase_layer.rb +90 -0
  51. data/test/classifiers/hyperpipes_test.rb +84 -0
  52. data/test/classifiers/ib1_test.rb +78 -0
  53. data/test/classifiers/id3_test.rb +208 -0
  54. data/test/classifiers/multilayer_perceptron_test.rb +79 -0
  55. data/test/classifiers/naive_bayes_test.rb +43 -0
  56. data/test/classifiers/one_r_test.rb +62 -0
  57. data/test/classifiers/prism_test.rb +85 -0
  58. data/test/classifiers/zero_r_test.rb +49 -0
  59. data/test/clusterers/average_linkage_test.rb +51 -0
  60. data/test/clusterers/bisecting_k_means_test.rb +66 -0
  61. data/test/clusterers/centroid_linkage_test.rb +53 -0
  62. data/test/clusterers/complete_linkage_test.rb +57 -0
  63. data/test/clusterers/diana_test.rb +69 -0
  64. data/test/clusterers/k_means_test.rb +100 -0
  65. data/test/clusterers/median_linkage_test.rb +53 -0
  66. data/test/clusterers/single_linkage_test.rb +122 -0
  67. data/test/clusterers/ward_linkage_hierarchical_test.rb +61 -0
  68. data/test/clusterers/ward_linkage_test.rb +53 -0
  69. data/test/clusterers/weighted_average_linkage_test.rb +53 -0
  70. data/test/data/data_set_test.rb +96 -0
  71. data/test/data/proximity_test.rb +81 -0
  72. data/test/data/statistics_test.rb +65 -0
  73. data/test/experiment/classifier_evaluator_test.rb +76 -0
  74. data/test/genetic_algorithm/chromosome_test.rb +58 -0
  75. data/test/genetic_algorithm/genetic_algorithm_test.rb +81 -0
  76. data/test/neural_network/backpropagation_test.rb +82 -0
  77. data/test/neural_network/hopfield_test.rb +72 -0
  78. data/test/som/som_test.rb +97 -0
  79. metadata +168 -0
@@ -0,0 +1,78 @@
1
+ # Author:: Sergio Fierens
2
+ # License:: MPL 1.1
3
+ # Project:: ai4r
4
+ # Url:: http://ai4r.rubyforge.org/
5
+ #
6
+ # You can redistribute it and/or modify it under the terms of
7
+ # the Mozilla Public License version 1.1 as published by the
8
+ # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
+
10
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/ib1'
11
+ require 'test/unit'
12
+
13
+ class Ai4r::Classifiers::IB1
14
+ attr_accessor :data_set, :min_values, :max_values
15
+ end
16
+
17
+ include Ai4r::Classifiers
18
+ include Ai4r::Data
19
+
20
+ class IB1Test < Test::Unit::TestCase
21
+
22
+ @@data_labels = [ 'city', 'age', 'gender', 'marketing_target' ]
23
+
24
+ @@data_items = [['New York', 25, 'M', 'Y'],
25
+ ['New York', 23, 'M', 'Y'],
26
+ ['New York', 18, 'M', 'Y'],
27
+ ['Chicago', 43, 'M', 'Y'],
28
+ ['New York', 34, 'F', 'N'],
29
+ ['Chicago', 33, 'F', 'Y'],
30
+ ['New York', 31, 'F', 'N'],
31
+ ['Chicago', 55, 'M', 'N'],
32
+ ['New York', 58, 'F', 'N'],
33
+ ['New York', 59, 'M', 'N'],
34
+ ['Chicago', 71, 'M', 'N'],
35
+ ['New York', 60, 'F', 'N'],
36
+ ['Chicago', 85, 'F', 'Y']
37
+ ]
38
+
39
+
40
+ def setup
41
+ IB1.send(:public, *IB1.protected_instance_methods)
42
+ @data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
43
+ @classifier = IB1.new.build(@data_set)
44
+ end
45
+
46
+ def test_build
47
+ assert_raise(ArgumentError) { IB1.new.build(DataSet.new) }
48
+ assert @classifier.data_set
49
+ assert_equal [nil, 18, nil, nil], @classifier.min_values
50
+ assert_equal [nil, 85, nil, nil], @classifier.max_values
51
+ end
52
+
53
+ def test_norm
54
+ assert_equal(0,@classifier.norm('Chicago', 0))
55
+ assert_in_delta(0.5522,@classifier.norm(55, 1),0.0001)
56
+ assert_equal(0,@classifier.norm('F', 0))
57
+ end
58
+
59
+ def test_distance
60
+ item = ['Chicago', 55, 'M', 'N']
61
+ assert_equal(0, @classifier.distance(['Chicago', 55, 'M'], item))
62
+ assert_equal(1, @classifier.distance([nil, 55, 'M'], item))
63
+ assert_equal(1, @classifier.distance(['New York', 55, 'M'], item))
64
+ assert_in_delta(0.2728, @classifier.distance(['Chicago', 20, 'M'], item), 0.0001)
65
+ end
66
+
67
+ def test_eval
68
+ classifier = IB1.new.build(@data_set)
69
+ assert classifier
70
+ assert_equal('N', classifier.eval(['Chicago', 55, 'M']))
71
+ assert_equal('N', classifier.eval(['New York', 35, 'F']))
72
+ assert_equal('Y', classifier.eval(['New York', 25, 'M']))
73
+ assert_equal('Y', classifier.eval(['Chicago', 85, 'F']))
74
+ end
75
+
76
+ end
77
+
78
+
@@ -0,0 +1,208 @@
1
+ # id3_test.rb
2
+ #
3
+ # This is a unit test file for the ID3 algorithm (Quinlan) implemented
4
+ # in ai4r
5
+ #
6
+ # Author:: Sergio Fierens
7
+ # License:: MPL 1.1
8
+ # Project:: ai4r
9
+ # Url:: http://ai4r.rubyforge.org/
10
+ #
11
+ # You can redistribute it and/or modify it under the terms of
12
+ # the Mozilla Public License version 1.1 as published by the
13
+ # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
14
+
15
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/id3'
16
+ require 'test/unit'
17
+
18
+ DATA_LABELS = [ 'city', 'age_range', 'gender', 'marketing_target' ]
19
+
20
+ DATA_ITEMS = [ ['New York', '<30', 'M', 'Y'],
21
+ ['Chicago', '<30', 'M', 'Y'],
22
+ ['Chicago', '<30', 'F', 'Y'],
23
+ ['New York', '<30', 'M', 'Y'],
24
+ ['New York', '<30', 'M', 'Y'],
25
+ ['Chicago', '[30-50)', 'M', 'Y'],
26
+ ['New York', '[30-50)', 'F', 'N'],
27
+ ['Chicago', '[30-50)', 'F', 'Y'],
28
+ ['New York', '[30-50)', 'F', 'N'],
29
+ ['Chicago', '[50-80]', 'M', 'N'],
30
+ ['New York', '[50-80]', 'F', 'N'],
31
+ ['New York', '[50-80]', 'M', 'N'],
32
+ ['Chicago', '[50-80]', 'M', 'N'],
33
+ ['New York', '[50-80]', 'F', 'N'],
34
+ ['Chicago', '>80', 'F', 'Y']
35
+ ]
36
+
37
+ SPLIT_DATA_ITEMS_BY_CITY = [ [
38
+ ["New York", "<30", "M", "Y"],
39
+ ["New York", "<30", "M", "Y"],
40
+ ["New York", "<30", "M", "Y"],
41
+ ["New York", "[30-50)", "F", "N"],
42
+ ["New York", "[30-50)", "F", "N"],
43
+ ["New York", "[50-80]", "F", "N"],
44
+ ["New York", "[50-80]", "M", "N"],
45
+ ["New York", "[50-80]", "F", "N"]],
46
+ [
47
+ ["Chicago", "<30", "M", "Y"],
48
+ ["Chicago", "<30", "F", "Y"],
49
+ ["Chicago", "[30-50)", "M", "Y"],
50
+ ["Chicago", "[30-50)", "F", "Y"],
51
+ ["Chicago", "[50-80]", "M", "N"],
52
+ ["Chicago", "[50-80]", "M", "N"],
53
+ ["Chicago", ">80", "F", "Y"]]
54
+ ]
55
+
56
+ SPLIT_DATA_ITEMS_BY_AGE = [ [
57
+ ["New York", "<30", "M", "Y"],
58
+ ["Chicago", "<30", "M", "Y"],
59
+ ["Chicago", "<30", "F", "Y"],
60
+ ["New York", "<30", "M", "Y"],
61
+ ["New York", "<30", "M", "Y"]],
62
+ [
63
+ ["Chicago", "[30-50)", "M", "Y"],
64
+ ["New York", "[30-50)", "F", "N"],
65
+ ["Chicago", "[30-50)", "F", "Y"],
66
+ ["New York", "[30-50)", "F", "N"]],
67
+ [
68
+ ["Chicago", "[50-80]", "M", "N"],
69
+ ["New York", "[50-80]", "F", "N"],
70
+ ["New York", "[50-80]", "M", "N"],
71
+ ["Chicago", "[50-80]", "M", "N"],
72
+ ["New York", "[50-80]", "F", "N"]],
73
+ [
74
+ ["Chicago", ">80", "F", "Y"]]
75
+ ]
76
+
77
+ EXPECTED_RULES_STRING =
78
+ "if age_range=='<30' then marketing_target='Y'\n"+
79
+ "elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'\n"+
80
+ "elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'\n"+
81
+ "elsif age_range=='[50-80]' then marketing_target='N'\n"+
82
+ "elsif age_range=='>80' then marketing_target='Y'\n"+
83
+ "else raise 'There was not enough information during training to do a proper induction for this data element' end"
84
+
85
+ include Ai4r::Classifiers
86
+ include Ai4r::Data
87
+
88
+ class ID3Test < Test::Unit::TestCase
89
+
90
+ def test_build
91
+ Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.protected_instance_methods)
92
+ Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.private_instance_methods)
93
+ end
94
+
95
+ def test_log2
96
+ assert_equal 1.0, ID3.log2(2)
97
+ assert_equal 0.0, ID3.log2(0)
98
+ assert 1.585 - ID3.log2(3) < 0.001
99
+ end
100
+
101
+ def test_sum
102
+ assert_equal 28, ID3.sum([5, 0, 22, 1])
103
+ assert_equal 0, ID3.sum([])
104
+ end
105
+
106
+ def test_data_labels
107
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS))
108
+ expected_default = [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
109
+ assert_equal(expected_default, id3.data_set.data_labels)
110
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
111
+ assert_equal(DATA_LABELS, id3.data_set.data_labels)
112
+ end
113
+
114
+ def test_domain
115
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
116
+ expected_domain = [["New York", "Chicago"], ["<30", "[30-50)", "[50-80]", ">80"], ["M", "F"], ["Y", "N"]]
117
+ assert_equal expected_domain, id3.domain(DATA_ITEMS)
118
+ end
119
+
120
+ def test_grid
121
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
122
+ expected_grid = [[3, 5], [5, 2]]
123
+ domain = id3.domain(DATA_ITEMS)
124
+ assert_equal expected_grid, id3.freq_grid(0, DATA_ITEMS, domain)
125
+ expected_grid = [[5, 0], [2, 2], [0, 5], [1, 0]]
126
+ assert_equal expected_grid, id3.freq_grid(1, DATA_ITEMS, domain)
127
+ end
128
+
129
+ def test_entropy
130
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
131
+ expected_entropy = 0.9118
132
+ domain = id3.domain(DATA_ITEMS)
133
+ freq_grid = id3.freq_grid(0, DATA_ITEMS, domain)
134
+ assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
135
+ expected_entropy = 0.2667
136
+ freq_grid = id3.freq_grid(1, DATA_ITEMS, domain)
137
+ assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
138
+ expected_entropy = 0.9688
139
+ freq_grid = id3.freq_grid(2, DATA_ITEMS, domain)
140
+ assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
141
+ end
142
+
143
+ def test_min_entropy_index
144
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
145
+ domain = id3.domain(DATA_ITEMS)
146
+ assert_equal 1, id3.min_entropy_index(DATA_ITEMS, domain)
147
+ assert_equal 0, id3.min_entropy_index(DATA_ITEMS, domain, [1])
148
+ assert_equal 2, id3.min_entropy_index(DATA_ITEMS, domain, [1, 0])
149
+ end
150
+
151
+ def test_split_data_examples
152
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
153
+ domain = id3.domain(DATA_ITEMS)
154
+ res = id3.split_data_examples(DATA_ITEMS, domain, 0)
155
+ assert_equal(SPLIT_DATA_ITEMS_BY_CITY, res)
156
+ res = id3.split_data_examples(DATA_ITEMS, domain, 1)
157
+ assert_equal(SPLIT_DATA_ITEMS_BY_AGE, res)
158
+ end
159
+
160
+ def test_most_freq
161
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
162
+ domain = id3.domain(DATA_ITEMS)
163
+ assert_equal 'Y', id3.most_freq(DATA_ITEMS, domain)
164
+ assert_equal 'Y', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[3], domain)
165
+ assert_equal 'N', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[2], domain)
166
+ end
167
+
168
+ def test_get_rules
169
+ assert_equal [["marketing_target='N'"]], CategoryNode.new('marketing_target', 'N').get_rules
170
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
171
+ assert_equal EXPECTED_RULES_STRING, id3.get_rules
172
+ end
173
+
174
+ def test_eval
175
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
176
+ #if age_range='<30' then marketing_target='Y'
177
+ assert_equal 'Y', id3.eval(['New York', '<30', 'F'])
178
+ assert_equal 'Y', id3.eval(['Chicago', '<30', 'M'])
179
+ #if age_range='[30-50)' and city='Chicago' then marketing_target='Y'
180
+ assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'F'])
181
+ assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'M'])
182
+ #if age_range='[30-50)' and city='New York' then marketing_target='N'
183
+ assert_equal 'N', id3.eval(['New York', '[30-50)', 'F'])
184
+ assert_equal 'N', id3.eval(['New York', '[30-50)', 'M'])
185
+ #if age_range='[50-80]' then marketing_target='N'
186
+ assert_equal 'N', id3.eval(['New York', '[50-80]', 'F'])
187
+ assert_equal 'N', id3.eval(['Chicago', '[50-80]', 'M'])
188
+ #if age_range='>80' then marketing_target='Y'
189
+ assert_equal 'Y', id3.eval(['New York', '>80', 'M'])
190
+ assert_equal 'Y', id3.eval(['Chicago', '>80', 'F'])
191
+ end
192
+
193
+ def test_rules_eval
194
+ id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
195
+ #if age_range='<30' then marketing_target='Y'
196
+ age_range = '<30'
197
+ marketing_target = nil
198
+ eval id3.get_rules
199
+ assert_equal 'Y', marketing_target
200
+ #if age_range='[30-50)' and city='New York' then marketing_target='N'
201
+ age_range='[30-50)'
202
+ city='New York'
203
+ eval id3.get_rules
204
+ assert_equal 'N', marketing_target
205
+ end
206
+ end
207
+
208
+
@@ -0,0 +1,79 @@
1
+ require 'test/unit'
2
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/multilayer_perceptron'
3
+ require File.dirname(__FILE__) + '/../../lib/ai4r/data/data_set'
4
+
5
+ # Make all accessors and methods public
6
+ class Ai4r::Classifiers::MultilayerPerceptron
7
+ attr_accessor :data_set, :class_value, :network, :domains, :outputs
8
+ public :get_max_index
9
+ public :data_to_output
10
+ end
11
+
12
+ class MultilayerPerceptronTest < Test::Unit::TestCase
13
+
14
+ include Ai4r::Classifiers
15
+ include Ai4r::Data
16
+
17
+ @@data_set = DataSet.new(:data_items =>[ ['New York', '<30', 'M', 'Y'],
18
+ ['Chicago', '<30', 'M', 'Y'],
19
+ ['New York', '<30', 'M', 'Y'],
20
+ ['New York', '[30-50)', 'F', 'N'],
21
+ ['Chicago', '[30-50)', 'F', 'Y'],
22
+ ['New York', '[30-50)', 'F', 'N'],
23
+ ['Chicago', '[50-80]', 'M', 'N'],
24
+ ])
25
+
26
+ def test_initialize
27
+ classifier = MultilayerPerceptron.new
28
+ assert_equal 1, classifier.active_node_value
29
+ assert_equal 0, classifier.inactive_node_value
30
+ assert_equal Ai4r::NeuralNetwork::Backpropagation, classifier.network_class
31
+ assert_equal [], classifier.hidden_layers
32
+ assert classifier.network_parameters
33
+ assert classifier.network_parameters.empty?
34
+ assert classifier.training_iterations > 1
35
+ end
36
+
37
+ def test_build
38
+ assert_raise(ArgumentError) { MultilayerPerceptron.new.build(DataSet.new) }
39
+ classifier = MultilayerPerceptron.new
40
+ classifier.training_iterations = 1
41
+ classifier.build(@@data_set)
42
+ assert_equal [7,2], classifier.network.structure
43
+ classifier.hidden_layers = [6, 4]
44
+ classifier.build(@@data_set)
45
+ assert_equal [7,6,4,2], classifier.network.structure
46
+ end
47
+
48
+ def test_eval
49
+ classifier = MultilayerPerceptron.new.build(@@data_set)
50
+ assert classifier
51
+ assert_equal('N', classifier.eval(['Chicago', '[50-80]', 'M']))
52
+ assert_equal('N', classifier.eval(['New York', '[30-50)', 'F']))
53
+ assert_equal('Y', classifier.eval(['New York', '<30', 'M']))
54
+ assert_equal('Y', classifier.eval(['Chicago', '[30-50)', 'F']))
55
+ end
56
+
57
+ def test_get_rules
58
+ assert_match(/raise/, MultilayerPerceptron.new.get_rules)
59
+ end
60
+
61
+ def test_get_max_index
62
+ classifier = MultilayerPerceptron.new
63
+ assert_equal(0, classifier.get_max_index([3, 1, 0.2, -9, 0, 2.99]))
64
+ assert_equal(2, classifier.get_max_index([3, 1, 5, -9, 0, 2.99]))
65
+ assert_equal(5, classifier.get_max_index([3, 1, 5, -9, 0, 6]))
66
+ end
67
+
68
+ def test_data_to_output
69
+ classifier = MultilayerPerceptron.new
70
+ classifier.outputs = 4
71
+ classifier.outputs = 4
72
+ classifier.domains = [nil, nil, nil, ["A", "B", "C", "D"]]
73
+ assert_equal([1,0,0,0], classifier.data_to_output("A"))
74
+ assert_equal([0,0,1,0], classifier.data_to_output("C"))
75
+ assert_equal([0,0,0,1], classifier.data_to_output("D"))
76
+ end
77
+
78
+ end
79
+
@@ -0,0 +1,43 @@
1
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/naive_bayes'
2
+ require File.dirname(__FILE__) + '/../../lib/ai4r/data/data_set'
3
+ require 'test/unit'
4
+
5
+ include Ai4r::Classifiers
6
+ include Ai4r::Data
7
+
8
+ class NaiveBayesTest < Test::Unit::TestCase
9
+
10
+ @@data_labels = [ "Color","Type","Origin","Stolen?" ]
11
+
12
+ @@data_items = [
13
+ ["Red", "Sports", "Domestic", "Yes"],
14
+ ["Red", "Sports", "Domestic", "No"],
15
+ ["Red", "Sports", "Domestic", "Yes"],
16
+ ["Yellow","Sports", "Domestic", "No"],
17
+ ["Yellow","Sports", "Imported", "Yes"],
18
+ ["Yellow","SUV", "Imported", "No"],
19
+ ["Yellow","SUV", "Imported", "Yes"],
20
+ ["Yellow","Sports", "Domestic", "No"],
21
+ ["Red", "SUV", "Imported", "No"],
22
+ ["Red", "Sports", "Imported", "Yes"]
23
+ ]
24
+
25
+ def setup
26
+ @data_set = DataSet.new
27
+ @data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
28
+ @b = NaiveBayes.new.set_parameters({:m=>3}).build @data_set
29
+ end
30
+
31
+ def test_eval
32
+ result = @b.eval(["Red", "SUV", "Domestic"])
33
+ assert_equal "No", result
34
+ end
35
+
36
+ def test_get_probability_map
37
+ map = @b.get_probability_map(["Red", "SUV", "Domestic"])
38
+ assert_equal 2, map.keys.length
39
+ assert_in_delta 0.42, map["Yes"], 0.1
40
+ assert_in_delta 0.58, map["No"], 0.1
41
+ end
42
+
43
+ end
@@ -0,0 +1,62 @@
1
+ require 'test/unit'
2
+ require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/one_r'
3
+
4
+ class OneRTest < Test::Unit::TestCase
5
+
6
+ include Ai4r::Classifiers
7
+ include Ai4r::Data
8
+
9
+ @@data_examples = [ ['New York', '<30', 'M', 'Y'],
10
+ ['Chicago', '<30', 'M', 'Y'],
11
+ ['New York', '<30', 'M', 'Y'],
12
+ ['New York', '[30-50)', 'F', 'N'],
13
+ ['Chicago', '[30-50)', 'F', 'Y'],
14
+ ['New York', '[30-50)', 'F', 'N'],
15
+ ['Chicago', '[50-80]', 'M', 'N']
16
+ ]
17
+
18
+ @@data_labels = [ 'city', 'age_range', 'gender', 'marketing_target' ]
19
+
20
+ def test_build
21
+ assert_raise(ArgumentError) { OneR.new.build(DataSet.new) }
22
+ classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples))
23
+ assert_not_nil(classifier.data_set.data_labels)
24
+ assert_not_nil(classifier.rule)
25
+ assert_equal("attribute_1", classifier.data_set.data_labels.first)
26
+ assert_equal("class_value", classifier.data_set.data_labels.last)
27
+ classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples,
28
+ :data_labels => @@data_labels))
29
+ assert_not_nil(classifier.data_set.data_labels)
30
+ assert_not_nil(classifier.rule)
31
+ assert_equal("city", classifier.data_set.data_labels.first)
32
+ assert_equal("marketing_target", classifier.data_set.data_labels.last)
33
+ assert_equal(1, classifier.rule[:attr_index])
34
+ end
35
+
36
+ def test_eval
37
+ classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples))
38
+ assert_equal("Y", classifier.eval(['New York', '<30', 'M']))
39
+ assert_equal("N", classifier.eval(['New York', '[30-50)', 'M']))
40
+ assert_equal("N", classifier.eval(['Chicago', '[50-80]', 'M']))
41
+ end
42
+
43
+ def test_get_rules
44
+ classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples,
45
+ :data_labels => @@data_labels))
46
+ marketing_target = nil
47
+ age_range = nil
48
+ eval(classifier.get_rules)
49
+ assert_nil(marketing_target)
50
+ age_range = '<30'
51
+ eval(classifier.get_rules)
52
+ assert_equal("Y", marketing_target)
53
+ age_range = '[30-50)'
54
+ eval(classifier.get_rules)
55
+ assert_equal("N", marketing_target)
56
+ age_range = '[50-80]'
57
+ eval(classifier.get_rules)
58
+ assert_equal("N", marketing_target)
59
+ end
60
+
61
+ end
62
+