qoro-divi 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qoro-divi might be problematic. Click here for more details.

divi/qprog/batch.py CHANGED
@@ -5,11 +5,9 @@
5
5
  import atexit
6
6
  import traceback
7
7
  from abc import ABC, abstractmethod
8
- from concurrent.futures import ProcessPoolExecutor, as_completed
9
- from multiprocessing import Event, Manager
10
- from multiprocessing.synchronize import Event as EventClass
8
+ from concurrent.futures import Future, ThreadPoolExecutor, as_completed
11
9
  from queue import Empty, Queue
12
- from threading import Lock, Thread
10
+ from threading import Event, Lock, Thread
13
11
  from typing import Any
14
12
  from warnings import warn
15
13
 
@@ -21,11 +19,11 @@ from divi.qprog.quantum_program import QuantumProgram
21
19
  from divi.reporting import disable_logging, make_progress_bar
22
20
 
23
21
 
24
- def queue_listener(
22
+ def _queue_listener(
25
23
  queue: Queue,
26
24
  progress_bar: Progress,
27
25
  pb_task_map: dict[QuantumProgram, TaskID],
28
- done_event: EventClass,
26
+ done_event: Event,
29
27
  is_jupyter: bool,
30
28
  lock: Lock,
31
29
  ):
@@ -60,6 +58,7 @@ def queue_listener(
60
58
  update_args["refresh"] = is_jupyter
61
59
 
62
60
  progress_bar.update(task_id, **update_args)
61
+ queue.task_done()
63
62
 
64
63
 
65
64
  def _default_task_function(program: QuantumProgram):
@@ -90,7 +89,7 @@ class ProgramBatch(ABC):
90
89
  self.backend = backend
91
90
  self._executor = None
92
91
  self._task_fn = _default_task_function
93
- self.programs = {}
92
+ self._programs = {}
94
93
 
95
94
  self._total_circuit_count = 0
96
95
  self._total_run_time = 0.0
@@ -103,28 +102,67 @@ class ProgramBatch(ABC):
103
102
 
104
103
  @property
105
104
  def total_circuit_count(self):
105
+ """
106
+ Get the total number of circuits executed across all programs in the batch.
107
+
108
+ Returns:
109
+ int: Cumulative count of circuits submitted by all programs.
110
+ """
106
111
  return self._total_circuit_count
107
112
 
108
113
  @property
109
114
  def total_run_time(self):
115
+ """
116
+ Get the total runtime across all programs in the batch.
117
+
118
+ Returns:
119
+ float: Cumulative execution time in seconds across all programs.
120
+ """
110
121
  return self._total_run_time
111
122
 
123
+ @property
124
+ def programs(self) -> dict:
125
+ """
126
+ Get a copy of the programs dictionary.
127
+
128
+ Returns:
129
+ dict: Copy of the programs dictionary mapping program IDs to
130
+ QuantumProgram instances. Modifications to this dict will not
131
+ affect the internal state.
132
+ """
133
+ return self._programs.copy()
134
+
135
+ @programs.setter
136
+ def programs(self, value: dict):
137
+ """Set the programs dictionary."""
138
+ self._programs = value
139
+
112
140
  @abstractmethod
113
141
  def create_programs(self):
114
- if len(self.programs) > 0:
142
+ if len(self._programs) > 0:
115
143
  raise RuntimeError(
116
144
  "Some programs already exist. "
117
145
  "Clear the program dictionary before creating new ones by using batch.reset()."
118
146
  )
119
147
 
120
- self._manager = Manager()
121
- self._queue = self._manager.Queue()
148
+ self._queue = Queue()
122
149
 
123
150
  if hasattr(self, "max_iterations"):
124
151
  self._done_event = Event()
125
152
 
126
153
  def reset(self):
127
- self.programs.clear()
154
+ """
155
+ Reset the batch to its initial state.
156
+
157
+ Clears all programs, stops any running executors, terminates listener threads,
158
+ and stops progress bars. This allows the batch to be reused for a new set of
159
+ programs.
160
+
161
+ Note:
162
+ Any running programs will be forcefully stopped. Results from incomplete
163
+ programs will be lost.
164
+ """
165
+ self._programs.clear()
128
166
 
129
167
  # Stop any active executor
130
168
  if self._executor is not None:
@@ -143,12 +181,6 @@ class ProgramBatch(ABC):
143
181
  warn("Listener thread did not terminate within timeout.")
144
182
  self._listener_thread = None
145
183
 
146
- # Shut down the manager process, which handles the queue cleanup.
147
- if hasattr(self, "_manager") and self._manager is not None:
148
- self._manager.shutdown()
149
- self._manager = None
150
- self._queue = None
151
-
152
184
  # Stop the progress bar if it's still active
153
185
  if getattr(self, "_progress_bar", None) is not None:
154
186
  try:
@@ -168,8 +200,21 @@ class ProgramBatch(ABC):
168
200
  )
169
201
  self.reset()
170
202
 
171
- def add_program_to_executor(self, program):
172
- self.futures.append(self._executor.submit(self._task_fn, program))
203
+ def _add_program_to_executor(self, program: QuantumProgram) -> Future:
204
+ """
205
+ Add a quantum program to the thread pool executor for execution.
206
+
207
+ Sets up the program with cancellation support and progress tracking, then
208
+ submits it for execution in a separate thread.
209
+
210
+ Args:
211
+ program (QuantumProgram): The quantum program to execute.
212
+
213
+ Returns:
214
+ Future: A Future object representing the program's execution.
215
+ """
216
+ if hasattr(program, "_set_cancellation_event"):
217
+ program._set_cancellation_event(self._cancellation_event)
173
218
 
174
219
  if self._progress_bar is not None:
175
220
  with self._pb_lock:
@@ -178,17 +223,39 @@ class ProgramBatch(ABC):
178
223
  job_name=f"Job {program.job_id}",
179
224
  total=self.max_iterations,
180
225
  completed=0,
181
- poll_attempt=0,
182
226
  message="",
183
- final_status="",
184
227
  mode=("simulation" if self._is_local else "network"),
185
228
  )
186
229
 
230
+ return self._executor.submit(self._task_fn, program)
231
+
187
232
  def run(self, blocking: bool = False):
233
+ """
234
+ Execute all programs in the batch.
235
+
236
+ Starts all quantum programs in parallel using a thread pool. Can run in
237
+ blocking or non-blocking mode.
238
+
239
+ Args:
240
+ blocking (bool, optional): If True, waits for all programs to complete
241
+ before returning. If False, returns immediately and programs run in
242
+ the background. Defaults to False.
243
+
244
+ Returns:
245
+ ProgramBatch: Returns self for method chaining.
246
+
247
+ Raises:
248
+ RuntimeError: If a batch is already running or if no programs have been
249
+ created.
250
+
251
+ Note:
252
+ In non-blocking mode, call `join()` later to wait for completion and
253
+ collect results.
254
+ """
188
255
  if self._executor is not None:
189
256
  raise RuntimeError("A batch is already being run.")
190
257
 
191
- if len(self.programs) == 0:
258
+ if len(self._programs) == 0:
192
259
  raise RuntimeError("No programs to run.")
193
260
 
194
261
  self._progress_bar = (
@@ -197,15 +264,17 @@ class ProgramBatch(ABC):
197
264
  else None
198
265
  )
199
266
 
200
- self._executor = ProcessPoolExecutor()
267
+ self._executor = ThreadPoolExecutor()
268
+ self._cancellation_event = Event()
201
269
  self.futures = []
270
+ self._future_to_program = {}
202
271
  self._pb_task_map = {}
203
272
  self._pb_lock = Lock()
204
273
 
205
274
  if self._progress_bar is not None:
206
275
  self._progress_bar.start()
207
276
  self._listener_thread = Thread(
208
- target=queue_listener,
277
+ target=_queue_listener,
209
278
  args=(
210
279
  self._queue,
211
280
  self._progress_bar,
@@ -218,8 +287,10 @@ class ProgramBatch(ABC):
218
287
  )
219
288
  self._listener_thread.start()
220
289
 
221
- for program in self.programs.values():
222
- self.add_program_to_executor(program)
290
+ for program in self._programs.values():
291
+ future = self._add_program_to_executor(program)
292
+ self.futures.append(future)
293
+ self._future_to_program[future] = program
223
294
 
224
295
  if not blocking:
225
296
  # Arm safety net
@@ -229,60 +300,175 @@ class ProgramBatch(ABC):
229
300
 
230
301
  return self
231
302
 
232
- def check_all_done(self):
303
+ def check_all_done(self) -> bool:
304
+ """
305
+ Check if all programs in the batch have completed execution.
306
+
307
+ Returns:
308
+ bool: True if all programs are finished (successfully or with errors),
309
+ False if any are still running.
310
+ """
233
311
  return all(future.done() for future in self.futures)
234
312
 
313
+ def _collect_completed_results(self, completed_futures: list):
314
+ """
315
+ Collects results from any futures that have completed successfully.
316
+ Appends (circuit_count, run_time) tuples to the completed_futures list.
317
+
318
+ Args:
319
+ completed_futures: List to append results to
320
+ """
321
+ for future in self.futures:
322
+ if future.done() and not future.cancelled():
323
+ try:
324
+ completed_futures.append(future.result())
325
+ except Exception:
326
+ pass # Skip failed futures
327
+
328
+ def _handle_cancellation(self):
329
+ """
330
+ Handles cancellation gracefully, providing accurate feedback by checking
331
+ the result of future.cancel().
332
+ """
333
+ self._cancellation_event.set()
334
+
335
+ successfully_cancelled = []
336
+ unstoppable_futures = []
337
+
338
+ # --- Phase 1: Attempt to cancel all non-finished tasks ---
339
+ for future, program in self._future_to_program.items():
340
+ if future.done():
341
+ continue
342
+
343
+ task_id = self._pb_task_map.get(program.job_id)
344
+ if self._progress_bar and task_id is not None:
345
+ cancel_result = future.cancel()
346
+ if cancel_result:
347
+ # The task was pending and was successfully cancelled.
348
+ successfully_cancelled.append(program)
349
+ else:
350
+ # The task is already running and cannot be stopped.
351
+ unstoppable_futures.append(future)
352
+ self._progress_bar.update(
353
+ task_id,
354
+ message="Finishing... ⏳",
355
+ refresh=self._is_jupyter,
356
+ )
357
+
358
+ # --- Phase 2: Immediately mark the successfully cancelled tasks ---
359
+ for program in successfully_cancelled:
360
+ task_id = self._pb_task_map.get(program.job_id)
361
+ if self._progress_bar and task_id is not None:
362
+ self._progress_bar.update(
363
+ task_id,
364
+ final_status="Cancelled",
365
+ message="Cancelled by user",
366
+ refresh=self._is_jupyter,
367
+ )
368
+
369
+ # --- Phase 3: Wait for the unstoppable tasks to finish ---
370
+ if unstoppable_futures:
371
+ for future in as_completed(unstoppable_futures):
372
+ program = self._future_to_program[future]
373
+ task_id = self._pb_task_map.get(program.job_id)
374
+ if self._progress_bar and task_id is not None:
375
+ self._progress_bar.update(
376
+ task_id,
377
+ final_status="Aborted",
378
+ message="Completed during cancellation",
379
+ refresh=self._is_jupyter,
380
+ )
381
+
235
382
  def join(self):
383
+ """
384
+ Wait for all programs in the batch to complete and collect results.
385
+
386
+ Blocks until all programs finish execution, aggregating their circuit counts
387
+ and run times. Handles keyboard interrupts gracefully by attempting to cancel
388
+ remaining programs.
389
+
390
+ Returns:
391
+ bool or None: Returns False if interrupted by KeyboardInterrupt, None otherwise.
392
+
393
+ Raises:
394
+ RuntimeError: If any program fails with an exception, after cancelling
395
+ remaining programs.
396
+
397
+ Note:
398
+ This method should be called after `run(blocking=False)` to wait for
399
+ completion. It's automatically called when using `run(blocking=True)`.
400
+ """
236
401
  if self._executor is None:
237
402
  return
238
403
 
239
- exceptions = []
404
+ completed_futures = []
240
405
  try:
241
- # Ensure all futures are completed and handle exceptions.
406
+ # The as_completed iterator will yield futures as they finish.
407
+ # If a task fails, future.result() will raise the exception immediately.
242
408
  for future in as_completed(self.futures):
243
- try:
244
- future.result() # Raises an exception if the task failed.
245
- except Exception as e:
246
- exceptions.append(e)
409
+ completed_futures.append(future.result())
410
+
411
+ except KeyboardInterrupt:
412
+
413
+ if self._progress_bar is not None:
414
+ self._progress_bar.console.print(
415
+ "[bold yellow]Shutdown signal received, waiting for programs to finish current iteration...[/bold yellow]"
416
+ )
417
+ self._handle_cancellation()
418
+
419
+ # Collect results from any futures that completed before/during cancellation
420
+ self._collect_completed_results(completed_futures)
421
+
422
+ return False
423
+
424
+ except Exception as e:
425
+ # A task has failed. Print the error and cancel the rest.
426
+ print(f"A task failed with an exception. Cancelling remaining tasks...")
427
+ traceback.print_exception(type(e), e, e.__traceback__)
428
+
429
+ # Collect results from any futures that completed before the failure
430
+ self._collect_completed_results(completed_futures)
431
+
432
+ # Cancel all other futures that have not yet completed.
433
+ for f in self.futures:
434
+ f.cancel()
435
+
436
+ # Re-raise a new error to indicate the batch failed.
437
+ raise RuntimeError("Batch execution failed and was cancelled.") from e
247
438
 
248
439
  finally:
249
- self._executor.shutdown(wait=True, cancel_futures=False)
440
+ # Aggregate results from completed futures
441
+ if completed_futures:
442
+ self._total_circuit_count += sum(
443
+ result[0] for result in completed_futures
444
+ )
445
+ self._total_run_time += sum(result[1] for result in completed_futures)
446
+ self.futures.clear()
447
+
448
+ self._executor.shutdown(wait=False)
250
449
  self._executor = None
251
450
 
252
451
  if self._progress_bar is not None:
452
+ self._queue.join()
253
453
  self._done_event.set()
254
454
  self._listener_thread.join()
255
-
256
- if exceptions:
257
- for i, exc in enumerate(exceptions, 1):
258
- print(f"Task {i} failed with exception:")
259
- traceback.print_exception(type(exc), exc, exc.__traceback__)
260
- raise RuntimeError("One or more tasks failed. Check logs for details.")
261
-
262
- if self._progress_bar is not None:
263
- self._progress_bar.stop()
264
-
265
- self._total_circuit_count += sum(future.result()[0] for future in self.futures)
266
- self._total_run_time += sum(future.result()[1] for future in self.futures)
267
- self.futures.clear()
455
+ self._progress_bar.stop()
268
456
 
269
457
  # After successful cleanup, try to unregister the hook.
270
- # This will only succeed if it was a non-blocking run.
271
458
  try:
272
459
  atexit.unregister(self._atexit_cleanup_hook)
273
460
  except TypeError:
274
- # This is expected for blocking runs where the hook was never registered.
275
461
  pass
276
462
 
277
463
  @abstractmethod
278
464
  def aggregate_results(self):
279
- if len(self.programs) == 0:
465
+ if len(self._programs) == 0:
280
466
  raise RuntimeError("No programs to aggregate. Run create_programs() first.")
281
467
 
282
468
  if self._executor is not None:
283
469
  self.join()
284
470
 
285
- if any(len(program.losses) == 0 for program in self.programs.values()):
471
+ if any(len(program.losses) == 0 for program in self._programs.values()):
286
472
  raise RuntimeError(
287
473
  "Some/All programs have empty losses. Did you call run()?"
288
474
  )
@@ -0,0 +1,9 @@
1
+ # SPDX-FileCopyrightText: 2025 Qoro Quantum Ltd <divi@qoroquantum.de>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ class _CancelledError(Exception):
7
+ """Internal exception to signal a task to stop due to cancellation."""
8
+
9
+ pass