qoro-divi 0.2.0b1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. divi/__init__.py +1 -2
  2. divi/backends/__init__.py +10 -0
  3. divi/backends/_backend_properties_conversion.py +227 -0
  4. divi/backends/_circuit_runner.py +70 -0
  5. divi/backends/_execution_result.py +70 -0
  6. divi/backends/_parallel_simulator.py +486 -0
  7. divi/backends/_qoro_service.py +663 -0
  8. divi/backends/_qpu_system.py +101 -0
  9. divi/backends/_results_processing.py +133 -0
  10. divi/circuits/__init__.py +13 -0
  11. divi/{exp/cirq → circuits/_cirq}/__init__.py +1 -2
  12. divi/circuits/_cirq/_parser.py +110 -0
  13. divi/circuits/_cirq/_qasm_export.py +78 -0
  14. divi/circuits/_core.py +391 -0
  15. divi/{qasm.py → circuits/_qasm_conversion.py} +73 -14
  16. divi/circuits/_qasm_validation.py +694 -0
  17. divi/qprog/__init__.py +27 -8
  18. divi/qprog/_expectation.py +181 -0
  19. divi/qprog/_hamiltonians.py +281 -0
  20. divi/qprog/algorithms/__init__.py +16 -0
  21. divi/qprog/algorithms/_ansatze.py +368 -0
  22. divi/qprog/algorithms/_custom_vqa.py +263 -0
  23. divi/qprog/algorithms/_pce.py +262 -0
  24. divi/qprog/algorithms/_qaoa.py +579 -0
  25. divi/qprog/algorithms/_vqe.py +262 -0
  26. divi/qprog/batch.py +387 -74
  27. divi/qprog/checkpointing.py +556 -0
  28. divi/qprog/exceptions.py +9 -0
  29. divi/qprog/optimizers.py +1014 -43
  30. divi/qprog/quantum_program.py +243 -412
  31. divi/qprog/typing.py +62 -0
  32. divi/qprog/variational_quantum_algorithm.py +1208 -0
  33. divi/qprog/workflows/__init__.py +10 -0
  34. divi/qprog/{_graph_partitioning.py → workflows/_graph_partitioning.py} +139 -95
  35. divi/qprog/workflows/_qubo_partitioning.py +221 -0
  36. divi/qprog/workflows/_vqe_sweep.py +560 -0
  37. divi/reporting/__init__.py +7 -0
  38. divi/reporting/_pbar.py +127 -0
  39. divi/reporting/_qlogger.py +68 -0
  40. divi/reporting/_reporter.py +155 -0
  41. {qoro_divi-0.2.0b1.dist-info → qoro_divi-0.6.0.dist-info}/METADATA +43 -15
  42. qoro_divi-0.6.0.dist-info/RECORD +47 -0
  43. {qoro_divi-0.2.0b1.dist-info → qoro_divi-0.6.0.dist-info}/WHEEL +1 -1
  44. qoro_divi-0.6.0.dist-info/licenses/LICENSES/.license-header +3 -0
  45. divi/_pbar.py +0 -73
  46. divi/circuits.py +0 -139
  47. divi/exp/cirq/_lexer.py +0 -126
  48. divi/exp/cirq/_parser.py +0 -889
  49. divi/exp/cirq/_qasm_export.py +0 -37
  50. divi/exp/cirq/_qasm_import.py +0 -35
  51. divi/exp/cirq/exception.py +0 -21
  52. divi/exp/scipy/_cobyla.py +0 -342
  53. divi/exp/scipy/pyprima/LICENCE.txt +0 -28
  54. divi/exp/scipy/pyprima/__init__.py +0 -263
  55. divi/exp/scipy/pyprima/cobyla/__init__.py +0 -0
  56. divi/exp/scipy/pyprima/cobyla/cobyla.py +0 -599
  57. divi/exp/scipy/pyprima/cobyla/cobylb.py +0 -849
  58. divi/exp/scipy/pyprima/cobyla/geometry.py +0 -240
  59. divi/exp/scipy/pyprima/cobyla/initialize.py +0 -269
  60. divi/exp/scipy/pyprima/cobyla/trustregion.py +0 -540
  61. divi/exp/scipy/pyprima/cobyla/update.py +0 -331
  62. divi/exp/scipy/pyprima/common/__init__.py +0 -0
  63. divi/exp/scipy/pyprima/common/_bounds.py +0 -41
  64. divi/exp/scipy/pyprima/common/_linear_constraints.py +0 -46
  65. divi/exp/scipy/pyprima/common/_nonlinear_constraints.py +0 -64
  66. divi/exp/scipy/pyprima/common/_project.py +0 -224
  67. divi/exp/scipy/pyprima/common/checkbreak.py +0 -107
  68. divi/exp/scipy/pyprima/common/consts.py +0 -48
  69. divi/exp/scipy/pyprima/common/evaluate.py +0 -101
  70. divi/exp/scipy/pyprima/common/history.py +0 -39
  71. divi/exp/scipy/pyprima/common/infos.py +0 -30
  72. divi/exp/scipy/pyprima/common/linalg.py +0 -452
  73. divi/exp/scipy/pyprima/common/message.py +0 -336
  74. divi/exp/scipy/pyprima/common/powalg.py +0 -131
  75. divi/exp/scipy/pyprima/common/preproc.py +0 -393
  76. divi/exp/scipy/pyprima/common/present.py +0 -5
  77. divi/exp/scipy/pyprima/common/ratio.py +0 -56
  78. divi/exp/scipy/pyprima/common/redrho.py +0 -49
  79. divi/exp/scipy/pyprima/common/selectx.py +0 -346
  80. divi/interfaces.py +0 -25
  81. divi/parallel_simulator.py +0 -258
  82. divi/qlogger.py +0 -119
  83. divi/qoro_service.py +0 -343
  84. divi/qprog/_mlae.py +0 -182
  85. divi/qprog/_qaoa.py +0 -440
  86. divi/qprog/_vqe.py +0 -275
  87. divi/qprog/_vqe_sweep.py +0 -144
  88. divi/utils.py +0 -116
  89. qoro_divi-0.2.0b1.dist-info/RECORD +0 -58
  90. /divi/{qem.py → circuits/qem.py} +0 -0
  91. {qoro_divi-0.2.0b1.dist-info → qoro_divi-0.6.0.dist-info/licenses}/LICENSE +0 -0
  92. {qoro_divi-0.2.0b1.dist-info → qoro_divi-0.6.0.dist-info/licenses}/LICENSES/Apache-2.0.txt +0 -0
@@ -0,0 +1,556 @@
1
+ # SPDX-FileCopyrightText: 2025 Qoro Quantum Ltd <divi@qoroquantum.de>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Checkpointing utilities for variational quantum algorithms."""
6
+
7
+ import json
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ # Constants for checkpoint file and directory naming
14
+ PROGRAM_STATE_FILE = "program_state.json"
15
+ OPTIMIZER_STATE_FILE = "optimizer_state.json"
16
+ SUBDIR_PREFIX = "checkpoint_"
17
+
18
+ # Maximum reasonable iteration number (prevents parsing errors from corrupted names)
19
+ _MAX_ITERATION_NUMBER = 1_000_000
20
+
21
+
22
+ def _get_checkpoint_subdir_name(iteration: int) -> str:
23
+ """Generate checkpoint subdirectory name for a given iteration.
24
+
25
+ Args:
26
+ iteration (int): Iteration number.
27
+
28
+ Returns:
29
+ str: Subdirectory name (e.g., "checkpoint_001").
30
+ """
31
+ return f"{SUBDIR_PREFIX}{iteration:03d}"
32
+
33
+
34
+ def _extract_iteration_from_subdir(subdir_name: str) -> int | None:
35
+ """Extract iteration number from checkpoint subdirectory name.
36
+
37
+ Args:
38
+ subdir_name (str): Subdirectory name (e.g., "checkpoint_001").
39
+
40
+ Returns:
41
+ int | None: Iteration number if valid and reasonable, None otherwise.
42
+ """
43
+ if not subdir_name.startswith(SUBDIR_PREFIX):
44
+ return None
45
+ suffix = subdir_name[len(SUBDIR_PREFIX) :]
46
+ if not suffix.isdigit():
47
+ return None
48
+ iteration = int(suffix)
49
+ # Validate that iteration number is reasonable
50
+ if iteration < 0 or iteration > _MAX_ITERATION_NUMBER:
51
+ return None
52
+ return iteration
53
+
54
+
55
+ def _ensure_checkpoint_dir(checkpoint_dir: Path) -> Path:
56
+ """Ensure checkpoint directory exists.
57
+
58
+ Args:
59
+ checkpoint_dir (Path): Directory path.
60
+
61
+ Returns:
62
+ Path: The checkpoint directory path.
63
+ """
64
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
65
+ return checkpoint_dir
66
+
67
+
68
+ def _get_checkpoint_subdir_path(main_dir: Path, iteration: int) -> Path:
69
+ """Get the path to a checkpoint subdirectory for a given iteration.
70
+
71
+ Args:
72
+ main_dir (Path): Main checkpoint directory.
73
+ iteration (int): Iteration number.
74
+
75
+ Returns:
76
+ Path: Path to the checkpoint subdirectory.
77
+ """
78
+ subdir_name = _get_checkpoint_subdir_name(iteration)
79
+ return main_dir / subdir_name
80
+
81
+
82
+ def _find_latest_checkpoint_subdir(main_dir: Path) -> Path:
83
+ """Find the latest checkpoint subdirectory by iteration number.
84
+
85
+ Args:
86
+ main_dir (Path): Main checkpoint directory.
87
+
88
+ Returns:
89
+ Path: Path to the latest checkpoint subdirectory.
90
+
91
+ Raises:
92
+ CheckpointNotFoundError: If no checkpoint subdirectories are found.
93
+ """
94
+ checkpoint_dirs = [
95
+ d
96
+ for d in main_dir.iterdir()
97
+ if d.is_dir() and _extract_iteration_from_subdir(d.name) is not None
98
+ ]
99
+ if not checkpoint_dirs:
100
+ # Provide helpful error message with available directories
101
+ available_dirs = [d.name for d in main_dir.iterdir() if d.is_dir()]
102
+ available_str = ", ".join(available_dirs[:5]) # Show first 5
103
+ if len(available_dirs) > 5:
104
+ available_str += f", ... ({len(available_dirs) - 5} more)"
105
+ raise CheckpointNotFoundError(
106
+ f"No checkpoint subdirectories found in {main_dir}",
107
+ main_dir=main_dir,
108
+ available_directories=available_dirs,
109
+ )
110
+ checkpoint_dirs.sort(key=lambda d: _extract_iteration_from_subdir(d.name) or -1)
111
+ return checkpoint_dirs[-1]
112
+
113
+
114
+ def resolve_checkpoint_path(
115
+ main_dir: Path | str, subdirectory: str | None = None
116
+ ) -> Path:
117
+ """Resolve the path to a checkpoint subdirectory.
118
+
119
+ Args:
120
+ main_dir (Path | str): Main checkpoint directory.
121
+ subdirectory (str | None): Specific checkpoint subdirectory to load
122
+ (e.g., "checkpoint_001"). If None, loads the latest checkpoint
123
+ based on iteration number.
124
+
125
+ Returns:
126
+ Path: Path to the checkpoint subdirectory.
127
+
128
+ Raises:
129
+ CheckpointNotFoundError: If the main directory or specified subdirectory
130
+ does not exist.
131
+ """
132
+ main_path = Path(main_dir)
133
+ if not main_path.exists():
134
+ raise CheckpointNotFoundError(
135
+ f"Checkpoint directory not found: {main_path}",
136
+ main_dir=main_path,
137
+ )
138
+
139
+ # Determine which subdirectory to load
140
+ if subdirectory is None:
141
+ checkpoint_path = _find_latest_checkpoint_subdir(main_path)
142
+ else:
143
+ checkpoint_path = main_path / subdirectory
144
+ if not checkpoint_path.exists():
145
+ # Provide helpful error with available checkpoints
146
+ available = [
147
+ d.name
148
+ for d in main_path.iterdir()
149
+ if d.is_dir() and d.name.startswith(SUBDIR_PREFIX)
150
+ ]
151
+ raise CheckpointNotFoundError(
152
+ f"Checkpoint subdirectory not found: {checkpoint_path}",
153
+ main_dir=main_path,
154
+ available_directories=available,
155
+ )
156
+
157
+ return checkpoint_path
158
+
159
+
160
+ class CheckpointError(Exception):
161
+ """Base exception for checkpoint-related errors."""
162
+
163
+ pass
164
+
165
+
166
+ class CheckpointNotFoundError(CheckpointError):
167
+ """Raised when a checkpoint directory or file is not found."""
168
+
169
+ def __init__(
170
+ self,
171
+ message: str,
172
+ main_dir: Path | None = None,
173
+ available_directories: list[str] | None = None,
174
+ ):
175
+ super().__init__(message)
176
+ self.main_dir = main_dir
177
+ self.available_directories = available_directories or []
178
+
179
+
180
+ class CheckpointCorruptedError(CheckpointError):
181
+ """Raised when a checkpoint file is corrupted or invalid."""
182
+
183
+ def __init__(
184
+ self, message: str, file_path: Path | None = None, details: str | None = None
185
+ ):
186
+ super().__init__(message)
187
+ self.file_path = file_path
188
+ self.details = details
189
+
190
+
191
+ def _atomic_write(path: Path, content: str) -> None:
192
+ """Write content to a file atomically using a temporary file and rename.
193
+
194
+ This ensures that if the write is interrupted, the original file is not corrupted.
195
+
196
+ Args:
197
+ path (Path): Target file path.
198
+ content (str): Content to write.
199
+
200
+ Raises:
201
+ OSError: If the file cannot be written.
202
+ """
203
+ # Create temporary file in the same directory to ensure atomic rename works
204
+ temp_file = path.with_suffix(path.suffix + ".tmp")
205
+ try:
206
+ with open(temp_file, "w") as f:
207
+ f.write(content)
208
+ # Atomic rename on POSIX systems
209
+ temp_file.replace(path)
210
+ except Exception as e:
211
+ # Clean up temp file if it exists
212
+ if temp_file.exists():
213
+ temp_file.unlink()
214
+ raise OSError(f"Failed to write checkpoint file {path}: {e}") from e
215
+
216
+
217
+ def _validate_checkpoint_json(
218
+ path: Path, required_fields: list[str] | None = None
219
+ ) -> dict[str, Any]:
220
+ """Validate that a checkpoint JSON file exists and is valid.
221
+
222
+ Args:
223
+ path (Path): Path to the JSON file.
224
+ required_fields (list[str] | None): List of required top-level fields.
225
+
226
+ Returns:
227
+ dict[str, Any]: Parsed JSON data.
228
+
229
+ Raises:
230
+ CheckpointNotFoundError: If the file does not exist.
231
+ CheckpointCorruptedError: If the file is invalid JSON or missing required fields.
232
+ """
233
+ if not path.exists():
234
+ raise CheckpointNotFoundError(
235
+ f"Checkpoint file not found: {path}",
236
+ main_dir=path.parent,
237
+ )
238
+
239
+ try:
240
+ with open(path, "r") as f:
241
+ data = json.load(f)
242
+ except json.JSONDecodeError as e:
243
+ raise CheckpointCorruptedError(
244
+ f"Checkpoint file is not valid JSON: {path}",
245
+ file_path=path,
246
+ details=f"JSON decode error: {e}",
247
+ ) from e
248
+ except Exception as e:
249
+ raise CheckpointCorruptedError(
250
+ f"Failed to read checkpoint file: {path}",
251
+ file_path=path,
252
+ details=str(e),
253
+ ) from e
254
+
255
+ if required_fields:
256
+ missing_fields = [field for field in required_fields if field not in data]
257
+ if missing_fields:
258
+ raise CheckpointCorruptedError(
259
+ f"Checkpoint file is missing required fields: {path}",
260
+ file_path=path,
261
+ details=f"Missing fields: {', '.join(missing_fields)}",
262
+ )
263
+
264
+ return data
265
+
266
+
267
+ def _load_and_validate_pydantic_model(
268
+ path: Path,
269
+ model_class: type,
270
+ required_fields: list[str] | None = None,
271
+ error_context: str | None = None,
272
+ ) -> Any:
273
+ """Load and validate a checkpoint JSON file with a Pydantic model.
274
+
275
+ This function combines JSON validation, conversion to string, and Pydantic
276
+ model validation into a single operation.
277
+
278
+ Args:
279
+ path (Path): Path to the JSON file.
280
+ model_class (type): Pydantic model class to validate against.
281
+ required_fields (list[str] | None): List of required top-level JSON fields.
282
+ error_context (str | None): Additional context for error messages (e.g., "Program state" or "Pymoo optimizer state").
283
+
284
+ Returns:
285
+ Any: Validated Pydantic model instance.
286
+
287
+ Raises:
288
+ CheckpointNotFoundError: If the file does not exist.
289
+ CheckpointCorruptedError: If the file is invalid JSON, missing required fields, or fails Pydantic validation.
290
+ """
291
+ try:
292
+ json_data_dict = _validate_checkpoint_json(
293
+ path, required_fields=required_fields
294
+ )
295
+ # Convert dict back to JSON string for Pydantic
296
+ json_data = json.dumps(json_data_dict)
297
+ except CheckpointNotFoundError:
298
+ raise CheckpointNotFoundError(
299
+ f"Checkpoint file not found: {path}",
300
+ main_dir=path.parent,
301
+ )
302
+ except CheckpointCorruptedError:
303
+ # Re-raise JSON validation errors as-is
304
+ raise
305
+
306
+ try:
307
+ return model_class.model_validate_json(json_data)
308
+ except Exception as e:
309
+ context = f"{error_context} " if error_context else ""
310
+ raise CheckpointCorruptedError(
311
+ f"Failed to validate {context}checkpoint state: {path}",
312
+ file_path=path,
313
+ details=str(e),
314
+ ) from e
315
+
316
+
317
+ @dataclass(frozen=True)
318
+ class CheckpointConfig:
319
+ """Configuration for checkpointing during optimization.
320
+
321
+ Attributes:
322
+ checkpoint_dir (Path | None): Directory path for saving checkpoints.
323
+ - If None: No checkpointing.
324
+ - If Path: Uses that directory.
325
+ checkpoint_interval (int | None): Save checkpoint every N iterations.
326
+ If None, saves every iteration (if checkpoint_dir is set).
327
+ """
328
+
329
+ checkpoint_dir: Path | None = None
330
+ checkpoint_interval: int | None = None
331
+
332
+ @classmethod
333
+ def with_timestamped_dir(
334
+ cls, checkpoint_interval: int | None = None
335
+ ) -> "CheckpointConfig":
336
+ """Create CheckpointConfig with auto-generated directory name.
337
+
338
+ Args:
339
+ checkpoint_interval (int | None): Save checkpoint every N iterations.
340
+ If None, saves every iteration (default).
341
+
342
+ Returns:
343
+ CheckpointConfig: A new CheckpointConfig with auto-generated directory.
344
+ """
345
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
346
+ generated_dir = Path(f"checkpoint_{timestamp}")
347
+ return cls(
348
+ checkpoint_dir=generated_dir, checkpoint_interval=checkpoint_interval
349
+ )
350
+
351
+ def _should_checkpoint(self, iteration: int) -> bool:
352
+ """Determine if a checkpoint should be saved at the given iteration.
353
+
354
+ Args:
355
+ iteration (int): Current iteration number.
356
+
357
+ Returns:
358
+ bool: True if checkpointing is enabled and should occur at this iteration.
359
+ """
360
+ if self.checkpoint_dir is None:
361
+ return False
362
+
363
+ if self.checkpoint_interval is None:
364
+ return True
365
+
366
+ return iteration % self.checkpoint_interval == 0
367
+
368
+
369
+ @dataclass(frozen=True)
370
+ class CheckpointInfo:
371
+ """Information about a checkpoint.
372
+
373
+ Attributes:
374
+ path (Path): Path to the checkpoint subdirectory.
375
+ iteration (int): Iteration number of this checkpoint.
376
+ timestamp (datetime): Modification time of the checkpoint directory.
377
+ size_bytes (int): Total size of the checkpoint in bytes.
378
+ is_valid (bool): Whether the checkpoint is valid (has required files).
379
+ """
380
+
381
+ path: Path
382
+ iteration: int
383
+ timestamp: datetime
384
+ size_bytes: int
385
+ is_valid: bool
386
+
387
+
388
+ def _calculate_checkpoint_size(checkpoint_path: Path) -> int:
389
+ """Calculate total size of a checkpoint directory in bytes.
390
+
391
+ Args:
392
+ checkpoint_path (Path): Path to checkpoint subdirectory.
393
+
394
+ Returns:
395
+ int: Total size in bytes.
396
+ """
397
+ total_size = 0
398
+ if checkpoint_path.exists():
399
+ for file_path in checkpoint_path.rglob("*"):
400
+ if file_path.is_file():
401
+ total_size += file_path.stat().st_size
402
+ return total_size
403
+
404
+
405
+ def _is_checkpoint_valid(checkpoint_path: Path) -> bool:
406
+ """Check if a checkpoint directory contains required files.
407
+
408
+ Args:
409
+ checkpoint_path (Path): Path to checkpoint subdirectory.
410
+
411
+ Returns:
412
+ bool: True if checkpoint has required files, False otherwise.
413
+ """
414
+ program_state = checkpoint_path / PROGRAM_STATE_FILE
415
+ optimizer_state = checkpoint_path / OPTIMIZER_STATE_FILE
416
+ return program_state.exists() and optimizer_state.exists()
417
+
418
+
419
+ def get_checkpoint_info(checkpoint_path: Path) -> CheckpointInfo:
420
+ """Get information about a checkpoint.
421
+
422
+ Args:
423
+ checkpoint_path (Path): Path to the checkpoint subdirectory.
424
+
425
+ Returns:
426
+ CheckpointInfo: Information about the checkpoint.
427
+
428
+ Raises:
429
+ CheckpointNotFoundError: If the checkpoint directory does not exist.
430
+ """
431
+ if not checkpoint_path.exists():
432
+ raise CheckpointNotFoundError(
433
+ f"Checkpoint directory not found: {checkpoint_path}",
434
+ main_dir=checkpoint_path.parent,
435
+ )
436
+
437
+ if not checkpoint_path.is_dir():
438
+ raise CheckpointNotFoundError(
439
+ f"Checkpoint path is not a directory: {checkpoint_path}",
440
+ main_dir=checkpoint_path.parent,
441
+ )
442
+
443
+ iteration = _extract_iteration_from_subdir(checkpoint_path.name)
444
+ if iteration is None:
445
+ raise ValueError(
446
+ f"Invalid checkpoint directory name: {checkpoint_path.name}. "
447
+ f"Expected format: {SUBDIR_PREFIX}XXX"
448
+ )
449
+
450
+ # Get modification time
451
+ mtime = checkpoint_path.stat().st_mtime
452
+ timestamp = datetime.fromtimestamp(mtime)
453
+
454
+ # Calculate size
455
+ size_bytes = _calculate_checkpoint_size(checkpoint_path)
456
+
457
+ # Check validity
458
+ is_valid = _is_checkpoint_valid(checkpoint_path)
459
+
460
+ return CheckpointInfo(
461
+ path=checkpoint_path,
462
+ iteration=iteration,
463
+ timestamp=timestamp,
464
+ size_bytes=size_bytes,
465
+ is_valid=is_valid,
466
+ )
467
+
468
+
469
+ def list_checkpoints(main_dir: Path) -> list[CheckpointInfo]:
470
+ """List all checkpoints in a main checkpoint directory.
471
+
472
+ Args:
473
+ main_dir (Path): Main checkpoint directory.
474
+
475
+ Returns:
476
+ list[CheckpointInfo]: List of checkpoint information, sorted by iteration number.
477
+
478
+ Raises:
479
+ CheckpointNotFoundError: If the main directory does not exist.
480
+ """
481
+ if not main_dir.exists():
482
+ raise CheckpointNotFoundError(
483
+ f"Checkpoint directory not found: {main_dir}",
484
+ main_dir=main_dir,
485
+ )
486
+
487
+ if not main_dir.is_dir():
488
+ raise CheckpointNotFoundError(
489
+ f"Path is not a directory: {main_dir}",
490
+ main_dir=main_dir,
491
+ )
492
+
493
+ checkpoints = []
494
+ for subdir in main_dir.iterdir():
495
+ if not subdir.is_dir():
496
+ continue
497
+
498
+ iteration = _extract_iteration_from_subdir(subdir.name)
499
+ if iteration is None:
500
+ continue
501
+
502
+ try:
503
+ info = get_checkpoint_info(subdir)
504
+ checkpoints.append(info)
505
+ except (CheckpointNotFoundError, ValueError):
506
+ # Skip invalid checkpoints
507
+ continue
508
+
509
+ # Sort by iteration number
510
+ checkpoints.sort(key=lambda x: x.iteration)
511
+ return checkpoints
512
+
513
+
514
+ def get_latest_checkpoint(main_dir: Path) -> Path | None:
515
+ """Get the path to the latest checkpoint.
516
+
517
+ Args:
518
+ main_dir (Path): Main checkpoint directory.
519
+
520
+ Returns:
521
+ Path | None: Path to the latest checkpoint, or None if no checkpoints exist.
522
+ """
523
+ try:
524
+ return _find_latest_checkpoint_subdir(main_dir)
525
+ except CheckpointNotFoundError:
526
+ return None
527
+
528
+
529
+ def cleanup_old_checkpoints(main_dir: Path, keep_last_n: int) -> None:
530
+ """Remove old checkpoints, keeping only the most recent N.
531
+
532
+ Args:
533
+ main_dir (Path): Main checkpoint directory.
534
+ keep_last_n (int): Number of most recent checkpoints to keep.
535
+
536
+ Raises:
537
+ ValueError: If keep_last_n is less than 1.
538
+ CheckpointNotFoundError: If the main directory does not exist.
539
+ """
540
+ if keep_last_n < 1:
541
+ raise ValueError("keep_last_n must be at least 1")
542
+
543
+ checkpoints = list_checkpoints(main_dir)
544
+
545
+ if len(checkpoints) <= keep_last_n:
546
+ return
547
+
548
+ # Sort by iteration (descending) and remove oldest
549
+ checkpoints.sort(key=lambda x: x.iteration, reverse=True)
550
+ to_remove = checkpoints[keep_last_n:]
551
+
552
+ for checkpoint_info in to_remove:
553
+ # Remove directory and all contents
554
+ import shutil
555
+
556
+ shutil.rmtree(checkpoint_info.path)
@@ -0,0 +1,9 @@
1
+ # SPDX-FileCopyrightText: 2025 Qoro Quantum Ltd <divi@qoroquantum.de>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ class _CancelledError(Exception):
7
+ """Internal exception to signal a task to stop due to cancellation."""
8
+
9
+ pass