compiled-knowledge 4.0.0a17__cp312-cp312-win_amd64.whl → 4.0.0a18__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (23) hide show
  1. ck/circuit/__init__.py +7 -0
  2. ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
  3. ck/circuit/_circuit_cy.pxd +33 -0
  4. ck/circuit/_circuit_cy.pyx +84 -110
  5. ck/circuit/_circuit_py.py +2 -2
  6. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  7. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +6 -5
  8. ck/pgm_compiler/support/circuit_table/__init__.py +7 -0
  9. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
  10. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +44 -37
  11. ck/pgm_compiler/support/circuit_table/_circuit_table_cy_cpp_verion.pyx +601 -0
  12. ck/pgm_compiler/support/circuit_table/_circuit_table_cy_minimal_version.pyx +311 -0
  13. ck/pgm_compiler/support/circuit_table/_circuit_table_cy_v4.0.0a17.pyx +325 -0
  14. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +76 -41
  15. ck/pgm_compiler/support/named_compiler_maker.py +12 -2
  16. ck/utils/iter_extras.py +8 -1
  17. ck_demos/utils/compare.py +5 -1
  18. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a18.dist-info}/METADATA +1 -1
  19. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a18.dist-info}/RECORD +22 -19
  20. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a18.dist-info}/WHEEL +1 -1
  21. ck/circuit_compiler/cython_vm_compiler/_compiler.c +0 -16946
  22. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a18.dist-info}/licenses/LICENSE.txt +0 -0
  23. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence, Tuple, Dict, Iterable, Set, Iterator
4
+
5
+ from ck.circuit import ADD, MUL
6
+ from ck.utils.map_list import MapList
7
+
8
+ from ck.circuit._circuit_cy cimport Circuit, CircuitNode
9
+
10
+ cdef int c_ADD = ADD
11
+ cdef int c_MUL = MUL
12
+
13
+
14
+ TableInstance = Tuple[int, ...]
15
+
16
+
17
+ class CircuitTable:
18
+ """
19
+ A circuit table manages a set of CircuitNodes, where each node corresponds
20
+ to an instance for a set of (zero or more) random variables.
21
+
22
+ Operations on circuit tables typically add circuit nodes to the circuit. It will
23
+ heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
24
+ by zero or one.) However, it may be that interim circuit nodes are created that
25
+ end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
26
+ completing all circuit table operations.
27
+
28
+ It is generally expected that no CircuitTable row will be created with a constant
29
+ zero node. These are assumed to be optimised out already.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ circuit: Circuit,
35
+ rv_idxs: Sequence[int],
36
+ rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
37
+ ):
38
+ """
39
+ Args:
40
+ circuit: the circuit whose nodes are being managed by this table.
41
+ rv_idxs: indexes of random variables.
42
+ rows: optional rows to add to the table.
43
+
44
+ Assumes:
45
+ * rv_idxs contains no duplicates.
46
+ * all row instances conform to the indexed random variables.
47
+ * all row circuit nodes belong to the given circuit.
48
+ """
49
+ self._circuit: Circuit = circuit
50
+ self._rv_idxs: Tuple[int, ...] = tuple(rv_idxs)
51
+ self._rows: Dict[TableInstance, CircuitNode] = dict(rows)
52
+
53
+ @property
54
+ def circuit(self) -> Circuit:
55
+ return self._circuit
56
+
57
+ @property
58
+ def rv_idxs(self) -> Tuple[int, ...]:
59
+ return self._rv_idxs
60
+
61
+ def __len__(self) -> int:
62
+ return len(self._rows)
63
+
64
+ def get(self, key, default=None):
65
+ return self._rows.get(key, default)
66
+
67
+ def keys(self) -> Iterable[TableInstance]:
68
+ return self._rows.keys()
69
+
70
+ def values(self) -> Iterable[CircuitNode]:
71
+ return self._rows.values()
72
+
73
+ def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
74
+ return self._rows.items()
75
+
76
+ def __getitem__(self, key):
77
+ return self._rows[key]
78
+
79
+ def __setitem__(self, key, value):
80
+ self._rows[key] = value
81
+
82
+ def top(self) -> CircuitNode:
83
+ """
84
+ Get the circuit top value.
85
+
86
+ Raises:
87
+ RuntimeError if there is more than one row in the table.
88
+
89
+ Returns:
90
+ A single circuit node.
91
+ """
92
+ if len(self._rows) == 0:
93
+ return self._circuit.zero
94
+ elif len(self._rows) == 1:
95
+ return next(iter(self._rows.values()))
96
+ else:
97
+ raise RuntimeError('cannot get top node from a table with more that 1 row')
98
+
99
+
100
+ # ==================================================================================
101
+ # Circuit Table Operations
102
+ # ==================================================================================
103
+
104
+
105
+ def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
106
+ """
107
+ Return a circuit table that results from summing out
108
+ the given random variables of this circuit table.
109
+
110
+ Normally this will return a new table. However, if rv_idxs is empty,
111
+ then the given table is returned unmodified.
112
+
113
+ Raises:
114
+ ValueError if rv_idxs is not a subset of table.rv_idxs.
115
+ ValueError if rv_idxs contains duplicates.
116
+ """
117
+ rv_idxs: Sequence[int] = tuple(rv_idxs)
118
+
119
+ if len(rv_idxs) == 0:
120
+ # nothing to do
121
+ return table
122
+
123
+ rv_idxs_set: Set[int] = set(rv_idxs)
124
+ if len(rv_idxs_set) != len(rv_idxs):
125
+ raise ValueError('rv_idxs contains duplicates')
126
+ if not rv_idxs_set.issubset(table.rv_idxs):
127
+ raise ValueError('rv_idxs is not a subset of table.rv_idxs')
128
+
129
+ remaining_rv_idxs = tuple(
130
+ rv_index
131
+ for rv_index in table.rv_idxs
132
+ if rv_index not in rv_idxs_set
133
+ )
134
+ num_remaining = len(remaining_rv_idxs)
135
+ if num_remaining == 0:
136
+ # Special case: summing out all random variables
137
+ return sum_out_all(table)
138
+
139
+ # index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
140
+ index_map = tuple(
141
+ table.rv_idxs.index(remaining_rv_index)
142
+ for remaining_rv_index in remaining_rv_idxs
143
+ )
144
+
145
+ # This is a one-pass version to sum the groups. The
146
+ # two-pass version (below) seems to have better performance.
147
+ #
148
+ # circuit: Circuit = table.circuit
149
+ # result = CircuitTable(circuit, remaining_rv_idxs)
150
+ # result_rows: Dict[TableInstance, CircuitNode] = result._rows
151
+ # for instance, node in table.items():
152
+ # group_instance = tuple(instance[i] for i in index_map)
153
+ # prev_sum = result_rows.get(group_instance)
154
+ # if prev_sum is None:
155
+ # result_rows[group_instance] = node
156
+ # else:
157
+ # result_rows[group_instance] = circuit.add(prev_sum, node)
158
+ # return result
159
+
160
+ groups: MapList[TableInstance, CircuitNode] = MapList()
161
+ for instance, node in table.items():
162
+ group_instance = tuple(instance[i] for i in index_map)
163
+ groups.append(group_instance, node)
164
+ circuit: Circuit = table.circuit
165
+ return CircuitTable(
166
+ circuit,
167
+ remaining_rv_idxs,
168
+ (
169
+ (group, circuit.op(c_ADD, tuple(to_add)))
170
+ for group, to_add in groups.items()
171
+ )
172
+ )
173
+
174
+
175
+
176
+ def sum_out_all(table: CircuitTable) -> CircuitTable:
177
+ """
178
+ Return a circuit table that results from summing out
179
+ all random variables of this circuit table.
180
+ """
181
+ circuit: Circuit = table.circuit
182
+ num_rows: int = len(table)
183
+ if num_rows == 0:
184
+ return CircuitTable(circuit, ())
185
+ elif num_rows == 1:
186
+ node = next(iter(table.values()))
187
+ else:
188
+ node: CircuitNode = circuit.op(c_ADD, tuple(table.values()))
189
+ if node.is_zero:
190
+ return CircuitTable(circuit, ())
191
+
192
+ return CircuitTable(circuit, (), [((), node)])
193
+
194
+
195
+ def project(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
196
+ """
197
+ Call `sum_out(table, to_sum_out)`, where
198
+ `to_sum_out = table.rv_idxs - rv_idxs`.
199
+ """
200
+ to_sum_out: Set[int] = set(table.rv_idxs)
201
+ to_sum_out.difference_update(rv_idxs)
202
+ return sum_out(table, to_sum_out)
203
+
204
+
205
+ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
206
+ """
207
+ Return a circuit table that results from the product of the two given tables.
208
+
209
+ If x or y have a single row with value 1, then the other table is returned. Otherwise,
210
+ a new circuit table will be constructed and returned.
211
+ """
212
+ circuit: Circuit = x.circuit
213
+ if y.circuit is not circuit:
214
+ raise ValueError('circuit tables must refer to the same circuit')
215
+
216
+ # Make the smaller table 'y', and the other 'x'.
217
+ # This is to minimise the index size on 'y'.
218
+ if len(x) < len(y):
219
+ x, y = y, x
220
+
221
+ x_rv_idxs: Tuple[int, ...] = x.rv_idxs
222
+ y_rv_idxs: Tuple[int, ...] = y.rv_idxs
223
+
224
+ # Special case: y == 0 or 1, and has no random variables.
225
+ if y_rv_idxs == ():
226
+ if len(y) == 1 and y.top().is_one:
227
+ return x
228
+ elif len(y) == 0:
229
+ return CircuitTable(circuit, x_rv_idxs)
230
+
231
+ # Set operations on rv indexes. After these operations:
232
+ # * co_rv_idxs is the set of rv indexes common (co) to x and y,
233
+ # * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
234
+ yo_rv_idxs_set: Set[int] = set(y_rv_idxs)
235
+ co_rv_idxs_set: Set[int] = set(x_rv_idxs)
236
+ co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
237
+ yo_rv_idxs_set.difference_update(co_rv_idxs_set)
238
+
239
+ if len(co_rv_idxs_set) == 0:
240
+ # Special case: no common random variables.
241
+ return _product_no_common_rvs(x, y)
242
+
243
+ # Convert random variable index sets to sequences
244
+ yo_rv_idxs: Tuple[int, ...] = tuple(yo_rv_idxs_set) # y only random variables
245
+ co_rv_idxs: Tuple[int, ...] = tuple(co_rv_idxs_set) # common random variables
246
+
247
+ # Cache mappings from result Instance to index into source Instance (x or y).
248
+ # This will be used in indexing and product loops to pull our needed values
249
+ # from the source instances.
250
+ co_from_x_map = tuple(x.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
251
+ co_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
252
+ yo_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in yo_rv_idxs)
253
+
254
+ # Index the y rows by common-only key (y is the smaller of the two tables).
255
+ y_index: MapList[TableInstance, Tuple[TableInstance, CircuitNode]] = MapList()
256
+ for y_instance, y_node in y.items():
257
+ co = tuple(y_instance[i] for i in co_from_y_map)
258
+ yo = tuple(y_instance[i] for i in yo_from_y_map)
259
+ y_index.append(co, (yo, y_node))
260
+
261
+ def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
262
+ # Iterate over x rows, yielding (instance, value).
263
+ # Rows with constant node values of one are optimised out.
264
+ for _x_instance, _x_node in x.items():
265
+ _co = tuple(_x_instance[i] for i in co_from_x_map)
266
+ if _x_node.is_one:
267
+ # Multiplying by one.
268
+ # Iterate over matching y rows.
269
+ for _yo, _y_node in y_index.get(_co, ()):
270
+ yield _x_instance + _yo, _y_node
271
+ else:
272
+ # Iterate over matching y rows.
273
+ for _yo, _y_node in y_index.get(_co, ()):
274
+ if _y_node.is_one:
275
+ yield _x_instance + _yo, _x_node
276
+ else:
277
+ yield _x_instance + _yo, circuit.op(c_MUL, (_x_node, _y_node))
278
+
279
+ return CircuitTable(circuit, x_rv_idxs + yo_rv_idxs, _result_rows())
280
+
281
+
282
+ def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
283
+ """
284
+ Return the product of x and y, where x and y have no common random variables.
285
+
286
+ This is an optimisation of more general product algorithm as no index needs
287
+ to be construction based on the common random variables.
288
+
289
+ Rows with constant node values of one are optimised out.
290
+
291
+ Assumes:
292
+ * There are no common random variables between x and y.
293
+ * x and y are for the same circuit.
294
+ """
295
+ circuit: Circuit = x.circuit
296
+
297
+ result_rv_idxs: Tuple[int, ...] = x.rv_idxs + y.rv_idxs
298
+
299
+ def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
300
+ for x_instance, x_node in x.items():
301
+ if x_node.is_one:
302
+ for y_instance, y_node in y.items():
303
+ yield x_instance + y_instance, y_node
304
+ else:
305
+ for y_instance, y_node in y.items():
306
+ if y_node.is_one:
307
+ yield x_instance + y_instance, y_node
308
+ else:
309
+ yield x_instance + y_instance, circuit.op(c_MUL, (x_node, y_node))
310
+
311
+ return CircuitTable(circuit, result_rv_idxs, _result_rows())
@@ -0,0 +1,325 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence, Tuple, Iterable, Iterator
4
+
5
+ from ck.circuit import CircuitNode, Circuit, OpNode, MUL
6
+
7
+ TableInstance = Tuple[int, ...]
8
+
9
+
10
+ cdef class CircuitTable:
11
+ """
12
+ A circuit table manages a set of CircuitNodes, where each node corresponds
13
+ to an instance for a set of (zero or more) random variables.
14
+
15
+ Operations on circuit tables typically add circuit nodes to the circuit. It will
16
+ heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
17
+ by zero or one.) However, it may be that interim circuit nodes are created that
18
+ end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
19
+ completing all circuit table operations.
20
+
21
+ It is generally expected that no CircuitTable row will be created with a constant
22
+ zero node. These are assumed to be optimised out already.
23
+ """
24
+
25
+ cdef public object circuit
26
+ cdef public tuple[int, ...] rv_idxs
27
+ cdef public dict[tuple[int, ...], CircuitNode] rows
28
+
29
+ def __init__(
30
+ self,
31
+ circuit: Circuit,
32
+ rv_idxs: Sequence[int, ...],
33
+ rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
34
+ ):
35
+ """
36
+ Args:
37
+ circuit: the circuit whose nodes are being managed by this table.
38
+ rv_idxs: indexes of random variables.
39
+ rows: optional rows to add to the table.
40
+
41
+ Assumes:
42
+ * rv_idxs contains no duplicates.
43
+ * all row instances conform to the indexed random variables.
44
+ * all row circuit nodes belong to the given circuit.
45
+ """
46
+ self.circuit = circuit
47
+ self.rv_idxs = tuple(rv_idxs)
48
+ self.rows = dict(rows)
49
+
50
+ def __len__(self) -> int:
51
+ return len(self.rows)
52
+
53
+ def get(self, key, default=None):
54
+ return self.rows.get(key, default)
55
+
56
+ def __getitem__(self, key):
57
+ return self.rows[key]
58
+
59
+ def __setitem__(self, key, value):
60
+ self.rows[key] = value
61
+
62
+ cpdef object top(self): # -> CircuitNode:
63
+ # Get the circuit top value.
64
+ #
65
+ # Raises:
66
+ # RuntimeError if there is more than one row in the table.
67
+ #
68
+ # Returns:
69
+ # A single circuit node.
70
+ cdef int number_of_rows = len(self.rows)
71
+ if number_of_rows == 0:
72
+ return self.circuit.zero
73
+ elif number_of_rows == 1:
74
+ return next(iter(self.rows.values()))
75
+ else:
76
+ raise RuntimeError('cannot get top node from a table with more that 1 row')
77
+
78
+
79
+ # ==================================================================================
80
+ # Circuit Table Operations
81
+ # ==================================================================================
82
+
83
+ cpdef object sum_out(object table: CircuitTable, object rv_idxs: Iterable[int]): # -> CircuitTable:
84
+ # Return a circuit table that results from summing out
85
+ # the given random variables of this circuit table.
86
+ #
87
+ # Normally this will return a new table. However, if rv_idxs is empty,
88
+ # then the given table is returned unmodified.
89
+ #
90
+ # Raises:
91
+ # ValueError if rv_idxs is not a subset of table.rv_idxs.
92
+ # ValueError if rv_idxs contains duplicates.
93
+ cdef tuple[int, ...] rv_idxs_seq = tuple(rv_idxs)
94
+
95
+ if len(rv_idxs_seq) == 0:
96
+ # nothing to do
97
+ return table
98
+
99
+ cdef set[int] rv_idxs_set = set(rv_idxs_seq)
100
+ if len(rv_idxs_set) != len(rv_idxs_seq):
101
+ raise ValueError('rv_idxs contains duplicates')
102
+ if not rv_idxs_set.issubset(table.rv_idxs):
103
+ raise ValueError('rv_idxs is not a subset of table.rv_idxs')
104
+
105
+ cdef int rv_index
106
+ cdef list[int] remaining_rv_idxs = []
107
+ for rv_index in table.rv_idxs:
108
+ if rv_index not in rv_idxs_set:
109
+ remaining_rv_idxs.append(rv_index)
110
+
111
+ cdef int num_remaining = len(remaining_rv_idxs)
112
+ if num_remaining == 0:
113
+ # Special case: summing out all random variables
114
+ return sum_out_all(table)
115
+
116
+ # index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
117
+ cdef list[int] index_map = []
118
+ for rv_index in remaining_rv_idxs:
119
+ index_map.append(_find(table.rv_idxs, rv_index))
120
+
121
+ cdef dict[tuple[int, ...], list[object]] groups = {}
122
+ cdef object got
123
+ cdef list[int] group_instance
124
+ cdef tuple[int, ...] group_instance_tuple
125
+ cdef int i
126
+ cdef object node
127
+ cdef tuple[int, ...] instance
128
+ for instance, node in table.rows.items():
129
+ group_instance = []
130
+ for i in index_map:
131
+ group_instance.append(instance[i])
132
+ group_instance_tuple = tuple(group_instance)
133
+ got = groups.get(group_instance_tuple)
134
+ if got is None:
135
+ groups[group_instance_tuple] = [node]
136
+ else:
137
+ got.append(node)
138
+
139
+ cdef object circuit = table.circuit
140
+ cdef object new_table = CircuitTable(circuit, remaining_rv_idxs)
141
+ cdef dict[tuple[int, ...], object] rows = new_table.rows
142
+
143
+ for group_instance_tuple, to_add in groups.items():
144
+ node = circuit.optimised_add(to_add)
145
+ if not node.is_zero:
146
+ rows[group_instance_tuple] = node
147
+
148
+ return new_table
149
+
150
+
151
+ cpdef object sum_out_all(object table: CircuitTable): # -> CircuitTable:
152
+ # Return a circuit table that results from summing out
153
+ # all random variables of this circuit table.
154
+ circuit: Circuit = table.circuit
155
+ num_rows: int = len(table)
156
+ if num_rows == 0:
157
+ return CircuitTable(circuit, ())
158
+ elif num_rows == 1:
159
+ node = next(iter(table.rows.values()))
160
+ else:
161
+ node: CircuitNode = circuit.optimised_add(table.rows.values())
162
+ if node.is_zero:
163
+ return CircuitTable(circuit, ())
164
+
165
+ return CircuitTable(circuit, (), [((), node)])
166
+
167
+
168
+ cpdef object project(object table: CircuitTable, object rv_idxs: Iterable[int]): # -> CircuitTable:
169
+ # Call `sum_out(table, to_sum_out)`, where
170
+ # `to_sum_out = table.rv_idxs - rv_idxs`.
171
+ cdef set[int] to_sum_out = set(table.rv_idxs)
172
+ to_sum_out.difference_update(rv_idxs)
173
+ return sum_out(table, to_sum_out)
174
+
175
+
176
+ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
177
+ # Return a circuit table that results from the product of the two given tables.
178
+ #
179
+ # If x or y equals `one_table`, then the other table is returned. Otherwise,
180
+ # a new circuit table will be constructed and returned.
181
+ cdef int i
182
+ cdef object circuit = x.circuit
183
+ if y.circuit is not circuit:
184
+ raise ValueError('circuit tables must refer to the same circuit')
185
+
186
+ # Make the smaller table 'y', and the other 'x'.
187
+ # This is to minimise the index size on 'y'.
188
+ if len(x) < len(y):
189
+ x, y = y, x
190
+
191
+ # Special case: y == 0 or 1, and has no random variables.
192
+ if len(y.rv_idxs) == 0:
193
+ if len(y) == 1 and y.top().is_one:
194
+ return x
195
+ elif len(y) == 0:
196
+ return CircuitTable(circuit, x.rv_idxs)
197
+
198
+ # Set operations on rv indexes. After these operations:
199
+ # * co_rv_idxs is the set of rv indexes common (co) to x and y,
200
+ # * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
201
+ cdef set[int] yo_rv_idxs_set = set(y.rv_idxs)
202
+ cdef set[int] co_rv_idxs_set = set(x.rv_idxs)
203
+ co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
204
+ yo_rv_idxs_set.difference_update(co_rv_idxs_set)
205
+
206
+ if len(co_rv_idxs_set) == 0:
207
+ # Special case: no common random variables.
208
+ return _product_no_common_rvs(x, y)
209
+
210
+ # Convert random variable index sets to sequences
211
+ cdef tuple[int, ...] yo_rv_idxs = tuple(yo_rv_idxs_set) # y only random variables
212
+ cdef tuple[int, ...] co_rv_idxs = tuple(co_rv_idxs_set) # common random variables
213
+
214
+ # Cache mappings from result Instance to index into source Instance (x or y).
215
+ # This will be used in indexing and product loops to pull our needed values
216
+ # from the source instances.
217
+ cdef list[int] co_from_x_map = []
218
+ cdef list[int] co_from_y_map = []
219
+ cdef list[int] yo_from_y_map = []
220
+ for rv_index in co_rv_idxs:
221
+ co_from_x_map.append(_find(x.rv_idxs, rv_index))
222
+ co_from_y_map.append(_find(y.rv_idxs, rv_index))
223
+ for rv_index in yo_rv_idxs:
224
+ yo_from_y_map.append(_find(y.rv_idxs, rv_index))
225
+
226
+ cdef list[int] co
227
+ cdef list[int] yo
228
+ cdef object got
229
+ cdef tuple[int, ...] co_tuple
230
+ cdef tuple[int, ...] yo_tuple
231
+
232
+ cdef object table = CircuitTable(circuit, x.rv_idxs + yo_rv_idxs)
233
+ cdef dict[tuple[int, ...], object] rows = table.rows
234
+
235
+
236
+ # Index the y rows by common-only key (y is the smaller of the two tables).
237
+ cdef dict[tuple[int, ...], list[tuple[tuple[int, ...], object]]] y_index = {}
238
+ for y_instance, y_node in y.rows.items():
239
+ co = []
240
+ yo = []
241
+ for i in co_from_y_map:
242
+ co.append(y_instance[i])
243
+ for i in yo_from_y_map:
244
+ yo.append(y_instance[i])
245
+ co_tuple = tuple(co)
246
+ yo_tuple = tuple(yo)
247
+ got = y_index.get(co_tuple)
248
+ if got is None:
249
+ y_index[co_tuple] = [(yo_tuple, y_node)]
250
+ else:
251
+ got.append((yo_tuple, y_node))
252
+
253
+
254
+ # Iterate over x rows, inserting (instance, value).
255
+ # Rows with constant node values of one are optimised out.
256
+ for x_instance, x_node in x.rows.items():
257
+ co = []
258
+ for i in co_from_x_map:
259
+ co.append(x_instance[i])
260
+ co_tuple = tuple(co)
261
+
262
+ if x_node.is_one:
263
+ # Multiplying by one.
264
+ # Iterate over matching y rows.
265
+ got = y_index.get(co_tuple)
266
+ if got is not None:
267
+ for yo_tuple, y_node in got:
268
+ rows[x_instance + yo_tuple] = y_node
269
+ else:
270
+ # Iterate over matching y rows.
271
+ got = y_index.get(co_tuple)
272
+ if got is not None:
273
+ for yo_tuple, y_node in got:
274
+ rows[x_instance + yo_tuple] = _optimised_mul(circuit, x_node, y_node)
275
+
276
+ return table
277
+
278
+
279
+ cdef int _find(tuple[int, ...] xs, int x):
280
+ cdef int i
281
+ for i in range(len(xs)):
282
+ if xs[i] == x:
283
+ return i
284
+ # Very unexpected
285
+ raise RuntimeError('not found')
286
+
287
+
288
+ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
289
+ # Return the product of x and y, where x and y have no common random variables.
290
+ #
291
+ # This is an optimisation of more general product algorithm as no index needs
292
+ # to be construction based on the common random variables.
293
+ #
294
+ # Rows with constant node values of one are optimised out.
295
+ #
296
+ # Assumes:
297
+ # * There are no common random variables between x and y.
298
+ # * x and y are for the same circuit.
299
+ cdef object circuit = x.circuit
300
+ cdef object table = CircuitTable(circuit, x.rv_idxs + y.rv_idxs)
301
+ cdef tuple[int, ...] instance
302
+
303
+ for x_instance, x_node in x.rows.items():
304
+ if x_node.is_one:
305
+ for y_instance, y_node in y.rows.items():
306
+ instance = x_instance + y_instance
307
+ table.rows[instance] = y_node
308
+ else:
309
+ for y_instance, y_node in y.rows.items():
310
+ instance = x_instance + y_instance
311
+ table.rows[instance] = _optimised_mul(circuit, x_node, y_node)
312
+
313
+ return table
314
+
315
+
316
+ cdef object _optimised_mul(object circuit: Circuit, object x: CircuitNode, object y: CircuitNode): # -> CircuitNode
317
+ if x.is_zero:
318
+ return x
319
+ if y.is_zero:
320
+ return y
321
+ if x.is_one:
322
+ return y
323
+ if y.is_one:
324
+ return x
325
+ return circuit.mul(x, y)