idhja22 0.14.4 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/.gitignore CHANGED
@@ -27,4 +27,4 @@ Gemfile.lock
27
27
  .DS_Store
28
28
 
29
29
  # data directory for storing csvs to run the program against
30
- data
30
+ /data
data/bin/idhja22 CHANGED
@@ -4,13 +4,22 @@ require 'thor'
4
4
  require 'idhja22'
5
5
 
6
6
  class TrainAndValidate < Thor
7
- desc "train_and_validate FILE", "train a tree for the given file and validate is against a validation set"
7
+ desc "train_and_validate_tree FILE", "train a tree for the given file and validate is against a validation set"
8
+ method_option :attributes, :type => :array
8
9
  method_option :"training-proportion", :type => :numeric, :default => 1.0, :aliases => 't'
9
- def train_and_validate(filename)
10
- t, v = Idhja22::Tree.train_and_validate_from_csv(filename, options[:"training-proportion"])
10
+ def train_and_validate_tree(filename)
11
+ t, v = Idhja22::Tree.train_and_validate_from_csv(filename, options)
11
12
  puts t.get_rules
12
13
  puts "Against validation set probability of successful classifiction: #{v}" if options[:"training-proportion"] < 1.0
13
14
  end
15
+
16
+ desc "train_and_validate_bayes FILE", "train a naive Bayesian classifier for the given file and validate is against a validation set"
17
+ method_option :attributes, :type => :array
18
+ method_option :"training-proportion", :type => :numeric, :default => 1.0, :aliases => 't'
19
+ def train_and_validate_bayes(filename)
20
+ t, v = Idhja22::Bayes.train_and_validate_from_csv(filename, options)
21
+ puts "Against validation set probability of successful classifiction: #{v}" if options[:"training-proportion"] < 1.0
22
+ end
14
23
  end
15
24
 
16
25
  TrainAndValidate.start
data/idhja22.gemspec CHANGED
@@ -7,8 +7,8 @@ Gem::Specification.new do |gem|
7
7
  gem.name = "idhja22"
8
8
  gem.version = Idhja22::VERSION
9
9
  gem.authors = ["Henry Addison"]
10
- gem.description = %q{Decision Trees}
11
- gem.summary = %q{A gem for creating decision trees}
10
+ gem.description = %q{Classifiers}
11
+ gem.summary = %q{A gem for creating classifiers (decision trees and naive Bayes so far)}
12
12
  gem.homepage = "https://github.com/henryaddison/idhja22"
13
13
 
14
14
  gem.files = `git ls-files`.split($/)
data/lib/idhja22/bayes.rb CHANGED
@@ -1,5 +1,53 @@
1
1
  module Idhja22
2
- class Bayes
3
-
2
+ class Bayes < BinaryClassifier
3
+ attr_accessor :conditional_probabilities, :prior_probabilities
4
+ class << self
5
+ def calculate_conditional_probabilities dataset, attribute_labels_to_use
6
+ conditional_probabilities = {}
7
+ attribute_labels_to_use.each do |attr_label|
8
+ conditional_probabilities[attr_label] = {}
9
+ dataset.partition_by_category.each do |cat, uniform_category_ds|
10
+ conditional_probabilities[attr_label][cat] = Hash.new(0)
11
+ partitioned_data = uniform_category_ds.partition(attr_label)
12
+ partitioned_data.each do |attr_value, uniform_value_ds|
13
+ conditional_probabilities[attr_label][cat][attr_value] = uniform_value_ds.size.to_f/uniform_category_ds.size.to_f
14
+ end
15
+ end
16
+ end
17
+
18
+ return conditional_probabilities
19
+ end
20
+
21
+ def calculate_priors dataset
22
+ output = Hash.new(0)
23
+ dataset.category_counts.each do |cat, count|
24
+ output[cat] = count.to_f/dataset.size.to_f
25
+ end
26
+ return output
27
+ end
28
+ end
29
+
30
+ def evaluate(query)
31
+ nb_values = {}
32
+ total_values = 0
33
+
34
+ prior_probabilities.each do |cat, prior_prob|
35
+ nb_value = prior_prob
36
+ conditional_probabilities.each do |attr_label, probs|
37
+ raise Idhja22::Dataset::Datum::UnknownAttributeValue, "Not seen value #{query[attr_label]} for attribute #{attr_label} in training." unless probs[cat].has_key? query[attr_label]
38
+ nb_value *= probs[cat][query[attr_label]]
39
+ end
40
+ total_values += nb_value
41
+ nb_values[cat] = nb_value
42
+ end
43
+
44
+ return nb_values['Y']/total_values
45
+ end
46
+
47
+ def train(dataset, attributes_to_use)
48
+ self.conditional_probabilities = self.class.calculate_conditional_probabilities(dataset, attributes_to_use)
49
+ self.prior_probabilities = self.class.calculate_priors(dataset)
50
+ return self
51
+ end
4
52
  end
5
53
  end
@@ -0,0 +1,54 @@
1
+ module Idhja22
2
+ class BinaryClassifier
3
+
4
+ class << self
5
+ # Trains a classifier using the provided Dataset.
6
+ def train(dataset, opts = {})
7
+ attributes_to_use = (opts[:attributes] || dataset.attribute_labels)
8
+ classifier = new
9
+ classifier.train(dataset, attributes_to_use)
10
+ return classifier
11
+ end
12
+
13
+ # Takes a dataset and splits it randomly into training and validation data.
14
+ # Uses the training data to train a classifier whose perfomance then measured using the validation data.
15
+ # @param [Float] Proportion of dataset to use for training. The rest will be used to validate the resulting classifier.
16
+ def train_and_validate(dataset, opts = {})
17
+ opts[:"training-proportion"] ||= 0.5
18
+ training_set, validation_set = dataset.split(opts[:"training-proportion"])
19
+ tree = self.train(training_set, opts)
20
+ validation_value = tree.validate(validation_set)
21
+ return tree, validation_value
22
+ end
23
+
24
+ # see #train
25
+ # @note Takes a CSV filename rather than a Dataset
26
+ def train_from_csv(filename, opts={})
27
+ ds = Dataset.from_csv(filename)
28
+ train(ds, opts)
29
+ end
30
+
31
+ # see #train_and_validate
32
+ # @note Takes a CSV filename rather than a Dataset
33
+ def train_and_validate_from_csv(filename, opts={})
34
+ ds = Dataset.from_csv(filename)
35
+ train_and_validate(ds, opts)
36
+ end
37
+ end
38
+
39
+ def validate(ds)
40
+ output = 0
41
+ ds.data.each do |validation_point|
42
+ begin
43
+ prob = evaluate(validation_point)
44
+ output += (validation_point.category == 'Y' ? prob : 1.0 - prob)
45
+ rescue Idhja22::Dataset::Datum::UnknownAttributeValue
46
+ # if don't recognised the attribute value in the example, then assume the worst:
47
+ # will never classify this point correctly
48
+ # equivalent to output += 0 but no point running this
49
+ end
50
+ end
51
+ return output.to_f/ds.size.to_f
52
+ end
53
+ end
54
+ end
@@ -2,4 +2,5 @@ Configuration.for('default') {
2
2
  default_probability 0.5
3
3
  termination_probability 0.95
4
4
  min_dataset_size 20
5
+ probability_delta 0.01
5
6
  }
@@ -1,4 +1,5 @@
1
1
  module Idhja22
2
+ class IncompleteTree < StandardError; end
2
3
  class Dataset
3
4
  class BadData < ArgumentError; end
4
5
  class InsufficientData < BadData; end
@@ -17,13 +17,13 @@ module Idhja22
17
17
  category_label = labels.pop
18
18
  attribute_labels = labels
19
19
 
20
- data = []
20
+ set = new([], attribute_labels, category_label)
21
21
  csv.each do |row|
22
22
  training_example = Example.new(row, attribute_labels, category_label)
23
- data << training_example
23
+ set << training_example
24
24
  end
25
25
 
26
- new(data, attribute_labels, category_label)
26
+ return set
27
27
  end
28
28
  end
29
29
 
@@ -36,8 +36,9 @@ module Idhja22
36
36
 
37
37
  def category_counts
38
38
  counts = Hash.new(0)
39
- data.each do |d|
40
- counts[d.category]+=1
39
+ split_data = partition_by_category
40
+ split_data.each do |cat, d|
41
+ counts[cat] = d.size
41
42
  end
42
43
  return counts
43
44
  end
@@ -66,5 +67,21 @@ module Idhja22
66
67
 
67
68
  return training_set, validation_set
68
69
  end
70
+
71
+ def <<(example)
72
+ raise Idhja22::Dataset::Datum::UnknownCategoryLabel unless example.category_label == self.category_label
73
+ raise Idhja22::Dataset::Datum::UnknownAttributeLabel unless example.attribute_labels == self.attribute_labels
74
+ self.data << example
75
+ end
76
+
77
+ def partition_by_category
78
+ output = Hash.new do |hash, key|
79
+ hash[key] = self.class.new([], attribute_labels, category_label)
80
+ end
81
+ self.data.each do |d|
82
+ output[d.category] << d
83
+ end
84
+ return output
85
+ end
69
86
  end
70
87
  end
@@ -20,9 +20,7 @@ module Idhja22
20
20
  return Idhja22::LeafNode.new(dataset.probability, dataset.category_label)
21
21
  end
22
22
 
23
- data_split, best_attribute = best_attribute(dataset, attributes_available)
24
-
25
- node = Idhja22::DecisionNode.new(data_split, best_attribute, attributes_available-[best_attribute], depth, dataset.probability)
23
+ node = DecisionNode.build(dataset, attributes_available, depth)
26
24
 
27
25
  return node
28
26
  end
@@ -59,21 +57,34 @@ module Idhja22
59
57
 
60
58
  class DecisionNode < Node
61
59
  attr_reader :branches, :decision_attribute
62
- def initialize(data_split, decision_attribute, attributes_available, depth, parent_probability)
63
- @decision_attribute = decision_attribute
64
- @branches = {}
65
- data_split.each do |value, dataset|
66
- node = Node.build_node(dataset, attributes_available, depth+1, parent_probability)
67
- if(node.is_a?(DecisionNode) && node.branches.values.all? { |n| n.is_a?(LeafNode) })
68
- probs = node.branches.values.collect(&:probability)
69
- if(probs.max - probs.min < 0.01)
70
- node = LeafNode.new(probs.max, dataset.category_label)
71
- end
60
+
61
+ class << self
62
+ def build(dataset, attributes_available, depth)
63
+ data_split, best_attribute = best_attribute(dataset, attributes_available)
64
+
65
+ output_node = new(best_attribute)
66
+
67
+ data_split.each do |value, dataset|
68
+ node = Node.build_node(dataset, attributes_available-[best_attribute], depth+1, dataset.probability)
69
+
70
+ output_node.add_branch(value, node) if node && !(node.is_a?(DecisionNode) && node.branches.empty?)
72
71
  end
73
- @branches[value] = node if node && !(node.is_a?(DecisionNode) && node.branches.empty?)
72
+
73
+ output_node.cleanup_children!
74
+
75
+ return output_node
74
76
  end
75
77
  end
76
78
 
79
+ def initialize(decision_attribute)
80
+ @decision_attribute = decision_attribute
81
+ @branches = {}
82
+ end
83
+
84
+ def add_branch(attr_value, node)
85
+ @branches[attr_value] = node
86
+ end
87
+
77
88
  def get_rules
78
89
  rules = []
79
90
  branches.each do |v,n|
@@ -104,6 +115,29 @@ module Idhja22
104
115
  raise Idhja22::Dataset::Datum::UnknownAttributeValue, "when looking at attribute labelled #{self.decision_attribute} could not find branch for value #{queried_value}" if branch.nil?
105
116
  branch.evaluate(query)
106
117
  end
118
+
119
+ def cleanup_children!
120
+ branches.each do |attr, child_node|
121
+ child_node.cleanup_children!
122
+ leaves = child_node.leaves
123
+ probs = leaves.collect(&:probability)
124
+ if(probs.max - probs.min < Idhja22.config.probability_delta)
125
+ new_node = LeafNode.new(probs.max, category_label)
126
+ add_branch(attr, new_node)
127
+ end
128
+ end
129
+ end
130
+
131
+ def leaves
132
+ raise Idhja22::IncompleteTree, "decision node with no branches" if branches.empty?
133
+ branches.values.flat_map do |child_node|
134
+ child_node.leaves
135
+ end
136
+ end
137
+
138
+ def category_label
139
+ leaves.first.category_label
140
+ end
107
141
  end
108
142
 
109
143
  class LeafNode < Node
@@ -125,5 +159,14 @@ module Idhja22
125
159
  raise Idhja22::Dataset::Datum::UnknownCategoryLabel, "expected category label for query is #{query.category_label} but node is using #{self.category_label}" unless query.category_label == self.category_label
126
160
  return probability
127
161
  end
162
+
163
+ def leaves
164
+ return [self]
165
+ end
166
+
167
+ # no-op method - a leaf node has no children by definition
168
+ def cleanup_children!
169
+
170
+ end
128
171
  end
129
172
  end
data/lib/idhja22/tree.rb CHANGED
@@ -2,42 +2,15 @@ require "idhja22/tree/node"
2
2
 
3
3
  module Idhja22
4
4
  # The main entry class for a training, viewing and evaluating a decision tree.
5
- class Tree
5
+ class Tree < BinaryClassifier
6
6
  attr_accessor :root
7
7
  class << self
8
- # Trains a Tree using the provided Dataset.
9
- def train(dataset)
10
- new(dataset, dataset.attribute_labels)
11
- end
12
-
13
- # Takes a dataset and splits it randomly into training and validation data.
14
- # Uses the training data to train a tree whose perfomance then measured using the validation data.
15
- # @param [Float] Proportion of dataset to use for training. The rest will be used to validate the resulting tree.
16
- def train_and_validate(dataset, training_proportion=0.5)
17
- training_set, validation_set = dataset.split(training_proportion)
18
- tree = self.train(training_set)
19
- validation_value = tree.validate(validation_set)
20
- return tree, validation_value
21
- end
22
-
23
- # see #train
24
- # @note Takes a CSV filename rather than a Dataset
25
- def train_from_csv(filename)
26
- ds = Dataset.from_csv(filename)
27
- train(ds)
28
- end
29
-
30
- # see #train_and_validate
31
- # @note Takes a CSV filename rather than a Dataset
32
- def train_and_validate_from_csv(filename, training_proportion=0.5)
33
- ds = Dataset.from_csv(filename)
34
- train_and_validate(ds, training_proportion)
35
- end
36
8
  end
37
9
 
38
- def initialize(dataset, attributes_available)
10
+ def train(dataset, attributes_available)
39
11
  raise Idhja22::Dataset::InsufficientData, "require at least #{Idhja22.config.min_dataset_size} data points, only have #{dataset.size} in data set provided" if(dataset.size < Idhja22.config.min_dataset_size)
40
12
  @root = Node.build_node(dataset, attributes_available, 0)
13
+ return self
41
14
  end
42
15
 
43
16
  def get_rules
@@ -52,20 +25,5 @@ module Idhja22
52
25
  def evaluate query
53
26
  @root.evaluate(query)
54
27
  end
55
-
56
- def validate(ds)
57
- output = 0
58
- ds.data.each do |validation_point|
59
- begin
60
- prob = evaluate(validation_point)
61
- output += (validation_point.category == 'Y' ? prob : 1.0 - prob)
62
- rescue Idhja22::Dataset::Datum::UnknownAttributeValue
63
- # if don't recognised the attribute value in the example, then assume the worst:
64
- # will never classify this point correctly
65
- # equivalent to output += 0 but no point running this
66
- end
67
- end
68
- return output.to_f/ds.size.to_f
69
- end
70
28
  end
71
29
  end
@@ -1,3 +1,3 @@
1
1
  module Idhja22
2
- VERSION = "0.14.4"
2
+ VERSION = "1.1.0"
3
3
  end
data/lib/idhja22.rb CHANGED
@@ -3,6 +3,7 @@ require 'idhja22/config/default'
3
3
 
4
4
  require "idhja22/version"
5
5
  require "idhja22/dataset"
6
+ require "idhja22/binary_classifier"
6
7
  require "idhja22/tree"
7
8
  require "idhja22/bayes"
8
9
 
data/spec/bayes_spec.rb CHANGED
@@ -1,5 +1,119 @@
1
1
  require 'spec_helper'
2
2
 
3
3
  describe Idhja22::Bayes do
4
+ before(:all) do
5
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'large_spec_data.csv'))
6
+ end
4
7
 
8
+ describe '.train' do
9
+ it 'should train a classifier from a dataset' do
10
+ classifier = Idhja22::Bayes.train @ds, :attributes => %w{0}
11
+ cond_probs = classifier.conditional_probabilities
12
+ cond_probs.keys.should == ['0']
13
+
14
+ cond_probs['0'].keys.should == ['Y', 'N']
15
+
16
+ cond_probs['0']['Y']['a'].should == 5.0/6.0
17
+ cond_probs['0']['N']['a'].should == 0.75
18
+
19
+ cond_probs['0']['Y']['b'].should == 1.0/6.0
20
+ cond_probs['0']['N']['b'].should == 0.25
21
+
22
+ prior_probs = classifier.prior_probabilities
23
+ prior_probs['Y'].should == 0.6
24
+ prior_probs['N'].should == 0.4
25
+ end
26
+ end
27
+
28
+ describe '.calculate_conditional_probabilities' do
29
+ it 'should calculate the conditional probabilities of P(Cat|attr_val) from dataset for given attribute labels' do
30
+ cond_probs = Idhja22::Bayes.calculate_conditional_probabilities @ds, %w{0 2}
31
+ cond_probs.keys.should == ['0', '2']
32
+ cond_probs['0'].keys.should == ['Y','N']
33
+ cond_probs['2'].keys.should == ['Y','N']
34
+
35
+ cond_probs['0']['Y']['a'].should == 5.0/6.0
36
+ cond_probs['0']['N']['a'].should == 0.75
37
+ cond_probs['0']['Y']['b'].should == 1.0/6.0
38
+ cond_probs['0']['N']['b'].should == 0.25
39
+
40
+ cond_probs['2']['Y']['a'].should == 1.0
41
+ cond_probs['2']['N']['a'].should == 0.5
42
+ cond_probs['2']['Y']['b'].should == 0
43
+ cond_probs['2']['N']['b'].should == 0.5
44
+ end
45
+ end
46
+
47
+ describe '.calculate_priors' do
48
+ it 'should calculate the prior probabilities' do
49
+ prior_probs = Idhja22::Bayes.calculate_priors @ds
50
+ prior_probs['Y'].should == 0.6
51
+ prior_probs['N'].should == 0.4
52
+ end
53
+
54
+ context 'all single category' do
55
+ it 'should return 0 for other categories' do
56
+ uniform_ds = Idhja22::Dataset.new([Idhja22::Dataset::Example.new(['high', '20-30', 'vanilla', 'Y'], ['Confidence', 'Age group', 'fav ice cream'] , 'Loves Reading')], ['Confidence', 'Age group', 'fav ice cream'], 'Loves Reading')
57
+ prior_probs = Idhja22::Bayes.calculate_priors uniform_ds
58
+ prior_probs['Y'].should == 1.0
59
+ prior_probs['N'].should == 0
60
+ end
61
+ end
62
+ end
63
+
64
+ describe '#evaluate' do
65
+ before(:all) do
66
+ @bayes = Idhja22::Bayes.new
67
+ @bayes.conditional_probabilities = {
68
+ 'age' => {
69
+ 'Y' => {'young' => 0.98, 'old' => 0.02},
70
+ 'N' => {'young' => 0.98, 'old' => 0.02}
71
+
72
+ },
73
+ 'confidence' => {
74
+ 'Y' => {'high' => 0.6, 'medium' => 0.3, 'low' => 0.1},
75
+ 'N' => {'high' => 0.8, 'medium' => 0.15, 'low' => 0.05}
76
+ },
77
+ 'fav ice cream' => {
78
+ 'Y' => {'vanilla' => 0.75, 'strawberry' => 0.25},
79
+ 'N' => {'vanilla' => 0.5, 'strawberry' => 0.6}
80
+ }
81
+ }
82
+ @bayes.prior_probabilities = {'Y' => 0.75, 'N' => 0.25}
83
+ end
84
+
85
+ context 'Y likely' do
86
+ it 'should return probability of being Y' do
87
+ query = Idhja22::Dataset::Datum.new(['high', 'young', 'vanilla', 'cheddar'], ['confidence', 'age', 'fav ice cream', 'fav cheese'], 'Loves Reading')
88
+ @bayes.evaluate(query).should be_within(0.00001).of(0.77143)
89
+ end
90
+ end
91
+
92
+ context 'N likely' do
93
+ it 'should return probability of being Y' do
94
+ query = Idhja22::Dataset::Datum.new(['high', 'young', 'strawberry', 'cheddar'], ['confidence', 'age', 'fav ice cream', 'fav cheese'], 'Loves Reading')
95
+ @bayes.evaluate(query).should be_within(0.00001).of(0.48387)
96
+ end
97
+ end
98
+
99
+ context 'unrecognised attribute value' do
100
+ it 'should throw an error' do
101
+ query = Idhja22::Dataset::Datum.new(['high', 'young', 'chocolate', 'cheddar'], ['confidence', 'age', 'fav ice cream', 'fav cheese'], 'Loves Reading')
102
+ expect { @bayes.evaluate(query) }.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeValue)
103
+ end
104
+ end
105
+ end
106
+
107
+ describe '#validate' do
108
+ before(:all) do
109
+ @bayes = Idhja22::Bayes.train(@ds)
110
+ end
111
+
112
+ it 'should return the average probability that the tree gets the validation examples correct' do
113
+ vds = Idhja22::Dataset.new([], ['0', '1','2','3','4'],'C')
114
+ vds << Idhja22::Dataset::Example.new(['a','a','a','a','a','Y'],['0', '1','2','3','4'],'C')
115
+ vds << Idhja22::Dataset::Example.new(['a','a','a','a','a','N'],['0', '1','2','3','4'],'C')
116
+ @bayes.validate(vds).should == 0.5
117
+ end
118
+ end
5
119
  end
@@ -0,0 +1,9 @@
1
+ 1,2,C
2
+ a,a,Y
3
+ a,a,N
4
+ a,b,Y
5
+ a,b,N
6
+ b,a,Y
7
+ b,a,N
8
+ b,b,Y
9
+ b,b,N
File without changes
data/spec/dataset_spec.rb CHANGED
@@ -10,7 +10,7 @@ describe Idhja22::Dataset do
10
10
 
11
11
  describe 'from_csv' do
12
12
  before(:all) do
13
- @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'spec_data.csv'))
13
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'spec_data.csv'))
14
14
  end
15
15
 
16
16
  it 'should extract labels' do
@@ -50,7 +50,7 @@ describe Idhja22::Dataset do
50
50
 
51
51
  context 'ready made' do
52
52
  before(:all) do
53
- @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
53
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'large_spec_data.csv'))
54
54
  end
55
55
 
56
56
  describe '#partition' do
@@ -125,6 +125,53 @@ describe Idhja22::Dataset do
125
125
  vs.size.should == 3
126
126
  end
127
127
  end
128
+
129
+ describe '#partition_by_category' do
130
+ it 'should divide the data set into a set of all Ys and a set of all Ns' do
131
+ sets = @ds.partition_by_category
132
+ sets.length.should == 2
133
+ sets['Y'].data.collect(&:category).uniq.should == ['Y']
134
+ sets['N'].data.collect(&:category).uniq.should == ['N']
135
+ end
136
+ end
137
+
138
+ describe '#<<' do
139
+ it 'should all datum to list of data' do
140
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d','e', 'Y'],['0','1','2','3','4'],'C')
141
+ expect { @ds << added_datum}.to change(@ds, :size)
142
+ @ds.data.last.should == added_datum
143
+ end
144
+
145
+ context 'mismatched category label' do
146
+ it 'should throw an error' do
147
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d','e', 'Y'],['0','1','2','3','4'],'D')
148
+ expect { @ds << added_datum}.to raise_error(Idhja22::Dataset::Datum::UnknownCategoryLabel)
149
+ end
150
+ end
151
+
152
+ context 'mismatching attributes' do
153
+ context 'extra attribute' do
154
+ it 'should throw an error' do
155
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d','e', 'f', 'Y'],['0','1','2','3','4', '5'],'C')
156
+ expect { @ds << added_datum}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
157
+ end
158
+ end
159
+
160
+ context 'missing attribute' do
161
+ it 'should throw an error' do
162
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d', 'Y'],['0','1','2','3'],'C')
163
+ expect { @ds << added_datum}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
164
+ end
165
+ end
166
+
167
+ context 'different attribute' do
168
+ it 'should throw an error' do
169
+ added_datum = Idhja22::Dataset::Example.new(['a','b','c','d', 'e', 'Y'],['0','1','2','3','9'],'C')
170
+ expect { @ds << added_datum}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
171
+ end
172
+ end
173
+ end
174
+ end
128
175
  end
129
176
  end
130
177
  end
@@ -0,0 +1,205 @@
1
+ require 'spec_helper'
2
+
3
+ describe Idhja22::DecisionNode do
4
+ before(:all) do
5
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'large_spec_data.csv'))
6
+ @simple_decision_node = Idhja22::DecisionNode.new('3')
7
+
8
+ l1 = Idhja22::LeafNode.new(0.75, 'C')
9
+ l2 = Idhja22::LeafNode.new(0.0, 'C')
10
+
11
+ @simple_decision_node.add_branch('a', l1)
12
+ @simple_decision_node.add_branch('b', l2)
13
+ end
14
+
15
+ describe('#get_rules') do
16
+ it 'should return a list of rules' do
17
+ @simple_decision_node.get_rules.should == ["3 == a and then chance of C = 0.75", "3 == b and then chance of C = 0.0"]
18
+ end
19
+ end
20
+
21
+ describe '#leaves' do
22
+ it 'should return a list of terminating values' do
23
+ @simple_decision_node.leaves.should == [Idhja22::LeafNode.new(0.75, 'C'), Idhja22::LeafNode.new(0.0, 'C')]
24
+ end
25
+
26
+ context 'a branch without a terminating leaf node' do
27
+ it 'should throw an error' do
28
+ decision_node = Idhja22::DecisionNode.new('a')
29
+ decision_node.add_branch('1', Idhja22::LeafNode.new(0.75, 'C'))
30
+ decision_node.add_branch('2', Idhja22::DecisionNode.new('b'))
31
+
32
+ expect { decision_node.leaves }.to raise_error(Idhja22::IncompleteTree)
33
+ end
34
+ end
35
+ end
36
+
37
+ describe(' == ') do
38
+ it 'should return false with different decision attributes' do
39
+ dn = Idhja22::DecisionNode.new('2')
40
+ diff_dn = Idhja22::DecisionNode.new('3')
41
+ dn.should_not == diff_dn
42
+ end
43
+
44
+ it 'should return false with different branches' do
45
+ dn1 = Idhja22::DecisionNode.new('2')
46
+ diff_dn = Idhja22::DecisionNode.new('2')
47
+
48
+ leaf = Idhja22::LeafNode.new(0.75, 'C')
49
+ dn1.add_branch('value', leaf)
50
+
51
+ dn1.should_not == diff_dn
52
+ end
53
+
54
+ it 'should return true if decision node and branches match' do
55
+ dn1 = Idhja22::DecisionNode.new('2')
56
+ dn2 = Idhja22::DecisionNode.new('2')
57
+
58
+ leaf = Idhja22::LeafNode.new(0.75, 'C')
59
+ dn1.add_branch('value', leaf)
60
+ dn2.add_branch('value', leaf)
61
+
62
+ dn1.should == dn2
63
+ end
64
+ end
65
+
66
+ describe 'category_label' do
67
+ it 'should return the category_label from the leaves' do
68
+ @simple_decision_node.category_label.should == 'C'
69
+ end
70
+
71
+ context 'incomplete node' do
72
+ it 'should throw an error' do
73
+ dn = Idhja22::DecisionNode.new('a')
74
+ expect { dn.category_label }.to raise_error(Idhja22::IncompleteTree)
75
+ end
76
+ end
77
+
78
+ end
79
+
80
+ describe 'evaluate' do
81
+ it 'should follow node to probability' do
82
+ query = Idhja22::Dataset::Datum.new(['a', 'a'], ['3', '4'], 'C')
83
+ @simple_decision_node.evaluate(query).should == 0.75
84
+
85
+ query = Idhja22::Dataset::Datum.new(['b', 'a'], ['3', '4'], 'C')
86
+ @simple_decision_node.evaluate(query).should == 0.0
87
+ end
88
+
89
+ context 'mismatching attribute label' do
90
+ it 'should raise an error' do
91
+ query = Idhja22::Dataset::Datum.new(['b', 'a'], ['1', '2'], 'C')
92
+ expect {@simple_decision_node.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
93
+ end
94
+ end
95
+
96
+ context 'unknown attribute value' do
97
+ it 'should raise an error' do
98
+ query = Idhja22::Dataset::Datum.new(['c', 'a'], ['3', '4'], 'C')
99
+ expect {@simple_decision_node.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeValue)
100
+ end
101
+ end
102
+ end
103
+
104
+ describe('.build') do
105
+ it 'should build a decision node based on the provided data' do
106
+ node = Idhja22::DecisionNode.build(@ds, @ds.attribute_labels, 0)
107
+ node.decision_attribute.should == "2"
108
+ node.branches.keys.should == ['a','b']
109
+ end
110
+
111
+ it 'should cleanup matching tails' do
112
+ ds = Idhja22::Dataset.from_csv(File.join(data_dir,'evenly_split.csv'))
113
+ node = Idhja22::DecisionNode.build(ds, ds.attribute_labels, 0)
114
+ node.get_rules.should == ['1 == a and then chance of C = 0.5', '1 == b and then chance of C = 0.5']
115
+ end
116
+ end
117
+
118
+ describe '#add_branch' do
119
+ it 'should add a branch for the given attribute value' do
120
+ node = Idhja22::DecisionNode.new 'attribute_name'
121
+ branch_node = Idhja22::DecisionNode.new 'other_name'
122
+ node.add_branch('value', branch_node)
123
+ node.branches.keys.should == ['value']
124
+ node.branches['value'].should == branch_node
125
+ end
126
+ end
127
+
128
+ describe '#cleanup_children!' do
129
+ context 'with matching output at level below' do
130
+ before(:all) do
131
+ @dn = Idhja22::DecisionNode.new('a')
132
+ @dn_below = Idhja22::DecisionNode.new('b')
133
+ @dn_below.add_branch('1', Idhja22::LeafNode.new(0.505, 'Category'))
134
+ @dn_below.add_branch('2', Idhja22::LeafNode.new(0.50, 'Category'))
135
+ @dn.add_branch('1', @dn_below)
136
+ end
137
+ it 'should merge any subnodes with same output into a single leafnode' do
138
+ @dn.cleanup_children!
139
+ @dn.branches['1'].should == Idhja22::LeafNode.new(0.505, 'Category')
140
+ end
141
+ end
142
+
143
+ context 'with matching output at two levels below' do
144
+ before(:all) do
145
+ @dn = Idhja22::DecisionNode.new('a')
146
+ @dn_1_below = Idhja22::DecisionNode.new('b')
147
+ @dn.add_branch('1', @dn_1_below)
148
+
149
+ @dn_2_below = Idhja22::DecisionNode.new('c')
150
+ @dn_1_below.add_branch('1', @dn_2_below)
151
+
152
+ @dn_2_below.add_branch('1', Idhja22::LeafNode.new(0.50, 'Category'))
153
+ @dn_2_below.add_branch('2', Idhja22::LeafNode.new(0.50, 'Category'))
154
+ end
155
+
156
+ it 'should merge nodes recusively' do
157
+ @dn.cleanup_children!
158
+ @dn.branches['1'].should == Idhja22::LeafNode.new(0.50, 'Category')
159
+ end
160
+ end
161
+
162
+ context 'with diverging branches that match internally' do
163
+ before(:all) do
164
+ @dn = Idhja22::DecisionNode.new('a')
165
+
166
+ dn_1_below = Idhja22::DecisionNode.new('b')
167
+ @dn.add_branch('1', dn_1_below)
168
+
169
+ dn_2_below = Idhja22::DecisionNode.new('c')
170
+ dn_1_below.add_branch('1', dn_2_below)
171
+
172
+ dn_2_below.add_branch('1', Idhja22::LeafNode.new(0.50, 'Category'))
173
+ dn_2_below.add_branch('2', Idhja22::LeafNode.new(0.50, 'Category'))
174
+
175
+ dn_2_below = Idhja22::DecisionNode.new('d')
176
+ dn_1_below.add_branch('2', dn_2_below)
177
+
178
+ dn_2_below.add_branch('1', Idhja22::LeafNode.new(0.70, 'Category'))
179
+ dn_2_below.add_branch('2', Idhja22::LeafNode.new(0.70, 'Category'))
180
+ end
181
+
182
+ it 'should merge nodes recusively' do
183
+ @dn.cleanup_children!
184
+ @dn.branches['1'].branches['1'].should == Idhja22::LeafNode.new(0.50, 'Category')
185
+ @dn.branches['1'].branches['2'].should == Idhja22::LeafNode.new(0.70, 'Category')
186
+ end
187
+ end
188
+
189
+ context 'without matching output' do
190
+ before(:all) do
191
+ @dn = Idhja22::DecisionNode.new('a')
192
+ @dn_below = Idhja22::DecisionNode.new('b')
193
+ @dn_below.add_branch('1', Idhja22::LeafNode.new(0.2, 'Category'))
194
+ @dn_below.add_branch('2', Idhja22::LeafNode.new(0.70, 'Category'))
195
+ @dn.add_branch('1', @dn_below)
196
+ end
197
+
198
+ it 'should do nothing' do
199
+ saved_rules = @dn.get_rules
200
+ @dn.cleanup_children!
201
+ @dn.get_rules.should == saved_rules
202
+ end
203
+ end
204
+ end
205
+ end
@@ -0,0 +1,53 @@
1
+ require 'spec_helper'
2
+
3
+ describe Idhja22::LeafNode do
4
+ describe('.new') do
5
+ it 'should store probability and category label' do
6
+ l = Idhja22::LeafNode.new(0.75, 'label')
7
+ l.probability.should == 0.75
8
+ l.category_label.should == 'label'
9
+ end
10
+ end
11
+
12
+ describe('#get_rules') do
13
+ it 'should return the probability' do
14
+ l = Idhja22::LeafNode.new(0.75, 'pudding')
15
+ l.get_rules.should == ['then chance of pudding = 0.75']
16
+ end
17
+ end
18
+
19
+ describe(' == ') do
20
+ let(:l1) { Idhja22::LeafNode.new(0.75, 'pudding') }
21
+ let(:l2) { Idhja22::LeafNode.new(0.75, 'pudding') }
22
+ let(:diff_l1) { Idhja22::LeafNode.new(0.7, 'pudding') }
23
+ let(:diff_l2) { Idhja22::LeafNode.new(0.75, 'starter') }
24
+ it 'should compare attributes' do
25
+ l1.should == l2
26
+ l1.should_not == diff_l1
27
+ l1.should_not == diff_l2
28
+ end
29
+ end
30
+
31
+ describe 'evaluate' do
32
+ let(:leaf) { Idhja22::LeafNode.new(0.6, 'pudding') }
33
+
34
+ it 'should return probability' do
35
+ query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'pudding')
36
+ leaf.evaluate(query).should == 0.6
37
+ end
38
+
39
+ context 'mismatching category labels' do
40
+ it 'should raise error' do
41
+ query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'tennis')
42
+ expect {leaf.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownCategoryLabel)
43
+ end
44
+ end
45
+ end
46
+
47
+ describe '#leaves' do
48
+ it 'should return self' do
49
+ leaf = Idhja22::LeafNode.new(0.6, 'pudding')
50
+ leaf.leaves.should == [leaf]
51
+ end
52
+ end
53
+ end
data/spec/spec_helper.rb CHANGED
@@ -15,6 +15,10 @@ Configuration.for('spec', Idhja22.config) {
15
15
 
16
16
  Idhja22.configure('spec')
17
17
 
18
+ def data_dir
19
+ File.dirname(__FILE__) + '/data/'
20
+ end
21
+
18
22
  RSpec.configure do |config|
19
23
 
20
24
  end
data/spec/tree_spec.rb CHANGED
@@ -2,7 +2,7 @@ require 'spec_helper'
2
2
 
3
3
  describe Idhja22::Tree do
4
4
  before(:all) do
5
- @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
5
+ @ds = Idhja22::Dataset.from_csv(File.join(data_dir,'large_spec_data.csv'))
6
6
  end
7
7
 
8
8
 
@@ -29,7 +29,7 @@ describe Idhja22::Tree do
29
29
  it 'should compare root nodes' do
30
30
  tree1 = Idhja22::Tree.train(@ds)
31
31
  tree2 = Idhja22::Tree.train(@ds)
32
- diff_ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'another_large_spec_data.csv'))
32
+ diff_ds = Idhja22::Dataset.from_csv(File.join(data_dir,'another_large_spec_data.csv'))
33
33
  diff_tree = Idhja22::Tree.train(diff_ds)
34
34
  tree1.should == tree2
35
35
  tree1.should_not == diff_tree
@@ -39,7 +39,7 @@ describe Idhja22::Tree do
39
39
  describe('.train_from_csv') do
40
40
  it 'should make the same tree as the one from the dataset' do
41
41
  tree = Idhja22::Tree.train(@ds)
42
- csv_tree = Idhja22::Tree.train_from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
42
+ csv_tree = Idhja22::Tree.train_from_csv(File.join(data_dir,'large_spec_data.csv'))
43
43
  tree.should == csv_tree
44
44
  end
45
45
  end
@@ -85,7 +85,7 @@ describe Idhja22::Tree do
85
85
 
86
86
  describe('.train_and_validate_from_csv') do
87
87
  it 'should make the same tree as the one from the dataset' do
88
- csv_tree, validation_value = Idhja22::Tree.train_and_validate_from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'), 0.75)
88
+ csv_tree, validation_value = Idhja22::Tree.train_and_validate_from_csv(File.join(data_dir,'large_spec_data.csv'), :"training-proportion" => 0.75)
89
89
  csv_tree.is_a?(Idhja22::Tree).should be_true
90
90
  (0..1).include?(validation_value).should be_true
91
91
  end
data/spec/version_spec.rb CHANGED
@@ -3,7 +3,7 @@ require 'spec_helper'
3
3
  describe Idhja22 do
4
4
  describe 'VERSION' do
5
5
  it 'should be current version' do
6
- Idhja22::VERSION.should == '0.14.4'
6
+ Idhja22::VERSION.should == '1.1.0'
7
7
  end
8
8
  end
9
9
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: idhja22
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.14.4
4
+ version: 1.1.0
5
5
  prerelease:
6
6
  platform: ruby
7
7
  authors:
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2012-12-18 00:00:00.000000000 Z
12
+ date: 2012-12-20 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rspec
@@ -123,7 +123,7 @@ dependencies:
123
123
  - - ! '>='
124
124
  - !ruby/object:Gem::Version
125
125
  version: '0'
126
- description: Decision Trees
126
+ description: Classifiers
127
127
  email:
128
128
  executables:
129
129
  - idhja22
@@ -141,6 +141,7 @@ files:
141
141
  - idhja22.gemspec
142
142
  - lib/idhja22.rb
143
143
  - lib/idhja22/bayes.rb
144
+ - lib/idhja22/binary_classifier.rb
144
145
  - lib/idhja22/config/default.rb
145
146
  - lib/idhja22/dataset.rb
146
147
  - lib/idhja22/dataset/datum.rb
@@ -149,13 +150,15 @@ files:
149
150
  - lib/idhja22/tree.rb
150
151
  - lib/idhja22/tree/node.rb
151
152
  - lib/idhja22/version.rb
152
- - spec/another_large_spec_data.csv
153
153
  - spec/bayes_spec.rb
154
+ - spec/data/another_large_spec_data.csv
155
+ - spec/data/evenly_split.csv
156
+ - spec/data/large_spec_data.csv
157
+ - spec/data/spec_data.csv
154
158
  - spec/dataset/example_spec.rb
155
159
  - spec/dataset_spec.rb
156
- - spec/large_spec_data.csv
157
- - spec/node_spec.rb
158
- - spec/spec_data.csv
160
+ - spec/node/decision_node_spec.rb
161
+ - spec/node/leaf_node_spec.rb
159
162
  - spec/spec_helper.rb
160
163
  - spec/tree_spec.rb
161
164
  - spec/version_spec.rb
@@ -173,7 +176,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
173
176
  version: '0'
174
177
  segments:
175
178
  - 0
176
- hash: -803768552583374641
179
+ hash: 3479458333568153307
177
180
  required_rubygems_version: !ruby/object:Gem::Requirement
178
181
  none: false
179
182
  requirements:
@@ -182,21 +185,23 @@ required_rubygems_version: !ruby/object:Gem::Requirement
182
185
  version: '0'
183
186
  segments:
184
187
  - 0
185
- hash: -803768552583374641
188
+ hash: 3479458333568153307
186
189
  requirements: []
187
190
  rubyforge_project:
188
191
  rubygems_version: 1.8.24
189
192
  signing_key:
190
193
  specification_version: 3
191
- summary: A gem for creating decision trees
194
+ summary: A gem for creating classifiers (decision trees and naive Bayes so far)
192
195
  test_files:
193
- - spec/another_large_spec_data.csv
194
196
  - spec/bayes_spec.rb
197
+ - spec/data/another_large_spec_data.csv
198
+ - spec/data/evenly_split.csv
199
+ - spec/data/large_spec_data.csv
200
+ - spec/data/spec_data.csv
195
201
  - spec/dataset/example_spec.rb
196
202
  - spec/dataset_spec.rb
197
- - spec/large_spec_data.csv
198
- - spec/node_spec.rb
199
- - spec/spec_data.csv
203
+ - spec/node/decision_node_spec.rb
204
+ - spec/node/leaf_node_spec.rb
200
205
  - spec/spec_helper.rb
201
206
  - spec/tree_spec.rb
202
207
  - spec/version_spec.rb
data/spec/node_spec.rb DELETED
@@ -1,97 +0,0 @@
1
- require 'spec_helper'
2
-
3
- describe Idhja22::LeafNode do
4
- describe('.new') do
5
- it 'should store probability and category label' do
6
- l = Idhja22::LeafNode.new(0.75, 'label')
7
- l.probability.should == 0.75
8
- l.category_label.should == 'label'
9
- end
10
- end
11
-
12
- describe('#get_rules') do
13
- it 'should return the probability' do
14
- l = Idhja22::LeafNode.new(0.75, 'pudding')
15
- l.get_rules.should == ['then chance of pudding = 0.75']
16
- end
17
- end
18
-
19
- describe(' == ') do
20
- let(:l1) { Idhja22::LeafNode.new(0.75, 'pudding') }
21
- let(:l2) { Idhja22::LeafNode.new(0.75, 'pudding') }
22
- let(:diff_l1) { Idhja22::LeafNode.new(0.7, 'pudding') }
23
- let(:diff_l2) { Idhja22::LeafNode.new(0.75, 'starter') }
24
- it 'should compare attributes' do
25
- l1.should == l2
26
- l1.should_not == diff_l1
27
- l1.should_not == diff_l2
28
- end
29
- end
30
-
31
- describe 'evaluate' do
32
- let(:leaf) { Idhja22::LeafNode.new(0.6, 'pudding') }
33
-
34
- it 'should return probability' do
35
- query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'pudding')
36
- leaf.evaluate(query).should == 0.6
37
- end
38
-
39
- context 'mismatching category labels' do
40
- it 'should raise error' do
41
- query = Idhja22::Dataset::Datum.new(['high', 'gusty'], ['temperature', 'windy'], 'tennis')
42
- expect {leaf.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownCategoryLabel)
43
- end
44
- end
45
- end
46
- end
47
-
48
- describe Idhja22::DecisionNode do
49
- before(:all) do
50
- @ds = Idhja22::Dataset.from_csv(File.join(File.dirname(__FILE__),'large_spec_data.csv'))
51
- end
52
-
53
- describe('#get_rules') do
54
- it 'should return a list of rules' do
55
- l = Idhja22::DecisionNode.new(@ds.partition('2'), '3', [], 0, 0.75)
56
- l.get_rules.should == ["3 == a and then chance of C = 0.75", "3 == b and then chance of C = 0.0"]
57
- end
58
- end
59
-
60
- describe(' == ') do
61
- let(:dn1) { Idhja22::DecisionNode.new(@ds.partition('2'), '2', [], 0, 0.75) }
62
- let(:dn2) { Idhja22::DecisionNode.new(@ds.partition('2'), '2', [], 0, 0.75) }
63
- let(:diff_dn1) { Idhja22::DecisionNode.new(@ds.partition('0'), '2', [], 0, 0.75) }
64
- let(:diff_dn2) { Idhja22::DecisionNode.new(@ds.partition('3'), '3', [], 0, 0.75) }
65
-
66
- it 'should compare ' do
67
- dn1.should == dn2
68
- dn1.should_not == diff_dn1
69
- dn1.should_not == diff_dn2
70
- end
71
- end
72
-
73
- describe 'evaluate' do
74
- let(:dn) { Idhja22::DecisionNode.new(@ds.partition('2'), '3', [], 0, 0.75) }
75
- it 'should follow node to probability' do
76
- query = Idhja22::Dataset::Datum.new(['a', 'a'], ['3', '4'], 'C')
77
- dn.evaluate(query).should == 0.75
78
-
79
- query = Idhja22::Dataset::Datum.new(['b', 'a'], ['3', '4'], 'C')
80
- dn.evaluate(query).should == 0.0
81
- end
82
-
83
- context 'mismatching attribute label' do
84
- it 'should raise an error' do
85
- query = Idhja22::Dataset::Datum.new(['b', 'a'], ['1', '2'], 'C')
86
- expect {dn.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeLabel)
87
- end
88
- end
89
-
90
- context 'unknown attribute value' do
91
- it 'should raise an error' do
92
- query = Idhja22::Dataset::Datum.new(['c', 'a'], ['3', '4'], 'C')
93
- expect {dn.evaluate(query)}.to raise_error(Idhja22::Dataset::Datum::UnknownAttributeValue)
94
- end
95
- end
96
- end
97
- end