churnkit 0.75.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
- churnkit-0.75.0a1.dist-info/METADATA +229 -0
- churnkit-0.75.0a1.dist-info/RECORD +302 -0
- churnkit-0.75.0a1.dist-info/WHEEL +4 -0
- churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
- churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
- customer_retention/__init__.py +37 -0
- customer_retention/analysis/__init__.py +0 -0
- customer_retention/analysis/auto_explorer/__init__.py +62 -0
- customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
- customer_retention/analysis/auto_explorer/explorer.py +258 -0
- customer_retention/analysis/auto_explorer/findings.py +291 -0
- customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
- customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
- customer_retention/analysis/auto_explorer/recommendations.py +418 -0
- customer_retention/analysis/business/__init__.py +26 -0
- customer_retention/analysis/business/ab_test_designer.py +144 -0
- customer_retention/analysis/business/fairness_analyzer.py +166 -0
- customer_retention/analysis/business/intervention_matcher.py +121 -0
- customer_retention/analysis/business/report_generator.py +222 -0
- customer_retention/analysis/business/risk_profile.py +199 -0
- customer_retention/analysis/business/roi_analyzer.py +139 -0
- customer_retention/analysis/diagnostics/__init__.py +20 -0
- customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
- customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
- customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
- customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
- customer_retention/analysis/diagnostics/noise_tester.py +140 -0
- customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
- customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
- customer_retention/analysis/discovery/__init__.py +8 -0
- customer_retention/analysis/discovery/config_generator.py +49 -0
- customer_retention/analysis/discovery/discovery_flow.py +19 -0
- customer_retention/analysis/discovery/type_inferencer.py +147 -0
- customer_retention/analysis/interpretability/__init__.py +13 -0
- customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
- customer_retention/analysis/interpretability/counterfactual.py +175 -0
- customer_retention/analysis/interpretability/individual_explainer.py +141 -0
- customer_retention/analysis/interpretability/pdp_generator.py +103 -0
- customer_retention/analysis/interpretability/shap_explainer.py +106 -0
- customer_retention/analysis/jupyter_save_hook.py +28 -0
- customer_retention/analysis/notebook_html_exporter.py +136 -0
- customer_retention/analysis/notebook_progress.py +60 -0
- customer_retention/analysis/plotly_preprocessor.py +154 -0
- customer_retention/analysis/recommendations/__init__.py +54 -0
- customer_retention/analysis/recommendations/base.py +158 -0
- customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
- customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
- customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
- customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
- customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
- customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
- customer_retention/analysis/recommendations/datetime/extract.py +149 -0
- customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
- customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
- customer_retention/analysis/recommendations/pipeline.py +74 -0
- customer_retention/analysis/recommendations/registry.py +76 -0
- customer_retention/analysis/recommendations/selection/__init__.py +3 -0
- customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
- customer_retention/analysis/recommendations/transform/__init__.py +4 -0
- customer_retention/analysis/recommendations/transform/power.py +94 -0
- customer_retention/analysis/recommendations/transform/scale.py +112 -0
- customer_retention/analysis/visualization/__init__.py +15 -0
- customer_retention/analysis/visualization/chart_builder.py +2619 -0
- customer_retention/analysis/visualization/console.py +122 -0
- customer_retention/analysis/visualization/display.py +171 -0
- customer_retention/analysis/visualization/number_formatter.py +36 -0
- customer_retention/artifacts/__init__.py +3 -0
- customer_retention/artifacts/fit_artifact_registry.py +146 -0
- customer_retention/cli.py +93 -0
- customer_retention/core/__init__.py +0 -0
- customer_retention/core/compat/__init__.py +193 -0
- customer_retention/core/compat/detection.py +99 -0
- customer_retention/core/compat/ops.py +48 -0
- customer_retention/core/compat/pandas_backend.py +57 -0
- customer_retention/core/compat/spark_backend.py +75 -0
- customer_retention/core/components/__init__.py +11 -0
- customer_retention/core/components/base.py +79 -0
- customer_retention/core/components/components/__init__.py +13 -0
- customer_retention/core/components/components/deployer.py +26 -0
- customer_retention/core/components/components/explainer.py +26 -0
- customer_retention/core/components/components/feature_eng.py +33 -0
- customer_retention/core/components/components/ingester.py +34 -0
- customer_retention/core/components/components/profiler.py +34 -0
- customer_retention/core/components/components/trainer.py +38 -0
- customer_retention/core/components/components/transformer.py +36 -0
- customer_retention/core/components/components/validator.py +37 -0
- customer_retention/core/components/enums.py +33 -0
- customer_retention/core/components/orchestrator.py +94 -0
- customer_retention/core/components/registry.py +59 -0
- customer_retention/core/config/__init__.py +39 -0
- customer_retention/core/config/column_config.py +95 -0
- customer_retention/core/config/experiments.py +71 -0
- customer_retention/core/config/pipeline_config.py +117 -0
- customer_retention/core/config/source_config.py +83 -0
- customer_retention/core/utils/__init__.py +28 -0
- customer_retention/core/utils/leakage.py +85 -0
- customer_retention/core/utils/severity.py +53 -0
- customer_retention/core/utils/statistics.py +90 -0
- customer_retention/generators/__init__.py +0 -0
- customer_retention/generators/notebook_generator/__init__.py +167 -0
- customer_retention/generators/notebook_generator/base.py +55 -0
- customer_retention/generators/notebook_generator/cell_builder.py +49 -0
- customer_retention/generators/notebook_generator/config.py +47 -0
- customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
- customer_retention/generators/notebook_generator/local_generator.py +48 -0
- customer_retention/generators/notebook_generator/project_init.py +174 -0
- customer_retention/generators/notebook_generator/runner.py +150 -0
- customer_retention/generators/notebook_generator/script_generator.py +110 -0
- customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
- customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
- customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
- customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
- customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
- customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
- customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
- customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
- customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
- customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
- customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
- customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
- customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
- customer_retention/generators/orchestration/__init__.py +23 -0
- customer_retention/generators/orchestration/code_generator.py +196 -0
- customer_retention/generators/orchestration/context.py +147 -0
- customer_retention/generators/orchestration/data_materializer.py +188 -0
- customer_retention/generators/orchestration/databricks_exporter.py +411 -0
- customer_retention/generators/orchestration/doc_generator.py +311 -0
- customer_retention/generators/pipeline_generator/__init__.py +26 -0
- customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
- customer_retention/generators/pipeline_generator/generator.py +142 -0
- customer_retention/generators/pipeline_generator/models.py +166 -0
- customer_retention/generators/pipeline_generator/renderer.py +2125 -0
- customer_retention/generators/spec_generator/__init__.py +37 -0
- customer_retention/generators/spec_generator/databricks_generator.py +433 -0
- customer_retention/generators/spec_generator/generic_generator.py +373 -0
- customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
- customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
- customer_retention/integrations/__init__.py +0 -0
- customer_retention/integrations/adapters/__init__.py +13 -0
- customer_retention/integrations/adapters/base.py +10 -0
- customer_retention/integrations/adapters/factory.py +25 -0
- customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
- customer_retention/integrations/adapters/feature_store/base.py +57 -0
- customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
- customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
- customer_retention/integrations/adapters/feature_store/local.py +75 -0
- customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
- customer_retention/integrations/adapters/mlflow/base.py +32 -0
- customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
- customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
- customer_retention/integrations/adapters/mlflow/local.py +50 -0
- customer_retention/integrations/adapters/storage/__init__.py +5 -0
- customer_retention/integrations/adapters/storage/base.py +33 -0
- customer_retention/integrations/adapters/storage/databricks.py +76 -0
- customer_retention/integrations/adapters/storage/local.py +59 -0
- customer_retention/integrations/feature_store/__init__.py +47 -0
- customer_retention/integrations/feature_store/definitions.py +215 -0
- customer_retention/integrations/feature_store/manager.py +744 -0
- customer_retention/integrations/feature_store/registry.py +412 -0
- customer_retention/integrations/iteration/__init__.py +28 -0
- customer_retention/integrations/iteration/context.py +212 -0
- customer_retention/integrations/iteration/feedback_collector.py +184 -0
- customer_retention/integrations/iteration/orchestrator.py +168 -0
- customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
- customer_retention/integrations/iteration/signals.py +212 -0
- customer_retention/integrations/llm_context/__init__.py +4 -0
- customer_retention/integrations/llm_context/context_builder.py +201 -0
- customer_retention/integrations/llm_context/prompts.py +100 -0
- customer_retention/integrations/streaming/__init__.py +103 -0
- customer_retention/integrations/streaming/batch_integration.py +149 -0
- customer_retention/integrations/streaming/early_warning_model.py +227 -0
- customer_retention/integrations/streaming/event_schema.py +214 -0
- customer_retention/integrations/streaming/online_store_writer.py +249 -0
- customer_retention/integrations/streaming/realtime_scorer.py +261 -0
- customer_retention/integrations/streaming/trigger_engine.py +293 -0
- customer_retention/integrations/streaming/window_aggregator.py +393 -0
- customer_retention/stages/__init__.py +0 -0
- customer_retention/stages/cleaning/__init__.py +9 -0
- customer_retention/stages/cleaning/base.py +28 -0
- customer_retention/stages/cleaning/missing_handler.py +160 -0
- customer_retention/stages/cleaning/outlier_handler.py +204 -0
- customer_retention/stages/deployment/__init__.py +28 -0
- customer_retention/stages/deployment/batch_scorer.py +106 -0
- customer_retention/stages/deployment/champion_challenger.py +299 -0
- customer_retention/stages/deployment/model_registry.py +182 -0
- customer_retention/stages/deployment/retraining_trigger.py +245 -0
- customer_retention/stages/features/__init__.py +73 -0
- customer_retention/stages/features/behavioral_features.py +266 -0
- customer_retention/stages/features/customer_segmentation.py +505 -0
- customer_retention/stages/features/feature_definitions.py +265 -0
- customer_retention/stages/features/feature_engineer.py +551 -0
- customer_retention/stages/features/feature_manifest.py +340 -0
- customer_retention/stages/features/feature_selector.py +239 -0
- customer_retention/stages/features/interaction_features.py +160 -0
- customer_retention/stages/features/temporal_features.py +243 -0
- customer_retention/stages/ingestion/__init__.py +9 -0
- customer_retention/stages/ingestion/load_result.py +32 -0
- customer_retention/stages/ingestion/loaders.py +195 -0
- customer_retention/stages/ingestion/source_registry.py +130 -0
- customer_retention/stages/modeling/__init__.py +31 -0
- customer_retention/stages/modeling/baseline_trainer.py +139 -0
- customer_retention/stages/modeling/cross_validator.py +125 -0
- customer_retention/stages/modeling/data_splitter.py +205 -0
- customer_retention/stages/modeling/feature_scaler.py +99 -0
- customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
- customer_retention/stages/modeling/imbalance_handler.py +282 -0
- customer_retention/stages/modeling/mlflow_logger.py +95 -0
- customer_retention/stages/modeling/model_comparator.py +149 -0
- customer_retention/stages/modeling/model_evaluator.py +138 -0
- customer_retention/stages/modeling/threshold_optimizer.py +131 -0
- customer_retention/stages/monitoring/__init__.py +37 -0
- customer_retention/stages/monitoring/alert_manager.py +328 -0
- customer_retention/stages/monitoring/drift_detector.py +201 -0
- customer_retention/stages/monitoring/performance_monitor.py +242 -0
- customer_retention/stages/preprocessing/__init__.py +5 -0
- customer_retention/stages/preprocessing/transformer_manager.py +284 -0
- customer_retention/stages/profiling/__init__.py +256 -0
- customer_retention/stages/profiling/categorical_distribution.py +269 -0
- customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
- customer_retention/stages/profiling/column_profiler.py +527 -0
- customer_retention/stages/profiling/distribution_analysis.py +483 -0
- customer_retention/stages/profiling/drift_detector.py +310 -0
- customer_retention/stages/profiling/feature_capacity.py +507 -0
- customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
- customer_retention/stages/profiling/profile_result.py +212 -0
- customer_retention/stages/profiling/quality_checks.py +1632 -0
- customer_retention/stages/profiling/relationship_detector.py +256 -0
- customer_retention/stages/profiling/relationship_recommender.py +454 -0
- customer_retention/stages/profiling/report_generator.py +520 -0
- customer_retention/stages/profiling/scd_analyzer.py +151 -0
- customer_retention/stages/profiling/segment_analyzer.py +632 -0
- customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
- customer_retention/stages/profiling/target_level_analyzer.py +217 -0
- customer_retention/stages/profiling/temporal_analyzer.py +388 -0
- customer_retention/stages/profiling/temporal_coverage.py +488 -0
- customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
- customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
- customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
- customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
- customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
- customer_retention/stages/profiling/text_embedder.py +87 -0
- customer_retention/stages/profiling/text_processor.py +115 -0
- customer_retention/stages/profiling/text_reducer.py +60 -0
- customer_retention/stages/profiling/time_series_profiler.py +303 -0
- customer_retention/stages/profiling/time_window_aggregator.py +376 -0
- customer_retention/stages/profiling/type_detector.py +382 -0
- customer_retention/stages/profiling/window_recommendation.py +288 -0
- customer_retention/stages/temporal/__init__.py +166 -0
- customer_retention/stages/temporal/access_guard.py +180 -0
- customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
- customer_retention/stages/temporal/data_preparer.py +178 -0
- customer_retention/stages/temporal/point_in_time_join.py +134 -0
- customer_retention/stages/temporal/point_in_time_registry.py +148 -0
- customer_retention/stages/temporal/scenario_detector.py +163 -0
- customer_retention/stages/temporal/snapshot_manager.py +259 -0
- customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
- customer_retention/stages/temporal/timestamp_discovery.py +531 -0
- customer_retention/stages/temporal/timestamp_manager.py +255 -0
- customer_retention/stages/transformation/__init__.py +13 -0
- customer_retention/stages/transformation/binary_handler.py +85 -0
- customer_retention/stages/transformation/categorical_encoder.py +245 -0
- customer_retention/stages/transformation/datetime_transformer.py +97 -0
- customer_retention/stages/transformation/numeric_transformer.py +181 -0
- customer_retention/stages/transformation/pipeline.py +257 -0
- customer_retention/stages/validation/__init__.py +60 -0
- customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
- customer_retention/stages/validation/business_sense_gate.py +173 -0
- customer_retention/stages/validation/data_quality_gate.py +235 -0
- customer_retention/stages/validation/data_validators.py +511 -0
- customer_retention/stages/validation/feature_quality_gate.py +183 -0
- customer_retention/stages/validation/gates.py +117 -0
- customer_retention/stages/validation/leakage_gate.py +352 -0
- customer_retention/stages/validation/model_validity_gate.py +213 -0
- customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
- customer_retention/stages/validation/quality_scorer.py +544 -0
- customer_retention/stages/validation/rule_generator.py +57 -0
- customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
- customer_retention/stages/validation/timeseries_detector.py +769 -0
- customer_retention/transforms/__init__.py +47 -0
- customer_retention/transforms/artifact_store.py +50 -0
- customer_retention/transforms/executor.py +157 -0
- customer_retention/transforms/fitted.py +92 -0
- customer_retention/transforms/ops.py +148 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""Cohort-level interpretability analysis."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import shap
|
|
8
|
+
|
|
9
|
+
from customer_retention.core.compat import DataFrame, Series, pd
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class CohortInsight:
|
|
14
|
+
cohort_name: str
|
|
15
|
+
cohort_size: int
|
|
16
|
+
cohort_percentage: float
|
|
17
|
+
churn_rate: float
|
|
18
|
+
top_features: List[Dict[str, float]]
|
|
19
|
+
key_differentiators: List[str] = field(default_factory=list)
|
|
20
|
+
recommended_strategy: str = ""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class CohortComparison:
|
|
25
|
+
cohort_a: str
|
|
26
|
+
cohort_b: str
|
|
27
|
+
feature_differences: Dict[str, float]
|
|
28
|
+
churn_rate_difference: float
|
|
29
|
+
key_differences: List[str] = field(default_factory=list)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class CohortAnalysisResult:
|
|
34
|
+
cohort_insights: List[CohortInsight]
|
|
35
|
+
key_differences: List[str]
|
|
36
|
+
overall_summary: str = ""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class CohortAnalyzer:
|
|
40
|
+
def __init__(self, model: Any, background_data: DataFrame, max_samples: int = 100):
|
|
41
|
+
self.model = model
|
|
42
|
+
self.background_data = background_data.head(max_samples)
|
|
43
|
+
self._explainer = self._create_explainer()
|
|
44
|
+
|
|
45
|
+
def _create_explainer(self) -> shap.Explainer:
|
|
46
|
+
model_type = type(self.model).__name__
|
|
47
|
+
if model_type in ["RandomForestClassifier", "GradientBoostingClassifier"]:
|
|
48
|
+
return shap.TreeExplainer(self.model)
|
|
49
|
+
return shap.KernelExplainer(self.model.predict_proba, self.background_data)
|
|
50
|
+
|
|
51
|
+
def analyze(self, X: DataFrame, y: Series, cohorts: Series) -> CohortAnalysisResult:
|
|
52
|
+
unique_cohorts = cohorts.unique()
|
|
53
|
+
insights = []
|
|
54
|
+
all_features_by_cohort = {}
|
|
55
|
+
for cohort in unique_cohorts:
|
|
56
|
+
mask = cohorts == cohort
|
|
57
|
+
cohort_X = X[mask]
|
|
58
|
+
cohort_y = y[mask]
|
|
59
|
+
churn_rate = float(1 - cohort_y.mean())
|
|
60
|
+
top_features = self._get_cohort_feature_importance(cohort_X)
|
|
61
|
+
all_features_by_cohort[cohort] = top_features
|
|
62
|
+
strategy = self._generate_strategy(cohort, churn_rate, top_features)
|
|
63
|
+
insights.append(CohortInsight(
|
|
64
|
+
cohort_name=cohort,
|
|
65
|
+
cohort_size=len(cohort_X),
|
|
66
|
+
cohort_percentage=len(cohort_X) / len(X),
|
|
67
|
+
churn_rate=churn_rate,
|
|
68
|
+
top_features=top_features,
|
|
69
|
+
recommended_strategy=strategy
|
|
70
|
+
))
|
|
71
|
+
key_differences = self._identify_key_differences(all_features_by_cohort, insights)
|
|
72
|
+
for insight in insights:
|
|
73
|
+
insight.key_differentiators = self._get_differentiators(insight.cohort_name, all_features_by_cohort)
|
|
74
|
+
return CohortAnalysisResult(
|
|
75
|
+
cohort_insights=insights,
|
|
76
|
+
key_differences=key_differences
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _get_cohort_feature_importance(self, cohort_X: DataFrame) -> List[Dict[str, float]]:
|
|
80
|
+
if len(cohort_X) == 0:
|
|
81
|
+
return []
|
|
82
|
+
sample = cohort_X.head(min(50, len(cohort_X)))
|
|
83
|
+
shap_values = self._extract_shap_values(sample)
|
|
84
|
+
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
85
|
+
sorted_indices = np.argsort(mean_abs_shap)[::-1][:5]
|
|
86
|
+
result = []
|
|
87
|
+
for idx in sorted_indices:
|
|
88
|
+
importance_val = mean_abs_shap[idx]
|
|
89
|
+
if hasattr(importance_val, '__len__') and len(importance_val) == 1:
|
|
90
|
+
importance_val = importance_val[0]
|
|
91
|
+
result.append({"feature": cohort_X.columns[idx], "importance": float(importance_val)})
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
def _extract_shap_values(self, X: DataFrame) -> np.ndarray:
|
|
95
|
+
shap_values = self._explainer.shap_values(X)
|
|
96
|
+
if hasattr(shap_values, 'values'):
|
|
97
|
+
shap_values = shap_values.values
|
|
98
|
+
if isinstance(shap_values, list):
|
|
99
|
+
shap_values = shap_values[1]
|
|
100
|
+
if len(shap_values.shape) == 3:
|
|
101
|
+
shap_values = shap_values[:, :, 1]
|
|
102
|
+
return shap_values
|
|
103
|
+
|
|
104
|
+
def _generate_strategy(self, cohort: str, churn_rate: float,
|
|
105
|
+
top_features: List[Dict[str, float]]) -> str:
|
|
106
|
+
if churn_rate > 0.5:
|
|
107
|
+
priority = "urgent intervention"
|
|
108
|
+
elif churn_rate > 0.3:
|
|
109
|
+
priority = "proactive engagement"
|
|
110
|
+
else:
|
|
111
|
+
priority = "standard nurturing"
|
|
112
|
+
top_feature = top_features[0]["feature"] if top_features else "engagement"
|
|
113
|
+
return f"Focus on {top_feature} with {priority} for {cohort} cohort"
|
|
114
|
+
|
|
115
|
+
def _identify_key_differences(self, features_by_cohort: Dict[str, List[Dict[str, float]]],
|
|
116
|
+
insights: List[CohortInsight]) -> List[str]:
|
|
117
|
+
differences = []
|
|
118
|
+
churn_rates = {i.cohort_name: i.churn_rate for i in insights}
|
|
119
|
+
if churn_rates:
|
|
120
|
+
max_cohort = max(churn_rates, key=churn_rates.get)
|
|
121
|
+
min_cohort = min(churn_rates, key=churn_rates.get)
|
|
122
|
+
diff = churn_rates[max_cohort] - churn_rates[min_cohort]
|
|
123
|
+
differences.append(f"{max_cohort} has {diff:.1%} higher churn than {min_cohort}")
|
|
124
|
+
for cohort, features in features_by_cohort.items():
|
|
125
|
+
if features:
|
|
126
|
+
top = features[0]["feature"]
|
|
127
|
+
differences.append(f"{cohort}: top driver is {top}")
|
|
128
|
+
return differences
|
|
129
|
+
|
|
130
|
+
def _get_differentiators(self, cohort: str,
|
|
131
|
+
features_by_cohort: Dict[str, List[Dict[str, float]]]) -> List[str]:
|
|
132
|
+
cohort_features = features_by_cohort.get(cohort, [])
|
|
133
|
+
cohort_top = set(f["feature"] for f in cohort_features[:3])
|
|
134
|
+
other_tops = set()
|
|
135
|
+
for other, features in features_by_cohort.items():
|
|
136
|
+
if other != cohort:
|
|
137
|
+
other_tops.update(f["feature"] for f in features[:3])
|
|
138
|
+
unique = cohort_top - other_tops
|
|
139
|
+
return [f"{cohort} uniquely driven by {f}" for f in unique]
|
|
140
|
+
|
|
141
|
+
def compare_cohorts(self, X: DataFrame, y: Series, cohorts: Series,
|
|
142
|
+
cohort_a: str, cohort_b: str) -> CohortComparison:
|
|
143
|
+
mask_a = cohorts == cohort_a
|
|
144
|
+
mask_b = cohorts == cohort_b
|
|
145
|
+
churn_a = 1 - y[mask_a].mean()
|
|
146
|
+
churn_b = 1 - y[mask_b].mean()
|
|
147
|
+
feature_diffs = {}
|
|
148
|
+
for col in X.columns:
|
|
149
|
+
mean_a = X.loc[mask_a, col].mean()
|
|
150
|
+
mean_b = X.loc[mask_b, col].mean()
|
|
151
|
+
feature_diffs[col] = float(mean_a - mean_b)
|
|
152
|
+
key_diffs = []
|
|
153
|
+
sorted_diffs = sorted(feature_diffs.items(), key=lambda x: abs(x[1]), reverse=True)
|
|
154
|
+
for feature, diff in sorted_diffs[:3]:
|
|
155
|
+
direction = "higher" if diff > 0 else "lower"
|
|
156
|
+
key_diffs.append(f"{cohort_a} has {direction} {feature} than {cohort_b}")
|
|
157
|
+
return CohortComparison(
|
|
158
|
+
cohort_a=cohort_a,
|
|
159
|
+
cohort_b=cohort_b,
|
|
160
|
+
feature_differences=feature_diffs,
|
|
161
|
+
churn_rate_difference=float(churn_a - churn_b),
|
|
162
|
+
key_differences=key_diffs
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
@staticmethod
|
|
166
|
+
def create_tenure_cohorts(tenure: Series,
|
|
167
|
+
bins: List[float] = None) -> Series:
|
|
168
|
+
bins = bins or [0, 90, 365, float("inf")]
|
|
169
|
+
labels = ["New", "Established", "Mature"]
|
|
170
|
+
return pd.cut(tenure, bins=bins, labels=labels)
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def create_value_cohorts(value: Series,
|
|
174
|
+
quantiles: List[float] = None) -> Series:
|
|
175
|
+
quantiles = quantiles or [0.33, 0.66]
|
|
176
|
+
q1, q2 = value.quantile(quantiles[0]), value.quantile(quantiles[1])
|
|
177
|
+
return pd.cut(value, bins=[-float("inf"), q1, q2, float("inf")],
|
|
178
|
+
labels=["Low", "Medium", "High"])
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
def create_activity_cohorts(activity: Series,
|
|
182
|
+
thresholds: List[float] = None) -> Series:
|
|
183
|
+
thresholds = thresholds or [5, 15]
|
|
184
|
+
return pd.cut(activity, bins=[-float("inf"), thresholds[0], thresholds[1], float("inf")],
|
|
185
|
+
labels=["Dormant", "Moderate", "Active"])
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Counterfactual explanation generation."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from customer_retention.core.compat import DataFrame, Series
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class CounterfactualChange:
|
|
13
|
+
feature_name: str
|
|
14
|
+
original_value: float
|
|
15
|
+
new_value: float
|
|
16
|
+
change_magnitude: float
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class Counterfactual:
|
|
21
|
+
original_prediction: float
|
|
22
|
+
counterfactual_prediction: float
|
|
23
|
+
changes: List[CounterfactualChange]
|
|
24
|
+
feasibility_score: float
|
|
25
|
+
business_interpretation: str
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class CounterfactualGenerator:
|
|
29
|
+
def __init__(self, model: Any, reference_data: DataFrame,
|
|
30
|
+
actionable_features: Optional[List[str]] = None,
|
|
31
|
+
constraints: Optional[Dict[str, Dict[str, float]]] = None):
|
|
32
|
+
self.model = model
|
|
33
|
+
self.reference_data = reference_data
|
|
34
|
+
self.actionable_features = actionable_features or list(reference_data.columns)
|
|
35
|
+
self.constraints = constraints or {}
|
|
36
|
+
self._feature_bounds = self._calculate_bounds()
|
|
37
|
+
|
|
38
|
+
def _calculate_bounds(self) -> Dict[str, Dict[str, float]]:
|
|
39
|
+
bounds = {}
|
|
40
|
+
for col in self.reference_data.columns:
|
|
41
|
+
bounds[col] = {
|
|
42
|
+
"min": float(self.reference_data[col].min()),
|
|
43
|
+
"max": float(self.reference_data[col].max()),
|
|
44
|
+
"mean": float(self.reference_data[col].mean()),
|
|
45
|
+
"std": float(self.reference_data[col].std())
|
|
46
|
+
}
|
|
47
|
+
return bounds
|
|
48
|
+
|
|
49
|
+
def generate(self, instance: Series, target_class: int = 0,
|
|
50
|
+
max_iterations: int = 100) -> Counterfactual:
|
|
51
|
+
instance_df = instance.to_frame().T
|
|
52
|
+
original_pred = float(self.model.predict_proba(instance_df)[0, 1])
|
|
53
|
+
best_cf = instance.copy()
|
|
54
|
+
best_pred = original_pred
|
|
55
|
+
best_changes = []
|
|
56
|
+
target_pred = 0.3 if target_class == 0 else 0.7
|
|
57
|
+
for _ in range(max_iterations):
|
|
58
|
+
candidate = self._perturb_instance(instance, best_cf)
|
|
59
|
+
candidate_df = candidate.to_frame().T
|
|
60
|
+
pred = float(self.model.predict_proba(candidate_df)[0, 1])
|
|
61
|
+
improved = (target_class == 0 and pred < best_pred) or (target_class == 1 and pred > best_pred)
|
|
62
|
+
if improved:
|
|
63
|
+
best_cf = candidate
|
|
64
|
+
best_pred = pred
|
|
65
|
+
best_changes = self._compute_changes(instance, best_cf)
|
|
66
|
+
if (target_class == 0 and best_pred < target_pred) or (target_class == 1 and best_pred > target_pred):
|
|
67
|
+
break
|
|
68
|
+
feasibility = self._calculate_feasibility(instance, best_cf)
|
|
69
|
+
interpretation = self._generate_interpretation(best_changes, original_pred, best_pred)
|
|
70
|
+
return Counterfactual(
|
|
71
|
+
original_prediction=original_pred,
|
|
72
|
+
counterfactual_prediction=best_pred,
|
|
73
|
+
changes=best_changes,
|
|
74
|
+
feasibility_score=feasibility,
|
|
75
|
+
business_interpretation=interpretation
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _perturb_instance(self, original: Series, current: Series) -> Series:
|
|
79
|
+
candidate = current.copy()
|
|
80
|
+
feature = np.random.choice(self.actionable_features)
|
|
81
|
+
bounds = self._get_feature_bounds(feature)
|
|
82
|
+
current_val = candidate[feature]
|
|
83
|
+
step = (bounds["max"] - bounds["min"]) * 0.1
|
|
84
|
+
direction = np.random.choice([-1, 1])
|
|
85
|
+
new_val = current_val + direction * step * np.random.uniform(0.5, 1.5)
|
|
86
|
+
new_val = np.clip(new_val, bounds["min"], bounds["max"])
|
|
87
|
+
candidate[feature] = new_val
|
|
88
|
+
return candidate
|
|
89
|
+
|
|
90
|
+
def _get_feature_bounds(self, feature: str) -> Dict[str, float]:
|
|
91
|
+
if feature in self.constraints:
|
|
92
|
+
constraint = self.constraints[feature]
|
|
93
|
+
return {
|
|
94
|
+
"min": constraint.get("min", self._feature_bounds[feature]["min"]),
|
|
95
|
+
"max": constraint.get("max", self._feature_bounds[feature]["max"])
|
|
96
|
+
}
|
|
97
|
+
return self._feature_bounds[feature]
|
|
98
|
+
|
|
99
|
+
def _compute_changes(self, original: Series, counterfactual: Series) -> List[CounterfactualChange]:
|
|
100
|
+
changes = []
|
|
101
|
+
for feature in self.actionable_features:
|
|
102
|
+
if abs(original[feature] - counterfactual[feature]) > 1e-6:
|
|
103
|
+
changes.append(CounterfactualChange(
|
|
104
|
+
feature_name=feature,
|
|
105
|
+
original_value=float(original[feature]),
|
|
106
|
+
new_value=float(counterfactual[feature]),
|
|
107
|
+
change_magnitude=float(abs(original[feature] - counterfactual[feature]))
|
|
108
|
+
))
|
|
109
|
+
return changes
|
|
110
|
+
|
|
111
|
+
def _calculate_feasibility(self, original: Series, counterfactual: Series) -> float:
|
|
112
|
+
total_change = 0
|
|
113
|
+
max_change = 0
|
|
114
|
+
for feature in self.actionable_features:
|
|
115
|
+
bounds = self._feature_bounds[feature]
|
|
116
|
+
range_size = bounds["max"] - bounds["min"]
|
|
117
|
+
if range_size > 0:
|
|
118
|
+
normalized_change = abs(original[feature] - counterfactual[feature]) / range_size
|
|
119
|
+
total_change += normalized_change
|
|
120
|
+
max_change += 1
|
|
121
|
+
if max_change == 0:
|
|
122
|
+
return 1.0
|
|
123
|
+
feasibility = 1 - (total_change / max_change)
|
|
124
|
+
return max(0.0, min(1.0, feasibility))
|
|
125
|
+
|
|
126
|
+
def _generate_interpretation(self, changes: List[CounterfactualChange],
|
|
127
|
+
original_pred: float, new_pred: float) -> str:
|
|
128
|
+
if not changes:
|
|
129
|
+
return "No changes needed to achieve target prediction."
|
|
130
|
+
change_strs = []
|
|
131
|
+
for c in changes[:3]:
|
|
132
|
+
direction = "increase" if c.new_value > c.original_value else "decrease"
|
|
133
|
+
change_strs.append(f"{direction} {c.feature_name} from {c.original_value:.2f} to {c.new_value:.2f}")
|
|
134
|
+
changes_text = ", ".join(change_strs)
|
|
135
|
+
return f"To reduce churn risk from {original_pred:.1%} to {new_pred:.1%}: {changes_text}"
|
|
136
|
+
|
|
137
|
+
def generate_diverse(self, instance: Series, n: int = 3) -> List[Counterfactual]:
|
|
138
|
+
counterfactuals = []
|
|
139
|
+
used_features = set()
|
|
140
|
+
for _ in range(n):
|
|
141
|
+
available = [f for f in self.actionable_features if f not in used_features]
|
|
142
|
+
if not available:
|
|
143
|
+
available = self.actionable_features
|
|
144
|
+
temp_generator = CounterfactualGenerator(
|
|
145
|
+
self.model, self.reference_data,
|
|
146
|
+
actionable_features=available,
|
|
147
|
+
constraints=self.constraints
|
|
148
|
+
)
|
|
149
|
+
cf = temp_generator.generate(instance)
|
|
150
|
+
counterfactuals.append(cf)
|
|
151
|
+
for change in cf.changes:
|
|
152
|
+
used_features.add(change.feature_name)
|
|
153
|
+
return counterfactuals
|
|
154
|
+
|
|
155
|
+
def generate_prototype(self, instance: Series, prototype_data: DataFrame) -> Counterfactual:
|
|
156
|
+
instance_df = instance.to_frame().T
|
|
157
|
+
original_pred = float(self.model.predict_proba(instance_df)[0, 1])
|
|
158
|
+
prototype = prototype_data.mean()
|
|
159
|
+
best_cf = instance.copy()
|
|
160
|
+
for feature in self.actionable_features:
|
|
161
|
+
bounds = self._get_feature_bounds(feature)
|
|
162
|
+
target_val = np.clip(prototype[feature], bounds["min"], bounds["max"])
|
|
163
|
+
best_cf[feature] = instance[feature] + 0.5 * (target_val - instance[feature])
|
|
164
|
+
cf_df = best_cf.to_frame().T
|
|
165
|
+
new_pred = float(self.model.predict_proba(cf_df)[0, 1])
|
|
166
|
+
changes = self._compute_changes(instance, best_cf)
|
|
167
|
+
feasibility = self._calculate_feasibility(instance, best_cf)
|
|
168
|
+
interpretation = self._generate_interpretation(changes, original_pred, new_pred)
|
|
169
|
+
return Counterfactual(
|
|
170
|
+
original_prediction=original_pred,
|
|
171
|
+
counterfactual_prediction=new_pred,
|
|
172
|
+
changes=changes,
|
|
173
|
+
feasibility_score=feasibility,
|
|
174
|
+
business_interpretation=interpretation
|
|
175
|
+
)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Individual customer explanation."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import shap
|
|
9
|
+
from sklearn.neighbors import NearestNeighbors
|
|
10
|
+
from sklearn.preprocessing import StandardScaler
|
|
11
|
+
|
|
12
|
+
from customer_retention.core.compat import DataFrame, Series
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Confidence(Enum):
|
|
16
|
+
HIGH = "high"
|
|
17
|
+
MEDIUM = "medium"
|
|
18
|
+
LOW = "low"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class RiskContribution:
|
|
23
|
+
feature_name: str
|
|
24
|
+
contribution: float
|
|
25
|
+
current_value: float
|
|
26
|
+
direction: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class IndividualExplanation:
|
|
31
|
+
customer_id: Optional[str]
|
|
32
|
+
churn_probability: float
|
|
33
|
+
base_value: float
|
|
34
|
+
shap_values: np.ndarray
|
|
35
|
+
top_positive_factors: List[RiskContribution]
|
|
36
|
+
top_negative_factors: List[RiskContribution]
|
|
37
|
+
confidence: Confidence
|
|
38
|
+
feature_names: List[str] = field(default_factory=list)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class IndividualExplainer:
|
|
42
|
+
def __init__(self, model: Any, background_data: DataFrame, max_samples: int = 100):
|
|
43
|
+
self.model = model
|
|
44
|
+
self.background_data = background_data.head(max_samples)
|
|
45
|
+
self.feature_names = list(background_data.columns)
|
|
46
|
+
self._explainer = self._create_explainer()
|
|
47
|
+
|
|
48
|
+
def _create_explainer(self) -> shap.Explainer:
|
|
49
|
+
model_type = type(self.model).__name__
|
|
50
|
+
if model_type in ["RandomForestClassifier", "GradientBoostingClassifier"]:
|
|
51
|
+
return shap.TreeExplainer(self.model)
|
|
52
|
+
if model_type in ["LogisticRegression", "LinearRegression"]:
|
|
53
|
+
return shap.LinearExplainer(self.model, self.background_data)
|
|
54
|
+
return shap.KernelExplainer(self.model.predict_proba, self.background_data)
|
|
55
|
+
|
|
56
|
+
def explain(self, instance: Series, customer_id: Optional[str] = None,
|
|
57
|
+
top_n: int = 3) -> IndividualExplanation:
|
|
58
|
+
instance_df = instance.to_frame().T
|
|
59
|
+
shap_values = self._extract_shap_values(instance_df)
|
|
60
|
+
churn_prob = float(self.model.predict_proba(instance_df)[0, 1])
|
|
61
|
+
expected_value = self._get_expected_value()
|
|
62
|
+
positive_factors = self._extract_factors(instance, shap_values, top_n, positive=True)
|
|
63
|
+
negative_factors = self._extract_factors(instance, shap_values, top_n, positive=False)
|
|
64
|
+
confidence = self._assess_confidence(churn_prob)
|
|
65
|
+
return IndividualExplanation(
|
|
66
|
+
customer_id=customer_id,
|
|
67
|
+
churn_probability=churn_prob,
|
|
68
|
+
base_value=float(expected_value),
|
|
69
|
+
shap_values=shap_values,
|
|
70
|
+
top_positive_factors=positive_factors,
|
|
71
|
+
top_negative_factors=negative_factors,
|
|
72
|
+
confidence=confidence,
|
|
73
|
+
feature_names=self.feature_names
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def _extract_shap_values(self, X: DataFrame) -> np.ndarray:
|
|
77
|
+
shap_values = self._explainer.shap_values(X)
|
|
78
|
+
if hasattr(shap_values, 'values'):
|
|
79
|
+
shap_values = shap_values.values
|
|
80
|
+
if isinstance(shap_values, list):
|
|
81
|
+
shap_values = shap_values[1]
|
|
82
|
+
if len(shap_values.shape) == 3:
|
|
83
|
+
shap_values = shap_values[:, :, 1]
|
|
84
|
+
return shap_values.flatten()
|
|
85
|
+
|
|
86
|
+
def _get_expected_value(self) -> float:
|
|
87
|
+
expected_value = self._explainer.expected_value
|
|
88
|
+
if hasattr(expected_value, '__len__'):
|
|
89
|
+
if len(expected_value) > 1:
|
|
90
|
+
return float(expected_value[1])
|
|
91
|
+
return float(expected_value[0])
|
|
92
|
+
return float(expected_value)
|
|
93
|
+
|
|
94
|
+
def _extract_factors(self, instance: Series, shap_values: np.ndarray,
|
|
95
|
+
top_n: int, positive: bool) -> List[RiskContribution]:
|
|
96
|
+
if positive:
|
|
97
|
+
indices = np.argsort(shap_values)[::-1]
|
|
98
|
+
values = [(i, shap_values[i]) for i in indices if shap_values[i] > 0]
|
|
99
|
+
else:
|
|
100
|
+
indices = np.argsort(shap_values)
|
|
101
|
+
values = [(i, shap_values[i]) for i in indices if shap_values[i] < 0]
|
|
102
|
+
factors = []
|
|
103
|
+
for idx, contrib in values[:top_n]:
|
|
104
|
+
feature_name = self.feature_names[idx]
|
|
105
|
+
factors.append(RiskContribution(
|
|
106
|
+
feature_name=feature_name,
|
|
107
|
+
contribution=float(contrib),
|
|
108
|
+
current_value=float(instance[feature_name]),
|
|
109
|
+
direction="increases risk" if contrib > 0 else "decreases risk"
|
|
110
|
+
))
|
|
111
|
+
return factors
|
|
112
|
+
|
|
113
|
+
def _assess_confidence(self, probability: float) -> Confidence:
|
|
114
|
+
if probability < 0.2 or probability > 0.8:
|
|
115
|
+
return Confidence.HIGH
|
|
116
|
+
if 0.4 < probability < 0.6:
|
|
117
|
+
return Confidence.LOW
|
|
118
|
+
return Confidence.MEDIUM
|
|
119
|
+
|
|
120
|
+
def find_similar_customers(self, instance: Series, X: DataFrame,
|
|
121
|
+
y: Series, k: int = 5) -> List[Dict]:
|
|
122
|
+
scaler = StandardScaler()
|
|
123
|
+
X_scaled = scaler.fit_transform(X)
|
|
124
|
+
instance_scaled = scaler.transform(instance.to_frame().T)
|
|
125
|
+
knn = NearestNeighbors(n_neighbors=k + 1, metric="euclidean")
|
|
126
|
+
knn.fit(X_scaled)
|
|
127
|
+
distances, indices = knn.kneighbors(instance_scaled)
|
|
128
|
+
similar = []
|
|
129
|
+
for dist, idx in zip(distances[0][1:], indices[0][1:]):
|
|
130
|
+
similar.append({
|
|
131
|
+
"index": int(idx),
|
|
132
|
+
"distance": float(dist),
|
|
133
|
+
"outcome": int(y.iloc[idx]),
|
|
134
|
+
"features": X.iloc[idx].to_dict()
|
|
135
|
+
})
|
|
136
|
+
return similar
|
|
137
|
+
|
|
138
|
+
def explain_batch(self, X: DataFrame,
|
|
139
|
+
customer_ids: Optional[List[str]] = None) -> List[IndividualExplanation]:
|
|
140
|
+
customer_ids = customer_ids or [None] * len(X)
|
|
141
|
+
return [self.explain(X.iloc[i], customer_ids[i]) for i in range(len(X))]
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Partial Dependence Plot generation."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, List, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from sklearn.inspection import partial_dependence
|
|
8
|
+
|
|
9
|
+
from customer_retention.core.compat import DataFrame
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class PDPResult:
|
|
14
|
+
feature_name: str
|
|
15
|
+
grid_values: np.ndarray
|
|
16
|
+
pdp_values: np.ndarray
|
|
17
|
+
feature_min: float
|
|
18
|
+
feature_max: float
|
|
19
|
+
average_prediction: float
|
|
20
|
+
ice_values: Optional[List[np.ndarray]] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class InteractionResult:
|
|
25
|
+
feature1_name: str
|
|
26
|
+
feature2_name: str
|
|
27
|
+
grid1_values: np.ndarray
|
|
28
|
+
grid2_values: np.ndarray
|
|
29
|
+
pdp_matrix: np.ndarray
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PDPGenerator:
|
|
33
|
+
def __init__(self, model: Any):
|
|
34
|
+
self.model = model
|
|
35
|
+
|
|
36
|
+
def generate(self, X: DataFrame, feature: str, grid_resolution: int = 50,
|
|
37
|
+
include_ice: bool = False, ice_lines: int = 100) -> PDPResult:
|
|
38
|
+
feature_idx = list(X.columns).index(feature)
|
|
39
|
+
pd_result = partial_dependence(
|
|
40
|
+
self.model, X, [feature_idx], kind="average", grid_resolution=grid_resolution
|
|
41
|
+
)
|
|
42
|
+
grid_values = pd_result["grid_values"][0]
|
|
43
|
+
pdp_values = pd_result["average"][0]
|
|
44
|
+
ice_values = None
|
|
45
|
+
if include_ice:
|
|
46
|
+
ice_values = self._calculate_ice(X, feature, grid_values, ice_lines)
|
|
47
|
+
return PDPResult(
|
|
48
|
+
feature_name=feature,
|
|
49
|
+
grid_values=grid_values,
|
|
50
|
+
pdp_values=pdp_values,
|
|
51
|
+
feature_min=float(X[feature].min()),
|
|
52
|
+
feature_max=float(X[feature].max()),
|
|
53
|
+
average_prediction=float(np.mean(pdp_values)),
|
|
54
|
+
ice_values=ice_values
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def _calculate_ice(self, X: DataFrame, feature: str,
|
|
58
|
+
grid_values: np.ndarray, n_samples: int) -> List[np.ndarray]:
|
|
59
|
+
sample_indices = np.random.choice(len(X), min(n_samples, len(X)), replace=False)
|
|
60
|
+
ice_lines = []
|
|
61
|
+
for idx in sample_indices:
|
|
62
|
+
X_temp = X.iloc[[idx]].copy()
|
|
63
|
+
predictions = []
|
|
64
|
+
for val in grid_values:
|
|
65
|
+
X_temp[feature] = val
|
|
66
|
+
pred = self.model.predict_proba(X_temp)[0, 1]
|
|
67
|
+
predictions.append(pred)
|
|
68
|
+
ice_lines.append(np.array(predictions))
|
|
69
|
+
return ice_lines
|
|
70
|
+
|
|
71
|
+
def generate_multiple(self, X: DataFrame, features: List[str],
|
|
72
|
+
grid_resolution: int = 50) -> List[PDPResult]:
|
|
73
|
+
return [self.generate(X, feature, grid_resolution) for feature in features]
|
|
74
|
+
|
|
75
|
+
def generate_top_features(self, X: DataFrame, n_features: int = 5,
|
|
76
|
+
grid_resolution: int = 50) -> List[PDPResult]:
|
|
77
|
+
importances = {}
|
|
78
|
+
for feature in X.columns:
|
|
79
|
+
X_shuffled = X.copy()
|
|
80
|
+
X_shuffled[feature] = np.random.permutation(X_shuffled[feature].values)
|
|
81
|
+
original_pred = self.model.predict_proba(X)[:, 1].mean()
|
|
82
|
+
shuffled_pred = self.model.predict_proba(X_shuffled)[:, 1].mean()
|
|
83
|
+
importances[feature] = abs(original_pred - shuffled_pred)
|
|
84
|
+
top_features = sorted(importances.keys(), key=lambda f: importances[f], reverse=True)[:n_features]
|
|
85
|
+
return self.generate_multiple(X, top_features, grid_resolution)
|
|
86
|
+
|
|
87
|
+
def generate_interaction(self, X: DataFrame, feature1: str, feature2: str,
|
|
88
|
+
grid_resolution: int = 20) -> InteractionResult:
|
|
89
|
+
feature1_idx = list(X.columns).index(feature1)
|
|
90
|
+
feature2_idx = list(X.columns).index(feature2)
|
|
91
|
+
pd_result = partial_dependence(
|
|
92
|
+
self.model, X, [(feature1_idx, feature2_idx)], kind="average", grid_resolution=grid_resolution
|
|
93
|
+
)
|
|
94
|
+
grid1 = pd_result["grid_values"][0]
|
|
95
|
+
grid2 = pd_result["grid_values"][1]
|
|
96
|
+
pdp_matrix = pd_result["average"][0]
|
|
97
|
+
return InteractionResult(
|
|
98
|
+
feature1_name=feature1,
|
|
99
|
+
feature2_name=feature2,
|
|
100
|
+
grid1_values=grid1,
|
|
101
|
+
grid2_values=grid2,
|
|
102
|
+
pdp_matrix=pdp_matrix
|
|
103
|
+
)
|