compiled-knowledge 4.0.0a6__cp313-cp313-macosx_10_13_universal2.whl → 4.0.0a8__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/circuit.cpython-313-darwin.so +0 -0
- ck/circuit/circuit.pyx +773 -0
- ck/circuit/circuit_node.pyx +138 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +239 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.pyx +325 -0
- {compiled_knowledge-4.0.0a6.dist-info → compiled_knowledge-4.0.0a8.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a6.dist-info → compiled_knowledge-4.0.0a8.dist-info}/RECORD +12 -8
- {compiled_knowledge-4.0.0a6.dist-info → compiled_knowledge-4.0.0a8.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a6.dist-info → compiled_knowledge-4.0.0a8.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a6.dist-info → compiled_knowledge-4.0.0a8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
|
|
5
|
+
# Python Type for values of ConstNode objects
|
|
6
|
+
ConstValue = float | int | bool
|
|
7
|
+
|
|
8
|
+
cdef class CircuitNode:
|
|
9
|
+
"""
|
|
10
|
+
A node in an arithmetic circuit.
|
|
11
|
+
Each node is either an op, var, or const node.
|
|
12
|
+
|
|
13
|
+
Each op node is either a mul, add or sub node. Each op
|
|
14
|
+
node has zero or more arguments. Each argument is another node.
|
|
15
|
+
|
|
16
|
+
Every var node has an index, `idx`, which is an integer counting from zero, and denotes
|
|
17
|
+
its creation order.
|
|
18
|
+
|
|
19
|
+
A var node may be temporarily set to be a constant node, which may
|
|
20
|
+
be useful for optimising a compiled circuit.
|
|
21
|
+
"""
|
|
22
|
+
cdef public object circuit
|
|
23
|
+
|
|
24
|
+
def __init__(self, circuit):
|
|
25
|
+
self.circuit = circuit
|
|
26
|
+
|
|
27
|
+
cpdef int is_zero(self) except*:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
cpdef int is_one(self) except*:
|
|
31
|
+
return False
|
|
32
|
+
|
|
33
|
+
def __add__(self, other: CircuitNode | ConstValue):
|
|
34
|
+
return self.circuit.add(self, other)
|
|
35
|
+
|
|
36
|
+
def __mul__(self, other: CircuitNode | ConstValue):
|
|
37
|
+
return self.circuit.mul(self, other)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
cdef class ConstNode(CircuitNode):
|
|
41
|
+
cdef public object value
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
A node in a circuit representing a constant value.
|
|
45
|
+
"""
|
|
46
|
+
def __init__(self, circuit, value: ConstValue):
|
|
47
|
+
super().__init__(circuit)
|
|
48
|
+
self.value: ConstValue = value
|
|
49
|
+
|
|
50
|
+
cpdef int is_zero(self) except*:
|
|
51
|
+
# noinspection PyProtectedMember
|
|
52
|
+
return self is self.circuit.zero
|
|
53
|
+
|
|
54
|
+
cpdef int is_one(self) except*:
|
|
55
|
+
# noinspection PyProtectedMember
|
|
56
|
+
return self is self.circuit.one
|
|
57
|
+
|
|
58
|
+
def __str__(self) -> str:
|
|
59
|
+
return 'const(' + str(self.value) + ')'
|
|
60
|
+
|
|
61
|
+
def __lt__(self, other) -> bool:
|
|
62
|
+
if isinstance(other, ConstNode):
|
|
63
|
+
return self.value < other.value
|
|
64
|
+
else:
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
cdef class VarNode(CircuitNode):
|
|
68
|
+
"""
|
|
69
|
+
A node in a circuit representing an input variable.
|
|
70
|
+
"""
|
|
71
|
+
cdef public int idx
|
|
72
|
+
cdef object _const
|
|
73
|
+
|
|
74
|
+
def __init__(self, circuit, idx: int):
|
|
75
|
+
super().__init__(circuit)
|
|
76
|
+
self.idx = idx
|
|
77
|
+
self._const = None
|
|
78
|
+
|
|
79
|
+
cpdef int is_zero(self) except*:
|
|
80
|
+
return self._const is not None and self._const.is_zero()
|
|
81
|
+
|
|
82
|
+
cpdef int is_one(self) except*:
|
|
83
|
+
return self._const is not None and self._const.is_one()
|
|
84
|
+
|
|
85
|
+
cpdef int is_const(self) except*:
|
|
86
|
+
return self._const is not None
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def const(self) -> Optional[ConstNode]:
|
|
90
|
+
return self._const
|
|
91
|
+
|
|
92
|
+
@const.setter
|
|
93
|
+
def const(self, value: ConstValue | ConstNode | None) -> None:
|
|
94
|
+
if value is None:
|
|
95
|
+
self._const = None
|
|
96
|
+
else:
|
|
97
|
+
self._const = self.circuit.const(value)
|
|
98
|
+
|
|
99
|
+
def __lt__(self, other) -> bool:
|
|
100
|
+
if isinstance(other, VarNode):
|
|
101
|
+
return self.idx < other.idx
|
|
102
|
+
else:
|
|
103
|
+
return False
|
|
104
|
+
|
|
105
|
+
def __str__(self) -> str:
|
|
106
|
+
if self._const is None:
|
|
107
|
+
return 'var[' + str(self.idx) + ']'
|
|
108
|
+
else:
|
|
109
|
+
return 'var[' + str(self.idx) + '] = ' + str(self._const.value)
|
|
110
|
+
|
|
111
|
+
cdef class OpNode(CircuitNode):
|
|
112
|
+
"""
|
|
113
|
+
A node in a circuit representing an arithmetic operation.
|
|
114
|
+
"""
|
|
115
|
+
cdef public tuple[object, ...] args
|
|
116
|
+
cdef public str symbol
|
|
117
|
+
|
|
118
|
+
def __init__(self, object circuit, symbol: str, tuple[object, ...] args: Tuple[CircuitNode]):
|
|
119
|
+
super().__init__(circuit)
|
|
120
|
+
self.args = tuple(args)
|
|
121
|
+
self.symbol = str(symbol)
|
|
122
|
+
|
|
123
|
+
def __str__(self) -> str:
|
|
124
|
+
return self.symbol + '\\' + str(len(self.args))
|
|
125
|
+
|
|
126
|
+
cdef class MulNode(OpNode):
|
|
127
|
+
"""
|
|
128
|
+
A node in a circuit representing a multiplication operation.
|
|
129
|
+
"""
|
|
130
|
+
def __init__(self, object circuit, tuple[object, ...] args: Tuple[CircuitNode, ...]):
|
|
131
|
+
super().__init__(circuit, 'mul', args)
|
|
132
|
+
|
|
133
|
+
cdef class AddNode(OpNode):
|
|
134
|
+
"""
|
|
135
|
+
A node in a circuit representing an addition operation.
|
|
136
|
+
"""
|
|
137
|
+
def __init__(self, circuit, tuple[object, ...] args: Tuple[CircuitNode, ...]):
|
|
138
|
+
super().__init__(circuit, 'add', args)
|
|
Binary file
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pickletools import long1
|
|
4
|
+
from typing import Sequence, Dict, List, Tuple, Set, Optional, Iterator
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import ctypes as ct
|
|
8
|
+
|
|
9
|
+
from ck import circuit
|
|
10
|
+
from ck.circuit import CircuitNode, ConstNode, VarNode, OpNode, ADD, Circuit
|
|
11
|
+
from ck.circuit_compiler.support.circuit_analyser import CircuitAnalysis, analyze_circuit
|
|
12
|
+
from ck.circuit_compiler.support.input_vars import infer_input_vars, InputVars
|
|
13
|
+
from ck.program.raw_program import RawProgram, RawProgramFunction
|
|
14
|
+
from ck.utils.np_extras import DType, NDArrayNumeric, NDArray, DTypeNumeric
|
|
15
|
+
|
|
16
|
+
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
|
|
17
|
+
|
|
18
|
+
cimport numpy as cnp
|
|
19
|
+
cimport cython
|
|
20
|
+
|
|
21
|
+
cnp.import_array()
|
|
22
|
+
|
|
23
|
+
DTYPE_FLOAT64 = np.float64
|
|
24
|
+
ctypedef cnp.float64_t DTYPE_FLOAT64_t
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def make_function(
|
|
29
|
+
var_nodes: Sequence[VarNode],
|
|
30
|
+
result_nodes: Sequence[CircuitNode],
|
|
31
|
+
dtype: DTypeNumeric,
|
|
32
|
+
) -> Tuple[RawProgramFunction, int]:
|
|
33
|
+
"""
|
|
34
|
+
Make a RawProgram function that interprets the circuit.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
(function, number_of_tmps)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
analysis: CircuitAnalysis = analyze_circuit(var_nodes, result_nodes)
|
|
41
|
+
cdef Instructions instructions
|
|
42
|
+
np_consts: NDArrayNumeric
|
|
43
|
+
instructions, np_consts = _make_instructions_from_analysis(analysis, dtype)
|
|
44
|
+
|
|
45
|
+
ptr_type = ct.POINTER(np.ctypeslib.as_ctypes_type(dtype))
|
|
46
|
+
c_np_consts = np_consts.ctypes.data_as(ptr_type)
|
|
47
|
+
|
|
48
|
+
# RawProgramFunction = Callable[[ct.POINTER, ct.POINTER, ct.POINTER], None]
|
|
49
|
+
def function(vars_in: ct.POINTER, tmps: ct.POINTER, result: ct.POINTER) -> None:
|
|
50
|
+
cdef size_t vars_in_addr = ct.cast(vars_in, ct.c_void_p).value
|
|
51
|
+
cdef size_t tmps_addr = ct.cast(tmps, ct.c_void_p).value
|
|
52
|
+
cdef size_t consts_addr = ct.cast(c_np_consts, ct.c_void_p).value
|
|
53
|
+
cdef size_t result_addr = ct.cast(result, ct.c_void_p).value
|
|
54
|
+
cvm_float64(
|
|
55
|
+
<double*> vars_in_addr,
|
|
56
|
+
<double*> tmps_addr,
|
|
57
|
+
<double*> consts_addr,
|
|
58
|
+
<double*> result_addr,
|
|
59
|
+
instructions,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return function, len(analysis.op_to_tmp)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# VM instructions
|
|
66
|
+
ADD = circuit.ADD
|
|
67
|
+
MUL = circuit.MUL
|
|
68
|
+
COPY: int = max(ADD, MUL) + 1
|
|
69
|
+
|
|
70
|
+
# VM arrays
|
|
71
|
+
VARS: int = 0
|
|
72
|
+
TMPS: int = 1
|
|
73
|
+
CONSTS: int = 2
|
|
74
|
+
RESULT: int = 3
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _make_instructions_from_analysis(
|
|
78
|
+
analysis: CircuitAnalysis,
|
|
79
|
+
dtype: DTypeNumeric,
|
|
80
|
+
) -> Tuple[Instructions, NDArrayNumeric]:
|
|
81
|
+
if dtype != np.float64:
|
|
82
|
+
raise RuntimeError(f'only DType {np.float64} currently supported')
|
|
83
|
+
|
|
84
|
+
# Store const values in a numpy array
|
|
85
|
+
node_to_const_idx: Dict[int, int] = {
|
|
86
|
+
id(node): i
|
|
87
|
+
for i, node in enumerate(analysis.const_nodes)
|
|
88
|
+
}
|
|
89
|
+
np_consts: NDArrayNumeric = np.zeros(len(node_to_const_idx), dtype=dtype)
|
|
90
|
+
for i, node in enumerate(analysis.const_nodes):
|
|
91
|
+
np_consts[i] = node.value
|
|
92
|
+
|
|
93
|
+
# Where to get input values for each possible node.
|
|
94
|
+
node_to_element: Dict[int, ElementID] = {}
|
|
95
|
+
# const nodes
|
|
96
|
+
for node_id, const_idx in node_to_const_idx.items():
|
|
97
|
+
node_to_element[node_id] = ElementID(CONSTS, const_idx)
|
|
98
|
+
# var nodes
|
|
99
|
+
for i, var_node in enumerate(analysis.var_nodes):
|
|
100
|
+
if var_node.is_const():
|
|
101
|
+
node_to_element[id(var_node)] = node_to_element[id(var_node.const)]
|
|
102
|
+
else:
|
|
103
|
+
node_to_element[id(var_node)] = ElementID(VARS, i)
|
|
104
|
+
# op nodes
|
|
105
|
+
for node_id, tmp_idx in analysis.op_to_tmp.items():
|
|
106
|
+
node_to_element[node_id] = ElementID(TMPS, tmp_idx)
|
|
107
|
+
for node_id, result_idx in analysis.op_to_result.items():
|
|
108
|
+
node_to_element[node_id] = ElementID(RESULT, result_idx)
|
|
109
|
+
|
|
110
|
+
# Build instructions
|
|
111
|
+
instructions: Instructions = Instructions()
|
|
112
|
+
|
|
113
|
+
op_node: OpNode
|
|
114
|
+
for op_node in analysis.op_nodes:
|
|
115
|
+
dest: ElementID = node_to_element[id(op_node)]
|
|
116
|
+
args: list[ElementID] = [
|
|
117
|
+
node_to_element[id(arg)]
|
|
118
|
+
for arg in op_node.args
|
|
119
|
+
]
|
|
120
|
+
instructions.append(op_node.symbol, args, dest)
|
|
121
|
+
|
|
122
|
+
# Add any copy operations, i.e., result nodes that are not op nodes
|
|
123
|
+
for i, node in enumerate(analysis.result_nodes):
|
|
124
|
+
if not isinstance(node, OpNode):
|
|
125
|
+
dest: ElementID = ElementID(RESULT, i)
|
|
126
|
+
args: list[ElementID] = [node_to_element[id(node)]]
|
|
127
|
+
instructions.append(COPY, args, dest)
|
|
128
|
+
|
|
129
|
+
return instructions, np_consts
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
cdef struct ElementID:
|
|
133
|
+
int array # VARS, TMPS, CONSTS, RESULT
|
|
134
|
+
int index # index into the array
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
cdef struct Instruction:
|
|
138
|
+
int symbol # ADD, MUL, COPY
|
|
139
|
+
int num_args
|
|
140
|
+
ElementID* args
|
|
141
|
+
ElementID dest
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
cdef class Instructions:
|
|
145
|
+
cdef Instruction* instructions
|
|
146
|
+
cdef int num_instructions
|
|
147
|
+
|
|
148
|
+
def __init__(self):
|
|
149
|
+
self.instructions = <Instruction*> PyMem_Malloc(0)
|
|
150
|
+
self.num_instructions = 0
|
|
151
|
+
|
|
152
|
+
def append(self, int symbol, list[ElementID] args, ElementID dest) -> None:
|
|
153
|
+
cdef int num_args = len(args)
|
|
154
|
+
cdef int i
|
|
155
|
+
|
|
156
|
+
c_args = <ElementID*> PyMem_Malloc(
|
|
157
|
+
num_args * sizeof(ElementID))
|
|
158
|
+
if not c_args:
|
|
159
|
+
raise MemoryError()
|
|
160
|
+
|
|
161
|
+
for i in range(num_args):
|
|
162
|
+
c_args[i] = args[i]
|
|
163
|
+
|
|
164
|
+
cdef int num_instructions = self.num_instructions
|
|
165
|
+
self.instructions = <Instruction*> PyMem_Realloc(
|
|
166
|
+
self.instructions,
|
|
167
|
+
sizeof(Instruction) * (num_instructions + 1)
|
|
168
|
+
)
|
|
169
|
+
if not self.instructions:
|
|
170
|
+
raise MemoryError()
|
|
171
|
+
|
|
172
|
+
self.instructions[num_instructions] = Instruction(
|
|
173
|
+
symbol,
|
|
174
|
+
num_args,
|
|
175
|
+
c_args,
|
|
176
|
+
dest
|
|
177
|
+
)
|
|
178
|
+
self.num_instructions = num_instructions + 1
|
|
179
|
+
|
|
180
|
+
def __dealloc__(self):
|
|
181
|
+
cdef Instruction* instructions = self.instructions
|
|
182
|
+
if instructions:
|
|
183
|
+
for i in range(self.num_instructions):
|
|
184
|
+
PyMem_Free(instructions[i].args)
|
|
185
|
+
PyMem_Free(instructions)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@cython.boundscheck(False) # turn off bounds-checking for entire function
|
|
190
|
+
@cython.wraparound(False) # turn off negative index wrapping for entire function
|
|
191
|
+
cdef void cvm_float64(
|
|
192
|
+
double* vars_in,
|
|
193
|
+
double* tmps,
|
|
194
|
+
double* consts,
|
|
195
|
+
double* result,
|
|
196
|
+
Instructions instructions,
|
|
197
|
+
):
|
|
198
|
+
# Core virtual machine.
|
|
199
|
+
|
|
200
|
+
cdef int i, num_args, symbol
|
|
201
|
+
cdef double accumulator
|
|
202
|
+
cdef ElementID* args
|
|
203
|
+
cdef ElementID elem
|
|
204
|
+
|
|
205
|
+
# index the four arrays by constants VARS, TMPS, CONSTS, and RESULT
|
|
206
|
+
cdef (double*) arrays[4]
|
|
207
|
+
arrays[VARS] = vars_in
|
|
208
|
+
arrays[TMPS] = tmps
|
|
209
|
+
arrays[CONSTS] = consts
|
|
210
|
+
arrays[RESULT] = result
|
|
211
|
+
|
|
212
|
+
cdef Instruction* instruction_ptr = instructions.instructions
|
|
213
|
+
for _ in range(instructions.num_instructions):
|
|
214
|
+
|
|
215
|
+
symbol = instruction_ptr[0].symbol
|
|
216
|
+
args = instruction_ptr[0].args
|
|
217
|
+
num_args = instruction_ptr[0].num_args
|
|
218
|
+
|
|
219
|
+
elem = args[0]
|
|
220
|
+
accumulator = arrays[elem.array][elem.index]
|
|
221
|
+
|
|
222
|
+
if symbol == ADD:
|
|
223
|
+
for i in range(1, num_args):
|
|
224
|
+
elem = args[i]
|
|
225
|
+
accumulator += arrays[elem.array][elem.index]
|
|
226
|
+
elif symbol == MUL:
|
|
227
|
+
for i in range(1, num_args):
|
|
228
|
+
elem = args[i]
|
|
229
|
+
accumulator *= arrays[elem.array][elem.index]
|
|
230
|
+
elif symbol == COPY:
|
|
231
|
+
pass
|
|
232
|
+
else:
|
|
233
|
+
raise RuntimeError('symbol not understood: ' + str(symbol))
|
|
234
|
+
|
|
235
|
+
elem = instruction_ptr[0].dest
|
|
236
|
+
arrays[elem.array][elem.index] = accumulator
|
|
237
|
+
|
|
238
|
+
# Advance the instruction pointer
|
|
239
|
+
instruction_ptr = &(instruction_ptr[1])
|
|
Binary file
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple, Iterable, Iterator
|
|
4
|
+
|
|
5
|
+
from ck.circuit import CircuitNode, Circuit, OpNode, MUL
|
|
6
|
+
|
|
7
|
+
TableInstance = Tuple[int, ...]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
cdef class CircuitTable:
|
|
11
|
+
"""
|
|
12
|
+
A circuit table manages a set of CircuitNodes, where each node corresponds
|
|
13
|
+
to an instance for a set of (zero or more) random variables.
|
|
14
|
+
|
|
15
|
+
Operations on circuit tables typically add circuit nodes to the circuit. It will
|
|
16
|
+
heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
|
|
17
|
+
by zero or one.) However, it may be that interim circuit nodes are created that
|
|
18
|
+
end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
|
|
19
|
+
completing all circuit table operations.
|
|
20
|
+
|
|
21
|
+
It is generally expected that no CircuitTable row will be created with a constant
|
|
22
|
+
zero node. These are assumed to be optimised out already.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
cdef public object circuit
|
|
26
|
+
cdef public tuple[int, ...] rv_idxs
|
|
27
|
+
cdef public dict[tuple[int, ...], CircuitNode] rows
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
circuit: Circuit,
|
|
32
|
+
rv_idxs: Sequence[int, ...],
|
|
33
|
+
rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Args:
|
|
37
|
+
circuit: the circuit whose nodes are being managed by this table.
|
|
38
|
+
rv_idxs: indexes of random variables.
|
|
39
|
+
rows: optional rows to add to the table.
|
|
40
|
+
|
|
41
|
+
Assumes:
|
|
42
|
+
* rv_idxs contains no duplicates.
|
|
43
|
+
* all row instances conform to the indexed random variables.
|
|
44
|
+
* all row circuit nodes belong to the given circuit.
|
|
45
|
+
"""
|
|
46
|
+
self.circuit = circuit
|
|
47
|
+
self.rv_idxs = tuple(rv_idxs)
|
|
48
|
+
self.rows = dict(rows)
|
|
49
|
+
|
|
50
|
+
def __len__(self) -> int:
|
|
51
|
+
return len(self.rows)
|
|
52
|
+
|
|
53
|
+
def get(self, key, default=None):
|
|
54
|
+
return self.rows.get(key, default)
|
|
55
|
+
|
|
56
|
+
def __getitem__(self, key):
|
|
57
|
+
return self.rows[key]
|
|
58
|
+
|
|
59
|
+
def __setitem__(self, key, value):
|
|
60
|
+
self.rows[key] = value
|
|
61
|
+
|
|
62
|
+
cpdef object top(self): # -> CircuitNode:
|
|
63
|
+
# Get the circuit top value.
|
|
64
|
+
#
|
|
65
|
+
# Raises:
|
|
66
|
+
# RuntimeError if there is more than one row in the table.
|
|
67
|
+
#
|
|
68
|
+
# Returns:
|
|
69
|
+
# A single circuit node.
|
|
70
|
+
cdef int number_of_rows = len(self.rows)
|
|
71
|
+
if number_of_rows == 0:
|
|
72
|
+
return self.circuit.zero
|
|
73
|
+
elif number_of_rows == 1:
|
|
74
|
+
return next(iter(self.rows.values()))
|
|
75
|
+
else:
|
|
76
|
+
raise RuntimeError('cannot get top node from a table with more that 1 row')
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ==================================================================================
|
|
80
|
+
# Circuit Table Operations
|
|
81
|
+
# ==================================================================================
|
|
82
|
+
|
|
83
|
+
cpdef object sum_out(object table: CircuitTable, object rv_idxs: Iterable[int]): # -> CircuitTable:
|
|
84
|
+
# Return a circuit table that results from summing out
|
|
85
|
+
# the given random variables of this circuit table.
|
|
86
|
+
#
|
|
87
|
+
# Normally this will return a new table. However, if rv_idxs is empty,
|
|
88
|
+
# then the given table is returned unmodified.
|
|
89
|
+
#
|
|
90
|
+
# Raises:
|
|
91
|
+
# ValueError if rv_idxs is not a subset of table.rv_idxs.
|
|
92
|
+
# ValueError if rv_idxs contains duplicates.
|
|
93
|
+
cdef tuple[int, ...] rv_idxs_seq = tuple(rv_idxs)
|
|
94
|
+
|
|
95
|
+
if len(rv_idxs_seq) == 0:
|
|
96
|
+
# nothing to do
|
|
97
|
+
return table
|
|
98
|
+
|
|
99
|
+
cdef set[int] rv_idxs_set = set(rv_idxs_seq)
|
|
100
|
+
if len(rv_idxs_set) != len(rv_idxs_seq):
|
|
101
|
+
raise ValueError('rv_idxs contains duplicates')
|
|
102
|
+
if not rv_idxs_set.issubset(table.rv_idxs):
|
|
103
|
+
raise ValueError('rv_idxs is not a subset of table.rv_idxs')
|
|
104
|
+
|
|
105
|
+
cdef int rv_index
|
|
106
|
+
cdef list[int] remaining_rv_idxs = []
|
|
107
|
+
for rv_index in table.rv_idxs:
|
|
108
|
+
if rv_index not in rv_idxs_set:
|
|
109
|
+
remaining_rv_idxs.append(rv_index)
|
|
110
|
+
|
|
111
|
+
cdef int num_remaining = len(remaining_rv_idxs)
|
|
112
|
+
if num_remaining == 0:
|
|
113
|
+
# Special case: summing out all random variables
|
|
114
|
+
return sum_out_all(table)
|
|
115
|
+
|
|
116
|
+
# index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
|
|
117
|
+
cdef list[int] index_map = []
|
|
118
|
+
for rv_index in remaining_rv_idxs:
|
|
119
|
+
index_map.append(_find(table.rv_idxs, rv_index))
|
|
120
|
+
|
|
121
|
+
cdef dict[tuple[int, ...], list[object]] groups = {}
|
|
122
|
+
cdef object got
|
|
123
|
+
cdef list[int] group_instance
|
|
124
|
+
cdef tuple[int, ...] group_instance_tuple
|
|
125
|
+
cdef int i
|
|
126
|
+
cdef object node
|
|
127
|
+
cdef tuple[int, ...] instance
|
|
128
|
+
for instance, node in table.rows.items():
|
|
129
|
+
group_instance = []
|
|
130
|
+
for i in index_map:
|
|
131
|
+
group_instance.append(instance[i])
|
|
132
|
+
group_instance_tuple = tuple(group_instance)
|
|
133
|
+
got = groups.get(group_instance_tuple)
|
|
134
|
+
if got is None:
|
|
135
|
+
groups[group_instance_tuple] = [node]
|
|
136
|
+
else:
|
|
137
|
+
got.append(node)
|
|
138
|
+
|
|
139
|
+
cdef object circuit = table.circuit
|
|
140
|
+
cdef object new_table = CircuitTable(circuit, remaining_rv_idxs)
|
|
141
|
+
cdef dict[tuple[int, ...], object] rows = new_table.rows
|
|
142
|
+
|
|
143
|
+
for group_instance_tuple, to_add in groups.items():
|
|
144
|
+
node = circuit.optimised_add(to_add)
|
|
145
|
+
if not node.is_zero():
|
|
146
|
+
rows[group_instance_tuple] = node
|
|
147
|
+
|
|
148
|
+
return new_table
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
cpdef object sum_out_all(object table: CircuitTable): # -> CircuitTable:
|
|
152
|
+
# Return a circuit table that results from summing out
|
|
153
|
+
# all random variables of this circuit table.
|
|
154
|
+
circuit: Circuit = table.circuit
|
|
155
|
+
num_rows: int = len(table)
|
|
156
|
+
if num_rows == 0:
|
|
157
|
+
return CircuitTable(circuit, ())
|
|
158
|
+
elif num_rows == 1:
|
|
159
|
+
node = next(iter(table.rows.values()))
|
|
160
|
+
else:
|
|
161
|
+
node: CircuitNode = circuit.optimised_add(table.rows.values())
|
|
162
|
+
if node.is_zero():
|
|
163
|
+
return CircuitTable(circuit, ())
|
|
164
|
+
|
|
165
|
+
return CircuitTable(circuit, (), [((), node)])
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
cpdef object project(object table: CircuitTable, object rv_idxs: Iterable[int]): # -> CircuitTable:
|
|
169
|
+
# Call `sum_out(table, to_sum_out)`, where
|
|
170
|
+
# `to_sum_out = table.rv_idxs - rv_idxs`.
|
|
171
|
+
cdef set[int] to_sum_out = set(table.rv_idxs)
|
|
172
|
+
to_sum_out.difference_update(rv_idxs)
|
|
173
|
+
return sum_out(table, to_sum_out)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
|
|
177
|
+
# Return a circuit table that results from the product of the two given tables.
|
|
178
|
+
#
|
|
179
|
+
# If x or y equals `one_table`, then the other table is returned. Otherwise,
|
|
180
|
+
# a new circuit table will be constructed and returned.
|
|
181
|
+
cdef int i
|
|
182
|
+
cdef object circuit = x.circuit
|
|
183
|
+
if y.circuit is not circuit:
|
|
184
|
+
raise ValueError('circuit tables must refer to the same circuit')
|
|
185
|
+
|
|
186
|
+
# Make the smaller table 'y', and the other 'x'.
|
|
187
|
+
# This is to minimise the index size on 'y'.
|
|
188
|
+
if len(x) < len(y):
|
|
189
|
+
x, y = y, x
|
|
190
|
+
|
|
191
|
+
# Special case: y == 0 or 1, and has no random variables.
|
|
192
|
+
if len(y.rv_idxs) == 0:
|
|
193
|
+
if len(y) == 1 and y.top().is_one():
|
|
194
|
+
return x
|
|
195
|
+
elif len(y) == 0:
|
|
196
|
+
return CircuitTable(circuit, x.rv_idxs)
|
|
197
|
+
|
|
198
|
+
# Set operations on rv indexes. After these operations:
|
|
199
|
+
# * co_rv_idxs is the set of rv indexes common (co) to x and y,
|
|
200
|
+
# * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
|
|
201
|
+
cdef set[int] yo_rv_idxs_set = set(y.rv_idxs)
|
|
202
|
+
cdef set[int] co_rv_idxs_set = set(x.rv_idxs)
|
|
203
|
+
co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
|
|
204
|
+
yo_rv_idxs_set.difference_update(co_rv_idxs_set)
|
|
205
|
+
|
|
206
|
+
if len(co_rv_idxs_set) == 0:
|
|
207
|
+
# Special case: no common random variables.
|
|
208
|
+
return _product_no_common_rvs(x, y)
|
|
209
|
+
|
|
210
|
+
# Convert random variable index sets to sequences
|
|
211
|
+
cdef tuple[int, ...] yo_rv_idxs = tuple(yo_rv_idxs_set) # y only random variables
|
|
212
|
+
cdef tuple[int, ...] co_rv_idxs = tuple(co_rv_idxs_set) # common random variables
|
|
213
|
+
|
|
214
|
+
# Cache mappings from result Instance to index into source Instance (x or y).
|
|
215
|
+
# This will be used in indexing and product loops to pull our needed values
|
|
216
|
+
# from the source instances.
|
|
217
|
+
cdef list[int] co_from_x_map = []
|
|
218
|
+
cdef list[int] co_from_y_map = []
|
|
219
|
+
cdef list[int] yo_from_y_map = []
|
|
220
|
+
for rv_index in co_rv_idxs:
|
|
221
|
+
co_from_x_map.append(_find(x.rv_idxs, rv_index))
|
|
222
|
+
co_from_y_map.append(_find(y.rv_idxs, rv_index))
|
|
223
|
+
for rv_index in yo_rv_idxs:
|
|
224
|
+
yo_from_y_map.append(_find(y.rv_idxs, rv_index))
|
|
225
|
+
|
|
226
|
+
cdef list[int] co
|
|
227
|
+
cdef list[int] yo
|
|
228
|
+
cdef object got
|
|
229
|
+
cdef tuple[int, ...] co_tuple
|
|
230
|
+
cdef tuple[int, ...] yo_tuple
|
|
231
|
+
|
|
232
|
+
cdef object table = CircuitTable(circuit, x.rv_idxs + yo_rv_idxs)
|
|
233
|
+
cdef dict[tuple[int, ...], object] rows = table.rows
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# Index the y rows by common-only key (y is the smaller of the two tables).
|
|
237
|
+
cdef dict[tuple[int, ...], list[tuple[tuple[int, ...], object]]] y_index = {}
|
|
238
|
+
for y_instance, y_node in y.rows.items():
|
|
239
|
+
co = []
|
|
240
|
+
yo = []
|
|
241
|
+
for i in co_from_y_map:
|
|
242
|
+
co.append(y_instance[i])
|
|
243
|
+
for i in yo_from_y_map:
|
|
244
|
+
yo.append(y_instance[i])
|
|
245
|
+
co_tuple = tuple(co)
|
|
246
|
+
yo_tuple = tuple(yo)
|
|
247
|
+
got = y_index.get(co_tuple)
|
|
248
|
+
if got is None:
|
|
249
|
+
y_index[co_tuple] = [(yo_tuple, y_node)]
|
|
250
|
+
else:
|
|
251
|
+
got.append((yo_tuple, y_node))
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# Iterate over x rows, inserting (instance, value).
|
|
255
|
+
# Rows with constant node values of one are optimised out.
|
|
256
|
+
for x_instance, x_node in x.rows.items():
|
|
257
|
+
co = []
|
|
258
|
+
for i in co_from_x_map:
|
|
259
|
+
co.append(x_instance[i])
|
|
260
|
+
co_tuple = tuple(co)
|
|
261
|
+
|
|
262
|
+
if x_node.is_one():
|
|
263
|
+
# Multiplying by one.
|
|
264
|
+
# Iterate over matching y rows.
|
|
265
|
+
got = y_index.get(co_tuple)
|
|
266
|
+
if got is not None:
|
|
267
|
+
for yo_tuple, y_node in got:
|
|
268
|
+
rows[x_instance + yo_tuple] = y_node
|
|
269
|
+
else:
|
|
270
|
+
# Iterate over matching y rows.
|
|
271
|
+
got = y_index.get(co_tuple)
|
|
272
|
+
if got is not None:
|
|
273
|
+
for yo_tuple, y_node in got:
|
|
274
|
+
rows[x_instance + yo_tuple] = _optimised_mul(circuit, x_node, y_node)
|
|
275
|
+
|
|
276
|
+
return table
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
cdef int _find(tuple[int, ...] xs, int x):
|
|
280
|
+
cdef int i
|
|
281
|
+
for i in range(len(xs)):
|
|
282
|
+
if xs[i] == x:
|
|
283
|
+
return i
|
|
284
|
+
# Very unexpected
|
|
285
|
+
raise RuntimeError('not found')
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
|
|
289
|
+
# Return the product of x and y, where x and y have no common random variables.
|
|
290
|
+
#
|
|
291
|
+
# This is an optimisation of more general product algorithm as no index needs
|
|
292
|
+
# to be construction based on the common random variables.
|
|
293
|
+
#
|
|
294
|
+
# Rows with constant node values of one are optimised out.
|
|
295
|
+
#
|
|
296
|
+
# Assumes:
|
|
297
|
+
# * There are no common random variables between x and y.
|
|
298
|
+
# * x and y are for the same circuit.
|
|
299
|
+
cdef object circuit = x.circuit
|
|
300
|
+
cdef object table = CircuitTable(circuit, x.rv_idxs + y.rv_idxs)
|
|
301
|
+
cdef tuple[int, ...] instance
|
|
302
|
+
|
|
303
|
+
for x_instance, x_node in x.rows.items():
|
|
304
|
+
if x_node.is_one():
|
|
305
|
+
for y_instance, y_node in y.rows.items():
|
|
306
|
+
instance = x_instance + y_instance
|
|
307
|
+
table.rows[instance] = y_node
|
|
308
|
+
else:
|
|
309
|
+
for y_instance, y_node in y.rows.items():
|
|
310
|
+
instance = x_instance + y_instance
|
|
311
|
+
table.rows[instance] = _optimised_mul(circuit, x_node, y_node)
|
|
312
|
+
|
|
313
|
+
return table
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
cdef object _optimised_mul(object circuit: Circuit, object x: CircuitNode, object y: CircuitNode): # -> CircuitNode
|
|
317
|
+
if x.is_zero():
|
|
318
|
+
return x
|
|
319
|
+
if y.is_zero():
|
|
320
|
+
return y
|
|
321
|
+
if x.is_one():
|
|
322
|
+
return y
|
|
323
|
+
if y.is_one():
|
|
324
|
+
return x
|
|
325
|
+
return circuit.mul(x, y)
|