churnkit 0.75.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
- churnkit-0.75.0a1.dist-info/METADATA +229 -0
- churnkit-0.75.0a1.dist-info/RECORD +302 -0
- churnkit-0.75.0a1.dist-info/WHEEL +4 -0
- churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
- churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
- customer_retention/__init__.py +37 -0
- customer_retention/analysis/__init__.py +0 -0
- customer_retention/analysis/auto_explorer/__init__.py +62 -0
- customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
- customer_retention/analysis/auto_explorer/explorer.py +258 -0
- customer_retention/analysis/auto_explorer/findings.py +291 -0
- customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
- customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
- customer_retention/analysis/auto_explorer/recommendations.py +418 -0
- customer_retention/analysis/business/__init__.py +26 -0
- customer_retention/analysis/business/ab_test_designer.py +144 -0
- customer_retention/analysis/business/fairness_analyzer.py +166 -0
- customer_retention/analysis/business/intervention_matcher.py +121 -0
- customer_retention/analysis/business/report_generator.py +222 -0
- customer_retention/analysis/business/risk_profile.py +199 -0
- customer_retention/analysis/business/roi_analyzer.py +139 -0
- customer_retention/analysis/diagnostics/__init__.py +20 -0
- customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
- customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
- customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
- customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
- customer_retention/analysis/diagnostics/noise_tester.py +140 -0
- customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
- customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
- customer_retention/analysis/discovery/__init__.py +8 -0
- customer_retention/analysis/discovery/config_generator.py +49 -0
- customer_retention/analysis/discovery/discovery_flow.py +19 -0
- customer_retention/analysis/discovery/type_inferencer.py +147 -0
- customer_retention/analysis/interpretability/__init__.py +13 -0
- customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
- customer_retention/analysis/interpretability/counterfactual.py +175 -0
- customer_retention/analysis/interpretability/individual_explainer.py +141 -0
- customer_retention/analysis/interpretability/pdp_generator.py +103 -0
- customer_retention/analysis/interpretability/shap_explainer.py +106 -0
- customer_retention/analysis/jupyter_save_hook.py +28 -0
- customer_retention/analysis/notebook_html_exporter.py +136 -0
- customer_retention/analysis/notebook_progress.py +60 -0
- customer_retention/analysis/plotly_preprocessor.py +154 -0
- customer_retention/analysis/recommendations/__init__.py +54 -0
- customer_retention/analysis/recommendations/base.py +158 -0
- customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
- customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
- customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
- customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
- customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
- customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
- customer_retention/analysis/recommendations/datetime/extract.py +149 -0
- customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
- customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
- customer_retention/analysis/recommendations/pipeline.py +74 -0
- customer_retention/analysis/recommendations/registry.py +76 -0
- customer_retention/analysis/recommendations/selection/__init__.py +3 -0
- customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
- customer_retention/analysis/recommendations/transform/__init__.py +4 -0
- customer_retention/analysis/recommendations/transform/power.py +94 -0
- customer_retention/analysis/recommendations/transform/scale.py +112 -0
- customer_retention/analysis/visualization/__init__.py +15 -0
- customer_retention/analysis/visualization/chart_builder.py +2619 -0
- customer_retention/analysis/visualization/console.py +122 -0
- customer_retention/analysis/visualization/display.py +171 -0
- customer_retention/analysis/visualization/number_formatter.py +36 -0
- customer_retention/artifacts/__init__.py +3 -0
- customer_retention/artifacts/fit_artifact_registry.py +146 -0
- customer_retention/cli.py +93 -0
- customer_retention/core/__init__.py +0 -0
- customer_retention/core/compat/__init__.py +193 -0
- customer_retention/core/compat/detection.py +99 -0
- customer_retention/core/compat/ops.py +48 -0
- customer_retention/core/compat/pandas_backend.py +57 -0
- customer_retention/core/compat/spark_backend.py +75 -0
- customer_retention/core/components/__init__.py +11 -0
- customer_retention/core/components/base.py +79 -0
- customer_retention/core/components/components/__init__.py +13 -0
- customer_retention/core/components/components/deployer.py +26 -0
- customer_retention/core/components/components/explainer.py +26 -0
- customer_retention/core/components/components/feature_eng.py +33 -0
- customer_retention/core/components/components/ingester.py +34 -0
- customer_retention/core/components/components/profiler.py +34 -0
- customer_retention/core/components/components/trainer.py +38 -0
- customer_retention/core/components/components/transformer.py +36 -0
- customer_retention/core/components/components/validator.py +37 -0
- customer_retention/core/components/enums.py +33 -0
- customer_retention/core/components/orchestrator.py +94 -0
- customer_retention/core/components/registry.py +59 -0
- customer_retention/core/config/__init__.py +39 -0
- customer_retention/core/config/column_config.py +95 -0
- customer_retention/core/config/experiments.py +71 -0
- customer_retention/core/config/pipeline_config.py +117 -0
- customer_retention/core/config/source_config.py +83 -0
- customer_retention/core/utils/__init__.py +28 -0
- customer_retention/core/utils/leakage.py +85 -0
- customer_retention/core/utils/severity.py +53 -0
- customer_retention/core/utils/statistics.py +90 -0
- customer_retention/generators/__init__.py +0 -0
- customer_retention/generators/notebook_generator/__init__.py +167 -0
- customer_retention/generators/notebook_generator/base.py +55 -0
- customer_retention/generators/notebook_generator/cell_builder.py +49 -0
- customer_retention/generators/notebook_generator/config.py +47 -0
- customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
- customer_retention/generators/notebook_generator/local_generator.py +48 -0
- customer_retention/generators/notebook_generator/project_init.py +174 -0
- customer_retention/generators/notebook_generator/runner.py +150 -0
- customer_retention/generators/notebook_generator/script_generator.py +110 -0
- customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
- customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
- customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
- customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
- customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
- customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
- customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
- customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
- customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
- customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
- customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
- customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
- customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
- customer_retention/generators/orchestration/__init__.py +23 -0
- customer_retention/generators/orchestration/code_generator.py +196 -0
- customer_retention/generators/orchestration/context.py +147 -0
- customer_retention/generators/orchestration/data_materializer.py +188 -0
- customer_retention/generators/orchestration/databricks_exporter.py +411 -0
- customer_retention/generators/orchestration/doc_generator.py +311 -0
- customer_retention/generators/pipeline_generator/__init__.py +26 -0
- customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
- customer_retention/generators/pipeline_generator/generator.py +142 -0
- customer_retention/generators/pipeline_generator/models.py +166 -0
- customer_retention/generators/pipeline_generator/renderer.py +2125 -0
- customer_retention/generators/spec_generator/__init__.py +37 -0
- customer_retention/generators/spec_generator/databricks_generator.py +433 -0
- customer_retention/generators/spec_generator/generic_generator.py +373 -0
- customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
- customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
- customer_retention/integrations/__init__.py +0 -0
- customer_retention/integrations/adapters/__init__.py +13 -0
- customer_retention/integrations/adapters/base.py +10 -0
- customer_retention/integrations/adapters/factory.py +25 -0
- customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
- customer_retention/integrations/adapters/feature_store/base.py +57 -0
- customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
- customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
- customer_retention/integrations/adapters/feature_store/local.py +75 -0
- customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
- customer_retention/integrations/adapters/mlflow/base.py +32 -0
- customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
- customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
- customer_retention/integrations/adapters/mlflow/local.py +50 -0
- customer_retention/integrations/adapters/storage/__init__.py +5 -0
- customer_retention/integrations/adapters/storage/base.py +33 -0
- customer_retention/integrations/adapters/storage/databricks.py +76 -0
- customer_retention/integrations/adapters/storage/local.py +59 -0
- customer_retention/integrations/feature_store/__init__.py +47 -0
- customer_retention/integrations/feature_store/definitions.py +215 -0
- customer_retention/integrations/feature_store/manager.py +744 -0
- customer_retention/integrations/feature_store/registry.py +412 -0
- customer_retention/integrations/iteration/__init__.py +28 -0
- customer_retention/integrations/iteration/context.py +212 -0
- customer_retention/integrations/iteration/feedback_collector.py +184 -0
- customer_retention/integrations/iteration/orchestrator.py +168 -0
- customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
- customer_retention/integrations/iteration/signals.py +212 -0
- customer_retention/integrations/llm_context/__init__.py +4 -0
- customer_retention/integrations/llm_context/context_builder.py +201 -0
- customer_retention/integrations/llm_context/prompts.py +100 -0
- customer_retention/integrations/streaming/__init__.py +103 -0
- customer_retention/integrations/streaming/batch_integration.py +149 -0
- customer_retention/integrations/streaming/early_warning_model.py +227 -0
- customer_retention/integrations/streaming/event_schema.py +214 -0
- customer_retention/integrations/streaming/online_store_writer.py +249 -0
- customer_retention/integrations/streaming/realtime_scorer.py +261 -0
- customer_retention/integrations/streaming/trigger_engine.py +293 -0
- customer_retention/integrations/streaming/window_aggregator.py +393 -0
- customer_retention/stages/__init__.py +0 -0
- customer_retention/stages/cleaning/__init__.py +9 -0
- customer_retention/stages/cleaning/base.py +28 -0
- customer_retention/stages/cleaning/missing_handler.py +160 -0
- customer_retention/stages/cleaning/outlier_handler.py +204 -0
- customer_retention/stages/deployment/__init__.py +28 -0
- customer_retention/stages/deployment/batch_scorer.py +106 -0
- customer_retention/stages/deployment/champion_challenger.py +299 -0
- customer_retention/stages/deployment/model_registry.py +182 -0
- customer_retention/stages/deployment/retraining_trigger.py +245 -0
- customer_retention/stages/features/__init__.py +73 -0
- customer_retention/stages/features/behavioral_features.py +266 -0
- customer_retention/stages/features/customer_segmentation.py +505 -0
- customer_retention/stages/features/feature_definitions.py +265 -0
- customer_retention/stages/features/feature_engineer.py +551 -0
- customer_retention/stages/features/feature_manifest.py +340 -0
- customer_retention/stages/features/feature_selector.py +239 -0
- customer_retention/stages/features/interaction_features.py +160 -0
- customer_retention/stages/features/temporal_features.py +243 -0
- customer_retention/stages/ingestion/__init__.py +9 -0
- customer_retention/stages/ingestion/load_result.py +32 -0
- customer_retention/stages/ingestion/loaders.py +195 -0
- customer_retention/stages/ingestion/source_registry.py +130 -0
- customer_retention/stages/modeling/__init__.py +31 -0
- customer_retention/stages/modeling/baseline_trainer.py +139 -0
- customer_retention/stages/modeling/cross_validator.py +125 -0
- customer_retention/stages/modeling/data_splitter.py +205 -0
- customer_retention/stages/modeling/feature_scaler.py +99 -0
- customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
- customer_retention/stages/modeling/imbalance_handler.py +282 -0
- customer_retention/stages/modeling/mlflow_logger.py +95 -0
- customer_retention/stages/modeling/model_comparator.py +149 -0
- customer_retention/stages/modeling/model_evaluator.py +138 -0
- customer_retention/stages/modeling/threshold_optimizer.py +131 -0
- customer_retention/stages/monitoring/__init__.py +37 -0
- customer_retention/stages/monitoring/alert_manager.py +328 -0
- customer_retention/stages/monitoring/drift_detector.py +201 -0
- customer_retention/stages/monitoring/performance_monitor.py +242 -0
- customer_retention/stages/preprocessing/__init__.py +5 -0
- customer_retention/stages/preprocessing/transformer_manager.py +284 -0
- customer_retention/stages/profiling/__init__.py +256 -0
- customer_retention/stages/profiling/categorical_distribution.py +269 -0
- customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
- customer_retention/stages/profiling/column_profiler.py +527 -0
- customer_retention/stages/profiling/distribution_analysis.py +483 -0
- customer_retention/stages/profiling/drift_detector.py +310 -0
- customer_retention/stages/profiling/feature_capacity.py +507 -0
- customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
- customer_retention/stages/profiling/profile_result.py +212 -0
- customer_retention/stages/profiling/quality_checks.py +1632 -0
- customer_retention/stages/profiling/relationship_detector.py +256 -0
- customer_retention/stages/profiling/relationship_recommender.py +454 -0
- customer_retention/stages/profiling/report_generator.py +520 -0
- customer_retention/stages/profiling/scd_analyzer.py +151 -0
- customer_retention/stages/profiling/segment_analyzer.py +632 -0
- customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
- customer_retention/stages/profiling/target_level_analyzer.py +217 -0
- customer_retention/stages/profiling/temporal_analyzer.py +388 -0
- customer_retention/stages/profiling/temporal_coverage.py +488 -0
- customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
- customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
- customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
- customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
- customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
- customer_retention/stages/profiling/text_embedder.py +87 -0
- customer_retention/stages/profiling/text_processor.py +115 -0
- customer_retention/stages/profiling/text_reducer.py +60 -0
- customer_retention/stages/profiling/time_series_profiler.py +303 -0
- customer_retention/stages/profiling/time_window_aggregator.py +376 -0
- customer_retention/stages/profiling/type_detector.py +382 -0
- customer_retention/stages/profiling/window_recommendation.py +288 -0
- customer_retention/stages/temporal/__init__.py +166 -0
- customer_retention/stages/temporal/access_guard.py +180 -0
- customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
- customer_retention/stages/temporal/data_preparer.py +178 -0
- customer_retention/stages/temporal/point_in_time_join.py +134 -0
- customer_retention/stages/temporal/point_in_time_registry.py +148 -0
- customer_retention/stages/temporal/scenario_detector.py +163 -0
- customer_retention/stages/temporal/snapshot_manager.py +259 -0
- customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
- customer_retention/stages/temporal/timestamp_discovery.py +531 -0
- customer_retention/stages/temporal/timestamp_manager.py +255 -0
- customer_retention/stages/transformation/__init__.py +13 -0
- customer_retention/stages/transformation/binary_handler.py +85 -0
- customer_retention/stages/transformation/categorical_encoder.py +245 -0
- customer_retention/stages/transformation/datetime_transformer.py +97 -0
- customer_retention/stages/transformation/numeric_transformer.py +181 -0
- customer_retention/stages/transformation/pipeline.py +257 -0
- customer_retention/stages/validation/__init__.py +60 -0
- customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
- customer_retention/stages/validation/business_sense_gate.py +173 -0
- customer_retention/stages/validation/data_quality_gate.py +235 -0
- customer_retention/stages/validation/data_validators.py +511 -0
- customer_retention/stages/validation/feature_quality_gate.py +183 -0
- customer_retention/stages/validation/gates.py +117 -0
- customer_retention/stages/validation/leakage_gate.py +352 -0
- customer_retention/stages/validation/model_validity_gate.py +213 -0
- customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
- customer_retention/stages/validation/quality_scorer.py +544 -0
- customer_retention/stages/validation/rule_generator.py +57 -0
- customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
- customer_retention/stages/validation/timeseries_detector.py +769 -0
- customer_retention/transforms/__init__.py +47 -0
- customer_retention/transforms/artifact_store.py +50 -0
- customer_retention/transforms/executor.py +157 -0
- customer_retention/transforms/fitted.py +92 -0
- customer_retention/transforms/ops.py +148 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import nbformat
|
|
4
|
+
|
|
5
|
+
from ..base import NotebookStage
|
|
6
|
+
from .base_stage import StageGenerator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FeatureEngineeringStage(StageGenerator):
|
|
10
|
+
@property
|
|
11
|
+
def stage(self) -> NotebookStage:
|
|
12
|
+
return NotebookStage.FEATURE_ENGINEERING
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def title(self) -> str:
|
|
16
|
+
return "05 - Feature Engineering"
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def description(self) -> str:
|
|
20
|
+
return "Create derived features, interactions, and aggregations."
|
|
21
|
+
|
|
22
|
+
def generate_local_cells(self) -> List[nbformat.NotebookNode]:
|
|
23
|
+
return self.header_cells() + [
|
|
24
|
+
self.cb.section("Imports"),
|
|
25
|
+
self.cb.from_imports_cell({
|
|
26
|
+
"customer_retention.stages.features": ["FeatureEngineer", "FeatureEngineerConfig"],
|
|
27
|
+
"customer_retention.stages.features.temporal_features": ["TemporalFeatureGenerator", "ReferenceDateSource"],
|
|
28
|
+
"customer_retention.stages.temporal": ["PointInTimeJoiner", "SnapshotManager"],
|
|
29
|
+
"pathlib": ["Path"],
|
|
30
|
+
"pandas": ["pd"],
|
|
31
|
+
"numpy": ["np"],
|
|
32
|
+
}),
|
|
33
|
+
self.cb.section("Load Latest Training Snapshot"),
|
|
34
|
+
self.cb.code('''snapshot_manager = SnapshotManager(Path("./experiments/data"))
|
|
35
|
+
latest_snapshot = snapshot_manager.get_latest_snapshot()
|
|
36
|
+
if latest_snapshot:
|
|
37
|
+
df, metadata = snapshot_manager.load_snapshot(latest_snapshot)
|
|
38
|
+
print(f"Loaded snapshot: {latest_snapshot}")
|
|
39
|
+
print(f"Rows: {len(df)}, Features: {len(df.columns)}")
|
|
40
|
+
else:
|
|
41
|
+
from customer_retention.integrations.adapters.factory import get_delta
|
|
42
|
+
storage = get_delta(force_local=True)
|
|
43
|
+
df = storage.read("./experiments/data/silver/customers_transformed")
|
|
44
|
+
print(f"No snapshot found, loaded transformed data: {df.shape}")'''),
|
|
45
|
+
self.cb.section("Point-in-Time Feature Engineering"),
|
|
46
|
+
self.cb.markdown('''**Important**: All temporal features are calculated relative to `feature_timestamp` to prevent data leakage.'''),
|
|
47
|
+
self.cb.code('''if "feature_timestamp" in df.columns:
|
|
48
|
+
temporal_gen = TemporalFeatureGenerator(
|
|
49
|
+
reference_date_source=ReferenceDateSource.FEATURE_TIMESTAMP,
|
|
50
|
+
created_column="signup_date" if "signup_date" in df.columns else None,
|
|
51
|
+
last_order_column="last_activity" if "last_activity" in df.columns else None,
|
|
52
|
+
)
|
|
53
|
+
df = temporal_gen.fit_transform(df)
|
|
54
|
+
print(f"Created temporal features: {temporal_gen.generated_features}")
|
|
55
|
+
else:
|
|
56
|
+
print("Warning: No feature_timestamp column found. Using current date (may cause leakage).")
|
|
57
|
+
if "signup_date" in df.columns:
|
|
58
|
+
df["tenure_days"] = (pd.Timestamp.now() - pd.to_datetime(df["signup_date"])).dt.days'''),
|
|
59
|
+
self.cb.section("Validate Point-in-Time Correctness"),
|
|
60
|
+
self.cb.code('''if "feature_timestamp" in df.columns:
|
|
61
|
+
pit_report = PointInTimeJoiner.validate_temporal_integrity(df)
|
|
62
|
+
if pit_report["valid"]:
|
|
63
|
+
print("Point-in-time validation PASSED")
|
|
64
|
+
else:
|
|
65
|
+
print("Point-in-time validation FAILED:")
|
|
66
|
+
for issue in pit_report["issues"]:
|
|
67
|
+
print(f" - {issue['type']}: {issue['message']}")'''),
|
|
68
|
+
self.cb.section("Create Interaction Features"),
|
|
69
|
+
self.cb.code('''numeric_cols = [c for c in df.select_dtypes(include=[np.number]).columns
|
|
70
|
+
if c not in ["target", "entity_id"]]
|
|
71
|
+
if len(numeric_cols) >= 2:
|
|
72
|
+
for i, col1 in enumerate(numeric_cols[:3]):
|
|
73
|
+
for col2 in numeric_cols[i+1:4]:
|
|
74
|
+
df[f"{col1}_x_{col2}"] = df[col1] * df[col2]
|
|
75
|
+
print(f"Created interaction features")'''),
|
|
76
|
+
self.cb.section("Create Ratio Features"),
|
|
77
|
+
self.cb.code('''if "total_spend" in df.columns and "num_transactions" in df.columns:
|
|
78
|
+
df["avg_transaction_value"] = df["total_spend"] / (df["num_transactions"] + 1)
|
|
79
|
+
print("Created avg_transaction_value feature")'''),
|
|
80
|
+
self.cb.section("Save to Gold Layer"),
|
|
81
|
+
self.cb.code('''from customer_retention.integrations.adapters.factory import get_delta
|
|
82
|
+
storage = get_delta(force_local=True)
|
|
83
|
+
storage.write(df, "./experiments/data/gold/customers_features")
|
|
84
|
+
print(f"Gold layer saved: {df.shape}")'''),
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
|
|
88
|
+
catalog = self.config.feature_store.catalog
|
|
89
|
+
schema = self.config.feature_store.schema
|
|
90
|
+
return self.header_cells() + [
|
|
91
|
+
self.cb.section("Load Transformed Data"),
|
|
92
|
+
self.cb.code(f'''df = spark.table("{catalog}.{schema}.silver_transformed")'''),
|
|
93
|
+
self.cb.section("Create Derived Features"),
|
|
94
|
+
self.cb.code('''from pyspark.sql.functions import datediff, current_date, col
|
|
95
|
+
|
|
96
|
+
if "signup_date" in df.columns:
|
|
97
|
+
df = df.withColumn("tenure_days", datediff(current_date(), col("signup_date")))
|
|
98
|
+
print("Created tenure_days feature")
|
|
99
|
+
|
|
100
|
+
if "last_activity" in df.columns:
|
|
101
|
+
df = df.withColumn("recency_days", datediff(current_date(), col("last_activity")))
|
|
102
|
+
print("Created recency_days feature")'''),
|
|
103
|
+
self.cb.section("Create Interaction Features"),
|
|
104
|
+
self.cb.code('''numeric_cols = [f.name for f in df.schema.fields if str(f.dataType) in ["IntegerType()", "DoubleType()", "FloatType()"]]
|
|
105
|
+
if len(numeric_cols) >= 2:
|
|
106
|
+
df = df.withColumn(f"{numeric_cols[0]}_x_{numeric_cols[1]}", col(numeric_cols[0]) * col(numeric_cols[1]))
|
|
107
|
+
print("Created interaction features")'''),
|
|
108
|
+
self.cb.section("Create Ratio Features"),
|
|
109
|
+
self.cb.code('''if "total_spend" in df.columns and "num_transactions" in df.columns:
|
|
110
|
+
df = df.withColumn("avg_transaction_value", col("total_spend") / (col("num_transactions") + 1))
|
|
111
|
+
print("Created avg_transaction_value feature")'''),
|
|
112
|
+
self.cb.section("Save to Gold Table"),
|
|
113
|
+
self.cb.code(f'''df.write.format("delta").mode("overwrite").saveAsTable("{catalog}.{schema}.gold_customers")
|
|
114
|
+
print("Gold table created")'''),
|
|
115
|
+
]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import nbformat
|
|
4
|
+
|
|
5
|
+
from ..base import NotebookStage
|
|
6
|
+
from .base_stage import StageGenerator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FeatureSelectionStage(StageGenerator):
|
|
10
|
+
@property
|
|
11
|
+
def stage(self) -> NotebookStage:
|
|
12
|
+
return NotebookStage.FEATURE_SELECTION
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def title(self) -> str:
|
|
16
|
+
return "06 - Feature Selection"
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def description(self) -> str:
|
|
20
|
+
return "Select best features using variance, correlation, and importance filters."
|
|
21
|
+
|
|
22
|
+
def generate_local_cells(self) -> List[nbformat.NotebookNode]:
|
|
23
|
+
target = self.get_target_column()
|
|
24
|
+
var_thresh = self.config.variance_threshold
|
|
25
|
+
corr_thresh = self.config.correlation_threshold
|
|
26
|
+
return self.header_cells() + [
|
|
27
|
+
self.cb.section("Imports"),
|
|
28
|
+
self.cb.from_imports_cell({
|
|
29
|
+
"customer_retention.stages.features": ["FeatureSelector"],
|
|
30
|
+
"pandas": ["pd"],
|
|
31
|
+
"numpy": ["np"],
|
|
32
|
+
}),
|
|
33
|
+
self.cb.section("Load Gold Data"),
|
|
34
|
+
self.cb.code('''df = pd.read_parquet("./experiments/data/gold/customers_features.parquet")
|
|
35
|
+
print(f"Input shape: {df.shape}")'''),
|
|
36
|
+
self.cb.section("Identify Feature Columns"),
|
|
37
|
+
self.cb.code(f'''target_col = "{target}"
|
|
38
|
+
id_cols = {self.get_identifier_columns()}
|
|
39
|
+
feature_cols = [c for c in df.columns if c not in id_cols + [target_col]]
|
|
40
|
+
X = df[feature_cols]
|
|
41
|
+
y = df[target_col] if target_col in df.columns else None
|
|
42
|
+
print(f"Feature columns: {{len(feature_cols)}}")'''),
|
|
43
|
+
self.cb.section("Variance Filter"),
|
|
44
|
+
self.cb.code(f'''variance_threshold = {var_thresh}
|
|
45
|
+
variances = X.var()
|
|
46
|
+
low_variance = variances[variances < variance_threshold].index.tolist()
|
|
47
|
+
print(f"Low variance features ({{len(low_variance)}}): {{low_variance[:5]}}")
|
|
48
|
+
X = X.drop(columns=low_variance)'''),
|
|
49
|
+
self.cb.section("Correlation Filter"),
|
|
50
|
+
self.cb.code(f'''correlation_threshold = {corr_thresh}
|
|
51
|
+
corr_matrix = X.corr().abs()
|
|
52
|
+
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
|
|
53
|
+
high_corr = [c for c in upper.columns if any(upper[c] > correlation_threshold)]
|
|
54
|
+
print(f"High correlation features ({{len(high_corr)}}): {{high_corr[:5]}}")
|
|
55
|
+
X = X.drop(columns=high_corr)'''),
|
|
56
|
+
self.cb.section("Save Selected Features"),
|
|
57
|
+
self.cb.code('''selected_df = df[[*id_cols, *X.columns, target_col]].dropna(subset=[target_col])
|
|
58
|
+
selected_df.to_parquet("./experiments/data/gold/customers_selected.parquet", index=False)
|
|
59
|
+
print(f"Selected {len(X.columns)} features, saved {len(selected_df)} rows")'''),
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
|
|
63
|
+
catalog = self.config.feature_store.catalog
|
|
64
|
+
schema = self.config.feature_store.schema
|
|
65
|
+
target = self.get_target_column()
|
|
66
|
+
return self.header_cells() + [
|
|
67
|
+
self.cb.section("Load Gold Data"),
|
|
68
|
+
self.cb.code(f'''df = spark.table("{catalog}.{schema}.gold_customers")'''),
|
|
69
|
+
self.cb.section("Compute Feature Correlations"),
|
|
70
|
+
self.cb.code(f'''from pyspark.ml.stat import Correlation
|
|
71
|
+
from pyspark.ml.feature import VectorAssembler
|
|
72
|
+
|
|
73
|
+
target_col = "{target}"
|
|
74
|
+
numeric_cols = [f.name for f in df.schema.fields if str(f.dataType) in ["IntegerType()", "DoubleType()", "FloatType()"] and f.name != target_col]
|
|
75
|
+
|
|
76
|
+
assembler = VectorAssembler(inputCols=numeric_cols, outputCol="features", handleInvalid="skip")
|
|
77
|
+
df_vec = assembler.transform(df)
|
|
78
|
+
corr_matrix = Correlation.corr(df_vec, "features").head()[0].toArray()
|
|
79
|
+
print(f"Correlation matrix shape: {{corr_matrix.shape}}")'''),
|
|
80
|
+
self.cb.section("Remove Highly Correlated Features"),
|
|
81
|
+
self.cb.code(f'''import numpy as np
|
|
82
|
+
|
|
83
|
+
correlation_threshold = {self.config.correlation_threshold}
|
|
84
|
+
to_drop = set()
|
|
85
|
+
for i in range(len(corr_matrix)):
|
|
86
|
+
for j in range(i+1, len(corr_matrix)):
|
|
87
|
+
if abs(corr_matrix[i,j]) > correlation_threshold:
|
|
88
|
+
to_drop.add(numeric_cols[j])
|
|
89
|
+
|
|
90
|
+
selected_cols = [c for c in numeric_cols if c not in to_drop]
|
|
91
|
+
print(f"Dropped {{len(to_drop)}} highly correlated features, keeping {{len(selected_cols)}}")'''),
|
|
92
|
+
self.cb.section("Save Selected Features"),
|
|
93
|
+
self.cb.code(f'''final_cols = {self.get_identifier_columns()} + selected_cols + [target_col]
|
|
94
|
+
df_selected = df.select(final_cols)
|
|
95
|
+
df_selected.write.format("delta").mode("overwrite").saveAsTable("{catalog}.{schema}.gold_selected")
|
|
96
|
+
print("Selected features saved")'''),
|
|
97
|
+
]
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import nbformat
|
|
4
|
+
|
|
5
|
+
from ..base import NotebookStage
|
|
6
|
+
from .base_stage import StageGenerator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelTrainingStage(StageGenerator):
|
|
10
|
+
@property
|
|
11
|
+
def stage(self) -> NotebookStage:
|
|
12
|
+
return NotebookStage.MODEL_TRAINING
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def title(self) -> str:
|
|
16
|
+
return "07 - Model Training"
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def description(self) -> str:
|
|
20
|
+
return "Train baseline models with MLflow experiment tracking."
|
|
21
|
+
|
|
22
|
+
def generate_local_cells(self) -> List[nbformat.NotebookNode]:
|
|
23
|
+
target = self.get_target_column()
|
|
24
|
+
test_size = self.config.test_size
|
|
25
|
+
exp_name = self.config.mlflow.experiment_name
|
|
26
|
+
tracking_uri = self.config.mlflow.tracking_uri
|
|
27
|
+
return self.header_cells() + [
|
|
28
|
+
self.cb.section("Imports"),
|
|
29
|
+
self.cb.from_imports_cell({
|
|
30
|
+
"customer_retention.stages.modeling": ["BaselineTrainer", "ModelEvaluator", "DataSplitter"],
|
|
31
|
+
"customer_retention.integrations.adapters": ["get_mlflow"],
|
|
32
|
+
"customer_retention.analysis.visualization": ["ChartBuilder"],
|
|
33
|
+
"customer_retention.stages.temporal": ["SnapshotManager"],
|
|
34
|
+
"customer_retention.analysis.diagnostics": ["LeakageDetector"],
|
|
35
|
+
"pathlib": ["Path"],
|
|
36
|
+
"pandas": ["pd"],
|
|
37
|
+
}),
|
|
38
|
+
self.cb.section("Load Training Snapshot"),
|
|
39
|
+
self.cb.markdown('''**Important**: We load from a versioned snapshot to ensure reproducibility and prevent data leakage.'''),
|
|
40
|
+
self.cb.code('''snapshot_manager = SnapshotManager(Path("./experiments/data"))
|
|
41
|
+
latest_snapshot = snapshot_manager.get_latest_snapshot()
|
|
42
|
+
|
|
43
|
+
if latest_snapshot:
|
|
44
|
+
df, snapshot_metadata = snapshot_manager.load_snapshot(latest_snapshot)
|
|
45
|
+
print(f"Loaded snapshot: {latest_snapshot}")
|
|
46
|
+
print(f"Snapshot cutoff date: {snapshot_metadata.cutoff_date}")
|
|
47
|
+
print(f"Data hash: {snapshot_metadata.data_hash}")
|
|
48
|
+
print(f"Rows: {snapshot_metadata.row_count}")
|
|
49
|
+
else:
|
|
50
|
+
from customer_retention.integrations.adapters.factory import get_delta
|
|
51
|
+
storage = get_delta(force_local=True)
|
|
52
|
+
df = storage.read("./experiments/data/gold/customers_selected")
|
|
53
|
+
snapshot_metadata = None
|
|
54
|
+
print(f"Warning: No snapshot found, loading from gold layer: {df.shape}")'''),
|
|
55
|
+
self.cb.section("Prepare Train/Test Split"),
|
|
56
|
+
self.cb.code(f'''target_col = "target" if "target" in df.columns else "{target}"
|
|
57
|
+
id_cols = ["entity_id"] if "entity_id" in df.columns else {self.get_identifier_columns()}
|
|
58
|
+
temporal_cols = ["feature_timestamp", "label_timestamp", "label_available_flag"]
|
|
59
|
+
exclude_cols = id_cols + [target_col] + temporal_cols
|
|
60
|
+
|
|
61
|
+
feature_cols = [c for c in df.columns if c not in exclude_cols]
|
|
62
|
+
print(f"Using {{len(feature_cols)}} features (excluded: {{exclude_cols}})")
|
|
63
|
+
|
|
64
|
+
X = df[feature_cols]
|
|
65
|
+
y = df[target_col]
|
|
66
|
+
|
|
67
|
+
splitter = DataSplitter(test_size={test_size}, stratify=True, random_state=42)
|
|
68
|
+
X_train, X_test, y_train, y_test = splitter.split(X, y)
|
|
69
|
+
print(f"Train: {{len(X_train)}}, Test: {{len(X_test)}}")'''),
|
|
70
|
+
self.cb.section("Run Leakage Detection"),
|
|
71
|
+
self.cb.code('''detector = LeakageDetector()
|
|
72
|
+
leakage_result = detector.run_all_checks(X_train, y_train)
|
|
73
|
+
|
|
74
|
+
if not leakage_result.passed:
|
|
75
|
+
print("WARNING: Leakage detected!")
|
|
76
|
+
for issue in leakage_result.critical_issues:
|
|
77
|
+
print(f" CRITICAL: {issue.feature} - {issue.recommendation}")
|
|
78
|
+
else:
|
|
79
|
+
print("Leakage check PASSED")'''),
|
|
80
|
+
self.cb.section("Setup MLflow Tracking"),
|
|
81
|
+
self.cb.code(f'''mlflow_adapter = get_mlflow(tracking_uri="{tracking_uri}", force_local=True)
|
|
82
|
+
experiment_name = "{exp_name}"
|
|
83
|
+
print(f"MLflow tracking URI: {tracking_uri}")
|
|
84
|
+
|
|
85
|
+
snapshot_params = {{}}
|
|
86
|
+
if snapshot_metadata:
|
|
87
|
+
snapshot_params = {{
|
|
88
|
+
"snapshot_id": snapshot_metadata.snapshot_id,
|
|
89
|
+
"snapshot_version": snapshot_metadata.version,
|
|
90
|
+
"snapshot_cutoff": str(snapshot_metadata.cutoff_date),
|
|
91
|
+
"snapshot_hash": snapshot_metadata.data_hash,
|
|
92
|
+
}}'''),
|
|
93
|
+
self.cb.section("Train Baseline Models"),
|
|
94
|
+
self.cb.code('''from sklearn.linear_model import LogisticRegression
|
|
95
|
+
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
|
|
96
|
+
|
|
97
|
+
models = {
|
|
98
|
+
"logistic_regression": LogisticRegression(class_weight="balanced", max_iter=1000),
|
|
99
|
+
"random_forest": RandomForestClassifier(class_weight="balanced", n_estimators=100, random_state=42),
|
|
100
|
+
"gradient_boosting": GradientBoostingClassifier(n_estimators=100, random_state=42),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
results = {}
|
|
104
|
+
for name, model in models.items():
|
|
105
|
+
mlflow_adapter.start_run(experiment_name, run_name=name)
|
|
106
|
+
model.fit(X_train, y_train)
|
|
107
|
+
y_pred = model.predict(X_test)
|
|
108
|
+
y_prob = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else y_pred
|
|
109
|
+
|
|
110
|
+
evaluator = ModelEvaluator()
|
|
111
|
+
metrics = evaluator.evaluate(y_test, y_pred, y_prob)
|
|
112
|
+
results[name] = {"model": model, "metrics": metrics, "y_pred": y_pred, "y_prob": y_prob}
|
|
113
|
+
|
|
114
|
+
all_params = {**model.get_params(), **snapshot_params}
|
|
115
|
+
mlflow_adapter.log_params(all_params)
|
|
116
|
+
mlflow_adapter.log_metrics(metrics)
|
|
117
|
+
mlflow_adapter.log_model(model, "model")
|
|
118
|
+
mlflow_adapter.end_run()
|
|
119
|
+
print(f"{name}: AUC={metrics.get('roc_auc', 0):.4f}, F1={metrics.get('f1', 0):.4f}")'''),
|
|
120
|
+
self.cb.section("Compare Models"),
|
|
121
|
+
self.cb.code('''charts = ChartBuilder()
|
|
122
|
+
fig = charts.model_comparison_grid(results, y_test)
|
|
123
|
+
fig.show()'''),
|
|
124
|
+
self.cb.section("Save Best Model"),
|
|
125
|
+
self.cb.code('''best_model_name = max(results, key=lambda k: results[k]["metrics"].get("roc_auc", 0))
|
|
126
|
+
best_model = results[best_model_name]["model"]
|
|
127
|
+
import joblib
|
|
128
|
+
joblib.dump(best_model, "./experiments/data/models/best_model.joblib")
|
|
129
|
+
print(f"Best model: {best_model_name}")'''),
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
|
|
133
|
+
catalog = self.config.feature_store.catalog
|
|
134
|
+
schema = self.config.feature_store.schema
|
|
135
|
+
target = self.get_target_column()
|
|
136
|
+
exp_name = self.config.mlflow.experiment_name
|
|
137
|
+
model_name = self.config.mlflow.model_name
|
|
138
|
+
return self.header_cells() + [
|
|
139
|
+
self.cb.section("Load Selected Features"),
|
|
140
|
+
self.cb.code(f'''df = spark.table("{catalog}.{schema}.gold_selected")'''),
|
|
141
|
+
self.cb.section("Prepare Features Vector"),
|
|
142
|
+
self.cb.code(f'''from pyspark.ml.feature import VectorAssembler
|
|
143
|
+
|
|
144
|
+
target_col = "{target}"
|
|
145
|
+
feature_cols = [c for c in df.columns if c not in {self.get_identifier_columns()} + [target_col]]
|
|
146
|
+
|
|
147
|
+
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features", handleInvalid="skip")
|
|
148
|
+
df_ml = assembler.transform(df).select("features", target_col)
|
|
149
|
+
train_df, test_df = df_ml.randomSplit([0.8, 0.2], seed=42)
|
|
150
|
+
print(f"Train: {{train_df.count()}}, Test: {{test_df.count()}}")'''),
|
|
151
|
+
self.cb.section("Setup MLflow"),
|
|
152
|
+
self.cb.code(f'''import mlflow
|
|
153
|
+
mlflow.set_experiment("/Users/{{spark.conf.get('spark.databricks.notebook.username', 'default')}}/{exp_name}")'''),
|
|
154
|
+
self.cb.section("Train Gradient Boosted Trees"),
|
|
155
|
+
self.cb.code(f'''from pyspark.ml.classification import GBTClassifier
|
|
156
|
+
from pyspark.ml.evaluation import BinaryClassificationEvaluator
|
|
157
|
+
|
|
158
|
+
with mlflow.start_run(run_name="gbt_baseline"):
|
|
159
|
+
gbt = GBTClassifier(featuresCol="features", labelCol="{target}", maxIter=100)
|
|
160
|
+
model = gbt.fit(train_df)
|
|
161
|
+
|
|
162
|
+
predictions = model.transform(test_df)
|
|
163
|
+
evaluator = BinaryClassificationEvaluator(labelCol="{target}", metricName="areaUnderROC")
|
|
164
|
+
auc = evaluator.evaluate(predictions)
|
|
165
|
+
|
|
166
|
+
mlflow.log_param("maxIter", 100)
|
|
167
|
+
mlflow.log_metric("auc_roc", auc)
|
|
168
|
+
mlflow.spark.log_model(model, "model")
|
|
169
|
+
|
|
170
|
+
run_id = mlflow.active_run().info.run_id
|
|
171
|
+
print(f"AUC: {{auc:.4f}}, Run ID: {{run_id}}")'''),
|
|
172
|
+
self.cb.section("Register Model"),
|
|
173
|
+
self.cb.code(f'''model_uri = f"runs:/{{run_id}}/model"
|
|
174
|
+
mlflow.register_model(model_uri, "{catalog}.{schema}.{model_name}")
|
|
175
|
+
print(f"Model registered: {catalog}.{schema}.{model_name}")'''),
|
|
176
|
+
]
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import nbformat
|
|
4
|
+
|
|
5
|
+
from ..base import NotebookStage
|
|
6
|
+
from .base_stage import StageGenerator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DeploymentStage(StageGenerator):
|
|
10
|
+
@property
|
|
11
|
+
def stage(self) -> NotebookStage:
|
|
12
|
+
return NotebookStage.DEPLOYMENT
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def title(self) -> str:
|
|
16
|
+
return "08 - Model Deployment"
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def description(self) -> str:
|
|
20
|
+
return "Register model to registry and promote to production."
|
|
21
|
+
|
|
22
|
+
def generate_local_cells(self) -> List[nbformat.NotebookNode]:
|
|
23
|
+
tracking_uri = self.config.mlflow.tracking_uri
|
|
24
|
+
model_name = self.config.mlflow.model_name
|
|
25
|
+
return self.header_cells() + [
|
|
26
|
+
self.cb.section("Imports"),
|
|
27
|
+
self.cb.from_imports_cell({
|
|
28
|
+
"customer_retention.stages.deployment": ["ModelRegistry", "ModelStage"],
|
|
29
|
+
"customer_retention.integrations.adapters": ["get_mlflow"],
|
|
30
|
+
}),
|
|
31
|
+
self.cb.section("Initialize Registry"),
|
|
32
|
+
self.cb.code(f'''mlflow_adapter = get_mlflow(tracking_uri="{tracking_uri}", force_local=True)
|
|
33
|
+
registry = ModelRegistry(tracking_uri="{tracking_uri}")
|
|
34
|
+
model_name = "{model_name}"'''),
|
|
35
|
+
self.cb.section("List Model Versions"),
|
|
36
|
+
self.cb.code('''versions = registry.list_versions(model_name)
|
|
37
|
+
for v in versions:
|
|
38
|
+
print(f"Version {v.version}: Stage={v.current_stage}, Run={v.run_id}")'''),
|
|
39
|
+
self.cb.section("Validate for Promotion"),
|
|
40
|
+
self.cb.code('''latest_version = max(versions, key=lambda v: int(v.version)).version if versions else "1"
|
|
41
|
+
validation = registry.validate_for_promotion(
|
|
42
|
+
model_name=model_name,
|
|
43
|
+
version=latest_version,
|
|
44
|
+
required_metrics={"roc_auc": 0.6},
|
|
45
|
+
)
|
|
46
|
+
print(f"Validation passed: {validation.is_valid}")
|
|
47
|
+
if not validation.is_valid:
|
|
48
|
+
print(f"Errors: {validation.errors}")'''),
|
|
49
|
+
self.cb.section("Promote to Production"),
|
|
50
|
+
self.cb.code('''if validation.is_valid:
|
|
51
|
+
registry.transition_stage(model_name, latest_version, ModelStage.PRODUCTION)
|
|
52
|
+
print(f"Model {model_name} v{latest_version} promoted to Production")
|
|
53
|
+
else:
|
|
54
|
+
print("Model not promoted due to validation failure")'''),
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
|
|
58
|
+
catalog = self.config.feature_store.catalog
|
|
59
|
+
schema = self.config.feature_store.schema
|
|
60
|
+
model_name = self.config.mlflow.model_name
|
|
61
|
+
return self.header_cells() + [
|
|
62
|
+
self.cb.section("Initialize MLflow Client"),
|
|
63
|
+
self.cb.code('''import mlflow
|
|
64
|
+
from mlflow.tracking import MlflowClient
|
|
65
|
+
|
|
66
|
+
client = MlflowClient()'''),
|
|
67
|
+
self.cb.section("Get Model Versions"),
|
|
68
|
+
self.cb.code(f'''model_full_name = "{catalog}.{schema}.{model_name}"
|
|
69
|
+
versions = client.search_model_versions(f"name='{{model_full_name}}'")
|
|
70
|
+
for v in versions:
|
|
71
|
+
print(f"Version {{v.version}}: Status={{v.status}}")'''),
|
|
72
|
+
self.cb.section("Get Latest Version"),
|
|
73
|
+
self.cb.code('''latest = max(versions, key=lambda v: int(v.version))
|
|
74
|
+
print(f"Latest version: {latest.version}")'''),
|
|
75
|
+
self.cb.section("Set Production Alias"),
|
|
76
|
+
self.cb.code('''client.set_registered_model_alias(model_full_name, "production", latest.version)
|
|
77
|
+
print(f"Model {model_full_name} v{latest.version} aliased as 'production'")'''),
|
|
78
|
+
self.cb.section("Verify Production Model"),
|
|
79
|
+
self.cb.code('''prod_version = client.get_model_version_by_alias(model_full_name, "production")
|
|
80
|
+
print(f"Production model version: {prod_version.version}")'''),
|
|
81
|
+
]
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import nbformat
|
|
4
|
+
|
|
5
|
+
from ..base import NotebookStage
|
|
6
|
+
from .base_stage import StageGenerator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MonitoringStage(StageGenerator):
|
|
10
|
+
@property
|
|
11
|
+
def stage(self) -> NotebookStage:
|
|
12
|
+
return NotebookStage.MONITORING
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def title(self) -> str:
|
|
16
|
+
return "09 - Model Monitoring"
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def description(self) -> str:
|
|
20
|
+
return "Track model performance, detect drift, and set up alerts."
|
|
21
|
+
|
|
22
|
+
def generate_local_cells(self) -> List[nbformat.NotebookNode]:
|
|
23
|
+
return self.header_cells() + [
|
|
24
|
+
self.cb.section("Imports"),
|
|
25
|
+
self.cb.from_imports_cell({
|
|
26
|
+
"customer_retention.stages.monitoring": ["PerformanceMonitor", "DriftDetector"],
|
|
27
|
+
"customer_retention.analysis.visualization": ["ChartBuilder"],
|
|
28
|
+
"pandas": ["pd"],
|
|
29
|
+
"joblib": ["joblib"],
|
|
30
|
+
}),
|
|
31
|
+
self.cb.section("Load Production Model and Test Data"),
|
|
32
|
+
self.cb.code('''model = joblib.load("./experiments/data/models/best_model.joblib")
|
|
33
|
+
df_test = pd.read_parquet("./experiments/data/gold/customers_selected.parquet").sample(n=1000, random_state=42)'''),
|
|
34
|
+
self.cb.section("Generate Predictions"),
|
|
35
|
+
self.cb.code(f'''target_col = "{self.get_target_column()}"
|
|
36
|
+
id_cols = {self.get_identifier_columns()}
|
|
37
|
+
feature_cols = [c for c in df_test.columns if c not in id_cols + [target_col]]
|
|
38
|
+
|
|
39
|
+
X_test = df_test[feature_cols]
|
|
40
|
+
y_test = df_test[target_col]
|
|
41
|
+
y_prob = model.predict_proba(X_test)[:, 1]
|
|
42
|
+
y_pred = (y_prob >= 0.5).astype(int)'''),
|
|
43
|
+
self.cb.section("Calculate Performance Metrics"),
|
|
44
|
+
self.cb.code('''from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
|
|
45
|
+
|
|
46
|
+
current_metrics = {
|
|
47
|
+
"roc_auc": roc_auc_score(y_test, y_prob),
|
|
48
|
+
"precision": precision_score(y_test, y_pred),
|
|
49
|
+
"recall": recall_score(y_test, y_pred),
|
|
50
|
+
"f1": f1_score(y_test, y_pred),
|
|
51
|
+
}
|
|
52
|
+
for name, value in current_metrics.items():
|
|
53
|
+
print(f"{name}: {value:.4f}")'''),
|
|
54
|
+
self.cb.section("Compare to Baseline"),
|
|
55
|
+
self.cb.code('''baseline_metrics = {"roc_auc": 0.75, "precision": 0.60, "recall": 0.70, "f1": 0.65}
|
|
56
|
+
monitor = PerformanceMonitor(baseline_metrics)
|
|
57
|
+
result = monitor.evaluate(current_metrics)
|
|
58
|
+
print(f"Status: {result.status}")
|
|
59
|
+
for metric, change in result.changes.items():
|
|
60
|
+
print(f" {metric}: {change:+.2%}")'''),
|
|
61
|
+
self.cb.section("Detect Feature Drift"),
|
|
62
|
+
self.cb.code('''df_reference = pd.read_parquet("./experiments/data/gold/customers_features.parquet").sample(n=1000, random_state=0)
|
|
63
|
+
drift_detector = DriftDetector()
|
|
64
|
+
for col in feature_cols[:5]:
|
|
65
|
+
result = drift_detector.detect(df_reference[col], df_test[col])
|
|
66
|
+
if result.has_drift:
|
|
67
|
+
print(f"DRIFT detected in {col}: PSI={result.psi:.4f}")'''),
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
|
|
71
|
+
catalog = self.config.feature_store.catalog
|
|
72
|
+
schema = self.config.feature_store.schema
|
|
73
|
+
model_name = self.config.mlflow.model_name
|
|
74
|
+
target = self.get_target_column()
|
|
75
|
+
return self.header_cells() + [
|
|
76
|
+
self.cb.section("Load Model and Data"),
|
|
77
|
+
self.cb.code(f'''import mlflow
|
|
78
|
+
|
|
79
|
+
model = mlflow.pyfunc.load_model(f"models:/{catalog}.{schema}.{model_name}@production")
|
|
80
|
+
df_test = spark.table("{catalog}.{schema}.gold_selected").sample(0.1)'''),
|
|
81
|
+
self.cb.section("Generate Predictions"),
|
|
82
|
+
self.cb.code(f'''from pyspark.sql.functions import pandas_udf
|
|
83
|
+
import pandas as pd
|
|
84
|
+
|
|
85
|
+
feature_cols = [c for c in df_test.columns if c not in {self.get_identifier_columns()} + ["{target}"]]
|
|
86
|
+
|
|
87
|
+
@pandas_udf("double")
|
|
88
|
+
def predict_udf(*cols):
|
|
89
|
+
df = pd.concat(cols, axis=1)
|
|
90
|
+
df.columns = feature_cols
|
|
91
|
+
return pd.Series(model.predict(df))
|
|
92
|
+
|
|
93
|
+
df_predictions = df_test.withColumn("prediction", predict_udf(*[df_test[c] for c in feature_cols]))
|
|
94
|
+
display(df_predictions.limit(10))'''),
|
|
95
|
+
self.cb.section("Calculate Metrics"),
|
|
96
|
+
self.cb.code(f'''from pyspark.ml.evaluation import BinaryClassificationEvaluator
|
|
97
|
+
|
|
98
|
+
evaluator = BinaryClassificationEvaluator(labelCol="{target}", rawPredictionCol="prediction")
|
|
99
|
+
auc = evaluator.evaluate(df_predictions)
|
|
100
|
+
print(f"Current AUC: {{auc:.4f}}")'''),
|
|
101
|
+
self.cb.section("Check for Drift"),
|
|
102
|
+
self.cb.code(f'''df_reference = spark.table("{catalog}.{schema}.gold_customers").sample(0.1)
|
|
103
|
+
|
|
104
|
+
for col in feature_cols[:5]:
|
|
105
|
+
ref_stats = df_reference.select(col).describe().collect()
|
|
106
|
+
cur_stats = df_test.select(col).describe().collect()
|
|
107
|
+
ref_mean = float(ref_stats[1][1]) if ref_stats[1][1] else 0
|
|
108
|
+
cur_mean = float(cur_stats[1][1]) if cur_stats[1][1] else 0
|
|
109
|
+
drift_pct = abs(ref_mean - cur_mean) / (ref_mean + 1e-10) * 100
|
|
110
|
+
if drift_pct > 10:
|
|
111
|
+
print(f"DRIFT in {{col}}: {{drift_pct:.1f}}% mean shift")'''),
|
|
112
|
+
]
|