compiled-knowledge 4.0.0a20__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37523 -0
- ck/circuit/_circuit_cy.cp313-win_amd64.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp313-win_amd64.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp313-win_amd64.pyd +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp313-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple, Iterable
|
|
4
|
+
|
|
5
|
+
from ck.circuit import MUL, ADD
|
|
6
|
+
|
|
7
|
+
from ck.circuit._circuit_cy cimport Circuit, CircuitNode
|
|
8
|
+
|
|
9
|
+
cdef int c_ADD = ADD
|
|
10
|
+
cdef int c_MUL = MUL
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
TableInstance = Tuple[int, ...]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
cdef class CircuitTable:
|
|
17
|
+
"""
|
|
18
|
+
A circuit table manages a set of CircuitNodes, where each node corresponds
|
|
19
|
+
to an instance for a set of (zero or more) random variables.
|
|
20
|
+
|
|
21
|
+
Operations on circuit tables typically add circuit nodes to the circuit. It will
|
|
22
|
+
heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
|
|
23
|
+
by zero or one.) However, it may be that interim circuit nodes are created that
|
|
24
|
+
end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
|
|
25
|
+
completing all circuit table operations.
|
|
26
|
+
|
|
27
|
+
It is generally expected that no CircuitTable row will be created with a constant
|
|
28
|
+
zero node. These are assumed to be optimised out already.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
cdef public Circuit circuit
|
|
32
|
+
cdef public tuple[int, ...] rv_idxs
|
|
33
|
+
cdef dict[tuple[int, ...], CircuitNode] rows
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
circuit: Circuit,
|
|
38
|
+
rv_idxs: Sequence[int],
|
|
39
|
+
rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Args:
|
|
43
|
+
circuit: the circuit whose nodes are being managed by this table.
|
|
44
|
+
rv_idxs: indexes of random variables.
|
|
45
|
+
rows: optional rows to add to the table.
|
|
46
|
+
|
|
47
|
+
Assumes:
|
|
48
|
+
* rv_idxs contains no duplicates.
|
|
49
|
+
* all row instances conform to the indexed random variables.
|
|
50
|
+
* all row circuit nodes belong to the given circuit.
|
|
51
|
+
"""
|
|
52
|
+
self.circuit = circuit
|
|
53
|
+
self.rv_idxs = tuple(rv_idxs)
|
|
54
|
+
self.rows = dict(rows)
|
|
55
|
+
|
|
56
|
+
def __len__(self) -> int:
|
|
57
|
+
return len(self.rows)
|
|
58
|
+
|
|
59
|
+
def get(self, key, default=None):
|
|
60
|
+
return self.rows.get(key, default)
|
|
61
|
+
|
|
62
|
+
def keys(self) -> Iterable[CircuitNode]:
|
|
63
|
+
return self.rows.keys()
|
|
64
|
+
|
|
65
|
+
def values(self) -> Iterable[tuple[int, ...]]:
|
|
66
|
+
return self.rows.values()
|
|
67
|
+
|
|
68
|
+
def __getitem__(self, key):
|
|
69
|
+
return self.rows[key]
|
|
70
|
+
|
|
71
|
+
def __setitem__(self, key, value):
|
|
72
|
+
self.rows[key] = value
|
|
73
|
+
|
|
74
|
+
cpdef CircuitNode top(self):
|
|
75
|
+
# Get the circuit top value.
|
|
76
|
+
#
|
|
77
|
+
# Raises:
|
|
78
|
+
# RuntimeError if there is more than one row in the table.
|
|
79
|
+
#
|
|
80
|
+
# Returns:
|
|
81
|
+
# A single circuit node.
|
|
82
|
+
cdef int number_of_rows = len(self.rows)
|
|
83
|
+
if number_of_rows == 0:
|
|
84
|
+
return self.circuit.zero
|
|
85
|
+
elif number_of_rows == 1:
|
|
86
|
+
return next(iter(self.rows.values()))
|
|
87
|
+
else:
|
|
88
|
+
raise RuntimeError('cannot get top node from a table with more that 1 row')
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# ==================================================================================
|
|
92
|
+
# Circuit Table Operations
|
|
93
|
+
# ==================================================================================
|
|
94
|
+
|
|
95
|
+
cpdef CircuitTable sum_out(CircuitTable table, object rv_idxs: Iterable[int]):
|
|
96
|
+
# Return a circuit table that results from summing out
|
|
97
|
+
# the given random variables of this circuit table.
|
|
98
|
+
#
|
|
99
|
+
# Normally this will return a new table. However, if rv_idxs is empty,
|
|
100
|
+
# then the given table is returned unmodified.
|
|
101
|
+
#
|
|
102
|
+
# Raises:
|
|
103
|
+
# ValueError if rv_idxs is not a subset of table.rv_idxs.
|
|
104
|
+
# ValueError if rv_idxs contains duplicates.
|
|
105
|
+
cdef tuple[int, ...] rv_idxs_seq = tuple(rv_idxs)
|
|
106
|
+
|
|
107
|
+
if len(rv_idxs_seq) == 0:
|
|
108
|
+
# nothing to do
|
|
109
|
+
return table
|
|
110
|
+
|
|
111
|
+
cdef set[int] rv_idxs_set = set(rv_idxs_seq)
|
|
112
|
+
if len(rv_idxs_set) != len(rv_idxs_seq):
|
|
113
|
+
raise ValueError('rv_idxs contains duplicates')
|
|
114
|
+
if not rv_idxs_set.issubset(table.rv_idxs):
|
|
115
|
+
raise ValueError('rv_idxs is not a subset of table.rv_idxs')
|
|
116
|
+
|
|
117
|
+
cdef int rv_index
|
|
118
|
+
cdef list[int] remaining_rv_idxs = []
|
|
119
|
+
for rv_index in table.rv_idxs:
|
|
120
|
+
if rv_index not in rv_idxs_set:
|
|
121
|
+
remaining_rv_idxs.append(rv_index)
|
|
122
|
+
|
|
123
|
+
cdef int num_remaining = len(remaining_rv_idxs)
|
|
124
|
+
if num_remaining == 0:
|
|
125
|
+
# Special case: summing out all random variables
|
|
126
|
+
return sum_out_all(table)
|
|
127
|
+
|
|
128
|
+
# index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
|
|
129
|
+
cdef list[int] index_map = []
|
|
130
|
+
for rv_index in remaining_rv_idxs:
|
|
131
|
+
index_map.append(_find(table.rv_idxs, rv_index))
|
|
132
|
+
|
|
133
|
+
cdef dict[tuple[int, ...], list[CircuitNode]] groups = {}
|
|
134
|
+
cdef object got
|
|
135
|
+
cdef list[int] group_instance
|
|
136
|
+
cdef tuple[int, ...] group_instance_tuple
|
|
137
|
+
cdef int i
|
|
138
|
+
cdef CircuitNode node
|
|
139
|
+
cdef tuple[int, ...] instance
|
|
140
|
+
for instance, node in table.rows.items():
|
|
141
|
+
group_instance = []
|
|
142
|
+
for i in index_map:
|
|
143
|
+
group_instance.append(instance[i])
|
|
144
|
+
group_instance_tuple = tuple(group_instance)
|
|
145
|
+
got = groups.get(group_instance_tuple)
|
|
146
|
+
if got is None:
|
|
147
|
+
groups[group_instance_tuple] = [node]
|
|
148
|
+
else:
|
|
149
|
+
got.append(node)
|
|
150
|
+
|
|
151
|
+
cdef Circuit circuit = table.circuit
|
|
152
|
+
cdef CircuitTable new_table = CircuitTable(circuit, remaining_rv_idxs)
|
|
153
|
+
cdef dict[tuple[int, ...], CircuitNode] rows = new_table.rows
|
|
154
|
+
|
|
155
|
+
for group_instance_tuple, to_add in groups.items():
|
|
156
|
+
node = circuit.op(c_ADD, tuple(to_add))
|
|
157
|
+
if not node.is_zero:
|
|
158
|
+
rows[group_instance_tuple] = node
|
|
159
|
+
|
|
160
|
+
return new_table
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
cpdef CircuitTable sum_out_all(CircuitTable table):
|
|
164
|
+
# Return a circuit table that results from summing out
|
|
165
|
+
# all random variables of this circuit table.
|
|
166
|
+
circuit: Circuit = table.circuit
|
|
167
|
+
num_rows: int = len(table)
|
|
168
|
+
if num_rows == 0:
|
|
169
|
+
return CircuitTable(circuit, ())
|
|
170
|
+
elif num_rows == 1:
|
|
171
|
+
node = next(iter(table.rows.values()))
|
|
172
|
+
else:
|
|
173
|
+
node: CircuitNode = circuit.op(c_ADD, tuple(table.rows.values()))
|
|
174
|
+
if node.is_zero:
|
|
175
|
+
return CircuitTable(circuit, ())
|
|
176
|
+
|
|
177
|
+
return CircuitTable(circuit, (), [((), node)])
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
cpdef CircuitTable project(CircuitTable table: CircuitTable, object rv_idxs: Iterable[int]):
|
|
181
|
+
# Call `sum_out(table, to_sum_out)`, where
|
|
182
|
+
# `to_sum_out = table.rv_idxs - rv_idxs`.
|
|
183
|
+
cdef set[int] to_sum_out = set(table.rv_idxs)
|
|
184
|
+
to_sum_out.difference_update(rv_idxs)
|
|
185
|
+
return sum_out(table, to_sum_out)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
cpdef CircuitTable product(CircuitTable x, CircuitTable y):
|
|
189
|
+
# Return a circuit table that results from the product of the two given tables.
|
|
190
|
+
#
|
|
191
|
+
# If x or y equals `one_table`, then the other table is returned. Otherwise,
|
|
192
|
+
# a new circuit table will be constructed and returned.
|
|
193
|
+
cdef int i
|
|
194
|
+
cdef Circuit circuit = x.circuit
|
|
195
|
+
if y.circuit is not circuit:
|
|
196
|
+
raise ValueError('circuit tables must refer to the same circuit')
|
|
197
|
+
|
|
198
|
+
# Make the smaller table 'y', and the other 'x'.
|
|
199
|
+
# This is to minimise the index size on 'y'.
|
|
200
|
+
if len(x) < len(y):
|
|
201
|
+
x, y = y, x
|
|
202
|
+
|
|
203
|
+
# Special case: y == 0 or 1, and has no random variables.
|
|
204
|
+
if len(y.rv_idxs) == 0:
|
|
205
|
+
if len(y) == 1 and y.top().is_one:
|
|
206
|
+
return x
|
|
207
|
+
elif len(y) == 0:
|
|
208
|
+
return CircuitTable(circuit, x.rv_idxs)
|
|
209
|
+
|
|
210
|
+
# Set operations on rv indexes. After these operations:
|
|
211
|
+
# * co_rv_idxs is the set of rv indexes common (co) to x and y,
|
|
212
|
+
# * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
|
|
213
|
+
cdef set[int] yo_rv_idxs_set = set(y.rv_idxs)
|
|
214
|
+
cdef set[int] co_rv_idxs_set = set(x.rv_idxs)
|
|
215
|
+
co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
|
|
216
|
+
yo_rv_idxs_set.difference_update(co_rv_idxs_set)
|
|
217
|
+
|
|
218
|
+
if len(co_rv_idxs_set) == 0:
|
|
219
|
+
# Special case: no common random variables.
|
|
220
|
+
return _product_no_common_rvs(x, y)
|
|
221
|
+
|
|
222
|
+
# Convert random variable index sets to sequences
|
|
223
|
+
cdef tuple[int, ...] yo_rv_idxs = tuple(yo_rv_idxs_set) # y only random variables
|
|
224
|
+
cdef tuple[int, ...] co_rv_idxs = tuple(co_rv_idxs_set) # common random variables
|
|
225
|
+
|
|
226
|
+
# Cache mappings from result Instance to index into source Instance (x or y).
|
|
227
|
+
# This will be used in indexing and product loops to pull our needed values
|
|
228
|
+
# from the source instances.
|
|
229
|
+
cdef list[int] co_from_x_map = []
|
|
230
|
+
cdef list[int] co_from_y_map = []
|
|
231
|
+
cdef list[int] yo_from_y_map = []
|
|
232
|
+
for rv_index in co_rv_idxs:
|
|
233
|
+
co_from_x_map.append(_find(x.rv_idxs, rv_index))
|
|
234
|
+
co_from_y_map.append(_find(y.rv_idxs, rv_index))
|
|
235
|
+
for rv_index in yo_rv_idxs:
|
|
236
|
+
yo_from_y_map.append(_find(y.rv_idxs, rv_index))
|
|
237
|
+
|
|
238
|
+
cdef list[int] co
|
|
239
|
+
cdef list[int] yo
|
|
240
|
+
cdef object got
|
|
241
|
+
cdef tuple[int, ...] co_tuple
|
|
242
|
+
cdef tuple[int, ...] yo_tuple
|
|
243
|
+
|
|
244
|
+
cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + yo_rv_idxs)
|
|
245
|
+
cdef dict[tuple[int, ...], CircuitNode] rows = table.rows
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# Index the y rows by common-only key (y is the smaller of the two tables).
|
|
249
|
+
cdef dict[tuple[int, ...], list[tuple[tuple[int, ...], CircuitNode]]] y_index = {}
|
|
250
|
+
for y_instance, y_node in y.rows.items():
|
|
251
|
+
co = []
|
|
252
|
+
yo = []
|
|
253
|
+
for i in co_from_y_map:
|
|
254
|
+
co.append(y_instance[i])
|
|
255
|
+
for i in yo_from_y_map:
|
|
256
|
+
yo.append(y_instance[i])
|
|
257
|
+
co_tuple = tuple(co)
|
|
258
|
+
yo_tuple = tuple(yo)
|
|
259
|
+
got = y_index.get(co_tuple)
|
|
260
|
+
if got is None:
|
|
261
|
+
y_index[co_tuple] = [(yo_tuple, y_node)]
|
|
262
|
+
else:
|
|
263
|
+
got.append((yo_tuple, y_node))
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
# Iterate over x rows, inserting (instance, value).
|
|
267
|
+
# Rows with constant node values of one are optimised out.
|
|
268
|
+
for x_instance, x_node in x.rows.items():
|
|
269
|
+
co = []
|
|
270
|
+
for i in co_from_x_map:
|
|
271
|
+
co.append(x_instance[i])
|
|
272
|
+
co_tuple = tuple(co)
|
|
273
|
+
|
|
274
|
+
if x_node.is_one:
|
|
275
|
+
# Multiplying by one.
|
|
276
|
+
# Iterate over matching y rows.
|
|
277
|
+
got = y_index.get(co_tuple)
|
|
278
|
+
if got is not None:
|
|
279
|
+
for yo_tuple, y_node in got:
|
|
280
|
+
rows[x_instance + yo_tuple] = y_node
|
|
281
|
+
else:
|
|
282
|
+
# Iterate over matching y rows.
|
|
283
|
+
got = y_index.get(co_tuple)
|
|
284
|
+
if got is not None:
|
|
285
|
+
for yo_tuple, y_node in got:
|
|
286
|
+
if y_node.is_one:
|
|
287
|
+
rows[x_instance + yo_tuple] = x_node
|
|
288
|
+
else:
|
|
289
|
+
rows[x_instance + yo_tuple] = circuit.op(c_MUL, (x_node, y_node))
|
|
290
|
+
|
|
291
|
+
return table
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
cdef int _find(tuple[int, ...] xs, int x):
|
|
295
|
+
cdef int i
|
|
296
|
+
for i in range(len(xs)):
|
|
297
|
+
if xs[i] == x:
|
|
298
|
+
return i
|
|
299
|
+
# Very unexpected
|
|
300
|
+
raise RuntimeError('not found')
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
cdef CircuitTable _product_no_common_rvs(CircuitTable x, CircuitTable y):
|
|
304
|
+
# Return the product of x and y, where x and y have no common random variables.
|
|
305
|
+
#
|
|
306
|
+
# This is an optimisation of more general product algorithm as no index needs
|
|
307
|
+
# to be construction based on the common random variables.
|
|
308
|
+
#
|
|
309
|
+
# Rows with constant node values of one are optimised out.
|
|
310
|
+
#
|
|
311
|
+
# Assumes:
|
|
312
|
+
# * There are no common random variables between x and y.
|
|
313
|
+
# * x and y are for the same circuit.
|
|
314
|
+
cdef Circuit circuit = x.circuit
|
|
315
|
+
cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + y.rv_idxs)
|
|
316
|
+
cdef tuple[int, ...] instance
|
|
317
|
+
|
|
318
|
+
for x_instance, x_node in x.rows.items():
|
|
319
|
+
if x_node.is_one:
|
|
320
|
+
for y_instance, y_node in y.rows.items():
|
|
321
|
+
instance = x_instance + y_instance
|
|
322
|
+
table.rows[instance] = y_node
|
|
323
|
+
else:
|
|
324
|
+
for y_instance, y_node in y.rows.items():
|
|
325
|
+
instance = x_instance + y_instance
|
|
326
|
+
if y_node.is_one:
|
|
327
|
+
table.rows[instance] = x_node
|
|
328
|
+
else:
|
|
329
|
+
table.rows[instance] = circuit.op(c_MUL, (x_node, y_node))
|
|
330
|
+
|
|
331
|
+
return table
|
|
332
|
+
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple, Dict, Iterable, Set, Iterator
|
|
4
|
+
|
|
5
|
+
from ck.circuit import CircuitNode, Circuit
|
|
6
|
+
from ck.utils.map_list import MapList
|
|
7
|
+
|
|
8
|
+
TableInstance = Tuple[int, ...]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CircuitTable:
|
|
12
|
+
"""
|
|
13
|
+
A circuit table manages a set of CircuitNodes, where each node corresponds
|
|
14
|
+
to an instance for a set of (zero or more) random variables.
|
|
15
|
+
|
|
16
|
+
Operations on circuit tables typically add circuit nodes to the circuit. It will
|
|
17
|
+
heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
|
|
18
|
+
by zero or one.) However, it may be that interim circuit nodes are created that
|
|
19
|
+
end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
|
|
20
|
+
completing all circuit table operations.
|
|
21
|
+
|
|
22
|
+
It is generally expected that no CircuitTable row will be created with a constant
|
|
23
|
+
zero node. These are assumed to be optimised out already.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
circuit: Circuit,
|
|
29
|
+
rv_idxs: Sequence[int],
|
|
30
|
+
rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Args:
|
|
34
|
+
circuit: the circuit whose nodes are being managed by this table.
|
|
35
|
+
rv_idxs: indexes of random variables.
|
|
36
|
+
rows: optional rows to add to the table.
|
|
37
|
+
|
|
38
|
+
Assumes:
|
|
39
|
+
* rv_idxs contains no duplicates.
|
|
40
|
+
* all row instances conform to the indexed random variables.
|
|
41
|
+
* all row circuit nodes belong to the given circuit.
|
|
42
|
+
"""
|
|
43
|
+
self._circuit: Circuit = circuit
|
|
44
|
+
self._rv_idxs: Tuple[int, ...] = tuple(rv_idxs)
|
|
45
|
+
self._rows: Dict[TableInstance, CircuitNode] = dict(rows)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def circuit(self) -> Circuit:
|
|
49
|
+
return self._circuit
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def rv_idxs(self) -> Tuple[int, ...]:
|
|
53
|
+
return self._rv_idxs
|
|
54
|
+
|
|
55
|
+
def __len__(self) -> int:
|
|
56
|
+
return len(self._rows)
|
|
57
|
+
|
|
58
|
+
def get(self, key, default=None):
|
|
59
|
+
return self._rows.get(key, default)
|
|
60
|
+
|
|
61
|
+
def keys(self) -> Iterable[TableInstance]:
|
|
62
|
+
return self._rows.keys()
|
|
63
|
+
|
|
64
|
+
def values(self) -> Iterable[CircuitNode]:
|
|
65
|
+
return self._rows.values()
|
|
66
|
+
|
|
67
|
+
def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
|
|
68
|
+
return self._rows.items()
|
|
69
|
+
|
|
70
|
+
def __getitem__(self, key):
|
|
71
|
+
return self._rows[key]
|
|
72
|
+
|
|
73
|
+
def __setitem__(self, key, value):
|
|
74
|
+
self._rows[key] = value
|
|
75
|
+
|
|
76
|
+
def top(self) -> CircuitNode:
|
|
77
|
+
"""
|
|
78
|
+
Get the circuit top value.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
RuntimeError if there is more than one row in the table.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A single circuit node.
|
|
85
|
+
"""
|
|
86
|
+
if len(self._rows) == 0:
|
|
87
|
+
return self._circuit.zero
|
|
88
|
+
elif len(self._rows) == 1:
|
|
89
|
+
return next(iter(self._rows.values()))
|
|
90
|
+
else:
|
|
91
|
+
raise RuntimeError('cannot get top node from a table with more that 1 row')
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# ==================================================================================
|
|
95
|
+
# Circuit Table Operations
|
|
96
|
+
# ==================================================================================
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
|
|
100
|
+
"""
|
|
101
|
+
Return a circuit table that results from summing out
|
|
102
|
+
the given random variables of this circuit table.
|
|
103
|
+
|
|
104
|
+
Normally this will return a new table. However, if rv_idxs is empty,
|
|
105
|
+
then the given table is returned unmodified.
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ValueError if rv_idxs is not a subset of table.rv_idxs.
|
|
109
|
+
ValueError if rv_idxs contains duplicates.
|
|
110
|
+
"""
|
|
111
|
+
rv_idxs: Sequence[int] = tuple(rv_idxs)
|
|
112
|
+
|
|
113
|
+
if len(rv_idxs) == 0:
|
|
114
|
+
# nothing to do
|
|
115
|
+
return table
|
|
116
|
+
|
|
117
|
+
rv_idxs_set: Set[int] = set(rv_idxs)
|
|
118
|
+
if len(rv_idxs_set) != len(rv_idxs):
|
|
119
|
+
raise ValueError('rv_idxs contains duplicates')
|
|
120
|
+
if not rv_idxs_set.issubset(table.rv_idxs):
|
|
121
|
+
raise ValueError('rv_idxs is not a subset of table.rv_idxs')
|
|
122
|
+
|
|
123
|
+
remaining_rv_idxs = tuple(
|
|
124
|
+
rv_index
|
|
125
|
+
for rv_index in table.rv_idxs
|
|
126
|
+
if rv_index not in rv_idxs_set
|
|
127
|
+
)
|
|
128
|
+
num_remaining = len(remaining_rv_idxs)
|
|
129
|
+
if num_remaining == 0:
|
|
130
|
+
# Special case: summing out all random variables
|
|
131
|
+
return sum_out_all(table)
|
|
132
|
+
|
|
133
|
+
# index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
|
|
134
|
+
index_map = tuple(
|
|
135
|
+
table.rv_idxs.index(remaining_rv_index)
|
|
136
|
+
for remaining_rv_index in remaining_rv_idxs
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# This is a one-pass version to sum the groups. The
|
|
140
|
+
# two-pass version (below) seems to have better performance.
|
|
141
|
+
#
|
|
142
|
+
# circuit: Circuit = table.circuit
|
|
143
|
+
# result = CircuitTable(circuit, remaining_rv_idxs)
|
|
144
|
+
# result_rows: Dict[TableInstance, CircuitNode] = result._rows
|
|
145
|
+
# for instance, node in table.items():
|
|
146
|
+
# group_instance = tuple(instance[i] for i in index_map)
|
|
147
|
+
# prev_sum = result_rows.get(group_instance)
|
|
148
|
+
# if prev_sum is None:
|
|
149
|
+
# result_rows[group_instance] = node
|
|
150
|
+
# else:
|
|
151
|
+
# result_rows[group_instance] = circuit.add(prev_sum, node)
|
|
152
|
+
# return result
|
|
153
|
+
|
|
154
|
+
groups: MapList[TableInstance, CircuitNode] = MapList()
|
|
155
|
+
for instance, node in table.items():
|
|
156
|
+
group_instance = tuple(instance[i] for i in index_map)
|
|
157
|
+
groups.append(group_instance, node)
|
|
158
|
+
circuit: Circuit = table.circuit
|
|
159
|
+
return CircuitTable(
|
|
160
|
+
circuit,
|
|
161
|
+
remaining_rv_idxs,
|
|
162
|
+
(
|
|
163
|
+
(group, circuit.add(to_add))
|
|
164
|
+
for group, to_add in groups.items()
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def sum_out_all(table: CircuitTable) -> CircuitTable:
|
|
170
|
+
"""
|
|
171
|
+
Return a circuit table that results from summing out
|
|
172
|
+
all random variables of this circuit table.
|
|
173
|
+
"""
|
|
174
|
+
circuit: Circuit = table.circuit
|
|
175
|
+
num_rows: int = len(table)
|
|
176
|
+
if num_rows == 0:
|
|
177
|
+
return CircuitTable(circuit, ())
|
|
178
|
+
elif num_rows == 1:
|
|
179
|
+
node = next(iter(table.values()))
|
|
180
|
+
else:
|
|
181
|
+
node: CircuitNode = circuit.optimised_add(table.values())
|
|
182
|
+
if node.is_zero:
|
|
183
|
+
return CircuitTable(circuit, ())
|
|
184
|
+
|
|
185
|
+
return CircuitTable(circuit, (), [((), node)])
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def project(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
|
|
189
|
+
"""
|
|
190
|
+
Call `sum_out(table, to_sum_out)`, where
|
|
191
|
+
`to_sum_out = table.rv_idxs - rv_idxs`.
|
|
192
|
+
"""
|
|
193
|
+
to_sum_out: Set[int] = set(table.rv_idxs)
|
|
194
|
+
to_sum_out.difference_update(rv_idxs)
|
|
195
|
+
return sum_out(table, to_sum_out)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
199
|
+
"""
|
|
200
|
+
Return a circuit table that results from the product of the two given tables.
|
|
201
|
+
|
|
202
|
+
If x or y have a single row with value 1, then the other table is returned. Otherwise,
|
|
203
|
+
a new circuit table will be constructed and returned.
|
|
204
|
+
"""
|
|
205
|
+
circuit: Circuit = x.circuit
|
|
206
|
+
if y.circuit is not circuit:
|
|
207
|
+
raise ValueError('circuit tables must refer to the same circuit')
|
|
208
|
+
|
|
209
|
+
# Make the smaller table 'y', and the other 'x'.
|
|
210
|
+
# This is to minimise the index size on 'y'.
|
|
211
|
+
if len(x) < len(y):
|
|
212
|
+
x, y = y, x
|
|
213
|
+
|
|
214
|
+
x_rv_idxs: Tuple[int, ...] = x.rv_idxs
|
|
215
|
+
y_rv_idxs: Tuple[int, ...] = y.rv_idxs
|
|
216
|
+
|
|
217
|
+
# Special case: y == 0 or 1, and has no random variables.
|
|
218
|
+
if y_rv_idxs == ():
|
|
219
|
+
if len(y) == 1 and y.top().is_one:
|
|
220
|
+
return x
|
|
221
|
+
elif len(y) == 0:
|
|
222
|
+
return CircuitTable(circuit, x_rv_idxs)
|
|
223
|
+
|
|
224
|
+
# Set operations on rv indexes. After these operations:
|
|
225
|
+
# * co_rv_idxs is the set of rv indexes common (co) to x and y,
|
|
226
|
+
# * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
|
|
227
|
+
yo_rv_idxs_set: Set[int] = set(y_rv_idxs)
|
|
228
|
+
co_rv_idxs_set: Set[int] = set(x_rv_idxs)
|
|
229
|
+
co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
|
|
230
|
+
yo_rv_idxs_set.difference_update(co_rv_idxs_set)
|
|
231
|
+
|
|
232
|
+
if len(co_rv_idxs_set) == 0:
|
|
233
|
+
# Special case: no common random variables.
|
|
234
|
+
return _product_no_common_rvs(x, y)
|
|
235
|
+
|
|
236
|
+
# Convert random variable index sets to sequences
|
|
237
|
+
yo_rv_idxs: Tuple[int, ...] = tuple(yo_rv_idxs_set) # y only random variables
|
|
238
|
+
co_rv_idxs: Tuple[int, ...] = tuple(co_rv_idxs_set) # common random variables
|
|
239
|
+
|
|
240
|
+
# Cache mappings from result Instance to index into source Instance (x or y).
|
|
241
|
+
# This will be used in indexing and product loops to pull our needed values
|
|
242
|
+
# from the source instances.
|
|
243
|
+
co_from_x_map = tuple(x.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
|
|
244
|
+
co_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
|
|
245
|
+
yo_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in yo_rv_idxs)
|
|
246
|
+
|
|
247
|
+
# Index the y rows by common-only key (y is the smaller of the two tables).
|
|
248
|
+
y_index: MapList[TableInstance, Tuple[TableInstance, CircuitNode]] = MapList()
|
|
249
|
+
for y_instance, y_node in y.items():
|
|
250
|
+
co = tuple(y_instance[i] for i in co_from_y_map)
|
|
251
|
+
yo = tuple(y_instance[i] for i in yo_from_y_map)
|
|
252
|
+
y_index.append(co, (yo, y_node))
|
|
253
|
+
|
|
254
|
+
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
255
|
+
# Iterate over x rows, yielding (instance, value).
|
|
256
|
+
# Rows with constant node values of one are optimised out.
|
|
257
|
+
for _x_instance, _x_node in x.items():
|
|
258
|
+
_co = tuple(_x_instance[i] for i in co_from_x_map)
|
|
259
|
+
if _x_node.is_one:
|
|
260
|
+
# Multiplying by one.
|
|
261
|
+
# Iterate over matching y rows.
|
|
262
|
+
for _yo, _y_node in y_index.get(_co, ()):
|
|
263
|
+
yield _x_instance + _yo, _y_node
|
|
264
|
+
else:
|
|
265
|
+
# Iterate over matching y rows.
|
|
266
|
+
for _yo, _y_node in y_index.get(_co, ()):
|
|
267
|
+
if _y_node.is_one:
|
|
268
|
+
yield _x_instance + _yo, _x_node
|
|
269
|
+
else:
|
|
270
|
+
yield _x_instance + _yo, circuit.mul(_x_node, _y_node)
|
|
271
|
+
|
|
272
|
+
return CircuitTable(circuit, x_rv_idxs + yo_rv_idxs, _result_rows())
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
276
|
+
"""
|
|
277
|
+
Return the product of x and y, where x and y have no common random variables.
|
|
278
|
+
|
|
279
|
+
This is an optimisation of more general product algorithm as no index needs
|
|
280
|
+
to be construction based on the common random variables.
|
|
281
|
+
|
|
282
|
+
Rows with constant node values of one are optimised out.
|
|
283
|
+
|
|
284
|
+
Assumes:
|
|
285
|
+
* There are no common random variables between x and y.
|
|
286
|
+
* x and y are for the same circuit.
|
|
287
|
+
"""
|
|
288
|
+
circuit: Circuit = x.circuit
|
|
289
|
+
|
|
290
|
+
result_rv_idxs: Tuple[int, ...] = x.rv_idxs + y.rv_idxs
|
|
291
|
+
|
|
292
|
+
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
293
|
+
for x_instance, x_node in x.items():
|
|
294
|
+
if x_node.is_one:
|
|
295
|
+
for y_instance, y_node in y.items():
|
|
296
|
+
yield x_instance + y_instance, y_node
|
|
297
|
+
else:
|
|
298
|
+
for y_instance, y_node in y.items():
|
|
299
|
+
if y_node.is_one:
|
|
300
|
+
yield x_instance + y_instance, y_node
|
|
301
|
+
else:
|
|
302
|
+
yield x_instance + y_instance, circuit.mul(x_node, y_node)
|
|
303
|
+
|
|
304
|
+
return CircuitTable(circuit, result_rv_idxs, _result_rows())
|