sbn 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/README +320 -0
- data/lib/combination.rb +78 -0
- data/lib/formats.rb +119 -0
- data/lib/helpers.rb +140 -0
- data/lib/inference.rb +65 -0
- data/lib/learning.rb +141 -0
- data/lib/net.rb +49 -0
- data/lib/numeric_variable.rb +94 -0
- data/lib/sbn.rb +6 -0
- data/lib/string_variable.rb +176 -0
- data/lib/variable.rb +224 -0
- data/test/sbn.rb +5 -0
- data/test/test_combination.rb +51 -0
- data/test/test_helpers.rb +80 -0
- data/test/test_learning.rb +104 -0
- data/test/test_net.rb +136 -0
- data/test/test_variable.rb +373 -0
- metadata +63 -0
data/lib/variable.rb
ADDED
@@ -0,0 +1,224 @@
|
|
1
|
+
class Sbn
|
2
|
+
class Variable
|
3
|
+
attr_reader :name, :states, :parents, :children, :probability_table
|
4
|
+
|
5
|
+
def initialize(net, name = '', probabilities = [0.5, 0.5], states = [:true, :false])
|
6
|
+
@net = net
|
7
|
+
@@variable_count ||= 0
|
8
|
+
@@variable_count += 1
|
9
|
+
name = "variable_#{@@variable_count}" if name.is_a? String and name.empty?
|
10
|
+
@name = name.to_underscore_sym
|
11
|
+
@children = []
|
12
|
+
@parents = []
|
13
|
+
@states = []
|
14
|
+
@state_frequencies = {} # used for storing sample points
|
15
|
+
set_states(states)
|
16
|
+
set_probabilities(probabilities)
|
17
|
+
net.add_variable(self)
|
18
|
+
end
|
19
|
+
|
20
|
+
def ==(obj); test_equal(obj); end
|
21
|
+
def eql?(obj); test_equal(obj); end
|
22
|
+
def ===(obj); test_equal(obj); end
|
23
|
+
|
24
|
+
def to_s
|
25
|
+
@name.to_s
|
26
|
+
end
|
27
|
+
|
28
|
+
def to_xmlbif_variable(xml)
|
29
|
+
xml.variable(:type => "nature") do
|
30
|
+
xml.name(@name.to_s)
|
31
|
+
@states.each {|s| xml.outcome(s.to_s) }
|
32
|
+
xml.property("SbnVariableType = #{self.class.to_s}")
|
33
|
+
yield(xml) if block_given?
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def to_xmlbif_definition(xml)
|
38
|
+
xml.definition do
|
39
|
+
xml.for(@name.to_s)
|
40
|
+
@parents.each {|p| xml.given(p.name.to_s) }
|
41
|
+
xml.table(@probability_table.transpose.last.join(' '))
|
42
|
+
yield(xml) if block_given?
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def set_states(states)
|
47
|
+
states.symbolize_values!
|
48
|
+
@states = states
|
49
|
+
generate_probability_table
|
50
|
+
end
|
51
|
+
|
52
|
+
def set_probability(probability, event)
|
53
|
+
event = @net.symbolize_evidence(event)
|
54
|
+
unless can_be_evaluated?(event)
|
55
|
+
raise "A valid state was not supplied for variable #{@name} and all its parents in call to set_probability()"
|
56
|
+
end
|
57
|
+
combination_for_this_event = []
|
58
|
+
@parents.each {|p| combination_for_this_event << event[p.name] }
|
59
|
+
combination_for_this_event << event[@name]
|
60
|
+
index = state_combinations.index(combination_for_this_event)
|
61
|
+
@probabilities[index] = probability
|
62
|
+
generate_probability_table
|
63
|
+
end
|
64
|
+
|
65
|
+
def add_child(variable)
|
66
|
+
add_child_no_recurse(variable)
|
67
|
+
variable.add_parent_no_recurse(self)
|
68
|
+
end
|
69
|
+
|
70
|
+
def add_parent(variable)
|
71
|
+
add_parent_no_recurse(variable)
|
72
|
+
variable.add_child_no_recurse(self)
|
73
|
+
end
|
74
|
+
|
75
|
+
def set_probabilities(probs)
|
76
|
+
@probabilities = probs
|
77
|
+
generate_probability_table
|
78
|
+
end
|
79
|
+
|
80
|
+
def evidence_name # :nodoc:
|
81
|
+
@name
|
82
|
+
end
|
83
|
+
|
84
|
+
def add_child_no_recurse(variable) # :nodoc:
|
85
|
+
return if variable == self or @children.include?(variable)
|
86
|
+
if variable.is_a?(StringVariable)
|
87
|
+
@children.concat variable.covariables
|
88
|
+
else
|
89
|
+
@children << variable
|
90
|
+
end
|
91
|
+
variable.generate_probability_table
|
92
|
+
end
|
93
|
+
|
94
|
+
def add_parent_no_recurse(variable) # :nodoc:
|
95
|
+
return if variable == self or @parents.include?(variable)
|
96
|
+
if variable.is_a?(StringVariable)
|
97
|
+
@parents.concat variable.covariables
|
98
|
+
else
|
99
|
+
@parents << variable
|
100
|
+
end
|
101
|
+
generate_probability_table
|
102
|
+
end
|
103
|
+
|
104
|
+
def set_in_evidence?(evidence) # :nodoc:
|
105
|
+
evidence.has_key?(evidence_name)
|
106
|
+
end
|
107
|
+
|
108
|
+
def get_observed_state(evidence) # :nodoc:
|
109
|
+
evidence[@name]
|
110
|
+
end
|
111
|
+
|
112
|
+
# A variable can't be evaluated unless its parents have
|
113
|
+
# been observed
|
114
|
+
def can_be_evaluated?(evidence) # :nodoc:
|
115
|
+
returnval = true
|
116
|
+
parents.each {|p| returnval = false unless p.set_in_evidence?(evidence) }
|
117
|
+
returnval
|
118
|
+
end
|
119
|
+
|
120
|
+
# In order to draw uniformly from the probabilty space, we can't
|
121
|
+
# just pick a random state. Instead we generate a random number
|
122
|
+
# between zero and one and iterate through the states until the
|
123
|
+
# cumulative sum of their probabilities exceeds our random number.
|
124
|
+
def get_random_state(event = {}) # :nodoc:
|
125
|
+
seek_state {|s| evaluate_marginal(s, event) }
|
126
|
+
end
|
127
|
+
|
128
|
+
# similar to get_random_state() except it evaluates a variable's markov
|
129
|
+
# blanket in addition to the variable itself.
|
130
|
+
def get_random_state_with_markov_blanket(event) # :nodoc:
|
131
|
+
evaluations = []
|
132
|
+
@states.each {|s| evaluations << evaluate_markov_blanket(s, event) }
|
133
|
+
evaluations.normalize!
|
134
|
+
seek_state {|s| evaluations.shift }
|
135
|
+
end
|
136
|
+
|
137
|
+
def generate_probability_table # :nodoc:
|
138
|
+
@probab
|
139
|
+
@probability_table = nil
|
140
|
+
if @probabilities and @probabilities.size == state_combinations.size
|
141
|
+
probs = @probabilities.dup
|
142
|
+
@probability_table = state_combinations.collect {|e| [e, probs.shift] }
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
def evaluate_marginal(state, event) # :nodoc:
|
147
|
+
temp_probs = @probability_table.dup
|
148
|
+
remove_irrelevant_states(temp_probs, state, event)
|
149
|
+
sum = 0.0
|
150
|
+
temp_probs.each {|e| sum += e[1] }
|
151
|
+
sum
|
152
|
+
end
|
153
|
+
|
154
|
+
def transform_evidence_value(val) # :nodoc:
|
155
|
+
val.to_underscore_sym
|
156
|
+
end
|
157
|
+
|
158
|
+
private
|
159
|
+
def seek_state
|
160
|
+
sum = 0.0
|
161
|
+
num = rand
|
162
|
+
returnval = nil
|
163
|
+
@states.each do |s|
|
164
|
+
returnval = s
|
165
|
+
sum += yield(s)
|
166
|
+
break if num < sum
|
167
|
+
end
|
168
|
+
returnval
|
169
|
+
end
|
170
|
+
|
171
|
+
def state_combinations
|
172
|
+
all_states = []
|
173
|
+
@parents.each {|p| all_states << p.states }
|
174
|
+
all_states << @states
|
175
|
+
Combination.new(all_states).to_a
|
176
|
+
end
|
177
|
+
|
178
|
+
def remove_irrelevant_states(probabilities, state, evidence)
|
179
|
+
# remove the states for this variable
|
180
|
+
probabilities.reject! {|e| e.first.last != state }
|
181
|
+
index = 0
|
182
|
+
@parents.each do |parent|
|
183
|
+
unless parent.set_in_evidence?(evidence)
|
184
|
+
raise "Marginal cannot be evaluated because there are unset parent variables"
|
185
|
+
end
|
186
|
+
probabilities.reject! {|e| e.first[index] != parent.get_observed_state(evidence) }
|
187
|
+
index += 1
|
188
|
+
end
|
189
|
+
probabilities
|
190
|
+
end
|
191
|
+
|
192
|
+
def evaluate_markov_blanket(state, event)
|
193
|
+
returnval = 1.0
|
194
|
+
temp_probs = @probability_table.dup
|
195
|
+
remove_irrelevant_states(temp_probs, state, event)
|
196
|
+
temp = get_observed_state(event)
|
197
|
+
event[@name] = state
|
198
|
+
returnval *= evaluate_marginal(state, event)
|
199
|
+
@children.each {|child| returnval *= child.evaluate_marginal(child.get_observed_state(event), event) }
|
200
|
+
event[@name] = temp
|
201
|
+
returnval
|
202
|
+
end
|
203
|
+
|
204
|
+
def test_equal(variable)
|
205
|
+
returnval = true
|
206
|
+
net_name = variable.instance_eval('@net.name')
|
207
|
+
returnval = false unless @net.name == net_name
|
208
|
+
returnval = false unless variable.class == self.class and self.is_a? Variable
|
209
|
+
returnval = false unless returnval and @name == variable.name
|
210
|
+
if returnval
|
211
|
+
parent_names = []
|
212
|
+
variable.parents.each {|p| parent_names << p.name.to_s }
|
213
|
+
my_parent_names = []
|
214
|
+
@parents.each {|p| my_parent_names << p.name.to_s }
|
215
|
+
returnval = false unless parent_names.sort == my_parent_names.sort
|
216
|
+
returnval = false unless @states == variable.states
|
217
|
+
table = variable.probability_table.transpose.last
|
218
|
+
my_table = @probability_table.transpose.last
|
219
|
+
returnval = false unless table == my_table
|
220
|
+
end
|
221
|
+
returnval
|
222
|
+
end
|
223
|
+
end
|
224
|
+
end
|
data/test/sbn.rb
ADDED
@@ -0,0 +1,51 @@
|
|
1
|
+
require 'test/unit'
|
2
|
+
require File.dirname(__FILE__) + '/../lib/sbn'
|
3
|
+
|
4
|
+
class TestCombination < Test::Unit::TestCase # :nodoc:
|
5
|
+
def setup
|
6
|
+
@c = Combination.new([[1, 2], [3, 4, 5]])
|
7
|
+
end
|
8
|
+
|
9
|
+
def test_current
|
10
|
+
test_first
|
11
|
+
test_last
|
12
|
+
end
|
13
|
+
|
14
|
+
def test_each
|
15
|
+
@c.first
|
16
|
+
combinations = [[1, 3], [1, 4], [1, 5], [2, 3], [2, 4], [2, 5]]
|
17
|
+
index = 0
|
18
|
+
@c.each do |comb|
|
19
|
+
assert_equal combinations[index], comb
|
20
|
+
index += 1
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def test_first
|
25
|
+
@c.first
|
26
|
+
assert_equal @c.current, [1, 3]
|
27
|
+
end
|
28
|
+
|
29
|
+
def test_last
|
30
|
+
@c.last
|
31
|
+
assert_equal @c.current, [2, 5]
|
32
|
+
end
|
33
|
+
|
34
|
+
def test_next_combination
|
35
|
+
@c.first
|
36
|
+
assert_equal @c.next_combination, [1, 4]
|
37
|
+
assert_equal @c.next_combination, [1, 5]
|
38
|
+
assert_equal @c.next_combination, [2, 3]
|
39
|
+
assert_equal @c.next_combination, [2, 4]
|
40
|
+
assert_equal @c.next_combination, [2, 5]
|
41
|
+
end
|
42
|
+
|
43
|
+
def test_prev_combination
|
44
|
+
@c.last
|
45
|
+
assert_equal @c.prev_combination, [2, 4]
|
46
|
+
assert_equal @c.prev_combination, [2, 3]
|
47
|
+
assert_equal @c.prev_combination, [1, 5]
|
48
|
+
assert_equal @c.prev_combination, [1, 4]
|
49
|
+
assert_equal @c.prev_combination, [1, 3]
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,80 @@
|
|
1
|
+
require 'test/unit'
|
2
|
+
require File.dirname(__FILE__) + '/../lib/sbn'
|
3
|
+
|
4
|
+
class EnumsTester
|
5
|
+
enums %w(FOO BAR BAZ)
|
6
|
+
bitwise_enums %w(ONE TWO FOUR EIGHT)
|
7
|
+
end
|
8
|
+
|
9
|
+
class TestHelpers < Test::Unit::TestCase # :nodoc:
|
10
|
+
# Tests for Enumerable helpers
|
11
|
+
def test_sum
|
12
|
+
assert_equal 45, (1..9).sum
|
13
|
+
end
|
14
|
+
|
15
|
+
def test_average
|
16
|
+
# Ranges don't have a length
|
17
|
+
assert_in_delta 5.0, (1..9).to_a.average, 0.01
|
18
|
+
end
|
19
|
+
|
20
|
+
def test_sample_variance
|
21
|
+
assert_in_delta 6.6666, (1..9).to_a.sample_variance, 0.0001
|
22
|
+
end
|
23
|
+
|
24
|
+
def test_standard_deviation
|
25
|
+
assert_in_delta 2.5819, (1..9).to_a.standard_deviation, 0.0001
|
26
|
+
end
|
27
|
+
|
28
|
+
def test_enums
|
29
|
+
assert_equal EnumsTester::FOO, 0
|
30
|
+
assert_equal EnumsTester::BAR, 1
|
31
|
+
assert_equal EnumsTester::BAZ, 2
|
32
|
+
assert_equal EnumsTester::ONE, 1
|
33
|
+
assert_equal EnumsTester::TWO, 2
|
34
|
+
assert_equal EnumsTester::FOUR, 4
|
35
|
+
assert_equal EnumsTester::EIGHT, 8
|
36
|
+
end
|
37
|
+
|
38
|
+
def test_to_underscore_sym
|
39
|
+
assert_equal 'THIS IS AN UGLY STRING'.to_underscore_sym, :this_is_an_ugly_string
|
40
|
+
assert_equal 'this is an ugly string'.to_underscore_sym, :this_is_an_ugly_string
|
41
|
+
assert_equal :"this is an ugly string".to_underscore_sym, :this_is_an_ugly_string
|
42
|
+
assert_equal :"THIS IS AN UGLY STRING".to_underscore_sym, :this_is_an_ugly_string
|
43
|
+
end
|
44
|
+
|
45
|
+
def test_symbolize_values
|
46
|
+
assert_not_equal %w(one two three), [:one, :two, :three]
|
47
|
+
assert_equal %w(one two three).symbolize_values, [:one, :two, :three]
|
48
|
+
arr = %w(one two three)
|
49
|
+
arr.symbolize_values!
|
50
|
+
assert_equal arr, [:one, :two, :three]
|
51
|
+
end
|
52
|
+
|
53
|
+
def test_symbolize_keys_and_values
|
54
|
+
assert_not_equal({"one" => "two", "three" => "four"}, {:one => :two, :three => :four})
|
55
|
+
assert_equal({"one" => "two", "three" => "four"}.symbolize_keys_and_values, {:one => :two, :three => :four})
|
56
|
+
h = {"one" => "two", "three" => "four"}
|
57
|
+
h.symbolize_keys_and_values!
|
58
|
+
assert_equal(h, {:one => :two, :three => :four})
|
59
|
+
end
|
60
|
+
|
61
|
+
def test_normalize
|
62
|
+
assert_equal [0.1, 0.1].normalize, [0.5, 0.5]
|
63
|
+
assert_equal [2, 2, 4].normalize, [0.25, 0.25, 0.5]
|
64
|
+
assert_equal [1, 1, 1, 1, 1].normalize, [0.2, 0.2, 0.2, 0.2, 0.2]
|
65
|
+
arr = [1, 1, 1, 1, 1]
|
66
|
+
arr.normalize!
|
67
|
+
assert_equal arr, [0.2, 0.2, 0.2, 0.2, 0.2]
|
68
|
+
end
|
69
|
+
|
70
|
+
def test_ngrams
|
71
|
+
two_ngram_array = ["TH", "HI", "IS", "S ", " I", "IS", "S ", " A", "A ", " S", "ST", "TR", "RI", "IN", "NG"]
|
72
|
+
assert_equal "THIS IS A STRING".ngrams(2), two_ngram_array
|
73
|
+
three_ngram_array = ["THI", "HIS", "IS ", "S I", " IS", "IS ", "S A", " A ", "A S", " ST", "STR", "TRI", "RIN", "ING"]
|
74
|
+
assert_equal "THIS IS A STRING".ngrams(3), three_ngram_array
|
75
|
+
four_ngram_array = ["THIS", "HIS ", "IS I", "S IS", " IS ", "IS A", "S A ", " A S", "A ST", " STR", "STRI", "TRIN", "RING"]
|
76
|
+
assert_equal "THIS IS A STRING".ngrams(4), four_ngram_array
|
77
|
+
five_ngram_array = ["THIS ", "HIS I", "IS IS", "S IS ", " IS A", "IS A ", "S A S", " A ST", "A STR", " STRI", "STRIN", "TRING"]
|
78
|
+
assert_equal "THIS IS A STRING".ngrams(5), five_ngram_array
|
79
|
+
end
|
80
|
+
end
|
@@ -0,0 +1,104 @@
|
|
1
|
+
require 'test/unit'
|
2
|
+
require File.dirname(__FILE__) + '/../lib/sbn'
|
3
|
+
|
4
|
+
class TestLearning < Test::Unit::TestCase # :nodoc:
|
5
|
+
def setup
|
6
|
+
@net = Sbn::Net.new("Categorization")
|
7
|
+
@category = Sbn::Variable.new(@net, :category, [0.33, 0.33, 0.33], [:food, :groceries, :gas])
|
8
|
+
@text = Sbn::StringVariable.new(@net, :text)
|
9
|
+
@category.add_child(@text)
|
10
|
+
end
|
11
|
+
|
12
|
+
def test_string_learning
|
13
|
+
@net.learn([
|
14
|
+
{:category => :food, :text => 'foo'},
|
15
|
+
{:category => :food, :text => 'gro'},
|
16
|
+
{:category => :food, :text => 'foo'},
|
17
|
+
{:category => :food, :text => 'foo'},
|
18
|
+
{:category => :groceries, :text => 'gro'},
|
19
|
+
{:category => :groceries, :text => 'gro'},
|
20
|
+
{:category => :groceries, :text => 'foo'},
|
21
|
+
{:category => :groceries, :text => 'gro'},
|
22
|
+
{:category => :gas, :text => 'gas'},
|
23
|
+
{:category => :gas, :text => 'gas'},
|
24
|
+
{:category => :gas, :text => 'gas'},
|
25
|
+
{:category => :gas, :text => 'gas'}
|
26
|
+
])
|
27
|
+
probs = @category.probability_table.dup
|
28
|
+
food_prob = probs.shift.pop
|
29
|
+
groceries_prob = probs.shift.pop
|
30
|
+
gas_prob = probs.shift.pop
|
31
|
+
assert_in_delta food_prob, 0.333, 0.001
|
32
|
+
assert_in_delta groceries_prob, 0.333, 0.001
|
33
|
+
assert_in_delta gas_prob, 0.333, 0.001
|
34
|
+
end
|
35
|
+
|
36
|
+
def test_is_complete_evidence_eh
|
37
|
+
assert !@text.is_complete_evidence?({})
|
38
|
+
assert !@text.is_complete_evidence?(:text => "doughnuts")
|
39
|
+
assert @text.is_complete_evidence?(:text => "doughnuts", :category => :food)
|
40
|
+
end
|
41
|
+
|
42
|
+
def test_var_add_sample_point
|
43
|
+
assert_raise(RuntimeError) { @category.add_sample_point(:text => "apples") }
|
44
|
+
|
45
|
+
# we have to add at least one sample point to initialize the container
|
46
|
+
@category.add_sample_point(:category => :groceries, :text => "albertsons")
|
47
|
+
sample_point = {:category => :gas, :text => "gas n go"}
|
48
|
+
sample_points = @category.instance_variable_get('@sample_points')
|
49
|
+
assert !sample_points.include?(sample_point)
|
50
|
+
@category.add_sample_point(sample_point)
|
51
|
+
assert sample_points.include?(sample_point)
|
52
|
+
end
|
53
|
+
|
54
|
+
def test_var_set_probabilities_from_sample_points
|
55
|
+
# test regular variable
|
56
|
+
@category.add_sample_point(:category => :food, :text => "foo")
|
57
|
+
@category.add_sample_point(:category => :food, :text => "foo")
|
58
|
+
@category.add_sample_point(:category => :groceries, :text => 'gro')
|
59
|
+
@category.add_sample_point(:category => :gas, :text => 'gas')
|
60
|
+
@category.set_probabilities_from_sample_points!
|
61
|
+
prob_table = @category.instance_variable_get('@probability_table')
|
62
|
+
assert_equal prob_table.transpose.last, [0.4999, 0.2499, 0.2499]
|
63
|
+
|
64
|
+
# test numeric variable
|
65
|
+
basicvar = Sbn::Variable.new(@net, :basicvar)
|
66
|
+
numvar = Sbn::NumericVariable.new(@net, :numvar)
|
67
|
+
numvar.add_parent(basicvar)
|
68
|
+
numvar.add_sample_point(:basicvar => :true, :numvar => 1.0)
|
69
|
+
numvar.add_sample_point(:basicvar => :false, :numvar => 2.0)
|
70
|
+
numvar.add_sample_point(:basicvar => :true, :numvar => 3.0)
|
71
|
+
numvar.add_sample_point(:basicvar => :false, :numvar => 4.0)
|
72
|
+
numvar.add_sample_point(:basicvar => :true, :numvar => 5.0)
|
73
|
+
numvar.set_probabilities_from_sample_points!
|
74
|
+
prob_table = numvar.instance_variable_get('@probability_table')
|
75
|
+
probs = prob_table.transpose.last
|
76
|
+
expected_probs = [0.0001, 0.0001, 0.333233333333333, 0.0001, 0.0001,
|
77
|
+
0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.333233333333333, 0.0001,
|
78
|
+
0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.333233333333333,
|
79
|
+
0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.4999,
|
80
|
+
0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001,
|
81
|
+
0.4999, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001]
|
82
|
+
probs.each {|p| assert_in_delta(p, expected_probs.shift, 0.001) }
|
83
|
+
end
|
84
|
+
|
85
|
+
def test_accumulate_state_frequencies
|
86
|
+
@category.add_sample_point(:category => :food, :text => "foo")
|
87
|
+
@category.add_sample_point(:category => :food, :text => "foo")
|
88
|
+
@category.add_sample_point(:category => :groceries, :text => 'gro')
|
89
|
+
@category.add_sample_point(:category => :gas, :text => 'gas')
|
90
|
+
@category.instance_eval('accumulate_state_frequencies')
|
91
|
+
freq = @category.instance_variable_get('@state_frequencies')
|
92
|
+
assert_equal(freq, {[:groceries] => 1, [:gas] => 1, [:food] => 2})
|
93
|
+
end
|
94
|
+
|
95
|
+
def test_net_add_sample_point
|
96
|
+
set = {:category => :food, :text => "foo"}
|
97
|
+
@net.add_sample_point(set)
|
98
|
+
variables = @net.instance_variable_get('@variables')
|
99
|
+
variables.each do |key, var|
|
100
|
+
sample_points = var.instance_variable_get('@sample_points')
|
101
|
+
assert sample_points.include?(set) if sample_points
|
102
|
+
end
|
103
|
+
end
|
104
|
+
end
|
data/test/test_net.rb
ADDED
@@ -0,0 +1,136 @@
|
|
1
|
+
require 'test/unit'
|
2
|
+
require File.dirname(__FILE__) + '/../lib/sbn'
|
3
|
+
|
4
|
+
class TestNet < Test::Unit::TestCase # :nodoc:
|
5
|
+
def setup
|
6
|
+
@net = Sbn::Net.new("Grass Wetness Belief Net")
|
7
|
+
@cloudy = Sbn::Variable.new(@net, :cloudy, [0.5, 0.5])
|
8
|
+
@sprinkler = Sbn::Variable.new(@net, :sprinkler, [0.1, 0.9, 0.5, 0.5])
|
9
|
+
@rain = Sbn::Variable.new(@net, :rain, [0.8, 0.2, 0.2, 0.8])
|
10
|
+
@grass_wet = Sbn::Variable.new(@net, :grass_wet, [0.99, 0.01, 0.9, 0.1, 0.9, 0.1, 0.0, 1.0])
|
11
|
+
@cloudy.add_child(@sprinkler)
|
12
|
+
@cloudy.add_child(@rain)
|
13
|
+
@sprinkler.add_child(@grass_wet)
|
14
|
+
@rain.add_child(@grass_wet)
|
15
|
+
@evidence = {:sprinkler => :false, :rain => :true}
|
16
|
+
end
|
17
|
+
|
18
|
+
def generate_simple_network
|
19
|
+
net = Sbn::Net.new("Test")
|
20
|
+
var1 = Sbn::Variable.new(net, :var1)
|
21
|
+
var2 = Sbn::Variable.new(net, :var2, [0.25, 0.75, 0.75, 0.25])
|
22
|
+
var2.add_parent(var1)
|
23
|
+
net
|
24
|
+
end
|
25
|
+
|
26
|
+
def test_inference_on_grasswet_with_evidence_querying_grasswet
|
27
|
+
@net.set_evidence @evidence
|
28
|
+
probs = @net.query_variable(:grass_wet)
|
29
|
+
assert_in_delta(probs[:true], 0.9, 0.1)
|
30
|
+
assert_in_delta(probs[:false], 0.1, 0.1)
|
31
|
+
end
|
32
|
+
|
33
|
+
def test_inference_on_grasswet_with_evidence_querying_cloudy
|
34
|
+
@net.set_evidence @evidence
|
35
|
+
probs = @net.query_variable(:cloudy)
|
36
|
+
assert_in_delta(probs[:true], 0.8780487804878049, 0.1)
|
37
|
+
assert_in_delta(probs[:false], 0.12195121951219512, 0.1)
|
38
|
+
end
|
39
|
+
|
40
|
+
def test_inference_on_grasswet_without_evidence_querying_cloudy
|
41
|
+
@net.set_evidence({})
|
42
|
+
probs = @net.query_variable(:cloudy)
|
43
|
+
assert_in_delta(probs[:true], 0.5, 0.1)
|
44
|
+
assert_in_delta(probs[:false], 0.5, 0.1)
|
45
|
+
end
|
46
|
+
|
47
|
+
def test_inference_on_grasswet_without_evidence_querying_sprinkler
|
48
|
+
@net.set_evidence({})
|
49
|
+
probs = @net.query_variable(:sprinkler)
|
50
|
+
assert_in_delta(probs[:true], 0.3, 0.1)
|
51
|
+
assert_in_delta(probs[:false], 0.7, 0.1)
|
52
|
+
end
|
53
|
+
|
54
|
+
def test_inference_on_grasswet_without_evidence_querying_rain
|
55
|
+
@net.set_evidence({})
|
56
|
+
probs = @net.query_variable(:rain)
|
57
|
+
assert_in_delta(probs[:true], 0.5, 0.1)
|
58
|
+
assert_in_delta(probs[:false], 0.5, 0.1)
|
59
|
+
end
|
60
|
+
|
61
|
+
def test_inference_on_grasswet_without_evidence_querying_grasswet
|
62
|
+
@net.set_evidence({})
|
63
|
+
probs = @net.query_variable(:grass_wet)
|
64
|
+
assert_in_delta(probs[:true], 0.6471, 0.1)
|
65
|
+
assert_in_delta(probs[:false], 0.3529, 0.1)
|
66
|
+
end
|
67
|
+
|
68
|
+
def test_import_export
|
69
|
+
output = @net.to_xmlbif
|
70
|
+
newnet = Sbn::Net.from_xmlbif(output)
|
71
|
+
assert_equal @net, newnet
|
72
|
+
end
|
73
|
+
|
74
|
+
def test_equality
|
75
|
+
net1 = generate_simple_network
|
76
|
+
net2 = generate_simple_network
|
77
|
+
net3 = Sbn::Net.new('Another Net')
|
78
|
+
assert_equal net1, net2
|
79
|
+
assert_not_equal net1, net3
|
80
|
+
end
|
81
|
+
|
82
|
+
def test_add_variable
|
83
|
+
net = generate_simple_network
|
84
|
+
assert net.instance_variable_get('@variables').has_key?(:var1)
|
85
|
+
assert net.instance_variable_get('@variables').has_key?(:var2)
|
86
|
+
end
|
87
|
+
|
88
|
+
def test_query_variable
|
89
|
+
net = generate_simple_network
|
90
|
+
net.set_evidence :var1 => :true
|
91
|
+
probs = net.query_variable(:var2)
|
92
|
+
assert probs.has_key?(:true)
|
93
|
+
assert probs.has_key?(:false)
|
94
|
+
assert_in_delta(probs[:true], 0.25, 0.1)
|
95
|
+
assert_in_delta(probs[:false], 0.75, 0.1)
|
96
|
+
end
|
97
|
+
|
98
|
+
def test_set_evidence
|
99
|
+
net = generate_simple_network
|
100
|
+
net.set_evidence :var2 => :false
|
101
|
+
variables = net.instance_variable_get('@variables')
|
102
|
+
evidence = net.instance_variable_get('@evidence')
|
103
|
+
assert variables[:var2].set_in_evidence?(evidence)
|
104
|
+
assert !variables[:var1].set_in_evidence?(evidence)
|
105
|
+
end
|
106
|
+
|
107
|
+
def test_symbolize_evidence
|
108
|
+
# Create a network with each kind of variable and make sure evidence
|
109
|
+
# transformation works.
|
110
|
+
net = Sbn::Net.new("Test")
|
111
|
+
basic_var = Sbn::Variable.new(net, :basic_var)
|
112
|
+
string_var = Sbn::StringVariable.new(net, :string_var)
|
113
|
+
num_var = Sbn::NumericVariable.new(net, :num_var, [0.5, 0.5], [1.0])
|
114
|
+
string_var.add_sample_point({:basic_var => :true, :string_var => "test", :num_var => 1.5})
|
115
|
+
|
116
|
+
# should not be able to set covariables directly
|
117
|
+
assert_raise(RuntimeError) { net.set_evidence({:string_var_covar_1 => :true}) }
|
118
|
+
|
119
|
+
net.set_evidence 'BASIC VAR' => 'true', 'string_var' => "TesT", 'num_var' => 3
|
120
|
+
evidence = net.instance_variable_get('@evidence')
|
121
|
+
assert evidence.has_key?(:basic_var)
|
122
|
+
assert !evidence.has_key?('BASIC VAR')
|
123
|
+
assert evidence[:basic_var].is_a?(Symbol)
|
124
|
+
assert_equal evidence[:basic_var], :true
|
125
|
+
|
126
|
+
assert evidence.has_key?(:string_var)
|
127
|
+
assert !evidence.has_key?('string_var')
|
128
|
+
assert evidence[:string_var].is_a?(String)
|
129
|
+
assert_equal evidence[:string_var], evidence[:string_var].downcase
|
130
|
+
|
131
|
+
assert evidence.has_key?(:num_var)
|
132
|
+
assert !evidence.has_key?('num_var')
|
133
|
+
assert evidence[:num_var].is_a?(Float)
|
134
|
+
assert_equal evidence[:num_var], 3.0
|
135
|
+
end
|
136
|
+
end
|