compiled-knowledge 4.0.0a20__cp312-cp312-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37523 -0
- ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,568 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Graph analysis to identify clusters using elimination heuristics.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Set, Iterable, Callable, Iterator, Tuple, List, overload, Sequence
|
|
7
|
+
|
|
8
|
+
from ck.pgm import PGM
|
|
9
|
+
|
|
10
|
+
# A VEObjective is a variable elimination objective function.
|
|
11
|
+
# An objective function is a function from a random variable index (int)
|
|
12
|
+
# to an objective value (float or int). This is used to select
|
|
13
|
+
# a random variable to eliminate in `ve_greedy_min`.
|
|
14
|
+
VEObjective = Callable[[int], int | float]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def ve_fixed(clusters: Clusters, order: Iterable[int]) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Apply the given fixed elimination order to the elimination tree.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
clusters: a clusters object with uneliminated random variables.
|
|
23
|
+
order: the order of variable elimination.
|
|
24
|
+
|
|
25
|
+
Assumes:
|
|
26
|
+
* All rv indexes in `order` are also in `clusters.uneliminated`.
|
|
27
|
+
* There are no duplicates in `order``.
|
|
28
|
+
"""
|
|
29
|
+
for rv_index in order:
|
|
30
|
+
clusters.eliminate(rv_index)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def ve_greedy_min(
|
|
34
|
+
clusters: Clusters,
|
|
35
|
+
objective: VEObjective | Tuple[VEObjective, ...],
|
|
36
|
+
use_twig_prefix: bool = True,
|
|
37
|
+
use_optimal_prefix: bool = False,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""
|
|
40
|
+
The greedy variable elimination heuristic.
|
|
41
|
+
|
|
42
|
+
The objective is a function from (eliminable: Clusters, var_idx: int) to
|
|
43
|
+
which should return an objective value (to greedily minimise by the method).
|
|
44
|
+
The objective may be a tuple of objective functions for tie breaking.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
clusters: a clusters object with uneliminated random variables.
|
|
48
|
+
objective: the objective function ( or a tuple of objective functions) to minimise each iteration.
|
|
49
|
+
use_twig_prefix: if true, then `twig_prefix` is used to eliminate any
|
|
50
|
+
candidate random variable prior selecting random variables using the objective function.
|
|
51
|
+
use_optimal_prefix: if true, then `optimal_prefix` is used to eliminate any
|
|
52
|
+
candidate random variable prior selecting random variables using the objective function.
|
|
53
|
+
"""
|
|
54
|
+
uneliminated: Set[int] = clusters.uneliminated
|
|
55
|
+
|
|
56
|
+
if isinstance(objective, tuple):
|
|
57
|
+
def __objective(_rv_index: int) -> Tuple[float | int, ...]:
|
|
58
|
+
return tuple(f(_rv_index) for f in objective)
|
|
59
|
+
else:
|
|
60
|
+
__objective = objective
|
|
61
|
+
|
|
62
|
+
while len(uneliminated) > 1:
|
|
63
|
+
|
|
64
|
+
if use_twig_prefix:
|
|
65
|
+
twig_prefix(clusters)
|
|
66
|
+
if len(uneliminated) <= 1:
|
|
67
|
+
break
|
|
68
|
+
|
|
69
|
+
if use_optimal_prefix:
|
|
70
|
+
optimal_prefix(clusters)
|
|
71
|
+
if len(uneliminated) <= 1:
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
min_iter: Iterator[int] = iter(uneliminated)
|
|
75
|
+
min_rv_index = next(min_iter)
|
|
76
|
+
min_obj = __objective(min_rv_index)
|
|
77
|
+
for rv_index in min_iter:
|
|
78
|
+
obj = __objective(rv_index)
|
|
79
|
+
if obj < min_obj:
|
|
80
|
+
min_rv_index = rv_index
|
|
81
|
+
min_obj = obj
|
|
82
|
+
clusters.eliminate(min_rv_index)
|
|
83
|
+
|
|
84
|
+
if len(uneliminated) > 0:
|
|
85
|
+
# eliminate the last rv
|
|
86
|
+
clusters.eliminate(next(iter(uneliminated)))
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def twig_prefix(clusters: Clusters) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Eliminate all random variables with degree zero or one.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def get_rvs(degree: int) -> List[int]:
|
|
95
|
+
return [
|
|
96
|
+
_rv_index
|
|
97
|
+
for _rv_index in clusters.uneliminated
|
|
98
|
+
if clusters.degree(_rv_index) == degree
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
for rv_index in get_rvs(degree=0):
|
|
102
|
+
clusters.eliminate(rv_index)
|
|
103
|
+
|
|
104
|
+
while len(clusters.uneliminated) > 0:
|
|
105
|
+
eliminating = get_rvs(degree=1)
|
|
106
|
+
if len(eliminating) == 0:
|
|
107
|
+
break
|
|
108
|
+
for rv_index in eliminating:
|
|
109
|
+
clusters.eliminate(rv_index)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def optimal_prefix(clusters: Clusters) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Eliminate all random variables that are guaranteed to be optimal (in resulting tree width).
|
|
115
|
+
|
|
116
|
+
See Adnan Darwiche, 2009, Modeling and Reasoning with Bayesian Networks, p207.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def _get_lower_bound() -> int:
|
|
120
|
+
# Return a lower bound on the tree width for the current clusters.
|
|
121
|
+
return max(
|
|
122
|
+
max(
|
|
123
|
+
(len(clusters.connections(_rv_index)) for _rv_index in clusters.uneliminated),
|
|
124
|
+
default=0
|
|
125
|
+
) - 1,
|
|
126
|
+
0
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
prev_number_uneliminated: int = len(clusters.uneliminated) + 1
|
|
130
|
+
|
|
131
|
+
while prev_number_uneliminated > len(clusters.uneliminated):
|
|
132
|
+
prev_number_uneliminated = len(clusters.uneliminated)
|
|
133
|
+
low: int = _get_lower_bound()
|
|
134
|
+
to_eliminate: Set[int] = set()
|
|
135
|
+
for rv_index in clusters.uneliminated:
|
|
136
|
+
fill: int = clusters.fill(rv_index)
|
|
137
|
+
if fill == 0:
|
|
138
|
+
# simplical rule: no fill edges
|
|
139
|
+
to_eliminate.add(rv_index)
|
|
140
|
+
elif fill == 1 and clusters.degree(rv_index) <= low:
|
|
141
|
+
# almost simplical rule: one fill edge and degree <= low
|
|
142
|
+
to_eliminate.add(rv_index)
|
|
143
|
+
|
|
144
|
+
# Perform eliminations
|
|
145
|
+
for rv_index in to_eliminate:
|
|
146
|
+
clusters.eliminate(rv_index)
|
|
147
|
+
|
|
148
|
+
low: int = _get_lower_bound()
|
|
149
|
+
if low >= 3:
|
|
150
|
+
to_eliminate: Set[int] = set()
|
|
151
|
+
for rv_index_i in clusters.uneliminated:
|
|
152
|
+
if clusters.degree(rv_index_i) == 3:
|
|
153
|
+
i_neighbours: Set[int] = clusters.connections(rv_index_i)
|
|
154
|
+
|
|
155
|
+
# buddy rule: two joined nodes with degree 3 and sam neighbours
|
|
156
|
+
for rv_index_j in i_neighbours:
|
|
157
|
+
if clusters.degree(rv_index_j) == 3:
|
|
158
|
+
j_neighbours: Set[int] = clusters.connections(rv_index_j)
|
|
159
|
+
if i_neighbours.difference([rv_index_j]) == j_neighbours.difference([rv_index_i]):
|
|
160
|
+
to_eliminate.add(rv_index_i)
|
|
161
|
+
to_eliminate.add(rv_index_j)
|
|
162
|
+
|
|
163
|
+
# check cube rule: i, a, b, c form a cube
|
|
164
|
+
if len(i_neighbours) == 3:
|
|
165
|
+
if all(clusters.degree(rv_index) == 3 for rv_index in i_neighbours):
|
|
166
|
+
a, b, c = tuple(i_neighbours)
|
|
167
|
+
ab = clusters.connections(a).intersection(clusters.connections(a))
|
|
168
|
+
ac = clusters.connections(a).intersection(clusters.connections(c))
|
|
169
|
+
bc = clusters.connections(b).intersection(clusters.connections(c))
|
|
170
|
+
if len(ab) == 1 and len(ac) == 1 and len(bc) == 1:
|
|
171
|
+
to_eliminate.add(rv_index_i)
|
|
172
|
+
to_eliminate.add(a)
|
|
173
|
+
to_eliminate.add(b)
|
|
174
|
+
to_eliminate.add(c)
|
|
175
|
+
|
|
176
|
+
# Perform eliminations
|
|
177
|
+
for rv_index in to_eliminate:
|
|
178
|
+
clusters.eliminate(rv_index)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class Clusters:
|
|
182
|
+
"""
|
|
183
|
+
A Clusters object holds the state of a connection graph while
|
|
184
|
+
eliminating variables to construct clusters for a PGM graph.
|
|
185
|
+
|
|
186
|
+
The Clusters object can either be "in-progress" where `len(Clusters.uneliminated) > 0`,
|
|
187
|
+
or be "completed" where `len(Clusters.uneliminated) == 0`.
|
|
188
|
+
|
|
189
|
+
See Adnan Darwiche, 2009, Modeling and Reasoning with Bayesian Networks, p164.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
def __init__(self, pgm: PGM, maximal_clusters_only: bool = True):
|
|
193
|
+
"""
|
|
194
|
+
Args:
|
|
195
|
+
pgm: source PGM defining initial connection graph.
|
|
196
|
+
maximal_clusters_only: if true, then any subsumed cluster will be incorporated
|
|
197
|
+
into its subsuming cluster (once all random variables are eliminated).
|
|
198
|
+
"""
|
|
199
|
+
self._pgm: PGM = pgm
|
|
200
|
+
self._uneliminated: Set[int] = {rv.idx for rv in pgm.rvs}
|
|
201
|
+
self._eliminated: List[int] = []
|
|
202
|
+
self._rv_log_sizes: Sequence[float] = pgm.rv_log_sizes
|
|
203
|
+
self._maximal_clusters_only = maximal_clusters_only
|
|
204
|
+
|
|
205
|
+
# Create a connection set for each random variable.
|
|
206
|
+
# The connection set keeps track of what _other_ random variable it's connected to (via factors).
|
|
207
|
+
# I.e., the connections define an interaction graph.
|
|
208
|
+
connections: List[Set[int]] = [set() for _ in range(pgm.number_of_rvs)]
|
|
209
|
+
for factor in pgm.factors:
|
|
210
|
+
rv_indexes = [rv.idx for rv in factor.rvs]
|
|
211
|
+
for index in rv_indexes:
|
|
212
|
+
connections[index].update(rv_indexes)
|
|
213
|
+
for index, rv_connections in enumerate(connections):
|
|
214
|
+
rv_connections.discard(index)
|
|
215
|
+
self._connections = connections
|
|
216
|
+
|
|
217
|
+
# Deal with the case of an empty PGM.
|
|
218
|
+
if len(self._uneliminated) == 0:
|
|
219
|
+
self._finish_elimination()
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def pgm(self) -> PGM:
|
|
223
|
+
"""
|
|
224
|
+
Returns:
|
|
225
|
+
the PGM that these clusters refer to.
|
|
226
|
+
"""
|
|
227
|
+
return self._pgm
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def eliminated(self) -> List[int]:
|
|
231
|
+
"""
|
|
232
|
+
Get the list of eliminated random variables (as random variable
|
|
233
|
+
indices, in elimination order).
|
|
234
|
+
|
|
235
|
+
Assumes:
|
|
236
|
+
* The returned list will not be modified by the caller.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
the indexes of eliminated random variables, in elimination order.
|
|
240
|
+
"""
|
|
241
|
+
return self._eliminated
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def uneliminated(self) -> Set[int]:
|
|
245
|
+
"""
|
|
246
|
+
Get the set of uneliminated random variables (as random variable indices).
|
|
247
|
+
|
|
248
|
+
Assumes:
|
|
249
|
+
* The returned set will not be modified by the caller.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
the set of random variable indexes that are yet to be eliminated.
|
|
253
|
+
"""
|
|
254
|
+
return self._uneliminated
|
|
255
|
+
|
|
256
|
+
def connections(self, rv_index: int) -> Set[int]:
|
|
257
|
+
"""
|
|
258
|
+
Get the current graph connections of a random variable.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
rv_index: The index of the random variable being queried.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
the set of random variable indexes that connected to the
|
|
265
|
+
given indexed random variable.
|
|
266
|
+
|
|
267
|
+
Assumes:
|
|
268
|
+
* Not all random variables are eliminated.
|
|
269
|
+
* `rv_idx` is in `self.uneliminated()`.
|
|
270
|
+
* The returned set will not be modified by the caller.
|
|
271
|
+
"""
|
|
272
|
+
assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
|
|
273
|
+
return self._connections[rv_index]
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def clusters(self) -> List[Set[int]]:
|
|
277
|
+
"""
|
|
278
|
+
Get the clusters that are a result of eliminating all random variables.
|
|
279
|
+
This only makes sense once all random variables are eliminated.
|
|
280
|
+
|
|
281
|
+
Assumes:
|
|
282
|
+
* All random variables are eliminated.
|
|
283
|
+
* The returned list and sets will not be modified by the caller.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
list of all clusters, each cluster is a set of random variable indexes.
|
|
287
|
+
"""
|
|
288
|
+
assert len(self._uneliminated) == 0, 'only makes sense when completed eliminating'
|
|
289
|
+
return self._connections
|
|
290
|
+
|
|
291
|
+
def max_cluster_size(self) -> int:
|
|
292
|
+
"""
|
|
293
|
+
Calculate the maximum cluster size over all clusters.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
the maximum `len(cluster)` over all clusters.
|
|
297
|
+
"""
|
|
298
|
+
return max(len(cluster) for cluster in self.clusters)
|
|
299
|
+
|
|
300
|
+
def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
|
|
301
|
+
"""
|
|
302
|
+
Calculate the maximum cluster weighted size over all clusters.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
rv_log_sizes: is an array of random variable sizes, such that
|
|
306
|
+
for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
|
|
307
|
+
Returns:
|
|
308
|
+
the maximum `sum(rv_log_sizes[rv_idx] for rv_idx in cluster)` over all clusters.
|
|
309
|
+
"""
|
|
310
|
+
return max(
|
|
311
|
+
sum(rv_log_sizes[rv_idx] for rv_idx in cluster)
|
|
312
|
+
for cluster in self.clusters
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def eliminate(self, rv_index: int) -> None:
|
|
316
|
+
"""
|
|
317
|
+
Perform one step of variable elimination.
|
|
318
|
+
|
|
319
|
+
A cluster will be identified (either existing or new) to cover the eliminated
|
|
320
|
+
random variable and any other interacting random variables according to
|
|
321
|
+
the factors of the3 PGM. The elimination will be recorded in the identified cluster.
|
|
322
|
+
|
|
323
|
+
Assumes:
|
|
324
|
+
`rv_idx` is in `self.uneliminated()`.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
# record that the rv is eliminated now
|
|
328
|
+
self._uneliminated.remove(rv_index) # may raise a KeyError.
|
|
329
|
+
self._eliminated.append(rv_index)
|
|
330
|
+
|
|
331
|
+
# Get all rvs connected to the rv being eliminated.
|
|
332
|
+
# For every rv mentioned, connect to all the others.
|
|
333
|
+
# This adds fill edges to connections.
|
|
334
|
+
mentioned_rvs: Set[int] = self._connections[rv_index]
|
|
335
|
+
for mentioned_index in mentioned_rvs:
|
|
336
|
+
rv_connections = self._connections[mentioned_index]
|
|
337
|
+
rv_connections.update(mentioned_rvs)
|
|
338
|
+
rv_connections.discard(mentioned_index)
|
|
339
|
+
rv_connections.discard(rv_index)
|
|
340
|
+
|
|
341
|
+
if len(self._uneliminated) == 0:
|
|
342
|
+
self._finish_elimination()
|
|
343
|
+
|
|
344
|
+
def degree(self, rv_index: int) -> int:
|
|
345
|
+
"""
|
|
346
|
+
What is the degree of the random variable with the given index
|
|
347
|
+
given the current state of eliminations.
|
|
348
|
+
Mathematically equivalent to `len(self.connections(rv_index))`.
|
|
349
|
+
"""
|
|
350
|
+
assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
|
|
351
|
+
return len(self._connections[rv_index])
|
|
352
|
+
|
|
353
|
+
def fill(self, rv_index: int) -> int:
|
|
354
|
+
"""
|
|
355
|
+
What number of new fill edges are created if eliminating the random variable with
|
|
356
|
+
the given index given the current state of eliminations.
|
|
357
|
+
"""
|
|
358
|
+
assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
|
|
359
|
+
return self._fill_count(
|
|
360
|
+
rv_index,
|
|
361
|
+
self._add_one,
|
|
362
|
+
self._identity,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
def weighted_degree(self, rv_index: int) -> float:
|
|
366
|
+
"""
|
|
367
|
+
What is the total weight of fill edges are created if eliminating the random variable with
|
|
368
|
+
the given index given the current state of eliminations.
|
|
369
|
+
"""
|
|
370
|
+
assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
|
|
371
|
+
rv_connections: Set[int] = self._connections[rv_index]
|
|
372
|
+
return sum(self._rv_log_sizes[other] for other in rv_connections)
|
|
373
|
+
|
|
374
|
+
def weighted_fill(self, rv_index: int) -> float:
|
|
375
|
+
"""
|
|
376
|
+
What is the total weight of fill edges are created if eliminating
|
|
377
|
+
the random variable with the given index given the current state of eliminations.
|
|
378
|
+
"""
|
|
379
|
+
assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
|
|
380
|
+
return self._fill_count(
|
|
381
|
+
rv_index,
|
|
382
|
+
self._add_sum_log2_states,
|
|
383
|
+
self._divide_2,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
def traditional_weighted_fill(self, rv_index: int) -> float:
|
|
387
|
+
"""
|
|
388
|
+
What is the total traditional weight of fill edges are created if eliminating
|
|
389
|
+
the random variable with the given index given the current state of eliminations.
|
|
390
|
+
"""
|
|
391
|
+
assert len(self._uneliminated) > 0, 'only makes sense while eliminating'
|
|
392
|
+
return self._fill_count(
|
|
393
|
+
rv_index,
|
|
394
|
+
self._add_mul_log2_states,
|
|
395
|
+
self._divide_4,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
def _finish_elimination(self) -> None:
|
|
399
|
+
"""
|
|
400
|
+
All rvs are now eliminated. Do any finishing processes.
|
|
401
|
+
"""
|
|
402
|
+
# add each rv to its own cluster
|
|
403
|
+
for rv_index, cluster in enumerate(self._connections):
|
|
404
|
+
cluster.add(rv_index)
|
|
405
|
+
|
|
406
|
+
if self._maximal_clusters_only:
|
|
407
|
+
# Removed subsumed clusters
|
|
408
|
+
delete_sentinel: Set[int] = set()
|
|
409
|
+
number_of_clusters = len(self._connections)
|
|
410
|
+
for i in range(number_of_clusters):
|
|
411
|
+
cluster_i = self._connections[i]
|
|
412
|
+
for j in range(i + 1, number_of_clusters):
|
|
413
|
+
cluster_j = self._connections[j]
|
|
414
|
+
if cluster_i.issuperset(cluster_j):
|
|
415
|
+
# The cluster_j is a subset of cluster_i.
|
|
416
|
+
# We move cluster i to position j to preserve correct cluster order.
|
|
417
|
+
self._connections[j] = cluster_i
|
|
418
|
+
self._connections[i] = delete_sentinel
|
|
419
|
+
break
|
|
420
|
+
# Remove clusters marked for deletion
|
|
421
|
+
self._connections = list(filter((lambda connection: connection is not delete_sentinel), self._connections))
|
|
422
|
+
|
|
423
|
+
def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
|
|
424
|
+
"""
|
|
425
|
+
Print a dump of the Clusters.
|
|
426
|
+
This is intended for debugging and demonstration purposes.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
prefix: optional prefix for indenting all lines.
|
|
430
|
+
indent: additional prefix to use for extra indentation.
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
def _rv_name(_rv_idx: int) -> str:
|
|
434
|
+
return repr(str(pgm.rvs[_rv_idx]))
|
|
435
|
+
|
|
436
|
+
pgm = self._pgm
|
|
437
|
+
|
|
438
|
+
if len(self._uneliminated) > 0:
|
|
439
|
+
print(f'{prefix}Clustering incomplete.')
|
|
440
|
+
print(f'{prefix}Uneliminated: ', ', '.join(_rv_name(rv_idx) for rv_idx in self._uneliminated))
|
|
441
|
+
print(f'{prefix}Eliminated: ', ', '.join(_rv_name(rv_idx) for rv_idx in self._eliminated))
|
|
442
|
+
print(f'{prefix}Connections:')
|
|
443
|
+
for i, connections in enumerate(self._connections):
|
|
444
|
+
print(f'{prefix}{indent}rv {i}:', ', '.join(_rv_name(rv_idx) for rv_idx in sorted(connections)))
|
|
445
|
+
return
|
|
446
|
+
|
|
447
|
+
print(f'{prefix}Elimination order:')
|
|
448
|
+
for rv_idx in self.eliminated:
|
|
449
|
+
print(f'{prefix}{indent}{_rv_name(rv_idx)}')
|
|
450
|
+
print(f'{prefix}Clusters:')
|
|
451
|
+
for i, cluster in enumerate(self.clusters):
|
|
452
|
+
print(f'{prefix}{indent}cluster {i}:', ', '.join(_rv_name(rv_idx) for rv_idx in sorted(cluster)))
|
|
453
|
+
|
|
454
|
+
@overload
|
|
455
|
+
def _fill_count(
|
|
456
|
+
self,
|
|
457
|
+
rv_index: int,
|
|
458
|
+
count: Callable[[int, int], float],
|
|
459
|
+
finish: Callable[[float], float],
|
|
460
|
+
) -> float:
|
|
461
|
+
...
|
|
462
|
+
|
|
463
|
+
@overload
|
|
464
|
+
def _fill_count(
|
|
465
|
+
self,
|
|
466
|
+
rv_index: int,
|
|
467
|
+
count: Callable[[int, int], int],
|
|
468
|
+
finish: Callable[[int], int],
|
|
469
|
+
) -> int:
|
|
470
|
+
...
|
|
471
|
+
|
|
472
|
+
def _fill_count(
|
|
473
|
+
self,
|
|
474
|
+
rv_index: int,
|
|
475
|
+
fill_value: Callable[[int, int], int | float],
|
|
476
|
+
result: Callable[[int | float], int | float]):
|
|
477
|
+
"""
|
|
478
|
+
Supporting function to calculate the "fill" of a random variable.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
rv_index: the index of the rv to compute the fill.
|
|
482
|
+
fill_value: compute the fill value of two indexed random variables.
|
|
483
|
+
result: compute the result value as a function of the sum of fill values.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
|
|
487
|
+
"""
|
|
488
|
+
fill_sum = 0
|
|
489
|
+
connections: Tuple[int, ...] = tuple(self._connections[rv_index])
|
|
490
|
+
for i, rv1 in enumerate(connections):
|
|
491
|
+
test_connections: Set[int] = self._connections[rv1]
|
|
492
|
+
for rv2 in connections[i + 1:]:
|
|
493
|
+
if rv2 not in test_connections:
|
|
494
|
+
fill_sum += fill_value(rv1, rv2)
|
|
495
|
+
return result(fill_sum)
|
|
496
|
+
|
|
497
|
+
# ==============================================================
|
|
498
|
+
# The following are functions to supply to `self._fill_count`.
|
|
499
|
+
# ==============================================================
|
|
500
|
+
|
|
501
|
+
@staticmethod
|
|
502
|
+
def _add_one(_1: int, _2: int) -> int:
|
|
503
|
+
return 1
|
|
504
|
+
|
|
505
|
+
def _add_sum_log2_states(self, rv1: int, rv2: int) -> float:
|
|
506
|
+
return self._rv_log_sizes[rv1] + self._rv_log_sizes[rv2]
|
|
507
|
+
|
|
508
|
+
def _add_mul_log2_states(self, rv1: int, rv2: int) -> float:
|
|
509
|
+
return self._rv_log_sizes[rv1] * self._rv_log_sizes[rv2]
|
|
510
|
+
|
|
511
|
+
@staticmethod
|
|
512
|
+
def _identity(result: int) -> int:
|
|
513
|
+
return result
|
|
514
|
+
|
|
515
|
+
@staticmethod
|
|
516
|
+
def _divide_2(result: float) -> float:
|
|
517
|
+
return result / 2.0
|
|
518
|
+
|
|
519
|
+
@staticmethod
|
|
520
|
+
def _divide_4(result: float) -> float:
|
|
521
|
+
return result / 4.0
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
# standard greedy algorithms
|
|
525
|
+
|
|
526
|
+
ClusterAlgorithm = Callable[[PGM], Clusters]
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def min_degree(pgm: PGM) -> Clusters:
|
|
530
|
+
clusters = Clusters(pgm)
|
|
531
|
+
ve_greedy_min(clusters, clusters.degree)
|
|
532
|
+
return clusters
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def min_fill(pgm: PGM) -> Clusters:
|
|
536
|
+
clusters = Clusters(pgm)
|
|
537
|
+
ve_greedy_min(clusters, clusters.fill)
|
|
538
|
+
return clusters
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def min_degree_then_fill(pgm: PGM) -> Clusters:
|
|
542
|
+
clusters = Clusters(pgm)
|
|
543
|
+
ve_greedy_min(clusters, (clusters.degree, clusters.fill))
|
|
544
|
+
return clusters
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def min_fill_then_degree(pgm: PGM) -> Clusters:
|
|
548
|
+
clusters = Clusters(pgm)
|
|
549
|
+
ve_greedy_min(clusters, (clusters.fill, clusters.degree))
|
|
550
|
+
return clusters
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def min_weighted_degree(pgm: PGM) -> Clusters:
|
|
554
|
+
clusters = Clusters(pgm)
|
|
555
|
+
ve_greedy_min(clusters, clusters.weighted_degree)
|
|
556
|
+
return clusters
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def min_weighted_fill(pgm: PGM) -> Clusters:
|
|
560
|
+
clusters = Clusters(pgm)
|
|
561
|
+
ve_greedy_min(clusters, clusters.weighted_fill)
|
|
562
|
+
return clusters
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def min_traditional_weighted_fill(pgm: PGM) -> Clusters:
|
|
566
|
+
clusters = Clusters(pgm)
|
|
567
|
+
ve_greedy_min(clusters, clusters.traditional_weighted_fill)
|
|
568
|
+
return clusters
|