idhja22 0.14.2

Sign up to get free protection for your applications and to get access to all the features.
data/.gitignore ADDED
@@ -0,0 +1,30 @@
1
+ *.gem
2
+ *.rbc
3
+ .bundle
4
+ .config
5
+ coverage
6
+ InstalledFiles
7
+ lib/bundler/man
8
+ pkg
9
+ rdoc
10
+ spec/reports
11
+ test/tmp
12
+ test/version_tmp
13
+ tmp
14
+
15
+ # YARD artifacts
16
+ .yardoc
17
+ _yardoc
18
+ doc/
19
+
20
+ # RVM
21
+ .rvmrc
22
+
23
+ #bundler artifacts
24
+ Gemfile.lock
25
+
26
+ #OS X files
27
+ .DS_Store
28
+
29
+ # data directory for storing csvs to run the program against
30
+ data
data/.travis.yml ADDED
@@ -0,0 +1,3 @@
1
+ language: ruby
2
+ rvm:
3
+ - "1.9.3"
data/Gemfile ADDED
@@ -0,0 +1,4 @@
1
+ source 'https://rubygems.org'
2
+
3
+ # Specify your gem's dependencies in idhja22.gemspec
4
+ gemspec
data/LICENSE.txt ADDED
@@ -0,0 +1,22 @@
1
+ Copyright (c) 2012 Henry Addison
2
+
3
+ MIT License
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining
6
+ a copy of this software and associated documentation files (the
7
+ "Software"), to deal in the Software without restriction, including
8
+ without limitation the rights to use, copy, modify, merge, publish,
9
+ distribute, sublicense, and/or sell copies of the Software, and to
10
+ permit persons to whom the Software is furnished to do so, subject to
11
+ the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be
14
+ included in all copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,39 @@
1
+ # Idhja22
2
+
3
+ [![Build Status](https://travis-ci.org/henryaddison/idhja22.png?branch=master)](https://travis-ci.org/henryaddison/idhja22)
4
+
5
+ Mostly my attempt at writing a gem.
6
+
7
+ Used for training a binary classifying tree (target values should be Y or N). Leaf nodes are a probability of Y rather than a Y or N.
8
+
9
+ ## Installation
10
+
11
+ Add this line to your application's Gemfile:
12
+
13
+ gem 'idhja22'
14
+
15
+ And then execute:
16
+
17
+ $ bundle
18
+
19
+ Or install it yourself as:
20
+
21
+ $ gem install idhja22
22
+
23
+ ## Usage
24
+
25
+ Simplest usage is to have a CSV of training data. The final column is treated as the target category value of each entry, the other columns are attributes for each datum. The first row is used as for attribute and target category labels.
26
+
27
+ > tree = Idhja22::Tree.train_from_csv('/path/to/data.csv')
28
+
29
+ To print out the rules produced by the tree:
30
+
31
+ > puts tree.get_rules
32
+
33
+ ## Contributing
34
+
35
+ 1. Fork it
36
+ 2. Create your feature branch (`git checkout -b my-new-feature`)
37
+ 3. Commit your changes (`git commit -am 'Add some feature'`)
38
+ 4. Push to the branch (`git push origin my-new-feature`)
39
+ 5. Create new Pull Request
data/Rakefile ADDED
@@ -0,0 +1,11 @@
1
+ require "bundler/gem_tasks"
2
+ require 'rspec/core/rake_task'
3
+
4
+ RSpec::Core::RakeTask.new('spec')
5
+
6
+ desc "Run specs with SimpleCov"
7
+ RSpec::Core::RakeTask.new('coverage') do |t|
8
+ ENV['COVERAGE'] = "true"
9
+ end
10
+
11
+ task :default => :spec
data/bin/idhja22 ADDED
@@ -0,0 +1,16 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require 'thor'
4
+ require 'idhja22'
5
+
6
+ class TrainAndValidate < Thor
7
+ desc "train_and_validate FILE", "train a tree for the given file and validate is against a validation set"
8
+ method_option :"training-proportion", :type => :numeric, :default => 1.0, :aliases => 't'
9
+ def train_and_validate(filename)
10
+ t, v = Idhja22::Tree.train_and_validate_from_csv(filename, options[:"training-proportion"])
11
+ puts t.get_rules
12
+ puts "Against validation set probability of successful classifiction: #{v}" if options[:"training-proportion"] < 1.0
13
+ end
14
+ end
15
+
16
+ TrainAndValidate.start
data/idhja22.gemspec ADDED
@@ -0,0 +1,23 @@
1
+ # -*- encoding: utf-8 -*-
2
+ lib = File.expand_path('../lib', __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require 'idhja22/version'
5
+
6
+ Gem::Specification.new do |gem|
7
+ gem.name = "idhja22"
8
+ gem.version = Idhja22::VERSION
9
+ gem.authors = ["Henry Addison"]
10
+ gem.description = %q{Decision Trees}
11
+ gem.summary = %q{A different take on decision trees}
12
+ gem.homepage = "https://github.com/henryaddison/idhja22"
13
+
14
+ gem.files = `git ls-files`.split($/)
15
+ gem.executables = gem.files.grep(%r{^bin/}).map{ |f| File.basename(f) }
16
+ gem.test_files = gem.files.grep(%r{^(test|spec|features)/})
17
+ gem.require_paths = ["lib"]
18
+
19
+ gem.add_development_dependency "rspec", "~>2.10"
20
+ gem.add_development_dependency "rake"
21
+ gem.add_development_dependency 'debugger'
22
+ gem.add_development_dependency 'simplecov'
23
+ end
@@ -0,0 +1,40 @@
1
+ module Idhja22
2
+ class Dataset
3
+ class Datum
4
+ attr_accessor :attributes, :category_label, :attribute_labels
5
+
6
+ def initialize(row, attr_labels, category_label)
7
+ self.category_label = category_label
8
+ raise NonUniqueAttributeLabels, "repeated attributes in #{attr_labels}" unless attr_labels == attr_labels.uniq
9
+ self.attribute_labels = attr_labels
10
+ self.attributes = row
11
+ end
12
+
13
+ def to_a
14
+ attributes
15
+ end
16
+
17
+ def [](attr_label)
18
+ if index = @attribute_labels.index(attr_label)
19
+ self.attributes[index]
20
+ else
21
+ raise UnknownAttributeLabel, "unknown attribute label #{attr_label} in labels #{@attribute_labels.join(', ')}"
22
+ end
23
+ end
24
+ end
25
+
26
+ class Example < Datum
27
+ attr_accessor :category
28
+
29
+ def initialize(row, attr_labels, category_label)
30
+ super
31
+ self.category = self.attributes.pop
32
+ raise UnknownCategoryValue, "Unrecognised category: #{@category} - should be Y or N" unless ['Y', 'N'].include?(@category)
33
+ end
34
+
35
+ def to_a
36
+ super+[category]
37
+ end
38
+ end
39
+ end
40
+ end
@@ -0,0 +1,13 @@
1
+ module Idhja22
2
+ class Dataset
3
+ class BadData < ArgumentError; end
4
+ class InsufficientData < BadData; end
5
+ class NonUniqueAttributeLabels < BadData; end
6
+ class Datum
7
+ class UnknownAttributeLabel < BadData; end
8
+ class UnknownAttributeValue < BadData; end
9
+ class UnknownCategoryLabel < BadData; end
10
+ class UnknownCategoryValue < BadData; end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,27 @@
1
+ module Idhja22
2
+ class Dataset
3
+ module TreeMethods
4
+ def partition(attr_label)
5
+ groups = Hash.new([])
6
+ data.each do |datum|
7
+ groups[datum[attr_label]] += [datum]
8
+ end
9
+ output = Hash.new
10
+ groups.each do |value, data|
11
+ output[value] = Dataset.new(data, attribute_labels, category_label)
12
+ end
13
+ return output
14
+ end
15
+
16
+ def entropy
17
+ total = self.size
18
+ return 1.0 if total < Idhja22::MIN_DATASET_SIZE
19
+ category_counts.values.inject(0.0) { |ent, count| prop = count.to_f/total.to_f; ent-prop*Math.log(prop,2) }
20
+ end
21
+
22
+ def terminating?
23
+ probability > Idhja22::TERMINATION_PROBABILITY || probability < 1-Idhja22::TERMINATION_PROBABILITY
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,70 @@
1
+ require "idhja22/dataset/errors"
2
+ require "idhja22/dataset/tree_methods"
3
+ require "idhja22/dataset/datum"
4
+ require 'csv'
5
+
6
+ module Idhja22
7
+ class Dataset
8
+ attr_reader :category_label, :attribute_labels, :data
9
+
10
+ include Idhja22::Dataset::TreeMethods
11
+
12
+ class << self
13
+ def from_csv(filename)
14
+ csv = CSV.read(filename)
15
+
16
+ labels = csv.shift
17
+ category_label = labels.pop
18
+ attribute_labels = labels
19
+
20
+ data = []
21
+ csv.each do |row|
22
+ training_example = Example.new(row, attribute_labels, category_label)
23
+ data << training_example
24
+ end
25
+
26
+ new(data, attribute_labels, category_label)
27
+ end
28
+ end
29
+
30
+ def initialize(data, attr_labels, category_label)
31
+ @category_label = category_label
32
+ raise NonUniqueAttributeLabels, "repeated attributes in #{attr_labels}" unless attr_labels == attr_labels.uniq
33
+ @attribute_labels = attr_labels
34
+ @data = data
35
+ end
36
+
37
+ def category_counts
38
+ counts = Hash.new(0)
39
+ data.each do |d|
40
+ counts[d.category]+=1
41
+ end
42
+ return counts
43
+ end
44
+
45
+ def size
46
+ return data.size
47
+ end
48
+
49
+ def empty?
50
+ return data.empty?
51
+ end
52
+
53
+ def probability
54
+ category_counts['Y'].to_f/size.to_f
55
+ end
56
+
57
+ def split(training_proportion)
58
+ shuffled_data = data.shuffle
59
+ cutoff_point = (training_proportion.to_f*size).to_i
60
+
61
+ training_data = shuffled_data[0...cutoff_point]
62
+ validation_data = shuffled_data[cutoff_point...size]
63
+
64
+ training_set = self.class.new(training_data, attribute_labels, category_label)
65
+ validation_set = self.class.new(validation_data, attribute_labels, category_label)
66
+
67
+ return training_set, validation_set
68
+ end
69
+ end
70
+ end
@@ -0,0 +1,77 @@
1
+ module Idhja22
2
+ class Node
3
+ def ==(other)
4
+ return self.class == other.class
5
+ end
6
+ end
7
+
8
+ class DecisionNode < Node
9
+ attr_reader :branches, :decision_attribute
10
+ def initialize(data_split, decision_attribute, attributes_available, depth, parent_probability)
11
+ @decision_attribute = decision_attribute
12
+ @branches = {}
13
+ data_split.each do |value, dataset|
14
+ node = Tree.build_node(dataset, attributes_available, depth+1, parent_probability)
15
+ if(node.is_a?(DecisionNode) && node.branches.values.all? { |n| n.is_a?(LeafNode) })
16
+ probs = node.branches.values.collect(&:probability)
17
+ if(probs.max - probs.min < 0.01)
18
+ node = LeafNode.new(probs.max, dataset.category_label)
19
+ end
20
+ end
21
+ @branches[value] = node if node && !(node.is_a?(DecisionNode) && node.branches.empty?)
22
+ end
23
+ end
24
+
25
+ def get_rules
26
+ rules = []
27
+ branches.each do |v,n|
28
+ current_rule = "#{decision_attribute} == #{v}"
29
+ sub_rules = n.get_rules
30
+ sub_rules.each do |r|
31
+ rules << "#{current_rule} and #{r}"
32
+ end
33
+ end
34
+
35
+ return rules
36
+ end
37
+
38
+ def ==(other)
39
+ return false unless super
40
+ return false unless self.decision_attribute == other.decision_attribute
41
+ return false unless self.branches.length == other.branches.length
42
+ self.branches.each do |attr_value, node|
43
+ return false unless other.branches.has_key?(attr_value)
44
+ return false unless node == other.branches[attr_value]
45
+ end
46
+ return true
47
+ end
48
+
49
+ def evaluate(query)
50
+ queried_value = query[self.decision_attribute]
51
+ branch = self.branches[queried_value]
52
+ raise Idhja22::Dataset::Datum::UnknownAttributeValue, "when looking at attribute labelled #{self.decision_attribute} could not find branch for value #{queried_value}" if branch.nil?
53
+ branch.evaluate(query)
54
+ end
55
+ end
56
+
57
+ class LeafNode < Node
58
+ attr_reader :probability, :category_label
59
+ def initialize(probability, category_label)
60
+ @probability = probability
61
+ @category_label = category_label
62
+ end
63
+
64
+ def get_rules
65
+ ["then chance of #{category_label} = #{probability.round(2)}"]
66
+ end
67
+
68
+ def ==(other)
69
+ return super && self.probability == other.probability && self.category_label == other.category_label
70
+ end
71
+
72
+ def evaluate(query)
73
+ raise Idhja22::Dataset::Datum::UnknownCategoryLabel, "expected category label for query is #{query.category_label} but node is using #{self.category_label}" unless query.category_label == self.category_label
74
+ return probability
75
+ end
76
+ end
77
+ end
@@ -0,0 +1,110 @@
1
+ module Idhja22
2
+ class Tree
3
+ attr_accessor :root
4
+ class << self
5
+ def train(dataset)
6
+ new(dataset, dataset.attribute_labels)
7
+ end
8
+
9
+ def train_and_validate(dataset, training_proportion=0.5)
10
+ training_set, validation_set = dataset.split(training_proportion)
11
+ tree = self.train(training_set)
12
+ validation_value = tree.validate(validation_set)
13
+ return tree, validation_value
14
+ end
15
+
16
+ def train_from_csv(filename)
17
+ ds = Dataset.from_csv(filename)
18
+ train(ds)
19
+ end
20
+
21
+ def train_and_validate_from_csv(filename, training_proportion=0.5)
22
+ ds = Dataset.from_csv(filename)
23
+ train_and_validate(ds, training_proportion)
24
+ end
25
+
26
+ def build_node(dataset, attributes_available, depth, parent_probability = nil)
27
+ if(dataset.size < Idhja22::MIN_DATASET_SIZE)
28
+ return Idhja22::LeafNode.new(probability_guess(parent_probability, depth), dataset.category_label)
29
+ end
30
+
31
+ #if successful termination - create and return a leaf node
32
+ if(dataset.terminating? && depth > 0) # don't terminate without splitting the data at least once
33
+ return Idhja22::LeafNode.new(dataset.probability, dataset.category_label)
34
+ end
35
+
36
+ if(depth >= 3) # don't let trees get too long
37
+ return Idhja22::LeafNode.new(dataset.probability, dataset.category_label)
38
+ end
39
+
40
+ #if we have no more attributes left to split the dataset on, then return a leafnode
41
+ if(attributes_available.empty?)
42
+ return Idhja22::LeafNode.new(dataset.probability, dataset.category_label)
43
+ end
44
+
45
+ data_split , best_attribute = best_attribute(dataset, attributes_available)
46
+
47
+ node = Idhja22::DecisionNode.new(data_split, best_attribute, attributes_available-[best_attribute], depth, dataset.probability)
48
+
49
+ return node
50
+ end
51
+
52
+ private
53
+ def best_attribute(dataset, attributes_available)
54
+ data_split = best_attribute = nil
55
+ igain = - Float::INFINITY
56
+
57
+ attributes_available.each do |attr_label|
58
+ possible_split = dataset.partition(attr_label)
59
+ possible_igain = dataset.entropy
60
+ possible_split.each do |value, ds|
61
+ possible_igain -= (ds.size.to_f/dataset.size.to_f)*ds.entropy
62
+ end
63
+ if(possible_igain > igain)
64
+ igain = possible_igain
65
+ data_split = possible_split
66
+ best_attribute = attr_label
67
+ end
68
+ end
69
+ return data_split, best_attribute
70
+ end
71
+
72
+ def probability_guess(parent_probability, depth)
73
+ return (parent_probability + (Idhja22::DEFAULT_PROBABILITY-parent_probability)/2**depth)
74
+ end
75
+ end
76
+
77
+ def initialize(dataset, attributes_available)
78
+ raise Idhja22::Dataset::InsufficientData, "require at least #{Idhja22::MIN_DATASET_SIZE} data points, only have #{dataset.size} in data set provided" if(dataset.size < Idhja22::MIN_DATASET_SIZE)
79
+ @root = self.class.build_node(dataset, attributes_available, 0)
80
+ end
81
+
82
+ def get_rules
83
+ rules = root.get_rules
84
+ "if " + rules.join("\nelsif ")
85
+ end
86
+
87
+ def ==(other)
88
+ return self.root == other.root
89
+ end
90
+
91
+ def evaluate query
92
+ @root.evaluate(query)
93
+ end
94
+
95
+ def validate(ds)
96
+ output = 0
97
+ ds.data.each do |validation_point|
98
+ begin
99
+ prob = evaluate(validation_point)
100
+ output += (validation_point.category == 'Y' ? prob : 1.0 - prob)
101
+ rescue Idhja22::Dataset::Datum::UnknownAttributeValue
102
+ # if don't recognised the attribute value in the example, then assume the worst:
103
+ # will never classify this point correctly
104
+ # equivalent to output += 0 but no point running this
105
+ end
106
+ end
107
+ return output.to_f/ds.size.to_f
108
+ end
109
+ end
110
+ end
@@ -0,0 +1,3 @@
1
+ module Idhja22
2
+ VERSION = "0.14.2"
3
+ end
data/lib/idhja22.rb ADDED
@@ -0,0 +1,10 @@
1
+ require "idhja22/version"
2
+ require "idhja22/dataset"
3
+ require "idhja22/tree"
4
+ require "idhja22/node"
5
+
6
+ module Idhja22
7
+ DEFAULT_PROBABILITY = 0.5
8
+ TERMINATION_PROBABILITY = 0.95
9
+ MIN_DATASET_SIZE = 20
10
+ end
@@ -0,0 +1,11 @@
1
+ 0,1,2,3,4,C
2
+ a,a,a,a,a,Y
3
+ a,a,b,b,a,N
4
+ a,a,a,c,a,Y
5
+ b,a,a,a,a,Y
6
+ b,a,b,c,a,N
7
+ a,a,a,a,a,Y
8
+ a,a,a,a,a,Y
9
+ a,a,a,a,a,Y
10
+ a,a,a,a,b,N
11
+ a,a,a,a,b,Y
@@ -0,0 +1,59 @@
1
+ require 'spec_helper'
2
+
3
+ describe Idhja22::Dataset::Example do
4
+ before(:all) do
5
+ @datum = Idhja22::Dataset::Example.new(['high', '20-30', 'vanilla','Y'], ['confidence', 'age', 'fav ice cream'], 'likes')
6
+ end
7
+
8
+ describe 'new' do
9
+ it 'should extract attributes' do
10
+ @datum.attributes.should == ['high', '20-30', 'vanilla']
11
+ @datum.attribute_labels.should == ['confidence', 'age', 'fav ice cream']
12
+ end
13
+
14
+ it 'should extract category' do
15
+ @datum.category.should == 'Y'
16
+ @datum.category_label.should == 'likes'
17
+ end
18
+
19
+ context 'with non-unique attribute labels' do
20
+ it 'should throw an exception' do
21
+ expect do
22
+ Idhja22::Dataset::Example.new(['high', '20-30', 'vanilla','Y'], ['confidence', 'age', 'age'], 'likes')
23
+ end.to raise_error(Idhja22::Dataset::NonUniqueAttributeLabels)
24
+ end
25
+ end
26
+
27
+ context 'unexpected label' do
28
+ it 'should raise an exception' do
29
+ expect do
30
+ Idhja22::Dataset::Example.new(['high', '20-30', 'vanilla','H'], ['confidence', 'age', 'fav ice cream'], 'likes')
31
+ end.to raise_error(Idhja22::Dataset::Example::UnknownCategoryValue)
32
+ end
33
+ end
34
+ end
35
+
36
+ describe 'to_a' do
37
+ it 'should list the data in an array format' do
38
+ @datum.to_a.should == ['high', '20-30', 'vanilla','Y']
39
+ end
40
+ end
41
+
42
+ describe '[]' do
43
+ context 'known attribute' do
44
+ it 'should map attribute label to value' do
45
+ @datum['age'].should == '20-30'
46
+ end
47
+ end
48
+
49
+ context 'unknown attribute' do
50
+ it 'should throw an exception' do
51
+ expect do
52
+ @datum['madeup']
53
+ end.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
54
+ end
55
+ end
56
+
57
+
58
+ end
59
+ end
@@ -0,0 +1,130 @@
1
+ require 'spec_helper'
2
+
3
+ describe Idhja22::Dataset do
4
+ context 'initialization' do
5
+
6
+ def check_labels(obj, exp_attr_labels, exp_cat_label)
7
+ obj.attribute_labels.should == exp_attr_labels
8
+ obj.category_label.should == exp_cat_label
9
+ end
10
+
11
+ describe 'from_csv' do
12
+ before(:all) do
13
+ @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'spec_data.csv'))
14
+ end
15
+
16
+ it 'should extract labels' do
17
+ check_labels(@ds, ['Weather', 'Temperature', 'Wind'], 'Plays')
18
+ end
19
+
20
+ it 'should extract data' do
21
+ @ds.data.length.should == 3
22
+ @ds.data.collect(&:attributes).should == [['sunny', 'hot', 'light'], ['sunny', 'cold', 'medium'], ['raining', 'cold', 'high']]
23
+ @ds.data.collect(&:category).should == ['Y', 'Y','N']
24
+ end
25
+ end
26
+
27
+ describe 'new' do
28
+ before(:all) do
29
+ @ds = Idhja22::Dataset.new([Idhja22::Dataset::Example.new(['high', '20-30', 'vanilla', 'Y'], ['Confidence', 'Age group', 'fav ice cream'] , 'Loves Reading')], ['Confidence', 'Age group', 'fav ice cream'], 'Loves Reading')
30
+ end
31
+
32
+ it 'should extract labels' do
33
+ check_labels(@ds, ['Confidence', 'Age group', 'fav ice cream'], 'Loves Reading')
34
+ end
35
+
36
+ it 'should extract data' do
37
+ @ds.data.length.should == 1
38
+ @ds.data.first.attributes.should == ['high', '20-30', 'vanilla']
39
+ @ds.data.first.category.should == 'Y'
40
+ end
41
+
42
+ context 'with repeated attribute labels' do
43
+ it 'should throw an error' do
44
+ expect do
45
+ Idhja22::Dataset.new([Idhja22::Dataset::Example.new(['high', '20-30', 'vanilla', 'Y'], ['Confidence', 'Age group', 'Confidence'] , 'Loves Reading')], ['Confidence', 'Age group', 'Confidence'], 'Loves Reading')
46
+ end.to raise_error(Idhja22::Dataset::NonUniqueAttributeLabels)
47
+ end
48
+ end
49
+ end
50
+
51
+ context 'ready made' do
52
+ before(:all) do
53
+ @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
54
+ end
55
+
56
+ describe '#partition' do
57
+ it 'should split the data set based on the values of an given attribute index' do
58
+ new_sets = @ds.partition('0')
59
+ new_sets.length.should == 2
60
+ new_sets.each do |value, dset|
61
+ dset.data.collect { |d| d.attributes[0] }.uniq.should == [value]
62
+ end
63
+ end
64
+
65
+ it 'should preserve the data other than splitting it' do
66
+ new_sets = @ds.partition('3')
67
+ new_sets.length.should == 3
68
+ new_sets['a'].attribute_labels.should == @ds.attribute_labels
69
+ new_sets['a'].category_label.should == @ds.category_label
70
+ new_sets['a'].data.collect(&:to_a).should == [%w{a a a a a Y}, %w{b a a a a Y}, %w{a a a a a Y}, %w{a a a a a Y}, %w{a a a a a Y}, %w{a a a a b N}, %w{a a a a b N}]
71
+ end
72
+
73
+
74
+ it 'should produce one item when the values are all the same' do
75
+ @ds.partition('1').length.should == 1
76
+ end
77
+ end
78
+
79
+ describe 'category_counts' do
80
+ it 'should count the number of entries in each category' do
81
+ @ds.category_counts.should == {'Y' => 6, 'N' => 4}
82
+ end
83
+ end
84
+
85
+ describe '#entropy' do
86
+ it 'should calculate entropy of set' do
87
+ @ds.entropy.should be_within(0.000001).of(0.970951)
88
+ end
89
+
90
+ context 'with little data' do
91
+ it 'should return 1' do
92
+ ds = Idhja22::Dataset.new([Idhja22::Dataset::Example.new(['high', '20-30', 'vanilla', 'Y'], ['Confidence', 'Age group', 'fav ice cream'] , 'Loves Reading')], ['Confidence', 'Age group', 'fav ice cream'], 'Loves Reading')
93
+ ds.entropy.should == 1.0
94
+ end
95
+ end
96
+
97
+ end
98
+
99
+ describe '#size' do
100
+ it 'should calculate size of dataset' do
101
+ @ds.size.should == 10
102
+ end
103
+ end
104
+
105
+ describe '#empty?' do
106
+ it 'should calculate size of dataset' do
107
+ @ds.empty?.should be_false
108
+ end
109
+ end
110
+
111
+ describe '#probability' do
112
+ it 'should return probabilty category is Y' do
113
+ @ds.probability.should be_within(0.0001).of(0.6)
114
+ end
115
+ end
116
+
117
+ describe '#split' do
118
+ it 'should split into a training and validation set according to the given proportion' do
119
+ ts, vs = @ds.split(0.5)
120
+ ts.size.should == 5
121
+ vs.size.should == 5
122
+
123
+ ts, vs = @ds.split(0.75)
124
+ ts.size.should == 7
125
+ vs.size.should == 3
126
+ end
127
+ end
128
+ end
129
+ end
130
+ end
@@ -0,0 +1,11 @@
1
+ 0,1,2,3,4,C
2
+ a,a,a,a,a,Y
3
+ a,a,b,b,a,N
4
+ a,a,a,c,a,Y
5
+ b,a,a,a,a,Y
6
+ b,a,b,c,a,N
7
+ a,a,a,a,a,Y
8
+ a,a,a,a,a,Y
9
+ a,a,a,a,a,Y
10
+ a,a,a,a,b,N
11
+ a,a,a,a,b,N
data/spec/node_spec.rb ADDED
@@ -0,0 +1,97 @@
1
+ require 'spec_helper'
2
+
3
+ describe Idhja22::LeafNode do
4
+ describe('.new') do
5
+ it 'should store probability and category label' do
6
+ l = Idhja22::LeafNode.new(0.75, 'label')
7
+ l.probability.should == 0.75
8
+ l.category_label.should == 'label'
9
+ end
10
+ end
11
+
12
+ describe('#get_rules') do
13
+ it 'should return the probability' do
14
+ l = Idhja22::LeafNode.new(0.75, 'pudding')
15
+ l.get_rules.should == ['then chance of pudding = 0.75']
16
+ end
17
+ end
18
+
19
+ describe(' == ') do
20
+ let(:l1) { Idhja22::LeafNode.new(0.75, 'pudding') }
21
+ let(:l2) { Idhja22::LeafNode.new(0.75, 'pudding') }
22
+ let(:diff_l1) { Idhja22::LeafNode.new(0.7, 'pudding') }
23
+ let(:diff_l2) { Idhja22::LeafNode.new(0.75, 'starter') }
24
+ it 'should compare attributes' do
25
+ l1.should == l2
26
+ l1.should_not == diff_l1
27
+ l1.should_not == diff_l2
28
+ end
29
+ end
30
+
31
+ describe 'evaluate' do
32
+ let(:leaf) { Idhja22::LeafNode.new(0.6, 'pudding') }
33
+
34
+ it 'should return probability' do
35
+ query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'pudding')
36
+ leaf.evaluate(query).should == 0.6
37
+ end
38
+
39
+ context 'mismatching category labels' do
40
+ it 'should raise error' do
41
+ query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'tennis')
42
+ expect {leaf.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownCategoryLabel)
43
+ end
44
+ end
45
+ end
46
+ end
47
+
48
+ describe Idhja22::DecisionNode do
49
+ before(:all) do
50
+ @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
51
+ end
52
+
53
+ describe('#get_rules') do
54
+ it 'should return a list of rules' do
55
+ l = Idhja22::DecisionNode.new(@ds.partition('2'), '3', [], 0, 0.75)
56
+ l.get_rules.should == ["3 == a and then chance of C = 0.75", "3 == b and then chance of C = 0.0"]
57
+ end
58
+ end
59
+
60
+ describe(' == ') do
61
+ let(:dn1) { Idhja22::DecisionNode.new(@ds.partition('2'), '2', [], 0, 0.75) }
62
+ let(:dn2) { Idhja22::DecisionNode.new(@ds.partition('2'), '2', [], 0, 0.75) }
63
+ let(:diff_dn1) { Idhja22::DecisionNode.new(@ds.partition('0'), '2', [], 0, 0.75) }
64
+ let(:diff_dn2) { Idhja22::DecisionNode.new(@ds.partition('3'), '3', [], 0, 0.75) }
65
+
66
+ it 'should compare ' do
67
+ dn1.should == dn2
68
+ dn1.should_not == diff_dn1
69
+ dn1.should_not == diff_dn2
70
+ end
71
+ end
72
+
73
+ describe 'evaluate' do
74
+ let(:dn) { Idhja22::DecisionNode.new(@ds.partition('2'), '3', [], 0, 0.75) }
75
+ it 'should follow node to probability' do
76
+ query = Idhja22::Dataset::Datum.new(['a', 'a'], ['3', '4'], 'C')
77
+ dn.evaluate(query).should == 0.75
78
+
79
+ query = Idhja22::Dataset::Datum.new(['b', 'a'], ['3', '4'], 'C')
80
+ dn.evaluate(query).should == 0.0
81
+ end
82
+
83
+ context 'mismatching attribute label' do
84
+ it 'should raise an error' do
85
+ query = Idhja22::Dataset::Datum.new(['b', 'a'], ['1', '2'], 'C')
86
+ expect {dn.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
87
+ end
88
+ end
89
+
90
+ context 'unknown attribute value' do
91
+ it 'should raise an error' do
92
+ query = Idhja22::Dataset::Datum.new(['c', 'a'], ['3', '4'], 'C')
93
+ expect {dn.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeValue)
94
+ end
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,4 @@
1
+ Weather,Temperature,Wind,Plays
2
+ sunny,hot,light,Y
3
+ sunny,cold,medium,Y
4
+ raining,cold,high,N
@@ -0,0 +1,20 @@
1
+ if ENV['COVERAGE']
2
+ require 'simplecov'
3
+ SimpleCov.start do
4
+ add_filter '/spec/'
5
+ end
6
+ end
7
+ $: << File.dirname(__FILE__) + '/../lib'
8
+
9
+ require 'idhja22'
10
+ require 'ruby-debug'
11
+
12
+ module Idhja22
13
+ MIN_DATASET_SIZE = 2
14
+ end
15
+
16
+
17
+
18
+ RSpec.configure do |config|
19
+
20
+ end
data/spec/tree_spec.rb ADDED
@@ -0,0 +1,93 @@
1
+ require 'spec_helper'
2
+
3
+ describe Idhja22::Tree do
4
+ before(:all) do
5
+ @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
6
+ end
7
+
8
+
9
+ describe('.train') do
10
+ it 'should make a tree' do
11
+ tree = Idhja22::Tree.train(@ds)
12
+ end
13
+
14
+ context 'with insufficient data' do
15
+ it 'should throw exception' do
16
+ ds = Idhja22::Dataset.new([Idhja22::Dataset::Datum.new(['high', '20-30', 'Vanilla', 'Y'], ['Confidence', 'Age group', 'Fav ice cream'] , 'Loves Reading')], ['Confidence', 'Age group', 'Fav ice cream'], 'Loves Reading')
17
+ expect { Idhja22::Tree.train(ds) }.to raise_error(Idhja22::Dataset::InsufficientData)
18
+ end
19
+ end
20
+ end
21
+
22
+ describe('#get_rules') do
23
+ it 'should list the rules of the tree' do
24
+ Idhja22::Tree.train(@ds).get_rules.should == "if 2 == a and 4 == a and then chance of C = 1.0\nelsif 2 == a and 4 == b and then chance of C = 0.0\nelsif 2 == b and then chance of C = 0.0"
25
+ end
26
+ end
27
+
28
+ describe(' == ') do
29
+ it 'should compare root nodes' do
30
+ tree1 = Idhja22::Tree.train(@ds)
31
+ tree2 = Idhja22::Tree.train(@ds)
32
+ diff_ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'another_large_spec_data.csv'))
33
+ diff_tree = Idhja22::Tree.train(diff_ds)
34
+ tree1.should == tree2
35
+ tree1.should_not == diff_tree
36
+ end
37
+ end
38
+
39
+ describe('.train_from_csv') do
40
+ it 'should make the same tree as the one from the dataset' do
41
+ tree = Idhja22::Tree.train(@ds)
42
+ csv_tree = Idhja22::Tree.train_from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
43
+ tree.should == csv_tree
44
+ end
45
+ end
46
+
47
+ describe('#evaluate') do
48
+ it 'should return the probabilty at the leaf of the tree' do
49
+ tree = Idhja22::Tree.train(@ds)
50
+ query = Idhja22::Dataset::Datum.new(['z','z','a','z','a'],['0', '1','2','3','4'],'C')
51
+ tree.evaluate(query).should == 1.0
52
+ end
53
+ end
54
+
55
+ describe '#validate' do
56
+ before(:all) do
57
+ @tree = Idhja22::Tree.train(@ds)
58
+ end
59
+
60
+ it 'should return the average probability that the tree gets the validation examples correct' do
61
+ vps = [Idhja22::Dataset::Example.new(['z','z','a','z','a','Y'],['0', '1','2','3','4'],'C')]
62
+ vps << Idhja22::Dataset::Example.new(['z','z','a','z','a','N'],['0', '1','2','3','4'],'C')
63
+ @tree.validate(Idhja22::Dataset.new(vps, ['0', '1','2','3','4'],'C')).should == 0.5
64
+ end
65
+
66
+ context 'against a data point with an unrecognised attribute value' do
67
+ before(:all) do
68
+ validation_point = Idhja22::Dataset::Example.new(['z','z','o','z','a','Y'],['0', '1','2','3','4'],'C')
69
+ @vps = Idhja22::Dataset.new([validation_point], ['0', '1','2','3','4'],'C')
70
+ end
71
+
72
+ it 'should treat a validation example as one it will never get right' do
73
+ @tree.validate(@vps).should == 0.0
74
+ end
75
+ end
76
+ end
77
+
78
+ describe '.train_and_validate' do
79
+ it 'should return a tree and the validation result' do
80
+ tree, value = Idhja22::Tree.train_and_validate(@ds)
81
+ tree.is_a?(Idhja22::Tree).should be_true
82
+ (0..1).include?(value).should be_true
83
+ end
84
+ end
85
+
86
+ describe('.train_and_validate_from_csv') do
87
+ it 'should make the same tree as the one from the dataset' do
88
+ csv_tree, validation_value = Idhja22::Tree.train_and_validate_from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'), 0.75)
89
+ csv_tree.is_a?(Idhja22::Tree).should be_true
90
+ (0..1).include?(validation_value).should be_true
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,9 @@
1
+ require 'spec_helper'
2
+
3
+ describe Idhja22 do
4
+ describe 'VERSION' do
5
+ it 'should be current version' do
6
+ Idhja22::VERSION.should == '0.14.2'
7
+ end
8
+ end
9
+ end
metadata ADDED
@@ -0,0 +1,149 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: idhja22
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.14.2
5
+ prerelease:
6
+ platform: ruby
7
+ authors:
8
+ - Henry Addison
9
+ autorequire:
10
+ bindir: bin
11
+ cert_chain: []
12
+ date: 2012-12-17 00:00:00.000000000 Z
13
+ dependencies:
14
+ - !ruby/object:Gem::Dependency
15
+ name: rspec
16
+ requirement: !ruby/object:Gem::Requirement
17
+ none: false
18
+ requirements:
19
+ - - ~>
20
+ - !ruby/object:Gem::Version
21
+ version: '2.10'
22
+ type: :development
23
+ prerelease: false
24
+ version_requirements: !ruby/object:Gem::Requirement
25
+ none: false
26
+ requirements:
27
+ - - ~>
28
+ - !ruby/object:Gem::Version
29
+ version: '2.10'
30
+ - !ruby/object:Gem::Dependency
31
+ name: rake
32
+ requirement: !ruby/object:Gem::Requirement
33
+ none: false
34
+ requirements:
35
+ - - ! '>='
36
+ - !ruby/object:Gem::Version
37
+ version: '0'
38
+ type: :development
39
+ prerelease: false
40
+ version_requirements: !ruby/object:Gem::Requirement
41
+ none: false
42
+ requirements:
43
+ - - ! '>='
44
+ - !ruby/object:Gem::Version
45
+ version: '0'
46
+ - !ruby/object:Gem::Dependency
47
+ name: debugger
48
+ requirement: !ruby/object:Gem::Requirement
49
+ none: false
50
+ requirements:
51
+ - - ! '>='
52
+ - !ruby/object:Gem::Version
53
+ version: '0'
54
+ type: :development
55
+ prerelease: false
56
+ version_requirements: !ruby/object:Gem::Requirement
57
+ none: false
58
+ requirements:
59
+ - - ! '>='
60
+ - !ruby/object:Gem::Version
61
+ version: '0'
62
+ - !ruby/object:Gem::Dependency
63
+ name: simplecov
64
+ requirement: !ruby/object:Gem::Requirement
65
+ none: false
66
+ requirements:
67
+ - - ! '>='
68
+ - !ruby/object:Gem::Version
69
+ version: '0'
70
+ type: :development
71
+ prerelease: false
72
+ version_requirements: !ruby/object:Gem::Requirement
73
+ none: false
74
+ requirements:
75
+ - - ! '>='
76
+ - !ruby/object:Gem::Version
77
+ version: '0'
78
+ description: Decision Trees
79
+ email:
80
+ executables:
81
+ - idhja22
82
+ extensions: []
83
+ extra_rdoc_files: []
84
+ files:
85
+ - .gitignore
86
+ - .travis.yml
87
+ - Gemfile
88
+ - LICENSE.txt
89
+ - README.md
90
+ - Rakefile
91
+ - bin/idhja22
92
+ - idhja22.gemspec
93
+ - lib/idhja22.rb
94
+ - lib/idhja22/dataset.rb
95
+ - lib/idhja22/dataset/datum.rb
96
+ - lib/idhja22/dataset/errors.rb
97
+ - lib/idhja22/dataset/tree_methods.rb
98
+ - lib/idhja22/node.rb
99
+ - lib/idhja22/tree.rb
100
+ - lib/idhja22/version.rb
101
+ - spec/another_large_spec_data.csv
102
+ - spec/dataset/example_spec.rb
103
+ - spec/dataset_spec.rb
104
+ - spec/large_spec_data.csv
105
+ - spec/node_spec.rb
106
+ - spec/spec_data.csv
107
+ - spec/spec_helper.rb
108
+ - spec/tree_spec.rb
109
+ - spec/version_spec.rb
110
+ homepage: https://github.com/henryaddison/idhja22
111
+ licenses: []
112
+ post_install_message:
113
+ rdoc_options: []
114
+ require_paths:
115
+ - lib
116
+ required_ruby_version: !ruby/object:Gem::Requirement
117
+ none: false
118
+ requirements:
119
+ - - ! '>='
120
+ - !ruby/object:Gem::Version
121
+ version: '0'
122
+ segments:
123
+ - 0
124
+ hash: -4104544286961851710
125
+ required_rubygems_version: !ruby/object:Gem::Requirement
126
+ none: false
127
+ requirements:
128
+ - - ! '>='
129
+ - !ruby/object:Gem::Version
130
+ version: '0'
131
+ segments:
132
+ - 0
133
+ hash: -4104544286961851710
134
+ requirements: []
135
+ rubyforge_project:
136
+ rubygems_version: 1.8.24
137
+ signing_key:
138
+ specification_version: 3
139
+ summary: A different take on decision trees
140
+ test_files:
141
+ - spec/another_large_spec_data.csv
142
+ - spec/dataset/example_spec.rb
143
+ - spec/dataset_spec.rb
144
+ - spec/large_spec_data.csv
145
+ - spec/node_spec.rb
146
+ - spec/spec_data.csv
147
+ - spec/spec_helper.rb
148
+ - spec/tree_spec.rb
149
+ - spec/version_spec.rb