bayesnet 0.0.2 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/Gemfile +1 -1
- data/Gemfile.lock +6 -29
- data/README.md +5 -2
- data/Rakefile +5 -2
- data/bayesnet.gemspec +0 -1
- data/doc/morning-mood-model.png +0 -0
- data/lib/bayesnet/dsl.rb +4 -0
- data/lib/bayesnet/error.rb +2 -0
- data/lib/bayesnet/factor.rb +138 -41
- data/lib/bayesnet/graph.rb +114 -13
- data/lib/bayesnet/logging.rb +13 -0
- data/lib/bayesnet/node.rb +25 -9
- data/lib/bayesnet/parsers/bif.rb +2484 -0
- data/lib/bayesnet/parsers/bif.treetop +250 -0
- data/lib/bayesnet/parsers/builder.rb +37 -0
- data/lib/bayesnet/version.rb +1 -1
- data/lib/bayesnet.rb +7 -0
- metadata +6 -16
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 7591665046345784f55275c06d1129fd91ee3f098f3800b2c03b6f9bbfd8e172
|
4
|
+
data.tar.gz: ec9009ab90593d42fa2506a230e5900d5a39bebb1a7fbd874953d6c86022b2eb
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 5e668b431f55f9239ad3ae06cdc020098ff4a8b68f7934283d6f77a3969a014aa17f68df12b23013d1681e79fdecd8f4c8e4da105e9430a16d2ab8075bbcca7b
|
7
|
+
data.tar.gz: 75eceac300152cfa8d0ce736b16939e779f127dc844ea0c8ce8e1d0f363b04048107533f64c744420b2631fd3e8678d6812e15c6b770e2b37a598ae290af773a
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,18 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.6.0] - 2022-06-26
|
4
|
+
- Using variables elimination algorithm to build a distribution
|
5
|
+
|
6
|
+
## [0.5.0] - 2022-02-26
|
7
|
+
|
8
|
+
- Constructing networks out of the `.BIF` ([Interchange Format for Bayesian Networks](https://www.cs.washington.edu/dm/vfml/appendixes/bif.htm)) files.
|
9
|
+
- Fixing inference bug
|
10
|
+
- Network children nodes could be specified ***before** their parents
|
11
|
+
|
12
|
+
## [0.0.3] - 2021-12-29
|
13
|
+
|
14
|
+
- Fixing terminology used in Factor class
|
15
|
+
|
3
16
|
## [0.0.2] - 2021-12-28
|
4
17
|
|
5
18
|
- README, CI/CD for Ruby 2.6, 2.7, 3.1 added
|
data/Gemfile
CHANGED
data/Gemfile.lock
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
PATH
|
2
2
|
remote: .
|
3
3
|
specs:
|
4
|
-
bayesnet (0.0
|
4
|
+
bayesnet (0.6.0)
|
5
5
|
|
6
6
|
GEM
|
7
7
|
remote: https://rubygems.org/
|
8
8
|
specs:
|
9
|
-
ast (2.4.2)
|
10
9
|
byebug (11.1.3)
|
11
10
|
coderay (1.1.3)
|
12
11
|
m (1.5.1)
|
@@ -14,38 +13,16 @@ GEM
|
|
14
13
|
rake (>= 0.9.2.2)
|
15
14
|
method_source (1.0.0)
|
16
15
|
minitest (5.15.0)
|
17
|
-
|
18
|
-
parser (3.0.3.2)
|
19
|
-
ast (~> 2.4.1)
|
16
|
+
polyglot (0.3.5)
|
20
17
|
pry (0.13.1)
|
21
18
|
coderay (~> 1.1)
|
22
19
|
method_source (~> 1.0)
|
23
20
|
pry-byebug (3.9.0)
|
24
21
|
byebug (~> 11.0)
|
25
22
|
pry (~> 0.13.0)
|
26
|
-
rainbow (3.0.0)
|
27
23
|
rake (13.0.6)
|
28
|
-
|
29
|
-
|
30
|
-
rubocop (1.23.0)
|
31
|
-
parallel (~> 1.10)
|
32
|
-
parser (>= 3.0.0.0)
|
33
|
-
rainbow (>= 2.2.2, < 4.0)
|
34
|
-
regexp_parser (>= 1.8, < 3.0)
|
35
|
-
rexml
|
36
|
-
rubocop-ast (>= 1.12.0, < 2.0)
|
37
|
-
ruby-progressbar (~> 1.7)
|
38
|
-
unicode-display_width (>= 1.4.0, < 3.0)
|
39
|
-
rubocop-ast (1.15.0)
|
40
|
-
parser (>= 3.0.1.1)
|
41
|
-
rubocop-performance (1.12.0)
|
42
|
-
rubocop (>= 1.7.0, < 2.0)
|
43
|
-
rubocop-ast (>= 0.4.0)
|
44
|
-
ruby-progressbar (1.11.0)
|
45
|
-
standard (1.5.0)
|
46
|
-
rubocop (= 1.23.0)
|
47
|
-
rubocop-performance (= 1.12.0)
|
48
|
-
unicode-display_width (2.1.0)
|
24
|
+
treetop (1.6.11)
|
25
|
+
polyglot (~> 0.3)
|
49
26
|
|
50
27
|
PLATFORMS
|
51
28
|
x86_64-darwin-19
|
@@ -57,7 +34,7 @@ DEPENDENCIES
|
|
57
34
|
minitest (~> 5.0)
|
58
35
|
pry-byebug (~> 3.9.0)
|
59
36
|
rake (~> 13.0)
|
60
|
-
|
37
|
+
treetop (~> 1.6)
|
61
38
|
|
62
39
|
BUNDLED WITH
|
63
|
-
2.
|
40
|
+
2.3.3
|
data/README.md
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# Bayesnet
|
2
2
|
|
3
|
-
This gem provides an
|
3
|
+
This gem provides an DSL for constructing Bayesian networks and let to execute basic inference queries. It is also capable of parsing .BIF format ([The Interchange Format for Bayesian Networks](https://www.cs.washington.edu/dm/vfml/appendixes/bif.htm)).
|
4
|
+
|
4
5
|
|
5
6
|
### Example:
|
6
7
|
|
@@ -66,12 +67,14 @@ The inference is based on summing over joint distribution, i.e. it is the simple
|
|
66
67
|
most expensive way to calculate it. No optimization is implemented in this version; the code
|
67
68
|
is more a proof of API.
|
68
69
|
|
70
|
+
### [Another example](https://afurmanov.com/reducing-anxiety-with-bayesian-network) of using this gem
|
71
|
+
|
69
72
|
## Installation
|
70
73
|
|
71
74
|
Add this line to your application's Gemfile:
|
72
75
|
|
73
76
|
```ruby
|
74
|
-
|
77
|
+
em 'bayesnet'
|
75
78
|
```
|
76
79
|
|
77
80
|
And then execute:
|
data/Rakefile
CHANGED
@@ -9,6 +9,9 @@ Rake::TestTask.new(:test) do |t|
|
|
9
9
|
t.test_files = FileList["test/**/*_test.rb"]
|
10
10
|
end
|
11
11
|
|
12
|
-
|
12
|
+
Rake::TestTask.new("regen-bif") do |t|
|
13
|
+
`rm ./lib/bayesnet/parsers/bif.rb`
|
14
|
+
`tt ./lib/bayesnet/parsers/bif.treetop`
|
15
|
+
end
|
13
16
|
|
14
|
-
task default: %i[test
|
17
|
+
task default: %i[test]
|
data/bayesnet.gemspec
CHANGED
@@ -34,7 +34,6 @@ Gem::Specification.new do |spec|
|
|
34
34
|
spec.add_development_dependency "m", "~> 1.5.0"
|
35
35
|
spec.add_development_dependency "minitest", "~> 5.0"
|
36
36
|
spec.add_development_dependency "pry-byebug", "~> 3.9.0"
|
37
|
-
spec.add_development_dependency "standard", "~> 1.3"
|
38
37
|
|
39
38
|
# For more information and examples about making a new gem, checkout our
|
40
39
|
# guide at: https://bundler.io/guides/creating_gem.html
|
data/doc/morning-mood-model.png
CHANGED
Binary file
|
data/lib/bayesnet/dsl.rb
CHANGED
data/lib/bayesnet/error.rb
CHANGED
data/lib/bayesnet/factor.rb
CHANGED
@@ -1,89 +1,186 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
module Bayesnet
|
2
|
-
# Factor if a function of
|
4
|
+
# Factor if a function of several variables (A, B, ...), where
|
5
|
+
# every variable cold take values from some finite set
|
3
6
|
class Factor
|
7
|
+
# +++ Factor DSL +++
|
8
|
+
#
|
9
|
+
# Factor DSL entry point:
|
4
10
|
def self.build(&block)
|
5
11
|
factor = new
|
6
12
|
factor.instance_eval(&block)
|
7
13
|
factor
|
8
14
|
end
|
9
15
|
|
10
|
-
#
|
11
|
-
|
12
|
-
|
16
|
+
# Factor DSL
|
17
|
+
# Defining variable with list of its possible values looks like:
|
18
|
+
# ```
|
19
|
+
# Bayesnet::Factor.build do
|
20
|
+
# scope weather: %i[sunny cloudy]
|
21
|
+
# scope mood: %i[bad good]
|
22
|
+
# ...
|
23
|
+
# ```
|
24
|
+
# ^ this code defines to variables `weather` and `mood`, where
|
25
|
+
# `weather` could be :sunny or :cloudy, and
|
26
|
+
# `mood` could be :bad or :good
|
27
|
+
def scope(var_name_to_values = nil)
|
28
|
+
if var_name_to_values
|
29
|
+
@scope.merge!(var_name_to_values)
|
30
|
+
else
|
31
|
+
@scope
|
32
|
+
end
|
13
33
|
end
|
14
34
|
|
15
|
-
#
|
16
|
-
|
17
|
-
|
18
|
-
|
35
|
+
# Factor DSL
|
36
|
+
# Specifies factor value for some set of variable values, i.e.
|
37
|
+
# ```
|
38
|
+
# Bayesnet::Factor.build do
|
39
|
+
# scope weather: %i[sunny cloudy]
|
40
|
+
# scope mood: %i[bad good]
|
41
|
+
# val :sunny, :bad, 0.1
|
42
|
+
# ...
|
43
|
+
# ```
|
44
|
+
# ^ this code says the value of factor for [weather == :sunny, mood == :bad] is 0.1
|
45
|
+
def val(*context_and_val)
|
46
|
+
context_and_val = context_and_val[0] if context_and_val.size == 1 && context_and_val[0].is_a?(Array)
|
47
|
+
@vals[context_and_val[0..-2]] = context_and_val[-1]
|
19
48
|
end
|
49
|
+
# --- Factor DSL ---
|
20
50
|
|
51
|
+
# List of variable names
|
21
52
|
def var_names
|
22
|
-
@
|
53
|
+
@scope.keys
|
23
54
|
end
|
24
55
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
56
|
+
# accessor factor value, i.e
|
57
|
+
# ```
|
58
|
+
# factor = Bayesnet::Factor.build do
|
59
|
+
# scope weather: %i[sunny cloudy]
|
60
|
+
# scope mood: %i[bad good]
|
61
|
+
# val :sunny, :bad, 0.1
|
62
|
+
# ...
|
63
|
+
# end
|
64
|
+
# factor[:sunny, :bad] # 0.1
|
65
|
+
# ```
|
66
|
+
def [](*context)
|
67
|
+
key = if context.size == 1 && context[0].is_a?(Hash)
|
68
|
+
context[0].slice(*var_names).values
|
69
|
+
else
|
70
|
+
context
|
71
|
+
end
|
31
72
|
@vals[key]
|
32
73
|
end
|
33
74
|
|
34
|
-
|
35
|
-
|
36
|
-
end
|
37
|
-
|
38
|
-
def args(*var_names)
|
75
|
+
# returns all combinations of values of `var_names`
|
76
|
+
def contextes(*var_names)
|
39
77
|
return [] if var_names.empty?
|
40
|
-
|
78
|
+
|
79
|
+
@scope[var_names[0]].product(*var_names[1..].map { |var_name| @scope[var_name] })
|
41
80
|
end
|
42
81
|
|
82
|
+
# returns all possible values
|
43
83
|
def values
|
44
84
|
@vals.values
|
45
85
|
end
|
46
86
|
|
87
|
+
# returns new normalized factor, i.e. where sum of all values is 1.0
|
47
88
|
def normalize
|
48
89
|
vals = @vals.clone
|
49
90
|
norm_factor = vals.map(&:last).sum * 1.0
|
50
|
-
vals.each { |k,
|
51
|
-
self.class.new(@
|
91
|
+
vals.each { |k, _v| vals[k] /= norm_factor }
|
92
|
+
self.class.new(@scope.clone, vals)
|
52
93
|
end
|
53
94
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
95
|
+
# Returns factor built as follows:
|
96
|
+
# 1. Original factor gets filtered out by variables having values compatible with `context`
|
97
|
+
# 2. Returned factor does not have any variables from `context` (because they have
|
98
|
+
# same values, after step 1)
|
99
|
+
# The `context` argument supposed to be an evidence, somewhat like
|
100
|
+
# `{weather: :sunny}`
|
101
|
+
def reduce_to(context)
|
102
|
+
limited_context = context.slice(*scope.keys)
|
103
|
+
return self.class.new(@scope, @vals) if limited_context.empty?
|
104
|
+
limited_scope = @scope.slice(*(@scope.keys - limited_context.keys))
|
58
105
|
|
59
|
-
|
60
|
-
indices =
|
61
|
-
vals = @vals.select { |k,
|
62
|
-
vals.transform_keys! { |k| k
|
106
|
+
context_vals = limited_context.values
|
107
|
+
indices = limited_context.keys.map { |k| index_by_var_name[k] }
|
108
|
+
vals = @vals.select { |k, _v| indices.map { |i| k[i] } == context_vals }
|
109
|
+
vals.transform_keys! { |k| delete_by_indices(k, indices) }
|
63
110
|
|
64
|
-
self.class.new(
|
111
|
+
self.class.new(limited_scope, vals)
|
65
112
|
end
|
66
113
|
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
114
|
+
# Returns new context defined over `var_names`, all other variables
|
115
|
+
# get eliminated. For every combination of `var_names`'s values
|
116
|
+
# the value of new factor is defined by summing up values in original factor
|
117
|
+
# having compatible value
|
118
|
+
def marginalize(var_names)
|
119
|
+
scope = @scope.slice(*var_names)
|
120
|
+
|
121
|
+
indices = scope.keys.map { |k| index_by_var_name[k] }
|
122
|
+
vals = @vals.group_by { |context, _val| indices.map { |i| context[i] } }
|
71
123
|
vals.transform_values! { |v| v.map(&:last).sum }
|
72
|
-
|
73
|
-
|
124
|
+
|
125
|
+
self.class.new(scope, vals)
|
126
|
+
end
|
127
|
+
|
128
|
+
def eliminate(var_name)
|
129
|
+
keep_var_names = var_names
|
130
|
+
keep_var_names.delete(var_name)
|
131
|
+
marginalize(keep_var_names)
|
132
|
+
end
|
133
|
+
|
134
|
+
def select(subcontext)
|
135
|
+
@vals.select do |context, _|
|
136
|
+
var_names.zip(context).slice(subcontext.keys) == subcontext
|
137
|
+
end
|
138
|
+
end
|
139
|
+
|
140
|
+
def *(other)
|
141
|
+
common_scope = @scope.keys & other.scope.keys
|
142
|
+
new_scope = scope.merge(other.scope)
|
143
|
+
new_vals = {}
|
144
|
+
group1 = group_by_scope_values(common_scope)
|
145
|
+
group2 = other.group_by_scope_values(common_scope)
|
146
|
+
group1.each do |scope, vals1|
|
147
|
+
combo = vals1.product(group2[scope])
|
148
|
+
combo.each do |(val1, val2)|
|
149
|
+
# values in scope must match variables order in new_scope, i.e.
|
150
|
+
# they must match `new_scope.var_names`
|
151
|
+
# The code bellow ensures it by merging two hashes in the same
|
152
|
+
# wasy as `new_scope`` is constructed above
|
153
|
+
val_by_name1 = var_names.zip(val1.first).to_h
|
154
|
+
val_by_name2 = other.var_names.zip(val2.first).to_h
|
155
|
+
new_vals[val_by_name1.merge(val_by_name2).values] = val1.last*val2.last
|
156
|
+
end
|
157
|
+
end
|
158
|
+
Factor.new(new_scope, new_vals)
|
159
|
+
end
|
160
|
+
|
161
|
+
def group_by_scope_values(scope_keys)
|
162
|
+
indices = scope_keys.map { |k| index_by_var_name[k] }
|
163
|
+
@vals.group_by { |context, _val| indices.map { |i| context[i] } }
|
74
164
|
end
|
75
165
|
|
76
166
|
private
|
77
167
|
|
78
|
-
def
|
79
|
-
|
168
|
+
def delete_by_indices(array, indices)
|
169
|
+
result = array.dup
|
170
|
+
indices.map { |i| result[i] = nil }
|
171
|
+
result.compact
|
172
|
+
end
|
173
|
+
|
174
|
+
def initialize(scope = {}, vals = {})
|
175
|
+
@scope = scope
|
80
176
|
@vals = vals
|
81
177
|
end
|
82
178
|
|
83
179
|
def index_by_var_name
|
84
180
|
return @index_by_var_name if @index_by_var_name
|
181
|
+
|
85
182
|
@index_by_var_name = {}
|
86
|
-
@
|
183
|
+
@scope.each_with_index { |(k, _v), i| @index_by_var_name[k] = i }
|
87
184
|
@index_by_var_name
|
88
185
|
end
|
89
186
|
end
|
data/lib/bayesnet/graph.rb
CHANGED
@@ -1,33 +1,123 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
require "bayesnet/node"
|
2
4
|
|
3
5
|
module Bayesnet
|
6
|
+
# Acyclic graph
|
4
7
|
class Graph
|
8
|
+
include Bayesnet::Logging
|
9
|
+
|
5
10
|
attr_reader :nodes
|
6
11
|
|
7
12
|
def initialize
|
8
13
|
@nodes = {}
|
9
14
|
end
|
10
15
|
|
11
|
-
|
12
|
-
nodes.keys
|
13
|
-
end
|
14
|
-
|
16
|
+
# +++ Graph DSL +++
|
15
17
|
def node(name, parents: [], &block)
|
16
18
|
raise Error, "DSL error, #node requires a &block" unless block
|
17
|
-
|
19
|
+
|
20
|
+
node = Node.new(name, parents)
|
18
21
|
node.instance_eval(&block)
|
19
22
|
@nodes[name] = node
|
20
23
|
end
|
24
|
+
# --- Graph DSL ---
|
25
|
+
|
26
|
+
# returns names of all nodes
|
27
|
+
def var_names
|
28
|
+
nodes.keys
|
29
|
+
end
|
30
|
+
|
31
|
+
# returns normalized distribution reduced to `evidence`
|
32
|
+
# and marginalized over `over`
|
33
|
+
def distribution(over: [], evidence: {}, algorithm: :variables_elimination)
|
34
|
+
case algorithm
|
35
|
+
when :brute_force
|
36
|
+
joint_distribution
|
37
|
+
.reduce_to(evidence)
|
38
|
+
.marginalize(over)
|
39
|
+
.normalize
|
40
|
+
when :variables_elimination
|
41
|
+
reduced_factors = nodes.values.map(&:factor).map { |f| f.reduce_to(evidence) }
|
42
|
+
not_include_in_order = evidence.keys.to_set + over.to_set
|
43
|
+
variables_order = elimination_order.reject { |v| not_include_in_order.include?(v) }
|
44
|
+
distribution = eliminate_variables(variables_order, reduced_factors)
|
45
|
+
distribution.normalize
|
46
|
+
else
|
47
|
+
raise "Uknown algorithm #{algorithm}"
|
48
|
+
end
|
49
|
+
end
|
50
|
+
|
51
|
+
def elimination_order
|
52
|
+
return @order if @order
|
53
|
+
@order = []
|
54
|
+
edges = Set.new
|
55
|
+
@nodes.each do |name, node|
|
56
|
+
parents = node.parent_nodes.keys
|
57
|
+
parents.each { |p| edges.add([name, p].to_set) }
|
58
|
+
parents.combination(2) { |p1, p2| edges.add([p1, p2].to_set) }
|
59
|
+
end
|
60
|
+
# edges now are moralized graph of `self`, just represented differently as
|
61
|
+
# set of edges
|
62
|
+
|
63
|
+
remaining_nodes = nodes.keys.to_set
|
64
|
+
until remaining_nodes.empty?
|
65
|
+
best_node = find_min_neighbor(remaining_nodes, edges)
|
66
|
+
remaining_nodes.delete(best_node)
|
67
|
+
@order.push(best_node)
|
68
|
+
clique = edges.select { |e| e.include?(best_node) }
|
69
|
+
edges -= clique
|
70
|
+
if edges.empty? #i.e. clique is the last edge
|
71
|
+
@order += remaining_nodes.to_a
|
72
|
+
remaining_nodes = Set.new
|
73
|
+
end
|
74
|
+
clique.
|
75
|
+
map { |e| e.delete(best_node) }.
|
76
|
+
map(&:first).
|
77
|
+
combination(2) { |p1, p2| edges.add([p1,p2].to_set) }
|
78
|
+
end
|
79
|
+
@order
|
80
|
+
end
|
21
81
|
|
22
|
-
def
|
23
|
-
|
24
|
-
|
82
|
+
def find_min_neighbor(remaining_nodes, edges)
|
83
|
+
result = nil
|
84
|
+
min_neighbors = nil
|
85
|
+
remaining_nodes.each do |name, _|
|
86
|
+
neighbors = edges.count { |e| e.include?(name) }
|
87
|
+
if min_neighbors.nil? || neighbors < min_neighbors
|
88
|
+
min_neighbors = neighbors
|
89
|
+
result = name
|
90
|
+
end
|
91
|
+
end
|
92
|
+
result
|
93
|
+
end
|
94
|
+
|
95
|
+
def eliminate_variables(variables_order, factors)
|
96
|
+
logger.debug "Eliminating variables #{variables_order} from #{factors.size} factors #{factors.map(&:var_names)}"
|
97
|
+
remaining_factors = factors.to_set
|
98
|
+
variables_order.each do |var_name|
|
99
|
+
logger.debug "Eliminating '#{var_name}'..."
|
100
|
+
grouped_factors = remaining_factors.select { |f| f.var_names.include?(var_name) }
|
101
|
+
remaining_factors -= grouped_factors
|
102
|
+
logger.debug "Building new factor out of #{grouped_factors.size} factors having '#{var_name}' - #{grouped_factors.map(&:var_names)}"
|
103
|
+
product_factor = grouped_factors.reduce(&:*)
|
104
|
+
logger.debug "Removing variable from new factor"
|
105
|
+
new_factor = product_factor.eliminate(var_name)
|
106
|
+
logger.debug "New factor variables are #{new_factor.var_names}"
|
107
|
+
remaining_factors.add(new_factor)
|
108
|
+
logger.debug "The variable '#{var_name}' is elminated"
|
109
|
+
end
|
110
|
+
logger.debug "Non-eliminated variables are #{remaining_factors.map(&:var_names).flatten.uniq}"
|
111
|
+
result = remaining_factors.reduce(&:*)
|
112
|
+
logger.debug "Eliminating is done"
|
113
|
+
result
|
25
114
|
end
|
26
115
|
|
27
116
|
# This is MAP query, i.e. Maximum a Posteriory
|
117
|
+
# returns value of `var_name` having maximum likelihood, when `evidence` is observed
|
28
118
|
def most_likely_value(var_name, evidence:)
|
29
119
|
posterior_distribution = distribution(over: [var_name], evidence: evidence)
|
30
|
-
mode = posterior_distribution.
|
120
|
+
mode = posterior_distribution.contextes(var_name).zip(posterior_distribution.values).max_by(&:last)
|
31
121
|
mode.first.first
|
32
122
|
end
|
33
123
|
|
@@ -37,6 +127,7 @@ module Bayesnet
|
|
37
127
|
posterior_distribution[*over_vars.values]
|
38
128
|
end
|
39
129
|
|
130
|
+
# Essentially it builds product of all node's factors
|
40
131
|
def joint_distribution
|
41
132
|
return @joint_distribution if @joint_distribution
|
42
133
|
|
@@ -47,17 +138,27 @@ module Bayesnet
|
|
47
138
|
|
48
139
|
factor = Factor.new
|
49
140
|
@nodes.each do |node_name, node|
|
50
|
-
factor.
|
141
|
+
factor.scope node_name => node.values
|
51
142
|
end
|
52
143
|
|
53
|
-
factor.
|
54
|
-
val_by_name = var_names.zip(
|
144
|
+
factor.contextes(*var_names).each do |context|
|
145
|
+
val_by_name = var_names.zip(context).to_h
|
55
146
|
val = nodes.values.reduce(1.0) do |prob, node|
|
56
147
|
prob * node.factor[val_by_name]
|
57
148
|
end
|
58
|
-
factor.val
|
149
|
+
factor.val context + [val]
|
59
150
|
end
|
60
151
|
@joint_distribution = factor.normalize
|
61
152
|
end
|
153
|
+
|
154
|
+
def parameters
|
155
|
+
nodes.values.map(&:parameters).sum
|
156
|
+
end
|
157
|
+
|
158
|
+
def resolve_factors
|
159
|
+
@nodes.values.each do |node|
|
160
|
+
node.resolve_factor(@nodes.slice(*node.parent_nodes))
|
161
|
+
end
|
162
|
+
end
|
62
163
|
end
|
63
164
|
end
|
data/lib/bayesnet/node.rb
CHANGED
@@ -6,8 +6,10 @@ module Bayesnet
|
|
6
6
|
@name = name
|
7
7
|
@parent_nodes = parent_nodes
|
8
8
|
@values = []
|
9
|
+
@factor = Factor.new
|
9
10
|
end
|
10
11
|
|
12
|
+
# +++ Node DSL +++
|
11
13
|
def values(hash_or_array = nil, &block)
|
12
14
|
case hash_or_array
|
13
15
|
when NilClass
|
@@ -16,7 +18,7 @@ module Bayesnet
|
|
16
18
|
@values = hash_or_array.keys
|
17
19
|
node = self
|
18
20
|
@factor = Factor.build do
|
19
|
-
|
21
|
+
scope node.name => node.values
|
20
22
|
hash_or_array.each do |value, probability|
|
21
23
|
val [value, probability]
|
22
24
|
end
|
@@ -24,25 +26,39 @@ module Bayesnet
|
|
24
26
|
when Array
|
25
27
|
raise Error, "DSL error, #values requires a &block when first argument is an Array" unless block
|
26
28
|
@values = hash_or_array
|
27
|
-
|
28
|
-
@factor = Factor.build do
|
29
|
-
var node.name => node.values
|
30
|
-
node.parent_nodes.each do |parent_node_name, parent_node|
|
31
|
-
var parent_node_name => parent_node.values
|
32
|
-
end
|
33
|
-
end
|
34
|
-
instance_eval(&block)
|
29
|
+
@factor = block
|
35
30
|
end
|
36
31
|
end
|
37
32
|
|
38
33
|
def distributions(&block)
|
39
34
|
instance_eval(&block)
|
40
35
|
end
|
36
|
+
# --- Node DSL ---
|
37
|
+
|
38
|
+
def parameters
|
39
|
+
(values.size - 1) * parent_nodes.values.reduce(1) { |mul, n| mul * n.values.size }
|
40
|
+
end
|
41
41
|
|
42
42
|
def as(distribution, given:)
|
43
43
|
@values.zip(distribution).each do |value, probability|
|
44
44
|
@factor.val [value] + given + [probability]
|
45
45
|
end
|
46
46
|
end
|
47
|
+
|
48
|
+
def resolve_factor(parent_nodes)
|
49
|
+
@parent_nodes = parent_nodes
|
50
|
+
if @factor.is_a?(Proc)
|
51
|
+
proc = @factor
|
52
|
+
node = self
|
53
|
+
@factor = Factor.build do
|
54
|
+
scope node.name => node.values
|
55
|
+
node.parent_nodes.each do |parent_node_name, parent_node|
|
56
|
+
scope parent_node_name => parent_node.values
|
57
|
+
end
|
58
|
+
end
|
59
|
+
instance_eval(&proc)
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
47
63
|
end
|
48
64
|
end
|