ai4ruby 1.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/README.rdoc +47 -0
- data/examples/classifiers/id3_data.csv +121 -0
- data/examples/classifiers/id3_example.rb +29 -0
- data/examples/classifiers/naive_bayes_data.csv +11 -0
- data/examples/classifiers/naive_bayes_example.rb +16 -0
- data/examples/classifiers/results.txt +31 -0
- data/examples/genetic_algorithm/genetic_algorithm_example.rb +37 -0
- data/examples/genetic_algorithm/travel_cost.csv +16 -0
- data/examples/neural_network/backpropagation_example.rb +67 -0
- data/examples/neural_network/patterns_with_base_noise.rb +68 -0
- data/examples/neural_network/patterns_with_noise.rb +66 -0
- data/examples/neural_network/training_patterns.rb +68 -0
- data/examples/neural_network/xor_example.rb +35 -0
- data/examples/som/som_data.rb +156 -0
- data/examples/som/som_multi_node_example.rb +22 -0
- data/examples/som/som_single_example.rb +24 -0
- data/lib/ai4r.rb +33 -0
- data/lib/ai4r/classifiers/classifier.rb +62 -0
- data/lib/ai4r/classifiers/hyperpipes.rb +118 -0
- data/lib/ai4r/classifiers/ib1.rb +121 -0
- data/lib/ai4r/classifiers/id3.rb +326 -0
- data/lib/ai4r/classifiers/multilayer_perceptron.rb +135 -0
- data/lib/ai4r/classifiers/naive_bayes.rb +259 -0
- data/lib/ai4r/classifiers/one_r.rb +110 -0
- data/lib/ai4r/classifiers/prism.rb +197 -0
- data/lib/ai4r/classifiers/zero_r.rb +73 -0
- data/lib/ai4r/clusterers/average_linkage.rb +59 -0
- data/lib/ai4r/clusterers/bisecting_k_means.rb +93 -0
- data/lib/ai4r/clusterers/centroid_linkage.rb +66 -0
- data/lib/ai4r/clusterers/clusterer.rb +61 -0
- data/lib/ai4r/clusterers/complete_linkage.rb +67 -0
- data/lib/ai4r/clusterers/diana.rb +139 -0
- data/lib/ai4r/clusterers/k_means.rb +126 -0
- data/lib/ai4r/clusterers/median_linkage.rb +61 -0
- data/lib/ai4r/clusterers/single_linkage.rb +194 -0
- data/lib/ai4r/clusterers/ward_linkage.rb +64 -0
- data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +31 -0
- data/lib/ai4r/clusterers/weighted_average_linkage.rb +61 -0
- data/lib/ai4r/data/data_set.rb +266 -0
- data/lib/ai4r/data/parameterizable.rb +64 -0
- data/lib/ai4r/data/proximity.rb +100 -0
- data/lib/ai4r/data/statistics.rb +77 -0
- data/lib/ai4r/experiment/classifier_evaluator.rb +95 -0
- data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +270 -0
- data/lib/ai4r/neural_network/backpropagation.rb +326 -0
- data/lib/ai4r/neural_network/hopfield.rb +149 -0
- data/lib/ai4r/som/layer.rb +68 -0
- data/lib/ai4r/som/node.rb +96 -0
- data/lib/ai4r/som/som.rb +155 -0
- data/lib/ai4r/som/two_phase_layer.rb +90 -0
- data/test/classifiers/hyperpipes_test.rb +84 -0
- data/test/classifiers/ib1_test.rb +78 -0
- data/test/classifiers/id3_test.rb +208 -0
- data/test/classifiers/multilayer_perceptron_test.rb +79 -0
- data/test/classifiers/naive_bayes_test.rb +43 -0
- data/test/classifiers/one_r_test.rb +62 -0
- data/test/classifiers/prism_test.rb +85 -0
- data/test/classifiers/zero_r_test.rb +49 -0
- data/test/clusterers/average_linkage_test.rb +51 -0
- data/test/clusterers/bisecting_k_means_test.rb +66 -0
- data/test/clusterers/centroid_linkage_test.rb +53 -0
- data/test/clusterers/complete_linkage_test.rb +57 -0
- data/test/clusterers/diana_test.rb +69 -0
- data/test/clusterers/k_means_test.rb +100 -0
- data/test/clusterers/median_linkage_test.rb +53 -0
- data/test/clusterers/single_linkage_test.rb +122 -0
- data/test/clusterers/ward_linkage_hierarchical_test.rb +61 -0
- data/test/clusterers/ward_linkage_test.rb +53 -0
- data/test/clusterers/weighted_average_linkage_test.rb +53 -0
- data/test/data/data_set_test.rb +96 -0
- data/test/data/proximity_test.rb +81 -0
- data/test/data/statistics_test.rb +65 -0
- data/test/experiment/classifier_evaluator_test.rb +76 -0
- data/test/genetic_algorithm/chromosome_test.rb +58 -0
- data/test/genetic_algorithm/genetic_algorithm_test.rb +81 -0
- data/test/neural_network/backpropagation_test.rb +82 -0
- data/test/neural_network/hopfield_test.rb +72 -0
- data/test/som/som_test.rb +97 -0
- metadata +168 -0
@@ -0,0 +1,78 @@
|
|
1
|
+
# Author:: Sergio Fierens
|
2
|
+
# License:: MPL 1.1
|
3
|
+
# Project:: ai4r
|
4
|
+
# Url:: http://ai4r.rubyforge.org/
|
5
|
+
#
|
6
|
+
# You can redistribute it and/or modify it under the terms of
|
7
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
|
+
|
10
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/ib1'
|
11
|
+
require 'test/unit'
|
12
|
+
|
13
|
+
class Ai4r::Classifiers::IB1
|
14
|
+
attr_accessor :data_set, :min_values, :max_values
|
15
|
+
end
|
16
|
+
|
17
|
+
include Ai4r::Classifiers
|
18
|
+
include Ai4r::Data
|
19
|
+
|
20
|
+
class IB1Test < Test::Unit::TestCase
|
21
|
+
|
22
|
+
@@data_labels = [ 'city', 'age', 'gender', 'marketing_target' ]
|
23
|
+
|
24
|
+
@@data_items = [['New York', 25, 'M', 'Y'],
|
25
|
+
['New York', 23, 'M', 'Y'],
|
26
|
+
['New York', 18, 'M', 'Y'],
|
27
|
+
['Chicago', 43, 'M', 'Y'],
|
28
|
+
['New York', 34, 'F', 'N'],
|
29
|
+
['Chicago', 33, 'F', 'Y'],
|
30
|
+
['New York', 31, 'F', 'N'],
|
31
|
+
['Chicago', 55, 'M', 'N'],
|
32
|
+
['New York', 58, 'F', 'N'],
|
33
|
+
['New York', 59, 'M', 'N'],
|
34
|
+
['Chicago', 71, 'M', 'N'],
|
35
|
+
['New York', 60, 'F', 'N'],
|
36
|
+
['Chicago', 85, 'F', 'Y']
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
def setup
|
41
|
+
IB1.send(:public, *IB1.protected_instance_methods)
|
42
|
+
@data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
|
43
|
+
@classifier = IB1.new.build(@data_set)
|
44
|
+
end
|
45
|
+
|
46
|
+
def test_build
|
47
|
+
assert_raise(ArgumentError) { IB1.new.build(DataSet.new) }
|
48
|
+
assert @classifier.data_set
|
49
|
+
assert_equal [nil, 18, nil, nil], @classifier.min_values
|
50
|
+
assert_equal [nil, 85, nil, nil], @classifier.max_values
|
51
|
+
end
|
52
|
+
|
53
|
+
def test_norm
|
54
|
+
assert_equal(0,@classifier.norm('Chicago', 0))
|
55
|
+
assert_in_delta(0.5522,@classifier.norm(55, 1),0.0001)
|
56
|
+
assert_equal(0,@classifier.norm('F', 0))
|
57
|
+
end
|
58
|
+
|
59
|
+
def test_distance
|
60
|
+
item = ['Chicago', 55, 'M', 'N']
|
61
|
+
assert_equal(0, @classifier.distance(['Chicago', 55, 'M'], item))
|
62
|
+
assert_equal(1, @classifier.distance([nil, 55, 'M'], item))
|
63
|
+
assert_equal(1, @classifier.distance(['New York', 55, 'M'], item))
|
64
|
+
assert_in_delta(0.2728, @classifier.distance(['Chicago', 20, 'M'], item), 0.0001)
|
65
|
+
end
|
66
|
+
|
67
|
+
def test_eval
|
68
|
+
classifier = IB1.new.build(@data_set)
|
69
|
+
assert classifier
|
70
|
+
assert_equal('N', classifier.eval(['Chicago', 55, 'M']))
|
71
|
+
assert_equal('N', classifier.eval(['New York', 35, 'F']))
|
72
|
+
assert_equal('Y', classifier.eval(['New York', 25, 'M']))
|
73
|
+
assert_equal('Y', classifier.eval(['Chicago', 85, 'F']))
|
74
|
+
end
|
75
|
+
|
76
|
+
end
|
77
|
+
|
78
|
+
|
@@ -0,0 +1,208 @@
|
|
1
|
+
# id3_test.rb
|
2
|
+
#
|
3
|
+
# This is a unit test file for the ID3 algorithm (Quinlan) implemented
|
4
|
+
# in ai4r
|
5
|
+
#
|
6
|
+
# Author:: Sergio Fierens
|
7
|
+
# License:: MPL 1.1
|
8
|
+
# Project:: ai4r
|
9
|
+
# Url:: http://ai4r.rubyforge.org/
|
10
|
+
#
|
11
|
+
# You can redistribute it and/or modify it under the terms of
|
12
|
+
# the Mozilla Public License version 1.1 as published by the
|
13
|
+
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
14
|
+
|
15
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/id3'
|
16
|
+
require 'test/unit'
|
17
|
+
|
18
|
+
DATA_LABELS = [ 'city', 'age_range', 'gender', 'marketing_target' ]
|
19
|
+
|
20
|
+
DATA_ITEMS = [ ['New York', '<30', 'M', 'Y'],
|
21
|
+
['Chicago', '<30', 'M', 'Y'],
|
22
|
+
['Chicago', '<30', 'F', 'Y'],
|
23
|
+
['New York', '<30', 'M', 'Y'],
|
24
|
+
['New York', '<30', 'M', 'Y'],
|
25
|
+
['Chicago', '[30-50)', 'M', 'Y'],
|
26
|
+
['New York', '[30-50)', 'F', 'N'],
|
27
|
+
['Chicago', '[30-50)', 'F', 'Y'],
|
28
|
+
['New York', '[30-50)', 'F', 'N'],
|
29
|
+
['Chicago', '[50-80]', 'M', 'N'],
|
30
|
+
['New York', '[50-80]', 'F', 'N'],
|
31
|
+
['New York', '[50-80]', 'M', 'N'],
|
32
|
+
['Chicago', '[50-80]', 'M', 'N'],
|
33
|
+
['New York', '[50-80]', 'F', 'N'],
|
34
|
+
['Chicago', '>80', 'F', 'Y']
|
35
|
+
]
|
36
|
+
|
37
|
+
SPLIT_DATA_ITEMS_BY_CITY = [ [
|
38
|
+
["New York", "<30", "M", "Y"],
|
39
|
+
["New York", "<30", "M", "Y"],
|
40
|
+
["New York", "<30", "M", "Y"],
|
41
|
+
["New York", "[30-50)", "F", "N"],
|
42
|
+
["New York", "[30-50)", "F", "N"],
|
43
|
+
["New York", "[50-80]", "F", "N"],
|
44
|
+
["New York", "[50-80]", "M", "N"],
|
45
|
+
["New York", "[50-80]", "F", "N"]],
|
46
|
+
[
|
47
|
+
["Chicago", "<30", "M", "Y"],
|
48
|
+
["Chicago", "<30", "F", "Y"],
|
49
|
+
["Chicago", "[30-50)", "M", "Y"],
|
50
|
+
["Chicago", "[30-50)", "F", "Y"],
|
51
|
+
["Chicago", "[50-80]", "M", "N"],
|
52
|
+
["Chicago", "[50-80]", "M", "N"],
|
53
|
+
["Chicago", ">80", "F", "Y"]]
|
54
|
+
]
|
55
|
+
|
56
|
+
SPLIT_DATA_ITEMS_BY_AGE = [ [
|
57
|
+
["New York", "<30", "M", "Y"],
|
58
|
+
["Chicago", "<30", "M", "Y"],
|
59
|
+
["Chicago", "<30", "F", "Y"],
|
60
|
+
["New York", "<30", "M", "Y"],
|
61
|
+
["New York", "<30", "M", "Y"]],
|
62
|
+
[
|
63
|
+
["Chicago", "[30-50)", "M", "Y"],
|
64
|
+
["New York", "[30-50)", "F", "N"],
|
65
|
+
["Chicago", "[30-50)", "F", "Y"],
|
66
|
+
["New York", "[30-50)", "F", "N"]],
|
67
|
+
[
|
68
|
+
["Chicago", "[50-80]", "M", "N"],
|
69
|
+
["New York", "[50-80]", "F", "N"],
|
70
|
+
["New York", "[50-80]", "M", "N"],
|
71
|
+
["Chicago", "[50-80]", "M", "N"],
|
72
|
+
["New York", "[50-80]", "F", "N"]],
|
73
|
+
[
|
74
|
+
["Chicago", ">80", "F", "Y"]]
|
75
|
+
]
|
76
|
+
|
77
|
+
EXPECTED_RULES_STRING =
|
78
|
+
"if age_range=='<30' then marketing_target='Y'\n"+
|
79
|
+
"elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'\n"+
|
80
|
+
"elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'\n"+
|
81
|
+
"elsif age_range=='[50-80]' then marketing_target='N'\n"+
|
82
|
+
"elsif age_range=='>80' then marketing_target='Y'\n"+
|
83
|
+
"else raise 'There was not enough information during training to do a proper induction for this data element' end"
|
84
|
+
|
85
|
+
include Ai4r::Classifiers
|
86
|
+
include Ai4r::Data
|
87
|
+
|
88
|
+
class ID3Test < Test::Unit::TestCase
|
89
|
+
|
90
|
+
def test_build
|
91
|
+
Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.protected_instance_methods)
|
92
|
+
Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.private_instance_methods)
|
93
|
+
end
|
94
|
+
|
95
|
+
def test_log2
|
96
|
+
assert_equal 1.0, ID3.log2(2)
|
97
|
+
assert_equal 0.0, ID3.log2(0)
|
98
|
+
assert 1.585 - ID3.log2(3) < 0.001
|
99
|
+
end
|
100
|
+
|
101
|
+
def test_sum
|
102
|
+
assert_equal 28, ID3.sum([5, 0, 22, 1])
|
103
|
+
assert_equal 0, ID3.sum([])
|
104
|
+
end
|
105
|
+
|
106
|
+
def test_data_labels
|
107
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS))
|
108
|
+
expected_default = [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
|
109
|
+
assert_equal(expected_default, id3.data_set.data_labels)
|
110
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
111
|
+
assert_equal(DATA_LABELS, id3.data_set.data_labels)
|
112
|
+
end
|
113
|
+
|
114
|
+
def test_domain
|
115
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
116
|
+
expected_domain = [["New York", "Chicago"], ["<30", "[30-50)", "[50-80]", ">80"], ["M", "F"], ["Y", "N"]]
|
117
|
+
assert_equal expected_domain, id3.domain(DATA_ITEMS)
|
118
|
+
end
|
119
|
+
|
120
|
+
def test_grid
|
121
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
122
|
+
expected_grid = [[3, 5], [5, 2]]
|
123
|
+
domain = id3.domain(DATA_ITEMS)
|
124
|
+
assert_equal expected_grid, id3.freq_grid(0, DATA_ITEMS, domain)
|
125
|
+
expected_grid = [[5, 0], [2, 2], [0, 5], [1, 0]]
|
126
|
+
assert_equal expected_grid, id3.freq_grid(1, DATA_ITEMS, domain)
|
127
|
+
end
|
128
|
+
|
129
|
+
def test_entropy
|
130
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
131
|
+
expected_entropy = 0.9118
|
132
|
+
domain = id3.domain(DATA_ITEMS)
|
133
|
+
freq_grid = id3.freq_grid(0, DATA_ITEMS, domain)
|
134
|
+
assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
|
135
|
+
expected_entropy = 0.2667
|
136
|
+
freq_grid = id3.freq_grid(1, DATA_ITEMS, domain)
|
137
|
+
assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
|
138
|
+
expected_entropy = 0.9688
|
139
|
+
freq_grid = id3.freq_grid(2, DATA_ITEMS, domain)
|
140
|
+
assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
|
141
|
+
end
|
142
|
+
|
143
|
+
def test_min_entropy_index
|
144
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
145
|
+
domain = id3.domain(DATA_ITEMS)
|
146
|
+
assert_equal 1, id3.min_entropy_index(DATA_ITEMS, domain)
|
147
|
+
assert_equal 0, id3.min_entropy_index(DATA_ITEMS, domain, [1])
|
148
|
+
assert_equal 2, id3.min_entropy_index(DATA_ITEMS, domain, [1, 0])
|
149
|
+
end
|
150
|
+
|
151
|
+
def test_split_data_examples
|
152
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
153
|
+
domain = id3.domain(DATA_ITEMS)
|
154
|
+
res = id3.split_data_examples(DATA_ITEMS, domain, 0)
|
155
|
+
assert_equal(SPLIT_DATA_ITEMS_BY_CITY, res)
|
156
|
+
res = id3.split_data_examples(DATA_ITEMS, domain, 1)
|
157
|
+
assert_equal(SPLIT_DATA_ITEMS_BY_AGE, res)
|
158
|
+
end
|
159
|
+
|
160
|
+
def test_most_freq
|
161
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
162
|
+
domain = id3.domain(DATA_ITEMS)
|
163
|
+
assert_equal 'Y', id3.most_freq(DATA_ITEMS, domain)
|
164
|
+
assert_equal 'Y', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[3], domain)
|
165
|
+
assert_equal 'N', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[2], domain)
|
166
|
+
end
|
167
|
+
|
168
|
+
def test_get_rules
|
169
|
+
assert_equal [["marketing_target='N'"]], CategoryNode.new('marketing_target', 'N').get_rules
|
170
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
171
|
+
assert_equal EXPECTED_RULES_STRING, id3.get_rules
|
172
|
+
end
|
173
|
+
|
174
|
+
def test_eval
|
175
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
176
|
+
#if age_range='<30' then marketing_target='Y'
|
177
|
+
assert_equal 'Y', id3.eval(['New York', '<30', 'F'])
|
178
|
+
assert_equal 'Y', id3.eval(['Chicago', '<30', 'M'])
|
179
|
+
#if age_range='[30-50)' and city='Chicago' then marketing_target='Y'
|
180
|
+
assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'F'])
|
181
|
+
assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'M'])
|
182
|
+
#if age_range='[30-50)' and city='New York' then marketing_target='N'
|
183
|
+
assert_equal 'N', id3.eval(['New York', '[30-50)', 'F'])
|
184
|
+
assert_equal 'N', id3.eval(['New York', '[30-50)', 'M'])
|
185
|
+
#if age_range='[50-80]' then marketing_target='N'
|
186
|
+
assert_equal 'N', id3.eval(['New York', '[50-80]', 'F'])
|
187
|
+
assert_equal 'N', id3.eval(['Chicago', '[50-80]', 'M'])
|
188
|
+
#if age_range='>80' then marketing_target='Y'
|
189
|
+
assert_equal 'Y', id3.eval(['New York', '>80', 'M'])
|
190
|
+
assert_equal 'Y', id3.eval(['Chicago', '>80', 'F'])
|
191
|
+
end
|
192
|
+
|
193
|
+
def test_rules_eval
|
194
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
195
|
+
#if age_range='<30' then marketing_target='Y'
|
196
|
+
age_range = '<30'
|
197
|
+
marketing_target = nil
|
198
|
+
eval id3.get_rules
|
199
|
+
assert_equal 'Y', marketing_target
|
200
|
+
#if age_range='[30-50)' and city='New York' then marketing_target='N'
|
201
|
+
age_range='[30-50)'
|
202
|
+
city='New York'
|
203
|
+
eval id3.get_rules
|
204
|
+
assert_equal 'N', marketing_target
|
205
|
+
end
|
206
|
+
end
|
207
|
+
|
208
|
+
|
@@ -0,0 +1,79 @@
|
|
1
|
+
require 'test/unit'
|
2
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/multilayer_perceptron'
|
3
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/data/data_set'
|
4
|
+
|
5
|
+
# Make all accessors and methods public
|
6
|
+
class Ai4r::Classifiers::MultilayerPerceptron
|
7
|
+
attr_accessor :data_set, :class_value, :network, :domains, :outputs
|
8
|
+
public :get_max_index
|
9
|
+
public :data_to_output
|
10
|
+
end
|
11
|
+
|
12
|
+
class MultilayerPerceptronTest < Test::Unit::TestCase
|
13
|
+
|
14
|
+
include Ai4r::Classifiers
|
15
|
+
include Ai4r::Data
|
16
|
+
|
17
|
+
@@data_set = DataSet.new(:data_items =>[ ['New York', '<30', 'M', 'Y'],
|
18
|
+
['Chicago', '<30', 'M', 'Y'],
|
19
|
+
['New York', '<30', 'M', 'Y'],
|
20
|
+
['New York', '[30-50)', 'F', 'N'],
|
21
|
+
['Chicago', '[30-50)', 'F', 'Y'],
|
22
|
+
['New York', '[30-50)', 'F', 'N'],
|
23
|
+
['Chicago', '[50-80]', 'M', 'N'],
|
24
|
+
])
|
25
|
+
|
26
|
+
def test_initialize
|
27
|
+
classifier = MultilayerPerceptron.new
|
28
|
+
assert_equal 1, classifier.active_node_value
|
29
|
+
assert_equal 0, classifier.inactive_node_value
|
30
|
+
assert_equal Ai4r::NeuralNetwork::Backpropagation, classifier.network_class
|
31
|
+
assert_equal [], classifier.hidden_layers
|
32
|
+
assert classifier.network_parameters
|
33
|
+
assert classifier.network_parameters.empty?
|
34
|
+
assert classifier.training_iterations > 1
|
35
|
+
end
|
36
|
+
|
37
|
+
def test_build
|
38
|
+
assert_raise(ArgumentError) { MultilayerPerceptron.new.build(DataSet.new) }
|
39
|
+
classifier = MultilayerPerceptron.new
|
40
|
+
classifier.training_iterations = 1
|
41
|
+
classifier.build(@@data_set)
|
42
|
+
assert_equal [7,2], classifier.network.structure
|
43
|
+
classifier.hidden_layers = [6, 4]
|
44
|
+
classifier.build(@@data_set)
|
45
|
+
assert_equal [7,6,4,2], classifier.network.structure
|
46
|
+
end
|
47
|
+
|
48
|
+
def test_eval
|
49
|
+
classifier = MultilayerPerceptron.new.build(@@data_set)
|
50
|
+
assert classifier
|
51
|
+
assert_equal('N', classifier.eval(['Chicago', '[50-80]', 'M']))
|
52
|
+
assert_equal('N', classifier.eval(['New York', '[30-50)', 'F']))
|
53
|
+
assert_equal('Y', classifier.eval(['New York', '<30', 'M']))
|
54
|
+
assert_equal('Y', classifier.eval(['Chicago', '[30-50)', 'F']))
|
55
|
+
end
|
56
|
+
|
57
|
+
def test_get_rules
|
58
|
+
assert_match(/raise/, MultilayerPerceptron.new.get_rules)
|
59
|
+
end
|
60
|
+
|
61
|
+
def test_get_max_index
|
62
|
+
classifier = MultilayerPerceptron.new
|
63
|
+
assert_equal(0, classifier.get_max_index([3, 1, 0.2, -9, 0, 2.99]))
|
64
|
+
assert_equal(2, classifier.get_max_index([3, 1, 5, -9, 0, 2.99]))
|
65
|
+
assert_equal(5, classifier.get_max_index([3, 1, 5, -9, 0, 6]))
|
66
|
+
end
|
67
|
+
|
68
|
+
def test_data_to_output
|
69
|
+
classifier = MultilayerPerceptron.new
|
70
|
+
classifier.outputs = 4
|
71
|
+
classifier.outputs = 4
|
72
|
+
classifier.domains = [nil, nil, nil, ["A", "B", "C", "D"]]
|
73
|
+
assert_equal([1,0,0,0], classifier.data_to_output("A"))
|
74
|
+
assert_equal([0,0,1,0], classifier.data_to_output("C"))
|
75
|
+
assert_equal([0,0,0,1], classifier.data_to_output("D"))
|
76
|
+
end
|
77
|
+
|
78
|
+
end
|
79
|
+
|
@@ -0,0 +1,43 @@
|
|
1
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/naive_bayes'
|
2
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/data/data_set'
|
3
|
+
require 'test/unit'
|
4
|
+
|
5
|
+
include Ai4r::Classifiers
|
6
|
+
include Ai4r::Data
|
7
|
+
|
8
|
+
class NaiveBayesTest < Test::Unit::TestCase
|
9
|
+
|
10
|
+
@@data_labels = [ "Color","Type","Origin","Stolen?" ]
|
11
|
+
|
12
|
+
@@data_items = [
|
13
|
+
["Red", "Sports", "Domestic", "Yes"],
|
14
|
+
["Red", "Sports", "Domestic", "No"],
|
15
|
+
["Red", "Sports", "Domestic", "Yes"],
|
16
|
+
["Yellow","Sports", "Domestic", "No"],
|
17
|
+
["Yellow","Sports", "Imported", "Yes"],
|
18
|
+
["Yellow","SUV", "Imported", "No"],
|
19
|
+
["Yellow","SUV", "Imported", "Yes"],
|
20
|
+
["Yellow","Sports", "Domestic", "No"],
|
21
|
+
["Red", "SUV", "Imported", "No"],
|
22
|
+
["Red", "Sports", "Imported", "Yes"]
|
23
|
+
]
|
24
|
+
|
25
|
+
def setup
|
26
|
+
@data_set = DataSet.new
|
27
|
+
@data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
|
28
|
+
@b = NaiveBayes.new.set_parameters({:m=>3}).build @data_set
|
29
|
+
end
|
30
|
+
|
31
|
+
def test_eval
|
32
|
+
result = @b.eval(["Red", "SUV", "Domestic"])
|
33
|
+
assert_equal "No", result
|
34
|
+
end
|
35
|
+
|
36
|
+
def test_get_probability_map
|
37
|
+
map = @b.get_probability_map(["Red", "SUV", "Domestic"])
|
38
|
+
assert_equal 2, map.keys.length
|
39
|
+
assert_in_delta 0.42, map["Yes"], 0.1
|
40
|
+
assert_in_delta 0.58, map["No"], 0.1
|
41
|
+
end
|
42
|
+
|
43
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
require 'test/unit'
|
2
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/one_r'
|
3
|
+
|
4
|
+
class OneRTest < Test::Unit::TestCase
|
5
|
+
|
6
|
+
include Ai4r::Classifiers
|
7
|
+
include Ai4r::Data
|
8
|
+
|
9
|
+
@@data_examples = [ ['New York', '<30', 'M', 'Y'],
|
10
|
+
['Chicago', '<30', 'M', 'Y'],
|
11
|
+
['New York', '<30', 'M', 'Y'],
|
12
|
+
['New York', '[30-50)', 'F', 'N'],
|
13
|
+
['Chicago', '[30-50)', 'F', 'Y'],
|
14
|
+
['New York', '[30-50)', 'F', 'N'],
|
15
|
+
['Chicago', '[50-80]', 'M', 'N']
|
16
|
+
]
|
17
|
+
|
18
|
+
@@data_labels = [ 'city', 'age_range', 'gender', 'marketing_target' ]
|
19
|
+
|
20
|
+
def test_build
|
21
|
+
assert_raise(ArgumentError) { OneR.new.build(DataSet.new) }
|
22
|
+
classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples))
|
23
|
+
assert_not_nil(classifier.data_set.data_labels)
|
24
|
+
assert_not_nil(classifier.rule)
|
25
|
+
assert_equal("attribute_1", classifier.data_set.data_labels.first)
|
26
|
+
assert_equal("class_value", classifier.data_set.data_labels.last)
|
27
|
+
classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples,
|
28
|
+
:data_labels => @@data_labels))
|
29
|
+
assert_not_nil(classifier.data_set.data_labels)
|
30
|
+
assert_not_nil(classifier.rule)
|
31
|
+
assert_equal("city", classifier.data_set.data_labels.first)
|
32
|
+
assert_equal("marketing_target", classifier.data_set.data_labels.last)
|
33
|
+
assert_equal(1, classifier.rule[:attr_index])
|
34
|
+
end
|
35
|
+
|
36
|
+
def test_eval
|
37
|
+
classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples))
|
38
|
+
assert_equal("Y", classifier.eval(['New York', '<30', 'M']))
|
39
|
+
assert_equal("N", classifier.eval(['New York', '[30-50)', 'M']))
|
40
|
+
assert_equal("N", classifier.eval(['Chicago', '[50-80]', 'M']))
|
41
|
+
end
|
42
|
+
|
43
|
+
def test_get_rules
|
44
|
+
classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples,
|
45
|
+
:data_labels => @@data_labels))
|
46
|
+
marketing_target = nil
|
47
|
+
age_range = nil
|
48
|
+
eval(classifier.get_rules)
|
49
|
+
assert_nil(marketing_target)
|
50
|
+
age_range = '<30'
|
51
|
+
eval(classifier.get_rules)
|
52
|
+
assert_equal("Y", marketing_target)
|
53
|
+
age_range = '[30-50)'
|
54
|
+
eval(classifier.get_rules)
|
55
|
+
assert_equal("N", marketing_target)
|
56
|
+
age_range = '[50-80]'
|
57
|
+
eval(classifier.get_rules)
|
58
|
+
assert_equal("N", marketing_target)
|
59
|
+
end
|
60
|
+
|
61
|
+
end
|
62
|
+
|