decisiontree 0.3.0 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ *.gem
2
+ *.rbc
3
+ .bundle
4
+ .config
5
+ .yardoc
6
+ Gemfile.lock
7
+ InstalledFiles
8
+ _yardoc
9
+ coverage
10
+ doc/
11
+ lib/bundler/man
12
+ pkg
13
+ rdoc
14
+ spec/reports
15
+ test/tmp
16
+ test/version_tmp
17
+ tmp
data/Gemfile ADDED
@@ -0,0 +1,4 @@
1
+ source 'https://rubygems.org'
2
+
3
+ # Specify your gem's dependencies in ..gemspec
4
+ gemspec
@@ -0,0 +1,48 @@
1
+ # Decision Tree
2
+
3
+ A ruby library which implements ID3 (information gain) algorithm for decision tree learning. Currently, continuous and discrete datasets can be learned.
4
+
5
+ - Discrete model assumes unique labels & can be graphed and converted into a png for visual analysis
6
+ - Continuous looks at all possible values for a variable and iteratively chooses the best threshold between all possible assignments. This results in a binary tree which is partitioned by the threshold at every step. (e.g. temperate > 20C)
7
+
8
+ ## Features
9
+ - ID3 algorithms for continuous and discrete cases, with support for incosistent datasets.
10
+ - Graphviz component to visualize the learned tree (http://rockit.sourceforge.net/subprojects/graphr/)
11
+ - Support for multiple, and symbolic outputs and graphing of continuos trees.
12
+ - Returns default value when no branches are suitable for input
13
+
14
+ ## Implementation
15
+
16
+ - Ruleset is a class that trains an ID3Tree with 2/3 of the training data, converts it into a set of rules and prunes the rules with the remaining 1/3 of the training data (in a C4.5 way).
17
+ - Bagging is a bagging-based trainer (quite obvious), which trains 10 Ruleset trainers and when predicting chooses the best output based on voting.
18
+
19
+ Blog post with explanation & examples: http://www.igvita.com/2007/04/16/decision-tree-learning-in-ruby/
20
+
21
+ ## Example
22
+
23
+ ```ruby
24
+ require 'decisiontree'
25
+
26
+ attributes = ['Temperature']
27
+ training = [
28
+ [36.6, 'healthy'],
29
+ [37, 'sick'],
30
+ [38, 'sick'],
31
+ [36.7, 'healthy'],
32
+ [40, 'sick'],
33
+ [50, 'really sick'],
34
+ ]
35
+
36
+ # Instantiate the tree, and train it based on the data (set default to '1')
37
+ dec_tree = DecisionTree::ID3Tree.new(attributes, training, 'sick', :continuous)
38
+ dec_tree.train
39
+
40
+ decision = dec_tree.predict([37, 'sick'])
41
+ puts "Predicted: #{decision} ... True decision: #{test.last}";
42
+
43
+ # => Predicted: sick ... True decision: sick
44
+ ```
45
+
46
+ ## License
47
+
48
+ The MIT License - Copyright (c) 2006 Ilya Grigorik
data/Rakefile CHANGED
@@ -1,123 +1,7 @@
1
- require 'rubygems'
2
- require 'rake'
3
- require 'rake/clean'
4
- require 'rake/testtask'
5
- require 'rake/packagetask'
6
- require 'rake/gempackagetask'
7
- require 'rake/rdoctask'
8
- require 'rake/contrib/rubyforgepublisher'
9
- require 'fileutils'
10
- require 'hoe'
11
-
12
- include FileUtils
13
- require File.join(File.dirname(__FILE__), 'lib', 'decisiontree', 'version')
14
-
15
- AUTHOR = 'Ilya Grigorik' # can also be an array of Authors
16
- EMAIL = "ilya <at> fortehost.com"
17
- DESCRIPTION = "ID3-based implementation of the M.L. Decision Tree algorithm"
18
- GEM_NAME = 'decisiontree' # what ppl will type to install your gem
19
-
20
- @config_file = "~/.rubyforge/user-config.yml"
21
- @config = nil
22
- def rubyforge_username
23
- unless @config
24
- begin
25
- @config = YAML.load(File.read(File.expand_path(@config_file)))
26
- rescue
27
- puts <<-EOS
28
- ERROR: No rubyforge config file found: #{@config_file}"
29
- Run 'rubyforge setup' to prepare your env for access to Rubyforge
30
- - See http://newgem.rubyforge.org/rubyforge.html for more details
31
- EOS
32
- exit
33
- end
34
- end
35
- @rubyforge_username ||= @config["username"]
36
- end
37
-
38
- RUBYFORGE_PROJECT = 'decisiontree' # The unix name for your project
39
- HOMEPATH = "http://#{RUBYFORGE_PROJECT}.rubyforge.org"
40
- DOWNLOAD_PATH = "http://rubyforge.org/projects/#{RUBYFORGE_PROJECT}"
41
-
42
- NAME = "decisiontree"
43
- REV = nil
44
- # UNCOMMENT IF REQUIRED:
45
- # REV = `svn info`.each {|line| if line =~ /^Revision:/ then k,v = line.split(': '); break v.chomp; else next; end} rescue nil
46
- VERS = DecisionTree::VERSION::STRING + (REV ? ".#{REV}" : "")
47
- CLEAN.include ['**/.*.sw?', '*.gem', '.config', '**/.DS_Store']
48
- RDOC_OPTS = ['--quiet', '--title', 'decisiontree documentation',
49
- "--opname", "index.html",
50
- "--line-numbers",
51
- "--main", "README",
52
- "--inline-source"]
53
-
54
- class Hoe
55
- def extra_deps
56
- @extra_deps.reject { |x| Array(x).first == 'hoe' }
57
- end
58
- end
59
-
60
- # Generate all the Rake tasks
61
- # Run 'rake -T' to see list of generated tasks (from gem root directory)
62
- hoe = Hoe.new(GEM_NAME, VERS) do |p|
63
- p.author = AUTHOR
64
- p.description = DESCRIPTION
65
- p.email = EMAIL
66
- p.summary = DESCRIPTION
67
- p.url = HOMEPATH
68
- p.rubyforge_name = RUBYFORGE_PROJECT if RUBYFORGE_PROJECT
69
- p.test_globs = ["test/**/test_*.rb"]
70
- p.clean_globs |= CLEAN #An array of file patterns to delete on clean.
71
-
72
- # == Optional
73
- p.changes = p.paragraphs_of("History.txt", 0..1).join("\n\n")
74
- #p.extra_deps = [] # An array of rubygem dependencies [name, version], e.g. [ ['active_support', '>= 1.3.1'] ]
75
- #p.spec_extras = {} # A hash of extra values to set in the gemspec.
76
- end
77
-
78
- CHANGES = hoe.paragraphs_of('History.txt', 0..1).join("\n\n")
79
- PATH = (RUBYFORGE_PROJECT == GEM_NAME) ? RUBYFORGE_PROJECT : "#{RUBYFORGE_PROJECT}/#{GEM_NAME}"
80
- hoe.remote_rdoc_dir = File.join(PATH.gsub(/^#{RUBYFORGE_PROJECT}\/?/,''), 'rdoc')
81
-
82
- desc 'Generate website files'
83
- task :website_generate do
84
- Dir['website/**/*.txt'].each do |txt|
85
- sh %{ ruby scripts/txt2html #{txt} > #{txt.gsub(/txt$/,'html')} }
86
- end
87
- end
88
-
89
- desc 'Upload website files to rubyforge'
90
- task :website_upload do
91
- host = "#{rubyforge_username}@rubyforge.org"
92
- remote_dir = "/var/www/gforge-projects/#{PATH}/"
93
- local_dir = 'website'
94
- sh %{rsync -aCv #{local_dir}/ #{host}:#{remote_dir}}
95
- end
96
-
97
- desc 'Generate and upload website files'
98
- task :website => [:website_generate, :website_upload, :publish_docs]
99
-
100
- desc 'Release the website and new gem version'
101
- task :deploy => [:check_version, :website, :release] do
102
- puts "Remember to create SVN tag:"
103
- puts "svn copy svn+ssh://#{rubyforge_username}@rubyforge.org/var/svn/#{PATH}/trunk " +
104
- "svn+ssh://#{rubyforge_username}@rubyforge.org/var/svn/#{PATH}/tags/REL-#{VERS} "
105
- puts "Suggested comment:"
106
- puts "Tagging release #{CHANGES}"
107
- end
108
-
109
- desc 'Runs tasks website_generate and install_gem as a local deployment of the gem'
110
- task :local_deploy => [:website_generate, :install_gem]
111
-
112
- task :check_version do
113
- unless ENV['VERSION']
114
- puts 'Must pass a VERSION=x.y.z release version'
115
- exit
116
- end
117
- unless ENV['VERSION'] == VERS
118
- puts "Please update your version.rb to match the release version, currently #{VERS}"
119
- exit
120
- end
121
- end
1
+ require 'bundler'
2
+ Bundler::GemHelper.install_tasks
122
3
 
4
+ require 'rspec/core/rake_task'
5
+ RSpec::Core::RakeTask.new
123
6
 
7
+ task :default => :spec
@@ -0,0 +1,25 @@
1
+ # -*- encoding: utf-8 -*-
2
+ $:.push File.expand_path("../lib", __FILE__)
3
+
4
+ Gem::Specification.new do |s|
5
+ s.name = "decisiontree"
6
+ s.version = "0.4.0"
7
+ s.platform = Gem::Platform::RUBY
8
+ s.authors = ["Ilya Grigorik"]
9
+ s.email = ["ilya@igvita.com"]
10
+ s.homepage = "https://github.com/igrigorik/decisiontree"
11
+ s.summary = %q{ID3-based implementation of the M.L. Decision Tree algorithm}
12
+ s.description = s.summary
13
+
14
+ s.rubyforge_project = "decisiontree"
15
+
16
+ s.add_development_dependency "graphr"
17
+ s.add_development_dependency "rspec"
18
+ s.add_development_dependency "rspec-given"
19
+ s.add_development_dependency "pry"
20
+
21
+ s.files = `git ls-files`.split("\n")
22
+ s.test_files = `git ls-files -- {test,spec,features}/*`.split("\n")
23
+ s.executables = `git ls-files -- bin/*`.split("\n").map{ |f| File.basename(f) }
24
+ s.require_paths = ["lib"]
25
+ end
File without changes
@@ -1 +1 @@
1
- Dir[File.join(File.dirname(__FILE__), 'decisiontree/**/*.rb')].sort.each { |lib| require lib }
1
+ require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
@@ -1,13 +1,9 @@
1
- #The MIT License
2
-
3
- ###Copyright (c) 2007 Ilya Grigorik <ilya AT fortehost DOT com>
4
- ###Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
5
-
6
- begin;
7
- require 'graph/graphviz_dot'
8
- rescue LoadError
9
- STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included."
10
- end
1
+ # The MIT License
2
+ #
3
+ ### Copyright (c) 2007 Ilya Grigorik <ilya AT igvita DOT com>
4
+ ### Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
5
+
6
+ require 'graphr'
11
7
 
12
8
  class Object
13
9
  def save_to_file(filename)
@@ -19,9 +15,9 @@ class Object
19
15
  end
20
16
  end
21
17
 
22
- class Array
23
- def classification; collect { |v| v.last }; end
24
-
18
+ class Array
19
+ def classification; collect { |v| v.last }; end
20
+
25
21
  # calculate information entropy
26
22
  def entropy
27
23
  return 0 if empty?
@@ -55,28 +51,34 @@ module DecisionTree
55
51
 
56
52
  @tree = id3_train(data2, attributes, default)
57
53
  end
58
-
59
- def id3_train(data, attributes, default, used={})
60
- # Choose a fitness algorithm
61
- case @type
62
- when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
54
+
55
+ def type(attribute)
56
+ @type.is_a?(Hash) ? @type[attribute.to_sym] : @type
57
+ end
58
+
59
+ def fitness_for(attribute)
60
+ case type(attribute)
61
+ when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
63
62
  when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
64
63
  end
65
-
66
- return default if data.empty?
64
+ end
65
+
66
+ def id3_train(data, attributes, default, used={})
67
+ return default if data.empty?
67
68
 
68
69
  # return classification if all examples have the same classification
69
70
  return data.first.last if data.classification.uniq.size == 1
70
71
 
71
72
  # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
72
- performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) }
73
+ performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
73
74
  max = performance.max { |a,b| a[0] <=> b[0] }
74
75
  best = Node.new(attributes[performance.index(max)], max[1], max[0])
75
76
  best.threshold = nil if @type == :discrete
76
- @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
77
+ @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
77
78
  tree, l = {best => {}}, ['>=', '<']
78
-
79
- case @type
79
+
80
+ fitness = fitness_for(best.attribute)
81
+ case type(best.attribute)
80
82
  when :continuous
81
83
  data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
82
84
  tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
@@ -86,7 +88,7 @@ module DecisionTree
86
88
  partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
87
89
  partitions.each_with_index { |examples, i|
88
90
  tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
89
- }
91
+ }
90
92
  end
91
93
 
92
94
  tree
@@ -100,32 +102,32 @@ module DecisionTree
100
102
  thresholds.pop
101
103
  #thresholds -= used[attribute] if used.has_key? attribute
102
104
 
103
- gain = thresholds.collect { |threshold|
105
+ gain = thresholds.collect { |threshold|
104
106
  sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
105
107
  pos = (sp[0].size).to_f / data.size
106
108
  neg = (sp[1].size).to_f / data.size
107
-
109
+
108
110
  [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
109
111
  }.max { |a,b| a[0] <=> b[0] }
110
112
 
111
113
  return [-1, -1] if gain.size == 0
112
114
  gain
113
115
  end
114
-
116
+
115
117
  # ID3 for discrete label cases
116
118
  def id3_discrete(data, attributes, attribute)
117
119
  values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
118
120
  partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
119
121
  remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i }
120
-
122
+
121
123
  [data.classification.entropy - remainder, attributes.index(attribute)]
122
124
  end
123
125
 
124
126
  def predict(test)
125
- return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)), 1
127
+ descend(@tree, test)
126
128
  end
127
129
 
128
- def graph(filename)
130
+ def graph(filename)
129
131
  dgp = DotGraphPrinter.new(build_tree)
130
132
  dgp.write_to_file("#{filename}.png", "png")
131
133
  end
@@ -155,22 +157,20 @@ module DecisionTree
155
157
  end
156
158
 
157
159
  private
158
- def descend_continuous(tree, test)
160
+ def descend(tree, test)
159
161
  attr = tree.to_a.first
160
162
  return @default if !attr
161
- return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
162
- return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
163
- return descend_continuous(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
164
- return descend_continuous(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
165
- end
166
-
167
- def descend_discrete(tree, test)
168
- attr = tree.to_a.first
169
- return @default if !attr
170
- return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
171
- return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test)
163
+ if type(attr.first.attribute) == :continuous
164
+ return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
165
+ return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
166
+ return descend(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
167
+ return descend(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
168
+ else
169
+ return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
170
+ return descend(attr[1][test[@attributes.index(attr[0].attribute)]],test)
171
+ end
172
172
  end
173
-
173
+
174
174
  def build_tree(tree = @tree)
175
175
  return [] unless tree.is_a?(Hash)
176
176
  return [["Always", @default]] if tree.empty?
@@ -186,7 +186,7 @@ module DecisionTree
186
186
  child = attr[1][key]
187
187
  child_text = "#{child}\n(#{child.to_s.clone.object_id})"
188
188
  end
189
- label_text = "#{key} #{@type == :continuous ? attr[0].threshold : ""}"
189
+ label_text = "#{key} #{type(attr[0].attribute) == :continuous ? attr[0].threshold : ""}"
190
190
 
191
191
  [parent_text, child_text, label_text]
192
192
  end
@@ -286,7 +286,7 @@ module DecisionTree
286
286
 
287
287
  def predict(test)
288
288
  @rules.each do |r|
289
- prediction = r.predict(test)
289
+ prediction = r.predict(test)
290
290
  return prediction, r.accuracy unless prediction.nil?
291
291
  end
292
292
  return @default, 0.0
@@ -0,0 +1,77 @@
1
+ require 'spec_helper'
2
+
3
+ describe describe DecisionTree::ID3Tree do
4
+
5
+ describe "simple discrete case" do
6
+ Given(:labels) { ["sun", "rain"]}
7
+ Given(:data) do
8
+ [
9
+ [1,0,1],
10
+ [0,1,0]
11
+ ]
12
+ end
13
+ Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) }
14
+ When { tree.train }
15
+ Then { tree.predict([1,0]).should == 1 }
16
+ Then { tree.predict([0,1]).should == 0 }
17
+ end
18
+
19
+ describe "discrete attributes" do
20
+ Given(:labels) { ["hungry", "color"] }
21
+ Given(:data) do
22
+ [
23
+ ["yes", "red", "angry"],
24
+ ["no", "blue", "not angry"],
25
+ ["yes", "blue", "not angry"],
26
+ ["no", "red", "not angry"]
27
+ ]
28
+ end
29
+ Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :discrete) }
30
+ When { tree.train }
31
+ Then { tree.predict(["yes", "red"]).should == "angry" }
32
+ Then { tree.predict(["no", "red"]).should == "not angry" }
33
+ end
34
+
35
+ describe "discrete attributes" do
36
+ Given(:labels) { ["hunger", "happiness"] }
37
+ Given(:data) do
38
+ [
39
+ [8, 7, "angry"],
40
+ [6, 7, "angry"],
41
+ [7, 9, "angry"],
42
+ [7, 1, "not angry"],
43
+ [2, 9, "not angry"],
44
+ [3, 2, "not angry"],
45
+ [2, 3, "not angry"],
46
+ [1, 4, "not angry"]
47
+ ]
48
+ end
49
+ Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :continuous) }
50
+ When { tree.train }
51
+ Then { tree.graph("continuous") }
52
+ Then { tree.predict([7, 7]).should == "angry" }
53
+ Then { tree.predict([2, 3]).should == "not angry" }
54
+ end
55
+
56
+ describe "a mixture" do
57
+ Given(:labels) { ["hunger", "color"] }
58
+ Given(:data) do
59
+ [
60
+ [8, "red", "angry"],
61
+ [6, "red", "angry"],
62
+ [7, "red", "angry"],
63
+ [7, "blue", "not angry"],
64
+ [2, "red", "not angry"],
65
+ [3, "blue", "not angry"],
66
+ [2, "blue", "not angry"],
67
+ [1, "red", "not angry"]
68
+ ]
69
+ end
70
+ Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", color: :discrete, hunger: :continuous) }
71
+ When { tree.train }
72
+ Then { tree.graph("continuous") }
73
+ Then { tree.predict([7, "red"]).should == "angry" }
74
+ Then { tree.predict([2, "blue"]).should == "not angry" }
75
+ end
76
+
77
+ end