idhja22 1.1.0 → 1.1.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,7 @@
1
1
  Configuration.for('default') {
2
2
  default_probability 0.5
3
3
  termination_probability 0.95
4
- min_dataset_size 20
4
+ min_dataset_size 10
5
5
  probability_delta 0.01
6
+ equivalent_sample_size 10
6
7
  }
@@ -55,6 +55,11 @@ module Idhja22
55
55
  category_counts['Y'].to_f/size.to_f
56
56
  end
57
57
 
58
+ def m_estimate(prior)
59
+ prior ||= Idhja22.config.default_probability
60
+ (category_counts['Y'] + (prior*Idhja22.config.equivalent_sample_size)).to_f/(size+Idhja22.config.equivalent_sample_size).to_f
61
+ end
62
+
58
63
  def split(training_proportion)
59
64
  shuffled_data = data.shuffle
60
65
  cutoff_point = (training_proportion.to_f*size).to_i
@@ -3,24 +3,24 @@ module Idhja22
3
3
  class << self
4
4
  def build_node(dataset, attributes_available, depth, parent_probability = nil)
5
5
  if(dataset.size < Idhja22.config.min_dataset_size)
6
- return Idhja22::LeafNode.new(probability_guess(parent_probability, depth), dataset.category_label)
6
+ return Idhja22::LeafNode.new(dataset.m_estimate(parent_probability), dataset.category_label)
7
7
  end
8
8
 
9
9
  #if successful termination - create and return a leaf node
10
10
  if(dataset.terminating? && depth > 0) # don't terminate without splitting the data at least once
11
- return Idhja22::LeafNode.new(dataset.probability, dataset.category_label)
11
+ return Idhja22::LeafNode.new(dataset.m_estimate(parent_probability), dataset.category_label)
12
12
  end
13
13
 
14
14
  if(depth >= 3) # don't let trees get too long
15
- return Idhja22::LeafNode.new(dataset.probability, dataset.category_label)
15
+ return Idhja22::LeafNode.new(dataset.m_estimate(parent_probability), dataset.category_label)
16
16
  end
17
17
 
18
18
  #if we have no more attributes left to split the dataset on, then return a leafnode
19
19
  if(attributes_available.empty?)
20
- return Idhja22::LeafNode.new(dataset.probability, dataset.category_label)
20
+ return Idhja22::LeafNode.new(dataset.m_estimate(parent_probability), dataset.category_label)
21
21
  end
22
22
 
23
- node = DecisionNode.build(dataset, attributes_available, depth)
23
+ node = DecisionNode.build(dataset, attributes_available, depth, dataset.m_estimate(parent_probability))
24
24
 
25
25
  return node
26
26
  end
@@ -44,10 +44,6 @@ module Idhja22
44
44
  end
45
45
  return data_split, best_attribute
46
46
  end
47
-
48
- def probability_guess(parent_probability, depth)
49
- return (parent_probability + (Idhja22.config.default_probability-parent_probability)/2**depth)
50
- end
51
47
  end
52
48
 
53
49
  def ==(other)
@@ -59,13 +55,13 @@ module Idhja22
59
55
  attr_reader :branches, :decision_attribute
60
56
 
61
57
  class << self
62
- def build(dataset, attributes_available, depth)
58
+ def build(dataset, attributes_available, depth, parent_probability=nil)
63
59
  data_split, best_attribute = best_attribute(dataset, attributes_available)
64
60
 
65
61
  output_node = new(best_attribute)
66
62
 
67
63
  data_split.each do |value, dataset|
68
- node = Node.build_node(dataset, attributes_available-[best_attribute], depth+1, dataset.probability)
64
+ node = Node.build_node(dataset, attributes_available-[best_attribute], depth+1, dataset.m_estimate(parent_probability))
69
65
 
70
66
  output_node.add_branch(value, node) if node && !(node.is_a?(DecisionNode) && node.branches.empty?)
71
67
  end
@@ -1,3 +1,3 @@
1
1
  module Idhja22
2
- VERSION = "1.1.0"
2
+ VERSION = "1.1.1"
3
3
  end
data/spec/dataset_spec.rb CHANGED
@@ -114,6 +114,18 @@ describe Idhja22::Dataset do
114
114
  end
115
115
  end
116
116
 
117
+ describe '#m_estimate' do
118
+ it 'should return an estimate for the probability of category being Y' do
119
+ @ds.m_estimate(0.5).should be_within(0.0001).of(0.55)
120
+ end
121
+
122
+ context 'nil prior' do
123
+ it 'should use the default prior' do
124
+ @ds.m_estimate(nil).should be_within(0.0001).of(0.55)
125
+ end
126
+ end
127
+ end
128
+
117
129
  describe '#split' do
118
130
  it 'should split into a training and validation set according to the given proportion' do
119
131
  ts, vs = @ds.split(0.5)
data/spec/tree_spec.rb CHANGED
@@ -21,7 +21,7 @@ describe Idhja22::Tree do
21
21
 
22
22
  describe('#get_rules') do
23
23
  it 'should list the rules of the tree' do
24
- Idhja22::Tree.train(@ds).get_rules.should == "if 2 == a and 4 == a and then chance of C = 1.0\nelsif 2 == a and 4 == b and then chance of C = 0.0\nelsif 2 == b and then chance of C = 0.0"
24
+ Idhja22::Tree.train(@ds).get_rules.should == "if 2 == a and 4 == a and then chance of C = 0.88\nelsif 2 == a and 4 == b and then chance of C = 0.48\nelsif 2 == b and then chance of C = 0.38"
25
25
  end
26
26
  end
27
27
 
@@ -48,7 +48,7 @@ describe Idhja22::Tree do
48
48
  it 'should return the probabilty at the leaf of the tree' do
49
49
  tree = Idhja22::Tree.train(@ds)
50
50
  query = Idhja22::Dataset::Datum.new(['z','z','a','z','a'],['0', '1','2','3','4'],'C')
51
- tree.evaluate(query).should == 1.0
51
+ tree.evaluate(query).should be_within(0.001).of(0.878)
52
52
  end
53
53
  end
54
54
 
data/spec/version_spec.rb CHANGED
@@ -3,7 +3,7 @@ require 'spec_helper'
3
3
  describe Idhja22 do
4
4
  describe 'VERSION' do
5
5
  it 'should be current version' do
6
- Idhja22::VERSION.should == '1.1.0'
6
+ Idhja22::VERSION.should == '1.1.1'
7
7
  end
8
8
  end
9
9
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: idhja22
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.1.0
4
+ version: 1.1.1
5
5
  prerelease:
6
6
  platform: ruby
7
7
  authors:
@@ -176,7 +176,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
176
176
  version: '0'
177
177
  segments:
178
178
  - 0
179
- hash: 3479458333568153307
179
+ hash: -1322747474535878301
180
180
  required_rubygems_version: !ruby/object:Gem::Requirement
181
181
  none: false
182
182
  requirements:
@@ -185,7 +185,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
185
185
  version: '0'
186
186
  segments:
187
187
  - 0
188
- hash: 3479458333568153307
188
+ hash: -1322747474535878301
189
189
  requirements: []
190
190
  rubyforge_project:
191
191
  rubygems_version: 1.8.24