mini_pgm 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 371cafde1b4a0699d5150dbeee1c502abdfa4f01b1b02b3c5c1ee9921de430e8
4
+ data.tar.gz: 557de3515a26ed7ff51883d93927255aa5f61e1cfbb786f1054d09787639b62e
5
+ SHA512:
6
+ metadata.gz: d4587b67bacb38663fbcd0ac05ad98cc5e6659eba4fb7af44f0b7ffbe01ae0bb817b756571d6b9f9c7cf5801fdc9f667b1cbc6b7a70efb1b95035f5510e78aba
7
+ data.tar.gz: 80a6468b078c223fd16092947501517114969980d0fc1af38499c0f39920984d451ba8bef7b96ce54f000a4154e2d3dbffb8db8482c3118af400c2e6a02bb58e
data/lib/mini_pgm.rb ADDED
@@ -0,0 +1,6 @@
1
+ require 'mini_pgm/edge'
2
+ require 'mini_pgm/model'
3
+ require 'mini_pgm/node'
4
+ require 'mini_pgm/printer'
5
+ require 'mini_pgm/tabular_cpd'
6
+ require 'mini_pgm/variable'
@@ -0,0 +1,12 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MiniPGM
4
+ #
5
+ # Represents a directed edge between two nodes in a PGM
6
+ #
7
+ Edge = Struct.new(:from, :to) do
8
+ def to_s
9
+ "#{from} -> #{to}"
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,113 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'node'
4
+
5
+ module MiniPGM
6
+ #
7
+ # Represents a Probabilistic Graphical Model (PGM)
8
+ #
9
+ class Model
10
+ class ModelError < StandardError; end
11
+
12
+ #
13
+ # Edges between individual nodes, sorted by label of the outgoing edge, e.g:
14
+ #
15
+ # Pollution -> Cancer
16
+ # Smoker -> Cancer
17
+ #
18
+ attr_reader :edges
19
+
20
+ #
21
+ # Lookup table of labelled nodes, each associated with a set of labels for all incoming edges, e.g:
22
+ #
23
+ # { Pollution, Smoker } -> Cancer -> { }
24
+ # { } -> Pollution -> { Cancer }
25
+ # { } -> Smoker -> { Cancer }
26
+ #
27
+ attr_reader :nodes
28
+
29
+ # most recent error after calling `valid?`
30
+ attr_reader :error
31
+
32
+ def initialize(*edges)
33
+ @edges = sort_edges(edges)
34
+ @nodes = reduce_edges(edges)
35
+ end
36
+
37
+ def add_cpd(cpd)
38
+ node = @nodes[cpd.variable.label]
39
+ raise ArgumentError, "node does not exist for label #{node.label}" unless node
40
+
41
+ check_cpd_evidence!(cpd.evidence.map(&:label), node.incoming_edges)
42
+ node.cpd = cpd
43
+ end
44
+
45
+ def to_s
46
+ ['Edges:', edges_to_s, '', 'Nodes:', nodes_to_s, '', 'Valid:', valid?, ''].join("\n")
47
+ end
48
+
49
+ def validate!
50
+ @nodes.each_value do |node|
51
+ raise ModelError, "node '#{node.label}' does not have a CPD" unless node.cpd
52
+ end
53
+
54
+ # validate cardinality between nodes for each edge
55
+ @edges.each do |edge|
56
+ validate_cardinality!(@nodes[edge.to], @nodes[edge.from])
57
+ end
58
+ end
59
+
60
+ def valid?
61
+ @error = nil
62
+ validate!
63
+ true
64
+ rescue ModelError => e
65
+ @error = e
66
+ false
67
+ end
68
+
69
+ private
70
+
71
+ def check_cpd_evidence!(cpd_evidence, node_dependencies)
72
+ cpd_evidence.each do |evidence|
73
+ raise ArgumentError, "node is missing dependency for CPD evidence '#{evidence}'" \
74
+ unless node_dependencies.include?(evidence)
75
+ end
76
+
77
+ node_dependencies.each do |dependency|
78
+ raise ArgumentError, "CPD is missing evidence for node dependency '#{dependency}'" \
79
+ unless cpd_evidence.include?(dependency)
80
+ end
81
+ end
82
+
83
+ def edges_to_s
84
+ @edges.map(&:to_s).join("\n")
85
+ end
86
+
87
+ def nodes_to_s
88
+ @nodes.keys.sort.map { |key| @nodes[key].to_s }.join("\n")
89
+ end
90
+
91
+ def reduce_edges(edges)
92
+ edges.each_with_object({}) do |edge, reduced|
93
+ # create node for incoming edge
94
+ (reduced[edge.to] ||= MiniPGM::Node.new(edge.to)).incoming_edges.add(edge.from)
95
+
96
+ # create node for outgoing edge
97
+ (reduced[edge.from] ||= MiniPGM::Node.new(edge.from)).outgoing_edges.add(edge.to)
98
+ end
99
+ end
100
+
101
+ def sort_edges(edges)
102
+ edges.sort_by(&:from)
103
+ end
104
+
105
+ def validate_cardinality!(to, from)
106
+ expected = to.cpd.evidence.find { |ev| ev.label == from.label }.cardinality
107
+ actual = from.cpd.variable.cardinality
108
+
109
+ raise ModelError, "cardinality mismatch in CPDs of '#{from.label}' (#{actual}) and '#{to.label}' (#{expected})" \
110
+ unless expected == actual
111
+ end
112
+ end
113
+ end
@@ -0,0 +1,36 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'set'
4
+
5
+ module MiniPGM
6
+ #
7
+ # Represents an individual node, or random variable, in a model
8
+ #
9
+ # Example string representation for a node with two incoming edges, and no outgoing edges:
10
+ #
11
+ # { Pollution, Smoker } -> Cancer -> { }
12
+ #
13
+ # Or for nodes with no incoming edges and only outgoing edges:
14
+ #
15
+ # { } -> Pollution -> { Cancer }
16
+ # { } -> Smoker -> { Cancer }
17
+ #
18
+ class Node
19
+ attr_accessor :cpd
20
+ attr_reader :label, :incoming_edges, :outgoing_edges
21
+
22
+ def initialize(label)
23
+ @label = label
24
+ @incoming_edges = Set.new
25
+ @outgoing_edges = Set.new
26
+ end
27
+
28
+ def to_s
29
+ [write_set(@incoming_edges), @label, write_set(@outgoing_edges)].join(' -> ')
30
+ end
31
+
32
+ def write_set(set)
33
+ set.empty? ? '{ }' : "{ #{set.to_a.sort.join(', ')} }"
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,48 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MiniPGM
4
+ #
5
+ # Helper class for printing data in an easily readable ASCII format
6
+ #
7
+ # Example output:
8
+ #
9
+ # +-----------+-------------+-------------+-------------+-------------+
10
+ # | Smoker | Smoker_0 | Smoker_0 | Smoker_1 | Smoker_1 |
11
+ # +-----------+-------------+-------------+-------------+-------------+
12
+ # | Pollution | Pollution_0 | Pollution_1 | Pollution_0 | Pollution_1 |
13
+ # +-----------+-------------+-------------+-------------+-------------+
14
+ # | Cancer_0 | 0.03 | 0.05 | 0.001 | 0.02 |
15
+ # +-----------+-------------+-------------+-------------+-------------+
16
+ # | Cancer_1 | 0.97 | 0.95 | 0.999 | 0.98 |
17
+ # +-----------+-------------+-------------+-------------+-------------+
18
+ #
19
+ class Printer
20
+ class << self
21
+ def print(rows)
22
+ # calculate column widths
23
+ num_columns = rows[0].length
24
+ column_widths = (0...num_columns).map { |col| max_width(rows, col) }
25
+
26
+ # write table
27
+ div = write_divider(column_widths)
28
+ rows.flat_map { |row| [write_row(column_widths, row), div] }
29
+ .unshift(div)
30
+ .join("\n")
31
+ end
32
+
33
+ private
34
+
35
+ def max_width(rows, column)
36
+ rows.map { |row| row[column].length }.max
37
+ end
38
+
39
+ def write_divider(column_widths)
40
+ "+-#{column_widths.map { |g| '-' * g }.join('-+-')}-+"
41
+ end
42
+
43
+ def write_row(column_widths, values)
44
+ "| #{values.map.with_index { |value, i| value.ljust(column_widths[i]) }.join(' | ')} |"
45
+ end
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,82 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'printer'
4
+
5
+ module MiniPGM
6
+ #
7
+ # Represents a tabular Conditional Probability Distribution (CPD)
8
+ #
9
+ # The CPD defines the probabilities of particular outcomes for a variable, given observations
10
+ # of the variables that it is dependent on (i.e. the evidence).
11
+ #
12
+ # For example, assuming the variable 'Cancer' is dependent on 'Pollution' and 'Smoker', each of
13
+ # which have two outcomes. Then if 'Cancer' also has two outcomes, then the string representation
14
+ # of the table could look like this:
15
+ #
16
+ # +-----------+-------------+-------------+-------------+-------------+
17
+ # | Smoker | Smoker_0 | Smoker_0 | Smoker_1 | Smoker_1 |
18
+ # +-----------+-------------+-------------+-------------+-------------+
19
+ # | Pollution | Pollution_0 | Pollution_1 | Pollution_0 | Pollution_1 |
20
+ # +-----------+-------------+-------------+-------------+-------------+
21
+ # | Cancer_0 | 0.03 | 0.05 | 0.001 | 0.02 |
22
+ # +-----------+-------------+-------------+-------------+-------------+
23
+ # | Cancer_1 | 0.97 | 0.95 | 0.999 | 0.98 |
24
+ # +-----------+-------------+-------------+-------------+-------------+
25
+ #
26
+ class TabularCPD
27
+ attr_reader :evidence, :variable
28
+
29
+ def initialize(variable, evidence, data)
30
+ expected_rows = variable.cardinality
31
+ raise ArgumentError, "wrong number of rows; expected #{expected_rows}" unless data.length == expected_rows
32
+
33
+ expected_cols = combinations(evidence)
34
+ data.each.with_index do |row, index|
35
+ raise ArgumentError, "wrong number of columns in row #{index}; expected #{expected_cols}" \
36
+ unless row.length == expected_cols
37
+ end
38
+
39
+ @variable = variable
40
+ @evidence = evidence
41
+ @data = data
42
+ end
43
+
44
+ def to_s
45
+ MiniPGM::Printer.print(header(@evidence) + body(@data, @variable.label))
46
+ end
47
+
48
+ private
49
+
50
+ def body(data, label)
51
+ data.map.with_index do |row, index|
52
+ row.map(&:to_s).unshift("#{label}_#{index}")
53
+ end
54
+ end
55
+
56
+ def combinations(variables)
57
+ variables.map(&:cardinality).inject(:*) || 1
58
+ end
59
+
60
+ def header(parents, cycles = 1)
61
+ head, *tail = parents
62
+
63
+ # base case
64
+ return [] unless head
65
+
66
+ # generate a row of data, including a label and a sequence of observations
67
+ repeats = combinations(tail)
68
+ cardinality = head.cardinality
69
+ label = head.label
70
+ row = [label] + sequence(cardinality, label, repeats) * cycles
71
+
72
+ # recurse and prepend row
73
+ header(tail, cycles * cardinality).unshift(row)
74
+ end
75
+
76
+ def sequence(cardinality, label, repeats)
77
+ (0...cardinality).flat_map do |index|
78
+ Array.new(repeats, "#{label}_#{index}")
79
+ end
80
+ end
81
+ end
82
+ end
@@ -0,0 +1,8 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MiniPGM
4
+ #
5
+ # Represents a variable in a tabular CPD
6
+ #
7
+ Variable = Struct.new(:label, :cardinality)
8
+ end
metadata ADDED
@@ -0,0 +1,91 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: mini_pgm
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.0.1
5
+ platform: ruby
6
+ authors:
7
+ - Tristan Penman
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2021-06-04 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rspec
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '3.10'
20
+ type: :development
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '3.10'
27
+ - !ruby/object:Gem::Dependency
28
+ name: rubocop
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: 1.16.0
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: 1.16.0
41
+ - !ruby/object:Gem::Dependency
42
+ name: simplecov
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: 0.21.2
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: 0.21.2
55
+ description: A minimal Probabilistic Graphical Model library for Ruby (very experimental)
56
+ email: tristan@tristanpenman.com
57
+ executables: []
58
+ extensions: []
59
+ extra_rdoc_files: []
60
+ files:
61
+ - lib/mini_pgm.rb
62
+ - lib/mini_pgm/edge.rb
63
+ - lib/mini_pgm/model.rb
64
+ - lib/mini_pgm/node.rb
65
+ - lib/mini_pgm/printer.rb
66
+ - lib/mini_pgm/tabular_cpd.rb
67
+ - lib/mini_pgm/variable.rb
68
+ homepage: https://github.com/tristanpenman/mini-pgm
69
+ licenses:
70
+ - MIT
71
+ metadata: {}
72
+ post_install_message:
73
+ rdoc_options: []
74
+ require_paths:
75
+ - lib
76
+ required_ruby_version: !ruby/object:Gem::Requirement
77
+ requirements:
78
+ - - ">="
79
+ - !ruby/object:Gem::Version
80
+ version: '0'
81
+ required_rubygems_version: !ruby/object:Gem::Requirement
82
+ requirements:
83
+ - - ">="
84
+ - !ruby/object:Gem::Version
85
+ version: '0'
86
+ requirements: []
87
+ rubygems_version: 3.0.9
88
+ signing_key:
89
+ specification_version: 4
90
+ summary: Minimal PGM library for Ruby
91
+ test_files: []