compiled-knowledge 4.0.0a16__cp312-cp312-win_amd64.whl → 4.0.0a17__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (29) hide show
  1. ck/circuit/__init__.py +2 -2
  2. ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
  3. ck/circuit/{circuit.pyx → _circuit_cy.pyx} +65 -57
  4. ck/circuit/{circuit_py.py → _circuit_py.py} +14 -6
  5. ck/circuit_compiler/cython_vm_compiler/_compiler.c +1603 -2030
  6. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  7. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +85 -58
  8. ck/circuit_compiler/named_circuit_compilers.py +1 -1
  9. ck/pgm_compiler/factor_elimination.py +23 -13
  10. ck/pgm_compiler/support/circuit_table/__init__.py +2 -2
  11. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
  12. ck/pgm_compiler/support/circuit_table/{circuit_table.pyx → _circuit_table_cy.pyx} +9 -9
  13. ck/pgm_compiler/support/circuit_table/{circuit_table_py.py → _circuit_table_py.py} +5 -5
  14. ck/pgm_compiler/support/clusters.py +16 -4
  15. ck/pgm_compiler/support/factor_tables.py +1 -1
  16. ck/pgm_compiler/support/join_tree.py +67 -10
  17. ck/pgm_compiler/variable_elimination.py +2 -0
  18. ck_demos/pgm_compiler/demo_compiler_dump.py +10 -0
  19. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  20. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/METADATA +1 -1
  21. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/RECORD +24 -26
  22. ck/circuit/circuit.c +0 -38861
  23. ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
  24. ck/circuit/circuit_node.pyx +0 -138
  25. ck/pgm_compiler/support/circuit_table/circuit_table.c +0 -16042
  26. ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
  27. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/WHEEL +0 -0
  28. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/licenses/LICENSE.txt +0 -0
  29. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a17.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
- from pickletools import long1
4
- from typing import Sequence, Dict, List, Tuple, Set, Optional, Iterator
3
+ from typing import Dict, Tuple, Sequence
5
4
 
6
5
  import numpy as np
7
6
  import ctypes as ct
8
7
 
9
8
  from ck import circuit
10
- from ck.circuit import CircuitNode, ConstNode, VarNode, OpNode, ADD, Circuit
9
+ from ck.circuit import OpNode, VarNode, CircuitNode
11
10
  from ck.circuit_compiler.support.circuit_analyser import CircuitAnalysis, analyze_circuit
12
- from ck.circuit_compiler.support.input_vars import infer_input_vars, InputVars
13
- from ck.program.raw_program import RawProgram, RawProgramFunction
14
- from ck.utils.np_extras import DType, NDArrayNumeric, NDArray, DTypeNumeric
11
+ from ck.program.raw_program import RawProgramFunction
12
+ from ck.utils.np_extras import NDArrayNumeric, DTypeNumeric
15
13
 
16
14
  from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
17
15
 
@@ -63,21 +61,21 @@ def make_function(
63
61
 
64
62
 
65
63
  # VM instructions
66
- ADD = circuit.ADD
67
- MUL = circuit.MUL
68
- COPY: int = max(ADD, MUL) + 1
64
+ cdef int ADD = circuit.ADD
65
+ cdef int MUL = circuit.MUL
66
+ cdef int COPY = max(ADD, MUL) + 1
69
67
 
70
68
  # VM arrays
71
- VARS: int = 0
72
- TMPS: int = 1
73
- CONSTS: int = 2
74
- RESULT: int = 3
69
+ cdef int VARS = 0
70
+ cdef int TMPS = 1
71
+ cdef int CONSTS = 2
72
+ cdef int RESULT = 3
75
73
 
76
74
 
77
- def _make_instructions_from_analysis(
78
- analysis: CircuitAnalysis,
79
- dtype: DTypeNumeric,
80
- ) -> Tuple[Instructions, NDArrayNumeric]:
75
+ cdef tuple[Instructions, cnp.ndarray] _make_instructions_from_analysis(
76
+ object analysis: CircuitAnalysis,
77
+ object dtype: DTypeNumeric,
78
+ ): # -> Tuple[Instructions, NDArrayNumeric]:
81
79
  if dtype != np.float64:
82
80
  raise RuntimeError(f'only DType {np.float64} currently supported')
83
81
 
@@ -91,7 +89,7 @@ def _make_instructions_from_analysis(
91
89
  np_consts[i] = node.value
92
90
 
93
91
  # Where to get input values for each possible node.
94
- node_to_element: Dict[int, ElementID] = {}
92
+ cdef dict[int, ElementID] node_to_element = {}
95
93
  # const nodes
96
94
  for node_id, const_idx in node_to_const_idx.items():
97
95
  node_to_element[node_id] = ElementID(CONSTS, const_idx)
@@ -110,21 +108,16 @@ def _make_instructions_from_analysis(
110
108
  # Build instructions
111
109
  instructions: Instructions = Instructions()
112
110
 
113
- op_node: OpNode
111
+ cdef object op_node
114
112
  for op_node in analysis.op_nodes:
115
- dest: ElementID = node_to_element[id(op_node)]
116
- args: list[ElementID] = [
117
- node_to_element[id(arg)]
118
- for arg in op_node.args
119
- ]
120
- instructions.append(op_node.symbol, args, dest)
113
+ instructions.append_op(op_node.symbol, op_node, node_to_element)
121
114
 
122
115
  # Add any copy operations, i.e., result nodes that are not op nodes
123
116
  for i, node in enumerate(analysis.result_nodes):
124
117
  if not isinstance(node, OpNode):
125
118
  dest: ElementID = ElementID(RESULT, i)
126
- args: list[ElementID] = [node_to_element[id(node)]]
127
- instructions.append(COPY, args, dest)
119
+ src: ElementID = node_to_element[id(node)]
120
+ instructions.append_copy(src, dest)
128
121
 
129
122
  return instructions, np_consts
130
123
 
@@ -143,32 +136,63 @@ cdef struct Instruction:
143
136
 
144
137
  cdef class Instructions:
145
138
  cdef Instruction* instructions
139
+ cdef int allocated
146
140
  cdef int num_instructions
147
141
 
148
- def __init__(self):
149
- self.instructions = <Instruction*> PyMem_Malloc(0)
142
+ def __init__(self) -> None:
150
143
  self.num_instructions = 0
144
+ self.allocated = 64
145
+ self.instructions = <Instruction*> PyMem_Malloc(self.allocated * sizeof(Instruction))
146
+
147
+ cdef void append_copy(
148
+ self,
149
+ ElementID src,
150
+ ElementID dest,
151
+ ) except*:
152
+ c_args = <ElementID*> PyMem_Malloc(sizeof(ElementID))
153
+ if not c_args:
154
+ raise MemoryError()
151
155
 
152
- def append(self, int symbol, list[ElementID] args, ElementID dest) -> None:
156
+ c_args[0] = src
157
+ self._append(COPY, 1, c_args, dest)
158
+
159
+ cdef void append_op(self, int symbol, object op_node: OpNode, dict[int, ElementID] node_to_element) except*:
160
+ args = op_node.args
153
161
  cdef int num_args = len(args)
154
- cdef int i
155
162
 
156
- c_args = <ElementID*> PyMem_Malloc(
157
- num_args * sizeof(ElementID))
163
+ # Create the instruction arguments array
164
+ c_args = <ElementID*> PyMem_Malloc(num_args * sizeof(ElementID))
158
165
  if not c_args:
159
166
  raise MemoryError()
160
167
 
161
- for i in range(num_args):
162
- c_args[i] = args[i]
168
+ cdef int i = num_args
169
+ while i > 0:
170
+ i -= 1
171
+ c_args[i] = node_to_element[id(args[i])]
172
+
173
+ dest: ElementID = node_to_element[id(op_node)]
174
+
175
+ self._append(symbol, num_args, c_args, dest)
176
+
177
+
178
+ cdef void _append(self, int symbol, int num_args, ElementID* c_args, ElementID dest) except *:
179
+ cdef int i
163
180
 
164
181
  cdef int num_instructions = self.num_instructions
165
- self.instructions = <Instruction*> PyMem_Realloc(
166
- self.instructions,
167
- sizeof(Instruction) * (num_instructions + 1)
168
- )
169
- if not self.instructions:
170
- raise MemoryError()
171
182
 
183
+ # Ensure sufficient instruction memory
184
+ cdef int allocated = self.allocated
185
+ if num_instructions == allocated:
186
+ allocated *= 2
187
+ self.instructions = <Instruction*> PyMem_Realloc(
188
+ self.instructions,
189
+ allocated * sizeof(Instruction),
190
+ )
191
+ if not self.instructions:
192
+ raise MemoryError()
193
+ self.allocated = allocated
194
+
195
+ # Add the instruction
172
196
  self.instructions[num_instructions] = Instruction(
173
197
  symbol,
174
198
  num_args,
@@ -177,7 +201,7 @@ cdef class Instructions:
177
201
  )
178
202
  self.num_instructions = num_instructions + 1
179
203
 
180
- def __dealloc__(self):
204
+ def __dealloc__(self) -> None:
181
205
  cdef Instruction* instructions = self.instructions
182
206
  if instructions:
183
207
  for i in range(self.num_instructions):
@@ -194,15 +218,15 @@ cdef void cvm_float64(
194
218
  double* consts,
195
219
  double* result,
196
220
  Instructions instructions,
197
- ):
198
- # Core virtual machine.
221
+ ) except *:
222
+ # Core virtual machine (for dtype float64).
199
223
 
200
- cdef int i, num_args, symbol
224
+ cdef int i, symbol
201
225
  cdef double accumulator
202
226
  cdef ElementID* args
203
- cdef ElementID elem
227
+ cdef ElementID elem
204
228
 
205
- # index the four arrays by constants VARS, TMPS, CONSTS, and RESULT
229
+ # Index the four arrays by constants VARS, TMPS, CONSTS, and RESULT
206
230
  cdef (double*) arrays[4]
207
231
  arrays[VARS] = vars_in
208
232
  arrays[TMPS] = tmps
@@ -210,29 +234,32 @@ cdef void cvm_float64(
210
234
  arrays[RESULT] = result
211
235
 
212
236
  cdef Instruction* instruction_ptr = instructions.instructions
213
- for _ in range(instructions.num_instructions):
237
+ cdef int num_instructions = instructions.num_instructions
214
238
 
215
- symbol = instruction_ptr[0].symbol
216
- args = instruction_ptr[0].args
217
- num_args = instruction_ptr[0].num_args
239
+ while num_instructions > 0:
240
+ num_instructions -= 1
241
+
242
+ symbol = instruction_ptr.symbol
243
+ args = instruction_ptr.args
218
244
 
219
245
  elem = args[0]
220
246
  accumulator = arrays[elem.array][elem.index]
221
247
 
222
248
  if symbol == ADD:
223
- for i in range(1, num_args):
249
+ i = instruction_ptr.num_args
250
+ while i > 1:
251
+ i -= 1
224
252
  elem = args[i]
225
253
  accumulator += arrays[elem.array][elem.index]
226
254
  elif symbol == MUL:
227
- for i in range(1, num_args):
255
+ i = instruction_ptr.num_args
256
+ while i > 1:
257
+ i -= 1
228
258
  elem = args[i]
229
259
  accumulator *= arrays[elem.array][elem.index]
230
- elif symbol == COPY:
231
- pass
232
- else:
233
- raise RuntimeError('symbol not understood: ' + str(symbol))
260
+ # else symbol == COPY, nothing to do
234
261
 
235
- elem = instruction_ptr[0].dest
262
+ elem = instruction_ptr.dest
236
263
  arrays[elem.array][elem.index] = accumulator
237
264
 
238
265
  # Advance the instruction pointer
@@ -54,4 +54,4 @@ class NamedCircuitCompiler(Enum):
54
54
  return self.value[0]
55
55
 
56
56
 
57
- DEFAULT_CIRCUIT_COMPILER: NamedCircuitCompiler = NamedCircuitCompiler.LLVM_VM
57
+ DEFAULT_CIRCUIT_COMPILER: NamedCircuitCompiler = NamedCircuitCompiler.CYTHON_VM
@@ -149,7 +149,7 @@ def join_tree_to_circuit(
149
149
  limit_product_tree_search,
150
150
  )
151
151
  top: CircuitNode = top_table.top()
152
- top_table.circuit.remove_unreachable_op_nodes(top)
152
+ top.circuit.remove_unreachable_op_nodes(top)
153
153
 
154
154
  return PGMCircuit(
155
155
  rvs=tuple(pgm.rvs),
@@ -169,27 +169,37 @@ def _circuit_tables_from_join_tree(
169
169
  ) -> CircuitTable:
170
170
  """
171
171
  This is a basic algorithm for constructing a circuit table from a join tree.
172
+ Algorithm synopsis:
173
+ 1) Get a CircuitTable for each factor allocated to this join tree node, and
174
+ for each child of the join tree node (recursive call to _circuit_tables_from_join_tree).
175
+ 2) Form a binary tree of the collected circuit tables.
176
+ 3) Perform table products and sum-outs for each node in the binary tree, which should
177
+ leave a single circuit table with a single row.
172
178
  """
173
- # The PGM factors allocated to this join tree node
174
- factors: List[CircuitTable] = [
175
- factor_tables.get_table(factor)
176
- for factor in join_tree.factors
177
- ]
178
-
179
- # The children of this join tree node
180
- factors.extend(
181
- _circuit_tables_from_join_tree(factor_tables, child, limit_product_tree_search)
182
- for child in join_tree.children
179
+ # Get all the factors to combine.
180
+ factors: List[CircuitTable] = list(
181
+ chain(
182
+ (
183
+ # The PGM factors allocated to this join tree node
184
+ factor_tables.get_table(factor)
185
+ for factor in join_tree.factors
186
+ ),
187
+ (
188
+ # The children of this join tree node
189
+ _circuit_tables_from_join_tree(factor_tables, child, limit_product_tree_search)
190
+ for child in join_tree.children
191
+ ),
192
+ )
183
193
  )
184
194
 
185
195
  # The usual join tree approach just forms the product all the tables in `factors`.
186
196
  # The tree width is not affected by the order of products, however some orders
187
197
  # lead to smaller numbers of arithmetic operations.
188
198
  #
189
- # If `options.optimise_products` is true, then heuristics are used
199
+ # If `limit_product_tree_search > 1`, then heuristics are used
190
200
  # reduce the number of arithmetic operations.
191
201
 
192
- # Deal with the special case: no factors
202
+ # Deal with the special case: zero factors
193
203
  if len(factors) == 0:
194
204
  circuit = factor_tables.circuit
195
205
  if len(join_tree.separator) == 0:
@@ -1,5 +1,5 @@
1
- # from .circuit_table_py import (
2
- from .circuit_table import (
1
+ # from ._circuit_table_py import (
2
+ from ._circuit_table_cy import (
3
3
  CircuitTable,
4
4
  TableInstance,
5
5
  sum_out,
@@ -142,7 +142,7 @@ cpdef object sum_out(object table: CircuitTable, object rv_idxs: Iterable[int]):
142
142
 
143
143
  for group_instance_tuple, to_add in groups.items():
144
144
  node = circuit.optimised_add(to_add)
145
- if not node.is_zero():
145
+ if not node.is_zero:
146
146
  rows[group_instance_tuple] = node
147
147
 
148
148
  return new_table
@@ -159,7 +159,7 @@ cpdef object sum_out_all(object table: CircuitTable): # -> CircuitTable:
159
159
  node = next(iter(table.rows.values()))
160
160
  else:
161
161
  node: CircuitNode = circuit.optimised_add(table.rows.values())
162
- if node.is_zero():
162
+ if node.is_zero:
163
163
  return CircuitTable(circuit, ())
164
164
 
165
165
  return CircuitTable(circuit, (), [((), node)])
@@ -190,7 +190,7 @@ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
190
190
 
191
191
  # Special case: y == 0 or 1, and has no random variables.
192
192
  if len(y.rv_idxs) == 0:
193
- if len(y) == 1 and y.top().is_one():
193
+ if len(y) == 1 and y.top().is_one:
194
194
  return x
195
195
  elif len(y) == 0:
196
196
  return CircuitTable(circuit, x.rv_idxs)
@@ -259,7 +259,7 @@ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
259
259
  co.append(x_instance[i])
260
260
  co_tuple = tuple(co)
261
261
 
262
- if x_node.is_one():
262
+ if x_node.is_one:
263
263
  # Multiplying by one.
264
264
  # Iterate over matching y rows.
265
265
  got = y_index.get(co_tuple)
@@ -301,7 +301,7 @@ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> Circ
301
301
  cdef tuple[int, ...] instance
302
302
 
303
303
  for x_instance, x_node in x.rows.items():
304
- if x_node.is_one():
304
+ if x_node.is_one:
305
305
  for y_instance, y_node in y.rows.items():
306
306
  instance = x_instance + y_instance
307
307
  table.rows[instance] = y_node
@@ -314,12 +314,12 @@ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> Circ
314
314
 
315
315
 
316
316
  cdef object _optimised_mul(object circuit: Circuit, object x: CircuitNode, object y: CircuitNode): # -> CircuitNode
317
- if x.is_zero():
317
+ if x.is_zero:
318
318
  return x
319
- if y.is_zero():
319
+ if y.is_zero:
320
320
  return y
321
- if x.is_one():
321
+ if x.is_one:
322
322
  return y
323
- if y.is_one():
323
+ if y.is_one:
324
324
  return x
325
325
  return circuit.mul(x, y)
@@ -129,7 +129,7 @@ def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
129
129
  def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
130
130
  for group, to_add in groups.items():
131
131
  _node: CircuitNode = circuit.optimised_add(to_add)
132
- if not _node.is_zero():
132
+ if not _node.is_zero:
133
133
  yield group, _node
134
134
 
135
135
  return CircuitTable(circuit, remaining_rv_idxs, _result_rows())
@@ -148,7 +148,7 @@ def sum_out_all(table: CircuitTable) -> CircuitTable:
148
148
  node = next(iter(table.rows.values()))
149
149
  else:
150
150
  node: CircuitNode = circuit.optimised_add(table.rows.values())
151
- if node.is_zero():
151
+ if node.is_zero:
152
152
  return CircuitTable(circuit, ())
153
153
 
154
154
  return CircuitTable(circuit, (), [((), node)])
@@ -185,7 +185,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
185
185
 
186
186
  # Special case: y == 0 or 1, and has no random variables.
187
187
  if y_rv_idxs == ():
188
- if len(y) == 1 and y.top().is_one():
188
+ if len(y) == 1 and y.top().is_one:
189
189
  return x
190
190
  elif len(y) == 0:
191
191
  return CircuitTable(circuit, x_rv_idxs)
@@ -225,7 +225,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
225
225
  # Rows with constant node values of one are optimised out.
226
226
  for _x_instance, _x_node in x.rows.items():
227
227
  _co = tuple(_x_instance[i] for i in co_from_x_map)
228
- if _x_node.is_one():
228
+ if _x_node.is_one:
229
229
  # Multiplying by one.
230
230
  # Iterate over matching y rows.
231
231
  for _yo, _y_node in y_index.get(_co, ()):
@@ -257,7 +257,7 @@ def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
257
257
 
258
258
  def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
259
259
  for x_instance, x_node in x.rows.items():
260
- if x_node.is_one():
260
+ if x_node.is_one:
261
261
  for y_instance, y_node in y.rows.items():
262
262
  instance = x_instance + y_instance
263
263
  yield instance, y_node
@@ -180,11 +180,11 @@ def optimal_prefix(clusters: Clusters) -> None:
180
180
 
181
181
  class Clusters:
182
182
  """
183
- Holds the state of a connection graph while eliminating variables
184
- to identify clusters a PGM graph.
183
+ A Clusters object holds the state of a connection graph while
184
+ eliminating variables to construct clusters for a PGM graph.
185
185
 
186
- The clusters can either be in-progress, `len(Clusters.uneliminated) > 0`,
187
- or be completed, `len(Clusters.uneliminated) == 0`.
186
+ The Clusters object can either be "in-progress" where `len(Clusters.uneliminated) > 0`,
187
+ or be "completed" where `len(Clusters.uneliminated) == 0`.
188
188
 
189
189
  See Adnan Darwiche, 2009, Modeling and Reasoning with Bayesian Networks, p164.
190
190
  """
@@ -229,6 +229,9 @@ class Clusters:
229
229
  @property
230
230
  def eliminated(self) -> List[int]:
231
231
  """
232
+ Get the list of eliminated random variables (as random variable
233
+ indices, in elimination order).
234
+
232
235
  Assumes:
233
236
  * The returned list will not be modified by the caller.
234
237
 
@@ -240,6 +243,8 @@ class Clusters:
240
243
  @property
241
244
  def uneliminated(self) -> Set[int]:
242
245
  """
246
+ Get the set of uneliminated random variables (as random variable indices).
247
+
243
248
  Assumes:
244
249
  * The returned set will not be modified by the caller.
245
250
 
@@ -285,6 +290,8 @@ class Clusters:
285
290
 
286
291
  def max_cluster_size(self) -> int:
287
292
  """
293
+ Calculate the maximum cluster size over all clusters.
294
+
288
295
  Returns:
289
296
  the maximum `len(cluster)` over all clusters.
290
297
  """
@@ -292,6 +299,11 @@ class Clusters:
292
299
 
293
300
  def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
294
301
  """
302
+ Calculate the maximum cluster weighted size over all clusters.
303
+
304
+ Args:
305
+ rv_log_sizes: is an array of random variable sizes, such that
306
+ for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
295
307
  Returns:
296
308
  the maximum `sum(rv_log_sizes[rv_idx] for rv_idx in cluster)` over all clusters.
297
309
  """
@@ -348,7 +348,7 @@ def _make_factor_table(
348
348
  mul_vars[instance[inst_index]]
349
349
  for inst_index, mul_vars in zip(inst_to_mul, mul_rvs_vars)
350
350
  )
351
- if not node.is_one():
351
+ if not node.is_one:
352
352
  to_mul += (node,)
353
353
  if len(to_mul) == 0:
354
354
  yield instance, circuit.one
@@ -15,6 +15,11 @@ from ck.utils.np_extras import NDArrayFloat64
15
15
 
16
16
  @dataclass
17
17
  class JoinTree:
18
+ """
19
+ This is a recursive data structure representing a join-tree.
20
+ Each node in the join-tree is represented by a JoinTree object.
21
+ """
22
+
18
23
  # The PGM that this join tree is for.
19
24
  pgm: PGM
20
25
 
@@ -40,6 +45,12 @@ class JoinTree:
40
45
 
41
46
  def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
42
47
  """
48
+ Calculate the maximum cluster weighted size for this cluster and its children.
49
+
50
+ Args:
51
+ rv_log_sizes: is an array of random variable sizes, such that
52
+ for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
53
+
43
54
  Returns:
44
55
  the maximum `log2` over self and all children, recursively.
45
56
  """
@@ -82,8 +93,8 @@ JoinTreeAlgorithm = Callable[[PGM], JoinTree]
82
93
 
83
94
  def _join_tree_algorithm(pgm_to_clusters: ClusterAlgorithm) -> JoinTreeAlgorithm:
84
95
  """
85
- Helper function for creating a standard JoinTreeAlgorithm from
86
- a ClusterAlgorithm.
96
+ Helper function for creating a standard JoinTreeAlgorithm
97
+ from a ClusterAlgorithm.
87
98
 
88
99
  Args:
89
100
  pgm_to_clusters: The clusters method to use.
@@ -112,14 +123,17 @@ MIN_TRADITIONAL_WEIGHTED_FILL: JoinTreeAlgorithm = _join_tree_algorithm(min_trad
112
123
 
113
124
  def clusters_to_join_tree(clusters: Clusters) -> JoinTree:
114
125
  """
115
- Construct a join tree maker for the given PGM and random variable clusters.
126
+ Construct a join tree from the given random variable clusters.
116
127
 
117
128
  A join tree is formed by finding a minimum spanning tree over the clusters
118
- where the cost between a pair of cluster is defined according to
119
- `separator_cost_counts` and `costing`.
129
+ where the cost between a pair of clusters is the number of random variables
130
+ in common (using separator state space size to break ties).
120
131
 
121
132
  Args:
122
- clusters: the clusters that resulted from graph clusters of the given PGM.
133
+ clusters: the clusters that resulted from graph clusters of a PGM.
134
+
135
+ Returns:
136
+ a JoinTree.
123
137
  """
124
138
  pgm: PGM = clusters.pgm
125
139
  cluster_sets: List[Set[int]] = clusters.clusters
@@ -170,6 +184,19 @@ def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]
170
184
  """
171
185
  Construct a minimum spanning tree over the clusters, where the root is the cluster with
172
186
  the smallest number of random variable.
187
+
188
+ Args:
189
+ cost: is an N x N matrix of costs between N clusters.
190
+ clusters: is a list of N clusters, each cluster is a set of random variable indices.
191
+
192
+ Returns:
193
+ (spanning_tree, root_index)
194
+
195
+ spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
196
+ the given cost matrix, each node is a list of children, each child being
197
+ represented as an index into the list of nodes.
198
+
199
+ root_index: is the index the chosen root of the spanning tree.
173
200
  """
174
201
  root_custer_index: int = 0
175
202
  root_size: int = len(clusters[root_custer_index])
@@ -185,10 +212,22 @@ def _make_spanning_tree_small_root(cost: NDArrayFloat64, clusters: List[Set[int]
185
212
  def _make_spanning_tree_arbitrary_root(cost: NDArrayFloat64) -> Tuple[List[List[int]], int]:
186
213
  """
187
214
  Construct a minimum spanning tree over the clusters, starting at an arbitrary root.
215
+
216
+ Args:
217
+ cost: is an N x N matrix of costs between N clusters.
218
+
219
+ Returns:
220
+ (spanning_tree, root_index)
221
+
222
+ spanning_tree: is a spanning tree represented as a list of nodes, the list is coindexed with
223
+ the given cost matrix, each node is a list of children, each child being
224
+ represented as an index into the list of nodes.
225
+
226
+ root_index: is the index the chosen root of the spanning tree.
188
227
  """
189
- root_custer_index: int = 0
190
- children: List[List[int]] = _make_spanning_tree_at_root(cost, root_custer_index)
191
- return children, root_custer_index
228
+ root_index: int = 0
229
+ spanning_tree: List[List[int]] = _make_spanning_tree_at_root(cost, root_index)
230
+ return spanning_tree, root_index
192
231
 
193
232
 
194
233
  def _make_spanning_tree_at_root(
@@ -202,6 +241,12 @@ def _make_spanning_tree_at_root(
202
241
  cost: and nxn matrix where n is the number of clusters and cost[i, j]
203
242
  gives the cost between clusters i and j.
204
243
  root_custer_index: a nominated root cluster to be the root of the tree.
244
+
245
+ Returns:
246
+ a spanning tree represented as a list of nodes, the list is coindexed with
247
+ the given cost matrix, each node is a list of children, each child being
248
+ represented as an index into the list of nodes. The root node is the
249
+ index `root_custer_index` as passed to this function.
205
250
  """
206
251
  number_of_clusters: int = cost.shape[0]
207
252
 
@@ -257,7 +302,19 @@ def _form_join_tree_r(
257
302
  cluster_factors: List[List[Factor]],
258
303
  ) -> JoinTree:
259
304
  """
260
- Recursively build the join tree data structure.
305
+ Recursively build a JoinTree from the spanning tree `children`.
306
+ This function merely pull the corresponding component from the
307
+ arguments to make a JoinTree object, doing this recursively
308
+ for the children.
309
+
310
+ Args:
311
+ pgm: the source PGM for the join tree.
312
+ cluster_index: index for the node we are processing (current root). This
313
+ indexes into `children`, `clusters`, and `cluster_factors`.
314
+ parent_cluster: set of random variable indices in the parent cluster.
315
+ children: list of spanning tree nodes, as per `_make_spanning_tree_at_root` result.
316
+ clusters: list of clusters, each cluster is a set of random variable indices.
317
+ cluster_factors: assignment of factors to clusters.
261
318
  """
262
319
  cluster: Set[int] = clusters[cluster_index]
263
320
  factors: List[Factor] = cluster_factors[cluster_index]
@@ -51,6 +51,8 @@ def compile_pgm(
51
51
 
52
52
  elimination_order: Sequence[int] = algorithm(pgm).eliminated
53
53
 
54
+ # Eliminate rvs from the factor tables according to the
55
+ # elimination order.
54
56
  cur_tables: List[CircuitTable] = list(factor_tables.tables)
55
57
  for rv_idx in elimination_order:
56
58
  next_tables: List[CircuitTable] = []
@@ -9,6 +9,16 @@ from ck.pgm_compiler.support.join_tree import JoinTree, clusters_to_join_tree
9
9
 
10
10
 
11
11
  def main() -> None:
12
+ """
13
+ This demo shows the full compilation chain for factor elimination.
14
+
15
+ Process:
16
+ Rain example -> PGM
17
+ min_degree -> Clusters
18
+ clusters_to_join_tree -> JoinTree
19
+ join_tree_to_circuit -> PGMCircuit
20
+ default circuit compiler -> WMCProgram
21
+ """
12
22
  pgm: PGM = example.Rain()
13
23
 
14
24
  print(f'PGM {pgm.name!r}')