compiled-knowledge 4.0.0a20__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37523 -0
  4. ck/circuit/_circuit_cy.cp313-win_amd64.pyd +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cp313-win_amd64.pyd +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp313-win_amd64.pyd +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp313-win_amd64.pyd +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,332 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence, Tuple, Iterable
4
+
5
+ from ck.circuit import MUL, ADD
6
+
7
+ from ck.circuit._circuit_cy cimport Circuit, CircuitNode
8
+
9
+ cdef int c_ADD = ADD
10
+ cdef int c_MUL = MUL
11
+
12
+
13
+ TableInstance = Tuple[int, ...]
14
+
15
+
16
+ cdef class CircuitTable:
17
+ """
18
+ A circuit table manages a set of CircuitNodes, where each node corresponds
19
+ to an instance for a set of (zero or more) random variables.
20
+
21
+ Operations on circuit tables typically add circuit nodes to the circuit. It will
22
+ heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
23
+ by zero or one.) However, it may be that interim circuit nodes are created that
24
+ end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
25
+ completing all circuit table operations.
26
+
27
+ It is generally expected that no CircuitTable row will be created with a constant
28
+ zero node. These are assumed to be optimised out already.
29
+ """
30
+
31
+ cdef public Circuit circuit
32
+ cdef public tuple[int, ...] rv_idxs
33
+ cdef dict[tuple[int, ...], CircuitNode] rows
34
+
35
+ def __init__(
36
+ self,
37
+ circuit: Circuit,
38
+ rv_idxs: Sequence[int],
39
+ rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
40
+ ):
41
+ """
42
+ Args:
43
+ circuit: the circuit whose nodes are being managed by this table.
44
+ rv_idxs: indexes of random variables.
45
+ rows: optional rows to add to the table.
46
+
47
+ Assumes:
48
+ * rv_idxs contains no duplicates.
49
+ * all row instances conform to the indexed random variables.
50
+ * all row circuit nodes belong to the given circuit.
51
+ """
52
+ self.circuit = circuit
53
+ self.rv_idxs = tuple(rv_idxs)
54
+ self.rows = dict(rows)
55
+
56
+ def __len__(self) -> int:
57
+ return len(self.rows)
58
+
59
+ def get(self, key, default=None):
60
+ return self.rows.get(key, default)
61
+
62
+ def keys(self) -> Iterable[CircuitNode]:
63
+ return self.rows.keys()
64
+
65
+ def values(self) -> Iterable[tuple[int, ...]]:
66
+ return self.rows.values()
67
+
68
+ def __getitem__(self, key):
69
+ return self.rows[key]
70
+
71
+ def __setitem__(self, key, value):
72
+ self.rows[key] = value
73
+
74
+ cpdef CircuitNode top(self):
75
+ # Get the circuit top value.
76
+ #
77
+ # Raises:
78
+ # RuntimeError if there is more than one row in the table.
79
+ #
80
+ # Returns:
81
+ # A single circuit node.
82
+ cdef int number_of_rows = len(self.rows)
83
+ if number_of_rows == 0:
84
+ return self.circuit.zero
85
+ elif number_of_rows == 1:
86
+ return next(iter(self.rows.values()))
87
+ else:
88
+ raise RuntimeError('cannot get top node from a table with more that 1 row')
89
+
90
+
91
+ # ==================================================================================
92
+ # Circuit Table Operations
93
+ # ==================================================================================
94
+
95
+ cpdef CircuitTable sum_out(CircuitTable table, object rv_idxs: Iterable[int]):
96
+ # Return a circuit table that results from summing out
97
+ # the given random variables of this circuit table.
98
+ #
99
+ # Normally this will return a new table. However, if rv_idxs is empty,
100
+ # then the given table is returned unmodified.
101
+ #
102
+ # Raises:
103
+ # ValueError if rv_idxs is not a subset of table.rv_idxs.
104
+ # ValueError if rv_idxs contains duplicates.
105
+ cdef tuple[int, ...] rv_idxs_seq = tuple(rv_idxs)
106
+
107
+ if len(rv_idxs_seq) == 0:
108
+ # nothing to do
109
+ return table
110
+
111
+ cdef set[int] rv_idxs_set = set(rv_idxs_seq)
112
+ if len(rv_idxs_set) != len(rv_idxs_seq):
113
+ raise ValueError('rv_idxs contains duplicates')
114
+ if not rv_idxs_set.issubset(table.rv_idxs):
115
+ raise ValueError('rv_idxs is not a subset of table.rv_idxs')
116
+
117
+ cdef int rv_index
118
+ cdef list[int] remaining_rv_idxs = []
119
+ for rv_index in table.rv_idxs:
120
+ if rv_index not in rv_idxs_set:
121
+ remaining_rv_idxs.append(rv_index)
122
+
123
+ cdef int num_remaining = len(remaining_rv_idxs)
124
+ if num_remaining == 0:
125
+ # Special case: summing out all random variables
126
+ return sum_out_all(table)
127
+
128
+ # index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
129
+ cdef list[int] index_map = []
130
+ for rv_index in remaining_rv_idxs:
131
+ index_map.append(_find(table.rv_idxs, rv_index))
132
+
133
+ cdef dict[tuple[int, ...], list[CircuitNode]] groups = {}
134
+ cdef object got
135
+ cdef list[int] group_instance
136
+ cdef tuple[int, ...] group_instance_tuple
137
+ cdef int i
138
+ cdef CircuitNode node
139
+ cdef tuple[int, ...] instance
140
+ for instance, node in table.rows.items():
141
+ group_instance = []
142
+ for i in index_map:
143
+ group_instance.append(instance[i])
144
+ group_instance_tuple = tuple(group_instance)
145
+ got = groups.get(group_instance_tuple)
146
+ if got is None:
147
+ groups[group_instance_tuple] = [node]
148
+ else:
149
+ got.append(node)
150
+
151
+ cdef Circuit circuit = table.circuit
152
+ cdef CircuitTable new_table = CircuitTable(circuit, remaining_rv_idxs)
153
+ cdef dict[tuple[int, ...], CircuitNode] rows = new_table.rows
154
+
155
+ for group_instance_tuple, to_add in groups.items():
156
+ node = circuit.op(c_ADD, tuple(to_add))
157
+ if not node.is_zero:
158
+ rows[group_instance_tuple] = node
159
+
160
+ return new_table
161
+
162
+
163
+ cpdef CircuitTable sum_out_all(CircuitTable table):
164
+ # Return a circuit table that results from summing out
165
+ # all random variables of this circuit table.
166
+ circuit: Circuit = table.circuit
167
+ num_rows: int = len(table)
168
+ if num_rows == 0:
169
+ return CircuitTable(circuit, ())
170
+ elif num_rows == 1:
171
+ node = next(iter(table.rows.values()))
172
+ else:
173
+ node: CircuitNode = circuit.op(c_ADD, tuple(table.rows.values()))
174
+ if node.is_zero:
175
+ return CircuitTable(circuit, ())
176
+
177
+ return CircuitTable(circuit, (), [((), node)])
178
+
179
+
180
+ cpdef CircuitTable project(CircuitTable table: CircuitTable, object rv_idxs: Iterable[int]):
181
+ # Call `sum_out(table, to_sum_out)`, where
182
+ # `to_sum_out = table.rv_idxs - rv_idxs`.
183
+ cdef set[int] to_sum_out = set(table.rv_idxs)
184
+ to_sum_out.difference_update(rv_idxs)
185
+ return sum_out(table, to_sum_out)
186
+
187
+
188
+ cpdef CircuitTable product(CircuitTable x, CircuitTable y):
189
+ # Return a circuit table that results from the product of the two given tables.
190
+ #
191
+ # If x or y equals `one_table`, then the other table is returned. Otherwise,
192
+ # a new circuit table will be constructed and returned.
193
+ cdef int i
194
+ cdef Circuit circuit = x.circuit
195
+ if y.circuit is not circuit:
196
+ raise ValueError('circuit tables must refer to the same circuit')
197
+
198
+ # Make the smaller table 'y', and the other 'x'.
199
+ # This is to minimise the index size on 'y'.
200
+ if len(x) < len(y):
201
+ x, y = y, x
202
+
203
+ # Special case: y == 0 or 1, and has no random variables.
204
+ if len(y.rv_idxs) == 0:
205
+ if len(y) == 1 and y.top().is_one:
206
+ return x
207
+ elif len(y) == 0:
208
+ return CircuitTable(circuit, x.rv_idxs)
209
+
210
+ # Set operations on rv indexes. After these operations:
211
+ # * co_rv_idxs is the set of rv indexes common (co) to x and y,
212
+ # * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
213
+ cdef set[int] yo_rv_idxs_set = set(y.rv_idxs)
214
+ cdef set[int] co_rv_idxs_set = set(x.rv_idxs)
215
+ co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
216
+ yo_rv_idxs_set.difference_update(co_rv_idxs_set)
217
+
218
+ if len(co_rv_idxs_set) == 0:
219
+ # Special case: no common random variables.
220
+ return _product_no_common_rvs(x, y)
221
+
222
+ # Convert random variable index sets to sequences
223
+ cdef tuple[int, ...] yo_rv_idxs = tuple(yo_rv_idxs_set) # y only random variables
224
+ cdef tuple[int, ...] co_rv_idxs = tuple(co_rv_idxs_set) # common random variables
225
+
226
+ # Cache mappings from result Instance to index into source Instance (x or y).
227
+ # This will be used in indexing and product loops to pull our needed values
228
+ # from the source instances.
229
+ cdef list[int] co_from_x_map = []
230
+ cdef list[int] co_from_y_map = []
231
+ cdef list[int] yo_from_y_map = []
232
+ for rv_index in co_rv_idxs:
233
+ co_from_x_map.append(_find(x.rv_idxs, rv_index))
234
+ co_from_y_map.append(_find(y.rv_idxs, rv_index))
235
+ for rv_index in yo_rv_idxs:
236
+ yo_from_y_map.append(_find(y.rv_idxs, rv_index))
237
+
238
+ cdef list[int] co
239
+ cdef list[int] yo
240
+ cdef object got
241
+ cdef tuple[int, ...] co_tuple
242
+ cdef tuple[int, ...] yo_tuple
243
+
244
+ cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + yo_rv_idxs)
245
+ cdef dict[tuple[int, ...], CircuitNode] rows = table.rows
246
+
247
+
248
+ # Index the y rows by common-only key (y is the smaller of the two tables).
249
+ cdef dict[tuple[int, ...], list[tuple[tuple[int, ...], CircuitNode]]] y_index = {}
250
+ for y_instance, y_node in y.rows.items():
251
+ co = []
252
+ yo = []
253
+ for i in co_from_y_map:
254
+ co.append(y_instance[i])
255
+ for i in yo_from_y_map:
256
+ yo.append(y_instance[i])
257
+ co_tuple = tuple(co)
258
+ yo_tuple = tuple(yo)
259
+ got = y_index.get(co_tuple)
260
+ if got is None:
261
+ y_index[co_tuple] = [(yo_tuple, y_node)]
262
+ else:
263
+ got.append((yo_tuple, y_node))
264
+
265
+
266
+ # Iterate over x rows, inserting (instance, value).
267
+ # Rows with constant node values of one are optimised out.
268
+ for x_instance, x_node in x.rows.items():
269
+ co = []
270
+ for i in co_from_x_map:
271
+ co.append(x_instance[i])
272
+ co_tuple = tuple(co)
273
+
274
+ if x_node.is_one:
275
+ # Multiplying by one.
276
+ # Iterate over matching y rows.
277
+ got = y_index.get(co_tuple)
278
+ if got is not None:
279
+ for yo_tuple, y_node in got:
280
+ rows[x_instance + yo_tuple] = y_node
281
+ else:
282
+ # Iterate over matching y rows.
283
+ got = y_index.get(co_tuple)
284
+ if got is not None:
285
+ for yo_tuple, y_node in got:
286
+ if y_node.is_one:
287
+ rows[x_instance + yo_tuple] = x_node
288
+ else:
289
+ rows[x_instance + yo_tuple] = circuit.op(c_MUL, (x_node, y_node))
290
+
291
+ return table
292
+
293
+
294
+ cdef int _find(tuple[int, ...] xs, int x):
295
+ cdef int i
296
+ for i in range(len(xs)):
297
+ if xs[i] == x:
298
+ return i
299
+ # Very unexpected
300
+ raise RuntimeError('not found')
301
+
302
+
303
+ cdef CircuitTable _product_no_common_rvs(CircuitTable x, CircuitTable y):
304
+ # Return the product of x and y, where x and y have no common random variables.
305
+ #
306
+ # This is an optimisation of more general product algorithm as no index needs
307
+ # to be construction based on the common random variables.
308
+ #
309
+ # Rows with constant node values of one are optimised out.
310
+ #
311
+ # Assumes:
312
+ # * There are no common random variables between x and y.
313
+ # * x and y are for the same circuit.
314
+ cdef Circuit circuit = x.circuit
315
+ cdef CircuitTable table = CircuitTable(circuit, x.rv_idxs + y.rv_idxs)
316
+ cdef tuple[int, ...] instance
317
+
318
+ for x_instance, x_node in x.rows.items():
319
+ if x_node.is_one:
320
+ for y_instance, y_node in y.rows.items():
321
+ instance = x_instance + y_instance
322
+ table.rows[instance] = y_node
323
+ else:
324
+ for y_instance, y_node in y.rows.items():
325
+ instance = x_instance + y_instance
326
+ if y_node.is_one:
327
+ table.rows[instance] = x_node
328
+ else:
329
+ table.rows[instance] = circuit.op(c_MUL, (x_node, y_node))
330
+
331
+ return table
332
+
@@ -0,0 +1,304 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence, Tuple, Dict, Iterable, Set, Iterator
4
+
5
+ from ck.circuit import CircuitNode, Circuit
6
+ from ck.utils.map_list import MapList
7
+
8
+ TableInstance = Tuple[int, ...]
9
+
10
+
11
+ class CircuitTable:
12
+ """
13
+ A circuit table manages a set of CircuitNodes, where each node corresponds
14
+ to an instance for a set of (zero or more) random variables.
15
+
16
+ Operations on circuit tables typically add circuit nodes to the circuit. It will
17
+ heuristically avoid adding unnecessary nodes (e.g. addition of zero, multiplication
18
+ by zero or one.) However, it may be that interim circuit nodes are created that
19
+ end up not being used. Consider calling `Circuit.remove_unreachable_op_nodes` after
20
+ completing all circuit table operations.
21
+
22
+ It is generally expected that no CircuitTable row will be created with a constant
23
+ zero node. These are assumed to be optimised out already.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ circuit: Circuit,
29
+ rv_idxs: Sequence[int],
30
+ rows: Iterable[Tuple[TableInstance, CircuitNode]] = (),
31
+ ):
32
+ """
33
+ Args:
34
+ circuit: the circuit whose nodes are being managed by this table.
35
+ rv_idxs: indexes of random variables.
36
+ rows: optional rows to add to the table.
37
+
38
+ Assumes:
39
+ * rv_idxs contains no duplicates.
40
+ * all row instances conform to the indexed random variables.
41
+ * all row circuit nodes belong to the given circuit.
42
+ """
43
+ self._circuit: Circuit = circuit
44
+ self._rv_idxs: Tuple[int, ...] = tuple(rv_idxs)
45
+ self._rows: Dict[TableInstance, CircuitNode] = dict(rows)
46
+
47
+ @property
48
+ def circuit(self) -> Circuit:
49
+ return self._circuit
50
+
51
+ @property
52
+ def rv_idxs(self) -> Tuple[int, ...]:
53
+ return self._rv_idxs
54
+
55
+ def __len__(self) -> int:
56
+ return len(self._rows)
57
+
58
+ def get(self, key, default=None):
59
+ return self._rows.get(key, default)
60
+
61
+ def keys(self) -> Iterable[TableInstance]:
62
+ return self._rows.keys()
63
+
64
+ def values(self) -> Iterable[CircuitNode]:
65
+ return self._rows.values()
66
+
67
+ def items(self) -> Iterable[Tuple[TableInstance, CircuitNode]]:
68
+ return self._rows.items()
69
+
70
+ def __getitem__(self, key):
71
+ return self._rows[key]
72
+
73
+ def __setitem__(self, key, value):
74
+ self._rows[key] = value
75
+
76
+ def top(self) -> CircuitNode:
77
+ """
78
+ Get the circuit top value.
79
+
80
+ Raises:
81
+ RuntimeError if there is more than one row in the table.
82
+
83
+ Returns:
84
+ A single circuit node.
85
+ """
86
+ if len(self._rows) == 0:
87
+ return self._circuit.zero
88
+ elif len(self._rows) == 1:
89
+ return next(iter(self._rows.values()))
90
+ else:
91
+ raise RuntimeError('cannot get top node from a table with more that 1 row')
92
+
93
+
94
+ # ==================================================================================
95
+ # Circuit Table Operations
96
+ # ==================================================================================
97
+
98
+
99
+ def sum_out(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
100
+ """
101
+ Return a circuit table that results from summing out
102
+ the given random variables of this circuit table.
103
+
104
+ Normally this will return a new table. However, if rv_idxs is empty,
105
+ then the given table is returned unmodified.
106
+
107
+ Raises:
108
+ ValueError if rv_idxs is not a subset of table.rv_idxs.
109
+ ValueError if rv_idxs contains duplicates.
110
+ """
111
+ rv_idxs: Sequence[int] = tuple(rv_idxs)
112
+
113
+ if len(rv_idxs) == 0:
114
+ # nothing to do
115
+ return table
116
+
117
+ rv_idxs_set: Set[int] = set(rv_idxs)
118
+ if len(rv_idxs_set) != len(rv_idxs):
119
+ raise ValueError('rv_idxs contains duplicates')
120
+ if not rv_idxs_set.issubset(table.rv_idxs):
121
+ raise ValueError('rv_idxs is not a subset of table.rv_idxs')
122
+
123
+ remaining_rv_idxs = tuple(
124
+ rv_index
125
+ for rv_index in table.rv_idxs
126
+ if rv_index not in rv_idxs_set
127
+ )
128
+ num_remaining = len(remaining_rv_idxs)
129
+ if num_remaining == 0:
130
+ # Special case: summing out all random variables
131
+ return sum_out_all(table)
132
+
133
+ # index_map[i] is the location in table.rv_idxs for remaining_rv_idxs[i]
134
+ index_map = tuple(
135
+ table.rv_idxs.index(remaining_rv_index)
136
+ for remaining_rv_index in remaining_rv_idxs
137
+ )
138
+
139
+ # This is a one-pass version to sum the groups. The
140
+ # two-pass version (below) seems to have better performance.
141
+ #
142
+ # circuit: Circuit = table.circuit
143
+ # result = CircuitTable(circuit, remaining_rv_idxs)
144
+ # result_rows: Dict[TableInstance, CircuitNode] = result._rows
145
+ # for instance, node in table.items():
146
+ # group_instance = tuple(instance[i] for i in index_map)
147
+ # prev_sum = result_rows.get(group_instance)
148
+ # if prev_sum is None:
149
+ # result_rows[group_instance] = node
150
+ # else:
151
+ # result_rows[group_instance] = circuit.add(prev_sum, node)
152
+ # return result
153
+
154
+ groups: MapList[TableInstance, CircuitNode] = MapList()
155
+ for instance, node in table.items():
156
+ group_instance = tuple(instance[i] for i in index_map)
157
+ groups.append(group_instance, node)
158
+ circuit: Circuit = table.circuit
159
+ return CircuitTable(
160
+ circuit,
161
+ remaining_rv_idxs,
162
+ (
163
+ (group, circuit.add(to_add))
164
+ for group, to_add in groups.items()
165
+ )
166
+ )
167
+
168
+
169
+ def sum_out_all(table: CircuitTable) -> CircuitTable:
170
+ """
171
+ Return a circuit table that results from summing out
172
+ all random variables of this circuit table.
173
+ """
174
+ circuit: Circuit = table.circuit
175
+ num_rows: int = len(table)
176
+ if num_rows == 0:
177
+ return CircuitTable(circuit, ())
178
+ elif num_rows == 1:
179
+ node = next(iter(table.values()))
180
+ else:
181
+ node: CircuitNode = circuit.optimised_add(table.values())
182
+ if node.is_zero:
183
+ return CircuitTable(circuit, ())
184
+
185
+ return CircuitTable(circuit, (), [((), node)])
186
+
187
+
188
+ def project(table: CircuitTable, rv_idxs: Iterable[int]) -> CircuitTable:
189
+ """
190
+ Call `sum_out(table, to_sum_out)`, where
191
+ `to_sum_out = table.rv_idxs - rv_idxs`.
192
+ """
193
+ to_sum_out: Set[int] = set(table.rv_idxs)
194
+ to_sum_out.difference_update(rv_idxs)
195
+ return sum_out(table, to_sum_out)
196
+
197
+
198
+ def product(x: CircuitTable, y: CircuitTable) -> CircuitTable:
199
+ """
200
+ Return a circuit table that results from the product of the two given tables.
201
+
202
+ If x or y have a single row with value 1, then the other table is returned. Otherwise,
203
+ a new circuit table will be constructed and returned.
204
+ """
205
+ circuit: Circuit = x.circuit
206
+ if y.circuit is not circuit:
207
+ raise ValueError('circuit tables must refer to the same circuit')
208
+
209
+ # Make the smaller table 'y', and the other 'x'.
210
+ # This is to minimise the index size on 'y'.
211
+ if len(x) < len(y):
212
+ x, y = y, x
213
+
214
+ x_rv_idxs: Tuple[int, ...] = x.rv_idxs
215
+ y_rv_idxs: Tuple[int, ...] = y.rv_idxs
216
+
217
+ # Special case: y == 0 or 1, and has no random variables.
218
+ if y_rv_idxs == ():
219
+ if len(y) == 1 and y.top().is_one:
220
+ return x
221
+ elif len(y) == 0:
222
+ return CircuitTable(circuit, x_rv_idxs)
223
+
224
+ # Set operations on rv indexes. After these operations:
225
+ # * co_rv_idxs is the set of rv indexes common (co) to x and y,
226
+ # * yo_rv_idxs is the set of rv indexes in y only (yo), and not in x.
227
+ yo_rv_idxs_set: Set[int] = set(y_rv_idxs)
228
+ co_rv_idxs_set: Set[int] = set(x_rv_idxs)
229
+ co_rv_idxs_set.intersection_update(yo_rv_idxs_set)
230
+ yo_rv_idxs_set.difference_update(co_rv_idxs_set)
231
+
232
+ if len(co_rv_idxs_set) == 0:
233
+ # Special case: no common random variables.
234
+ return _product_no_common_rvs(x, y)
235
+
236
+ # Convert random variable index sets to sequences
237
+ yo_rv_idxs: Tuple[int, ...] = tuple(yo_rv_idxs_set) # y only random variables
238
+ co_rv_idxs: Tuple[int, ...] = tuple(co_rv_idxs_set) # common random variables
239
+
240
+ # Cache mappings from result Instance to index into source Instance (x or y).
241
+ # This will be used in indexing and product loops to pull our needed values
242
+ # from the source instances.
243
+ co_from_x_map = tuple(x.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
244
+ co_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in co_rv_idxs)
245
+ yo_from_y_map = tuple(y.rv_idxs.index(rv_index) for rv_index in yo_rv_idxs)
246
+
247
+ # Index the y rows by common-only key (y is the smaller of the two tables).
248
+ y_index: MapList[TableInstance, Tuple[TableInstance, CircuitNode]] = MapList()
249
+ for y_instance, y_node in y.items():
250
+ co = tuple(y_instance[i] for i in co_from_y_map)
251
+ yo = tuple(y_instance[i] for i in yo_from_y_map)
252
+ y_index.append(co, (yo, y_node))
253
+
254
+ def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
255
+ # Iterate over x rows, yielding (instance, value).
256
+ # Rows with constant node values of one are optimised out.
257
+ for _x_instance, _x_node in x.items():
258
+ _co = tuple(_x_instance[i] for i in co_from_x_map)
259
+ if _x_node.is_one:
260
+ # Multiplying by one.
261
+ # Iterate over matching y rows.
262
+ for _yo, _y_node in y_index.get(_co, ()):
263
+ yield _x_instance + _yo, _y_node
264
+ else:
265
+ # Iterate over matching y rows.
266
+ for _yo, _y_node in y_index.get(_co, ()):
267
+ if _y_node.is_one:
268
+ yield _x_instance + _yo, _x_node
269
+ else:
270
+ yield _x_instance + _yo, circuit.mul(_x_node, _y_node)
271
+
272
+ return CircuitTable(circuit, x_rv_idxs + yo_rv_idxs, _result_rows())
273
+
274
+
275
+ def _product_no_common_rvs(x: CircuitTable, y: CircuitTable) -> CircuitTable:
276
+ """
277
+ Return the product of x and y, where x and y have no common random variables.
278
+
279
+ This is an optimisation of more general product algorithm as no index needs
280
+ to be construction based on the common random variables.
281
+
282
+ Rows with constant node values of one are optimised out.
283
+
284
+ Assumes:
285
+ * There are no common random variables between x and y.
286
+ * x and y are for the same circuit.
287
+ """
288
+ circuit: Circuit = x.circuit
289
+
290
+ result_rv_idxs: Tuple[int, ...] = x.rv_idxs + y.rv_idxs
291
+
292
+ def _result_rows() -> Iterator[Tuple[TableInstance, CircuitNode]]:
293
+ for x_instance, x_node in x.items():
294
+ if x_node.is_one:
295
+ for y_instance, y_node in y.items():
296
+ yield x_instance + y_instance, y_node
297
+ else:
298
+ for y_instance, y_node in y.items():
299
+ if y_node.is_one:
300
+ yield x_instance + y_instance, y_node
301
+ else:
302
+ yield x_instance + y_instance, circuit.mul(x_node, y_node)
303
+
304
+ return CircuitTable(circuit, result_rv_idxs, _result_rows())