compiled-knowledge 4.0.0a18__cp312-cp312-win_amd64.whl → 4.0.0a20__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/__init__.py +0 -3
- ck/circuit/_circuit_cy.c +37523 -0
- ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +3 -4
- ck/circuit/_circuit_cy.pyx +80 -79
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +188 -75
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +29 -4
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win_amd64.pyd +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/{circuit_analyser.py → circuit_analyser/_circuit_analyser_py.py} +14 -2
- ck/pgm_compiler/ace/__init__.py +1 -1
- ck/pgm_compiler/support/circuit_table/__init__.py +1 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
- ck_demos/ace/demo_ace.py +5 -0
- {compiled_knowledge-4.0.0a18.dist-info → compiled_knowledge-4.0.0a20.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a18.dist-info → compiled_knowledge-4.0.0a20.dist-info}/RECORD +24 -20
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy_cpp_verion.pyx +0 -601
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy_minimal_version.pyx +0 -311
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy_v4.0.0a17.pyx +0 -325
- {compiled_knowledge-4.0.0a18.dist-info → compiled_knowledge-4.0.0a20.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a18.dist-info → compiled_knowledge-4.0.0a20.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a18.dist-info → compiled_knowledge-4.0.0a20.dist-info}/top_level.txt +0 -0
|
@@ -1,601 +0,0 @@
|
|
|
1
|
-
# distutils: language = c++
|
|
2
|
-
|
|
3
|
-
from typing import Sequence, Tuple, Iterable, Optional, TypeAlias, Set
|
|
4
|
-
|
|
5
|
-
from ck.circuit import ADD, MUL
|
|
6
|
-
|
|
7
|
-
cdef int c_ADD = ADD
|
|
8
|
-
cdef int c_MUL = MUL
|
|
9
|
-
|
|
10
|
-
from ck.circuit._circuit_cy cimport Circuit, CircuitNode
|
|
11
|
-
|
|
12
|
-
from libcpp.vector cimport vector
|
|
13
|
-
from libcpp.unordered_set cimport unordered_set
|
|
14
|
-
from libcpp.unordered_map cimport unordered_map
|
|
15
|
-
from libcpp.pair cimport pair
|
|
16
|
-
|
|
17
|
-
from cpython.ref cimport PyObject, Py_INCREF, Py_DECREF
|
|
18
|
-
from cython.operator cimport dereference as deref, postincrement as incr
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
# Inject a hash function for vector[int] into the standard namespace
|
|
22
|
-
cdef extern from "_vector_hash.h" namespace "ck":
|
|
23
|
-
cdef size_t hash_std_vector_int(vector[int])
|
|
24
|
-
|
|
25
|
-
ctypedef vector[int] Key
|
|
26
|
-
ctypedef vector[int].iterator KeyIterator
|
|
27
|
-
|
|
28
|
-
ctypedef vector[int] IntVec
|
|
29
|
-
ctypedef vector[int].iterator IntVecIterator
|
|
30
|
-
|
|
31
|
-
ctypedef unordered_set[int] IntSet
|
|
32
|
-
ctypedef unordered_set[int].iterator IntSetIterator
|
|
33
|
-
|
|
34
|
-
# KeyMap defined below
|
|
35
|
-
ctypedef unordered_map[Key, PyObject *].iterator KeyMapIterator
|
|
36
|
-
ctypedef pair[Key, PyObject *] KeyMapItem
|
|
37
|
-
|
|
38
|
-
# NodeVec defined below
|
|
39
|
-
ctypedef vector[PyObject *].iterator NodeVecIterator
|
|
40
|
-
|
|
41
|
-
TableInstance: TypeAlias = Sequence[int]
|
|
42
|
-
|
|
43
|
-
cdef class KeyMap:
|
|
44
|
-
"""
|
|
45
|
-
A map from a Key (hashable STL vector) to a Python CircuitNode object.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
cdef unordered_map[Key, PyObject *] _items
|
|
49
|
-
|
|
50
|
-
cdef void put(self, Key key, CircuitNode node):
|
|
51
|
-
Py_INCREF(node)
|
|
52
|
-
self._items[key] = <PyObject *> node
|
|
53
|
-
|
|
54
|
-
cdef Optional[CircuitNode] get(self, Key key):
|
|
55
|
-
cdef CircuitNode node
|
|
56
|
-
cdef KeyMapIterator it = self._items.find(key)
|
|
57
|
-
if it == self._items.end():
|
|
58
|
-
return None
|
|
59
|
-
else:
|
|
60
|
-
node = <CircuitNode> deref(it).second
|
|
61
|
-
return node
|
|
62
|
-
|
|
63
|
-
cdef void sum(self, Key key, CircuitNode node):
|
|
64
|
-
# Add the given node to the node at `key`.
|
|
65
|
-
cdef CircuitNode existing_node
|
|
66
|
-
cdef KeyMapIterator it = self._items.find(key)
|
|
67
|
-
if it == self._items.end():
|
|
68
|
-
self.put(key, node)
|
|
69
|
-
else:
|
|
70
|
-
existing_node = <CircuitNode> deref(it).second
|
|
71
|
-
self.put(key, _optimised_add(existing_node, node))
|
|
72
|
-
|
|
73
|
-
cdef size_t size(self):
|
|
74
|
-
return self._items.size()
|
|
75
|
-
|
|
76
|
-
cdef KeyMapIterator begin(self):
|
|
77
|
-
return self._items.begin()
|
|
78
|
-
|
|
79
|
-
cdef KeyMapIterator end(self):
|
|
80
|
-
return self._items.end()
|
|
81
|
-
|
|
82
|
-
def instances(self) -> Iterable[TableInstance]:
|
|
83
|
-
"""
|
|
84
|
-
Get the keys as Python objects of type TableInstance.
|
|
85
|
-
"""
|
|
86
|
-
cdef KeyMapItem item
|
|
87
|
-
for item in self._items:
|
|
88
|
-
yield _key_to_instance(item.first)
|
|
89
|
-
|
|
90
|
-
def values(self) -> Iterable[CircuitNode]:
|
|
91
|
-
"""
|
|
92
|
-
Get the values as Python CircuitNode objects.
|
|
93
|
-
"""
|
|
94
|
-
cdef KeyMapItem item
|
|
95
|
-
for item in self._items:
|
|
96
|
-
yield <CircuitNode> item.second
|
|
97
|
-
|
|
98
|
-
def instance_values(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
|
|
99
|
-
"""
|
|
100
|
-
Get the values as Python CircuitNode objects.
|
|
101
|
-
"""
|
|
102
|
-
cdef KeyMapItem item
|
|
103
|
-
for item in self._items:
|
|
104
|
-
yield _key_to_instance(item.first), <CircuitNode> item.second
|
|
105
|
-
|
|
106
|
-
cdef void clear(self):
|
|
107
|
-
for pair in self._items:
|
|
108
|
-
Py_DECREF(<object> pair.second)
|
|
109
|
-
self._items.clear()
|
|
110
|
-
|
|
111
|
-
def __dealloc__(self):
|
|
112
|
-
self.clear()
|
|
113
|
-
|
|
114
|
-
cdef class NodeVec:
|
|
115
|
-
cdef vector[PyObject *] _items
|
|
116
|
-
|
|
117
|
-
cdef void push_back(self, CircuitNode node):
|
|
118
|
-
Py_INCREF(node)
|
|
119
|
-
self._items.push_back(<PyObject *> node)
|
|
120
|
-
|
|
121
|
-
cdef CircuitNode at(self, int i):
|
|
122
|
-
cdef CircuitNode node = <CircuitNode> self._items.at(i)
|
|
123
|
-
return node
|
|
124
|
-
|
|
125
|
-
cdef NodeVecIterator begin(self):
|
|
126
|
-
return self._items.begin()
|
|
127
|
-
|
|
128
|
-
cdef NodeVecIterator end(self):
|
|
129
|
-
return self._items.end()
|
|
130
|
-
|
|
131
|
-
cdef void clear(self):
|
|
132
|
-
for ptr in self._items:
|
|
133
|
-
Py_DECREF(<object> ptr)
|
|
134
|
-
self._items.clear()
|
|
135
|
-
|
|
136
|
-
def __dealloc__(self):
|
|
137
|
-
self.clear()
|
|
138
|
-
|
|
139
|
-
cdef class CircuitTable:
|
|
140
|
-
"""
|
|
141
|
-
A circuit table manages a set of CircuitNodes, where each node corresponds
|
|
142
|
-
to an instance for a set of (zero or more) random variables.
|
|
143
|
-
|
|
144
|
-
Operations on circuit tables typically add circuit nodes to the circuit. It will
|
|
145
|
-
heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
|
|
146
|
-
by zero or one.) However, it may be that interim circuit nodes are created that
|
|
147
|
-
end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
|
|
148
|
-
completing all circuit table operations.
|
|
149
|
-
|
|
150
|
-
It is generally expected that no CircuitTable row will be created with a constant
|
|
151
|
-
zero node. These are assumed to be optimised out already. This expectation
|
|
152
|
-
is not enforced by the CircuitTable class.
|
|
153
|
-
"""
|
|
154
|
-
cdef public Circuit circuit
|
|
155
|
-
cdef public tuple[int, ...] rv_idxs
|
|
156
|
-
cdef IntVec vec_rv_idxs
|
|
157
|
-
cdef KeyMap rows
|
|
158
|
-
|
|
159
|
-
def __init__(
|
|
160
|
-
self,
|
|
161
|
-
circuit: Circuit,
|
|
162
|
-
rv_idxs: Sequence[int],
|
|
163
|
-
rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
|
|
164
|
-
):
|
|
165
|
-
"""
|
|
166
|
-
Args:
|
|
167
|
-
circuit: the circuit whose nodes are being managed by this table.
|
|
168
|
-
rv_idxs: indexes of random variables.
|
|
169
|
-
rows: optional rows to add to the table.
|
|
170
|
-
|
|
171
|
-
Assumes:
|
|
172
|
-
* rv_idxs contains no duplicates.
|
|
173
|
-
* all row instances conform to the indexed random variables.
|
|
174
|
-
* all row circuit nodes belong to the given circuit.
|
|
175
|
-
"""
|
|
176
|
-
self.circuit = circuit
|
|
177
|
-
self.rv_idxs = tuple(rv_idxs)
|
|
178
|
-
|
|
179
|
-
self.vec_rv_idxs = IntVec()
|
|
180
|
-
for rv_id in self.rv_idxs:
|
|
181
|
-
self.vec_rv_idxs.push_back(rv_id)
|
|
182
|
-
|
|
183
|
-
self.rows = KeyMap()
|
|
184
|
-
instance: TableInstance
|
|
185
|
-
node: CircuitNode
|
|
186
|
-
for instance, node in rows:
|
|
187
|
-
self.rows.put(_instance_to_key(instance), node)
|
|
188
|
-
|
|
189
|
-
cdef void add_row(self, tuple[int, ...] instance, CircuitNode node):
|
|
190
|
-
self.rows.put(_instance_to_key(instance), node)
|
|
191
|
-
|
|
192
|
-
cdef void put(self, Key key, CircuitNode value):
|
|
193
|
-
self.rows.put(key, value)
|
|
194
|
-
|
|
195
|
-
def __len__(self) -> int:
|
|
196
|
-
return self.rows.size()
|
|
197
|
-
|
|
198
|
-
def get(self, instance: TableInstance, default=None):
|
|
199
|
-
value: Optional[CircuitNode] = self.rows.get(_instance_to_key(instance))
|
|
200
|
-
if value is None:
|
|
201
|
-
return default
|
|
202
|
-
else:
|
|
203
|
-
return value
|
|
204
|
-
|
|
205
|
-
def __getitem__(self, instance: TableInstance) -> CircuitNode:
|
|
206
|
-
value: Optional[CircuitNode] = self.rows.get(_instance_to_key(instance))
|
|
207
|
-
if value is None:
|
|
208
|
-
raise KeyError('instance not found: ' + str(instance))
|
|
209
|
-
return value
|
|
210
|
-
|
|
211
|
-
def __setitem__(self, instance: TableInstance, value: CircuitNode):
|
|
212
|
-
self.put(_instance_to_key(instance), value)
|
|
213
|
-
|
|
214
|
-
def keys(self) -> Iterable[TableInstance]:
|
|
215
|
-
return self.rows.instances()
|
|
216
|
-
|
|
217
|
-
def values(self) -> Iterable[CircuitNode]:
|
|
218
|
-
return self.rows.values()
|
|
219
|
-
|
|
220
|
-
def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
|
|
221
|
-
return self.rows.instance_values()
|
|
222
|
-
|
|
223
|
-
cpdef CircuitNode top(self):
|
|
224
|
-
# Get the circuit top value.
|
|
225
|
-
#
|
|
226
|
-
# Raises:
|
|
227
|
-
# RuntimeError if there is more than one row in the table.
|
|
228
|
-
#
|
|
229
|
-
# Returns:
|
|
230
|
-
# A single circuit node.
|
|
231
|
-
cdef CircuitNode node
|
|
232
|
-
cdef KeyMapIterator it
|
|
233
|
-
cdef size_t number_of_rows = self.rows.size()
|
|
234
|
-
if number_of_rows == 0:
|
|
235
|
-
return self.circuit.zero
|
|
236
|
-
elif number_of_rows == 1:
|
|
237
|
-
it = self.rows.begin()
|
|
238
|
-
node = <CircuitNode> deref(it).second
|
|
239
|
-
return node
|
|
240
|
-
else:
|
|
241
|
-
raise RuntimeError('cannot get top node from a table with more that 1 row')
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
# ==================================================================================
|
|
245
|
-
# Circuit Table Operations
|
|
246
|
-
# ==================================================================================
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
|
|
250
|
-
"""
|
|
251
|
-
Return a circuit table that results from summing out
|
|
252
|
-
the given random variables of this circuit table.
|
|
253
|
-
|
|
254
|
-
Normally this will return a new table. However, if rv_idxs is empty,
|
|
255
|
-
then the given table is returned unmodified.
|
|
256
|
-
|
|
257
|
-
Raises:
|
|
258
|
-
ValueError if rv_idxs is not a subset of table.rv_idxs.
|
|
259
|
-
ValueError if rv_idxs contains duplicates.
|
|
260
|
-
"""
|
|
261
|
-
return _sum_out(table, tuple(rv_idxs))
|
|
262
|
-
|
|
263
|
-
def project(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
|
|
264
|
-
"""
|
|
265
|
-
Project the given table onto the given random variables.
|
|
266
|
-
Equivalent to `sum_out(table, to_sum_out)`, where `to_sum_out = table.rv_idxs - rv_idxs`.
|
|
267
|
-
"""
|
|
268
|
-
to_sum_out: Set[int] = set(table.rv_idxs)
|
|
269
|
-
to_sum_out.difference_update(rv_idxs)
|
|
270
|
-
return _sum_out(table, tuple(to_sum_out))
|
|
271
|
-
|
|
272
|
-
def sum_out_all(table: CircuitTable) -> CircuitTable:
|
|
273
|
-
"""
|
|
274
|
-
Return a circuit table that results from summing out
|
|
275
|
-
all random variables of this circuit table.
|
|
276
|
-
"""
|
|
277
|
-
return _sum_out_all(table)
|
|
278
|
-
|
|
279
|
-
def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
280
|
-
"""
|
|
281
|
-
Return a circuit table that results from the product of the two given tables.
|
|
282
|
-
|
|
283
|
-
If x or y have a single row with value 1, then the other table is returned. Otherwise,
|
|
284
|
-
a new circuit table will be constructed and returned.
|
|
285
|
-
"""
|
|
286
|
-
return _product(x, y)
|
|
287
|
-
|
|
288
|
-
cdef CircuitTable _product(CircuitTable x, CircuitTable y):
|
|
289
|
-
cdef int i
|
|
290
|
-
cdef Circuit circuit = x.circuit
|
|
291
|
-
if y.circuit is not circuit:
|
|
292
|
-
raise ValueError('circuit tables must refer to the same circuit')
|
|
293
|
-
|
|
294
|
-
# Make the smaller table 'y', and the other 'x'.
|
|
295
|
-
# This is to minimise the index size on 'y'.
|
|
296
|
-
if x.rows.size() < y.rows.size():
|
|
297
|
-
x, y = y, x
|
|
298
|
-
|
|
299
|
-
# Special case: y == 0 or 1, and has no random variables.
|
|
300
|
-
if y.vec_rv_idxs.size() == 0:
|
|
301
|
-
if y.rows.size() == 1 and y.top().is_one:
|
|
302
|
-
return x
|
|
303
|
-
elif y.rows.size() == 0:
|
|
304
|
-
return CircuitTable(circuit, x.rv_idxs)
|
|
305
|
-
|
|
306
|
-
# Set operations on rv indexes. After these operations:
|
|
307
|
-
# * co_rv_idxs is the set of rv indexes common (co) to x and y,
|
|
308
|
-
# * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
|
|
309
|
-
cdef IntSet yo_rv_idxs_set = IntSet()
|
|
310
|
-
cdef IntSet co_rv_idxs_set = IntSet()
|
|
311
|
-
yo_rv_idxs_set.insert(y.vec_rv_idxs.begin(), y.vec_rv_idxs.end())
|
|
312
|
-
for i in x.vec_rv_idxs:
|
|
313
|
-
if yo_rv_idxs_set.find(i) != yo_rv_idxs_set.end():
|
|
314
|
-
co_rv_idxs_set.insert(i)
|
|
315
|
-
for i in co_rv_idxs_set:
|
|
316
|
-
yo_rv_idxs_set.erase(i)
|
|
317
|
-
|
|
318
|
-
if co_rv_idxs_set.size() == 0:
|
|
319
|
-
# Special case: no common random variables.
|
|
320
|
-
return _product_no_common_rvs(x, y)
|
|
321
|
-
|
|
322
|
-
# Convert random variable index sets to sequences
|
|
323
|
-
cdef IntVec yo_rv_idxs = IntVec(yo_rv_idxs_set.begin(), yo_rv_idxs_set.end()) # y only random variables
|
|
324
|
-
cdef IntVec co_rv_idxs = IntVec(co_rv_idxs_set.begin(), co_rv_idxs_set.end()) # common random variables
|
|
325
|
-
|
|
326
|
-
# Cache mappings from result Instance to index into source Instance (x or y).
|
|
327
|
-
# This will be used in indexing and product loops to pull our needed values
|
|
328
|
-
# from the source instances.
|
|
329
|
-
cdef IntVec co_from_x_map = IntVec()
|
|
330
|
-
cdef IntVec co_from_y_map = IntVec()
|
|
331
|
-
cdef IntVec yo_from_y_map = IntVec()
|
|
332
|
-
for rv_index in co_rv_idxs:
|
|
333
|
-
co_from_x_map.push_back(_find(x.vec_rv_idxs, rv_index))
|
|
334
|
-
co_from_y_map.push_back(_find(y.vec_rv_idxs, rv_index))
|
|
335
|
-
for rv_index in yo_rv_idxs:
|
|
336
|
-
yo_from_y_map.push_back(_find(y.vec_rv_idxs, rv_index))
|
|
337
|
-
|
|
338
|
-
# Index the y rows by common-only key (y is the smaller of the two tables).
|
|
339
|
-
cdef unordered_map[Key, vector[KeyMapItem]] y_index = unordered_map[Key, vector[KeyMapItem]]()
|
|
340
|
-
cdef unordered_map[Key, vector[KeyMapItem]].iterator y_index_find
|
|
341
|
-
cdef IntVec co = IntVec()
|
|
342
|
-
cdef IntVec yo = IntVec()
|
|
343
|
-
cdef Key y_key
|
|
344
|
-
cdef PyObject * y_node_ptr
|
|
345
|
-
cdef KeyMapItem item
|
|
346
|
-
cdef KeyMapIterator y_it = y.rows.begin()
|
|
347
|
-
cdef KeyMapIterator y_end = y.rows.end()
|
|
348
|
-
while y_it != y_end:
|
|
349
|
-
y_key = deref(y_it).first
|
|
350
|
-
y_node_ptr = deref(y_it).second
|
|
351
|
-
incr(y_it)
|
|
352
|
-
|
|
353
|
-
# Split y_key into the common part (co) and the remaining part (yo)
|
|
354
|
-
co.clear()
|
|
355
|
-
yo.clear()
|
|
356
|
-
for i in co_from_y_map:
|
|
357
|
-
co.push_back(y_key[i])
|
|
358
|
-
for i in yo_from_y_map:
|
|
359
|
-
yo.push_back(y_key[i])
|
|
360
|
-
|
|
361
|
-
# Append (yo, y_node) to y_index[co]
|
|
362
|
-
y_index_find = y_index.find(co)
|
|
363
|
-
item = KeyMapItem(yo, y_node_ptr)
|
|
364
|
-
if y_index_find == y_index.end():
|
|
365
|
-
y_index[co] = vector[KeyMapItem]()
|
|
366
|
-
y_index_find = y_index.find(co)
|
|
367
|
-
deref(y_index_find).second.push_back(item)
|
|
368
|
-
|
|
369
|
-
cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + tuple(yo_rv_idxs))
|
|
370
|
-
cdef KeyMap rows = table.rows
|
|
371
|
-
|
|
372
|
-
# Iterate over x rows, inserting (instance, value).
|
|
373
|
-
# Rows with constant node values of one are optimised out.
|
|
374
|
-
cdef KeyMapIterator x_it = x.rows.begin()
|
|
375
|
-
cdef KeyMapIterator x_end = x.rows.end()
|
|
376
|
-
cdef Key x_key
|
|
377
|
-
cdef CircuitNode x_node, y_node
|
|
378
|
-
while x_it != x_end:
|
|
379
|
-
x_key = deref(x_it).first
|
|
380
|
-
x_node = <CircuitNode> deref(x_it).second
|
|
381
|
-
incr(x_it)
|
|
382
|
-
|
|
383
|
-
# Split x_key to get the common part (co)
|
|
384
|
-
co.clear()
|
|
385
|
-
for i in co_from_x_map:
|
|
386
|
-
co.push_back(x_key[i])
|
|
387
|
-
|
|
388
|
-
# Get the y rows matching co
|
|
389
|
-
y_index_find = y_index.find(co)
|
|
390
|
-
if y_index_find == y_index.end():
|
|
391
|
-
# no matching y rows, continue to next x row
|
|
392
|
-
continue
|
|
393
|
-
|
|
394
|
-
if x_node.is_one:
|
|
395
|
-
# Multiplying by one.
|
|
396
|
-
# Iterate over matching y rows.
|
|
397
|
-
for item in deref(y_index_find).second:
|
|
398
|
-
yo = item.first
|
|
399
|
-
y_node = <CircuitNode> item.second
|
|
400
|
-
key = Key(x_key.begin(), x_key.end())
|
|
401
|
-
key.insert(key.end(), yo.begin(), yo.end()) # append yo to x_key
|
|
402
|
-
rows.put(key, y_node)
|
|
403
|
-
else:
|
|
404
|
-
# Iterate over matching y rows.
|
|
405
|
-
for item in deref(y_index_find).second:
|
|
406
|
-
yo = item.first
|
|
407
|
-
y_node = <CircuitNode> item.second
|
|
408
|
-
key = Key(x_key.begin(), x_key.end())
|
|
409
|
-
key.insert(key.end(), yo.begin(), yo.end()) # append yo to x_key
|
|
410
|
-
rows.put(key, _optimised_mul(x_node, y_node))
|
|
411
|
-
|
|
412
|
-
return table
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
cdef CircuitTable _sum_out(CircuitTable table, tuple[int, ...] rv_idxs):
|
|
416
|
-
cdef int rv_index, i
|
|
417
|
-
|
|
418
|
-
cdef IntSet rv_idxs_set
|
|
419
|
-
for py_rv_index in rv_idxs:
|
|
420
|
-
rv_idxs_set.insert(py_rv_index)
|
|
421
|
-
|
|
422
|
-
if rv_idxs_set.size() == 0:
|
|
423
|
-
# nothing to do
|
|
424
|
-
return table
|
|
425
|
-
|
|
426
|
-
# Get all table rvs that are not being summed out, remaining_rv_idxs.
|
|
427
|
-
# Sets index_map[i] to the location in table.rv_idxs for remaining_rv_idxs[i]
|
|
428
|
-
cdef IntVec remaining_rv_idxs
|
|
429
|
-
cdef IntVec index_map
|
|
430
|
-
cdef IntSetIterator find_it
|
|
431
|
-
cdef IntVecIterator rvs_it = table.vec_rv_idxs.begin()
|
|
432
|
-
cdef IntVecIterator rvs_end = table.vec_rv_idxs.end()
|
|
433
|
-
i = 0
|
|
434
|
-
while rvs_it != rvs_end:
|
|
435
|
-
rv_index = deref(rvs_it)
|
|
436
|
-
find_it = rv_idxs_set.find(rv_index)
|
|
437
|
-
if find_it == rv_idxs_set.end():
|
|
438
|
-
remaining_rv_idxs.push_back(rv_index)
|
|
439
|
-
index_map.push_back(i)
|
|
440
|
-
incr(rvs_it)
|
|
441
|
-
i += 1
|
|
442
|
-
|
|
443
|
-
cdef size_t num_remaining = remaining_rv_idxs.size()
|
|
444
|
-
if num_remaining == 0:
|
|
445
|
-
# Special case: summing out all random variables
|
|
446
|
-
return _sum_out_all(table)
|
|
447
|
-
|
|
448
|
-
# Group all result nodes by remaining rvs, summing them up as they are encountered
|
|
449
|
-
cdef Circuit circuit = table.circuit
|
|
450
|
-
cdef CircuitTable result_table = CircuitTable(circuit, remaining_rv_idxs)
|
|
451
|
-
cdef KeyMap groups = result_table.rows
|
|
452
|
-
cdef KeyMapIterator it = table.rows.begin()
|
|
453
|
-
cdef KeyMapIterator end = table.rows.end()
|
|
454
|
-
cdef Key table_key
|
|
455
|
-
cdef Key group_key
|
|
456
|
-
cdef CircuitNode node
|
|
457
|
-
# Make the result table from the group sums
|
|
458
|
-
while it != end:
|
|
459
|
-
table_key = deref(it).first
|
|
460
|
-
node = <CircuitNode> deref(it).second
|
|
461
|
-
group_key.clear()
|
|
462
|
-
for i in index_map:
|
|
463
|
-
group_key.push_back(table_key.at(i))
|
|
464
|
-
groups.sum(group_key, node)
|
|
465
|
-
incr(it)
|
|
466
|
-
|
|
467
|
-
return result_table
|
|
468
|
-
|
|
469
|
-
cdef CircuitTable _sum_out_all(CircuitTable table):
|
|
470
|
-
# Return a circuit table that results from summing out
|
|
471
|
-
# all random variables of this circuit table.
|
|
472
|
-
|
|
473
|
-
cdef Circuit circuit = table.circuit
|
|
474
|
-
cdef size_t num_rows = table.rows.size()
|
|
475
|
-
|
|
476
|
-
cdef KeyMapIterator it, end
|
|
477
|
-
cdef CircuitNode node, next_node
|
|
478
|
-
|
|
479
|
-
if num_rows == 0:
|
|
480
|
-
return CircuitTable(circuit, ())
|
|
481
|
-
else:
|
|
482
|
-
it = table.rows.begin()
|
|
483
|
-
end = table.rows.end()
|
|
484
|
-
node = <CircuitNode> deref(it).second
|
|
485
|
-
incr(it)
|
|
486
|
-
while it != end:
|
|
487
|
-
next_node = <CircuitNode> deref(it).second
|
|
488
|
-
node = _optimised_add(node, next_node)
|
|
489
|
-
incr(it)
|
|
490
|
-
|
|
491
|
-
if node.is_zero:
|
|
492
|
-
return CircuitTable(circuit, ())
|
|
493
|
-
else:
|
|
494
|
-
return CircuitTable(circuit, (), [((), node)])
|
|
495
|
-
|
|
496
|
-
cdef int _find(IntVec xs, int x):
|
|
497
|
-
# Return index of x in xs or -1 if not found
|
|
498
|
-
cdef int i
|
|
499
|
-
for i in range(xs.size()):
|
|
500
|
-
if xs[i] == x:
|
|
501
|
-
return i
|
|
502
|
-
return -1
|
|
503
|
-
|
|
504
|
-
cdef CircuitTable _product_no_common_rvs(CircuitTable x, CircuitTable y):
|
|
505
|
-
# Return the product of x and y, where x and y have no common random variables.
|
|
506
|
-
#
|
|
507
|
-
# This is an optimisation of more general product algorithm as no index needs
|
|
508
|
-
# to be construction based on the common random variables.
|
|
509
|
-
#
|
|
510
|
-
# Rows with constant node values of one are optimised out.
|
|
511
|
-
#
|
|
512
|
-
# Assumes:
|
|
513
|
-
# * There are no common random variables between x and y.
|
|
514
|
-
# * x and y are for the same circuit.
|
|
515
|
-
cdef Circuit circuit = x.circuit
|
|
516
|
-
|
|
517
|
-
cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + y.rv_idxs)
|
|
518
|
-
|
|
519
|
-
cdef KeyMapIterator it_x = x.rows.begin()
|
|
520
|
-
cdef KeyMapIterator it_y
|
|
521
|
-
|
|
522
|
-
cdef KeyMapIterator end_x = x.rows.end()
|
|
523
|
-
cdef KeyMapIterator end_y = y.rows.end()
|
|
524
|
-
|
|
525
|
-
cdef CircuitNode node_x
|
|
526
|
-
cdef CircuitNode node_y
|
|
527
|
-
cdef CircuitNode node
|
|
528
|
-
|
|
529
|
-
cdef Key key_x
|
|
530
|
-
cdef Key key_y
|
|
531
|
-
cdef Key key
|
|
532
|
-
|
|
533
|
-
while it_x != end_x:
|
|
534
|
-
it_y = y.rows.begin()
|
|
535
|
-
key_x = deref(it_x).first
|
|
536
|
-
node_x = <CircuitNode> deref(it_x).second
|
|
537
|
-
if node_x.is_zero:
|
|
538
|
-
pass
|
|
539
|
-
elif node_x.is_one:
|
|
540
|
-
while it_y != end_y:
|
|
541
|
-
key_y = deref(it_y).first
|
|
542
|
-
node_y = <CircuitNode> deref(it_y).second
|
|
543
|
-
if node_y.is_zero:
|
|
544
|
-
pass
|
|
545
|
-
else:
|
|
546
|
-
key = _join_keys(key_x, key_y)
|
|
547
|
-
table.rows.put(key, node_y)
|
|
548
|
-
incr(it_y)
|
|
549
|
-
else:
|
|
550
|
-
while it_y != end_y:
|
|
551
|
-
key_y = deref(it_y).first
|
|
552
|
-
node_y = <CircuitNode> deref(it_y).second
|
|
553
|
-
if node_y.is_zero:
|
|
554
|
-
pass
|
|
555
|
-
else:
|
|
556
|
-
key = _join_keys(key_x, key_y)
|
|
557
|
-
node = _optimised_mul(node_x, node_y)
|
|
558
|
-
table.rows.put(key, node)
|
|
559
|
-
incr(it_y)
|
|
560
|
-
incr(it_x)
|
|
561
|
-
return table
|
|
562
|
-
|
|
563
|
-
cdef Key _instance_to_key(object instance: Iterable[int]):
|
|
564
|
-
cdef Key key
|
|
565
|
-
for state_idx in instance:
|
|
566
|
-
key.push_back(state_idx)
|
|
567
|
-
return key
|
|
568
|
-
|
|
569
|
-
cdef tuple[int, ...] _key_to_instance(Key key):
|
|
570
|
-
cdef list[int] instance = []
|
|
571
|
-
cdef KeyIterator it = key.begin()
|
|
572
|
-
while it != key.end():
|
|
573
|
-
instance.append(deref(it))
|
|
574
|
-
incr(it)
|
|
575
|
-
return tuple(instance)
|
|
576
|
-
|
|
577
|
-
cdef Key _join_keys(Key x, Key y):
|
|
578
|
-
cdef Key result = Key(x)
|
|
579
|
-
cdef KeyIterator it = y.begin()
|
|
580
|
-
while it != y.end():
|
|
581
|
-
result.push_back(deref(it))
|
|
582
|
-
incr(it)
|
|
583
|
-
return result
|
|
584
|
-
|
|
585
|
-
cdef CircuitNode _optimised_add(CircuitNode x, CircuitNode y):
|
|
586
|
-
if x.is_zero:
|
|
587
|
-
return y
|
|
588
|
-
if y.is_zero:
|
|
589
|
-
return x
|
|
590
|
-
return x.circuit.op(c_ADD, (x, y))
|
|
591
|
-
|
|
592
|
-
cdef CircuitNode _optimised_mul(CircuitNode x, CircuitNode y):
|
|
593
|
-
if x.is_zero:
|
|
594
|
-
return x
|
|
595
|
-
if y.is_zero:
|
|
596
|
-
return y
|
|
597
|
-
if x.is_one:
|
|
598
|
-
return y
|
|
599
|
-
if y.is_one:
|
|
600
|
-
return x
|
|
601
|
-
return x.circuit.op(c_MUL, (x, y))
|