compiled-knowledge 4.0.0a20__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37525 -0
- ck/circuit/_circuit_cy.cpython-313-darwin.so +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19826 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10620 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16398 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +6 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Sequence, Tuple, List, Iterator, Set, Iterable, Optional, Callable
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ck.circuit import Circuit, VarNode, CircuitNode
|
|
9
|
+
from ck.pgm import PGM, ParamId, Factor, PotentialFunction, RandomVariable, ZeroPotentialFunction
|
|
10
|
+
from ck.pgm_circuit.slot_map import SlotMap, SlotKey
|
|
11
|
+
from ck.pgm_compiler.support.circuit_table import CircuitTable, TableInstance
|
|
12
|
+
from ck.utils.iter_extras import pairs
|
|
13
|
+
from ck.utils.map_list import MapList
|
|
14
|
+
from ck.utils.np_extras import NDArray, NDArrayFloat64
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class FactorTables:
|
|
19
|
+
circuit: Circuit # The host circuit
|
|
20
|
+
number_of_indicators: int # number of indicator variables
|
|
21
|
+
number_of_parameters: int # number of parameter variables (i.e., non-const, in-use parameters)
|
|
22
|
+
slot_map: SlotMap # map from Indicator or ParamId object to a circuit var index.
|
|
23
|
+
tables: Sequence[CircuitTable] # one CircuitTable for each PGM factor.
|
|
24
|
+
|
|
25
|
+
# For a non-const, in-use parameter with id `param_id`, the PGM value of that
|
|
26
|
+
# parameter was `self.parameter_values[self.slot_map[param_id] - self.number_of_indicators]`.
|
|
27
|
+
parameter_values: NDArray
|
|
28
|
+
|
|
29
|
+
def get_table(self, factor: Factor) -> CircuitTable:
|
|
30
|
+
return self.tables[factor.idx]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def make_factor_tables(
|
|
34
|
+
pgm: PGM,
|
|
35
|
+
const_parameters: bool,
|
|
36
|
+
multiply_indicators: bool,
|
|
37
|
+
pre_prune_factor_tables: bool,
|
|
38
|
+
) -> FactorTables:
|
|
39
|
+
"""
|
|
40
|
+
Consistently and efficiently create circuit tables for factors of a PGM.
|
|
41
|
+
|
|
42
|
+
Creates:
|
|
43
|
+
* a circuit,
|
|
44
|
+
* a circuit variable for each indicator of the PGM,
|
|
45
|
+
* a circuit variable for each non-constant, in-use potential function parameter.
|
|
46
|
+
* a circuit table for each Factor of the PGM,
|
|
47
|
+
|
|
48
|
+
The parameter of each potential function will be converted either
|
|
49
|
+
eiter to a circuit constant (if const_parameters is true) or a circuit
|
|
50
|
+
variable (if const_parameters is false).
|
|
51
|
+
|
|
52
|
+
Random variables will be multiplied into factor circuit tables if
|
|
53
|
+
`multiply_indicators` is true.
|
|
54
|
+
|
|
55
|
+
A slot map will be created that maps PGM indicators and parameter ids to circuit var indices.
|
|
56
|
+
Specifically, a circuit var will be added for each indicator,
|
|
57
|
+
in the order they appear in `pgm.indicators`. Circuit vars for parameter ids will be added
|
|
58
|
+
after those for indicators, and only if const_parameters is false.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
pgm: The PGM with the random variables, factors, and potential functions.
|
|
62
|
+
const_parameters: if true, then potential function parameters will be circuit constants,
|
|
63
|
+
otherwise they will be circuit variables, with entries in the returned slot map.
|
|
64
|
+
multiply_indicators: if true then indicator variables will be multiplied into an acceptable
|
|
65
|
+
factor.
|
|
66
|
+
pre_prune_factor_tables: if true, then heuristics will be used to remove any provably zero row.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
FactorTables, holding a slot_map and a circuit table for each PGM factor.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
# Create circuit and initialise the slot map with indicator variables
|
|
73
|
+
circuit = Circuit()
|
|
74
|
+
slot_map: Dict[SlotKey, int] = {
|
|
75
|
+
indicator: circuit.new_var().idx
|
|
76
|
+
for indicator in pgm.indicators
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
# Get the circuit table rows for each potential function
|
|
80
|
+
# functions_rows[id(function)] = rows for the function
|
|
81
|
+
functions_rows: Dict[int, _FunctionRows]
|
|
82
|
+
if const_parameters:
|
|
83
|
+
functions_rows = {
|
|
84
|
+
id(function): _rows_for_function_const(function, circuit)
|
|
85
|
+
for function in pgm.functions
|
|
86
|
+
}
|
|
87
|
+
else:
|
|
88
|
+
functions_rows = {
|
|
89
|
+
id(function): _rows_for_function_var(function, circuit, slot_map)
|
|
90
|
+
for function in pgm.functions
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# Link factors to function rows.
|
|
94
|
+
# factor_rows[id(factor)] = rows for the factor
|
|
95
|
+
factor_rows: Dict[int, _FactorRows] = {}
|
|
96
|
+
for factor in pgm.factors:
|
|
97
|
+
rows: _FunctionRows = functions_rows[id(factor.function)]
|
|
98
|
+
rows.use_count += 1
|
|
99
|
+
factor_rows[id(factor)] = _FactorRows(factor, rows)
|
|
100
|
+
|
|
101
|
+
# Check to see if any factor rows can be pre-pruned.
|
|
102
|
+
if pre_prune_factor_tables:
|
|
103
|
+
_pre_prune_factor_tables(list(factor_rows.values()))
|
|
104
|
+
|
|
105
|
+
# Allocated random variables to factors
|
|
106
|
+
factors_mul_rvs: MapList[int, RandomVariable]
|
|
107
|
+
if multiply_indicators:
|
|
108
|
+
def _factor_size(_factor: Factor) -> int:
|
|
109
|
+
return len(factor_rows[id(_factor)])
|
|
110
|
+
|
|
111
|
+
factors_mul_rvs = _assign_rvs_to_factors(pgm, _factor_size)
|
|
112
|
+
else:
|
|
113
|
+
factors_mul_rvs = MapList() # no assignment of rvs to factors.
|
|
114
|
+
|
|
115
|
+
# Make a circuit table for each factor. `tables[factor.index]` is the circuit table for `factor`.
|
|
116
|
+
tables: List[CircuitTable] = [
|
|
117
|
+
_make_factor_table(factor, circuit, slot_map, factor_rows[id(factor)], factors_mul_rvs)
|
|
118
|
+
for factor in pgm.factors
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# Extract the parameter values (if they are circuit vars).
|
|
122
|
+
number_of_indicators: int = pgm.number_of_indicators
|
|
123
|
+
number_of_parameters: int = len(slot_map) - number_of_indicators
|
|
124
|
+
parameter_values: NDArrayFloat64 = np.zeros(number_of_parameters, dtype=np.float64)
|
|
125
|
+
if not const_parameters:
|
|
126
|
+
for function in pgm.functions:
|
|
127
|
+
for param_index, value in function.params:
|
|
128
|
+
param_id: ParamId = function.param_id(param_index)
|
|
129
|
+
slot: Optional[int] = slot_map.get(param_id)
|
|
130
|
+
if slot is not None:
|
|
131
|
+
parameter_values[slot - number_of_indicators] = value
|
|
132
|
+
|
|
133
|
+
return FactorTables(
|
|
134
|
+
circuit=circuit,
|
|
135
|
+
number_of_indicators=number_of_indicators,
|
|
136
|
+
number_of_parameters=number_of_parameters,
|
|
137
|
+
slot_map=slot_map,
|
|
138
|
+
tables=tables,
|
|
139
|
+
parameter_values=parameter_values,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _assign_rvs_to_factors(
|
|
144
|
+
pgm: PGM,
|
|
145
|
+
factor_size: Callable[[Factor], int],
|
|
146
|
+
) -> MapList[int, RandomVariable]:
|
|
147
|
+
"""
|
|
148
|
+
Assign each random variable to the smallest factor containing it.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
a map from factor id to list of random variables assigned to that factor
|
|
152
|
+
"""
|
|
153
|
+
factors = pgm.factors
|
|
154
|
+
rvs = pgm.rvs
|
|
155
|
+
|
|
156
|
+
# For each rv, get the factors it is in
|
|
157
|
+
rv_factors: MapList[int, Factor] = MapList() # rv index to list of Factors with that rv.
|
|
158
|
+
for factor in factors:
|
|
159
|
+
for rv in factor.rvs:
|
|
160
|
+
rv_factors.append(rv.idx, factor)
|
|
161
|
+
|
|
162
|
+
# For each rv, assign it to a factor for multiplication
|
|
163
|
+
factors_mul_rvs: MapList[int, RandomVariable] = MapList() # factor id to list of rvs
|
|
164
|
+
for rv_index in range(len(rvs)):
|
|
165
|
+
candidates: Sequence[Factor] = rv_factors.get(rv_index, ())
|
|
166
|
+
if len(candidates) > 0:
|
|
167
|
+
best_factor = min(candidates, key=factor_size)
|
|
168
|
+
factors_mul_rvs.append(id(best_factor), rvs[rv_index])
|
|
169
|
+
|
|
170
|
+
return factors_mul_rvs
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class _FunctionRows:
|
|
174
|
+
def __init__(self, rows: Dict[TableInstance, CircuitNode], use_count: int = 0):
|
|
175
|
+
self.rows: Dict[TableInstance, CircuitNode] = rows
|
|
176
|
+
self.use_count: int = use_count
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class _FactorRows:
|
|
180
|
+
def __init__(self, factor: Factor, rows: _FunctionRows):
|
|
181
|
+
self.rows: _FunctionRows = rows
|
|
182
|
+
self.rv_indexes: Tuple[int, ...] = tuple(rv.idx for rv in factor.rvs)
|
|
183
|
+
|
|
184
|
+
def __len__(self) -> int:
|
|
185
|
+
return len(self.rows.rows)
|
|
186
|
+
|
|
187
|
+
def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
|
|
188
|
+
return self.rows.rows.items()
|
|
189
|
+
|
|
190
|
+
def prune(self, extra_keys: Set[TableInstance]) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Remove the given keys from the factor's function rows.
|
|
193
|
+
"""
|
|
194
|
+
if len(extra_keys) > 0:
|
|
195
|
+
new_rows: Dict[TableInstance, CircuitNode] = {
|
|
196
|
+
instance: node
|
|
197
|
+
for instance, node in self.rows.rows.items()
|
|
198
|
+
if instance not in extra_keys
|
|
199
|
+
}
|
|
200
|
+
if self.rows.use_count > 1:
|
|
201
|
+
self.rows.use_count -= 1
|
|
202
|
+
self.rows = _FunctionRows(new_rows, 1)
|
|
203
|
+
else:
|
|
204
|
+
self.rows.rows = new_rows
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class _FactorPair:
|
|
208
|
+
def __init__(self, x: _FactorRows, y: _FactorRows):
|
|
209
|
+
self.x: _FactorRows = x
|
|
210
|
+
self.y: _FactorRows = y
|
|
211
|
+
|
|
212
|
+
x_set = set(self.x.rv_indexes)
|
|
213
|
+
|
|
214
|
+
# Identify all random variables used by x and y
|
|
215
|
+
self.all_rv_indexes: Set[int] = x_set.union(self.y.rv_indexes)
|
|
216
|
+
|
|
217
|
+
# Identify common random variables between x and y
|
|
218
|
+
# Keep them in a stable order
|
|
219
|
+
self.co_rv_indexes: Tuple[int, ...] = tuple(x_set.intersection(self.y.rv_indexes))
|
|
220
|
+
|
|
221
|
+
# Cache mappings from result Instance to index into source Instance (x or y).
|
|
222
|
+
# This will be used in indexing and product loops to pull our needed values
|
|
223
|
+
# from the source instances.
|
|
224
|
+
self.co_from_x_map = tuple(x.rv_indexes.index(rv_index) for rv_index in self.co_rv_indexes)
|
|
225
|
+
self.co_from_y_map = tuple(y.rv_indexes.index(rv_index) for rv_index in self.co_rv_indexes)
|
|
226
|
+
|
|
227
|
+
def prune(self) -> None:
|
|
228
|
+
"""
|
|
229
|
+
Prune any rows from x and y that cannot join to each other.
|
|
230
|
+
"""
|
|
231
|
+
co_from_x_map = self.co_from_x_map
|
|
232
|
+
co_from_y_map = self.co_from_y_map
|
|
233
|
+
x_rows = self.x.rows.rows
|
|
234
|
+
y_rows = self.y.rows.rows
|
|
235
|
+
|
|
236
|
+
x_co_set: Set[TableInstance] = {
|
|
237
|
+
tuple(instance[i] for i in co_from_x_map)
|
|
238
|
+
for instance in x_rows.keys()
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
y_co_set: Set[TableInstance] = {
|
|
242
|
+
tuple(instance[i] for i in co_from_y_map)
|
|
243
|
+
for instance in y_rows.keys()
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
# Keys in x that will not join to y
|
|
247
|
+
x_extra_keys: Set[TableInstance] = {
|
|
248
|
+
instance
|
|
249
|
+
for instance in x_rows.keys()
|
|
250
|
+
if tuple(instance[i] for i in co_from_x_map) not in y_co_set
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
# Keys in y that will not join to x
|
|
254
|
+
y_extra_keys: Set[TableInstance] = {
|
|
255
|
+
instance
|
|
256
|
+
for instance in y_rows.keys()
|
|
257
|
+
if tuple(instance[i] for i in co_from_y_map) not in x_co_set
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
self.x.prune(x_extra_keys)
|
|
261
|
+
self.y.prune(y_extra_keys)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _pre_prune_factor_tables(factor_rows: Sequence[_FactorRows]) -> None:
|
|
265
|
+
"""
|
|
266
|
+
It may be possible to reduce the size of a table for a factor.
|
|
267
|
+
|
|
268
|
+
If two factors contain a common random variable then at some point their product
|
|
269
|
+
will be formed, which may eliminate rows. This method identifies and removes
|
|
270
|
+
such rows.
|
|
271
|
+
"""
|
|
272
|
+
# Find all pairs of factors that have at least one common random variable.
|
|
273
|
+
pairs_to_check: List[_FactorPair] = [
|
|
274
|
+
_FactorPair(f1, f2)
|
|
275
|
+
for f1, f2 in pairs(factor_rows)
|
|
276
|
+
if not set(f1.rv_indexes).isdisjoint(f1.rv_indexes)
|
|
277
|
+
]
|
|
278
|
+
|
|
279
|
+
# Simple version.
|
|
280
|
+
for pair in pairs_to_check:
|
|
281
|
+
pair.prune()
|
|
282
|
+
|
|
283
|
+
# Earlier version.
|
|
284
|
+
# This version re-checks processed pairs that may get benefit from a subsequent pruning.
|
|
285
|
+
# Unfortunately, this is computationally expensive, and provides no practical benefit.
|
|
286
|
+
#
|
|
287
|
+
# pairs_done: List[_FactorPair] = []
|
|
288
|
+
# while len(pairs_to_check) > 0:
|
|
289
|
+
# pair: _FactorPair = pairs_to_check.pop()
|
|
290
|
+
# x: _FactorRows = pair.x
|
|
291
|
+
# y: _FactorRows = pair.y
|
|
292
|
+
#
|
|
293
|
+
# x_size = len(x)
|
|
294
|
+
# y_size = len(y)
|
|
295
|
+
# pair.prune()
|
|
296
|
+
#
|
|
297
|
+
# # See if any pairs need re-checking
|
|
298
|
+
# rvs_affected: Set[int] = set()
|
|
299
|
+
# if x_size != len(x):
|
|
300
|
+
# rvs_affected.update(x.rv_indexes)
|
|
301
|
+
# if y_size != len(y):
|
|
302
|
+
# rvs_affected.update(y.rv_indexes)
|
|
303
|
+
# if len(rvs_affected) > 0:
|
|
304
|
+
# next_pairs_done: List[_FactorPair] = []
|
|
305
|
+
# for pair in pairs_done:
|
|
306
|
+
# if rvs_affected.isdisjoint(pair.all_rv_indexes):
|
|
307
|
+
# next_pairs_done.append(pair)
|
|
308
|
+
# else:
|
|
309
|
+
# pairs_to_check.append(pair)
|
|
310
|
+
# pairs_done = next_pairs_done
|
|
311
|
+
#
|
|
312
|
+
# # Mark the current pair as done.
|
|
313
|
+
# pairs_done.append(pair)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def _make_factor_table(
|
|
317
|
+
factor: Factor,
|
|
318
|
+
circuit: Circuit,
|
|
319
|
+
slot_map: Dict[SlotKey, int],
|
|
320
|
+
rows: _FactorRows,
|
|
321
|
+
factors_mul_rvs: MapList[int, RandomVariable],
|
|
322
|
+
) -> CircuitTable:
|
|
323
|
+
# Get random variables to multiply into the table
|
|
324
|
+
factor_mul_rvs: Sequence[RandomVariable] = factors_mul_rvs.get(id(factor), ())
|
|
325
|
+
|
|
326
|
+
# Create the empty circuit table
|
|
327
|
+
factor_rv_indexes: Sequence[int] = tuple(rv.idx for rv in factor.rvs)
|
|
328
|
+
|
|
329
|
+
if len(factor_mul_rvs) == 0:
|
|
330
|
+
# Trivial case - no random variables to multiply into the table.
|
|
331
|
+
return CircuitTable(circuit, factor_rv_indexes, rows.items())
|
|
332
|
+
|
|
333
|
+
# Work out what element in an instance of the factor will select the indicator
|
|
334
|
+
# variable for each mul rv.
|
|
335
|
+
# inst_to_mul[i] is the index into factor.rvs for factor_mul_rvs[i]
|
|
336
|
+
inst_to_mul: Sequence[int] = tuple(factor_rv_indexes.index(rv.idx) for rv in factor_mul_rvs)
|
|
337
|
+
|
|
338
|
+
# Map a state index of a mul rv to its indicator circuit variable.
|
|
339
|
+
# mul_rvs_vars[i][j] is the indicator circuit variable for factor_mul_rvs[i][j]
|
|
340
|
+
mul_rvs_vars: Sequence[Sequence[CircuitNode]] = tuple(
|
|
341
|
+
tuple(circuit.vars[slot_map[ind]] for ind in rv.indicators)
|
|
342
|
+
for rv in factor_mul_rvs
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
346
|
+
for instance, node in rows.items():
|
|
347
|
+
to_mul = tuple(
|
|
348
|
+
mul_vars[instance[inst_index]]
|
|
349
|
+
for inst_index, mul_vars in zip(inst_to_mul, mul_rvs_vars)
|
|
350
|
+
)
|
|
351
|
+
if not node.is_one:
|
|
352
|
+
to_mul += (node,)
|
|
353
|
+
if len(to_mul) == 0:
|
|
354
|
+
yield instance, circuit.one
|
|
355
|
+
elif len(to_mul) == 1:
|
|
356
|
+
yield instance, to_mul[0]
|
|
357
|
+
else:
|
|
358
|
+
yield instance, circuit.optimised_mul(to_mul)
|
|
359
|
+
|
|
360
|
+
return CircuitTable(circuit, factor_rv_indexes, _result_rows())
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _rows_for_function_const(
|
|
364
|
+
function: PotentialFunction,
|
|
365
|
+
circuit: Circuit,
|
|
366
|
+
) -> _FunctionRows:
|
|
367
|
+
"""
|
|
368
|
+
Get the rows (instance, node) for the given potential function
|
|
369
|
+
where each node is a circuit constant.
|
|
370
|
+
This will exclude zero values.
|
|
371
|
+
"""
|
|
372
|
+
if isinstance(function, ZeroPotentialFunction):
|
|
373
|
+
# shortcut
|
|
374
|
+
return _FunctionRows({})
|
|
375
|
+
|
|
376
|
+
return _FunctionRows({
|
|
377
|
+
tuple(instance): circuit.const(value)
|
|
378
|
+
for instance, _, value in function.keys_with_param
|
|
379
|
+
if value != 0
|
|
380
|
+
})
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _rows_for_function_var(
|
|
384
|
+
function: PotentialFunction,
|
|
385
|
+
circuit: Circuit,
|
|
386
|
+
slot_map: Dict[SlotKey, int],
|
|
387
|
+
) -> _FunctionRows:
|
|
388
|
+
"""
|
|
389
|
+
Get the rows (instance, node) for the given potential function
|
|
390
|
+
where each node is a circuit variable.
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
def _create_param_var(param_id: ParamId) -> VarNode:
|
|
394
|
+
"""
|
|
395
|
+
Create a circuit variable for the given parameter id.
|
|
396
|
+
This assumes one does not already exist for the parameter id.
|
|
397
|
+
"""
|
|
398
|
+
assert param_id not in slot_map.keys(), 'parameter should not already have a circuit var'
|
|
399
|
+
node: VarNode = circuit.new_var()
|
|
400
|
+
slot_map[param_id] = node.idx
|
|
401
|
+
return node
|
|
402
|
+
|
|
403
|
+
return _FunctionRows({
|
|
404
|
+
tuple(instance): _create_param_var(function.param_id(param_index))
|
|
405
|
+
for instance, param_index, _ in function.keys_with_param
|
|
406
|
+
})
|