qoro-divi 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,263 @@
1
+ # SPDX-FileCopyrightText: 2025-2026 Qoro Quantum Ltd <divi@qoroquantum.de>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from typing import Any
6
+ from warnings import warn
7
+
8
+ import numpy as np
9
+ import pennylane as qml
10
+ import sympy as sp
11
+ from qiskit import QuantumCircuit
12
+
13
+ from divi.circuits import CircuitBundle, MetaCircuit
14
+ from divi.qprog._hamiltonians import _clean_hamiltonian
15
+ from divi.qprog.variational_quantum_algorithm import VariationalQuantumAlgorithm
16
+
17
+
18
+ class CustomVQA(VariationalQuantumAlgorithm):
19
+ """Custom variational algorithm for a parameterized QuantumScript.
20
+
21
+ This implementation wraps a PennyLane QuantumScript (or converts a Qiskit
22
+ QuantumCircuit into one) and optimizes its trainable parameters to minimize
23
+ a single expectation-value measurement. Qiskit measurements are converted
24
+ into a PauliZ expectation on the measured wires. Parameters are bound to sympy
25
+ symbols to enable QASM substitution and reuse of MetaCircuit templates
26
+ during optimization.
27
+
28
+ Attributes:
29
+ qscript (qml.tape.QuantumScript): The parameterized QuantumScript.
30
+ param_shape (tuple[int, ...]): Shape of a single parameter set.
31
+ n_qubits (int): Number of qubits in the script.
32
+ n_layers (int): Layer count (fixed to 1 for this wrapper).
33
+ cost_hamiltonian (qml.operation.Operator): Observable being minimized.
34
+ loss_constant (float): Constant term extracted from the observable.
35
+ optimizer (Optimizer): Classical optimizer for parameter updates.
36
+ max_iterations (int): Maximum number of optimization iterations.
37
+ current_iteration (int): Current optimization iteration.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ qscript: qml.tape.QuantumScript | QuantumCircuit,
43
+ *,
44
+ param_shape: tuple[int, ...] | int | None = None,
45
+ max_iterations: int = 10,
46
+ **kwargs,
47
+ ) -> None:
48
+ """Initialize a CustomVQA instance.
49
+
50
+ Args:
51
+ qscript (qml.tape.QuantumScript | QuantumCircuit): A parameterized QuantumScript with a
52
+ single expectation-value measurement, or a Qiskit QuantumCircuit with
53
+ computational basis measurements.
54
+ param_shape (tuple[int, ...] | int | None): Shape of a single parameter
55
+ set. If None, uses a flat shape inferred from trainable parameters.
56
+ max_iterations (int): Maximum number of optimization iterations.
57
+ **kwargs: Additional keyword arguments passed to the parent class, including
58
+ backend and optimizer.
59
+
60
+ Raises:
61
+ TypeError: If qscript is not a supported PennyLane QuantumScript or Qiskit QuantumCircuit.
62
+ ValueError: If the script has an invalid measurement or no trainable parameters.
63
+ """
64
+ super().__init__(**kwargs)
65
+
66
+ self._qiskit_param_names = (
67
+ [param.name for param in qscript.parameters]
68
+ if isinstance(qscript, QuantumCircuit)
69
+ else None
70
+ )
71
+ self.qscript = self._coerce_to_quantum_script(qscript)
72
+
73
+ if len(self.qscript.measurements) != 1:
74
+ raise ValueError(
75
+ "QuantumScript must contain exactly one measurement for optimization."
76
+ )
77
+
78
+ measurement = self.qscript.measurements[0]
79
+ if not hasattr(measurement, "obs") or measurement.obs is None:
80
+ raise ValueError(
81
+ "QuantumScript must contain a single expectation-value measurement."
82
+ )
83
+
84
+ self._cost_hamiltonian, self.loss_constant = _clean_hamiltonian(measurement.obs)
85
+ if (
86
+ isinstance(self._cost_hamiltonian, qml.Hamiltonian)
87
+ and not self._cost_hamiltonian.operands
88
+ ):
89
+ raise ValueError("Hamiltonian contains only constant terms.")
90
+
91
+ self.n_qubits = self.qscript.num_wires
92
+ self.n_layers = 1
93
+ self.max_iterations = max_iterations
94
+ self.current_iteration = 0
95
+
96
+ trainable_param_indices = (
97
+ list(self.qscript.trainable_params)
98
+ if self.qscript.trainable_params
99
+ else list(range(len(self.qscript.get_parameters())))
100
+ )
101
+ if not trainable_param_indices:
102
+ raise ValueError("QuantumScript does not contain any trainable parameters.")
103
+
104
+ self._param_shape = self._resolve_param_shape(
105
+ param_shape, len(trainable_param_indices)
106
+ )
107
+ self._n_params = int(np.prod(self._param_shape))
108
+
109
+ self._trainable_param_indices = trainable_param_indices
110
+ self._param_symbols = (
111
+ np.array(
112
+ [sp.Symbol(name) for name in self._qiskit_param_names], dtype=object
113
+ ).reshape(self._param_shape)
114
+ if self._qiskit_param_names is not None
115
+ else sp.symarray("p", self._param_shape)
116
+ )
117
+
118
+ flat_symbols = self._param_symbols.flatten().tolist()
119
+ self._qscript = self.qscript.bind_new_parameters(
120
+ flat_symbols, self._trainable_param_indices
121
+ )
122
+
123
+ @property
124
+ def cost_hamiltonian(self) -> qml.operation.Operator:
125
+ """The cost Hamiltonian for the QuantumScript optimization."""
126
+ return self._cost_hamiltonian
127
+
128
+ @property
129
+ def param_shape(self) -> tuple[int, ...]:
130
+ """Shape of a single parameter set."""
131
+ return self._param_shape
132
+
133
+ def _resolve_param_shape(
134
+ self, param_shape: tuple[int, ...] | int | None, n_params: int
135
+ ) -> tuple[int, ...]:
136
+ """Validate and normalize the parameter shape.
137
+
138
+ Args:
139
+ param_shape (tuple[int, ...] | int | None): User-provided parameter shape.
140
+ n_params (int): Number of trainable parameters in the script.
141
+
142
+ Returns:
143
+ tuple[int, ...]: Normalized parameter shape.
144
+
145
+ Raises:
146
+ ValueError: If the shape is invalid or does not match n_params.
147
+ """
148
+ if param_shape is None:
149
+ return (n_params,)
150
+
151
+ param_shape = (param_shape,) if isinstance(param_shape, int) else param_shape
152
+
153
+ if any(dim <= 0 for dim in param_shape):
154
+ raise ValueError(
155
+ f"param_shape entries must be positive, got {param_shape}."
156
+ )
157
+
158
+ if int(np.prod(param_shape)) != n_params:
159
+ raise ValueError(
160
+ f"param_shape does not match the number of trainable parameters. "
161
+ f"Expected product {n_params}, got {int(np.prod(param_shape))}."
162
+ )
163
+
164
+ return tuple(param_shape)
165
+
166
+ def _coerce_to_quantum_script(
167
+ self,
168
+ qscript: qml.tape.QuantumScript | QuantumCircuit,
169
+ ) -> qml.tape.QuantumScript:
170
+ """Convert supported inputs into a PennyLane QuantumScript.
171
+
172
+ Args:
173
+ qscript (qml.tape.QuantumScript): Input QuantumScript or Qiskit QuantumCircuit.
174
+
175
+ Returns:
176
+ qml.tape.QuantumScript: The converted QuantumScript.
177
+
178
+ Raises:
179
+ TypeError: If the input type is unsupported.
180
+ """
181
+ if isinstance(qscript, qml.tape.QuantumScript):
182
+ return qscript
183
+
184
+ if isinstance(qscript, QuantumCircuit):
185
+ measured_wires = sorted(
186
+ {
187
+ qscript.qubits.index(qubit)
188
+ for instruction in qscript.data
189
+ if instruction.operation.name == "measure"
190
+ for qubit in instruction.qubits
191
+ }
192
+ )
193
+ if not measured_wires:
194
+ warn(
195
+ "Provided QuantumCircuit has no measurement operations. "
196
+ "Defaulting to measuring all wires with PauliZ.",
197
+ UserWarning,
198
+ )
199
+ measured_wires = list(range(len(qscript.qubits)))
200
+
201
+ obs = (
202
+ qml.Z(measured_wires[0])
203
+ if len(measured_wires) == 1
204
+ else qml.sum(*(qml.Z(wire) for wire in measured_wires))
205
+ )
206
+ # Remove measurements before conversion to avoid MidMeasureMP issues
207
+ qc_no_measure = QuantumCircuit(qscript.num_qubits)
208
+ for instruction in qscript.data:
209
+ if instruction.operation.name != "measure":
210
+ qc_no_measure.append(
211
+ instruction.operation, instruction.qubits, instruction.clbits
212
+ )
213
+ qfunc = qml.from_qiskit(qc_no_measure)
214
+ qiskit_params = [
215
+ qml.numpy.array(0.0, requires_grad=True) for _ in qscript.parameters
216
+ ]
217
+
218
+ def qfunc_with_measurement(*params):
219
+ qfunc(*params)
220
+ return qml.expval(obs)
221
+
222
+ return qml.tape.make_qscript(qfunc_with_measurement)(*qiskit_params)
223
+
224
+ raise TypeError(
225
+ "qscript must be a PennyLane QuantumScript or a Qiskit QuantumCircuit."
226
+ )
227
+
228
+ def _create_meta_circuits_dict(self) -> dict[str, MetaCircuit]:
229
+ """Create the meta-circuit dictionary for CustomVQA.
230
+
231
+ Returns:
232
+ dict[str, MetaCircuit]: Dictionary containing the cost circuit template.
233
+ """
234
+ return {
235
+ "cost_circuit": self._meta_circuit_factory(
236
+ self._qscript, symbols=self._param_symbols.flatten()
237
+ )
238
+ }
239
+
240
+ def _generate_circuits(self) -> list[CircuitBundle]:
241
+ """Generate circuits for the current parameter sets.
242
+
243
+ Returns:
244
+ list[CircuitBundle]: Circuit bundles tagged by parameter index.
245
+ """
246
+ return [
247
+ self.meta_circuits["cost_circuit"].initialize_circuit_from_params(
248
+ params_group, param_idx=p
249
+ )
250
+ for p, params_group in enumerate(self._curr_params)
251
+ ]
252
+
253
+ def _perform_final_computation(self, **kwargs) -> None:
254
+ """No-op by default for custom QuantumScript optimization."""
255
+ pass
256
+
257
+ def _save_subclass_state(self) -> dict[str, Any]:
258
+ """Save subclass-specific state for checkpointing."""
259
+ return {}
260
+
261
+ def _load_subclass_state(self, state: dict[str, Any]) -> None:
262
+ """Load subclass-specific state from a checkpoint."""
263
+ pass
@@ -0,0 +1,262 @@
1
+ # SPDX-FileCopyrightText: 2026 Qoro Quantum Ltd <divi@qoroquantum.de>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from warnings import warn
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import pennylane as qml
10
+ import sympy as sp
11
+
12
+ from divi.circuits import MetaCircuit
13
+ from divi.qprog.typing import QUBOProblemTypes, qubo_to_matrix
14
+
15
+ from ._vqe import VQE
16
+
17
+ # Pre-computed 8-bit popcount table for O(1) lookups
18
+ _POPCOUNT_TABLE_8BIT = np.array([bin(i).count("1") for i in range(256)], dtype=np.uint8)
19
+
20
+
21
+ def _fast_popcount_parity(arr_input: npt.NDArray[np.integer]) -> npt.NDArray[np.uint8]:
22
+ """
23
+ Vectorized calculation of (popcount % 2) for an array of integers.
24
+ Uses numpy view casting for extreme speed over large arrays.
25
+ """
26
+ # 1. Ensure array is uint64
27
+ arr_u64 = arr_input.astype(np.uint64)
28
+
29
+ # 2. View as bytes to use 8-bit lookup table
30
+ arr_bytes = arr_u64.view(np.uint8).reshape(arr_input.shape + (8,))
31
+
32
+ # 3. Lookup and sum bits
33
+ total_bits = _POPCOUNT_TABLE_8BIT[arr_bytes].sum(axis=-1)
34
+
35
+ # 4. Return Parity (0 or 1)
36
+ return total_bits % 2
37
+
38
+
39
+ def _aggregate_param_group(
40
+ param_group: list[tuple[str, dict[str, int]]],
41
+ merge_counts_fn,
42
+ ) -> tuple[list[str], npt.NDArray[np.float64], float]:
43
+ """Aggregate a parameter group into states, counts, and total shots."""
44
+ shots_dict = merge_counts_fn(param_group)
45
+ state_strings = list(shots_dict.keys())
46
+ counts = np.array(list(shots_dict.values()), dtype=float)
47
+ total_shots = counts.sum()
48
+ return state_strings, counts, float(total_shots)
49
+
50
+
51
+ def _decode_parities(
52
+ state_strings: list[str], variable_masks_u64: npt.NDArray[np.uint64]
53
+ ) -> npt.NDArray[np.uint8]:
54
+ """Decode bitstring parities using the precomputed variable masks."""
55
+ states = np.array([int(s, 2) for s in state_strings], dtype=np.uint64)
56
+ overlaps = variable_masks_u64[:, None] & states[None, :]
57
+ return _fast_popcount_parity(overlaps)
58
+
59
+
60
+ def _compute_soft_energy(
61
+ parities: npt.NDArray[np.uint8],
62
+ probs: npt.NDArray[np.float64],
63
+ alpha: float,
64
+ qubo_matrix: npt.NDArray[np.float64] | np.ndarray,
65
+ ) -> float:
66
+ """Compute the relaxed (soft) QUBO energy from parity expectations."""
67
+ mean_parities = parities.dot(probs)
68
+ z_expectations = 1.0 - (2.0 * mean_parities)
69
+ x_soft = 0.5 * (1.0 + np.tanh(alpha * z_expectations))
70
+ Qx = qubo_matrix @ x_soft
71
+ return float(np.dot(x_soft, Qx))
72
+
73
+
74
+ def _compute_hard_cvar_energy(
75
+ parities: npt.NDArray[np.uint8],
76
+ counts: npt.NDArray[np.float64],
77
+ total_shots: float,
78
+ qubo_matrix: npt.NDArray[np.float64] | np.ndarray,
79
+ alpha_cvar: float = 0.25,
80
+ ) -> float:
81
+ """Compute CVaR energy from sampled hard assignments."""
82
+ x_vals = 1.0 - parities.astype(float)
83
+ Qx = qubo_matrix @ x_vals
84
+ energies = np.einsum("ij,ij->j", x_vals, Qx)
85
+
86
+ sorted_indices = np.argsort(energies)
87
+ sorted_energies = energies[sorted_indices]
88
+ sorted_counts = counts[sorted_indices]
89
+
90
+ cutoff_count = int(np.ceil(alpha_cvar * total_shots))
91
+ accumulated_counts = np.cumsum(sorted_counts)
92
+ limit_idx = np.searchsorted(accumulated_counts, cutoff_count)
93
+
94
+ cvar_energy = 0.0
95
+ count_sum = 0
96
+ if limit_idx > 0:
97
+ cvar_energy += np.sum(sorted_energies[:limit_idx] * sorted_counts[:limit_idx])
98
+ count_sum += np.sum(sorted_counts[:limit_idx])
99
+
100
+ remaining = cutoff_count - count_sum
101
+ cvar_energy += sorted_energies[limit_idx] * remaining
102
+ return float(cvar_energy / cutoff_count)
103
+
104
+
105
+ class PCE(VQE):
106
+ """
107
+ Generalized Pauli Correlation Encoding (PCE) VQE.
108
+
109
+ Encodes an N-variable QUBO into O(log2(N)) qubits by mapping each variable
110
+ to a parity (Pauli-Z correlation) of the measured bitstring. The algorithm
111
+ uses the measurement distribution to estimate these parities, applies a
112
+ smooth relaxation when `alpha` is small, and evaluates the classical QUBO
113
+ objective: E = x.T @ Q @ x. For larger `alpha`, it switches to a discrete
114
+ objective (CVaR over sampled energies) for harder convergence.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ qubo_matrix: QUBOProblemTypes,
120
+ n_qubits: int | None = None,
121
+ alpha: float = 2.0,
122
+ **kwargs,
123
+ ):
124
+ """
125
+ Args:
126
+ qubo_matrix (QUBOProblemTypes): The N x N matrix to minimize. Accepts
127
+ a dense array, sparse matrix, list, or BinaryQuadraticModel.
128
+ n_qubits (int | None): Optional override. Must be >= ceil(log2(N)).
129
+ Larger values increase circuit size without adding representational power.
130
+ alpha (float): Scaling factor for the tanh() activation. Higher = harder
131
+ binary constraints, Lower = smoother gradient.
132
+ """
133
+
134
+ self.qubo_matrix = qubo_to_matrix(qubo_matrix)
135
+ self.n_vars = self.qubo_matrix.shape[0]
136
+ self.alpha = alpha
137
+ self._use_soft_objective = self.alpha < 5.0
138
+ self._final_vector: npt.NDArray[np.integer] | None = None
139
+
140
+ if kwargs.get("qem_protocol") is not None:
141
+ raise ValueError("PCE does not currently support qem_protocol.")
142
+
143
+ # Calculate required qubits (Logarithmic Scaling)
144
+ min_qubits = int(np.ceil(np.log2(self.n_vars + 1)))
145
+ if n_qubits is not None and n_qubits < min_qubits:
146
+ raise ValueError(
147
+ "n_qubits must be >= ceil(log2(N + 1)) to represent all variables. "
148
+ f"Got n_qubits={n_qubits}, minimum={min_qubits}."
149
+ )
150
+ if n_qubits is not None and n_qubits > min_qubits:
151
+ warn(
152
+ "n_qubits exceeds the minimum required; extra qubits increase circuit "
153
+ "size and can add noise without representing more variables.",
154
+ UserWarning,
155
+ )
156
+ self.n_qubits = n_qubits if n_qubits is not None else min_qubits
157
+
158
+ # Pre-compute U64 masks for the fast broadcasting step later
159
+ self._variable_masks_u64 = np.arange(1, self.n_vars + 1, dtype=np.uint64)
160
+
161
+ # Placeholder Hamiltonian required by VQE; we care about the measurement
162
+ # probability distribution, and Z-basis measurements provide it.
163
+ placeholder_hamiltonian = qml.Hamiltonian(
164
+ [1.0] * self.n_qubits, [qml.PauliZ(i) for i in range(self.n_qubits)]
165
+ )
166
+ super().__init__(hamiltonian=placeholder_hamiltonian, **kwargs)
167
+
168
+ def _create_meta_circuits_dict(self) -> dict[str, MetaCircuit]:
169
+ """Create meta circuits, handling the edge case of zero parameters."""
170
+ n_params = self.ansatz.n_params_per_layer(
171
+ self.n_qubits, n_electrons=self.n_electrons
172
+ )
173
+
174
+ weights_syms = sp.symarray("w", (self.n_layers, n_params))
175
+
176
+ ops = self.ansatz.build(
177
+ weights_syms,
178
+ n_qubits=self.n_qubits,
179
+ n_layers=self.n_layers,
180
+ n_electrons=self.n_electrons,
181
+ )
182
+
183
+ return {
184
+ "cost_circuit": self._meta_circuit_factory(
185
+ qml.tape.QuantumScript(
186
+ ops=ops, measurements=[qml.expval(self._cost_hamiltonian)]
187
+ ),
188
+ symbols=weights_syms.flatten(),
189
+ ),
190
+ "meas_circuit": self._meta_circuit_factory(
191
+ qml.tape.QuantumScript(ops=ops, measurements=[qml.probs()]),
192
+ symbols=weights_syms.flatten(),
193
+ grouping_strategy="wires",
194
+ ),
195
+ }
196
+
197
+ def _post_process_results(
198
+ self, results: dict[str, dict[str, int]]
199
+ ) -> dict[int, float]:
200
+ """
201
+ Calculates loss.
202
+ If self.alpha < 5.0, computes 'Soft Energy' (Relaxed VQE) for smooth gradients.
203
+ If self.alpha >= 5.0, computes 'Hard CVaR Energy' for final convergence.
204
+ """
205
+
206
+ # Return raw probabilities if requested (skip processing)
207
+ if getattr(self, "_is_compute_probabilities", False):
208
+ return super()._post_process_results(results)
209
+
210
+ losses = {}
211
+
212
+ for p_idx, qem_groups in self._group_results(results).items():
213
+ # PCE ignores QEM ids; aggregate all shots for this parameter set.
214
+ param_group = [
215
+ ("0", shots)
216
+ for shots_list in qem_groups.values()
217
+ for shots in shots_list
218
+ ]
219
+
220
+ state_strings, counts, total_shots = _aggregate_param_group(
221
+ param_group, self._merge_param_group_counts
222
+ )
223
+
224
+ parities = _decode_parities(state_strings, self._variable_masks_u64)
225
+ if self._use_soft_objective:
226
+ probs = counts / total_shots
227
+ losses[p_idx] = _compute_soft_energy(
228
+ parities, probs, self.alpha, self.qubo_matrix
229
+ )
230
+ else:
231
+ losses[p_idx] = _compute_hard_cvar_energy(
232
+ parities, counts, total_shots, self.qubo_matrix
233
+ )
234
+
235
+ return losses
236
+
237
+ def _perform_final_computation(self, **kwargs) -> None:
238
+ """Compute the final eigenstate and decode it into a PCE vector."""
239
+ super()._perform_final_computation(**kwargs)
240
+
241
+ if self._eigenstate is None:
242
+ self._final_vector = None
243
+ return
244
+
245
+ best_bitstring = "".join(str(x) for x in self._eigenstate)
246
+ state_int = int(best_bitstring, 2)
247
+ state_u64 = np.array([state_int], dtype=np.uint64)
248
+
249
+ overlaps = self._variable_masks_u64[:, None] & state_u64[None, :]
250
+ parities = _fast_popcount_parity(overlaps).flatten()
251
+ self._final_vector = 1 - parities
252
+
253
+ @property
254
+ def solution(self) -> npt.NDArray[np.integer]:
255
+ """
256
+ Returns the final optimized vector (hard binary 0/1) based on the best parameters found.
257
+ You must run .run() before calling this.
258
+ """
259
+ if self._final_vector is None:
260
+ raise RuntimeError("Run the VQE optimization first.")
261
+
262
+ return self._final_vector
@@ -1,8 +1,9 @@
1
- # SPDX-FileCopyrightText: 2025 Qoro Quantum Ltd <divi@qoroquantum.de>
1
+ # SPDX-FileCopyrightText: 2025-2026 Qoro Quantum Ltd <divi@qoroquantum.de>
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
5
  import logging
6
+ from collections.abc import Callable
6
7
  from enum import Enum
7
8
  from typing import Any, Literal, get_args
8
9
  from warnings import warn
@@ -13,8 +14,6 @@ import networkx as nx
13
14
  import numpy as np
14
15
  import pennylane as qml
15
16
  import pennylane.qaoa as pqaoa
16
- import rustworkx as rx
17
- import scipy.sparse as sps
18
17
  import sympy as sp
19
18
 
20
19
  from divi.circuits import CircuitBundle, MetaCircuit
@@ -22,13 +21,11 @@ from divi.qprog._hamiltonians import (
22
21
  _clean_hamiltonian,
23
22
  convert_qubo_matrix_to_pennylane_ising,
24
23
  )
24
+ from divi.qprog.typing import GraphProblemTypes, QUBOProblemTypes, qubo_to_matrix
25
25
  from divi.qprog.variational_quantum_algorithm import VariationalQuantumAlgorithm
26
26
 
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
- GraphProblemTypes = nx.Graph | rx.PyGraph
30
- QUBOProblemTypes = list | np.ndarray | sps.spmatrix | dimod.BinaryQuadraticModel
31
-
32
29
 
33
30
  def _extract_loss_constant(
34
31
  problem_metadata: dict, constant_from_hamiltonian: float
@@ -163,17 +160,7 @@ def _resolve_circuit_layers(
163
160
 
164
161
  return *getattr(pqaoa, graph_problem.pl_string)(*params), resolved_initial_state
165
162
  else:
166
- # Convert BinaryQuadraticModel to matrix if needed
167
- if isinstance(problem, dimod.BinaryQuadraticModel):
168
- # Manual conversion from BQM to matrix (replacing deprecated to_numpy_matrix)
169
- variables = list(problem.variables)
170
- var_to_idx = {v: i for i, v in enumerate(variables)}
171
- qubo_matrix = np.diag([problem.linear.get(v, 0) for v in variables])
172
- for (u, v), coeff in problem.quadratic.items():
173
- i, j = var_to_idx[u], var_to_idx[v]
174
- qubo_matrix[i, j] = qubo_matrix[j, i] = coeff
175
- else:
176
- qubo_matrix = problem
163
+ qubo_matrix = qubo_to_matrix(problem)
177
164
 
178
165
  cost_hamiltonian, constant = convert_qubo_matrix_to_pennylane_ising(qubo_matrix)
179
166
 
@@ -225,6 +212,7 @@ class QAOA(VariationalQuantumAlgorithm):
225
212
  n_layers: int = 1,
226
213
  initial_state: _SUPPORTED_INITIAL_STATES_LITERAL = "Recommended",
227
214
  max_iterations: int = 10,
215
+ decode_solution_fn: Callable[[str], Any] | None = None,
228
216
  **kwargs,
229
217
  ):
230
218
  """Initialize the QAOA problem.
@@ -237,15 +225,21 @@ class QAOA(VariationalQuantumAlgorithm):
237
225
  n_layers (int): Number of QAOA layers. Defaults to 1.
238
226
  initial_state (_SUPPORTED_INITIAL_STATES_LITERAL): The initial state of the circuit. Defaults to "Recommended".
239
227
  max_iterations (int): Maximum number of optimization iterations. Defaults to 10.
228
+ decode_solution_fn (callable[[str], Any] | None): Optional decoder for bitstrings.
229
+ If not provided, a default decoder is selected based on problem type.
240
230
  **kwargs: Additional keyword arguments passed to the parent class, including `optimizer`.
241
231
  """
242
- super().__init__(**kwargs)
243
-
244
232
  self.graph_problem = graph_problem
245
233
 
246
- # Validate and process problem
234
+ # Validate and process problem (needed to determine decode function)
235
+ # This sets n_qubits which is needed before parent init
247
236
  self.problem = self._validate_and_set_problem(problem, graph_problem)
248
237
 
238
+ if decode_solution_fn is not None:
239
+ kwargs["decode_solution_fn"] = decode_solution_fn
240
+
241
+ super().__init__(**kwargs)
242
+
249
243
  # Validate initial state
250
244
  if initial_state not in get_args(_SUPPORTED_INITIAL_STATES_LITERAL):
251
245
  raise ValueError(
@@ -286,6 +280,22 @@ class QAOA(VariationalQuantumAlgorithm):
286
280
  # Extract wire labels from the cost Hamiltonian to ensure consistency
287
281
  self._circuit_wires = tuple(self._cost_hamiltonian.wires)
288
282
 
283
+ # Set up decode function based on problem type if user didn't provide one
284
+ if decode_solution_fn is None:
285
+ if isinstance(self.problem, QUBOProblemTypes):
286
+ # For QUBO: convert bitstring to numpy array of int32
287
+ self._decode_solution_fn = lambda bitstring: np.fromiter(
288
+ bitstring, dtype=np.int32
289
+ )
290
+ elif isinstance(self.problem, GraphProblemTypes):
291
+ # For Graph: map bitstring positions to graph node labels
292
+ circuit_wires = self._circuit_wires # Capture for closure
293
+ self._decode_solution_fn = lambda bitstring: [
294
+ circuit_wires[idx]
295
+ for idx, bit in enumerate(bitstring)
296
+ if bit == "1" and idx < len(circuit_wires)
297
+ ]
298
+
289
299
  def _save_subclass_state(self) -> dict[str, Any]:
290
300
  """Save QAOA-specific runtime state."""
291
301
  return {
@@ -497,7 +507,7 @@ class QAOA(VariationalQuantumAlgorithm):
497
507
 
498
508
  return [
499
509
  self.meta_circuits[circuit_type].initialize_circuit_from_params(
500
- params_group, tag_prefix=f"{p}"
510
+ params_group, param_idx=p
501
511
  )
502
512
  for p, params_group in enumerate(self._curr_params)
503
513
  ]
@@ -508,9 +518,11 @@ class QAOA(VariationalQuantumAlgorithm):
508
518
  This method performs the following steps:
509
519
  1. Executes measurement circuits with the best parameters (those that achieved the lowest loss).
510
520
  2. Retrieves the bitstring representing the best solution, correcting for endianness.
511
- 3. Depending on the problem type:
512
- - For QUBO problems, stores the solution as a NumPy array of bits.
513
- - For graph problems, stores the solution as a list of node indices corresponding to '1's in the bitstring.
521
+ 3. Uses the `decode_solution_fn` (configured in constructor based on problem type) to decode
522
+ the bitstring into the appropriate format:
523
+ - For QUBO problems: NumPy array of bits (int32).
524
+ - For graph problems: List of node indices corresponding to '1's in the bitstring.
525
+ 4. Stores the decoded solution in the appropriate attribute.
514
526
 
515
527
  Returns:
516
528
  tuple[int, float]: A tuple containing:
@@ -529,19 +541,14 @@ class QAOA(VariationalQuantumAlgorithm):
529
541
  best_measurement_probs, key=best_measurement_probs.get
530
542
  )
531
543
 
532
- if isinstance(self.problem, QUBOProblemTypes):
533
- self._solution_bitstring[:] = np.fromiter(
534
- best_solution_bitstring, dtype=np.int32
535
- )
544
+ # Use decode function to get the decoded solution
545
+ decoded_solution = self._decode_solution_fn(best_solution_bitstring)
536
546
 
537
- if isinstance(self.problem, GraphProblemTypes):
538
- # Map bitstring positions to actual graph node labels
539
- # Bitstring is already endianness-corrected, so positions map directly to circuit_wires
540
- self._solution_nodes[:] = [
541
- self._circuit_wires[idx]
542
- for idx, bit in enumerate(best_solution_bitstring)
543
- if bit == "1" and idx < len(self._circuit_wires)
544
- ]
547
+ # Store in appropriate attribute based on problem type
548
+ if isinstance(self.problem, QUBOProblemTypes):
549
+ self._solution_bitstring[:] = decoded_solution
550
+ elif isinstance(self.problem, GraphProblemTypes):
551
+ self._solution_nodes[:] = decoded_solution
545
552
 
546
553
  self.reporter.info(message="🏁 Computed Final Solution! 🏁")
547
554