decisiontree 0.3.0 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- data/.gitignore +17 -0
- data/Gemfile +4 -0
- data/README.md +48 -0
- data/Rakefile +5 -121
- data/decisiontree.gemspec +25 -0
- data/examples/simple.rb +0 -0
- data/lib/decisiontree.rb +1 -1
- data/lib/decisiontree/id3_tree.rb +46 -46
- data/spec/id3_spec.rb +77 -0
- data/spec/spec_helper.rb +3 -0
- metadata +118 -69
- data/CHANGELOG.txt +0 -17
- data/History.txt +0 -0
- data/Manifest.txt +0 -24
- data/README.txt +0 -15
- data/lib/decisiontree/version.rb +0 -9
- data/scripts/txt2html +0 -67
- data/setup.rb +0 -1585
- data/test/test_decisiontree.rb +0 -26
- data/test/test_helper.rb +0 -2
- data/website/index.html +0 -11
- data/website/index.txt +0 -38
- data/website/javascripts/rounded_corners_lite.inc.js +0 -285
- data/website/stylesheets/screen.css +0 -138
- data/website/template.rhtml +0 -48
data/.gitignore
ADDED
data/Gemfile
ADDED
data/README.md
ADDED
@@ -0,0 +1,48 @@
|
|
1
|
+
# Decision Tree
|
2
|
+
|
3
|
+
A ruby library which implements ID3 (information gain) algorithm for decision tree learning. Currently, continuous and discrete datasets can be learned.
|
4
|
+
|
5
|
+
- Discrete model assumes unique labels & can be graphed and converted into a png for visual analysis
|
6
|
+
- Continuous looks at all possible values for a variable and iteratively chooses the best threshold between all possible assignments. This results in a binary tree which is partitioned by the threshold at every step. (e.g. temperate > 20C)
|
7
|
+
|
8
|
+
## Features
|
9
|
+
- ID3 algorithms for continuous and discrete cases, with support for incosistent datasets.
|
10
|
+
- Graphviz component to visualize the learned tree (http://rockit.sourceforge.net/subprojects/graphr/)
|
11
|
+
- Support for multiple, and symbolic outputs and graphing of continuos trees.
|
12
|
+
- Returns default value when no branches are suitable for input
|
13
|
+
|
14
|
+
## Implementation
|
15
|
+
|
16
|
+
- Ruleset is a class that trains an ID3Tree with 2/3 of the training data, converts it into a set of rules and prunes the rules with the remaining 1/3 of the training data (in a C4.5 way).
|
17
|
+
- Bagging is a bagging-based trainer (quite obvious), which trains 10 Ruleset trainers and when predicting chooses the best output based on voting.
|
18
|
+
|
19
|
+
Blog post with explanation & examples: http://www.igvita.com/2007/04/16/decision-tree-learning-in-ruby/
|
20
|
+
|
21
|
+
## Example
|
22
|
+
|
23
|
+
```ruby
|
24
|
+
require 'decisiontree'
|
25
|
+
|
26
|
+
attributes = ['Temperature']
|
27
|
+
training = [
|
28
|
+
[36.6, 'healthy'],
|
29
|
+
[37, 'sick'],
|
30
|
+
[38, 'sick'],
|
31
|
+
[36.7, 'healthy'],
|
32
|
+
[40, 'sick'],
|
33
|
+
[50, 'really sick'],
|
34
|
+
]
|
35
|
+
|
36
|
+
# Instantiate the tree, and train it based on the data (set default to '1')
|
37
|
+
dec_tree = DecisionTree::ID3Tree.new(attributes, training, 'sick', :continuous)
|
38
|
+
dec_tree.train
|
39
|
+
|
40
|
+
decision = dec_tree.predict([37, 'sick'])
|
41
|
+
puts "Predicted: #{decision} ... True decision: #{test.last}";
|
42
|
+
|
43
|
+
# => Predicted: sick ... True decision: sick
|
44
|
+
```
|
45
|
+
|
46
|
+
## License
|
47
|
+
|
48
|
+
The MIT License - Copyright (c) 2006 Ilya Grigorik
|
data/Rakefile
CHANGED
@@ -1,123 +1,7 @@
|
|
1
|
-
require '
|
2
|
-
|
3
|
-
require 'rake/clean'
|
4
|
-
require 'rake/testtask'
|
5
|
-
require 'rake/packagetask'
|
6
|
-
require 'rake/gempackagetask'
|
7
|
-
require 'rake/rdoctask'
|
8
|
-
require 'rake/contrib/rubyforgepublisher'
|
9
|
-
require 'fileutils'
|
10
|
-
require 'hoe'
|
11
|
-
|
12
|
-
include FileUtils
|
13
|
-
require File.join(File.dirname(__FILE__), 'lib', 'decisiontree', 'version')
|
14
|
-
|
15
|
-
AUTHOR = 'Ilya Grigorik' # can also be an array of Authors
|
16
|
-
EMAIL = "ilya <at> fortehost.com"
|
17
|
-
DESCRIPTION = "ID3-based implementation of the M.L. Decision Tree algorithm"
|
18
|
-
GEM_NAME = 'decisiontree' # what ppl will type to install your gem
|
19
|
-
|
20
|
-
@config_file = "~/.rubyforge/user-config.yml"
|
21
|
-
@config = nil
|
22
|
-
def rubyforge_username
|
23
|
-
unless @config
|
24
|
-
begin
|
25
|
-
@config = YAML.load(File.read(File.expand_path(@config_file)))
|
26
|
-
rescue
|
27
|
-
puts <<-EOS
|
28
|
-
ERROR: No rubyforge config file found: #{@config_file}"
|
29
|
-
Run 'rubyforge setup' to prepare your env for access to Rubyforge
|
30
|
-
- See http://newgem.rubyforge.org/rubyforge.html for more details
|
31
|
-
EOS
|
32
|
-
exit
|
33
|
-
end
|
34
|
-
end
|
35
|
-
@rubyforge_username ||= @config["username"]
|
36
|
-
end
|
37
|
-
|
38
|
-
RUBYFORGE_PROJECT = 'decisiontree' # The unix name for your project
|
39
|
-
HOMEPATH = "http://#{RUBYFORGE_PROJECT}.rubyforge.org"
|
40
|
-
DOWNLOAD_PATH = "http://rubyforge.org/projects/#{RUBYFORGE_PROJECT}"
|
41
|
-
|
42
|
-
NAME = "decisiontree"
|
43
|
-
REV = nil
|
44
|
-
# UNCOMMENT IF REQUIRED:
|
45
|
-
# REV = `svn info`.each {|line| if line =~ /^Revision:/ then k,v = line.split(': '); break v.chomp; else next; end} rescue nil
|
46
|
-
VERS = DecisionTree::VERSION::STRING + (REV ? ".#{REV}" : "")
|
47
|
-
CLEAN.include ['**/.*.sw?', '*.gem', '.config', '**/.DS_Store']
|
48
|
-
RDOC_OPTS = ['--quiet', '--title', 'decisiontree documentation',
|
49
|
-
"--opname", "index.html",
|
50
|
-
"--line-numbers",
|
51
|
-
"--main", "README",
|
52
|
-
"--inline-source"]
|
53
|
-
|
54
|
-
class Hoe
|
55
|
-
def extra_deps
|
56
|
-
@extra_deps.reject { |x| Array(x).first == 'hoe' }
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
|
-
# Generate all the Rake tasks
|
61
|
-
# Run 'rake -T' to see list of generated tasks (from gem root directory)
|
62
|
-
hoe = Hoe.new(GEM_NAME, VERS) do |p|
|
63
|
-
p.author = AUTHOR
|
64
|
-
p.description = DESCRIPTION
|
65
|
-
p.email = EMAIL
|
66
|
-
p.summary = DESCRIPTION
|
67
|
-
p.url = HOMEPATH
|
68
|
-
p.rubyforge_name = RUBYFORGE_PROJECT if RUBYFORGE_PROJECT
|
69
|
-
p.test_globs = ["test/**/test_*.rb"]
|
70
|
-
p.clean_globs |= CLEAN #An array of file patterns to delete on clean.
|
71
|
-
|
72
|
-
# == Optional
|
73
|
-
p.changes = p.paragraphs_of("History.txt", 0..1).join("\n\n")
|
74
|
-
#p.extra_deps = [] # An array of rubygem dependencies [name, version], e.g. [ ['active_support', '>= 1.3.1'] ]
|
75
|
-
#p.spec_extras = {} # A hash of extra values to set in the gemspec.
|
76
|
-
end
|
77
|
-
|
78
|
-
CHANGES = hoe.paragraphs_of('History.txt', 0..1).join("\n\n")
|
79
|
-
PATH = (RUBYFORGE_PROJECT == GEM_NAME) ? RUBYFORGE_PROJECT : "#{RUBYFORGE_PROJECT}/#{GEM_NAME}"
|
80
|
-
hoe.remote_rdoc_dir = File.join(PATH.gsub(/^#{RUBYFORGE_PROJECT}\/?/,''), 'rdoc')
|
81
|
-
|
82
|
-
desc 'Generate website files'
|
83
|
-
task :website_generate do
|
84
|
-
Dir['website/**/*.txt'].each do |txt|
|
85
|
-
sh %{ ruby scripts/txt2html #{txt} > #{txt.gsub(/txt$/,'html')} }
|
86
|
-
end
|
87
|
-
end
|
88
|
-
|
89
|
-
desc 'Upload website files to rubyforge'
|
90
|
-
task :website_upload do
|
91
|
-
host = "#{rubyforge_username}@rubyforge.org"
|
92
|
-
remote_dir = "/var/www/gforge-projects/#{PATH}/"
|
93
|
-
local_dir = 'website'
|
94
|
-
sh %{rsync -aCv #{local_dir}/ #{host}:#{remote_dir}}
|
95
|
-
end
|
96
|
-
|
97
|
-
desc 'Generate and upload website files'
|
98
|
-
task :website => [:website_generate, :website_upload, :publish_docs]
|
99
|
-
|
100
|
-
desc 'Release the website and new gem version'
|
101
|
-
task :deploy => [:check_version, :website, :release] do
|
102
|
-
puts "Remember to create SVN tag:"
|
103
|
-
puts "svn copy svn+ssh://#{rubyforge_username}@rubyforge.org/var/svn/#{PATH}/trunk " +
|
104
|
-
"svn+ssh://#{rubyforge_username}@rubyforge.org/var/svn/#{PATH}/tags/REL-#{VERS} "
|
105
|
-
puts "Suggested comment:"
|
106
|
-
puts "Tagging release #{CHANGES}"
|
107
|
-
end
|
108
|
-
|
109
|
-
desc 'Runs tasks website_generate and install_gem as a local deployment of the gem'
|
110
|
-
task :local_deploy => [:website_generate, :install_gem]
|
111
|
-
|
112
|
-
task :check_version do
|
113
|
-
unless ENV['VERSION']
|
114
|
-
puts 'Must pass a VERSION=x.y.z release version'
|
115
|
-
exit
|
116
|
-
end
|
117
|
-
unless ENV['VERSION'] == VERS
|
118
|
-
puts "Please update your version.rb to match the release version, currently #{VERS}"
|
119
|
-
exit
|
120
|
-
end
|
121
|
-
end
|
1
|
+
require 'bundler'
|
2
|
+
Bundler::GemHelper.install_tasks
|
122
3
|
|
4
|
+
require 'rspec/core/rake_task'
|
5
|
+
RSpec::Core::RakeTask.new
|
123
6
|
|
7
|
+
task :default => :spec
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
2
|
+
$:.push File.expand_path("../lib", __FILE__)
|
3
|
+
|
4
|
+
Gem::Specification.new do |s|
|
5
|
+
s.name = "decisiontree"
|
6
|
+
s.version = "0.4.0"
|
7
|
+
s.platform = Gem::Platform::RUBY
|
8
|
+
s.authors = ["Ilya Grigorik"]
|
9
|
+
s.email = ["ilya@igvita.com"]
|
10
|
+
s.homepage = "https://github.com/igrigorik/decisiontree"
|
11
|
+
s.summary = %q{ID3-based implementation of the M.L. Decision Tree algorithm}
|
12
|
+
s.description = s.summary
|
13
|
+
|
14
|
+
s.rubyforge_project = "decisiontree"
|
15
|
+
|
16
|
+
s.add_development_dependency "graphr"
|
17
|
+
s.add_development_dependency "rspec"
|
18
|
+
s.add_development_dependency "rspec-given"
|
19
|
+
s.add_development_dependency "pry"
|
20
|
+
|
21
|
+
s.files = `git ls-files`.split("\n")
|
22
|
+
s.test_files = `git ls-files -- {test,spec,features}/*`.split("\n")
|
23
|
+
s.executables = `git ls-files -- bin/*`.split("\n").map{ |f| File.basename(f) }
|
24
|
+
s.require_paths = ["lib"]
|
25
|
+
end
|
data/examples/simple.rb
CHANGED
File without changes
|
data/lib/decisiontree.rb
CHANGED
@@ -1 +1 @@
|
|
1
|
-
|
1
|
+
require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
|
@@ -1,13 +1,9 @@
|
|
1
|
-
#The MIT License
|
2
|
-
|
3
|
-
###Copyright (c) 2007 Ilya Grigorik <ilya AT
|
4
|
-
###Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
|
5
|
-
|
6
|
-
|
7
|
-
require 'graph/graphviz_dot'
|
8
|
-
rescue LoadError
|
9
|
-
STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included."
|
10
|
-
end
|
1
|
+
# The MIT License
|
2
|
+
#
|
3
|
+
### Copyright (c) 2007 Ilya Grigorik <ilya AT igvita DOT com>
|
4
|
+
### Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
|
5
|
+
|
6
|
+
require 'graphr'
|
11
7
|
|
12
8
|
class Object
|
13
9
|
def save_to_file(filename)
|
@@ -19,9 +15,9 @@ class Object
|
|
19
15
|
end
|
20
16
|
end
|
21
17
|
|
22
|
-
class Array
|
23
|
-
def classification; collect { |v| v.last }; end
|
24
|
-
|
18
|
+
class Array
|
19
|
+
def classification; collect { |v| v.last }; end
|
20
|
+
|
25
21
|
# calculate information entropy
|
26
22
|
def entropy
|
27
23
|
return 0 if empty?
|
@@ -55,28 +51,34 @@ module DecisionTree
|
|
55
51
|
|
56
52
|
@tree = id3_train(data2, attributes, default)
|
57
53
|
end
|
58
|
-
|
59
|
-
def
|
60
|
-
|
61
|
-
|
62
|
-
|
54
|
+
|
55
|
+
def type(attribute)
|
56
|
+
@type.is_a?(Hash) ? @type[attribute.to_sym] : @type
|
57
|
+
end
|
58
|
+
|
59
|
+
def fitness_for(attribute)
|
60
|
+
case type(attribute)
|
61
|
+
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
63
62
|
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
|
64
63
|
end
|
65
|
-
|
66
|
-
|
64
|
+
end
|
65
|
+
|
66
|
+
def id3_train(data, attributes, default, used={})
|
67
|
+
return default if data.empty?
|
67
68
|
|
68
69
|
# return classification if all examples have the same classification
|
69
70
|
return data.first.last if data.classification.uniq.size == 1
|
70
71
|
|
71
72
|
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
|
72
|
-
performance = attributes.collect { |attribute|
|
73
|
+
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
73
74
|
max = performance.max { |a,b| a[0] <=> b[0] }
|
74
75
|
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
75
76
|
best.threshold = nil if @type == :discrete
|
76
|
-
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
77
|
+
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
77
78
|
tree, l = {best => {}}, ['>=', '<']
|
78
|
-
|
79
|
-
|
79
|
+
|
80
|
+
fitness = fitness_for(best.attribute)
|
81
|
+
case type(best.attribute)
|
80
82
|
when :continuous
|
81
83
|
data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
|
82
84
|
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
@@ -86,7 +88,7 @@ module DecisionTree
|
|
86
88
|
partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
|
87
89
|
partitions.each_with_index { |examples, i|
|
88
90
|
tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
|
89
|
-
}
|
91
|
+
}
|
90
92
|
end
|
91
93
|
|
92
94
|
tree
|
@@ -100,32 +102,32 @@ module DecisionTree
|
|
100
102
|
thresholds.pop
|
101
103
|
#thresholds -= used[attribute] if used.has_key? attribute
|
102
104
|
|
103
|
-
gain = thresholds.collect { |threshold|
|
105
|
+
gain = thresholds.collect { |threshold|
|
104
106
|
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
105
107
|
pos = (sp[0].size).to_f / data.size
|
106
108
|
neg = (sp[1].size).to_f / data.size
|
107
|
-
|
109
|
+
|
108
110
|
[data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
|
109
111
|
}.max { |a,b| a[0] <=> b[0] }
|
110
112
|
|
111
113
|
return [-1, -1] if gain.size == 0
|
112
114
|
gain
|
113
115
|
end
|
114
|
-
|
116
|
+
|
115
117
|
# ID3 for discrete label cases
|
116
118
|
def id3_discrete(data, attributes, attribute)
|
117
119
|
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
|
118
120
|
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
|
119
121
|
remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i }
|
120
|
-
|
122
|
+
|
121
123
|
[data.classification.entropy - remainder, attributes.index(attribute)]
|
122
124
|
end
|
123
125
|
|
124
126
|
def predict(test)
|
125
|
-
|
127
|
+
descend(@tree, test)
|
126
128
|
end
|
127
129
|
|
128
|
-
def graph(filename)
|
130
|
+
def graph(filename)
|
129
131
|
dgp = DotGraphPrinter.new(build_tree)
|
130
132
|
dgp.write_to_file("#{filename}.png", "png")
|
131
133
|
end
|
@@ -155,22 +157,20 @@ module DecisionTree
|
|
155
157
|
end
|
156
158
|
|
157
159
|
private
|
158
|
-
def
|
160
|
+
def descend(tree, test)
|
159
161
|
attr = tree.to_a.first
|
160
162
|
return @default if !attr
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
171
|
-
return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test)
|
163
|
+
if type(attr.first.attribute) == :continuous
|
164
|
+
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
165
|
+
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
166
|
+
return descend(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
167
|
+
return descend(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
168
|
+
else
|
169
|
+
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
170
|
+
return descend(attr[1][test[@attributes.index(attr[0].attribute)]],test)
|
171
|
+
end
|
172
172
|
end
|
173
|
-
|
173
|
+
|
174
174
|
def build_tree(tree = @tree)
|
175
175
|
return [] unless tree.is_a?(Hash)
|
176
176
|
return [["Always", @default]] if tree.empty?
|
@@ -186,7 +186,7 @@ module DecisionTree
|
|
186
186
|
child = attr[1][key]
|
187
187
|
child_text = "#{child}\n(#{child.to_s.clone.object_id})"
|
188
188
|
end
|
189
|
-
label_text = "#{key} #{
|
189
|
+
label_text = "#{key} #{type(attr[0].attribute) == :continuous ? attr[0].threshold : ""}"
|
190
190
|
|
191
191
|
[parent_text, child_text, label_text]
|
192
192
|
end
|
@@ -286,7 +286,7 @@ module DecisionTree
|
|
286
286
|
|
287
287
|
def predict(test)
|
288
288
|
@rules.each do |r|
|
289
|
-
prediction = r.predict(test)
|
289
|
+
prediction = r.predict(test)
|
290
290
|
return prediction, r.accuracy unless prediction.nil?
|
291
291
|
end
|
292
292
|
return @default, 0.0
|
data/spec/id3_spec.rb
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe describe DecisionTree::ID3Tree do
|
4
|
+
|
5
|
+
describe "simple discrete case" do
|
6
|
+
Given(:labels) { ["sun", "rain"]}
|
7
|
+
Given(:data) do
|
8
|
+
[
|
9
|
+
[1,0,1],
|
10
|
+
[0,1,0]
|
11
|
+
]
|
12
|
+
end
|
13
|
+
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) }
|
14
|
+
When { tree.train }
|
15
|
+
Then { tree.predict([1,0]).should == 1 }
|
16
|
+
Then { tree.predict([0,1]).should == 0 }
|
17
|
+
end
|
18
|
+
|
19
|
+
describe "discrete attributes" do
|
20
|
+
Given(:labels) { ["hungry", "color"] }
|
21
|
+
Given(:data) do
|
22
|
+
[
|
23
|
+
["yes", "red", "angry"],
|
24
|
+
["no", "blue", "not angry"],
|
25
|
+
["yes", "blue", "not angry"],
|
26
|
+
["no", "red", "not angry"]
|
27
|
+
]
|
28
|
+
end
|
29
|
+
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :discrete) }
|
30
|
+
When { tree.train }
|
31
|
+
Then { tree.predict(["yes", "red"]).should == "angry" }
|
32
|
+
Then { tree.predict(["no", "red"]).should == "not angry" }
|
33
|
+
end
|
34
|
+
|
35
|
+
describe "discrete attributes" do
|
36
|
+
Given(:labels) { ["hunger", "happiness"] }
|
37
|
+
Given(:data) do
|
38
|
+
[
|
39
|
+
[8, 7, "angry"],
|
40
|
+
[6, 7, "angry"],
|
41
|
+
[7, 9, "angry"],
|
42
|
+
[7, 1, "not angry"],
|
43
|
+
[2, 9, "not angry"],
|
44
|
+
[3, 2, "not angry"],
|
45
|
+
[2, 3, "not angry"],
|
46
|
+
[1, 4, "not angry"]
|
47
|
+
]
|
48
|
+
end
|
49
|
+
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :continuous) }
|
50
|
+
When { tree.train }
|
51
|
+
Then { tree.graph("continuous") }
|
52
|
+
Then { tree.predict([7, 7]).should == "angry" }
|
53
|
+
Then { tree.predict([2, 3]).should == "not angry" }
|
54
|
+
end
|
55
|
+
|
56
|
+
describe "a mixture" do
|
57
|
+
Given(:labels) { ["hunger", "color"] }
|
58
|
+
Given(:data) do
|
59
|
+
[
|
60
|
+
[8, "red", "angry"],
|
61
|
+
[6, "red", "angry"],
|
62
|
+
[7, "red", "angry"],
|
63
|
+
[7, "blue", "not angry"],
|
64
|
+
[2, "red", "not angry"],
|
65
|
+
[3, "blue", "not angry"],
|
66
|
+
[2, "blue", "not angry"],
|
67
|
+
[1, "red", "not angry"]
|
68
|
+
]
|
69
|
+
end
|
70
|
+
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", color: :discrete, hunger: :continuous) }
|
71
|
+
When { tree.train }
|
72
|
+
Then { tree.graph("continuous") }
|
73
|
+
Then { tree.predict([7, "red"]).should == "angry" }
|
74
|
+
Then { tree.predict([2, "blue"]).should == "not angry" }
|
75
|
+
end
|
76
|
+
|
77
|
+
end
|