driftguard-ai-sdk 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- driftguard_ai_sdk-1.0.0/PKG-INFO +39 -0
- driftguard_ai_sdk-1.0.0/README.md +197 -0
- driftguard_ai_sdk-1.0.0/driftguard/__init__.py +16 -0
- driftguard_ai_sdk-1.0.0/driftguard/alert.py +89 -0
- driftguard_ai_sdk-1.0.0/driftguard/callback_runner.py +342 -0
- driftguard_ai_sdk-1.0.0/driftguard/config.py +55 -0
- driftguard_ai_sdk-1.0.0/driftguard/drift_detector.py +308 -0
- driftguard_ai_sdk-1.0.0/driftguard/tracker.py +611 -0
- driftguard_ai_sdk-1.0.0/driftguard/validation.py +84 -0
- driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/PKG-INFO +39 -0
- driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/SOURCES.txt +28 -0
- driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/dependency_links.txt +1 -0
- driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/requires.txt +34 -0
- driftguard_ai_sdk-1.0.0/driftguard_ai_sdk.egg-info/top_level.txt +1 -0
- driftguard_ai_sdk-1.0.0/setup.cfg +4 -0
- driftguard_ai_sdk-1.0.0/setup.py +58 -0
- driftguard_ai_sdk-1.0.0/tests/test_artifact_storage.py +122 -0
- driftguard_ai_sdk-1.0.0/tests/test_audit_log.py +96 -0
- driftguard_ai_sdk-1.0.0/tests/test_auth.py +81 -0
- driftguard_ai_sdk-1.0.0/tests/test_canary_router.py +96 -0
- driftguard_ai_sdk-1.0.0/tests/test_champion_challenger.py +162 -0
- driftguard_ai_sdk-1.0.0/tests/test_critical_fixes.py +409 -0
- driftguard_ai_sdk-1.0.0/tests/test_drift_detector.py +122 -0
- driftguard_ai_sdk-1.0.0/tests/test_model_registry.py +115 -0
- driftguard_ai_sdk-1.0.0/tests/test_ownership.py +104 -0
- driftguard_ai_sdk-1.0.0/tests/test_projects.py +73 -0
- driftguard_ai_sdk-1.0.0/tests/test_retrain_pipeline.py +100 -0
- driftguard_ai_sdk-1.0.0/tests/test_retrainer_callback.py +169 -0
- driftguard_ai_sdk-1.0.0/tests/test_rollback_persistence.py +126 -0
- driftguard_ai_sdk-1.0.0/tests/test_threshold_engine.py +174 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: driftguard-ai-sdk
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: DriftGuard — Autonomous Model Health Platform
|
|
5
|
+
Author: DriftGuard Team
|
|
6
|
+
Requires-Python: >=3.9
|
|
7
|
+
Requires-Dist: numpy>=1.24
|
|
8
|
+
Requires-Dist: httpx>=0.24
|
|
9
|
+
Requires-Dist: python-dotenv>=1.0
|
|
10
|
+
Requires-Dist: river>=0.21.2
|
|
11
|
+
Requires-Dist: scikit-learn>=1.3
|
|
12
|
+
Provides-Extra: server
|
|
13
|
+
Requires-Dist: fastapi==0.111.0; extra == "server"
|
|
14
|
+
Requires-Dist: uvicorn==0.28.1; extra == "server"
|
|
15
|
+
Requires-Dist: pydantic==1.10.13; extra == "server"
|
|
16
|
+
Requires-Dist: prometheus-client==0.20.0; extra == "server"
|
|
17
|
+
Requires-Dist: pandas==2.2.2; extra == "server"
|
|
18
|
+
Requires-Dist: redis==5.0.4; extra == "server"
|
|
19
|
+
Requires-Dist: psycopg2-binary==2.9.9; extra == "server"
|
|
20
|
+
Requires-Dist: sqlalchemy==2.0.30; extra == "server"
|
|
21
|
+
Provides-Extra: validation
|
|
22
|
+
Requires-Dist: great-expectations==0.18.15; extra == "validation"
|
|
23
|
+
Requires-Dist: sqlalchemy==1.4.41; extra == "validation"
|
|
24
|
+
Provides-Extra: evidently
|
|
25
|
+
Requires-Dist: evidently==0.4.30; extra == "evidently"
|
|
26
|
+
Provides-Extra: serving
|
|
27
|
+
Requires-Dist: bentoml==1.2.0; extra == "serving"
|
|
28
|
+
Requires-Dist: ray[serve]==2.10.0; extra == "serving"
|
|
29
|
+
Provides-Extra: pipeline
|
|
30
|
+
Requires-Dist: prefect==2.19.0; extra == "pipeline"
|
|
31
|
+
Requires-Dist: zenml==0.57.0; extra == "pipeline"
|
|
32
|
+
Provides-Extra: test
|
|
33
|
+
Requires-Dist: pytest==8.2.0; extra == "test"
|
|
34
|
+
Requires-Dist: pytest-asyncio==0.23.6; extra == "test"
|
|
35
|
+
Dynamic: author
|
|
36
|
+
Dynamic: provides-extra
|
|
37
|
+
Dynamic: requires-dist
|
|
38
|
+
Dynamic: requires-python
|
|
39
|
+
Dynamic: summary
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# 🛡️ DriftGuard — Autonomous Model Health Platform
|
|
2
|
+
|
|
3
|
+
DriftGuard is a production-grade, self-healing MLOps platform designed to detect data drift, concept drift, and model degradation in real time, automatically trigger validating retraining pipelines, and progressively deploy champion models via progressive canary routers.
|
|
4
|
+
|
|
5
|
+
---
|
|
6
|
+
|
|
7
|
+
## 🏗️ Architecture Design
|
|
8
|
+
|
|
9
|
+
```
|
|
10
|
+
+---------------------------------------+
|
|
11
|
+
| Client Application |
|
|
12
|
+
+-------------------+-------------------+
|
|
13
|
+
|
|
|
14
|
+
(Predict Telemetry)
|
|
15
|
+
v
|
|
16
|
+
+-------------------+-------------------+
|
|
17
|
+
| DriftGuard SDK |
|
|
18
|
+
| - Wrapper pattern intercept |
|
|
19
|
+
| - River ADWIN concept drift checks |
|
|
20
|
+
+-------------------+-------------------+
|
|
21
|
+
|
|
|
22
|
+
(HTTP Telemetry)
|
|
23
|
+
v
|
|
24
|
+
+-------------------+-------------------+
|
|
25
|
+
| DriftGuard FastAPI Core API | <---+ NextJS Dashboard (:3000)
|
|
26
|
+
| - /register, /predict, /drift | <---+ Grafana (:3001)
|
|
27
|
+
| - Prom metrics /metrics (:8000) |
|
|
28
|
+
+-------------------+-------------------+
|
|
29
|
+
|
|
|
30
|
+
(SLA Drift Breach Trigger)
|
|
31
|
+
v
|
|
32
|
+
+-------------------+-------------------+
|
|
33
|
+
| Prefect Orchestration Server |
|
|
34
|
+
| - drift_detection_flow (:4200) |
|
|
35
|
+
+-------------------+-------------------+
|
|
36
|
+
|
|
|
37
|
+
(Runs steps)
|
|
38
|
+
v
|
|
39
|
+
+-------------------+-------------------+
|
|
40
|
+
| ZenML Step Training Pipelines |
|
|
41
|
+
| - Step 1: Great Expectations Validate|
|
|
42
|
+
| - Step 2: Feast Feature Store Check |
|
|
43
|
+
| - Step 3: Train & Track (MLflow/W&B) |
|
|
44
|
+
| - Step 4: Validate (>1% boost check) |
|
|
45
|
+
| - Step 5: Canary Progressive Deploy |
|
|
46
|
+
| - Step 6: Immutable JSON Ledger & PDF|
|
|
47
|
+
+-------------------+-------------------+
|
|
48
|
+
|
|
|
49
|
+
(Progressive Split Promotes)
|
|
50
|
+
v
|
|
51
|
+
+-------------------+-------------------+
|
|
52
|
+
| BentoML & Ray Serve Fleet |
|
|
53
|
+
| - canary_router: 10%->100% |
|
|
54
|
+
| - SLA Monitoring & Rollbacks |
|
|
55
|
+
+---------------------------------------+
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
---
|
|
59
|
+
|
|
60
|
+
## ⚡ Prerequisites
|
|
61
|
+
|
|
62
|
+
To run and configure DriftGuard, ensure the following are installed:
|
|
63
|
+
- **Python 3.11** only (Ray and BentoML have incomplete 3.12 support).
|
|
64
|
+
- **Docker & Docker Compose** (for multi-service orchestration).
|
|
65
|
+
- **kubectl & Helm** (optional, for Kubernetes deployments).
|
|
66
|
+
- **HashiCorp Terraform** (optional, for AWS cloud provisioning).
|
|
67
|
+
|
|
68
|
+
---
|
|
69
|
+
|
|
70
|
+
## 🚀 Quick Start in 5 Lines
|
|
71
|
+
|
|
72
|
+
Wrap any scikit-learn, PyTorch, or HuggingFace model with DriftGuard SDK to track predictions, compute concept drift, and initiate auto-healing:
|
|
73
|
+
|
|
74
|
+
```python
|
|
75
|
+
from driftguard import DriftGuard
|
|
76
|
+
|
|
77
|
+
# 1. Initialize DriftGuard
|
|
78
|
+
dg = DriftGuard(model_id="fraud-detector-v1", api_url="http://localhost:8000", drift_threshold=0.15, auto_retrain=True)
|
|
79
|
+
|
|
80
|
+
# 2. Wrap model seamlessly
|
|
81
|
+
model = dg.wrap(trained_sklearn_model)
|
|
82
|
+
|
|
83
|
+
# 3. Predict normally - DriftGuard tracks inputs, outputs, and triggers retrain on drift!
|
|
84
|
+
prediction = model.predict(features)
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
---
|
|
88
|
+
|
|
89
|
+
## 📦 Installation & Setup
|
|
90
|
+
|
|
91
|
+
### 1. Local Package Installation
|
|
92
|
+
Clone the repository and install the DriftGuard package locally:
|
|
93
|
+
```bash
|
|
94
|
+
git clone https://github.com/your-repo/DriftGuard.git
|
|
95
|
+
cd DriftGuard
|
|
96
|
+
pip install -e .
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
To install validation pipelines dependencies (Great Expectations + SQLAlchemy 1.4 pin) separately:
|
|
100
|
+
```bash
|
|
101
|
+
pip install -e ".[validation]"
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### 2. Launch Platform Services
|
|
105
|
+
Spin up the entire 8-service DriftGuard stack (FastAPI, NextJS dashboard, MLflow, Prefect, Postgres, Redis, Prometheus, Grafana) instantly:
|
|
106
|
+
```bash
|
|
107
|
+
docker-compose -f infra/docker-compose.yml up --build -d
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
---
|
|
111
|
+
|
|
112
|
+
## ⚙️ SDK Configuration Parameters
|
|
113
|
+
|
|
114
|
+
The `DriftGuard` class accepts the following parameters:
|
|
115
|
+
|
|
116
|
+
| Parameter | Type | Default | Description |
|
|
117
|
+
|---|---|---|---|
|
|
118
|
+
| `model_id` | `str` | *Required* | Unique name identifier of the model. |
|
|
119
|
+
| `api_url` | `str` | `http://localhost:8000` | Gateway endpoint of the DriftGuard API. |
|
|
120
|
+
| `drift_threshold` | `float` | `0.15` | Limit before concept drift alert and retraining triggers. |
|
|
121
|
+
| `auto_retrain` | `bool` | `True` | Automatically triggers retraining flow on threshold breach. |
|
|
122
|
+
|
|
123
|
+
---
|
|
124
|
+
|
|
125
|
+
## 📡 API Gateway Documentation
|
|
126
|
+
|
|
127
|
+
DriftGuard Core API runs on port `8000`. Key REST endpoints include:
|
|
128
|
+
|
|
129
|
+
### `POST /register`
|
|
130
|
+
Registers a model for monitoring.
|
|
131
|
+
- **Request Body:**
|
|
132
|
+
```json
|
|
133
|
+
{
|
|
134
|
+
"model_id": "fraud-detector-v1",
|
|
135
|
+
"drift_threshold": 0.15,
|
|
136
|
+
"reference_data_path": "./data/baseline.parquet",
|
|
137
|
+
"features": ["amount", "location_score", "velocity"]
|
|
138
|
+
}
|
|
139
|
+
```
|
|
140
|
+
- **Response:** `{"status": "registered", "model_id": "fraud-detector-v1"}`
|
|
141
|
+
|
|
142
|
+
### `POST /predict/{model_id}`
|
|
143
|
+
SDK telemetry gateway recording prediction details and updating gauges.
|
|
144
|
+
- **Request Body:**
|
|
145
|
+
```json
|
|
146
|
+
{
|
|
147
|
+
"features": [1.2, 0.4, 9.8],
|
|
148
|
+
"prediction": [1.0],
|
|
149
|
+
"drift_score": 0.08
|
|
150
|
+
}
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
### `GET /drift/{model_id}`
|
|
154
|
+
Fetches the last 100 historical drift scores for Recharts charts rendering.
|
|
155
|
+
|
|
156
|
+
### `POST /retrain/{model_id}`
|
|
157
|
+
Manually triggers the background retraining pipeline flow.
|
|
158
|
+
|
|
159
|
+
### `GET /metrics`
|
|
160
|
+
Exposes system health gauges for Prometheus scrapers in OpenMetrics format.
|
|
161
|
+
|
|
162
|
+
---
|
|
163
|
+
|
|
164
|
+
## 🖥️ Dashboards & Observability Portals
|
|
165
|
+
|
|
166
|
+
Once the docker services are online:
|
|
167
|
+
1. **NextJS UI Dashboard:** Navigate to [http://localhost:3000](http://localhost:3000) to review models list, drift histories, vertical retraining timelines, and searchable audit logs.
|
|
168
|
+
2. **MLflow Registry:** Access [http://localhost:5000](http://localhost:5000) to review runs parameters, artifacts (confusion matrix plots), and staging/production champions.
|
|
169
|
+
3. **Prefect Dashboard:** Access [http://localhost:4200](http://localhost:4200) to inspect flows execution history.
|
|
170
|
+
4. **Grafana Dashboards:** Open [http://localhost:3001](http://localhost:3001) (User: `admin` | Pass: `admin`) to view pre-provisioned telemetry panels scraping predictions, drift rates, accuracy levels, and quantiles latency.
|
|
171
|
+
|
|
172
|
+
---
|
|
173
|
+
|
|
174
|
+
## 🧪 Running Unit Tests
|
|
175
|
+
|
|
176
|
+
Run all unit tests verifying ADWIN concept detectors, Great Expectations validators, canary routing splits, emergency rollbacks, and cryptographic audit log chains:
|
|
177
|
+
```bash
|
|
178
|
+
pytest tests/ -v
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
---
|
|
182
|
+
|
|
183
|
+
## ☁️ Deploying to AWS Cloud (Terraform)
|
|
184
|
+
|
|
185
|
+
Deploy DriftGuard core infrastructure to Amazon Web Services:
|
|
186
|
+
```bash
|
|
187
|
+
cd infra/terraform
|
|
188
|
+
terraform init
|
|
189
|
+
terraform plan
|
|
190
|
+
terraform apply -var="db_password=SecurePasswordPass22!"
|
|
191
|
+
```
|
|
192
|
+
This provisions:
|
|
193
|
+
- Amazon EKS cluster for progressive Kubernetes canary Rollouts.
|
|
194
|
+
- Amazon RDS PostgreSQL database for MLflow, Prefect, and platform metadata.
|
|
195
|
+
- Amazon ElastiCache Redis for online real-time Feast features access.
|
|
196
|
+
- Amazon S3 bucket for artifacts.
|
|
197
|
+
- Amazon ECR for Docker images.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DriftGuard — public package entry point.
|
|
3
|
+
|
|
4
|
+
Users import from here:
|
|
5
|
+
from driftguard import DriftGuard
|
|
6
|
+
"""
|
|
7
|
+
from .tracker import DriftGuard, DriftGuardModelWrapper
|
|
8
|
+
from .callback_runner import RetrainerCallbackRunner
|
|
9
|
+
from .config import settings
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"DriftGuard",
|
|
13
|
+
"DriftGuardModelWrapper",
|
|
14
|
+
"RetrainerCallbackRunner",
|
|
15
|
+
"settings",
|
|
16
|
+
]
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DriftGuard Alerting & Notification Client.
|
|
3
|
+
Handles distribution of system alert updates to console logs and webhook endpoints like Slack.
|
|
4
|
+
"""
|
|
5
|
+
import httpx
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Dict, Any, Optional
|
|
8
|
+
|
|
9
|
+
from driftguard.config import settings
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger("DriftGuard.Alert")
|
|
12
|
+
|
|
13
|
+
def send_alert(
|
|
14
|
+
event_type: str,
|
|
15
|
+
message: str,
|
|
16
|
+
details: Optional[Dict[str, Any]] = None,
|
|
17
|
+
webhook_url: Optional[str] = None
|
|
18
|
+
) -> bool:
|
|
19
|
+
"""
|
|
20
|
+
Distributes alert notification updates to platform targets.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
event_type: Category of the alert (e.g., 'drift_detected', 'retrain_triggered', 'validation_failed', 'model_promoted').
|
|
24
|
+
message: Readable summary message string.
|
|
25
|
+
details: Optional JSON data properties containing details about the event.
|
|
26
|
+
webhook_url: Optional override of settings Slack Webhook URL.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
True if successfully sent (or logged when offline), False on HTTP failure.
|
|
30
|
+
"""
|
|
31
|
+
payload_details = details or {}
|
|
32
|
+
url = webhook_url or settings.SLACK_WEBHOOK_URL
|
|
33
|
+
|
|
34
|
+
# 1. Standard Logger Print
|
|
35
|
+
log_msg = f"[ALERT - {event_type.upper()}] {message} | Details: {payload_details}"
|
|
36
|
+
if event_type in ["drift_detected", "validation_failed", "rollback"]:
|
|
37
|
+
logger.error(log_msg)
|
|
38
|
+
else:
|
|
39
|
+
logger.info(log_msg)
|
|
40
|
+
|
|
41
|
+
# 2. Slack Webhook Dispatch
|
|
42
|
+
if not url or "mock_webhook_url" in url:
|
|
43
|
+
logger.debug("No active Slack webhook configured. Skipping network alert.")
|
|
44
|
+
return True
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
# Construct premium formatted Slack block
|
|
48
|
+
blocks = [
|
|
49
|
+
{
|
|
50
|
+
"type": "header",
|
|
51
|
+
"text": {
|
|
52
|
+
"type": "plain_text",
|
|
53
|
+
"text": f"🛡️ DriftGuard Alert: {event_type.replace('_', ' ').title()}",
|
|
54
|
+
"emoji": True
|
|
55
|
+
}
|
|
56
|
+
},
|
|
57
|
+
{
|
|
58
|
+
"type": "section",
|
|
59
|
+
"text": {
|
|
60
|
+
"type": "mrkdwn",
|
|
61
|
+
"text": f"*Message:*\n{message}"
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
if payload_details:
|
|
67
|
+
details_str = "\n".join([f"• *{k}:* {v}" for k, v in payload_details.items()])
|
|
68
|
+
blocks.append({
|
|
69
|
+
"type": "section",
|
|
70
|
+
"text": {
|
|
71
|
+
"type": "mrkdwn",
|
|
72
|
+
"text": f"*Metadata:*\n{details_str}"
|
|
73
|
+
}
|
|
74
|
+
})
|
|
75
|
+
|
|
76
|
+
slack_payload = {"blocks": blocks}
|
|
77
|
+
|
|
78
|
+
with httpx.Client(timeout=3.0) as client:
|
|
79
|
+
resp = client.post(url, json=slack_payload)
|
|
80
|
+
if resp.status_code in [200, 201]:
|
|
81
|
+
logger.info("Successfully posted alert to Slack channel.")
|
|
82
|
+
return True
|
|
83
|
+
else:
|
|
84
|
+
logger.error(f"Slack webhook endpoint returned HTTP {resp.status_code}: {resp.text}")
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.error(f"Failed to transmit Slack webhook: {e}")
|
|
89
|
+
return False
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DriftGuard Retrainer Callback Runner.
|
|
3
|
+
|
|
4
|
+
Executes user-registered ``@dg.retrainer`` callbacks entirely inside the
|
|
5
|
+
SDK process — where the user's data, credentials, and environment live.
|
|
6
|
+
|
|
7
|
+
Flow
|
|
8
|
+
----
|
|
9
|
+
1. Notify API: POST /retrain/{model_id} with source="sdk_callback"
|
|
10
|
+
→ Server creates a DBRetrainingEvent record and returns event_id.
|
|
11
|
+
→ Server does NOT spawn its own background pipeline.
|
|
12
|
+
2. Invoke the user's retrainer function → get challenger model.
|
|
13
|
+
3. Validate challenger vs champion using dg.set_validation_data().
|
|
14
|
+
4. Report outcome: POST /retrain/{model_id}/complete.
|
|
15
|
+
→ Server updates model version, accuracy, audit log, Prometheus, Slack.
|
|
16
|
+
5. Update tracker._champion_model to the new champion.
|
|
17
|
+
6. Reset tracker.retraining_triggered so future drift events can fire.
|
|
18
|
+
|
|
19
|
+
NOTE: Production telemetry stored in dg_predictions is never touched.
|
|
20
|
+
Training data comes exclusively from the user's registered callback.
|
|
21
|
+
"""
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
|
26
|
+
|
|
27
|
+
import httpx
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from driftguard.tracker import DriftGuard
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger("DriftGuard.CallbackRunner")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RetrainerCallbackRunner:
|
|
36
|
+
"""
|
|
37
|
+
Orchestrates the full local retraining pipeline for a registered callback.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
tracker:
|
|
42
|
+
The ``DriftGuard`` instance that owns the callback.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, tracker: "DriftGuard") -> None:
|
|
46
|
+
self.tracker = tracker
|
|
47
|
+
self.api_url = tracker.api_url
|
|
48
|
+
self.model_id = tracker.model_id
|
|
49
|
+
|
|
50
|
+
# ------------------------------------------------------------------
|
|
51
|
+
# Public entry point
|
|
52
|
+
# ------------------------------------------------------------------
|
|
53
|
+
|
|
54
|
+
def run(self, drift_score: float) -> bool:
|
|
55
|
+
"""
|
|
56
|
+
Execute the full callback-based retraining pipeline.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
bool
|
|
61
|
+
``True`` if the challenger was promoted; ``False`` otherwise.
|
|
62
|
+
"""
|
|
63
|
+
logger.info(
|
|
64
|
+
f"[{self.model_id}] Callback retraining pipeline started "
|
|
65
|
+
f"(drift_score={drift_score:.4f})"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Step 1 — Record event on the server (no background task spawned)
|
|
69
|
+
current_version = self._get_current_version()
|
|
70
|
+
event_id = self._notify_retrain_start(drift_score)
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
# Step 2 — Invoke user-registered callback
|
|
74
|
+
challenger_model = self._invoke_callback()
|
|
75
|
+
|
|
76
|
+
# Step 3 — Validate challenger against champion
|
|
77
|
+
validation_passed, champ_score, chall_score = self._validate(challenger_model)
|
|
78
|
+
|
|
79
|
+
print("\n===== VALIDATION RESULTS =====")
|
|
80
|
+
print("Champion:", champ_score)
|
|
81
|
+
print("Challenger:", chall_score)
|
|
82
|
+
print("Passed:", validation_passed)
|
|
83
|
+
print("==============================\n")
|
|
84
|
+
|
|
85
|
+
if not validation_passed:
|
|
86
|
+
reason = (
|
|
87
|
+
f"Challenger accuracy {chall_score:.4f} did not beat "
|
|
88
|
+
f"champion {champ_score:.4f} by ≥1%."
|
|
89
|
+
)
|
|
90
|
+
logger.warning(f"[{self.model_id}] {reason}")
|
|
91
|
+
self._report_failure(event_id=event_id, reason=reason, chall_score=chall_score)
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
# Step 4 — Promote challenger
|
|
95
|
+
print("PROMOTION STAGE STARTED")
|
|
96
|
+
new_version = self._bump_version(current_version)
|
|
97
|
+
print(f"NEW VERSION = {new_version}")
|
|
98
|
+
|
|
99
|
+
# Persist challenger model before promotion
|
|
100
|
+
if self.tracker.project_id:
|
|
101
|
+
try:
|
|
102
|
+
import joblib
|
|
103
|
+
import os
|
|
104
|
+
from driftguard.config import settings as _settings
|
|
105
|
+
dir_path = os.path.join(
|
|
106
|
+
_settings.ARTIFACT_ROOT,
|
|
107
|
+
str(self.tracker.project_id),
|
|
108
|
+
self.model_id
|
|
109
|
+
)
|
|
110
|
+
os.makedirs(dir_path, exist_ok=True)
|
|
111
|
+
file_path = os.path.join(dir_path, f"version_{new_version}.pkl")
|
|
112
|
+
joblib.dump(challenger_model, file_path)
|
|
113
|
+
print(f"PERSISTED CHALLENGER MODEL TO {file_path}")
|
|
114
|
+
logger.info(f"[{self.model_id}] Persisted challenger model before promotion to {file_path}")
|
|
115
|
+
except Exception as e:
|
|
116
|
+
print(f"FAILED TO PERSIST CHALLENGER MODEL: {e}")
|
|
117
|
+
logger.warning(f"[{self.model_id}] Failed to persist challenger model: {e}")
|
|
118
|
+
|
|
119
|
+
print("POSTING COMPLETION EVENT")
|
|
120
|
+
self._report_success(
|
|
121
|
+
event_id=event_id,
|
|
122
|
+
new_version=new_version,
|
|
123
|
+
new_accuracy=chall_score,
|
|
124
|
+
old_accuracy=champ_score,
|
|
125
|
+
)
|
|
126
|
+
print("COMPLETION EVENT POSTED")
|
|
127
|
+
|
|
128
|
+
# Step 5 — Update local champion reference so next comparison is correct
|
|
129
|
+
self.tracker._champion_model = challenger_model
|
|
130
|
+
print("CHAMPION UPDATED")
|
|
131
|
+
|
|
132
|
+
logger.info(
|
|
133
|
+
f"[{self.model_id}] Challenger promoted: "
|
|
134
|
+
f"{current_version} → {new_version} "
|
|
135
|
+
f"(accuracy {champ_score:.4f} → {chall_score:.4f})"
|
|
136
|
+
)
|
|
137
|
+
return True
|
|
138
|
+
|
|
139
|
+
except Exception as exc:
|
|
140
|
+
import traceback
|
|
141
|
+
traceback.print_exc()
|
|
142
|
+
logger.error(
|
|
143
|
+
f"[{self.model_id}] Callback retraining pipeline failed: {exc}",
|
|
144
|
+
exc_info=True,
|
|
145
|
+
)
|
|
146
|
+
self._report_failure(event_id=event_id, reason=str(exc), chall_score=0.0)
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
finally:
|
|
150
|
+
# Always reset so future drift events can trigger a new run
|
|
151
|
+
self.tracker.retraining_triggered = False
|
|
152
|
+
|
|
153
|
+
# ------------------------------------------------------------------
|
|
154
|
+
# Step implementations
|
|
155
|
+
# ------------------------------------------------------------------
|
|
156
|
+
|
|
157
|
+
def _invoke_callback(self) -> Any:
|
|
158
|
+
"""
|
|
159
|
+
Call the user's ``@dg.retrainer`` function and return the trained model.
|
|
160
|
+
|
|
161
|
+
Raises
|
|
162
|
+
------
|
|
163
|
+
RuntimeError
|
|
164
|
+
If the callback raises or returns ``None``.
|
|
165
|
+
"""
|
|
166
|
+
fn = self.tracker._retrainer_fn
|
|
167
|
+
logger.info(f"[{self.model_id}] Invoking retrainer callback: {fn.__name__}()")
|
|
168
|
+
try:
|
|
169
|
+
model = fn()
|
|
170
|
+
except Exception as exc:
|
|
171
|
+
raise RuntimeError(
|
|
172
|
+
f"Retrainer callback '{fn.__name__}' raised an exception: {exc}"
|
|
173
|
+
) from exc
|
|
174
|
+
|
|
175
|
+
if model is None:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
f"Retrainer callback '{fn.__name__}' returned None. "
|
|
178
|
+
"The function must return a trained model object."
|
|
179
|
+
)
|
|
180
|
+
if not (hasattr(model, "predict") or callable(model)):
|
|
181
|
+
raise TypeError(
|
|
182
|
+
f"Retrainer callback '{fn.__name__}' returned an invalid model type '{type(model).__name__}'. "
|
|
183
|
+
"The model must have a 'predict' method or be callable."
|
|
184
|
+
)
|
|
185
|
+
return model
|
|
186
|
+
|
|
187
|
+
def _validate(self, challenger_model: Any) -> Tuple[bool, float, float]:
|
|
188
|
+
"""
|
|
189
|
+
Compare challenger against the current champion.
|
|
190
|
+
|
|
191
|
+
If ``dg.set_champion()`` has not been called, the challenger is
|
|
192
|
+
promoted directly as the first known-good version.
|
|
193
|
+
|
|
194
|
+
If ``dg.set_validation_data()`` has not been called, validation is
|
|
195
|
+
skipped and the challenger is promoted with a warning.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
(validation_passed, champion_score, challenger_score)
|
|
200
|
+
"""
|
|
201
|
+
champion_model = self.tracker._champion_model
|
|
202
|
+
|
|
203
|
+
if champion_model is None:
|
|
204
|
+
logger.info(
|
|
205
|
+
f"[{self.model_id}] No champion registered via dg.set_champion(). "
|
|
206
|
+
"Promoting challenger as first champion."
|
|
207
|
+
)
|
|
208
|
+
return True, 0.0, 1.0
|
|
209
|
+
|
|
210
|
+
val_features = self.tracker._validation_features
|
|
211
|
+
val_labels = self.tracker._validation_labels
|
|
212
|
+
|
|
213
|
+
if val_features is None or val_labels is None:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f"Validation data is missing for model '{self.model_id}'. "
|
|
216
|
+
"Validation datasets are required when retraining triggers."
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
from driftguard.validation import validate_challenger_vs_champion
|
|
220
|
+
|
|
221
|
+
return validate_challenger_vs_champion(
|
|
222
|
+
champion_model=champion_model,
|
|
223
|
+
challenger_model=challenger_model,
|
|
224
|
+
val_features=val_features,
|
|
225
|
+
val_labels=val_labels,
|
|
226
|
+
threshold_pct=0.01, # challenger must beat champion by ≥1%
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# ------------------------------------------------------------------
|
|
230
|
+
# API notification helpers
|
|
231
|
+
# ------------------------------------------------------------------
|
|
232
|
+
|
|
233
|
+
def _get_current_version(self) -> str:
|
|
234
|
+
"""Fetch the model's current version string from the API."""
|
|
235
|
+
try:
|
|
236
|
+
headers = {"X-API-Key": self.tracker.api_key} if self.tracker.api_key else {}
|
|
237
|
+
with httpx.Client(timeout=5.0) as client:
|
|
238
|
+
resp = client.get(f"{self.api_url}/models/{self.model_id}", headers=headers)
|
|
239
|
+
if resp.status_code == 200:
|
|
240
|
+
return resp.json().get("version", "1.0.0")
|
|
241
|
+
except Exception as exc:
|
|
242
|
+
logger.debug(f"[{self.model_id}] Could not fetch current version: {exc}")
|
|
243
|
+
return "1.0.0"
|
|
244
|
+
|
|
245
|
+
def _bump_version(self, current_version: str) -> str:
|
|
246
|
+
"""Increment the patch segment of a semantic version string."""
|
|
247
|
+
try:
|
|
248
|
+
parts = current_version.split(".")
|
|
249
|
+
parts[-1] = str(int(parts[-1]) + 1)
|
|
250
|
+
return ".".join(parts)
|
|
251
|
+
except Exception:
|
|
252
|
+
return "1.0.1"
|
|
253
|
+
|
|
254
|
+
def _notify_retrain_start(self, drift_score: float) -> Optional[int]:
|
|
255
|
+
"""
|
|
256
|
+
POST to ``/retrain/{model_id}`` with ``source="sdk_callback"``.
|
|
257
|
+
|
|
258
|
+
The server records the event and returns an ``event_id`` but does
|
|
259
|
+
NOT spawn its own background retraining task.
|
|
260
|
+
"""
|
|
261
|
+
try:
|
|
262
|
+
headers = {"X-API-Key": self.tracker.api_key} if self.tracker.api_key else {}
|
|
263
|
+
with httpx.Client(timeout=5.0) as client:
|
|
264
|
+
resp = client.post(
|
|
265
|
+
f"{self.api_url}/retrain/{self.model_id}",
|
|
266
|
+
json={
|
|
267
|
+
"drift_score": drift_score,
|
|
268
|
+
"triggered_by": "automatic",
|
|
269
|
+
"source": "sdk_callback",
|
|
270
|
+
},
|
|
271
|
+
headers=headers,
|
|
272
|
+
)
|
|
273
|
+
if resp.status_code == 200:
|
|
274
|
+
data = resp.json()
|
|
275
|
+
event_id = data.get("event_id")
|
|
276
|
+
logger.debug(
|
|
277
|
+
f"[{self.model_id}] Retrain event created, event_id={event_id}"
|
|
278
|
+
)
|
|
279
|
+
return event_id
|
|
280
|
+
logger.warning(
|
|
281
|
+
f"[{self.model_id}] /retrain returned HTTP {resp.status_code}"
|
|
282
|
+
)
|
|
283
|
+
except Exception as exc:
|
|
284
|
+
logger.warning(
|
|
285
|
+
f"[{self.model_id}] Could not notify API of retrain start: {exc}"
|
|
286
|
+
)
|
|
287
|
+
return None
|
|
288
|
+
|
|
289
|
+
def _report_success(
|
|
290
|
+
self,
|
|
291
|
+
event_id: Optional[int],
|
|
292
|
+
new_version: str,
|
|
293
|
+
new_accuracy: float,
|
|
294
|
+
old_accuracy: float,
|
|
295
|
+
) -> None:
|
|
296
|
+
"""POST callback pipeline results to ``/retrain/{model_id}/complete``."""
|
|
297
|
+
try:
|
|
298
|
+
headers = {"X-API-Key": self.tracker.api_key} if self.tracker.api_key else {}
|
|
299
|
+
with httpx.Client(timeout=10.0) as client:
|
|
300
|
+
client.post(
|
|
301
|
+
f"{self.api_url}/retrain/{self.model_id}/complete",
|
|
302
|
+
json={
|
|
303
|
+
"event_id": event_id,
|
|
304
|
+
"validation_passed": True,
|
|
305
|
+
"new_version": new_version,
|
|
306
|
+
"new_accuracy": new_accuracy,
|
|
307
|
+
"old_accuracy": old_accuracy,
|
|
308
|
+
"error": None,
|
|
309
|
+
},
|
|
310
|
+
headers=headers,
|
|
311
|
+
)
|
|
312
|
+
except Exception as exc:
|
|
313
|
+
logger.warning(
|
|
314
|
+
f"[{self.model_id}] Could not report retrain success to API: {exc}"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def _report_failure(
|
|
318
|
+
self,
|
|
319
|
+
event_id: Optional[int],
|
|
320
|
+
reason: str,
|
|
321
|
+
chall_score: float,
|
|
322
|
+
) -> None:
|
|
323
|
+
"""Report a failed callback pipeline to ``/retrain/{model_id}/complete``."""
|
|
324
|
+
try:
|
|
325
|
+
headers = {"X-API-Key": self.tracker.api_key} if self.tracker.api_key else {}
|
|
326
|
+
with httpx.Client(timeout=10.0) as client:
|
|
327
|
+
client.post(
|
|
328
|
+
f"{self.api_url}/retrain/{self.model_id}/complete",
|
|
329
|
+
json={
|
|
330
|
+
"event_id": event_id,
|
|
331
|
+
"validation_passed": False,
|
|
332
|
+
"new_version": None,
|
|
333
|
+
"new_accuracy": chall_score,
|
|
334
|
+
"old_accuracy": None,
|
|
335
|
+
"error": reason,
|
|
336
|
+
},
|
|
337
|
+
headers=headers,
|
|
338
|
+
)
|
|
339
|
+
except Exception as exc:
|
|
340
|
+
logger.warning(
|
|
341
|
+
f"[{self.model_id}] Could not report retrain failure to API: {exc}"
|
|
342
|
+
)
|