compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (167) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +13 -0
  3. ck/circuit/circuit.c +38749 -0
  4. ck/circuit/circuit.cpython-313-darwin.so +0 -0
  5. ck/circuit/circuit_py.py +807 -0
  6. ck/circuit/tmp_const.py +74 -0
  7. ck/circuit_compiler/__init__.py +2 -0
  8. ck/circuit_compiler/circuit_compiler.py +26 -0
  9. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  10. ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
  11. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  12. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
  13. ck/circuit_compiler/interpret_compiler.py +223 -0
  14. ck/circuit_compiler/llvm_compiler.py +388 -0
  15. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  16. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  17. ck/circuit_compiler/support/__init__.py +0 -0
  18. ck/circuit_compiler/support/circuit_analyser.py +81 -0
  19. ck/circuit_compiler/support/input_vars.py +148 -0
  20. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  21. ck/example/__init__.py +53 -0
  22. ck/example/alarm.py +366 -0
  23. ck/example/asia.py +28 -0
  24. ck/example/binary_clique.py +32 -0
  25. ck/example/bow_tie.py +33 -0
  26. ck/example/cancer.py +37 -0
  27. ck/example/chain.py +38 -0
  28. ck/example/child.py +199 -0
  29. ck/example/clique.py +33 -0
  30. ck/example/cnf_pgm.py +39 -0
  31. ck/example/diamond_square.py +68 -0
  32. ck/example/earthquake.py +36 -0
  33. ck/example/empty.py +10 -0
  34. ck/example/hailfinder.py +539 -0
  35. ck/example/hepar2.py +628 -0
  36. ck/example/insurance.py +504 -0
  37. ck/example/loop.py +40 -0
  38. ck/example/mildew.py +38161 -0
  39. ck/example/munin.py +22982 -0
  40. ck/example/pathfinder.py +53674 -0
  41. ck/example/rain.py +39 -0
  42. ck/example/rectangle.py +161 -0
  43. ck/example/run.py +30 -0
  44. ck/example/sachs.py +129 -0
  45. ck/example/sprinkler.py +30 -0
  46. ck/example/star.py +44 -0
  47. ck/example/stress.py +64 -0
  48. ck/example/student.py +43 -0
  49. ck/example/survey.py +46 -0
  50. ck/example/triangle_square.py +54 -0
  51. ck/example/truss.py +49 -0
  52. ck/in_out/__init__.py +3 -0
  53. ck/in_out/parse_ace_lmap.py +216 -0
  54. ck/in_out/parse_ace_nnf.py +288 -0
  55. ck/in_out/parse_net.py +480 -0
  56. ck/in_out/parser_utils.py +185 -0
  57. ck/in_out/pgm_pickle.py +42 -0
  58. ck/in_out/pgm_python.py +268 -0
  59. ck/in_out/render_bugs.py +111 -0
  60. ck/in_out/render_net.py +177 -0
  61. ck/in_out/render_pomegranate.py +184 -0
  62. ck/pgm.py +3494 -0
  63. ck/pgm_circuit/__init__.py +1 -0
  64. ck/pgm_circuit/marginals_program.py +352 -0
  65. ck/pgm_circuit/mpe_program.py +237 -0
  66. ck/pgm_circuit/pgm_circuit.py +75 -0
  67. ck/pgm_circuit/program_with_slotmap.py +234 -0
  68. ck/pgm_circuit/slot_map.py +35 -0
  69. ck/pgm_circuit/support/__init__.py +0 -0
  70. ck/pgm_circuit/support/compile_circuit.py +83 -0
  71. ck/pgm_circuit/target_marginals_program.py +103 -0
  72. ck/pgm_circuit/wmc_program.py +323 -0
  73. ck/pgm_compiler/__init__.py +2 -0
  74. ck/pgm_compiler/ace/__init__.py +1 -0
  75. ck/pgm_compiler/ace/ace.py +252 -0
  76. ck/pgm_compiler/factor_elimination.py +383 -0
  77. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  78. ck/pgm_compiler/pgm_compiler.py +19 -0
  79. ck/pgm_compiler/recursive_conditioning.py +226 -0
  80. ck/pgm_compiler/support/__init__.py +0 -0
  81. ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
  82. ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
  83. ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
  84. ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
  85. ck/pgm_compiler/support/clusters.py +556 -0
  86. ck/pgm_compiler/support/factor_tables.py +398 -0
  87. ck/pgm_compiler/support/join_tree.py +275 -0
  88. ck/pgm_compiler/support/named_compiler_maker.py +33 -0
  89. ck/pgm_compiler/variable_elimination.py +89 -0
  90. ck/probability/__init__.py +0 -0
  91. ck/probability/empirical_probability_space.py +47 -0
  92. ck/probability/probability_space.py +568 -0
  93. ck/program/__init__.py +3 -0
  94. ck/program/program.py +129 -0
  95. ck/program/program_buffer.py +180 -0
  96. ck/program/raw_program.py +61 -0
  97. ck/sampling/__init__.py +0 -0
  98. ck/sampling/forward_sampler.py +211 -0
  99. ck/sampling/marginals_direct_sampler.py +113 -0
  100. ck/sampling/sampler.py +62 -0
  101. ck/sampling/sampler_support.py +232 -0
  102. ck/sampling/uniform_sampler.py +66 -0
  103. ck/sampling/wmc_direct_sampler.py +169 -0
  104. ck/sampling/wmc_gibbs_sampler.py +147 -0
  105. ck/sampling/wmc_metropolis_sampler.py +159 -0
  106. ck/sampling/wmc_rejection_sampler.py +113 -0
  107. ck/utils/__init__.py +0 -0
  108. ck/utils/iter_extras.py +153 -0
  109. ck/utils/map_list.py +128 -0
  110. ck/utils/map_set.py +128 -0
  111. ck/utils/np_extras.py +51 -0
  112. ck/utils/random_extras.py +64 -0
  113. ck/utils/tmp_dir.py +94 -0
  114. ck_demos/__init__.py +0 -0
  115. ck_demos/ace/__init__.py +0 -0
  116. ck_demos/ace/copy_ace_to_ck.py +15 -0
  117. ck_demos/ace/demo_ace.py +44 -0
  118. ck_demos/all_demos.py +88 -0
  119. ck_demos/circuit/__init__.py +0 -0
  120. ck_demos/circuit/demo_circuit_dump.py +22 -0
  121. ck_demos/circuit/demo_derivatives.py +43 -0
  122. ck_demos/circuit_compiler/__init__.py +0 -0
  123. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  124. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  125. ck_demos/pgm/__init__.py +0 -0
  126. ck_demos/pgm/demo_pgm_dump.py +18 -0
  127. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  128. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  129. ck_demos/pgm/show_examples.py +25 -0
  130. ck_demos/pgm_compiler/__init__.py +0 -0
  131. ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
  132. ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
  133. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  134. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  135. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  136. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  137. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  138. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  139. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  140. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  141. ck_demos/pgm_inference/__init__.py +0 -0
  142. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  143. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  144. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  145. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  146. ck_demos/programs/__init__.py +0 -0
  147. ck_demos/programs/demo_program_buffer.py +24 -0
  148. ck_demos/programs/demo_program_multi.py +24 -0
  149. ck_demos/programs/demo_program_none.py +19 -0
  150. ck_demos/programs/demo_program_single.py +23 -0
  151. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  152. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  153. ck_demos/sampling/__init__.py +0 -0
  154. ck_demos/sampling/check_sampler.py +71 -0
  155. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  156. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  157. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  158. ck_demos/utils/__init__.py +0 -0
  159. ck_demos/utils/compare.py +88 -0
  160. ck_demos/utils/convert_network.py +45 -0
  161. ck_demos/utils/sample_model.py +216 -0
  162. ck_demos/utils/stop_watch.py +384 -0
  163. compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
  164. compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
  165. compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
  166. compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
  167. compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
ck/pgm.py ADDED
@@ -0,0 +1,3494 @@
1
+ """
2
+ This module support the in-memory creation of probabilistic graphical models.
3
+
4
+ A probabilistic graphical model (PGM) represents a joint probability distribution over
5
+ a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
6
+
7
+ A random variable is represented by a RandomVariable object. Each random variable has a
8
+ fixed, finite number of states. Many algorithms will assume at least two states.
9
+ Every RandomVariable object belongs to exactly one PGM object. A RandomVariable
10
+ has a name (for human convenience) and its states are indexed by integers, counting
11
+ from zero.
12
+
13
+ A PGM also has factors. Each Factor of a PGM connects a set of RandomVariable objects
14
+ of the PGM. In general, the order of the random variables of a factor is functionally
15
+ irrelevant, but is practically relevant for operating with Factor objects. The "shape"
16
+ of a factor is the list of the numbers of states of the factor's random variables (co-indexed
17
+ with the list of random variables of the factor).
18
+
19
+ If a PGM is representing a Bayesian network, then each factor represents a conditional
20
+ probability table (CPT) and the first random variable of the factor is taken to be the child
21
+ random variable, with the remaining random variables being the parents.
22
+
23
+ Every factor has associated with it a potential function. A potential function maps
24
+ each combination of states of the factor's random variables to a value (of type float).
25
+ A combination of states of random variables is represented as a Key. A Key is essentially
26
+ a list of state indexes, co-indexed with the factor's random variables.
27
+
28
+ A potential function is a map from all possible keys (according to the potential function's
29
+ shape) to a float value. Each potential function has zero or more "parameters" which may be
30
+ adjusted to change the potential function's mapping. The parameters of a potential function
31
+ are indexed sequentially from zero.
32
+
33
+ Each parameter of a potential function is associated with one or more keys. The value of the
34
+ parameter is the value of the potential function for it's associated keys. Conversely, each
35
+ key of a potential function is associate with zero or one parameters. That is, it is possible
36
+ that a potential function maps multiple keys to the same parameter, in which case keys that map
37
+ to the same parameter will have the same value.
38
+
39
+ If a key of a potential function is associated with a parameter, then the value of
40
+ the potential function for that key is the value of the parameter.
41
+
42
+ If a key of a potential function is associated with zero parameters then the value of
43
+ the potential function for that key is zero. Furthermore, the key is referred to as
44
+ "guaranteed-zero", meaning that no change in the parameter values of the potential function
45
+ will change the value for that key away from zero.
46
+
47
+ RandomVariable objects are immutable and hashable, including their states.
48
+
49
+ Factor objects cannot change the random variables they are a factor of. However,
50
+ the PotentialFunction associated with a Factor may be updated.
51
+
52
+ Factors may share a potential function, so long as they have the same shape.
53
+
54
+ PotentialFunction objects cannot change their shape, but may be otherwise mutable and
55
+ are generally not hashable. A particular class of potential function may allow its mapping
56
+ to change and even its available parameters to change.
57
+
58
+ There are many kinds of potential function. A DensePotentialFunction has exactly
59
+ one parameter for each possible key (no 'guaranteed-zero' keys) and there are no
60
+ shared parameters. A SparsePotentialFunction only has parameters for explicitly
61
+ mentioned keys.
62
+
63
+ There is a special class of potential function called a ZeroPotentialFunction which
64
+ (like DensePotentialFunction) has a parameter for each possible key (and thus no
65
+ key is guaranteed-zero). However, the value of each parameter is zero and there
66
+ is no mechanism to update these values.
67
+
68
+ A ZeroPotentialFunction is the default PotentialFunction for a Factor. It may be seen
69
+ as a light-weight placeholder until replaced by some other potential function.
70
+ It may also be used as a light-weight surrogate for a DensePotentialFunction when
71
+ performing PGM parameter learning.
72
+
73
+ Each RandomVariable has an index (`idx`) which is a sequence number, starting from zero,
74
+ indicating when that RandomVariable was added to its PGM.
75
+
76
+ Each Factor has an index (`idx`) which is a sequence number, starting from zero,
77
+ indicating when that Factor was added to its PGM.
78
+ """
79
+ from __future__ import annotations
80
+
81
+ import math
82
+ from abc import ABC, abstractmethod
83
+ from dataclasses import dataclass
84
+ from itertools import repeat as _repeat
85
+ from typing import Sequence, Tuple, Dict, Optional, overload, Iterator, Set, Iterable, List, Union, Callable, \
86
+ Collection, Any
87
+
88
+ import numpy as np
89
+
90
+ from ck.utils.iter_extras import (
91
+ combos_ranges as _combos_ranges, multiply as _multiply, combos as _combos
92
+ )
93
+ from ck.utils.np_extras import NDArrayFloat64, NDArrayUInt8
94
+
95
+ # What types are permitted as random variable states
96
+ State = Union[int, str, bool, float, None]
97
+
98
+ # An instance (of a sequence of random variables) is a tuple of integers
99
+ # that are state indexes, co-indexed with a known sequence of random variables.
100
+ Instance = Sequence[int]
101
+
102
+ # A key identifies an instance.
103
+ # A single integer is treated as an instance with one dimension.
104
+ Key = Union[Sequence[int], int]
105
+
106
+ # The shape of a sequence of random variables (e.g., a PGM, Factor or PotentialFunction).
107
+ Shape = Sequence[int]
108
+
109
+ DEFAULT_TOLERANCE: float = 0.000001 # For checking CPT sums.
110
+
111
+
112
+ class PGM:
113
+ """
114
+ A probabilistic graphical model (PGM) represents a joint probability distribution over
115
+ a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
116
+
117
+ Add a random variable to a PGM, pgm, using `rv = pgm.new_rv(...)`.
118
+
119
+ Add a factor to the PGM, pgm, using `factor = pgm.new_factor(...)`.
120
+
121
+ A PGM may be given a human-readable name.
122
+ """
123
+
124
+ def __init__(self, name: Optional[str] = None):
125
+ """
126
+ Create an empty PGM.
127
+
128
+ Args:
129
+ name: an optional name for the PGM. If not provided, a default name will be
130
+ created using `default_pgm_name`.
131
+ """
132
+ self._name: str = name if name is not None else default_pgm_name(self)
133
+ self._rvs: Tuple[RandomVariable, ...] = ()
134
+ self._shape: Shape = ()
135
+ self._indicators: Tuple[Indicator, ...] = ()
136
+ self._factors: Tuple[Factor, ...] = ()
137
+
138
+ @property
139
+ def name(self) -> str:
140
+ """
141
+ Returns:
142
+ The name of the PGM.
143
+ """
144
+ return self._name
145
+
146
+ @property
147
+ def number_of_rvs(self) -> int:
148
+ """
149
+ Returns:
150
+ How many random variables are defined in this PGM.
151
+ """
152
+ return len(self._rvs)
153
+
154
+ @property
155
+ def shape(self) -> Shape:
156
+ """
157
+ Returns:
158
+ a sequence of the lengths of `self.rvs`.
159
+ """
160
+ return self._shape
161
+
162
+ @property
163
+ def number_of_indicators(self) -> int:
164
+ """
165
+ Returns:
166
+ How many indicators are defined in this PGM, i.e., `sum(len(rv) for rv in self.rvs)`.
167
+ """
168
+ return len(self._indicators)
169
+
170
+ @property
171
+ def number_of_states(self) -> int:
172
+ """
173
+ Returns:
174
+ What is the size of the state space, i.e., `multiply(len(rv) for rv in self.rvs)`.
175
+ """
176
+ return number_of_states(*self._rvs)
177
+
178
+ @property
179
+ def number_of_factors(self) -> int:
180
+ """
181
+ Returns:
182
+ How many factors are defined in this PGM.
183
+ """
184
+ return len(self._factors)
185
+
186
+ @property
187
+ def number_of_functions(self) -> int:
188
+ """
189
+ Returns:
190
+ How many potential functions are defined in this PGM, including zero potential functions.
191
+ """
192
+ return sum(1 for _ in self.functions)
193
+
194
+ @property
195
+ def number_of_non_zero_functions(self) -> int:
196
+ """
197
+ Returns:
198
+ How many potential functions are defined in this PGM, excluding zero potential functions.
199
+ """
200
+ return sum(1 for _ in self.non_zero_functions)
201
+
202
+ @property
203
+ def rvs(self) -> Sequence[RandomVariable]:
204
+ """
205
+ Returns:
206
+ All the random variables, in `idx` order, which is the same as creation order.
207
+
208
+ Ensures:
209
+ `self.rvs[rv.idx] = rv`
210
+ """
211
+ return self._rvs
212
+
213
+ @property
214
+ def rv_log_sizes(self) -> Sequence[float]:
215
+ """
216
+ Returns:
217
+ [log2(len(rv)) for rv in self.rvs]
218
+ """
219
+ return [math.log2(len(rv)) for rv in self.rvs]
220
+
221
+ @property
222
+ def indicators(self) -> Sequence[Indicator]:
223
+ """
224
+ Returns:
225
+ All the random variable indicators.
226
+
227
+ Ensures:
228
+ the indicators of a random variable are adjacent,
229
+ the indicators of a random variable are in state index order,
230
+ the random variables are in the same order as `self.rvs`.
231
+ """
232
+ return self._indicators
233
+
234
+ @property
235
+ def factors(self) -> Sequence[Factor]:
236
+ """
237
+ Returns:
238
+ All the factors, in `idx` order, which is the same as creation order.
239
+
240
+ Ensures:
241
+ `self.factors[factor.idx] = factor`
242
+ """
243
+ return self._factors
244
+
245
+ @property
246
+ def functions(self) -> Iterable[PotentialFunction]:
247
+ """
248
+ Iterate over all in-use potential functions of this PGM, including
249
+ zero potential functions.
250
+
251
+ Returns:
252
+ An Iterable over all potential functions (including zero potential functions).
253
+ """
254
+ seen: Set[int] = set()
255
+ for factor in self._factors:
256
+ function = factor.function
257
+ if id(function) not in seen:
258
+ seen.add(id(function))
259
+ yield function
260
+
261
+ @property
262
+ def non_zero_functions(self) -> Iterable[PotentialFunction]:
263
+ """
264
+ Iterate over all in-use potential functions of this PGM, excluding
265
+ zero potential functions.
266
+
267
+ Returns:
268
+ An Iterable over all potential functions (excluding zero potential functions).
269
+ """
270
+ seen: Set[int] = set()
271
+ for factor in self._factors:
272
+ function = factor.function
273
+ if not (isinstance(function, ZeroPotentialFunction) or id(function) in seen):
274
+ seen.add(id(function))
275
+ yield function
276
+
277
+ def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
278
+ """
279
+ Add a new random variable to this PGM.
280
+
281
+ The returned random variable will have an `idx` equal to the value of
282
+ `self.number_of_rvs` just prior to adding the new random variable.
283
+
284
+ Assumes:
285
+ Provided states contain no duplicates.
286
+
287
+ Args:
288
+ name: a name for the random variable.
289
+ states: either an integer number of states or a sequence of state values. If a
290
+ single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
291
+ If a sequence of states are provided then the states must be unique.
292
+
293
+ Returns:
294
+ a RandomVariable object belonging to this PGM.
295
+ """
296
+ return RandomVariable(self, name, states)
297
+
298
+ def new_factor(self, *rvs: RandomVariable) -> Factor:
299
+ """
300
+ Add a new factor to this PGM where the factor connects
301
+ the given random variables.
302
+
303
+ The returned factor will have a ZeroPotentialFunction as its potential function.
304
+ The potential function may be changed by calling methods on the returned factor.
305
+
306
+ The returned factor will have an `idx` equal to the value of
307
+ `self.number_of_factors` just prior to adding the new factor.
308
+
309
+ Assumes:
310
+ The given random variables all belong to this PGM.
311
+ The random variables contain no duplicates.
312
+
313
+ Args:
314
+ *rvs: the random variables.
315
+
316
+ Returns:
317
+ a Factor object belonging to this PGM.
318
+ """
319
+ return Factor(self, *rvs)
320
+
321
+ def new_factor_implies(
322
+ self,
323
+ rv_1: RandomVariable,
324
+ state_idxs_1: int | Collection[int],
325
+ rv_2: RandomVariable,
326
+ state_idxs_2: int | Collection[int],
327
+ ) -> Factor:
328
+ """
329
+ Add a sparse 0/1 factor to this PGM representing:
330
+ rv_1 in state_idxs_1 ==> rv_2 in states_2.
331
+ That is:
332
+ factor[s1, s2] = 1, if s1 not in state_idxs_1 or s2 in states_2;
333
+ = 0, otherwise.
334
+
335
+ Args:
336
+ rv_1: The first random variable.
337
+ state_idxs_1: state idxs of the first random variable.
338
+ rv_2: The second random variable.
339
+ state_idxs_2: state idxs of the second random variable.
340
+
341
+ Returns:
342
+ a Factor object belonging to this PGM, with a configured sparse potential function.
343
+ """
344
+ if isinstance(state_idxs_1, int):
345
+ state_idxs_1 = (state_idxs_1,)
346
+ if isinstance(state_idxs_2, int):
347
+ state_idxs_2 = (state_idxs_2,)
348
+
349
+ factor = self.new_factor(rv_1, rv_2)
350
+ f = factor.set_sparse()
351
+ for i_1 in rv_1.state_range():
352
+ if i_1 not in state_idxs_1:
353
+ for i_2 in rv_2.state_range():
354
+ f[i_1, i_2] = 1
355
+ else:
356
+ for i_2 in rv_2.state_range():
357
+ if i_2 in state_idxs_2:
358
+ f[i_1, i_2] = 1
359
+ return factor
360
+
361
+ def new_factor_equiv(
362
+ self,
363
+ rv_1: RandomVariable,
364
+ state_idxs_1: int | Collection[int],
365
+ rv_2: RandomVariable,
366
+ state_idxs_2: int | Collection[int],
367
+ ) -> Factor:
368
+ """
369
+ Add a sparse 0/1 factor to this PGM representing:
370
+ rv_1 in state_idxs_1 <==> rv_2 in state_idxs_2.
371
+ That is:
372
+ factor[s1, s2] = 1, if s1 in state_idxs_1 == s2 in state_idxs_2;
373
+ = 0, otherwise.
374
+
375
+ Args:
376
+ rv_1: The first random variable.
377
+ state_idxs_1: state idxs of the first random variable.
378
+ rv_2: The second random variable.
379
+ state_idxs_2: state idxs of the second random variable.
380
+
381
+ Returns:
382
+ a Factor object belonging to this PGM, with a configured sparse potential function.
383
+ """
384
+ if isinstance(state_idxs_1, int):
385
+ state_idxs_1 = (state_idxs_1,)
386
+ if isinstance(state_idxs_2, int):
387
+ state_idxs_2 = (state_idxs_2,)
388
+
389
+ factor = self.new_factor(rv_1, rv_2)
390
+ f = factor.set_sparse()
391
+ for i_1 in rv_1.state_range():
392
+ in_1 = i_1 in state_idxs_1
393
+ for i_2 in rv_2.state_range():
394
+ in_2 = i_2 in state_idxs_2
395
+ if in_1 == in_2:
396
+ f[i_1, i_2] = 1
397
+ return factor
398
+
399
+ def new_factor_functional(
400
+ self,
401
+ function: Callable[[...], int],
402
+ result_rv: RandomVariable,
403
+ *input_rvs: RandomVariable
404
+ ) -> Factor:
405
+ """
406
+ Add a sparse 0/1 factor to this PGM representing:
407
+ result_rv == function(*rvs).
408
+ That is:
409
+ factor[result_s, *input_s] = 1, if result_s == function(*input_s);
410
+ = 0, otherwise.
411
+ Args:
412
+ function: a function from state indexes of the input random variables to a state index
413
+ of the result random variable. The function should take the same number of arguments
414
+ as `input_rvs` and return a state index for `result_rv`.
415
+ result_rv: the random variable defining result values.
416
+ *input_rvs: the random variables defining input values.
417
+
418
+ Returns:
419
+ a Factor object belonging to this PGM, with a configured sparse potential function.
420
+ """
421
+ factor = self.new_factor(result_rv, *input_rvs)
422
+ f = factor.set_sparse()
423
+ for input_s in _combos([list(rv.state_range()) for rv in input_rvs]):
424
+ result_s = function(*input_s)
425
+ f[(result_s,) + input_s] = 1
426
+ return factor
427
+
428
+ def indicator_pair(self, indicator: Indicator) -> Tuple[RandomVariable, State]:
429
+ """
430
+ Convert the given indicator to its RandomVariable and State value.
431
+
432
+ Args:
433
+ indicator: the indicator to convert.
434
+
435
+ Returns:
436
+ (rv, state) where
437
+ rv: is the random variable of the indicator.
438
+ state: is the random variable state of the indicator.
439
+ """
440
+ rv = self._rvs[indicator.rv_idx]
441
+ state = rv.states[indicator.state_idx]
442
+ return rv, state
443
+
444
+ def indicator_str(self, *indicators: Indicator, sep: str = '=', delim: str = ', ') -> str:
445
+ """
446
+ Render indicators as a string.
447
+
448
+ For example:
449
+ pgm = PGM()
450
+ a = pgm.new_rv('A', ('x', 'y', 'z'))
451
+ b = pgm.new_rv('B', (3, 5))
452
+ print(pgm.indicator_str(a[0], b[1], a[2]))
453
+ will print:
454
+ A=x, B=5, A=z
455
+
456
+ Args:
457
+ *indicators: the indicators to render.
458
+ sep: the separator to use between the random variable and its state.
459
+ delim: the delimiter to used when rendering multiple indicators.
460
+
461
+ Returns:
462
+ a string representation of the given indicators.
463
+ """
464
+ return delim.join(
465
+ f'{rv}{sep}{state}'
466
+ for rv, state in (
467
+ self.indicator_pair(indicator)
468
+ for indicator in indicators
469
+ )
470
+ )
471
+
472
+ def condition_str(self, *indicators: Indicator) -> str:
473
+ """
474
+ Render indicators as a string, grouping indicators by random variable.
475
+
476
+ For example:
477
+ pgm = PGM()
478
+ a = pgm.new_rv('A', ('x', 'y', 'z'))
479
+ b = pgm.new_rv('B', (3, 5))
480
+ print(pgm.condition_str(a[0], b[1], a[2]))
481
+ will print:
482
+ A in {x, z}, B=5
483
+
484
+ Args:
485
+ *indicators: the indicators to render.
486
+ Return:
487
+ a string representation of the given indicators, as a condition.
488
+ """
489
+ indicators: List[Indicator] = sorted(indicators, reverse=True)
490
+ cur_rv: Set[Indicator] = set()
491
+ cur_idx: int = -1 # rv_idx of the rv we are currently working on, -1 means not yet started.
492
+ cur_str: str = '' # accumulated result string
493
+ while len(indicators) > 0:
494
+ this_ind = indicators.pop()
495
+ if this_ind.rv_idx != cur_idx:
496
+ if cur_idx >= 0:
497
+ cur_str = self._condition_str_rv(cur_str, cur_rv)
498
+ cur_rv = set()
499
+ cur_idx = this_ind.rv_idx
500
+ cur_rv.add(this_ind)
501
+ if cur_idx >= 0:
502
+ cur_str = self._condition_str_rv(cur_str, cur_rv)
503
+ return cur_str
504
+
505
+ def instance_str(
506
+ self,
507
+ instance: Instance,
508
+ rvs: Optional[Sequence[RandomVariable]] = None,
509
+ sep: str = '=',
510
+ delim: str = ', ',
511
+ ) -> str:
512
+ """
513
+ Render an instance as a string.
514
+
515
+ The result looks something like 'X=x, Y=y, Z=z' where X, Y, and X are
516
+ random variables and x, y, and z are the states represented by the
517
+ given instance.
518
+
519
+ Args:
520
+ instance: the instance to render.
521
+ rvs: the random variables that the instance refers to. If rvs is None, then `self.rvs` is used.
522
+ sep: the separator to use between the random variable and its state.
523
+ delim: the delimiter to used when rendering multiple indicators.
524
+
525
+ Returns:
526
+ a string representation of the indicators implied by the given instance.
527
+ """
528
+ if rvs is None:
529
+ rvs = self.rvs
530
+ assert len(instance) == len(rvs)
531
+ return self.indicator_str(
532
+ *[rv[state] for rv, state in zip(rvs, instance)],
533
+ sep=sep,
534
+ delim=delim
535
+ )
536
+
537
+ def state_str(
538
+ self,
539
+ instance: Instance,
540
+ rvs: Optional[Sequence[RandomVariable]] = None,
541
+ delim: str = ', ',
542
+ ) -> str:
543
+ """
544
+ Render the states of an instance.
545
+
546
+ The result looks something like 'x, y, z' where x, y, and z are
547
+ the states of the random variables represented by the given instance.
548
+
549
+ Args:
550
+ instance: the instance to render.
551
+ rvs: the random variables that the instance refers to. If rvs is None, then `self.rvs` is used.
552
+ delim: the delimiter to used when rendering multiple indicators.
553
+
554
+ Returns:
555
+ a string representation of the states implied by the given instance.
556
+ """
557
+ if rvs is None:
558
+ rvs = self.rvs
559
+ assert len(instance) == len(rvs)
560
+ return delim.join(str(rv.states[i]) for rv, i in zip(rvs, instance))
561
+
562
+ def instances(self, flip: bool = False) -> Iterator[Instance]:
563
+ """
564
+ Iterate over all possible instances of this PGM, in natural index
565
+ order (i.e., last random variable changing most quickly).
566
+
567
+ Args:
568
+ flip: if true, then first random variable changes most quickly.
569
+
570
+ Returns:
571
+ an iteration over tuples, each tuple holds random variable state indexes
572
+ co-indexed with this PGM's random variables, `self.rvs`.
573
+ """
574
+ return _combos_ranges(tuple(len(rv) for rv in self._rvs), flip=not flip)
575
+
576
+ def instances_as_indicators(self, flip: bool = False) -> Iterator[Sequence[Indicator]]:
577
+ """
578
+ Iterate over all possible instances of this PGM, in natural index
579
+ order (i.e., last random variable changing most quickly).
580
+
581
+ Args:
582
+ flip: if true, then first random variable changes most quickly.
583
+
584
+ Returns:
585
+ an iteration over tuples, each tuples holds random variable indicators
586
+ co-indexed with this PGM's random variables, `self.rvs`.
587
+ """
588
+ for inst in self.instances(flip=flip):
589
+ yield self.state_idxs_to_indicators(inst)
590
+
591
+ def state_idxs_to_indicators(self, instance: Sequence[int]) -> Sequence[Indicator]:
592
+ """
593
+ Given an instance (list of random variable state indexes), co-indexed with the PGM's
594
+ random variables, `self.rvs`, return the corresponding indicators.
595
+
596
+ Assumes:
597
+ The instance has the same length as `self.rvs`.
598
+ The instance is co-indexed with `self.rvs`.
599
+
600
+ Args:
601
+ instance: the instance to convert to indicators.
602
+
603
+ Returns:
604
+ a tuple of indicators, co-indexed with `self.rvs`.
605
+ """
606
+ return tuple(rv[state] for rv, state in zip(self._rvs, instance))
607
+
608
+ def factor_values(self, key: Key) -> Iterator[float]:
609
+ """
610
+ For a given instance key, each factor defines a single value. This method
611
+ returns those values.
612
+
613
+ Args:
614
+ key: the key defining an instance of this PGM.
615
+
616
+ Returns:
617
+ an iterator over factor values, co-indexed with the factors of this PGM.
618
+ """
619
+ instance: Instance = check_key(self._shape, key)
620
+ assert len(instance) == len(self._rvs)
621
+ for factor in self._factors:
622
+ states: Sequence[int] = tuple(instance[rv.idx] for rv in factor.rvs)
623
+ value: float = factor.function[states]
624
+ yield value
625
+
626
+ @property
627
+ def is_structure_bayesian(self) -> bool:
628
+ """
629
+ Does the PGM structure correspond to a Bayesian network, where
630
+ each factor is taken to be a CPT and the first random variable of factor
631
+ is taken to be the child.
632
+
633
+ This method does not check the factor parameters to confirm they correspond
634
+ to valid CPTs.
635
+
636
+ Return:
637
+ True only if:
638
+ the number of factors equals the number of random variables,
639
+ each random variable appears exactly once as the first random variable of a factor,
640
+ there are no directed loops created by the factors.
641
+ """
642
+
643
+ # One factor per random variable.
644
+ if self.number_of_factors != self.number_of_rvs:
645
+ return False
646
+
647
+ # Each random variable is a child.
648
+ # Map each random variable to the factor it is a child of
649
+ child_to_factor: Dict[int, Factor] = {
650
+ factor.rvs[0].idx: factor
651
+ for factor in self._factors
652
+ }
653
+ if len(child_to_factor) != self.number_of_rvs:
654
+ return False
655
+
656
+ # Factors form a DAG
657
+ states: NDArrayUInt8 = np.zeros(self.number_of_factors, dtype=np.uint8)
658
+ for factor in self._factors:
659
+ if self._has_cycle(factor, child_to_factor, states):
660
+ return False
661
+
662
+ # All tests passed
663
+ return True
664
+
665
+ def factors_are_cpts(self, tolerance: float = DEFAULT_TOLERANCE) -> bool:
666
+ """
667
+ Are all factor potential functions set with parameters values
668
+ conforming to Conditional Probability Tables.
669
+
670
+ Assumes:
671
+ tolerance is non-negative.
672
+
673
+ Args:
674
+ tolerance: a tolerance when testing if values are equal to zero or one.
675
+
676
+ Returns:
677
+ True only if every potential function conforms to being a valid CPT.
678
+ """
679
+ return all(function.is_cpt(tolerance) for function in self.functions)
680
+
681
+ def check_is_bayesian_network(self, tolerance: float = DEFAULT_TOLERANCE) -> bool:
682
+ """
683
+ Is this PGM a Bayesian network.
684
+
685
+ Assumes:
686
+ tolerance is non-negative.
687
+
688
+ Args:
689
+ tolerance: a tolerance when testing if values are equal to zero or one.
690
+
691
+ Returns:
692
+ `is_structure_bayesian and check_factors_are_cpts(tolerance)`.
693
+ """
694
+ return self.is_structure_bayesian and self.factors_are_cpts(tolerance)
695
+
696
+ def value_product(self, key: Key) -> float:
697
+ """
698
+ For a given instance key, each factor defines a single value. This method
699
+ returns the product of those values.
700
+
701
+ Args:
702
+ key: the key defining an instance of this PGM.
703
+
704
+ Returns:
705
+ the product of factor values.
706
+ """
707
+ return _multiply(self.factor_values(key))
708
+
709
+ def value_product_indicators(self, *indicators: Indicator) -> float:
710
+ """
711
+ Return the product of factors, conditioned on the given indicators.
712
+
713
+ For random variables not mentioned in the indicators, then the result is the sum
714
+ of the value product for each possible combination of states of the unmentioned
715
+ random variables.
716
+
717
+ If no indicators are provided, then the value of the partition function (z)
718
+ is returned.
719
+
720
+ Warning:
721
+ this is potentially computationally expensive as it marginalised random
722
+ variables not mentioned in the given indicators.
723
+
724
+ Args:
725
+ *indicators: are indicators from random variables of this PGM.
726
+
727
+ Returns:
728
+ the product of factors, conditioned on the given instance. This is the
729
+ computed value of the PGM, conditioned on the given instance.
730
+
731
+ Raises:
732
+ RuntimeError: if a random variable is referenced multiple times in the given indicators.
733
+ """
734
+ # Create an instance from the indicators
735
+ inst: List[int] = [-1] * self.number_of_rvs
736
+ for indicator in indicators:
737
+ rv_idx: int = indicator.rv_idx
738
+ if inst[rv_idx] >= 0:
739
+ raise RuntimeError(f'random variable mentioned multiple times: {self.rvs[rv_idx]}')
740
+ inst[rv_idx] = indicator.state_idx
741
+
742
+ # Collect rvs not mentioned - to marginalise
743
+ rvs = [rv for rv in self.rvs if inst[rv.idx] < 0]
744
+
745
+ # Accumulate the result
746
+ sum_value = 0
747
+ for instance in rv_instances_as_indicators(*rvs):
748
+ for indicator in instance:
749
+ inst[indicator.rv_idx] = indicator.state_idx
750
+ sum_value += self.value_product(inst)
751
+
752
+ return sum_value
753
+
754
+ def dump_synopsis(
755
+ self,
756
+ *,
757
+ prefix: str = '',
758
+ precision: int = 3,
759
+ max_state_digits: int = 21,
760
+ ):
761
+ """
762
+ Print a synopsis of the PGM.
763
+ This is intended for demonstration and debugging purposes.
764
+
765
+ Args:
766
+ prefix: optional prefix for indenting all lines.
767
+ precision: a limit on the render precision of floating point numbers.
768
+ max_state_digits: a limit on the number of digits when showing number of states as an integer.
769
+ """
770
+ # limit the precision when displaying number of states
771
+ num_states: int = self.number_of_states
772
+ number_of_parameters = sum(function.number_of_parameters for function in self.functions)
773
+ number_of_nz_parameters = sum(function.number_of_parameters for function in self.non_zero_functions)
774
+
775
+ if math.log10(num_states) > max_state_digits:
776
+ log_states = math.log10(num_states)
777
+ exp = int(log_states)
778
+ man = math.pow(10, log_states - exp)
779
+ num_states_str = f'{man:,.{precision}f}e+{exp}'
780
+ else:
781
+ num_states_str = f'{num_states:,}'
782
+
783
+ log_2_num_states = math.log2(num_states)
784
+ if (
785
+ log_2_num_states == 0
786
+ or (
787
+ log_2_num_states == int(log_2_num_states)
788
+ and math.log10(log_2_num_states) <= max_state_digits
789
+ )
790
+ ):
791
+ log_2_num_states_str = f'{int(log_2_num_states):,}'
792
+ else:
793
+ log_2_num_states_str = f'{math.log2(num_states):,.{precision}f}'
794
+
795
+ print(f'{prefix}name: {self.name}')
796
+ print(f'{prefix}number of random variables: {self.number_of_rvs:,}')
797
+ print(f'{prefix}number of indicators: {self.number_of_indicators:,}')
798
+ print(f'{prefix}number of states: {num_states_str}')
799
+ print(f'{prefix}log 2 of states: {log_2_num_states_str}')
800
+ print(f'{prefix}number of factors: {self.number_of_factors:,}')
801
+ print(f'{prefix}number of functions: {self.number_of_functions:,}')
802
+ print(f'{prefix}number of non-zero functions: {self.number_of_non_zero_functions:,}')
803
+ print(f'{prefix}number of parameters: {number_of_parameters:,}')
804
+ print(f'{prefix}number of functions (excluding ZeroPotentialFunction): {self.number_of_non_zero_functions:,}')
805
+ print(f'{prefix}number of parameters (excluding ZeroPotentialFunction): {number_of_nz_parameters:,}')
806
+ print(f'{prefix}Bayesian structure: {self.is_structure_bayesian}')
807
+ print(f'{prefix}CPT factors: {self.factors_are_cpts()}')
808
+
809
+ def dump(
810
+ self,
811
+ *,
812
+ prefix: str = '',
813
+ indent: str = ' ',
814
+ show_function_values: bool = False,
815
+ precision: int = 3,
816
+ max_state_digits: int = 21,
817
+ ) -> None:
818
+ """
819
+ Print a dump of the PGM.
820
+ This is intended for demonstration and debugging purposes.
821
+
822
+ Args:
823
+ prefix: optional prefix for indenting all lines.
824
+ show_function_values: if true, then the function values will be dumped.
825
+ indent: additional prefix to use for extra indentation.
826
+ precision: a limit on the render precision of floating point numbers.
827
+ max_state_digits: a limit on the number of digits when showing number of states as an integer.
828
+ """
829
+
830
+ next_prefix: str = prefix + indent
831
+ next_next_prefix: str = next_prefix + indent
832
+
833
+ print(f'{prefix}PGM id={id(self)} name={self.name!r}')
834
+ self.dump_synopsis(prefix=next_prefix, precision=precision, max_state_digits=max_state_digits)
835
+
836
+ print(f'{prefix}random variables ({self.number_of_rvs})')
837
+ for rv in self.rvs:
838
+ print(f'{next_prefix}{rv.idx:>3} {rv.name!r} ({len(rv)})', end='')
839
+ if not rv.is_default_states():
840
+ print(' [', end='')
841
+ print(', '.join(repr(s) for s in rv.states), end='')
842
+ print(']', end='')
843
+ print()
844
+
845
+ print(f'{prefix}factors ({self.number_of_factors})')
846
+ for factor in self.factors:
847
+ rv_idxs = [rv.idx for rv in factor.rvs]
848
+ if factor.is_zero:
849
+ function_ref = '<zero>'
850
+ else:
851
+ function = factor.function
852
+ function_ref = f'{id(function)}: {function.__class__.__name__}'
853
+
854
+ print(f'{next_prefix}{factor.idx:>3} rvs={rv_idxs} function={function_ref}')
855
+
856
+ print(f'{prefix}functions ({self.number_of_functions})')
857
+ for function in sorted(self.non_zero_functions, key=lambda f: id(f)):
858
+ print(f'{next_prefix}{id(function):>13}: {function.__class__.__name__}')
859
+ function.dump(prefix=next_next_prefix, show_function_values=show_function_values, show_id_class=False)
860
+
861
+ print(f'{prefix}end PGM id={id(self)}')
862
+
863
+ def _has_cycle(self, factor: Factor, child_to_factor: Dict[int, Factor], states: NDArrayUInt8) -> bool:
864
+ """
865
+ Support function for `is_structure_bayesian`.
866
+
867
+ A recursive depth-first-search to see if the factors form a DAG.
868
+
869
+ For a factor `f` the value of states[f.idx] is the search state.
870
+ Specifically:
871
+ state 0 => the factor has not been seen yet,
872
+ state 1 => the factor is seen but not fully processed,
873
+ state 2 => the factor is fully processed.
874
+
875
+ Args:
876
+ factor: the current Factor being checked.
877
+ child_to_factor: a dictionary from `RandomVariable.idx` to Factor
878
+ with that random variable as the child.
879
+ states: depth-first-search states, i.e., `states[i]` is the state of a factor with `Factor.idx == i`.
880
+ Returns:
881
+ True if a directed cycle is detected.
882
+ """
883
+ f_idx: int = factor.idx
884
+ match states.item(f_idx):
885
+ case 1:
886
+ return True
887
+ case 0:
888
+ states[f_idx] = 1
889
+ for parent in factor.rvs[1:]:
890
+ parent_factor = child_to_factor[parent.idx]
891
+ if self._has_cycle(parent_factor, child_to_factor, states):
892
+ return True
893
+ states[f_idx] = 2
894
+ return False
895
+ return False
896
+
897
+ def _register_rv(self, rv: RandomVariable) -> None:
898
+ """
899
+ Called by the constructor of RandomVariable to record a newly created Random variable
900
+ of this PGM.
901
+
902
+ Args:
903
+ rv: the newly constructed random variable.
904
+ """
905
+ assert rv.pgm is self
906
+ self._rvs += (rv,)
907
+ self._shape += (len(rv),)
908
+ self._indicators += rv.indicators
909
+
910
+ def _condition_str_rv(
911
+ self,
912
+ cur_str: str,
913
+ cur_rv: Set[Indicator],
914
+ sep: str = ', ',
915
+ equal: str = '=',
916
+ elem: str = ' in ',
917
+ ) -> str:
918
+ """
919
+ Support method for `self.condition_str`.
920
+
921
+ This is a method renders a condition defined by a set of indicators, of the same random variable.
922
+
923
+ Args:
924
+ cur_str: the string to append to.
925
+ cur_rv: a set of indicators, all from the same random variable.
926
+ sep: the separator string to use between condition components.
927
+ equal: the string to use for _rv_ = _state_.
928
+ elem: the string to use for _rv_ in _set_.
929
+
930
+ Returns:
931
+ `cur_str` appended with the new condition, `cur_rv`.
932
+ """
933
+ if cur_str != '':
934
+ cur_str += sep
935
+ if len(cur_rv) == 1:
936
+ cur_str += self.indicator_str(*cur_rv, sep=equal)
937
+ else:
938
+ _cur_rv = sorted(cur_rv)
939
+ rv = self._rvs[_cur_rv[0].rv_idx]
940
+ states_str = sep.join(str(rv.states[ind.state_idx]) for ind in _cur_rv)
941
+ cur_str += f'{rv}{elem}{{{states_str}}}'
942
+ return cur_str
943
+
944
+
945
+ @dataclass(frozen=True, eq=True, slots=True)
946
+ class Indicator:
947
+ """
948
+ An indicator identifies a random variable being in a particular state.
949
+
950
+ Indicators are immutable and hashable.
951
+
952
+ Note that an Indicator does not know which PGM it came from, therefore indicators from one PGM
953
+ are interchangeable with indicators of another PGM so long as corresponding random variables of the
954
+ PGMs are co-indexed (created in the same order) and corresponding random variables have the same
955
+ states.
956
+
957
+ Fields:
958
+ rv_idx: `rv.idx` where `rv` is the random variable referenced by this indicator.
959
+ state_idx: the state index of the state referenced by this indicator.
960
+ """
961
+ rv_idx: int
962
+ state_idx: int
963
+
964
+ def __lt__(self, other) -> bool:
965
+ """
966
+ Define a sort order over indicators.
967
+ When sorted, indicators are ordered by random variable index, then by state index.
968
+ """
969
+ if isinstance(other, Indicator):
970
+ if self.rv_idx < other.rv_idx:
971
+ return True
972
+ if self.rv_idx > other.rv_idx:
973
+ return False
974
+ return self.state_idx < other.state_idx
975
+ return False
976
+
977
+
978
+ class RandomVariable(Sequence[Indicator]):
979
+ """
980
+ A random variable in a probabilistic graphical model.
981
+
982
+ Random variables are immutable and hashable.
983
+
984
+ Each RandomVariable has a fixed finite number of states.
985
+ Its states are indexed by integers, counting from zero.
986
+
987
+ Every RandomVariable object belongs to exactly one PGM object.
988
+
989
+ Every random variable has an index (counting from zero) which is its position
990
+ in the random variable's PGM list of random variables.
991
+
992
+ A random variable behaves like a sequence of Indicators, where each indicator represents a random
993
+ variable being in a particular state. Specifically for a random variable rv, len(rv) is the
994
+ number of states of the random variable and rv[i] is the Indicators representing that
995
+ rv is in the ith state. When sliced, the result is a tuple, i.e. rv[1:3] = (rv[1], rv[2]).
996
+
997
+ A RandomVariable has a name. This is for human convenience and has no functional purpose
998
+ within a PGM.
999
+ """
1000
+
1001
+ def __init__(self, pgm: PGM, name: str, states: Union[int, Sequence[State]]):
1002
+ """
1003
+ Create a new random variable, in the given PGM.
1004
+
1005
+ Assumes:
1006
+ Provided states contain no duplicates.
1007
+
1008
+ Args:
1009
+ pgm: the PGM that the random variable will belong to.
1010
+ name: a name for the random variable.
1011
+ states: either an integer number of states or a sequence of state values. If a
1012
+ single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
1013
+ If a sequence of states are provided then the states must be unique.
1014
+ """
1015
+ self._pgm: PGM = pgm
1016
+ self._name: str = name
1017
+
1018
+ if isinstance(states, int):
1019
+ states = tuple(range(states))
1020
+
1021
+ self._states: Sequence[State] = tuple(states)
1022
+ self._inv_states: Dict[State, int] = {state: idx for idx, state in enumerate(self._states)}
1023
+
1024
+ if len(self._inv_states) != len(self._states):
1025
+ raise ValueError('random variable states are not unique')
1026
+
1027
+ self._offset: int = pgm.number_of_indicators
1028
+ self._idx: int = pgm.number_of_rvs
1029
+ self._indicators: Sequence[Indicator] = tuple(Indicator(self._idx, i) for i in range(len(self._states)))
1030
+
1031
+ # Register self with our PGM
1032
+ # noinspection PyProtectedMember
1033
+ pgm._register_rv(self)
1034
+
1035
+ @property
1036
+ def pgm(self) -> PGM:
1037
+ """
1038
+ Returns:
1039
+ The PGM that this random variable belongs to.
1040
+ """
1041
+ return self._pgm
1042
+
1043
+ @property
1044
+ def name(self) -> str:
1045
+ """
1046
+ Returns:
1047
+ The name of this random variable.
1048
+ """
1049
+ return self._name
1050
+
1051
+ @property
1052
+ def idx(self) -> int:
1053
+ """
1054
+ Returns:
1055
+ The index of this random variable into the PGM.
1056
+
1057
+ Ensures:
1058
+ `self.pgm.rvs[self.idx] is self`.
1059
+ """
1060
+ return self._idx
1061
+
1062
+ @property
1063
+ def offset(self) -> int:
1064
+ """
1065
+ Returns:
1066
+ The index into the PGM's indicators for the start of this random variable's indicators.
1067
+
1068
+ Ensures:
1069
+ `self.pgm.indicators[self.offset + i] is self[i] for i in range(len(self))`.
1070
+ """
1071
+ return self._offset
1072
+
1073
+ @property
1074
+ def states(self) -> Sequence[State]:
1075
+ """
1076
+ Returns:
1077
+ the states of this random variable, in state index order.
1078
+ """
1079
+ return self._states
1080
+
1081
+ @property
1082
+ def indicators(self) -> Sequence[Indicator]:
1083
+ """
1084
+ Returns:
1085
+ the indicators of this random variable, in state index order.
1086
+ """
1087
+ return self._indicators
1088
+
1089
+ def state_range(self) -> Iterable[int]:
1090
+ """
1091
+ Iterate over the state indexes of this random variable, in order.
1092
+
1093
+ Returns:
1094
+ range(len(self))
1095
+ """
1096
+ return range(len(self._states))
1097
+
1098
+ def factors(self) -> Iterator[Factor]:
1099
+ """
1100
+ Iterate over factors that this random variable participates in.
1101
+ This method performs a search through all `self.pgm.factors`.
1102
+
1103
+ Returns:
1104
+ an iterator over factors.
1105
+ """
1106
+ for factor in self._pgm.factors:
1107
+ if self in factor.rvs:
1108
+ yield factor
1109
+
1110
+ def markov_blanket(self) -> Set[RandomVariable]:
1111
+ """
1112
+ Return the set of random variable that are connected
1113
+ to this random variable by a factor.
1114
+ This method performs a search through all `self.pgm.factors`.
1115
+
1116
+ Returns:
1117
+ a set of random variables connected to this random variable by any factor, excluding self.
1118
+ """
1119
+ result = set()
1120
+ for factor in self.factors():
1121
+ result.update(factor.rvs)
1122
+ result.discard(self)
1123
+ return result
1124
+
1125
+ def state_idx(self, state: State) -> int:
1126
+ """
1127
+ Returns:
1128
+ the state index of the given state of this random variable.
1129
+
1130
+ Assumes:
1131
+ the given state is a state of this random variable.
1132
+ """
1133
+ return self._inv_states[state]
1134
+
1135
+ def is_default_states(self) -> bool:
1136
+ """
1137
+ Are the states of this random variable the default states.
1138
+ I.e., `self.states[i] == i, for all 0 <= i < len(self)`.
1139
+
1140
+ Returns:
1141
+ True only if the states are the same as the state indexes.
1142
+ """
1143
+ return all(i == s for i, s in enumerate(self._states))
1144
+
1145
+ def __str__(self) -> str:
1146
+ """
1147
+ Returns:
1148
+ the name of this random variable.
1149
+ """
1150
+ return self._name
1151
+
1152
+ def __call__(self, state: State) -> Indicator:
1153
+ """
1154
+ Get the indicator for the given state.
1155
+ This is equivalent to self[self.state_idx(state)].
1156
+
1157
+ Returns:
1158
+ an indicator of this random variable.
1159
+
1160
+ Assumes:
1161
+ the given state is a state of this random variable.
1162
+ """
1163
+ return self._indicators[self._inv_states[state]]
1164
+
1165
+ def __hash__(self) -> int:
1166
+ """
1167
+ A random variable is hashable.
1168
+ """
1169
+ return self._idx
1170
+
1171
+ def __eq__(self, other) -> bool:
1172
+ """
1173
+ Two random variable are equal if they are the same object.
1174
+ """
1175
+ return self is other
1176
+
1177
+ def equivalent(self, other: RandomVariable | Sequence[Indicator]) -> bool:
1178
+ """
1179
+ Two random variable are equivalent if their indicators are equal. Only
1180
+ random variable indexes and state indexes are checked.
1181
+
1182
+ This ignores the names of the random variable and the names of their states.
1183
+ This means their indicators will work correctly in slot maps, even
1184
+ if from different PGMs.
1185
+
1186
+ Args:
1187
+ other: either a random variable or a sequence of Indicators.
1188
+
1189
+ Returns:
1190
+ True only if they represent the same sequence of indicators.
1191
+ """
1192
+ indicators = self._indicators
1193
+ if isinstance(other, RandomVariable):
1194
+ return self.idx == other.idx and len(self) == len(other)
1195
+ else:
1196
+ return (
1197
+ len(indicators) == len(other) and
1198
+ all(indicators[i] == other[i] for i in range(len(indicators)))
1199
+ )
1200
+
1201
+ def __len__(self) -> int:
1202
+ """
1203
+ Returns:
1204
+ Number of states (or equivalently, the number of indicators) of this random variable.
1205
+ """
1206
+ return len(self._states)
1207
+
1208
+ def __iter__(self) -> Iterator[Indicator]:
1209
+ """
1210
+ Iterate over the indicators of this random variable.
1211
+ """
1212
+ return iter(self._indicators)
1213
+
1214
+ @overload
1215
+ def __getitem__(self, index: int) -> Indicator:
1216
+ ...
1217
+
1218
+ @overload
1219
+ def __getitem__(self, index: slice) -> Sequence[Indicator]:
1220
+ ...
1221
+
1222
+ def __getitem__(self, index):
1223
+ """
1224
+ Get the indexed (or sliced) indicators.
1225
+ """
1226
+ return self._indicators[index]
1227
+
1228
+ def index(self, value: Any, start: int = 0, stop: int = -1) -> int:
1229
+ """
1230
+ Returns the first index of `value`.
1231
+ Raises ValueError if the value is not present.
1232
+ Contracted by Sequence[Indicator].
1233
+
1234
+ Warning:
1235
+ This method is different to `self.idx`.
1236
+ """
1237
+ if isinstance(value, Indicator):
1238
+ if value.rv_idx == self._idx:
1239
+ idx: int = value.state_idx
1240
+ if stop < 0:
1241
+ stop = len(self) + stop + 1
1242
+ if 0 <= idx < len(self) and start <= idx < stop:
1243
+ return value.state_idx
1244
+ raise ValueError(f'{value!r} is not an indicator of the random variable')
1245
+
1246
+ def count(self, value: Any) -> int:
1247
+ """
1248
+ Returns the number of occurrences of `value`.
1249
+ Contracted by Sequence[Indicator].
1250
+ """
1251
+ if isinstance(value, Indicator):
1252
+ if value.rv_idx == self._idx and 0 <= value.state_idx < len(self):
1253
+ return 1
1254
+ return 0
1255
+
1256
+
1257
+ class RVMap(Sequence[RandomVariable]):
1258
+ """
1259
+ Wrap a PGM to provide convenient access to PGM random variables.
1260
+
1261
+ An RVMap of a PGM behaves exactly like the PGM `rvs` property. That it, it
1262
+ behaves like a sequence of RandomVariable objects.
1263
+
1264
+ If the underlying PGM is updated, then the RVMap will automatically update.
1265
+
1266
+ Additionally, an RVMap enables access to the PGM random variable via the name
1267
+ of each random variable.
1268
+
1269
+ for example, if `pgm.rvs[1]` is a random variable named `xray`, then
1270
+ ```
1271
+ rvs = RVMap(pgm)
1272
+
1273
+ # These all retrieve the same random variable object.
1274
+ xray = rvs[1]
1275
+ xray = rvs('xray')
1276
+ xray = rvs.xray
1277
+ ```
1278
+
1279
+ To use an RVMap on a PGM, the variable names must be unique across the PGM.
1280
+ """
1281
+
1282
+ def __init__(self, pgm: PGM, ignore_case: bool = False):
1283
+ """
1284
+ Construct an RVMap for the given PGM.
1285
+
1286
+ Args:
1287
+ pgm: the PGM to wrap.
1288
+ ignore_case: if true, the variable name are not case-sensitive.
1289
+ """
1290
+ self._pgm: PGM = pgm
1291
+ self._ignore_case: bool = ignore_case
1292
+ self.__rv_map: Dict[str, RandomVariable] = {}
1293
+ self._reserved_names: Set[str] = {self._clean_name(name) for name in dir(self)}
1294
+
1295
+ # Force the rv map cache to be updated.
1296
+ # This may raise an exception.
1297
+ _ = self._rv_map
1298
+
1299
+ def _clean_name(self, name: str) -> str:
1300
+ """
1301
+ Adjust the case of the given name as needed.
1302
+ """
1303
+ return name.lower() if self._ignore_case else name
1304
+
1305
+ @property
1306
+ def _rv_map(self) -> Dict[str, RandomVariable]:
1307
+ """
1308
+ Get the cached rv map, updating as needed if the PGM changed.
1309
+ Returns:
1310
+ a mapping from random variable name to random variable
1311
+ """
1312
+ if len(self.__rv_map) != len(self._pgm.rvs):
1313
+ # There is a difference between the map and the PGM - create a new map.
1314
+ self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
1315
+ if len(self.__rv_map) != len(self._pgm.rvs):
1316
+ raise RuntimeError(f'random variable names are not unique')
1317
+ if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
1318
+ raise RuntimeError(f'random variable names clash with reserved names.')
1319
+ return self.__rv_map
1320
+
1321
+ def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
1322
+ """
1323
+ As per `PGM.new_rv`.
1324
+ Delegate creating a new random variable to the PGM.
1325
+
1326
+ Returns:
1327
+ a RandomVariable object belonging to the PGM.
1328
+ """
1329
+ return self._pgm.new_rv(name, states)
1330
+
1331
+ def __len__(self) -> int:
1332
+ return len(self._pgm.rvs)
1333
+
1334
+ def __getitem__(self, index: int) -> RandomVariable:
1335
+ return self._pgm.rvs[index]
1336
+
1337
+ def items(self) -> Iterable[Tuple[str, RandomVariable]]:
1338
+ return self._rv_map.items()
1339
+
1340
+ def keys(self) -> Iterable[str]:
1341
+ return self._rv_map.keys()
1342
+
1343
+ def values(self) -> Iterable[RandomVariable]:
1344
+ return self._rv_map.values()
1345
+
1346
+ def get(self, rv_name: str, default=None):
1347
+ return self._rv_map.get(self._clean_name(rv_name), default)
1348
+
1349
+ def __call__(self, rv_name: str) -> RandomVariable:
1350
+ return self._rv_map[self._clean_name(rv_name)]
1351
+
1352
+ def __getattr__(self, rv_name: str) -> RandomVariable:
1353
+ return self(rv_name)
1354
+
1355
+
1356
+ class Factor:
1357
+ """
1358
+ A PGM factor over one or more random variables declares a relationship between
1359
+ those variables. A Factor also has a potential function associated with
1360
+ it which defines a real-number value with each combination of states of
1361
+ the random variables.
1362
+
1363
+ The default potential function for a factor is a unique ZeroPotentialFunction.
1364
+
1365
+ The order of a Factors random variables is important as many things will be
1366
+ co-indexed with the random variables. For example, the shape of a Factor is
1367
+ the tuple of random variable lengths.
1368
+
1369
+ Note that multiple factors may share a potential function, so long as they all
1370
+ belong to the same PGM object and have the same shape.
1371
+ """
1372
+
1373
+ def __init__(self, pgm: PGM, *rvs: RandomVariable):
1374
+ """
1375
+ Add a new factor to the given PGM.
1376
+
1377
+ Assumes:
1378
+ The given random variables all belong to this PGM.
1379
+ The random variables contain no duplicates.
1380
+
1381
+ Args:
1382
+ pgm: the PGM that the factor will belong to.
1383
+ *rvs: the random variables.
1384
+
1385
+ Returns:
1386
+ a Factor object belonging to this PGM.
1387
+ """
1388
+ if len(set(rvs)) != len(rvs):
1389
+ raise ValueError('duplicated random variable in factor')
1390
+ if len(rvs) == 0:
1391
+ raise ValueError('must be at least one random variable')
1392
+ if any(rv.pgm is not pgm for rv in rvs):
1393
+ raise ValueError('random variable not from the same PGM')
1394
+
1395
+ self._pgm: PGM = pgm
1396
+ self._idx: int = pgm.number_of_factors
1397
+ self._rvs: Sequence[RandomVariable] = tuple(rvs)
1398
+ self._shape: Shape = tuple(len(rv) for rv in rvs)
1399
+
1400
+ self._zero_potential_function: ZeroPotentialFunction = ZeroPotentialFunction(self)
1401
+ self._potential_function: PotentialFunction = self._zero_potential_function
1402
+
1403
+ # Register self with our PGM
1404
+ # noinspection PyProtectedMember
1405
+ pgm._factors += (self,)
1406
+
1407
+ @property
1408
+ def rvs(self) -> Sequence[RandomVariable]:
1409
+ """
1410
+ Returns:
1411
+ The random variables of this factor.
1412
+ """
1413
+ return self._rvs
1414
+
1415
+ @property
1416
+ def pgm(self) -> PGM:
1417
+ """
1418
+ Returns:
1419
+ The PGM that this factor belongs to.
1420
+ """
1421
+ return self._pgm
1422
+
1423
+ @property
1424
+ def idx(self) -> int:
1425
+ """
1426
+ Returns:
1427
+ The index of this factor into the PGM.
1428
+
1429
+ Ensures:
1430
+ `self.pgm.factors[self.idx] is self`.
1431
+ """
1432
+ return self._idx
1433
+
1434
+ @property
1435
+ def shape(self) -> Shape:
1436
+ return self._shape
1437
+
1438
+ @property
1439
+ def number_of_states(self) -> int:
1440
+ """
1441
+ How many distinct states are covered by this Factor.
1442
+ """
1443
+ return self._potential_function.number_of_states
1444
+
1445
+ def __str__(self) -> str:
1446
+ """
1447
+ Return a human-readable string to represent this factor.
1448
+ This is intended mainly for debugging purposes.
1449
+ """
1450
+ return '(' + ', '.join([repr(str(rv)) for rv in self._rvs]) + ')'
1451
+
1452
+ def __len__(self) -> int:
1453
+ """
1454
+ Returns:
1455
+ the number of random variables.
1456
+ """
1457
+ return len(self._rvs)
1458
+
1459
+ @overload
1460
+ def __getitem__(self, index: int) -> RandomVariable:
1461
+ ...
1462
+
1463
+ @overload
1464
+ def __getitem__(self, index: slice) -> Sequence[RandomVariable]:
1465
+ ...
1466
+
1467
+ def __getitem__(self, index):
1468
+ return self._rvs[index]
1469
+
1470
+ def instances(self, flip: bool = False) -> Iterator[Instance]:
1471
+ """
1472
+ Iterate over all possible instances, in natural index order (i.e.,
1473
+ last random variable changing most quickly).
1474
+
1475
+ Args:
1476
+ flip: if true, then first random variable changes most quickly
1477
+
1478
+ Returns:
1479
+ an iterator over tuples, each tuple holds random variable
1480
+ state indexes, co-indexed with this object's shape, i.e., self.shape.
1481
+ """
1482
+ return self.function.instances(flip)
1483
+
1484
+ def parent_instances(self, flip: bool = False) -> Iterator[Instance]:
1485
+ """
1486
+ Iterate over all possible instances of parent random variable, in
1487
+ natural index order (i.e., last random variable changing most quickly).
1488
+
1489
+ Args:
1490
+ flip: if true, then first random variable changes most quickly
1491
+
1492
+ Returns:
1493
+ an iteration over tuples, each tuple holds random variable states
1494
+ co-indexed with this object's 'parent' shape, i.e., `self.shape[1:]`.
1495
+ """
1496
+ return self.function.parent_instances(flip)
1497
+
1498
+ @property
1499
+ def is_zero(self) -> bool:
1500
+ """
1501
+ Is the potential function of this factor set to the special 'zero' potential function.
1502
+ """
1503
+ return self._potential_function is self._zero_potential_function
1504
+
1505
+ @property
1506
+ def function(self) -> PotentialFunction:
1507
+ return self._potential_function
1508
+
1509
+ @function.setter
1510
+ def function(self, function: PotentialFunction | Factor) -> None:
1511
+ """
1512
+ Set the potential function for this PGM factor to the given potential function
1513
+ or factor.
1514
+
1515
+ Assumes:
1516
+ The given potential function belongs to the same PGM as this Factor.
1517
+ The potential function has the correct shape.
1518
+ """
1519
+ if isinstance(function, Factor):
1520
+ function = function.function
1521
+ assert isinstance(function, PotentialFunction)
1522
+
1523
+ if self._potential_function is function:
1524
+ # nothing to do
1525
+ return
1526
+
1527
+ if function.pgm is not self._pgm:
1528
+ raise ValueError(f'the given function is not of the same PGM as the factor')
1529
+
1530
+ if function.shape != self._shape:
1531
+ raise ValueError(f'incorrect function shape: expected {self._shape}, got {function.shape}')
1532
+
1533
+ if isinstance(function, ZeroPotentialFunction):
1534
+ self.set_zero()
1535
+ else:
1536
+ self._potential_function = function
1537
+
1538
+ def set_zero(self) -> ZeroPotentialFunction:
1539
+ """
1540
+ Set the factor's potential function to its original ZeroPotentialFunction.
1541
+
1542
+ Returns:
1543
+ the potential function.
1544
+ """
1545
+ self._potential_function = self._zero_potential_function
1546
+ return self._potential_function
1547
+
1548
+ def set_dense(self) -> DensePotentialFunction:
1549
+ """
1550
+ Set to the potential function to a new `DensePotentialFunction` object.
1551
+
1552
+ Returns:
1553
+ the potential function.
1554
+ """
1555
+ self._potential_function = DensePotentialFunction(self)
1556
+ return self._potential_function
1557
+
1558
+ def set_sparse(self) -> SparsePotentialFunction:
1559
+ """
1560
+ Set to the potential function to a new `SparsePotentialFunction` object.
1561
+
1562
+ Returns:
1563
+ the potential function.
1564
+ """
1565
+ self._potential_function = SparsePotentialFunction(self)
1566
+ return self._potential_function
1567
+
1568
+ def set_compact(self) -> CompactPotentialFunction:
1569
+ """
1570
+ Set to the potential function to a new `CompactPotentialFunction` object.
1571
+
1572
+ Returns:
1573
+ the potential function.
1574
+ """
1575
+ self._potential_function = CompactPotentialFunction(self)
1576
+ return self._potential_function
1577
+
1578
+ def set_clause(self, *key: int) -> ClausePotentialFunction:
1579
+ """
1580
+ Set to the potential function to a new `ClausePotentialFunction` object.
1581
+
1582
+ Args:
1583
+ *key: defines the random variable states of the clause. The key is a sequence of
1584
+ random variable state indexes, co-indexed with `Factor.rvs`.
1585
+
1586
+ Returns:
1587
+ the potential function.
1588
+
1589
+ Raises:
1590
+ KeyError: if the key is not valid for the shape of the factor.
1591
+ """
1592
+ self._potential_function = ClausePotentialFunction(self, key)
1593
+ return self._potential_function
1594
+
1595
+ def set_cpt(self, tolerance: float = DEFAULT_TOLERANCE) -> CPTPotentialFunction:
1596
+ """
1597
+ Set to the potential function to a new `CPTPotentialFunction` object.
1598
+
1599
+ Args:
1600
+ tolerance: a tolerance when testing if values are equal to zero or one.
1601
+
1602
+ Returns:
1603
+ the potential function.
1604
+
1605
+ Raises:
1606
+ ValueError: if tolerance is negative.
1607
+ """
1608
+ self._potential_function = CPTPotentialFunction(self, tolerance)
1609
+ return self._potential_function
1610
+
1611
+
1612
+ @dataclass(frozen=True, eq=True)
1613
+ class ParamId:
1614
+ """
1615
+ A ParamId identifies a parameter of a potential function.
1616
+
1617
+ Parameter identifiers uniquely identify every parameter within a PGM.
1618
+
1619
+ A ParamId is immutable and hashable.
1620
+ """
1621
+ function_id: int
1622
+ param_idx: int
1623
+
1624
+
1625
+ class PotentialFunction(ABC):
1626
+ """
1627
+ A potential function defines the potential values for a Factor, where
1628
+ a factor joins one or more variables of a PGM.
1629
+
1630
+ A potential function may be shared by several Factors of a PGM,
1631
+ i.e., can be applied to multiple variables.
1632
+
1633
+ The `shape` of a potential function is a tuple of integers which defines
1634
+ the number of variables, len(shape), and the number of states of each
1635
+ variable, shape[i].
1636
+
1637
+ The potential function value for variable states (x = i, y = j, ...) is given by
1638
+ self[i, j, ...], i.e., self.__getitem__((i, j, ...)). The tuple, (i, j, ...), is
1639
+ known as a Key.
1640
+
1641
+ The values of a potential function are defined by potential function parameters.
1642
+ The number of potential function parameters is given by number_of_parameters.
1643
+ The value of each parameter is given by get_param(i), where i is the parameter index.
1644
+
1645
+ Every valid key of the potential function is mapped either mapped to a parameter or is
1646
+ "guaranteed zero" which means that the value is zero and cannot be changed by changing
1647
+ the values of the potential function's parameters.
1648
+ """
1649
+
1650
+ def __init__(self, factor: Factor):
1651
+ """
1652
+ Create a potential function compatible with the given factor.
1653
+
1654
+ Ensures:
1655
+ Does not hold a reference to the given factor.
1656
+ Does not register the potential function with the PGM.
1657
+
1658
+ Args:
1659
+ factor: which factor is this potential function is compatible with.
1660
+ """
1661
+ self._pgm: PGM = factor.pgm
1662
+ self._shape: Shape = factor.shape
1663
+ self._number_of_states = _multiply(self._shape)
1664
+
1665
+ @property
1666
+ def pgm(self) -> PGM:
1667
+ """
1668
+ Returns:
1669
+ The PGM that this potential function belong to.
1670
+ """
1671
+ return self._pgm
1672
+
1673
+ @property
1674
+ def shape(self) -> Shape:
1675
+ """
1676
+ Returns:
1677
+ The shape of this potential function.
1678
+ """
1679
+ return self._shape
1680
+
1681
+ @property
1682
+ def number_of_rvs(self) -> int:
1683
+ """
1684
+ Returns:
1685
+ The number of random variables in this potential function.
1686
+ """
1687
+ return len(self._shape)
1688
+
1689
+ @property
1690
+ def number_of_states(self) -> int:
1691
+ """
1692
+ How many distinct states are covered by this potential function.
1693
+
1694
+ Returns:
1695
+ The size of the state space of this potential function.
1696
+ """
1697
+ return self._number_of_states
1698
+
1699
+ @property
1700
+ def number_of_parent_states(self) -> int:
1701
+ """
1702
+ How many distinct states are covered by this potential function parents,
1703
+ i.e., excluding the first random variable.
1704
+
1705
+ Returns:
1706
+ The size of the state space of this potential function parent random variables.
1707
+ """
1708
+ return _multiply(self._shape[1:])
1709
+
1710
+ def count_usage(self) -> int:
1711
+ """
1712
+ Check all PGM factors to count the number of times that this potential function
1713
+ is used.
1714
+
1715
+ Returns:
1716
+ the number of factors that use this potential function.
1717
+ """
1718
+ return sum(1 for factor in self._pgm.factors if factor.function is self)
1719
+
1720
+ def check_key(self, key: Key) -> Instance:
1721
+ """
1722
+ Convert the key into an instance.
1723
+
1724
+ Arg:
1725
+ key: defines an instance in the state space of the potential function.
1726
+
1727
+ Returns:
1728
+ an instance, which is a tuple of state indexes, co-indexed with `self.rvs`.
1729
+
1730
+ Raises:
1731
+ KeyError: if the key is not valid for the shape of the factor.
1732
+ """
1733
+ return check_key(self._shape, key)
1734
+
1735
+ def valid_key(self, key: Key) -> bool:
1736
+ """
1737
+ Is the given key valid.
1738
+
1739
+ Arg:
1740
+ key: defines an instance in the state space of the potential function.
1741
+
1742
+ Returns:
1743
+ True only if the given key is valid.
1744
+ """
1745
+ return valid_key(self._shape, key)
1746
+
1747
+ def valid_parameter(self, param_idx: int) -> bool:
1748
+ """
1749
+ Is the given parameter index valid.
1750
+
1751
+ Arg:
1752
+ param_idx: a parameter index.
1753
+
1754
+ Returns:
1755
+ True only if `0 <= param_idx < self.number_of_parameters`.
1756
+ """
1757
+ return 0 <= param_idx < self.number_of_parameters
1758
+
1759
+ @property
1760
+ def is_sparse(self) -> bool:
1761
+ """
1762
+ Are there any 'guaranteed zero' parameters values.
1763
+
1764
+ Returns:
1765
+ True only if `self.number_of_not_guaranteed_zero < self._number_of_states`.
1766
+ """
1767
+ return self.number_of_not_guaranteed_zero < self._number_of_states
1768
+
1769
+ @property
1770
+ @abstractmethod
1771
+ def number_of_not_guaranteed_zero(self) -> int:
1772
+ """
1773
+ How many of the states of this potential function are not 'guaranteed zero'.
1774
+ That is, how many keys are associated with a parameter.
1775
+
1776
+ Returns:
1777
+ The number of valid keys that are associated with a parameter.
1778
+
1779
+ Ensures:
1780
+ 0 <= self.number_of_not_guaranteed_zero <= self.number_of_states.
1781
+ """
1782
+ ...
1783
+
1784
+ @property
1785
+ @abstractmethod
1786
+ def number_of_parameters(self) -> int:
1787
+ """
1788
+ Get the number of parameters defining the potential function values.
1789
+ Each valid key of the function maps either to a parameter
1790
+ is 'guaranteed zero'.
1791
+
1792
+ Returns:
1793
+ The number of parameters.
1794
+
1795
+ Ensures:
1796
+ 0 <= self.number_of_parameters <= self.number_of_not_guaranteed_zero.
1797
+ """
1798
+ ...
1799
+
1800
+ @property
1801
+ @abstractmethod
1802
+ def params(self) -> Iterable[Tuple[int, float]]:
1803
+ """
1804
+ Iterate the parameters and their associated values.
1805
+
1806
+ Returns:
1807
+ An iterable over (param_idx, value) tuples, for every possible parameter.
1808
+
1809
+ Assumes:
1810
+ The potential function is not mutated while iterating.
1811
+ """
1812
+ ...
1813
+
1814
+ @property
1815
+ @abstractmethod
1816
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
1817
+ """
1818
+ Iterate the keys that have a parameter associated with them.
1819
+
1820
+ Returns:
1821
+ An iterable over (key, param_idx, value) tuples, for every key with an associated parameter.
1822
+
1823
+ Assumes:
1824
+ The potential function is not mutated while iterating.
1825
+ """
1826
+ ...
1827
+
1828
+ @abstractmethod
1829
+ def __getitem__(self, key: Key) -> float:
1830
+ """
1831
+ Get the potential function value for the given instance key.
1832
+
1833
+ Arg:
1834
+ key: defines an instance in the state space of the potential function.
1835
+
1836
+ Returns:
1837
+ The value of the potential function for the given key.
1838
+
1839
+ Assumes:
1840
+ self.valid_key(key).
1841
+ """
1842
+ ...
1843
+
1844
+ @abstractmethod
1845
+ def param_value(self, param_idx: int) -> float:
1846
+ """
1847
+ Get the potential function value by parameter index.
1848
+
1849
+ Arg:
1850
+ param_idx: a parameter index.
1851
+
1852
+ Assumes:
1853
+ `self.valid_parameter(param_idx)`.
1854
+ """
1855
+ ...
1856
+
1857
+ @abstractmethod
1858
+ def param_idx(self, key: Key) -> Optional[int]:
1859
+ """
1860
+ Get the parameter index for the given potential function random variables states (key).
1861
+
1862
+ Arg:
1863
+ key: defines an instance in the state space of the potential function.
1864
+
1865
+ Returns:
1866
+ either `None` indicating a "guaranteed zero" value, or the parameter index holding
1867
+ the potential function value for the key.
1868
+ """
1869
+ ...
1870
+
1871
+ def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
1872
+ """
1873
+ Is the potential function set with parameters values conforming to a
1874
+ Conditional Probability Table.
1875
+
1876
+ Every parameter value must be non-negative.
1877
+ For every state of the parent (non-first slots)
1878
+ the sum of the parameters over the child states (first slots)
1879
+ must be either 1 or 0.
1880
+
1881
+ Assumes:
1882
+ tolerance is non-negative.
1883
+
1884
+ Args:
1885
+ tolerance: a tolerance when testing if values are equal to zero or one.
1886
+
1887
+ Returns:
1888
+ True only if the potential function is compatible with being a CPT.
1889
+ """
1890
+ # This default implementation calculates the result the long way, by checking
1891
+ # every valid key of the potential function.
1892
+ # Subclasses may override this implementation.
1893
+ low: float = 1.0 - tolerance
1894
+ high: float = 1.0 + tolerance
1895
+ for parent_state in self.parent_instances():
1896
+ total: float = sum(
1897
+ self[(state,) + tuple(parent_state)]
1898
+ for state in range(self.shape[0])
1899
+ )
1900
+ if not ((low <= total <= high) or (0 <= total <= tolerance)):
1901
+ return False
1902
+ return True
1903
+
1904
+ def natural_param_idx(self, key: Key) -> int:
1905
+ """
1906
+ Get the natural parameter index for the given key. This is the same index as used
1907
+ by a DensePotentialFunction with the same shape.
1908
+
1909
+ Args:
1910
+ key: is a valid key of the potential function, referring to an instance in the factor's state space.
1911
+
1912
+ Assumes:
1913
+ `self.valid_key(key)` is true.
1914
+
1915
+ Returns:
1916
+ a hypothetical parameter index assuming that every valid key has a unique parameter
1917
+ as per DensePotentialFunction.
1918
+ """
1919
+ return _natural_key_idx(self._shape, key)
1920
+
1921
+ def param_id(self, param_idx: int) -> ParamId:
1922
+ """
1923
+ Get a hashable object to represent the parameter with the given parameter index.
1924
+
1925
+ Arg:
1926
+ param_idx: a parameter index.
1927
+
1928
+ Returns:
1929
+ a hashable ParamId object for the parameter of this potential function.
1930
+
1931
+ Raises:
1932
+ ValueError: if the parameter index is not valid.
1933
+ """
1934
+ if not (0 <= param_idx < self.number_of_parameters):
1935
+ raise ValueError(f'invalid parameter index: {param_idx}')
1936
+ return ParamId(id(self), param_idx)
1937
+
1938
+ def items(self) -> Iterator[Tuple[Instance, float]]:
1939
+ """
1940
+ Iterate over all keys and values of this potential function.
1941
+
1942
+ Returns:
1943
+ An iterator over all (key, value) pairs, where key is an Instance and value
1944
+ is the value of the potential function for the key.
1945
+ """
1946
+ for key in _combos_ranges(self._shape, flip=True):
1947
+ yield key, self[key]
1948
+
1949
+ def instances(self, flip: bool = False) -> Iterator[Instance]:
1950
+ """
1951
+ Iterate over all possible instances, in natural index order (i.e.,
1952
+ last random variable changing most quickly).
1953
+
1954
+ Args:
1955
+ flip: if true, then first random variable changes most quickly
1956
+
1957
+ Returns:
1958
+ an iterator over tuples, each tuple holds random variable
1959
+ state indexes, co-indexed with this object's shape, i.e., self.shape.
1960
+ """
1961
+ return _combos_ranges(self._shape, flip=not flip)
1962
+
1963
+ def parent_instances(self, flip: bool = False) -> Iterator[Instance]:
1964
+ """
1965
+ Iterate over all possible instances of parent random variable, in
1966
+ natural index order (i.e., last random variable changing most quickly).
1967
+
1968
+ Args:
1969
+ flip: if true, then first random variable changes most quickly
1970
+
1971
+ Returns:
1972
+ an iteration over tuples, each tuple holds random variable states
1973
+ co-indexed with this object's 'parent' shape, i.e., `self.shape[1:]`.
1974
+ """
1975
+ return _combos_ranges(self._shape[1:], flip=not flip)
1976
+
1977
+ def __str__(self) -> str:
1978
+ """
1979
+ Provide a human-readable representation of this potential function.
1980
+ This is intended mainly for debugging purposes.
1981
+ """
1982
+ shape_str: str = ', '.join(str(x) for x in self._shape)
1983
+ return f'{self.__class__.__name__}({shape_str})'
1984
+
1985
+ def dump(
1986
+ self,
1987
+ *,
1988
+ prefix: str = '',
1989
+ indent: str = ' ',
1990
+ show_function_values: bool = False,
1991
+ show_id_class: bool = True,
1992
+ ) -> None:
1993
+ """
1994
+ Print a dump of the function.
1995
+ This is intended for debugging purposes.
1996
+
1997
+ Args:
1998
+ prefix: optional prefix for indenting all lines.
1999
+ indent: additional prefix to use for extra indentation.
2000
+ show_function_values: if true, then the function values will be dumped.
2001
+ show_id_class: if true, then the function id and class will be dumped.
2002
+ """
2003
+
2004
+ shape_str: str = ', '.join(str(x) for x in self._shape)
2005
+
2006
+ if show_id_class:
2007
+ print(f'{prefix}id: {id(self)}')
2008
+ print(f'{prefix}class: {self.__class__.__name__}')
2009
+ print(f'{prefix}usage: {self.count_usage()}')
2010
+ print(f'{prefix}rvs: {self.number_of_rvs}')
2011
+ print(f'{prefix}shape: ({shape_str})')
2012
+ print(f'{prefix}states: {self._number_of_states}')
2013
+ print(f'{prefix}guaranteed zero: {self._number_of_states - self.number_of_not_guaranteed_zero}')
2014
+ print(f'{prefix}not guaranteed zero: {self.number_of_not_guaranteed_zero}')
2015
+ print(f'{prefix}parameters: {self.number_of_parameters}')
2016
+ if show_function_values:
2017
+ next_prefix = prefix + indent
2018
+ for key, param_idx, value in self.keys_with_param:
2019
+ print(f'{next_prefix}{param_idx} {key} = {value}')
2020
+
2021
+
2022
+ class ZeroPotentialFunction(PotentialFunction):
2023
+ """
2024
+ A ZeroPotentialFunction behaves like a DensePotentialFunction
2025
+ in that there is a parameter for each possible key.
2026
+ However, a PGM user has no way to change parameter values.
2027
+ Parameter values are always zero.
2028
+ Despite the inability to change the value of the parameters,
2029
+ no key is considered 'guaranteed zero'.
2030
+
2031
+ The primary use of a ZeroPotentialFunction is as a placeholder
2032
+ within a factor, prior to parameter learning.
2033
+ """
2034
+ __slots__ = ()
2035
+
2036
+ def __init__(self, factor: Factor):
2037
+ """
2038
+ Create a potential function for the given factor.
2039
+
2040
+ Ensures:
2041
+ Does not hold a reference to the given factor.
2042
+ Does not register the potential function with the PGM.
2043
+
2044
+ Args:
2045
+ factor: which factor is this potential function is compatible with.
2046
+ """
2047
+ super().__init__(factor)
2048
+
2049
+ @property
2050
+ def number_of_not_guaranteed_zero(self) -> int:
2051
+ return self.number_of_states
2052
+
2053
+ @property
2054
+ def number_of_parameters(self) -> int:
2055
+ return self.number_of_states
2056
+
2057
+ @property
2058
+ def params(self) -> Iterator[Tuple[int, float]]:
2059
+ for param_idx in range(self.number_of_parameters):
2060
+ yield param_idx, 0
2061
+
2062
+ @property
2063
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2064
+ for param_idx, instance in enumerate(self.instances()):
2065
+ yield instance, param_idx, 0
2066
+
2067
+ def __getitem__(self, key: Key) -> float:
2068
+ self.check_key(key)
2069
+ return 0
2070
+
2071
+ def param_value(self, param_idx: int) -> float:
2072
+ if not self.valid_parameter(param_idx):
2073
+ raise ValueError(f'invalid parameter index: {param_idx}')
2074
+ return 0
2075
+
2076
+ def param_idx(self, key: Key) -> int:
2077
+ return _natural_key_idx(self._shape, key)
2078
+
2079
+ def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
2080
+ return True
2081
+
2082
+
2083
+ class DensePotentialFunction(PotentialFunction):
2084
+ """
2085
+ A dense (tabular) potential function.
2086
+ There is one parameter for each valid key of the potential function.
2087
+ The initial value for each parameter is zero.
2088
+ It is possible independently change any value corresponding to any key.
2089
+ """
2090
+
2091
+ def __init__(self, factor: Factor):
2092
+ """
2093
+ Create a potential function for the given factor.
2094
+
2095
+ Ensures:
2096
+ Does not hold a reference to the given factor.
2097
+ Does not register the potential function with the PGM.
2098
+
2099
+ Args:
2100
+ factor: which factor is this potential function is compatible with.
2101
+ """
2102
+ super().__init__(factor)
2103
+ self._values: NDArrayFloat64 = np.zeros(self.number_of_states, dtype=np.float64)
2104
+
2105
+ @property
2106
+ def number_of_not_guaranteed_zero(self) -> int:
2107
+ return self.number_of_states
2108
+
2109
+ @property
2110
+ def number_of_parameters(self) -> int:
2111
+ return self.number_of_states
2112
+
2113
+ def __getitem__(self, key: Key) -> float:
2114
+ return self._values.item(self.param_idx(key))
2115
+
2116
+ def param_value(self, param_idx: int) -> float:
2117
+ return self._values.item(param_idx)
2118
+
2119
+ def param_idx(self, key: Key) -> Optional[int]:
2120
+ return self.natural_param_idx(key)
2121
+
2122
+ @property
2123
+ def params(self) -> Iterable[Tuple[int, float]]:
2124
+ return enumerate(self._values)
2125
+
2126
+ @property
2127
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2128
+ for param_idx, key in enumerate(self.instances()):
2129
+ value: float = self.param_value(param_idx)
2130
+ yield key, param_idx, value
2131
+
2132
+ # Mutators
2133
+
2134
+ def __setitem__(self, key: Key, value: float) -> None:
2135
+ """
2136
+ Set the potential function value, for a given key.
2137
+
2138
+ Arg:
2139
+ key: defines an instance in the state space of the potential function.
2140
+ value: the new value of the potential function for the given key.
2141
+
2142
+ Assumes:
2143
+ self.valid_key(key).
2144
+ """
2145
+ self._values[self.param_idx(key)] = value
2146
+
2147
+ def set_param_value(self, param_idx: int, value: float) -> None:
2148
+ """
2149
+ Set the parameter value.
2150
+
2151
+ Arg:
2152
+ param_idx: is the index of the parameter.
2153
+ value: the new value of the potential function for the given key.
2154
+
2155
+ Assumes:
2156
+ self.valid_param(param_idx).
2157
+ """
2158
+ self._values[param_idx] = value
2159
+
2160
+ def clear(self) -> DensePotentialFunction:
2161
+ """
2162
+ Set all values of the potential function to zero.
2163
+
2164
+ Returns:
2165
+ self
2166
+ """
2167
+ return self.set_all(0)
2168
+
2169
+ def normalise_cpt(self) -> DensePotentialFunction:
2170
+ """
2171
+ Normalise the parameter values as if this was a CPT.
2172
+ That is, treat the first random variable as the child and the others as parents;
2173
+ for each combination of parent states, ensure the parameters over the child
2174
+ states sum to 1 (or 0).
2175
+
2176
+ Assumes:
2177
+ There are no negative parameter values.
2178
+
2179
+ Returns:
2180
+ self
2181
+ """
2182
+ child = self._shape[0]
2183
+ parents = self._shape[1:]
2184
+ for parent_states in _combos_ranges(parents):
2185
+ keys = [(c,) + parent_states for c in range(child)]
2186
+ total = sum(self[key] for key in keys)
2187
+ if total != 0 and total != 1:
2188
+ for key in keys:
2189
+ self[key] /= total
2190
+ return self
2191
+
2192
+ def normalise(self, grouping_positions: Sequence[int] = ()) -> DensePotentialFunction:
2193
+ """
2194
+ Convert the potential function to a CPT with 'grouping_positions' nominating
2195
+ the parent random variables.
2196
+
2197
+ I.e., for each possible key of the function with the same value at each
2198
+ grouping position, the sum of values for matching keys in the factor is scaled
2199
+ to be 1 (or 0).
2200
+
2201
+ Parameter 'grouping_positions' are indices into `self.shape`. For example, the
2202
+ grouping positions of a factor with parent rvs 'conditioning_rvs', then
2203
+ grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
2204
+
2205
+ Args:
2206
+ grouping_positions: indices into `self.shape`.
2207
+
2208
+ Returns:
2209
+ self
2210
+ """
2211
+ _normalise_potential_function(self, grouping_positions)
2212
+ return self
2213
+
2214
+ def set_iter(self, values: Iterable[float]) -> DensePotentialFunction:
2215
+ """
2216
+ Set the values of the potential function using the given iterator.
2217
+
2218
+ Mapping instances to *values is as follows:
2219
+ Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2220
+ values[0] represents instance (0,0)
2221
+ values[1] represents instance (0,1)
2222
+ values[2] represents instance (0,2)
2223
+ values[3] represents instance (1,0)
2224
+ values[4] represents instance (1,1)
2225
+ values[5] represents instance (1,2).
2226
+
2227
+ For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
2228
+
2229
+ Args:
2230
+ values: an iterable providing values to use.
2231
+
2232
+ Returns:
2233
+ self
2234
+ """
2235
+ self._values = np.fromiter(
2236
+ values,
2237
+ dtype=np.float64,
2238
+ count=self.number_of_parameters
2239
+ )
2240
+ return self
2241
+
2242
+ def set_stream(self, stream: Callable[[], float]) -> DensePotentialFunction:
2243
+ """
2244
+ Set the values of the potential function by repeatedly calling the stream function.
2245
+ The order of values is the same as set_iter.
2246
+
2247
+ For example, to set to random numbers, use `self.set_stream(random.random)`.
2248
+
2249
+ Args:
2250
+ stream: a callable taking no arguments, returning the values to use.
2251
+
2252
+ Returns:
2253
+ self
2254
+ """
2255
+ return self.set_iter(iter(stream, None))
2256
+
2257
+ def set_flat(self, *value: float) -> DensePotentialFunction:
2258
+ """
2259
+ Set the values of the potential function to the given values.
2260
+ The order of values is the same as set_iter.
2261
+
2262
+ Args:
2263
+ *value: the values to use.
2264
+
2265
+ Returns:
2266
+ self
2267
+
2268
+ Raises:
2269
+ ValueError: if `len(value) != self.number_of_states`.
2270
+ """
2271
+ if len(value) != self.number_of_states:
2272
+ raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
2273
+ return self.set_iter(value)
2274
+
2275
+ def set_all(self, value: float) -> DensePotentialFunction:
2276
+ """
2277
+ Set all values of the potential function to the given value.
2278
+
2279
+ Args:
2280
+ value: the value to use.
2281
+
2282
+ Returns:
2283
+ self
2284
+ """
2285
+ return self.set_iter(_repeat(value))
2286
+
2287
+ def set_uniform(self) -> DensePotentialFunction:
2288
+ """
2289
+ Set all values of the potential function 1/number_of_states.
2290
+
2291
+ Returns:
2292
+ self
2293
+ """
2294
+ return self.set_all(1.0 / self.number_of_states)
2295
+
2296
+
2297
+ class SparsePotentialFunction(PotentialFunction):
2298
+ """
2299
+ A sparse potential function.
2300
+ The default value for each parameter is zero.
2301
+ The user may set the value of any key.
2302
+ Setting the value of a key back to zero does not remove its parameter.
2303
+ """
2304
+
2305
+ def __init__(self, factor: Factor):
2306
+ """
2307
+ Create a potential function for the given factor.
2308
+
2309
+ Ensures:
2310
+ Does not hold a reference to the given factor.
2311
+ Does not register the potential function with the PGM.
2312
+
2313
+ Args:
2314
+ factor: which factor is this potential function is compatible with.
2315
+ """
2316
+ super().__init__(factor)
2317
+ self._values: List[float] = []
2318
+ self._params: Dict[Instance, int] = {}
2319
+
2320
+ @property
2321
+ def number_of_not_guaranteed_zero(self) -> int:
2322
+ return len(self._params)
2323
+
2324
+ @property
2325
+ def number_of_parameters(self) -> int:
2326
+ return len(self._params)
2327
+
2328
+ def __getitem__(self, key: Key) -> float:
2329
+ param_idx: Optional[int] = self.param_idx(key)
2330
+ if param_idx is None:
2331
+ return 0
2332
+ else:
2333
+ return self._values[param_idx]
2334
+
2335
+ def param_value(self, param_idx: int) -> float:
2336
+ return self._values[param_idx]
2337
+
2338
+ def param_idx(self, key: Key) -> Optional[int]:
2339
+ return self._params.get(_key_to_instance(key))
2340
+
2341
+ @property
2342
+ def params(self) -> Iterable[Tuple[int, float]]:
2343
+ return enumerate(self._values)
2344
+
2345
+ @property
2346
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2347
+ for key, param_idx in self._params.items():
2348
+ value: float = self._values[param_idx]
2349
+ yield key, param_idx, value
2350
+
2351
+ # Mutators
2352
+
2353
+ def __setitem__(self, key: Key, value: float) -> None:
2354
+ """
2355
+ Set the potential function value, for a given key.
2356
+
2357
+ If value is zero, then the key will become "guaranteed zero".
2358
+
2359
+ Arg:
2360
+ key: defines an instance in the state space of the potential function.
2361
+ value: the new value of the potential function for the given key.
2362
+
2363
+ Assumes:
2364
+ self.valid_key(key).
2365
+ """
2366
+ instance: Instance = _key_to_instance(key)
2367
+ param_idx: Optional[int] = self._params.get(instance)
2368
+
2369
+ if param_idx is None:
2370
+ if value == 0:
2371
+ # nothing to do
2372
+ return
2373
+ param_idx = len(self._values)
2374
+ self._values.append(value)
2375
+ self._params[instance] = param_idx
2376
+ return
2377
+
2378
+ if value != 0:
2379
+ # simple case
2380
+ self._values[param_idx] = value
2381
+ return
2382
+
2383
+ # Need to clear an existing non-zero parameter.
2384
+ end: int = len(self._values) - 1
2385
+ if param_idx != end:
2386
+ # need to swap the parameter with the end.
2387
+ self._values[param_idx] = self._values[end]
2388
+
2389
+ for test_instance, test_param_idx in self._params.items():
2390
+ if test_param_idx == end:
2391
+ self._params[test_instance] = param_idx
2392
+ # There will only be one, so we can break now
2393
+ break
2394
+
2395
+ # remove the parameter
2396
+ self._values.pop()
2397
+ self._params.pop(instance)
2398
+
2399
+ def set_param_value(self, param_idx: int, value: float) -> None:
2400
+ """
2401
+ Set the parameter value.
2402
+
2403
+ Arg:
2404
+ param_idx: is the index of the parameter.
2405
+ value: the new value of the potential function for the given key.
2406
+
2407
+ Assumes:
2408
+ self.valid_param(param_idx).
2409
+ """
2410
+ self._values[param_idx] = value
2411
+
2412
+ def clear(self) -> SparsePotentialFunction:
2413
+ """
2414
+ Set all values of the potential function to zero.
2415
+
2416
+ Returns:
2417
+ self
2418
+ """
2419
+ self._values = []
2420
+ self._params = {}
2421
+ return self
2422
+
2423
+ def normalise_cpt(self) -> SparsePotentialFunction:
2424
+ """
2425
+ Normalise the parameter values as if this was a CPT.
2426
+ That is, treat the first random variable as the child and the others as parents;
2427
+ for each combination of parent states, ensure the parameters over
2428
+ the child states sum to 1 (or 0).
2429
+
2430
+ Returns:
2431
+ self
2432
+ """
2433
+ grouping_positions = list(range(1, self.number_of_rvs))
2434
+ _normalise_potential_function(self, grouping_positions)
2435
+ return self
2436
+
2437
+ def normalise(self, grouping_positions=()) -> SparsePotentialFunction:
2438
+ """
2439
+ Convert the potential function to a CPT with 'grouping_positions' nominating
2440
+ the parent random variables.
2441
+
2442
+ I.e., for each possible key of the function with the same value at each
2443
+ grouping position, the sum of values for matching keys in the factor is scaled
2444
+ to be 1 (or 0).
2445
+
2446
+ Parameter 'grouping_positions' are indices into function.shape. For example, the
2447
+ grouping positions of a factor with parent rvs 'conditioning_rvs', then
2448
+ grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
2449
+
2450
+ Returns:
2451
+ self
2452
+ """
2453
+ _normalise_potential_function(self, grouping_positions)
2454
+ return self
2455
+
2456
+ def set_iter(self, values: Iterable[float]) -> SparsePotentialFunction:
2457
+ """
2458
+ Set the values of the potential function using the given iterator.
2459
+
2460
+ Mapping instances to *values is as follows:
2461
+ Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2462
+ values[0] represents instance (0,0)
2463
+ values[1] represents instance (0,1)
2464
+ values[2] represents instance (0,2)
2465
+ values[3] represents instance (1,0)
2466
+ values[4] represents instance (1,1)
2467
+ values[5] represents instance (1,2).
2468
+
2469
+ For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
2470
+
2471
+ Args:
2472
+ values: an iterable providing values to use.
2473
+
2474
+ Returns:
2475
+ self
2476
+ """
2477
+ self.clear()
2478
+ for instance, value in zip(self.instances(), values):
2479
+ if value != 0:
2480
+ self._params[instance] = len(self._values)
2481
+ self._values.append(value)
2482
+ return self
2483
+
2484
+ def set_stream(self, stream: Callable[[], float]) -> SparsePotentialFunction:
2485
+ """
2486
+ Set the values of the potential function by repeatedly calling the stream function.
2487
+ The order of values is the same as set_iter.
2488
+
2489
+ For example, to set to random numbers, use `self.set_stream(random.random)`.
2490
+
2491
+ Args:
2492
+ stream: a callable taking no arguments, returning the values to use.
2493
+
2494
+ Returns:
2495
+ self
2496
+ """
2497
+ return self.set_iter(iter(stream, None))
2498
+
2499
+ def set_flat(self, *value: float) -> SparsePotentialFunction:
2500
+ """
2501
+ Set the values of the potential function to the given values.
2502
+ The order of values is the same as set_iter.
2503
+
2504
+ Args:
2505
+ *value: the values to use.
2506
+
2507
+ Returns:
2508
+ self
2509
+
2510
+ Raises:
2511
+ ValueError: if `len(value) != self.number_of_states`.
2512
+ """
2513
+ if len(value) != self.number_of_states:
2514
+ raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
2515
+ return self.set_iter(value)
2516
+
2517
+ def set_all(self, value: float) -> SparsePotentialFunction:
2518
+ """
2519
+ Set all values of the potential function to the given value.
2520
+
2521
+ Args:
2522
+ value: the value to use.
2523
+
2524
+ Returns:
2525
+ self
2526
+ """
2527
+ if value == 0:
2528
+ return self.clear()
2529
+ else:
2530
+ return self.set_iter(_repeat(value))
2531
+
2532
+ def set_uniform(self) -> SparsePotentialFunction:
2533
+ """
2534
+ Set all values of the potential function 1/number_of_states.
2535
+
2536
+ Returns:
2537
+ self
2538
+ """
2539
+ return self.set_all(1.0 / self.number_of_states)
2540
+
2541
+
2542
+ class CompactPotentialFunction(PotentialFunction):
2543
+ """
2544
+ A sparse potential function.
2545
+ There is one parameter for each unique, non-zero parameter value.
2546
+ The default value for each parameter is zero.
2547
+ The user may set the value of any key.
2548
+ """
2549
+
2550
+ def __init__(self, factor: Factor):
2551
+ """
2552
+ Create a potential function for the given factor.
2553
+
2554
+ Ensures:
2555
+ Does not hold a reference to the given factor.
2556
+ Does not register the potential function with the PGM.
2557
+
2558
+ Args:
2559
+ factor: which factor is this potential function is compatible with.
2560
+ """
2561
+ super().__init__(factor)
2562
+ self._values: List[float] = []
2563
+ self._counts: List[int] = []
2564
+ self._map: Dict[Instance, int] = {}
2565
+ self._inv_map: Dict[float, int] = {}
2566
+
2567
+ @property
2568
+ def number_of_not_guaranteed_zero(self) -> int:
2569
+ return len(self._map)
2570
+
2571
+ @property
2572
+ def number_of_parameters(self) -> int:
2573
+ return len(self._values)
2574
+
2575
+ @property
2576
+ def params(self) -> Iterable[Tuple[int, float]]:
2577
+ return enumerate(self._values)
2578
+
2579
+ @property
2580
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2581
+ for key, param_idx in self._map.items():
2582
+ value: float = self._values[param_idx]
2583
+ yield key, param_idx, value
2584
+
2585
+ def __getitem__(self, key: Key) -> float:
2586
+ param_idx: Optional[int] = self.param_idx(key)
2587
+ if param_idx is None:
2588
+ return 0
2589
+ else:
2590
+ return self._values[param_idx]
2591
+
2592
+ def param_value(self, param_idx: int) -> float:
2593
+ return self._values[param_idx]
2594
+
2595
+ def param_idx(self, key: Key) -> Optional[int]:
2596
+ return self._map.get(_key_to_instance(key))
2597
+
2598
+ # Mutators
2599
+
2600
+ def __setitem__(self, key: Key, value: float) -> None:
2601
+ """
2602
+ Set the potential function value, for a given key.
2603
+
2604
+ If value is zero, then the key will become "guaranteed zero".
2605
+ If the value is the same as an existing parameter value, then
2606
+ that parameter will be reused.
2607
+
2608
+ Arg:
2609
+ key: defines an instance in the state space of the potential function.
2610
+ value: the new value of the potential function for the given key.
2611
+
2612
+ Assumes:
2613
+ self.valid_key(key).
2614
+ """
2615
+ instance: Instance = _key_to_instance(key)
2616
+
2617
+ param_idx: Optional[int] = self._map.get(instance)
2618
+
2619
+ if param_idx is None:
2620
+ # previous value for the key was zero
2621
+ if value == 0:
2622
+ # nothing to do
2623
+ return
2624
+ param_idx: Optional[int] = self._inv_map.get(value)
2625
+ if param_idx is not None:
2626
+ # the value already exists in the function, so reuse it
2627
+ self._map[instance] = param_idx
2628
+ self._counts[param_idx] += 1
2629
+ else:
2630
+ # need to allocate a new value
2631
+ new_param_idx: int = len(self._values)
2632
+ self._values.append(value)
2633
+ self._counts.append(1)
2634
+ self._inv_map[value] = new_param_idx
2635
+ self._map[instance] = new_param_idx
2636
+ return
2637
+
2638
+ # the key previously had a non-zero value
2639
+ prev_value: float = self._values[param_idx]
2640
+
2641
+ if value == prev_value:
2642
+ # nothing to do
2643
+ return
2644
+
2645
+ reference_count: int = self._counts[param_idx]
2646
+ if reference_count == 1:
2647
+ if value != 0:
2648
+ # simple case
2649
+ self._values[param_idx] = value
2650
+ else:
2651
+ # need to remove the parameter
2652
+ self._remove_param(param_idx)
2653
+ self._map.pop(instance)
2654
+ self._inv_map.pop(prev_value)
2655
+ return
2656
+
2657
+ # decrement the reference count of the previous parameter
2658
+ self._counts[param_idx] = reference_count - 1
2659
+
2660
+ # allocate the key to a different parameter
2661
+ param_idx: Optional[int] = self._inv_map.get(value)
2662
+ if param_idx is not None:
2663
+ # the value already exists in the function, so reuse it
2664
+ self._map[instance] = param_idx
2665
+ self._counts[param_idx] += 1
2666
+ else:
2667
+ # need to allocate a new value
2668
+ new_param_idx: int = len(self._values)
2669
+ self._values.append(value)
2670
+ self._counts.append(1)
2671
+ self._inv_map[value] = new_param_idx
2672
+ self._map[instance] = new_param_idx
2673
+
2674
+ def set_iter(self, values: Iterable[float]) -> CompactPotentialFunction:
2675
+ """
2676
+ Set the values of the potential function using the given iterator.
2677
+
2678
+ Mapping instances to *values is as follows:
2679
+ Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2680
+ values[0] represents instance (0,0)
2681
+ values[1] represents instance (0,1)
2682
+ values[2] represents instance (0,2)
2683
+ values[3] represents instance (1,0)
2684
+ values[4] represents instance (1,1)
2685
+ values[5] represents instance (1,2).
2686
+
2687
+ For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
2688
+
2689
+ Args:
2690
+ values: an iterable providing values to use.
2691
+
2692
+ Returns:
2693
+ self
2694
+ """
2695
+ self.clear()
2696
+ for instance, value in zip(self.instances(), values):
2697
+ self[instance] = value
2698
+ return self
2699
+
2700
+ def set_stream(self, stream: Callable[[], float]) -> CompactPotentialFunction:
2701
+ """
2702
+ Set the values of the potential function by repeatedly calling the stream function.
2703
+ The order of values is the same as set_iter.
2704
+
2705
+ For example, to set to random numbers, use `self.set_stream(random.random)`.
2706
+
2707
+ Args:
2708
+ stream: a callable taking no arguments, returning the values to use.
2709
+
2710
+ Returns:
2711
+ self
2712
+ """
2713
+ return self.set_iter(iter(stream, None))
2714
+
2715
+ def set_flat(self, *value: float) -> CompactPotentialFunction:
2716
+ """
2717
+ Set the values of the potential function to the given values.
2718
+ The order of values is the same as set_iter.
2719
+
2720
+ Args:
2721
+ *value: the values to use.
2722
+
2723
+ Returns:
2724
+ self
2725
+
2726
+ Raises:
2727
+ ValueError: if `len(value) != self.number_of_states`.
2728
+ """
2729
+ if len(value) != self.number_of_states:
2730
+ raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
2731
+ return self.set_iter(value)
2732
+
2733
+ def set_all(self, value: float) -> CompactPotentialFunction:
2734
+ """
2735
+ Set all values of the potential function to the given value.
2736
+
2737
+ Args:
2738
+ value: the value to use.
2739
+
2740
+ Returns:
2741
+ self
2742
+ """
2743
+ self.clear()
2744
+ if value != 0:
2745
+ self._values = [value]
2746
+ self._counts = [self.number_of_states]
2747
+ self._inv_map = {value: 0}
2748
+ self._map = {instance: 0 for instance in self.instances()}
2749
+ return self
2750
+
2751
+ def set_uniform(self) -> CompactPotentialFunction:
2752
+ """
2753
+ Set all values of the potential function 1/number_of_states.
2754
+
2755
+ Returns:
2756
+ self
2757
+ """
2758
+ return self.set_all(1.0 / self.number_of_states)
2759
+
2760
+ def clear(self) -> CompactPotentialFunction:
2761
+ """
2762
+ Set all values of the potential function to zero.
2763
+
2764
+ Returns:
2765
+ self
2766
+ """
2767
+ self._values = []
2768
+ self._counts = []
2769
+ self._map = {}
2770
+ self._inv_map = {}
2771
+ return self
2772
+
2773
+ def _remove_param(self, param_idx: int) -> None:
2774
+ """
2775
+ Remove the index parameter from self._params and self._counts.
2776
+ If the parameter is not at the end of the list of parameters
2777
+ then it will be swapped with the end parameter.
2778
+ """
2779
+
2780
+ # ensure the parameter is at the end of the list
2781
+ end: int = len(self._values) - 1
2782
+ if param_idx != end:
2783
+ # swap `param_idx` with `end`
2784
+ end_value: float = self._values[end]
2785
+ self._values[param_idx] = end_value
2786
+ self._counts[param_idx] = self._counts[end]
2787
+ self._inv_map[end_value] = param_idx
2788
+ for instance, instance_param_idx in self._map.items():
2789
+ if instance_param_idx == end:
2790
+ self._map[instance] = param_idx
2791
+
2792
+ # remove the end parameter
2793
+ self._values.pop()
2794
+ self._counts.pop()
2795
+
2796
+
2797
+ class ClausePotentialFunction(PotentialFunction):
2798
+ """
2799
+ A clause potential function represents a clause (from a CNF formula) i.e. a disjunction.
2800
+ A clause over variables X, Y, Z, is of the form: 'X=x or Y=y or Z=z'.
2801
+
2802
+ A clause potential function guaranteed zero for the key where the clause is false,
2803
+ i.e., when 'X != x and Y != y and Z != z'.
2804
+
2805
+ For keys where the clause is true, the value of the potential function
2806
+ is given by the only parameter of the potential function. That parameter
2807
+ is called the clause 'weight' and is notionally 1.
2808
+
2809
+ The weight of a clause is permitted to be zero, but that is _not_ equivalent to
2810
+ guaranteed-zero.
2811
+ """
2812
+
2813
+ def __init__(self, factor: Factor, key: Key, weight: float = 1):
2814
+ """
2815
+ Create a clause potential function for the given factor.
2816
+
2817
+ Ensures:
2818
+ Does not hold a reference to the given factor.
2819
+ Does not register the potential function with the PGM.
2820
+
2821
+ Raises:
2822
+ KeyError: if the key is not valid for the shape of the factor.
2823
+
2824
+ Args:
2825
+ factor: which factor is this potential function is compatible with.
2826
+ key: defines the random variable states of the clause.
2827
+ """
2828
+ super().__init__(factor)
2829
+ self._weight: float = weight
2830
+ self._clause: Instance = self.check_key(key)
2831
+ self._num_not_guaranteed_zero: int = _zero_space(self.shape)
2832
+
2833
+ @property
2834
+ def number_of_not_guaranteed_zero(self) -> int:
2835
+ return self._num_not_guaranteed_zero
2836
+
2837
+ @property
2838
+ def number_of_parameters(self) -> int:
2839
+ return 1
2840
+
2841
+ @property
2842
+ def params(self) -> Iterable[Tuple[int, float]]:
2843
+ return ((0, self._weight),)
2844
+
2845
+ @property
2846
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2847
+ value = self._weight
2848
+ for i in range(self.number_of_rvs):
2849
+ key = list(self._clause)
2850
+ for j in range(self.shape[i]):
2851
+ key[i] = j
2852
+ yield tuple(key), 0, value
2853
+
2854
+ def __getitem__(self, key: Key) -> float:
2855
+ instance: Instance = self.check_key(key)
2856
+ for key_state_idx, clause_state_idx in zip(instance, self._clause):
2857
+ if key_state_idx == clause_state_idx:
2858
+ return self._weight
2859
+ return 0
2860
+
2861
+ def param_value(self, param_idx: int) -> float:
2862
+ if param_idx != 0:
2863
+ raise IndexError(param_idx)
2864
+ return self._weight
2865
+
2866
+ def param_idx(self, key: Key) -> Optional[int]:
2867
+ instance: Instance = _key_to_instance(key)
2868
+ if instance == self._clause:
2869
+ return 0
2870
+ else:
2871
+ return None
2872
+
2873
+ def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
2874
+ """
2875
+ A ClausePotentialFunction can only be a CTP when all entries are zero.
2876
+ """
2877
+ return -tolerance <= self._weight <= tolerance
2878
+
2879
+ def is_sparse(self) -> bool:
2880
+ return True
2881
+
2882
+ @property
2883
+ def weight(self) -> float:
2884
+ """
2885
+ Returns:
2886
+ the "weight" parameter defining the potential function.
2887
+ """
2888
+ return self._weight
2889
+
2890
+ @property
2891
+ def clause(self) -> Instance:
2892
+ """
2893
+ Returns:
2894
+ the clause defining the potential function.
2895
+ """
2896
+ return self._clause
2897
+
2898
+ # Mutators
2899
+
2900
+ @weight.setter
2901
+ def weight(self, value: float) -> None:
2902
+ """
2903
+ Set the weight parameter to the given value.
2904
+ """
2905
+ self._weight = value
2906
+
2907
+ @clause.setter
2908
+ def clause(self, key: Key) -> None:
2909
+ """
2910
+ Set the clause to the given key.
2911
+
2912
+ Raises:
2913
+ KeyError: if the key is not valid for the shape of the factor.
2914
+ """
2915
+ self._clause = self.check_key(key)
2916
+
2917
+
2918
+ class CPTPotentialFunction(PotentialFunction):
2919
+ """
2920
+ A potential function implementing a sparse Conditional Probability Table (CPT).
2921
+
2922
+ The first random variable in the signature is the child, and the remaining random
2923
+ variables are parents.
2924
+
2925
+ For each instantiation of the parent random variables there is a Conditioned Probability
2926
+ Distribution (CPD) over the states of the child random variable.
2927
+
2928
+ If a CPD is not provided for a parent instantiation, then that parent instantiation
2929
+ is taken to have probability zero (i.e., all values of the CPD are guaranteed zero).
2930
+ """
2931
+
2932
+ def __init__(self, factor: Factor, tolerance: float):
2933
+ """
2934
+ Create a CPT potential function for the given factor.
2935
+
2936
+ Ensures:
2937
+ Does not hold a reference to the given factor.
2938
+ Does not register the potential function with the PGM.
2939
+
2940
+ Args:
2941
+ factor: which factor is this potential function is compatible with.
2942
+ tolerance: a tolerance when testing if values are equal to zero or one.
2943
+
2944
+ Raises:
2945
+ ValueError: if tolerance is negative.
2946
+ """
2947
+ super().__init__(factor)
2948
+
2949
+ if tolerance < 0:
2950
+ raise ValueError('tolerance cannot be negative')
2951
+
2952
+ self._child_size: int = self.shape[0]
2953
+ self._parent_shape: Shape = self.shape[1:]
2954
+ self._map: Dict[Instance, int] = {}
2955
+ self._values: List[float] = []
2956
+ self._inv_map: List[Instance] = []
2957
+ self._tolerance = tolerance
2958
+
2959
+ @property
2960
+ def number_of_not_guaranteed_zero(self) -> int:
2961
+ return len(self._values)
2962
+
2963
+ @property
2964
+ def number_of_parameters(self) -> int:
2965
+ return len(self._values)
2966
+
2967
+ def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
2968
+ if tolerance >= self._tolerance:
2969
+ return True
2970
+ else:
2971
+ # The requested tolerance is tighter than ensured.
2972
+ # Need to use the default method
2973
+ return super().is_cpt(tolerance)
2974
+
2975
+ @property
2976
+ def params(self) -> Iterable[Tuple[int, float]]:
2977
+ return enumerate(self._values)
2978
+
2979
+ @property
2980
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2981
+ child_size: int = self._child_size
2982
+ for param_idx, value in enumerate(self._values):
2983
+ parent: Instance = self._inv_map[param_idx // child_size]
2984
+ key: Instance = (param_idx % child_size,) + tuple(parent)
2985
+ yield key, param_idx, value
2986
+
2987
+ def __getitem__(self, key: Key) -> float:
2988
+ param_idx: Optional[int] = self.param_idx(key)
2989
+ if param_idx is None:
2990
+ return 0
2991
+ else:
2992
+ return self._values[param_idx]
2993
+
2994
+ def param_value(self, param_idx: int) -> float:
2995
+ return self._values[param_idx]
2996
+
2997
+ def param_idx(self, key: Key) -> Optional[int]:
2998
+ instance: Instance = self.check_key(key)
2999
+ offset: Optional[int] = self._map.get(instance[1:])
3000
+ if offset is None:
3001
+ return None
3002
+ else:
3003
+ return offset + instance[0]
3004
+
3005
+ @property
3006
+ def parent_shape(self) -> Shape:
3007
+ """
3008
+ What is the shape of the parents.
3009
+ """
3010
+ return self._parent_shape
3011
+
3012
+ @property
3013
+ def number_of_parent_states(self) -> int:
3014
+ """
3015
+ How many combinations of parent states.
3016
+ """
3017
+ return _multiply(self._parent_shape)
3018
+
3019
+ @property
3020
+ def number_of_child_states(self) -> int:
3021
+ """
3022
+ Number of child random variable states.
3023
+
3024
+ This is the same as the number of values in each conditional
3025
+ probability distribution. This is equivalent to `self.shape[0]`.
3026
+
3027
+ Returns:
3028
+ the number of child states.
3029
+ """
3030
+ return self._child_size
3031
+
3032
+ def get_cpd(self, parent_states: Key) -> List[float]:
3033
+ """
3034
+ Get the CPD conditioned on parent states indicated by `parent_states`.
3035
+
3036
+ Args:
3037
+ parent_states: indicates the parent states.
3038
+
3039
+ Returns:
3040
+ The conditioned probability distribution.
3041
+ """
3042
+ parent_instance: Instance = check_key(self._parent_shape, parent_states)
3043
+ offset: Optional[int] = self._map.get(parent_instance)
3044
+ child_size: int = self._child_size
3045
+ if offset is None:
3046
+ return [0] * child_size
3047
+ else:
3048
+ return self._values[offset:offset + child_size]
3049
+
3050
+ def cpds(self) -> Iterator[Tuple[Instance, Sequence[float]]]:
3051
+ """
3052
+ Iterate over (parent_states, cpd) tuples.
3053
+ This will exclude zero CPDs.
3054
+ Do not change CPDs to (or from) zero while iterating over them.
3055
+
3056
+ Get the CPD conditioned on parent states indicated by `parent_states`.
3057
+
3058
+ Returns:
3059
+ an iterator over pairs (instance, cpd) where,
3060
+ instance: is indicates the state of the parent random variables.
3061
+ cpd: is the conditioned probability distribution, for the parent instance.
3062
+ """
3063
+ for parent_instance, offset in self._map.items():
3064
+ cpd = self._values[offset:offset + self._child_size]
3065
+ yield parent_instance, cpd
3066
+
3067
+ # Mutators
3068
+
3069
+ def clear(self) -> CPTPotentialFunction:
3070
+ """
3071
+ Set all values of the potential function to zero.
3072
+
3073
+ Returns:
3074
+ self
3075
+ """
3076
+ self._map = {}
3077
+ self._values = []
3078
+ self._inv_map = []
3079
+ return self
3080
+
3081
+ def set_uniform(self) -> CPTPotentialFunction:
3082
+ """
3083
+ Set each CPD to a uniform distribution.
3084
+
3085
+ Returns:
3086
+ self
3087
+ """
3088
+ self.clear()
3089
+ for parent_states in self.parent_instances():
3090
+ self.set_cpd_uniform(parent_states)
3091
+ return self
3092
+
3093
+ def set_random(self, random: Callable[[], float], sparsity: float = 0) -> CPTPotentialFunction:
3094
+ """
3095
+ Set the values of the potential function to random CPDs.
3096
+
3097
+ Args:
3098
+ random: is a stream of random numbers, assumed uniformly distributed in the interval [0, 1].
3099
+ sparsity: sets the expected proportion of probability values that are zero.
3100
+
3101
+ Returns:
3102
+ self
3103
+ """
3104
+ self.clear()
3105
+ for parent_states in self.parent_instances():
3106
+ self.set_cpd_random(parent_states, random, sparsity)
3107
+ return self
3108
+
3109
+ def set(self, *rows: Tuple[Key, Sequence[float]]) -> CPTPotentialFunction:
3110
+ """
3111
+ Calls self.set_cpd(parent_states, cpd) for each row (parent_states, cpd)
3112
+ in rows. Any unmentioned parent states will have zero probabilities.
3113
+
3114
+ Example usage, assuming three Boolean random variables:
3115
+ pgm.Factor(x, y, z).set_cpt().set(
3116
+ # y z x[0] x[1]
3117
+ ((0, 0), (0.1, 0.9)),
3118
+ ((0, 1), (0.1, 0.9)),
3119
+ ((1, 0), (0.1, 0.9)),
3120
+ ((1, 1), (0.1, 0.9))
3121
+ )
3122
+
3123
+ Args:
3124
+ *rows: are tuples (key, cpd) used to set the potential function values.
3125
+
3126
+ Raises:
3127
+ ValueError: if a CPD is not valid.
3128
+
3129
+ Returns:
3130
+ self
3131
+ """
3132
+ self.clear()
3133
+ for parent_states, cpd in rows:
3134
+ self.set_cpd(parent_states, cpd)
3135
+ return self
3136
+
3137
+ def set_all(self, *cpds: Optional[Sequence[float]]) -> CPTPotentialFunction:
3138
+ """
3139
+ Set all CPDs using the given `cpds` which are taken to be in order of the parent states
3140
+ with the last variable of the parent changing state most rapidly, as per parent_states().
3141
+
3142
+ If insufficient CPDs are provided then the remaining parent instantiations are taken to be
3143
+ impossible (i.e. not set and guaranteed zero).
3144
+ If too many CPDs are provided then the extras are ignored.
3145
+ Any list entry may be None, indicating 'guaranteed zero' for the associated parent states.
3146
+
3147
+ Args:
3148
+ *cpds: are the CPDs used to set the potential function values.
3149
+
3150
+ Raises:
3151
+ ValueError: if a CPD is not valid.
3152
+
3153
+ Returns:
3154
+ self
3155
+ """
3156
+ self.clear()
3157
+ for parent_states, cpd in zip(self.parent_instances(), cpds):
3158
+ self.set_cpd(parent_states, cpd)
3159
+ return self
3160
+
3161
+ def set_cpd(self, parent_states: Key, cpd: Optional[Sequence[float]]) -> CPTPotentialFunction:
3162
+ """
3163
+ Set the CPD of the given parent states to the given cpd.
3164
+ If cpd is None or all zeros, then this is equivalent to clear_cpd(parent_states).
3165
+
3166
+ Args:
3167
+ parent_states: indicates the CPD to set, based on the parent states.
3168
+ cpd: is a conditioned probability distribution, or None indicating `guaranteed zero`.
3169
+
3170
+ Raises:
3171
+ ValueError: if the CPD is not valid.
3172
+ KeyError if the key is not valid.
3173
+
3174
+ Returns:
3175
+ self
3176
+ """
3177
+ parent_instance: Instance = check_key(self._parent_shape, parent_states)
3178
+
3179
+ if cpd is None:
3180
+ self._clear_cpd(parent_instance)
3181
+ return self
3182
+
3183
+ if len(cpd) != self._child_size:
3184
+ raise ValueError(f'CPD incorrect size: expected {self._child_size}, got {len(cpd)}')
3185
+ if not all(0 <= value <= 1 for value in cpd):
3186
+ raise ValueError(f'not a valid CPD: {cpd!r}')
3187
+
3188
+ total_value = sum(cpd)
3189
+ if total_value < self._tolerance:
3190
+ self._clear_cpd(parent_instance)
3191
+ return self
3192
+
3193
+ if total_value < 1 - self._tolerance or total_value > 1 + self._tolerance:
3194
+ raise ValueError(f'not a valid CPD: sum of values = {total_value}')
3195
+
3196
+ offset: Optional[int] = self._map.get(parent_instance)
3197
+ child_size: int = self._child_size
3198
+ if offset is None:
3199
+ offset = len(self._values)
3200
+ self._values.extend(cpd)
3201
+ self._map[parent_instance] = offset
3202
+ self._inv_map.append(parent_instance)
3203
+ else:
3204
+ self._values[offset:offset + child_size] = cpd
3205
+
3206
+ return self
3207
+
3208
+ def clear_cpd(self, parent_states: Key) -> CPTPotentialFunction:
3209
+ """
3210
+ Set the CPD of the given parent_states to all 'guaranteed zero'.
3211
+
3212
+ Args:
3213
+ parent_states: indicates the CPD to clear, based on the parent states.
3214
+
3215
+ Raises:
3216
+ KeyError if the key is not valid.
3217
+
3218
+ Returns:
3219
+ self
3220
+ """
3221
+ parent_instance: Instance = check_key(self._parent_shape, parent_states)
3222
+ self._clear_cpd(parent_instance)
3223
+ return self
3224
+
3225
+ def set_cpd_uniform(self, parent_states: Key) -> CPTPotentialFunction:
3226
+ """
3227
+ Set the CPD of the given parent_states to a uniform CPD.
3228
+
3229
+ Args:
3230
+ parent_states: indicates the CPD to clear, based on the parent states.
3231
+
3232
+ Raises:
3233
+ KeyError if the key is not valid.
3234
+
3235
+ Returns:
3236
+ self
3237
+ """
3238
+ num_states = self.number_of_child_states
3239
+ cpd = [1.0 / num_states] * num_states
3240
+ return self.set_cpd(parent_states, cpd)
3241
+
3242
+ def set_cpd_random(
3243
+ self,
3244
+ parent_states: Key,
3245
+ random: Callable[[], float],
3246
+ sparsity: float = 0,
3247
+ ) -> CPTPotentialFunction:
3248
+ """
3249
+ Set the CPD of the given parent_states to a random CPD.
3250
+
3251
+ Args:
3252
+ parent_states: identifies the CPD being set.
3253
+ random: is a stream of random numbers, assumed uniformly distributed in the interval [0, 1].
3254
+ sparsity: sets the expected proportion of probability values that are zero.
3255
+
3256
+ Returns:
3257
+ self
3258
+ """
3259
+ cpd = np.zeros(self.number_of_child_states, dtype=np.float64)
3260
+ if sparsity <= 0:
3261
+ for i in range(len(cpd)):
3262
+ cpd[i] = 0.0000001 + random()
3263
+ else:
3264
+ for i in range(len(cpd)):
3265
+ if random() > sparsity:
3266
+ cpd[i] = 0.0000001 + random()
3267
+ sum_value = np.sum(cpd)
3268
+ if sum_value > 0:
3269
+ cpd /= sum_value
3270
+ return self.set_cpd(parent_states, cpd)
3271
+ else:
3272
+ return self.clear_cpd(parent_states)
3273
+
3274
+ def _clear_cpd(self, parent_instance: Instance) -> None:
3275
+ """
3276
+ Remove the parent instance from the parameters
3277
+ """
3278
+ offset: Optional[int] = self._map.get(parent_instance)
3279
+ if offset is None:
3280
+ # nothing to do
3281
+ return
3282
+
3283
+ child_size: int = self._child_size
3284
+ end_offset: int = len(self._values) - child_size
3285
+ if offset != end_offset:
3286
+ # need to swap parameters
3287
+ end_cpd = self._values[end_offset:]
3288
+ end_parent_instance = self._inv_map[-1]
3289
+
3290
+ self._values[offset:offset + child_size] = end_cpd
3291
+ self._map[end_parent_instance] = offset
3292
+ self._inv_map[offset // child_size] = end_parent_instance
3293
+
3294
+ self._map.pop(parent_instance)
3295
+ self._inv_map.pop()
3296
+ for _ in range(child_size):
3297
+ self._values.pop()
3298
+
3299
+
3300
+ def default_pgm_name(pgm: PGM) -> str:
3301
+ """
3302
+ If no name is provided to a PGM constructor, then this will be the default name for the PGM.
3303
+
3304
+ Args:
3305
+ pgm: a PGM object.
3306
+
3307
+ Returns:
3308
+ a name for the PGM if none is given at construction time.
3309
+ """
3310
+ return 'PGM_' + str(id(pgm))
3311
+
3312
+
3313
+ def check_key(shape: Shape, key: Key) -> Instance:
3314
+ """
3315
+ Convert the key into an instance.
3316
+
3317
+ Args:
3318
+ shape: the shape defining the state space.
3319
+ key: a key into the state space.
3320
+
3321
+ Returns:
3322
+ A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
3323
+
3324
+ Raises:
3325
+ KeyError if the key is not valid.
3326
+ """
3327
+ _key: Instance = _key_to_instance(key)
3328
+ if len(_key) != len(shape):
3329
+ raise KeyError(f'not a valid key for shape {shape}: {key!r}')
3330
+ if all((0 <= i <= m) for i, m in zip(_key, shape)):
3331
+ return tuple(_key)
3332
+ raise KeyError(f'not a valid key for shape {shape}: {key!r}')
3333
+
3334
+
3335
+ def valid_key(shape: Shape, key: Key) -> bool:
3336
+ """
3337
+ Is the given key valid.
3338
+
3339
+ Args:
3340
+ shape: the shape defining the state space.
3341
+ key: a key into the state space.
3342
+
3343
+ Returns:
3344
+ True only if tke key is valid for the given shape.
3345
+ """
3346
+ try:
3347
+ check_key(shape, key)
3348
+ return True
3349
+ except KeyError:
3350
+ return False
3351
+
3352
+
3353
+ def number_of_states(*rvs: RandomVariable) -> int:
3354
+ """
3355
+ Returns:
3356
+ What is the size of the state space, i.e., `multiply(len(rv) for rv in self.rvs)`.
3357
+ """
3358
+ return _multiply(len(rv) for rv in rvs)
3359
+
3360
+
3361
+ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterator[Instance]:
3362
+ """
3363
+ Enumerate instances of the given random variables.
3364
+
3365
+ Each instance is a tuples of state indexes, co-indexed with the given random variables.
3366
+
3367
+ The order is the natural index order (i.e., last random variable changing most quickly).
3368
+
3369
+ Args:
3370
+ flip: if true, then first random variable changes most quickly.
3371
+
3372
+ Returns:
3373
+ an iteration over tuples, each tuple holds state indexes
3374
+ co-indexed with the given random variables.
3375
+ """
3376
+ shape = [len(rv) for rv in rvs]
3377
+ return _combos_ranges(shape, flip=not flip)
3378
+
3379
+
3380
+ def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterator[Sequence[Indicator]]:
3381
+ """
3382
+ Enumerate instances of the given random variables.
3383
+
3384
+ Each instance is a tuples of indicators, co-indexed with the given random variables.
3385
+
3386
+ The order is the natural index order (i.e., last random variable changing most quickly).
3387
+
3388
+ Args:
3389
+ flip: if true, then first random variable changes most quickly.
3390
+
3391
+ Returns:
3392
+ an iteration over tuples, each tuples holds random variable indicators
3393
+ co-indexed with the given random variables.
3394
+ """
3395
+ return _combos(rvs, flip=not flip)
3396
+
3397
+
3398
+ def _key_to_instance(key: Key) -> Instance:
3399
+ """
3400
+ Convert a key to an instance.
3401
+
3402
+ Args:
3403
+ key: a key into a state space.
3404
+
3405
+ Returns:
3406
+ A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
3407
+
3408
+ Assumes:
3409
+ The key is valid for the implied state space.
3410
+ """
3411
+ if isinstance(key, int):
3412
+ return (key,)
3413
+ else:
3414
+ return tuple(key)
3415
+
3416
+
3417
+ def _natural_key_idx(shape: Shape, key: Key) -> int:
3418
+ """
3419
+ What is the natural index of the given key, assuming the given shape.
3420
+
3421
+ Args:
3422
+ shape: the shape defining the state space.
3423
+ key: a key into the state space.
3424
+
3425
+ Returns:
3426
+ an index as per enumerated instances in their natural order, i.e.
3427
+ last random variable changing most quickly.
3428
+
3429
+ Assumes:
3430
+ The key is valid for the shape.
3431
+ """
3432
+ instance: Instance = _key_to_instance(key)
3433
+ result: int = instance[0]
3434
+ for s, i in zip(shape[1:], instance[1:]):
3435
+ result = result * s + i
3436
+ return result
3437
+
3438
+
3439
+ def _zero_space(shape: Shape) -> int:
3440
+ """
3441
+ Return the size of the zero space of the given shape. This is the number
3442
+ of possible instances in the state space that do not have a zero in the instance.
3443
+
3444
+ The zero space is the same as the shape but with one less state
3445
+ for each random variable.
3446
+
3447
+ Args:
3448
+ shape: the shape defining the state space.
3449
+
3450
+ Returns:
3451
+ the size of the zero space.
3452
+ """
3453
+ return _multiply(x - 1 for x in shape)
3454
+
3455
+
3456
+ def _normalise_potential_function(
3457
+ function: Union[DensePotentialFunction, SparsePotentialFunction],
3458
+ grouping_positions: Sequence[int],
3459
+ ) -> None:
3460
+ """
3461
+ Convert the potential function to a CPT with 'grouping_positions' nominating
3462
+ the parent random variables.
3463
+
3464
+ I.e., for each possible key of the function with the same value at each
3465
+ grouping position, the sum of values for matching keys in the factor is scaled
3466
+ to be 1 (or 0).
3467
+
3468
+ Parameter 'grouping_positions' are indices into `function.shape`. For example, the
3469
+ grouping positions of a factor with parent rvs 'conditioning_rvs', then
3470
+ grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
3471
+
3472
+ Args:
3473
+ function: the potential function to normalise.
3474
+ grouping_positions: indices into `function.shape`.
3475
+ """
3476
+ if len(grouping_positions) == 0:
3477
+ total = sum(
3478
+ function.param_value(param_idx)
3479
+ for param_idx in range(function.number_of_parameters)
3480
+ )
3481
+ if total != 0 and total != 1:
3482
+ for param_key, param_idx, param_value in function.keys_with_param:
3483
+ function.set_param_value(param_idx, param_value / total)
3484
+ else:
3485
+ group_sum = {}
3486
+ for param_key, param_idx, param_value in function.keys_with_param:
3487
+ group = tuple(param_key[i] for i in grouping_positions)
3488
+ group_sum[group] = group_sum.get(group, 0) + param_value
3489
+
3490
+ for param_key, param_idx, param_value in function.keys_with_param:
3491
+ group = tuple(param_key[i] for i in grouping_positions)
3492
+ total = group_sum[group]
3493
+ if total > 0:
3494
+ function.set_param_value(param_idx, param_value / total)