compiled-knowledge 4.0.0a20__cp312-cp312-musllinux_1_2_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37520 -0
  4. ck/circuit/_circuit_cy.cpython-312-x86_64-linux-musl.so +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19821 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-x86_64-linux-musl.so +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10615 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-x86_64-linux-musl.so +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16393 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-x86_64-linux-musl.so +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,568 @@
1
+ """
2
+ Graph analysis to identify clusters using elimination heuristics.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from typing import Set, Iterable, Callable, Iterator, Tuple, List, overload, Sequence
7
+
8
+ from ck.pgm import PGM
9
+
10
+ # A VEObjective is a variable elimination objective function.
11
+ # An objective function is a function from a random variable index (int)
12
+ # to an objective value (float or int). This is used to select
13
+ # a random variable to eliminate in `ve_greedy_min`.
14
+ VEObjective = Callable[[int], int | float]
15
+
16
+
17
+ def ve_fixed(clusters: Clusters, order: Iterable[int]) -> None:
18
+ """
19
+ Apply the given fixed elimination order to the elimination tree.
20
+
21
+ Args:
22
+ clusters: a clusters object with uneliminated random variables.
23
+ order: the order of variable elimination.
24
+
25
+ Assumes:
26
+ * All rv indexes in `order` are also in `clusters.uneliminated`.
27
+ * There are no duplicates in `order``.
28
+ """
29
+ for rv_index in order:
30
+ clusters.eliminate(rv_index)
31
+
32
+
33
+ def ve_greedy_min(
34
+ clusters: Clusters,
35
+ objective: VEObjective | Tuple[VEObjective, ...],
36
+ use_twig_prefix: bool = True,
37
+ use_optimal_prefix: bool = False,
38
+ ) -> None:
39
+ """
40
+ The greedy variable elimination heuristic.
41
+
42
+ The objective is a function from (eliminable: Clusters, var_idx: int) to
43
+ which should return an objective value (to greedily minimise by the method).
44
+ The objective may be a tuple of objective functions for tie breaking.
45
+
46
+ Args:
47
+ clusters: a clusters object with uneliminated random variables.
48
+ objective: the objective function ( or a tuple of objective functions) to minimise each iteration.
49
+ use_twig_prefix: if true, then `twig_prefix` is used to eliminate any
50
+ candidate random variable prior selecting random variables using the objective function.
51
+ use_optimal_prefix: if true, then `optimal_prefix` is used to eliminate any
52
+ candidate random variable prior selecting random variables using the objective function.
53
+ """
54
+ uneliminated: Set[int] = clusters.uneliminated
55
+
56
+ if isinstance(objective, tuple):
57
+ def __objective(_rv_index: int) -> Tuple[float | int, ...]:
58
+ return tuple(f(_rv_index) for f in objective)
59
+ else:
60
+ __objective = objective
61
+
62
+ while len(uneliminated) > 1:
63
+
64
+ if use_twig_prefix:
65
+ twig_prefix(clusters)
66
+ if len(uneliminated) <= 1:
67
+ break
68
+
69
+ if use_optimal_prefix:
70
+ optimal_prefix(clusters)
71
+ if len(uneliminated) <= 1:
72
+ break
73
+
74
+ min_iter: Iterator[int] = iter(uneliminated)
75
+ min_rv_index = next(min_iter)
76
+ min_obj = __objective(min_rv_index)
77
+ for rv_index in min_iter:
78
+ obj = __objective(rv_index)
79
+ if obj < min_obj:
80
+ min_rv_index = rv_index
81
+ min_obj = obj
82
+ clusters.eliminate(min_rv_index)
83
+
84
+ if len(uneliminated) > 0:
85
+ # eliminate the last rv
86
+ clusters.eliminate(next(iter(uneliminated)))
87
+
88
+
89
+ def twig_prefix(clusters: Clusters) -> None:
90
+ """
91
+ Eliminate all random variables with degree zero or one.
92
+ """
93
+
94
+ def get_rvs(degree: int) -> List[int]:
95
+ return [
96
+ _rv_index
97
+ for _rv_index in clusters.uneliminated
98
+ if clusters.degree(_rv_index) == degree
99
+ ]
100
+
101
+ for rv_index in get_rvs(degree=0):
102
+ clusters.eliminate(rv_index)
103
+
104
+ while len(clusters.uneliminated) > 0:
105
+ eliminating = get_rvs(degree=1)
106
+ if len(eliminating) == 0:
107
+ break
108
+ for rv_index in eliminating:
109
+ clusters.eliminate(rv_index)
110
+
111
+
112
+ def optimal_prefix(clusters: Clusters) -> None:
113
+ """
114
+ Eliminate all random variables that are guaranteed to be optimal (in resulting tree width).
115
+
116
+ See Adnan Darwiche, 2009, Modeling and Reasoning with Bayesian Networks, p207.
117
+ """
118
+
119
+ def _get_lower_bound() -> int:
120
+ # Return a lower bound on the tree width for the current clusters.
121
+ return max(
122
+ max(
123
+ (len(clusters.connections(_rv_index)) for _rv_index in clusters.uneliminated),
124
+ default=0
125
+ ) - 1,
126
+ 0
127
+ )
128
+
129
+ prev_number_uneliminated: int = len(clusters.uneliminated) + 1
130
+
131
+ while prev_number_uneliminated > len(clusters.uneliminated):
132
+ prev_number_uneliminated = len(clusters.uneliminated)
133
+ low: int = _get_lower_bound()
134
+ to_eliminate: Set[int] = set()
135
+ for rv_index in clusters.uneliminated:
136
+ fill: int = clusters.fill(rv_index)
137
+ if fill == 0:
138
+ # simplical rule: no fill edges
139
+ to_eliminate.add(rv_index)
140
+ elif fill == 1 and clusters.degree(rv_index) <= low:
141
+ # almost simplical rule: one fill edge and degree <= low
142
+ to_eliminate.add(rv_index)
143
+
144
+ # Perform eliminations
145
+ for rv_index in to_eliminate:
146
+ clusters.eliminate(rv_index)
147
+
148
+ low: int = _get_lower_bound()
149
+ if low >= 3:
150
+ to_eliminate: Set[int] = set()
151
+ for rv_index_i in clusters.uneliminated:
152
+ if clusters.degree(rv_index_i) == 3:
153
+ i_neighbours: Set[int] = clusters.connections(rv_index_i)
154
+
155
+ # buddy rule: two joined nodes with degree 3 and sam neighbours
156
+ for rv_index_j in i_neighbours:
157
+ if clusters.degree(rv_index_j) == 3:
158
+ j_neighbours: Set[int] = clusters.connections(rv_index_j)
159
+ if i_neighbours.difference([rv_index_j]) == j_neighbours.difference([rv_index_i]):
160
+ to_eliminate.add(rv_index_i)
161
+ to_eliminate.add(rv_index_j)
162
+
163
+ # check cube rule: i, a, b, c form a cube
164
+ if len(i_neighbours) == 3:
165
+ if all(clusters.degree(rv_index) == 3 for rv_index in i_neighbours):
166
+ a, b, c = tuple(i_neighbours)
167
+ ab = clusters.connections(a).intersection(clusters.connections(a))
168
+ ac = clusters.connections(a).intersection(clusters.connections(c))
169
+ bc = clusters.connections(b).intersection(clusters.connections(c))
170
+ if len(ab) == 1 and len(ac) == 1 and len(bc) == 1:
171
+ to_eliminate.add(rv_index_i)
172
+ to_eliminate.add(a)
173
+ to_eliminate.add(b)
174
+ to_eliminate.add(c)
175
+
176
+ # Perform eliminations
177
+ for rv_index in to_eliminate:
178
+ clusters.eliminate(rv_index)
179
+
180
+
181
+ class Clusters:
182
+ """
183
+ A Clusters object holds the state of a connection graph while
184
+ eliminating variables to construct clusters for a PGM graph.
185
+
186
+ The Clusters object can either be "in-progress" where `len(Clusters.uneliminated) > 0`,
187
+ or be "completed" where `len(Clusters.uneliminated) == 0`.
188
+
189
+ See Adnan Darwiche, 2009, Modeling and Reasoning with Bayesian Networks, p164.
190
+ """
191
+
192
+ def __init__(self, pgm: PGM, maximal_clusters_only: bool = True):
193
+ """
194
+ Args:
195
+ pgm: source PGM defining initial connection graph.
196
+ maximal_clusters_only: if true, then any subsumed cluster will be incorporated
197
+ into its subsuming cluster (once all random variables are eliminated).
198
+ """
199
+ self._pgm: PGM = pgm
200
+ self._uneliminated: Set[int] = {rv.idx for rv in pgm.rvs}
201
+ self._eliminated: List[int] = []
202
+ self._rv_log_sizes: Sequence[float] = pgm.rv_log_sizes
203
+ self._maximal_clusters_only = maximal_clusters_only
204
+
205
+ # Create a connection set for each random variable.
206
+ # The connection set keeps track of what _other_ random variable it's connected to (via factors).
207
+ # I.e., the connections define an interaction graph.
208
+ connections: List[Set[int]] = [set() for _ in range(pgm.number_of_rvs)]
209
+ for factor in pgm.factors:
210
+ rv_indexes = [rv.idx for rv in factor.rvs]
211
+ for index in rv_indexes:
212
+ connections[index].update(rv_indexes)
213
+ for index, rv_connections in enumerate(connections):
214
+ rv_connections.discard(index)
215
+ self._connections = connections
216
+
217
+ # Deal with the case of an empty PGM.
218
+ if len(self._uneliminated) == 0:
219
+ self._finish_elimination()
220
+
221
+ @property
222
+ def pgm(self) -> PGM:
223
+ """
224
+ Returns:
225
+ the PGM that these clusters refer to.
226
+ """
227
+ return self._pgm
228
+
229
+ @property
230
+ def eliminated(self) -> List[int]:
231
+ """
232
+ Get the list of eliminated random variables (as random variable
233
+ indices, in elimination order).
234
+
235
+ Assumes:
236
+ * The returned list will not be modified by the caller.
237
+
238
+ Returns:
239
+ the indexes of eliminated random variables, in elimination order.
240
+ """
241
+ return self._eliminated
242
+
243
+ @property
244
+ def uneliminated(self) -> Set[int]:
245
+ """
246
+ Get the set of uneliminated random variables (as random variable indices).
247
+
248
+ Assumes:
249
+ * The returned set will not be modified by the caller.
250
+
251
+ Returns:
252
+ the set of random variable indexes that are yet to be eliminated.
253
+ """
254
+ return self._uneliminated
255
+
256
+ def connections(self, rv_index: int) -> Set[int]:
257
+ """
258
+ Get the current graph connections of a random variable.
259
+
260
+ Args:
261
+ rv_index: The index of the random variable being queried.
262
+
263
+ Returns:
264
+ the set of random variable indexes that connected to the
265
+ given indexed random variable.
266
+
267
+ Assumes:
268
+ * Not all random variables are eliminated.
269
+ * `rv_idx` is in `self.uneliminated()`.
270
+ * The returned set will not be modified by the caller.
271
+ """
272
+ assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
273
+ return self._connections[rv_index]
274
+
275
+ @property
276
+ def clusters(self) -> List[Set[int]]:
277
+ """
278
+ Get the clusters that are a result of eliminating all random variables.
279
+ This only makes sense once all random variables are eliminated.
280
+
281
+ Assumes:
282
+ * All random variables are eliminated.
283
+ * The returned list and sets will not be modified by the caller.
284
+
285
+ Returns:
286
+ list of all clusters, each cluster is a set of random variable indexes.
287
+ """
288
+ assert len(self._uneliminated) == 0, 'only makes sense when completed eliminating'
289
+ return self._connections
290
+
291
+ def max_cluster_size(self) -> int:
292
+ """
293
+ Calculate the maximum cluster size over all clusters.
294
+
295
+ Returns:
296
+ the maximum `len(cluster)` over all clusters.
297
+ """
298
+ return max(len(cluster) for cluster in self.clusters)
299
+
300
+ def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
301
+ """
302
+ Calculate the maximum cluster weighted size over all clusters.
303
+
304
+ Args:
305
+ rv_log_sizes: is an array of random variable sizes, such that
306
+ for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
307
+ Returns:
308
+ the maximum `sum(rv_log_sizes[rv_idx] for rv_idx in cluster)` over all clusters.
309
+ """
310
+ return max(
311
+ sum(rv_log_sizes[rv_idx] for rv_idx in cluster)
312
+ for cluster in self.clusters
313
+ )
314
+
315
+ def eliminate(self, rv_index: int) -> None:
316
+ """
317
+ Perform one step of variable elimination.
318
+
319
+ A cluster will be identified (either existing or new) to cover the eliminated
320
+ random variable and any other interacting random variables according to
321
+ the factors of the3 PGM. The elimination will be recorded in the identified cluster.
322
+
323
+ Assumes:
324
+ `rv_idx` is in `self.uneliminated()`.
325
+ """
326
+
327
+ # record that the rv is eliminated now
328
+ self._uneliminated.remove(rv_index) # may raise a KeyError.
329
+ self._eliminated.append(rv_index)
330
+
331
+ # Get all rvs connected to the rv being eliminated.
332
+ # For every rv mentioned, connect to all the others.
333
+ # This adds fill edges to connections.
334
+ mentioned_rvs: Set[int] = self._connections[rv_index]
335
+ for mentioned_index in mentioned_rvs:
336
+ rv_connections = self._connections[mentioned_index]
337
+ rv_connections.update(mentioned_rvs)
338
+ rv_connections.discard(mentioned_index)
339
+ rv_connections.discard(rv_index)
340
+
341
+ if len(self._uneliminated) == 0:
342
+ self._finish_elimination()
343
+
344
+ def degree(self, rv_index: int) -> int:
345
+ """
346
+ What is the degree of the random variable with the given index
347
+ given the current state of eliminations.
348
+ Mathematically equivalent to `len(self.connections(rv_index))`.
349
+ """
350
+ assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
351
+ return len(self._connections[rv_index])
352
+
353
+ def fill(self, rv_index: int) -> int:
354
+ """
355
+ What number of new fill edges are created if eliminating the random variable with
356
+ the given index given the current state of eliminations.
357
+ """
358
+ assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
359
+ return self._fill_count(
360
+ rv_index,
361
+ self._add_one,
362
+ self._identity,
363
+ )
364
+
365
+ def weighted_degree(self, rv_index: int) -> float:
366
+ """
367
+ What is the total weight of fill edges are created if eliminating the random variable with
368
+ the given index given the current state of eliminations.
369
+ """
370
+ assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
371
+ rv_connections: Set[int] = self._connections[rv_index]
372
+ return sum(self._rv_log_sizes[other] for other in rv_connections)
373
+
374
+ def weighted_fill(self, rv_index: int) -> float:
375
+ """
376
+ What is the total weight of fill edges are created if eliminating
377
+ the random variable with the given index given the current state of eliminations.
378
+ """
379
+ assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
380
+ return self._fill_count(
381
+ rv_index,
382
+ self._add_sum_log2_states,
383
+ self._divide_2,
384
+ )
385
+
386
+ def traditional_weighted_fill(self, rv_index: int) -> float:
387
+ """
388
+ What is the total traditional weight of fill edges are created if eliminating
389
+ the random variable with the given index given the current state of eliminations.
390
+ """
391
+ assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
392
+ return self._fill_count(
393
+ rv_index,
394
+ self._add_mul_log2_states,
395
+ self._divide_4,
396
+ )
397
+
398
+ def _finish_elimination(self) -> None:
399
+ """
400
+ All rvs are now eliminated. Do any finishing processes.
401
+ """
402
+ # add each rv to its own cluster
403
+ for rv_index, cluster in enumerate(self._connections):
404
+ cluster.add(rv_index)
405
+
406
+ if self._maximal_clusters_only:
407
+ # Removed subsumed clusters
408
+ delete_sentinel: Set[int] = set()
409
+ number_of_clusters = len(self._connections)
410
+ for i in range(number_of_clusters):
411
+ cluster_i = self._connections[i]
412
+ for j in range(i + 1, number_of_clusters):
413
+ cluster_j = self._connections[j]
414
+ if cluster_i.issuperset(cluster_j):
415
+ # The cluster_j is a subset of cluster_i.
416
+ # We move cluster i to position j to preserve correct cluster order.
417
+ self._connections[j] = cluster_i
418
+ self._connections[i] = delete_sentinel
419
+ break
420
+ # Remove clusters marked for deletion
421
+ self._connections = list(filter((lambda connection: connection is not delete_sentinel), self._connections))
422
+
423
+ def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
424
+ """
425
+ Print a dump of the Clusters.
426
+ This is intended for debugging and demonstration purposes.
427
+
428
+ Args:
429
+ prefix: optional prefix for indenting all lines.
430
+ indent: additional prefix to use for extra indentation.
431
+ """
432
+
433
+ def _rv_name(_rv_idx: int) -> str:
434
+ return repr(str(pgm.rvs[_rv_idx]))
435
+
436
+ pgm = self._pgm
437
+
438
+ if len(self._uneliminated) > 0:
439
+ print(f'{prefix}Clustering incomplete.')
440
+ print(f'{prefix}Uneliminated: ', ', '.join(_rv_name(rv_idx) for rv_idx in self._uneliminated))
441
+ print(f'{prefix}Eliminated: ', ', '.join(_rv_name(rv_idx) for rv_idx in self._eliminated))
442
+ print(f'{prefix}Connections:')
443
+ for i, connections in enumerate(self._connections):
444
+ print(f'{prefix}{indent}rv {i}:', ', '.join(_rv_name(rv_idx) for rv_idx in sorted(connections)))
445
+ return
446
+
447
+ print(f'{prefix}Elimination order:')
448
+ for rv_idx in self.eliminated:
449
+ print(f'{prefix}{indent}{_rv_name(rv_idx)}')
450
+ print(f'{prefix}Clusters:')
451
+ for i, cluster in enumerate(self.clusters):
452
+ print(f'{prefix}{indent}cluster {i}:', ', '.join(_rv_name(rv_idx) for rv_idx in sorted(cluster)))
453
+
454
+ @overload
455
+ def _fill_count(
456
+ self,
457
+ rv_index: int,
458
+ count: Callable[[int, int], float],
459
+ finish: Callable[[float], float],
460
+ ) -> float:
461
+ ...
462
+
463
+ @overload
464
+ def _fill_count(
465
+ self,
466
+ rv_index: int,
467
+ count: Callable[[int, int], int],
468
+ finish: Callable[[int], int],
469
+ ) -> int:
470
+ ...
471
+
472
+ def _fill_count(
473
+ self,
474
+ rv_index: int,
475
+ fill_value: Callable[[int, int], int | float],
476
+ result: Callable[[int | float], int | float]):
477
+ """
478
+ Supporting function to calculate the "fill" of a random variable.
479
+
480
+ Args:
481
+ rv_index: the index of the rv to compute the fill.
482
+ fill_value: compute the fill value of two indexed random variables.
483
+ result: compute the result value as a function of the sum of fill values.
484
+
485
+ Returns:
486
+
487
+ """
488
+ fill_sum = 0
489
+ connections: Tuple[int, ...] = tuple(self._connections[rv_index])
490
+ for i, rv1 in enumerate(connections):
491
+ test_connections: Set[int] = self._connections[rv1]
492
+ for rv2 in connections[i + 1:]:
493
+ if rv2 not in test_connections:
494
+ fill_sum += fill_value(rv1, rv2)
495
+ return result(fill_sum)
496
+
497
+ # ==============================================================
498
+ # The following are functions to supply to `self._fill_count`.
499
+ # ==============================================================
500
+
501
+ @staticmethod
502
+ def _add_one(_1: int, _2: int) -> int:
503
+ return 1
504
+
505
+ def _add_sum_log2_states(self, rv1: int, rv2: int) -> float:
506
+ return self._rv_log_sizes[rv1] + self._rv_log_sizes[rv2]
507
+
508
+ def _add_mul_log2_states(self, rv1: int, rv2: int) -> float:
509
+ return self._rv_log_sizes[rv1] * self._rv_log_sizes[rv2]
510
+
511
+ @staticmethod
512
+ def _identity(result: int) -> int:
513
+ return result
514
+
515
+ @staticmethod
516
+ def _divide_2(result: float) -> float:
517
+ return result / 2.0
518
+
519
+ @staticmethod
520
+ def _divide_4(result: float) -> float:
521
+ return result / 4.0
522
+
523
+
524
+ # standard greedy algorithms
525
+
526
+ ClusterAlgorithm = Callable[[PGM], Clusters]
527
+
528
+
529
+ def min_degree(pgm: PGM) -> Clusters:
530
+ clusters = Clusters(pgm)
531
+ ve_greedy_min(clusters, clusters.degree)
532
+ return clusters
533
+
534
+
535
+ def min_fill(pgm: PGM) -> Clusters:
536
+ clusters = Clusters(pgm)
537
+ ve_greedy_min(clusters, clusters.fill)
538
+ return clusters
539
+
540
+
541
+ def min_degree_then_fill(pgm: PGM) -> Clusters:
542
+ clusters = Clusters(pgm)
543
+ ve_greedy_min(clusters, (clusters.degree, clusters.fill))
544
+ return clusters
545
+
546
+
547
+ def min_fill_then_degree(pgm: PGM) -> Clusters:
548
+ clusters = Clusters(pgm)
549
+ ve_greedy_min(clusters, (clusters.fill, clusters.degree))
550
+ return clusters
551
+
552
+
553
+ def min_weighted_degree(pgm: PGM) -> Clusters:
554
+ clusters = Clusters(pgm)
555
+ ve_greedy_min(clusters, clusters.weighted_degree)
556
+ return clusters
557
+
558
+
559
+ def min_weighted_fill(pgm: PGM) -> Clusters:
560
+ clusters = Clusters(pgm)
561
+ ve_greedy_min(clusters, clusters.weighted_fill)
562
+ return clusters
563
+
564
+
565
+ def min_traditional_weighted_fill(pgm: PGM) -> Clusters:
566
+ clusters = Clusters(pgm)
567
+ ve_greedy_min(clusters, clusters.traditional_weighted_fill)
568
+ return clusters