decisiontree 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/CHANGELOG.txt +8 -0
- data/lib/decisiontree/id3_tree.rb +201 -27
- data/lib/decisiontree/version.rb +1 -1
- metadata +2 -2
data/CHANGELOG.txt
CHANGED
@@ -7,3 +7,11 @@
|
|
7
7
|
* Added support for multiple, and symbolic outputs and graphing of continuos trees.
|
8
8
|
* Modified to return the default value when no branches are suitable for the input.
|
9
9
|
* Refactored entropy code.
|
10
|
+
|
11
|
+
0.3.0 - Sept. 15/07
|
12
|
+
* ID3Tree can now handle inconsistent datasets.
|
13
|
+
* Ruleset is a new class that trains an ID3Tree with 2/3 of the training data,
|
14
|
+
converts it into a set of rules and prunes the rules with the remaining 1/3
|
15
|
+
of the training data (in a C4.5 way).
|
16
|
+
* Bagging is a bagging-based trainer (quite obvious), which trains 10 Ruleset
|
17
|
+
trainers and when predicting chooses the best output based on voting.
|
@@ -9,34 +9,54 @@ rescue LoadError
|
|
9
9
|
STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included."
|
10
10
|
end
|
11
11
|
|
12
|
+
class Object
|
13
|
+
def save_to_file(filename)
|
14
|
+
File.open(filename, 'w+' ) { |f| f << Marshal.dump(self) }
|
15
|
+
end
|
16
|
+
|
17
|
+
def self.load_from_file(filename)
|
18
|
+
Marshal.load( File.read( filename ) )
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
12
22
|
class Array
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
end
|
27
|
-
result
|
23
|
+
def classification; collect { |v| v.last }; end
|
24
|
+
|
25
|
+
# calculate information entropy
|
26
|
+
def entropy
|
27
|
+
return 0 if empty?
|
28
|
+
|
29
|
+
info = {}
|
30
|
+
total = 0
|
31
|
+
each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
|
32
|
+
|
33
|
+
result = 0
|
34
|
+
info.each do |symbol, count|
|
35
|
+
result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0)
|
28
36
|
end
|
37
|
+
result
|
38
|
+
end
|
29
39
|
end
|
30
40
|
|
31
41
|
module DecisionTree
|
42
|
+
Node = Struct.new(:attribute, :threshold, :gain)
|
43
|
+
|
32
44
|
class ID3Tree
|
33
|
-
Node = Struct.new(:attribute, :threshold, :gain)
|
34
45
|
def initialize(attributes, data, default, type)
|
35
46
|
@used, @tree, @type = {}, {}, type
|
36
47
|
@data, @attributes, @default = data, attributes, default
|
37
48
|
end
|
38
|
-
|
49
|
+
|
39
50
|
def train(data=@data, attributes=@attributes, default=@default)
|
51
|
+
initialize(attributes, data, default, @type)
|
52
|
+
|
53
|
+
# Remove samples with same attributes leaving most common classification
|
54
|
+
data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]}
|
55
|
+
|
56
|
+
@tree = id3_train(data2, attributes, default)
|
57
|
+
end
|
58
|
+
|
59
|
+
def id3_train(data, attributes, default, used={})
|
40
60
|
# Choose a fitness algorithm
|
41
61
|
case @type
|
42
62
|
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
@@ -47,43 +67,49 @@ module DecisionTree
|
|
47
67
|
|
48
68
|
# return classification if all examples have the same classification
|
49
69
|
return data.first.last if data.classification.uniq.size == 1
|
50
|
-
|
70
|
+
|
51
71
|
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
|
52
72
|
performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) }
|
53
73
|
max = performance.max { |a,b| a[0] <=> b[0] }
|
54
74
|
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
75
|
+
best.threshold = nil if @type == :discrete
|
55
76
|
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
56
77
|
tree, l = {best => {}}, ['>=', '<']
|
57
78
|
|
58
79
|
case @type
|
59
80
|
when :continuous
|
60
|
-
data.partition { |d| d[attributes.index(best.attribute)]
|
61
|
-
tree[best][String.new(l[i])] =
|
81
|
+
data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
|
82
|
+
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
62
83
|
}
|
63
84
|
when :discrete
|
64
85
|
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
65
86
|
partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
|
66
87
|
partitions.each_with_index { |examples, i|
|
67
|
-
tree[best][values[i]] =
|
88
|
+
tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
|
68
89
|
}
|
69
90
|
end
|
70
|
-
|
71
|
-
|
91
|
+
|
92
|
+
tree
|
72
93
|
end
|
73
94
|
|
74
95
|
# ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds)
|
75
96
|
def id3_continuous(data, attributes, attribute)
|
76
97
|
values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, []
|
98
|
+
return [-1, -1] if values.size == 1
|
77
99
|
values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) }
|
78
|
-
thresholds
|
100
|
+
thresholds.pop
|
101
|
+
#thresholds -= used[attribute] if used.has_key? attribute
|
79
102
|
|
80
103
|
gain = thresholds.collect { |threshold|
|
81
|
-
sp = data.partition { |d| d[attributes.index(attribute)]
|
104
|
+
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
82
105
|
pos = (sp[0].size).to_f / data.size
|
83
106
|
neg = (sp[1].size).to_f / data.size
|
84
107
|
|
85
108
|
[data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
|
86
109
|
}.max { |a,b| a[0] <=> b[0] }
|
110
|
+
|
111
|
+
return [-1, -1] if gain.size == 0
|
112
|
+
gain
|
87
113
|
end
|
88
114
|
|
89
115
|
# ID3 for discrete label cases
|
@@ -96,7 +122,7 @@ module DecisionTree
|
|
96
122
|
end
|
97
123
|
|
98
124
|
def predict(test)
|
99
|
-
@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)
|
125
|
+
return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)), 1
|
100
126
|
end
|
101
127
|
|
102
128
|
def graph(filename)
|
@@ -104,6 +130,30 @@ module DecisionTree
|
|
104
130
|
dgp.write_to_file("#{filename}.png", "png")
|
105
131
|
end
|
106
132
|
|
133
|
+
def ruleset
|
134
|
+
rs = Ruleset.new(@attributes, @data, @default, @type)
|
135
|
+
rs.rules = build_rules
|
136
|
+
rs
|
137
|
+
end
|
138
|
+
|
139
|
+
def build_rules(tree=@tree)
|
140
|
+
attr = tree.to_a.first
|
141
|
+
cases = attr[1].to_a
|
142
|
+
rules = []
|
143
|
+
cases.each do |c,child|
|
144
|
+
if child.is_a?(Hash) then
|
145
|
+
build_rules(child).each do |r|
|
146
|
+
r2 = r.clone
|
147
|
+
r2.premises.unshift([attr.first, c])
|
148
|
+
rules << r2
|
149
|
+
end
|
150
|
+
else
|
151
|
+
rules << Rule.new(@attributes, [[attr.first, c]], child)
|
152
|
+
end
|
153
|
+
end
|
154
|
+
rules
|
155
|
+
end
|
156
|
+
|
107
157
|
private
|
108
158
|
def descend_continuous(tree, test)
|
109
159
|
attr = tree.to_a.first
|
@@ -134,7 +184,7 @@ module DecisionTree
|
|
134
184
|
child_text = "#{child.attribute}\n(#{child.object_id})"
|
135
185
|
else
|
136
186
|
child = attr[1][key]
|
137
|
-
child_text = "#{child}\n(#{child.to_s.object_id})"
|
187
|
+
child_text = "#{child}\n(#{child.to_s.clone.object_id})"
|
138
188
|
end
|
139
189
|
label_text = "#{key} #{@type == :continuous ? attr[0].threshold : ""}"
|
140
190
|
|
@@ -145,4 +195,128 @@ module DecisionTree
|
|
145
195
|
return links
|
146
196
|
end
|
147
197
|
end
|
198
|
+
|
199
|
+
class Rule
|
200
|
+
attr_accessor :premises
|
201
|
+
attr_accessor :conclusion
|
202
|
+
attr_accessor :attributes
|
203
|
+
|
204
|
+
def initialize(attributes,premises=[],conclusion=nil)
|
205
|
+
@attributes, @premises, @conclusion = attributes, premises, conclusion
|
206
|
+
end
|
207
|
+
|
208
|
+
def to_s
|
209
|
+
str = ''
|
210
|
+
@premises.each do |p|
|
211
|
+
str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" if p.first.threshold
|
212
|
+
str += "#{p.first.attribute} = #{p.last}" if !p.first.threshold
|
213
|
+
str += "\n"
|
214
|
+
end
|
215
|
+
str += "=> #{@conclusion} (#{accuracy})"
|
216
|
+
end
|
217
|
+
|
218
|
+
def predict(test)
|
219
|
+
verifies = true;
|
220
|
+
@premises.each do |p|
|
221
|
+
if p.first.threshold then # Continuous
|
222
|
+
if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) then
|
223
|
+
verifies = false; break
|
224
|
+
end
|
225
|
+
else # Discrete
|
226
|
+
if test[@attributes.index(p.first.attribute)] != p.last then
|
227
|
+
verifies = false; break
|
228
|
+
end
|
229
|
+
end
|
230
|
+
end
|
231
|
+
return @conclusion if verifies
|
232
|
+
return nil
|
233
|
+
end
|
234
|
+
|
235
|
+
def get_accuracy(data)
|
236
|
+
correct = 0; total = 0
|
237
|
+
data.each do |d|
|
238
|
+
prediction = predict(d)
|
239
|
+
correct += 1 if d.last == prediction
|
240
|
+
total += 1 if !prediction.nil?
|
241
|
+
end
|
242
|
+
(correct.to_f + 1) / (total.to_f + 2)
|
243
|
+
end
|
244
|
+
|
245
|
+
def accuracy(data=nil)
|
246
|
+
data.nil? ? @accuracy : @accuracy = get_accuracy(data)
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
class Ruleset
|
251
|
+
attr_accessor :rules
|
252
|
+
|
253
|
+
def initialize(attributes, data, default, type)
|
254
|
+
@attributes, @default, @type = attributes, default, type
|
255
|
+
mixed_data = data.sort_by {rand}
|
256
|
+
cut = (mixed_data.size.to_f * 0.67).to_i
|
257
|
+
@train_data = mixed_data.slice(0..cut-1)
|
258
|
+
@prune_data = mixed_data.slice(cut..-1)
|
259
|
+
end
|
260
|
+
|
261
|
+
def train(train_data=@train_data, attributes=@attributes, default=@default)
|
262
|
+
dec_tree = DecisionTree::ID3Tree.new(attributes, train_data, default, @type)
|
263
|
+
dec_tree.train
|
264
|
+
@rules = dec_tree.build_rules
|
265
|
+
@rules.each { |r| r.accuracy(train_data) } # Calculate accuracy
|
266
|
+
prune
|
267
|
+
end
|
268
|
+
|
269
|
+
def prune(data=@prune_data)
|
270
|
+
@rules.each do |r|
|
271
|
+
(1..r.premises.size).each do
|
272
|
+
acc1 = r.accuracy(data)
|
273
|
+
p = r.premises.pop
|
274
|
+
if acc1 > r.get_accuracy(data) then
|
275
|
+
r.premises.push(p); break
|
276
|
+
end
|
277
|
+
end
|
278
|
+
end
|
279
|
+
@rules = @rules.sort_by{|r| -r.accuracy(data)}
|
280
|
+
end
|
281
|
+
|
282
|
+
def to_s
|
283
|
+
str = ''; @rules.each { |rule| str += "#{rule}\n\n" }
|
284
|
+
str
|
285
|
+
end
|
286
|
+
|
287
|
+
def predict(test)
|
288
|
+
@rules.each do |r|
|
289
|
+
prediction = r.predict(test)
|
290
|
+
return prediction, r.accuracy unless prediction.nil?
|
291
|
+
end
|
292
|
+
return @default, 0.0
|
293
|
+
end
|
294
|
+
end
|
295
|
+
|
296
|
+
class Bagging
|
297
|
+
attr_accessor :classifiers
|
298
|
+
def initialize(attributes, data, default, type)
|
299
|
+
@classifiers, @type = [], type
|
300
|
+
@data, @attributes, @default = data, attributes, default
|
301
|
+
end
|
302
|
+
|
303
|
+
def train(data=@data, attributes=@attributes, default=@default)
|
304
|
+
@classifiers = []
|
305
|
+
10.times { @classifiers << Ruleset.new(attributes, data, default, @type) }
|
306
|
+
@classifiers.each do |c|
|
307
|
+
c.train(data, attributes, default)
|
308
|
+
end
|
309
|
+
end
|
310
|
+
|
311
|
+
def predict(test)
|
312
|
+
predictions = Hash.new(0)
|
313
|
+
@classifiers.each do |c|
|
314
|
+
p, accuracy = c.predict(test)
|
315
|
+
predictions[p] += accuracy unless p.nil?
|
316
|
+
end
|
317
|
+
return @default, 0.0 if predictions.empty?
|
318
|
+
winner = predictions.sort_by {|k,v| -v}.first
|
319
|
+
return winner[0], winner[1].to_f / @classifiers.size.to_f
|
320
|
+
end
|
321
|
+
end
|
148
322
|
end
|
data/lib/decisiontree/version.rb
CHANGED
metadata
CHANGED
@@ -3,8 +3,8 @@ rubygems_version: 0.9.2
|
|
3
3
|
specification_version: 1
|
4
4
|
name: decisiontree
|
5
5
|
version: !ruby/object:Gem::Version
|
6
|
-
version: 0.
|
7
|
-
date: 2007-
|
6
|
+
version: 0.3.0
|
7
|
+
date: 2007-09-15 00:00:00 -04:00
|
8
8
|
summary: ID3-based implementation of the M.L. Decision Tree algorithm
|
9
9
|
require_paths:
|
10
10
|
- lib
|