compiled-knowledge 4.0.0a20__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37525 -0
  4. ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,406 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Sequence, Tuple, List, Iterator, Set, Iterable, Optional, Callable
5
+
6
+ import numpy as np
7
+
8
+ from ck.circuit import Circuit, VarNode, CircuitNode
9
+ from ck.pgm import PGM, ParamId, Factor, PotentialFunction, RandomVariable, ZeroPotentialFunction
10
+ from ck.pgm_circuit.slot_map import SlotMap, SlotKey
11
+ from ck.pgm_compiler.support.circuit_table import CircuitTable, TableInstance
12
+ from ck.utils.iter_extras import pairs
13
+ from ck.utils.map_list import MapList
14
+ from ck.utils.np_extras import NDArray, NDArrayFloat64
15
+
16
+
17
+ @dataclass
18
+ class FactorTables:
19
+ circuit: Circuit # The host circuit
20
+ number_of_indicators: int # number of indicator variables
21
+ number_of_parameters: int # number of parameter variables (i.e., non-const, in-use parameters)
22
+ slot_map: SlotMap # map from Indicator or ParamId object to a circuit var index.
23
+ tables: Sequence[CircuitTable] # one CircuitTable for each PGM factor.
24
+
25
+ # For a non-const, in-use parameter with id `param_id`, the PGM value of that
26
+ # parameter was `self.parameter_values[self.slot_map[param_id] - self.number_of_indicators]`.
27
+ parameter_values: NDArray
28
+
29
+ def get_table(self, factor: Factor) -> CircuitTable:
30
+ return self.tables[factor.idx]
31
+
32
+
33
+ def make_factor_tables(
34
+ pgm: PGM,
35
+ const_parameters: bool,
36
+ multiply_indicators: bool,
37
+ pre_prune_factor_tables: bool,
38
+ ) -> FactorTables:
39
+ """
40
+ Consistently and efficiently create circuit tables for factors of a PGM.
41
+
42
+ Creates:
43
+ * a circuit,
44
+ * a circuit variable for each indicator of the PGM,
45
+ * a circuit variable for each non-constant, in-use potential function parameter.
46
+ * a circuit table for each Factor of the PGM,
47
+
48
+ The parameter of each potential function will be converted either
49
+ eiter to a circuit constant (if const_parameters is true) or a circuit
50
+ variable (if const_parameters is false).
51
+
52
+ Random variables will be multiplied into factor circuit tables if
53
+ `multiply_indicators` is true.
54
+
55
+ A slot map will be created that maps PGM indicators and parameter ids to circuit var indices.
56
+ Specifically, a circuit var will be added for each indicator,
57
+ in the order they appear in `pgm.indicators`. Circuit vars for parameter ids will be added
58
+ after those for indicators, and only if const_parameters is false.
59
+
60
+ Args:
61
+ pgm: The PGM with the random variables, factors, and potential functions.
62
+ const_parameters: if true, then potential function parameters will be circuit constants,
63
+ otherwise they will be circuit variables, with entries in the returned slot map.
64
+ multiply_indicators: if true then indicator variables will be multiplied into an acceptable
65
+ factor.
66
+ pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
67
+
68
+ Returns:
69
+ FactorTables, holding a slot_map and a circuit table for each PGM factor.
70
+ """
71
+
72
+ # Create circuit and initialise the slot map with indicator variables
73
+ circuit = Circuit()
74
+ slot_map: Dict[SlotKey, int] = {
75
+ indicator: circuit.new_var().idx
76
+ for indicator in pgm.indicators
77
+ }
78
+
79
+ # Get the circuit table rows for each potential function
80
+ # functions_rows[id(function)] = rows for the function
81
+ functions_rows: Dict[int, _FunctionRows]
82
+ if const_parameters:
83
+ functions_rows = {
84
+ id(function): _rows_for_function_const(function, circuit)
85
+ for function in pgm.functions
86
+ }
87
+ else:
88
+ functions_rows = {
89
+ id(function): _rows_for_function_var(function, circuit, slot_map)
90
+ for function in pgm.functions
91
+ }
92
+
93
+ # Link factors to function rows.
94
+ # factor_rows[id(factor)] = rows for the factor
95
+ factor_rows: Dict[int, _FactorRows] = {}
96
+ for factor in pgm.factors:
97
+ rows: _FunctionRows = functions_rows[id(factor.function)]
98
+ rows.use_count += 1
99
+ factor_rows[id(factor)] = _FactorRows(factor, rows)
100
+
101
+ # Check to see if any factor rows can be pre-pruned.
102
+ if pre_prune_factor_tables:
103
+ _pre_prune_factor_tables(list(factor_rows.values()))
104
+
105
+ # Allocated random variables to factors
106
+ factors_mul_rvs: MapList[int, RandomVariable]
107
+ if multiply_indicators:
108
+ def _factor_size(_factor: Factor) -> int:
109
+ return len(factor_rows[id(_factor)])
110
+
111
+ factors_mul_rvs = _assign_rvs_to_factors(pgm, _factor_size)
112
+ else:
113
+ factors_mul_rvs = MapList() # no assignment of rvs to factors.
114
+
115
+ # Make a circuit table for each factor. `tables[factor.index]` is the circuit table for `factor`.
116
+ tables: List[CircuitTable] = [
117
+ _make_factor_table(factor, circuit, slot_map, factor_rows[id(factor)], factors_mul_rvs)
118
+ for factor in pgm.factors
119
+ ]
120
+
121
+ # Extract the parameter values (if they are circuit vars).
122
+ number_of_indicators: int = pgm.number_of_indicators
123
+ number_of_parameters: int = len(slot_map) - number_of_indicators
124
+ parameter_values: NDArrayFloat64 = np.zeros(number_of_parameters, dtype=np.float64)
125
+ if not const_parameters:
126
+ for function in pgm.functions:
127
+ for param_index, value in function.params:
128
+ param_id: ParamId = function.param_id(param_index)
129
+ slot: Optional[int] = slot_map.get(param_id)
130
+ if slot is not None:
131
+ parameter_values[slot - number_of_indicators] = value
132
+
133
+ return FactorTables(
134
+ circuit=circuit,
135
+ number_of_indicators=number_of_indicators,
136
+ number_of_parameters=number_of_parameters,
137
+ slot_map=slot_map,
138
+ tables=tables,
139
+ parameter_values=parameter_values,
140
+ )
141
+
142
+
143
+ def _assign_rvs_to_factors(
144
+ pgm: PGM,
145
+ factor_size: Callable[[Factor], int],
146
+ ) -> MapList[int, RandomVariable]:
147
+ """
148
+ Assign each random variable to the smallest factor containing it.
149
+
150
+ Returns:
151
+ a map from factor id to list of random variables assigned to that factor
152
+ """
153
+ factors = pgm.factors
154
+ rvs = pgm.rvs
155
+
156
+ # For each rv, get the factors it is in
157
+ rv_factors: MapList[int, Factor] = MapList() # rv index to list of Factors with that rv.
158
+ for factor in factors:
159
+ for rv in factor.rvs:
160
+ rv_factors.append(rv.idx, factor)
161
+
162
+ # For each rv, assign it to a factor for multiplication
163
+ factors_mul_rvs: MapList[int, RandomVariable] = MapList() # factor id to list of rvs
164
+ for rv_index in range(len(rvs)):
165
+ candidates: Sequence[Factor] = rv_factors.get(rv_index, ())
166
+ if len(candidates) > 0:
167
+ best_factor = min(candidates, key=factor_size)
168
+ factors_mul_rvs.append(id(best_factor), rvs[rv_index])
169
+
170
+ return factors_mul_rvs
171
+
172
+
173
+ class _FunctionRows:
174
+ def __init__(self, rows: Dict[TableInstance, CircuitNode], use_count: int = 0):
175
+ self.rows: Dict[TableInstance, CircuitNode] = rows
176
+ self.use_count: int = use_count
177
+
178
+
179
+ class _FactorRows:
180
+ def __init__(self, factor: Factor, rows: _FunctionRows):
181
+ self.rows: _FunctionRows = rows
182
+ self.rv_indexes: Tuple[int, ...] = tuple(rv.idx for rv in factor.rvs)
183
+
184
+ def __len__(self) -> int:
185
+ return len(self.rows.rows)
186
+
187
+ def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
188
+ return self.rows.rows.items()
189
+
190
+ def prune(self, extra_keys: Set[TableInstance]) -> None:
191
+ """
192
+ Remove the given keys from the factor's function rows.
193
+ """
194
+ if len(extra_keys) > 0:
195
+ new_rows: Dict[TableInstance, CircuitNode] = {
196
+ instance: node
197
+ for instance, node in self.rows.rows.items()
198
+ if instance not in extra_keys
199
+ }
200
+ if self.rows.use_count > 1:
201
+ self.rows.use_count -= 1
202
+ self.rows = _FunctionRows(new_rows, 1)
203
+ else:
204
+ self.rows.rows = new_rows
205
+
206
+
207
+ class _FactorPair:
208
+ def __init__(self, x: _FactorRows, y: _FactorRows):
209
+ self.x: _FactorRows = x
210
+ self.y: _FactorRows = y
211
+
212
+ x_set = set(self.x.rv_indexes)
213
+
214
+ # Identify all random variables used by x and y
215
+ self.all_rv_indexes: Set[int] = x_set.union(self.y.rv_indexes)
216
+
217
+ # Identify common random variables between x and y
218
+ # Keep them in a stable order
219
+ self.co_rv_indexes: Tuple[int, ...] = tuple(x_set.intersection(self.y.rv_indexes))
220
+
221
+ # Cache mappings from result Instance to index into source Instance (x or y).
222
+ # This will be used in indexing and product loops to pull our needed values
223
+ # from the source instances.
224
+ self.co_from_x_map = tuple(x.rv_indexes.index(rv_index) for rv_index in self.co_rv_indexes)
225
+ self.co_from_y_map = tuple(y.rv_indexes.index(rv_index) for rv_index in self.co_rv_indexes)
226
+
227
+ def prune(self) -> None:
228
+ """
229
+ Prune any rows from x and y that cannot join to each other.
230
+ """
231
+ co_from_x_map = self.co_from_x_map
232
+ co_from_y_map = self.co_from_y_map
233
+ x_rows = self.x.rows.rows
234
+ y_rows = self.y.rows.rows
235
+
236
+ x_co_set: Set[TableInstance] = {
237
+ tuple(instance[i] for i in co_from_x_map)
238
+ for instance in x_rows.keys()
239
+ }
240
+
241
+ y_co_set: Set[TableInstance] = {
242
+ tuple(instance[i] for i in co_from_y_map)
243
+ for instance in y_rows.keys()
244
+ }
245
+
246
+ # Keys in x that will not join to y
247
+ x_extra_keys: Set[TableInstance] = {
248
+ instance
249
+ for instance in x_rows.keys()
250
+ if tuple(instance[i] for i in co_from_x_map) not in y_co_set
251
+ }
252
+
253
+ # Keys in y that will not join to x
254
+ y_extra_keys: Set[TableInstance] = {
255
+ instance
256
+ for instance in y_rows.keys()
257
+ if tuple(instance[i] for i in co_from_y_map) not in x_co_set
258
+ }
259
+
260
+ self.x.prune(x_extra_keys)
261
+ self.y.prune(y_extra_keys)
262
+
263
+
264
+ def _pre_prune_factor_tables(factor_rows: Sequence[_FactorRows]) -> None:
265
+ """
266
+ It may be possible to reduce the size of a table for a factor.
267
+
268
+ If two factors contain a common random variable then at some point their product
269
+ will be formed, which may eliminate rows. This method identifies and removes
270
+ such rows.
271
+ """
272
+ # Find all pairs of factors that have at least one common random variable.
273
+ pairs_to_check: List[_FactorPair] = [
274
+ _FactorPair(f1, f2)
275
+ for f1, f2 in pairs(factor_rows)
276
+ if not set(f1.rv_indexes).isdisjoint(f1.rv_indexes)
277
+ ]
278
+
279
+ # Simple version.
280
+ for pair in pairs_to_check:
281
+ pair.prune()
282
+
283
+ # Earlier version.
284
+ # This version re-checks processed pairs that may get benefit from a subsequent pruning.
285
+ # Unfortunately, this is computationally expensive, and provides no practical benefit.
286
+ #
287
+ # pairs_done: List[_FactorPair] = []
288
+ # while len(pairs_to_check) > 0:
289
+ # pair: _FactorPair = pairs_to_check.pop()
290
+ # x: _FactorRows = pair.x
291
+ # y: _FactorRows = pair.y
292
+ #
293
+ # x_size = len(x)
294
+ # y_size = len(y)
295
+ # pair.prune()
296
+ #
297
+ # # See if any pairs need re-checking
298
+ # rvs_affected: Set[int] = set()
299
+ # if x_size != len(x):
300
+ # rvs_affected.update(x.rv_indexes)
301
+ # if y_size != len(y):
302
+ # rvs_affected.update(y.rv_indexes)
303
+ # if len(rvs_affected) > 0:
304
+ # next_pairs_done: List[_FactorPair] = []
305
+ # for pair in pairs_done:
306
+ # if rvs_affected.isdisjoint(pair.all_rv_indexes):
307
+ # next_pairs_done.append(pair)
308
+ # else:
309
+ # pairs_to_check.append(pair)
310
+ # pairs_done = next_pairs_done
311
+ #
312
+ # # Mark the current pair as done.
313
+ # pairs_done.append(pair)
314
+
315
+
316
+ def _make_factor_table(
317
+ factor: Factor,
318
+ circuit: Circuit,
319
+ slot_map: Dict[SlotKey, int],
320
+ rows: _FactorRows,
321
+ factors_mul_rvs: MapList[int, RandomVariable],
322
+ ) -> CircuitTable:
323
+ # Get random variables to multiply into the table
324
+ factor_mul_rvs: Sequence[RandomVariable] = factors_mul_rvs.get(id(factor), ())
325
+
326
+ # Create the empty circuit table
327
+ factor_rv_indexes: Sequence[int] = tuple(rv.idx for rv in factor.rvs)
328
+
329
+ if len(factor_mul_rvs) == 0:
330
+ # Trivial case - no random variables to multiply into the table.
331
+ return CircuitTable(circuit, factor_rv_indexes, rows.items())
332
+
333
+ # Work out what element in an instance of the factor will select the indicator
334
+ # variable for each mul rv.
335
+ # inst_to_mul[i] is the index into factor.rvs for factor_mul_rvs[i]
336
+ inst_to_mul: Sequence[int] = tuple(factor_rv_indexes.index(rv.idx) for rv in factor_mul_rvs)
337
+
338
+ # Map a state index of a mul rv to its indicator circuit variable.
339
+ # mul_rvs_vars[i][j] is the indicator circuit variable for factor_mul_rvs[i][j]
340
+ mul_rvs_vars: Sequence[Sequence[CircuitNode]] = tuple(
341
+ tuple(circuit.vars[slot_map[ind]] for ind in rv.indicators)
342
+ for rv in factor_mul_rvs
343
+ )
344
+
345
+ def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
346
+ for instance, node in rows.items():
347
+ to_mul = tuple(
348
+ mul_vars[instance[inst_index]]
349
+ for inst_index, mul_vars in zip(inst_to_mul, mul_rvs_vars)
350
+ )
351
+ if not node.is_one:
352
+ to_mul += (node,)
353
+ if len(to_mul) == 0:
354
+ yield instance, circuit.one
355
+ elif len(to_mul) == 1:
356
+ yield instance, to_mul[0]
357
+ else:
358
+ yield instance, circuit.optimised_mul(to_mul)
359
+
360
+ return CircuitTable(circuit, factor_rv_indexes, _result_rows())
361
+
362
+
363
+ def _rows_for_function_const(
364
+ function: PotentialFunction,
365
+ circuit: Circuit,
366
+ ) -> _FunctionRows:
367
+ """
368
+ Get the rows (instance, node) for the given potential function
369
+ where each node is a circuit constant.
370
+ This will exclude zero values.
371
+ """
372
+ if isinstance(function, ZeroPotentialFunction):
373
+ # shortcut
374
+ return _FunctionRows({})
375
+
376
+ return _FunctionRows({
377
+ tuple(instance): circuit.const(value)
378
+ for instance, _, value in function.keys_with_param
379
+ if value != 0
380
+ })
381
+
382
+
383
+ def _rows_for_function_var(
384
+ function: PotentialFunction,
385
+ circuit: Circuit,
386
+ slot_map: Dict[SlotKey, int],
387
+ ) -> _FunctionRows:
388
+ """
389
+ Get the rows (instance, node) for the given potential function
390
+ where each node is a circuit variable.
391
+ """
392
+
393
+ def _create_param_var(param_id: ParamId) -> VarNode:
394
+ """
395
+ Create a circuit variable for the given parameter id.
396
+ This assumes one does not already exist for the parameter id.
397
+ """
398
+ assert param_id not in slot_map.keys(), 'parameter should not already have a circuit var'
399
+ node: VarNode = circuit.new_var()
400
+ slot_map[param_id] = node.idx
401
+ return node
402
+
403
+ return _FunctionRows({
404
+ tuple(instance): _create_param_var(function.param_id(param_index))
405
+ for instance, param_index, _ in function.keys_with_param
406
+ })