tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +18 -2
- tensorcircuit/about.py +46 -0
- tensorcircuit/abstractcircuit.py +4 -0
- tensorcircuit/analogcircuit.py +413 -0
- tensorcircuit/applications/layers.py +1 -1
- tensorcircuit/applications/van.py +1 -1
- tensorcircuit/backends/abstract_backend.py +320 -7
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +102 -4
- tensorcircuit/backends/jax_ops.py +110 -1
- tensorcircuit/backends/numpy_backend.py +49 -3
- tensorcircuit/backends/pytorch_backend.py +92 -3
- tensorcircuit/backends/tensorflow_backend.py +102 -3
- tensorcircuit/basecircuit.py +157 -98
- tensorcircuit/circuit.py +115 -57
- tensorcircuit/cloud/local.py +1 -1
- tensorcircuit/cloud/quafu_provider.py +1 -1
- tensorcircuit/cloud/tencent.py +1 -1
- tensorcircuit/compiler/simple_compiler.py +2 -2
- tensorcircuit/cons.py +142 -21
- tensorcircuit/densitymatrix.py +43 -14
- tensorcircuit/experimental.py +387 -129
- tensorcircuit/fgs.py +282 -81
- tensorcircuit/gates.py +66 -22
- tensorcircuit/interfaces/__init__.py +1 -3
- tensorcircuit/interfaces/jax.py +189 -0
- tensorcircuit/keras.py +3 -3
- tensorcircuit/mpscircuit.py +154 -65
- tensorcircuit/quantum.py +868 -152
- tensorcircuit/quditcircuit.py +733 -0
- tensorcircuit/quditgates.py +618 -0
- tensorcircuit/results/counts.py +147 -20
- tensorcircuit/results/readout_mitigation.py +4 -1
- tensorcircuit/shadows.py +1 -1
- tensorcircuit/simplify.py +3 -1
- tensorcircuit/stabilizercircuit.py +479 -0
- tensorcircuit/templates/__init__.py +2 -0
- tensorcircuit/templates/blocks.py +2 -2
- tensorcircuit/templates/hamiltonians.py +174 -0
- tensorcircuit/templates/lattice.py +1789 -0
- tensorcircuit/timeevol.py +896 -0
- tensorcircuit/translation.py +10 -3
- tensorcircuit/utils.py +7 -0
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
- tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
- tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
- tests/__init__.py +0 -0
- tests/conftest.py +0 -67
- tests/test_backends.py +0 -1031
- tests/test_calibrating.py +0 -149
- tests/test_channels.py +0 -365
- tests/test_circuit.py +0 -1699
- tests/test_cloud.py +0 -219
- tests/test_compiler.py +0 -147
- tests/test_dmcircuit.py +0 -555
- tests/test_ensemble.py +0 -72
- tests/test_fgs.py +0 -310
- tests/test_gates.py +0 -156
- tests/test_interfaces.py +0 -429
- tests/test_keras.py +0 -160
- tests/test_miscs.py +0 -277
- tests/test_mpscircuit.py +0 -341
- tests/test_noisemodel.py +0 -156
- tests/test_qaoa.py +0 -86
- tests/test_qem.py +0 -152
- tests/test_quantum.py +0 -526
- tests/test_quantum_attr.py +0 -42
- tests/test_results.py +0 -347
- tests/test_shadows.py +0 -160
- tests/test_simplify.py +0 -46
- tests/test_templates.py +0 -218
- tests/test_torchnn.py +0 -99
- tests/test_van.py +0 -102
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
tensorcircuit/gates.py
CHANGED
|
@@ -34,6 +34,12 @@ one_state = np.array([0.0, 1.0], dtype=npdtype)
|
|
|
34
34
|
plus_state = 1.0 / np.sqrt(2) * (zero_state + one_state)
|
|
35
35
|
minus_state = 1.0 / np.sqrt(2) * (zero_state - one_state)
|
|
36
36
|
|
|
37
|
+
# Common elements as np.ndarray objects
|
|
38
|
+
_i00 = np.array([[1.0, 0.0], [0.0, 0.0]])
|
|
39
|
+
_i01 = np.array([[0.0, 1.0], [0.0, 0.0]])
|
|
40
|
+
_i10 = np.array([[0.0, 0.0], [1.0, 0.0]])
|
|
41
|
+
_i11 = np.array([[0.0, 0.0], [0.0, 1.0]])
|
|
42
|
+
|
|
37
43
|
# Common single qubit gates as np.ndarray objects
|
|
38
44
|
_h_matrix = 1 / np.sqrt(2) * np.array([[1.0, 1.0], [1.0, -1.0]])
|
|
39
45
|
_i_matrix = np.array([[1.0, 0.0], [0.0, 1.0]])
|
|
@@ -229,7 +235,7 @@ def num_to_tensor(*num: Union[float, Tensor], dtype: Optional[str] = None) -> An
|
|
|
229
235
|
# TODO(@YHPeter): fix __doc__ for same function with different names
|
|
230
236
|
|
|
231
237
|
l = []
|
|
232
|
-
if
|
|
238
|
+
if dtype is None:
|
|
233
239
|
dtype = dtypestr
|
|
234
240
|
for n in num:
|
|
235
241
|
if not backend.is_tensor(n):
|
|
@@ -245,7 +251,7 @@ array_to_tensor = num_to_tensor
|
|
|
245
251
|
|
|
246
252
|
|
|
247
253
|
def gate_wrapper(m: Tensor, n: Optional[str] = None) -> Gate:
|
|
248
|
-
if
|
|
254
|
+
if n is None:
|
|
249
255
|
n = "unknowngate"
|
|
250
256
|
m = m.astype(npdtype)
|
|
251
257
|
return Gate(deepcopy(m), name=n)
|
|
@@ -255,7 +261,7 @@ class GateF:
|
|
|
255
261
|
def __init__(
|
|
256
262
|
self, m: Tensor, n: Optional[str] = None, ctrl: Optional[List[int]] = None
|
|
257
263
|
):
|
|
258
|
-
if
|
|
264
|
+
if n is None:
|
|
259
265
|
n = "unknowngate"
|
|
260
266
|
self.m = m
|
|
261
267
|
self.n = n
|
|
@@ -310,7 +316,7 @@ class GateF:
|
|
|
310
316
|
|
|
311
317
|
return Gate(cu, name="c" + self.n)
|
|
312
318
|
|
|
313
|
-
if
|
|
319
|
+
if self.ctrl is None:
|
|
314
320
|
ctrl = [1]
|
|
315
321
|
else:
|
|
316
322
|
ctrl = [1] + self.ctrl
|
|
@@ -330,7 +336,7 @@ class GateF:
|
|
|
330
336
|
# TODO(@refraction-ray): ctrl convention to be finally determined
|
|
331
337
|
return Gate(ocu, name="o" + self.n)
|
|
332
338
|
|
|
333
|
-
if
|
|
339
|
+
if self.ctrl is None:
|
|
334
340
|
ctrl = [0]
|
|
335
341
|
else:
|
|
336
342
|
ctrl = [0] + self.ctrl
|
|
@@ -349,7 +355,7 @@ class GateVF(GateF):
|
|
|
349
355
|
n: Optional[str] = None,
|
|
350
356
|
ctrl: Optional[List[int]] = None,
|
|
351
357
|
):
|
|
352
|
-
if
|
|
358
|
+
if n is None:
|
|
353
359
|
n = "unknowngate"
|
|
354
360
|
self.f = f
|
|
355
361
|
self.n = n
|
|
@@ -483,7 +489,7 @@ def phase_gate(theta: float = 0) -> Gate:
|
|
|
483
489
|
:rtype: Gate
|
|
484
490
|
"""
|
|
485
491
|
theta = array_to_tensor(theta)
|
|
486
|
-
i00, i11 = array_to_tensor(
|
|
492
|
+
i00, i11 = array_to_tensor(_i00, _i11)
|
|
487
493
|
unitary = i00 + backend.exp(1.0j * theta) * i11
|
|
488
494
|
return Gate(unitary)
|
|
489
495
|
|
|
@@ -512,7 +518,7 @@ def get_u_parameter(m: Tensor) -> Tuple[float, float, float]:
|
|
|
512
518
|
return theta, phi, lbd
|
|
513
519
|
|
|
514
520
|
|
|
515
|
-
def u_gate(theta: float = 0, phi: float = 0, lbd: float = 0) -> Gate:
|
|
521
|
+
def u_gate(theta: float = 0.0, phi: float = 0.0, lbd: float = 0.0) -> Gate:
|
|
516
522
|
r"""
|
|
517
523
|
IBMQ U gate following the converntion of OpenQASM3.0.
|
|
518
524
|
See `OpenQASM doc <https://openqasm.com/language/gates.html#built-in-gates>`_
|
|
@@ -533,12 +539,7 @@ def u_gate(theta: float = 0, phi: float = 0, lbd: float = 0) -> Gate:
|
|
|
533
539
|
:rtype: Gate
|
|
534
540
|
"""
|
|
535
541
|
theta, phi, lbd = array_to_tensor(theta, phi, lbd)
|
|
536
|
-
i00, i01, i10, i11 = array_to_tensor(
|
|
537
|
-
np.array([[1, 0], [0, 0]]),
|
|
538
|
-
np.array([[0, 1], [0, 0]]),
|
|
539
|
-
np.array([[0, 0], [1, 0]]),
|
|
540
|
-
np.array([[0, 0], [0, 1]]),
|
|
541
|
-
)
|
|
542
|
+
i00, i01, i10, i11 = array_to_tensor(_i00, _i01, _i10, _i11)
|
|
542
543
|
unitary = (
|
|
543
544
|
backend.cos(theta / 2) * i00
|
|
544
545
|
- backend.exp(1.0j * lbd) * backend.sin(theta / 2) * i01
|
|
@@ -548,7 +549,7 @@ def u_gate(theta: float = 0, phi: float = 0, lbd: float = 0) -> Gate:
|
|
|
548
549
|
return Gate(unitary)
|
|
549
550
|
|
|
550
551
|
|
|
551
|
-
def r_gate(theta: float = 0, alpha: float = 0, phi: float = 0) -> Gate:
|
|
552
|
+
def r_gate(theta: float = 0.0, alpha: float = 0.0, phi: float = 0.0) -> Gate:
|
|
552
553
|
r"""
|
|
553
554
|
General single qubit rotation gate
|
|
554
555
|
|
|
@@ -582,7 +583,7 @@ def r_gate(theta: float = 0, alpha: float = 0, phi: float = 0) -> Gate:
|
|
|
582
583
|
# r = r_gate
|
|
583
584
|
|
|
584
585
|
|
|
585
|
-
def rx_gate(theta: float = 0) -> Gate:
|
|
586
|
+
def rx_gate(theta: float = 0.0) -> Gate:
|
|
586
587
|
r"""
|
|
587
588
|
Rotation gate along :math:`x` axis.
|
|
588
589
|
|
|
@@ -603,7 +604,7 @@ def rx_gate(theta: float = 0) -> Gate:
|
|
|
603
604
|
# rx = rx_gate
|
|
604
605
|
|
|
605
606
|
|
|
606
|
-
def ry_gate(theta: float = 0) -> Gate:
|
|
607
|
+
def ry_gate(theta: float = 0.0) -> Gate:
|
|
607
608
|
r"""
|
|
608
609
|
Rotation gate along :math:`y` axis.
|
|
609
610
|
|
|
@@ -624,7 +625,7 @@ def ry_gate(theta: float = 0) -> Gate:
|
|
|
624
625
|
# ry = ry_gate
|
|
625
626
|
|
|
626
627
|
|
|
627
|
-
def rz_gate(theta: float = 0) -> Gate:
|
|
628
|
+
def rz_gate(theta: float = 0.0) -> Gate:
|
|
628
629
|
r"""
|
|
629
630
|
Rotation gate along :math:`z` axis.
|
|
630
631
|
|
|
@@ -645,7 +646,7 @@ def rz_gate(theta: float = 0) -> Gate:
|
|
|
645
646
|
# rz = rz_gate
|
|
646
647
|
|
|
647
648
|
|
|
648
|
-
def rgate_theoretical(theta: float = 0, alpha: float = 0, phi: float = 0) -> Gate:
|
|
649
|
+
def rgate_theoretical(theta: float = 0.0, alpha: float = 0.0, phi: float = 0.0) -> Gate:
|
|
649
650
|
r"""
|
|
650
651
|
Rotation gate implemented by matrix exponential. The output is the same as `rgate`.
|
|
651
652
|
|
|
@@ -723,7 +724,7 @@ def iswap_gate(theta: float = 1.0) -> Gate:
|
|
|
723
724
|
# iswap = iswap_gate
|
|
724
725
|
|
|
725
726
|
|
|
726
|
-
def cr_gate(theta: float = 0, alpha: float = 0, phi: float = 0) -> Gate:
|
|
727
|
+
def cr_gate(theta: float = 0.0, alpha: float = 0.0, phi: float = 0.0) -> Gate:
|
|
727
728
|
r"""
|
|
728
729
|
Controlled rotation gate. When the control qubit is 1, `rgate` is applied to the target qubit.
|
|
729
730
|
|
|
@@ -775,7 +776,7 @@ def random_two_qubit_gate() -> Gate:
|
|
|
775
776
|
return Gate(deepcopy(unitary), name="R2Q")
|
|
776
777
|
|
|
777
778
|
|
|
778
|
-
def any_gate(unitary: Tensor, name: str = "any") -> Gate:
|
|
779
|
+
def any_gate(unitary: Tensor, name: str = "any", dim: Optional[int] = None) -> Gate:
|
|
779
780
|
"""
|
|
780
781
|
Note one should provide the gate with properly reshaped.
|
|
781
782
|
|
|
@@ -783,6 +784,8 @@ def any_gate(unitary: Tensor, name: str = "any") -> Gate:
|
|
|
783
784
|
:type unitary: Tensor
|
|
784
785
|
:param name: The name of the gate.
|
|
785
786
|
:type name: str
|
|
787
|
+
:param dim: The dimension of the gate.
|
|
788
|
+
:type dim: int
|
|
786
789
|
:return: the resulted gate
|
|
787
790
|
:rtype: Gate
|
|
788
791
|
"""
|
|
@@ -791,7 +794,10 @@ def any_gate(unitary: Tensor, name: str = "any") -> Gate:
|
|
|
791
794
|
unitary.tensor = backend.cast(unitary.tensor, dtypestr)
|
|
792
795
|
return unitary
|
|
793
796
|
unitary = backend.cast(unitary, dtypestr)
|
|
794
|
-
|
|
797
|
+
if dim is None or dim == 2:
|
|
798
|
+
unitary = backend.reshape2(unitary)
|
|
799
|
+
else:
|
|
800
|
+
unitary = backend.reshaped(unitary, dim)
|
|
795
801
|
# nleg = int(np.log2(backend.sizen(unitary)))
|
|
796
802
|
# unitary = backend.reshape(unitary, [2 for _ in range(nleg)])
|
|
797
803
|
return Gate(unitary, name=name)
|
|
@@ -864,6 +870,43 @@ def exponential_gate_unity(
|
|
|
864
870
|
return Gate(mat, name="exp1-" + name)
|
|
865
871
|
|
|
866
872
|
|
|
873
|
+
def su4_gate(theta: Tensor, name: str = "su(4)") -> Gate:
|
|
874
|
+
r"""
|
|
875
|
+
Two-qubit general SU(4) gate.
|
|
876
|
+
|
|
877
|
+
:param theta: the angle tensor (15 components) of the gate.
|
|
878
|
+
:type theta: Tensor
|
|
879
|
+
:param name: the name of the gate.
|
|
880
|
+
:type name: str
|
|
881
|
+
:return: a gate object.
|
|
882
|
+
:rtype: Gate
|
|
883
|
+
"""
|
|
884
|
+
theta = num_to_tensor(theta)
|
|
885
|
+
pauli_ops = array_to_tensor(
|
|
886
|
+
_ix_matrix,
|
|
887
|
+
_iy_matrix,
|
|
888
|
+
_iz_matrix,
|
|
889
|
+
_xi_matrix,
|
|
890
|
+
_xx_matrix,
|
|
891
|
+
_xy_matrix,
|
|
892
|
+
_xz_matrix,
|
|
893
|
+
_yi_matrix,
|
|
894
|
+
_yx_matrix,
|
|
895
|
+
_yy_matrix,
|
|
896
|
+
_yz_matrix,
|
|
897
|
+
_zi_matrix,
|
|
898
|
+
_zx_matrix,
|
|
899
|
+
_zy_matrix,
|
|
900
|
+
_zz_matrix,
|
|
901
|
+
)
|
|
902
|
+
generator = backend.sum(
|
|
903
|
+
backend.stack([theta[i] * pauli_ops[i] for i in range(15)]), axis=0
|
|
904
|
+
)
|
|
905
|
+
mat = backend.expm(-1j * generator)
|
|
906
|
+
mat = backend.reshape2(mat)
|
|
907
|
+
return Gate(mat, name=name)
|
|
908
|
+
|
|
909
|
+
|
|
867
910
|
exp1_gate = exponential_gate_unity
|
|
868
911
|
# exp1 = exponential_gate_unity
|
|
869
912
|
rzz_gate = partial(exp1_gate, unitary=_zz_matrix, half=True)
|
|
@@ -968,6 +1011,7 @@ def meta_vgate() -> None:
|
|
|
968
1011
|
"rzz",
|
|
969
1012
|
"rxx",
|
|
970
1013
|
"ryy",
|
|
1014
|
+
"su4",
|
|
971
1015
|
]:
|
|
972
1016
|
for funcname in [f, f + "gate"]:
|
|
973
1017
|
setattr(thismodule, funcname, GateVF(getattr(thismodule, f + "_gate"), f))
|
|
@@ -14,6 +14,4 @@ from .numpy import numpy_interface, np_interface
|
|
|
14
14
|
from .scipy import scipy_interface, scipy_optimize_interface
|
|
15
15
|
from .torch import torch_interface, pytorch_interface, torch_interface_kws
|
|
16
16
|
from .tensorflow import tensorflow_interface, tf_interface
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
# TODO(@refraction-ray): jax interface using puer_callback and custom_vjp
|
|
17
|
+
from .jax import jax_interface
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Interface wraps quantum function as a jax function
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable, Tuple, Optional, Union, Sequence
|
|
6
|
+
from functools import wraps, partial
|
|
7
|
+
|
|
8
|
+
from ..cons import backend
|
|
9
|
+
from .tensortrans import general_args_to_backend
|
|
10
|
+
|
|
11
|
+
Tensor = Any
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def jax_wrapper(
|
|
15
|
+
fun: Callable[..., Any],
|
|
16
|
+
enable_dlpack: bool = False,
|
|
17
|
+
output_shape: Optional[
|
|
18
|
+
Union[Tuple[int, ...], Tuple[int, ...], Sequence[Tuple[int, ...]]]
|
|
19
|
+
] = None,
|
|
20
|
+
output_dtype: Optional[Union[Any, Sequence[Any]]] = None,
|
|
21
|
+
) -> Callable[..., Any]:
|
|
22
|
+
import jax
|
|
23
|
+
|
|
24
|
+
@wraps(fun)
|
|
25
|
+
def fun_jax(*x: Any) -> Any:
|
|
26
|
+
def wrapped_fun(*args: Any) -> Any:
|
|
27
|
+
args = general_args_to_backend(args, enable_dlpack=enable_dlpack)
|
|
28
|
+
y = fun(*args)
|
|
29
|
+
y = general_args_to_backend(
|
|
30
|
+
y, target_backend="jax", enable_dlpack=enable_dlpack
|
|
31
|
+
)
|
|
32
|
+
return y
|
|
33
|
+
|
|
34
|
+
# Use provided shape and dtype if available, otherwise run test
|
|
35
|
+
if output_shape is not None and output_dtype is not None:
|
|
36
|
+
if isinstance(output_shape, Sequence) and not isinstance(
|
|
37
|
+
output_shape[0], int
|
|
38
|
+
):
|
|
39
|
+
# Multiple outputs case
|
|
40
|
+
out_shape = tuple(
|
|
41
|
+
jax.ShapeDtypeStruct(s, d)
|
|
42
|
+
for s, d in zip(output_shape, output_dtype)
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
# Single output case
|
|
46
|
+
out_shape = jax.ShapeDtypeStruct(output_shape, output_dtype) # type: ignore
|
|
47
|
+
else:
|
|
48
|
+
# Get expected output shape by running function once
|
|
49
|
+
test_out = wrapped_fun(*x)
|
|
50
|
+
if isinstance(test_out, tuple):
|
|
51
|
+
# Multiple outputs case
|
|
52
|
+
out_shape = tuple(
|
|
53
|
+
jax.ShapeDtypeStruct(
|
|
54
|
+
t.shape if hasattr(t, "shape") else (),
|
|
55
|
+
t.dtype if hasattr(t, "dtype") else x[0].dtype,
|
|
56
|
+
)
|
|
57
|
+
for t in test_out
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
# Single output case
|
|
61
|
+
out_shape = jax.ShapeDtypeStruct( # type: ignore
|
|
62
|
+
test_out.shape if hasattr(test_out, "shape") else (),
|
|
63
|
+
test_out.dtype if hasattr(test_out, "dtype") else x[0].dtype,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Use pure_callback with correct output shape
|
|
67
|
+
result = jax.pure_callback(wrapped_fun, out_shape, *x)
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
return fun_jax
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def jax_interface(
|
|
74
|
+
fun: Callable[..., Any],
|
|
75
|
+
jit: bool = False,
|
|
76
|
+
enable_dlpack: bool = False,
|
|
77
|
+
output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]] = None,
|
|
78
|
+
output_dtype: Optional[Any] = None,
|
|
79
|
+
) -> Callable[..., Any]:
|
|
80
|
+
"""
|
|
81
|
+
Wrap a function on different ML backend with a jax interface.
|
|
82
|
+
|
|
83
|
+
:Example:
|
|
84
|
+
|
|
85
|
+
.. code-block:: python
|
|
86
|
+
|
|
87
|
+
tc.set_backend("tensorflow")
|
|
88
|
+
|
|
89
|
+
def f(params):
|
|
90
|
+
c = tc.Circuit(1)
|
|
91
|
+
c.rx(0, theta=params[0])
|
|
92
|
+
c.ry(0, theta=params[1])
|
|
93
|
+
return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
|
|
94
|
+
|
|
95
|
+
f = tc.interfaces.jax_interface(f, jit=True)
|
|
96
|
+
|
|
97
|
+
params = jnp.ones(2)
|
|
98
|
+
value, grad = jax.value_and_grad(f)(params)
|
|
99
|
+
|
|
100
|
+
:param fun: The quantum function with tensor in and tensor out
|
|
101
|
+
:type fun: Callable[..., Any]
|
|
102
|
+
:param jit: whether to jit ``fun``, defaults to False
|
|
103
|
+
:type jit: bool, optional
|
|
104
|
+
:param enable_dlpack: whether transform tensor backend via dlpack, defaults to False
|
|
105
|
+
:type enable_dlpack: bool, optional
|
|
106
|
+
:param output_shape: Optional shape of the function output, defaults to None
|
|
107
|
+
:type output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]], optional
|
|
108
|
+
:param output_dtype: Optional dtype of the function output, defaults to None
|
|
109
|
+
:type output_dtype: Optional[Any], optional
|
|
110
|
+
:return: The same quantum function but now with jax array in and jax array out
|
|
111
|
+
while AD is also supported
|
|
112
|
+
:rtype: Callable[..., Any]
|
|
113
|
+
"""
|
|
114
|
+
jax_fun = create_jax_function(
|
|
115
|
+
fun,
|
|
116
|
+
enable_dlpack=enable_dlpack,
|
|
117
|
+
jit=jit,
|
|
118
|
+
output_shape=output_shape,
|
|
119
|
+
output_dtype=output_dtype,
|
|
120
|
+
)
|
|
121
|
+
return jax_fun
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def create_jax_function(
|
|
125
|
+
fun: Callable[..., Any],
|
|
126
|
+
enable_dlpack: bool = False,
|
|
127
|
+
jit: bool = False,
|
|
128
|
+
output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]] = None,
|
|
129
|
+
output_dtype: Optional[Any] = None,
|
|
130
|
+
) -> Callable[..., Any]:
|
|
131
|
+
import jax
|
|
132
|
+
from jax import custom_vjp
|
|
133
|
+
|
|
134
|
+
if jit:
|
|
135
|
+
fun = backend.jit(fun)
|
|
136
|
+
|
|
137
|
+
wrapped = jax_wrapper(
|
|
138
|
+
fun,
|
|
139
|
+
enable_dlpack=enable_dlpack,
|
|
140
|
+
output_shape=output_shape,
|
|
141
|
+
output_dtype=output_dtype,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
@custom_vjp
|
|
145
|
+
def f(*x: Any) -> Any:
|
|
146
|
+
return wrapped(*x)
|
|
147
|
+
|
|
148
|
+
def f_fwd(*x: Any) -> Tuple[Any, Tuple[Any, ...]]:
|
|
149
|
+
y = wrapped(*x)
|
|
150
|
+
return y, x
|
|
151
|
+
|
|
152
|
+
def f_bwd(res: Tuple[Any, ...], g: Any) -> Tuple[Any, ...]:
|
|
153
|
+
x = res
|
|
154
|
+
|
|
155
|
+
if len(x) == 1:
|
|
156
|
+
x = x[0]
|
|
157
|
+
|
|
158
|
+
vjp_fun = partial(backend.vjp, fun)
|
|
159
|
+
if jit:
|
|
160
|
+
vjp_fun = backend.jit(vjp_fun) # type: ignore
|
|
161
|
+
|
|
162
|
+
def vjp_wrapped(args: Any) -> Any:
|
|
163
|
+
args = general_args_to_backend(args, enable_dlpack=enable_dlpack)
|
|
164
|
+
gb = general_args_to_backend(g, enable_dlpack=enable_dlpack)
|
|
165
|
+
r = vjp_fun(args, gb)[1]
|
|
166
|
+
r = general_args_to_backend(
|
|
167
|
+
r, target_backend="jax", enable_dlpack=enable_dlpack
|
|
168
|
+
)
|
|
169
|
+
return r
|
|
170
|
+
|
|
171
|
+
# Handle gradient shape for both single input and tuple inputs
|
|
172
|
+
if isinstance(x, tuple):
|
|
173
|
+
# Create a tuple of ShapeDtypeStruct for each input
|
|
174
|
+
grad_shape = tuple(jax.ShapeDtypeStruct(xi.shape, xi.dtype) for xi in x)
|
|
175
|
+
else:
|
|
176
|
+
grad_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
|
|
177
|
+
|
|
178
|
+
dx = jax.pure_callback(
|
|
179
|
+
vjp_wrapped,
|
|
180
|
+
grad_shape,
|
|
181
|
+
x,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if not isinstance(dx, tuple):
|
|
185
|
+
dx = (dx,)
|
|
186
|
+
return dx # type: ignore
|
|
187
|
+
|
|
188
|
+
f.defvjp(f_fwd, f_bwd)
|
|
189
|
+
return f
|
tensorcircuit/keras.py
CHANGED
|
@@ -24,7 +24,7 @@ class QuantumLayer(Layer): # type: ignore
|
|
|
24
24
|
initializer: Union[Text, Sequence[Text]] = "glorot_uniform",
|
|
25
25
|
constraint: Optional[Union[Text, Sequence[Text]]] = None,
|
|
26
26
|
regularizer: Optional[Union[Text, Sequence[Text]]] = None,
|
|
27
|
-
**kwargs: Any
|
|
27
|
+
**kwargs: Any,
|
|
28
28
|
) -> None:
|
|
29
29
|
"""
|
|
30
30
|
`QuantumLayer` wraps the quantum function `f` as a `keras.Layer`
|
|
@@ -103,7 +103,7 @@ class QuantumLayer(Layer): # type: ignore
|
|
|
103
103
|
inputs: tf.Tensor,
|
|
104
104
|
training: Optional[bool] = None,
|
|
105
105
|
mask: Optional[tf.Tensor] = None,
|
|
106
|
-
**kwargs: Any
|
|
106
|
+
**kwargs: Any,
|
|
107
107
|
) -> tf.Tensor:
|
|
108
108
|
# input_shape = list(inputs.shape)
|
|
109
109
|
# inputs = tf.reshape(inputs, (-1, input_shape[-1]))
|
|
@@ -154,7 +154,7 @@ class HardwareLayer(QuantumLayer):
|
|
|
154
154
|
inputs: tf.Tensor,
|
|
155
155
|
training: Optional[bool] = None,
|
|
156
156
|
mask: Optional[tf.Tensor] = None,
|
|
157
|
-
**kwargs: Any
|
|
157
|
+
**kwargs: Any,
|
|
158
158
|
) -> tf.Tensor:
|
|
159
159
|
if inputs is None: # not possible
|
|
160
160
|
result = self.f(*self.pqc_weights, **kwargs)
|