compiled-knowledge 4.0.0a20__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37525 -0
- ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from itertools import chain
|
|
5
|
+
from typing import List, Set, Callable, Sequence, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from ck.pgm import PGM, Factor
|
|
10
|
+
from ck.pgm_compiler.support.clusters import Clusters, min_degree, min_fill, \
|
|
11
|
+
min_degree_then_fill, min_fill_then_degree, min_weighted_degree, min_weighted_fill, min_traditional_weighted_fill, \
|
|
12
|
+
ClusterAlgorithm
|
|
13
|
+
from ck.utils.np_extras import NDArrayFloat64
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class JoinTree:
|
|
18
|
+
"""
|
|
19
|
+
This is a recursive data structure representing a join-tree.
|
|
20
|
+
Each node in the join-tree is represented by a JoinTree object.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
# The PGM that this join tree is for.
|
|
24
|
+
pgm: PGM
|
|
25
|
+
|
|
26
|
+
# Indexes of random variables in this join tree node
|
|
27
|
+
cluster: Set[int]
|
|
28
|
+
|
|
29
|
+
# Child nodes in the join tree
|
|
30
|
+
children: List[JoinTree]
|
|
31
|
+
|
|
32
|
+
# Factors of the PGM allocated to this join tree node.
|
|
33
|
+
factors: List[Factor]
|
|
34
|
+
|
|
35
|
+
# Indexes of random variables that in both this cluster and the parent's cluster.
|
|
36
|
+
# (Empty if this is the root of the spanning tree).
|
|
37
|
+
separator: Set[int]
|
|
38
|
+
|
|
39
|
+
def max_cluster_size(self) -> int:
|
|
40
|
+
"""
|
|
41
|
+
Returns:
|
|
42
|
+
the maximum `len(self.cluster)` over self and all children, recursively.
|
|
43
|
+
"""
|
|
44
|
+
return max(chain((len(self.cluster),), (child.max_cluster_size() for child in self.children)))
|
|
45
|
+
|
|
46
|
+
def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
|
|
47
|
+
"""
|
|
48
|
+
Calculate the maximum cluster weighted size for this cluster and its children.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
rv_log_sizes: is an array of random variable sizes, such that
|
|
52
|
+
for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
the maximum `log2` over self and all children, recursively.
|
|
56
|
+
"""
|
|
57
|
+
self_weighted_size: float = sum(rv_log_sizes[rv_idx] for rv_idx in self.cluster)
|
|
58
|
+
return max(
|
|
59
|
+
chain(
|
|
60
|
+
(self_weighted_size,),
|
|
61
|
+
(child.max_cluster_weighted_size(rv_log_sizes) for child in self.children)
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def dump(self, *, prefix: str = '', indent: str = ' ', show_factors: bool = True) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Print a dump of the Join Tree.
|
|
68
|
+
This is intended for debugging and demonstration purposes.
|
|
69
|
+
|
|
70
|
+
Each cluster is printed as: {separator rvs} | {non-separator rvs}.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
prefix: optional prefix for indenting all lines.
|
|
74
|
+
indent: additional prefix to use for extra indentation.
|
|
75
|
+
show_factors: if true, the factors of each cluster are shown.
|
|
76
|
+
"""
|
|
77
|
+
sep_str = ' '.join(repr(str(self.pgm.rvs[i])) for i in sorted(self.separator))
|
|
78
|
+
rest_str = ' '.join(repr(str(self.pgm.rvs[i])) for i in sorted(self.cluster) if i not in self.separator)
|
|
79
|
+
if len(sep_str) > 0:
|
|
80
|
+
sep_str += ' '
|
|
81
|
+
print(f'{prefix}{sep_str}| {rest_str} (factors: {len(self.factors)})')
|
|
82
|
+
if show_factors:
|
|
83
|
+
for factor in self.factors:
|
|
84
|
+
print(f'{prefix}factor{factor}')
|
|
85
|
+
next_prefix = prefix + indent
|
|
86
|
+
for child in self.children:
|
|
87
|
+
child.dump(prefix=next_prefix, indent=indent, show_factors=show_factors)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# Type for a join tree algorithm: PGM -> JoinTree.
|
|
91
|
+
JoinTreeAlgorithm = Callable[[PGM], JoinTree]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _join_tree_algorithm(pgm_to_clusters: ClusterAlgorithm) -> JoinTreeAlgorithm:
|
|
95
|
+
"""
|
|
96
|
+
Helper function for creating a standard JoinTreeAlgorithm
|
|
97
|
+
from a ClusterAlgorithm.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
pgm_to_clusters: The clusters method to use.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
a JoinTreeAlgorithm.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __join_tree_algorithm(pgm: PGM) -> JoinTree:
|
|
107
|
+
clusters: Clusters = pgm_to_clusters(pgm)
|
|
108
|
+
return clusters_to_join_tree(clusters)
|
|
109
|
+
|
|
110
|
+
return __join_tree_algorithm
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# standard JoinTreeAlgorithms
|
|
114
|
+
|
|
115
|
+
MIN_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_degree)
|
|
116
|
+
MIN_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_fill)
|
|
117
|
+
MIN_DEGREE_THEN_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_degree_then_fill)
|
|
118
|
+
MIN_FILL_THEN_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_fill_then_degree)
|
|
119
|
+
MIN_WEIGHTED_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_weighted_degree)
|
|
120
|
+
MIN_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_weighted_fill)
|
|
121
|
+
MIN_TRADITIONAL_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_traditional_weighted_fill)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def clusters_to_join_tree(clusters: Clusters) -> JoinTree:
|
|
125
|
+
"""
|
|
126
|
+
Construct a join tree from the given random variable clusters.
|
|
127
|
+
|
|
128
|
+
A join tree is formed by finding a minimum spanning tree over the clusters
|
|
129
|
+
where the cost between a pair of clusters is the number of random variables
|
|
130
|
+
in common (using separator state space size to break ties).
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
clusters: the clusters that resulted from graph clusters of a PGM.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
a JoinTree.
|
|
137
|
+
"""
|
|
138
|
+
pgm: PGM = clusters.pgm
|
|
139
|
+
cluster_sets: List[Set[int]] = clusters.clusters
|
|
140
|
+
number_of_clusters = len(cluster_sets)
|
|
141
|
+
|
|
142
|
+
# Dealing with these cases directly simplifies
|
|
143
|
+
# the spanning tree algorithm implementation.
|
|
144
|
+
if number_of_clusters == 0:
|
|
145
|
+
return JoinTree(pgm, set(), [], [], set())
|
|
146
|
+
elif number_of_clusters == 1:
|
|
147
|
+
return JoinTree(pgm, cluster_sets[0], [], list(pgm.factors), set())
|
|
148
|
+
|
|
149
|
+
# Calculate inter-cluster costs for determining the minimum spanning tree
|
|
150
|
+
cost: NDArrayFloat64 = np.zeros((number_of_clusters, number_of_clusters), dtype=np.float64)
|
|
151
|
+
# We will use separator state space size to break ties.
|
|
152
|
+
max_raw_break_cost = sum(pgm.rv_log_sizes) * 1.1 # sum of break costs must be < 1
|
|
153
|
+
break_cost = [cost / max_raw_break_cost for cost in pgm.rv_log_sizes]
|
|
154
|
+
for i in range(number_of_clusters):
|
|
155
|
+
cluster_i = cluster_sets[i]
|
|
156
|
+
for j in range(i + 1, number_of_clusters):
|
|
157
|
+
cluster_j = cluster_sets[j]
|
|
158
|
+
separator = cluster_i.intersection(cluster_j)
|
|
159
|
+
cost[i, j] = cost[j, i] = -len(separator) + sum(break_cost[rv_idx] for rv_idx in separator)
|
|
160
|
+
|
|
161
|
+
# Make the spanning tree over the clusters
|
|
162
|
+
root_custer_index: int
|
|
163
|
+
children: List[List[int]]
|
|
164
|
+
children, root_custer_index = _make_spanning_tree_small_root(cost, clusters.clusters)
|
|
165
|
+
|
|
166
|
+
# Allocate each PGM factor to a cluster
|
|
167
|
+
cluster_factors: List[List[Factor]] = [[] for _ in range(number_of_clusters)]
|
|
168
|
+
ordered_indexed_clusters = list(enumerate(cluster_sets))
|
|
169
|
+
ordered_indexed_clusters.sort(key=lambda idx_c: len(idx_c[1])) # sort from smallest to largest cluster
|
|
170
|
+
for factor in pgm.factors:
|
|
171
|
+
rv_indexes = frozenset(rv.idx for rv in factor.rvs)
|
|
172
|
+
for cluster_index, cluster in ordered_indexed_clusters:
|
|
173
|
+
if rv_indexes.issubset(cluster):
|
|
174
|
+
cluster_factors[cluster_index].append(factor)
|
|
175
|
+
break
|
|
176
|
+
|
|
177
|
+
return _form_join_tree_r(pgm, root_custer_index, set(), children, cluster_sets, cluster_factors)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
_INF = float('inf')
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]]) -> Tuple[List[List[int]], int]:
|
|
184
|
+
"""
|
|
185
|
+
Construct a minimum spanning tree over the clusters, where the root is the cluster with
|
|
186
|
+
the smallest number of random variable.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
cost: is an N x N matrix of costs between N clusters.
|
|
190
|
+
clusters: is a list of N clusters, each cluster is a set of random variable indices.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
(spanning_tree, root_index)
|
|
194
|
+
|
|
195
|
+
spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
|
|
196
|
+
the given cost matrix, each node is a list of children, each child being
|
|
197
|
+
represented as an index into the list of nodes.
|
|
198
|
+
|
|
199
|
+
root_index: is the index the chosen root of the spanning tree.
|
|
200
|
+
"""
|
|
201
|
+
root_custer_index: int = 0
|
|
202
|
+
root_size: int = len(clusters[root_custer_index])
|
|
203
|
+
for i, cluster in enumerate(clusters[1:], start=1):
|
|
204
|
+
if len(clusters[root_custer_index]) < root_size:
|
|
205
|
+
root_custer_index = i
|
|
206
|
+
root_size: int = len(cluster)
|
|
207
|
+
|
|
208
|
+
children: List[List[int]] = _make_spanning_tree_at_root(cost, root_custer_index)
|
|
209
|
+
return children, root_custer_index
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _make_spanning_tree_arbitrary_root(cost: NDArrayFloat64) -> Tuple[List[List[int]], int]:
|
|
213
|
+
"""
|
|
214
|
+
Construct a minimum spanning tree over the clusters, starting at an arbitrary root.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
cost: is an N x N matrix of costs between N clusters.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
(spanning_tree, root_index)
|
|
221
|
+
|
|
222
|
+
spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
|
|
223
|
+
the given cost matrix, each node is a list of children, each child being
|
|
224
|
+
represented as an index into the list of nodes.
|
|
225
|
+
|
|
226
|
+
root_index: is the index the chosen root of the spanning tree.
|
|
227
|
+
"""
|
|
228
|
+
root_index: int = 0
|
|
229
|
+
spanning_tree: List[List[int]] = _make_spanning_tree_at_root(cost, root_index)
|
|
230
|
+
return spanning_tree, root_index
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _make_spanning_tree_at_root(
|
|
234
|
+
cost: NDArrayFloat64,
|
|
235
|
+
root_custer_index: int,
|
|
236
|
+
) -> List[List[int]]:
|
|
237
|
+
"""
|
|
238
|
+
Construct a minimum spanning tree over the clusters, starting at the given root.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
cost: and nxn matrix where n is the number of clusters and cost[i, j]
|
|
242
|
+
gives the cost between clusters i and j.
|
|
243
|
+
root_custer_index: a nominated root cluster to be the root of the tree.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
a spanning tree represented as a list of nodes, the list is coindexed with
|
|
247
|
+
the given cost matrix, each node is a list of children, each child being
|
|
248
|
+
represented as an index into the list of nodes. The root node is the
|
|
249
|
+
index `root_custer_index` as passed to this function.
|
|
250
|
+
"""
|
|
251
|
+
number_of_clusters: int = cost.shape[0]
|
|
252
|
+
|
|
253
|
+
# clusters left to process.
|
|
254
|
+
remaining: List[int] = list(range(number_of_clusters))
|
|
255
|
+
|
|
256
|
+
# clusters that have been processed.
|
|
257
|
+
included: List[int] = []
|
|
258
|
+
|
|
259
|
+
def remove_remaining(_remaining_index: int) -> None:
|
|
260
|
+
# Remove the `remaining` element at the given index location.
|
|
261
|
+
remaining[_remaining_index] = remaining[-1]
|
|
262
|
+
remaining.pop()
|
|
263
|
+
|
|
264
|
+
# Move root from `remaining` to `included`
|
|
265
|
+
included.append(root_custer_index)
|
|
266
|
+
remove_remaining(root_custer_index) # assumes remaining[root_custer_index] = root_custer_index
|
|
267
|
+
|
|
268
|
+
# Data structure to collect the results.
|
|
269
|
+
children: List[List[int]] = [[] for _ in range(number_of_clusters)]
|
|
270
|
+
|
|
271
|
+
while True:
|
|
272
|
+
min_i: int = 0
|
|
273
|
+
min_j: int = 0
|
|
274
|
+
min_j_pos: int = 0
|
|
275
|
+
min_c: float = _INF
|
|
276
|
+
for i in included:
|
|
277
|
+
for j_pos, j in enumerate(remaining):
|
|
278
|
+
c: float = cost.item(i, j)
|
|
279
|
+
if c < min_c:
|
|
280
|
+
min_c = c
|
|
281
|
+
min_i = i
|
|
282
|
+
min_j = j
|
|
283
|
+
min_j_pos = j_pos
|
|
284
|
+
|
|
285
|
+
# Record the child and move remaining_idx from 'remaining' to 'included'.
|
|
286
|
+
children[min_i].append(min_j)
|
|
287
|
+
if len(remaining) == 1:
|
|
288
|
+
# That was the last one.
|
|
289
|
+
return children
|
|
290
|
+
|
|
291
|
+
# Update `remaining` and `included`
|
|
292
|
+
remove_remaining(min_j_pos)
|
|
293
|
+
included.append(min_j)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _form_join_tree_r(
|
|
297
|
+
pgm: PGM,
|
|
298
|
+
cluster_index: int,
|
|
299
|
+
parent_cluster: Set[int],
|
|
300
|
+
children: Sequence[List[int]],
|
|
301
|
+
clusters: Sequence[Set[int]],
|
|
302
|
+
cluster_factors: List[List[Factor]],
|
|
303
|
+
) -> JoinTree:
|
|
304
|
+
"""
|
|
305
|
+
Recursively build a JoinTree from the spanning tree `children`.
|
|
306
|
+
This function merely pull the corresponding component from the
|
|
307
|
+
arguments to make a JoinTree object, doing this recursively
|
|
308
|
+
for the children.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
pgm: the source PGM for the join tree.
|
|
312
|
+
cluster_index: index for the node we are processing (current root). This
|
|
313
|
+
indexes into `children`, `clusters`, and `cluster_factors`.
|
|
314
|
+
parent_cluster: set of random variable indices in the parent cluster.
|
|
315
|
+
children: list of spanning tree nodes, as per `_make_spanning_tree_at_root` result.
|
|
316
|
+
clusters: list of clusters, each cluster is a set of random variable indices.
|
|
317
|
+
cluster_factors: assignment of factors to clusters.
|
|
318
|
+
"""
|
|
319
|
+
cluster: Set[int] = clusters[cluster_index]
|
|
320
|
+
factors: List[Factor] = cluster_factors[cluster_index]
|
|
321
|
+
children = [
|
|
322
|
+
_form_join_tree_r(pgm, child, cluster, children, clusters, cluster_factors)
|
|
323
|
+
for child in children[cluster_index]
|
|
324
|
+
]
|
|
325
|
+
separator: Set[int] = parent_cluster.intersection(cluster)
|
|
326
|
+
return JoinTree(
|
|
327
|
+
pgm,
|
|
328
|
+
cluster,
|
|
329
|
+
children,
|
|
330
|
+
factors,
|
|
331
|
+
separator,
|
|
332
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from types import ModuleType
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
from ck.pgm import PGM
|
|
5
|
+
from ck.pgm_circuit import PGMCircuit
|
|
6
|
+
from ck.pgm_compiler import PGMCompiler
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_compiler(module: ModuleType, **kwargs) -> Tuple[PGMCompiler]:
|
|
10
|
+
"""
|
|
11
|
+
Helper function to create a named PGM compiler.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
module: module containing `compile_pgm` function.
|
|
15
|
+
kwargs: are additional keyword arguments to `compile_pgm`.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
a singleton tuple containing PGMCompiler function.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def compiler(pgm: PGM, const_parameters: bool = True) -> PGMCircuit:
|
|
22
|
+
"""Conforms to the `PGMCompiler` protocol."""
|
|
23
|
+
return module.compile_pgm(pgm, const_parameters=const_parameters, **kwargs)
|
|
24
|
+
|
|
25
|
+
return compiler,
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_compiler_algorithm(module: ModuleType, algorithm: str, **kwargs) -> Tuple[PGMCompiler]:
|
|
29
|
+
"""
|
|
30
|
+
Helper function to create a named PGM compiler, with a named algorithm argument.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
module: module containing `compile_pgm` function.
|
|
34
|
+
algorithm: name of the algorithm, to pass as keyword argument to `compile_pgm`.
|
|
35
|
+
The algorithm should be declared in the module.
|
|
36
|
+
kwargs: are additional keyword arguments to `compile_pgm`.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
a singleton tuple containing PGMCompiler function.
|
|
40
|
+
"""
|
|
41
|
+
return get_compiler(module, algorithm=getattr(module, algorithm), **kwargs)
|
|
42
|
+
|
|
43
|
+
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Sequence
|
|
4
|
+
|
|
5
|
+
from ck.circuit import CircuitNode
|
|
6
|
+
from ck.pgm import PGM
|
|
7
|
+
from ck.pgm_circuit import PGMCircuit
|
|
8
|
+
from ck.pgm_compiler.support import clusters
|
|
9
|
+
from ck.pgm_compiler.support.circuit_table import CircuitTable, product, sum_out
|
|
10
|
+
from ck.pgm_compiler.support.clusters import ClusterAlgorithm
|
|
11
|
+
from ck.pgm_compiler.support.factor_tables import make_factor_tables, FactorTables
|
|
12
|
+
|
|
13
|
+
# Standard cluster algorithms.
|
|
14
|
+
MIN_DEGREE: ClusterAlgorithm = clusters.min_degree
|
|
15
|
+
MIN_FILL: ClusterAlgorithm = clusters.min_fill
|
|
16
|
+
MIN_DEGREE_THEN_FILL: ClusterAlgorithm = clusters.min_degree_then_fill
|
|
17
|
+
MIN_FILL_THEN_DEGREE: ClusterAlgorithm = clusters.min_fill_then_degree
|
|
18
|
+
MIN_WEIGHTED_DEGREE: ClusterAlgorithm = clusters.min_weighted_degree
|
|
19
|
+
MIN_WEIGHTED_FILL: ClusterAlgorithm = clusters.min_weighted_fill
|
|
20
|
+
MIN_TRADITIONAL_WEIGHTED_FILL: ClusterAlgorithm = clusters.min_traditional_weighted_fill
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def compile_pgm(
|
|
24
|
+
pgm: PGM,
|
|
25
|
+
const_parameters: bool = True,
|
|
26
|
+
*,
|
|
27
|
+
algorithm: ClusterAlgorithm = MIN_FILL_THEN_DEGREE,
|
|
28
|
+
pre_prune_factor_tables: bool = False,
|
|
29
|
+
) -> PGMCircuit:
|
|
30
|
+
"""
|
|
31
|
+
Compile the PGM to an arithmetic circuit, using variable elimination.
|
|
32
|
+
|
|
33
|
+
Conforms to the `PGMCompiler` protocol.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
pgm: The PGM to compile.
|
|
37
|
+
const_parameters: If true, the potential function parameters will be circuit
|
|
38
|
+
constants, otherwise they will be circuit variables.
|
|
39
|
+
algorithm: algorithm to get an elimination order.
|
|
40
|
+
pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
a PGMCircuit object.
|
|
44
|
+
"""
|
|
45
|
+
factor_tables: FactorTables = make_factor_tables(
|
|
46
|
+
pgm=pgm,
|
|
47
|
+
const_parameters=const_parameters,
|
|
48
|
+
multiply_indicators=True,
|
|
49
|
+
pre_prune_factor_tables=pre_prune_factor_tables,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
elimination_order: Sequence[int] = algorithm(pgm).eliminated
|
|
53
|
+
|
|
54
|
+
# Eliminate rvs from the factor tables according to the
|
|
55
|
+
# elimination order.
|
|
56
|
+
cur_tables: List[CircuitTable] = list(factor_tables.tables)
|
|
57
|
+
for rv_idx in elimination_order:
|
|
58
|
+
next_tables: List[CircuitTable] = []
|
|
59
|
+
tables_with_rv: List[CircuitTable] = []
|
|
60
|
+
for table in cur_tables:
|
|
61
|
+
if rv_idx in table.rv_idxs:
|
|
62
|
+
tables_with_rv.append(table)
|
|
63
|
+
else:
|
|
64
|
+
next_tables.append(table)
|
|
65
|
+
if len(tables_with_rv) > 0:
|
|
66
|
+
while len(tables_with_rv) > 1:
|
|
67
|
+
# product the two smallest tables
|
|
68
|
+
tables_with_rv.sort(key=lambda _t: -len(_t))
|
|
69
|
+
x = tables_with_rv.pop()
|
|
70
|
+
y = tables_with_rv.pop()
|
|
71
|
+
tables_with_rv.append(product(x, y))
|
|
72
|
+
next_tables.append(sum_out(tables_with_rv[0], (rv_idx,)))
|
|
73
|
+
cur_tables = next_tables
|
|
74
|
+
|
|
75
|
+
# All rvs are now eliminated - all tables should have a single top.
|
|
76
|
+
tops: List[CircuitNode] = [
|
|
77
|
+
table.top()
|
|
78
|
+
for table in cur_tables
|
|
79
|
+
]
|
|
80
|
+
top: CircuitNode = factor_tables.circuit.optimised_mul(tops)
|
|
81
|
+
top.circuit.remove_unreachable_op_nodes(top)
|
|
82
|
+
|
|
83
|
+
return PGMCircuit(
|
|
84
|
+
rvs=tuple(pgm.rvs),
|
|
85
|
+
conditions=(),
|
|
86
|
+
circuit_top=top,
|
|
87
|
+
number_of_indicators=factor_tables.number_of_indicators,
|
|
88
|
+
number_of_parameters=factor_tables.number_of_parameters,
|
|
89
|
+
slot_map=factor_tables.slot_map,
|
|
90
|
+
parameter_values=factor_tables.parameter_values,
|
|
91
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Sequence, Iterable, Tuple, Dict, List
|
|
2
|
+
|
|
3
|
+
from ck.pgm import RandomVariable, Indicator, Instance
|
|
4
|
+
from ck.probability.probability_space import ProbabilitySpace, Condition, check_condition
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class EmpiricalProbabilitySpace(ProbabilitySpace):
|
|
8
|
+
def __init__(self, rvs: Sequence[RandomVariable], samples: Iterable[Instance]):
|
|
9
|
+
"""
|
|
10
|
+
Enable probabilistic queries over a sample from a sample space.
|
|
11
|
+
Note that this is not necessarily an efficient approach to calculating probabilities and statistics.
|
|
12
|
+
|
|
13
|
+
Assumes:
|
|
14
|
+
len(sample) == len(rvs), for each sample in samples.
|
|
15
|
+
0 <= sample[i] < len(rvs[i]), for each sample in samples, for i in range(len(rvs)).
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
rvs: The random variables.
|
|
19
|
+
samples: instances (state indexes) that are samples from the given rvs.
|
|
20
|
+
"""
|
|
21
|
+
self._rvs: Sequence[RandomVariable] = tuple(rvs)
|
|
22
|
+
self._samples: List[Instance] = list(samples)
|
|
23
|
+
self._rv_idx_to_sample_idx: Dict[int, int] = {
|
|
24
|
+
rv.idx: i
|
|
25
|
+
for i, rv in enumerate(self._rvs)
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def rvs(self) -> Sequence[RandomVariable]:
|
|
30
|
+
return self._rvs
|
|
31
|
+
|
|
32
|
+
def wmc(self, *condition: Condition) -> float:
|
|
33
|
+
condition: Tuple[Indicator, ...] = check_condition(condition)
|
|
34
|
+
|
|
35
|
+
checks = [set() for _ in self._rvs]
|
|
36
|
+
for ind in condition:
|
|
37
|
+
checks[self._rv_idx_to_sample_idx[ind.rv_idx]].add(ind.state_idx)
|
|
38
|
+
for i in range(len(checks)):
|
|
39
|
+
if len(checks[i]) > 0:
|
|
40
|
+
checks[i] = set(range(len(self._rvs[i]))).difference(checks[i])
|
|
41
|
+
|
|
42
|
+
def satisfied(instance: Instance) -> bool:
|
|
43
|
+
return not any((state in check) for state, check in zip(instance, checks))
|
|
44
|
+
|
|
45
|
+
return sum(1 for _ in filter(satisfied, self._samples))
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def z(self) -> float:
|
|
49
|
+
return len(self._samples)
|
|
50
|
+
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from typing import Sequence, Iterable, Tuple, Dict, List
|
|
2
|
+
|
|
3
|
+
from ck.pgm import RandomVariable, Indicator, Instance, PGM
|
|
4
|
+
from ck.probability.probability_space import ProbabilitySpace, Condition, check_condition
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PGMProbabilitySpace(ProbabilitySpace):
|
|
8
|
+
def __init__(self, pgm: PGM):
|
|
9
|
+
"""
|
|
10
|
+
Enable probabilistic queries directly on a PGM.
|
|
11
|
+
Note that this is not necessarily an efficient approach to calculating probabilities and statistics.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
pgm: The PGM to query.
|
|
15
|
+
"""
|
|
16
|
+
self._pgm = pgm
|
|
17
|
+
self._z = None
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def rvs(self) -> Sequence[RandomVariable]:
|
|
21
|
+
return self._pgm.rvs
|
|
22
|
+
|
|
23
|
+
def wmc(self, *condition: Condition) -> float:
|
|
24
|
+
condition: Tuple[Indicator, ...] = check_condition(condition)
|
|
25
|
+
return self._pgm.value_product_indicators(*condition)
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def z(self) -> float:
|
|
29
|
+
if self._z is None:
|
|
30
|
+
self._z = self._pgm.value_product_indicators()
|
|
31
|
+
return self._z
|
|
32
|
+
|