compiled-knowledge 4.0.0a20__cp313-cp313-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37523 -0
  4. ck/circuit/_circuit_cy.cp313-win32.pyd +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cp313-win32.pyd +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp313-win32.pyd +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp313-win32.pyd +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
ck/pgm.py ADDED
@@ -0,0 +1,3475 @@
1
+ """
2
+ For more documentation on this module, refer to the Jupyter notebook docs/4_PGM_advanced.ipynb.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from itertools import repeat as _repeat
10
+ from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
11
+ Collection, Any, Iterator
12
+
13
+ import numpy as np
14
+
15
+ from ck.utils.iter_extras import (
16
+ combos_ranges as _combos_ranges, multiply as _multiply, combos as _combos
17
+ )
18
+ from ck.utils.np_extras import NDArrayFloat64, NDArrayUInt8
19
+
20
+ # What types are permitted as random variable states
21
+ State = Union[int, str, bool, float, None]
22
+
23
+ # An instance (of a sequence of random variables) is a tuple of integers
24
+ # that are state indexes, co-indexed with a known sequence of random variables.
25
+ Instance = Sequence[int]
26
+
27
+ # A key identifies an instance.
28
+ # A single integer is treated as an instance with one dimension.
29
+ Key = Union[Sequence[int], int]
30
+
31
+ # The shape of a sequence of random variables (e.g., a PGM, Factor or PotentialFunction).
32
+ Shape = Sequence[int]
33
+
34
+ DEFAULT_TOLERANCE: float = 0.000001 # For checking CPT sums.
35
+
36
+
37
+ class PGM:
38
+ """
39
+ A probabilistic graphical model (PGM) represents a joint probability distribution over
40
+ a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
41
+
42
+ Add a random variable to a PGM, pgm, using `rv = pgm.new_rv(...)`.
43
+
44
+ Add a factor to the PGM, pgm, using `factor = pgm.new_factor(...)`.
45
+
46
+ A PGM may be given a human-readable name.
47
+ """
48
+
49
+ def __init__(self, name: Optional[str] = None):
50
+ """
51
+ Create an empty PGM.
52
+
53
+ Args:
54
+ name: an optional name for the PGM. If not provided, a default name will be
55
+ created using `default_pgm_name`.
56
+ """
57
+ self._name: str = name if name is not None else default_pgm_name(self)
58
+ self._rvs: Tuple[RandomVariable, ...] = ()
59
+ self._shape: Shape = ()
60
+ self._indicators: Tuple[Indicator, ...] = ()
61
+ self._factors: Tuple[Factor, ...] = ()
62
+
63
+ @property
64
+ def name(self) -> str:
65
+ """
66
+ Returns:
67
+ The name of the PGM.
68
+ """
69
+ return self._name
70
+
71
+ @property
72
+ def number_of_rvs(self) -> int:
73
+ """
74
+ Returns:
75
+ How many random variables are defined in this PGM.
76
+ """
77
+ return len(self._rvs)
78
+
79
+ @property
80
+ def shape(self) -> Shape:
81
+ """
82
+ Returns:
83
+ a sequence of the lengths of `self.rvs`.
84
+ """
85
+ return self._shape
86
+
87
+ @property
88
+ def number_of_indicators(self) -> int:
89
+ """
90
+ Returns:
91
+ How many indicators are defined in this PGM, i.e., `sum(len(rv) for rv in self.rvs)`.
92
+ """
93
+ return len(self._indicators)
94
+
95
+ @property
96
+ def number_of_states(self) -> int:
97
+ """
98
+ Returns:
99
+ What is the size of the state space, i.e., `multiply(len(rv) for rv in self.rvs)`.
100
+ """
101
+ return number_of_states(*self._rvs)
102
+
103
+ @property
104
+ def number_of_factors(self) -> int:
105
+ """
106
+ Returns:
107
+ How many factors are defined in this PGM.
108
+ """
109
+ return len(self._factors)
110
+
111
+ @property
112
+ def number_of_functions(self) -> int:
113
+ """
114
+ Returns:
115
+ How many potential functions are defined in this PGM, including zero potential functions.
116
+ """
117
+ return sum(1 for _ in self.functions)
118
+
119
+ @property
120
+ def number_of_non_zero_functions(self) -> int:
121
+ """
122
+ Returns:
123
+ How many potential functions are defined in this PGM, excluding zero potential functions.
124
+ """
125
+ return sum(1 for _ in self.non_zero_functions)
126
+
127
+ @property
128
+ def rvs(self) -> Sequence[RandomVariable]:
129
+ """
130
+ Returns:
131
+ All the random variables, in `idx` order, which is the same as creation order.
132
+
133
+ Ensures:
134
+ `self.rvs[rv.idx] = rv`
135
+ """
136
+ return self._rvs
137
+
138
+ @property
139
+ def rv_log_sizes(self) -> Sequence[float]:
140
+ """
141
+ Returns:
142
+ [log2(len(rv)) for rv in self.rvs]
143
+ """
144
+ return [math.log2(len(rv)) for rv in self.rvs]
145
+
146
+ @property
147
+ def indicators(self) -> Sequence[Indicator]:
148
+ """
149
+ Returns:
150
+ All the random variable indicators.
151
+
152
+ Ensures:
153
+ the indicators of a random variable are adjacent,
154
+ the indicators of a random variable are in state index order,
155
+ the random variables are in the same order as `self.rvs`.
156
+ """
157
+ return self._indicators
158
+
159
+ @property
160
+ def factors(self) -> Sequence[Factor]:
161
+ """
162
+ Returns:
163
+ All the factors, in `idx` order, which is the same as creation order.
164
+
165
+ Ensures:
166
+ `self.factors[factor.idx] = factor`
167
+ """
168
+ return self._factors
169
+
170
+ @property
171
+ def functions(self) -> Iterable[PotentialFunction]:
172
+ """
173
+ Iterate over all in-use potential functions of this PGM, including
174
+ zero potential functions.
175
+
176
+ Returns:
177
+ An Iterable over all potential functions (including zero potential functions).
178
+ """
179
+ seen: Set[int] = set()
180
+ for factor in self._factors:
181
+ function = factor.function
182
+ if id(function) not in seen:
183
+ seen.add(id(function))
184
+ yield function
185
+
186
+ @property
187
+ def non_zero_functions(self) -> Iterable[PotentialFunction]:
188
+ """
189
+ Iterate over all in-use potential functions of this PGM, excluding
190
+ zero potential functions.
191
+
192
+ Returns:
193
+ An Iterable over all potential functions (excluding zero potential functions).
194
+ """
195
+ seen: Set[int] = set()
196
+ for factor in self._factors:
197
+ function = factor.function
198
+ if not (isinstance(function, ZeroPotentialFunction) or id(function) in seen):
199
+ seen.add(id(function))
200
+ yield function
201
+
202
+ def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
203
+ """
204
+ Add a new random variable to this PGM.
205
+
206
+ The returned random variable will have an `idx` equal to the value of
207
+ `self.number_of_rvs` just prior to adding the new random variable.
208
+
209
+ Assumes:
210
+ Provided states contain no duplicates.
211
+
212
+ Args:
213
+ name: a name for the random variable.
214
+ states: either an integer number of states or a sequence of state values. If a
215
+ single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
216
+ If a sequence of states are provided then the states must be unique.
217
+
218
+ Returns:
219
+ a RandomVariable object belonging to this PGM.
220
+ """
221
+ return RandomVariable(self, name, states)
222
+
223
+ def new_factor(self, *rvs: RandomVariable) -> Factor:
224
+ """
225
+ Add a new factor to this PGM where the factor connects
226
+ the given random variables.
227
+
228
+ The returned factor will have a ZeroPotentialFunction as its potential function.
229
+ The potential function may be changed by calling methods on the returned factor.
230
+
231
+ The returned factor will have an `idx` equal to the value of
232
+ `self.number_of_factors` just prior to adding the new factor.
233
+
234
+ Assumes:
235
+ The given random variables all belong to this PGM.
236
+ The random variables contain no duplicates.
237
+
238
+ Args:
239
+ *rvs: the random variables.
240
+
241
+ Returns:
242
+ a Factor object belonging to this PGM.
243
+ """
244
+ return Factor(self, *rvs)
245
+
246
+ def new_factor_implies(
247
+ self,
248
+ rv_1: RandomVariable,
249
+ state_idxs_1: int | Collection[int],
250
+ rv_2: RandomVariable,
251
+ state_idxs_2: int | Collection[int],
252
+ ) -> Factor:
253
+ """
254
+ Add a sparse 0/1 factor to this PGM representing:
255
+ rv_1 in state_idxs_1 ==> rv_2 in states_2.
256
+ That is:
257
+ factor[s1, s2] = 1, if s1 not in state_idxs_1 or s2 in states_2;
258
+ = 0, otherwise.
259
+
260
+ Args:
261
+ rv_1: The first random variable.
262
+ state_idxs_1: state idxs of the first random variable.
263
+ rv_2: The second random variable.
264
+ state_idxs_2: state idxs of the second random variable.
265
+
266
+ Returns:
267
+ a Factor object belonging to this PGM, with a configured sparse potential function.
268
+ """
269
+ if isinstance(state_idxs_1, int):
270
+ state_idxs_1 = (state_idxs_1,)
271
+ if isinstance(state_idxs_2, int):
272
+ state_idxs_2 = (state_idxs_2,)
273
+
274
+ factor = self.new_factor(rv_1, rv_2)
275
+ f = factor.set_sparse()
276
+ for i_1 in rv_1.state_range():
277
+ if i_1 not in state_idxs_1:
278
+ for i_2 in rv_2.state_range():
279
+ f[i_1, i_2] = 1
280
+ else:
281
+ for i_2 in rv_2.state_range():
282
+ if i_2 in state_idxs_2:
283
+ f[i_1, i_2] = 1
284
+ return factor
285
+
286
+ def new_factor_equiv(
287
+ self,
288
+ rv_1: RandomVariable,
289
+ state_idxs_1: int | Collection[int],
290
+ rv_2: RandomVariable,
291
+ state_idxs_2: int | Collection[int],
292
+ ) -> Factor:
293
+ """
294
+ Add a sparse 0/1 factor to this PGM representing:
295
+ rv_1 in state_idxs_1 <==> rv_2 in state_idxs_2.
296
+ That is:
297
+ factor[s1, s2] = 1, if s1 in state_idxs_1 == s2 in state_idxs_2;
298
+ = 0, otherwise.
299
+
300
+ Args:
301
+ rv_1: The first random variable.
302
+ state_idxs_1: state idxs of the first random variable.
303
+ rv_2: The second random variable.
304
+ state_idxs_2: state idxs of the second random variable.
305
+
306
+ Returns:
307
+ a Factor object belonging to this PGM, with a configured sparse potential function.
308
+ """
309
+ if isinstance(state_idxs_1, int):
310
+ state_idxs_1 = (state_idxs_1,)
311
+ if isinstance(state_idxs_2, int):
312
+ state_idxs_2 = (state_idxs_2,)
313
+
314
+ factor = self.new_factor(rv_1, rv_2)
315
+ f = factor.set_sparse()
316
+ for i_1 in rv_1.state_range():
317
+ in_1 = i_1 in state_idxs_1
318
+ for i_2 in rv_2.state_range():
319
+ in_2 = i_2 in state_idxs_2
320
+ if in_1 == in_2:
321
+ f[i_1, i_2] = 1
322
+ return factor
323
+
324
+ def new_factor_functional(
325
+ self,
326
+ function: Callable[[...], int],
327
+ result_rv: RandomVariable,
328
+ *input_rvs: RandomVariable
329
+ ) -> Factor:
330
+ """
331
+ Add a sparse 0/1 factor to this PGM representing:
332
+ result_rv == function(*rvs).
333
+ That is:
334
+ factor[result_s, *input_s] = 1, if result_s == function(*input_s);
335
+ = 0, otherwise.
336
+ Args:
337
+ function: a function from state indexes of the input random variables to a state index
338
+ of the result random variable. The function should take the same number of arguments
339
+ as `input_rvs` and return a state index for `result_rv`.
340
+ result_rv: the random variable defining result values.
341
+ *input_rvs: the random variables defining input values.
342
+
343
+ Returns:
344
+ a Factor object belonging to this PGM, with a configured sparse potential function.
345
+ """
346
+ factor = self.new_factor(result_rv, *input_rvs)
347
+ f = factor.set_sparse()
348
+ for input_s in _combos([list(rv.state_range()) for rv in input_rvs]):
349
+ result_s = function(*input_s)
350
+ f[(result_s,) + input_s] = 1
351
+ return factor
352
+
353
+ def indicator_pair(self, indicator: Indicator) -> Tuple[RandomVariable, State]:
354
+ """
355
+ Convert the given indicator to its RandomVariable and State value.
356
+
357
+ Args:
358
+ indicator: the indicator to convert.
359
+
360
+ Returns:
361
+ (rv, state) where
362
+ rv: is the random variable of the indicator.
363
+ state: is the random variable state of the indicator.
364
+ """
365
+ rv = self._rvs[indicator.rv_idx]
366
+ state = rv.states[indicator.state_idx]
367
+ return rv, state
368
+
369
+ def indicator_str(self, *indicators: Indicator, sep: str = '=', delim: str = ', ') -> str:
370
+ """
371
+ Render indicators as a string.
372
+
373
+ For example:
374
+ pgm = PGM()
375
+ a = pgm.new_rv('A', ('x', 'y', 'z'))
376
+ b = pgm.new_rv('B', (3, 5))
377
+ print(pgm.indicator_str(a[0], b[1], a[2]))
378
+ will print:
379
+ A=x, B=5, A=z
380
+
381
+ Args:
382
+ *indicators: the indicators to render.
383
+ sep: the separator to use between the random variable and its state.
384
+ delim: the delimiter to used when rendering multiple indicators.
385
+
386
+ Returns:
387
+ a string representation of the given indicators.
388
+ """
389
+ return delim.join(
390
+ f'{_clean_str(rv)}{sep}{_clean_str(state)}'
391
+ for rv, state in (
392
+ self.indicator_pair(indicator)
393
+ for indicator in indicators
394
+ )
395
+ )
396
+
397
+ def condition_str(self, *indicators: Indicator) -> str:
398
+ """
399
+ Render indicators as a string, grouping indicators by random variable.
400
+
401
+ For example:
402
+ pgm = PGM()
403
+ a = pgm.new_rv('A', ('x', 'y', 'z'))
404
+ b = pgm.new_rv('B', (3, 5))
405
+ print(pgm.condition_str(a[0], b[1], a[2]))
406
+ will print:
407
+ A in {x, z}, B=5
408
+
409
+ Args:
410
+ *indicators: the indicators to render.
411
+ Return:
412
+ a string representation of the given indicators, as a condition.
413
+ """
414
+ indicators: List[Indicator] = sorted(indicators, reverse=True)
415
+ cur_rv: Set[Indicator] = set()
416
+ cur_idx: int = -1 # rv_idx of the rv we are currently working on, -1 means not yet started.
417
+ cur_str: str = '' # accumulated result string
418
+ while len(indicators) > 0:
419
+ this_ind = indicators.pop()
420
+ if this_ind.rv_idx != cur_idx:
421
+ if cur_idx >= 0:
422
+ cur_str = self._condition_str_rv(cur_str, cur_rv)
423
+ cur_rv = set()
424
+ cur_idx = this_ind.rv_idx
425
+ cur_rv.add(this_ind)
426
+ if cur_idx >= 0:
427
+ cur_str = self._condition_str_rv(cur_str, cur_rv)
428
+ return cur_str
429
+
430
+ def instance_str(
431
+ self,
432
+ instance: Instance,
433
+ rvs: Optional[Sequence[RandomVariable]] = None,
434
+ sep: str = '=',
435
+ delim: str = ', ',
436
+ ) -> str:
437
+ """
438
+ Render an instance as a string.
439
+
440
+ The result looks something like 'X=x, Y=y, Z=z' where X, Y, and X are
441
+ random variables and x, y, and z are the states represented by the
442
+ given instance.
443
+
444
+ Args:
445
+ instance: the instance to render.
446
+ rvs: the random variables that the instance refers to. If rvs is None, then `self.rvs` is used.
447
+ sep: the separator to use between the random variable and its state.
448
+ delim: the delimiter to used when rendering multiple indicators.
449
+
450
+ Returns:
451
+ a string representation of the indicators implied by the given instance.
452
+ """
453
+ if rvs is None:
454
+ rvs = self.rvs
455
+ assert len(instance) == len(rvs)
456
+ return self.indicator_str(
457
+ *[rv[state] for rv, state in zip(rvs, instance)],
458
+ sep=sep,
459
+ delim=delim
460
+ )
461
+
462
+ def state_str(
463
+ self,
464
+ instance: Instance,
465
+ rvs: Optional[Sequence[RandomVariable]] = None,
466
+ delim: str = ', ',
467
+ ) -> str:
468
+ """
469
+ Render the states of an instance.
470
+
471
+ The result looks something like 'x, y, z' where x, y, and z are
472
+ the states of the random variables represented by the given instance.
473
+
474
+ Args:
475
+ instance: the instance to render.
476
+ rvs: the random variables that the instance refers to. If rvs is None, then `self.rvs` is used.
477
+ delim: the delimiter to used when rendering multiple indicators.
478
+
479
+ Returns:
480
+ a string representation of the states implied by the given instance.
481
+ """
482
+ if rvs is None:
483
+ rvs = self.rvs
484
+ assert len(instance) == len(rvs)
485
+ return delim.join(str(rv.states[i]) for rv, i in zip(rvs, instance))
486
+
487
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
488
+ """
489
+ Iterate over all possible instances of this PGM, in natural index
490
+ order (i.e., last random variable changing most quickly).
491
+
492
+ Args:
493
+ flip: if true, then first random variable changes most quickly.
494
+
495
+ Returns:
496
+ an iteration over tuples, each tuple holds random variable state indexes
497
+ co-indexed with this PGM's random variables, `self.rvs`.
498
+ """
499
+ return _combos_ranges(tuple(len(rv) for rv in self._rvs), flip=not flip)
500
+
501
+ def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
502
+ """
503
+ Iterate over all possible instances of this PGM, in natural index
504
+ order (i.e., last random variable changing most quickly).
505
+
506
+ Args:
507
+ flip: if true, then first random variable changes most quickly.
508
+
509
+ Returns:
510
+ an iteration over tuples, each tuples holds random variable indicators
511
+ co-indexed with this PGM's random variables, `self.rvs`.
512
+ """
513
+ for inst in self.instances(flip=flip):
514
+ yield self.state_idxs_to_indicators(inst)
515
+
516
+ def state_idxs_to_indicators(self, instance: Sequence[int]) -> Sequence[Indicator]:
517
+ """
518
+ Given an instance (list of random variable state indexes), co-indexed with the PGM's
519
+ random variables, `self.rvs`, return the corresponding indicators.
520
+
521
+ Assumes:
522
+ The instance has the same length as `self.rvs`.
523
+ The instance is co-indexed with `self.rvs`.
524
+
525
+ Args:
526
+ instance: the instance to convert to indicators.
527
+
528
+ Returns:
529
+ a tuple of indicators, co-indexed with `self.rvs`.
530
+ """
531
+ return tuple(rv[state] for rv, state in zip(self._rvs, instance))
532
+
533
+ def factor_values(self, key: Key) -> Iterable[float]:
534
+ """
535
+ For a given instance key, each factor defines a single value. This method
536
+ returns those values.
537
+
538
+ Args:
539
+ key: the key defining an instance of this PGM.
540
+
541
+ Returns:
542
+ an iterator over factor values, co-indexed with the factors of this PGM.
543
+ """
544
+ instance: Instance = check_key(self._shape, key)
545
+ assert len(instance) == len(self._rvs)
546
+ for factor in self._factors:
547
+ states: Sequence[int] = tuple(instance[rv.idx] for rv in factor.rvs)
548
+ value: float = factor.function[states]
549
+ yield value
550
+
551
+ @property
552
+ def is_structure_bayesian(self) -> bool:
553
+ """
554
+ Does the PGM structure correspond to a Bayesian network, where
555
+ each factor is taken to be a CPT and the first random variable of factor
556
+ is taken to be the child.
557
+
558
+ This method does not check the factor parameters to confirm they correspond
559
+ to valid CPTs.
560
+
561
+ Return:
562
+ True only if:
563
+ the number of factors equals the number of random variables,
564
+ each random variable appears exactly once as the first random variable of a factor,
565
+ there are no directed loops created by the factors.
566
+ """
567
+
568
+ # One factor per random variable.
569
+ if self.number_of_factors != self.number_of_rvs:
570
+ return False
571
+
572
+ # Each random variable is a child.
573
+ # Map each random variable to the factor it is a child of
574
+ child_to_factor: Dict[int, Factor] = {
575
+ factor.rvs[0].idx: factor
576
+ for factor in self._factors
577
+ }
578
+ if len(child_to_factor) != self.number_of_rvs:
579
+ return False
580
+
581
+ # Factors form a DAG
582
+ states: NDArrayUInt8 = np.zeros(self.number_of_factors, dtype=np.uint8)
583
+ for factor in self._factors:
584
+ if self._has_cycle(factor, child_to_factor, states):
585
+ return False
586
+
587
+ # All tests passed
588
+ return True
589
+
590
+ def factors_are_cpts(self, tolerance: float = DEFAULT_TOLERANCE) -> bool:
591
+ """
592
+ Are all factor potential functions set with parameters values
593
+ conforming to Conditional Probability Tables.
594
+
595
+ Assumes:
596
+ tolerance is non-negative.
597
+
598
+ Args:
599
+ tolerance: a tolerance when testing if values are equal to zero or one.
600
+
601
+ Returns:
602
+ True only if every potential function conforms to being a valid CPT.
603
+ """
604
+ return all(function.is_cpt(tolerance) for function in self.functions)
605
+
606
+ def check_is_bayesian_network(self, tolerance: float = DEFAULT_TOLERANCE) -> bool:
607
+ """
608
+ Is this PGM a Bayesian network.
609
+
610
+ Assumes:
611
+ tolerance is non-negative.
612
+
613
+ Args:
614
+ tolerance: a tolerance when testing if values are equal to zero or one.
615
+
616
+ Returns:
617
+ `is_structure_bayesian and check_factors_are_cpts(tolerance)`.
618
+ """
619
+ return self.is_structure_bayesian and self.factors_are_cpts(tolerance)
620
+
621
+ def value_product(self, key: Key) -> float:
622
+ """
623
+ For a given instance key, each factor defines a single value. This method
624
+ returns the product of those values.
625
+
626
+ Args:
627
+ key: the key defining an instance of this PGM.
628
+
629
+ Returns:
630
+ the product of factor values.
631
+ """
632
+ return _multiply(self.factor_values(key))
633
+
634
+ def value_product_indicators(self, *indicators: Indicator) -> float:
635
+ """
636
+ Return the product of factors, conditioned on the given indicators.
637
+
638
+ For random variables not mentioned in the indicators, then the result is the sum
639
+ of the value product for each possible combination of states of the unmentioned
640
+ random variables.
641
+
642
+ If no indicators are provided, then the value of the partition function (z)
643
+ is returned.
644
+
645
+ If multiple indicators are provided for the same random variable, then all matching
646
+ instances are summed.
647
+
648
+ This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
649
+
650
+ Warning:
651
+ this is potentially computationally expensive as it marginalised random
652
+ variables not mentioned in the given indicators.
653
+
654
+ Args:
655
+ *indicators: are indicators from random variables of this PGM.
656
+
657
+ Returns:
658
+ the product of factors, conditioned on the given instance. This is the
659
+ computed value of the PGM, conditioned on the given instance.
660
+ """
661
+ # # Create a filter from the indicators
662
+ # inst_filter: List[Set[int]] = [set() for _ in range(self.number_of_rvs)]
663
+ # for indicator in indicators:
664
+ # rv_idx: int = indicator.rv_idx
665
+ # inst_filter[rv_idx].add(indicator.state_idx)
666
+ # # Collect rvs not mentioned - to marginalise
667
+ # for rv, rv_filter in zip(self.rvs, inst_filter):
668
+ # if len(rv_filter) == 0:
669
+ # rv_filter.update(rv.state_range())
670
+ #
671
+ # def _sum_inst(_instance: Instance) -> bool:
672
+ # return all(
673
+ # (_state in _rv_filter)
674
+ # for _state, _rv_filter in zip(_instance, inst_filter)
675
+ # )
676
+ #
677
+ # # Accumulate the result
678
+ # sum_value = 0
679
+ # for instance in self.instances():
680
+ # if _sum_inst(instance):
681
+ # sum_value += self.value_product(instance)
682
+ #
683
+ # return sum_value
684
+
685
+ # Work out the space to sum over
686
+ sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
687
+ for indicator in indicators:
688
+ rv_idx: int = indicator.rv_idx
689
+ cur_set = sum_space_set[rv_idx]
690
+ if cur_set is None:
691
+ sum_space_set[rv_idx] = cur_set = set()
692
+ cur_set.add(indicator.state_idx)
693
+
694
+ # Convert to a list of states that we need to sum over.
695
+ sum_space_list: List[List[int]] = [
696
+ list(cur_set if cur_set is not None else rv.state_range())
697
+ for cur_set, rv in zip(sum_space_set, self.rvs)
698
+ ]
699
+
700
+ # Accumulate the result
701
+ return sum(
702
+ self.value_product(instance)
703
+ for instance in _combos(sum_space_list)
704
+ )
705
+
706
+ def dump_synopsis(
707
+ self,
708
+ *,
709
+ prefix: str = '',
710
+ precision: int = 3,
711
+ max_state_digits: int = 21,
712
+ ):
713
+ """
714
+ Print a synopsis of the PGM.
715
+ This is intended for demonstration and debugging purposes.
716
+
717
+ Args:
718
+ prefix: optional prefix for indenting all lines.
719
+ precision: a limit on the render precision of floating point numbers.
720
+ max_state_digits: a limit on the number of digits when showing number of states as an integer.
721
+ """
722
+ # limit the precision when displaying number of states
723
+ num_states: int = self.number_of_states
724
+ number_of_parameters = sum(function.number_of_parameters for function in self.functions)
725
+ number_of_nz_parameters = sum(function.number_of_parameters for function in self.non_zero_functions)
726
+
727
+ if math.log10(num_states) > max_state_digits:
728
+ log_states = math.log10(num_states)
729
+ exp = int(log_states)
730
+ man = math.pow(10, log_states - exp)
731
+ num_states_str = f'{man:,.{precision}f}e+{exp}'
732
+ else:
733
+ num_states_str = f'{num_states:,}'
734
+
735
+ log_2_num_states = math.log2(num_states)
736
+ if (
737
+ log_2_num_states == 0
738
+ or (
739
+ log_2_num_states == int(log_2_num_states)
740
+ and math.log10(log_2_num_states) <= max_state_digits
741
+ )
742
+ ):
743
+ log_2_num_states_str = f'{int(log_2_num_states):,}'
744
+ else:
745
+ log_2_num_states_str = f'{math.log2(num_states):,.{precision}f}'
746
+
747
+ print(f'{prefix}name: {self.name}')
748
+ print(f'{prefix}number of random variables: {self.number_of_rvs:,}')
749
+ print(f'{prefix}number of indicators: {self.number_of_indicators:,}')
750
+ print(f'{prefix}number of states: {num_states_str}')
751
+ print(f'{prefix}log 2 of states: {log_2_num_states_str}')
752
+ print(f'{prefix}number of factors: {self.number_of_factors:,}')
753
+ print(f'{prefix}number of functions: {self.number_of_functions:,}')
754
+ print(f'{prefix}number of non-zero functions: {self.number_of_non_zero_functions:,}')
755
+ print(f'{prefix}number of parameters: {number_of_parameters:,}')
756
+ print(f'{prefix}number of functions (excluding ZeroPotentialFunction): {self.number_of_non_zero_functions:,}')
757
+ print(f'{prefix}number of parameters (excluding ZeroPotentialFunction): {number_of_nz_parameters:,}')
758
+ print(f'{prefix}Bayesian structure: {self.is_structure_bayesian}')
759
+ print(f'{prefix}CPT factors: {self.factors_are_cpts()}')
760
+
761
+ def dump(
762
+ self,
763
+ *,
764
+ prefix: str = '',
765
+ indent: str = ' ',
766
+ show_function_values: bool = False,
767
+ precision: int = 3,
768
+ max_state_digits: int = 21,
769
+ ) -> None:
770
+ """
771
+ Print a dump of the PGM.
772
+ This is intended for demonstration and debugging purposes.
773
+
774
+ Args:
775
+ prefix: optional prefix for indenting all lines.
776
+ show_function_values: if true, then the function values will be dumped.
777
+ indent: additional prefix to use for extra indentation.
778
+ precision: a limit on the render precision of floating point numbers.
779
+ max_state_digits: a limit on the number of digits when showing number of states as an integer.
780
+ """
781
+
782
+ next_prefix: str = prefix + indent
783
+ next_next_prefix: str = next_prefix + indent
784
+
785
+ print(f'{prefix}PGM id={id(self)} name={self.name!r}')
786
+ self.dump_synopsis(prefix=next_prefix, precision=precision, max_state_digits=max_state_digits)
787
+
788
+ print(f'{prefix}random variables ({self.number_of_rvs})')
789
+ for rv in self.rvs:
790
+ print(f'{next_prefix}{rv.idx:>3} {rv.name!r} ({len(rv)})', end='')
791
+ if not rv.is_default_states():
792
+ print(' [', end='')
793
+ print(', '.join(repr(s) for s in rv.states), end='')
794
+ print(']', end='')
795
+ print()
796
+
797
+ print(f'{prefix}factors ({self.number_of_factors})')
798
+ for factor in self.factors:
799
+ rv_idxs = [rv.idx for rv in factor.rvs]
800
+ if factor.is_zero:
801
+ function_ref = '<zero>'
802
+ else:
803
+ function = factor.function
804
+ function_ref = f'{id(function)}: {function.__class__.__name__}'
805
+
806
+ print(f'{next_prefix}{factor.idx:>3} rvs={rv_idxs} function={function_ref}')
807
+
808
+ print(f'{prefix}functions ({self.number_of_functions})')
809
+ for function in sorted(self.non_zero_functions, key=lambda f: id(f)):
810
+ print(f'{next_prefix}{id(function):>13}: {function.__class__.__name__}')
811
+ function.dump(prefix=next_next_prefix, show_function_values=show_function_values, show_id_class=False)
812
+
813
+ print(f'{prefix}end PGM id={id(self)}')
814
+
815
+ def _has_cycle(self, factor: Factor, child_to_factor: Dict[int, Factor], states: NDArrayUInt8) -> bool:
816
+ """
817
+ Support function for `is_structure_bayesian`.
818
+
819
+ A recursive depth-first-search to see if the factors form a DAG.
820
+
821
+ For a factor `f` the value of states[f.idx] is the search state.
822
+ Specifically:
823
+ state 0 => the factor has not been seen yet,
824
+ state 1 => the factor is seen but not fully processed,
825
+ state 2 => the factor is fully processed.
826
+
827
+ Args:
828
+ factor: the current Factor being checked.
829
+ child_to_factor: a dictionary from `RandomVariable.idx` to Factor
830
+ with that random variable as the child.
831
+ states: depth-first-search states, i.e., `states[i]` is the state of a factor with `Factor.idx == i`.
832
+ Returns:
833
+ True if a directed cycle is detected.
834
+ """
835
+ f_idx: int = factor.idx
836
+ match states.item(f_idx):
837
+ case 1:
838
+ return True
839
+ case 0:
840
+ states[f_idx] = 1
841
+ for parent in factor.rvs[1:]:
842
+ parent_factor = child_to_factor[parent.idx]
843
+ if self._has_cycle(parent_factor, child_to_factor, states):
844
+ return True
845
+ states[f_idx] = 2
846
+ return False
847
+ return False
848
+
849
+ def _register_rv(self, rv: RandomVariable) -> None:
850
+ """
851
+ Called by the constructor of RandomVariable to record a newly created Random variable
852
+ of this PGM.
853
+
854
+ Args:
855
+ rv: the newly constructed random variable.
856
+ """
857
+ assert rv.pgm is self
858
+ self._rvs += (rv,)
859
+ self._shape += (len(rv),)
860
+ self._indicators += rv.indicators
861
+
862
+ def _condition_str_rv(
863
+ self,
864
+ cur_str: str,
865
+ cur_rv: Set[Indicator],
866
+ sep: str = ', ',
867
+ equal: str = '=',
868
+ elem: str = ' in ',
869
+ ) -> str:
870
+ """
871
+ Support method for `self.condition_str`.
872
+
873
+ This is a method renders a condition defined by a set of indicators, of the same random variable.
874
+
875
+ Args:
876
+ cur_str: the string to append to.
877
+ cur_rv: a set of indicators, all from the same random variable.
878
+ sep: the separator string to use between condition components.
879
+ equal: the string to use for _rv_ = _state_.
880
+ elem: the string to use for _rv_ in _set_.
881
+
882
+ Returns:
883
+ `cur_str` appended with the new condition, `cur_rv`.
884
+ """
885
+ if cur_str != '':
886
+ cur_str += sep
887
+ if len(cur_rv) == 1:
888
+ cur_str += self.indicator_str(*cur_rv, sep=equal)
889
+ else:
890
+ _cur_rv = sorted(cur_rv)
891
+ rv = self._rvs[_cur_rv[0].rv_idx]
892
+ states_str: str = sep.join(_clean_str(rv.states[ind.state_idx]) for ind in _cur_rv)
893
+ cur_str += f'{_clean_str(rv)}{elem}{{{states_str}}}'
894
+ return cur_str
895
+
896
+
897
+ @dataclass(frozen=True, eq=True, slots=True)
898
+ class Indicator:
899
+ """
900
+ An indicator identifies a random variable being in a particular state.
901
+
902
+ Indicators are immutable and hashable.
903
+
904
+ Note that an Indicator does not know which PGM it came from, therefore indicators from one PGM
905
+ are interchangeable with indicators of another PGM so long as corresponding random variables of the
906
+ PGMs are co-indexed (created in the same order) and corresponding random variables have the same
907
+ states.
908
+
909
+ Fields:
910
+ rv_idx: `rv.idx` where `rv` is the random variable referenced by this indicator.
911
+ state_idx: the state index of the state referenced by this indicator.
912
+ """
913
+ rv_idx: int
914
+ state_idx: int
915
+
916
+ def __lt__(self, other) -> bool:
917
+ """
918
+ Define a sort order over indicators.
919
+ When sorted, indicators are ordered by random variable index, then by state index.
920
+ """
921
+ if isinstance(other, Indicator):
922
+ if self.rv_idx < other.rv_idx:
923
+ return True
924
+ if self.rv_idx > other.rv_idx:
925
+ return False
926
+ return self.state_idx < other.state_idx
927
+ return False
928
+
929
+
930
+ class RandomVariable(Sequence[Indicator]):
931
+ """
932
+ A random variable in a probabilistic graphical model.
933
+
934
+ Random variables are immutable and hashable.
935
+
936
+ Each RandomVariable has a fixed finite number of states.
937
+ Its states are indexed by integers, counting from zero.
938
+
939
+ Every RandomVariable object belongs to exactly one PGM object.
940
+
941
+ Every random variable has an index (counting from zero) which is its position
942
+ in the random variable's PGM list of random variables.
943
+
944
+ A random variable behaves like a sequence of Indicators, where each indicator represents a random
945
+ variable being in a particular state. Specifically for a random variable rv, len(rv) is the
946
+ number of states of the random variable and rv[i] is the Indicators representing that
947
+ rv is in the ith state. When sliced, the result is a tuple, i.e. rv[1:3] = (rv[1], rv[2]).
948
+
949
+ A RandomVariable has a name. This is for human convenience and has no functional purpose
950
+ within a PGM.
951
+ """
952
+
953
+ def __init__(self, pgm: PGM, name: str, states: Union[int, Sequence[State]]):
954
+ """
955
+ Create a new random variable, in the given PGM.
956
+
957
+ Assumes:
958
+ Provided states contain no duplicates.
959
+
960
+ Args:
961
+ pgm: the PGM that the random variable will belong to.
962
+ name: a name for the random variable.
963
+ states: either an integer number of states or a sequence of state values. If a
964
+ single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
965
+ If a sequence of states are provided then the states must be unique.
966
+ """
967
+ self._pgm: PGM = pgm
968
+ self._name: str = name
969
+
970
+ if isinstance(states, int):
971
+ states = tuple(range(states))
972
+
973
+ self._states: Sequence[State] = tuple(states)
974
+ self._inv_states: Dict[State, int] = {state: idx for idx, state in enumerate(self._states)}
975
+
976
+ if len(self._inv_states) != len(self._states):
977
+ raise ValueError('random variable states are not unique')
978
+
979
+ self._offset: int = pgm.number_of_indicators
980
+ self._idx: int = pgm.number_of_rvs
981
+ self._indicators: Sequence[Indicator] = tuple(Indicator(self._idx, i) for i in range(len(self._states)))
982
+
983
+ # Register self with our PGM
984
+ # noinspection PyProtectedMember
985
+ pgm._register_rv(self)
986
+
987
+ @property
988
+ def pgm(self) -> PGM:
989
+ """
990
+ Returns:
991
+ The PGM that this random variable belongs to.
992
+ """
993
+ return self._pgm
994
+
995
+ @property
996
+ def name(self) -> str:
997
+ """
998
+ Returns:
999
+ The name of this random variable.
1000
+ """
1001
+ return self._name
1002
+
1003
+ @property
1004
+ def idx(self) -> int:
1005
+ """
1006
+ Returns:
1007
+ The index of this random variable into the PGM.
1008
+
1009
+ Ensures:
1010
+ `self.pgm.rvs[self.idx] is self`.
1011
+ """
1012
+ return self._idx
1013
+
1014
+ @property
1015
+ def offset(self) -> int:
1016
+ """
1017
+ Returns:
1018
+ The index into the PGM's indicators for the start of this random variable's indicators.
1019
+
1020
+ Ensures:
1021
+ `self.pgm.indicators[self.offset + i] is self[i] for i in range(len(self))`.
1022
+ """
1023
+ return self._offset
1024
+
1025
+ @property
1026
+ def states(self) -> Sequence[State]:
1027
+ """
1028
+ Returns:
1029
+ the states of this random variable, in state index order.
1030
+ """
1031
+ return self._states
1032
+
1033
+ @property
1034
+ def indicators(self) -> Sequence[Indicator]:
1035
+ """
1036
+ Returns:
1037
+ the indicators of this random variable, in state index order.
1038
+ """
1039
+ return self._indicators
1040
+
1041
+ def state_range(self) -> Iterable[int]:
1042
+ """
1043
+ Iterate over the state indexes of this random variable, in order.
1044
+
1045
+ Returns:
1046
+ range(len(self))
1047
+ """
1048
+ return range(len(self._states))
1049
+
1050
+ def factors(self) -> Iterable[Factor]:
1051
+ """
1052
+ Iterate over factors that this random variable participates in.
1053
+ This method performs a search through all `self.pgm.factors`.
1054
+
1055
+ Returns:
1056
+ an iterator over factors.
1057
+ """
1058
+ for factor in self._pgm.factors:
1059
+ if self in factor.rvs:
1060
+ yield factor
1061
+
1062
+ def markov_blanket(self) -> Set[RandomVariable]:
1063
+ """
1064
+ Return the set of random variable that are connected
1065
+ to this random variable by a factor.
1066
+ This method performs a search through all `self.pgm.factors`.
1067
+
1068
+ Returns:
1069
+ a set of random variables connected to this random variable by any factor, excluding self.
1070
+ """
1071
+ result = set()
1072
+ for factor in self.factors():
1073
+ result.update(factor.rvs)
1074
+ result.discard(self)
1075
+ return result
1076
+
1077
+ def state_idx(self, state: State) -> int:
1078
+ """
1079
+ Returns:
1080
+ the state index of the given state of this random variable.
1081
+
1082
+ Assumes:
1083
+ the given state is a state of this random variable.
1084
+ """
1085
+ return self._inv_states[state]
1086
+
1087
+ def is_default_states(self) -> bool:
1088
+ """
1089
+ Are the states of this random variable the default states.
1090
+ I.e., `self.states[i] == i, for all 0 <= i < len(self)`.
1091
+
1092
+ Returns:
1093
+ True only if the states are the same as the state indexes.
1094
+ """
1095
+ return all(i == s for i, s in enumerate(self._states))
1096
+
1097
+ def __str__(self) -> str:
1098
+ """
1099
+ Returns:
1100
+ the name of this random variable.
1101
+ """
1102
+ return self._name
1103
+
1104
+ def __call__(self, state: State) -> Indicator:
1105
+ """
1106
+ Get the indicator for the given state.
1107
+ This is equivalent to self[self.state_idx(state)].
1108
+
1109
+ Returns:
1110
+ an indicator of this random variable.
1111
+
1112
+ Assumes:
1113
+ the given state is a state of this random variable.
1114
+ """
1115
+ return self._indicators[self._inv_states[state]]
1116
+
1117
+ def __hash__(self) -> int:
1118
+ """
1119
+ A random variable is hashable.
1120
+ """
1121
+ return self._idx
1122
+
1123
+ def __eq__(self, other) -> bool:
1124
+ """
1125
+ Two random variable are equal if they are the same object.
1126
+ """
1127
+ return self is other
1128
+
1129
+ def equivalent(self, other: RandomVariable | Sequence[Indicator]) -> bool:
1130
+ """
1131
+ Two random variable are equivalent if their indicators are equal. Only
1132
+ random variable indexes and state indexes are checked.
1133
+
1134
+ This ignores the names of the random variable and the names of their states.
1135
+ This means their indicators will work correctly in slot maps, even
1136
+ if from different PGMs.
1137
+
1138
+ Args:
1139
+ other: either a random variable or a sequence of Indicators.
1140
+
1141
+ Returns:
1142
+ True only if they represent the same sequence of indicators.
1143
+ """
1144
+ indicators = self._indicators
1145
+ if isinstance(other, RandomVariable):
1146
+ return self.idx == other.idx and len(self) == len(other)
1147
+ else:
1148
+ return (
1149
+ len(indicators) == len(other) and
1150
+ all(indicators[i] == other[i] for i in range(len(indicators)))
1151
+ )
1152
+
1153
+ def __len__(self) -> int:
1154
+ """
1155
+ Returns:
1156
+ Number of states (or equivalently, the number of indicators) of this random variable.
1157
+ """
1158
+ return len(self._states)
1159
+
1160
+ def __iter__(self) -> Iterator[Indicator]:
1161
+ """
1162
+ Iterate over the indicators of this random variable.
1163
+ """
1164
+ return iter(self._indicators)
1165
+
1166
+ @overload
1167
+ def __getitem__(self, index: int) -> Indicator:
1168
+ ...
1169
+
1170
+ @overload
1171
+ def __getitem__(self, index: slice) -> Sequence[Indicator]:
1172
+ ...
1173
+
1174
+ def __getitem__(self, index):
1175
+ """
1176
+ Get the indexed (or sliced) indicators.
1177
+ """
1178
+ return self._indicators[index]
1179
+
1180
+ def index(self, value: Any, start: int = 0, stop: int = -1) -> int:
1181
+ """
1182
+ Returns the first index of `value`.
1183
+ Raises ValueError if the value is not present.
1184
+ Contracted by Sequence[Indicator].
1185
+
1186
+ Warning:
1187
+ This method is different to `self.idx`.
1188
+ """
1189
+ if isinstance(value, Indicator):
1190
+ if value.rv_idx == self._idx:
1191
+ idx: int = value.state_idx
1192
+ if stop < 0:
1193
+ stop = len(self) + stop + 1
1194
+ if 0 <= idx < len(self) and start <= idx < stop:
1195
+ return value.state_idx
1196
+ raise ValueError(f'{value!r} is not an indicator of the random variable')
1197
+
1198
+ def count(self, value: Any) -> int:
1199
+ """
1200
+ Returns the number of occurrences of `value`.
1201
+ Contracted by Sequence[Indicator].
1202
+ """
1203
+ if isinstance(value, Indicator):
1204
+ if value.rv_idx == self._idx and 0 <= value.state_idx < len(self):
1205
+ return 1
1206
+ return 0
1207
+
1208
+
1209
+ class RVMap(Sequence[RandomVariable]):
1210
+ """
1211
+ Wrap a PGM to provide convenient access to PGM random variables.
1212
+
1213
+ An RVMap of a PGM behaves exactly like the PGM `rvs` property. That it, it
1214
+ behaves like a sequence of RandomVariable objects.
1215
+
1216
+ If the underlying PGM is updated, then the RVMap will automatically update.
1217
+
1218
+ Additionally, an RVMap enables access to the PGM random variable via the name
1219
+ of each random variable.
1220
+
1221
+ for example, if `pgm.rvs[1]` is a random variable named `xray`, then
1222
+ ```
1223
+ rvs = RVMap(pgm)
1224
+
1225
+ # These all retrieve the same random variable object.
1226
+ xray = rvs[1]
1227
+ xray = rvs('xray')
1228
+ xray = rvs.xray
1229
+ ```
1230
+
1231
+ To use an RVMap on a PGM, the variable names must be unique across the PGM.
1232
+ """
1233
+
1234
+ def __init__(self, pgm: PGM, ignore_case: bool = False):
1235
+ """
1236
+ Construct an RVMap for the given PGM.
1237
+
1238
+ Args:
1239
+ pgm: the PGM to wrap.
1240
+ ignore_case: if true, the variable name are not case-sensitive.
1241
+ """
1242
+ self._pgm: PGM = pgm
1243
+ self._ignore_case: bool = ignore_case
1244
+ self.__rv_map: Dict[str, RandomVariable] = {}
1245
+ self._reserved_names: Set[str] = {self._clean_name(name) for name in dir(self)}
1246
+
1247
+ # Force the rv map cache to be updated.
1248
+ # This may raise an exception.
1249
+ _ = self._rv_map
1250
+
1251
+ def _clean_name(self, name: str) -> str:
1252
+ """
1253
+ Adjust the case of the given name as needed.
1254
+ """
1255
+ return name.lower() if self._ignore_case else name
1256
+
1257
+ @property
1258
+ def _rv_map(self) -> Dict[str, RandomVariable]:
1259
+ """
1260
+ Get the cached rv map, updating as needed if the PGM changed.
1261
+ Returns:
1262
+ a mapping from random variable name to random variable
1263
+ """
1264
+ if len(self.__rv_map) != len(self._pgm.rvs):
1265
+ # There is a difference between the map and the PGM - create a new map.
1266
+ self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
1267
+ if len(self.__rv_map) != len(self._pgm.rvs):
1268
+ raise RuntimeError(f'random variable names are not unique')
1269
+ if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
1270
+ raise RuntimeError(f'random variable names clash with reserved names.')
1271
+ return self.__rv_map
1272
+
1273
+ def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
1274
+ """
1275
+ As per `PGM.new_rv`.
1276
+ Delegate creating a new random variable to the PGM.
1277
+
1278
+ Returns:
1279
+ a RandomVariable object belonging to the PGM.
1280
+ """
1281
+ return self._pgm.new_rv(name, states)
1282
+
1283
+ def __len__(self) -> int:
1284
+ return len(self._pgm.rvs)
1285
+
1286
+ def __getitem__(self, index: int) -> RandomVariable:
1287
+ return self._pgm.rvs[index]
1288
+
1289
+ def items(self) -> Iterable[Tuple[str, RandomVariable]]:
1290
+ return self._rv_map.items()
1291
+
1292
+ def keys(self) -> Iterable[str]:
1293
+ return self._rv_map.keys()
1294
+
1295
+ def values(self) -> Iterable[RandomVariable]:
1296
+ return self._rv_map.values()
1297
+
1298
+ def get(self, rv_name: str, default=None):
1299
+ return self._rv_map.get(self._clean_name(rv_name), default)
1300
+
1301
+ def __call__(self, rv_name: str) -> RandomVariable:
1302
+ return self._rv_map[self._clean_name(rv_name)]
1303
+
1304
+ def __getattr__(self, rv_name: str) -> RandomVariable:
1305
+ return self(rv_name)
1306
+
1307
+
1308
+ class Factor:
1309
+ """
1310
+ A PGM factor over one or more random variables declares a relationship between
1311
+ those variables. A Factor also has a potential function associated with
1312
+ it which defines a real-number value with each combination of states of
1313
+ the random variables.
1314
+
1315
+ The default potential function for a factor is a unique ZeroPotentialFunction.
1316
+
1317
+ The order of a Factors random variables is important as many things will be
1318
+ co-indexed with the random variables. For example, the shape of a Factor is
1319
+ the tuple of random variable lengths.
1320
+
1321
+ Note that multiple factors may share a potential function, so long as they all
1322
+ belong to the same PGM object and have the same shape.
1323
+ """
1324
+
1325
+ def __init__(self, pgm: PGM, *rvs: RandomVariable):
1326
+ """
1327
+ Add a new factor to the given PGM.
1328
+
1329
+ Assumes:
1330
+ The given random variables all belong to this PGM.
1331
+ The random variables contain no duplicates.
1332
+
1333
+ Args:
1334
+ pgm: the PGM that the factor will belong to.
1335
+ *rvs: the random variables.
1336
+
1337
+ Returns:
1338
+ a Factor object belonging to this PGM.
1339
+ """
1340
+ if len(set(rvs)) != len(rvs):
1341
+ raise ValueError('duplicated random variable in factor')
1342
+ if len(rvs) == 0:
1343
+ raise ValueError('must be at least one random variable')
1344
+ if any(rv.pgm is not pgm for rv in rvs):
1345
+ raise ValueError('random variable not from the same PGM')
1346
+
1347
+ self._pgm: PGM = pgm
1348
+ self._idx: int = pgm.number_of_factors
1349
+ self._rvs: Sequence[RandomVariable] = tuple(rvs)
1350
+ self._shape: Shape = tuple(len(rv) for rv in rvs)
1351
+
1352
+ self._zero_potential_function: ZeroPotentialFunction = ZeroPotentialFunction(self)
1353
+ self._potential_function: PotentialFunction = self._zero_potential_function
1354
+
1355
+ # Register self with our PGM
1356
+ # noinspection PyProtectedMember
1357
+ pgm._factors += (self,)
1358
+
1359
+ @property
1360
+ def rvs(self) -> Sequence[RandomVariable]:
1361
+ """
1362
+ Returns:
1363
+ The random variables of this factor.
1364
+ """
1365
+ return self._rvs
1366
+
1367
+ @property
1368
+ def pgm(self) -> PGM:
1369
+ """
1370
+ Returns:
1371
+ The PGM that this factor belongs to.
1372
+ """
1373
+ return self._pgm
1374
+
1375
+ @property
1376
+ def idx(self) -> int:
1377
+ """
1378
+ Returns:
1379
+ The index of this factor into the PGM.
1380
+
1381
+ Ensures:
1382
+ `self.pgm.factors[self.idx] is self`.
1383
+ """
1384
+ return self._idx
1385
+
1386
+ @property
1387
+ def shape(self) -> Shape:
1388
+ return self._shape
1389
+
1390
+ @property
1391
+ def number_of_states(self) -> int:
1392
+ """
1393
+ How many distinct states are covered by this Factor.
1394
+ """
1395
+ return self._potential_function.number_of_states
1396
+
1397
+ def __str__(self) -> str:
1398
+ """
1399
+ Return a human-readable string to represent this factor.
1400
+ This is intended mainly for debugging purposes.
1401
+ """
1402
+ return '(' + ', '.join([repr(str(rv)) for rv in self._rvs]) + ')'
1403
+
1404
+ def __len__(self) -> int:
1405
+ """
1406
+ Returns:
1407
+ the number of random variables.
1408
+ """
1409
+ return len(self._rvs)
1410
+
1411
+ @overload
1412
+ def __getitem__(self, index: int) -> RandomVariable:
1413
+ ...
1414
+
1415
+ @overload
1416
+ def __getitem__(self, index: slice) -> Sequence[RandomVariable]:
1417
+ ...
1418
+
1419
+ def __getitem__(self, index):
1420
+ return self._rvs[index]
1421
+
1422
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
1423
+ """
1424
+ Iterate over all possible instances, in natural index order (i.e.,
1425
+ last random variable changing most quickly).
1426
+
1427
+ Args:
1428
+ flip: if true, then first random variable changes most quickly
1429
+
1430
+ Returns:
1431
+ an iterator over tuples, each tuple holds random variable
1432
+ state indexes, co-indexed with this object's shape, i.e., self.shape.
1433
+ """
1434
+ return self.function.instances(flip)
1435
+
1436
+ def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
1437
+ """
1438
+ Iterate over all possible instances of parent random variable, in
1439
+ natural index order (i.e., last random variable changing most quickly).
1440
+
1441
+ Args:
1442
+ flip: if true, then first random variable changes most quickly
1443
+
1444
+ Returns:
1445
+ an iteration over tuples, each tuple holds random variable states
1446
+ co-indexed with this object's 'parent' shape, i.e., `self.shape[1:]`.
1447
+ """
1448
+ return self.function.parent_instances(flip)
1449
+
1450
+ @property
1451
+ def is_zero(self) -> bool:
1452
+ """
1453
+ Is the potential function of this factor set to the special 'zero' potential function.
1454
+ """
1455
+ return self._potential_function is self._zero_potential_function
1456
+
1457
+ @property
1458
+ def function(self) -> PotentialFunction:
1459
+ return self._potential_function
1460
+
1461
+ @function.setter
1462
+ def function(self, function: PotentialFunction | Factor) -> None:
1463
+ """
1464
+ Set the potential function for this PGM factor to the given potential function
1465
+ or factor.
1466
+
1467
+ Assumes:
1468
+ The given potential function belongs to the same PGM as this Factor.
1469
+ The potential function has the correct shape.
1470
+ """
1471
+ if isinstance(function, Factor):
1472
+ function = function.function
1473
+ assert isinstance(function, PotentialFunction)
1474
+
1475
+ if self._potential_function is function:
1476
+ # nothing to do
1477
+ return
1478
+
1479
+ if function.pgm is not self._pgm:
1480
+ raise ValueError(f'the given function is not of the same PGM as the factor')
1481
+
1482
+ if function.shape != self._shape:
1483
+ raise ValueError(f'incorrect function shape: expected {self._shape}, got {function.shape}')
1484
+
1485
+ if isinstance(function, ZeroPotentialFunction):
1486
+ self.set_zero()
1487
+ else:
1488
+ self._potential_function = function
1489
+
1490
+ def set_zero(self) -> ZeroPotentialFunction:
1491
+ """
1492
+ Set the factor's potential function to its original ZeroPotentialFunction.
1493
+
1494
+ Returns:
1495
+ the potential function.
1496
+ """
1497
+ self._potential_function = self._zero_potential_function
1498
+ return self._potential_function
1499
+
1500
+ def set_dense(self) -> DensePotentialFunction:
1501
+ """
1502
+ Set to the potential function to a new `DensePotentialFunction` object.
1503
+
1504
+ Returns:
1505
+ the potential function.
1506
+ """
1507
+ self._potential_function = DensePotentialFunction(self)
1508
+ return self._potential_function
1509
+
1510
+ def set_sparse(self) -> SparsePotentialFunction:
1511
+ """
1512
+ Set to the potential function to a new `SparsePotentialFunction` object.
1513
+
1514
+ Returns:
1515
+ the potential function.
1516
+ """
1517
+ self._potential_function = SparsePotentialFunction(self)
1518
+ return self._potential_function
1519
+
1520
+ def set_compact(self) -> CompactPotentialFunction:
1521
+ """
1522
+ Set to the potential function to a new `CompactPotentialFunction` object.
1523
+
1524
+ Returns:
1525
+ the potential function.
1526
+ """
1527
+ self._potential_function = CompactPotentialFunction(self)
1528
+ return self._potential_function
1529
+
1530
+ def set_clause(self, *key: int) -> ClausePotentialFunction:
1531
+ """
1532
+ Set to the potential function to a new `ClausePotentialFunction` object.
1533
+
1534
+ Args:
1535
+ *key: defines the random variable states of the clause. The key is a sequence of
1536
+ random variable state indexes, co-indexed with `Factor.rvs`.
1537
+
1538
+ Returns:
1539
+ the potential function.
1540
+
1541
+ Raises:
1542
+ KeyError: if the key is not valid for the shape of the factor.
1543
+ """
1544
+ self._potential_function = ClausePotentialFunction(self, key)
1545
+ return self._potential_function
1546
+
1547
+ def set_cpt(self, tolerance: float = DEFAULT_TOLERANCE) -> CPTPotentialFunction:
1548
+ """
1549
+ Set to the potential function to a new `CPTPotentialFunction` object.
1550
+
1551
+ Args:
1552
+ tolerance: a tolerance when testing if values are equal to zero or one.
1553
+
1554
+ Returns:
1555
+ the potential function.
1556
+
1557
+ Raises:
1558
+ ValueError: if tolerance is negative.
1559
+ """
1560
+ self._potential_function = CPTPotentialFunction(self, tolerance)
1561
+ return self._potential_function
1562
+
1563
+
1564
+ @dataclass(frozen=True, eq=True)
1565
+ class ParamId:
1566
+ """
1567
+ A ParamId identifies a parameter of a potential function.
1568
+
1569
+ Parameter identifiers uniquely identify every parameter within a PGM.
1570
+
1571
+ A ParamId is immutable and hashable.
1572
+ """
1573
+ function_id: int
1574
+ param_idx: int
1575
+
1576
+
1577
+ class PotentialFunction(ABC):
1578
+ """
1579
+ A potential function defines the potential values for a Factor, where
1580
+ a factor joins one or more variables of a PGM.
1581
+
1582
+ A potential function may be shared by several Factors of a PGM,
1583
+ i.e., can be applied to multiple variables.
1584
+
1585
+ The `shape` of a potential function is a tuple of integers which defines
1586
+ the number of variables, len(shape), and the number of states of each
1587
+ variable, shape[i].
1588
+
1589
+ The potential function value for variable states (x = i, y = j, ...) is given by
1590
+ self[i, j, ...], i.e., self.__getitem__((i, j, ...)). The tuple, (i, j, ...), is
1591
+ known as a Key.
1592
+
1593
+ The values of a potential function are defined by potential function parameters.
1594
+ The number of potential function parameters is given by number_of_parameters.
1595
+ The value of each parameter is given by get_param(i), where i is the parameter index.
1596
+
1597
+ Every valid key of the potential function is mapped either mapped to a parameter or is
1598
+ "guaranteed zero" which means that the value is zero and cannot be changed by changing
1599
+ the values of the potential function's parameters.
1600
+ """
1601
+
1602
+ def __init__(self, factor: Factor):
1603
+ """
1604
+ Create a potential function compatible with the given factor.
1605
+
1606
+ Ensures:
1607
+ Does not hold a reference to the given factor.
1608
+ Does not register the potential function with the PGM.
1609
+
1610
+ Args:
1611
+ factor: which factor is this potential function is compatible with.
1612
+ """
1613
+ self._pgm: PGM = factor.pgm
1614
+ self._shape: Shape = factor.shape
1615
+ self._number_of_states = _multiply(self._shape)
1616
+
1617
+ @property
1618
+ def pgm(self) -> PGM:
1619
+ """
1620
+ Returns:
1621
+ The PGM that this potential function belong to.
1622
+ """
1623
+ return self._pgm
1624
+
1625
+ @property
1626
+ def shape(self) -> Shape:
1627
+ """
1628
+ Returns:
1629
+ The shape of this potential function.
1630
+ """
1631
+ return self._shape
1632
+
1633
+ @property
1634
+ def number_of_rvs(self) -> int:
1635
+ """
1636
+ Returns:
1637
+ The number of random variables in this potential function.
1638
+ """
1639
+ return len(self._shape)
1640
+
1641
+ @property
1642
+ def number_of_states(self) -> int:
1643
+ """
1644
+ How many distinct states are covered by this potential function.
1645
+
1646
+ Returns:
1647
+ The size of the state space of this potential function.
1648
+ """
1649
+ return self._number_of_states
1650
+
1651
+ @property
1652
+ def number_of_parent_states(self) -> int:
1653
+ """
1654
+ How many distinct states are covered by this potential function parents,
1655
+ i.e., excluding the first random variable.
1656
+
1657
+ Returns:
1658
+ The size of the state space of this potential function parent random variables.
1659
+ """
1660
+ return _multiply(self._shape[1:])
1661
+
1662
+ def count_usage(self) -> int:
1663
+ """
1664
+ Check all PGM factors to count the number of times that this potential function
1665
+ is used.
1666
+
1667
+ Returns:
1668
+ the number of factors that use this potential function.
1669
+ """
1670
+ return sum(1 for factor in self._pgm.factors if factor.function is self)
1671
+
1672
+ def check_key(self, key: Key) -> Instance:
1673
+ """
1674
+ Convert the key into an instance.
1675
+
1676
+ Arg:
1677
+ key: defines an instance in the state space of the potential function.
1678
+
1679
+ Returns:
1680
+ an instance, which is a tuple of state indexes, co-indexed with `self.rvs`.
1681
+
1682
+ Raises:
1683
+ KeyError: if the key is not valid for the shape of the factor.
1684
+ """
1685
+ return check_key(self._shape, key)
1686
+
1687
+ def valid_key(self, key: Key) -> bool:
1688
+ """
1689
+ Is the given key valid.
1690
+
1691
+ Arg:
1692
+ key: defines an instance in the state space of the potential function.
1693
+
1694
+ Returns:
1695
+ True only if the given key is valid.
1696
+ """
1697
+ return valid_key(self._shape, key)
1698
+
1699
+ def valid_parameter(self, param_idx: int) -> bool:
1700
+ """
1701
+ Is the given parameter index valid.
1702
+
1703
+ Arg:
1704
+ param_idx: a parameter index.
1705
+
1706
+ Returns:
1707
+ True only if `0 <= param_idx < self.number_of_parameters`.
1708
+ """
1709
+ return 0 <= param_idx < self.number_of_parameters
1710
+
1711
+ @property
1712
+ def is_sparse(self) -> bool:
1713
+ """
1714
+ Are there any 'guaranteed zero' parameters values.
1715
+
1716
+ Returns:
1717
+ True only if `self.number_of_not_guaranteed_zero < self._number_of_states`.
1718
+ """
1719
+ return self.number_of_not_guaranteed_zero < self._number_of_states
1720
+
1721
+ @property
1722
+ @abstractmethod
1723
+ def number_of_not_guaranteed_zero(self) -> int:
1724
+ """
1725
+ How many of the states of this potential function are not 'guaranteed zero'.
1726
+ That is, how many keys are associated with a parameter.
1727
+
1728
+ Returns:
1729
+ The number of valid keys that are associated with a parameter.
1730
+
1731
+ Ensures:
1732
+ 0 <= self.number_of_not_guaranteed_zero <= self.number_of_states.
1733
+ """
1734
+ ...
1735
+
1736
+ @property
1737
+ @abstractmethod
1738
+ def number_of_parameters(self) -> int:
1739
+ """
1740
+ Get the number of parameters defining the potential function values.
1741
+ Each valid key of the function maps either to a parameter
1742
+ is 'guaranteed zero'.
1743
+
1744
+ Returns:
1745
+ The number of parameters.
1746
+
1747
+ Ensures:
1748
+ 0 <= self.number_of_parameters <= self.number_of_not_guaranteed_zero.
1749
+ """
1750
+ ...
1751
+
1752
+ @property
1753
+ @abstractmethod
1754
+ def params(self) -> Iterable[Tuple[int, float]]:
1755
+ """
1756
+ Iterate the parameters and their associated values.
1757
+
1758
+ Returns:
1759
+ An iterable over (param_idx, value) tuples, for every possible parameter.
1760
+
1761
+ Assumes:
1762
+ The potential function is not mutated while iterating.
1763
+ """
1764
+ ...
1765
+
1766
+ @property
1767
+ @abstractmethod
1768
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
1769
+ """
1770
+ Iterate the keys that have a parameter associated with them.
1771
+
1772
+ Returns:
1773
+ An iterable over (key, param_idx, value) tuples, for every key with an associated parameter.
1774
+
1775
+ Assumes:
1776
+ The potential function is not mutated while iterating.
1777
+ """
1778
+ ...
1779
+
1780
+ @abstractmethod
1781
+ def __getitem__(self, key: Key) -> float:
1782
+ """
1783
+ Get the potential function value for the given instance key.
1784
+
1785
+ Arg:
1786
+ key: defines an instance in the state space of the potential function.
1787
+
1788
+ Returns:
1789
+ The value of the potential function for the given key.
1790
+
1791
+ Assumes:
1792
+ self.valid_key(key).
1793
+ """
1794
+ ...
1795
+
1796
+ @abstractmethod
1797
+ def param_value(self, param_idx: int) -> float:
1798
+ """
1799
+ Get the potential function value by parameter index.
1800
+
1801
+ Arg:
1802
+ param_idx: a parameter index.
1803
+
1804
+ Assumes:
1805
+ `self.valid_parameter(param_idx)`.
1806
+ """
1807
+ ...
1808
+
1809
+ @abstractmethod
1810
+ def param_idx(self, key: Key) -> Optional[int]:
1811
+ """
1812
+ Get the parameter index for the given potential function random variables states (key).
1813
+
1814
+ Arg:
1815
+ key: defines an instance in the state space of the potential function.
1816
+
1817
+ Returns:
1818
+ either `None` indicating a "guaranteed zero" value, or the parameter index holding
1819
+ the potential function value for the key.
1820
+ """
1821
+ ...
1822
+
1823
+ def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
1824
+ """
1825
+ Is the potential function set with parameters values conforming to a
1826
+ Conditional Probability Table.
1827
+
1828
+ Every parameter value must be non-negative.
1829
+ For every state of the parent (non-first slots)
1830
+ the sum of the parameters over the child states (first slots)
1831
+ must be either 1 or 0.
1832
+
1833
+ Assumes:
1834
+ tolerance is non-negative.
1835
+
1836
+ Args:
1837
+ tolerance: a tolerance when testing if values are equal to zero or one.
1838
+
1839
+ Returns:
1840
+ True only if the potential function is compatible with being a CPT.
1841
+ """
1842
+ # This default implementation calculates the result the long way, by checking
1843
+ # every valid key of the potential function.
1844
+ # Subclasses may override this implementation.
1845
+ low: float = 1.0 - tolerance
1846
+ high: float = 1.0 + tolerance
1847
+ for parent_state in self.parent_instances():
1848
+ total: float = sum(
1849
+ self[(state,) + tuple(parent_state)]
1850
+ for state in range(self.shape[0])
1851
+ )
1852
+ if not ((low <= total <= high) or (0 <= total <= tolerance)):
1853
+ return False
1854
+ return True
1855
+
1856
+ def natural_param_idx(self, key: Key) -> int:
1857
+ """
1858
+ Get the natural parameter index for the given key. This is the same index as used
1859
+ by a DensePotentialFunction with the same shape.
1860
+
1861
+ Args:
1862
+ key: is a valid key of the potential function, referring to an instance in the factor's state space.
1863
+
1864
+ Assumes:
1865
+ `self.valid_key(key)` is true.
1866
+
1867
+ Returns:
1868
+ a hypothetical parameter index assuming that every valid key has a unique parameter
1869
+ as per DensePotentialFunction.
1870
+ """
1871
+ return _natural_key_idx(self._shape, key)
1872
+
1873
+ def param_id(self, param_idx: int) -> ParamId:
1874
+ """
1875
+ Get a hashable object to represent the parameter with the given parameter index.
1876
+
1877
+ Arg:
1878
+ param_idx: a parameter index.
1879
+
1880
+ Returns:
1881
+ a hashable ParamId object for the parameter of this potential function.
1882
+
1883
+ Raises:
1884
+ ValueError: if the parameter index is not valid.
1885
+ """
1886
+ if not (0 <= param_idx < self.number_of_parameters):
1887
+ raise ValueError(f'invalid parameter index: {param_idx}')
1888
+ return ParamId(id(self), param_idx)
1889
+
1890
+ def items(self) -> Iterable[Tuple[Instance, float]]:
1891
+ """
1892
+ Iterate over all keys and values of this potential function.
1893
+
1894
+ Returns:
1895
+ An iterator over all (key, value) pairs, where key is an Instance and value
1896
+ is the value of the potential function for the key.
1897
+ """
1898
+ for key in _combos_ranges(self._shape, flip=True):
1899
+ yield key, self[key]
1900
+
1901
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
1902
+ """
1903
+ Iterate over all possible instances, in natural index order (i.e.,
1904
+ last random variable changing most quickly).
1905
+
1906
+ Args:
1907
+ flip: if true, then first random variable changes most quickly
1908
+
1909
+ Returns:
1910
+ an iterator over tuples, each tuple holds random variable
1911
+ state indexes, co-indexed with this object's shape, i.e., self.shape.
1912
+ """
1913
+ return _combos_ranges(self._shape, flip=not flip)
1914
+
1915
+ def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
1916
+ """
1917
+ Iterate over all possible instances of parent random variable, in
1918
+ natural index order (i.e., last random variable changing most quickly).
1919
+
1920
+ Args:
1921
+ flip: if true, then first random variable changes most quickly
1922
+
1923
+ Returns:
1924
+ an iteration over tuples, each tuple holds random variable states
1925
+ co-indexed with this object's 'parent' shape, i.e., `self.shape[1:]`.
1926
+ """
1927
+ return _combos_ranges(self._shape[1:], flip=not flip)
1928
+
1929
+ def __str__(self) -> str:
1930
+ """
1931
+ Provide a human-readable representation of this potential function.
1932
+ This is intended mainly for debugging purposes.
1933
+ """
1934
+ shape_str: str = ', '.join(str(x) for x in self._shape)
1935
+ return f'{self.__class__.__name__}({shape_str})'
1936
+
1937
+ def dump(
1938
+ self,
1939
+ *,
1940
+ prefix: str = '',
1941
+ indent: str = ' ',
1942
+ show_function_values: bool = False,
1943
+ show_id_class: bool = True,
1944
+ ) -> None:
1945
+ """
1946
+ Print a dump of the function.
1947
+ This is intended for debugging purposes.
1948
+
1949
+ Args:
1950
+ prefix: optional prefix for indenting all lines.
1951
+ indent: additional prefix to use for extra indentation.
1952
+ show_function_values: if true, then the function values will be dumped.
1953
+ show_id_class: if true, then the function id and class will be dumped.
1954
+ """
1955
+
1956
+ shape_str: str = ', '.join(str(x) for x in self._shape)
1957
+
1958
+ if show_id_class:
1959
+ print(f'{prefix}id: {id(self)}')
1960
+ print(f'{prefix}class: {self.__class__.__name__}')
1961
+ print(f'{prefix}usage: {self.count_usage()}')
1962
+ print(f'{prefix}rvs: {self.number_of_rvs}')
1963
+ print(f'{prefix}shape: ({shape_str})')
1964
+ print(f'{prefix}states: {self._number_of_states}')
1965
+ print(f'{prefix}guaranteed zero: {self._number_of_states - self.number_of_not_guaranteed_zero}')
1966
+ print(f'{prefix}not guaranteed zero: {self.number_of_not_guaranteed_zero}')
1967
+ print(f'{prefix}parameters: {self.number_of_parameters}')
1968
+ if show_function_values:
1969
+ next_prefix = prefix + indent
1970
+ for key, param_idx, value in self.keys_with_param:
1971
+ print(f'{next_prefix}{param_idx} {key} = {value}')
1972
+
1973
+
1974
+ class ZeroPotentialFunction(PotentialFunction):
1975
+ """
1976
+ A ZeroPotentialFunction behaves like a DensePotentialFunction
1977
+ in that there is a parameter for each possible key.
1978
+ However, a PGM user has no way to change parameter values.
1979
+ Parameter values are always zero.
1980
+ Despite the inability to change the value of the parameters,
1981
+ no key is considered 'guaranteed zero'.
1982
+
1983
+ The primary use of a ZeroPotentialFunction is as a placeholder
1984
+ within a factor, prior to parameter learning.
1985
+ """
1986
+ __slots__ = ()
1987
+
1988
+ def __init__(self, factor: Factor):
1989
+ """
1990
+ Create a potential function for the given factor.
1991
+
1992
+ Ensures:
1993
+ Does not hold a reference to the given factor.
1994
+ Does not register the potential function with the PGM.
1995
+
1996
+ Args:
1997
+ factor: which factor is this potential function is compatible with.
1998
+ """
1999
+ super().__init__(factor)
2000
+
2001
+ @property
2002
+ def number_of_not_guaranteed_zero(self) -> int:
2003
+ return self.number_of_states
2004
+
2005
+ @property
2006
+ def number_of_parameters(self) -> int:
2007
+ return self.number_of_states
2008
+
2009
+ @property
2010
+ def params(self) -> Iterable[Tuple[int, float]]:
2011
+ for param_idx in range(self.number_of_parameters):
2012
+ yield param_idx, 0
2013
+
2014
+ @property
2015
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2016
+ for param_idx, instance in enumerate(self.instances()):
2017
+ yield instance, param_idx, 0
2018
+
2019
+ def __getitem__(self, key: Key) -> float:
2020
+ self.check_key(key)
2021
+ return 0
2022
+
2023
+ def param_value(self, param_idx: int) -> float:
2024
+ if not self.valid_parameter(param_idx):
2025
+ raise ValueError(f'invalid parameter index: {param_idx}')
2026
+ return 0
2027
+
2028
+ def param_idx(self, key: Key) -> int:
2029
+ return _natural_key_idx(self._shape, key)
2030
+
2031
+ def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
2032
+ return True
2033
+
2034
+
2035
+ class DensePotentialFunction(PotentialFunction):
2036
+ """
2037
+ A dense (tabular) potential function.
2038
+ There is one parameter for each valid key of the potential function.
2039
+ The initial value for each parameter is zero.
2040
+ It is possible independently change any value corresponding to any key.
2041
+ """
2042
+
2043
+ def __init__(self, factor: Factor):
2044
+ """
2045
+ Create a potential function for the given factor.
2046
+
2047
+ Ensures:
2048
+ Does not hold a reference to the given factor.
2049
+ Does not register the potential function with the PGM.
2050
+
2051
+ Args:
2052
+ factor: which factor is this potential function is compatible with.
2053
+ """
2054
+ super().__init__(factor)
2055
+ self._values: NDArrayFloat64 = np.zeros(self.number_of_states, dtype=np.float64)
2056
+
2057
+ @property
2058
+ def number_of_not_guaranteed_zero(self) -> int:
2059
+ return self.number_of_states
2060
+
2061
+ @property
2062
+ def number_of_parameters(self) -> int:
2063
+ return self.number_of_states
2064
+
2065
+ def __getitem__(self, key: Key) -> float:
2066
+ return self._values.item(self.param_idx(key))
2067
+
2068
+ def param_value(self, param_idx: int) -> float:
2069
+ return self._values.item(param_idx)
2070
+
2071
+ def param_idx(self, key: Key) -> Optional[int]:
2072
+ return self.natural_param_idx(key)
2073
+
2074
+ @property
2075
+ def params(self) -> Iterable[Tuple[int, float]]:
2076
+ # Type warning due to numpy type erasure
2077
+ # noinspection PyTypeChecker
2078
+ return enumerate(self._values)
2079
+
2080
+ @property
2081
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2082
+ for param_idx, key in enumerate(self.instances()):
2083
+ value: float = self.param_value(param_idx)
2084
+ yield key, param_idx, value
2085
+
2086
+ # Mutators
2087
+
2088
+ def __setitem__(self, key: Key, value: float) -> None:
2089
+ """
2090
+ Set the potential function value, for a given key.
2091
+
2092
+ Arg:
2093
+ key: defines an instance in the state space of the potential function.
2094
+ value: the new value of the potential function for the given key.
2095
+
2096
+ Assumes:
2097
+ self.valid_key(key).
2098
+ """
2099
+ self._values[self.param_idx(key)] = value
2100
+
2101
+ def set_param_value(self, param_idx: int, value: float) -> None:
2102
+ """
2103
+ Set the parameter value.
2104
+
2105
+ Arg:
2106
+ param_idx: is the index of the parameter.
2107
+ value: the new value of the potential function for the given key.
2108
+
2109
+ Assumes:
2110
+ self.valid_param(param_idx).
2111
+ """
2112
+ self._values[param_idx] = value
2113
+
2114
+ def clear(self) -> DensePotentialFunction:
2115
+ """
2116
+ Set all values of the potential function to zero.
2117
+
2118
+ Returns:
2119
+ self
2120
+ """
2121
+ return self.set_all(0)
2122
+
2123
+ def normalise_cpt(self) -> DensePotentialFunction:
2124
+ """
2125
+ Normalise the parameter values as if this was a CPT.
2126
+ That is, treat the first random variable as the child and the others as parents;
2127
+ for each combination of parent states, ensure the parameters over the child
2128
+ states sum to 1 (or 0).
2129
+
2130
+ Assumes:
2131
+ There are no negative parameter values.
2132
+
2133
+ Returns:
2134
+ self
2135
+ """
2136
+ child = self._shape[0]
2137
+ parents = self._shape[1:]
2138
+ for parent_states in _combos_ranges(parents):
2139
+ keys = [(c,) + parent_states for c in range(child)]
2140
+ total = sum(self[key] for key in keys)
2141
+ if total != 0 and total != 1:
2142
+ for key in keys:
2143
+ self[key] /= total
2144
+ return self
2145
+
2146
+ def normalise(self, grouping_positions: Sequence[int] = ()) -> DensePotentialFunction:
2147
+ """
2148
+ Convert the potential function to a CPT with 'grouping_positions' nominating
2149
+ the parent random variables.
2150
+
2151
+ I.e., for each possible key of the function with the same value at each
2152
+ grouping position, the sum of values for matching keys in the factor is scaled
2153
+ to be 1 (or 0).
2154
+
2155
+ Parameter 'grouping_positions' are indices into `self.shape`. For example, the
2156
+ grouping positions of a factor with parent rvs 'conditioning_rvs', then
2157
+ grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
2158
+
2159
+ Args:
2160
+ grouping_positions: indices into `self.shape`.
2161
+
2162
+ Returns:
2163
+ self
2164
+ """
2165
+ _normalise_potential_function(self, grouping_positions)
2166
+ return self
2167
+
2168
+ def set_iter(self, values: Iterable[float]) -> DensePotentialFunction:
2169
+ """
2170
+ Set the values of the potential function using the given iterator.
2171
+
2172
+ Mapping instances to *values is as follows:
2173
+ Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2174
+ values[0] represents instance (0,0)
2175
+ values[1] represents instance (0,1)
2176
+ values[2] represents instance (0,2)
2177
+ values[3] represents instance (1,0)
2178
+ values[4] represents instance (1,1)
2179
+ values[5] represents instance (1,2).
2180
+
2181
+ For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
2182
+
2183
+ Args:
2184
+ values: an iterable providing values to use.
2185
+
2186
+ Returns:
2187
+ self
2188
+ """
2189
+ self._values = np.fromiter(
2190
+ values,
2191
+ dtype=np.float64,
2192
+ count=self.number_of_parameters
2193
+ )
2194
+ return self
2195
+
2196
+ def set_stream(self, stream: Callable[[], float]) -> DensePotentialFunction:
2197
+ """
2198
+ Set the values of the potential function by repeatedly calling the stream function.
2199
+ The order of values is the same as set_iter.
2200
+
2201
+ For example, to set to random numbers, use `self.set_stream(random.random)`.
2202
+
2203
+ Args:
2204
+ stream: a callable taking no arguments, returning the values to use.
2205
+
2206
+ Returns:
2207
+ self
2208
+ """
2209
+ return self.set_iter(iter(stream, None))
2210
+
2211
+ def set_flat(self, *value: float) -> DensePotentialFunction:
2212
+ """
2213
+ Set the values of the potential function to the given values.
2214
+ The order of values is the same as set_iter.
2215
+
2216
+ Args:
2217
+ *value: the values to use.
2218
+
2219
+ Returns:
2220
+ self
2221
+
2222
+ Raises:
2223
+ ValueError: if `len(value) != self.number_of_states`.
2224
+ """
2225
+ if len(value) != self.number_of_states:
2226
+ raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
2227
+ return self.set_iter(value)
2228
+
2229
+ def set_all(self, value: float) -> DensePotentialFunction:
2230
+ """
2231
+ Set all values of the potential function to the given value.
2232
+
2233
+ Args:
2234
+ value: the value to use.
2235
+
2236
+ Returns:
2237
+ self
2238
+ """
2239
+ return self.set_iter(_repeat(value))
2240
+
2241
+ def set_uniform(self) -> DensePotentialFunction:
2242
+ """
2243
+ Set all values of the potential function 1/number_of_states.
2244
+
2245
+ Returns:
2246
+ self
2247
+ """
2248
+ return self.set_all(1.0 / self.number_of_states)
2249
+
2250
+
2251
+ class SparsePotentialFunction(PotentialFunction):
2252
+ """
2253
+ A sparse potential function.
2254
+
2255
+ There is one parameter for each non-zero key value.
2256
+ The user may set the value for any key and parameters will
2257
+ be automatically reconfigured as needed. Setting the value for
2258
+ a key to zero disassociates the key from its parameter and
2259
+ thus makes that key "guaranteed zero".
2260
+ """
2261
+
2262
+ def __init__(self, factor: Factor):
2263
+ """
2264
+ Create a potential function for the given factor.
2265
+
2266
+ Ensures:
2267
+ Does not hold a reference to the given factor.
2268
+ Does not register the potential function with the PGM.
2269
+
2270
+ Args:
2271
+ factor: which factor is this potential function is compatible with.
2272
+ """
2273
+ super().__init__(factor)
2274
+ self._values: List[float] = []
2275
+ self._params: Dict[Instance, int] = {}
2276
+
2277
+ @property
2278
+ def number_of_not_guaranteed_zero(self) -> int:
2279
+ return len(self._params)
2280
+
2281
+ @property
2282
+ def number_of_parameters(self) -> int:
2283
+ return len(self._params)
2284
+
2285
+ def __getitem__(self, key: Key) -> float:
2286
+ param_idx: Optional[int] = self.param_idx(key)
2287
+ if param_idx is None:
2288
+ return 0
2289
+ else:
2290
+ return self._values[param_idx]
2291
+
2292
+ def param_value(self, param_idx: int) -> float:
2293
+ return self._values[param_idx]
2294
+
2295
+ def param_idx(self, key: Key) -> Optional[int]:
2296
+ return self._params.get(_key_to_instance(key))
2297
+
2298
+ @property
2299
+ def params(self) -> Iterable[Tuple[int, float]]:
2300
+ return enumerate(self._values)
2301
+
2302
+ @property
2303
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2304
+ for key, param_idx in self._params.items():
2305
+ value: float = self._values[param_idx]
2306
+ yield key, param_idx, value
2307
+
2308
+ # Mutators
2309
+
2310
+ def __setitem__(self, key: Key, value: float) -> None:
2311
+ """
2312
+ Set the potential function value, for a given key.
2313
+
2314
+ If value is zero, then the key will become "guaranteed zero".
2315
+
2316
+ Arg:
2317
+ key: defines an instance in the state space of the potential function.
2318
+ value: the new value of the potential function for the given key.
2319
+
2320
+ Assumes:
2321
+ self.valid_key(key).
2322
+ """
2323
+ instance: Instance = _key_to_instance(key)
2324
+ param_idx: Optional[int] = self._params.get(instance)
2325
+
2326
+ if param_idx is None:
2327
+ if value == 0:
2328
+ # Nothing to do
2329
+ return
2330
+ param_idx = len(self._values)
2331
+ self._values.append(value)
2332
+ self._params[instance] = param_idx
2333
+ return
2334
+
2335
+ if value != 0:
2336
+ # Simple case
2337
+ self._values[param_idx] = value
2338
+ return
2339
+
2340
+ # This is the case where the key was associated with a parameter
2341
+ # but the value is being set to zero, so we
2342
+ # need to clear an existing non-zero parameter.
2343
+ # This code operates by first ensuring the parameter is the last one,
2344
+ # then popping the last parameter.
2345
+
2346
+ end: int = len(self._values) - 1
2347
+ if param_idx != end:
2348
+ # need to swap the parameter with the end.
2349
+ self._values[param_idx] = self._values[end]
2350
+
2351
+ for test_instance, test_param_idx in self._params.items():
2352
+ if test_param_idx == end:
2353
+ self._params[test_instance] = param_idx
2354
+ # There will only be one, so we can break now
2355
+ break
2356
+
2357
+ # Remove the parameter
2358
+ self._values.pop()
2359
+ self._params.pop(instance)
2360
+
2361
+ def set_param_value(self, param_idx: int, value: float) -> None:
2362
+ """
2363
+ Set the parameter value.
2364
+
2365
+ Arg:
2366
+ param_idx: is the index of the parameter.
2367
+ value: the new value of the potential function for the given key.
2368
+
2369
+ Assumes:
2370
+ self.valid_param(param_idx).
2371
+ """
2372
+ self._values[param_idx] = value
2373
+
2374
+ def clear(self) -> SparsePotentialFunction:
2375
+ """
2376
+ Set all values of the potential function to zero.
2377
+
2378
+ Returns:
2379
+ self
2380
+ """
2381
+ self._values = []
2382
+ self._params = {}
2383
+ return self
2384
+
2385
+ def normalise_cpt(self) -> SparsePotentialFunction:
2386
+ """
2387
+ Normalise the parameter values as if this was a CPT.
2388
+ That is, treat the first random variable as the child and the others as parents;
2389
+ for each combination of parent states, ensure the parameters over
2390
+ the child states sum to 1 (or 0).
2391
+
2392
+ Returns:
2393
+ self
2394
+ """
2395
+ grouping_positions = list(range(1, self.number_of_rvs))
2396
+ _normalise_potential_function(self, grouping_positions)
2397
+ return self
2398
+
2399
+ def normalise(self, grouping_positions=()) -> SparsePotentialFunction:
2400
+ """
2401
+ Convert the potential function to a CPT with 'grouping_positions' nominating
2402
+ the parent random variables.
2403
+
2404
+ I.e., for each possible key of the function with the same value at each
2405
+ grouping position, the sum of values for matching keys in the factor is scaled
2406
+ to be 1 (or 0).
2407
+
2408
+ Parameter 'grouping_positions' are indices into function.shape. For example, the
2409
+ grouping positions of a factor with parent rvs 'conditioning_rvs', then
2410
+ grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
2411
+
2412
+ Returns:
2413
+ self
2414
+ """
2415
+ _normalise_potential_function(self, grouping_positions)
2416
+ return self
2417
+
2418
+ def set_iter(self, values: Iterable[float]) -> SparsePotentialFunction:
2419
+ """
2420
+ Set the values of the potential function using the given iterator.
2421
+
2422
+ Mapping instances to *values is as follows:
2423
+ Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2424
+ values[0] represents instance (0,0)
2425
+ values[1] represents instance (0,1)
2426
+ values[2] represents instance (0,2)
2427
+ values[3] represents instance (1,0)
2428
+ values[4] represents instance (1,1)
2429
+ values[5] represents instance (1,2).
2430
+
2431
+ For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
2432
+
2433
+ Args:
2434
+ values: an iterable providing values to use.
2435
+
2436
+ Returns:
2437
+ self
2438
+ """
2439
+ self.clear()
2440
+ for instance, value in zip(self.instances(), values):
2441
+ if value != 0:
2442
+ self._params[instance] = len(self._values)
2443
+ self._values.append(value)
2444
+ return self
2445
+
2446
+ def set_stream(self, stream: Callable[[], float]) -> SparsePotentialFunction:
2447
+ """
2448
+ Set the values of the potential function by repeatedly calling the stream function.
2449
+ The order of values is the same as set_iter.
2450
+
2451
+ For example, to set to random numbers, use `self.set_stream(random.random)`.
2452
+
2453
+ Args:
2454
+ stream: a callable taking no arguments, returning the values to use.
2455
+
2456
+ Returns:
2457
+ self
2458
+ """
2459
+ return self.set_iter(iter(stream, None))
2460
+
2461
+ def set_flat(self, *value: float) -> SparsePotentialFunction:
2462
+ """
2463
+ Set the values of the potential function to the given values.
2464
+ The order of values is the same as set_iter.
2465
+
2466
+ Args:
2467
+ *value: the values to use.
2468
+
2469
+ Returns:
2470
+ self
2471
+
2472
+ Raises:
2473
+ ValueError: if `len(value) != self.number_of_states`.
2474
+ """
2475
+ if len(value) != self.number_of_states:
2476
+ raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
2477
+ return self.set_iter(value)
2478
+
2479
+ def set_all(self, value: float) -> SparsePotentialFunction:
2480
+ """
2481
+ Set all values of the potential function to the given value.
2482
+
2483
+ Args:
2484
+ value: the value to use.
2485
+
2486
+ Returns:
2487
+ self
2488
+ """
2489
+ if value == 0:
2490
+ return self.clear()
2491
+ else:
2492
+ return self.set_iter(_repeat(value))
2493
+
2494
+ def set_uniform(self) -> SparsePotentialFunction:
2495
+ """
2496
+ Set all values of the potential function 1/number_of_states.
2497
+
2498
+ Returns:
2499
+ self
2500
+ """
2501
+ return self.set_all(1.0 / self.number_of_states)
2502
+
2503
+
2504
+ class CompactPotentialFunction(PotentialFunction):
2505
+ """
2506
+ A compact potential function is sparse, where values for keys of
2507
+ the same value are represented by a single parameter.
2508
+
2509
+ There is one parameter for each unique, non-zero key value.
2510
+ The user may set the value for any key and parameters will
2511
+ be automatically reconfigured as needed. Setting the value for
2512
+ a key to zero disassociates the key from its parameter and
2513
+ thus makes that key "guaranteed zero".
2514
+ """
2515
+
2516
+ def __init__(self, factor: Factor):
2517
+ """
2518
+ Create a potential function for the given factor.
2519
+
2520
+ Ensures:
2521
+ Does not hold a reference to the given factor.
2522
+ Does not register the potential function with the PGM.
2523
+
2524
+ Args:
2525
+ factor: which factor is this potential function is compatible with.
2526
+ """
2527
+ super().__init__(factor)
2528
+ self._values: List[float] = []
2529
+ self._counts: List[int] = []
2530
+ self._map: Dict[Instance, int] = {}
2531
+ self._inv_map: Dict[float, int] = {}
2532
+
2533
+ @property
2534
+ def number_of_not_guaranteed_zero(self) -> int:
2535
+ return len(self._map)
2536
+
2537
+ @property
2538
+ def number_of_parameters(self) -> int:
2539
+ return len(self._values)
2540
+
2541
+ @property
2542
+ def params(self) -> Iterable[Tuple[int, float]]:
2543
+ return enumerate(self._values)
2544
+
2545
+ @property
2546
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2547
+ for key, param_idx in self._map.items():
2548
+ value: float = self._values[param_idx]
2549
+ yield key, param_idx, value
2550
+
2551
+ def __getitem__(self, key: Key) -> float:
2552
+ param_idx: Optional[int] = self.param_idx(key)
2553
+ if param_idx is None:
2554
+ return 0
2555
+ else:
2556
+ return self._values[param_idx]
2557
+
2558
+ def param_value(self, param_idx: int) -> float:
2559
+ return self._values[param_idx]
2560
+
2561
+ def param_idx(self, key: Key) -> Optional[int]:
2562
+ return self._map.get(_key_to_instance(key))
2563
+
2564
+ # Mutators
2565
+
2566
+ def __setitem__(self, key: Key, value: float) -> None:
2567
+ """
2568
+ Set the potential function value, for a given key.
2569
+
2570
+ If value is zero, then the key will become "guaranteed zero".
2571
+ If the value is the same as an existing parameter value, then
2572
+ that parameter will be reused.
2573
+
2574
+ Arg:
2575
+ key: defines an instance in the state space of the potential function.
2576
+ value: the new value of the potential function for the given key.
2577
+
2578
+ Assumes:
2579
+ self.valid_key(key).
2580
+ """
2581
+ instance: Instance = _key_to_instance(key)
2582
+
2583
+ param_idx: Optional[int] = self._map.get(instance)
2584
+
2585
+ if param_idx is None:
2586
+ # previous value for the key was zero
2587
+ if value == 0:
2588
+ # nothing to do
2589
+ return
2590
+ param_idx: Optional[int] = self._inv_map.get(value)
2591
+ if param_idx is not None:
2592
+ # the value already exists in the function, so reuse it
2593
+ self._map[instance] = param_idx
2594
+ self._counts[param_idx] += 1
2595
+ else:
2596
+ # need to allocate a new value
2597
+ new_param_idx: int = len(self._values)
2598
+ self._values.append(value)
2599
+ self._counts.append(1)
2600
+ self._inv_map[value] = new_param_idx
2601
+ self._map[instance] = new_param_idx
2602
+ return
2603
+
2604
+ # the key previously had a non-zero value
2605
+ prev_value: float = self._values[param_idx]
2606
+
2607
+ if value == prev_value:
2608
+ # nothing to do
2609
+ return
2610
+
2611
+ reference_count: int = self._counts[param_idx]
2612
+ if reference_count == 1:
2613
+ if value != 0:
2614
+ # simple case
2615
+ self._values[param_idx] = value
2616
+ else:
2617
+ # need to remove the parameter
2618
+ self._remove_param(param_idx)
2619
+ self._map.pop(instance)
2620
+ self._inv_map.pop(prev_value)
2621
+ return
2622
+
2623
+ # decrement the reference count of the previous parameter
2624
+ self._counts[param_idx] = reference_count - 1
2625
+
2626
+ # allocate the key to a different parameter
2627
+ param_idx: Optional[int] = self._inv_map.get(value)
2628
+ if param_idx is not None:
2629
+ # the value already exists in the function, so reuse it
2630
+ self._map[instance] = param_idx
2631
+ self._counts[param_idx] += 1
2632
+ else:
2633
+ # need to allocate a new value
2634
+ new_param_idx: int = len(self._values)
2635
+ self._values.append(value)
2636
+ self._counts.append(1)
2637
+ self._inv_map[value] = new_param_idx
2638
+ self._map[instance] = new_param_idx
2639
+
2640
+ def set_iter(self, values: Iterable[float]) -> CompactPotentialFunction:
2641
+ """
2642
+ Set the values of the potential function using the given iterator.
2643
+
2644
+ Mapping instances to *values is as follows:
2645
+ Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2646
+ values[0] represents instance (0,0)
2647
+ values[1] represents instance (0,1)
2648
+ values[2] represents instance (0,2)
2649
+ values[3] represents instance (1,0)
2650
+ values[4] represents instance (1,1)
2651
+ values[5] represents instance (1,2).
2652
+
2653
+ For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
2654
+
2655
+ Args:
2656
+ values: an iterable providing values to use.
2657
+
2658
+ Returns:
2659
+ self
2660
+ """
2661
+ self.clear()
2662
+ for instance, value in zip(self.instances(), values):
2663
+ self[instance] = value
2664
+ return self
2665
+
2666
+ def set_stream(self, stream: Callable[[], float]) -> CompactPotentialFunction:
2667
+ """
2668
+ Set the values of the potential function by repeatedly calling the stream function.
2669
+ The order of values is the same as set_iter.
2670
+
2671
+ For example, to set to random numbers, use `self.set_stream(random.random)`.
2672
+
2673
+ Args:
2674
+ stream: a callable taking no arguments, returning the values to use.
2675
+
2676
+ Returns:
2677
+ self
2678
+ """
2679
+ return self.set_iter(iter(stream, None))
2680
+
2681
+ def set_flat(self, *value: float) -> CompactPotentialFunction:
2682
+ """
2683
+ Set the values of the potential function to the given values.
2684
+ The order of values is the same as set_iter.
2685
+
2686
+ Args:
2687
+ *value: the values to use.
2688
+
2689
+ Returns:
2690
+ self
2691
+
2692
+ Raises:
2693
+ ValueError: if `len(value) != self.number_of_states`.
2694
+ """
2695
+ if len(value) != self.number_of_states:
2696
+ raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
2697
+ return self.set_iter(value)
2698
+
2699
+ def set_all(self, value: float) -> CompactPotentialFunction:
2700
+ """
2701
+ Set all values of the potential function to the given value.
2702
+
2703
+ Args:
2704
+ value: the value to use.
2705
+
2706
+ Returns:
2707
+ self
2708
+ """
2709
+ self.clear()
2710
+ if value != 0:
2711
+ self._values = [value]
2712
+ self._counts = [self.number_of_states]
2713
+ self._inv_map = {value: 0}
2714
+ self._map = {instance: 0 for instance in self.instances()}
2715
+ return self
2716
+
2717
+ def set_uniform(self) -> CompactPotentialFunction:
2718
+ """
2719
+ Set all values of the potential function 1/number_of_states.
2720
+
2721
+ Returns:
2722
+ self
2723
+ """
2724
+ return self.set_all(1.0 / self.number_of_states)
2725
+
2726
+ def clear(self) -> CompactPotentialFunction:
2727
+ """
2728
+ Set all values of the potential function to zero.
2729
+
2730
+ Returns:
2731
+ self
2732
+ """
2733
+ self._values = []
2734
+ self._counts = []
2735
+ self._map = {}
2736
+ self._inv_map = {}
2737
+ return self
2738
+
2739
+ def _remove_param(self, param_idx: int) -> None:
2740
+ """
2741
+ Remove the indexed parameter from self._params and self._counts.
2742
+ If the parameter is not at the end of the list of parameters
2743
+ then it will be swapped with the last parameter in the list.
2744
+ """
2745
+
2746
+ # ensure the parameter is at the end of the list
2747
+ end: int = len(self._values) - 1
2748
+ if param_idx != end:
2749
+ # swap `param_idx` with `end`
2750
+ end_value: float = self._values[end]
2751
+ self._values[param_idx] = end_value
2752
+ self._counts[param_idx] = self._counts[end]
2753
+ self._inv_map[end_value] = param_idx
2754
+ for instance, instance_param_idx in self._map.items():
2755
+ if instance_param_idx == end:
2756
+ self._map[instance] = param_idx
2757
+
2758
+ # remove the end parameter
2759
+ self._values.pop()
2760
+ self._counts.pop()
2761
+
2762
+
2763
+ class ClausePotentialFunction(PotentialFunction):
2764
+ """
2765
+ A clause potential function represents a clause From a CNF formula.
2766
+ I.e. a clause over variables X, Y, Z, is a disjunction of the form: 'X=x or Y=y or Z=z'.
2767
+
2768
+ A clause potential function is guaranteed zero for a key where the clause is false,
2769
+ i.e., when 'X != x and Y != y and Z != z'.
2770
+
2771
+ For keys where the clause is true, the value of the potential function
2772
+ is given by the only parameter of the potential function. That parameter
2773
+ is called the clause 'weight' and is notionally 1.
2774
+
2775
+ The weight of a clause is permitted to be zero, but that is _not_ equivalent to
2776
+ guaranteed-zero.
2777
+ """
2778
+
2779
+ def __init__(self, factor: Factor, key: Key, weight: float = 1):
2780
+ """
2781
+ Create a clause potential function for the given factor.
2782
+
2783
+ Ensures:
2784
+ Does not hold a reference to the given factor.
2785
+ Does not register the potential function with the PGM.
2786
+
2787
+ Raises:
2788
+ KeyError: if the key is not valid for the shape of the factor.
2789
+
2790
+ Args:
2791
+ factor: which factor is this potential function is compatible with.
2792
+ key: defines the random variable states of the clause.
2793
+ """
2794
+ super().__init__(factor)
2795
+ self._weight: float = weight
2796
+ self._clause: Instance = self.check_key(key)
2797
+ self._num_not_guaranteed_zero: int = _zero_space(self.shape)
2798
+
2799
+ @property
2800
+ def number_of_not_guaranteed_zero(self) -> int:
2801
+ return self._num_not_guaranteed_zero
2802
+
2803
+ @property
2804
+ def number_of_parameters(self) -> int:
2805
+ return 1
2806
+
2807
+ @property
2808
+ def params(self) -> Iterable[Tuple[int, float]]:
2809
+ return ((0, self._weight),)
2810
+
2811
+ @property
2812
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2813
+ value = self._weight
2814
+ for i in range(self.number_of_rvs):
2815
+ key = list(self._clause)
2816
+ for j in range(self.shape[i]):
2817
+ key[i] = j
2818
+ yield tuple(key), 0, value
2819
+
2820
+ def __getitem__(self, key: Key) -> float:
2821
+ instance: Instance = self.check_key(key)
2822
+ for key_state_idx, clause_state_idx in zip(instance, self._clause):
2823
+ if key_state_idx == clause_state_idx:
2824
+ return self._weight
2825
+ return 0
2826
+
2827
+ def param_value(self, param_idx: int) -> float:
2828
+ if param_idx != 0:
2829
+ raise IndexError(param_idx)
2830
+ return self._weight
2831
+
2832
+ def param_idx(self, key: Key) -> Optional[int]:
2833
+ instance: Instance = _key_to_instance(key)
2834
+ if instance == self._clause:
2835
+ return 0
2836
+ else:
2837
+ return None
2838
+
2839
+ def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
2840
+ """
2841
+ A ClausePotentialFunction can only be a CTP when all entries are zero.
2842
+ """
2843
+ return -tolerance <= self._weight <= tolerance
2844
+
2845
+ def is_sparse(self) -> bool:
2846
+ return True
2847
+
2848
+ @property
2849
+ def weight(self) -> float:
2850
+ """
2851
+ Returns:
2852
+ the "weight" parameter defining the potential function.
2853
+ """
2854
+ return self._weight
2855
+
2856
+ @property
2857
+ def clause(self) -> Instance:
2858
+ """
2859
+ Returns:
2860
+ the clause defining the potential function.
2861
+ """
2862
+ return self._clause
2863
+
2864
+ # Mutators
2865
+
2866
+ @weight.setter
2867
+ def weight(self, value: float) -> None:
2868
+ """
2869
+ Set the weight parameter to the given value.
2870
+ """
2871
+ self._weight = value
2872
+
2873
+ @clause.setter
2874
+ def clause(self, key: Key) -> None:
2875
+ """
2876
+ Set the clause to the given key.
2877
+
2878
+ Raises:
2879
+ KeyError: if the key is not valid for the shape of the factor.
2880
+ """
2881
+ self._clause = self.check_key(key)
2882
+
2883
+
2884
+ class CPTPotentialFunction(PotentialFunction):
2885
+ """
2886
+ A potential function implementing a sparse Conditional Probability Table (CPT).
2887
+
2888
+ The first random variable in the signature is the child, and the remaining random
2889
+ variables are parents.
2890
+
2891
+ For each instantiation of the parent random variables there is a Conditioned Probability
2892
+ Distribution (CPD) over the states of the child random variable.
2893
+
2894
+ If a CPD is not provided for a parent instantiation, then that parent instantiation
2895
+ is taken to have probability zero (i.e., all values of the CPD are guaranteed zero).
2896
+ """
2897
+
2898
+ def __init__(self, factor: Factor, tolerance: float):
2899
+ """
2900
+ Create a CPT potential function for the given factor.
2901
+
2902
+ Ensures:
2903
+ Does not hold a reference to the given factor.
2904
+ Does not register the potential function with the PGM.
2905
+
2906
+ Args:
2907
+ factor: which factor is this potential function is compatible with.
2908
+ tolerance: a tolerance when testing if values are equal to zero or one.
2909
+
2910
+ Raises:
2911
+ ValueError: if tolerance is negative.
2912
+ """
2913
+ super().__init__(factor)
2914
+
2915
+ if tolerance < 0:
2916
+ raise ValueError('tolerance cannot be negative')
2917
+
2918
+ self._child_size: int = self.shape[0]
2919
+ self._parent_shape: Shape = self.shape[1:]
2920
+ self._map: Dict[Instance, int] = {}
2921
+ self._values: List[float] = []
2922
+ self._inv_map: List[Instance] = []
2923
+ self._tolerance = tolerance
2924
+
2925
+ @property
2926
+ def number_of_not_guaranteed_zero(self) -> int:
2927
+ return len(self._values)
2928
+
2929
+ @property
2930
+ def number_of_parameters(self) -> int:
2931
+ return len(self._values)
2932
+
2933
+ def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
2934
+ if tolerance >= self._tolerance:
2935
+ return True
2936
+ else:
2937
+ # The requested tolerance is tighter than ensured.
2938
+ # Need to use the default method
2939
+ return super().is_cpt(tolerance)
2940
+
2941
+ @property
2942
+ def params(self) -> Iterable[Tuple[int, float]]:
2943
+ return enumerate(self._values)
2944
+
2945
+ @property
2946
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2947
+ child_size: int = self._child_size
2948
+ for param_idx, value in enumerate(self._values):
2949
+ parent: Instance = self._inv_map[param_idx // child_size]
2950
+ key: Instance = (param_idx % child_size,) + tuple(parent)
2951
+ yield key, param_idx, value
2952
+
2953
+ def __getitem__(self, key: Key) -> float:
2954
+ param_idx: Optional[int] = self.param_idx(key)
2955
+ if param_idx is None:
2956
+ return 0
2957
+ else:
2958
+ return self._values[param_idx]
2959
+
2960
+ def param_value(self, param_idx: int) -> float:
2961
+ return self._values[param_idx]
2962
+
2963
+ def param_idx(self, key: Key) -> Optional[int]:
2964
+ instance: Instance = self.check_key(key)
2965
+ offset: Optional[int] = self._map.get(instance[1:])
2966
+ if offset is None:
2967
+ return None
2968
+ else:
2969
+ return offset + instance[0]
2970
+
2971
+ @property
2972
+ def parent_shape(self) -> Shape:
2973
+ """
2974
+ What is the shape of the parents.
2975
+ """
2976
+ return self._parent_shape
2977
+
2978
+ @property
2979
+ def number_of_parent_states(self) -> int:
2980
+ """
2981
+ How many combinations of parent states.
2982
+ """
2983
+ return _multiply(self._parent_shape)
2984
+
2985
+ @property
2986
+ def number_of_child_states(self) -> int:
2987
+ """
2988
+ Number of child random variable states.
2989
+
2990
+ This is the same as the number of values in each conditional
2991
+ probability distribution. This is equivalent to `self.shape[0]`.
2992
+
2993
+ Returns:
2994
+ the number of child states.
2995
+ """
2996
+ return self._child_size
2997
+
2998
+ def get_cpd(self, parent_states: Key) -> List[float]:
2999
+ """
3000
+ Get the CPD conditioned on parent states indicated by `parent_states`.
3001
+
3002
+ Args:
3003
+ parent_states: indicates the parent states.
3004
+
3005
+ Returns:
3006
+ The conditioned probability distribution.
3007
+ """
3008
+ parent_instance: Instance = check_key(self._parent_shape, parent_states)
3009
+ offset: Optional[int] = self._map.get(parent_instance)
3010
+ child_size: int = self._child_size
3011
+ if offset is None:
3012
+ return [0] * child_size
3013
+ else:
3014
+ return self._values[offset:offset + child_size]
3015
+
3016
+ def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
3017
+ """
3018
+ Iterate over (parent_states, cpd) tuples.
3019
+ This will exclude zero CPDs.
3020
+ Do not change CPDs to (or from) zero while iterating over them.
3021
+
3022
+ Get the CPD conditioned on parent states indicated by `parent_states`.
3023
+
3024
+ Returns:
3025
+ an iterator over pairs (instance, cpd) where,
3026
+ instance: is indicates the state of the parent random variables.
3027
+ cpd: is the conditioned probability distribution, for the parent instance.
3028
+ """
3029
+ for parent_instance, offset in self._map.items():
3030
+ cpd = self._values[offset:offset + self._child_size]
3031
+ yield parent_instance, cpd
3032
+
3033
+ # Mutators
3034
+
3035
+ def clear(self) -> CPTPotentialFunction:
3036
+ """
3037
+ Set all values of the potential function to zero.
3038
+
3039
+ Returns:
3040
+ self
3041
+ """
3042
+ self._map = {}
3043
+ self._values = []
3044
+ self._inv_map = []
3045
+ return self
3046
+
3047
+ def set_uniform(self) -> CPTPotentialFunction:
3048
+ """
3049
+ Set each CPD to a uniform distribution.
3050
+
3051
+ Returns:
3052
+ self
3053
+ """
3054
+ self.clear()
3055
+ for parent_states in self.parent_instances():
3056
+ self.set_cpd_uniform(parent_states)
3057
+ return self
3058
+
3059
+ def set_random(self, random: Callable[[], float], sparsity: float = 0) -> CPTPotentialFunction:
3060
+ """
3061
+ Set the values of the potential function to random CPDs.
3062
+
3063
+ Args:
3064
+ random: is a stream of random numbers, assumed uniformly distributed in the interval [0, 1].
3065
+ sparsity: sets the expected proportion of probability values that are zero.
3066
+
3067
+ Returns:
3068
+ self
3069
+ """
3070
+ self.clear()
3071
+ for parent_states in self.parent_instances():
3072
+ self.set_cpd_random(parent_states, random, sparsity)
3073
+ return self
3074
+
3075
+ def set(self, *rows: Tuple[Key, Sequence[float]]) -> CPTPotentialFunction:
3076
+ """
3077
+ Calls self.set_cpd(parent_states, cpd) for each row (parent_states, cpd)
3078
+ in rows. Any unmentioned parent states will have zero probabilities.
3079
+
3080
+ Example usage, assuming three Boolean random variables:
3081
+ pgm.Factor(x, y, z).set_cpt().set(
3082
+ # y z x[0] x[1]
3083
+ ((0, 0), (0.1, 0.9)),
3084
+ ((0, 1), (0.1, 0.9)),
3085
+ ((1, 0), (0.1, 0.9)),
3086
+ ((1, 1), (0.1, 0.9))
3087
+ )
3088
+
3089
+ Args:
3090
+ *rows: are tuples (key, cpd) used to set the potential function values.
3091
+
3092
+ Raises:
3093
+ ValueError: if a CPD is not valid.
3094
+
3095
+ Returns:
3096
+ self
3097
+ """
3098
+ self.clear()
3099
+ for parent_states, cpd in rows:
3100
+ self.set_cpd(parent_states, cpd)
3101
+ return self
3102
+
3103
+ def set_all(self, *cpds: Optional[Sequence[float]]) -> CPTPotentialFunction:
3104
+ """
3105
+ Set all CPDs using the given `cpds` which are taken to be in order of the parent states
3106
+ with the last variable of the parent changing state most rapidly, as per parent_states().
3107
+
3108
+ If insufficient CPDs are provided then the remaining parent instantiations are taken to be
3109
+ impossible (i.e. not set and guaranteed zero).
3110
+ If too many CPDs are provided then the extras are ignored.
3111
+ Any list entry may be None, indicating 'guaranteed zero' for the associated parent states.
3112
+
3113
+ Args:
3114
+ *cpds: are the CPDs used to set the potential function values.
3115
+
3116
+ Raises:
3117
+ ValueError: if a CPD is not valid.
3118
+
3119
+ Returns:
3120
+ self
3121
+ """
3122
+ self.clear()
3123
+ for parent_states, cpd in zip(self.parent_instances(), cpds):
3124
+ self.set_cpd(parent_states, cpd)
3125
+ return self
3126
+
3127
+ def set_cpd(self, parent_states: Key, cpd: Optional[Sequence[float]]) -> CPTPotentialFunction:
3128
+ """
3129
+ Set the CPD of the given parent states to the given cpd.
3130
+ If cpd is None or all zeros, then this is equivalent to clear_cpd(parent_states).
3131
+
3132
+ Args:
3133
+ parent_states: indicates the CPD to set, based on the parent states.
3134
+ cpd: is a conditioned probability distribution, or None indicating `guaranteed zero`.
3135
+
3136
+ Raises:
3137
+ ValueError: if the CPD is not valid.
3138
+ KeyError if the key is not valid.
3139
+
3140
+ Returns:
3141
+ self
3142
+ """
3143
+ parent_instance: Instance = check_key(self._parent_shape, parent_states)
3144
+
3145
+ if cpd is None:
3146
+ self._clear_cpd(parent_instance)
3147
+ return self
3148
+
3149
+ if len(cpd) != self._child_size:
3150
+ raise ValueError(f'CPD incorrect size: expected {self._child_size}, got {len(cpd)}')
3151
+ if not all(0 <= value <= 1 for value in cpd):
3152
+ raise ValueError(f'not a valid CPD: {cpd!r}')
3153
+
3154
+ total_value = sum(cpd)
3155
+ if total_value < self._tolerance:
3156
+ self._clear_cpd(parent_instance)
3157
+ return self
3158
+
3159
+ if total_value < 1 - self._tolerance or total_value > 1 + self._tolerance:
3160
+ raise ValueError(f'not a valid CPD: sum of values = {total_value}')
3161
+
3162
+ offset: Optional[int] = self._map.get(parent_instance)
3163
+ child_size: int = self._child_size
3164
+ if offset is None:
3165
+ offset = len(self._values)
3166
+ self._values.extend(cpd)
3167
+ self._map[parent_instance] = offset
3168
+ self._inv_map.append(parent_instance)
3169
+ else:
3170
+ self._values[offset:offset + child_size] = cpd
3171
+
3172
+ return self
3173
+
3174
+ def clear_cpd(self, parent_states: Key) -> CPTPotentialFunction:
3175
+ """
3176
+ Set the CPD of the given parent_states to all 'guaranteed zero'.
3177
+
3178
+ Args:
3179
+ parent_states: indicates the CPD to clear, based on the parent states.
3180
+
3181
+ Raises:
3182
+ KeyError if the key is not valid.
3183
+
3184
+ Returns:
3185
+ self
3186
+ """
3187
+ parent_instance: Instance = check_key(self._parent_shape, parent_states)
3188
+ self._clear_cpd(parent_instance)
3189
+ return self
3190
+
3191
+ def set_cpd_uniform(self, parent_states: Key) -> CPTPotentialFunction:
3192
+ """
3193
+ Set the CPD of the given parent_states to a uniform CPD.
3194
+
3195
+ Args:
3196
+ parent_states: indicates the CPD to clear, based on the parent states.
3197
+
3198
+ Raises:
3199
+ KeyError if the key is not valid.
3200
+
3201
+ Returns:
3202
+ self
3203
+ """
3204
+ num_states = self.number_of_child_states
3205
+ cpd = [1.0 / num_states] * num_states
3206
+ return self.set_cpd(parent_states, cpd)
3207
+
3208
+ def set_cpd_random(
3209
+ self,
3210
+ parent_states: Key,
3211
+ random: Callable[[], float],
3212
+ sparsity: float = 0,
3213
+ ) -> CPTPotentialFunction:
3214
+ """
3215
+ Set the CPD of the given parent_states to a random CPD.
3216
+
3217
+ Args:
3218
+ parent_states: identifies the CPD being set.
3219
+ random: is a stream of random numbers, assumed uniformly distributed in the interval [0, 1].
3220
+ sparsity: sets the expected proportion of probability values that are zero.
3221
+
3222
+ Returns:
3223
+ self
3224
+ """
3225
+ cpd = np.zeros(self.number_of_child_states, dtype=np.float64)
3226
+ if sparsity <= 0:
3227
+ for i in range(len(cpd)):
3228
+ cpd[i] = 0.0000001 + random()
3229
+ else:
3230
+ for i in range(len(cpd)):
3231
+ if random() > sparsity:
3232
+ cpd[i] = 0.0000001 + random()
3233
+ sum_value = np.sum(cpd)
3234
+ if sum_value > 0:
3235
+ cpd /= sum_value
3236
+ return self.set_cpd(parent_states, cpd)
3237
+ else:
3238
+ return self.clear_cpd(parent_states)
3239
+
3240
+ def _clear_cpd(self, parent_instance: Instance) -> None:
3241
+ """
3242
+ Remove the parent instance from the parameters
3243
+ """
3244
+ offset: Optional[int] = self._map.get(parent_instance)
3245
+ if offset is None:
3246
+ # nothing to do
3247
+ return
3248
+
3249
+ child_size: int = self._child_size
3250
+ end_offset: int = len(self._values) - child_size
3251
+ if offset != end_offset:
3252
+ # need to swap parameters
3253
+ end_cpd = self._values[end_offset:]
3254
+ end_parent_instance = self._inv_map[-1]
3255
+
3256
+ self._values[offset:offset + child_size] = end_cpd
3257
+ self._map[end_parent_instance] = offset
3258
+ self._inv_map[offset // child_size] = end_parent_instance
3259
+
3260
+ self._map.pop(parent_instance)
3261
+ self._inv_map.pop()
3262
+ for _ in range(child_size):
3263
+ self._values.pop()
3264
+
3265
+
3266
+ def default_pgm_name(pgm: PGM) -> str:
3267
+ """
3268
+ If no name is provided to a PGM constructor, then this will be the default name for the PGM.
3269
+
3270
+ Args:
3271
+ pgm: a PGM object.
3272
+
3273
+ Returns:
3274
+ a name for the PGM if none is given at construction time.
3275
+ """
3276
+ return 'PGM_' + str(id(pgm))
3277
+
3278
+
3279
+ def check_key(shape: Shape, key: Key) -> Instance:
3280
+ """
3281
+ Convert the key into an instance.
3282
+
3283
+ Args:
3284
+ shape: the shape defining the state space.
3285
+ key: a key into the state space.
3286
+
3287
+ Returns:
3288
+ A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
3289
+
3290
+ Raises:
3291
+ KeyError if the key is not valid.
3292
+ """
3293
+ _key: Instance = _key_to_instance(key)
3294
+ if len(_key) != len(shape):
3295
+ raise KeyError(f'not a valid key for shape {shape}: {key!r}')
3296
+ if all((0 <= i <= m) for i, m in zip(_key, shape)):
3297
+ return tuple(_key)
3298
+ raise KeyError(f'not a valid key for shape {shape}: {key!r}')
3299
+
3300
+
3301
+ def valid_key(shape: Shape, key: Key) -> bool:
3302
+ """
3303
+ Is the given key valid.
3304
+
3305
+ Args:
3306
+ shape: the shape defining the state space.
3307
+ key: a key into the state space.
3308
+
3309
+ Returns:
3310
+ True only if tke key is valid for the given shape.
3311
+ """
3312
+ try:
3313
+ check_key(shape, key)
3314
+ return True
3315
+ except KeyError:
3316
+ return False
3317
+
3318
+
3319
+ def number_of_states(*rvs: RandomVariable) -> int:
3320
+ """
3321
+ Returns:
3322
+ What is the size of the state space, i.e., `multiply(len(rv) for rv in self.rvs)`.
3323
+ """
3324
+ return _multiply(len(rv) for rv in rvs)
3325
+
3326
+
3327
+ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]:
3328
+ """
3329
+ Enumerate instances of the given random variables.
3330
+
3331
+ Each instance is a tuples of state indexes, co-indexed with the given random variables.
3332
+
3333
+ The order is the natural index order (i.e., last random variable changing most quickly).
3334
+
3335
+ Args:
3336
+ flip: if true, then first random variable changes most quickly.
3337
+
3338
+ Returns:
3339
+ an iteration over tuples, each tuple holds state indexes
3340
+ co-indexed with the given random variables.
3341
+ """
3342
+ shape = [len(rv) for rv in rvs]
3343
+ return _combos_ranges(shape, flip=not flip)
3344
+
3345
+
3346
+ def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterable[Sequence[Indicator]]:
3347
+ """
3348
+ Enumerate instances of the given random variables.
3349
+
3350
+ Each instance is a tuples of indicators, co-indexed with the given random variables.
3351
+
3352
+ The order is the natural index order (i.e., last random variable changing most quickly).
3353
+
3354
+ Args:
3355
+ flip: if true, then first random variable changes most quickly.
3356
+
3357
+ Returns:
3358
+ an iteration over tuples, each tuples holds random variable indicators
3359
+ co-indexed with the given random variables.
3360
+ """
3361
+ return _combos(rvs, flip=not flip)
3362
+
3363
+
3364
+ def _key_to_instance(key: Key) -> Instance:
3365
+ """
3366
+ Convert a key to an instance.
3367
+
3368
+ Args:
3369
+ key: a key into a state space.
3370
+
3371
+ Returns:
3372
+ A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
3373
+
3374
+ Assumes:
3375
+ The key is valid for the implied state space.
3376
+ """
3377
+ if isinstance(key, int):
3378
+ return (key,)
3379
+ else:
3380
+ return tuple(key)
3381
+
3382
+
3383
+ def _natural_key_idx(shape: Shape, key: Key) -> int:
3384
+ """
3385
+ What is the natural index of the given key, assuming the given shape.
3386
+
3387
+ Args:
3388
+ shape: the shape defining the state space.
3389
+ key: a key into the state space.
3390
+
3391
+ Returns:
3392
+ an index as per enumerated instances in their natural order, i.e.
3393
+ last random variable changing most quickly.
3394
+
3395
+ Assumes:
3396
+ The key is valid for the shape.
3397
+ """
3398
+ instance: Instance = _key_to_instance(key)
3399
+ result: int = instance[0]
3400
+ for s, i in zip(shape[1:], instance[1:]):
3401
+ result = result * s + i
3402
+ return result
3403
+
3404
+
3405
+ def _zero_space(shape: Shape) -> int:
3406
+ """
3407
+ Return the size of the zero space of the given shape. This is the number
3408
+ of possible instances in the state space that do not have a zero in the instance.
3409
+
3410
+ The zero space is the same as the shape but with one less state
3411
+ for each random variable.
3412
+
3413
+ Args:
3414
+ shape: the shape defining the state space.
3415
+
3416
+ Returns:
3417
+ the size of the zero space.
3418
+ """
3419
+ return _multiply(x - 1 for x in shape)
3420
+
3421
+
3422
+ def _normalise_potential_function(
3423
+ function: Union[DensePotentialFunction, SparsePotentialFunction],
3424
+ grouping_positions: Sequence[int],
3425
+ ) -> None:
3426
+ """
3427
+ Convert the potential function to a CPT with 'grouping_positions' nominating
3428
+ the parent random variables.
3429
+
3430
+ I.e., for each possible key of the function with the same value at each
3431
+ grouping position, the sum of values for matching keys in the factor is scaled
3432
+ to be 1 (or 0).
3433
+
3434
+ Parameter 'grouping_positions' are indices into `function.shape`. For example, the
3435
+ grouping positions of a factor with parent rvs 'conditioning_rvs', then
3436
+ grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
3437
+
3438
+ Args:
3439
+ function: the potential function to normalise.
3440
+ grouping_positions: indices into `function.shape`.
3441
+ """
3442
+ if len(grouping_positions) == 0:
3443
+ total = sum(
3444
+ function.param_value(param_idx)
3445
+ for param_idx in range(function.number_of_parameters)
3446
+ )
3447
+ if total != 0 and total != 1:
3448
+ for param_key, param_idx, param_value in function.keys_with_param:
3449
+ function.set_param_value(param_idx, param_value / total)
3450
+ else:
3451
+ group_sum = {}
3452
+ for param_key, param_idx, param_value in function.keys_with_param:
3453
+ group = tuple(param_key[i] for i in grouping_positions)
3454
+ group_sum[group] = group_sum.get(group, 0) + param_value
3455
+
3456
+ for param_key, param_idx, param_value in function.keys_with_param:
3457
+ group = tuple(param_key[i] for i in grouping_positions)
3458
+ total = group_sum[group]
3459
+ if total > 0:
3460
+ function.set_param_value(param_idx, param_value / total)
3461
+
3462
+
3463
+ _CLEAN_CHARS: Set[str] = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-+~?.')
3464
+
3465
+
3466
+ def _clean_str(s) -> str:
3467
+ """
3468
+ Quote a string if empty or not all characters are in _CLEAN_CHARS.
3469
+ This is used when rendering indicators.
3470
+ """
3471
+ s = str(s)
3472
+ if len(s) == 0 or not all(c in _CLEAN_CHARS for c in s):
3473
+ return repr(s)
3474
+ else:
3475
+ return s