churnkit 0.75.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
  2. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
  3. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
  4. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
  5. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
  6. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
  7. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
  8. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
  9. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
  10. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
  11. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
  12. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
  13. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
  14. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
  15. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
  16. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
  17. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
  18. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
  19. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
  20. churnkit-0.75.0a1.dist-info/METADATA +229 -0
  21. churnkit-0.75.0a1.dist-info/RECORD +302 -0
  22. churnkit-0.75.0a1.dist-info/WHEEL +4 -0
  23. churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
  24. churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
  25. customer_retention/__init__.py +37 -0
  26. customer_retention/analysis/__init__.py +0 -0
  27. customer_retention/analysis/auto_explorer/__init__.py +62 -0
  28. customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
  29. customer_retention/analysis/auto_explorer/explorer.py +258 -0
  30. customer_retention/analysis/auto_explorer/findings.py +291 -0
  31. customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
  32. customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
  33. customer_retention/analysis/auto_explorer/recommendations.py +418 -0
  34. customer_retention/analysis/business/__init__.py +26 -0
  35. customer_retention/analysis/business/ab_test_designer.py +144 -0
  36. customer_retention/analysis/business/fairness_analyzer.py +166 -0
  37. customer_retention/analysis/business/intervention_matcher.py +121 -0
  38. customer_retention/analysis/business/report_generator.py +222 -0
  39. customer_retention/analysis/business/risk_profile.py +199 -0
  40. customer_retention/analysis/business/roi_analyzer.py +139 -0
  41. customer_retention/analysis/diagnostics/__init__.py +20 -0
  42. customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
  43. customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
  44. customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
  45. customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
  46. customer_retention/analysis/diagnostics/noise_tester.py +140 -0
  47. customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
  48. customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
  49. customer_retention/analysis/discovery/__init__.py +8 -0
  50. customer_retention/analysis/discovery/config_generator.py +49 -0
  51. customer_retention/analysis/discovery/discovery_flow.py +19 -0
  52. customer_retention/analysis/discovery/type_inferencer.py +147 -0
  53. customer_retention/analysis/interpretability/__init__.py +13 -0
  54. customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
  55. customer_retention/analysis/interpretability/counterfactual.py +175 -0
  56. customer_retention/analysis/interpretability/individual_explainer.py +141 -0
  57. customer_retention/analysis/interpretability/pdp_generator.py +103 -0
  58. customer_retention/analysis/interpretability/shap_explainer.py +106 -0
  59. customer_retention/analysis/jupyter_save_hook.py +28 -0
  60. customer_retention/analysis/notebook_html_exporter.py +136 -0
  61. customer_retention/analysis/notebook_progress.py +60 -0
  62. customer_retention/analysis/plotly_preprocessor.py +154 -0
  63. customer_retention/analysis/recommendations/__init__.py +54 -0
  64. customer_retention/analysis/recommendations/base.py +158 -0
  65. customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
  66. customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
  67. customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
  68. customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
  69. customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
  70. customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
  71. customer_retention/analysis/recommendations/datetime/extract.py +149 -0
  72. customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
  73. customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
  74. customer_retention/analysis/recommendations/pipeline.py +74 -0
  75. customer_retention/analysis/recommendations/registry.py +76 -0
  76. customer_retention/analysis/recommendations/selection/__init__.py +3 -0
  77. customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
  78. customer_retention/analysis/recommendations/transform/__init__.py +4 -0
  79. customer_retention/analysis/recommendations/transform/power.py +94 -0
  80. customer_retention/analysis/recommendations/transform/scale.py +112 -0
  81. customer_retention/analysis/visualization/__init__.py +15 -0
  82. customer_retention/analysis/visualization/chart_builder.py +2619 -0
  83. customer_retention/analysis/visualization/console.py +122 -0
  84. customer_retention/analysis/visualization/display.py +171 -0
  85. customer_retention/analysis/visualization/number_formatter.py +36 -0
  86. customer_retention/artifacts/__init__.py +3 -0
  87. customer_retention/artifacts/fit_artifact_registry.py +146 -0
  88. customer_retention/cli.py +93 -0
  89. customer_retention/core/__init__.py +0 -0
  90. customer_retention/core/compat/__init__.py +193 -0
  91. customer_retention/core/compat/detection.py +99 -0
  92. customer_retention/core/compat/ops.py +48 -0
  93. customer_retention/core/compat/pandas_backend.py +57 -0
  94. customer_retention/core/compat/spark_backend.py +75 -0
  95. customer_retention/core/components/__init__.py +11 -0
  96. customer_retention/core/components/base.py +79 -0
  97. customer_retention/core/components/components/__init__.py +13 -0
  98. customer_retention/core/components/components/deployer.py +26 -0
  99. customer_retention/core/components/components/explainer.py +26 -0
  100. customer_retention/core/components/components/feature_eng.py +33 -0
  101. customer_retention/core/components/components/ingester.py +34 -0
  102. customer_retention/core/components/components/profiler.py +34 -0
  103. customer_retention/core/components/components/trainer.py +38 -0
  104. customer_retention/core/components/components/transformer.py +36 -0
  105. customer_retention/core/components/components/validator.py +37 -0
  106. customer_retention/core/components/enums.py +33 -0
  107. customer_retention/core/components/orchestrator.py +94 -0
  108. customer_retention/core/components/registry.py +59 -0
  109. customer_retention/core/config/__init__.py +39 -0
  110. customer_retention/core/config/column_config.py +95 -0
  111. customer_retention/core/config/experiments.py +71 -0
  112. customer_retention/core/config/pipeline_config.py +117 -0
  113. customer_retention/core/config/source_config.py +83 -0
  114. customer_retention/core/utils/__init__.py +28 -0
  115. customer_retention/core/utils/leakage.py +85 -0
  116. customer_retention/core/utils/severity.py +53 -0
  117. customer_retention/core/utils/statistics.py +90 -0
  118. customer_retention/generators/__init__.py +0 -0
  119. customer_retention/generators/notebook_generator/__init__.py +167 -0
  120. customer_retention/generators/notebook_generator/base.py +55 -0
  121. customer_retention/generators/notebook_generator/cell_builder.py +49 -0
  122. customer_retention/generators/notebook_generator/config.py +47 -0
  123. customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
  124. customer_retention/generators/notebook_generator/local_generator.py +48 -0
  125. customer_retention/generators/notebook_generator/project_init.py +174 -0
  126. customer_retention/generators/notebook_generator/runner.py +150 -0
  127. customer_retention/generators/notebook_generator/script_generator.py +110 -0
  128. customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
  129. customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
  130. customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
  131. customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
  132. customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
  133. customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
  134. customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
  135. customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
  136. customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
  137. customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
  138. customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
  139. customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
  140. customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
  141. customer_retention/generators/orchestration/__init__.py +23 -0
  142. customer_retention/generators/orchestration/code_generator.py +196 -0
  143. customer_retention/generators/orchestration/context.py +147 -0
  144. customer_retention/generators/orchestration/data_materializer.py +188 -0
  145. customer_retention/generators/orchestration/databricks_exporter.py +411 -0
  146. customer_retention/generators/orchestration/doc_generator.py +311 -0
  147. customer_retention/generators/pipeline_generator/__init__.py +26 -0
  148. customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
  149. customer_retention/generators/pipeline_generator/generator.py +142 -0
  150. customer_retention/generators/pipeline_generator/models.py +166 -0
  151. customer_retention/generators/pipeline_generator/renderer.py +2125 -0
  152. customer_retention/generators/spec_generator/__init__.py +37 -0
  153. customer_retention/generators/spec_generator/databricks_generator.py +433 -0
  154. customer_retention/generators/spec_generator/generic_generator.py +373 -0
  155. customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
  156. customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
  157. customer_retention/integrations/__init__.py +0 -0
  158. customer_retention/integrations/adapters/__init__.py +13 -0
  159. customer_retention/integrations/adapters/base.py +10 -0
  160. customer_retention/integrations/adapters/factory.py +25 -0
  161. customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
  162. customer_retention/integrations/adapters/feature_store/base.py +57 -0
  163. customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
  164. customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
  165. customer_retention/integrations/adapters/feature_store/local.py +75 -0
  166. customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
  167. customer_retention/integrations/adapters/mlflow/base.py +32 -0
  168. customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
  169. customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
  170. customer_retention/integrations/adapters/mlflow/local.py +50 -0
  171. customer_retention/integrations/adapters/storage/__init__.py +5 -0
  172. customer_retention/integrations/adapters/storage/base.py +33 -0
  173. customer_retention/integrations/adapters/storage/databricks.py +76 -0
  174. customer_retention/integrations/adapters/storage/local.py +59 -0
  175. customer_retention/integrations/feature_store/__init__.py +47 -0
  176. customer_retention/integrations/feature_store/definitions.py +215 -0
  177. customer_retention/integrations/feature_store/manager.py +744 -0
  178. customer_retention/integrations/feature_store/registry.py +412 -0
  179. customer_retention/integrations/iteration/__init__.py +28 -0
  180. customer_retention/integrations/iteration/context.py +212 -0
  181. customer_retention/integrations/iteration/feedback_collector.py +184 -0
  182. customer_retention/integrations/iteration/orchestrator.py +168 -0
  183. customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
  184. customer_retention/integrations/iteration/signals.py +212 -0
  185. customer_retention/integrations/llm_context/__init__.py +4 -0
  186. customer_retention/integrations/llm_context/context_builder.py +201 -0
  187. customer_retention/integrations/llm_context/prompts.py +100 -0
  188. customer_retention/integrations/streaming/__init__.py +103 -0
  189. customer_retention/integrations/streaming/batch_integration.py +149 -0
  190. customer_retention/integrations/streaming/early_warning_model.py +227 -0
  191. customer_retention/integrations/streaming/event_schema.py +214 -0
  192. customer_retention/integrations/streaming/online_store_writer.py +249 -0
  193. customer_retention/integrations/streaming/realtime_scorer.py +261 -0
  194. customer_retention/integrations/streaming/trigger_engine.py +293 -0
  195. customer_retention/integrations/streaming/window_aggregator.py +393 -0
  196. customer_retention/stages/__init__.py +0 -0
  197. customer_retention/stages/cleaning/__init__.py +9 -0
  198. customer_retention/stages/cleaning/base.py +28 -0
  199. customer_retention/stages/cleaning/missing_handler.py +160 -0
  200. customer_retention/stages/cleaning/outlier_handler.py +204 -0
  201. customer_retention/stages/deployment/__init__.py +28 -0
  202. customer_retention/stages/deployment/batch_scorer.py +106 -0
  203. customer_retention/stages/deployment/champion_challenger.py +299 -0
  204. customer_retention/stages/deployment/model_registry.py +182 -0
  205. customer_retention/stages/deployment/retraining_trigger.py +245 -0
  206. customer_retention/stages/features/__init__.py +73 -0
  207. customer_retention/stages/features/behavioral_features.py +266 -0
  208. customer_retention/stages/features/customer_segmentation.py +505 -0
  209. customer_retention/stages/features/feature_definitions.py +265 -0
  210. customer_retention/stages/features/feature_engineer.py +551 -0
  211. customer_retention/stages/features/feature_manifest.py +340 -0
  212. customer_retention/stages/features/feature_selector.py +239 -0
  213. customer_retention/stages/features/interaction_features.py +160 -0
  214. customer_retention/stages/features/temporal_features.py +243 -0
  215. customer_retention/stages/ingestion/__init__.py +9 -0
  216. customer_retention/stages/ingestion/load_result.py +32 -0
  217. customer_retention/stages/ingestion/loaders.py +195 -0
  218. customer_retention/stages/ingestion/source_registry.py +130 -0
  219. customer_retention/stages/modeling/__init__.py +31 -0
  220. customer_retention/stages/modeling/baseline_trainer.py +139 -0
  221. customer_retention/stages/modeling/cross_validator.py +125 -0
  222. customer_retention/stages/modeling/data_splitter.py +205 -0
  223. customer_retention/stages/modeling/feature_scaler.py +99 -0
  224. customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
  225. customer_retention/stages/modeling/imbalance_handler.py +282 -0
  226. customer_retention/stages/modeling/mlflow_logger.py +95 -0
  227. customer_retention/stages/modeling/model_comparator.py +149 -0
  228. customer_retention/stages/modeling/model_evaluator.py +138 -0
  229. customer_retention/stages/modeling/threshold_optimizer.py +131 -0
  230. customer_retention/stages/monitoring/__init__.py +37 -0
  231. customer_retention/stages/monitoring/alert_manager.py +328 -0
  232. customer_retention/stages/monitoring/drift_detector.py +201 -0
  233. customer_retention/stages/monitoring/performance_monitor.py +242 -0
  234. customer_retention/stages/preprocessing/__init__.py +5 -0
  235. customer_retention/stages/preprocessing/transformer_manager.py +284 -0
  236. customer_retention/stages/profiling/__init__.py +256 -0
  237. customer_retention/stages/profiling/categorical_distribution.py +269 -0
  238. customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
  239. customer_retention/stages/profiling/column_profiler.py +527 -0
  240. customer_retention/stages/profiling/distribution_analysis.py +483 -0
  241. customer_retention/stages/profiling/drift_detector.py +310 -0
  242. customer_retention/stages/profiling/feature_capacity.py +507 -0
  243. customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
  244. customer_retention/stages/profiling/profile_result.py +212 -0
  245. customer_retention/stages/profiling/quality_checks.py +1632 -0
  246. customer_retention/stages/profiling/relationship_detector.py +256 -0
  247. customer_retention/stages/profiling/relationship_recommender.py +454 -0
  248. customer_retention/stages/profiling/report_generator.py +520 -0
  249. customer_retention/stages/profiling/scd_analyzer.py +151 -0
  250. customer_retention/stages/profiling/segment_analyzer.py +632 -0
  251. customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
  252. customer_retention/stages/profiling/target_level_analyzer.py +217 -0
  253. customer_retention/stages/profiling/temporal_analyzer.py +388 -0
  254. customer_retention/stages/profiling/temporal_coverage.py +488 -0
  255. customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
  256. customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
  257. customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
  258. customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
  259. customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
  260. customer_retention/stages/profiling/text_embedder.py +87 -0
  261. customer_retention/stages/profiling/text_processor.py +115 -0
  262. customer_retention/stages/profiling/text_reducer.py +60 -0
  263. customer_retention/stages/profiling/time_series_profiler.py +303 -0
  264. customer_retention/stages/profiling/time_window_aggregator.py +376 -0
  265. customer_retention/stages/profiling/type_detector.py +382 -0
  266. customer_retention/stages/profiling/window_recommendation.py +288 -0
  267. customer_retention/stages/temporal/__init__.py +166 -0
  268. customer_retention/stages/temporal/access_guard.py +180 -0
  269. customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
  270. customer_retention/stages/temporal/data_preparer.py +178 -0
  271. customer_retention/stages/temporal/point_in_time_join.py +134 -0
  272. customer_retention/stages/temporal/point_in_time_registry.py +148 -0
  273. customer_retention/stages/temporal/scenario_detector.py +163 -0
  274. customer_retention/stages/temporal/snapshot_manager.py +259 -0
  275. customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
  276. customer_retention/stages/temporal/timestamp_discovery.py +531 -0
  277. customer_retention/stages/temporal/timestamp_manager.py +255 -0
  278. customer_retention/stages/transformation/__init__.py +13 -0
  279. customer_retention/stages/transformation/binary_handler.py +85 -0
  280. customer_retention/stages/transformation/categorical_encoder.py +245 -0
  281. customer_retention/stages/transformation/datetime_transformer.py +97 -0
  282. customer_retention/stages/transformation/numeric_transformer.py +181 -0
  283. customer_retention/stages/transformation/pipeline.py +257 -0
  284. customer_retention/stages/validation/__init__.py +60 -0
  285. customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
  286. customer_retention/stages/validation/business_sense_gate.py +173 -0
  287. customer_retention/stages/validation/data_quality_gate.py +235 -0
  288. customer_retention/stages/validation/data_validators.py +511 -0
  289. customer_retention/stages/validation/feature_quality_gate.py +183 -0
  290. customer_retention/stages/validation/gates.py +117 -0
  291. customer_retention/stages/validation/leakage_gate.py +352 -0
  292. customer_retention/stages/validation/model_validity_gate.py +213 -0
  293. customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
  294. customer_retention/stages/validation/quality_scorer.py +544 -0
  295. customer_retention/stages/validation/rule_generator.py +57 -0
  296. customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
  297. customer_retention/stages/validation/timeseries_detector.py +769 -0
  298. customer_retention/transforms/__init__.py +47 -0
  299. customer_retention/transforms/artifact_store.py +50 -0
  300. customer_retention/transforms/executor.py +157 -0
  301. customer_retention/transforms/fitted.py +92 -0
  302. customer_retention/transforms/ops.py +148 -0
@@ -0,0 +1,115 @@
1
+ from typing import List
2
+
3
+ import nbformat
4
+
5
+ from ..base import NotebookStage
6
+ from .base_stage import StageGenerator
7
+
8
+
9
+ class FeatureEngineeringStage(StageGenerator):
10
+ @property
11
+ def stage(self) -> NotebookStage:
12
+ return NotebookStage.FEATURE_ENGINEERING
13
+
14
+ @property
15
+ def title(self) -> str:
16
+ return "05 - Feature Engineering"
17
+
18
+ @property
19
+ def description(self) -> str:
20
+ return "Create derived features, interactions, and aggregations."
21
+
22
+ def generate_local_cells(self) -> List[nbformat.NotebookNode]:
23
+ return self.header_cells() + [
24
+ self.cb.section("Imports"),
25
+ self.cb.from_imports_cell({
26
+ "customer_retention.stages.features": ["FeatureEngineer", "FeatureEngineerConfig"],
27
+ "customer_retention.stages.features.temporal_features": ["TemporalFeatureGenerator", "ReferenceDateSource"],
28
+ "customer_retention.stages.temporal": ["PointInTimeJoiner", "SnapshotManager"],
29
+ "pathlib": ["Path"],
30
+ "pandas": ["pd"],
31
+ "numpy": ["np"],
32
+ }),
33
+ self.cb.section("Load Latest Training Snapshot"),
34
+ self.cb.code('''snapshot_manager = SnapshotManager(Path("./experiments/data"))
35
+ latest_snapshot = snapshot_manager.get_latest_snapshot()
36
+ if latest_snapshot:
37
+ df, metadata = snapshot_manager.load_snapshot(latest_snapshot)
38
+ print(f"Loaded snapshot: {latest_snapshot}")
39
+ print(f"Rows: {len(df)}, Features: {len(df.columns)}")
40
+ else:
41
+ from customer_retention.integrations.adapters.factory import get_delta
42
+ storage = get_delta(force_local=True)
43
+ df = storage.read("./experiments/data/silver/customers_transformed")
44
+ print(f"No snapshot found, loaded transformed data: {df.shape}")'''),
45
+ self.cb.section("Point-in-Time Feature Engineering"),
46
+ self.cb.markdown('''**Important**: All temporal features are calculated relative to `feature_timestamp` to prevent data leakage.'''),
47
+ self.cb.code('''if "feature_timestamp" in df.columns:
48
+ temporal_gen = TemporalFeatureGenerator(
49
+ reference_date_source=ReferenceDateSource.FEATURE_TIMESTAMP,
50
+ created_column="signup_date" if "signup_date" in df.columns else None,
51
+ last_order_column="last_activity" if "last_activity" in df.columns else None,
52
+ )
53
+ df = temporal_gen.fit_transform(df)
54
+ print(f"Created temporal features: {temporal_gen.generated_features}")
55
+ else:
56
+ print("Warning: No feature_timestamp column found. Using current date (may cause leakage).")
57
+ if "signup_date" in df.columns:
58
+ df["tenure_days"] = (pd.Timestamp.now() - pd.to_datetime(df["signup_date"])).dt.days'''),
59
+ self.cb.section("Validate Point-in-Time Correctness"),
60
+ self.cb.code('''if "feature_timestamp" in df.columns:
61
+ pit_report = PointInTimeJoiner.validate_temporal_integrity(df)
62
+ if pit_report["valid"]:
63
+ print("Point-in-time validation PASSED")
64
+ else:
65
+ print("Point-in-time validation FAILED:")
66
+ for issue in pit_report["issues"]:
67
+ print(f" - {issue['type']}: {issue['message']}")'''),
68
+ self.cb.section("Create Interaction Features"),
69
+ self.cb.code('''numeric_cols = [c for c in df.select_dtypes(include=[np.number]).columns
70
+ if c not in ["target", "entity_id"]]
71
+ if len(numeric_cols) >= 2:
72
+ for i, col1 in enumerate(numeric_cols[:3]):
73
+ for col2 in numeric_cols[i+1:4]:
74
+ df[f"{col1}_x_{col2}"] = df[col1] * df[col2]
75
+ print(f"Created interaction features")'''),
76
+ self.cb.section("Create Ratio Features"),
77
+ self.cb.code('''if "total_spend" in df.columns and "num_transactions" in df.columns:
78
+ df["avg_transaction_value"] = df["total_spend"] / (df["num_transactions"] + 1)
79
+ print("Created avg_transaction_value feature")'''),
80
+ self.cb.section("Save to Gold Layer"),
81
+ self.cb.code('''from customer_retention.integrations.adapters.factory import get_delta
82
+ storage = get_delta(force_local=True)
83
+ storage.write(df, "./experiments/data/gold/customers_features")
84
+ print(f"Gold layer saved: {df.shape}")'''),
85
+ ]
86
+
87
+ def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
88
+ catalog = self.config.feature_store.catalog
89
+ schema = self.config.feature_store.schema
90
+ return self.header_cells() + [
91
+ self.cb.section("Load Transformed Data"),
92
+ self.cb.code(f'''df = spark.table("{catalog}.{schema}.silver_transformed")'''),
93
+ self.cb.section("Create Derived Features"),
94
+ self.cb.code('''from pyspark.sql.functions import datediff, current_date, col
95
+
96
+ if "signup_date" in df.columns:
97
+ df = df.withColumn("tenure_days", datediff(current_date(), col("signup_date")))
98
+ print("Created tenure_days feature")
99
+
100
+ if "last_activity" in df.columns:
101
+ df = df.withColumn("recency_days", datediff(current_date(), col("last_activity")))
102
+ print("Created recency_days feature")'''),
103
+ self.cb.section("Create Interaction Features"),
104
+ self.cb.code('''numeric_cols = [f.name for f in df.schema.fields if str(f.dataType) in ["IntegerType()", "DoubleType()", "FloatType()"]]
105
+ if len(numeric_cols) >= 2:
106
+ df = df.withColumn(f"{numeric_cols[0]}_x_{numeric_cols[1]}", col(numeric_cols[0]) * col(numeric_cols[1]))
107
+ print("Created interaction features")'''),
108
+ self.cb.section("Create Ratio Features"),
109
+ self.cb.code('''if "total_spend" in df.columns and "num_transactions" in df.columns:
110
+ df = df.withColumn("avg_transaction_value", col("total_spend") / (col("num_transactions") + 1))
111
+ print("Created avg_transaction_value feature")'''),
112
+ self.cb.section("Save to Gold Table"),
113
+ self.cb.code(f'''df.write.format("delta").mode("overwrite").saveAsTable("{catalog}.{schema}.gold_customers")
114
+ print("Gold table created")'''),
115
+ ]
@@ -0,0 +1,97 @@
1
+ from typing import List
2
+
3
+ import nbformat
4
+
5
+ from ..base import NotebookStage
6
+ from .base_stage import StageGenerator
7
+
8
+
9
+ class FeatureSelectionStage(StageGenerator):
10
+ @property
11
+ def stage(self) -> NotebookStage:
12
+ return NotebookStage.FEATURE_SELECTION
13
+
14
+ @property
15
+ def title(self) -> str:
16
+ return "06 - Feature Selection"
17
+
18
+ @property
19
+ def description(self) -> str:
20
+ return "Select best features using variance, correlation, and importance filters."
21
+
22
+ def generate_local_cells(self) -> List[nbformat.NotebookNode]:
23
+ target = self.get_target_column()
24
+ var_thresh = self.config.variance_threshold
25
+ corr_thresh = self.config.correlation_threshold
26
+ return self.header_cells() + [
27
+ self.cb.section("Imports"),
28
+ self.cb.from_imports_cell({
29
+ "customer_retention.stages.features": ["FeatureSelector"],
30
+ "pandas": ["pd"],
31
+ "numpy": ["np"],
32
+ }),
33
+ self.cb.section("Load Gold Data"),
34
+ self.cb.code('''df = pd.read_parquet("./experiments/data/gold/customers_features.parquet")
35
+ print(f"Input shape: {df.shape}")'''),
36
+ self.cb.section("Identify Feature Columns"),
37
+ self.cb.code(f'''target_col = "{target}"
38
+ id_cols = {self.get_identifier_columns()}
39
+ feature_cols = [c for c in df.columns if c not in id_cols + [target_col]]
40
+ X = df[feature_cols]
41
+ y = df[target_col] if target_col in df.columns else None
42
+ print(f"Feature columns: {{len(feature_cols)}}")'''),
43
+ self.cb.section("Variance Filter"),
44
+ self.cb.code(f'''variance_threshold = {var_thresh}
45
+ variances = X.var()
46
+ low_variance = variances[variances < variance_threshold].index.tolist()
47
+ print(f"Low variance features ({{len(low_variance)}}): {{low_variance[:5]}}")
48
+ X = X.drop(columns=low_variance)'''),
49
+ self.cb.section("Correlation Filter"),
50
+ self.cb.code(f'''correlation_threshold = {corr_thresh}
51
+ corr_matrix = X.corr().abs()
52
+ upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
53
+ high_corr = [c for c in upper.columns if any(upper[c] > correlation_threshold)]
54
+ print(f"High correlation features ({{len(high_corr)}}): {{high_corr[:5]}}")
55
+ X = X.drop(columns=high_corr)'''),
56
+ self.cb.section("Save Selected Features"),
57
+ self.cb.code('''selected_df = df[[*id_cols, *X.columns, target_col]].dropna(subset=[target_col])
58
+ selected_df.to_parquet("./experiments/data/gold/customers_selected.parquet", index=False)
59
+ print(f"Selected {len(X.columns)} features, saved {len(selected_df)} rows")'''),
60
+ ]
61
+
62
+ def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
63
+ catalog = self.config.feature_store.catalog
64
+ schema = self.config.feature_store.schema
65
+ target = self.get_target_column()
66
+ return self.header_cells() + [
67
+ self.cb.section("Load Gold Data"),
68
+ self.cb.code(f'''df = spark.table("{catalog}.{schema}.gold_customers")'''),
69
+ self.cb.section("Compute Feature Correlations"),
70
+ self.cb.code(f'''from pyspark.ml.stat import Correlation
71
+ from pyspark.ml.feature import VectorAssembler
72
+
73
+ target_col = "{target}"
74
+ numeric_cols = [f.name for f in df.schema.fields if str(f.dataType) in ["IntegerType()", "DoubleType()", "FloatType()"] and f.name != target_col]
75
+
76
+ assembler = VectorAssembler(inputCols=numeric_cols, outputCol="features", handleInvalid="skip")
77
+ df_vec = assembler.transform(df)
78
+ corr_matrix = Correlation.corr(df_vec, "features").head()[0].toArray()
79
+ print(f"Correlation matrix shape: {{corr_matrix.shape}}")'''),
80
+ self.cb.section("Remove Highly Correlated Features"),
81
+ self.cb.code(f'''import numpy as np
82
+
83
+ correlation_threshold = {self.config.correlation_threshold}
84
+ to_drop = set()
85
+ for i in range(len(corr_matrix)):
86
+ for j in range(i+1, len(corr_matrix)):
87
+ if abs(corr_matrix[i,j]) > correlation_threshold:
88
+ to_drop.add(numeric_cols[j])
89
+
90
+ selected_cols = [c for c in numeric_cols if c not in to_drop]
91
+ print(f"Dropped {{len(to_drop)}} highly correlated features, keeping {{len(selected_cols)}}")'''),
92
+ self.cb.section("Save Selected Features"),
93
+ self.cb.code(f'''final_cols = {self.get_identifier_columns()} + selected_cols + [target_col]
94
+ df_selected = df.select(final_cols)
95
+ df_selected.write.format("delta").mode("overwrite").saveAsTable("{catalog}.{schema}.gold_selected")
96
+ print("Selected features saved")'''),
97
+ ]
@@ -0,0 +1,176 @@
1
+ from typing import List
2
+
3
+ import nbformat
4
+
5
+ from ..base import NotebookStage
6
+ from .base_stage import StageGenerator
7
+
8
+
9
+ class ModelTrainingStage(StageGenerator):
10
+ @property
11
+ def stage(self) -> NotebookStage:
12
+ return NotebookStage.MODEL_TRAINING
13
+
14
+ @property
15
+ def title(self) -> str:
16
+ return "07 - Model Training"
17
+
18
+ @property
19
+ def description(self) -> str:
20
+ return "Train baseline models with MLflow experiment tracking."
21
+
22
+ def generate_local_cells(self) -> List[nbformat.NotebookNode]:
23
+ target = self.get_target_column()
24
+ test_size = self.config.test_size
25
+ exp_name = self.config.mlflow.experiment_name
26
+ tracking_uri = self.config.mlflow.tracking_uri
27
+ return self.header_cells() + [
28
+ self.cb.section("Imports"),
29
+ self.cb.from_imports_cell({
30
+ "customer_retention.stages.modeling": ["BaselineTrainer", "ModelEvaluator", "DataSplitter"],
31
+ "customer_retention.integrations.adapters": ["get_mlflow"],
32
+ "customer_retention.analysis.visualization": ["ChartBuilder"],
33
+ "customer_retention.stages.temporal": ["SnapshotManager"],
34
+ "customer_retention.analysis.diagnostics": ["LeakageDetector"],
35
+ "pathlib": ["Path"],
36
+ "pandas": ["pd"],
37
+ }),
38
+ self.cb.section("Load Training Snapshot"),
39
+ self.cb.markdown('''**Important**: We load from a versioned snapshot to ensure reproducibility and prevent data leakage.'''),
40
+ self.cb.code('''snapshot_manager = SnapshotManager(Path("./experiments/data"))
41
+ latest_snapshot = snapshot_manager.get_latest_snapshot()
42
+
43
+ if latest_snapshot:
44
+ df, snapshot_metadata = snapshot_manager.load_snapshot(latest_snapshot)
45
+ print(f"Loaded snapshot: {latest_snapshot}")
46
+ print(f"Snapshot cutoff date: {snapshot_metadata.cutoff_date}")
47
+ print(f"Data hash: {snapshot_metadata.data_hash}")
48
+ print(f"Rows: {snapshot_metadata.row_count}")
49
+ else:
50
+ from customer_retention.integrations.adapters.factory import get_delta
51
+ storage = get_delta(force_local=True)
52
+ df = storage.read("./experiments/data/gold/customers_selected")
53
+ snapshot_metadata = None
54
+ print(f"Warning: No snapshot found, loading from gold layer: {df.shape}")'''),
55
+ self.cb.section("Prepare Train/Test Split"),
56
+ self.cb.code(f'''target_col = "target" if "target" in df.columns else "{target}"
57
+ id_cols = ["entity_id"] if "entity_id" in df.columns else {self.get_identifier_columns()}
58
+ temporal_cols = ["feature_timestamp", "label_timestamp", "label_available_flag"]
59
+ exclude_cols = id_cols + [target_col] + temporal_cols
60
+
61
+ feature_cols = [c for c in df.columns if c not in exclude_cols]
62
+ print(f"Using {{len(feature_cols)}} features (excluded: {{exclude_cols}})")
63
+
64
+ X = df[feature_cols]
65
+ y = df[target_col]
66
+
67
+ splitter = DataSplitter(test_size={test_size}, stratify=True, random_state=42)
68
+ X_train, X_test, y_train, y_test = splitter.split(X, y)
69
+ print(f"Train: {{len(X_train)}}, Test: {{len(X_test)}}")'''),
70
+ self.cb.section("Run Leakage Detection"),
71
+ self.cb.code('''detector = LeakageDetector()
72
+ leakage_result = detector.run_all_checks(X_train, y_train)
73
+
74
+ if not leakage_result.passed:
75
+ print("WARNING: Leakage detected!")
76
+ for issue in leakage_result.critical_issues:
77
+ print(f" CRITICAL: {issue.feature} - {issue.recommendation}")
78
+ else:
79
+ print("Leakage check PASSED")'''),
80
+ self.cb.section("Setup MLflow Tracking"),
81
+ self.cb.code(f'''mlflow_adapter = get_mlflow(tracking_uri="{tracking_uri}", force_local=True)
82
+ experiment_name = "{exp_name}"
83
+ print(f"MLflow tracking URI: {tracking_uri}")
84
+
85
+ snapshot_params = {{}}
86
+ if snapshot_metadata:
87
+ snapshot_params = {{
88
+ "snapshot_id": snapshot_metadata.snapshot_id,
89
+ "snapshot_version": snapshot_metadata.version,
90
+ "snapshot_cutoff": str(snapshot_metadata.cutoff_date),
91
+ "snapshot_hash": snapshot_metadata.data_hash,
92
+ }}'''),
93
+ self.cb.section("Train Baseline Models"),
94
+ self.cb.code('''from sklearn.linear_model import LogisticRegression
95
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
96
+
97
+ models = {
98
+ "logistic_regression": LogisticRegression(class_weight="balanced", max_iter=1000),
99
+ "random_forest": RandomForestClassifier(class_weight="balanced", n_estimators=100, random_state=42),
100
+ "gradient_boosting": GradientBoostingClassifier(n_estimators=100, random_state=42),
101
+ }
102
+
103
+ results = {}
104
+ for name, model in models.items():
105
+ mlflow_adapter.start_run(experiment_name, run_name=name)
106
+ model.fit(X_train, y_train)
107
+ y_pred = model.predict(X_test)
108
+ y_prob = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else y_pred
109
+
110
+ evaluator = ModelEvaluator()
111
+ metrics = evaluator.evaluate(y_test, y_pred, y_prob)
112
+ results[name] = {"model": model, "metrics": metrics, "y_pred": y_pred, "y_prob": y_prob}
113
+
114
+ all_params = {**model.get_params(), **snapshot_params}
115
+ mlflow_adapter.log_params(all_params)
116
+ mlflow_adapter.log_metrics(metrics)
117
+ mlflow_adapter.log_model(model, "model")
118
+ mlflow_adapter.end_run()
119
+ print(f"{name}: AUC={metrics.get('roc_auc', 0):.4f}, F1={metrics.get('f1', 0):.4f}")'''),
120
+ self.cb.section("Compare Models"),
121
+ self.cb.code('''charts = ChartBuilder()
122
+ fig = charts.model_comparison_grid(results, y_test)
123
+ fig.show()'''),
124
+ self.cb.section("Save Best Model"),
125
+ self.cb.code('''best_model_name = max(results, key=lambda k: results[k]["metrics"].get("roc_auc", 0))
126
+ best_model = results[best_model_name]["model"]
127
+ import joblib
128
+ joblib.dump(best_model, "./experiments/data/models/best_model.joblib")
129
+ print(f"Best model: {best_model_name}")'''),
130
+ ]
131
+
132
+ def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
133
+ catalog = self.config.feature_store.catalog
134
+ schema = self.config.feature_store.schema
135
+ target = self.get_target_column()
136
+ exp_name = self.config.mlflow.experiment_name
137
+ model_name = self.config.mlflow.model_name
138
+ return self.header_cells() + [
139
+ self.cb.section("Load Selected Features"),
140
+ self.cb.code(f'''df = spark.table("{catalog}.{schema}.gold_selected")'''),
141
+ self.cb.section("Prepare Features Vector"),
142
+ self.cb.code(f'''from pyspark.ml.feature import VectorAssembler
143
+
144
+ target_col = "{target}"
145
+ feature_cols = [c for c in df.columns if c not in {self.get_identifier_columns()} + [target_col]]
146
+
147
+ assembler = VectorAssembler(inputCols=feature_cols, outputCol="features", handleInvalid="skip")
148
+ df_ml = assembler.transform(df).select("features", target_col)
149
+ train_df, test_df = df_ml.randomSplit([0.8, 0.2], seed=42)
150
+ print(f"Train: {{train_df.count()}}, Test: {{test_df.count()}}")'''),
151
+ self.cb.section("Setup MLflow"),
152
+ self.cb.code(f'''import mlflow
153
+ mlflow.set_experiment("/Users/{{spark.conf.get('spark.databricks.notebook.username', 'default')}}/{exp_name}")'''),
154
+ self.cb.section("Train Gradient Boosted Trees"),
155
+ self.cb.code(f'''from pyspark.ml.classification import GBTClassifier
156
+ from pyspark.ml.evaluation import BinaryClassificationEvaluator
157
+
158
+ with mlflow.start_run(run_name="gbt_baseline"):
159
+ gbt = GBTClassifier(featuresCol="features", labelCol="{target}", maxIter=100)
160
+ model = gbt.fit(train_df)
161
+
162
+ predictions = model.transform(test_df)
163
+ evaluator = BinaryClassificationEvaluator(labelCol="{target}", metricName="areaUnderROC")
164
+ auc = evaluator.evaluate(predictions)
165
+
166
+ mlflow.log_param("maxIter", 100)
167
+ mlflow.log_metric("auc_roc", auc)
168
+ mlflow.spark.log_model(model, "model")
169
+
170
+ run_id = mlflow.active_run().info.run_id
171
+ print(f"AUC: {{auc:.4f}}, Run ID: {{run_id}}")'''),
172
+ self.cb.section("Register Model"),
173
+ self.cb.code(f'''model_uri = f"runs:/{{run_id}}/model"
174
+ mlflow.register_model(model_uri, "{catalog}.{schema}.{model_name}")
175
+ print(f"Model registered: {catalog}.{schema}.{model_name}")'''),
176
+ ]
@@ -0,0 +1,81 @@
1
+ from typing import List
2
+
3
+ import nbformat
4
+
5
+ from ..base import NotebookStage
6
+ from .base_stage import StageGenerator
7
+
8
+
9
+ class DeploymentStage(StageGenerator):
10
+ @property
11
+ def stage(self) -> NotebookStage:
12
+ return NotebookStage.DEPLOYMENT
13
+
14
+ @property
15
+ def title(self) -> str:
16
+ return "08 - Model Deployment"
17
+
18
+ @property
19
+ def description(self) -> str:
20
+ return "Register model to registry and promote to production."
21
+
22
+ def generate_local_cells(self) -> List[nbformat.NotebookNode]:
23
+ tracking_uri = self.config.mlflow.tracking_uri
24
+ model_name = self.config.mlflow.model_name
25
+ return self.header_cells() + [
26
+ self.cb.section("Imports"),
27
+ self.cb.from_imports_cell({
28
+ "customer_retention.stages.deployment": ["ModelRegistry", "ModelStage"],
29
+ "customer_retention.integrations.adapters": ["get_mlflow"],
30
+ }),
31
+ self.cb.section("Initialize Registry"),
32
+ self.cb.code(f'''mlflow_adapter = get_mlflow(tracking_uri="{tracking_uri}", force_local=True)
33
+ registry = ModelRegistry(tracking_uri="{tracking_uri}")
34
+ model_name = "{model_name}"'''),
35
+ self.cb.section("List Model Versions"),
36
+ self.cb.code('''versions = registry.list_versions(model_name)
37
+ for v in versions:
38
+ print(f"Version {v.version}: Stage={v.current_stage}, Run={v.run_id}")'''),
39
+ self.cb.section("Validate for Promotion"),
40
+ self.cb.code('''latest_version = max(versions, key=lambda v: int(v.version)).version if versions else "1"
41
+ validation = registry.validate_for_promotion(
42
+ model_name=model_name,
43
+ version=latest_version,
44
+ required_metrics={"roc_auc": 0.6},
45
+ )
46
+ print(f"Validation passed: {validation.is_valid}")
47
+ if not validation.is_valid:
48
+ print(f"Errors: {validation.errors}")'''),
49
+ self.cb.section("Promote to Production"),
50
+ self.cb.code('''if validation.is_valid:
51
+ registry.transition_stage(model_name, latest_version, ModelStage.PRODUCTION)
52
+ print(f"Model {model_name} v{latest_version} promoted to Production")
53
+ else:
54
+ print("Model not promoted due to validation failure")'''),
55
+ ]
56
+
57
+ def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
58
+ catalog = self.config.feature_store.catalog
59
+ schema = self.config.feature_store.schema
60
+ model_name = self.config.mlflow.model_name
61
+ return self.header_cells() + [
62
+ self.cb.section("Initialize MLflow Client"),
63
+ self.cb.code('''import mlflow
64
+ from mlflow.tracking import MlflowClient
65
+
66
+ client = MlflowClient()'''),
67
+ self.cb.section("Get Model Versions"),
68
+ self.cb.code(f'''model_full_name = "{catalog}.{schema}.{model_name}"
69
+ versions = client.search_model_versions(f"name='{{model_full_name}}'")
70
+ for v in versions:
71
+ print(f"Version {{v.version}}: Status={{v.status}}")'''),
72
+ self.cb.section("Get Latest Version"),
73
+ self.cb.code('''latest = max(versions, key=lambda v: int(v.version))
74
+ print(f"Latest version: {latest.version}")'''),
75
+ self.cb.section("Set Production Alias"),
76
+ self.cb.code('''client.set_registered_model_alias(model_full_name, "production", latest.version)
77
+ print(f"Model {model_full_name} v{latest.version} aliased as 'production'")'''),
78
+ self.cb.section("Verify Production Model"),
79
+ self.cb.code('''prod_version = client.get_model_version_by_alias(model_full_name, "production")
80
+ print(f"Production model version: {prod_version.version}")'''),
81
+ ]
@@ -0,0 +1,112 @@
1
+ from typing import List
2
+
3
+ import nbformat
4
+
5
+ from ..base import NotebookStage
6
+ from .base_stage import StageGenerator
7
+
8
+
9
+ class MonitoringStage(StageGenerator):
10
+ @property
11
+ def stage(self) -> NotebookStage:
12
+ return NotebookStage.MONITORING
13
+
14
+ @property
15
+ def title(self) -> str:
16
+ return "09 - Model Monitoring"
17
+
18
+ @property
19
+ def description(self) -> str:
20
+ return "Track model performance, detect drift, and set up alerts."
21
+
22
+ def generate_local_cells(self) -> List[nbformat.NotebookNode]:
23
+ return self.header_cells() + [
24
+ self.cb.section("Imports"),
25
+ self.cb.from_imports_cell({
26
+ "customer_retention.stages.monitoring": ["PerformanceMonitor", "DriftDetector"],
27
+ "customer_retention.analysis.visualization": ["ChartBuilder"],
28
+ "pandas": ["pd"],
29
+ "joblib": ["joblib"],
30
+ }),
31
+ self.cb.section("Load Production Model and Test Data"),
32
+ self.cb.code('''model = joblib.load("./experiments/data/models/best_model.joblib")
33
+ df_test = pd.read_parquet("./experiments/data/gold/customers_selected.parquet").sample(n=1000, random_state=42)'''),
34
+ self.cb.section("Generate Predictions"),
35
+ self.cb.code(f'''target_col = "{self.get_target_column()}"
36
+ id_cols = {self.get_identifier_columns()}
37
+ feature_cols = [c for c in df_test.columns if c not in id_cols + [target_col]]
38
+
39
+ X_test = df_test[feature_cols]
40
+ y_test = df_test[target_col]
41
+ y_prob = model.predict_proba(X_test)[:, 1]
42
+ y_pred = (y_prob >= 0.5).astype(int)'''),
43
+ self.cb.section("Calculate Performance Metrics"),
44
+ self.cb.code('''from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
45
+
46
+ current_metrics = {
47
+ "roc_auc": roc_auc_score(y_test, y_prob),
48
+ "precision": precision_score(y_test, y_pred),
49
+ "recall": recall_score(y_test, y_pred),
50
+ "f1": f1_score(y_test, y_pred),
51
+ }
52
+ for name, value in current_metrics.items():
53
+ print(f"{name}: {value:.4f}")'''),
54
+ self.cb.section("Compare to Baseline"),
55
+ self.cb.code('''baseline_metrics = {"roc_auc": 0.75, "precision": 0.60, "recall": 0.70, "f1": 0.65}
56
+ monitor = PerformanceMonitor(baseline_metrics)
57
+ result = monitor.evaluate(current_metrics)
58
+ print(f"Status: {result.status}")
59
+ for metric, change in result.changes.items():
60
+ print(f" {metric}: {change:+.2%}")'''),
61
+ self.cb.section("Detect Feature Drift"),
62
+ self.cb.code('''df_reference = pd.read_parquet("./experiments/data/gold/customers_features.parquet").sample(n=1000, random_state=0)
63
+ drift_detector = DriftDetector()
64
+ for col in feature_cols[:5]:
65
+ result = drift_detector.detect(df_reference[col], df_test[col])
66
+ if result.has_drift:
67
+ print(f"DRIFT detected in {col}: PSI={result.psi:.4f}")'''),
68
+ ]
69
+
70
+ def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
71
+ catalog = self.config.feature_store.catalog
72
+ schema = self.config.feature_store.schema
73
+ model_name = self.config.mlflow.model_name
74
+ target = self.get_target_column()
75
+ return self.header_cells() + [
76
+ self.cb.section("Load Model and Data"),
77
+ self.cb.code(f'''import mlflow
78
+
79
+ model = mlflow.pyfunc.load_model(f"models:/{catalog}.{schema}.{model_name}@production")
80
+ df_test = spark.table("{catalog}.{schema}.gold_selected").sample(0.1)'''),
81
+ self.cb.section("Generate Predictions"),
82
+ self.cb.code(f'''from pyspark.sql.functions import pandas_udf
83
+ import pandas as pd
84
+
85
+ feature_cols = [c for c in df_test.columns if c not in {self.get_identifier_columns()} + ["{target}"]]
86
+
87
+ @pandas_udf("double")
88
+ def predict_udf(*cols):
89
+ df = pd.concat(cols, axis=1)
90
+ df.columns = feature_cols
91
+ return pd.Series(model.predict(df))
92
+
93
+ df_predictions = df_test.withColumn("prediction", predict_udf(*[df_test[c] for c in feature_cols]))
94
+ display(df_predictions.limit(10))'''),
95
+ self.cb.section("Calculate Metrics"),
96
+ self.cb.code(f'''from pyspark.ml.evaluation import BinaryClassificationEvaluator
97
+
98
+ evaluator = BinaryClassificationEvaluator(labelCol="{target}", rawPredictionCol="prediction")
99
+ auc = evaluator.evaluate(df_predictions)
100
+ print(f"Current AUC: {{auc:.4f}}")'''),
101
+ self.cb.section("Check for Drift"),
102
+ self.cb.code(f'''df_reference = spark.table("{catalog}.{schema}.gold_customers").sample(0.1)
103
+
104
+ for col in feature_cols[:5]:
105
+ ref_stats = df_reference.select(col).describe().collect()
106
+ cur_stats = df_test.select(col).describe().collect()
107
+ ref_mean = float(ref_stats[1][1]) if ref_stats[1][1] else 0
108
+ cur_mean = float(cur_stats[1][1]) if cur_stats[1][1] else 0
109
+ drift_pct = abs(ref_mean - cur_mean) / (ref_mean + 1e-10) * 100
110
+ if drift_pct > 10:
111
+ print(f"DRIFT in {{col}}: {{drift_pct:.1f}}% mean shift")'''),
112
+ ]