qoro-divi 0.2.0b1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- divi/__init__.py +1 -2
- divi/backends/__init__.py +9 -0
- divi/backends/_circuit_runner.py +70 -0
- divi/backends/_execution_result.py +70 -0
- divi/backends/_parallel_simulator.py +486 -0
- divi/backends/_qoro_service.py +663 -0
- divi/backends/_qpu_system.py +101 -0
- divi/backends/_results_processing.py +133 -0
- divi/circuits/__init__.py +8 -0
- divi/{exp/cirq → circuits/_cirq}/__init__.py +1 -2
- divi/circuits/_cirq/_parser.py +110 -0
- divi/circuits/_cirq/_qasm_export.py +78 -0
- divi/circuits/_core.py +369 -0
- divi/{qasm.py → circuits/_qasm_conversion.py} +73 -14
- divi/circuits/_qasm_validation.py +694 -0
- divi/qprog/__init__.py +24 -6
- divi/qprog/_expectation.py +181 -0
- divi/qprog/_hamiltonians.py +281 -0
- divi/qprog/algorithms/__init__.py +14 -0
- divi/qprog/algorithms/_ansatze.py +356 -0
- divi/qprog/algorithms/_qaoa.py +572 -0
- divi/qprog/algorithms/_vqe.py +249 -0
- divi/qprog/batch.py +383 -73
- divi/qprog/checkpointing.py +556 -0
- divi/qprog/exceptions.py +9 -0
- divi/qprog/optimizers.py +1014 -43
- divi/qprog/quantum_program.py +231 -413
- divi/qprog/variational_quantum_algorithm.py +995 -0
- divi/qprog/workflows/__init__.py +10 -0
- divi/qprog/{_graph_partitioning.py → workflows/_graph_partitioning.py} +139 -95
- divi/qprog/workflows/_qubo_partitioning.py +220 -0
- divi/qprog/workflows/_vqe_sweep.py +560 -0
- divi/reporting/__init__.py +7 -0
- divi/reporting/_pbar.py +127 -0
- divi/reporting/_qlogger.py +68 -0
- divi/reporting/_reporter.py +133 -0
- {qoro_divi-0.2.0b1.dist-info → qoro_divi-0.5.0.dist-info}/METADATA +43 -15
- qoro_divi-0.5.0.dist-info/RECORD +43 -0
- {qoro_divi-0.2.0b1.dist-info → qoro_divi-0.5.0.dist-info}/WHEEL +1 -1
- qoro_divi-0.5.0.dist-info/licenses/LICENSES/.license-header +3 -0
- divi/_pbar.py +0 -73
- divi/circuits.py +0 -139
- divi/exp/cirq/_lexer.py +0 -126
- divi/exp/cirq/_parser.py +0 -889
- divi/exp/cirq/_qasm_export.py +0 -37
- divi/exp/cirq/_qasm_import.py +0 -35
- divi/exp/cirq/exception.py +0 -21
- divi/exp/scipy/_cobyla.py +0 -342
- divi/exp/scipy/pyprima/LICENCE.txt +0 -28
- divi/exp/scipy/pyprima/__init__.py +0 -263
- divi/exp/scipy/pyprima/cobyla/__init__.py +0 -0
- divi/exp/scipy/pyprima/cobyla/cobyla.py +0 -599
- divi/exp/scipy/pyprima/cobyla/cobylb.py +0 -849
- divi/exp/scipy/pyprima/cobyla/geometry.py +0 -240
- divi/exp/scipy/pyprima/cobyla/initialize.py +0 -269
- divi/exp/scipy/pyprima/cobyla/trustregion.py +0 -540
- divi/exp/scipy/pyprima/cobyla/update.py +0 -331
- divi/exp/scipy/pyprima/common/__init__.py +0 -0
- divi/exp/scipy/pyprima/common/_bounds.py +0 -41
- divi/exp/scipy/pyprima/common/_linear_constraints.py +0 -46
- divi/exp/scipy/pyprima/common/_nonlinear_constraints.py +0 -64
- divi/exp/scipy/pyprima/common/_project.py +0 -224
- divi/exp/scipy/pyprima/common/checkbreak.py +0 -107
- divi/exp/scipy/pyprima/common/consts.py +0 -48
- divi/exp/scipy/pyprima/common/evaluate.py +0 -101
- divi/exp/scipy/pyprima/common/history.py +0 -39
- divi/exp/scipy/pyprima/common/infos.py +0 -30
- divi/exp/scipy/pyprima/common/linalg.py +0 -452
- divi/exp/scipy/pyprima/common/message.py +0 -336
- divi/exp/scipy/pyprima/common/powalg.py +0 -131
- divi/exp/scipy/pyprima/common/preproc.py +0 -393
- divi/exp/scipy/pyprima/common/present.py +0 -5
- divi/exp/scipy/pyprima/common/ratio.py +0 -56
- divi/exp/scipy/pyprima/common/redrho.py +0 -49
- divi/exp/scipy/pyprima/common/selectx.py +0 -346
- divi/interfaces.py +0 -25
- divi/parallel_simulator.py +0 -258
- divi/qlogger.py +0 -119
- divi/qoro_service.py +0 -343
- divi/qprog/_mlae.py +0 -182
- divi/qprog/_qaoa.py +0 -440
- divi/qprog/_vqe.py +0 -275
- divi/qprog/_vqe_sweep.py +0 -144
- divi/utils.py +0 -116
- qoro_divi-0.2.0b1.dist-info/RECORD +0 -58
- /divi/{qem.py → circuits/qem.py} +0 -0
- {qoro_divi-0.2.0b1.dist-info → qoro_divi-0.5.0.dist-info/licenses}/LICENSE +0 -0
- {qoro_divi-0.2.0b1.dist-info → qoro_divi-0.5.0.dist-info/licenses}/LICENSES/Apache-2.0.txt +0 -0
|
@@ -0,0 +1,556 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2025 Qoro Quantum Ltd <divi@qoroquantum.de>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""Checkpointing utilities for variational quantum algorithms."""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
# Constants for checkpoint file and directory naming
|
|
14
|
+
PROGRAM_STATE_FILE = "program_state.json"
|
|
15
|
+
OPTIMIZER_STATE_FILE = "optimizer_state.json"
|
|
16
|
+
SUBDIR_PREFIX = "checkpoint_"
|
|
17
|
+
|
|
18
|
+
# Maximum reasonable iteration number (prevents parsing errors from corrupted names)
|
|
19
|
+
_MAX_ITERATION_NUMBER = 1_000_000
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_checkpoint_subdir_name(iteration: int) -> str:
|
|
23
|
+
"""Generate checkpoint subdirectory name for a given iteration.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
iteration (int): Iteration number.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
str: Subdirectory name (e.g., "checkpoint_001").
|
|
30
|
+
"""
|
|
31
|
+
return f"{SUBDIR_PREFIX}{iteration:03d}"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _extract_iteration_from_subdir(subdir_name: str) -> int | None:
|
|
35
|
+
"""Extract iteration number from checkpoint subdirectory name.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
subdir_name (str): Subdirectory name (e.g., "checkpoint_001").
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
int | None: Iteration number if valid and reasonable, None otherwise.
|
|
42
|
+
"""
|
|
43
|
+
if not subdir_name.startswith(SUBDIR_PREFIX):
|
|
44
|
+
return None
|
|
45
|
+
suffix = subdir_name[len(SUBDIR_PREFIX) :]
|
|
46
|
+
if not suffix.isdigit():
|
|
47
|
+
return None
|
|
48
|
+
iteration = int(suffix)
|
|
49
|
+
# Validate that iteration number is reasonable
|
|
50
|
+
if iteration < 0 or iteration > _MAX_ITERATION_NUMBER:
|
|
51
|
+
return None
|
|
52
|
+
return iteration
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _ensure_checkpoint_dir(checkpoint_dir: Path) -> Path:
|
|
56
|
+
"""Ensure checkpoint directory exists.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
checkpoint_dir (Path): Directory path.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Path: The checkpoint directory path.
|
|
63
|
+
"""
|
|
64
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
65
|
+
return checkpoint_dir
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _get_checkpoint_subdir_path(main_dir: Path, iteration: int) -> Path:
|
|
69
|
+
"""Get the path to a checkpoint subdirectory for a given iteration.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
main_dir (Path): Main checkpoint directory.
|
|
73
|
+
iteration (int): Iteration number.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Path: Path to the checkpoint subdirectory.
|
|
77
|
+
"""
|
|
78
|
+
subdir_name = _get_checkpoint_subdir_name(iteration)
|
|
79
|
+
return main_dir / subdir_name
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _find_latest_checkpoint_subdir(main_dir: Path) -> Path:
|
|
83
|
+
"""Find the latest checkpoint subdirectory by iteration number.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
main_dir (Path): Main checkpoint directory.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Path: Path to the latest checkpoint subdirectory.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
CheckpointNotFoundError: If no checkpoint subdirectories are found.
|
|
93
|
+
"""
|
|
94
|
+
checkpoint_dirs = [
|
|
95
|
+
d
|
|
96
|
+
for d in main_dir.iterdir()
|
|
97
|
+
if d.is_dir() and _extract_iteration_from_subdir(d.name) is not None
|
|
98
|
+
]
|
|
99
|
+
if not checkpoint_dirs:
|
|
100
|
+
# Provide helpful error message with available directories
|
|
101
|
+
available_dirs = [d.name for d in main_dir.iterdir() if d.is_dir()]
|
|
102
|
+
available_str = ", ".join(available_dirs[:5]) # Show first 5
|
|
103
|
+
if len(available_dirs) > 5:
|
|
104
|
+
available_str += f", ... ({len(available_dirs) - 5} more)"
|
|
105
|
+
raise CheckpointNotFoundError(
|
|
106
|
+
f"No checkpoint subdirectories found in {main_dir}",
|
|
107
|
+
main_dir=main_dir,
|
|
108
|
+
available_directories=available_dirs,
|
|
109
|
+
)
|
|
110
|
+
checkpoint_dirs.sort(key=lambda d: _extract_iteration_from_subdir(d.name) or -1)
|
|
111
|
+
return checkpoint_dirs[-1]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def resolve_checkpoint_path(
|
|
115
|
+
main_dir: Path | str, subdirectory: str | None = None
|
|
116
|
+
) -> Path:
|
|
117
|
+
"""Resolve the path to a checkpoint subdirectory.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
main_dir (Path | str): Main checkpoint directory.
|
|
121
|
+
subdirectory (str | None): Specific checkpoint subdirectory to load
|
|
122
|
+
(e.g., "checkpoint_001"). If None, loads the latest checkpoint
|
|
123
|
+
based on iteration number.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Path: Path to the checkpoint subdirectory.
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
CheckpointNotFoundError: If the main directory or specified subdirectory
|
|
130
|
+
does not exist.
|
|
131
|
+
"""
|
|
132
|
+
main_path = Path(main_dir)
|
|
133
|
+
if not main_path.exists():
|
|
134
|
+
raise CheckpointNotFoundError(
|
|
135
|
+
f"Checkpoint directory not found: {main_path}",
|
|
136
|
+
main_dir=main_path,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Determine which subdirectory to load
|
|
140
|
+
if subdirectory is None:
|
|
141
|
+
checkpoint_path = _find_latest_checkpoint_subdir(main_path)
|
|
142
|
+
else:
|
|
143
|
+
checkpoint_path = main_path / subdirectory
|
|
144
|
+
if not checkpoint_path.exists():
|
|
145
|
+
# Provide helpful error with available checkpoints
|
|
146
|
+
available = [
|
|
147
|
+
d.name
|
|
148
|
+
for d in main_path.iterdir()
|
|
149
|
+
if d.is_dir() and d.name.startswith(SUBDIR_PREFIX)
|
|
150
|
+
]
|
|
151
|
+
raise CheckpointNotFoundError(
|
|
152
|
+
f"Checkpoint subdirectory not found: {checkpoint_path}",
|
|
153
|
+
main_dir=main_path,
|
|
154
|
+
available_directories=available,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return checkpoint_path
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class CheckpointError(Exception):
|
|
161
|
+
"""Base exception for checkpoint-related errors."""
|
|
162
|
+
|
|
163
|
+
pass
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class CheckpointNotFoundError(CheckpointError):
|
|
167
|
+
"""Raised when a checkpoint directory or file is not found."""
|
|
168
|
+
|
|
169
|
+
def __init__(
|
|
170
|
+
self,
|
|
171
|
+
message: str,
|
|
172
|
+
main_dir: Path | None = None,
|
|
173
|
+
available_directories: list[str] | None = None,
|
|
174
|
+
):
|
|
175
|
+
super().__init__(message)
|
|
176
|
+
self.main_dir = main_dir
|
|
177
|
+
self.available_directories = available_directories or []
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class CheckpointCorruptedError(CheckpointError):
|
|
181
|
+
"""Raised when a checkpoint file is corrupted or invalid."""
|
|
182
|
+
|
|
183
|
+
def __init__(
|
|
184
|
+
self, message: str, file_path: Path | None = None, details: str | None = None
|
|
185
|
+
):
|
|
186
|
+
super().__init__(message)
|
|
187
|
+
self.file_path = file_path
|
|
188
|
+
self.details = details
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _atomic_write(path: Path, content: str) -> None:
|
|
192
|
+
"""Write content to a file atomically using a temporary file and rename.
|
|
193
|
+
|
|
194
|
+
This ensures that if the write is interrupted, the original file is not corrupted.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
path (Path): Target file path.
|
|
198
|
+
content (str): Content to write.
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
OSError: If the file cannot be written.
|
|
202
|
+
"""
|
|
203
|
+
# Create temporary file in the same directory to ensure atomic rename works
|
|
204
|
+
temp_file = path.with_suffix(path.suffix + ".tmp")
|
|
205
|
+
try:
|
|
206
|
+
with open(temp_file, "w") as f:
|
|
207
|
+
f.write(content)
|
|
208
|
+
# Atomic rename on POSIX systems
|
|
209
|
+
temp_file.replace(path)
|
|
210
|
+
except Exception as e:
|
|
211
|
+
# Clean up temp file if it exists
|
|
212
|
+
if temp_file.exists():
|
|
213
|
+
temp_file.unlink()
|
|
214
|
+
raise OSError(f"Failed to write checkpoint file {path}: {e}") from e
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _validate_checkpoint_json(
|
|
218
|
+
path: Path, required_fields: list[str] | None = None
|
|
219
|
+
) -> dict[str, Any]:
|
|
220
|
+
"""Validate that a checkpoint JSON file exists and is valid.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
path (Path): Path to the JSON file.
|
|
224
|
+
required_fields (list[str] | None): List of required top-level fields.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
dict[str, Any]: Parsed JSON data.
|
|
228
|
+
|
|
229
|
+
Raises:
|
|
230
|
+
CheckpointNotFoundError: If the file does not exist.
|
|
231
|
+
CheckpointCorruptedError: If the file is invalid JSON or missing required fields.
|
|
232
|
+
"""
|
|
233
|
+
if not path.exists():
|
|
234
|
+
raise CheckpointNotFoundError(
|
|
235
|
+
f"Checkpoint file not found: {path}",
|
|
236
|
+
main_dir=path.parent,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
try:
|
|
240
|
+
with open(path, "r") as f:
|
|
241
|
+
data = json.load(f)
|
|
242
|
+
except json.JSONDecodeError as e:
|
|
243
|
+
raise CheckpointCorruptedError(
|
|
244
|
+
f"Checkpoint file is not valid JSON: {path}",
|
|
245
|
+
file_path=path,
|
|
246
|
+
details=f"JSON decode error: {e}",
|
|
247
|
+
) from e
|
|
248
|
+
except Exception as e:
|
|
249
|
+
raise CheckpointCorruptedError(
|
|
250
|
+
f"Failed to read checkpoint file: {path}",
|
|
251
|
+
file_path=path,
|
|
252
|
+
details=str(e),
|
|
253
|
+
) from e
|
|
254
|
+
|
|
255
|
+
if required_fields:
|
|
256
|
+
missing_fields = [field for field in required_fields if field not in data]
|
|
257
|
+
if missing_fields:
|
|
258
|
+
raise CheckpointCorruptedError(
|
|
259
|
+
f"Checkpoint file is missing required fields: {path}",
|
|
260
|
+
file_path=path,
|
|
261
|
+
details=f"Missing fields: {', '.join(missing_fields)}",
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
return data
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _load_and_validate_pydantic_model(
|
|
268
|
+
path: Path,
|
|
269
|
+
model_class: type,
|
|
270
|
+
required_fields: list[str] | None = None,
|
|
271
|
+
error_context: str | None = None,
|
|
272
|
+
) -> Any:
|
|
273
|
+
"""Load and validate a checkpoint JSON file with a Pydantic model.
|
|
274
|
+
|
|
275
|
+
This function combines JSON validation, conversion to string, and Pydantic
|
|
276
|
+
model validation into a single operation.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
path (Path): Path to the JSON file.
|
|
280
|
+
model_class (type): Pydantic model class to validate against.
|
|
281
|
+
required_fields (list[str] | None): List of required top-level JSON fields.
|
|
282
|
+
error_context (str | None): Additional context for error messages (e.g., "Program state" or "Pymoo optimizer state").
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Any: Validated Pydantic model instance.
|
|
286
|
+
|
|
287
|
+
Raises:
|
|
288
|
+
CheckpointNotFoundError: If the file does not exist.
|
|
289
|
+
CheckpointCorruptedError: If the file is invalid JSON, missing required fields, or fails Pydantic validation.
|
|
290
|
+
"""
|
|
291
|
+
try:
|
|
292
|
+
json_data_dict = _validate_checkpoint_json(
|
|
293
|
+
path, required_fields=required_fields
|
|
294
|
+
)
|
|
295
|
+
# Convert dict back to JSON string for Pydantic
|
|
296
|
+
json_data = json.dumps(json_data_dict)
|
|
297
|
+
except CheckpointNotFoundError:
|
|
298
|
+
raise CheckpointNotFoundError(
|
|
299
|
+
f"Checkpoint file not found: {path}",
|
|
300
|
+
main_dir=path.parent,
|
|
301
|
+
)
|
|
302
|
+
except CheckpointCorruptedError:
|
|
303
|
+
# Re-raise JSON validation errors as-is
|
|
304
|
+
raise
|
|
305
|
+
|
|
306
|
+
try:
|
|
307
|
+
return model_class.model_validate_json(json_data)
|
|
308
|
+
except Exception as e:
|
|
309
|
+
context = f"{error_context} " if error_context else ""
|
|
310
|
+
raise CheckpointCorruptedError(
|
|
311
|
+
f"Failed to validate {context}checkpoint state: {path}",
|
|
312
|
+
file_path=path,
|
|
313
|
+
details=str(e),
|
|
314
|
+
) from e
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@dataclass(frozen=True)
|
|
318
|
+
class CheckpointConfig:
|
|
319
|
+
"""Configuration for checkpointing during optimization.
|
|
320
|
+
|
|
321
|
+
Attributes:
|
|
322
|
+
checkpoint_dir (Path | None): Directory path for saving checkpoints.
|
|
323
|
+
- If None: No checkpointing.
|
|
324
|
+
- If Path: Uses that directory.
|
|
325
|
+
checkpoint_interval (int | None): Save checkpoint every N iterations.
|
|
326
|
+
If None, saves every iteration (if checkpoint_dir is set).
|
|
327
|
+
"""
|
|
328
|
+
|
|
329
|
+
checkpoint_dir: Path | None = None
|
|
330
|
+
checkpoint_interval: int | None = None
|
|
331
|
+
|
|
332
|
+
@classmethod
|
|
333
|
+
def with_timestamped_dir(
|
|
334
|
+
cls, checkpoint_interval: int | None = None
|
|
335
|
+
) -> "CheckpointConfig":
|
|
336
|
+
"""Create CheckpointConfig with auto-generated directory name.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
checkpoint_interval (int | None): Save checkpoint every N iterations.
|
|
340
|
+
If None, saves every iteration (default).
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
CheckpointConfig: A new CheckpointConfig with auto-generated directory.
|
|
344
|
+
"""
|
|
345
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
346
|
+
generated_dir = Path(f"checkpoint_{timestamp}")
|
|
347
|
+
return cls(
|
|
348
|
+
checkpoint_dir=generated_dir, checkpoint_interval=checkpoint_interval
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def _should_checkpoint(self, iteration: int) -> bool:
|
|
352
|
+
"""Determine if a checkpoint should be saved at the given iteration.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
iteration (int): Current iteration number.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
bool: True if checkpointing is enabled and should occur at this iteration.
|
|
359
|
+
"""
|
|
360
|
+
if self.checkpoint_dir is None:
|
|
361
|
+
return False
|
|
362
|
+
|
|
363
|
+
if self.checkpoint_interval is None:
|
|
364
|
+
return True
|
|
365
|
+
|
|
366
|
+
return iteration % self.checkpoint_interval == 0
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
@dataclass(frozen=True)
|
|
370
|
+
class CheckpointInfo:
|
|
371
|
+
"""Information about a checkpoint.
|
|
372
|
+
|
|
373
|
+
Attributes:
|
|
374
|
+
path (Path): Path to the checkpoint subdirectory.
|
|
375
|
+
iteration (int): Iteration number of this checkpoint.
|
|
376
|
+
timestamp (datetime): Modification time of the checkpoint directory.
|
|
377
|
+
size_bytes (int): Total size of the checkpoint in bytes.
|
|
378
|
+
is_valid (bool): Whether the checkpoint is valid (has required files).
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
path: Path
|
|
382
|
+
iteration: int
|
|
383
|
+
timestamp: datetime
|
|
384
|
+
size_bytes: int
|
|
385
|
+
is_valid: bool
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _calculate_checkpoint_size(checkpoint_path: Path) -> int:
|
|
389
|
+
"""Calculate total size of a checkpoint directory in bytes.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
checkpoint_path (Path): Path to checkpoint subdirectory.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
int: Total size in bytes.
|
|
396
|
+
"""
|
|
397
|
+
total_size = 0
|
|
398
|
+
if checkpoint_path.exists():
|
|
399
|
+
for file_path in checkpoint_path.rglob("*"):
|
|
400
|
+
if file_path.is_file():
|
|
401
|
+
total_size += file_path.stat().st_size
|
|
402
|
+
return total_size
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def _is_checkpoint_valid(checkpoint_path: Path) -> bool:
|
|
406
|
+
"""Check if a checkpoint directory contains required files.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
checkpoint_path (Path): Path to checkpoint subdirectory.
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
bool: True if checkpoint has required files, False otherwise.
|
|
413
|
+
"""
|
|
414
|
+
program_state = checkpoint_path / PROGRAM_STATE_FILE
|
|
415
|
+
optimizer_state = checkpoint_path / OPTIMIZER_STATE_FILE
|
|
416
|
+
return program_state.exists() and optimizer_state.exists()
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def get_checkpoint_info(checkpoint_path: Path) -> CheckpointInfo:
|
|
420
|
+
"""Get information about a checkpoint.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
checkpoint_path (Path): Path to the checkpoint subdirectory.
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
CheckpointInfo: Information about the checkpoint.
|
|
427
|
+
|
|
428
|
+
Raises:
|
|
429
|
+
CheckpointNotFoundError: If the checkpoint directory does not exist.
|
|
430
|
+
"""
|
|
431
|
+
if not checkpoint_path.exists():
|
|
432
|
+
raise CheckpointNotFoundError(
|
|
433
|
+
f"Checkpoint directory not found: {checkpoint_path}",
|
|
434
|
+
main_dir=checkpoint_path.parent,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
if not checkpoint_path.is_dir():
|
|
438
|
+
raise CheckpointNotFoundError(
|
|
439
|
+
f"Checkpoint path is not a directory: {checkpoint_path}",
|
|
440
|
+
main_dir=checkpoint_path.parent,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
iteration = _extract_iteration_from_subdir(checkpoint_path.name)
|
|
444
|
+
if iteration is None:
|
|
445
|
+
raise ValueError(
|
|
446
|
+
f"Invalid checkpoint directory name: {checkpoint_path.name}. "
|
|
447
|
+
f"Expected format: {SUBDIR_PREFIX}XXX"
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# Get modification time
|
|
451
|
+
mtime = checkpoint_path.stat().st_mtime
|
|
452
|
+
timestamp = datetime.fromtimestamp(mtime)
|
|
453
|
+
|
|
454
|
+
# Calculate size
|
|
455
|
+
size_bytes = _calculate_checkpoint_size(checkpoint_path)
|
|
456
|
+
|
|
457
|
+
# Check validity
|
|
458
|
+
is_valid = _is_checkpoint_valid(checkpoint_path)
|
|
459
|
+
|
|
460
|
+
return CheckpointInfo(
|
|
461
|
+
path=checkpoint_path,
|
|
462
|
+
iteration=iteration,
|
|
463
|
+
timestamp=timestamp,
|
|
464
|
+
size_bytes=size_bytes,
|
|
465
|
+
is_valid=is_valid,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def list_checkpoints(main_dir: Path) -> list[CheckpointInfo]:
|
|
470
|
+
"""List all checkpoints in a main checkpoint directory.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
main_dir (Path): Main checkpoint directory.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
list[CheckpointInfo]: List of checkpoint information, sorted by iteration number.
|
|
477
|
+
|
|
478
|
+
Raises:
|
|
479
|
+
CheckpointNotFoundError: If the main directory does not exist.
|
|
480
|
+
"""
|
|
481
|
+
if not main_dir.exists():
|
|
482
|
+
raise CheckpointNotFoundError(
|
|
483
|
+
f"Checkpoint directory not found: {main_dir}",
|
|
484
|
+
main_dir=main_dir,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
if not main_dir.is_dir():
|
|
488
|
+
raise CheckpointNotFoundError(
|
|
489
|
+
f"Path is not a directory: {main_dir}",
|
|
490
|
+
main_dir=main_dir,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
checkpoints = []
|
|
494
|
+
for subdir in main_dir.iterdir():
|
|
495
|
+
if not subdir.is_dir():
|
|
496
|
+
continue
|
|
497
|
+
|
|
498
|
+
iteration = _extract_iteration_from_subdir(subdir.name)
|
|
499
|
+
if iteration is None:
|
|
500
|
+
continue
|
|
501
|
+
|
|
502
|
+
try:
|
|
503
|
+
info = get_checkpoint_info(subdir)
|
|
504
|
+
checkpoints.append(info)
|
|
505
|
+
except (CheckpointNotFoundError, ValueError):
|
|
506
|
+
# Skip invalid checkpoints
|
|
507
|
+
continue
|
|
508
|
+
|
|
509
|
+
# Sort by iteration number
|
|
510
|
+
checkpoints.sort(key=lambda x: x.iteration)
|
|
511
|
+
return checkpoints
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def get_latest_checkpoint(main_dir: Path) -> Path | None:
|
|
515
|
+
"""Get the path to the latest checkpoint.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
main_dir (Path): Main checkpoint directory.
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
Path | None: Path to the latest checkpoint, or None if no checkpoints exist.
|
|
522
|
+
"""
|
|
523
|
+
try:
|
|
524
|
+
return _find_latest_checkpoint_subdir(main_dir)
|
|
525
|
+
except CheckpointNotFoundError:
|
|
526
|
+
return None
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def cleanup_old_checkpoints(main_dir: Path, keep_last_n: int) -> None:
|
|
530
|
+
"""Remove old checkpoints, keeping only the most recent N.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
main_dir (Path): Main checkpoint directory.
|
|
534
|
+
keep_last_n (int): Number of most recent checkpoints to keep.
|
|
535
|
+
|
|
536
|
+
Raises:
|
|
537
|
+
ValueError: If keep_last_n is less than 1.
|
|
538
|
+
CheckpointNotFoundError: If the main directory does not exist.
|
|
539
|
+
"""
|
|
540
|
+
if keep_last_n < 1:
|
|
541
|
+
raise ValueError("keep_last_n must be at least 1")
|
|
542
|
+
|
|
543
|
+
checkpoints = list_checkpoints(main_dir)
|
|
544
|
+
|
|
545
|
+
if len(checkpoints) <= keep_last_n:
|
|
546
|
+
return
|
|
547
|
+
|
|
548
|
+
# Sort by iteration (descending) and remove oldest
|
|
549
|
+
checkpoints.sort(key=lambda x: x.iteration, reverse=True)
|
|
550
|
+
to_remove = checkpoints[keep_last_n:]
|
|
551
|
+
|
|
552
|
+
for checkpoint_info in to_remove:
|
|
553
|
+
# Remove directory and all contents
|
|
554
|
+
import shutil
|
|
555
|
+
|
|
556
|
+
shutil.rmtree(checkpoint_info.path)
|
divi/qprog/exceptions.py
ADDED