qadence 1.11.1__py3-none-any.whl → 1.11.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/backend.py +33 -10
- qadence/backends/agpsr_utils.py +96 -0
- qadence/backends/api.py +8 -1
- qadence/backends/horqrux/backend.py +24 -10
- qadence/backends/horqrux/config.py +17 -1
- qadence/backends/horqrux/convert_ops.py +20 -97
- qadence/backends/jax_utils.py +5 -2
- qadence/backends/{gpsr.py → parameter_shift_rules.py} +48 -30
- qadence/backends/pulser/backend.py +16 -9
- qadence/backends/pulser/config.py +18 -0
- qadence/backends/pyqtorch/backend.py +25 -11
- qadence/backends/pyqtorch/config.py +18 -0
- qadence/blocks/embedding.py +10 -1
- qadence/blocks/primitive.py +2 -3
- qadence/blocks/utils.py +33 -24
- qadence/engines/differentiable_backend.py +7 -1
- qadence/engines/jax/differentiable_backend.py +7 -1
- qadence/engines/torch/differentiable_backend.py +12 -9
- qadence/engines/torch/differentiable_expectation.py +12 -11
- qadence/extensions.py +0 -10
- qadence/ml_tools/__init__.py +2 -0
- qadence/ml_tools/callbacks/callbackmanager.py +4 -2
- qadence/ml_tools/constructors.py +264 -4
- qadence/ml_tools/qcnn_model.py +158 -0
- qadence/model.py +113 -8
- qadence/parameters.py +2 -0
- qadence/serialization.py +1 -1
- qadence/transpile/__init__.py +3 -2
- qadence/transpile/block.py +58 -5
- qadence/types.py +2 -4
- qadence/utils.py +39 -8
- {qadence-1.11.1.dist-info → qadence-1.11.3.dist-info}/METADATA +22 -11
- {qadence-1.11.1.dist-info → qadence-1.11.3.dist-info}/RECORD +35 -33
- qadence-1.11.3.dist-info/licenses/LICENSE +13 -0
- qadence-1.11.1.dist-info/licenses/LICENSE +0 -202
- {qadence-1.11.1.dist-info → qadence-1.11.3.dist-info}/WHEEL +0 -0
    
        qadence/backend.py
    CHANGED
    
    | @@ -25,7 +25,14 @@ from qadence.measurements import Measurements | |
| 25 25 | 
             
            from qadence.mitigations import Mitigations
         | 
| 26 26 | 
             
            from qadence.noise import NoiseHandler
         | 
| 27 27 | 
             
            from qadence.parameters import stringify
         | 
| 28 | 
            -
            from qadence.types import  | 
| 28 | 
            +
            from qadence.types import (
         | 
| 29 | 
            +
                ArrayLike,
         | 
| 30 | 
            +
                BackendName,
         | 
| 31 | 
            +
                DiffMode,
         | 
| 32 | 
            +
                Endianness,
         | 
| 33 | 
            +
                Engine,
         | 
| 34 | 
            +
                ParamDictType,
         | 
| 35 | 
            +
            )
         | 
| 29 36 |  | 
| 30 37 | 
             
            logger = getLogger(__name__)
         | 
| 31 38 |  | 
| @@ -54,11 +61,18 @@ class BackendConfiguration: | |
| 54 61 | 
             
                    conf_msg = ""
         | 
| 55 62 | 
             
                    for _field in fields(self):
         | 
| 56 63 | 
             
                        if not _field.name.startswith("_"):
         | 
| 57 | 
            -
                            conf_msg += (
         | 
| 58 | 
            -
                                f"Name: {_field.name} - Type: {_field.type} - Default value: {_field.default}\n"
         | 
| 59 | 
            -
                            )
         | 
| 64 | 
            +
                            conf_msg += f"Name: {_field.name} - Type: {_field.type} - Current value: {getattr(self, _field.name)} - Default value: {_field.default}\n"
         | 
| 60 65 | 
             
                    return conf_msg
         | 
| 61 66 |  | 
| 67 | 
            +
                def change_config(self, new_config: dict) -> None:
         | 
| 68 | 
            +
                    """Change configuration with the input."""
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    for key, value in new_config.items():
         | 
| 71 | 
            +
                        if hasattr(self, key):
         | 
| 72 | 
            +
                            setattr(self, key, value)
         | 
| 73 | 
            +
                        else:
         | 
| 74 | 
            +
                            raise ValueError(f"Warning: '{key}' is not a valid configuration attribute.")
         | 
| 75 | 
            +
             | 
| 62 76 | 
             
                @classmethod
         | 
| 63 77 | 
             
                def from_dict(cls, values: dict) -> BackendConfiguration:
         | 
| 64 78 | 
             
                    field_names = {field.name for field in fields(cls)}
         | 
| @@ -208,7 +222,7 @@ class Backend(ABC): | |
| 208 222 | 
             
                    if observable is not None:
         | 
| 209 223 | 
             
                        observable = observable if isinstance(observable, list) else [observable]
         | 
| 210 224 | 
             
                        conv_obs = []
         | 
| 211 | 
            -
                         | 
| 225 | 
            +
                        obs_embedding_fns = []
         | 
| 212 226 |  | 
| 213 227 | 
             
                        for obs in observable:
         | 
| 214 228 | 
             
                            obs = check_observable(obs)
         | 
| @@ -217,13 +231,18 @@ class Backend(ABC): | |
| 217 231 | 
             
                                c_obs.abstract, self.config._use_gate_params, self.engine
         | 
| 218 232 | 
             
                            )
         | 
| 219 233 | 
             
                            params.update(obs_params)
         | 
| 220 | 
            -
                             | 
| 234 | 
            +
                            obs_embedding_fns.append(obs_embedding_fn)
         | 
| 221 235 | 
             
                            conv_obs.append(c_obs)
         | 
| 222 236 |  | 
| 223 237 | 
             
                        def embedding_fn_dict(a: dict, b: dict) -> dict:
         | 
| 224 | 
            -
                             | 
| 225 | 
            -
             | 
| 226 | 
            -
                                 | 
| 238 | 
            +
                            if "circuit" in b or "observables" in b:
         | 
| 239 | 
            +
                                embedding_dict = {"circuit": circ_embedding_fn(a, b), "observables": dict()}
         | 
| 240 | 
            +
                                for obs_embedding_fn in obs_embedding_fns:
         | 
| 241 | 
            +
                                    embedding_dict["observables"].update(obs_embedding_fn(a, b))
         | 
| 242 | 
            +
                            else:
         | 
| 243 | 
            +
                                embedding_dict = circ_embedding_fn(a, b)
         | 
| 244 | 
            +
                                for obs_embedding_fn in obs_embedding_fns:
         | 
| 245 | 
            +
                                    embedding_dict.update(obs_embedding_fn(a, b))
         | 
| 227 246 | 
             
                            return embedding_dict
         | 
| 228 247 |  | 
| 229 248 | 
             
                        return Converted(conv_circ, conv_obs, embedding_fn_dict, params)
         | 
| @@ -309,7 +328,11 @@ class Backend(ABC): | |
| 309 328 | 
             
                    raise NotImplementedError
         | 
| 310 329 |  | 
| 311 330 | 
             
                @abstractmethod
         | 
| 312 | 
            -
                def assign_parameters( | 
| 331 | 
            +
                def assign_parameters(
         | 
| 332 | 
            +
                    self,
         | 
| 333 | 
            +
                    circuit: ConvertedCircuit,
         | 
| 334 | 
            +
                    param_values: dict[str, Tensor] | dict[str, dict[str, Tensor]],
         | 
| 335 | 
            +
                ) -> Any:
         | 
| 313 336 | 
             
                    raise NotImplementedError
         | 
| 314 337 |  | 
| 315 338 | 
             
                @staticmethod
         | 
| @@ -0,0 +1,96 @@ | |
| 1 | 
            +
            from __future__ import annotations
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from functools import lru_cache
         | 
| 4 | 
            +
            from typing import Callable
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from scipy.optimize import minimize
         | 
| 9 | 
            +
            from torch import Tensor
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def variance(shifts: Tensor, spectral_gaps: Tensor) -> Tensor:
         | 
| 13 | 
            +
                """Calculate the exact variance of deirivative estimation using aGPSR.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                Args:
         | 
| 16 | 
            +
                    shifts (Tensor): shifts to apply for each spectral gap
         | 
| 17 | 
            +
                    spectral_gaps (Tensor): tensor containing spectral gap values
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                Returns:
         | 
| 20 | 
            +
                    Tensor: variance tensor
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                # calculate inverse of M (see: https://arxiv.org/pdf/2108.01218.pdf on p. 4 for definitions)
         | 
| 24 | 
            +
                M = 4 * torch.sin(torch.outer(torch.as_tensor(shifts), spectral_gaps) / 2)
         | 
| 25 | 
            +
                try:
         | 
| 26 | 
            +
                    # calculate the variance of derivative estimation by solving a linear equation system
         | 
| 27 | 
            +
                    a = torch.linalg.solve(M, spectral_gaps.reshape(-1, 1))
         | 
| 28 | 
            +
                    var = 2 * torch.matmul(a.T, a)
         | 
| 29 | 
            +
                except RuntimeError:  # matrix M is singulkar
         | 
| 30 | 
            +
                    # fallback method of variance calculation using inverse matrix
         | 
| 31 | 
            +
                    M_inv = torch.linalg.pinv(M)
         | 
| 32 | 
            +
                    a = torch.matmul(spectral_gaps.reshape(1, -1), M_inv)
         | 
| 33 | 
            +
                    var = 2 * torch.matmul(a, a.T)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                return var
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            @lru_cache
         | 
| 39 | 
            +
            def calculate_optimal_shifts(
         | 
| 40 | 
            +
                n_eqs: int,
         | 
| 41 | 
            +
                spectral_gaps: Tensor,
         | 
| 42 | 
            +
                lb: float,
         | 
| 43 | 
            +
                ub: float,
         | 
| 44 | 
            +
            ) -> Tensor:
         | 
| 45 | 
            +
                """Calculates optimal shift values for GPSR algorithm.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                Args:
         | 
| 48 | 
            +
                    n_eqs (int): number of equations in linear equation system for derivative estimation
         | 
| 49 | 
            +
                    spectral_gaps (Tensor): tensor containing spectral gap values
         | 
| 50 | 
            +
                    lb (float): lower bound of optimal shift value search interval
         | 
| 51 | 
            +
                    ub (float): upper bound of optimal shift value search interval
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                Returns:
         | 
| 54 | 
            +
                    Tensor: optimal shift values
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                if not (lb and ub):
         | 
| 57 | 
            +
                    raise ValueError("Both lower and upper bounds of optimization interval must be given.")
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                constraints = []
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                # specify solution bound constraints
         | 
| 62 | 
            +
                for i in range(n_eqs):
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    def fn_lb(x, i=i):  # type: ignore [no-untyped-def]
         | 
| 65 | 
            +
                        return x[i] - lb
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    def fn_ub(x, i=i):  # type: ignore [no-untyped-def]
         | 
| 68 | 
            +
                        return ub - x[i]
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    constraints.append({"type": "ineq", "fun": fn_lb})
         | 
| 71 | 
            +
                    constraints.append({"type": "ineq", "fun": fn_ub})
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                # specify constraints for solutions to be unique
         | 
| 74 | 
            +
                for i in range(n_eqs - 1, 0, -1):
         | 
| 75 | 
            +
                    for j in range(i):
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                        def fn(x, i=i, j=j):  # type: ignore [no-untyped-def]
         | 
| 78 | 
            +
                            return np.abs(x[i] - x[j]) - 0.02
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                        constraints.append({"type": "ineq", "fun": fn})
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                init_guess = torch.linspace(lb, ub, n_eqs)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def minimize_variance(
         | 
| 85 | 
            +
                    var_fn: Callable[[Tensor, Tensor], Tensor],
         | 
| 86 | 
            +
                ) -> Tensor:
         | 
| 87 | 
            +
                    res = minimize(
         | 
| 88 | 
            +
                        fun=var_fn,
         | 
| 89 | 
            +
                        x0=init_guess,
         | 
| 90 | 
            +
                        args=(spectral_gaps,),
         | 
| 91 | 
            +
                        method="COBYLA",
         | 
| 92 | 
            +
                        constraints=constraints,
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
                    return torch.as_tensor(res.x)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                return minimize_variance(variance)
         | 
    
        qadence/backends/api.py
    CHANGED
    
    | @@ -50,7 +50,14 @@ def backend_factory( | |
| 50 50 | 
             
                    # Wrap the quantum Backend in a DifferentiableBackend if a diff_mode is passed.
         | 
| 51 51 | 
             
                    if diff_mode is not None:
         | 
| 52 52 | 
             
                        diff_backend_cls = import_engine(backend_inst.engine)
         | 
| 53 | 
            -
                         | 
| 53 | 
            +
                        psr_args = {
         | 
| 54 | 
            +
                            "n_eqs": configuration.n_eqs,  # type: ignore [attr-defined]
         | 
| 55 | 
            +
                            "shift_prefac": configuration.shift_prefac,  # type: ignore [attr-defined]
         | 
| 56 | 
            +
                            "gap_step": configuration.gap_step,  # type: ignore [attr-defined]
         | 
| 57 | 
            +
                            "lb": configuration.lb,  # type: ignore [attr-defined]
         | 
| 58 | 
            +
                            "ub": configuration.ub,  # type: ignore [attr-defined]
         | 
| 59 | 
            +
                        }
         | 
| 60 | 
            +
                        backend_inst = diff_backend_cls(backend=backend_inst, diff_mode=DiffMode(diff_mode), **psr_args)  # type: ignore[operator]
         | 
| 54 61 | 
             
                    return backend_inst
         | 
| 55 62 | 
             
                except (BackendNotFoundError, EngineNotFoundError, ConfigNotFoundError) as e:
         | 
| 56 63 | 
             
                    logger.error(e.msg)
         | 
| @@ -7,7 +7,7 @@ from typing import Any | |
| 7 7 |  | 
| 8 8 | 
             
            import jax
         | 
| 9 9 | 
             
            import jax.numpy as jnp
         | 
| 10 | 
            -
            from horqrux.utils import zero_state
         | 
| 10 | 
            +
            from horqrux.utils.operator_utils import zero_state
         | 
| 11 11 | 
             
            from jax.typing import ArrayLike
         | 
| 12 12 |  | 
| 13 13 | 
             
            from qadence.backend import Backend as BackendInterface
         | 
| @@ -27,7 +27,8 @@ from qadence.types import BackendName, Endianness, Engine, ParamDictType | |
| 27 27 | 
             
            from qadence.utils import int_to_basis
         | 
| 28 28 |  | 
| 29 29 | 
             
            from .config import Configuration, default_passes
         | 
| 30 | 
            -
            from .convert_ops import  | 
| 30 | 
            +
            from .convert_ops import convert_block, convert_observable
         | 
| 31 | 
            +
            from horqrux.circuit import QuantumCircuit as HorqruxCircuit
         | 
| 31 32 |  | 
| 32 33 | 
             
            logger = getLogger(__name__)
         | 
| 33 34 |  | 
| @@ -58,7 +59,7 @@ class Backend(BackendInterface): | |
| 58 59 | 
             
                        circuit = transpile(*passes)(circuit)
         | 
| 59 60 | 
             
                    ops = convert_block(circuit.block, n_qubits=circuit.n_qubits, config=self.config)
         | 
| 60 61 | 
             
                    return ConvertedCircuit(
         | 
| 61 | 
            -
                        native=HorqruxCircuit(ops), abstract=circuit, original=original_circ
         | 
| 62 | 
            +
                        native=HorqruxCircuit(circuit.n_qubits, ops), abstract=circuit, original=original_circ
         | 
| 62 63 | 
             
                    )
         | 
| 63 64 |  | 
| 64 65 | 
             
                def observable(self, observable: AbstractBlock, n_qubits: int) -> ConvertedObservable:
         | 
| @@ -97,7 +98,7 @@ class Backend(BackendInterface): | |
| 97 98 | 
             
                        state = zero_state(n_qubits)
         | 
| 98 99 | 
             
                    else:
         | 
| 99 100 | 
             
                        state = horqify(state) if horqify_state else state
         | 
| 100 | 
            -
                    state = circuit.native | 
| 101 | 
            +
                    state = circuit.native(state, param_values)
         | 
| 101 102 | 
             
                    if endianness != self.native_endianness:
         | 
| 102 103 | 
             
                        state = jnp.reshape(state, (1, 2**n_qubits))  # batch_size is always 1
         | 
| 103 104 | 
             
                        ls = list(range(2**n_qubits))
         | 
| @@ -133,19 +134,32 @@ class Backend(BackendInterface): | |
| 133 134 | 
             
                    Returns:
         | 
| 134 135 | 
             
                        A jax.Array of shape (batch_size, n_observables)
         | 
| 135 136 | 
             
                    """
         | 
| 136 | 
            -
                     | 
| 137 | 
            -
                     | 
| 137 | 
            +
                    observables = observable if isinstance(observable, list) else [observable]
         | 
| 138 | 
            +
                    if "observables" in param_values or "circuit" in param_values:
         | 
| 139 | 
            +
                        raise NotImplementedError("The Horqrux backend does not support separated parameters.")
         | 
| 140 | 
            +
                    else:
         | 
| 141 | 
            +
                        merged_params = param_values
         | 
| 142 | 
            +
                        batch_size = max([arr.size for arr in param_values.values()])  # type: ignore[union-attr]
         | 
| 138 143 | 
             
                    n_obs = len(observable)
         | 
| 139 144 |  | 
| 140 145 | 
             
                    def _expectation(params: ParamDictType) -> ArrayLike:
         | 
| 146 | 
            +
                        param_circuits = params["circuit"] if "circuit" in params else params
         | 
| 147 | 
            +
                        param_observables = params["observables"] if "observables" in params else params
         | 
| 141 148 | 
             
                        out_state = self.run(
         | 
| 142 | 
            -
                            circuit, | 
| 149 | 
            +
                            circuit,
         | 
| 150 | 
            +
                            param_circuits,
         | 
| 151 | 
            +
                            state,
         | 
| 152 | 
            +
                            endianness,
         | 
| 153 | 
            +
                            horqify_state=True,
         | 
| 154 | 
            +
                            unhorqify_state=False,
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
                        return jnp.array(
         | 
| 157 | 
            +
                            [observable.native(out_state, param_observables) for observable in observables]
         | 
| 143 158 | 
             
                        )
         | 
| 144 | 
            -
                        return jnp.array([o.native.forward(out_state, params) for o in observable])
         | 
| 145 159 |  | 
| 146 160 | 
             
                    if batch_size > 1:  # We vmap for batch_size > 1
         | 
| 147 | 
            -
                        expvals = jax.vmap(_expectation, in_axes=({k: 0 for k in  | 
| 148 | 
            -
                            uniform_batchsize( | 
| 161 | 
            +
                        expvals = jax.vmap(_expectation, in_axes=({k: 0 for k in merged_params.keys()},))(
         | 
| 162 | 
            +
                            uniform_batchsize(merged_params)
         | 
| 149 163 | 
             
                        )
         | 
| 150 164 | 
             
                    else:
         | 
| 151 165 | 
             
                        expvals = _expectation(param_values)
         | 
| @@ -32,4 +32,20 @@ def default_passes(config: Configuration) -> list[Callable]: | |
| 32 32 |  | 
| 33 33 | 
             
            @dataclass
         | 
| 34 34 | 
             
            class Configuration(BackendConfiguration):
         | 
| 35 | 
            -
                 | 
| 35 | 
            +
                n_eqs: int | None = None
         | 
| 36 | 
            +
                """Number of equations to use in aGPSR calculations."""
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                shift_prefac: float = 0.5
         | 
| 39 | 
            +
                """Prefactor governing the magnitude of parameter shift values.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                Select smaller value if spectral gaps are large.
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                gap_step: float = 1.0
         | 
| 45 | 
            +
                """Step between generated pseudo-gaps when using aGPSR algorithm."""
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                lb: float | None = None
         | 
| 48 | 
            +
                """Lower bound of optimal shift value search interval."""
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                ub: float | None = None
         | 
| 51 | 
            +
                """Upper bound of optimal shift value search interval."""
         | 
| @@ -8,11 +8,15 @@ from typing import Any, Callable, Dict | |
| 8 8 |  | 
| 9 9 | 
             
            import jax.numpy as jnp
         | 
| 10 10 | 
             
            from horqrux.analog import _HamiltonianEvolution as NativeHorqHEvo
         | 
| 11 | 
            -
            from horqrux.apply import  | 
| 12 | 
            -
            from horqrux.parametric import RX, RY, RZ
         | 
| 13 | 
            -
            from horqrux.primitive import NOT, SWAP, H, I, X, Y, Z
         | 
| 14 | 
            -
            from horqrux.primitive import Primitive as Gate
         | 
| 15 | 
            -
            from horqrux. | 
| 11 | 
            +
            from horqrux.apply import apply_gates
         | 
| 12 | 
            +
            from horqrux.primitives.parametric import RX, RY, RZ
         | 
| 13 | 
            +
            from horqrux.primitives.primitive import NOT, SWAP, H, I, X, Y, Z
         | 
| 14 | 
            +
            from horqrux.primitives.primitive import Primitive as Gate
         | 
| 15 | 
            +
            from horqrux.composite.sequence import OpSequence as HorqruxSequence
         | 
| 16 | 
            +
            from horqrux.composite.compose import Scale as HorqScaleGate
         | 
| 17 | 
            +
            from horqrux.composite.compose import Add as HorqAddGate
         | 
| 18 | 
            +
            from horqrux.composite.compose import Observable as HorqruxObservable
         | 
| 19 | 
            +
            from horqrux.utils.operator_utils import ControlQubits, TargetQubits
         | 
| 16 20 | 
             
            from jax import Array
         | 
| 17 21 | 
             
            from jax.scipy.linalg import expm
         | 
| 18 22 | 
             
            from jax.tree_util import register_pytree_node_class
         | 
| @@ -36,6 +40,7 @@ from qadence.operations import ( | |
| 36 40 | 
             
                MCRY,
         | 
| 37 41 | 
             
                MCRZ,
         | 
| 38 42 | 
             
                MCZ,
         | 
| 43 | 
            +
                CZ,
         | 
| 39 44 | 
             
            )
         | 
| 40 45 | 
             
            from qadence.operations import SWAP as QDSWAP
         | 
| 41 46 | 
             
            from qadence.types import OpName, ParamDictType
         | 
| @@ -53,6 +58,7 @@ ops_map: Dict[str, Callable] = { | |
| 53 58 | 
             
                OpName.CRX: RX,
         | 
| 54 59 | 
             
                OpName.CRY: RY,
         | 
| 55 60 | 
             
                OpName.CRZ: RZ,
         | 
| 61 | 
            +
                OpName.CZ: Z,
         | 
| 56 62 | 
             
                OpName.CNOT: NOT,
         | 
| 57 63 | 
             
                OpName.I: I,
         | 
| 58 64 | 
             
                OpName.SWAP: SWAP,
         | 
| @@ -61,37 +67,6 @@ ops_map: Dict[str, Callable] = { | |
| 61 67 | 
             
            supported_gates = list(set(list(ops_map.keys())))
         | 
| 62 68 |  | 
| 63 69 |  | 
| 64 | 
            -
            @register_pytree_node_class
         | 
| 65 | 
            -
            @dataclass
         | 
| 66 | 
            -
            class HorqruxCircuit:
         | 
| 67 | 
            -
                operators: list[Gate] = field(default_factory=list)
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                def tree_flatten(self) -> tuple[tuple[list[Any]], tuple[()]]:
         | 
| 70 | 
            -
                    children = (self.operators,)
         | 
| 71 | 
            -
                    aux_data = ()
         | 
| 72 | 
            -
                    return (children, aux_data)
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                @classmethod
         | 
| 75 | 
            -
                def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
         | 
| 76 | 
            -
                    return cls(*children, *aux_data)
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                def _forward(self, state: Array, values: ParamDictType) -> Array:
         | 
| 79 | 
            -
                    return reduce(lambda state, gate: gate.forward(state, values), self.operators, state)
         | 
| 80 | 
            -
             | 
| 81 | 
            -
                def forward(self, state: Array, values: ParamDictType) -> Array:
         | 
| 82 | 
            -
                    return self._forward(state, values)
         | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
            @register_pytree_node_class
         | 
| 86 | 
            -
            @dataclass
         | 
| 87 | 
            -
            class HorqruxObservable(HorqruxCircuit):
         | 
| 88 | 
            -
                def __init__(self, operators: list[Gate]):
         | 
| 89 | 
            -
                    super().__init__(operators=operators)
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                def forward(self, state: Array, values: ParamDictType) -> Array:
         | 
| 92 | 
            -
                    return jnp.real(inner(state, self._forward(state, values)))
         | 
| 93 | 
            -
             | 
| 94 | 
            -
             | 
| 95 70 | 
             
            def convert_observable(
         | 
| 96 71 | 
             
                block: AbstractBlock, n_qubits: int, config: Configuration
         | 
| 97 72 | 
             
            ) -> HorqruxObservable:
         | 
| @@ -109,7 +84,7 @@ def convert_block( | |
| 109 84 | 
             
                ops = []
         | 
| 110 85 | 
             
                if isinstance(block, CompositeBlock):
         | 
| 111 86 | 
             
                    ops = list(flatten(*(convert_block(b, n_qubits, config) for b in block.blocks)))
         | 
| 112 | 
            -
                    ops = [HorqAddGate(ops)] if isinstance(block, AddBlock) else [ | 
| 87 | 
            +
                    ops = [HorqAddGate(ops)] if isinstance(block, AddBlock) else [HorqruxSequence(ops)]
         | 
| 113 88 | 
             
                elif isinstance(block, ScaleBlock):
         | 
| 114 89 | 
             
                    op = convert_block(block.block, n_qubits, config=config)[0]
         | 
| 115 90 | 
             
                    param_name = config.get_param_name(block)[0]
         | 
| @@ -118,7 +93,7 @@ def convert_block( | |
| 118 93 | 
             
                    native_op_fn = ops_map[block.name]
         | 
| 119 94 | 
             
                    target, control = (
         | 
| 120 95 | 
             
                        (block.qubit_support[1], block.qubit_support[0])
         | 
| 121 | 
            -
                        if isinstance(block, (CNOT, CRX, CRY, CRZ, QDSWAP))
         | 
| 96 | 
            +
                        if isinstance(block, (CZ, CNOT, CRX, CRY, CRZ, QDSWAP))
         | 
| 122 97 | 
             
                        else (block.qubit_support[0], (None,))
         | 
| 123 98 | 
             
                    )
         | 
| 124 99 | 
             
                    native_gate: Gate
         | 
| @@ -133,7 +108,7 @@ def convert_block( | |
| 133 108 | 
             
                            native_gate = native_op_fn(block.qubit_support[::-1])
         | 
| 134 109 | 
             
                        else:
         | 
| 135 110 | 
             
                            native_gate = native_op_fn(target=target, control=control)
         | 
| 136 | 
            -
                    ops = [ | 
| 111 | 
            +
                    ops = [native_gate]
         | 
| 137 112 |  | 
| 138 113 | 
             
                elif isinstance(block, (MCRX, MCRY, MCRZ, MCZ)):
         | 
| 139 114 | 
             
                    block_name = block.name[2:] if block.name.startswith("M") else block.name
         | 
| @@ -146,7 +121,7 @@ def convert_block( | |
| 146 121 | 
             
                        native_gate = native_op_fn(param=param, target=target, control=control)
         | 
| 147 122 | 
             
                    else:
         | 
| 148 123 | 
             
                        native_gate = native_op_fn(target, control)
         | 
| 149 | 
            -
                    ops = [ | 
| 124 | 
            +
                    ops = [native_gate]
         | 
| 150 125 | 
             
                elif isinstance(block, TimeEvolutionBlock):
         | 
| 151 126 | 
             
                    ops = [HorqHamiltonianEvolution(block, config)]
         | 
| 152 127 | 
             
                else:
         | 
| @@ -155,58 +130,6 @@ def convert_block( | |
| 155 130 | 
             
                return ops
         | 
| 156 131 |  | 
| 157 132 |  | 
| 158 | 
            -
            @register_pytree_node_class
         | 
| 159 | 
            -
            class HorqAddGate(HorqruxCircuit):
         | 
| 160 | 
            -
                def __init__(self, operations: list[Gate]):
         | 
| 161 | 
            -
                    self.operators = operations
         | 
| 162 | 
            -
                    self.name = "Add"
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                def forward(self, state: Array, values: ParamDictType = {}) -> Array:
         | 
| 165 | 
            -
                    return reduce(add, (gate.forward(state, values) for gate in self.operators))
         | 
| 166 | 
            -
             | 
| 167 | 
            -
                def __repr__(self) -> str:
         | 
| 168 | 
            -
                    return self.name + f"({self.operators})"
         | 
| 169 | 
            -
             | 
| 170 | 
            -
             | 
| 171 | 
            -
            @register_pytree_node_class
         | 
| 172 | 
            -
            @dataclass
         | 
| 173 | 
            -
            class HorqOperation:
         | 
| 174 | 
            -
                def __init__(self, native_gate: Gate):
         | 
| 175 | 
            -
                    self.native_gate = native_gate
         | 
| 176 | 
            -
             | 
| 177 | 
            -
                def forward(self, state: Array, values: ParamDictType) -> Array:
         | 
| 178 | 
            -
                    return apply_gate(state, self.native_gate, values)
         | 
| 179 | 
            -
             | 
| 180 | 
            -
                def tree_flatten(self) -> tuple[tuple[Gate], tuple[()]]:
         | 
| 181 | 
            -
                    children = (self.native_gate,)
         | 
| 182 | 
            -
                    aux_data = ()
         | 
| 183 | 
            -
                    return (children, aux_data)
         | 
| 184 | 
            -
             | 
| 185 | 
            -
                @classmethod
         | 
| 186 | 
            -
                def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
         | 
| 187 | 
            -
                    return cls(*children, *aux_data)
         | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
            @register_pytree_node_class
         | 
| 191 | 
            -
            @dataclass
         | 
| 192 | 
            -
            class HorqScaleGate:
         | 
| 193 | 
            -
                def __init__(self, gate: HorqOperation, parameter_name: str):
         | 
| 194 | 
            -
                    self.gate = gate
         | 
| 195 | 
            -
                    self.parameter: str = parameter_name
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                def forward(self, state: Array, values: ParamDictType) -> Array:
         | 
| 198 | 
            -
                    return jnp.array(values[self.parameter]) * self.gate.forward(state, values)
         | 
| 199 | 
            -
             | 
| 200 | 
            -
                def tree_flatten(self) -> tuple[tuple[HorqOperation], tuple[str]]:
         | 
| 201 | 
            -
                    children = (self.gate,)
         | 
| 202 | 
            -
                    aux_data = (self.parameter,)
         | 
| 203 | 
            -
                    return (children, aux_data)
         | 
| 204 | 
            -
             | 
| 205 | 
            -
                @classmethod
         | 
| 206 | 
            -
                def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
         | 
| 207 | 
            -
                    return cls(*children, *aux_data)
         | 
| 208 | 
            -
             | 
| 209 | 
            -
             | 
| 210 133 | 
             
            @register_pytree_node_class
         | 
| 211 134 | 
             
            @dataclass
         | 
| 212 135 | 
             
            class HorqHamiltonianEvolution(NativeHorqHEvo):
         | 
| @@ -216,7 +139,7 @@ class HorqHamiltonianEvolution(NativeHorqHEvo): | |
| 216 139 | 
             
                    config: Configuration,
         | 
| 217 140 | 
             
                ):
         | 
| 218 141 | 
             
                    super().__init__("I", block.qubit_support, (None,))
         | 
| 219 | 
            -
                    self. | 
| 142 | 
            +
                    self._qubit_support = block.qubit_support
         | 
| 220 143 | 
             
                    self.param_names = config.get_param_name(block)
         | 
| 221 144 | 
             
                    self.block = block
         | 
| 222 145 | 
             
                    self.hmat: Array
         | 
| @@ -224,7 +147,7 @@ class HorqHamiltonianEvolution(NativeHorqHEvo): | |
| 224 147 | 
             
                    if isinstance(block.generator, AbstractBlock) and not block.generator.is_parametric:
         | 
| 225 148 | 
             
                        hmat = block_to_jax(
         | 
| 226 149 | 
             
                            block.generator,
         | 
| 227 | 
            -
                            qubit_support=self. | 
| 150 | 
            +
                            qubit_support=self._qubit_support,
         | 
| 228 151 | 
             
                            use_full_support=False,
         | 
| 229 152 | 
             
                        )
         | 
| 230 153 | 
             
                        self.hmat = hmat
         | 
| @@ -236,7 +159,7 @@ class HorqHamiltonianEvolution(NativeHorqHEvo): | |
| 236 159 | 
             
                            hmat = block_to_jax(
         | 
| 237 160 | 
             
                                block.generator,  # type: ignore[arg-type]
         | 
| 238 161 | 
             
                                values=values,
         | 
| 239 | 
            -
                                qubit_support=self. | 
| 162 | 
            +
                                qubit_support=self._qubit_support,
         | 
| 240 163 | 
             
                                use_full_support=False,
         | 
| 241 164 | 
             
                            )
         | 
| 242 165 | 
             
                            return hmat
         | 
| @@ -249,12 +172,12 @@ class HorqHamiltonianEvolution(NativeHorqHEvo): | |
| 249 172 | 
             
                    """The evolved operator given current parameter values for generator and time evolution."""
         | 
| 250 173 | 
             
                    return expm(self._hamiltonian(self, values) * (-1j * self._time_evolution(values)))
         | 
| 251 174 |  | 
| 252 | 
            -
                def  | 
| 175 | 
            +
                def __call__(
         | 
| 253 176 | 
             
                    self,
         | 
| 254 177 | 
             
                    state: Array,
         | 
| 255 178 | 
             
                    values: dict[str, Array],
         | 
| 256 179 | 
             
                ) -> Array:
         | 
| 257 | 
            -
                    return  | 
| 180 | 
            +
                    return apply_gates(state, self, values)
         | 
| 258 181 |  | 
| 259 182 | 
             
                def tree_flatten(self) -> tuple[tuple[NativeHorqHEvo], tuple]:
         | 
| 260 183 | 
             
                    children = (self,)
         | 
    
        qadence/backends/jax_utils.py
    CHANGED
    
    | @@ -19,6 +19,7 @@ from qadence.blocks import ( | |
| 19 19 | 
             
            )
         | 
| 20 20 | 
             
            from qadence.blocks.block_to_tensor import _gate_parameters
         | 
| 21 21 | 
             
            from qadence.types import Endianness, ParamDictType
         | 
| 22 | 
            +
            from qadence.utils import merge_separate_params
         | 
| 22 23 |  | 
| 23 24 |  | 
| 24 25 | 
             
            def jarr_to_tensor(arr: Array, dtype: Any = cdouble) -> Tensor:
         | 
| @@ -52,9 +53,11 @@ def horqify(state: Array) -> Array: | |
| 52 53 |  | 
| 53 54 |  | 
| 54 55 | 
             
            def uniform_batchsize(param_values: ParamDictType) -> ParamDictType:
         | 
| 55 | 
            -
                 | 
| 56 | 
            +
                if "observables" in param_values or "circuit" in param_values:
         | 
| 57 | 
            +
                    param_values = merge_separate_params(param_values)
         | 
| 58 | 
            +
                max_batch_size = max(p.size for p in param_values.values())  # type: ignore[union-attr]
         | 
| 56 59 | 
             
                batched_values = {
         | 
| 57 | 
            -
                    k: (v if v.size == max_batch_size else v.repeat(max_batch_size))
         | 
| 60 | 
            +
                    k: (v if v.size == max_batch_size else v.repeat(max_batch_size))  # type: ignore[union-attr]
         | 
| 58 61 | 
             
                    for k, v in param_values.items()
         | 
| 59 62 | 
             
                }
         | 
| 60 63 | 
             
                return batched_values
         | 
| @@ -8,32 +8,45 @@ from torch import Tensor | |
| 8 8 |  | 
| 9 9 | 
             
            from qadence.types import PI
         | 
| 10 10 | 
             
            from qadence.utils import _round_complex
         | 
| 11 | 
            +
            from qadence.backends.agpsr_utils import calculate_optimal_shifts
         | 
| 11 12 |  | 
| 12 13 |  | 
| 13 | 
            -
            def general_psr( | 
| 14 | 
            +
            def general_psr(
         | 
| 15 | 
            +
                spectrum: Tensor,
         | 
| 16 | 
            +
                n_eqs: int | None = None,
         | 
| 17 | 
            +
                shift_prefac: float | None = 0.5,
         | 
| 18 | 
            +
                gap_step: float = 1.0,
         | 
| 19 | 
            +
                lb: float | None = None,
         | 
| 20 | 
            +
                ub: float | None = None,
         | 
| 21 | 
            +
            ) -> Callable:
         | 
| 14 22 | 
             
                """Define whether single_gap_psr or multi_gap_psr is used.
         | 
| 15 23 |  | 
| 16 24 | 
             
                Args:
         | 
| 17 25 | 
             
                    spectrum (Tensor): Spectrum of the operation we apply PSR onto.
         | 
| 18 | 
            -
                    n_eqs (int | None | 
| 19 | 
            -
                        If provided,  | 
| 20 | 
            -
                    shift_prefac (float | 
| 26 | 
            +
                    n_eqs (int | None): Number of equations. Defaults to None.
         | 
| 27 | 
            +
                        If provided, aGPSR algorithm is effectively used.
         | 
| 28 | 
            +
                    shift_prefac (float | None): prefactor governing the magnitude of parameter shift values -
         | 
| 29 | 
            +
                        select smaller value if spectral gaps are large. Defaults to 0.5.
         | 
| 30 | 
            +
                    gap_step (float): Step between generated pseudo-gaps when using aGPSR algorithm. Defaults to 1.0.
         | 
| 31 | 
            +
                    lb (float | None): Lower bound of optimal shift value search interval. Defaults to None.
         | 
| 32 | 
            +
                    ub (float | None): Upper bound of optimal shift value search interval. Defaults to None.
         | 
| 21 33 |  | 
| 22 34 | 
             
                Returns:
         | 
| 23 | 
            -
                    Callable: single_gap_psr or multi_gap_psr function for
         | 
| 24 | 
            -
                        concerned operation.
         | 
| 35 | 
            +
                    Callable: single_gap_psr or multi_gap_psr function for concerned operation.
         | 
| 25 36 | 
             
                """
         | 
| 37 | 
            +
             | 
| 26 38 | 
             
                diffs = _round_complex(spectrum - spectrum.reshape(-1, 1))
         | 
| 27 | 
            -
                 | 
| 39 | 
            +
                orig_unique_spectral_gaps = torch.unique(torch.abs(torch.tril(diffs)))
         | 
| 28 40 |  | 
| 29 41 | 
             
                # We have to filter out zeros
         | 
| 30 | 
            -
                 | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
                     | 
| 34 | 
            -
                     | 
| 35 | 
            -
                 | 
| 36 | 
            -
             | 
| 42 | 
            +
                orig_unique_spectral_gaps = orig_unique_spectral_gaps[orig_unique_spectral_gaps > 0]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                if n_eqs is None:  # GPSR case
         | 
| 45 | 
            +
                    n_eqs = len(orig_unique_spectral_gaps)
         | 
| 46 | 
            +
                    sorted_unique_spectral_gaps = orig_unique_spectral_gaps
         | 
| 47 | 
            +
                else:  # aGPSR case
         | 
| 48 | 
            +
                    sorted_unique_spectral_gaps = torch.arange(0, n_eqs) * gap_step
         | 
| 49 | 
            +
                    sorted_unique_spectral_gaps[0] = 0.001
         | 
| 37 50 |  | 
| 38 51 | 
             
                if n_eqs == 1:
         | 
| 39 52 | 
             
                    return partial(
         | 
| @@ -46,6 +59,8 @@ def general_psr(spectrum: Tensor, n_eqs: int | None = None, shift_prefac: float | |
| 46 59 | 
             
                        multi_gap_psr,
         | 
| 47 60 | 
             
                        spectral_gaps=sorted_unique_spectral_gaps,
         | 
| 48 61 | 
             
                        shift_prefac=shift_prefac,
         | 
| 62 | 
            +
                        lb=lb,
         | 
| 63 | 
            +
                        ub=ub,
         | 
| 49 64 | 
             
                    )
         | 
| 50 65 |  | 
| 51 66 |  | 
| @@ -60,8 +75,7 @@ def single_gap_psr( | |
| 60 75 |  | 
| 61 76 | 
             
                Args:
         | 
| 62 77 | 
             
                    expectation_fn (Callable[[dict[str, Tensor]], Tensor]): backend-dependent function
         | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 78 | 
            +
                        to calculate expectation value
         | 
| 65 79 | 
             
                    param_dict (dict[str, Tensor]): dict storing parameters of parameterized blocks
         | 
| 66 80 | 
             
                    param_name (str): name of parameter with respect to that differentiation is performed
         | 
| 67 81 |  | 
| @@ -93,19 +107,22 @@ def multi_gap_psr( | |
| 93 107 | 
             
                param_dict: dict[str, Tensor],
         | 
| 94 108 | 
             
                param_name: str,
         | 
| 95 109 | 
             
                spectral_gaps: Tensor,
         | 
| 96 | 
            -
                shift_prefac: float = 0.5,
         | 
| 110 | 
            +
                shift_prefac: float | None = 0.5,
         | 
| 111 | 
            +
                lb: float | None = None,
         | 
| 112 | 
            +
                ub: float | None = None,
         | 
| 97 113 | 
             
            ) -> Tensor:
         | 
| 98 114 | 
             
                """Implements multi-gap multi-qubit GPSR rule.
         | 
| 99 115 |  | 
| 100 116 | 
             
                Args:
         | 
| 101 117 | 
             
                    expectation_fn (Callable[[dict[str, Tensor]], Tensor]): backend-dependent function
         | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 118 | 
            +
                        to calculate expectation value
         | 
| 104 119 | 
             
                    param_dict (dict[str, Tensor]): dict storing parameters values of parameterized blocks
         | 
| 105 120 | 
             
                    param_name (str): name of parameter with respect to that differentiation is performed
         | 
| 106 | 
            -
             | 
| 121 | 
            +
                        spectral_gaps (Tensor): tensor containing spectral gap values
         | 
| 107 122 | 
             
                    shift_prefac (float): prefactor governing the magnitude of parameter shift values -
         | 
| 108 | 
            -
             | 
| 123 | 
            +
                        select smaller value if spectral gaps are large
         | 
| 124 | 
            +
                    lb (float): lower bound of optimal shift value search interval
         | 
| 125 | 
            +
                    ub (float): upper bound of optimal shift value search interval
         | 
| 109 126 |  | 
| 110 127 | 
             
                Returns:
         | 
| 111 128 | 
             
                    Tensor: tensor containing derivative values
         | 
| @@ -113,10 +130,15 @@ def multi_gap_psr( | |
| 113 130 | 
             
                n_eqs = len(spectral_gaps)
         | 
| 114 131 | 
             
                batch_size = max(t.size(0) for t in param_dict.values())
         | 
| 115 132 |  | 
| 116 | 
            -
                # get shift values
         | 
| 117 | 
            -
                 | 
| 118 | 
            -
                     | 
| 119 | 
            -
             | 
| 133 | 
            +
                # get shift values - values minimize the variance of expectation
         | 
| 134 | 
            +
                if shift_prefac is not None:
         | 
| 135 | 
            +
                    # Set shift values manually by breaking the symmetry of sampling range
         | 
| 136 | 
            +
                    # around PI/2 to reduce the possibility that M is singular
         | 
| 137 | 
            +
                    shifts = shift_prefac * torch.linspace(PI / 2 - PI / 4, PI / 2 + PI / 5, n_eqs)
         | 
| 138 | 
            +
                else:
         | 
| 139 | 
            +
                    # calculate optimal shift values
         | 
| 140 | 
            +
                    shifts = calculate_optimal_shifts(n_eqs, spectral_gaps, lb, ub)
         | 
| 141 | 
            +
             | 
| 120 142 | 
             
                device = torch.device("cpu")
         | 
| 121 143 | 
             
                try:
         | 
| 122 144 | 
             
                    device = [v.device for v in param_dict.values()][0]
         | 
| @@ -127,7 +149,7 @@ def multi_gap_psr( | |
| 127 149 | 
             
                # calculate F vector and M matrix
         | 
| 128 150 | 
             
                # (see: https://arxiv.org/pdf/2108.01218.pdf on p. 4 for definitions)
         | 
| 129 151 | 
             
                F = []
         | 
| 130 | 
            -
                M = torch. | 
| 152 | 
            +
                M = 4 * torch.sin(torch.outer(shifts, spectral_gaps) / 2).to(device=device)
         | 
| 131 153 | 
             
                n_obs = 1
         | 
| 132 154 | 
             
                for i in range(n_eqs):
         | 
| 133 155 | 
             
                    # + shift
         | 
| @@ -142,10 +164,6 @@ def multi_gap_psr( | |
| 142 164 |  | 
| 143 165 | 
             
                    F.append((f_plus - f_minus))
         | 
| 144 166 |  | 
| 145 | 
            -
                    # calculate M matrix
         | 
| 146 | 
            -
                    for j in range(n_eqs):
         | 
| 147 | 
            -
                        M[i, j] = 4 * torch.sin(shifts[i] * spectral_gaps[j] / 2)
         | 
| 148 | 
            -
             | 
| 149 167 | 
             
                # get number of observables from expectation value tensor
         | 
| 150 168 | 
             
                if f_plus.numel() > 1:
         | 
| 151 169 | 
             
                    batch_size = F[0].shape[0]
         |