compiled-knowledge 4.0.0a16__cp312-cp312-win_amd64.whl → 4.0.0a17__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/__init__.py +2 -2
- ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
- ck/circuit/{circuit.pyx → _circuit_cy.pyx} +65 -57
- ck/circuit/{circuit_py.py → _circuit_py.py} +14 -6
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +1603 -2030
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +85 -58
- ck/circuit_compiler/named_circuit_compilers.py +1 -1
- ck/pgm_compiler/factor_elimination.py +23 -13
- ck/pgm_compiler/support/circuit_table/__init__.py +2 -2
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/{circuit_table.pyx → _circuit_table_cy.pyx} +9 -9
- ck/pgm_compiler/support/circuit_table/{circuit_table_py.py → _circuit_table_py.py} +5 -5
- ck/pgm_compiler/support/clusters.py +16 -4
- ck/pgm_compiler/support/factor_tables.py +1 -1
- ck/pgm_compiler/support/join_tree.py +67 -10
- ck/pgm_compiler/variable_elimination.py +2 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +10 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/RECORD +24 -26
- ck/circuit/circuit.c +0 -38861
- ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
- ck/circuit/circuit_node.pyx +0 -138
- ck/pgm_compiler/support/circuit_table/circuit_table.c +0 -16042
- ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/top_level.txt +0 -0
|
Binary file
|
|
@@ -1,17 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
from typing import Sequence, Dict, List, Tuple, Set, Optional, Iterator
|
|
3
|
+
from typing import Dict, Tuple, Sequence
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
import ctypes as ct
|
|
8
7
|
|
|
9
8
|
from ck import circuit
|
|
10
|
-
from ck.circuit import
|
|
9
|
+
from ck.circuit import OpNode, VarNode, CircuitNode
|
|
11
10
|
from ck.circuit_compiler.support.circuit_analyser import CircuitAnalysis, analyze_circuit
|
|
12
|
-
from ck.
|
|
13
|
-
from ck.
|
|
14
|
-
from ck.utils.np_extras import DType, NDArrayNumeric, NDArray, DTypeNumeric
|
|
11
|
+
from ck.program.raw_program import RawProgramFunction
|
|
12
|
+
from ck.utils.np_extras import NDArrayNumeric, DTypeNumeric
|
|
15
13
|
|
|
16
14
|
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
|
|
17
15
|
|
|
@@ -63,21 +61,21 @@ def make_function(
|
|
|
63
61
|
|
|
64
62
|
|
|
65
63
|
# VM instructions
|
|
66
|
-
ADD = circuit.ADD
|
|
67
|
-
MUL = circuit.MUL
|
|
68
|
-
|
|
64
|
+
cdef int ADD = circuit.ADD
|
|
65
|
+
cdef int MUL = circuit.MUL
|
|
66
|
+
cdef int COPY = max(ADD, MUL) + 1
|
|
69
67
|
|
|
70
68
|
# VM arrays
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
69
|
+
cdef int VARS = 0
|
|
70
|
+
cdef int TMPS = 1
|
|
71
|
+
cdef int CONSTS = 2
|
|
72
|
+
cdef int RESULT = 3
|
|
75
73
|
|
|
76
74
|
|
|
77
|
-
|
|
78
|
-
analysis: CircuitAnalysis,
|
|
79
|
-
dtype: DTypeNumeric,
|
|
80
|
-
) -> Tuple[Instructions, NDArrayNumeric]:
|
|
75
|
+
cdef tuple[Instructions, cnp.ndarray] _make_instructions_from_analysis(
|
|
76
|
+
object analysis: CircuitAnalysis,
|
|
77
|
+
object dtype: DTypeNumeric,
|
|
78
|
+
): # -> Tuple[Instructions, NDArrayNumeric]:
|
|
81
79
|
if dtype != np.float64:
|
|
82
80
|
raise RuntimeError(f'only DType {np.float64} currently supported')
|
|
83
81
|
|
|
@@ -91,7 +89,7 @@ def _make_instructions_from_analysis(
|
|
|
91
89
|
np_consts[i] = node.value
|
|
92
90
|
|
|
93
91
|
# Where to get input values for each possible node.
|
|
94
|
-
|
|
92
|
+
cdef dict[int, ElementID] node_to_element = {}
|
|
95
93
|
# const nodes
|
|
96
94
|
for node_id, const_idx in node_to_const_idx.items():
|
|
97
95
|
node_to_element[node_id] = ElementID(CONSTS, const_idx)
|
|
@@ -110,21 +108,16 @@ def _make_instructions_from_analysis(
|
|
|
110
108
|
# Build instructions
|
|
111
109
|
instructions: Instructions = Instructions()
|
|
112
110
|
|
|
113
|
-
op_node
|
|
111
|
+
cdef object op_node
|
|
114
112
|
for op_node in analysis.op_nodes:
|
|
115
|
-
|
|
116
|
-
args: list[ElementID] = [
|
|
117
|
-
node_to_element[id(arg)]
|
|
118
|
-
for arg in op_node.args
|
|
119
|
-
]
|
|
120
|
-
instructions.append(op_node.symbol, args, dest)
|
|
113
|
+
instructions.append_op(op_node.symbol, op_node, node_to_element)
|
|
121
114
|
|
|
122
115
|
# Add any copy operations, i.e., result nodes that are not op nodes
|
|
123
116
|
for i, node in enumerate(analysis.result_nodes):
|
|
124
117
|
if not isinstance(node, OpNode):
|
|
125
118
|
dest: ElementID = ElementID(RESULT, i)
|
|
126
|
-
|
|
127
|
-
instructions.
|
|
119
|
+
src: ElementID = node_to_element[id(node)]
|
|
120
|
+
instructions.append_copy(src, dest)
|
|
128
121
|
|
|
129
122
|
return instructions, np_consts
|
|
130
123
|
|
|
@@ -143,32 +136,63 @@ cdef struct Instruction:
|
|
|
143
136
|
|
|
144
137
|
cdef class Instructions:
|
|
145
138
|
cdef Instruction* instructions
|
|
139
|
+
cdef int allocated
|
|
146
140
|
cdef int num_instructions
|
|
147
141
|
|
|
148
|
-
def __init__(self):
|
|
149
|
-
self.instructions = <Instruction*> PyMem_Malloc(0)
|
|
142
|
+
def __init__(self) -> None:
|
|
150
143
|
self.num_instructions = 0
|
|
144
|
+
self.allocated = 64
|
|
145
|
+
self.instructions = <Instruction*> PyMem_Malloc(self.allocated * sizeof(Instruction))
|
|
146
|
+
|
|
147
|
+
cdef void append_copy(
|
|
148
|
+
self,
|
|
149
|
+
ElementID src,
|
|
150
|
+
ElementID dest,
|
|
151
|
+
) except*:
|
|
152
|
+
c_args = <ElementID*> PyMem_Malloc(sizeof(ElementID))
|
|
153
|
+
if not c_args:
|
|
154
|
+
raise MemoryError()
|
|
151
155
|
|
|
152
|
-
|
|
156
|
+
c_args[0] = src
|
|
157
|
+
self._append(COPY, 1, c_args, dest)
|
|
158
|
+
|
|
159
|
+
cdef void append_op(self, int symbol, object op_node: OpNode, dict[int, ElementID] node_to_element) except*:
|
|
160
|
+
args = op_node.args
|
|
153
161
|
cdef int num_args = len(args)
|
|
154
|
-
cdef int i
|
|
155
162
|
|
|
156
|
-
|
|
157
|
-
|
|
163
|
+
# Create the instruction arguments array
|
|
164
|
+
c_args = <ElementID*> PyMem_Malloc(num_args * sizeof(ElementID))
|
|
158
165
|
if not c_args:
|
|
159
166
|
raise MemoryError()
|
|
160
167
|
|
|
161
|
-
|
|
162
|
-
|
|
168
|
+
cdef int i = num_args
|
|
169
|
+
while i > 0:
|
|
170
|
+
i -= 1
|
|
171
|
+
c_args[i] = node_to_element[id(args[i])]
|
|
172
|
+
|
|
173
|
+
dest: ElementID = node_to_element[id(op_node)]
|
|
174
|
+
|
|
175
|
+
self._append(symbol, num_args, c_args, dest)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
cdef void _append(self, int symbol, int num_args, ElementID* c_args, ElementID dest) except *:
|
|
179
|
+
cdef int i
|
|
163
180
|
|
|
164
181
|
cdef int num_instructions = self.num_instructions
|
|
165
|
-
self.instructions = <Instruction*> PyMem_Realloc(
|
|
166
|
-
self.instructions,
|
|
167
|
-
sizeof(Instruction) * (num_instructions + 1)
|
|
168
|
-
)
|
|
169
|
-
if not self.instructions:
|
|
170
|
-
raise MemoryError()
|
|
171
182
|
|
|
183
|
+
# Ensure sufficient instruction memory
|
|
184
|
+
cdef int allocated = self.allocated
|
|
185
|
+
if num_instructions == allocated:
|
|
186
|
+
allocated *= 2
|
|
187
|
+
self.instructions = <Instruction*> PyMem_Realloc(
|
|
188
|
+
self.instructions,
|
|
189
|
+
allocated * sizeof(Instruction),
|
|
190
|
+
)
|
|
191
|
+
if not self.instructions:
|
|
192
|
+
raise MemoryError()
|
|
193
|
+
self.allocated = allocated
|
|
194
|
+
|
|
195
|
+
# Add the instruction
|
|
172
196
|
self.instructions[num_instructions] = Instruction(
|
|
173
197
|
symbol,
|
|
174
198
|
num_args,
|
|
@@ -177,7 +201,7 @@ cdef class Instructions:
|
|
|
177
201
|
)
|
|
178
202
|
self.num_instructions = num_instructions + 1
|
|
179
203
|
|
|
180
|
-
def __dealloc__(self):
|
|
204
|
+
def __dealloc__(self) -> None:
|
|
181
205
|
cdef Instruction* instructions = self.instructions
|
|
182
206
|
if instructions:
|
|
183
207
|
for i in range(self.num_instructions):
|
|
@@ -194,15 +218,15 @@ cdef void cvm_float64(
|
|
|
194
218
|
double* consts,
|
|
195
219
|
double* result,
|
|
196
220
|
Instructions instructions,
|
|
197
|
-
)
|
|
198
|
-
# Core virtual machine.
|
|
221
|
+
) except *:
|
|
222
|
+
# Core virtual machine (for dtype float64).
|
|
199
223
|
|
|
200
|
-
cdef int i,
|
|
224
|
+
cdef int i, symbol
|
|
201
225
|
cdef double accumulator
|
|
202
226
|
cdef ElementID* args
|
|
203
|
-
cdef ElementID
|
|
227
|
+
cdef ElementID elem
|
|
204
228
|
|
|
205
|
-
#
|
|
229
|
+
# Index the four arrays by constants VARS, TMPS, CONSTS, and RESULT
|
|
206
230
|
cdef (double*) arrays[4]
|
|
207
231
|
arrays[VARS] = vars_in
|
|
208
232
|
arrays[TMPS] = tmps
|
|
@@ -210,29 +234,32 @@ cdef void cvm_float64(
|
|
|
210
234
|
arrays[RESULT] = result
|
|
211
235
|
|
|
212
236
|
cdef Instruction* instruction_ptr = instructions.instructions
|
|
213
|
-
|
|
237
|
+
cdef int num_instructions = instructions.num_instructions
|
|
214
238
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
239
|
+
while num_instructions > 0:
|
|
240
|
+
num_instructions -= 1
|
|
241
|
+
|
|
242
|
+
symbol = instruction_ptr.symbol
|
|
243
|
+
args = instruction_ptr.args
|
|
218
244
|
|
|
219
245
|
elem = args[0]
|
|
220
246
|
accumulator = arrays[elem.array][elem.index]
|
|
221
247
|
|
|
222
248
|
if symbol == ADD:
|
|
223
|
-
|
|
249
|
+
i = instruction_ptr.num_args
|
|
250
|
+
while i > 1:
|
|
251
|
+
i -= 1
|
|
224
252
|
elem = args[i]
|
|
225
253
|
accumulator += arrays[elem.array][elem.index]
|
|
226
254
|
elif symbol == MUL:
|
|
227
|
-
|
|
255
|
+
i = instruction_ptr.num_args
|
|
256
|
+
while i > 1:
|
|
257
|
+
i -= 1
|
|
228
258
|
elem = args[i]
|
|
229
259
|
accumulator *= arrays[elem.array][elem.index]
|
|
230
|
-
|
|
231
|
-
pass
|
|
232
|
-
else:
|
|
233
|
-
raise RuntimeError('symbol not understood: ' + str(symbol))
|
|
260
|
+
# else symbol == COPY, nothing to do
|
|
234
261
|
|
|
235
|
-
elem = instruction_ptr
|
|
262
|
+
elem = instruction_ptr.dest
|
|
236
263
|
arrays[elem.array][elem.index] = accumulator
|
|
237
264
|
|
|
238
265
|
# Advance the instruction pointer
|
|
@@ -149,7 +149,7 @@ def join_tree_to_circuit(
|
|
|
149
149
|
limit_product_tree_search,
|
|
150
150
|
)
|
|
151
151
|
top: CircuitNode = top_table.top()
|
|
152
|
-
|
|
152
|
+
top.circuit.remove_unreachable_op_nodes(top)
|
|
153
153
|
|
|
154
154
|
return PGMCircuit(
|
|
155
155
|
rvs=tuple(pgm.rvs),
|
|
@@ -169,27 +169,37 @@ def _circuit_tables_from_join_tree(
|
|
|
169
169
|
) -> CircuitTable:
|
|
170
170
|
"""
|
|
171
171
|
This is a basic algorithm for constructing a circuit table from a join tree.
|
|
172
|
+
Algorithm synopsis:
|
|
173
|
+
1) Get a CircuitTable for each factor allocated to this join tree node, and
|
|
174
|
+
for each child of the join tree node (recursive call to _circuit_tables_from_join_tree).
|
|
175
|
+
2) Form a binary tree of the collected circuit tables.
|
|
176
|
+
3) Perform table products and sum-outs for each node in the binary tree, which should
|
|
177
|
+
leave a single circuit table with a single row.
|
|
172
178
|
"""
|
|
173
|
-
#
|
|
174
|
-
factors: List[CircuitTable] =
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
179
|
+
# Get all the factors to combine.
|
|
180
|
+
factors: List[CircuitTable] = list(
|
|
181
|
+
chain(
|
|
182
|
+
(
|
|
183
|
+
# The PGM factors allocated to this join tree node
|
|
184
|
+
factor_tables.get_table(factor)
|
|
185
|
+
for factor in join_tree.factors
|
|
186
|
+
),
|
|
187
|
+
(
|
|
188
|
+
# The children of this join tree node
|
|
189
|
+
_circuit_tables_from_join_tree(factor_tables, child, limit_product_tree_search)
|
|
190
|
+
for child in join_tree.children
|
|
191
|
+
),
|
|
192
|
+
)
|
|
183
193
|
)
|
|
184
194
|
|
|
185
195
|
# The usual join tree approach just forms the product all the tables in `factors`.
|
|
186
196
|
# The tree width is not affected by the order of products, however some orders
|
|
187
197
|
# lead to smaller numbers of arithmetic operations.
|
|
188
198
|
#
|
|
189
|
-
# If `
|
|
199
|
+
# If `limit_product_tree_search > 1`, then heuristics are used
|
|
190
200
|
# reduce the number of arithmetic operations.
|
|
191
201
|
|
|
192
|
-
# Deal with the special case:
|
|
202
|
+
# Deal with the special case: zero factors
|
|
193
203
|
if len(factors) == 0:
|
|
194
204
|
circuit = factor_tables.circuit
|
|
195
205
|
if len(join_tree.separator) == 0:
|
|
@@ -142,7 +142,7 @@ cpdef object sum_out(object table: CircuitTable, object rv_idxs: Iterable[int]):
|
|
|
142
142
|
|
|
143
143
|
for group_instance_tuple, to_add in groups.items():
|
|
144
144
|
node = circuit.optimised_add(to_add)
|
|
145
|
-
if not node.is_zero
|
|
145
|
+
if not node.is_zero:
|
|
146
146
|
rows[group_instance_tuple] = node
|
|
147
147
|
|
|
148
148
|
return new_table
|
|
@@ -159,7 +159,7 @@ cpdef object sum_out_all(object table: CircuitTable): # -> CircuitTable:
|
|
|
159
159
|
node = next(iter(table.rows.values()))
|
|
160
160
|
else:
|
|
161
161
|
node: CircuitNode = circuit.optimised_add(table.rows.values())
|
|
162
|
-
if node.is_zero
|
|
162
|
+
if node.is_zero:
|
|
163
163
|
return CircuitTable(circuit, ())
|
|
164
164
|
|
|
165
165
|
return CircuitTable(circuit, (), [((), node)])
|
|
@@ -190,7 +190,7 @@ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
|
|
|
190
190
|
|
|
191
191
|
# Special case: y == 0 or 1, and has no random variables.
|
|
192
192
|
if len(y.rv_idxs) == 0:
|
|
193
|
-
if len(y) == 1 and y.top().is_one
|
|
193
|
+
if len(y) == 1 and y.top().is_one:
|
|
194
194
|
return x
|
|
195
195
|
elif len(y) == 0:
|
|
196
196
|
return CircuitTable(circuit, x.rv_idxs)
|
|
@@ -259,7 +259,7 @@ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
|
|
|
259
259
|
co.append(x_instance[i])
|
|
260
260
|
co_tuple = tuple(co)
|
|
261
261
|
|
|
262
|
-
if x_node.is_one
|
|
262
|
+
if x_node.is_one:
|
|
263
263
|
# Multiplying by one.
|
|
264
264
|
# Iterate over matching y rows.
|
|
265
265
|
got = y_index.get(co_tuple)
|
|
@@ -301,7 +301,7 @@ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> Circ
|
|
|
301
301
|
cdef tuple[int, ...] instance
|
|
302
302
|
|
|
303
303
|
for x_instance, x_node in x.rows.items():
|
|
304
|
-
if x_node.is_one
|
|
304
|
+
if x_node.is_one:
|
|
305
305
|
for y_instance, y_node in y.rows.items():
|
|
306
306
|
instance = x_instance + y_instance
|
|
307
307
|
table.rows[instance] = y_node
|
|
@@ -314,12 +314,12 @@ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> Circ
|
|
|
314
314
|
|
|
315
315
|
|
|
316
316
|
cdef object _optimised_mul(object circuit: Circuit, object x: CircuitNode, object y: CircuitNode): # -> CircuitNode
|
|
317
|
-
if x.is_zero
|
|
317
|
+
if x.is_zero:
|
|
318
318
|
return x
|
|
319
|
-
if y.is_zero
|
|
319
|
+
if y.is_zero:
|
|
320
320
|
return y
|
|
321
|
-
if x.is_one
|
|
321
|
+
if x.is_one:
|
|
322
322
|
return y
|
|
323
|
-
if y.is_one
|
|
323
|
+
if y.is_one:
|
|
324
324
|
return x
|
|
325
325
|
return circuit.mul(x, y)
|
|
@@ -129,7 +129,7 @@ def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
|
|
|
129
129
|
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
130
130
|
for group, to_add in groups.items():
|
|
131
131
|
_node: CircuitNode = circuit.optimised_add(to_add)
|
|
132
|
-
if not _node.is_zero
|
|
132
|
+
if not _node.is_zero:
|
|
133
133
|
yield group, _node
|
|
134
134
|
|
|
135
135
|
return CircuitTable(circuit, remaining_rv_idxs, _result_rows())
|
|
@@ -148,7 +148,7 @@ def sum_out_all(table: CircuitTable) -> CircuitTable:
|
|
|
148
148
|
node = next(iter(table.rows.values()))
|
|
149
149
|
else:
|
|
150
150
|
node: CircuitNode = circuit.optimised_add(table.rows.values())
|
|
151
|
-
if node.is_zero
|
|
151
|
+
if node.is_zero:
|
|
152
152
|
return CircuitTable(circuit, ())
|
|
153
153
|
|
|
154
154
|
return CircuitTable(circuit, (), [((), node)])
|
|
@@ -185,7 +185,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
185
185
|
|
|
186
186
|
# Special case: y == 0 or 1, and has no random variables.
|
|
187
187
|
if y_rv_idxs == ():
|
|
188
|
-
if len(y) == 1 and y.top().is_one
|
|
188
|
+
if len(y) == 1 and y.top().is_one:
|
|
189
189
|
return x
|
|
190
190
|
elif len(y) == 0:
|
|
191
191
|
return CircuitTable(circuit, x_rv_idxs)
|
|
@@ -225,7 +225,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
225
225
|
# Rows with constant node values of one are optimised out.
|
|
226
226
|
for _x_instance, _x_node in x.rows.items():
|
|
227
227
|
_co = tuple(_x_instance[i] for i in co_from_x_map)
|
|
228
|
-
if _x_node.is_one
|
|
228
|
+
if _x_node.is_one:
|
|
229
229
|
# Multiplying by one.
|
|
230
230
|
# Iterate over matching y rows.
|
|
231
231
|
for _yo, _y_node in y_index.get(_co, ()):
|
|
@@ -257,7 +257,7 @@ def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
|
|
|
257
257
|
|
|
258
258
|
def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
|
|
259
259
|
for x_instance, x_node in x.rows.items():
|
|
260
|
-
if x_node.is_one
|
|
260
|
+
if x_node.is_one:
|
|
261
261
|
for y_instance, y_node in y.rows.items():
|
|
262
262
|
instance = x_instance + y_instance
|
|
263
263
|
yield instance, y_node
|
|
@@ -180,11 +180,11 @@ def optimal_prefix(clusters: Clusters) -> None:
|
|
|
180
180
|
|
|
181
181
|
class Clusters:
|
|
182
182
|
"""
|
|
183
|
-
|
|
184
|
-
to
|
|
183
|
+
A Clusters object holds the state of a connection graph while
|
|
184
|
+
eliminating variables to construct clusters for a PGM graph.
|
|
185
185
|
|
|
186
|
-
The
|
|
187
|
-
or be completed
|
|
186
|
+
The Clusters object can either be "in-progress" where `len(Clusters.uneliminated) > 0`,
|
|
187
|
+
or be "completed" where `len(Clusters.uneliminated) == 0`.
|
|
188
188
|
|
|
189
189
|
See Adnan Darwiche, 2009, Modeling and Reasoning with Bayesian Networks, p164.
|
|
190
190
|
"""
|
|
@@ -229,6 +229,9 @@ class Clusters:
|
|
|
229
229
|
@property
|
|
230
230
|
def eliminated(self) -> List[int]:
|
|
231
231
|
"""
|
|
232
|
+
Get the list of eliminated random variables (as random variable
|
|
233
|
+
indices, in elimination order).
|
|
234
|
+
|
|
232
235
|
Assumes:
|
|
233
236
|
* The returned list will not be modified by the caller.
|
|
234
237
|
|
|
@@ -240,6 +243,8 @@ class Clusters:
|
|
|
240
243
|
@property
|
|
241
244
|
def uneliminated(self) -> Set[int]:
|
|
242
245
|
"""
|
|
246
|
+
Get the set of uneliminated random variables (as random variable indices).
|
|
247
|
+
|
|
243
248
|
Assumes:
|
|
244
249
|
* The returned set will not be modified by the caller.
|
|
245
250
|
|
|
@@ -285,6 +290,8 @@ class Clusters:
|
|
|
285
290
|
|
|
286
291
|
def max_cluster_size(self) -> int:
|
|
287
292
|
"""
|
|
293
|
+
Calculate the maximum cluster size over all clusters.
|
|
294
|
+
|
|
288
295
|
Returns:
|
|
289
296
|
the maximum `len(cluster)` over all clusters.
|
|
290
297
|
"""
|
|
@@ -292,6 +299,11 @@ class Clusters:
|
|
|
292
299
|
|
|
293
300
|
def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
|
|
294
301
|
"""
|
|
302
|
+
Calculate the maximum cluster weighted size over all clusters.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
rv_log_sizes: is an array of random variable sizes, such that
|
|
306
|
+
for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
|
|
295
307
|
Returns:
|
|
296
308
|
the maximum `sum(rv_log_sizes[rv_idx] for rv_idx in cluster)` over all clusters.
|
|
297
309
|
"""
|
|
@@ -348,7 +348,7 @@ def _make_factor_table(
|
|
|
348
348
|
mul_vars[instance[inst_index]]
|
|
349
349
|
for inst_index, mul_vars in zip(inst_to_mul, mul_rvs_vars)
|
|
350
350
|
)
|
|
351
|
-
if not node.is_one
|
|
351
|
+
if not node.is_one:
|
|
352
352
|
to_mul += (node,)
|
|
353
353
|
if len(to_mul) == 0:
|
|
354
354
|
yield instance, circuit.one
|
|
@@ -15,6 +15,11 @@ from ck.utils.np_extras import NDArrayFloat64
|
|
|
15
15
|
|
|
16
16
|
@dataclass
|
|
17
17
|
class JoinTree:
|
|
18
|
+
"""
|
|
19
|
+
This is a recursive data structure representing a join-tree.
|
|
20
|
+
Each node in the join-tree is represented by a JoinTree object.
|
|
21
|
+
"""
|
|
22
|
+
|
|
18
23
|
# The PGM that this join tree is for.
|
|
19
24
|
pgm: PGM
|
|
20
25
|
|
|
@@ -40,6 +45,12 @@ class JoinTree:
|
|
|
40
45
|
|
|
41
46
|
def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
|
|
42
47
|
"""
|
|
48
|
+
Calculate the maximum cluster weighted size for this cluster and its children.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
rv_log_sizes: is an array of random variable sizes, such that
|
|
52
|
+
for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
|
|
53
|
+
|
|
43
54
|
Returns:
|
|
44
55
|
the maximum `log2` over self and all children, recursively.
|
|
45
56
|
"""
|
|
@@ -82,8 +93,8 @@ JoinTreeAlgorithm = Callable[[PGM], JoinTree]
|
|
|
82
93
|
|
|
83
94
|
def _join_tree_algorithm(pgm_to_clusters: ClusterAlgorithm) -> JoinTreeAlgorithm:
|
|
84
95
|
"""
|
|
85
|
-
Helper function for creating a standard JoinTreeAlgorithm
|
|
86
|
-
a ClusterAlgorithm.
|
|
96
|
+
Helper function for creating a standard JoinTreeAlgorithm
|
|
97
|
+
from a ClusterAlgorithm.
|
|
87
98
|
|
|
88
99
|
Args:
|
|
89
100
|
pgm_to_clusters: The clusters method to use.
|
|
@@ -112,14 +123,17 @@ MIN_TRADITIONAL_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_trad
|
|
|
112
123
|
|
|
113
124
|
def clusters_to_join_tree(clusters: Clusters) -> JoinTree:
|
|
114
125
|
"""
|
|
115
|
-
Construct a join tree
|
|
126
|
+
Construct a join tree from the given random variable clusters.
|
|
116
127
|
|
|
117
128
|
A join tree is formed by finding a minimum spanning tree over the clusters
|
|
118
|
-
where the cost between a pair of
|
|
119
|
-
|
|
129
|
+
where the cost between a pair of clusters is the number of random variables
|
|
130
|
+
in common (using separator state space size to break ties).
|
|
120
131
|
|
|
121
132
|
Args:
|
|
122
|
-
clusters: the clusters that resulted from graph clusters of
|
|
133
|
+
clusters: the clusters that resulted from graph clusters of a PGM.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
a JoinTree.
|
|
123
137
|
"""
|
|
124
138
|
pgm: PGM = clusters.pgm
|
|
125
139
|
cluster_sets: List[Set[int]] = clusters.clusters
|
|
@@ -170,6 +184,19 @@ def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]
|
|
|
170
184
|
"""
|
|
171
185
|
Construct a minimum spanning tree over the clusters, where the root is the cluster with
|
|
172
186
|
the smallest number of random variable.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
cost: is an N x N matrix of costs between N clusters.
|
|
190
|
+
clusters: is a list of N clusters, each cluster is a set of random variable indices.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
(spanning_tree, root_index)
|
|
194
|
+
|
|
195
|
+
spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
|
|
196
|
+
the given cost matrix, each node is a list of children, each child being
|
|
197
|
+
represented as an index into the list of nodes.
|
|
198
|
+
|
|
199
|
+
root_index: is the index the chosen root of the spanning tree.
|
|
173
200
|
"""
|
|
174
201
|
root_custer_index: int = 0
|
|
175
202
|
root_size: int = len(clusters[root_custer_index])
|
|
@@ -185,10 +212,22 @@ def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]
|
|
|
185
212
|
def _make_spanning_tree_arbitrary_root(cost: NDArrayFloat64) -> Tuple[List[List[int]], int]:
|
|
186
213
|
"""
|
|
187
214
|
Construct a minimum spanning tree over the clusters, starting at an arbitrary root.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
cost: is an N x N matrix of costs between N clusters.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
(spanning_tree, root_index)
|
|
221
|
+
|
|
222
|
+
spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
|
|
223
|
+
the given cost matrix, each node is a list of children, each child being
|
|
224
|
+
represented as an index into the list of nodes.
|
|
225
|
+
|
|
226
|
+
root_index: is the index the chosen root of the spanning tree.
|
|
188
227
|
"""
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
return
|
|
228
|
+
root_index: int = 0
|
|
229
|
+
spanning_tree: List[List[int]] = _make_spanning_tree_at_root(cost, root_index)
|
|
230
|
+
return spanning_tree, root_index
|
|
192
231
|
|
|
193
232
|
|
|
194
233
|
def _make_spanning_tree_at_root(
|
|
@@ -202,6 +241,12 @@ def _make_spanning_tree_at_root(
|
|
|
202
241
|
cost: and nxn matrix where n is the number of clusters and cost[i, j]
|
|
203
242
|
gives the cost between clusters i and j.
|
|
204
243
|
root_custer_index: a nominated root cluster to be the root of the tree.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
a spanning tree represented as a list of nodes, the list is coindexed with
|
|
247
|
+
the given cost matrix, each node is a list of children, each child being
|
|
248
|
+
represented as an index into the list of nodes. The root node is the
|
|
249
|
+
index `root_custer_index` as passed to this function.
|
|
205
250
|
"""
|
|
206
251
|
number_of_clusters: int = cost.shape[0]
|
|
207
252
|
|
|
@@ -257,7 +302,19 @@ def _form_join_tree_r(
|
|
|
257
302
|
cluster_factors: List[List[Factor]],
|
|
258
303
|
) -> JoinTree:
|
|
259
304
|
"""
|
|
260
|
-
Recursively build the
|
|
305
|
+
Recursively build a JoinTree from the spanning tree `children`.
|
|
306
|
+
This function merely pull the corresponding component from the
|
|
307
|
+
arguments to make a JoinTree object, doing this recursively
|
|
308
|
+
for the children.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
pgm: the source PGM for the join tree.
|
|
312
|
+
cluster_index: index for the node we are processing (current root). This
|
|
313
|
+
indexes into `children`, `clusters`, and `cluster_factors`.
|
|
314
|
+
parent_cluster: set of random variable indices in the parent cluster.
|
|
315
|
+
children: list of spanning tree nodes, as per `_make_spanning_tree_at_root` result.
|
|
316
|
+
clusters: list of clusters, each cluster is a set of random variable indices.
|
|
317
|
+
cluster_factors: assignment of factors to clusters.
|
|
261
318
|
"""
|
|
262
319
|
cluster: Set[int] = clusters[cluster_index]
|
|
263
320
|
factors: List[Factor] = cluster_factors[cluster_index]
|
|
@@ -51,6 +51,8 @@ def compile_pgm(
|
|
|
51
51
|
|
|
52
52
|
elimination_order: Sequence[int] = algorithm(pgm).eliminated
|
|
53
53
|
|
|
54
|
+
# Eliminate rvs from the factor tables according to the
|
|
55
|
+
# elimination order.
|
|
54
56
|
cur_tables: List[CircuitTable] = list(factor_tables.tables)
|
|
55
57
|
for rv_idx in elimination_order:
|
|
56
58
|
next_tables: List[CircuitTable] = []
|
|
@@ -9,6 +9,16 @@ from ck.pgm_compiler.support.join_tree import JoinTree, clusters_to_join_tree
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def main() -> None:
|
|
12
|
+
"""
|
|
13
|
+
This demo shows the full compilation chain for factor elimination.
|
|
14
|
+
|
|
15
|
+
Process:
|
|
16
|
+
Rain example -> PGM
|
|
17
|
+
min_degree -> Clusters
|
|
18
|
+
clusters_to_join_tree -> JoinTree
|
|
19
|
+
join_tree_to_circuit -> PGMCircuit
|
|
20
|
+
default circuit compiler -> WMCProgram
|
|
21
|
+
"""
|
|
12
22
|
pgm: PGM = example.Rain()
|
|
13
23
|
|
|
14
24
|
print(f'PGM {pgm.name!r}')
|