driftwatch 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- driftwatch/__init__.py +22 -0
- driftwatch/cli/__init__.py +5 -0
- driftwatch/cli/main.py +274 -0
- driftwatch/core/__init__.py +6 -0
- driftwatch/core/monitor.py +153 -0
- driftwatch/core/report.py +162 -0
- driftwatch/detectors/__init__.py +6 -0
- driftwatch/detectors/base.py +67 -0
- driftwatch/detectors/categorical.py +145 -0
- driftwatch/detectors/numerical.py +198 -0
- driftwatch/detectors/registry.py +71 -0
- driftwatch/integrations/__init__.py +5 -0
- driftwatch/integrations/alerting.py +211 -0
- driftwatch/integrations/fastapi.py +297 -0
- driftwatch/py.typed +1 -0
- driftwatch/simulation/__init__.py +1 -0
- driftwatch-0.2.0.dist-info/METADATA +144 -0
- driftwatch-0.2.0.dist-info/RECORD +22 -0
- driftwatch-0.2.0.dist-info/WHEEL +5 -0
- driftwatch-0.2.0.dist-info/entry_points.txt +2 -0
- driftwatch-0.2.0.dist-info/licenses/LICENSE +21 -0
- driftwatch-0.2.0.dist-info/top_level.txt +1 -0
driftwatch/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DriftWatch - Lightweight ML drift monitoring, built for real-world pipelines.
|
|
3
|
+
|
|
4
|
+
DriftWatch is an open-source Python library for detecting data drift and
|
|
5
|
+
model drift in machine learning systems.
|
|
6
|
+
|
|
7
|
+
Basic Usage:
|
|
8
|
+
>>> from driftwatch import Monitor
|
|
9
|
+
>>> monitor = Monitor(reference_data=train_df, features=["age", "income"])
|
|
10
|
+
>>> report = monitor.check(production_df)
|
|
11
|
+
>>> print(report.summary())
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from driftwatch.core.monitor import Monitor
|
|
15
|
+
from driftwatch.core.report import DriftReport
|
|
16
|
+
|
|
17
|
+
__version__ = "0.2.0"
|
|
18
|
+
__all__ = [
|
|
19
|
+
"DriftReport",
|
|
20
|
+
"Monitor",
|
|
21
|
+
"__version__",
|
|
22
|
+
]
|
driftwatch/cli/main.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Main CLI application for DriftWatch.
|
|
2
|
+
|
|
3
|
+
Provides commands for drift checking and reporting.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path # noqa: TC003
|
|
10
|
+
from typing import Annotated
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import typer
|
|
14
|
+
from rich.console import Console
|
|
15
|
+
from rich.table import Table
|
|
16
|
+
|
|
17
|
+
from driftwatch import Monitor
|
|
18
|
+
from driftwatch.core.report import DriftReport, DriftStatus
|
|
19
|
+
|
|
20
|
+
app = typer.Typer(
|
|
21
|
+
name="driftwatch",
|
|
22
|
+
help="🔍 DriftWatch - ML Model Drift Detection",
|
|
23
|
+
add_completion=False,
|
|
24
|
+
)
|
|
25
|
+
console = Console()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_dataframe(path: Path) -> pd.DataFrame:
|
|
29
|
+
"""Load dataframe from CSV or Parquet file.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
path: Path to the file
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Loaded pandas DataFrame
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
typer.BadParameter: If file format is not supported
|
|
39
|
+
"""
|
|
40
|
+
if path.suffix.lower() == ".csv":
|
|
41
|
+
return pd.read_csv(path)
|
|
42
|
+
elif path.suffix.lower() in [".parquet", ".pq"]:
|
|
43
|
+
return pd.read_parquet(path)
|
|
44
|
+
else:
|
|
45
|
+
raise typer.BadParameter(
|
|
46
|
+
f"Unsupported file format: {path.suffix}. Use .csv or .parquet"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@app.command()
|
|
51
|
+
def check(
|
|
52
|
+
ref: Annotated[
|
|
53
|
+
Path,
|
|
54
|
+
typer.Option(
|
|
55
|
+
"--ref",
|
|
56
|
+
"-r",
|
|
57
|
+
help="Path to reference dataset (CSV or Parquet)",
|
|
58
|
+
exists=True,
|
|
59
|
+
file_okay=True,
|
|
60
|
+
dir_okay=False,
|
|
61
|
+
readable=True,
|
|
62
|
+
),
|
|
63
|
+
],
|
|
64
|
+
prod: Annotated[
|
|
65
|
+
Path,
|
|
66
|
+
typer.Option(
|
|
67
|
+
"--prod",
|
|
68
|
+
"-p",
|
|
69
|
+
help="Path to production dataset (CSV or Parquet)",
|
|
70
|
+
exists=True,
|
|
71
|
+
file_okay=True,
|
|
72
|
+
dir_okay=False,
|
|
73
|
+
readable=True,
|
|
74
|
+
),
|
|
75
|
+
],
|
|
76
|
+
threshold_psi: Annotated[
|
|
77
|
+
float, typer.Option("--threshold-psi", help="PSI threshold (default: 0.2)")
|
|
78
|
+
] = 0.2,
|
|
79
|
+
threshold_ks: Annotated[
|
|
80
|
+
float,
|
|
81
|
+
typer.Option("--threshold-ks", help="KS p-value threshold (default: 0.05)"),
|
|
82
|
+
] = 0.05,
|
|
83
|
+
threshold_chi2: Annotated[
|
|
84
|
+
float,
|
|
85
|
+
typer.Option(
|
|
86
|
+
"--threshold-chi2", help="Chi-squared p-value threshold (default: 0.05)"
|
|
87
|
+
),
|
|
88
|
+
] = 0.05,
|
|
89
|
+
output: Annotated[
|
|
90
|
+
Path | None,
|
|
91
|
+
typer.Option("--output", "-o", help="Save report to JSON file"),
|
|
92
|
+
] = None,
|
|
93
|
+
) -> None:
|
|
94
|
+
"""Check for drift between reference and production datasets.
|
|
95
|
+
|
|
96
|
+
Example:
|
|
97
|
+
driftwatch check --ref train.csv --prod prod.csv
|
|
98
|
+
driftwatch check -r train.parquet -p prod.parquet --threshold-psi 0.15
|
|
99
|
+
"""
|
|
100
|
+
console.print("[bold blue]🔍 DriftWatch - Drift Detection[/bold blue]\n")
|
|
101
|
+
|
|
102
|
+
# Load datasets
|
|
103
|
+
console.print(f"Loading reference data from [cyan]{ref}[/cyan]...")
|
|
104
|
+
ref_df = load_dataframe(ref)
|
|
105
|
+
console.print(
|
|
106
|
+
f"✓ Loaded {len(ref_df):,} samples with {len(ref_df.columns)} features\n"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
console.print(f"Loading production data from [cyan]{prod}[/cyan]...")
|
|
110
|
+
prod_df = load_dataframe(prod)
|
|
111
|
+
console.print(
|
|
112
|
+
f"✓ Loaded {len(prod_df):,} samples with {len(prod_df.columns)} features\n"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Create monitor
|
|
116
|
+
console.print("Initializing monitor...")
|
|
117
|
+
monitor = Monitor(
|
|
118
|
+
reference_data=ref_df,
|
|
119
|
+
thresholds={
|
|
120
|
+
"psi": threshold_psi,
|
|
121
|
+
"ks_pvalue": threshold_ks,
|
|
122
|
+
"chi2_pvalue": threshold_chi2,
|
|
123
|
+
},
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Run drift check
|
|
127
|
+
console.print("Running drift detection...\n")
|
|
128
|
+
report = monitor.check(prod_df)
|
|
129
|
+
|
|
130
|
+
# Display results
|
|
131
|
+
_display_report(report)
|
|
132
|
+
|
|
133
|
+
# Save to file if requested
|
|
134
|
+
if output:
|
|
135
|
+
output.write_text(json.dumps(report.to_dict(), indent=2), encoding="utf-8")
|
|
136
|
+
console.print(f"\n✓ Report saved to [cyan]{output}[/cyan]")
|
|
137
|
+
|
|
138
|
+
# Exit with appropriate code
|
|
139
|
+
if report.status == DriftStatus.CRITICAL:
|
|
140
|
+
raise typer.Exit(code=2)
|
|
141
|
+
elif report.status == DriftStatus.WARNING:
|
|
142
|
+
raise typer.Exit(code=1)
|
|
143
|
+
else:
|
|
144
|
+
raise typer.Exit(code=0)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@app.command()
|
|
148
|
+
def report(
|
|
149
|
+
input_file: Annotated[
|
|
150
|
+
Path,
|
|
151
|
+
typer.Argument(
|
|
152
|
+
help="Path to drift report JSON file",
|
|
153
|
+
exists=True,
|
|
154
|
+
file_okay=True,
|
|
155
|
+
dir_okay=False,
|
|
156
|
+
readable=True,
|
|
157
|
+
),
|
|
158
|
+
],
|
|
159
|
+
format: Annotated[
|
|
160
|
+
str,
|
|
161
|
+
typer.Option(
|
|
162
|
+
"--format",
|
|
163
|
+
"-f",
|
|
164
|
+
help="Output format (table or json)",
|
|
165
|
+
),
|
|
166
|
+
] = "table",
|
|
167
|
+
output: Annotated[
|
|
168
|
+
Path | None,
|
|
169
|
+
typer.Option("--output", "-o", help="Save output to file"),
|
|
170
|
+
] = None,
|
|
171
|
+
) -> None:
|
|
172
|
+
"""Display a drift report from a JSON file.
|
|
173
|
+
|
|
174
|
+
Example:
|
|
175
|
+
driftwatch report drift_report.json
|
|
176
|
+
driftwatch report drift_report.json --format json
|
|
177
|
+
driftwatch report drift_report.json --format table --output report.txt
|
|
178
|
+
"""
|
|
179
|
+
# Load report
|
|
180
|
+
data = json.loads(input_file.read_text(encoding="utf-8"))
|
|
181
|
+
|
|
182
|
+
# Reconstruct report (basic reconstruction)
|
|
183
|
+
# In a real scenario, you'd have a from_dict method
|
|
184
|
+
if format == "json":
|
|
185
|
+
output_str = json.dumps(data, indent=2)
|
|
186
|
+
if output:
|
|
187
|
+
output.write_text(output_str, encoding="utf-8")
|
|
188
|
+
console.print(f"✓ Report saved to [cyan]{output}[/cyan]")
|
|
189
|
+
else:
|
|
190
|
+
console.print(output_str)
|
|
191
|
+
else:
|
|
192
|
+
# Table format
|
|
193
|
+
_display_dict_report(data)
|
|
194
|
+
if output:
|
|
195
|
+
# For table output to file, we'd need to capture the rich output
|
|
196
|
+
console.print(
|
|
197
|
+
"[yellow]Warning: Table output to file not yet implemented. Use --format json[/yellow]"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _display_report(report: DriftReport) -> None:
|
|
202
|
+
"""Display drift report with Rich formatting."""
|
|
203
|
+
# Status
|
|
204
|
+
status_colors = {
|
|
205
|
+
DriftStatus.OK: "green",
|
|
206
|
+
DriftStatus.WARNING: "yellow",
|
|
207
|
+
DriftStatus.CRITICAL: "red",
|
|
208
|
+
}
|
|
209
|
+
color = status_colors.get(report.status, "white")
|
|
210
|
+
console.print(f"[bold {color}]Status: {report.status.value}[/bold {color}]")
|
|
211
|
+
|
|
212
|
+
if report.has_drift():
|
|
213
|
+
console.print(
|
|
214
|
+
f"[{color}]Drift Detected: {len(report.drifted_features())}/{len(report.feature_results)} features[/{color}]"
|
|
215
|
+
)
|
|
216
|
+
console.print(f"[{color}]Drift Ratio: {report.drift_ratio():.1%}[/{color}]\n")
|
|
217
|
+
else:
|
|
218
|
+
console.print("[green]No drift detected ✓[/green]\n")
|
|
219
|
+
|
|
220
|
+
# Feature table
|
|
221
|
+
console.print("[bold]Feature Analysis:[/bold]\n")
|
|
222
|
+
|
|
223
|
+
table = Table(show_header=True, header_style="bold cyan")
|
|
224
|
+
table.add_column("Feature", style="cyan")
|
|
225
|
+
table.add_column("Method")
|
|
226
|
+
table.add_column("Score", justify="right")
|
|
227
|
+
table.add_column("Threshold", justify="right")
|
|
228
|
+
table.add_column("Status")
|
|
229
|
+
|
|
230
|
+
for result in report.feature_results:
|
|
231
|
+
status_str = "⚠️ DRIFT" if result.has_drift else "✓ OK"
|
|
232
|
+
status_color = "red" if result.has_drift else "green"
|
|
233
|
+
|
|
234
|
+
table.add_row(
|
|
235
|
+
result.feature_name,
|
|
236
|
+
result.method,
|
|
237
|
+
f"{result.score:.4f}",
|
|
238
|
+
f"{result.threshold:.4f}",
|
|
239
|
+
f"[{status_color}]{status_str}[/{status_color}]",
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
console.print(table)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _display_dict_report(data: dict) -> None:
|
|
246
|
+
"""Display drift report from dictionary data."""
|
|
247
|
+
console.print(f"[bold]Status:[/bold] {data.get('status', 'UNKNOWN')}")
|
|
248
|
+
|
|
249
|
+
if data.get("feature_results"):
|
|
250
|
+
table = Table(show_header=True, header_style="bold cyan")
|
|
251
|
+
table.add_column("Feature", style="cyan")
|
|
252
|
+
table.add_column("Method")
|
|
253
|
+
table.add_column("Score", justify="right")
|
|
254
|
+
table.add_column("Threshold", justify="right")
|
|
255
|
+
table.add_column("Status")
|
|
256
|
+
|
|
257
|
+
for result in data["feature_results"]:
|
|
258
|
+
has_drift = result.get("has_drift", False)
|
|
259
|
+
status_str = "⚠️ DRIFT" if has_drift else "✓ OK"
|
|
260
|
+
status_color = "red" if has_drift else "green"
|
|
261
|
+
|
|
262
|
+
table.add_row(
|
|
263
|
+
result["feature_name"],
|
|
264
|
+
result["method"],
|
|
265
|
+
f"{result['score']:.4f}",
|
|
266
|
+
f"{result['threshold']:.4f}",
|
|
267
|
+
f"[{status_color}]{status_str}[/{status_color}]",
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
console.print(table)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
if __name__ == "__main__":
|
|
274
|
+
app()
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Monitor class for detecting drift between reference and production data.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
8
|
+
|
|
9
|
+
from driftwatch.core.report import DriftReport, FeatureDriftResult
|
|
10
|
+
from driftwatch.detectors import get_detector
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from driftwatch.detectors.base import BaseDetector
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Monitor:
|
|
19
|
+
"""
|
|
20
|
+
Main class for monitoring data and model drift.
|
|
21
|
+
|
|
22
|
+
The Monitor compares production data against a reference dataset
|
|
23
|
+
(typically training data) to detect distribution shifts.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
reference_data: Reference DataFrame (training data)
|
|
27
|
+
features: List of feature columns to monitor.
|
|
28
|
+
If None, all columns are monitored.
|
|
29
|
+
model: Optional ML model for prediction drift detection
|
|
30
|
+
thresholds: Dictionary of threshold values for drift detection.
|
|
31
|
+
Supported keys: "psi", "ks_pvalue", "wasserstein", "chi2_pvalue"
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> monitor = Monitor(
|
|
35
|
+
... reference_data=train_df,
|
|
36
|
+
... features=["age", "income", "category"],
|
|
37
|
+
... thresholds={"psi": 0.2, "ks_pvalue": 0.05}
|
|
38
|
+
... )
|
|
39
|
+
>>> report = monitor.check(production_df)
|
|
40
|
+
>>> print(report.has_drift())
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
DEFAULT_THRESHOLDS: ClassVar[dict[str, float]] = {
|
|
44
|
+
"psi": 0.2,
|
|
45
|
+
"ks_pvalue": 0.05,
|
|
46
|
+
"wasserstein": 0.1,
|
|
47
|
+
"chi2_pvalue": 0.05,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
reference_data: pd.DataFrame,
|
|
53
|
+
features: list[str] | None = None,
|
|
54
|
+
model: Any | None = None,
|
|
55
|
+
thresholds: dict[str, float] | None = None,
|
|
56
|
+
) -> None:
|
|
57
|
+
self._validate_reference_data(reference_data)
|
|
58
|
+
|
|
59
|
+
self.reference_data = reference_data
|
|
60
|
+
self.features = features or list(reference_data.columns)
|
|
61
|
+
self.model = model
|
|
62
|
+
self.thresholds = {**self.DEFAULT_THRESHOLDS, **(thresholds or {})}
|
|
63
|
+
|
|
64
|
+
self._detectors: dict[str, BaseDetector] = {}
|
|
65
|
+
self._setup_detectors()
|
|
66
|
+
|
|
67
|
+
def _validate_reference_data(self, data: pd.DataFrame) -> None:
|
|
68
|
+
"""Validate reference data is not empty."""
|
|
69
|
+
if data.empty:
|
|
70
|
+
raise ValueError("Reference data cannot be empty")
|
|
71
|
+
|
|
72
|
+
def _setup_detectors(self) -> None:
|
|
73
|
+
"""Initialize detectors for each feature based on dtype."""
|
|
74
|
+
for feature in self.features:
|
|
75
|
+
if feature not in self.reference_data.columns:
|
|
76
|
+
raise ValueError(f"Feature '{feature}' not found in reference data")
|
|
77
|
+
|
|
78
|
+
dtype = self.reference_data[feature].dtype
|
|
79
|
+
detector = get_detector(dtype, self.thresholds)
|
|
80
|
+
self._detectors[feature] = detector
|
|
81
|
+
|
|
82
|
+
def check(self, production_data: pd.DataFrame) -> DriftReport:
|
|
83
|
+
"""
|
|
84
|
+
Check for drift between reference and production data.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
production_data: Production DataFrame to compare
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
DriftReport containing per-feature and aggregate drift results
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
ValueError: If production data is empty or missing features
|
|
94
|
+
"""
|
|
95
|
+
self._validate_production_data(production_data)
|
|
96
|
+
|
|
97
|
+
feature_results: list[FeatureDriftResult] = []
|
|
98
|
+
|
|
99
|
+
for feature in self.features:
|
|
100
|
+
ref_series = self.reference_data[feature]
|
|
101
|
+
prod_series = production_data[feature]
|
|
102
|
+
|
|
103
|
+
detector = self._detectors[feature]
|
|
104
|
+
result = detector.detect(ref_series, prod_series)
|
|
105
|
+
|
|
106
|
+
feature_results.append(
|
|
107
|
+
FeatureDriftResult(
|
|
108
|
+
feature_name=feature,
|
|
109
|
+
has_drift=result.has_drift,
|
|
110
|
+
score=result.score,
|
|
111
|
+
method=result.method,
|
|
112
|
+
threshold=result.threshold,
|
|
113
|
+
p_value=result.p_value,
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return DriftReport(
|
|
118
|
+
feature_results=feature_results,
|
|
119
|
+
reference_size=len(self.reference_data),
|
|
120
|
+
production_size=len(production_data),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def _validate_production_data(self, data: pd.DataFrame) -> None:
|
|
124
|
+
"""Validate production data has required features."""
|
|
125
|
+
if data.empty:
|
|
126
|
+
raise ValueError("Production data cannot be empty")
|
|
127
|
+
|
|
128
|
+
missing = set(self.features) - set(data.columns)
|
|
129
|
+
if missing:
|
|
130
|
+
raise ValueError(f"Missing features in production data: {missing}")
|
|
131
|
+
|
|
132
|
+
def add_feature(self, feature: str) -> None:
|
|
133
|
+
"""Add a feature to monitor."""
|
|
134
|
+
if feature in self.features:
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
if feature not in self.reference_data.columns:
|
|
138
|
+
raise ValueError(f"Feature '{feature}' not found in reference data")
|
|
139
|
+
|
|
140
|
+
self.features.append(feature)
|
|
141
|
+
dtype = self.reference_data[feature].dtype
|
|
142
|
+
self._detectors[feature] = get_detector(dtype, self.thresholds)
|
|
143
|
+
|
|
144
|
+
def remove_feature(self, feature: str) -> None:
|
|
145
|
+
"""Remove a feature from monitoring."""
|
|
146
|
+
if feature in self.features:
|
|
147
|
+
self.features.remove(feature)
|
|
148
|
+
del self._detectors[feature]
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def monitored_features(self) -> list[str]:
|
|
152
|
+
"""Return list of monitored features."""
|
|
153
|
+
return self.features.copy()
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DriftReport class for structured drift detection results.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DriftStatus(str, Enum):
|
|
15
|
+
"""Overall drift status levels."""
|
|
16
|
+
|
|
17
|
+
OK = "OK"
|
|
18
|
+
WARNING = "WARNING"
|
|
19
|
+
CRITICAL = "CRITICAL"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class FeatureDriftResult:
|
|
24
|
+
"""Result of drift detection for a single feature."""
|
|
25
|
+
|
|
26
|
+
feature_name: str
|
|
27
|
+
has_drift: bool
|
|
28
|
+
score: float
|
|
29
|
+
method: str
|
|
30
|
+
threshold: float
|
|
31
|
+
p_value: float | None = None
|
|
32
|
+
|
|
33
|
+
def to_dict(self) -> dict[str, Any]:
|
|
34
|
+
"""Convert to dictionary."""
|
|
35
|
+
return {
|
|
36
|
+
"feature_name": self.feature_name,
|
|
37
|
+
"has_drift": self.has_drift,
|
|
38
|
+
"score": self.score,
|
|
39
|
+
"method": self.method,
|
|
40
|
+
"threshold": self.threshold,
|
|
41
|
+
"p_value": self.p_value,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class DriftReport:
|
|
47
|
+
"""
|
|
48
|
+
Comprehensive report of drift detection results.
|
|
49
|
+
|
|
50
|
+
Contains per-feature metrics and aggregate status.
|
|
51
|
+
|
|
52
|
+
Attributes:
|
|
53
|
+
feature_results: List of per-feature drift results
|
|
54
|
+
reference_size: Number of samples in reference data
|
|
55
|
+
production_size: Number of samples in production data
|
|
56
|
+
timestamp: When the check was performed
|
|
57
|
+
model_version: Optional model version identifier
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
feature_results: list[FeatureDriftResult]
|
|
61
|
+
reference_size: int
|
|
62
|
+
production_size: int
|
|
63
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
64
|
+
model_version: str | None = None
|
|
65
|
+
|
|
66
|
+
def has_drift(self) -> bool:
|
|
67
|
+
"""Check if any feature has drift."""
|
|
68
|
+
return any(r.has_drift for r in self.feature_results)
|
|
69
|
+
|
|
70
|
+
def drifted_features(self) -> list[str]:
|
|
71
|
+
"""Return list of features with detected drift."""
|
|
72
|
+
return [r.feature_name for r in self.feature_results if r.has_drift]
|
|
73
|
+
|
|
74
|
+
def drift_ratio(self) -> float:
|
|
75
|
+
"""Return ratio of drifted features to total features."""
|
|
76
|
+
if not self.feature_results:
|
|
77
|
+
return 0.0
|
|
78
|
+
return len(self.drifted_features()) / len(self.feature_results)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def status(self) -> DriftStatus:
|
|
82
|
+
"""
|
|
83
|
+
Determine overall drift status.
|
|
84
|
+
|
|
85
|
+
- OK: No drift detected
|
|
86
|
+
- WARNING: <50% of features have drift
|
|
87
|
+
- CRITICAL: >=50% of features have drift
|
|
88
|
+
"""
|
|
89
|
+
ratio = self.drift_ratio()
|
|
90
|
+
if ratio == 0:
|
|
91
|
+
return DriftStatus.OK
|
|
92
|
+
elif ratio < 0.5:
|
|
93
|
+
return DriftStatus.WARNING
|
|
94
|
+
else:
|
|
95
|
+
return DriftStatus.CRITICAL
|
|
96
|
+
|
|
97
|
+
def feature_drift(self, feature_name: str) -> FeatureDriftResult | None:
|
|
98
|
+
"""Get drift result for a specific feature."""
|
|
99
|
+
for result in self.feature_results:
|
|
100
|
+
if result.feature_name == feature_name:
|
|
101
|
+
return result
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
def summary(self) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Generate a human-readable summary of the drift report.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Formatted string summary
|
|
110
|
+
"""
|
|
111
|
+
lines = [
|
|
112
|
+
"=" * 50,
|
|
113
|
+
"DRIFT REPORT",
|
|
114
|
+
"=" * 50,
|
|
115
|
+
f"Status: {self.status.value}",
|
|
116
|
+
f"Timestamp: {self.timestamp.isoformat()}",
|
|
117
|
+
f"Reference samples: {self.reference_size:,}",
|
|
118
|
+
f"Production samples: {self.production_size:,}",
|
|
119
|
+
"",
|
|
120
|
+
f"Features analyzed: {len(self.feature_results)}",
|
|
121
|
+
f"Features with drift: {len(self.drifted_features())}",
|
|
122
|
+
f"Drift ratio: {self.drift_ratio():.1%}",
|
|
123
|
+
"",
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
if self.drifted_features():
|
|
127
|
+
lines.append("Drifted features:")
|
|
128
|
+
for result in self.feature_results:
|
|
129
|
+
if result.has_drift:
|
|
130
|
+
lines.append(
|
|
131
|
+
f" - {result.feature_name}: "
|
|
132
|
+
f"{result.method}={result.score:.4f} "
|
|
133
|
+
f"(threshold={result.threshold})"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
lines.append("=" * 50)
|
|
137
|
+
return "\n".join(lines)
|
|
138
|
+
|
|
139
|
+
def to_dict(self) -> dict[str, Any]:
|
|
140
|
+
"""Convert report to dictionary."""
|
|
141
|
+
return {
|
|
142
|
+
"status": self.status.value,
|
|
143
|
+
"timestamp": self.timestamp.isoformat(),
|
|
144
|
+
"reference_size": self.reference_size,
|
|
145
|
+
"production_size": self.production_size,
|
|
146
|
+
"model_version": self.model_version,
|
|
147
|
+
"has_drift": self.has_drift(),
|
|
148
|
+
"drift_ratio": self.drift_ratio(),
|
|
149
|
+
"drifted_features": self.drifted_features(),
|
|
150
|
+
"feature_results": [r.to_dict() for r in self.feature_results],
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
def to_json(self, indent: int = 2) -> str:
|
|
154
|
+
"""Convert report to JSON string."""
|
|
155
|
+
return json.dumps(self.to_dict(), indent=indent, default=str)
|
|
156
|
+
|
|
157
|
+
def __repr__(self) -> str:
|
|
158
|
+
return (
|
|
159
|
+
f"DriftReport(status={self.status.value}, "
|
|
160
|
+
f"features={len(self.feature_results)}, "
|
|
161
|
+
f"drifted={len(self.drifted_features())})"
|
|
162
|
+
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Base class for drift detectors."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class DetectionResult:
|
|
15
|
+
"""Result from a drift detection test."""
|
|
16
|
+
|
|
17
|
+
has_drift: bool
|
|
18
|
+
score: float
|
|
19
|
+
method: str
|
|
20
|
+
threshold: float
|
|
21
|
+
p_value: float | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseDetector(ABC):
|
|
25
|
+
"""
|
|
26
|
+
Abstract base class for drift detectors.
|
|
27
|
+
|
|
28
|
+
All drift detection methods should inherit from this class
|
|
29
|
+
and implement the `detect` method.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
threshold: Threshold value for determining drift
|
|
33
|
+
name: Human-readable name for the detector
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, threshold: float, name: str) -> None:
|
|
37
|
+
self.threshold = threshold
|
|
38
|
+
self.name = name
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def detect(
|
|
42
|
+
self,
|
|
43
|
+
reference: pd.Series,
|
|
44
|
+
production: pd.Series,
|
|
45
|
+
) -> DetectionResult:
|
|
46
|
+
"""
|
|
47
|
+
Detect drift between reference and production data.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
reference: Reference data series
|
|
51
|
+
production: Production data series
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
DetectionResult with drift status and metrics
|
|
55
|
+
"""
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
def _validate_inputs(
|
|
59
|
+
self,
|
|
60
|
+
reference: pd.Series,
|
|
61
|
+
production: pd.Series,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Validate input series are not empty."""
|
|
64
|
+
if reference.empty:
|
|
65
|
+
raise ValueError("Reference series cannot be empty")
|
|
66
|
+
if production.empty:
|
|
67
|
+
raise ValueError("Production series cannot be empty")
|