churnkit 0.75.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
  2. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
  3. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
  4. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
  5. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
  6. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
  7. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
  8. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
  9. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
  10. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
  11. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
  12. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
  13. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
  14. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
  15. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
  16. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
  17. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
  18. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
  19. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
  20. churnkit-0.75.0a1.dist-info/METADATA +229 -0
  21. churnkit-0.75.0a1.dist-info/RECORD +302 -0
  22. churnkit-0.75.0a1.dist-info/WHEEL +4 -0
  23. churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
  24. churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
  25. customer_retention/__init__.py +37 -0
  26. customer_retention/analysis/__init__.py +0 -0
  27. customer_retention/analysis/auto_explorer/__init__.py +62 -0
  28. customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
  29. customer_retention/analysis/auto_explorer/explorer.py +258 -0
  30. customer_retention/analysis/auto_explorer/findings.py +291 -0
  31. customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
  32. customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
  33. customer_retention/analysis/auto_explorer/recommendations.py +418 -0
  34. customer_retention/analysis/business/__init__.py +26 -0
  35. customer_retention/analysis/business/ab_test_designer.py +144 -0
  36. customer_retention/analysis/business/fairness_analyzer.py +166 -0
  37. customer_retention/analysis/business/intervention_matcher.py +121 -0
  38. customer_retention/analysis/business/report_generator.py +222 -0
  39. customer_retention/analysis/business/risk_profile.py +199 -0
  40. customer_retention/analysis/business/roi_analyzer.py +139 -0
  41. customer_retention/analysis/diagnostics/__init__.py +20 -0
  42. customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
  43. customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
  44. customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
  45. customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
  46. customer_retention/analysis/diagnostics/noise_tester.py +140 -0
  47. customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
  48. customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
  49. customer_retention/analysis/discovery/__init__.py +8 -0
  50. customer_retention/analysis/discovery/config_generator.py +49 -0
  51. customer_retention/analysis/discovery/discovery_flow.py +19 -0
  52. customer_retention/analysis/discovery/type_inferencer.py +147 -0
  53. customer_retention/analysis/interpretability/__init__.py +13 -0
  54. customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
  55. customer_retention/analysis/interpretability/counterfactual.py +175 -0
  56. customer_retention/analysis/interpretability/individual_explainer.py +141 -0
  57. customer_retention/analysis/interpretability/pdp_generator.py +103 -0
  58. customer_retention/analysis/interpretability/shap_explainer.py +106 -0
  59. customer_retention/analysis/jupyter_save_hook.py +28 -0
  60. customer_retention/analysis/notebook_html_exporter.py +136 -0
  61. customer_retention/analysis/notebook_progress.py +60 -0
  62. customer_retention/analysis/plotly_preprocessor.py +154 -0
  63. customer_retention/analysis/recommendations/__init__.py +54 -0
  64. customer_retention/analysis/recommendations/base.py +158 -0
  65. customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
  66. customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
  67. customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
  68. customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
  69. customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
  70. customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
  71. customer_retention/analysis/recommendations/datetime/extract.py +149 -0
  72. customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
  73. customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
  74. customer_retention/analysis/recommendations/pipeline.py +74 -0
  75. customer_retention/analysis/recommendations/registry.py +76 -0
  76. customer_retention/analysis/recommendations/selection/__init__.py +3 -0
  77. customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
  78. customer_retention/analysis/recommendations/transform/__init__.py +4 -0
  79. customer_retention/analysis/recommendations/transform/power.py +94 -0
  80. customer_retention/analysis/recommendations/transform/scale.py +112 -0
  81. customer_retention/analysis/visualization/__init__.py +15 -0
  82. customer_retention/analysis/visualization/chart_builder.py +2619 -0
  83. customer_retention/analysis/visualization/console.py +122 -0
  84. customer_retention/analysis/visualization/display.py +171 -0
  85. customer_retention/analysis/visualization/number_formatter.py +36 -0
  86. customer_retention/artifacts/__init__.py +3 -0
  87. customer_retention/artifacts/fit_artifact_registry.py +146 -0
  88. customer_retention/cli.py +93 -0
  89. customer_retention/core/__init__.py +0 -0
  90. customer_retention/core/compat/__init__.py +193 -0
  91. customer_retention/core/compat/detection.py +99 -0
  92. customer_retention/core/compat/ops.py +48 -0
  93. customer_retention/core/compat/pandas_backend.py +57 -0
  94. customer_retention/core/compat/spark_backend.py +75 -0
  95. customer_retention/core/components/__init__.py +11 -0
  96. customer_retention/core/components/base.py +79 -0
  97. customer_retention/core/components/components/__init__.py +13 -0
  98. customer_retention/core/components/components/deployer.py +26 -0
  99. customer_retention/core/components/components/explainer.py +26 -0
  100. customer_retention/core/components/components/feature_eng.py +33 -0
  101. customer_retention/core/components/components/ingester.py +34 -0
  102. customer_retention/core/components/components/profiler.py +34 -0
  103. customer_retention/core/components/components/trainer.py +38 -0
  104. customer_retention/core/components/components/transformer.py +36 -0
  105. customer_retention/core/components/components/validator.py +37 -0
  106. customer_retention/core/components/enums.py +33 -0
  107. customer_retention/core/components/orchestrator.py +94 -0
  108. customer_retention/core/components/registry.py +59 -0
  109. customer_retention/core/config/__init__.py +39 -0
  110. customer_retention/core/config/column_config.py +95 -0
  111. customer_retention/core/config/experiments.py +71 -0
  112. customer_retention/core/config/pipeline_config.py +117 -0
  113. customer_retention/core/config/source_config.py +83 -0
  114. customer_retention/core/utils/__init__.py +28 -0
  115. customer_retention/core/utils/leakage.py +85 -0
  116. customer_retention/core/utils/severity.py +53 -0
  117. customer_retention/core/utils/statistics.py +90 -0
  118. customer_retention/generators/__init__.py +0 -0
  119. customer_retention/generators/notebook_generator/__init__.py +167 -0
  120. customer_retention/generators/notebook_generator/base.py +55 -0
  121. customer_retention/generators/notebook_generator/cell_builder.py +49 -0
  122. customer_retention/generators/notebook_generator/config.py +47 -0
  123. customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
  124. customer_retention/generators/notebook_generator/local_generator.py +48 -0
  125. customer_retention/generators/notebook_generator/project_init.py +174 -0
  126. customer_retention/generators/notebook_generator/runner.py +150 -0
  127. customer_retention/generators/notebook_generator/script_generator.py +110 -0
  128. customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
  129. customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
  130. customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
  131. customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
  132. customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
  133. customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
  134. customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
  135. customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
  136. customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
  137. customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
  138. customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
  139. customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
  140. customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
  141. customer_retention/generators/orchestration/__init__.py +23 -0
  142. customer_retention/generators/orchestration/code_generator.py +196 -0
  143. customer_retention/generators/orchestration/context.py +147 -0
  144. customer_retention/generators/orchestration/data_materializer.py +188 -0
  145. customer_retention/generators/orchestration/databricks_exporter.py +411 -0
  146. customer_retention/generators/orchestration/doc_generator.py +311 -0
  147. customer_retention/generators/pipeline_generator/__init__.py +26 -0
  148. customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
  149. customer_retention/generators/pipeline_generator/generator.py +142 -0
  150. customer_retention/generators/pipeline_generator/models.py +166 -0
  151. customer_retention/generators/pipeline_generator/renderer.py +2125 -0
  152. customer_retention/generators/spec_generator/__init__.py +37 -0
  153. customer_retention/generators/spec_generator/databricks_generator.py +433 -0
  154. customer_retention/generators/spec_generator/generic_generator.py +373 -0
  155. customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
  156. customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
  157. customer_retention/integrations/__init__.py +0 -0
  158. customer_retention/integrations/adapters/__init__.py +13 -0
  159. customer_retention/integrations/adapters/base.py +10 -0
  160. customer_retention/integrations/adapters/factory.py +25 -0
  161. customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
  162. customer_retention/integrations/adapters/feature_store/base.py +57 -0
  163. customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
  164. customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
  165. customer_retention/integrations/adapters/feature_store/local.py +75 -0
  166. customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
  167. customer_retention/integrations/adapters/mlflow/base.py +32 -0
  168. customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
  169. customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
  170. customer_retention/integrations/adapters/mlflow/local.py +50 -0
  171. customer_retention/integrations/adapters/storage/__init__.py +5 -0
  172. customer_retention/integrations/adapters/storage/base.py +33 -0
  173. customer_retention/integrations/adapters/storage/databricks.py +76 -0
  174. customer_retention/integrations/adapters/storage/local.py +59 -0
  175. customer_retention/integrations/feature_store/__init__.py +47 -0
  176. customer_retention/integrations/feature_store/definitions.py +215 -0
  177. customer_retention/integrations/feature_store/manager.py +744 -0
  178. customer_retention/integrations/feature_store/registry.py +412 -0
  179. customer_retention/integrations/iteration/__init__.py +28 -0
  180. customer_retention/integrations/iteration/context.py +212 -0
  181. customer_retention/integrations/iteration/feedback_collector.py +184 -0
  182. customer_retention/integrations/iteration/orchestrator.py +168 -0
  183. customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
  184. customer_retention/integrations/iteration/signals.py +212 -0
  185. customer_retention/integrations/llm_context/__init__.py +4 -0
  186. customer_retention/integrations/llm_context/context_builder.py +201 -0
  187. customer_retention/integrations/llm_context/prompts.py +100 -0
  188. customer_retention/integrations/streaming/__init__.py +103 -0
  189. customer_retention/integrations/streaming/batch_integration.py +149 -0
  190. customer_retention/integrations/streaming/early_warning_model.py +227 -0
  191. customer_retention/integrations/streaming/event_schema.py +214 -0
  192. customer_retention/integrations/streaming/online_store_writer.py +249 -0
  193. customer_retention/integrations/streaming/realtime_scorer.py +261 -0
  194. customer_retention/integrations/streaming/trigger_engine.py +293 -0
  195. customer_retention/integrations/streaming/window_aggregator.py +393 -0
  196. customer_retention/stages/__init__.py +0 -0
  197. customer_retention/stages/cleaning/__init__.py +9 -0
  198. customer_retention/stages/cleaning/base.py +28 -0
  199. customer_retention/stages/cleaning/missing_handler.py +160 -0
  200. customer_retention/stages/cleaning/outlier_handler.py +204 -0
  201. customer_retention/stages/deployment/__init__.py +28 -0
  202. customer_retention/stages/deployment/batch_scorer.py +106 -0
  203. customer_retention/stages/deployment/champion_challenger.py +299 -0
  204. customer_retention/stages/deployment/model_registry.py +182 -0
  205. customer_retention/stages/deployment/retraining_trigger.py +245 -0
  206. customer_retention/stages/features/__init__.py +73 -0
  207. customer_retention/stages/features/behavioral_features.py +266 -0
  208. customer_retention/stages/features/customer_segmentation.py +505 -0
  209. customer_retention/stages/features/feature_definitions.py +265 -0
  210. customer_retention/stages/features/feature_engineer.py +551 -0
  211. customer_retention/stages/features/feature_manifest.py +340 -0
  212. customer_retention/stages/features/feature_selector.py +239 -0
  213. customer_retention/stages/features/interaction_features.py +160 -0
  214. customer_retention/stages/features/temporal_features.py +243 -0
  215. customer_retention/stages/ingestion/__init__.py +9 -0
  216. customer_retention/stages/ingestion/load_result.py +32 -0
  217. customer_retention/stages/ingestion/loaders.py +195 -0
  218. customer_retention/stages/ingestion/source_registry.py +130 -0
  219. customer_retention/stages/modeling/__init__.py +31 -0
  220. customer_retention/stages/modeling/baseline_trainer.py +139 -0
  221. customer_retention/stages/modeling/cross_validator.py +125 -0
  222. customer_retention/stages/modeling/data_splitter.py +205 -0
  223. customer_retention/stages/modeling/feature_scaler.py +99 -0
  224. customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
  225. customer_retention/stages/modeling/imbalance_handler.py +282 -0
  226. customer_retention/stages/modeling/mlflow_logger.py +95 -0
  227. customer_retention/stages/modeling/model_comparator.py +149 -0
  228. customer_retention/stages/modeling/model_evaluator.py +138 -0
  229. customer_retention/stages/modeling/threshold_optimizer.py +131 -0
  230. customer_retention/stages/monitoring/__init__.py +37 -0
  231. customer_retention/stages/monitoring/alert_manager.py +328 -0
  232. customer_retention/stages/monitoring/drift_detector.py +201 -0
  233. customer_retention/stages/monitoring/performance_monitor.py +242 -0
  234. customer_retention/stages/preprocessing/__init__.py +5 -0
  235. customer_retention/stages/preprocessing/transformer_manager.py +284 -0
  236. customer_retention/stages/profiling/__init__.py +256 -0
  237. customer_retention/stages/profiling/categorical_distribution.py +269 -0
  238. customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
  239. customer_retention/stages/profiling/column_profiler.py +527 -0
  240. customer_retention/stages/profiling/distribution_analysis.py +483 -0
  241. customer_retention/stages/profiling/drift_detector.py +310 -0
  242. customer_retention/stages/profiling/feature_capacity.py +507 -0
  243. customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
  244. customer_retention/stages/profiling/profile_result.py +212 -0
  245. customer_retention/stages/profiling/quality_checks.py +1632 -0
  246. customer_retention/stages/profiling/relationship_detector.py +256 -0
  247. customer_retention/stages/profiling/relationship_recommender.py +454 -0
  248. customer_retention/stages/profiling/report_generator.py +520 -0
  249. customer_retention/stages/profiling/scd_analyzer.py +151 -0
  250. customer_retention/stages/profiling/segment_analyzer.py +632 -0
  251. customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
  252. customer_retention/stages/profiling/target_level_analyzer.py +217 -0
  253. customer_retention/stages/profiling/temporal_analyzer.py +388 -0
  254. customer_retention/stages/profiling/temporal_coverage.py +488 -0
  255. customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
  256. customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
  257. customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
  258. customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
  259. customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
  260. customer_retention/stages/profiling/text_embedder.py +87 -0
  261. customer_retention/stages/profiling/text_processor.py +115 -0
  262. customer_retention/stages/profiling/text_reducer.py +60 -0
  263. customer_retention/stages/profiling/time_series_profiler.py +303 -0
  264. customer_retention/stages/profiling/time_window_aggregator.py +376 -0
  265. customer_retention/stages/profiling/type_detector.py +382 -0
  266. customer_retention/stages/profiling/window_recommendation.py +288 -0
  267. customer_retention/stages/temporal/__init__.py +166 -0
  268. customer_retention/stages/temporal/access_guard.py +180 -0
  269. customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
  270. customer_retention/stages/temporal/data_preparer.py +178 -0
  271. customer_retention/stages/temporal/point_in_time_join.py +134 -0
  272. customer_retention/stages/temporal/point_in_time_registry.py +148 -0
  273. customer_retention/stages/temporal/scenario_detector.py +163 -0
  274. customer_retention/stages/temporal/snapshot_manager.py +259 -0
  275. customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
  276. customer_retention/stages/temporal/timestamp_discovery.py +531 -0
  277. customer_retention/stages/temporal/timestamp_manager.py +255 -0
  278. customer_retention/stages/transformation/__init__.py +13 -0
  279. customer_retention/stages/transformation/binary_handler.py +85 -0
  280. customer_retention/stages/transformation/categorical_encoder.py +245 -0
  281. customer_retention/stages/transformation/datetime_transformer.py +97 -0
  282. customer_retention/stages/transformation/numeric_transformer.py +181 -0
  283. customer_retention/stages/transformation/pipeline.py +257 -0
  284. customer_retention/stages/validation/__init__.py +60 -0
  285. customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
  286. customer_retention/stages/validation/business_sense_gate.py +173 -0
  287. customer_retention/stages/validation/data_quality_gate.py +235 -0
  288. customer_retention/stages/validation/data_validators.py +511 -0
  289. customer_retention/stages/validation/feature_quality_gate.py +183 -0
  290. customer_retention/stages/validation/gates.py +117 -0
  291. customer_retention/stages/validation/leakage_gate.py +352 -0
  292. customer_retention/stages/validation/model_validity_gate.py +213 -0
  293. customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
  294. customer_retention/stages/validation/quality_scorer.py +544 -0
  295. customer_retention/stages/validation/rule_generator.py +57 -0
  296. customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
  297. customer_retention/stages/validation/timeseries_detector.py +769 -0
  298. customer_retention/transforms/__init__.py +47 -0
  299. customer_retention/transforms/artifact_store.py +50 -0
  300. customer_retention/transforms/executor.py +157 -0
  301. customer_retention/transforms/fitted.py +92 -0
  302. customer_retention/transforms/ops.py +148 -0
@@ -0,0 +1,185 @@
1
+ """Cohort-level interpretability analysis."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List
5
+
6
+ import numpy as np
7
+ import shap
8
+
9
+ from customer_retention.core.compat import DataFrame, Series, pd
10
+
11
+
12
+ @dataclass
13
+ class CohortInsight:
14
+ cohort_name: str
15
+ cohort_size: int
16
+ cohort_percentage: float
17
+ churn_rate: float
18
+ top_features: List[Dict[str, float]]
19
+ key_differentiators: List[str] = field(default_factory=list)
20
+ recommended_strategy: str = ""
21
+
22
+
23
+ @dataclass
24
+ class CohortComparison:
25
+ cohort_a: str
26
+ cohort_b: str
27
+ feature_differences: Dict[str, float]
28
+ churn_rate_difference: float
29
+ key_differences: List[str] = field(default_factory=list)
30
+
31
+
32
+ @dataclass
33
+ class CohortAnalysisResult:
34
+ cohort_insights: List[CohortInsight]
35
+ key_differences: List[str]
36
+ overall_summary: str = ""
37
+
38
+
39
+ class CohortAnalyzer:
40
+ def __init__(self, model: Any, background_data: DataFrame, max_samples: int = 100):
41
+ self.model = model
42
+ self.background_data = background_data.head(max_samples)
43
+ self._explainer = self._create_explainer()
44
+
45
+ def _create_explainer(self) -> shap.Explainer:
46
+ model_type = type(self.model).__name__
47
+ if model_type in ["RandomForestClassifier", "GradientBoostingClassifier"]:
48
+ return shap.TreeExplainer(self.model)
49
+ return shap.KernelExplainer(self.model.predict_proba, self.background_data)
50
+
51
+ def analyze(self, X: DataFrame, y: Series, cohorts: Series) -> CohortAnalysisResult:
52
+ unique_cohorts = cohorts.unique()
53
+ insights = []
54
+ all_features_by_cohort = {}
55
+ for cohort in unique_cohorts:
56
+ mask = cohorts == cohort
57
+ cohort_X = X[mask]
58
+ cohort_y = y[mask]
59
+ churn_rate = float(1 - cohort_y.mean())
60
+ top_features = self._get_cohort_feature_importance(cohort_X)
61
+ all_features_by_cohort[cohort] = top_features
62
+ strategy = self._generate_strategy(cohort, churn_rate, top_features)
63
+ insights.append(CohortInsight(
64
+ cohort_name=cohort,
65
+ cohort_size=len(cohort_X),
66
+ cohort_percentage=len(cohort_X) / len(X),
67
+ churn_rate=churn_rate,
68
+ top_features=top_features,
69
+ recommended_strategy=strategy
70
+ ))
71
+ key_differences = self._identify_key_differences(all_features_by_cohort, insights)
72
+ for insight in insights:
73
+ insight.key_differentiators = self._get_differentiators(insight.cohort_name, all_features_by_cohort)
74
+ return CohortAnalysisResult(
75
+ cohort_insights=insights,
76
+ key_differences=key_differences
77
+ )
78
+
79
+ def _get_cohort_feature_importance(self, cohort_X: DataFrame) -> List[Dict[str, float]]:
80
+ if len(cohort_X) == 0:
81
+ return []
82
+ sample = cohort_X.head(min(50, len(cohort_X)))
83
+ shap_values = self._extract_shap_values(sample)
84
+ mean_abs_shap = np.abs(shap_values).mean(axis=0)
85
+ sorted_indices = np.argsort(mean_abs_shap)[::-1][:5]
86
+ result = []
87
+ for idx in sorted_indices:
88
+ importance_val = mean_abs_shap[idx]
89
+ if hasattr(importance_val, '__len__') and len(importance_val) == 1:
90
+ importance_val = importance_val[0]
91
+ result.append({"feature": cohort_X.columns[idx], "importance": float(importance_val)})
92
+ return result
93
+
94
+ def _extract_shap_values(self, X: DataFrame) -> np.ndarray:
95
+ shap_values = self._explainer.shap_values(X)
96
+ if hasattr(shap_values, 'values'):
97
+ shap_values = shap_values.values
98
+ if isinstance(shap_values, list):
99
+ shap_values = shap_values[1]
100
+ if len(shap_values.shape) == 3:
101
+ shap_values = shap_values[:, :, 1]
102
+ return shap_values
103
+
104
+ def _generate_strategy(self, cohort: str, churn_rate: float,
105
+ top_features: List[Dict[str, float]]) -> str:
106
+ if churn_rate > 0.5:
107
+ priority = "urgent intervention"
108
+ elif churn_rate > 0.3:
109
+ priority = "proactive engagement"
110
+ else:
111
+ priority = "standard nurturing"
112
+ top_feature = top_features[0]["feature"] if top_features else "engagement"
113
+ return f"Focus on {top_feature} with {priority} for {cohort} cohort"
114
+
115
+ def _identify_key_differences(self, features_by_cohort: Dict[str, List[Dict[str, float]]],
116
+ insights: List[CohortInsight]) -> List[str]:
117
+ differences = []
118
+ churn_rates = {i.cohort_name: i.churn_rate for i in insights}
119
+ if churn_rates:
120
+ max_cohort = max(churn_rates, key=churn_rates.get)
121
+ min_cohort = min(churn_rates, key=churn_rates.get)
122
+ diff = churn_rates[max_cohort] - churn_rates[min_cohort]
123
+ differences.append(f"{max_cohort} has {diff:.1%} higher churn than {min_cohort}")
124
+ for cohort, features in features_by_cohort.items():
125
+ if features:
126
+ top = features[0]["feature"]
127
+ differences.append(f"{cohort}: top driver is {top}")
128
+ return differences
129
+
130
+ def _get_differentiators(self, cohort: str,
131
+ features_by_cohort: Dict[str, List[Dict[str, float]]]) -> List[str]:
132
+ cohort_features = features_by_cohort.get(cohort, [])
133
+ cohort_top = set(f["feature"] for f in cohort_features[:3])
134
+ other_tops = set()
135
+ for other, features in features_by_cohort.items():
136
+ if other != cohort:
137
+ other_tops.update(f["feature"] for f in features[:3])
138
+ unique = cohort_top - other_tops
139
+ return [f"{cohort} uniquely driven by {f}" for f in unique]
140
+
141
+ def compare_cohorts(self, X: DataFrame, y: Series, cohorts: Series,
142
+ cohort_a: str, cohort_b: str) -> CohortComparison:
143
+ mask_a = cohorts == cohort_a
144
+ mask_b = cohorts == cohort_b
145
+ churn_a = 1 - y[mask_a].mean()
146
+ churn_b = 1 - y[mask_b].mean()
147
+ feature_diffs = {}
148
+ for col in X.columns:
149
+ mean_a = X.loc[mask_a, col].mean()
150
+ mean_b = X.loc[mask_b, col].mean()
151
+ feature_diffs[col] = float(mean_a - mean_b)
152
+ key_diffs = []
153
+ sorted_diffs = sorted(feature_diffs.items(), key=lambda x: abs(x[1]), reverse=True)
154
+ for feature, diff in sorted_diffs[:3]:
155
+ direction = "higher" if diff > 0 else "lower"
156
+ key_diffs.append(f"{cohort_a} has {direction} {feature} than {cohort_b}")
157
+ return CohortComparison(
158
+ cohort_a=cohort_a,
159
+ cohort_b=cohort_b,
160
+ feature_differences=feature_diffs,
161
+ churn_rate_difference=float(churn_a - churn_b),
162
+ key_differences=key_diffs
163
+ )
164
+
165
+ @staticmethod
166
+ def create_tenure_cohorts(tenure: Series,
167
+ bins: List[float] = None) -> Series:
168
+ bins = bins or [0, 90, 365, float("inf")]
169
+ labels = ["New", "Established", "Mature"]
170
+ return pd.cut(tenure, bins=bins, labels=labels)
171
+
172
+ @staticmethod
173
+ def create_value_cohorts(value: Series,
174
+ quantiles: List[float] = None) -> Series:
175
+ quantiles = quantiles or [0.33, 0.66]
176
+ q1, q2 = value.quantile(quantiles[0]), value.quantile(quantiles[1])
177
+ return pd.cut(value, bins=[-float("inf"), q1, q2, float("inf")],
178
+ labels=["Low", "Medium", "High"])
179
+
180
+ @staticmethod
181
+ def create_activity_cohorts(activity: Series,
182
+ thresholds: List[float] = None) -> Series:
183
+ thresholds = thresholds or [5, 15]
184
+ return pd.cut(activity, bins=[-float("inf"), thresholds[0], thresholds[1], float("inf")],
185
+ labels=["Dormant", "Moderate", "Active"])
@@ -0,0 +1,175 @@
1
+ """Counterfactual explanation generation."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import numpy as np
7
+
8
+ from customer_retention.core.compat import DataFrame, Series
9
+
10
+
11
+ @dataclass
12
+ class CounterfactualChange:
13
+ feature_name: str
14
+ original_value: float
15
+ new_value: float
16
+ change_magnitude: float
17
+
18
+
19
+ @dataclass
20
+ class Counterfactual:
21
+ original_prediction: float
22
+ counterfactual_prediction: float
23
+ changes: List[CounterfactualChange]
24
+ feasibility_score: float
25
+ business_interpretation: str
26
+
27
+
28
+ class CounterfactualGenerator:
29
+ def __init__(self, model: Any, reference_data: DataFrame,
30
+ actionable_features: Optional[List[str]] = None,
31
+ constraints: Optional[Dict[str, Dict[str, float]]] = None):
32
+ self.model = model
33
+ self.reference_data = reference_data
34
+ self.actionable_features = actionable_features or list(reference_data.columns)
35
+ self.constraints = constraints or {}
36
+ self._feature_bounds = self._calculate_bounds()
37
+
38
+ def _calculate_bounds(self) -> Dict[str, Dict[str, float]]:
39
+ bounds = {}
40
+ for col in self.reference_data.columns:
41
+ bounds[col] = {
42
+ "min": float(self.reference_data[col].min()),
43
+ "max": float(self.reference_data[col].max()),
44
+ "mean": float(self.reference_data[col].mean()),
45
+ "std": float(self.reference_data[col].std())
46
+ }
47
+ return bounds
48
+
49
+ def generate(self, instance: Series, target_class: int = 0,
50
+ max_iterations: int = 100) -> Counterfactual:
51
+ instance_df = instance.to_frame().T
52
+ original_pred = float(self.model.predict_proba(instance_df)[0, 1])
53
+ best_cf = instance.copy()
54
+ best_pred = original_pred
55
+ best_changes = []
56
+ target_pred = 0.3 if target_class == 0 else 0.7
57
+ for _ in range(max_iterations):
58
+ candidate = self._perturb_instance(instance, best_cf)
59
+ candidate_df = candidate.to_frame().T
60
+ pred = float(self.model.predict_proba(candidate_df)[0, 1])
61
+ improved = (target_class == 0 and pred < best_pred) or (target_class == 1 and pred > best_pred)
62
+ if improved:
63
+ best_cf = candidate
64
+ best_pred = pred
65
+ best_changes = self._compute_changes(instance, best_cf)
66
+ if (target_class == 0 and best_pred < target_pred) or (target_class == 1 and best_pred > target_pred):
67
+ break
68
+ feasibility = self._calculate_feasibility(instance, best_cf)
69
+ interpretation = self._generate_interpretation(best_changes, original_pred, best_pred)
70
+ return Counterfactual(
71
+ original_prediction=original_pred,
72
+ counterfactual_prediction=best_pred,
73
+ changes=best_changes,
74
+ feasibility_score=feasibility,
75
+ business_interpretation=interpretation
76
+ )
77
+
78
+ def _perturb_instance(self, original: Series, current: Series) -> Series:
79
+ candidate = current.copy()
80
+ feature = np.random.choice(self.actionable_features)
81
+ bounds = self._get_feature_bounds(feature)
82
+ current_val = candidate[feature]
83
+ step = (bounds["max"] - bounds["min"]) * 0.1
84
+ direction = np.random.choice([-1, 1])
85
+ new_val = current_val + direction * step * np.random.uniform(0.5, 1.5)
86
+ new_val = np.clip(new_val, bounds["min"], bounds["max"])
87
+ candidate[feature] = new_val
88
+ return candidate
89
+
90
+ def _get_feature_bounds(self, feature: str) -> Dict[str, float]:
91
+ if feature in self.constraints:
92
+ constraint = self.constraints[feature]
93
+ return {
94
+ "min": constraint.get("min", self._feature_bounds[feature]["min"]),
95
+ "max": constraint.get("max", self._feature_bounds[feature]["max"])
96
+ }
97
+ return self._feature_bounds[feature]
98
+
99
+ def _compute_changes(self, original: Series, counterfactual: Series) -> List[CounterfactualChange]:
100
+ changes = []
101
+ for feature in self.actionable_features:
102
+ if abs(original[feature] - counterfactual[feature]) > 1e-6:
103
+ changes.append(CounterfactualChange(
104
+ feature_name=feature,
105
+ original_value=float(original[feature]),
106
+ new_value=float(counterfactual[feature]),
107
+ change_magnitude=float(abs(original[feature] - counterfactual[feature]))
108
+ ))
109
+ return changes
110
+
111
+ def _calculate_feasibility(self, original: Series, counterfactual: Series) -> float:
112
+ total_change = 0
113
+ max_change = 0
114
+ for feature in self.actionable_features:
115
+ bounds = self._feature_bounds[feature]
116
+ range_size = bounds["max"] - bounds["min"]
117
+ if range_size > 0:
118
+ normalized_change = abs(original[feature] - counterfactual[feature]) / range_size
119
+ total_change += normalized_change
120
+ max_change += 1
121
+ if max_change == 0:
122
+ return 1.0
123
+ feasibility = 1 - (total_change / max_change)
124
+ return max(0.0, min(1.0, feasibility))
125
+
126
+ def _generate_interpretation(self, changes: List[CounterfactualChange],
127
+ original_pred: float, new_pred: float) -> str:
128
+ if not changes:
129
+ return "No changes needed to achieve target prediction."
130
+ change_strs = []
131
+ for c in changes[:3]:
132
+ direction = "increase" if c.new_value > c.original_value else "decrease"
133
+ change_strs.append(f"{direction} {c.feature_name} from {c.original_value:.2f} to {c.new_value:.2f}")
134
+ changes_text = ", ".join(change_strs)
135
+ return f"To reduce churn risk from {original_pred:.1%} to {new_pred:.1%}: {changes_text}"
136
+
137
+ def generate_diverse(self, instance: Series, n: int = 3) -> List[Counterfactual]:
138
+ counterfactuals = []
139
+ used_features = set()
140
+ for _ in range(n):
141
+ available = [f for f in self.actionable_features if f not in used_features]
142
+ if not available:
143
+ available = self.actionable_features
144
+ temp_generator = CounterfactualGenerator(
145
+ self.model, self.reference_data,
146
+ actionable_features=available,
147
+ constraints=self.constraints
148
+ )
149
+ cf = temp_generator.generate(instance)
150
+ counterfactuals.append(cf)
151
+ for change in cf.changes:
152
+ used_features.add(change.feature_name)
153
+ return counterfactuals
154
+
155
+ def generate_prototype(self, instance: Series, prototype_data: DataFrame) -> Counterfactual:
156
+ instance_df = instance.to_frame().T
157
+ original_pred = float(self.model.predict_proba(instance_df)[0, 1])
158
+ prototype = prototype_data.mean()
159
+ best_cf = instance.copy()
160
+ for feature in self.actionable_features:
161
+ bounds = self._get_feature_bounds(feature)
162
+ target_val = np.clip(prototype[feature], bounds["min"], bounds["max"])
163
+ best_cf[feature] = instance[feature] + 0.5 * (target_val - instance[feature])
164
+ cf_df = best_cf.to_frame().T
165
+ new_pred = float(self.model.predict_proba(cf_df)[0, 1])
166
+ changes = self._compute_changes(instance, best_cf)
167
+ feasibility = self._calculate_feasibility(instance, best_cf)
168
+ interpretation = self._generate_interpretation(changes, original_pred, new_pred)
169
+ return Counterfactual(
170
+ original_prediction=original_pred,
171
+ counterfactual_prediction=new_pred,
172
+ changes=changes,
173
+ feasibility_score=feasibility,
174
+ business_interpretation=interpretation
175
+ )
@@ -0,0 +1,141 @@
1
+ """Individual customer explanation."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import shap
9
+ from sklearn.neighbors import NearestNeighbors
10
+ from sklearn.preprocessing import StandardScaler
11
+
12
+ from customer_retention.core.compat import DataFrame, Series
13
+
14
+
15
+ class Confidence(Enum):
16
+ HIGH = "high"
17
+ MEDIUM = "medium"
18
+ LOW = "low"
19
+
20
+
21
+ @dataclass
22
+ class RiskContribution:
23
+ feature_name: str
24
+ contribution: float
25
+ current_value: float
26
+ direction: str
27
+
28
+
29
+ @dataclass
30
+ class IndividualExplanation:
31
+ customer_id: Optional[str]
32
+ churn_probability: float
33
+ base_value: float
34
+ shap_values: np.ndarray
35
+ top_positive_factors: List[RiskContribution]
36
+ top_negative_factors: List[RiskContribution]
37
+ confidence: Confidence
38
+ feature_names: List[str] = field(default_factory=list)
39
+
40
+
41
+ class IndividualExplainer:
42
+ def __init__(self, model: Any, background_data: DataFrame, max_samples: int = 100):
43
+ self.model = model
44
+ self.background_data = background_data.head(max_samples)
45
+ self.feature_names = list(background_data.columns)
46
+ self._explainer = self._create_explainer()
47
+
48
+ def _create_explainer(self) -> shap.Explainer:
49
+ model_type = type(self.model).__name__
50
+ if model_type in ["RandomForestClassifier", "GradientBoostingClassifier"]:
51
+ return shap.TreeExplainer(self.model)
52
+ if model_type in ["LogisticRegression", "LinearRegression"]:
53
+ return shap.LinearExplainer(self.model, self.background_data)
54
+ return shap.KernelExplainer(self.model.predict_proba, self.background_data)
55
+
56
+ def explain(self, instance: Series, customer_id: Optional[str] = None,
57
+ top_n: int = 3) -> IndividualExplanation:
58
+ instance_df = instance.to_frame().T
59
+ shap_values = self._extract_shap_values(instance_df)
60
+ churn_prob = float(self.model.predict_proba(instance_df)[0, 1])
61
+ expected_value = self._get_expected_value()
62
+ positive_factors = self._extract_factors(instance, shap_values, top_n, positive=True)
63
+ negative_factors = self._extract_factors(instance, shap_values, top_n, positive=False)
64
+ confidence = self._assess_confidence(churn_prob)
65
+ return IndividualExplanation(
66
+ customer_id=customer_id,
67
+ churn_probability=churn_prob,
68
+ base_value=float(expected_value),
69
+ shap_values=shap_values,
70
+ top_positive_factors=positive_factors,
71
+ top_negative_factors=negative_factors,
72
+ confidence=confidence,
73
+ feature_names=self.feature_names
74
+ )
75
+
76
+ def _extract_shap_values(self, X: DataFrame) -> np.ndarray:
77
+ shap_values = self._explainer.shap_values(X)
78
+ if hasattr(shap_values, 'values'):
79
+ shap_values = shap_values.values
80
+ if isinstance(shap_values, list):
81
+ shap_values = shap_values[1]
82
+ if len(shap_values.shape) == 3:
83
+ shap_values = shap_values[:, :, 1]
84
+ return shap_values.flatten()
85
+
86
+ def _get_expected_value(self) -> float:
87
+ expected_value = self._explainer.expected_value
88
+ if hasattr(expected_value, '__len__'):
89
+ if len(expected_value) > 1:
90
+ return float(expected_value[1])
91
+ return float(expected_value[0])
92
+ return float(expected_value)
93
+
94
+ def _extract_factors(self, instance: Series, shap_values: np.ndarray,
95
+ top_n: int, positive: bool) -> List[RiskContribution]:
96
+ if positive:
97
+ indices = np.argsort(shap_values)[::-1]
98
+ values = [(i, shap_values[i]) for i in indices if shap_values[i] > 0]
99
+ else:
100
+ indices = np.argsort(shap_values)
101
+ values = [(i, shap_values[i]) for i in indices if shap_values[i] < 0]
102
+ factors = []
103
+ for idx, contrib in values[:top_n]:
104
+ feature_name = self.feature_names[idx]
105
+ factors.append(RiskContribution(
106
+ feature_name=feature_name,
107
+ contribution=float(contrib),
108
+ current_value=float(instance[feature_name]),
109
+ direction="increases risk" if contrib > 0 else "decreases risk"
110
+ ))
111
+ return factors
112
+
113
+ def _assess_confidence(self, probability: float) -> Confidence:
114
+ if probability < 0.2 or probability > 0.8:
115
+ return Confidence.HIGH
116
+ if 0.4 < probability < 0.6:
117
+ return Confidence.LOW
118
+ return Confidence.MEDIUM
119
+
120
+ def find_similar_customers(self, instance: Series, X: DataFrame,
121
+ y: Series, k: int = 5) -> List[Dict]:
122
+ scaler = StandardScaler()
123
+ X_scaled = scaler.fit_transform(X)
124
+ instance_scaled = scaler.transform(instance.to_frame().T)
125
+ knn = NearestNeighbors(n_neighbors=k + 1, metric="euclidean")
126
+ knn.fit(X_scaled)
127
+ distances, indices = knn.kneighbors(instance_scaled)
128
+ similar = []
129
+ for dist, idx in zip(distances[0][1:], indices[0][1:]):
130
+ similar.append({
131
+ "index": int(idx),
132
+ "distance": float(dist),
133
+ "outcome": int(y.iloc[idx]),
134
+ "features": X.iloc[idx].to_dict()
135
+ })
136
+ return similar
137
+
138
+ def explain_batch(self, X: DataFrame,
139
+ customer_ids: Optional[List[str]] = None) -> List[IndividualExplanation]:
140
+ customer_ids = customer_ids or [None] * len(X)
141
+ return [self.explain(X.iloc[i], customer_ids[i]) for i in range(len(X))]
@@ -0,0 +1,103 @@
1
+ """Partial Dependence Plot generation."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, List, Optional
5
+
6
+ import numpy as np
7
+ from sklearn.inspection import partial_dependence
8
+
9
+ from customer_retention.core.compat import DataFrame
10
+
11
+
12
+ @dataclass
13
+ class PDPResult:
14
+ feature_name: str
15
+ grid_values: np.ndarray
16
+ pdp_values: np.ndarray
17
+ feature_min: float
18
+ feature_max: float
19
+ average_prediction: float
20
+ ice_values: Optional[List[np.ndarray]] = None
21
+
22
+
23
+ @dataclass
24
+ class InteractionResult:
25
+ feature1_name: str
26
+ feature2_name: str
27
+ grid1_values: np.ndarray
28
+ grid2_values: np.ndarray
29
+ pdp_matrix: np.ndarray
30
+
31
+
32
+ class PDPGenerator:
33
+ def __init__(self, model: Any):
34
+ self.model = model
35
+
36
+ def generate(self, X: DataFrame, feature: str, grid_resolution: int = 50,
37
+ include_ice: bool = False, ice_lines: int = 100) -> PDPResult:
38
+ feature_idx = list(X.columns).index(feature)
39
+ pd_result = partial_dependence(
40
+ self.model, X, [feature_idx], kind="average", grid_resolution=grid_resolution
41
+ )
42
+ grid_values = pd_result["grid_values"][0]
43
+ pdp_values = pd_result["average"][0]
44
+ ice_values = None
45
+ if include_ice:
46
+ ice_values = self._calculate_ice(X, feature, grid_values, ice_lines)
47
+ return PDPResult(
48
+ feature_name=feature,
49
+ grid_values=grid_values,
50
+ pdp_values=pdp_values,
51
+ feature_min=float(X[feature].min()),
52
+ feature_max=float(X[feature].max()),
53
+ average_prediction=float(np.mean(pdp_values)),
54
+ ice_values=ice_values
55
+ )
56
+
57
+ def _calculate_ice(self, X: DataFrame, feature: str,
58
+ grid_values: np.ndarray, n_samples: int) -> List[np.ndarray]:
59
+ sample_indices = np.random.choice(len(X), min(n_samples, len(X)), replace=False)
60
+ ice_lines = []
61
+ for idx in sample_indices:
62
+ X_temp = X.iloc[[idx]].copy()
63
+ predictions = []
64
+ for val in grid_values:
65
+ X_temp[feature] = val
66
+ pred = self.model.predict_proba(X_temp)[0, 1]
67
+ predictions.append(pred)
68
+ ice_lines.append(np.array(predictions))
69
+ return ice_lines
70
+
71
+ def generate_multiple(self, X: DataFrame, features: List[str],
72
+ grid_resolution: int = 50) -> List[PDPResult]:
73
+ return [self.generate(X, feature, grid_resolution) for feature in features]
74
+
75
+ def generate_top_features(self, X: DataFrame, n_features: int = 5,
76
+ grid_resolution: int = 50) -> List[PDPResult]:
77
+ importances = {}
78
+ for feature in X.columns:
79
+ X_shuffled = X.copy()
80
+ X_shuffled[feature] = np.random.permutation(X_shuffled[feature].values)
81
+ original_pred = self.model.predict_proba(X)[:, 1].mean()
82
+ shuffled_pred = self.model.predict_proba(X_shuffled)[:, 1].mean()
83
+ importances[feature] = abs(original_pred - shuffled_pred)
84
+ top_features = sorted(importances.keys(), key=lambda f: importances[f], reverse=True)[:n_features]
85
+ return self.generate_multiple(X, top_features, grid_resolution)
86
+
87
+ def generate_interaction(self, X: DataFrame, feature1: str, feature2: str,
88
+ grid_resolution: int = 20) -> InteractionResult:
89
+ feature1_idx = list(X.columns).index(feature1)
90
+ feature2_idx = list(X.columns).index(feature2)
91
+ pd_result = partial_dependence(
92
+ self.model, X, [(feature1_idx, feature2_idx)], kind="average", grid_resolution=grid_resolution
93
+ )
94
+ grid1 = pd_result["grid_values"][0]
95
+ grid2 = pd_result["grid_values"][1]
96
+ pdp_matrix = pd_result["average"][0]
97
+ return InteractionResult(
98
+ feature1_name=feature1,
99
+ feature2_name=feature2,
100
+ grid1_values=grid1,
101
+ grid2_values=grid2,
102
+ pdp_matrix=pdp_matrix
103
+ )