compiled-knowledge 4.0.0a17__cp312-cp312-win_amd64.whl → 4.0.0a19__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/__init__.py +4 -0
- ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +32 -0
- ck/circuit/_circuit_cy.pyx +157 -182
- ck/circuit/_circuit_py.py +2 -2
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +193 -79
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +29 -4
- ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win_amd64.pyd +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
- ck/circuit_compiler/support/{circuit_analyser.py → circuit_analyser/_circuit_analyser_py.py} +14 -2
- ck/pgm_compiler/ace/__init__.py +1 -1
- ck/pgm_compiler/support/circuit_table/__init__.py +8 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +44 -37
- ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +76 -41
- ck/pgm_compiler/support/named_compiler_maker.py +12 -2
- ck/utils/iter_extras.py +8 -1
- ck_demos/ace/demo_ace.py +5 -0
- ck_demos/utils/compare.py +5 -1
- {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/RECORD +26 -23
- {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/WHEEL +1 -1
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +0 -16946
- {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/top_level.txt +0 -0
ck/circuit/__init__.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
# There are two implementations of the `circuit` module are provided
|
|
2
|
+
# for developer R&D purposes. One is pure Python and the other is Cython.
|
|
3
|
+
# Which implementation is used can be selected here.
|
|
4
|
+
|
|
1
5
|
# from ._circuit_py import (
|
|
2
6
|
from ._circuit_cy import (
|
|
3
7
|
Circuit,
|
|
Binary file
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
cdef class Circuit:
|
|
2
|
+
cdef public list[VarNode] vars
|
|
3
|
+
cdef public list[OpNode] ops
|
|
4
|
+
cdef public ConstNode zero
|
|
5
|
+
cdef public ConstNode one
|
|
6
|
+
cdef dict[object, ConstNode] _const_map
|
|
7
|
+
cdef object __derivatives
|
|
8
|
+
|
|
9
|
+
cdef OpNode op(self, int symbol, tuple[CircuitNode, ...] nodes)
|
|
10
|
+
cdef list[OpNode] find_reachable_op_nodes(self, list[CircuitNode] nodes)
|
|
11
|
+
cdef void _remove_unreachable_op_nodes(self, list[CircuitNode] nodes)
|
|
12
|
+
cdef list[CircuitNode] _check_nodes(self, object nodes)
|
|
13
|
+
cdef void __check_nodes(self, object nodes, list[CircuitNode] result)
|
|
14
|
+
cdef object _derivatives(self, CircuitNode f)
|
|
15
|
+
|
|
16
|
+
cdef class CircuitNode:
|
|
17
|
+
cdef public Circuit circuit
|
|
18
|
+
cdef public bint is_zero
|
|
19
|
+
cdef public bint is_one
|
|
20
|
+
|
|
21
|
+
cdef class ConstNode(CircuitNode):
|
|
22
|
+
cdef public object value
|
|
23
|
+
|
|
24
|
+
cdef class VarNode(CircuitNode):
|
|
25
|
+
cdef public int idx
|
|
26
|
+
cdef object _const
|
|
27
|
+
|
|
28
|
+
cpdef int is_const(self) except*
|
|
29
|
+
|
|
30
|
+
cdef class OpNode(CircuitNode):
|
|
31
|
+
cdef public tuple[object, ...] args
|
|
32
|
+
cdef public int symbol
|
ck/circuit/_circuit_cy.pyx
CHANGED
|
@@ -4,7 +4,7 @@ For more documentation on this module, refer to the Jupyter notebook docs/6_circ
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
from itertools import chain
|
|
7
|
-
from typing import Dict,
|
|
7
|
+
from typing import Dict, Optional, Iterable, Sequence, List, overload
|
|
8
8
|
|
|
9
9
|
# Type for values of ConstNode objects
|
|
10
10
|
ConstValue = float | int | bool
|
|
@@ -15,6 +15,8 @@ Args = CircuitNode | ConstValue | Iterable[CircuitNode | ConstValue]
|
|
|
15
15
|
ADD: int = 0
|
|
16
16
|
MUL: int = 1
|
|
17
17
|
|
|
18
|
+
cdef int c_ADD = ADD
|
|
19
|
+
cdef int c_MUL = MUL
|
|
18
20
|
|
|
19
21
|
cdef class Circuit:
|
|
20
22
|
"""
|
|
@@ -29,13 +31,6 @@ cdef class Circuit:
|
|
|
29
31
|
`VarNode` may be temporarily be set to a constant value.
|
|
30
32
|
"""
|
|
31
33
|
|
|
32
|
-
cdef public list[VarNode] vars
|
|
33
|
-
cdef public list[OpNode] ops
|
|
34
|
-
cdef public object zero
|
|
35
|
-
cdef public object one
|
|
36
|
-
cdef dict[Any, ConstNode] _const_map
|
|
37
|
-
cdef object __derivatives
|
|
38
|
-
|
|
39
34
|
def __init__(self, zero: ConstValue = 0, one: ConstValue = 1):
|
|
40
35
|
"""
|
|
41
36
|
Construct a new, empty circuit.
|
|
@@ -46,13 +41,11 @@ cdef class Circuit:
|
|
|
46
41
|
"""
|
|
47
42
|
self.vars: List[VarNode] = []
|
|
48
43
|
self.ops: List[OpNode] = []
|
|
49
|
-
self._const_map: Dict[ConstValue, ConstNode] = {}
|
|
50
|
-
self.__derivatives: Optional[_DerivativeHelper] = None # cache for partial derivatives calculations.
|
|
51
44
|
self.zero: ConstNode = ConstNode(self, zero, is_zero=True)
|
|
52
45
|
self.one: ConstNode = ConstNode(self, one, is_one=True)
|
|
53
46
|
|
|
54
|
-
self._const_map[
|
|
55
|
-
self.
|
|
47
|
+
self._const_map: Dict[ConstValue, ConstNode] = {zero: self.zero, one: self.one}
|
|
48
|
+
self.__derivatives: Optional[_DerivativeHelper] = None # cache for partial derivatives calculations.
|
|
56
49
|
|
|
57
50
|
@property
|
|
58
51
|
def number_of_vars(self) -> int:
|
|
@@ -127,72 +120,57 @@ cdef class Circuit:
|
|
|
127
120
|
self._const_map[value] = node
|
|
128
121
|
return node
|
|
129
122
|
|
|
130
|
-
cdef OpNode _op(self, int symbol, tuple[CircuitNode, ...] nodes):
|
|
131
|
-
cdef OpNode node = OpNode(self, symbol, nodes)
|
|
132
|
-
self.ops.append(node)
|
|
133
|
-
return node
|
|
134
|
-
|
|
135
123
|
def add(self, *nodes: Args) -> OpNode:
|
|
136
124
|
"""
|
|
137
125
|
Create and return a new 'addition' node, applied to the given arguments.
|
|
138
126
|
"""
|
|
139
127
|
cdef list[CircuitNode] args = self._check_nodes(nodes)
|
|
140
|
-
return self.
|
|
128
|
+
return self.op(c_ADD, tuple(args))
|
|
141
129
|
|
|
142
130
|
def mul(self, *nodes: Args) -> OpNode:
|
|
143
131
|
"""
|
|
144
132
|
Create and return a new 'multiplication' node, applied to the given arguments.
|
|
145
133
|
"""
|
|
146
134
|
cdef list[CircuitNode] args = self._check_nodes(nodes)
|
|
147
|
-
return self.
|
|
135
|
+
return self.op(c_MUL, tuple(args))
|
|
148
136
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
cdef CircuitNode n
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
raise RuntimeError('node does not belong to this circuit')
|
|
162
|
-
if not n.is_zero:
|
|
163
|
-
to_add.append(n)
|
|
164
|
-
cdef int len_to_add = len(to_add)
|
|
165
|
-
if len_to_add == 0:
|
|
137
|
+
def optimised_add(self, *args: Args) -> CircuitNode:
|
|
138
|
+
"""
|
|
139
|
+
Optimised circuit node addition.
|
|
140
|
+
|
|
141
|
+
Performs the following optimisations:
|
|
142
|
+
* addition to zero is avoided: add(x, 0) = x,
|
|
143
|
+
* singleton addition is avoided: add(x) = x,
|
|
144
|
+
* empty addition is avoided: add() = 0,
|
|
145
|
+
"""
|
|
146
|
+
cdef tuple[CircuitNode, ...] to_add = tuple(n for n in self._check_nodes(args) if not n.is_zero)
|
|
147
|
+
cdef size_t num_to_add = len(to_add)
|
|
148
|
+
if num_to_add == 0:
|
|
166
149
|
return self.zero
|
|
167
|
-
|
|
150
|
+
if num_to_add == 1:
|
|
168
151
|
return to_add[0]
|
|
169
|
-
|
|
170
|
-
return self._op(ADD, tuple(to_add))
|
|
152
|
+
return self.op(c_ADD, to_add)
|
|
171
153
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
for n in
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
to_mul.append(n)
|
|
189
|
-
cdef int len_to_mul = len(to_mul)
|
|
190
|
-
if len_to_mul == 0:
|
|
154
|
+
def optimised_mul(self, *args: Args) -> CircuitNode:
|
|
155
|
+
"""
|
|
156
|
+
Optimised circuit node multiplication.
|
|
157
|
+
|
|
158
|
+
Performs the following optimisations:
|
|
159
|
+
* multiplication by zero is avoided: mul(x, 0) = 0,
|
|
160
|
+
* multiplication by one is avoided: mul(x, 1) = x,
|
|
161
|
+
* singleton multiplication is avoided: mul(x) = x,
|
|
162
|
+
* empty multiplication is avoided: mul() = 1,
|
|
163
|
+
"""
|
|
164
|
+
cdef tuple[CircuitNode, ...] to_mul = tuple(n for n in self._check_nodes(args) if not n.is_one)
|
|
165
|
+
if any(n.is_zero for n in to_mul):
|
|
166
|
+
return self.zero
|
|
167
|
+
cdef Py_ssize_t num_to_mul = len(to_mul)
|
|
168
|
+
|
|
169
|
+
if num_to_mul == 0:
|
|
191
170
|
return self.one
|
|
192
|
-
|
|
171
|
+
if num_to_mul == 1:
|
|
193
172
|
return to_mul[0]
|
|
194
|
-
|
|
195
|
-
return self._op(MUL, tuple(to_mul))
|
|
173
|
+
return self.op(c_MUL, to_mul)
|
|
196
174
|
|
|
197
175
|
def cartesian_product(self, xs: Sequence[CircuitNode], ys: Sequence[CircuitNode]) -> List[CircuitNode]:
|
|
198
176
|
"""
|
|
@@ -206,12 +184,12 @@ cdef class Circuit:
|
|
|
206
184
|
a list of 'mul' nodes, one for each combination of xs and ys. The results are in the order
|
|
207
185
|
given by `[mul(x, y) for x in xs for y in ys]`.
|
|
208
186
|
"""
|
|
209
|
-
|
|
210
|
-
|
|
187
|
+
cdef list[CircuitNode] xs_list = self._check_nodes(xs)
|
|
188
|
+
cdef list[CircuitNode] ys_list = self._check_nodes(ys)
|
|
211
189
|
return [
|
|
212
|
-
self.optimised_mul(
|
|
213
|
-
for x in
|
|
214
|
-
for y in
|
|
190
|
+
self.optimised_mul(x, y)
|
|
191
|
+
for x in xs_list
|
|
192
|
+
for y in ys_list
|
|
215
193
|
]
|
|
216
194
|
|
|
217
195
|
@overload
|
|
@@ -262,10 +240,10 @@ cdef class Circuit:
|
|
|
262
240
|
If `args` is a single CircuitNode, then a single CircuitNode will be returned, otherwise
|
|
263
241
|
a list of CircuitNode is returned.
|
|
264
242
|
"""
|
|
265
|
-
single_result
|
|
243
|
+
cdef bint single_result = isinstance(args, CircuitNode)
|
|
266
244
|
|
|
267
|
-
|
|
268
|
-
if len(
|
|
245
|
+
cdef list[CircuitNode] args_list = self._check_nodes([args])
|
|
246
|
+
if len(args_list) == 0:
|
|
269
247
|
# Trivial case
|
|
270
248
|
return []
|
|
271
249
|
|
|
@@ -274,12 +252,12 @@ cdef class Circuit:
|
|
|
274
252
|
if self_multiply:
|
|
275
253
|
result = [
|
|
276
254
|
derivatives.derivative_self_mul(arg)
|
|
277
|
-
for arg in
|
|
255
|
+
for arg in args_list
|
|
278
256
|
]
|
|
279
257
|
else:
|
|
280
258
|
result = [
|
|
281
259
|
derivatives.derivative(arg)
|
|
282
|
-
for arg in
|
|
260
|
+
for arg in args_list
|
|
283
261
|
]
|
|
284
262
|
|
|
285
263
|
if single_result:
|
|
@@ -303,26 +281,8 @@ cdef class Circuit:
|
|
|
303
281
|
Args:
|
|
304
282
|
*nodes: may be either a node or a list of nodes.
|
|
305
283
|
"""
|
|
306
|
-
|
|
307
|
-
self._remove_unreachable_op_nodes(
|
|
308
|
-
|
|
309
|
-
cdef void _remove_unreachable_op_nodes(self, list[CircuitNode] nodes):
|
|
310
|
-
# Set of object ids for all reachable op nodes
|
|
311
|
-
cdef set[int] seen = set()
|
|
312
|
-
|
|
313
|
-
cdef CircuitNode node
|
|
314
|
-
for node in nodes:
|
|
315
|
-
_reachable_op_nodes_seen_r(node, seen)
|
|
316
|
-
|
|
317
|
-
if len(seen) < len(self.ops):
|
|
318
|
-
# Invalidate unreadable op nodes
|
|
319
|
-
for op_node in self.ops:
|
|
320
|
-
if id(op_node) not in seen:
|
|
321
|
-
op_node.circuit = None
|
|
322
|
-
op_node.args = ()
|
|
323
|
-
|
|
324
|
-
# Keep only reachable op nodes, in the same order as `self.ops`.
|
|
325
|
-
self.ops = [op_node for op_node in self.ops if id(op_node) in seen]
|
|
284
|
+
cdef list[CircuitNode] node_list = self._check_nodes(nodes)
|
|
285
|
+
self._remove_unreachable_op_nodes(node_list)
|
|
326
286
|
|
|
327
287
|
def reachable_op_nodes(self, *nodes: Args) -> List[OpNode]:
|
|
328
288
|
"""
|
|
@@ -338,19 +298,8 @@ cdef class Circuit:
|
|
|
338
298
|
Returned nodes are not repeated.
|
|
339
299
|
The result is ordered such that if result[i] is referenced by result[j] then i < j.
|
|
340
300
|
"""
|
|
341
|
-
|
|
342
|
-
return self.
|
|
343
|
-
|
|
344
|
-
cdef list[OpNode] _reachable_op_nodes(self, list[CircuitNode] nodes):
|
|
345
|
-
# Set of object ids for all reachable op nodes
|
|
346
|
-
cdef set[int] seen = set()
|
|
347
|
-
|
|
348
|
-
cdef list[OpNode] result = []
|
|
349
|
-
|
|
350
|
-
cdef CircuitNode node
|
|
351
|
-
for node in nodes:
|
|
352
|
-
_reachable_op_nodes_r(node, seen, result)
|
|
353
|
-
return result
|
|
301
|
+
cdef list[CircuitNode] node_list = self._check_nodes(nodes)
|
|
302
|
+
return self.find_reachable_op_nodes(node_list)
|
|
354
303
|
|
|
355
304
|
def dump(
|
|
356
305
|
self,
|
|
@@ -415,8 +364,42 @@ cdef class Circuit:
|
|
|
415
364
|
args_str = ' '.join(node_name[id(arg)] for arg in op.args)
|
|
416
365
|
print(f'{next_prefix}{op_name}: {args_str}')
|
|
417
366
|
|
|
367
|
+
cdef OpNode op(self, int symbol, tuple[CircuitNode, ...] nodes):
|
|
368
|
+
cdef OpNode node = OpNode(self, symbol, nodes)
|
|
369
|
+
self.ops.append(node)
|
|
370
|
+
return node
|
|
371
|
+
|
|
372
|
+
cdef list[OpNode] find_reachable_op_nodes(self, list[CircuitNode] nodes):
|
|
373
|
+
# Set of object ids for all reachable op nodes
|
|
374
|
+
cdef set[int] seen = set()
|
|
375
|
+
|
|
376
|
+
cdef list[OpNode] result = []
|
|
377
|
+
|
|
378
|
+
cdef CircuitNode node
|
|
379
|
+
for node in nodes:
|
|
380
|
+
find_reachable_op_nodes_r(node, seen, result)
|
|
381
|
+
return result
|
|
382
|
+
|
|
383
|
+
cdef void _remove_unreachable_op_nodes(self, list[CircuitNode] nodes):
|
|
384
|
+
# Set of object ids for all reachable op nodes
|
|
385
|
+
cdef set[int] seen = set()
|
|
386
|
+
|
|
387
|
+
cdef CircuitNode node
|
|
388
|
+
for node in nodes:
|
|
389
|
+
find_reachable_op_nodes_seen_r(node, seen)
|
|
390
|
+
|
|
391
|
+
if len(seen) < len(self.ops):
|
|
392
|
+
# Invalidate unreadable op nodes
|
|
393
|
+
for op_node in self.ops:
|
|
394
|
+
if id(op_node) not in seen:
|
|
395
|
+
op_node.circuit = None
|
|
396
|
+
op_node.args = ()
|
|
397
|
+
|
|
398
|
+
# Keep only reachable op nodes, in the same order as `self.ops`.
|
|
399
|
+
self.ops = [op_node for op_node in self.ops if id(op_node) in seen]
|
|
400
|
+
|
|
418
401
|
cdef list[CircuitNode] _check_nodes(self, object nodes: Iterable[Args]): # -> Sequence[CircuitNode]:
|
|
419
|
-
# Convert the given circuit nodes to a
|
|
402
|
+
# Convert the given circuit nodes to a list, flattening nested iterables as needed.
|
|
420
403
|
#
|
|
421
404
|
# Args:
|
|
422
405
|
# nodes: some circuit nodes of constant values.
|
|
@@ -427,8 +410,8 @@ cdef class Circuit:
|
|
|
427
410
|
self.__check_nodes(nodes, result)
|
|
428
411
|
return result
|
|
429
412
|
|
|
430
|
-
cdef void __check_nodes(self, nodes: Iterable[Args], list[CircuitNode] result):
|
|
431
|
-
# Convert the given circuit nodes to a
|
|
413
|
+
cdef void __check_nodes(self, object nodes: Iterable[Args], list[CircuitNode] result):
|
|
414
|
+
# Convert the given circuit nodes to a list, flattening nested iterables as needed.
|
|
432
415
|
#
|
|
433
416
|
# Args:
|
|
434
417
|
# nodes: some circuit nodes of constant values.
|
|
@@ -455,7 +438,6 @@ cdef class Circuit:
|
|
|
455
438
|
self.__derivatives = derivatives
|
|
456
439
|
return derivatives
|
|
457
440
|
|
|
458
|
-
|
|
459
441
|
cdef class CircuitNode:
|
|
460
442
|
"""
|
|
461
443
|
A node in an arithmetic circuit.
|
|
@@ -470,9 +452,6 @@ cdef class CircuitNode:
|
|
|
470
452
|
A var node may be temporarily set to be a constant node, which may
|
|
471
453
|
be useful for optimising a compiled circuit.
|
|
472
454
|
"""
|
|
473
|
-
cdef public Circuit circuit
|
|
474
|
-
cdef public bint is_zero
|
|
475
|
-
cdef public bint is_one
|
|
476
455
|
|
|
477
456
|
def __init__(self, circuit: Circuit, is_zero: bool, is_one: bool):
|
|
478
457
|
self.circuit = circuit
|
|
@@ -485,11 +464,10 @@ cdef class CircuitNode:
|
|
|
485
464
|
def __mul__(self, other: CircuitNode | ConstValue):
|
|
486
465
|
return self.circuit.mul(self, other)
|
|
487
466
|
|
|
488
|
-
|
|
489
467
|
cdef class ConstNode(CircuitNode):
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
468
|
+
"""
|
|
469
|
+
A node in a circuit representing a constant value.
|
|
470
|
+
"""
|
|
493
471
|
|
|
494
472
|
def __init__(self, circuit, value: ConstValue, is_zero: bool = False, is_one: bool = False):
|
|
495
473
|
super().__init__(circuit, is_zero, is_one)
|
|
@@ -504,13 +482,10 @@ cdef class ConstNode(CircuitNode):
|
|
|
504
482
|
else:
|
|
505
483
|
return False
|
|
506
484
|
|
|
507
|
-
|
|
508
485
|
cdef class VarNode(CircuitNode):
|
|
509
486
|
"""
|
|
510
487
|
A node in a circuit representing an input variable.
|
|
511
488
|
"""
|
|
512
|
-
cdef public int idx
|
|
513
|
-
cdef object _const
|
|
514
489
|
|
|
515
490
|
def __init__(self, circuit, idx: int):
|
|
516
491
|
super().__init__(circuit, False, False)
|
|
@@ -552,13 +527,11 @@ cdef class OpNode(CircuitNode):
|
|
|
552
527
|
"""
|
|
553
528
|
A node in a circuit representing an arithmetic operation.
|
|
554
529
|
"""
|
|
555
|
-
cdef public tuple[object, ...] args
|
|
556
|
-
cdef public int symbol
|
|
557
530
|
|
|
558
|
-
def __init__(self,
|
|
531
|
+
def __init__(self, Circuit circuit, int symbol, tuple[CircuitNode, ...] args):
|
|
559
532
|
super().__init__(circuit, False, False)
|
|
560
533
|
self.args = tuple(args)
|
|
561
|
-
self.symbol = int
|
|
534
|
+
self.symbol = <int> symbol
|
|
562
535
|
|
|
563
536
|
def __str__(self) -> str:
|
|
564
537
|
return f'{self.op_str()}\\{len(self.args)}'
|
|
@@ -567,9 +540,9 @@ cdef class OpNode(CircuitNode):
|
|
|
567
540
|
"""
|
|
568
541
|
Returns the op node operation as a string.
|
|
569
542
|
"""
|
|
570
|
-
if self.symbol ==
|
|
543
|
+
if self.symbol == c_MUL:
|
|
571
544
|
return 'mul'
|
|
572
|
-
elif self.symbol ==
|
|
545
|
+
elif self.symbol == c_ADD:
|
|
573
546
|
return 'add'
|
|
574
547
|
else:
|
|
575
548
|
return '?' + str(self.symbol)
|
|
@@ -579,11 +552,11 @@ cdef class _DNode:
|
|
|
579
552
|
A data structure supporting derivative calculations.
|
|
580
553
|
A DNode holds all information needed to calculate the partial derivative at `node`.
|
|
581
554
|
"""
|
|
582
|
-
cdef
|
|
583
|
-
cdef
|
|
584
|
-
cdef
|
|
585
|
-
cdef
|
|
586
|
-
cdef
|
|
555
|
+
cdef CircuitNode node
|
|
556
|
+
cdef object derivative
|
|
557
|
+
cdef object derivative_self_mul
|
|
558
|
+
cdef list[_DNodeProduct] sum_prod
|
|
559
|
+
cdef bint processed
|
|
587
560
|
|
|
588
561
|
def __init__(
|
|
589
562
|
self,
|
|
@@ -615,8 +588,8 @@ cdef class _DNodeProduct:
|
|
|
615
588
|
|
|
616
589
|
The represents a product of `parent` and `prod`.
|
|
617
590
|
"""
|
|
618
|
-
cdef
|
|
619
|
-
cdef
|
|
591
|
+
cdef _DNode parent
|
|
592
|
+
cdef list[CircuitNode] prod
|
|
620
593
|
|
|
621
594
|
def __init__(self, parent: _DNode, prod: List[CircuitNode]):
|
|
622
595
|
self.parent = parent
|
|
@@ -629,33 +602,40 @@ cdef class _DNodeProduct:
|
|
|
629
602
|
return 'DNodeProduct(' + str(self.parent) + ', ' + str(self.prod) + ')'
|
|
630
603
|
|
|
631
604
|
|
|
632
|
-
class _DerivativeHelper:
|
|
605
|
+
cdef class _DerivativeHelper:
|
|
633
606
|
"""
|
|
634
607
|
A data structure to support efficient calculation of partial derivatives
|
|
635
608
|
with respect to some function node `f`.
|
|
636
609
|
"""
|
|
637
610
|
|
|
611
|
+
cdef CircuitNode f
|
|
612
|
+
cdef CircuitNode zero
|
|
613
|
+
cdef CircuitNode one
|
|
614
|
+
cdef Circuit circuit
|
|
615
|
+
cdef dict[int, _DNode] d_nodes
|
|
616
|
+
|
|
638
617
|
def __init__(self, f: CircuitNode):
|
|
639
618
|
"""
|
|
640
619
|
Prepare to calculate partial derivatives with respect to `f`.
|
|
641
620
|
"""
|
|
642
|
-
self.f
|
|
643
|
-
self.circuit
|
|
644
|
-
self.d_nodes
|
|
621
|
+
self.f = f
|
|
622
|
+
self.circuit = f.circuit
|
|
623
|
+
self.d_nodes = {} # map id(CircuitNode) to its DNode
|
|
645
624
|
self.zero = self.circuit.zero
|
|
646
625
|
self.one = self.circuit.one
|
|
647
|
-
|
|
626
|
+
|
|
627
|
+
cdef _DNode top_d_node = _DNode(f, self.one)
|
|
648
628
|
self.d_nodes[id(f)] = top_d_node
|
|
649
629
|
self._mk_derivative_r(top_d_node)
|
|
650
630
|
|
|
651
|
-
|
|
631
|
+
cdef CircuitNode derivative(self, CircuitNode node):
|
|
652
632
|
d_node: Optional[_DNode] = self.d_nodes.get(id(node))
|
|
653
633
|
if d_node is None:
|
|
654
634
|
return self.zero
|
|
655
635
|
else:
|
|
656
636
|
return self._derivative(d_node)
|
|
657
637
|
|
|
658
|
-
|
|
638
|
+
cdef CircuitNode derivative_self_mul(self, CircuitNode node):
|
|
659
639
|
d_node: Optional[_DNode] = self.d_nodes.get(id(node))
|
|
660
640
|
if d_node is None:
|
|
661
641
|
return self.zero
|
|
@@ -667,11 +647,11 @@ class _DerivativeHelper:
|
|
|
667
647
|
elif d is self.one:
|
|
668
648
|
d_node.derivative_self_mul = node
|
|
669
649
|
else:
|
|
670
|
-
d_node.derivative_self_mul = self.circuit.optimised_mul(
|
|
650
|
+
d_node.derivative_self_mul = self.circuit.optimised_mul(d, node)
|
|
671
651
|
|
|
672
652
|
return d_node.derivative_self_mul
|
|
673
653
|
|
|
674
|
-
|
|
654
|
+
cdef CircuitNode _derivative(self, _DNode d_node):
|
|
675
655
|
if d_node.derivative is not None:
|
|
676
656
|
return d_node.derivative
|
|
677
657
|
|
|
@@ -685,14 +665,13 @@ class _DerivativeHelper:
|
|
|
685
665
|
d_node.sum_prod = None
|
|
686
666
|
|
|
687
667
|
# Construct the addition operation
|
|
688
|
-
d_node.derivative = self.circuit.optimised_add(to_add)
|
|
668
|
+
d_node.derivative = self.circuit.optimised_add(*to_add)
|
|
689
669
|
|
|
690
670
|
return d_node.derivative
|
|
691
671
|
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
"""
|
|
672
|
+
cdef CircuitNode _derivative_prod(self, _DNodeProduct prods):
|
|
673
|
+
# Support `_derivative` by constructing the derivative product for the given _DNodeProduct.
|
|
674
|
+
|
|
696
675
|
# Get the derivative of the parent node.
|
|
697
676
|
parent: CircuitNode = self._derivative(prods.parent)
|
|
698
677
|
|
|
@@ -706,59 +685,56 @@ class _DerivativeHelper:
|
|
|
706
685
|
to_mul.append(arg)
|
|
707
686
|
|
|
708
687
|
# Construct the multiplication operation
|
|
709
|
-
return self.circuit.optimised_mul(to_mul)
|
|
688
|
+
return self.circuit.optimised_mul(*to_mul)
|
|
689
|
+
|
|
690
|
+
cdef void _mk_derivative_r(self, _DNode d_node):
|
|
691
|
+
# Construct a DNode for each argument of the given DNode.
|
|
710
692
|
|
|
711
|
-
def _mk_derivative_r(self, d_node: _DNode) -> None:
|
|
712
|
-
"""
|
|
713
|
-
Construct a DNode for each argument of the given DNode.
|
|
714
|
-
"""
|
|
715
693
|
if d_node.processed:
|
|
716
694
|
return
|
|
717
695
|
d_node.processed = True
|
|
718
696
|
node: CircuitNode = d_node.node
|
|
719
697
|
|
|
720
698
|
if isinstance(node, OpNode):
|
|
721
|
-
if node.symbol ==
|
|
699
|
+
if node.symbol == c_ADD:
|
|
722
700
|
for arg in node.args:
|
|
723
701
|
child_d_node = self._add(arg, d_node, [])
|
|
724
702
|
self._mk_derivative_r(child_d_node)
|
|
725
|
-
elif node.symbol ==
|
|
703
|
+
elif node.symbol == c_MUL:
|
|
726
704
|
for arg in node.args:
|
|
727
705
|
prod = [arg2 for arg2 in node.args if arg is not arg2]
|
|
728
706
|
child_d_node = self._add(arg, d_node, prod)
|
|
729
707
|
self._mk_derivative_r(child_d_node)
|
|
730
708
|
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
709
|
+
cdef _DNode _add(self, CircuitNode node, _DNode parent, list[CircuitNode] prod):
|
|
710
|
+
# Support for `_mk_derivative_r`.
|
|
711
|
+
#
|
|
712
|
+
# Add a _DNodeProduct(parent, negate, prod) to the DNode for the given circuit node.
|
|
713
|
+
#
|
|
714
|
+
# If the DNode for `node` does not yet exist, one will be created.
|
|
715
|
+
#
|
|
716
|
+
# The given circuit node may have multiple parents (i.e., a shared sub-expression). Therefore,
|
|
717
|
+
# this method may be called multiple times for a given node. Each time a new _DNodeProduct will be added.
|
|
718
|
+
#
|
|
719
|
+
# Args:
|
|
720
|
+
# node: the CircuitNode that the returned DNode is for.
|
|
721
|
+
# parent: the DNode of the parent node, i.e., `node` is an argument to the parent node.
|
|
722
|
+
# prod: other circuit nodes that need to be multiplied with the parent derivative when
|
|
723
|
+
# constructing a derivative for `node`.
|
|
724
|
+
#
|
|
725
|
+
# Returns:
|
|
726
|
+
# the DNode for `node`.
|
|
747
727
|
|
|
748
|
-
Returns:
|
|
749
|
-
the DNode for `node`.
|
|
750
|
-
"""
|
|
751
728
|
child_d_node: _DNode = self._get(node)
|
|
752
729
|
child_d_node.sum_prod.append(_DNodeProduct(parent, prod))
|
|
753
730
|
return child_d_node
|
|
754
731
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
732
|
+
cdef _DNode _get(self, CircuitNode node):
|
|
733
|
+
# Helper for derivatives.
|
|
734
|
+
#
|
|
735
|
+
# Get the DNode for the given circuit node.
|
|
736
|
+
# If no DNode exist for it yet, then one will be constructed.
|
|
758
737
|
|
|
759
|
-
Get the DNode for the given circuit node.
|
|
760
|
-
If no DNode exist for it yet, then one will be constructed.
|
|
761
|
-
"""
|
|
762
738
|
node_id: int = id(node)
|
|
763
739
|
d_node: Optional[_DNode] = self.d_nodes.get(node_id)
|
|
764
740
|
if d_node is None:
|
|
@@ -767,7 +743,7 @@ class _DerivativeHelper:
|
|
|
767
743
|
return d_node
|
|
768
744
|
|
|
769
745
|
|
|
770
|
-
cdef void
|
|
746
|
+
cdef void find_reachable_op_nodes_r(CircuitNode node, set[int] seen, list[OpNode] result):
|
|
771
747
|
# Recursive helper for `reachable_op_nodes`. Performs a depth-first search.
|
|
772
748
|
#
|
|
773
749
|
# Args:
|
|
@@ -777,11 +753,10 @@ cdef void _reachable_op_nodes_r(CircuitNode node, set[int] seen, list[OpNode] re
|
|
|
777
753
|
if isinstance(node, OpNode) and id(node) not in seen:
|
|
778
754
|
seen.add(id(node))
|
|
779
755
|
for arg in node.args:
|
|
780
|
-
|
|
756
|
+
find_reachable_op_nodes_r(arg, seen, result)
|
|
781
757
|
result.append(node)
|
|
782
758
|
|
|
783
|
-
|
|
784
|
-
cdef void _reachable_op_nodes_seen_r(CircuitNode node, set[int] seen):
|
|
759
|
+
cdef void find_reachable_op_nodes_seen_r(CircuitNode node, set[int] seen):
|
|
785
760
|
# Recursive helper for `remove_unreachable_op_nodes`. Performs a depth-first search.
|
|
786
761
|
#
|
|
787
762
|
# Args:
|
|
@@ -790,4 +765,4 @@ cdef void _reachable_op_nodes_seen_r(CircuitNode node, set[int] seen):
|
|
|
790
765
|
if isinstance(node, OpNode) and id(node) not in seen:
|
|
791
766
|
seen.add(id(node))
|
|
792
767
|
for arg in node.args:
|
|
793
|
-
|
|
768
|
+
find_reachable_op_nodes_seen_r(arg, seen)
|
ck/circuit/_circuit_py.py
CHANGED
|
@@ -721,7 +721,7 @@ class _DerivativeHelper:
|
|
|
721
721
|
d_node.sum_prod = None
|
|
722
722
|
|
|
723
723
|
# Construct the addition operation
|
|
724
|
-
d_node.derivative = self.circuit.optimised_add(to_add)
|
|
724
|
+
d_node.derivative = self.circuit.optimised_add(*to_add)
|
|
725
725
|
|
|
726
726
|
return d_node.derivative
|
|
727
727
|
|
|
@@ -742,7 +742,7 @@ class _DerivativeHelper:
|
|
|
742
742
|
to_mul.append(arg)
|
|
743
743
|
|
|
744
744
|
# Construct the multiplication operation
|
|
745
|
-
return self.circuit.optimised_mul(to_mul)
|
|
745
|
+
return self.circuit.optimised_mul(*to_mul)
|
|
746
746
|
|
|
747
747
|
def _mk_derivative_r(self, d_node: _DNode) -> None:
|
|
748
748
|
"""
|
|
Binary file
|