compiled-knowledge 4.0.0a20__cp312-cp312-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +17 -0
- ck/circuit/_circuit_cy.c +37523 -0
- ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +768 -0
- ck/circuit/_circuit_py.py +836 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +19824 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +380 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +121 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +10618 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_py.py +93 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53747 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +322 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3475 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +79 -0
- ck/pgm_circuit/program_with_slotmap.py +236 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +299 -0
- ck/pgm_compiler/factor_elimination.py +395 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +231 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +17 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +16396 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +304 -0
- ck/pgm_compiler/support/clusters.py +568 -0
- ck/pgm_compiler/support/factor_tables.py +406 -0
- ck/pgm_compiler/support/join_tree.py +332 -0
- ck/pgm_compiler/support/named_compiler_maker.py +43 -0
- ck/pgm_compiler/variable_elimination.py +91 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +50 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +622 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +137 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +67 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +72 -0
- ck/sampling/wmc_direct_sampler.py +171 -0
- ck/sampling/wmc_gibbs_sampler.py +153 -0
- ck/sampling/wmc_metropolis_sampler.py +165 -0
- ck/sampling/wmc_rejection_sampler.py +115 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +163 -0
- ck/utils/local_config.py +270 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +49 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +63 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +60 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +120 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a20.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a20.dist-info/RECORD +178 -0
- compiled_knowledge-4.0.0a20.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a20.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a20.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,836 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This is a pure Python implementation of Circuits (for testing and development)
|
|
3
|
+
|
|
4
|
+
For more documentation on this module, refer to the Jupyter notebook docs/6_circuits_and_programs.ipynb.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from itertools import chain
|
|
11
|
+
from typing import Dict, Tuple, Optional, Iterable, Sequence, List, overload, Iterator, Set
|
|
12
|
+
|
|
13
|
+
# Type for values of ConstNode objects
|
|
14
|
+
ConstValue = float | int | bool
|
|
15
|
+
|
|
16
|
+
# Symbols for op nodes
|
|
17
|
+
ADD: int = 0
|
|
18
|
+
MUL: int = 1
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Circuit:
|
|
22
|
+
"""
|
|
23
|
+
An arithmetic circuit defines an arithmetic function from input variables (`VarNode` objects)
|
|
24
|
+
and constant values (`ConstNode` objects) to one or more result values. Computation is defined
|
|
25
|
+
over a mathematical ring, with two operations: addition and multiplication (represented
|
|
26
|
+
by `OpNode` objects).
|
|
27
|
+
|
|
28
|
+
An arithmetic circuit needs to be compiled to a program to execute the function.
|
|
29
|
+
|
|
30
|
+
All nodes belong to a circuit. All nodes are immutable, with the exception that a
|
|
31
|
+
`VarNode` may be temporarily be set to a constant value.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, zero: ConstValue = 0, one: ConstValue = 1):
|
|
35
|
+
"""
|
|
36
|
+
Construct a new, empty circuit.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
zero: The constant value for zero. mul(x, zero) = zero, add(x, zero) = x.
|
|
40
|
+
one: The constant value for one. mul(x, one) = x.
|
|
41
|
+
"""
|
|
42
|
+
self._vars: List[VarNode] = []
|
|
43
|
+
self._ops: List[OpNode] = []
|
|
44
|
+
self._const_map: Dict[ConstValue, ConstNode] = {}
|
|
45
|
+
self.__derivatives: Optional[_DerivativeHelper] = None # cache for partial derivatives calculations.
|
|
46
|
+
self._zero: ConstNode = self.const(zero)
|
|
47
|
+
self._one: ConstNode = self.const(one)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def number_of_vars(self) -> int:
|
|
51
|
+
"""
|
|
52
|
+
Returns:
|
|
53
|
+
the number of "var" nodes.
|
|
54
|
+
"""
|
|
55
|
+
return len(self._vars)
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def number_of_consts(self) -> int:
|
|
59
|
+
"""
|
|
60
|
+
Returns:
|
|
61
|
+
the number of "const" nodes.
|
|
62
|
+
"""
|
|
63
|
+
return len(self._const_map)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def number_of_op_nodes(self) -> int:
|
|
67
|
+
"""
|
|
68
|
+
Returns:
|
|
69
|
+
the number of "op" nodes.
|
|
70
|
+
"""
|
|
71
|
+
return len(self._ops)
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def number_of_arcs(self) -> int:
|
|
75
|
+
"""
|
|
76
|
+
Returns:
|
|
77
|
+
the number of arcs in the circuit, i.e., the sum of the
|
|
78
|
+
number of arguments for all op nodes.
|
|
79
|
+
"""
|
|
80
|
+
return sum(len(op.args) for op in self._ops)
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def number_of_operations(self):
|
|
84
|
+
"""
|
|
85
|
+
How many op nodes are in the circuit.
|
|
86
|
+
This is number_of_arcs - number_of_op_nodes.
|
|
87
|
+
"""
|
|
88
|
+
return self.number_of_arcs - self.number_of_op_nodes
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def vars(self) -> Sequence[VarNode]:
|
|
92
|
+
"""
|
|
93
|
+
Returns:
|
|
94
|
+
the var nodes, in index order.
|
|
95
|
+
|
|
96
|
+
Ensures:
|
|
97
|
+
`self.vars[i].idx == i`.
|
|
98
|
+
"""
|
|
99
|
+
return self._vars
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def ops(self) -> Sequence[OpNode]:
|
|
103
|
+
"""
|
|
104
|
+
Returns:
|
|
105
|
+
the op nodes, in the order they were added to this circuit.
|
|
106
|
+
"""
|
|
107
|
+
return self._ops
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def zero(self) -> ConstNode:
|
|
111
|
+
"""
|
|
112
|
+
Get the constant representing zero.
|
|
113
|
+
"""
|
|
114
|
+
return self._zero
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def one(self) -> ConstNode:
|
|
118
|
+
"""
|
|
119
|
+
Get the constant representing one.
|
|
120
|
+
"""
|
|
121
|
+
return self._one
|
|
122
|
+
|
|
123
|
+
def new_var(self) -> VarNode:
|
|
124
|
+
"""
|
|
125
|
+
Create and return a new variable node.
|
|
126
|
+
"""
|
|
127
|
+
node = VarNode(self, len(self._vars))
|
|
128
|
+
self._vars.append(node)
|
|
129
|
+
return node
|
|
130
|
+
|
|
131
|
+
def new_vars(self, num_of_vars: int) -> Sequence[VarNode]:
|
|
132
|
+
"""
|
|
133
|
+
Create and return multiple variable nodes.
|
|
134
|
+
"""
|
|
135
|
+
offset = self.number_of_vars
|
|
136
|
+
new_vars = tuple(VarNode(self, i) for i in range(offset, offset + num_of_vars))
|
|
137
|
+
self._vars.extend(new_vars)
|
|
138
|
+
return new_vars
|
|
139
|
+
|
|
140
|
+
def const(self, value: ConstValue | ConstNode) -> ConstNode:
|
|
141
|
+
"""
|
|
142
|
+
Return a const node for the given value.
|
|
143
|
+
If a const node for that value already exists, then it will be returned,
|
|
144
|
+
otherwise a new const node will be created.
|
|
145
|
+
"""
|
|
146
|
+
if isinstance(value, ConstNode):
|
|
147
|
+
value = value.value
|
|
148
|
+
|
|
149
|
+
node = self._const_map.get(value)
|
|
150
|
+
if node is None:
|
|
151
|
+
node = ConstNode(self, value)
|
|
152
|
+
self._const_map[value] = node
|
|
153
|
+
return node
|
|
154
|
+
|
|
155
|
+
def _op(self, symbol: int, args: Tuple[CircuitNode, ...]) -> OpNode:
|
|
156
|
+
"""
|
|
157
|
+
Create and return a new op node, applied to the given arguments.
|
|
158
|
+
"""
|
|
159
|
+
node = OpNode(self, symbol, args)
|
|
160
|
+
self._ops.append(node)
|
|
161
|
+
return node
|
|
162
|
+
|
|
163
|
+
def add(self, *args: Args) -> OpNode:
|
|
164
|
+
"""
|
|
165
|
+
Create and return a new 'addition' node, applied to the given arguments.
|
|
166
|
+
"""
|
|
167
|
+
return self._op(ADD, self._check_nodes(args))
|
|
168
|
+
|
|
169
|
+
def mul(self, *args: Args) -> OpNode:
|
|
170
|
+
"""
|
|
171
|
+
Create and return a new 'multiplication' node, applied to the given arguments.
|
|
172
|
+
"""
|
|
173
|
+
return self._op(MUL, self._check_nodes(args))
|
|
174
|
+
|
|
175
|
+
def optimised_add(self, *args: Args) -> CircuitNode:
|
|
176
|
+
"""
|
|
177
|
+
Optimised circuit node addition.
|
|
178
|
+
|
|
179
|
+
Performs the following optimisations:
|
|
180
|
+
* addition to zero is avoided: add(x, 0) = x,
|
|
181
|
+
* singleton addition is avoided: add(x) = x,
|
|
182
|
+
* empty addition is avoided: add() = 0,
|
|
183
|
+
"""
|
|
184
|
+
to_add = tuple(n for n in self._check_nodes(args) if not n.is_zero)
|
|
185
|
+
match len(to_add):
|
|
186
|
+
case 0:
|
|
187
|
+
return self.zero
|
|
188
|
+
case 1:
|
|
189
|
+
return to_add[0]
|
|
190
|
+
case _:
|
|
191
|
+
return self._op(ADD, to_add)
|
|
192
|
+
|
|
193
|
+
def optimised_mul(self, *args: Args) -> CircuitNode:
|
|
194
|
+
"""
|
|
195
|
+
Optimised circuit node multiplication.
|
|
196
|
+
|
|
197
|
+
Performs the following optimisations:
|
|
198
|
+
* multiplication by zero is avoided: mul(x, 0) = 0,
|
|
199
|
+
* multiplication by one is avoided: mul(x, 1) = x,
|
|
200
|
+
* singleton multiplication is avoided: mul(x) = x,
|
|
201
|
+
* empty multiplication is avoided: mul() = 1,
|
|
202
|
+
"""
|
|
203
|
+
to_mul = tuple(n for n in self._check_nodes(args) if not n.is_one)
|
|
204
|
+
if any(n.is_zero for n in to_mul):
|
|
205
|
+
return self.zero
|
|
206
|
+
match len(to_mul):
|
|
207
|
+
case 0:
|
|
208
|
+
return self.one
|
|
209
|
+
case 1:
|
|
210
|
+
return to_mul[0]
|
|
211
|
+
case _:
|
|
212
|
+
return self._op(MUL, to_mul)
|
|
213
|
+
|
|
214
|
+
def cartesian_product(self, xs: Sequence[CircuitNode], ys: Sequence[CircuitNode]) -> List[CircuitNode]:
|
|
215
|
+
"""
|
|
216
|
+
Add multiply operations, one for each possible combination of x from xs and y from ys.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
xs: first list of circuit nodes, may be either a Node object or a list of Nodes.
|
|
220
|
+
ys: second list of circuit nodes, may be either a Node object or a list of Nodes.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
a list of 'mul' nodes, one for each combination of xs and ys. The results are in the order
|
|
224
|
+
given by `[mul(x, y) for x in xs for y in ys]`.
|
|
225
|
+
"""
|
|
226
|
+
xs: Sequence[CircuitNode] = self._check_nodes(xs)
|
|
227
|
+
ys: Sequence[CircuitNode] = self._check_nodes(ys)
|
|
228
|
+
return [
|
|
229
|
+
self.optimised_mul(x, y)
|
|
230
|
+
for x in xs
|
|
231
|
+
for y in ys
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
@overload
|
|
235
|
+
def partial_derivatives(
|
|
236
|
+
self,
|
|
237
|
+
f: CircuitNode,
|
|
238
|
+
args: Sequence[CircuitNode],
|
|
239
|
+
*,
|
|
240
|
+
self_multiply: bool = False,
|
|
241
|
+
) -> List[CircuitNode]:
|
|
242
|
+
...
|
|
243
|
+
|
|
244
|
+
@overload
|
|
245
|
+
def partial_derivatives(
|
|
246
|
+
self,
|
|
247
|
+
f: CircuitNode,
|
|
248
|
+
args: CircuitNode,
|
|
249
|
+
*,
|
|
250
|
+
self_multiply: bool = False,
|
|
251
|
+
) -> CircuitNode:
|
|
252
|
+
...
|
|
253
|
+
|
|
254
|
+
def partial_derivatives(
|
|
255
|
+
self,
|
|
256
|
+
f: CircuitNode,
|
|
257
|
+
args,
|
|
258
|
+
*,
|
|
259
|
+
self_multiply: bool = False,
|
|
260
|
+
):
|
|
261
|
+
"""
|
|
262
|
+
Add to the circuit the operations required to calculate the partial derivative of f
|
|
263
|
+
with respect to each of the given nodes. If self_multiple is True, then this is
|
|
264
|
+
equivalent to calculating the marginal probability at each var that represents
|
|
265
|
+
an indicator.
|
|
266
|
+
|
|
267
|
+
This method will cache partial derivative calculations for `f` so that subsequent calls
|
|
268
|
+
to this method with the same `f` will not cause duplicated calculations to be added to
|
|
269
|
+
the circuit. To release this cache, call `self.release_derivatives_cache()`.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
f: is the circuit node that defines the function to differentiate.
|
|
273
|
+
args: nodes that are the arguments to f (typically VarNode objects).
|
|
274
|
+
The value may be either a CircuitNode object or a list of CircuitNode objects.
|
|
275
|
+
self_multiply: if true then each partial derivative df/dx will be multiplied by x.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
the results nodes for the partial derivatives, co-indexed with the given arg nodes.
|
|
279
|
+
If `args` is a single CircuitNode, then a single CircuitNode will be returned, otherwise
|
|
280
|
+
a list of CircuitNode is returned.
|
|
281
|
+
"""
|
|
282
|
+
single_result: bool = isinstance(args, CircuitNode)
|
|
283
|
+
|
|
284
|
+
args: Sequence[CircuitNode] = self._check_nodes([args])
|
|
285
|
+
if len(args) == 0:
|
|
286
|
+
# Trivial case
|
|
287
|
+
return []
|
|
288
|
+
|
|
289
|
+
derivatives: _DerivativeHelper = self._derivatives(f)
|
|
290
|
+
result: List[CircuitNode]
|
|
291
|
+
if self_multiply:
|
|
292
|
+
result = [
|
|
293
|
+
derivatives.derivative_self_mul(arg)
|
|
294
|
+
for arg in args
|
|
295
|
+
]
|
|
296
|
+
else:
|
|
297
|
+
result = [
|
|
298
|
+
derivatives.derivative(arg)
|
|
299
|
+
for arg in args
|
|
300
|
+
]
|
|
301
|
+
|
|
302
|
+
if single_result:
|
|
303
|
+
return result[0]
|
|
304
|
+
else:
|
|
305
|
+
return result
|
|
306
|
+
|
|
307
|
+
def remove_derivatives_cache(self) -> None:
|
|
308
|
+
"""
|
|
309
|
+
Release the memory held for partial derivative calculations, as per `partial_derivatives`.
|
|
310
|
+
"""
|
|
311
|
+
self.__derivatives = None
|
|
312
|
+
|
|
313
|
+
def remove_unreachable_op_nodes(self, *nodes: Args) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Find all op nodes reachable from the given nodes, via op arguments.
|
|
316
|
+
(using `self.reachable_op_nodes`). Remove all other op nodes from this circuit.
|
|
317
|
+
|
|
318
|
+
If any external object holds a reference to a removed node, that node will be unusable.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
*nodes: may be either a node or a list of nodes.
|
|
322
|
+
"""
|
|
323
|
+
seen: Set[int] = set() # set of object ids for all reachable op nodes.
|
|
324
|
+
for node in self._check_nodes(nodes):
|
|
325
|
+
_reachable_op_nodes_seen_r(node, seen)
|
|
326
|
+
|
|
327
|
+
if len(seen) < len(self._ops):
|
|
328
|
+
# Invalidate unreadable op nodes
|
|
329
|
+
for op_node in self._ops:
|
|
330
|
+
if id(op_node) not in seen:
|
|
331
|
+
op_node.circuit = None
|
|
332
|
+
op_node.args = ()
|
|
333
|
+
|
|
334
|
+
# Keep only reachable op nodes
|
|
335
|
+
self._ops = tuple(op_node for op_node in self._ops if id(op_node) in seen)
|
|
336
|
+
|
|
337
|
+
def reachable_op_nodes(self, *nodes: Args) -> List[OpNode]:
|
|
338
|
+
"""
|
|
339
|
+
Iterate over all op nodes reachable from the given nodes, via op arguments.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
*nodes: may be either a node or a list of nodes.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
An list of all op nodes reachable from the given nodes.
|
|
346
|
+
|
|
347
|
+
Ensures:
|
|
348
|
+
Returned nodes are not repeated.
|
|
349
|
+
The result is ordered such that if result[i] is referenced by result[j] then i < j.
|
|
350
|
+
"""
|
|
351
|
+
seen: Set[int] = set() # set of object ids for all reachable op nodes.
|
|
352
|
+
return [
|
|
353
|
+
reachable
|
|
354
|
+
for node in self._check_nodes(nodes)
|
|
355
|
+
for reachable in _reachable_op_nodes_r(node, seen)
|
|
356
|
+
]
|
|
357
|
+
|
|
358
|
+
def dump(
|
|
359
|
+
self,
|
|
360
|
+
*,
|
|
361
|
+
prefix: str = '',
|
|
362
|
+
indent: str = ' ',
|
|
363
|
+
var_names: Optional[List[str]] = None,
|
|
364
|
+
include_consts: bool = False,
|
|
365
|
+
) -> None:
|
|
366
|
+
"""
|
|
367
|
+
Print a dump of the Circuit.
|
|
368
|
+
This is intended for debugging and demonstration purposes.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
prefix: optional prefix for indenting all lines.
|
|
372
|
+
indent: additional prefix to use for extra indentation.
|
|
373
|
+
var_names: optional variable names to show.
|
|
374
|
+
include_consts: if true, then constant values are dumped.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
next_prefix: str = prefix + indent
|
|
378
|
+
|
|
379
|
+
node_name: Dict[int, str] = {}
|
|
380
|
+
|
|
381
|
+
print(f'{prefix}number of vars: {self.number_of_vars:,}')
|
|
382
|
+
print(f'{prefix}number of const nodes: {self.number_of_consts:,}')
|
|
383
|
+
print(f'{prefix}number of op nodes: {self.number_of_op_nodes:,}')
|
|
384
|
+
print(f'{prefix}number of operations: {self.number_of_operations:,}')
|
|
385
|
+
print(f'{prefix}number of arcs: {self.number_of_arcs:,}')
|
|
386
|
+
|
|
387
|
+
print(f'{prefix}var nodes: {self.number_of_vars}')
|
|
388
|
+
for var in self.vars:
|
|
389
|
+
node_name[id(var)] = f'var[{var.idx}]'
|
|
390
|
+
var_name: str = '' if var_names is None or var.idx >= len(var_names) else var_names[var.idx]
|
|
391
|
+
if var_name != '':
|
|
392
|
+
if var.is_const():
|
|
393
|
+
print(f'{next_prefix}var[{var.idx}]: {var_name}, {var.const.value}')
|
|
394
|
+
else:
|
|
395
|
+
print(f'{next_prefix}var[{var.idx}]: {var_name}')
|
|
396
|
+
elif var.is_const():
|
|
397
|
+
print(f'{next_prefix}var[{var.idx}]: {var.const.value}')
|
|
398
|
+
|
|
399
|
+
if include_consts:
|
|
400
|
+
print(f'{prefix}const nodes: {self.number_of_consts}')
|
|
401
|
+
for const in self._const_map.values():
|
|
402
|
+
print(f'{next_prefix}{const.value!r}')
|
|
403
|
+
|
|
404
|
+
# Add const nodes to the node_name dict
|
|
405
|
+
for const in self._const_map.values():
|
|
406
|
+
node_name[id(const)] = repr(const.value)
|
|
407
|
+
|
|
408
|
+
# Add op nodes to the node_name dict
|
|
409
|
+
for i, op in enumerate(self.ops):
|
|
410
|
+
node_name[id(op)] = f'{op.op_str()}<{i}>'
|
|
411
|
+
|
|
412
|
+
print(
|
|
413
|
+
f'{prefix}op nodes: {self.number_of_op_nodes} '
|
|
414
|
+
f'(arcs: {self.number_of_arcs}, ops: {self.number_of_operations})'
|
|
415
|
+
)
|
|
416
|
+
for op in reversed(self.ops):
|
|
417
|
+
op_name = node_name[id(op)]
|
|
418
|
+
args_str = ' '.join(node_name[id(arg)] for arg in op.args)
|
|
419
|
+
print(f'{next_prefix}{op_name}: {args_str}')
|
|
420
|
+
|
|
421
|
+
def _check_nodes(self, nodes: Iterable[Args]) -> Tuple[CircuitNode, ...]:
|
|
422
|
+
"""
|
|
423
|
+
Convert the given circuit nodes to a tuple, flattening nested iterables as needed.
|
|
424
|
+
|
|
425
|
+
Args:
|
|
426
|
+
nodes: some circuit nodes of constant values.
|
|
427
|
+
|
|
428
|
+
Raises:
|
|
429
|
+
RuntimeError: if any node does not belong to this circuit.
|
|
430
|
+
"""
|
|
431
|
+
return tuple(self.__check_nodes(nodes))
|
|
432
|
+
|
|
433
|
+
def __check_nodes(self, nodes: Iterable[Args]) -> Iterable[CircuitNode]:
|
|
434
|
+
"""
|
|
435
|
+
Convert the given circuit nodes to a tuple, flattening nested iterables as needed.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
nodes: some circuit nodes of constant values.
|
|
439
|
+
|
|
440
|
+
Raises:
|
|
441
|
+
RuntimeError: if any node does not belong to this circuit.
|
|
442
|
+
"""
|
|
443
|
+
|
|
444
|
+
for node in nodes:
|
|
445
|
+
if isinstance(node, CircuitNode):
|
|
446
|
+
if node.circuit is not self:
|
|
447
|
+
raise RuntimeError('node does not belong to this circuit')
|
|
448
|
+
else:
|
|
449
|
+
yield node
|
|
450
|
+
elif isinstance(node, ConstValue):
|
|
451
|
+
yield self.const(node)
|
|
452
|
+
else:
|
|
453
|
+
# Assume it's iterable
|
|
454
|
+
for n in self._check_nodes(node):
|
|
455
|
+
yield n
|
|
456
|
+
|
|
457
|
+
def _derivatives(self, f: CircuitNode) -> _DerivativeHelper:
|
|
458
|
+
"""
|
|
459
|
+
Get a _DerivativeHelper for `f`.
|
|
460
|
+
Checking the derivative cache.
|
|
461
|
+
"""
|
|
462
|
+
derivatives: Optional[_DerivativeHelper] = self.__derivatives
|
|
463
|
+
if derivatives is None or derivatives.f is not f:
|
|
464
|
+
derivatives = _DerivativeHelper(f)
|
|
465
|
+
self.__derivatives = derivatives
|
|
466
|
+
return derivatives
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class CircuitNode:
|
|
470
|
+
"""
|
|
471
|
+
A node in an arithmetic circuit.
|
|
472
|
+
Each node is either an op, var, or const node.
|
|
473
|
+
|
|
474
|
+
Each op node is either a mul, add or sub node. Each op
|
|
475
|
+
node has zero or more arguments. Each argument is another node.
|
|
476
|
+
|
|
477
|
+
Every var node has an index, `idx`, which is an integer counting from zero, and denotes
|
|
478
|
+
its creation order.
|
|
479
|
+
|
|
480
|
+
A var node may be temporarily set to be a constant node, which may
|
|
481
|
+
be useful for optimising a compiled circuit.
|
|
482
|
+
"""
|
|
483
|
+
__slots__ = ('circuit',)
|
|
484
|
+
|
|
485
|
+
def __init__(self, circuit: Circuit):
|
|
486
|
+
self.circuit = circuit
|
|
487
|
+
|
|
488
|
+
@property
|
|
489
|
+
def is_zero(self) -> bool:
|
|
490
|
+
"""
|
|
491
|
+
Does this node represent the constant zero.
|
|
492
|
+
"""
|
|
493
|
+
return False
|
|
494
|
+
|
|
495
|
+
@property
|
|
496
|
+
def is_one(self) -> bool:
|
|
497
|
+
"""
|
|
498
|
+
Does this node represent the constant one.
|
|
499
|
+
"""
|
|
500
|
+
return False
|
|
501
|
+
|
|
502
|
+
def __add__(self, other: CircuitNode | ConstValue):
|
|
503
|
+
return self.circuit.add(self, other)
|
|
504
|
+
|
|
505
|
+
def __mul__(self, other: CircuitNode | ConstValue):
|
|
506
|
+
return self.circuit.mul(self, other)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
# A type representing a flexible representation of multiple CircuitNode objects.
|
|
510
|
+
Args = CircuitNode | ConstValue | Iterable[CircuitNode | ConstValue]
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
class ConstNode(CircuitNode):
|
|
514
|
+
"""
|
|
515
|
+
A node in a circuit representing a constant value.
|
|
516
|
+
"""
|
|
517
|
+
__slots__ = ('_value',)
|
|
518
|
+
|
|
519
|
+
def __init__(self, circuit: Circuit, value: ConstValue):
|
|
520
|
+
super().__init__(circuit)
|
|
521
|
+
self._value: ConstValue = value
|
|
522
|
+
|
|
523
|
+
@property
|
|
524
|
+
def value(self) -> ConstValue:
|
|
525
|
+
return self._value
|
|
526
|
+
|
|
527
|
+
@property
|
|
528
|
+
def is_zero(self) -> bool:
|
|
529
|
+
# noinspection PyProtectedMember
|
|
530
|
+
return self is self.circuit.zero
|
|
531
|
+
|
|
532
|
+
@property
|
|
533
|
+
def is_one(self) -> bool:
|
|
534
|
+
# noinspection PyProtectedMember
|
|
535
|
+
return self is self.circuit.one
|
|
536
|
+
|
|
537
|
+
def __str__(self) -> str:
|
|
538
|
+
return f'const({self.value})'
|
|
539
|
+
|
|
540
|
+
def __lt__(self, other) -> bool:
|
|
541
|
+
if isinstance(other, ConstNode):
|
|
542
|
+
return self._value < other._value
|
|
543
|
+
else:
|
|
544
|
+
return False
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
class VarNode(CircuitNode):
|
|
548
|
+
"""
|
|
549
|
+
A node in a circuit representing an input variable.
|
|
550
|
+
"""
|
|
551
|
+
__slots__ = ('_idx', '_const')
|
|
552
|
+
|
|
553
|
+
def __init__(self, circuit: Circuit, idx: int):
|
|
554
|
+
super().__init__(circuit)
|
|
555
|
+
self._idx: int = idx
|
|
556
|
+
self._const: Optional[ConstNode] = None
|
|
557
|
+
|
|
558
|
+
@property
|
|
559
|
+
def idx(self) -> int:
|
|
560
|
+
return self._idx
|
|
561
|
+
|
|
562
|
+
def is_const(self) -> bool:
|
|
563
|
+
return self._const is not None
|
|
564
|
+
|
|
565
|
+
@property
|
|
566
|
+
def const(self) -> Optional[ConstNode]:
|
|
567
|
+
return self._const
|
|
568
|
+
|
|
569
|
+
@const.setter
|
|
570
|
+
def const(self, value: ConstValue | ConstNode | None) -> None:
|
|
571
|
+
if value is None:
|
|
572
|
+
self._const = None
|
|
573
|
+
else:
|
|
574
|
+
self._const = self.circuit.const(value)
|
|
575
|
+
|
|
576
|
+
@property
|
|
577
|
+
def is_zero(self) -> bool:
|
|
578
|
+
return self._const is not None and self._const.is_zero
|
|
579
|
+
|
|
580
|
+
@property
|
|
581
|
+
def is_one(self) -> bool:
|
|
582
|
+
return self._const is not None and self._const.is_one
|
|
583
|
+
|
|
584
|
+
def __lt__(self, other) -> bool:
|
|
585
|
+
if isinstance(other, VarNode):
|
|
586
|
+
return self._idx < other.idx
|
|
587
|
+
else:
|
|
588
|
+
return False
|
|
589
|
+
|
|
590
|
+
def __str__(self) -> str:
|
|
591
|
+
if self._const is None:
|
|
592
|
+
return f'var[{self.idx}]'
|
|
593
|
+
else:
|
|
594
|
+
return f'var[{self.idx}] = {self._const.value}'
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
class OpNode(CircuitNode):
|
|
598
|
+
"""
|
|
599
|
+
A node in a circuit representing an arithmetic operation.
|
|
600
|
+
"""
|
|
601
|
+
__slots__ = ('args', 'symbol')
|
|
602
|
+
|
|
603
|
+
def __init__(self, circuit: Circuit, symbol: int, args: Tuple[CircuitNode, ...]):
|
|
604
|
+
super().__init__(circuit)
|
|
605
|
+
self.args: Tuple[CircuitNode, ...] = args
|
|
606
|
+
self.symbol: int = symbol
|
|
607
|
+
|
|
608
|
+
def __str__(self) -> str:
|
|
609
|
+
return f'{self.op_str()}\\{len(self.args)}'
|
|
610
|
+
|
|
611
|
+
def op_str(self) -> str:
|
|
612
|
+
"""
|
|
613
|
+
Returns the op node operation as a string.
|
|
614
|
+
"""
|
|
615
|
+
if self.symbol == MUL:
|
|
616
|
+
return 'mul'
|
|
617
|
+
elif self.symbol == ADD:
|
|
618
|
+
return 'add'
|
|
619
|
+
else:
|
|
620
|
+
return '?' + str(self.symbol)
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
@dataclass
|
|
624
|
+
class _DNode:
|
|
625
|
+
"""
|
|
626
|
+
A data structure supporting derivative calculations.
|
|
627
|
+
A DNode holds all information needed to calculate the partial derivative at `node`.
|
|
628
|
+
"""
|
|
629
|
+
node: CircuitNode
|
|
630
|
+
derivative: Optional[CircuitNode]
|
|
631
|
+
derivative_self_mul: Optional[CircuitNode] = None
|
|
632
|
+
sum_prod: List[_DNodeProduct] = field(default_factory=list)
|
|
633
|
+
processed: bool = False
|
|
634
|
+
|
|
635
|
+
def __str__(self) -> str:
|
|
636
|
+
"""
|
|
637
|
+
for debugging
|
|
638
|
+
"""
|
|
639
|
+
dots: str = '...'
|
|
640
|
+
return (
|
|
641
|
+
f'DNode({self.node}, '
|
|
642
|
+
f'{None if self.derivative is None else dots}, '
|
|
643
|
+
f'{None if self.derivative_self_mul is None else dots}, '
|
|
644
|
+
f'{len(self.sum_prod)}, '
|
|
645
|
+
f'{self.processed})'
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
@dataclass
|
|
650
|
+
class _DNodeProduct:
|
|
651
|
+
"""
|
|
652
|
+
A data structure supporting derivative calculations.
|
|
653
|
+
|
|
654
|
+
The represents a product of `parent` and `prod`.
|
|
655
|
+
"""
|
|
656
|
+
parent: _DNode
|
|
657
|
+
prod: List[CircuitNode]
|
|
658
|
+
|
|
659
|
+
def __str__(self) -> str:
|
|
660
|
+
"""
|
|
661
|
+
for debugging
|
|
662
|
+
"""
|
|
663
|
+
return f'DNodeProduct({self.parent}, {self.prod})'
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
class _DerivativeHelper:
|
|
667
|
+
"""
|
|
668
|
+
A data structure to support efficient calculation of partial derivatives
|
|
669
|
+
with respect to some function node `f`.
|
|
670
|
+
"""
|
|
671
|
+
|
|
672
|
+
def __init__(self, f: CircuitNode):
|
|
673
|
+
"""
|
|
674
|
+
Prepare to calculate partial derivatives with respect to `f`.
|
|
675
|
+
"""
|
|
676
|
+
self.f: CircuitNode = f
|
|
677
|
+
self.circuit: Circuit = f.circuit
|
|
678
|
+
self.d_nodes: Dict[int, _DNode] = {} # map id(CircuitNode) to its DNode
|
|
679
|
+
self.zero = self.circuit.zero
|
|
680
|
+
self.one = self.circuit.one
|
|
681
|
+
top_d_node: _DNode = _DNode(f, self.one)
|
|
682
|
+
self.d_nodes[id(f)] = top_d_node
|
|
683
|
+
self._mk_derivative_r(top_d_node)
|
|
684
|
+
|
|
685
|
+
def derivative(self, node: CircuitNode) -> CircuitNode:
|
|
686
|
+
d_node: Optional[_DNode] = self.d_nodes.get(id(node))
|
|
687
|
+
if d_node is None:
|
|
688
|
+
return self.zero
|
|
689
|
+
else:
|
|
690
|
+
return self._derivative(d_node)
|
|
691
|
+
|
|
692
|
+
def derivative_self_mul(self, node: CircuitNode) -> CircuitNode:
|
|
693
|
+
d_node: Optional[_DNode] = self.d_nodes.get(id(node))
|
|
694
|
+
if d_node is None:
|
|
695
|
+
return self.zero
|
|
696
|
+
|
|
697
|
+
if d_node.derivative_self_mul is None:
|
|
698
|
+
d: CircuitNode = self._derivative(d_node)
|
|
699
|
+
if d is self.zero:
|
|
700
|
+
d_node.derivative_self_mul = self.zero
|
|
701
|
+
elif d is self.one:
|
|
702
|
+
d_node.derivative_self_mul = node
|
|
703
|
+
else:
|
|
704
|
+
d_node.derivative_self_mul = self.circuit.optimised_mul(d, node)
|
|
705
|
+
|
|
706
|
+
return d_node.derivative_self_mul
|
|
707
|
+
|
|
708
|
+
def _derivative(self, d_node: _DNode) -> CircuitNode:
|
|
709
|
+
if d_node.derivative is not None:
|
|
710
|
+
return d_node.derivative
|
|
711
|
+
|
|
712
|
+
# Get the list of circuit nodes that must be added together.
|
|
713
|
+
to_add: Sequence[CircuitNode] = tuple(
|
|
714
|
+
value
|
|
715
|
+
for value in (self._derivative_prod(prods) for prods in d_node.sum_prod)
|
|
716
|
+
if not value.is_zero
|
|
717
|
+
)
|
|
718
|
+
# We can release the temporary memory at this DNode now
|
|
719
|
+
# Warning disabled as we will never use this field again - doing so would be an error.
|
|
720
|
+
# noinspection PyTypeChecker
|
|
721
|
+
d_node.sum_prod = None
|
|
722
|
+
|
|
723
|
+
# Construct the addition operation
|
|
724
|
+
d_node.derivative = self.circuit.optimised_add(*to_add)
|
|
725
|
+
|
|
726
|
+
return d_node.derivative
|
|
727
|
+
|
|
728
|
+
def _derivative_prod(self, prods: _DNodeProduct) -> CircuitNode:
|
|
729
|
+
"""
|
|
730
|
+
Support `_derivative` by constructing the derivative product for the given _DNodeProduct.
|
|
731
|
+
"""
|
|
732
|
+
# Get the derivative of the parent node.
|
|
733
|
+
parent: CircuitNode = self._derivative(prods.parent)
|
|
734
|
+
|
|
735
|
+
# Multiply the parent derivative with all other nodes recorded at prod.
|
|
736
|
+
to_mul: List[CircuitNode] = []
|
|
737
|
+
for arg in chain((parent,), prods.prod):
|
|
738
|
+
if arg is self.zero:
|
|
739
|
+
# Multiplication by zero is zero
|
|
740
|
+
return self.zero
|
|
741
|
+
if arg is not self.one:
|
|
742
|
+
to_mul.append(arg)
|
|
743
|
+
|
|
744
|
+
# Construct the multiplication operation
|
|
745
|
+
return self.circuit.optimised_mul(*to_mul)
|
|
746
|
+
|
|
747
|
+
def _mk_derivative_r(self, d_node: _DNode) -> None:
|
|
748
|
+
"""
|
|
749
|
+
Construct a DNode for each argument of the given DNode.
|
|
750
|
+
"""
|
|
751
|
+
if d_node.processed:
|
|
752
|
+
return
|
|
753
|
+
d_node.processed = True
|
|
754
|
+
node: CircuitNode = d_node.node
|
|
755
|
+
|
|
756
|
+
if isinstance(node, OpNode):
|
|
757
|
+
if node.symbol == ADD:
|
|
758
|
+
for arg in node.args:
|
|
759
|
+
child_d_node = self._add(arg, d_node, [])
|
|
760
|
+
self._mk_derivative_r(child_d_node)
|
|
761
|
+
elif node.symbol == MUL:
|
|
762
|
+
for arg in node.args:
|
|
763
|
+
prod = [arg2 for arg2 in node.args if arg is not arg2]
|
|
764
|
+
child_d_node = self._add(arg, d_node, prod)
|
|
765
|
+
self._mk_derivative_r(child_d_node)
|
|
766
|
+
else:
|
|
767
|
+
raise RuntimeError(f'unknown op node symbol: {node.symbol!r}')
|
|
768
|
+
|
|
769
|
+
def _add(self, node: CircuitNode, parent: _DNode, prod: List[CircuitNode]) -> _DNode:
|
|
770
|
+
"""
|
|
771
|
+
Support for `_mk_derivative_r`.
|
|
772
|
+
|
|
773
|
+
Add a _DNodeProduct(parent, negate, prod) to the DNode for the given circuit node.
|
|
774
|
+
|
|
775
|
+
If the DNode for `node` does not yet exist, one will be created.
|
|
776
|
+
|
|
777
|
+
The given circuit node may have multiple parents (i.e., a shared sub-expression). Therefore,
|
|
778
|
+
this method may be called multiple times for a given node. Each time a new _DNodeProduct will be added.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
node: the CircuitNode that the returned DNode is for.
|
|
782
|
+
parent: the DNode of the parent node, i.e., `node` is an argument to the parent node.
|
|
783
|
+
prod: other circuit nodes that need to be multiplied with the parent derivative when
|
|
784
|
+
constructing a derivative for `node`.
|
|
785
|
+
|
|
786
|
+
Returns:
|
|
787
|
+
the DNode for `node`.
|
|
788
|
+
"""
|
|
789
|
+
child_d_node: _DNode = self._get(node)
|
|
790
|
+
child_d_node.sum_prod.append(_DNodeProduct(parent, prod))
|
|
791
|
+
return child_d_node
|
|
792
|
+
|
|
793
|
+
def _get(self, node: CircuitNode) -> _DNode:
|
|
794
|
+
"""
|
|
795
|
+
Get the DNode for the given circuit node.
|
|
796
|
+
If no DNode exist for it yet, then one will be constructed.
|
|
797
|
+
"""
|
|
798
|
+
node_id: int = id(node)
|
|
799
|
+
d_node: Optional[_DNode] = self.d_nodes.get(node_id)
|
|
800
|
+
if d_node is None:
|
|
801
|
+
d_node = _DNode(node, None)
|
|
802
|
+
self.d_nodes[node_id] = d_node
|
|
803
|
+
return d_node
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
def _reachable_op_nodes_r(node: CircuitNode, seen: Set[int]) -> Iterator[OpNode]:
|
|
807
|
+
"""
|
|
808
|
+
Recursive helper for `reachable_op_nodes`. Performs a depth-first search.
|
|
809
|
+
|
|
810
|
+
Args:
|
|
811
|
+
node: the current node being checked.
|
|
812
|
+
seen: keep track of seen op node ids (to avoid returning multiple of the same node).
|
|
813
|
+
|
|
814
|
+
Returns:
|
|
815
|
+
An iterator over all op nodes reachable from the given node.
|
|
816
|
+
"""
|
|
817
|
+
if isinstance(node, OpNode) and id(node) not in seen:
|
|
818
|
+
seen.add(id(node))
|
|
819
|
+
for arg in node.args:
|
|
820
|
+
for reachable in _reachable_op_nodes_r(arg, seen):
|
|
821
|
+
yield reachable
|
|
822
|
+
yield node
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def _reachable_op_nodes_seen_r(node: CircuitNode, seen: Set[int]) -> None:
|
|
826
|
+
"""
|
|
827
|
+
Recursive helper for `remove_unreachable_op_nodes`. Performs a depth-first search.
|
|
828
|
+
|
|
829
|
+
Args:
|
|
830
|
+
node: the current node being checked.
|
|
831
|
+
seen: set of seen op node ids.
|
|
832
|
+
"""
|
|
833
|
+
if isinstance(node, OpNode) and id(node) not in seen:
|
|
834
|
+
seen.add(id(node))
|
|
835
|
+
for arg in node.args:
|
|
836
|
+
_reachable_op_nodes_seen_r(arg, seen)
|