compiled-knowledge 4.0.0a20__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37525 -0
- ck/circuit/_circuit_cy.cpython-313-darwin.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
ck/pgm.py
ADDED
|
@@ -0,0 +1,3475 @@
|
|
|
1
|
+
"""
|
|
2
|
+
For more documentation on this module, refer to the Jupyter notebook docs/4_PGM_advanced.ipynb.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from itertools import repeat as _repeat
|
|
10
|
+
from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
|
|
11
|
+
Collection, Any, Iterator
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from ck.utils.iter_extras import (
|
|
16
|
+
combos_ranges as _combos_ranges, multiply as _multiply, combos as _combos
|
|
17
|
+
)
|
|
18
|
+
from ck.utils.np_extras import NDArrayFloat64, NDArrayUInt8
|
|
19
|
+
|
|
20
|
+
# What types are permitted as random variable states
|
|
21
|
+
State = Union[int, str, bool, float, None]
|
|
22
|
+
|
|
23
|
+
# An instance (of a sequence of random variables) is a tuple of integers
|
|
24
|
+
# that are state indexes, co-indexed with a known sequence of random variables.
|
|
25
|
+
Instance = Sequence[int]
|
|
26
|
+
|
|
27
|
+
# A key identifies an instance.
|
|
28
|
+
# A single integer is treated as an instance with one dimension.
|
|
29
|
+
Key = Union[Sequence[int], int]
|
|
30
|
+
|
|
31
|
+
# The shape of a sequence of random variables (e.g., a PGM, Factor or PotentialFunction).
|
|
32
|
+
Shape = Sequence[int]
|
|
33
|
+
|
|
34
|
+
DEFAULT_TOLERANCE: float = 0.000001 # For checking CPT sums.
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class PGM:
|
|
38
|
+
"""
|
|
39
|
+
A probabilistic graphical model (PGM) represents a joint probability distribution over
|
|
40
|
+
a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
|
|
41
|
+
|
|
42
|
+
Add a random variable to a PGM, pgm, using `rv = pgm.new_rv(...)`.
|
|
43
|
+
|
|
44
|
+
Add a factor to the PGM, pgm, using `factor = pgm.new_factor(...)`.
|
|
45
|
+
|
|
46
|
+
A PGM may be given a human-readable name.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, name: Optional[str] = None):
|
|
50
|
+
"""
|
|
51
|
+
Create an empty PGM.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
name: an optional name for the PGM. If not provided, a default name will be
|
|
55
|
+
created using `default_pgm_name`.
|
|
56
|
+
"""
|
|
57
|
+
self._name: str = name if name is not None else default_pgm_name(self)
|
|
58
|
+
self._rvs: Tuple[RandomVariable, ...] = ()
|
|
59
|
+
self._shape: Shape = ()
|
|
60
|
+
self._indicators: Tuple[Indicator, ...] = ()
|
|
61
|
+
self._factors: Tuple[Factor, ...] = ()
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def name(self) -> str:
|
|
65
|
+
"""
|
|
66
|
+
Returns:
|
|
67
|
+
The name of the PGM.
|
|
68
|
+
"""
|
|
69
|
+
return self._name
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def number_of_rvs(self) -> int:
|
|
73
|
+
"""
|
|
74
|
+
Returns:
|
|
75
|
+
How many random variables are defined in this PGM.
|
|
76
|
+
"""
|
|
77
|
+
return len(self._rvs)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def shape(self) -> Shape:
|
|
81
|
+
"""
|
|
82
|
+
Returns:
|
|
83
|
+
a sequence of the lengths of `self.rvs`.
|
|
84
|
+
"""
|
|
85
|
+
return self._shape
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def number_of_indicators(self) -> int:
|
|
89
|
+
"""
|
|
90
|
+
Returns:
|
|
91
|
+
How many indicators are defined in this PGM, i.e., `sum(len(rv) for rv in self.rvs)`.
|
|
92
|
+
"""
|
|
93
|
+
return len(self._indicators)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def number_of_states(self) -> int:
|
|
97
|
+
"""
|
|
98
|
+
Returns:
|
|
99
|
+
What is the size of the state space, i.e., `multiply(len(rv) for rv in self.rvs)`.
|
|
100
|
+
"""
|
|
101
|
+
return number_of_states(*self._rvs)
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def number_of_factors(self) -> int:
|
|
105
|
+
"""
|
|
106
|
+
Returns:
|
|
107
|
+
How many factors are defined in this PGM.
|
|
108
|
+
"""
|
|
109
|
+
return len(self._factors)
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def number_of_functions(self) -> int:
|
|
113
|
+
"""
|
|
114
|
+
Returns:
|
|
115
|
+
How many potential functions are defined in this PGM, including zero potential functions.
|
|
116
|
+
"""
|
|
117
|
+
return sum(1 for _ in self.functions)
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def number_of_non_zero_functions(self) -> int:
|
|
121
|
+
"""
|
|
122
|
+
Returns:
|
|
123
|
+
How many potential functions are defined in this PGM, excluding zero potential functions.
|
|
124
|
+
"""
|
|
125
|
+
return sum(1 for _ in self.non_zero_functions)
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def rvs(self) -> Sequence[RandomVariable]:
|
|
129
|
+
"""
|
|
130
|
+
Returns:
|
|
131
|
+
All the random variables, in `idx` order, which is the same as creation order.
|
|
132
|
+
|
|
133
|
+
Ensures:
|
|
134
|
+
`self.rvs[rv.idx] = rv`
|
|
135
|
+
"""
|
|
136
|
+
return self._rvs
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def rv_log_sizes(self) -> Sequence[float]:
|
|
140
|
+
"""
|
|
141
|
+
Returns:
|
|
142
|
+
[log2(len(rv)) for rv in self.rvs]
|
|
143
|
+
"""
|
|
144
|
+
return [math.log2(len(rv)) for rv in self.rvs]
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def indicators(self) -> Sequence[Indicator]:
|
|
148
|
+
"""
|
|
149
|
+
Returns:
|
|
150
|
+
All the random variable indicators.
|
|
151
|
+
|
|
152
|
+
Ensures:
|
|
153
|
+
the indicators of a random variable are adjacent,
|
|
154
|
+
the indicators of a random variable are in state index order,
|
|
155
|
+
the random variables are in the same order as `self.rvs`.
|
|
156
|
+
"""
|
|
157
|
+
return self._indicators
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def factors(self) -> Sequence[Factor]:
|
|
161
|
+
"""
|
|
162
|
+
Returns:
|
|
163
|
+
All the factors, in `idx` order, which is the same as creation order.
|
|
164
|
+
|
|
165
|
+
Ensures:
|
|
166
|
+
`self.factors[factor.idx] = factor`
|
|
167
|
+
"""
|
|
168
|
+
return self._factors
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def functions(self) -> Iterable[PotentialFunction]:
|
|
172
|
+
"""
|
|
173
|
+
Iterate over all in-use potential functions of this PGM, including
|
|
174
|
+
zero potential functions.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
An Iterable over all potential functions (including zero potential functions).
|
|
178
|
+
"""
|
|
179
|
+
seen: Set[int] = set()
|
|
180
|
+
for factor in self._factors:
|
|
181
|
+
function = factor.function
|
|
182
|
+
if id(function) not in seen:
|
|
183
|
+
seen.add(id(function))
|
|
184
|
+
yield function
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def non_zero_functions(self) -> Iterable[PotentialFunction]:
|
|
188
|
+
"""
|
|
189
|
+
Iterate over all in-use potential functions of this PGM, excluding
|
|
190
|
+
zero potential functions.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
An Iterable over all potential functions (excluding zero potential functions).
|
|
194
|
+
"""
|
|
195
|
+
seen: Set[int] = set()
|
|
196
|
+
for factor in self._factors:
|
|
197
|
+
function = factor.function
|
|
198
|
+
if not (isinstance(function, ZeroPotentialFunction) or id(function) in seen):
|
|
199
|
+
seen.add(id(function))
|
|
200
|
+
yield function
|
|
201
|
+
|
|
202
|
+
def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
|
|
203
|
+
"""
|
|
204
|
+
Add a new random variable to this PGM.
|
|
205
|
+
|
|
206
|
+
The returned random variable will have an `idx` equal to the value of
|
|
207
|
+
`self.number_of_rvs` just prior to adding the new random variable.
|
|
208
|
+
|
|
209
|
+
Assumes:
|
|
210
|
+
Provided states contain no duplicates.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
name: a name for the random variable.
|
|
214
|
+
states: either an integer number of states or a sequence of state values. If a
|
|
215
|
+
single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
|
|
216
|
+
If a sequence of states are provided then the states must be unique.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
a RandomVariable object belonging to this PGM.
|
|
220
|
+
"""
|
|
221
|
+
return RandomVariable(self, name, states)
|
|
222
|
+
|
|
223
|
+
def new_factor(self, *rvs: RandomVariable) -> Factor:
|
|
224
|
+
"""
|
|
225
|
+
Add a new factor to this PGM where the factor connects
|
|
226
|
+
the given random variables.
|
|
227
|
+
|
|
228
|
+
The returned factor will have a ZeroPotentialFunction as its potential function.
|
|
229
|
+
The potential function may be changed by calling methods on the returned factor.
|
|
230
|
+
|
|
231
|
+
The returned factor will have an `idx` equal to the value of
|
|
232
|
+
`self.number_of_factors` just prior to adding the new factor.
|
|
233
|
+
|
|
234
|
+
Assumes:
|
|
235
|
+
The given random variables all belong to this PGM.
|
|
236
|
+
The random variables contain no duplicates.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
*rvs: the random variables.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
a Factor object belonging to this PGM.
|
|
243
|
+
"""
|
|
244
|
+
return Factor(self, *rvs)
|
|
245
|
+
|
|
246
|
+
def new_factor_implies(
|
|
247
|
+
self,
|
|
248
|
+
rv_1: RandomVariable,
|
|
249
|
+
state_idxs_1: int | Collection[int],
|
|
250
|
+
rv_2: RandomVariable,
|
|
251
|
+
state_idxs_2: int | Collection[int],
|
|
252
|
+
) -> Factor:
|
|
253
|
+
"""
|
|
254
|
+
Add a sparse 0/1 factor to this PGM representing:
|
|
255
|
+
rv_1 in state_idxs_1 ==> rv_2 in states_2.
|
|
256
|
+
That is:
|
|
257
|
+
factor[s1, s2] = 1, if s1 not in state_idxs_1 or s2 in states_2;
|
|
258
|
+
= 0, otherwise.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
rv_1: The first random variable.
|
|
262
|
+
state_idxs_1: state idxs of the first random variable.
|
|
263
|
+
rv_2: The second random variable.
|
|
264
|
+
state_idxs_2: state idxs of the second random variable.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
a Factor object belonging to this PGM, with a configured sparse potential function.
|
|
268
|
+
"""
|
|
269
|
+
if isinstance(state_idxs_1, int):
|
|
270
|
+
state_idxs_1 = (state_idxs_1,)
|
|
271
|
+
if isinstance(state_idxs_2, int):
|
|
272
|
+
state_idxs_2 = (state_idxs_2,)
|
|
273
|
+
|
|
274
|
+
factor = self.new_factor(rv_1, rv_2)
|
|
275
|
+
f = factor.set_sparse()
|
|
276
|
+
for i_1 in rv_1.state_range():
|
|
277
|
+
if i_1 not in state_idxs_1:
|
|
278
|
+
for i_2 in rv_2.state_range():
|
|
279
|
+
f[i_1, i_2] = 1
|
|
280
|
+
else:
|
|
281
|
+
for i_2 in rv_2.state_range():
|
|
282
|
+
if i_2 in state_idxs_2:
|
|
283
|
+
f[i_1, i_2] = 1
|
|
284
|
+
return factor
|
|
285
|
+
|
|
286
|
+
def new_factor_equiv(
|
|
287
|
+
self,
|
|
288
|
+
rv_1: RandomVariable,
|
|
289
|
+
state_idxs_1: int | Collection[int],
|
|
290
|
+
rv_2: RandomVariable,
|
|
291
|
+
state_idxs_2: int | Collection[int],
|
|
292
|
+
) -> Factor:
|
|
293
|
+
"""
|
|
294
|
+
Add a sparse 0/1 factor to this PGM representing:
|
|
295
|
+
rv_1 in state_idxs_1 <==> rv_2 in state_idxs_2.
|
|
296
|
+
That is:
|
|
297
|
+
factor[s1, s2] = 1, if s1 in state_idxs_1 == s2 in state_idxs_2;
|
|
298
|
+
= 0, otherwise.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
rv_1: The first random variable.
|
|
302
|
+
state_idxs_1: state idxs of the first random variable.
|
|
303
|
+
rv_2: The second random variable.
|
|
304
|
+
state_idxs_2: state idxs of the second random variable.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
a Factor object belonging to this PGM, with a configured sparse potential function.
|
|
308
|
+
"""
|
|
309
|
+
if isinstance(state_idxs_1, int):
|
|
310
|
+
state_idxs_1 = (state_idxs_1,)
|
|
311
|
+
if isinstance(state_idxs_2, int):
|
|
312
|
+
state_idxs_2 = (state_idxs_2,)
|
|
313
|
+
|
|
314
|
+
factor = self.new_factor(rv_1, rv_2)
|
|
315
|
+
f = factor.set_sparse()
|
|
316
|
+
for i_1 in rv_1.state_range():
|
|
317
|
+
in_1 = i_1 in state_idxs_1
|
|
318
|
+
for i_2 in rv_2.state_range():
|
|
319
|
+
in_2 = i_2 in state_idxs_2
|
|
320
|
+
if in_1 == in_2:
|
|
321
|
+
f[i_1, i_2] = 1
|
|
322
|
+
return factor
|
|
323
|
+
|
|
324
|
+
def new_factor_functional(
|
|
325
|
+
self,
|
|
326
|
+
function: Callable[[...], int],
|
|
327
|
+
result_rv: RandomVariable,
|
|
328
|
+
*input_rvs: RandomVariable
|
|
329
|
+
) -> Factor:
|
|
330
|
+
"""
|
|
331
|
+
Add a sparse 0/1 factor to this PGM representing:
|
|
332
|
+
result_rv == function(*rvs).
|
|
333
|
+
That is:
|
|
334
|
+
factor[result_s, *input_s] = 1, if result_s == function(*input_s);
|
|
335
|
+
= 0, otherwise.
|
|
336
|
+
Args:
|
|
337
|
+
function: a function from state indexes of the input random variables to a state index
|
|
338
|
+
of the result random variable. The function should take the same number of arguments
|
|
339
|
+
as `input_rvs` and return a state index for `result_rv`.
|
|
340
|
+
result_rv: the random variable defining result values.
|
|
341
|
+
*input_rvs: the random variables defining input values.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
a Factor object belonging to this PGM, with a configured sparse potential function.
|
|
345
|
+
"""
|
|
346
|
+
factor = self.new_factor(result_rv, *input_rvs)
|
|
347
|
+
f = factor.set_sparse()
|
|
348
|
+
for input_s in _combos([list(rv.state_range()) for rv in input_rvs]):
|
|
349
|
+
result_s = function(*input_s)
|
|
350
|
+
f[(result_s,) + input_s] = 1
|
|
351
|
+
return factor
|
|
352
|
+
|
|
353
|
+
def indicator_pair(self, indicator: Indicator) -> Tuple[RandomVariable, State]:
|
|
354
|
+
"""
|
|
355
|
+
Convert the given indicator to its RandomVariable and State value.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
indicator: the indicator to convert.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
(rv, state) where
|
|
362
|
+
rv: is the random variable of the indicator.
|
|
363
|
+
state: is the random variable state of the indicator.
|
|
364
|
+
"""
|
|
365
|
+
rv = self._rvs[indicator.rv_idx]
|
|
366
|
+
state = rv.states[indicator.state_idx]
|
|
367
|
+
return rv, state
|
|
368
|
+
|
|
369
|
+
def indicator_str(self, *indicators: Indicator, sep: str = '=', delim: str = ', ') -> str:
|
|
370
|
+
"""
|
|
371
|
+
Render indicators as a string.
|
|
372
|
+
|
|
373
|
+
For example:
|
|
374
|
+
pgm = PGM()
|
|
375
|
+
a = pgm.new_rv('A', ('x', 'y', 'z'))
|
|
376
|
+
b = pgm.new_rv('B', (3, 5))
|
|
377
|
+
print(pgm.indicator_str(a[0], b[1], a[2]))
|
|
378
|
+
will print:
|
|
379
|
+
A=x, B=5, A=z
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
*indicators: the indicators to render.
|
|
383
|
+
sep: the separator to use between the random variable and its state.
|
|
384
|
+
delim: the delimiter to used when rendering multiple indicators.
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
a string representation of the given indicators.
|
|
388
|
+
"""
|
|
389
|
+
return delim.join(
|
|
390
|
+
f'{_clean_str(rv)}{sep}{_clean_str(state)}'
|
|
391
|
+
for rv, state in (
|
|
392
|
+
self.indicator_pair(indicator)
|
|
393
|
+
for indicator in indicators
|
|
394
|
+
)
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
def condition_str(self, *indicators: Indicator) -> str:
|
|
398
|
+
"""
|
|
399
|
+
Render indicators as a string, grouping indicators by random variable.
|
|
400
|
+
|
|
401
|
+
For example:
|
|
402
|
+
pgm = PGM()
|
|
403
|
+
a = pgm.new_rv('A', ('x', 'y', 'z'))
|
|
404
|
+
b = pgm.new_rv('B', (3, 5))
|
|
405
|
+
print(pgm.condition_str(a[0], b[1], a[2]))
|
|
406
|
+
will print:
|
|
407
|
+
A in {x, z}, B=5
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
*indicators: the indicators to render.
|
|
411
|
+
Return:
|
|
412
|
+
a string representation of the given indicators, as a condition.
|
|
413
|
+
"""
|
|
414
|
+
indicators: List[Indicator] = sorted(indicators, reverse=True)
|
|
415
|
+
cur_rv: Set[Indicator] = set()
|
|
416
|
+
cur_idx: int = -1 # rv_idx of the rv we are currently working on, -1 means not yet started.
|
|
417
|
+
cur_str: str = '' # accumulated result string
|
|
418
|
+
while len(indicators) > 0:
|
|
419
|
+
this_ind = indicators.pop()
|
|
420
|
+
if this_ind.rv_idx != cur_idx:
|
|
421
|
+
if cur_idx >= 0:
|
|
422
|
+
cur_str = self._condition_str_rv(cur_str, cur_rv)
|
|
423
|
+
cur_rv = set()
|
|
424
|
+
cur_idx = this_ind.rv_idx
|
|
425
|
+
cur_rv.add(this_ind)
|
|
426
|
+
if cur_idx >= 0:
|
|
427
|
+
cur_str = self._condition_str_rv(cur_str, cur_rv)
|
|
428
|
+
return cur_str
|
|
429
|
+
|
|
430
|
+
def instance_str(
|
|
431
|
+
self,
|
|
432
|
+
instance: Instance,
|
|
433
|
+
rvs: Optional[Sequence[RandomVariable]] = None,
|
|
434
|
+
sep: str = '=',
|
|
435
|
+
delim: str = ', ',
|
|
436
|
+
) -> str:
|
|
437
|
+
"""
|
|
438
|
+
Render an instance as a string.
|
|
439
|
+
|
|
440
|
+
The result looks something like 'X=x, Y=y, Z=z' where X, Y, and X are
|
|
441
|
+
random variables and x, y, and z are the states represented by the
|
|
442
|
+
given instance.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
instance: the instance to render.
|
|
446
|
+
rvs: the random variables that the instance refers to. If rvs is None, then `self.rvs` is used.
|
|
447
|
+
sep: the separator to use between the random variable and its state.
|
|
448
|
+
delim: the delimiter to used when rendering multiple indicators.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
a string representation of the indicators implied by the given instance.
|
|
452
|
+
"""
|
|
453
|
+
if rvs is None:
|
|
454
|
+
rvs = self.rvs
|
|
455
|
+
assert len(instance) == len(rvs)
|
|
456
|
+
return self.indicator_str(
|
|
457
|
+
*[rv[state] for rv, state in zip(rvs, instance)],
|
|
458
|
+
sep=sep,
|
|
459
|
+
delim=delim
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
def state_str(
|
|
463
|
+
self,
|
|
464
|
+
instance: Instance,
|
|
465
|
+
rvs: Optional[Sequence[RandomVariable]] = None,
|
|
466
|
+
delim: str = ', ',
|
|
467
|
+
) -> str:
|
|
468
|
+
"""
|
|
469
|
+
Render the states of an instance.
|
|
470
|
+
|
|
471
|
+
The result looks something like 'x, y, z' where x, y, and z are
|
|
472
|
+
the states of the random variables represented by the given instance.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
instance: the instance to render.
|
|
476
|
+
rvs: the random variables that the instance refers to. If rvs is None, then `self.rvs` is used.
|
|
477
|
+
delim: the delimiter to used when rendering multiple indicators.
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
a string representation of the states implied by the given instance.
|
|
481
|
+
"""
|
|
482
|
+
if rvs is None:
|
|
483
|
+
rvs = self.rvs
|
|
484
|
+
assert len(instance) == len(rvs)
|
|
485
|
+
return delim.join(str(rv.states[i]) for rv, i in zip(rvs, instance))
|
|
486
|
+
|
|
487
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
488
|
+
"""
|
|
489
|
+
Iterate over all possible instances of this PGM, in natural index
|
|
490
|
+
order (i.e., last random variable changing most quickly).
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
flip: if true, then first random variable changes most quickly.
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
an iteration over tuples, each tuple holds random variable state indexes
|
|
497
|
+
co-indexed with this PGM's random variables, `self.rvs`.
|
|
498
|
+
"""
|
|
499
|
+
return _combos_ranges(tuple(len(rv) for rv in self._rvs), flip=not flip)
|
|
500
|
+
|
|
501
|
+
def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
|
|
502
|
+
"""
|
|
503
|
+
Iterate over all possible instances of this PGM, in natural index
|
|
504
|
+
order (i.e., last random variable changing most quickly).
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
flip: if true, then first random variable changes most quickly.
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
an iteration over tuples, each tuples holds random variable indicators
|
|
511
|
+
co-indexed with this PGM's random variables, `self.rvs`.
|
|
512
|
+
"""
|
|
513
|
+
for inst in self.instances(flip=flip):
|
|
514
|
+
yield self.state_idxs_to_indicators(inst)
|
|
515
|
+
|
|
516
|
+
def state_idxs_to_indicators(self, instance: Sequence[int]) -> Sequence[Indicator]:
|
|
517
|
+
"""
|
|
518
|
+
Given an instance (list of random variable state indexes), co-indexed with the PGM's
|
|
519
|
+
random variables, `self.rvs`, return the corresponding indicators.
|
|
520
|
+
|
|
521
|
+
Assumes:
|
|
522
|
+
The instance has the same length as `self.rvs`.
|
|
523
|
+
The instance is co-indexed with `self.rvs`.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
instance: the instance to convert to indicators.
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
a tuple of indicators, co-indexed with `self.rvs`.
|
|
530
|
+
"""
|
|
531
|
+
return tuple(rv[state] for rv, state in zip(self._rvs, instance))
|
|
532
|
+
|
|
533
|
+
def factor_values(self, key: Key) -> Iterable[float]:
|
|
534
|
+
"""
|
|
535
|
+
For a given instance key, each factor defines a single value. This method
|
|
536
|
+
returns those values.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
key: the key defining an instance of this PGM.
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
an iterator over factor values, co-indexed with the factors of this PGM.
|
|
543
|
+
"""
|
|
544
|
+
instance: Instance = check_key(self._shape, key)
|
|
545
|
+
assert len(instance) == len(self._rvs)
|
|
546
|
+
for factor in self._factors:
|
|
547
|
+
states: Sequence[int] = tuple(instance[rv.idx] for rv in factor.rvs)
|
|
548
|
+
value: float = factor.function[states]
|
|
549
|
+
yield value
|
|
550
|
+
|
|
551
|
+
@property
|
|
552
|
+
def is_structure_bayesian(self) -> bool:
|
|
553
|
+
"""
|
|
554
|
+
Does the PGM structure correspond to a Bayesian network, where
|
|
555
|
+
each factor is taken to be a CPT and the first random variable of factor
|
|
556
|
+
is taken to be the child.
|
|
557
|
+
|
|
558
|
+
This method does not check the factor parameters to confirm they correspond
|
|
559
|
+
to valid CPTs.
|
|
560
|
+
|
|
561
|
+
Return:
|
|
562
|
+
True only if:
|
|
563
|
+
the number of factors equals the number of random variables,
|
|
564
|
+
each random variable appears exactly once as the first random variable of a factor,
|
|
565
|
+
there are no directed loops created by the factors.
|
|
566
|
+
"""
|
|
567
|
+
|
|
568
|
+
# One factor per random variable.
|
|
569
|
+
if self.number_of_factors != self.number_of_rvs:
|
|
570
|
+
return False
|
|
571
|
+
|
|
572
|
+
# Each random variable is a child.
|
|
573
|
+
# Map each random variable to the factor it is a child of
|
|
574
|
+
child_to_factor: Dict[int, Factor] = {
|
|
575
|
+
factor.rvs[0].idx: factor
|
|
576
|
+
for factor in self._factors
|
|
577
|
+
}
|
|
578
|
+
if len(child_to_factor) != self.number_of_rvs:
|
|
579
|
+
return False
|
|
580
|
+
|
|
581
|
+
# Factors form a DAG
|
|
582
|
+
states: NDArrayUInt8 = np.zeros(self.number_of_factors, dtype=np.uint8)
|
|
583
|
+
for factor in self._factors:
|
|
584
|
+
if self._has_cycle(factor, child_to_factor, states):
|
|
585
|
+
return False
|
|
586
|
+
|
|
587
|
+
# All tests passed
|
|
588
|
+
return True
|
|
589
|
+
|
|
590
|
+
def factors_are_cpts(self, tolerance: float = DEFAULT_TOLERANCE) -> bool:
|
|
591
|
+
"""
|
|
592
|
+
Are all factor potential functions set with parameters values
|
|
593
|
+
conforming to Conditional Probability Tables.
|
|
594
|
+
|
|
595
|
+
Assumes:
|
|
596
|
+
tolerance is non-negative.
|
|
597
|
+
|
|
598
|
+
Args:
|
|
599
|
+
tolerance: a tolerance when testing if values are equal to zero or one.
|
|
600
|
+
|
|
601
|
+
Returns:
|
|
602
|
+
True only if every potential function conforms to being a valid CPT.
|
|
603
|
+
"""
|
|
604
|
+
return all(function.is_cpt(tolerance) for function in self.functions)
|
|
605
|
+
|
|
606
|
+
def check_is_bayesian_network(self, tolerance: float = DEFAULT_TOLERANCE) -> bool:
|
|
607
|
+
"""
|
|
608
|
+
Is this PGM a Bayesian network.
|
|
609
|
+
|
|
610
|
+
Assumes:
|
|
611
|
+
tolerance is non-negative.
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
tolerance: a tolerance when testing if values are equal to zero or one.
|
|
615
|
+
|
|
616
|
+
Returns:
|
|
617
|
+
`is_structure_bayesian and check_factors_are_cpts(tolerance)`.
|
|
618
|
+
"""
|
|
619
|
+
return self.is_structure_bayesian and self.factors_are_cpts(tolerance)
|
|
620
|
+
|
|
621
|
+
def value_product(self, key: Key) -> float:
|
|
622
|
+
"""
|
|
623
|
+
For a given instance key, each factor defines a single value. This method
|
|
624
|
+
returns the product of those values.
|
|
625
|
+
|
|
626
|
+
Args:
|
|
627
|
+
key: the key defining an instance of this PGM.
|
|
628
|
+
|
|
629
|
+
Returns:
|
|
630
|
+
the product of factor values.
|
|
631
|
+
"""
|
|
632
|
+
return _multiply(self.factor_values(key))
|
|
633
|
+
|
|
634
|
+
def value_product_indicators(self, *indicators: Indicator) -> float:
|
|
635
|
+
"""
|
|
636
|
+
Return the product of factors, conditioned on the given indicators.
|
|
637
|
+
|
|
638
|
+
For random variables not mentioned in the indicators, then the result is the sum
|
|
639
|
+
of the value product for each possible combination of states of the unmentioned
|
|
640
|
+
random variables.
|
|
641
|
+
|
|
642
|
+
If no indicators are provided, then the value of the partition function (z)
|
|
643
|
+
is returned.
|
|
644
|
+
|
|
645
|
+
If multiple indicators are provided for the same random variable, then all matching
|
|
646
|
+
instances are summed.
|
|
647
|
+
|
|
648
|
+
This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
|
|
649
|
+
|
|
650
|
+
Warning:
|
|
651
|
+
this is potentially computationally expensive as it marginalised random
|
|
652
|
+
variables not mentioned in the given indicators.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
*indicators: are indicators from random variables of this PGM.
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
the product of factors, conditioned on the given instance. This is the
|
|
659
|
+
computed value of the PGM, conditioned on the given instance.
|
|
660
|
+
"""
|
|
661
|
+
# # Create a filter from the indicators
|
|
662
|
+
# inst_filter: List[Set[int]] = [set() for _ in range(self.number_of_rvs)]
|
|
663
|
+
# for indicator in indicators:
|
|
664
|
+
# rv_idx: int = indicator.rv_idx
|
|
665
|
+
# inst_filter[rv_idx].add(indicator.state_idx)
|
|
666
|
+
# # Collect rvs not mentioned - to marginalise
|
|
667
|
+
# for rv, rv_filter in zip(self.rvs, inst_filter):
|
|
668
|
+
# if len(rv_filter) == 0:
|
|
669
|
+
# rv_filter.update(rv.state_range())
|
|
670
|
+
#
|
|
671
|
+
# def _sum_inst(_instance: Instance) -> bool:
|
|
672
|
+
# return all(
|
|
673
|
+
# (_state in _rv_filter)
|
|
674
|
+
# for _state, _rv_filter in zip(_instance, inst_filter)
|
|
675
|
+
# )
|
|
676
|
+
#
|
|
677
|
+
# # Accumulate the result
|
|
678
|
+
# sum_value = 0
|
|
679
|
+
# for instance in self.instances():
|
|
680
|
+
# if _sum_inst(instance):
|
|
681
|
+
# sum_value += self.value_product(instance)
|
|
682
|
+
#
|
|
683
|
+
# return sum_value
|
|
684
|
+
|
|
685
|
+
# Work out the space to sum over
|
|
686
|
+
sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
|
|
687
|
+
for indicator in indicators:
|
|
688
|
+
rv_idx: int = indicator.rv_idx
|
|
689
|
+
cur_set = sum_space_set[rv_idx]
|
|
690
|
+
if cur_set is None:
|
|
691
|
+
sum_space_set[rv_idx] = cur_set = set()
|
|
692
|
+
cur_set.add(indicator.state_idx)
|
|
693
|
+
|
|
694
|
+
# Convert to a list of states that we need to sum over.
|
|
695
|
+
sum_space_list: List[List[int]] = [
|
|
696
|
+
list(cur_set if cur_set is not None else rv.state_range())
|
|
697
|
+
for cur_set, rv in zip(sum_space_set, self.rvs)
|
|
698
|
+
]
|
|
699
|
+
|
|
700
|
+
# Accumulate the result
|
|
701
|
+
return sum(
|
|
702
|
+
self.value_product(instance)
|
|
703
|
+
for instance in _combos(sum_space_list)
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
def dump_synopsis(
|
|
707
|
+
self,
|
|
708
|
+
*,
|
|
709
|
+
prefix: str = '',
|
|
710
|
+
precision: int = 3,
|
|
711
|
+
max_state_digits: int = 21,
|
|
712
|
+
):
|
|
713
|
+
"""
|
|
714
|
+
Print a synopsis of the PGM.
|
|
715
|
+
This is intended for demonstration and debugging purposes.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
prefix: optional prefix for indenting all lines.
|
|
719
|
+
precision: a limit on the render precision of floating point numbers.
|
|
720
|
+
max_state_digits: a limit on the number of digits when showing number of states as an integer.
|
|
721
|
+
"""
|
|
722
|
+
# limit the precision when displaying number of states
|
|
723
|
+
num_states: int = self.number_of_states
|
|
724
|
+
number_of_parameters = sum(function.number_of_parameters for function in self.functions)
|
|
725
|
+
number_of_nz_parameters = sum(function.number_of_parameters for function in self.non_zero_functions)
|
|
726
|
+
|
|
727
|
+
if math.log10(num_states) > max_state_digits:
|
|
728
|
+
log_states = math.log10(num_states)
|
|
729
|
+
exp = int(log_states)
|
|
730
|
+
man = math.pow(10, log_states - exp)
|
|
731
|
+
num_states_str = f'{man:,.{precision}f}e+{exp}'
|
|
732
|
+
else:
|
|
733
|
+
num_states_str = f'{num_states:,}'
|
|
734
|
+
|
|
735
|
+
log_2_num_states = math.log2(num_states)
|
|
736
|
+
if (
|
|
737
|
+
log_2_num_states == 0
|
|
738
|
+
or (
|
|
739
|
+
log_2_num_states == int(log_2_num_states)
|
|
740
|
+
and math.log10(log_2_num_states) <= max_state_digits
|
|
741
|
+
)
|
|
742
|
+
):
|
|
743
|
+
log_2_num_states_str = f'{int(log_2_num_states):,}'
|
|
744
|
+
else:
|
|
745
|
+
log_2_num_states_str = f'{math.log2(num_states):,.{precision}f}'
|
|
746
|
+
|
|
747
|
+
print(f'{prefix}name: {self.name}')
|
|
748
|
+
print(f'{prefix}number of random variables: {self.number_of_rvs:,}')
|
|
749
|
+
print(f'{prefix}number of indicators: {self.number_of_indicators:,}')
|
|
750
|
+
print(f'{prefix}number of states: {num_states_str}')
|
|
751
|
+
print(f'{prefix}log 2 of states: {log_2_num_states_str}')
|
|
752
|
+
print(f'{prefix}number of factors: {self.number_of_factors:,}')
|
|
753
|
+
print(f'{prefix}number of functions: {self.number_of_functions:,}')
|
|
754
|
+
print(f'{prefix}number of non-zero functions: {self.number_of_non_zero_functions:,}')
|
|
755
|
+
print(f'{prefix}number of parameters: {number_of_parameters:,}')
|
|
756
|
+
print(f'{prefix}number of functions (excluding ZeroPotentialFunction): {self.number_of_non_zero_functions:,}')
|
|
757
|
+
print(f'{prefix}number of parameters (excluding ZeroPotentialFunction): {number_of_nz_parameters:,}')
|
|
758
|
+
print(f'{prefix}Bayesian structure: {self.is_structure_bayesian}')
|
|
759
|
+
print(f'{prefix}CPT factors: {self.factors_are_cpts()}')
|
|
760
|
+
|
|
761
|
+
def dump(
|
|
762
|
+
self,
|
|
763
|
+
*,
|
|
764
|
+
prefix: str = '',
|
|
765
|
+
indent: str = ' ',
|
|
766
|
+
show_function_values: bool = False,
|
|
767
|
+
precision: int = 3,
|
|
768
|
+
max_state_digits: int = 21,
|
|
769
|
+
) -> None:
|
|
770
|
+
"""
|
|
771
|
+
Print a dump of the PGM.
|
|
772
|
+
This is intended for demonstration and debugging purposes.
|
|
773
|
+
|
|
774
|
+
Args:
|
|
775
|
+
prefix: optional prefix for indenting all lines.
|
|
776
|
+
show_function_values: if true, then the function values will be dumped.
|
|
777
|
+
indent: additional prefix to use for extra indentation.
|
|
778
|
+
precision: a limit on the render precision of floating point numbers.
|
|
779
|
+
max_state_digits: a limit on the number of digits when showing number of states as an integer.
|
|
780
|
+
"""
|
|
781
|
+
|
|
782
|
+
next_prefix: str = prefix + indent
|
|
783
|
+
next_next_prefix: str = next_prefix + indent
|
|
784
|
+
|
|
785
|
+
print(f'{prefix}PGM id={id(self)} name={self.name!r}')
|
|
786
|
+
self.dump_synopsis(prefix=next_prefix, precision=precision, max_state_digits=max_state_digits)
|
|
787
|
+
|
|
788
|
+
print(f'{prefix}random variables ({self.number_of_rvs})')
|
|
789
|
+
for rv in self.rvs:
|
|
790
|
+
print(f'{next_prefix}{rv.idx:>3} {rv.name!r} ({len(rv)})', end='')
|
|
791
|
+
if not rv.is_default_states():
|
|
792
|
+
print(' [', end='')
|
|
793
|
+
print(', '.join(repr(s) for s in rv.states), end='')
|
|
794
|
+
print(']', end='')
|
|
795
|
+
print()
|
|
796
|
+
|
|
797
|
+
print(f'{prefix}factors ({self.number_of_factors})')
|
|
798
|
+
for factor in self.factors:
|
|
799
|
+
rv_idxs = [rv.idx for rv in factor.rvs]
|
|
800
|
+
if factor.is_zero:
|
|
801
|
+
function_ref = '<zero>'
|
|
802
|
+
else:
|
|
803
|
+
function = factor.function
|
|
804
|
+
function_ref = f'{id(function)}: {function.__class__.__name__}'
|
|
805
|
+
|
|
806
|
+
print(f'{next_prefix}{factor.idx:>3} rvs={rv_idxs} function={function_ref}')
|
|
807
|
+
|
|
808
|
+
print(f'{prefix}functions ({self.number_of_functions})')
|
|
809
|
+
for function in sorted(self.non_zero_functions, key=lambda f: id(f)):
|
|
810
|
+
print(f'{next_prefix}{id(function):>13}: {function.__class__.__name__}')
|
|
811
|
+
function.dump(prefix=next_next_prefix, show_function_values=show_function_values, show_id_class=False)
|
|
812
|
+
|
|
813
|
+
print(f'{prefix}end PGM id={id(self)}')
|
|
814
|
+
|
|
815
|
+
def _has_cycle(self, factor: Factor, child_to_factor: Dict[int, Factor], states: NDArrayUInt8) -> bool:
|
|
816
|
+
"""
|
|
817
|
+
Support function for `is_structure_bayesian`.
|
|
818
|
+
|
|
819
|
+
A recursive depth-first-search to see if the factors form a DAG.
|
|
820
|
+
|
|
821
|
+
For a factor `f` the value of states[f.idx] is the search state.
|
|
822
|
+
Specifically:
|
|
823
|
+
state 0 => the factor has not been seen yet,
|
|
824
|
+
state 1 => the factor is seen but not fully processed,
|
|
825
|
+
state 2 => the factor is fully processed.
|
|
826
|
+
|
|
827
|
+
Args:
|
|
828
|
+
factor: the current Factor being checked.
|
|
829
|
+
child_to_factor: a dictionary from `RandomVariable.idx` to Factor
|
|
830
|
+
with that random variable as the child.
|
|
831
|
+
states: depth-first-search states, i.e., `states[i]` is the state of a factor with `Factor.idx == i`.
|
|
832
|
+
Returns:
|
|
833
|
+
True if a directed cycle is detected.
|
|
834
|
+
"""
|
|
835
|
+
f_idx: int = factor.idx
|
|
836
|
+
match states.item(f_idx):
|
|
837
|
+
case 1:
|
|
838
|
+
return True
|
|
839
|
+
case 0:
|
|
840
|
+
states[f_idx] = 1
|
|
841
|
+
for parent in factor.rvs[1:]:
|
|
842
|
+
parent_factor = child_to_factor[parent.idx]
|
|
843
|
+
if self._has_cycle(parent_factor, child_to_factor, states):
|
|
844
|
+
return True
|
|
845
|
+
states[f_idx] = 2
|
|
846
|
+
return False
|
|
847
|
+
return False
|
|
848
|
+
|
|
849
|
+
def _register_rv(self, rv: RandomVariable) -> None:
|
|
850
|
+
"""
|
|
851
|
+
Called by the constructor of RandomVariable to record a newly created Random variable
|
|
852
|
+
of this PGM.
|
|
853
|
+
|
|
854
|
+
Args:
|
|
855
|
+
rv: the newly constructed random variable.
|
|
856
|
+
"""
|
|
857
|
+
assert rv.pgm is self
|
|
858
|
+
self._rvs += (rv,)
|
|
859
|
+
self._shape += (len(rv),)
|
|
860
|
+
self._indicators += rv.indicators
|
|
861
|
+
|
|
862
|
+
def _condition_str_rv(
|
|
863
|
+
self,
|
|
864
|
+
cur_str: str,
|
|
865
|
+
cur_rv: Set[Indicator],
|
|
866
|
+
sep: str = ', ',
|
|
867
|
+
equal: str = '=',
|
|
868
|
+
elem: str = ' in ',
|
|
869
|
+
) -> str:
|
|
870
|
+
"""
|
|
871
|
+
Support method for `self.condition_str`.
|
|
872
|
+
|
|
873
|
+
This is a method renders a condition defined by a set of indicators, of the same random variable.
|
|
874
|
+
|
|
875
|
+
Args:
|
|
876
|
+
cur_str: the string to append to.
|
|
877
|
+
cur_rv: a set of indicators, all from the same random variable.
|
|
878
|
+
sep: the separator string to use between condition components.
|
|
879
|
+
equal: the string to use for _rv_ = _state_.
|
|
880
|
+
elem: the string to use for _rv_ in _set_.
|
|
881
|
+
|
|
882
|
+
Returns:
|
|
883
|
+
`cur_str` appended with the new condition, `cur_rv`.
|
|
884
|
+
"""
|
|
885
|
+
if cur_str != '':
|
|
886
|
+
cur_str += sep
|
|
887
|
+
if len(cur_rv) == 1:
|
|
888
|
+
cur_str += self.indicator_str(*cur_rv, sep=equal)
|
|
889
|
+
else:
|
|
890
|
+
_cur_rv = sorted(cur_rv)
|
|
891
|
+
rv = self._rvs[_cur_rv[0].rv_idx]
|
|
892
|
+
states_str: str = sep.join(_clean_str(rv.states[ind.state_idx]) for ind in _cur_rv)
|
|
893
|
+
cur_str += f'{_clean_str(rv)}{elem}{{{states_str}}}'
|
|
894
|
+
return cur_str
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
@dataclass(frozen=True, eq=True, slots=True)
|
|
898
|
+
class Indicator:
|
|
899
|
+
"""
|
|
900
|
+
An indicator identifies a random variable being in a particular state.
|
|
901
|
+
|
|
902
|
+
Indicators are immutable and hashable.
|
|
903
|
+
|
|
904
|
+
Note that an Indicator does not know which PGM it came from, therefore indicators from one PGM
|
|
905
|
+
are interchangeable with indicators of another PGM so long as corresponding random variables of the
|
|
906
|
+
PGMs are co-indexed (created in the same order) and corresponding random variables have the same
|
|
907
|
+
states.
|
|
908
|
+
|
|
909
|
+
Fields:
|
|
910
|
+
rv_idx: `rv.idx` where `rv` is the random variable referenced by this indicator.
|
|
911
|
+
state_idx: the state index of the state referenced by this indicator.
|
|
912
|
+
"""
|
|
913
|
+
rv_idx: int
|
|
914
|
+
state_idx: int
|
|
915
|
+
|
|
916
|
+
def __lt__(self, other) -> bool:
|
|
917
|
+
"""
|
|
918
|
+
Define a sort order over indicators.
|
|
919
|
+
When sorted, indicators are ordered by random variable index, then by state index.
|
|
920
|
+
"""
|
|
921
|
+
if isinstance(other, Indicator):
|
|
922
|
+
if self.rv_idx < other.rv_idx:
|
|
923
|
+
return True
|
|
924
|
+
if self.rv_idx > other.rv_idx:
|
|
925
|
+
return False
|
|
926
|
+
return self.state_idx < other.state_idx
|
|
927
|
+
return False
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
class RandomVariable(Sequence[Indicator]):
|
|
931
|
+
"""
|
|
932
|
+
A random variable in a probabilistic graphical model.
|
|
933
|
+
|
|
934
|
+
Random variables are immutable and hashable.
|
|
935
|
+
|
|
936
|
+
Each RandomVariable has a fixed finite number of states.
|
|
937
|
+
Its states are indexed by integers, counting from zero.
|
|
938
|
+
|
|
939
|
+
Every RandomVariable object belongs to exactly one PGM object.
|
|
940
|
+
|
|
941
|
+
Every random variable has an index (counting from zero) which is its position
|
|
942
|
+
in the random variable's PGM list of random variables.
|
|
943
|
+
|
|
944
|
+
A random variable behaves like a sequence of Indicators, where each indicator represents a random
|
|
945
|
+
variable being in a particular state. Specifically for a random variable rv, len(rv) is the
|
|
946
|
+
number of states of the random variable and rv[i] is the Indicators representing that
|
|
947
|
+
rv is in the ith state. When sliced, the result is a tuple, i.e. rv[1:3] = (rv[1], rv[2]).
|
|
948
|
+
|
|
949
|
+
A RandomVariable has a name. This is for human convenience and has no functional purpose
|
|
950
|
+
within a PGM.
|
|
951
|
+
"""
|
|
952
|
+
|
|
953
|
+
def __init__(self, pgm: PGM, name: str, states: Union[int, Sequence[State]]):
|
|
954
|
+
"""
|
|
955
|
+
Create a new random variable, in the given PGM.
|
|
956
|
+
|
|
957
|
+
Assumes:
|
|
958
|
+
Provided states contain no duplicates.
|
|
959
|
+
|
|
960
|
+
Args:
|
|
961
|
+
pgm: the PGM that the random variable will belong to.
|
|
962
|
+
name: a name for the random variable.
|
|
963
|
+
states: either an integer number of states or a sequence of state values. If a
|
|
964
|
+
single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
|
|
965
|
+
If a sequence of states are provided then the states must be unique.
|
|
966
|
+
"""
|
|
967
|
+
self._pgm: PGM = pgm
|
|
968
|
+
self._name: str = name
|
|
969
|
+
|
|
970
|
+
if isinstance(states, int):
|
|
971
|
+
states = tuple(range(states))
|
|
972
|
+
|
|
973
|
+
self._states: Sequence[State] = tuple(states)
|
|
974
|
+
self._inv_states: Dict[State, int] = {state: idx for idx, state in enumerate(self._states)}
|
|
975
|
+
|
|
976
|
+
if len(self._inv_states) != len(self._states):
|
|
977
|
+
raise ValueError('random variable states are not unique')
|
|
978
|
+
|
|
979
|
+
self._offset: int = pgm.number_of_indicators
|
|
980
|
+
self._idx: int = pgm.number_of_rvs
|
|
981
|
+
self._indicators: Sequence[Indicator] = tuple(Indicator(self._idx, i) for i in range(len(self._states)))
|
|
982
|
+
|
|
983
|
+
# Register self with our PGM
|
|
984
|
+
# noinspection PyProtectedMember
|
|
985
|
+
pgm._register_rv(self)
|
|
986
|
+
|
|
987
|
+
@property
|
|
988
|
+
def pgm(self) -> PGM:
|
|
989
|
+
"""
|
|
990
|
+
Returns:
|
|
991
|
+
The PGM that this random variable belongs to.
|
|
992
|
+
"""
|
|
993
|
+
return self._pgm
|
|
994
|
+
|
|
995
|
+
@property
|
|
996
|
+
def name(self) -> str:
|
|
997
|
+
"""
|
|
998
|
+
Returns:
|
|
999
|
+
The name of this random variable.
|
|
1000
|
+
"""
|
|
1001
|
+
return self._name
|
|
1002
|
+
|
|
1003
|
+
@property
|
|
1004
|
+
def idx(self) -> int:
|
|
1005
|
+
"""
|
|
1006
|
+
Returns:
|
|
1007
|
+
The index of this random variable into the PGM.
|
|
1008
|
+
|
|
1009
|
+
Ensures:
|
|
1010
|
+
`self.pgm.rvs[self.idx] is self`.
|
|
1011
|
+
"""
|
|
1012
|
+
return self._idx
|
|
1013
|
+
|
|
1014
|
+
@property
|
|
1015
|
+
def offset(self) -> int:
|
|
1016
|
+
"""
|
|
1017
|
+
Returns:
|
|
1018
|
+
The index into the PGM's indicators for the start of this random variable's indicators.
|
|
1019
|
+
|
|
1020
|
+
Ensures:
|
|
1021
|
+
`self.pgm.indicators[self.offset + i] is self[i] for i in range(len(self))`.
|
|
1022
|
+
"""
|
|
1023
|
+
return self._offset
|
|
1024
|
+
|
|
1025
|
+
@property
|
|
1026
|
+
def states(self) -> Sequence[State]:
|
|
1027
|
+
"""
|
|
1028
|
+
Returns:
|
|
1029
|
+
the states of this random variable, in state index order.
|
|
1030
|
+
"""
|
|
1031
|
+
return self._states
|
|
1032
|
+
|
|
1033
|
+
@property
|
|
1034
|
+
def indicators(self) -> Sequence[Indicator]:
|
|
1035
|
+
"""
|
|
1036
|
+
Returns:
|
|
1037
|
+
the indicators of this random variable, in state index order.
|
|
1038
|
+
"""
|
|
1039
|
+
return self._indicators
|
|
1040
|
+
|
|
1041
|
+
def state_range(self) -> Iterable[int]:
|
|
1042
|
+
"""
|
|
1043
|
+
Iterate over the state indexes of this random variable, in order.
|
|
1044
|
+
|
|
1045
|
+
Returns:
|
|
1046
|
+
range(len(self))
|
|
1047
|
+
"""
|
|
1048
|
+
return range(len(self._states))
|
|
1049
|
+
|
|
1050
|
+
def factors(self) -> Iterable[Factor]:
|
|
1051
|
+
"""
|
|
1052
|
+
Iterate over factors that this random variable participates in.
|
|
1053
|
+
This method performs a search through all `self.pgm.factors`.
|
|
1054
|
+
|
|
1055
|
+
Returns:
|
|
1056
|
+
an iterator over factors.
|
|
1057
|
+
"""
|
|
1058
|
+
for factor in self._pgm.factors:
|
|
1059
|
+
if self in factor.rvs:
|
|
1060
|
+
yield factor
|
|
1061
|
+
|
|
1062
|
+
def markov_blanket(self) -> Set[RandomVariable]:
|
|
1063
|
+
"""
|
|
1064
|
+
Return the set of random variable that are connected
|
|
1065
|
+
to this random variable by a factor.
|
|
1066
|
+
This method performs a search through all `self.pgm.factors`.
|
|
1067
|
+
|
|
1068
|
+
Returns:
|
|
1069
|
+
a set of random variables connected to this random variable by any factor, excluding self.
|
|
1070
|
+
"""
|
|
1071
|
+
result = set()
|
|
1072
|
+
for factor in self.factors():
|
|
1073
|
+
result.update(factor.rvs)
|
|
1074
|
+
result.discard(self)
|
|
1075
|
+
return result
|
|
1076
|
+
|
|
1077
|
+
def state_idx(self, state: State) -> int:
|
|
1078
|
+
"""
|
|
1079
|
+
Returns:
|
|
1080
|
+
the state index of the given state of this random variable.
|
|
1081
|
+
|
|
1082
|
+
Assumes:
|
|
1083
|
+
the given state is a state of this random variable.
|
|
1084
|
+
"""
|
|
1085
|
+
return self._inv_states[state]
|
|
1086
|
+
|
|
1087
|
+
def is_default_states(self) -> bool:
|
|
1088
|
+
"""
|
|
1089
|
+
Are the states of this random variable the default states.
|
|
1090
|
+
I.e., `self.states[i] == i, for all 0 <= i < len(self)`.
|
|
1091
|
+
|
|
1092
|
+
Returns:
|
|
1093
|
+
True only if the states are the same as the state indexes.
|
|
1094
|
+
"""
|
|
1095
|
+
return all(i == s for i, s in enumerate(self._states))
|
|
1096
|
+
|
|
1097
|
+
def __str__(self) -> str:
|
|
1098
|
+
"""
|
|
1099
|
+
Returns:
|
|
1100
|
+
the name of this random variable.
|
|
1101
|
+
"""
|
|
1102
|
+
return self._name
|
|
1103
|
+
|
|
1104
|
+
def __call__(self, state: State) -> Indicator:
|
|
1105
|
+
"""
|
|
1106
|
+
Get the indicator for the given state.
|
|
1107
|
+
This is equivalent to self[self.state_idx(state)].
|
|
1108
|
+
|
|
1109
|
+
Returns:
|
|
1110
|
+
an indicator of this random variable.
|
|
1111
|
+
|
|
1112
|
+
Assumes:
|
|
1113
|
+
the given state is a state of this random variable.
|
|
1114
|
+
"""
|
|
1115
|
+
return self._indicators[self._inv_states[state]]
|
|
1116
|
+
|
|
1117
|
+
def __hash__(self) -> int:
|
|
1118
|
+
"""
|
|
1119
|
+
A random variable is hashable.
|
|
1120
|
+
"""
|
|
1121
|
+
return self._idx
|
|
1122
|
+
|
|
1123
|
+
def __eq__(self, other) -> bool:
|
|
1124
|
+
"""
|
|
1125
|
+
Two random variable are equal if they are the same object.
|
|
1126
|
+
"""
|
|
1127
|
+
return self is other
|
|
1128
|
+
|
|
1129
|
+
def equivalent(self, other: RandomVariable | Sequence[Indicator]) -> bool:
|
|
1130
|
+
"""
|
|
1131
|
+
Two random variable are equivalent if their indicators are equal. Only
|
|
1132
|
+
random variable indexes and state indexes are checked.
|
|
1133
|
+
|
|
1134
|
+
This ignores the names of the random variable and the names of their states.
|
|
1135
|
+
This means their indicators will work correctly in slot maps, even
|
|
1136
|
+
if from different PGMs.
|
|
1137
|
+
|
|
1138
|
+
Args:
|
|
1139
|
+
other: either a random variable or a sequence of Indicators.
|
|
1140
|
+
|
|
1141
|
+
Returns:
|
|
1142
|
+
True only if they represent the same sequence of indicators.
|
|
1143
|
+
"""
|
|
1144
|
+
indicators = self._indicators
|
|
1145
|
+
if isinstance(other, RandomVariable):
|
|
1146
|
+
return self.idx == other.idx and len(self) == len(other)
|
|
1147
|
+
else:
|
|
1148
|
+
return (
|
|
1149
|
+
len(indicators) == len(other) and
|
|
1150
|
+
all(indicators[i] == other[i] for i in range(len(indicators)))
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
def __len__(self) -> int:
|
|
1154
|
+
"""
|
|
1155
|
+
Returns:
|
|
1156
|
+
Number of states (or equivalently, the number of indicators) of this random variable.
|
|
1157
|
+
"""
|
|
1158
|
+
return len(self._states)
|
|
1159
|
+
|
|
1160
|
+
def __iter__(self) -> Iterator[Indicator]:
|
|
1161
|
+
"""
|
|
1162
|
+
Iterate over the indicators of this random variable.
|
|
1163
|
+
"""
|
|
1164
|
+
return iter(self._indicators)
|
|
1165
|
+
|
|
1166
|
+
@overload
|
|
1167
|
+
def __getitem__(self, index: int) -> Indicator:
|
|
1168
|
+
...
|
|
1169
|
+
|
|
1170
|
+
@overload
|
|
1171
|
+
def __getitem__(self, index: slice) -> Sequence[Indicator]:
|
|
1172
|
+
...
|
|
1173
|
+
|
|
1174
|
+
def __getitem__(self, index):
|
|
1175
|
+
"""
|
|
1176
|
+
Get the indexed (or sliced) indicators.
|
|
1177
|
+
"""
|
|
1178
|
+
return self._indicators[index]
|
|
1179
|
+
|
|
1180
|
+
def index(self, value: Any, start: int = 0, stop: int = -1) -> int:
|
|
1181
|
+
"""
|
|
1182
|
+
Returns the first index of `value`.
|
|
1183
|
+
Raises ValueError if the value is not present.
|
|
1184
|
+
Contracted by Sequence[Indicator].
|
|
1185
|
+
|
|
1186
|
+
Warning:
|
|
1187
|
+
This method is different to `self.idx`.
|
|
1188
|
+
"""
|
|
1189
|
+
if isinstance(value, Indicator):
|
|
1190
|
+
if value.rv_idx == self._idx:
|
|
1191
|
+
idx: int = value.state_idx
|
|
1192
|
+
if stop < 0:
|
|
1193
|
+
stop = len(self) + stop + 1
|
|
1194
|
+
if 0 <= idx < len(self) and start <= idx < stop:
|
|
1195
|
+
return value.state_idx
|
|
1196
|
+
raise ValueError(f'{value!r} is not an indicator of the random variable')
|
|
1197
|
+
|
|
1198
|
+
def count(self, value: Any) -> int:
|
|
1199
|
+
"""
|
|
1200
|
+
Returns the number of occurrences of `value`.
|
|
1201
|
+
Contracted by Sequence[Indicator].
|
|
1202
|
+
"""
|
|
1203
|
+
if isinstance(value, Indicator):
|
|
1204
|
+
if value.rv_idx == self._idx and 0 <= value.state_idx < len(self):
|
|
1205
|
+
return 1
|
|
1206
|
+
return 0
|
|
1207
|
+
|
|
1208
|
+
|
|
1209
|
+
class RVMap(Sequence[RandomVariable]):
|
|
1210
|
+
"""
|
|
1211
|
+
Wrap a PGM to provide convenient access to PGM random variables.
|
|
1212
|
+
|
|
1213
|
+
An RVMap of a PGM behaves exactly like the PGM `rvs` property. That it, it
|
|
1214
|
+
behaves like a sequence of RandomVariable objects.
|
|
1215
|
+
|
|
1216
|
+
If the underlying PGM is updated, then the RVMap will automatically update.
|
|
1217
|
+
|
|
1218
|
+
Additionally, an RVMap enables access to the PGM random variable via the name
|
|
1219
|
+
of each random variable.
|
|
1220
|
+
|
|
1221
|
+
for example, if `pgm.rvs[1]` is a random variable named `xray`, then
|
|
1222
|
+
```
|
|
1223
|
+
rvs = RVMap(pgm)
|
|
1224
|
+
|
|
1225
|
+
# These all retrieve the same random variable object.
|
|
1226
|
+
xray = rvs[1]
|
|
1227
|
+
xray = rvs('xray')
|
|
1228
|
+
xray = rvs.xray
|
|
1229
|
+
```
|
|
1230
|
+
|
|
1231
|
+
To use an RVMap on a PGM, the variable names must be unique across the PGM.
|
|
1232
|
+
"""
|
|
1233
|
+
|
|
1234
|
+
def __init__(self, pgm: PGM, ignore_case: bool = False):
|
|
1235
|
+
"""
|
|
1236
|
+
Construct an RVMap for the given PGM.
|
|
1237
|
+
|
|
1238
|
+
Args:
|
|
1239
|
+
pgm: the PGM to wrap.
|
|
1240
|
+
ignore_case: if true, the variable name are not case-sensitive.
|
|
1241
|
+
"""
|
|
1242
|
+
self._pgm: PGM = pgm
|
|
1243
|
+
self._ignore_case: bool = ignore_case
|
|
1244
|
+
self.__rv_map: Dict[str, RandomVariable] = {}
|
|
1245
|
+
self._reserved_names: Set[str] = {self._clean_name(name) for name in dir(self)}
|
|
1246
|
+
|
|
1247
|
+
# Force the rv map cache to be updated.
|
|
1248
|
+
# This may raise an exception.
|
|
1249
|
+
_ = self._rv_map
|
|
1250
|
+
|
|
1251
|
+
def _clean_name(self, name: str) -> str:
|
|
1252
|
+
"""
|
|
1253
|
+
Adjust the case of the given name as needed.
|
|
1254
|
+
"""
|
|
1255
|
+
return name.lower() if self._ignore_case else name
|
|
1256
|
+
|
|
1257
|
+
@property
|
|
1258
|
+
def _rv_map(self) -> Dict[str, RandomVariable]:
|
|
1259
|
+
"""
|
|
1260
|
+
Get the cached rv map, updating as needed if the PGM changed.
|
|
1261
|
+
Returns:
|
|
1262
|
+
a mapping from random variable name to random variable
|
|
1263
|
+
"""
|
|
1264
|
+
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1265
|
+
# There is a difference between the map and the PGM - create a new map.
|
|
1266
|
+
self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
|
|
1267
|
+
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1268
|
+
raise RuntimeError(f'random variable names are not unique')
|
|
1269
|
+
if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
|
|
1270
|
+
raise RuntimeError(f'random variable names clash with reserved names.')
|
|
1271
|
+
return self.__rv_map
|
|
1272
|
+
|
|
1273
|
+
def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
|
|
1274
|
+
"""
|
|
1275
|
+
As per `PGM.new_rv`.
|
|
1276
|
+
Delegate creating a new random variable to the PGM.
|
|
1277
|
+
|
|
1278
|
+
Returns:
|
|
1279
|
+
a RandomVariable object belonging to the PGM.
|
|
1280
|
+
"""
|
|
1281
|
+
return self._pgm.new_rv(name, states)
|
|
1282
|
+
|
|
1283
|
+
def __len__(self) -> int:
|
|
1284
|
+
return len(self._pgm.rvs)
|
|
1285
|
+
|
|
1286
|
+
def __getitem__(self, index: int) -> RandomVariable:
|
|
1287
|
+
return self._pgm.rvs[index]
|
|
1288
|
+
|
|
1289
|
+
def items(self) -> Iterable[Tuple[str, RandomVariable]]:
|
|
1290
|
+
return self._rv_map.items()
|
|
1291
|
+
|
|
1292
|
+
def keys(self) -> Iterable[str]:
|
|
1293
|
+
return self._rv_map.keys()
|
|
1294
|
+
|
|
1295
|
+
def values(self) -> Iterable[RandomVariable]:
|
|
1296
|
+
return self._rv_map.values()
|
|
1297
|
+
|
|
1298
|
+
def get(self, rv_name: str, default=None):
|
|
1299
|
+
return self._rv_map.get(self._clean_name(rv_name), default)
|
|
1300
|
+
|
|
1301
|
+
def __call__(self, rv_name: str) -> RandomVariable:
|
|
1302
|
+
return self._rv_map[self._clean_name(rv_name)]
|
|
1303
|
+
|
|
1304
|
+
def __getattr__(self, rv_name: str) -> RandomVariable:
|
|
1305
|
+
return self(rv_name)
|
|
1306
|
+
|
|
1307
|
+
|
|
1308
|
+
class Factor:
|
|
1309
|
+
"""
|
|
1310
|
+
A PGM factor over one or more random variables declares a relationship between
|
|
1311
|
+
those variables. A Factor also has a potential function associated with
|
|
1312
|
+
it which defines a real-number value with each combination of states of
|
|
1313
|
+
the random variables.
|
|
1314
|
+
|
|
1315
|
+
The default potential function for a factor is a unique ZeroPotentialFunction.
|
|
1316
|
+
|
|
1317
|
+
The order of a Factors random variables is important as many things will be
|
|
1318
|
+
co-indexed with the random variables. For example, the shape of a Factor is
|
|
1319
|
+
the tuple of random variable lengths.
|
|
1320
|
+
|
|
1321
|
+
Note that multiple factors may share a potential function, so long as they all
|
|
1322
|
+
belong to the same PGM object and have the same shape.
|
|
1323
|
+
"""
|
|
1324
|
+
|
|
1325
|
+
def __init__(self, pgm: PGM, *rvs: RandomVariable):
|
|
1326
|
+
"""
|
|
1327
|
+
Add a new factor to the given PGM.
|
|
1328
|
+
|
|
1329
|
+
Assumes:
|
|
1330
|
+
The given random variables all belong to this PGM.
|
|
1331
|
+
The random variables contain no duplicates.
|
|
1332
|
+
|
|
1333
|
+
Args:
|
|
1334
|
+
pgm: the PGM that the factor will belong to.
|
|
1335
|
+
*rvs: the random variables.
|
|
1336
|
+
|
|
1337
|
+
Returns:
|
|
1338
|
+
a Factor object belonging to this PGM.
|
|
1339
|
+
"""
|
|
1340
|
+
if len(set(rvs)) != len(rvs):
|
|
1341
|
+
raise ValueError('duplicated random variable in factor')
|
|
1342
|
+
if len(rvs) == 0:
|
|
1343
|
+
raise ValueError('must be at least one random variable')
|
|
1344
|
+
if any(rv.pgm is not pgm for rv in rvs):
|
|
1345
|
+
raise ValueError('random variable not from the same PGM')
|
|
1346
|
+
|
|
1347
|
+
self._pgm: PGM = pgm
|
|
1348
|
+
self._idx: int = pgm.number_of_factors
|
|
1349
|
+
self._rvs: Sequence[RandomVariable] = tuple(rvs)
|
|
1350
|
+
self._shape: Shape = tuple(len(rv) for rv in rvs)
|
|
1351
|
+
|
|
1352
|
+
self._zero_potential_function: ZeroPotentialFunction = ZeroPotentialFunction(self)
|
|
1353
|
+
self._potential_function: PotentialFunction = self._zero_potential_function
|
|
1354
|
+
|
|
1355
|
+
# Register self with our PGM
|
|
1356
|
+
# noinspection PyProtectedMember
|
|
1357
|
+
pgm._factors += (self,)
|
|
1358
|
+
|
|
1359
|
+
@property
|
|
1360
|
+
def rvs(self) -> Sequence[RandomVariable]:
|
|
1361
|
+
"""
|
|
1362
|
+
Returns:
|
|
1363
|
+
The random variables of this factor.
|
|
1364
|
+
"""
|
|
1365
|
+
return self._rvs
|
|
1366
|
+
|
|
1367
|
+
@property
|
|
1368
|
+
def pgm(self) -> PGM:
|
|
1369
|
+
"""
|
|
1370
|
+
Returns:
|
|
1371
|
+
The PGM that this factor belongs to.
|
|
1372
|
+
"""
|
|
1373
|
+
return self._pgm
|
|
1374
|
+
|
|
1375
|
+
@property
|
|
1376
|
+
def idx(self) -> int:
|
|
1377
|
+
"""
|
|
1378
|
+
Returns:
|
|
1379
|
+
The index of this factor into the PGM.
|
|
1380
|
+
|
|
1381
|
+
Ensures:
|
|
1382
|
+
`self.pgm.factors[self.idx] is self`.
|
|
1383
|
+
"""
|
|
1384
|
+
return self._idx
|
|
1385
|
+
|
|
1386
|
+
@property
|
|
1387
|
+
def shape(self) -> Shape:
|
|
1388
|
+
return self._shape
|
|
1389
|
+
|
|
1390
|
+
@property
|
|
1391
|
+
def number_of_states(self) -> int:
|
|
1392
|
+
"""
|
|
1393
|
+
How many distinct states are covered by this Factor.
|
|
1394
|
+
"""
|
|
1395
|
+
return self._potential_function.number_of_states
|
|
1396
|
+
|
|
1397
|
+
def __str__(self) -> str:
|
|
1398
|
+
"""
|
|
1399
|
+
Return a human-readable string to represent this factor.
|
|
1400
|
+
This is intended mainly for debugging purposes.
|
|
1401
|
+
"""
|
|
1402
|
+
return '(' + ', '.join([repr(str(rv)) for rv in self._rvs]) + ')'
|
|
1403
|
+
|
|
1404
|
+
def __len__(self) -> int:
|
|
1405
|
+
"""
|
|
1406
|
+
Returns:
|
|
1407
|
+
the number of random variables.
|
|
1408
|
+
"""
|
|
1409
|
+
return len(self._rvs)
|
|
1410
|
+
|
|
1411
|
+
@overload
|
|
1412
|
+
def __getitem__(self, index: int) -> RandomVariable:
|
|
1413
|
+
...
|
|
1414
|
+
|
|
1415
|
+
@overload
|
|
1416
|
+
def __getitem__(self, index: slice) -> Sequence[RandomVariable]:
|
|
1417
|
+
...
|
|
1418
|
+
|
|
1419
|
+
def __getitem__(self, index):
|
|
1420
|
+
return self._rvs[index]
|
|
1421
|
+
|
|
1422
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1423
|
+
"""
|
|
1424
|
+
Iterate over all possible instances, in natural index order (i.e.,
|
|
1425
|
+
last random variable changing most quickly).
|
|
1426
|
+
|
|
1427
|
+
Args:
|
|
1428
|
+
flip: if true, then first random variable changes most quickly
|
|
1429
|
+
|
|
1430
|
+
Returns:
|
|
1431
|
+
an iterator over tuples, each tuple holds random variable
|
|
1432
|
+
state indexes, co-indexed with this object's shape, i.e., self.shape.
|
|
1433
|
+
"""
|
|
1434
|
+
return self.function.instances(flip)
|
|
1435
|
+
|
|
1436
|
+
def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1437
|
+
"""
|
|
1438
|
+
Iterate over all possible instances of parent random variable, in
|
|
1439
|
+
natural index order (i.e., last random variable changing most quickly).
|
|
1440
|
+
|
|
1441
|
+
Args:
|
|
1442
|
+
flip: if true, then first random variable changes most quickly
|
|
1443
|
+
|
|
1444
|
+
Returns:
|
|
1445
|
+
an iteration over tuples, each tuple holds random variable states
|
|
1446
|
+
co-indexed with this object's 'parent' shape, i.e., `self.shape[1:]`.
|
|
1447
|
+
"""
|
|
1448
|
+
return self.function.parent_instances(flip)
|
|
1449
|
+
|
|
1450
|
+
@property
|
|
1451
|
+
def is_zero(self) -> bool:
|
|
1452
|
+
"""
|
|
1453
|
+
Is the potential function of this factor set to the special 'zero' potential function.
|
|
1454
|
+
"""
|
|
1455
|
+
return self._potential_function is self._zero_potential_function
|
|
1456
|
+
|
|
1457
|
+
@property
|
|
1458
|
+
def function(self) -> PotentialFunction:
|
|
1459
|
+
return self._potential_function
|
|
1460
|
+
|
|
1461
|
+
@function.setter
|
|
1462
|
+
def function(self, function: PotentialFunction | Factor) -> None:
|
|
1463
|
+
"""
|
|
1464
|
+
Set the potential function for this PGM factor to the given potential function
|
|
1465
|
+
or factor.
|
|
1466
|
+
|
|
1467
|
+
Assumes:
|
|
1468
|
+
The given potential function belongs to the same PGM as this Factor.
|
|
1469
|
+
The potential function has the correct shape.
|
|
1470
|
+
"""
|
|
1471
|
+
if isinstance(function, Factor):
|
|
1472
|
+
function = function.function
|
|
1473
|
+
assert isinstance(function, PotentialFunction)
|
|
1474
|
+
|
|
1475
|
+
if self._potential_function is function:
|
|
1476
|
+
# nothing to do
|
|
1477
|
+
return
|
|
1478
|
+
|
|
1479
|
+
if function.pgm is not self._pgm:
|
|
1480
|
+
raise ValueError(f'the given function is not of the same PGM as the factor')
|
|
1481
|
+
|
|
1482
|
+
if function.shape != self._shape:
|
|
1483
|
+
raise ValueError(f'incorrect function shape: expected {self._shape}, got {function.shape}')
|
|
1484
|
+
|
|
1485
|
+
if isinstance(function, ZeroPotentialFunction):
|
|
1486
|
+
self.set_zero()
|
|
1487
|
+
else:
|
|
1488
|
+
self._potential_function = function
|
|
1489
|
+
|
|
1490
|
+
def set_zero(self) -> ZeroPotentialFunction:
|
|
1491
|
+
"""
|
|
1492
|
+
Set the factor's potential function to its original ZeroPotentialFunction.
|
|
1493
|
+
|
|
1494
|
+
Returns:
|
|
1495
|
+
the potential function.
|
|
1496
|
+
"""
|
|
1497
|
+
self._potential_function = self._zero_potential_function
|
|
1498
|
+
return self._potential_function
|
|
1499
|
+
|
|
1500
|
+
def set_dense(self) -> DensePotentialFunction:
|
|
1501
|
+
"""
|
|
1502
|
+
Set to the potential function to a new `DensePotentialFunction` object.
|
|
1503
|
+
|
|
1504
|
+
Returns:
|
|
1505
|
+
the potential function.
|
|
1506
|
+
"""
|
|
1507
|
+
self._potential_function = DensePotentialFunction(self)
|
|
1508
|
+
return self._potential_function
|
|
1509
|
+
|
|
1510
|
+
def set_sparse(self) -> SparsePotentialFunction:
|
|
1511
|
+
"""
|
|
1512
|
+
Set to the potential function to a new `SparsePotentialFunction` object.
|
|
1513
|
+
|
|
1514
|
+
Returns:
|
|
1515
|
+
the potential function.
|
|
1516
|
+
"""
|
|
1517
|
+
self._potential_function = SparsePotentialFunction(self)
|
|
1518
|
+
return self._potential_function
|
|
1519
|
+
|
|
1520
|
+
def set_compact(self) -> CompactPotentialFunction:
|
|
1521
|
+
"""
|
|
1522
|
+
Set to the potential function to a new `CompactPotentialFunction` object.
|
|
1523
|
+
|
|
1524
|
+
Returns:
|
|
1525
|
+
the potential function.
|
|
1526
|
+
"""
|
|
1527
|
+
self._potential_function = CompactPotentialFunction(self)
|
|
1528
|
+
return self._potential_function
|
|
1529
|
+
|
|
1530
|
+
def set_clause(self, *key: int) -> ClausePotentialFunction:
|
|
1531
|
+
"""
|
|
1532
|
+
Set to the potential function to a new `ClausePotentialFunction` object.
|
|
1533
|
+
|
|
1534
|
+
Args:
|
|
1535
|
+
*key: defines the random variable states of the clause. The key is a sequence of
|
|
1536
|
+
random variable state indexes, co-indexed with `Factor.rvs`.
|
|
1537
|
+
|
|
1538
|
+
Returns:
|
|
1539
|
+
the potential function.
|
|
1540
|
+
|
|
1541
|
+
Raises:
|
|
1542
|
+
KeyError: if the key is not valid for the shape of the factor.
|
|
1543
|
+
"""
|
|
1544
|
+
self._potential_function = ClausePotentialFunction(self, key)
|
|
1545
|
+
return self._potential_function
|
|
1546
|
+
|
|
1547
|
+
def set_cpt(self, tolerance: float = DEFAULT_TOLERANCE) -> CPTPotentialFunction:
|
|
1548
|
+
"""
|
|
1549
|
+
Set to the potential function to a new `CPTPotentialFunction` object.
|
|
1550
|
+
|
|
1551
|
+
Args:
|
|
1552
|
+
tolerance: a tolerance when testing if values are equal to zero or one.
|
|
1553
|
+
|
|
1554
|
+
Returns:
|
|
1555
|
+
the potential function.
|
|
1556
|
+
|
|
1557
|
+
Raises:
|
|
1558
|
+
ValueError: if tolerance is negative.
|
|
1559
|
+
"""
|
|
1560
|
+
self._potential_function = CPTPotentialFunction(self, tolerance)
|
|
1561
|
+
return self._potential_function
|
|
1562
|
+
|
|
1563
|
+
|
|
1564
|
+
@dataclass(frozen=True, eq=True)
|
|
1565
|
+
class ParamId:
|
|
1566
|
+
"""
|
|
1567
|
+
A ParamId identifies a parameter of a potential function.
|
|
1568
|
+
|
|
1569
|
+
Parameter identifiers uniquely identify every parameter within a PGM.
|
|
1570
|
+
|
|
1571
|
+
A ParamId is immutable and hashable.
|
|
1572
|
+
"""
|
|
1573
|
+
function_id: int
|
|
1574
|
+
param_idx: int
|
|
1575
|
+
|
|
1576
|
+
|
|
1577
|
+
class PotentialFunction(ABC):
|
|
1578
|
+
"""
|
|
1579
|
+
A potential function defines the potential values for a Factor, where
|
|
1580
|
+
a factor joins one or more variables of a PGM.
|
|
1581
|
+
|
|
1582
|
+
A potential function may be shared by several Factors of a PGM,
|
|
1583
|
+
i.e., can be applied to multiple variables.
|
|
1584
|
+
|
|
1585
|
+
The `shape` of a potential function is a tuple of integers which defines
|
|
1586
|
+
the number of variables, len(shape), and the number of states of each
|
|
1587
|
+
variable, shape[i].
|
|
1588
|
+
|
|
1589
|
+
The potential function value for variable states (x = i, y = j, ...) is given by
|
|
1590
|
+
self[i, j, ...], i.e., self.__getitem__((i, j, ...)). The tuple, (i, j, ...), is
|
|
1591
|
+
known as a Key.
|
|
1592
|
+
|
|
1593
|
+
The values of a potential function are defined by potential function parameters.
|
|
1594
|
+
The number of potential function parameters is given by number_of_parameters.
|
|
1595
|
+
The value of each parameter is given by get_param(i), where i is the parameter index.
|
|
1596
|
+
|
|
1597
|
+
Every valid key of the potential function is mapped either mapped to a parameter or is
|
|
1598
|
+
"guaranteed zero" which means that the value is zero and cannot be changed by changing
|
|
1599
|
+
the values of the potential function's parameters.
|
|
1600
|
+
"""
|
|
1601
|
+
|
|
1602
|
+
def __init__(self, factor: Factor):
|
|
1603
|
+
"""
|
|
1604
|
+
Create a potential function compatible with the given factor.
|
|
1605
|
+
|
|
1606
|
+
Ensures:
|
|
1607
|
+
Does not hold a reference to the given factor.
|
|
1608
|
+
Does not register the potential function with the PGM.
|
|
1609
|
+
|
|
1610
|
+
Args:
|
|
1611
|
+
factor: which factor is this potential function is compatible with.
|
|
1612
|
+
"""
|
|
1613
|
+
self._pgm: PGM = factor.pgm
|
|
1614
|
+
self._shape: Shape = factor.shape
|
|
1615
|
+
self._number_of_states = _multiply(self._shape)
|
|
1616
|
+
|
|
1617
|
+
@property
|
|
1618
|
+
def pgm(self) -> PGM:
|
|
1619
|
+
"""
|
|
1620
|
+
Returns:
|
|
1621
|
+
The PGM that this potential function belong to.
|
|
1622
|
+
"""
|
|
1623
|
+
return self._pgm
|
|
1624
|
+
|
|
1625
|
+
@property
|
|
1626
|
+
def shape(self) -> Shape:
|
|
1627
|
+
"""
|
|
1628
|
+
Returns:
|
|
1629
|
+
The shape of this potential function.
|
|
1630
|
+
"""
|
|
1631
|
+
return self._shape
|
|
1632
|
+
|
|
1633
|
+
@property
|
|
1634
|
+
def number_of_rvs(self) -> int:
|
|
1635
|
+
"""
|
|
1636
|
+
Returns:
|
|
1637
|
+
The number of random variables in this potential function.
|
|
1638
|
+
"""
|
|
1639
|
+
return len(self._shape)
|
|
1640
|
+
|
|
1641
|
+
@property
|
|
1642
|
+
def number_of_states(self) -> int:
|
|
1643
|
+
"""
|
|
1644
|
+
How many distinct states are covered by this potential function.
|
|
1645
|
+
|
|
1646
|
+
Returns:
|
|
1647
|
+
The size of the state space of this potential function.
|
|
1648
|
+
"""
|
|
1649
|
+
return self._number_of_states
|
|
1650
|
+
|
|
1651
|
+
@property
|
|
1652
|
+
def number_of_parent_states(self) -> int:
|
|
1653
|
+
"""
|
|
1654
|
+
How many distinct states are covered by this potential function parents,
|
|
1655
|
+
i.e., excluding the first random variable.
|
|
1656
|
+
|
|
1657
|
+
Returns:
|
|
1658
|
+
The size of the state space of this potential function parent random variables.
|
|
1659
|
+
"""
|
|
1660
|
+
return _multiply(self._shape[1:])
|
|
1661
|
+
|
|
1662
|
+
def count_usage(self) -> int:
|
|
1663
|
+
"""
|
|
1664
|
+
Check all PGM factors to count the number of times that this potential function
|
|
1665
|
+
is used.
|
|
1666
|
+
|
|
1667
|
+
Returns:
|
|
1668
|
+
the number of factors that use this potential function.
|
|
1669
|
+
"""
|
|
1670
|
+
return sum(1 for factor in self._pgm.factors if factor.function is self)
|
|
1671
|
+
|
|
1672
|
+
def check_key(self, key: Key) -> Instance:
|
|
1673
|
+
"""
|
|
1674
|
+
Convert the key into an instance.
|
|
1675
|
+
|
|
1676
|
+
Arg:
|
|
1677
|
+
key: defines an instance in the state space of the potential function.
|
|
1678
|
+
|
|
1679
|
+
Returns:
|
|
1680
|
+
an instance, which is a tuple of state indexes, co-indexed with `self.rvs`.
|
|
1681
|
+
|
|
1682
|
+
Raises:
|
|
1683
|
+
KeyError: if the key is not valid for the shape of the factor.
|
|
1684
|
+
"""
|
|
1685
|
+
return check_key(self._shape, key)
|
|
1686
|
+
|
|
1687
|
+
def valid_key(self, key: Key) -> bool:
|
|
1688
|
+
"""
|
|
1689
|
+
Is the given key valid.
|
|
1690
|
+
|
|
1691
|
+
Arg:
|
|
1692
|
+
key: defines an instance in the state space of the potential function.
|
|
1693
|
+
|
|
1694
|
+
Returns:
|
|
1695
|
+
True only if the given key is valid.
|
|
1696
|
+
"""
|
|
1697
|
+
return valid_key(self._shape, key)
|
|
1698
|
+
|
|
1699
|
+
def valid_parameter(self, param_idx: int) -> bool:
|
|
1700
|
+
"""
|
|
1701
|
+
Is the given parameter index valid.
|
|
1702
|
+
|
|
1703
|
+
Arg:
|
|
1704
|
+
param_idx: a parameter index.
|
|
1705
|
+
|
|
1706
|
+
Returns:
|
|
1707
|
+
True only if `0 <= param_idx < self.number_of_parameters`.
|
|
1708
|
+
"""
|
|
1709
|
+
return 0 <= param_idx < self.number_of_parameters
|
|
1710
|
+
|
|
1711
|
+
@property
|
|
1712
|
+
def is_sparse(self) -> bool:
|
|
1713
|
+
"""
|
|
1714
|
+
Are there any 'guaranteed zero' parameters values.
|
|
1715
|
+
|
|
1716
|
+
Returns:
|
|
1717
|
+
True only if `self.number_of_not_guaranteed_zero < self._number_of_states`.
|
|
1718
|
+
"""
|
|
1719
|
+
return self.number_of_not_guaranteed_zero < self._number_of_states
|
|
1720
|
+
|
|
1721
|
+
@property
|
|
1722
|
+
@abstractmethod
|
|
1723
|
+
def number_of_not_guaranteed_zero(self) -> int:
|
|
1724
|
+
"""
|
|
1725
|
+
How many of the states of this potential function are not 'guaranteed zero'.
|
|
1726
|
+
That is, how many keys are associated with a parameter.
|
|
1727
|
+
|
|
1728
|
+
Returns:
|
|
1729
|
+
The number of valid keys that are associated with a parameter.
|
|
1730
|
+
|
|
1731
|
+
Ensures:
|
|
1732
|
+
0 <= self.number_of_not_guaranteed_zero <= self.number_of_states.
|
|
1733
|
+
"""
|
|
1734
|
+
...
|
|
1735
|
+
|
|
1736
|
+
@property
|
|
1737
|
+
@abstractmethod
|
|
1738
|
+
def number_of_parameters(self) -> int:
|
|
1739
|
+
"""
|
|
1740
|
+
Get the number of parameters defining the potential function values.
|
|
1741
|
+
Each valid key of the function maps either to a parameter
|
|
1742
|
+
is 'guaranteed zero'.
|
|
1743
|
+
|
|
1744
|
+
Returns:
|
|
1745
|
+
The number of parameters.
|
|
1746
|
+
|
|
1747
|
+
Ensures:
|
|
1748
|
+
0 <= self.number_of_parameters <= self.number_of_not_guaranteed_zero.
|
|
1749
|
+
"""
|
|
1750
|
+
...
|
|
1751
|
+
|
|
1752
|
+
@property
|
|
1753
|
+
@abstractmethod
|
|
1754
|
+
def params(self) -> Iterable[Tuple[int, float]]:
|
|
1755
|
+
"""
|
|
1756
|
+
Iterate the parameters and their associated values.
|
|
1757
|
+
|
|
1758
|
+
Returns:
|
|
1759
|
+
An iterable over (param_idx, value) tuples, for every possible parameter.
|
|
1760
|
+
|
|
1761
|
+
Assumes:
|
|
1762
|
+
The potential function is not mutated while iterating.
|
|
1763
|
+
"""
|
|
1764
|
+
...
|
|
1765
|
+
|
|
1766
|
+
@property
|
|
1767
|
+
@abstractmethod
|
|
1768
|
+
def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
|
|
1769
|
+
"""
|
|
1770
|
+
Iterate the keys that have a parameter associated with them.
|
|
1771
|
+
|
|
1772
|
+
Returns:
|
|
1773
|
+
An iterable over (key, param_idx, value) tuples, for every key with an associated parameter.
|
|
1774
|
+
|
|
1775
|
+
Assumes:
|
|
1776
|
+
The potential function is not mutated while iterating.
|
|
1777
|
+
"""
|
|
1778
|
+
...
|
|
1779
|
+
|
|
1780
|
+
@abstractmethod
|
|
1781
|
+
def __getitem__(self, key: Key) -> float:
|
|
1782
|
+
"""
|
|
1783
|
+
Get the potential function value for the given instance key.
|
|
1784
|
+
|
|
1785
|
+
Arg:
|
|
1786
|
+
key: defines an instance in the state space of the potential function.
|
|
1787
|
+
|
|
1788
|
+
Returns:
|
|
1789
|
+
The value of the potential function for the given key.
|
|
1790
|
+
|
|
1791
|
+
Assumes:
|
|
1792
|
+
self.valid_key(key).
|
|
1793
|
+
"""
|
|
1794
|
+
...
|
|
1795
|
+
|
|
1796
|
+
@abstractmethod
|
|
1797
|
+
def param_value(self, param_idx: int) -> float:
|
|
1798
|
+
"""
|
|
1799
|
+
Get the potential function value by parameter index.
|
|
1800
|
+
|
|
1801
|
+
Arg:
|
|
1802
|
+
param_idx: a parameter index.
|
|
1803
|
+
|
|
1804
|
+
Assumes:
|
|
1805
|
+
`self.valid_parameter(param_idx)`.
|
|
1806
|
+
"""
|
|
1807
|
+
...
|
|
1808
|
+
|
|
1809
|
+
@abstractmethod
|
|
1810
|
+
def param_idx(self, key: Key) -> Optional[int]:
|
|
1811
|
+
"""
|
|
1812
|
+
Get the parameter index for the given potential function random variables states (key).
|
|
1813
|
+
|
|
1814
|
+
Arg:
|
|
1815
|
+
key: defines an instance in the state space of the potential function.
|
|
1816
|
+
|
|
1817
|
+
Returns:
|
|
1818
|
+
either `None` indicating a "guaranteed zero" value, or the parameter index holding
|
|
1819
|
+
the potential function value for the key.
|
|
1820
|
+
"""
|
|
1821
|
+
...
|
|
1822
|
+
|
|
1823
|
+
def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
|
|
1824
|
+
"""
|
|
1825
|
+
Is the potential function set with parameters values conforming to a
|
|
1826
|
+
Conditional Probability Table.
|
|
1827
|
+
|
|
1828
|
+
Every parameter value must be non-negative.
|
|
1829
|
+
For every state of the parent (non-first slots)
|
|
1830
|
+
the sum of the parameters over the child states (first slots)
|
|
1831
|
+
must be either 1 or 0.
|
|
1832
|
+
|
|
1833
|
+
Assumes:
|
|
1834
|
+
tolerance is non-negative.
|
|
1835
|
+
|
|
1836
|
+
Args:
|
|
1837
|
+
tolerance: a tolerance when testing if values are equal to zero or one.
|
|
1838
|
+
|
|
1839
|
+
Returns:
|
|
1840
|
+
True only if the potential function is compatible with being a CPT.
|
|
1841
|
+
"""
|
|
1842
|
+
# This default implementation calculates the result the long way, by checking
|
|
1843
|
+
# every valid key of the potential function.
|
|
1844
|
+
# Subclasses may override this implementation.
|
|
1845
|
+
low: float = 1.0 - tolerance
|
|
1846
|
+
high: float = 1.0 + tolerance
|
|
1847
|
+
for parent_state in self.parent_instances():
|
|
1848
|
+
total: float = sum(
|
|
1849
|
+
self[(state,) + tuple(parent_state)]
|
|
1850
|
+
for state in range(self.shape[0])
|
|
1851
|
+
)
|
|
1852
|
+
if not ((low <= total <= high) or (0 <= total <= tolerance)):
|
|
1853
|
+
return False
|
|
1854
|
+
return True
|
|
1855
|
+
|
|
1856
|
+
def natural_param_idx(self, key: Key) -> int:
|
|
1857
|
+
"""
|
|
1858
|
+
Get the natural parameter index for the given key. This is the same index as used
|
|
1859
|
+
by a DensePotentialFunction with the same shape.
|
|
1860
|
+
|
|
1861
|
+
Args:
|
|
1862
|
+
key: is a valid key of the potential function, referring to an instance in the factor's state space.
|
|
1863
|
+
|
|
1864
|
+
Assumes:
|
|
1865
|
+
`self.valid_key(key)` is true.
|
|
1866
|
+
|
|
1867
|
+
Returns:
|
|
1868
|
+
a hypothetical parameter index assuming that every valid key has a unique parameter
|
|
1869
|
+
as per DensePotentialFunction.
|
|
1870
|
+
"""
|
|
1871
|
+
return _natural_key_idx(self._shape, key)
|
|
1872
|
+
|
|
1873
|
+
def param_id(self, param_idx: int) -> ParamId:
|
|
1874
|
+
"""
|
|
1875
|
+
Get a hashable object to represent the parameter with the given parameter index.
|
|
1876
|
+
|
|
1877
|
+
Arg:
|
|
1878
|
+
param_idx: a parameter index.
|
|
1879
|
+
|
|
1880
|
+
Returns:
|
|
1881
|
+
a hashable ParamId object for the parameter of this potential function.
|
|
1882
|
+
|
|
1883
|
+
Raises:
|
|
1884
|
+
ValueError: if the parameter index is not valid.
|
|
1885
|
+
"""
|
|
1886
|
+
if not (0 <= param_idx < self.number_of_parameters):
|
|
1887
|
+
raise ValueError(f'invalid parameter index: {param_idx}')
|
|
1888
|
+
return ParamId(id(self), param_idx)
|
|
1889
|
+
|
|
1890
|
+
def items(self) -> Iterable[Tuple[Instance, float]]:
|
|
1891
|
+
"""
|
|
1892
|
+
Iterate over all keys and values of this potential function.
|
|
1893
|
+
|
|
1894
|
+
Returns:
|
|
1895
|
+
An iterator over all (key, value) pairs, where key is an Instance and value
|
|
1896
|
+
is the value of the potential function for the key.
|
|
1897
|
+
"""
|
|
1898
|
+
for key in _combos_ranges(self._shape, flip=True):
|
|
1899
|
+
yield key, self[key]
|
|
1900
|
+
|
|
1901
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1902
|
+
"""
|
|
1903
|
+
Iterate over all possible instances, in natural index order (i.e.,
|
|
1904
|
+
last random variable changing most quickly).
|
|
1905
|
+
|
|
1906
|
+
Args:
|
|
1907
|
+
flip: if true, then first random variable changes most quickly
|
|
1908
|
+
|
|
1909
|
+
Returns:
|
|
1910
|
+
an iterator over tuples, each tuple holds random variable
|
|
1911
|
+
state indexes, co-indexed with this object's shape, i.e., self.shape.
|
|
1912
|
+
"""
|
|
1913
|
+
return _combos_ranges(self._shape, flip=not flip)
|
|
1914
|
+
|
|
1915
|
+
def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1916
|
+
"""
|
|
1917
|
+
Iterate over all possible instances of parent random variable, in
|
|
1918
|
+
natural index order (i.e., last random variable changing most quickly).
|
|
1919
|
+
|
|
1920
|
+
Args:
|
|
1921
|
+
flip: if true, then first random variable changes most quickly
|
|
1922
|
+
|
|
1923
|
+
Returns:
|
|
1924
|
+
an iteration over tuples, each tuple holds random variable states
|
|
1925
|
+
co-indexed with this object's 'parent' shape, i.e., `self.shape[1:]`.
|
|
1926
|
+
"""
|
|
1927
|
+
return _combos_ranges(self._shape[1:], flip=not flip)
|
|
1928
|
+
|
|
1929
|
+
def __str__(self) -> str:
|
|
1930
|
+
"""
|
|
1931
|
+
Provide a human-readable representation of this potential function.
|
|
1932
|
+
This is intended mainly for debugging purposes.
|
|
1933
|
+
"""
|
|
1934
|
+
shape_str: str = ', '.join(str(x) for x in self._shape)
|
|
1935
|
+
return f'{self.__class__.__name__}({shape_str})'
|
|
1936
|
+
|
|
1937
|
+
def dump(
|
|
1938
|
+
self,
|
|
1939
|
+
*,
|
|
1940
|
+
prefix: str = '',
|
|
1941
|
+
indent: str = ' ',
|
|
1942
|
+
show_function_values: bool = False,
|
|
1943
|
+
show_id_class: bool = True,
|
|
1944
|
+
) -> None:
|
|
1945
|
+
"""
|
|
1946
|
+
Print a dump of the function.
|
|
1947
|
+
This is intended for debugging purposes.
|
|
1948
|
+
|
|
1949
|
+
Args:
|
|
1950
|
+
prefix: optional prefix for indenting all lines.
|
|
1951
|
+
indent: additional prefix to use for extra indentation.
|
|
1952
|
+
show_function_values: if true, then the function values will be dumped.
|
|
1953
|
+
show_id_class: if true, then the function id and class will be dumped.
|
|
1954
|
+
"""
|
|
1955
|
+
|
|
1956
|
+
shape_str: str = ', '.join(str(x) for x in self._shape)
|
|
1957
|
+
|
|
1958
|
+
if show_id_class:
|
|
1959
|
+
print(f'{prefix}id: {id(self)}')
|
|
1960
|
+
print(f'{prefix}class: {self.__class__.__name__}')
|
|
1961
|
+
print(f'{prefix}usage: {self.count_usage()}')
|
|
1962
|
+
print(f'{prefix}rvs: {self.number_of_rvs}')
|
|
1963
|
+
print(f'{prefix}shape: ({shape_str})')
|
|
1964
|
+
print(f'{prefix}states: {self._number_of_states}')
|
|
1965
|
+
print(f'{prefix}guaranteed zero: {self._number_of_states - self.number_of_not_guaranteed_zero}')
|
|
1966
|
+
print(f'{prefix}not guaranteed zero: {self.number_of_not_guaranteed_zero}')
|
|
1967
|
+
print(f'{prefix}parameters: {self.number_of_parameters}')
|
|
1968
|
+
if show_function_values:
|
|
1969
|
+
next_prefix = prefix + indent
|
|
1970
|
+
for key, param_idx, value in self.keys_with_param:
|
|
1971
|
+
print(f'{next_prefix}{param_idx} {key} = {value}')
|
|
1972
|
+
|
|
1973
|
+
|
|
1974
|
+
class ZeroPotentialFunction(PotentialFunction):
|
|
1975
|
+
"""
|
|
1976
|
+
A ZeroPotentialFunction behaves like a DensePotentialFunction
|
|
1977
|
+
in that there is a parameter for each possible key.
|
|
1978
|
+
However, a PGM user has no way to change parameter values.
|
|
1979
|
+
Parameter values are always zero.
|
|
1980
|
+
Despite the inability to change the value of the parameters,
|
|
1981
|
+
no key is considered 'guaranteed zero'.
|
|
1982
|
+
|
|
1983
|
+
The primary use of a ZeroPotentialFunction is as a placeholder
|
|
1984
|
+
within a factor, prior to parameter learning.
|
|
1985
|
+
"""
|
|
1986
|
+
__slots__ = ()
|
|
1987
|
+
|
|
1988
|
+
def __init__(self, factor: Factor):
|
|
1989
|
+
"""
|
|
1990
|
+
Create a potential function for the given factor.
|
|
1991
|
+
|
|
1992
|
+
Ensures:
|
|
1993
|
+
Does not hold a reference to the given factor.
|
|
1994
|
+
Does not register the potential function with the PGM.
|
|
1995
|
+
|
|
1996
|
+
Args:
|
|
1997
|
+
factor: which factor is this potential function is compatible with.
|
|
1998
|
+
"""
|
|
1999
|
+
super().__init__(factor)
|
|
2000
|
+
|
|
2001
|
+
@property
|
|
2002
|
+
def number_of_not_guaranteed_zero(self) -> int:
|
|
2003
|
+
return self.number_of_states
|
|
2004
|
+
|
|
2005
|
+
@property
|
|
2006
|
+
def number_of_parameters(self) -> int:
|
|
2007
|
+
return self.number_of_states
|
|
2008
|
+
|
|
2009
|
+
@property
|
|
2010
|
+
def params(self) -> Iterable[Tuple[int, float]]:
|
|
2011
|
+
for param_idx in range(self.number_of_parameters):
|
|
2012
|
+
yield param_idx, 0
|
|
2013
|
+
|
|
2014
|
+
@property
|
|
2015
|
+
def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
|
|
2016
|
+
for param_idx, instance in enumerate(self.instances()):
|
|
2017
|
+
yield instance, param_idx, 0
|
|
2018
|
+
|
|
2019
|
+
def __getitem__(self, key: Key) -> float:
|
|
2020
|
+
self.check_key(key)
|
|
2021
|
+
return 0
|
|
2022
|
+
|
|
2023
|
+
def param_value(self, param_idx: int) -> float:
|
|
2024
|
+
if not self.valid_parameter(param_idx):
|
|
2025
|
+
raise ValueError(f'invalid parameter index: {param_idx}')
|
|
2026
|
+
return 0
|
|
2027
|
+
|
|
2028
|
+
def param_idx(self, key: Key) -> int:
|
|
2029
|
+
return _natural_key_idx(self._shape, key)
|
|
2030
|
+
|
|
2031
|
+
def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
|
|
2032
|
+
return True
|
|
2033
|
+
|
|
2034
|
+
|
|
2035
|
+
class DensePotentialFunction(PotentialFunction):
|
|
2036
|
+
"""
|
|
2037
|
+
A dense (tabular) potential function.
|
|
2038
|
+
There is one parameter for each valid key of the potential function.
|
|
2039
|
+
The initial value for each parameter is zero.
|
|
2040
|
+
It is possible independently change any value corresponding to any key.
|
|
2041
|
+
"""
|
|
2042
|
+
|
|
2043
|
+
def __init__(self, factor: Factor):
|
|
2044
|
+
"""
|
|
2045
|
+
Create a potential function for the given factor.
|
|
2046
|
+
|
|
2047
|
+
Ensures:
|
|
2048
|
+
Does not hold a reference to the given factor.
|
|
2049
|
+
Does not register the potential function with the PGM.
|
|
2050
|
+
|
|
2051
|
+
Args:
|
|
2052
|
+
factor: which factor is this potential function is compatible with.
|
|
2053
|
+
"""
|
|
2054
|
+
super().__init__(factor)
|
|
2055
|
+
self._values: NDArrayFloat64 = np.zeros(self.number_of_states, dtype=np.float64)
|
|
2056
|
+
|
|
2057
|
+
@property
|
|
2058
|
+
def number_of_not_guaranteed_zero(self) -> int:
|
|
2059
|
+
return self.number_of_states
|
|
2060
|
+
|
|
2061
|
+
@property
|
|
2062
|
+
def number_of_parameters(self) -> int:
|
|
2063
|
+
return self.number_of_states
|
|
2064
|
+
|
|
2065
|
+
def __getitem__(self, key: Key) -> float:
|
|
2066
|
+
return self._values.item(self.param_idx(key))
|
|
2067
|
+
|
|
2068
|
+
def param_value(self, param_idx: int) -> float:
|
|
2069
|
+
return self._values.item(param_idx)
|
|
2070
|
+
|
|
2071
|
+
def param_idx(self, key: Key) -> Optional[int]:
|
|
2072
|
+
return self.natural_param_idx(key)
|
|
2073
|
+
|
|
2074
|
+
@property
|
|
2075
|
+
def params(self) -> Iterable[Tuple[int, float]]:
|
|
2076
|
+
# Type warning due to numpy type erasure
|
|
2077
|
+
# noinspection PyTypeChecker
|
|
2078
|
+
return enumerate(self._values)
|
|
2079
|
+
|
|
2080
|
+
@property
|
|
2081
|
+
def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
|
|
2082
|
+
for param_idx, key in enumerate(self.instances()):
|
|
2083
|
+
value: float = self.param_value(param_idx)
|
|
2084
|
+
yield key, param_idx, value
|
|
2085
|
+
|
|
2086
|
+
# Mutators
|
|
2087
|
+
|
|
2088
|
+
def __setitem__(self, key: Key, value: float) -> None:
|
|
2089
|
+
"""
|
|
2090
|
+
Set the potential function value, for a given key.
|
|
2091
|
+
|
|
2092
|
+
Arg:
|
|
2093
|
+
key: defines an instance in the state space of the potential function.
|
|
2094
|
+
value: the new value of the potential function for the given key.
|
|
2095
|
+
|
|
2096
|
+
Assumes:
|
|
2097
|
+
self.valid_key(key).
|
|
2098
|
+
"""
|
|
2099
|
+
self._values[self.param_idx(key)] = value
|
|
2100
|
+
|
|
2101
|
+
def set_param_value(self, param_idx: int, value: float) -> None:
|
|
2102
|
+
"""
|
|
2103
|
+
Set the parameter value.
|
|
2104
|
+
|
|
2105
|
+
Arg:
|
|
2106
|
+
param_idx: is the index of the parameter.
|
|
2107
|
+
value: the new value of the potential function for the given key.
|
|
2108
|
+
|
|
2109
|
+
Assumes:
|
|
2110
|
+
self.valid_param(param_idx).
|
|
2111
|
+
"""
|
|
2112
|
+
self._values[param_idx] = value
|
|
2113
|
+
|
|
2114
|
+
def clear(self) -> DensePotentialFunction:
|
|
2115
|
+
"""
|
|
2116
|
+
Set all values of the potential function to zero.
|
|
2117
|
+
|
|
2118
|
+
Returns:
|
|
2119
|
+
self
|
|
2120
|
+
"""
|
|
2121
|
+
return self.set_all(0)
|
|
2122
|
+
|
|
2123
|
+
def normalise_cpt(self) -> DensePotentialFunction:
|
|
2124
|
+
"""
|
|
2125
|
+
Normalise the parameter values as if this was a CPT.
|
|
2126
|
+
That is, treat the first random variable as the child and the others as parents;
|
|
2127
|
+
for each combination of parent states, ensure the parameters over the child
|
|
2128
|
+
states sum to 1 (or 0).
|
|
2129
|
+
|
|
2130
|
+
Assumes:
|
|
2131
|
+
There are no negative parameter values.
|
|
2132
|
+
|
|
2133
|
+
Returns:
|
|
2134
|
+
self
|
|
2135
|
+
"""
|
|
2136
|
+
child = self._shape[0]
|
|
2137
|
+
parents = self._shape[1:]
|
|
2138
|
+
for parent_states in _combos_ranges(parents):
|
|
2139
|
+
keys = [(c,) + parent_states for c in range(child)]
|
|
2140
|
+
total = sum(self[key] for key in keys)
|
|
2141
|
+
if total != 0 and total != 1:
|
|
2142
|
+
for key in keys:
|
|
2143
|
+
self[key] /= total
|
|
2144
|
+
return self
|
|
2145
|
+
|
|
2146
|
+
def normalise(self, grouping_positions: Sequence[int] = ()) -> DensePotentialFunction:
|
|
2147
|
+
"""
|
|
2148
|
+
Convert the potential function to a CPT with 'grouping_positions' nominating
|
|
2149
|
+
the parent random variables.
|
|
2150
|
+
|
|
2151
|
+
I.e., for each possible key of the function with the same value at each
|
|
2152
|
+
grouping position, the sum of values for matching keys in the factor is scaled
|
|
2153
|
+
to be 1 (or 0).
|
|
2154
|
+
|
|
2155
|
+
Parameter 'grouping_positions' are indices into `self.shape`. For example, the
|
|
2156
|
+
grouping positions of a factor with parent rvs 'conditioning_rvs', then
|
|
2157
|
+
grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
|
|
2158
|
+
|
|
2159
|
+
Args:
|
|
2160
|
+
grouping_positions: indices into `self.shape`.
|
|
2161
|
+
|
|
2162
|
+
Returns:
|
|
2163
|
+
self
|
|
2164
|
+
"""
|
|
2165
|
+
_normalise_potential_function(self, grouping_positions)
|
|
2166
|
+
return self
|
|
2167
|
+
|
|
2168
|
+
def set_iter(self, values: Iterable[float]) -> DensePotentialFunction:
|
|
2169
|
+
"""
|
|
2170
|
+
Set the values of the potential function using the given iterator.
|
|
2171
|
+
|
|
2172
|
+
Mapping instances to *values is as follows:
|
|
2173
|
+
Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
|
|
2174
|
+
values[0] represents instance (0,0)
|
|
2175
|
+
values[1] represents instance (0,1)
|
|
2176
|
+
values[2] represents instance (0,2)
|
|
2177
|
+
values[3] represents instance (1,0)
|
|
2178
|
+
values[4] represents instance (1,1)
|
|
2179
|
+
values[5] represents instance (1,2).
|
|
2180
|
+
|
|
2181
|
+
For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
|
|
2182
|
+
|
|
2183
|
+
Args:
|
|
2184
|
+
values: an iterable providing values to use.
|
|
2185
|
+
|
|
2186
|
+
Returns:
|
|
2187
|
+
self
|
|
2188
|
+
"""
|
|
2189
|
+
self._values = np.fromiter(
|
|
2190
|
+
values,
|
|
2191
|
+
dtype=np.float64,
|
|
2192
|
+
count=self.number_of_parameters
|
|
2193
|
+
)
|
|
2194
|
+
return self
|
|
2195
|
+
|
|
2196
|
+
def set_stream(self, stream: Callable[[], float]) -> DensePotentialFunction:
|
|
2197
|
+
"""
|
|
2198
|
+
Set the values of the potential function by repeatedly calling the stream function.
|
|
2199
|
+
The order of values is the same as set_iter.
|
|
2200
|
+
|
|
2201
|
+
For example, to set to random numbers, use `self.set_stream(random.random)`.
|
|
2202
|
+
|
|
2203
|
+
Args:
|
|
2204
|
+
stream: a callable taking no arguments, returning the values to use.
|
|
2205
|
+
|
|
2206
|
+
Returns:
|
|
2207
|
+
self
|
|
2208
|
+
"""
|
|
2209
|
+
return self.set_iter(iter(stream, None))
|
|
2210
|
+
|
|
2211
|
+
def set_flat(self, *value: float) -> DensePotentialFunction:
|
|
2212
|
+
"""
|
|
2213
|
+
Set the values of the potential function to the given values.
|
|
2214
|
+
The order of values is the same as set_iter.
|
|
2215
|
+
|
|
2216
|
+
Args:
|
|
2217
|
+
*value: the values to use.
|
|
2218
|
+
|
|
2219
|
+
Returns:
|
|
2220
|
+
self
|
|
2221
|
+
|
|
2222
|
+
Raises:
|
|
2223
|
+
ValueError: if `len(value) != self.number_of_states`.
|
|
2224
|
+
"""
|
|
2225
|
+
if len(value) != self.number_of_states:
|
|
2226
|
+
raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
|
|
2227
|
+
return self.set_iter(value)
|
|
2228
|
+
|
|
2229
|
+
def set_all(self, value: float) -> DensePotentialFunction:
|
|
2230
|
+
"""
|
|
2231
|
+
Set all values of the potential function to the given value.
|
|
2232
|
+
|
|
2233
|
+
Args:
|
|
2234
|
+
value: the value to use.
|
|
2235
|
+
|
|
2236
|
+
Returns:
|
|
2237
|
+
self
|
|
2238
|
+
"""
|
|
2239
|
+
return self.set_iter(_repeat(value))
|
|
2240
|
+
|
|
2241
|
+
def set_uniform(self) -> DensePotentialFunction:
|
|
2242
|
+
"""
|
|
2243
|
+
Set all values of the potential function 1/number_of_states.
|
|
2244
|
+
|
|
2245
|
+
Returns:
|
|
2246
|
+
self
|
|
2247
|
+
"""
|
|
2248
|
+
return self.set_all(1.0 / self.number_of_states)
|
|
2249
|
+
|
|
2250
|
+
|
|
2251
|
+
class SparsePotentialFunction(PotentialFunction):
|
|
2252
|
+
"""
|
|
2253
|
+
A sparse potential function.
|
|
2254
|
+
|
|
2255
|
+
There is one parameter for each non-zero key value.
|
|
2256
|
+
The user may set the value for any key and parameters will
|
|
2257
|
+
be automatically reconfigured as needed. Setting the value for
|
|
2258
|
+
a key to zero disassociates the key from its parameter and
|
|
2259
|
+
thus makes that key "guaranteed zero".
|
|
2260
|
+
"""
|
|
2261
|
+
|
|
2262
|
+
def __init__(self, factor: Factor):
|
|
2263
|
+
"""
|
|
2264
|
+
Create a potential function for the given factor.
|
|
2265
|
+
|
|
2266
|
+
Ensures:
|
|
2267
|
+
Does not hold a reference to the given factor.
|
|
2268
|
+
Does not register the potential function with the PGM.
|
|
2269
|
+
|
|
2270
|
+
Args:
|
|
2271
|
+
factor: which factor is this potential function is compatible with.
|
|
2272
|
+
"""
|
|
2273
|
+
super().__init__(factor)
|
|
2274
|
+
self._values: List[float] = []
|
|
2275
|
+
self._params: Dict[Instance, int] = {}
|
|
2276
|
+
|
|
2277
|
+
@property
|
|
2278
|
+
def number_of_not_guaranteed_zero(self) -> int:
|
|
2279
|
+
return len(self._params)
|
|
2280
|
+
|
|
2281
|
+
@property
|
|
2282
|
+
def number_of_parameters(self) -> int:
|
|
2283
|
+
return len(self._params)
|
|
2284
|
+
|
|
2285
|
+
def __getitem__(self, key: Key) -> float:
|
|
2286
|
+
param_idx: Optional[int] = self.param_idx(key)
|
|
2287
|
+
if param_idx is None:
|
|
2288
|
+
return 0
|
|
2289
|
+
else:
|
|
2290
|
+
return self._values[param_idx]
|
|
2291
|
+
|
|
2292
|
+
def param_value(self, param_idx: int) -> float:
|
|
2293
|
+
return self._values[param_idx]
|
|
2294
|
+
|
|
2295
|
+
def param_idx(self, key: Key) -> Optional[int]:
|
|
2296
|
+
return self._params.get(_key_to_instance(key))
|
|
2297
|
+
|
|
2298
|
+
@property
|
|
2299
|
+
def params(self) -> Iterable[Tuple[int, float]]:
|
|
2300
|
+
return enumerate(self._values)
|
|
2301
|
+
|
|
2302
|
+
@property
|
|
2303
|
+
def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
|
|
2304
|
+
for key, param_idx in self._params.items():
|
|
2305
|
+
value: float = self._values[param_idx]
|
|
2306
|
+
yield key, param_idx, value
|
|
2307
|
+
|
|
2308
|
+
# Mutators
|
|
2309
|
+
|
|
2310
|
+
def __setitem__(self, key: Key, value: float) -> None:
|
|
2311
|
+
"""
|
|
2312
|
+
Set the potential function value, for a given key.
|
|
2313
|
+
|
|
2314
|
+
If value is zero, then the key will become "guaranteed zero".
|
|
2315
|
+
|
|
2316
|
+
Arg:
|
|
2317
|
+
key: defines an instance in the state space of the potential function.
|
|
2318
|
+
value: the new value of the potential function for the given key.
|
|
2319
|
+
|
|
2320
|
+
Assumes:
|
|
2321
|
+
self.valid_key(key).
|
|
2322
|
+
"""
|
|
2323
|
+
instance: Instance = _key_to_instance(key)
|
|
2324
|
+
param_idx: Optional[int] = self._params.get(instance)
|
|
2325
|
+
|
|
2326
|
+
if param_idx is None:
|
|
2327
|
+
if value == 0:
|
|
2328
|
+
# Nothing to do
|
|
2329
|
+
return
|
|
2330
|
+
param_idx = len(self._values)
|
|
2331
|
+
self._values.append(value)
|
|
2332
|
+
self._params[instance] = param_idx
|
|
2333
|
+
return
|
|
2334
|
+
|
|
2335
|
+
if value != 0:
|
|
2336
|
+
# Simple case
|
|
2337
|
+
self._values[param_idx] = value
|
|
2338
|
+
return
|
|
2339
|
+
|
|
2340
|
+
# This is the case where the key was associated with a parameter
|
|
2341
|
+
# but the value is being set to zero, so we
|
|
2342
|
+
# need to clear an existing non-zero parameter.
|
|
2343
|
+
# This code operates by first ensuring the parameter is the last one,
|
|
2344
|
+
# then popping the last parameter.
|
|
2345
|
+
|
|
2346
|
+
end: int = len(self._values) - 1
|
|
2347
|
+
if param_idx != end:
|
|
2348
|
+
# need to swap the parameter with the end.
|
|
2349
|
+
self._values[param_idx] = self._values[end]
|
|
2350
|
+
|
|
2351
|
+
for test_instance, test_param_idx in self._params.items():
|
|
2352
|
+
if test_param_idx == end:
|
|
2353
|
+
self._params[test_instance] = param_idx
|
|
2354
|
+
# There will only be one, so we can break now
|
|
2355
|
+
break
|
|
2356
|
+
|
|
2357
|
+
# Remove the parameter
|
|
2358
|
+
self._values.pop()
|
|
2359
|
+
self._params.pop(instance)
|
|
2360
|
+
|
|
2361
|
+
def set_param_value(self, param_idx: int, value: float) -> None:
|
|
2362
|
+
"""
|
|
2363
|
+
Set the parameter value.
|
|
2364
|
+
|
|
2365
|
+
Arg:
|
|
2366
|
+
param_idx: is the index of the parameter.
|
|
2367
|
+
value: the new value of the potential function for the given key.
|
|
2368
|
+
|
|
2369
|
+
Assumes:
|
|
2370
|
+
self.valid_param(param_idx).
|
|
2371
|
+
"""
|
|
2372
|
+
self._values[param_idx] = value
|
|
2373
|
+
|
|
2374
|
+
def clear(self) -> SparsePotentialFunction:
|
|
2375
|
+
"""
|
|
2376
|
+
Set all values of the potential function to zero.
|
|
2377
|
+
|
|
2378
|
+
Returns:
|
|
2379
|
+
self
|
|
2380
|
+
"""
|
|
2381
|
+
self._values = []
|
|
2382
|
+
self._params = {}
|
|
2383
|
+
return self
|
|
2384
|
+
|
|
2385
|
+
def normalise_cpt(self) -> SparsePotentialFunction:
|
|
2386
|
+
"""
|
|
2387
|
+
Normalise the parameter values as if this was a CPT.
|
|
2388
|
+
That is, treat the first random variable as the child and the others as parents;
|
|
2389
|
+
for each combination of parent states, ensure the parameters over
|
|
2390
|
+
the child states sum to 1 (or 0).
|
|
2391
|
+
|
|
2392
|
+
Returns:
|
|
2393
|
+
self
|
|
2394
|
+
"""
|
|
2395
|
+
grouping_positions = list(range(1, self.number_of_rvs))
|
|
2396
|
+
_normalise_potential_function(self, grouping_positions)
|
|
2397
|
+
return self
|
|
2398
|
+
|
|
2399
|
+
def normalise(self, grouping_positions=()) -> SparsePotentialFunction:
|
|
2400
|
+
"""
|
|
2401
|
+
Convert the potential function to a CPT with 'grouping_positions' nominating
|
|
2402
|
+
the parent random variables.
|
|
2403
|
+
|
|
2404
|
+
I.e., for each possible key of the function with the same value at each
|
|
2405
|
+
grouping position, the sum of values for matching keys in the factor is scaled
|
|
2406
|
+
to be 1 (or 0).
|
|
2407
|
+
|
|
2408
|
+
Parameter 'grouping_positions' are indices into function.shape. For example, the
|
|
2409
|
+
grouping positions of a factor with parent rvs 'conditioning_rvs', then
|
|
2410
|
+
grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
|
|
2411
|
+
|
|
2412
|
+
Returns:
|
|
2413
|
+
self
|
|
2414
|
+
"""
|
|
2415
|
+
_normalise_potential_function(self, grouping_positions)
|
|
2416
|
+
return self
|
|
2417
|
+
|
|
2418
|
+
def set_iter(self, values: Iterable[float]) -> SparsePotentialFunction:
|
|
2419
|
+
"""
|
|
2420
|
+
Set the values of the potential function using the given iterator.
|
|
2421
|
+
|
|
2422
|
+
Mapping instances to *values is as follows:
|
|
2423
|
+
Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
|
|
2424
|
+
values[0] represents instance (0,0)
|
|
2425
|
+
values[1] represents instance (0,1)
|
|
2426
|
+
values[2] represents instance (0,2)
|
|
2427
|
+
values[3] represents instance (1,0)
|
|
2428
|
+
values[4] represents instance (1,1)
|
|
2429
|
+
values[5] represents instance (1,2).
|
|
2430
|
+
|
|
2431
|
+
For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
|
|
2432
|
+
|
|
2433
|
+
Args:
|
|
2434
|
+
values: an iterable providing values to use.
|
|
2435
|
+
|
|
2436
|
+
Returns:
|
|
2437
|
+
self
|
|
2438
|
+
"""
|
|
2439
|
+
self.clear()
|
|
2440
|
+
for instance, value in zip(self.instances(), values):
|
|
2441
|
+
if value != 0:
|
|
2442
|
+
self._params[instance] = len(self._values)
|
|
2443
|
+
self._values.append(value)
|
|
2444
|
+
return self
|
|
2445
|
+
|
|
2446
|
+
def set_stream(self, stream: Callable[[], float]) -> SparsePotentialFunction:
|
|
2447
|
+
"""
|
|
2448
|
+
Set the values of the potential function by repeatedly calling the stream function.
|
|
2449
|
+
The order of values is the same as set_iter.
|
|
2450
|
+
|
|
2451
|
+
For example, to set to random numbers, use `self.set_stream(random.random)`.
|
|
2452
|
+
|
|
2453
|
+
Args:
|
|
2454
|
+
stream: a callable taking no arguments, returning the values to use.
|
|
2455
|
+
|
|
2456
|
+
Returns:
|
|
2457
|
+
self
|
|
2458
|
+
"""
|
|
2459
|
+
return self.set_iter(iter(stream, None))
|
|
2460
|
+
|
|
2461
|
+
def set_flat(self, *value: float) -> SparsePotentialFunction:
|
|
2462
|
+
"""
|
|
2463
|
+
Set the values of the potential function to the given values.
|
|
2464
|
+
The order of values is the same as set_iter.
|
|
2465
|
+
|
|
2466
|
+
Args:
|
|
2467
|
+
*value: the values to use.
|
|
2468
|
+
|
|
2469
|
+
Returns:
|
|
2470
|
+
self
|
|
2471
|
+
|
|
2472
|
+
Raises:
|
|
2473
|
+
ValueError: if `len(value) != self.number_of_states`.
|
|
2474
|
+
"""
|
|
2475
|
+
if len(value) != self.number_of_states:
|
|
2476
|
+
raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
|
|
2477
|
+
return self.set_iter(value)
|
|
2478
|
+
|
|
2479
|
+
def set_all(self, value: float) -> SparsePotentialFunction:
|
|
2480
|
+
"""
|
|
2481
|
+
Set all values of the potential function to the given value.
|
|
2482
|
+
|
|
2483
|
+
Args:
|
|
2484
|
+
value: the value to use.
|
|
2485
|
+
|
|
2486
|
+
Returns:
|
|
2487
|
+
self
|
|
2488
|
+
"""
|
|
2489
|
+
if value == 0:
|
|
2490
|
+
return self.clear()
|
|
2491
|
+
else:
|
|
2492
|
+
return self.set_iter(_repeat(value))
|
|
2493
|
+
|
|
2494
|
+
def set_uniform(self) -> SparsePotentialFunction:
|
|
2495
|
+
"""
|
|
2496
|
+
Set all values of the potential function 1/number_of_states.
|
|
2497
|
+
|
|
2498
|
+
Returns:
|
|
2499
|
+
self
|
|
2500
|
+
"""
|
|
2501
|
+
return self.set_all(1.0 / self.number_of_states)
|
|
2502
|
+
|
|
2503
|
+
|
|
2504
|
+
class CompactPotentialFunction(PotentialFunction):
|
|
2505
|
+
"""
|
|
2506
|
+
A compact potential function is sparse, where values for keys of
|
|
2507
|
+
the same value are represented by a single parameter.
|
|
2508
|
+
|
|
2509
|
+
There is one parameter for each unique, non-zero key value.
|
|
2510
|
+
The user may set the value for any key and parameters will
|
|
2511
|
+
be automatically reconfigured as needed. Setting the value for
|
|
2512
|
+
a key to zero disassociates the key from its parameter and
|
|
2513
|
+
thus makes that key "guaranteed zero".
|
|
2514
|
+
"""
|
|
2515
|
+
|
|
2516
|
+
def __init__(self, factor: Factor):
|
|
2517
|
+
"""
|
|
2518
|
+
Create a potential function for the given factor.
|
|
2519
|
+
|
|
2520
|
+
Ensures:
|
|
2521
|
+
Does not hold a reference to the given factor.
|
|
2522
|
+
Does not register the potential function with the PGM.
|
|
2523
|
+
|
|
2524
|
+
Args:
|
|
2525
|
+
factor: which factor is this potential function is compatible with.
|
|
2526
|
+
"""
|
|
2527
|
+
super().__init__(factor)
|
|
2528
|
+
self._values: List[float] = []
|
|
2529
|
+
self._counts: List[int] = []
|
|
2530
|
+
self._map: Dict[Instance, int] = {}
|
|
2531
|
+
self._inv_map: Dict[float, int] = {}
|
|
2532
|
+
|
|
2533
|
+
@property
|
|
2534
|
+
def number_of_not_guaranteed_zero(self) -> int:
|
|
2535
|
+
return len(self._map)
|
|
2536
|
+
|
|
2537
|
+
@property
|
|
2538
|
+
def number_of_parameters(self) -> int:
|
|
2539
|
+
return len(self._values)
|
|
2540
|
+
|
|
2541
|
+
@property
|
|
2542
|
+
def params(self) -> Iterable[Tuple[int, float]]:
|
|
2543
|
+
return enumerate(self._values)
|
|
2544
|
+
|
|
2545
|
+
@property
|
|
2546
|
+
def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
|
|
2547
|
+
for key, param_idx in self._map.items():
|
|
2548
|
+
value: float = self._values[param_idx]
|
|
2549
|
+
yield key, param_idx, value
|
|
2550
|
+
|
|
2551
|
+
def __getitem__(self, key: Key) -> float:
|
|
2552
|
+
param_idx: Optional[int] = self.param_idx(key)
|
|
2553
|
+
if param_idx is None:
|
|
2554
|
+
return 0
|
|
2555
|
+
else:
|
|
2556
|
+
return self._values[param_idx]
|
|
2557
|
+
|
|
2558
|
+
def param_value(self, param_idx: int) -> float:
|
|
2559
|
+
return self._values[param_idx]
|
|
2560
|
+
|
|
2561
|
+
def param_idx(self, key: Key) -> Optional[int]:
|
|
2562
|
+
return self._map.get(_key_to_instance(key))
|
|
2563
|
+
|
|
2564
|
+
# Mutators
|
|
2565
|
+
|
|
2566
|
+
def __setitem__(self, key: Key, value: float) -> None:
|
|
2567
|
+
"""
|
|
2568
|
+
Set the potential function value, for a given key.
|
|
2569
|
+
|
|
2570
|
+
If value is zero, then the key will become "guaranteed zero".
|
|
2571
|
+
If the value is the same as an existing parameter value, then
|
|
2572
|
+
that parameter will be reused.
|
|
2573
|
+
|
|
2574
|
+
Arg:
|
|
2575
|
+
key: defines an instance in the state space of the potential function.
|
|
2576
|
+
value: the new value of the potential function for the given key.
|
|
2577
|
+
|
|
2578
|
+
Assumes:
|
|
2579
|
+
self.valid_key(key).
|
|
2580
|
+
"""
|
|
2581
|
+
instance: Instance = _key_to_instance(key)
|
|
2582
|
+
|
|
2583
|
+
param_idx: Optional[int] = self._map.get(instance)
|
|
2584
|
+
|
|
2585
|
+
if param_idx is None:
|
|
2586
|
+
# previous value for the key was zero
|
|
2587
|
+
if value == 0:
|
|
2588
|
+
# nothing to do
|
|
2589
|
+
return
|
|
2590
|
+
param_idx: Optional[int] = self._inv_map.get(value)
|
|
2591
|
+
if param_idx is not None:
|
|
2592
|
+
# the value already exists in the function, so reuse it
|
|
2593
|
+
self._map[instance] = param_idx
|
|
2594
|
+
self._counts[param_idx] += 1
|
|
2595
|
+
else:
|
|
2596
|
+
# need to allocate a new value
|
|
2597
|
+
new_param_idx: int = len(self._values)
|
|
2598
|
+
self._values.append(value)
|
|
2599
|
+
self._counts.append(1)
|
|
2600
|
+
self._inv_map[value] = new_param_idx
|
|
2601
|
+
self._map[instance] = new_param_idx
|
|
2602
|
+
return
|
|
2603
|
+
|
|
2604
|
+
# the key previously had a non-zero value
|
|
2605
|
+
prev_value: float = self._values[param_idx]
|
|
2606
|
+
|
|
2607
|
+
if value == prev_value:
|
|
2608
|
+
# nothing to do
|
|
2609
|
+
return
|
|
2610
|
+
|
|
2611
|
+
reference_count: int = self._counts[param_idx]
|
|
2612
|
+
if reference_count == 1:
|
|
2613
|
+
if value != 0:
|
|
2614
|
+
# simple case
|
|
2615
|
+
self._values[param_idx] = value
|
|
2616
|
+
else:
|
|
2617
|
+
# need to remove the parameter
|
|
2618
|
+
self._remove_param(param_idx)
|
|
2619
|
+
self._map.pop(instance)
|
|
2620
|
+
self._inv_map.pop(prev_value)
|
|
2621
|
+
return
|
|
2622
|
+
|
|
2623
|
+
# decrement the reference count of the previous parameter
|
|
2624
|
+
self._counts[param_idx] = reference_count - 1
|
|
2625
|
+
|
|
2626
|
+
# allocate the key to a different parameter
|
|
2627
|
+
param_idx: Optional[int] = self._inv_map.get(value)
|
|
2628
|
+
if param_idx is not None:
|
|
2629
|
+
# the value already exists in the function, so reuse it
|
|
2630
|
+
self._map[instance] = param_idx
|
|
2631
|
+
self._counts[param_idx] += 1
|
|
2632
|
+
else:
|
|
2633
|
+
# need to allocate a new value
|
|
2634
|
+
new_param_idx: int = len(self._values)
|
|
2635
|
+
self._values.append(value)
|
|
2636
|
+
self._counts.append(1)
|
|
2637
|
+
self._inv_map[value] = new_param_idx
|
|
2638
|
+
self._map[instance] = new_param_idx
|
|
2639
|
+
|
|
2640
|
+
def set_iter(self, values: Iterable[float]) -> CompactPotentialFunction:
|
|
2641
|
+
"""
|
|
2642
|
+
Set the values of the potential function using the given iterator.
|
|
2643
|
+
|
|
2644
|
+
Mapping instances to *values is as follows:
|
|
2645
|
+
Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
|
|
2646
|
+
values[0] represents instance (0,0)
|
|
2647
|
+
values[1] represents instance (0,1)
|
|
2648
|
+
values[2] represents instance (0,2)
|
|
2649
|
+
values[3] represents instance (1,0)
|
|
2650
|
+
values[4] represents instance (1,1)
|
|
2651
|
+
values[5] represents instance (1,2).
|
|
2652
|
+
|
|
2653
|
+
For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
|
|
2654
|
+
|
|
2655
|
+
Args:
|
|
2656
|
+
values: an iterable providing values to use.
|
|
2657
|
+
|
|
2658
|
+
Returns:
|
|
2659
|
+
self
|
|
2660
|
+
"""
|
|
2661
|
+
self.clear()
|
|
2662
|
+
for instance, value in zip(self.instances(), values):
|
|
2663
|
+
self[instance] = value
|
|
2664
|
+
return self
|
|
2665
|
+
|
|
2666
|
+
def set_stream(self, stream: Callable[[], float]) -> CompactPotentialFunction:
|
|
2667
|
+
"""
|
|
2668
|
+
Set the values of the potential function by repeatedly calling the stream function.
|
|
2669
|
+
The order of values is the same as set_iter.
|
|
2670
|
+
|
|
2671
|
+
For example, to set to random numbers, use `self.set_stream(random.random)`.
|
|
2672
|
+
|
|
2673
|
+
Args:
|
|
2674
|
+
stream: a callable taking no arguments, returning the values to use.
|
|
2675
|
+
|
|
2676
|
+
Returns:
|
|
2677
|
+
self
|
|
2678
|
+
"""
|
|
2679
|
+
return self.set_iter(iter(stream, None))
|
|
2680
|
+
|
|
2681
|
+
def set_flat(self, *value: float) -> CompactPotentialFunction:
|
|
2682
|
+
"""
|
|
2683
|
+
Set the values of the potential function to the given values.
|
|
2684
|
+
The order of values is the same as set_iter.
|
|
2685
|
+
|
|
2686
|
+
Args:
|
|
2687
|
+
*value: the values to use.
|
|
2688
|
+
|
|
2689
|
+
Returns:
|
|
2690
|
+
self
|
|
2691
|
+
|
|
2692
|
+
Raises:
|
|
2693
|
+
ValueError: if `len(value) != self.number_of_states`.
|
|
2694
|
+
"""
|
|
2695
|
+
if len(value) != self.number_of_states:
|
|
2696
|
+
raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
|
|
2697
|
+
return self.set_iter(value)
|
|
2698
|
+
|
|
2699
|
+
def set_all(self, value: float) -> CompactPotentialFunction:
|
|
2700
|
+
"""
|
|
2701
|
+
Set all values of the potential function to the given value.
|
|
2702
|
+
|
|
2703
|
+
Args:
|
|
2704
|
+
value: the value to use.
|
|
2705
|
+
|
|
2706
|
+
Returns:
|
|
2707
|
+
self
|
|
2708
|
+
"""
|
|
2709
|
+
self.clear()
|
|
2710
|
+
if value != 0:
|
|
2711
|
+
self._values = [value]
|
|
2712
|
+
self._counts = [self.number_of_states]
|
|
2713
|
+
self._inv_map = {value: 0}
|
|
2714
|
+
self._map = {instance: 0 for instance in self.instances()}
|
|
2715
|
+
return self
|
|
2716
|
+
|
|
2717
|
+
def set_uniform(self) -> CompactPotentialFunction:
|
|
2718
|
+
"""
|
|
2719
|
+
Set all values of the potential function 1/number_of_states.
|
|
2720
|
+
|
|
2721
|
+
Returns:
|
|
2722
|
+
self
|
|
2723
|
+
"""
|
|
2724
|
+
return self.set_all(1.0 / self.number_of_states)
|
|
2725
|
+
|
|
2726
|
+
def clear(self) -> CompactPotentialFunction:
|
|
2727
|
+
"""
|
|
2728
|
+
Set all values of the potential function to zero.
|
|
2729
|
+
|
|
2730
|
+
Returns:
|
|
2731
|
+
self
|
|
2732
|
+
"""
|
|
2733
|
+
self._values = []
|
|
2734
|
+
self._counts = []
|
|
2735
|
+
self._map = {}
|
|
2736
|
+
self._inv_map = {}
|
|
2737
|
+
return self
|
|
2738
|
+
|
|
2739
|
+
def _remove_param(self, param_idx: int) -> None:
|
|
2740
|
+
"""
|
|
2741
|
+
Remove the indexed parameter from self._params and self._counts.
|
|
2742
|
+
If the parameter is not at the end of the list of parameters
|
|
2743
|
+
then it will be swapped with the last parameter in the list.
|
|
2744
|
+
"""
|
|
2745
|
+
|
|
2746
|
+
# ensure the parameter is at the end of the list
|
|
2747
|
+
end: int = len(self._values) - 1
|
|
2748
|
+
if param_idx != end:
|
|
2749
|
+
# swap `param_idx` with `end`
|
|
2750
|
+
end_value: float = self._values[end]
|
|
2751
|
+
self._values[param_idx] = end_value
|
|
2752
|
+
self._counts[param_idx] = self._counts[end]
|
|
2753
|
+
self._inv_map[end_value] = param_idx
|
|
2754
|
+
for instance, instance_param_idx in self._map.items():
|
|
2755
|
+
if instance_param_idx == end:
|
|
2756
|
+
self._map[instance] = param_idx
|
|
2757
|
+
|
|
2758
|
+
# remove the end parameter
|
|
2759
|
+
self._values.pop()
|
|
2760
|
+
self._counts.pop()
|
|
2761
|
+
|
|
2762
|
+
|
|
2763
|
+
class ClausePotentialFunction(PotentialFunction):
|
|
2764
|
+
"""
|
|
2765
|
+
A clause potential function represents a clause From a CNF formula.
|
|
2766
|
+
I.e. a clause over variables X, Y, Z, is a disjunction of the form: 'X=x or Y=y or Z=z'.
|
|
2767
|
+
|
|
2768
|
+
A clause potential function is guaranteed zero for a key where the clause is false,
|
|
2769
|
+
i.e., when 'X != x and Y != y and Z != z'.
|
|
2770
|
+
|
|
2771
|
+
For keys where the clause is true, the value of the potential function
|
|
2772
|
+
is given by the only parameter of the potential function. That parameter
|
|
2773
|
+
is called the clause 'weight' and is notionally 1.
|
|
2774
|
+
|
|
2775
|
+
The weight of a clause is permitted to be zero, but that is _not_ equivalent to
|
|
2776
|
+
guaranteed-zero.
|
|
2777
|
+
"""
|
|
2778
|
+
|
|
2779
|
+
def __init__(self, factor: Factor, key: Key, weight: float = 1):
|
|
2780
|
+
"""
|
|
2781
|
+
Create a clause potential function for the given factor.
|
|
2782
|
+
|
|
2783
|
+
Ensures:
|
|
2784
|
+
Does not hold a reference to the given factor.
|
|
2785
|
+
Does not register the potential function with the PGM.
|
|
2786
|
+
|
|
2787
|
+
Raises:
|
|
2788
|
+
KeyError: if the key is not valid for the shape of the factor.
|
|
2789
|
+
|
|
2790
|
+
Args:
|
|
2791
|
+
factor: which factor is this potential function is compatible with.
|
|
2792
|
+
key: defines the random variable states of the clause.
|
|
2793
|
+
"""
|
|
2794
|
+
super().__init__(factor)
|
|
2795
|
+
self._weight: float = weight
|
|
2796
|
+
self._clause: Instance = self.check_key(key)
|
|
2797
|
+
self._num_not_guaranteed_zero: int = _zero_space(self.shape)
|
|
2798
|
+
|
|
2799
|
+
@property
|
|
2800
|
+
def number_of_not_guaranteed_zero(self) -> int:
|
|
2801
|
+
return self._num_not_guaranteed_zero
|
|
2802
|
+
|
|
2803
|
+
@property
|
|
2804
|
+
def number_of_parameters(self) -> int:
|
|
2805
|
+
return 1
|
|
2806
|
+
|
|
2807
|
+
@property
|
|
2808
|
+
def params(self) -> Iterable[Tuple[int, float]]:
|
|
2809
|
+
return ((0, self._weight),)
|
|
2810
|
+
|
|
2811
|
+
@property
|
|
2812
|
+
def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
|
|
2813
|
+
value = self._weight
|
|
2814
|
+
for i in range(self.number_of_rvs):
|
|
2815
|
+
key = list(self._clause)
|
|
2816
|
+
for j in range(self.shape[i]):
|
|
2817
|
+
key[i] = j
|
|
2818
|
+
yield tuple(key), 0, value
|
|
2819
|
+
|
|
2820
|
+
def __getitem__(self, key: Key) -> float:
|
|
2821
|
+
instance: Instance = self.check_key(key)
|
|
2822
|
+
for key_state_idx, clause_state_idx in zip(instance, self._clause):
|
|
2823
|
+
if key_state_idx == clause_state_idx:
|
|
2824
|
+
return self._weight
|
|
2825
|
+
return 0
|
|
2826
|
+
|
|
2827
|
+
def param_value(self, param_idx: int) -> float:
|
|
2828
|
+
if param_idx != 0:
|
|
2829
|
+
raise IndexError(param_idx)
|
|
2830
|
+
return self._weight
|
|
2831
|
+
|
|
2832
|
+
def param_idx(self, key: Key) -> Optional[int]:
|
|
2833
|
+
instance: Instance = _key_to_instance(key)
|
|
2834
|
+
if instance == self._clause:
|
|
2835
|
+
return 0
|
|
2836
|
+
else:
|
|
2837
|
+
return None
|
|
2838
|
+
|
|
2839
|
+
def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
|
|
2840
|
+
"""
|
|
2841
|
+
A ClausePotentialFunction can only be a CTP when all entries are zero.
|
|
2842
|
+
"""
|
|
2843
|
+
return -tolerance <= self._weight <= tolerance
|
|
2844
|
+
|
|
2845
|
+
def is_sparse(self) -> bool:
|
|
2846
|
+
return True
|
|
2847
|
+
|
|
2848
|
+
@property
|
|
2849
|
+
def weight(self) -> float:
|
|
2850
|
+
"""
|
|
2851
|
+
Returns:
|
|
2852
|
+
the "weight" parameter defining the potential function.
|
|
2853
|
+
"""
|
|
2854
|
+
return self._weight
|
|
2855
|
+
|
|
2856
|
+
@property
|
|
2857
|
+
def clause(self) -> Instance:
|
|
2858
|
+
"""
|
|
2859
|
+
Returns:
|
|
2860
|
+
the clause defining the potential function.
|
|
2861
|
+
"""
|
|
2862
|
+
return self._clause
|
|
2863
|
+
|
|
2864
|
+
# Mutators
|
|
2865
|
+
|
|
2866
|
+
@weight.setter
|
|
2867
|
+
def weight(self, value: float) -> None:
|
|
2868
|
+
"""
|
|
2869
|
+
Set the weight parameter to the given value.
|
|
2870
|
+
"""
|
|
2871
|
+
self._weight = value
|
|
2872
|
+
|
|
2873
|
+
@clause.setter
|
|
2874
|
+
def clause(self, key: Key) -> None:
|
|
2875
|
+
"""
|
|
2876
|
+
Set the clause to the given key.
|
|
2877
|
+
|
|
2878
|
+
Raises:
|
|
2879
|
+
KeyError: if the key is not valid for the shape of the factor.
|
|
2880
|
+
"""
|
|
2881
|
+
self._clause = self.check_key(key)
|
|
2882
|
+
|
|
2883
|
+
|
|
2884
|
+
class CPTPotentialFunction(PotentialFunction):
|
|
2885
|
+
"""
|
|
2886
|
+
A potential function implementing a sparse Conditional Probability Table (CPT).
|
|
2887
|
+
|
|
2888
|
+
The first random variable in the signature is the child, and the remaining random
|
|
2889
|
+
variables are parents.
|
|
2890
|
+
|
|
2891
|
+
For each instantiation of the parent random variables there is a Conditioned Probability
|
|
2892
|
+
Distribution (CPD) over the states of the child random variable.
|
|
2893
|
+
|
|
2894
|
+
If a CPD is not provided for a parent instantiation, then that parent instantiation
|
|
2895
|
+
is taken to have probability zero (i.e., all values of the CPD are guaranteed zero).
|
|
2896
|
+
"""
|
|
2897
|
+
|
|
2898
|
+
def __init__(self, factor: Factor, tolerance: float):
|
|
2899
|
+
"""
|
|
2900
|
+
Create a CPT potential function for the given factor.
|
|
2901
|
+
|
|
2902
|
+
Ensures:
|
|
2903
|
+
Does not hold a reference to the given factor.
|
|
2904
|
+
Does not register the potential function with the PGM.
|
|
2905
|
+
|
|
2906
|
+
Args:
|
|
2907
|
+
factor: which factor is this potential function is compatible with.
|
|
2908
|
+
tolerance: a tolerance when testing if values are equal to zero or one.
|
|
2909
|
+
|
|
2910
|
+
Raises:
|
|
2911
|
+
ValueError: if tolerance is negative.
|
|
2912
|
+
"""
|
|
2913
|
+
super().__init__(factor)
|
|
2914
|
+
|
|
2915
|
+
if tolerance < 0:
|
|
2916
|
+
raise ValueError('tolerance cannot be negative')
|
|
2917
|
+
|
|
2918
|
+
self._child_size: int = self.shape[0]
|
|
2919
|
+
self._parent_shape: Shape = self.shape[1:]
|
|
2920
|
+
self._map: Dict[Instance, int] = {}
|
|
2921
|
+
self._values: List[float] = []
|
|
2922
|
+
self._inv_map: List[Instance] = []
|
|
2923
|
+
self._tolerance = tolerance
|
|
2924
|
+
|
|
2925
|
+
@property
|
|
2926
|
+
def number_of_not_guaranteed_zero(self) -> int:
|
|
2927
|
+
return len(self._values)
|
|
2928
|
+
|
|
2929
|
+
@property
|
|
2930
|
+
def number_of_parameters(self) -> int:
|
|
2931
|
+
return len(self._values)
|
|
2932
|
+
|
|
2933
|
+
def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
|
|
2934
|
+
if tolerance >= self._tolerance:
|
|
2935
|
+
return True
|
|
2936
|
+
else:
|
|
2937
|
+
# The requested tolerance is tighter than ensured.
|
|
2938
|
+
# Need to use the default method
|
|
2939
|
+
return super().is_cpt(tolerance)
|
|
2940
|
+
|
|
2941
|
+
@property
|
|
2942
|
+
def params(self) -> Iterable[Tuple[int, float]]:
|
|
2943
|
+
return enumerate(self._values)
|
|
2944
|
+
|
|
2945
|
+
@property
|
|
2946
|
+
def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
|
|
2947
|
+
child_size: int = self._child_size
|
|
2948
|
+
for param_idx, value in enumerate(self._values):
|
|
2949
|
+
parent: Instance = self._inv_map[param_idx // child_size]
|
|
2950
|
+
key: Instance = (param_idx % child_size,) + tuple(parent)
|
|
2951
|
+
yield key, param_idx, value
|
|
2952
|
+
|
|
2953
|
+
def __getitem__(self, key: Key) -> float:
|
|
2954
|
+
param_idx: Optional[int] = self.param_idx(key)
|
|
2955
|
+
if param_idx is None:
|
|
2956
|
+
return 0
|
|
2957
|
+
else:
|
|
2958
|
+
return self._values[param_idx]
|
|
2959
|
+
|
|
2960
|
+
def param_value(self, param_idx: int) -> float:
|
|
2961
|
+
return self._values[param_idx]
|
|
2962
|
+
|
|
2963
|
+
def param_idx(self, key: Key) -> Optional[int]:
|
|
2964
|
+
instance: Instance = self.check_key(key)
|
|
2965
|
+
offset: Optional[int] = self._map.get(instance[1:])
|
|
2966
|
+
if offset is None:
|
|
2967
|
+
return None
|
|
2968
|
+
else:
|
|
2969
|
+
return offset + instance[0]
|
|
2970
|
+
|
|
2971
|
+
@property
|
|
2972
|
+
def parent_shape(self) -> Shape:
|
|
2973
|
+
"""
|
|
2974
|
+
What is the shape of the parents.
|
|
2975
|
+
"""
|
|
2976
|
+
return self._parent_shape
|
|
2977
|
+
|
|
2978
|
+
@property
|
|
2979
|
+
def number_of_parent_states(self) -> int:
|
|
2980
|
+
"""
|
|
2981
|
+
How many combinations of parent states.
|
|
2982
|
+
"""
|
|
2983
|
+
return _multiply(self._parent_shape)
|
|
2984
|
+
|
|
2985
|
+
@property
|
|
2986
|
+
def number_of_child_states(self) -> int:
|
|
2987
|
+
"""
|
|
2988
|
+
Number of child random variable states.
|
|
2989
|
+
|
|
2990
|
+
This is the same as the number of values in each conditional
|
|
2991
|
+
probability distribution. This is equivalent to `self.shape[0]`.
|
|
2992
|
+
|
|
2993
|
+
Returns:
|
|
2994
|
+
the number of child states.
|
|
2995
|
+
"""
|
|
2996
|
+
return self._child_size
|
|
2997
|
+
|
|
2998
|
+
def get_cpd(self, parent_states: Key) -> List[float]:
|
|
2999
|
+
"""
|
|
3000
|
+
Get the CPD conditioned on parent states indicated by `parent_states`.
|
|
3001
|
+
|
|
3002
|
+
Args:
|
|
3003
|
+
parent_states: indicates the parent states.
|
|
3004
|
+
|
|
3005
|
+
Returns:
|
|
3006
|
+
The conditioned probability distribution.
|
|
3007
|
+
"""
|
|
3008
|
+
parent_instance: Instance = check_key(self._parent_shape, parent_states)
|
|
3009
|
+
offset: Optional[int] = self._map.get(parent_instance)
|
|
3010
|
+
child_size: int = self._child_size
|
|
3011
|
+
if offset is None:
|
|
3012
|
+
return [0] * child_size
|
|
3013
|
+
else:
|
|
3014
|
+
return self._values[offset:offset + child_size]
|
|
3015
|
+
|
|
3016
|
+
def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
|
|
3017
|
+
"""
|
|
3018
|
+
Iterate over (parent_states, cpd) tuples.
|
|
3019
|
+
This will exclude zero CPDs.
|
|
3020
|
+
Do not change CPDs to (or from) zero while iterating over them.
|
|
3021
|
+
|
|
3022
|
+
Get the CPD conditioned on parent states indicated by `parent_states`.
|
|
3023
|
+
|
|
3024
|
+
Returns:
|
|
3025
|
+
an iterator over pairs (instance, cpd) where,
|
|
3026
|
+
instance: is indicates the state of the parent random variables.
|
|
3027
|
+
cpd: is the conditioned probability distribution, for the parent instance.
|
|
3028
|
+
"""
|
|
3029
|
+
for parent_instance, offset in self._map.items():
|
|
3030
|
+
cpd = self._values[offset:offset + self._child_size]
|
|
3031
|
+
yield parent_instance, cpd
|
|
3032
|
+
|
|
3033
|
+
# Mutators
|
|
3034
|
+
|
|
3035
|
+
def clear(self) -> CPTPotentialFunction:
|
|
3036
|
+
"""
|
|
3037
|
+
Set all values of the potential function to zero.
|
|
3038
|
+
|
|
3039
|
+
Returns:
|
|
3040
|
+
self
|
|
3041
|
+
"""
|
|
3042
|
+
self._map = {}
|
|
3043
|
+
self._values = []
|
|
3044
|
+
self._inv_map = []
|
|
3045
|
+
return self
|
|
3046
|
+
|
|
3047
|
+
def set_uniform(self) -> CPTPotentialFunction:
|
|
3048
|
+
"""
|
|
3049
|
+
Set each CPD to a uniform distribution.
|
|
3050
|
+
|
|
3051
|
+
Returns:
|
|
3052
|
+
self
|
|
3053
|
+
"""
|
|
3054
|
+
self.clear()
|
|
3055
|
+
for parent_states in self.parent_instances():
|
|
3056
|
+
self.set_cpd_uniform(parent_states)
|
|
3057
|
+
return self
|
|
3058
|
+
|
|
3059
|
+
def set_random(self, random: Callable[[], float], sparsity: float = 0) -> CPTPotentialFunction:
|
|
3060
|
+
"""
|
|
3061
|
+
Set the values of the potential function to random CPDs.
|
|
3062
|
+
|
|
3063
|
+
Args:
|
|
3064
|
+
random: is a stream of random numbers, assumed uniformly distributed in the interval [0, 1].
|
|
3065
|
+
sparsity: sets the expected proportion of probability values that are zero.
|
|
3066
|
+
|
|
3067
|
+
Returns:
|
|
3068
|
+
self
|
|
3069
|
+
"""
|
|
3070
|
+
self.clear()
|
|
3071
|
+
for parent_states in self.parent_instances():
|
|
3072
|
+
self.set_cpd_random(parent_states, random, sparsity)
|
|
3073
|
+
return self
|
|
3074
|
+
|
|
3075
|
+
def set(self, *rows: Tuple[Key, Sequence[float]]) -> CPTPotentialFunction:
|
|
3076
|
+
"""
|
|
3077
|
+
Calls self.set_cpd(parent_states, cpd) for each row (parent_states, cpd)
|
|
3078
|
+
in rows. Any unmentioned parent states will have zero probabilities.
|
|
3079
|
+
|
|
3080
|
+
Example usage, assuming three Boolean random variables:
|
|
3081
|
+
pgm.Factor(x, y, z).set_cpt().set(
|
|
3082
|
+
# y z x[0] x[1]
|
|
3083
|
+
((0, 0), (0.1, 0.9)),
|
|
3084
|
+
((0, 1), (0.1, 0.9)),
|
|
3085
|
+
((1, 0), (0.1, 0.9)),
|
|
3086
|
+
((1, 1), (0.1, 0.9))
|
|
3087
|
+
)
|
|
3088
|
+
|
|
3089
|
+
Args:
|
|
3090
|
+
*rows: are tuples (key, cpd) used to set the potential function values.
|
|
3091
|
+
|
|
3092
|
+
Raises:
|
|
3093
|
+
ValueError: if a CPD is not valid.
|
|
3094
|
+
|
|
3095
|
+
Returns:
|
|
3096
|
+
self
|
|
3097
|
+
"""
|
|
3098
|
+
self.clear()
|
|
3099
|
+
for parent_states, cpd in rows:
|
|
3100
|
+
self.set_cpd(parent_states, cpd)
|
|
3101
|
+
return self
|
|
3102
|
+
|
|
3103
|
+
def set_all(self, *cpds: Optional[Sequence[float]]) -> CPTPotentialFunction:
|
|
3104
|
+
"""
|
|
3105
|
+
Set all CPDs using the given `cpds` which are taken to be in order of the parent states
|
|
3106
|
+
with the last variable of the parent changing state most rapidly, as per parent_states().
|
|
3107
|
+
|
|
3108
|
+
If insufficient CPDs are provided then the remaining parent instantiations are taken to be
|
|
3109
|
+
impossible (i.e. not set and guaranteed zero).
|
|
3110
|
+
If too many CPDs are provided then the extras are ignored.
|
|
3111
|
+
Any list entry may be None, indicating 'guaranteed zero' for the associated parent states.
|
|
3112
|
+
|
|
3113
|
+
Args:
|
|
3114
|
+
*cpds: are the CPDs used to set the potential function values.
|
|
3115
|
+
|
|
3116
|
+
Raises:
|
|
3117
|
+
ValueError: if a CPD is not valid.
|
|
3118
|
+
|
|
3119
|
+
Returns:
|
|
3120
|
+
self
|
|
3121
|
+
"""
|
|
3122
|
+
self.clear()
|
|
3123
|
+
for parent_states, cpd in zip(self.parent_instances(), cpds):
|
|
3124
|
+
self.set_cpd(parent_states, cpd)
|
|
3125
|
+
return self
|
|
3126
|
+
|
|
3127
|
+
def set_cpd(self, parent_states: Key, cpd: Optional[Sequence[float]]) -> CPTPotentialFunction:
|
|
3128
|
+
"""
|
|
3129
|
+
Set the CPD of the given parent states to the given cpd.
|
|
3130
|
+
If cpd is None or all zeros, then this is equivalent to clear_cpd(parent_states).
|
|
3131
|
+
|
|
3132
|
+
Args:
|
|
3133
|
+
parent_states: indicates the CPD to set, based on the parent states.
|
|
3134
|
+
cpd: is a conditioned probability distribution, or None indicating `guaranteed zero`.
|
|
3135
|
+
|
|
3136
|
+
Raises:
|
|
3137
|
+
ValueError: if the CPD is not valid.
|
|
3138
|
+
KeyError if the key is not valid.
|
|
3139
|
+
|
|
3140
|
+
Returns:
|
|
3141
|
+
self
|
|
3142
|
+
"""
|
|
3143
|
+
parent_instance: Instance = check_key(self._parent_shape, parent_states)
|
|
3144
|
+
|
|
3145
|
+
if cpd is None:
|
|
3146
|
+
self._clear_cpd(parent_instance)
|
|
3147
|
+
return self
|
|
3148
|
+
|
|
3149
|
+
if len(cpd) != self._child_size:
|
|
3150
|
+
raise ValueError(f'CPD incorrect size: expected {self._child_size}, got {len(cpd)}')
|
|
3151
|
+
if not all(0 <= value <= 1 for value in cpd):
|
|
3152
|
+
raise ValueError(f'not a valid CPD: {cpd!r}')
|
|
3153
|
+
|
|
3154
|
+
total_value = sum(cpd)
|
|
3155
|
+
if total_value < self._tolerance:
|
|
3156
|
+
self._clear_cpd(parent_instance)
|
|
3157
|
+
return self
|
|
3158
|
+
|
|
3159
|
+
if total_value < 1 - self._tolerance or total_value > 1 + self._tolerance:
|
|
3160
|
+
raise ValueError(f'not a valid CPD: sum of values = {total_value}')
|
|
3161
|
+
|
|
3162
|
+
offset: Optional[int] = self._map.get(parent_instance)
|
|
3163
|
+
child_size: int = self._child_size
|
|
3164
|
+
if offset is None:
|
|
3165
|
+
offset = len(self._values)
|
|
3166
|
+
self._values.extend(cpd)
|
|
3167
|
+
self._map[parent_instance] = offset
|
|
3168
|
+
self._inv_map.append(parent_instance)
|
|
3169
|
+
else:
|
|
3170
|
+
self._values[offset:offset + child_size] = cpd
|
|
3171
|
+
|
|
3172
|
+
return self
|
|
3173
|
+
|
|
3174
|
+
def clear_cpd(self, parent_states: Key) -> CPTPotentialFunction:
|
|
3175
|
+
"""
|
|
3176
|
+
Set the CPD of the given parent_states to all 'guaranteed zero'.
|
|
3177
|
+
|
|
3178
|
+
Args:
|
|
3179
|
+
parent_states: indicates the CPD to clear, based on the parent states.
|
|
3180
|
+
|
|
3181
|
+
Raises:
|
|
3182
|
+
KeyError if the key is not valid.
|
|
3183
|
+
|
|
3184
|
+
Returns:
|
|
3185
|
+
self
|
|
3186
|
+
"""
|
|
3187
|
+
parent_instance: Instance = check_key(self._parent_shape, parent_states)
|
|
3188
|
+
self._clear_cpd(parent_instance)
|
|
3189
|
+
return self
|
|
3190
|
+
|
|
3191
|
+
def set_cpd_uniform(self, parent_states: Key) -> CPTPotentialFunction:
|
|
3192
|
+
"""
|
|
3193
|
+
Set the CPD of the given parent_states to a uniform CPD.
|
|
3194
|
+
|
|
3195
|
+
Args:
|
|
3196
|
+
parent_states: indicates the CPD to clear, based on the parent states.
|
|
3197
|
+
|
|
3198
|
+
Raises:
|
|
3199
|
+
KeyError if the key is not valid.
|
|
3200
|
+
|
|
3201
|
+
Returns:
|
|
3202
|
+
self
|
|
3203
|
+
"""
|
|
3204
|
+
num_states = self.number_of_child_states
|
|
3205
|
+
cpd = [1.0 / num_states] * num_states
|
|
3206
|
+
return self.set_cpd(parent_states, cpd)
|
|
3207
|
+
|
|
3208
|
+
def set_cpd_random(
|
|
3209
|
+
self,
|
|
3210
|
+
parent_states: Key,
|
|
3211
|
+
random: Callable[[], float],
|
|
3212
|
+
sparsity: float = 0,
|
|
3213
|
+
) -> CPTPotentialFunction:
|
|
3214
|
+
"""
|
|
3215
|
+
Set the CPD of the given parent_states to a random CPD.
|
|
3216
|
+
|
|
3217
|
+
Args:
|
|
3218
|
+
parent_states: identifies the CPD being set.
|
|
3219
|
+
random: is a stream of random numbers, assumed uniformly distributed in the interval [0, 1].
|
|
3220
|
+
sparsity: sets the expected proportion of probability values that are zero.
|
|
3221
|
+
|
|
3222
|
+
Returns:
|
|
3223
|
+
self
|
|
3224
|
+
"""
|
|
3225
|
+
cpd = np.zeros(self.number_of_child_states, dtype=np.float64)
|
|
3226
|
+
if sparsity <= 0:
|
|
3227
|
+
for i in range(len(cpd)):
|
|
3228
|
+
cpd[i] = 0.0000001 + random()
|
|
3229
|
+
else:
|
|
3230
|
+
for i in range(len(cpd)):
|
|
3231
|
+
if random() > sparsity:
|
|
3232
|
+
cpd[i] = 0.0000001 + random()
|
|
3233
|
+
sum_value = np.sum(cpd)
|
|
3234
|
+
if sum_value > 0:
|
|
3235
|
+
cpd /= sum_value
|
|
3236
|
+
return self.set_cpd(parent_states, cpd)
|
|
3237
|
+
else:
|
|
3238
|
+
return self.clear_cpd(parent_states)
|
|
3239
|
+
|
|
3240
|
+
def _clear_cpd(self, parent_instance: Instance) -> None:
|
|
3241
|
+
"""
|
|
3242
|
+
Remove the parent instance from the parameters
|
|
3243
|
+
"""
|
|
3244
|
+
offset: Optional[int] = self._map.get(parent_instance)
|
|
3245
|
+
if offset is None:
|
|
3246
|
+
# nothing to do
|
|
3247
|
+
return
|
|
3248
|
+
|
|
3249
|
+
child_size: int = self._child_size
|
|
3250
|
+
end_offset: int = len(self._values) - child_size
|
|
3251
|
+
if offset != end_offset:
|
|
3252
|
+
# need to swap parameters
|
|
3253
|
+
end_cpd = self._values[end_offset:]
|
|
3254
|
+
end_parent_instance = self._inv_map[-1]
|
|
3255
|
+
|
|
3256
|
+
self._values[offset:offset + child_size] = end_cpd
|
|
3257
|
+
self._map[end_parent_instance] = offset
|
|
3258
|
+
self._inv_map[offset // child_size] = end_parent_instance
|
|
3259
|
+
|
|
3260
|
+
self._map.pop(parent_instance)
|
|
3261
|
+
self._inv_map.pop()
|
|
3262
|
+
for _ in range(child_size):
|
|
3263
|
+
self._values.pop()
|
|
3264
|
+
|
|
3265
|
+
|
|
3266
|
+
def default_pgm_name(pgm: PGM) -> str:
|
|
3267
|
+
"""
|
|
3268
|
+
If no name is provided to a PGM constructor, then this will be the default name for the PGM.
|
|
3269
|
+
|
|
3270
|
+
Args:
|
|
3271
|
+
pgm: a PGM object.
|
|
3272
|
+
|
|
3273
|
+
Returns:
|
|
3274
|
+
a name for the PGM if none is given at construction time.
|
|
3275
|
+
"""
|
|
3276
|
+
return 'PGM_' + str(id(pgm))
|
|
3277
|
+
|
|
3278
|
+
|
|
3279
|
+
def check_key(shape: Shape, key: Key) -> Instance:
|
|
3280
|
+
"""
|
|
3281
|
+
Convert the key into an instance.
|
|
3282
|
+
|
|
3283
|
+
Args:
|
|
3284
|
+
shape: the shape defining the state space.
|
|
3285
|
+
key: a key into the state space.
|
|
3286
|
+
|
|
3287
|
+
Returns:
|
|
3288
|
+
A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
|
|
3289
|
+
|
|
3290
|
+
Raises:
|
|
3291
|
+
KeyError if the key is not valid.
|
|
3292
|
+
"""
|
|
3293
|
+
_key: Instance = _key_to_instance(key)
|
|
3294
|
+
if len(_key) != len(shape):
|
|
3295
|
+
raise KeyError(f'not a valid key for shape {shape}: {key!r}')
|
|
3296
|
+
if all((0 <= i <= m) for i, m in zip(_key, shape)):
|
|
3297
|
+
return tuple(_key)
|
|
3298
|
+
raise KeyError(f'not a valid key for shape {shape}: {key!r}')
|
|
3299
|
+
|
|
3300
|
+
|
|
3301
|
+
def valid_key(shape: Shape, key: Key) -> bool:
|
|
3302
|
+
"""
|
|
3303
|
+
Is the given key valid.
|
|
3304
|
+
|
|
3305
|
+
Args:
|
|
3306
|
+
shape: the shape defining the state space.
|
|
3307
|
+
key: a key into the state space.
|
|
3308
|
+
|
|
3309
|
+
Returns:
|
|
3310
|
+
True only if tke key is valid for the given shape.
|
|
3311
|
+
"""
|
|
3312
|
+
try:
|
|
3313
|
+
check_key(shape, key)
|
|
3314
|
+
return True
|
|
3315
|
+
except KeyError:
|
|
3316
|
+
return False
|
|
3317
|
+
|
|
3318
|
+
|
|
3319
|
+
def number_of_states(*rvs: RandomVariable) -> int:
|
|
3320
|
+
"""
|
|
3321
|
+
Returns:
|
|
3322
|
+
What is the size of the state space, i.e., `multiply(len(rv) for rv in self.rvs)`.
|
|
3323
|
+
"""
|
|
3324
|
+
return _multiply(len(rv) for rv in rvs)
|
|
3325
|
+
|
|
3326
|
+
|
|
3327
|
+
def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]:
|
|
3328
|
+
"""
|
|
3329
|
+
Enumerate instances of the given random variables.
|
|
3330
|
+
|
|
3331
|
+
Each instance is a tuples of state indexes, co-indexed with the given random variables.
|
|
3332
|
+
|
|
3333
|
+
The order is the natural index order (i.e., last random variable changing most quickly).
|
|
3334
|
+
|
|
3335
|
+
Args:
|
|
3336
|
+
flip: if true, then first random variable changes most quickly.
|
|
3337
|
+
|
|
3338
|
+
Returns:
|
|
3339
|
+
an iteration over tuples, each tuple holds state indexes
|
|
3340
|
+
co-indexed with the given random variables.
|
|
3341
|
+
"""
|
|
3342
|
+
shape = [len(rv) for rv in rvs]
|
|
3343
|
+
return _combos_ranges(shape, flip=not flip)
|
|
3344
|
+
|
|
3345
|
+
|
|
3346
|
+
def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterable[Sequence[Indicator]]:
|
|
3347
|
+
"""
|
|
3348
|
+
Enumerate instances of the given random variables.
|
|
3349
|
+
|
|
3350
|
+
Each instance is a tuples of indicators, co-indexed with the given random variables.
|
|
3351
|
+
|
|
3352
|
+
The order is the natural index order (i.e., last random variable changing most quickly).
|
|
3353
|
+
|
|
3354
|
+
Args:
|
|
3355
|
+
flip: if true, then first random variable changes most quickly.
|
|
3356
|
+
|
|
3357
|
+
Returns:
|
|
3358
|
+
an iteration over tuples, each tuples holds random variable indicators
|
|
3359
|
+
co-indexed with the given random variables.
|
|
3360
|
+
"""
|
|
3361
|
+
return _combos(rvs, flip=not flip)
|
|
3362
|
+
|
|
3363
|
+
|
|
3364
|
+
def _key_to_instance(key: Key) -> Instance:
|
|
3365
|
+
"""
|
|
3366
|
+
Convert a key to an instance.
|
|
3367
|
+
|
|
3368
|
+
Args:
|
|
3369
|
+
key: a key into a state space.
|
|
3370
|
+
|
|
3371
|
+
Returns:
|
|
3372
|
+
A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
|
|
3373
|
+
|
|
3374
|
+
Assumes:
|
|
3375
|
+
The key is valid for the implied state space.
|
|
3376
|
+
"""
|
|
3377
|
+
if isinstance(key, int):
|
|
3378
|
+
return (key,)
|
|
3379
|
+
else:
|
|
3380
|
+
return tuple(key)
|
|
3381
|
+
|
|
3382
|
+
|
|
3383
|
+
def _natural_key_idx(shape: Shape, key: Key) -> int:
|
|
3384
|
+
"""
|
|
3385
|
+
What is the natural index of the given key, assuming the given shape.
|
|
3386
|
+
|
|
3387
|
+
Args:
|
|
3388
|
+
shape: the shape defining the state space.
|
|
3389
|
+
key: a key into the state space.
|
|
3390
|
+
|
|
3391
|
+
Returns:
|
|
3392
|
+
an index as per enumerated instances in their natural order, i.e.
|
|
3393
|
+
last random variable changing most quickly.
|
|
3394
|
+
|
|
3395
|
+
Assumes:
|
|
3396
|
+
The key is valid for the shape.
|
|
3397
|
+
"""
|
|
3398
|
+
instance: Instance = _key_to_instance(key)
|
|
3399
|
+
result: int = instance[0]
|
|
3400
|
+
for s, i in zip(shape[1:], instance[1:]):
|
|
3401
|
+
result = result * s + i
|
|
3402
|
+
return result
|
|
3403
|
+
|
|
3404
|
+
|
|
3405
|
+
def _zero_space(shape: Shape) -> int:
|
|
3406
|
+
"""
|
|
3407
|
+
Return the size of the zero space of the given shape. This is the number
|
|
3408
|
+
of possible instances in the state space that do not have a zero in the instance.
|
|
3409
|
+
|
|
3410
|
+
The zero space is the same as the shape but with one less state
|
|
3411
|
+
for each random variable.
|
|
3412
|
+
|
|
3413
|
+
Args:
|
|
3414
|
+
shape: the shape defining the state space.
|
|
3415
|
+
|
|
3416
|
+
Returns:
|
|
3417
|
+
the size of the zero space.
|
|
3418
|
+
"""
|
|
3419
|
+
return _multiply(x - 1 for x in shape)
|
|
3420
|
+
|
|
3421
|
+
|
|
3422
|
+
def _normalise_potential_function(
|
|
3423
|
+
function: Union[DensePotentialFunction, SparsePotentialFunction],
|
|
3424
|
+
grouping_positions: Sequence[int],
|
|
3425
|
+
) -> None:
|
|
3426
|
+
"""
|
|
3427
|
+
Convert the potential function to a CPT with 'grouping_positions' nominating
|
|
3428
|
+
the parent random variables.
|
|
3429
|
+
|
|
3430
|
+
I.e., for each possible key of the function with the same value at each
|
|
3431
|
+
grouping position, the sum of values for matching keys in the factor is scaled
|
|
3432
|
+
to be 1 (or 0).
|
|
3433
|
+
|
|
3434
|
+
Parameter 'grouping_positions' are indices into `function.shape`. For example, the
|
|
3435
|
+
grouping positions of a factor with parent rvs 'conditioning_rvs', then
|
|
3436
|
+
grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
|
|
3437
|
+
|
|
3438
|
+
Args:
|
|
3439
|
+
function: the potential function to normalise.
|
|
3440
|
+
grouping_positions: indices into `function.shape`.
|
|
3441
|
+
"""
|
|
3442
|
+
if len(grouping_positions) == 0:
|
|
3443
|
+
total = sum(
|
|
3444
|
+
function.param_value(param_idx)
|
|
3445
|
+
for param_idx in range(function.number_of_parameters)
|
|
3446
|
+
)
|
|
3447
|
+
if total != 0 and total != 1:
|
|
3448
|
+
for param_key, param_idx, param_value in function.keys_with_param:
|
|
3449
|
+
function.set_param_value(param_idx, param_value / total)
|
|
3450
|
+
else:
|
|
3451
|
+
group_sum = {}
|
|
3452
|
+
for param_key, param_idx, param_value in function.keys_with_param:
|
|
3453
|
+
group = tuple(param_key[i] for i in grouping_positions)
|
|
3454
|
+
group_sum[group] = group_sum.get(group, 0) + param_value
|
|
3455
|
+
|
|
3456
|
+
for param_key, param_idx, param_value in function.keys_with_param:
|
|
3457
|
+
group = tuple(param_key[i] for i in grouping_positions)
|
|
3458
|
+
total = group_sum[group]
|
|
3459
|
+
if total > 0:
|
|
3460
|
+
function.set_param_value(param_idx, param_value / total)
|
|
3461
|
+
|
|
3462
|
+
|
|
3463
|
+
_CLEAN_CHARS: Set[str] = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-+~?.')
|
|
3464
|
+
|
|
3465
|
+
|
|
3466
|
+
def _clean_str(s) -> str:
|
|
3467
|
+
"""
|
|
3468
|
+
Quote a string if empty or not all characters are in _CLEAN_CHARS.
|
|
3469
|
+
This is used when rendering indicators.
|
|
3470
|
+
"""
|
|
3471
|
+
s = str(s)
|
|
3472
|
+
if len(s) == 0 or not all(c in _CLEAN_CHARS for c in s):
|
|
3473
|
+
return repr(s)
|
|
3474
|
+
else:
|
|
3475
|
+
return s
|