decisiontree 0.2.0 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
data/CHANGELOG.txt CHANGED
@@ -7,3 +7,11 @@
7
7
  * Added support for multiple, and symbolic outputs and graphing of continuos trees.
8
8
  * Modified to return the default value when no branches are suitable for the input.
9
9
  * Refactored entropy code.
10
+
11
+ 0.3.0 - Sept. 15/07
12
+ * ID3Tree can now handle inconsistent datasets.
13
+ * Ruleset is a new class that trains an ID3Tree with 2/3 of the training data,
14
+ converts it into a set of rules and prunes the rules with the remaining 1/3
15
+ of the training data (in a C4.5 way).
16
+ * Bagging is a bagging-based trainer (quite obvious), which trains 10 Ruleset
17
+ trainers and when predicting chooses the best output based on voting.
@@ -9,34 +9,54 @@ rescue LoadError
9
9
  STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included."
10
10
  end
11
11
 
12
+ class Object
13
+ def save_to_file(filename)
14
+ File.open(filename, 'w+' ) { |f| f << Marshal.dump(self) }
15
+ end
16
+
17
+ def self.load_from_file(filename)
18
+ Marshal.load( File.read( filename ) )
19
+ end
20
+ end
21
+
12
22
  class Array
13
- def classification; collect { |v| v.last }; end
14
-
15
- # calculate Information entropy
16
- def entropy
17
- return 0 if empty?
18
-
19
- info = {}
20
- total = 0
21
- each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
22
-
23
- result = 0
24
- info.each do |symbol, count|
25
- result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0)
26
- end
27
- result
23
+ def classification; collect { |v| v.last }; end
24
+
25
+ # calculate information entropy
26
+ def entropy
27
+ return 0 if empty?
28
+
29
+ info = {}
30
+ total = 0
31
+ each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
32
+
33
+ result = 0
34
+ info.each do |symbol, count|
35
+ result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0)
28
36
  end
37
+ result
38
+ end
29
39
  end
30
40
 
31
41
  module DecisionTree
42
+ Node = Struct.new(:attribute, :threshold, :gain)
43
+
32
44
  class ID3Tree
33
- Node = Struct.new(:attribute, :threshold, :gain)
34
45
  def initialize(attributes, data, default, type)
35
46
  @used, @tree, @type = {}, {}, type
36
47
  @data, @attributes, @default = data, attributes, default
37
48
  end
38
-
49
+
39
50
  def train(data=@data, attributes=@attributes, default=@default)
51
+ initialize(attributes, data, default, @type)
52
+
53
+ # Remove samples with same attributes leaving most common classification
54
+ data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]}
55
+
56
+ @tree = id3_train(data2, attributes, default)
57
+ end
58
+
59
+ def id3_train(data, attributes, default, used={})
40
60
  # Choose a fitness algorithm
41
61
  case @type
42
62
  when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
@@ -47,43 +67,49 @@ module DecisionTree
47
67
 
48
68
  # return classification if all examples have the same classification
49
69
  return data.first.last if data.classification.uniq.size == 1
50
-
70
+
51
71
  # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
52
72
  performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) }
53
73
  max = performance.max { |a,b| a[0] <=> b[0] }
54
74
  best = Node.new(attributes[performance.index(max)], max[1], max[0])
75
+ best.threshold = nil if @type == :discrete
55
76
  @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
56
77
  tree, l = {best => {}}, ['>=', '<']
57
78
 
58
79
  case @type
59
80
  when :continuous
60
- data.partition { |d| d[attributes.index(best.attribute)] > best.threshold }.each_with_index { |examples, i|
61
- tree[best][String.new(l[i])] = train(examples, attributes, (data.classification.mode rescue 0), &fitness)
81
+ data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
82
+ tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
62
83
  }
63
84
  when :discrete
64
85
  values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
65
86
  partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
66
87
  partitions.each_with_index { |examples, i|
67
- tree[best][values[i]] = train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
88
+ tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
68
89
  }
69
90
  end
70
-
71
- @tree = tree
91
+
92
+ tree
72
93
  end
73
94
 
74
95
  # ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds)
75
96
  def id3_continuous(data, attributes, attribute)
76
97
  values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, []
98
+ return [-1, -1] if values.size == 1
77
99
  values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) }
78
- thresholds -= @used[attribute] if @used.has_key? attribute
100
+ thresholds.pop
101
+ #thresholds -= used[attribute] if used.has_key? attribute
79
102
 
80
103
  gain = thresholds.collect { |threshold|
81
- sp = data.partition { |d| d[attributes.index(attribute)] > threshold }
104
+ sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
82
105
  pos = (sp[0].size).to_f / data.size
83
106
  neg = (sp[1].size).to_f / data.size
84
107
 
85
108
  [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
86
109
  }.max { |a,b| a[0] <=> b[0] }
110
+
111
+ return [-1, -1] if gain.size == 0
112
+ gain
87
113
  end
88
114
 
89
115
  # ID3 for discrete label cases
@@ -96,7 +122,7 @@ module DecisionTree
96
122
  end
97
123
 
98
124
  def predict(test)
99
- @type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)
125
+ return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)), 1
100
126
  end
101
127
 
102
128
  def graph(filename)
@@ -104,6 +130,30 @@ module DecisionTree
104
130
  dgp.write_to_file("#{filename}.png", "png")
105
131
  end
106
132
 
133
+ def ruleset
134
+ rs = Ruleset.new(@attributes, @data, @default, @type)
135
+ rs.rules = build_rules
136
+ rs
137
+ end
138
+
139
+ def build_rules(tree=@tree)
140
+ attr = tree.to_a.first
141
+ cases = attr[1].to_a
142
+ rules = []
143
+ cases.each do |c,child|
144
+ if child.is_a?(Hash) then
145
+ build_rules(child).each do |r|
146
+ r2 = r.clone
147
+ r2.premises.unshift([attr.first, c])
148
+ rules << r2
149
+ end
150
+ else
151
+ rules << Rule.new(@attributes, [[attr.first, c]], child)
152
+ end
153
+ end
154
+ rules
155
+ end
156
+
107
157
  private
108
158
  def descend_continuous(tree, test)
109
159
  attr = tree.to_a.first
@@ -134,7 +184,7 @@ module DecisionTree
134
184
  child_text = "#{child.attribute}\n(#{child.object_id})"
135
185
  else
136
186
  child = attr[1][key]
137
- child_text = "#{child}\n(#{child.to_s.object_id})"
187
+ child_text = "#{child}\n(#{child.to_s.clone.object_id})"
138
188
  end
139
189
  label_text = "#{key} #{@type == :continuous ? attr[0].threshold : ""}"
140
190
 
@@ -145,4 +195,128 @@ module DecisionTree
145
195
  return links
146
196
  end
147
197
  end
198
+
199
+ class Rule
200
+ attr_accessor :premises
201
+ attr_accessor :conclusion
202
+ attr_accessor :attributes
203
+
204
+ def initialize(attributes,premises=[],conclusion=nil)
205
+ @attributes, @premises, @conclusion = attributes, premises, conclusion
206
+ end
207
+
208
+ def to_s
209
+ str = ''
210
+ @premises.each do |p|
211
+ str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" if p.first.threshold
212
+ str += "#{p.first.attribute} = #{p.last}" if !p.first.threshold
213
+ str += "\n"
214
+ end
215
+ str += "=> #{@conclusion} (#{accuracy})"
216
+ end
217
+
218
+ def predict(test)
219
+ verifies = true;
220
+ @premises.each do |p|
221
+ if p.first.threshold then # Continuous
222
+ if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) then
223
+ verifies = false; break
224
+ end
225
+ else # Discrete
226
+ if test[@attributes.index(p.first.attribute)] != p.last then
227
+ verifies = false; break
228
+ end
229
+ end
230
+ end
231
+ return @conclusion if verifies
232
+ return nil
233
+ end
234
+
235
+ def get_accuracy(data)
236
+ correct = 0; total = 0
237
+ data.each do |d|
238
+ prediction = predict(d)
239
+ correct += 1 if d.last == prediction
240
+ total += 1 if !prediction.nil?
241
+ end
242
+ (correct.to_f + 1) / (total.to_f + 2)
243
+ end
244
+
245
+ def accuracy(data=nil)
246
+ data.nil? ? @accuracy : @accuracy = get_accuracy(data)
247
+ end
248
+ end
249
+
250
+ class Ruleset
251
+ attr_accessor :rules
252
+
253
+ def initialize(attributes, data, default, type)
254
+ @attributes, @default, @type = attributes, default, type
255
+ mixed_data = data.sort_by {rand}
256
+ cut = (mixed_data.size.to_f * 0.67).to_i
257
+ @train_data = mixed_data.slice(0..cut-1)
258
+ @prune_data = mixed_data.slice(cut..-1)
259
+ end
260
+
261
+ def train(train_data=@train_data, attributes=@attributes, default=@default)
262
+ dec_tree = DecisionTree::ID3Tree.new(attributes, train_data, default, @type)
263
+ dec_tree.train
264
+ @rules = dec_tree.build_rules
265
+ @rules.each { |r| r.accuracy(train_data) } # Calculate accuracy
266
+ prune
267
+ end
268
+
269
+ def prune(data=@prune_data)
270
+ @rules.each do |r|
271
+ (1..r.premises.size).each do
272
+ acc1 = r.accuracy(data)
273
+ p = r.premises.pop
274
+ if acc1 > r.get_accuracy(data) then
275
+ r.premises.push(p); break
276
+ end
277
+ end
278
+ end
279
+ @rules = @rules.sort_by{|r| -r.accuracy(data)}
280
+ end
281
+
282
+ def to_s
283
+ str = ''; @rules.each { |rule| str += "#{rule}\n\n" }
284
+ str
285
+ end
286
+
287
+ def predict(test)
288
+ @rules.each do |r|
289
+ prediction = r.predict(test)
290
+ return prediction, r.accuracy unless prediction.nil?
291
+ end
292
+ return @default, 0.0
293
+ end
294
+ end
295
+
296
+ class Bagging
297
+ attr_accessor :classifiers
298
+ def initialize(attributes, data, default, type)
299
+ @classifiers, @type = [], type
300
+ @data, @attributes, @default = data, attributes, default
301
+ end
302
+
303
+ def train(data=@data, attributes=@attributes, default=@default)
304
+ @classifiers = []
305
+ 10.times { @classifiers << Ruleset.new(attributes, data, default, @type) }
306
+ @classifiers.each do |c|
307
+ c.train(data, attributes, default)
308
+ end
309
+ end
310
+
311
+ def predict(test)
312
+ predictions = Hash.new(0)
313
+ @classifiers.each do |c|
314
+ p, accuracy = c.predict(test)
315
+ predictions[p] += accuracy unless p.nil?
316
+ end
317
+ return @default, 0.0 if predictions.empty?
318
+ winner = predictions.sort_by {|k,v| -v}.first
319
+ return winner[0], winner[1].to_f / @classifiers.size.to_f
320
+ end
321
+ end
148
322
  end
@@ -1,7 +1,7 @@
1
1
  module DecisionTree #:nodoc:
2
2
  module VERSION #:nodoc:
3
3
  MAJOR = 0
4
- MINOR = 2
4
+ MINOR = 3
5
5
  TINY = 0
6
6
 
7
7
  STRING = [MAJOR, MINOR, TINY].join('.')
metadata CHANGED
@@ -3,8 +3,8 @@ rubygems_version: 0.9.2
3
3
  specification_version: 1
4
4
  name: decisiontree
5
5
  version: !ruby/object:Gem::Version
6
- version: 0.2.0
7
- date: 2007-07-07 00:00:00 -04:00
6
+ version: 0.3.0
7
+ date: 2007-09-15 00:00:00 -04:00
8
8
  summary: ID3-based implementation of the M.L. Decision Tree algorithm
9
9
  require_paths:
10
10
  - lib