bayesnet 0.0.2 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9f17178cdfa472aea14e7769f1831f1903ceee2613b6e7b71e1135990dc974ab
4
- data.tar.gz: c55cb95134d4dbb479e89839be7a4634b6bc6a8be52c5c066f1cc1ad21c584a7
3
+ metadata.gz: 7591665046345784f55275c06d1129fd91ee3f098f3800b2c03b6f9bbfd8e172
4
+ data.tar.gz: ec9009ab90593d42fa2506a230e5900d5a39bebb1a7fbd874953d6c86022b2eb
5
5
  SHA512:
6
- metadata.gz: fda426d9dd319b2ad8300a8ba59acae511740cc271729f40b36ce7a7765d8019a901d9d83a2e955ccd964901ee604a3ad96fba092ce41394da5980a77d1bb198
7
- data.tar.gz: 03d4beb66b11ba684007dd49f67ef2ae085ecd5c6e5ca9b8b63327eee109ada3b1f389731543fa6d970bff110230231bc1533022640367a906f57dec1c41bc6d
6
+ metadata.gz: 5e668b431f55f9239ad3ae06cdc020098ff4a8b68f7934283d6f77a3969a014aa17f68df12b23013d1681e79fdecd8f4c8e4da105e9430a16d2ab8075bbcca7b
7
+ data.tar.gz: 75eceac300152cfa8d0ce736b16939e779f127dc844ea0c8ce8e1d0f363b04048107533f64c744420b2631fd3e8678d6812e15c6b770e2b37a598ae290af773a
data/CHANGELOG.md CHANGED
@@ -1,5 +1,18 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.6.0] - 2022-06-26
4
+ - Using variables elimination algorithm to build a distribution
5
+
6
+ ## [0.5.0] - 2022-02-26
7
+
8
+ - Constructing networks out of the `.BIF` ([Interchange Format for Bayesian Networks](https://www.cs.washington.edu/dm/vfml/appendixes/bif.htm)) files.
9
+ - Fixing inference bug
10
+ - Network children nodes could be specified ***before** their parents
11
+
12
+ ## [0.0.3] - 2021-12-29
13
+
14
+ - Fixing terminology used in Factor class
15
+
3
16
  ## [0.0.2] - 2021-12-28
4
17
 
5
18
  - README, CI/CD for Ruby 2.6, 2.7, 3.1 added
data/Gemfile CHANGED
@@ -6,10 +6,10 @@ source "https://rubygems.org"
6
6
  gemspec
7
7
 
8
8
  gem "rake", "~> 13.0"
9
+ gem "treetop", "~> 1.6"
9
10
 
10
11
  group :development, :test do
11
12
  gem "m", "~> 1.5.0"
12
13
  gem "minitest", "~> 5.0"
13
14
  gem "pry-byebug", "~> 3.9.0"
14
- gem "standard", "~> 1.3"
15
15
  end
data/Gemfile.lock CHANGED
@@ -1,12 +1,11 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- bayesnet (0.0.2)
4
+ bayesnet (0.6.0)
5
5
 
6
6
  GEM
7
7
  remote: https://rubygems.org/
8
8
  specs:
9
- ast (2.4.2)
10
9
  byebug (11.1.3)
11
10
  coderay (1.1.3)
12
11
  m (1.5.1)
@@ -14,38 +13,16 @@ GEM
14
13
  rake (>= 0.9.2.2)
15
14
  method_source (1.0.0)
16
15
  minitest (5.15.0)
17
- parallel (1.21.0)
18
- parser (3.0.3.2)
19
- ast (~> 2.4.1)
16
+ polyglot (0.3.5)
20
17
  pry (0.13.1)
21
18
  coderay (~> 1.1)
22
19
  method_source (~> 1.0)
23
20
  pry-byebug (3.9.0)
24
21
  byebug (~> 11.0)
25
22
  pry (~> 0.13.0)
26
- rainbow (3.0.0)
27
23
  rake (13.0.6)
28
- regexp_parser (2.2.0)
29
- rexml (3.2.5)
30
- rubocop (1.23.0)
31
- parallel (~> 1.10)
32
- parser (>= 3.0.0.0)
33
- rainbow (>= 2.2.2, < 4.0)
34
- regexp_parser (>= 1.8, < 3.0)
35
- rexml
36
- rubocop-ast (>= 1.12.0, < 2.0)
37
- ruby-progressbar (~> 1.7)
38
- unicode-display_width (>= 1.4.0, < 3.0)
39
- rubocop-ast (1.15.0)
40
- parser (>= 3.0.1.1)
41
- rubocop-performance (1.12.0)
42
- rubocop (>= 1.7.0, < 2.0)
43
- rubocop-ast (>= 0.4.0)
44
- ruby-progressbar (1.11.0)
45
- standard (1.5.0)
46
- rubocop (= 1.23.0)
47
- rubocop-performance (= 1.12.0)
48
- unicode-display_width (2.1.0)
24
+ treetop (1.6.11)
25
+ polyglot (~> 0.3)
49
26
 
50
27
  PLATFORMS
51
28
  x86_64-darwin-19
@@ -57,7 +34,7 @@ DEPENDENCIES
57
34
  minitest (~> 5.0)
58
35
  pry-byebug (~> 3.9.0)
59
36
  rake (~> 13.0)
60
- standard (~> 1.3)
37
+ treetop (~> 1.6)
61
38
 
62
39
  BUNDLED WITH
63
- 2.2.32
40
+ 2.3.3
data/README.md CHANGED
@@ -1,6 +1,7 @@
1
1
  # Bayesnet
2
2
 
3
- This gem provides an API for building a Bayesian network and executing some queries against it.
3
+ This gem provides an DSL for constructing Bayesian networks and let to execute basic inference queries. It is also capable of parsing .BIF format ([The Interchange Format for Bayesian Networks](https://www.cs.washington.edu/dm/vfml/appendixes/bif.htm)).
4
+
4
5
 
5
6
  ### Example:
6
7
 
@@ -66,12 +67,14 @@ The inference is based on summing over joint distribution, i.e. it is the simple
66
67
  most expensive way to calculate it. No optimization is implemented in this version; the code
67
68
  is more a proof of API.
68
69
 
70
+ ### [Another example](https://afurmanov.com/reducing-anxiety-with-bayesian-network) of using this gem
71
+
69
72
  ## Installation
70
73
 
71
74
  Add this line to your application's Gemfile:
72
75
 
73
76
  ```ruby
74
- gem 'bayesnet'
77
+ em 'bayesnet'
75
78
  ```
76
79
 
77
80
  And then execute:
data/Rakefile CHANGED
@@ -9,6 +9,9 @@ Rake::TestTask.new(:test) do |t|
9
9
  t.test_files = FileList["test/**/*_test.rb"]
10
10
  end
11
11
 
12
- require "standard/rake"
12
+ Rake::TestTask.new("regen-bif") do |t|
13
+ `rm ./lib/bayesnet/parsers/bif.rb`
14
+ `tt ./lib/bayesnet/parsers/bif.treetop`
15
+ end
13
16
 
14
- task default: %i[test standard]
17
+ task default: %i[test]
data/bayesnet.gemspec CHANGED
@@ -34,7 +34,6 @@ Gem::Specification.new do |spec|
34
34
  spec.add_development_dependency "m", "~> 1.5.0"
35
35
  spec.add_development_dependency "minitest", "~> 5.0"
36
36
  spec.add_development_dependency "pry-byebug", "~> 3.9.0"
37
- spec.add_development_dependency "standard", "~> 1.3"
38
37
 
39
38
  # For more information and examples about making a new gem, checkout our
40
39
  # guide at: https://bundler.io/guides/creating_gem.html
Binary file
data/lib/bayesnet/dsl.rb CHANGED
@@ -1,10 +1,14 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require "bayesnet/graph"
2
4
 
3
5
  module Bayesnet
6
+ # Bayesnet::DSL.define ...
4
7
  module DSL
5
8
  def define(&block)
6
9
  graph = Graph.new
7
10
  graph.instance_eval(&block) if block
11
+ graph.resolve_factors
8
12
  graph
9
13
  end
10
14
  end
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module Bayesnet
2
4
  class Error < StandardError
3
5
  end
@@ -1,89 +1,186 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module Bayesnet
2
- # Factor if a function of sevaral variables (A, B, ...) each defined on values from finite set
4
+ # Factor if a function of several variables (A, B, ...), where
5
+ # every variable cold take values from some finite set
3
6
  class Factor
7
+ # +++ Factor DSL +++
8
+ #
9
+ # Factor DSL entry point:
4
10
  def self.build(&block)
5
11
  factor = new
6
12
  factor.instance_eval(&block)
7
13
  factor
8
14
  end
9
15
 
10
- # Specifies variable name together with its values
11
- def var(var_name_to_values)
12
- @vars.merge!(var_name_to_values)
16
+ # Factor DSL
17
+ # Defining variable with list of its possible values looks like:
18
+ # ```
19
+ # Bayesnet::Factor.build do
20
+ # scope weather: %i[sunny cloudy]
21
+ # scope mood: %i[bad good]
22
+ # ...
23
+ # ```
24
+ # ^ this code defines to variables `weather` and `mood`, where
25
+ # `weather` could be :sunny or :cloudy, and
26
+ # `mood` could be :bad or :good
27
+ def scope(var_name_to_values = nil)
28
+ if var_name_to_values
29
+ @scope.merge!(var_name_to_values)
30
+ else
31
+ @scope
32
+ end
13
33
  end
14
34
 
15
- # Specifies function values for args. Latest args is an function value, all previous are argument values
16
- def val(*args)
17
- args = args[0] if args.size == 1 && args[0].is_a?(Array)
18
- @vals[args[0..-2]] = args[-1]
35
+ # Factor DSL
36
+ # Specifies factor value for some set of variable values, i.e.
37
+ # ```
38
+ # Bayesnet::Factor.build do
39
+ # scope weather: %i[sunny cloudy]
40
+ # scope mood: %i[bad good]
41
+ # val :sunny, :bad, 0.1
42
+ # ...
43
+ # ```
44
+ # ^ this code says the value of factor for [weather == :sunny, mood == :bad] is 0.1
45
+ def val(*context_and_val)
46
+ context_and_val = context_and_val[0] if context_and_val.size == 1 && context_and_val[0].is_a?(Array)
47
+ @vals[context_and_val[0..-2]] = context_and_val[-1]
19
48
  end
49
+ # --- Factor DSL ---
20
50
 
51
+ # List of variable names
21
52
  def var_names
22
- @vars.keys
53
+ @scope.keys
23
54
  end
24
55
 
25
- def [](*args)
26
- key = if args.size == 1 && args[0].is_a?(Hash)
27
- args[0].slice(*var_names).values
28
- else
29
- args
30
- end
56
+ # accessor factor value, i.e
57
+ # ```
58
+ # factor = Bayesnet::Factor.build do
59
+ # scope weather: %i[sunny cloudy]
60
+ # scope mood: %i[bad good]
61
+ # val :sunny, :bad, 0.1
62
+ # ...
63
+ # end
64
+ # factor[:sunny, :bad] # 0.1
65
+ # ```
66
+ def [](*context)
67
+ key = if context.size == 1 && context[0].is_a?(Hash)
68
+ context[0].slice(*var_names).values
69
+ else
70
+ context
71
+ end
31
72
  @vals[key]
32
73
  end
33
74
 
34
- def self.from_distribution(var_distribution)
35
- self.class.new(var_distribution.keys, var_distribution.values.map(&:to_a))
36
- end
37
-
38
- def args(*var_names)
75
+ # returns all combinations of values of `var_names`
76
+ def contextes(*var_names)
39
77
  return [] if var_names.empty?
40
- @vars[var_names[0]].product(*var_names[1..].map { |var_name| @vars[var_name] })
78
+
79
+ @scope[var_names[0]].product(*var_names[1..].map { |var_name| @scope[var_name] })
41
80
  end
42
81
 
82
+ # returns all possible values
43
83
  def values
44
84
  @vals.values
45
85
  end
46
86
 
87
+ # returns new normalized factor, i.e. where sum of all values is 1.0
47
88
  def normalize
48
89
  vals = @vals.clone
49
90
  norm_factor = vals.map(&:last).sum * 1.0
50
- vals.each { |k, v| vals[k] /= norm_factor }
51
- self.class.new(@vars.clone, vals)
91
+ vals.each { |k, _v| vals[k] /= norm_factor }
92
+ self.class.new(@scope.clone, vals)
52
93
  end
53
94
 
54
- def limit_by(evidence)
55
- # todo: use Hash#except when Ruby 2.6 support no longer needed
56
- evidence_keys_set = evidence.keys.to_set
57
- vars = @vars.reject { |k, _| evidence_keys_set.include?(k) }
95
+ # Returns factor built as follows:
96
+ # 1. Original factor gets filtered out by variables having values compatible with `context`
97
+ # 2. Returned factor does not have any variables from `context` (because they have
98
+ # same values, after step 1)
99
+ # The `context` argument supposed to be an evidence, somewhat like
100
+ # `{weather: :sunny}`
101
+ def reduce_to(context)
102
+ limited_context = context.slice(*scope.keys)
103
+ return self.class.new(@scope, @vals) if limited_context.empty?
104
+ limited_scope = @scope.slice(*(@scope.keys - limited_context.keys))
58
105
 
59
- evidence_vals = evidence.values
60
- indices = evidence.keys.map { |k| index_by_var_name[k] }
61
- vals = @vals.select { |k, v| indices.map { |i| k[i] } == evidence_vals }
62
- vals.transform_keys! { |k| k - evidence_vals }
106
+ context_vals = limited_context.values
107
+ indices = limited_context.keys.map { |k| index_by_var_name[k] }
108
+ vals = @vals.select { |k, _v| indices.map { |i| k[i] } == context_vals }
109
+ vals.transform_keys! { |k| delete_by_indices(k, indices) }
63
110
 
64
- self.class.new(vars, vals)
111
+ self.class.new(limited_scope, vals)
65
112
  end
66
113
 
67
- def reduce(over_vars)
68
- vars = @vars.slice(*over_vars)
69
- indices = vars.keys.map { |k| index_by_var_name[k] }
70
- vals = @vals.group_by { |args, val| indices.map { |i| args[i] } }
114
+ # Returns new context defined over `var_names`, all other variables
115
+ # get eliminated. For every combination of `var_names`'s values
116
+ # the value of new factor is defined by summing up values in original factor
117
+ # having compatible value
118
+ def marginalize(var_names)
119
+ scope = @scope.slice(*var_names)
120
+
121
+ indices = scope.keys.map { |k| index_by_var_name[k] }
122
+ vals = @vals.group_by { |context, _val| indices.map { |i| context[i] } }
71
123
  vals.transform_values! { |v| v.map(&:last).sum }
72
- reduced = self.class.new(vars, vals)
73
- reduced.normalize
124
+
125
+ self.class.new(scope, vals)
126
+ end
127
+
128
+ def eliminate(var_name)
129
+ keep_var_names = var_names
130
+ keep_var_names.delete(var_name)
131
+ marginalize(keep_var_names)
132
+ end
133
+
134
+ def select(subcontext)
135
+ @vals.select do |context, _|
136
+ var_names.zip(context).slice(subcontext.keys) == subcontext
137
+ end
138
+ end
139
+
140
+ def *(other)
141
+ common_scope = @scope.keys & other.scope.keys
142
+ new_scope = scope.merge(other.scope)
143
+ new_vals = {}
144
+ group1 = group_by_scope_values(common_scope)
145
+ group2 = other.group_by_scope_values(common_scope)
146
+ group1.each do |scope, vals1|
147
+ combo = vals1.product(group2[scope])
148
+ combo.each do |(val1, val2)|
149
+ # values in scope must match variables order in new_scope, i.e.
150
+ # they must match `new_scope.var_names`
151
+ # The code bellow ensures it by merging two hashes in the same
152
+ # wasy as `new_scope`` is constructed above
153
+ val_by_name1 = var_names.zip(val1.first).to_h
154
+ val_by_name2 = other.var_names.zip(val2.first).to_h
155
+ new_vals[val_by_name1.merge(val_by_name2).values] = val1.last*val2.last
156
+ end
157
+ end
158
+ Factor.new(new_scope, new_vals)
159
+ end
160
+
161
+ def group_by_scope_values(scope_keys)
162
+ indices = scope_keys.map { |k| index_by_var_name[k] }
163
+ @vals.group_by { |context, _val| indices.map { |i| context[i] } }
74
164
  end
75
165
 
76
166
  private
77
167
 
78
- def initialize(vars = {}, vals = {})
79
- @vars = vars
168
+ def delete_by_indices(array, indices)
169
+ result = array.dup
170
+ indices.map { |i| result[i] = nil }
171
+ result.compact
172
+ end
173
+
174
+ def initialize(scope = {}, vals = {})
175
+ @scope = scope
80
176
  @vals = vals
81
177
  end
82
178
 
83
179
  def index_by_var_name
84
180
  return @index_by_var_name if @index_by_var_name
181
+
85
182
  @index_by_var_name = {}
86
- @vars.each_with_index { |(k, v), i| @index_by_var_name[k] = i }
183
+ @scope.each_with_index { |(k, _v), i| @index_by_var_name[k] = i }
87
184
  @index_by_var_name
88
185
  end
89
186
  end
@@ -1,33 +1,123 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require "bayesnet/node"
2
4
 
3
5
  module Bayesnet
6
+ # Acyclic graph
4
7
  class Graph
8
+ include Bayesnet::Logging
9
+
5
10
  attr_reader :nodes
6
11
 
7
12
  def initialize
8
13
  @nodes = {}
9
14
  end
10
15
 
11
- def var_names
12
- nodes.keys
13
- end
14
-
16
+ # +++ Graph DSL +++
15
17
  def node(name, parents: [], &block)
16
18
  raise Error, "DSL error, #node requires a &block" unless block
17
- node = Node.new(name, @nodes.slice(*parents))
19
+
20
+ node = Node.new(name, parents)
18
21
  node.instance_eval(&block)
19
22
  @nodes[name] = node
20
23
  end
24
+ # --- Graph DSL ---
25
+
26
+ # returns names of all nodes
27
+ def var_names
28
+ nodes.keys
29
+ end
30
+
31
+ # returns normalized distribution reduced to `evidence`
32
+ # and marginalized over `over`
33
+ def distribution(over: [], evidence: {}, algorithm: :variables_elimination)
34
+ case algorithm
35
+ when :brute_force
36
+ joint_distribution
37
+ .reduce_to(evidence)
38
+ .marginalize(over)
39
+ .normalize
40
+ when :variables_elimination
41
+ reduced_factors = nodes.values.map(&:factor).map { |f| f.reduce_to(evidence) }
42
+ not_include_in_order = evidence.keys.to_set + over.to_set
43
+ variables_order = elimination_order.reject { |v| not_include_in_order.include?(v) }
44
+ distribution = eliminate_variables(variables_order, reduced_factors)
45
+ distribution.normalize
46
+ else
47
+ raise "Uknown algorithm #{algorithm}"
48
+ end
49
+ end
50
+
51
+ def elimination_order
52
+ return @order if @order
53
+ @order = []
54
+ edges = Set.new
55
+ @nodes.each do |name, node|
56
+ parents = node.parent_nodes.keys
57
+ parents.each { |p| edges.add([name, p].to_set) }
58
+ parents.combination(2) { |p1, p2| edges.add([p1, p2].to_set) }
59
+ end
60
+ # edges now are moralized graph of `self`, just represented differently as
61
+ # set of edges
62
+
63
+ remaining_nodes = nodes.keys.to_set
64
+ until remaining_nodes.empty?
65
+ best_node = find_min_neighbor(remaining_nodes, edges)
66
+ remaining_nodes.delete(best_node)
67
+ @order.push(best_node)
68
+ clique = edges.select { |e| e.include?(best_node) }
69
+ edges -= clique
70
+ if edges.empty? #i.e. clique is the last edge
71
+ @order += remaining_nodes.to_a
72
+ remaining_nodes = Set.new
73
+ end
74
+ clique.
75
+ map { |e| e.delete(best_node) }.
76
+ map(&:first).
77
+ combination(2) { |p1, p2| edges.add([p1,p2].to_set) }
78
+ end
79
+ @order
80
+ end
21
81
 
22
- def distribution(over: [], evidence: {})
23
- limited = joint_distribution.limit_by(evidence)
24
- limited.reduce(over)
82
+ def find_min_neighbor(remaining_nodes, edges)
83
+ result = nil
84
+ min_neighbors = nil
85
+ remaining_nodes.each do |name, _|
86
+ neighbors = edges.count { |e| e.include?(name) }
87
+ if min_neighbors.nil? || neighbors < min_neighbors
88
+ min_neighbors = neighbors
89
+ result = name
90
+ end
91
+ end
92
+ result
93
+ end
94
+
95
+ def eliminate_variables(variables_order, factors)
96
+ logger.debug "Eliminating variables #{variables_order} from #{factors.size} factors #{factors.map(&:var_names)}"
97
+ remaining_factors = factors.to_set
98
+ variables_order.each do |var_name|
99
+ logger.debug "Eliminating '#{var_name}'..."
100
+ grouped_factors = remaining_factors.select { |f| f.var_names.include?(var_name) }
101
+ remaining_factors -= grouped_factors
102
+ logger.debug "Building new factor out of #{grouped_factors.size} factors having '#{var_name}' - #{grouped_factors.map(&:var_names)}"
103
+ product_factor = grouped_factors.reduce(&:*)
104
+ logger.debug "Removing variable from new factor"
105
+ new_factor = product_factor.eliminate(var_name)
106
+ logger.debug "New factor variables are #{new_factor.var_names}"
107
+ remaining_factors.add(new_factor)
108
+ logger.debug "The variable '#{var_name}' is elminated"
109
+ end
110
+ logger.debug "Non-eliminated variables are #{remaining_factors.map(&:var_names).flatten.uniq}"
111
+ result = remaining_factors.reduce(&:*)
112
+ logger.debug "Eliminating is done"
113
+ result
25
114
  end
26
115
 
27
116
  # This is MAP query, i.e. Maximum a Posteriory
117
+ # returns value of `var_name` having maximum likelihood, when `evidence` is observed
28
118
  def most_likely_value(var_name, evidence:)
29
119
  posterior_distribution = distribution(over: [var_name], evidence: evidence)
30
- mode = posterior_distribution.args(var_name).zip(posterior_distribution.values).max_by(&:last)
120
+ mode = posterior_distribution.contextes(var_name).zip(posterior_distribution.values).max_by(&:last)
31
121
  mode.first.first
32
122
  end
33
123
 
@@ -37,6 +127,7 @@ module Bayesnet
37
127
  posterior_distribution[*over_vars.values]
38
128
  end
39
129
 
130
+ # Essentially it builds product of all node's factors
40
131
  def joint_distribution
41
132
  return @joint_distribution if @joint_distribution
42
133
 
@@ -47,17 +138,27 @@ module Bayesnet
47
138
 
48
139
  factor = Factor.new
49
140
  @nodes.each do |node_name, node|
50
- factor.var node_name => node.values
141
+ factor.scope node_name => node.values
51
142
  end
52
143
 
53
- factor.args(*var_names).each do |args|
54
- val_by_name = var_names.zip(args).to_h
144
+ factor.contextes(*var_names).each do |context|
145
+ val_by_name = var_names.zip(context).to_h
55
146
  val = nodes.values.reduce(1.0) do |prob, node|
56
147
  prob * node.factor[val_by_name]
57
148
  end
58
- factor.val args + [val]
149
+ factor.val context + [val]
59
150
  end
60
151
  @joint_distribution = factor.normalize
61
152
  end
153
+
154
+ def parameters
155
+ nodes.values.map(&:parameters).sum
156
+ end
157
+
158
+ def resolve_factors
159
+ @nodes.values.each do |node|
160
+ node.resolve_factor(@nodes.slice(*node.parent_nodes))
161
+ end
162
+ end
62
163
  end
63
164
  end
@@ -0,0 +1,13 @@
1
+ # lib/logging.rb
2
+
3
+ module Bayesnet
4
+ def self.logger
5
+ @logger ||= Logger.new(STDOUT).tap { |l| l.level = :debug }
6
+ end
7
+
8
+ module Logging
9
+ def logger
10
+ Bayesnet.logger
11
+ end
12
+ end
13
+ end
data/lib/bayesnet/node.rb CHANGED
@@ -6,8 +6,10 @@ module Bayesnet
6
6
  @name = name
7
7
  @parent_nodes = parent_nodes
8
8
  @values = []
9
+ @factor = Factor.new
9
10
  end
10
11
 
12
+ # +++ Node DSL +++
11
13
  def values(hash_or_array = nil, &block)
12
14
  case hash_or_array
13
15
  when NilClass
@@ -16,7 +18,7 @@ module Bayesnet
16
18
  @values = hash_or_array.keys
17
19
  node = self
18
20
  @factor = Factor.build do
19
- var node.name => node.values
21
+ scope node.name => node.values
20
22
  hash_or_array.each do |value, probability|
21
23
  val [value, probability]
22
24
  end
@@ -24,25 +26,39 @@ module Bayesnet
24
26
  when Array
25
27
  raise Error, "DSL error, #values requires a &block when first argument is an Array" unless block
26
28
  @values = hash_or_array
27
- node = self
28
- @factor = Factor.build do
29
- var node.name => node.values
30
- node.parent_nodes.each do |parent_node_name, parent_node|
31
- var parent_node_name => parent_node.values
32
- end
33
- end
34
- instance_eval(&block)
29
+ @factor = block
35
30
  end
36
31
  end
37
32
 
38
33
  def distributions(&block)
39
34
  instance_eval(&block)
40
35
  end
36
+ # --- Node DSL ---
37
+
38
+ def parameters
39
+ (values.size - 1) * parent_nodes.values.reduce(1) { |mul, n| mul * n.values.size }
40
+ end
41
41
 
42
42
  def as(distribution, given:)
43
43
  @values.zip(distribution).each do |value, probability|
44
44
  @factor.val [value] + given + [probability]
45
45
  end
46
46
  end
47
+
48
+ def resolve_factor(parent_nodes)
49
+ @parent_nodes = parent_nodes
50
+ if @factor.is_a?(Proc)
51
+ proc = @factor
52
+ node = self
53
+ @factor = Factor.build do
54
+ scope node.name => node.values
55
+ node.parent_nodes.each do |parent_node_name, parent_node|
56
+ scope parent_node_name => parent_node.values
57
+ end
58
+ end
59
+ instance_eval(&proc)
60
+ end
61
+ end
62
+
47
63
  end
48
64
  end