mini_pgm 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/lib/mini_pgm.rb +6 -0
- data/lib/mini_pgm/edge.rb +12 -0
- data/lib/mini_pgm/model.rb +113 -0
- data/lib/mini_pgm/node.rb +36 -0
- data/lib/mini_pgm/printer.rb +48 -0
- data/lib/mini_pgm/tabular_cpd.rb +82 -0
- data/lib/mini_pgm/variable.rb +8 -0
- metadata +91 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 371cafde1b4a0699d5150dbeee1c502abdfa4f01b1b02b3c5c1ee9921de430e8
|
4
|
+
data.tar.gz: 557de3515a26ed7ff51883d93927255aa5f61e1cfbb786f1054d09787639b62e
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: d4587b67bacb38663fbcd0ac05ad98cc5e6659eba4fb7af44f0b7ffbe01ae0bb817b756571d6b9f9c7cf5801fdc9f667b1cbc6b7a70efb1b95035f5510e78aba
|
7
|
+
data.tar.gz: 80a6468b078c223fd16092947501517114969980d0fc1af38499c0f39920984d451ba8bef7b96ce54f000a4154e2d3dbffb8db8482c3118af400c2e6a02bb58e
|
data/lib/mini_pgm.rb
ADDED
@@ -0,0 +1,113 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'node'
|
4
|
+
|
5
|
+
module MiniPGM
|
6
|
+
#
|
7
|
+
# Represents a Probabilistic Graphical Model (PGM)
|
8
|
+
#
|
9
|
+
class Model
|
10
|
+
class ModelError < StandardError; end
|
11
|
+
|
12
|
+
#
|
13
|
+
# Edges between individual nodes, sorted by label of the outgoing edge, e.g:
|
14
|
+
#
|
15
|
+
# Pollution -> Cancer
|
16
|
+
# Smoker -> Cancer
|
17
|
+
#
|
18
|
+
attr_reader :edges
|
19
|
+
|
20
|
+
#
|
21
|
+
# Lookup table of labelled nodes, each associated with a set of labels for all incoming edges, e.g:
|
22
|
+
#
|
23
|
+
# { Pollution, Smoker } -> Cancer -> { }
|
24
|
+
# { } -> Pollution -> { Cancer }
|
25
|
+
# { } -> Smoker -> { Cancer }
|
26
|
+
#
|
27
|
+
attr_reader :nodes
|
28
|
+
|
29
|
+
# most recent error after calling `valid?`
|
30
|
+
attr_reader :error
|
31
|
+
|
32
|
+
def initialize(*edges)
|
33
|
+
@edges = sort_edges(edges)
|
34
|
+
@nodes = reduce_edges(edges)
|
35
|
+
end
|
36
|
+
|
37
|
+
def add_cpd(cpd)
|
38
|
+
node = @nodes[cpd.variable.label]
|
39
|
+
raise ArgumentError, "node does not exist for label #{node.label}" unless node
|
40
|
+
|
41
|
+
check_cpd_evidence!(cpd.evidence.map(&:label), node.incoming_edges)
|
42
|
+
node.cpd = cpd
|
43
|
+
end
|
44
|
+
|
45
|
+
def to_s
|
46
|
+
['Edges:', edges_to_s, '', 'Nodes:', nodes_to_s, '', 'Valid:', valid?, ''].join("\n")
|
47
|
+
end
|
48
|
+
|
49
|
+
def validate!
|
50
|
+
@nodes.each_value do |node|
|
51
|
+
raise ModelError, "node '#{node.label}' does not have a CPD" unless node.cpd
|
52
|
+
end
|
53
|
+
|
54
|
+
# validate cardinality between nodes for each edge
|
55
|
+
@edges.each do |edge|
|
56
|
+
validate_cardinality!(@nodes[edge.to], @nodes[edge.from])
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
def valid?
|
61
|
+
@error = nil
|
62
|
+
validate!
|
63
|
+
true
|
64
|
+
rescue ModelError => e
|
65
|
+
@error = e
|
66
|
+
false
|
67
|
+
end
|
68
|
+
|
69
|
+
private
|
70
|
+
|
71
|
+
def check_cpd_evidence!(cpd_evidence, node_dependencies)
|
72
|
+
cpd_evidence.each do |evidence|
|
73
|
+
raise ArgumentError, "node is missing dependency for CPD evidence '#{evidence}'" \
|
74
|
+
unless node_dependencies.include?(evidence)
|
75
|
+
end
|
76
|
+
|
77
|
+
node_dependencies.each do |dependency|
|
78
|
+
raise ArgumentError, "CPD is missing evidence for node dependency '#{dependency}'" \
|
79
|
+
unless cpd_evidence.include?(dependency)
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
def edges_to_s
|
84
|
+
@edges.map(&:to_s).join("\n")
|
85
|
+
end
|
86
|
+
|
87
|
+
def nodes_to_s
|
88
|
+
@nodes.keys.sort.map { |key| @nodes[key].to_s }.join("\n")
|
89
|
+
end
|
90
|
+
|
91
|
+
def reduce_edges(edges)
|
92
|
+
edges.each_with_object({}) do |edge, reduced|
|
93
|
+
# create node for incoming edge
|
94
|
+
(reduced[edge.to] ||= MiniPGM::Node.new(edge.to)).incoming_edges.add(edge.from)
|
95
|
+
|
96
|
+
# create node for outgoing edge
|
97
|
+
(reduced[edge.from] ||= MiniPGM::Node.new(edge.from)).outgoing_edges.add(edge.to)
|
98
|
+
end
|
99
|
+
end
|
100
|
+
|
101
|
+
def sort_edges(edges)
|
102
|
+
edges.sort_by(&:from)
|
103
|
+
end
|
104
|
+
|
105
|
+
def validate_cardinality!(to, from)
|
106
|
+
expected = to.cpd.evidence.find { |ev| ev.label == from.label }.cardinality
|
107
|
+
actual = from.cpd.variable.cardinality
|
108
|
+
|
109
|
+
raise ModelError, "cardinality mismatch in CPDs of '#{from.label}' (#{actual}) and '#{to.label}' (#{expected})" \
|
110
|
+
unless expected == actual
|
111
|
+
end
|
112
|
+
end
|
113
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'set'
|
4
|
+
|
5
|
+
module MiniPGM
|
6
|
+
#
|
7
|
+
# Represents an individual node, or random variable, in a model
|
8
|
+
#
|
9
|
+
# Example string representation for a node with two incoming edges, and no outgoing edges:
|
10
|
+
#
|
11
|
+
# { Pollution, Smoker } -> Cancer -> { }
|
12
|
+
#
|
13
|
+
# Or for nodes with no incoming edges and only outgoing edges:
|
14
|
+
#
|
15
|
+
# { } -> Pollution -> { Cancer }
|
16
|
+
# { } -> Smoker -> { Cancer }
|
17
|
+
#
|
18
|
+
class Node
|
19
|
+
attr_accessor :cpd
|
20
|
+
attr_reader :label, :incoming_edges, :outgoing_edges
|
21
|
+
|
22
|
+
def initialize(label)
|
23
|
+
@label = label
|
24
|
+
@incoming_edges = Set.new
|
25
|
+
@outgoing_edges = Set.new
|
26
|
+
end
|
27
|
+
|
28
|
+
def to_s
|
29
|
+
[write_set(@incoming_edges), @label, write_set(@outgoing_edges)].join(' -> ')
|
30
|
+
end
|
31
|
+
|
32
|
+
def write_set(set)
|
33
|
+
set.empty? ? '{ }' : "{ #{set.to_a.sort.join(', ')} }"
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module MiniPGM
|
4
|
+
#
|
5
|
+
# Helper class for printing data in an easily readable ASCII format
|
6
|
+
#
|
7
|
+
# Example output:
|
8
|
+
#
|
9
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
10
|
+
# | Smoker | Smoker_0 | Smoker_0 | Smoker_1 | Smoker_1 |
|
11
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
12
|
+
# | Pollution | Pollution_0 | Pollution_1 | Pollution_0 | Pollution_1 |
|
13
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
14
|
+
# | Cancer_0 | 0.03 | 0.05 | 0.001 | 0.02 |
|
15
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
16
|
+
# | Cancer_1 | 0.97 | 0.95 | 0.999 | 0.98 |
|
17
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
18
|
+
#
|
19
|
+
class Printer
|
20
|
+
class << self
|
21
|
+
def print(rows)
|
22
|
+
# calculate column widths
|
23
|
+
num_columns = rows[0].length
|
24
|
+
column_widths = (0...num_columns).map { |col| max_width(rows, col) }
|
25
|
+
|
26
|
+
# write table
|
27
|
+
div = write_divider(column_widths)
|
28
|
+
rows.flat_map { |row| [write_row(column_widths, row), div] }
|
29
|
+
.unshift(div)
|
30
|
+
.join("\n")
|
31
|
+
end
|
32
|
+
|
33
|
+
private
|
34
|
+
|
35
|
+
def max_width(rows, column)
|
36
|
+
rows.map { |row| row[column].length }.max
|
37
|
+
end
|
38
|
+
|
39
|
+
def write_divider(column_widths)
|
40
|
+
"+-#{column_widths.map { |g| '-' * g }.join('-+-')}-+"
|
41
|
+
end
|
42
|
+
|
43
|
+
def write_row(column_widths, values)
|
44
|
+
"| #{values.map.with_index { |value, i| value.ljust(column_widths[i]) }.join(' | ')} |"
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'printer'
|
4
|
+
|
5
|
+
module MiniPGM
|
6
|
+
#
|
7
|
+
# Represents a tabular Conditional Probability Distribution (CPD)
|
8
|
+
#
|
9
|
+
# The CPD defines the probabilities of particular outcomes for a variable, given observations
|
10
|
+
# of the variables that it is dependent on (i.e. the evidence).
|
11
|
+
#
|
12
|
+
# For example, assuming the variable 'Cancer' is dependent on 'Pollution' and 'Smoker', each of
|
13
|
+
# which have two outcomes. Then if 'Cancer' also has two outcomes, then the string representation
|
14
|
+
# of the table could look like this:
|
15
|
+
#
|
16
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
17
|
+
# | Smoker | Smoker_0 | Smoker_0 | Smoker_1 | Smoker_1 |
|
18
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
19
|
+
# | Pollution | Pollution_0 | Pollution_1 | Pollution_0 | Pollution_1 |
|
20
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
21
|
+
# | Cancer_0 | 0.03 | 0.05 | 0.001 | 0.02 |
|
22
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
23
|
+
# | Cancer_1 | 0.97 | 0.95 | 0.999 | 0.98 |
|
24
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
25
|
+
#
|
26
|
+
class TabularCPD
|
27
|
+
attr_reader :evidence, :variable
|
28
|
+
|
29
|
+
def initialize(variable, evidence, data)
|
30
|
+
expected_rows = variable.cardinality
|
31
|
+
raise ArgumentError, "wrong number of rows; expected #{expected_rows}" unless data.length == expected_rows
|
32
|
+
|
33
|
+
expected_cols = combinations(evidence)
|
34
|
+
data.each.with_index do |row, index|
|
35
|
+
raise ArgumentError, "wrong number of columns in row #{index}; expected #{expected_cols}" \
|
36
|
+
unless row.length == expected_cols
|
37
|
+
end
|
38
|
+
|
39
|
+
@variable = variable
|
40
|
+
@evidence = evidence
|
41
|
+
@data = data
|
42
|
+
end
|
43
|
+
|
44
|
+
def to_s
|
45
|
+
MiniPGM::Printer.print(header(@evidence) + body(@data, @variable.label))
|
46
|
+
end
|
47
|
+
|
48
|
+
private
|
49
|
+
|
50
|
+
def body(data, label)
|
51
|
+
data.map.with_index do |row, index|
|
52
|
+
row.map(&:to_s).unshift("#{label}_#{index}")
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
def combinations(variables)
|
57
|
+
variables.map(&:cardinality).inject(:*) || 1
|
58
|
+
end
|
59
|
+
|
60
|
+
def header(parents, cycles = 1)
|
61
|
+
head, *tail = parents
|
62
|
+
|
63
|
+
# base case
|
64
|
+
return [] unless head
|
65
|
+
|
66
|
+
# generate a row of data, including a label and a sequence of observations
|
67
|
+
repeats = combinations(tail)
|
68
|
+
cardinality = head.cardinality
|
69
|
+
label = head.label
|
70
|
+
row = [label] + sequence(cardinality, label, repeats) * cycles
|
71
|
+
|
72
|
+
# recurse and prepend row
|
73
|
+
header(tail, cycles * cardinality).unshift(row)
|
74
|
+
end
|
75
|
+
|
76
|
+
def sequence(cardinality, label, repeats)
|
77
|
+
(0...cardinality).flat_map do |index|
|
78
|
+
Array.new(repeats, "#{label}_#{index}")
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
metadata
ADDED
@@ -0,0 +1,91 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: mini_pgm
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.0.1
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Tristan Penman
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2021-06-04 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: rspec
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '3.10'
|
20
|
+
type: :development
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '3.10'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rubocop
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - "~>"
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: 1.16.0
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - "~>"
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: 1.16.0
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: simplecov
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: 0.21.2
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: 0.21.2
|
55
|
+
description: A minimal Probabilistic Graphical Model library for Ruby (very experimental)
|
56
|
+
email: tristan@tristanpenman.com
|
57
|
+
executables: []
|
58
|
+
extensions: []
|
59
|
+
extra_rdoc_files: []
|
60
|
+
files:
|
61
|
+
- lib/mini_pgm.rb
|
62
|
+
- lib/mini_pgm/edge.rb
|
63
|
+
- lib/mini_pgm/model.rb
|
64
|
+
- lib/mini_pgm/node.rb
|
65
|
+
- lib/mini_pgm/printer.rb
|
66
|
+
- lib/mini_pgm/tabular_cpd.rb
|
67
|
+
- lib/mini_pgm/variable.rb
|
68
|
+
homepage: https://github.com/tristanpenman/mini-pgm
|
69
|
+
licenses:
|
70
|
+
- MIT
|
71
|
+
metadata: {}
|
72
|
+
post_install_message:
|
73
|
+
rdoc_options: []
|
74
|
+
require_paths:
|
75
|
+
- lib
|
76
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
77
|
+
requirements:
|
78
|
+
- - ">="
|
79
|
+
- !ruby/object:Gem::Version
|
80
|
+
version: '0'
|
81
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
82
|
+
requirements:
|
83
|
+
- - ">="
|
84
|
+
- !ruby/object:Gem::Version
|
85
|
+
version: '0'
|
86
|
+
requirements: []
|
87
|
+
rubygems_version: 3.0.9
|
88
|
+
signing_key:
|
89
|
+
specification_version: 4
|
90
|
+
summary: Minimal PGM library for Ruby
|
91
|
+
test_files: []
|