tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (76) hide show
  1. tensorcircuit/__init__.py +18 -2
  2. tensorcircuit/about.py +46 -0
  3. tensorcircuit/abstractcircuit.py +4 -0
  4. tensorcircuit/analogcircuit.py +413 -0
  5. tensorcircuit/applications/layers.py +1 -1
  6. tensorcircuit/applications/van.py +1 -1
  7. tensorcircuit/backends/abstract_backend.py +320 -7
  8. tensorcircuit/backends/cupy_backend.py +3 -1
  9. tensorcircuit/backends/jax_backend.py +102 -4
  10. tensorcircuit/backends/jax_ops.py +110 -1
  11. tensorcircuit/backends/numpy_backend.py +49 -3
  12. tensorcircuit/backends/pytorch_backend.py +92 -3
  13. tensorcircuit/backends/tensorflow_backend.py +102 -3
  14. tensorcircuit/basecircuit.py +157 -98
  15. tensorcircuit/circuit.py +115 -57
  16. tensorcircuit/cloud/local.py +1 -1
  17. tensorcircuit/cloud/quafu_provider.py +1 -1
  18. tensorcircuit/cloud/tencent.py +1 -1
  19. tensorcircuit/compiler/simple_compiler.py +2 -2
  20. tensorcircuit/cons.py +142 -21
  21. tensorcircuit/densitymatrix.py +43 -14
  22. tensorcircuit/experimental.py +387 -129
  23. tensorcircuit/fgs.py +282 -81
  24. tensorcircuit/gates.py +66 -22
  25. tensorcircuit/interfaces/__init__.py +1 -3
  26. tensorcircuit/interfaces/jax.py +189 -0
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +868 -152
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +147 -20
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +479 -0
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
  45. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1031
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -365
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -429
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -277
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -526
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -347
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_templates.py +0 -218
  74. tests/test_torchnn.py +0 -99
  75. tests/test_van.py +0 -102
  76. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
tensorcircuit/gates.py CHANGED
@@ -34,6 +34,12 @@ one_state = np.array([0.0, 1.0], dtype=npdtype)
34
34
  plus_state = 1.0 / np.sqrt(2) * (zero_state + one_state)
35
35
  minus_state = 1.0 / np.sqrt(2) * (zero_state - one_state)
36
36
 
37
+ # Common elements as np.ndarray objects
38
+ _i00 = np.array([[1.0, 0.0], [0.0, 0.0]])
39
+ _i01 = np.array([[0.0, 1.0], [0.0, 0.0]])
40
+ _i10 = np.array([[0.0, 0.0], [1.0, 0.0]])
41
+ _i11 = np.array([[0.0, 0.0], [0.0, 1.0]])
42
+
37
43
  # Common single qubit gates as np.ndarray objects
38
44
  _h_matrix = 1 / np.sqrt(2) * np.array([[1.0, 1.0], [1.0, -1.0]])
39
45
  _i_matrix = np.array([[1.0, 0.0], [0.0, 1.0]])
@@ -229,7 +235,7 @@ def num_to_tensor(*num: Union[float, Tensor], dtype: Optional[str] = None) -> An
229
235
  # TODO(@YHPeter): fix __doc__ for same function with different names
230
236
 
231
237
  l = []
232
- if not dtype:
238
+ if dtype is None:
233
239
  dtype = dtypestr
234
240
  for n in num:
235
241
  if not backend.is_tensor(n):
@@ -245,7 +251,7 @@ array_to_tensor = num_to_tensor
245
251
 
246
252
 
247
253
  def gate_wrapper(m: Tensor, n: Optional[str] = None) -> Gate:
248
- if not n:
254
+ if n is None:
249
255
  n = "unknowngate"
250
256
  m = m.astype(npdtype)
251
257
  return Gate(deepcopy(m), name=n)
@@ -255,7 +261,7 @@ class GateF:
255
261
  def __init__(
256
262
  self, m: Tensor, n: Optional[str] = None, ctrl: Optional[List[int]] = None
257
263
  ):
258
- if not n:
264
+ if n is None:
259
265
  n = "unknowngate"
260
266
  self.m = m
261
267
  self.n = n
@@ -310,7 +316,7 @@ class GateF:
310
316
 
311
317
  return Gate(cu, name="c" + self.n)
312
318
 
313
- if not self.ctrl:
319
+ if self.ctrl is None:
314
320
  ctrl = [1]
315
321
  else:
316
322
  ctrl = [1] + self.ctrl
@@ -330,7 +336,7 @@ class GateF:
330
336
  # TODO(@refraction-ray): ctrl convention to be finally determined
331
337
  return Gate(ocu, name="o" + self.n)
332
338
 
333
- if not self.ctrl:
339
+ if self.ctrl is None:
334
340
  ctrl = [0]
335
341
  else:
336
342
  ctrl = [0] + self.ctrl
@@ -349,7 +355,7 @@ class GateVF(GateF):
349
355
  n: Optional[str] = None,
350
356
  ctrl: Optional[List[int]] = None,
351
357
  ):
352
- if not n:
358
+ if n is None:
353
359
  n = "unknowngate"
354
360
  self.f = f
355
361
  self.n = n
@@ -483,7 +489,7 @@ def phase_gate(theta: float = 0) -> Gate:
483
489
  :rtype: Gate
484
490
  """
485
491
  theta = array_to_tensor(theta)
486
- i00, i11 = array_to_tensor(np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]]))
492
+ i00, i11 = array_to_tensor(_i00, _i11)
487
493
  unitary = i00 + backend.exp(1.0j * theta) * i11
488
494
  return Gate(unitary)
489
495
 
@@ -512,7 +518,7 @@ def get_u_parameter(m: Tensor) -> Tuple[float, float, float]:
512
518
  return theta, phi, lbd
513
519
 
514
520
 
515
- def u_gate(theta: float = 0, phi: float = 0, lbd: float = 0) -> Gate:
521
+ def u_gate(theta: float = 0.0, phi: float = 0.0, lbd: float = 0.0) -> Gate:
516
522
  r"""
517
523
  IBMQ U gate following the converntion of OpenQASM3.0.
518
524
  See `OpenQASM doc <https://openqasm.com/language/gates.html#built-in-gates>`_
@@ -533,12 +539,7 @@ def u_gate(theta: float = 0, phi: float = 0, lbd: float = 0) -> Gate:
533
539
  :rtype: Gate
534
540
  """
535
541
  theta, phi, lbd = array_to_tensor(theta, phi, lbd)
536
- i00, i01, i10, i11 = array_to_tensor(
537
- np.array([[1, 0], [0, 0]]),
538
- np.array([[0, 1], [0, 0]]),
539
- np.array([[0, 0], [1, 0]]),
540
- np.array([[0, 0], [0, 1]]),
541
- )
542
+ i00, i01, i10, i11 = array_to_tensor(_i00, _i01, _i10, _i11)
542
543
  unitary = (
543
544
  backend.cos(theta / 2) * i00
544
545
  - backend.exp(1.0j * lbd) * backend.sin(theta / 2) * i01
@@ -548,7 +549,7 @@ def u_gate(theta: float = 0, phi: float = 0, lbd: float = 0) -> Gate:
548
549
  return Gate(unitary)
549
550
 
550
551
 
551
- def r_gate(theta: float = 0, alpha: float = 0, phi: float = 0) -> Gate:
552
+ def r_gate(theta: float = 0.0, alpha: float = 0.0, phi: float = 0.0) -> Gate:
552
553
  r"""
553
554
  General single qubit rotation gate
554
555
 
@@ -582,7 +583,7 @@ def r_gate(theta: float = 0, alpha: float = 0, phi: float = 0) -> Gate:
582
583
  # r = r_gate
583
584
 
584
585
 
585
- def rx_gate(theta: float = 0) -> Gate:
586
+ def rx_gate(theta: float = 0.0) -> Gate:
586
587
  r"""
587
588
  Rotation gate along :math:`x` axis.
588
589
 
@@ -603,7 +604,7 @@ def rx_gate(theta: float = 0) -> Gate:
603
604
  # rx = rx_gate
604
605
 
605
606
 
606
- def ry_gate(theta: float = 0) -> Gate:
607
+ def ry_gate(theta: float = 0.0) -> Gate:
607
608
  r"""
608
609
  Rotation gate along :math:`y` axis.
609
610
 
@@ -624,7 +625,7 @@ def ry_gate(theta: float = 0) -> Gate:
624
625
  # ry = ry_gate
625
626
 
626
627
 
627
- def rz_gate(theta: float = 0) -> Gate:
628
+ def rz_gate(theta: float = 0.0) -> Gate:
628
629
  r"""
629
630
  Rotation gate along :math:`z` axis.
630
631
 
@@ -645,7 +646,7 @@ def rz_gate(theta: float = 0) -> Gate:
645
646
  # rz = rz_gate
646
647
 
647
648
 
648
- def rgate_theoretical(theta: float = 0, alpha: float = 0, phi: float = 0) -> Gate:
649
+ def rgate_theoretical(theta: float = 0.0, alpha: float = 0.0, phi: float = 0.0) -> Gate:
649
650
  r"""
650
651
  Rotation gate implemented by matrix exponential. The output is the same as `rgate`.
651
652
 
@@ -723,7 +724,7 @@ def iswap_gate(theta: float = 1.0) -> Gate:
723
724
  # iswap = iswap_gate
724
725
 
725
726
 
726
- def cr_gate(theta: float = 0, alpha: float = 0, phi: float = 0) -> Gate:
727
+ def cr_gate(theta: float = 0.0, alpha: float = 0.0, phi: float = 0.0) -> Gate:
727
728
  r"""
728
729
  Controlled rotation gate. When the control qubit is 1, `rgate` is applied to the target qubit.
729
730
 
@@ -775,7 +776,7 @@ def random_two_qubit_gate() -> Gate:
775
776
  return Gate(deepcopy(unitary), name="R2Q")
776
777
 
777
778
 
778
- def any_gate(unitary: Tensor, name: str = "any") -> Gate:
779
+ def any_gate(unitary: Tensor, name: str = "any", dim: Optional[int] = None) -> Gate:
779
780
  """
780
781
  Note one should provide the gate with properly reshaped.
781
782
 
@@ -783,6 +784,8 @@ def any_gate(unitary: Tensor, name: str = "any") -> Gate:
783
784
  :type unitary: Tensor
784
785
  :param name: The name of the gate.
785
786
  :type name: str
787
+ :param dim: The dimension of the gate.
788
+ :type dim: int
786
789
  :return: the resulted gate
787
790
  :rtype: Gate
788
791
  """
@@ -791,7 +794,10 @@ def any_gate(unitary: Tensor, name: str = "any") -> Gate:
791
794
  unitary.tensor = backend.cast(unitary.tensor, dtypestr)
792
795
  return unitary
793
796
  unitary = backend.cast(unitary, dtypestr)
794
- unitary = backend.reshape2(unitary)
797
+ if dim is None or dim == 2:
798
+ unitary = backend.reshape2(unitary)
799
+ else:
800
+ unitary = backend.reshaped(unitary, dim)
795
801
  # nleg = int(np.log2(backend.sizen(unitary)))
796
802
  # unitary = backend.reshape(unitary, [2 for _ in range(nleg)])
797
803
  return Gate(unitary, name=name)
@@ -864,6 +870,43 @@ def exponential_gate_unity(
864
870
  return Gate(mat, name="exp1-" + name)
865
871
 
866
872
 
873
+ def su4_gate(theta: Tensor, name: str = "su(4)") -> Gate:
874
+ r"""
875
+ Two-qubit general SU(4) gate.
876
+
877
+ :param theta: the angle tensor (15 components) of the gate.
878
+ :type theta: Tensor
879
+ :param name: the name of the gate.
880
+ :type name: str
881
+ :return: a gate object.
882
+ :rtype: Gate
883
+ """
884
+ theta = num_to_tensor(theta)
885
+ pauli_ops = array_to_tensor(
886
+ _ix_matrix,
887
+ _iy_matrix,
888
+ _iz_matrix,
889
+ _xi_matrix,
890
+ _xx_matrix,
891
+ _xy_matrix,
892
+ _xz_matrix,
893
+ _yi_matrix,
894
+ _yx_matrix,
895
+ _yy_matrix,
896
+ _yz_matrix,
897
+ _zi_matrix,
898
+ _zx_matrix,
899
+ _zy_matrix,
900
+ _zz_matrix,
901
+ )
902
+ generator = backend.sum(
903
+ backend.stack([theta[i] * pauli_ops[i] for i in range(15)]), axis=0
904
+ )
905
+ mat = backend.expm(-1j * generator)
906
+ mat = backend.reshape2(mat)
907
+ return Gate(mat, name=name)
908
+
909
+
867
910
  exp1_gate = exponential_gate_unity
868
911
  # exp1 = exponential_gate_unity
869
912
  rzz_gate = partial(exp1_gate, unitary=_zz_matrix, half=True)
@@ -968,6 +1011,7 @@ def meta_vgate() -> None:
968
1011
  "rzz",
969
1012
  "rxx",
970
1013
  "ryy",
1014
+ "su4",
971
1015
  ]:
972
1016
  for funcname in [f, f + "gate"]:
973
1017
  setattr(thismodule, funcname, GateVF(getattr(thismodule, f + "_gate"), f))
@@ -14,6 +14,4 @@ from .numpy import numpy_interface, np_interface
14
14
  from .scipy import scipy_interface, scipy_optimize_interface
15
15
  from .torch import torch_interface, pytorch_interface, torch_interface_kws
16
16
  from .tensorflow import tensorflow_interface, tf_interface
17
-
18
-
19
- # TODO(@refraction-ray): jax interface using puer_callback and custom_vjp
17
+ from .jax import jax_interface
@@ -0,0 +1,189 @@
1
+ """
2
+ Interface wraps quantum function as a jax function
3
+ """
4
+
5
+ from typing import Any, Callable, Tuple, Optional, Union, Sequence
6
+ from functools import wraps, partial
7
+
8
+ from ..cons import backend
9
+ from .tensortrans import general_args_to_backend
10
+
11
+ Tensor = Any
12
+
13
+
14
+ def jax_wrapper(
15
+ fun: Callable[..., Any],
16
+ enable_dlpack: bool = False,
17
+ output_shape: Optional[
18
+ Union[Tuple[int, ...], Tuple[int, ...], Sequence[Tuple[int, ...]]]
19
+ ] = None,
20
+ output_dtype: Optional[Union[Any, Sequence[Any]]] = None,
21
+ ) -> Callable[..., Any]:
22
+ import jax
23
+
24
+ @wraps(fun)
25
+ def fun_jax(*x: Any) -> Any:
26
+ def wrapped_fun(*args: Any) -> Any:
27
+ args = general_args_to_backend(args, enable_dlpack=enable_dlpack)
28
+ y = fun(*args)
29
+ y = general_args_to_backend(
30
+ y, target_backend="jax", enable_dlpack=enable_dlpack
31
+ )
32
+ return y
33
+
34
+ # Use provided shape and dtype if available, otherwise run test
35
+ if output_shape is not None and output_dtype is not None:
36
+ if isinstance(output_shape, Sequence) and not isinstance(
37
+ output_shape[0], int
38
+ ):
39
+ # Multiple outputs case
40
+ out_shape = tuple(
41
+ jax.ShapeDtypeStruct(s, d)
42
+ for s, d in zip(output_shape, output_dtype)
43
+ )
44
+ else:
45
+ # Single output case
46
+ out_shape = jax.ShapeDtypeStruct(output_shape, output_dtype) # type: ignore
47
+ else:
48
+ # Get expected output shape by running function once
49
+ test_out = wrapped_fun(*x)
50
+ if isinstance(test_out, tuple):
51
+ # Multiple outputs case
52
+ out_shape = tuple(
53
+ jax.ShapeDtypeStruct(
54
+ t.shape if hasattr(t, "shape") else (),
55
+ t.dtype if hasattr(t, "dtype") else x[0].dtype,
56
+ )
57
+ for t in test_out
58
+ )
59
+ else:
60
+ # Single output case
61
+ out_shape = jax.ShapeDtypeStruct( # type: ignore
62
+ test_out.shape if hasattr(test_out, "shape") else (),
63
+ test_out.dtype if hasattr(test_out, "dtype") else x[0].dtype,
64
+ )
65
+
66
+ # Use pure_callback with correct output shape
67
+ result = jax.pure_callback(wrapped_fun, out_shape, *x)
68
+ return result
69
+
70
+ return fun_jax
71
+
72
+
73
+ def jax_interface(
74
+ fun: Callable[..., Any],
75
+ jit: bool = False,
76
+ enable_dlpack: bool = False,
77
+ output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]] = None,
78
+ output_dtype: Optional[Any] = None,
79
+ ) -> Callable[..., Any]:
80
+ """
81
+ Wrap a function on different ML backend with a jax interface.
82
+
83
+ :Example:
84
+
85
+ .. code-block:: python
86
+
87
+ tc.set_backend("tensorflow")
88
+
89
+ def f(params):
90
+ c = tc.Circuit(1)
91
+ c.rx(0, theta=params[0])
92
+ c.ry(0, theta=params[1])
93
+ return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
94
+
95
+ f = tc.interfaces.jax_interface(f, jit=True)
96
+
97
+ params = jnp.ones(2)
98
+ value, grad = jax.value_and_grad(f)(params)
99
+
100
+ :param fun: The quantum function with tensor in and tensor out
101
+ :type fun: Callable[..., Any]
102
+ :param jit: whether to jit ``fun``, defaults to False
103
+ :type jit: bool, optional
104
+ :param enable_dlpack: whether transform tensor backend via dlpack, defaults to False
105
+ :type enable_dlpack: bool, optional
106
+ :param output_shape: Optional shape of the function output, defaults to None
107
+ :type output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]], optional
108
+ :param output_dtype: Optional dtype of the function output, defaults to None
109
+ :type output_dtype: Optional[Any], optional
110
+ :return: The same quantum function but now with jax array in and jax array out
111
+ while AD is also supported
112
+ :rtype: Callable[..., Any]
113
+ """
114
+ jax_fun = create_jax_function(
115
+ fun,
116
+ enable_dlpack=enable_dlpack,
117
+ jit=jit,
118
+ output_shape=output_shape,
119
+ output_dtype=output_dtype,
120
+ )
121
+ return jax_fun
122
+
123
+
124
+ def create_jax_function(
125
+ fun: Callable[..., Any],
126
+ enable_dlpack: bool = False,
127
+ jit: bool = False,
128
+ output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]] = None,
129
+ output_dtype: Optional[Any] = None,
130
+ ) -> Callable[..., Any]:
131
+ import jax
132
+ from jax import custom_vjp
133
+
134
+ if jit:
135
+ fun = backend.jit(fun)
136
+
137
+ wrapped = jax_wrapper(
138
+ fun,
139
+ enable_dlpack=enable_dlpack,
140
+ output_shape=output_shape,
141
+ output_dtype=output_dtype,
142
+ )
143
+
144
+ @custom_vjp
145
+ def f(*x: Any) -> Any:
146
+ return wrapped(*x)
147
+
148
+ def f_fwd(*x: Any) -> Tuple[Any, Tuple[Any, ...]]:
149
+ y = wrapped(*x)
150
+ return y, x
151
+
152
+ def f_bwd(res: Tuple[Any, ...], g: Any) -> Tuple[Any, ...]:
153
+ x = res
154
+
155
+ if len(x) == 1:
156
+ x = x[0]
157
+
158
+ vjp_fun = partial(backend.vjp, fun)
159
+ if jit:
160
+ vjp_fun = backend.jit(vjp_fun) # type: ignore
161
+
162
+ def vjp_wrapped(args: Any) -> Any:
163
+ args = general_args_to_backend(args, enable_dlpack=enable_dlpack)
164
+ gb = general_args_to_backend(g, enable_dlpack=enable_dlpack)
165
+ r = vjp_fun(args, gb)[1]
166
+ r = general_args_to_backend(
167
+ r, target_backend="jax", enable_dlpack=enable_dlpack
168
+ )
169
+ return r
170
+
171
+ # Handle gradient shape for both single input and tuple inputs
172
+ if isinstance(x, tuple):
173
+ # Create a tuple of ShapeDtypeStruct for each input
174
+ grad_shape = tuple(jax.ShapeDtypeStruct(xi.shape, xi.dtype) for xi in x)
175
+ else:
176
+ grad_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
177
+
178
+ dx = jax.pure_callback(
179
+ vjp_wrapped,
180
+ grad_shape,
181
+ x,
182
+ )
183
+
184
+ if not isinstance(dx, tuple):
185
+ dx = (dx,)
186
+ return dx # type: ignore
187
+
188
+ f.defvjp(f_fwd, f_bwd)
189
+ return f
tensorcircuit/keras.py CHANGED
@@ -24,7 +24,7 @@ class QuantumLayer(Layer): # type: ignore
24
24
  initializer: Union[Text, Sequence[Text]] = "glorot_uniform",
25
25
  constraint: Optional[Union[Text, Sequence[Text]]] = None,
26
26
  regularizer: Optional[Union[Text, Sequence[Text]]] = None,
27
- **kwargs: Any
27
+ **kwargs: Any,
28
28
  ) -> None:
29
29
  """
30
30
  `QuantumLayer` wraps the quantum function `f` as a `keras.Layer`
@@ -103,7 +103,7 @@ class QuantumLayer(Layer): # type: ignore
103
103
  inputs: tf.Tensor,
104
104
  training: Optional[bool] = None,
105
105
  mask: Optional[tf.Tensor] = None,
106
- **kwargs: Any
106
+ **kwargs: Any,
107
107
  ) -> tf.Tensor:
108
108
  # input_shape = list(inputs.shape)
109
109
  # inputs = tf.reshape(inputs, (-1, input_shape[-1]))
@@ -154,7 +154,7 @@ class HardwareLayer(QuantumLayer):
154
154
  inputs: tf.Tensor,
155
155
  training: Optional[bool] = None,
156
156
  mask: Optional[tf.Tensor] = None,
157
- **kwargs: Any
157
+ **kwargs: Any,
158
158
  ) -> tf.Tensor:
159
159
  if inputs is None: # not possible
160
160
  result = self.f(*self.pqc_weights, **kwargs)