bayesnet 0.0.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9f17178cdfa472aea14e7769f1831f1903ceee2613b6e7b71e1135990dc974ab
4
- data.tar.gz: c55cb95134d4dbb479e89839be7a4634b6bc6a8be52c5c066f1cc1ad21c584a7
3
+ metadata.gz: 7591665046345784f55275c06d1129fd91ee3f098f3800b2c03b6f9bbfd8e172
4
+ data.tar.gz: ec9009ab90593d42fa2506a230e5900d5a39bebb1a7fbd874953d6c86022b2eb
5
5
  SHA512:
6
- metadata.gz: fda426d9dd319b2ad8300a8ba59acae511740cc271729f40b36ce7a7765d8019a901d9d83a2e955ccd964901ee604a3ad96fba092ce41394da5980a77d1bb198
7
- data.tar.gz: 03d4beb66b11ba684007dd49f67ef2ae085ecd5c6e5ca9b8b63327eee109ada3b1f389731543fa6d970bff110230231bc1533022640367a906f57dec1c41bc6d
6
+ metadata.gz: 5e668b431f55f9239ad3ae06cdc020098ff4a8b68f7934283d6f77a3969a014aa17f68df12b23013d1681e79fdecd8f4c8e4da105e9430a16d2ab8075bbcca7b
7
+ data.tar.gz: 75eceac300152cfa8d0ce736b16939e779f127dc844ea0c8ce8e1d0f363b04048107533f64c744420b2631fd3e8678d6812e15c6b770e2b37a598ae290af773a
data/CHANGELOG.md CHANGED
@@ -1,5 +1,18 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.6.0] - 2022-06-26
4
+ - Using variables elimination algorithm to build a distribution
5
+
6
+ ## [0.5.0] - 2022-02-26
7
+
8
+ - Constructing networks out of the `.BIF` ([Interchange Format for Bayesian Networks](https://www.cs.washington.edu/dm/vfml/appendixes/bif.htm)) files.
9
+ - Fixing inference bug
10
+ - Network children nodes could be specified ***before** their parents
11
+
12
+ ## [0.0.3] - 2021-12-29
13
+
14
+ - Fixing terminology used in Factor class
15
+
3
16
  ## [0.0.2] - 2021-12-28
4
17
 
5
18
  - README, CI/CD for Ruby 2.6, 2.7, 3.1 added
data/Gemfile CHANGED
@@ -6,10 +6,10 @@ source "https://rubygems.org"
6
6
  gemspec
7
7
 
8
8
  gem "rake", "~> 13.0"
9
+ gem "treetop", "~> 1.6"
9
10
 
10
11
  group :development, :test do
11
12
  gem "m", "~> 1.5.0"
12
13
  gem "minitest", "~> 5.0"
13
14
  gem "pry-byebug", "~> 3.9.0"
14
- gem "standard", "~> 1.3"
15
15
  end
data/Gemfile.lock CHANGED
@@ -1,12 +1,11 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- bayesnet (0.0.2)
4
+ bayesnet (0.6.0)
5
5
 
6
6
  GEM
7
7
  remote: https://rubygems.org/
8
8
  specs:
9
- ast (2.4.2)
10
9
  byebug (11.1.3)
11
10
  coderay (1.1.3)
12
11
  m (1.5.1)
@@ -14,38 +13,16 @@ GEM
14
13
  rake (>= 0.9.2.2)
15
14
  method_source (1.0.0)
16
15
  minitest (5.15.0)
17
- parallel (1.21.0)
18
- parser (3.0.3.2)
19
- ast (~> 2.4.1)
16
+ polyglot (0.3.5)
20
17
  pry (0.13.1)
21
18
  coderay (~> 1.1)
22
19
  method_source (~> 1.0)
23
20
  pry-byebug (3.9.0)
24
21
  byebug (~> 11.0)
25
22
  pry (~> 0.13.0)
26
- rainbow (3.0.0)
27
23
  rake (13.0.6)
28
- regexp_parser (2.2.0)
29
- rexml (3.2.5)
30
- rubocop (1.23.0)
31
- parallel (~> 1.10)
32
- parser (>= 3.0.0.0)
33
- rainbow (>= 2.2.2, < 4.0)
34
- regexp_parser (>= 1.8, < 3.0)
35
- rexml
36
- rubocop-ast (>= 1.12.0, < 2.0)
37
- ruby-progressbar (~> 1.7)
38
- unicode-display_width (>= 1.4.0, < 3.0)
39
- rubocop-ast (1.15.0)
40
- parser (>= 3.0.1.1)
41
- rubocop-performance (1.12.0)
42
- rubocop (>= 1.7.0, < 2.0)
43
- rubocop-ast (>= 0.4.0)
44
- ruby-progressbar (1.11.0)
45
- standard (1.5.0)
46
- rubocop (= 1.23.0)
47
- rubocop-performance (= 1.12.0)
48
- unicode-display_width (2.1.0)
24
+ treetop (1.6.11)
25
+ polyglot (~> 0.3)
49
26
 
50
27
  PLATFORMS
51
28
  x86_64-darwin-19
@@ -57,7 +34,7 @@ DEPENDENCIES
57
34
  minitest (~> 5.0)
58
35
  pry-byebug (~> 3.9.0)
59
36
  rake (~> 13.0)
60
- standard (~> 1.3)
37
+ treetop (~> 1.6)
61
38
 
62
39
  BUNDLED WITH
63
- 2.2.32
40
+ 2.3.3
data/README.md CHANGED
@@ -1,6 +1,7 @@
1
1
  # Bayesnet
2
2
 
3
- This gem provides an API for building a Bayesian network and executing some queries against it.
3
+ This gem provides an DSL for constructing Bayesian networks and let to execute basic inference queries. It is also capable of parsing .BIF format ([The Interchange Format for Bayesian Networks](https://www.cs.washington.edu/dm/vfml/appendixes/bif.htm)).
4
+
4
5
 
5
6
  ### Example:
6
7
 
@@ -66,12 +67,14 @@ The inference is based on summing over joint distribution, i.e. it is the simple
66
67
  most expensive way to calculate it. No optimization is implemented in this version; the code
67
68
  is more a proof of API.
68
69
 
70
+ ### [Another example](https://afurmanov.com/reducing-anxiety-with-bayesian-network) of using this gem
71
+
69
72
  ## Installation
70
73
 
71
74
  Add this line to your application's Gemfile:
72
75
 
73
76
  ```ruby
74
- gem 'bayesnet'
77
+ em 'bayesnet'
75
78
  ```
76
79
 
77
80
  And then execute:
data/Rakefile CHANGED
@@ -9,6 +9,9 @@ Rake::TestTask.new(:test) do |t|
9
9
  t.test_files = FileList["test/**/*_test.rb"]
10
10
  end
11
11
 
12
- require "standard/rake"
12
+ Rake::TestTask.new("regen-bif") do |t|
13
+ `rm ./lib/bayesnet/parsers/bif.rb`
14
+ `tt ./lib/bayesnet/parsers/bif.treetop`
15
+ end
13
16
 
14
- task default: %i[test standard]
17
+ task default: %i[test]
data/bayesnet.gemspec CHANGED
@@ -34,7 +34,6 @@ Gem::Specification.new do |spec|
34
34
  spec.add_development_dependency "m", "~> 1.5.0"
35
35
  spec.add_development_dependency "minitest", "~> 5.0"
36
36
  spec.add_development_dependency "pry-byebug", "~> 3.9.0"
37
- spec.add_development_dependency "standard", "~> 1.3"
38
37
 
39
38
  # For more information and examples about making a new gem, checkout our
40
39
  # guide at: https://bundler.io/guides/creating_gem.html
Binary file
data/lib/bayesnet/dsl.rb CHANGED
@@ -1,10 +1,14 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require "bayesnet/graph"
2
4
 
3
5
  module Bayesnet
6
+ # Bayesnet::DSL.define ...
4
7
  module DSL
5
8
  def define(&block)
6
9
  graph = Graph.new
7
10
  graph.instance_eval(&block) if block
11
+ graph.resolve_factors
8
12
  graph
9
13
  end
10
14
  end
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module Bayesnet
2
4
  class Error < StandardError
3
5
  end
@@ -1,89 +1,186 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module Bayesnet
2
- # Factor if a function of sevaral variables (A, B, ...) each defined on values from finite set
4
+ # Factor if a function of several variables (A, B, ...), where
5
+ # every variable cold take values from some finite set
3
6
  class Factor
7
+ # +++ Factor DSL +++
8
+ #
9
+ # Factor DSL entry point:
4
10
  def self.build(&block)
5
11
  factor = new
6
12
  factor.instance_eval(&block)
7
13
  factor
8
14
  end
9
15
 
10
- # Specifies variable name together with its values
11
- def var(var_name_to_values)
12
- @vars.merge!(var_name_to_values)
16
+ # Factor DSL
17
+ # Defining variable with list of its possible values looks like:
18
+ # ```
19
+ # Bayesnet::Factor.build do
20
+ # scope weather: %i[sunny cloudy]
21
+ # scope mood: %i[bad good]
22
+ # ...
23
+ # ```
24
+ # ^ this code defines to variables `weather` and `mood`, where
25
+ # `weather` could be :sunny or :cloudy, and
26
+ # `mood` could be :bad or :good
27
+ def scope(var_name_to_values = nil)
28
+ if var_name_to_values
29
+ @scope.merge!(var_name_to_values)
30
+ else
31
+ @scope
32
+ end
13
33
  end
14
34
 
15
- # Specifies function values for args. Latest args is an function value, all previous are argument values
16
- def val(*args)
17
- args = args[0] if args.size == 1 && args[0].is_a?(Array)
18
- @vals[args[0..-2]] = args[-1]
35
+ # Factor DSL
36
+ # Specifies factor value for some set of variable values, i.e.
37
+ # ```
38
+ # Bayesnet::Factor.build do
39
+ # scope weather: %i[sunny cloudy]
40
+ # scope mood: %i[bad good]
41
+ # val :sunny, :bad, 0.1
42
+ # ...
43
+ # ```
44
+ # ^ this code says the value of factor for [weather == :sunny, mood == :bad] is 0.1
45
+ def val(*context_and_val)
46
+ context_and_val = context_and_val[0] if context_and_val.size == 1 && context_and_val[0].is_a?(Array)
47
+ @vals[context_and_val[0..-2]] = context_and_val[-1]
19
48
  end
49
+ # --- Factor DSL ---
20
50
 
51
+ # List of variable names
21
52
  def var_names
22
- @vars.keys
53
+ @scope.keys
23
54
  end
24
55
 
25
- def [](*args)
26
- key = if args.size == 1 && args[0].is_a?(Hash)
27
- args[0].slice(*var_names).values
28
- else
29
- args
30
- end
56
+ # accessor factor value, i.e
57
+ # ```
58
+ # factor = Bayesnet::Factor.build do
59
+ # scope weather: %i[sunny cloudy]
60
+ # scope mood: %i[bad good]
61
+ # val :sunny, :bad, 0.1
62
+ # ...
63
+ # end
64
+ # factor[:sunny, :bad] # 0.1
65
+ # ```
66
+ def [](*context)
67
+ key = if context.size == 1 && context[0].is_a?(Hash)
68
+ context[0].slice(*var_names).values
69
+ else
70
+ context
71
+ end
31
72
  @vals[key]
32
73
  end
33
74
 
34
- def self.from_distribution(var_distribution)
35
- self.class.new(var_distribution.keys, var_distribution.values.map(&:to_a))
36
- end
37
-
38
- def args(*var_names)
75
+ # returns all combinations of values of `var_names`
76
+ def contextes(*var_names)
39
77
  return [] if var_names.empty?
40
- @vars[var_names[0]].product(*var_names[1..].map { |var_name| @vars[var_name] })
78
+
79
+ @scope[var_names[0]].product(*var_names[1..].map { |var_name| @scope[var_name] })
41
80
  end
42
81
 
82
+ # returns all possible values
43
83
  def values
44
84
  @vals.values
45
85
  end
46
86
 
87
+ # returns new normalized factor, i.e. where sum of all values is 1.0
47
88
  def normalize
48
89
  vals = @vals.clone
49
90
  norm_factor = vals.map(&:last).sum * 1.0
50
- vals.each { |k, v| vals[k] /= norm_factor }
51
- self.class.new(@vars.clone, vals)
91
+ vals.each { |k, _v| vals[k] /= norm_factor }
92
+ self.class.new(@scope.clone, vals)
52
93
  end
53
94
 
54
- def limit_by(evidence)
55
- # todo: use Hash#except when Ruby 2.6 support no longer needed
56
- evidence_keys_set = evidence.keys.to_set
57
- vars = @vars.reject { |k, _| evidence_keys_set.include?(k) }
95
+ # Returns factor built as follows:
96
+ # 1. Original factor gets filtered out by variables having values compatible with `context`
97
+ # 2. Returned factor does not have any variables from `context` (because they have
98
+ # same values, after step 1)
99
+ # The `context` argument supposed to be an evidence, somewhat like
100
+ # `{weather: :sunny}`
101
+ def reduce_to(context)
102
+ limited_context = context.slice(*scope.keys)
103
+ return self.class.new(@scope, @vals) if limited_context.empty?
104
+ limited_scope = @scope.slice(*(@scope.keys - limited_context.keys))
58
105
 
59
- evidence_vals = evidence.values
60
- indices = evidence.keys.map { |k| index_by_var_name[k] }
61
- vals = @vals.select { |k, v| indices.map { |i| k[i] } == evidence_vals }
62
- vals.transform_keys! { |k| k - evidence_vals }
106
+ context_vals = limited_context.values
107
+ indices = limited_context.keys.map { |k| index_by_var_name[k] }
108
+ vals = @vals.select { |k, _v| indices.map { |i| k[i] } == context_vals }
109
+ vals.transform_keys! { |k| delete_by_indices(k, indices) }
63
110
 
64
- self.class.new(vars, vals)
111
+ self.class.new(limited_scope, vals)
65
112
  end
66
113
 
67
- def reduce(over_vars)
68
- vars = @vars.slice(*over_vars)
69
- indices = vars.keys.map { |k| index_by_var_name[k] }
70
- vals = @vals.group_by { |args, val| indices.map { |i| args[i] } }
114
+ # Returns new context defined over `var_names`, all other variables
115
+ # get eliminated. For every combination of `var_names`'s values
116
+ # the value of new factor is defined by summing up values in original factor
117
+ # having compatible value
118
+ def marginalize(var_names)
119
+ scope = @scope.slice(*var_names)
120
+
121
+ indices = scope.keys.map { |k| index_by_var_name[k] }
122
+ vals = @vals.group_by { |context, _val| indices.map { |i| context[i] } }
71
123
  vals.transform_values! { |v| v.map(&:last).sum }
72
- reduced = self.class.new(vars, vals)
73
- reduced.normalize
124
+
125
+ self.class.new(scope, vals)
126
+ end
127
+
128
+ def eliminate(var_name)
129
+ keep_var_names = var_names
130
+ keep_var_names.delete(var_name)
131
+ marginalize(keep_var_names)
132
+ end
133
+
134
+ def select(subcontext)
135
+ @vals.select do |context, _|
136
+ var_names.zip(context).slice(subcontext.keys) == subcontext
137
+ end
138
+ end
139
+
140
+ def *(other)
141
+ common_scope = @scope.keys & other.scope.keys
142
+ new_scope = scope.merge(other.scope)
143
+ new_vals = {}
144
+ group1 = group_by_scope_values(common_scope)
145
+ group2 = other.group_by_scope_values(common_scope)
146
+ group1.each do |scope, vals1|
147
+ combo = vals1.product(group2[scope])
148
+ combo.each do |(val1, val2)|
149
+ # values in scope must match variables order in new_scope, i.e.
150
+ # they must match `new_scope.var_names`
151
+ # The code bellow ensures it by merging two hashes in the same
152
+ # wasy as `new_scope`` is constructed above
153
+ val_by_name1 = var_names.zip(val1.first).to_h
154
+ val_by_name2 = other.var_names.zip(val2.first).to_h
155
+ new_vals[val_by_name1.merge(val_by_name2).values] = val1.last*val2.last
156
+ end
157
+ end
158
+ Factor.new(new_scope, new_vals)
159
+ end
160
+
161
+ def group_by_scope_values(scope_keys)
162
+ indices = scope_keys.map { |k| index_by_var_name[k] }
163
+ @vals.group_by { |context, _val| indices.map { |i| context[i] } }
74
164
  end
75
165
 
76
166
  private
77
167
 
78
- def initialize(vars = {}, vals = {})
79
- @vars = vars
168
+ def delete_by_indices(array, indices)
169
+ result = array.dup
170
+ indices.map { |i| result[i] = nil }
171
+ result.compact
172
+ end
173
+
174
+ def initialize(scope = {}, vals = {})
175
+ @scope = scope
80
176
  @vals = vals
81
177
  end
82
178
 
83
179
  def index_by_var_name
84
180
  return @index_by_var_name if @index_by_var_name
181
+
85
182
  @index_by_var_name = {}
86
- @vars.each_with_index { |(k, v), i| @index_by_var_name[k] = i }
183
+ @scope.each_with_index { |(k, _v), i| @index_by_var_name[k] = i }
87
184
  @index_by_var_name
88
185
  end
89
186
  end
@@ -1,33 +1,123 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require "bayesnet/node"
2
4
 
3
5
  module Bayesnet
6
+ # Acyclic graph
4
7
  class Graph
8
+ include Bayesnet::Logging
9
+
5
10
  attr_reader :nodes
6
11
 
7
12
  def initialize
8
13
  @nodes = {}
9
14
  end
10
15
 
11
- def var_names
12
- nodes.keys
13
- end
14
-
16
+ # +++ Graph DSL +++
15
17
  def node(name, parents: [], &block)
16
18
  raise Error, "DSL error, #node requires a &block" unless block
17
- node = Node.new(name, @nodes.slice(*parents))
19
+
20
+ node = Node.new(name, parents)
18
21
  node.instance_eval(&block)
19
22
  @nodes[name] = node
20
23
  end
24
+ # --- Graph DSL ---
25
+
26
+ # returns names of all nodes
27
+ def var_names
28
+ nodes.keys
29
+ end
30
+
31
+ # returns normalized distribution reduced to `evidence`
32
+ # and marginalized over `over`
33
+ def distribution(over: [], evidence: {}, algorithm: :variables_elimination)
34
+ case algorithm
35
+ when :brute_force
36
+ joint_distribution
37
+ .reduce_to(evidence)
38
+ .marginalize(over)
39
+ .normalize
40
+ when :variables_elimination
41
+ reduced_factors = nodes.values.map(&:factor).map { |f| f.reduce_to(evidence) }
42
+ not_include_in_order = evidence.keys.to_set + over.to_set
43
+ variables_order = elimination_order.reject { |v| not_include_in_order.include?(v) }
44
+ distribution = eliminate_variables(variables_order, reduced_factors)
45
+ distribution.normalize
46
+ else
47
+ raise "Uknown algorithm #{algorithm}"
48
+ end
49
+ end
50
+
51
+ def elimination_order
52
+ return @order if @order
53
+ @order = []
54
+ edges = Set.new
55
+ @nodes.each do |name, node|
56
+ parents = node.parent_nodes.keys
57
+ parents.each { |p| edges.add([name, p].to_set) }
58
+ parents.combination(2) { |p1, p2| edges.add([p1, p2].to_set) }
59
+ end
60
+ # edges now are moralized graph of `self`, just represented differently as
61
+ # set of edges
62
+
63
+ remaining_nodes = nodes.keys.to_set
64
+ until remaining_nodes.empty?
65
+ best_node = find_min_neighbor(remaining_nodes, edges)
66
+ remaining_nodes.delete(best_node)
67
+ @order.push(best_node)
68
+ clique = edges.select { |e| e.include?(best_node) }
69
+ edges -= clique
70
+ if edges.empty? #i.e. clique is the last edge
71
+ @order += remaining_nodes.to_a
72
+ remaining_nodes = Set.new
73
+ end
74
+ clique.
75
+ map { |e| e.delete(best_node) }.
76
+ map(&:first).
77
+ combination(2) { |p1, p2| edges.add([p1,p2].to_set) }
78
+ end
79
+ @order
80
+ end
21
81
 
22
- def distribution(over: [], evidence: {})
23
- limited = joint_distribution.limit_by(evidence)
24
- limited.reduce(over)
82
+ def find_min_neighbor(remaining_nodes, edges)
83
+ result = nil
84
+ min_neighbors = nil
85
+ remaining_nodes.each do |name, _|
86
+ neighbors = edges.count { |e| e.include?(name) }
87
+ if min_neighbors.nil? || neighbors < min_neighbors
88
+ min_neighbors = neighbors
89
+ result = name
90
+ end
91
+ end
92
+ result
93
+ end
94
+
95
+ def eliminate_variables(variables_order, factors)
96
+ logger.debug "Eliminating variables #{variables_order} from #{factors.size} factors #{factors.map(&:var_names)}"
97
+ remaining_factors = factors.to_set
98
+ variables_order.each do |var_name|
99
+ logger.debug "Eliminating '#{var_name}'..."
100
+ grouped_factors = remaining_factors.select { |f| f.var_names.include?(var_name) }
101
+ remaining_factors -= grouped_factors
102
+ logger.debug "Building new factor out of #{grouped_factors.size} factors having '#{var_name}' - #{grouped_factors.map(&:var_names)}"
103
+ product_factor = grouped_factors.reduce(&:*)
104
+ logger.debug "Removing variable from new factor"
105
+ new_factor = product_factor.eliminate(var_name)
106
+ logger.debug "New factor variables are #{new_factor.var_names}"
107
+ remaining_factors.add(new_factor)
108
+ logger.debug "The variable '#{var_name}' is elminated"
109
+ end
110
+ logger.debug "Non-eliminated variables are #{remaining_factors.map(&:var_names).flatten.uniq}"
111
+ result = remaining_factors.reduce(&:*)
112
+ logger.debug "Eliminating is done"
113
+ result
25
114
  end
26
115
 
27
116
  # This is MAP query, i.e. Maximum a Posteriory
117
+ # returns value of `var_name` having maximum likelihood, when `evidence` is observed
28
118
  def most_likely_value(var_name, evidence:)
29
119
  posterior_distribution = distribution(over: [var_name], evidence: evidence)
30
- mode = posterior_distribution.args(var_name).zip(posterior_distribution.values).max_by(&:last)
120
+ mode = posterior_distribution.contextes(var_name).zip(posterior_distribution.values).max_by(&:last)
31
121
  mode.first.first
32
122
  end
33
123
 
@@ -37,6 +127,7 @@ module Bayesnet
37
127
  posterior_distribution[*over_vars.values]
38
128
  end
39
129
 
130
+ # Essentially it builds product of all node's factors
40
131
  def joint_distribution
41
132
  return @joint_distribution if @joint_distribution
42
133
 
@@ -47,17 +138,27 @@ module Bayesnet
47
138
 
48
139
  factor = Factor.new
49
140
  @nodes.each do |node_name, node|
50
- factor.var node_name => node.values
141
+ factor.scope node_name => node.values
51
142
  end
52
143
 
53
- factor.args(*var_names).each do |args|
54
- val_by_name = var_names.zip(args).to_h
144
+ factor.contextes(*var_names).each do |context|
145
+ val_by_name = var_names.zip(context).to_h
55
146
  val = nodes.values.reduce(1.0) do |prob, node|
56
147
  prob * node.factor[val_by_name]
57
148
  end
58
- factor.val args + [val]
149
+ factor.val context + [val]
59
150
  end
60
151
  @joint_distribution = factor.normalize
61
152
  end
153
+
154
+ def parameters
155
+ nodes.values.map(&:parameters).sum
156
+ end
157
+
158
+ def resolve_factors
159
+ @nodes.values.each do |node|
160
+ node.resolve_factor(@nodes.slice(*node.parent_nodes))
161
+ end
162
+ end
62
163
  end
63
164
  end
@@ -0,0 +1,13 @@
1
+ # lib/logging.rb
2
+
3
+ module Bayesnet
4
+ def self.logger
5
+ @logger ||= Logger.new(STDOUT).tap { |l| l.level = :debug }
6
+ end
7
+
8
+ module Logging
9
+ def logger
10
+ Bayesnet.logger
11
+ end
12
+ end
13
+ end
data/lib/bayesnet/node.rb CHANGED
@@ -6,8 +6,10 @@ module Bayesnet
6
6
  @name = name
7
7
  @parent_nodes = parent_nodes
8
8
  @values = []
9
+ @factor = Factor.new
9
10
  end
10
11
 
12
+ # +++ Node DSL +++
11
13
  def values(hash_or_array = nil, &block)
12
14
  case hash_or_array
13
15
  when NilClass
@@ -16,7 +18,7 @@ module Bayesnet
16
18
  @values = hash_or_array.keys
17
19
  node = self
18
20
  @factor = Factor.build do
19
- var node.name => node.values
21
+ scope node.name => node.values
20
22
  hash_or_array.each do |value, probability|
21
23
  val [value, probability]
22
24
  end
@@ -24,25 +26,39 @@ module Bayesnet
24
26
  when Array
25
27
  raise Error, "DSL error, #values requires a &block when first argument is an Array" unless block
26
28
  @values = hash_or_array
27
- node = self
28
- @factor = Factor.build do
29
- var node.name => node.values
30
- node.parent_nodes.each do |parent_node_name, parent_node|
31
- var parent_node_name => parent_node.values
32
- end
33
- end
34
- instance_eval(&block)
29
+ @factor = block
35
30
  end
36
31
  end
37
32
 
38
33
  def distributions(&block)
39
34
  instance_eval(&block)
40
35
  end
36
+ # --- Node DSL ---
37
+
38
+ def parameters
39
+ (values.size - 1) * parent_nodes.values.reduce(1) { |mul, n| mul * n.values.size }
40
+ end
41
41
 
42
42
  def as(distribution, given:)
43
43
  @values.zip(distribution).each do |value, probability|
44
44
  @factor.val [value] + given + [probability]
45
45
  end
46
46
  end
47
+
48
+ def resolve_factor(parent_nodes)
49
+ @parent_nodes = parent_nodes
50
+ if @factor.is_a?(Proc)
51
+ proc = @factor
52
+ node = self
53
+ @factor = Factor.build do
54
+ scope node.name => node.values
55
+ node.parent_nodes.each do |parent_node_name, parent_node|
56
+ scope parent_node_name => parent_node.values
57
+ end
58
+ end
59
+ instance_eval(&proc)
60
+ end
61
+ end
62
+
47
63
  end
48
64
  end