compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (167) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +13 -0
  3. ck/circuit/circuit.c +38749 -0
  4. ck/circuit/circuit.cpython-313-darwin.so +0 -0
  5. ck/circuit/circuit_py.py +807 -0
  6. ck/circuit/tmp_const.py +74 -0
  7. ck/circuit_compiler/__init__.py +2 -0
  8. ck/circuit_compiler/circuit_compiler.py +26 -0
  9. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  10. ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
  11. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  12. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
  13. ck/circuit_compiler/interpret_compiler.py +223 -0
  14. ck/circuit_compiler/llvm_compiler.py +388 -0
  15. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  16. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  17. ck/circuit_compiler/support/__init__.py +0 -0
  18. ck/circuit_compiler/support/circuit_analyser.py +81 -0
  19. ck/circuit_compiler/support/input_vars.py +148 -0
  20. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  21. ck/example/__init__.py +53 -0
  22. ck/example/alarm.py +366 -0
  23. ck/example/asia.py +28 -0
  24. ck/example/binary_clique.py +32 -0
  25. ck/example/bow_tie.py +33 -0
  26. ck/example/cancer.py +37 -0
  27. ck/example/chain.py +38 -0
  28. ck/example/child.py +199 -0
  29. ck/example/clique.py +33 -0
  30. ck/example/cnf_pgm.py +39 -0
  31. ck/example/diamond_square.py +68 -0
  32. ck/example/earthquake.py +36 -0
  33. ck/example/empty.py +10 -0
  34. ck/example/hailfinder.py +539 -0
  35. ck/example/hepar2.py +628 -0
  36. ck/example/insurance.py +504 -0
  37. ck/example/loop.py +40 -0
  38. ck/example/mildew.py +38161 -0
  39. ck/example/munin.py +22982 -0
  40. ck/example/pathfinder.py +53674 -0
  41. ck/example/rain.py +39 -0
  42. ck/example/rectangle.py +161 -0
  43. ck/example/run.py +30 -0
  44. ck/example/sachs.py +129 -0
  45. ck/example/sprinkler.py +30 -0
  46. ck/example/star.py +44 -0
  47. ck/example/stress.py +64 -0
  48. ck/example/student.py +43 -0
  49. ck/example/survey.py +46 -0
  50. ck/example/triangle_square.py +54 -0
  51. ck/example/truss.py +49 -0
  52. ck/in_out/__init__.py +3 -0
  53. ck/in_out/parse_ace_lmap.py +216 -0
  54. ck/in_out/parse_ace_nnf.py +288 -0
  55. ck/in_out/parse_net.py +480 -0
  56. ck/in_out/parser_utils.py +185 -0
  57. ck/in_out/pgm_pickle.py +42 -0
  58. ck/in_out/pgm_python.py +268 -0
  59. ck/in_out/render_bugs.py +111 -0
  60. ck/in_out/render_net.py +177 -0
  61. ck/in_out/render_pomegranate.py +184 -0
  62. ck/pgm.py +3494 -0
  63. ck/pgm_circuit/__init__.py +1 -0
  64. ck/pgm_circuit/marginals_program.py +352 -0
  65. ck/pgm_circuit/mpe_program.py +237 -0
  66. ck/pgm_circuit/pgm_circuit.py +75 -0
  67. ck/pgm_circuit/program_with_slotmap.py +234 -0
  68. ck/pgm_circuit/slot_map.py +35 -0
  69. ck/pgm_circuit/support/__init__.py +0 -0
  70. ck/pgm_circuit/support/compile_circuit.py +83 -0
  71. ck/pgm_circuit/target_marginals_program.py +103 -0
  72. ck/pgm_circuit/wmc_program.py +323 -0
  73. ck/pgm_compiler/__init__.py +2 -0
  74. ck/pgm_compiler/ace/__init__.py +1 -0
  75. ck/pgm_compiler/ace/ace.py +252 -0
  76. ck/pgm_compiler/factor_elimination.py +383 -0
  77. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  78. ck/pgm_compiler/pgm_compiler.py +19 -0
  79. ck/pgm_compiler/recursive_conditioning.py +226 -0
  80. ck/pgm_compiler/support/__init__.py +0 -0
  81. ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
  82. ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
  83. ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
  84. ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
  85. ck/pgm_compiler/support/clusters.py +556 -0
  86. ck/pgm_compiler/support/factor_tables.py +398 -0
  87. ck/pgm_compiler/support/join_tree.py +275 -0
  88. ck/pgm_compiler/support/named_compiler_maker.py +33 -0
  89. ck/pgm_compiler/variable_elimination.py +89 -0
  90. ck/probability/__init__.py +0 -0
  91. ck/probability/empirical_probability_space.py +47 -0
  92. ck/probability/probability_space.py +568 -0
  93. ck/program/__init__.py +3 -0
  94. ck/program/program.py +129 -0
  95. ck/program/program_buffer.py +180 -0
  96. ck/program/raw_program.py +61 -0
  97. ck/sampling/__init__.py +0 -0
  98. ck/sampling/forward_sampler.py +211 -0
  99. ck/sampling/marginals_direct_sampler.py +113 -0
  100. ck/sampling/sampler.py +62 -0
  101. ck/sampling/sampler_support.py +232 -0
  102. ck/sampling/uniform_sampler.py +66 -0
  103. ck/sampling/wmc_direct_sampler.py +169 -0
  104. ck/sampling/wmc_gibbs_sampler.py +147 -0
  105. ck/sampling/wmc_metropolis_sampler.py +159 -0
  106. ck/sampling/wmc_rejection_sampler.py +113 -0
  107. ck/utils/__init__.py +0 -0
  108. ck/utils/iter_extras.py +153 -0
  109. ck/utils/map_list.py +128 -0
  110. ck/utils/map_set.py +128 -0
  111. ck/utils/np_extras.py +51 -0
  112. ck/utils/random_extras.py +64 -0
  113. ck/utils/tmp_dir.py +94 -0
  114. ck_demos/__init__.py +0 -0
  115. ck_demos/ace/__init__.py +0 -0
  116. ck_demos/ace/copy_ace_to_ck.py +15 -0
  117. ck_demos/ace/demo_ace.py +44 -0
  118. ck_demos/all_demos.py +88 -0
  119. ck_demos/circuit/__init__.py +0 -0
  120. ck_demos/circuit/demo_circuit_dump.py +22 -0
  121. ck_demos/circuit/demo_derivatives.py +43 -0
  122. ck_demos/circuit_compiler/__init__.py +0 -0
  123. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  124. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  125. ck_demos/pgm/__init__.py +0 -0
  126. ck_demos/pgm/demo_pgm_dump.py +18 -0
  127. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  128. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  129. ck_demos/pgm/show_examples.py +25 -0
  130. ck_demos/pgm_compiler/__init__.py +0 -0
  131. ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
  132. ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
  133. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  134. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  135. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  136. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  137. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  138. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  139. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  140. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  141. ck_demos/pgm_inference/__init__.py +0 -0
  142. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  143. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  144. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  145. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  146. ck_demos/programs/__init__.py +0 -0
  147. ck_demos/programs/demo_program_buffer.py +24 -0
  148. ck_demos/programs/demo_program_multi.py +24 -0
  149. ck_demos/programs/demo_program_none.py +19 -0
  150. ck_demos/programs/demo_program_single.py +23 -0
  151. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  152. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  153. ck_demos/sampling/__init__.py +0 -0
  154. ck_demos/sampling/check_sampler.py +71 -0
  155. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  156. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  157. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  158. ck_demos/utils/__init__.py +0 -0
  159. ck_demos/utils/compare.py +88 -0
  160. ck_demos/utils/convert_network.py +45 -0
  161. ck_demos/utils/sample_model.py +216 -0
  162. ck_demos/utils/stop_watch.py +384 -0
  163. compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
  164. compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
  165. compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
  166. compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
  167. compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
@@ -0,0 +1,383 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from itertools import islice
5
+ from typing import Iterator, Optional, FrozenSet
6
+
7
+ from ck.circuit import CircuitNode
8
+ from ck.pgm_circuit import PGMCircuit
9
+ from ck.pgm_compiler.support.circuit_table import CircuitTable, product, sum_out
10
+ from ck.pgm_compiler.support.factor_tables import make_factor_tables, FactorTables
11
+ from ck.pgm_compiler.support.join_tree import *
12
+
13
+ _NEG_INF = float('-inf')
14
+
15
+
16
+ def compile_pgm(
17
+ pgm: PGM,
18
+ const_parameters: bool = True,
19
+ *,
20
+ algorithm: JoinTreeAlgorithm = MIN_FILL_THEN_DEGREE,
21
+ limit_product_tree_search: int = 1000,
22
+ pre_prune_factor_tables: bool = True,
23
+ ) -> PGMCircuit:
24
+ """
25
+ Compile the PGM to an arithmetic circuit, using factor elimination.
26
+
27
+ When forming the product of factors withing a join tree nodes,
28
+ this method searches all practical binary trees for forming products,
29
+ up to the given limit, `limit_product_tree_search`. The minimum is 1.
30
+
31
+ Conforms to the `PGMCompiler` protocol.
32
+
33
+ Args:
34
+ pgm: The PGM to compile.
35
+ const_parameters: If true, the potential function parameters will be circuit
36
+ constants, otherwise they will be circuit variables.
37
+ algorithm: algorithm to get a join tree.
38
+ limit_product_tree_search: limit on number of product trees to consider.
39
+ pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
40
+
41
+ Returns:
42
+ a PGMCircuit object.
43
+
44
+ Raises:
45
+ ValueError if `limit_product_tree_search` is not > 0.
46
+ """
47
+ join_tree: JoinTree = algorithm(pgm)
48
+ return join_tree_to_circuit(
49
+ join_tree,
50
+ const_parameters,
51
+ limit_product_tree_search,
52
+ pre_prune_factor_tables,
53
+ )
54
+
55
+
56
+ def compile_pgm_best_jointree(
57
+ pgm: PGM,
58
+ const_parameters: bool = True,
59
+ *,
60
+ limit_product_tree_search: int = 1000,
61
+ pre_prune_factor_tables: bool = True,
62
+ ) -> PGMCircuit:
63
+ """
64
+ Try multiple elimination heuristics, and use the join tree that has
65
+ the smallest maximum cluster size.
66
+
67
+ Conforms to the `PGMCompiler` protocol.
68
+
69
+ Args:
70
+ pgm: The PGM to compile.
71
+ const_parameters: If true, the potential function parameters will be circuit
72
+ constants, otherwise they will be circuit variables.
73
+ limit_product_tree_search: limit on number of product trees to consider.
74
+ pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
75
+
76
+ Returns:
77
+ a PGMCircuit object.
78
+
79
+ Raises:
80
+ ValueError if `limit_product_tree_search` is not > 0.
81
+ """
82
+ # Get the smallest cluster sequence for a list of possibles.
83
+ algorithms: Sequence[ClusterAlgorithm] = [
84
+ min_degree,
85
+ min_fill,
86
+ min_degree_then_fill,
87
+ min_fill_then_degree,
88
+ min_weighted_degree,
89
+ min_weighted_fill,
90
+ min_traditional_weighted_fill,
91
+ ]
92
+ rv_log_sizes: Sequence[float] = pgm.rv_log_sizes
93
+ best_clusters: Clusters = algorithms[0](pgm)
94
+ best_size = best_clusters.max_cluster_weighted_size(rv_log_sizes)
95
+ for algorithm in algorithms[1:]:
96
+ clusters: Clusters = algorithm(pgm)
97
+ size = clusters.max_cluster_weighted_size(rv_log_sizes)
98
+ if size < best_size:
99
+ best_size = size
100
+ best_clusters = clusters
101
+
102
+ join_tree: JoinTree = clusters_to_join_tree(best_clusters)
103
+ return join_tree_to_circuit(
104
+ join_tree,
105
+ const_parameters,
106
+ limit_product_tree_search,
107
+ pre_prune_factor_tables,
108
+ )
109
+
110
+
111
+ def join_tree_to_circuit(
112
+ join_tree: JoinTree,
113
+ const_parameters: bool = True,
114
+ limit_product_tree_search: int = 1000,
115
+ pre_prune_factor_tables: bool = True,
116
+ ) -> PGMCircuit:
117
+ """
118
+ Construct a PGMCircuit from a join-tree.
119
+
120
+ Args:
121
+ join_tree: a join tree for a PGM.
122
+ const_parameters: If true, the potential function parameters will be circuit
123
+ constants, otherwise they will be circuit variables.
124
+ limit_product_tree_search: limit on number of product trees to consider.
125
+ pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
126
+
127
+ Returns:
128
+ an arithmetic circuit and slot map, as a PGMCircuit object.
129
+
130
+ Raises:
131
+ ValueError if `limit_product_tree_search` is not > 0.
132
+ """
133
+ if limit_product_tree_search <= 0:
134
+ raise ValueError('limit_product_tree_search must be > 0')
135
+
136
+ pgm: PGM = join_tree.pgm
137
+ factor_tables: FactorTables = make_factor_tables(
138
+ pgm=pgm,
139
+ const_parameters=const_parameters,
140
+ multiply_indicators=True,
141
+ pre_prune_factor_tables=pre_prune_factor_tables,
142
+ )
143
+
144
+ top_table: CircuitTable = _circuit_tables_from_join_tree(
145
+ factor_tables,
146
+ join_tree,
147
+ limit_product_tree_search,
148
+ )
149
+ top: CircuitNode = top_table.top()
150
+ top_table.circuit.remove_unreachable_op_nodes(top)
151
+
152
+ return PGMCircuit(
153
+ rvs=tuple(pgm.rvs),
154
+ conditions=(),
155
+ circuit_top=top,
156
+ number_of_indicators=factor_tables.number_of_indicators,
157
+ number_of_parameters=factor_tables.number_of_parameters,
158
+ slot_map=factor_tables.slot_map,
159
+ parameter_values=factor_tables.parameter_values,
160
+ )
161
+
162
+
163
+ def _circuit_tables_from_join_tree(
164
+ factor_tables: FactorTables,
165
+ join_tree: JoinTree,
166
+ limit_product_tree_search: int,
167
+ ) -> CircuitTable:
168
+ """
169
+ This is a basic algorithm for constructing a circuit table from a join tree.
170
+ """
171
+ # The PGM factors allocated to this join tree node
172
+ factors: List[CircuitTable] = [
173
+ factor_tables.get_table(factor)
174
+ for factor in join_tree.factors
175
+ ]
176
+
177
+ # The children of this join tree node
178
+ factors.extend(
179
+ _circuit_tables_from_join_tree(factor_tables, child, limit_product_tree_search)
180
+ for child in join_tree.children
181
+ )
182
+
183
+ # The usual join tree approach just forms the product all the tables in `factors`.
184
+ # The tree width is not affected by the order of products, however some orders
185
+ # lead to smaller numbers of arithmetic operations.
186
+ #
187
+ # If `options.optimise_products` is true, then heuristics are used
188
+ # reduce the number of arithmetic operations.
189
+
190
+ # Deal with the special case: no factors
191
+ if len(factors) == 0:
192
+ circuit = factor_tables.circuit
193
+ if len(join_tree.separator) == 0:
194
+ # table one
195
+ return CircuitTable(circuit, (), (((), circuit.one),))
196
+ else:
197
+ # table zero
198
+ return CircuitTable(circuit, tuple(join_tree.separator), ())
199
+
200
+ # Analise different ways to combine the factors
201
+ # This method potentially examines all possible trees, O(len(factors)!),
202
+ # which may need to be improved!
203
+ # Trees that result in rvs to be summed out early are scored more highly.
204
+
205
+ rv_log_sizes: Sequence[float] = join_tree.pgm.rv_log_sizes
206
+ best_score = _NEG_INF
207
+ best_tree = None
208
+ for tree in islice(_iterate_trees(factors, join_tree.separator), limit_product_tree_search):
209
+ score = tree.score(rv_log_sizes)
210
+ if score > best_score:
211
+ best_score = score
212
+ best_tree = tree
213
+
214
+ # The tree knows how to form products and perform sum-outs.
215
+ return best_tree.get_table()
216
+
217
+
218
+ class _Product(ABC):
219
+ """
220
+ A node in a binary product tree.
221
+
222
+ A node is either a _ProductLeaf, holding a single CircuitTable,
223
+ or is a _ProductInterior, which has exactly two children.
224
+ """
225
+
226
+ def __init__(self, available: Set[int]):
227
+ """
228
+ Construct a node in a binary product tree.
229
+
230
+ Args:
231
+ available: the rvs that are available (prior to sum-out)
232
+ after the product is formed.
233
+ """
234
+ self.available: Set[int] = available
235
+ self.sum_out: Set[int] = set()
236
+
237
+ @abstractmethod
238
+ def set_sum_out(self, need: Set[int]) -> None:
239
+ """
240
+ Set the self.sum_out, based on what rvs are needed.
241
+
242
+ Args:
243
+ need: what rvs are require to be supplied by this node
244
+ after the product is formed. This will be a subset
245
+ of `self.available`.
246
+ """
247
+ ...
248
+
249
+ @abstractmethod
250
+ def score(self, rv_log_sizes: Sequence[float]) -> float:
251
+ """
252
+ Heuristically score a tree (assuming set_sum_out has been called).
253
+ """
254
+ ...
255
+
256
+ @abstractmethod
257
+ def get_table(self) -> CircuitTable:
258
+ """
259
+ Returns:
260
+ The circuit table (after products and sum-outs) implied
261
+ by this node.
262
+ """
263
+ ...
264
+
265
+
266
+ @dataclass
267
+ class _ProductLeaf(_Product):
268
+
269
+ def __init__(self, table: CircuitTable):
270
+ super().__init__(set(table.rv_idxs))
271
+ self.table: CircuitTable = table
272
+
273
+ def set_sum_out(self, need: Set[int]) -> None:
274
+ self.sum_out = self.available.difference(need)
275
+
276
+ def score(self, rv_log_sizes: Sequence[float]) -> float:
277
+ return sum(rv_log_sizes[i] for i in self.sum_out)
278
+
279
+ def get_table(self) -> CircuitTable:
280
+ return sum_out(self.table, self.sum_out)
281
+
282
+
283
+ @dataclass
284
+ class _ProductInterior(_Product):
285
+
286
+ def __init__(self, x: _Product, y: _Product):
287
+ super().__init__(x.available.union(y.available))
288
+ self.x: _Product = x
289
+ self.y: _Product = y
290
+
291
+ def set_sum_out(self, need: Set[int]) -> None:
292
+ x = self.x
293
+ y = self.y
294
+ x_y_common: Set[int] = x.available.intersection(y.available)
295
+ x_need: Set[int] = x.available.intersection(chain(need, x_y_common))
296
+ y_need: Set[int] = y.available.intersection(chain(need, x_y_common))
297
+ self.x.set_sum_out(x_need)
298
+ self.y.set_sum_out(y_need)
299
+ self.sum_out = x_need.union(y_need).difference(need)
300
+
301
+ def score(self, rv_log_sizes: Sequence[float]) -> float:
302
+ x_score = self.x.score(rv_log_sizes)
303
+ y_score = self.y.score(rv_log_sizes)
304
+ return sum(rv_log_sizes[i] for i in self.sum_out) + (x_score + y_score) * 2
305
+
306
+ def get_table(self) -> CircuitTable:
307
+ return sum_out(product(self.x.get_table(), self.y.get_table()), self.sum_out)
308
+
309
+
310
+ def _iterate_trees(factors: List[CircuitTable], separator: Set[int]) -> Iterator[_Product]:
311
+ """
312
+ Iterate over all possible binary trees that form the product of the given factors.
313
+
314
+ Args:
315
+ factors: The list of factors to be in the product.
316
+ separator: What rvs the resulting product needs to be projected onto.
317
+
318
+ Returns:
319
+ An iterator over binary product trees.
320
+
321
+ Assumes:
322
+ There is at least one factor.
323
+ """
324
+ leaves = [_ProductLeaf(table) for table in factors]
325
+ for tree in _iterate_trees_r(leaves):
326
+ tree.set_sum_out(separator)
327
+ yield tree
328
+
329
+
330
+ def _iterate_trees_r(factors: List[_Product]) -> Iterator[_Product]:
331
+ """
332
+ Recursive support function for _iterate_trees.
333
+
334
+ This will form the products, but not will not set the
335
+ `sum_out` field as that can only be done once a tree is fully formed.
336
+
337
+ Args:
338
+ factors: The list of factors to be in the product.
339
+
340
+ Returns:
341
+ An iterator over binary product trees.
342
+
343
+ Assumes:
344
+ There is at least one factor.
345
+ """
346
+
347
+ # Use heuristics to reduce the number of arithmetic operations.
348
+ # If the rvs of one factor is a subset of another factor, form their
349
+ # product, preferring to product factors with small numbers of rvs.
350
+
351
+ # Sort factors by number or rvs (in increasing order).
352
+ sorted_factors: List[Tuple[FrozenSet[int], Optional[_Product]]] = sorted(
353
+ (
354
+ (frozenset(factor.available), factor)
355
+ for factor in factors
356
+ ),
357
+ key=lambda _x: _x[0]
358
+ )
359
+
360
+ # Product any factor who's rvs are a subset of another factor.
361
+ i: int
362
+ j: int
363
+ for i, (rvs_idxs, factor) in enumerate(sorted_factors):
364
+ for j, (other_rvs_idxs, other_factor) in enumerate(sorted_factors[i + 1:], start=i + 1):
365
+ if other_rvs_idxs.issuperset(rvs_idxs):
366
+ sorted_factors[j] = (other_rvs_idxs, _ProductInterior(other_factor, factor))
367
+ sorted_factors[i] = (rvs_idxs, None)
368
+ break
369
+ factors = [factor for _, factor in sorted_factors if factor is not None]
370
+
371
+ if len(factors) == 1:
372
+ yield factors[0]
373
+ elif len(factors) == 2:
374
+ yield _ProductInterior(*factors)
375
+ else:
376
+ for i in range(len(factors)):
377
+ for j in range(i):
378
+ copy: List[_Product] = factors.copy()
379
+ x = copy.pop(i)
380
+ y = copy.pop(j)
381
+ copy.append(_ProductInterior(x, y))
382
+ for tree in _iterate_trees_r(copy):
383
+ yield tree
@@ -0,0 +1,63 @@
1
+ from enum import Enum
2
+
3
+ from ck.pgm import PGM
4
+ from ck.pgm_circuit import PGMCircuit
5
+ from ck.pgm_compiler import variable_elimination, factor_elimination, recursive_conditioning, ace
6
+ from .pgm_compiler import PGMCompiler
7
+ from .support.named_compiler_maker import get_compiler_algorithm as _get_compiler_algorithm, \
8
+ get_compiler as _get_compiler
9
+
10
+
11
+ class NamedPGMCompiler(Enum):
12
+ """
13
+ A standard collection of named compiler functions.
14
+
15
+ The `value` of each enum member is tuple containing a compiler function (PGM -> PGMCircuit).
16
+ Wrapping in a tuple is needed otherwise Python erases the type of the member, which can cause problems.
17
+ Each member itself is callable, confirming to the PGMCompiler protocol, delegating to the compiler function.
18
+ """
19
+ # @formatter:off
20
+
21
+ VE_MIN_DEGREE: PGMCompiler = _get_compiler_algorithm(variable_elimination, 'MIN_DEGREE')
22
+ VE_MIN_DEGREE_THEN_FILL: PGMCompiler = _get_compiler_algorithm(variable_elimination, 'MIN_DEGREE_THEN_FILL')
23
+ VE_MIN_FILL: PGMCompiler = _get_compiler_algorithm(variable_elimination, 'MIN_FILL')
24
+ VE_MIN_FILL_THEN_DEGREE: PGMCompiler = _get_compiler_algorithm(variable_elimination, 'MIN_FILL_THEN_DEGREE')
25
+ VE_MIN_WEIGHTED_DEGREE: PGMCompiler = _get_compiler_algorithm(variable_elimination, 'MIN_WEIGHTED_DEGREE')
26
+ VE_MIN_WEIGHTED_FILL: PGMCompiler = _get_compiler_algorithm(variable_elimination, 'MIN_WEIGHTED_FILL')
27
+ VE_MIN_TRADITIONAL_WEIGHTED_FILL: PGMCompiler = _get_compiler_algorithm(variable_elimination, 'MIN_TRADITIONAL_WEIGHTED_FILL')
28
+
29
+ FE_MIN_DEGREE: PGMCompiler = _get_compiler_algorithm(factor_elimination, 'MIN_DEGREE')
30
+ FE_MIN_DEGREE_THEN_FILL: PGMCompiler = _get_compiler_algorithm(factor_elimination, 'MIN_DEGREE_THEN_FILL')
31
+ FE_MIN_FILL: PGMCompiler = _get_compiler_algorithm(factor_elimination, 'MIN_FILL')
32
+ FE_MIN_FILL_THEN_DEGREE: PGMCompiler = _get_compiler_algorithm(factor_elimination, 'MIN_FILL_THEN_DEGREE')
33
+ FE_MIN_WEIGHTED_DEGREE: PGMCompiler = _get_compiler_algorithm(factor_elimination, 'MIN_WEIGHTED_DEGREE')
34
+ FE_MIN_WEIGHTED_FILL: PGMCompiler = _get_compiler_algorithm(factor_elimination, 'MIN_WEIGHTED_FILL')
35
+ FE_MIN_TRADITIONAL_WEIGHTED_FILL: PGMCompiler = _get_compiler_algorithm(factor_elimination, 'MIN_TRADITIONAL_WEIGHTED_FILL')
36
+ FE_BEST_JOINTREE: PGMCompiler = factor_elimination.compile_pgm_best_jointree,
37
+
38
+ RC_MIN_DEGREE: PGMCompiler = _get_compiler_algorithm(recursive_conditioning, 'MIN_DEGREE')
39
+ RC_MIN_DEGREE_THEN_FILL: PGMCompiler = _get_compiler_algorithm(recursive_conditioning, 'MIN_DEGREE_THEN_FILL')
40
+ RC_MIN_FILL: PGMCompiler = _get_compiler_algorithm(recursive_conditioning, 'MIN_FILL')
41
+ RC_MIN_FILL_THEN_DEGREE: PGMCompiler = _get_compiler_algorithm(recursive_conditioning, 'MIN_FILL_THEN_DEGREE')
42
+ RC_MIN_WEIGHTED_DEGREE: PGMCompiler = _get_compiler_algorithm(recursive_conditioning, 'MIN_WEIGHTED_DEGREE')
43
+ RC_MIN_WEIGHTED_FILL: PGMCompiler = _get_compiler_algorithm(recursive_conditioning, 'MIN_WEIGHTED_FILL')
44
+ RC_MIN_TRADITIONAL_WEIGHTED_FILL: PGMCompiler = _get_compiler_algorithm(recursive_conditioning, 'MIN_TRADITIONAL_WEIGHTED_FILL')
45
+
46
+ ACE: PGMCompiler = _get_compiler(ace)
47
+
48
+ # @formatter:on
49
+
50
+ def __call__(self, pgm: PGM, const_parameters: bool = True) -> PGMCircuit:
51
+ """
52
+ Each member of the enum is a PGMCompiler function.
53
+
54
+ This implements the `PGMCompiler` protocol.
55
+ """
56
+ return self.compiler(pgm, const_parameters=const_parameters)
57
+
58
+ @property
59
+ def compiler(self) -> PGMCompiler:
60
+ return self.value[0]
61
+
62
+
63
+ DEFAULT_PGM_COMPILER: NamedPGMCompiler = NamedPGMCompiler.FE_BEST_JOINTREE
@@ -0,0 +1,19 @@
1
+ from typing import Protocol
2
+
3
+ from ck.pgm import PGM
4
+ from ck.pgm_circuit import PGMCircuit
5
+
6
+
7
+ class PGMCompiler(Protocol):
8
+ def __call__(self, pgm: PGM, *, const_parameters: bool = True) -> PGMCircuit:
9
+ """
10
+ A PGM compiler is a function with this signature.
11
+
12
+ Args:
13
+ pgm: The PGM to compile.
14
+ const_parameters: If true, the potential function parameters will be circuit
15
+ constants, otherwise they will be circuit variables.
16
+
17
+ Returns:
18
+ a PGMCircuit which provides an arithmetic circuit to represent the PGM.
19
+ """
@@ -0,0 +1,226 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Iterable, Dict, Optional, List, Sequence, Tuple, Set
6
+
7
+ from ck.circuit import Circuit, CircuitNode
8
+ from ck.pgm import PGM
9
+ from ck.pgm_circuit import PGMCircuit
10
+ from ck.pgm_compiler.support import clusters
11
+ from ck.pgm_compiler.support.circuit_table import CircuitTable
12
+ from ck.pgm_compiler.support.clusters import ClusterAlgorithm
13
+ from ck.pgm_compiler.support.factor_tables import make_factor_tables, FactorTables
14
+ from ck.utils.iter_extras import combos
15
+
16
+ # Standard cluster algorithms.
17
+ MIN_DEGREE: ClusterAlgorithm = clusters.min_degree
18
+ MIN_FILL: ClusterAlgorithm = clusters.min_fill
19
+ MIN_DEGREE_THEN_FILL: ClusterAlgorithm = clusters.min_degree_then_fill
20
+ MIN_FILL_THEN_DEGREE: ClusterAlgorithm = clusters.min_fill_then_degree
21
+ MIN_WEIGHTED_DEGREE: ClusterAlgorithm = clusters.min_weighted_degree
22
+ MIN_WEIGHTED_FILL: ClusterAlgorithm = clusters.min_weighted_fill
23
+ MIN_TRADITIONAL_WEIGHTED_FILL: ClusterAlgorithm = clusters.min_traditional_weighted_fill
24
+
25
+
26
+ def compile_pgm(
27
+ pgm: PGM,
28
+ const_parameters: bool = True,
29
+ *,
30
+ algorithm: ClusterAlgorithm = MIN_FILL_THEN_DEGREE,
31
+ pre_prune_factor_tables: bool = True,
32
+ ) -> PGMCircuit:
33
+ """
34
+ Compile the PGM to an arithmetic circuit, using recursive conditioning.
35
+
36
+ Conforms to the `PGMCompiler` protocol.
37
+
38
+ Args:
39
+ pgm: The PGM to compile.
40
+ const_parameters: If true, the potential function parameters will be circuit
41
+ constants, otherwise they will be circuit variables.
42
+ algorithm: algorithm to get an elimination order.
43
+ pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
44
+
45
+ Returns:
46
+ a PGMCircuit object.
47
+ """
48
+ elimination_order: Sequence[int] = algorithm(pgm).eliminated
49
+ factor_tables: FactorTables = make_factor_tables(
50
+ pgm=pgm,
51
+ const_parameters=const_parameters,
52
+ multiply_indicators=True,
53
+ pre_prune_factor_tables=pre_prune_factor_tables,
54
+ )
55
+ dtree: _DTree = _make_dtree(elimination_order, factor_tables)
56
+
57
+ states: List[Sequence[int]] = [tuple(range(len(rv))) for rv in pgm.rvs]
58
+ top: CircuitNode = dtree.make_circuit(states, factor_tables.circuit)
59
+ top.circuit.remove_unreachable_op_nodes(top)
60
+
61
+ return PGMCircuit(
62
+ rvs=tuple(pgm.rvs),
63
+ conditions=(),
64
+ circuit_top=top,
65
+ number_of_indicators=factor_tables.number_of_indicators,
66
+ number_of_parameters=factor_tables.number_of_parameters,
67
+ slot_map=factor_tables.slot_map,
68
+ parameter_values=factor_tables.parameter_values,
69
+ )
70
+
71
+
72
+ def _make_dtree(elimination_order: Sequence[int], factor_tables: FactorTables) -> _DTree:
73
+ if len(factor_tables.tables) == 0:
74
+ return _DTreeLeaf(CircuitTable(factor_tables.circuit, (), ()))
75
+
76
+ # Populate `trees` with all the leaves
77
+ trees: List[_DTree] = [_DTreeLeaf(table) for table in factor_tables.tables]
78
+
79
+ # join trees by elimination random variable
80
+ for rv_index in elimination_order:
81
+ next_trees: List[_DTree] = []
82
+ to_join: List[_DTree] = []
83
+ for tree in trees:
84
+ if rv_index in tree.vars:
85
+ to_join.append(tree)
86
+ else:
87
+ next_trees.append(tree)
88
+ if len(to_join) >= 2:
89
+ while len(to_join) > 1:
90
+ # join the two shallowest trees
91
+ to_join.sort(key=lambda t: -t.depth())
92
+ x = to_join.pop()
93
+ y = to_join.pop()
94
+ to_join.append(_DTreeInterior(x, y))
95
+ next_trees.append(to_join[0])
96
+ trees = next_trees
97
+
98
+ # Make sure there is only one tree
99
+ while len(trees) > 1:
100
+ x = trees.pop(0)
101
+ y = trees.pop(0)
102
+ trees.append(_DTreeInterior(x, y))
103
+
104
+ root = trees[0]
105
+ root.update_cutset()
106
+ return root
107
+
108
+
109
+ class _DTree(ABC):
110
+ """
111
+ A node in a binary decomposition tree.
112
+
113
+ A node is either a _DTreeLeaf, holding a single CircuitTable,
114
+ or is a _DTreeInterior, which has exactly two children.
115
+ """
116
+
117
+ def __init__(self, vars_idxs: Set[int]):
118
+ self.vars: Set[int] = vars_idxs
119
+ self.cutset: Sequence[int] = ()
120
+ self.context: Sequence[int] = ()
121
+
122
+ @abstractmethod
123
+ def update_cutset(self, acutset: Iterable[int] = ()) -> None:
124
+ """
125
+ After the d-tree is defined, call `update_cutset` on the root
126
+ to ensure all fields are properly set.
127
+ """
128
+ ...
129
+
130
+ @abstractmethod
131
+ def make_circuit(self, states: List[Sequence[int]], circuit: Circuit) -> CircuitNode:
132
+ """
133
+ After the d-tree is defined and cutsets are updated,
134
+ construct a circuit using recursive conditioning.
135
+ """
136
+ ...
137
+
138
+ @abstractmethod
139
+ def depth(self) -> int:
140
+ """
141
+ Tree depth.
142
+ """
143
+
144
+
145
+ @dataclass
146
+ class _DTreeLeaf(_DTree):
147
+
148
+ def __init__(self, table: CircuitTable):
149
+ super().__init__(set(table.rv_idxs))
150
+ self.table: CircuitTable = table
151
+
152
+ def update_cutset(self, acutset: Iterable[int] = ()) -> None:
153
+ pass
154
+
155
+ def make_circuit(self, states: List[Sequence[int]], circuit: Circuit) -> CircuitNode:
156
+ table = self.table
157
+
158
+ key_states: List[Sequence[int]] = [
159
+ states[rv_idx]
160
+ for rv_idx in table.rv_idxs
161
+ ]
162
+ to_sum: List[CircuitNode] = list(
163
+ filter(
164
+ (lambda n: n is not None),
165
+ (table.get(key) for key in combos(key_states))
166
+ )
167
+ )
168
+ return circuit.optimised_add(to_sum)
169
+
170
+ def depth(self) -> int:
171
+ return 1
172
+
173
+
174
+ @dataclass
175
+ class _DTreeInterior(_DTree):
176
+
177
+ def __init__(self, x: _DTree, y: _DTree):
178
+ super().__init__(x.vars.union(y.vars))
179
+ self.x: _DTree = x
180
+ self.y: _DTree = y
181
+ self.cache: Dict[Tuple[int, ...], CircuitNode] = {}
182
+
183
+ def update_cutset(self, acutset: Iterable[int] = ()) -> None:
184
+ cutset: Set[int] = self.x.vars.intersection(self.y.vars).difference(acutset)
185
+ self.cutset = tuple(cutset)
186
+ self.context = tuple(self.vars.intersection(acutset))
187
+
188
+ next_acutset = cutset.union(acutset)
189
+ self.x.update_cutset(next_acutset)
190
+ self.y.update_cutset(next_acutset)
191
+
192
+ def make_circuit(self, states: List[Sequence[int]], circuit: Circuit) -> CircuitNode:
193
+
194
+ assert all(len(states[rv_idx]) == 1 for rv_idx in self.context), 'consistency check'
195
+ context_key: Tuple[int, ...] = tuple(
196
+ states[rv_idx][0]
197
+ for rv_idx in self.context
198
+ )
199
+
200
+ cache: Optional[CircuitNode] = self.cache.get(context_key)
201
+ if cache is not None:
202
+ return cache
203
+
204
+ cutset = self.cutset
205
+ key_states: List[Sequence[int]] = [
206
+ states[rv_idx]
207
+ for rv_idx in cutset
208
+ ]
209
+ to_sum: List[CircuitNode] = []
210
+ for key in combos(key_states):
211
+ # Update the evidence with the keys
212
+ next_states = states.copy()
213
+ for s, i in zip(key, cutset):
214
+ next_states[i] = (s,)
215
+
216
+ # Recursively call
217
+ x_node = self.x.make_circuit(next_states, circuit)
218
+ y_node = self.y.make_circuit(next_states, circuit)
219
+ to_sum.append(circuit.optimised_mul((x_node, y_node)))
220
+
221
+ result = circuit.optimised_add(to_sum)
222
+ self.cache[context_key] = result
223
+ return result
224
+
225
+ def depth(self) -> int:
226
+ return max(self.x.depth(), self.y.depth())
File without changes
@@ -0,0 +1,9 @@
1
+ # from .circuit_table_py import (
2
+ from .circuit_table import (
3
+ CircuitTable,
4
+ TableInstance,
5
+ sum_out,
6
+ sum_out_all,
7
+ project,
8
+ product,
9
+ )