idhja22 0.14.4 → 1.1.0

Sign up to get free protection for your applications and to get access to all the features.
data/.gitignore CHANGED
@@ -27,4 +27,4 @@ Gemfile.lock
27
27
  .DS_Store
28
28
 
29
29
  # data directory for storing csvs to run the program against
30
- data
30
+ /data
data/bin/idhja22 CHANGED
@@ -4,13 +4,22 @@ require 'thor'
4
4
  require 'idhja22'
5
5
 
6
6
  class TrainAndValidate < Thor
7
- desc "train_and_validate FILE", "train a tree for the given file and validate is against a validation set"
7
+ desc "train_and_validate_tree FILE", "train a tree for the given file and validate is against a validation set"
8
+ method_option :attributes, :type => :array
8
9
  method_option :"training-proportion", :type => :numeric, :default => 1.0, :aliases => 't'
9
- def train_and_validate(filename)
10
- t, v = Idhja22::Tree.train_and_validate_from_csv(filename, options[:"training-proportion"])
10
+ def train_and_validate_tree(filename)
11
+ t, v = Idhja22::Tree.train_and_validate_from_csv(filename, options)
11
12
  puts t.get_rules
12
13
  puts "Against validation set probability of successful classifiction: #{v}" if options[:"training-proportion"] < 1.0
13
14
  end
15
+
16
+ desc "train_and_validate_bayes FILE", "train a naive Bayesian classifier for the given file and validate is against a validation set"
17
+ method_option :attributes, :type => :array
18
+ method_option :"training-proportion", :type => :numeric, :default => 1.0, :aliases => 't'
19
+ def train_and_validate_bayes(filename)
20
+ t, v = Idhja22::Bayes.train_and_validate_from_csv(filename, options)
21
+ puts "Against validation set probability of successful classifiction: #{v}" if options[:"training-proportion"] < 1.0
22
+ end
14
23
  end
15
24
 
16
25
  TrainAndValidate.start
data/idhja22.gemspec CHANGED
@@ -7,8 +7,8 @@ Gem::Specification.new do |gem|
7
7
  gem.name = "idhja22"
8
8
  gem.version = Idhja22::VERSION
9
9
  gem.authors = ["Henry Addison"]
10
- gem.description = %q{Decision Trees}
11
- gem.summary = %q{A gem for creating decision trees}
10
+ gem.description = %q{Classifiers}
11
+ gem.summary = %q{A gem for creating classifiers (decision trees and naive Bayes so far)}
12
12
  gem.homepage = "https://github.com/henryaddison/idhja22"
13
13
 
14
14
  gem.files = `git ls-files`.split($/)
data/lib/idhja22/bayes.rb CHANGED
@@ -1,5 +1,53 @@
1
1
  module Idhja22
2
- class Bayes
3
-
2
+ class Bayes < BinaryClassifier
3
+ attr_accessor :conditional_probabilities, :prior_probabilities
4
+ class << self
5
+ def calculate_conditional_probabilities dataset, attribute_labels_to_use
6
+ conditional_probabilities = {}
7
+ attribute_labels_to_use.each do |attr_label|
8
+ conditional_probabilities[attr_label] = {}
9
+ dataset.partition_by_category.each do |cat, uniform_category_ds|
10
+ conditional_probabilities[attr_label][cat] = Hash.new(0)
11
+ partitioned_data = uniform_category_ds.partition(attr_label)
12
+ partitioned_data.each do |attr_value, uniform_value_ds|
13
+ conditional_probabilities[attr_label][cat][attr_value] = uniform_value_ds.size.to_f/uniform_category_ds.size.to_f
14
+ end
15
+ end
16
+ end
17
+
18
+ return conditional_probabilities
19
+ end
20
+
21
+ def calculate_priors dataset
22
+ output = Hash.new(0)
23
+ dataset.category_counts.each do |cat, count|
24
+ output[cat] = count.to_f/dataset.size.to_f
25
+ end
26
+ return output
27
+ end
28
+ end
29
+
30
+ def evaluate(query)
31
+ nb_values = {}
32
+ total_values = 0
33
+
34
+ prior_probabilities.each do |cat, prior_prob|
35
+ nb_value = prior_prob
36
+ conditional_probabilities.each do |attr_label, probs|
37
+ raise Idhja22::Dataset::Datum::UnknownAttributeValue, "Not seen value #{query[attr_label]} for attribute #{attr_label} in training." unless probs[cat].has_key? query[attr_label]
38
+ nb_value *= probs[cat][query[attr_label]]
39
+ end
40
+ total_values += nb_value
41
+ nb_values[cat] = nb_value
42
+ end
43
+
44
+ return nb_values['Y']/total_values
45
+ end
46
+
47
+ def train(dataset, attributes_to_use)
48
+ self.conditional_probabilities = self.class.calculate_conditional_probabilities(dataset, attributes_to_use)
49
+ self.prior_probabilities = self.class.calculate_priors(dataset)
50
+ return self
51
+ end
4
52
  end
5
53
  end
@@ -0,0 +1,54 @@
1
+ module Idhja22
2
+ class BinaryClassifier
3
+
4
+ class << self
5
+ # Trains a classifier using the provided Dataset.
6
+ def train(dataset, opts = {})
7
+ attributes_to_use = (opts[:attributes] || dataset.attribute_labels)
8
+ classifier = new
9
+ classifier.train(dataset, attributes_to_use)
10
+ return classifier
11
+ end
12
+
13
+ # Takes a dataset and splits it randomly into training and validation data.
14
+ # Uses the training data to train a classifier whose perfomance then measured using the validation data.
15
+ # @param [Float] Proportion of dataset to use for training. The rest will be used to validate the resulting classifier.
16
+ def train_and_validate(dataset, opts = {})
17
+ opts[:"training-proportion"] ||= 0.5
18
+ training_set, validation_set = dataset.split(opts[:"training-proportion"])
19
+ tree = self.train(training_set, opts)
20
+ validation_value = tree.validate(validation_set)
21
+ return tree, validation_value
22
+ end
23
+
24
+ # see #train
25
+ # @note Takes a CSV filename rather than a Dataset
26
+ def train_from_csv(filename, opts={})
27
+ ds = Dataset.from_csv(filename)
28
+ train(ds, opts)
29
+ end
30
+
31
+ # see #train_and_validate
32
+ # @note Takes a CSV filename rather than a Dataset
33
+ def train_and_validate_from_csv(filename, opts={})
34
+ ds = Dataset.from_csv(filename)
35
+ train_and_validate(ds, opts)
36
+ end
37
+ end
38
+
39
+ def validate(ds)
40
+ output = 0
41
+ ds.data.each do |validation_point|
42
+ begin
43
+ prob = evaluate(validation_point)
44
+ output += (validation_point.category == 'Y' ? prob : 1.0 - prob)
45
+ rescue Idhja22::Dataset::Datum::UnknownAttributeValue
46
+ # if don't recognised the attribute value in the example, then assume the worst:
47
+ # will never classify this point correctly
48
+ # equivalent to output += 0 but no point running this
49
+ end
50
+ end
51
+ return output.to_f/ds.size.to_f
52
+ end
53
+ end
54
+ end
@@ -2,4 +2,5 @@ Configuration.for('default') {
2
2
  default_probability 0.5
3
3
  termination_probability 0.95
4
4
  min_dataset_size 20
5
+ probability_delta 0.01
5
6
  }
@@ -1,4 +1,5 @@
1
1
  module Idhja22
2
+ class IncompleteTree < StandardError; end
2
3
  class Dataset
3
4
  class BadData < ArgumentError; end
4
5
  class InsufficientData < BadData; end
@@ -17,13 +17,13 @@ module Idhja22
17
17
  category_label = labels.pop
18
18
  attribute_labels = labels
19
19
 
20
- data = []
20
+ set = new([], attribute_labels, category_label)
21
21
  csv.each do |row|
22
22
  training_example = Example.new(row, attribute_labels, category_label)
23
- data << training_example
23
+ set << training_example
24
24
  end
25
25
 
26
- new(data, attribute_labels, category_label)
26
+ return set
27
27
  end
28
28
  end
29
29
 
@@ -36,8 +36,9 @@ module Idhja22
36
36
 
37
37
  def category_counts
38
38
  counts = Hash.new(0)
39
- data.each do |d|
40
- counts[d.category]+=1
39
+ split_data = partition_by_category
40
+ split_data.each do |cat, d|
41
+ counts[cat] = d.size
41
42
  end
42
43
  return counts
43
44
  end
@@ -66,5 +67,21 @@ module Idhja22
66
67
 
67
68
  return training_set, validation_set
68
69
  end
70
+
71
+ def <<(example)
72
+ raise Idhja22::Dataset::Datum::UnknownCategoryLabel unless example.category_label == self.category_label
73
+ raise Idhja22::Dataset::Datum::UnknownAttributeLabel unless example.attribute_labels == self.attribute_labels
74
+ self.data << example
75
+ end
76
+
77
+ def partition_by_category
78
+ output = Hash.new do |hash, key|
79
+ hash[key] = self.class.new([], attribute_labels, category_label)
80
+ end
81
+ self.data.each do |d|
82
+ output[d.category] << d
83
+ end
84
+ return output
85
+ end
69
86
  end
70
87
  end
@@ -20,9 +20,7 @@ module Idhja22
20
20
  return Idhja22::LeafNode.new(dataset.probability, dataset.category_label)
21
21
  end
22
22
 
23
- data_split, best_attribute = best_attribute(dataset, attributes_available)
24
-
25
- node = Idhja22::DecisionNode.new(data_split, best_attribute, attributes_available-[best_attribute], depth, dataset.probability)
23
+ node = DecisionNode.build(dataset, attributes_available, depth)
26
24
 
27
25
  return node
28
26
  end
@@ -59,21 +57,34 @@ module Idhja22
59
57
 
60
58
  class DecisionNode < Node
61
59
  attr_reader :branches, :decision_attribute
62
- def initialize(data_split, decision_attribute, attributes_available, depth, parent_probability)
63
- @decision_attribute = decision_attribute
64
- @branches = {}
65
- data_split.each do |value, dataset|
66
- node = Node.build_node(dataset, attributes_available, depth+1, parent_probability)
67
- if(node.is_a?(DecisionNode) && node.branches.values.all? { |n| n.is_a?(LeafNode) })
68
- probs = node.branches.values.collect(&:probability)
69
- if(probs.max - probs.min < 0.01)
70
- node = LeafNode.new(probs.max, dataset.category_label)
71
- end
60
+
61
+ class << self
62
+ def build(dataset, attributes_available, depth)
63
+ data_split, best_attribute = best_attribute(dataset, attributes_available)
64
+
65
+ output_node = new(best_attribute)
66
+
67
+ data_split.each do |value, dataset|
68
+ node = Node.build_node(dataset, attributes_available-[best_attribute], depth+1, dataset.probability)
69
+
70
+ output_node.add_branch(value, node) if node && !(node.is_a?(DecisionNode) && node.branches.empty?)
72
71
  end
73
- @branches[value] = node if node && !(node.is_a?(DecisionNode) && node.branches.empty?)
72
+
73
+ output_node.cleanup_children!
74
+
75
+ return output_node
74
76
  end
75
77
  end
76
78
 
79
+ def initialize(decision_attribute)
80
+ @decision_attribute = decision_attribute
81
+ @branches = {}
82
+ end
83
+
84
+ def add_branch(attr_value, node)
85
+ @branches[attr_value] = node
86
+ end
87
+
77
88
  def get_rules
78
89
  rules = []
79
90
  branches.each do |v,n|
@@ -104,6 +115,29 @@ module Idhja22
104
115
  raise Idhja22::Dataset::Datum::UnknownAttributeValue, "when looking at attribute labelled #{self.decision_attribute} could not find branch for value #{queried_value}" if branch.nil?
105
116
  branch.evaluate(query)
106
117
  end
118
+
119
+ def cleanup_children!
120
+ branches.each do |attr, child_node|
121
+ child_node.cleanup_children!
122
+ leaves = child_node.leaves
123
+ probs = leaves.collect(&:probability)
124
+ if(probs.max - probs.min < Idhja22.config.probability_delta)
125
+ new_node = LeafNode.new(probs.max, category_label)
126
+ add_branch(attr, new_node)
127
+ end
128
+ end
129
+ end
130
+
131
+ def leaves
132
+ raise Idhja22::IncompleteTree, "decision node with no branches" if branches.empty?
133
+ branches.values.flat_map do |child_node|
134
+ child_node.leaves
135
+ end
136
+ end
137
+
138
+ def category_label
139
+ leaves.first.category_label
140
+ end
107
141
  end
108
142
 
109
143
  class LeafNode < Node
@@ -125,5 +159,14 @@ module Idhja22
125
159
  raise Idhja22::Dataset::Datum::UnknownCategoryLabel, "expected category label for query is #{query.category_label} but node is using #{self.category_label}" unless query.category_label == self.category_label
126
160
  return probability
127
161
  end
162
+
163
+ def leaves
164
+ return [self]
165
+ end
166
+
167
+ # no-op method - a leaf node has no children by definition
168
+ def cleanup_children!
169
+
170
+ end
128
171
  end
129
172
  end
data/lib/idhja22/tree.rb CHANGED
@@ -2,42 +2,15 @@ require "idhja22/tree/node"
2
2
 
3
3
  module Idhja22
4
4
  # The main entry class for a training, viewing and evaluating a decision tree.
5
- class Tree
5
+ class Tree < BinaryClassifier
6
6
  attr_accessor :root
7
7
  class << self
8
- # Trains a Tree using the provided Dataset.
9
- def train(dataset)
10
- new(dataset, dataset.attribute_labels)
11
- end
12
-
13
- # Takes a dataset and splits it randomly into training and validation data.
14
- # Uses the training data to train a tree whose perfomance then measured using the validation data.
15
- # @param [Float] Proportion of dataset to use for training. The rest will be used to validate the resulting tree.
16
- def train_and_validate(dataset, training_proportion=0.5)
17
- training_set, validation_set = dataset.split(training_proportion)
18
- tree = self.train(training_set)
19
- validation_value = tree.validate(validation_set)
20
- return tree, validation_value
21
- end
22
-
23
- # see #train
24
- # @note Takes a CSV filename rather than a Dataset
25
- def train_from_csv(filename)
26
- ds = Dataset.from_csv(filename)
27
- train(ds)
28
- end
29
-
30
- # see #train_and_validate
31
- # @note Takes a CSV filename rather than a Dataset
32
- def train_and_validate_from_csv(filename, training_proportion=0.5)
33
- ds = Dataset.from_csv(filename)
34
- train_and_validate(ds, training_proportion)
35
- end
36
8
  end
37
9
 
38
- def initialize(dataset, attributes_available)
10
+ def train(dataset, attributes_available)
39
11
  raise Idhja22::Dataset::InsufficientData, "require at least #{Idhja22.config.min_dataset_size} data points, only have #{dataset.size} in data set provided" if(dataset.size < Idhja22.config.min_dataset_size)
40
12
  @root = Node.build_node(dataset, attributes_available, 0)
13
+ return self
41
14
  end
42
15
 
43
16
  def get_rules
@@ -52,20 +25,5 @@ module Idhja22
52
25
  def evaluate query
53
26
  @root.evaluate(query)
54
27
  end
55
-
56
- def validate(ds)
57
- output = 0
58
- ds.data.each do |validation_point|
59
- begin
60
- prob = evaluate(validation_point)
61
- output += (validation_point.category == 'Y' ? prob : 1.0 - prob)
62
- rescue Idhja22::Dataset::Datum::UnknownAttributeValue
63
- # if don't recognised the attribute value in the example, then assume the worst:
64
- # will never classify this point correctly
65
- # equivalent to output += 0 but no point running this
66
- end
67
- end
68
- return output.to_f/ds.size.to_f
69
- end
70
28
  end
71
29
  end
@@ -1,3 +1,3 @@
1
1
  module Idhja22
2
- VERSION = "0.14.4"
2
+ VERSION = "1.1.0"
3
3
  end
data/lib/idhja22.rb CHANGED
@@ -3,6 +3,7 @@ require 'idhja22/config/default'
3
3
 
4
4
  require "idhja22/version"
5
5
  require "idhja22/dataset"
6
+ require "idhja22/binary_classifier"
6
7
  require "idhja22/tree"
7
8
  require "idhja22/bayes"
8
9
 
data/spec/bayes_spec.rb CHANGED
@@ -1,5 +1,119 @@
1
1
  require 'spec_helper'
2
2
 
3
3
  describe Idhja22::Bayes do
4
+ before(:all) do
5
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'large_spec_data.csv'))
6
+ end
4
7
 
8
+ describe '.train' do
9
+ it 'should train a classifier from a dataset' do
10
+ classifier = Idhja22::Bayes.train @ds, :attributes => %w{0}
11
+ cond_probs = classifier.conditional_probabilities
12
+ cond_probs.keys.should == ['0']
13
+
14
+ cond_probs['0'].keys.should == ['Y', 'N']
15
+
16
+ cond_probs['0']['Y']['a'].should == 5.0/6.0
17
+ cond_probs['0']['N']['a'].should == 0.75
18
+
19
+ cond_probs['0']['Y']['b'].should == 1.0/6.0
20
+ cond_probs['0']['N']['b'].should == 0.25
21
+
22
+ prior_probs = classifier.prior_probabilities
23
+ prior_probs['Y'].should == 0.6
24
+ prior_probs['N'].should == 0.4
25
+ end
26
+ end
27
+
28
+ describe '.calculate_conditional_probabilities' do
29
+ it 'should calculate the conditional probabilities of P(Cat|attr_val) from dataset for given attribute labels' do
30
+ cond_probs = Idhja22::Bayes.calculate_conditional_probabilities @ds, %w{0 2}
31
+ cond_probs.keys.should == ['0', '2']
32
+ cond_probs['0'].keys.should == ['Y','N']
33
+ cond_probs['2'].keys.should == ['Y','N']
34
+
35
+ cond_probs['0']['Y']['a'].should == 5.0/6.0
36
+ cond_probs['0']['N']['a'].should == 0.75
37
+ cond_probs['0']['Y']['b'].should == 1.0/6.0
38
+ cond_probs['0']['N']['b'].should == 0.25
39
+
40
+ cond_probs['2']['Y']['a'].should == 1.0
41
+ cond_probs['2']['N']['a'].should == 0.5
42
+ cond_probs['2']['Y']['b'].should == 0
43
+ cond_probs['2']['N']['b'].should == 0.5
44
+ end
45
+ end
46
+
47
+ describe '.calculate_priors' do
48
+ it 'should calculate the prior probabilities' do
49
+ prior_probs = Idhja22::Bayes.calculate_priors @ds
50
+ prior_probs['Y'].should == 0.6
51
+ prior_probs['N'].should == 0.4
52
+ end
53
+
54
+ context 'all single category' do
55
+ it 'should return 0 for other categories' do
56
+ uniform_ds = Idhja22::Dataset.new([Idhja22::Dataset::Example.new(['high', '20-30', 'vanilla', 'Y'], ['Confidence', 'Age group', 'fav ice cream'] , 'Loves Reading')], ['Confidence', 'Age group', 'fav ice cream'], 'Loves Reading')
57
+ prior_probs = Idhja22::Bayes.calculate_priors uniform_ds
58
+ prior_probs['Y'].should == 1.0
59
+ prior_probs['N'].should == 0
60
+ end
61
+ end
62
+ end
63
+
64
+ describe '#evaluate' do
65
+ before(:all) do
66
+ @bayes = Idhja22::Bayes.new
67
+ @bayes.conditional_probabilities = {
68
+ 'age' => {
69
+ 'Y' => {'young' => 0.98, 'old' => 0.02},
70
+ 'N' => {'young' => 0.98, 'old' => 0.02}
71
+
72
+ },
73
+ 'confidence' => {
74
+ 'Y' => {'high' => 0.6, 'medium' => 0.3, 'low' => 0.1},
75
+ 'N' => {'high' => 0.8, 'medium' => 0.15, 'low' => 0.05}
76
+ },
77
+ 'fav ice cream' => {
78
+ 'Y' => {'vanilla' => 0.75, 'strawberry' => 0.25},
79
+ 'N' => {'vanilla' => 0.5, 'strawberry' => 0.6}
80
+ }
81
+ }
82
+ @bayes.prior_probabilities = {'Y' => 0.75, 'N' => 0.25}
83
+ end
84
+
85
+ context 'Y likely' do
86
+ it 'should return probability of being Y' do
87
+ query = Idhja22::Dataset::Datum.new(['high', 'young', 'vanilla', 'cheddar'], ['confidence', 'age', 'fav ice cream', 'fav cheese'], 'Loves Reading')
88
+ @bayes.evaluate(query).should be_within(0.00001).of(0.77143)
89
+ end
90
+ end
91
+
92
+ context 'N likely' do
93
+ it 'should return probability of being Y' do
94
+ query = Idhja22::Dataset::Datum.new(['high', 'young', 'strawberry', 'cheddar'], ['confidence', 'age', 'fav ice cream', 'fav cheese'], 'Loves Reading')
95
+ @bayes.evaluate(query).should be_within(0.00001).of(0.48387)
96
+ end
97
+ end
98
+
99
+ context 'unrecognised attribute value' do
100
+ it 'should throw an error' do
101
+ query = Idhja22::Dataset::Datum.new(['high', 'young', 'chocolate', 'cheddar'], ['confidence', 'age', 'fav ice cream', 'fav cheese'], 'Loves Reading')
102
+ expect { @bayes.evaluate(query) }.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeValue)
103
+ end
104
+ end
105
+ end
106
+
107
+ describe '#validate' do
108
+ before(:all) do
109
+ @bayes = Idhja22::Bayes.train(@ds)
110
+ end
111
+
112
+ it 'should return the average probability that the tree gets the validation examples correct' do
113
+ vds = Idhja22::Dataset.new([], ['0', '1','2','3','4'],'C')
114
+ vds << Idhja22::Dataset::Example.new(['a','a','a','a','a','Y'],['0', '1','2','3','4'],'C')
115
+ vds << Idhja22::Dataset::Example.new(['a','a','a','a','a','N'],['0', '1','2','3','4'],'C')
116
+ @bayes.validate(vds).should == 0.5
117
+ end
118
+ end
5
119
  end
@@ -0,0 +1,9 @@
1
+ 1,2,C
2
+ a,a,Y
3
+ a,a,N
4
+ a,b,Y
5
+ a,b,N
6
+ b,a,Y
7
+ b,a,N
8
+ b,b,Y
9
+ b,b,N
File without changes
data/spec/dataset_spec.rb CHANGED
@@ -10,7 +10,7 @@ describe Idhja22::Dataset do
10
10
 
11
11
  describe 'from_csv' do
12
12
  before(:all) do
13
- @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'spec_data.csv'))
13
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'spec_data.csv'))
14
14
  end
15
15
 
16
16
  it 'should extract labels' do
@@ -50,7 +50,7 @@ describe Idhja22::Dataset do
50
50
 
51
51
  context 'ready made' do
52
52
  before(:all) do
53
- @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
53
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'large_spec_data.csv'))
54
54
  end
55
55
 
56
56
  describe '#partition' do
@@ -125,6 +125,53 @@ describe Idhja22::Dataset do
125
125
  vs.size.should == 3
126
126
  end
127
127
  end
128
+
129
+ describe '#partition_by_category' do
130
+ it 'should divide the data set into a set of all Ys and a set of all Ns' do
131
+ sets = @ds.partition_by_category
132
+ sets.length.should == 2
133
+ sets['Y'].data.collect(&:category).uniq.should == ['Y']
134
+ sets['N'].data.collect(&:category).uniq.should == ['N']
135
+ end
136
+ end
137
+
138
+ describe '#<<' do
139
+ it 'should all datum to list of data' do
140
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d','e', 'Y'],['0','1','2','3','4'],'C')
141
+ expect { @ds << added_datum}.to change(@ds, :size)
142
+ @ds.data.last.should == added_datum
143
+ end
144
+
145
+ context 'mismatched category label' do
146
+ it 'should throw an error' do
147
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d','e', 'Y'],['0','1','2','3','4'],'D')
148
+ expect { @ds << added_datum}.to raise_error(Idhja22::Dataset::Datum::UnknownCategoryLabel)
149
+ end
150
+ end
151
+
152
+ context 'mismatching attributes' do
153
+ context 'extra attribute' do
154
+ it 'should throw an error' do
155
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d','e', 'f', 'Y'],['0','1','2','3','4', '5'],'C')
156
+ expect { @ds << added_datum}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
157
+ end
158
+ end
159
+
160
+ context 'missing attribute' do
161
+ it 'should throw an error' do
162
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d', 'Y'],['0','1','2','3'],'C')
163
+ expect { @ds << added_datum}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
164
+ end
165
+ end
166
+
167
+ context 'different attribute' do
168
+ it 'should throw an error' do
169
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d', 'e', 'Y'],['0','1','2','3','9'],'C')
170
+ expect { @ds << added_datum}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
171
+ end
172
+ end
173
+ end
174
+ end
128
175
  end
129
176
  end
130
177
  end
@@ -0,0 +1,205 @@
1
+ require 'spec_helper'
2
+
3
+ describe Idhja22::DecisionNode do
4
+ before(:all) do
5
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'large_spec_data.csv'))
6
+ @simple_decision_node = Idhja22::DecisionNode.new('3')
7
+
8
+ l1 = Idhja22::LeafNode.new(0.75, 'C')
9
+ l2 = Idhja22::LeafNode.new(0.0, 'C')
10
+
11
+ @simple_decision_node.add_branch('a', l1)
12
+ @simple_decision_node.add_branch('b', l2)
13
+ end
14
+
15
+ describe('#get_rules') do
16
+ it 'should return a list of rules' do
17
+ @simple_decision_node.get_rules.should == ["3 == a and then chance of C = 0.75", "3 == b and then chance of C = 0.0"]
18
+ end
19
+ end
20
+
21
+ describe '#leaves' do
22
+ it 'should return a list of terminating values' do
23
+ @simple_decision_node.leaves.should == [Idhja22::LeafNode.new(0.75, 'C'), Idhja22::LeafNode.new(0.0, 'C')]
24
+ end
25
+
26
+ context 'a branch without a terminating leaf node' do
27
+ it 'should throw an error' do
28
+ decision_node = Idhja22::DecisionNode.new('a')
29
+ decision_node.add_branch('1', Idhja22::LeafNode.new(0.75, 'C'))
30
+ decision_node.add_branch('2', Idhja22::DecisionNode.new('b'))
31
+
32
+ expect { decision_node.leaves }.to raise_error(Idhja22::IncompleteTree)
33
+ end
34
+ end
35
+ end
36
+
37
+ describe(' == ') do
38
+ it 'should return false with different decision attributes' do
39
+ dn = Idhja22::DecisionNode.new('2')
40
+ diff_dn = Idhja22::DecisionNode.new('3')
41
+ dn.should_not == diff_dn
42
+ end
43
+
44
+ it 'should return false with different branches' do
45
+ dn1 = Idhja22::DecisionNode.new('2')
46
+ diff_dn = Idhja22::DecisionNode.new('2')
47
+
48
+ leaf = Idhja22::LeafNode.new(0.75, 'C')
49
+ dn1.add_branch('value', leaf)
50
+
51
+ dn1.should_not == diff_dn
52
+ end
53
+
54
+ it 'should return true if decision node and branches match' do
55
+ dn1 = Idhja22::DecisionNode.new('2')
56
+ dn2 = Idhja22::DecisionNode.new('2')
57
+
58
+ leaf = Idhja22::LeafNode.new(0.75, 'C')
59
+ dn1.add_branch('value', leaf)
60
+ dn2.add_branch('value', leaf)
61
+
62
+ dn1.should == dn2
63
+ end
64
+ end
65
+
66
+ describe 'category_label' do
67
+ it 'should return the category_label from the leaves' do
68
+ @simple_decision_node.category_label.should == 'C'
69
+ end
70
+
71
+ context 'incomplete node' do
72
+ it 'should throw an error' do
73
+ dn = Idhja22::DecisionNode.new('a')
74
+ expect { dn.category_label }.to raise_error(Idhja22::IncompleteTree)
75
+ end
76
+ end
77
+
78
+ end
79
+
80
+ describe 'evaluate' do
81
+ it 'should follow node to probability' do
82
+ query = Idhja22::Dataset::Datum.new(['a', 'a'], ['3', '4'], 'C')
83
+ @simple_decision_node.evaluate(query).should == 0.75
84
+
85
+ query = Idhja22::Dataset::Datum.new(['b', 'a'], ['3', '4'], 'C')
86
+ @simple_decision_node.evaluate(query).should == 0.0
87
+ end
88
+
89
+ context 'mismatching attribute label' do
90
+ it 'should raise an error' do
91
+ query = Idhja22::Dataset::Datum.new(['b', 'a'], ['1', '2'], 'C')
92
+ expect {@simple_decision_node.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
93
+ end
94
+ end
95
+
96
+ context 'unknown attribute value' do
97
+ it 'should raise an error' do
98
+ query = Idhja22::Dataset::Datum.new(['c', 'a'], ['3', '4'], 'C')
99
+ expect {@simple_decision_node.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeValue)
100
+ end
101
+ end
102
+ end
103
+
104
+ describe('.build') do
105
+ it 'should build a decision node based on the provided data' do
106
+ node = Idhja22::DecisionNode.build(@ds, @ds.attribute_labels, 0)
107
+ node.decision_attribute.should == "2"
108
+ node.branches.keys.should == ['a','b']
109
+ end
110
+
111
+ it 'should cleanup matching tails' do
112
+ ds = Idhja22::Dataset.from_csv(File.join(data_dir,'evenly_split.csv'))
113
+ node = Idhja22::DecisionNode.build(ds, ds.attribute_labels, 0)
114
+ node.get_rules.should == ['1 == a and then chance of C = 0.5', '1 == b and then chance of C = 0.5']
115
+ end
116
+ end
117
+
118
+ describe '#add_branch' do
119
+ it 'should add a branch for the given attribute value' do
120
+ node = Idhja22::DecisionNode.new 'attribute_name'
121
+ branch_node = Idhja22::DecisionNode.new 'other_name'
122
+ node.add_branch('value', branch_node)
123
+ node.branches.keys.should == ['value']
124
+ node.branches['value'].should == branch_node
125
+ end
126
+ end
127
+
128
+ describe '#cleanup_children!' do
129
+ context 'with matching output at level below' do
130
+ before(:all) do
131
+ @dn = Idhja22::DecisionNode.new('a')
132
+ @dn_below = Idhja22::DecisionNode.new('b')
133
+ @dn_below.add_branch('1', Idhja22::LeafNode.new(0.505, 'Category'))
134
+ @dn_below.add_branch('2', Idhja22::LeafNode.new(0.50, 'Category'))
135
+ @dn.add_branch('1', @dn_below)
136
+ end
137
+ it 'should merge any subnodes with same output into a single leafnode' do
138
+ @dn.cleanup_children!
139
+ @dn.branches['1'].should == Idhja22::LeafNode.new(0.505, 'Category')
140
+ end
141
+ end
142
+
143
+ context 'with matching output at two levels below' do
144
+ before(:all) do
145
+ @dn = Idhja22::DecisionNode.new('a')
146
+ @dn_1_below = Idhja22::DecisionNode.new('b')
147
+ @dn.add_branch('1', @dn_1_below)
148
+
149
+ @dn_2_below = Idhja22::DecisionNode.new('c')
150
+ @dn_1_below.add_branch('1', @dn_2_below)
151
+
152
+ @dn_2_below.add_branch('1', Idhja22::LeafNode.new(0.50, 'Category'))
153
+ @dn_2_below.add_branch('2', Idhja22::LeafNode.new(0.50, 'Category'))
154
+ end
155
+
156
+ it 'should merge nodes recusively' do
157
+ @dn.cleanup_children!
158
+ @dn.branches['1'].should == Idhja22::LeafNode.new(0.50, 'Category')
159
+ end
160
+ end
161
+
162
+ context 'with diverging branches that match internally' do
163
+ before(:all) do
164
+ @dn = Idhja22::DecisionNode.new('a')
165
+
166
+ dn_1_below = Idhja22::DecisionNode.new('b')
167
+ @dn.add_branch('1', dn_1_below)
168
+
169
+ dn_2_below = Idhja22::DecisionNode.new('c')
170
+ dn_1_below.add_branch('1', dn_2_below)
171
+
172
+ dn_2_below.add_branch('1', Idhja22::LeafNode.new(0.50, 'Category'))
173
+ dn_2_below.add_branch('2', Idhja22::LeafNode.new(0.50, 'Category'))
174
+
175
+ dn_2_below = Idhja22::DecisionNode.new('d')
176
+ dn_1_below.add_branch('2', dn_2_below)
177
+
178
+ dn_2_below.add_branch('1', Idhja22::LeafNode.new(0.70, 'Category'))
179
+ dn_2_below.add_branch('2', Idhja22::LeafNode.new(0.70, 'Category'))
180
+ end
181
+
182
+ it 'should merge nodes recusively' do
183
+ @dn.cleanup_children!
184
+ @dn.branches['1'].branches['1'].should == Idhja22::LeafNode.new(0.50, 'Category')
185
+ @dn.branches['1'].branches['2'].should == Idhja22::LeafNode.new(0.70, 'Category')
186
+ end
187
+ end
188
+
189
+ context 'without matching output' do
190
+ before(:all) do
191
+ @dn = Idhja22::DecisionNode.new('a')
192
+ @dn_below = Idhja22::DecisionNode.new('b')
193
+ @dn_below.add_branch('1', Idhja22::LeafNode.new(0.2, 'Category'))
194
+ @dn_below.add_branch('2', Idhja22::LeafNode.new(0.70, 'Category'))
195
+ @dn.add_branch('1', @dn_below)
196
+ end
197
+
198
+ it 'should do nothing' do
199
+ saved_rules = @dn.get_rules
200
+ @dn.cleanup_children!
201
+ @dn.get_rules.should == saved_rules
202
+ end
203
+ end
204
+ end
205
+ end
@@ -0,0 +1,53 @@
1
+ require 'spec_helper'
2
+
3
+ describe Idhja22::LeafNode do
4
+ describe('.new') do
5
+ it 'should store probability and category label' do
6
+ l = Idhja22::LeafNode.new(0.75, 'label')
7
+ l.probability.should == 0.75
8
+ l.category_label.should == 'label'
9
+ end
10
+ end
11
+
12
+ describe('#get_rules') do
13
+ it 'should return the probability' do
14
+ l = Idhja22::LeafNode.new(0.75, 'pudding')
15
+ l.get_rules.should == ['then chance of pudding = 0.75']
16
+ end
17
+ end
18
+
19
+ describe(' == ') do
20
+ let(:l1) { Idhja22::LeafNode.new(0.75, 'pudding') }
21
+ let(:l2) { Idhja22::LeafNode.new(0.75, 'pudding') }
22
+ let(:diff_l1) { Idhja22::LeafNode.new(0.7, 'pudding') }
23
+ let(:diff_l2) { Idhja22::LeafNode.new(0.75, 'starter') }
24
+ it 'should compare attributes' do
25
+ l1.should == l2
26
+ l1.should_not == diff_l1
27
+ l1.should_not == diff_l2
28
+ end
29
+ end
30
+
31
+ describe 'evaluate' do
32
+ let(:leaf) { Idhja22::LeafNode.new(0.6, 'pudding') }
33
+
34
+ it 'should return probability' do
35
+ query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'pudding')
36
+ leaf.evaluate(query).should == 0.6
37
+ end
38
+
39
+ context 'mismatching category labels' do
40
+ it 'should raise error' do
41
+ query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'tennis')
42
+ expect {leaf.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownCategoryLabel)
43
+ end
44
+ end
45
+ end
46
+
47
+ describe '#leaves' do
48
+ it 'should return self' do
49
+ leaf = Idhja22::LeafNode.new(0.6, 'pudding')
50
+ leaf.leaves.should == [leaf]
51
+ end
52
+ end
53
+ end
data/spec/spec_helper.rb CHANGED
@@ -15,6 +15,10 @@ Configuration.for('spec', Idhja22.config) {
15
15
 
16
16
  Idhja22.configure('spec')
17
17
 
18
+ def data_dir
19
+ File.dirname(__FILE__) + '/data/'
20
+ end
21
+
18
22
  RSpec.configure do |config|
19
23
 
20
24
  end
data/spec/tree_spec.rb CHANGED
@@ -2,7 +2,7 @@ require 'spec_helper'
2
2
 
3
3
  describe Idhja22::Tree do
4
4
  before(:all) do
5
- @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
5
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'large_spec_data.csv'))
6
6
  end
7
7
 
8
8
 
@@ -29,7 +29,7 @@ describe Idhja22::Tree do
29
29
  it 'should compare root nodes' do
30
30
  tree1 = Idhja22::Tree.train(@ds)
31
31
  tree2 = Idhja22::Tree.train(@ds)
32
- diff_ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'another_large_spec_data.csv'))
32
+ diff_ds = Idhja22::Dataset.from_csv(File.join(data_dir,'another_large_spec_data.csv'))
33
33
  diff_tree = Idhja22::Tree.train(diff_ds)
34
34
  tree1.should == tree2
35
35
  tree1.should_not == diff_tree
@@ -39,7 +39,7 @@ describe Idhja22::Tree do
39
39
  describe('.train_from_csv') do
40
40
  it 'should make the same tree as the one from the dataset' do
41
41
  tree = Idhja22::Tree.train(@ds)
42
- csv_tree = Idhja22::Tree.train_from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
42
+ csv_tree = Idhja22::Tree.train_from_csv(File.join(data_dir,'large_spec_data.csv'))
43
43
  tree.should == csv_tree
44
44
  end
45
45
  end
@@ -85,7 +85,7 @@ describe Idhja22::Tree do
85
85
 
86
86
  describe('.train_and_validate_from_csv') do
87
87
  it 'should make the same tree as the one from the dataset' do
88
- csv_tree, validation_value = Idhja22::Tree.train_and_validate_from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'), 0.75)
88
+ csv_tree, validation_value = Idhja22::Tree.train_and_validate_from_csv(File.join(data_dir,'large_spec_data.csv'), :"training-proportion" => 0.75)
89
89
  csv_tree.is_a?(Idhja22::Tree).should be_true
90
90
  (0..1).include?(validation_value).should be_true
91
91
  end
data/spec/version_spec.rb CHANGED
@@ -3,7 +3,7 @@ require 'spec_helper'
3
3
  describe Idhja22 do
4
4
  describe 'VERSION' do
5
5
  it 'should be current version' do
6
- Idhja22::VERSION.should == '0.14.4'
6
+ Idhja22::VERSION.should == '1.1.0'
7
7
  end
8
8
  end
9
9
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: idhja22
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.14.4
4
+ version: 1.1.0
5
5
  prerelease:
6
6
  platform: ruby
7
7
  authors:
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2012-12-18 00:00:00.000000000 Z
12
+ date: 2012-12-20 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rspec
@@ -123,7 +123,7 @@ dependencies:
123
123
  - - ! '>='
124
124
  - !ruby/object:Gem::Version
125
125
  version: '0'
126
- description: Decision Trees
126
+ description: Classifiers
127
127
  email:
128
128
  executables:
129
129
  - idhja22
@@ -141,6 +141,7 @@ files:
141
141
  - idhja22.gemspec
142
142
  - lib/idhja22.rb
143
143
  - lib/idhja22/bayes.rb
144
+ - lib/idhja22/binary_classifier.rb
144
145
  - lib/idhja22/config/default.rb
145
146
  - lib/idhja22/dataset.rb
146
147
  - lib/idhja22/dataset/datum.rb
@@ -149,13 +150,15 @@ files:
149
150
  - lib/idhja22/tree.rb
150
151
  - lib/idhja22/tree/node.rb
151
152
  - lib/idhja22/version.rb
152
- - spec/another_large_spec_data.csv
153
153
  - spec/bayes_spec.rb
154
+ - spec/data/another_large_spec_data.csv
155
+ - spec/data/evenly_split.csv
156
+ - spec/data/large_spec_data.csv
157
+ - spec/data/spec_data.csv
154
158
  - spec/dataset/example_spec.rb
155
159
  - spec/dataset_spec.rb
156
- - spec/large_spec_data.csv
157
- - spec/node_spec.rb
158
- - spec/spec_data.csv
160
+ - spec/node/decision_node_spec.rb
161
+ - spec/node/leaf_node_spec.rb
159
162
  - spec/spec_helper.rb
160
163
  - spec/tree_spec.rb
161
164
  - spec/version_spec.rb
@@ -173,7 +176,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
173
176
  version: '0'
174
177
  segments:
175
178
  - 0
176
- hash: -803768552583374641
179
+ hash: 3479458333568153307
177
180
  required_rubygems_version: !ruby/object:Gem::Requirement
178
181
  none: false
179
182
  requirements:
@@ -182,21 +185,23 @@ required_rubygems_version: !ruby/object:Gem::Requirement
182
185
  version: '0'
183
186
  segments:
184
187
  - 0
185
- hash: -803768552583374641
188
+ hash: 3479458333568153307
186
189
  requirements: []
187
190
  rubyforge_project:
188
191
  rubygems_version: 1.8.24
189
192
  signing_key:
190
193
  specification_version: 3
191
- summary: A gem for creating decision trees
194
+ summary: A gem for creating classifiers (decision trees and naive Bayes so far)
192
195
  test_files:
193
- - spec/another_large_spec_data.csv
194
196
  - spec/bayes_spec.rb
197
+ - spec/data/another_large_spec_data.csv
198
+ - spec/data/evenly_split.csv
199
+ - spec/data/large_spec_data.csv
200
+ - spec/data/spec_data.csv
195
201
  - spec/dataset/example_spec.rb
196
202
  - spec/dataset_spec.rb
197
- - spec/large_spec_data.csv
198
- - spec/node_spec.rb
199
- - spec/spec_data.csv
203
+ - spec/node/decision_node_spec.rb
204
+ - spec/node/leaf_node_spec.rb
200
205
  - spec/spec_helper.rb
201
206
  - spec/tree_spec.rb
202
207
  - spec/version_spec.rb
data/spec/node_spec.rb DELETED
@@ -1,97 +0,0 @@
1
- require 'spec_helper'
2
-
3
- describe Idhja22::LeafNode do
4
- describe('.new') do
5
- it 'should store probability and category label' do
6
- l = Idhja22::LeafNode.new(0.75, 'label')
7
- l.probability.should == 0.75
8
- l.category_label.should == 'label'
9
- end
10
- end
11
-
12
- describe('#get_rules') do
13
- it 'should return the probability' do
14
- l = Idhja22::LeafNode.new(0.75, 'pudding')
15
- l.get_rules.should == ['then chance of pudding = 0.75']
16
- end
17
- end
18
-
19
- describe(' == ') do
20
- let(:l1) { Idhja22::LeafNode.new(0.75, 'pudding') }
21
- let(:l2) { Idhja22::LeafNode.new(0.75, 'pudding') }
22
- let(:diff_l1) { Idhja22::LeafNode.new(0.7, 'pudding') }
23
- let(:diff_l2) { Idhja22::LeafNode.new(0.75, 'starter') }
24
- it 'should compare attributes' do
25
- l1.should == l2
26
- l1.should_not == diff_l1
27
- l1.should_not == diff_l2
28
- end
29
- end
30
-
31
- describe 'evaluate' do
32
- let(:leaf) { Idhja22::LeafNode.new(0.6, 'pudding') }
33
-
34
- it 'should return probability' do
35
- query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'pudding')
36
- leaf.evaluate(query).should == 0.6
37
- end
38
-
39
- context 'mismatching category labels' do
40
- it 'should raise error' do
41
- query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'tennis')
42
- expect {leaf.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownCategoryLabel)
43
- end
44
- end
45
- end
46
- end
47
-
48
- describe Idhja22::DecisionNode do
49
- before(:all) do
50
- @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
51
- end
52
-
53
- describe('#get_rules') do
54
- it 'should return a list of rules' do
55
- l = Idhja22::DecisionNode.new(@ds.partition('2'), '3', [], 0, 0.75)
56
- l.get_rules.should == ["3 == a and then chance of C = 0.75", "3 == b and then chance of C = 0.0"]
57
- end
58
- end
59
-
60
- describe(' == ') do
61
- let(:dn1) { Idhja22::DecisionNode.new(@ds.partition('2'), '2', [], 0, 0.75) }
62
- let(:dn2) { Idhja22::DecisionNode.new(@ds.partition('2'), '2', [], 0, 0.75) }
63
- let(:diff_dn1) { Idhja22::DecisionNode.new(@ds.partition('0'), '2', [], 0, 0.75) }
64
- let(:diff_dn2) { Idhja22::DecisionNode.new(@ds.partition('3'), '3', [], 0, 0.75) }
65
-
66
- it 'should compare ' do
67
- dn1.should == dn2
68
- dn1.should_not == diff_dn1
69
- dn1.should_not == diff_dn2
70
- end
71
- end
72
-
73
- describe 'evaluate' do
74
- let(:dn) { Idhja22::DecisionNode.new(@ds.partition('2'), '3', [], 0, 0.75) }
75
- it 'should follow node to probability' do
76
- query = Idhja22::Dataset::Datum.new(['a', 'a'], ['3', '4'], 'C')
77
- dn.evaluate(query).should == 0.75
78
-
79
- query = Idhja22::Dataset::Datum.new(['b', 'a'], ['3', '4'], 'C')
80
- dn.evaluate(query).should == 0.0
81
- end
82
-
83
- context 'mismatching attribute label' do
84
- it 'should raise an error' do
85
- query = Idhja22::Dataset::Datum.new(['b', 'a'], ['1', '2'], 'C')
86
- expect {dn.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
87
- end
88
- end
89
-
90
- context 'unknown attribute value' do
91
- it 'should raise an error' do
92
- query = Idhja22::Dataset::Datum.new(['c', 'a'], ['3', '4'], 'C')
93
- expect {dn.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeValue)
94
- end
95
- end
96
- end
97
- end