truthound-dashboard 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truthound_dashboard/api/deps.py +28 -0
- truthound_dashboard/api/drift.py +1 -0
- truthound_dashboard/api/mask.py +164 -0
- truthound_dashboard/api/profile.py +11 -3
- truthound_dashboard/api/router.py +22 -0
- truthound_dashboard/api/scan.py +168 -0
- truthound_dashboard/api/schemas.py +13 -4
- truthound_dashboard/api/validations.py +33 -1
- truthound_dashboard/api/validators.py +85 -0
- truthound_dashboard/core/__init__.py +8 -0
- truthound_dashboard/core/phase5/activity.py +1 -1
- truthound_dashboard/core/services.py +457 -7
- truthound_dashboard/core/truthound_adapter.py +441 -26
- truthound_dashboard/db/__init__.py +6 -0
- truthound_dashboard/db/models.py +250 -1
- truthound_dashboard/schemas/__init__.py +52 -1
- truthound_dashboard/schemas/collaboration.py +1 -1
- truthound_dashboard/schemas/drift.py +118 -3
- truthound_dashboard/schemas/mask.py +209 -0
- truthound_dashboard/schemas/profile.py +45 -2
- truthound_dashboard/schemas/scan.py +312 -0
- truthound_dashboard/schemas/schema.py +30 -2
- truthound_dashboard/schemas/validation.py +60 -3
- truthound_dashboard/schemas/validators/__init__.py +59 -0
- truthound_dashboard/schemas/validators/aggregate_validators.py +238 -0
- truthound_dashboard/schemas/validators/anomaly_validators.py +723 -0
- truthound_dashboard/schemas/validators/base.py +263 -0
- truthound_dashboard/schemas/validators/completeness_validators.py +269 -0
- truthound_dashboard/schemas/validators/cross_table_validators.py +375 -0
- truthound_dashboard/schemas/validators/datetime_validators.py +253 -0
- truthound_dashboard/schemas/validators/distribution_validators.py +422 -0
- truthound_dashboard/schemas/validators/drift_validators.py +615 -0
- truthound_dashboard/schemas/validators/geospatial_validators.py +486 -0
- truthound_dashboard/schemas/validators/multi_column_validators.py +706 -0
- truthound_dashboard/schemas/validators/privacy_validators.py +531 -0
- truthound_dashboard/schemas/validators/query_validators.py +510 -0
- truthound_dashboard/schemas/validators/registry.py +318 -0
- truthound_dashboard/schemas/validators/schema_validators.py +408 -0
- truthound_dashboard/schemas/validators/string_validators.py +396 -0
- truthound_dashboard/schemas/validators/table_validators.py +412 -0
- truthound_dashboard/schemas/validators/uniqueness_validators.py +355 -0
- truthound_dashboard/schemas/validators.py +59 -0
- truthound_dashboard/static/assets/{index-BqXVFyqj.js → index-BCA8H1hO.js} +95 -95
- truthound_dashboard/static/assets/index-BNsSQ2fN.css +1 -0
- truthound_dashboard/static/assets/unmerged_dictionaries-CsJWCRx9.js +1 -0
- truthound_dashboard/static/index.html +2 -2
- {truthound_dashboard-1.2.1.dist-info → truthound_dashboard-1.3.0.dist-info}/METADATA +46 -11
- {truthound_dashboard-1.2.1.dist-info → truthound_dashboard-1.3.0.dist-info}/RECORD +51 -27
- truthound_dashboard/static/assets/index-o8qHVDte.css +0 -1
- truthound_dashboard/static/assets/unmerged_dictionaries-n_T3wZTf.js +0 -1
- {truthound_dashboard-1.2.1.dist-info → truthound_dashboard-1.3.0.dist-info}/WHEEL +0 -0
- {truthound_dashboard-1.2.1.dist-info → truthound_dashboard-1.3.0.dist-info}/entry_points.txt +0 -0
- {truthound_dashboard-1.2.1.dist-info → truthound_dashboard-1.3.0.dist-info}/licenses/LICENSE +0 -0
truthound_dashboard/db/models.py
CHANGED
|
@@ -598,6 +598,253 @@ class DriftComparison(Base, UUIDMixin, TimestampMixin):
|
|
|
598
598
|
return []
|
|
599
599
|
|
|
600
600
|
|
|
601
|
+
class MaskingStrategy(str, Enum):
|
|
602
|
+
"""Masking strategy enum."""
|
|
603
|
+
|
|
604
|
+
REDACT = "redact"
|
|
605
|
+
HASH = "hash"
|
|
606
|
+
FAKE = "fake"
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
class DataMask(Base, UUIDMixin):
|
|
610
|
+
"""Data masking operation model.
|
|
611
|
+
|
|
612
|
+
Stores results from th.mask() data masking operations.
|
|
613
|
+
Supports three strategies: redact (asterisks), hash (SHA256), fake (realistic data).
|
|
614
|
+
|
|
615
|
+
Attributes:
|
|
616
|
+
id: Unique identifier (UUID).
|
|
617
|
+
source_id: Reference to parent Source.
|
|
618
|
+
status: Current status (pending, running, success, failed, error).
|
|
619
|
+
strategy: Masking strategy used (redact, hash, fake).
|
|
620
|
+
output_path: Path to the masked output file.
|
|
621
|
+
columns_masked: List of columns that were masked.
|
|
622
|
+
row_count: Number of rows processed.
|
|
623
|
+
column_count: Number of columns in the data.
|
|
624
|
+
auto_detected: Whether PII columns were auto-detected.
|
|
625
|
+
result_json: Full mask result as JSON.
|
|
626
|
+
duration_ms: Operation duration in milliseconds.
|
|
627
|
+
"""
|
|
628
|
+
|
|
629
|
+
__tablename__ = "data_masks"
|
|
630
|
+
|
|
631
|
+
# Composite index for efficient history queries (source + time ordering)
|
|
632
|
+
__table_args__ = (
|
|
633
|
+
Index("idx_data_masks_source_created", "source_id", "created_at"),
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
source_id: Mapped[str] = mapped_column(
|
|
637
|
+
String(36),
|
|
638
|
+
ForeignKey("sources.id", ondelete="CASCADE"),
|
|
639
|
+
nullable=False,
|
|
640
|
+
index=True,
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
# Status tracking
|
|
644
|
+
status: Mapped[str] = mapped_column(
|
|
645
|
+
String(20),
|
|
646
|
+
nullable=False,
|
|
647
|
+
default="pending",
|
|
648
|
+
index=True,
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# Masking configuration
|
|
652
|
+
strategy: Mapped[str] = mapped_column(
|
|
653
|
+
String(20),
|
|
654
|
+
nullable=False,
|
|
655
|
+
default=MaskingStrategy.REDACT.value,
|
|
656
|
+
index=True,
|
|
657
|
+
)
|
|
658
|
+
output_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
659
|
+
columns_masked: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
|
|
660
|
+
auto_detected: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
|
661
|
+
|
|
662
|
+
# Data statistics
|
|
663
|
+
row_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
664
|
+
column_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
665
|
+
|
|
666
|
+
# Full result and timing
|
|
667
|
+
result_json: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
|
|
668
|
+
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
669
|
+
duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
670
|
+
|
|
671
|
+
# Timestamps
|
|
672
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
673
|
+
DateTime, default=datetime.utcnow, nullable=False
|
|
674
|
+
)
|
|
675
|
+
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
|
676
|
+
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
|
677
|
+
|
|
678
|
+
# Relationships
|
|
679
|
+
source: Mapped[Source] = relationship(
|
|
680
|
+
"Source",
|
|
681
|
+
backref="data_masks",
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
@property
|
|
685
|
+
def is_complete(self) -> bool:
|
|
686
|
+
"""Check if masking operation has completed."""
|
|
687
|
+
return self.status in ("success", "failed", "error")
|
|
688
|
+
|
|
689
|
+
@property
|
|
690
|
+
def masked_column_count(self) -> int:
|
|
691
|
+
"""Get number of columns that were masked."""
|
|
692
|
+
return len(self.columns_masked) if self.columns_masked else 0
|
|
693
|
+
|
|
694
|
+
def mark_started(self) -> None:
|
|
695
|
+
"""Mark operation as started."""
|
|
696
|
+
self.status = "running"
|
|
697
|
+
self.started_at = datetime.utcnow()
|
|
698
|
+
|
|
699
|
+
def mark_completed(
|
|
700
|
+
self,
|
|
701
|
+
result: dict[str, Any],
|
|
702
|
+
) -> None:
|
|
703
|
+
"""Mark operation as completed with results."""
|
|
704
|
+
self.status = "success"
|
|
705
|
+
self.result_json = result
|
|
706
|
+
self.completed_at = datetime.utcnow()
|
|
707
|
+
|
|
708
|
+
if self.started_at:
|
|
709
|
+
delta = self.completed_at - self.started_at
|
|
710
|
+
self.duration_ms = int(delta.total_seconds() * 1000)
|
|
711
|
+
|
|
712
|
+
def mark_error(self, message: str) -> None:
|
|
713
|
+
"""Mark operation as errored."""
|
|
714
|
+
self.status = "error"
|
|
715
|
+
self.error_message = message
|
|
716
|
+
self.completed_at = datetime.utcnow()
|
|
717
|
+
|
|
718
|
+
if self.started_at:
|
|
719
|
+
delta = self.completed_at - self.started_at
|
|
720
|
+
self.duration_ms = int(delta.total_seconds() * 1000)
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
class PIIScan(Base, UUIDMixin):
|
|
724
|
+
"""PII scan result model.
|
|
725
|
+
|
|
726
|
+
Stores results from th.scan() PII detection runs.
|
|
727
|
+
|
|
728
|
+
Attributes:
|
|
729
|
+
id: Unique identifier (UUID).
|
|
730
|
+
source_id: Reference to parent Source.
|
|
731
|
+
status: Current status (pending, running, success, failed, error).
|
|
732
|
+
total_columns_scanned: Total columns that were scanned.
|
|
733
|
+
columns_with_pii: Number of columns containing PII.
|
|
734
|
+
total_findings: Total number of PII findings.
|
|
735
|
+
has_violations: Whether any regulation violations were found.
|
|
736
|
+
total_violations: Number of regulation violations.
|
|
737
|
+
min_confidence: Confidence threshold used for this scan.
|
|
738
|
+
regulations_checked: List of regulations checked.
|
|
739
|
+
result_json: Full scan result as JSON.
|
|
740
|
+
duration_ms: Scan duration in milliseconds.
|
|
741
|
+
"""
|
|
742
|
+
|
|
743
|
+
__tablename__ = "pii_scans"
|
|
744
|
+
|
|
745
|
+
# Composite index for efficient history queries (source + time ordering)
|
|
746
|
+
__table_args__ = (
|
|
747
|
+
Index("idx_pii_scans_source_created", "source_id", "created_at"),
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
source_id: Mapped[str] = mapped_column(
|
|
751
|
+
String(36),
|
|
752
|
+
ForeignKey("sources.id", ondelete="CASCADE"),
|
|
753
|
+
nullable=False,
|
|
754
|
+
index=True,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
# Status tracking
|
|
758
|
+
status: Mapped[str] = mapped_column(
|
|
759
|
+
String(20),
|
|
760
|
+
nullable=False,
|
|
761
|
+
default="pending",
|
|
762
|
+
index=True,
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
# Scan summary
|
|
766
|
+
total_columns_scanned: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
767
|
+
columns_with_pii: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
768
|
+
total_findings: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
769
|
+
has_violations: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
|
770
|
+
total_violations: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
771
|
+
|
|
772
|
+
# Data statistics
|
|
773
|
+
row_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
774
|
+
column_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
775
|
+
|
|
776
|
+
# Configuration used
|
|
777
|
+
min_confidence: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
778
|
+
regulations_checked: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
|
|
779
|
+
|
|
780
|
+
# Full result and timing
|
|
781
|
+
result_json: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
|
|
782
|
+
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
783
|
+
duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
784
|
+
|
|
785
|
+
# Timestamps
|
|
786
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
787
|
+
DateTime, default=datetime.utcnow, nullable=False
|
|
788
|
+
)
|
|
789
|
+
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
|
790
|
+
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
|
791
|
+
|
|
792
|
+
# Relationships
|
|
793
|
+
source: Mapped[Source] = relationship(
|
|
794
|
+
"Source",
|
|
795
|
+
backref="pii_scans",
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
@property
|
|
799
|
+
def findings(self) -> list[dict[str, Any]]:
|
|
800
|
+
"""Get list of PII findings from result JSON."""
|
|
801
|
+
if self.result_json and "findings" in self.result_json:
|
|
802
|
+
return self.result_json["findings"]
|
|
803
|
+
return []
|
|
804
|
+
|
|
805
|
+
@property
|
|
806
|
+
def violations(self) -> list[dict[str, Any]]:
|
|
807
|
+
"""Get list of regulation violations from result JSON."""
|
|
808
|
+
if self.result_json and "violations" in self.result_json:
|
|
809
|
+
return self.result_json["violations"]
|
|
810
|
+
return []
|
|
811
|
+
|
|
812
|
+
@property
|
|
813
|
+
def is_complete(self) -> bool:
|
|
814
|
+
"""Check if scan has completed (success, failed, or error)."""
|
|
815
|
+
return self.status in ("success", "failed", "error")
|
|
816
|
+
|
|
817
|
+
def mark_started(self) -> None:
|
|
818
|
+
"""Mark scan as started."""
|
|
819
|
+
self.status = "running"
|
|
820
|
+
self.started_at = datetime.utcnow()
|
|
821
|
+
|
|
822
|
+
def mark_completed(
|
|
823
|
+
self,
|
|
824
|
+
has_violations: bool,
|
|
825
|
+
result: dict[str, Any],
|
|
826
|
+
) -> None:
|
|
827
|
+
"""Mark scan as completed with results."""
|
|
828
|
+
self.status = "success" if not has_violations else "failed"
|
|
829
|
+
self.has_violations = has_violations
|
|
830
|
+
self.result_json = result
|
|
831
|
+
self.completed_at = datetime.utcnow()
|
|
832
|
+
|
|
833
|
+
if self.started_at:
|
|
834
|
+
delta = self.completed_at - self.started_at
|
|
835
|
+
self.duration_ms = int(delta.total_seconds() * 1000)
|
|
836
|
+
|
|
837
|
+
def mark_error(self, message: str) -> None:
|
|
838
|
+
"""Mark scan as errored."""
|
|
839
|
+
self.status = "error"
|
|
840
|
+
self.error_message = message
|
|
841
|
+
self.completed_at = datetime.utcnow()
|
|
842
|
+
|
|
843
|
+
if self.started_at:
|
|
844
|
+
delta = self.completed_at - self.started_at
|
|
845
|
+
self.duration_ms = int(delta.total_seconds() * 1000)
|
|
846
|
+
|
|
847
|
+
|
|
601
848
|
class AppSettings(Base):
|
|
602
849
|
"""Application settings model.
|
|
603
850
|
|
|
@@ -1411,7 +1658,9 @@ class Activity(Base, UUIDMixin):
|
|
|
1411
1658
|
action: Mapped[str] = mapped_column(String(30), nullable=False, index=True)
|
|
1412
1659
|
actor_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
1413
1660
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
1414
|
-
|
|
1661
|
+
activity_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
|
1662
|
+
"metadata", JSON, nullable=True
|
|
1663
|
+
)
|
|
1415
1664
|
created_at: Mapped[datetime] = mapped_column(
|
|
1416
1665
|
DateTime,
|
|
1417
1666
|
default=datetime.utcnow,
|
|
@@ -20,12 +20,18 @@ from .base import (
|
|
|
20
20
|
)
|
|
21
21
|
from .drift import (
|
|
22
22
|
ColumnDriftResult,
|
|
23
|
+
CorrectionMethod,
|
|
24
|
+
CorrectionMethodLiteral,
|
|
25
|
+
DEFAULT_THRESHOLDS,
|
|
23
26
|
DriftCompareRequest,
|
|
24
27
|
DriftComparisonListItem,
|
|
25
28
|
DriftComparisonListResponse,
|
|
26
29
|
DriftComparisonResponse,
|
|
30
|
+
DriftMethod,
|
|
31
|
+
DriftMethodLiteral,
|
|
27
32
|
DriftResult,
|
|
28
33
|
DriftSourceSummary,
|
|
34
|
+
get_default_threshold,
|
|
29
35
|
)
|
|
30
36
|
from .history import (
|
|
31
37
|
FailureFrequencyItem,
|
|
@@ -35,7 +41,7 @@ from .history import (
|
|
|
35
41
|
RecentValidation,
|
|
36
42
|
TrendDataPoint,
|
|
37
43
|
)
|
|
38
|
-
from .profile import ColumnProfile, ProfileResponse
|
|
44
|
+
from .profile import ColumnProfile, ProfileRequest, ProfileResponse
|
|
39
45
|
from .rule import (
|
|
40
46
|
RuleBase,
|
|
41
47
|
RuleCreate,
|
|
@@ -114,6 +120,26 @@ from .collaboration import (
|
|
|
114
120
|
CommentUpdate,
|
|
115
121
|
ResourceType,
|
|
116
122
|
)
|
|
123
|
+
from .mask import (
|
|
124
|
+
MaskingStrategy,
|
|
125
|
+
MaskingStrategyLiteral,
|
|
126
|
+
MaskListItem,
|
|
127
|
+
MaskListResponse,
|
|
128
|
+
MaskRequest,
|
|
129
|
+
MaskResponse,
|
|
130
|
+
MaskStatus,
|
|
131
|
+
MaskSummary,
|
|
132
|
+
)
|
|
133
|
+
from .scan import (
|
|
134
|
+
PIIFinding,
|
|
135
|
+
PIIScanListItem,
|
|
136
|
+
PIIScanListResponse,
|
|
137
|
+
PIIScanRequest,
|
|
138
|
+
PIIScanResponse,
|
|
139
|
+
PIIScanSummary,
|
|
140
|
+
Regulation,
|
|
141
|
+
RegulationViolation,
|
|
142
|
+
)
|
|
117
143
|
from .schema import (
|
|
118
144
|
ColumnSchema,
|
|
119
145
|
SchemaLearnRequest,
|
|
@@ -182,8 +208,27 @@ __all__ = [
|
|
|
182
208
|
"SchemaResponse",
|
|
183
209
|
"SchemaSummary",
|
|
184
210
|
# Profile
|
|
211
|
+
"ProfileRequest",
|
|
185
212
|
"ColumnProfile",
|
|
186
213
|
"ProfileResponse",
|
|
214
|
+
# Data Masking
|
|
215
|
+
"MaskingStrategy",
|
|
216
|
+
"MaskingStrategyLiteral",
|
|
217
|
+
"MaskStatus",
|
|
218
|
+
"MaskRequest",
|
|
219
|
+
"MaskSummary",
|
|
220
|
+
"MaskResponse",
|
|
221
|
+
"MaskListItem",
|
|
222
|
+
"MaskListResponse",
|
|
223
|
+
# PII Scan
|
|
224
|
+
"Regulation",
|
|
225
|
+
"PIIScanRequest",
|
|
226
|
+
"PIIFinding",
|
|
227
|
+
"RegulationViolation",
|
|
228
|
+
"PIIScanSummary",
|
|
229
|
+
"PIIScanResponse",
|
|
230
|
+
"PIIScanListItem",
|
|
231
|
+
"PIIScanListResponse",
|
|
187
232
|
# History
|
|
188
233
|
"TrendDataPoint",
|
|
189
234
|
"FailureFrequencyItem",
|
|
@@ -192,6 +237,12 @@ __all__ = [
|
|
|
192
237
|
"HistoryResponse",
|
|
193
238
|
"HistoryQueryParams",
|
|
194
239
|
# Drift
|
|
240
|
+
"DriftMethod",
|
|
241
|
+
"DriftMethodLiteral",
|
|
242
|
+
"CorrectionMethod",
|
|
243
|
+
"CorrectionMethodLiteral",
|
|
244
|
+
"DEFAULT_THRESHOLDS",
|
|
245
|
+
"get_default_threshold",
|
|
195
246
|
"DriftCompareRequest",
|
|
196
247
|
"ColumnDriftResult",
|
|
197
248
|
"DriftResult",
|
|
@@ -142,7 +142,7 @@ class ActivityResponse(BaseSchema, IDMixin):
|
|
|
142
142
|
action=ActivityAction(activity.action),
|
|
143
143
|
actor_id=activity.actor_id,
|
|
144
144
|
description=activity.description,
|
|
145
|
-
metadata=activity.
|
|
145
|
+
metadata=activity.activity_metadata,
|
|
146
146
|
created_at=activity.created_at,
|
|
147
147
|
)
|
|
148
148
|
|
|
@@ -1,10 +1,26 @@
|
|
|
1
1
|
"""Drift detection schemas.
|
|
2
2
|
|
|
3
3
|
Schemas for drift comparison request/response.
|
|
4
|
+
|
|
5
|
+
Drift Methods (from truthound):
|
|
6
|
+
- ks: Kolmogorov-Smirnov test (continuous distributions)
|
|
7
|
+
- psi: Population Stability Index (any distribution, industry standard)
|
|
8
|
+
- chi2: Chi-Square test (categorical data)
|
|
9
|
+
- js: Jensen-Shannon divergence (probability distributions)
|
|
10
|
+
- kl: Kullback-Leibler divergence (distribution difference)
|
|
11
|
+
- wasserstein: Wasserstein/Earth Mover's Distance (distribution transport)
|
|
12
|
+
- cvm: Cramér-von Mises test (more sensitive to tails than KS)
|
|
13
|
+
- anderson: Anderson-Darling test (weighted for tail sensitivity)
|
|
14
|
+
|
|
15
|
+
Multiple Testing Correction:
|
|
16
|
+
- bonferroni: Conservative, independent tests
|
|
17
|
+
- holm: Sequential adjustment, less conservative
|
|
18
|
+
- bh: Benjamini-Hochberg (FDR control, default for multiple columns)
|
|
4
19
|
"""
|
|
5
20
|
|
|
6
21
|
from __future__ import annotations
|
|
7
22
|
|
|
23
|
+
from enum import Enum
|
|
8
24
|
from typing import Any, Literal
|
|
9
25
|
|
|
10
26
|
from pydantic import BaseModel, Field
|
|
@@ -12,6 +28,87 @@ from pydantic import BaseModel, Field
|
|
|
12
28
|
from .base import IDMixin, TimestampMixin
|
|
13
29
|
|
|
14
30
|
|
|
31
|
+
class DriftMethod(str, Enum):
|
|
32
|
+
"""Drift detection methods supported by truthound.
|
|
33
|
+
|
|
34
|
+
Each method has different characteristics and use cases:
|
|
35
|
+
- auto: Smart selection based on data type (numeric → PSI, categorical → chi2)
|
|
36
|
+
- ks: Kolmogorov-Smirnov test - best for continuous distributions
|
|
37
|
+
- psi: Population Stability Index - industry standard, any distribution
|
|
38
|
+
- chi2: Chi-Square test - best for categorical data
|
|
39
|
+
- js: Jensen-Shannon divergence - symmetric, bounded (0-1)
|
|
40
|
+
- kl: Kullback-Leibler divergence - information loss measure
|
|
41
|
+
- wasserstein: Earth Mover's Distance - metric, meaningful for non-overlapping
|
|
42
|
+
- cvm: Cramér-von Mises - more sensitive to tail differences than KS
|
|
43
|
+
- anderson: Anderson-Darling - weighted for tail sensitivity
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
AUTO = "auto"
|
|
47
|
+
KS = "ks"
|
|
48
|
+
PSI = "psi"
|
|
49
|
+
CHI2 = "chi2"
|
|
50
|
+
JS = "js"
|
|
51
|
+
KL = "kl"
|
|
52
|
+
WASSERSTEIN = "wasserstein"
|
|
53
|
+
CVM = "cvm"
|
|
54
|
+
ANDERSON = "anderson"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CorrectionMethod(str, Enum):
|
|
58
|
+
"""Multiple testing correction methods.
|
|
59
|
+
|
|
60
|
+
When comparing multiple columns, correction adjusts p-values to control
|
|
61
|
+
false discovery rate:
|
|
62
|
+
- none: No correction (use with caution)
|
|
63
|
+
- bonferroni: Conservative, suitable for independent tests
|
|
64
|
+
- holm: Sequential adjustment, less conservative than Bonferroni
|
|
65
|
+
- bh: Benjamini-Hochberg (FDR control), default for multiple columns
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
NONE = "none"
|
|
69
|
+
BONFERRONI = "bonferroni"
|
|
70
|
+
HOLM = "holm"
|
|
71
|
+
BH = "bh"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Default thresholds for each detection method
|
|
75
|
+
DEFAULT_THRESHOLDS: dict[DriftMethod, float] = {
|
|
76
|
+
DriftMethod.AUTO: 0.05,
|
|
77
|
+
DriftMethod.KS: 0.05,
|
|
78
|
+
DriftMethod.PSI: 0.1,
|
|
79
|
+
DriftMethod.CHI2: 0.05,
|
|
80
|
+
DriftMethod.JS: 0.1,
|
|
81
|
+
DriftMethod.KL: 0.1,
|
|
82
|
+
DriftMethod.WASSERSTEIN: 0.1, # Scale-dependent, adjust based on data
|
|
83
|
+
DriftMethod.CVM: 0.05,
|
|
84
|
+
DriftMethod.ANDERSON: 0.05,
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_default_threshold(method: DriftMethod | str) -> float:
|
|
89
|
+
"""Get default threshold for a drift detection method.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
method: Drift detection method
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Default threshold value for the method
|
|
96
|
+
"""
|
|
97
|
+
if isinstance(method, str):
|
|
98
|
+
try:
|
|
99
|
+
method = DriftMethod(method)
|
|
100
|
+
except ValueError:
|
|
101
|
+
return 0.05 # Fallback default
|
|
102
|
+
return DEFAULT_THRESHOLDS.get(method, 0.05)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# Type alias for method values (for Literal type hints)
|
|
106
|
+
DriftMethodLiteral = Literal[
|
|
107
|
+
"auto", "ks", "psi", "chi2", "js", "kl", "wasserstein", "cvm", "anderson"
|
|
108
|
+
]
|
|
109
|
+
CorrectionMethodLiteral = Literal["none", "bonferroni", "holm", "bh"]
|
|
110
|
+
|
|
111
|
+
|
|
15
112
|
class DriftCompareRequest(BaseModel):
|
|
16
113
|
"""Request body for drift comparison."""
|
|
17
114
|
|
|
@@ -20,10 +117,28 @@ class DriftCompareRequest(BaseModel):
|
|
|
20
117
|
columns: list[str] | None = Field(
|
|
21
118
|
None, description="Columns to compare (None = all)"
|
|
22
119
|
)
|
|
23
|
-
method:
|
|
24
|
-
"auto",
|
|
120
|
+
method: DriftMethodLiteral = Field(
|
|
121
|
+
"auto",
|
|
122
|
+
description=(
|
|
123
|
+
"Drift detection method: "
|
|
124
|
+
"auto (smart selection), ks (Kolmogorov-Smirnov), psi (Population Stability Index), "
|
|
125
|
+
"chi2 (Chi-Square), js (Jensen-Shannon), kl (Kullback-Leibler), "
|
|
126
|
+
"wasserstein (Earth Mover's), cvm (Cramér-von Mises), anderson (Anderson-Darling)"
|
|
127
|
+
),
|
|
128
|
+
)
|
|
129
|
+
threshold: float | None = Field(
|
|
130
|
+
None,
|
|
131
|
+
ge=0,
|
|
132
|
+
le=1,
|
|
133
|
+
description="Custom threshold (default varies by method: KS/chi2/cvm/anderson=0.05, PSI/JS/KL/wasserstein=0.1)",
|
|
134
|
+
)
|
|
135
|
+
correction: CorrectionMethodLiteral | None = Field(
|
|
136
|
+
None,
|
|
137
|
+
description=(
|
|
138
|
+
"Multiple testing correction: none, bonferroni (conservative), "
|
|
139
|
+
"holm (sequential), bh (Benjamini-Hochberg FDR, default for multiple columns)"
|
|
140
|
+
),
|
|
25
141
|
)
|
|
26
|
-
threshold: float | None = Field(None, ge=0, le=1, description="Custom threshold")
|
|
27
142
|
sample_size: int | None = Field(
|
|
28
143
|
None, ge=100, description="Sample size for large datasets"
|
|
29
144
|
)
|