compiled-knowledge 4.0.0a20__cp312-cp312-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37525 -0
- ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from itertools import islice
|
|
5
|
+
from typing import Iterator, Optional, FrozenSet
|
|
6
|
+
|
|
7
|
+
from ck.circuit import CircuitNode
|
|
8
|
+
from ck.pgm_circuit import PGMCircuit
|
|
9
|
+
from ck.pgm_compiler.support.circuit_table import CircuitTable, product, sum_out
|
|
10
|
+
from ck.pgm_compiler.support.factor_tables import make_factor_tables, FactorTables
|
|
11
|
+
from ck.pgm_compiler.support.join_tree import *
|
|
12
|
+
|
|
13
|
+
_NEG_INF = float('-inf')
|
|
14
|
+
|
|
15
|
+
DEFAULT_PRODUCT_SEARCH_LIMIT: int = 1000
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def compile_pgm(
|
|
19
|
+
pgm: PGM,
|
|
20
|
+
const_parameters: bool = True,
|
|
21
|
+
*,
|
|
22
|
+
algorithm: JoinTreeAlgorithm = MIN_FILL_THEN_DEGREE,
|
|
23
|
+
limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
24
|
+
pre_prune_factor_tables: bool = False,
|
|
25
|
+
) -> PGMCircuit:
|
|
26
|
+
"""
|
|
27
|
+
Compile the PGM to an arithmetic circuit, using factor elimination.
|
|
28
|
+
|
|
29
|
+
When forming the product of factors within join tree nodes,
|
|
30
|
+
this method searches all practical binary trees for forming products,
|
|
31
|
+
up to the given limit, `limit_product_tree_search`. The minimum is 1.
|
|
32
|
+
|
|
33
|
+
Conforms to the `PGMCompiler` protocol.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
pgm: The PGM to compile.
|
|
37
|
+
const_parameters: If true, the potential function parameters will be circuit
|
|
38
|
+
constants, otherwise they will be circuit variables.
|
|
39
|
+
algorithm: algorithm to get a join tree.
|
|
40
|
+
limit_product_tree_search: limit on number of product trees to consider.
|
|
41
|
+
pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
a PGMCircuit object.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError if `limit_product_tree_search` is not > 0.
|
|
48
|
+
"""
|
|
49
|
+
join_tree: JoinTree = algorithm(pgm)
|
|
50
|
+
return join_tree_to_circuit(
|
|
51
|
+
join_tree,
|
|
52
|
+
const_parameters,
|
|
53
|
+
limit_product_tree_search,
|
|
54
|
+
pre_prune_factor_tables,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def compile_pgm_best_jointree(
|
|
59
|
+
pgm: PGM,
|
|
60
|
+
const_parameters: bool = True,
|
|
61
|
+
*,
|
|
62
|
+
limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
63
|
+
pre_prune_factor_tables: bool = False,
|
|
64
|
+
) -> PGMCircuit:
|
|
65
|
+
"""
|
|
66
|
+
Try multiple elimination heuristics, and use the join tree that has
|
|
67
|
+
the smallest maximum cluster size.
|
|
68
|
+
|
|
69
|
+
Conforms to the `PGMCompiler` protocol.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
pgm: The PGM to compile.
|
|
73
|
+
const_parameters: If true, the potential function parameters will be circuit
|
|
74
|
+
constants, otherwise they will be circuit variables.
|
|
75
|
+
limit_product_tree_search: limit on number of product trees to consider.
|
|
76
|
+
pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
a PGMCircuit object.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError if `limit_product_tree_search` is not > 0.
|
|
83
|
+
"""
|
|
84
|
+
# Get the smallest cluster sequence for a list of possibles.
|
|
85
|
+
algorithms: Sequence[ClusterAlgorithm] = [
|
|
86
|
+
min_degree,
|
|
87
|
+
min_fill,
|
|
88
|
+
min_degree_then_fill,
|
|
89
|
+
min_fill_then_degree,
|
|
90
|
+
min_weighted_degree,
|
|
91
|
+
min_weighted_fill,
|
|
92
|
+
min_traditional_weighted_fill,
|
|
93
|
+
]
|
|
94
|
+
rv_log_sizes: Sequence[float] = pgm.rv_log_sizes
|
|
95
|
+
best_clusters: Clusters = algorithms[0](pgm)
|
|
96
|
+
best_size = best_clusters.max_cluster_weighted_size(rv_log_sizes)
|
|
97
|
+
for algorithm in algorithms[1:]:
|
|
98
|
+
clusters: Clusters = algorithm(pgm)
|
|
99
|
+
size = clusters.max_cluster_weighted_size(rv_log_sizes)
|
|
100
|
+
if size < best_size:
|
|
101
|
+
best_size = size
|
|
102
|
+
best_clusters = clusters
|
|
103
|
+
|
|
104
|
+
join_tree: JoinTree = clusters_to_join_tree(best_clusters)
|
|
105
|
+
return join_tree_to_circuit(
|
|
106
|
+
join_tree,
|
|
107
|
+
const_parameters,
|
|
108
|
+
limit_product_tree_search,
|
|
109
|
+
pre_prune_factor_tables,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def join_tree_to_circuit(
|
|
114
|
+
join_tree: JoinTree,
|
|
115
|
+
const_parameters: bool = True,
|
|
116
|
+
limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
117
|
+
pre_prune_factor_tables: bool = False,
|
|
118
|
+
) -> PGMCircuit:
|
|
119
|
+
"""
|
|
120
|
+
Construct a PGMCircuit from a join-tree.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
join_tree: a join tree for a PGM.
|
|
124
|
+
const_parameters: If true, the potential function parameters will be circuit
|
|
125
|
+
constants, otherwise they will be circuit variables.
|
|
126
|
+
limit_product_tree_search: limit on number of product trees to consider.
|
|
127
|
+
pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
an arithmetic circuit and slot map, as a PGMCircuit object.
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
ValueError if `limit_product_tree_search` is not > 0.
|
|
134
|
+
"""
|
|
135
|
+
if limit_product_tree_search <= 0:
|
|
136
|
+
raise ValueError('limit_product_tree_search must be > 0')
|
|
137
|
+
|
|
138
|
+
pgm: PGM = join_tree.pgm
|
|
139
|
+
factor_tables: FactorTables = make_factor_tables(
|
|
140
|
+
pgm=pgm,
|
|
141
|
+
const_parameters=const_parameters,
|
|
142
|
+
multiply_indicators=True,
|
|
143
|
+
pre_prune_factor_tables=pre_prune_factor_tables,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
top_table: CircuitTable = _circuit_tables_from_join_tree(
|
|
147
|
+
factor_tables,
|
|
148
|
+
join_tree,
|
|
149
|
+
limit_product_tree_search,
|
|
150
|
+
)
|
|
151
|
+
top: CircuitNode = top_table.top()
|
|
152
|
+
top.circuit.remove_unreachable_op_nodes(top)
|
|
153
|
+
|
|
154
|
+
return PGMCircuit(
|
|
155
|
+
rvs=tuple(pgm.rvs),
|
|
156
|
+
conditions=(),
|
|
157
|
+
circuit_top=top,
|
|
158
|
+
number_of_indicators=factor_tables.number_of_indicators,
|
|
159
|
+
number_of_parameters=factor_tables.number_of_parameters,
|
|
160
|
+
slot_map=factor_tables.slot_map,
|
|
161
|
+
parameter_values=factor_tables.parameter_values,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _circuit_tables_from_join_tree(
|
|
166
|
+
factor_tables: FactorTables,
|
|
167
|
+
join_tree: JoinTree,
|
|
168
|
+
limit_product_tree_search: int,
|
|
169
|
+
) -> CircuitTable:
|
|
170
|
+
"""
|
|
171
|
+
This is a basic algorithm for constructing a circuit table from a join tree.
|
|
172
|
+
Algorithm synopsis:
|
|
173
|
+
1) Get a CircuitTable for each factor allocated to this join tree node, and
|
|
174
|
+
for each child of the join tree node (recursive call to _circuit_tables_from_join_tree).
|
|
175
|
+
2) Form a binary tree of the collected circuit tables.
|
|
176
|
+
3) Perform table products and sum-outs for each node in the binary tree, which should
|
|
177
|
+
leave a single circuit table with a single row.
|
|
178
|
+
"""
|
|
179
|
+
# Get all the factors to combine.
|
|
180
|
+
factors: List[CircuitTable] = list(
|
|
181
|
+
chain(
|
|
182
|
+
(
|
|
183
|
+
# The PGM factors allocated to this join tree node
|
|
184
|
+
factor_tables.get_table(factor)
|
|
185
|
+
for factor in join_tree.factors
|
|
186
|
+
),
|
|
187
|
+
(
|
|
188
|
+
# The children of this join tree node
|
|
189
|
+
_circuit_tables_from_join_tree(factor_tables, child, limit_product_tree_search)
|
|
190
|
+
for child in join_tree.children
|
|
191
|
+
),
|
|
192
|
+
)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# The usual join tree approach just forms the product all the tables in `factors`.
|
|
196
|
+
# The tree width is not affected by the order of products, however some orders
|
|
197
|
+
# lead to smaller numbers of arithmetic operations.
|
|
198
|
+
#
|
|
199
|
+
# If `limit_product_tree_search > 1`, then heuristics are used
|
|
200
|
+
# reduce the number of arithmetic operations.
|
|
201
|
+
|
|
202
|
+
# Deal with the special case: zero factors
|
|
203
|
+
if len(factors) == 0:
|
|
204
|
+
circuit = factor_tables.circuit
|
|
205
|
+
if len(join_tree.separator) == 0:
|
|
206
|
+
# table one
|
|
207
|
+
return CircuitTable(circuit, (), (((), circuit.one),))
|
|
208
|
+
else:
|
|
209
|
+
# table zero
|
|
210
|
+
return CircuitTable(circuit, tuple(join_tree.separator), ())
|
|
211
|
+
|
|
212
|
+
# Analise different ways to combine the factors
|
|
213
|
+
# This method potentially examines all possible trees, O(len(factors)!),
|
|
214
|
+
# which may need to be improved!
|
|
215
|
+
# Trees that result in rvs to be summed out early are scored more highly.
|
|
216
|
+
|
|
217
|
+
rv_log_sizes: Sequence[float] = join_tree.pgm.rv_log_sizes
|
|
218
|
+
best_score = _NEG_INF
|
|
219
|
+
best_tree = None
|
|
220
|
+
for tree in islice(_iterate_trees(factors, join_tree.separator), limit_product_tree_search):
|
|
221
|
+
score = tree.score(rv_log_sizes)
|
|
222
|
+
if score > best_score:
|
|
223
|
+
best_score = score
|
|
224
|
+
best_tree = tree
|
|
225
|
+
|
|
226
|
+
# The tree knows how to form products and perform sum-outs.
|
|
227
|
+
return best_tree.get_table()
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class _Product(ABC):
|
|
231
|
+
"""
|
|
232
|
+
A node in a binary product tree.
|
|
233
|
+
|
|
234
|
+
A node is either a _ProductLeaf, holding a single CircuitTable,
|
|
235
|
+
or is a _ProductInterior, which has exactly two children.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(self, available: Set[int]):
|
|
239
|
+
"""
|
|
240
|
+
Construct a node in a binary product tree.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
available: the rvs that are available (prior to sum-out)
|
|
244
|
+
after the product is formed.
|
|
245
|
+
"""
|
|
246
|
+
self.available: Set[int] = available
|
|
247
|
+
self.sum_out: Set[int] = set()
|
|
248
|
+
|
|
249
|
+
@abstractmethod
|
|
250
|
+
def set_sum_out(self, need: Set[int]) -> None:
|
|
251
|
+
"""
|
|
252
|
+
Set the self.sum_out, based on what rvs are needed.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
need: what rvs are require to be supplied by this node
|
|
256
|
+
after the product is formed. This will be a subset
|
|
257
|
+
of `self.available`.
|
|
258
|
+
"""
|
|
259
|
+
...
|
|
260
|
+
|
|
261
|
+
@abstractmethod
|
|
262
|
+
def score(self, rv_log_sizes: Sequence[float]) -> float:
|
|
263
|
+
"""
|
|
264
|
+
Heuristically score a tree (assuming set_sum_out has been called).
|
|
265
|
+
"""
|
|
266
|
+
...
|
|
267
|
+
|
|
268
|
+
@abstractmethod
|
|
269
|
+
def get_table(self) -> CircuitTable:
|
|
270
|
+
"""
|
|
271
|
+
Returns:
|
|
272
|
+
The circuit table (after products and sum-outs) implied
|
|
273
|
+
by this node.
|
|
274
|
+
"""
|
|
275
|
+
...
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@dataclass
|
|
279
|
+
class _ProductLeaf(_Product):
|
|
280
|
+
|
|
281
|
+
def __init__(self, table: CircuitTable):
|
|
282
|
+
super().__init__(set(table.rv_idxs))
|
|
283
|
+
self.table: CircuitTable = table
|
|
284
|
+
|
|
285
|
+
def set_sum_out(self, need: Set[int]) -> None:
|
|
286
|
+
self.sum_out = self.available.difference(need)
|
|
287
|
+
|
|
288
|
+
def score(self, rv_log_sizes: Sequence[float]) -> float:
|
|
289
|
+
return sum(rv_log_sizes[i] for i in self.sum_out)
|
|
290
|
+
|
|
291
|
+
def get_table(self) -> CircuitTable:
|
|
292
|
+
return sum_out(self.table, self.sum_out)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@dataclass
|
|
296
|
+
class _ProductInterior(_Product):
|
|
297
|
+
|
|
298
|
+
def __init__(self, x: _Product, y: _Product):
|
|
299
|
+
super().__init__(x.available.union(y.available))
|
|
300
|
+
self.x: _Product = x
|
|
301
|
+
self.y: _Product = y
|
|
302
|
+
|
|
303
|
+
def set_sum_out(self, need: Set[int]) -> None:
|
|
304
|
+
x = self.x
|
|
305
|
+
y = self.y
|
|
306
|
+
x_y_common: Set[int] = x.available.intersection(y.available)
|
|
307
|
+
x_need: Set[int] = x.available.intersection(chain(need, x_y_common))
|
|
308
|
+
y_need: Set[int] = y.available.intersection(chain(need, x_y_common))
|
|
309
|
+
self.x.set_sum_out(x_need)
|
|
310
|
+
self.y.set_sum_out(y_need)
|
|
311
|
+
self.sum_out = x_need.union(y_need).difference(need)
|
|
312
|
+
|
|
313
|
+
def score(self, rv_log_sizes: Sequence[float]) -> float:
|
|
314
|
+
x_score = self.x.score(rv_log_sizes)
|
|
315
|
+
y_score = self.y.score(rv_log_sizes)
|
|
316
|
+
return sum(rv_log_sizes[i] for i in self.sum_out) + (x_score + y_score) * 2
|
|
317
|
+
|
|
318
|
+
def get_table(self) -> CircuitTable:
|
|
319
|
+
return sum_out(product(self.x.get_table(), self.y.get_table()), self.sum_out)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _iterate_trees(factors: List[CircuitTable], separator: Set[int]) -> Iterator[_Product]:
|
|
323
|
+
"""
|
|
324
|
+
Iterate over all possible binary trees that form the product of the given factors.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
factors: The list of factors to be in the product.
|
|
328
|
+
separator: What rvs the resulting product needs to be projected onto.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
An iterator over binary product trees.
|
|
332
|
+
|
|
333
|
+
Assumes:
|
|
334
|
+
There is at least one factor.
|
|
335
|
+
"""
|
|
336
|
+
leaves = [_ProductLeaf(table) for table in factors]
|
|
337
|
+
for tree in _iterate_trees_r(leaves):
|
|
338
|
+
tree.set_sum_out(separator)
|
|
339
|
+
yield tree
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _iterate_trees_r(factors: List[_Product]) -> Iterator[_Product]:
|
|
343
|
+
"""
|
|
344
|
+
Recursive support function for _iterate_trees.
|
|
345
|
+
|
|
346
|
+
This will form the products, but not will not set the
|
|
347
|
+
`sum_out` field as that can only be done once a tree is fully formed.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
factors: The list of factors to be in the product.
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
An iterator over binary product trees.
|
|
354
|
+
|
|
355
|
+
Assumes:
|
|
356
|
+
There is at least one factor.
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
# Use heuristics to reduce the number of arithmetic operations.
|
|
360
|
+
# If the rvs of one factor is a subset of another factor, form their
|
|
361
|
+
# product, preferring to product factors with small numbers of rvs.
|
|
362
|
+
|
|
363
|
+
# Sort factors by number or rvs (in increasing order).
|
|
364
|
+
sorted_factors: List[Tuple[FrozenSet[int], Optional[_Product]]] = sorted(
|
|
365
|
+
(
|
|
366
|
+
(frozenset(factor.available), factor)
|
|
367
|
+
for factor in factors
|
|
368
|
+
),
|
|
369
|
+
key=lambda _x: _x[0]
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Product any factor who's rvs are a subset of another factor.
|
|
373
|
+
i: int
|
|
374
|
+
j: int
|
|
375
|
+
for i, (rvs_idxs, factor) in enumerate(sorted_factors):
|
|
376
|
+
for j, (other_rvs_idxs, other_factor) in enumerate(sorted_factors[i + 1:], start=i + 1):
|
|
377
|
+
if other_rvs_idxs.issuperset(rvs_idxs):
|
|
378
|
+
sorted_factors[j] = (other_rvs_idxs, _ProductInterior(other_factor, factor))
|
|
379
|
+
sorted_factors[i] = (rvs_idxs, None)
|
|
380
|
+
break
|
|
381
|
+
factors = [factor for _, factor in sorted_factors if factor is not None]
|
|
382
|
+
|
|
383
|
+
if len(factors) == 1:
|
|
384
|
+
yield factors[0]
|
|
385
|
+
elif len(factors) == 2:
|
|
386
|
+
yield _ProductInterior(*factors)
|
|
387
|
+
else:
|
|
388
|
+
for i in range(len(factors)):
|
|
389
|
+
for j in range(i):
|
|
390
|
+
copy: List[_Product] = factors.copy()
|
|
391
|
+
x = copy.pop(i)
|
|
392
|
+
y = copy.pop(j)
|
|
393
|
+
copy.append(_ProductInterior(x, y))
|
|
394
|
+
for tree in _iterate_trees_r(copy):
|
|
395
|
+
yield tree
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
from ck.pgm import PGM
|
|
4
|
+
from ck.pgm_circuit import PGMCircuit
|
|
5
|
+
from ck.pgm_compiler import variable_elimination, factor_elimination, recursive_conditioning, ace
|
|
6
|
+
from .pgm_compiler import PGMCompiler
|
|
7
|
+
from .support.named_compiler_maker import get_compiler_algorithm as _get_compiler_algorithm, \
|
|
8
|
+
get_compiler as _get_compiler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NamedPGMCompiler(Enum):
|
|
12
|
+
"""
|
|
13
|
+
A standard collection of named compiler functions.
|
|
14
|
+
|
|
15
|
+
The `value` of each enum member is tuple containing a compiler function (PGM -> PGMCircuit).
|
|
16
|
+
Wrapping in a tuple is needed otherwise Python erases the type of the member, which can cause problems.
|
|
17
|
+
Each member itself is callable, confirming to the PGMCompiler protocol, delegating to the compiler function.
|
|
18
|
+
"""
|
|
19
|
+
# @formatter:off
|
|
20
|
+
|
|
21
|
+
VE_MIN_DEGREE = _get_compiler_algorithm(variable_elimination, 'MIN_DEGREE')
|
|
22
|
+
VE_MIN_DEGREE_THEN_FILL = _get_compiler_algorithm(variable_elimination, 'MIN_DEGREE_THEN_FILL')
|
|
23
|
+
VE_MIN_FILL = _get_compiler_algorithm(variable_elimination, 'MIN_FILL')
|
|
24
|
+
VE_MIN_FILL_THEN_DEGREE = _get_compiler_algorithm(variable_elimination, 'MIN_FILL_THEN_DEGREE')
|
|
25
|
+
VE_MIN_WEIGHTED_DEGREE = _get_compiler_algorithm(variable_elimination, 'MIN_WEIGHTED_DEGREE')
|
|
26
|
+
VE_MIN_WEIGHTED_FILL = _get_compiler_algorithm(variable_elimination, 'MIN_WEIGHTED_FILL')
|
|
27
|
+
VE_MIN_TRADITIONAL_WEIGHTED_FILL = _get_compiler_algorithm(variable_elimination, 'MIN_TRADITIONAL_WEIGHTED_FILL')
|
|
28
|
+
|
|
29
|
+
FE_MIN_DEGREE = _get_compiler_algorithm(factor_elimination, 'MIN_DEGREE')
|
|
30
|
+
FE_MIN_DEGREE_THEN_FILL = _get_compiler_algorithm(factor_elimination, 'MIN_DEGREE_THEN_FILL')
|
|
31
|
+
FE_MIN_FILL = _get_compiler_algorithm(factor_elimination, 'MIN_FILL')
|
|
32
|
+
FE_MIN_FILL_THEN_DEGREE = _get_compiler_algorithm(factor_elimination, 'MIN_FILL_THEN_DEGREE')
|
|
33
|
+
FE_MIN_WEIGHTED_DEGREE = _get_compiler_algorithm(factor_elimination, 'MIN_WEIGHTED_DEGREE')
|
|
34
|
+
FE_MIN_WEIGHTED_FILL = _get_compiler_algorithm(factor_elimination, 'MIN_WEIGHTED_FILL')
|
|
35
|
+
FE_MIN_TRADITIONAL_WEIGHTED_FILL = _get_compiler_algorithm(factor_elimination, 'MIN_TRADITIONAL_WEIGHTED_FILL')
|
|
36
|
+
FE_BEST_JOINTREE = factor_elimination.compile_pgm_best_jointree,
|
|
37
|
+
|
|
38
|
+
RC_MIN_DEGREE = _get_compiler_algorithm(recursive_conditioning, 'MIN_DEGREE')
|
|
39
|
+
RC_MIN_DEGREE_THEN_FILL = _get_compiler_algorithm(recursive_conditioning, 'MIN_DEGREE_THEN_FILL')
|
|
40
|
+
RC_MIN_FILL = _get_compiler_algorithm(recursive_conditioning, 'MIN_FILL')
|
|
41
|
+
RC_MIN_FILL_THEN_DEGREE = _get_compiler_algorithm(recursive_conditioning, 'MIN_FILL_THEN_DEGREE')
|
|
42
|
+
RC_MIN_WEIGHTED_DEGREE = _get_compiler_algorithm(recursive_conditioning, 'MIN_WEIGHTED_DEGREE')
|
|
43
|
+
RC_MIN_WEIGHTED_FILL = _get_compiler_algorithm(recursive_conditioning, 'MIN_WEIGHTED_FILL')
|
|
44
|
+
RC_MIN_TRADITIONAL_WEIGHTED_FILL = _get_compiler_algorithm(recursive_conditioning, 'MIN_TRADITIONAL_WEIGHTED_FILL')
|
|
45
|
+
|
|
46
|
+
ACE = _get_compiler(ace)
|
|
47
|
+
|
|
48
|
+
# @formatter:on
|
|
49
|
+
|
|
50
|
+
def __call__(self, pgm: PGM, const_parameters: bool = True) -> PGMCircuit:
|
|
51
|
+
"""
|
|
52
|
+
Each member of the enum is a PGMCompiler function.
|
|
53
|
+
|
|
54
|
+
This implements the `PGMCompiler` protocol.
|
|
55
|
+
"""
|
|
56
|
+
return self.compiler(pgm, const_parameters=const_parameters)
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def compiler(self) -> PGMCompiler:
|
|
60
|
+
return self.value[0]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
DEFAULT_PGM_COMPILER: NamedPGMCompiler = NamedPGMCompiler.FE_BEST_JOINTREE
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from typing import Protocol
|
|
2
|
+
|
|
3
|
+
from ck.pgm import PGM
|
|
4
|
+
from ck.pgm_circuit import PGMCircuit
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PGMCompiler(Protocol):
|
|
8
|
+
def __call__(self, pgm: PGM, *, const_parameters: bool = True) -> PGMCircuit:
|
|
9
|
+
"""
|
|
10
|
+
A PGM compiler is a function with this signature.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
pgm: The PGM to compile.
|
|
14
|
+
const_parameters: If true, the potential function parameters will be circuit
|
|
15
|
+
constants, otherwise they will be circuit variables.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
a PGMCircuit which provides an arithmetic circuit to represent the PGM.
|
|
19
|
+
"""
|