driftwatch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
driftwatch/__init__.py ADDED
@@ -0,0 +1,22 @@
1
+ """
2
+ DriftWatch - Lightweight ML drift monitoring, built for real-world pipelines.
3
+
4
+ DriftWatch is an open-source Python library for detecting data drift and
5
+ model drift in machine learning systems.
6
+
7
+ Basic Usage:
8
+ >>> from driftwatch import Monitor
9
+ >>> monitor = Monitor(reference_data=train_df, features=["age", "income"])
10
+ >>> report = monitor.check(production_df)
11
+ >>> print(report.summary())
12
+ """
13
+
14
+ from driftwatch.core.monitor import Monitor
15
+ from driftwatch.core.report import DriftReport
16
+
17
+ __version__ = "0.2.0"
18
+ __all__ = [
19
+ "DriftReport",
20
+ "Monitor",
21
+ "__version__",
22
+ ]
@@ -0,0 +1,5 @@
1
+ """Command-line interface for DriftWatch."""
2
+
3
+ from driftwatch.cli.main import app
4
+
5
+ __all__ = ["app"]
driftwatch/cli/main.py ADDED
@@ -0,0 +1,274 @@
1
+ """Main CLI application for DriftWatch.
2
+
3
+ Provides commands for drift checking and reporting.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ from pathlib import Path # noqa: TC003
10
+ from typing import Annotated
11
+
12
+ import pandas as pd
13
+ import typer
14
+ from rich.console import Console
15
+ from rich.table import Table
16
+
17
+ from driftwatch import Monitor
18
+ from driftwatch.core.report import DriftReport, DriftStatus
19
+
20
+ app = typer.Typer(
21
+ name="driftwatch",
22
+ help="🔍 DriftWatch - ML Model Drift Detection",
23
+ add_completion=False,
24
+ )
25
+ console = Console()
26
+
27
+
28
+ def load_dataframe(path: Path) -> pd.DataFrame:
29
+ """Load dataframe from CSV or Parquet file.
30
+
31
+ Args:
32
+ path: Path to the file
33
+
34
+ Returns:
35
+ Loaded pandas DataFrame
36
+
37
+ Raises:
38
+ typer.BadParameter: If file format is not supported
39
+ """
40
+ if path.suffix.lower() == ".csv":
41
+ return pd.read_csv(path)
42
+ elif path.suffix.lower() in [".parquet", ".pq"]:
43
+ return pd.read_parquet(path)
44
+ else:
45
+ raise typer.BadParameter(
46
+ f"Unsupported file format: {path.suffix}. Use .csv or .parquet"
47
+ )
48
+
49
+
50
+ @app.command()
51
+ def check(
52
+ ref: Annotated[
53
+ Path,
54
+ typer.Option(
55
+ "--ref",
56
+ "-r",
57
+ help="Path to reference dataset (CSV or Parquet)",
58
+ exists=True,
59
+ file_okay=True,
60
+ dir_okay=False,
61
+ readable=True,
62
+ ),
63
+ ],
64
+ prod: Annotated[
65
+ Path,
66
+ typer.Option(
67
+ "--prod",
68
+ "-p",
69
+ help="Path to production dataset (CSV or Parquet)",
70
+ exists=True,
71
+ file_okay=True,
72
+ dir_okay=False,
73
+ readable=True,
74
+ ),
75
+ ],
76
+ threshold_psi: Annotated[
77
+ float, typer.Option("--threshold-psi", help="PSI threshold (default: 0.2)")
78
+ ] = 0.2,
79
+ threshold_ks: Annotated[
80
+ float,
81
+ typer.Option("--threshold-ks", help="KS p-value threshold (default: 0.05)"),
82
+ ] = 0.05,
83
+ threshold_chi2: Annotated[
84
+ float,
85
+ typer.Option(
86
+ "--threshold-chi2", help="Chi-squared p-value threshold (default: 0.05)"
87
+ ),
88
+ ] = 0.05,
89
+ output: Annotated[
90
+ Path | None,
91
+ typer.Option("--output", "-o", help="Save report to JSON file"),
92
+ ] = None,
93
+ ) -> None:
94
+ """Check for drift between reference and production datasets.
95
+
96
+ Example:
97
+ driftwatch check --ref train.csv --prod prod.csv
98
+ driftwatch check -r train.parquet -p prod.parquet --threshold-psi 0.15
99
+ """
100
+ console.print("[bold blue]🔍 DriftWatch - Drift Detection[/bold blue]\n")
101
+
102
+ # Load datasets
103
+ console.print(f"Loading reference data from [cyan]{ref}[/cyan]...")
104
+ ref_df = load_dataframe(ref)
105
+ console.print(
106
+ f"✓ Loaded {len(ref_df):,} samples with {len(ref_df.columns)} features\n"
107
+ )
108
+
109
+ console.print(f"Loading production data from [cyan]{prod}[/cyan]...")
110
+ prod_df = load_dataframe(prod)
111
+ console.print(
112
+ f"✓ Loaded {len(prod_df):,} samples with {len(prod_df.columns)} features\n"
113
+ )
114
+
115
+ # Create monitor
116
+ console.print("Initializing monitor...")
117
+ monitor = Monitor(
118
+ reference_data=ref_df,
119
+ thresholds={
120
+ "psi": threshold_psi,
121
+ "ks_pvalue": threshold_ks,
122
+ "chi2_pvalue": threshold_chi2,
123
+ },
124
+ )
125
+
126
+ # Run drift check
127
+ console.print("Running drift detection...\n")
128
+ report = monitor.check(prod_df)
129
+
130
+ # Display results
131
+ _display_report(report)
132
+
133
+ # Save to file if requested
134
+ if output:
135
+ output.write_text(json.dumps(report.to_dict(), indent=2), encoding="utf-8")
136
+ console.print(f"\n✓ Report saved to [cyan]{output}[/cyan]")
137
+
138
+ # Exit with appropriate code
139
+ if report.status == DriftStatus.CRITICAL:
140
+ raise typer.Exit(code=2)
141
+ elif report.status == DriftStatus.WARNING:
142
+ raise typer.Exit(code=1)
143
+ else:
144
+ raise typer.Exit(code=0)
145
+
146
+
147
+ @app.command()
148
+ def report(
149
+ input_file: Annotated[
150
+ Path,
151
+ typer.Argument(
152
+ help="Path to drift report JSON file",
153
+ exists=True,
154
+ file_okay=True,
155
+ dir_okay=False,
156
+ readable=True,
157
+ ),
158
+ ],
159
+ format: Annotated[
160
+ str,
161
+ typer.Option(
162
+ "--format",
163
+ "-f",
164
+ help="Output format (table or json)",
165
+ ),
166
+ ] = "table",
167
+ output: Annotated[
168
+ Path | None,
169
+ typer.Option("--output", "-o", help="Save output to file"),
170
+ ] = None,
171
+ ) -> None:
172
+ """Display a drift report from a JSON file.
173
+
174
+ Example:
175
+ driftwatch report drift_report.json
176
+ driftwatch report drift_report.json --format json
177
+ driftwatch report drift_report.json --format table --output report.txt
178
+ """
179
+ # Load report
180
+ data = json.loads(input_file.read_text(encoding="utf-8"))
181
+
182
+ # Reconstruct report (basic reconstruction)
183
+ # In a real scenario, you'd have a from_dict method
184
+ if format == "json":
185
+ output_str = json.dumps(data, indent=2)
186
+ if output:
187
+ output.write_text(output_str, encoding="utf-8")
188
+ console.print(f"✓ Report saved to [cyan]{output}[/cyan]")
189
+ else:
190
+ console.print(output_str)
191
+ else:
192
+ # Table format
193
+ _display_dict_report(data)
194
+ if output:
195
+ # For table output to file, we'd need to capture the rich output
196
+ console.print(
197
+ "[yellow]Warning: Table output to file not yet implemented. Use --format json[/yellow]"
198
+ )
199
+
200
+
201
+ def _display_report(report: DriftReport) -> None:
202
+ """Display drift report with Rich formatting."""
203
+ # Status
204
+ status_colors = {
205
+ DriftStatus.OK: "green",
206
+ DriftStatus.WARNING: "yellow",
207
+ DriftStatus.CRITICAL: "red",
208
+ }
209
+ color = status_colors.get(report.status, "white")
210
+ console.print(f"[bold {color}]Status: {report.status.value}[/bold {color}]")
211
+
212
+ if report.has_drift():
213
+ console.print(
214
+ f"[{color}]Drift Detected: {len(report.drifted_features())}/{len(report.feature_results)} features[/{color}]"
215
+ )
216
+ console.print(f"[{color}]Drift Ratio: {report.drift_ratio():.1%}[/{color}]\n")
217
+ else:
218
+ console.print("[green]No drift detected ✓[/green]\n")
219
+
220
+ # Feature table
221
+ console.print("[bold]Feature Analysis:[/bold]\n")
222
+
223
+ table = Table(show_header=True, header_style="bold cyan")
224
+ table.add_column("Feature", style="cyan")
225
+ table.add_column("Method")
226
+ table.add_column("Score", justify="right")
227
+ table.add_column("Threshold", justify="right")
228
+ table.add_column("Status")
229
+
230
+ for result in report.feature_results:
231
+ status_str = "⚠️ DRIFT" if result.has_drift else "✓ OK"
232
+ status_color = "red" if result.has_drift else "green"
233
+
234
+ table.add_row(
235
+ result.feature_name,
236
+ result.method,
237
+ f"{result.score:.4f}",
238
+ f"{result.threshold:.4f}",
239
+ f"[{status_color}]{status_str}[/{status_color}]",
240
+ )
241
+
242
+ console.print(table)
243
+
244
+
245
+ def _display_dict_report(data: dict) -> None:
246
+ """Display drift report from dictionary data."""
247
+ console.print(f"[bold]Status:[/bold] {data.get('status', 'UNKNOWN')}")
248
+
249
+ if data.get("feature_results"):
250
+ table = Table(show_header=True, header_style="bold cyan")
251
+ table.add_column("Feature", style="cyan")
252
+ table.add_column("Method")
253
+ table.add_column("Score", justify="right")
254
+ table.add_column("Threshold", justify="right")
255
+ table.add_column("Status")
256
+
257
+ for result in data["feature_results"]:
258
+ has_drift = result.get("has_drift", False)
259
+ status_str = "⚠️ DRIFT" if has_drift else "✓ OK"
260
+ status_color = "red" if has_drift else "green"
261
+
262
+ table.add_row(
263
+ result["feature_name"],
264
+ result["method"],
265
+ f"{result['score']:.4f}",
266
+ f"{result['threshold']:.4f}",
267
+ f"[{status_color}]{status_str}[/{status_color}]",
268
+ )
269
+
270
+ console.print(table)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ app()
@@ -0,0 +1,6 @@
1
+ """Core module containing the main Monitor and DriftReport classes."""
2
+
3
+ from driftwatch.core.monitor import Monitor
4
+ from driftwatch.core.report import DriftReport
5
+
6
+ __all__ = ["DriftReport", "Monitor"]
@@ -0,0 +1,153 @@
1
+ """
2
+ Monitor class for detecting drift between reference and production data.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Any, ClassVar
8
+
9
+ from driftwatch.core.report import DriftReport, FeatureDriftResult
10
+ from driftwatch.detectors import get_detector
11
+
12
+ if TYPE_CHECKING:
13
+ import pandas as pd
14
+
15
+ from driftwatch.detectors.base import BaseDetector
16
+
17
+
18
+ class Monitor:
19
+ """
20
+ Main class for monitoring data and model drift.
21
+
22
+ The Monitor compares production data against a reference dataset
23
+ (typically training data) to detect distribution shifts.
24
+
25
+ Args:
26
+ reference_data: Reference DataFrame (training data)
27
+ features: List of feature columns to monitor.
28
+ If None, all columns are monitored.
29
+ model: Optional ML model for prediction drift detection
30
+ thresholds: Dictionary of threshold values for drift detection.
31
+ Supported keys: "psi", "ks_pvalue", "wasserstein", "chi2_pvalue"
32
+
33
+ Example:
34
+ >>> monitor = Monitor(
35
+ ... reference_data=train_df,
36
+ ... features=["age", "income", "category"],
37
+ ... thresholds={"psi": 0.2, "ks_pvalue": 0.05}
38
+ ... )
39
+ >>> report = monitor.check(production_df)
40
+ >>> print(report.has_drift())
41
+ """
42
+
43
+ DEFAULT_THRESHOLDS: ClassVar[dict[str, float]] = {
44
+ "psi": 0.2,
45
+ "ks_pvalue": 0.05,
46
+ "wasserstein": 0.1,
47
+ "chi2_pvalue": 0.05,
48
+ }
49
+
50
+ def __init__(
51
+ self,
52
+ reference_data: pd.DataFrame,
53
+ features: list[str] | None = None,
54
+ model: Any | None = None,
55
+ thresholds: dict[str, float] | None = None,
56
+ ) -> None:
57
+ self._validate_reference_data(reference_data)
58
+
59
+ self.reference_data = reference_data
60
+ self.features = features or list(reference_data.columns)
61
+ self.model = model
62
+ self.thresholds = {**self.DEFAULT_THRESHOLDS, **(thresholds or {})}
63
+
64
+ self._detectors: dict[str, BaseDetector] = {}
65
+ self._setup_detectors()
66
+
67
+ def _validate_reference_data(self, data: pd.DataFrame) -> None:
68
+ """Validate reference data is not empty."""
69
+ if data.empty:
70
+ raise ValueError("Reference data cannot be empty")
71
+
72
+ def _setup_detectors(self) -> None:
73
+ """Initialize detectors for each feature based on dtype."""
74
+ for feature in self.features:
75
+ if feature not in self.reference_data.columns:
76
+ raise ValueError(f"Feature '{feature}' not found in reference data")
77
+
78
+ dtype = self.reference_data[feature].dtype
79
+ detector = get_detector(dtype, self.thresholds)
80
+ self._detectors[feature] = detector
81
+
82
+ def check(self, production_data: pd.DataFrame) -> DriftReport:
83
+ """
84
+ Check for drift between reference and production data.
85
+
86
+ Args:
87
+ production_data: Production DataFrame to compare
88
+
89
+ Returns:
90
+ DriftReport containing per-feature and aggregate drift results
91
+
92
+ Raises:
93
+ ValueError: If production data is empty or missing features
94
+ """
95
+ self._validate_production_data(production_data)
96
+
97
+ feature_results: list[FeatureDriftResult] = []
98
+
99
+ for feature in self.features:
100
+ ref_series = self.reference_data[feature]
101
+ prod_series = production_data[feature]
102
+
103
+ detector = self._detectors[feature]
104
+ result = detector.detect(ref_series, prod_series)
105
+
106
+ feature_results.append(
107
+ FeatureDriftResult(
108
+ feature_name=feature,
109
+ has_drift=result.has_drift,
110
+ score=result.score,
111
+ method=result.method,
112
+ threshold=result.threshold,
113
+ p_value=result.p_value,
114
+ )
115
+ )
116
+
117
+ return DriftReport(
118
+ feature_results=feature_results,
119
+ reference_size=len(self.reference_data),
120
+ production_size=len(production_data),
121
+ )
122
+
123
+ def _validate_production_data(self, data: pd.DataFrame) -> None:
124
+ """Validate production data has required features."""
125
+ if data.empty:
126
+ raise ValueError("Production data cannot be empty")
127
+
128
+ missing = set(self.features) - set(data.columns)
129
+ if missing:
130
+ raise ValueError(f"Missing features in production data: {missing}")
131
+
132
+ def add_feature(self, feature: str) -> None:
133
+ """Add a feature to monitor."""
134
+ if feature in self.features:
135
+ return
136
+
137
+ if feature not in self.reference_data.columns:
138
+ raise ValueError(f"Feature '{feature}' not found in reference data")
139
+
140
+ self.features.append(feature)
141
+ dtype = self.reference_data[feature].dtype
142
+ self._detectors[feature] = get_detector(dtype, self.thresholds)
143
+
144
+ def remove_feature(self, feature: str) -> None:
145
+ """Remove a feature from monitoring."""
146
+ if feature in self.features:
147
+ self.features.remove(feature)
148
+ del self._detectors[feature]
149
+
150
+ @property
151
+ def monitored_features(self) -> list[str]:
152
+ """Return list of monitored features."""
153
+ return self.features.copy()
@@ -0,0 +1,162 @@
1
+ """
2
+ DriftReport class for structured drift detection results.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from dataclasses import dataclass, field
9
+ from datetime import datetime, timezone
10
+ from enum import Enum
11
+ from typing import Any
12
+
13
+
14
+ class DriftStatus(str, Enum):
15
+ """Overall drift status levels."""
16
+
17
+ OK = "OK"
18
+ WARNING = "WARNING"
19
+ CRITICAL = "CRITICAL"
20
+
21
+
22
+ @dataclass
23
+ class FeatureDriftResult:
24
+ """Result of drift detection for a single feature."""
25
+
26
+ feature_name: str
27
+ has_drift: bool
28
+ score: float
29
+ method: str
30
+ threshold: float
31
+ p_value: float | None = None
32
+
33
+ def to_dict(self) -> dict[str, Any]:
34
+ """Convert to dictionary."""
35
+ return {
36
+ "feature_name": self.feature_name,
37
+ "has_drift": self.has_drift,
38
+ "score": self.score,
39
+ "method": self.method,
40
+ "threshold": self.threshold,
41
+ "p_value": self.p_value,
42
+ }
43
+
44
+
45
+ @dataclass
46
+ class DriftReport:
47
+ """
48
+ Comprehensive report of drift detection results.
49
+
50
+ Contains per-feature metrics and aggregate status.
51
+
52
+ Attributes:
53
+ feature_results: List of per-feature drift results
54
+ reference_size: Number of samples in reference data
55
+ production_size: Number of samples in production data
56
+ timestamp: When the check was performed
57
+ model_version: Optional model version identifier
58
+ """
59
+
60
+ feature_results: list[FeatureDriftResult]
61
+ reference_size: int
62
+ production_size: int
63
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
64
+ model_version: str | None = None
65
+
66
+ def has_drift(self) -> bool:
67
+ """Check if any feature has drift."""
68
+ return any(r.has_drift for r in self.feature_results)
69
+
70
+ def drifted_features(self) -> list[str]:
71
+ """Return list of features with detected drift."""
72
+ return [r.feature_name for r in self.feature_results if r.has_drift]
73
+
74
+ def drift_ratio(self) -> float:
75
+ """Return ratio of drifted features to total features."""
76
+ if not self.feature_results:
77
+ return 0.0
78
+ return len(self.drifted_features()) / len(self.feature_results)
79
+
80
+ @property
81
+ def status(self) -> DriftStatus:
82
+ """
83
+ Determine overall drift status.
84
+
85
+ - OK: No drift detected
86
+ - WARNING: <50% of features have drift
87
+ - CRITICAL: >=50% of features have drift
88
+ """
89
+ ratio = self.drift_ratio()
90
+ if ratio == 0:
91
+ return DriftStatus.OK
92
+ elif ratio < 0.5:
93
+ return DriftStatus.WARNING
94
+ else:
95
+ return DriftStatus.CRITICAL
96
+
97
+ def feature_drift(self, feature_name: str) -> FeatureDriftResult | None:
98
+ """Get drift result for a specific feature."""
99
+ for result in self.feature_results:
100
+ if result.feature_name == feature_name:
101
+ return result
102
+ return None
103
+
104
+ def summary(self) -> str:
105
+ """
106
+ Generate a human-readable summary of the drift report.
107
+
108
+ Returns:
109
+ Formatted string summary
110
+ """
111
+ lines = [
112
+ "=" * 50,
113
+ "DRIFT REPORT",
114
+ "=" * 50,
115
+ f"Status: {self.status.value}",
116
+ f"Timestamp: {self.timestamp.isoformat()}",
117
+ f"Reference samples: {self.reference_size:,}",
118
+ f"Production samples: {self.production_size:,}",
119
+ "",
120
+ f"Features analyzed: {len(self.feature_results)}",
121
+ f"Features with drift: {len(self.drifted_features())}",
122
+ f"Drift ratio: {self.drift_ratio():.1%}",
123
+ "",
124
+ ]
125
+
126
+ if self.drifted_features():
127
+ lines.append("Drifted features:")
128
+ for result in self.feature_results:
129
+ if result.has_drift:
130
+ lines.append(
131
+ f" - {result.feature_name}: "
132
+ f"{result.method}={result.score:.4f} "
133
+ f"(threshold={result.threshold})"
134
+ )
135
+
136
+ lines.append("=" * 50)
137
+ return "\n".join(lines)
138
+
139
+ def to_dict(self) -> dict[str, Any]:
140
+ """Convert report to dictionary."""
141
+ return {
142
+ "status": self.status.value,
143
+ "timestamp": self.timestamp.isoformat(),
144
+ "reference_size": self.reference_size,
145
+ "production_size": self.production_size,
146
+ "model_version": self.model_version,
147
+ "has_drift": self.has_drift(),
148
+ "drift_ratio": self.drift_ratio(),
149
+ "drifted_features": self.drifted_features(),
150
+ "feature_results": [r.to_dict() for r in self.feature_results],
151
+ }
152
+
153
+ def to_json(self, indent: int = 2) -> str:
154
+ """Convert report to JSON string."""
155
+ return json.dumps(self.to_dict(), indent=indent, default=str)
156
+
157
+ def __repr__(self) -> str:
158
+ return (
159
+ f"DriftReport(status={self.status.value}, "
160
+ f"features={len(self.feature_results)}, "
161
+ f"drifted={len(self.drifted_features())})"
162
+ )
@@ -0,0 +1,6 @@
1
+ """Drift detectors module."""
2
+
3
+ from driftwatch.detectors.base import BaseDetector, DetectionResult
4
+ from driftwatch.detectors.registry import get_detector
5
+
6
+ __all__ = ["BaseDetector", "DetectionResult", "get_detector"]
@@ -0,0 +1,67 @@
1
+ """Base class for drift detectors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ import pandas as pd
11
+
12
+
13
+ @dataclass
14
+ class DetectionResult:
15
+ """Result from a drift detection test."""
16
+
17
+ has_drift: bool
18
+ score: float
19
+ method: str
20
+ threshold: float
21
+ p_value: float | None = None
22
+
23
+
24
+ class BaseDetector(ABC):
25
+ """
26
+ Abstract base class for drift detectors.
27
+
28
+ All drift detection methods should inherit from this class
29
+ and implement the `detect` method.
30
+
31
+ Args:
32
+ threshold: Threshold value for determining drift
33
+ name: Human-readable name for the detector
34
+ """
35
+
36
+ def __init__(self, threshold: float, name: str) -> None:
37
+ self.threshold = threshold
38
+ self.name = name
39
+
40
+ @abstractmethod
41
+ def detect(
42
+ self,
43
+ reference: pd.Series,
44
+ production: pd.Series,
45
+ ) -> DetectionResult:
46
+ """
47
+ Detect drift between reference and production data.
48
+
49
+ Args:
50
+ reference: Reference data series
51
+ production: Production data series
52
+
53
+ Returns:
54
+ DetectionResult with drift status and metrics
55
+ """
56
+ ...
57
+
58
+ def _validate_inputs(
59
+ self,
60
+ reference: pd.Series,
61
+ production: pd.Series,
62
+ ) -> None:
63
+ """Validate input series are not empty."""
64
+ if reference.empty:
65
+ raise ValueError("Reference series cannot be empty")
66
+ if production.empty:
67
+ raise ValueError("Production series cannot be empty")