decisiontree 0.2.0 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- data/CHANGELOG.txt +8 -0
- data/lib/decisiontree/id3_tree.rb +201 -27
- data/lib/decisiontree/version.rb +1 -1
- metadata +2 -2
data/CHANGELOG.txt
CHANGED
@@ -7,3 +7,11 @@
|
|
7
7
|
* Added support for multiple, and symbolic outputs and graphing of continuos trees.
|
8
8
|
* Modified to return the default value when no branches are suitable for the input.
|
9
9
|
* Refactored entropy code.
|
10
|
+
|
11
|
+
0.3.0 - Sept. 15/07
|
12
|
+
* ID3Tree can now handle inconsistent datasets.
|
13
|
+
* Ruleset is a new class that trains an ID3Tree with 2/3 of the training data,
|
14
|
+
converts it into a set of rules and prunes the rules with the remaining 1/3
|
15
|
+
of the training data (in a C4.5 way).
|
16
|
+
* Bagging is a bagging-based trainer (quite obvious), which trains 10 Ruleset
|
17
|
+
trainers and when predicting chooses the best output based on voting.
|
@@ -9,34 +9,54 @@ rescue LoadError
|
|
9
9
|
STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included."
|
10
10
|
end
|
11
11
|
|
12
|
+
class Object
|
13
|
+
def save_to_file(filename)
|
14
|
+
File.open(filename, 'w+' ) { |f| f << Marshal.dump(self) }
|
15
|
+
end
|
16
|
+
|
17
|
+
def self.load_from_file(filename)
|
18
|
+
Marshal.load( File.read( filename ) )
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
12
22
|
class Array
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
end
|
27
|
-
result
|
23
|
+
def classification; collect { |v| v.last }; end
|
24
|
+
|
25
|
+
# calculate information entropy
|
26
|
+
def entropy
|
27
|
+
return 0 if empty?
|
28
|
+
|
29
|
+
info = {}
|
30
|
+
total = 0
|
31
|
+
each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
|
32
|
+
|
33
|
+
result = 0
|
34
|
+
info.each do |symbol, count|
|
35
|
+
result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0)
|
28
36
|
end
|
37
|
+
result
|
38
|
+
end
|
29
39
|
end
|
30
40
|
|
31
41
|
module DecisionTree
|
42
|
+
Node = Struct.new(:attribute, :threshold, :gain)
|
43
|
+
|
32
44
|
class ID3Tree
|
33
|
-
Node = Struct.new(:attribute, :threshold, :gain)
|
34
45
|
def initialize(attributes, data, default, type)
|
35
46
|
@used, @tree, @type = {}, {}, type
|
36
47
|
@data, @attributes, @default = data, attributes, default
|
37
48
|
end
|
38
|
-
|
49
|
+
|
39
50
|
def train(data=@data, attributes=@attributes, default=@default)
|
51
|
+
initialize(attributes, data, default, @type)
|
52
|
+
|
53
|
+
# Remove samples with same attributes leaving most common classification
|
54
|
+
data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]}
|
55
|
+
|
56
|
+
@tree = id3_train(data2, attributes, default)
|
57
|
+
end
|
58
|
+
|
59
|
+
def id3_train(data, attributes, default, used={})
|
40
60
|
# Choose a fitness algorithm
|
41
61
|
case @type
|
42
62
|
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
@@ -47,43 +67,49 @@ module DecisionTree
|
|
47
67
|
|
48
68
|
# return classification if all examples have the same classification
|
49
69
|
return data.first.last if data.classification.uniq.size == 1
|
50
|
-
|
70
|
+
|
51
71
|
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
|
52
72
|
performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) }
|
53
73
|
max = performance.max { |a,b| a[0] <=> b[0] }
|
54
74
|
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
75
|
+
best.threshold = nil if @type == :discrete
|
55
76
|
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
56
77
|
tree, l = {best => {}}, ['>=', '<']
|
57
78
|
|
58
79
|
case @type
|
59
80
|
when :continuous
|
60
|
-
data.partition { |d| d[attributes.index(best.attribute)]
|
61
|
-
tree[best][String.new(l[i])] =
|
81
|
+
data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
|
82
|
+
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
62
83
|
}
|
63
84
|
when :discrete
|
64
85
|
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
65
86
|
partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
|
66
87
|
partitions.each_with_index { |examples, i|
|
67
|
-
tree[best][values[i]] =
|
88
|
+
tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
|
68
89
|
}
|
69
90
|
end
|
70
|
-
|
71
|
-
|
91
|
+
|
92
|
+
tree
|
72
93
|
end
|
73
94
|
|
74
95
|
# ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds)
|
75
96
|
def id3_continuous(data, attributes, attribute)
|
76
97
|
values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, []
|
98
|
+
return [-1, -1] if values.size == 1
|
77
99
|
values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) }
|
78
|
-
thresholds
|
100
|
+
thresholds.pop
|
101
|
+
#thresholds -= used[attribute] if used.has_key? attribute
|
79
102
|
|
80
103
|
gain = thresholds.collect { |threshold|
|
81
|
-
sp = data.partition { |d| d[attributes.index(attribute)]
|
104
|
+
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
82
105
|
pos = (sp[0].size).to_f / data.size
|
83
106
|
neg = (sp[1].size).to_f / data.size
|
84
107
|
|
85
108
|
[data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
|
86
109
|
}.max { |a,b| a[0] <=> b[0] }
|
110
|
+
|
111
|
+
return [-1, -1] if gain.size == 0
|
112
|
+
gain
|
87
113
|
end
|
88
114
|
|
89
115
|
# ID3 for discrete label cases
|
@@ -96,7 +122,7 @@ module DecisionTree
|
|
96
122
|
end
|
97
123
|
|
98
124
|
def predict(test)
|
99
|
-
@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)
|
125
|
+
return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)), 1
|
100
126
|
end
|
101
127
|
|
102
128
|
def graph(filename)
|
@@ -104,6 +130,30 @@ module DecisionTree
|
|
104
130
|
dgp.write_to_file("#{filename}.png", "png")
|
105
131
|
end
|
106
132
|
|
133
|
+
def ruleset
|
134
|
+
rs = Ruleset.new(@attributes, @data, @default, @type)
|
135
|
+
rs.rules = build_rules
|
136
|
+
rs
|
137
|
+
end
|
138
|
+
|
139
|
+
def build_rules(tree=@tree)
|
140
|
+
attr = tree.to_a.first
|
141
|
+
cases = attr[1].to_a
|
142
|
+
rules = []
|
143
|
+
cases.each do |c,child|
|
144
|
+
if child.is_a?(Hash) then
|
145
|
+
build_rules(child).each do |r|
|
146
|
+
r2 = r.clone
|
147
|
+
r2.premises.unshift([attr.first, c])
|
148
|
+
rules << r2
|
149
|
+
end
|
150
|
+
else
|
151
|
+
rules << Rule.new(@attributes, [[attr.first, c]], child)
|
152
|
+
end
|
153
|
+
end
|
154
|
+
rules
|
155
|
+
end
|
156
|
+
|
107
157
|
private
|
108
158
|
def descend_continuous(tree, test)
|
109
159
|
attr = tree.to_a.first
|
@@ -134,7 +184,7 @@ module DecisionTree
|
|
134
184
|
child_text = "#{child.attribute}\n(#{child.object_id})"
|
135
185
|
else
|
136
186
|
child = attr[1][key]
|
137
|
-
child_text = "#{child}\n(#{child.to_s.object_id})"
|
187
|
+
child_text = "#{child}\n(#{child.to_s.clone.object_id})"
|
138
188
|
end
|
139
189
|
label_text = "#{key} #{@type == :continuous ? attr[0].threshold : ""}"
|
140
190
|
|
@@ -145,4 +195,128 @@ module DecisionTree
|
|
145
195
|
return links
|
146
196
|
end
|
147
197
|
end
|
198
|
+
|
199
|
+
class Rule
|
200
|
+
attr_accessor :premises
|
201
|
+
attr_accessor :conclusion
|
202
|
+
attr_accessor :attributes
|
203
|
+
|
204
|
+
def initialize(attributes,premises=[],conclusion=nil)
|
205
|
+
@attributes, @premises, @conclusion = attributes, premises, conclusion
|
206
|
+
end
|
207
|
+
|
208
|
+
def to_s
|
209
|
+
str = ''
|
210
|
+
@premises.each do |p|
|
211
|
+
str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" if p.first.threshold
|
212
|
+
str += "#{p.first.attribute} = #{p.last}" if !p.first.threshold
|
213
|
+
str += "\n"
|
214
|
+
end
|
215
|
+
str += "=> #{@conclusion} (#{accuracy})"
|
216
|
+
end
|
217
|
+
|
218
|
+
def predict(test)
|
219
|
+
verifies = true;
|
220
|
+
@premises.each do |p|
|
221
|
+
if p.first.threshold then # Continuous
|
222
|
+
if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) then
|
223
|
+
verifies = false; break
|
224
|
+
end
|
225
|
+
else # Discrete
|
226
|
+
if test[@attributes.index(p.first.attribute)] != p.last then
|
227
|
+
verifies = false; break
|
228
|
+
end
|
229
|
+
end
|
230
|
+
end
|
231
|
+
return @conclusion if verifies
|
232
|
+
return nil
|
233
|
+
end
|
234
|
+
|
235
|
+
def get_accuracy(data)
|
236
|
+
correct = 0; total = 0
|
237
|
+
data.each do |d|
|
238
|
+
prediction = predict(d)
|
239
|
+
correct += 1 if d.last == prediction
|
240
|
+
total += 1 if !prediction.nil?
|
241
|
+
end
|
242
|
+
(correct.to_f + 1) / (total.to_f + 2)
|
243
|
+
end
|
244
|
+
|
245
|
+
def accuracy(data=nil)
|
246
|
+
data.nil? ? @accuracy : @accuracy = get_accuracy(data)
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
class Ruleset
|
251
|
+
attr_accessor :rules
|
252
|
+
|
253
|
+
def initialize(attributes, data, default, type)
|
254
|
+
@attributes, @default, @type = attributes, default, type
|
255
|
+
mixed_data = data.sort_by {rand}
|
256
|
+
cut = (mixed_data.size.to_f * 0.67).to_i
|
257
|
+
@train_data = mixed_data.slice(0..cut-1)
|
258
|
+
@prune_data = mixed_data.slice(cut..-1)
|
259
|
+
end
|
260
|
+
|
261
|
+
def train(train_data=@train_data, attributes=@attributes, default=@default)
|
262
|
+
dec_tree = DecisionTree::ID3Tree.new(attributes, train_data, default, @type)
|
263
|
+
dec_tree.train
|
264
|
+
@rules = dec_tree.build_rules
|
265
|
+
@rules.each { |r| r.accuracy(train_data) } # Calculate accuracy
|
266
|
+
prune
|
267
|
+
end
|
268
|
+
|
269
|
+
def prune(data=@prune_data)
|
270
|
+
@rules.each do |r|
|
271
|
+
(1..r.premises.size).each do
|
272
|
+
acc1 = r.accuracy(data)
|
273
|
+
p = r.premises.pop
|
274
|
+
if acc1 > r.get_accuracy(data) then
|
275
|
+
r.premises.push(p); break
|
276
|
+
end
|
277
|
+
end
|
278
|
+
end
|
279
|
+
@rules = @rules.sort_by{|r| -r.accuracy(data)}
|
280
|
+
end
|
281
|
+
|
282
|
+
def to_s
|
283
|
+
str = ''; @rules.each { |rule| str += "#{rule}\n\n" }
|
284
|
+
str
|
285
|
+
end
|
286
|
+
|
287
|
+
def predict(test)
|
288
|
+
@rules.each do |r|
|
289
|
+
prediction = r.predict(test)
|
290
|
+
return prediction, r.accuracy unless prediction.nil?
|
291
|
+
end
|
292
|
+
return @default, 0.0
|
293
|
+
end
|
294
|
+
end
|
295
|
+
|
296
|
+
class Bagging
|
297
|
+
attr_accessor :classifiers
|
298
|
+
def initialize(attributes, data, default, type)
|
299
|
+
@classifiers, @type = [], type
|
300
|
+
@data, @attributes, @default = data, attributes, default
|
301
|
+
end
|
302
|
+
|
303
|
+
def train(data=@data, attributes=@attributes, default=@default)
|
304
|
+
@classifiers = []
|
305
|
+
10.times { @classifiers << Ruleset.new(attributes, data, default, @type) }
|
306
|
+
@classifiers.each do |c|
|
307
|
+
c.train(data, attributes, default)
|
308
|
+
end
|
309
|
+
end
|
310
|
+
|
311
|
+
def predict(test)
|
312
|
+
predictions = Hash.new(0)
|
313
|
+
@classifiers.each do |c|
|
314
|
+
p, accuracy = c.predict(test)
|
315
|
+
predictions[p] += accuracy unless p.nil?
|
316
|
+
end
|
317
|
+
return @default, 0.0 if predictions.empty?
|
318
|
+
winner = predictions.sort_by {|k,v| -v}.first
|
319
|
+
return winner[0], winner[1].to_f / @classifiers.size.to_f
|
320
|
+
end
|
321
|
+
end
|
148
322
|
end
|
data/lib/decisiontree/version.rb
CHANGED
metadata
CHANGED
@@ -3,8 +3,8 @@ rubygems_version: 0.9.2
|
|
3
3
|
specification_version: 1
|
4
4
|
name: decisiontree
|
5
5
|
version: !ruby/object:Gem::Version
|
6
|
-
version: 0.
|
7
|
-
date: 2007-
|
6
|
+
version: 0.3.0
|
7
|
+
date: 2007-09-15 00:00:00 -04:00
|
8
8
|
summary: ID3-based implementation of the M.L. Decision Tree algorithm
|
9
9
|
require_paths:
|
10
10
|
- lib
|