compiled-knowledge 4.0.0a16__cp312-cp312-win_amd64.whl → 4.0.0a18__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (36) hide show
  1. ck/circuit/__init__.py +9 -2
  2. ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
  3. ck/circuit/_circuit_cy.pxd +33 -0
  4. ck/circuit/{circuit.pyx → _circuit_cy.pyx} +115 -133
  5. ck/circuit/{circuit_py.py → _circuit_py.py} +16 -8
  6. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  7. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +88 -60
  8. ck/circuit_compiler/named_circuit_compilers.py +1 -1
  9. ck/pgm_compiler/factor_elimination.py +23 -13
  10. ck/pgm_compiler/support/circuit_table/__init__.py +9 -2
  11. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
  12. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  13. ck/pgm_compiler/support/circuit_table/_circuit_table_cy_cpp_verion.pyx +601 -0
  14. ck/pgm_compiler/support/circuit_table/_circuit_table_cy_minimal_version.pyx +311 -0
  15. ck/pgm_compiler/support/circuit_table/{circuit_table.pyx → _circuit_table_cy_v4.0.0a17.pyx} +9 -9
  16. ck/pgm_compiler/support/circuit_table/{circuit_table_py.py → _circuit_table_py.py} +80 -45
  17. ck/pgm_compiler/support/clusters.py +16 -4
  18. ck/pgm_compiler/support/factor_tables.py +1 -1
  19. ck/pgm_compiler/support/join_tree.py +67 -10
  20. ck/pgm_compiler/support/named_compiler_maker.py +12 -2
  21. ck/pgm_compiler/variable_elimination.py +2 -0
  22. ck/utils/iter_extras.py +8 -1
  23. ck_demos/pgm_compiler/demo_compiler_dump.py +10 -0
  24. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  25. ck_demos/utils/compare.py +5 -1
  26. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/METADATA +1 -1
  27. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/RECORD +30 -29
  28. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/WHEEL +1 -1
  29. ck/circuit/circuit.c +0 -38861
  30. ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
  31. ck/circuit/circuit_node.pyx +0 -138
  32. ck/circuit_compiler/cython_vm_compiler/_compiler.c +0 -17373
  33. ck/pgm_compiler/support/circuit_table/circuit_table.c +0 -16042
  34. ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
  35. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/licenses/LICENSE.txt +0 -0
  36. {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence, Tuple, Dict, Iterable, Set, Iterator
4
+
5
+ from ck.circuit import ADD, MUL
6
+ from ck.utils.map_list import MapList
7
+
8
+ from ck.circuit._circuit_cy cimport Circuit, CircuitNode
9
+
10
+ cdef int c_ADD = ADD
11
+ cdef int c_MUL = MUL
12
+
13
+
14
+ TableInstance = Tuple[int, ...]
15
+
16
+
17
+ class CircuitTable:
18
+ """
19
+ A circuit table manages a set of CircuitNodes, where each node corresponds
20
+ to an instance for a set of (zero or more) random variables.
21
+
22
+ Operations on circuit tables typically add circuit nodes to the circuit. It will
23
+ heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
24
+ by zero or one.) However, it may be that interim circuit nodes are created that
25
+ end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
26
+ completing all circuit table operations.
27
+
28
+ It is generally expected that no CircuitTable row will be created with a constant
29
+ zero node. These are assumed to be optimised out already.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ circuit: Circuit,
35
+ rv_idxs: Sequence[int],
36
+ rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
37
+ ):
38
+ """
39
+ Args:
40
+ circuit: the circuit whose nodes are being managed by this table.
41
+ rv_idxs: indexes of random variables.
42
+ rows: optional rows to add to the table.
43
+
44
+ Assumes:
45
+ * rv_idxs contains no duplicates.
46
+ * all row instances conform to the indexed random variables.
47
+ * all row circuit nodes belong to the given circuit.
48
+ """
49
+ self._circuit: Circuit = circuit
50
+ self._rv_idxs: Tuple[int, ...] = tuple(rv_idxs)
51
+ self._rows: Dict[TableInstance, CircuitNode] = dict(rows)
52
+
53
+ @property
54
+ def circuit(self) -> Circuit:
55
+ return self._circuit
56
+
57
+ @property
58
+ def rv_idxs(self) -> Tuple[int, ...]:
59
+ return self._rv_idxs
60
+
61
+ def __len__(self) -> int:
62
+ return len(self._rows)
63
+
64
+ def get(self, key, default=None):
65
+ return self._rows.get(key, default)
66
+
67
+ def keys(self) -> Iterable[TableInstance]:
68
+ return self._rows.keys()
69
+
70
+ def values(self) -> Iterable[CircuitNode]:
71
+ return self._rows.values()
72
+
73
+ def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
74
+ return self._rows.items()
75
+
76
+ def __getitem__(self, key):
77
+ return self._rows[key]
78
+
79
+ def __setitem__(self, key, value):
80
+ self._rows[key] = value
81
+
82
+ def top(self) -> CircuitNode:
83
+ """
84
+ Get the circuit top value.
85
+
86
+ Raises:
87
+ RuntimeError if there is more than one row in the table.
88
+
89
+ Returns:
90
+ A single circuit node.
91
+ """
92
+ if len(self._rows) == 0:
93
+ return self._circuit.zero
94
+ elif len(self._rows) == 1:
95
+ return next(iter(self._rows.values()))
96
+ else:
97
+ raise RuntimeError('cannot get top node from a table with more that 1 row')
98
+
99
+
100
+ # ==================================================================================
101
+ # Circuit Table Operations
102
+ # ==================================================================================
103
+
104
+
105
+ def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
106
+ """
107
+ Return a circuit table that results from summing out
108
+ the given random variables of this circuit table.
109
+
110
+ Normally this will return a new table. However, if rv_idxs is empty,
111
+ then the given table is returned unmodified.
112
+
113
+ Raises:
114
+ ValueError if rv_idxs is not a subset of table.rv_idxs.
115
+ ValueError if rv_idxs contains duplicates.
116
+ """
117
+ rv_idxs: Sequence[int] = tuple(rv_idxs)
118
+
119
+ if len(rv_idxs) == 0:
120
+ # nothing to do
121
+ return table
122
+
123
+ rv_idxs_set: Set[int] = set(rv_idxs)
124
+ if len(rv_idxs_set) != len(rv_idxs):
125
+ raise ValueError('rv_idxs contains duplicates')
126
+ if not rv_idxs_set.issubset(table.rv_idxs):
127
+ raise ValueError('rv_idxs is not a subset of table.rv_idxs')
128
+
129
+ remaining_rv_idxs = tuple(
130
+ rv_index
131
+ for rv_index in table.rv_idxs
132
+ if rv_index not in rv_idxs_set
133
+ )
134
+ num_remaining = len(remaining_rv_idxs)
135
+ if num_remaining == 0:
136
+ # Special case: summing out all random variables
137
+ return sum_out_all(table)
138
+
139
+ # index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
140
+ index_map = tuple(
141
+ table.rv_idxs.index(remaining_rv_index)
142
+ for remaining_rv_index in remaining_rv_idxs
143
+ )
144
+
145
+ # This is a one-pass version to sum the groups. The
146
+ # two-pass version (below) seems to have better performance.
147
+ #
148
+ # circuit: Circuit = table.circuit
149
+ # result = CircuitTable(circuit, remaining_rv_idxs)
150
+ # result_rows: Dict[TableInstance, CircuitNode] = result._rows
151
+ # for instance, node in table.items():
152
+ # group_instance = tuple(instance[i] for i in index_map)
153
+ # prev_sum = result_rows.get(group_instance)
154
+ # if prev_sum is None:
155
+ # result_rows[group_instance] = node
156
+ # else:
157
+ # result_rows[group_instance] = circuit.add(prev_sum, node)
158
+ # return result
159
+
160
+ groups: MapList[TableInstance, CircuitNode] = MapList()
161
+ for instance, node in table.items():
162
+ group_instance = tuple(instance[i] for i in index_map)
163
+ groups.append(group_instance, node)
164
+ circuit: Circuit = table.circuit
165
+ return CircuitTable(
166
+ circuit,
167
+ remaining_rv_idxs,
168
+ (
169
+ (group, circuit.op(c_ADD, tuple(to_add)))
170
+ for group, to_add in groups.items()
171
+ )
172
+ )
173
+
174
+
175
+
176
+ def sum_out_all(table: CircuitTable) -> CircuitTable:
177
+ """
178
+ Return a circuit table that results from summing out
179
+ all random variables of this circuit table.
180
+ """
181
+ circuit: Circuit = table.circuit
182
+ num_rows: int = len(table)
183
+ if num_rows == 0:
184
+ return CircuitTable(circuit, ())
185
+ elif num_rows == 1:
186
+ node = next(iter(table.values()))
187
+ else:
188
+ node: CircuitNode = circuit.op(c_ADD, tuple(table.values()))
189
+ if node.is_zero:
190
+ return CircuitTable(circuit, ())
191
+
192
+ return CircuitTable(circuit, (), [((), node)])
193
+
194
+
195
+ def project(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
196
+ """
197
+ Call `sum_out(table, to_sum_out)`, where
198
+ `to_sum_out = table.rv_idxs - rv_idxs`.
199
+ """
200
+ to_sum_out: Set[int] = set(table.rv_idxs)
201
+ to_sum_out.difference_update(rv_idxs)
202
+ return sum_out(table, to_sum_out)
203
+
204
+
205
+ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
206
+ """
207
+ Return a circuit table that results from the product of the two given tables.
208
+
209
+ If x or y have a single row with value 1, then the other table is returned. Otherwise,
210
+ a new circuit table will be constructed and returned.
211
+ """
212
+ circuit: Circuit = x.circuit
213
+ if y.circuit is not circuit:
214
+ raise ValueError('circuit tables must refer to the same circuit')
215
+
216
+ # Make the smaller table 'y', and the other 'x'.
217
+ # This is to minimise the index size on 'y'.
218
+ if len(x) < len(y):
219
+ x, y = y, x
220
+
221
+ x_rv_idxs: Tuple[int, ...] = x.rv_idxs
222
+ y_rv_idxs: Tuple[int, ...] = y.rv_idxs
223
+
224
+ # Special case: y == 0 or 1, and has no random variables.
225
+ if y_rv_idxs == ():
226
+ if len(y) == 1 and y.top().is_one:
227
+ return x
228
+ elif len(y) == 0:
229
+ return CircuitTable(circuit, x_rv_idxs)
230
+
231
+ # Set operations on rv indexes. After these operations:
232
+ # * co_rv_idxs is the set of rv indexes common (co) to x and y,
233
+ # * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
234
+ yo_rv_idxs_set: Set[int] = set(y_rv_idxs)
235
+ co_rv_idxs_set: Set[int] = set(x_rv_idxs)
236
+ co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
237
+ yo_rv_idxs_set.difference_update(co_rv_idxs_set)
238
+
239
+ if len(co_rv_idxs_set) == 0:
240
+ # Special case: no common random variables.
241
+ return _product_no_common_rvs(x, y)
242
+
243
+ # Convert random variable index sets to sequences
244
+ yo_rv_idxs: Tuple[int, ...] = tuple(yo_rv_idxs_set) # y only random variables
245
+ co_rv_idxs: Tuple[int, ...] = tuple(co_rv_idxs_set) # common random variables
246
+
247
+ # Cache mappings from result Instance to index into source Instance (x or y).
248
+ # This will be used in indexing and product loops to pull our needed values
249
+ # from the source instances.
250
+ co_from_x_map = tuple(x.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
251
+ co_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
252
+ yo_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in yo_rv_idxs)
253
+
254
+ # Index the y rows by common-only key (y is the smaller of the two tables).
255
+ y_index: MapList[TableInstance, Tuple[TableInstance, CircuitNode]] = MapList()
256
+ for y_instance, y_node in y.items():
257
+ co = tuple(y_instance[i] for i in co_from_y_map)
258
+ yo = tuple(y_instance[i] for i in yo_from_y_map)
259
+ y_index.append(co, (yo, y_node))
260
+
261
+ def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
262
+ # Iterate over x rows, yielding (instance, value).
263
+ # Rows with constant node values of one are optimised out.
264
+ for _x_instance, _x_node in x.items():
265
+ _co = tuple(_x_instance[i] for i in co_from_x_map)
266
+ if _x_node.is_one:
267
+ # Multiplying by one.
268
+ # Iterate over matching y rows.
269
+ for _yo, _y_node in y_index.get(_co, ()):
270
+ yield _x_instance + _yo, _y_node
271
+ else:
272
+ # Iterate over matching y rows.
273
+ for _yo, _y_node in y_index.get(_co, ()):
274
+ if _y_node.is_one:
275
+ yield _x_instance + _yo, _x_node
276
+ else:
277
+ yield _x_instance + _yo, circuit.op(c_MUL, (_x_node, _y_node))
278
+
279
+ return CircuitTable(circuit, x_rv_idxs + yo_rv_idxs, _result_rows())
280
+
281
+
282
+ def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
283
+ """
284
+ Return the product of x and y, where x and y have no common random variables.
285
+
286
+ This is an optimisation of more general product algorithm as no index needs
287
+ to be construction based on the common random variables.
288
+
289
+ Rows with constant node values of one are optimised out.
290
+
291
+ Assumes:
292
+ * There are no common random variables between x and y.
293
+ * x and y are for the same circuit.
294
+ """
295
+ circuit: Circuit = x.circuit
296
+
297
+ result_rv_idxs: Tuple[int, ...] = x.rv_idxs + y.rv_idxs
298
+
299
+ def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
300
+ for x_instance, x_node in x.items():
301
+ if x_node.is_one:
302
+ for y_instance, y_node in y.items():
303
+ yield x_instance + y_instance, y_node
304
+ else:
305
+ for y_instance, y_node in y.items():
306
+ if y_node.is_one:
307
+ yield x_instance + y_instance, y_node
308
+ else:
309
+ yield x_instance + y_instance, circuit.op(c_MUL, (x_node, y_node))
310
+
311
+ return CircuitTable(circuit, result_rv_idxs, _result_rows())
@@ -142,7 +142,7 @@ cpdef object sum_out(object table: CircuitTable, object rv_idxs: Iterable[int]):
142
142
 
143
143
  for group_instance_tuple, to_add in groups.items():
144
144
  node = circuit.optimised_add(to_add)
145
- if not node.is_zero():
145
+ if not node.is_zero:
146
146
  rows[group_instance_tuple] = node
147
147
 
148
148
  return new_table
@@ -159,7 +159,7 @@ cpdef object sum_out_all(object table: CircuitTable): # -> CircuitTable:
159
159
  node = next(iter(table.rows.values()))
160
160
  else:
161
161
  node: CircuitNode = circuit.optimised_add(table.rows.values())
162
- if node.is_zero():
162
+ if node.is_zero:
163
163
  return CircuitTable(circuit, ())
164
164
 
165
165
  return CircuitTable(circuit, (), [((), node)])
@@ -190,7 +190,7 @@ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
190
190
 
191
191
  # Special case: y == 0 or 1, and has no random variables.
192
192
  if len(y.rv_idxs) == 0:
193
- if len(y) == 1 and y.top().is_one():
193
+ if len(y) == 1 and y.top().is_one:
194
194
  return x
195
195
  elif len(y) == 0:
196
196
  return CircuitTable(circuit, x.rv_idxs)
@@ -259,7 +259,7 @@ cpdef object product(x: CircuitTable, y: CircuitTable): # -> CircuitTable:
259
259
  co.append(x_instance[i])
260
260
  co_tuple = tuple(co)
261
261
 
262
- if x_node.is_one():
262
+ if x_node.is_one:
263
263
  # Multiplying by one.
264
264
  # Iterate over matching y rows.
265
265
  got = y_index.get(co_tuple)
@@ -301,7 +301,7 @@ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> Circ
301
301
  cdef tuple[int, ...] instance
302
302
 
303
303
  for x_instance, x_node in x.rows.items():
304
- if x_node.is_one():
304
+ if x_node.is_one:
305
305
  for y_instance, y_node in y.rows.items():
306
306
  instance = x_instance + y_instance
307
307
  table.rows[instance] = y_node
@@ -314,12 +314,12 @@ cdef object _product_no_common_rvs(x: CircuitTable, y: CircuitTable): # -> Circ
314
314
 
315
315
 
316
316
  cdef object _optimised_mul(object circuit: Circuit, object x: CircuitNode, object y: CircuitNode): # -> CircuitNode
317
- if x.is_zero():
317
+ if x.is_zero:
318
318
  return x
319
- if y.is_zero():
319
+ if y.is_zero:
320
320
  return y
321
- if x.is_one():
321
+ if x.is_one:
322
322
  return y
323
- if y.is_one():
323
+ if y.is_one:
324
324
  return x
325
325
  return circuit.mul(x, y)
@@ -40,21 +40,38 @@ class CircuitTable:
40
40
  * all row instances conform to the indexed random variables.
41
41
  * all row circuit nodes belong to the given circuit.
42
42
  """
43
- self.circuit: Circuit = circuit
44
- self.rv_idxs: Tuple[int, ...] = tuple(rv_idxs)
45
- self.rows: Dict[TableInstance, CircuitNode] = dict(rows)
43
+ self._circuit: Circuit = circuit
44
+ self._rv_idxs: Tuple[int, ...] = tuple(rv_idxs)
45
+ self._rows: Dict[TableInstance, CircuitNode] = dict(rows)
46
+
47
+ @property
48
+ def circuit(self) -> Circuit:
49
+ return self._circuit
50
+
51
+ @property
52
+ def rv_idxs(self) -> Tuple[int, ...]:
53
+ return self._rv_idxs
46
54
 
47
55
  def __len__(self) -> int:
48
- return len(self.rows)
56
+ return len(self._rows)
49
57
 
50
58
  def get(self, key, default=None):
51
- return self.rows.get(key, default)
59
+ return self._rows.get(key, default)
60
+
61
+ def keys(self) -> Iterable[TableInstance]:
62
+ return self._rows.keys()
63
+
64
+ def values(self) -> Iterable[CircuitNode]:
65
+ return self._rows.values()
66
+
67
+ def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
68
+ return self._rows.items()
52
69
 
53
70
  def __getitem__(self, key):
54
- return self.rows[key]
71
+ return self._rows[key]
55
72
 
56
73
  def __setitem__(self, key, value):
57
- self.rows[key] = value
74
+ self._rows[key] = value
58
75
 
59
76
  def top(self) -> CircuitNode:
60
77
  """
@@ -66,10 +83,10 @@ class CircuitTable:
66
83
  Returns:
67
84
  A single circuit node.
68
85
  """
69
- if len(self.rows) == 0:
70
- return self.circuit.zero
71
- elif len(self.rows) == 1:
72
- return next(iter(self.rows.values()))
86
+ if len(self._rows) == 0:
87
+ return self._circuit.zero
88
+ elif len(self._rows) == 1:
89
+ return next(iter(self._rows.values()))
73
90
  else:
74
91
  raise RuntimeError('cannot get top node from a table with more that 1 row')
75
92
 
@@ -119,20 +136,34 @@ def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
119
136
  for remaining_rv_index in remaining_rv_idxs
120
137
  )
121
138
 
139
+ # This is a one-pass version to sum the groups. The
140
+ # two-pass version (below) seems to have better performance.
141
+ #
142
+ # circuit: Circuit = table.circuit
143
+ # result = CircuitTable(circuit, remaining_rv_idxs)
144
+ # result_rows: Dict[TableInstance, CircuitNode] = result._rows
145
+ # for instance, node in table.items():
146
+ # group_instance = tuple(instance[i] for i in index_map)
147
+ # prev_sum = result_rows.get(group_instance)
148
+ # if prev_sum is None:
149
+ # result_rows[group_instance] = node
150
+ # else:
151
+ # result_rows[group_instance] = circuit.add(prev_sum, node)
152
+ # return result
153
+
122
154
  groups: MapList[TableInstance, CircuitNode] = MapList()
123
- for instance, node in table.rows.items():
155
+ for instance, node in table.items():
124
156
  group_instance = tuple(instance[i] for i in index_map)
125
157
  groups.append(group_instance, node)
126
-
127
158
  circuit: Circuit = table.circuit
128
-
129
- def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
130
- for group, to_add in groups.items():
131
- _node: CircuitNode = circuit.optimised_add(to_add)
132
- if not _node.is_zero():
133
- yield group, _node
134
-
135
- return CircuitTable(circuit, remaining_rv_idxs, _result_rows())
159
+ return CircuitTable(
160
+ circuit,
161
+ remaining_rv_idxs,
162
+ (
163
+ (group, circuit.add(to_add))
164
+ for group, to_add in groups.items()
165
+ )
166
+ )
136
167
 
137
168
 
138
169
  def sum_out_all(table: CircuitTable) -> CircuitTable:
@@ -145,10 +176,10 @@ def sum_out_all(table: CircuitTable) -> CircuitTable:
145
176
  if num_rows == 0:
146
177
  return CircuitTable(circuit, ())
147
178
  elif num_rows == 1:
148
- node = next(iter(table.rows.values()))
179
+ node = next(iter(table.values()))
149
180
  else:
150
- node: CircuitNode = circuit.optimised_add(table.rows.values())
151
- if node.is_zero():
181
+ node: CircuitNode = circuit.optimised_add(table.values())
182
+ if node.is_zero:
152
183
  return CircuitTable(circuit, ())
153
184
 
154
185
  return CircuitTable(circuit, (), [((), node)])
@@ -168,7 +199,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
168
199
  """
169
200
  Return a circuit table that results from the product of the two given tables.
170
201
 
171
- If x or y equals `one_table`, then the other table is returned. Otherwise,
202
+ If x or y have a single row with value 1, then the other table is returned. Otherwise,
172
203
  a new circuit table will be constructed and returned.
173
204
  """
174
205
  circuit: Circuit = x.circuit
@@ -185,7 +216,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
185
216
 
186
217
  # Special case: y == 0 or 1, and has no random variables.
187
218
  if y_rv_idxs == ():
188
- if len(y) == 1 and y.top().is_one():
219
+ if len(y) == 1 and y.top().is_one:
189
220
  return x
190
221
  elif len(y) == 0:
191
222
  return CircuitTable(circuit, x_rv_idxs)
@@ -193,18 +224,18 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
193
224
  # Set operations on rv indexes. After these operations:
194
225
  # * co_rv_idxs is the set of rv indexes common (co) to x and y,
195
226
  # * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
196
- yo_rv_idxs: Set[int] = set(y_rv_idxs)
197
- co_rv_idxs: Set[int] = set(x_rv_idxs)
198
- co_rv_idxs.intersection_update(yo_rv_idxs)
199
- yo_rv_idxs.difference_update(co_rv_idxs)
227
+ yo_rv_idxs_set: Set[int] = set(y_rv_idxs)
228
+ co_rv_idxs_set: Set[int] = set(x_rv_idxs)
229
+ co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
230
+ yo_rv_idxs_set.difference_update(co_rv_idxs_set)
200
231
 
201
- if len(co_rv_idxs) == 0:
232
+ if len(co_rv_idxs_set) == 0:
202
233
  # Special case: no common random variables.
203
234
  return _product_no_common_rvs(x, y)
204
235
 
205
236
  # Convert random variable index sets to sequences
206
- yo_rv_idxs: Tuple[int, ...] = tuple(yo_rv_idxs) # y only random variables
207
- co_rv_idxs: Tuple[int, ...] = tuple(co_rv_idxs) # common random variables
237
+ yo_rv_idxs: Tuple[int, ...] = tuple(yo_rv_idxs_set) # y only random variables
238
+ co_rv_idxs: Tuple[int, ...] = tuple(co_rv_idxs_set) # common random variables
208
239
 
209
240
  # Cache mappings from result Instance to index into source Instance (x or y).
210
241
  # This will be used in indexing and product loops to pull our needed values
@@ -215,7 +246,7 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
215
246
 
216
247
  # Index the y rows by common-only key (y is the smaller of the two tables).
217
248
  y_index: MapList[TableInstance, Tuple[TableInstance, CircuitNode]] = MapList()
218
- for y_instance, y_node in y.rows.items():
249
+ for y_instance, y_node in y.items():
219
250
  co = tuple(y_instance[i] for i in co_from_y_map)
220
251
  yo = tuple(y_instance[i] for i in yo_from_y_map)
221
252
  y_index.append(co, (yo, y_node))
@@ -223,9 +254,9 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
223
254
  def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
224
255
  # Iterate over x rows, yielding (instance, value).
225
256
  # Rows with constant node values of one are optimised out.
226
- for _x_instance, _x_node in x.rows.items():
257
+ for _x_instance, _x_node in x.items():
227
258
  _co = tuple(_x_instance[i] for i in co_from_x_map)
228
- if _x_node.is_one():
259
+ if _x_node.is_one:
229
260
  # Multiplying by one.
230
261
  # Iterate over matching y rows.
231
262
  for _yo, _y_node in y_index.get(_co, ()):
@@ -233,7 +264,10 @@ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
233
264
  else:
234
265
  # Iterate over matching y rows.
235
266
  for _yo, _y_node in y_index.get(_co, ()):
236
- yield _x_instance + _yo, circuit.optimised_mul((_x_node, _y_node))
267
+ if _y_node.is_one:
268
+ yield _x_instance + _yo, _x_node
269
+ else:
270
+ yield _x_instance + _yo, circuit.mul(_x_node, _y_node)
237
271
 
238
272
  return CircuitTable(circuit, x_rv_idxs + yo_rv_idxs, _result_rows())
239
273
 
@@ -256,14 +290,15 @@ def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
256
290
  result_rv_idxs: Tuple[int, ...] = x.rv_idxs + y.rv_idxs
257
291
 
258
292
  def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
259
- for x_instance, x_node in x.rows.items():
260
- if x_node.is_one():
261
- for y_instance, y_node in y.rows.items():
262
- instance = x_instance + y_instance
263
- yield instance, y_node
293
+ for x_instance, x_node in x.items():
294
+ if x_node.is_one:
295
+ for y_instance, y_node in y.items():
296
+ yield x_instance + y_instance, y_node
264
297
  else:
265
- for y_instance, y_node in y.rows.items():
266
- instance = x_instance + y_instance
267
- yield instance, circuit.optimised_mul((x_node, y_node))
298
+ for y_instance, y_node in y.items():
299
+ if y_node.is_one:
300
+ yield x_instance + y_instance, y_node
301
+ else:
302
+ yield x_instance + y_instance, circuit.mul(x_node, y_node)
268
303
 
269
304
  return CircuitTable(circuit, result_rv_idxs, _result_rows())
@@ -180,11 +180,11 @@ def optimal_prefix(clusters: Clusters) -> None:
180
180
 
181
181
  class Clusters:
182
182
  """
183
- Holds the state of a connection graph while eliminating variables
184
- to identify clusters a PGM graph.
183
+ A Clusters object holds the state of a connection graph while
184
+ eliminating variables to construct clusters for a PGM graph.
185
185
 
186
- The clusters can either be in-progress, `len(Clusters.uneliminated) > 0`,
187
- or be completed, `len(Clusters.uneliminated) == 0`.
186
+ The Clusters object can either be "in-progress" where `len(Clusters.uneliminated) > 0`,
187
+ or be "completed" where `len(Clusters.uneliminated) == 0`.
188
188
 
189
189
  See Adnan Darwiche, 2009, Modeling and Reasoning with Bayesian Networks, p164.
190
190
  """
@@ -229,6 +229,9 @@ class Clusters:
229
229
  @property
230
230
  def eliminated(self) -> List[int]:
231
231
  """
232
+ Get the list of eliminated random variables (as random variable
233
+ indices, in elimination order).
234
+
232
235
  Assumes:
233
236
  * The returned list will not be modified by the caller.
234
237
 
@@ -240,6 +243,8 @@ class Clusters:
240
243
  @property
241
244
  def uneliminated(self) -> Set[int]:
242
245
  """
246
+ Get the set of uneliminated random variables (as random variable indices).
247
+
243
248
  Assumes:
244
249
  * The returned set will not be modified by the caller.
245
250
 
@@ -285,6 +290,8 @@ class Clusters:
285
290
 
286
291
  def max_cluster_size(self) -> int:
287
292
  """
293
+ Calculate the maximum cluster size over all clusters.
294
+
288
295
  Returns:
289
296
  the maximum `len(cluster)` over all clusters.
290
297
  """
@@ -292,6 +299,11 @@ class Clusters:
292
299
 
293
300
  def max_cluster_weighted_size(self, rv_log_sizes: Sequence[float]) -> float:
294
301
  """
302
+ Calculate the maximum cluster weighted size over all clusters.
303
+
304
+ Args:
305
+ rv_log_sizes: is an array of random variable sizes, such that
306
+ for a random variable `rv`, `rv_log_sizes[rv.idx] = log2(len(rv))`.
295
307
  Returns:
296
308
  the maximum `sum(rv_log_sizes[rv_idx] for rv_idx in cluster)` over all clusters.
297
309
  """
@@ -348,7 +348,7 @@ def _make_factor_table(
348
348
  mul_vars[instance[inst_index]]
349
349
  for inst_index, mul_vars in zip(inst_to_mul, mul_rvs_vars)
350
350
  )
351
- if not node.is_one():
351
+ if not node.is_one:
352
352
  to_mul += (node,)
353
353
  if len(to_mul) == 0:
354
354
  yield instance, circuit.one