churnkit 0.75.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
- churnkit-0.75.0a1.dist-info/METADATA +229 -0
- churnkit-0.75.0a1.dist-info/RECORD +302 -0
- churnkit-0.75.0a1.dist-info/WHEEL +4 -0
- churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
- churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
- customer_retention/__init__.py +37 -0
- customer_retention/analysis/__init__.py +0 -0
- customer_retention/analysis/auto_explorer/__init__.py +62 -0
- customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
- customer_retention/analysis/auto_explorer/explorer.py +258 -0
- customer_retention/analysis/auto_explorer/findings.py +291 -0
- customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
- customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
- customer_retention/analysis/auto_explorer/recommendations.py +418 -0
- customer_retention/analysis/business/__init__.py +26 -0
- customer_retention/analysis/business/ab_test_designer.py +144 -0
- customer_retention/analysis/business/fairness_analyzer.py +166 -0
- customer_retention/analysis/business/intervention_matcher.py +121 -0
- customer_retention/analysis/business/report_generator.py +222 -0
- customer_retention/analysis/business/risk_profile.py +199 -0
- customer_retention/analysis/business/roi_analyzer.py +139 -0
- customer_retention/analysis/diagnostics/__init__.py +20 -0
- customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
- customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
- customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
- customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
- customer_retention/analysis/diagnostics/noise_tester.py +140 -0
- customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
- customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
- customer_retention/analysis/discovery/__init__.py +8 -0
- customer_retention/analysis/discovery/config_generator.py +49 -0
- customer_retention/analysis/discovery/discovery_flow.py +19 -0
- customer_retention/analysis/discovery/type_inferencer.py +147 -0
- customer_retention/analysis/interpretability/__init__.py +13 -0
- customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
- customer_retention/analysis/interpretability/counterfactual.py +175 -0
- customer_retention/analysis/interpretability/individual_explainer.py +141 -0
- customer_retention/analysis/interpretability/pdp_generator.py +103 -0
- customer_retention/analysis/interpretability/shap_explainer.py +106 -0
- customer_retention/analysis/jupyter_save_hook.py +28 -0
- customer_retention/analysis/notebook_html_exporter.py +136 -0
- customer_retention/analysis/notebook_progress.py +60 -0
- customer_retention/analysis/plotly_preprocessor.py +154 -0
- customer_retention/analysis/recommendations/__init__.py +54 -0
- customer_retention/analysis/recommendations/base.py +158 -0
- customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
- customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
- customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
- customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
- customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
- customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
- customer_retention/analysis/recommendations/datetime/extract.py +149 -0
- customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
- customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
- customer_retention/analysis/recommendations/pipeline.py +74 -0
- customer_retention/analysis/recommendations/registry.py +76 -0
- customer_retention/analysis/recommendations/selection/__init__.py +3 -0
- customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
- customer_retention/analysis/recommendations/transform/__init__.py +4 -0
- customer_retention/analysis/recommendations/transform/power.py +94 -0
- customer_retention/analysis/recommendations/transform/scale.py +112 -0
- customer_retention/analysis/visualization/__init__.py +15 -0
- customer_retention/analysis/visualization/chart_builder.py +2619 -0
- customer_retention/analysis/visualization/console.py +122 -0
- customer_retention/analysis/visualization/display.py +171 -0
- customer_retention/analysis/visualization/number_formatter.py +36 -0
- customer_retention/artifacts/__init__.py +3 -0
- customer_retention/artifacts/fit_artifact_registry.py +146 -0
- customer_retention/cli.py +93 -0
- customer_retention/core/__init__.py +0 -0
- customer_retention/core/compat/__init__.py +193 -0
- customer_retention/core/compat/detection.py +99 -0
- customer_retention/core/compat/ops.py +48 -0
- customer_retention/core/compat/pandas_backend.py +57 -0
- customer_retention/core/compat/spark_backend.py +75 -0
- customer_retention/core/components/__init__.py +11 -0
- customer_retention/core/components/base.py +79 -0
- customer_retention/core/components/components/__init__.py +13 -0
- customer_retention/core/components/components/deployer.py +26 -0
- customer_retention/core/components/components/explainer.py +26 -0
- customer_retention/core/components/components/feature_eng.py +33 -0
- customer_retention/core/components/components/ingester.py +34 -0
- customer_retention/core/components/components/profiler.py +34 -0
- customer_retention/core/components/components/trainer.py +38 -0
- customer_retention/core/components/components/transformer.py +36 -0
- customer_retention/core/components/components/validator.py +37 -0
- customer_retention/core/components/enums.py +33 -0
- customer_retention/core/components/orchestrator.py +94 -0
- customer_retention/core/components/registry.py +59 -0
- customer_retention/core/config/__init__.py +39 -0
- customer_retention/core/config/column_config.py +95 -0
- customer_retention/core/config/experiments.py +71 -0
- customer_retention/core/config/pipeline_config.py +117 -0
- customer_retention/core/config/source_config.py +83 -0
- customer_retention/core/utils/__init__.py +28 -0
- customer_retention/core/utils/leakage.py +85 -0
- customer_retention/core/utils/severity.py +53 -0
- customer_retention/core/utils/statistics.py +90 -0
- customer_retention/generators/__init__.py +0 -0
- customer_retention/generators/notebook_generator/__init__.py +167 -0
- customer_retention/generators/notebook_generator/base.py +55 -0
- customer_retention/generators/notebook_generator/cell_builder.py +49 -0
- customer_retention/generators/notebook_generator/config.py +47 -0
- customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
- customer_retention/generators/notebook_generator/local_generator.py +48 -0
- customer_retention/generators/notebook_generator/project_init.py +174 -0
- customer_retention/generators/notebook_generator/runner.py +150 -0
- customer_retention/generators/notebook_generator/script_generator.py +110 -0
- customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
- customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
- customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
- customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
- customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
- customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
- customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
- customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
- customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
- customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
- customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
- customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
- customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
- customer_retention/generators/orchestration/__init__.py +23 -0
- customer_retention/generators/orchestration/code_generator.py +196 -0
- customer_retention/generators/orchestration/context.py +147 -0
- customer_retention/generators/orchestration/data_materializer.py +188 -0
- customer_retention/generators/orchestration/databricks_exporter.py +411 -0
- customer_retention/generators/orchestration/doc_generator.py +311 -0
- customer_retention/generators/pipeline_generator/__init__.py +26 -0
- customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
- customer_retention/generators/pipeline_generator/generator.py +142 -0
- customer_retention/generators/pipeline_generator/models.py +166 -0
- customer_retention/generators/pipeline_generator/renderer.py +2125 -0
- customer_retention/generators/spec_generator/__init__.py +37 -0
- customer_retention/generators/spec_generator/databricks_generator.py +433 -0
- customer_retention/generators/spec_generator/generic_generator.py +373 -0
- customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
- customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
- customer_retention/integrations/__init__.py +0 -0
- customer_retention/integrations/adapters/__init__.py +13 -0
- customer_retention/integrations/adapters/base.py +10 -0
- customer_retention/integrations/adapters/factory.py +25 -0
- customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
- customer_retention/integrations/adapters/feature_store/base.py +57 -0
- customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
- customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
- customer_retention/integrations/adapters/feature_store/local.py +75 -0
- customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
- customer_retention/integrations/adapters/mlflow/base.py +32 -0
- customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
- customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
- customer_retention/integrations/adapters/mlflow/local.py +50 -0
- customer_retention/integrations/adapters/storage/__init__.py +5 -0
- customer_retention/integrations/adapters/storage/base.py +33 -0
- customer_retention/integrations/adapters/storage/databricks.py +76 -0
- customer_retention/integrations/adapters/storage/local.py +59 -0
- customer_retention/integrations/feature_store/__init__.py +47 -0
- customer_retention/integrations/feature_store/definitions.py +215 -0
- customer_retention/integrations/feature_store/manager.py +744 -0
- customer_retention/integrations/feature_store/registry.py +412 -0
- customer_retention/integrations/iteration/__init__.py +28 -0
- customer_retention/integrations/iteration/context.py +212 -0
- customer_retention/integrations/iteration/feedback_collector.py +184 -0
- customer_retention/integrations/iteration/orchestrator.py +168 -0
- customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
- customer_retention/integrations/iteration/signals.py +212 -0
- customer_retention/integrations/llm_context/__init__.py +4 -0
- customer_retention/integrations/llm_context/context_builder.py +201 -0
- customer_retention/integrations/llm_context/prompts.py +100 -0
- customer_retention/integrations/streaming/__init__.py +103 -0
- customer_retention/integrations/streaming/batch_integration.py +149 -0
- customer_retention/integrations/streaming/early_warning_model.py +227 -0
- customer_retention/integrations/streaming/event_schema.py +214 -0
- customer_retention/integrations/streaming/online_store_writer.py +249 -0
- customer_retention/integrations/streaming/realtime_scorer.py +261 -0
- customer_retention/integrations/streaming/trigger_engine.py +293 -0
- customer_retention/integrations/streaming/window_aggregator.py +393 -0
- customer_retention/stages/__init__.py +0 -0
- customer_retention/stages/cleaning/__init__.py +9 -0
- customer_retention/stages/cleaning/base.py +28 -0
- customer_retention/stages/cleaning/missing_handler.py +160 -0
- customer_retention/stages/cleaning/outlier_handler.py +204 -0
- customer_retention/stages/deployment/__init__.py +28 -0
- customer_retention/stages/deployment/batch_scorer.py +106 -0
- customer_retention/stages/deployment/champion_challenger.py +299 -0
- customer_retention/stages/deployment/model_registry.py +182 -0
- customer_retention/stages/deployment/retraining_trigger.py +245 -0
- customer_retention/stages/features/__init__.py +73 -0
- customer_retention/stages/features/behavioral_features.py +266 -0
- customer_retention/stages/features/customer_segmentation.py +505 -0
- customer_retention/stages/features/feature_definitions.py +265 -0
- customer_retention/stages/features/feature_engineer.py +551 -0
- customer_retention/stages/features/feature_manifest.py +340 -0
- customer_retention/stages/features/feature_selector.py +239 -0
- customer_retention/stages/features/interaction_features.py +160 -0
- customer_retention/stages/features/temporal_features.py +243 -0
- customer_retention/stages/ingestion/__init__.py +9 -0
- customer_retention/stages/ingestion/load_result.py +32 -0
- customer_retention/stages/ingestion/loaders.py +195 -0
- customer_retention/stages/ingestion/source_registry.py +130 -0
- customer_retention/stages/modeling/__init__.py +31 -0
- customer_retention/stages/modeling/baseline_trainer.py +139 -0
- customer_retention/stages/modeling/cross_validator.py +125 -0
- customer_retention/stages/modeling/data_splitter.py +205 -0
- customer_retention/stages/modeling/feature_scaler.py +99 -0
- customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
- customer_retention/stages/modeling/imbalance_handler.py +282 -0
- customer_retention/stages/modeling/mlflow_logger.py +95 -0
- customer_retention/stages/modeling/model_comparator.py +149 -0
- customer_retention/stages/modeling/model_evaluator.py +138 -0
- customer_retention/stages/modeling/threshold_optimizer.py +131 -0
- customer_retention/stages/monitoring/__init__.py +37 -0
- customer_retention/stages/monitoring/alert_manager.py +328 -0
- customer_retention/stages/monitoring/drift_detector.py +201 -0
- customer_retention/stages/monitoring/performance_monitor.py +242 -0
- customer_retention/stages/preprocessing/__init__.py +5 -0
- customer_retention/stages/preprocessing/transformer_manager.py +284 -0
- customer_retention/stages/profiling/__init__.py +256 -0
- customer_retention/stages/profiling/categorical_distribution.py +269 -0
- customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
- customer_retention/stages/profiling/column_profiler.py +527 -0
- customer_retention/stages/profiling/distribution_analysis.py +483 -0
- customer_retention/stages/profiling/drift_detector.py +310 -0
- customer_retention/stages/profiling/feature_capacity.py +507 -0
- customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
- customer_retention/stages/profiling/profile_result.py +212 -0
- customer_retention/stages/profiling/quality_checks.py +1632 -0
- customer_retention/stages/profiling/relationship_detector.py +256 -0
- customer_retention/stages/profiling/relationship_recommender.py +454 -0
- customer_retention/stages/profiling/report_generator.py +520 -0
- customer_retention/stages/profiling/scd_analyzer.py +151 -0
- customer_retention/stages/profiling/segment_analyzer.py +632 -0
- customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
- customer_retention/stages/profiling/target_level_analyzer.py +217 -0
- customer_retention/stages/profiling/temporal_analyzer.py +388 -0
- customer_retention/stages/profiling/temporal_coverage.py +488 -0
- customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
- customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
- customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
- customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
- customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
- customer_retention/stages/profiling/text_embedder.py +87 -0
- customer_retention/stages/profiling/text_processor.py +115 -0
- customer_retention/stages/profiling/text_reducer.py +60 -0
- customer_retention/stages/profiling/time_series_profiler.py +303 -0
- customer_retention/stages/profiling/time_window_aggregator.py +376 -0
- customer_retention/stages/profiling/type_detector.py +382 -0
- customer_retention/stages/profiling/window_recommendation.py +288 -0
- customer_retention/stages/temporal/__init__.py +166 -0
- customer_retention/stages/temporal/access_guard.py +180 -0
- customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
- customer_retention/stages/temporal/data_preparer.py +178 -0
- customer_retention/stages/temporal/point_in_time_join.py +134 -0
- customer_retention/stages/temporal/point_in_time_registry.py +148 -0
- customer_retention/stages/temporal/scenario_detector.py +163 -0
- customer_retention/stages/temporal/snapshot_manager.py +259 -0
- customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
- customer_retention/stages/temporal/timestamp_discovery.py +531 -0
- customer_retention/stages/temporal/timestamp_manager.py +255 -0
- customer_retention/stages/transformation/__init__.py +13 -0
- customer_retention/stages/transformation/binary_handler.py +85 -0
- customer_retention/stages/transformation/categorical_encoder.py +245 -0
- customer_retention/stages/transformation/datetime_transformer.py +97 -0
- customer_retention/stages/transformation/numeric_transformer.py +181 -0
- customer_retention/stages/transformation/pipeline.py +257 -0
- customer_retention/stages/validation/__init__.py +60 -0
- customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
- customer_retention/stages/validation/business_sense_gate.py +173 -0
- customer_retention/stages/validation/data_quality_gate.py +235 -0
- customer_retention/stages/validation/data_validators.py +511 -0
- customer_retention/stages/validation/feature_quality_gate.py +183 -0
- customer_retention/stages/validation/gates.py +117 -0
- customer_retention/stages/validation/leakage_gate.py +352 -0
- customer_retention/stages/validation/model_validity_gate.py +213 -0
- customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
- customer_retention/stages/validation/quality_scorer.py +544 -0
- customer_retention/stages/validation/rule_generator.py +57 -0
- customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
- customer_retention/stages/validation/timeseries_detector.py +769 -0
- customer_retention/transforms/__init__.py +47 -0
- customer_retention/transforms/artifact_store.py +50 -0
- customer_retention/transforms/executor.py +157 -0
- customer_retention/transforms/fitted.py +92 -0
- customer_retention/transforms/ops.py +148 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Overfitting analysis probes for model validation."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from sklearn.model_selection import learning_curve
|
|
8
|
+
|
|
9
|
+
from customer_retention.core.compat import DataFrame, Series
|
|
10
|
+
from customer_retention.core.components.enums import Severity
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class OverfittingCheck:
|
|
15
|
+
check_id: str
|
|
16
|
+
metric: str
|
|
17
|
+
severity: Severity
|
|
18
|
+
recommendation: str
|
|
19
|
+
train_value: float = 0.0
|
|
20
|
+
test_value: float = 0.0
|
|
21
|
+
gap: float = 0.0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class OverfittingResult:
|
|
26
|
+
passed: bool
|
|
27
|
+
checks: List[OverfittingCheck] = field(default_factory=list)
|
|
28
|
+
recommendations: List[str] = field(default_factory=list)
|
|
29
|
+
learning_curve: List[Dict[str, float]] = field(default_factory=list)
|
|
30
|
+
diagnosis: Optional[str] = None
|
|
31
|
+
sample_to_feature_ratio: float = 0.0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class OverfittingAnalyzer:
|
|
35
|
+
GAP_CRITICAL = 0.15
|
|
36
|
+
GAP_HIGH = 0.10
|
|
37
|
+
GAP_MEDIUM = 0.05
|
|
38
|
+
RATIO_CRITICAL = 10
|
|
39
|
+
RATIO_HIGH = 50
|
|
40
|
+
DEPTH_HIGH = 15
|
|
41
|
+
ESTIMATORS_HIGH = 500
|
|
42
|
+
|
|
43
|
+
def analyze_train_test_gap(self, train_metrics: Dict[str, float], test_metrics: Dict[str, float]) -> OverfittingResult:
|
|
44
|
+
checks = []
|
|
45
|
+
for metric in train_metrics:
|
|
46
|
+
if metric in test_metrics:
|
|
47
|
+
train_val = train_metrics[metric]
|
|
48
|
+
test_val = test_metrics[metric]
|
|
49
|
+
gap = train_val - test_val
|
|
50
|
+
severity, check_id = self._classify_gap(gap)
|
|
51
|
+
checks.append(OverfittingCheck(
|
|
52
|
+
check_id=check_id,
|
|
53
|
+
metric=metric,
|
|
54
|
+
severity=severity,
|
|
55
|
+
recommendation=self._gap_recommendation(metric, gap),
|
|
56
|
+
train_value=train_val,
|
|
57
|
+
test_value=test_val,
|
|
58
|
+
gap=gap,
|
|
59
|
+
))
|
|
60
|
+
critical = [c for c in checks if c.severity == Severity.CRITICAL]
|
|
61
|
+
recommendations = [c.recommendation for c in checks if c.severity in [Severity.CRITICAL, Severity.HIGH]]
|
|
62
|
+
return OverfittingResult(passed=len(critical) == 0, checks=checks, recommendations=recommendations)
|
|
63
|
+
|
|
64
|
+
def _classify_gap(self, gap: float) -> tuple:
|
|
65
|
+
if gap > self.GAP_CRITICAL:
|
|
66
|
+
return Severity.CRITICAL, "OF001"
|
|
67
|
+
if gap > self.GAP_HIGH:
|
|
68
|
+
return Severity.HIGH, "OF002"
|
|
69
|
+
if gap > self.GAP_MEDIUM:
|
|
70
|
+
return Severity.MEDIUM, "OF003"
|
|
71
|
+
return Severity.INFO, "OF004"
|
|
72
|
+
|
|
73
|
+
def _gap_recommendation(self, metric: str, gap: float) -> str:
|
|
74
|
+
if gap > self.GAP_CRITICAL:
|
|
75
|
+
return f"CRITICAL: {metric} gap {gap:.1%} indicates severe overfitting. Reduce model complexity, add regularization."
|
|
76
|
+
if gap > self.GAP_HIGH:
|
|
77
|
+
return f"HIGH: {metric} gap {gap:.1%} indicates moderate overfitting. Consider feature selection or regularization."
|
|
78
|
+
if gap > self.GAP_MEDIUM:
|
|
79
|
+
return f"MEDIUM: {metric} gap {gap:.1%} shows mild overfitting. Monitor closely."
|
|
80
|
+
return f"OK: {metric} gap {gap:.1%} shows good generalization."
|
|
81
|
+
|
|
82
|
+
def analyze_learning_curve(self, model, X: DataFrame, y: Series, cv: int = 5) -> OverfittingResult:
|
|
83
|
+
try:
|
|
84
|
+
train_sizes = np.linspace(0.2, 1.0, 5)
|
|
85
|
+
train_sizes_abs, train_scores, val_scores = learning_curve(
|
|
86
|
+
model, X, y, train_sizes=train_sizes, cv=cv, scoring="roc_auc", random_state=42
|
|
87
|
+
)
|
|
88
|
+
curve_data = []
|
|
89
|
+
for i, size in enumerate(train_sizes_abs):
|
|
90
|
+
curve_data.append({
|
|
91
|
+
"train_size": int(size),
|
|
92
|
+
"train_score": float(np.mean(train_scores[i])),
|
|
93
|
+
"val_score": float(np.mean(val_scores[i])),
|
|
94
|
+
})
|
|
95
|
+
diagnosis = self._diagnose_learning_curve(curve_data)
|
|
96
|
+
return OverfittingResult(passed=True, learning_curve=curve_data, diagnosis=diagnosis)
|
|
97
|
+
except Exception:
|
|
98
|
+
return OverfittingResult(passed=True, learning_curve=[], diagnosis="Unable to generate learning curve")
|
|
99
|
+
|
|
100
|
+
def _diagnose_learning_curve(self, curve_data: List[Dict[str, float]]) -> str:
|
|
101
|
+
if not curve_data:
|
|
102
|
+
return "Insufficient data for diagnosis"
|
|
103
|
+
last = curve_data[-1]
|
|
104
|
+
first = curve_data[0]
|
|
105
|
+
train_score = last["train_score"]
|
|
106
|
+
val_score = last["val_score"]
|
|
107
|
+
gap = train_score - val_score
|
|
108
|
+
val_improvement = last["val_score"] - first["val_score"]
|
|
109
|
+
if gap < 0.05 and val_score > 0.7:
|
|
110
|
+
return "Good fit: Both curves converged at high performance"
|
|
111
|
+
if gap > 0.15:
|
|
112
|
+
return "Overfitting: High train score but low validation. Reduce complexity."
|
|
113
|
+
if val_score < 0.6 and train_score < 0.7:
|
|
114
|
+
return "Underfitting: Both scores low. Increase model complexity or add features."
|
|
115
|
+
if val_improvement > 0.05:
|
|
116
|
+
return "More data may help: Validation still improving with more samples."
|
|
117
|
+
return "Validation plateau: More data unlikely to help significantly."
|
|
118
|
+
|
|
119
|
+
def analyze_complexity(self, X: DataFrame, y: Series) -> OverfittingResult:
|
|
120
|
+
n_samples, n_features = X.shape
|
|
121
|
+
ratio = n_samples / max(n_features, 1)
|
|
122
|
+
checks = []
|
|
123
|
+
severity, check_id = self._classify_ratio(ratio)
|
|
124
|
+
if severity != Severity.INFO:
|
|
125
|
+
checks.append(OverfittingCheck(
|
|
126
|
+
check_id=check_id,
|
|
127
|
+
metric="sample_to_feature_ratio",
|
|
128
|
+
severity=severity,
|
|
129
|
+
recommendation=self._ratio_recommendation(ratio, n_samples, n_features),
|
|
130
|
+
train_value=ratio,
|
|
131
|
+
))
|
|
132
|
+
critical = [c for c in checks if c.severity == Severity.CRITICAL]
|
|
133
|
+
recommendations = [c.recommendation for c in checks if c.severity in [Severity.CRITICAL, Severity.HIGH]]
|
|
134
|
+
return OverfittingResult(passed=len(critical) == 0, checks=checks, recommendations=recommendations, sample_to_feature_ratio=ratio)
|
|
135
|
+
|
|
136
|
+
def _classify_ratio(self, ratio: float) -> tuple:
|
|
137
|
+
if ratio < self.RATIO_CRITICAL:
|
|
138
|
+
return Severity.CRITICAL, "OF010"
|
|
139
|
+
if ratio < self.RATIO_HIGH:
|
|
140
|
+
return Severity.HIGH, "OF011"
|
|
141
|
+
return Severity.INFO, "OF000"
|
|
142
|
+
|
|
143
|
+
def _ratio_recommendation(self, ratio: float, n_samples: int, n_features: int) -> str:
|
|
144
|
+
if ratio < self.RATIO_CRITICAL:
|
|
145
|
+
suggested_features = n_samples // 10
|
|
146
|
+
return f"CRITICAL: Ratio {ratio:.1f}:1 is too low. Reduce to {suggested_features} features or get more data."
|
|
147
|
+
if ratio < self.RATIO_HIGH:
|
|
148
|
+
return f"HIGH: Ratio {ratio:.1f}:1 is concerning. Use L1 regularization and monitor closely."
|
|
149
|
+
return f"OK: Ratio {ratio:.1f}:1 is adequate."
|
|
150
|
+
|
|
151
|
+
def analyze_model_complexity(self, model_params: Dict[str, Any]) -> OverfittingResult:
|
|
152
|
+
checks = []
|
|
153
|
+
if "max_depth" in model_params and model_params["max_depth"]:
|
|
154
|
+
depth = model_params["max_depth"]
|
|
155
|
+
if depth > self.DEPTH_HIGH:
|
|
156
|
+
checks.append(OverfittingCheck(
|
|
157
|
+
check_id="OF012",
|
|
158
|
+
metric="max_depth",
|
|
159
|
+
severity=Severity.HIGH,
|
|
160
|
+
recommendation=f"HIGH: max_depth={depth} may cause overfitting. Consider depth <= 10.",
|
|
161
|
+
train_value=depth,
|
|
162
|
+
))
|
|
163
|
+
if "n_estimators" in model_params:
|
|
164
|
+
n_est = model_params["n_estimators"]
|
|
165
|
+
if n_est > self.ESTIMATORS_HIGH and "regularization" not in model_params:
|
|
166
|
+
checks.append(OverfittingCheck(
|
|
167
|
+
check_id="OF013",
|
|
168
|
+
metric="n_estimators",
|
|
169
|
+
severity=Severity.MEDIUM,
|
|
170
|
+
recommendation=f"MEDIUM: n_estimators={n_est} without regularization may cause overfitting.",
|
|
171
|
+
train_value=n_est,
|
|
172
|
+
))
|
|
173
|
+
critical = [c for c in checks if c.severity == Severity.CRITICAL]
|
|
174
|
+
return OverfittingResult(passed=len(critical) == 0, checks=checks)
|
|
175
|
+
|
|
176
|
+
def run_all(self, model, X: DataFrame, y: Series, train_metrics: Dict[str, float], test_metrics: Dict[str, float]) -> OverfittingResult:
|
|
177
|
+
gap_result = self.analyze_train_test_gap(train_metrics, test_metrics)
|
|
178
|
+
complexity_result = self.analyze_complexity(X, y)
|
|
179
|
+
learning_result = self.analyze_learning_curve(model, X, y)
|
|
180
|
+
all_checks = gap_result.checks + complexity_result.checks
|
|
181
|
+
all_recommendations = gap_result.recommendations + complexity_result.recommendations
|
|
182
|
+
critical = [c for c in all_checks if c.severity == Severity.CRITICAL]
|
|
183
|
+
return OverfittingResult(
|
|
184
|
+
passed=len(critical) == 0,
|
|
185
|
+
checks=all_checks,
|
|
186
|
+
recommendations=list(set(all_recommendations)),
|
|
187
|
+
learning_curve=learning_result.learning_curve,
|
|
188
|
+
diagnosis=learning_result.diagnosis,
|
|
189
|
+
sample_to_feature_ratio=complexity_result.sample_to_feature_ratio,
|
|
190
|
+
)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Segment performance analysis probes."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from sklearn.metrics import average_precision_score, precision_score, recall_score, roc_auc_score
|
|
8
|
+
|
|
9
|
+
from customer_retention.core.compat import DataFrame, Series, pd
|
|
10
|
+
from customer_retention.core.components.enums import Severity
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class SegmentCheck:
|
|
15
|
+
check_id: str
|
|
16
|
+
segment: str
|
|
17
|
+
severity: Severity
|
|
18
|
+
recommendation: str
|
|
19
|
+
metric: str = ""
|
|
20
|
+
value: float = 0.0
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class SegmentResult:
|
|
25
|
+
passed: bool
|
|
26
|
+
checks: List[SegmentCheck] = field(default_factory=list)
|
|
27
|
+
segment_metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
|
28
|
+
recommendations: List[str] = field(default_factory=list)
|
|
29
|
+
recommendation: str = ""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SegmentPerformanceAnalyzer:
|
|
33
|
+
UNDERPERFORMANCE_THRESHOLD = 0.20
|
|
34
|
+
LOW_RECALL_THRESHOLD = 0.20
|
|
35
|
+
SMALL_SEGMENT_THRESHOLD = 0.05
|
|
36
|
+
|
|
37
|
+
def define_segments(self, X: DataFrame, segment_column: str, segment_type: str = "quantile") -> Series:
|
|
38
|
+
if segment_column not in X.columns:
|
|
39
|
+
return pd.Series(["all"] * len(X))
|
|
40
|
+
values = X[segment_column]
|
|
41
|
+
if segment_type == "tenure":
|
|
42
|
+
return pd.cut(values, bins=[0, 90, 365, np.inf], labels=["new", "established", "mature"])
|
|
43
|
+
if segment_type == "quantile":
|
|
44
|
+
return pd.qcut(values, q=3, labels=["low", "medium", "high"], duplicates="drop")
|
|
45
|
+
return Series(["all"] * len(X))
|
|
46
|
+
|
|
47
|
+
def analyze_performance(self, model, X: DataFrame, y: Series, segments: Series) -> SegmentResult:
|
|
48
|
+
checks = []
|
|
49
|
+
segment_metrics = {}
|
|
50
|
+
global_metrics = self._compute_metrics(model, X, y)
|
|
51
|
+
unique_segments = segments.unique()
|
|
52
|
+
for seg in unique_segments:
|
|
53
|
+
mask = segments == seg
|
|
54
|
+
if mask.sum() < 10:
|
|
55
|
+
continue
|
|
56
|
+
X_seg = X[mask]
|
|
57
|
+
y_seg = y[mask]
|
|
58
|
+
seg_size_pct = mask.sum() / len(y)
|
|
59
|
+
metrics = self._compute_metrics(model, X_seg, y_seg)
|
|
60
|
+
segment_metrics[str(seg)] = metrics
|
|
61
|
+
if seg_size_pct < self.SMALL_SEGMENT_THRESHOLD:
|
|
62
|
+
checks.append(SegmentCheck(
|
|
63
|
+
check_id="SG003",
|
|
64
|
+
segment=str(seg),
|
|
65
|
+
severity=Severity.MEDIUM,
|
|
66
|
+
recommendation=f"MEDIUM: Segment '{seg}' is small ({seg_size_pct:.1%}). Results may be unreliable.",
|
|
67
|
+
metric="size",
|
|
68
|
+
value=seg_size_pct,
|
|
69
|
+
))
|
|
70
|
+
if "pr_auc" in metrics and "pr_auc" in global_metrics:
|
|
71
|
+
gap = global_metrics["pr_auc"] - metrics["pr_auc"]
|
|
72
|
+
if gap > self.UNDERPERFORMANCE_THRESHOLD:
|
|
73
|
+
checks.append(SegmentCheck(
|
|
74
|
+
check_id="SG001",
|
|
75
|
+
segment=str(seg),
|
|
76
|
+
severity=Severity.HIGH,
|
|
77
|
+
recommendation=f"HIGH: Segment '{seg}' underperforms by {gap:.1%}. Consider segment-specific model.",
|
|
78
|
+
metric="pr_auc",
|
|
79
|
+
value=metrics["pr_auc"],
|
|
80
|
+
))
|
|
81
|
+
if "recall" in metrics and metrics["recall"] < self.LOW_RECALL_THRESHOLD:
|
|
82
|
+
checks.append(SegmentCheck(
|
|
83
|
+
check_id="SG002",
|
|
84
|
+
segment=str(seg),
|
|
85
|
+
severity=Severity.HIGH,
|
|
86
|
+
recommendation=f"HIGH: Segment '{seg}' has low recall ({metrics['recall']:.1%}). Adjust threshold or add features.",
|
|
87
|
+
metric="recall",
|
|
88
|
+
value=metrics["recall"],
|
|
89
|
+
))
|
|
90
|
+
critical = [c for c in checks if c.severity == Severity.CRITICAL]
|
|
91
|
+
recommendations = [c.recommendation for c in checks if c.severity in [Severity.CRITICAL, Severity.HIGH]]
|
|
92
|
+
recommendation = self._global_recommendation(checks, unique_segments)
|
|
93
|
+
return SegmentResult(
|
|
94
|
+
passed=len(critical) == 0,
|
|
95
|
+
checks=checks,
|
|
96
|
+
segment_metrics=segment_metrics,
|
|
97
|
+
recommendations=recommendations,
|
|
98
|
+
recommendation=recommendation,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def _compute_metrics(self, model, X: DataFrame, y: Series) -> Dict[str, float]:
|
|
102
|
+
try:
|
|
103
|
+
y_pred = model.predict(X)
|
|
104
|
+
y_proba = model.predict_proba(X)[:, 1] if hasattr(model, "predict_proba") else y_pred
|
|
105
|
+
return {
|
|
106
|
+
"precision": precision_score(y, y_pred, zero_division=0),
|
|
107
|
+
"recall": recall_score(y, y_pred, zero_division=0),
|
|
108
|
+
"roc_auc": roc_auc_score(y, y_proba) if len(np.unique(y)) > 1 else 0.5,
|
|
109
|
+
"pr_auc": average_precision_score(y, y_proba) if len(np.unique(y)) > 1 else 0.5,
|
|
110
|
+
"churn_rate": y.mean(),
|
|
111
|
+
"sample_size": len(y),
|
|
112
|
+
}
|
|
113
|
+
except Exception:
|
|
114
|
+
return {}
|
|
115
|
+
|
|
116
|
+
def _global_recommendation(self, checks: List[SegmentCheck], segments) -> str:
|
|
117
|
+
high_issues = [c for c in checks if c.severity == Severity.HIGH]
|
|
118
|
+
if not high_issues:
|
|
119
|
+
return "No significant segment gaps. Continue with global model."
|
|
120
|
+
if len(high_issues) == 1:
|
|
121
|
+
return "One segment underperforms. Consider adding segment as feature."
|
|
122
|
+
return "Multiple segments underperform. Consider segment-specific models."
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .config_generator import ConfigGenerator
|
|
2
|
+
from .discovery_flow import discover_and_configure
|
|
3
|
+
from .type_inferencer import ColumnInference, InferenceConfidence, InferenceResult, TypeInferencer
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"TypeInferencer", "InferenceResult", "ColumnInference", "InferenceConfidence",
|
|
7
|
+
"ConfigGenerator", "discover_and_configure"
|
|
8
|
+
]
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from customer_retention.core.config.column_config import ColumnConfig
|
|
5
|
+
from customer_retention.core.config.pipeline_config import (
|
|
6
|
+
BronzeConfig,
|
|
7
|
+
GoldConfig,
|
|
8
|
+
ModelingConfig,
|
|
9
|
+
PipelineConfig,
|
|
10
|
+
SilverConfig,
|
|
11
|
+
)
|
|
12
|
+
from customer_retention.core.config.source_config import DataSourceConfig, FileFormat, SourceType
|
|
13
|
+
|
|
14
|
+
from .type_inferencer import InferenceResult
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ConfigGenerator:
|
|
18
|
+
def from_inference(self, result: InferenceResult, project_name: str = "customer_retention",
|
|
19
|
+
source_path: Optional[str] = None) -> PipelineConfig:
|
|
20
|
+
column_configs = []
|
|
21
|
+
for col, inf in result.inferences.items():
|
|
22
|
+
cc = ColumnConfig(name=col, column_type=inf.inferred_type)
|
|
23
|
+
column_configs.append(cc)
|
|
24
|
+
primary_key = result.identifier_columns[0] if result.identifier_columns else "id"
|
|
25
|
+
data_source = DataSourceConfig(
|
|
26
|
+
name="main_source",
|
|
27
|
+
source_type=SourceType.BATCH_FILE,
|
|
28
|
+
primary_key=primary_key,
|
|
29
|
+
path=source_path or "./data.csv",
|
|
30
|
+
file_format=FileFormat.CSV,
|
|
31
|
+
columns=column_configs
|
|
32
|
+
)
|
|
33
|
+
target_col = result.target_column or "target"
|
|
34
|
+
modeling = ModelingConfig(target_column=target_col)
|
|
35
|
+
bronze = BronzeConfig(dedup_keys=[primary_key])
|
|
36
|
+
silver = SilverConfig(entity_key=primary_key)
|
|
37
|
+
return PipelineConfig(
|
|
38
|
+
project_name=project_name,
|
|
39
|
+
data_sources=[data_source],
|
|
40
|
+
bronze=bronze,
|
|
41
|
+
silver=silver,
|
|
42
|
+
gold=GoldConfig(),
|
|
43
|
+
modeling=modeling
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def save(self, config: PipelineConfig, path: str) -> None:
|
|
47
|
+
data = config.model_dump() if hasattr(config, "model_dump") else config.dict()
|
|
48
|
+
with open(path, "w") as f:
|
|
49
|
+
json.dump(data, f, indent=2, default=str)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
from customer_retention.core.config.pipeline_config import PipelineConfig
|
|
6
|
+
|
|
7
|
+
from .config_generator import ConfigGenerator
|
|
8
|
+
from .type_inferencer import TypeInferencer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def discover_and_configure(source: Union[str, pd.DataFrame], project_name: str = "customer_retention",
|
|
12
|
+
target_hint: Optional[str] = None) -> PipelineConfig:
|
|
13
|
+
inferencer = TypeInferencer()
|
|
14
|
+
result = inferencer.infer(source)
|
|
15
|
+
if target_hint:
|
|
16
|
+
result.target_column = target_hint
|
|
17
|
+
generator = ConfigGenerator()
|
|
18
|
+
config = generator.from_inference(result, project_name=project_name)
|
|
19
|
+
return config
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
from customer_retention.core.compat import ops
|
|
8
|
+
from customer_retention.core.config.column_config import ColumnType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class InferenceConfidence(str, Enum):
|
|
12
|
+
HIGH = "high"
|
|
13
|
+
MEDIUM = "medium"
|
|
14
|
+
LOW = "low"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ColumnInference:
|
|
19
|
+
column_name: str
|
|
20
|
+
inferred_type: ColumnType
|
|
21
|
+
confidence: InferenceConfidence
|
|
22
|
+
evidence: List[str]
|
|
23
|
+
alternatives: List[ColumnType] = field(default_factory=list)
|
|
24
|
+
suggested_encoding: Optional[str] = None
|
|
25
|
+
suggested_scaling: Optional[str] = None
|
|
26
|
+
suggested_missing_strategy: Optional[str] = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class InferenceResult:
|
|
31
|
+
inferences: Dict[str, ColumnInference]
|
|
32
|
+
target_column: Optional[str] = None
|
|
33
|
+
identifier_columns: List[str] = field(default_factory=list)
|
|
34
|
+
datetime_columns: List[str] = field(default_factory=list)
|
|
35
|
+
warnings: List[str] = field(default_factory=list)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TypeInferencer:
|
|
39
|
+
TARGET_PATTERNS = ["target", "label", "churn", "retained", "outcome", "class", "y"]
|
|
40
|
+
ID_PATTERNS = ["id", "key", "code", "identifier", "index"]
|
|
41
|
+
|
|
42
|
+
def __init__(self):
|
|
43
|
+
self.evidence: List[str] = []
|
|
44
|
+
|
|
45
|
+
def infer(self, source: Union[str, pd.DataFrame]) -> InferenceResult:
|
|
46
|
+
if isinstance(source, str):
|
|
47
|
+
df = ops.read_csv(source)
|
|
48
|
+
else:
|
|
49
|
+
df = source
|
|
50
|
+
inferences = {}
|
|
51
|
+
target_column = None
|
|
52
|
+
identifier_columns = []
|
|
53
|
+
datetime_columns = []
|
|
54
|
+
for col in df.columns:
|
|
55
|
+
inference = self._infer_column(df[col], col)
|
|
56
|
+
inferences[col] = inference
|
|
57
|
+
if inference.inferred_type == ColumnType.TARGET:
|
|
58
|
+
target_column = col
|
|
59
|
+
elif inference.inferred_type == ColumnType.IDENTIFIER:
|
|
60
|
+
identifier_columns.append(col)
|
|
61
|
+
elif inference.inferred_type == ColumnType.DATETIME:
|
|
62
|
+
datetime_columns.append(col)
|
|
63
|
+
return InferenceResult(
|
|
64
|
+
inferences=inferences,
|
|
65
|
+
target_column=target_column,
|
|
66
|
+
identifier_columns=identifier_columns,
|
|
67
|
+
datetime_columns=datetime_columns
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def _infer_column(self, series: pd.Series, column_name: str) -> ColumnInference:
|
|
71
|
+
evidence = []
|
|
72
|
+
col_lower = column_name.lower()
|
|
73
|
+
if self._is_identifier(series, col_lower, evidence):
|
|
74
|
+
return ColumnInference(column_name, ColumnType.IDENTIFIER, InferenceConfidence.HIGH, evidence)
|
|
75
|
+
if self._is_target(series, col_lower, evidence):
|
|
76
|
+
return ColumnInference(column_name, ColumnType.TARGET, InferenceConfidence.HIGH, evidence)
|
|
77
|
+
if self._is_datetime(series, evidence):
|
|
78
|
+
return ColumnInference(column_name, ColumnType.DATETIME, InferenceConfidence.HIGH, evidence)
|
|
79
|
+
if self._is_binary(series, evidence):
|
|
80
|
+
return ColumnInference(column_name, ColumnType.BINARY, InferenceConfidence.HIGH, evidence)
|
|
81
|
+
if pd.api.types.is_numeric_dtype(series):
|
|
82
|
+
return self._infer_numeric(series, column_name, evidence)
|
|
83
|
+
return self._infer_categorical(series, column_name, evidence)
|
|
84
|
+
|
|
85
|
+
def _is_identifier(self, series: pd.Series, col_lower: str, evidence: List[str]) -> bool:
|
|
86
|
+
if any(p in col_lower for p in self.ID_PATTERNS):
|
|
87
|
+
if series.nunique() == len(series):
|
|
88
|
+
evidence.append("unique values, id pattern in name")
|
|
89
|
+
return True
|
|
90
|
+
if series.nunique() == len(series) and pd.api.types.is_integer_dtype(series):
|
|
91
|
+
evidence.append("unique integer values")
|
|
92
|
+
return True
|
|
93
|
+
return False
|
|
94
|
+
|
|
95
|
+
def _is_target(self, series: pd.Series, col_lower: str, evidence: List[str]) -> bool:
|
|
96
|
+
if any(p in col_lower for p in self.TARGET_PATTERNS):
|
|
97
|
+
if series.nunique() <= 10:
|
|
98
|
+
evidence.append(f"target pattern in name, {series.nunique()} distinct values")
|
|
99
|
+
return True
|
|
100
|
+
return False
|
|
101
|
+
|
|
102
|
+
def _is_datetime(self, series: pd.Series, evidence: List[str]) -> bool:
|
|
103
|
+
if pd.api.types.is_datetime64_any_dtype(series):
|
|
104
|
+
evidence.append("datetime dtype")
|
|
105
|
+
return True
|
|
106
|
+
if series.dtype == object:
|
|
107
|
+
try:
|
|
108
|
+
pd.to_datetime(series.dropna().head(100), format='mixed')
|
|
109
|
+
evidence.append("parseable as datetime")
|
|
110
|
+
return True
|
|
111
|
+
except (ValueError, TypeError):
|
|
112
|
+
pass
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
def _is_binary(self, series: pd.Series, evidence: List[str]) -> bool:
|
|
116
|
+
unique = series.dropna().unique()
|
|
117
|
+
if len(unique) == 2:
|
|
118
|
+
evidence.append("exactly 2 unique values")
|
|
119
|
+
return True
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
def _infer_numeric(self, series: pd.Series, column_name: str, evidence: List[str]) -> ColumnInference:
|
|
123
|
+
nunique = series.nunique()
|
|
124
|
+
if nunique <= 20:
|
|
125
|
+
evidence.append(f"numeric with {nunique} unique values (discrete)")
|
|
126
|
+
return ColumnInference(column_name, ColumnType.NUMERIC_DISCRETE, InferenceConfidence.HIGH, evidence,
|
|
127
|
+
suggested_encoding="ordinal", suggested_missing_strategy="median")
|
|
128
|
+
evidence.append(f"numeric with {nunique} unique values (continuous)")
|
|
129
|
+
return ColumnInference(column_name, ColumnType.NUMERIC_CONTINUOUS, InferenceConfidence.HIGH, evidence,
|
|
130
|
+
suggested_scaling="standard", suggested_missing_strategy="median")
|
|
131
|
+
|
|
132
|
+
def _infer_categorical(self, series: pd.Series, column_name: str, evidence: List[str]) -> ColumnInference:
|
|
133
|
+
nunique = series.nunique()
|
|
134
|
+
if nunique <= 10:
|
|
135
|
+
evidence.append(f"categorical with {nunique} categories (low cardinality)")
|
|
136
|
+
return ColumnInference(column_name, ColumnType.CATEGORICAL_NOMINAL, InferenceConfidence.HIGH, evidence,
|
|
137
|
+
suggested_encoding="onehot", suggested_missing_strategy="mode")
|
|
138
|
+
evidence.append(f"categorical with {nunique} categories (high cardinality)")
|
|
139
|
+
return ColumnInference(column_name, ColumnType.CATEGORICAL_NOMINAL, InferenceConfidence.MEDIUM, evidence,
|
|
140
|
+
suggested_encoding="target", suggested_missing_strategy="mode")
|
|
141
|
+
|
|
142
|
+
def show_report(self, result: InferenceResult) -> None:
|
|
143
|
+
print(f"Target column: {result.target_column}")
|
|
144
|
+
print(f"Identifier columns: {result.identifier_columns}")
|
|
145
|
+
print(f"Datetime columns: {result.datetime_columns}")
|
|
146
|
+
for col, inf in result.inferences.items():
|
|
147
|
+
print(f" {col}: {inf.inferred_type.value} ({inf.confidence.value}) - {', '.join(inf.evidence)}")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .cohort_analyzer import CohortAnalysisResult, CohortAnalyzer, CohortComparison, CohortInsight
|
|
2
|
+
from .counterfactual import Counterfactual, CounterfactualChange, CounterfactualGenerator
|
|
3
|
+
from .individual_explainer import Confidence, IndividualExplainer, IndividualExplanation, RiskContribution
|
|
4
|
+
from .pdp_generator import InteractionResult, PDPGenerator, PDPResult
|
|
5
|
+
from .shap_explainer import FeatureImportance, GlobalExplanation, ShapExplainer
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"ShapExplainer", "GlobalExplanation", "FeatureImportance",
|
|
9
|
+
"PDPGenerator", "PDPResult", "InteractionResult",
|
|
10
|
+
"CohortAnalyzer", "CohortInsight", "CohortComparison", "CohortAnalysisResult",
|
|
11
|
+
"IndividualExplainer", "IndividualExplanation", "RiskContribution", "Confidence",
|
|
12
|
+
"CounterfactualGenerator", "Counterfactual", "CounterfactualChange",
|
|
13
|
+
]
|