compiled-knowledge 4.0.0a16__cp312-cp312-win_amd64.whl → 4.0.0a18__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/__init__.py +9 -2
- ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
- ck/circuit/_circuit_cy.pxd +33 -0
- ck/circuit/{circuit.pyx → _circuit_cy.pyx} +115 -133
- ck/circuit/{circuit_py.py → _circuit_py.py} +16 -8
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +88 -60
- ck/circuit_compiler/named_circuit_compilers.py +1 -1
- ck/pgm_compiler/factor_elimination.py +23 -13
- ck/pgm_compiler/support/circuit_table/__init__.py +9 -2
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +332 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy_cpp_verion.pyx +601 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy_minimal_version.pyx +311 -0
- ck/pgm_compiler/support/circuit_table/{circuit_table.pyx → _circuit_table_cy_v4.0.0a17.pyx} +9 -9
- ck/pgm_compiler/support/circuit_table/{circuit_table_py.py → _circuit_table_py.py} +80 -45
- ck/pgm_compiler/support/clusters.py +16 -4
- ck/pgm_compiler/support/factor_tables.py +1 -1
- ck/pgm_compiler/support/join_tree.py +67 -10
- ck/pgm_compiler/support/named_compiler_maker.py +12 -2
- ck/pgm_compiler/variable_elimination.py +2 -0
- ck/utils/iter_extras.py +8 -1
- ck_demos/pgm_compiler/demo_compiler_dump.py +10 -0
- ck_demos/pgm_compiler/time_fe_compiler.py +93 -0
- ck_demos/utils/compare.py +5 -1
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/RECORD +30 -29
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/WHEEL +1 -1
- ck/circuit/circuit.c +0 -38861
- ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
- ck/circuit/circuit_node.pyx +0 -138
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +0 -17373
- ck/pgm_compiler/support/circuit_table/circuit_table.c +0 -16042
- ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a16.dist-info → compiled_knowledge-4.0.0a18.dist-info}/top_level.txt +0 -0
ck/circuit/__init__.py
CHANGED
|
@@ -1,5 +1,12 @@
|
|
|
1
|
-
#
|
|
2
|
-
|
|
1
|
+
# There are two implementations of the `circuit` module are provided
|
|
2
|
+
# for developer R&D purposes. One is pure Python and the other is Cython.
|
|
3
|
+
# Which implementation is used can be selected here.
|
|
4
|
+
# A similar selection can be made for the `circuit_table` module.
|
|
5
|
+
# Note that if the Cython implementation is chosen for `circuit_table` then
|
|
6
|
+
# the Cython implementation must be chosen for `circuit`.
|
|
7
|
+
|
|
8
|
+
# from ._circuit_py import (
|
|
9
|
+
from ._circuit_cy import (
|
|
3
10
|
Circuit,
|
|
4
11
|
CircuitNode,
|
|
5
12
|
VarNode,
|
|
Binary file
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
cdef class Circuit:
|
|
2
|
+
cdef public list[VarNode] vars
|
|
3
|
+
cdef public list[OpNode] ops
|
|
4
|
+
cdef public object zero
|
|
5
|
+
cdef public object one
|
|
6
|
+
cdef dict[object, ConstNode] _const_map
|
|
7
|
+
cdef object __derivatives
|
|
8
|
+
|
|
9
|
+
cdef OpNode op(self, int symbol, tuple[CircuitNode, ...] nodes)
|
|
10
|
+
cdef void _remove_unreachable_op_nodes(self, list[CircuitNode] nodes)
|
|
11
|
+
cdef list[OpNode] _reachable_op_nodes(self, list[CircuitNode] nodes)
|
|
12
|
+
cdef list[CircuitNode] _check_nodes(self, object nodes)
|
|
13
|
+
cdef void __check_nodes(self, object nodes, list[CircuitNode] result)
|
|
14
|
+
cdef object _derivatives(self, CircuitNode f)
|
|
15
|
+
cdef object _derivatives(self, CircuitNode f)
|
|
16
|
+
|
|
17
|
+
cdef class CircuitNode:
|
|
18
|
+
cdef public Circuit circuit
|
|
19
|
+
cdef public bint is_zero
|
|
20
|
+
cdef public bint is_one
|
|
21
|
+
|
|
22
|
+
cdef class ConstNode(CircuitNode):
|
|
23
|
+
cdef public object value
|
|
24
|
+
|
|
25
|
+
cdef class VarNode(CircuitNode):
|
|
26
|
+
cdef public int idx
|
|
27
|
+
cdef object _const
|
|
28
|
+
|
|
29
|
+
cpdef int is_const(self) except*
|
|
30
|
+
|
|
31
|
+
cdef class OpNode(CircuitNode):
|
|
32
|
+
cdef public tuple[object, ...] args
|
|
33
|
+
cdef public int symbol
|
|
@@ -4,7 +4,7 @@ For more documentation on this module, refer to the Jupyter notebook docs/6_circ
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
from itertools import chain
|
|
7
|
-
from typing import Dict, Tuple, Optional, Iterable, Sequence, List, overload
|
|
7
|
+
from typing import Dict, Tuple, Optional, Iterable, Sequence, List, overload
|
|
8
8
|
|
|
9
9
|
# Type for values of ConstNode objects
|
|
10
10
|
ConstValue = float | int | bool
|
|
@@ -15,6 +15,8 @@ Args = CircuitNode | ConstValue | Iterable[CircuitNode | ConstValue]
|
|
|
15
15
|
ADD: int = 0
|
|
16
16
|
MUL: int = 1
|
|
17
17
|
|
|
18
|
+
cdef int c_ADD = ADD
|
|
19
|
+
cdef int c_MUL = MUL
|
|
18
20
|
|
|
19
21
|
cdef class Circuit:
|
|
20
22
|
"""
|
|
@@ -29,13 +31,6 @@ cdef class Circuit:
|
|
|
29
31
|
`VarNode` may be temporarily be set to a constant value.
|
|
30
32
|
"""
|
|
31
33
|
|
|
32
|
-
cdef public list[VarNode] vars
|
|
33
|
-
cdef public list[OpNode] ops
|
|
34
|
-
cdef public object zero
|
|
35
|
-
cdef public object one
|
|
36
|
-
cdef dict[Any, ConstNode] _const_map
|
|
37
|
-
cdef object __derivatives
|
|
38
|
-
|
|
39
34
|
def __init__(self, zero: ConstValue = 0, one: ConstValue = 1):
|
|
40
35
|
"""
|
|
41
36
|
Construct a new, empty circuit.
|
|
@@ -48,8 +43,11 @@ cdef class Circuit:
|
|
|
48
43
|
self.ops: List[OpNode] = []
|
|
49
44
|
self._const_map: Dict[ConstValue, ConstNode] = {}
|
|
50
45
|
self.__derivatives: Optional[_DerivativeHelper] = None # cache for partial derivatives calculations.
|
|
51
|
-
self.zero: ConstNode = self
|
|
52
|
-
self.one: ConstNode = self
|
|
46
|
+
self.zero: ConstNode = ConstNode(self, zero, is_zero=True)
|
|
47
|
+
self.one: ConstNode = ConstNode(self, one, is_one=True)
|
|
48
|
+
|
|
49
|
+
self._const_map[zero] = self.zero
|
|
50
|
+
self._const_map[one] = self.one
|
|
53
51
|
|
|
54
52
|
@property
|
|
55
53
|
def number_of_vars(self) -> int:
|
|
@@ -124,72 +122,57 @@ cdef class Circuit:
|
|
|
124
122
|
self._const_map[value] = node
|
|
125
123
|
return node
|
|
126
124
|
|
|
127
|
-
cdef object _op(self, int symbol, tuple[CircuitNode, ...] nodes):
|
|
128
|
-
cdef object node = OpNode(self, symbol, nodes)
|
|
129
|
-
self.ops.append(node)
|
|
130
|
-
return node
|
|
131
|
-
|
|
132
125
|
def add(self, *nodes: Args) -> OpNode:
|
|
133
126
|
"""
|
|
134
127
|
Create and return a new 'addition' node, applied to the given arguments.
|
|
135
128
|
"""
|
|
136
|
-
cdef list[
|
|
137
|
-
return self.
|
|
129
|
+
cdef list[CircuitNode] args = self._check_nodes(nodes)
|
|
130
|
+
return self.op(c_ADD, tuple(args))
|
|
138
131
|
|
|
139
132
|
def mul(self, *nodes: Args) -> OpNode:
|
|
140
133
|
"""
|
|
141
134
|
Create and return a new 'multiplication' node, applied to the given arguments.
|
|
142
135
|
"""
|
|
143
|
-
cdef list[
|
|
144
|
-
return self.
|
|
136
|
+
cdef list[CircuitNode] args = self._check_nodes(nodes)
|
|
137
|
+
return self.op(c_MUL, tuple(args))
|
|
145
138
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
cdef
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
raise RuntimeError('node does not belong to this circuit')
|
|
159
|
-
if not n.is_zero():
|
|
160
|
-
to_add.append(n)
|
|
161
|
-
cdef int len_to_add = len(to_add)
|
|
162
|
-
if len_to_add == 0:
|
|
139
|
+
def optimised_add(self, *args: Args) -> CircuitNode:
|
|
140
|
+
"""
|
|
141
|
+
Optimised circuit node addition.
|
|
142
|
+
|
|
143
|
+
Performs the following optimisations:
|
|
144
|
+
* addition to zero is avoided: add(x, 0) = x,
|
|
145
|
+
* singleton addition is avoided: add(x) = x,
|
|
146
|
+
* empty addition is avoided: add() = 0,
|
|
147
|
+
"""
|
|
148
|
+
cdef tuple[CircuitNode] to_add = tuple(n for n in self._check_nodes(args) if not n.is_zero)
|
|
149
|
+
cdef size_t num_to_add = len(to_add)
|
|
150
|
+
if num_to_add == 0:
|
|
163
151
|
return self.zero
|
|
164
|
-
|
|
152
|
+
if num_to_add == 1:
|
|
165
153
|
return to_add[0]
|
|
166
|
-
|
|
167
|
-
return self._op(ADD, tuple(to_add))
|
|
154
|
+
return self.op(c_ADD, to_add)
|
|
168
155
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
for n in
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
to_mul.append(n)
|
|
186
|
-
cdef int len_to_mul = len(to_mul)
|
|
187
|
-
if len_to_mul == 0:
|
|
156
|
+
def optimised_mul(self, *args: Args) -> CircuitNode:
|
|
157
|
+
"""
|
|
158
|
+
Optimised circuit node multiplication.
|
|
159
|
+
|
|
160
|
+
Performs the following optimisations:
|
|
161
|
+
* multiplication by zero is avoided: mul(x, 0) = 0,
|
|
162
|
+
* multiplication by one is avoided: mul(x, 1) = x,
|
|
163
|
+
* singleton multiplication is avoided: mul(x) = x,
|
|
164
|
+
* empty multiplication is avoided: mul() = 1,
|
|
165
|
+
"""
|
|
166
|
+
cdef tuple[CircuitNode] to_mul = tuple(n for n in self._check_nodes(args) if not n.is_one)
|
|
167
|
+
if any(n.is_zero for n in to_mul):
|
|
168
|
+
return self.zero
|
|
169
|
+
cdef size_t num_to_mul = len(to_mul)
|
|
170
|
+
|
|
171
|
+
if num_to_mul == 0:
|
|
188
172
|
return self.one
|
|
189
|
-
|
|
173
|
+
if num_to_mul == 1:
|
|
190
174
|
return to_mul[0]
|
|
191
|
-
|
|
192
|
-
return self._op(MUL, tuple(to_mul))
|
|
175
|
+
return self.op(c_MUL, to_mul)
|
|
193
176
|
|
|
194
177
|
def cartesian_product(self, xs: Sequence[CircuitNode], ys: Sequence[CircuitNode]) -> List[CircuitNode]:
|
|
195
178
|
"""
|
|
@@ -206,7 +189,7 @@ cdef class Circuit:
|
|
|
206
189
|
xs: List[CircuitNode] = self._check_nodes(xs)
|
|
207
190
|
ys: List[CircuitNode] = self._check_nodes(ys)
|
|
208
191
|
return [
|
|
209
|
-
self.optimised_mul(
|
|
192
|
+
self.optimised_mul(x, y)
|
|
210
193
|
for x in xs
|
|
211
194
|
for y in ys
|
|
212
195
|
]
|
|
@@ -300,19 +283,8 @@ cdef class Circuit:
|
|
|
300
283
|
Args:
|
|
301
284
|
*nodes: may be either a node or a list of nodes.
|
|
302
285
|
"""
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
_reachable_op_nodes_seen_r(node, seen)
|
|
306
|
-
|
|
307
|
-
if len(seen) < len(self.ops):
|
|
308
|
-
# Invalidate unreadable op nodes
|
|
309
|
-
for op_node in self.ops:
|
|
310
|
-
if id(op_node) not in seen:
|
|
311
|
-
op_node.circuit = None
|
|
312
|
-
op_node.args = ()
|
|
313
|
-
|
|
314
|
-
# Keep only reachable op nodes, in the same order as `self.ops`.
|
|
315
|
-
self.ops = [op_node for op_node in self.ops if id(op_node) in seen]
|
|
286
|
+
nodes = self._check_nodes(nodes)
|
|
287
|
+
self._remove_unreachable_op_nodes(nodes)
|
|
316
288
|
|
|
317
289
|
def reachable_op_nodes(self, *nodes: Args) -> List[OpNode]:
|
|
318
290
|
"""
|
|
@@ -328,11 +300,8 @@ cdef class Circuit:
|
|
|
328
300
|
Returned nodes are not repeated.
|
|
329
301
|
The result is ordered such that if result[i] is referenced by result[j] then i < j.
|
|
330
302
|
"""
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
for node in self._check_nodes(nodes):
|
|
334
|
-
_reachable_op_nodes_r(node, seen, result)
|
|
335
|
-
return result
|
|
303
|
+
nodes = self._check_nodes(nodes)
|
|
304
|
+
return self._reachable_op_nodes(nodes)
|
|
336
305
|
|
|
337
306
|
def dump(
|
|
338
307
|
self,
|
|
@@ -397,7 +366,41 @@ cdef class Circuit:
|
|
|
397
366
|
args_str = ' '.join(node_name[id(arg)] for arg in op.args)
|
|
398
367
|
print(f'{next_prefix}{op_name}: {args_str}')
|
|
399
368
|
|
|
400
|
-
cdef
|
|
369
|
+
cdef OpNode op(self, int symbol, tuple[CircuitNode, ...] nodes):
|
|
370
|
+
cdef OpNode node = OpNode(self, symbol, nodes)
|
|
371
|
+
self.ops.append(node)
|
|
372
|
+
return node
|
|
373
|
+
|
|
374
|
+
cdef list[OpNode] _reachable_op_nodes(self, list[CircuitNode] nodes):
|
|
375
|
+
# Set of object ids for all reachable op nodes
|
|
376
|
+
cdef set[int] seen = set()
|
|
377
|
+
|
|
378
|
+
cdef list[OpNode] result = []
|
|
379
|
+
|
|
380
|
+
cdef CircuitNode node
|
|
381
|
+
for node in nodes:
|
|
382
|
+
_reachable_op_nodes_r(node, seen, result)
|
|
383
|
+
return result
|
|
384
|
+
|
|
385
|
+
cdef void _remove_unreachable_op_nodes(self, list[CircuitNode] nodes):
|
|
386
|
+
# Set of object ids for all reachable op nodes
|
|
387
|
+
cdef set[int] seen = set()
|
|
388
|
+
|
|
389
|
+
cdef CircuitNode node
|
|
390
|
+
for node in nodes:
|
|
391
|
+
_reachable_op_nodes_seen_r(node, seen)
|
|
392
|
+
|
|
393
|
+
if len(seen) < len(self.ops):
|
|
394
|
+
# Invalidate unreadable op nodes
|
|
395
|
+
for op_node in self.ops:
|
|
396
|
+
if id(op_node) not in seen:
|
|
397
|
+
op_node.circuit = None
|
|
398
|
+
op_node.args = ()
|
|
399
|
+
|
|
400
|
+
# Keep only reachable op nodes, in the same order as `self.ops`.
|
|
401
|
+
self.ops = [op_node for op_node in self.ops if id(op_node) in seen]
|
|
402
|
+
|
|
403
|
+
cdef list[CircuitNode] _check_nodes(self, object nodes: Iterable[Args]): # -> Sequence[CircuitNode]:
|
|
401
404
|
# Convert the given circuit nodes to a tuple, flattening nested iterables as needed.
|
|
402
405
|
#
|
|
403
406
|
# Args:
|
|
@@ -405,11 +408,11 @@ cdef class Circuit:
|
|
|
405
408
|
#
|
|
406
409
|
# Raises:
|
|
407
410
|
# RuntimeError: if any node does not belong to this circuit.
|
|
408
|
-
cdef list[
|
|
411
|
+
cdef list[CircuitNode] result = []
|
|
409
412
|
self.__check_nodes(nodes, result)
|
|
410
413
|
return result
|
|
411
414
|
|
|
412
|
-
cdef __check_nodes(self, nodes: Iterable[Args], list[
|
|
415
|
+
cdef void __check_nodes(self, object nodes: Iterable[Args], list[CircuitNode] result):
|
|
413
416
|
# Convert the given circuit nodes to a tuple, flattening nested iterables as needed.
|
|
414
417
|
#
|
|
415
418
|
# Args:
|
|
@@ -428,7 +431,7 @@ cdef class Circuit:
|
|
|
428
431
|
else:
|
|
429
432
|
self.__check_nodes(node, result)
|
|
430
433
|
|
|
431
|
-
cdef object _derivatives(self,
|
|
434
|
+
cdef object _derivatives(self, CircuitNode f):
|
|
432
435
|
# Get a _DerivativeHelper for `f`.
|
|
433
436
|
# Checking the derivative cache.
|
|
434
437
|
derivatives: Optional[_DerivativeHelper] = self.__derivatives
|
|
@@ -437,7 +440,6 @@ cdef class Circuit:
|
|
|
437
440
|
self.__derivatives = derivatives
|
|
438
441
|
return derivatives
|
|
439
442
|
|
|
440
|
-
|
|
441
443
|
cdef class CircuitNode:
|
|
442
444
|
"""
|
|
443
445
|
A node in an arithmetic circuit.
|
|
@@ -452,16 +454,11 @@ cdef class CircuitNode:
|
|
|
452
454
|
A var node may be temporarily set to be a constant node, which may
|
|
453
455
|
be useful for optimising a compiled circuit.
|
|
454
456
|
"""
|
|
455
|
-
cdef public object circuit
|
|
456
457
|
|
|
457
|
-
def __init__(self, circuit):
|
|
458
|
+
def __init__(self, circuit: Circuit, is_zero: bool, is_one: bool):
|
|
458
459
|
self.circuit = circuit
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
return False
|
|
462
|
-
|
|
463
|
-
cpdef int is_one(self) except*:
|
|
464
|
-
return False
|
|
460
|
+
self.is_zero = is_zero
|
|
461
|
+
self.is_one = is_one
|
|
465
462
|
|
|
466
463
|
def __add__(self, other: CircuitNode | ConstValue):
|
|
467
464
|
return self.circuit.add(self, other)
|
|
@@ -469,24 +466,14 @@ cdef class CircuitNode:
|
|
|
469
466
|
def __mul__(self, other: CircuitNode | ConstValue):
|
|
470
467
|
return self.circuit.mul(self, other)
|
|
471
468
|
|
|
472
|
-
|
|
473
469
|
cdef class ConstNode(CircuitNode):
|
|
474
|
-
cdef public object value
|
|
475
|
-
|
|
476
470
|
"""
|
|
477
471
|
A node in a circuit representing a constant value.
|
|
478
472
|
"""
|
|
479
|
-
def __init__(self, circuit, value: ConstValue):
|
|
480
|
-
super().__init__(circuit)
|
|
481
|
-
self.value: ConstValue = value
|
|
482
|
-
|
|
483
|
-
cpdef int is_zero(self) except*:
|
|
484
|
-
# noinspection PyProtectedMember
|
|
485
|
-
return self is self.circuit.zero
|
|
486
473
|
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
474
|
+
def __init__(self, circuit, value: ConstValue, is_zero: bool = False, is_one: bool = False):
|
|
475
|
+
super().__init__(circuit, is_zero, is_one)
|
|
476
|
+
self.value: ConstValue = value
|
|
490
477
|
|
|
491
478
|
def __str__(self) -> str:
|
|
492
479
|
return 'const(' + str(self.value) + ')'
|
|
@@ -497,25 +484,16 @@ cdef class ConstNode(CircuitNode):
|
|
|
497
484
|
else:
|
|
498
485
|
return False
|
|
499
486
|
|
|
500
|
-
|
|
501
487
|
cdef class VarNode(CircuitNode):
|
|
502
488
|
"""
|
|
503
489
|
A node in a circuit representing an input variable.
|
|
504
490
|
"""
|
|
505
|
-
cdef public int idx
|
|
506
|
-
cdef object _const
|
|
507
491
|
|
|
508
492
|
def __init__(self, circuit, idx: int):
|
|
509
|
-
super().__init__(circuit)
|
|
493
|
+
super().__init__(circuit, False, False)
|
|
510
494
|
self.idx = idx
|
|
511
495
|
self._const = None
|
|
512
496
|
|
|
513
|
-
cpdef int is_zero(self) except*:
|
|
514
|
-
return self._const is not None and self._const.is_zero()
|
|
515
|
-
|
|
516
|
-
cpdef int is_one(self) except*:
|
|
517
|
-
return self._const is not None and self._const.is_one()
|
|
518
|
-
|
|
519
497
|
cpdef int is_const(self) except*:
|
|
520
498
|
return self._const is not None
|
|
521
499
|
|
|
@@ -527,8 +505,13 @@ cdef class VarNode(CircuitNode):
|
|
|
527
505
|
def const(self, value: ConstValue | ConstNode | None) -> None:
|
|
528
506
|
if value is None:
|
|
529
507
|
self._const = None
|
|
508
|
+
self.is_zero = False
|
|
509
|
+
self.is_one = False
|
|
530
510
|
else:
|
|
531
|
-
|
|
511
|
+
const_node: ConstNode = self.circuit.const(value)
|
|
512
|
+
self._const = const_node
|
|
513
|
+
self.is_zero = const_node.is_zero
|
|
514
|
+
self.is_one = const_node.is_one
|
|
532
515
|
|
|
533
516
|
def __lt__(self, other) -> bool:
|
|
534
517
|
if isinstance(other, VarNode):
|
|
@@ -546,13 +529,11 @@ cdef class OpNode(CircuitNode):
|
|
|
546
529
|
"""
|
|
547
530
|
A node in a circuit representing an arithmetic operation.
|
|
548
531
|
"""
|
|
549
|
-
cdef public tuple[object, ...] args
|
|
550
|
-
cdef public int symbol
|
|
551
532
|
|
|
552
533
|
def __init__(self, object circuit, symbol: int, tuple[object, ...] args: Tuple[CircuitNode]):
|
|
553
|
-
super().__init__(circuit)
|
|
534
|
+
super().__init__(circuit, False, False)
|
|
554
535
|
self.args = tuple(args)
|
|
555
|
-
self.symbol = int
|
|
536
|
+
self.symbol = <int> symbol
|
|
556
537
|
|
|
557
538
|
def __str__(self) -> str:
|
|
558
539
|
return f'{self.op_str()}\\{len(self.args)}'
|
|
@@ -561,9 +542,9 @@ cdef class OpNode(CircuitNode):
|
|
|
561
542
|
"""
|
|
562
543
|
Returns the op node operation as a string.
|
|
563
544
|
"""
|
|
564
|
-
if self.symbol ==
|
|
545
|
+
if self.symbol == c_MUL:
|
|
565
546
|
return 'mul'
|
|
566
|
-
elif self.symbol ==
|
|
547
|
+
elif self.symbol == c_ADD:
|
|
567
548
|
return 'add'
|
|
568
549
|
else:
|
|
569
550
|
return '?' + str(self.symbol)
|
|
@@ -661,7 +642,7 @@ class _DerivativeHelper:
|
|
|
661
642
|
elif d is self.one:
|
|
662
643
|
d_node.derivative_self_mul = node
|
|
663
644
|
else:
|
|
664
|
-
d_node.derivative_self_mul = self.circuit.optimised_mul(
|
|
645
|
+
d_node.derivative_self_mul = self.circuit.optimised_mul(d, node)
|
|
665
646
|
|
|
666
647
|
return d_node.derivative_self_mul
|
|
667
648
|
|
|
@@ -673,13 +654,13 @@ class _DerivativeHelper:
|
|
|
673
654
|
to_add: Sequence[CircuitNode] = tuple(
|
|
674
655
|
value
|
|
675
656
|
for value in (self._derivative_prod(prods) for prods in d_node.sum_prod)
|
|
676
|
-
if not value.is_zero
|
|
657
|
+
if not value.is_zero
|
|
677
658
|
)
|
|
678
659
|
# we can release the temporary memory at this DNode now
|
|
679
660
|
d_node.sum_prod = None
|
|
680
661
|
|
|
681
662
|
# Construct the addition operation
|
|
682
|
-
d_node.derivative = self.circuit.optimised_add(to_add)
|
|
663
|
+
d_node.derivative = self.circuit.optimised_add(*to_add)
|
|
683
664
|
|
|
684
665
|
return d_node.derivative
|
|
685
666
|
|
|
@@ -700,7 +681,7 @@ class _DerivativeHelper:
|
|
|
700
681
|
to_mul.append(arg)
|
|
701
682
|
|
|
702
683
|
# Construct the multiplication operation
|
|
703
|
-
return self.circuit.optimised_mul(to_mul)
|
|
684
|
+
return self.circuit.optimised_mul(*to_mul)
|
|
704
685
|
|
|
705
686
|
def _mk_derivative_r(self, d_node: _DNode) -> None:
|
|
706
687
|
"""
|
|
@@ -712,11 +693,11 @@ class _DerivativeHelper:
|
|
|
712
693
|
node: CircuitNode = d_node.node
|
|
713
694
|
|
|
714
695
|
if isinstance(node, OpNode):
|
|
715
|
-
if node.symbol ==
|
|
696
|
+
if node.symbol == c_ADD:
|
|
716
697
|
for arg in node.args:
|
|
717
698
|
child_d_node = self._add(arg, d_node, [])
|
|
718
699
|
self._mk_derivative_r(child_d_node)
|
|
719
|
-
elif node.symbol ==
|
|
700
|
+
elif node.symbol == c_MUL:
|
|
720
701
|
for arg in node.args:
|
|
721
702
|
prod = [arg2 for arg2 in node.args if arg is not arg2]
|
|
722
703
|
child_d_node = self._add(arg, d_node, prod)
|
|
@@ -748,6 +729,8 @@ class _DerivativeHelper:
|
|
|
748
729
|
|
|
749
730
|
def _get(self, node: CircuitNode) -> _DNode:
|
|
750
731
|
"""
|
|
732
|
+
Helper for derivatives.
|
|
733
|
+
|
|
751
734
|
Get the DNode for the given circuit node.
|
|
752
735
|
If no DNode exist for it yet, then one will be constructed.
|
|
753
736
|
"""
|
|
@@ -759,21 +742,20 @@ class _DerivativeHelper:
|
|
|
759
742
|
return d_node
|
|
760
743
|
|
|
761
744
|
|
|
762
|
-
cdef void _reachable_op_nodes_r(
|
|
745
|
+
cdef void _reachable_op_nodes_r(CircuitNode node, set[int] seen, list[OpNode] result):
|
|
763
746
|
# Recursive helper for `reachable_op_nodes`. Performs a depth-first search.
|
|
764
747
|
#
|
|
765
748
|
# Args:
|
|
766
749
|
# node: the current node being checked.
|
|
767
750
|
# seen: keep track of seen op node ids (to avoid returning multiple of the same node).
|
|
768
|
-
# result: a list where the nodes are added
|
|
751
|
+
# result: a list where the seen nodes are added.
|
|
769
752
|
if isinstance(node, OpNode) and id(node) not in seen:
|
|
770
753
|
seen.add(id(node))
|
|
771
754
|
for arg in node.args:
|
|
772
755
|
_reachable_op_nodes_r(arg, seen, result)
|
|
773
756
|
result.append(node)
|
|
774
757
|
|
|
775
|
-
|
|
776
|
-
cdef void _reachable_op_nodes_seen_r(object node: CircuitNode, set seen: Set[int]):
|
|
758
|
+
cdef void _reachable_op_nodes_seen_r(CircuitNode node, set[int] seen):
|
|
777
759
|
# Recursive helper for `remove_unreachable_op_nodes`. Performs a depth-first search.
|
|
778
760
|
#
|
|
779
761
|
# Args:
|
|
@@ -181,7 +181,7 @@ class Circuit:
|
|
|
181
181
|
* singleton addition is avoided: add(x) = x,
|
|
182
182
|
* empty addition is avoided: add() = 0,
|
|
183
183
|
"""
|
|
184
|
-
to_add = tuple(n for n in self._check_nodes(args) if not n.is_zero
|
|
184
|
+
to_add = tuple(n for n in self._check_nodes(args) if not n.is_zero)
|
|
185
185
|
match len(to_add):
|
|
186
186
|
case 0:
|
|
187
187
|
return self.zero
|
|
@@ -200,8 +200,8 @@ class Circuit:
|
|
|
200
200
|
* singleton multiplication is avoided: mul(x) = x,
|
|
201
201
|
* empty multiplication is avoided: mul() = 1,
|
|
202
202
|
"""
|
|
203
|
-
to_mul = tuple(n for n in self._check_nodes(args) if not n.is_one
|
|
204
|
-
if any(n.is_zero
|
|
203
|
+
to_mul = tuple(n for n in self._check_nodes(args) if not n.is_one)
|
|
204
|
+
if any(n.is_zero for n in to_mul):
|
|
205
205
|
return self.zero
|
|
206
206
|
match len(to_mul):
|
|
207
207
|
case 0:
|
|
@@ -485,12 +485,14 @@ class CircuitNode:
|
|
|
485
485
|
def __init__(self, circuit: Circuit):
|
|
486
486
|
self.circuit = circuit
|
|
487
487
|
|
|
488
|
+
@property
|
|
488
489
|
def is_zero(self) -> bool:
|
|
489
490
|
"""
|
|
490
491
|
Does this node represent the constant zero.
|
|
491
492
|
"""
|
|
492
493
|
return False
|
|
493
494
|
|
|
495
|
+
@property
|
|
494
496
|
def is_one(self) -> bool:
|
|
495
497
|
"""
|
|
496
498
|
Does this node represent the constant one.
|
|
@@ -522,10 +524,12 @@ class ConstNode(CircuitNode):
|
|
|
522
524
|
def value(self) -> ConstValue:
|
|
523
525
|
return self._value
|
|
524
526
|
|
|
527
|
+
@property
|
|
525
528
|
def is_zero(self) -> bool:
|
|
526
529
|
# noinspection PyProtectedMember
|
|
527
530
|
return self is self.circuit.zero
|
|
528
531
|
|
|
532
|
+
@property
|
|
529
533
|
def is_one(self) -> bool:
|
|
530
534
|
# noinspection PyProtectedMember
|
|
531
535
|
return self is self.circuit.one
|
|
@@ -569,11 +573,13 @@ class VarNode(CircuitNode):
|
|
|
569
573
|
else:
|
|
570
574
|
self._const = self.circuit.const(value)
|
|
571
575
|
|
|
576
|
+
@property
|
|
572
577
|
def is_zero(self) -> bool:
|
|
573
|
-
return self._const is not None and self._const.is_zero
|
|
578
|
+
return self._const is not None and self._const.is_zero
|
|
574
579
|
|
|
580
|
+
@property
|
|
575
581
|
def is_one(self) -> bool:
|
|
576
|
-
return self._const is not None and self._const.is_one
|
|
582
|
+
return self._const is not None and self._const.is_one
|
|
577
583
|
|
|
578
584
|
def __lt__(self, other) -> bool:
|
|
579
585
|
if isinstance(other, VarNode):
|
|
@@ -707,13 +713,15 @@ class _DerivativeHelper:
|
|
|
707
713
|
to_add: Sequence[CircuitNode] = tuple(
|
|
708
714
|
value
|
|
709
715
|
for value in (self._derivative_prod(prods) for prods in d_node.sum_prod)
|
|
710
|
-
if not value.is_zero
|
|
716
|
+
if not value.is_zero
|
|
711
717
|
)
|
|
712
718
|
# We can release the temporary memory at this DNode now
|
|
719
|
+
# Warning disabled as we will never use this field again - doing so would be an error.
|
|
720
|
+
# noinspection PyTypeChecker
|
|
713
721
|
d_node.sum_prod = None
|
|
714
722
|
|
|
715
723
|
# Construct the addition operation
|
|
716
|
-
d_node.derivative = self.circuit.optimised_add(to_add)
|
|
724
|
+
d_node.derivative = self.circuit.optimised_add(*to_add)
|
|
717
725
|
|
|
718
726
|
return d_node.derivative
|
|
719
727
|
|
|
@@ -734,7 +742,7 @@ class _DerivativeHelper:
|
|
|
734
742
|
to_mul.append(arg)
|
|
735
743
|
|
|
736
744
|
# Construct the multiplication operation
|
|
737
|
-
return self.circuit.optimised_mul(to_mul)
|
|
745
|
+
return self.circuit.optimised_mul(*to_mul)
|
|
738
746
|
|
|
739
747
|
def _mk_derivative_r(self, d_node: _DNode) -> None:
|
|
740
748
|
"""
|
|
Binary file
|