quantumflow-sdk 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,602 @@
1
+ """
2
+ Base Pipeline with Checkpointing Support.
3
+
4
+ Abstract base class for all quantum data pipelines with:
5
+ - Configurable checkpointing (time/steps based)
6
+ - State serialization and recovery
7
+ - Integration with anomaly detection
8
+ - Auto-rollback support
9
+
10
+ Example:
11
+ class MyPipeline(BasePipeline):
12
+ def execute_step(self, step: int, state: PipelineState) -> PipelineState:
13
+ # Custom step logic
14
+ return state
15
+
16
+ def get_state_for_checkpoint(self, state: PipelineState) -> dict:
17
+ return state.to_dict()
18
+ """
19
+
20
+ from abc import ABC, abstractmethod
21
+ from dataclasses import dataclass, field
22
+ from typing import Any, Dict, List, Optional, Callable
23
+ from datetime import datetime
24
+ import time
25
+ import uuid
26
+ import logging
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class PipelineConfig:
33
+ """Configuration for a pipeline."""
34
+
35
+ # Checkpointing
36
+ checkpoint_interval_steps: int = 10
37
+ checkpoint_interval_seconds: int = 300 # 5 minutes
38
+ max_checkpoints: int = 5
39
+ enable_quantum_compression: bool = False
40
+
41
+ # Anomaly detection
42
+ enable_anomaly_detection: bool = True
43
+ auto_rollback_on_critical: bool = True
44
+ gradient_explosion_threshold: float = 100.0
45
+ gradient_vanishing_threshold: float = 1e-7
46
+
47
+ # Execution
48
+ backend: str = "simulator"
49
+ max_retries: int = 3
50
+ retry_delay_seconds: float = 1.0
51
+
52
+ # Temporal memory
53
+ enable_temporal_memory: bool = True
54
+ memory_window_size: int = 100
55
+
56
+ def to_dict(self) -> Dict[str, Any]:
57
+ """Convert config to dictionary."""
58
+ return {
59
+ "checkpoint_interval_steps": self.checkpoint_interval_steps,
60
+ "checkpoint_interval_seconds": self.checkpoint_interval_seconds,
61
+ "max_checkpoints": self.max_checkpoints,
62
+ "enable_quantum_compression": self.enable_quantum_compression,
63
+ "enable_anomaly_detection": self.enable_anomaly_detection,
64
+ "auto_rollback_on_critical": self.auto_rollback_on_critical,
65
+ "gradient_explosion_threshold": self.gradient_explosion_threshold,
66
+ "gradient_vanishing_threshold": self.gradient_vanishing_threshold,
67
+ "backend": self.backend,
68
+ "max_retries": self.max_retries,
69
+ "retry_delay_seconds": self.retry_delay_seconds,
70
+ "enable_temporal_memory": self.enable_temporal_memory,
71
+ "memory_window_size": self.memory_window_size,
72
+ }
73
+
74
+ @classmethod
75
+ def from_dict(cls, data: Dict[str, Any]) -> "PipelineConfig":
76
+ """Create config from dictionary."""
77
+ return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
78
+
79
+
80
+ @dataclass
81
+ class PipelineState:
82
+ """Current state of a pipeline execution."""
83
+
84
+ # Core state
85
+ step: int = 0
86
+ data: Dict[str, Any] = field(default_factory=dict)
87
+
88
+ # Metrics
89
+ metrics: Dict[str, float] = field(default_factory=dict)
90
+ gradient_history: List[float] = field(default_factory=list)
91
+
92
+ # Algorithm-specific state
93
+ weights: Optional[List[float]] = None
94
+ parameters: Optional[Dict[str, Any]] = None
95
+
96
+ # History
97
+ step_history: List[Dict[str, Any]] = field(default_factory=list)
98
+
99
+ def to_dict(self) -> Dict[str, Any]:
100
+ """Serialize state to dictionary."""
101
+ return {
102
+ "step": self.step,
103
+ "data": self.data,
104
+ "metrics": self.metrics,
105
+ "gradient_history": self.gradient_history,
106
+ "weights": self.weights,
107
+ "parameters": self.parameters,
108
+ }
109
+
110
+ @classmethod
111
+ def from_dict(cls, data: Dict[str, Any]) -> "PipelineState":
112
+ """Deserialize state from dictionary."""
113
+ return cls(
114
+ step=data.get("step", 0),
115
+ data=data.get("data", {}),
116
+ metrics=data.get("metrics", {}),
117
+ gradient_history=data.get("gradient_history", []),
118
+ weights=data.get("weights"),
119
+ parameters=data.get("parameters"),
120
+ )
121
+
122
+ def update_metrics(self, **kwargs):
123
+ """Update metrics."""
124
+ self.metrics.update(kwargs)
125
+
126
+ def add_gradient(self, gradient: float):
127
+ """Add gradient to history."""
128
+ self.gradient_history.append(gradient)
129
+
130
+
131
+ @dataclass
132
+ class PipelineResult:
133
+ """Result of pipeline execution."""
134
+
135
+ pipeline_id: str
136
+ status: str # completed, failed, rolled_back
137
+ final_state: PipelineState
138
+ total_steps: int
139
+ total_time_ms: float
140
+ checkpoints_created: int
141
+ anomalies_detected: int
142
+ rollbacks_performed: int
143
+ error: Optional[str] = None
144
+
145
+ def to_dict(self) -> Dict[str, Any]:
146
+ """Convert result to dictionary."""
147
+ return {
148
+ "pipeline_id": self.pipeline_id,
149
+ "status": self.status,
150
+ "final_state": self.final_state.to_dict(),
151
+ "total_steps": self.total_steps,
152
+ "total_time_ms": self.total_time_ms,
153
+ "checkpoints_created": self.checkpoints_created,
154
+ "anomalies_detected": self.anomalies_detected,
155
+ "rollbacks_performed": self.rollbacks_performed,
156
+ "error": self.error,
157
+ }
158
+
159
+
160
+ class BasePipeline(ABC):
161
+ """
162
+ Abstract base class for quantum data pipelines.
163
+
164
+ Provides:
165
+ - Checkpointing at configurable intervals
166
+ - State serialization/deserialization
167
+ - Hooks for anomaly detection integration
168
+ - Auto-rollback on critical failures
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ name: str,
174
+ config: Optional[PipelineConfig] = None,
175
+ pipeline_id: Optional[str] = None,
176
+ ):
177
+ """
178
+ Initialize pipeline.
179
+
180
+ Args:
181
+ name: Human-readable pipeline name
182
+ config: Pipeline configuration
183
+ pipeline_id: Optional existing pipeline ID for resumption
184
+ """
185
+ self.name = name
186
+ self.config = config or PipelineConfig()
187
+ self.pipeline_id = pipeline_id or str(uuid.uuid4())
188
+
189
+ # State
190
+ self._state = PipelineState()
191
+ self._is_running = False
192
+ self._is_paused = False
193
+
194
+ # Checkpointing
195
+ self._checkpoint_manager = None
196
+ self._last_checkpoint_time = 0.0
197
+ self._last_checkpoint_step = 0
198
+ self._checkpoints_created = 0
199
+
200
+ # Anomaly detection
201
+ self._anomaly_detector = None
202
+ self._anomalies_detected = 0
203
+
204
+ # Temporal memory
205
+ self._temporal_memory = None
206
+ self._run_id = str(uuid.uuid4())
207
+
208
+ # Rollback tracking
209
+ self._rollbacks_performed = 0
210
+
211
+ # Callbacks
212
+ self._on_step_complete: List[Callable] = []
213
+ self._on_checkpoint: List[Callable] = []
214
+ self._on_anomaly: List[Callable] = []
215
+ self._on_rollback: List[Callable] = []
216
+
217
+ @property
218
+ @abstractmethod
219
+ def pipeline_type(self) -> str:
220
+ """Return the pipeline type identifier."""
221
+ pass
222
+
223
+ @abstractmethod
224
+ def execute_step(self, step: int, state: PipelineState) -> PipelineState:
225
+ """
226
+ Execute a single pipeline step.
227
+
228
+ Args:
229
+ step: Current step number
230
+ state: Current pipeline state
231
+
232
+ Returns:
233
+ Updated pipeline state
234
+ """
235
+ pass
236
+
237
+ @abstractmethod
238
+ def get_state_for_checkpoint(self, state: PipelineState) -> Dict[str, Any]:
239
+ """
240
+ Get state data to save in checkpoint.
241
+
242
+ Override to include algorithm-specific state.
243
+
244
+ Args:
245
+ state: Current pipeline state
246
+
247
+ Returns:
248
+ Serializable state dictionary
249
+ """
250
+ pass
251
+
252
+ @abstractmethod
253
+ def restore_state_from_checkpoint(self, checkpoint_data: Dict[str, Any]) -> PipelineState:
254
+ """
255
+ Restore state from checkpoint data.
256
+
257
+ Args:
258
+ checkpoint_data: Saved checkpoint data
259
+
260
+ Returns:
261
+ Restored pipeline state
262
+ """
263
+ pass
264
+
265
+ def initialize(self) -> PipelineState:
266
+ """
267
+ Initialize pipeline state before execution.
268
+
269
+ Override for custom initialization.
270
+
271
+ Returns:
272
+ Initial pipeline state
273
+ """
274
+ return PipelineState()
275
+
276
+ def finalize(self, state: PipelineState) -> PipelineState:
277
+ """
278
+ Finalize pipeline after execution.
279
+
280
+ Override for custom finalization.
281
+
282
+ Args:
283
+ state: Final pipeline state
284
+
285
+ Returns:
286
+ Finalized state
287
+ """
288
+ return state
289
+
290
+ def should_stop(self, state: PipelineState) -> bool:
291
+ """
292
+ Check if pipeline should stop execution.
293
+
294
+ Override for custom stopping conditions.
295
+
296
+ Args:
297
+ state: Current pipeline state
298
+
299
+ Returns:
300
+ True if pipeline should stop
301
+ """
302
+ return False
303
+
304
+ def set_checkpoint_manager(self, manager: "CheckpointManager"):
305
+ """Set checkpoint manager for persistence."""
306
+ self._checkpoint_manager = manager
307
+
308
+ def set_anomaly_detector(self, detector: "AnomalyDetector"):
309
+ """Set anomaly detector."""
310
+ self._anomaly_detector = detector
311
+
312
+ def set_temporal_memory(self, memory: "TemporalMemoryStore"):
313
+ """Set temporal memory store."""
314
+ self._temporal_memory = memory
315
+
316
+ def on_step_complete(self, callback: Callable[[int, PipelineState], None]):
317
+ """Register step completion callback."""
318
+ self._on_step_complete.append(callback)
319
+
320
+ def on_checkpoint(self, callback: Callable[[int, Dict[str, Any]], None]):
321
+ """Register checkpoint callback."""
322
+ self._on_checkpoint.append(callback)
323
+
324
+ def on_anomaly(self, callback: Callable[[str, Dict[str, Any]], None]):
325
+ """Register anomaly callback."""
326
+ self._on_anomaly.append(callback)
327
+
328
+ def on_rollback(self, callback: Callable[[int, int], None]):
329
+ """Register rollback callback (from_step, to_step)."""
330
+ self._on_rollback.append(callback)
331
+
332
+ def _should_checkpoint(self, step: int) -> bool:
333
+ """Check if checkpoint should be created."""
334
+ # Step-based
335
+ if step - self._last_checkpoint_step >= self.config.checkpoint_interval_steps:
336
+ return True
337
+
338
+ # Time-based
339
+ current_time = time.time()
340
+ if current_time - self._last_checkpoint_time >= self.config.checkpoint_interval_seconds:
341
+ return True
342
+
343
+ return False
344
+
345
+ def _create_checkpoint(self, step: int, state: PipelineState):
346
+ """Create checkpoint at current step."""
347
+ if not self._checkpoint_manager:
348
+ return
349
+
350
+ checkpoint_data = self.get_state_for_checkpoint(state)
351
+
352
+ self._checkpoint_manager.save(
353
+ pipeline_id=self.pipeline_id,
354
+ step=step,
355
+ state_data=checkpoint_data,
356
+ metrics=state.metrics.copy(),
357
+ use_quantum_compression=self.config.enable_quantum_compression,
358
+ )
359
+
360
+ self._last_checkpoint_step = step
361
+ self._last_checkpoint_time = time.time()
362
+ self._checkpoints_created += 1
363
+
364
+ # Prune old checkpoints
365
+ self._checkpoint_manager.prune(self.pipeline_id, self.config.max_checkpoints)
366
+
367
+ # Notify callbacks
368
+ for callback in self._on_checkpoint:
369
+ callback(step, checkpoint_data)
370
+
371
+ logger.info(f"Checkpoint created at step {step}")
372
+
373
+ def _check_anomalies(self, state: PipelineState) -> Optional[Dict[str, Any]]:
374
+ """Check for anomalies in current state."""
375
+ if not self.config.enable_anomaly_detection:
376
+ return None
377
+
378
+ if not self._anomaly_detector:
379
+ return None
380
+
381
+ result = self._anomaly_detector.detect(
382
+ state=state,
383
+ step=state.step,
384
+ gradient_threshold=self.config.gradient_explosion_threshold,
385
+ vanishing_threshold=self.config.gradient_vanishing_threshold,
386
+ )
387
+
388
+ if result and result.is_anomaly:
389
+ self._anomalies_detected += 1
390
+
391
+ # Notify callbacks
392
+ for callback in self._on_anomaly:
393
+ callback(result.anomaly_type, result.to_dict())
394
+
395
+ logger.warning(f"Anomaly detected: {result.anomaly_type} at step {state.step}")
396
+
397
+ return result.to_dict()
398
+
399
+ return None
400
+
401
+ def _rollback_to_checkpoint(self, checkpoint_step: Optional[int] = None) -> Optional[PipelineState]:
402
+ """
403
+ Rollback to a previous checkpoint.
404
+
405
+ Args:
406
+ checkpoint_step: Specific step to rollback to (None = most recent valid)
407
+
408
+ Returns:
409
+ Restored state or None if rollback failed
410
+ """
411
+ if not self._checkpoint_manager:
412
+ logger.error("Cannot rollback: no checkpoint manager")
413
+ return None
414
+
415
+ # Find checkpoint
416
+ if checkpoint_step is not None:
417
+ checkpoint = self._checkpoint_manager.load(self.pipeline_id, checkpoint_step)
418
+ else:
419
+ checkpoint = self._checkpoint_manager.get_latest_valid(self.pipeline_id)
420
+
421
+ if not checkpoint:
422
+ logger.error("No valid checkpoint found for rollback")
423
+ return None
424
+
425
+ # Restore state
426
+ restored_state = self.restore_state_from_checkpoint(checkpoint["state_data"])
427
+
428
+ from_step = self._state.step
429
+ to_step = checkpoint["step_number"]
430
+
431
+ self._state = restored_state
432
+ self._rollbacks_performed += 1
433
+
434
+ # Notify callbacks
435
+ for callback in self._on_rollback:
436
+ callback(from_step, to_step)
437
+
438
+ logger.info(f"Rolled back from step {from_step} to step {to_step}")
439
+
440
+ return restored_state
441
+
442
+ def _store_temporal_state(self, state: PipelineState):
443
+ """Store state in temporal memory."""
444
+ if not self.config.enable_temporal_memory:
445
+ return
446
+
447
+ if not self._temporal_memory:
448
+ return
449
+
450
+ # Create state vector from metrics and gradients
451
+ state_vector = list(state.metrics.values())
452
+ if state.gradient_history:
453
+ state_vector.extend(state.gradient_history[-10:]) # Last 10 gradients
454
+
455
+ if state_vector:
456
+ self._temporal_memory.store(
457
+ pipeline_id=self.pipeline_id,
458
+ run_id=self._run_id,
459
+ sequence_number=state.step,
460
+ state_vector=state_vector,
461
+ metadata={"step": state.step, "metrics": state.metrics},
462
+ )
463
+
464
+ def run(self, total_steps: int, initial_state: Optional[PipelineState] = None) -> PipelineResult:
465
+ """
466
+ Execute the pipeline.
467
+
468
+ Args:
469
+ total_steps: Total number of steps to execute
470
+ initial_state: Optional initial state (for resumption)
471
+
472
+ Returns:
473
+ Pipeline execution result
474
+ """
475
+ start_time = time.time()
476
+ self._is_running = True
477
+ error = None
478
+ status = "completed"
479
+
480
+ try:
481
+ # Initialize
482
+ self._state = initial_state or self.initialize()
483
+ self._last_checkpoint_time = time.time()
484
+ self._last_checkpoint_step = self._state.step
485
+
486
+ logger.info(f"Starting pipeline {self.name} from step {self._state.step}")
487
+
488
+ # Execute steps
489
+ while self._state.step < total_steps:
490
+ if self._is_paused:
491
+ time.sleep(0.1)
492
+ continue
493
+
494
+ step = self._state.step
495
+
496
+ # Execute step
497
+ try:
498
+ self._state = self.execute_step(step, self._state)
499
+ self._state.step = step + 1
500
+ except Exception as e:
501
+ logger.error(f"Step {step} failed: {e}")
502
+
503
+ if self.config.auto_rollback_on_critical:
504
+ restored = self._rollback_to_checkpoint()
505
+ if restored:
506
+ continue
507
+
508
+ raise
509
+
510
+ # Notify step completion
511
+ for callback in self._on_step_complete:
512
+ callback(step, self._state)
513
+
514
+ # Check for anomalies
515
+ anomaly = self._check_anomalies(self._state)
516
+ if anomaly and anomaly.get("severity") == "critical":
517
+ if self.config.auto_rollback_on_critical:
518
+ restored = self._rollback_to_checkpoint()
519
+ if restored:
520
+ continue
521
+ else:
522
+ status = "failed"
523
+ error = f"Critical anomaly and rollback failed: {anomaly}"
524
+ break
525
+
526
+ # Store temporal state
527
+ self._store_temporal_state(self._state)
528
+
529
+ # Create checkpoint if needed
530
+ if self._should_checkpoint(self._state.step):
531
+ self._create_checkpoint(self._state.step, self._state)
532
+
533
+ # Check stopping condition
534
+ if self.should_stop(self._state):
535
+ logger.info(f"Pipeline stopped early at step {self._state.step}")
536
+ break
537
+
538
+ # Finalize
539
+ self._state = self.finalize(self._state)
540
+
541
+ # Final checkpoint
542
+ if self._checkpoint_manager:
543
+ self._create_checkpoint(self._state.step, self._state)
544
+
545
+ except Exception as e:
546
+ status = "failed"
547
+ error = str(e)
548
+ logger.error(f"Pipeline failed: {e}")
549
+
550
+ # Attempt rollback on failure
551
+ if self.config.auto_rollback_on_critical:
552
+ restored = self._rollback_to_checkpoint()
553
+ if restored:
554
+ status = "rolled_back"
555
+
556
+ finally:
557
+ self._is_running = False
558
+
559
+ end_time = time.time()
560
+ total_time_ms = (end_time - start_time) * 1000
561
+
562
+ return PipelineResult(
563
+ pipeline_id=self.pipeline_id,
564
+ status=status,
565
+ final_state=self._state,
566
+ total_steps=self._state.step,
567
+ total_time_ms=total_time_ms,
568
+ checkpoints_created=self._checkpoints_created,
569
+ anomalies_detected=self._anomalies_detected,
570
+ rollbacks_performed=self._rollbacks_performed,
571
+ error=error,
572
+ )
573
+
574
+ def pause(self):
575
+ """Pause pipeline execution."""
576
+ self._is_paused = True
577
+ logger.info(f"Pipeline {self.name} paused")
578
+
579
+ def resume(self):
580
+ """Resume pipeline execution."""
581
+ self._is_paused = False
582
+ logger.info(f"Pipeline {self.name} resumed")
583
+
584
+ def stop(self):
585
+ """Stop pipeline execution."""
586
+ self._is_running = False
587
+ logger.info(f"Pipeline {self.name} stopped")
588
+
589
+ @property
590
+ def state(self) -> PipelineState:
591
+ """Get current pipeline state."""
592
+ return self._state
593
+
594
+ @property
595
+ def is_running(self) -> bool:
596
+ """Check if pipeline is running."""
597
+ return self._is_running
598
+
599
+ @property
600
+ def is_paused(self) -> bool:
601
+ """Check if pipeline is paused."""
602
+ return self._is_paused