mini_pgm 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 371cafde1b4a0699d5150dbeee1c502abdfa4f01b1b02b3c5c1ee9921de430e8
4
+ data.tar.gz: 557de3515a26ed7ff51883d93927255aa5f61e1cfbb786f1054d09787639b62e
5
+ SHA512:
6
+ metadata.gz: d4587b67bacb38663fbcd0ac05ad98cc5e6659eba4fb7af44f0b7ffbe01ae0bb817b756571d6b9f9c7cf5801fdc9f667b1cbc6b7a70efb1b95035f5510e78aba
7
+ data.tar.gz: 80a6468b078c223fd16092947501517114969980d0fc1af38499c0f39920984d451ba8bef7b96ce54f000a4154e2d3dbffb8db8482c3118af400c2e6a02bb58e
data/lib/mini_pgm.rb ADDED
@@ -0,0 +1,6 @@
1
+ require 'mini_pgm/edge'
2
+ require 'mini_pgm/model'
3
+ require 'mini_pgm/node'
4
+ require 'mini_pgm/printer'
5
+ require 'mini_pgm/tabular_cpd'
6
+ require 'mini_pgm/variable'
@@ -0,0 +1,12 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MiniPGM
4
+ #
5
+ # Represents a directed edge between two nodes in a PGM
6
+ #
7
+ Edge = Struct.new(:from, :to) do
8
+ def to_s
9
+ "#{from} -> #{to}"
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,113 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'node'
4
+
5
+ module MiniPGM
6
+ #
7
+ # Represents a Probabilistic Graphical Model (PGM)
8
+ #
9
+ class Model
10
+ class ModelError < StandardError; end
11
+
12
+ #
13
+ # Edges between individual nodes, sorted by label of the outgoing edge, e.g:
14
+ #
15
+ # Pollution -> Cancer
16
+ # Smoker -> Cancer
17
+ #
18
+ attr_reader :edges
19
+
20
+ #
21
+ # Lookup table of labelled nodes, each associated with a set of labels for all incoming edges, e.g:
22
+ #
23
+ # { Pollution, Smoker } -> Cancer -> { }
24
+ # { } -> Pollution -> { Cancer }
25
+ # { } -> Smoker -> { Cancer }
26
+ #
27
+ attr_reader :nodes
28
+
29
+ # most recent error after calling `valid?`
30
+ attr_reader :error
31
+
32
+ def initialize(*edges)
33
+ @edges = sort_edges(edges)
34
+ @nodes = reduce_edges(edges)
35
+ end
36
+
37
+ def add_cpd(cpd)
38
+ node = @nodes[cpd.variable.label]
39
+ raise ArgumentError, "node does not exist for label #{node.label}" unless node
40
+
41
+ check_cpd_evidence!(cpd.evidence.map(&:label), node.incoming_edges)
42
+ node.cpd = cpd
43
+ end
44
+
45
+ def to_s
46
+ ['Edges:', edges_to_s, '', 'Nodes:', nodes_to_s, '', 'Valid:', valid?, ''].join("\n")
47
+ end
48
+
49
+ def validate!
50
+ @nodes.each_value do |node|
51
+ raise ModelError, "node '#{node.label}' does not have a CPD" unless node.cpd
52
+ end
53
+
54
+ # validate cardinality between nodes for each edge
55
+ @edges.each do |edge|
56
+ validate_cardinality!(@nodes[edge.to], @nodes[edge.from])
57
+ end
58
+ end
59
+
60
+ def valid?
61
+ @error = nil
62
+ validate!
63
+ true
64
+ rescue ModelError => e
65
+ @error = e
66
+ false
67
+ end
68
+
69
+ private
70
+
71
+ def check_cpd_evidence!(cpd_evidence, node_dependencies)
72
+ cpd_evidence.each do |evidence|
73
+ raise ArgumentError, "node is missing dependency for CPD evidence '#{evidence}'" \
74
+ unless node_dependencies.include?(evidence)
75
+ end
76
+
77
+ node_dependencies.each do |dependency|
78
+ raise ArgumentError, "CPD is missing evidence for node dependency '#{dependency}'" \
79
+ unless cpd_evidence.include?(dependency)
80
+ end
81
+ end
82
+
83
+ def edges_to_s
84
+ @edges.map(&:to_s).join("\n")
85
+ end
86
+
87
+ def nodes_to_s
88
+ @nodes.keys.sort.map { |key| @nodes[key].to_s }.join("\n")
89
+ end
90
+
91
+ def reduce_edges(edges)
92
+ edges.each_with_object({}) do |edge, reduced|
93
+ # create node for incoming edge
94
+ (reduced[edge.to] ||= MiniPGM::Node.new(edge.to)).incoming_edges.add(edge.from)
95
+
96
+ # create node for outgoing edge
97
+ (reduced[edge.from] ||= MiniPGM::Node.new(edge.from)).outgoing_edges.add(edge.to)
98
+ end
99
+ end
100
+
101
+ def sort_edges(edges)
102
+ edges.sort_by(&:from)
103
+ end
104
+
105
+ def validate_cardinality!(to, from)
106
+ expected = to.cpd.evidence.find { |ev| ev.label == from.label }.cardinality
107
+ actual = from.cpd.variable.cardinality
108
+
109
+ raise ModelError, "cardinality mismatch in CPDs of '#{from.label}' (#{actual}) and '#{to.label}' (#{expected})" \
110
+ unless expected == actual
111
+ end
112
+ end
113
+ end
@@ -0,0 +1,36 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'set'
4
+
5
+ module MiniPGM
6
+ #
7
+ # Represents an individual node, or random variable, in a model
8
+ #
9
+ # Example string representation for a node with two incoming edges, and no outgoing edges:
10
+ #
11
+ # { Pollution, Smoker } -> Cancer -> { }
12
+ #
13
+ # Or for nodes with no incoming edges and only outgoing edges:
14
+ #
15
+ # { } -> Pollution -> { Cancer }
16
+ # { } -> Smoker -> { Cancer }
17
+ #
18
+ class Node
19
+ attr_accessor :cpd
20
+ attr_reader :label, :incoming_edges, :outgoing_edges
21
+
22
+ def initialize(label)
23
+ @label = label
24
+ @incoming_edges = Set.new
25
+ @outgoing_edges = Set.new
26
+ end
27
+
28
+ def to_s
29
+ [write_set(@incoming_edges), @label, write_set(@outgoing_edges)].join(' -> ')
30
+ end
31
+
32
+ def write_set(set)
33
+ set.empty? ? '{ }' : "{ #{set.to_a.sort.join(', ')} }"
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,48 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MiniPGM
4
+ #
5
+ # Helper class for printing data in an easily readable ASCII format
6
+ #
7
+ # Example output:
8
+ #
9
+ # +-----------+-------------+-------------+-------------+-------------+
10
+ # | Smoker | Smoker_0 | Smoker_0 | Smoker_1 | Smoker_1 |
11
+ # +-----------+-------------+-------------+-------------+-------------+
12
+ # | Pollution | Pollution_0 | Pollution_1 | Pollution_0 | Pollution_1 |
13
+ # +-----------+-------------+-------------+-------------+-------------+
14
+ # | Cancer_0 | 0.03 | 0.05 | 0.001 | 0.02 |
15
+ # +-----------+-------------+-------------+-------------+-------------+
16
+ # | Cancer_1 | 0.97 | 0.95 | 0.999 | 0.98 |
17
+ # +-----------+-------------+-------------+-------------+-------------+
18
+ #
19
+ class Printer
20
+ class << self
21
+ def print(rows)
22
+ # calculate column widths
23
+ num_columns = rows[0].length
24
+ column_widths = (0...num_columns).map { |col| max_width(rows, col) }
25
+
26
+ # write table
27
+ div = write_divider(column_widths)
28
+ rows.flat_map { |row| [write_row(column_widths, row), div] }
29
+ .unshift(div)
30
+ .join("\n")
31
+ end
32
+
33
+ private
34
+
35
+ def max_width(rows, column)
36
+ rows.map { |row| row[column].length }.max
37
+ end
38
+
39
+ def write_divider(column_widths)
40
+ "+-#{column_widths.map { |g| '-' * g }.join('-+-')}-+"
41
+ end
42
+
43
+ def write_row(column_widths, values)
44
+ "| #{values.map.with_index { |value, i| value.ljust(column_widths[i]) }.join(' | ')} |"
45
+ end
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,82 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'printer'
4
+
5
+ module MiniPGM
6
+ #
7
+ # Represents a tabular Conditional Probability Distribution (CPD)
8
+ #
9
+ # The CPD defines the probabilities of particular outcomes for a variable, given observations
10
+ # of the variables that it is dependent on (i.e. the evidence).
11
+ #
12
+ # For example, assuming the variable 'Cancer' is dependent on 'Pollution' and 'Smoker', each of
13
+ # which have two outcomes. Then if 'Cancer' also has two outcomes, then the string representation
14
+ # of the table could look like this:
15
+ #
16
+ # +-----------+-------------+-------------+-------------+-------------+
17
+ # | Smoker | Smoker_0 | Smoker_0 | Smoker_1 | Smoker_1 |
18
+ # +-----------+-------------+-------------+-------------+-------------+
19
+ # | Pollution | Pollution_0 | Pollution_1 | Pollution_0 | Pollution_1 |
20
+ # +-----------+-------------+-------------+-------------+-------------+
21
+ # | Cancer_0 | 0.03 | 0.05 | 0.001 | 0.02 |
22
+ # +-----------+-------------+-------------+-------------+-------------+
23
+ # | Cancer_1 | 0.97 | 0.95 | 0.999 | 0.98 |
24
+ # +-----------+-------------+-------------+-------------+-------------+
25
+ #
26
+ class TabularCPD
27
+ attr_reader :evidence, :variable
28
+
29
+ def initialize(variable, evidence, data)
30
+ expected_rows = variable.cardinality
31
+ raise ArgumentError, "wrong number of rows; expected #{expected_rows}" unless data.length == expected_rows
32
+
33
+ expected_cols = combinations(evidence)
34
+ data.each.with_index do |row, index|
35
+ raise ArgumentError, "wrong number of columns in row #{index}; expected #{expected_cols}" \
36
+ unless row.length == expected_cols
37
+ end
38
+
39
+ @variable = variable
40
+ @evidence = evidence
41
+ @data = data
42
+ end
43
+
44
+ def to_s
45
+ MiniPGM::Printer.print(header(@evidence) + body(@data, @variable.label))
46
+ end
47
+
48
+ private
49
+
50
+ def body(data, label)
51
+ data.map.with_index do |row, index|
52
+ row.map(&:to_s).unshift("#{label}_#{index}")
53
+ end
54
+ end
55
+
56
+ def combinations(variables)
57
+ variables.map(&:cardinality).inject(:*) || 1
58
+ end
59
+
60
+ def header(parents, cycles = 1)
61
+ head, *tail = parents
62
+
63
+ # base case
64
+ return [] unless head
65
+
66
+ # generate a row of data, including a label and a sequence of observations
67
+ repeats = combinations(tail)
68
+ cardinality = head.cardinality
69
+ label = head.label
70
+ row = [label] + sequence(cardinality, label, repeats) * cycles
71
+
72
+ # recurse and prepend row
73
+ header(tail, cycles * cardinality).unshift(row)
74
+ end
75
+
76
+ def sequence(cardinality, label, repeats)
77
+ (0...cardinality).flat_map do |index|
78
+ Array.new(repeats, "#{label}_#{index}")
79
+ end
80
+ end
81
+ end
82
+ end
@@ -0,0 +1,8 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MiniPGM
4
+ #
5
+ # Represents a variable in a tabular CPD
6
+ #
7
+ Variable = Struct.new(:label, :cardinality)
8
+ end
metadata ADDED
@@ -0,0 +1,91 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: mini_pgm
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.0.1
5
+ platform: ruby
6
+ authors:
7
+ - Tristan Penman
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2021-06-04 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rspec
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '3.10'
20
+ type: :development
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '3.10'
27
+ - !ruby/object:Gem::Dependency
28
+ name: rubocop
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: 1.16.0
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: 1.16.0
41
+ - !ruby/object:Gem::Dependency
42
+ name: simplecov
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: 0.21.2
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: 0.21.2
55
+ description: A minimal Probabilistic Graphical Model library for Ruby (very experimental)
56
+ email: tristan@tristanpenman.com
57
+ executables: []
58
+ extensions: []
59
+ extra_rdoc_files: []
60
+ files:
61
+ - lib/mini_pgm.rb
62
+ - lib/mini_pgm/edge.rb
63
+ - lib/mini_pgm/model.rb
64
+ - lib/mini_pgm/node.rb
65
+ - lib/mini_pgm/printer.rb
66
+ - lib/mini_pgm/tabular_cpd.rb
67
+ - lib/mini_pgm/variable.rb
68
+ homepage: https://github.com/tristanpenman/mini-pgm
69
+ licenses:
70
+ - MIT
71
+ metadata: {}
72
+ post_install_message:
73
+ rdoc_options: []
74
+ require_paths:
75
+ - lib
76
+ required_ruby_version: !ruby/object:Gem::Requirement
77
+ requirements:
78
+ - - ">="
79
+ - !ruby/object:Gem::Version
80
+ version: '0'
81
+ required_rubygems_version: !ruby/object:Gem::Requirement
82
+ requirements:
83
+ - - ">="
84
+ - !ruby/object:Gem::Version
85
+ version: '0'
86
+ requirements: []
87
+ rubygems_version: 3.0.9
88
+ signing_key:
89
+ specification_version: 4
90
+ summary: Minimal PGM library for Ruby
91
+ test_files: []