driftguard-ai-sdk 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. driftguard_ai_sdk-1.0.0/PKG-INFO +39 -0
  2. driftguard_ai_sdk-1.0.0/README.md +197 -0
  3. driftguard_ai_sdk-1.0.0/driftguard/__init__.py +16 -0
  4. driftguard_ai_sdk-1.0.0/driftguard/alert.py +89 -0
  5. driftguard_ai_sdk-1.0.0/driftguard/callback_runner.py +342 -0
  6. driftguard_ai_sdk-1.0.0/driftguard/config.py +55 -0
  7. driftguard_ai_sdk-1.0.0/driftguard/drift_detector.py +308 -0
  8. driftguard_ai_sdk-1.0.0/driftguard/tracker.py +611 -0
  9. driftguard_ai_sdk-1.0.0/driftguard/validation.py +84 -0
  10. driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/PKG-INFO +39 -0
  11. driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/SOURCES.txt +28 -0
  12. driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/dependency_links.txt +1 -0
  13. driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/requires.txt +34 -0
  14. driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/top_level.txt +1 -0
  15. driftguard_ai_sdk-1.0.0/setup.cfg +4 -0
  16. driftguard_ai_sdk-1.0.0/setup.py +58 -0
  17. driftguard_ai_sdk-1.0.0/tests/test_artifact_storage.py +122 -0
  18. driftguard_ai_sdk-1.0.0/tests/test_audit_log.py +96 -0
  19. driftguard_ai_sdk-1.0.0/tests/test_auth.py +81 -0
  20. driftguard_ai_sdk-1.0.0/tests/test_canary_router.py +96 -0
  21. driftguard_ai_sdk-1.0.0/tests/test_champion_challenger.py +162 -0
  22. driftguard_ai_sdk-1.0.0/tests/test_critical_fixes.py +409 -0
  23. driftguard_ai_sdk-1.0.0/tests/test_drift_detector.py +122 -0
  24. driftguard_ai_sdk-1.0.0/tests/test_model_registry.py +115 -0
  25. driftguard_ai_sdk-1.0.0/tests/test_ownership.py +104 -0
  26. driftguard_ai_sdk-1.0.0/tests/test_projects.py +73 -0
  27. driftguard_ai_sdk-1.0.0/tests/test_retrain_pipeline.py +100 -0
  28. driftguard_ai_sdk-1.0.0/tests/test_retrainer_callback.py +169 -0
  29. driftguard_ai_sdk-1.0.0/tests/test_rollback_persistence.py +126 -0
  30. driftguard_ai_sdk-1.0.0/tests/test_threshold_engine.py +174 -0
@@ -0,0 +1,39 @@
1
+ Metadata-Version: 2.4
2
+ Name: driftguard-ai-sdk
3
+ Version: 1.0.0
4
+ Summary: DriftGuard — Autonomous Model Health Platform
5
+ Author: DriftGuard Team
6
+ Requires-Python: >=3.9
7
+ Requires-Dist: numpy>=1.24
8
+ Requires-Dist: httpx>=0.24
9
+ Requires-Dist: python-dotenv>=1.0
10
+ Requires-Dist: river>=0.21.2
11
+ Requires-Dist: scikit-learn>=1.3
12
+ Provides-Extra: server
13
+ Requires-Dist: fastapi==0.111.0; extra == "server"
14
+ Requires-Dist: uvicorn==0.28.1; extra == "server"
15
+ Requires-Dist: pydantic==1.10.13; extra == "server"
16
+ Requires-Dist: prometheus-client==0.20.0; extra == "server"
17
+ Requires-Dist: pandas==2.2.2; extra == "server"
18
+ Requires-Dist: redis==5.0.4; extra == "server"
19
+ Requires-Dist: psycopg2-binary==2.9.9; extra == "server"
20
+ Requires-Dist: sqlalchemy==2.0.30; extra == "server"
21
+ Provides-Extra: validation
22
+ Requires-Dist: great-expectations==0.18.15; extra == "validation"
23
+ Requires-Dist: sqlalchemy==1.4.41; extra == "validation"
24
+ Provides-Extra: evidently
25
+ Requires-Dist: evidently==0.4.30; extra == "evidently"
26
+ Provides-Extra: serving
27
+ Requires-Dist: bentoml==1.2.0; extra == "serving"
28
+ Requires-Dist: ray[serve]==2.10.0; extra == "serving"
29
+ Provides-Extra: pipeline
30
+ Requires-Dist: prefect==2.19.0; extra == "pipeline"
31
+ Requires-Dist: zenml==0.57.0; extra == "pipeline"
32
+ Provides-Extra: test
33
+ Requires-Dist: pytest==8.2.0; extra == "test"
34
+ Requires-Dist: pytest-asyncio==0.23.6; extra == "test"
35
+ Dynamic: author
36
+ Dynamic: provides-extra
37
+ Dynamic: requires-dist
38
+ Dynamic: requires-python
39
+ Dynamic: summary
@@ -0,0 +1,197 @@
1
+ # 🛡️ DriftGuard — Autonomous Model Health Platform
2
+
3
+ DriftGuard is a production-grade, self-healing MLOps platform designed to detect data drift, concept drift, and model degradation in real time, automatically trigger validating retraining pipelines, and progressively deploy champion models via progressive canary routers.
4
+
5
+ ---
6
+
7
+ ## 🏗️ Architecture Design
8
+
9
+ ```
10
+ +---------------------------------------+
11
+ | Client Application |
12
+ +-------------------+-------------------+
13
+ |
14
+ (Predict Telemetry)
15
+ v
16
+ +-------------------+-------------------+
17
+ | DriftGuard SDK |
18
+ | - Wrapper pattern intercept |
19
+ | - River ADWIN concept drift checks |
20
+ +-------------------+-------------------+
21
+ |
22
+ (HTTP Telemetry)
23
+ v
24
+ +-------------------+-------------------+
25
+ | DriftGuard FastAPI Core API | <---+ NextJS Dashboard (:3000)
26
+ | - /register, /predict, /drift | <---+ Grafana (:3001)
27
+ | - Prom metrics /metrics (:8000) |
28
+ +-------------------+-------------------+
29
+ |
30
+ (SLA Drift Breach Trigger)
31
+ v
32
+ +-------------------+-------------------+
33
+ | Prefect Orchestration Server |
34
+ | - drift_detection_flow (:4200) |
35
+ +-------------------+-------------------+
36
+ |
37
+ (Runs steps)
38
+ v
39
+ +-------------------+-------------------+
40
+ | ZenML Step Training Pipelines |
41
+ | - Step 1: Great Expectations Validate|
42
+ | - Step 2: Feast Feature Store Check |
43
+ | - Step 3: Train & Track (MLflow/W&B) |
44
+ | - Step 4: Validate (>1% boost check) |
45
+ | - Step 5: Canary Progressive Deploy |
46
+ | - Step 6: Immutable JSON Ledger & PDF|
47
+ +-------------------+-------------------+
48
+ |
49
+ (Progressive Split Promotes)
50
+ v
51
+ +-------------------+-------------------+
52
+ | BentoML & Ray Serve Fleet |
53
+ | - canary_router: 10%->100% |
54
+ | - SLA Monitoring & Rollbacks |
55
+ +---------------------------------------+
56
+ ```
57
+
58
+ ---
59
+
60
+ ## ⚡ Prerequisites
61
+
62
+ To run and configure DriftGuard, ensure the following are installed:
63
+ - **Python 3.11** only (Ray and BentoML have incomplete 3.12 support).
64
+ - **Docker & Docker Compose** (for multi-service orchestration).
65
+ - **kubectl & Helm** (optional, for Kubernetes deployments).
66
+ - **HashiCorp Terraform** (optional, for AWS cloud provisioning).
67
+
68
+ ---
69
+
70
+ ## 🚀 Quick Start in 5 Lines
71
+
72
+ Wrap any scikit-learn, PyTorch, or HuggingFace model with DriftGuard SDK to track predictions, compute concept drift, and initiate auto-healing:
73
+
74
+ ```python
75
+ from driftguard import DriftGuard
76
+
77
+ # 1. Initialize DriftGuard
78
+ dg = DriftGuard(model_id="fraud-detector-v1", api_url="http://localhost:8000", drift_threshold=0.15, auto_retrain=True)
79
+
80
+ # 2. Wrap model seamlessly
81
+ model = dg.wrap(trained_sklearn_model)
82
+
83
+ # 3. Predict normally - DriftGuard tracks inputs, outputs, and triggers retrain on drift!
84
+ prediction = model.predict(features)
85
+ ```
86
+
87
+ ---
88
+
89
+ ## 📦 Installation & Setup
90
+
91
+ ### 1. Local Package Installation
92
+ Clone the repository and install the DriftGuard package locally:
93
+ ```bash
94
+ git clone https://github.com/your-repo/DriftGuard.git
95
+ cd DriftGuard
96
+ pip install -e .
97
+ ```
98
+
99
+ To install validation pipelines dependencies (Great Expectations + SQLAlchemy 1.4 pin) separately:
100
+ ```bash
101
+ pip install -e ".[validation]"
102
+ ```
103
+
104
+ ### 2. Launch Platform Services
105
+ Spin up the entire 8-service DriftGuard stack (FastAPI, NextJS dashboard, MLflow, Prefect, Postgres, Redis, Prometheus, Grafana) instantly:
106
+ ```bash
107
+ docker-compose -f infra/docker-compose.yml up --build -d
108
+ ```
109
+
110
+ ---
111
+
112
+ ## ⚙️ SDK Configuration Parameters
113
+
114
+ The `DriftGuard` class accepts the following parameters:
115
+
116
+ | Parameter | Type | Default | Description |
117
+ |---|---|---|---|
118
+ | `model_id` | `str` | *Required* | Unique name identifier of the model. |
119
+ | `api_url` | `str` | `http://localhost:8000` | Gateway endpoint of the DriftGuard API. |
120
+ | `drift_threshold` | `float` | `0.15` | Limit before concept drift alert and retraining triggers. |
121
+ | `auto_retrain` | `bool` | `True` | Automatically triggers retraining flow on threshold breach. |
122
+
123
+ ---
124
+
125
+ ## 📡 API Gateway Documentation
126
+
127
+ DriftGuard Core API runs on port `8000`. Key REST endpoints include:
128
+
129
+ ### `POST /register`
130
+ Registers a model for monitoring.
131
+ - **Request Body:**
132
+ ```json
133
+ {
134
+ "model_id": "fraud-detector-v1",
135
+ "drift_threshold": 0.15,
136
+ "reference_data_path": "./data/baseline.parquet",
137
+ "features": ["amount", "location_score", "velocity"]
138
+ }
139
+ ```
140
+ - **Response:** `{"status": "registered", "model_id": "fraud-detector-v1"}`
141
+
142
+ ### `POST /predict/{model_id}`
143
+ SDK telemetry gateway recording prediction details and updating gauges.
144
+ - **Request Body:**
145
+ ```json
146
+ {
147
+ "features": [1.2, 0.4, 9.8],
148
+ "prediction": [1.0],
149
+ "drift_score": 0.08
150
+ }
151
+ ```
152
+
153
+ ### `GET /drift/{model_id}`
154
+ Fetches the last 100 historical drift scores for Recharts charts rendering.
155
+
156
+ ### `POST /retrain/{model_id}`
157
+ Manually triggers the background retraining pipeline flow.
158
+
159
+ ### `GET /metrics`
160
+ Exposes system health gauges for Prometheus scrapers in OpenMetrics format.
161
+
162
+ ---
163
+
164
+ ## 🖥️ Dashboards & Observability Portals
165
+
166
+ Once the docker services are online:
167
+ 1. **NextJS UI Dashboard:** Navigate to [http://localhost:3000](http://localhost:3000) to review models list, drift histories, vertical retraining timelines, and searchable audit logs.
168
+ 2. **MLflow Registry:** Access [http://localhost:5000](http://localhost:5000) to review runs parameters, artifacts (confusion matrix plots), and staging/production champions.
169
+ 3. **Prefect Dashboard:** Access [http://localhost:4200](http://localhost:4200) to inspect flows execution history.
170
+ 4. **Grafana Dashboards:** Open [http://localhost:3001](http://localhost:3001) (User: `admin` | Pass: `admin`) to view pre-provisioned telemetry panels scraping predictions, drift rates, accuracy levels, and quantiles latency.
171
+
172
+ ---
173
+
174
+ ## 🧪 Running Unit Tests
175
+
176
+ Run all unit tests verifying ADWIN concept detectors, Great Expectations validators, canary routing splits, emergency rollbacks, and cryptographic audit log chains:
177
+ ```bash
178
+ pytest tests/ -v
179
+ ```
180
+
181
+ ---
182
+
183
+ ## ☁️ Deploying to AWS Cloud (Terraform)
184
+
185
+ Deploy DriftGuard core infrastructure to Amazon Web Services:
186
+ ```bash
187
+ cd infra/terraform
188
+ terraform init
189
+ terraform plan
190
+ terraform apply -var="db_password=SecurePasswordPass22!"
191
+ ```
192
+ This provisions:
193
+ - Amazon EKS cluster for progressive Kubernetes canary Rollouts.
194
+ - Amazon RDS PostgreSQL database for MLflow, Prefect, and platform metadata.
195
+ - Amazon ElastiCache Redis for online real-time Feast features access.
196
+ - Amazon S3 bucket for artifacts.
197
+ - Amazon ECR for Docker images.
@@ -0,0 +1,16 @@
1
+ """
2
+ DriftGuard — public package entry point.
3
+
4
+ Users import from here:
5
+ from driftguard import DriftGuard
6
+ """
7
+ from .tracker import DriftGuard, DriftGuardModelWrapper
8
+ from .callback_runner import RetrainerCallbackRunner
9
+ from .config import settings
10
+
11
+ __all__ = [
12
+ "DriftGuard",
13
+ "DriftGuardModelWrapper",
14
+ "RetrainerCallbackRunner",
15
+ "settings",
16
+ ]
@@ -0,0 +1,89 @@
1
+ """
2
+ DriftGuard Alerting & Notification Client.
3
+ Handles distribution of system alert updates to console logs and webhook endpoints like Slack.
4
+ """
5
+ import httpx
6
+ import logging
7
+ from typing import Dict, Any, Optional
8
+
9
+ from driftguard.config import settings
10
+
11
+ logger = logging.getLogger("DriftGuard.Alert")
12
+
13
+ def send_alert(
14
+ event_type: str,
15
+ message: str,
16
+ details: Optional[Dict[str, Any]] = None,
17
+ webhook_url: Optional[str] = None
18
+ ) -> bool:
19
+ """
20
+ Distributes alert notification updates to platform targets.
21
+
22
+ Args:
23
+ event_type: Category of the alert (e.g., 'drift_detected', 'retrain_triggered', 'validation_failed', 'model_promoted').
24
+ message: Readable summary message string.
25
+ details: Optional JSON data properties containing details about the event.
26
+ webhook_url: Optional override of settings Slack Webhook URL.
27
+
28
+ Returns:
29
+ True if successfully sent (or logged when offline), False on HTTP failure.
30
+ """
31
+ payload_details = details or {}
32
+ url = webhook_url or settings.SLACK_WEBHOOK_URL
33
+
34
+ # 1. Standard Logger Print
35
+ log_msg = f"[ALERT - {event_type.upper()}] {message} | Details: {payload_details}"
36
+ if event_type in ["drift_detected", "validation_failed", "rollback"]:
37
+ logger.error(log_msg)
38
+ else:
39
+ logger.info(log_msg)
40
+
41
+ # 2. Slack Webhook Dispatch
42
+ if not url or "mock_webhook_url" in url:
43
+ logger.debug("No active Slack webhook configured. Skipping network alert.")
44
+ return True
45
+
46
+ try:
47
+ # Construct premium formatted Slack block
48
+ blocks = [
49
+ {
50
+ "type": "header",
51
+ "text": {
52
+ "type": "plain_text",
53
+ "text": f"🛡️ DriftGuard Alert: {event_type.replace('_', ' ').title()}",
54
+ "emoji": True
55
+ }
56
+ },
57
+ {
58
+ "type": "section",
59
+ "text": {
60
+ "type": "mrkdwn",
61
+ "text": f"*Message:*\n{message}"
62
+ }
63
+ }
64
+ ]
65
+
66
+ if payload_details:
67
+ details_str = "\n".join([f"• *{k}:* {v}" for k, v in payload_details.items()])
68
+ blocks.append({
69
+ "type": "section",
70
+ "text": {
71
+ "type": "mrkdwn",
72
+ "text": f"*Metadata:*\n{details_str}"
73
+ }
74
+ })
75
+
76
+ slack_payload = {"blocks": blocks}
77
+
78
+ with httpx.Client(timeout=3.0) as client:
79
+ resp = client.post(url, json=slack_payload)
80
+ if resp.status_code in [200, 201]:
81
+ logger.info("Successfully posted alert to Slack channel.")
82
+ return True
83
+ else:
84
+ logger.error(f"Slack webhook endpoint returned HTTP {resp.status_code}: {resp.text}")
85
+ return False
86
+
87
+ except Exception as e:
88
+ logger.error(f"Failed to transmit Slack webhook: {e}")
89
+ return False
@@ -0,0 +1,342 @@
1
+ """
2
+ DriftGuard Retrainer Callback Runner.
3
+
4
+ Executes user-registered ``@dg.retrainer`` callbacks entirely inside the
5
+ SDK process — where the user's data, credentials, and environment live.
6
+
7
+ Flow
8
+ ----
9
+ 1. Notify API: POST /retrain/{model_id} with source="sdk_callback"
10
+ → Server creates a DBRetrainingEvent record and returns event_id.
11
+ → Server does NOT spawn its own background pipeline.
12
+ 2. Invoke the user's retrainer function → get challenger model.
13
+ 3. Validate challenger vs champion using dg.set_validation_data().
14
+ 4. Report outcome: POST /retrain/{model_id}/complete.
15
+ → Server updates model version, accuracy, audit log, Prometheus, Slack.
16
+ 5. Update tracker._champion_model to the new champion.
17
+ 6. Reset tracker.retraining_triggered so future drift events can fire.
18
+
19
+ NOTE: Production telemetry stored in dg_predictions is never touched.
20
+ Training data comes exclusively from the user's registered callback.
21
+ """
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING, Any, Optional, Tuple
26
+
27
+ import httpx
28
+
29
+ if TYPE_CHECKING:
30
+ from driftguard.tracker import DriftGuard
31
+
32
+ logger = logging.getLogger("DriftGuard.CallbackRunner")
33
+
34
+
35
+ class RetrainerCallbackRunner:
36
+ """
37
+ Orchestrates the full local retraining pipeline for a registered callback.
38
+
39
+ Parameters
40
+ ----------
41
+ tracker:
42
+ The ``DriftGuard`` instance that owns the callback.
43
+ """
44
+
45
+ def __init__(self, tracker: "DriftGuard") -> None:
46
+ self.tracker = tracker
47
+ self.api_url = tracker.api_url
48
+ self.model_id = tracker.model_id
49
+
50
+ # ------------------------------------------------------------------
51
+ # Public entry point
52
+ # ------------------------------------------------------------------
53
+
54
+ def run(self, drift_score: float) -> bool:
55
+ """
56
+ Execute the full callback-based retraining pipeline.
57
+
58
+ Returns
59
+ -------
60
+ bool
61
+ ``True`` if the challenger was promoted; ``False`` otherwise.
62
+ """
63
+ logger.info(
64
+ f"[{self.model_id}] Callback retraining pipeline started "
65
+ f"(drift_score={drift_score:.4f})"
66
+ )
67
+
68
+ # Step 1 — Record event on the server (no background task spawned)
69
+ current_version = self._get_current_version()
70
+ event_id = self._notify_retrain_start(drift_score)
71
+
72
+ try:
73
+ # Step 2 — Invoke user-registered callback
74
+ challenger_model = self._invoke_callback()
75
+
76
+ # Step 3 — Validate challenger against champion
77
+ validation_passed, champ_score, chall_score = self._validate(challenger_model)
78
+
79
+ print("\n===== VALIDATION RESULTS =====")
80
+ print("Champion:", champ_score)
81
+ print("Challenger:", chall_score)
82
+ print("Passed:", validation_passed)
83
+ print("==============================\n")
84
+
85
+ if not validation_passed:
86
+ reason = (
87
+ f"Challenger accuracy {chall_score:.4f} did not beat "
88
+ f"champion {champ_score:.4f} by ≥1%."
89
+ )
90
+ logger.warning(f"[{self.model_id}] {reason}")
91
+ self._report_failure(event_id=event_id, reason=reason, chall_score=chall_score)
92
+ return False
93
+
94
+ # Step 4 — Promote challenger
95
+ print("PROMOTION STAGE STARTED")
96
+ new_version = self._bump_version(current_version)
97
+ print(f"NEW VERSION = {new_version}")
98
+
99
+ # Persist challenger model before promotion
100
+ if self.tracker.project_id:
101
+ try:
102
+ import joblib
103
+ import os
104
+ from driftguard.config import settings as _settings
105
+ dir_path = os.path.join(
106
+ _settings.ARTIFACT_ROOT,
107
+ str(self.tracker.project_id),
108
+ self.model_id
109
+ )
110
+ os.makedirs(dir_path, exist_ok=True)
111
+ file_path = os.path.join(dir_path, f"version_{new_version}.pkl")
112
+ joblib.dump(challenger_model, file_path)
113
+ print(f"PERSISTED CHALLENGER MODEL TO {file_path}")
114
+ logger.info(f"[{self.model_id}] Persisted challenger model before promotion to {file_path}")
115
+ except Exception as e:
116
+ print(f"FAILED TO PERSIST CHALLENGER MODEL: {e}")
117
+ logger.warning(f"[{self.model_id}] Failed to persist challenger model: {e}")
118
+
119
+ print("POSTING COMPLETION EVENT")
120
+ self._report_success(
121
+ event_id=event_id,
122
+ new_version=new_version,
123
+ new_accuracy=chall_score,
124
+ old_accuracy=champ_score,
125
+ )
126
+ print("COMPLETION EVENT POSTED")
127
+
128
+ # Step 5 — Update local champion reference so next comparison is correct
129
+ self.tracker._champion_model = challenger_model
130
+ print("CHAMPION UPDATED")
131
+
132
+ logger.info(
133
+ f"[{self.model_id}] Challenger promoted: "
134
+ f"{current_version} → {new_version} "
135
+ f"(accuracy {champ_score:.4f} → {chall_score:.4f})"
136
+ )
137
+ return True
138
+
139
+ except Exception as exc:
140
+ import traceback
141
+ traceback.print_exc()
142
+ logger.error(
143
+ f"[{self.model_id}] Callback retraining pipeline failed: {exc}",
144
+ exc_info=True,
145
+ )
146
+ self._report_failure(event_id=event_id, reason=str(exc), chall_score=0.0)
147
+ return False
148
+
149
+ finally:
150
+ # Always reset so future drift events can trigger a new run
151
+ self.tracker.retraining_triggered = False
152
+
153
+ # ------------------------------------------------------------------
154
+ # Step implementations
155
+ # ------------------------------------------------------------------
156
+
157
+ def _invoke_callback(self) -> Any:
158
+ """
159
+ Call the user's ``@dg.retrainer`` function and return the trained model.
160
+
161
+ Raises
162
+ ------
163
+ RuntimeError
164
+ If the callback raises or returns ``None``.
165
+ """
166
+ fn = self.tracker._retrainer_fn
167
+ logger.info(f"[{self.model_id}] Invoking retrainer callback: {fn.__name__}()")
168
+ try:
169
+ model = fn()
170
+ except Exception as exc:
171
+ raise RuntimeError(
172
+ f"Retrainer callback '{fn.__name__}' raised an exception: {exc}"
173
+ ) from exc
174
+
175
+ if model is None:
176
+ raise ValueError(
177
+ f"Retrainer callback '{fn.__name__}' returned None. "
178
+ "The function must return a trained model object."
179
+ )
180
+ if not (hasattr(model, "predict") or callable(model)):
181
+ raise TypeError(
182
+ f"Retrainer callback '{fn.__name__}' returned an invalid model type '{type(model).__name__}'. "
183
+ "The model must have a 'predict' method or be callable."
184
+ )
185
+ return model
186
+
187
+ def _validate(self, challenger_model: Any) -> Tuple[bool, float, float]:
188
+ """
189
+ Compare challenger against the current champion.
190
+
191
+ If ``dg.set_champion()`` has not been called, the challenger is
192
+ promoted directly as the first known-good version.
193
+
194
+ If ``dg.set_validation_data()`` has not been called, validation is
195
+ skipped and the challenger is promoted with a warning.
196
+
197
+ Returns
198
+ -------
199
+ (validation_passed, champion_score, challenger_score)
200
+ """
201
+ champion_model = self.tracker._champion_model
202
+
203
+ if champion_model is None:
204
+ logger.info(
205
+ f"[{self.model_id}] No champion registered via dg.set_champion(). "
206
+ "Promoting challenger as first champion."
207
+ )
208
+ return True, 0.0, 1.0
209
+
210
+ val_features = self.tracker._validation_features
211
+ val_labels = self.tracker._validation_labels
212
+
213
+ if val_features is None or val_labels is None:
214
+ raise ValueError(
215
+ f"Validation data is missing for model '{self.model_id}'. "
216
+ "Validation datasets are required when retraining triggers."
217
+ )
218
+
219
+ from driftguard.validation import validate_challenger_vs_champion
220
+
221
+ return validate_challenger_vs_champion(
222
+ champion_model=champion_model,
223
+ challenger_model=challenger_model,
224
+ val_features=val_features,
225
+ val_labels=val_labels,
226
+ threshold_pct=0.01, # challenger must beat champion by ≥1%
227
+ )
228
+
229
+ # ------------------------------------------------------------------
230
+ # API notification helpers
231
+ # ------------------------------------------------------------------
232
+
233
+ def _get_current_version(self) -> str:
234
+ """Fetch the model's current version string from the API."""
235
+ try:
236
+ headers = {"X-API-Key": self.tracker.api_key} if self.tracker.api_key else {}
237
+ with httpx.Client(timeout=5.0) as client:
238
+ resp = client.get(f"{self.api_url}/models/{self.model_id}", headers=headers)
239
+ if resp.status_code == 200:
240
+ return resp.json().get("version", "1.0.0")
241
+ except Exception as exc:
242
+ logger.debug(f"[{self.model_id}] Could not fetch current version: {exc}")
243
+ return "1.0.0"
244
+
245
+ def _bump_version(self, current_version: str) -> str:
246
+ """Increment the patch segment of a semantic version string."""
247
+ try:
248
+ parts = current_version.split(".")
249
+ parts[-1] = str(int(parts[-1]) + 1)
250
+ return ".".join(parts)
251
+ except Exception:
252
+ return "1.0.1"
253
+
254
+ def _notify_retrain_start(self, drift_score: float) -> Optional[int]:
255
+ """
256
+ POST to ``/retrain/{model_id}`` with ``source="sdk_callback"``.
257
+
258
+ The server records the event and returns an ``event_id`` but does
259
+ NOT spawn its own background retraining task.
260
+ """
261
+ try:
262
+ headers = {"X-API-Key": self.tracker.api_key} if self.tracker.api_key else {}
263
+ with httpx.Client(timeout=5.0) as client:
264
+ resp = client.post(
265
+ f"{self.api_url}/retrain/{self.model_id}",
266
+ json={
267
+ "drift_score": drift_score,
268
+ "triggered_by": "automatic",
269
+ "source": "sdk_callback",
270
+ },
271
+ headers=headers,
272
+ )
273
+ if resp.status_code == 200:
274
+ data = resp.json()
275
+ event_id = data.get("event_id")
276
+ logger.debug(
277
+ f"[{self.model_id}] Retrain event created, event_id={event_id}"
278
+ )
279
+ return event_id
280
+ logger.warning(
281
+ f"[{self.model_id}] /retrain returned HTTP {resp.status_code}"
282
+ )
283
+ except Exception as exc:
284
+ logger.warning(
285
+ f"[{self.model_id}] Could not notify API of retrain start: {exc}"
286
+ )
287
+ return None
288
+
289
+ def _report_success(
290
+ self,
291
+ event_id: Optional[int],
292
+ new_version: str,
293
+ new_accuracy: float,
294
+ old_accuracy: float,
295
+ ) -> None:
296
+ """POST callback pipeline results to ``/retrain/{model_id}/complete``."""
297
+ try:
298
+ headers = {"X-API-Key": self.tracker.api_key} if self.tracker.api_key else {}
299
+ with httpx.Client(timeout=10.0) as client:
300
+ client.post(
301
+ f"{self.api_url}/retrain/{self.model_id}/complete",
302
+ json={
303
+ "event_id": event_id,
304
+ "validation_passed": True,
305
+ "new_version": new_version,
306
+ "new_accuracy": new_accuracy,
307
+ "old_accuracy": old_accuracy,
308
+ "error": None,
309
+ },
310
+ headers=headers,
311
+ )
312
+ except Exception as exc:
313
+ logger.warning(
314
+ f"[{self.model_id}] Could not report retrain success to API: {exc}"
315
+ )
316
+
317
+ def _report_failure(
318
+ self,
319
+ event_id: Optional[int],
320
+ reason: str,
321
+ chall_score: float,
322
+ ) -> None:
323
+ """Report a failed callback pipeline to ``/retrain/{model_id}/complete``."""
324
+ try:
325
+ headers = {"X-API-Key": self.tracker.api_key} if self.tracker.api_key else {}
326
+ with httpx.Client(timeout=10.0) as client:
327
+ client.post(
328
+ f"{self.api_url}/retrain/{self.model_id}/complete",
329
+ json={
330
+ "event_id": event_id,
331
+ "validation_passed": False,
332
+ "new_version": None,
333
+ "new_accuracy": chall_score,
334
+ "old_accuracy": None,
335
+ "error": reason,
336
+ },
337
+ headers=headers,
338
+ )
339
+ except Exception as exc:
340
+ logger.warning(
341
+ f"[{self.model_id}] Could not report retrain failure to API: {exc}"
342
+ )