mini_pgm 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/lib/mini_pgm.rb +6 -0
- data/lib/mini_pgm/edge.rb +12 -0
- data/lib/mini_pgm/model.rb +113 -0
- data/lib/mini_pgm/node.rb +36 -0
- data/lib/mini_pgm/printer.rb +48 -0
- data/lib/mini_pgm/tabular_cpd.rb +82 -0
- data/lib/mini_pgm/variable.rb +8 -0
- metadata +91 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 371cafde1b4a0699d5150dbeee1c502abdfa4f01b1b02b3c5c1ee9921de430e8
|
4
|
+
data.tar.gz: 557de3515a26ed7ff51883d93927255aa5f61e1cfbb786f1054d09787639b62e
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: d4587b67bacb38663fbcd0ac05ad98cc5e6659eba4fb7af44f0b7ffbe01ae0bb817b756571d6b9f9c7cf5801fdc9f667b1cbc6b7a70efb1b95035f5510e78aba
|
7
|
+
data.tar.gz: 80a6468b078c223fd16092947501517114969980d0fc1af38499c0f39920984d451ba8bef7b96ce54f000a4154e2d3dbffb8db8482c3118af400c2e6a02bb58e
|
data/lib/mini_pgm.rb
ADDED
@@ -0,0 +1,113 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'node'
|
4
|
+
|
5
|
+
module MiniPGM
|
6
|
+
#
|
7
|
+
# Represents a Probabilistic Graphical Model (PGM)
|
8
|
+
#
|
9
|
+
class Model
|
10
|
+
class ModelError < StandardError; end
|
11
|
+
|
12
|
+
#
|
13
|
+
# Edges between individual nodes, sorted by label of the outgoing edge, e.g:
|
14
|
+
#
|
15
|
+
# Pollution -> Cancer
|
16
|
+
# Smoker -> Cancer
|
17
|
+
#
|
18
|
+
attr_reader :edges
|
19
|
+
|
20
|
+
#
|
21
|
+
# Lookup table of labelled nodes, each associated with a set of labels for all incoming edges, e.g:
|
22
|
+
#
|
23
|
+
# { Pollution, Smoker } -> Cancer -> { }
|
24
|
+
# { } -> Pollution -> { Cancer }
|
25
|
+
# { } -> Smoker -> { Cancer }
|
26
|
+
#
|
27
|
+
attr_reader :nodes
|
28
|
+
|
29
|
+
# most recent error after calling `valid?`
|
30
|
+
attr_reader :error
|
31
|
+
|
32
|
+
def initialize(*edges)
|
33
|
+
@edges = sort_edges(edges)
|
34
|
+
@nodes = reduce_edges(edges)
|
35
|
+
end
|
36
|
+
|
37
|
+
def add_cpd(cpd)
|
38
|
+
node = @nodes[cpd.variable.label]
|
39
|
+
raise ArgumentError, "node does not exist for label #{node.label}" unless node
|
40
|
+
|
41
|
+
check_cpd_evidence!(cpd.evidence.map(&:label), node.incoming_edges)
|
42
|
+
node.cpd = cpd
|
43
|
+
end
|
44
|
+
|
45
|
+
def to_s
|
46
|
+
['Edges:', edges_to_s, '', 'Nodes:', nodes_to_s, '', 'Valid:', valid?, ''].join("\n")
|
47
|
+
end
|
48
|
+
|
49
|
+
def validate!
|
50
|
+
@nodes.each_value do |node|
|
51
|
+
raise ModelError, "node '#{node.label}' does not have a CPD" unless node.cpd
|
52
|
+
end
|
53
|
+
|
54
|
+
# validate cardinality between nodes for each edge
|
55
|
+
@edges.each do |edge|
|
56
|
+
validate_cardinality!(@nodes[edge.to], @nodes[edge.from])
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
def valid?
|
61
|
+
@error = nil
|
62
|
+
validate!
|
63
|
+
true
|
64
|
+
rescue ModelError => e
|
65
|
+
@error = e
|
66
|
+
false
|
67
|
+
end
|
68
|
+
|
69
|
+
private
|
70
|
+
|
71
|
+
def check_cpd_evidence!(cpd_evidence, node_dependencies)
|
72
|
+
cpd_evidence.each do |evidence|
|
73
|
+
raise ArgumentError, "node is missing dependency for CPD evidence '#{evidence}'" \
|
74
|
+
unless node_dependencies.include?(evidence)
|
75
|
+
end
|
76
|
+
|
77
|
+
node_dependencies.each do |dependency|
|
78
|
+
raise ArgumentError, "CPD is missing evidence for node dependency '#{dependency}'" \
|
79
|
+
unless cpd_evidence.include?(dependency)
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
def edges_to_s
|
84
|
+
@edges.map(&:to_s).join("\n")
|
85
|
+
end
|
86
|
+
|
87
|
+
def nodes_to_s
|
88
|
+
@nodes.keys.sort.map { |key| @nodes[key].to_s }.join("\n")
|
89
|
+
end
|
90
|
+
|
91
|
+
def reduce_edges(edges)
|
92
|
+
edges.each_with_object({}) do |edge, reduced|
|
93
|
+
# create node for incoming edge
|
94
|
+
(reduced[edge.to] ||= MiniPGM::Node.new(edge.to)).incoming_edges.add(edge.from)
|
95
|
+
|
96
|
+
# create node for outgoing edge
|
97
|
+
(reduced[edge.from] ||= MiniPGM::Node.new(edge.from)).outgoing_edges.add(edge.to)
|
98
|
+
end
|
99
|
+
end
|
100
|
+
|
101
|
+
def sort_edges(edges)
|
102
|
+
edges.sort_by(&:from)
|
103
|
+
end
|
104
|
+
|
105
|
+
def validate_cardinality!(to, from)
|
106
|
+
expected = to.cpd.evidence.find { |ev| ev.label == from.label }.cardinality
|
107
|
+
actual = from.cpd.variable.cardinality
|
108
|
+
|
109
|
+
raise ModelError, "cardinality mismatch in CPDs of '#{from.label}' (#{actual}) and '#{to.label}' (#{expected})" \
|
110
|
+
unless expected == actual
|
111
|
+
end
|
112
|
+
end
|
113
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'set'
|
4
|
+
|
5
|
+
module MiniPGM
|
6
|
+
#
|
7
|
+
# Represents an individual node, or random variable, in a model
|
8
|
+
#
|
9
|
+
# Example string representation for a node with two incoming edges, and no outgoing edges:
|
10
|
+
#
|
11
|
+
# { Pollution, Smoker } -> Cancer -> { }
|
12
|
+
#
|
13
|
+
# Or for nodes with no incoming edges and only outgoing edges:
|
14
|
+
#
|
15
|
+
# { } -> Pollution -> { Cancer }
|
16
|
+
# { } -> Smoker -> { Cancer }
|
17
|
+
#
|
18
|
+
class Node
|
19
|
+
attr_accessor :cpd
|
20
|
+
attr_reader :label, :incoming_edges, :outgoing_edges
|
21
|
+
|
22
|
+
def initialize(label)
|
23
|
+
@label = label
|
24
|
+
@incoming_edges = Set.new
|
25
|
+
@outgoing_edges = Set.new
|
26
|
+
end
|
27
|
+
|
28
|
+
def to_s
|
29
|
+
[write_set(@incoming_edges), @label, write_set(@outgoing_edges)].join(' -> ')
|
30
|
+
end
|
31
|
+
|
32
|
+
def write_set(set)
|
33
|
+
set.empty? ? '{ }' : "{ #{set.to_a.sort.join(', ')} }"
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module MiniPGM
|
4
|
+
#
|
5
|
+
# Helper class for printing data in an easily readable ASCII format
|
6
|
+
#
|
7
|
+
# Example output:
|
8
|
+
#
|
9
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
10
|
+
# | Smoker | Smoker_0 | Smoker_0 | Smoker_1 | Smoker_1 |
|
11
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
12
|
+
# | Pollution | Pollution_0 | Pollution_1 | Pollution_0 | Pollution_1 |
|
13
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
14
|
+
# | Cancer_0 | 0.03 | 0.05 | 0.001 | 0.02 |
|
15
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
16
|
+
# | Cancer_1 | 0.97 | 0.95 | 0.999 | 0.98 |
|
17
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
18
|
+
#
|
19
|
+
class Printer
|
20
|
+
class << self
|
21
|
+
def print(rows)
|
22
|
+
# calculate column widths
|
23
|
+
num_columns = rows[0].length
|
24
|
+
column_widths = (0...num_columns).map { |col| max_width(rows, col) }
|
25
|
+
|
26
|
+
# write table
|
27
|
+
div = write_divider(column_widths)
|
28
|
+
rows.flat_map { |row| [write_row(column_widths, row), div] }
|
29
|
+
.unshift(div)
|
30
|
+
.join("\n")
|
31
|
+
end
|
32
|
+
|
33
|
+
private
|
34
|
+
|
35
|
+
def max_width(rows, column)
|
36
|
+
rows.map { |row| row[column].length }.max
|
37
|
+
end
|
38
|
+
|
39
|
+
def write_divider(column_widths)
|
40
|
+
"+-#{column_widths.map { |g| '-' * g }.join('-+-')}-+"
|
41
|
+
end
|
42
|
+
|
43
|
+
def write_row(column_widths, values)
|
44
|
+
"| #{values.map.with_index { |value, i| value.ljust(column_widths[i]) }.join(' | ')} |"
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'printer'
|
4
|
+
|
5
|
+
module MiniPGM
|
6
|
+
#
|
7
|
+
# Represents a tabular Conditional Probability Distribution (CPD)
|
8
|
+
#
|
9
|
+
# The CPD defines the probabilities of particular outcomes for a variable, given observations
|
10
|
+
# of the variables that it is dependent on (i.e. the evidence).
|
11
|
+
#
|
12
|
+
# For example, assuming the variable 'Cancer' is dependent on 'Pollution' and 'Smoker', each of
|
13
|
+
# which have two outcomes. Then if 'Cancer' also has two outcomes, then the string representation
|
14
|
+
# of the table could look like this:
|
15
|
+
#
|
16
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
17
|
+
# | Smoker | Smoker_0 | Smoker_0 | Smoker_1 | Smoker_1 |
|
18
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
19
|
+
# | Pollution | Pollution_0 | Pollution_1 | Pollution_0 | Pollution_1 |
|
20
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
21
|
+
# | Cancer_0 | 0.03 | 0.05 | 0.001 | 0.02 |
|
22
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
23
|
+
# | Cancer_1 | 0.97 | 0.95 | 0.999 | 0.98 |
|
24
|
+
# +-----------+-------------+-------------+-------------+-------------+
|
25
|
+
#
|
26
|
+
class TabularCPD
|
27
|
+
attr_reader :evidence, :variable
|
28
|
+
|
29
|
+
def initialize(variable, evidence, data)
|
30
|
+
expected_rows = variable.cardinality
|
31
|
+
raise ArgumentError, "wrong number of rows; expected #{expected_rows}" unless data.length == expected_rows
|
32
|
+
|
33
|
+
expected_cols = combinations(evidence)
|
34
|
+
data.each.with_index do |row, index|
|
35
|
+
raise ArgumentError, "wrong number of columns in row #{index}; expected #{expected_cols}" \
|
36
|
+
unless row.length == expected_cols
|
37
|
+
end
|
38
|
+
|
39
|
+
@variable = variable
|
40
|
+
@evidence = evidence
|
41
|
+
@data = data
|
42
|
+
end
|
43
|
+
|
44
|
+
def to_s
|
45
|
+
MiniPGM::Printer.print(header(@evidence) + body(@data, @variable.label))
|
46
|
+
end
|
47
|
+
|
48
|
+
private
|
49
|
+
|
50
|
+
def body(data, label)
|
51
|
+
data.map.with_index do |row, index|
|
52
|
+
row.map(&:to_s).unshift("#{label}_#{index}")
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
def combinations(variables)
|
57
|
+
variables.map(&:cardinality).inject(:*) || 1
|
58
|
+
end
|
59
|
+
|
60
|
+
def header(parents, cycles = 1)
|
61
|
+
head, *tail = parents
|
62
|
+
|
63
|
+
# base case
|
64
|
+
return [] unless head
|
65
|
+
|
66
|
+
# generate a row of data, including a label and a sequence of observations
|
67
|
+
repeats = combinations(tail)
|
68
|
+
cardinality = head.cardinality
|
69
|
+
label = head.label
|
70
|
+
row = [label] + sequence(cardinality, label, repeats) * cycles
|
71
|
+
|
72
|
+
# recurse and prepend row
|
73
|
+
header(tail, cycles * cardinality).unshift(row)
|
74
|
+
end
|
75
|
+
|
76
|
+
def sequence(cardinality, label, repeats)
|
77
|
+
(0...cardinality).flat_map do |index|
|
78
|
+
Array.new(repeats, "#{label}_#{index}")
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
metadata
ADDED
@@ -0,0 +1,91 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: mini_pgm
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.0.1
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Tristan Penman
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2021-06-04 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: rspec
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '3.10'
|
20
|
+
type: :development
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '3.10'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rubocop
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - "~>"
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: 1.16.0
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - "~>"
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: 1.16.0
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: simplecov
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: 0.21.2
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: 0.21.2
|
55
|
+
description: A minimal Probabilistic Graphical Model library for Ruby (very experimental)
|
56
|
+
email: tristan@tristanpenman.com
|
57
|
+
executables: []
|
58
|
+
extensions: []
|
59
|
+
extra_rdoc_files: []
|
60
|
+
files:
|
61
|
+
- lib/mini_pgm.rb
|
62
|
+
- lib/mini_pgm/edge.rb
|
63
|
+
- lib/mini_pgm/model.rb
|
64
|
+
- lib/mini_pgm/node.rb
|
65
|
+
- lib/mini_pgm/printer.rb
|
66
|
+
- lib/mini_pgm/tabular_cpd.rb
|
67
|
+
- lib/mini_pgm/variable.rb
|
68
|
+
homepage: https://github.com/tristanpenman/mini-pgm
|
69
|
+
licenses:
|
70
|
+
- MIT
|
71
|
+
metadata: {}
|
72
|
+
post_install_message:
|
73
|
+
rdoc_options: []
|
74
|
+
require_paths:
|
75
|
+
- lib
|
76
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
77
|
+
requirements:
|
78
|
+
- - ">="
|
79
|
+
- !ruby/object:Gem::Version
|
80
|
+
version: '0'
|
81
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
82
|
+
requirements:
|
83
|
+
- - ">="
|
84
|
+
- !ruby/object:Gem::Version
|
85
|
+
version: '0'
|
86
|
+
requirements: []
|
87
|
+
rubygems_version: 3.0.9
|
88
|
+
signing_key:
|
89
|
+
specification_version: 4
|
90
|
+
summary: Minimal PGM library for Ruby
|
91
|
+
test_files: []
|