compiled-knowledge 4.0.0a6__cp313-cp313-macosx_10_13_universal2.whl → 4.0.0a8__cp313-cp313-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

@@ -0,0 +1,138 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ # Python Type for values of ConstNode objects
6
+ ConstValue = float | int | bool
7
+
8
+ cdef class CircuitNode:
9
+ """
10
+ A node in an arithmetic circuit.
11
+ Each node is either an op, var, or const node.
12
+
13
+ Each op node is either a mul, add or sub node. Each op
14
+ node has zero or more arguments. Each argument is another node.
15
+
16
+ Every var node has an index, `idx`, which is an integer counting from zero, and denotes
17
+ its creation order.
18
+
19
+ A var node may be temporarily set to be a constant node, which may
20
+ be useful for optimising a compiled circuit.
21
+ """
22
+ cdef public object circuit
23
+
24
+ def __init__(self, circuit):
25
+ self.circuit = circuit
26
+
27
+ cpdef int is_zero(self) except*:
28
+ return False
29
+
30
+ cpdef int is_one(self) except*:
31
+ return False
32
+
33
+ def __add__(self, other: CircuitNode | ConstValue):
34
+ return self.circuit.add(self, other)
35
+
36
+ def __mul__(self, other: CircuitNode | ConstValue):
37
+ return self.circuit.mul(self, other)
38
+
39
+
40
+ cdef class ConstNode(CircuitNode):
41
+ cdef public object value
42
+
43
+ """
44
+ A node in a circuit representing a constant value.
45
+ """
46
+ def __init__(self, circuit, value: ConstValue):
47
+ super().__init__(circuit)
48
+ self.value: ConstValue = value
49
+
50
+ cpdef int is_zero(self) except*:
51
+ # noinspection PyProtectedMember
52
+ return self is self.circuit.zero
53
+
54
+ cpdef int is_one(self) except*:
55
+ # noinspection PyProtectedMember
56
+ return self is self.circuit.one
57
+
58
+ def __str__(self) -> str:
59
+ return 'const(' + str(self.value) + ')'
60
+
61
+ def __lt__(self, other) -> bool:
62
+ if isinstance(other, ConstNode):
63
+ return self.value < other.value
64
+ else:
65
+ return False
66
+
67
+ cdef class VarNode(CircuitNode):
68
+ """
69
+ A node in a circuit representing an input variable.
70
+ """
71
+ cdef public int idx
72
+ cdef object _const
73
+
74
+ def __init__(self, circuit, idx: int):
75
+ super().__init__(circuit)
76
+ self.idx = idx
77
+ self._const = None
78
+
79
+ cpdef int is_zero(self) except*:
80
+ return self._const is not None and self._const.is_zero()
81
+
82
+ cpdef int is_one(self) except*:
83
+ return self._const is not None and self._const.is_one()
84
+
85
+ cpdef int is_const(self) except*:
86
+ return self._const is not None
87
+
88
+ @property
89
+ def const(self) -> Optional[ConstNode]:
90
+ return self._const
91
+
92
+ @const.setter
93
+ def const(self, value: ConstValue | ConstNode | None) -> None:
94
+ if value is None:
95
+ self._const = None
96
+ else:
97
+ self._const = self.circuit.const(value)
98
+
99
+ def __lt__(self, other) -> bool:
100
+ if isinstance(other, VarNode):
101
+ return self.idx < other.idx
102
+ else:
103
+ return False
104
+
105
+ def __str__(self) -> str:
106
+ if self._const is None:
107
+ return 'var[' + str(self.idx) + ']'
108
+ else:
109
+ return 'var[' + str(self.idx) + '] = ' + str(self._const.value)
110
+
111
+ cdef class OpNode(CircuitNode):
112
+ """
113
+ A node in a circuit representing an arithmetic operation.
114
+ """
115
+ cdef public tuple[object, ...] args
116
+ cdef public str symbol
117
+
118
+ def __init__(self, object circuit, symbol: str, tuple[object, ...] args: Tuple[CircuitNode]):
119
+ super().__init__(circuit)
120
+ self.args = tuple(args)
121
+ self.symbol = str(symbol)
122
+
123
+ def __str__(self) -> str:
124
+ return self.symbol + '\\' + str(len(self.args))
125
+
126
+ cdef class MulNode(OpNode):
127
+ """
128
+ A node in a circuit representing a multiplication operation.
129
+ """
130
+ def __init__(self, object circuit, tuple[object, ...] args: Tuple[CircuitNode, ...]):
131
+ super().__init__(circuit, 'mul', args)
132
+
133
+ cdef class AddNode(OpNode):
134
+ """
135
+ A node in a circuit representing an addition operation.
136
+ """
137
+ def __init__(self, circuit, tuple[object, ...] args: Tuple[CircuitNode, ...]):
138
+ super().__init__(circuit, 'add', args)
@@ -0,0 +1,239 @@
1
+ from __future__ import annotations
2
+
3
+ from pickletools import long1
4
+ from typing import Sequence, Dict, List, Tuple, Set, Optional, Iterator
5
+
6
+ import numpy as np
7
+ import ctypes as ct
8
+
9
+ from ck import circuit
10
+ from ck.circuit import CircuitNode, ConstNode, VarNode, OpNode, ADD, Circuit
11
+ from ck.circuit_compiler.support.circuit_analyser import CircuitAnalysis, analyze_circuit
12
+ from ck.circuit_compiler.support.input_vars import infer_input_vars, InputVars
13
+ from ck.program.raw_program import RawProgram, RawProgramFunction
14
+ from ck.utils.np_extras import DType, NDArrayNumeric, NDArray, DTypeNumeric
15
+
16
+ from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
17
+
18
+ cimport numpy as cnp
19
+ cimport cython
20
+
21
+ cnp.import_array()
22
+
23
+ DTYPE_FLOAT64 = np.float64
24
+ ctypedef cnp.float64_t DTYPE_FLOAT64_t
25
+
26
+
27
+
28
+ def make_function(
29
+ var_nodes: Sequence[VarNode],
30
+ result_nodes: Sequence[CircuitNode],
31
+ dtype: DTypeNumeric,
32
+ ) -> Tuple[RawProgramFunction, int]:
33
+ """
34
+ Make a RawProgram function that interprets the circuit.
35
+
36
+ Returns:
37
+ (function, number_of_tmps)
38
+ """
39
+
40
+ analysis: CircuitAnalysis = analyze_circuit(var_nodes, result_nodes)
41
+ cdef Instructions instructions
42
+ np_consts: NDArrayNumeric
43
+ instructions, np_consts = _make_instructions_from_analysis(analysis, dtype)
44
+
45
+ ptr_type = ct.POINTER(np.ctypeslib.as_ctypes_type(dtype))
46
+ c_np_consts = np_consts.ctypes.data_as(ptr_type)
47
+
48
+ # RawProgramFunction = Callable[[ct.POINTER, ct.POINTER, ct.POINTER], None]
49
+ def function(vars_in: ct.POINTER, tmps: ct.POINTER, result: ct.POINTER) -> None:
50
+ cdef size_t vars_in_addr = ct.cast(vars_in, ct.c_void_p).value
51
+ cdef size_t tmps_addr = ct.cast(tmps, ct.c_void_p).value
52
+ cdef size_t consts_addr = ct.cast(c_np_consts, ct.c_void_p).value
53
+ cdef size_t result_addr = ct.cast(result, ct.c_void_p).value
54
+ cvm_float64(
55
+ <double*> vars_in_addr,
56
+ <double*> tmps_addr,
57
+ <double*> consts_addr,
58
+ <double*> result_addr,
59
+ instructions,
60
+ )
61
+
62
+ return function, len(analysis.op_to_tmp)
63
+
64
+
65
+ # VM instructions
66
+ ADD = circuit.ADD
67
+ MUL = circuit.MUL
68
+ COPY: int = max(ADD, MUL) + 1
69
+
70
+ # VM arrays
71
+ VARS: int = 0
72
+ TMPS: int = 1
73
+ CONSTS: int = 2
74
+ RESULT: int = 3
75
+
76
+
77
+ def _make_instructions_from_analysis(
78
+ analysis: CircuitAnalysis,
79
+ dtype: DTypeNumeric,
80
+ ) -> Tuple[Instructions, NDArrayNumeric]:
81
+ if dtype != np.float64:
82
+ raise RuntimeError(f'only DType {np.float64} currently supported')
83
+
84
+ # Store const values in a numpy array
85
+ node_to_const_idx: Dict[int, int] = {
86
+ id(node): i
87
+ for i, node in enumerate(analysis.const_nodes)
88
+ }
89
+ np_consts: NDArrayNumeric = np.zeros(len(node_to_const_idx), dtype=dtype)
90
+ for i, node in enumerate(analysis.const_nodes):
91
+ np_consts[i] = node.value
92
+
93
+ # Where to get input values for each possible node.
94
+ node_to_element: Dict[int, ElementID] = {}
95
+ # const nodes
96
+ for node_id, const_idx in node_to_const_idx.items():
97
+ node_to_element[node_id] = ElementID(CONSTS, const_idx)
98
+ # var nodes
99
+ for i, var_node in enumerate(analysis.var_nodes):
100
+ if var_node.is_const():
101
+ node_to_element[id(var_node)] = node_to_element[id(var_node.const)]
102
+ else:
103
+ node_to_element[id(var_node)] = ElementID(VARS, i)
104
+ # op nodes
105
+ for node_id, tmp_idx in analysis.op_to_tmp.items():
106
+ node_to_element[node_id] = ElementID(TMPS, tmp_idx)
107
+ for node_id, result_idx in analysis.op_to_result.items():
108
+ node_to_element[node_id] = ElementID(RESULT, result_idx)
109
+
110
+ # Build instructions
111
+ instructions: Instructions = Instructions()
112
+
113
+ op_node: OpNode
114
+ for op_node in analysis.op_nodes:
115
+ dest: ElementID = node_to_element[id(op_node)]
116
+ args: list[ElementID] = [
117
+ node_to_element[id(arg)]
118
+ for arg in op_node.args
119
+ ]
120
+ instructions.append(op_node.symbol, args, dest)
121
+
122
+ # Add any copy operations, i.e., result nodes that are not op nodes
123
+ for i, node in enumerate(analysis.result_nodes):
124
+ if not isinstance(node, OpNode):
125
+ dest: ElementID = ElementID(RESULT, i)
126
+ args: list[ElementID] = [node_to_element[id(node)]]
127
+ instructions.append(COPY, args, dest)
128
+
129
+ return instructions, np_consts
130
+
131
+
132
+ cdef struct ElementID:
133
+ int array # VARS, TMPS, CONSTS, RESULT
134
+ int index # index into the array
135
+
136
+
137
+ cdef struct Instruction:
138
+ int symbol # ADD, MUL, COPY
139
+ int num_args
140
+ ElementID* args
141
+ ElementID dest
142
+
143
+
144
+ cdef class Instructions:
145
+ cdef Instruction* instructions
146
+ cdef int num_instructions
147
+
148
+ def __init__(self):
149
+ self.instructions = <Instruction*> PyMem_Malloc(0)
150
+ self.num_instructions = 0
151
+
152
+ def append(self, int symbol, list[ElementID] args, ElementID dest) -> None:
153
+ cdef int num_args = len(args)
154
+ cdef int i
155
+
156
+ c_args = <ElementID*> PyMem_Malloc(
157
+ num_args * sizeof(ElementID))
158
+ if not c_args:
159
+ raise MemoryError()
160
+
161
+ for i in range(num_args):
162
+ c_args[i] = args[i]
163
+
164
+ cdef int num_instructions = self.num_instructions
165
+ self.instructions = <Instruction*> PyMem_Realloc(
166
+ self.instructions,
167
+ sizeof(Instruction) * (num_instructions + 1)
168
+ )
169
+ if not self.instructions:
170
+ raise MemoryError()
171
+
172
+ self.instructions[num_instructions] = Instruction(
173
+ symbol,
174
+ num_args,
175
+ c_args,
176
+ dest
177
+ )
178
+ self.num_instructions = num_instructions + 1
179
+
180
+ def __dealloc__(self):
181
+ cdef Instruction* instructions = self.instructions
182
+ if instructions:
183
+ for i in range(self.num_instructions):
184
+ PyMem_Free(instructions[i].args)
185
+ PyMem_Free(instructions)
186
+
187
+
188
+
189
+ @cython.boundscheck(False) # turn off bounds-checking for entire function
190
+ @cython.wraparound(False) # turn off negative index wrapping for entire function
191
+ cdef void cvm_float64(
192
+ double* vars_in,
193
+ double* tmps,
194
+ double* consts,
195
+ double* result,
196
+ Instructions instructions,
197
+ ):
198
+ # Core virtual machine.
199
+
200
+ cdef int i, num_args, symbol
201
+ cdef double accumulator
202
+ cdef ElementID* args
203
+ cdef ElementID elem
204
+
205
+ # index the four arrays by constants VARS, TMPS, CONSTS, and RESULT
206
+ cdef (double*) arrays[4]
207
+ arrays[VARS] = vars_in
208
+ arrays[TMPS] = tmps
209
+ arrays[CONSTS] = consts
210
+ arrays[RESULT] = result
211
+
212
+ cdef Instruction* instruction_ptr = instructions.instructions
213
+ for _ in range(instructions.num_instructions):
214
+
215
+ symbol = instruction_ptr[0].symbol
216
+ args = instruction_ptr[0].args
217
+ num_args = instruction_ptr[0].num_args
218
+
219
+ elem = args[0]
220
+ accumulator = arrays[elem.array][elem.index]
221
+
222
+ if symbol == ADD:
223
+ for i in range(1, num_args):
224
+ elem = args[i]
225
+ accumulator += arrays[elem.array][elem.index]
226
+ elif symbol == MUL:
227
+ for i in range(1, num_args):
228
+ elem = args[i]
229
+ accumulator *= arrays[elem.array][elem.index]
230
+ elif symbol == COPY:
231
+ pass
232
+ else:
233
+ raise RuntimeError('symbol not understood: ' + str(symbol))
234
+
235
+ elem = instruction_ptr[0].dest
236
+ arrays[elem.array][elem.index] = accumulator
237
+
238
+ # Advance the instruction pointer
239
+ instruction_ptr = &(instruction_ptr[1])
@@ -0,0 +1,325 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence, Tuple, Iterable, Iterator
4
+
5
+ from ck.circuit import CircuitNode, Circuit, OpNode, MUL
6
+
7
+ TableInstance = Tuple[int, ...]
8
+
9
+
10
+ cdef class CircuitTable:
11
+ """
12
+ A circuit table manages a set of CircuitNodes, where each node corresponds
13
+ to an instance for a set of (zero or more) random variables.
14
+
15
+ Operations on circuit tables typically add circuit nodes to the circuit. It will
16
+ heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
17
+ by zero or one.) However, it may be that interim circuit nodes are created that
18
+ end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
19
+ completing all circuit table operations.
20
+
21
+ It is generally expected that no CircuitTable row will be created with a constant
22
+ zero node. These are assumed to be optimised out already.
23
+ """
24
+
25
+ cdef public object circuit
26
+ cdef public tuple[int, ...] rv_idxs
27
+ cdef public dict[tuple[int, ...], CircuitNode] rows
28
+
29
+ def __init__(
30
+ self,
31
+ circuit: Circuit,
32
+ rv_idxs: Sequence[int, ...],
33
+ rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
34
+ ):
35
+ """
36
+ Args:
37
+ circuit: the circuit whose nodes are being managed by this table.
38
+ rv_idxs: indexes of random variables.
39
+ rows: optional rows to add to the table.
40
+
41
+ Assumes:
42
+ * rv_idxs contains no duplicates.
43
+ * all row instances conform to the indexed random variables.
44
+ * all row circuit nodes belong to the given circuit.
45
+ """
46
+ self.circuit = circuit
47
+ self.rv_idxs = tuple(rv_idxs)
48
+ self.rows = dict(rows)
49
+
50
+ def __len__(self) -> int:
51
+ return len(self.rows)
52
+
53
+ def get(self, key, default=None):
54
+ return self.rows.get(key, default)
55
+
56
+ def __getitem__(self, key):
57
+ return self.rows[key]
58
+
59
+ def __setitem__(self, key, value):
60
+ self.rows[key] = value
61
+
62
+ cpdef object top(self): # -> CircuitNode:
63
+ # Get the circuit top value.
64
+ #
65
+ # Raises:
66
+ # RuntimeError if there is more than one row in the table.
67
+ #
68
+ # Returns:
69
+ # A single circuit node.
70
+ cdef int number_of_rows = len(self.rows)
71
+ if number_of_rows == 0:
72
+ return self.circuit.zero
73
+ elif number_of_rows == 1:
74
+ return next(iter(self.rows.values()))
75
+ else:
76
+ raise RuntimeError('cannot get top node from a table with more that 1 row')
77
+
78
+
79
+ # ==================================================================================
80
+ # Circuit Table Operations
81
+ # ==================================================================================
82
+
83
+ cpdef object sum_out(object table: CircuitTable, object rv_idxs: Iterable[int]): # -> CircuitTable:
84
+ # Return a circuit table that results from summing out
85
+ # the given random variables of this circuit table.
86
+ #
87
+ # Normally this will return a new table. However, if rv_idxs is empty,
88
+ # then the given table is returned unmodified.
89
+ #
90
+ # Raises:
91
+ # ValueError if rv_idxs is not a subset of table.rv_idxs.
92
+ # ValueError if rv_idxs contains duplicates.
93
+ cdef tuple[int, ...] rv_idxs_seq = tuple(rv_idxs)
94
+
95
+ if len(rv_idxs_seq) == 0:
96
+ # nothing to do
97
+ return table
98
+
99
+ cdef set[int] rv_idxs_set = set(rv_idxs_seq)
100
+ if len(rv_idxs_set) != len(rv_idxs_seq):
101
+ raise ValueError('rv_idxs contains duplicates')
102
+ if not rv_idxs_set.issubset(table.rv_idxs):
103
+ raise ValueError('rv_idxs is not a subset of table.rv_idxs')
104
+
105
+ cdef int rv_index
106
+ cdef list[int] remaining_rv_idxs = []
107
+ for rv_index in table.rv_idxs:
108
+ if rv_index not in rv_idxs_set:
109
+ remaining_rv_idxs.append(rv_index)
110
+
111
+ cdef int num_remaining = len(remaining_rv_idxs)
112
+ if num_remaining == 0:
113
+ # Special case: summing out all random variables
114
+ return sum_out_all(table)
115
+
116
+ # index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
117
+ cdef list[int] index_map = []
118
+ for rv_index in remaining_rv_idxs:
119
+ index_map.append(_find(table.rv_idxs, rv_index))
120
+
121
+ cdef dict[tuple[int, ...], list[object]] groups = {}
122
+ cdef object got
123
+ cdef list[int] group_instance
124
+ cdef tuple[int, ...] group_instance_tuple
125
+ cdef int i
126
+ cdef object node
127
+ cdef tuple[int, ...] instance
128
+ for instance, node in table.rows.items():
129
+ group_instance = []
130
+ for i in index_map:
131
+ group_instance.append(instance[i])
132
+ group_instance_tuple = tuple(group_instance)
133
+ got = groups.get(group_instance_tuple)
134
+ if got is None:
135
+ groups[group_instance_tuple] = [node]
136
+ else:
137
+ got.append(node)
138
+
139
+ cdef object circuit = table.circuit
140
+ cdef object new_table = CircuitTable(circuit, remaining_rv_idxs)
141
+ cdef dict[tuple[int, ...], object] rows = new_table.rows
142
+
143
+ for group_instance_tuple, to_add in groups.items():
144
+ node = circuit.optimised_add(to_add)
145
+ if not node.is_zero():
146
+ rows[group_instance_tuple] = node
147
+
148
+ return new_table
149
+
150
+
151
+ cpdef object sum_out_all(object table: CircuitTable): # -> CircuitTable:
152
+ # Return a circuit table that results from summing out
153
+ # all random variables of this circuit table.
154
+ circuit: Circuit = table.circuit
155
+ num_rows: int = len(table)
156
+ if num_rows == 0:
157
+ return CircuitTable(circuit, ())
158
+ elif num_rows == 1:
159
+ node = next(iter(table.rows.values()))
160
+ else:
161
+ node: CircuitNode = circuit.optimised_add(table.rows.values())
162
+ if node.is_zero():
163
+ return CircuitTable(circuit, ())
164
+
165
+ return CircuitTable(circuit, (), [((), node)])
166
+
167
+
168
+ cpdef object project(object table: CircuitTable, object rv_idxs: Iterable[int]): # -> CircuitTable:
169
+ # Call `sum_out(table, to_sum_out)`, where
170
+ # `to_sum_out = table.rv_idxs - rv_idxs`.
171
+ cdef set[int] to_sum_out = set(table.rv_idxs)
172
+ to_sum_out.difference_update(rv_idxs)
173
+ return sum_out(table, to_sum_out)
174
+
175
+
176
+ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
177
+ # Return a circuit table that results from the product of the two given tables.
178
+ #
179
+ # If x or y equals `one_table`, then the other table is returned. Otherwise,
180
+ # a new circuit table will be constructed and returned.
181
+ cdef int i
182
+ cdef object circuit = x.circuit
183
+ if y.circuit is not circuit:
184
+ raise ValueError('circuit tables must refer to the same circuit')
185
+
186
+ # Make the smaller table 'y', and the other 'x'.
187
+ # This is to minimise the index size on 'y'.
188
+ if len(x) < len(y):
189
+ x, y = y, x
190
+
191
+ # Special case: y == 0 or 1, and has no random variables.
192
+ if len(y.rv_idxs) == 0:
193
+ if len(y) == 1 and y.top().is_one():
194
+ return x
195
+ elif len(y) == 0:
196
+ return CircuitTable(circuit, x.rv_idxs)
197
+
198
+ # Set operations on rv indexes. After these operations:
199
+ # * co_rv_idxs is the set of rv indexes common (co) to x and y,
200
+ # * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
201
+ cdef set[int] yo_rv_idxs_set = set(y.rv_idxs)
202
+ cdef set[int] co_rv_idxs_set = set(x.rv_idxs)
203
+ co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
204
+ yo_rv_idxs_set.difference_update(co_rv_idxs_set)
205
+
206
+ if len(co_rv_idxs_set) == 0:
207
+ # Special case: no common random variables.
208
+ return _product_no_common_rvs(x, y)
209
+
210
+ # Convert random variable index sets to sequences
211
+ cdef tuple[int, ...] yo_rv_idxs = tuple(yo_rv_idxs_set) # y only random variables
212
+ cdef tuple[int, ...] co_rv_idxs = tuple(co_rv_idxs_set) # common random variables
213
+
214
+ # Cache mappings from result Instance to index into source Instance (x or y).
215
+ # This will be used in indexing and product loops to pull our needed values
216
+ # from the source instances.
217
+ cdef list[int] co_from_x_map = []
218
+ cdef list[int] co_from_y_map = []
219
+ cdef list[int] yo_from_y_map = []
220
+ for rv_index in co_rv_idxs:
221
+ co_from_x_map.append(_find(x.rv_idxs, rv_index))
222
+ co_from_y_map.append(_find(y.rv_idxs, rv_index))
223
+ for rv_index in yo_rv_idxs:
224
+ yo_from_y_map.append(_find(y.rv_idxs, rv_index))
225
+
226
+ cdef list[int] co
227
+ cdef list[int] yo
228
+ cdef object got
229
+ cdef tuple[int, ...] co_tuple
230
+ cdef tuple[int, ...] yo_tuple
231
+
232
+ cdef object table = CircuitTable(circuit, x.rv_idxs + yo_rv_idxs)
233
+ cdef dict[tuple[int, ...], object] rows = table.rows
234
+
235
+
236
+ # Index the y rows by common-only key (y is the smaller of the two tables).
237
+ cdef dict[tuple[int, ...], list[tuple[tuple[int, ...], object]]] y_index = {}
238
+ for y_instance, y_node in y.rows.items():
239
+ co = []
240
+ yo = []
241
+ for i in co_from_y_map:
242
+ co.append(y_instance[i])
243
+ for i in yo_from_y_map:
244
+ yo.append(y_instance[i])
245
+ co_tuple = tuple(co)
246
+ yo_tuple = tuple(yo)
247
+ got = y_index.get(co_tuple)
248
+ if got is None:
249
+ y_index[co_tuple] = [(yo_tuple, y_node)]
250
+ else:
251
+ got.append((yo_tuple, y_node))
252
+
253
+
254
+ # Iterate over x rows, inserting (instance, value).
255
+ # Rows with constant node values of one are optimised out.
256
+ for x_instance, x_node in x.rows.items():
257
+ co = []
258
+ for i in co_from_x_map:
259
+ co.append(x_instance[i])
260
+ co_tuple = tuple(co)
261
+
262
+ if x_node.is_one():
263
+ # Multiplying by one.
264
+ # Iterate over matching y rows.
265
+ got = y_index.get(co_tuple)
266
+ if got is not None:
267
+ for yo_tuple, y_node in got:
268
+ rows[x_instance + yo_tuple] = y_node
269
+ else:
270
+ # Iterate over matching y rows.
271
+ got = y_index.get(co_tuple)
272
+ if got is not None:
273
+ for yo_tuple, y_node in got:
274
+ rows[x_instance + yo_tuple] = _optimised_mul(circuit, x_node, y_node)
275
+
276
+ return table
277
+
278
+
279
+ cdef int _find(tuple[int, ...] xs, int x):
280
+ cdef int i
281
+ for i in range(len(xs)):
282
+ if xs[i] == x:
283
+ return i
284
+ # Very unexpected
285
+ raise RuntimeError('not found')
286
+
287
+
288
+ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
289
+ # Return the product of x and y, where x and y have no common random variables.
290
+ #
291
+ # This is an optimisation of more general product algorithm as no index needs
292
+ # to be construction based on the common random variables.
293
+ #
294
+ # Rows with constant node values of one are optimised out.
295
+ #
296
+ # Assumes:
297
+ # * There are no common random variables between x and y.
298
+ # * x and y are for the same circuit.
299
+ cdef object circuit = x.circuit
300
+ cdef object table = CircuitTable(circuit, x.rv_idxs + y.rv_idxs)
301
+ cdef tuple[int, ...] instance
302
+
303
+ for x_instance, x_node in x.rows.items():
304
+ if x_node.is_one():
305
+ for y_instance, y_node in y.rows.items():
306
+ instance = x_instance + y_instance
307
+ table.rows[instance] = y_node
308
+ else:
309
+ for y_instance, y_node in y.rows.items():
310
+ instance = x_instance + y_instance
311
+ table.rows[instance] = _optimised_mul(circuit, x_node, y_node)
312
+
313
+ return table
314
+
315
+
316
+ cdef object _optimised_mul(object circuit: Circuit, object x: CircuitNode, object y: CircuitNode): # -> CircuitNode
317
+ if x.is_zero():
318
+ return x
319
+ if y.is_zero():
320
+ return y
321
+ if x.is_one():
322
+ return y
323
+ if y.is_one():
324
+ return x
325
+ return circuit.mul(x, y)