compiled-knowledge 4.0.0a20__cp312-cp312-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (178) hide show
  1. ck/__init__.py +0 -0
  2. ck/circuit/__init__.py +17 -0
  3. ck/circuit/_circuit_cy.c +37523 -0
  4. ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
  5. ck/circuit/_circuit_cy.pxd +32 -0
  6. ck/circuit/_circuit_cy.pyx +768 -0
  7. ck/circuit/_circuit_py.py +836 -0
  8. ck/circuit/tmp_const.py +74 -0
  9. ck/circuit_compiler/__init__.py +2 -0
  10. ck/circuit_compiler/circuit_compiler.py +26 -0
  11. ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
  12. ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
  13. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
  14. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
  15. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
  16. ck/circuit_compiler/interpret_compiler.py +223 -0
  17. ck/circuit_compiler/llvm_compiler.py +388 -0
  18. ck/circuit_compiler/llvm_vm_compiler.py +546 -0
  19. ck/circuit_compiler/named_circuit_compilers.py +57 -0
  20. ck/circuit_compiler/support/__init__.py +0 -0
  21. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  22. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
  23. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
  24. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  25. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
  26. ck/circuit_compiler/support/input_vars.py +148 -0
  27. ck/circuit_compiler/support/llvm_ir_function.py +234 -0
  28. ck/example/__init__.py +53 -0
  29. ck/example/alarm.py +366 -0
  30. ck/example/asia.py +28 -0
  31. ck/example/binary_clique.py +32 -0
  32. ck/example/bow_tie.py +33 -0
  33. ck/example/cancer.py +37 -0
  34. ck/example/chain.py +38 -0
  35. ck/example/child.py +199 -0
  36. ck/example/clique.py +33 -0
  37. ck/example/cnf_pgm.py +39 -0
  38. ck/example/diamond_square.py +68 -0
  39. ck/example/earthquake.py +36 -0
  40. ck/example/empty.py +10 -0
  41. ck/example/hailfinder.py +539 -0
  42. ck/example/hepar2.py +628 -0
  43. ck/example/insurance.py +504 -0
  44. ck/example/loop.py +40 -0
  45. ck/example/mildew.py +38161 -0
  46. ck/example/munin.py +22982 -0
  47. ck/example/pathfinder.py +53747 -0
  48. ck/example/rain.py +39 -0
  49. ck/example/rectangle.py +161 -0
  50. ck/example/run.py +30 -0
  51. ck/example/sachs.py +129 -0
  52. ck/example/sprinkler.py +30 -0
  53. ck/example/star.py +44 -0
  54. ck/example/stress.py +64 -0
  55. ck/example/student.py +43 -0
  56. ck/example/survey.py +46 -0
  57. ck/example/triangle_square.py +54 -0
  58. ck/example/truss.py +49 -0
  59. ck/in_out/__init__.py +3 -0
  60. ck/in_out/parse_ace_lmap.py +216 -0
  61. ck/in_out/parse_ace_nnf.py +322 -0
  62. ck/in_out/parse_net.py +480 -0
  63. ck/in_out/parser_utils.py +185 -0
  64. ck/in_out/pgm_pickle.py +42 -0
  65. ck/in_out/pgm_python.py +268 -0
  66. ck/in_out/render_bugs.py +111 -0
  67. ck/in_out/render_net.py +177 -0
  68. ck/in_out/render_pomegranate.py +184 -0
  69. ck/pgm.py +3475 -0
  70. ck/pgm_circuit/__init__.py +1 -0
  71. ck/pgm_circuit/marginals_program.py +352 -0
  72. ck/pgm_circuit/mpe_program.py +237 -0
  73. ck/pgm_circuit/pgm_circuit.py +79 -0
  74. ck/pgm_circuit/program_with_slotmap.py +236 -0
  75. ck/pgm_circuit/slot_map.py +35 -0
  76. ck/pgm_circuit/support/__init__.py +0 -0
  77. ck/pgm_circuit/support/compile_circuit.py +83 -0
  78. ck/pgm_circuit/target_marginals_program.py +103 -0
  79. ck/pgm_circuit/wmc_program.py +323 -0
  80. ck/pgm_compiler/__init__.py +2 -0
  81. ck/pgm_compiler/ace/__init__.py +1 -0
  82. ck/pgm_compiler/ace/ace.py +299 -0
  83. ck/pgm_compiler/factor_elimination.py +395 -0
  84. ck/pgm_compiler/named_pgm_compilers.py +63 -0
  85. ck/pgm_compiler/pgm_compiler.py +19 -0
  86. ck/pgm_compiler/recursive_conditioning.py +231 -0
  87. ck/pgm_compiler/support/__init__.py +0 -0
  88. ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
  89. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
  90. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
  91. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
  92. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
  93. ck/pgm_compiler/support/clusters.py +568 -0
  94. ck/pgm_compiler/support/factor_tables.py +406 -0
  95. ck/pgm_compiler/support/join_tree.py +332 -0
  96. ck/pgm_compiler/support/named_compiler_maker.py +43 -0
  97. ck/pgm_compiler/variable_elimination.py +91 -0
  98. ck/probability/__init__.py +0 -0
  99. ck/probability/empirical_probability_space.py +50 -0
  100. ck/probability/pgm_probability_space.py +32 -0
  101. ck/probability/probability_space.py +622 -0
  102. ck/program/__init__.py +3 -0
  103. ck/program/program.py +137 -0
  104. ck/program/program_buffer.py +180 -0
  105. ck/program/raw_program.py +67 -0
  106. ck/sampling/__init__.py +0 -0
  107. ck/sampling/forward_sampler.py +211 -0
  108. ck/sampling/marginals_direct_sampler.py +113 -0
  109. ck/sampling/sampler.py +62 -0
  110. ck/sampling/sampler_support.py +232 -0
  111. ck/sampling/uniform_sampler.py +72 -0
  112. ck/sampling/wmc_direct_sampler.py +171 -0
  113. ck/sampling/wmc_gibbs_sampler.py +153 -0
  114. ck/sampling/wmc_metropolis_sampler.py +165 -0
  115. ck/sampling/wmc_rejection_sampler.py +115 -0
  116. ck/utils/__init__.py +0 -0
  117. ck/utils/iter_extras.py +163 -0
  118. ck/utils/local_config.py +270 -0
  119. ck/utils/map_list.py +128 -0
  120. ck/utils/map_set.py +128 -0
  121. ck/utils/np_extras.py +51 -0
  122. ck/utils/random_extras.py +64 -0
  123. ck/utils/tmp_dir.py +94 -0
  124. ck_demos/__init__.py +0 -0
  125. ck_demos/ace/__init__.py +0 -0
  126. ck_demos/ace/copy_ace_to_ck.py +15 -0
  127. ck_demos/ace/demo_ace.py +49 -0
  128. ck_demos/all_demos.py +88 -0
  129. ck_demos/circuit/__init__.py +0 -0
  130. ck_demos/circuit/demo_circuit_dump.py +22 -0
  131. ck_demos/circuit/demo_derivatives.py +43 -0
  132. ck_demos/circuit_compiler/__init__.py +0 -0
  133. ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
  134. ck_demos/circuit_compiler/show_llvm_program.py +26 -0
  135. ck_demos/pgm/__init__.py +0 -0
  136. ck_demos/pgm/demo_pgm_dump.py +18 -0
  137. ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
  138. ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
  139. ck_demos/pgm/show_examples.py +25 -0
  140. ck_demos/pgm_compiler/__init__.py +0 -0
  141. ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
  142. ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
  143. ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
  144. ck_demos/pgm_compiler/demo_join_tree.py +25 -0
  145. ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
  146. ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
  147. ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
  148. ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
  149. ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
  150. ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
  151. ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
  152. ck_demos/pgm_inference/__init__.py +0 -0
  153. ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
  154. ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
  155. ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
  156. ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
  157. ck_demos/programs/__init__.py +0 -0
  158. ck_demos/programs/demo_program_buffer.py +24 -0
  159. ck_demos/programs/demo_program_multi.py +24 -0
  160. ck_demos/programs/demo_program_none.py +19 -0
  161. ck_demos/programs/demo_program_single.py +23 -0
  162. ck_demos/programs/demo_raw_program_interpreted.py +21 -0
  163. ck_demos/programs/demo_raw_program_llvm.py +21 -0
  164. ck_demos/sampling/__init__.py +0 -0
  165. ck_demos/sampling/check_sampler.py +71 -0
  166. ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
  167. ck_demos/sampling/demo_uniform_sampler.py +38 -0
  168. ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
  169. ck_demos/utils/__init__.py +0 -0
  170. ck_demos/utils/compare.py +120 -0
  171. ck_demos/utils/convert_network.py +45 -0
  172. ck_demos/utils/sample_model.py +216 -0
  173. ck_demos/utils/stop_watch.py +384 -0
  174. compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
  175. compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
  176. compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
  177. compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
  178. compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
@@ -0,0 +1,768 @@
1
+ """
2
+ For more documentation on this module, refer to the Jupyter notebook docs/6_circuits_and_programs.ipynb.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from itertools import chain
7
+ from typing import Dict, Optional, Iterable, Sequence, List, overload
8
+
9
+ # Type for values of ConstNode objects
10
+ ConstValue = float | int | bool
11
+
12
+ # A type representing a flexible representation of multiple CircuitNode objects.
13
+ Args = CircuitNode | ConstValue | Iterable[CircuitNode | ConstValue]
14
+
15
+ ADD: int = 0
16
+ MUL: int = 1
17
+
18
+ cdef int c_ADD = ADD
19
+ cdef int c_MUL = MUL
20
+
21
+ cdef class Circuit:
22
+ """
23
+ An arithmetic circuit defines an arithmetic function from input variables (`VarNode` objects)
24
+ and constant values (`ConstNode` objects) to one or more result values. Computation is defined
25
+ over a mathematical ring, with two operations: addition and multiplication (represented
26
+ by `OpNode` objects).
27
+
28
+ An arithmetic circuit needs to be compiled to a program to execute the function.
29
+
30
+ All nodes belong to a circuit. All nodes are immutable, with the exception that a
31
+ `VarNode` may be temporarily be set to a constant value.
32
+ """
33
+
34
+ def __init__(self, zero: ConstValue = 0, one: ConstValue = 1):
35
+ """
36
+ Construct a new, empty circuit.
37
+
38
+ Args:
39
+ zero: The constant value for zero. mul(x, zero) = zero, add(x, zero) = x.
40
+ one: The constant value for one. mul(x, one) = x.
41
+ """
42
+ self.vars: List[VarNode] = []
43
+ self.ops: List[OpNode] = []
44
+ self.zero: ConstNode = ConstNode(self, zero, is_zero=True)
45
+ self.one: ConstNode = ConstNode(self, one, is_one=True)
46
+
47
+ self._const_map: Dict[ConstValue, ConstNode] = {zero: self.zero, one: self.one}
48
+ self.__derivatives: Optional[_DerivativeHelper] = None # cache for partial derivatives calculations.
49
+
50
+ @property
51
+ def number_of_vars(self) -> int:
52
+ """
53
+ Returns:
54
+ the number of "var" nodes.
55
+ """
56
+ return len(self.vars)
57
+
58
+ @property
59
+ def number_of_consts(self) -> int:
60
+ """
61
+ Returns:
62
+ the number of "const" nodes.
63
+ """
64
+ return len(self._const_map)
65
+
66
+ @property
67
+ def number_of_op_nodes(self) -> int:
68
+ """
69
+ Returns:
70
+ the number of "op" nodes.
71
+ """
72
+ return len(self.ops)
73
+
74
+ @property
75
+ def number_of_arcs(self) -> int:
76
+ """
77
+ Returns:
78
+ the number of arcs in the circuit, i.e., the sum of the
79
+ number of arguments for all op nodes.
80
+ """
81
+ return sum(len(op.args) for op in self.ops)
82
+
83
+ @property
84
+ def number_of_operations(self):
85
+ """
86
+ How many arithmetic operations are needed to calculate the circuit.
87
+ This is number_of_arcs - number_of_op_nodes.
88
+ """
89
+ return self.number_of_arcs - self.number_of_op_nodes
90
+
91
+ def new_var(self) -> VarNode:
92
+ """
93
+ Create and return a new variable node.
94
+ """
95
+ node = VarNode(self, len(self.vars))
96
+ self.vars.append(node)
97
+ return node
98
+
99
+ def new_vars(self, num_of_vars: int) -> Sequence[VarNode]:
100
+ """
101
+ Create and return multiple variable nodes.
102
+ """
103
+ offset = self.number_of_vars
104
+ new_vars = tuple(VarNode(self, i) for i in range(offset, offset + num_of_vars))
105
+ self.vars.extend(new_vars)
106
+ return new_vars
107
+
108
+ def const(self, value: ConstValue | ConstNode) -> ConstNode:
109
+ """
110
+ Return a const node for the given value.
111
+ If a const node for that value already exists, then it will be returned,
112
+ otherwise a new const node will be created.
113
+ """
114
+ if isinstance(value, ConstNode):
115
+ value = value.value
116
+
117
+ node = self._const_map.get(value)
118
+ if node is None:
119
+ node = ConstNode(self, value)
120
+ self._const_map[value] = node
121
+ return node
122
+
123
+ def add(self, *nodes: Args) -> OpNode:
124
+ """
125
+ Create and return a new 'addition' node, applied to the given arguments.
126
+ """
127
+ cdef list[CircuitNode] args = self._check_nodes(nodes)
128
+ return self.op(c_ADD, tuple(args))
129
+
130
+ def mul(self, *nodes: Args) -> OpNode:
131
+ """
132
+ Create and return a new 'multiplication' node, applied to the given arguments.
133
+ """
134
+ cdef list[CircuitNode] args = self._check_nodes(nodes)
135
+ return self.op(c_MUL, tuple(args))
136
+
137
+ def optimised_add(self, *args: Args) -> CircuitNode:
138
+ """
139
+ Optimised circuit node addition.
140
+
141
+ Performs the following optimisations:
142
+ * addition to zero is avoided: add(x, 0) = x,
143
+ * singleton addition is avoided: add(x) = x,
144
+ * empty addition is avoided: add() = 0,
145
+ """
146
+ cdef tuple[CircuitNode, ...] to_add = tuple(n for n in self._check_nodes(args) if not n.is_zero)
147
+ cdef size_t num_to_add = len(to_add)
148
+ if num_to_add == 0:
149
+ return self.zero
150
+ if num_to_add == 1:
151
+ return to_add[0]
152
+ return self.op(c_ADD, to_add)
153
+
154
+ def optimised_mul(self, *args: Args) -> CircuitNode:
155
+ """
156
+ Optimised circuit node multiplication.
157
+
158
+ Performs the following optimisations:
159
+ * multiplication by zero is avoided: mul(x, 0) = 0,
160
+ * multiplication by one is avoided: mul(x, 1) = x,
161
+ * singleton multiplication is avoided: mul(x) = x,
162
+ * empty multiplication is avoided: mul() = 1,
163
+ """
164
+ cdef tuple[CircuitNode, ...] to_mul = tuple(n for n in self._check_nodes(args) if not n.is_one)
165
+ if any(n.is_zero for n in to_mul):
166
+ return self.zero
167
+ cdef Py_ssize_t num_to_mul = len(to_mul)
168
+
169
+ if num_to_mul == 0:
170
+ return self.one
171
+ if num_to_mul == 1:
172
+ return to_mul[0]
173
+ return self.op(c_MUL, to_mul)
174
+
175
+ def cartesian_product(self, xs: Sequence[CircuitNode], ys: Sequence[CircuitNode]) -> List[CircuitNode]:
176
+ """
177
+ Add multiply operations, one for each possible combination of x from xs and y from ys.
178
+
179
+ Args:
180
+ xs: first list of circuit nodes, may be either a Node object or a list of Nodes.
181
+ ys: second list of circuit nodes, may be either a Node object or a list of Nodes.
182
+
183
+ Returns:
184
+ a list of 'mul' nodes, one for each combination of xs and ys. The results are in the order
185
+ given by `[mul(x, y) for x in xs for y in ys]`.
186
+ """
187
+ cdef list[CircuitNode] xs_list = self._check_nodes(xs)
188
+ cdef list[CircuitNode] ys_list = self._check_nodes(ys)
189
+ return [
190
+ self.optimised_mul(x, y)
191
+ for x in xs_list
192
+ for y in ys_list
193
+ ]
194
+
195
+ @overload
196
+ def partial_derivatives(
197
+ self,
198
+ f: CircuitNode,
199
+ args: Sequence[CircuitNode],
200
+ *,
201
+ self_multiply: bool = False,
202
+ ) -> List[CircuitNode]:
203
+ ...
204
+
205
+ @overload
206
+ def partial_derivatives(
207
+ self,
208
+ f: CircuitNode,
209
+ args: CircuitNode,
210
+ *,
211
+ self_multiply: bool = False,
212
+ ) -> CircuitNode:
213
+ ...
214
+
215
+ def partial_derivatives(
216
+ self,
217
+ f: CircuitNode,
218
+ args,
219
+ *,
220
+ self_multiply: bool = False,
221
+ ):
222
+ """
223
+ Add to the circuit the operations required to calculate the partial derivative of f
224
+ with respect to each of the given nodes. If self_multiple is True, then this is
225
+ equivalent to calculating the marginal probability at each var that represents
226
+ an indicator.
227
+
228
+ This method will cache partial derivative calculations for `f` so that subsequent calls
229
+ to this method with the same `f` will not cause duplicated calculations to be added to
230
+ the circuit. To release this cache, call `self.release_derivatives_cache()`.
231
+
232
+ Args:
233
+ f: is the circuit node that defines the function to differentiate.
234
+ args: nodes that are the arguments to f (typically VarNode objects).
235
+ The value may be either a CircuitNode object or a list of CircuitNode objects.
236
+ self_multiply: if true then each partial derivative df/dx will be multiplied by x.
237
+
238
+ Returns:
239
+ the results nodes for the partial derivatives, co-indexed with the given arg nodes.
240
+ If `args` is a single CircuitNode, then a single CircuitNode will be returned, otherwise
241
+ a list of CircuitNode is returned.
242
+ """
243
+ cdef bint single_result = isinstance(args, CircuitNode)
244
+
245
+ cdef list[CircuitNode] args_list = self._check_nodes([args])
246
+ if len(args_list) == 0:
247
+ # Trivial case
248
+ return []
249
+
250
+ derivatives: _DerivativeHelper = self._derivatives(f)
251
+ result: List[CircuitNode]
252
+ if self_multiply:
253
+ result = [
254
+ derivatives.derivative_self_mul(arg)
255
+ for arg in args_list
256
+ ]
257
+ else:
258
+ result = [
259
+ derivatives.derivative(arg)
260
+ for arg in args_list
261
+ ]
262
+
263
+ if single_result:
264
+ return result[0]
265
+ else:
266
+ return result
267
+
268
+ def remove_derivatives_cache(self) -> None:
269
+ """
270
+ Release the memory held for partial derivative calculations, as per `partial_derivatives`.
271
+ """
272
+ self.__derivatives = None
273
+
274
+ def remove_unreachable_op_nodes(self, *nodes: Args) -> None:
275
+ """
276
+ Find all op nodes reachable from the given nodes, via op arguments.
277
+ (using `self.reachable_op_nodes`). Remove all other op nodes from this circuit.
278
+
279
+ If any external object holds a reference to a removed node, that node will be unusable.
280
+
281
+ Args:
282
+ *nodes: may be either a node or a list of nodes.
283
+ """
284
+ cdef list[CircuitNode] node_list = self._check_nodes(nodes)
285
+ self._remove_unreachable_op_nodes(node_list)
286
+
287
+ def reachable_op_nodes(self, *nodes: Args) -> List[OpNode]:
288
+ """
289
+ Iterate over all op nodes reachable from the given nodes, via op arguments.
290
+
291
+ Args:
292
+ *nodes: may be either a node or a list of nodes.
293
+
294
+ Returns:
295
+ An iterator over all op nodes reachable from the given nodes.
296
+
297
+ Ensures:
298
+ Returned nodes are not repeated.
299
+ The result is ordered such that if result[i] is referenced by result[j] then i < j.
300
+ """
301
+ cdef list[CircuitNode] node_list = self._check_nodes(nodes)
302
+ return self.find_reachable_op_nodes(node_list)
303
+
304
+ def dump(
305
+ self,
306
+ *,
307
+ prefix: str = '',
308
+ indent: str = ' ',
309
+ var_names: Optional[List[str]] = None,
310
+ include_consts: bool = False,
311
+ ) -> None:
312
+ """
313
+ Print a dump of the Circuit.
314
+ This is intended for debugging and demonstration purposes.
315
+
316
+ Args:
317
+ prefix: optional prefix for indenting all lines.
318
+ indent: additional prefix to use for extra indentation.
319
+ var_names: optional variable names to show.
320
+ include_consts: if true, then constant values are dumped.
321
+ """
322
+
323
+ next_prefix: str = prefix + indent
324
+
325
+ node_name: Dict[int, str] = {}
326
+
327
+ print(f'{prefix}number of vars: {self.number_of_vars:,}')
328
+ print(f'{prefix}number of const nodes: {self.number_of_consts:,}')
329
+ print(f'{prefix}number of op nodes: {self.number_of_op_nodes:,}')
330
+ print(f'{prefix}number of operations: {self.number_of_operations:,}')
331
+ print(f'{prefix}number of arcs: {self.number_of_arcs:,}')
332
+
333
+ print(f'{prefix}var nodes: {self.number_of_vars}')
334
+ for var in self.vars:
335
+ node_name[id(var)] = f'var[{var.idx}]'
336
+ var_name: str = '' if var_names is None or var.idx >= len(var_names) else var_names[var.idx]
337
+ if var_name != '':
338
+ if var.is_const():
339
+ print(f'{next_prefix}var[{var.idx}]: {var_name}, {var.const.value}')
340
+ else:
341
+ print(f'{next_prefix}var[{var.idx}]: {var_name}')
342
+ elif var.is_const():
343
+ print(f'{next_prefix}var[{var.idx}]: {var.const.value}')
344
+
345
+ if include_consts:
346
+ print(f'{prefix}const nodes: {self.number_of_consts}')
347
+ for const in self._const_map.values():
348
+ print(f'{next_prefix}{const.value!r}')
349
+
350
+ # Add const nodes to the node_name dict
351
+ for const in self._const_map.values():
352
+ node_name[id(const)] = repr(const.value)
353
+
354
+ # Add op nodes to the node_name dict
355
+ for i, op in enumerate(self.ops):
356
+ node_name[id(op)] = f'{op.op_str()}<{i}>'
357
+
358
+ print(
359
+ f'{prefix}op nodes: {self.number_of_op_nodes} '
360
+ f'(arcs: {self.number_of_arcs}, ops: {self.number_of_operations})'
361
+ )
362
+ for op in reversed(self.ops):
363
+ op_name = node_name[id(op)]
364
+ args_str = ' '.join(node_name[id(arg)] for arg in op.args)
365
+ print(f'{next_prefix}{op_name}: {args_str}')
366
+
367
+ cdef OpNode op(self, int symbol, tuple[CircuitNode, ...] nodes):
368
+ cdef OpNode node = OpNode(self, symbol, nodes)
369
+ self.ops.append(node)
370
+ return node
371
+
372
+ cdef list[OpNode] find_reachable_op_nodes(self, list[CircuitNode] nodes):
373
+ # Set of object ids for all reachable op nodes
374
+ cdef set[int] seen = set()
375
+
376
+ cdef list[OpNode] result = []
377
+
378
+ cdef CircuitNode node
379
+ for node in nodes:
380
+ find_reachable_op_nodes_r(node, seen, result)
381
+ return result
382
+
383
+ cdef void _remove_unreachable_op_nodes(self, list[CircuitNode] nodes):
384
+ # Set of object ids for all reachable op nodes
385
+ cdef set[int] seen = set()
386
+
387
+ cdef CircuitNode node
388
+ for node in nodes:
389
+ find_reachable_op_nodes_seen_r(node, seen)
390
+
391
+ if len(seen) < len(self.ops):
392
+ # Invalidate unreadable op nodes
393
+ for op_node in self.ops:
394
+ if id(op_node) not in seen:
395
+ op_node.circuit = None
396
+ op_node.args = ()
397
+
398
+ # Keep only reachable op nodes, in the same order as `self.ops`.
399
+ self.ops = [op_node for op_node in self.ops if id(op_node) in seen]
400
+
401
+ cdef list[CircuitNode] _check_nodes(self, object nodes: Iterable[Args]): # -> Sequence[CircuitNode]:
402
+ # Convert the given circuit nodes to a list, flattening nested iterables as needed.
403
+ #
404
+ # Args:
405
+ # nodes: some circuit nodes of constant values.
406
+ #
407
+ # Raises:
408
+ # RuntimeError: if any node does not belong to this circuit.
409
+ cdef list[CircuitNode] result = []
410
+ self.__check_nodes(nodes, result)
411
+ return result
412
+
413
+ cdef void __check_nodes(self, object nodes: Iterable[Args], list[CircuitNode] result):
414
+ # Convert the given circuit nodes to a list, flattening nested iterables as needed.
415
+ #
416
+ # Args:
417
+ # nodes: some circuit nodes of constant values.
418
+ #
419
+ # Raises:
420
+ # RuntimeError: if any node does not belong to this circuit.
421
+ for node in nodes:
422
+ if isinstance(node, CircuitNode):
423
+ if node.circuit is not self:
424
+ raise RuntimeError('node does not belong to this circuit')
425
+ else:
426
+ result.append(node)
427
+ elif isinstance(node, ConstValue):
428
+ result.append(self.const(node))
429
+ else:
430
+ self.__check_nodes(node, result)
431
+
432
+ cdef object _derivatives(self, CircuitNode f):
433
+ # Get a _DerivativeHelper for `f`.
434
+ # Checking the derivative cache.
435
+ derivatives: Optional[_DerivativeHelper] = self.__derivatives
436
+ if derivatives is None or derivatives.f is not f:
437
+ derivatives = _DerivativeHelper(f)
438
+ self.__derivatives = derivatives
439
+ return derivatives
440
+
441
+ cdef class CircuitNode:
442
+ """
443
+ A node in an arithmetic circuit.
444
+ Each node is either an op, var, or const node.
445
+
446
+ Each op node is either a mul, add or sub node. Each op
447
+ node has zero or more arguments. Each argument is another node.
448
+
449
+ Every var node has an index, `idx`, which is an integer counting from zero, and denotes
450
+ its creation order.
451
+
452
+ A var node may be temporarily set to be a constant node, which may
453
+ be useful for optimising a compiled circuit.
454
+ """
455
+
456
+ def __init__(self, circuit: Circuit, is_zero: bool, is_one: bool):
457
+ self.circuit = circuit
458
+ self.is_zero = is_zero
459
+ self.is_one = is_one
460
+
461
+ def __add__(self, other: CircuitNode | ConstValue):
462
+ return self.circuit.add(self, other)
463
+
464
+ def __mul__(self, other: CircuitNode | ConstValue):
465
+ return self.circuit.mul(self, other)
466
+
467
+ cdef class ConstNode(CircuitNode):
468
+ """
469
+ A node in a circuit representing a constant value.
470
+ """
471
+
472
+ def __init__(self, circuit, value: ConstValue, is_zero: bool = False, is_one: bool = False):
473
+ super().__init__(circuit, is_zero, is_one)
474
+ self.value: ConstValue = value
475
+
476
+ def __str__(self) -> str:
477
+ return 'const(' + str(self.value) + ')'
478
+
479
+ def __lt__(self, other) -> bool:
480
+ if isinstance(other, ConstNode):
481
+ return self.value < other.value
482
+ else:
483
+ return False
484
+
485
+ cdef class VarNode(CircuitNode):
486
+ """
487
+ A node in a circuit representing an input variable.
488
+ """
489
+
490
+ def __init__(self, circuit, idx: int):
491
+ super().__init__(circuit, False, False)
492
+ self.idx = idx
493
+ self._const = None
494
+
495
+ cpdef int is_const(self) except*:
496
+ return self._const is not None
497
+
498
+ @property
499
+ def const(self) -> Optional[ConstNode]:
500
+ return self._const
501
+
502
+ @const.setter
503
+ def const(self, value: ConstValue | ConstNode | None) -> None:
504
+ if value is None:
505
+ self._const = None
506
+ self.is_zero = False
507
+ self.is_one = False
508
+ else:
509
+ const_node: ConstNode = self.circuit.const(value)
510
+ self._const = const_node
511
+ self.is_zero = const_node.is_zero
512
+ self.is_one = const_node.is_one
513
+
514
+ def __lt__(self, other) -> bool:
515
+ if isinstance(other, VarNode):
516
+ return self.idx < other.idx
517
+ else:
518
+ return False
519
+
520
+ def __str__(self) -> str:
521
+ if self._const is None:
522
+ return 'var[' + str(self.idx) + ']'
523
+ else:
524
+ return 'var[' + str(self.idx) + '] = ' + str(self._const.value)
525
+
526
+ cdef class OpNode(CircuitNode):
527
+ """
528
+ A node in a circuit representing an arithmetic operation.
529
+ """
530
+
531
+ def __init__(self, Circuit circuit, int symbol, tuple[CircuitNode, ...] args):
532
+ super().__init__(circuit, False, False)
533
+ self.args = tuple(args)
534
+ self.symbol = <int> symbol
535
+
536
+ def __str__(self) -> str:
537
+ return f'{self.op_str()}\\{len(self.args)}'
538
+
539
+ def op_str(self) -> str:
540
+ """
541
+ Returns the op node operation as a string.
542
+ """
543
+ if self.symbol == c_MUL:
544
+ return 'mul'
545
+ elif self.symbol == c_ADD:
546
+ return 'add'
547
+ else:
548
+ return '?' + str(self.symbol)
549
+
550
+ cdef class _DNode:
551
+ """
552
+ A data structure supporting derivative calculations.
553
+ A DNode holds all information needed to calculate the partial derivative at `node`.
554
+ """
555
+ cdef CircuitNode node
556
+ cdef object derivative
557
+ cdef object derivative_self_mul
558
+ cdef list[_DNodeProduct] sum_prod
559
+ cdef bint processed
560
+
561
+ def __init__(
562
+ self,
563
+ node: CircuitNode,
564
+ derivative: Optional[CircuitNode],
565
+ ):
566
+ self.node = node
567
+ self.derivative = derivative
568
+ self.derivative_self_mul = None
569
+ self.sum_prod = []
570
+ self.processed = False
571
+
572
+ def __str__(self) -> str:
573
+ """
574
+ for debugging
575
+ """
576
+ dots: str = '...'
577
+ return (
578
+ 'DNode(' + str(self.node) + ', '
579
+ + str(None if self.derivative is None else dots) + ', '
580
+ + str(None if self.derivative_self_mul is None else dots) + ', '
581
+ + str(len(self.sum_prod)) + ', '
582
+ + str(self.processed)
583
+ )
584
+
585
+ cdef class _DNodeProduct:
586
+ """
587
+ A data structure supporting derivative calculations.
588
+
589
+ The represents a product of `parent` and `prod`.
590
+ """
591
+ cdef _DNode parent
592
+ cdef list[CircuitNode] prod
593
+
594
+ def __init__(self, parent: _DNode, prod: List[CircuitNode]):
595
+ self.parent = parent
596
+ self.prod = prod
597
+
598
+ def __str__(self) -> str:
599
+ """
600
+ for debugging
601
+ """
602
+ return 'DNodeProduct(' + str(self.parent) + ', ' + str(self.prod) + ')'
603
+
604
+
605
+ cdef class _DerivativeHelper:
606
+ """
607
+ A data structure to support efficient calculation of partial derivatives
608
+ with respect to some function node `f`.
609
+ """
610
+
611
+ cdef CircuitNode f
612
+ cdef CircuitNode zero
613
+ cdef CircuitNode one
614
+ cdef Circuit circuit
615
+ cdef dict[int, _DNode] d_nodes
616
+
617
+ def __init__(self, f: CircuitNode):
618
+ """
619
+ Prepare to calculate partial derivatives with respect to `f`.
620
+ """
621
+ self.f = f
622
+ self.circuit = f.circuit
623
+ self.d_nodes = {} # map id(CircuitNode) to its DNode
624
+ self.zero = self.circuit.zero
625
+ self.one = self.circuit.one
626
+
627
+ cdef _DNode top_d_node = _DNode(f, self.one)
628
+ self.d_nodes[id(f)] = top_d_node
629
+ self._mk_derivative_r(top_d_node)
630
+
631
+ cdef CircuitNode derivative(self, CircuitNode node):
632
+ d_node: Optional[_DNode] = self.d_nodes.get(id(node))
633
+ if d_node is None:
634
+ return self.zero
635
+ else:
636
+ return self._derivative(d_node)
637
+
638
+ cdef CircuitNode derivative_self_mul(self, CircuitNode node):
639
+ d_node: Optional[_DNode] = self.d_nodes.get(id(node))
640
+ if d_node is None:
641
+ return self.zero
642
+
643
+ if d_node.derivative_self_mul is None:
644
+ d: CircuitNode = self._derivative(d_node)
645
+ if d is self.zero:
646
+ d_node.derivative_self_mul = self.zero
647
+ elif d is self.one:
648
+ d_node.derivative_self_mul = node
649
+ else:
650
+ d_node.derivative_self_mul = self.circuit.optimised_mul(d, node)
651
+
652
+ return d_node.derivative_self_mul
653
+
654
+ cdef CircuitNode _derivative(self, _DNode d_node):
655
+ if d_node.derivative is not None:
656
+ return d_node.derivative
657
+
658
+ # Get the list of circuit nodes that must be added together.
659
+ to_add: Sequence[CircuitNode] = tuple(
660
+ value
661
+ for value in (self._derivative_prod(prods) for prods in d_node.sum_prod)
662
+ if not value.is_zero
663
+ )
664
+ # we can release the temporary memory at this DNode now
665
+ d_node.sum_prod = None
666
+
667
+ # Construct the addition operation
668
+ d_node.derivative = self.circuit.optimised_add(*to_add)
669
+
670
+ return d_node.derivative
671
+
672
+ cdef CircuitNode _derivative_prod(self, _DNodeProduct prods):
673
+ # Support `_derivative` by constructing the derivative product for the given _DNodeProduct.
674
+
675
+ # Get the derivative of the parent node.
676
+ parent: CircuitNode = self._derivative(prods.parent)
677
+
678
+ # Multiply the parent derivative with all other nodes recorded at prod.
679
+ to_mul: List[CircuitNode] = []
680
+ for arg in chain((parent,), prods.prod):
681
+ if arg is self.zero:
682
+ # Multiplication by zero is zero
683
+ return self.zero
684
+ if arg is not self.one:
685
+ to_mul.append(arg)
686
+
687
+ # Construct the multiplication operation
688
+ return self.circuit.optimised_mul(*to_mul)
689
+
690
+ cdef void _mk_derivative_r(self, _DNode d_node):
691
+ # Construct a DNode for each argument of the given DNode.
692
+
693
+ if d_node.processed:
694
+ return
695
+ d_node.processed = True
696
+ node: CircuitNode = d_node.node
697
+
698
+ if isinstance(node, OpNode):
699
+ if node.symbol == c_ADD:
700
+ for arg in node.args:
701
+ child_d_node = self._add(arg, d_node, [])
702
+ self._mk_derivative_r(child_d_node)
703
+ elif node.symbol == c_MUL:
704
+ for arg in node.args:
705
+ prod = [arg2 for arg2 in node.args if arg is not arg2]
706
+ child_d_node = self._add(arg, d_node, prod)
707
+ self._mk_derivative_r(child_d_node)
708
+
709
+ cdef _DNode _add(self, CircuitNode node, _DNode parent, list[CircuitNode] prod):
710
+ # Support for `_mk_derivative_r`.
711
+ #
712
+ # Add a _DNodeProduct(parent, negate, prod) to the DNode for the given circuit node.
713
+ #
714
+ # If the DNode for `node` does not yet exist, one will be created.
715
+ #
716
+ # The given circuit node may have multiple parents (i.e., a shared sub-expression). Therefore,
717
+ # this method may be called multiple times for a given node. Each time a new _DNodeProduct will be added.
718
+ #
719
+ # Args:
720
+ # node: the CircuitNode that the returned DNode is for.
721
+ # parent: the DNode of the parent node, i.e., `node` is an argument to the parent node.
722
+ # prod: other circuit nodes that need to be multiplied with the parent derivative when
723
+ # constructing a derivative for `node`.
724
+ #
725
+ # Returns:
726
+ # the DNode for `node`.
727
+
728
+ child_d_node: _DNode = self._get(node)
729
+ child_d_node.sum_prod.append(_DNodeProduct(parent, prod))
730
+ return child_d_node
731
+
732
+ cdef _DNode _get(self, CircuitNode node):
733
+ # Helper for derivatives.
734
+ #
735
+ # Get the DNode for the given circuit node.
736
+ # If no DNode exist for it yet, then one will be constructed.
737
+
738
+ node_id: int = id(node)
739
+ d_node: Optional[_DNode] = self.d_nodes.get(node_id)
740
+ if d_node is None:
741
+ d_node = _DNode(node, None)
742
+ self.d_nodes[node_id] = d_node
743
+ return d_node
744
+
745
+
746
+ cdef void find_reachable_op_nodes_r(CircuitNode node, set[int] seen, list[OpNode] result):
747
+ # Recursive helper for `reachable_op_nodes`. Performs a depth-first search.
748
+ #
749
+ # Args:
750
+ # node: the current node being checked.
751
+ # seen: keep track of seen op node ids (to avoid returning multiple of the same node).
752
+ # result: a list where the seen nodes are added.
753
+ if isinstance(node, OpNode) and id(node) not in seen:
754
+ seen.add(id(node))
755
+ for arg in node.args:
756
+ find_reachable_op_nodes_r(arg, seen, result)
757
+ result.append(node)
758
+
759
+ cdef void find_reachable_op_nodes_seen_r(CircuitNode node, set[int] seen):
760
+ # Recursive helper for `remove_unreachable_op_nodes`. Performs a depth-first search.
761
+ #
762
+ # Args:
763
+ # node: the current node being checked.
764
+ # seen: set of seen op node ids.
765
+ if isinstance(node, OpNode) and id(node) not in seen:
766
+ seen.add(id(node))
767
+ for arg in node.args:
768
+ find_reachable_op_nodes_seen_r(arg, seen)