tensorcircuit-nightly 1.2.1.dev20250723__py3-none-any.whl → 1.2.1.dev20250724__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +1 -1
- tensorcircuit/basecircuit.py +30 -14
- tensorcircuit/cons.py +46 -23
- tensorcircuit/experimental.py +346 -3
- tensorcircuit/quantum.py +2 -2
- {tensorcircuit_nightly-1.2.1.dev20250723.dist-info → tensorcircuit_nightly-1.2.1.dev20250724.dist-info}/METADATA +6 -6
- {tensorcircuit_nightly-1.2.1.dev20250723.dist-info → tensorcircuit_nightly-1.2.1.dev20250724.dist-info}/RECORD +12 -12
- tests/test_miscs.py +32 -0
- tests/test_stabilizer.py +2 -2
- {tensorcircuit_nightly-1.2.1.dev20250723.dist-info → tensorcircuit_nightly-1.2.1.dev20250724.dist-info}/WHEEL +0 -0
- {tensorcircuit_nightly-1.2.1.dev20250723.dist-info → tensorcircuit_nightly-1.2.1.dev20250724.dist-info}/licenses/LICENSE +0 -0
- {tensorcircuit_nightly-1.2.1.dev20250723.dist-info → tensorcircuit_nightly-1.2.1.dev20250724.dist-info}/top_level.txt +0 -0
tensorcircuit/__init__.py
CHANGED
tensorcircuit/basecircuit.py
CHANGED
|
@@ -441,27 +441,17 @@ class BaseCircuit(AbstractCircuit):
|
|
|
441
441
|
|
|
442
442
|
measure = measure_jit
|
|
443
443
|
|
|
444
|
-
def
|
|
444
|
+
def amplitude_before(self, l: Union[str, Tensor]) -> List[Gate]:
|
|
445
445
|
r"""
|
|
446
|
-
Returns the amplitude of the circuit given the bitstring l.
|
|
446
|
+
Returns the tensornetwor nodes for the amplitude of the circuit given the bitstring l.
|
|
447
447
|
For state simulator, it computes :math:`\langle l\vert \psi\rangle`,
|
|
448
448
|
for density matrix simulator, it computes :math:`Tr(\rho \vert l\rangle \langle 1\vert)`
|
|
449
449
|
Note how these two are different up to a square operation.
|
|
450
450
|
|
|
451
|
-
:Example:
|
|
452
|
-
|
|
453
|
-
>>> c = tc.Circuit(2)
|
|
454
|
-
>>> c.X(0)
|
|
455
|
-
>>> c.amplitude("10")
|
|
456
|
-
array(1.+0.j, dtype=complex64)
|
|
457
|
-
>>> c.CNOT(0, 1)
|
|
458
|
-
>>> c.amplitude("11")
|
|
459
|
-
array(1.+0.j, dtype=complex64)
|
|
460
|
-
|
|
461
451
|
:param l: The bitstring of 0 and 1s.
|
|
462
452
|
:type l: Union[str, Tensor]
|
|
463
|
-
:return: The amplitude of the circuit.
|
|
464
|
-
:rtype:
|
|
453
|
+
:return: The tensornetwork nodes for the amplitude of the circuit.
|
|
454
|
+
:rtype: List[Gate]
|
|
465
455
|
"""
|
|
466
456
|
no, d_edges = self._copy()
|
|
467
457
|
ms = []
|
|
@@ -502,6 +492,32 @@ class BaseCircuit(AbstractCircuit):
|
|
|
502
492
|
no.extend(ms)
|
|
503
493
|
if self.is_dm:
|
|
504
494
|
no.extend(msconj)
|
|
495
|
+
return no
|
|
496
|
+
|
|
497
|
+
def amplitude(self, l: Union[str, Tensor]) -> Tensor:
|
|
498
|
+
r"""
|
|
499
|
+
Returns the amplitude of the circuit given the bitstring l.
|
|
500
|
+
For state simulator, it computes :math:`\langle l\vert \psi\rangle`,
|
|
501
|
+
for density matrix simulator, it computes :math:`Tr(\rho \vert l\rangle \langle 1\vert)`
|
|
502
|
+
Note how these two are different up to a square operation.
|
|
503
|
+
|
|
504
|
+
:Example:
|
|
505
|
+
|
|
506
|
+
>>> c = tc.Circuit(2)
|
|
507
|
+
>>> c.X(0)
|
|
508
|
+
>>> c.amplitude("10")
|
|
509
|
+
array(1.+0.j, dtype=complex64)
|
|
510
|
+
>>> c.CNOT(0, 1)
|
|
511
|
+
>>> c.amplitude("11")
|
|
512
|
+
array(1.+0.j, dtype=complex64)
|
|
513
|
+
|
|
514
|
+
:param l: The bitstring of 0 and 1s.
|
|
515
|
+
:type l: Union[str, Tensor]
|
|
516
|
+
:return: The amplitude of the circuit.
|
|
517
|
+
:rtype: tn.Node.tensor
|
|
518
|
+
"""
|
|
519
|
+
no = self.amplitude_before(l)
|
|
520
|
+
|
|
505
521
|
return contractor(no).tensor
|
|
506
522
|
|
|
507
523
|
def probability(self) -> Tensor:
|
tensorcircuit/cons.py
CHANGED
|
@@ -23,6 +23,39 @@ from .simplify import _multi_remove
|
|
|
23
23
|
|
|
24
24
|
logger = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
|
+
## monkey patch
|
|
27
|
+
_NODE_CREATION_COUNTER = 0
|
|
28
|
+
_original_node_init = tn.Node.__init__
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@wraps(_original_node_init)
|
|
32
|
+
def _patched_node_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
|
33
|
+
"""Patched Node.__init__ to add a stable creation ID."""
|
|
34
|
+
global _NODE_CREATION_COUNTER
|
|
35
|
+
_original_node_init(self, *args, **kwargs)
|
|
36
|
+
self._stable_id_ = _NODE_CREATION_COUNTER
|
|
37
|
+
_NODE_CREATION_COUNTER += 1
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
tn.Node.__init__ = _patched_node_init
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _get_edge_stable_key(edge: tn.Edge) -> Tuple[int, int, int, int]:
|
|
44
|
+
n1, n2 = edge.node1, edge.node2
|
|
45
|
+
id1 = getattr(n1, "_stable_id_", -1)
|
|
46
|
+
id2 = getattr(n2, "_stable_id_", -1) if n2 is not None else -2 # -2 for dangling
|
|
47
|
+
|
|
48
|
+
if id1 > id2 or (id1 == id2 and edge.axis1 > edge.axis2):
|
|
49
|
+
id1, id2, ax1, ax2 = id2, id1, edge.axis2, edge.axis1
|
|
50
|
+
else:
|
|
51
|
+
ax1, ax2 = edge.axis1, edge.axis2
|
|
52
|
+
return (id1, ax1, id2, ax2)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def sorted_edges(edges: Iterator[tn.Edge]) -> List[tn.Edge]:
|
|
56
|
+
return sorted(edges, key=_get_edge_stable_key)
|
|
57
|
+
|
|
58
|
+
|
|
26
59
|
package_name = "tensorcircuit"
|
|
27
60
|
thismodule = sys.modules[__name__]
|
|
28
61
|
dtypestr = "complex64"
|
|
@@ -477,39 +510,29 @@ def _identity(*args: Any, **kws: Any) -> Any:
|
|
|
477
510
|
return args
|
|
478
511
|
|
|
479
512
|
|
|
480
|
-
def _sort_tuple_list(input_list: List[Any], output_list: List[Any]) -> List[Any]:
|
|
481
|
-
sorted_elements = [(tuple(sorted(t)), i) for i, t in enumerate(input_list)]
|
|
482
|
-
sorted_elements.sort()
|
|
483
|
-
return [output_list[i] for _, i in sorted_elements]
|
|
484
|
-
|
|
485
|
-
|
|
486
513
|
def _get_path_cache_friendly(
|
|
487
514
|
nodes: List[tn.Node], algorithm: Any
|
|
488
515
|
) -> Tuple[List[Tuple[int, int]], List[tn.Node]]:
|
|
489
516
|
nodes = list(nodes)
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
for n in nodes:
|
|
493
|
-
for e in n:
|
|
494
|
-
if id(e) not in mapping_dict:
|
|
495
|
-
mapping_dict[id(e)] = get_symbol(i)
|
|
496
|
-
i += 1
|
|
497
|
-
# TODO(@refraction-ray): may be not that cache friendly, since the edge id correspondence is not that fixed?
|
|
498
|
-
input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes]
|
|
499
|
-
# placeholder = [[1e20 for _ in range(100)]]
|
|
500
|
-
# order = np.argsort(np.array(list(map(sorted, input_sets)), dtype=object)) # type: ignore
|
|
501
|
-
# nodes_new = [nodes[i] for i in order]
|
|
502
|
-
nodes_new = _sort_tuple_list(input_sets, nodes)
|
|
517
|
+
|
|
518
|
+
nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
|
|
503
519
|
if isinstance(algorithm, list):
|
|
504
520
|
return algorithm, nodes_new
|
|
505
521
|
|
|
522
|
+
all_edges = tn.get_all_edges(nodes_new)
|
|
523
|
+
all_edges_sorted = sorted_edges(all_edges)
|
|
524
|
+
mapping_dict = {}
|
|
525
|
+
i = 0
|
|
526
|
+
for edge in all_edges_sorted:
|
|
527
|
+
if id(edge) not in mapping_dict:
|
|
528
|
+
mapping_dict[id(edge)] = get_symbol(i)
|
|
529
|
+
i += 1
|
|
530
|
+
|
|
506
531
|
input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
|
|
507
532
|
output_set = list(
|
|
508
|
-
[mapping_dict[id(e)] for e in tn.get_subgraph_dangling(nodes_new)]
|
|
533
|
+
[mapping_dict[id(e)] for e in sorted_edges(tn.get_subgraph_dangling(nodes_new))]
|
|
509
534
|
)
|
|
510
|
-
size_dict = {
|
|
511
|
-
mapping_dict[id(edge)]: edge.dimension for edge in tn.get_all_edges(nodes_new)
|
|
512
|
-
}
|
|
535
|
+
size_dict = {mapping_dict[id(edge)]: edge.dimension for edge in all_edges_sorted}
|
|
513
536
|
logger.debug("input_sets: %s" % input_sets)
|
|
514
537
|
logger.debug("output_set: %s" % output_set)
|
|
515
538
|
logger.debug("size_dict: %s" % size_dict)
|
tensorcircuit/experimental.py
CHANGED
|
@@ -3,16 +3,19 @@ Experimental features
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from functools import partial
|
|
6
|
-
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Callable, Dict, Optional, Tuple, List, Sequence, Union
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
|
|
10
|
-
from .cons import backend, dtypestr, contractor, rdtypestr
|
|
11
|
+
from .cons import backend, dtypestr, contractor, rdtypestr, get_tn_info
|
|
11
12
|
from .gates import Gate
|
|
12
13
|
|
|
13
14
|
Tensor = Any
|
|
14
15
|
Circuit = Any
|
|
15
16
|
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
16
19
|
|
|
17
20
|
def adaptive_vmap(
|
|
18
21
|
f: Callable[..., Any],
|
|
@@ -504,7 +507,7 @@ def evol_local(
|
|
|
504
507
|
h_fun: Callable[..., Tensor],
|
|
505
508
|
t: float,
|
|
506
509
|
*args: Any,
|
|
507
|
-
**solver_kws: Any
|
|
510
|
+
**solver_kws: Any,
|
|
508
511
|
) -> Circuit:
|
|
509
512
|
"""
|
|
510
513
|
ode evolution of time dependent Hamiltonian on circuit of given indices
|
|
@@ -626,3 +629,343 @@ def jax_jitted_function_load(filename: str) -> Callable[..., Any]:
|
|
|
626
629
|
|
|
627
630
|
|
|
628
631
|
jax_func_load = jax_jitted_function_load
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
PADDING_VALUE = -1
|
|
635
|
+
jaxlib: Any
|
|
636
|
+
ctg: Any
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
class DistributedContractor:
|
|
640
|
+
"""
|
|
641
|
+
A distributed tensor network contractor that parallelizes computations across multiple devices.
|
|
642
|
+
|
|
643
|
+
This class uses cotengra to find optimal contraction paths and distributes the computational
|
|
644
|
+
load across multiple devices (e.g., GPUs) for efficient tensor network calculations.
|
|
645
|
+
Particularly useful for large-scale quantum circuit simulations and variational quantum algorithms.
|
|
646
|
+
|
|
647
|
+
Example:
|
|
648
|
+
>>> def nodes_fn(params):
|
|
649
|
+
... c = tc.Circuit(4)
|
|
650
|
+
... c.rx(0, theta=params[0])
|
|
651
|
+
... return c.expectation_before([tc.gates.z(), [0]], reuse=False)
|
|
652
|
+
>>> dc = DistributedContractor(nodes_fn, params)
|
|
653
|
+
>>> value, grad = dc.value_and_grad(params)
|
|
654
|
+
|
|
655
|
+
:param nodes_fn: Function that takes parameters and returns a list of tensor network nodes
|
|
656
|
+
:type nodes_fn: Callable[[Tensor], List[Gate]]
|
|
657
|
+
:param params: Initial parameters used to determine the tensor network structure
|
|
658
|
+
:type params: Tensor
|
|
659
|
+
:param cotengra_options: Configuration options passed to the cotengra optimizer. Defaults to None
|
|
660
|
+
:type cotengra_options: Optional[Dict[str, Any]], optional
|
|
661
|
+
:param devices: List of devices to use. If None, uses all available local devices
|
|
662
|
+
:type devices: Optional[List[Any]], optional
|
|
663
|
+
"""
|
|
664
|
+
|
|
665
|
+
def __init__(
|
|
666
|
+
self,
|
|
667
|
+
nodes_fn: Callable[[Tensor], List[Gate]],
|
|
668
|
+
params: Tensor,
|
|
669
|
+
cotengra_options: Optional[Dict[str, Any]] = None,
|
|
670
|
+
devices: Optional[List[Any]] = None,
|
|
671
|
+
) -> None:
|
|
672
|
+
global jaxlib
|
|
673
|
+
global ctg
|
|
674
|
+
|
|
675
|
+
logger.info("Initializing DistributedContractor...")
|
|
676
|
+
import cotengra as ctg
|
|
677
|
+
import jax as jaxlib
|
|
678
|
+
|
|
679
|
+
self.nodes_fn = nodes_fn
|
|
680
|
+
if devices is None:
|
|
681
|
+
self.num_devices = jaxlib.local_device_count()
|
|
682
|
+
self.devices = jaxlib.local_devices()
|
|
683
|
+
# TODO(@refraction-ray): multi host support
|
|
684
|
+
else:
|
|
685
|
+
self.devices = devices
|
|
686
|
+
self.num_devices = len(devices)
|
|
687
|
+
|
|
688
|
+
if self.num_devices <= 1:
|
|
689
|
+
logger.info("DistributedContractor is running on a single device.")
|
|
690
|
+
|
|
691
|
+
self._params_template = params
|
|
692
|
+
self._backend = "jax"
|
|
693
|
+
self._compiled_v_fns: Dict[
|
|
694
|
+
Tuple[Callable[[Tensor], Tensor], str],
|
|
695
|
+
Callable[[Any, Tensor, Tensor], Tensor],
|
|
696
|
+
] = {}
|
|
697
|
+
self._compiled_vg_fns: Dict[
|
|
698
|
+
Tuple[Callable[[Tensor], Tensor], str],
|
|
699
|
+
Callable[[Any, Tensor, Tensor], Tensor],
|
|
700
|
+
] = {}
|
|
701
|
+
|
|
702
|
+
logger.info("Running cotengra pathfinder... (This may take a while)")
|
|
703
|
+
nodes = self.nodes_fn(self._params_template)
|
|
704
|
+
tn_info, _ = get_tn_info(nodes)
|
|
705
|
+
default_cotengra_options = {
|
|
706
|
+
"slicing_reconf_opts": {"target_size": 2**28},
|
|
707
|
+
"max_repeats": 128,
|
|
708
|
+
"progbar": True,
|
|
709
|
+
"minimize": "write",
|
|
710
|
+
"parallel": "auto",
|
|
711
|
+
}
|
|
712
|
+
if cotengra_options:
|
|
713
|
+
default_cotengra_options = cotengra_options
|
|
714
|
+
|
|
715
|
+
opt = ctg.ReusableHyperOptimizer(**default_cotengra_options)
|
|
716
|
+
self.tree = opt.search(*tn_info)
|
|
717
|
+
actual_num_slices = self.tree.nslices
|
|
718
|
+
|
|
719
|
+
print("\n--- Contraction Path Info ---")
|
|
720
|
+
stats = self.tree.contract_stats()
|
|
721
|
+
print(f"Path found with {actual_num_slices} slices.")
|
|
722
|
+
print(
|
|
723
|
+
f"Arithmetic Intensity (higher is better): {self.tree.arithmetic_intensity():.2f}"
|
|
724
|
+
)
|
|
725
|
+
print("flops (TFlops):", stats["flops"] / 2**40 / self.num_devices)
|
|
726
|
+
print("write (GB):", stats["write"] / 2**27 / actual_num_slices)
|
|
727
|
+
print("size (GB):", stats["size"] / 2**27)
|
|
728
|
+
print("-----------------------------\n")
|
|
729
|
+
|
|
730
|
+
slices_per_device = int(np.ceil(actual_num_slices / self.num_devices))
|
|
731
|
+
padded_size = slices_per_device * self.num_devices
|
|
732
|
+
slice_indices = np.arange(actual_num_slices)
|
|
733
|
+
padded_slice_indices = np.full(padded_size, PADDING_VALUE, dtype=np.int32)
|
|
734
|
+
padded_slice_indices[:actual_num_slices] = slice_indices
|
|
735
|
+
self.batched_slice_indices = backend.convert_to_tensor(
|
|
736
|
+
padded_slice_indices.reshape(self.num_devices, slices_per_device)
|
|
737
|
+
)
|
|
738
|
+
print(
|
|
739
|
+
f"Distributing across {self.num_devices} devices. Each device will sequentially process "
|
|
740
|
+
f"up to {slices_per_device} slices."
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
self._compiled_vg_fn = None
|
|
744
|
+
self._compiled_v_fn = None
|
|
745
|
+
|
|
746
|
+
logger.info("Initialization complete.")
|
|
747
|
+
|
|
748
|
+
def _get_single_slice_contraction_fn(
|
|
749
|
+
self, op: Optional[Callable[[Tensor], Tensor]] = None
|
|
750
|
+
) -> Callable[[Any, Tensor, int], Tensor]:
|
|
751
|
+
if op is None:
|
|
752
|
+
op = backend.sum
|
|
753
|
+
|
|
754
|
+
def single_slice_contraction(
|
|
755
|
+
tree: ctg.ContractionTree, params: Tensor, slice_idx: int
|
|
756
|
+
) -> Tensor:
|
|
757
|
+
nodes = self.nodes_fn(params)
|
|
758
|
+
_, standardized_nodes = get_tn_info(nodes)
|
|
759
|
+
input_arrays = [node.tensor for node in standardized_nodes]
|
|
760
|
+
sliced_arrays = tree.slice_arrays(input_arrays, slice_idx)
|
|
761
|
+
result = tree.contract_core(sliced_arrays, backend=self._backend)
|
|
762
|
+
return op(result)
|
|
763
|
+
|
|
764
|
+
return single_slice_contraction
|
|
765
|
+
|
|
766
|
+
def _get_device_sum_vg_fn(
|
|
767
|
+
self,
|
|
768
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
769
|
+
output_dtype: Optional[str] = None,
|
|
770
|
+
) -> Callable[[Any, Tensor, Tensor], Tuple[Tensor, Tensor]]:
|
|
771
|
+
post_processing = lambda x: backend.real(backend.sum(x))
|
|
772
|
+
if op is None:
|
|
773
|
+
op = post_processing
|
|
774
|
+
base_fn = self._get_single_slice_contraction_fn(op=op)
|
|
775
|
+
# to ensure the output is real so that can be differentiated
|
|
776
|
+
single_slice_vg_fn = jaxlib.value_and_grad(base_fn, argnums=1)
|
|
777
|
+
|
|
778
|
+
if output_dtype is None:
|
|
779
|
+
output_dtype = rdtypestr
|
|
780
|
+
|
|
781
|
+
def device_sum_fn(
|
|
782
|
+
tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
|
|
783
|
+
) -> Tuple[Tensor, Tensor]:
|
|
784
|
+
def scan_body(
|
|
785
|
+
carry: Tuple[Tensor, Tensor], slice_idx: Tensor
|
|
786
|
+
) -> Tuple[Tuple[Tensor, Tensor], None]:
|
|
787
|
+
acc_value, acc_grads = carry
|
|
788
|
+
|
|
789
|
+
def compute_and_add() -> Tuple[Tensor, Tensor]:
|
|
790
|
+
value_slice, grads_slice = single_slice_vg_fn(
|
|
791
|
+
tree, params, slice_idx
|
|
792
|
+
)
|
|
793
|
+
new_value = acc_value + value_slice
|
|
794
|
+
new_grads = jaxlib.tree_util.tree_map(
|
|
795
|
+
jaxlib.numpy.add, acc_grads, grads_slice
|
|
796
|
+
)
|
|
797
|
+
return new_value, new_grads
|
|
798
|
+
|
|
799
|
+
def do_nothing() -> Tuple[Tensor, Tensor]:
|
|
800
|
+
return acc_value, acc_grads
|
|
801
|
+
|
|
802
|
+
return (
|
|
803
|
+
jaxlib.lax.cond(
|
|
804
|
+
slice_idx == PADDING_VALUE, do_nothing, compute_and_add
|
|
805
|
+
),
|
|
806
|
+
None,
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
initial_carry = (
|
|
810
|
+
backend.cast(backend.convert_to_tensor(0.0), dtype=output_dtype),
|
|
811
|
+
jaxlib.tree_util.tree_map(lambda x: jaxlib.numpy.zeros_like(x), params),
|
|
812
|
+
)
|
|
813
|
+
(final_value, final_grads), _ = jaxlib.lax.scan(
|
|
814
|
+
scan_body, initial_carry, slice_indices_for_device
|
|
815
|
+
)
|
|
816
|
+
return final_value, final_grads
|
|
817
|
+
|
|
818
|
+
return device_sum_fn
|
|
819
|
+
|
|
820
|
+
def _get_device_sum_v_fn(
|
|
821
|
+
self,
|
|
822
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
823
|
+
output_dtype: Optional[str] = None,
|
|
824
|
+
) -> Callable[[Any, Tensor, Tensor], Tensor]:
|
|
825
|
+
base_fn = self._get_single_slice_contraction_fn(op=op)
|
|
826
|
+
if output_dtype is None:
|
|
827
|
+
output_dtype = dtypestr
|
|
828
|
+
|
|
829
|
+
def device_sum_fn(
|
|
830
|
+
tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
|
|
831
|
+
) -> Tensor:
|
|
832
|
+
def scan_body(
|
|
833
|
+
carry_value: Tensor, slice_idx: Tensor
|
|
834
|
+
) -> Tuple[Tensor, None]:
|
|
835
|
+
def compute_and_add() -> Tensor:
|
|
836
|
+
return carry_value + base_fn(tree, params, slice_idx)
|
|
837
|
+
|
|
838
|
+
return (
|
|
839
|
+
jaxlib.lax.cond(
|
|
840
|
+
slice_idx == PADDING_VALUE, lambda: carry_value, compute_and_add
|
|
841
|
+
),
|
|
842
|
+
None,
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
initial_carry = backend.cast(
|
|
846
|
+
backend.convert_to_tensor(0.0), dtype=output_dtype
|
|
847
|
+
)
|
|
848
|
+
final_value, _ = jaxlib.lax.scan(
|
|
849
|
+
scan_body, initial_carry, slice_indices_for_device
|
|
850
|
+
)
|
|
851
|
+
return final_value
|
|
852
|
+
|
|
853
|
+
return device_sum_fn
|
|
854
|
+
|
|
855
|
+
def _get_or_compile_fn(
|
|
856
|
+
self,
|
|
857
|
+
cache: Dict[
|
|
858
|
+
Tuple[Callable[[Tensor], Tensor], str],
|
|
859
|
+
Callable[[Any, Tensor, Tensor], Tensor],
|
|
860
|
+
],
|
|
861
|
+
fn_getter: Callable[..., Any],
|
|
862
|
+
op: Optional[Callable[[Tensor], Tensor]],
|
|
863
|
+
output_dtype: Optional[str],
|
|
864
|
+
) -> Callable[[Any, Tensor, Tensor], Tensor]:
|
|
865
|
+
"""
|
|
866
|
+
Gets a compiled pmap-ed function from cache or compiles and caches it.
|
|
867
|
+
|
|
868
|
+
The cache key is a tuple of (op, output_dtype). Caution on lambda function!
|
|
869
|
+
|
|
870
|
+
Returns:
|
|
871
|
+
The compiled, pmap-ed JAX function.
|
|
872
|
+
"""
|
|
873
|
+
cache_key = (op, output_dtype)
|
|
874
|
+
if cache_key not in cache:
|
|
875
|
+
device_fn = fn_getter(op=op, output_dtype=output_dtype)
|
|
876
|
+
compiled_fn = jaxlib.pmap(
|
|
877
|
+
device_fn,
|
|
878
|
+
in_axes=(
|
|
879
|
+
None,
|
|
880
|
+
None,
|
|
881
|
+
0,
|
|
882
|
+
), # tree: broadcast, params: broadcast, indices: map
|
|
883
|
+
static_broadcasted_argnums=(0,), # arg 0 (tree) is a static argument
|
|
884
|
+
devices=self.devices,
|
|
885
|
+
)
|
|
886
|
+
cache[cache_key] = compiled_fn # type: ignore
|
|
887
|
+
return cache[cache_key] # type: ignore
|
|
888
|
+
|
|
889
|
+
def value_and_grad(
|
|
890
|
+
self,
|
|
891
|
+
params: Tensor,
|
|
892
|
+
aggregate: bool = True,
|
|
893
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
894
|
+
output_dtype: Optional[str] = None,
|
|
895
|
+
) -> Tuple[Tensor, Tensor]:
|
|
896
|
+
"""
|
|
897
|
+
Calculates the value and gradient, compiling the pmap function if needed for the first call.
|
|
898
|
+
|
|
899
|
+
:param params: Parameters for the `nodes_fn` input
|
|
900
|
+
:type params: Tensor
|
|
901
|
+
:param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
|
|
902
|
+
:type aggregate: bool, optional
|
|
903
|
+
:param op: Optional post-processing function for the output, defaults to None (corresponding to `backend.real`)
|
|
904
|
+
op is a cache key, so dont directly pass lambda function for op
|
|
905
|
+
:type op: Optional[Callable[[Tensor], Tensor]], optional
|
|
906
|
+
:param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `rdtypestr`)
|
|
907
|
+
:type output_dtype: Optional[str], optional
|
|
908
|
+
"""
|
|
909
|
+
compiled_vg_fn = self._get_or_compile_fn(
|
|
910
|
+
cache=self._compiled_vg_fns,
|
|
911
|
+
fn_getter=self._get_device_sum_vg_fn,
|
|
912
|
+
op=op,
|
|
913
|
+
output_dtype=output_dtype,
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
device_values, device_grads = compiled_vg_fn(
|
|
917
|
+
self.tree, params, self.batched_slice_indices
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
if aggregate:
|
|
921
|
+
total_value = backend.sum(device_values)
|
|
922
|
+
total_grad = jaxlib.tree_util.tree_map(
|
|
923
|
+
lambda x: backend.sum(x, axis=0), device_grads
|
|
924
|
+
)
|
|
925
|
+
return total_value, total_grad
|
|
926
|
+
return device_values, device_grads
|
|
927
|
+
|
|
928
|
+
def value(
|
|
929
|
+
self,
|
|
930
|
+
params: Tensor,
|
|
931
|
+
aggregate: bool = True,
|
|
932
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
933
|
+
output_dtype: Optional[str] = None,
|
|
934
|
+
) -> Tensor:
|
|
935
|
+
"""
|
|
936
|
+
Calculates the value, compiling the pmap function for the first call.
|
|
937
|
+
|
|
938
|
+
:param params: Parameters for the `nodes_fn` input
|
|
939
|
+
:type params: Tensor
|
|
940
|
+
:param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
|
|
941
|
+
:type aggregate: bool, optional
|
|
942
|
+
:param op: Optional post-processing function for the output, defaults to None (corresponding to identity)
|
|
943
|
+
op is a cache key, so dont directly pass lambda function for op
|
|
944
|
+
:type op: Optional[Callable[[Tensor], Tensor]], optional
|
|
945
|
+
:param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `dtypestr`)
|
|
946
|
+
:type output_dtype: Optional[str], optional
|
|
947
|
+
"""
|
|
948
|
+
compiled_v_fn = self._get_or_compile_fn(
|
|
949
|
+
cache=self._compiled_v_fns,
|
|
950
|
+
fn_getter=self._get_device_sum_v_fn,
|
|
951
|
+
op=op,
|
|
952
|
+
output_dtype=output_dtype,
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
device_values = compiled_v_fn(self.tree, params, self.batched_slice_indices)
|
|
956
|
+
|
|
957
|
+
if aggregate:
|
|
958
|
+
return backend.sum(device_values)
|
|
959
|
+
return device_values
|
|
960
|
+
|
|
961
|
+
def grad(
|
|
962
|
+
self,
|
|
963
|
+
params: Tensor,
|
|
964
|
+
aggregate: bool = True,
|
|
965
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
966
|
+
output_dtype: Optional[str] = None,
|
|
967
|
+
) -> Tensor:
|
|
968
|
+
_, grad = self.value_and_grad(
|
|
969
|
+
params, aggregate=aggregate, op=op, output_dtype=output_dtype
|
|
970
|
+
)
|
|
971
|
+
return grad
|
tensorcircuit/quantum.py
CHANGED
|
@@ -71,7 +71,7 @@ def _reachable(nodes: List[AbstractNode]) -> List[AbstractNode]:
|
|
|
71
71
|
if n not in seen_nodes and n not in node_que[i + 1 :]:
|
|
72
72
|
node_que.append(n)
|
|
73
73
|
i += 1
|
|
74
|
-
return seen_nodes
|
|
74
|
+
return sorted(seen_nodes, key=lambda node: getattr(node, "_stable_id_", -1))
|
|
75
75
|
|
|
76
76
|
|
|
77
77
|
def reachable(
|
|
@@ -1164,7 +1164,7 @@ def tn2qop(tn_mpo: Any) -> QuOperator:
|
|
|
1164
1164
|
nwires = len(tn_mpo)
|
|
1165
1165
|
mpo = []
|
|
1166
1166
|
for i in range(nwires):
|
|
1167
|
-
mpo.append(Node(tn_mpo[i]))
|
|
1167
|
+
mpo.append(Node(tn_mpo[i], name=f"mpo_{i}"))
|
|
1168
1168
|
|
|
1169
1169
|
for i in range(nwires - 1):
|
|
1170
1170
|
connect(mpo[i][1], mpo[i + 1][0])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tensorcircuit-nightly
|
|
3
|
-
Version: 1.2.1.
|
|
3
|
+
Version: 1.2.1.dev20250724
|
|
4
4
|
Summary: nightly release for tensorcircuit
|
|
5
5
|
Home-page: https://github.com/refraction-ray/tensorcircuit-dev
|
|
6
6
|
Author: TensorCircuit Authors
|
|
@@ -60,17 +60,17 @@ Dynamic: summary
|
|
|
60
60
|
|
|
61
61
|
<p align="center"> English | <a href="README_cn.md"> 简体中文 </a></p>
|
|
62
62
|
|
|
63
|
-
TensorCircuit-NG is the next-generation open-source high-performance quantum software framework, built upon tensornetwork engines, supporting for automatic differentiation, just-in-time compiling, hardware acceleration, and
|
|
63
|
+
TensorCircuit-NG is the next-generation open-source high-performance quantum software framework, built upon tensornetwork engines, supporting for automatic differentiation, just-in-time compiling, hardware acceleration, vectorized parallelism and distributed training, providing unified infrastructures and interfaces for quantum programming. It can compose quantum circuits, neural networks and tensor networks seamlessly with high simulation efficiency and flexibility.
|
|
64
64
|
|
|
65
|
-
TensorCircuit-NG is built on top of modern machine learning frameworks: Jax, TensorFlow, and PyTorch. It is specifically suitable for large-scale simulations of quantum-classical hybrid paradigm and variational quantum algorithms in ideal, noisy, Clifford, approximate and
|
|
65
|
+
TensorCircuit-NG is built on top of modern machine learning frameworks: Jax, TensorFlow, and PyTorch. It is specifically suitable for large-scale simulations of quantum-classical hybrid paradigm and variational quantum algorithms in ideal, noisy, Clifford, approximate, analog and fermionic cases. It also supports quantum hardware access and provides CPU/GPU/QPU hybrid deployment solutions.
|
|
66
66
|
|
|
67
|
-
TensorCircuit-NG is the actively maintained official version and a [fully compatible](https://tensorcircuit-ng.readthedocs.io/en/latest/faq.html#what-is-the-relation-between-tensorcircuit-and-tensorcircuit-ng) successor to TensorCircuit with more new features (stabilizer circuit
|
|
67
|
+
TensorCircuit-NG is the actively maintained official version and a [fully compatible](https://tensorcircuit-ng.readthedocs.io/en/latest/faq.html#what-is-the-relation-between-tensorcircuit-and-tensorcircuit-ng) successor to TensorCircuit with more new features (stabilizer circuit, multi-card distributed simulation, etc.) and bug fixes (support latest `numpy>2` and `qiskit>1`).
|
|
68
68
|
|
|
69
69
|
## Getting Started
|
|
70
70
|
|
|
71
71
|
Please begin with [Quick Start](/docs/source/quickstart.rst) in the [full documentation](https://tensorcircuit-ng.readthedocs.io/).
|
|
72
72
|
|
|
73
|
-
For more information on software usage, sota algorithm implementation and engineer paradigm demonstration, please refer to 80+ [example scripts](/examples) and 30+ [tutorial notebooks](https://tensorcircuit-ng.readthedocs.io/en/latest/#tutorials). API docstrings and test cases in [tests](/tests) are also informative. One can also refer to tensorcircuit-ng [deepwiki](https://deepwiki.com/tensorcircuit/tensorcircuit-ng)
|
|
73
|
+
For more information on software usage, sota algorithm implementation and engineer paradigm demonstration, please refer to 80+ [example scripts](/examples) and 30+ [tutorial notebooks](https://tensorcircuit-ng.readthedocs.io/en/latest/#tutorials). API docstrings and test cases in [tests](/tests) are also informative. One can also refer to tensorcircuit-ng [deepwiki](https://deepwiki.com/tensorcircuit/tensorcircuit-ng) by LLM.
|
|
74
74
|
|
|
75
75
|
For beginners, please refer to [quantum computing lectures with TC-NG](https://github.com/sxzgroup/qc_lecture) to learn both quantum computing basics and representative usage of TensorCircuit-NG.
|
|
76
76
|
|
|
@@ -170,7 +170,7 @@ The package is written in pure Python and can be obtained via pip as:
|
|
|
170
170
|
pip install tensorcircuit-ng
|
|
171
171
|
```
|
|
172
172
|
|
|
173
|
-
We recommend you install this package with tensorflow also installed as:
|
|
173
|
+
We recommend you install this package with tensorflow or jax also installed as:
|
|
174
174
|
|
|
175
175
|
```python
|
|
176
176
|
pip install "tensorcircuit-ng[tensorflow]"
|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
tensorcircuit/__init__.py,sha256=
|
|
1
|
+
tensorcircuit/__init__.py,sha256=crGf_yKgxwxVS0dPKaNTaY7srmWi_K0oqt9Ixlj6CZ0,2032
|
|
2
2
|
tensorcircuit/about.py,sha256=DazTswU2nAwOmASTaDII3L04PVtaQ7oiWPty5YMI3Wk,5267
|
|
3
3
|
tensorcircuit/abstractcircuit.py,sha256=0osacPqq7B1EJki-cI1aLYoVRmjFaG9q3XevWMs7SsA,44125
|
|
4
4
|
tensorcircuit/asciiart.py,sha256=neY1OWFwtoW5cHPNwkQHgRPktDniQvdlP9QKHkk52fM,8236
|
|
5
|
-
tensorcircuit/basecircuit.py,sha256=
|
|
5
|
+
tensorcircuit/basecircuit.py,sha256=ipCg3J55sgkciUZ2qCZqpVqE00YIWRlACu509nktg3I,37203
|
|
6
6
|
tensorcircuit/channels.py,sha256=CFQxWI-JmkIxexslCBdjp_RSxUbHs6eAJv4LvlXXXCY,28637
|
|
7
7
|
tensorcircuit/circuit.py,sha256=jC1Bb9A06pt6XX7muC-Q72BR9HS6n0Ft6aMjOGcz9iM,36428
|
|
8
|
-
tensorcircuit/cons.py,sha256=
|
|
8
|
+
tensorcircuit/cons.py,sha256=0fE9UY02TNI3FyQWGyGCKuYpwkMlV-a1cMbTZveFYmk,31125
|
|
9
9
|
tensorcircuit/densitymatrix.py,sha256=VqMBnWCxO5-OsOp6LOdc5RS2AzmB3U4-w40Vn_lqygo,14865
|
|
10
|
-
tensorcircuit/experimental.py,sha256=
|
|
10
|
+
tensorcircuit/experimental.py,sha256=RW97ncitCfO1QJLAUbKBvm2Tsc0hzKhqkC65ShA9-Q0,34456
|
|
11
11
|
tensorcircuit/fgs.py,sha256=eIi38DnQBGxY4itxqzGVbi8cAjB3vCYAX87xcJVJmoo,40846
|
|
12
12
|
tensorcircuit/gates.py,sha256=x-wA7adVpP7o0AQLt_xYUScFKj8tU_wUOV2mR1GyrPc,29322
|
|
13
13
|
tensorcircuit/keras.py,sha256=5OF4dfhEeS8sRYglpqYtQsWPeqp7uK0i7-P-6RRJ7zQ,10126
|
|
14
14
|
tensorcircuit/mps_base.py,sha256=UZ-v8vsr_rAsKrfun8prVgbXJ-qsdqKy2DZIHpq3sxo,15400
|
|
15
15
|
tensorcircuit/mpscircuit.py,sha256=Jv4nsRyOhQxSHpDUJpb9OS6A5E3bTJoIHYGzwgs7NYU,34591
|
|
16
16
|
tensorcircuit/noisemodel.py,sha256=vzxpoYEZbHVC4a6g7_Jk4dxsHi4wvhpRFwud8b616Qo,11878
|
|
17
|
-
tensorcircuit/quantum.py,sha256=
|
|
17
|
+
tensorcircuit/quantum.py,sha256=LNkIv5cJ2KG6puC18zTuXi-5cojW1Tnz-N-WjZ0Qu5Q,90217
|
|
18
18
|
tensorcircuit/shadows.py,sha256=6XmWNubbuaxFNvZVWu-RXd0lN9Jkk-xwong_K8o8_KE,17014
|
|
19
19
|
tensorcircuit/simplify.py,sha256=O11G3UYiVAc30GOfwXXmhLXwGZrQ8OVwLTMQMZp_XBc,9414
|
|
20
20
|
tensorcircuit/stabilizercircuit.py,sha256=4gDeTgko04j4dwt7NdJvl8NhqmB8JH75nZjdbLU3Aw0,15178
|
|
@@ -84,7 +84,7 @@ tensorcircuit/templates/dataset.py,sha256=ldPvCUlwjHU_S98E2ISQp34KqJzJPpPHmDIKJ4
|
|
|
84
84
|
tensorcircuit/templates/graphs.py,sha256=cPYrxjoem0xZ-Is9dZKAvEzWZL_FejfIRiCEOTA4qd4,3935
|
|
85
85
|
tensorcircuit/templates/lattice.py,sha256=F35ebANk0DSmSHLR0-Q_hUbcznyCmZjb4fKmvCMywmA,58575
|
|
86
86
|
tensorcircuit/templates/measurements.py,sha256=pzc5Aa9S416Ilg4aOY77Z6ZhUlYcXnAkQNQFTuHjFFs,10943
|
|
87
|
-
tensorcircuit_nightly-1.2.1.
|
|
87
|
+
tensorcircuit_nightly-1.2.1.dev20250724.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
88
88
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
89
89
|
tests/conftest.py,sha256=J9nHlLE3Zspz1rMyzadEuBWhaS5I4Q9sq0lnWybcdIA,1457
|
|
90
90
|
tests/test_backends.py,sha256=rClxb2gyAoGeXd_ZYVSAJ0zEvJ7z_2btAeFM_Iy_wwY,33925
|
|
@@ -100,7 +100,7 @@ tests/test_gates.py,sha256=rAIV2QFpFsA5bT1QivTSkhdarvwu5t0N3IOz4SEDrzg,4593
|
|
|
100
100
|
tests/test_interfaces.py,sha256=iJPmes8S8HkA9_PGjsu4Ike-vCXYyS1EMgnNKKXDNaU,16938
|
|
101
101
|
tests/test_keras.py,sha256=U453jukavmx0RMeTSDEgPzrNdHNEfK1CW0CqO3XCNKo,4841
|
|
102
102
|
tests/test_lattice.py,sha256=_ptDVK3EhS-X5fCQWiP8sHk3azdyGFuwqg6KMkBTkDE,65789
|
|
103
|
-
tests/test_miscs.py,sha256=
|
|
103
|
+
tests/test_miscs.py,sha256=Wo2fZ-Co4-iPm7n3F9NTxnXuabWi_J6uvrOr0GIMqvY,9175
|
|
104
104
|
tests/test_mpscircuit.py,sha256=mDXX8oQeFeHr_PdZvwqyDs_tVcVAqLmCERqlTAU7590,10552
|
|
105
105
|
tests/test_noisemodel.py,sha256=UYoMtCjwDaB-CCn5kLosofz-qTMiY4KGAFBjVtqqLPE,5637
|
|
106
106
|
tests/test_qaoa.py,sha256=hEcC_XVmKBGt9XgUGtbTO8eQQK4mjorgTIrfqZCeQls,2616
|
|
@@ -110,11 +110,11 @@ tests/test_quantum_attr.py,sha256=Zl6WbkbnTWVp6FL2rR21qBGsLoheoIEZXqWZKxfpDRs,12
|
|
|
110
110
|
tests/test_results.py,sha256=8cQO0ShkBc4_pB-fi9s35WJbuZl5ex5y1oElSV-GlRo,11882
|
|
111
111
|
tests/test_shadows.py,sha256=1T3kJesVJ5XfZrSncL80xdq-taGCSnTDF3eL15UlavY,5160
|
|
112
112
|
tests/test_simplify.py,sha256=35tbOu1QANsPvY1buLwNhqPnMkBOsnBtHn82qaukmgI,1175
|
|
113
|
-
tests/test_stabilizer.py,sha256=
|
|
113
|
+
tests/test_stabilizer.py,sha256=HdVRbEshg02CaNsqni_nRYY7KL5vhRBp9k1KGzOSE9I,5252
|
|
114
114
|
tests/test_templates.py,sha256=Xm9otFFaaBWG9TZpgJ-nNh9MBfRipTzFWL8fBOnie2k,7192
|
|
115
115
|
tests/test_torchnn.py,sha256=CHLTfWkF7Ses5_XnGFN_uv_JddfgenFEFzaDtSH8XYU,2848
|
|
116
116
|
tests/test_van.py,sha256=kAWz860ivlb5zAJuYpzuBe27qccT-Yf0jatf5uXtTo4,3163
|
|
117
|
-
tensorcircuit_nightly-1.2.1.
|
|
118
|
-
tensorcircuit_nightly-1.2.1.
|
|
119
|
-
tensorcircuit_nightly-1.2.1.
|
|
120
|
-
tensorcircuit_nightly-1.2.1.
|
|
117
|
+
tensorcircuit_nightly-1.2.1.dev20250724.dist-info/METADATA,sha256=zl0gksspBZ2l7gAYmvRyfOwclQEjtFkFy2k9cC46Fr8,34831
|
|
118
|
+
tensorcircuit_nightly-1.2.1.dev20250724.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
119
|
+
tensorcircuit_nightly-1.2.1.dev20250724.dist-info/top_level.txt,sha256=O_Iqeh2x02lasEYMI9iyPNNNtMzcpg5qvwMOkZQ7n4A,20
|
|
120
|
+
tensorcircuit_nightly-1.2.1.dev20250724.dist-info/RECORD,,
|
tests/test_miscs.py
CHANGED
|
@@ -280,3 +280,35 @@ def test_jax_function_load(jaxb, tmp_path):
|
|
|
280
280
|
os.path.join(tmp_path, "temp.bin")
|
|
281
281
|
)
|
|
282
282
|
np.testing.assert_allclose(f_load(K.ones([3])), 0.5403, atol=1e-4)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def test_distrubuted_contractor(jaxb):
|
|
286
|
+
def nodes_fn(params):
|
|
287
|
+
c = tc.Circuit(4)
|
|
288
|
+
c.rx(range(4), theta=params["x"])
|
|
289
|
+
c.cnot([0, 1, 2], [1, 2, 3])
|
|
290
|
+
c.ry(range(4), theta=params["y"])
|
|
291
|
+
return c.expectation_before([tc.gates.z(), [-1]], reuse=False)
|
|
292
|
+
|
|
293
|
+
params = {"x": np.ones([4]), "y": 0.3 * np.ones([4])}
|
|
294
|
+
dc = experimental.DistributedContractor(
|
|
295
|
+
nodes_fn,
|
|
296
|
+
params,
|
|
297
|
+
{
|
|
298
|
+
"slicing_reconf_opts": {"target_size": 2**3},
|
|
299
|
+
"max_repeats": 8,
|
|
300
|
+
"minimize": "write",
|
|
301
|
+
"parallel": False,
|
|
302
|
+
},
|
|
303
|
+
)
|
|
304
|
+
value, grad = dc.value_and_grad(params)
|
|
305
|
+
assert grad["y"].shape == (4,)
|
|
306
|
+
|
|
307
|
+
def baseline(params):
|
|
308
|
+
c = tc.Circuit(4)
|
|
309
|
+
c.rx(range(4), theta=params["x"])
|
|
310
|
+
c.cnot([0, 1, 2], [1, 2, 3])
|
|
311
|
+
c.ry(range(4), theta=params["y"])
|
|
312
|
+
return c.expectation_ps(z=[-1])
|
|
313
|
+
|
|
314
|
+
np.testing.assert_allclose(value, baseline(params), atol=1e-6)
|
tests/test_stabilizer.py
CHANGED
|
@@ -178,13 +178,13 @@ def test_circuit_inputs():
|
|
|
178
178
|
|
|
179
179
|
def test_depolarize():
|
|
180
180
|
r = []
|
|
181
|
-
for _ in range(
|
|
181
|
+
for _ in range(40):
|
|
182
182
|
c = tc.StabilizerCircuit(2)
|
|
183
183
|
c.h(0)
|
|
184
184
|
c.depolarizing(0, 1, p=0.2)
|
|
185
185
|
c.h(0)
|
|
186
186
|
r.append(c.expectation_ps(z=[0]))
|
|
187
|
-
assert 4 < np.sum(r) <
|
|
187
|
+
assert 4 < np.sum(r) < 39
|
|
188
188
|
|
|
189
189
|
|
|
190
190
|
def test_tableau_inputs():
|
|
File without changes
|
|
File without changes
|