churnkit 0.75.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
  2. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
  3. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
  4. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
  5. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
  6. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
  7. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
  8. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
  9. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
  10. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
  11. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
  12. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
  13. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
  14. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
  15. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
  16. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
  17. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
  18. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
  19. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
  20. churnkit-0.75.0a1.dist-info/METADATA +229 -0
  21. churnkit-0.75.0a1.dist-info/RECORD +302 -0
  22. churnkit-0.75.0a1.dist-info/WHEEL +4 -0
  23. churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
  24. churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
  25. customer_retention/__init__.py +37 -0
  26. customer_retention/analysis/__init__.py +0 -0
  27. customer_retention/analysis/auto_explorer/__init__.py +62 -0
  28. customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
  29. customer_retention/analysis/auto_explorer/explorer.py +258 -0
  30. customer_retention/analysis/auto_explorer/findings.py +291 -0
  31. customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
  32. customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
  33. customer_retention/analysis/auto_explorer/recommendations.py +418 -0
  34. customer_retention/analysis/business/__init__.py +26 -0
  35. customer_retention/analysis/business/ab_test_designer.py +144 -0
  36. customer_retention/analysis/business/fairness_analyzer.py +166 -0
  37. customer_retention/analysis/business/intervention_matcher.py +121 -0
  38. customer_retention/analysis/business/report_generator.py +222 -0
  39. customer_retention/analysis/business/risk_profile.py +199 -0
  40. customer_retention/analysis/business/roi_analyzer.py +139 -0
  41. customer_retention/analysis/diagnostics/__init__.py +20 -0
  42. customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
  43. customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
  44. customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
  45. customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
  46. customer_retention/analysis/diagnostics/noise_tester.py +140 -0
  47. customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
  48. customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
  49. customer_retention/analysis/discovery/__init__.py +8 -0
  50. customer_retention/analysis/discovery/config_generator.py +49 -0
  51. customer_retention/analysis/discovery/discovery_flow.py +19 -0
  52. customer_retention/analysis/discovery/type_inferencer.py +147 -0
  53. customer_retention/analysis/interpretability/__init__.py +13 -0
  54. customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
  55. customer_retention/analysis/interpretability/counterfactual.py +175 -0
  56. customer_retention/analysis/interpretability/individual_explainer.py +141 -0
  57. customer_retention/analysis/interpretability/pdp_generator.py +103 -0
  58. customer_retention/analysis/interpretability/shap_explainer.py +106 -0
  59. customer_retention/analysis/jupyter_save_hook.py +28 -0
  60. customer_retention/analysis/notebook_html_exporter.py +136 -0
  61. customer_retention/analysis/notebook_progress.py +60 -0
  62. customer_retention/analysis/plotly_preprocessor.py +154 -0
  63. customer_retention/analysis/recommendations/__init__.py +54 -0
  64. customer_retention/analysis/recommendations/base.py +158 -0
  65. customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
  66. customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
  67. customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
  68. customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
  69. customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
  70. customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
  71. customer_retention/analysis/recommendations/datetime/extract.py +149 -0
  72. customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
  73. customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
  74. customer_retention/analysis/recommendations/pipeline.py +74 -0
  75. customer_retention/analysis/recommendations/registry.py +76 -0
  76. customer_retention/analysis/recommendations/selection/__init__.py +3 -0
  77. customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
  78. customer_retention/analysis/recommendations/transform/__init__.py +4 -0
  79. customer_retention/analysis/recommendations/transform/power.py +94 -0
  80. customer_retention/analysis/recommendations/transform/scale.py +112 -0
  81. customer_retention/analysis/visualization/__init__.py +15 -0
  82. customer_retention/analysis/visualization/chart_builder.py +2619 -0
  83. customer_retention/analysis/visualization/console.py +122 -0
  84. customer_retention/analysis/visualization/display.py +171 -0
  85. customer_retention/analysis/visualization/number_formatter.py +36 -0
  86. customer_retention/artifacts/__init__.py +3 -0
  87. customer_retention/artifacts/fit_artifact_registry.py +146 -0
  88. customer_retention/cli.py +93 -0
  89. customer_retention/core/__init__.py +0 -0
  90. customer_retention/core/compat/__init__.py +193 -0
  91. customer_retention/core/compat/detection.py +99 -0
  92. customer_retention/core/compat/ops.py +48 -0
  93. customer_retention/core/compat/pandas_backend.py +57 -0
  94. customer_retention/core/compat/spark_backend.py +75 -0
  95. customer_retention/core/components/__init__.py +11 -0
  96. customer_retention/core/components/base.py +79 -0
  97. customer_retention/core/components/components/__init__.py +13 -0
  98. customer_retention/core/components/components/deployer.py +26 -0
  99. customer_retention/core/components/components/explainer.py +26 -0
  100. customer_retention/core/components/components/feature_eng.py +33 -0
  101. customer_retention/core/components/components/ingester.py +34 -0
  102. customer_retention/core/components/components/profiler.py +34 -0
  103. customer_retention/core/components/components/trainer.py +38 -0
  104. customer_retention/core/components/components/transformer.py +36 -0
  105. customer_retention/core/components/components/validator.py +37 -0
  106. customer_retention/core/components/enums.py +33 -0
  107. customer_retention/core/components/orchestrator.py +94 -0
  108. customer_retention/core/components/registry.py +59 -0
  109. customer_retention/core/config/__init__.py +39 -0
  110. customer_retention/core/config/column_config.py +95 -0
  111. customer_retention/core/config/experiments.py +71 -0
  112. customer_retention/core/config/pipeline_config.py +117 -0
  113. customer_retention/core/config/source_config.py +83 -0
  114. customer_retention/core/utils/__init__.py +28 -0
  115. customer_retention/core/utils/leakage.py +85 -0
  116. customer_retention/core/utils/severity.py +53 -0
  117. customer_retention/core/utils/statistics.py +90 -0
  118. customer_retention/generators/__init__.py +0 -0
  119. customer_retention/generators/notebook_generator/__init__.py +167 -0
  120. customer_retention/generators/notebook_generator/base.py +55 -0
  121. customer_retention/generators/notebook_generator/cell_builder.py +49 -0
  122. customer_retention/generators/notebook_generator/config.py +47 -0
  123. customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
  124. customer_retention/generators/notebook_generator/local_generator.py +48 -0
  125. customer_retention/generators/notebook_generator/project_init.py +174 -0
  126. customer_retention/generators/notebook_generator/runner.py +150 -0
  127. customer_retention/generators/notebook_generator/script_generator.py +110 -0
  128. customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
  129. customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
  130. customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
  131. customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
  132. customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
  133. customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
  134. customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
  135. customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
  136. customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
  137. customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
  138. customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
  139. customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
  140. customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
  141. customer_retention/generators/orchestration/__init__.py +23 -0
  142. customer_retention/generators/orchestration/code_generator.py +196 -0
  143. customer_retention/generators/orchestration/context.py +147 -0
  144. customer_retention/generators/orchestration/data_materializer.py +188 -0
  145. customer_retention/generators/orchestration/databricks_exporter.py +411 -0
  146. customer_retention/generators/orchestration/doc_generator.py +311 -0
  147. customer_retention/generators/pipeline_generator/__init__.py +26 -0
  148. customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
  149. customer_retention/generators/pipeline_generator/generator.py +142 -0
  150. customer_retention/generators/pipeline_generator/models.py +166 -0
  151. customer_retention/generators/pipeline_generator/renderer.py +2125 -0
  152. customer_retention/generators/spec_generator/__init__.py +37 -0
  153. customer_retention/generators/spec_generator/databricks_generator.py +433 -0
  154. customer_retention/generators/spec_generator/generic_generator.py +373 -0
  155. customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
  156. customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
  157. customer_retention/integrations/__init__.py +0 -0
  158. customer_retention/integrations/adapters/__init__.py +13 -0
  159. customer_retention/integrations/adapters/base.py +10 -0
  160. customer_retention/integrations/adapters/factory.py +25 -0
  161. customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
  162. customer_retention/integrations/adapters/feature_store/base.py +57 -0
  163. customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
  164. customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
  165. customer_retention/integrations/adapters/feature_store/local.py +75 -0
  166. customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
  167. customer_retention/integrations/adapters/mlflow/base.py +32 -0
  168. customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
  169. customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
  170. customer_retention/integrations/adapters/mlflow/local.py +50 -0
  171. customer_retention/integrations/adapters/storage/__init__.py +5 -0
  172. customer_retention/integrations/adapters/storage/base.py +33 -0
  173. customer_retention/integrations/adapters/storage/databricks.py +76 -0
  174. customer_retention/integrations/adapters/storage/local.py +59 -0
  175. customer_retention/integrations/feature_store/__init__.py +47 -0
  176. customer_retention/integrations/feature_store/definitions.py +215 -0
  177. customer_retention/integrations/feature_store/manager.py +744 -0
  178. customer_retention/integrations/feature_store/registry.py +412 -0
  179. customer_retention/integrations/iteration/__init__.py +28 -0
  180. customer_retention/integrations/iteration/context.py +212 -0
  181. customer_retention/integrations/iteration/feedback_collector.py +184 -0
  182. customer_retention/integrations/iteration/orchestrator.py +168 -0
  183. customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
  184. customer_retention/integrations/iteration/signals.py +212 -0
  185. customer_retention/integrations/llm_context/__init__.py +4 -0
  186. customer_retention/integrations/llm_context/context_builder.py +201 -0
  187. customer_retention/integrations/llm_context/prompts.py +100 -0
  188. customer_retention/integrations/streaming/__init__.py +103 -0
  189. customer_retention/integrations/streaming/batch_integration.py +149 -0
  190. customer_retention/integrations/streaming/early_warning_model.py +227 -0
  191. customer_retention/integrations/streaming/event_schema.py +214 -0
  192. customer_retention/integrations/streaming/online_store_writer.py +249 -0
  193. customer_retention/integrations/streaming/realtime_scorer.py +261 -0
  194. customer_retention/integrations/streaming/trigger_engine.py +293 -0
  195. customer_retention/integrations/streaming/window_aggregator.py +393 -0
  196. customer_retention/stages/__init__.py +0 -0
  197. customer_retention/stages/cleaning/__init__.py +9 -0
  198. customer_retention/stages/cleaning/base.py +28 -0
  199. customer_retention/stages/cleaning/missing_handler.py +160 -0
  200. customer_retention/stages/cleaning/outlier_handler.py +204 -0
  201. customer_retention/stages/deployment/__init__.py +28 -0
  202. customer_retention/stages/deployment/batch_scorer.py +106 -0
  203. customer_retention/stages/deployment/champion_challenger.py +299 -0
  204. customer_retention/stages/deployment/model_registry.py +182 -0
  205. customer_retention/stages/deployment/retraining_trigger.py +245 -0
  206. customer_retention/stages/features/__init__.py +73 -0
  207. customer_retention/stages/features/behavioral_features.py +266 -0
  208. customer_retention/stages/features/customer_segmentation.py +505 -0
  209. customer_retention/stages/features/feature_definitions.py +265 -0
  210. customer_retention/stages/features/feature_engineer.py +551 -0
  211. customer_retention/stages/features/feature_manifest.py +340 -0
  212. customer_retention/stages/features/feature_selector.py +239 -0
  213. customer_retention/stages/features/interaction_features.py +160 -0
  214. customer_retention/stages/features/temporal_features.py +243 -0
  215. customer_retention/stages/ingestion/__init__.py +9 -0
  216. customer_retention/stages/ingestion/load_result.py +32 -0
  217. customer_retention/stages/ingestion/loaders.py +195 -0
  218. customer_retention/stages/ingestion/source_registry.py +130 -0
  219. customer_retention/stages/modeling/__init__.py +31 -0
  220. customer_retention/stages/modeling/baseline_trainer.py +139 -0
  221. customer_retention/stages/modeling/cross_validator.py +125 -0
  222. customer_retention/stages/modeling/data_splitter.py +205 -0
  223. customer_retention/stages/modeling/feature_scaler.py +99 -0
  224. customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
  225. customer_retention/stages/modeling/imbalance_handler.py +282 -0
  226. customer_retention/stages/modeling/mlflow_logger.py +95 -0
  227. customer_retention/stages/modeling/model_comparator.py +149 -0
  228. customer_retention/stages/modeling/model_evaluator.py +138 -0
  229. customer_retention/stages/modeling/threshold_optimizer.py +131 -0
  230. customer_retention/stages/monitoring/__init__.py +37 -0
  231. customer_retention/stages/monitoring/alert_manager.py +328 -0
  232. customer_retention/stages/monitoring/drift_detector.py +201 -0
  233. customer_retention/stages/monitoring/performance_monitor.py +242 -0
  234. customer_retention/stages/preprocessing/__init__.py +5 -0
  235. customer_retention/stages/preprocessing/transformer_manager.py +284 -0
  236. customer_retention/stages/profiling/__init__.py +256 -0
  237. customer_retention/stages/profiling/categorical_distribution.py +269 -0
  238. customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
  239. customer_retention/stages/profiling/column_profiler.py +527 -0
  240. customer_retention/stages/profiling/distribution_analysis.py +483 -0
  241. customer_retention/stages/profiling/drift_detector.py +310 -0
  242. customer_retention/stages/profiling/feature_capacity.py +507 -0
  243. customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
  244. customer_retention/stages/profiling/profile_result.py +212 -0
  245. customer_retention/stages/profiling/quality_checks.py +1632 -0
  246. customer_retention/stages/profiling/relationship_detector.py +256 -0
  247. customer_retention/stages/profiling/relationship_recommender.py +454 -0
  248. customer_retention/stages/profiling/report_generator.py +520 -0
  249. customer_retention/stages/profiling/scd_analyzer.py +151 -0
  250. customer_retention/stages/profiling/segment_analyzer.py +632 -0
  251. customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
  252. customer_retention/stages/profiling/target_level_analyzer.py +217 -0
  253. customer_retention/stages/profiling/temporal_analyzer.py +388 -0
  254. customer_retention/stages/profiling/temporal_coverage.py +488 -0
  255. customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
  256. customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
  257. customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
  258. customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
  259. customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
  260. customer_retention/stages/profiling/text_embedder.py +87 -0
  261. customer_retention/stages/profiling/text_processor.py +115 -0
  262. customer_retention/stages/profiling/text_reducer.py +60 -0
  263. customer_retention/stages/profiling/time_series_profiler.py +303 -0
  264. customer_retention/stages/profiling/time_window_aggregator.py +376 -0
  265. customer_retention/stages/profiling/type_detector.py +382 -0
  266. customer_retention/stages/profiling/window_recommendation.py +288 -0
  267. customer_retention/stages/temporal/__init__.py +166 -0
  268. customer_retention/stages/temporal/access_guard.py +180 -0
  269. customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
  270. customer_retention/stages/temporal/data_preparer.py +178 -0
  271. customer_retention/stages/temporal/point_in_time_join.py +134 -0
  272. customer_retention/stages/temporal/point_in_time_registry.py +148 -0
  273. customer_retention/stages/temporal/scenario_detector.py +163 -0
  274. customer_retention/stages/temporal/snapshot_manager.py +259 -0
  275. customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
  276. customer_retention/stages/temporal/timestamp_discovery.py +531 -0
  277. customer_retention/stages/temporal/timestamp_manager.py +255 -0
  278. customer_retention/stages/transformation/__init__.py +13 -0
  279. customer_retention/stages/transformation/binary_handler.py +85 -0
  280. customer_retention/stages/transformation/categorical_encoder.py +245 -0
  281. customer_retention/stages/transformation/datetime_transformer.py +97 -0
  282. customer_retention/stages/transformation/numeric_transformer.py +181 -0
  283. customer_retention/stages/transformation/pipeline.py +257 -0
  284. customer_retention/stages/validation/__init__.py +60 -0
  285. customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
  286. customer_retention/stages/validation/business_sense_gate.py +173 -0
  287. customer_retention/stages/validation/data_quality_gate.py +235 -0
  288. customer_retention/stages/validation/data_validators.py +511 -0
  289. customer_retention/stages/validation/feature_quality_gate.py +183 -0
  290. customer_retention/stages/validation/gates.py +117 -0
  291. customer_retention/stages/validation/leakage_gate.py +352 -0
  292. customer_retention/stages/validation/model_validity_gate.py +213 -0
  293. customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
  294. customer_retention/stages/validation/quality_scorer.py +544 -0
  295. customer_retention/stages/validation/rule_generator.py +57 -0
  296. customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
  297. customer_retention/stages/validation/timeseries_detector.py +769 -0
  298. customer_retention/transforms/__init__.py +47 -0
  299. customer_retention/transforms/artifact_store.py +50 -0
  300. customer_retention/transforms/executor.py +157 -0
  301. customer_retention/transforms/fitted.py +92 -0
  302. customer_retention/transforms/ops.py +148 -0
@@ -0,0 +1,99 @@
1
+ """Feature scaling for model training."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Any, Dict, Optional
6
+
7
+ from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler
8
+
9
+ from customer_retention.core.compat import DataFrame
10
+
11
+
12
+ class ScalerType(Enum):
13
+ STANDARD = "standard"
14
+ ROBUST = "robust"
15
+ MINMAX = "minmax"
16
+ NONE = "none"
17
+
18
+
19
+ @dataclass
20
+ class ScalingResult:
21
+ scaler: Optional[Any]
22
+ X_train_scaled: DataFrame
23
+ X_test_scaled: DataFrame
24
+ scaling_params: Dict[str, Any]
25
+
26
+
27
+ class FeatureScaler:
28
+ def __init__(
29
+ self,
30
+ scaler_type: ScalerType = ScalerType.ROBUST,
31
+ fit_on_train_only: bool = True,
32
+ save_scaler: bool = True,
33
+ ):
34
+ self.scaler_type = scaler_type
35
+ self.fit_on_train_only = fit_on_train_only
36
+ self.save_scaler = save_scaler
37
+ self._scaler = None
38
+ self._feature_names = None
39
+
40
+ def fit_transform(
41
+ self,
42
+ X_train: DataFrame,
43
+ X_test: DataFrame,
44
+ ) -> ScalingResult:
45
+ self._feature_names = list(X_train.columns)
46
+
47
+ if self.scaler_type == ScalerType.NONE:
48
+ return ScalingResult(
49
+ scaler=None,
50
+ X_train_scaled=X_train,
51
+ X_test_scaled=X_test,
52
+ scaling_params={},
53
+ )
54
+
55
+ self._scaler = self._create_scaler()
56
+ X_train_scaled = self._scaler.fit_transform(X_train)
57
+ X_test_scaled = self._scaler.transform(X_test)
58
+
59
+ scaling_params = self._extract_params()
60
+
61
+ return ScalingResult(
62
+ scaler=self._scaler if self.save_scaler else None,
63
+ X_train_scaled=DataFrame(X_train_scaled, columns=self._feature_names, index=X_train.index),
64
+ X_test_scaled=DataFrame(X_test_scaled, columns=self._feature_names, index=X_test.index),
65
+ scaling_params=scaling_params,
66
+ )
67
+
68
+ def transform(self, X: DataFrame) -> DataFrame:
69
+ if self._scaler is None:
70
+ return X
71
+ X_scaled = self._scaler.transform(X)
72
+ return DataFrame(X_scaled, columns=self._feature_names, index=X.index)
73
+
74
+ def _create_scaler(self):
75
+ if self.scaler_type == ScalerType.STANDARD:
76
+ return StandardScaler()
77
+ if self.scaler_type == ScalerType.ROBUST:
78
+ return RobustScaler()
79
+ if self.scaler_type == ScalerType.MINMAX:
80
+ return MinMaxScaler()
81
+ return None
82
+
83
+ def _extract_params(self) -> Dict[str, Any]:
84
+ if self._scaler is None:
85
+ return {}
86
+
87
+ params = {}
88
+ if hasattr(self._scaler, "mean_"):
89
+ params["mean"] = self._scaler.mean_.tolist()
90
+ if hasattr(self._scaler, "scale_"):
91
+ params["scale"] = self._scaler.scale_.tolist()
92
+ if hasattr(self._scaler, "center_"):
93
+ params["center"] = self._scaler.center_.tolist()
94
+ if hasattr(self._scaler, "data_min_"):
95
+ params["data_min"] = self._scaler.data_min_.tolist()
96
+ if hasattr(self._scaler, "data_max_"):
97
+ params["data_max"] = self._scaler.data_max_.tolist()
98
+
99
+ return params
@@ -0,0 +1,107 @@
1
+ """Hyperparameter tuning strategies for model optimization."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
8
+
9
+ from customer_retention.core.compat import DataFrame, Series
10
+
11
+
12
+ class SearchStrategy(Enum):
13
+ RANDOM_SEARCH = "random_search"
14
+ GRID_SEARCH = "grid_search"
15
+ BAYESIAN = "bayesian"
16
+ HALVING = "halving"
17
+
18
+
19
+ @dataclass
20
+ class TuningResult:
21
+ best_params: Dict[str, Any]
22
+ best_score: float
23
+ best_model: Any
24
+ cv_results: List[Dict[str, Any]]
25
+ scoring: str
26
+
27
+
28
+ class HyperparameterTuner:
29
+ def __init__(
30
+ self,
31
+ strategy: SearchStrategy = SearchStrategy.RANDOM_SEARCH,
32
+ param_space: Optional[Dict[str, Any]] = None,
33
+ n_iter: int = 50,
34
+ cv: int = 5,
35
+ scoring: str = "average_precision",
36
+ n_jobs: int = -1,
37
+ verbose: int = 0,
38
+ random_state: int = 42,
39
+ ):
40
+ self.strategy = strategy
41
+ self.param_space = param_space or {}
42
+ self.n_iter = n_iter
43
+ self.cv = cv
44
+ self.scoring = scoring
45
+ self.n_jobs = n_jobs
46
+ self.verbose = verbose
47
+ self.random_state = random_state
48
+
49
+ def tune(self, model, X: DataFrame, y: Series) -> TuningResult:
50
+ search = self._create_search(model)
51
+ search.fit(X, y)
52
+
53
+ cv_results = self._extract_cv_results(search)
54
+
55
+ return TuningResult(
56
+ best_params=search.best_params_,
57
+ best_score=search.best_score_,
58
+ best_model=search.best_estimator_,
59
+ cv_results=cv_results,
60
+ scoring=self.scoring,
61
+ )
62
+
63
+ def _create_search(self, model):
64
+ if self.strategy == SearchStrategy.GRID_SEARCH:
65
+ return GridSearchCV(
66
+ model,
67
+ param_grid=self.param_space,
68
+ cv=self.cv,
69
+ scoring=self.scoring,
70
+ n_jobs=self.n_jobs,
71
+ verbose=self.verbose,
72
+ )
73
+
74
+ if self.strategy == SearchStrategy.HALVING:
75
+ from sklearn.model_selection import HalvingRandomSearchCV
76
+ return HalvingRandomSearchCV(
77
+ model,
78
+ param_distributions=self.param_space,
79
+ cv=self.cv,
80
+ scoring=self.scoring,
81
+ n_jobs=self.n_jobs,
82
+ verbose=self.verbose,
83
+ random_state=self.random_state,
84
+ )
85
+
86
+ return RandomizedSearchCV(
87
+ model,
88
+ param_distributions=self.param_space,
89
+ n_iter=self.n_iter,
90
+ cv=self.cv,
91
+ scoring=self.scoring,
92
+ n_jobs=self.n_jobs,
93
+ verbose=self.verbose,
94
+ random_state=self.random_state,
95
+ )
96
+
97
+ def _extract_cv_results(self, search) -> List[Dict[str, Any]]:
98
+ results = []
99
+ for i in range(len(search.cv_results_["mean_test_score"])):
100
+ result = {
101
+ "params": search.cv_results_["params"][i],
102
+ "mean_score": search.cv_results_["mean_test_score"][i],
103
+ "std_score": search.cv_results_["std_test_score"][i],
104
+ "rank": search.cv_results_["rank_test_score"][i],
105
+ }
106
+ results.append(result)
107
+ return results
@@ -0,0 +1,282 @@
1
+ """Class imbalance handling strategies for model training."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Dict, Optional, Union
6
+
7
+ import numpy as np
8
+
9
+ from customer_retention.core.compat import DataFrame, Series
10
+
11
+
12
+ class ImbalanceStrategy(Enum):
13
+ CLASS_WEIGHT = "class_weight"
14
+ SMOTE = "smote"
15
+ RANDOM_OVERSAMPLE = "random_oversample"
16
+ RANDOM_UNDERSAMPLE = "random_undersample"
17
+ SMOTEENN = "smoteenn"
18
+ ADASYN = "adasyn"
19
+ NONE = "none"
20
+
21
+
22
+ class ClassWeightMethod(Enum):
23
+ BALANCED = "balanced"
24
+ CUSTOM = "custom"
25
+ INVERSE = "inverse"
26
+
27
+
28
+ @dataclass
29
+ class ImbalanceResult:
30
+ X_resampled: Optional[DataFrame]
31
+ y_resampled: Optional[Series]
32
+ strategy_used: ImbalanceStrategy
33
+ original_class_counts: Dict[int, int]
34
+ resampled_class_counts: Optional[Dict[int, int]] = None
35
+ class_weights: Optional[Dict[int, float]] = None
36
+ imbalance_ratio: Optional[float] = None
37
+
38
+
39
+ class ImbalanceHandler:
40
+ def __init__(
41
+ self,
42
+ strategy: ImbalanceStrategy = ImbalanceStrategy.CLASS_WEIGHT,
43
+ weight_method: ClassWeightMethod = ClassWeightMethod.BALANCED,
44
+ custom_weights: Optional[Dict[int, float]] = None,
45
+ sampling_strategy: Union[str, float] = "auto",
46
+ random_state: int = 42,
47
+ ):
48
+ self.strategy = strategy
49
+ self.weight_method = weight_method
50
+ self.custom_weights = custom_weights
51
+ self.sampling_strategy = sampling_strategy
52
+ self.random_state = random_state
53
+ self._class_weights = None
54
+
55
+ def fit(self, X: DataFrame, y: Series) -> ImbalanceResult:
56
+ original_counts = y.value_counts().to_dict()
57
+ imbalance_ratio = max(original_counts.values()) / min(original_counts.values())
58
+
59
+ class_weights = None
60
+ if self.strategy == ImbalanceStrategy.CLASS_WEIGHT:
61
+ class_weights = self._compute_class_weights(y)
62
+
63
+ return ImbalanceResult(
64
+ X_resampled=None,
65
+ y_resampled=None,
66
+ strategy_used=self.strategy,
67
+ original_class_counts=original_counts,
68
+ resampled_class_counts=None,
69
+ class_weights=class_weights,
70
+ imbalance_ratio=imbalance_ratio,
71
+ )
72
+
73
+ def fit_transform(self, X: DataFrame, y: Series) -> ImbalanceResult:
74
+ original_counts = y.value_counts().to_dict()
75
+ imbalance_ratio = max(original_counts.values()) / min(original_counts.values())
76
+
77
+ if self.strategy == ImbalanceStrategy.NONE:
78
+ return ImbalanceResult(
79
+ X_resampled=X,
80
+ y_resampled=y,
81
+ strategy_used=self.strategy,
82
+ original_class_counts=original_counts,
83
+ resampled_class_counts=original_counts,
84
+ imbalance_ratio=imbalance_ratio,
85
+ )
86
+
87
+ if self.strategy == ImbalanceStrategy.CLASS_WEIGHT:
88
+ return ImbalanceResult(
89
+ X_resampled=X,
90
+ y_resampled=y,
91
+ strategy_used=self.strategy,
92
+ original_class_counts=original_counts,
93
+ resampled_class_counts=original_counts,
94
+ class_weights=self._compute_class_weights(y),
95
+ imbalance_ratio=imbalance_ratio,
96
+ )
97
+
98
+ X_res, y_res = self._resample(X, y)
99
+ resampled_counts = Series(y_res).value_counts().to_dict()
100
+
101
+ return ImbalanceResult(
102
+ X_resampled=DataFrame(X_res, columns=X.columns),
103
+ y_resampled=Series(y_res),
104
+ strategy_used=self.strategy,
105
+ original_class_counts=original_counts,
106
+ resampled_class_counts=resampled_counts,
107
+ imbalance_ratio=imbalance_ratio,
108
+ )
109
+
110
+ def _compute_class_weights(self, y: Series) -> Dict[int, float]:
111
+ if self.weight_method == ClassWeightMethod.CUSTOM:
112
+ return self.custom_weights
113
+
114
+ classes = np.unique(y)
115
+ n_samples = len(y)
116
+ n_classes = len(classes)
117
+
118
+ if self.weight_method == ClassWeightMethod.BALANCED:
119
+ weights = {}
120
+ for cls in classes:
121
+ n_cls = (y == cls).sum()
122
+ weights[cls] = n_samples / (n_classes * n_cls)
123
+ return weights
124
+
125
+ if self.weight_method == ClassWeightMethod.INVERSE:
126
+ weights = {}
127
+ for cls in classes:
128
+ proportion = (y == cls).sum() / n_samples
129
+ weights[cls] = 1.0 / proportion
130
+ return weights
131
+
132
+ return {cls: 1.0 for cls in classes}
133
+
134
+ def _resample(self, X: DataFrame, y: Series) -> tuple:
135
+ if self.strategy == ImbalanceStrategy.SMOTE:
136
+ from imblearn.over_sampling import SMOTE
137
+ sampler = SMOTE(sampling_strategy=self.sampling_strategy, random_state=self.random_state)
138
+ return sampler.fit_resample(X, y)
139
+
140
+ if self.strategy == ImbalanceStrategy.RANDOM_OVERSAMPLE:
141
+ from imblearn.over_sampling import RandomOverSampler
142
+ sampler = RandomOverSampler(sampling_strategy=self.sampling_strategy, random_state=self.random_state)
143
+ return sampler.fit_resample(X, y)
144
+
145
+ if self.strategy == ImbalanceStrategy.RANDOM_UNDERSAMPLE:
146
+ from imblearn.under_sampling import RandomUnderSampler
147
+ sampler = RandomUnderSampler(sampling_strategy=self.sampling_strategy, random_state=self.random_state)
148
+ return sampler.fit_resample(X, y)
149
+
150
+ if self.strategy == ImbalanceStrategy.SMOTEENN:
151
+ from imblearn.combine import SMOTEENN
152
+ sampler = SMOTEENN(sampling_strategy=self.sampling_strategy, random_state=self.random_state)
153
+ return sampler.fit_resample(X, y)
154
+
155
+ if self.strategy == ImbalanceStrategy.ADASYN:
156
+ from imblearn.over_sampling import ADASYN
157
+ sampler = ADASYN(sampling_strategy=self.sampling_strategy, random_state=self.random_state)
158
+ return sampler.fit_resample(X, y)
159
+
160
+ return X.values, y.values
161
+
162
+
163
+ @dataclass
164
+ class ImbalanceRecommendation:
165
+ """Recommendation for handling class imbalance."""
166
+ severity: str # "low", "moderate", "high", "severe"
167
+ ratio: float
168
+ strategies: list
169
+ primary_strategy: ImbalanceStrategy
170
+ explanation: str
171
+
172
+ def print_recommendation(self):
173
+ icons = {"low": "🟢", "moderate": "🟡", "high": "🟠", "severe": "🔴"}
174
+ print(f"\n{icons.get(self.severity, '⚪')} Class Imbalance: {self.severity.upper()} ({self.ratio:.1f}:1)")
175
+ print(f"\n{self.explanation}")
176
+ print("\nRecommended strategies (in order of preference):")
177
+ for i, (strategy, desc) in enumerate(self.strategies, 1):
178
+ marker = "→" if strategy == self.primary_strategy else " "
179
+ print(f" {marker} {i}. {strategy.value}: {desc}")
180
+
181
+
182
+ class ImbalanceRecommender:
183
+ """Recommends imbalance handling strategies based on data characteristics."""
184
+
185
+ THRESHOLDS = {"low": 3, "moderate": 10, "high": 20, "severe": float("inf")}
186
+
187
+ STRATEGY_DESCRIPTIONS = {
188
+ ImbalanceStrategy.CLASS_WEIGHT: "Adjust loss function weights (no data modification)",
189
+ ImbalanceStrategy.SMOTE: "Generate synthetic minority samples using k-NN interpolation",
190
+ ImbalanceStrategy.RANDOM_UNDERSAMPLE: "Randomly remove majority samples",
191
+ ImbalanceStrategy.RANDOM_OVERSAMPLE: "Duplicate minority samples (risk of overfitting)",
192
+ ImbalanceStrategy.SMOTEENN: "SMOTE + ENN cleaning (removes noisy samples)",
193
+ ImbalanceStrategy.ADASYN: "Adaptive synthetic sampling (focuses on harder examples)",
194
+ }
195
+
196
+ def recommend(self, y: Series, n_samples: Optional[int] = None) -> ImbalanceRecommendation:
197
+ """Recommend imbalance handling strategy based on class distribution."""
198
+ counts = y.value_counts().to_dict()
199
+ ratio = max(counts.values()) / min(counts.values())
200
+ n_minority = min(counts.values())
201
+ n_total = n_samples or len(y)
202
+
203
+ severity = self._get_severity(ratio)
204
+ strategies, primary, explanation = self._get_strategies(severity, ratio, n_minority, n_total)
205
+
206
+ return ImbalanceRecommendation(
207
+ severity=severity, ratio=ratio, strategies=strategies,
208
+ primary_strategy=primary, explanation=explanation
209
+ )
210
+
211
+ def _get_severity(self, ratio: float) -> str:
212
+ if ratio < self.THRESHOLDS["low"]:
213
+ return "low"
214
+ elif ratio < self.THRESHOLDS["moderate"]:
215
+ return "moderate"
216
+ elif ratio < self.THRESHOLDS["high"]:
217
+ return "high"
218
+ return "severe"
219
+
220
+ def _get_strategies(self, severity: str, ratio: float, n_minority: int, n_total: int):
221
+ strategies = []
222
+ primary = ImbalanceStrategy.CLASS_WEIGHT
223
+ explanation = ""
224
+
225
+ if severity == "low":
226
+ explanation = f"Ratio {ratio:.1f}:1 is manageable. Class weights are usually sufficient."
227
+ strategies = [
228
+ (ImbalanceStrategy.CLASS_WEIGHT, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.CLASS_WEIGHT]),
229
+ ]
230
+ primary = ImbalanceStrategy.CLASS_WEIGHT
231
+
232
+ elif severity == "moderate":
233
+ explanation = f"Ratio {ratio:.1f}:1 may affect model performance. Consider resampling if class weights aren't enough."
234
+ strategies = [
235
+ (ImbalanceStrategy.CLASS_WEIGHT, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.CLASS_WEIGHT]),
236
+ (ImbalanceStrategy.SMOTE, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.SMOTE]),
237
+ ]
238
+ primary = ImbalanceStrategy.CLASS_WEIGHT
239
+
240
+ elif severity == "high":
241
+ explanation = f"Ratio {ratio:.1f}:1 is significant. SMOTE recommended to create synthetic minority samples."
242
+ if n_minority < 6:
243
+ explanation += f"\n⚠️ Only {n_minority} minority samples - SMOTE needs k=5 neighbors minimum."
244
+ strategies = [
245
+ (ImbalanceStrategy.RANDOM_OVERSAMPLE, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.RANDOM_OVERSAMPLE]),
246
+ (ImbalanceStrategy.CLASS_WEIGHT, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.CLASS_WEIGHT]),
247
+ ]
248
+ primary = ImbalanceStrategy.RANDOM_OVERSAMPLE
249
+ else:
250
+ strategies = [
251
+ (ImbalanceStrategy.SMOTE, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.SMOTE]),
252
+ (ImbalanceStrategy.SMOTEENN, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.SMOTEENN]),
253
+ (ImbalanceStrategy.CLASS_WEIGHT, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.CLASS_WEIGHT]),
254
+ ]
255
+ primary = ImbalanceStrategy.SMOTE
256
+
257
+ else: # severe
258
+ explanation = f"Ratio {ratio:.1f}:1 is severe. Combination of techniques recommended."
259
+ if n_total > 100000:
260
+ explanation += f"\nDataset is large ({n_total:,} rows) - undersampling majority is viable."
261
+ strategies = [
262
+ (ImbalanceStrategy.RANDOM_UNDERSAMPLE, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.RANDOM_UNDERSAMPLE]),
263
+ (ImbalanceStrategy.SMOTE, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.SMOTE]),
264
+ (ImbalanceStrategy.SMOTEENN, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.SMOTEENN]),
265
+ ]
266
+ primary = ImbalanceStrategy.RANDOM_UNDERSAMPLE
267
+ elif n_minority < 6:
268
+ explanation += f"\n⚠️ Only {n_minority} minority samples - limited options."
269
+ strategies = [
270
+ (ImbalanceStrategy.RANDOM_OVERSAMPLE, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.RANDOM_OVERSAMPLE]),
271
+ (ImbalanceStrategy.CLASS_WEIGHT, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.CLASS_WEIGHT]),
272
+ ]
273
+ primary = ImbalanceStrategy.RANDOM_OVERSAMPLE
274
+ else:
275
+ strategies = [
276
+ (ImbalanceStrategy.SMOTE, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.SMOTE]),
277
+ (ImbalanceStrategy.ADASYN, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.ADASYN]),
278
+ (ImbalanceStrategy.SMOTEENN, self.STRATEGY_DESCRIPTIONS[ImbalanceStrategy.SMOTEENN]),
279
+ ]
280
+ primary = ImbalanceStrategy.SMOTE
281
+
282
+ return strategies, primary, explanation
@@ -0,0 +1,95 @@
1
+ """MLflow integration for experiment tracking."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Optional
5
+
6
+ try:
7
+ import mlflow
8
+ import mlflow.sklearn
9
+ MLFLOW_AVAILABLE = True
10
+ except ImportError:
11
+ MLFLOW_AVAILABLE = False
12
+
13
+
14
+ @dataclass
15
+ class ExperimentConfig:
16
+ experiment_name: str
17
+ run_name: Optional[str] = None
18
+ tracking_uri: Optional[str] = None
19
+ artifact_location: Optional[str] = None
20
+
21
+
22
+ class MLflowLogger:
23
+ def __init__(
24
+ self,
25
+ experiment_name: str,
26
+ run_name: Optional[str] = None,
27
+ tracking_uri: Optional[str] = None,
28
+ ):
29
+ self.experiment_name = experiment_name
30
+ self.run_name = run_name
31
+ self.tracking_uri = tracking_uri
32
+ self._run = None
33
+
34
+ def __enter__(self):
35
+ self.start_run()
36
+ return self
37
+
38
+ def __exit__(self, exc_type, exc_val, exc_tb):
39
+ self.end_run()
40
+ return False
41
+
42
+ def start_run(self, run_name: Optional[str] = None):
43
+ if not MLFLOW_AVAILABLE:
44
+ return
45
+
46
+ if self.tracking_uri:
47
+ mlflow.set_tracking_uri(self.tracking_uri)
48
+
49
+ experiment = mlflow.get_experiment_by_name(self.experiment_name)
50
+ if experiment is None:
51
+ experiment_id = mlflow.create_experiment(self.experiment_name)
52
+ else:
53
+ experiment_id = experiment.experiment_id
54
+
55
+ self._run = mlflow.start_run(
56
+ experiment_id=experiment_id,
57
+ run_name=run_name or self.run_name,
58
+ )
59
+
60
+ def end_run(self):
61
+ if MLFLOW_AVAILABLE:
62
+ mlflow.end_run()
63
+ self._run = None
64
+
65
+ def log_params(self, params: Dict[str, Any]):
66
+ if MLFLOW_AVAILABLE:
67
+ mlflow.log_params(params)
68
+
69
+ def log_metrics(self, metrics: Dict[str, float]):
70
+ if MLFLOW_AVAILABLE:
71
+ mlflow.log_metrics(metrics)
72
+
73
+ def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
74
+ if MLFLOW_AVAILABLE:
75
+ mlflow.log_artifact(local_path, artifact_path)
76
+
77
+ def set_tags(self, tags: Dict[str, str]):
78
+ if MLFLOW_AVAILABLE:
79
+ mlflow.set_tags(tags)
80
+
81
+ def log_dict(self, dictionary: Dict[str, Any], artifact_file: str):
82
+ if MLFLOW_AVAILABLE:
83
+ mlflow.log_dict(dictionary, artifact_file)
84
+
85
+ def log_model(self, model, artifact_path: str, registered_model_name: Optional[str] = None):
86
+ if MLFLOW_AVAILABLE:
87
+ mlflow.sklearn.log_model(
88
+ model,
89
+ artifact_path,
90
+ registered_model_name=registered_model_name,
91
+ )
92
+
93
+ def log_figure(self, figure, artifact_file: str):
94
+ if MLFLOW_AVAILABLE:
95
+ mlflow.log_figure(figure, artifact_file)