qadence 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. qadence/__init__.py +1 -0
  2. qadence/analog/__init__.py +4 -2
  3. qadence/analog/addressing.py +167 -0
  4. qadence/analog/constants.py +59 -0
  5. qadence/analog/device.py +82 -0
  6. qadence/analog/hamiltonian_terms.py +101 -0
  7. qadence/analog/parse_analog.py +120 -0
  8. qadence/backend.py +42 -12
  9. qadence/backends/__init__.py +1 -2
  10. qadence/backends/api.py +27 -9
  11. qadence/backends/braket/backend.py +3 -2
  12. qadence/backends/horqrux/__init__.py +5 -0
  13. qadence/backends/horqrux/backend.py +216 -0
  14. qadence/backends/horqrux/config.py +26 -0
  15. qadence/backends/horqrux/convert_ops.py +273 -0
  16. qadence/backends/jax_utils.py +45 -0
  17. qadence/backends/pulser/__init__.py +0 -1
  18. qadence/backends/pulser/backend.py +31 -15
  19. qadence/backends/pulser/config.py +19 -10
  20. qadence/backends/pulser/devices.py +57 -63
  21. qadence/backends/pulser/pulses.py +70 -12
  22. qadence/backends/pyqtorch/backend.py +4 -4
  23. qadence/backends/pyqtorch/config.py +18 -12
  24. qadence/backends/pyqtorch/convert_ops.py +15 -7
  25. qadence/backends/utils.py +5 -9
  26. qadence/blocks/abstract.py +5 -1
  27. qadence/blocks/analog.py +18 -9
  28. qadence/blocks/block_to_tensor.py +11 -0
  29. qadence/blocks/embedding.py +46 -24
  30. qadence/blocks/primitive.py +81 -9
  31. qadence/blocks/utils.py +20 -1
  32. qadence/circuit.py +3 -9
  33. qadence/constructors/__init__.py +4 -0
  34. qadence/constructors/feature_maps.py +84 -60
  35. qadence/constructors/hamiltonians.py +27 -98
  36. qadence/constructors/rydberg_feature_maps.py +113 -0
  37. qadence/divergences.py +12 -0
  38. qadence/engines/__init__.py +0 -0
  39. qadence/engines/differentiable_backend.py +152 -0
  40. qadence/engines/jax/__init__.py +8 -0
  41. qadence/engines/jax/differentiable_backend.py +73 -0
  42. qadence/engines/jax/differentiable_expectation.py +94 -0
  43. qadence/engines/torch/__init__.py +4 -0
  44. qadence/engines/torch/differentiable_backend.py +85 -0
  45. qadence/extensions.py +21 -9
  46. qadence/finitediff.py +47 -0
  47. qadence/mitigations/readout.py +92 -25
  48. qadence/ml_tools/models.py +10 -3
  49. qadence/models/qnn.py +88 -23
  50. qadence/models/quantum_model.py +13 -2
  51. qadence/operations.py +55 -70
  52. qadence/parameters.py +24 -13
  53. qadence/register.py +91 -43
  54. qadence/transpile/__init__.py +1 -0
  55. qadence/transpile/apply_fn.py +40 -0
  56. qadence/types.py +32 -2
  57. qadence/utils.py +35 -0
  58. {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/METADATA +22 -3
  59. {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/RECORD +62 -44
  60. {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/WHEEL +1 -1
  61. qadence/analog/interaction.py +0 -198
  62. qadence/analog/utils.py +0 -132
  63. /qadence/{backends/pytorch_wrapper.py → engines/torch/differentiable_expectation.py} +0 -0
  64. {qadence-1.1.1.dist-info → qadence-1.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+ from jax import config
4
+
5
+ from .differentiable_backend import DifferentiableBackend
6
+ from .differentiable_expectation import DifferentiableExpectation
7
+
8
+ config.update("jax_enable_x64", True)
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ from qadence.backend import Backend, ConvertedCircuit, ConvertedObservable
4
+ from qadence.engines.differentiable_backend import (
5
+ DifferentiableBackend as DifferentiableBackendInterface,
6
+ )
7
+ from qadence.engines.jax.differentiable_expectation import DifferentiableExpectation
8
+ from qadence.measurements import Measurements
9
+ from qadence.mitigations import Mitigations
10
+ from qadence.noise import Noise
11
+ from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType
12
+
13
+
14
+ class DifferentiableBackend(DifferentiableBackendInterface):
15
+ """A class which wraps a QuantumBackend with the automatic differentation engine JAX.
16
+
17
+ Arguments:
18
+ backend: An instance of the QuantumBackend type perform execution.
19
+ diff_mode: A differentiable mode supported by the differentiation engine.
20
+ **psr_args: Arguments that will be passed on to `DifferentiableExpectation`.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ backend: Backend,
26
+ diff_mode: DiffMode = DiffMode.AD,
27
+ **psr_args: int | float | None,
28
+ ) -> None:
29
+ super().__init__(backend=backend, engine=Engine.JAX, diff_mode=diff_mode)
30
+ self.psr_args = psr_args
31
+
32
+ def expectation(
33
+ self,
34
+ circuit: ConvertedCircuit,
35
+ observable: list[ConvertedObservable] | ConvertedObservable,
36
+ param_values: ParamDictType = {},
37
+ state: ArrayLike | None = None,
38
+ measurement: Measurements | None = None,
39
+ noise: Noise | None = None,
40
+ mitigation: Mitigations | None = None,
41
+ endianness: Endianness = Endianness.BIG,
42
+ ) -> ArrayLike:
43
+ """Compute the expectation value of the `circuit` with the given `observable`.
44
+
45
+ Arguments:
46
+ circuit: A converted circuit as returned by `backend.circuit`.
47
+ observable: A converted observable as returned by `backend.observable`.
48
+ param_values: _**Already embedded**_ parameters of the circuit. See
49
+ [`embedding`][qadence.blocks.embedding.embedding] for more info.
50
+ state: Initial state.
51
+ measurement: Optional measurement protocol. If None, use
52
+ exact expectation value with a statevector simulator.
53
+ noise: A noise model to use.
54
+ mitigation: The error mitigation to use.
55
+ endianness: Endianness of the resulting bit strings.
56
+ """
57
+ observable = observable if isinstance(observable, list) else [observable]
58
+
59
+ if self.diff_mode == DiffMode.AD:
60
+ expectation = self.backend.expectation(circuit, observable, param_values, state)
61
+ else:
62
+ expectation = DifferentiableExpectation(
63
+ backend=self.backend,
64
+ circuit=circuit,
65
+ observable=observable,
66
+ param_values=param_values,
67
+ state=state,
68
+ measurement=measurement,
69
+ noise=noise,
70
+ mitigation=mitigation,
71
+ endianness=endianness,
72
+ ).psr()
73
+ return expectation
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Tuple
5
+
6
+ import jax.numpy as jnp
7
+ from jax import Array, custom_vjp, vmap
8
+
9
+ from qadence.backend import Backend as QuantumBackend
10
+ from qadence.backend import ConvertedCircuit, ConvertedObservable
11
+ from qadence.backends.jax_utils import (
12
+ tensor_to_jnp,
13
+ )
14
+ from qadence.blocks.utils import uuid_to_eigen
15
+ from qadence.measurements import Measurements
16
+ from qadence.mitigations import Mitigations
17
+ from qadence.noise import Noise
18
+ from qadence.types import Endianness, Engine, ParamDictType
19
+
20
+
21
+ def compute_single_gap(eigen_vals: Array, default_val: float = 2.0) -> Array:
22
+ eigen_vals = eigen_vals.reshape(1, 2)
23
+ gaps = jnp.abs(jnp.tril(eigen_vals.T - eigen_vals))
24
+ return jnp.unique(jnp.where(gaps > 0.0, gaps, default_val), size=1)
25
+
26
+
27
+ @dataclass
28
+ class DifferentiableExpectation:
29
+ """A handler for differentiating expectation estimation using various engines."""
30
+
31
+ backend: QuantumBackend
32
+ circuit: ConvertedCircuit
33
+ observable: list[ConvertedObservable] | ConvertedObservable
34
+ param_values: ParamDictType
35
+ state: Array | None = None
36
+ measurement: Measurements | None = None
37
+ noise: Noise | None = None
38
+ mitigation: Mitigations | None = None
39
+ endianness: Endianness = Endianness.BIG
40
+ engine: Engine = Engine.JAX
41
+
42
+ def psr(self) -> Any:
43
+ n_obs = len(self.observable)
44
+
45
+ def expectation_fn(state: Array, values: ParamDictType, psr_params: ParamDictType) -> Array:
46
+ return self.backend.expectation(
47
+ circuit=self.circuit, observable=self.observable, param_values=values, state=state
48
+ )
49
+
50
+ @custom_vjp
51
+ def expectation(state: Array, values: ParamDictType, psr_params: ParamDictType) -> Array:
52
+ return expectation_fn(state, values, psr_params)
53
+
54
+ uuid_to_eigs = {
55
+ k: tensor_to_jnp(v) for k, v in uuid_to_eigen(self.circuit.abstract.block).items()
56
+ }
57
+ self.psr_params = {
58
+ k: self.param_values[k] for k in uuid_to_eigs.keys()
59
+ } # Subset of params on which to perform PSR.
60
+
61
+ def expectation_fwd(state: Array, values: ParamDictType, psr_params: ParamDictType) -> Any:
62
+ return expectation_fn(state, values, psr_params), (
63
+ state,
64
+ values,
65
+ psr_params,
66
+ )
67
+
68
+ def expectation_bwd(res: Tuple[Array, ParamDictType, ParamDictType], tangent: Array) -> Any:
69
+ state, values, psr_params = res
70
+ grads = {}
71
+ # FIXME Hardcoding the single spectral_gap to 2.
72
+ spectral_gap = 2.0
73
+ shift = jnp.pi / 2
74
+
75
+ def shift_circ(param_name: str, values: dict) -> Array:
76
+ shifted_values = values.copy()
77
+ shiftvals = jnp.array(
78
+ [shifted_values[param_name] + shift, shifted_values[param_name] - shift]
79
+ )
80
+
81
+ def _expectation(val: Array) -> Array:
82
+ shifted_values[param_name] = val
83
+ return expectation(state, shifted_values, psr_params)
84
+
85
+ return vmap(_expectation, in_axes=(0,))(shiftvals)
86
+
87
+ for param_name, _ in psr_params.items():
88
+ f_plus, f_min = shift_circ(param_name, values)
89
+ grad = spectral_gap * (f_plus - f_min) / (4.0 * jnp.sin(spectral_gap * shift / 2.0))
90
+ grads[param_name] = jnp.sum(tangent * grad, axis=1) if n_obs > 1 else tangent * grad
91
+ return None, None, grads
92
+
93
+ expectation.defvjp(expectation_fwd, expectation_bwd)
94
+ return expectation(self.state, self.param_values, self.psr_params)
@@ -0,0 +1,4 @@
1
+ from __future__ import annotations
2
+
3
+ from .differentiable_backend import DifferentiableBackend
4
+ from .differentiable_expectation import DifferentiableExpectation
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+
5
+ from qadence.backend import Backend as QuantumBackend
6
+ from qadence.backend import ConvertedCircuit, ConvertedObservable
7
+ from qadence.engines.differentiable_backend import (
8
+ DifferentiableBackend as DifferentiableBackendInterface,
9
+ )
10
+ from qadence.engines.torch.differentiable_expectation import DifferentiableExpectation
11
+ from qadence.extensions import get_gpsr_fns
12
+ from qadence.measurements import Measurements
13
+ from qadence.mitigations import Mitigations
14
+ from qadence.noise import Noise
15
+ from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType
16
+
17
+
18
+ class DifferentiableBackend(DifferentiableBackendInterface):
19
+ """A class which wraps a QuantumBackend with the automatic differentation engine TORCH.
20
+
21
+ Arguments:
22
+ backend: An instance of the QuantumBackend type perform execution.
23
+ diff_mode: A differentiable mode supported by the differentiation engine.
24
+ **psr_args: Arguments that will be passed on to `DifferentiableExpectation`.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ backend: QuantumBackend,
30
+ diff_mode: DiffMode = DiffMode.AD,
31
+ **psr_args: int | float | None,
32
+ ) -> None:
33
+ super().__init__(backend=backend, engine=Engine.TORCH, diff_mode=diff_mode)
34
+ self.psr_args = psr_args
35
+
36
+ def expectation(
37
+ self,
38
+ circuit: ConvertedCircuit,
39
+ observable: list[ConvertedObservable] | ConvertedObservable,
40
+ param_values: ParamDictType = {},
41
+ state: ArrayLike | None = None,
42
+ measurement: Measurements | None = None,
43
+ noise: Noise | None = None,
44
+ mitigation: Mitigations | None = None,
45
+ endianness: Endianness = Endianness.BIG,
46
+ ) -> ArrayLike:
47
+ """Compute the expectation value of the `circuit` with the given `observable`.
48
+
49
+ Arguments:
50
+ circuit: A converted circuit as returned by `backend.circuit`.
51
+ observable: A converted observable as returned by `backend.observable`.
52
+ param_values: _**Already embedded**_ parameters of the circuit. See
53
+ [`embedding`][qadence.blocks.embedding.embedding] for more info.
54
+ state: Initial state.
55
+ measurement: Optional measurement protocol. If None, use
56
+ exact expectation value with a statevector simulator.
57
+ noise: A noise model to use.
58
+ mitigation: The error mitigation to use.
59
+ endianness: Endianness of the resulting bit strings.
60
+ """
61
+ observable = observable if isinstance(observable, list) else [observable]
62
+ differentiable_expectation = DifferentiableExpectation(
63
+ backend=self.backend,
64
+ circuit=circuit,
65
+ observable=observable,
66
+ param_values=param_values,
67
+ state=state,
68
+ measurement=measurement,
69
+ noise=noise,
70
+ mitigation=mitigation,
71
+ endianness=endianness,
72
+ )
73
+
74
+ if self.diff_mode == DiffMode.AD:
75
+ expectation = differentiable_expectation.ad
76
+ elif self.diff_mode == DiffMode.ADJOINT:
77
+ expectation = differentiable_expectation.adjoint
78
+ else:
79
+ try:
80
+ fns = get_gpsr_fns()
81
+ psr_fn = fns[self.diff_mode]
82
+ except KeyError:
83
+ raise ValueError(f"{self.diff_mode} differentiation mode is not supported")
84
+ expectation = partial(differentiable_expectation.psr, psr_fn=psr_fn, **self.psr_args)
85
+ return expectation()
qadence/extensions.py CHANGED
@@ -2,24 +2,34 @@ from __future__ import annotations
2
2
 
3
3
  import importlib
4
4
  from string import Template
5
- from typing import TypeVar
6
5
 
7
6
  from qadence.backend import Backend
8
- from qadence.blocks import (
9
- AbstractBlock,
10
- )
7
+ from qadence.blocks.abstract import TAbstractBlock
11
8
  from qadence.logger import get_logger
12
- from qadence.types import BackendName, DiffMode
13
-
14
- TAbstractBlock = TypeVar("TAbstractBlock", bound=AbstractBlock)
9
+ from qadence.types import BackendName, DiffMode, Engine
15
10
 
16
11
  backends_namespace = Template("qadence.backends.$name")
17
12
 
18
13
  logger = get_logger(__name__)
19
14
 
20
15
 
16
+ def _available_engines() -> dict:
17
+ """Returns a dictionary of currently installed, native qadence engines."""
18
+ res = {}
19
+ for engine in Engine.list():
20
+ module_path = f"qadence.engines.{engine}.differentiable_backend"
21
+ try:
22
+ module = importlib.import_module(module_path)
23
+ DifferentiableBackendCls = getattr(module, "DifferentiableBackend")
24
+ res[engine] = DifferentiableBackendCls
25
+ except (ImportError, ModuleNotFoundError):
26
+ pass
27
+ logger.info(f"Found engines: {res.keys()}")
28
+ return res
29
+
30
+
21
31
  def _available_backends() -> dict:
22
- """Fallback function for native Qadence available backends if extensions is not present."""
32
+ """Returns a dictionary of currently installed, native qadence backends."""
23
33
  res = {}
24
34
  for backend in BackendName.list():
25
35
  module_path = f"qadence.backends.{backend}.backend"
@@ -29,11 +39,12 @@ def _available_backends() -> dict:
29
39
  res[backend] = BackendCls
30
40
  except (ImportError, ModuleNotFoundError):
31
41
  pass
42
+ logger.info(f"Found backends: {res.keys()}")
32
43
  return res
33
44
 
34
45
 
35
46
  def _supported_gates(name: BackendName | str) -> list[TAbstractBlock]:
36
- """Fallback function for native Qadence backend supported gates if extensions is not present."""
47
+ """Returns a list of supported gates for the queried backend 'name'."""
37
48
  from qadence import operations
38
49
 
39
50
  name = str(BackendName(name).name.lower())
@@ -107,6 +118,7 @@ try:
107
118
  set_backend_config = getattr(module, "set_backend_config")
108
119
  except ModuleNotFoundError:
109
120
  available_backends = _available_backends
121
+ available_engines = _available_engines
110
122
  supported_gates = _supported_gates
111
123
  get_gpsr_fns = _gpsr_fns
112
124
  set_backend_config = _set_backend_config
qadence/finitediff.py ADDED
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+
9
+ def finitediff(
10
+ f: Callable,
11
+ x: Tensor,
12
+ derivative_indices: tuple[int, ...],
13
+ eps: float = None,
14
+ ) -> Tensor:
15
+ """
16
+ Arguments:
17
+
18
+ f: Function to differentiate
19
+ x: Input of shape `(batch_size, input_size)`
20
+ derivative_indices: which *input* to differentiate (i.e. which variable x[:,i])
21
+ eps: finite difference spacing (uses `torch.finfo(x.dtype).eps ** (1 / (2 + order))` as a
22
+ default)
23
+ """
24
+
25
+ if eps is None:
26
+ order = len(derivative_indices)
27
+ eps = torch.finfo(x.dtype).eps ** (1 / (2 + order))
28
+
29
+ # compute derivative direction vector(s)
30
+ eps = torch.as_tensor(eps, dtype=x.dtype)
31
+ _eps = 1 / eps # type: ignore[operator]
32
+ ev = torch.zeros_like(x)
33
+ i = derivative_indices[0]
34
+ ev[:, i] += eps
35
+
36
+ # recursive finite differencing for higher order than 3 / mixed derivatives
37
+ if len(derivative_indices) > 3 or len(set(derivative_indices)) > 1:
38
+ di = derivative_indices[1:]
39
+ return (finitediff(f, x + ev, di) - finitediff(f, x - ev, di)) * _eps / 2
40
+ elif len(derivative_indices) == 3:
41
+ return (f(x + 2 * ev) - 2 * f(x + ev) + 2 * f(x - ev) - f(x - 2 * ev)) * _eps**3 / 2
42
+ elif len(derivative_indices) == 2:
43
+ return (f(x + ev) + f(x - ev) - 2 * f(x)) * _eps**2
44
+ elif len(derivative_indices) == 1:
45
+ return (f(x + ev) - f(x - ev)) * _eps / 2
46
+ else:
47
+ raise ValueError("If you see this error there is a bug in the `finitediff` function.")
@@ -6,17 +6,50 @@ from functools import reduce
6
6
  import numpy as np
7
7
  import numpy.typing as npt
8
8
  import torch
9
+ from numpy.linalg import inv, matrix_rank, pinv
9
10
  from scipy.linalg import norm
10
11
  from scipy.optimize import LinearConstraint, minimize
11
12
 
12
13
  from qadence.mitigations.protocols import Mitigations
13
14
  from qadence.noise.protocols import Noise
15
+ from qadence.types import ReadOutOptimization
14
16
 
15
17
 
16
18
  def corrected_probas(p_corr: npt.NDArray, T: npt.NDArray, p_raw: npt.NDArray) -> np.double:
17
19
  return norm(T @ p_corr.T - p_raw.T, ord=2) ** 2
18
20
 
19
21
 
22
+ def mle_solve(p_raw: npt.NDArray) -> npt.NDArray:
23
+ """
24
+ Compute the MLE probability vector.
25
+
26
+ Algorithmic details can be found in https://arxiv.org/pdf/1106.5458.pdf Page(3).
27
+ """
28
+ # Sort p_raw by values while keeping track of indices.
29
+ index_sort = p_raw.argsort()
30
+ p_sort = p_raw[index_sort]
31
+ neg_sum = 0
32
+ breakpoint = len(p_sort) - 1
33
+
34
+ for i in range(len(p_sort)):
35
+ ## if neg_sum cannot be distributed among other probabilities, continue to accumulate
36
+ if p_sort[i] + neg_sum / (len(p_sort) - i) < 0:
37
+ neg_sum += p_sort[i]
38
+ p_sort[i] = 0
39
+ # set breakpoint to current index
40
+ else:
41
+ breakpoint = i
42
+ break
43
+ ## number of entries to which i can distribute(includes breakpoint)
44
+ size = len(p_sort) - breakpoint
45
+ p_sort[breakpoint:] += neg_sum / size
46
+
47
+ re_index_sort = index_sort.argsort()
48
+ p_corr = p_sort[re_index_sort]
49
+
50
+ return p_corr
51
+
52
+
20
53
  def renormalize_counts(corrected_counts: npt.NDArray, n_shots: int) -> npt.NDArray:
21
54
  """Renormalize counts rounding discrepancies."""
22
55
  total_counts = sum(corrected_counts)
@@ -25,51 +58,85 @@ def renormalize_counts(corrected_counts: npt.NDArray, n_shots: int) -> npt.NDArr
25
58
  corrected_counts -= counts_diff
26
59
  corrected_counts = np.where(corrected_counts < 0, 0, corrected_counts)
27
60
  sum_corrected_counts = sum(corrected_counts)
28
- if sum_corrected_counts < n_shots:
29
- renormalization_factor = max(sum_corrected_counts, n_shots) / min(
30
- sum_corrected_counts, n_shots
31
- )
32
- else:
33
- renormalization_factor = min(sum_corrected_counts, n_shots) / max(
34
- sum_corrected_counts, n_shots
35
- )
61
+
62
+ renormalization_factor = n_shots / sum_corrected_counts
36
63
  corrected_counts = np.rint(corrected_counts * renormalization_factor).astype(int)
37
64
  return corrected_counts
38
65
 
39
66
 
67
+ def matrix_inv(K: npt.NDArray) -> npt.NDArray:
68
+ return inv(K) if matrix_rank(K) == K.shape[0] else pinv(K)
69
+
70
+
40
71
  def mitigation_minimization(
41
- noise: Noise, mitigation: Mitigations, samples: list[Counter]
72
+ noise: Noise,
73
+ mitigation: Mitigations,
74
+ samples: list[Counter],
42
75
  ) -> list[Counter]:
43
76
  """Minimize a correction matrix subjected to stochasticity constraints.
44
77
 
45
78
  See Equation (5) in https://arxiv.org/pdf/2001.09980.pdf.
79
+ See Page(3) in https://arxiv.org/pdf/1106.5458.pdf for MLE implementation
80
+
81
+ Args:
82
+ noise: Specifies confusion matrix and default error probability
83
+ mitigation: Selects additional mitigation options based on noise choice.
84
+ For readout we have the following mitigation options for optimization
85
+ 1. constrained 2. mle. Default : mle
86
+ samples: List of samples to be mitigated
87
+
88
+ Returns:
89
+ Mitigated counts computed by the algorithm
46
90
  """
47
91
  noise_matrices = noise.options.get("noise_matrix", noise.options["confusion_matrices"])
92
+ optimization_type = mitigation.options.get("optimization_type", ReadOutOptimization.MLE)
48
93
  n_qubits = len(list(samples[0].keys())[0])
49
94
  n_shots = sum(samples[0].values())
50
- # Build the whole T matrix.
51
- T_matrix = reduce(torch.kron, noise_matrices).detach().numpy()
52
95
  corrected_counters: list[Counter] = []
96
+
97
+ if optimization_type == ReadOutOptimization.CONSTRAINED:
98
+ # Build the whole T matrix.
99
+ T_matrix = reduce(torch.kron, noise_matrices).detach().numpy()
100
+
101
+ if optimization_type == ReadOutOptimization.MLE:
102
+ # Check if matrix is singular and use appropriate inverse.
103
+ noise_matrices_inv = list(map(matrix_inv, noise_matrices.numpy()))
104
+ T_inv = reduce(np.kron, noise_matrices_inv)
105
+
53
106
  for sample in samples:
54
107
  bitstring_length = 2**n_qubits
55
108
  # List of bitstrings in lexicographical order.
56
109
  ordered_bitstrings = [f"{i:0{n_qubits}b}" for i in range(bitstring_length)]
57
110
  # Array of raw probabilites.
58
111
  p_raw = np.array([sample[bs] for bs in ordered_bitstrings]) / n_shots
59
- # Initial random guess in [0,1].
60
- p_corr0 = np.random.rand(bitstring_length)
61
- # Stochasticity constraints.
62
- normality_constraint = LinearConstraint(
63
- np.ones(bitstring_length).astype(int), lb=1.0, ub=1.0
64
- )
65
- positivity_constraint = LinearConstraint(
66
- np.eye(bitstring_length).astype(int), lb=0.0, ub=1.0
67
- )
68
- constraints = [normality_constraint, positivity_constraint]
69
- # Minimize the corrected probabilities.
70
- res = minimize(corrected_probas, p_corr0, args=(T_matrix, p_raw), constraints=constraints)
71
- # breakpoint()
72
- corrected_counts = np.rint(res.x * n_shots).astype(int)
112
+
113
+ if optimization_type == ReadOutOptimization.CONSTRAINED:
114
+ # Initial random guess in [0,1].
115
+ p_corr0 = np.random.rand(bitstring_length)
116
+ # Stochasticity constraints.
117
+ normality_constraint = LinearConstraint(
118
+ np.ones(bitstring_length).astype(int), lb=1.0, ub=1.0
119
+ )
120
+ positivity_constraint = LinearConstraint(
121
+ np.eye(bitstring_length).astype(int), lb=0.0, ub=1.0
122
+ )
123
+ constraints = [normality_constraint, positivity_constraint]
124
+ # Minimize the corrected probabilities.
125
+ res = minimize(
126
+ corrected_probas, p_corr0, args=(T_matrix, p_raw), constraints=constraints
127
+ )
128
+ p_corr = res.x
129
+
130
+ elif optimization_type == ReadOutOptimization.MLE:
131
+ # Compute corrected inverse using matrix inversion and run MLE.
132
+ p_corr = mle_solve(T_inv @ p_raw)
133
+ else:
134
+ raise NotImplementedError(
135
+ f"Requested method {optimization_type} does not match supported protocols."
136
+ )
137
+
138
+ corrected_counts = np.rint(p_corr * n_shots).astype(int)
139
+
73
140
  # Renormalize if total counts differs from n_shots.
74
141
  corrected_counts = renormalize_counts(corrected_counts=corrected_counts, n_shots=n_shots)
75
142
  # At this point, the count should be off by at most 2, added or substracted to/from the
@@ -178,9 +178,16 @@ class TransformedModule(torch.nn.Module):
178
178
  if isinstance(self.model, (QuantumModel, QNN)):
179
179
  if not isinstance(x, dict):
180
180
  x = self._format_to_dict(x)
181
- return {
182
- key: self._input_scaling * (val + self._input_shifting) for key, val in x.items()
183
- }
181
+ if self.in_features == 1:
182
+ return {
183
+ key: self._input_scaling * (val + self._input_shifting)
184
+ for key, val in x.items()
185
+ }
186
+ else:
187
+ return {
188
+ key: self._input_scaling[idx] * (val + self._input_shifting[idx])
189
+ for idx, (key, val) in enumerate(x.items())
190
+ }
184
191
 
185
192
  else:
186
193
  assert isinstance(self.model, torch.nn.Module) and isinstance(x, Tensor)