compiled-knowledge 4.0.0a16__cp312-cp312-win_amd64.whl → 4.0.0a18__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/__init__.py +9 -2
- ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +33 -0
- ck/circuit/{circuit.pyx → _circuit_cy.pyx} +115 -133
- ck/circuit/{circuit_py.py → _circuit_py.py} +16 -8
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +88 -60
- ck/circuit_compiler/named_circuit_compilers.py +1 -1
- ck/pgm_compiler/factor_elimination.py +23 -13
- ck/pgm_compiler/support/circuit_table/__init__.py +9 -2
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy_cpp_verion.pyx +601 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy_minimal_version.pyx +311 -0
- ck/pgm_compiler/support/circuit_table/{circuit_table.pyx → _circuit_table_cy_v4.0.0a17.pyx} +9 -9
- ck/pgm_compiler/support/circuit_table/{circuit_table_py.py → _circuit_table_py.py} +80 -45
- ck/pgm_compiler/support/clusters.py +16 -4
- ck/pgm_compiler/support/factor_tables.py +1 -1
- ck/pgm_compiler/support/join_tree.py +67 -10
- ck/pgm_compiler/support/named_compiler_maker.py +12 -2
- ck/pgm_compiler/variable_elimination.py +2 -0
- ck/utils/iter_extras.py +8 -1
- ck_demos/pgm_compiler/demo_compiler_dump.py +10 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/utils/compare.py +5 -1
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/RECORD +30 -29
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/WHEEL +1 -1
- ck/circuit/circuit.c +0 -38861
- ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
- ck/circuit/circuit_node.pyx +0 -138
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +0 -17373
- ck/pgm_compiler/support/circuit_table/circuit_table.c +0 -16042
- ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple, Dict, Iterable, Set, Iterator
|
|
4
|
+
|
|
5
|
+
from ck.circuit import ADD, MUL
|
|
6
|
+
from ck.utils.map_list import MapList
|
|
7
|
+
|
|
8
|
+
from ck.circuit._circuit_cy cimport Circuit, CircuitNode
|
|
9
|
+
|
|
10
|
+
cdef int c_ADD = ADD
|
|
11
|
+
cdef int c_MUL = MUL
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
TableInstance = Tuple[int, ...]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CircuitTable:
|
|
18
|
+
"""
|
|
19
|
+
A circuit table manages a set of CircuitNodes, where each node corresponds
|
|
20
|
+
to an instance for a set of (zero or more) random variables.
|
|
21
|
+
|
|
22
|
+
Operations on circuit tables typically add circuit nodes to the circuit. It will
|
|
23
|
+
heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
|
|
24
|
+
by zero or one.) However, it may be that interim circuit nodes are created that
|
|
25
|
+
end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
|
|
26
|
+
completing all circuit table operations.
|
|
27
|
+
|
|
28
|
+
It is generally expected that no CircuitTable row will be created with a constant
|
|
29
|
+
zero node. These are assumed to be optimised out already.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
circuit: Circuit,
|
|
35
|
+
rv_idxs: Sequence[int],
|
|
36
|
+
rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Args:
|
|
40
|
+
circuit: the circuit whose nodes are being managed by this table.
|
|
41
|
+
rv_idxs: indexes of random variables.
|
|
42
|
+
rows: optional rows to add to the table.
|
|
43
|
+
|
|
44
|
+
Assumes:
|
|
45
|
+
* rv_idxs contains no duplicates.
|
|
46
|
+
* all row instances conform to the indexed random variables.
|
|
47
|
+
* all row circuit nodes belong to the given circuit.
|
|
48
|
+
"""
|
|
49
|
+
self._circuit: Circuit = circuit
|
|
50
|
+
self._rv_idxs: Tuple[int, ...] = tuple(rv_idxs)
|
|
51
|
+
self._rows: Dict[TableInstance, CircuitNode] = dict(rows)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def circuit(self) -> Circuit:
|
|
55
|
+
return self._circuit
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def rv_idxs(self) -> Tuple[int, ...]:
|
|
59
|
+
return self._rv_idxs
|
|
60
|
+
|
|
61
|
+
def __len__(self) -> int:
|
|
62
|
+
return len(self._rows)
|
|
63
|
+
|
|
64
|
+
def get(self, key, default=None):
|
|
65
|
+
return self._rows.get(key, default)
|
|
66
|
+
|
|
67
|
+
def keys(self) -> Iterable[TableInstance]:
|
|
68
|
+
return self._rows.keys()
|
|
69
|
+
|
|
70
|
+
def values(self) -> Iterable[CircuitNode]:
|
|
71
|
+
return self._rows.values()
|
|
72
|
+
|
|
73
|
+
def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
|
|
74
|
+
return self._rows.items()
|
|
75
|
+
|
|
76
|
+
def __getitem__(self, key):
|
|
77
|
+
return self._rows[key]
|
|
78
|
+
|
|
79
|
+
def __setitem__(self, key, value):
|
|
80
|
+
self._rows[key] = value
|
|
81
|
+
|
|
82
|
+
def top(self) -> CircuitNode:
|
|
83
|
+
"""
|
|
84
|
+
Get the circuit top value.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
RuntimeError if there is more than one row in the table.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A single circuit node.
|
|
91
|
+
"""
|
|
92
|
+
if len(self._rows) == 0:
|
|
93
|
+
return self._circuit.zero
|
|
94
|
+
elif len(self._rows) == 1:
|
|
95
|
+
return next(iter(self._rows.values()))
|
|
96
|
+
else:
|
|
97
|
+
raise RuntimeError('cannot get top node from a table with more that 1 row')
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# ==================================================================================
|
|
101
|
+
# Circuit Table Operations
|
|
102
|
+
# ==================================================================================
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
|
|
106
|
+
"""
|
|
107
|
+
Return a circuit table that results from summing out
|
|
108
|
+
the given random variables of this circuit table.
|
|
109
|
+
|
|
110
|
+
Normally this will return a new table. However, if rv_idxs is empty,
|
|
111
|
+
then the given table is returned unmodified.
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
ValueError if rv_idxs is not a subset of table.rv_idxs.
|
|
115
|
+
ValueError if rv_idxs contains duplicates.
|
|
116
|
+
"""
|
|
117
|
+
rv_idxs: Sequence[int] = tuple(rv_idxs)
|
|
118
|
+
|
|
119
|
+
if len(rv_idxs) == 0:
|
|
120
|
+
# nothing to do
|
|
121
|
+
return table
|
|
122
|
+
|
|
123
|
+
rv_idxs_set: Set[int] = set(rv_idxs)
|
|
124
|
+
if len(rv_idxs_set) != len(rv_idxs):
|
|
125
|
+
raise ValueError('rv_idxs contains duplicates')
|
|
126
|
+
if not rv_idxs_set.issubset(table.rv_idxs):
|
|
127
|
+
raise ValueError('rv_idxs is not a subset of table.rv_idxs')
|
|
128
|
+
|
|
129
|
+
remaining_rv_idxs = tuple(
|
|
130
|
+
rv_index
|
|
131
|
+
for rv_index in table.rv_idxs
|
|
132
|
+
if rv_index not in rv_idxs_set
|
|
133
|
+
)
|
|
134
|
+
num_remaining = len(remaining_rv_idxs)
|
|
135
|
+
if num_remaining == 0:
|
|
136
|
+
# Special case: summing out all random variables
|
|
137
|
+
return sum_out_all(table)
|
|
138
|
+
|
|
139
|
+
# index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
|
|
140
|
+
index_map = tuple(
|
|
141
|
+
table.rv_idxs.index(remaining_rv_index)
|
|
142
|
+
for remaining_rv_index in remaining_rv_idxs
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# This is a one-pass version to sum the groups. The
|
|
146
|
+
# two-pass version (below) seems to have better performance.
|
|
147
|
+
#
|
|
148
|
+
# circuit: Circuit = table.circuit
|
|
149
|
+
# result = CircuitTable(circuit, remaining_rv_idxs)
|
|
150
|
+
# result_rows: Dict[TableInstance, CircuitNode] = result._rows
|
|
151
|
+
# for instance, node in table.items():
|
|
152
|
+
# group_instance = tuple(instance[i] for i in index_map)
|
|
153
|
+
# prev_sum = result_rows.get(group_instance)
|
|
154
|
+
# if prev_sum is None:
|
|
155
|
+
# result_rows[group_instance] = node
|
|
156
|
+
# else:
|
|
157
|
+
# result_rows[group_instance] = circuit.add(prev_sum, node)
|
|
158
|
+
# return result
|
|
159
|
+
|
|
160
|
+
groups: MapList[TableInstance, CircuitNode] = MapList()
|
|
161
|
+
for instance, node in table.items():
|
|
162
|
+
group_instance = tuple(instance[i] for i in index_map)
|
|
163
|
+
groups.append(group_instance, node)
|
|
164
|
+
circuit: Circuit = table.circuit
|
|
165
|
+
return CircuitTable(
|
|
166
|
+
circuit,
|
|
167
|
+
remaining_rv_idxs,
|
|
168
|
+
(
|
|
169
|
+
(group, circuit.op(c_ADD, tuple(to_add)))
|
|
170
|
+
for group, to_add in groups.items()
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def sum_out_all(table: CircuitTable) -> CircuitTable:
|
|
177
|
+
"""
|
|
178
|
+
Return a circuit table that results from summing out
|
|
179
|
+
all random variables of this circuit table.
|
|
180
|
+
"""
|
|
181
|
+
circuit: Circuit = table.circuit
|
|
182
|
+
num_rows: int = len(table)
|
|
183
|
+
if num_rows == 0:
|
|
184
|
+
return CircuitTable(circuit, ())
|
|
185
|
+
elif num_rows == 1:
|
|
186
|
+
node = next(iter(table.values()))
|
|
187
|
+
else:
|
|
188
|
+
node: CircuitNode = circuit.op(c_ADD, tuple(table.values()))
|
|
189
|
+
if node.is_zero:
|
|
190
|
+
return CircuitTable(circuit, ())
|
|
191
|
+
|
|
192
|
+
return CircuitTable(circuit, (), [((), node)])
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def project(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
|
|
196
|
+
"""
|
|
197
|
+
Call `sum_out(table, to_sum_out)`, where
|
|
198
|
+
`to_sum_out = table.rv_idxs - rv_idxs`.
|
|
199
|
+
"""
|
|
200
|
+
to_sum_out: Set[int] = set(table.rv_idxs)
|
|
201
|
+
to_sum_out.difference_update(rv_idxs)
|
|
202
|
+
return sum_out(table, to_sum_out)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
206
|
+
"""
|
|
207
|
+
Return a circuit table that results from the product of the two given tables.
|
|
208
|
+
|
|
209
|
+
If x or y have a single row with value 1, then the other table is returned. Otherwise,
|
|
210
|
+
a new circuit table will be constructed and returned.
|
|
211
|
+
"""
|
|
212
|
+
circuit: Circuit = x.circuit
|
|
213
|
+
if y.circuit is not circuit:
|
|
214
|
+
raise ValueError('circuit tables must refer to the same circuit')
|
|
215
|
+
|
|
216
|
+
# Make the smaller table 'y', and the other 'x'.
|
|
217
|
+
# This is to minimise the index size on 'y'.
|
|
218
|
+
if len(x) < len(y):
|
|
219
|
+
x, y = y, x
|
|
220
|
+
|
|
221
|
+
x_rv_idxs: Tuple[int, ...] = x.rv_idxs
|
|
222
|
+
y_rv_idxs: Tuple[int, ...] = y.rv_idxs
|
|
223
|
+
|
|
224
|
+
# Special case: y == 0 or 1, and has no random variables.
|
|
225
|
+
if y_rv_idxs == ():
|
|
226
|
+
if len(y) == 1 and y.top().is_one:
|
|
227
|
+
return x
|
|
228
|
+
elif len(y) == 0:
|
|
229
|
+
return CircuitTable(circuit, x_rv_idxs)
|
|
230
|
+
|
|
231
|
+
# Set operations on rv indexes. After these operations:
|
|
232
|
+
# * co_rv_idxs is the set of rv indexes common (co) to x and y,
|
|
233
|
+
# * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
|
|
234
|
+
yo_rv_idxs_set: Set[int] = set(y_rv_idxs)
|
|
235
|
+
co_rv_idxs_set: Set[int] = set(x_rv_idxs)
|
|
236
|
+
co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
|
|
237
|
+
yo_rv_idxs_set.difference_update(co_rv_idxs_set)
|
|
238
|
+
|
|
239
|
+
if len(co_rv_idxs_set) == 0:
|
|
240
|
+
# Special case: no common random variables.
|
|
241
|
+
return _product_no_common_rvs(x, y)
|
|
242
|
+
|
|
243
|
+
# Convert random variable index sets to sequences
|
|
244
|
+
yo_rv_idxs: Tuple[int, ...] = tuple(yo_rv_idxs_set) # y only random variables
|
|
245
|
+
co_rv_idxs: Tuple[int, ...] = tuple(co_rv_idxs_set) # common random variables
|
|
246
|
+
|
|
247
|
+
# Cache mappings from result Instance to index into source Instance (x or y).
|
|
248
|
+
# This will be used in indexing and product loops to pull our needed values
|
|
249
|
+
# from the source instances.
|
|
250
|
+
co_from_x_map = tuple(x.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
|
|
251
|
+
co_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
|
|
252
|
+
yo_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in yo_rv_idxs)
|
|
253
|
+
|
|
254
|
+
# Index the y rows by common-only key (y is the smaller of the two tables).
|
|
255
|
+
y_index: MapList[TableInstance, Tuple[TableInstance, CircuitNode]] = MapList()
|
|
256
|
+
for y_instance, y_node in y.items():
|
|
257
|
+
co = tuple(y_instance[i] for i in co_from_y_map)
|
|
258
|
+
yo = tuple(y_instance[i] for i in yo_from_y_map)
|
|
259
|
+
y_index.append(co, (yo, y_node))
|
|
260
|
+
|
|
261
|
+
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
262
|
+
# Iterate over x rows, yielding (instance, value).
|
|
263
|
+
# Rows with constant node values of one are optimised out.
|
|
264
|
+
for _x_instance, _x_node in x.items():
|
|
265
|
+
_co = tuple(_x_instance[i] for i in co_from_x_map)
|
|
266
|
+
if _x_node.is_one:
|
|
267
|
+
# Multiplying by one.
|
|
268
|
+
# Iterate over matching y rows.
|
|
269
|
+
for _yo, _y_node in y_index.get(_co, ()):
|
|
270
|
+
yield _x_instance + _yo, _y_node
|
|
271
|
+
else:
|
|
272
|
+
# Iterate over matching y rows.
|
|
273
|
+
for _yo, _y_node in y_index.get(_co, ()):
|
|
274
|
+
if _y_node.is_one:
|
|
275
|
+
yield _x_instance + _yo, _x_node
|
|
276
|
+
else:
|
|
277
|
+
yield _x_instance + _yo, circuit.op(c_MUL, (_x_node, _y_node))
|
|
278
|
+
|
|
279
|
+
return CircuitTable(circuit, x_rv_idxs + yo_rv_idxs, _result_rows())
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
283
|
+
"""
|
|
284
|
+
Return the product of x and y, where x and y have no common random variables.
|
|
285
|
+
|
|
286
|
+
This is an optimisation of more general product algorithm as no index needs
|
|
287
|
+
to be construction based on the common random variables.
|
|
288
|
+
|
|
289
|
+
Rows with constant node values of one are optimised out.
|
|
290
|
+
|
|
291
|
+
Assumes:
|
|
292
|
+
* There are no common random variables between x and y.
|
|
293
|
+
* x and y are for the same circuit.
|
|
294
|
+
"""
|
|
295
|
+
circuit: Circuit = x.circuit
|
|
296
|
+
|
|
297
|
+
result_rv_idxs: Tuple[int, ...] = x.rv_idxs + y.rv_idxs
|
|
298
|
+
|
|
299
|
+
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
300
|
+
for x_instance, x_node in x.items():
|
|
301
|
+
if x_node.is_one:
|
|
302
|
+
for y_instance, y_node in y.items():
|
|
303
|
+
yield x_instance + y_instance, y_node
|
|
304
|
+
else:
|
|
305
|
+
for y_instance, y_node in y.items():
|
|
306
|
+
if y_node.is_one:
|
|
307
|
+
yield x_instance + y_instance, y_node
|
|
308
|
+
else:
|
|
309
|
+
yield x_instance + y_instance, circuit.op(c_MUL, (x_node, y_node))
|
|
310
|
+
|
|
311
|
+
return CircuitTable(circuit, result_rv_idxs, _result_rows())
|
|
@@ -142,7 +142,7 @@ cpdef object sum_out(object table: CircuitTable, object rv_idxs: Iterable[int]):
|
|
|
142
142
|
|
|
143
143
|
for group_instance_tuple, to_add in groups.items():
|
|
144
144
|
node = circuit.optimised_add(to_add)
|
|
145
|
-
if not node.is_zero
|
|
145
|
+
if not node.is_zero:
|
|
146
146
|
rows[group_instance_tuple] = node
|
|
147
147
|
|
|
148
148
|
return new_table
|
|
@@ -159,7 +159,7 @@ cpdef object sum_out_all(object table: CircuitTable): # -> CircuitTable:
|
|
|
159
159
|
node = next(iter(table.rows.values()))
|
|
160
160
|
else:
|
|
161
161
|
node: CircuitNode = circuit.optimised_add(table.rows.values())
|
|
162
|
-
if node.is_zero
|
|
162
|
+
if node.is_zero:
|
|
163
163
|
return CircuitTable(circuit, ())
|
|
164
164
|
|
|
165
165
|
return CircuitTable(circuit, (), [((), node)])
|
|
@@ -190,7 +190,7 @@ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
|
|
|
190
190
|
|
|
191
191
|
# Special case: y == 0 or 1, and has no random variables.
|
|
192
192
|
if len(y.rv_idxs) == 0:
|
|
193
|
-
if len(y) == 1 and y.top().is_one
|
|
193
|
+
if len(y) == 1 and y.top().is_one:
|
|
194
194
|
return x
|
|
195
195
|
elif len(y) == 0:
|
|
196
196
|
return CircuitTable(circuit, x.rv_idxs)
|
|
@@ -259,7 +259,7 @@ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
|
|
|
259
259
|
co.append(x_instance[i])
|
|
260
260
|
co_tuple = tuple(co)
|
|
261
261
|
|
|
262
|
-
if x_node.is_one
|
|
262
|
+
if x_node.is_one:
|
|
263
263
|
# Multiplying by one.
|
|
264
264
|
# Iterate over matching y rows.
|
|
265
265
|
got = y_index.get(co_tuple)
|
|
@@ -301,7 +301,7 @@ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> Circ
|
|
|
301
301
|
cdef tuple[int, ...] instance
|
|
302
302
|
|
|
303
303
|
for x_instance, x_node in x.rows.items():
|
|
304
|
-
if x_node.is_one
|
|
304
|
+
if x_node.is_one:
|
|
305
305
|
for y_instance, y_node in y.rows.items():
|
|
306
306
|
instance = x_instance + y_instance
|
|
307
307
|
table.rows[instance] = y_node
|
|
@@ -314,12 +314,12 @@ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> Circ
|
|
|
314
314
|
|
|
315
315
|
|
|
316
316
|
cdef object _optimised_mul(object circuit: Circuit, object x: CircuitNode, object y: CircuitNode): # -> CircuitNode
|
|
317
|
-
if x.is_zero
|
|
317
|
+
if x.is_zero:
|
|
318
318
|
return x
|
|
319
|
-
if y.is_zero
|
|
319
|
+
if y.is_zero:
|
|
320
320
|
return y
|
|
321
|
-
if x.is_one
|
|
321
|
+
if x.is_one:
|
|
322
322
|
return y
|
|
323
|
-
if y.is_one
|
|
323
|
+
if y.is_one:
|
|
324
324
|
return x
|
|
325
325
|
return circuit.mul(x, y)
|
|
@@ -40,21 +40,38 @@ class CircuitTable:
|
|
|
40
40
|
* all row instances conform to the indexed random variables.
|
|
41
41
|
* all row circuit nodes belong to the given circuit.
|
|
42
42
|
"""
|
|
43
|
-
self.
|
|
44
|
-
self.
|
|
45
|
-
self.
|
|
43
|
+
self._circuit: Circuit = circuit
|
|
44
|
+
self._rv_idxs: Tuple[int, ...] = tuple(rv_idxs)
|
|
45
|
+
self._rows: Dict[TableInstance, CircuitNode] = dict(rows)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def circuit(self) -> Circuit:
|
|
49
|
+
return self._circuit
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def rv_idxs(self) -> Tuple[int, ...]:
|
|
53
|
+
return self._rv_idxs
|
|
46
54
|
|
|
47
55
|
def __len__(self) -> int:
|
|
48
|
-
return len(self.
|
|
56
|
+
return len(self._rows)
|
|
49
57
|
|
|
50
58
|
def get(self, key, default=None):
|
|
51
|
-
return self.
|
|
59
|
+
return self._rows.get(key, default)
|
|
60
|
+
|
|
61
|
+
def keys(self) -> Iterable[TableInstance]:
|
|
62
|
+
return self._rows.keys()
|
|
63
|
+
|
|
64
|
+
def values(self) -> Iterable[CircuitNode]:
|
|
65
|
+
return self._rows.values()
|
|
66
|
+
|
|
67
|
+
def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
|
|
68
|
+
return self._rows.items()
|
|
52
69
|
|
|
53
70
|
def __getitem__(self, key):
|
|
54
|
-
return self.
|
|
71
|
+
return self._rows[key]
|
|
55
72
|
|
|
56
73
|
def __setitem__(self, key, value):
|
|
57
|
-
self.
|
|
74
|
+
self._rows[key] = value
|
|
58
75
|
|
|
59
76
|
def top(self) -> CircuitNode:
|
|
60
77
|
"""
|
|
@@ -66,10 +83,10 @@ class CircuitTable:
|
|
|
66
83
|
Returns:
|
|
67
84
|
A single circuit node.
|
|
68
85
|
"""
|
|
69
|
-
if len(self.
|
|
70
|
-
return self.
|
|
71
|
-
elif len(self.
|
|
72
|
-
return next(iter(self.
|
|
86
|
+
if len(self._rows) == 0:
|
|
87
|
+
return self._circuit.zero
|
|
88
|
+
elif len(self._rows) == 1:
|
|
89
|
+
return next(iter(self._rows.values()))
|
|
73
90
|
else:
|
|
74
91
|
raise RuntimeError('cannot get top node from a table with more that 1 row')
|
|
75
92
|
|
|
@@ -119,20 +136,34 @@ def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
|
|
|
119
136
|
for remaining_rv_index in remaining_rv_idxs
|
|
120
137
|
)
|
|
121
138
|
|
|
139
|
+
# This is a one-pass version to sum the groups. The
|
|
140
|
+
# two-pass version (below) seems to have better performance.
|
|
141
|
+
#
|
|
142
|
+
# circuit: Circuit = table.circuit
|
|
143
|
+
# result = CircuitTable(circuit, remaining_rv_idxs)
|
|
144
|
+
# result_rows: Dict[TableInstance, CircuitNode] = result._rows
|
|
145
|
+
# for instance, node in table.items():
|
|
146
|
+
# group_instance = tuple(instance[i] for i in index_map)
|
|
147
|
+
# prev_sum = result_rows.get(group_instance)
|
|
148
|
+
# if prev_sum is None:
|
|
149
|
+
# result_rows[group_instance] = node
|
|
150
|
+
# else:
|
|
151
|
+
# result_rows[group_instance] = circuit.add(prev_sum, node)
|
|
152
|
+
# return result
|
|
153
|
+
|
|
122
154
|
groups: MapList[TableInstance, CircuitNode] = MapList()
|
|
123
|
-
for instance, node in table.
|
|
155
|
+
for instance, node in table.items():
|
|
124
156
|
group_instance = tuple(instance[i] for i in index_map)
|
|
125
157
|
groups.append(group_instance, node)
|
|
126
|
-
|
|
127
158
|
circuit: Circuit = table.circuit
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
159
|
+
return CircuitTable(
|
|
160
|
+
circuit,
|
|
161
|
+
remaining_rv_idxs,
|
|
162
|
+
(
|
|
163
|
+
(group, circuit.add(to_add))
|
|
164
|
+
for group, to_add in groups.items()
|
|
165
|
+
)
|
|
166
|
+
)
|
|
136
167
|
|
|
137
168
|
|
|
138
169
|
def sum_out_all(table: CircuitTable) -> CircuitTable:
|
|
@@ -145,10 +176,10 @@ def sum_out_all(table: CircuitTable) -> CircuitTable:
|
|
|
145
176
|
if num_rows == 0:
|
|
146
177
|
return CircuitTable(circuit, ())
|
|
147
178
|
elif num_rows == 1:
|
|
148
|
-
node = next(iter(table.
|
|
179
|
+
node = next(iter(table.values()))
|
|
149
180
|
else:
|
|
150
|
-
node: CircuitNode = circuit.optimised_add(table.
|
|
151
|
-
if node.is_zero
|
|
181
|
+
node: CircuitNode = circuit.optimised_add(table.values())
|
|
182
|
+
if node.is_zero:
|
|
152
183
|
return CircuitTable(circuit, ())
|
|
153
184
|
|
|
154
185
|
return CircuitTable(circuit, (), [((), node)])
|
|
@@ -168,7 +199,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
168
199
|
"""
|
|
169
200
|
Return a circuit table that results from the product of the two given tables.
|
|
170
201
|
|
|
171
|
-
If x or y
|
|
202
|
+
If x or y have a single row with value 1, then the other table is returned. Otherwise,
|
|
172
203
|
a new circuit table will be constructed and returned.
|
|
173
204
|
"""
|
|
174
205
|
circuit: Circuit = x.circuit
|
|
@@ -185,7 +216,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
185
216
|
|
|
186
217
|
# Special case: y == 0 or 1, and has no random variables.
|
|
187
218
|
if y_rv_idxs == ():
|
|
188
|
-
if len(y) == 1 and y.top().is_one
|
|
219
|
+
if len(y) == 1 and y.top().is_one:
|
|
189
220
|
return x
|
|
190
221
|
elif len(y) == 0:
|
|
191
222
|
return CircuitTable(circuit, x_rv_idxs)
|
|
@@ -193,18 +224,18 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
193
224
|
# Set operations on rv indexes. After these operations:
|
|
194
225
|
# * co_rv_idxs is the set of rv indexes common (co) to x and y,
|
|
195
226
|
# * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
227
|
+
yo_rv_idxs_set: Set[int] = set(y_rv_idxs)
|
|
228
|
+
co_rv_idxs_set: Set[int] = set(x_rv_idxs)
|
|
229
|
+
co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
|
|
230
|
+
yo_rv_idxs_set.difference_update(co_rv_idxs_set)
|
|
200
231
|
|
|
201
|
-
if len(
|
|
232
|
+
if len(co_rv_idxs_set) == 0:
|
|
202
233
|
# Special case: no common random variables.
|
|
203
234
|
return _product_no_common_rvs(x, y)
|
|
204
235
|
|
|
205
236
|
# Convert random variable index sets to sequences
|
|
206
|
-
yo_rv_idxs: Tuple[int, ...] = tuple(
|
|
207
|
-
co_rv_idxs: Tuple[int, ...] = tuple(
|
|
237
|
+
yo_rv_idxs: Tuple[int, ...] = tuple(yo_rv_idxs_set) # y only random variables
|
|
238
|
+
co_rv_idxs: Tuple[int, ...] = tuple(co_rv_idxs_set) # common random variables
|
|
208
239
|
|
|
209
240
|
# Cache mappings from result Instance to index into source Instance (x or y).
|
|
210
241
|
# This will be used in indexing and product loops to pull our needed values
|
|
@@ -215,7 +246,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
215
246
|
|
|
216
247
|
# Index the y rows by common-only key (y is the smaller of the two tables).
|
|
217
248
|
y_index: MapList[TableInstance, Tuple[TableInstance, CircuitNode]] = MapList()
|
|
218
|
-
for y_instance, y_node in y.
|
|
249
|
+
for y_instance, y_node in y.items():
|
|
219
250
|
co = tuple(y_instance[i] for i in co_from_y_map)
|
|
220
251
|
yo = tuple(y_instance[i] for i in yo_from_y_map)
|
|
221
252
|
y_index.append(co, (yo, y_node))
|
|
@@ -223,9 +254,9 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
223
254
|
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
224
255
|
# Iterate over x rows, yielding (instance, value).
|
|
225
256
|
# Rows with constant node values of one are optimised out.
|
|
226
|
-
for _x_instance, _x_node in x.
|
|
257
|
+
for _x_instance, _x_node in x.items():
|
|
227
258
|
_co = tuple(_x_instance[i] for i in co_from_x_map)
|
|
228
|
-
if _x_node.is_one
|
|
259
|
+
if _x_node.is_one:
|
|
229
260
|
# Multiplying by one.
|
|
230
261
|
# Iterate over matching y rows.
|
|
231
262
|
for _yo, _y_node in y_index.get(_co, ()):
|
|
@@ -233,7 +264,10 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
233
264
|
else:
|
|
234
265
|
# Iterate over matching y rows.
|
|
235
266
|
for _yo, _y_node in y_index.get(_co, ()):
|
|
236
|
-
|
|
267
|
+
if _y_node.is_one:
|
|
268
|
+
yield _x_instance + _yo, _x_node
|
|
269
|
+
else:
|
|
270
|
+
yield _x_instance + _yo, circuit.mul(_x_node, _y_node)
|
|
237
271
|
|
|
238
272
|
return CircuitTable(circuit, x_rv_idxs + yo_rv_idxs, _result_rows())
|
|
239
273
|
|
|
@@ -256,14 +290,15 @@ def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
256
290
|
result_rv_idxs: Tuple[int, ...] = x.rv_idxs + y.rv_idxs
|
|
257
291
|
|
|
258
292
|
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
259
|
-
for x_instance, x_node in x.
|
|
260
|
-
if x_node.is_one
|
|
261
|
-
for y_instance, y_node in y.
|
|
262
|
-
|
|
263
|
-
yield instance, y_node
|
|
293
|
+
for x_instance, x_node in x.items():
|
|
294
|
+
if x_node.is_one:
|
|
295
|
+
for y_instance, y_node in y.items():
|
|
296
|
+
yield x_instance + y_instance, y_node
|
|
264
297
|
else:
|
|
265
|
-
for y_instance, y_node in y.
|
|
266
|
-
|
|
267
|
-
|
|
298
|
+
for y_instance, y_node in y.items():
|
|
299
|
+
if y_node.is_one:
|
|
300
|
+
yield x_instance + y_instance, y_node
|
|
301
|
+
else:
|
|
302
|
+
yield x_instance + y_instance, circuit.mul(x_node, y_node)
|
|
268
303
|
|
|
269
304
|
return CircuitTable(circuit, result_rv_idxs, _result_rows())
|
|
@@ -180,11 +180,11 @@ def optimal_prefix(clusters: Clusters) -> None:
|
|
|
180
180
|
|
|
181
181
|
class Clusters:
|
|
182
182
|
"""
|
|
183
|
-
|
|
184
|
-
to
|
|
183
|
+
A Clusters object holds the state of a connection graph while
|
|
184
|
+
eliminating variables to construct clusters for a PGM graph.
|
|
185
185
|
|
|
186
|
-
The
|
|
187
|
-
or be completed
|
|
186
|
+
The Clusters object can either be "in-progress" where `len(Clusters.uneliminated) > 0`,
|
|
187
|
+
or be "completed" where `len(Clusters.uneliminated) == 0`.
|
|
188
188
|
|
|
189
189
|
See Adnan Darwiche, 2009, Modeling and Reasoning with Bayesian Networks, p164.
|
|
190
190
|
"""
|
|
@@ -229,6 +229,9 @@ class Clusters:
|
|
|
229
229
|
@property
|
|
230
230
|
def eliminated(self) -> List[int]:
|
|
231
231
|
"""
|
|
232
|
+
Get the list of eliminated random variables (as random variable
|
|
233
|
+
indices, in elimination order).
|
|
234
|
+
|
|
232
235
|
Assumes:
|
|
233
236
|
* The returned list will not be modified by the caller.
|
|
234
237
|
|
|
@@ -240,6 +243,8 @@ class Clusters:
|
|
|
240
243
|
@property
|
|
241
244
|
def uneliminated(self) -> Set[int]:
|
|
242
245
|
"""
|
|
246
|
+
Get the set of uneliminated random variables (as random variable indices).
|
|
247
|
+
|
|
243
248
|
Assumes:
|
|
244
249
|
* The returned set will not be modified by the caller.
|
|
245
250
|
|
|
@@ -285,6 +290,8 @@ class Clusters:
|
|
|
285
290
|
|
|
286
291
|
def max_cluster_size(self) -> int:
|
|
287
292
|
"""
|
|
293
|
+
Calculate the maximum cluster size over all clusters.
|
|
294
|
+
|
|
288
295
|
Returns:
|
|
289
296
|
the maximum `len(cluster)` over all clusters.
|
|
290
297
|
"""
|
|
@@ -292,6 +299,11 @@ class Clusters:
|
|
|
292
299
|
|
|
293
300
|
def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
|
|
294
301
|
"""
|
|
302
|
+
Calculate the maximum cluster weighted size over all clusters.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
rv_log_sizes: is an array of random variable sizes, such that
|
|
306
|
+
for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
|
|
295
307
|
Returns:
|
|
296
308
|
the maximum `sum(rv_log_sizes[rv_idx] for rv_idx in cluster)` over all clusters.
|
|
297
309
|
"""
|
|
@@ -348,7 +348,7 @@ def _make_factor_table(
|
|
|
348
348
|
mul_vars[instance[inst_index]]
|
|
349
349
|
for inst_index, mul_vars in zip(inst_to_mul, mul_rvs_vars)
|
|
350
350
|
)
|
|
351
|
-
if not node.is_one
|
|
351
|
+
if not node.is_one:
|
|
352
352
|
to_mul += (node,)
|
|
353
353
|
if len(to_mul) == 0:
|
|
354
354
|
yield instance, circuit.one
|