compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (167) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +13 -0
  3. ck/circuit/circuit.c +38749 -0
  4. ck/circuit/circuit.cpython-313-darwin.so +0 -0
  5. ck/circuit/circuit_py.py +807 -0
  6. ck/circuit/tmp_const.py +74 -0
  7. ck/circuit_compiler/__init__.py +2 -0
  8. ck/circuit_compiler/circuit_compiler.py +26 -0
  9. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  10. ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
  11. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  12. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
  13. ck/circuit_compiler/interpret_compiler.py +223 -0
  14. ck/circuit_compiler/llvm_compiler.py +388 -0
  15. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  16. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  17. ck/circuit_compiler/support/__init__.py +0 -0
  18. ck/circuit_compiler/support/circuit_analyser.py +81 -0
  19. ck/circuit_compiler/support/input_vars.py +148 -0
  20. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  21. ck/example/__init__.py +53 -0
  22. ck/example/alarm.py +366 -0
  23. ck/example/asia.py +28 -0
  24. ck/example/binary_clique.py +32 -0
  25. ck/example/bow_tie.py +33 -0
  26. ck/example/cancer.py +37 -0
  27. ck/example/chain.py +38 -0
  28. ck/example/child.py +199 -0
  29. ck/example/clique.py +33 -0
  30. ck/example/cnf_pgm.py +39 -0
  31. ck/example/diamond_square.py +68 -0
  32. ck/example/earthquake.py +36 -0
  33. ck/example/empty.py +10 -0
  34. ck/example/hailfinder.py +539 -0
  35. ck/example/hepar2.py +628 -0
  36. ck/example/insurance.py +504 -0
  37. ck/example/loop.py +40 -0
  38. ck/example/mildew.py +38161 -0
  39. ck/example/munin.py +22982 -0
  40. ck/example/pathfinder.py +53674 -0
  41. ck/example/rain.py +39 -0
  42. ck/example/rectangle.py +161 -0
  43. ck/example/run.py +30 -0
  44. ck/example/sachs.py +129 -0
  45. ck/example/sprinkler.py +30 -0
  46. ck/example/star.py +44 -0
  47. ck/example/stress.py +64 -0
  48. ck/example/student.py +43 -0
  49. ck/example/survey.py +46 -0
  50. ck/example/triangle_square.py +54 -0
  51. ck/example/truss.py +49 -0
  52. ck/in_out/__init__.py +3 -0
  53. ck/in_out/parse_ace_lmap.py +216 -0
  54. ck/in_out/parse_ace_nnf.py +288 -0
  55. ck/in_out/parse_net.py +480 -0
  56. ck/in_out/parser_utils.py +185 -0
  57. ck/in_out/pgm_pickle.py +42 -0
  58. ck/in_out/pgm_python.py +268 -0
  59. ck/in_out/render_bugs.py +111 -0
  60. ck/in_out/render_net.py +177 -0
  61. ck/in_out/render_pomegranate.py +184 -0
  62. ck/pgm.py +3494 -0
  63. ck/pgm_circuit/__init__.py +1 -0
  64. ck/pgm_circuit/marginals_program.py +352 -0
  65. ck/pgm_circuit/mpe_program.py +237 -0
  66. ck/pgm_circuit/pgm_circuit.py +75 -0
  67. ck/pgm_circuit/program_with_slotmap.py +234 -0
  68. ck/pgm_circuit/slot_map.py +35 -0
  69. ck/pgm_circuit/support/__init__.py +0 -0
  70. ck/pgm_circuit/support/compile_circuit.py +83 -0
  71. ck/pgm_circuit/target_marginals_program.py +103 -0
  72. ck/pgm_circuit/wmc_program.py +323 -0
  73. ck/pgm_compiler/__init__.py +2 -0
  74. ck/pgm_compiler/ace/__init__.py +1 -0
  75. ck/pgm_compiler/ace/ace.py +252 -0
  76. ck/pgm_compiler/factor_elimination.py +383 -0
  77. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  78. ck/pgm_compiler/pgm_compiler.py +19 -0
  79. ck/pgm_compiler/recursive_conditioning.py +226 -0
  80. ck/pgm_compiler/support/__init__.py +0 -0
  81. ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
  82. ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
  83. ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
  84. ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
  85. ck/pgm_compiler/support/clusters.py +556 -0
  86. ck/pgm_compiler/support/factor_tables.py +398 -0
  87. ck/pgm_compiler/support/join_tree.py +275 -0
  88. ck/pgm_compiler/support/named_compiler_maker.py +33 -0
  89. ck/pgm_compiler/variable_elimination.py +89 -0
  90. ck/probability/__init__.py +0 -0
  91. ck/probability/empirical_probability_space.py +47 -0
  92. ck/probability/probability_space.py +568 -0
  93. ck/program/__init__.py +3 -0
  94. ck/program/program.py +129 -0
  95. ck/program/program_buffer.py +180 -0
  96. ck/program/raw_program.py +61 -0
  97. ck/sampling/__init__.py +0 -0
  98. ck/sampling/forward_sampler.py +211 -0
  99. ck/sampling/marginals_direct_sampler.py +113 -0
  100. ck/sampling/sampler.py +62 -0
  101. ck/sampling/sampler_support.py +232 -0
  102. ck/sampling/uniform_sampler.py +66 -0
  103. ck/sampling/wmc_direct_sampler.py +169 -0
  104. ck/sampling/wmc_gibbs_sampler.py +147 -0
  105. ck/sampling/wmc_metropolis_sampler.py +159 -0
  106. ck/sampling/wmc_rejection_sampler.py +113 -0
  107. ck/utils/__init__.py +0 -0
  108. ck/utils/iter_extras.py +153 -0
  109. ck/utils/map_list.py +128 -0
  110. ck/utils/map_set.py +128 -0
  111. ck/utils/np_extras.py +51 -0
  112. ck/utils/random_extras.py +64 -0
  113. ck/utils/tmp_dir.py +94 -0
  114. ck_demos/__init__.py +0 -0
  115. ck_demos/ace/__init__.py +0 -0
  116. ck_demos/ace/copy_ace_to_ck.py +15 -0
  117. ck_demos/ace/demo_ace.py +44 -0
  118. ck_demos/all_demos.py +88 -0
  119. ck_demos/circuit/__init__.py +0 -0
  120. ck_demos/circuit/demo_circuit_dump.py +22 -0
  121. ck_demos/circuit/demo_derivatives.py +43 -0
  122. ck_demos/circuit_compiler/__init__.py +0 -0
  123. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  124. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  125. ck_demos/pgm/__init__.py +0 -0
  126. ck_demos/pgm/demo_pgm_dump.py +18 -0
  127. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  128. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  129. ck_demos/pgm/show_examples.py +25 -0
  130. ck_demos/pgm_compiler/__init__.py +0 -0
  131. ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
  132. ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
  133. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  134. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  135. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  136. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  137. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  138. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  139. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  140. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  141. ck_demos/pgm_inference/__init__.py +0 -0
  142. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  143. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  144. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  145. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  146. ck_demos/programs/__init__.py +0 -0
  147. ck_demos/programs/demo_program_buffer.py +24 -0
  148. ck_demos/programs/demo_program_multi.py +24 -0
  149. ck_demos/programs/demo_program_none.py +19 -0
  150. ck_demos/programs/demo_program_single.py +23 -0
  151. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  152. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  153. ck_demos/sampling/__init__.py +0 -0
  154. ck_demos/sampling/check_sampler.py +71 -0
  155. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  156. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  157. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  158. ck_demos/utils/__init__.py +0 -0
  159. ck_demos/utils/compare.py +88 -0
  160. ck_demos/utils/convert_network.py +45 -0
  161. ck_demos/utils/sample_model.py +216 -0
  162. ck_demos/utils/stop_watch.py +384 -0
  163. compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
  164. compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
  165. compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
  166. compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
  167. compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
@@ -0,0 +1,398 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Sequence, Tuple, List, Iterator, Set, Iterable, Optional, Callable
5
+
6
+ import numpy as np
7
+
8
+ from ck.circuit import Circuit, VarNode, CircuitNode
9
+ from ck.pgm import PGM, ParamId, Factor, PotentialFunction, RandomVariable, ZeroPotentialFunction
10
+ from ck.pgm_circuit.slot_map import SlotMap, SlotKey
11
+ from ck.pgm_compiler.support.circuit_table import CircuitTable, TableInstance
12
+ from ck.utils.iter_extras import pairs
13
+ from ck.utils.map_list import MapList
14
+ from ck.utils.np_extras import NDArray, NDArrayFloat64
15
+
16
+
17
+ @dataclass
18
+ class FactorTables:
19
+ circuit: Circuit # The host circuit
20
+ number_of_indicators: int # number of indicator variables
21
+ number_of_parameters: int # number of parameter variables (i.e., non-const, in-use parameters)
22
+ slot_map: SlotMap # map from Indicator or ParamId object to a circuit var index.
23
+ tables: Sequence[CircuitTable] # one CircuitTable for each PGM factor.
24
+
25
+ # For a non-const, in-use parameter with id `param_id`, the PGM value of that
26
+ # parameter was `self.parameter_values[self.slot_map[param_id] - self.number_of_indicators]`.
27
+ parameter_values: NDArray
28
+
29
+ def get_table(self, factor: Factor) -> CircuitTable:
30
+ return self.tables[factor.idx]
31
+
32
+
33
+ def make_factor_tables(
34
+ pgm: PGM,
35
+ const_parameters: bool,
36
+ multiply_indicators: bool,
37
+ pre_prune_factor_tables: bool,
38
+ ) -> FactorTables:
39
+ """
40
+ Consistently and efficiently create circuit tables for factors of a PGM.
41
+
42
+ Creates:
43
+ * a circuit,
44
+ * a circuit variable for each indicator of the PGM,
45
+ * a circuit variable for each non-constant, in-use potential function parameter.
46
+ * a circuit table for each Factor of the PGM,
47
+
48
+ The parameter of each potential function will be converted either
49
+ eiter to a circuit constant (if const_parameters is true) or a circuit
50
+ variable (if const_parameters is false).
51
+
52
+ Random variables will be multiplied into factor circuit tables if
53
+ `multiply_indicators` is true.
54
+
55
+ A slot map will be created that maps PGM indicators and parameter ids to circuit var indices.
56
+ Specifically, a circuit var will be added for each indicator,
57
+ in the order they appear in `pgm.indicators`. Circuit vars for parameter ids will be added
58
+ after those for indicators, and only if const_parameters is false.
59
+
60
+ Args:
61
+ pgm: The PGM with the random variables, factors, and potential functions.
62
+ const_parameters: if true, then potential function parameters will be circuit constants,
63
+ otherwise they will be circuit variables, with entries in the returned slot map.
64
+ multiply_indicators: if true then indicator variables will be multiplied into an acceptable
65
+ factor.
66
+ pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
67
+
68
+ Returns:
69
+ FactorTables, holding a slot_map and a circuit table for each PGM factor.
70
+ """
71
+
72
+ # Create circuit and initialise the slot map with indicator variables
73
+ circuit = Circuit()
74
+ slot_map: Dict[SlotKey, int] = {
75
+ indicator: circuit.new_var().idx
76
+ for indicator in pgm.indicators
77
+ }
78
+
79
+ # Get the circuit table rows for each potential function
80
+ # functions_rows[id(function)] = rows for the function
81
+ functions_rows: Dict[int, _FunctionRows]
82
+ if const_parameters:
83
+ functions_rows = {
84
+ id(function): _rows_for_function_const(function, circuit)
85
+ for function in pgm.functions
86
+ }
87
+ else:
88
+ functions_rows = {
89
+ id(function): _rows_for_function_var(function, circuit, slot_map)
90
+ for function in pgm.functions
91
+ }
92
+
93
+ # Link factors to function rows.
94
+ # factor_rows[id(factor)] = rows for the factor
95
+ factor_rows: Dict[int, _FactorRows] = {}
96
+ for factor in pgm.factors:
97
+ rows: _FunctionRows = functions_rows[id(factor.function)]
98
+ rows.use_count += 1
99
+ factor_rows[id(factor)] = _FactorRows(factor, rows)
100
+
101
+ # Check to see if any factor rows can be pre-pruned.
102
+ if pre_prune_factor_tables:
103
+ _pre_prune_factor_tables(list(factor_rows.values()))
104
+
105
+ # Allocated random variables to factors
106
+ factors_mul_rvs: MapList[int, RandomVariable]
107
+ if multiply_indicators:
108
+ def _factor_size(_factor: Factor) -> int:
109
+ return len(factor_rows[id(_factor)])
110
+
111
+ factors_mul_rvs = _assign_rvs_to_factors(pgm, _factor_size)
112
+ else:
113
+ factors_mul_rvs = MapList() # no assignment of rvs to factors.
114
+
115
+ # Make a circuit table for each factor. `tables[factor.index]` is the circuit table for `factor`.
116
+ tables: List[CircuitTable] = [
117
+ _make_factor_table(factor, circuit, slot_map, factor_rows[id(factor)], factors_mul_rvs)
118
+ for factor in pgm.factors
119
+ ]
120
+
121
+ # Extract the parameter values (if they are circuit vars).
122
+ number_of_indicators: int = pgm.number_of_indicators
123
+ number_of_parameters: int = len(slot_map) - number_of_indicators
124
+ parameter_values: NDArrayFloat64 = np.zeros(number_of_parameters, dtype=np.float64)
125
+ if not const_parameters:
126
+ for function in pgm.functions:
127
+ for param_index, value in function.params:
128
+ param_id: ParamId = function.param_id(param_index)
129
+ slot: Optional[int] = slot_map.get(param_id)
130
+ if slot is not None:
131
+ parameter_values[slot - number_of_indicators] = value
132
+
133
+ return FactorTables(
134
+ circuit=circuit,
135
+ number_of_indicators=number_of_indicators,
136
+ number_of_parameters=number_of_parameters,
137
+ slot_map=slot_map,
138
+ tables=tables,
139
+ parameter_values=parameter_values,
140
+ )
141
+
142
+
143
+ def _assign_rvs_to_factors(
144
+ pgm: PGM,
145
+ factor_size: Callable[[Factor], int],
146
+ ) -> MapList[int, RandomVariable]:
147
+ """
148
+ Assign each random variable to the smallest factor containing it.
149
+
150
+ Returns:
151
+ a map from factor id to list of random variables assigned to that factor
152
+ """
153
+ factors = pgm.factors
154
+ rvs = pgm.rvs
155
+
156
+ # For each rv, get the factors it is in
157
+ rv_factors: MapList[int, Factor] = MapList() # rv index to list of Factors with that rv.
158
+ for factor in factors:
159
+ for rv in factor.rvs:
160
+ rv_factors.append(rv.idx, factor)
161
+
162
+ # For each rv, assign it to a factor for multiplication
163
+ factors_mul_rvs: MapList[int, RandomVariable] = MapList() # factor id to list of rvs
164
+ for rv_index in range(len(rvs)):
165
+ candidates: Sequence[Factor] = rv_factors.get(rv_index, ())
166
+ if len(candidates) > 0:
167
+ best_factor = min(candidates, key=factor_size)
168
+ factors_mul_rvs.append(id(best_factor), rvs[rv_index])
169
+
170
+ return factors_mul_rvs
171
+
172
+
173
+ class _FunctionRows:
174
+ def __init__(self, rows: Dict[TableInstance, CircuitNode], use_count: int = 0):
175
+ self.rows: Dict[TableInstance, CircuitNode] = rows
176
+ self.use_count: int = use_count
177
+
178
+
179
+ class _FactorRows:
180
+ def __init__(self, factor: Factor, rows: _FunctionRows):
181
+ self.rows: _FunctionRows = rows
182
+ self.rv_indexes: Tuple[int, ...] = tuple(rv.idx for rv in factor.rvs)
183
+
184
+ def __len__(self) -> int:
185
+ return len(self.rows.rows)
186
+
187
+ def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
188
+ return self.rows.rows.items()
189
+
190
+ def prune(self, extra_keys: Set[TableInstance]) -> None:
191
+ """
192
+ Remove the given keys from the factor's function rows.
193
+ """
194
+ if len(extra_keys) > 0:
195
+ new_rows: Dict[TableInstance, CircuitNode] = {
196
+ instance: node
197
+ for instance, node in self.rows.rows.items()
198
+ if instance not in extra_keys
199
+ }
200
+ if self.rows.use_count > 1:
201
+ self.rows.use_count -= 1
202
+ self.rows = _FunctionRows(new_rows, 1)
203
+ else:
204
+ self.rows.rows = new_rows
205
+
206
+
207
+ class _FactorPair:
208
+ def __init__(self, x: _FactorRows, y: _FactorRows):
209
+ self.x: _FactorRows = x
210
+ self.y: _FactorRows = y
211
+
212
+ x_set = set(self.x.rv_indexes)
213
+
214
+ # Identify all random variables used by x and y
215
+ self.all_rv_indexes: Set[int] = x_set.union(self.y.rv_indexes)
216
+
217
+ # Identify common random variables between x and y
218
+ # Keep them in a stable order
219
+ self.co_rv_indexes: Tuple[int, ...] = tuple(x_set.intersection(self.y.rv_indexes))
220
+
221
+ # Cache mappings from result Instance to index into source Instance (x or y).
222
+ # This will be used in indexing and product loops to pull our needed values
223
+ # from the source instances.
224
+ self.co_from_x_map = tuple(x.rv_indexes.index(rv_index) for rv_index in self.co_rv_indexes)
225
+ self.co_from_y_map = tuple(y.rv_indexes.index(rv_index) for rv_index in self.co_rv_indexes)
226
+
227
+ def prune(self) -> None:
228
+ """
229
+ Prune any rows from x and y that cannot join to each other.
230
+ """
231
+ co_from_x_map = self.co_from_x_map
232
+ co_from_y_map = self.co_from_y_map
233
+ x_rows = self.x.rows.rows
234
+ y_rows = self.y.rows.rows
235
+
236
+ x_co_set: Set[TableInstance] = {
237
+ tuple(instance[i] for i in co_from_x_map)
238
+ for instance in x_rows.keys()
239
+ }
240
+
241
+ y_co_set: Set[TableInstance] = {
242
+ tuple(instance[i] for i in co_from_y_map)
243
+ for instance in y_rows.keys()
244
+ }
245
+
246
+ # Keys in x that will not join to y
247
+ x_extra_keys: Set[TableInstance] = {
248
+ instance
249
+ for instance in x_rows.keys()
250
+ if tuple(instance[i] for i in co_from_x_map) not in y_co_set
251
+ }
252
+
253
+ # Keys in y that will not join to x
254
+ y_extra_keys: Set[TableInstance] = {
255
+ instance
256
+ for instance in y_rows.keys()
257
+ if tuple(instance[i] for i in co_from_y_map) not in x_co_set
258
+ }
259
+
260
+ self.x.prune(x_extra_keys)
261
+ self.y.prune(y_extra_keys)
262
+
263
+
264
+ def _pre_prune_factor_tables(factor_rows: Sequence[_FactorRows]) -> None:
265
+ """
266
+ It may be possible to reduce the size of a table for a factor.
267
+
268
+ If two factors contain a common random variable then at some point their product
269
+ will be formed, which may eliminate rows. This method identifies and removes
270
+ such rows.
271
+ """
272
+ pairs_to_check: List[_FactorPair] = [
273
+ _FactorPair(f1, f2)
274
+ for f1, f2 in pairs(factor_rows)
275
+ if not set(f1.rv_indexes).isdisjoint(f1.rv_indexes)
276
+ ]
277
+
278
+ pairs_done: List[_FactorPair] = []
279
+
280
+ while len(pairs_to_check) > 0:
281
+ pair = pairs_to_check.pop()
282
+ x = pair.x
283
+ y = pair.y
284
+
285
+ x_size = len(x)
286
+ y_size = len(y)
287
+ pair.prune()
288
+
289
+ # See if any pairs need re-checking
290
+ rvs_affected: Set[int] = set()
291
+ if x_size != len(x):
292
+ rvs_affected.update(x.rv_indexes)
293
+ if y_size != len(y):
294
+ rvs_affected.update(y.rv_indexes)
295
+ if len(rvs_affected) > 0:
296
+ next_pairs_done: List[_FactorPair] = []
297
+ for pair in pairs_done:
298
+ if rvs_affected.isdisjoint(pair.all_rv_indexes):
299
+ next_pairs_done.append(pair)
300
+ else:
301
+ pairs_to_check.append(pair)
302
+ pairs_done = next_pairs_done
303
+
304
+ # Mark the current pair as done.
305
+ pairs_done.append(pair)
306
+
307
+
308
+ def _make_factor_table(
309
+ factor: Factor,
310
+ circuit: Circuit,
311
+ slot_map: Dict[SlotKey, int],
312
+ rows: _FactorRows,
313
+ factors_mul_rvs: MapList[int, RandomVariable],
314
+ ) -> CircuitTable:
315
+ # Get random variables to multiply into the table
316
+ factor_mul_rvs: Sequence[RandomVariable] = factors_mul_rvs.get(id(factor), ())
317
+
318
+ # Create the empty circuit table
319
+ factor_rv_indexes: Sequence[int] = tuple(rv.idx for rv in factor.rvs)
320
+
321
+ if len(factor_mul_rvs) == 0:
322
+ # Trivial case - no random variables to multiply into the table.
323
+ return CircuitTable(circuit, factor_rv_indexes, rows.items())
324
+
325
+ # Work out what element in an instance of the factor will select the indicator
326
+ # variable for each mul rv.
327
+ # inst_to_mul[i] is the index into factor.rvs for factor_mul_rvs[i]
328
+ inst_to_mul: Sequence[int] = tuple(factor_rv_indexes.index(rv.idx) for rv in factor_mul_rvs)
329
+
330
+ # Map a state index of a mul rv to its indicator circuit variable.
331
+ # mul_rvs_vars[i][j] is the indicator circuit variable for factor_mul_rvs[i][j]
332
+ mul_rvs_vars: Sequence[Sequence[CircuitNode]] = tuple(
333
+ tuple(circuit.vars[slot_map[ind]] for ind in rv.indicators)
334
+ for rv in factor_mul_rvs
335
+ )
336
+
337
+ def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
338
+ for instance, node in rows.items():
339
+ to_mul = tuple(
340
+ mul_vars[instance[inst_index]]
341
+ for inst_index, mul_vars in zip(inst_to_mul, mul_rvs_vars)
342
+ )
343
+ if not node.is_one():
344
+ to_mul += (node,)
345
+ if len(to_mul) == 0:
346
+ yield instance, circuit.one
347
+ elif len(to_mul) == 1:
348
+ yield instance, to_mul[0]
349
+ else:
350
+ yield instance, circuit.optimised_mul(to_mul)
351
+
352
+ return CircuitTable(circuit, factor_rv_indexes, _result_rows())
353
+
354
+
355
+ def _rows_for_function_const(
356
+ function: PotentialFunction,
357
+ circuit: Circuit,
358
+ ) -> _FunctionRows:
359
+ """
360
+ Get the rows (instance, node) for the given potential function
361
+ where each node is a circuit constant.
362
+ This will exclude zero values.
363
+ """
364
+ if isinstance(function, ZeroPotentialFunction):
365
+ # shortcut
366
+ return _FunctionRows({})
367
+
368
+ return _FunctionRows({
369
+ tuple(instance): circuit.const(value)
370
+ for instance, _, value in function.keys_with_param
371
+ if value != 0
372
+ })
373
+
374
+
375
+ def _rows_for_function_var(
376
+ function: PotentialFunction,
377
+ circuit: Circuit,
378
+ slot_map: Dict[SlotKey, int],
379
+ ) -> _FunctionRows:
380
+ """
381
+ Get the rows (instance, node) for the given potential function
382
+ where each node is a circuit variable.
383
+ """
384
+
385
+ def _create_param_var(param_id: ParamId) -> VarNode:
386
+ """
387
+ Create a circuit variable for the given parameter id.
388
+ This assumes one does not already exist for the parameter id.
389
+ """
390
+ assert param_id not in slot_map.keys(), 'parameter should not already have a circuit var'
391
+ node: VarNode = circuit.new_var()
392
+ slot_map[param_id] = node.idx
393
+ return node
394
+
395
+ return _FunctionRows({
396
+ tuple(instance): _create_param_var(function.param_id(param_index))
397
+ for instance, param_index, _ in function.keys_with_param
398
+ })
@@ -0,0 +1,275 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from itertools import chain
5
+ from typing import List, Set, Callable, Sequence, Tuple
6
+
7
+ import numpy as np
8
+
9
+ from ck.pgm import PGM, Factor
10
+ from ck.pgm_compiler.support.clusters import Clusters, min_degree, min_fill, \
11
+ min_degree_then_fill, min_fill_then_degree, min_weighted_degree, min_weighted_fill, min_traditional_weighted_fill, \
12
+ ClusterAlgorithm
13
+ from ck.utils.np_extras import NDArrayFloat64
14
+
15
+
16
+ @dataclass
17
+ class JoinTree:
18
+ # The PGM that this join tree is for.
19
+ pgm: PGM
20
+
21
+ # Indexes of random variables in this join tree node
22
+ cluster: Set[int]
23
+
24
+ # Child nodes in the join tree
25
+ children: List[JoinTree]
26
+
27
+ # Factors of the PGM allocated to this join tree node.
28
+ factors: List[Factor]
29
+
30
+ # Indexes of random variables that in both this cluster and the parent's cluster.
31
+ # (Empty if this is the root of the spanning tree).
32
+ separator: Set[int]
33
+
34
+ def max_cluster_size(self) -> int:
35
+ """
36
+ Returns:
37
+ the maximum `len(self.cluster)` over self and all children, recursively.
38
+ """
39
+ return max(chain((len(self.cluster),), (child.max_cluster_size() for child in self.children)))
40
+
41
+ def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
42
+ """
43
+ Returns:
44
+ the maximum `log2` over self and all children, recursively.
45
+ """
46
+ self_weighted_size: float = sum(rv_log_sizes[rv_idx] for rv_idx in self.cluster)
47
+ return max(
48
+ chain(
49
+ (self_weighted_size,),
50
+ (child.max_cluster_weighted_size(rv_log_sizes) for child in self.children)
51
+ )
52
+ )
53
+
54
+ def dump(self, *, prefix: str = '', indent: str = ' ', show_factors: bool = True) -> None:
55
+ """
56
+ Print a dump of the Join Tree.
57
+ This is intended for debugging and demonstration purposes.
58
+
59
+ Each cluster is printed as: {separator rvs} | {non-separator rvs}.
60
+
61
+ Args:
62
+ prefix: optional prefix for indenting all lines.
63
+ indent: additional prefix to use for extra indentation.
64
+ show_factors: if true, the factors of each cluster are shown.
65
+ """
66
+ sep_str = ' '.join(repr(str(self.pgm.rvs[i])) for i in sorted(self.separator))
67
+ rest_str = ' '.join(repr(str(self.pgm.rvs[i])) for i in sorted(self.cluster) if i not in self.separator)
68
+ if len(sep_str) > 0:
69
+ sep_str += ' '
70
+ print(f'{prefix}{sep_str}| {rest_str} (factors: {len(self.factors)})')
71
+ if show_factors:
72
+ for factor in self.factors:
73
+ print(f'{prefix}factor{factor}')
74
+ next_prefix = prefix + indent
75
+ for child in self.children:
76
+ child.dump(prefix=next_prefix, indent=indent, show_factors=show_factors)
77
+
78
+
79
+ # Type for a join tree algorithm: PGM -> JoinTree.
80
+ JoinTreeAlgorithm = Callable[[PGM], JoinTree]
81
+
82
+
83
+ def _join_tree_algorithm(pgm_to_clusters: ClusterAlgorithm) -> JoinTreeAlgorithm:
84
+ """
85
+ Helper function for creating a standard JoinTreeAlgorithm from
86
+ a ClusterAlgorithm.
87
+
88
+ Args:
89
+ pgm_to_clusters: The clusters method to use.
90
+
91
+ Returns:
92
+ a JoinTreeAlgorithm.
93
+ """
94
+
95
+ def __join_tree_algorithm(pgm: PGM) -> JoinTree:
96
+ clusters: Clusters = pgm_to_clusters(pgm)
97
+ return clusters_to_join_tree(clusters)
98
+
99
+ return __join_tree_algorithm
100
+
101
+
102
+ # standard JoinTreeAlgorithms
103
+
104
+ MIN_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_degree)
105
+ MIN_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_fill)
106
+ MIN_DEGREE_THEN_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_degree_then_fill)
107
+ MIN_FILL_THEN_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_fill_then_degree)
108
+ MIN_WEIGHTED_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_weighted_degree)
109
+ MIN_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_weighted_fill)
110
+ MIN_TRADITIONAL_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_traditional_weighted_fill)
111
+
112
+
113
+ def clusters_to_join_tree(clusters: Clusters) -> JoinTree:
114
+ """
115
+ Construct a join tree maker for the given PGM and random variable clusters.
116
+
117
+ A join tree is formed by finding a minimum spanning tree over the clusters
118
+ where the cost between a pair of cluster is defined according to
119
+ `separator_cost_counts` and `costing`.
120
+
121
+ Args:
122
+ clusters: the clusters that resulted from graph clusters of the given PGM.
123
+ """
124
+ pgm: PGM = clusters.pgm
125
+ cluster_sets: List[Set[int]] = clusters.clusters
126
+ number_of_clusters = len(cluster_sets)
127
+
128
+ # Dealing with these cases directly simplifies
129
+ # the spanning tree algorithm implementation.
130
+ if number_of_clusters == 0:
131
+ return JoinTree(pgm, set(), [], [], set())
132
+ elif number_of_clusters == 1:
133
+ return JoinTree(pgm, cluster_sets[0], [], list(pgm.factors), set())
134
+
135
+ # Calculate inter-cluster costs for determining the minimum spanning tree
136
+ cost: NDArrayFloat64 = np.zeros((number_of_clusters, number_of_clusters), dtype=np.float64)
137
+ # We will use separator state space size to break ties.
138
+ max_raw_break_cost = sum(pgm.rv_log_sizes) * 1.1 # sum of break costs must be < 1
139
+ break_cost = [cost / max_raw_break_cost for cost in pgm.rv_log_sizes]
140
+ for i in range(number_of_clusters):
141
+ cluster_i = cluster_sets[i]
142
+ for j in range(i + 1, number_of_clusters):
143
+ cluster_j = cluster_sets[j]
144
+ separator = cluster_i.intersection(cluster_j)
145
+ cost[i, j] = cost[j, i] = -len(separator) + sum(break_cost[rv_idx] for rv_idx in separator)
146
+
147
+ # Make the spanning tree over the clusters
148
+ root_custer_index: int
149
+ children: List[List[int]]
150
+ children, root_custer_index = _make_spanning_tree_small_root(cost, clusters.clusters)
151
+
152
+ # Allocate each PGM factor to a cluster
153
+ cluster_factors: List[List[Factor]] = [[] for _ in range(number_of_clusters)]
154
+ ordered_indexed_clusters = list(enumerate(cluster_sets))
155
+ ordered_indexed_clusters.sort(key=lambda idx_c: len(idx_c[1])) # sort from smallest to largest cluster
156
+ for factor in pgm.factors:
157
+ rv_indexes = frozenset(rv.idx for rv in factor.rvs)
158
+ for cluster_index, cluster in ordered_indexed_clusters:
159
+ if rv_indexes.issubset(cluster):
160
+ cluster_factors[cluster_index].append(factor)
161
+ break
162
+
163
+ return _form_join_tree_r(pgm, root_custer_index, set(), children, cluster_sets, cluster_factors)
164
+
165
+
166
+ _INF = float('inf')
167
+
168
+
169
+ def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]]) -> Tuple[List[List[int]], int]:
170
+ """
171
+ Construct a minimum spanning tree over the clusters, where the root is the cluster with
172
+ the smallest number of random variable.
173
+ """
174
+ root_custer_index: int = 0
175
+ root_size: int = len(clusters[root_custer_index])
176
+ for i, cluster in enumerate(clusters[1:], start=1):
177
+ if len(clusters[root_custer_index]) < root_size:
178
+ root_custer_index = i
179
+ root_size: int = len(cluster)
180
+
181
+ children: List[List[int]] = _make_spanning_tree_at_root(cost, root_custer_index)
182
+ return children, root_custer_index
183
+
184
+
185
+ def _make_spanning_tree_arbitrary_root(cost: NDArrayFloat64) -> Tuple[List[List[int]], int]:
186
+ """
187
+ Construct a minimum spanning tree over the clusters, starting at an arbitrary root.
188
+ """
189
+ root_custer_index: int = 0
190
+ children: List[List[int]] = _make_spanning_tree_at_root(cost, root_custer_index)
191
+ return children, root_custer_index
192
+
193
+
194
+ def _make_spanning_tree_at_root(
195
+ cost: NDArrayFloat64,
196
+ root_custer_index: int,
197
+ ) -> List[List[int]]:
198
+ """
199
+ Construct a minimum spanning tree over the clusters, starting at the given root.
200
+
201
+ Args:
202
+ cost: and nxn matrix where n is the number of clusters and cost[i, j]
203
+ gives the cost between clusters i and j.
204
+ root_custer_index: a nominated root cluster to be the root of the tree.
205
+ """
206
+ number_of_clusters: int = cost.shape[0]
207
+
208
+ # clusters left to process.
209
+ remaining: List[int] = list(range(number_of_clusters))
210
+
211
+ # clusters that have been processed.
212
+ included: List[int] = []
213
+
214
+ def remove_remaining(_remaining_index: int) -> None:
215
+ # Remove the `remaining` element at the given index location.
216
+ remaining[_remaining_index] = remaining[-1]
217
+ remaining.pop()
218
+
219
+ # Move root from `remaining` to `included`
220
+ included.append(root_custer_index)
221
+ remove_remaining(root_custer_index) # assumes remaining[root_custer_index] = root_custer_index
222
+
223
+ # Data structure to collect the results.
224
+ children: List[List[int]] = [[] for _ in range(number_of_clusters)]
225
+
226
+ while True:
227
+ min_i: int = 0
228
+ min_j: int = 0
229
+ min_j_pos: int = 0
230
+ min_c: float = _INF
231
+ for i in included:
232
+ for j_pos, j in enumerate(remaining):
233
+ c: float = cost.item(i, j)
234
+ if c < min_c:
235
+ min_c = c
236
+ min_i = i
237
+ min_j = j
238
+ min_j_pos = j_pos
239
+
240
+ # Record the child and move remaining_idx from 'remaining' to 'included'.
241
+ children[min_i].append(min_j)
242
+ if len(remaining) == 1:
243
+ # That was the last one.
244
+ return children
245
+
246
+ # Update `remaining` and `included`
247
+ remove_remaining(min_j_pos)
248
+ included.append(min_j)
249
+
250
+
251
+ def _form_join_tree_r(
252
+ pgm: PGM,
253
+ cluster_index: int,
254
+ parent_cluster: Set[int],
255
+ children: Sequence[List[int]],
256
+ clusters: Sequence[Set[int]],
257
+ cluster_factors: List[List[Factor]],
258
+ ) -> JoinTree:
259
+ """
260
+ Recursively build the join tree data structure.
261
+ """
262
+ cluster: Set[int] = clusters[cluster_index]
263
+ factors: List[Factor] = cluster_factors[cluster_index]
264
+ children = [
265
+ _form_join_tree_r(pgm, child, cluster, children, clusters, cluster_factors)
266
+ for child in children[cluster_index]
267
+ ]
268
+ separator: Set[int] = parent_cluster.intersection(cluster)
269
+ return JoinTree(
270
+ pgm,
271
+ cluster,
272
+ children,
273
+ factors,
274
+ separator,
275
+ )
@@ -0,0 +1,33 @@
1
+ from types import ModuleType
2
+ from typing import Tuple
3
+
4
+ from ck.pgm import PGM
5
+ from ck.pgm_circuit import PGMCircuit
6
+ from ck.pgm_compiler import PGMCompiler
7
+
8
+
9
+ def get_compiler(module: ModuleType, **kwargs) -> Tuple[PGMCompiler]:
10
+ """
11
+ Helper function to create a named PGM compiler.
12
+
13
+ Args:
14
+ module: module containing `compile_pgm` function.
15
+
16
+ Returns:
17
+ a singleton tuple containing PGMCompiler function.
18
+ """
19
+
20
+ def compiler(pgm: PGM, const_parameters: bool = True) -> PGMCircuit:
21
+ """Conforms to the `PGMCompiler` protocol."""
22
+ return module.compile_pgm(pgm, const_parameters=const_parameters, **kwargs)
23
+
24
+ return compiler,
25
+
26
+
27
+ def get_compiler_algorithm(module, algorithm: str, **kwargs) -> Tuple[PGMCompiler]:
28
+ """
29
+ Helper function to create a named PGM compiler, with a named algorithm argument.
30
+ """
31
+ return get_compiler(module, algorithm=getattr(module, algorithm, **kwargs))
32
+
33
+