driftwatch 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,145 @@
1
+ """Categorical feature drift detectors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ from scipy import stats
9
+
10
+ from driftwatch.detectors.base import BaseDetector, DetectionResult
11
+
12
+ if TYPE_CHECKING:
13
+ import pandas as pd
14
+
15
+
16
+ class ChiSquaredDetector(BaseDetector):
17
+ """
18
+ Chi-Squared test for categorical drift detection.
19
+
20
+ Tests whether the frequency distribution of categories
21
+ has changed between reference and production data.
22
+
23
+ Args:
24
+ threshold: P-value threshold below which drift is detected.
25
+ Default is 0.05 (95% confidence).
26
+
27
+ Example:
28
+ >>> detector = ChiSquaredDetector(threshold=0.05)
29
+ >>> result = detector.detect(reference_series, production_series)
30
+ """
31
+
32
+ def __init__(self, threshold: float = 0.05) -> None:
33
+ super().__init__(threshold=threshold, name="chi_squared")
34
+
35
+ def detect(
36
+ self,
37
+ reference: pd.Series,
38
+ production: pd.Series,
39
+ ) -> DetectionResult:
40
+ """
41
+ Perform Chi-Squared test on category frequencies.
42
+
43
+ Returns:
44
+ DetectionResult with chi-squared statistic and p-value
45
+ """
46
+ self._validate_inputs(reference, production)
47
+
48
+ # Get all categories from both datasets
49
+ all_categories = set(reference.dropna().unique()) | set(
50
+ production.dropna().unique()
51
+ )
52
+
53
+ # Count frequencies
54
+ ref_counts = reference.value_counts()
55
+ prod_counts = production.value_counts()
56
+
57
+ # Align to same categories
58
+ ref_freq = np.array([ref_counts.get(cat, 0) for cat in all_categories])
59
+ prod_freq = np.array([prod_counts.get(cat, 0) for cat in all_categories])
60
+
61
+ # Handle edge case of zero frequencies
62
+ if ref_freq.sum() == 0 or prod_freq.sum() == 0:
63
+ return DetectionResult(
64
+ has_drift=True,
65
+ score=float("inf"),
66
+ method=self.name,
67
+ threshold=self.threshold,
68
+ p_value=0.0,
69
+ )
70
+
71
+ # Calculate expected frequencies based on reference proportions
72
+ ref_proportions = ref_freq / ref_freq.sum()
73
+ expected = ref_proportions * prod_freq.sum()
74
+
75
+ # Add small epsilon to avoid division by zero
76
+ expected = np.maximum(expected, 1e-10)
77
+
78
+ # Chi-squared statistic
79
+ statistic, p_value = stats.chisquare(prod_freq, f_exp=expected)
80
+
81
+ return DetectionResult(
82
+ has_drift=p_value < self.threshold,
83
+ score=float(statistic),
84
+ method=self.name,
85
+ threshold=self.threshold,
86
+ p_value=float(p_value),
87
+ )
88
+
89
+
90
+ class FrequencyPSIDetector(BaseDetector):
91
+ """
92
+ PSI-based detector for categorical features.
93
+
94
+ Calculates PSI using category frequency distributions
95
+ instead of numerical buckets.
96
+
97
+ Args:
98
+ threshold: PSI value above which drift is detected.
99
+ Default is 0.2.
100
+ """
101
+
102
+ def __init__(self, threshold: float = 0.2) -> None:
103
+ super().__init__(threshold=threshold, name="frequency_psi")
104
+
105
+ def detect(
106
+ self,
107
+ reference: pd.Series,
108
+ production: pd.Series,
109
+ ) -> DetectionResult:
110
+ """
111
+ Calculate PSI on category frequencies.
112
+
113
+ Returns:
114
+ DetectionResult with PSI score
115
+ """
116
+ self._validate_inputs(reference, production)
117
+
118
+ # Get normalized frequencies
119
+ ref_freq = reference.value_counts(normalize=True)
120
+ prod_freq = production.value_counts(normalize=True)
121
+
122
+ # Get all categories
123
+ all_categories = set(ref_freq.index) | set(prod_freq.index)
124
+
125
+ # Calculate PSI
126
+ eps = 1e-10
127
+ psi = 0.0
128
+
129
+ for cat in all_categories:
130
+ ref_pct = ref_freq.get(cat, eps)
131
+ prod_pct = prod_freq.get(cat, eps)
132
+
133
+ # Clip to avoid log(0)
134
+ ref_pct = max(ref_pct, eps)
135
+ prod_pct = max(prod_pct, eps)
136
+
137
+ psi += (prod_pct - ref_pct) * np.log(prod_pct / ref_pct)
138
+
139
+ return DetectionResult(
140
+ has_drift=psi >= self.threshold,
141
+ score=float(psi),
142
+ method=self.name,
143
+ threshold=self.threshold,
144
+ p_value=None,
145
+ )
@@ -0,0 +1,198 @@
1
+ """Numerical feature drift detectors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ from scipy import stats
9
+
10
+ from driftwatch.detectors.base import BaseDetector, DetectionResult
11
+
12
+ if TYPE_CHECKING:
13
+ import pandas as pd
14
+
15
+
16
+ class KSDetector(BaseDetector):
17
+ """
18
+ Kolmogorov-Smirnov test for numerical drift detection.
19
+
20
+ The KS test measures the maximum distance between the cumulative
21
+ distribution functions of two samples.
22
+
23
+ Args:
24
+ threshold: P-value threshold below which drift is detected.
25
+ Default is 0.05 (95% confidence).
26
+
27
+ Example:
28
+ >>> detector = KSDetector(threshold=0.05)
29
+ >>> result = detector.detect(reference_series, production_series)
30
+ >>> print(f"Drift detected: {result.has_drift}")
31
+ """
32
+
33
+ def __init__(self, threshold: float = 0.05) -> None:
34
+ super().__init__(threshold=threshold, name="ks_test")
35
+
36
+ def detect(
37
+ self,
38
+ reference: pd.Series,
39
+ production: pd.Series,
40
+ ) -> DetectionResult:
41
+ """
42
+ Perform KS test between reference and production distributions.
43
+
44
+ Returns:
45
+ DetectionResult with KS statistic as score and p-value
46
+ """
47
+ self._validate_inputs(reference, production)
48
+
49
+ statistic, p_value = stats.ks_2samp(
50
+ reference.dropna(),
51
+ production.dropna(),
52
+ )
53
+
54
+ return DetectionResult(
55
+ has_drift=p_value < self.threshold,
56
+ score=float(statistic),
57
+ method=self.name,
58
+ threshold=self.threshold,
59
+ p_value=float(p_value),
60
+ )
61
+
62
+
63
+ class PSIDetector(BaseDetector):
64
+ """
65
+ Population Stability Index (PSI) for numerical drift detection.
66
+
67
+ PSI measures the shift in distribution between two populations.
68
+ Commonly used thresholds:
69
+ - PSI < 0.1: No significant change
70
+ - 0.1 <= PSI < 0.2: Minor shift
71
+ - PSI >= 0.2: Significant shift (drift)
72
+
73
+ Args:
74
+ threshold: PSI value above which drift is detected.
75
+ Default is 0.2.
76
+ buckets: Number of buckets for binning. Default is 10.
77
+
78
+ Example:
79
+ >>> detector = PSIDetector(threshold=0.2, buckets=10)
80
+ >>> result = detector.detect(reference_series, production_series)
81
+ """
82
+
83
+ def __init__(self, threshold: float = 0.2, buckets: int = 10) -> None:
84
+ super().__init__(threshold=threshold, name="psi")
85
+ self.buckets = buckets
86
+
87
+ def detect(
88
+ self,
89
+ reference: pd.Series,
90
+ production: pd.Series,
91
+ ) -> DetectionResult:
92
+ """
93
+ Calculate PSI between reference and production distributions.
94
+
95
+ Returns:
96
+ DetectionResult with PSI score
97
+ """
98
+ self._validate_inputs(reference, production)
99
+
100
+ psi_value = self._calculate_psi(
101
+ np.asarray(reference.dropna().values),
102
+ np.asarray(production.dropna().values),
103
+ )
104
+
105
+ return DetectionResult(
106
+ has_drift=psi_value >= self.threshold,
107
+ score=float(psi_value),
108
+ method=self.name,
109
+ threshold=self.threshold,
110
+ p_value=None,
111
+ )
112
+
113
+ def _calculate_psi(
114
+ self,
115
+ reference: np.ndarray,
116
+ production: np.ndarray,
117
+ ) -> float:
118
+ """
119
+ Calculate PSI using percentile-based buckets.
120
+
121
+ The reference distribution defines the bucket boundaries,
122
+ and we compare the distribution of production data across
123
+ these same buckets.
124
+ """
125
+ # Create buckets based on reference quantiles
126
+ breakpoints = np.percentile(
127
+ reference,
128
+ np.linspace(0, 100, self.buckets + 1),
129
+ )
130
+ # Ensure unique breakpoints
131
+ breakpoints = np.unique(breakpoints)
132
+
133
+ if len(breakpoints) < 2:
134
+ # Not enough variation, return 0
135
+ return 0.0
136
+
137
+ # Calculate distribution in each bucket
138
+ ref_counts = np.histogram(reference, bins=breakpoints)[0]
139
+ prod_counts = np.histogram(production, bins=breakpoints)[0]
140
+
141
+ # Convert to percentages, avoiding division by zero
142
+ ref_pct = ref_counts / len(reference)
143
+ prod_pct = prod_counts / len(production)
144
+
145
+ # Add small epsilon to avoid log(0)
146
+ eps = 1e-10
147
+ ref_pct = np.clip(ref_pct, eps, 1)
148
+ prod_pct = np.clip(prod_pct, eps, 1)
149
+
150
+ # Calculate PSI
151
+ psi: float = float(np.sum((prod_pct - ref_pct) * np.log(prod_pct / ref_pct)))
152
+
153
+ return float(psi)
154
+
155
+
156
+ class WassersteinDetector(BaseDetector):
157
+ """
158
+ Wasserstein distance (Earth Mover's Distance) for drift detection.
159
+
160
+ Measures the minimum "work" required to transform one distribution
161
+ into another. More sensitive to subtle distributional changes.
162
+
163
+ Args:
164
+ threshold: Distance above which drift is detected.
165
+ """
166
+
167
+ def __init__(self, threshold: float = 0.1) -> None:
168
+ super().__init__(threshold=threshold, name="wasserstein")
169
+
170
+ def detect(
171
+ self,
172
+ reference: pd.Series,
173
+ production: pd.Series,
174
+ ) -> DetectionResult:
175
+ """
176
+ Calculate Wasserstein distance between distributions.
177
+
178
+ Note: Values are normalized by the reference standard deviation
179
+ to make the threshold more interpretable.
180
+ """
181
+ self._validate_inputs(reference, production)
182
+
183
+ ref_clean = reference.dropna().values
184
+ prod_clean = production.dropna().values
185
+
186
+ distance = stats.wasserstein_distance(ref_clean, prod_clean)
187
+
188
+ # Normalize by reference std for interpretability
189
+ ref_std = np.std(ref_clean)
190
+ normalized_distance = distance / ref_std if ref_std > 0 else distance
191
+
192
+ return DetectionResult(
193
+ has_drift=normalized_distance >= self.threshold,
194
+ score=float(normalized_distance),
195
+ method=self.name,
196
+ threshold=self.threshold,
197
+ p_value=None,
198
+ )
@@ -0,0 +1,71 @@
1
+ """Detector registry for automatic selection based on dtype."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import numpy as np
8
+
9
+ from driftwatch.detectors.categorical import ChiSquaredDetector
10
+ from driftwatch.detectors.numerical import KSDetector, PSIDetector, WassersteinDetector
11
+
12
+ if TYPE_CHECKING:
13
+ from driftwatch.detectors.base import BaseDetector
14
+
15
+
16
+ def get_detector(dtype: np.dtype[Any], thresholds: dict[str, float]) -> BaseDetector:
17
+ """
18
+ Get appropriate detector based on data type.
19
+
20
+ Args:
21
+ dtype: NumPy dtype of the feature
22
+ thresholds: Dictionary of threshold values
23
+
24
+ Returns:
25
+ Appropriate detector instance
26
+
27
+ Note:
28
+ - Numerical types use PSI by default
29
+ - Categorical/object types use Chi-Squared
30
+ """
31
+ if np.issubdtype(dtype, np.number):
32
+ # Use PSI for numerical features by default
33
+ return PSIDetector(threshold=thresholds.get("psi", 0.2))
34
+ else:
35
+ # Use Chi-Squared for categorical features
36
+ return ChiSquaredDetector(threshold=thresholds.get("chi2_pvalue", 0.05))
37
+
38
+
39
+ def get_detector_by_name(
40
+ name: str,
41
+ thresholds: dict[str, float],
42
+ ) -> BaseDetector:
43
+ """
44
+ Get detector by explicit name.
45
+
46
+ Args:
47
+ name: Detector name ("ks", "psi", "wasserstein", "chi2")
48
+ thresholds: Dictionary of threshold values
49
+
50
+ Returns:
51
+ Requested detector instance
52
+
53
+ Raises:
54
+ ValueError: If detector name is unknown
55
+ """
56
+ detectors = {
57
+ "ks": lambda: KSDetector(threshold=thresholds.get("ks_pvalue", 0.05)),
58
+ "psi": lambda: PSIDetector(threshold=thresholds.get("psi", 0.2)),
59
+ "wasserstein": lambda: WassersteinDetector(
60
+ threshold=thresholds.get("wasserstein", 0.1)
61
+ ),
62
+ "chi2": lambda: ChiSquaredDetector(
63
+ threshold=thresholds.get("chi2_pvalue", 0.05)
64
+ ),
65
+ }
66
+
67
+ if name not in detectors:
68
+ available = ", ".join(detectors.keys())
69
+ raise ValueError(f"Unknown detector '{name}'. Available: {available}")
70
+
71
+ return detectors[name]()
@@ -0,0 +1,5 @@
1
+ """DriftWatch integrations for external services."""
2
+
3
+ from driftwatch.integrations.fastapi import DriftMiddleware, add_drift_routes
4
+
5
+ __all__ = ["DriftMiddleware", "add_drift_routes"]
@@ -0,0 +1,211 @@
1
+ """Alerting integrations for DriftWatch.
2
+
3
+ Provides alerting mechanisms (Slack, Email, etc.) for drift detection.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import time
9
+ from datetime import datetime, timezone
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ import httpx
13
+
14
+ if TYPE_CHECKING:
15
+ from driftwatch.core.report import DriftReport
16
+
17
+
18
+ class SlackAlerter:
19
+ """
20
+ Send drift alerts to Slack via webhook.
21
+
22
+ Formats drift reports as Slack Block Kit messages with feature-level
23
+ details and supports alert throttling to avoid spam.
24
+
25
+ Args:
26
+ webhook_url: Slack webhook URL (https://hooks.slack.com/...)
27
+ throttle_minutes: Minimum minutes between alerts (default: 60)
28
+ mention_user: Optional Slack user ID to mention (@U123ABC)
29
+ channel_override: Optional channel to post to (overrides webhook default)
30
+
31
+ Example:
32
+ ```python
33
+ from driftwatch.integrations.alerting import SlackAlerter
34
+
35
+ alerter = SlackAlerter(
36
+ webhook_url="https://hooks.slack.com/services/...",
37
+ throttle_minutes=60
38
+ )
39
+
40
+ if report.has_drift():
41
+ alerter.send(report)
42
+ ```
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ webhook_url: str,
48
+ throttle_minutes: int = 60,
49
+ mention_user: str | None = None,
50
+ channel_override: str | None = None,
51
+ ) -> None:
52
+ self.webhook_url = webhook_url
53
+ self.throttle_seconds = throttle_minutes * 60
54
+ self.mention_user = mention_user
55
+ self.channel_override = channel_override
56
+ self._last_alert_time: float = 0.0
57
+
58
+ def send(
59
+ self,
60
+ report: DriftReport,
61
+ force: bool = False,
62
+ custom_message: str | None = None,
63
+ ) -> bool:
64
+ """
65
+ Send drift report to Slack.
66
+
67
+ Args:
68
+ report: DriftReport to send
69
+ force: Skip throttling check
70
+ custom_message: Optional custom message prefix
71
+
72
+ Returns:
73
+ True if alert was sent, False if throttled
74
+
75
+ Raises:
76
+ httpx.HTTPError: If webhook request fails
77
+ """
78
+ # Check throttling
79
+ if not force and self._is_throttled():
80
+ return False
81
+
82
+ # Build Slack message
83
+ blocks = self._build_blocks(report, custom_message)
84
+ payload: dict[str, Any] = {"blocks": blocks}
85
+
86
+ if self.channel_override:
87
+ payload["channel"] = self.channel_override
88
+
89
+ # Send to Slack
90
+ response = httpx.post(
91
+ self.webhook_url, json=payload, timeout=10.0, follow_redirects=True
92
+ )
93
+ response.raise_for_status()
94
+
95
+ # Update throttle timestamp
96
+ self._last_alert_time = time.time()
97
+
98
+ return True
99
+
100
+ def _is_throttled(self) -> bool:
101
+ """Check if alert should be throttled."""
102
+ if self._last_alert_time == 0.0:
103
+ return False
104
+
105
+ elapsed = time.time() - self._last_alert_time
106
+ return elapsed < self.throttle_seconds
107
+
108
+ def _build_blocks(
109
+ self, report: DriftReport, custom_message: str | None = None
110
+ ) -> list[dict[str, Any]]:
111
+ """Build Slack Block Kit message."""
112
+ blocks: list[dict[str, Any]] = []
113
+
114
+ # Status emoji and color
115
+ emoji = {"OK": "✅", "WARNING": "⚠️", "CRITICAL": "🚨"}.get(
116
+ report.status.value, "📊"
117
+ )
118
+
119
+ # Header
120
+ header_text = f"{emoji} *Drift Detected - DriftWatch*"
121
+ if custom_message:
122
+ header_text = f"{custom_message}\n{header_text}"
123
+ if self.mention_user:
124
+ header_text = f"<@{self.mention_user}> {header_text}"
125
+
126
+ blocks.append(
127
+ {"type": "header", "text": {"type": "plain_text", "text": header_text}}
128
+ )
129
+
130
+ # Summary section
131
+ summary_fields = [
132
+ {"type": "mrkdwn", "text": f"*Status:*\n{report.status.value}"},
133
+ {
134
+ "type": "mrkdwn",
135
+ "text": f"*Drift Ratio:*\n{report.drift_ratio():.1%}",
136
+ },
137
+ {
138
+ "type": "mrkdwn",
139
+ "text": f"*Affected Features:*\n{len(report.drifted_features())}/{len(report.feature_results)}",
140
+ },
141
+ {
142
+ "type": "mrkdwn",
143
+ "text": f"*Timestamp:*\n{self._format_timestamp(report.timestamp)}",
144
+ },
145
+ ]
146
+
147
+ blocks.append({"type": "section", "fields": summary_fields})
148
+
149
+ # Divider
150
+ blocks.append({"type": "divider"})
151
+
152
+ # Feature details (only drifted features)
153
+ if report.drifted_features():
154
+ blocks.append(
155
+ {
156
+ "type": "section",
157
+ "text": {
158
+ "type": "mrkdwn",
159
+ "text": "*Drifted Features:*",
160
+ },
161
+ }
162
+ )
163
+
164
+ feature_details = []
165
+ for result in report.feature_results:
166
+ if result.has_drift:
167
+ detail = f"• `{result.feature_name}`: {result.method.upper()}={result.score:.4f} (threshold={result.threshold:.4f})"
168
+ feature_details.append(detail)
169
+
170
+ blocks.append(
171
+ {
172
+ "type": "section",
173
+ "text": {
174
+ "type": "mrkdwn",
175
+ "text": "\n".join(feature_details),
176
+ },
177
+ }
178
+ )
179
+
180
+ # Context footer
181
+ context_text = "DriftWatch Monitor"
182
+ if report.model_version:
183
+ context_text += f" • Model: {report.model_version}"
184
+
185
+ blocks.append(
186
+ {
187
+ "type": "context",
188
+ "elements": [{"type": "mrkdwn", "text": context_text}],
189
+ }
190
+ )
191
+
192
+ return blocks
193
+
194
+ def _format_timestamp(self, timestamp: datetime) -> str:
195
+ """Format timestamp for Slack message."""
196
+ if timestamp.tzinfo is None:
197
+ timestamp = timestamp.replace(tzinfo=timezone.utc)
198
+
199
+ return timestamp.strftime("%Y-%m-%d %H:%M:%S UTC")
200
+
201
+ def get_next_alert_time(self) -> datetime | None:
202
+ """Get the earliest time the next alert can be sent."""
203
+ if self._last_alert_time == 0.0:
204
+ return None
205
+
206
+ next_time = self._last_alert_time + self.throttle_seconds
207
+ return datetime.fromtimestamp(next_time, tz=timezone.utc)
208
+
209
+ def reset_throttle(self) -> None:
210
+ """Reset throttle timer (allows immediate next alert)."""
211
+ self._last_alert_time = 0.0