compiled-knowledge 4.0.0a16__cp312-cp312-win_amd64.whl → 4.0.0a18__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (36) hide show
  1. ck/circuit/__init__.py +9 -2
  2. ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
  3. ck/circuit/_circuit_cy.pxd +33 -0
  4. ck/circuit/{circuit.pyx → _circuit_cy.pyx} +115 -133
  5. ck/circuit/{circuit_py.py → _circuit_py.py} +16 -8
  6. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  7. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +88 -60
  8. ck/circuit_compiler/named_circuit_compilers.py +1 -1
  9. ck/pgm_compiler/factor_elimination.py +23 -13
  10. ck/pgm_compiler/support/circuit_table/__init__.py +9 -2
  11. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
  12. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  13. ck/pgm_compiler/support/circuit_table/_circuit_table_cy_cpp_verion.pyx +601 -0
  14. ck/pgm_compiler/support/circuit_table/_circuit_table_cy_minimal_version.pyx +311 -0
  15. ck/pgm_compiler/support/circuit_table/{circuit_table.pyx → _circuit_table_cy_v4.0.0a17.pyx} +9 -9
  16. ck/pgm_compiler/support/circuit_table/{circuit_table_py.py → _circuit_table_py.py} +80 -45
  17. ck/pgm_compiler/support/clusters.py +16 -4
  18. ck/pgm_compiler/support/factor_tables.py +1 -1
  19. ck/pgm_compiler/support/join_tree.py +67 -10
  20. ck/pgm_compiler/support/named_compiler_maker.py +12 -2
  21. ck/pgm_compiler/variable_elimination.py +2 -0
  22. ck/utils/iter_extras.py +8 -1
  23. ck_demos/pgm_compiler/demo_compiler_dump.py +10 -0
  24. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  25. ck_demos/utils/compare.py +5 -1
  26. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/METADATA +1 -1
  27. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/RECORD +30 -29
  28. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/WHEEL +1 -1
  29. ck/circuit/circuit.c +0 -38861
  30. ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
  31. ck/circuit/circuit_node.pyx +0 -138
  32. ck/circuit_compiler/cython_vm_compiler/_compiler.c +0 -17373
  33. ck/pgm_compiler/support/circuit_table/circuit_table.c +0 -16042
  34. ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
  35. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/licenses/LICENSE.txt +0 -0
  36. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
- from pickletools import long1
4
- from typing import Sequence, Dict, List, Tuple, Set, Optional, Iterator
3
+ from typing import Dict, Tuple, Sequence
5
4
 
6
5
  import numpy as np
7
6
  import ctypes as ct
8
7
 
9
8
  from ck import circuit
10
- from ck.circuit import CircuitNode, ConstNode, VarNode, OpNode, ADD, Circuit
9
+ from ck.circuit import OpNode, VarNode, CircuitNode
11
10
  from ck.circuit_compiler.support.circuit_analyser import CircuitAnalysis, analyze_circuit
12
- from ck.circuit_compiler.support.input_vars import infer_input_vars, InputVars
13
- from ck.program.raw_program import RawProgram, RawProgramFunction
14
- from ck.utils.np_extras import DType, NDArrayNumeric, NDArray, DTypeNumeric
11
+ from ck.program.raw_program import RawProgramFunction
12
+ from ck.utils.np_extras import NDArrayNumeric, DTypeNumeric
15
13
 
16
14
  from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
17
15
 
@@ -63,21 +61,21 @@ def make_function(
63
61
 
64
62
 
65
63
  # VM instructions
66
- ADD = circuit.ADD
67
- MUL = circuit.MUL
68
- COPY: int = max(ADD, MUL) + 1
64
+ cdef int ADD = circuit.ADD
65
+ cdef int MUL = circuit.MUL
66
+ cdef int COPY = max(ADD, MUL) + 1
69
67
 
70
68
  # VM arrays
71
- VARS: int = 0
72
- TMPS: int = 1
73
- CONSTS: int = 2
74
- RESULT: int = 3
69
+ cdef int VARS = 0
70
+ cdef int TMPS = 1
71
+ cdef int CONSTS = 2
72
+ cdef int RESULT = 3
75
73
 
76
74
 
77
- def _make_instructions_from_analysis(
78
- analysis: CircuitAnalysis,
79
- dtype: DTypeNumeric,
80
- ) -> Tuple[Instructions, NDArrayNumeric]:
75
+ cdef tuple[Instructions, cnp.ndarray] _make_instructions_from_analysis(
76
+ object analysis: CircuitAnalysis,
77
+ object dtype: DTypeNumeric,
78
+ ): # -> Tuple[Instructions, NDArrayNumeric]:
81
79
  if dtype != np.float64:
82
80
  raise RuntimeError(f'only DType {np.float64} currently supported')
83
81
 
@@ -91,7 +89,7 @@ def _make_instructions_from_analysis(
91
89
  np_consts[i] = node.value
92
90
 
93
91
  # Where to get input values for each possible node.
94
- node_to_element: Dict[int, ElementID] = {}
92
+ cdef dict[int, ElementID] node_to_element = {}
95
93
  # const nodes
96
94
  for node_id, const_idx in node_to_const_idx.items():
97
95
  node_to_element[node_id] = ElementID(CONSTS, const_idx)
@@ -110,21 +108,16 @@ def _make_instructions_from_analysis(
110
108
  # Build instructions
111
109
  instructions: Instructions = Instructions()
112
110
 
113
- op_node: OpNode
111
+ cdef object op_node
114
112
  for op_node in analysis.op_nodes:
115
- dest: ElementID = node_to_element[id(op_node)]
116
- args: list[ElementID] = [
117
- node_to_element[id(arg)]
118
- for arg in op_node.args
119
- ]
120
- instructions.append(op_node.symbol, args, dest)
113
+ instructions.append_op(op_node.symbol, op_node, node_to_element)
121
114
 
122
115
  # Add any copy operations, i.e., result nodes that are not op nodes
123
116
  for i, node in enumerate(analysis.result_nodes):
124
117
  if not isinstance(node, OpNode):
125
118
  dest: ElementID = ElementID(RESULT, i)
126
- args: list[ElementID] = [node_to_element[id(node)]]
127
- instructions.append(COPY, args, dest)
119
+ src: ElementID = node_to_element[id(node)]
120
+ instructions.append_copy(src, dest)
128
121
 
129
122
  return instructions, np_consts
130
123
 
@@ -136,39 +129,70 @@ cdef struct ElementID:
136
129
 
137
130
  cdef struct Instruction:
138
131
  int symbol # ADD, MUL, COPY
139
- int num_args
132
+ Py_ssize_t num_args
140
133
  ElementID* args
141
134
  ElementID dest
142
135
 
143
136
 
144
137
  cdef class Instructions:
145
138
  cdef Instruction* instructions
139
+ cdef int allocated
146
140
  cdef int num_instructions
147
141
 
148
- def __init__(self):
149
- self.instructions = <Instruction*> PyMem_Malloc(0)
142
+ def __init__(self) -> None:
150
143
  self.num_instructions = 0
144
+ self.allocated = 64
145
+ self.instructions = <Instruction*> PyMem_Malloc(self.allocated * sizeof(Instruction))
146
+
147
+ cdef void append_copy(
148
+ self,
149
+ ElementID src,
150
+ ElementID dest,
151
+ ) except*:
152
+ c_args = <ElementID*> PyMem_Malloc(sizeof(ElementID))
153
+ if not c_args:
154
+ raise MemoryError()
151
155
 
152
- def append(self, int symbol, list[ElementID] args, ElementID dest) -> None:
153
- cdef int num_args = len(args)
154
- cdef int i
156
+ c_args[0] = src
157
+ self._append(COPY, 1, c_args, dest)
158
+
159
+ cdef void append_op(self, int symbol, object op_node: OpNode, dict[int, ElementID] node_to_element) except*:
160
+ args = op_node.args
161
+ cdef Py_ssize_t num_args = len(args)
155
162
 
156
- c_args = <ElementID*> PyMem_Malloc(
157
- num_args * sizeof(ElementID))
163
+ # Create the instruction arguments array
164
+ c_args = <ElementID*> PyMem_Malloc(num_args * sizeof(ElementID))
158
165
  if not c_args:
159
166
  raise MemoryError()
160
167
 
161
- for i in range(num_args):
162
- c_args[i] = args[i]
168
+ cdef Py_ssize_t i = num_args
169
+ while i > 0:
170
+ i -= 1
171
+ c_args[i] = node_to_element[id(args[i])]
172
+
173
+ dest: ElementID = node_to_element[id(op_node)]
174
+
175
+ self._append(symbol, num_args, c_args, dest)
176
+
177
+
178
+ cdef void _append(self, int symbol, Py_ssize_t num_args, ElementID* c_args, ElementID dest) except *:
179
+ cdef int i
163
180
 
164
181
  cdef int num_instructions = self.num_instructions
165
- self.instructions = <Instruction*> PyMem_Realloc(
166
- self.instructions,
167
- sizeof(Instruction) * (num_instructions + 1)
168
- )
169
- if not self.instructions:
170
- raise MemoryError()
171
182
 
183
+ # Ensure sufficient instruction memory
184
+ cdef int allocated = self.allocated
185
+ if num_instructions == allocated:
186
+ allocated *= 2
187
+ self.instructions = <Instruction*> PyMem_Realloc(
188
+ self.instructions,
189
+ allocated * sizeof(Instruction),
190
+ )
191
+ if not self.instructions:
192
+ raise MemoryError()
193
+ self.allocated = allocated
194
+
195
+ # Add the instruction
172
196
  self.instructions[num_instructions] = Instruction(
173
197
  symbol,
174
198
  num_args,
@@ -177,7 +201,7 @@ cdef class Instructions:
177
201
  )
178
202
  self.num_instructions = num_instructions + 1
179
203
 
180
- def __dealloc__(self):
204
+ def __dealloc__(self) -> None:
181
205
  cdef Instruction* instructions = self.instructions
182
206
  if instructions:
183
207
  for i in range(self.num_instructions):
@@ -194,15 +218,16 @@ cdef void cvm_float64(
194
218
  double* consts,
195
219
  double* result,
196
220
  Instructions instructions,
197
- ):
198
- # Core virtual machine.
221
+ ) except *:
222
+ # Core virtual machine (for dtype float64).
199
223
 
200
- cdef int i, num_args, symbol
224
+ cdef int symbol
225
+ cdef Py_ssize_t i
201
226
  cdef double accumulator
202
227
  cdef ElementID* args
203
- cdef ElementID elem
228
+ cdef ElementID elem
204
229
 
205
- # index the four arrays by constants VARS, TMPS, CONSTS, and RESULT
230
+ # Index the four arrays by constants VARS, TMPS, CONSTS, and RESULT
206
231
  cdef (double*) arrays[4]
207
232
  arrays[VARS] = vars_in
208
233
  arrays[TMPS] = tmps
@@ -210,29 +235,32 @@ cdef void cvm_float64(
210
235
  arrays[RESULT] = result
211
236
 
212
237
  cdef Instruction* instruction_ptr = instructions.instructions
213
- for _ in range(instructions.num_instructions):
238
+ cdef int num_instructions = instructions.num_instructions
214
239
 
215
- symbol = instruction_ptr[0].symbol
216
- args = instruction_ptr[0].args
217
- num_args = instruction_ptr[0].num_args
240
+ while num_instructions > 0:
241
+ num_instructions -= 1
242
+
243
+ symbol = instruction_ptr.symbol
244
+ args = instruction_ptr.args
218
245
 
219
246
  elem = args[0]
220
247
  accumulator = arrays[elem.array][elem.index]
221
248
 
222
249
  if symbol == ADD:
223
- for i in range(1, num_args):
250
+ i = instruction_ptr.num_args
251
+ while i > 1:
252
+ i -= 1
224
253
  elem = args[i]
225
254
  accumulator += arrays[elem.array][elem.index]
226
255
  elif symbol == MUL:
227
- for i in range(1, num_args):
256
+ i = instruction_ptr.num_args
257
+ while i > 1:
258
+ i -= 1
228
259
  elem = args[i]
229
260
  accumulator *= arrays[elem.array][elem.index]
230
- elif symbol == COPY:
231
- pass
232
- else:
233
- raise RuntimeError('symbol not understood: ' + str(symbol))
261
+ # else symbol == COPY, nothing to do
234
262
 
235
- elem = instruction_ptr[0].dest
263
+ elem = instruction_ptr.dest
236
264
  arrays[elem.array][elem.index] = accumulator
237
265
 
238
266
  # Advance the instruction pointer
@@ -54,4 +54,4 @@ class NamedCircuitCompiler(Enum):
54
54
  return self.value[0]
55
55
 
56
56
 
57
- DEFAULT_CIRCUIT_COMPILER: NamedCircuitCompiler = NamedCircuitCompiler.LLVM_VM
57
+ DEFAULT_CIRCUIT_COMPILER: NamedCircuitCompiler = NamedCircuitCompiler.CYTHON_VM
@@ -149,7 +149,7 @@ def join_tree_to_circuit(
149
149
  limit_product_tree_search,
150
150
  )
151
151
  top: CircuitNode = top_table.top()
152
- top_table.circuit.remove_unreachable_op_nodes(top)
152
+ top.circuit.remove_unreachable_op_nodes(top)
153
153
 
154
154
  return PGMCircuit(
155
155
  rvs=tuple(pgm.rvs),
@@ -169,27 +169,37 @@ def _circuit_tables_from_join_tree(
169
169
  ) -> CircuitTable:
170
170
  """
171
171
  This is a basic algorithm for constructing a circuit table from a join tree.
172
+ Algorithm synopsis:
173
+ 1) Get a CircuitTable for each factor allocated to this join tree node, and
174
+ for each child of the join tree node (recursive call to _circuit_tables_from_join_tree).
175
+ 2) Form a binary tree of the collected circuit tables.
176
+ 3) Perform table products and sum-outs for each node in the binary tree, which should
177
+ leave a single circuit table with a single row.
172
178
  """
173
- # The PGM factors allocated to this join tree node
174
- factors: List[CircuitTable] = [
175
- factor_tables.get_table(factor)
176
- for factor in join_tree.factors
177
- ]
178
-
179
- # The children of this join tree node
180
- factors.extend(
181
- _circuit_tables_from_join_tree(factor_tables, child, limit_product_tree_search)
182
- for child in join_tree.children
179
+ # Get all the factors to combine.
180
+ factors: List[CircuitTable] = list(
181
+ chain(
182
+ (
183
+ # The PGM factors allocated to this join tree node
184
+ factor_tables.get_table(factor)
185
+ for factor in join_tree.factors
186
+ ),
187
+ (
188
+ # The children of this join tree node
189
+ _circuit_tables_from_join_tree(factor_tables, child, limit_product_tree_search)
190
+ for child in join_tree.children
191
+ ),
192
+ )
183
193
  )
184
194
 
185
195
  # The usual join tree approach just forms the product all the tables in `factors`.
186
196
  # The tree width is not affected by the order of products, however some orders
187
197
  # lead to smaller numbers of arithmetic operations.
188
198
  #
189
- # If `options.optimise_products` is true, then heuristics are used
199
+ # If `limit_product_tree_search > 1`, then heuristics are used
190
200
  # reduce the number of arithmetic operations.
191
201
 
192
- # Deal with the special case: no factors
202
+ # Deal with the special case: zero factors
193
203
  if len(factors) == 0:
194
204
  circuit = factor_tables.circuit
195
205
  if len(join_tree.separator) == 0:
@@ -1,5 +1,12 @@
1
- # from .circuit_table_py import (
2
- from .circuit_table import (
1
+ # There are two implementations of the `circuit_table` module are provided
2
+ # for developer R&D purposes. One is pure Python and the other is Cython.
3
+ # Which implementation is used can be selected here.
4
+ # A similar selection can be made for the `circuit` module.
5
+ # Note that if the Cython implementation is chosen for `circuit_table` then
6
+ # the Cython implementation must be chosen for `circuit`.
7
+
8
+ # from ._circuit_table_py import (
9
+ from ._circuit_table_cy import (
3
10
  CircuitTable,
4
11
  TableInstance,
5
12
  sum_out,
@@ -0,0 +1,332 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence, Tuple, Iterable
4
+
5
+ from ck.circuit import MUL, ADD
6
+
7
+ from ck.circuit._circuit_cy cimport Circuit, CircuitNode
8
+
9
+ cdef int c_ADD = ADD
10
+ cdef int c_MUL = MUL
11
+
12
+
13
+ TableInstance = Tuple[int, ...]
14
+
15
+
16
+ cdef class CircuitTable:
17
+ """
18
+ A circuit table manages a set of CircuitNodes, where each node corresponds
19
+ to an instance for a set of (zero or more) random variables.
20
+
21
+ Operations on circuit tables typically add circuit nodes to the circuit. It will
22
+ heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
23
+ by zero or one.) However, it may be that interim circuit nodes are created that
24
+ end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
25
+ completing all circuit table operations.
26
+
27
+ It is generally expected that no CircuitTable row will be created with a constant
28
+ zero node. These are assumed to be optimised out already.
29
+ """
30
+
31
+ cdef public Circuit circuit
32
+ cdef public tuple[int, ...] rv_idxs
33
+ cdef dict[tuple[int, ...], CircuitNode] rows
34
+
35
+ def __init__(
36
+ self,
37
+ circuit: Circuit,
38
+ rv_idxs: Sequence[int],
39
+ rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
40
+ ):
41
+ """
42
+ Args:
43
+ circuit: the circuit whose nodes are being managed by this table.
44
+ rv_idxs: indexes of random variables.
45
+ rows: optional rows to add to the table.
46
+
47
+ Assumes:
48
+ * rv_idxs contains no duplicates.
49
+ * all row instances conform to the indexed random variables.
50
+ * all row circuit nodes belong to the given circuit.
51
+ """
52
+ self.circuit = circuit
53
+ self.rv_idxs = tuple(rv_idxs)
54
+ self.rows = dict(rows)
55
+
56
+ def __len__(self) -> int:
57
+ return len(self.rows)
58
+
59
+ def get(self, key, default=None):
60
+ return self.rows.get(key, default)
61
+
62
+ def keys(self) -> Iterable[CircuitNode]:
63
+ return self.rows.keys()
64
+
65
+ def values(self) -> Iterable[tuple[int, ...]]:
66
+ return self.rows.values()
67
+
68
+ def __getitem__(self, key):
69
+ return self.rows[key]
70
+
71
+ def __setitem__(self, key, value):
72
+ self.rows[key] = value
73
+
74
+ cpdef CircuitNode top(self):
75
+ # Get the circuit top value.
76
+ #
77
+ # Raises:
78
+ # RuntimeError if there is more than one row in the table.
79
+ #
80
+ # Returns:
81
+ # A single circuit node.
82
+ cdef int number_of_rows = len(self.rows)
83
+ if number_of_rows == 0:
84
+ return self.circuit.zero
85
+ elif number_of_rows == 1:
86
+ return next(iter(self.rows.values()))
87
+ else:
88
+ raise RuntimeError('cannot get top node from a table with more that 1 row')
89
+
90
+
91
+ # ==================================================================================
92
+ # Circuit Table Operations
93
+ # ==================================================================================
94
+
95
+ cpdef CircuitTable sum_out(CircuitTable table, object rv_idxs: Iterable[int]):
96
+ # Return a circuit table that results from summing out
97
+ # the given random variables of this circuit table.
98
+ #
99
+ # Normally this will return a new table. However, if rv_idxs is empty,
100
+ # then the given table is returned unmodified.
101
+ #
102
+ # Raises:
103
+ # ValueError if rv_idxs is not a subset of table.rv_idxs.
104
+ # ValueError if rv_idxs contains duplicates.
105
+ cdef tuple[int, ...] rv_idxs_seq = tuple(rv_idxs)
106
+
107
+ if len(rv_idxs_seq) == 0:
108
+ # nothing to do
109
+ return table
110
+
111
+ cdef set[int] rv_idxs_set = set(rv_idxs_seq)
112
+ if len(rv_idxs_set) != len(rv_idxs_seq):
113
+ raise ValueError('rv_idxs contains duplicates')
114
+ if not rv_idxs_set.issubset(table.rv_idxs):
115
+ raise ValueError('rv_idxs is not a subset of table.rv_idxs')
116
+
117
+ cdef int rv_index
118
+ cdef list[int] remaining_rv_idxs = []
119
+ for rv_index in table.rv_idxs:
120
+ if rv_index not in rv_idxs_set:
121
+ remaining_rv_idxs.append(rv_index)
122
+
123
+ cdef int num_remaining = len(remaining_rv_idxs)
124
+ if num_remaining == 0:
125
+ # Special case: summing out all random variables
126
+ return sum_out_all(table)
127
+
128
+ # index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
129
+ cdef list[int] index_map = []
130
+ for rv_index in remaining_rv_idxs:
131
+ index_map.append(_find(table.rv_idxs, rv_index))
132
+
133
+ cdef dict[tuple[int, ...], list[CircuitNode]] groups = {}
134
+ cdef object got
135
+ cdef list[int] group_instance
136
+ cdef tuple[int, ...] group_instance_tuple
137
+ cdef int i
138
+ cdef CircuitNode node
139
+ cdef tuple[int, ...] instance
140
+ for instance, node in table.rows.items():
141
+ group_instance = []
142
+ for i in index_map:
143
+ group_instance.append(instance[i])
144
+ group_instance_tuple = tuple(group_instance)
145
+ got = groups.get(group_instance_tuple)
146
+ if got is None:
147
+ groups[group_instance_tuple] = [node]
148
+ else:
149
+ got.append(node)
150
+
151
+ cdef Circuit circuit = table.circuit
152
+ cdef CircuitTable new_table = CircuitTable(circuit, remaining_rv_idxs)
153
+ cdef dict[tuple[int, ...], CircuitNode] rows = new_table.rows
154
+
155
+ for group_instance_tuple, to_add in groups.items():
156
+ node = circuit.op(c_ADD, tuple(to_add))
157
+ if not node.is_zero:
158
+ rows[group_instance_tuple] = node
159
+
160
+ return new_table
161
+
162
+
163
+ cpdef CircuitTable sum_out_all(CircuitTable table):
164
+ # Return a circuit table that results from summing out
165
+ # all random variables of this circuit table.
166
+ circuit: Circuit = table.circuit
167
+ num_rows: int = len(table)
168
+ if num_rows == 0:
169
+ return CircuitTable(circuit, ())
170
+ elif num_rows == 1:
171
+ node = next(iter(table.rows.values()))
172
+ else:
173
+ node: CircuitNode = circuit.op(c_ADD, tuple(table.rows.values()))
174
+ if node.is_zero:
175
+ return CircuitTable(circuit, ())
176
+
177
+ return CircuitTable(circuit, (), [((), node)])
178
+
179
+
180
+ cpdef CircuitTable project(CircuitTable table: CircuitTable, object rv_idxs: Iterable[int]):
181
+ # Call `sum_out(table, to_sum_out)`, where
182
+ # `to_sum_out = table.rv_idxs - rv_idxs`.
183
+ cdef set[int] to_sum_out = set(table.rv_idxs)
184
+ to_sum_out.difference_update(rv_idxs)
185
+ return sum_out(table, to_sum_out)
186
+
187
+
188
+ cpdef CircuitTable product(CircuitTable x, CircuitTable y):
189
+ # Return a circuit table that results from the product of the two given tables.
190
+ #
191
+ # If x or y equals `one_table`, then the other table is returned. Otherwise,
192
+ # a new circuit table will be constructed and returned.
193
+ cdef int i
194
+ cdef Circuit circuit = x.circuit
195
+ if y.circuit is not circuit:
196
+ raise ValueError('circuit tables must refer to the same circuit')
197
+
198
+ # Make the smaller table 'y', and the other 'x'.
199
+ # This is to minimise the index size on 'y'.
200
+ if len(x) < len(y):
201
+ x, y = y, x
202
+
203
+ # Special case: y == 0 or 1, and has no random variables.
204
+ if len(y.rv_idxs) == 0:
205
+ if len(y) == 1 and y.top().is_one:
206
+ return x
207
+ elif len(y) == 0:
208
+ return CircuitTable(circuit, x.rv_idxs)
209
+
210
+ # Set operations on rv indexes. After these operations:
211
+ # * co_rv_idxs is the set of rv indexes common (co) to x and y,
212
+ # * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
213
+ cdef set[int] yo_rv_idxs_set = set(y.rv_idxs)
214
+ cdef set[int] co_rv_idxs_set = set(x.rv_idxs)
215
+ co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
216
+ yo_rv_idxs_set.difference_update(co_rv_idxs_set)
217
+
218
+ if len(co_rv_idxs_set) == 0:
219
+ # Special case: no common random variables.
220
+ return _product_no_common_rvs(x, y)
221
+
222
+ # Convert random variable index sets to sequences
223
+ cdef tuple[int, ...] yo_rv_idxs = tuple(yo_rv_idxs_set) # y only random variables
224
+ cdef tuple[int, ...] co_rv_idxs = tuple(co_rv_idxs_set) # common random variables
225
+
226
+ # Cache mappings from result Instance to index into source Instance (x or y).
227
+ # This will be used in indexing and product loops to pull our needed values
228
+ # from the source instances.
229
+ cdef list[int] co_from_x_map = []
230
+ cdef list[int] co_from_y_map = []
231
+ cdef list[int] yo_from_y_map = []
232
+ for rv_index in co_rv_idxs:
233
+ co_from_x_map.append(_find(x.rv_idxs, rv_index))
234
+ co_from_y_map.append(_find(y.rv_idxs, rv_index))
235
+ for rv_index in yo_rv_idxs:
236
+ yo_from_y_map.append(_find(y.rv_idxs, rv_index))
237
+
238
+ cdef list[int] co
239
+ cdef list[int] yo
240
+ cdef object got
241
+ cdef tuple[int, ...] co_tuple
242
+ cdef tuple[int, ...] yo_tuple
243
+
244
+ cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + yo_rv_idxs)
245
+ cdef dict[tuple[int, ...], CircuitNode] rows = table.rows
246
+
247
+
248
+ # Index the y rows by common-only key (y is the smaller of the two tables).
249
+ cdef dict[tuple[int, ...], list[tuple[tuple[int, ...], CircuitNode]]] y_index = {}
250
+ for y_instance, y_node in y.rows.items():
251
+ co = []
252
+ yo = []
253
+ for i in co_from_y_map:
254
+ co.append(y_instance[i])
255
+ for i in yo_from_y_map:
256
+ yo.append(y_instance[i])
257
+ co_tuple = tuple(co)
258
+ yo_tuple = tuple(yo)
259
+ got = y_index.get(co_tuple)
260
+ if got is None:
261
+ y_index[co_tuple] = [(yo_tuple, y_node)]
262
+ else:
263
+ got.append((yo_tuple, y_node))
264
+
265
+
266
+ # Iterate over x rows, inserting (instance, value).
267
+ # Rows with constant node values of one are optimised out.
268
+ for x_instance, x_node in x.rows.items():
269
+ co = []
270
+ for i in co_from_x_map:
271
+ co.append(x_instance[i])
272
+ co_tuple = tuple(co)
273
+
274
+ if x_node.is_one:
275
+ # Multiplying by one.
276
+ # Iterate over matching y rows.
277
+ got = y_index.get(co_tuple)
278
+ if got is not None:
279
+ for yo_tuple, y_node in got:
280
+ rows[x_instance + yo_tuple] = y_node
281
+ else:
282
+ # Iterate over matching y rows.
283
+ got = y_index.get(co_tuple)
284
+ if got is not None:
285
+ for yo_tuple, y_node in got:
286
+ if y_node.is_one:
287
+ rows[x_instance + yo_tuple] = x_node
288
+ else:
289
+ rows[x_instance + yo_tuple] = circuit.op(c_MUL, (x_node, y_node))
290
+
291
+ return table
292
+
293
+
294
+ cdef int _find(tuple[int, ...] xs, int x):
295
+ cdef int i
296
+ for i in range(len(xs)):
297
+ if xs[i] == x:
298
+ return i
299
+ # Very unexpected
300
+ raise RuntimeError('not found')
301
+
302
+
303
+ cdef CircuitTable _product_no_common_rvs(CircuitTable x, CircuitTable y):
304
+ # Return the product of x and y, where x and y have no common random variables.
305
+ #
306
+ # This is an optimisation of more general product algorithm as no index needs
307
+ # to be construction based on the common random variables.
308
+ #
309
+ # Rows with constant node values of one are optimised out.
310
+ #
311
+ # Assumes:
312
+ # * There are no common random variables between x and y.
313
+ # * x and y are for the same circuit.
314
+ cdef Circuit circuit = x.circuit
315
+ cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + y.rv_idxs)
316
+ cdef tuple[int, ...] instance
317
+
318
+ for x_instance, x_node in x.rows.items():
319
+ if x_node.is_one:
320
+ for y_instance, y_node in y.rows.items():
321
+ instance = x_instance + y_instance
322
+ table.rows[instance] = y_node
323
+ else:
324
+ for y_instance, y_node in y.rows.items():
325
+ instance = x_instance + y_instance
326
+ if y_node.is_one:
327
+ table.rows[instance] = x_node
328
+ else:
329
+ table.rows[instance] = circuit.op(c_MUL, (x_node, y_node))
330
+
331
+ return table
332
+