decisiontree 0.3.0 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/.gitignore +17 -0
- data/Gemfile +4 -0
- data/README.md +48 -0
- data/Rakefile +5 -121
- data/decisiontree.gemspec +25 -0
- data/examples/simple.rb +0 -0
- data/lib/decisiontree.rb +1 -1
- data/lib/decisiontree/id3_tree.rb +46 -46
- data/spec/id3_spec.rb +77 -0
- data/spec/spec_helper.rb +3 -0
- metadata +118 -69
- data/CHANGELOG.txt +0 -17
- data/History.txt +0 -0
- data/Manifest.txt +0 -24
- data/README.txt +0 -15
- data/lib/decisiontree/version.rb +0 -9
- data/scripts/txt2html +0 -67
- data/setup.rb +0 -1585
- data/test/test_decisiontree.rb +0 -26
- data/test/test_helper.rb +0 -2
- data/website/index.html +0 -11
- data/website/index.txt +0 -38
- data/website/javascripts/rounded_corners_lite.inc.js +0 -285
- data/website/stylesheets/screen.css +0 -138
- data/website/template.rhtml +0 -48
data/.gitignore
ADDED
data/Gemfile
ADDED
data/README.md
ADDED
@@ -0,0 +1,48 @@
|
|
1
|
+
# Decision Tree
|
2
|
+
|
3
|
+
A ruby library which implements ID3 (information gain) algorithm for decision tree learning. Currently, continuous and discrete datasets can be learned.
|
4
|
+
|
5
|
+
- Discrete model assumes unique labels & can be graphed and converted into a png for visual analysis
|
6
|
+
- Continuous looks at all possible values for a variable and iteratively chooses the best threshold between all possible assignments. This results in a binary tree which is partitioned by the threshold at every step. (e.g. temperate > 20C)
|
7
|
+
|
8
|
+
## Features
|
9
|
+
- ID3 algorithms for continuous and discrete cases, with support for incosistent datasets.
|
10
|
+
- Graphviz component to visualize the learned tree (http://rockit.sourceforge.net/subprojects/graphr/)
|
11
|
+
- Support for multiple, and symbolic outputs and graphing of continuos trees.
|
12
|
+
- Returns default value when no branches are suitable for input
|
13
|
+
|
14
|
+
## Implementation
|
15
|
+
|
16
|
+
- Ruleset is a class that trains an ID3Tree with 2/3 of the training data, converts it into a set of rules and prunes the rules with the remaining 1/3 of the training data (in a C4.5 way).
|
17
|
+
- Bagging is a bagging-based trainer (quite obvious), which trains 10 Ruleset trainers and when predicting chooses the best output based on voting.
|
18
|
+
|
19
|
+
Blog post with explanation & examples: http://www.igvita.com/2007/04/16/decision-tree-learning-in-ruby/
|
20
|
+
|
21
|
+
## Example
|
22
|
+
|
23
|
+
```ruby
|
24
|
+
require 'decisiontree'
|
25
|
+
|
26
|
+
attributes = ['Temperature']
|
27
|
+
training = [
|
28
|
+
[36.6, 'healthy'],
|
29
|
+
[37, 'sick'],
|
30
|
+
[38, 'sick'],
|
31
|
+
[36.7, 'healthy'],
|
32
|
+
[40, 'sick'],
|
33
|
+
[50, 'really sick'],
|
34
|
+
]
|
35
|
+
|
36
|
+
# Instantiate the tree, and train it based on the data (set default to '1')
|
37
|
+
dec_tree = DecisionTree::ID3Tree.new(attributes, training, 'sick', :continuous)
|
38
|
+
dec_tree.train
|
39
|
+
|
40
|
+
decision = dec_tree.predict([37, 'sick'])
|
41
|
+
puts "Predicted: #{decision} ... True decision: #{test.last}";
|
42
|
+
|
43
|
+
# => Predicted: sick ... True decision: sick
|
44
|
+
```
|
45
|
+
|
46
|
+
## License
|
47
|
+
|
48
|
+
The MIT License - Copyright (c) 2006 Ilya Grigorik
|
data/Rakefile
CHANGED
@@ -1,123 +1,7 @@
|
|
1
|
-
require '
|
2
|
-
|
3
|
-
require 'rake/clean'
|
4
|
-
require 'rake/testtask'
|
5
|
-
require 'rake/packagetask'
|
6
|
-
require 'rake/gempackagetask'
|
7
|
-
require 'rake/rdoctask'
|
8
|
-
require 'rake/contrib/rubyforgepublisher'
|
9
|
-
require 'fileutils'
|
10
|
-
require 'hoe'
|
11
|
-
|
12
|
-
include FileUtils
|
13
|
-
require File.join(File.dirname(__FILE__), 'lib', 'decisiontree', 'version')
|
14
|
-
|
15
|
-
AUTHOR = 'Ilya Grigorik' # can also be an array of Authors
|
16
|
-
EMAIL = "ilya <at> fortehost.com"
|
17
|
-
DESCRIPTION = "ID3-based implementation of the M.L. Decision Tree algorithm"
|
18
|
-
GEM_NAME = 'decisiontree' # what ppl will type to install your gem
|
19
|
-
|
20
|
-
@config_file = "~/.rubyforge/user-config.yml"
|
21
|
-
@config = nil
|
22
|
-
def rubyforge_username
|
23
|
-
unless @config
|
24
|
-
begin
|
25
|
-
@config = YAML.load(File.read(File.expand_path(@config_file)))
|
26
|
-
rescue
|
27
|
-
puts <<-EOS
|
28
|
-
ERROR: No rubyforge config file found: #{@config_file}"
|
29
|
-
Run 'rubyforge setup' to prepare your env for access to Rubyforge
|
30
|
-
- See http://newgem.rubyforge.org/rubyforge.html for more details
|
31
|
-
EOS
|
32
|
-
exit
|
33
|
-
end
|
34
|
-
end
|
35
|
-
@rubyforge_username ||= @config["username"]
|
36
|
-
end
|
37
|
-
|
38
|
-
RUBYFORGE_PROJECT = 'decisiontree' # The unix name for your project
|
39
|
-
HOMEPATH = "http://#{RUBYFORGE_PROJECT}.rubyforge.org"
|
40
|
-
DOWNLOAD_PATH = "http://rubyforge.org/projects/#{RUBYFORGE_PROJECT}"
|
41
|
-
|
42
|
-
NAME = "decisiontree"
|
43
|
-
REV = nil
|
44
|
-
# UNCOMMENT IF REQUIRED:
|
45
|
-
# REV = `svn info`.each {|line| if line =~ /^Revision:/ then k,v = line.split(': '); break v.chomp; else next; end} rescue nil
|
46
|
-
VERS = DecisionTree::VERSION::STRING + (REV ? ".#{REV}" : "")
|
47
|
-
CLEAN.include ['**/.*.sw?', '*.gem', '.config', '**/.DS_Store']
|
48
|
-
RDOC_OPTS = ['--quiet', '--title', 'decisiontree documentation',
|
49
|
-
"--opname", "index.html",
|
50
|
-
"--line-numbers",
|
51
|
-
"--main", "README",
|
52
|
-
"--inline-source"]
|
53
|
-
|
54
|
-
class Hoe
|
55
|
-
def extra_deps
|
56
|
-
@extra_deps.reject { |x| Array(x).first == 'hoe' }
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
|
-
# Generate all the Rake tasks
|
61
|
-
# Run 'rake -T' to see list of generated tasks (from gem root directory)
|
62
|
-
hoe = Hoe.new(GEM_NAME, VERS) do |p|
|
63
|
-
p.author = AUTHOR
|
64
|
-
p.description = DESCRIPTION
|
65
|
-
p.email = EMAIL
|
66
|
-
p.summary = DESCRIPTION
|
67
|
-
p.url = HOMEPATH
|
68
|
-
p.rubyforge_name = RUBYFORGE_PROJECT if RUBYFORGE_PROJECT
|
69
|
-
p.test_globs = ["test/**/test_*.rb"]
|
70
|
-
p.clean_globs |= CLEAN #An array of file patterns to delete on clean.
|
71
|
-
|
72
|
-
# == Optional
|
73
|
-
p.changes = p.paragraphs_of("History.txt", 0..1).join("\n\n")
|
74
|
-
#p.extra_deps = [] # An array of rubygem dependencies [name, version], e.g. [ ['active_support', '>= 1.3.1'] ]
|
75
|
-
#p.spec_extras = {} # A hash of extra values to set in the gemspec.
|
76
|
-
end
|
77
|
-
|
78
|
-
CHANGES = hoe.paragraphs_of('History.txt', 0..1).join("\n\n")
|
79
|
-
PATH = (RUBYFORGE_PROJECT == GEM_NAME) ? RUBYFORGE_PROJECT : "#{RUBYFORGE_PROJECT}/#{GEM_NAME}"
|
80
|
-
hoe.remote_rdoc_dir = File.join(PATH.gsub(/^#{RUBYFORGE_PROJECT}\/?/,''), 'rdoc')
|
81
|
-
|
82
|
-
desc 'Generate website files'
|
83
|
-
task :website_generate do
|
84
|
-
Dir['website/**/*.txt'].each do |txt|
|
85
|
-
sh %{ ruby scripts/txt2html #{txt} > #{txt.gsub(/txt$/,'html')} }
|
86
|
-
end
|
87
|
-
end
|
88
|
-
|
89
|
-
desc 'Upload website files to rubyforge'
|
90
|
-
task :website_upload do
|
91
|
-
host = "#{rubyforge_username}@rubyforge.org"
|
92
|
-
remote_dir = "/var/www/gforge-projects/#{PATH}/"
|
93
|
-
local_dir = 'website'
|
94
|
-
sh %{rsync -aCv #{local_dir}/ #{host}:#{remote_dir}}
|
95
|
-
end
|
96
|
-
|
97
|
-
desc 'Generate and upload website files'
|
98
|
-
task :website => [:website_generate, :website_upload, :publish_docs]
|
99
|
-
|
100
|
-
desc 'Release the website and new gem version'
|
101
|
-
task :deploy => [:check_version, :website, :release] do
|
102
|
-
puts "Remember to create SVN tag:"
|
103
|
-
puts "svn copy svn+ssh://#{rubyforge_username}@rubyforge.org/var/svn/#{PATH}/trunk " +
|
104
|
-
"svn+ssh://#{rubyforge_username}@rubyforge.org/var/svn/#{PATH}/tags/REL-#{VERS} "
|
105
|
-
puts "Suggested comment:"
|
106
|
-
puts "Tagging release #{CHANGES}"
|
107
|
-
end
|
108
|
-
|
109
|
-
desc 'Runs tasks website_generate and install_gem as a local deployment of the gem'
|
110
|
-
task :local_deploy => [:website_generate, :install_gem]
|
111
|
-
|
112
|
-
task :check_version do
|
113
|
-
unless ENV['VERSION']
|
114
|
-
puts 'Must pass a VERSION=x.y.z release version'
|
115
|
-
exit
|
116
|
-
end
|
117
|
-
unless ENV['VERSION'] == VERS
|
118
|
-
puts "Please update your version.rb to match the release version, currently #{VERS}"
|
119
|
-
exit
|
120
|
-
end
|
121
|
-
end
|
1
|
+
require 'bundler'
|
2
|
+
Bundler::GemHelper.install_tasks
|
122
3
|
|
4
|
+
require 'rspec/core/rake_task'
|
5
|
+
RSpec::Core::RakeTask.new
|
123
6
|
|
7
|
+
task :default => :spec
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
2
|
+
$:.push File.expand_path("../lib", __FILE__)
|
3
|
+
|
4
|
+
Gem::Specification.new do |s|
|
5
|
+
s.name = "decisiontree"
|
6
|
+
s.version = "0.4.0"
|
7
|
+
s.platform = Gem::Platform::RUBY
|
8
|
+
s.authors = ["Ilya Grigorik"]
|
9
|
+
s.email = ["ilya@igvita.com"]
|
10
|
+
s.homepage = "https://github.com/igrigorik/decisiontree"
|
11
|
+
s.summary = %q{ID3-based implementation of the M.L. Decision Tree algorithm}
|
12
|
+
s.description = s.summary
|
13
|
+
|
14
|
+
s.rubyforge_project = "decisiontree"
|
15
|
+
|
16
|
+
s.add_development_dependency "graphr"
|
17
|
+
s.add_development_dependency "rspec"
|
18
|
+
s.add_development_dependency "rspec-given"
|
19
|
+
s.add_development_dependency "pry"
|
20
|
+
|
21
|
+
s.files = `git ls-files`.split("\n")
|
22
|
+
s.test_files = `git ls-files -- {test,spec,features}/*`.split("\n")
|
23
|
+
s.executables = `git ls-files -- bin/*`.split("\n").map{ |f| File.basename(f) }
|
24
|
+
s.require_paths = ["lib"]
|
25
|
+
end
|
data/examples/simple.rb
CHANGED
File without changes
|
data/lib/decisiontree.rb
CHANGED
@@ -1 +1 @@
|
|
1
|
-
|
1
|
+
require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
|
@@ -1,13 +1,9 @@
|
|
1
|
-
#The MIT License
|
2
|
-
|
3
|
-
###Copyright (c) 2007 Ilya Grigorik <ilya AT
|
4
|
-
###Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
|
5
|
-
|
6
|
-
|
7
|
-
require 'graph/graphviz_dot'
|
8
|
-
rescue LoadError
|
9
|
-
STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included."
|
10
|
-
end
|
1
|
+
# The MIT License
|
2
|
+
#
|
3
|
+
### Copyright (c) 2007 Ilya Grigorik <ilya AT igvita DOT com>
|
4
|
+
### Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
|
5
|
+
|
6
|
+
require 'graphr'
|
11
7
|
|
12
8
|
class Object
|
13
9
|
def save_to_file(filename)
|
@@ -19,9 +15,9 @@ class Object
|
|
19
15
|
end
|
20
16
|
end
|
21
17
|
|
22
|
-
class Array
|
23
|
-
def classification; collect { |v| v.last }; end
|
24
|
-
|
18
|
+
class Array
|
19
|
+
def classification; collect { |v| v.last }; end
|
20
|
+
|
25
21
|
# calculate information entropy
|
26
22
|
def entropy
|
27
23
|
return 0 if empty?
|
@@ -55,28 +51,34 @@ module DecisionTree
|
|
55
51
|
|
56
52
|
@tree = id3_train(data2, attributes, default)
|
57
53
|
end
|
58
|
-
|
59
|
-
def
|
60
|
-
|
61
|
-
|
62
|
-
|
54
|
+
|
55
|
+
def type(attribute)
|
56
|
+
@type.is_a?(Hash) ? @type[attribute.to_sym] : @type
|
57
|
+
end
|
58
|
+
|
59
|
+
def fitness_for(attribute)
|
60
|
+
case type(attribute)
|
61
|
+
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
63
62
|
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
|
64
63
|
end
|
65
|
-
|
66
|
-
|
64
|
+
end
|
65
|
+
|
66
|
+
def id3_train(data, attributes, default, used={})
|
67
|
+
return default if data.empty?
|
67
68
|
|
68
69
|
# return classification if all examples have the same classification
|
69
70
|
return data.first.last if data.classification.uniq.size == 1
|
70
71
|
|
71
72
|
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
|
72
|
-
performance = attributes.collect { |attribute|
|
73
|
+
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
73
74
|
max = performance.max { |a,b| a[0] <=> b[0] }
|
74
75
|
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
75
76
|
best.threshold = nil if @type == :discrete
|
76
|
-
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
77
|
+
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
77
78
|
tree, l = {best => {}}, ['>=', '<']
|
78
|
-
|
79
|
-
|
79
|
+
|
80
|
+
fitness = fitness_for(best.attribute)
|
81
|
+
case type(best.attribute)
|
80
82
|
when :continuous
|
81
83
|
data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
|
82
84
|
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
@@ -86,7 +88,7 @@ module DecisionTree
|
|
86
88
|
partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
|
87
89
|
partitions.each_with_index { |examples, i|
|
88
90
|
tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
|
89
|
-
}
|
91
|
+
}
|
90
92
|
end
|
91
93
|
|
92
94
|
tree
|
@@ -100,32 +102,32 @@ module DecisionTree
|
|
100
102
|
thresholds.pop
|
101
103
|
#thresholds -= used[attribute] if used.has_key? attribute
|
102
104
|
|
103
|
-
gain = thresholds.collect { |threshold|
|
105
|
+
gain = thresholds.collect { |threshold|
|
104
106
|
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
105
107
|
pos = (sp[0].size).to_f / data.size
|
106
108
|
neg = (sp[1].size).to_f / data.size
|
107
|
-
|
109
|
+
|
108
110
|
[data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
|
109
111
|
}.max { |a,b| a[0] <=> b[0] }
|
110
112
|
|
111
113
|
return [-1, -1] if gain.size == 0
|
112
114
|
gain
|
113
115
|
end
|
114
|
-
|
116
|
+
|
115
117
|
# ID3 for discrete label cases
|
116
118
|
def id3_discrete(data, attributes, attribute)
|
117
119
|
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
|
118
120
|
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
|
119
121
|
remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i }
|
120
|
-
|
122
|
+
|
121
123
|
[data.classification.entropy - remainder, attributes.index(attribute)]
|
122
124
|
end
|
123
125
|
|
124
126
|
def predict(test)
|
125
|
-
|
127
|
+
descend(@tree, test)
|
126
128
|
end
|
127
129
|
|
128
|
-
def graph(filename)
|
130
|
+
def graph(filename)
|
129
131
|
dgp = DotGraphPrinter.new(build_tree)
|
130
132
|
dgp.write_to_file("#{filename}.png", "png")
|
131
133
|
end
|
@@ -155,22 +157,20 @@ module DecisionTree
|
|
155
157
|
end
|
156
158
|
|
157
159
|
private
|
158
|
-
def
|
160
|
+
def descend(tree, test)
|
159
161
|
attr = tree.to_a.first
|
160
162
|
return @default if !attr
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
171
|
-
return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test)
|
163
|
+
if type(attr.first.attribute) == :continuous
|
164
|
+
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
165
|
+
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
166
|
+
return descend(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
167
|
+
return descend(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
168
|
+
else
|
169
|
+
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
170
|
+
return descend(attr[1][test[@attributes.index(attr[0].attribute)]],test)
|
171
|
+
end
|
172
172
|
end
|
173
|
-
|
173
|
+
|
174
174
|
def build_tree(tree = @tree)
|
175
175
|
return [] unless tree.is_a?(Hash)
|
176
176
|
return [["Always", @default]] if tree.empty?
|
@@ -186,7 +186,7 @@ module DecisionTree
|
|
186
186
|
child = attr[1][key]
|
187
187
|
child_text = "#{child}\n(#{child.to_s.clone.object_id})"
|
188
188
|
end
|
189
|
-
label_text = "#{key} #{
|
189
|
+
label_text = "#{key} #{type(attr[0].attribute) == :continuous ? attr[0].threshold : ""}"
|
190
190
|
|
191
191
|
[parent_text, child_text, label_text]
|
192
192
|
end
|
@@ -286,7 +286,7 @@ module DecisionTree
|
|
286
286
|
|
287
287
|
def predict(test)
|
288
288
|
@rules.each do |r|
|
289
|
-
prediction = r.predict(test)
|
289
|
+
prediction = r.predict(test)
|
290
290
|
return prediction, r.accuracy unless prediction.nil?
|
291
291
|
end
|
292
292
|
return @default, 0.0
|
data/spec/id3_spec.rb
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
require 'spec_helper'
|
2
|
+
|
3
|
+
describe describe DecisionTree::ID3Tree do
|
4
|
+
|
5
|
+
describe "simple discrete case" do
|
6
|
+
Given(:labels) { ["sun", "rain"]}
|
7
|
+
Given(:data) do
|
8
|
+
[
|
9
|
+
[1,0,1],
|
10
|
+
[0,1,0]
|
11
|
+
]
|
12
|
+
end
|
13
|
+
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) }
|
14
|
+
When { tree.train }
|
15
|
+
Then { tree.predict([1,0]).should == 1 }
|
16
|
+
Then { tree.predict([0,1]).should == 0 }
|
17
|
+
end
|
18
|
+
|
19
|
+
describe "discrete attributes" do
|
20
|
+
Given(:labels) { ["hungry", "color"] }
|
21
|
+
Given(:data) do
|
22
|
+
[
|
23
|
+
["yes", "red", "angry"],
|
24
|
+
["no", "blue", "not angry"],
|
25
|
+
["yes", "blue", "not angry"],
|
26
|
+
["no", "red", "not angry"]
|
27
|
+
]
|
28
|
+
end
|
29
|
+
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :discrete) }
|
30
|
+
When { tree.train }
|
31
|
+
Then { tree.predict(["yes", "red"]).should == "angry" }
|
32
|
+
Then { tree.predict(["no", "red"]).should == "not angry" }
|
33
|
+
end
|
34
|
+
|
35
|
+
describe "discrete attributes" do
|
36
|
+
Given(:labels) { ["hunger", "happiness"] }
|
37
|
+
Given(:data) do
|
38
|
+
[
|
39
|
+
[8, 7, "angry"],
|
40
|
+
[6, 7, "angry"],
|
41
|
+
[7, 9, "angry"],
|
42
|
+
[7, 1, "not angry"],
|
43
|
+
[2, 9, "not angry"],
|
44
|
+
[3, 2, "not angry"],
|
45
|
+
[2, 3, "not angry"],
|
46
|
+
[1, 4, "not angry"]
|
47
|
+
]
|
48
|
+
end
|
49
|
+
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :continuous) }
|
50
|
+
When { tree.train }
|
51
|
+
Then { tree.graph("continuous") }
|
52
|
+
Then { tree.predict([7, 7]).should == "angry" }
|
53
|
+
Then { tree.predict([2, 3]).should == "not angry" }
|
54
|
+
end
|
55
|
+
|
56
|
+
describe "a mixture" do
|
57
|
+
Given(:labels) { ["hunger", "color"] }
|
58
|
+
Given(:data) do
|
59
|
+
[
|
60
|
+
[8, "red", "angry"],
|
61
|
+
[6, "red", "angry"],
|
62
|
+
[7, "red", "angry"],
|
63
|
+
[7, "blue", "not angry"],
|
64
|
+
[2, "red", "not angry"],
|
65
|
+
[3, "blue", "not angry"],
|
66
|
+
[2, "blue", "not angry"],
|
67
|
+
[1, "red", "not angry"]
|
68
|
+
]
|
69
|
+
end
|
70
|
+
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", color: :discrete, hunger: :continuous) }
|
71
|
+
When { tree.train }
|
72
|
+
Then { tree.graph("continuous") }
|
73
|
+
Then { tree.predict([7, "red"]).should == "angry" }
|
74
|
+
Then { tree.predict([2, "blue"]).should == "not angry" }
|
75
|
+
end
|
76
|
+
|
77
|
+
end
|