qoro-divi 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qoro-divi might be problematic. Click here for more details.

@@ -4,8 +4,7 @@
4
4
 
5
5
  import heapq
6
6
  import string
7
- from collections.abc import Callable, Sequence, Set
8
- from concurrent.futures import ProcessPoolExecutor
7
+ from collections.abc import Callable, Sequence
9
8
  from dataclasses import dataclass
10
9
  from functools import partial
11
10
  from typing import Literal
@@ -21,7 +20,7 @@ from pymetis import part_graph
21
20
  from sklearn.cluster import SpectralClustering
22
21
 
23
22
  from divi.backends import CircuitRunner
24
- from divi.qprog import QAOA, ProgramBatch, QuantumProgram
23
+ from divi.qprog import QAOA, ProgramBatch
25
24
  from divi.qprog.algorithms._qaoa import (
26
25
  _SUPPORTED_INITIAL_STATES_LITERAL,
27
26
  GraphProblem,
@@ -41,6 +40,43 @@ _MAXIMUM_AVAILABLE_QUBITS = 30
41
40
 
42
41
  @dataclass(frozen=True, eq=True)
43
42
  class PartitioningConfig:
43
+ """Configuration for graph partitioning algorithms.
44
+
45
+ This class defines the parameters and constraints for partitioning large graphs
46
+ into smaller subgraphs for quantum algorithm execution. It supports multiple
47
+ partitioning algorithms and allows specification of size constraints.
48
+
49
+ Attributes:
50
+ max_n_nodes_per_cluster: Maximum number of nodes allowed in each cluster.
51
+ If None, no upper limit is enforced. Must be a positive integer.
52
+ minimum_n_clusters: Minimum number of clusters to create. If None, no
53
+ lower limit is enforced. Must be a positive integer.
54
+ partitioning_algorithm: Algorithm to use for partitioning. Options are:
55
+ - "spectral": Spectral partitioning using Fiedler vector (default)
56
+ - "metis": METIS graph partitioning library
57
+ - "kernighan_lin": Kernighan-Lin algorithm
58
+
59
+ Note:
60
+ At least one of `max_n_nodes_per_cluster` or `minimum_n_clusters` must be
61
+ specified. Both constraints cannot be None.
62
+
63
+ Examples:
64
+ >>> # Partition into clusters of at most 10 nodes
65
+ >>> config = PartitioningConfig(max_n_nodes_per_cluster=10)
66
+
67
+ >>> # Create at least 5 clusters using METIS
68
+ >>> config = PartitioningConfig(
69
+ ... minimum_n_clusters=5,
70
+ ... partitioning_algorithm="metis"
71
+ ... )
72
+
73
+ >>> # Both constraints: clusters of max 8 nodes, min 3 clusters
74
+ >>> config = PartitioningConfig(
75
+ ... max_n_nodes_per_cluster=8,
76
+ ... minimum_n_clusters=3
77
+ ... )
78
+ """
79
+
44
80
  max_n_nodes_per_cluster: int | None = None
45
81
  minimum_n_clusters: int | None = None
46
82
  partitioning_algorithm: Literal["spectral", "metis", "kernighan_lin"] = "spectral"
@@ -335,7 +371,7 @@ def _node_partition_graph(
335
371
 
336
372
  def linear_aggregation(
337
373
  curr_solution: Sequence[Literal[0] | Literal[1]],
338
- subproblem_solution: Set[int],
374
+ subproblem_solution: set[int],
339
375
  subproblem_reverse_index_map: dict[int, int],
340
376
  ):
341
377
  """Linearly combines a subproblem's solution into the main solution vector.
@@ -365,7 +401,7 @@ def linear_aggregation(
365
401
 
366
402
  def dominance_aggregation(
367
403
  curr_solution: Sequence[Literal[0] | Literal[1]],
368
- subproblem_solution: Set[int],
404
+ subproblem_solution: set[int],
369
405
  subproblem_reverse_index_map: dict[int, int],
370
406
  ):
371
407
  for node in subproblem_solution:
@@ -387,14 +423,6 @@ def dominance_aggregation(
387
423
  return curr_solution
388
424
 
389
425
 
390
- def _run_and_compute_solution(program: QuantumProgram):
391
- program.run()
392
-
393
- final_sol_circuit_count, final_sol_run_time = program.compute_final_solution()
394
-
395
- return final_sol_circuit_count, final_sol_run_time
396
-
397
-
398
426
  class GraphPartitioningQAOA(ProgramBatch):
399
427
  def __init__(
400
428
  self,
@@ -443,8 +471,6 @@ class GraphPartitioningQAOA(ProgramBatch):
443
471
  self.solution = None
444
472
  self.aggregate_fn = aggregate_fn
445
473
 
446
- self._task_fn = _run_and_compute_solution
447
-
448
474
  self._constructor = partial(
449
475
  QAOA,
450
476
  initial_state=initial_state,
@@ -499,33 +525,12 @@ class GraphPartitioningQAOA(ProgramBatch):
499
525
  self.reverse_index_maps[prog_id] = {v: k for k, v in index_map.items()}
500
526
 
501
527
  _subgraph = nx.relabel_nodes(subgraph, index_map)
502
- self.programs[prog_id] = self._constructor(
528
+ self._programs[prog_id] = self._constructor(
503
529
  job_id=prog_id,
504
530
  problem=_subgraph,
505
- losses=self._manager.list(),
506
- probs=self._manager.dict(),
507
- final_params=self._manager.list(),
508
- solution_nodes=self._manager.list(),
509
531
  progress_queue=self._queue,
510
532
  )
511
533
 
512
- def compute_final_solutions(self):
513
- if self._executor is not None:
514
- self.join()
515
-
516
- if self._executor is not None:
517
- raise RuntimeError("A batch is already being run.")
518
-
519
- if len(self.programs) == 0:
520
- raise RuntimeError("No programs to run.")
521
-
522
- self._executor = ProcessPoolExecutor()
523
-
524
- self.futures = [
525
- self._executor.submit(program.compute_final_solution)
526
- for program in self.programs.values()
527
- ]
528
-
529
534
  def aggregate_results(self):
530
535
  """
531
536
  Aggregates the results from all QAOA subprograms to form a global solution.
@@ -544,7 +549,7 @@ class GraphPartitioningQAOA(ProgramBatch):
544
549
  """
545
550
  super().aggregate_results()
546
551
 
547
- if any(len(program.probs) == 0 for program in self.programs.values()):
552
+ if any(len(program.best_probs) == 0 for program in self.programs.values()):
548
553
  raise RuntimeError(
549
554
  "Not all final probabilities computed yet. Please call `run()` first."
550
555
  )
@@ -49,14 +49,6 @@ def _sanitize_problem_input(qubo: T) -> tuple[T, BinaryQuadraticModel]:
49
49
  raise ValueError(f"Got an unsupported QUBO input format: {type(qubo)}")
50
50
 
51
51
 
52
- def _run_and_compute_solution(program: QuantumProgram):
53
- program.run()
54
-
55
- final_sol_circuit_count, final_sol_run_time = program.compute_final_solution()
56
-
57
- return final_sol_circuit_count, final_sol_run_time
58
-
59
-
60
52
  class QUBOPartitioningQAOA(ProgramBatch):
61
53
  def __init__(
62
54
  self,
@@ -94,10 +86,10 @@ class QUBOPartitioningQAOA(ProgramBatch):
94
86
  self._partitioning = hybrid.Unwind(decomposer)
95
87
  self._aggregating = hybrid.Reduce(hybrid.Lambda(_merge_substates)) | composer
96
88
 
97
- self._task_fn = _run_and_compute_solution
98
-
99
89
  self.max_iterations = max_iterations
100
90
 
91
+ self.trivial_program_ids = set()
92
+
101
93
  self._constructor = partial(
102
94
  QAOA,
103
95
  optimizer=optimizer if optimizer is not None else MonteCarloOptimizer(),
@@ -140,6 +132,12 @@ class QUBOPartitioningQAOA(ProgramBatch):
140
132
  del partition["problem"]
141
133
 
142
134
  prog_id = (string.ascii_uppercase[i], len(partition.subproblem))
135
+ self.prog_id_to_bqm_subproblem_states[prog_id] = partition
136
+
137
+ if partition.subproblem.num_interactions == 0:
138
+ # Skip creating a full QAOA program for this trivial case.
139
+ self.trivial_program_ids.add(prog_id)
140
+ continue
143
141
 
144
142
  ldata, (irow, icol, qdata), _ = partition.subproblem.to_numpy_vectors(
145
143
  partition.subproblem.variables
@@ -155,33 +153,52 @@ class QUBOPartitioningQAOA(ProgramBatch):
155
153
  ),
156
154
  shape=(len(ldata), len(ldata)),
157
155
  )
158
- self.prog_id_to_bqm_subproblem_states[prog_id] = partition
159
- self.programs[prog_id] = self._constructor(
156
+
157
+ self._programs[prog_id] = self._constructor(
160
158
  job_id=prog_id,
161
159
  problem=coo_mat,
162
- losses=self._manager.list(),
163
- probs=self._manager.dict(),
164
- final_params=self._manager.list(),
165
- solution_bitstring=self._manager.list(),
166
160
  progress_queue=self._queue,
167
161
  )
168
162
 
169
163
  def aggregate_results(self):
164
+ """
165
+ Aggregate results from all QUBO subproblems into a global solution.
166
+
167
+ Collects solutions from each partitioned subproblem (both QAOA-optimized and
168
+ trivial ones) and uses the hybrid framework composer to combine them into
169
+ a final solution for the original QUBO problem.
170
+
171
+ Returns:
172
+ tuple: A tuple containing:
173
+ - solution (np.ndarray): Binary solution vector for the QUBO problem.
174
+ - solution_energy (float): Energy/cost of the solution.
175
+
176
+ Raises:
177
+ RuntimeError: If programs haven't been run or if final probabilities
178
+ haven't been computed.
179
+ """
170
180
  super().aggregate_results()
171
181
 
172
- if any(len(program.probs) == 0 for program in self.programs.values()):
182
+ if any(len(program.best_probs) == 0 for program in self.programs.values()):
173
183
  raise RuntimeError(
174
184
  "Not all final probabilities computed yet. Please call `run()` first."
175
185
  )
176
186
 
177
- for prog_id, subproblem in self.programs.items():
178
- bqm_subproblem_state = self.prog_id_to_bqm_subproblem_states[prog_id]
187
+ for (
188
+ prog_id,
189
+ bqm_subproblem_state,
190
+ ) in self.prog_id_to_bqm_subproblem_states.items():
191
+
192
+ if prog_id in self.trivial_program_ids:
193
+ # Case 1: Trivial problem. Solve classically.
194
+ # The solution is any bitstring (e.g., all zeros) with energy 0.
195
+ var_to_val = {v: 0 for v in bqm_subproblem_state.subproblem.variables}
196
+ else:
197
+ subproblem = self._programs[prog_id]
198
+ var_to_val = dict(
199
+ zip(bqm_subproblem_state.subproblem.variables, subproblem.solution)
200
+ )
179
201
 
180
- curr_final_solution = subproblem.solution
181
-
182
- var_to_val = dict(
183
- zip(bqm_subproblem_state.subproblem.variables, curr_final_solution)
184
- )
185
202
  sample_set = dimod.SampleSet.from_samples(
186
203
  dimod.as_samples(var_to_val), "BINARY", 0
187
204
  )
@@ -127,6 +127,13 @@ def _zmatrix_to_cartesian(z_matrix: list[_ZMatrixEntry]) -> np.ndarray:
127
127
  if n_atoms == 0:
128
128
  return coords
129
129
 
130
+ # Validate bond lengths are positive
131
+ for i, entry in enumerate(z_matrix[1:], start=1):
132
+ if entry.bond_length is not None and entry.bond_length <= 0:
133
+ raise ValueError(
134
+ f"Bond length for atom {i} must be positive, got {entry.bond_length}"
135
+ )
136
+
130
137
  # --- First atom at origin ---
131
138
  coords[0] = np.array([0.0, 0.0, 0.0])
132
139
 
@@ -251,12 +258,14 @@ def _kabsch_align(P_in: np.ndarray, Q_in: np.ndarray, reference_atoms_idx=slice(
251
258
  H = P_centered.T @ Q_centered
252
259
  U, _, Vt = np.linalg.svd(H)
253
260
 
254
- # Reflection check
255
- d = np.sign(np.linalg.det(Vt.T @ U.T))
256
- D = np.diag([1] * (P.shape[1] - 1) + [d])
261
+ # Compute rotation matrix
262
+ R = Vt.T @ U.T
257
263
 
258
- # Optimal rotation and translation
259
- R = Vt.T @ D @ U.T
264
+ # Ensure proper rotation (det = +1) by handling reflections
265
+ if np.linalg.det(R) < 0:
266
+ # Flip the last column of Vt to ensure proper rotation
267
+ Vt[-1, :] *= -1
268
+ R = Vt.T @ U.T
260
269
  t = Qc - Pc @ R
261
270
 
262
271
  # Apply transformation
@@ -270,27 +279,28 @@ def _kabsch_align(P_in: np.ndarray, Q_in: np.ndarray, reference_atoms_idx=slice(
270
279
  @dataclass(frozen=True, eq=True)
271
280
  class MoleculeTransformer:
272
281
  """
273
- base_molecule: qml.qchem.Molecule
274
- The reference molecule used as a template for generating variants.
275
- bond_modifiers: Sequence[float]
276
- A list of values used to adjust bond lengths. The class will generate
277
- **one new molecule for each modifier** in this list. The modification
278
- mode is detected automatically:
282
+ A class for transforming molecular structures by modifying bond lengths.
283
+
284
+ This class generates variants of a base molecule by adjusting bond lengths
285
+ according to specified modifiers. The modification mode is detected automatically.
286
+
287
+ Attributes:
288
+ base_molecule (qml.qchem.Molecule): The reference molecule used as a template for generating variants.
289
+ bond_modifiers (Sequence[float]): A list of values used to adjust bond lengths. The class will generate
290
+ **one new molecule for each modifier** in this list. The modification
291
+ mode is detected automatically:
279
292
  - **Scale mode**: If all values are positive, they are used as scaling
280
293
  factors (e.g., 1.1 for a 10% increase).
281
294
  - **Delta mode**: If any value is zero or negative, all values are
282
295
  treated as additive changes to the bond length, in Ångstroms.
283
- atom_connectivity: Sequence[tuple[int, int]] | None
284
- A sequence of atom index pairs specifying the bonds in the molecule.
285
- If not provided, a chain structure will be assumed
286
- e.g.: `[(0, 1), (1, 2), (2, 3), ...]`.
287
- bonds_to_transform: Sequence[tuple[int, int]] | None
288
- A subset of `atom_connectivity` that specifies the bonds to modify.
289
- If None, all bonds will be transformed.
290
- alignment_atoms: Sequence[int] | None
291
- Indices of atoms onto which to align the orientation of the resulting
292
- variants of the molecule. Only useful for visualization and debuggin.
293
- If None, no alignment is carried out.
296
+ atom_connectivity (Sequence[tuple[int, int]] | None): A sequence of atom index pairs specifying the bonds in the molecule.
297
+ If not provided, a chain structure will be assumed
298
+ e.g.: `[(0, 1), (1, 2), (2, 3), ...]`.
299
+ bonds_to_transform (Sequence[tuple[int, int]] | None): A subset of `atom_connectivity` that specifies the bonds to modify.
300
+ If None, all bonds will be transformed.
301
+ alignment_atoms (Sequence[int] | None): Indices of atoms onto which to align the orientation of the resulting
302
+ variants of the molecule. Only useful for visualization and debugging.
303
+ If None, no alignment is carried out.
294
304
  """
295
305
 
296
306
  base_molecule: qml.qchem.Molecule
@@ -403,7 +413,7 @@ class VQEHyperparameterSweep(ProgramBatch):
403
413
 
404
414
  Parameters
405
415
  ----------
406
- ansatze: Sequence[VQEAnsatz]
416
+ ansatze: Sequence[Ansatz]
407
417
  A sequence of ansatz circuits to test.
408
418
  molecule_transformer: MoleculeTransformer
409
419
  A `MoleculeTransformer` object defining the configuration for
@@ -430,6 +440,15 @@ class VQEHyperparameterSweep(ProgramBatch):
430
440
  )
431
441
 
432
442
  def create_programs(self):
443
+ """
444
+ Create VQE programs for all combinations of ansätze and molecule variants.
445
+
446
+ Generates molecule variants using the configured MoleculeTransformer, then
447
+ creates a VQE program for each (ansatz, molecule_variant) pair.
448
+
449
+ Note:
450
+ Program IDs are tuples of (ansatz_name, bond_modifier_value).
451
+ """
433
452
  super().create_programs()
434
453
 
435
454
  self.molecule_variants = self.molecule_transformer.generate()
@@ -437,23 +456,35 @@ class VQEHyperparameterSweep(ProgramBatch):
437
456
  for ansatz, (modifier, molecule) in product(
438
457
  self.ansatze, self.molecule_variants.items()
439
458
  ):
440
- _job_id = (ansatz, modifier)
441
- self.programs[_job_id] = self._constructor(
459
+ _job_id = (ansatz.name, modifier)
460
+ self._programs[_job_id] = self._constructor(
442
461
  job_id=_job_id,
443
462
  molecule=molecule,
444
463
  ansatz=ansatz,
445
- losses=self._manager.list(),
446
- final_params=self._manager.list(),
447
464
  progress_queue=self._queue,
448
465
  )
449
466
 
450
467
  def aggregate_results(self):
468
+ """
469
+ Find the best ansatz and bond configuration from all VQE runs.
470
+
471
+ Compares the final energies across all ansatz/molecule combinations
472
+ and returns the configuration that achieved the lowest ground state energy.
473
+
474
+ Returns:
475
+ tuple: A tuple containing:
476
+ - best_config (tuple): (ansatz_name, bond_modifier) of the best result.
477
+ - best_energy (float): The lowest energy achieved.
478
+
479
+ Raises:
480
+ RuntimeError: If programs haven't been run or have empty losses.
481
+ """
451
482
  super().aggregate_results()
452
483
 
453
- all_energies = {key: prog.losses[-1] for key, prog in self.programs.items()}
484
+ all_energies = {key: prog.best_loss for key, prog in self.programs.items()}
454
485
 
455
- smallest_key = min(all_energies, key=lambda k: min(all_energies[k].values()))
456
- smallest_value = min(all_energies[smallest_key].values())
486
+ smallest_key = min(all_energies, key=lambda k: all_energies[k])
487
+ smallest_value = all_energies[smallest_key]
457
488
 
458
489
  return smallest_key, smallest_value
459
490
 
@@ -483,18 +514,17 @@ class VQEHyperparameterSweep(ProgramBatch):
483
514
  # Plot each ansatz's results as a separate series for clarity
484
515
  for ansatz in unique_ansatze:
485
516
  modifiers = []
486
- min_energies = []
517
+ energies = []
487
518
  for modifier in self.molecule_transformer.bond_modifiers:
488
- program_key = (ansatz, modifier)
489
- if program_key in self.programs:
519
+ program_key = (ansatz.name, modifier)
520
+ if program_key in self._programs:
490
521
  modifiers.append(modifier)
491
- curr_energies = self.programs[program_key].losses[-1]
492
- min_energies.append(min(curr_energies.values()))
522
+ energies.append(self._programs[program_key].best_loss)
493
523
 
494
524
  # Use the new .name property for the label and the color_map
495
525
  plt.scatter(
496
526
  modifiers,
497
- min_energies,
527
+ energies,
498
528
  color=color_map[ansatz],
499
529
  label=ansatz.name,
500
530
  )
@@ -503,9 +533,7 @@ class VQEHyperparameterSweep(ProgramBatch):
503
533
  for ansatz in unique_ansatze:
504
534
  energies = []
505
535
  for modifier in self.molecule_transformer.bond_modifiers:
506
- energies.append(
507
- min(self.programs[(ansatz, modifier)].losses[-1].values())
508
- )
536
+ energies.append(self._programs[(ansatz.name, modifier)].best_loss)
509
537
 
510
538
  plt.plot(
511
539
  self.molecule_transformer.bond_modifiers,
divi/reporting/_pbar.py CHANGED
@@ -9,11 +9,26 @@ from rich.progress import (
9
9
  ProgressColumn,
10
10
  SpinnerColumn,
11
11
  TextColumn,
12
+ TimeElapsedColumn,
12
13
  )
13
14
  from rich.text import Text
14
15
 
15
16
 
17
+ class _UnfinishedTaskWrapper:
18
+ """Wrapper that forces a task to appear unfinished for spinner animation."""
19
+
20
+ def __init__(self, task):
21
+ self._task = task
22
+
23
+ def __getattr__(self, name):
24
+ if name == "finished":
25
+ return False
26
+ return getattr(self._task, name)
27
+
28
+
16
29
  class ConditionalSpinnerColumn(ProgressColumn):
30
+ _FINAL_STATUSES = ("Success", "Failed", "Cancelled", "Aborted")
31
+
17
32
  def __init__(self):
18
33
  super().__init__()
19
34
  self.spinner = SpinnerColumn("point")
@@ -21,10 +36,11 @@ class ConditionalSpinnerColumn(ProgressColumn):
21
36
  def render(self, task):
22
37
  status = task.fields.get("final_status")
23
38
 
24
- if status in ("Success", "Failed"):
39
+ if status in self._FINAL_STATUSES:
25
40
  return Text("")
26
41
 
27
- return self.spinner.render(task)
42
+ # Force the task to appear unfinished for spinner animation
43
+ return self.spinner.render(_UnfinishedTaskWrapper(task))
28
44
 
29
45
 
30
46
  class PhaseStatusColumn(ProgressColumn):
@@ -38,29 +54,55 @@ class PhaseStatusColumn(ProgressColumn):
38
54
  return Text("• Success! ✅", style="bold green")
39
55
  elif final_status == "Failed":
40
56
  return Text("• Failed! ❌", style="bold red")
57
+ elif final_status == "Cancelled":
58
+ return Text("• Cancelled ⏹️", style="bold yellow")
59
+ elif final_status == "Aborted":
60
+ return Text("• Aborted ⚠️", style="dim magenta")
41
61
 
42
62
  message = task.fields.get("message")
43
63
 
44
- poll_attempt = task.fields.get("poll_attempt")
64
+ poll_attempt = task.fields.get("poll_attempt", 0)
45
65
  polling_str = ""
46
- service_job_id = ""
47
- if poll_attempt > 0:
48
- max_retries = task.fields.get("max_retries")
49
- service_job_id = task.fields.get("service_job_id").split("-")[0]
66
+ service_job_id = task.fields.get("service_job_id")
67
+
68
+ if service_job_id:
69
+ split_job_id = service_job_id.split("-")[0]
50
70
  job_status = task.fields.get("job_status")
51
- polling_str = f" [Job {service_job_id} is {job_status}. Polling attempt {poll_attempt} / {max_retries}]"
71
+
72
+ if job_status == "COMPLETED":
73
+ polling_str = f" [Job {split_job_id} is complete.]"
74
+ elif poll_attempt > 0:
75
+ max_retries = task.fields.get("max_retries")
76
+ polling_str = f" [Job {split_job_id} is {job_status}. Polling attempt {poll_attempt} / {max_retries}]"
52
77
 
53
78
  final_text = Text(f"[{message}]{polling_str}")
54
- final_text.highlight_words([service_job_id], "blue")
79
+ if service_job_id:
80
+ final_text.highlight_words([split_job_id], "blue")
55
81
 
56
82
  return final_text
57
83
 
58
84
 
59
85
  def make_progress_bar(is_jupyter: bool = False) -> Progress:
86
+ """
87
+ Create a customized Rich progress bar for tracking quantum program execution.
88
+
89
+ Builds a progress bar with custom columns including job name, completion status,
90
+ elapsed time, spinner, and phase status indicators. Automatically adapts refresh
91
+ behavior for Jupyter notebook environments.
92
+
93
+ Args:
94
+ is_jupyter (bool, optional): Whether the progress bar is being displayed in
95
+ a Jupyter notebook environment. Affects refresh behavior. Defaults to False.
96
+
97
+ Returns:
98
+ Progress: A configured Rich Progress instance with custom columns for
99
+ quantum program tracking.
100
+ """
60
101
  return Progress(
61
102
  TextColumn("[bold blue]{task.fields[job_name]}"),
62
103
  BarColumn(),
63
104
  MofNCompleteColumn(),
105
+ TimeElapsedColumn(),
64
106
  ConditionalSpinnerColumn(),
65
107
  PhaseStatusColumn(),
66
108
  # For jupyter notebooks, refresh manually instead
@@ -30,6 +30,18 @@ def _is_jupyter():
30
30
  return False # IPython is not installed
31
31
 
32
32
 
33
+ class CustomFormatter(logging.Formatter):
34
+ """
35
+ A custom log formatter that removes '._reporter' from the logger name.
36
+ """
37
+
38
+ def format(self, record):
39
+ # Modify the record's name attribute in place
40
+ if record.name.endswith("._reporter"):
41
+ record.name = record.name.removesuffix("._reporter")
42
+ return super().format(record)
43
+
44
+
33
45
  class OverwriteStreamHandler(logging.StreamHandler):
34
46
  def __init__(self, stream=None):
35
47
  super().__init__(stream)
@@ -100,9 +112,24 @@ class OverwriteStreamHandler(logging.StreamHandler):
100
112
 
101
113
 
102
114
  def enable_logging(level=logging.INFO):
115
+ """
116
+ Enable logging for the divi package with custom formatting.
117
+
118
+ Sets up a custom logger with an OverwriteStreamHandler that supports
119
+ message overwriting (for progress updates) and removes the '._reporter'
120
+ suffix from logger names.
121
+
122
+ Args:
123
+ level (int, optional): Logging level to set (e.g., logging.INFO,
124
+ logging.DEBUG). Defaults to logging.INFO.
125
+
126
+ Note:
127
+ This function clears any existing handlers and sets up a new handler
128
+ with custom formatting.
129
+ """
103
130
  root_logger = logging.getLogger(__name__.split(".")[0])
104
131
 
105
- formatter = logging.Formatter(
132
+ formatter = CustomFormatter(
106
133
  "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
107
134
  datefmt="%Y-%m-%d %H:%M:%S",
108
135
  )
@@ -116,6 +143,13 @@ def enable_logging(level=logging.INFO):
116
143
 
117
144
 
118
145
  def disable_logging():
146
+ """
147
+ Disable all logging for the divi package.
148
+
149
+ Removes all handlers and sets the logging level to above CRITICAL,
150
+ effectively suppressing all log messages. This is useful when using
151
+ progress bars that provide visual feedback.
152
+ """
119
153
  root_logger = logging.getLogger(__name__.split(".")[0])
120
154
  root_logger.handlers.clear()
121
155
  root_logger.setLevel(logging.CRITICAL + 1)
@@ -13,28 +13,22 @@ class ProgressReporter(ABC):
13
13
  """An abstract base class for reporting progress of a quantum program."""
14
14
 
15
15
  @abstractmethod
16
- def update(self, **kwargs):
16
+ def update(self, **kwargs) -> None:
17
17
  """Provides a progress update."""
18
18
  pass
19
19
 
20
20
  @abstractmethod
21
- def info(self, message: str, **kwargs):
22
- """
23
- Provides a simple informational message.
24
- No changes to progress or state.
25
- """
21
+ def info(self, message: str, **kwargs) -> None:
22
+ """Provides a simple informational message."""
26
23
  pass
27
24
 
28
25
 
29
26
  class QueueProgressReporter(ProgressReporter):
30
27
  """Reports progress by putting structured dictionaries onto a Queue."""
31
28
 
32
- def __init__(
33
- self, job_id: str, progress_queue: Queue, has_final_computation: bool = False
34
- ):
29
+ def __init__(self, job_id: str, progress_queue: Queue):
35
30
  self._job_id = job_id
36
31
  self._queue = progress_queue
37
- self.has_final_computation = has_final_computation
38
32
 
39
33
  def update(self, **kwargs):
40
34
  payload = {"job_id": self._job_id, "progress": 1}
@@ -43,20 +37,19 @@ class QueueProgressReporter(ProgressReporter):
43
37
  def info(self, message: str, **kwargs):
44
38
  payload = {"job_id": self._job_id, "progress": 0, "message": message}
45
39
 
46
- # Determine if this message indicates the job is truly finished.
47
- is_final_step = "Computed Final Solution" in message or (
48
- "Finished Optimization" in message and not self.has_final_computation
49
- )
50
-
51
- if is_final_step:
40
+ if "Finished successfully!" in message:
52
41
  payload["final_status"] = "Success"
53
- elif "poll_attempt" in kwargs:
42
+
43
+ if "poll_attempt" in kwargs:
54
44
  # For polling, remove the message key so the last message persists.
55
45
  del payload["message"]
56
46
  payload["poll_attempt"] = kwargs["poll_attempt"]
57
47
  payload["max_retries"] = kwargs["max_retries"]
58
48
  payload["service_job_id"] = kwargs["service_job_id"]
59
49
  payload["job_status"] = kwargs["job_status"]
50
+ else:
51
+ # For any other message, explicitly reset the polling attempt counter.
52
+ payload["poll_attempt"] = 0
60
53
 
61
54
  self._queue.put(payload)
62
55
 
@@ -82,9 +75,7 @@ class LoggingProgressReporter(ProgressReporter):
82
75
  return
83
76
 
84
77
  if "iteration" in kwargs:
85
- logger.info(
86
- f"Running Iteration #{kwargs['iteration'] + 1} circuits: {message}\r"
87
- )
78
+ logger.info(f"Iteration #{kwargs['iteration'] + 1}: {message}\r")
88
79
  return
89
80
 
90
81
  logger.info(message)