quantumflow-sdk 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- api/main.py +28 -1
- api/models.py +41 -0
- api/routes/algorithm_routes.py +134 -0
- api/routes/chat_routes.py +565 -0
- api/routes/pipeline_routes.py +578 -0
- db/models.py +357 -0
- quantumflow/algorithms/machine_learning/__init__.py +14 -2
- quantumflow/algorithms/machine_learning/vqe.py +355 -3
- quantumflow/core/__init__.py +10 -1
- quantumflow/core/quantum_compressor.py +379 -1
- quantumflow/integrations/domain_agents.py +617 -0
- quantumflow/pipeline/__init__.py +29 -0
- quantumflow/pipeline/anomaly_detector.py +521 -0
- quantumflow/pipeline/base_pipeline.py +602 -0
- quantumflow/pipeline/checkpoint_manager.py +587 -0
- quantumflow/pipeline/finance/__init__.py +5 -0
- quantumflow/pipeline/finance/portfolio_optimization.py +595 -0
- quantumflow/pipeline/healthcare/__init__.py +5 -0
- quantumflow/pipeline/healthcare/protein_folding.py +994 -0
- quantumflow/pipeline/temporal_memory.py +577 -0
- {quantumflow_sdk-0.3.0.dist-info → quantumflow_sdk-0.4.0.dist-info}/METADATA +3 -3
- {quantumflow_sdk-0.3.0.dist-info → quantumflow_sdk-0.4.0.dist-info}/RECORD +25 -13
- {quantumflow_sdk-0.3.0.dist-info → quantumflow_sdk-0.4.0.dist-info}/WHEEL +0 -0
- {quantumflow_sdk-0.3.0.dist-info → quantumflow_sdk-0.4.0.dist-info}/entry_points.txt +0 -0
- {quantumflow_sdk-0.3.0.dist-info → quantumflow_sdk-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Checkpoint Manager for Pipeline State Persistence.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- Save/load checkpoints with SHA-256 integrity verification
|
|
6
|
+
- Optional quantum compression for large state vectors
|
|
7
|
+
- Automatic pruning of old checkpoints
|
|
8
|
+
- Database persistence
|
|
9
|
+
|
|
10
|
+
Example:
|
|
11
|
+
manager = CheckpointManager(backend="simulator")
|
|
12
|
+
|
|
13
|
+
# Save checkpoint
|
|
14
|
+
manager.save(
|
|
15
|
+
pipeline_id="...",
|
|
16
|
+
step=100,
|
|
17
|
+
state_data={"weights": [...], "energy": -1.5},
|
|
18
|
+
metrics={"loss": 0.01},
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Load checkpoint
|
|
22
|
+
checkpoint = manager.load(pipeline_id, step=100)
|
|
23
|
+
|
|
24
|
+
# Get latest valid checkpoint
|
|
25
|
+
latest = manager.get_latest_valid(pipeline_id)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import hashlib
|
|
29
|
+
import json
|
|
30
|
+
import logging
|
|
31
|
+
from datetime import datetime
|
|
32
|
+
from typing import Any, Dict, List, Optional
|
|
33
|
+
import uuid
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class CheckpointManager:
|
|
39
|
+
"""
|
|
40
|
+
Manages checkpoint persistence for pipelines.
|
|
41
|
+
|
|
42
|
+
Supports:
|
|
43
|
+
- State serialization with SHA-256 integrity hashes
|
|
44
|
+
- Optional quantum compression for large states
|
|
45
|
+
- Automatic pruning to keep N most recent
|
|
46
|
+
- In-memory cache for fast access
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
backend: str = "simulator",
|
|
52
|
+
use_database: bool = True,
|
|
53
|
+
cache_size: int = 10,
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Initialize checkpoint manager.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
backend: Quantum backend for compression
|
|
60
|
+
use_database: Whether to persist to database
|
|
61
|
+
cache_size: Number of checkpoints to cache in memory
|
|
62
|
+
"""
|
|
63
|
+
self.backend = backend
|
|
64
|
+
self.use_database = use_database
|
|
65
|
+
self.cache_size = cache_size
|
|
66
|
+
|
|
67
|
+
# In-memory storage (for non-DB mode or caching)
|
|
68
|
+
self._checkpoints: Dict[str, Dict[int, Dict[str, Any]]] = {}
|
|
69
|
+
self._cache: Dict[str, Dict[int, Dict[str, Any]]] = {}
|
|
70
|
+
|
|
71
|
+
# Quantum compressor (lazy loaded)
|
|
72
|
+
self._compressor = None
|
|
73
|
+
|
|
74
|
+
def _get_compressor(self):
|
|
75
|
+
"""Get or create quantum compressor."""
|
|
76
|
+
if self._compressor is None:
|
|
77
|
+
try:
|
|
78
|
+
from quantumflow.core.quantum_compressor import QuantumCompressor
|
|
79
|
+
self._compressor = QuantumCompressor(backend=self.backend)
|
|
80
|
+
except ImportError:
|
|
81
|
+
logger.warning("QuantumCompressor not available, compression disabled")
|
|
82
|
+
return self._compressor
|
|
83
|
+
|
|
84
|
+
def _compute_hash(self, data: Dict[str, Any]) -> str:
|
|
85
|
+
"""Compute SHA-256 hash of state data."""
|
|
86
|
+
serialized = json.dumps(data, sort_keys=True, default=str)
|
|
87
|
+
return hashlib.sha256(serialized.encode()).hexdigest()
|
|
88
|
+
|
|
89
|
+
def _serialize_state(self, state_data: Dict[str, Any]) -> str:
|
|
90
|
+
"""Serialize state data to JSON string."""
|
|
91
|
+
return json.dumps(state_data, default=self._json_serializer)
|
|
92
|
+
|
|
93
|
+
def _deserialize_state(self, state_json: str) -> Dict[str, Any]:
|
|
94
|
+
"""Deserialize JSON string to state data."""
|
|
95
|
+
return json.loads(state_json)
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def _json_serializer(obj):
|
|
99
|
+
"""Custom JSON serializer for complex types."""
|
|
100
|
+
if hasattr(obj, "tolist"): # numpy arrays
|
|
101
|
+
return obj.tolist()
|
|
102
|
+
if hasattr(obj, "isoformat"): # datetime
|
|
103
|
+
return obj.isoformat()
|
|
104
|
+
if hasattr(obj, "__dict__"):
|
|
105
|
+
return obj.__dict__
|
|
106
|
+
return str(obj)
|
|
107
|
+
|
|
108
|
+
def _compress_state(self, state_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
109
|
+
"""
|
|
110
|
+
Quantum-compress state data.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
state_data: State to compress
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Compressed state or None if compression not available
|
|
117
|
+
"""
|
|
118
|
+
compressor = self._get_compressor()
|
|
119
|
+
if not compressor:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
# Extract numeric data for compression
|
|
124
|
+
vectors = []
|
|
125
|
+
vector_keys = []
|
|
126
|
+
|
|
127
|
+
for key, value in state_data.items():
|
|
128
|
+
if isinstance(value, (list, tuple)) and all(
|
|
129
|
+
isinstance(v, (int, float)) for v in value
|
|
130
|
+
):
|
|
131
|
+
vectors.extend(value)
|
|
132
|
+
vector_keys.append((key, len(value)))
|
|
133
|
+
|
|
134
|
+
if not vectors:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
# Compress
|
|
138
|
+
compressed = compressor.compress(vectors)
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
"amplitudes": compressed.amplitudes.tolist() if hasattr(compressed.amplitudes, "tolist") else list(compressed.amplitudes),
|
|
142
|
+
"n_qubits": compressed.n_qubits,
|
|
143
|
+
"vector_keys": vector_keys,
|
|
144
|
+
"compression_ratio": compressed.compression_percentage / 100,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.warning(f"Compression failed: {e}")
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
def _decompress_state(
|
|
152
|
+
self, compressed: Dict[str, Any], original_keys: List[tuple]
|
|
153
|
+
) -> Dict[str, Any]:
|
|
154
|
+
"""
|
|
155
|
+
Decompress quantum-compressed state.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
compressed: Compressed state data
|
|
159
|
+
original_keys: List of (key, length) tuples
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Decompressed state data
|
|
163
|
+
"""
|
|
164
|
+
compressor = self._get_compressor()
|
|
165
|
+
if not compressor:
|
|
166
|
+
return {}
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
from quantumflow.core.quantum_compressor import CompressedState
|
|
170
|
+
|
|
171
|
+
cs = CompressedState(
|
|
172
|
+
amplitudes=compressed["amplitudes"],
|
|
173
|
+
n_qubits=compressed["n_qubits"],
|
|
174
|
+
original_length=sum(length for _, length in original_keys),
|
|
175
|
+
compression_percentage=compressed.get("compression_ratio", 0) * 100,
|
|
176
|
+
input_token_count=0,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
decompressed = compressor.decompress(cs)
|
|
180
|
+
|
|
181
|
+
# Reconstruct dictionary
|
|
182
|
+
result = {}
|
|
183
|
+
offset = 0
|
|
184
|
+
for key, length in original_keys:
|
|
185
|
+
result[key] = decompressed[offset : offset + length]
|
|
186
|
+
offset += length
|
|
187
|
+
|
|
188
|
+
return result
|
|
189
|
+
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.warning(f"Decompression failed: {e}")
|
|
192
|
+
return {}
|
|
193
|
+
|
|
194
|
+
def save(
|
|
195
|
+
self,
|
|
196
|
+
pipeline_id: str,
|
|
197
|
+
step: int,
|
|
198
|
+
state_data: Dict[str, Any],
|
|
199
|
+
metrics: Optional[Dict[str, float]] = None,
|
|
200
|
+
name: Optional[str] = None,
|
|
201
|
+
use_quantum_compression: bool = False,
|
|
202
|
+
) -> str:
|
|
203
|
+
"""
|
|
204
|
+
Save a checkpoint.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
pipeline_id: Pipeline identifier
|
|
208
|
+
step: Step number
|
|
209
|
+
state_data: State data to save
|
|
210
|
+
metrics: Optional metrics at checkpoint
|
|
211
|
+
name: Optional checkpoint name
|
|
212
|
+
use_quantum_compression: Whether to use quantum compression
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Checkpoint ID
|
|
216
|
+
"""
|
|
217
|
+
checkpoint_id = str(uuid.uuid4())
|
|
218
|
+
|
|
219
|
+
# Compute integrity hash
|
|
220
|
+
state_hash = self._compute_hash(state_data)
|
|
221
|
+
|
|
222
|
+
# Serialize state
|
|
223
|
+
state_json = self._serialize_state(state_data)
|
|
224
|
+
state_size = len(state_json.encode())
|
|
225
|
+
|
|
226
|
+
# Quantum compression (optional)
|
|
227
|
+
compressed_state = None
|
|
228
|
+
compression_ratio = None
|
|
229
|
+
if use_quantum_compression:
|
|
230
|
+
compressed = self._compress_state(state_data)
|
|
231
|
+
if compressed:
|
|
232
|
+
compressed_state = compressed
|
|
233
|
+
compression_ratio = compressed.get("compression_ratio")
|
|
234
|
+
|
|
235
|
+
checkpoint = {
|
|
236
|
+
"id": checkpoint_id,
|
|
237
|
+
"pipeline_id": pipeline_id,
|
|
238
|
+
"step_number": step,
|
|
239
|
+
"checkpoint_name": name,
|
|
240
|
+
"state_data": state_data,
|
|
241
|
+
"state_hash": state_hash,
|
|
242
|
+
"compressed_state": compressed_state,
|
|
243
|
+
"compression_ratio": compression_ratio,
|
|
244
|
+
"metrics": metrics or {},
|
|
245
|
+
"is_valid": True,
|
|
246
|
+
"validation_error": None,
|
|
247
|
+
"created_at": datetime.utcnow().isoformat(),
|
|
248
|
+
"state_size_bytes": state_size,
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
# Store in memory
|
|
252
|
+
if pipeline_id not in self._checkpoints:
|
|
253
|
+
self._checkpoints[pipeline_id] = {}
|
|
254
|
+
self._checkpoints[pipeline_id][step] = checkpoint
|
|
255
|
+
|
|
256
|
+
# Update cache
|
|
257
|
+
if pipeline_id not in self._cache:
|
|
258
|
+
self._cache[pipeline_id] = {}
|
|
259
|
+
self._cache[pipeline_id][step] = checkpoint
|
|
260
|
+
|
|
261
|
+
# Database persistence
|
|
262
|
+
if self.use_database:
|
|
263
|
+
self._save_to_database(checkpoint)
|
|
264
|
+
|
|
265
|
+
logger.debug(f"Checkpoint saved: pipeline={pipeline_id}, step={step}")
|
|
266
|
+
|
|
267
|
+
return checkpoint_id
|
|
268
|
+
|
|
269
|
+
def _save_to_database(self, checkpoint: Dict[str, Any]):
|
|
270
|
+
"""Save checkpoint to database."""
|
|
271
|
+
try:
|
|
272
|
+
from db.database import get_session
|
|
273
|
+
from db.models import Checkpoint as CheckpointModel
|
|
274
|
+
|
|
275
|
+
with get_session() as session:
|
|
276
|
+
db_checkpoint = CheckpointModel(
|
|
277
|
+
id=uuid.UUID(checkpoint["id"]),
|
|
278
|
+
pipeline_id=uuid.UUID(checkpoint["pipeline_id"]),
|
|
279
|
+
step_number=checkpoint["step_number"],
|
|
280
|
+
checkpoint_name=checkpoint["checkpoint_name"],
|
|
281
|
+
state_data=checkpoint["state_data"],
|
|
282
|
+
state_hash=checkpoint["state_hash"],
|
|
283
|
+
compressed_state=checkpoint["compressed_state"],
|
|
284
|
+
compression_ratio=checkpoint["compression_ratio"],
|
|
285
|
+
metrics=checkpoint["metrics"],
|
|
286
|
+
is_valid=checkpoint["is_valid"],
|
|
287
|
+
state_size_bytes=checkpoint["state_size_bytes"],
|
|
288
|
+
)
|
|
289
|
+
session.add(db_checkpoint)
|
|
290
|
+
session.commit()
|
|
291
|
+
|
|
292
|
+
except Exception as e:
|
|
293
|
+
logger.warning(f"Database save failed, using in-memory only: {e}")
|
|
294
|
+
|
|
295
|
+
def load(
|
|
296
|
+
self, pipeline_id: str, step: int, verify_integrity: bool = True
|
|
297
|
+
) -> Optional[Dict[str, Any]]:
|
|
298
|
+
"""
|
|
299
|
+
Load a checkpoint.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
pipeline_id: Pipeline identifier
|
|
303
|
+
step: Step number
|
|
304
|
+
verify_integrity: Whether to verify hash integrity
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Checkpoint data or None if not found
|
|
308
|
+
"""
|
|
309
|
+
# Check cache first
|
|
310
|
+
if pipeline_id in self._cache and step in self._cache[pipeline_id]:
|
|
311
|
+
checkpoint = self._cache[pipeline_id][step]
|
|
312
|
+
if verify_integrity and not self._verify_integrity(checkpoint):
|
|
313
|
+
return None
|
|
314
|
+
return checkpoint
|
|
315
|
+
|
|
316
|
+
# Check in-memory storage
|
|
317
|
+
if pipeline_id in self._checkpoints and step in self._checkpoints[pipeline_id]:
|
|
318
|
+
checkpoint = self._checkpoints[pipeline_id][step]
|
|
319
|
+
if verify_integrity and not self._verify_integrity(checkpoint):
|
|
320
|
+
return None
|
|
321
|
+
return checkpoint
|
|
322
|
+
|
|
323
|
+
# Load from database
|
|
324
|
+
if self.use_database:
|
|
325
|
+
checkpoint = self._load_from_database(pipeline_id, step)
|
|
326
|
+
if checkpoint:
|
|
327
|
+
if verify_integrity and not self._verify_integrity(checkpoint):
|
|
328
|
+
return None
|
|
329
|
+
|
|
330
|
+
# Update cache
|
|
331
|
+
if pipeline_id not in self._cache:
|
|
332
|
+
self._cache[pipeline_id] = {}
|
|
333
|
+
self._cache[pipeline_id][step] = checkpoint
|
|
334
|
+
|
|
335
|
+
return checkpoint
|
|
336
|
+
|
|
337
|
+
return None
|
|
338
|
+
|
|
339
|
+
def _load_from_database(
|
|
340
|
+
self, pipeline_id: str, step: int
|
|
341
|
+
) -> Optional[Dict[str, Any]]:
|
|
342
|
+
"""Load checkpoint from database."""
|
|
343
|
+
try:
|
|
344
|
+
from db.database import get_session
|
|
345
|
+
from db.models import Checkpoint as CheckpointModel
|
|
346
|
+
|
|
347
|
+
with get_session() as session:
|
|
348
|
+
checkpoint = (
|
|
349
|
+
session.query(CheckpointModel)
|
|
350
|
+
.filter(
|
|
351
|
+
CheckpointModel.pipeline_id == uuid.UUID(pipeline_id),
|
|
352
|
+
CheckpointModel.step_number == step,
|
|
353
|
+
)
|
|
354
|
+
.first()
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
if checkpoint:
|
|
358
|
+
return {
|
|
359
|
+
"id": str(checkpoint.id),
|
|
360
|
+
"pipeline_id": str(checkpoint.pipeline_id),
|
|
361
|
+
"step_number": checkpoint.step_number,
|
|
362
|
+
"checkpoint_name": checkpoint.checkpoint_name,
|
|
363
|
+
"state_data": checkpoint.state_data,
|
|
364
|
+
"state_hash": checkpoint.state_hash,
|
|
365
|
+
"compressed_state": checkpoint.compressed_state,
|
|
366
|
+
"compression_ratio": checkpoint.compression_ratio,
|
|
367
|
+
"metrics": checkpoint.metrics,
|
|
368
|
+
"is_valid": checkpoint.is_valid,
|
|
369
|
+
"validation_error": checkpoint.validation_error,
|
|
370
|
+
"created_at": checkpoint.created_at.isoformat(),
|
|
371
|
+
"state_size_bytes": checkpoint.state_size_bytes,
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
except Exception as e:
|
|
375
|
+
logger.warning(f"Database load failed: {e}")
|
|
376
|
+
|
|
377
|
+
return None
|
|
378
|
+
|
|
379
|
+
def _verify_integrity(self, checkpoint: Dict[str, Any]) -> bool:
|
|
380
|
+
"""Verify checkpoint integrity using hash."""
|
|
381
|
+
computed_hash = self._compute_hash(checkpoint["state_data"])
|
|
382
|
+
is_valid = computed_hash == checkpoint["state_hash"]
|
|
383
|
+
|
|
384
|
+
if not is_valid:
|
|
385
|
+
logger.warning(
|
|
386
|
+
f"Checkpoint integrity check failed: step={checkpoint['step_number']}"
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
return is_valid
|
|
390
|
+
|
|
391
|
+
def get_latest_valid(self, pipeline_id: str) -> Optional[Dict[str, Any]]:
|
|
392
|
+
"""
|
|
393
|
+
Get the most recent valid checkpoint.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
pipeline_id: Pipeline identifier
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
Latest valid checkpoint or None
|
|
400
|
+
"""
|
|
401
|
+
# Get all checkpoints for pipeline
|
|
402
|
+
checkpoints = self.list_checkpoints(pipeline_id)
|
|
403
|
+
|
|
404
|
+
if not checkpoints:
|
|
405
|
+
return None
|
|
406
|
+
|
|
407
|
+
# Sort by step descending
|
|
408
|
+
sorted_checkpoints = sorted(checkpoints, key=lambda c: c["step_number"], reverse=True)
|
|
409
|
+
|
|
410
|
+
# Find first valid one
|
|
411
|
+
for cp in sorted_checkpoints:
|
|
412
|
+
if cp["is_valid"] and self._verify_integrity(cp):
|
|
413
|
+
return cp
|
|
414
|
+
|
|
415
|
+
return None
|
|
416
|
+
|
|
417
|
+
def list_checkpoints(
|
|
418
|
+
self, pipeline_id: str, limit: Optional[int] = None
|
|
419
|
+
) -> List[Dict[str, Any]]:
|
|
420
|
+
"""
|
|
421
|
+
List all checkpoints for a pipeline.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
pipeline_id: Pipeline identifier
|
|
425
|
+
limit: Maximum number to return
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
List of checkpoints sorted by step descending
|
|
429
|
+
"""
|
|
430
|
+
checkpoints = []
|
|
431
|
+
|
|
432
|
+
# From in-memory
|
|
433
|
+
if pipeline_id in self._checkpoints:
|
|
434
|
+
checkpoints = list(self._checkpoints[pipeline_id].values())
|
|
435
|
+
|
|
436
|
+
# From database if empty
|
|
437
|
+
if not checkpoints and self.use_database:
|
|
438
|
+
checkpoints = self._list_from_database(pipeline_id, limit)
|
|
439
|
+
|
|
440
|
+
# Sort by step descending
|
|
441
|
+
checkpoints.sort(key=lambda c: c["step_number"], reverse=True)
|
|
442
|
+
|
|
443
|
+
if limit:
|
|
444
|
+
checkpoints = checkpoints[:limit]
|
|
445
|
+
|
|
446
|
+
return checkpoints
|
|
447
|
+
|
|
448
|
+
def _list_from_database(
|
|
449
|
+
self, pipeline_id: str, limit: Optional[int] = None
|
|
450
|
+
) -> List[Dict[str, Any]]:
|
|
451
|
+
"""List checkpoints from database."""
|
|
452
|
+
try:
|
|
453
|
+
from db.database import get_session
|
|
454
|
+
from db.models import Checkpoint as CheckpointModel
|
|
455
|
+
|
|
456
|
+
with get_session() as session:
|
|
457
|
+
query = (
|
|
458
|
+
session.query(CheckpointModel)
|
|
459
|
+
.filter(CheckpointModel.pipeline_id == uuid.UUID(pipeline_id))
|
|
460
|
+
.order_by(CheckpointModel.step_number.desc())
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
if limit:
|
|
464
|
+
query = query.limit(limit)
|
|
465
|
+
|
|
466
|
+
return [
|
|
467
|
+
{
|
|
468
|
+
"id": str(cp.id),
|
|
469
|
+
"pipeline_id": str(cp.pipeline_id),
|
|
470
|
+
"step_number": cp.step_number,
|
|
471
|
+
"checkpoint_name": cp.checkpoint_name,
|
|
472
|
+
"state_data": cp.state_data,
|
|
473
|
+
"state_hash": cp.state_hash,
|
|
474
|
+
"compressed_state": cp.compressed_state,
|
|
475
|
+
"compression_ratio": cp.compression_ratio,
|
|
476
|
+
"metrics": cp.metrics,
|
|
477
|
+
"is_valid": cp.is_valid,
|
|
478
|
+
"validation_error": cp.validation_error,
|
|
479
|
+
"created_at": cp.created_at.isoformat(),
|
|
480
|
+
"state_size_bytes": cp.state_size_bytes,
|
|
481
|
+
}
|
|
482
|
+
for cp in query.all()
|
|
483
|
+
]
|
|
484
|
+
|
|
485
|
+
except Exception as e:
|
|
486
|
+
logger.warning(f"Database list failed: {e}")
|
|
487
|
+
return []
|
|
488
|
+
|
|
489
|
+
def prune(self, pipeline_id: str, keep_count: int):
|
|
490
|
+
"""
|
|
491
|
+
Prune old checkpoints, keeping only the most recent N.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
pipeline_id: Pipeline identifier
|
|
495
|
+
keep_count: Number of checkpoints to keep
|
|
496
|
+
"""
|
|
497
|
+
checkpoints = self.list_checkpoints(pipeline_id)
|
|
498
|
+
|
|
499
|
+
if len(checkpoints) <= keep_count:
|
|
500
|
+
return
|
|
501
|
+
|
|
502
|
+
# Checkpoints to delete (oldest ones)
|
|
503
|
+
to_delete = checkpoints[keep_count:]
|
|
504
|
+
|
|
505
|
+
for cp in to_delete:
|
|
506
|
+
step = cp["step_number"]
|
|
507
|
+
|
|
508
|
+
# Remove from in-memory
|
|
509
|
+
if pipeline_id in self._checkpoints and step in self._checkpoints[pipeline_id]:
|
|
510
|
+
del self._checkpoints[pipeline_id][step]
|
|
511
|
+
|
|
512
|
+
# Remove from cache
|
|
513
|
+
if pipeline_id in self._cache and step in self._cache[pipeline_id]:
|
|
514
|
+
del self._cache[pipeline_id][step]
|
|
515
|
+
|
|
516
|
+
# Remove from database
|
|
517
|
+
if self.use_database:
|
|
518
|
+
self._delete_from_database(pipeline_id, step)
|
|
519
|
+
|
|
520
|
+
logger.debug(f"Pruned {len(to_delete)} old checkpoints for pipeline {pipeline_id}")
|
|
521
|
+
|
|
522
|
+
def _delete_from_database(self, pipeline_id: str, step: int):
|
|
523
|
+
"""Delete checkpoint from database."""
|
|
524
|
+
try:
|
|
525
|
+
from db.database import get_session
|
|
526
|
+
from db.models import Checkpoint as CheckpointModel
|
|
527
|
+
|
|
528
|
+
with get_session() as session:
|
|
529
|
+
session.query(CheckpointModel).filter(
|
|
530
|
+
CheckpointModel.pipeline_id == uuid.UUID(pipeline_id),
|
|
531
|
+
CheckpointModel.step_number == step,
|
|
532
|
+
).delete()
|
|
533
|
+
session.commit()
|
|
534
|
+
|
|
535
|
+
except Exception as e:
|
|
536
|
+
logger.warning(f"Database delete failed: {e}")
|
|
537
|
+
|
|
538
|
+
def invalidate(self, pipeline_id: str, step: int, reason: str):
|
|
539
|
+
"""
|
|
540
|
+
Mark a checkpoint as invalid.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
pipeline_id: Pipeline identifier
|
|
544
|
+
step: Step number
|
|
545
|
+
reason: Reason for invalidation
|
|
546
|
+
"""
|
|
547
|
+
checkpoint = self.load(pipeline_id, step, verify_integrity=False)
|
|
548
|
+
|
|
549
|
+
if checkpoint:
|
|
550
|
+
checkpoint["is_valid"] = False
|
|
551
|
+
checkpoint["validation_error"] = reason
|
|
552
|
+
|
|
553
|
+
# Update in-memory
|
|
554
|
+
if pipeline_id in self._checkpoints and step in self._checkpoints[pipeline_id]:
|
|
555
|
+
self._checkpoints[pipeline_id][step] = checkpoint
|
|
556
|
+
|
|
557
|
+
# Update database
|
|
558
|
+
if self.use_database:
|
|
559
|
+
self._update_validity_in_database(pipeline_id, step, False, reason)
|
|
560
|
+
|
|
561
|
+
logger.info(f"Checkpoint invalidated: step={step}, reason={reason}")
|
|
562
|
+
|
|
563
|
+
def _update_validity_in_database(
|
|
564
|
+
self, pipeline_id: str, step: int, is_valid: bool, error: Optional[str]
|
|
565
|
+
):
|
|
566
|
+
"""Update checkpoint validity in database."""
|
|
567
|
+
try:
|
|
568
|
+
from db.database import get_session
|
|
569
|
+
from db.models import Checkpoint as CheckpointModel
|
|
570
|
+
|
|
571
|
+
with get_session() as session:
|
|
572
|
+
session.query(CheckpointModel).filter(
|
|
573
|
+
CheckpointModel.pipeline_id == uuid.UUID(pipeline_id),
|
|
574
|
+
CheckpointModel.step_number == step,
|
|
575
|
+
).update({"is_valid": is_valid, "validation_error": error})
|
|
576
|
+
session.commit()
|
|
577
|
+
|
|
578
|
+
except Exception as e:
|
|
579
|
+
logger.warning(f"Database validity update failed: {e}")
|
|
580
|
+
|
|
581
|
+
def clear_cache(self, pipeline_id: Optional[str] = None):
|
|
582
|
+
"""Clear checkpoint cache."""
|
|
583
|
+
if pipeline_id:
|
|
584
|
+
if pipeline_id in self._cache:
|
|
585
|
+
del self._cache[pipeline_id]
|
|
586
|
+
else:
|
|
587
|
+
self._cache.clear()
|