compiled-knowledge 4.0.0a17__cp312-cp312-win_amd64.whl → 4.0.0a19__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (27) hide show
  1. ck/circuit/__init__.py +4 -0
  2. ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
  3. ck/circuit/_circuit_cy.pxd +32 -0
  4. ck/circuit/_circuit_cy.pyx +157 -182
  5. ck/circuit/_circuit_py.py +2 -2
  6. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  7. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +193 -79
  8. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +29 -4
  9. ck/circuit_compiler/support/circuit_analyser/__init__.py +13 -0
  10. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win_amd64.pyd +0 -0
  11. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.pyx +98 -0
  12. ck/circuit_compiler/support/{circuit_analyser.py → circuit_analyser/_circuit_analyser_py.py} +14 -2
  13. ck/pgm_compiler/ace/__init__.py +1 -1
  14. ck/pgm_compiler/support/circuit_table/__init__.py +8 -0
  15. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
  16. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.pyx +44 -37
  17. ck/pgm_compiler/support/circuit_table/_circuit_table_py.py +76 -41
  18. ck/pgm_compiler/support/named_compiler_maker.py +12 -2
  19. ck/utils/iter_extras.py +8 -1
  20. ck_demos/ace/demo_ace.py +5 -0
  21. ck_demos/utils/compare.py +5 -1
  22. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/METADATA +1 -1
  23. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/RECORD +26 -23
  24. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/WHEEL +1 -1
  25. ck/circuit_compiler/cython_vm_compiler/_compiler.c +0 -16946
  26. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/licenses/LICENSE.txt +0 -0
  27. {compiled_knowledge-4.0.0a17.dist-info → compiled_knowledge-4.0.0a19.dist-info}/top_level.txt +0 -0
ck/circuit/__init__.py CHANGED
@@ -1,3 +1,7 @@
1
+ # There are two implementations of the `circuit` module are provided
2
+ # for developer R&D purposes. One is pure Python and the other is Cython.
3
+ # Which implementation is used can be selected here.
4
+
1
5
  # from ._circuit_py import (
2
6
  from ._circuit_cy import (
3
7
  Circuit,
Binary file
@@ -0,0 +1,32 @@
1
+ cdef class Circuit:
2
+ cdef public list[VarNode] vars
3
+ cdef public list[OpNode] ops
4
+ cdef public ConstNode zero
5
+ cdef public ConstNode one
6
+ cdef dict[object, ConstNode] _const_map
7
+ cdef object __derivatives
8
+
9
+ cdef OpNode op(self, int symbol, tuple[CircuitNode, ...] nodes)
10
+ cdef list[OpNode] find_reachable_op_nodes(self, list[CircuitNode] nodes)
11
+ cdef void _remove_unreachable_op_nodes(self, list[CircuitNode] nodes)
12
+ cdef list[CircuitNode] _check_nodes(self, object nodes)
13
+ cdef void __check_nodes(self, object nodes, list[CircuitNode] result)
14
+ cdef object _derivatives(self, CircuitNode f)
15
+
16
+ cdef class CircuitNode:
17
+ cdef public Circuit circuit
18
+ cdef public bint is_zero
19
+ cdef public bint is_one
20
+
21
+ cdef class ConstNode(CircuitNode):
22
+ cdef public object value
23
+
24
+ cdef class VarNode(CircuitNode):
25
+ cdef public int idx
26
+ cdef object _const
27
+
28
+ cpdef int is_const(self) except*
29
+
30
+ cdef class OpNode(CircuitNode):
31
+ cdef public tuple[object, ...] args
32
+ cdef public int symbol
@@ -4,7 +4,7 @@ For more documentation on this module, refer to the Jupyter notebook docs/6_circ
4
4
  from __future__ import annotations
5
5
 
6
6
  from itertools import chain
7
- from typing import Dict, Tuple, Optional, Iterable, Sequence, List, overload, Any
7
+ from typing import Dict, Optional, Iterable, Sequence, List, overload
8
8
 
9
9
  # Type for values of ConstNode objects
10
10
  ConstValue = float | int | bool
@@ -15,6 +15,8 @@ Args = CircuitNode | ConstValue | Iterable[CircuitNode | ConstValue]
15
15
  ADD: int = 0
16
16
  MUL: int = 1
17
17
 
18
+ cdef int c_ADD = ADD
19
+ cdef int c_MUL = MUL
18
20
 
19
21
  cdef class Circuit:
20
22
  """
@@ -29,13 +31,6 @@ cdef class Circuit:
29
31
  `VarNode` may be temporarily be set to a constant value.
30
32
  """
31
33
 
32
- cdef public list[VarNode] vars
33
- cdef public list[OpNode] ops
34
- cdef public object zero
35
- cdef public object one
36
- cdef dict[Any, ConstNode] _const_map
37
- cdef object __derivatives
38
-
39
34
  def __init__(self, zero: ConstValue = 0, one: ConstValue = 1):
40
35
  """
41
36
  Construct a new, empty circuit.
@@ -46,13 +41,11 @@ cdef class Circuit:
46
41
  """
47
42
  self.vars: List[VarNode] = []
48
43
  self.ops: List[OpNode] = []
49
- self._const_map: Dict[ConstValue, ConstNode] = {}
50
- self.__derivatives: Optional[_DerivativeHelper] = None # cache for partial derivatives calculations.
51
44
  self.zero: ConstNode = ConstNode(self, zero, is_zero=True)
52
45
  self.one: ConstNode = ConstNode(self, one, is_one=True)
53
46
 
54
- self._const_map[zero] = self.zero
55
- self._const_map[one] = self.one
47
+ self._const_map: Dict[ConstValue, ConstNode] = {zero: self.zero, one: self.one}
48
+ self.__derivatives: Optional[_DerivativeHelper] = None # cache for partial derivatives calculations.
56
49
 
57
50
  @property
58
51
  def number_of_vars(self) -> int:
@@ -127,72 +120,57 @@ cdef class Circuit:
127
120
  self._const_map[value] = node
128
121
  return node
129
122
 
130
- cdef OpNode _op(self, int symbol, tuple[CircuitNode, ...] nodes):
131
- cdef OpNode node = OpNode(self, symbol, nodes)
132
- self.ops.append(node)
133
- return node
134
-
135
123
  def add(self, *nodes: Args) -> OpNode:
136
124
  """
137
125
  Create and return a new 'addition' node, applied to the given arguments.
138
126
  """
139
127
  cdef list[CircuitNode] args = self._check_nodes(nodes)
140
- return self._op(ADD, tuple(args))
128
+ return self.op(c_ADD, tuple(args))
141
129
 
142
130
  def mul(self, *nodes: Args) -> OpNode:
143
131
  """
144
132
  Create and return a new 'multiplication' node, applied to the given arguments.
145
133
  """
146
134
  cdef list[CircuitNode] args = self._check_nodes(nodes)
147
- return self._op(MUL, tuple(args))
135
+ return self.op(c_MUL, tuple(args))
148
136
 
149
- cpdef object optimised_add(self, object nodes: Iterable[CircuitNode]): # -> CircuitNode:
150
- # Optimised circuit node addition.
151
- #
152
- # Performs the following optimisations:
153
- # * addition to zero is avoided: add(x, 0) = x,
154
- # * singleton addition is avoided: add(x) = x,
155
- # * empty addition is avoided: add() = 0.
156
-
157
- cdef list[CircuitNode] to_add = []
158
- cdef CircuitNode n
159
- for n in nodes:
160
- if n.circuit is not self:
161
- raise RuntimeError('node does not belong to this circuit')
162
- if not n.is_zero:
163
- to_add.append(n)
164
- cdef int len_to_add = len(to_add)
165
- if len_to_add == 0:
137
+ def optimised_add(self, *args: Args) -> CircuitNode:
138
+ """
139
+ Optimised circuit node addition.
140
+
141
+ Performs the following optimisations:
142
+ * addition to zero is avoided: add(x, 0) = x,
143
+ * singleton addition is avoided: add(x) = x,
144
+ * empty addition is avoided: add() = 0,
145
+ """
146
+ cdef tuple[CircuitNode, ...] to_add = tuple(n for n in self._check_nodes(args) if not n.is_zero)
147
+ cdef size_t num_to_add = len(to_add)
148
+ if num_to_add == 0:
166
149
  return self.zero
167
- elif len_to_add == 1:
150
+ if num_to_add == 1:
168
151
  return to_add[0]
169
- else:
170
- return self._op(ADD, tuple(to_add))
152
+ return self.op(c_ADD, to_add)
171
153
 
172
- cpdef object optimised_mul(self, object nodes: Iterable[CircuitNode]): # -> CircuitNode:
173
- # Optimised circuit node multiplication.
174
- #
175
- # Performs the following optimisations:
176
- # * multiplication by zero is avoided: mul(x, 0) = 0,
177
- # * multiplication by one is avoided: mul(x, 1) = x,
178
- # * singleton multiplication is avoided: mul(x) = x,
179
- # * empty multiplication is avoided: mul() = 1.
180
- cdef list[CircuitNode] to_mul = []
181
- cdef CircuitNode n
182
- for n in nodes:
183
- if n.circuit is not self:
184
- raise RuntimeError('node does not belong to this circuit')
185
- if n.is_zero:
186
- return self.zero
187
- if not n.is_one:
188
- to_mul.append(n)
189
- cdef int len_to_mul = len(to_mul)
190
- if len_to_mul == 0:
154
+ def optimised_mul(self, *args: Args) -> CircuitNode:
155
+ """
156
+ Optimised circuit node multiplication.
157
+
158
+ Performs the following optimisations:
159
+ * multiplication by zero is avoided: mul(x, 0) = 0,
160
+ * multiplication by one is avoided: mul(x, 1) = x,
161
+ * singleton multiplication is avoided: mul(x) = x,
162
+ * empty multiplication is avoided: mul() = 1,
163
+ """
164
+ cdef tuple[CircuitNode, ...] to_mul = tuple(n for n in self._check_nodes(args) if not n.is_one)
165
+ if any(n.is_zero for n in to_mul):
166
+ return self.zero
167
+ cdef Py_ssize_t num_to_mul = len(to_mul)
168
+
169
+ if num_to_mul == 0:
191
170
  return self.one
192
- elif len_to_mul == 1:
171
+ if num_to_mul == 1:
193
172
  return to_mul[0]
194
- else:
195
- return self._op(MUL, tuple(to_mul))
173
+ return self.op(c_MUL, to_mul)
196
174
 
197
175
  def cartesian_product(self, xs: Sequence[CircuitNode], ys: Sequence[CircuitNode]) -> List[CircuitNode]:
198
176
  """
@@ -206,12 +184,12 @@ cdef class Circuit:
206
184
  a list of 'mul' nodes, one for each combination of xs and ys. The results are in the order
207
185
  given by `[mul(x, y) for x in xs for y in ys]`.
208
186
  """
209
- xs: List[CircuitNode] = self._check_nodes(xs)
210
- ys: List[CircuitNode] = self._check_nodes(ys)
187
+ cdef list[CircuitNode] xs_list = self._check_nodes(xs)
188
+ cdef list[CircuitNode] ys_list = self._check_nodes(ys)
211
189
  return [
212
- self.optimised_mul((x, y))
213
- for x in xs
214
- for y in ys
190
+ self.optimised_mul(x, y)
191
+ for x in xs_list
192
+ for y in ys_list
215
193
  ]
216
194
 
217
195
  @overload
@@ -262,10 +240,10 @@ cdef class Circuit:
262
240
  If `args` is a single CircuitNode, then a single CircuitNode will be returned, otherwise
263
241
  a list of CircuitNode is returned.
264
242
  """
265
- single_result: bool = isinstance(args, CircuitNode)
243
+ cdef bint single_result = isinstance(args, CircuitNode)
266
244
 
267
- args: List[CircuitNode] = self._check_nodes([args])
268
- if len(args) == 0:
245
+ cdef list[CircuitNode] args_list = self._check_nodes([args])
246
+ if len(args_list) == 0:
269
247
  # Trivial case
270
248
  return []
271
249
 
@@ -274,12 +252,12 @@ cdef class Circuit:
274
252
  if self_multiply:
275
253
  result = [
276
254
  derivatives.derivative_self_mul(arg)
277
- for arg in args
255
+ for arg in args_list
278
256
  ]
279
257
  else:
280
258
  result = [
281
259
  derivatives.derivative(arg)
282
- for arg in args
260
+ for arg in args_list
283
261
  ]
284
262
 
285
263
  if single_result:
@@ -303,26 +281,8 @@ cdef class Circuit:
303
281
  Args:
304
282
  *nodes: may be either a node or a list of nodes.
305
283
  """
306
- nodes = self._check_nodes(nodes)
307
- self._remove_unreachable_op_nodes(nodes)
308
-
309
- cdef void _remove_unreachable_op_nodes(self, list[CircuitNode] nodes):
310
- # Set of object ids for all reachable op nodes
311
- cdef set[int] seen = set()
312
-
313
- cdef CircuitNode node
314
- for node in nodes:
315
- _reachable_op_nodes_seen_r(node, seen)
316
-
317
- if len(seen) < len(self.ops):
318
- # Invalidate unreadable op nodes
319
- for op_node in self.ops:
320
- if id(op_node) not in seen:
321
- op_node.circuit = None
322
- op_node.args = ()
323
-
324
- # Keep only reachable op nodes, in the same order as `self.ops`.
325
- self.ops = [op_node for op_node in self.ops if id(op_node) in seen]
284
+ cdef list[CircuitNode] node_list = self._check_nodes(nodes)
285
+ self._remove_unreachable_op_nodes(node_list)
326
286
 
327
287
  def reachable_op_nodes(self, *nodes: Args) -> List[OpNode]:
328
288
  """
@@ -338,19 +298,8 @@ cdef class Circuit:
338
298
  Returned nodes are not repeated.
339
299
  The result is ordered such that if result[i] is referenced by result[j] then i < j.
340
300
  """
341
- nodes = self._check_nodes(nodes)
342
- return self._reachable_op_nodes(nodes)
343
-
344
- cdef list[OpNode] _reachable_op_nodes(self, list[CircuitNode] nodes):
345
- # Set of object ids for all reachable op nodes
346
- cdef set[int] seen = set()
347
-
348
- cdef list[OpNode] result = []
349
-
350
- cdef CircuitNode node
351
- for node in nodes:
352
- _reachable_op_nodes_r(node, seen, result)
353
- return result
301
+ cdef list[CircuitNode] node_list = self._check_nodes(nodes)
302
+ return self.find_reachable_op_nodes(node_list)
354
303
 
355
304
  def dump(
356
305
  self,
@@ -415,8 +364,42 @@ cdef class Circuit:
415
364
  args_str = ' '.join(node_name[id(arg)] for arg in op.args)
416
365
  print(f'{next_prefix}{op_name}: {args_str}')
417
366
 
367
+ cdef OpNode op(self, int symbol, tuple[CircuitNode, ...] nodes):
368
+ cdef OpNode node = OpNode(self, symbol, nodes)
369
+ self.ops.append(node)
370
+ return node
371
+
372
+ cdef list[OpNode] find_reachable_op_nodes(self, list[CircuitNode] nodes):
373
+ # Set of object ids for all reachable op nodes
374
+ cdef set[int] seen = set()
375
+
376
+ cdef list[OpNode] result = []
377
+
378
+ cdef CircuitNode node
379
+ for node in nodes:
380
+ find_reachable_op_nodes_r(node, seen, result)
381
+ return result
382
+
383
+ cdef void _remove_unreachable_op_nodes(self, list[CircuitNode] nodes):
384
+ # Set of object ids for all reachable op nodes
385
+ cdef set[int] seen = set()
386
+
387
+ cdef CircuitNode node
388
+ for node in nodes:
389
+ find_reachable_op_nodes_seen_r(node, seen)
390
+
391
+ if len(seen) < len(self.ops):
392
+ # Invalidate unreadable op nodes
393
+ for op_node in self.ops:
394
+ if id(op_node) not in seen:
395
+ op_node.circuit = None
396
+ op_node.args = ()
397
+
398
+ # Keep only reachable op nodes, in the same order as `self.ops`.
399
+ self.ops = [op_node for op_node in self.ops if id(op_node) in seen]
400
+
418
401
  cdef list[CircuitNode] _check_nodes(self, object nodes: Iterable[Args]): # -> Sequence[CircuitNode]:
419
- # Convert the given circuit nodes to a tuple, flattening nested iterables as needed.
402
+ # Convert the given circuit nodes to a list, flattening nested iterables as needed.
420
403
  #
421
404
  # Args:
422
405
  # nodes: some circuit nodes of constant values.
@@ -427,8 +410,8 @@ cdef class Circuit:
427
410
  self.__check_nodes(nodes, result)
428
411
  return result
429
412
 
430
- cdef void __check_nodes(self, nodes: Iterable[Args], list[CircuitNode] result):
431
- # Convert the given circuit nodes to a tuple, flattening nested iterables as needed.
413
+ cdef void __check_nodes(self, object nodes: Iterable[Args], list[CircuitNode] result):
414
+ # Convert the given circuit nodes to a list, flattening nested iterables as needed.
432
415
  #
433
416
  # Args:
434
417
  # nodes: some circuit nodes of constant values.
@@ -455,7 +438,6 @@ cdef class Circuit:
455
438
  self.__derivatives = derivatives
456
439
  return derivatives
457
440
 
458
-
459
441
  cdef class CircuitNode:
460
442
  """
461
443
  A node in an arithmetic circuit.
@@ -470,9 +452,6 @@ cdef class CircuitNode:
470
452
  A var node may be temporarily set to be a constant node, which may
471
453
  be useful for optimising a compiled circuit.
472
454
  """
473
- cdef public Circuit circuit
474
- cdef public bint is_zero
475
- cdef public bint is_one
476
455
 
477
456
  def __init__(self, circuit: Circuit, is_zero: bool, is_one: bool):
478
457
  self.circuit = circuit
@@ -485,11 +464,10 @@ cdef class CircuitNode:
485
464
  def __mul__(self, other: CircuitNode | ConstValue):
486
465
  return self.circuit.mul(self, other)
487
466
 
488
-
489
467
  cdef class ConstNode(CircuitNode):
490
- # A node in a circuit representing a constant value.
491
-
492
- cdef public object value
468
+ """
469
+ A node in a circuit representing a constant value.
470
+ """
493
471
 
494
472
  def __init__(self, circuit, value: ConstValue, is_zero: bool = False, is_one: bool = False):
495
473
  super().__init__(circuit, is_zero, is_one)
@@ -504,13 +482,10 @@ cdef class ConstNode(CircuitNode):
504
482
  else:
505
483
  return False
506
484
 
507
-
508
485
  cdef class VarNode(CircuitNode):
509
486
  """
510
487
  A node in a circuit representing an input variable.
511
488
  """
512
- cdef public int idx
513
- cdef object _const
514
489
 
515
490
  def __init__(self, circuit, idx: int):
516
491
  super().__init__(circuit, False, False)
@@ -552,13 +527,11 @@ cdef class OpNode(CircuitNode):
552
527
  """
553
528
  A node in a circuit representing an arithmetic operation.
554
529
  """
555
- cdef public tuple[object, ...] args
556
- cdef public int symbol
557
530
 
558
- def __init__(self, object circuit, symbol: int, tuple[object, ...] args: Tuple[CircuitNode]):
531
+ def __init__(self, Circuit circuit, int symbol, tuple[CircuitNode, ...] args):
559
532
  super().__init__(circuit, False, False)
560
533
  self.args = tuple(args)
561
- self.symbol = int(symbol)
534
+ self.symbol = <int> symbol
562
535
 
563
536
  def __str__(self) -> str:
564
537
  return f'{self.op_str()}\\{len(self.args)}'
@@ -567,9 +540,9 @@ cdef class OpNode(CircuitNode):
567
540
  """
568
541
  Returns the op node operation as a string.
569
542
  """
570
- if self.symbol == MUL:
543
+ if self.symbol == c_MUL:
571
544
  return 'mul'
572
- elif self.symbol == ADD:
545
+ elif self.symbol == c_ADD:
573
546
  return 'add'
574
547
  else:
575
548
  return '?' + str(self.symbol)
@@ -579,11 +552,11 @@ cdef class _DNode:
579
552
  A data structure supporting derivative calculations.
580
553
  A DNode holds all information needed to calculate the partial derivative at `node`.
581
554
  """
582
- cdef public object node
583
- cdef public object derivative
584
- cdef public object derivative_self_mul
585
- cdef public list sum_prod
586
- cdef public bint processed
555
+ cdef CircuitNode node
556
+ cdef object derivative
557
+ cdef object derivative_self_mul
558
+ cdef list[_DNodeProduct] sum_prod
559
+ cdef bint processed
587
560
 
588
561
  def __init__(
589
562
  self,
@@ -615,8 +588,8 @@ cdef class _DNodeProduct:
615
588
 
616
589
  The represents a product of `parent` and `prod`.
617
590
  """
618
- cdef public object parent
619
- cdef public list prod
591
+ cdef _DNode parent
592
+ cdef list[CircuitNode] prod
620
593
 
621
594
  def __init__(self, parent: _DNode, prod: List[CircuitNode]):
622
595
  self.parent = parent
@@ -629,33 +602,40 @@ cdef class _DNodeProduct:
629
602
  return 'DNodeProduct(' + str(self.parent) + ', ' + str(self.prod) + ')'
630
603
 
631
604
 
632
- class _DerivativeHelper:
605
+ cdef class _DerivativeHelper:
633
606
  """
634
607
  A data structure to support efficient calculation of partial derivatives
635
608
  with respect to some function node `f`.
636
609
  """
637
610
 
611
+ cdef CircuitNode f
612
+ cdef CircuitNode zero
613
+ cdef CircuitNode one
614
+ cdef Circuit circuit
615
+ cdef dict[int, _DNode] d_nodes
616
+
638
617
  def __init__(self, f: CircuitNode):
639
618
  """
640
619
  Prepare to calculate partial derivatives with respect to `f`.
641
620
  """
642
- self.f: CircuitNode = f
643
- self.circuit: Circuit = f.circuit
644
- self.d_nodes: Dict[int, _DNode] = {} # map id(CircuitNode) to its DNode
621
+ self.f = f
622
+ self.circuit = f.circuit
623
+ self.d_nodes = {} # map id(CircuitNode) to its DNode
645
624
  self.zero = self.circuit.zero
646
625
  self.one = self.circuit.one
647
- top_d_node: _DNode = _DNode(f, self.one)
626
+
627
+ cdef _DNode top_d_node = _DNode(f, self.one)
648
628
  self.d_nodes[id(f)] = top_d_node
649
629
  self._mk_derivative_r(top_d_node)
650
630
 
651
- def derivative(self, node: CircuitNode) -> CircuitNode:
631
+ cdef CircuitNode derivative(self, CircuitNode node):
652
632
  d_node: Optional[_DNode] = self.d_nodes.get(id(node))
653
633
  if d_node is None:
654
634
  return self.zero
655
635
  else:
656
636
  return self._derivative(d_node)
657
637
 
658
- def derivative_self_mul(self, node: CircuitNode) -> CircuitNode:
638
+ cdef CircuitNode derivative_self_mul(self, CircuitNode node):
659
639
  d_node: Optional[_DNode] = self.d_nodes.get(id(node))
660
640
  if d_node is None:
661
641
  return self.zero
@@ -667,11 +647,11 @@ class _DerivativeHelper:
667
647
  elif d is self.one:
668
648
  d_node.derivative_self_mul = node
669
649
  else:
670
- d_node.derivative_self_mul = self.circuit.optimised_mul((d, node))
650
+ d_node.derivative_self_mul = self.circuit.optimised_mul(d, node)
671
651
 
672
652
  return d_node.derivative_self_mul
673
653
 
674
- def _derivative(self, d_node: _DNode) -> CircuitNode:
654
+ cdef CircuitNode _derivative(self, _DNode d_node):
675
655
  if d_node.derivative is not None:
676
656
  return d_node.derivative
677
657
 
@@ -685,14 +665,13 @@ class _DerivativeHelper:
685
665
  d_node.sum_prod = None
686
666
 
687
667
  # Construct the addition operation
688
- d_node.derivative = self.circuit.optimised_add(to_add)
668
+ d_node.derivative = self.circuit.optimised_add(*to_add)
689
669
 
690
670
  return d_node.derivative
691
671
 
692
- def _derivative_prod(self, prods: _DNodeProduct) -> CircuitNode:
693
- """
694
- Support `_derivative` by constructing the derivative product for the given _DNodeProduct.
695
- """
672
+ cdef CircuitNode _derivative_prod(self, _DNodeProduct prods):
673
+ # Support `_derivative` by constructing the derivative product for the given _DNodeProduct.
674
+
696
675
  # Get the derivative of the parent node.
697
676
  parent: CircuitNode = self._derivative(prods.parent)
698
677
 
@@ -706,59 +685,56 @@ class _DerivativeHelper:
706
685
  to_mul.append(arg)
707
686
 
708
687
  # Construct the multiplication operation
709
- return self.circuit.optimised_mul(to_mul)
688
+ return self.circuit.optimised_mul(*to_mul)
689
+
690
+ cdef void _mk_derivative_r(self, _DNode d_node):
691
+ # Construct a DNode for each argument of the given DNode.
710
692
 
711
- def _mk_derivative_r(self, d_node: _DNode) -> None:
712
- """
713
- Construct a DNode for each argument of the given DNode.
714
- """
715
693
  if d_node.processed:
716
694
  return
717
695
  d_node.processed = True
718
696
  node: CircuitNode = d_node.node
719
697
 
720
698
  if isinstance(node, OpNode):
721
- if node.symbol == ADD:
699
+ if node.symbol == c_ADD:
722
700
  for arg in node.args:
723
701
  child_d_node = self._add(arg, d_node, [])
724
702
  self._mk_derivative_r(child_d_node)
725
- elif node.symbol == MUL:
703
+ elif node.symbol == c_MUL:
726
704
  for arg in node.args:
727
705
  prod = [arg2 for arg2 in node.args if arg is not arg2]
728
706
  child_d_node = self._add(arg, d_node, prod)
729
707
  self._mk_derivative_r(child_d_node)
730
708
 
731
- def _add(self, node: CircuitNode, parent: _DNode, prod: List[CircuitNode]) -> _DNode:
732
- """
733
- Support for `_mk_derivative_r`.
734
-
735
- Add a _DNodeProduct(parent, negate, prod) to the DNode for the given circuit node.
736
-
737
- If the DNode for `node` does not yet exist, one will be created.
738
-
739
- The given circuit node may have multiple parents (i.e., a shared sub-expression). Therefore,
740
- this method may be called multiple times for a given node. Each time a new _DNodeProduct will be added.
741
-
742
- Args:
743
- node: the CircuitNode that the returned DNode is for.
744
- parent: the DNode of the parent node, i.e., `node` is an argument to the parent node.
745
- prod: other circuit nodes that need to be multiplied with the parent derivative when
746
- constructing a derivative for `node`.
709
+ cdef _DNode _add(self, CircuitNode node, _DNode parent, list[CircuitNode] prod):
710
+ # Support for `_mk_derivative_r`.
711
+ #
712
+ # Add a _DNodeProduct(parent, negate, prod) to the DNode for the given circuit node.
713
+ #
714
+ # If the DNode for `node` does not yet exist, one will be created.
715
+ #
716
+ # The given circuit node may have multiple parents (i.e., a shared sub-expression). Therefore,
717
+ # this method may be called multiple times for a given node. Each time a new _DNodeProduct will be added.
718
+ #
719
+ # Args:
720
+ # node: the CircuitNode that the returned DNode is for.
721
+ # parent: the DNode of the parent node, i.e., `node` is an argument to the parent node.
722
+ # prod: other circuit nodes that need to be multiplied with the parent derivative when
723
+ # constructing a derivative for `node`.
724
+ #
725
+ # Returns:
726
+ # the DNode for `node`.
747
727
 
748
- Returns:
749
- the DNode for `node`.
750
- """
751
728
  child_d_node: _DNode = self._get(node)
752
729
  child_d_node.sum_prod.append(_DNodeProduct(parent, prod))
753
730
  return child_d_node
754
731
 
755
- def _get(self, node: CircuitNode) -> _DNode:
756
- """
757
- Helper for derivatives.
732
+ cdef _DNode _get(self, CircuitNode node):
733
+ # Helper for derivatives.
734
+ #
735
+ # Get the DNode for the given circuit node.
736
+ # If no DNode exist for it yet, then one will be constructed.
758
737
 
759
- Get the DNode for the given circuit node.
760
- If no DNode exist for it yet, then one will be constructed.
761
- """
762
738
  node_id: int = id(node)
763
739
  d_node: Optional[_DNode] = self.d_nodes.get(node_id)
764
740
  if d_node is None:
@@ -767,7 +743,7 @@ class _DerivativeHelper:
767
743
  return d_node
768
744
 
769
745
 
770
- cdef void _reachable_op_nodes_r(CircuitNode node, set[int] seen, list[OpNode] result):
746
+ cdef void find_reachable_op_nodes_r(CircuitNode node, set[int] seen, list[OpNode] result):
771
747
  # Recursive helper for `reachable_op_nodes`. Performs a depth-first search.
772
748
  #
773
749
  # Args:
@@ -777,11 +753,10 @@ cdef void _reachable_op_nodes_r(CircuitNode node, set[int] seen, list[OpNode] re
777
753
  if isinstance(node, OpNode) and id(node) not in seen:
778
754
  seen.add(id(node))
779
755
  for arg in node.args:
780
- _reachable_op_nodes_r(arg, seen, result)
756
+ find_reachable_op_nodes_r(arg, seen, result)
781
757
  result.append(node)
782
758
 
783
-
784
- cdef void _reachable_op_nodes_seen_r(CircuitNode node, set[int] seen):
759
+ cdef void find_reachable_op_nodes_seen_r(CircuitNode node, set[int] seen):
785
760
  # Recursive helper for `remove_unreachable_op_nodes`. Performs a depth-first search.
786
761
  #
787
762
  # Args:
@@ -790,4 +765,4 @@ cdef void _reachable_op_nodes_seen_r(CircuitNode node, set[int] seen):
790
765
  if isinstance(node, OpNode) and id(node) not in seen:
791
766
  seen.add(id(node))
792
767
  for arg in node.args:
793
- _reachable_op_nodes_seen_r(arg, seen)
768
+ find_reachable_op_nodes_seen_r(arg, seen)
ck/circuit/_circuit_py.py CHANGED
@@ -721,7 +721,7 @@ class _DerivativeHelper:
721
721
  d_node.sum_prod = None
722
722
 
723
723
  # Construct the addition operation
724
- d_node.derivative = self.circuit.optimised_add(to_add)
724
+ d_node.derivative = self.circuit.optimised_add(*to_add)
725
725
 
726
726
  return d_node.derivative
727
727
 
@@ -742,7 +742,7 @@ class _DerivativeHelper:
742
742
  to_mul.append(arg)
743
743
 
744
744
  # Construct the multiplication operation
745
- return self.circuit.optimised_mul(to_mul)
745
+ return self.circuit.optimised_mul(*to_mul)
746
746
 
747
747
  def _mk_derivative_r(self, d_node: _DNode) -> None:
748
748
  """