compiled-knowledge 4.0.0a16__cp312-cp312-win_amd64.whl → 4.0.0a18__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/__init__.py +9 -2
- ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +33 -0
- ck/circuit/{circuit.pyx → _circuit_cy.pyx} +115 -133
- ck/circuit/{circuit_py.py → _circuit_py.py} +16 -8
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +88 -60
- ck/circuit_compiler/named_circuit_compilers.py +1 -1
- ck/pgm_compiler/factor_elimination.py +23 -13
- ck/pgm_compiler/support/circuit_table/__init__.py +9 -2
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy_cpp_verion.pyx +601 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy_minimal_version.pyx +311 -0
- ck/pgm_compiler/support/circuit_table/{circuit_table.pyx → _circuit_table_cy_v4.0.0a17.pyx} +9 -9
- ck/pgm_compiler/support/circuit_table/{circuit_table_py.py → _circuit_table_py.py} +80 -45
- ck/pgm_compiler/support/clusters.py +16 -4
- ck/pgm_compiler/support/factor_tables.py +1 -1
- ck/pgm_compiler/support/join_tree.py +67 -10
- ck/pgm_compiler/support/named_compiler_maker.py +12 -2
- ck/pgm_compiler/variable_elimination.py +2 -0
- ck/utils/iter_extras.py +8 -1
- ck_demos/pgm_compiler/demo_compiler_dump.py +10 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/utils/compare.py +5 -1
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/RECORD +30 -29
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/WHEEL +1 -1
- ck/circuit/circuit.c +0 -38861
- ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
- ck/circuit/circuit_node.pyx +0 -138
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +0 -17373
- ck/pgm_compiler/support/circuit_table/circuit_table.c +0 -16042
- ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
from typing import Sequence, Dict, List, Tuple, Set, Optional, Iterator
|
|
3
|
+
from typing import Dict, Tuple, Sequence
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
import ctypes as ct
|
|
8
7
|
|
|
9
8
|
from ck import circuit
|
|
10
|
-
from ck.circuit import
|
|
9
|
+
from ck.circuit import OpNode, VarNode, CircuitNode
|
|
11
10
|
from ck.circuit_compiler.support.circuit_analyser import CircuitAnalysis, analyze_circuit
|
|
12
|
-
from ck.
|
|
13
|
-
from ck.
|
|
14
|
-
from ck.utils.np_extras import DType, NDArrayNumeric, NDArray, DTypeNumeric
|
|
11
|
+
from ck.program.raw_program import RawProgramFunction
|
|
12
|
+
from ck.utils.np_extras import NDArrayNumeric, DTypeNumeric
|
|
15
13
|
|
|
16
14
|
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
|
|
17
15
|
|
|
@@ -63,21 +61,21 @@ def make_function(
|
|
|
63
61
|
|
|
64
62
|
|
|
65
63
|
# VM instructions
|
|
66
|
-
ADD = circuit.ADD
|
|
67
|
-
MUL = circuit.MUL
|
|
68
|
-
|
|
64
|
+
cdef int ADD = circuit.ADD
|
|
65
|
+
cdef int MUL = circuit.MUL
|
|
66
|
+
cdef int COPY = max(ADD, MUL) + 1
|
|
69
67
|
|
|
70
68
|
# VM arrays
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
69
|
+
cdef int VARS = 0
|
|
70
|
+
cdef int TMPS = 1
|
|
71
|
+
cdef int CONSTS = 2
|
|
72
|
+
cdef int RESULT = 3
|
|
75
73
|
|
|
76
74
|
|
|
77
|
-
|
|
78
|
-
analysis: CircuitAnalysis,
|
|
79
|
-
dtype: DTypeNumeric,
|
|
80
|
-
) -> Tuple[Instructions, NDArrayNumeric]:
|
|
75
|
+
cdef tuple[Instructions, cnp.ndarray] _make_instructions_from_analysis(
|
|
76
|
+
object analysis: CircuitAnalysis,
|
|
77
|
+
object dtype: DTypeNumeric,
|
|
78
|
+
): # -> Tuple[Instructions, NDArrayNumeric]:
|
|
81
79
|
if dtype != np.float64:
|
|
82
80
|
raise RuntimeError(f'only DType {np.float64} currently supported')
|
|
83
81
|
|
|
@@ -91,7 +89,7 @@ def _make_instructions_from_analysis(
|
|
|
91
89
|
np_consts[i] = node.value
|
|
92
90
|
|
|
93
91
|
# Where to get input values for each possible node.
|
|
94
|
-
|
|
92
|
+
cdef dict[int, ElementID] node_to_element = {}
|
|
95
93
|
# const nodes
|
|
96
94
|
for node_id, const_idx in node_to_const_idx.items():
|
|
97
95
|
node_to_element[node_id] = ElementID(CONSTS, const_idx)
|
|
@@ -110,21 +108,16 @@ def _make_instructions_from_analysis(
|
|
|
110
108
|
# Build instructions
|
|
111
109
|
instructions: Instructions = Instructions()
|
|
112
110
|
|
|
113
|
-
op_node
|
|
111
|
+
cdef object op_node
|
|
114
112
|
for op_node in analysis.op_nodes:
|
|
115
|
-
|
|
116
|
-
args: list[ElementID] = [
|
|
117
|
-
node_to_element[id(arg)]
|
|
118
|
-
for arg in op_node.args
|
|
119
|
-
]
|
|
120
|
-
instructions.append(op_node.symbol, args, dest)
|
|
113
|
+
instructions.append_op(op_node.symbol, op_node, node_to_element)
|
|
121
114
|
|
|
122
115
|
# Add any copy operations, i.e., result nodes that are not op nodes
|
|
123
116
|
for i, node in enumerate(analysis.result_nodes):
|
|
124
117
|
if not isinstance(node, OpNode):
|
|
125
118
|
dest: ElementID = ElementID(RESULT, i)
|
|
126
|
-
|
|
127
|
-
instructions.
|
|
119
|
+
src: ElementID = node_to_element[id(node)]
|
|
120
|
+
instructions.append_copy(src, dest)
|
|
128
121
|
|
|
129
122
|
return instructions, np_consts
|
|
130
123
|
|
|
@@ -136,39 +129,70 @@ cdef struct ElementID:
|
|
|
136
129
|
|
|
137
130
|
cdef struct Instruction:
|
|
138
131
|
int symbol # ADD, MUL, COPY
|
|
139
|
-
|
|
132
|
+
Py_ssize_t num_args
|
|
140
133
|
ElementID* args
|
|
141
134
|
ElementID dest
|
|
142
135
|
|
|
143
136
|
|
|
144
137
|
cdef class Instructions:
|
|
145
138
|
cdef Instruction* instructions
|
|
139
|
+
cdef int allocated
|
|
146
140
|
cdef int num_instructions
|
|
147
141
|
|
|
148
|
-
def __init__(self):
|
|
149
|
-
self.instructions = <Instruction*> PyMem_Malloc(0)
|
|
142
|
+
def __init__(self) -> None:
|
|
150
143
|
self.num_instructions = 0
|
|
144
|
+
self.allocated = 64
|
|
145
|
+
self.instructions = <Instruction*> PyMem_Malloc(self.allocated * sizeof(Instruction))
|
|
146
|
+
|
|
147
|
+
cdef void append_copy(
|
|
148
|
+
self,
|
|
149
|
+
ElementID src,
|
|
150
|
+
ElementID dest,
|
|
151
|
+
) except*:
|
|
152
|
+
c_args = <ElementID*> PyMem_Malloc(sizeof(ElementID))
|
|
153
|
+
if not c_args:
|
|
154
|
+
raise MemoryError()
|
|
151
155
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
156
|
+
c_args[0] = src
|
|
157
|
+
self._append(COPY, 1, c_args, dest)
|
|
158
|
+
|
|
159
|
+
cdef void append_op(self, int symbol, object op_node: OpNode, dict[int, ElementID] node_to_element) except*:
|
|
160
|
+
args = op_node.args
|
|
161
|
+
cdef Py_ssize_t num_args = len(args)
|
|
155
162
|
|
|
156
|
-
|
|
157
|
-
|
|
163
|
+
# Create the instruction arguments array
|
|
164
|
+
c_args = <ElementID*> PyMem_Malloc(num_args * sizeof(ElementID))
|
|
158
165
|
if not c_args:
|
|
159
166
|
raise MemoryError()
|
|
160
167
|
|
|
161
|
-
|
|
162
|
-
|
|
168
|
+
cdef Py_ssize_t i = num_args
|
|
169
|
+
while i > 0:
|
|
170
|
+
i -= 1
|
|
171
|
+
c_args[i] = node_to_element[id(args[i])]
|
|
172
|
+
|
|
173
|
+
dest: ElementID = node_to_element[id(op_node)]
|
|
174
|
+
|
|
175
|
+
self._append(symbol, num_args, c_args, dest)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
cdef void _append(self, int symbol, Py_ssize_t num_args, ElementID* c_args, ElementID dest) except *:
|
|
179
|
+
cdef int i
|
|
163
180
|
|
|
164
181
|
cdef int num_instructions = self.num_instructions
|
|
165
|
-
self.instructions = <Instruction*> PyMem_Realloc(
|
|
166
|
-
self.instructions,
|
|
167
|
-
sizeof(Instruction) * (num_instructions + 1)
|
|
168
|
-
)
|
|
169
|
-
if not self.instructions:
|
|
170
|
-
raise MemoryError()
|
|
171
182
|
|
|
183
|
+
# Ensure sufficient instruction memory
|
|
184
|
+
cdef int allocated = self.allocated
|
|
185
|
+
if num_instructions == allocated:
|
|
186
|
+
allocated *= 2
|
|
187
|
+
self.instructions = <Instruction*> PyMem_Realloc(
|
|
188
|
+
self.instructions,
|
|
189
|
+
allocated * sizeof(Instruction),
|
|
190
|
+
)
|
|
191
|
+
if not self.instructions:
|
|
192
|
+
raise MemoryError()
|
|
193
|
+
self.allocated = allocated
|
|
194
|
+
|
|
195
|
+
# Add the instruction
|
|
172
196
|
self.instructions[num_instructions] = Instruction(
|
|
173
197
|
symbol,
|
|
174
198
|
num_args,
|
|
@@ -177,7 +201,7 @@ cdef class Instructions:
|
|
|
177
201
|
)
|
|
178
202
|
self.num_instructions = num_instructions + 1
|
|
179
203
|
|
|
180
|
-
def __dealloc__(self):
|
|
204
|
+
def __dealloc__(self) -> None:
|
|
181
205
|
cdef Instruction* instructions = self.instructions
|
|
182
206
|
if instructions:
|
|
183
207
|
for i in range(self.num_instructions):
|
|
@@ -194,15 +218,16 @@ cdef void cvm_float64(
|
|
|
194
218
|
double* consts,
|
|
195
219
|
double* result,
|
|
196
220
|
Instructions instructions,
|
|
197
|
-
)
|
|
198
|
-
# Core virtual machine.
|
|
221
|
+
) except *:
|
|
222
|
+
# Core virtual machine (for dtype float64).
|
|
199
223
|
|
|
200
|
-
cdef int
|
|
224
|
+
cdef int symbol
|
|
225
|
+
cdef Py_ssize_t i
|
|
201
226
|
cdef double accumulator
|
|
202
227
|
cdef ElementID* args
|
|
203
|
-
cdef ElementID
|
|
228
|
+
cdef ElementID elem
|
|
204
229
|
|
|
205
|
-
#
|
|
230
|
+
# Index the four arrays by constants VARS, TMPS, CONSTS, and RESULT
|
|
206
231
|
cdef (double*) arrays[4]
|
|
207
232
|
arrays[VARS] = vars_in
|
|
208
233
|
arrays[TMPS] = tmps
|
|
@@ -210,29 +235,32 @@ cdef void cvm_float64(
|
|
|
210
235
|
arrays[RESULT] = result
|
|
211
236
|
|
|
212
237
|
cdef Instruction* instruction_ptr = instructions.instructions
|
|
213
|
-
|
|
238
|
+
cdef int num_instructions = instructions.num_instructions
|
|
214
239
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
240
|
+
while num_instructions > 0:
|
|
241
|
+
num_instructions -= 1
|
|
242
|
+
|
|
243
|
+
symbol = instruction_ptr.symbol
|
|
244
|
+
args = instruction_ptr.args
|
|
218
245
|
|
|
219
246
|
elem = args[0]
|
|
220
247
|
accumulator = arrays[elem.array][elem.index]
|
|
221
248
|
|
|
222
249
|
if symbol == ADD:
|
|
223
|
-
|
|
250
|
+
i = instruction_ptr.num_args
|
|
251
|
+
while i > 1:
|
|
252
|
+
i -= 1
|
|
224
253
|
elem = args[i]
|
|
225
254
|
accumulator += arrays[elem.array][elem.index]
|
|
226
255
|
elif symbol == MUL:
|
|
227
|
-
|
|
256
|
+
i = instruction_ptr.num_args
|
|
257
|
+
while i > 1:
|
|
258
|
+
i -= 1
|
|
228
259
|
elem = args[i]
|
|
229
260
|
accumulator *= arrays[elem.array][elem.index]
|
|
230
|
-
|
|
231
|
-
pass
|
|
232
|
-
else:
|
|
233
|
-
raise RuntimeError('symbol not understood: ' + str(symbol))
|
|
261
|
+
# else symbol == COPY, nothing to do
|
|
234
262
|
|
|
235
|
-
elem = instruction_ptr
|
|
263
|
+
elem = instruction_ptr.dest
|
|
236
264
|
arrays[elem.array][elem.index] = accumulator
|
|
237
265
|
|
|
238
266
|
# Advance the instruction pointer
|
|
@@ -149,7 +149,7 @@ def join_tree_to_circuit(
|
|
|
149
149
|
limit_product_tree_search,
|
|
150
150
|
)
|
|
151
151
|
top: CircuitNode = top_table.top()
|
|
152
|
-
|
|
152
|
+
top.circuit.remove_unreachable_op_nodes(top)
|
|
153
153
|
|
|
154
154
|
return PGMCircuit(
|
|
155
155
|
rvs=tuple(pgm.rvs),
|
|
@@ -169,27 +169,37 @@ def _circuit_tables_from_join_tree(
|
|
|
169
169
|
) -> CircuitTable:
|
|
170
170
|
"""
|
|
171
171
|
This is a basic algorithm for constructing a circuit table from a join tree.
|
|
172
|
+
Algorithm synopsis:
|
|
173
|
+
1) Get a CircuitTable for each factor allocated to this join tree node, and
|
|
174
|
+
for each child of the join tree node (recursive call to _circuit_tables_from_join_tree).
|
|
175
|
+
2) Form a binary tree of the collected circuit tables.
|
|
176
|
+
3) Perform table products and sum-outs for each node in the binary tree, which should
|
|
177
|
+
leave a single circuit table with a single row.
|
|
172
178
|
"""
|
|
173
|
-
#
|
|
174
|
-
factors: List[CircuitTable] =
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
179
|
+
# Get all the factors to combine.
|
|
180
|
+
factors: List[CircuitTable] = list(
|
|
181
|
+
chain(
|
|
182
|
+
(
|
|
183
|
+
# The PGM factors allocated to this join tree node
|
|
184
|
+
factor_tables.get_table(factor)
|
|
185
|
+
for factor in join_tree.factors
|
|
186
|
+
),
|
|
187
|
+
(
|
|
188
|
+
# The children of this join tree node
|
|
189
|
+
_circuit_tables_from_join_tree(factor_tables, child, limit_product_tree_search)
|
|
190
|
+
for child in join_tree.children
|
|
191
|
+
),
|
|
192
|
+
)
|
|
183
193
|
)
|
|
184
194
|
|
|
185
195
|
# The usual join tree approach just forms the product all the tables in `factors`.
|
|
186
196
|
# The tree width is not affected by the order of products, however some orders
|
|
187
197
|
# lead to smaller numbers of arithmetic operations.
|
|
188
198
|
#
|
|
189
|
-
# If `
|
|
199
|
+
# If `limit_product_tree_search > 1`, then heuristics are used
|
|
190
200
|
# reduce the number of arithmetic operations.
|
|
191
201
|
|
|
192
|
-
# Deal with the special case:
|
|
202
|
+
# Deal with the special case: zero factors
|
|
193
203
|
if len(factors) == 0:
|
|
194
204
|
circuit = factor_tables.circuit
|
|
195
205
|
if len(join_tree.separator) == 0:
|
|
@@ -1,5 +1,12 @@
|
|
|
1
|
-
#
|
|
2
|
-
|
|
1
|
+
# There are two implementations of the `circuit_table` module are provided
|
|
2
|
+
# for developer R&D purposes. One is pure Python and the other is Cython.
|
|
3
|
+
# Which implementation is used can be selected here.
|
|
4
|
+
# A similar selection can be made for the `circuit` module.
|
|
5
|
+
# Note that if the Cython implementation is chosen for `circuit_table` then
|
|
6
|
+
# the Cython implementation must be chosen for `circuit`.
|
|
7
|
+
|
|
8
|
+
# from ._circuit_table_py import (
|
|
9
|
+
from ._circuit_table_cy import (
|
|
3
10
|
CircuitTable,
|
|
4
11
|
TableInstance,
|
|
5
12
|
sum_out,
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple, Iterable
|
|
4
|
+
|
|
5
|
+
from ck.circuit import MUL, ADD
|
|
6
|
+
|
|
7
|
+
from ck.circuit._circuit_cy cimport Circuit, CircuitNode
|
|
8
|
+
|
|
9
|
+
cdef int c_ADD = ADD
|
|
10
|
+
cdef int c_MUL = MUL
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
TableInstance = Tuple[int, ...]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
cdef class CircuitTable:
|
|
17
|
+
"""
|
|
18
|
+
A circuit table manages a set of CircuitNodes, where each node corresponds
|
|
19
|
+
to an instance for a set of (zero or more) random variables.
|
|
20
|
+
|
|
21
|
+
Operations on circuit tables typically add circuit nodes to the circuit. It will
|
|
22
|
+
heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
|
|
23
|
+
by zero or one.) However, it may be that interim circuit nodes are created that
|
|
24
|
+
end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
|
|
25
|
+
completing all circuit table operations.
|
|
26
|
+
|
|
27
|
+
It is generally expected that no CircuitTable row will be created with a constant
|
|
28
|
+
zero node. These are assumed to be optimised out already.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
cdef public Circuit circuit
|
|
32
|
+
cdef public tuple[int, ...] rv_idxs
|
|
33
|
+
cdef dict[tuple[int, ...], CircuitNode] rows
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
circuit: Circuit,
|
|
38
|
+
rv_idxs: Sequence[int],
|
|
39
|
+
rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Args:
|
|
43
|
+
circuit: the circuit whose nodes are being managed by this table.
|
|
44
|
+
rv_idxs: indexes of random variables.
|
|
45
|
+
rows: optional rows to add to the table.
|
|
46
|
+
|
|
47
|
+
Assumes:
|
|
48
|
+
* rv_idxs contains no duplicates.
|
|
49
|
+
* all row instances conform to the indexed random variables.
|
|
50
|
+
* all row circuit nodes belong to the given circuit.
|
|
51
|
+
"""
|
|
52
|
+
self.circuit = circuit
|
|
53
|
+
self.rv_idxs = tuple(rv_idxs)
|
|
54
|
+
self.rows = dict(rows)
|
|
55
|
+
|
|
56
|
+
def __len__(self) -> int:
|
|
57
|
+
return len(self.rows)
|
|
58
|
+
|
|
59
|
+
def get(self, key, default=None):
|
|
60
|
+
return self.rows.get(key, default)
|
|
61
|
+
|
|
62
|
+
def keys(self) -> Iterable[CircuitNode]:
|
|
63
|
+
return self.rows.keys()
|
|
64
|
+
|
|
65
|
+
def values(self) -> Iterable[tuple[int, ...]]:
|
|
66
|
+
return self.rows.values()
|
|
67
|
+
|
|
68
|
+
def __getitem__(self, key):
|
|
69
|
+
return self.rows[key]
|
|
70
|
+
|
|
71
|
+
def __setitem__(self, key, value):
|
|
72
|
+
self.rows[key] = value
|
|
73
|
+
|
|
74
|
+
cpdef CircuitNode top(self):
|
|
75
|
+
# Get the circuit top value.
|
|
76
|
+
#
|
|
77
|
+
# Raises:
|
|
78
|
+
# RuntimeError if there is more than one row in the table.
|
|
79
|
+
#
|
|
80
|
+
# Returns:
|
|
81
|
+
# A single circuit node.
|
|
82
|
+
cdef int number_of_rows = len(self.rows)
|
|
83
|
+
if number_of_rows == 0:
|
|
84
|
+
return self.circuit.zero
|
|
85
|
+
elif number_of_rows == 1:
|
|
86
|
+
return next(iter(self.rows.values()))
|
|
87
|
+
else:
|
|
88
|
+
raise RuntimeError('cannot get top node from a table with more that 1 row')
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# ==================================================================================
|
|
92
|
+
# Circuit Table Operations
|
|
93
|
+
# ==================================================================================
|
|
94
|
+
|
|
95
|
+
cpdef CircuitTable sum_out(CircuitTable table, object rv_idxs: Iterable[int]):
|
|
96
|
+
# Return a circuit table that results from summing out
|
|
97
|
+
# the given random variables of this circuit table.
|
|
98
|
+
#
|
|
99
|
+
# Normally this will return a new table. However, if rv_idxs is empty,
|
|
100
|
+
# then the given table is returned unmodified.
|
|
101
|
+
#
|
|
102
|
+
# Raises:
|
|
103
|
+
# ValueError if rv_idxs is not a subset of table.rv_idxs.
|
|
104
|
+
# ValueError if rv_idxs contains duplicates.
|
|
105
|
+
cdef tuple[int, ...] rv_idxs_seq = tuple(rv_idxs)
|
|
106
|
+
|
|
107
|
+
if len(rv_idxs_seq) == 0:
|
|
108
|
+
# nothing to do
|
|
109
|
+
return table
|
|
110
|
+
|
|
111
|
+
cdef set[int] rv_idxs_set = set(rv_idxs_seq)
|
|
112
|
+
if len(rv_idxs_set) != len(rv_idxs_seq):
|
|
113
|
+
raise ValueError('rv_idxs contains duplicates')
|
|
114
|
+
if not rv_idxs_set.issubset(table.rv_idxs):
|
|
115
|
+
raise ValueError('rv_idxs is not a subset of table.rv_idxs')
|
|
116
|
+
|
|
117
|
+
cdef int rv_index
|
|
118
|
+
cdef list[int] remaining_rv_idxs = []
|
|
119
|
+
for rv_index in table.rv_idxs:
|
|
120
|
+
if rv_index not in rv_idxs_set:
|
|
121
|
+
remaining_rv_idxs.append(rv_index)
|
|
122
|
+
|
|
123
|
+
cdef int num_remaining = len(remaining_rv_idxs)
|
|
124
|
+
if num_remaining == 0:
|
|
125
|
+
# Special case: summing out all random variables
|
|
126
|
+
return sum_out_all(table)
|
|
127
|
+
|
|
128
|
+
# index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
|
|
129
|
+
cdef list[int] index_map = []
|
|
130
|
+
for rv_index in remaining_rv_idxs:
|
|
131
|
+
index_map.append(_find(table.rv_idxs, rv_index))
|
|
132
|
+
|
|
133
|
+
cdef dict[tuple[int, ...], list[CircuitNode]] groups = {}
|
|
134
|
+
cdef object got
|
|
135
|
+
cdef list[int] group_instance
|
|
136
|
+
cdef tuple[int, ...] group_instance_tuple
|
|
137
|
+
cdef int i
|
|
138
|
+
cdef CircuitNode node
|
|
139
|
+
cdef tuple[int, ...] instance
|
|
140
|
+
for instance, node in table.rows.items():
|
|
141
|
+
group_instance = []
|
|
142
|
+
for i in index_map:
|
|
143
|
+
group_instance.append(instance[i])
|
|
144
|
+
group_instance_tuple = tuple(group_instance)
|
|
145
|
+
got = groups.get(group_instance_tuple)
|
|
146
|
+
if got is None:
|
|
147
|
+
groups[group_instance_tuple] = [node]
|
|
148
|
+
else:
|
|
149
|
+
got.append(node)
|
|
150
|
+
|
|
151
|
+
cdef Circuit circuit = table.circuit
|
|
152
|
+
cdef CircuitTable new_table = CircuitTable(circuit, remaining_rv_idxs)
|
|
153
|
+
cdef dict[tuple[int, ...], CircuitNode] rows = new_table.rows
|
|
154
|
+
|
|
155
|
+
for group_instance_tuple, to_add in groups.items():
|
|
156
|
+
node = circuit.op(c_ADD, tuple(to_add))
|
|
157
|
+
if not node.is_zero:
|
|
158
|
+
rows[group_instance_tuple] = node
|
|
159
|
+
|
|
160
|
+
return new_table
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
cpdef CircuitTable sum_out_all(CircuitTable table):
|
|
164
|
+
# Return a circuit table that results from summing out
|
|
165
|
+
# all random variables of this circuit table.
|
|
166
|
+
circuit: Circuit = table.circuit
|
|
167
|
+
num_rows: int = len(table)
|
|
168
|
+
if num_rows == 0:
|
|
169
|
+
return CircuitTable(circuit, ())
|
|
170
|
+
elif num_rows == 1:
|
|
171
|
+
node = next(iter(table.rows.values()))
|
|
172
|
+
else:
|
|
173
|
+
node: CircuitNode = circuit.op(c_ADD, tuple(table.rows.values()))
|
|
174
|
+
if node.is_zero:
|
|
175
|
+
return CircuitTable(circuit, ())
|
|
176
|
+
|
|
177
|
+
return CircuitTable(circuit, (), [((), node)])
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
cpdef CircuitTable project(CircuitTable table: CircuitTable, object rv_idxs: Iterable[int]):
|
|
181
|
+
# Call `sum_out(table, to_sum_out)`, where
|
|
182
|
+
# `to_sum_out = table.rv_idxs - rv_idxs`.
|
|
183
|
+
cdef set[int] to_sum_out = set(table.rv_idxs)
|
|
184
|
+
to_sum_out.difference_update(rv_idxs)
|
|
185
|
+
return sum_out(table, to_sum_out)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
cpdef CircuitTable product(CircuitTable x, CircuitTable y):
|
|
189
|
+
# Return a circuit table that results from the product of the two given tables.
|
|
190
|
+
#
|
|
191
|
+
# If x or y equals `one_table`, then the other table is returned. Otherwise,
|
|
192
|
+
# a new circuit table will be constructed and returned.
|
|
193
|
+
cdef int i
|
|
194
|
+
cdef Circuit circuit = x.circuit
|
|
195
|
+
if y.circuit is not circuit:
|
|
196
|
+
raise ValueError('circuit tables must refer to the same circuit')
|
|
197
|
+
|
|
198
|
+
# Make the smaller table 'y', and the other 'x'.
|
|
199
|
+
# This is to minimise the index size on 'y'.
|
|
200
|
+
if len(x) < len(y):
|
|
201
|
+
x, y = y, x
|
|
202
|
+
|
|
203
|
+
# Special case: y == 0 or 1, and has no random variables.
|
|
204
|
+
if len(y.rv_idxs) == 0:
|
|
205
|
+
if len(y) == 1 and y.top().is_one:
|
|
206
|
+
return x
|
|
207
|
+
elif len(y) == 0:
|
|
208
|
+
return CircuitTable(circuit, x.rv_idxs)
|
|
209
|
+
|
|
210
|
+
# Set operations on rv indexes. After these operations:
|
|
211
|
+
# * co_rv_idxs is the set of rv indexes common (co) to x and y,
|
|
212
|
+
# * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
|
|
213
|
+
cdef set[int] yo_rv_idxs_set = set(y.rv_idxs)
|
|
214
|
+
cdef set[int] co_rv_idxs_set = set(x.rv_idxs)
|
|
215
|
+
co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
|
|
216
|
+
yo_rv_idxs_set.difference_update(co_rv_idxs_set)
|
|
217
|
+
|
|
218
|
+
if len(co_rv_idxs_set) == 0:
|
|
219
|
+
# Special case: no common random variables.
|
|
220
|
+
return _product_no_common_rvs(x, y)
|
|
221
|
+
|
|
222
|
+
# Convert random variable index sets to sequences
|
|
223
|
+
cdef tuple[int, ...] yo_rv_idxs = tuple(yo_rv_idxs_set) # y only random variables
|
|
224
|
+
cdef tuple[int, ...] co_rv_idxs = tuple(co_rv_idxs_set) # common random variables
|
|
225
|
+
|
|
226
|
+
# Cache mappings from result Instance to index into source Instance (x or y).
|
|
227
|
+
# This will be used in indexing and product loops to pull our needed values
|
|
228
|
+
# from the source instances.
|
|
229
|
+
cdef list[int] co_from_x_map = []
|
|
230
|
+
cdef list[int] co_from_y_map = []
|
|
231
|
+
cdef list[int] yo_from_y_map = []
|
|
232
|
+
for rv_index in co_rv_idxs:
|
|
233
|
+
co_from_x_map.append(_find(x.rv_idxs, rv_index))
|
|
234
|
+
co_from_y_map.append(_find(y.rv_idxs, rv_index))
|
|
235
|
+
for rv_index in yo_rv_idxs:
|
|
236
|
+
yo_from_y_map.append(_find(y.rv_idxs, rv_index))
|
|
237
|
+
|
|
238
|
+
cdef list[int] co
|
|
239
|
+
cdef list[int] yo
|
|
240
|
+
cdef object got
|
|
241
|
+
cdef tuple[int, ...] co_tuple
|
|
242
|
+
cdef tuple[int, ...] yo_tuple
|
|
243
|
+
|
|
244
|
+
cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + yo_rv_idxs)
|
|
245
|
+
cdef dict[tuple[int, ...], CircuitNode] rows = table.rows
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# Index the y rows by common-only key (y is the smaller of the two tables).
|
|
249
|
+
cdef dict[tuple[int, ...], list[tuple[tuple[int, ...], CircuitNode]]] y_index = {}
|
|
250
|
+
for y_instance, y_node in y.rows.items():
|
|
251
|
+
co = []
|
|
252
|
+
yo = []
|
|
253
|
+
for i in co_from_y_map:
|
|
254
|
+
co.append(y_instance[i])
|
|
255
|
+
for i in yo_from_y_map:
|
|
256
|
+
yo.append(y_instance[i])
|
|
257
|
+
co_tuple = tuple(co)
|
|
258
|
+
yo_tuple = tuple(yo)
|
|
259
|
+
got = y_index.get(co_tuple)
|
|
260
|
+
if got is None:
|
|
261
|
+
y_index[co_tuple] = [(yo_tuple, y_node)]
|
|
262
|
+
else:
|
|
263
|
+
got.append((yo_tuple, y_node))
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
# Iterate over x rows, inserting (instance, value).
|
|
267
|
+
# Rows with constant node values of one are optimised out.
|
|
268
|
+
for x_instance, x_node in x.rows.items():
|
|
269
|
+
co = []
|
|
270
|
+
for i in co_from_x_map:
|
|
271
|
+
co.append(x_instance[i])
|
|
272
|
+
co_tuple = tuple(co)
|
|
273
|
+
|
|
274
|
+
if x_node.is_one:
|
|
275
|
+
# Multiplying by one.
|
|
276
|
+
# Iterate over matching y rows.
|
|
277
|
+
got = y_index.get(co_tuple)
|
|
278
|
+
if got is not None:
|
|
279
|
+
for yo_tuple, y_node in got:
|
|
280
|
+
rows[x_instance + yo_tuple] = y_node
|
|
281
|
+
else:
|
|
282
|
+
# Iterate over matching y rows.
|
|
283
|
+
got = y_index.get(co_tuple)
|
|
284
|
+
if got is not None:
|
|
285
|
+
for yo_tuple, y_node in got:
|
|
286
|
+
if y_node.is_one:
|
|
287
|
+
rows[x_instance + yo_tuple] = x_node
|
|
288
|
+
else:
|
|
289
|
+
rows[x_instance + yo_tuple] = circuit.op(c_MUL, (x_node, y_node))
|
|
290
|
+
|
|
291
|
+
return table
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
cdef int _find(tuple[int, ...] xs, int x):
|
|
295
|
+
cdef int i
|
|
296
|
+
for i in range(len(xs)):
|
|
297
|
+
if xs[i] == x:
|
|
298
|
+
return i
|
|
299
|
+
# Very unexpected
|
|
300
|
+
raise RuntimeError('not found')
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
cdef CircuitTable _product_no_common_rvs(CircuitTable x, CircuitTable y):
|
|
304
|
+
# Return the product of x and y, where x and y have no common random variables.
|
|
305
|
+
#
|
|
306
|
+
# This is an optimisation of more general product algorithm as no index needs
|
|
307
|
+
# to be construction based on the common random variables.
|
|
308
|
+
#
|
|
309
|
+
# Rows with constant node values of one are optimised out.
|
|
310
|
+
#
|
|
311
|
+
# Assumes:
|
|
312
|
+
# * There are no common random variables between x and y.
|
|
313
|
+
# * x and y are for the same circuit.
|
|
314
|
+
cdef Circuit circuit = x.circuit
|
|
315
|
+
cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + y.rv_idxs)
|
|
316
|
+
cdef tuple[int, ...] instance
|
|
317
|
+
|
|
318
|
+
for x_instance, x_node in x.rows.items():
|
|
319
|
+
if x_node.is_one:
|
|
320
|
+
for y_instance, y_node in y.rows.items():
|
|
321
|
+
instance = x_instance + y_instance
|
|
322
|
+
table.rows[instance] = y_node
|
|
323
|
+
else:
|
|
324
|
+
for y_instance, y_node in y.rows.items():
|
|
325
|
+
instance = x_instance + y_instance
|
|
326
|
+
if y_node.is_one:
|
|
327
|
+
table.rows[instance] = x_node
|
|
328
|
+
else:
|
|
329
|
+
table.rows[instance] = circuit.op(c_MUL, (x_node, y_node))
|
|
330
|
+
|
|
331
|
+
return table
|
|
332
|
+
|