compiled-knowledge 4.0.0__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (182) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37515 -0
  4. ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +75 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +27 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19835 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +128 -0
  16. ck/circuit_compiler/interpret_compiler.py +255 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +552 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +251 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +70 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +56 -0
  58. ck/example/truss.py +51 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +482 -0
  63. ck/in_out/parser_utils.py +189 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3482 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +236 -0
  73. ck/pgm_circuit/pgm_circuit.py +88 -0
  74. ck/pgm_circuit/program_with_slotmap.py +217 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +78 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +60 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +572 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +52 -0
  100. ck/probability/pgm_probability_space.py +36 -0
  101. ck/probability/probability_space.py +627 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +106 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +234 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +164 -0
  118. ck/utils/local_config.py +278 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/ace/simple_ace_demo.py +18 -0
  129. ck_demos/all_demos.py +88 -0
  130. ck_demos/circuit/__init__.py +0 -0
  131. ck_demos/circuit/demo_circuit_dump.py +22 -0
  132. ck_demos/circuit/demo_derivatives.py +43 -0
  133. ck_demos/circuit_compiler/__init__.py +0 -0
  134. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  135. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  136. ck_demos/getting_started/__init__.py +0 -0
  137. ck_demos/getting_started/simple_demo.py +18 -0
  138. ck_demos/pgm/__init__.py +0 -0
  139. ck_demos/pgm/demo_pgm_dump.py +18 -0
  140. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  141. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  142. ck_demos/pgm/show_examples.py +25 -0
  143. ck_demos/pgm_compiler/__init__.py +0 -0
  144. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  145. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  146. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  147. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  148. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  149. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  150. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  151. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  152. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  153. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  154. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  155. ck_demos/pgm_inference/__init__.py +0 -0
  156. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  157. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  158. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  159. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  160. ck_demos/programs/__init__.py +0 -0
  161. ck_demos/programs/demo_program_buffer.py +24 -0
  162. ck_demos/programs/demo_program_multi.py +24 -0
  163. ck_demos/programs/demo_program_none.py +19 -0
  164. ck_demos/programs/demo_program_single.py +23 -0
  165. ck_demos/programs/demo_raw_program_dump.py +17 -0
  166. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  167. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  168. ck_demos/sampling/__init__.py +0 -0
  169. ck_demos/sampling/check_sampler.py +71 -0
  170. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  171. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  172. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  173. ck_demos/utils/__init__.py +0 -0
  174. ck_demos/utils/compare.py +120 -0
  175. ck_demos/utils/convert_network.py +45 -0
  176. ck_demos/utils/sample_model.py +216 -0
  177. ck_demos/utils/stop_watch.py +384 -0
  178. compiled_knowledge-4.0.0.dist-info/METADATA +50 -0
  179. compiled_knowledge-4.0.0.dist-info/RECORD +182 -0
  180. compiled_knowledge-4.0.0.dist-info/WHEEL +6 -0
  181. compiled_knowledge-4.0.0.dist-info/licenses/LICENSE.txt +21 -0
  182. compiled_knowledge-4.0.0.dist-info/top_level.txt +2 -0
ck/pgm.py ADDED
@@ -0,0 +1,3482 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
6
+ from itertools import repeat as _repeat
7
+ from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
8
+ Collection, Any, Iterator, TypeAlias
9
+
10
+ import numpy as np
11
+
12
+ from ck.utils.iter_extras import (
13
+ combos_ranges as _combos_ranges, multiply as _multiply, combos as _combos
14
+ )
15
+ from ck.utils.np_extras import NDArrayFloat64, NDArrayUInt8
16
+
17
+ State: TypeAlias = Union[int, str, bool, float, None]
18
+ """
19
+ The type for a possible state of a random variable.
20
+ """
21
+
22
+ Instance: TypeAlias = Sequence[int]
23
+ """
24
+ An instance (of a sequence of random variables) is a sequence of integers
25
+ that are state indexes, co-indexed with a known sequence of random variables.
26
+ """
27
+
28
+ Key: TypeAlias = Union[Instance, int]
29
+ """
30
+ A key identifies an instance, either as an instance itself or a
31
+ single integer, representing an instance with one dimension.
32
+ """
33
+
34
+ Shape: TypeAlias = Sequence[int]
35
+ """
36
+ The type for the "shape" of a sequence of random variables.
37
+ That is, the shape of (rv1, rv2, rv3) is (len(rv1), len(rv2), len(rv3)).
38
+ """
39
+
40
+ DEFAULT_CPT_TOLERANCE: float = 0.000001
41
+ """
42
+ A tolerance when checking CPT distributions sum to one (or zero).
43
+ """
44
+
45
+
46
+ class PGM:
47
+ """
48
+ A probabilistic graphical model (PGM) represents a joint probability distribution over
49
+ a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
50
+
51
+ Add a random variable to a PGM, `pgm`, using `rv = pgm.new_rv(...)`.
52
+
53
+ Add a factor to the PGM, `pgm`, using `factor = pgm.new_factor(...)`.
54
+ """
55
+
56
+ def __init__(self, name: Optional[str] = None):
57
+ """
58
+ Create an empty PGM.
59
+
60
+ Args:
61
+ name: an optional name for the PGM. If not provided, a default name will be
62
+ created using `default_pgm_name`.
63
+ """
64
+ self._name: str = name if name is not None else default_pgm_name(self)
65
+ self._rvs: Tuple[RandomVariable, ...] = ()
66
+ self._shape: Shape = ()
67
+ self._indicators: Tuple[Indicator, ...] = ()
68
+ self._factors: Tuple[Factor, ...] = ()
69
+
70
+ @property
71
+ def name(self) -> str:
72
+ """
73
+ Returns:
74
+ The name of the PGM.
75
+ """
76
+ return self._name
77
+
78
+ @property
79
+ def number_of_rvs(self) -> int:
80
+ """
81
+ Returns:
82
+ How many random variables are defined in this PGM.
83
+ """
84
+ return len(self._rvs)
85
+
86
+ @property
87
+ def shape(self) -> Shape:
88
+ """
89
+ Returns:
90
+ a sequence of the lengths of `self.rvs`.
91
+ """
92
+ return self._shape
93
+
94
+ @property
95
+ def number_of_indicators(self) -> int:
96
+ """
97
+ Returns:
98
+ How many indicators are defined in this PGM, i.e., `sum(len(rv) for rv in self.rvs)`.
99
+ """
100
+ return len(self._indicators)
101
+
102
+ @property
103
+ def number_of_states(self) -> int:
104
+ """
105
+ Returns:
106
+ What is the size of the state space, i.e., `multiply(len(rv) for rv in self.rvs)`.
107
+ """
108
+ return number_of_states(*self._rvs)
109
+
110
+ @property
111
+ def number_of_factors(self) -> int:
112
+ """
113
+ Returns:
114
+ How many factors are defined in this PGM.
115
+ """
116
+ return len(self._factors)
117
+
118
+ @property
119
+ def number_of_functions(self) -> int:
120
+ """
121
+ Returns:
122
+ How many potential functions are defined in this PGM, including zero potential functions.
123
+ """
124
+ return sum(1 for _ in self.functions)
125
+
126
+ @property
127
+ def number_of_non_zero_functions(self) -> int:
128
+ """
129
+ Returns:
130
+ How many potential functions are defined in this PGM, excluding zero potential functions.
131
+ """
132
+ return sum(1 for _ in self.non_zero_functions)
133
+
134
+ @property
135
+ def rvs(self) -> Sequence[RandomVariable]:
136
+ """
137
+ Returns:
138
+ All the random variables, in `idx` order, which is the same as creation order.
139
+
140
+ Ensures:
141
+ `self.rvs[rv.idx] = rv`
142
+ """
143
+ return self._rvs
144
+
145
+ @property
146
+ def rv_log_sizes(self) -> Sequence[float]:
147
+ """
148
+ Returns:
149
+ [log2(len(rv)) for rv in self.rvs]
150
+ """
151
+ return [math.log2(len(rv)) for rv in self.rvs]
152
+
153
+ @property
154
+ def indicators(self) -> Sequence[Indicator]:
155
+ """
156
+ Returns:
157
+ All the random variable indicators.
158
+
159
+ Ensures:
160
+ the indicators of a random variable are adjacent,
161
+ the indicators of a random variable are in state index order,
162
+ the random variables are in the same order as `self.rvs`.
163
+ """
164
+ return self._indicators
165
+
166
+ @property
167
+ def factors(self) -> Sequence[Factor]:
168
+ """
169
+ Returns:
170
+ All the factors, in `idx` order, which is the same as creation order.
171
+
172
+ Ensures:
173
+ `self.factors[factor.idx] = factor`
174
+ """
175
+ return self._factors
176
+
177
+ @property
178
+ def functions(self) -> Iterable[PotentialFunction]:
179
+ """
180
+ Iterate over all in-use potential functions of this PGM, including
181
+ zero potential functions.
182
+
183
+ Returns:
184
+ An Iterable over all potential functions (including zero potential functions).
185
+ """
186
+ seen: Set[int] = set()
187
+ for factor in self._factors:
188
+ function = factor.function
189
+ if id(function) not in seen:
190
+ seen.add(id(function))
191
+ yield function
192
+
193
+ @property
194
+ def non_zero_functions(self) -> Iterable[PotentialFunction]:
195
+ """
196
+ Iterate over all in-use potential functions of this PGM, excluding
197
+ zero potential functions.
198
+
199
+ Returns:
200
+ An Iterable over all potential functions (excluding zero potential functions).
201
+ """
202
+ seen: Set[int] = set()
203
+ for factor in self._factors:
204
+ function = factor.function
205
+ if not (isinstance(function, ZeroPotentialFunction) or id(function) in seen):
206
+ seen.add(id(function))
207
+ yield function
208
+
209
+ def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
210
+ """
211
+ Add a new random variable to this PGM.
212
+
213
+ The returned random variable will have an `idx` equal to the value of
214
+ `self.number_of_rvs` just prior to adding the new random variable.
215
+
216
+ The states of the random variable can be specified either as an integer
217
+ representing the number of states, or as a sequence of state values. If a
218
+ single integer, `n`, is provided then the states will be: 0, 1, ..., n-1.
219
+ If a sequence of states are provided then the states must be unique.
220
+
221
+ Assumes:
222
+ Provided states contain no duplicates.
223
+
224
+ Args:
225
+ name: a name for the random variable.
226
+ states: either the number of states or a sequence of state values.
227
+
228
+ Returns:
229
+ a RandomVariable object belonging to this PGM.
230
+ """
231
+ return RandomVariable(self, name, states)
232
+
233
+ def new_factor(self, *rvs: RandomVariable) -> Factor:
234
+ """
235
+ Add a new factor to this PGM where the factor connects
236
+ the given random variables.
237
+
238
+ The returned factor will have a ZeroPotentialFunction as its potential function.
239
+ The potential function may be changed by calling methods on the returned factor.
240
+
241
+ The returned factor will have an `idx` equal to the value of
242
+ `self.number_of_factors` just prior to adding the new factor.
243
+
244
+ Assumes:
245
+ The given random variables all belong to this PGM.
246
+
247
+ The random variables contain no duplicates.
248
+
249
+ Args:
250
+ rvs: the random variables.
251
+
252
+ Returns:
253
+ a Factor object belonging to this PGM.
254
+ """
255
+ return Factor(self, *rvs)
256
+
257
+ def new_factor_implies(
258
+ self,
259
+ rv_1: RandomVariable,
260
+ state_idxs_1: int | Collection[int],
261
+ rv_2: RandomVariable,
262
+ state_idxs_2: int | Collection[int],
263
+ ) -> Factor:
264
+ """
265
+ Add a sparse 0/1 factor to this PGM representing:
266
+ rv_1 in state_idxs_1 ==> rv_2 in states_2.
267
+ That is:
268
+ factor[s1, s2] = 1, if s1 not in state_idxs_1 or s2 in states_2;
269
+ = 0, otherwise.
270
+
271
+ Args:
272
+ rv_1: The first random variable.
273
+ state_idxs_1: state idxs of the first random variable.
274
+ rv_2: The second random variable.
275
+ state_idxs_2: state idxs of the second random variable.
276
+
277
+ Returns:
278
+ a Factor object belonging to this PGM, with a configured sparse potential function.
279
+ """
280
+ if isinstance(state_idxs_1, int):
281
+ state_idxs_1 = (state_idxs_1,)
282
+ if isinstance(state_idxs_2, int):
283
+ state_idxs_2 = (state_idxs_2,)
284
+
285
+ factor = self.new_factor(rv_1, rv_2)
286
+ f = factor.set_sparse()
287
+ for i_1 in rv_1.state_range():
288
+ if i_1 not in state_idxs_1:
289
+ for i_2 in rv_2.state_range():
290
+ f[i_1, i_2] = 1
291
+ else:
292
+ for i_2 in rv_2.state_range():
293
+ if i_2 in state_idxs_2:
294
+ f[i_1, i_2] = 1
295
+ return factor
296
+
297
+ def new_factor_equiv(
298
+ self,
299
+ rv_1: RandomVariable,
300
+ state_idxs_1: int | Collection[int],
301
+ rv_2: RandomVariable,
302
+ state_idxs_2: int | Collection[int],
303
+ ) -> Factor:
304
+ """
305
+ Add a sparse 0/1 factor to this PGM representing:
306
+ rv_1 in state_idxs_1 <==> rv_2 in state_idxs_2.
307
+ That is:
308
+ factor[s1, s2] = 1, if s1 in state_idxs_1 == s2 in state_idxs_2;
309
+ = 0, otherwise.
310
+
311
+ Args:
312
+ rv_1: The first random variable.
313
+ state_idxs_1: state idxs of the first random variable.
314
+ rv_2: The second random variable.
315
+ state_idxs_2: state idxs of the second random variable.
316
+
317
+ Returns:
318
+ a Factor object belonging to this PGM, with a configured sparse potential function.
319
+ """
320
+ if isinstance(state_idxs_1, int):
321
+ state_idxs_1 = (state_idxs_1,)
322
+ if isinstance(state_idxs_2, int):
323
+ state_idxs_2 = (state_idxs_2,)
324
+
325
+ factor = self.new_factor(rv_1, rv_2)
326
+ f = factor.set_sparse()
327
+ for i_1 in rv_1.state_range():
328
+ in_1 = i_1 in state_idxs_1
329
+ for i_2 in rv_2.state_range():
330
+ in_2 = i_2 in state_idxs_2
331
+ if in_1 == in_2:
332
+ f[i_1, i_2] = 1
333
+ return factor
334
+
335
+ def new_factor_functional(
336
+ self,
337
+ function: Callable[[...], int],
338
+ result_rv: RandomVariable,
339
+ *input_rvs: RandomVariable
340
+ ) -> Factor:
341
+ """
342
+ Add a sparse 0/1 factor to this PGM representing `result_rv == function(*rvs)`.
343
+ That is::
344
+
345
+ factor[result_s, *input_s] = 1, if result_s == function(*input_s);
346
+ = 0, otherwise.
347
+
348
+ Args:
349
+ function: a function from state indexes of the input random variables to a state index
350
+ of the result random variable. The function should take the same number of arguments
351
+ as `input_rvs` and return a state index for `result_rv`.
352
+ result_rv: the random variable defining result values.
353
+ input_rvs: the random variables defining input values.
354
+
355
+ Returns:
356
+ a Factor object belonging to this PGM, with a configured sparse potential function.
357
+ """
358
+ factor = self.new_factor(result_rv, *input_rvs)
359
+ f = factor.set_sparse()
360
+ for input_s in _combos([list(rv.state_range()) for rv in input_rvs]):
361
+ result_s = function(*input_s)
362
+ f[(result_s,) + input_s] = 1
363
+ return factor
364
+
365
+ def indicator_pair(self, indicator: Indicator) -> Tuple[RandomVariable, State]:
366
+ """
367
+ Convert the given indicator to its RandomVariable and State value.
368
+
369
+ Args:
370
+ indicator: the indicator to convert.
371
+
372
+ Returns:
373
+ (rv, state) where
374
+ rv: is the random variable of the indicator.
375
+ state: is the random variable state of the indicator.
376
+ """
377
+ rv = self._rvs[indicator.rv_idx]
378
+ state = rv.states[indicator.state_idx]
379
+ return rv, state
380
+
381
+ def indicator_str(self, *indicators: Indicator, sep: str = '=', delim: str = ', ') -> str:
382
+ """
383
+ Render indicators as a string.
384
+
385
+ For example::
386
+ pgm = PGM()
387
+ a = pgm.new_rv('A', ('x', 'y', 'z'))
388
+ b = pgm.new_rv('B', (3, 5))
389
+ print(pgm.indicator_str(a[0], b[1], a[2]))
390
+
391
+ will print::
392
+ A=x, B=5, A=z
393
+
394
+ Args:
395
+ indicators: the indicators to render.
396
+ sep: the separator to use between the random variable and its state.
397
+ delim: the delimiter to used when rendering multiple indicators.
398
+
399
+ Returns:
400
+ a string representation of the given indicators.
401
+ """
402
+ return delim.join(
403
+ f'{_clean_str(rv)}{sep}{_clean_str(state)}'
404
+ for rv, state in (
405
+ self.indicator_pair(indicator)
406
+ for indicator in indicators
407
+ )
408
+ )
409
+
410
+ def condition_str(self, *indicators: Indicator) -> str:
411
+ """
412
+ Render indicators as a string, grouping indicators by random variable.
413
+
414
+ For example::
415
+ pgm = PGM()
416
+ a = pgm.new_rv('A', ('x', 'y', 'z'))
417
+ b = pgm.new_rv('B', (3, 5))
418
+ print(pgm.condition_str(a[0], b[1], a[2]))
419
+
420
+ will print::
421
+ A in {x, z}, B=5
422
+
423
+ Args:
424
+ indicators: the indicators to render.
425
+ Return:
426
+ a string representation of the given indicators, as a condition.
427
+ """
428
+ indicators: List[Indicator] = sorted(indicators, reverse=True)
429
+ cur_rv: Set[Indicator] = set()
430
+ cur_idx: int = -1 # rv_idx of the rv we are currently working on, -1 means not yet started.
431
+ cur_str: str = '' # accumulated result string
432
+ while len(indicators) > 0:
433
+ this_ind = indicators.pop()
434
+ if this_ind.rv_idx != cur_idx:
435
+ if cur_idx >= 0:
436
+ cur_str = self._condition_str_rv(cur_str, cur_rv)
437
+ cur_rv = set()
438
+ cur_idx = this_ind.rv_idx
439
+ cur_rv.add(this_ind)
440
+ if cur_idx >= 0:
441
+ cur_str = self._condition_str_rv(cur_str, cur_rv)
442
+ return cur_str
443
+
444
+ def instance_str(
445
+ self,
446
+ instance: Instance,
447
+ rvs: Optional[Sequence[RandomVariable]] = None,
448
+ sep: str = '=',
449
+ delim: str = ', ',
450
+ ) -> str:
451
+ """
452
+ Render an instance as a string.
453
+
454
+ The result looks something like 'X=x, Y=y, Z=z' where X, Y, and X are
455
+ random variables and x, y, and z are the states represented by the
456
+ given instance.
457
+
458
+ Args:
459
+ instance: the instance to render.
460
+ rvs: the random variables that the instance refers to. If rvs is None, then `self.rvs` is used.
461
+ sep: the separator to use between the random variable and its state.
462
+ delim: the delimiter to used when rendering multiple indicators.
463
+
464
+ Returns:
465
+ a string representation of the indicators implied by the given instance.
466
+ """
467
+ if rvs is None:
468
+ rvs = self.rvs
469
+ assert len(instance) == len(rvs)
470
+ return self.indicator_str(
471
+ *[rv[state] for rv, state in zip(rvs, instance)],
472
+ sep=sep,
473
+ delim=delim
474
+ )
475
+
476
+ def state_str(
477
+ self,
478
+ instance: Instance,
479
+ rvs: Optional[Sequence[RandomVariable]] = None,
480
+ delim: str = ', ',
481
+ ) -> str:
482
+ """
483
+ Render the states of an instance.
484
+
485
+ The result looks something like 'x, y, z' where x, y, and z are
486
+ the states of the random variables represented by the given instance.
487
+
488
+ Args:
489
+ instance: the instance to render.
490
+ rvs: the random variables that the instance refers to. If rvs is None, then `self.rvs` is used.
491
+ delim: the delimiter to used when rendering multiple indicators.
492
+
493
+ Returns:
494
+ a string representation of the states implied by the given instance.
495
+ """
496
+ if rvs is None:
497
+ rvs = self.rvs
498
+ assert len(instance) == len(rvs)
499
+ return delim.join(str(rv.states[i]) for rv, i in zip(rvs, instance))
500
+
501
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
502
+ """
503
+ Iterate over all possible instances of this PGM, in natural index
504
+ order (i.e., last random variable changing most quickly).
505
+
506
+ Args:
507
+ flip: if true, then first random variable changes most quickly.
508
+
509
+ Returns:
510
+ an iteration over tuples, each tuple holds random variable state indexes
511
+ co-indexed with this PGM's random variables, `self.rvs`.
512
+ """
513
+ return _combos_ranges(tuple(len(rv) for rv in self._rvs), flip=not flip)
514
+
515
+ def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
516
+ """
517
+ Iterate over all possible instances of this PGM, in natural index
518
+ order (i.e., last random variable changing most quickly).
519
+
520
+ Args:
521
+ flip: if true, then first random variable changes most quickly.
522
+
523
+ Returns:
524
+ an iteration over tuples, each tuples holds random variable indicators
525
+ co-indexed with this PGM's random variables, `self.rvs`.
526
+ """
527
+ for inst in self.instances(flip=flip):
528
+ yield self.state_idxs_to_indicators(inst)
529
+
530
+ def state_idxs_to_indicators(self, instance: Sequence[int]) -> Sequence[Indicator]:
531
+ """
532
+ Given an instance (list of random variable state indexes), co-indexed with the PGM's
533
+ random variables, `self.rvs`, return the corresponding indicators.
534
+
535
+ Assumes:
536
+ The instance has the same length as `self.rvs`.
537
+ The instance is co-indexed with `self.rvs`.
538
+
539
+ Args:
540
+ instance: the instance to convert to indicators.
541
+
542
+ Returns:
543
+ a tuple of indicators, co-indexed with `self.rvs`.
544
+ """
545
+ return tuple(rv[state] for rv, state in zip(self._rvs, instance))
546
+
547
+ def factor_values(self, key: Key) -> Iterable[float]:
548
+ """
549
+ For a given instance key, each factor defines a single value. This method
550
+ returns those values.
551
+
552
+ Args:
553
+ key: the key defining an instance of this PGM.
554
+
555
+ Returns:
556
+ an iterator over factor values, co-indexed with the factors of this PGM.
557
+ """
558
+ instance: Instance = check_key(self._shape, key)
559
+ assert len(instance) == len(self._rvs)
560
+ for factor in self._factors:
561
+ states: Sequence[int] = tuple(instance[rv.idx] for rv in factor.rvs)
562
+ value: float = factor.function[states]
563
+ yield value
564
+
565
+ @property
566
+ def is_structure_bayesian(self) -> bool:
567
+ """
568
+ Does the PGM structure correspond to a Bayesian network, where
569
+ each factor is taken to be a CPT and the first random variable of factor
570
+ is taken to be the child.
571
+
572
+ This method does not check the factor parameters to confirm they correspond
573
+ to valid CPTs.
574
+
575
+ Return:
576
+ True only if:
577
+ the number of factors equals the number of random variables,
578
+ each random variable appears exactly once as the first random variable of a factor,
579
+ there are no directed loops created by the factors.
580
+ """
581
+
582
+ # One factor per random variable.
583
+ if self.number_of_factors != self.number_of_rvs:
584
+ return False
585
+
586
+ # Each random variable is a child.
587
+ # Map each random variable to the factor it is a child of
588
+ child_to_factor: Dict[int, Factor] = {
589
+ factor.rvs[0].idx: factor
590
+ for factor in self._factors
591
+ }
592
+ if len(child_to_factor) != self.number_of_rvs:
593
+ return False
594
+
595
+ # Factors form a DAG
596
+ states: NDArrayUInt8 = np.zeros(self.number_of_factors, dtype=np.uint8)
597
+ for factor in self._factors:
598
+ if self._has_cycle(factor, child_to_factor, states):
599
+ return False
600
+
601
+ # All tests passed
602
+ return True
603
+
604
+ def factors_are_cpts(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> bool:
605
+ """
606
+ Are all factor potential functions set with parameters values
607
+ conforming to Conditional Probability Tables.
608
+
609
+ Assumes:
610
+ tolerance is non-negative.
611
+
612
+ Args:
613
+ tolerance: a tolerance when testing if values are equal to zero or one.
614
+
615
+ Returns:
616
+ True only if every potential function conforms to being a valid CPT.
617
+ """
618
+ return all(function.is_cpt(tolerance) for function in self.functions)
619
+
620
+ def check_is_bayesian_network(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> bool:
621
+ """
622
+ Is this PGM a Bayesian network.
623
+
624
+ Assumes:
625
+ tolerance is non-negative.
626
+
627
+ Args:
628
+ tolerance: a tolerance when testing if values are equal to zero or one.
629
+
630
+ Returns:
631
+ `is_structure_bayesian and check_factors_are_cpts(tolerance)`.
632
+ """
633
+ return self.is_structure_bayesian and self.factors_are_cpts(tolerance)
634
+
635
+ def value_product(self, key: Key) -> float:
636
+ """
637
+ For a given instance key, each factor defines a single value. This method
638
+ returns the product of those values.
639
+
640
+ Args:
641
+ key: the key defining an instance of this PGM.
642
+
643
+ Returns:
644
+ the product of factor values.
645
+ """
646
+ return _multiply(self.factor_values(key))
647
+
648
+ def value_product_indicators(self, *indicators: Indicator) -> float:
649
+ """
650
+ Return the product of factors, conditioned on the given indicators.
651
+
652
+ For random variables not mentioned in the indicators, then the result is the sum
653
+ of the value product for each possible combination of states of the unmentioned
654
+ random variables.
655
+
656
+ If no indicators are provided, then the value of the partition function (z)
657
+ is returned.
658
+
659
+ If multiple indicators are provided for the same random variable, then all matching
660
+ instances are summed.
661
+
662
+ This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
663
+
664
+ Warning:
665
+ this is potentially computationally expensive as it marginalises random
666
+ variables not mentioned in the given indicators.
667
+
668
+ Args:
669
+ *indicators: are indicators from random variables of this PGM.
670
+
671
+ Returns:
672
+ the product of factors, conditioned on the given instance. This is the
673
+ computed value of the PGM, conditioned on the given instance.
674
+ """
675
+ # Rather than naively checking all possible states of the PGM random
676
+ # variables, this method works to define the state space that should
677
+ # be summed over, based on the given indicators. Thus, if the given
678
+ # indicators constrain the state space to a small number of possibilities,
679
+ # then the sum is only performed over those possibilities.
680
+
681
+ # Work out the space to sum over
682
+ sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
683
+ for indicator in indicators:
684
+ rv_idx: int = indicator.rv_idx
685
+ cur_set = sum_space_set[rv_idx]
686
+ if cur_set is None:
687
+ sum_space_set[rv_idx] = cur_set = set()
688
+ cur_set.add(indicator.state_idx)
689
+
690
+ # Convert to a list of states that we need to sum over.
691
+ sum_space_list: List[List[int]] = [
692
+ list(cur_set if cur_set is not None else rv.state_range())
693
+ for cur_set, rv in zip(sum_space_set, self.rvs)
694
+ ]
695
+
696
+ # Accumulate the result
697
+ return sum(
698
+ self.value_product(instance)
699
+ for instance in _combos(sum_space_list)
700
+ )
701
+
702
+ def dump_synopsis(
703
+ self,
704
+ *,
705
+ prefix: str = '',
706
+ precision: int = 3,
707
+ max_state_digits: int = 21,
708
+ ):
709
+ """
710
+ Print a synopsis of the PGM.
711
+ This is intended for demonstration and debugging purposes.
712
+
713
+ Args:
714
+ prefix: optional prefix for indenting all lines.
715
+ precision: a limit on the render precision of floating point numbers.
716
+ max_state_digits: a limit on the number of digits when showing number of states as an integer.
717
+ """
718
+ # Determine a limit to precision when displaying number of states
719
+ num_states: int = self.number_of_states
720
+ number_of_parameters = sum(function.number_of_parameters for function in self.functions)
721
+ number_of_nz_parameters = sum(function.number_of_parameters for function in self.non_zero_functions)
722
+ if math.log10(num_states) > max_state_digits:
723
+ log_states = math.log10(num_states)
724
+ exp = int(log_states)
725
+ man = math.pow(10, log_states - exp)
726
+ num_states_str = f'{man:,.{precision}f}e+{exp}'
727
+ else:
728
+ num_states_str = f'{num_states:,}'
729
+ log_2_num_states = math.log2(num_states)
730
+ if (
731
+ log_2_num_states == 0
732
+ or (
733
+ log_2_num_states == int(log_2_num_states)
734
+ and math.log10(log_2_num_states) <= max_state_digits
735
+ )
736
+ ):
737
+ log_2_num_states_str = f'{int(log_2_num_states):,}'
738
+ else:
739
+ log_2_num_states_str = f'{math.log2(num_states):,.{precision}f}'
740
+
741
+ print(f'{prefix}name: {self.name}')
742
+ print(f'{prefix}number of random variables: {self.number_of_rvs:,}')
743
+ print(f'{prefix}number of indicators: {self.number_of_indicators:,}')
744
+ print(f'{prefix}number of states: {num_states_str}')
745
+ print(f'{prefix}log 2 of states: {log_2_num_states_str}')
746
+ print(f'{prefix}number of factors: {self.number_of_factors:,}')
747
+ print(f'{prefix}number of functions: {self.number_of_functions:,}')
748
+ print(f'{prefix}number of non-zero functions: {self.number_of_non_zero_functions:,}')
749
+ print(f'{prefix}number of parameters: {number_of_parameters:,}')
750
+ print(f'{prefix}number of functions (excluding ZeroPotentialFunction): {self.number_of_non_zero_functions:,}')
751
+ print(f'{prefix}number of parameters (excluding ZeroPotentialFunction): {number_of_nz_parameters:,}')
752
+ print(f'{prefix}Bayesian structure: {self.is_structure_bayesian}')
753
+ print(f'{prefix}CPT factors: {self.factors_are_cpts()}')
754
+
755
+ def dump(
756
+ self,
757
+ *,
758
+ prefix: str = '',
759
+ indent: str = ' ',
760
+ show_function_values: bool = False,
761
+ precision: int = 3,
762
+ max_state_digits: int = 21,
763
+ ) -> None:
764
+ """
765
+ Print a dump of the PGM.
766
+ This is intended for demonstration and debugging purposes.
767
+
768
+ Args:
769
+ prefix: optional prefix for indenting all lines.
770
+ show_function_values: if true, then the function values will be dumped.
771
+ indent: additional prefix to use for extra indentation.
772
+ precision: a limit on the render precision of floating point numbers.
773
+ max_state_digits: a limit on the number of digits when showing number of states as an integer.
774
+ """
775
+
776
+ next_prefix: str = prefix + indent
777
+ next_next_prefix: str = next_prefix + indent
778
+
779
+ print(f'{prefix}PGM id={id(self)} name={self.name!r}')
780
+ self.dump_synopsis(prefix=next_prefix, precision=precision, max_state_digits=max_state_digits)
781
+
782
+ print(f'{prefix}random variables ({self.number_of_rvs})')
783
+ for rv in self.rvs:
784
+ print(f'{next_prefix}{rv.idx:>3} {rv.name!r} ({len(rv)})', end='')
785
+ if not rv.is_default_states():
786
+ print(' [', end='')
787
+ print(', '.join(repr(s) for s in rv.states), end='')
788
+ print(']', end='')
789
+ print()
790
+
791
+ print(f'{prefix}factors ({self.number_of_factors})')
792
+ for factor in self.factors:
793
+ rv_idxs = [rv.idx for rv in factor.rvs]
794
+ if factor.is_zero:
795
+ function_ref = '<zero>'
796
+ else:
797
+ function = factor.function
798
+ function_ref = f'{id(function)}: {function.__class__.__name__}'
799
+
800
+ print(f'{next_prefix}{factor.idx:>3} rvs={rv_idxs} function={function_ref}')
801
+
802
+ print(f'{prefix}functions ({self.number_of_functions})')
803
+ for function in sorted(self.non_zero_functions, key=lambda f: id(f)):
804
+ print(f'{next_prefix}{id(function):>13}: {function.__class__.__name__}')
805
+ function.dump(prefix=next_next_prefix, show_function_values=show_function_values, show_id_class=False)
806
+
807
+ print(f'{prefix}end PGM id={id(self)}')
808
+
809
+ def _has_cycle(self, factor: Factor, child_to_factor: Dict[int, Factor], states: NDArrayUInt8) -> bool:
810
+ """
811
+ Support function for `is_structure_bayesian`.
812
+
813
+ A recursive depth-first-search to see if the factors form a DAG.
814
+
815
+ For a factor `f` the value of states[f.idx] is the search state.
816
+ Specifically:
817
+ state 0 => the factor has not been seen yet,
818
+ state 1 => the factor is seen but not fully processed,
819
+ state 2 => the factor is fully processed.
820
+
821
+ Args:
822
+ factor: the current Factor being checked.
823
+ child_to_factor: a dictionary from `RandomVariable.idx` to Factor
824
+ with that random variable as the child.
825
+ states: depth-first-search states, i.e., `states[i]` is the state of a factor with `Factor.idx == i`.
826
+ Returns:
827
+ True if a directed cycle is detected.
828
+ """
829
+ f_idx: int = factor.idx
830
+ match states.item(f_idx):
831
+ case 1:
832
+ return True
833
+ case 0:
834
+ states[f_idx] = 1
835
+ for parent in factor.rvs[1:]:
836
+ parent_factor = child_to_factor[parent.idx]
837
+ if self._has_cycle(parent_factor, child_to_factor, states):
838
+ return True
839
+ states[f_idx] = 2
840
+ return False
841
+ return False
842
+
843
+ def _register_rv(self, rv: RandomVariable) -> None:
844
+ """
845
+ Called by the constructor of RandomVariable to record a newly created Random variable
846
+ of this PGM.
847
+
848
+ Args:
849
+ rv: the newly constructed random variable.
850
+ """
851
+ assert rv.pgm is self
852
+ self._rvs += (rv,)
853
+ self._shape += (len(rv),)
854
+ self._indicators += rv.indicators
855
+
856
+ def _condition_str_rv(
857
+ self,
858
+ cur_str: str,
859
+ cur_rv: Set[Indicator],
860
+ sep: str = ', ',
861
+ equal: str = '=',
862
+ elem: str = ' in ',
863
+ ) -> str:
864
+ """
865
+ Support method for `self.condition_str`.
866
+
867
+ This is a method renders a condition defined by a set of indicators, of the same random variable.
868
+
869
+ Args:
870
+ cur_str: the string to append to.
871
+ cur_rv: a set of indicators, all from the same random variable.
872
+ sep: the separator string to use between condition components.
873
+ equal: the string to use for _rv_ = _state_.
874
+ elem: the string to use for _rv_ in _set_.
875
+
876
+ Returns:
877
+ `cur_str` appended with the new condition, `cur_rv`.
878
+ """
879
+ if cur_str != '':
880
+ cur_str += sep
881
+ if len(cur_rv) == 1:
882
+ cur_str += self.indicator_str(*cur_rv, sep=equal)
883
+ else:
884
+ _cur_rv = sorted(cur_rv)
885
+ rv = self._rvs[_cur_rv[0].rv_idx]
886
+ states_str: str = sep.join(_clean_str(rv.states[ind.state_idx]) for ind in _cur_rv)
887
+ cur_str += f'{_clean_str(rv)}{elem}{{{states_str}}}'
888
+ return cur_str
889
+
890
+
891
+ @dataclass(frozen=True, eq=True, slots=True)
892
+ class Indicator:
893
+ """
894
+ An indicator identifies a random variable being in a particular state.
895
+
896
+ Indicators are immutable and hashable.
897
+
898
+ Note that an Indicator does not know which PGM it came from, therefore indicators from one PGM
899
+ are interchangeable with indicators of another PGM so long as corresponding random variables of the
900
+ PGMs are co-indexed (created in the same order) and corresponding random variables have the same
901
+ states.
902
+
903
+ Fields:
904
+ rv_idx: `rv.idx` where `rv` is the random variable referenced by this indicator.
905
+ state_idx: the state index of the state referenced by this indicator.
906
+ """
907
+ rv_idx: int
908
+ state_idx: int
909
+
910
+ def __lt__(self, other) -> bool:
911
+ """
912
+ Define a sort order over indicators.
913
+ When sorted, indicators are ordered by random variable index, then by state index.
914
+ """
915
+ if isinstance(other, Indicator):
916
+ if self.rv_idx < other.rv_idx:
917
+ return True
918
+ if self.rv_idx > other.rv_idx:
919
+ return False
920
+ return self.state_idx < other.state_idx
921
+ return False
922
+
923
+
924
+ class RandomVariable(Sequence[Indicator]):
925
+ """
926
+ A random variable in a probabilistic graphical model.
927
+
928
+ Random variables are immutable and hashable.
929
+
930
+ Each RandomVariable has a fixed finite number of states.
931
+ Its states are indexed by integers, counting from zero.
932
+
933
+ Every RandomVariable object belongs to exactly one PGM object.
934
+
935
+ Every random variable has an index (counting from zero) which is its position
936
+ in the random variable's PGM list of random variables.
937
+
938
+ A random variable behaves like a sequence of Indicators, where each indicator represents a random
939
+ variable being in a particular state. Specifically for a random variable rv, `len(rv)` is the
940
+ number of states of the random variable and rv[i] is the Indicators representing that
941
+ rv is in the ith state. When sliced, the result is a tuple, i.e. `rv[1:3] = (rv[1], rv[2])`.
942
+
943
+ A RandomVariable has a name. This is for human convenience and has no functional purpose
944
+ within a PGM.
945
+ """
946
+
947
+ def __init__(self, pgm: PGM, name: str, states: Union[int, Sequence[State]]):
948
+ """
949
+ Create a new random variable, in the given PGM.
950
+
951
+ The states of the random variable can be specified either as an integer
952
+ representing the number of states, or as a sequence of state values. If a
953
+ single integer, `n`, is provided then the states will be: 0, 1, ..., n-1.
954
+ If a sequence of states are provided then the states must be unique.
955
+
956
+ Assumes:
957
+ Provided states contain no duplicates.
958
+
959
+ Args:
960
+ pgm: the PGM that the random variable will belong to.
961
+ name: a name for the random variable.
962
+ states: either the number of states or a sequence of state values.
963
+ """
964
+ self._pgm: PGM = pgm
965
+ self._name: str = name
966
+
967
+ if isinstance(states, int):
968
+ states = tuple(range(states))
969
+
970
+ self._states: Sequence[State] = tuple(states)
971
+ self._inv_states: Dict[State, int] = {state: idx for idx, state in enumerate(self._states)}
972
+
973
+ if len(self._inv_states) != len(self._states):
974
+ raise ValueError('random variable states are not unique')
975
+
976
+ self._offset: int = pgm.number_of_indicators
977
+ self._idx: int = pgm.number_of_rvs
978
+ self._indicators: Sequence[Indicator] = tuple(Indicator(self._idx, i) for i in range(len(self._states)))
979
+
980
+ # Register self with our PGM
981
+ # noinspection PyProtectedMember
982
+ pgm._register_rv(self)
983
+
984
+ @property
985
+ def pgm(self) -> PGM:
986
+ """
987
+ Returns:
988
+ The PGM that this random variable belongs to.
989
+ """
990
+ return self._pgm
991
+
992
+ @property
993
+ def name(self) -> str:
994
+ """
995
+ Returns:
996
+ The name of this random variable.
997
+ """
998
+ return self._name
999
+
1000
+ @property
1001
+ def idx(self) -> int:
1002
+ """
1003
+ Returns:
1004
+ The index of this random variable into the PGM.
1005
+
1006
+ Ensures:
1007
+ `self.pgm.rvs[self.idx] is self`.
1008
+ """
1009
+ return self._idx
1010
+
1011
+ @property
1012
+ def offset(self) -> int:
1013
+ """
1014
+ Returns:
1015
+ The index into the PGM's indicators for the start of this random variable's indicators.
1016
+
1017
+ Ensures:
1018
+ `self.pgm.indicators[self.offset + i] is self[i] for i in range(len(self))`.
1019
+ """
1020
+ return self._offset
1021
+
1022
+ @property
1023
+ def states(self) -> Sequence[State]:
1024
+ """
1025
+ Returns:
1026
+ the states of this random variable, in state index order.
1027
+ """
1028
+ return self._states
1029
+
1030
+ @property
1031
+ def indicators(self) -> Sequence[Indicator]:
1032
+ """
1033
+ Returns:
1034
+ the indicators of this random variable, in state index order.
1035
+ """
1036
+ return self._indicators
1037
+
1038
+ def state_range(self) -> Iterable[int]:
1039
+ """
1040
+ Iterate over the state indexes of this random variable, in ascending order.
1041
+
1042
+ Returns:
1043
+ range(len(self))
1044
+ """
1045
+ return range(len(self._states))
1046
+
1047
+ def factors(self) -> Iterable[Factor]:
1048
+ """
1049
+ Iterate over factors that this random variable participates in.
1050
+ This method performs a search through all `self.pgm.factors`.
1051
+
1052
+ Returns:
1053
+ an iterator over factors.
1054
+ """
1055
+ for factor in self._pgm.factors:
1056
+ if self in factor.rvs:
1057
+ yield factor
1058
+
1059
+ def markov_blanket(self) -> Set[RandomVariable]:
1060
+ """
1061
+ Return the set of random variable that are connected
1062
+ to this random variable by a factor.
1063
+ This method performs a search through all `self.pgm.factors`.
1064
+
1065
+ Returns:
1066
+ a set of random variables connected to this random variable by any factor, excluding self.
1067
+ """
1068
+ result = set()
1069
+ for factor in self.factors():
1070
+ result.update(factor.rvs)
1071
+ result.discard(self)
1072
+ return result
1073
+
1074
+ def state_idx(self, state: State) -> int:
1075
+ """
1076
+ Returns:
1077
+ the state index of the given state of this random variable.
1078
+
1079
+ Assumes:
1080
+ the given state is a state of this random variable.
1081
+ """
1082
+ return self._inv_states[state]
1083
+
1084
+ def is_default_states(self) -> bool:
1085
+ """
1086
+ Are the states of this random variable the default states.
1087
+ I.e., `self.states[i] == i, for all 0 <= i < len(self)`.
1088
+
1089
+ Returns:
1090
+ True only if the states are the same as the state indexes.
1091
+ """
1092
+ return all(i == s for i, s in enumerate(self._states))
1093
+
1094
+ def __str__(self) -> str:
1095
+ """
1096
+ Returns:
1097
+ the name of this random variable.
1098
+ """
1099
+ return self._name
1100
+
1101
+ def __call__(self, state: State) -> Indicator:
1102
+ """
1103
+ Get the indicator for the given state.
1104
+ This is equivalent to self[self.state_idx(state)].
1105
+
1106
+ Returns:
1107
+ an indicator of this random variable.
1108
+
1109
+ Assumes:
1110
+ the given state is a state of this random variable.
1111
+ """
1112
+ return self._indicators[self._inv_states[state]]
1113
+
1114
+ def __hash__(self) -> int:
1115
+ """
1116
+ A random variable is hashable.
1117
+ """
1118
+ return self._idx
1119
+
1120
+ def __eq__(self, other) -> bool:
1121
+ """
1122
+ Two random variables are equal if they are the same object.
1123
+ """
1124
+ return self is other
1125
+
1126
+ def equivalent(self, other: RandomVariable | Sequence[Indicator]) -> bool:
1127
+ """
1128
+ Two random variable are equivalent if their indicators are equal.
1129
+ Only random variable indexes and state indexes are checked.
1130
+ This ignores the names of the random variable and the names of their states.
1131
+
1132
+ Slot maps operate across `equivalent` random variables.
1133
+ This means indicators of equivalent random variables will work
1134
+ correctly in slot maps, even if from different PGMs.
1135
+
1136
+ Args:
1137
+ other: either a random variable or a sequence of Indicators.
1138
+
1139
+ Returns:
1140
+ True only if they represent the same sequence of indicators.
1141
+ """
1142
+ indicators = self._indicators
1143
+ if isinstance(other, RandomVariable):
1144
+ return self.idx == other.idx and len(self) == len(other)
1145
+ else:
1146
+ return (
1147
+ len(indicators) == len(other) and
1148
+ all(indicators[i] == other[i] for i in range(len(indicators)))
1149
+ )
1150
+
1151
+ def __len__(self) -> int:
1152
+ """
1153
+ Returns:
1154
+ Number of states (or equivalently, the number of indicators) of this random variable.
1155
+ """
1156
+ return len(self._states)
1157
+
1158
+ def __iter__(self) -> Iterator[Indicator]:
1159
+ """
1160
+ Iterate over the indicators of this random variable.
1161
+ """
1162
+ return iter(self._indicators)
1163
+
1164
+ @overload
1165
+ def __getitem__(self, index: int) -> Indicator:
1166
+ ...
1167
+
1168
+ @overload
1169
+ def __getitem__(self, index: slice) -> Sequence[Indicator]:
1170
+ ...
1171
+
1172
+ def __getitem__(self, index):
1173
+ """
1174
+ Get the indexed (or sliced) indicators.
1175
+ """
1176
+ return self._indicators[index]
1177
+
1178
+ def index(self, value: Any, start: int = 0, stop: int = -1) -> int:
1179
+ """
1180
+ Returns the first index of `value`.
1181
+ Raises ValueError if the value is not present.
1182
+
1183
+ This method is contracted by `Sequence[Indicator]`.
1184
+
1185
+ Warning:
1186
+ This method is different to `self.idx`.
1187
+ """
1188
+ if isinstance(value, Indicator):
1189
+ if value.rv_idx == self._idx:
1190
+ idx: int = value.state_idx
1191
+ if stop < 0:
1192
+ stop = len(self) + stop + 1
1193
+ if 0 <= idx < len(self) and start <= idx < stop:
1194
+ return value.state_idx
1195
+ raise ValueError(f'{value!r} is not an indicator of the random variable')
1196
+
1197
+ def count(self, value: Any) -> int:
1198
+ """
1199
+ Returns the number of occurrences of `value`.
1200
+ That is, if `value` is an indicator of this random variable
1201
+ then 1 is returned, otherwise 0 is returned.
1202
+
1203
+ This method is contracted by `Sequence[Indicator]`.
1204
+ """
1205
+ if isinstance(value, Indicator):
1206
+ if value.rv_idx == self._idx and 0 <= value.state_idx < len(self):
1207
+ return 1
1208
+ return 0
1209
+
1210
+
1211
+ class RVMap(Sequence[RandomVariable]):
1212
+ """
1213
+ Wrap a PGM to provide convenient access to PGM random variables.
1214
+
1215
+ An RVMap of a PGM behaves like the PGM `rvs` property (sequence of
1216
+ RandomVariable objects), with additional access methods for the PGM's
1217
+ random variables.
1218
+
1219
+ If the underlying PGM is updated, then the RVMap will automatically update.
1220
+
1221
+ In addition to accessing a random variable by its index, an RVMap enables
1222
+ access to the PGM random variable via the name of each random variable.
1223
+
1224
+ For example, if `pgm.rvs[1]` is a random variable named `xray`, then::
1225
+
1226
+ rvs = RVMap(pgm)
1227
+
1228
+ # These all retrieve the same random variable object.
1229
+ xray = rvs[1]
1230
+ xray = rvs('xray')
1231
+ xray = rvs.xray
1232
+
1233
+ To use an RVMap on a PGM, the random variable names must be unique across the PGM.
1234
+ """
1235
+
1236
+ def __init__(self, pgm: PGM, ignore_case: bool = False):
1237
+ """
1238
+ Construct an RVMap for the given PGM.
1239
+
1240
+ Args:
1241
+ pgm: the PGM to wrap.
1242
+ ignore_case: if true, the variable name are not case-sensitive.
1243
+ """
1244
+ self._pgm: PGM = pgm
1245
+ self._ignore_case: bool = ignore_case
1246
+ self.__rv_map: Dict[str, RandomVariable] = {}
1247
+ self._reserved_names: Set[str] = {self._clean_name(name) for name in dir(self)}
1248
+
1249
+ # Force the rv map cache to be updated.
1250
+ # This may raise an exception.
1251
+ _ = self._rv_map
1252
+
1253
+ def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
1254
+ """
1255
+ As per `PGM.new_rv`.
1256
+ Delegate creating a new random variable to the PGM.
1257
+
1258
+ Returns:
1259
+ a RandomVariable object belonging to the PGM.
1260
+ """
1261
+ return self._pgm.new_rv(name, states)
1262
+
1263
+ def __len__(self) -> int:
1264
+ return len(self._pgm.rvs)
1265
+
1266
+ def __getitem__(self, index: int) -> RandomVariable:
1267
+ return self._pgm.rvs[index]
1268
+
1269
+ def items(self) -> Iterable[Tuple[str, RandomVariable]]:
1270
+ return self._rv_map.items()
1271
+
1272
+ def keys(self) -> Iterable[str]:
1273
+ return self._rv_map.keys()
1274
+
1275
+ def values(self) -> Iterable[RandomVariable]:
1276
+ return self._rv_map.values()
1277
+
1278
+ def get(self, rv_name: str, default=None):
1279
+ return self._rv_map.get(self._clean_name(rv_name), default)
1280
+
1281
+ def __call__(self, rv_name: str) -> RandomVariable:
1282
+ return self._rv_map[self._clean_name(rv_name)]
1283
+
1284
+ def __getattr__(self, rv_name: str) -> RandomVariable:
1285
+ return self(rv_name)
1286
+
1287
+ @property
1288
+ def _rv_map(self) -> Dict[str, RandomVariable]:
1289
+ """
1290
+ Get the cached random variable map, updating as needed if the PGM changed.
1291
+
1292
+ Returns:
1293
+ a mapping from random variable name to random variable
1294
+ """
1295
+ if len(self.__rv_map) != len(self._pgm.rvs):
1296
+ # There is a difference between the map and the PGM - create a new map.
1297
+ self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
1298
+ if len(self.__rv_map) != len(self._pgm.rvs):
1299
+ raise RuntimeError(f'random variable names are not unique')
1300
+ if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
1301
+ raise RuntimeError(f'random variable names clash with reserved names.')
1302
+ return self.__rv_map
1303
+
1304
+ def _clean_name(self, name: str) -> str:
1305
+ """
1306
+ Adjust the case of the given name as needed.
1307
+ """
1308
+ return name.lower() if self._ignore_case else name
1309
+
1310
+
1311
+ class Factor:
1312
+ """
1313
+ A PGM factor over one or more random variables declares a relationship between
1314
+ those variables. A Factor also has a potential function associated with
1315
+ it which defines a real-number value with each combination of states of
1316
+ the random variables.
1317
+
1318
+ The default potential function for a factor is a unique ZeroPotentialFunction.
1319
+
1320
+ The order of a Factors random variables is important as many things will be
1321
+ co-indexed with the random variables. For example, the shape of a Factor is
1322
+ the tuple of random variable lengths.
1323
+
1324
+ Note that multiple factors may share a potential function, so long as they all
1325
+ belong to the same PGM object and have the same shape.
1326
+ """
1327
+
1328
+ def __init__(self, pgm: PGM, *rvs: RandomVariable):
1329
+ """
1330
+ Add a new factor to the given PGM.
1331
+
1332
+ Assumes:
1333
+ The given random variables all belong to this PGM.
1334
+ The random variables contain no duplicates.
1335
+
1336
+ Args:
1337
+ pgm: the PGM that the factor will belong to.
1338
+ *rvs: the random variables.
1339
+
1340
+ Returns:
1341
+ a Factor object belonging to this PGM.
1342
+ """
1343
+ if len(set(rvs)) != len(rvs):
1344
+ raise ValueError('duplicated random variable in factor')
1345
+ if len(rvs) == 0:
1346
+ raise ValueError('must be at least one random variable')
1347
+ if any(rv.pgm is not pgm for rv in rvs):
1348
+ raise ValueError('random variable not from the same PGM')
1349
+
1350
+ self._pgm: PGM = pgm
1351
+ self._idx: int = pgm.number_of_factors
1352
+ self._rvs: Sequence[RandomVariable] = tuple(rvs)
1353
+ self._shape: Shape = tuple(len(rv) for rv in rvs)
1354
+
1355
+ self._zero_potential_function: ZeroPotentialFunction = ZeroPotentialFunction(self)
1356
+ self._potential_function: PotentialFunction = self._zero_potential_function
1357
+
1358
+ # Register self with our PGM
1359
+ # noinspection PyProtectedMember
1360
+ pgm._factors += (self,)
1361
+
1362
+ @property
1363
+ def rvs(self) -> Sequence[RandomVariable]:
1364
+ """
1365
+ Returns:
1366
+ The random variables of this factor.
1367
+ """
1368
+ return self._rvs
1369
+
1370
+ @property
1371
+ def pgm(self) -> PGM:
1372
+ """
1373
+ Returns:
1374
+ The PGM that this factor belongs to.
1375
+ """
1376
+ return self._pgm
1377
+
1378
+ @property
1379
+ def idx(self) -> int:
1380
+ """
1381
+ Returns:
1382
+ The index of this factor into the PGM.
1383
+
1384
+ Ensures:
1385
+ `self.pgm.factors[self.idx] is self`.
1386
+ """
1387
+ return self._idx
1388
+
1389
+ @property
1390
+ def shape(self) -> Shape:
1391
+ return self._shape
1392
+
1393
+ @property
1394
+ def number_of_states(self) -> int:
1395
+ """
1396
+ How many distinct states are covered by this Factor.
1397
+ """
1398
+ return self._potential_function.number_of_states
1399
+
1400
+ def __str__(self) -> str:
1401
+ """
1402
+ Return a human-readable string to represent this factor.
1403
+ This is intended mainly for debugging purposes.
1404
+ """
1405
+ return '(' + ', '.join([repr(str(rv)) for rv in self._rvs]) + ')'
1406
+
1407
+ def __len__(self) -> int:
1408
+ """
1409
+ Returns:
1410
+ the number of random variables.
1411
+ """
1412
+ return len(self._rvs)
1413
+
1414
+ @overload
1415
+ def __getitem__(self, index: int) -> RandomVariable:
1416
+ ...
1417
+
1418
+ @overload
1419
+ def __getitem__(self, index: slice) -> Sequence[RandomVariable]:
1420
+ ...
1421
+
1422
+ def __getitem__(self, index):
1423
+ return self._rvs[index]
1424
+
1425
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
1426
+ """
1427
+ Iterate over all possible instances, in natural index order (i.e.,
1428
+ last random variable changing most quickly).
1429
+
1430
+ Args:
1431
+ flip: if true, then first random variable changes most quickly
1432
+
1433
+ Returns:
1434
+ an iterator over tuples, each tuple holds random variable
1435
+ state indexes, co-indexed with this object's shape, i.e., self.shape.
1436
+ """
1437
+ return self.function.instances(flip)
1438
+
1439
+ def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
1440
+ """
1441
+ Iterate over all possible instances of parent random variable, in
1442
+ natural index order (i.e., last random variable changing most quickly).
1443
+
1444
+ Args:
1445
+ flip: if true, then first random variable changes most quickly
1446
+
1447
+ Returns:
1448
+ an iteration over tuples, each tuple holds random variable states
1449
+ co-indexed with this object's 'parent' shape, i.e., `self.shape[1:]`.
1450
+ """
1451
+ return self.function.parent_instances(flip)
1452
+
1453
+ @property
1454
+ def is_zero(self) -> bool:
1455
+ """
1456
+ Is the potential function of this factor set to the special 'zero' potential function.
1457
+ """
1458
+ return self._potential_function is self._zero_potential_function
1459
+
1460
+ @property
1461
+ def function(self) -> PotentialFunction:
1462
+ return self._potential_function
1463
+
1464
+ @function.setter
1465
+ def function(self, function: PotentialFunction | Factor) -> None:
1466
+ """
1467
+ Set the potential function for this PGM factor to the given potential function
1468
+ or factor.
1469
+
1470
+ Assumes:
1471
+ The given potential function belongs to the same PGM as this Factor.
1472
+ The potential function has the correct shape.
1473
+ """
1474
+ if isinstance(function, Factor):
1475
+ function = function.function
1476
+ assert isinstance(function, PotentialFunction)
1477
+
1478
+ if self._potential_function is function:
1479
+ # nothing to do
1480
+ return
1481
+
1482
+ if function.pgm is not self._pgm:
1483
+ raise ValueError(f'the given function is not of the same PGM as the factor')
1484
+
1485
+ if function.shape != self._shape:
1486
+ raise ValueError(f'incorrect function shape: expected {self._shape}, got {function.shape}')
1487
+
1488
+ if isinstance(function, ZeroPotentialFunction):
1489
+ self.set_zero()
1490
+ else:
1491
+ self._potential_function = function
1492
+
1493
+ def set_zero(self) -> ZeroPotentialFunction:
1494
+ """
1495
+ Set the factor's potential function to its original ZeroPotentialFunction.
1496
+
1497
+ Returns:
1498
+ the potential function.
1499
+ """
1500
+ self._potential_function = self._zero_potential_function
1501
+ return self._potential_function
1502
+
1503
+ def set_dense(self) -> DensePotentialFunction:
1504
+ """
1505
+ Set to the potential function to a new `DensePotentialFunction` object.
1506
+
1507
+ Returns:
1508
+ the potential function.
1509
+ """
1510
+ self._potential_function = DensePotentialFunction(self)
1511
+ return self._potential_function
1512
+
1513
+ def set_sparse(self) -> SparsePotentialFunction:
1514
+ """
1515
+ Set to the potential function to a new `SparsePotentialFunction` object.
1516
+
1517
+ Returns:
1518
+ the potential function.
1519
+ """
1520
+ self._potential_function = SparsePotentialFunction(self)
1521
+ return self._potential_function
1522
+
1523
+ def set_compact(self) -> CompactPotentialFunction:
1524
+ """
1525
+ Set to the potential function to a new `CompactPotentialFunction` object.
1526
+
1527
+ Returns:
1528
+ the potential function.
1529
+ """
1530
+ self._potential_function = CompactPotentialFunction(self)
1531
+ return self._potential_function
1532
+
1533
+ def set_clause(self, *key: int) -> ClausePotentialFunction:
1534
+ """
1535
+ Set to the potential function to a new `ClausePotentialFunction` object.
1536
+
1537
+ Args:
1538
+ key: defines the random variable states of the clause. The key is a sequence of
1539
+ random variable state indexes, co-indexed with `Factor.rvs`.
1540
+
1541
+ Returns:
1542
+ the potential function.
1543
+
1544
+ Raises:
1545
+ KeyError: if the key is not valid for the shape of the factor.
1546
+ """
1547
+ self._potential_function = ClausePotentialFunction(self, key)
1548
+ return self._potential_function
1549
+
1550
+ def set_cpt(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> CPTPotentialFunction:
1551
+ """
1552
+ Set to the potential function to a new `CPTPotentialFunction` object.
1553
+
1554
+ Args:
1555
+ tolerance: a tolerance when testing if values are equal to zero or one.
1556
+
1557
+ Returns:
1558
+ the potential function.
1559
+
1560
+ Raises:
1561
+ ValueError: if tolerance is negative.
1562
+ """
1563
+ self._potential_function = CPTPotentialFunction(self, tolerance)
1564
+ return self._potential_function
1565
+
1566
+
1567
+ @dataclass(frozen=True, eq=True, slots=True)
1568
+ class ParamId:
1569
+ """
1570
+ A ParamId identifies a parameter of a potential function.
1571
+
1572
+ Parameter identifiers uniquely identify every parameter within a PGM.
1573
+
1574
+ A ParamId is immutable and hashable.
1575
+ """
1576
+ function_id: int
1577
+ param_idx: int
1578
+
1579
+
1580
+ class PotentialFunction(ABC):
1581
+ """
1582
+ A potential function defines the potential values for a Factor, where
1583
+ a factor joins one or more variables of a PGM.
1584
+
1585
+ A potential function may be shared by several Factors of a PGM,
1586
+ i.e., can be applied to multiple variables.
1587
+
1588
+ The `shape` of a potential function is a tuple of integers which defines
1589
+ the number of variables, len(shape), and the number of states of each
1590
+ variable, shape[i].
1591
+
1592
+ The potential function value for variable states (x = i, y = j, ...) is given by
1593
+ self[i, j, ...], i.e., self.__getitem__((i, j, ...)). The tuple, (i, j, ...), is
1594
+ known as a Key.
1595
+
1596
+ The values of a potential function are defined by potential function parameters.
1597
+ The number of potential function parameters is given by number_of_parameters.
1598
+ The value of each parameter is given by get_param(i), where i is the parameter index.
1599
+
1600
+ Every valid key of the potential function is mapped either mapped to a parameter or is
1601
+ "guaranteed zero" which means that the value is zero and cannot be changed by changing
1602
+ the values of the potential function's parameters.
1603
+ """
1604
+
1605
+ def __init__(self, factor: Factor):
1606
+ """
1607
+ Create a potential function compatible with the given factor.
1608
+
1609
+ Ensures:
1610
+ Does not hold a reference to the given factor.
1611
+ Does not register the potential function with the PGM.
1612
+
1613
+ Args:
1614
+ factor: which factor is this potential function is compatible with.
1615
+ """
1616
+ self._pgm: PGM = factor.pgm
1617
+ self._shape: Shape = factor.shape
1618
+ self._number_of_states = _multiply(self._shape)
1619
+
1620
+ @property
1621
+ def pgm(self) -> PGM:
1622
+ """
1623
+ Returns:
1624
+ The PGM that this potential function belong to.
1625
+ """
1626
+ return self._pgm
1627
+
1628
+ @property
1629
+ def shape(self) -> Shape:
1630
+ """
1631
+ Returns:
1632
+ The shape of this potential function.
1633
+ """
1634
+ return self._shape
1635
+
1636
+ @property
1637
+ def number_of_rvs(self) -> int:
1638
+ """
1639
+ Returns:
1640
+ The number of random variables in this potential function.
1641
+ """
1642
+ return len(self._shape)
1643
+
1644
+ @property
1645
+ def number_of_states(self) -> int:
1646
+ """
1647
+ How many distinct states are covered by this potential function.
1648
+
1649
+ Returns:
1650
+ The size of the state space of this potential function.
1651
+ """
1652
+ return self._number_of_states
1653
+
1654
+ @property
1655
+ def number_of_parent_states(self) -> int:
1656
+ """
1657
+ How many distinct states are covered by this potential function parents,
1658
+ i.e., excluding the first random variable.
1659
+
1660
+ Returns:
1661
+ The size of the state space of this potential function parent random variables.
1662
+ """
1663
+ return _multiply(self._shape[1:])
1664
+
1665
+ def count_usage(self) -> int:
1666
+ """
1667
+ Check all PGM factors to count the number of times that this potential function
1668
+ is used.
1669
+
1670
+ Returns:
1671
+ the number of factors that use this potential function.
1672
+ """
1673
+ return sum(1 for factor in self._pgm.factors if factor.function is self)
1674
+
1675
+ def check_key(self, key: Key) -> Instance:
1676
+ """
1677
+ Convert the key into an instance.
1678
+
1679
+ Arg:
1680
+ key: defines an instance in the state space of the potential function.
1681
+
1682
+ Returns:
1683
+ an instance, which is a tuple of state indexes, co-indexed with `self.rvs`.
1684
+
1685
+ Raises:
1686
+ KeyError: if the key is not valid for the shape of the factor.
1687
+ """
1688
+ return check_key(self._shape, key)
1689
+
1690
+ def valid_key(self, key: Key) -> bool:
1691
+ """
1692
+ Is the given key valid.
1693
+
1694
+ Arg:
1695
+ key: defines an instance in the state space of the potential function.
1696
+
1697
+ Returns:
1698
+ True only if the given key is valid.
1699
+ """
1700
+ return valid_key(self._shape, key)
1701
+
1702
+ def valid_parameter(self, param_idx: int) -> bool:
1703
+ """
1704
+ Is the given parameter index valid.
1705
+
1706
+ Arg:
1707
+ param_idx: a parameter index.
1708
+
1709
+ Returns:
1710
+ True only if `0 <= param_idx < self.number_of_parameters`.
1711
+ """
1712
+ return 0 <= param_idx < self.number_of_parameters
1713
+
1714
+ @property
1715
+ def is_sparse(self) -> bool:
1716
+ """
1717
+ Are there any 'guaranteed zero' parameters values.
1718
+
1719
+ Returns:
1720
+ True only if `self.number_of_not_guaranteed_zero < self._number_of_states`.
1721
+ """
1722
+ return self.number_of_not_guaranteed_zero < self._number_of_states
1723
+
1724
+ @property
1725
+ @abstractmethod
1726
+ def number_of_not_guaranteed_zero(self) -> int:
1727
+ """
1728
+ How many of the states of this potential function are not 'guaranteed zero'.
1729
+ That is, how many keys are associated with a parameter.
1730
+
1731
+ Returns:
1732
+ The number of valid keys that are associated with a parameter.
1733
+
1734
+ Ensures:
1735
+ 0 <= self.number_of_not_guaranteed_zero <= self.number_of_states.
1736
+ """
1737
+ ...
1738
+
1739
+ @property
1740
+ @abstractmethod
1741
+ def number_of_parameters(self) -> int:
1742
+ """
1743
+ Get the number of parameters defining the potential function values.
1744
+ Each valid key of the function maps either to a parameter
1745
+ is 'guaranteed zero'.
1746
+
1747
+ Returns:
1748
+ The number of parameters.
1749
+
1750
+ Ensures:
1751
+ 0 <= self.number_of_parameters <= self.number_of_not_guaranteed_zero.
1752
+ """
1753
+ ...
1754
+
1755
+ @property
1756
+ @abstractmethod
1757
+ def params(self) -> Iterable[Tuple[int, float]]:
1758
+ """
1759
+ Iterate the parameters and their associated values.
1760
+
1761
+ Returns:
1762
+ An iterable over (param_idx, value) tuples, for every possible parameter.
1763
+
1764
+ Assumes:
1765
+ The potential function is not mutated while iterating.
1766
+ """
1767
+ ...
1768
+
1769
+ @property
1770
+ @abstractmethod
1771
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
1772
+ """
1773
+ Iterate the keys that have a parameter associated with them.
1774
+
1775
+ Returns:
1776
+ An iterable over (key, param_idx, value) tuples, for every key with an associated parameter.
1777
+
1778
+ Assumes:
1779
+ The potential function is not mutated while iterating.
1780
+ """
1781
+ ...
1782
+
1783
+ @abstractmethod
1784
+ def __getitem__(self, key: Key) -> float:
1785
+ """
1786
+ Get the potential function value for the given instance key.
1787
+
1788
+ Arg:
1789
+ key: defines an instance in the state space of the potential function.
1790
+
1791
+ Returns:
1792
+ The value of the potential function for the given key.
1793
+
1794
+ Assumes:
1795
+ self.valid_key(key).
1796
+ """
1797
+ ...
1798
+
1799
+ @abstractmethod
1800
+ def param_value(self, param_idx: int) -> float:
1801
+ """
1802
+ Get the potential function value by parameter index.
1803
+
1804
+ Arg:
1805
+ param_idx: a parameter index.
1806
+
1807
+ Assumes:
1808
+ `self.valid_parameter(param_idx)`.
1809
+ """
1810
+ ...
1811
+
1812
+ @abstractmethod
1813
+ def param_idx(self, key: Key) -> Optional[int]:
1814
+ """
1815
+ Get the parameter index for the given potential function random variables states (key).
1816
+
1817
+ Arg:
1818
+ key: defines an instance in the state space of the potential function.
1819
+
1820
+ Returns:
1821
+ either `None` indicating a "guaranteed zero" value, or the parameter index holding
1822
+ the potential function value for the key.
1823
+ """
1824
+ ...
1825
+
1826
+ def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
1827
+ """
1828
+ Is the potential function set with parameters values conforming to a
1829
+ Conditional Probability Table.
1830
+
1831
+ Every parameter value must be non-negative.
1832
+ For every state of the parent (non-first slots)
1833
+ the sum of the parameters over the child states (first slots)
1834
+ must be either 1 or 0.
1835
+
1836
+ Assumes:
1837
+ tolerance is non-negative.
1838
+
1839
+ Args:
1840
+ tolerance: a tolerance when testing if values are equal to zero or one.
1841
+
1842
+ Returns:
1843
+ True only if the potential function is compatible with being a CPT.
1844
+ """
1845
+ # This default implementation calculates the result the long way, by checking
1846
+ # every valid key of the potential function.
1847
+ # Subclasses may override this implementation.
1848
+ low: float = 1.0 - tolerance
1849
+ high: float = 1.0 + tolerance
1850
+ for parent_state in self.parent_instances():
1851
+ total: float = sum(
1852
+ self[(state,) + tuple(parent_state)]
1853
+ for state in range(self.shape[0])
1854
+ )
1855
+ if not ((low <= total <= high) or (0 <= total <= tolerance)):
1856
+ return False
1857
+ return True
1858
+
1859
+ def natural_param_idx(self, key: Key) -> int:
1860
+ """
1861
+ Get the natural parameter index for the given key. This is the same index as used
1862
+ by a DensePotentialFunction with the same shape.
1863
+
1864
+ Args:
1865
+ key: is a valid key of the potential function, referring to an instance in the factor's state space.
1866
+
1867
+ Assumes:
1868
+ `self.valid_key(key)` is true.
1869
+
1870
+ Returns:
1871
+ a hypothetical parameter index assuming that every valid key has a unique parameter
1872
+ as per DensePotentialFunction.
1873
+ """
1874
+ return _natural_key_idx(self._shape, key)
1875
+
1876
+ def param_id(self, param_idx: int) -> ParamId:
1877
+ """
1878
+ Get a hashable object to represent the parameter with the given parameter index.
1879
+
1880
+ Arg:
1881
+ param_idx: a parameter index.
1882
+
1883
+ Returns:
1884
+ a hashable ParamId object for the parameter of this potential function.
1885
+
1886
+ Raises:
1887
+ ValueError: if the parameter index is not valid.
1888
+ """
1889
+ if not (0 <= param_idx < self.number_of_parameters):
1890
+ raise ValueError(f'invalid parameter index: {param_idx}')
1891
+ return ParamId(id(self), param_idx)
1892
+
1893
+ def items(self) -> Iterable[Tuple[Instance, float]]:
1894
+ """
1895
+ Iterate over all keys and values of this potential function.
1896
+
1897
+ Returns:
1898
+ An iterator over all (key, value) pairs, where key is an Instance and value
1899
+ is the value of the potential function for the key.
1900
+ """
1901
+ for key in _combos_ranges(self._shape, flip=True):
1902
+ yield key, self[key]
1903
+
1904
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
1905
+ """
1906
+ Iterate over all possible instances, in natural index order (i.e.,
1907
+ last random variable changing most quickly).
1908
+
1909
+ Args:
1910
+ flip: if true, then first random variable changes most quickly
1911
+
1912
+ Returns:
1913
+ an iterator over tuples, each tuple holds random variable
1914
+ state indexes, co-indexed with this object's shape, i.e., self.shape.
1915
+ """
1916
+ return _combos_ranges(self._shape, flip=not flip)
1917
+
1918
+ def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
1919
+ """
1920
+ Iterate over all possible instances of parent random variable, in
1921
+ natural index order (i.e., last random variable changing most quickly).
1922
+
1923
+ Args:
1924
+ flip: if true, then first random variable changes most quickly
1925
+
1926
+ Returns:
1927
+ an iteration over tuples, each tuple holds random variable states
1928
+ co-indexed with this object's 'parent' shape, i.e., `self.shape[1:]`.
1929
+ """
1930
+ return _combos_ranges(self._shape[1:], flip=not flip)
1931
+
1932
+ def __str__(self) -> str:
1933
+ """
1934
+ Provide a human-readable representation of this potential function.
1935
+ This is intended mainly for debugging purposes.
1936
+ """
1937
+ shape_str: str = ', '.join(str(x) for x in self._shape)
1938
+ return f'{self.__class__.__name__}({shape_str})'
1939
+
1940
+ def dump(
1941
+ self,
1942
+ *,
1943
+ prefix: str = '',
1944
+ indent: str = ' ',
1945
+ show_function_values: bool = False,
1946
+ show_id_class: bool = True,
1947
+ ) -> None:
1948
+ """
1949
+ Print a dump of the function.
1950
+ This is intended for debugging purposes.
1951
+
1952
+ Args:
1953
+ prefix: optional prefix for indenting all lines.
1954
+ indent: additional prefix to use for extra indentation.
1955
+ show_function_values: if true, then the function values will be dumped.
1956
+ show_id_class: if true, then the function id and class will be dumped.
1957
+ """
1958
+
1959
+ shape_str: str = ', '.join(str(x) for x in self._shape)
1960
+
1961
+ if show_id_class:
1962
+ print(f'{prefix}id: {id(self)}')
1963
+ print(f'{prefix}class: {self.__class__.__name__}')
1964
+ print(f'{prefix}usage: {self.count_usage()}')
1965
+ print(f'{prefix}rvs: {self.number_of_rvs}')
1966
+ print(f'{prefix}shape: ({shape_str})')
1967
+ print(f'{prefix}states: {self._number_of_states}')
1968
+ print(f'{prefix}guaranteed zero: {self._number_of_states - self.number_of_not_guaranteed_zero}')
1969
+ print(f'{prefix}not guaranteed zero: {self.number_of_not_guaranteed_zero}')
1970
+ print(f'{prefix}parameters: {self.number_of_parameters}')
1971
+ if show_function_values:
1972
+ next_prefix = prefix + indent
1973
+ for key, param_idx, value in self.keys_with_param:
1974
+ print(f'{next_prefix}{param_idx} {key} = {value}')
1975
+
1976
+
1977
+ class ZeroPotentialFunction(PotentialFunction):
1978
+ """
1979
+ A ZeroPotentialFunction behaves like a DensePotentialFunction
1980
+ in that there is a parameter for each possible key.
1981
+ However, a PGM user has no way to change parameter values.
1982
+ Parameter values are always zero.
1983
+ Despite the inability to change the value of the parameters,
1984
+ no key is considered 'guaranteed zero'.
1985
+
1986
+ The primary use of a ZeroPotentialFunction is as a placeholder
1987
+ within a factor, prior to parameter learning.
1988
+ """
1989
+ __slots__ = ()
1990
+
1991
+ def __init__(self, factor: Factor):
1992
+ """
1993
+ Create a potential function for the given factor.
1994
+
1995
+ Ensures:
1996
+ Does not hold a reference to the given factor.
1997
+ Does not register the potential function with the PGM.
1998
+
1999
+ Args:
2000
+ factor: which factor is this potential function is compatible with.
2001
+ """
2002
+ super().__init__(factor)
2003
+
2004
+ @property
2005
+ def number_of_not_guaranteed_zero(self) -> int:
2006
+ return self.number_of_states
2007
+
2008
+ @property
2009
+ def number_of_parameters(self) -> int:
2010
+ return self.number_of_states
2011
+
2012
+ @property
2013
+ def params(self) -> Iterable[Tuple[int, float]]:
2014
+ for param_idx in range(self.number_of_parameters):
2015
+ yield param_idx, 0
2016
+
2017
+ @property
2018
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2019
+ for param_idx, instance in enumerate(self.instances()):
2020
+ yield instance, param_idx, 0
2021
+
2022
+ def __getitem__(self, key: Key) -> float:
2023
+ self.check_key(key)
2024
+ return 0
2025
+
2026
+ def param_value(self, param_idx: int) -> float:
2027
+ if not self.valid_parameter(param_idx):
2028
+ raise ValueError(f'invalid parameter index: {param_idx}')
2029
+ return 0
2030
+
2031
+ def param_idx(self, key: Key) -> int:
2032
+ return _natural_key_idx(self._shape, key)
2033
+
2034
+ def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
2035
+ return True
2036
+
2037
+
2038
+ class DensePotentialFunction(PotentialFunction):
2039
+ """
2040
+ A dense (tabular) potential function.
2041
+ There is one parameter for each valid key of the potential function.
2042
+ The initial value for each parameter is zero.
2043
+ It is possible independently change any value corresponding to any key.
2044
+ """
2045
+
2046
+ def __init__(self, factor: Factor):
2047
+ """
2048
+ Create a potential function for the given factor.
2049
+
2050
+ Ensures:
2051
+ Does not hold a reference to the given factor.
2052
+ Does not register the potential function with the PGM.
2053
+
2054
+ Args:
2055
+ factor: which factor is this potential function is compatible with.
2056
+ """
2057
+ super().__init__(factor)
2058
+ self._values: NDArrayFloat64 = np.zeros(self.number_of_states, dtype=np.float64)
2059
+
2060
+ @property
2061
+ def number_of_not_guaranteed_zero(self) -> int:
2062
+ return self.number_of_states
2063
+
2064
+ @property
2065
+ def number_of_parameters(self) -> int:
2066
+ return self.number_of_states
2067
+
2068
+ def __getitem__(self, key: Key) -> float:
2069
+ return self._values.item(self.param_idx(key))
2070
+
2071
+ def param_value(self, param_idx: int) -> float:
2072
+ return self._values.item(param_idx)
2073
+
2074
+ def param_idx(self, key: Key) -> Optional[int]:
2075
+ return self.natural_param_idx(key)
2076
+
2077
+ @property
2078
+ def params(self) -> Iterable[Tuple[int, float]]:
2079
+ # Type warning due to numpy type erasure
2080
+ # noinspection PyTypeChecker
2081
+ return enumerate(self._values)
2082
+
2083
+ @property
2084
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2085
+ for param_idx, key in enumerate(self.instances()):
2086
+ value: float = self.param_value(param_idx)
2087
+ yield key, param_idx, value
2088
+
2089
+ # Mutators
2090
+
2091
+ def __setitem__(self, key: Key, value: float) -> None:
2092
+ """
2093
+ Set the potential function value, for a given key.
2094
+
2095
+ Arg:
2096
+ key: defines an instance in the state space of the potential function.
2097
+ value: the new value of the potential function for the given key.
2098
+
2099
+ Assumes:
2100
+ self.valid_key(key).
2101
+ """
2102
+ self._values[self.param_idx(key)] = value
2103
+
2104
+ def set_param_value(self, param_idx: int, value: float) -> None:
2105
+ """
2106
+ Set the parameter value.
2107
+
2108
+ Arg:
2109
+ param_idx: is the index of the parameter.
2110
+ value: the new value of the potential function for the given key.
2111
+
2112
+ Assumes:
2113
+ self.valid_param(param_idx).
2114
+ """
2115
+ self._values[param_idx] = value
2116
+
2117
+ def clear(self) -> DensePotentialFunction:
2118
+ """
2119
+ Set all values of the potential function to zero.
2120
+
2121
+ Returns:
2122
+ self
2123
+ """
2124
+ return self.set_all(0)
2125
+
2126
+ def normalise_cpt(self) -> DensePotentialFunction:
2127
+ """
2128
+ Normalise the parameter values as if this was a CPT.
2129
+ That is, treat the first random variable as the child and the others as parents;
2130
+ for each combination of parent states, ensure the parameters over the child
2131
+ states sum to 1 (or 0).
2132
+
2133
+ Assumes:
2134
+ There are no negative parameter values.
2135
+
2136
+ Returns:
2137
+ self
2138
+ """
2139
+ child = self._shape[0]
2140
+ parents = self._shape[1:]
2141
+ for parent_states in _combos_ranges(parents):
2142
+ keys = [(c,) + parent_states for c in range(child)]
2143
+ total = sum(self[key] for key in keys)
2144
+ if total != 0 and total != 1:
2145
+ for key in keys:
2146
+ self[key] /= total
2147
+ return self
2148
+
2149
+ def normalise(self, grouping_positions: Sequence[int] = ()) -> DensePotentialFunction:
2150
+ """
2151
+ Convert the potential function to a CPT with 'grouping_positions' nominating
2152
+ the parent random variables.
2153
+
2154
+ I.e., for each possible key of the function with the same value at each
2155
+ grouping position, the sum of values for matching keys in the factor is scaled
2156
+ to be 1 (or 0).
2157
+
2158
+ Parameter 'grouping_positions' are indices into `self.shape`. For example, the
2159
+ grouping positions of a factor with parent rvs 'conditioning_rvs', then
2160
+ grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
2161
+
2162
+ Args:
2163
+ grouping_positions: indices into `self.shape`.
2164
+
2165
+ Returns:
2166
+ self
2167
+ """
2168
+ _normalise_potential_function(self, grouping_positions)
2169
+ return self
2170
+
2171
+ def set_iter(self, values: Iterable[float]) -> DensePotentialFunction:
2172
+ """
2173
+ Set the values of the potential function using the given iterator.
2174
+
2175
+ Mapping instances to values is as follows:
2176
+ Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2177
+ values[0] represents instance (0,0)
2178
+ values[1] represents instance (0,1)
2179
+ values[2] represents instance (0,2)
2180
+ values[3] represents instance (1,0)
2181
+ values[4] represents instance (1,1)
2182
+ values[5] represents instance (1,2).
2183
+
2184
+ For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
2185
+
2186
+ Args:
2187
+ values: an iterable providing values to use.
2188
+
2189
+ Returns:
2190
+ self
2191
+ """
2192
+ self._values = np.fromiter(
2193
+ values,
2194
+ dtype=np.float64,
2195
+ count=self.number_of_parameters
2196
+ )
2197
+ return self
2198
+
2199
+ def set_stream(self, stream: Callable[[], float]) -> DensePotentialFunction:
2200
+ """
2201
+ Set the values of the potential function by repeatedly calling the stream function.
2202
+ The order of values is the same as set_iter.
2203
+
2204
+ For example, to set to random numbers, use `self.set_stream(random.random)`.
2205
+
2206
+ Args:
2207
+ stream: a callable taking no arguments, returning the values to use.
2208
+
2209
+ Returns:
2210
+ self
2211
+ """
2212
+ return self.set_iter(iter(stream, None))
2213
+
2214
+ def set_flat(self, *value: float) -> DensePotentialFunction:
2215
+ """
2216
+ Set the values of the potential function to the given values.
2217
+ The order of values is the same as set_iter.
2218
+
2219
+ Args:
2220
+ value: the values to use.
2221
+
2222
+ Returns:
2223
+ self
2224
+
2225
+ Raises:
2226
+ ValueError: if `len(value) != self.number_of_states`.
2227
+ """
2228
+ if len(value) != self.number_of_states:
2229
+ raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
2230
+ return self.set_iter(value)
2231
+
2232
+ def set_all(self, value: float) -> DensePotentialFunction:
2233
+ """
2234
+ Set all values of the potential function to the given value.
2235
+
2236
+ Args:
2237
+ value: the value to use.
2238
+
2239
+ Returns:
2240
+ self
2241
+ """
2242
+ return self.set_iter(_repeat(value))
2243
+
2244
+ def set_uniform(self) -> DensePotentialFunction:
2245
+ """
2246
+ Set all values of the potential function 1/number_of_states.
2247
+
2248
+ Returns:
2249
+ self
2250
+ """
2251
+ return self.set_all(1.0 / self.number_of_states)
2252
+
2253
+
2254
+ class SparsePotentialFunction(PotentialFunction):
2255
+ """
2256
+ A sparse potential function.
2257
+
2258
+ There is one parameter for each non-zero key value.
2259
+ The user may set the value for any key and parameters will
2260
+ be automatically reconfigured as needed. Setting the value for
2261
+ a key to zero disassociates the key from its parameter and
2262
+ thus makes that key "guaranteed zero".
2263
+ """
2264
+
2265
+ def __init__(self, factor: Factor):
2266
+ """
2267
+ Create a potential function for the given factor.
2268
+
2269
+ Ensures:
2270
+ Does not hold a reference to the given factor.
2271
+ Does not register the potential function with the PGM.
2272
+
2273
+ Args:
2274
+ factor: which factor is this potential function is compatible with.
2275
+ """
2276
+ super().__init__(factor)
2277
+ self._values: List[float] = []
2278
+ self._params: Dict[Instance, int] = {}
2279
+
2280
+ @property
2281
+ def number_of_not_guaranteed_zero(self) -> int:
2282
+ return len(self._params)
2283
+
2284
+ @property
2285
+ def number_of_parameters(self) -> int:
2286
+ return len(self._params)
2287
+
2288
+ def __getitem__(self, key: Key) -> float:
2289
+ param_idx: Optional[int] = self.param_idx(key)
2290
+ if param_idx is None:
2291
+ return 0
2292
+ else:
2293
+ return self._values[param_idx]
2294
+
2295
+ def param_value(self, param_idx: int) -> float:
2296
+ return self._values[param_idx]
2297
+
2298
+ def param_idx(self, key: Key) -> Optional[int]:
2299
+ return self._params.get(_key_to_instance(key))
2300
+
2301
+ @property
2302
+ def params(self) -> Iterable[Tuple[int, float]]:
2303
+ return enumerate(self._values)
2304
+
2305
+ @property
2306
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2307
+ for key, param_idx in self._params.items():
2308
+ value: float = self._values[param_idx]
2309
+ yield key, param_idx, value
2310
+
2311
+ # Mutators
2312
+
2313
+ def __setitem__(self, key: Key, value: float) -> None:
2314
+ """
2315
+ Set the potential function value, for a given key.
2316
+
2317
+ If value is zero, then the key will become "guaranteed zero".
2318
+
2319
+ Arg:
2320
+ key: defines an instance in the state space of the potential function.
2321
+ value: the new value of the potential function for the given key.
2322
+
2323
+ Assumes:
2324
+ self.valid_key(key).
2325
+ """
2326
+ instance: Instance = _key_to_instance(key)
2327
+ param_idx: Optional[int] = self._params.get(instance)
2328
+
2329
+ if param_idx is None:
2330
+ if value == 0:
2331
+ # Nothing to do
2332
+ return
2333
+ param_idx = len(self._values)
2334
+ self._values.append(value)
2335
+ self._params[instance] = param_idx
2336
+ return
2337
+
2338
+ if value != 0:
2339
+ # Simple case
2340
+ self._values[param_idx] = value
2341
+ return
2342
+
2343
+ # This is the case where the key was associated with a parameter
2344
+ # but the value is being set to zero, so we
2345
+ # need to clear an existing non-zero parameter.
2346
+ # This code operates by first ensuring the parameter is the last one,
2347
+ # then popping the last parameter.
2348
+
2349
+ end: int = len(self._values) - 1
2350
+ if param_idx != end:
2351
+ # need to swap the parameter with the end.
2352
+ self._values[param_idx] = self._values[end]
2353
+
2354
+ for test_instance, test_param_idx in self._params.items():
2355
+ if test_param_idx == end:
2356
+ self._params[test_instance] = param_idx
2357
+ # There will only be one, so we can break now
2358
+ break
2359
+
2360
+ # Remove the parameter
2361
+ self._values.pop()
2362
+ self._params.pop(instance)
2363
+
2364
+ def set_param_value(self, param_idx: int, value: float) -> None:
2365
+ """
2366
+ Set the parameter value.
2367
+
2368
+ Arg:
2369
+ param_idx: is the index of the parameter.
2370
+ value: the new value of the potential function for the given key.
2371
+
2372
+ Assumes:
2373
+ self.valid_param(param_idx).
2374
+ """
2375
+ self._values[param_idx] = value
2376
+
2377
+ def clear(self) -> SparsePotentialFunction:
2378
+ """
2379
+ Set all values of the potential function to zero.
2380
+
2381
+ Returns:
2382
+ self
2383
+ """
2384
+ self._values = []
2385
+ self._params = {}
2386
+ return self
2387
+
2388
+ def normalise_cpt(self) -> SparsePotentialFunction:
2389
+ """
2390
+ Normalise the parameter values as if this was a CPT.
2391
+ That is, treat the first random variable as the child and the others as parents;
2392
+ for each combination of parent states, ensure the parameters over
2393
+ the child states sum to 1 (or 0).
2394
+
2395
+ Returns:
2396
+ self
2397
+ """
2398
+ grouping_positions = list(range(1, self.number_of_rvs))
2399
+ _normalise_potential_function(self, grouping_positions)
2400
+ return self
2401
+
2402
+ def normalise(self, grouping_positions=()) -> SparsePotentialFunction:
2403
+ """
2404
+ Convert the potential function to a CPT with 'grouping_positions' nominating
2405
+ the parent random variables.
2406
+
2407
+ I.e., for each possible key of the function with the same value at each
2408
+ grouping position, the sum of values for matching keys in the factor is scaled
2409
+ to be 1 (or 0).
2410
+
2411
+ Parameter 'grouping_positions' are indices into function.shape. For example, the
2412
+ grouping positions of a factor with parent rvs 'conditioning_rvs', then
2413
+ grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
2414
+
2415
+ Returns:
2416
+ self
2417
+ """
2418
+ _normalise_potential_function(self, grouping_positions)
2419
+ return self
2420
+
2421
+ def set_iter(self, values: Iterable[float]) -> SparsePotentialFunction:
2422
+ """
2423
+ Set the values of the potential function using the given iterator.
2424
+
2425
+ Mapping instances to values is as follows:
2426
+ Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2427
+ values[0] represents instance (0,0)
2428
+ values[1] represents instance (0,1)
2429
+ values[2] represents instance (0,2)
2430
+ values[3] represents instance (1,0)
2431
+ values[4] represents instance (1,1)
2432
+ values[5] represents instance (1,2).
2433
+
2434
+ For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
2435
+
2436
+ Args:
2437
+ values: an iterable providing values to use.
2438
+
2439
+ Returns:
2440
+ self
2441
+ """
2442
+ self.clear()
2443
+ for instance, value in zip(self.instances(), values):
2444
+ if value != 0:
2445
+ self._params[instance] = len(self._values)
2446
+ self._values.append(value)
2447
+ return self
2448
+
2449
+ def set_stream(self, stream: Callable[[], float]) -> SparsePotentialFunction:
2450
+ """
2451
+ Set the values of the potential function by repeatedly calling the stream function.
2452
+ The order of values is the same as set_iter.
2453
+
2454
+ For example, to set to random numbers, use `self.set_stream(random.random)`.
2455
+
2456
+ Args:
2457
+ stream: a callable taking no arguments, returning the values to use.
2458
+
2459
+ Returns:
2460
+ self
2461
+ """
2462
+ return self.set_iter(iter(stream, None))
2463
+
2464
+ def set_flat(self, *value: float) -> SparsePotentialFunction:
2465
+ """
2466
+ Set the values of the potential function to the given values.
2467
+ The order of values is the same as set_iter.
2468
+
2469
+ Args:
2470
+ *value: the values to use.
2471
+
2472
+ Returns:
2473
+ self
2474
+
2475
+ Raises:
2476
+ ValueError: if `len(value) != self.number_of_states`.
2477
+ """
2478
+ if len(value) != self.number_of_states:
2479
+ raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
2480
+ return self.set_iter(value)
2481
+
2482
+ def set_all(self, value: float) -> SparsePotentialFunction:
2483
+ """
2484
+ Set all values of the potential function to the given value.
2485
+
2486
+ Args:
2487
+ value: the value to use.
2488
+
2489
+ Returns:
2490
+ self
2491
+ """
2492
+ if value == 0:
2493
+ return self.clear()
2494
+ else:
2495
+ return self.set_iter(_repeat(value))
2496
+
2497
+ def set_uniform(self) -> SparsePotentialFunction:
2498
+ """
2499
+ Set all values of the potential function 1/number_of_states.
2500
+
2501
+ Returns:
2502
+ self
2503
+ """
2504
+ return self.set_all(1.0 / self.number_of_states)
2505
+
2506
+
2507
+ class CompactPotentialFunction(PotentialFunction):
2508
+ """
2509
+ A compact potential function is sparse, where values for keys of
2510
+ the same value are represented by a single parameter.
2511
+
2512
+ There is one parameter for each unique, non-zero key value.
2513
+ The user may set the value for any key and parameters will
2514
+ be automatically reconfigured as needed. Setting the value for
2515
+ a key to zero disassociates the key from its parameter and
2516
+ thus makes that key "guaranteed zero".
2517
+ """
2518
+
2519
+ def __init__(self, factor: Factor):
2520
+ """
2521
+ Create a potential function for the given factor.
2522
+
2523
+ Ensures:
2524
+ Does not hold a reference to the given factor.
2525
+ Does not register the potential function with the PGM.
2526
+
2527
+ Args:
2528
+ factor: which factor is this potential function is compatible with.
2529
+ """
2530
+ super().__init__(factor)
2531
+ self._values: List[float] = []
2532
+ self._counts: List[int] = []
2533
+ self._map: Dict[Instance, int] = {}
2534
+ self._inv_map: Dict[float, int] = {}
2535
+
2536
+ @property
2537
+ def number_of_not_guaranteed_zero(self) -> int:
2538
+ return len(self._map)
2539
+
2540
+ @property
2541
+ def number_of_parameters(self) -> int:
2542
+ return len(self._values)
2543
+
2544
+ @property
2545
+ def params(self) -> Iterable[Tuple[int, float]]:
2546
+ return enumerate(self._values)
2547
+
2548
+ @property
2549
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2550
+ for key, param_idx in self._map.items():
2551
+ value: float = self._values[param_idx]
2552
+ yield key, param_idx, value
2553
+
2554
+ def __getitem__(self, key: Key) -> float:
2555
+ param_idx: Optional[int] = self.param_idx(key)
2556
+ if param_idx is None:
2557
+ return 0
2558
+ else:
2559
+ return self._values[param_idx]
2560
+
2561
+ def param_value(self, param_idx: int) -> float:
2562
+ return self._values[param_idx]
2563
+
2564
+ def param_idx(self, key: Key) -> Optional[int]:
2565
+ return self._map.get(_key_to_instance(key))
2566
+
2567
+ # Mutators
2568
+
2569
+ def __setitem__(self, key: Key, value: float) -> None:
2570
+ """
2571
+ Set the potential function value, for a given key.
2572
+
2573
+ If value is zero, then the key will become "guaranteed zero".
2574
+ If the value is the same as an existing parameter value, then
2575
+ that parameter will be reused.
2576
+
2577
+ Arg:
2578
+ key: defines an instance in the state space of the potential function.
2579
+ value: the new value of the potential function for the given key.
2580
+
2581
+ Assumes:
2582
+ self.valid_key(key).
2583
+ """
2584
+ instance: Instance = _key_to_instance(key)
2585
+
2586
+ param_idx: Optional[int] = self._map.get(instance)
2587
+
2588
+ if param_idx is None:
2589
+ # previous value for the key was zero
2590
+ if value == 0:
2591
+ # nothing to do
2592
+ return
2593
+ param_idx: Optional[int] = self._inv_map.get(value)
2594
+ if param_idx is not None:
2595
+ # the value already exists in the function, so reuse it
2596
+ self._map[instance] = param_idx
2597
+ self._counts[param_idx] += 1
2598
+ else:
2599
+ # need to allocate a new value
2600
+ new_param_idx: int = len(self._values)
2601
+ self._values.append(value)
2602
+ self._counts.append(1)
2603
+ self._inv_map[value] = new_param_idx
2604
+ self._map[instance] = new_param_idx
2605
+ return
2606
+
2607
+ # the key previously had a non-zero value
2608
+ prev_value: float = self._values[param_idx]
2609
+
2610
+ if value == prev_value:
2611
+ # nothing to do
2612
+ return
2613
+
2614
+ reference_count: int = self._counts[param_idx]
2615
+ if reference_count == 1:
2616
+ if value != 0:
2617
+ # simple case
2618
+ self._values[param_idx] = value
2619
+ else:
2620
+ # need to remove the parameter
2621
+ self._remove_param(param_idx)
2622
+ self._map.pop(instance)
2623
+ self._inv_map.pop(prev_value)
2624
+ return
2625
+
2626
+ # decrement the reference count of the previous parameter
2627
+ self._counts[param_idx] = reference_count - 1
2628
+
2629
+ # allocate the key to a different parameter
2630
+ param_idx: Optional[int] = self._inv_map.get(value)
2631
+ if param_idx is not None:
2632
+ # the value already exists in the function, so reuse it
2633
+ self._map[instance] = param_idx
2634
+ self._counts[param_idx] += 1
2635
+ else:
2636
+ # need to allocate a new value
2637
+ new_param_idx: int = len(self._values)
2638
+ self._values.append(value)
2639
+ self._counts.append(1)
2640
+ self._inv_map[value] = new_param_idx
2641
+ self._map[instance] = new_param_idx
2642
+
2643
+ def set_iter(self, values: Iterable[float]) -> CompactPotentialFunction:
2644
+ """
2645
+ Set the values of the potential function using the given iterator.
2646
+
2647
+ Mapping instances to `values` is as follows:
2648
+ Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2649
+ values[0] represents instance (0,0)
2650
+ values[1] represents instance (0,1)
2651
+ values[2] represents instance (0,2)
2652
+ values[3] represents instance (1,0)
2653
+ values[4] represents instance (1,1)
2654
+ values[5] represents instance (1,2).
2655
+
2656
+ For example: to set to counts, starting from 1, use `self.set_iter(itertools.count(1))`.
2657
+
2658
+ Args:
2659
+ values: an iterable providing values to use.
2660
+
2661
+ Returns:
2662
+ self
2663
+ """
2664
+ self.clear()
2665
+ for instance, value in zip(self.instances(), values):
2666
+ self[instance] = value
2667
+ return self
2668
+
2669
+ def set_stream(self, stream: Callable[[], float]) -> CompactPotentialFunction:
2670
+ """
2671
+ Set the values of the potential function by repeatedly calling the stream function.
2672
+ The order of values is the same as set_iter.
2673
+
2674
+ For example, to set to random numbers, use `self.set_stream(random.random)`.
2675
+
2676
+ Args:
2677
+ stream: a callable taking no arguments, returning the values to use.
2678
+
2679
+ Returns:
2680
+ self
2681
+ """
2682
+ return self.set_iter(iter(stream, None))
2683
+
2684
+ def set_flat(self, *value: float) -> CompactPotentialFunction:
2685
+ """
2686
+ Set the values of the potential function to the given values.
2687
+ The order of values is the same as set_iter.
2688
+
2689
+ Args:
2690
+ value: the values to use.
2691
+
2692
+ Returns:
2693
+ self
2694
+
2695
+ Raises:
2696
+ ValueError: if `len(value) != self.number_of_states`.
2697
+ """
2698
+ if len(value) != self.number_of_states:
2699
+ raise ValueError(f'wrong number of values: expected {self.number_of_states}, got {len(value)}')
2700
+ return self.set_iter(value)
2701
+
2702
+ def set_all(self, value: float) -> CompactPotentialFunction:
2703
+ """
2704
+ Set all values of the potential function to the given value.
2705
+
2706
+ Args:
2707
+ value: the value to use.
2708
+
2709
+ Returns:
2710
+ self
2711
+ """
2712
+ self.clear()
2713
+ if value != 0:
2714
+ self._values = [value]
2715
+ self._counts = [self.number_of_states]
2716
+ self._inv_map = {value: 0}
2717
+ self._map = {instance: 0 for instance in self.instances()}
2718
+ return self
2719
+
2720
+ def set_uniform(self) -> CompactPotentialFunction:
2721
+ """
2722
+ Set all values of the potential function 1/number_of_states.
2723
+
2724
+ Returns:
2725
+ self
2726
+ """
2727
+ return self.set_all(1.0 / self.number_of_states)
2728
+
2729
+ def clear(self) -> CompactPotentialFunction:
2730
+ """
2731
+ Set all values of the potential function to zero.
2732
+
2733
+ Returns:
2734
+ self
2735
+ """
2736
+ self._values = []
2737
+ self._counts = []
2738
+ self._map = {}
2739
+ self._inv_map = {}
2740
+ return self
2741
+
2742
+ def _remove_param(self, param_idx: int) -> None:
2743
+ """
2744
+ Remove the indexed parameter from self._params and self._counts.
2745
+ If the parameter is not at the end of the list of parameters
2746
+ then it will be swapped with the last parameter in the list.
2747
+ """
2748
+
2749
+ # ensure the parameter is at the end of the list
2750
+ end: int = len(self._values) - 1
2751
+ if param_idx != end:
2752
+ # swap `param_idx` with `end`
2753
+ end_value: float = self._values[end]
2754
+ self._values[param_idx] = end_value
2755
+ self._counts[param_idx] = self._counts[end]
2756
+ self._inv_map[end_value] = param_idx
2757
+ for instance, instance_param_idx in self._map.items():
2758
+ if instance_param_idx == end:
2759
+ self._map[instance] = param_idx
2760
+
2761
+ # remove the end parameter
2762
+ self._values.pop()
2763
+ self._counts.pop()
2764
+
2765
+
2766
+ class ClausePotentialFunction(PotentialFunction):
2767
+ """
2768
+ A clause potential function represents a clause From a CNF formula.
2769
+ I.e. a clause over variables X, Y, Z, is a disjunction of the form: 'X=x or Y=y or Z=z'.
2770
+
2771
+ A clause potential function is guaranteed zero for a key where the clause is false,
2772
+ i.e., when 'X != x and Y != y and Z != z'.
2773
+
2774
+ For keys where the clause is true, the value of the potential function
2775
+ is given by the only parameter of the potential function. That parameter
2776
+ is called the clause 'weight' and is notionally 1.
2777
+
2778
+ The weight of a clause is permitted to be zero, but that is _not_ equivalent to
2779
+ guaranteed-zero.
2780
+ """
2781
+
2782
+ def __init__(self, factor: Factor, key: Key, weight: float = 1):
2783
+ """
2784
+ Create a clause potential function for the given factor.
2785
+
2786
+ Ensures:
2787
+ Does not hold a reference to the given factor.
2788
+ Does not register the potential function with the PGM.
2789
+
2790
+ Raises:
2791
+ KeyError: if the key is not valid for the shape of the factor.
2792
+
2793
+ Args:
2794
+ factor: which factor is this potential function is compatible with.
2795
+ key: defines the random variable states of the clause.
2796
+ """
2797
+ super().__init__(factor)
2798
+ self._weight: float = weight
2799
+ self._clause: Instance = self.check_key(key)
2800
+ self._num_not_guaranteed_zero: int = _zero_space(self.shape)
2801
+
2802
+ @property
2803
+ def number_of_not_guaranteed_zero(self) -> int:
2804
+ return self._num_not_guaranteed_zero
2805
+
2806
+ @property
2807
+ def number_of_parameters(self) -> int:
2808
+ return 1
2809
+
2810
+ @property
2811
+ def params(self) -> Iterable[Tuple[int, float]]:
2812
+ return ((0, self._weight),)
2813
+
2814
+ @property
2815
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2816
+ value = self._weight
2817
+ for i in range(self.number_of_rvs):
2818
+ key = list(self._clause)
2819
+ for j in range(self.shape[i]):
2820
+ key[i] = j
2821
+ yield tuple(key), 0, value
2822
+
2823
+ def __getitem__(self, key: Key) -> float:
2824
+ instance: Instance = self.check_key(key)
2825
+ for key_state_idx, clause_state_idx in zip(instance, self._clause):
2826
+ if key_state_idx == clause_state_idx:
2827
+ return self._weight
2828
+ return 0
2829
+
2830
+ def param_value(self, param_idx: int) -> float:
2831
+ if param_idx != 0:
2832
+ raise IndexError(param_idx)
2833
+ return self._weight
2834
+
2835
+ def param_idx(self, key: Key) -> Optional[int]:
2836
+ instance: Instance = _key_to_instance(key)
2837
+ if instance == self._clause:
2838
+ return 0
2839
+ else:
2840
+ return None
2841
+
2842
+ def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
2843
+ """
2844
+ A ClausePotentialFunction can only be a CTP when all entries are zero.
2845
+ """
2846
+ return -tolerance <= self._weight <= tolerance
2847
+
2848
+ def is_sparse(self) -> bool:
2849
+ return True
2850
+
2851
+ @property
2852
+ def weight(self) -> float:
2853
+ """
2854
+ Returns:
2855
+ the "weight" parameter defining the potential function.
2856
+ """
2857
+ return self._weight
2858
+
2859
+ @property
2860
+ def clause(self) -> Instance:
2861
+ """
2862
+ Returns:
2863
+ the clause defining the potential function.
2864
+ """
2865
+ return self._clause
2866
+
2867
+ # Mutators
2868
+
2869
+ @weight.setter
2870
+ def weight(self, value: float) -> None:
2871
+ """
2872
+ Set the weight parameter to the given value.
2873
+ """
2874
+ self._weight = value
2875
+
2876
+ @clause.setter
2877
+ def clause(self, key: Key) -> None:
2878
+ """
2879
+ Set the clause to the given key.
2880
+
2881
+ Raises:
2882
+ KeyError: if the key is not valid for the shape of the factor.
2883
+ """
2884
+ self._clause = self.check_key(key)
2885
+
2886
+
2887
+ class CPTPotentialFunction(PotentialFunction):
2888
+ """
2889
+ A potential function implementing a sparse Conditional Probability Table (CPT).
2890
+
2891
+ The first random variable in the signature is the child, and the remaining random
2892
+ variables are parents.
2893
+
2894
+ For each instantiation of the parent random variables there is a Conditioned Probability
2895
+ Distribution (CPD) over the states of the child random variable.
2896
+
2897
+ If a CPD is not provided for a parent instantiation, then that parent instantiation
2898
+ is taken to have probability zero (i.e., all values of the CPD are guaranteed zero).
2899
+ """
2900
+
2901
+ def __init__(self, factor: Factor, tolerance: float):
2902
+ """
2903
+ Create a CPT potential function for the given factor.
2904
+
2905
+ Ensures:
2906
+ Does not hold a reference to the given factor.
2907
+ Does not register the potential function with the PGM.
2908
+
2909
+ Args:
2910
+ factor: which factor is this potential function is compatible with.
2911
+ tolerance: a tolerance when testing if values are equal to zero or one.
2912
+
2913
+ Raises:
2914
+ ValueError: if tolerance is negative.
2915
+ """
2916
+ super().__init__(factor)
2917
+
2918
+ if tolerance < 0:
2919
+ raise ValueError('tolerance cannot be negative')
2920
+
2921
+ self._child_size: int = self.shape[0]
2922
+ self._parent_shape: Shape = self.shape[1:]
2923
+ self._map: Dict[Instance, int] = {}
2924
+ self._values: List[float] = []
2925
+ self._inv_map: List[Instance] = []
2926
+ self._tolerance = tolerance
2927
+
2928
+ @property
2929
+ def number_of_not_guaranteed_zero(self) -> int:
2930
+ return len(self._values)
2931
+
2932
+ @property
2933
+ def number_of_parameters(self) -> int:
2934
+ return len(self._values)
2935
+
2936
+ def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
2937
+ if tolerance >= self._tolerance:
2938
+ return True
2939
+ else:
2940
+ # The requested tolerance is tighter than ensured.
2941
+ # Need to use the default method.
2942
+ return super().is_cpt(tolerance)
2943
+
2944
+ @property
2945
+ def params(self) -> Iterable[Tuple[int, float]]:
2946
+ return enumerate(self._values)
2947
+
2948
+ @property
2949
+ def keys_with_param(self) -> Iterable[Tuple[Instance, int, float]]:
2950
+ child_size: int = self._child_size
2951
+ for param_idx, value in enumerate(self._values):
2952
+ parent: Instance = self._inv_map[param_idx // child_size]
2953
+ key: Instance = (param_idx % child_size,) + tuple(parent)
2954
+ yield key, param_idx, value
2955
+
2956
+ def __getitem__(self, key: Key) -> float:
2957
+ param_idx: Optional[int] = self.param_idx(key)
2958
+ if param_idx is None:
2959
+ return 0
2960
+ else:
2961
+ return self._values[param_idx]
2962
+
2963
+ def param_value(self, param_idx: int) -> float:
2964
+ return self._values[param_idx]
2965
+
2966
+ def param_idx(self, key: Key) -> Optional[int]:
2967
+ instance: Instance = self.check_key(key)
2968
+ offset: Optional[int] = self._map.get(instance[1:])
2969
+ if offset is None:
2970
+ return None
2971
+ else:
2972
+ return offset + instance[0]
2973
+
2974
+ @property
2975
+ def parent_shape(self) -> Shape:
2976
+ """
2977
+ What is the shape of the parents.
2978
+ """
2979
+ return self._parent_shape
2980
+
2981
+ @property
2982
+ def number_of_parent_states(self) -> int:
2983
+ """
2984
+ How many combinations of parent states.
2985
+ """
2986
+ return _multiply(self._parent_shape)
2987
+
2988
+ @property
2989
+ def number_of_child_states(self) -> int:
2990
+ """
2991
+ Number of child random variable states.
2992
+
2993
+ This is the same as the number of values in each conditional
2994
+ probability distribution. This is equivalent to `self.shape[0]`.
2995
+
2996
+ Returns:
2997
+ the number of child states.
2998
+ """
2999
+ return self._child_size
3000
+
3001
+ def get_cpd(self, parent_states: Key) -> List[float]:
3002
+ """
3003
+ Get the CPD conditioned on parent states indicated by `parent_states`.
3004
+
3005
+ Args:
3006
+ parent_states: indicates the parent states.
3007
+
3008
+ Returns:
3009
+ The conditioned probability distribution.
3010
+ """
3011
+ parent_instance: Instance = check_key(self._parent_shape, parent_states)
3012
+ offset: Optional[int] = self._map.get(parent_instance)
3013
+ child_size: int = self._child_size
3014
+ if offset is None:
3015
+ return [0] * child_size
3016
+ else:
3017
+ return self._values[offset:offset + child_size]
3018
+
3019
+ def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
3020
+ """
3021
+ Iterate over (parent_states, cpd) tuples. This will exclude zero CPDs.
3022
+
3023
+ Warning:
3024
+ Do not change CPDs to (or from) zero while iterating over them.
3025
+
3026
+ Returns:
3027
+ an iterator over pairs (instance, cpd) where,
3028
+ instance: is indicates the state of the parent random variables.
3029
+ cpd: is the conditioned probability distribution, for the parent instance.
3030
+ """
3031
+ for parent_instance, offset in self._map.items():
3032
+ cpd = self._values[offset:offset + self._child_size]
3033
+ yield parent_instance, cpd
3034
+
3035
+ # Mutators
3036
+
3037
+ def clear(self) -> CPTPotentialFunction:
3038
+ """
3039
+ Set all values of the potential function to zero.
3040
+
3041
+ Returns:
3042
+ self
3043
+ """
3044
+ self._map = {}
3045
+ self._values = []
3046
+ self._inv_map = []
3047
+ return self
3048
+
3049
+ def set_uniform(self) -> CPTPotentialFunction:
3050
+ """
3051
+ Set each CPD to a uniform distribution.
3052
+
3053
+ Returns:
3054
+ self
3055
+ """
3056
+ self.clear()
3057
+ for parent_states in self.parent_instances():
3058
+ self.set_cpd_uniform(parent_states)
3059
+ return self
3060
+
3061
+ def set_random(self, random: Callable[[], float], sparsity: float = 0) -> CPTPotentialFunction:
3062
+ """
3063
+ Set the values of the potential function to random CPDs.
3064
+
3065
+ Args:
3066
+ random: is a stream of random numbers, assumed uniformly distributed in the interval [0, 1].
3067
+ sparsity: sets the expected proportion of probability values that are zero.
3068
+
3069
+ Returns:
3070
+ self
3071
+ """
3072
+ self.clear()
3073
+ for parent_states in self.parent_instances():
3074
+ self.set_cpd_random(parent_states, random, sparsity)
3075
+ return self
3076
+
3077
+ def set(self, *rows: Tuple[Key, Sequence[float]]) -> CPTPotentialFunction:
3078
+ """
3079
+ Calls self.set_cpd(parent_states, cpd) for each row (parent_states, cpd)
3080
+ in rows. Any unmentioned parent states will have zero probabilities.
3081
+
3082
+ Example usage, assuming three Boolean random variables::
3083
+
3084
+ pgm.Factor(x, y, z).set_cpt().set(
3085
+ # y z x[0] x[1]
3086
+ ((0, 0), (0.1, 0.9)),
3087
+ ((0, 1), (0.1, 0.9)),
3088
+ ((1, 0), (0.1, 0.9)),
3089
+ ((1, 1), (0.1, 0.9))
3090
+ )
3091
+
3092
+ Args:
3093
+ rows: are tuples (key, cpd) used to set the potential function values.
3094
+
3095
+ Raises:
3096
+ ValueError: if a CPD is not valid.
3097
+
3098
+ Returns:
3099
+ self
3100
+ """
3101
+ self.clear()
3102
+ for parent_states, cpd in rows:
3103
+ self.set_cpd(parent_states, cpd)
3104
+ return self
3105
+
3106
+ def set_all(self, *cpds: Optional[Sequence[float]]) -> CPTPotentialFunction:
3107
+ """
3108
+ Set all CPDs using the given `cpds` which are taken to be in order of the parent states
3109
+ with the last variable of the parent changing state most rapidly, as per parent_states().
3110
+
3111
+ If insufficient CPDs are provided then the remaining parent instantiations are taken to be
3112
+ impossible (i.e. not set and guaranteed zero).
3113
+ If too many CPDs are provided then the extras are ignored.
3114
+ Any list entry may be None, indicating 'guaranteed zero' for the associated parent states.
3115
+
3116
+ Args:
3117
+ cpds: are the CPDs used to set the potential function values.
3118
+
3119
+ Raises:
3120
+ ValueError: if a CPD is not valid.
3121
+
3122
+ Returns:
3123
+ self
3124
+ """
3125
+ self.clear()
3126
+ for parent_states, cpd in zip(self.parent_instances(), cpds):
3127
+ self.set_cpd(parent_states, cpd)
3128
+ return self
3129
+
3130
+ def set_cpd(self, parent_states: Key, cpd: Optional[Sequence[float]]) -> CPTPotentialFunction:
3131
+ """
3132
+ Set the CPD of the given parent states to the given cpd.
3133
+ If cpd is None or all zeros, then this is equivalent to clear_cpd(parent_states).
3134
+
3135
+ Args:
3136
+ parent_states: indicates the CPD to set, based on the parent states.
3137
+ cpd: is a conditioned probability distribution, or None indicating `guaranteed zero`.
3138
+
3139
+ Raises:
3140
+ ValueError: if the CPD is not valid.
3141
+ KeyError if the key is not valid.
3142
+
3143
+ Returns:
3144
+ self
3145
+ """
3146
+ parent_instance: Instance = check_key(self._parent_shape, parent_states)
3147
+
3148
+ if cpd is None:
3149
+ self._clear_cpd(parent_instance)
3150
+ return self
3151
+
3152
+ if len(cpd) != self._child_size:
3153
+ raise ValueError(f'CPD incorrect size: expected {self._child_size}, got {len(cpd)}')
3154
+ if not all(0 <= value <= 1 for value in cpd):
3155
+ raise ValueError(f'not a valid CPD: {cpd!r}')
3156
+
3157
+ total_value = sum(cpd)
3158
+ if total_value < self._tolerance:
3159
+ self._clear_cpd(parent_instance)
3160
+ return self
3161
+
3162
+ if total_value < 1 - self._tolerance or total_value > 1 + self._tolerance:
3163
+ raise ValueError(f'not a valid CPD: sum of values = {total_value}')
3164
+
3165
+ offset: Optional[int] = self._map.get(parent_instance)
3166
+ child_size: int = self._child_size
3167
+ if offset is None:
3168
+ offset = len(self._values)
3169
+ self._values.extend(cpd)
3170
+ self._map[parent_instance] = offset
3171
+ self._inv_map.append(parent_instance)
3172
+ else:
3173
+ self._values[offset:offset + child_size] = cpd
3174
+
3175
+ return self
3176
+
3177
+ def clear_cpd(self, parent_states: Key) -> CPTPotentialFunction:
3178
+ """
3179
+ Set the CPD of the given parent_states to all 'guaranteed zero'.
3180
+
3181
+ Args:
3182
+ parent_states: indicates the CPD to clear, based on the parent states.
3183
+
3184
+ Raises:
3185
+ KeyError if the key is not valid.
3186
+
3187
+ Returns:
3188
+ self
3189
+ """
3190
+ parent_instance: Instance = check_key(self._parent_shape, parent_states)
3191
+ self._clear_cpd(parent_instance)
3192
+ return self
3193
+
3194
+ def set_cpd_uniform(self, parent_states: Key) -> CPTPotentialFunction:
3195
+ """
3196
+ Set the CPD of the given parent_states to a uniform CPD.
3197
+
3198
+ Args:
3199
+ parent_states: indicates the CPD to clear, based on the parent states.
3200
+
3201
+ Raises:
3202
+ KeyError if the key is not valid.
3203
+
3204
+ Returns:
3205
+ self
3206
+ """
3207
+ num_states = self.number_of_child_states
3208
+ cpd = [1.0 / num_states] * num_states
3209
+ return self.set_cpd(parent_states, cpd)
3210
+
3211
+ def set_cpd_random(
3212
+ self,
3213
+ parent_states: Key,
3214
+ random: Callable[[], float],
3215
+ sparsity: float = 0,
3216
+ ) -> CPTPotentialFunction:
3217
+ """
3218
+ Set the CPD of the given parent_states to a random CPD.
3219
+
3220
+ Args:
3221
+ parent_states: identifies the CPD being set.
3222
+ random: is a stream of random numbers, assumed uniformly distributed in the interval [0, 1].
3223
+ sparsity: sets the expected proportion of probability values that are zero.
3224
+
3225
+ Returns:
3226
+ self
3227
+ """
3228
+ cpd = np.zeros(self.number_of_child_states, dtype=np.float64)
3229
+ if sparsity <= 0:
3230
+ for i in range(len(cpd)):
3231
+ cpd[i] = 0.0000001 + random()
3232
+ else:
3233
+ for i in range(len(cpd)):
3234
+ if random() > sparsity:
3235
+ cpd[i] = 0.0000001 + random()
3236
+ sum_value = np.sum(cpd)
3237
+ if sum_value > 0:
3238
+ cpd /= sum_value
3239
+ return self.set_cpd(parent_states, cpd)
3240
+ else:
3241
+ return self.clear_cpd(parent_states)
3242
+
3243
+ def _clear_cpd(self, parent_instance: Instance) -> None:
3244
+ """
3245
+ Remove the parent instance from the parameters
3246
+ """
3247
+ offset: Optional[int] = self._map.get(parent_instance)
3248
+ if offset is None:
3249
+ # nothing to do
3250
+ return
3251
+
3252
+ child_size: int = self._child_size
3253
+ end_offset: int = len(self._values) - child_size
3254
+ if offset != end_offset:
3255
+ # need to swap parameters
3256
+ end_cpd = self._values[end_offset:]
3257
+ end_parent_instance = self._inv_map[-1]
3258
+
3259
+ self._values[offset:offset + child_size] = end_cpd
3260
+ self._map[end_parent_instance] = offset
3261
+ self._inv_map[offset // child_size] = end_parent_instance
3262
+
3263
+ self._map.pop(parent_instance)
3264
+ self._inv_map.pop()
3265
+ for _ in range(child_size):
3266
+ self._values.pop()
3267
+
3268
+
3269
+ def default_pgm_name(pgm: PGM) -> str:
3270
+ """
3271
+ If no name is provided to a PGM constructor, then this will be the default name for the PGM.
3272
+
3273
+ Args:
3274
+ pgm: a PGM object.
3275
+
3276
+ Returns:
3277
+ a name for the PGM if none is given at construction time.
3278
+ """
3279
+ return 'PGM_' + str(id(pgm))
3280
+
3281
+
3282
+ def check_key(shape: Shape, key: Key) -> Instance:
3283
+ """
3284
+ Convert the key into an instance.
3285
+
3286
+ Args:
3287
+ shape: the shape defining the state space.
3288
+ key: a key into the state space.
3289
+
3290
+ Returns:
3291
+ A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
3292
+
3293
+ Raises:
3294
+ KeyError if the key is not valid for the given shape.
3295
+ """
3296
+ _key: Instance = _key_to_instance(key)
3297
+ if len(_key) != len(shape):
3298
+ raise KeyError(f'not a valid key for shape {shape}: {key!r}')
3299
+ if all((0 <= i <= m) for i, m in zip(_key, shape)):
3300
+ return tuple(_key)
3301
+ raise KeyError(f'not a valid key for shape {shape}: {key!r}')
3302
+
3303
+
3304
+ def valid_key(shape: Shape, key: Key) -> bool:
3305
+ """
3306
+ Is the given key valid.
3307
+
3308
+ Args:
3309
+ shape: the shape defining the state space.
3310
+ key: a key into the state space.
3311
+
3312
+ Returns:
3313
+ True only if tke key is valid for the given shape.
3314
+ """
3315
+ try:
3316
+ check_key(shape, key)
3317
+ return True
3318
+ except KeyError:
3319
+ return False
3320
+
3321
+
3322
+ def number_of_states(*rvs: RandomVariable) -> int:
3323
+ """
3324
+ Returns:
3325
+ What is the size of the state space, i.e., `multiply(len(rv) for rv in self.rvs)`.
3326
+ """
3327
+ return _multiply(len(rv) for rv in rvs)
3328
+
3329
+
3330
+ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]:
3331
+ """
3332
+ Enumerate instances of the given random variables.
3333
+
3334
+ Each instance is a tuples of state indexes, co-indexed with the given random variables.
3335
+
3336
+ The order is the natural index order (i.e., last random variable changing most quickly).
3337
+
3338
+ Args:
3339
+ flip: if true, then first random variable changes most quickly.
3340
+
3341
+ Returns:
3342
+ an iteration over instances, each instance is a tuple of state
3343
+ indexes, co-indexed with the given random variables.
3344
+ """
3345
+ shape = [len(rv) for rv in rvs]
3346
+ return _combos_ranges(shape, flip=not flip)
3347
+
3348
+
3349
+ def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterable[Sequence[Indicator]]:
3350
+ """
3351
+ Enumerate instances of the given random variables.
3352
+
3353
+ Each instance is a tuples of indicators, co-indexed with the given random variables.
3354
+
3355
+ The order is the natural index order (i.e., last random variable changing most quickly).
3356
+
3357
+ Args:
3358
+ flip: if true, then first random variable changes most quickly.
3359
+
3360
+ Returns:
3361
+ an iteration over tuples, each tuples holds random variable indicators
3362
+ co-indexed with the given random variables.
3363
+ """
3364
+ return _combos(rvs, flip=not flip)
3365
+
3366
+
3367
+ def _key_to_instance(key: Key) -> Instance:
3368
+ """
3369
+ Convert a key to an instance.
3370
+
3371
+ Args:
3372
+ key: a key into a state space.
3373
+
3374
+ Returns:
3375
+ A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
3376
+
3377
+ Assumes:
3378
+ The key is valid for the implied state space.
3379
+ """
3380
+ if isinstance(key, int):
3381
+ return (key,)
3382
+ else:
3383
+ return tuple(key)
3384
+
3385
+
3386
+ def _natural_key_idx(shape: Shape, key: Key) -> int:
3387
+ """
3388
+ What is the natural index of the given key, assuming the given shape.
3389
+
3390
+ The natural index of an instance is defined as the index of the
3391
+ instance if all instances for the shape are enumerated as per
3392
+ `rv_instances`.
3393
+
3394
+ Args:
3395
+ shape: the shape defining the state space.
3396
+ key: a key into the state space.
3397
+
3398
+ Returns:
3399
+ an index as per enumerated instances in their natural order, i.e.
3400
+ last random variable changing most quickly.
3401
+
3402
+ Assumes:
3403
+ The key is valid for the shape.
3404
+ """
3405
+ instance: Instance = _key_to_instance(key)
3406
+ result: int = instance[0]
3407
+ for s, i in zip(shape[1:], instance[1:]):
3408
+ result = result * s + i
3409
+ return result
3410
+
3411
+
3412
+ def _zero_space(shape: Shape) -> int:
3413
+ """
3414
+ Return the size of the zero space of the given shape. This is the number
3415
+ of possible instances in the state space that do not have a zero in the instance.
3416
+
3417
+ The zero space is the same as the shape but with one less state
3418
+ for each random variable.
3419
+
3420
+ Args:
3421
+ shape: the shape defining the state space.
3422
+
3423
+ Returns:
3424
+ the size of the zero space.
3425
+ """
3426
+ return _multiply(x - 1 for x in shape)
3427
+
3428
+
3429
+ def _normalise_potential_function(
3430
+ function: Union[DensePotentialFunction, SparsePotentialFunction],
3431
+ grouping_positions: Sequence[int],
3432
+ ) -> None:
3433
+ """
3434
+ Convert the potential function to a CPT with 'grouping_positions' nominating
3435
+ the parent random variables.
3436
+
3437
+ I.e., for each possible key of the function with the same value at each
3438
+ grouping position, the sum of values for matching keys in the factor is scaled
3439
+ to be 1 (or 0).
3440
+
3441
+ Parameter 'grouping_positions' are indices into `function.shape`. For example, the
3442
+ grouping positions of a factor with parent rvs 'conditioning_rvs', then
3443
+ grouping_positions = [i for i, rv in enumerate(factor.rvs) if rv in conditioning_rvs].
3444
+
3445
+ Args:
3446
+ function: the potential function to normalise.
3447
+ grouping_positions: indices into `function.shape`.
3448
+ """
3449
+ if len(grouping_positions) == 0:
3450
+ total = sum(
3451
+ function.param_value(param_idx)
3452
+ for param_idx in range(function.number_of_parameters)
3453
+ )
3454
+ if total != 0 and total != 1:
3455
+ for param_key, param_idx, param_value in function.keys_with_param:
3456
+ function.set_param_value(param_idx, param_value / total)
3457
+ else:
3458
+ group_sum = {}
3459
+ for param_key, param_idx, param_value in function.keys_with_param:
3460
+ group = tuple(param_key[i] for i in grouping_positions)
3461
+ group_sum[group] = group_sum.get(group, 0) + param_value
3462
+
3463
+ for param_key, param_idx, param_value in function.keys_with_param:
3464
+ group = tuple(param_key[i] for i in grouping_positions)
3465
+ total = group_sum[group]
3466
+ if total > 0:
3467
+ function.set_param_value(param_idx, param_value / total)
3468
+
3469
+
3470
+ _CLEAN_CHARS: Set[str] = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-+~?.')
3471
+
3472
+
3473
+ def _clean_str(s) -> str:
3474
+ """
3475
+ Quote a string if empty or not all characters are in _CLEAN_CHARS.
3476
+ This is used when rendering indicators.
3477
+ """
3478
+ s = str(s)
3479
+ if len(s) == 0 or not all(c in _CLEAN_CHARS for c in s):
3480
+ return repr(s)
3481
+ else:
3482
+ return s