quantumflow-sdk 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,587 @@
1
+ """
2
+ Checkpoint Manager for Pipeline State Persistence.
3
+
4
+ Provides:
5
+ - Save/load checkpoints with SHA-256 integrity verification
6
+ - Optional quantum compression for large state vectors
7
+ - Automatic pruning of old checkpoints
8
+ - Database persistence
9
+
10
+ Example:
11
+ manager = CheckpointManager(backend="simulator")
12
+
13
+ # Save checkpoint
14
+ manager.save(
15
+ pipeline_id="...",
16
+ step=100,
17
+ state_data={"weights": [...], "energy": -1.5},
18
+ metrics={"loss": 0.01},
19
+ )
20
+
21
+ # Load checkpoint
22
+ checkpoint = manager.load(pipeline_id, step=100)
23
+
24
+ # Get latest valid checkpoint
25
+ latest = manager.get_latest_valid(pipeline_id)
26
+ """
27
+
28
+ import hashlib
29
+ import json
30
+ import logging
31
+ from datetime import datetime
32
+ from typing import Any, Dict, List, Optional
33
+ import uuid
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class CheckpointManager:
39
+ """
40
+ Manages checkpoint persistence for pipelines.
41
+
42
+ Supports:
43
+ - State serialization with SHA-256 integrity hashes
44
+ - Optional quantum compression for large states
45
+ - Automatic pruning to keep N most recent
46
+ - In-memory cache for fast access
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ backend: str = "simulator",
52
+ use_database: bool = True,
53
+ cache_size: int = 10,
54
+ ):
55
+ """
56
+ Initialize checkpoint manager.
57
+
58
+ Args:
59
+ backend: Quantum backend for compression
60
+ use_database: Whether to persist to database
61
+ cache_size: Number of checkpoints to cache in memory
62
+ """
63
+ self.backend = backend
64
+ self.use_database = use_database
65
+ self.cache_size = cache_size
66
+
67
+ # In-memory storage (for non-DB mode or caching)
68
+ self._checkpoints: Dict[str, Dict[int, Dict[str, Any]]] = {}
69
+ self._cache: Dict[str, Dict[int, Dict[str, Any]]] = {}
70
+
71
+ # Quantum compressor (lazy loaded)
72
+ self._compressor = None
73
+
74
+ def _get_compressor(self):
75
+ """Get or create quantum compressor."""
76
+ if self._compressor is None:
77
+ try:
78
+ from quantumflow.core.quantum_compressor import QuantumCompressor
79
+ self._compressor = QuantumCompressor(backend=self.backend)
80
+ except ImportError:
81
+ logger.warning("QuantumCompressor not available, compression disabled")
82
+ return self._compressor
83
+
84
+ def _compute_hash(self, data: Dict[str, Any]) -> str:
85
+ """Compute SHA-256 hash of state data."""
86
+ serialized = json.dumps(data, sort_keys=True, default=str)
87
+ return hashlib.sha256(serialized.encode()).hexdigest()
88
+
89
+ def _serialize_state(self, state_data: Dict[str, Any]) -> str:
90
+ """Serialize state data to JSON string."""
91
+ return json.dumps(state_data, default=self._json_serializer)
92
+
93
+ def _deserialize_state(self, state_json: str) -> Dict[str, Any]:
94
+ """Deserialize JSON string to state data."""
95
+ return json.loads(state_json)
96
+
97
+ @staticmethod
98
+ def _json_serializer(obj):
99
+ """Custom JSON serializer for complex types."""
100
+ if hasattr(obj, "tolist"): # numpy arrays
101
+ return obj.tolist()
102
+ if hasattr(obj, "isoformat"): # datetime
103
+ return obj.isoformat()
104
+ if hasattr(obj, "__dict__"):
105
+ return obj.__dict__
106
+ return str(obj)
107
+
108
+ def _compress_state(self, state_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
109
+ """
110
+ Quantum-compress state data.
111
+
112
+ Args:
113
+ state_data: State to compress
114
+
115
+ Returns:
116
+ Compressed state or None if compression not available
117
+ """
118
+ compressor = self._get_compressor()
119
+ if not compressor:
120
+ return None
121
+
122
+ try:
123
+ # Extract numeric data for compression
124
+ vectors = []
125
+ vector_keys = []
126
+
127
+ for key, value in state_data.items():
128
+ if isinstance(value, (list, tuple)) and all(
129
+ isinstance(v, (int, float)) for v in value
130
+ ):
131
+ vectors.extend(value)
132
+ vector_keys.append((key, len(value)))
133
+
134
+ if not vectors:
135
+ return None
136
+
137
+ # Compress
138
+ compressed = compressor.compress(vectors)
139
+
140
+ return {
141
+ "amplitudes": compressed.amplitudes.tolist() if hasattr(compressed.amplitudes, "tolist") else list(compressed.amplitudes),
142
+ "n_qubits": compressed.n_qubits,
143
+ "vector_keys": vector_keys,
144
+ "compression_ratio": compressed.compression_percentage / 100,
145
+ }
146
+
147
+ except Exception as e:
148
+ logger.warning(f"Compression failed: {e}")
149
+ return None
150
+
151
+ def _decompress_state(
152
+ self, compressed: Dict[str, Any], original_keys: List[tuple]
153
+ ) -> Dict[str, Any]:
154
+ """
155
+ Decompress quantum-compressed state.
156
+
157
+ Args:
158
+ compressed: Compressed state data
159
+ original_keys: List of (key, length) tuples
160
+
161
+ Returns:
162
+ Decompressed state data
163
+ """
164
+ compressor = self._get_compressor()
165
+ if not compressor:
166
+ return {}
167
+
168
+ try:
169
+ from quantumflow.core.quantum_compressor import CompressedState
170
+
171
+ cs = CompressedState(
172
+ amplitudes=compressed["amplitudes"],
173
+ n_qubits=compressed["n_qubits"],
174
+ original_length=sum(length for _, length in original_keys),
175
+ compression_percentage=compressed.get("compression_ratio", 0) * 100,
176
+ input_token_count=0,
177
+ )
178
+
179
+ decompressed = compressor.decompress(cs)
180
+
181
+ # Reconstruct dictionary
182
+ result = {}
183
+ offset = 0
184
+ for key, length in original_keys:
185
+ result[key] = decompressed[offset : offset + length]
186
+ offset += length
187
+
188
+ return result
189
+
190
+ except Exception as e:
191
+ logger.warning(f"Decompression failed: {e}")
192
+ return {}
193
+
194
+ def save(
195
+ self,
196
+ pipeline_id: str,
197
+ step: int,
198
+ state_data: Dict[str, Any],
199
+ metrics: Optional[Dict[str, float]] = None,
200
+ name: Optional[str] = None,
201
+ use_quantum_compression: bool = False,
202
+ ) -> str:
203
+ """
204
+ Save a checkpoint.
205
+
206
+ Args:
207
+ pipeline_id: Pipeline identifier
208
+ step: Step number
209
+ state_data: State data to save
210
+ metrics: Optional metrics at checkpoint
211
+ name: Optional checkpoint name
212
+ use_quantum_compression: Whether to use quantum compression
213
+
214
+ Returns:
215
+ Checkpoint ID
216
+ """
217
+ checkpoint_id = str(uuid.uuid4())
218
+
219
+ # Compute integrity hash
220
+ state_hash = self._compute_hash(state_data)
221
+
222
+ # Serialize state
223
+ state_json = self._serialize_state(state_data)
224
+ state_size = len(state_json.encode())
225
+
226
+ # Quantum compression (optional)
227
+ compressed_state = None
228
+ compression_ratio = None
229
+ if use_quantum_compression:
230
+ compressed = self._compress_state(state_data)
231
+ if compressed:
232
+ compressed_state = compressed
233
+ compression_ratio = compressed.get("compression_ratio")
234
+
235
+ checkpoint = {
236
+ "id": checkpoint_id,
237
+ "pipeline_id": pipeline_id,
238
+ "step_number": step,
239
+ "checkpoint_name": name,
240
+ "state_data": state_data,
241
+ "state_hash": state_hash,
242
+ "compressed_state": compressed_state,
243
+ "compression_ratio": compression_ratio,
244
+ "metrics": metrics or {},
245
+ "is_valid": True,
246
+ "validation_error": None,
247
+ "created_at": datetime.utcnow().isoformat(),
248
+ "state_size_bytes": state_size,
249
+ }
250
+
251
+ # Store in memory
252
+ if pipeline_id not in self._checkpoints:
253
+ self._checkpoints[pipeline_id] = {}
254
+ self._checkpoints[pipeline_id][step] = checkpoint
255
+
256
+ # Update cache
257
+ if pipeline_id not in self._cache:
258
+ self._cache[pipeline_id] = {}
259
+ self._cache[pipeline_id][step] = checkpoint
260
+
261
+ # Database persistence
262
+ if self.use_database:
263
+ self._save_to_database(checkpoint)
264
+
265
+ logger.debug(f"Checkpoint saved: pipeline={pipeline_id}, step={step}")
266
+
267
+ return checkpoint_id
268
+
269
+ def _save_to_database(self, checkpoint: Dict[str, Any]):
270
+ """Save checkpoint to database."""
271
+ try:
272
+ from db.database import get_session
273
+ from db.models import Checkpoint as CheckpointModel
274
+
275
+ with get_session() as session:
276
+ db_checkpoint = CheckpointModel(
277
+ id=uuid.UUID(checkpoint["id"]),
278
+ pipeline_id=uuid.UUID(checkpoint["pipeline_id"]),
279
+ step_number=checkpoint["step_number"],
280
+ checkpoint_name=checkpoint["checkpoint_name"],
281
+ state_data=checkpoint["state_data"],
282
+ state_hash=checkpoint["state_hash"],
283
+ compressed_state=checkpoint["compressed_state"],
284
+ compression_ratio=checkpoint["compression_ratio"],
285
+ metrics=checkpoint["metrics"],
286
+ is_valid=checkpoint["is_valid"],
287
+ state_size_bytes=checkpoint["state_size_bytes"],
288
+ )
289
+ session.add(db_checkpoint)
290
+ session.commit()
291
+
292
+ except Exception as e:
293
+ logger.warning(f"Database save failed, using in-memory only: {e}")
294
+
295
+ def load(
296
+ self, pipeline_id: str, step: int, verify_integrity: bool = True
297
+ ) -> Optional[Dict[str, Any]]:
298
+ """
299
+ Load a checkpoint.
300
+
301
+ Args:
302
+ pipeline_id: Pipeline identifier
303
+ step: Step number
304
+ verify_integrity: Whether to verify hash integrity
305
+
306
+ Returns:
307
+ Checkpoint data or None if not found
308
+ """
309
+ # Check cache first
310
+ if pipeline_id in self._cache and step in self._cache[pipeline_id]:
311
+ checkpoint = self._cache[pipeline_id][step]
312
+ if verify_integrity and not self._verify_integrity(checkpoint):
313
+ return None
314
+ return checkpoint
315
+
316
+ # Check in-memory storage
317
+ if pipeline_id in self._checkpoints and step in self._checkpoints[pipeline_id]:
318
+ checkpoint = self._checkpoints[pipeline_id][step]
319
+ if verify_integrity and not self._verify_integrity(checkpoint):
320
+ return None
321
+ return checkpoint
322
+
323
+ # Load from database
324
+ if self.use_database:
325
+ checkpoint = self._load_from_database(pipeline_id, step)
326
+ if checkpoint:
327
+ if verify_integrity and not self._verify_integrity(checkpoint):
328
+ return None
329
+
330
+ # Update cache
331
+ if pipeline_id not in self._cache:
332
+ self._cache[pipeline_id] = {}
333
+ self._cache[pipeline_id][step] = checkpoint
334
+
335
+ return checkpoint
336
+
337
+ return None
338
+
339
+ def _load_from_database(
340
+ self, pipeline_id: str, step: int
341
+ ) -> Optional[Dict[str, Any]]:
342
+ """Load checkpoint from database."""
343
+ try:
344
+ from db.database import get_session
345
+ from db.models import Checkpoint as CheckpointModel
346
+
347
+ with get_session() as session:
348
+ checkpoint = (
349
+ session.query(CheckpointModel)
350
+ .filter(
351
+ CheckpointModel.pipeline_id == uuid.UUID(pipeline_id),
352
+ CheckpointModel.step_number == step,
353
+ )
354
+ .first()
355
+ )
356
+
357
+ if checkpoint:
358
+ return {
359
+ "id": str(checkpoint.id),
360
+ "pipeline_id": str(checkpoint.pipeline_id),
361
+ "step_number": checkpoint.step_number,
362
+ "checkpoint_name": checkpoint.checkpoint_name,
363
+ "state_data": checkpoint.state_data,
364
+ "state_hash": checkpoint.state_hash,
365
+ "compressed_state": checkpoint.compressed_state,
366
+ "compression_ratio": checkpoint.compression_ratio,
367
+ "metrics": checkpoint.metrics,
368
+ "is_valid": checkpoint.is_valid,
369
+ "validation_error": checkpoint.validation_error,
370
+ "created_at": checkpoint.created_at.isoformat(),
371
+ "state_size_bytes": checkpoint.state_size_bytes,
372
+ }
373
+
374
+ except Exception as e:
375
+ logger.warning(f"Database load failed: {e}")
376
+
377
+ return None
378
+
379
+ def _verify_integrity(self, checkpoint: Dict[str, Any]) -> bool:
380
+ """Verify checkpoint integrity using hash."""
381
+ computed_hash = self._compute_hash(checkpoint["state_data"])
382
+ is_valid = computed_hash == checkpoint["state_hash"]
383
+
384
+ if not is_valid:
385
+ logger.warning(
386
+ f"Checkpoint integrity check failed: step={checkpoint['step_number']}"
387
+ )
388
+
389
+ return is_valid
390
+
391
+ def get_latest_valid(self, pipeline_id: str) -> Optional[Dict[str, Any]]:
392
+ """
393
+ Get the most recent valid checkpoint.
394
+
395
+ Args:
396
+ pipeline_id: Pipeline identifier
397
+
398
+ Returns:
399
+ Latest valid checkpoint or None
400
+ """
401
+ # Get all checkpoints for pipeline
402
+ checkpoints = self.list_checkpoints(pipeline_id)
403
+
404
+ if not checkpoints:
405
+ return None
406
+
407
+ # Sort by step descending
408
+ sorted_checkpoints = sorted(checkpoints, key=lambda c: c["step_number"], reverse=True)
409
+
410
+ # Find first valid one
411
+ for cp in sorted_checkpoints:
412
+ if cp["is_valid"] and self._verify_integrity(cp):
413
+ return cp
414
+
415
+ return None
416
+
417
+ def list_checkpoints(
418
+ self, pipeline_id: str, limit: Optional[int] = None
419
+ ) -> List[Dict[str, Any]]:
420
+ """
421
+ List all checkpoints for a pipeline.
422
+
423
+ Args:
424
+ pipeline_id: Pipeline identifier
425
+ limit: Maximum number to return
426
+
427
+ Returns:
428
+ List of checkpoints sorted by step descending
429
+ """
430
+ checkpoints = []
431
+
432
+ # From in-memory
433
+ if pipeline_id in self._checkpoints:
434
+ checkpoints = list(self._checkpoints[pipeline_id].values())
435
+
436
+ # From database if empty
437
+ if not checkpoints and self.use_database:
438
+ checkpoints = self._list_from_database(pipeline_id, limit)
439
+
440
+ # Sort by step descending
441
+ checkpoints.sort(key=lambda c: c["step_number"], reverse=True)
442
+
443
+ if limit:
444
+ checkpoints = checkpoints[:limit]
445
+
446
+ return checkpoints
447
+
448
+ def _list_from_database(
449
+ self, pipeline_id: str, limit: Optional[int] = None
450
+ ) -> List[Dict[str, Any]]:
451
+ """List checkpoints from database."""
452
+ try:
453
+ from db.database import get_session
454
+ from db.models import Checkpoint as CheckpointModel
455
+
456
+ with get_session() as session:
457
+ query = (
458
+ session.query(CheckpointModel)
459
+ .filter(CheckpointModel.pipeline_id == uuid.UUID(pipeline_id))
460
+ .order_by(CheckpointModel.step_number.desc())
461
+ )
462
+
463
+ if limit:
464
+ query = query.limit(limit)
465
+
466
+ return [
467
+ {
468
+ "id": str(cp.id),
469
+ "pipeline_id": str(cp.pipeline_id),
470
+ "step_number": cp.step_number,
471
+ "checkpoint_name": cp.checkpoint_name,
472
+ "state_data": cp.state_data,
473
+ "state_hash": cp.state_hash,
474
+ "compressed_state": cp.compressed_state,
475
+ "compression_ratio": cp.compression_ratio,
476
+ "metrics": cp.metrics,
477
+ "is_valid": cp.is_valid,
478
+ "validation_error": cp.validation_error,
479
+ "created_at": cp.created_at.isoformat(),
480
+ "state_size_bytes": cp.state_size_bytes,
481
+ }
482
+ for cp in query.all()
483
+ ]
484
+
485
+ except Exception as e:
486
+ logger.warning(f"Database list failed: {e}")
487
+ return []
488
+
489
+ def prune(self, pipeline_id: str, keep_count: int):
490
+ """
491
+ Prune old checkpoints, keeping only the most recent N.
492
+
493
+ Args:
494
+ pipeline_id: Pipeline identifier
495
+ keep_count: Number of checkpoints to keep
496
+ """
497
+ checkpoints = self.list_checkpoints(pipeline_id)
498
+
499
+ if len(checkpoints) <= keep_count:
500
+ return
501
+
502
+ # Checkpoints to delete (oldest ones)
503
+ to_delete = checkpoints[keep_count:]
504
+
505
+ for cp in to_delete:
506
+ step = cp["step_number"]
507
+
508
+ # Remove from in-memory
509
+ if pipeline_id in self._checkpoints and step in self._checkpoints[pipeline_id]:
510
+ del self._checkpoints[pipeline_id][step]
511
+
512
+ # Remove from cache
513
+ if pipeline_id in self._cache and step in self._cache[pipeline_id]:
514
+ del self._cache[pipeline_id][step]
515
+
516
+ # Remove from database
517
+ if self.use_database:
518
+ self._delete_from_database(pipeline_id, step)
519
+
520
+ logger.debug(f"Pruned {len(to_delete)} old checkpoints for pipeline {pipeline_id}")
521
+
522
+ def _delete_from_database(self, pipeline_id: str, step: int):
523
+ """Delete checkpoint from database."""
524
+ try:
525
+ from db.database import get_session
526
+ from db.models import Checkpoint as CheckpointModel
527
+
528
+ with get_session() as session:
529
+ session.query(CheckpointModel).filter(
530
+ CheckpointModel.pipeline_id == uuid.UUID(pipeline_id),
531
+ CheckpointModel.step_number == step,
532
+ ).delete()
533
+ session.commit()
534
+
535
+ except Exception as e:
536
+ logger.warning(f"Database delete failed: {e}")
537
+
538
+ def invalidate(self, pipeline_id: str, step: int, reason: str):
539
+ """
540
+ Mark a checkpoint as invalid.
541
+
542
+ Args:
543
+ pipeline_id: Pipeline identifier
544
+ step: Step number
545
+ reason: Reason for invalidation
546
+ """
547
+ checkpoint = self.load(pipeline_id, step, verify_integrity=False)
548
+
549
+ if checkpoint:
550
+ checkpoint["is_valid"] = False
551
+ checkpoint["validation_error"] = reason
552
+
553
+ # Update in-memory
554
+ if pipeline_id in self._checkpoints and step in self._checkpoints[pipeline_id]:
555
+ self._checkpoints[pipeline_id][step] = checkpoint
556
+
557
+ # Update database
558
+ if self.use_database:
559
+ self._update_validity_in_database(pipeline_id, step, False, reason)
560
+
561
+ logger.info(f"Checkpoint invalidated: step={step}, reason={reason}")
562
+
563
+ def _update_validity_in_database(
564
+ self, pipeline_id: str, step: int, is_valid: bool, error: Optional[str]
565
+ ):
566
+ """Update checkpoint validity in database."""
567
+ try:
568
+ from db.database import get_session
569
+ from db.models import Checkpoint as CheckpointModel
570
+
571
+ with get_session() as session:
572
+ session.query(CheckpointModel).filter(
573
+ CheckpointModel.pipeline_id == uuid.UUID(pipeline_id),
574
+ CheckpointModel.step_number == step,
575
+ ).update({"is_valid": is_valid, "validation_error": error})
576
+ session.commit()
577
+
578
+ except Exception as e:
579
+ logger.warning(f"Database validity update failed: {e}")
580
+
581
+ def clear_cache(self, pipeline_id: Optional[str] = None):
582
+ """Clear checkpoint cache."""
583
+ if pipeline_id:
584
+ if pipeline_id in self._cache:
585
+ del self._cache[pipeline_id]
586
+ else:
587
+ self._cache.clear()
@@ -0,0 +1,5 @@
1
+ """Finance domain pipelines."""
2
+
3
+ from quantumflow.pipeline.finance.portfolio_optimization import PortfolioOptimizationPipeline
4
+
5
+ __all__ = ["PortfolioOptimizationPipeline"]