compiled-knowledge 4.0.0a20__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37525 -0
  4. ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,332 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from itertools import chain
5
+ from typing import List, Set, Callable, Sequence, Tuple
6
+
7
+ import numpy as np
8
+
9
+ from ck.pgm import PGM, Factor
10
+ from ck.pgm_compiler.support.clusters import Clusters, min_degree, min_fill, \
11
+ min_degree_then_fill, min_fill_then_degree, min_weighted_degree, min_weighted_fill, min_traditional_weighted_fill, \
12
+ ClusterAlgorithm
13
+ from ck.utils.np_extras import NDArrayFloat64
14
+
15
+
16
+ @dataclass
17
+ class JoinTree:
18
+ """
19
+ This is a recursive data structure representing a join-tree.
20
+ Each node in the join-tree is represented by a JoinTree object.
21
+ """
22
+
23
+ # The PGM that this join tree is for.
24
+ pgm: PGM
25
+
26
+ # Indexes of random variables in this join tree node
27
+ cluster: Set[int]
28
+
29
+ # Child nodes in the join tree
30
+ children: List[JoinTree]
31
+
32
+ # Factors of the PGM allocated to this join tree node.
33
+ factors: List[Factor]
34
+
35
+ # Indexes of random variables that in both this cluster and the parent's cluster.
36
+ # (Empty if this is the root of the spanning tree).
37
+ separator: Set[int]
38
+
39
+ def max_cluster_size(self) -> int:
40
+ """
41
+ Returns:
42
+ the maximum `len(self.cluster)` over self and all children, recursively.
43
+ """
44
+ return max(chain((len(self.cluster),), (child.max_cluster_size() for child in self.children)))
45
+
46
+ def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
47
+ """
48
+ Calculate the maximum cluster weighted size for this cluster and its children.
49
+
50
+ Args:
51
+ rv_log_sizes: is an array of random variable sizes, such that
52
+ for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
53
+
54
+ Returns:
55
+ the maximum `log2` over self and all children, recursively.
56
+ """
57
+ self_weighted_size: float = sum(rv_log_sizes[rv_idx] for rv_idx in self.cluster)
58
+ return max(
59
+ chain(
60
+ (self_weighted_size,),
61
+ (child.max_cluster_weighted_size(rv_log_sizes) for child in self.children)
62
+ )
63
+ )
64
+
65
+ def dump(self, *, prefix: str = '', indent: str = ' ', show_factors: bool = True) -> None:
66
+ """
67
+ Print a dump of the Join Tree.
68
+ This is intended for debugging and demonstration purposes.
69
+
70
+ Each cluster is printed as: {separator rvs} | {non-separator rvs}.
71
+
72
+ Args:
73
+ prefix: optional prefix for indenting all lines.
74
+ indent: additional prefix to use for extra indentation.
75
+ show_factors: if true, the factors of each cluster are shown.
76
+ """
77
+ sep_str = ' '.join(repr(str(self.pgm.rvs[i])) for i in sorted(self.separator))
78
+ rest_str = ' '.join(repr(str(self.pgm.rvs[i])) for i in sorted(self.cluster) if i not in self.separator)
79
+ if len(sep_str) > 0:
80
+ sep_str += ' '
81
+ print(f'{prefix}{sep_str}| {rest_str} (factors: {len(self.factors)})')
82
+ if show_factors:
83
+ for factor in self.factors:
84
+ print(f'{prefix}factor{factor}')
85
+ next_prefix = prefix + indent
86
+ for child in self.children:
87
+ child.dump(prefix=next_prefix, indent=indent, show_factors=show_factors)
88
+
89
+
90
+ # Type for a join tree algorithm: PGM -> JoinTree.
91
+ JoinTreeAlgorithm = Callable[[PGM], JoinTree]
92
+
93
+
94
+ def _join_tree_algorithm(pgm_to_clusters: ClusterAlgorithm) -> JoinTreeAlgorithm:
95
+ """
96
+ Helper function for creating a standard JoinTreeAlgorithm
97
+ from a ClusterAlgorithm.
98
+
99
+ Args:
100
+ pgm_to_clusters: The clusters method to use.
101
+
102
+ Returns:
103
+ a JoinTreeAlgorithm.
104
+ """
105
+
106
+ def __join_tree_algorithm(pgm: PGM) -> JoinTree:
107
+ clusters: Clusters = pgm_to_clusters(pgm)
108
+ return clusters_to_join_tree(clusters)
109
+
110
+ return __join_tree_algorithm
111
+
112
+
113
+ # standard JoinTreeAlgorithms
114
+
115
+ MIN_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_degree)
116
+ MIN_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_fill)
117
+ MIN_DEGREE_THEN_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_degree_then_fill)
118
+ MIN_FILL_THEN_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_fill_then_degree)
119
+ MIN_WEIGHTED_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_weighted_degree)
120
+ MIN_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_weighted_fill)
121
+ MIN_TRADITIONAL_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_traditional_weighted_fill)
122
+
123
+
124
+ def clusters_to_join_tree(clusters: Clusters) -> JoinTree:
125
+ """
126
+ Construct a join tree from the given random variable clusters.
127
+
128
+ A join tree is formed by finding a minimum spanning tree over the clusters
129
+ where the cost between a pair of clusters is the number of random variables
130
+ in common (using separator state space size to break ties).
131
+
132
+ Args:
133
+ clusters: the clusters that resulted from graph clusters of a PGM.
134
+
135
+ Returns:
136
+ a JoinTree.
137
+ """
138
+ pgm: PGM = clusters.pgm
139
+ cluster_sets: List[Set[int]] = clusters.clusters
140
+ number_of_clusters = len(cluster_sets)
141
+
142
+ # Dealing with these cases directly simplifies
143
+ # the spanning tree algorithm implementation.
144
+ if number_of_clusters == 0:
145
+ return JoinTree(pgm, set(), [], [], set())
146
+ elif number_of_clusters == 1:
147
+ return JoinTree(pgm, cluster_sets[0], [], list(pgm.factors), set())
148
+
149
+ # Calculate inter-cluster costs for determining the minimum spanning tree
150
+ cost: NDArrayFloat64 = np.zeros((number_of_clusters, number_of_clusters), dtype=np.float64)
151
+ # We will use separator state space size to break ties.
152
+ max_raw_break_cost = sum(pgm.rv_log_sizes) * 1.1 # sum of break costs must be < 1
153
+ break_cost = [cost / max_raw_break_cost for cost in pgm.rv_log_sizes]
154
+ for i in range(number_of_clusters):
155
+ cluster_i = cluster_sets[i]
156
+ for j in range(i + 1, number_of_clusters):
157
+ cluster_j = cluster_sets[j]
158
+ separator = cluster_i.intersection(cluster_j)
159
+ cost[i, j] = cost[j, i] = -len(separator) + sum(break_cost[rv_idx] for rv_idx in separator)
160
+
161
+ # Make the spanning tree over the clusters
162
+ root_custer_index: int
163
+ children: List[List[int]]
164
+ children, root_custer_index = _make_spanning_tree_small_root(cost, clusters.clusters)
165
+
166
+ # Allocate each PGM factor to a cluster
167
+ cluster_factors: List[List[Factor]] = [[] for _ in range(number_of_clusters)]
168
+ ordered_indexed_clusters = list(enumerate(cluster_sets))
169
+ ordered_indexed_clusters.sort(key=lambda idx_c: len(idx_c[1])) # sort from smallest to largest cluster
170
+ for factor in pgm.factors:
171
+ rv_indexes = frozenset(rv.idx for rv in factor.rvs)
172
+ for cluster_index, cluster in ordered_indexed_clusters:
173
+ if rv_indexes.issubset(cluster):
174
+ cluster_factors[cluster_index].append(factor)
175
+ break
176
+
177
+ return _form_join_tree_r(pgm, root_custer_index, set(), children, cluster_sets, cluster_factors)
178
+
179
+
180
+ _INF = float('inf')
181
+
182
+
183
+ def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]]) -> Tuple[List[List[int]], int]:
184
+ """
185
+ Construct a minimum spanning tree over the clusters, where the root is the cluster with
186
+ the smallest number of random variable.
187
+
188
+ Args:
189
+ cost: is an N x N matrix of costs between N clusters.
190
+ clusters: is a list of N clusters, each cluster is a set of random variable indices.
191
+
192
+ Returns:
193
+ (spanning_tree, root_index)
194
+
195
+ spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
196
+ the given cost matrix, each node is a list of children, each child being
197
+ represented as an index into the list of nodes.
198
+
199
+ root_index: is the index the chosen root of the spanning tree.
200
+ """
201
+ root_custer_index: int = 0
202
+ root_size: int = len(clusters[root_custer_index])
203
+ for i, cluster in enumerate(clusters[1:], start=1):
204
+ if len(clusters[root_custer_index]) < root_size:
205
+ root_custer_index = i
206
+ root_size: int = len(cluster)
207
+
208
+ children: List[List[int]] = _make_spanning_tree_at_root(cost, root_custer_index)
209
+ return children, root_custer_index
210
+
211
+
212
+ def _make_spanning_tree_arbitrary_root(cost: NDArrayFloat64) -> Tuple[List[List[int]], int]:
213
+ """
214
+ Construct a minimum spanning tree over the clusters, starting at an arbitrary root.
215
+
216
+ Args:
217
+ cost: is an N x N matrix of costs between N clusters.
218
+
219
+ Returns:
220
+ (spanning_tree, root_index)
221
+
222
+ spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
223
+ the given cost matrix, each node is a list of children, each child being
224
+ represented as an index into the list of nodes.
225
+
226
+ root_index: is the index the chosen root of the spanning tree.
227
+ """
228
+ root_index: int = 0
229
+ spanning_tree: List[List[int]] = _make_spanning_tree_at_root(cost, root_index)
230
+ return spanning_tree, root_index
231
+
232
+
233
+ def _make_spanning_tree_at_root(
234
+ cost: NDArrayFloat64,
235
+ root_custer_index: int,
236
+ ) -> List[List[int]]:
237
+ """
238
+ Construct a minimum spanning tree over the clusters, starting at the given root.
239
+
240
+ Args:
241
+ cost: and nxn matrix where n is the number of clusters and cost[i, j]
242
+ gives the cost between clusters i and j.
243
+ root_custer_index: a nominated root cluster to be the root of the tree.
244
+
245
+ Returns:
246
+ a spanning tree represented as a list of nodes, the list is coindexed with
247
+ the given cost matrix, each node is a list of children, each child being
248
+ represented as an index into the list of nodes. The root node is the
249
+ index `root_custer_index` as passed to this function.
250
+ """
251
+ number_of_clusters: int = cost.shape[0]
252
+
253
+ # clusters left to process.
254
+ remaining: List[int] = list(range(number_of_clusters))
255
+
256
+ # clusters that have been processed.
257
+ included: List[int] = []
258
+
259
+ def remove_remaining(_remaining_index: int) -> None:
260
+ # Remove the `remaining` element at the given index location.
261
+ remaining[_remaining_index] = remaining[-1]
262
+ remaining.pop()
263
+
264
+ # Move root from `remaining` to `included`
265
+ included.append(root_custer_index)
266
+ remove_remaining(root_custer_index) # assumes remaining[root_custer_index] = root_custer_index
267
+
268
+ # Data structure to collect the results.
269
+ children: List[List[int]] = [[] for _ in range(number_of_clusters)]
270
+
271
+ while True:
272
+ min_i: int = 0
273
+ min_j: int = 0
274
+ min_j_pos: int = 0
275
+ min_c: float = _INF
276
+ for i in included:
277
+ for j_pos, j in enumerate(remaining):
278
+ c: float = cost.item(i, j)
279
+ if c < min_c:
280
+ min_c = c
281
+ min_i = i
282
+ min_j = j
283
+ min_j_pos = j_pos
284
+
285
+ # Record the child and move remaining_idx from 'remaining' to 'included'.
286
+ children[min_i].append(min_j)
287
+ if len(remaining) == 1:
288
+ # That was the last one.
289
+ return children
290
+
291
+ # Update `remaining` and `included`
292
+ remove_remaining(min_j_pos)
293
+ included.append(min_j)
294
+
295
+
296
+ def _form_join_tree_r(
297
+ pgm: PGM,
298
+ cluster_index: int,
299
+ parent_cluster: Set[int],
300
+ children: Sequence[List[int]],
301
+ clusters: Sequence[Set[int]],
302
+ cluster_factors: List[List[Factor]],
303
+ ) -> JoinTree:
304
+ """
305
+ Recursively build a JoinTree from the spanning tree `children`.
306
+ This function merely pull the corresponding component from the
307
+ arguments to make a JoinTree object, doing this recursively
308
+ for the children.
309
+
310
+ Args:
311
+ pgm: the source PGM for the join tree.
312
+ cluster_index: index for the node we are processing (current root). This
313
+ indexes into `children`, `clusters`, and `cluster_factors`.
314
+ parent_cluster: set of random variable indices in the parent cluster.
315
+ children: list of spanning tree nodes, as per `_make_spanning_tree_at_root` result.
316
+ clusters: list of clusters, each cluster is a set of random variable indices.
317
+ cluster_factors: assignment of factors to clusters.
318
+ """
319
+ cluster: Set[int] = clusters[cluster_index]
320
+ factors: List[Factor] = cluster_factors[cluster_index]
321
+ children = [
322
+ _form_join_tree_r(pgm, child, cluster, children, clusters, cluster_factors)
323
+ for child in children[cluster_index]
324
+ ]
325
+ separator: Set[int] = parent_cluster.intersection(cluster)
326
+ return JoinTree(
327
+ pgm,
328
+ cluster,
329
+ children,
330
+ factors,
331
+ separator,
332
+ )
@@ -0,0 +1,43 @@
1
+ from types import ModuleType
2
+ from typing import Tuple
3
+
4
+ from ck.pgm import PGM
5
+ from ck.pgm_circuit import PGMCircuit
6
+ from ck.pgm_compiler import PGMCompiler
7
+
8
+
9
+ def get_compiler(module: ModuleType, **kwargs) -> Tuple[PGMCompiler]:
10
+ """
11
+ Helper function to create a named PGM compiler.
12
+
13
+ Args:
14
+ module: module containing `compile_pgm` function.
15
+ kwargs: are additional keyword arguments to `compile_pgm`.
16
+
17
+ Returns:
18
+ a singleton tuple containing PGMCompiler function.
19
+ """
20
+
21
+ def compiler(pgm: PGM, const_parameters: bool = True) -> PGMCircuit:
22
+ """Conforms to the `PGMCompiler` protocol."""
23
+ return module.compile_pgm(pgm, const_parameters=const_parameters, **kwargs)
24
+
25
+ return compiler,
26
+
27
+
28
+ def get_compiler_algorithm(module: ModuleType, algorithm: str, **kwargs) -> Tuple[PGMCompiler]:
29
+ """
30
+ Helper function to create a named PGM compiler, with a named algorithm argument.
31
+
32
+ Args:
33
+ module: module containing `compile_pgm` function.
34
+ algorithm: name of the algorithm, to pass as keyword argument to `compile_pgm`.
35
+ The algorithm should be declared in the module.
36
+ kwargs: are additional keyword arguments to `compile_pgm`.
37
+
38
+ Returns:
39
+ a singleton tuple containing PGMCompiler function.
40
+ """
41
+ return get_compiler(module, algorithm=getattr(module, algorithm), **kwargs)
42
+
43
+
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Sequence
4
+
5
+ from ck.circuit import CircuitNode
6
+ from ck.pgm import PGM
7
+ from ck.pgm_circuit import PGMCircuit
8
+ from ck.pgm_compiler.support import clusters
9
+ from ck.pgm_compiler.support.circuit_table import CircuitTable, product, sum_out
10
+ from ck.pgm_compiler.support.clusters import ClusterAlgorithm
11
+ from ck.pgm_compiler.support.factor_tables import make_factor_tables, FactorTables
12
+
13
+ # Standard cluster algorithms.
14
+ MIN_DEGREE: ClusterAlgorithm = clusters.min_degree
15
+ MIN_FILL: ClusterAlgorithm = clusters.min_fill
16
+ MIN_DEGREE_THEN_FILL: ClusterAlgorithm = clusters.min_degree_then_fill
17
+ MIN_FILL_THEN_DEGREE: ClusterAlgorithm = clusters.min_fill_then_degree
18
+ MIN_WEIGHTED_DEGREE: ClusterAlgorithm = clusters.min_weighted_degree
19
+ MIN_WEIGHTED_FILL: ClusterAlgorithm = clusters.min_weighted_fill
20
+ MIN_TRADITIONAL_WEIGHTED_FILL: ClusterAlgorithm = clusters.min_traditional_weighted_fill
21
+
22
+
23
+ def compile_pgm(
24
+ pgm: PGM,
25
+ const_parameters: bool = True,
26
+ *,
27
+ algorithm: ClusterAlgorithm = MIN_FILL_THEN_DEGREE,
28
+ pre_prune_factor_tables: bool = False,
29
+ ) -> PGMCircuit:
30
+ """
31
+ Compile the PGM to an arithmetic circuit, using variable elimination.
32
+
33
+ Conforms to the `PGMCompiler` protocol.
34
+
35
+ Args:
36
+ pgm: The PGM to compile.
37
+ const_parameters: If true, the potential function parameters will be circuit
38
+ constants, otherwise they will be circuit variables.
39
+ algorithm: algorithm to get an elimination order.
40
+ pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
41
+
42
+ Returns:
43
+ a PGMCircuit object.
44
+ """
45
+ factor_tables: FactorTables = make_factor_tables(
46
+ pgm=pgm,
47
+ const_parameters=const_parameters,
48
+ multiply_indicators=True,
49
+ pre_prune_factor_tables=pre_prune_factor_tables,
50
+ )
51
+
52
+ elimination_order: Sequence[int] = algorithm(pgm).eliminated
53
+
54
+ # Eliminate rvs from the factor tables according to the
55
+ # elimination order.
56
+ cur_tables: List[CircuitTable] = list(factor_tables.tables)
57
+ for rv_idx in elimination_order:
58
+ next_tables: List[CircuitTable] = []
59
+ tables_with_rv: List[CircuitTable] = []
60
+ for table in cur_tables:
61
+ if rv_idx in table.rv_idxs:
62
+ tables_with_rv.append(table)
63
+ else:
64
+ next_tables.append(table)
65
+ if len(tables_with_rv) > 0:
66
+ while len(tables_with_rv) > 1:
67
+ # product the two smallest tables
68
+ tables_with_rv.sort(key=lambda _t: -len(_t))
69
+ x = tables_with_rv.pop()
70
+ y = tables_with_rv.pop()
71
+ tables_with_rv.append(product(x, y))
72
+ next_tables.append(sum_out(tables_with_rv[0], (rv_idx,)))
73
+ cur_tables = next_tables
74
+
75
+ # All rvs are now eliminated - all tables should have a single top.
76
+ tops: List[CircuitNode] = [
77
+ table.top()
78
+ for table in cur_tables
79
+ ]
80
+ top: CircuitNode = factor_tables.circuit.optimised_mul(tops)
81
+ top.circuit.remove_unreachable_op_nodes(top)
82
+
83
+ return PGMCircuit(
84
+ rvs=tuple(pgm.rvs),
85
+ conditions=(),
86
+ circuit_top=top,
87
+ number_of_indicators=factor_tables.number_of_indicators,
88
+ number_of_parameters=factor_tables.number_of_parameters,
89
+ slot_map=factor_tables.slot_map,
90
+ parameter_values=factor_tables.parameter_values,
91
+ )
File without changes
@@ -0,0 +1,50 @@
1
+ from typing import Sequence, Iterable, Tuple, Dict, List
2
+
3
+ from ck.pgm import RandomVariable, Indicator, Instance
4
+ from ck.probability.probability_space import ProbabilitySpace, Condition, check_condition
5
+
6
+
7
+ class EmpiricalProbabilitySpace(ProbabilitySpace):
8
+ def __init__(self, rvs: Sequence[RandomVariable], samples: Iterable[Instance]):
9
+ """
10
+ Enable probabilistic queries over a sample from a sample space.
11
+ Note that this is not necessarily an efficient approach to calculating probabilities and statistics.
12
+
13
+ Assumes:
14
+ len(sample) == len(rvs), for each sample in samples.
15
+ 0 <= sample[i] < len(rvs[i]), for each sample in samples, for i in range(len(rvs)).
16
+
17
+ Args:
18
+ rvs: The random variables.
19
+ samples: instances (state indexes) that are samples from the given rvs.
20
+ """
21
+ self._rvs: Sequence[RandomVariable] = tuple(rvs)
22
+ self._samples: List[Instance] = list(samples)
23
+ self._rv_idx_to_sample_idx: Dict[int, int] = {
24
+ rv.idx: i
25
+ for i, rv in enumerate(self._rvs)
26
+ }
27
+
28
+ @property
29
+ def rvs(self) -> Sequence[RandomVariable]:
30
+ return self._rvs
31
+
32
+ def wmc(self, *condition: Condition) -> float:
33
+ condition: Tuple[Indicator, ...] = check_condition(condition)
34
+
35
+ checks = [set() for _ in self._rvs]
36
+ for ind in condition:
37
+ checks[self._rv_idx_to_sample_idx[ind.rv_idx]].add(ind.state_idx)
38
+ for i in range(len(checks)):
39
+ if len(checks[i]) > 0:
40
+ checks[i] = set(range(len(self._rvs[i]))).difference(checks[i])
41
+
42
+ def satisfied(instance: Instance) -> bool:
43
+ return not any((state in check) for state, check in zip(instance, checks))
44
+
45
+ return sum(1 for _ in filter(satisfied, self._samples))
46
+
47
+ @property
48
+ def z(self) -> float:
49
+ return len(self._samples)
50
+
@@ -0,0 +1,32 @@
1
+ from typing import Sequence, Iterable, Tuple, Dict, List
2
+
3
+ from ck.pgm import RandomVariable, Indicator, Instance, PGM
4
+ from ck.probability.probability_space import ProbabilitySpace, Condition, check_condition
5
+
6
+
7
+ class PGMProbabilitySpace(ProbabilitySpace):
8
+ def __init__(self, pgm: PGM):
9
+ """
10
+ Enable probabilistic queries directly on a PGM.
11
+ Note that this is not necessarily an efficient approach to calculating probabilities and statistics.
12
+
13
+ Args:
14
+ pgm: The PGM to query.
15
+ """
16
+ self._pgm = pgm
17
+ self._z = None
18
+
19
+ @property
20
+ def rvs(self) -> Sequence[RandomVariable]:
21
+ return self._pgm.rvs
22
+
23
+ def wmc(self, *condition: Condition) -> float:
24
+ condition: Tuple[Indicator, ...] = check_condition(condition)
25
+ return self._pgm.value_product_indicators(*condition)
26
+
27
+ @property
28
+ def z(self) -> float:
29
+ if self._z is None:
30
+ self._z = self._pgm.value_product_indicators()
31
+ return self._z
32
+