bayesnet 0.0.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/Gemfile +1 -1
- data/Gemfile.lock +6 -29
- data/README.md +5 -2
- data/Rakefile +5 -2
- data/bayesnet.gemspec +0 -1
- data/doc/morning-mood-model.png +0 -0
- data/lib/bayesnet/dsl.rb +4 -0
- data/lib/bayesnet/error.rb +2 -0
- data/lib/bayesnet/factor.rb +138 -41
- data/lib/bayesnet/graph.rb +114 -13
- data/lib/bayesnet/logging.rb +13 -0
- data/lib/bayesnet/node.rb +25 -9
- data/lib/bayesnet/parsers/bif.rb +2484 -0
- data/lib/bayesnet/parsers/bif.treetop +250 -0
- data/lib/bayesnet/parsers/builder.rb +37 -0
- data/lib/bayesnet/version.rb +1 -1
- data/lib/bayesnet.rb +7 -0
- metadata +6 -16
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 7591665046345784f55275c06d1129fd91ee3f098f3800b2c03b6f9bbfd8e172
|
4
|
+
data.tar.gz: ec9009ab90593d42fa2506a230e5900d5a39bebb1a7fbd874953d6c86022b2eb
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 5e668b431f55f9239ad3ae06cdc020098ff4a8b68f7934283d6f77a3969a014aa17f68df12b23013d1681e79fdecd8f4c8e4da105e9430a16d2ab8075bbcca7b
|
7
|
+
data.tar.gz: 75eceac300152cfa8d0ce736b16939e779f127dc844ea0c8ce8e1d0f363b04048107533f64c744420b2631fd3e8678d6812e15c6b770e2b37a598ae290af773a
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,18 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.6.0] - 2022-06-26
|
4
|
+
- Using variables elimination algorithm to build a distribution
|
5
|
+
|
6
|
+
## [0.5.0] - 2022-02-26
|
7
|
+
|
8
|
+
- Constructing networks out of the `.BIF` ([Interchange Format for Bayesian Networks](https://www.cs.washington.edu/dm/vfml/appendixes/bif.htm)) files.
|
9
|
+
- Fixing inference bug
|
10
|
+
- Network children nodes could be specified ***before** their parents
|
11
|
+
|
12
|
+
## [0.0.3] - 2021-12-29
|
13
|
+
|
14
|
+
- Fixing terminology used in Factor class
|
15
|
+
|
3
16
|
## [0.0.2] - 2021-12-28
|
4
17
|
|
5
18
|
- README, CI/CD for Ruby 2.6, 2.7, 3.1 added
|
data/Gemfile
CHANGED
data/Gemfile.lock
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
PATH
|
2
2
|
remote: .
|
3
3
|
specs:
|
4
|
-
bayesnet (0.0
|
4
|
+
bayesnet (0.6.0)
|
5
5
|
|
6
6
|
GEM
|
7
7
|
remote: https://rubygems.org/
|
8
8
|
specs:
|
9
|
-
ast (2.4.2)
|
10
9
|
byebug (11.1.3)
|
11
10
|
coderay (1.1.3)
|
12
11
|
m (1.5.1)
|
@@ -14,38 +13,16 @@ GEM
|
|
14
13
|
rake (>= 0.9.2.2)
|
15
14
|
method_source (1.0.0)
|
16
15
|
minitest (5.15.0)
|
17
|
-
|
18
|
-
parser (3.0.3.2)
|
19
|
-
ast (~> 2.4.1)
|
16
|
+
polyglot (0.3.5)
|
20
17
|
pry (0.13.1)
|
21
18
|
coderay (~> 1.1)
|
22
19
|
method_source (~> 1.0)
|
23
20
|
pry-byebug (3.9.0)
|
24
21
|
byebug (~> 11.0)
|
25
22
|
pry (~> 0.13.0)
|
26
|
-
rainbow (3.0.0)
|
27
23
|
rake (13.0.6)
|
28
|
-
|
29
|
-
|
30
|
-
rubocop (1.23.0)
|
31
|
-
parallel (~> 1.10)
|
32
|
-
parser (>= 3.0.0.0)
|
33
|
-
rainbow (>= 2.2.2, < 4.0)
|
34
|
-
regexp_parser (>= 1.8, < 3.0)
|
35
|
-
rexml
|
36
|
-
rubocop-ast (>= 1.12.0, < 2.0)
|
37
|
-
ruby-progressbar (~> 1.7)
|
38
|
-
unicode-display_width (>= 1.4.0, < 3.0)
|
39
|
-
rubocop-ast (1.15.0)
|
40
|
-
parser (>= 3.0.1.1)
|
41
|
-
rubocop-performance (1.12.0)
|
42
|
-
rubocop (>= 1.7.0, < 2.0)
|
43
|
-
rubocop-ast (>= 0.4.0)
|
44
|
-
ruby-progressbar (1.11.0)
|
45
|
-
standard (1.5.0)
|
46
|
-
rubocop (= 1.23.0)
|
47
|
-
rubocop-performance (= 1.12.0)
|
48
|
-
unicode-display_width (2.1.0)
|
24
|
+
treetop (1.6.11)
|
25
|
+
polyglot (~> 0.3)
|
49
26
|
|
50
27
|
PLATFORMS
|
51
28
|
x86_64-darwin-19
|
@@ -57,7 +34,7 @@ DEPENDENCIES
|
|
57
34
|
minitest (~> 5.0)
|
58
35
|
pry-byebug (~> 3.9.0)
|
59
36
|
rake (~> 13.0)
|
60
|
-
|
37
|
+
treetop (~> 1.6)
|
61
38
|
|
62
39
|
BUNDLED WITH
|
63
|
-
2.
|
40
|
+
2.3.3
|
data/README.md
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# Bayesnet
|
2
2
|
|
3
|
-
This gem provides an
|
3
|
+
This gem provides an DSL for constructing Bayesian networks and let to execute basic inference queries. It is also capable of parsing .BIF format ([The Interchange Format for Bayesian Networks](https://www.cs.washington.edu/dm/vfml/appendixes/bif.htm)).
|
4
|
+
|
4
5
|
|
5
6
|
### Example:
|
6
7
|
|
@@ -66,12 +67,14 @@ The inference is based on summing over joint distribution, i.e. it is the simple
|
|
66
67
|
most expensive way to calculate it. No optimization is implemented in this version; the code
|
67
68
|
is more a proof of API.
|
68
69
|
|
70
|
+
### [Another example](https://afurmanov.com/reducing-anxiety-with-bayesian-network) of using this gem
|
71
|
+
|
69
72
|
## Installation
|
70
73
|
|
71
74
|
Add this line to your application's Gemfile:
|
72
75
|
|
73
76
|
```ruby
|
74
|
-
|
77
|
+
em 'bayesnet'
|
75
78
|
```
|
76
79
|
|
77
80
|
And then execute:
|
data/Rakefile
CHANGED
@@ -9,6 +9,9 @@ Rake::TestTask.new(:test) do |t|
|
|
9
9
|
t.test_files = FileList["test/**/*_test.rb"]
|
10
10
|
end
|
11
11
|
|
12
|
-
|
12
|
+
Rake::TestTask.new("regen-bif") do |t|
|
13
|
+
`rm ./lib/bayesnet/parsers/bif.rb`
|
14
|
+
`tt ./lib/bayesnet/parsers/bif.treetop`
|
15
|
+
end
|
13
16
|
|
14
|
-
task default: %i[test
|
17
|
+
task default: %i[test]
|
data/bayesnet.gemspec
CHANGED
@@ -34,7 +34,6 @@ Gem::Specification.new do |spec|
|
|
34
34
|
spec.add_development_dependency "m", "~> 1.5.0"
|
35
35
|
spec.add_development_dependency "minitest", "~> 5.0"
|
36
36
|
spec.add_development_dependency "pry-byebug", "~> 3.9.0"
|
37
|
-
spec.add_development_dependency "standard", "~> 1.3"
|
38
37
|
|
39
38
|
# For more information and examples about making a new gem, checkout our
|
40
39
|
# guide at: https://bundler.io/guides/creating_gem.html
|
data/doc/morning-mood-model.png
CHANGED
Binary file
|
data/lib/bayesnet/dsl.rb
CHANGED
data/lib/bayesnet/error.rb
CHANGED
data/lib/bayesnet/factor.rb
CHANGED
@@ -1,89 +1,186 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
module Bayesnet
|
2
|
-
# Factor if a function of
|
4
|
+
# Factor if a function of several variables (A, B, ...), where
|
5
|
+
# every variable cold take values from some finite set
|
3
6
|
class Factor
|
7
|
+
# +++ Factor DSL +++
|
8
|
+
#
|
9
|
+
# Factor DSL entry point:
|
4
10
|
def self.build(&block)
|
5
11
|
factor = new
|
6
12
|
factor.instance_eval(&block)
|
7
13
|
factor
|
8
14
|
end
|
9
15
|
|
10
|
-
#
|
11
|
-
|
12
|
-
|
16
|
+
# Factor DSL
|
17
|
+
# Defining variable with list of its possible values looks like:
|
18
|
+
# ```
|
19
|
+
# Bayesnet::Factor.build do
|
20
|
+
# scope weather: %i[sunny cloudy]
|
21
|
+
# scope mood: %i[bad good]
|
22
|
+
# ...
|
23
|
+
# ```
|
24
|
+
# ^ this code defines to variables `weather` and `mood`, where
|
25
|
+
# `weather` could be :sunny or :cloudy, and
|
26
|
+
# `mood` could be :bad or :good
|
27
|
+
def scope(var_name_to_values = nil)
|
28
|
+
if var_name_to_values
|
29
|
+
@scope.merge!(var_name_to_values)
|
30
|
+
else
|
31
|
+
@scope
|
32
|
+
end
|
13
33
|
end
|
14
34
|
|
15
|
-
#
|
16
|
-
|
17
|
-
|
18
|
-
|
35
|
+
# Factor DSL
|
36
|
+
# Specifies factor value for some set of variable values, i.e.
|
37
|
+
# ```
|
38
|
+
# Bayesnet::Factor.build do
|
39
|
+
# scope weather: %i[sunny cloudy]
|
40
|
+
# scope mood: %i[bad good]
|
41
|
+
# val :sunny, :bad, 0.1
|
42
|
+
# ...
|
43
|
+
# ```
|
44
|
+
# ^ this code says the value of factor for [weather == :sunny, mood == :bad] is 0.1
|
45
|
+
def val(*context_and_val)
|
46
|
+
context_and_val = context_and_val[0] if context_and_val.size == 1 && context_and_val[0].is_a?(Array)
|
47
|
+
@vals[context_and_val[0..-2]] = context_and_val[-1]
|
19
48
|
end
|
49
|
+
# --- Factor DSL ---
|
20
50
|
|
51
|
+
# List of variable names
|
21
52
|
def var_names
|
22
|
-
@
|
53
|
+
@scope.keys
|
23
54
|
end
|
24
55
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
56
|
+
# accessor factor value, i.e
|
57
|
+
# ```
|
58
|
+
# factor = Bayesnet::Factor.build do
|
59
|
+
# scope weather: %i[sunny cloudy]
|
60
|
+
# scope mood: %i[bad good]
|
61
|
+
# val :sunny, :bad, 0.1
|
62
|
+
# ...
|
63
|
+
# end
|
64
|
+
# factor[:sunny, :bad] # 0.1
|
65
|
+
# ```
|
66
|
+
def [](*context)
|
67
|
+
key = if context.size == 1 && context[0].is_a?(Hash)
|
68
|
+
context[0].slice(*var_names).values
|
69
|
+
else
|
70
|
+
context
|
71
|
+
end
|
31
72
|
@vals[key]
|
32
73
|
end
|
33
74
|
|
34
|
-
|
35
|
-
|
36
|
-
end
|
37
|
-
|
38
|
-
def args(*var_names)
|
75
|
+
# returns all combinations of values of `var_names`
|
76
|
+
def contextes(*var_names)
|
39
77
|
return [] if var_names.empty?
|
40
|
-
|
78
|
+
|
79
|
+
@scope[var_names[0]].product(*var_names[1..].map { |var_name| @scope[var_name] })
|
41
80
|
end
|
42
81
|
|
82
|
+
# returns all possible values
|
43
83
|
def values
|
44
84
|
@vals.values
|
45
85
|
end
|
46
86
|
|
87
|
+
# returns new normalized factor, i.e. where sum of all values is 1.0
|
47
88
|
def normalize
|
48
89
|
vals = @vals.clone
|
49
90
|
norm_factor = vals.map(&:last).sum * 1.0
|
50
|
-
vals.each { |k,
|
51
|
-
self.class.new(@
|
91
|
+
vals.each { |k, _v| vals[k] /= norm_factor }
|
92
|
+
self.class.new(@scope.clone, vals)
|
52
93
|
end
|
53
94
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
95
|
+
# Returns factor built as follows:
|
96
|
+
# 1. Original factor gets filtered out by variables having values compatible with `context`
|
97
|
+
# 2. Returned factor does not have any variables from `context` (because they have
|
98
|
+
# same values, after step 1)
|
99
|
+
# The `context` argument supposed to be an evidence, somewhat like
|
100
|
+
# `{weather: :sunny}`
|
101
|
+
def reduce_to(context)
|
102
|
+
limited_context = context.slice(*scope.keys)
|
103
|
+
return self.class.new(@scope, @vals) if limited_context.empty?
|
104
|
+
limited_scope = @scope.slice(*(@scope.keys - limited_context.keys))
|
58
105
|
|
59
|
-
|
60
|
-
indices =
|
61
|
-
vals = @vals.select { |k,
|
62
|
-
vals.transform_keys! { |k| k
|
106
|
+
context_vals = limited_context.values
|
107
|
+
indices = limited_context.keys.map { |k| index_by_var_name[k] }
|
108
|
+
vals = @vals.select { |k, _v| indices.map { |i| k[i] } == context_vals }
|
109
|
+
vals.transform_keys! { |k| delete_by_indices(k, indices) }
|
63
110
|
|
64
|
-
self.class.new(
|
111
|
+
self.class.new(limited_scope, vals)
|
65
112
|
end
|
66
113
|
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
114
|
+
# Returns new context defined over `var_names`, all other variables
|
115
|
+
# get eliminated. For every combination of `var_names`'s values
|
116
|
+
# the value of new factor is defined by summing up values in original factor
|
117
|
+
# having compatible value
|
118
|
+
def marginalize(var_names)
|
119
|
+
scope = @scope.slice(*var_names)
|
120
|
+
|
121
|
+
indices = scope.keys.map { |k| index_by_var_name[k] }
|
122
|
+
vals = @vals.group_by { |context, _val| indices.map { |i| context[i] } }
|
71
123
|
vals.transform_values! { |v| v.map(&:last).sum }
|
72
|
-
|
73
|
-
|
124
|
+
|
125
|
+
self.class.new(scope, vals)
|
126
|
+
end
|
127
|
+
|
128
|
+
def eliminate(var_name)
|
129
|
+
keep_var_names = var_names
|
130
|
+
keep_var_names.delete(var_name)
|
131
|
+
marginalize(keep_var_names)
|
132
|
+
end
|
133
|
+
|
134
|
+
def select(subcontext)
|
135
|
+
@vals.select do |context, _|
|
136
|
+
var_names.zip(context).slice(subcontext.keys) == subcontext
|
137
|
+
end
|
138
|
+
end
|
139
|
+
|
140
|
+
def *(other)
|
141
|
+
common_scope = @scope.keys & other.scope.keys
|
142
|
+
new_scope = scope.merge(other.scope)
|
143
|
+
new_vals = {}
|
144
|
+
group1 = group_by_scope_values(common_scope)
|
145
|
+
group2 = other.group_by_scope_values(common_scope)
|
146
|
+
group1.each do |scope, vals1|
|
147
|
+
combo = vals1.product(group2[scope])
|
148
|
+
combo.each do |(val1, val2)|
|
149
|
+
# values in scope must match variables order in new_scope, i.e.
|
150
|
+
# they must match `new_scope.var_names`
|
151
|
+
# The code bellow ensures it by merging two hashes in the same
|
152
|
+
# wasy as `new_scope`` is constructed above
|
153
|
+
val_by_name1 = var_names.zip(val1.first).to_h
|
154
|
+
val_by_name2 = other.var_names.zip(val2.first).to_h
|
155
|
+
new_vals[val_by_name1.merge(val_by_name2).values] = val1.last*val2.last
|
156
|
+
end
|
157
|
+
end
|
158
|
+
Factor.new(new_scope, new_vals)
|
159
|
+
end
|
160
|
+
|
161
|
+
def group_by_scope_values(scope_keys)
|
162
|
+
indices = scope_keys.map { |k| index_by_var_name[k] }
|
163
|
+
@vals.group_by { |context, _val| indices.map { |i| context[i] } }
|
74
164
|
end
|
75
165
|
|
76
166
|
private
|
77
167
|
|
78
|
-
def
|
79
|
-
|
168
|
+
def delete_by_indices(array, indices)
|
169
|
+
result = array.dup
|
170
|
+
indices.map { |i| result[i] = nil }
|
171
|
+
result.compact
|
172
|
+
end
|
173
|
+
|
174
|
+
def initialize(scope = {}, vals = {})
|
175
|
+
@scope = scope
|
80
176
|
@vals = vals
|
81
177
|
end
|
82
178
|
|
83
179
|
def index_by_var_name
|
84
180
|
return @index_by_var_name if @index_by_var_name
|
181
|
+
|
85
182
|
@index_by_var_name = {}
|
86
|
-
@
|
183
|
+
@scope.each_with_index { |(k, _v), i| @index_by_var_name[k] = i }
|
87
184
|
@index_by_var_name
|
88
185
|
end
|
89
186
|
end
|
data/lib/bayesnet/graph.rb
CHANGED
@@ -1,33 +1,123 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
require "bayesnet/node"
|
2
4
|
|
3
5
|
module Bayesnet
|
6
|
+
# Acyclic graph
|
4
7
|
class Graph
|
8
|
+
include Bayesnet::Logging
|
9
|
+
|
5
10
|
attr_reader :nodes
|
6
11
|
|
7
12
|
def initialize
|
8
13
|
@nodes = {}
|
9
14
|
end
|
10
15
|
|
11
|
-
|
12
|
-
nodes.keys
|
13
|
-
end
|
14
|
-
|
16
|
+
# +++ Graph DSL +++
|
15
17
|
def node(name, parents: [], &block)
|
16
18
|
raise Error, "DSL error, #node requires a &block" unless block
|
17
|
-
|
19
|
+
|
20
|
+
node = Node.new(name, parents)
|
18
21
|
node.instance_eval(&block)
|
19
22
|
@nodes[name] = node
|
20
23
|
end
|
24
|
+
# --- Graph DSL ---
|
25
|
+
|
26
|
+
# returns names of all nodes
|
27
|
+
def var_names
|
28
|
+
nodes.keys
|
29
|
+
end
|
30
|
+
|
31
|
+
# returns normalized distribution reduced to `evidence`
|
32
|
+
# and marginalized over `over`
|
33
|
+
def distribution(over: [], evidence: {}, algorithm: :variables_elimination)
|
34
|
+
case algorithm
|
35
|
+
when :brute_force
|
36
|
+
joint_distribution
|
37
|
+
.reduce_to(evidence)
|
38
|
+
.marginalize(over)
|
39
|
+
.normalize
|
40
|
+
when :variables_elimination
|
41
|
+
reduced_factors = nodes.values.map(&:factor).map { |f| f.reduce_to(evidence) }
|
42
|
+
not_include_in_order = evidence.keys.to_set + over.to_set
|
43
|
+
variables_order = elimination_order.reject { |v| not_include_in_order.include?(v) }
|
44
|
+
distribution = eliminate_variables(variables_order, reduced_factors)
|
45
|
+
distribution.normalize
|
46
|
+
else
|
47
|
+
raise "Uknown algorithm #{algorithm}"
|
48
|
+
end
|
49
|
+
end
|
50
|
+
|
51
|
+
def elimination_order
|
52
|
+
return @order if @order
|
53
|
+
@order = []
|
54
|
+
edges = Set.new
|
55
|
+
@nodes.each do |name, node|
|
56
|
+
parents = node.parent_nodes.keys
|
57
|
+
parents.each { |p| edges.add([name, p].to_set) }
|
58
|
+
parents.combination(2) { |p1, p2| edges.add([p1, p2].to_set) }
|
59
|
+
end
|
60
|
+
# edges now are moralized graph of `self`, just represented differently as
|
61
|
+
# set of edges
|
62
|
+
|
63
|
+
remaining_nodes = nodes.keys.to_set
|
64
|
+
until remaining_nodes.empty?
|
65
|
+
best_node = find_min_neighbor(remaining_nodes, edges)
|
66
|
+
remaining_nodes.delete(best_node)
|
67
|
+
@order.push(best_node)
|
68
|
+
clique = edges.select { |e| e.include?(best_node) }
|
69
|
+
edges -= clique
|
70
|
+
if edges.empty? #i.e. clique is the last edge
|
71
|
+
@order += remaining_nodes.to_a
|
72
|
+
remaining_nodes = Set.new
|
73
|
+
end
|
74
|
+
clique.
|
75
|
+
map { |e| e.delete(best_node) }.
|
76
|
+
map(&:first).
|
77
|
+
combination(2) { |p1, p2| edges.add([p1,p2].to_set) }
|
78
|
+
end
|
79
|
+
@order
|
80
|
+
end
|
21
81
|
|
22
|
-
def
|
23
|
-
|
24
|
-
|
82
|
+
def find_min_neighbor(remaining_nodes, edges)
|
83
|
+
result = nil
|
84
|
+
min_neighbors = nil
|
85
|
+
remaining_nodes.each do |name, _|
|
86
|
+
neighbors = edges.count { |e| e.include?(name) }
|
87
|
+
if min_neighbors.nil? || neighbors < min_neighbors
|
88
|
+
min_neighbors = neighbors
|
89
|
+
result = name
|
90
|
+
end
|
91
|
+
end
|
92
|
+
result
|
93
|
+
end
|
94
|
+
|
95
|
+
def eliminate_variables(variables_order, factors)
|
96
|
+
logger.debug "Eliminating variables #{variables_order} from #{factors.size} factors #{factors.map(&:var_names)}"
|
97
|
+
remaining_factors = factors.to_set
|
98
|
+
variables_order.each do |var_name|
|
99
|
+
logger.debug "Eliminating '#{var_name}'..."
|
100
|
+
grouped_factors = remaining_factors.select { |f| f.var_names.include?(var_name) }
|
101
|
+
remaining_factors -= grouped_factors
|
102
|
+
logger.debug "Building new factor out of #{grouped_factors.size} factors having '#{var_name}' - #{grouped_factors.map(&:var_names)}"
|
103
|
+
product_factor = grouped_factors.reduce(&:*)
|
104
|
+
logger.debug "Removing variable from new factor"
|
105
|
+
new_factor = product_factor.eliminate(var_name)
|
106
|
+
logger.debug "New factor variables are #{new_factor.var_names}"
|
107
|
+
remaining_factors.add(new_factor)
|
108
|
+
logger.debug "The variable '#{var_name}' is elminated"
|
109
|
+
end
|
110
|
+
logger.debug "Non-eliminated variables are #{remaining_factors.map(&:var_names).flatten.uniq}"
|
111
|
+
result = remaining_factors.reduce(&:*)
|
112
|
+
logger.debug "Eliminating is done"
|
113
|
+
result
|
25
114
|
end
|
26
115
|
|
27
116
|
# This is MAP query, i.e. Maximum a Posteriory
|
117
|
+
# returns value of `var_name` having maximum likelihood, when `evidence` is observed
|
28
118
|
def most_likely_value(var_name, evidence:)
|
29
119
|
posterior_distribution = distribution(over: [var_name], evidence: evidence)
|
30
|
-
mode = posterior_distribution.
|
120
|
+
mode = posterior_distribution.contextes(var_name).zip(posterior_distribution.values).max_by(&:last)
|
31
121
|
mode.first.first
|
32
122
|
end
|
33
123
|
|
@@ -37,6 +127,7 @@ module Bayesnet
|
|
37
127
|
posterior_distribution[*over_vars.values]
|
38
128
|
end
|
39
129
|
|
130
|
+
# Essentially it builds product of all node's factors
|
40
131
|
def joint_distribution
|
41
132
|
return @joint_distribution if @joint_distribution
|
42
133
|
|
@@ -47,17 +138,27 @@ module Bayesnet
|
|
47
138
|
|
48
139
|
factor = Factor.new
|
49
140
|
@nodes.each do |node_name, node|
|
50
|
-
factor.
|
141
|
+
factor.scope node_name => node.values
|
51
142
|
end
|
52
143
|
|
53
|
-
factor.
|
54
|
-
val_by_name = var_names.zip(
|
144
|
+
factor.contextes(*var_names).each do |context|
|
145
|
+
val_by_name = var_names.zip(context).to_h
|
55
146
|
val = nodes.values.reduce(1.0) do |prob, node|
|
56
147
|
prob * node.factor[val_by_name]
|
57
148
|
end
|
58
|
-
factor.val
|
149
|
+
factor.val context + [val]
|
59
150
|
end
|
60
151
|
@joint_distribution = factor.normalize
|
61
152
|
end
|
153
|
+
|
154
|
+
def parameters
|
155
|
+
nodes.values.map(&:parameters).sum
|
156
|
+
end
|
157
|
+
|
158
|
+
def resolve_factors
|
159
|
+
@nodes.values.each do |node|
|
160
|
+
node.resolve_factor(@nodes.slice(*node.parent_nodes))
|
161
|
+
end
|
162
|
+
end
|
62
163
|
end
|
63
164
|
end
|
data/lib/bayesnet/node.rb
CHANGED
@@ -6,8 +6,10 @@ module Bayesnet
|
|
6
6
|
@name = name
|
7
7
|
@parent_nodes = parent_nodes
|
8
8
|
@values = []
|
9
|
+
@factor = Factor.new
|
9
10
|
end
|
10
11
|
|
12
|
+
# +++ Node DSL +++
|
11
13
|
def values(hash_or_array = nil, &block)
|
12
14
|
case hash_or_array
|
13
15
|
when NilClass
|
@@ -16,7 +18,7 @@ module Bayesnet
|
|
16
18
|
@values = hash_or_array.keys
|
17
19
|
node = self
|
18
20
|
@factor = Factor.build do
|
19
|
-
|
21
|
+
scope node.name => node.values
|
20
22
|
hash_or_array.each do |value, probability|
|
21
23
|
val [value, probability]
|
22
24
|
end
|
@@ -24,25 +26,39 @@ module Bayesnet
|
|
24
26
|
when Array
|
25
27
|
raise Error, "DSL error, #values requires a &block when first argument is an Array" unless block
|
26
28
|
@values = hash_or_array
|
27
|
-
|
28
|
-
@factor = Factor.build do
|
29
|
-
var node.name => node.values
|
30
|
-
node.parent_nodes.each do |parent_node_name, parent_node|
|
31
|
-
var parent_node_name => parent_node.values
|
32
|
-
end
|
33
|
-
end
|
34
|
-
instance_eval(&block)
|
29
|
+
@factor = block
|
35
30
|
end
|
36
31
|
end
|
37
32
|
|
38
33
|
def distributions(&block)
|
39
34
|
instance_eval(&block)
|
40
35
|
end
|
36
|
+
# --- Node DSL ---
|
37
|
+
|
38
|
+
def parameters
|
39
|
+
(values.size - 1) * parent_nodes.values.reduce(1) { |mul, n| mul * n.values.size }
|
40
|
+
end
|
41
41
|
|
42
42
|
def as(distribution, given:)
|
43
43
|
@values.zip(distribution).each do |value, probability|
|
44
44
|
@factor.val [value] + given + [probability]
|
45
45
|
end
|
46
46
|
end
|
47
|
+
|
48
|
+
def resolve_factor(parent_nodes)
|
49
|
+
@parent_nodes = parent_nodes
|
50
|
+
if @factor.is_a?(Proc)
|
51
|
+
proc = @factor
|
52
|
+
node = self
|
53
|
+
@factor = Factor.build do
|
54
|
+
scope node.name => node.values
|
55
|
+
node.parent_nodes.each do |parent_node_name, parent_node|
|
56
|
+
scope parent_node_name => parent_node.values
|
57
|
+
end
|
58
|
+
end
|
59
|
+
instance_eval(&proc)
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
47
63
|
end
|
48
64
|
end
|