compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +13 -0
- ck/circuit/circuit.c +38749 -0
- ck/circuit/circuit.cpython-313-darwin.so +0 -0
- ck/circuit/circuit_py.py +807 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser.py +81 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53674 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +288 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3494 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +75 -0
- ck/pgm_circuit/program_with_slotmap.py +234 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +252 -0
- ck/pgm_compiler/factor_elimination.py +383 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +226 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
- ck/pgm_compiler/support/clusters.py +556 -0
- ck/pgm_compiler/support/factor_tables.py +398 -0
- ck/pgm_compiler/support/join_tree.py +275 -0
- ck/pgm_compiler/support/named_compiler_maker.py +33 -0
- ck/pgm_compiler/variable_elimination.py +89 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +47 -0
- ck/probability/probability_space.py +568 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +129 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +61 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +66 -0
- ck/sampling/wmc_direct_sampler.py +169 -0
- ck/sampling/wmc_gibbs_sampler.py +147 -0
- ck/sampling/wmc_metropolis_sampler.py +159 -0
- ck/sampling/wmc_rejection_sampler.py +113 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +153 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +44 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +88 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
- compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Sequence, Tuple, List, Iterator, Set, Iterable, Optional, Callable
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ck.circuit import Circuit, VarNode, CircuitNode
|
|
9
|
+
from ck.pgm import PGM, ParamId, Factor, PotentialFunction, RandomVariable, ZeroPotentialFunction
|
|
10
|
+
from ck.pgm_circuit.slot_map import SlotMap, SlotKey
|
|
11
|
+
from ck.pgm_compiler.support.circuit_table import CircuitTable, TableInstance
|
|
12
|
+
from ck.utils.iter_extras import pairs
|
|
13
|
+
from ck.utils.map_list import MapList
|
|
14
|
+
from ck.utils.np_extras import NDArray, NDArrayFloat64
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class FactorTables:
|
|
19
|
+
circuit: Circuit # The host circuit
|
|
20
|
+
number_of_indicators: int # number of indicator variables
|
|
21
|
+
number_of_parameters: int # number of parameter variables (i.e., non-const, in-use parameters)
|
|
22
|
+
slot_map: SlotMap # map from Indicator or ParamId object to a circuit var index.
|
|
23
|
+
tables: Sequence[CircuitTable] # one CircuitTable for each PGM factor.
|
|
24
|
+
|
|
25
|
+
# For a non-const, in-use parameter with id `param_id`, the PGM value of that
|
|
26
|
+
# parameter was `self.parameter_values[self.slot_map[param_id] - self.number_of_indicators]`.
|
|
27
|
+
parameter_values: NDArray
|
|
28
|
+
|
|
29
|
+
def get_table(self, factor: Factor) -> CircuitTable:
|
|
30
|
+
return self.tables[factor.idx]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def make_factor_tables(
|
|
34
|
+
pgm: PGM,
|
|
35
|
+
const_parameters: bool,
|
|
36
|
+
multiply_indicators: bool,
|
|
37
|
+
pre_prune_factor_tables: bool,
|
|
38
|
+
) -> FactorTables:
|
|
39
|
+
"""
|
|
40
|
+
Consistently and efficiently create circuit tables for factors of a PGM.
|
|
41
|
+
|
|
42
|
+
Creates:
|
|
43
|
+
* a circuit,
|
|
44
|
+
* a circuit variable for each indicator of the PGM,
|
|
45
|
+
* a circuit variable for each non-constant, in-use potential function parameter.
|
|
46
|
+
* a circuit table for each Factor of the PGM,
|
|
47
|
+
|
|
48
|
+
The parameter of each potential function will be converted either
|
|
49
|
+
eiter to a circuit constant (if const_parameters is true) or a circuit
|
|
50
|
+
variable (if const_parameters is false).
|
|
51
|
+
|
|
52
|
+
Random variables will be multiplied into factor circuit tables if
|
|
53
|
+
`multiply_indicators` is true.
|
|
54
|
+
|
|
55
|
+
A slot map will be created that maps PGM indicators and parameter ids to circuit var indices.
|
|
56
|
+
Specifically, a circuit var will be added for each indicator,
|
|
57
|
+
in the order they appear in `pgm.indicators`. Circuit vars for parameter ids will be added
|
|
58
|
+
after those for indicators, and only if const_parameters is false.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
pgm: The PGM with the random variables, factors, and potential functions.
|
|
62
|
+
const_parameters: if true, then potential function parameters will be circuit constants,
|
|
63
|
+
otherwise they will be circuit variables, with entries in the returned slot map.
|
|
64
|
+
multiply_indicators: if true then indicator variables will be multiplied into an acceptable
|
|
65
|
+
factor.
|
|
66
|
+
pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
FactorTables, holding a slot_map and a circuit table for each PGM factor.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
# Create circuit and initialise the slot map with indicator variables
|
|
73
|
+
circuit = Circuit()
|
|
74
|
+
slot_map: Dict[SlotKey, int] = {
|
|
75
|
+
indicator: circuit.new_var().idx
|
|
76
|
+
for indicator in pgm.indicators
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
# Get the circuit table rows for each potential function
|
|
80
|
+
# functions_rows[id(function)] = rows for the function
|
|
81
|
+
functions_rows: Dict[int, _FunctionRows]
|
|
82
|
+
if const_parameters:
|
|
83
|
+
functions_rows = {
|
|
84
|
+
id(function): _rows_for_function_const(function, circuit)
|
|
85
|
+
for function in pgm.functions
|
|
86
|
+
}
|
|
87
|
+
else:
|
|
88
|
+
functions_rows = {
|
|
89
|
+
id(function): _rows_for_function_var(function, circuit, slot_map)
|
|
90
|
+
for function in pgm.functions
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# Link factors to function rows.
|
|
94
|
+
# factor_rows[id(factor)] = rows for the factor
|
|
95
|
+
factor_rows: Dict[int, _FactorRows] = {}
|
|
96
|
+
for factor in pgm.factors:
|
|
97
|
+
rows: _FunctionRows = functions_rows[id(factor.function)]
|
|
98
|
+
rows.use_count += 1
|
|
99
|
+
factor_rows[id(factor)] = _FactorRows(factor, rows)
|
|
100
|
+
|
|
101
|
+
# Check to see if any factor rows can be pre-pruned.
|
|
102
|
+
if pre_prune_factor_tables:
|
|
103
|
+
_pre_prune_factor_tables(list(factor_rows.values()))
|
|
104
|
+
|
|
105
|
+
# Allocated random variables to factors
|
|
106
|
+
factors_mul_rvs: MapList[int, RandomVariable]
|
|
107
|
+
if multiply_indicators:
|
|
108
|
+
def _factor_size(_factor: Factor) -> int:
|
|
109
|
+
return len(factor_rows[id(_factor)])
|
|
110
|
+
|
|
111
|
+
factors_mul_rvs = _assign_rvs_to_factors(pgm, _factor_size)
|
|
112
|
+
else:
|
|
113
|
+
factors_mul_rvs = MapList() # no assignment of rvs to factors.
|
|
114
|
+
|
|
115
|
+
# Make a circuit table for each factor. `tables[factor.index]` is the circuit table for `factor`.
|
|
116
|
+
tables: List[CircuitTable] = [
|
|
117
|
+
_make_factor_table(factor, circuit, slot_map, factor_rows[id(factor)], factors_mul_rvs)
|
|
118
|
+
for factor in pgm.factors
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# Extract the parameter values (if they are circuit vars).
|
|
122
|
+
number_of_indicators: int = pgm.number_of_indicators
|
|
123
|
+
number_of_parameters: int = len(slot_map) - number_of_indicators
|
|
124
|
+
parameter_values: NDArrayFloat64 = np.zeros(number_of_parameters, dtype=np.float64)
|
|
125
|
+
if not const_parameters:
|
|
126
|
+
for function in pgm.functions:
|
|
127
|
+
for param_index, value in function.params:
|
|
128
|
+
param_id: ParamId = function.param_id(param_index)
|
|
129
|
+
slot: Optional[int] = slot_map.get(param_id)
|
|
130
|
+
if slot is not None:
|
|
131
|
+
parameter_values[slot - number_of_indicators] = value
|
|
132
|
+
|
|
133
|
+
return FactorTables(
|
|
134
|
+
circuit=circuit,
|
|
135
|
+
number_of_indicators=number_of_indicators,
|
|
136
|
+
number_of_parameters=number_of_parameters,
|
|
137
|
+
slot_map=slot_map,
|
|
138
|
+
tables=tables,
|
|
139
|
+
parameter_values=parameter_values,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _assign_rvs_to_factors(
|
|
144
|
+
pgm: PGM,
|
|
145
|
+
factor_size: Callable[[Factor], int],
|
|
146
|
+
) -> MapList[int, RandomVariable]:
|
|
147
|
+
"""
|
|
148
|
+
Assign each random variable to the smallest factor containing it.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
a map from factor id to list of random variables assigned to that factor
|
|
152
|
+
"""
|
|
153
|
+
factors = pgm.factors
|
|
154
|
+
rvs = pgm.rvs
|
|
155
|
+
|
|
156
|
+
# For each rv, get the factors it is in
|
|
157
|
+
rv_factors: MapList[int, Factor] = MapList() # rv index to list of Factors with that rv.
|
|
158
|
+
for factor in factors:
|
|
159
|
+
for rv in factor.rvs:
|
|
160
|
+
rv_factors.append(rv.idx, factor)
|
|
161
|
+
|
|
162
|
+
# For each rv, assign it to a factor for multiplication
|
|
163
|
+
factors_mul_rvs: MapList[int, RandomVariable] = MapList() # factor id to list of rvs
|
|
164
|
+
for rv_index in range(len(rvs)):
|
|
165
|
+
candidates: Sequence[Factor] = rv_factors.get(rv_index, ())
|
|
166
|
+
if len(candidates) > 0:
|
|
167
|
+
best_factor = min(candidates, key=factor_size)
|
|
168
|
+
factors_mul_rvs.append(id(best_factor), rvs[rv_index])
|
|
169
|
+
|
|
170
|
+
return factors_mul_rvs
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class _FunctionRows:
|
|
174
|
+
def __init__(self, rows: Dict[TableInstance, CircuitNode], use_count: int = 0):
|
|
175
|
+
self.rows: Dict[TableInstance, CircuitNode] = rows
|
|
176
|
+
self.use_count: int = use_count
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class _FactorRows:
|
|
180
|
+
def __init__(self, factor: Factor, rows: _FunctionRows):
|
|
181
|
+
self.rows: _FunctionRows = rows
|
|
182
|
+
self.rv_indexes: Tuple[int, ...] = tuple(rv.idx for rv in factor.rvs)
|
|
183
|
+
|
|
184
|
+
def __len__(self) -> int:
|
|
185
|
+
return len(self.rows.rows)
|
|
186
|
+
|
|
187
|
+
def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
|
|
188
|
+
return self.rows.rows.items()
|
|
189
|
+
|
|
190
|
+
def prune(self, extra_keys: Set[TableInstance]) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Remove the given keys from the factor's function rows.
|
|
193
|
+
"""
|
|
194
|
+
if len(extra_keys) > 0:
|
|
195
|
+
new_rows: Dict[TableInstance, CircuitNode] = {
|
|
196
|
+
instance: node
|
|
197
|
+
for instance, node in self.rows.rows.items()
|
|
198
|
+
if instance not in extra_keys
|
|
199
|
+
}
|
|
200
|
+
if self.rows.use_count > 1:
|
|
201
|
+
self.rows.use_count -= 1
|
|
202
|
+
self.rows = _FunctionRows(new_rows, 1)
|
|
203
|
+
else:
|
|
204
|
+
self.rows.rows = new_rows
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class _FactorPair:
|
|
208
|
+
def __init__(self, x: _FactorRows, y: _FactorRows):
|
|
209
|
+
self.x: _FactorRows = x
|
|
210
|
+
self.y: _FactorRows = y
|
|
211
|
+
|
|
212
|
+
x_set = set(self.x.rv_indexes)
|
|
213
|
+
|
|
214
|
+
# Identify all random variables used by x and y
|
|
215
|
+
self.all_rv_indexes: Set[int] = x_set.union(self.y.rv_indexes)
|
|
216
|
+
|
|
217
|
+
# Identify common random variables between x and y
|
|
218
|
+
# Keep them in a stable order
|
|
219
|
+
self.co_rv_indexes: Tuple[int, ...] = tuple(x_set.intersection(self.y.rv_indexes))
|
|
220
|
+
|
|
221
|
+
# Cache mappings from result Instance to index into source Instance (x or y).
|
|
222
|
+
# This will be used in indexing and product loops to pull our needed values
|
|
223
|
+
# from the source instances.
|
|
224
|
+
self.co_from_x_map = tuple(x.rv_indexes.index(rv_index) for rv_index in self.co_rv_indexes)
|
|
225
|
+
self.co_from_y_map = tuple(y.rv_indexes.index(rv_index) for rv_index in self.co_rv_indexes)
|
|
226
|
+
|
|
227
|
+
def prune(self) -> None:
|
|
228
|
+
"""
|
|
229
|
+
Prune any rows from x and y that cannot join to each other.
|
|
230
|
+
"""
|
|
231
|
+
co_from_x_map = self.co_from_x_map
|
|
232
|
+
co_from_y_map = self.co_from_y_map
|
|
233
|
+
x_rows = self.x.rows.rows
|
|
234
|
+
y_rows = self.y.rows.rows
|
|
235
|
+
|
|
236
|
+
x_co_set: Set[TableInstance] = {
|
|
237
|
+
tuple(instance[i] for i in co_from_x_map)
|
|
238
|
+
for instance in x_rows.keys()
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
y_co_set: Set[TableInstance] = {
|
|
242
|
+
tuple(instance[i] for i in co_from_y_map)
|
|
243
|
+
for instance in y_rows.keys()
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
# Keys in x that will not join to y
|
|
247
|
+
x_extra_keys: Set[TableInstance] = {
|
|
248
|
+
instance
|
|
249
|
+
for instance in x_rows.keys()
|
|
250
|
+
if tuple(instance[i] for i in co_from_x_map) not in y_co_set
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
# Keys in y that will not join to x
|
|
254
|
+
y_extra_keys: Set[TableInstance] = {
|
|
255
|
+
instance
|
|
256
|
+
for instance in y_rows.keys()
|
|
257
|
+
if tuple(instance[i] for i in co_from_y_map) not in x_co_set
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
self.x.prune(x_extra_keys)
|
|
261
|
+
self.y.prune(y_extra_keys)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _pre_prune_factor_tables(factor_rows: Sequence[_FactorRows]) -> None:
|
|
265
|
+
"""
|
|
266
|
+
It may be possible to reduce the size of a table for a factor.
|
|
267
|
+
|
|
268
|
+
If two factors contain a common random variable then at some point their product
|
|
269
|
+
will be formed, which may eliminate rows. This method identifies and removes
|
|
270
|
+
such rows.
|
|
271
|
+
"""
|
|
272
|
+
pairs_to_check: List[_FactorPair] = [
|
|
273
|
+
_FactorPair(f1, f2)
|
|
274
|
+
for f1, f2 in pairs(factor_rows)
|
|
275
|
+
if not set(f1.rv_indexes).isdisjoint(f1.rv_indexes)
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
pairs_done: List[_FactorPair] = []
|
|
279
|
+
|
|
280
|
+
while len(pairs_to_check) > 0:
|
|
281
|
+
pair = pairs_to_check.pop()
|
|
282
|
+
x = pair.x
|
|
283
|
+
y = pair.y
|
|
284
|
+
|
|
285
|
+
x_size = len(x)
|
|
286
|
+
y_size = len(y)
|
|
287
|
+
pair.prune()
|
|
288
|
+
|
|
289
|
+
# See if any pairs need re-checking
|
|
290
|
+
rvs_affected: Set[int] = set()
|
|
291
|
+
if x_size != len(x):
|
|
292
|
+
rvs_affected.update(x.rv_indexes)
|
|
293
|
+
if y_size != len(y):
|
|
294
|
+
rvs_affected.update(y.rv_indexes)
|
|
295
|
+
if len(rvs_affected) > 0:
|
|
296
|
+
next_pairs_done: List[_FactorPair] = []
|
|
297
|
+
for pair in pairs_done:
|
|
298
|
+
if rvs_affected.isdisjoint(pair.all_rv_indexes):
|
|
299
|
+
next_pairs_done.append(pair)
|
|
300
|
+
else:
|
|
301
|
+
pairs_to_check.append(pair)
|
|
302
|
+
pairs_done = next_pairs_done
|
|
303
|
+
|
|
304
|
+
# Mark the current pair as done.
|
|
305
|
+
pairs_done.append(pair)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _make_factor_table(
|
|
309
|
+
factor: Factor,
|
|
310
|
+
circuit: Circuit,
|
|
311
|
+
slot_map: Dict[SlotKey, int],
|
|
312
|
+
rows: _FactorRows,
|
|
313
|
+
factors_mul_rvs: MapList[int, RandomVariable],
|
|
314
|
+
) -> CircuitTable:
|
|
315
|
+
# Get random variables to multiply into the table
|
|
316
|
+
factor_mul_rvs: Sequence[RandomVariable] = factors_mul_rvs.get(id(factor), ())
|
|
317
|
+
|
|
318
|
+
# Create the empty circuit table
|
|
319
|
+
factor_rv_indexes: Sequence[int] = tuple(rv.idx for rv in factor.rvs)
|
|
320
|
+
|
|
321
|
+
if len(factor_mul_rvs) == 0:
|
|
322
|
+
# Trivial case - no random variables to multiply into the table.
|
|
323
|
+
return CircuitTable(circuit, factor_rv_indexes, rows.items())
|
|
324
|
+
|
|
325
|
+
# Work out what element in an instance of the factor will select the indicator
|
|
326
|
+
# variable for each mul rv.
|
|
327
|
+
# inst_to_mul[i] is the index into factor.rvs for factor_mul_rvs[i]
|
|
328
|
+
inst_to_mul: Sequence[int] = tuple(factor_rv_indexes.index(rv.idx) for rv in factor_mul_rvs)
|
|
329
|
+
|
|
330
|
+
# Map a state index of a mul rv to its indicator circuit variable.
|
|
331
|
+
# mul_rvs_vars[i][j] is the indicator circuit variable for factor_mul_rvs[i][j]
|
|
332
|
+
mul_rvs_vars: Sequence[Sequence[CircuitNode]] = tuple(
|
|
333
|
+
tuple(circuit.vars[slot_map[ind]] for ind in rv.indicators)
|
|
334
|
+
for rv in factor_mul_rvs
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
338
|
+
for instance, node in rows.items():
|
|
339
|
+
to_mul = tuple(
|
|
340
|
+
mul_vars[instance[inst_index]]
|
|
341
|
+
for inst_index, mul_vars in zip(inst_to_mul, mul_rvs_vars)
|
|
342
|
+
)
|
|
343
|
+
if not node.is_one():
|
|
344
|
+
to_mul += (node,)
|
|
345
|
+
if len(to_mul) == 0:
|
|
346
|
+
yield instance, circuit.one
|
|
347
|
+
elif len(to_mul) == 1:
|
|
348
|
+
yield instance, to_mul[0]
|
|
349
|
+
else:
|
|
350
|
+
yield instance, circuit.optimised_mul(to_mul)
|
|
351
|
+
|
|
352
|
+
return CircuitTable(circuit, factor_rv_indexes, _result_rows())
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _rows_for_function_const(
|
|
356
|
+
function: PotentialFunction,
|
|
357
|
+
circuit: Circuit,
|
|
358
|
+
) -> _FunctionRows:
|
|
359
|
+
"""
|
|
360
|
+
Get the rows (instance, node) for the given potential function
|
|
361
|
+
where each node is a circuit constant.
|
|
362
|
+
This will exclude zero values.
|
|
363
|
+
"""
|
|
364
|
+
if isinstance(function, ZeroPotentialFunction):
|
|
365
|
+
# shortcut
|
|
366
|
+
return _FunctionRows({})
|
|
367
|
+
|
|
368
|
+
return _FunctionRows({
|
|
369
|
+
tuple(instance): circuit.const(value)
|
|
370
|
+
for instance, _, value in function.keys_with_param
|
|
371
|
+
if value != 0
|
|
372
|
+
})
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def _rows_for_function_var(
|
|
376
|
+
function: PotentialFunction,
|
|
377
|
+
circuit: Circuit,
|
|
378
|
+
slot_map: Dict[SlotKey, int],
|
|
379
|
+
) -> _FunctionRows:
|
|
380
|
+
"""
|
|
381
|
+
Get the rows (instance, node) for the given potential function
|
|
382
|
+
where each node is a circuit variable.
|
|
383
|
+
"""
|
|
384
|
+
|
|
385
|
+
def _create_param_var(param_id: ParamId) -> VarNode:
|
|
386
|
+
"""
|
|
387
|
+
Create a circuit variable for the given parameter id.
|
|
388
|
+
This assumes one does not already exist for the parameter id.
|
|
389
|
+
"""
|
|
390
|
+
assert param_id not in slot_map.keys(), 'parameter should not already have a circuit var'
|
|
391
|
+
node: VarNode = circuit.new_var()
|
|
392
|
+
slot_map[param_id] = node.idx
|
|
393
|
+
return node
|
|
394
|
+
|
|
395
|
+
return _FunctionRows({
|
|
396
|
+
tuple(instance): _create_param_var(function.param_id(param_index))
|
|
397
|
+
for instance, param_index, _ in function.keys_with_param
|
|
398
|
+
})
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from itertools import chain
|
|
5
|
+
from typing import List, Set, Callable, Sequence, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from ck.pgm import PGM, Factor
|
|
10
|
+
from ck.pgm_compiler.support.clusters import Clusters, min_degree, min_fill, \
|
|
11
|
+
min_degree_then_fill, min_fill_then_degree, min_weighted_degree, min_weighted_fill, min_traditional_weighted_fill, \
|
|
12
|
+
ClusterAlgorithm
|
|
13
|
+
from ck.utils.np_extras import NDArrayFloat64
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class JoinTree:
|
|
18
|
+
# The PGM that this join tree is for.
|
|
19
|
+
pgm: PGM
|
|
20
|
+
|
|
21
|
+
# Indexes of random variables in this join tree node
|
|
22
|
+
cluster: Set[int]
|
|
23
|
+
|
|
24
|
+
# Child nodes in the join tree
|
|
25
|
+
children: List[JoinTree]
|
|
26
|
+
|
|
27
|
+
# Factors of the PGM allocated to this join tree node.
|
|
28
|
+
factors: List[Factor]
|
|
29
|
+
|
|
30
|
+
# Indexes of random variables that in both this cluster and the parent's cluster.
|
|
31
|
+
# (Empty if this is the root of the spanning tree).
|
|
32
|
+
separator: Set[int]
|
|
33
|
+
|
|
34
|
+
def max_cluster_size(self) -> int:
|
|
35
|
+
"""
|
|
36
|
+
Returns:
|
|
37
|
+
the maximum `len(self.cluster)` over self and all children, recursively.
|
|
38
|
+
"""
|
|
39
|
+
return max(chain((len(self.cluster),), (child.max_cluster_size() for child in self.children)))
|
|
40
|
+
|
|
41
|
+
def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
|
|
42
|
+
"""
|
|
43
|
+
Returns:
|
|
44
|
+
the maximum `log2` over self and all children, recursively.
|
|
45
|
+
"""
|
|
46
|
+
self_weighted_size: float = sum(rv_log_sizes[rv_idx] for rv_idx in self.cluster)
|
|
47
|
+
return max(
|
|
48
|
+
chain(
|
|
49
|
+
(self_weighted_size,),
|
|
50
|
+
(child.max_cluster_weighted_size(rv_log_sizes) for child in self.children)
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def dump(self, *, prefix: str = '', indent: str = ' ', show_factors: bool = True) -> None:
|
|
55
|
+
"""
|
|
56
|
+
Print a dump of the Join Tree.
|
|
57
|
+
This is intended for debugging and demonstration purposes.
|
|
58
|
+
|
|
59
|
+
Each cluster is printed as: {separator rvs} | {non-separator rvs}.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
prefix: optional prefix for indenting all lines.
|
|
63
|
+
indent: additional prefix to use for extra indentation.
|
|
64
|
+
show_factors: if true, the factors of each cluster are shown.
|
|
65
|
+
"""
|
|
66
|
+
sep_str = ' '.join(repr(str(self.pgm.rvs[i])) for i in sorted(self.separator))
|
|
67
|
+
rest_str = ' '.join(repr(str(self.pgm.rvs[i])) for i in sorted(self.cluster) if i not in self.separator)
|
|
68
|
+
if len(sep_str) > 0:
|
|
69
|
+
sep_str += ' '
|
|
70
|
+
print(f'{prefix}{sep_str}| {rest_str} (factors: {len(self.factors)})')
|
|
71
|
+
if show_factors:
|
|
72
|
+
for factor in self.factors:
|
|
73
|
+
print(f'{prefix}factor{factor}')
|
|
74
|
+
next_prefix = prefix + indent
|
|
75
|
+
for child in self.children:
|
|
76
|
+
child.dump(prefix=next_prefix, indent=indent, show_factors=show_factors)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# Type for a join tree algorithm: PGM -> JoinTree.
|
|
80
|
+
JoinTreeAlgorithm = Callable[[PGM], JoinTree]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _join_tree_algorithm(pgm_to_clusters: ClusterAlgorithm) -> JoinTreeAlgorithm:
|
|
84
|
+
"""
|
|
85
|
+
Helper function for creating a standard JoinTreeAlgorithm from
|
|
86
|
+
a ClusterAlgorithm.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
pgm_to_clusters: The clusters method to use.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
a JoinTreeAlgorithm.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __join_tree_algorithm(pgm: PGM) -> JoinTree:
|
|
96
|
+
clusters: Clusters = pgm_to_clusters(pgm)
|
|
97
|
+
return clusters_to_join_tree(clusters)
|
|
98
|
+
|
|
99
|
+
return __join_tree_algorithm
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# standard JoinTreeAlgorithms
|
|
103
|
+
|
|
104
|
+
MIN_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_degree)
|
|
105
|
+
MIN_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_fill)
|
|
106
|
+
MIN_DEGREE_THEN_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_degree_then_fill)
|
|
107
|
+
MIN_FILL_THEN_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_fill_then_degree)
|
|
108
|
+
MIN_WEIGHTED_DEGREE: JoinTreeAlgorithm = _join_tree_algorithm(min_weighted_degree)
|
|
109
|
+
MIN_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_weighted_fill)
|
|
110
|
+
MIN_TRADITIONAL_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_traditional_weighted_fill)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def clusters_to_join_tree(clusters: Clusters) -> JoinTree:
|
|
114
|
+
"""
|
|
115
|
+
Construct a join tree maker for the given PGM and random variable clusters.
|
|
116
|
+
|
|
117
|
+
A join tree is formed by finding a minimum spanning tree over the clusters
|
|
118
|
+
where the cost between a pair of cluster is defined according to
|
|
119
|
+
`separator_cost_counts` and `costing`.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
clusters: the clusters that resulted from graph clusters of the given PGM.
|
|
123
|
+
"""
|
|
124
|
+
pgm: PGM = clusters.pgm
|
|
125
|
+
cluster_sets: List[Set[int]] = clusters.clusters
|
|
126
|
+
number_of_clusters = len(cluster_sets)
|
|
127
|
+
|
|
128
|
+
# Dealing with these cases directly simplifies
|
|
129
|
+
# the spanning tree algorithm implementation.
|
|
130
|
+
if number_of_clusters == 0:
|
|
131
|
+
return JoinTree(pgm, set(), [], [], set())
|
|
132
|
+
elif number_of_clusters == 1:
|
|
133
|
+
return JoinTree(pgm, cluster_sets[0], [], list(pgm.factors), set())
|
|
134
|
+
|
|
135
|
+
# Calculate inter-cluster costs for determining the minimum spanning tree
|
|
136
|
+
cost: NDArrayFloat64 = np.zeros((number_of_clusters, number_of_clusters), dtype=np.float64)
|
|
137
|
+
# We will use separator state space size to break ties.
|
|
138
|
+
max_raw_break_cost = sum(pgm.rv_log_sizes) * 1.1 # sum of break costs must be < 1
|
|
139
|
+
break_cost = [cost / max_raw_break_cost for cost in pgm.rv_log_sizes]
|
|
140
|
+
for i in range(number_of_clusters):
|
|
141
|
+
cluster_i = cluster_sets[i]
|
|
142
|
+
for j in range(i + 1, number_of_clusters):
|
|
143
|
+
cluster_j = cluster_sets[j]
|
|
144
|
+
separator = cluster_i.intersection(cluster_j)
|
|
145
|
+
cost[i, j] = cost[j, i] = -len(separator) + sum(break_cost[rv_idx] for rv_idx in separator)
|
|
146
|
+
|
|
147
|
+
# Make the spanning tree over the clusters
|
|
148
|
+
root_custer_index: int
|
|
149
|
+
children: List[List[int]]
|
|
150
|
+
children, root_custer_index = _make_spanning_tree_small_root(cost, clusters.clusters)
|
|
151
|
+
|
|
152
|
+
# Allocate each PGM factor to a cluster
|
|
153
|
+
cluster_factors: List[List[Factor]] = [[] for _ in range(number_of_clusters)]
|
|
154
|
+
ordered_indexed_clusters = list(enumerate(cluster_sets))
|
|
155
|
+
ordered_indexed_clusters.sort(key=lambda idx_c: len(idx_c[1])) # sort from smallest to largest cluster
|
|
156
|
+
for factor in pgm.factors:
|
|
157
|
+
rv_indexes = frozenset(rv.idx for rv in factor.rvs)
|
|
158
|
+
for cluster_index, cluster in ordered_indexed_clusters:
|
|
159
|
+
if rv_indexes.issubset(cluster):
|
|
160
|
+
cluster_factors[cluster_index].append(factor)
|
|
161
|
+
break
|
|
162
|
+
|
|
163
|
+
return _form_join_tree_r(pgm, root_custer_index, set(), children, cluster_sets, cluster_factors)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
_INF = float('inf')
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]]) -> Tuple[List[List[int]], int]:
|
|
170
|
+
"""
|
|
171
|
+
Construct a minimum spanning tree over the clusters, where the root is the cluster with
|
|
172
|
+
the smallest number of random variable.
|
|
173
|
+
"""
|
|
174
|
+
root_custer_index: int = 0
|
|
175
|
+
root_size: int = len(clusters[root_custer_index])
|
|
176
|
+
for i, cluster in enumerate(clusters[1:], start=1):
|
|
177
|
+
if len(clusters[root_custer_index]) < root_size:
|
|
178
|
+
root_custer_index = i
|
|
179
|
+
root_size: int = len(cluster)
|
|
180
|
+
|
|
181
|
+
children: List[List[int]] = _make_spanning_tree_at_root(cost, root_custer_index)
|
|
182
|
+
return children, root_custer_index
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _make_spanning_tree_arbitrary_root(cost: NDArrayFloat64) -> Tuple[List[List[int]], int]:
|
|
186
|
+
"""
|
|
187
|
+
Construct a minimum spanning tree over the clusters, starting at an arbitrary root.
|
|
188
|
+
"""
|
|
189
|
+
root_custer_index: int = 0
|
|
190
|
+
children: List[List[int]] = _make_spanning_tree_at_root(cost, root_custer_index)
|
|
191
|
+
return children, root_custer_index
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _make_spanning_tree_at_root(
|
|
195
|
+
cost: NDArrayFloat64,
|
|
196
|
+
root_custer_index: int,
|
|
197
|
+
) -> List[List[int]]:
|
|
198
|
+
"""
|
|
199
|
+
Construct a minimum spanning tree over the clusters, starting at the given root.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
cost: and nxn matrix where n is the number of clusters and cost[i, j]
|
|
203
|
+
gives the cost between clusters i and j.
|
|
204
|
+
root_custer_index: a nominated root cluster to be the root of the tree.
|
|
205
|
+
"""
|
|
206
|
+
number_of_clusters: int = cost.shape[0]
|
|
207
|
+
|
|
208
|
+
# clusters left to process.
|
|
209
|
+
remaining: List[int] = list(range(number_of_clusters))
|
|
210
|
+
|
|
211
|
+
# clusters that have been processed.
|
|
212
|
+
included: List[int] = []
|
|
213
|
+
|
|
214
|
+
def remove_remaining(_remaining_index: int) -> None:
|
|
215
|
+
# Remove the `remaining` element at the given index location.
|
|
216
|
+
remaining[_remaining_index] = remaining[-1]
|
|
217
|
+
remaining.pop()
|
|
218
|
+
|
|
219
|
+
# Move root from `remaining` to `included`
|
|
220
|
+
included.append(root_custer_index)
|
|
221
|
+
remove_remaining(root_custer_index) # assumes remaining[root_custer_index] = root_custer_index
|
|
222
|
+
|
|
223
|
+
# Data structure to collect the results.
|
|
224
|
+
children: List[List[int]] = [[] for _ in range(number_of_clusters)]
|
|
225
|
+
|
|
226
|
+
while True:
|
|
227
|
+
min_i: int = 0
|
|
228
|
+
min_j: int = 0
|
|
229
|
+
min_j_pos: int = 0
|
|
230
|
+
min_c: float = _INF
|
|
231
|
+
for i in included:
|
|
232
|
+
for j_pos, j in enumerate(remaining):
|
|
233
|
+
c: float = cost.item(i, j)
|
|
234
|
+
if c < min_c:
|
|
235
|
+
min_c = c
|
|
236
|
+
min_i = i
|
|
237
|
+
min_j = j
|
|
238
|
+
min_j_pos = j_pos
|
|
239
|
+
|
|
240
|
+
# Record the child and move remaining_idx from 'remaining' to 'included'.
|
|
241
|
+
children[min_i].append(min_j)
|
|
242
|
+
if len(remaining) == 1:
|
|
243
|
+
# That was the last one.
|
|
244
|
+
return children
|
|
245
|
+
|
|
246
|
+
# Update `remaining` and `included`
|
|
247
|
+
remove_remaining(min_j_pos)
|
|
248
|
+
included.append(min_j)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _form_join_tree_r(
|
|
252
|
+
pgm: PGM,
|
|
253
|
+
cluster_index: int,
|
|
254
|
+
parent_cluster: Set[int],
|
|
255
|
+
children: Sequence[List[int]],
|
|
256
|
+
clusters: Sequence[Set[int]],
|
|
257
|
+
cluster_factors: List[List[Factor]],
|
|
258
|
+
) -> JoinTree:
|
|
259
|
+
"""
|
|
260
|
+
Recursively build the join tree data structure.
|
|
261
|
+
"""
|
|
262
|
+
cluster: Set[int] = clusters[cluster_index]
|
|
263
|
+
factors: List[Factor] = cluster_factors[cluster_index]
|
|
264
|
+
children = [
|
|
265
|
+
_form_join_tree_r(pgm, child, cluster, children, clusters, cluster_factors)
|
|
266
|
+
for child in children[cluster_index]
|
|
267
|
+
]
|
|
268
|
+
separator: Set[int] = parent_cluster.intersection(cluster)
|
|
269
|
+
return JoinTree(
|
|
270
|
+
pgm,
|
|
271
|
+
cluster,
|
|
272
|
+
children,
|
|
273
|
+
factors,
|
|
274
|
+
separator,
|
|
275
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from types import ModuleType
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
from ck.pgm import PGM
|
|
5
|
+
from ck.pgm_circuit import PGMCircuit
|
|
6
|
+
from ck.pgm_compiler import PGMCompiler
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_compiler(module: ModuleType, **kwargs) -> Tuple[PGMCompiler]:
|
|
10
|
+
"""
|
|
11
|
+
Helper function to create a named PGM compiler.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
module: module containing `compile_pgm` function.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
a singleton tuple containing PGMCompiler function.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def compiler(pgm: PGM, const_parameters: bool = True) -> PGMCircuit:
|
|
21
|
+
"""Conforms to the `PGMCompiler` protocol."""
|
|
22
|
+
return module.compile_pgm(pgm, const_parameters=const_parameters, **kwargs)
|
|
23
|
+
|
|
24
|
+
return compiler,
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_compiler_algorithm(module, algorithm: str, **kwargs) -> Tuple[PGMCompiler]:
|
|
28
|
+
"""
|
|
29
|
+
Helper function to create a named PGM compiler, with a named algorithm argument.
|
|
30
|
+
"""
|
|
31
|
+
return get_compiler(module, algorithm=getattr(module, algorithm, **kwargs))
|
|
32
|
+
|
|
33
|
+
|