churnkit 0.75.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
  2. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
  3. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
  4. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
  5. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
  6. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
  7. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
  8. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
  9. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
  10. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
  11. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
  12. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
  13. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
  14. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
  15. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
  16. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
  17. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
  18. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
  19. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
  20. churnkit-0.75.0a1.dist-info/METADATA +229 -0
  21. churnkit-0.75.0a1.dist-info/RECORD +302 -0
  22. churnkit-0.75.0a1.dist-info/WHEEL +4 -0
  23. churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
  24. churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
  25. customer_retention/__init__.py +37 -0
  26. customer_retention/analysis/__init__.py +0 -0
  27. customer_retention/analysis/auto_explorer/__init__.py +62 -0
  28. customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
  29. customer_retention/analysis/auto_explorer/explorer.py +258 -0
  30. customer_retention/analysis/auto_explorer/findings.py +291 -0
  31. customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
  32. customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
  33. customer_retention/analysis/auto_explorer/recommendations.py +418 -0
  34. customer_retention/analysis/business/__init__.py +26 -0
  35. customer_retention/analysis/business/ab_test_designer.py +144 -0
  36. customer_retention/analysis/business/fairness_analyzer.py +166 -0
  37. customer_retention/analysis/business/intervention_matcher.py +121 -0
  38. customer_retention/analysis/business/report_generator.py +222 -0
  39. customer_retention/analysis/business/risk_profile.py +199 -0
  40. customer_retention/analysis/business/roi_analyzer.py +139 -0
  41. customer_retention/analysis/diagnostics/__init__.py +20 -0
  42. customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
  43. customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
  44. customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
  45. customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
  46. customer_retention/analysis/diagnostics/noise_tester.py +140 -0
  47. customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
  48. customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
  49. customer_retention/analysis/discovery/__init__.py +8 -0
  50. customer_retention/analysis/discovery/config_generator.py +49 -0
  51. customer_retention/analysis/discovery/discovery_flow.py +19 -0
  52. customer_retention/analysis/discovery/type_inferencer.py +147 -0
  53. customer_retention/analysis/interpretability/__init__.py +13 -0
  54. customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
  55. customer_retention/analysis/interpretability/counterfactual.py +175 -0
  56. customer_retention/analysis/interpretability/individual_explainer.py +141 -0
  57. customer_retention/analysis/interpretability/pdp_generator.py +103 -0
  58. customer_retention/analysis/interpretability/shap_explainer.py +106 -0
  59. customer_retention/analysis/jupyter_save_hook.py +28 -0
  60. customer_retention/analysis/notebook_html_exporter.py +136 -0
  61. customer_retention/analysis/notebook_progress.py +60 -0
  62. customer_retention/analysis/plotly_preprocessor.py +154 -0
  63. customer_retention/analysis/recommendations/__init__.py +54 -0
  64. customer_retention/analysis/recommendations/base.py +158 -0
  65. customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
  66. customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
  67. customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
  68. customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
  69. customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
  70. customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
  71. customer_retention/analysis/recommendations/datetime/extract.py +149 -0
  72. customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
  73. customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
  74. customer_retention/analysis/recommendations/pipeline.py +74 -0
  75. customer_retention/analysis/recommendations/registry.py +76 -0
  76. customer_retention/analysis/recommendations/selection/__init__.py +3 -0
  77. customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
  78. customer_retention/analysis/recommendations/transform/__init__.py +4 -0
  79. customer_retention/analysis/recommendations/transform/power.py +94 -0
  80. customer_retention/analysis/recommendations/transform/scale.py +112 -0
  81. customer_retention/analysis/visualization/__init__.py +15 -0
  82. customer_retention/analysis/visualization/chart_builder.py +2619 -0
  83. customer_retention/analysis/visualization/console.py +122 -0
  84. customer_retention/analysis/visualization/display.py +171 -0
  85. customer_retention/analysis/visualization/number_formatter.py +36 -0
  86. customer_retention/artifacts/__init__.py +3 -0
  87. customer_retention/artifacts/fit_artifact_registry.py +146 -0
  88. customer_retention/cli.py +93 -0
  89. customer_retention/core/__init__.py +0 -0
  90. customer_retention/core/compat/__init__.py +193 -0
  91. customer_retention/core/compat/detection.py +99 -0
  92. customer_retention/core/compat/ops.py +48 -0
  93. customer_retention/core/compat/pandas_backend.py +57 -0
  94. customer_retention/core/compat/spark_backend.py +75 -0
  95. customer_retention/core/components/__init__.py +11 -0
  96. customer_retention/core/components/base.py +79 -0
  97. customer_retention/core/components/components/__init__.py +13 -0
  98. customer_retention/core/components/components/deployer.py +26 -0
  99. customer_retention/core/components/components/explainer.py +26 -0
  100. customer_retention/core/components/components/feature_eng.py +33 -0
  101. customer_retention/core/components/components/ingester.py +34 -0
  102. customer_retention/core/components/components/profiler.py +34 -0
  103. customer_retention/core/components/components/trainer.py +38 -0
  104. customer_retention/core/components/components/transformer.py +36 -0
  105. customer_retention/core/components/components/validator.py +37 -0
  106. customer_retention/core/components/enums.py +33 -0
  107. customer_retention/core/components/orchestrator.py +94 -0
  108. customer_retention/core/components/registry.py +59 -0
  109. customer_retention/core/config/__init__.py +39 -0
  110. customer_retention/core/config/column_config.py +95 -0
  111. customer_retention/core/config/experiments.py +71 -0
  112. customer_retention/core/config/pipeline_config.py +117 -0
  113. customer_retention/core/config/source_config.py +83 -0
  114. customer_retention/core/utils/__init__.py +28 -0
  115. customer_retention/core/utils/leakage.py +85 -0
  116. customer_retention/core/utils/severity.py +53 -0
  117. customer_retention/core/utils/statistics.py +90 -0
  118. customer_retention/generators/__init__.py +0 -0
  119. customer_retention/generators/notebook_generator/__init__.py +167 -0
  120. customer_retention/generators/notebook_generator/base.py +55 -0
  121. customer_retention/generators/notebook_generator/cell_builder.py +49 -0
  122. customer_retention/generators/notebook_generator/config.py +47 -0
  123. customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
  124. customer_retention/generators/notebook_generator/local_generator.py +48 -0
  125. customer_retention/generators/notebook_generator/project_init.py +174 -0
  126. customer_retention/generators/notebook_generator/runner.py +150 -0
  127. customer_retention/generators/notebook_generator/script_generator.py +110 -0
  128. customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
  129. customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
  130. customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
  131. customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
  132. customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
  133. customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
  134. customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
  135. customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
  136. customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
  137. customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
  138. customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
  139. customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
  140. customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
  141. customer_retention/generators/orchestration/__init__.py +23 -0
  142. customer_retention/generators/orchestration/code_generator.py +196 -0
  143. customer_retention/generators/orchestration/context.py +147 -0
  144. customer_retention/generators/orchestration/data_materializer.py +188 -0
  145. customer_retention/generators/orchestration/databricks_exporter.py +411 -0
  146. customer_retention/generators/orchestration/doc_generator.py +311 -0
  147. customer_retention/generators/pipeline_generator/__init__.py +26 -0
  148. customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
  149. customer_retention/generators/pipeline_generator/generator.py +142 -0
  150. customer_retention/generators/pipeline_generator/models.py +166 -0
  151. customer_retention/generators/pipeline_generator/renderer.py +2125 -0
  152. customer_retention/generators/spec_generator/__init__.py +37 -0
  153. customer_retention/generators/spec_generator/databricks_generator.py +433 -0
  154. customer_retention/generators/spec_generator/generic_generator.py +373 -0
  155. customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
  156. customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
  157. customer_retention/integrations/__init__.py +0 -0
  158. customer_retention/integrations/adapters/__init__.py +13 -0
  159. customer_retention/integrations/adapters/base.py +10 -0
  160. customer_retention/integrations/adapters/factory.py +25 -0
  161. customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
  162. customer_retention/integrations/adapters/feature_store/base.py +57 -0
  163. customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
  164. customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
  165. customer_retention/integrations/adapters/feature_store/local.py +75 -0
  166. customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
  167. customer_retention/integrations/adapters/mlflow/base.py +32 -0
  168. customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
  169. customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
  170. customer_retention/integrations/adapters/mlflow/local.py +50 -0
  171. customer_retention/integrations/adapters/storage/__init__.py +5 -0
  172. customer_retention/integrations/adapters/storage/base.py +33 -0
  173. customer_retention/integrations/adapters/storage/databricks.py +76 -0
  174. customer_retention/integrations/adapters/storage/local.py +59 -0
  175. customer_retention/integrations/feature_store/__init__.py +47 -0
  176. customer_retention/integrations/feature_store/definitions.py +215 -0
  177. customer_retention/integrations/feature_store/manager.py +744 -0
  178. customer_retention/integrations/feature_store/registry.py +412 -0
  179. customer_retention/integrations/iteration/__init__.py +28 -0
  180. customer_retention/integrations/iteration/context.py +212 -0
  181. customer_retention/integrations/iteration/feedback_collector.py +184 -0
  182. customer_retention/integrations/iteration/orchestrator.py +168 -0
  183. customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
  184. customer_retention/integrations/iteration/signals.py +212 -0
  185. customer_retention/integrations/llm_context/__init__.py +4 -0
  186. customer_retention/integrations/llm_context/context_builder.py +201 -0
  187. customer_retention/integrations/llm_context/prompts.py +100 -0
  188. customer_retention/integrations/streaming/__init__.py +103 -0
  189. customer_retention/integrations/streaming/batch_integration.py +149 -0
  190. customer_retention/integrations/streaming/early_warning_model.py +227 -0
  191. customer_retention/integrations/streaming/event_schema.py +214 -0
  192. customer_retention/integrations/streaming/online_store_writer.py +249 -0
  193. customer_retention/integrations/streaming/realtime_scorer.py +261 -0
  194. customer_retention/integrations/streaming/trigger_engine.py +293 -0
  195. customer_retention/integrations/streaming/window_aggregator.py +393 -0
  196. customer_retention/stages/__init__.py +0 -0
  197. customer_retention/stages/cleaning/__init__.py +9 -0
  198. customer_retention/stages/cleaning/base.py +28 -0
  199. customer_retention/stages/cleaning/missing_handler.py +160 -0
  200. customer_retention/stages/cleaning/outlier_handler.py +204 -0
  201. customer_retention/stages/deployment/__init__.py +28 -0
  202. customer_retention/stages/deployment/batch_scorer.py +106 -0
  203. customer_retention/stages/deployment/champion_challenger.py +299 -0
  204. customer_retention/stages/deployment/model_registry.py +182 -0
  205. customer_retention/stages/deployment/retraining_trigger.py +245 -0
  206. customer_retention/stages/features/__init__.py +73 -0
  207. customer_retention/stages/features/behavioral_features.py +266 -0
  208. customer_retention/stages/features/customer_segmentation.py +505 -0
  209. customer_retention/stages/features/feature_definitions.py +265 -0
  210. customer_retention/stages/features/feature_engineer.py +551 -0
  211. customer_retention/stages/features/feature_manifest.py +340 -0
  212. customer_retention/stages/features/feature_selector.py +239 -0
  213. customer_retention/stages/features/interaction_features.py +160 -0
  214. customer_retention/stages/features/temporal_features.py +243 -0
  215. customer_retention/stages/ingestion/__init__.py +9 -0
  216. customer_retention/stages/ingestion/load_result.py +32 -0
  217. customer_retention/stages/ingestion/loaders.py +195 -0
  218. customer_retention/stages/ingestion/source_registry.py +130 -0
  219. customer_retention/stages/modeling/__init__.py +31 -0
  220. customer_retention/stages/modeling/baseline_trainer.py +139 -0
  221. customer_retention/stages/modeling/cross_validator.py +125 -0
  222. customer_retention/stages/modeling/data_splitter.py +205 -0
  223. customer_retention/stages/modeling/feature_scaler.py +99 -0
  224. customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
  225. customer_retention/stages/modeling/imbalance_handler.py +282 -0
  226. customer_retention/stages/modeling/mlflow_logger.py +95 -0
  227. customer_retention/stages/modeling/model_comparator.py +149 -0
  228. customer_retention/stages/modeling/model_evaluator.py +138 -0
  229. customer_retention/stages/modeling/threshold_optimizer.py +131 -0
  230. customer_retention/stages/monitoring/__init__.py +37 -0
  231. customer_retention/stages/monitoring/alert_manager.py +328 -0
  232. customer_retention/stages/monitoring/drift_detector.py +201 -0
  233. customer_retention/stages/monitoring/performance_monitor.py +242 -0
  234. customer_retention/stages/preprocessing/__init__.py +5 -0
  235. customer_retention/stages/preprocessing/transformer_manager.py +284 -0
  236. customer_retention/stages/profiling/__init__.py +256 -0
  237. customer_retention/stages/profiling/categorical_distribution.py +269 -0
  238. customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
  239. customer_retention/stages/profiling/column_profiler.py +527 -0
  240. customer_retention/stages/profiling/distribution_analysis.py +483 -0
  241. customer_retention/stages/profiling/drift_detector.py +310 -0
  242. customer_retention/stages/profiling/feature_capacity.py +507 -0
  243. customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
  244. customer_retention/stages/profiling/profile_result.py +212 -0
  245. customer_retention/stages/profiling/quality_checks.py +1632 -0
  246. customer_retention/stages/profiling/relationship_detector.py +256 -0
  247. customer_retention/stages/profiling/relationship_recommender.py +454 -0
  248. customer_retention/stages/profiling/report_generator.py +520 -0
  249. customer_retention/stages/profiling/scd_analyzer.py +151 -0
  250. customer_retention/stages/profiling/segment_analyzer.py +632 -0
  251. customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
  252. customer_retention/stages/profiling/target_level_analyzer.py +217 -0
  253. customer_retention/stages/profiling/temporal_analyzer.py +388 -0
  254. customer_retention/stages/profiling/temporal_coverage.py +488 -0
  255. customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
  256. customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
  257. customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
  258. customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
  259. customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
  260. customer_retention/stages/profiling/text_embedder.py +87 -0
  261. customer_retention/stages/profiling/text_processor.py +115 -0
  262. customer_retention/stages/profiling/text_reducer.py +60 -0
  263. customer_retention/stages/profiling/time_series_profiler.py +303 -0
  264. customer_retention/stages/profiling/time_window_aggregator.py +376 -0
  265. customer_retention/stages/profiling/type_detector.py +382 -0
  266. customer_retention/stages/profiling/window_recommendation.py +288 -0
  267. customer_retention/stages/temporal/__init__.py +166 -0
  268. customer_retention/stages/temporal/access_guard.py +180 -0
  269. customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
  270. customer_retention/stages/temporal/data_preparer.py +178 -0
  271. customer_retention/stages/temporal/point_in_time_join.py +134 -0
  272. customer_retention/stages/temporal/point_in_time_registry.py +148 -0
  273. customer_retention/stages/temporal/scenario_detector.py +163 -0
  274. customer_retention/stages/temporal/snapshot_manager.py +259 -0
  275. customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
  276. customer_retention/stages/temporal/timestamp_discovery.py +531 -0
  277. customer_retention/stages/temporal/timestamp_manager.py +255 -0
  278. customer_retention/stages/transformation/__init__.py +13 -0
  279. customer_retention/stages/transformation/binary_handler.py +85 -0
  280. customer_retention/stages/transformation/categorical_encoder.py +245 -0
  281. customer_retention/stages/transformation/datetime_transformer.py +97 -0
  282. customer_retention/stages/transformation/numeric_transformer.py +181 -0
  283. customer_retention/stages/transformation/pipeline.py +257 -0
  284. customer_retention/stages/validation/__init__.py +60 -0
  285. customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
  286. customer_retention/stages/validation/business_sense_gate.py +173 -0
  287. customer_retention/stages/validation/data_quality_gate.py +235 -0
  288. customer_retention/stages/validation/data_validators.py +511 -0
  289. customer_retention/stages/validation/feature_quality_gate.py +183 -0
  290. customer_retention/stages/validation/gates.py +117 -0
  291. customer_retention/stages/validation/leakage_gate.py +352 -0
  292. customer_retention/stages/validation/model_validity_gate.py +213 -0
  293. customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
  294. customer_retention/stages/validation/quality_scorer.py +544 -0
  295. customer_retention/stages/validation/rule_generator.py +57 -0
  296. customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
  297. customer_retention/stages/validation/timeseries_detector.py +769 -0
  298. customer_retention/transforms/__init__.py +47 -0
  299. customer_retention/transforms/artifact_store.py +50 -0
  300. customer_retention/transforms/executor.py +157 -0
  301. customer_retention/transforms/fitted.py +92 -0
  302. customer_retention/transforms/ops.py +148 -0
@@ -0,0 +1,107 @@
1
+ """Prediction error analysis probes."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List
5
+
6
+ import numpy as np
7
+
8
+ from customer_retention.core.compat import DataFrame, Series
9
+
10
+
11
+ @dataclass
12
+ class ErrorPattern:
13
+ error_type: str
14
+ feature: str
15
+ pattern: str
16
+ count: int
17
+
18
+
19
+ @dataclass
20
+ class ErrorAnalysisResult:
21
+ total_errors: int
22
+ error_rate: float
23
+ fp_count: int
24
+ fn_count: int
25
+ false_positives: DataFrame
26
+ false_negatives: DataFrame
27
+ high_confidence_fp: DataFrame
28
+ high_confidence_fn: DataFrame
29
+ fp_confidence_dist: Dict[str, int] = field(default_factory=dict)
30
+ fn_confidence_dist: Dict[str, int] = field(default_factory=dict)
31
+ error_patterns: List[ErrorPattern] = field(default_factory=list)
32
+ hypotheses: List[str] = field(default_factory=list)
33
+
34
+
35
+ class ErrorAnalyzer:
36
+ HIGH_CONFIDENCE_FP_THRESHOLD = 0.8
37
+ HIGH_CONFIDENCE_FN_THRESHOLD = 0.2
38
+
39
+ def analyze_errors(self, model, X: DataFrame, y: Series, threshold: float = 0.5) -> ErrorAnalysisResult:
40
+ y_pred = model.predict(X)
41
+ y_proba = model.predict_proba(X)[:, 1] if hasattr(model, "predict_proba") else y_pred.astype(float)
42
+ fp_mask = (y_pred == 1) & (y == 0)
43
+ fn_mask = (y_pred == 0) & (y == 1)
44
+ false_positives = X[fp_mask].copy()
45
+ false_negatives = X[fn_mask].copy()
46
+ false_positives["probability"] = y_proba[fp_mask]
47
+ false_negatives["probability"] = y_proba[fn_mask]
48
+ high_conf_fp = false_positives[false_positives["probability"] > self.HIGH_CONFIDENCE_FP_THRESHOLD]
49
+ high_conf_fn = false_negatives[false_negatives["probability"] < self.HIGH_CONFIDENCE_FN_THRESHOLD]
50
+ fp_confidence_dist = self._compute_confidence_dist(false_positives["probability"].values if len(false_positives) > 0 else np.array([]))
51
+ fn_confidence_dist = self._compute_confidence_dist(false_negatives["probability"].values if len(false_negatives) > 0 else np.array([]))
52
+ error_patterns = self._find_patterns(X, y, y_pred, fp_mask, fn_mask)
53
+ hypotheses = self._generate_hypotheses(false_positives, false_negatives, high_conf_fp, high_conf_fn)
54
+ total_errors = fp_mask.sum() + fn_mask.sum()
55
+ error_rate = total_errors / len(y) if len(y) > 0 else 0.0
56
+ return ErrorAnalysisResult(
57
+ total_errors=total_errors,
58
+ error_rate=error_rate,
59
+ fp_count=fp_mask.sum(),
60
+ fn_count=fn_mask.sum(),
61
+ false_positives=false_positives,
62
+ false_negatives=false_negatives,
63
+ high_confidence_fp=high_conf_fp,
64
+ high_confidence_fn=high_conf_fn,
65
+ fp_confidence_dist=fp_confidence_dist,
66
+ fn_confidence_dist=fn_confidence_dist,
67
+ error_patterns=error_patterns,
68
+ hypotheses=hypotheses,
69
+ )
70
+
71
+ def _compute_confidence_dist(self, proba: np.ndarray) -> Dict[str, int]:
72
+ if len(proba) == 0:
73
+ return {"low": 0, "medium": 0, "high": 0}
74
+ return {
75
+ "low": int((proba < 0.4).sum()),
76
+ "medium": int(((proba >= 0.4) & (proba < 0.7)).sum()),
77
+ "high": int((proba >= 0.7).sum()),
78
+ }
79
+
80
+ def _find_patterns(self, X: DataFrame, y: Series, y_pred, fp_mask, fn_mask) -> List[ErrorPattern]:
81
+ patterns = []
82
+ for col in X.columns:
83
+ if X[col].dtype in [np.float64, np.int64, np.float32, np.int32]:
84
+ fp_mean = X.loc[fp_mask, col].mean() if fp_mask.sum() > 0 else 0
85
+ correct_mean = X.loc[~fp_mask & ~fn_mask, col].mean()
86
+ if abs(fp_mean - correct_mean) > X[col].std() * 0.5:
87
+ patterns.append(ErrorPattern(
88
+ error_type="FP",
89
+ feature=col,
90
+ pattern=f"FPs have {'higher' if fp_mean > correct_mean else 'lower'} {col}",
91
+ count=fp_mask.sum(),
92
+ ))
93
+ return patterns
94
+
95
+ def _generate_hypotheses(self, fps, fns, high_fp, high_fn) -> List[str]:
96
+ hypotheses = []
97
+ if len(high_fp) > 0:
98
+ hypotheses.append(f"Model is overconfident on {len(high_fp)} false positives. Review these cases.")
99
+ if len(high_fn) > 0:
100
+ hypotheses.append(f"Model is overconfident on {len(high_fn)} false negatives. These are high-risk misses.")
101
+ if len(fps) > len(fns) * 2:
102
+ hypotheses.append("Model is biased toward positive predictions. Consider raising threshold.")
103
+ if len(fns) > len(fps) * 2:
104
+ hypotheses.append("Model is biased toward negative predictions. Consider lowering threshold.")
105
+ if not hypotheses:
106
+ hypotheses.append("Error distribution appears balanced. Focus on feature engineering to reduce errors.")
107
+ return hypotheses
@@ -0,0 +1,394 @@
1
+ """Leakage detection probes for model validation."""
2
+
3
+ import re
4
+ from dataclasses import dataclass, field
5
+ from typing import List, Optional, Set, Tuple
6
+
7
+ import numpy as np
8
+ from sklearn.linear_model import LogisticRegression
9
+ from sklearn.metrics import roc_auc_score
10
+ from sklearn.model_selection import StratifiedKFold, cross_val_predict
11
+
12
+ from customer_retention.core.compat import DataFrame, Series, pd
13
+ from customer_retention.core.components.enums import Severity
14
+ from customer_retention.core.utils.leakage import TEMPORAL_METADATA_COLUMNS, calculate_class_overlap
15
+
16
+
17
+ @dataclass
18
+ class LeakageCheck:
19
+ check_id: str
20
+ feature: str
21
+ severity: Severity
22
+ recommendation: str
23
+ correlation: float = 0.0
24
+ overlap_pct: float = 100.0
25
+ auc: float = 0.5
26
+
27
+
28
+ @dataclass
29
+ class LeakageResult:
30
+ passed: bool
31
+ checks: List[LeakageCheck] = field(default_factory=list)
32
+ critical_issues: List[LeakageCheck] = field(default_factory=list)
33
+ recommendations: List[str] = field(default_factory=list)
34
+
35
+
36
+ class LeakageDetector:
37
+ TEMPORAL_PATTERNS = re.compile(r"(days|since|tenure|recency|last|ago|date|time)", re.IGNORECASE)
38
+ DOMAIN_TARGET_PATTERNS = re.compile(
39
+ r"(churn|reten|cancel|unsubscribe|attrit|lapse|defect|convert|active|inactive|"
40
+ r"leave|stay|renew|expir|terminat|close|deactivat)",
41
+ re.IGNORECASE,
42
+ )
43
+ CROSS_ENTITY_PATTERNS = re.compile(
44
+ r"(global|population|all_user|cross_entity|market_avg|cohort_avg|"
45
+ r"overall_mean|overall_std|benchmark|percentile_rank)",
46
+ re.IGNORECASE,
47
+ )
48
+ CORRELATION_CRITICAL, CORRELATION_HIGH, CORRELATION_MEDIUM = 0.90, 0.70, 0.50
49
+ SEPARATION_CRITICAL, SEPARATION_HIGH, SEPARATION_MEDIUM = 0.0, 1.0, 5.0
50
+ AUC_CRITICAL, AUC_HIGH = 0.90, 0.80
51
+ CV_FOLDS = 5
52
+ NUMERIC_DTYPES = (np.float64, np.int64, np.float32, np.int32)
53
+
54
+ def __init__(self, feature_timestamp_column: str = "feature_timestamp", label_timestamp_column: str = "label_timestamp"):
55
+ self.feature_timestamp_column = feature_timestamp_column
56
+ self.label_timestamp_column = label_timestamp_column
57
+ self._excluded_columns: Set[str] = set(TEMPORAL_METADATA_COLUMNS)
58
+
59
+ def _get_analyzable_columns(self, X: DataFrame) -> List[str]:
60
+ return [c for c in X.columns if c not in self._excluded_columns]
61
+
62
+ def _get_numeric_columns(self, X: DataFrame) -> List[str]:
63
+ return [c for c in self._get_analyzable_columns(X) if X[c].dtype in self.NUMERIC_DTYPES]
64
+
65
+ def _safe_correlation(self, X: DataFrame, col: str, y: Series) -> float:
66
+ corr = abs(X[col].corr(y))
67
+ return 0.0 if np.isnan(corr) else corr
68
+
69
+ def _build_result(self, checks: List[LeakageCheck]) -> LeakageResult:
70
+ critical = [c for c in checks if c.severity == Severity.CRITICAL]
71
+ return LeakageResult(passed=len(critical) == 0, checks=checks, critical_issues=critical)
72
+
73
+ def check_correlations(self, X: DataFrame, y: Series) -> LeakageResult:
74
+ checks = []
75
+ for col in self._get_numeric_columns(X):
76
+ corr = self._safe_correlation(X, col, y)
77
+ severity, check_id = self._classify_correlation(corr)
78
+ if severity != Severity.INFO:
79
+ checks.append(LeakageCheck(
80
+ check_id=check_id, feature=col, severity=severity,
81
+ recommendation=self._correlation_recommendation(col, corr), correlation=corr,
82
+ ))
83
+ return self._build_result(checks)
84
+
85
+ def _classify_correlation(self, corr: float) -> Tuple[Severity, str]:
86
+ if corr > self.CORRELATION_CRITICAL:
87
+ return Severity.CRITICAL, "LD001"
88
+ if corr > self.CORRELATION_HIGH:
89
+ return Severity.HIGH, "LD002"
90
+ if corr > self.CORRELATION_MEDIUM:
91
+ return Severity.MEDIUM, "LD003"
92
+ return Severity.INFO, "LD000"
93
+
94
+ def _correlation_recommendation(self, feature: str, corr: float) -> str:
95
+ if corr > self.CORRELATION_CRITICAL:
96
+ return f"REMOVE {feature}: correlation {corr:.2f} indicates likely data leakage"
97
+ if corr > self.CORRELATION_HIGH:
98
+ return f"INVESTIGATE {feature}: correlation {corr:.2f} is suspiciously high"
99
+ return f"MONITOR {feature}: elevated correlation {corr:.2f}"
100
+
101
+ def check_separation(self, X: DataFrame, y: Series) -> LeakageResult:
102
+ checks = []
103
+ for col in self._get_numeric_columns(X):
104
+ overlap_pct = calculate_class_overlap(X[col], y)
105
+ severity, check_id = self._classify_separation(overlap_pct)
106
+ checks.append(LeakageCheck(
107
+ check_id=check_id, feature=col, severity=severity,
108
+ recommendation=self._separation_recommendation(col, overlap_pct), overlap_pct=overlap_pct,
109
+ ))
110
+ return self._build_result(checks)
111
+
112
+ def _classify_separation(self, overlap_pct: float) -> Tuple[Severity, str]:
113
+ if overlap_pct <= self.SEPARATION_CRITICAL:
114
+ return Severity.CRITICAL, "LD010"
115
+ if overlap_pct < self.SEPARATION_HIGH:
116
+ return Severity.HIGH, "LD011"
117
+ if overlap_pct < self.SEPARATION_MEDIUM:
118
+ return Severity.MEDIUM, "LD012"
119
+ return Severity.INFO, "LD000"
120
+
121
+ def _separation_recommendation(self, feature: str, overlap_pct: float) -> str:
122
+ if overlap_pct <= self.SEPARATION_CRITICAL:
123
+ return f"REMOVE {feature}: perfect class separation indicates leakage"
124
+ if overlap_pct < self.SEPARATION_HIGH:
125
+ return f"REMOVE {feature}: near-perfect separation ({overlap_pct:.1f}% overlap)"
126
+ if overlap_pct < self.SEPARATION_MEDIUM:
127
+ return f"INVESTIGATE {feature}: high separation ({overlap_pct:.1f}% overlap)"
128
+ return f"OK: {feature} has normal class overlap"
129
+
130
+ def check_temporal_logic(self, X: DataFrame, y: Series) -> LeakageResult:
131
+ checks = []
132
+ for col in self._get_numeric_columns(X):
133
+ if not self.TEMPORAL_PATTERNS.search(col):
134
+ continue
135
+ corr = self._safe_correlation(X, col, y)
136
+ if corr > self.CORRELATION_HIGH:
137
+ checks.append(LeakageCheck(
138
+ check_id="LD022", feature=col, severity=Severity.HIGH,
139
+ recommendation=f"REVIEW temporal feature {col}: high correlation ({corr:.2f}) may indicate future data",
140
+ correlation=corr,
141
+ ))
142
+ elif corr > self.CORRELATION_MEDIUM:
143
+ checks.append(LeakageCheck(
144
+ check_id="LD022", feature=col, severity=Severity.MEDIUM,
145
+ recommendation=f"CHECK temporal feature {col}: verify reference date logic", correlation=corr,
146
+ ))
147
+ return self._build_result(checks)
148
+
149
+ def check_single_feature_auc(self, X: DataFrame, y: Series) -> LeakageResult:
150
+ checks = []
151
+ for col in self._get_numeric_columns(X):
152
+ auc = self._compute_single_feature_auc(X[col], y)
153
+ is_temporal = bool(self.TEMPORAL_PATTERNS.search(col))
154
+ severity, check_id = self._classify_auc(auc, is_temporal=is_temporal)
155
+ if severity != Severity.INFO:
156
+ checks.append(LeakageCheck(
157
+ check_id=check_id, feature=col, severity=severity,
158
+ recommendation=self._auc_recommendation(col, auc, is_temporal=is_temporal), auc=auc,
159
+ ))
160
+ return self._build_result(checks)
161
+
162
+ def _compute_single_feature_auc(self, feature: Series, y: Series) -> float:
163
+ try:
164
+ X_single = feature.values.reshape(-1, 1)
165
+ mask = ~np.isnan(X_single.flatten())
166
+ X_clean, y_clean = X_single[mask], y.values[mask]
167
+ if len(np.unique(y_clean)) < 2 or min(np.bincount(y_clean.astype(int))) < self.CV_FOLDS:
168
+ return 0.5
169
+ model = LogisticRegression(max_iter=200, solver="lbfgs", random_state=42)
170
+ cv = StratifiedKFold(n_splits=self.CV_FOLDS, shuffle=True, random_state=42)
171
+ proba = cross_val_predict(model, X_clean, y_clean, cv=cv, method="predict_proba")
172
+ return roc_auc_score(y_clean, proba[:, 1])
173
+ except Exception:
174
+ return 0.5
175
+
176
+ def _classify_auc(self, auc: float, *, is_temporal: bool = False) -> Tuple[Severity, str]:
177
+ if is_temporal:
178
+ if auc > self.AUC_CRITICAL:
179
+ return Severity.HIGH, "LD031"
180
+ return Severity.INFO, "LD000"
181
+ if auc > self.AUC_CRITICAL:
182
+ return Severity.CRITICAL, "LD030"
183
+ if auc > self.AUC_HIGH:
184
+ return Severity.HIGH, "LD031"
185
+ return Severity.INFO, "LD000"
186
+
187
+ def _auc_recommendation(self, feature: str, auc: float, *, is_temporal: bool = False) -> str:
188
+ if auc > self.AUC_CRITICAL:
189
+ if is_temporal:
190
+ return f"REVIEW {feature}: temporal feature AUC {auc:.2f} is high but expected for recency/tenure features"
191
+ return f"REMOVE {feature}: single-feature AUC {auc:.2f} indicates leakage"
192
+ if auc > self.AUC_HIGH:
193
+ return f"INVESTIGATE {feature}: single-feature AUC {auc:.2f} is very high"
194
+ return f"OK: {feature} has normal predictive power"
195
+
196
+ def check_point_in_time(self, df: DataFrame) -> LeakageResult:
197
+ checks = []
198
+ feature_ts = self._parse_timestamp(df, self.feature_timestamp_column)
199
+ if feature_ts is None:
200
+ return self._build_result([])
201
+
202
+ self._check_label_timestamp_violation(df, feature_ts, checks)
203
+ self._check_datetime_column_violations(df, feature_ts, checks)
204
+ return self._build_result(checks)
205
+
206
+ def _parse_timestamp(self, df: DataFrame, col: str) -> Optional[Series]:
207
+ if col not in df.columns:
208
+ return None
209
+ try:
210
+ return pd.to_datetime(df[col], errors="coerce", format="mixed")
211
+ except Exception:
212
+ return None
213
+
214
+ def _check_label_timestamp_violation(self, df: DataFrame, feature_ts: Series, checks: List[LeakageCheck]) -> None:
215
+ label_ts = self._parse_timestamp(df, self.label_timestamp_column)
216
+ if label_ts is None:
217
+ return
218
+ violations = (feature_ts > label_ts).sum()
219
+ if violations > 0:
220
+ checks.append(LeakageCheck(
221
+ check_id="LD040", feature=self.feature_timestamp_column, severity=Severity.CRITICAL,
222
+ recommendation=f"FIX: {violations} rows have feature_timestamp > label_timestamp",
223
+ ))
224
+
225
+ def _check_datetime_column_violations(self, df: DataFrame, feature_ts: Series, checks: List[LeakageCheck]) -> None:
226
+ skip_cols = {self.feature_timestamp_column, self.label_timestamp_column}
227
+ for col in df.select_dtypes(include=["datetime64"]).columns:
228
+ if col in skip_cols:
229
+ continue
230
+ try:
231
+ col_ts = pd.to_datetime(df[col], errors="coerce", format="mixed")
232
+ violations = (col_ts > feature_ts).sum()
233
+ if violations > 0:
234
+ pct = violations / len(df) * 100
235
+ checks.append(LeakageCheck(
236
+ check_id="LD041", feature=col,
237
+ severity=Severity.CRITICAL if pct > 5 else Severity.HIGH,
238
+ recommendation=f"INVESTIGATE {col}: {violations} rows ({pct:.1f}%) have values after feature_timestamp",
239
+ ))
240
+ except Exception:
241
+ continue
242
+
243
+ def check_uniform_timestamps(self, df: DataFrame, timestamp_column: str = "event_timestamp") -> LeakageResult:
244
+ checks = []
245
+ if timestamp_column not in df.columns:
246
+ return self._build_result([])
247
+
248
+ try:
249
+ timestamps = pd.to_datetime(df[timestamp_column], errors="coerce").dropna()
250
+ if len(timestamps) < 2:
251
+ return self._build_result([])
252
+
253
+ if timestamps.nunique() == 1:
254
+ checks.append(LeakageCheck(
255
+ check_id="LD050", feature=timestamp_column, severity=Severity.HIGH,
256
+ recommendation=(
257
+ f"INVESTIGATE {timestamp_column}: All {len(timestamps)} timestamps are identical. "
258
+ "This suggests datetime.now() was used instead of actual aggregation reference dates."
259
+ ),
260
+ ))
261
+ elif (timestamps.max() - timestamps.min()).total_seconds() < 60:
262
+ time_span = (timestamps.max() - timestamps.min()).total_seconds()
263
+ checks.append(LeakageCheck(
264
+ check_id="LD050", feature=timestamp_column, severity=Severity.MEDIUM,
265
+ recommendation=(
266
+ f"REVIEW {timestamp_column}: Timestamps span only {time_span:.1f} seconds across "
267
+ f"{len(timestamps)} records. Verify timestamps reflect actual observation dates."
268
+ ),
269
+ ))
270
+ except Exception:
271
+ pass
272
+ return self._build_result(checks)
273
+
274
+ def check_target_in_features(self, X: DataFrame, y: Series, target_name: str = "target") -> LeakageResult:
275
+ checks = []
276
+ self._check_target_column_direct(X, target_name, checks)
277
+ self._check_target_derived_names(X, target_name, checks)
278
+ self._check_perfect_correlation(X, y, target_name, checks)
279
+ return self._build_result(checks)
280
+
281
+ def _check_target_column_direct(self, X: DataFrame, target_name: str, checks: List[LeakageCheck]) -> None:
282
+ if target_name in X.columns:
283
+ checks.append(LeakageCheck(
284
+ check_id="LD052", feature=target_name, severity=Severity.CRITICAL,
285
+ recommendation=f"REMOVE {target_name}: Target column found in feature matrix. Direct data leakage.",
286
+ correlation=1.0,
287
+ ))
288
+
289
+ def _check_target_derived_names(self, X: DataFrame, target_name: str, checks: List[LeakageCheck]) -> None:
290
+ patterns = [f"{target_name}_", f"_{target_name}"]
291
+ for col in X.columns:
292
+ col_lower = col.lower()
293
+ if any(p.lower() in col_lower for p in patterns):
294
+ checks.append(LeakageCheck(
295
+ check_id="LD052", feature=col, severity=Severity.CRITICAL,
296
+ recommendation=f"REMOVE {col}: Column name suggests derivation from target '{target_name}'.",
297
+ ))
298
+
299
+ def _check_perfect_correlation(self, X: DataFrame, y: Series, target_name: str, checks: List[LeakageCheck]) -> None:
300
+ already_flagged = {target_name} | {c.feature for c in checks}
301
+ for col in self._get_numeric_columns(X):
302
+ if col in already_flagged:
303
+ continue
304
+ try:
305
+ corr = abs(X[col].corr(y))
306
+ if not np.isnan(corr) and corr > 0.99:
307
+ checks.append(LeakageCheck(
308
+ check_id="LD052", feature=col, severity=Severity.CRITICAL,
309
+ recommendation=f"REMOVE {col}: Perfect correlation ({corr:.4f}) indicates leakage.",
310
+ correlation=corr,
311
+ ))
312
+ except Exception:
313
+ pass
314
+
315
+ def check_cross_entity_leakage(self, X: DataFrame, y: Series, entity_column: str, timestamp_column: str) -> LeakageResult:
316
+ checks = []
317
+ for col in self._get_numeric_columns(X):
318
+ if not self.CROSS_ENTITY_PATTERNS.search(col):
319
+ continue
320
+ corr = self._safe_correlation(X, col, y)
321
+ severity = Severity.HIGH if corr > self.CORRELATION_MEDIUM else Severity.MEDIUM
322
+ checks.append(LeakageCheck(
323
+ check_id="LD060", feature=col, severity=severity,
324
+ recommendation=(
325
+ f"REVIEW {col}: Cross-entity aggregation pattern detected. Correlation: {corr:.2f}. "
326
+ "Verify this feature doesn't use future data from other entities."
327
+ ),
328
+ correlation=corr,
329
+ ))
330
+ return self._build_result(checks)
331
+
332
+ def check_temporal_split(self, train_timestamps: Series, test_timestamps: Series, timestamp_column: str = "timestamp") -> LeakageResult:
333
+ checks = []
334
+ try:
335
+ train_ts = pd.to_datetime(train_timestamps, errors="coerce").dropna()
336
+ test_ts = pd.to_datetime(test_timestamps, errors="coerce").dropna()
337
+ if len(train_ts) == 0 or len(test_ts) == 0:
338
+ return self._build_result([])
339
+
340
+ train_max, test_min = train_ts.max(), test_ts.min()
341
+ if train_max >= test_min:
342
+ overlap_count = (train_ts >= test_min).sum()
343
+ overlap_pct = overlap_count / len(train_ts) * 100
344
+ checks.append(LeakageCheck(
345
+ check_id="LD061", feature=timestamp_column, severity=Severity.CRITICAL,
346
+ recommendation=(
347
+ f"FIX temporal split: Train max ({train_max}) >= Test min ({test_min}). "
348
+ f"{overlap_count} train rows ({overlap_pct:.1f}%) overlap with test period."
349
+ ),
350
+ ))
351
+ except Exception:
352
+ pass
353
+ return self._build_result(checks)
354
+
355
+ def check_domain_target_patterns(self, X: DataFrame, y: Series) -> LeakageResult:
356
+ checks = []
357
+ for col in self._get_numeric_columns(X):
358
+ if not self.DOMAIN_TARGET_PATTERNS.search(col):
359
+ continue
360
+ corr = self._safe_correlation(X, col, y)
361
+ severity, recommendation = self._classify_domain_pattern(col, corr)
362
+ checks.append(LeakageCheck(
363
+ check_id="LD053", feature=col, severity=severity,
364
+ recommendation=recommendation, correlation=corr,
365
+ ))
366
+ return self._build_result(checks)
367
+
368
+ def _classify_domain_pattern(self, col: str, corr: float) -> Tuple[Severity, str]:
369
+ if corr > self.CORRELATION_HIGH:
370
+ return Severity.CRITICAL, f"REMOVE {col}: Domain pattern with high correlation ({corr:.2f}) confirms likely leakage."
371
+ if corr > self.CORRELATION_MEDIUM:
372
+ return Severity.HIGH, f"INVESTIGATE {col}: Domain pattern with correlation ({corr:.2f}) warrants review."
373
+ return Severity.MEDIUM, f"REVIEW {col}: Contains churn/retention terminology. Low correlation ({corr:.2f}) suggests safe."
374
+
375
+ def run_all_checks(self, X: DataFrame, y: Series, include_pit: bool = True) -> LeakageResult:
376
+ all_checks = (
377
+ self.check_correlations(X, y).checks
378
+ + self.check_separation(X, y).checks
379
+ + self.check_temporal_logic(X, y).checks
380
+ + self.check_single_feature_auc(X, y).checks
381
+ )
382
+
383
+ if include_pit:
384
+ df_with_y = X.copy()
385
+ df_with_y["_target"] = y
386
+ all_checks.extend(self.check_point_in_time(df_with_y).checks)
387
+ all_checks.extend(self.check_uniform_timestamps(df_with_y, timestamp_column=self.feature_timestamp_column).checks)
388
+
389
+ all_checks.extend(self.check_target_in_features(X, y).checks)
390
+ all_checks.extend(self.check_domain_target_patterns(X, y).checks)
391
+
392
+ critical = [c for c in all_checks if c.severity == Severity.CRITICAL]
393
+ recommendations = list({c.recommendation for c in all_checks if c.severity in [Severity.CRITICAL, Severity.HIGH]})
394
+ return LeakageResult(passed=len(critical) == 0, checks=all_checks, critical_issues=critical, recommendations=recommendations)
@@ -0,0 +1,140 @@
1
+ """Noise robustness testing probes."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List
5
+
6
+ import numpy as np
7
+ from sklearn.metrics import roc_auc_score
8
+
9
+ from customer_retention.core.compat import DataFrame, Series
10
+ from customer_retention.core.components.enums import Severity
11
+
12
+
13
+ @dataclass
14
+ class NoiseCheck:
15
+ check_id: str
16
+ metric: str
17
+ severity: Severity
18
+ recommendation: str
19
+ value: float = 0.0
20
+
21
+
22
+ @dataclass
23
+ class NoiseResult:
24
+ passed: bool
25
+ checks: List[NoiseCheck] = field(default_factory=list)
26
+ degradation_curve: List[Dict[str, float]] = field(default_factory=list)
27
+ robustness_score: float = 1.0
28
+ feature_importance: Dict[str, float] = field(default_factory=dict)
29
+
30
+
31
+ class NoiseTester:
32
+ NOISE_LEVELS = {"low": 0.01, "medium": 0.05, "high": 0.10, "extreme": 0.20}
33
+ DROPOUT_LEVELS = {"low": 0.05, "medium": 0.10, "high": 0.20, "extreme": 0.30}
34
+ DEGRADATION_LOW_THRESHOLD = 0.10
35
+ DEGRADATION_MEDIUM_THRESHOLD = 0.20
36
+
37
+ def test_gaussian_noise(self, model, X: DataFrame, y: Series) -> NoiseResult:
38
+ baseline_score = self._get_score(model, X, y)
39
+ degradation_curve = []
40
+ checks = []
41
+ for level_name, noise_factor in self.NOISE_LEVELS.items():
42
+ X_noisy = X.copy()
43
+ for col in X.columns:
44
+ if X[col].dtype in [np.float64, np.int64, np.float32, np.int32]:
45
+ std = X[col].std()
46
+ X_noisy[col] = X[col] + np.random.randn(len(X)) * std * noise_factor
47
+ noisy_score = self._get_score(model, X_noisy, y)
48
+ degradation = (baseline_score - noisy_score) / baseline_score if baseline_score > 0 else 0
49
+ degradation_curve.append({
50
+ "noise_level": level_name,
51
+ "noise_factor": noise_factor,
52
+ "score": noisy_score,
53
+ "degradation": degradation,
54
+ })
55
+ if level_name == "low" and degradation > self.DEGRADATION_LOW_THRESHOLD:
56
+ checks.append(NoiseCheck(
57
+ check_id="NR001",
58
+ metric="degradation_low_noise",
59
+ severity=Severity.HIGH,
60
+ recommendation=f"HIGH: Model fragile to low noise ({degradation:.1%} degradation). Consider regularization.",
61
+ value=degradation,
62
+ ))
63
+ if level_name == "medium" and degradation > self.DEGRADATION_MEDIUM_THRESHOLD:
64
+ checks.append(NoiseCheck(
65
+ check_id="NR002",
66
+ metric="degradation_medium_noise",
67
+ severity=Severity.MEDIUM,
68
+ recommendation=f"MEDIUM: Model moderately fragile ({degradation:.1%} at medium noise).",
69
+ value=degradation,
70
+ ))
71
+ robustness_score = self._compute_robustness(degradation_curve)
72
+ critical = [c for c in checks if c.severity == Severity.CRITICAL]
73
+ return NoiseResult(passed=len(critical) == 0, checks=checks, degradation_curve=degradation_curve, robustness_score=robustness_score)
74
+
75
+ def test_feature_dropout(self, model, X: DataFrame, y: Series) -> NoiseResult:
76
+ baseline_score = self._get_score(model, X, y)
77
+ degradation_curve = []
78
+ feature_importance = {}
79
+ for col in X.columns:
80
+ X_dropped = X.copy()
81
+ X_dropped[col] = 0
82
+ dropped_score = self._get_score(model, X_dropped, y)
83
+ importance = (baseline_score - dropped_score) / baseline_score if baseline_score > 0 else 0
84
+ feature_importance[col] = importance
85
+ checks = []
86
+ max_importance = max(feature_importance.values()) if feature_importance else 0
87
+ if max_importance > 0.5:
88
+ dominant_feature = max(feature_importance, key=feature_importance.get)
89
+ checks.append(NoiseCheck(
90
+ check_id="NR003",
91
+ metric="single_feature_dependency",
92
+ severity=Severity.HIGH,
93
+ recommendation=f"HIGH: Feature '{dominant_feature}' causes {max_importance:.1%} degradation when dropped. Model too dependent.",
94
+ value=max_importance,
95
+ ))
96
+ for level_name, dropout_rate in self.DROPOUT_LEVELS.items():
97
+ X_dropout = X.copy()
98
+ n_drop = int(len(X.columns) * dropout_rate)
99
+ cols_to_drop = np.random.choice(X.columns, min(n_drop, len(X.columns)), replace=False)
100
+ for col in cols_to_drop:
101
+ X_dropout[col] = 0
102
+ dropout_score = self._get_score(model, X_dropout, y)
103
+ degradation = (baseline_score - dropout_score) / baseline_score if baseline_score > 0 else 0
104
+ degradation_curve.append({
105
+ "dropout_level": level_name,
106
+ "dropout_rate": dropout_rate,
107
+ "score": dropout_score,
108
+ "degradation": degradation,
109
+ })
110
+ robustness_score = self._compute_robustness(degradation_curve)
111
+ critical = [c for c in checks if c.severity == Severity.CRITICAL]
112
+ return NoiseResult(passed=len(critical) == 0, checks=checks, degradation_curve=degradation_curve, robustness_score=robustness_score, feature_importance=feature_importance)
113
+
114
+ def _get_score(self, model, X: DataFrame, y: Series) -> float:
115
+ try:
116
+ y_proba = model.predict_proba(X)[:, 1]
117
+ return roc_auc_score(y, y_proba)
118
+ except Exception:
119
+ return 0.5
120
+
121
+ def _compute_robustness(self, degradation_curve: List[Dict]) -> float:
122
+ if not degradation_curve:
123
+ return 1.0
124
+ degradations = [d.get("degradation", 0) for d in degradation_curve]
125
+ return max(0, 1 - np.mean(degradations))
126
+
127
+ def run_all(self, model, X: DataFrame, y: Series) -> NoiseResult:
128
+ gaussian_result = self.test_gaussian_noise(model, X, y)
129
+ dropout_result = self.test_feature_dropout(model, X, y)
130
+ all_checks = gaussian_result.checks + dropout_result.checks
131
+ all_degradation = gaussian_result.degradation_curve + dropout_result.degradation_curve
132
+ avg_robustness = (gaussian_result.robustness_score + dropout_result.robustness_score) / 2
133
+ critical = [c for c in all_checks if c.severity == Severity.CRITICAL]
134
+ return NoiseResult(
135
+ passed=len(critical) == 0,
136
+ checks=all_checks,
137
+ degradation_curve=all_degradation,
138
+ robustness_score=avg_robustness,
139
+ feature_importance=dropout_result.feature_importance,
140
+ )