decisiontree 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/CHANGELOG.txt CHANGED
@@ -7,3 +7,11 @@
7
7
  * Added support for multiple, and symbolic outputs and graphing of continuos trees.
8
8
  * Modified to return the default value when no branches are suitable for the input.
9
9
  * Refactored entropy code.
10
+
11
+ 0.3.0 - Sept. 15/07
12
+ * ID3Tree can now handle inconsistent datasets.
13
+ * Ruleset is a new class that trains an ID3Tree with 2/3 of the training data,
14
+ converts it into a set of rules and prunes the rules with the remaining 1/3
15
+ of the training data (in a C4.5 way).
16
+ * Bagging is a bagging-based trainer (quite obvious), which trains 10 Ruleset
17
+ trainers and when predicting chooses the best output based on voting.
@@ -9,34 +9,54 @@ rescue LoadError
9
9
  STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included."
10
10
  end
11
11
 
12
+ class Object
13
+ def save_to_file(filename)
14
+ File.open(filename, 'w+' ) { |f| f << Marshal.dump(self) }
15
+ end
16
+
17
+ def self.load_from_file(filename)
18
+ Marshal.load( File.read( filename ) )
19
+ end
20
+ end
21
+
12
22
  class Array
13
- def classification; collect { |v| v.last }; end
14
-
15
- # calculate Information entropy
16
- def entropy
17
- return 0 if empty?
18
-
19
- info = {}
20
- total = 0
21
- each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
22
-
23
- result = 0
24
- info.each do |symbol, count|
25
- result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0)
26
- end
27
- result
23
+ def classification; collect { |v| v.last }; end
24
+
25
+ # calculate information entropy
26
+ def entropy
27
+ return 0 if empty?
28
+
29
+ info = {}
30
+ total = 0
31
+ each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
32
+
33
+ result = 0
34
+ info.each do |symbol, count|
35
+ result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0)
28
36
  end
37
+ result
38
+ end
29
39
  end
30
40
 
31
41
  module DecisionTree
42
+ Node = Struct.new(:attribute, :threshold, :gain)
43
+
32
44
  class ID3Tree
33
- Node = Struct.new(:attribute, :threshold, :gain)
34
45
  def initialize(attributes, data, default, type)
35
46
  @used, @tree, @type = {}, {}, type
36
47
  @data, @attributes, @default = data, attributes, default
37
48
  end
38
-
49
+
39
50
  def train(data=@data, attributes=@attributes, default=@default)
51
+ initialize(attributes, data, default, @type)
52
+
53
+ # Remove samples with same attributes leaving most common classification
54
+ data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]}
55
+
56
+ @tree = id3_train(data2, attributes, default)
57
+ end
58
+
59
+ def id3_train(data, attributes, default, used={})
40
60
  # Choose a fitness algorithm
41
61
  case @type
42
62
  when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
@@ -47,43 +67,49 @@ module DecisionTree
47
67
 
48
68
  # return classification if all examples have the same classification
49
69
  return data.first.last if data.classification.uniq.size == 1
50
-
70
+
51
71
  # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
52
72
  performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) }
53
73
  max = performance.max { |a,b| a[0] <=> b[0] }
54
74
  best = Node.new(attributes[performance.index(max)], max[1], max[0])
75
+ best.threshold = nil if @type == :discrete
55
76
  @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
56
77
  tree, l = {best => {}}, ['>=', '<']
57
78
 
58
79
  case @type
59
80
  when :continuous
60
- data.partition { |d| d[attributes.index(best.attribute)] > best.threshold }.each_with_index { |examples, i|
61
- tree[best][String.new(l[i])] = train(examples, attributes, (data.classification.mode rescue 0), &fitness)
81
+ data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
82
+ tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
62
83
  }
63
84
  when :discrete
64
85
  values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
65
86
  partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
66
87
  partitions.each_with_index { |examples, i|
67
- tree[best][values[i]] = train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
88
+ tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
68
89
  }
69
90
  end
70
-
71
- @tree = tree
91
+
92
+ tree
72
93
  end
73
94
 
74
95
  # ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds)
75
96
  def id3_continuous(data, attributes, attribute)
76
97
  values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, []
98
+ return [-1, -1] if values.size == 1
77
99
  values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) }
78
- thresholds -= @used[attribute] if @used.has_key? attribute
100
+ thresholds.pop
101
+ #thresholds -= used[attribute] if used.has_key? attribute
79
102
 
80
103
  gain = thresholds.collect { |threshold|
81
- sp = data.partition { |d| d[attributes.index(attribute)] > threshold }
104
+ sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
82
105
  pos = (sp[0].size).to_f / data.size
83
106
  neg = (sp[1].size).to_f / data.size
84
107
 
85
108
  [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
86
109
  }.max { |a,b| a[0] <=> b[0] }
110
+
111
+ return [-1, -1] if gain.size == 0
112
+ gain
87
113
  end
88
114
 
89
115
  # ID3 for discrete label cases
@@ -96,7 +122,7 @@ module DecisionTree
96
122
  end
97
123
 
98
124
  def predict(test)
99
- @type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)
125
+ return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)), 1
100
126
  end
101
127
 
102
128
  def graph(filename)
@@ -104,6 +130,30 @@ module DecisionTree
104
130
  dgp.write_to_file("#{filename}.png", "png")
105
131
  end
106
132
 
133
+ def ruleset
134
+ rs = Ruleset.new(@attributes, @data, @default, @type)
135
+ rs.rules = build_rules
136
+ rs
137
+ end
138
+
139
+ def build_rules(tree=@tree)
140
+ attr = tree.to_a.first
141
+ cases = attr[1].to_a
142
+ rules = []
143
+ cases.each do |c,child|
144
+ if child.is_a?(Hash) then
145
+ build_rules(child).each do |r|
146
+ r2 = r.clone
147
+ r2.premises.unshift([attr.first, c])
148
+ rules << r2
149
+ end
150
+ else
151
+ rules << Rule.new(@attributes, [[attr.first, c]], child)
152
+ end
153
+ end
154
+ rules
155
+ end
156
+
107
157
  private
108
158
  def descend_continuous(tree, test)
109
159
  attr = tree.to_a.first
@@ -134,7 +184,7 @@ module DecisionTree
134
184
  child_text = "#{child.attribute}\n(#{child.object_id})"
135
185
  else
136
186
  child = attr[1][key]
137
- child_text = "#{child}\n(#{child.to_s.object_id})"
187
+ child_text = "#{child}\n(#{child.to_s.clone.object_id})"
138
188
  end
139
189
  label_text = "#{key} #{@type == :continuous ? attr[0].threshold : ""}"
140
190
 
@@ -145,4 +195,128 @@ module DecisionTree
145
195
  return links
146
196
  end
147
197
  end
198
+
199
+ class Rule
200
+ attr_accessor :premises
201
+ attr_accessor :conclusion
202
+ attr_accessor :attributes
203
+
204
+ def initialize(attributes,premises=[],conclusion=nil)
205
+ @attributes, @premises, @conclusion = attributes, premises, conclusion
206
+ end
207
+
208
+ def to_s
209
+ str = ''
210
+ @premises.each do |p|
211
+ str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" if p.first.threshold
212
+ str += "#{p.first.attribute} = #{p.last}" if !p.first.threshold
213
+ str += "\n"
214
+ end
215
+ str += "=> #{@conclusion} (#{accuracy})"
216
+ end
217
+
218
+ def predict(test)
219
+ verifies = true;
220
+ @premises.each do |p|
221
+ if p.first.threshold then # Continuous
222
+ if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) then
223
+ verifies = false; break
224
+ end
225
+ else # Discrete
226
+ if test[@attributes.index(p.first.attribute)] != p.last then
227
+ verifies = false; break
228
+ end
229
+ end
230
+ end
231
+ return @conclusion if verifies
232
+ return nil
233
+ end
234
+
235
+ def get_accuracy(data)
236
+ correct = 0; total = 0
237
+ data.each do |d|
238
+ prediction = predict(d)
239
+ correct += 1 if d.last == prediction
240
+ total += 1 if !prediction.nil?
241
+ end
242
+ (correct.to_f + 1) / (total.to_f + 2)
243
+ end
244
+
245
+ def accuracy(data=nil)
246
+ data.nil? ? @accuracy : @accuracy = get_accuracy(data)
247
+ end
248
+ end
249
+
250
+ class Ruleset
251
+ attr_accessor :rules
252
+
253
+ def initialize(attributes, data, default, type)
254
+ @attributes, @default, @type = attributes, default, type
255
+ mixed_data = data.sort_by {rand}
256
+ cut = (mixed_data.size.to_f * 0.67).to_i
257
+ @train_data = mixed_data.slice(0..cut-1)
258
+ @prune_data = mixed_data.slice(cut..-1)
259
+ end
260
+
261
+ def train(train_data=@train_data, attributes=@attributes, default=@default)
262
+ dec_tree = DecisionTree::ID3Tree.new(attributes, train_data, default, @type)
263
+ dec_tree.train
264
+ @rules = dec_tree.build_rules
265
+ @rules.each { |r| r.accuracy(train_data) } # Calculate accuracy
266
+ prune
267
+ end
268
+
269
+ def prune(data=@prune_data)
270
+ @rules.each do |r|
271
+ (1..r.premises.size).each do
272
+ acc1 = r.accuracy(data)
273
+ p = r.premises.pop
274
+ if acc1 > r.get_accuracy(data) then
275
+ r.premises.push(p); break
276
+ end
277
+ end
278
+ end
279
+ @rules = @rules.sort_by{|r| -r.accuracy(data)}
280
+ end
281
+
282
+ def to_s
283
+ str = ''; @rules.each { |rule| str += "#{rule}\n\n" }
284
+ str
285
+ end
286
+
287
+ def predict(test)
288
+ @rules.each do |r|
289
+ prediction = r.predict(test)
290
+ return prediction, r.accuracy unless prediction.nil?
291
+ end
292
+ return @default, 0.0
293
+ end
294
+ end
295
+
296
+ class Bagging
297
+ attr_accessor :classifiers
298
+ def initialize(attributes, data, default, type)
299
+ @classifiers, @type = [], type
300
+ @data, @attributes, @default = data, attributes, default
301
+ end
302
+
303
+ def train(data=@data, attributes=@attributes, default=@default)
304
+ @classifiers = []
305
+ 10.times { @classifiers << Ruleset.new(attributes, data, default, @type) }
306
+ @classifiers.each do |c|
307
+ c.train(data, attributes, default)
308
+ end
309
+ end
310
+
311
+ def predict(test)
312
+ predictions = Hash.new(0)
313
+ @classifiers.each do |c|
314
+ p, accuracy = c.predict(test)
315
+ predictions[p] += accuracy unless p.nil?
316
+ end
317
+ return @default, 0.0 if predictions.empty?
318
+ winner = predictions.sort_by {|k,v| -v}.first
319
+ return winner[0], winner[1].to_f / @classifiers.size.to_f
320
+ end
321
+ end
148
322
  end
@@ -1,7 +1,7 @@
1
1
  module DecisionTree #:nodoc:
2
2
  module VERSION #:nodoc:
3
3
  MAJOR = 0
4
- MINOR = 2
4
+ MINOR = 3
5
5
  TINY = 0
6
6
 
7
7
  STRING = [MAJOR, MINOR, TINY].join('.')
metadata CHANGED
@@ -3,8 +3,8 @@ rubygems_version: 0.9.2
3
3
  specification_version: 1
4
4
  name: decisiontree
5
5
  version: !ruby/object:Gem::Version
6
- version: 0.2.0
7
- date: 2007-07-07 00:00:00 -04:00
6
+ version: 0.3.0
7
+ date: 2007-09-15 00:00:00 -04:00
8
8
  summary: ID3-based implementation of the M.L. Decision Tree algorithm
9
9
  require_paths:
10
10
  - lib