ai4ruby 1.11
Sign up to get free protection for your applications and to get access to all the features.
- data/README.rdoc +47 -0
- data/examples/classifiers/id3_data.csv +121 -0
- data/examples/classifiers/id3_example.rb +29 -0
- data/examples/classifiers/naive_bayes_data.csv +11 -0
- data/examples/classifiers/naive_bayes_example.rb +16 -0
- data/examples/classifiers/results.txt +31 -0
- data/examples/genetic_algorithm/genetic_algorithm_example.rb +37 -0
- data/examples/genetic_algorithm/travel_cost.csv +16 -0
- data/examples/neural_network/backpropagation_example.rb +67 -0
- data/examples/neural_network/patterns_with_base_noise.rb +68 -0
- data/examples/neural_network/patterns_with_noise.rb +66 -0
- data/examples/neural_network/training_patterns.rb +68 -0
- data/examples/neural_network/xor_example.rb +35 -0
- data/examples/som/som_data.rb +156 -0
- data/examples/som/som_multi_node_example.rb +22 -0
- data/examples/som/som_single_example.rb +24 -0
- data/lib/ai4r.rb +33 -0
- data/lib/ai4r/classifiers/classifier.rb +62 -0
- data/lib/ai4r/classifiers/hyperpipes.rb +118 -0
- data/lib/ai4r/classifiers/ib1.rb +121 -0
- data/lib/ai4r/classifiers/id3.rb +326 -0
- data/lib/ai4r/classifiers/multilayer_perceptron.rb +135 -0
- data/lib/ai4r/classifiers/naive_bayes.rb +259 -0
- data/lib/ai4r/classifiers/one_r.rb +110 -0
- data/lib/ai4r/classifiers/prism.rb +197 -0
- data/lib/ai4r/classifiers/zero_r.rb +73 -0
- data/lib/ai4r/clusterers/average_linkage.rb +59 -0
- data/lib/ai4r/clusterers/bisecting_k_means.rb +93 -0
- data/lib/ai4r/clusterers/centroid_linkage.rb +66 -0
- data/lib/ai4r/clusterers/clusterer.rb +61 -0
- data/lib/ai4r/clusterers/complete_linkage.rb +67 -0
- data/lib/ai4r/clusterers/diana.rb +139 -0
- data/lib/ai4r/clusterers/k_means.rb +126 -0
- data/lib/ai4r/clusterers/median_linkage.rb +61 -0
- data/lib/ai4r/clusterers/single_linkage.rb +194 -0
- data/lib/ai4r/clusterers/ward_linkage.rb +64 -0
- data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +31 -0
- data/lib/ai4r/clusterers/weighted_average_linkage.rb +61 -0
- data/lib/ai4r/data/data_set.rb +266 -0
- data/lib/ai4r/data/parameterizable.rb +64 -0
- data/lib/ai4r/data/proximity.rb +100 -0
- data/lib/ai4r/data/statistics.rb +77 -0
- data/lib/ai4r/experiment/classifier_evaluator.rb +95 -0
- data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +270 -0
- data/lib/ai4r/neural_network/backpropagation.rb +326 -0
- data/lib/ai4r/neural_network/hopfield.rb +149 -0
- data/lib/ai4r/som/layer.rb +68 -0
- data/lib/ai4r/som/node.rb +96 -0
- data/lib/ai4r/som/som.rb +155 -0
- data/lib/ai4r/som/two_phase_layer.rb +90 -0
- data/test/classifiers/hyperpipes_test.rb +84 -0
- data/test/classifiers/ib1_test.rb +78 -0
- data/test/classifiers/id3_test.rb +208 -0
- data/test/classifiers/multilayer_perceptron_test.rb +79 -0
- data/test/classifiers/naive_bayes_test.rb +43 -0
- data/test/classifiers/one_r_test.rb +62 -0
- data/test/classifiers/prism_test.rb +85 -0
- data/test/classifiers/zero_r_test.rb +49 -0
- data/test/clusterers/average_linkage_test.rb +51 -0
- data/test/clusterers/bisecting_k_means_test.rb +66 -0
- data/test/clusterers/centroid_linkage_test.rb +53 -0
- data/test/clusterers/complete_linkage_test.rb +57 -0
- data/test/clusterers/diana_test.rb +69 -0
- data/test/clusterers/k_means_test.rb +100 -0
- data/test/clusterers/median_linkage_test.rb +53 -0
- data/test/clusterers/single_linkage_test.rb +122 -0
- data/test/clusterers/ward_linkage_hierarchical_test.rb +61 -0
- data/test/clusterers/ward_linkage_test.rb +53 -0
- data/test/clusterers/weighted_average_linkage_test.rb +53 -0
- data/test/data/data_set_test.rb +96 -0
- data/test/data/proximity_test.rb +81 -0
- data/test/data/statistics_test.rb +65 -0
- data/test/experiment/classifier_evaluator_test.rb +76 -0
- data/test/genetic_algorithm/chromosome_test.rb +58 -0
- data/test/genetic_algorithm/genetic_algorithm_test.rb +81 -0
- data/test/neural_network/backpropagation_test.rb +82 -0
- data/test/neural_network/hopfield_test.rb +72 -0
- data/test/som/som_test.rb +97 -0
- metadata +168 -0
@@ -0,0 +1,78 @@
|
|
1
|
+
# Author:: Sergio Fierens
|
2
|
+
# License:: MPL 1.1
|
3
|
+
# Project:: ai4r
|
4
|
+
# Url:: http://ai4r.rubyforge.org/
|
5
|
+
#
|
6
|
+
# You can redistribute it and/or modify it under the terms of
|
7
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
|
+
|
10
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/ib1'
|
11
|
+
require 'test/unit'
|
12
|
+
|
13
|
+
class Ai4r::Classifiers::IB1
|
14
|
+
attr_accessor :data_set, :min_values, :max_values
|
15
|
+
end
|
16
|
+
|
17
|
+
include Ai4r::Classifiers
|
18
|
+
include Ai4r::Data
|
19
|
+
|
20
|
+
class IB1Test < Test::Unit::TestCase
|
21
|
+
|
22
|
+
@@data_labels = [ 'city', 'age', 'gender', 'marketing_target' ]
|
23
|
+
|
24
|
+
@@data_items = [['New York', 25, 'M', 'Y'],
|
25
|
+
['New York', 23, 'M', 'Y'],
|
26
|
+
['New York', 18, 'M', 'Y'],
|
27
|
+
['Chicago', 43, 'M', 'Y'],
|
28
|
+
['New York', 34, 'F', 'N'],
|
29
|
+
['Chicago', 33, 'F', 'Y'],
|
30
|
+
['New York', 31, 'F', 'N'],
|
31
|
+
['Chicago', 55, 'M', 'N'],
|
32
|
+
['New York', 58, 'F', 'N'],
|
33
|
+
['New York', 59, 'M', 'N'],
|
34
|
+
['Chicago', 71, 'M', 'N'],
|
35
|
+
['New York', 60, 'F', 'N'],
|
36
|
+
['Chicago', 85, 'F', 'Y']
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
def setup
|
41
|
+
IB1.send(:public, *IB1.protected_instance_methods)
|
42
|
+
@data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
|
43
|
+
@classifier = IB1.new.build(@data_set)
|
44
|
+
end
|
45
|
+
|
46
|
+
def test_build
|
47
|
+
assert_raise(ArgumentError) { IB1.new.build(DataSet.new) }
|
48
|
+
assert @classifier.data_set
|
49
|
+
assert_equal [nil, 18, nil, nil], @classifier.min_values
|
50
|
+
assert_equal [nil, 85, nil, nil], @classifier.max_values
|
51
|
+
end
|
52
|
+
|
53
|
+
def test_norm
|
54
|
+
assert_equal(0,@classifier.norm('Chicago', 0))
|
55
|
+
assert_in_delta(0.5522,@classifier.norm(55, 1),0.0001)
|
56
|
+
assert_equal(0,@classifier.norm('F', 0))
|
57
|
+
end
|
58
|
+
|
59
|
+
def test_distance
|
60
|
+
item = ['Chicago', 55, 'M', 'N']
|
61
|
+
assert_equal(0, @classifier.distance(['Chicago', 55, 'M'], item))
|
62
|
+
assert_equal(1, @classifier.distance([nil, 55, 'M'], item))
|
63
|
+
assert_equal(1, @classifier.distance(['New York', 55, 'M'], item))
|
64
|
+
assert_in_delta(0.2728, @classifier.distance(['Chicago', 20, 'M'], item), 0.0001)
|
65
|
+
end
|
66
|
+
|
67
|
+
def test_eval
|
68
|
+
classifier = IB1.new.build(@data_set)
|
69
|
+
assert classifier
|
70
|
+
assert_equal('N', classifier.eval(['Chicago', 55, 'M']))
|
71
|
+
assert_equal('N', classifier.eval(['New York', 35, 'F']))
|
72
|
+
assert_equal('Y', classifier.eval(['New York', 25, 'M']))
|
73
|
+
assert_equal('Y', classifier.eval(['Chicago', 85, 'F']))
|
74
|
+
end
|
75
|
+
|
76
|
+
end
|
77
|
+
|
78
|
+
|
@@ -0,0 +1,208 @@
|
|
1
|
+
# id3_test.rb
|
2
|
+
#
|
3
|
+
# This is a unit test file for the ID3 algorithm (Quinlan) implemented
|
4
|
+
# in ai4r
|
5
|
+
#
|
6
|
+
# Author:: Sergio Fierens
|
7
|
+
# License:: MPL 1.1
|
8
|
+
# Project:: ai4r
|
9
|
+
# Url:: http://ai4r.rubyforge.org/
|
10
|
+
#
|
11
|
+
# You can redistribute it and/or modify it under the terms of
|
12
|
+
# the Mozilla Public License version 1.1 as published by the
|
13
|
+
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
14
|
+
|
15
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/id3'
|
16
|
+
require 'test/unit'
|
17
|
+
|
18
|
+
DATA_LABELS = [ 'city', 'age_range', 'gender', 'marketing_target' ]
|
19
|
+
|
20
|
+
DATA_ITEMS = [ ['New York', '<30', 'M', 'Y'],
|
21
|
+
['Chicago', '<30', 'M', 'Y'],
|
22
|
+
['Chicago', '<30', 'F', 'Y'],
|
23
|
+
['New York', '<30', 'M', 'Y'],
|
24
|
+
['New York', '<30', 'M', 'Y'],
|
25
|
+
['Chicago', '[30-50)', 'M', 'Y'],
|
26
|
+
['New York', '[30-50)', 'F', 'N'],
|
27
|
+
['Chicago', '[30-50)', 'F', 'Y'],
|
28
|
+
['New York', '[30-50)', 'F', 'N'],
|
29
|
+
['Chicago', '[50-80]', 'M', 'N'],
|
30
|
+
['New York', '[50-80]', 'F', 'N'],
|
31
|
+
['New York', '[50-80]', 'M', 'N'],
|
32
|
+
['Chicago', '[50-80]', 'M', 'N'],
|
33
|
+
['New York', '[50-80]', 'F', 'N'],
|
34
|
+
['Chicago', '>80', 'F', 'Y']
|
35
|
+
]
|
36
|
+
|
37
|
+
SPLIT_DATA_ITEMS_BY_CITY = [ [
|
38
|
+
["New York", "<30", "M", "Y"],
|
39
|
+
["New York", "<30", "M", "Y"],
|
40
|
+
["New York", "<30", "M", "Y"],
|
41
|
+
["New York", "[30-50)", "F", "N"],
|
42
|
+
["New York", "[30-50)", "F", "N"],
|
43
|
+
["New York", "[50-80]", "F", "N"],
|
44
|
+
["New York", "[50-80]", "M", "N"],
|
45
|
+
["New York", "[50-80]", "F", "N"]],
|
46
|
+
[
|
47
|
+
["Chicago", "<30", "M", "Y"],
|
48
|
+
["Chicago", "<30", "F", "Y"],
|
49
|
+
["Chicago", "[30-50)", "M", "Y"],
|
50
|
+
["Chicago", "[30-50)", "F", "Y"],
|
51
|
+
["Chicago", "[50-80]", "M", "N"],
|
52
|
+
["Chicago", "[50-80]", "M", "N"],
|
53
|
+
["Chicago", ">80", "F", "Y"]]
|
54
|
+
]
|
55
|
+
|
56
|
+
SPLIT_DATA_ITEMS_BY_AGE = [ [
|
57
|
+
["New York", "<30", "M", "Y"],
|
58
|
+
["Chicago", "<30", "M", "Y"],
|
59
|
+
["Chicago", "<30", "F", "Y"],
|
60
|
+
["New York", "<30", "M", "Y"],
|
61
|
+
["New York", "<30", "M", "Y"]],
|
62
|
+
[
|
63
|
+
["Chicago", "[30-50)", "M", "Y"],
|
64
|
+
["New York", "[30-50)", "F", "N"],
|
65
|
+
["Chicago", "[30-50)", "F", "Y"],
|
66
|
+
["New York", "[30-50)", "F", "N"]],
|
67
|
+
[
|
68
|
+
["Chicago", "[50-80]", "M", "N"],
|
69
|
+
["New York", "[50-80]", "F", "N"],
|
70
|
+
["New York", "[50-80]", "M", "N"],
|
71
|
+
["Chicago", "[50-80]", "M", "N"],
|
72
|
+
["New York", "[50-80]", "F", "N"]],
|
73
|
+
[
|
74
|
+
["Chicago", ">80", "F", "Y"]]
|
75
|
+
]
|
76
|
+
|
77
|
+
EXPECTED_RULES_STRING =
|
78
|
+
"if age_range=='<30' then marketing_target='Y'\n"+
|
79
|
+
"elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'\n"+
|
80
|
+
"elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'\n"+
|
81
|
+
"elsif age_range=='[50-80]' then marketing_target='N'\n"+
|
82
|
+
"elsif age_range=='>80' then marketing_target='Y'\n"+
|
83
|
+
"else raise 'There was not enough information during training to do a proper induction for this data element' end"
|
84
|
+
|
85
|
+
include Ai4r::Classifiers
|
86
|
+
include Ai4r::Data
|
87
|
+
|
88
|
+
class ID3Test < Test::Unit::TestCase
|
89
|
+
|
90
|
+
def test_build
|
91
|
+
Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.protected_instance_methods)
|
92
|
+
Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.private_instance_methods)
|
93
|
+
end
|
94
|
+
|
95
|
+
def test_log2
|
96
|
+
assert_equal 1.0, ID3.log2(2)
|
97
|
+
assert_equal 0.0, ID3.log2(0)
|
98
|
+
assert 1.585 - ID3.log2(3) < 0.001
|
99
|
+
end
|
100
|
+
|
101
|
+
def test_sum
|
102
|
+
assert_equal 28, ID3.sum([5, 0, 22, 1])
|
103
|
+
assert_equal 0, ID3.sum([])
|
104
|
+
end
|
105
|
+
|
106
|
+
def test_data_labels
|
107
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS))
|
108
|
+
expected_default = [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
|
109
|
+
assert_equal(expected_default, id3.data_set.data_labels)
|
110
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
111
|
+
assert_equal(DATA_LABELS, id3.data_set.data_labels)
|
112
|
+
end
|
113
|
+
|
114
|
+
def test_domain
|
115
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
116
|
+
expected_domain = [["New York", "Chicago"], ["<30", "[30-50)", "[50-80]", ">80"], ["M", "F"], ["Y", "N"]]
|
117
|
+
assert_equal expected_domain, id3.domain(DATA_ITEMS)
|
118
|
+
end
|
119
|
+
|
120
|
+
def test_grid
|
121
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
122
|
+
expected_grid = [[3, 5], [5, 2]]
|
123
|
+
domain = id3.domain(DATA_ITEMS)
|
124
|
+
assert_equal expected_grid, id3.freq_grid(0, DATA_ITEMS, domain)
|
125
|
+
expected_grid = [[5, 0], [2, 2], [0, 5], [1, 0]]
|
126
|
+
assert_equal expected_grid, id3.freq_grid(1, DATA_ITEMS, domain)
|
127
|
+
end
|
128
|
+
|
129
|
+
def test_entropy
|
130
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
131
|
+
expected_entropy = 0.9118
|
132
|
+
domain = id3.domain(DATA_ITEMS)
|
133
|
+
freq_grid = id3.freq_grid(0, DATA_ITEMS, domain)
|
134
|
+
assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
|
135
|
+
expected_entropy = 0.2667
|
136
|
+
freq_grid = id3.freq_grid(1, DATA_ITEMS, domain)
|
137
|
+
assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
|
138
|
+
expected_entropy = 0.9688
|
139
|
+
freq_grid = id3.freq_grid(2, DATA_ITEMS, domain)
|
140
|
+
assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
|
141
|
+
end
|
142
|
+
|
143
|
+
def test_min_entropy_index
|
144
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
145
|
+
domain = id3.domain(DATA_ITEMS)
|
146
|
+
assert_equal 1, id3.min_entropy_index(DATA_ITEMS, domain)
|
147
|
+
assert_equal 0, id3.min_entropy_index(DATA_ITEMS, domain, [1])
|
148
|
+
assert_equal 2, id3.min_entropy_index(DATA_ITEMS, domain, [1, 0])
|
149
|
+
end
|
150
|
+
|
151
|
+
def test_split_data_examples
|
152
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
153
|
+
domain = id3.domain(DATA_ITEMS)
|
154
|
+
res = id3.split_data_examples(DATA_ITEMS, domain, 0)
|
155
|
+
assert_equal(SPLIT_DATA_ITEMS_BY_CITY, res)
|
156
|
+
res = id3.split_data_examples(DATA_ITEMS, domain, 1)
|
157
|
+
assert_equal(SPLIT_DATA_ITEMS_BY_AGE, res)
|
158
|
+
end
|
159
|
+
|
160
|
+
def test_most_freq
|
161
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
162
|
+
domain = id3.domain(DATA_ITEMS)
|
163
|
+
assert_equal 'Y', id3.most_freq(DATA_ITEMS, domain)
|
164
|
+
assert_equal 'Y', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[3], domain)
|
165
|
+
assert_equal 'N', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[2], domain)
|
166
|
+
end
|
167
|
+
|
168
|
+
def test_get_rules
|
169
|
+
assert_equal [["marketing_target='N'"]], CategoryNode.new('marketing_target', 'N').get_rules
|
170
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
171
|
+
assert_equal EXPECTED_RULES_STRING, id3.get_rules
|
172
|
+
end
|
173
|
+
|
174
|
+
def test_eval
|
175
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
176
|
+
#if age_range='<30' then marketing_target='Y'
|
177
|
+
assert_equal 'Y', id3.eval(['New York', '<30', 'F'])
|
178
|
+
assert_equal 'Y', id3.eval(['Chicago', '<30', 'M'])
|
179
|
+
#if age_range='[30-50)' and city='Chicago' then marketing_target='Y'
|
180
|
+
assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'F'])
|
181
|
+
assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'M'])
|
182
|
+
#if age_range='[30-50)' and city='New York' then marketing_target='N'
|
183
|
+
assert_equal 'N', id3.eval(['New York', '[30-50)', 'F'])
|
184
|
+
assert_equal 'N', id3.eval(['New York', '[30-50)', 'M'])
|
185
|
+
#if age_range='[50-80]' then marketing_target='N'
|
186
|
+
assert_equal 'N', id3.eval(['New York', '[50-80]', 'F'])
|
187
|
+
assert_equal 'N', id3.eval(['Chicago', '[50-80]', 'M'])
|
188
|
+
#if age_range='>80' then marketing_target='Y'
|
189
|
+
assert_equal 'Y', id3.eval(['New York', '>80', 'M'])
|
190
|
+
assert_equal 'Y', id3.eval(['Chicago', '>80', 'F'])
|
191
|
+
end
|
192
|
+
|
193
|
+
def test_rules_eval
|
194
|
+
id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
|
195
|
+
#if age_range='<30' then marketing_target='Y'
|
196
|
+
age_range = '<30'
|
197
|
+
marketing_target = nil
|
198
|
+
eval id3.get_rules
|
199
|
+
assert_equal 'Y', marketing_target
|
200
|
+
#if age_range='[30-50)' and city='New York' then marketing_target='N'
|
201
|
+
age_range='[30-50)'
|
202
|
+
city='New York'
|
203
|
+
eval id3.get_rules
|
204
|
+
assert_equal 'N', marketing_target
|
205
|
+
end
|
206
|
+
end
|
207
|
+
|
208
|
+
|
@@ -0,0 +1,79 @@
|
|
1
|
+
require 'test/unit'
|
2
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/multilayer_perceptron'
|
3
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/data/data_set'
|
4
|
+
|
5
|
+
# Make all accessors and methods public
|
6
|
+
class Ai4r::Classifiers::MultilayerPerceptron
|
7
|
+
attr_accessor :data_set, :class_value, :network, :domains, :outputs
|
8
|
+
public :get_max_index
|
9
|
+
public :data_to_output
|
10
|
+
end
|
11
|
+
|
12
|
+
class MultilayerPerceptronTest < Test::Unit::TestCase
|
13
|
+
|
14
|
+
include Ai4r::Classifiers
|
15
|
+
include Ai4r::Data
|
16
|
+
|
17
|
+
@@data_set = DataSet.new(:data_items =>[ ['New York', '<30', 'M', 'Y'],
|
18
|
+
['Chicago', '<30', 'M', 'Y'],
|
19
|
+
['New York', '<30', 'M', 'Y'],
|
20
|
+
['New York', '[30-50)', 'F', 'N'],
|
21
|
+
['Chicago', '[30-50)', 'F', 'Y'],
|
22
|
+
['New York', '[30-50)', 'F', 'N'],
|
23
|
+
['Chicago', '[50-80]', 'M', 'N'],
|
24
|
+
])
|
25
|
+
|
26
|
+
def test_initialize
|
27
|
+
classifier = MultilayerPerceptron.new
|
28
|
+
assert_equal 1, classifier.active_node_value
|
29
|
+
assert_equal 0, classifier.inactive_node_value
|
30
|
+
assert_equal Ai4r::NeuralNetwork::Backpropagation, classifier.network_class
|
31
|
+
assert_equal [], classifier.hidden_layers
|
32
|
+
assert classifier.network_parameters
|
33
|
+
assert classifier.network_parameters.empty?
|
34
|
+
assert classifier.training_iterations > 1
|
35
|
+
end
|
36
|
+
|
37
|
+
def test_build
|
38
|
+
assert_raise(ArgumentError) { MultilayerPerceptron.new.build(DataSet.new) }
|
39
|
+
classifier = MultilayerPerceptron.new
|
40
|
+
classifier.training_iterations = 1
|
41
|
+
classifier.build(@@data_set)
|
42
|
+
assert_equal [7,2], classifier.network.structure
|
43
|
+
classifier.hidden_layers = [6, 4]
|
44
|
+
classifier.build(@@data_set)
|
45
|
+
assert_equal [7,6,4,2], classifier.network.structure
|
46
|
+
end
|
47
|
+
|
48
|
+
def test_eval
|
49
|
+
classifier = MultilayerPerceptron.new.build(@@data_set)
|
50
|
+
assert classifier
|
51
|
+
assert_equal('N', classifier.eval(['Chicago', '[50-80]', 'M']))
|
52
|
+
assert_equal('N', classifier.eval(['New York', '[30-50)', 'F']))
|
53
|
+
assert_equal('Y', classifier.eval(['New York', '<30', 'M']))
|
54
|
+
assert_equal('Y', classifier.eval(['Chicago', '[30-50)', 'F']))
|
55
|
+
end
|
56
|
+
|
57
|
+
def test_get_rules
|
58
|
+
assert_match(/raise/, MultilayerPerceptron.new.get_rules)
|
59
|
+
end
|
60
|
+
|
61
|
+
def test_get_max_index
|
62
|
+
classifier = MultilayerPerceptron.new
|
63
|
+
assert_equal(0, classifier.get_max_index([3, 1, 0.2, -9, 0, 2.99]))
|
64
|
+
assert_equal(2, classifier.get_max_index([3, 1, 5, -9, 0, 2.99]))
|
65
|
+
assert_equal(5, classifier.get_max_index([3, 1, 5, -9, 0, 6]))
|
66
|
+
end
|
67
|
+
|
68
|
+
def test_data_to_output
|
69
|
+
classifier = MultilayerPerceptron.new
|
70
|
+
classifier.outputs = 4
|
71
|
+
classifier.outputs = 4
|
72
|
+
classifier.domains = [nil, nil, nil, ["A", "B", "C", "D"]]
|
73
|
+
assert_equal([1,0,0,0], classifier.data_to_output("A"))
|
74
|
+
assert_equal([0,0,1,0], classifier.data_to_output("C"))
|
75
|
+
assert_equal([0,0,0,1], classifier.data_to_output("D"))
|
76
|
+
end
|
77
|
+
|
78
|
+
end
|
79
|
+
|
@@ -0,0 +1,43 @@
|
|
1
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/naive_bayes'
|
2
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/data/data_set'
|
3
|
+
require 'test/unit'
|
4
|
+
|
5
|
+
include Ai4r::Classifiers
|
6
|
+
include Ai4r::Data
|
7
|
+
|
8
|
+
class NaiveBayesTest < Test::Unit::TestCase
|
9
|
+
|
10
|
+
@@data_labels = [ "Color","Type","Origin","Stolen?" ]
|
11
|
+
|
12
|
+
@@data_items = [
|
13
|
+
["Red", "Sports", "Domestic", "Yes"],
|
14
|
+
["Red", "Sports", "Domestic", "No"],
|
15
|
+
["Red", "Sports", "Domestic", "Yes"],
|
16
|
+
["Yellow","Sports", "Domestic", "No"],
|
17
|
+
["Yellow","Sports", "Imported", "Yes"],
|
18
|
+
["Yellow","SUV", "Imported", "No"],
|
19
|
+
["Yellow","SUV", "Imported", "Yes"],
|
20
|
+
["Yellow","Sports", "Domestic", "No"],
|
21
|
+
["Red", "SUV", "Imported", "No"],
|
22
|
+
["Red", "Sports", "Imported", "Yes"]
|
23
|
+
]
|
24
|
+
|
25
|
+
def setup
|
26
|
+
@data_set = DataSet.new
|
27
|
+
@data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
|
28
|
+
@b = NaiveBayes.new.set_parameters({:m=>3}).build @data_set
|
29
|
+
end
|
30
|
+
|
31
|
+
def test_eval
|
32
|
+
result = @b.eval(["Red", "SUV", "Domestic"])
|
33
|
+
assert_equal "No", result
|
34
|
+
end
|
35
|
+
|
36
|
+
def test_get_probability_map
|
37
|
+
map = @b.get_probability_map(["Red", "SUV", "Domestic"])
|
38
|
+
assert_equal 2, map.keys.length
|
39
|
+
assert_in_delta 0.42, map["Yes"], 0.1
|
40
|
+
assert_in_delta 0.58, map["No"], 0.1
|
41
|
+
end
|
42
|
+
|
43
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
require 'test/unit'
|
2
|
+
require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/one_r'
|
3
|
+
|
4
|
+
class OneRTest < Test::Unit::TestCase
|
5
|
+
|
6
|
+
include Ai4r::Classifiers
|
7
|
+
include Ai4r::Data
|
8
|
+
|
9
|
+
@@data_examples = [ ['New York', '<30', 'M', 'Y'],
|
10
|
+
['Chicago', '<30', 'M', 'Y'],
|
11
|
+
['New York', '<30', 'M', 'Y'],
|
12
|
+
['New York', '[30-50)', 'F', 'N'],
|
13
|
+
['Chicago', '[30-50)', 'F', 'Y'],
|
14
|
+
['New York', '[30-50)', 'F', 'N'],
|
15
|
+
['Chicago', '[50-80]', 'M', 'N']
|
16
|
+
]
|
17
|
+
|
18
|
+
@@data_labels = [ 'city', 'age_range', 'gender', 'marketing_target' ]
|
19
|
+
|
20
|
+
def test_build
|
21
|
+
assert_raise(ArgumentError) { OneR.new.build(DataSet.new) }
|
22
|
+
classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples))
|
23
|
+
assert_not_nil(classifier.data_set.data_labels)
|
24
|
+
assert_not_nil(classifier.rule)
|
25
|
+
assert_equal("attribute_1", classifier.data_set.data_labels.first)
|
26
|
+
assert_equal("class_value", classifier.data_set.data_labels.last)
|
27
|
+
classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples,
|
28
|
+
:data_labels => @@data_labels))
|
29
|
+
assert_not_nil(classifier.data_set.data_labels)
|
30
|
+
assert_not_nil(classifier.rule)
|
31
|
+
assert_equal("city", classifier.data_set.data_labels.first)
|
32
|
+
assert_equal("marketing_target", classifier.data_set.data_labels.last)
|
33
|
+
assert_equal(1, classifier.rule[:attr_index])
|
34
|
+
end
|
35
|
+
|
36
|
+
def test_eval
|
37
|
+
classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples))
|
38
|
+
assert_equal("Y", classifier.eval(['New York', '<30', 'M']))
|
39
|
+
assert_equal("N", classifier.eval(['New York', '[30-50)', 'M']))
|
40
|
+
assert_equal("N", classifier.eval(['Chicago', '[50-80]', 'M']))
|
41
|
+
end
|
42
|
+
|
43
|
+
def test_get_rules
|
44
|
+
classifier = OneR.new.build(DataSet.new(:data_items => @@data_examples,
|
45
|
+
:data_labels => @@data_labels))
|
46
|
+
marketing_target = nil
|
47
|
+
age_range = nil
|
48
|
+
eval(classifier.get_rules)
|
49
|
+
assert_nil(marketing_target)
|
50
|
+
age_range = '<30'
|
51
|
+
eval(classifier.get_rules)
|
52
|
+
assert_equal("Y", marketing_target)
|
53
|
+
age_range = '[30-50)'
|
54
|
+
eval(classifier.get_rules)
|
55
|
+
assert_equal("N", marketing_target)
|
56
|
+
age_range = '[50-80]'
|
57
|
+
eval(classifier.get_rules)
|
58
|
+
assert_equal("N", marketing_target)
|
59
|
+
end
|
60
|
+
|
61
|
+
end
|
62
|
+
|