churnkit 0.75.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
- churnkit-0.75.0a1.dist-info/METADATA +229 -0
- churnkit-0.75.0a1.dist-info/RECORD +302 -0
- churnkit-0.75.0a1.dist-info/WHEEL +4 -0
- churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
- churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
- customer_retention/__init__.py +37 -0
- customer_retention/analysis/__init__.py +0 -0
- customer_retention/analysis/auto_explorer/__init__.py +62 -0
- customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
- customer_retention/analysis/auto_explorer/explorer.py +258 -0
- customer_retention/analysis/auto_explorer/findings.py +291 -0
- customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
- customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
- customer_retention/analysis/auto_explorer/recommendations.py +418 -0
- customer_retention/analysis/business/__init__.py +26 -0
- customer_retention/analysis/business/ab_test_designer.py +144 -0
- customer_retention/analysis/business/fairness_analyzer.py +166 -0
- customer_retention/analysis/business/intervention_matcher.py +121 -0
- customer_retention/analysis/business/report_generator.py +222 -0
- customer_retention/analysis/business/risk_profile.py +199 -0
- customer_retention/analysis/business/roi_analyzer.py +139 -0
- customer_retention/analysis/diagnostics/__init__.py +20 -0
- customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
- customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
- customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
- customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
- customer_retention/analysis/diagnostics/noise_tester.py +140 -0
- customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
- customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
- customer_retention/analysis/discovery/__init__.py +8 -0
- customer_retention/analysis/discovery/config_generator.py +49 -0
- customer_retention/analysis/discovery/discovery_flow.py +19 -0
- customer_retention/analysis/discovery/type_inferencer.py +147 -0
- customer_retention/analysis/interpretability/__init__.py +13 -0
- customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
- customer_retention/analysis/interpretability/counterfactual.py +175 -0
- customer_retention/analysis/interpretability/individual_explainer.py +141 -0
- customer_retention/analysis/interpretability/pdp_generator.py +103 -0
- customer_retention/analysis/interpretability/shap_explainer.py +106 -0
- customer_retention/analysis/jupyter_save_hook.py +28 -0
- customer_retention/analysis/notebook_html_exporter.py +136 -0
- customer_retention/analysis/notebook_progress.py +60 -0
- customer_retention/analysis/plotly_preprocessor.py +154 -0
- customer_retention/analysis/recommendations/__init__.py +54 -0
- customer_retention/analysis/recommendations/base.py +158 -0
- customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
- customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
- customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
- customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
- customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
- customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
- customer_retention/analysis/recommendations/datetime/extract.py +149 -0
- customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
- customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
- customer_retention/analysis/recommendations/pipeline.py +74 -0
- customer_retention/analysis/recommendations/registry.py +76 -0
- customer_retention/analysis/recommendations/selection/__init__.py +3 -0
- customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
- customer_retention/analysis/recommendations/transform/__init__.py +4 -0
- customer_retention/analysis/recommendations/transform/power.py +94 -0
- customer_retention/analysis/recommendations/transform/scale.py +112 -0
- customer_retention/analysis/visualization/__init__.py +15 -0
- customer_retention/analysis/visualization/chart_builder.py +2619 -0
- customer_retention/analysis/visualization/console.py +122 -0
- customer_retention/analysis/visualization/display.py +171 -0
- customer_retention/analysis/visualization/number_formatter.py +36 -0
- customer_retention/artifacts/__init__.py +3 -0
- customer_retention/artifacts/fit_artifact_registry.py +146 -0
- customer_retention/cli.py +93 -0
- customer_retention/core/__init__.py +0 -0
- customer_retention/core/compat/__init__.py +193 -0
- customer_retention/core/compat/detection.py +99 -0
- customer_retention/core/compat/ops.py +48 -0
- customer_retention/core/compat/pandas_backend.py +57 -0
- customer_retention/core/compat/spark_backend.py +75 -0
- customer_retention/core/components/__init__.py +11 -0
- customer_retention/core/components/base.py +79 -0
- customer_retention/core/components/components/__init__.py +13 -0
- customer_retention/core/components/components/deployer.py +26 -0
- customer_retention/core/components/components/explainer.py +26 -0
- customer_retention/core/components/components/feature_eng.py +33 -0
- customer_retention/core/components/components/ingester.py +34 -0
- customer_retention/core/components/components/profiler.py +34 -0
- customer_retention/core/components/components/trainer.py +38 -0
- customer_retention/core/components/components/transformer.py +36 -0
- customer_retention/core/components/components/validator.py +37 -0
- customer_retention/core/components/enums.py +33 -0
- customer_retention/core/components/orchestrator.py +94 -0
- customer_retention/core/components/registry.py +59 -0
- customer_retention/core/config/__init__.py +39 -0
- customer_retention/core/config/column_config.py +95 -0
- customer_retention/core/config/experiments.py +71 -0
- customer_retention/core/config/pipeline_config.py +117 -0
- customer_retention/core/config/source_config.py +83 -0
- customer_retention/core/utils/__init__.py +28 -0
- customer_retention/core/utils/leakage.py +85 -0
- customer_retention/core/utils/severity.py +53 -0
- customer_retention/core/utils/statistics.py +90 -0
- customer_retention/generators/__init__.py +0 -0
- customer_retention/generators/notebook_generator/__init__.py +167 -0
- customer_retention/generators/notebook_generator/base.py +55 -0
- customer_retention/generators/notebook_generator/cell_builder.py +49 -0
- customer_retention/generators/notebook_generator/config.py +47 -0
- customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
- customer_retention/generators/notebook_generator/local_generator.py +48 -0
- customer_retention/generators/notebook_generator/project_init.py +174 -0
- customer_retention/generators/notebook_generator/runner.py +150 -0
- customer_retention/generators/notebook_generator/script_generator.py +110 -0
- customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
- customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
- customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
- customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
- customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
- customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
- customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
- customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
- customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
- customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
- customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
- customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
- customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
- customer_retention/generators/orchestration/__init__.py +23 -0
- customer_retention/generators/orchestration/code_generator.py +196 -0
- customer_retention/generators/orchestration/context.py +147 -0
- customer_retention/generators/orchestration/data_materializer.py +188 -0
- customer_retention/generators/orchestration/databricks_exporter.py +411 -0
- customer_retention/generators/orchestration/doc_generator.py +311 -0
- customer_retention/generators/pipeline_generator/__init__.py +26 -0
- customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
- customer_retention/generators/pipeline_generator/generator.py +142 -0
- customer_retention/generators/pipeline_generator/models.py +166 -0
- customer_retention/generators/pipeline_generator/renderer.py +2125 -0
- customer_retention/generators/spec_generator/__init__.py +37 -0
- customer_retention/generators/spec_generator/databricks_generator.py +433 -0
- customer_retention/generators/spec_generator/generic_generator.py +373 -0
- customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
- customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
- customer_retention/integrations/__init__.py +0 -0
- customer_retention/integrations/adapters/__init__.py +13 -0
- customer_retention/integrations/adapters/base.py +10 -0
- customer_retention/integrations/adapters/factory.py +25 -0
- customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
- customer_retention/integrations/adapters/feature_store/base.py +57 -0
- customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
- customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
- customer_retention/integrations/adapters/feature_store/local.py +75 -0
- customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
- customer_retention/integrations/adapters/mlflow/base.py +32 -0
- customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
- customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
- customer_retention/integrations/adapters/mlflow/local.py +50 -0
- customer_retention/integrations/adapters/storage/__init__.py +5 -0
- customer_retention/integrations/adapters/storage/base.py +33 -0
- customer_retention/integrations/adapters/storage/databricks.py +76 -0
- customer_retention/integrations/adapters/storage/local.py +59 -0
- customer_retention/integrations/feature_store/__init__.py +47 -0
- customer_retention/integrations/feature_store/definitions.py +215 -0
- customer_retention/integrations/feature_store/manager.py +744 -0
- customer_retention/integrations/feature_store/registry.py +412 -0
- customer_retention/integrations/iteration/__init__.py +28 -0
- customer_retention/integrations/iteration/context.py +212 -0
- customer_retention/integrations/iteration/feedback_collector.py +184 -0
- customer_retention/integrations/iteration/orchestrator.py +168 -0
- customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
- customer_retention/integrations/iteration/signals.py +212 -0
- customer_retention/integrations/llm_context/__init__.py +4 -0
- customer_retention/integrations/llm_context/context_builder.py +201 -0
- customer_retention/integrations/llm_context/prompts.py +100 -0
- customer_retention/integrations/streaming/__init__.py +103 -0
- customer_retention/integrations/streaming/batch_integration.py +149 -0
- customer_retention/integrations/streaming/early_warning_model.py +227 -0
- customer_retention/integrations/streaming/event_schema.py +214 -0
- customer_retention/integrations/streaming/online_store_writer.py +249 -0
- customer_retention/integrations/streaming/realtime_scorer.py +261 -0
- customer_retention/integrations/streaming/trigger_engine.py +293 -0
- customer_retention/integrations/streaming/window_aggregator.py +393 -0
- customer_retention/stages/__init__.py +0 -0
- customer_retention/stages/cleaning/__init__.py +9 -0
- customer_retention/stages/cleaning/base.py +28 -0
- customer_retention/stages/cleaning/missing_handler.py +160 -0
- customer_retention/stages/cleaning/outlier_handler.py +204 -0
- customer_retention/stages/deployment/__init__.py +28 -0
- customer_retention/stages/deployment/batch_scorer.py +106 -0
- customer_retention/stages/deployment/champion_challenger.py +299 -0
- customer_retention/stages/deployment/model_registry.py +182 -0
- customer_retention/stages/deployment/retraining_trigger.py +245 -0
- customer_retention/stages/features/__init__.py +73 -0
- customer_retention/stages/features/behavioral_features.py +266 -0
- customer_retention/stages/features/customer_segmentation.py +505 -0
- customer_retention/stages/features/feature_definitions.py +265 -0
- customer_retention/stages/features/feature_engineer.py +551 -0
- customer_retention/stages/features/feature_manifest.py +340 -0
- customer_retention/stages/features/feature_selector.py +239 -0
- customer_retention/stages/features/interaction_features.py +160 -0
- customer_retention/stages/features/temporal_features.py +243 -0
- customer_retention/stages/ingestion/__init__.py +9 -0
- customer_retention/stages/ingestion/load_result.py +32 -0
- customer_retention/stages/ingestion/loaders.py +195 -0
- customer_retention/stages/ingestion/source_registry.py +130 -0
- customer_retention/stages/modeling/__init__.py +31 -0
- customer_retention/stages/modeling/baseline_trainer.py +139 -0
- customer_retention/stages/modeling/cross_validator.py +125 -0
- customer_retention/stages/modeling/data_splitter.py +205 -0
- customer_retention/stages/modeling/feature_scaler.py +99 -0
- customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
- customer_retention/stages/modeling/imbalance_handler.py +282 -0
- customer_retention/stages/modeling/mlflow_logger.py +95 -0
- customer_retention/stages/modeling/model_comparator.py +149 -0
- customer_retention/stages/modeling/model_evaluator.py +138 -0
- customer_retention/stages/modeling/threshold_optimizer.py +131 -0
- customer_retention/stages/monitoring/__init__.py +37 -0
- customer_retention/stages/monitoring/alert_manager.py +328 -0
- customer_retention/stages/monitoring/drift_detector.py +201 -0
- customer_retention/stages/monitoring/performance_monitor.py +242 -0
- customer_retention/stages/preprocessing/__init__.py +5 -0
- customer_retention/stages/preprocessing/transformer_manager.py +284 -0
- customer_retention/stages/profiling/__init__.py +256 -0
- customer_retention/stages/profiling/categorical_distribution.py +269 -0
- customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
- customer_retention/stages/profiling/column_profiler.py +527 -0
- customer_retention/stages/profiling/distribution_analysis.py +483 -0
- customer_retention/stages/profiling/drift_detector.py +310 -0
- customer_retention/stages/profiling/feature_capacity.py +507 -0
- customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
- customer_retention/stages/profiling/profile_result.py +212 -0
- customer_retention/stages/profiling/quality_checks.py +1632 -0
- customer_retention/stages/profiling/relationship_detector.py +256 -0
- customer_retention/stages/profiling/relationship_recommender.py +454 -0
- customer_retention/stages/profiling/report_generator.py +520 -0
- customer_retention/stages/profiling/scd_analyzer.py +151 -0
- customer_retention/stages/profiling/segment_analyzer.py +632 -0
- customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
- customer_retention/stages/profiling/target_level_analyzer.py +217 -0
- customer_retention/stages/profiling/temporal_analyzer.py +388 -0
- customer_retention/stages/profiling/temporal_coverage.py +488 -0
- customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
- customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
- customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
- customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
- customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
- customer_retention/stages/profiling/text_embedder.py +87 -0
- customer_retention/stages/profiling/text_processor.py +115 -0
- customer_retention/stages/profiling/text_reducer.py +60 -0
- customer_retention/stages/profiling/time_series_profiler.py +303 -0
- customer_retention/stages/profiling/time_window_aggregator.py +376 -0
- customer_retention/stages/profiling/type_detector.py +382 -0
- customer_retention/stages/profiling/window_recommendation.py +288 -0
- customer_retention/stages/temporal/__init__.py +166 -0
- customer_retention/stages/temporal/access_guard.py +180 -0
- customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
- customer_retention/stages/temporal/data_preparer.py +178 -0
- customer_retention/stages/temporal/point_in_time_join.py +134 -0
- customer_retention/stages/temporal/point_in_time_registry.py +148 -0
- customer_retention/stages/temporal/scenario_detector.py +163 -0
- customer_retention/stages/temporal/snapshot_manager.py +259 -0
- customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
- customer_retention/stages/temporal/timestamp_discovery.py +531 -0
- customer_retention/stages/temporal/timestamp_manager.py +255 -0
- customer_retention/stages/transformation/__init__.py +13 -0
- customer_retention/stages/transformation/binary_handler.py +85 -0
- customer_retention/stages/transformation/categorical_encoder.py +245 -0
- customer_retention/stages/transformation/datetime_transformer.py +97 -0
- customer_retention/stages/transformation/numeric_transformer.py +181 -0
- customer_retention/stages/transformation/pipeline.py +257 -0
- customer_retention/stages/validation/__init__.py +60 -0
- customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
- customer_retention/stages/validation/business_sense_gate.py +173 -0
- customer_retention/stages/validation/data_quality_gate.py +235 -0
- customer_retention/stages/validation/data_validators.py +511 -0
- customer_retention/stages/validation/feature_quality_gate.py +183 -0
- customer_retention/stages/validation/gates.py +117 -0
- customer_retention/stages/validation/leakage_gate.py +352 -0
- customer_retention/stages/validation/model_validity_gate.py +213 -0
- customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
- customer_retention/stages/validation/quality_scorer.py +544 -0
- customer_retention/stages/validation/rule_generator.py +57 -0
- customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
- customer_retention/stages/validation/timeseries_detector.py +769 -0
- customer_retention/transforms/__init__.py +47 -0
- customer_retention/transforms/artifact_store.py +50 -0
- customer_retention/transforms/executor.py +157 -0
- customer_retention/transforms/fitted.py +92 -0
- customer_retention/transforms/ops.py +148 -0
|
@@ -0,0 +1,642 @@
|
|
|
1
|
+
"""Batch inference stage for notebook generation.
|
|
2
|
+
|
|
3
|
+
This stage generates notebooks that perform batch scoring using the feature store
|
|
4
|
+
with point-in-time correctness for feature retrieval.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import nbformat
|
|
10
|
+
|
|
11
|
+
from ..base import NotebookStage
|
|
12
|
+
from .base_stage import StageGenerator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BatchInferenceStage(StageGenerator):
|
|
16
|
+
@property
|
|
17
|
+
def stage(self) -> NotebookStage:
|
|
18
|
+
return NotebookStage.BATCH_INFERENCE
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def title(self) -> str:
|
|
22
|
+
return "10 - Batch Inference with Point-in-Time Features"
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def description(self) -> str:
|
|
26
|
+
return """Score customers in batch using the production model with point-in-time correct feature retrieval.
|
|
27
|
+
|
|
28
|
+
**Key Concepts:**
|
|
29
|
+
- **Point-in-Time (PIT) Correctness**: Features are retrieved as they existed at inference time
|
|
30
|
+
- **Inference Timestamp**: The moment when predictions are made, ensuring no future data leakage
|
|
31
|
+
- **Feature Store Integration**: Uses Feast (local) or Databricks Feature Store for consistent feature retrieval
|
|
32
|
+
|
|
33
|
+
This notebook:
|
|
34
|
+
1. Sets the inference timestamp (point-in-time for prediction)
|
|
35
|
+
2. Retrieves features from the feature store with PIT correctness
|
|
36
|
+
3. Scores customers using the production model
|
|
37
|
+
4. Generates a dashboard showing predictions with the inference timestamp
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def generate_local_cells(self) -> List[nbformat.NotebookNode]:
|
|
41
|
+
"""Generate cells for local Feast-based batch inference."""
|
|
42
|
+
threshold = self.config.threshold
|
|
43
|
+
return self.header_cells() + [
|
|
44
|
+
self.cb.section("1. Setup and Imports"),
|
|
45
|
+
self.cb.code('''import pandas as pd
|
|
46
|
+
import numpy as np
|
|
47
|
+
from pathlib import Path
|
|
48
|
+
from datetime import datetime
|
|
49
|
+
import joblib
|
|
50
|
+
import plotly.graph_objects as go
|
|
51
|
+
from plotly.subplots import make_subplots
|
|
52
|
+
|
|
53
|
+
from customer_retention.integrations.feature_store import FeatureStoreManager, FeatureRegistry
|
|
54
|
+
from customer_retention.stages.temporal import SnapshotManager
|
|
55
|
+
|
|
56
|
+
print("Batch inference imports loaded")'''),
|
|
57
|
+
|
|
58
|
+
self.cb.section("2. Set Inference Point-in-Time"),
|
|
59
|
+
self.cb.markdown('''**Critical**: The inference timestamp determines what point in time we use to retrieve features.
|
|
60
|
+
This ensures we only use data that was available at the time of prediction (no future leakage).
|
|
61
|
+
|
|
62
|
+
- For **real-time inference**: Use `datetime.now()`
|
|
63
|
+
- For **historical backtesting**: Use a specific past timestamp
|
|
64
|
+
- For **scheduled batch jobs**: Use the job execution timestamp'''),
|
|
65
|
+
self.cb.code('''# INFERENCE_TIMESTAMP: The point-in-time for this batch prediction
|
|
66
|
+
# This is the "as-of" time for feature retrieval
|
|
67
|
+
INFERENCE_TIMESTAMP = datetime.now()
|
|
68
|
+
|
|
69
|
+
# Alternative: Use a specific historical timestamp for backtesting
|
|
70
|
+
# INFERENCE_TIMESTAMP = datetime(2024, 1, 15, 0, 0, 0)
|
|
71
|
+
|
|
72
|
+
print(f"=" * 70)
|
|
73
|
+
print(f"INFERENCE POINT-IN-TIME: {INFERENCE_TIMESTAMP}")
|
|
74
|
+
print(f"=" * 70)
|
|
75
|
+
print(f"All features will be retrieved as they existed at this timestamp.")
|
|
76
|
+
print(f"This ensures no future data leakage in predictions.")'''),
|
|
77
|
+
|
|
78
|
+
self.cb.section("3. Load Production Model"),
|
|
79
|
+
self.cb.code('''# Load the production model
|
|
80
|
+
model_path = Path("./experiments/data/models/best_model.joblib")
|
|
81
|
+
if not model_path.exists():
|
|
82
|
+
raise FileNotFoundError(f"Model not found at {model_path}. Run training first.")
|
|
83
|
+
|
|
84
|
+
model = joblib.load(model_path)
|
|
85
|
+
print(f"Model loaded: {type(model).__name__}")
|
|
86
|
+
|
|
87
|
+
# Load feature registry to know which features to retrieve
|
|
88
|
+
registry_path = Path("./experiments/feature_store/feature_registry.json")
|
|
89
|
+
if registry_path.exists():
|
|
90
|
+
registry = FeatureRegistry.load(registry_path)
|
|
91
|
+
print(f"Feature registry loaded: {len(registry)} features")
|
|
92
|
+
else:
|
|
93
|
+
print("Warning: No feature registry found. Using model feature names.")
|
|
94
|
+
registry = None'''),
|
|
95
|
+
|
|
96
|
+
self.cb.section("4. Initialize Feature Store Manager"),
|
|
97
|
+
self.cb.code('''# Create feature store manager (Feast backend for local)
|
|
98
|
+
manager = FeatureStoreManager.create(
|
|
99
|
+
backend="feast",
|
|
100
|
+
repo_path="./experiments/feature_store/feature_repo",
|
|
101
|
+
output_path="./experiments/data",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
print(f"Feature store initialized")
|
|
105
|
+
print(f"Available tables: {manager.list_tables()}")'''),
|
|
106
|
+
|
|
107
|
+
self.cb.section("5. Load Customers to Score"),
|
|
108
|
+
self.cb.code(f'''# Load the entities (customers) to score
|
|
109
|
+
# These are customers we want to make predictions for
|
|
110
|
+
|
|
111
|
+
# Option 1: Load from a specific file
|
|
112
|
+
from customer_retention.integrations.adapters.factory import get_delta
|
|
113
|
+
storage = get_delta(force_local=True)
|
|
114
|
+
|
|
115
|
+
customers_delta = Path("./experiments/data/gold/customers_to_score")
|
|
116
|
+
customers_parquet = Path("./experiments/data/gold/customers_to_score.parquet")
|
|
117
|
+
if customers_delta.is_dir() and (customers_delta / "_delta_log").is_dir():
|
|
118
|
+
df_customers = storage.read(str(customers_delta))
|
|
119
|
+
elif customers_parquet.exists():
|
|
120
|
+
df_customers = pd.read_parquet(customers_parquet)
|
|
121
|
+
else:
|
|
122
|
+
gold_delta = Path("./experiments/data/gold/customers_features")
|
|
123
|
+
gold_parquet = Path("./experiments/data/gold/customers_features.parquet")
|
|
124
|
+
if gold_delta.is_dir() and (gold_delta / "_delta_log").is_dir():
|
|
125
|
+
df_customers = storage.read(str(gold_delta))
|
|
126
|
+
elif gold_parquet.exists():
|
|
127
|
+
df_customers = pd.read_parquet(gold_parquet)
|
|
128
|
+
else:
|
|
129
|
+
# Option 3: Fall back to latest snapshot
|
|
130
|
+
snapshot_manager = SnapshotManager(Path("./experiments/data"))
|
|
131
|
+
latest = snapshot_manager.get_latest_snapshot()
|
|
132
|
+
if latest:
|
|
133
|
+
df_customers, _ = snapshot_manager.load_snapshot(latest)
|
|
134
|
+
else:
|
|
135
|
+
raise FileNotFoundError("No customer data found")
|
|
136
|
+
|
|
137
|
+
# Ensure entity_id column exists
|
|
138
|
+
id_cols = {self.get_identifier_columns()}
|
|
139
|
+
entity_col = id_cols[0] if id_cols else "customer_id"
|
|
140
|
+
if entity_col not in df_customers.columns and "entity_id" in df_customers.columns:
|
|
141
|
+
entity_col = "entity_id"
|
|
142
|
+
|
|
143
|
+
print(f"Loaded {{len(df_customers):,}} customers to score")
|
|
144
|
+
print(f"Entity column: {{entity_col}}")'''),
|
|
145
|
+
|
|
146
|
+
self.cb.section("6. Retrieve Features with Point-in-Time Correctness"),
|
|
147
|
+
self.cb.markdown('''The feature store retrieves features as they existed at the **inference timestamp**.
|
|
148
|
+
This is crucial for:
|
|
149
|
+
- **Training-serving consistency**: Same features used in training and inference
|
|
150
|
+
- **No future leakage**: Only data available at prediction time is used
|
|
151
|
+
- **Reproducibility**: Same timestamp always gives same features'''),
|
|
152
|
+
self.cb.code('''# Create entity DataFrame with inference timestamp
|
|
153
|
+
# All customers get the same inference timestamp for this batch
|
|
154
|
+
entity_df = df_customers[[entity_col]].copy()
|
|
155
|
+
entity_df = entity_df.rename(columns={entity_col: "entity_id"})
|
|
156
|
+
entity_df["event_timestamp"] = INFERENCE_TIMESTAMP
|
|
157
|
+
|
|
158
|
+
print(f"Retrieving features for {len(entity_df):,} entities")
|
|
159
|
+
print(f"Point-in-Time: {INFERENCE_TIMESTAMP}")
|
|
160
|
+
|
|
161
|
+
# Get features from feature store with PIT correctness
|
|
162
|
+
if registry:
|
|
163
|
+
feature_names = registry.list_features()
|
|
164
|
+
else:
|
|
165
|
+
# Fall back to model feature names if available
|
|
166
|
+
feature_names = getattr(model, 'feature_names_in_', None)
|
|
167
|
+
if feature_names is None:
|
|
168
|
+
raise ValueError("Cannot determine feature names. Please provide a feature registry.")
|
|
169
|
+
feature_names = list(feature_names)
|
|
170
|
+
|
|
171
|
+
# Retrieve point-in-time correct features
|
|
172
|
+
inference_df = manager.get_inference_features(
|
|
173
|
+
entity_df=entity_df,
|
|
174
|
+
registry=registry,
|
|
175
|
+
feature_names=feature_names,
|
|
176
|
+
table_name="customer_features",
|
|
177
|
+
timestamp_column="event_timestamp",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
print(f"Retrieved {len(inference_df.columns)} features for {len(inference_df):,} customers")
|
|
181
|
+
print(f"Feature retrieval timestamp: {INFERENCE_TIMESTAMP}")'''),
|
|
182
|
+
|
|
183
|
+
self.cb.section("7. Generate Predictions"),
|
|
184
|
+
self.cb.code(f'''# Prepare features for prediction
|
|
185
|
+
# Remove non-feature columns
|
|
186
|
+
meta_cols = ["entity_id", "event_timestamp"]
|
|
187
|
+
feature_cols = [c for c in inference_df.columns if c not in meta_cols]
|
|
188
|
+
|
|
189
|
+
X = inference_df[feature_cols]
|
|
190
|
+
|
|
191
|
+
# Handle any missing values from feature retrieval
|
|
192
|
+
missing_pct = X.isnull().sum().sum() / (len(X) * len(X.columns)) * 100
|
|
193
|
+
if missing_pct > 0:
|
|
194
|
+
print(f"Warning: {{missing_pct:.2f}}% missing values in features")
|
|
195
|
+
X = X.fillna(X.median())
|
|
196
|
+
|
|
197
|
+
# Generate predictions
|
|
198
|
+
threshold = {threshold}
|
|
199
|
+
y_prob = model.predict_proba(X)[:, 1]
|
|
200
|
+
y_pred = (y_prob >= threshold).astype(int)
|
|
201
|
+
|
|
202
|
+
# Add predictions to results
|
|
203
|
+
results_df = inference_df[["entity_id"]].copy()
|
|
204
|
+
results_df["churn_probability"] = y_prob
|
|
205
|
+
results_df["churn_prediction"] = y_pred
|
|
206
|
+
results_df["risk_tier"] = pd.cut(
|
|
207
|
+
y_prob,
|
|
208
|
+
bins=[0, 0.3, 0.6, 1.0],
|
|
209
|
+
labels=["Low", "Medium", "High"]
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Add inference metadata
|
|
213
|
+
results_df["inference_timestamp"] = INFERENCE_TIMESTAMP
|
|
214
|
+
results_df["model_version"] = str(model_path)
|
|
215
|
+
|
|
216
|
+
print(f"Predictions generated for {{len(results_df):,}} customers")
|
|
217
|
+
print(f"Threshold: {{threshold}}")'''),
|
|
218
|
+
|
|
219
|
+
self.cb.section("8. Prediction Summary Dashboard"),
|
|
220
|
+
self.cb.markdown('''This dashboard shows the batch scoring results with the **point-in-time** used for inference prominently displayed.'''),
|
|
221
|
+
self.cb.code('''# Create summary statistics
|
|
222
|
+
total_customers = len(results_df)
|
|
223
|
+
predicted_churners = results_df["churn_prediction"].sum()
|
|
224
|
+
churn_rate = predicted_churners / total_customers * 100
|
|
225
|
+
avg_probability = results_df["churn_probability"].mean()
|
|
226
|
+
|
|
227
|
+
risk_distribution = results_df["risk_tier"].value_counts()
|
|
228
|
+
|
|
229
|
+
print("=" * 70)
|
|
230
|
+
print("BATCH INFERENCE RESULTS DASHBOARD")
|
|
231
|
+
print("=" * 70)
|
|
232
|
+
print(f"")
|
|
233
|
+
print(f"š
INFERENCE POINT-IN-TIME: {INFERENCE_TIMESTAMP.strftime('%Y-%m-%d %H:%M:%S UTC')}")
|
|
234
|
+
print(f"")
|
|
235
|
+
print(f"š SUMMARY STATISTICS:")
|
|
236
|
+
print(f" Total Customers Scored: {total_customers:,}")
|
|
237
|
+
print(f" Predicted Churners: {predicted_churners:,} ({churn_rate:.1f}%)")
|
|
238
|
+
print(f" Average Churn Probability: {avg_probability:.3f}")
|
|
239
|
+
print(f"")
|
|
240
|
+
print(f"šÆ RISK DISTRIBUTION:")
|
|
241
|
+
for tier in ["High", "Medium", "Low"]:
|
|
242
|
+
count = risk_distribution.get(tier, 0)
|
|
243
|
+
pct = count / total_customers * 100
|
|
244
|
+
print(f" {tier}: {count:,} ({pct:.1f}%)")
|
|
245
|
+
print(f"")
|
|
246
|
+
print("=" * 70)'''),
|
|
247
|
+
|
|
248
|
+
self.cb.section("9. Interactive Results Dashboard"),
|
|
249
|
+
self.cb.code('''# Create interactive dashboard with Plotly
|
|
250
|
+
fig = make_subplots(
|
|
251
|
+
rows=2, cols=2,
|
|
252
|
+
subplot_titles=[
|
|
253
|
+
f"Risk Distribution (PIT: {INFERENCE_TIMESTAMP.strftime('%Y-%m-%d %H:%M')})",
|
|
254
|
+
"Churn Probability Distribution",
|
|
255
|
+
"Risk by Probability Range",
|
|
256
|
+
"Inference Metadata"
|
|
257
|
+
],
|
|
258
|
+
specs=[
|
|
259
|
+
[{"type": "pie"}, {"type": "histogram"}],
|
|
260
|
+
[{"type": "bar"}, {"type": "table"}]
|
|
261
|
+
]
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Risk tier pie chart
|
|
265
|
+
colors = {"High": "#e74c3c", "Medium": "#f39c12", "Low": "#27ae60"}
|
|
266
|
+
fig.add_trace(
|
|
267
|
+
go.Pie(
|
|
268
|
+
labels=risk_distribution.index.tolist(),
|
|
269
|
+
values=risk_distribution.values.tolist(),
|
|
270
|
+
marker_colors=[colors.get(tier, "#95a5a6") for tier in risk_distribution.index],
|
|
271
|
+
textinfo="label+percent",
|
|
272
|
+
hole=0.4
|
|
273
|
+
),
|
|
274
|
+
row=1, col=1
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Probability histogram
|
|
278
|
+
fig.add_trace(
|
|
279
|
+
go.Histogram(
|
|
280
|
+
x=results_df["churn_probability"],
|
|
281
|
+
nbinsx=50,
|
|
282
|
+
marker_color="#3498db",
|
|
283
|
+
name="Probability"
|
|
284
|
+
),
|
|
285
|
+
row=1, col=2
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Risk by probability range bar chart
|
|
289
|
+
prob_bins = pd.cut(results_df["churn_probability"], bins=[0, 0.2, 0.4, 0.6, 0.8, 1.0])
|
|
290
|
+
prob_counts = prob_bins.value_counts().sort_index()
|
|
291
|
+
fig.add_trace(
|
|
292
|
+
go.Bar(
|
|
293
|
+
x=[str(b) for b in prob_counts.index],
|
|
294
|
+
y=prob_counts.values,
|
|
295
|
+
marker_color="#9b59b6",
|
|
296
|
+
name="Count"
|
|
297
|
+
),
|
|
298
|
+
row=2, col=1
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Metadata table
|
|
302
|
+
metadata = [
|
|
303
|
+
["Metric", "Value"],
|
|
304
|
+
["Inference Point-in-Time", INFERENCE_TIMESTAMP.strftime('%Y-%m-%d %H:%M:%S')],
|
|
305
|
+
["Total Customers", f"{total_customers:,}"],
|
|
306
|
+
["Predicted Churners", f"{predicted_churners:,}"],
|
|
307
|
+
["Churn Rate", f"{churn_rate:.1f}%"],
|
|
308
|
+
["Model", type(model).__name__],
|
|
309
|
+
["Threshold", f"{threshold}"],
|
|
310
|
+
]
|
|
311
|
+
fig.add_trace(
|
|
312
|
+
go.Table(
|
|
313
|
+
header=dict(values=["Metric", "Value"], fill_color="#2c3e50", font=dict(color="white")),
|
|
314
|
+
cells=dict(values=[[row[0] for row in metadata[1:]], [row[1] for row in metadata[1:]]],
|
|
315
|
+
fill_color="#ecf0f1")
|
|
316
|
+
),
|
|
317
|
+
row=2, col=2
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
fig.update_layout(
|
|
321
|
+
title=dict(
|
|
322
|
+
text=f"<b>Batch Inference Dashboard</b><br><sub>Point-in-Time: {INFERENCE_TIMESTAMP.strftime('%Y-%m-%d %H:%M:%S UTC')}</sub>",
|
|
323
|
+
font=dict(size=20)
|
|
324
|
+
),
|
|
325
|
+
height=700,
|
|
326
|
+
showlegend=False,
|
|
327
|
+
template="plotly_white"
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
fig.show()'''),
|
|
331
|
+
|
|
332
|
+
self.cb.section("10. Save Predictions with Metadata"),
|
|
333
|
+
self.cb.code('''# Save predictions with full metadata
|
|
334
|
+
output_dir = Path("./experiments/data/predictions")
|
|
335
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
336
|
+
|
|
337
|
+
# Create timestamped filename for audit trail
|
|
338
|
+
timestamp_str = INFERENCE_TIMESTAMP.strftime('%Y%m%d_%H%M%S')
|
|
339
|
+
output_file = output_dir / f"batch_predictions_{timestamp_str}.parquet"
|
|
340
|
+
|
|
341
|
+
# Save with all metadata via delta
|
|
342
|
+
storage.write(results_df, str(output_dir / f"batch_predictions_{timestamp_str}"))
|
|
343
|
+
|
|
344
|
+
# Also save a "latest" version for downstream consumption
|
|
345
|
+
latest_file = output_dir / "batch_predictions_latest"
|
|
346
|
+
storage.write(results_df, str(latest_file))
|
|
347
|
+
|
|
348
|
+
print(f"ā
Predictions saved:")
|
|
349
|
+
print(f" Timestamped: {output_file}")
|
|
350
|
+
print(f" Latest: {latest_file}")
|
|
351
|
+
print(f"")
|
|
352
|
+
print(f"š
Inference Point-in-Time: {INFERENCE_TIMESTAMP}")
|
|
353
|
+
print(f"š Records: {len(results_df):,}")
|
|
354
|
+
|
|
355
|
+
# Save inference metadata as JSON for audit
|
|
356
|
+
import json
|
|
357
|
+
metadata_file = output_dir / f"inference_metadata_{timestamp_str}.json"
|
|
358
|
+
metadata = {
|
|
359
|
+
"inference_timestamp": INFERENCE_TIMESTAMP.isoformat(),
|
|
360
|
+
"total_customers": int(total_customers),
|
|
361
|
+
"predicted_churners": int(predicted_churners),
|
|
362
|
+
"churn_rate_pct": float(churn_rate),
|
|
363
|
+
"avg_probability": float(avg_probability),
|
|
364
|
+
"model_path": str(model_path),
|
|
365
|
+
"threshold": threshold,
|
|
366
|
+
"risk_distribution": {str(k): int(v) for k, v in risk_distribution.items()},
|
|
367
|
+
}
|
|
368
|
+
with open(metadata_file, "w") as f:
|
|
369
|
+
json.dump(metadata, f, indent=2)
|
|
370
|
+
print(f" Metadata: {metadata_file}")'''),
|
|
371
|
+
|
|
372
|
+
self.cb.section("11. Summary"),
|
|
373
|
+
self.cb.code('''print("=" * 70)
|
|
374
|
+
print("BATCH INFERENCE COMPLETE")
|
|
375
|
+
print("=" * 70)
|
|
376
|
+
print(f"")
|
|
377
|
+
print(f"š Point-in-Time Used: {INFERENCE_TIMESTAMP}")
|
|
378
|
+
print(f"š Customers Scored: {total_customers:,}")
|
|
379
|
+
print(f"ā ļø High Risk: {risk_distribution.get('High', 0):,}")
|
|
380
|
+
print(f"š” Medium Risk: {risk_distribution.get('Medium', 0):,}")
|
|
381
|
+
print(f"ā
Low Risk: {risk_distribution.get('Low', 0):,}")
|
|
382
|
+
print(f"")
|
|
383
|
+
print("Next steps:")
|
|
384
|
+
print("1. Review high-risk customers for intervention")
|
|
385
|
+
print("2. Schedule next batch inference run")
|
|
386
|
+
print("3. Monitor model performance over time")'''),
|
|
387
|
+
]
|
|
388
|
+
|
|
389
|
+
def generate_databricks_cells(self) -> List[nbformat.NotebookNode]:
|
|
390
|
+
"""Generate cells for Databricks Feature Store batch inference."""
|
|
391
|
+
catalog = self.config.feature_store.catalog
|
|
392
|
+
schema = self.config.feature_store.schema
|
|
393
|
+
model_name = self.config.mlflow.model_name
|
|
394
|
+
self.get_target_column()
|
|
395
|
+
threshold = self.config.threshold
|
|
396
|
+
|
|
397
|
+
return self.header_cells() + [
|
|
398
|
+
self.cb.section("1. Setup and Imports"),
|
|
399
|
+
self.cb.code(f'''from databricks.feature_engineering import FeatureEngineeringClient, FeatureLookup
|
|
400
|
+
from pyspark.sql.functions import col, lit, current_timestamp, when, count, sum as spark_sum, mean
|
|
401
|
+
from pyspark.sql.types import TimestampType
|
|
402
|
+
from datetime import datetime
|
|
403
|
+
import mlflow
|
|
404
|
+
|
|
405
|
+
fe = FeatureEngineeringClient()
|
|
406
|
+
CATALOG = "{catalog}"
|
|
407
|
+
SCHEMA = "{schema}"
|
|
408
|
+
FEATURE_TABLE = f"{{CATALOG}}.{{SCHEMA}}.customer_features"
|
|
409
|
+
MODEL_URI = f"models:/{{CATALOG}}.{{SCHEMA}}.{model_name}@production"
|
|
410
|
+
|
|
411
|
+
print(f"Feature Store: {{FEATURE_TABLE}}")
|
|
412
|
+
print(f"Model: {{MODEL_URI}}")'''),
|
|
413
|
+
|
|
414
|
+
self.cb.section("2. Set Inference Point-in-Time"),
|
|
415
|
+
self.cb.markdown('''**Critical**: The inference timestamp is the point-in-time for feature retrieval.
|
|
416
|
+
The Databricks Feature Store uses `timestamp_lookup_key` to ensure PIT correctness.'''),
|
|
417
|
+
self.cb.code('''# INFERENCE_TIMESTAMP: The point-in-time for this batch prediction
|
|
418
|
+
# For production batch jobs, use the job execution timestamp
|
|
419
|
+
INFERENCE_TIMESTAMP = datetime.now()
|
|
420
|
+
|
|
421
|
+
# Alternative: Use a specific historical timestamp for backtesting
|
|
422
|
+
# INFERENCE_TIMESTAMP = datetime(2024, 1, 15, 0, 0, 0)
|
|
423
|
+
|
|
424
|
+
print("=" * 70)
|
|
425
|
+
print(f"INFERENCE POINT-IN-TIME: {INFERENCE_TIMESTAMP}")
|
|
426
|
+
print("=" * 70)
|
|
427
|
+
print("Features will be retrieved as they existed at this timestamp.")'''),
|
|
428
|
+
|
|
429
|
+
self.cb.section("3. Load Customers to Score"),
|
|
430
|
+
self.cb.code(f'''# Load customers to score from the gold layer
|
|
431
|
+
df_customers = spark.table("{catalog}.{schema}.gold_customers")
|
|
432
|
+
|
|
433
|
+
# Select only entity IDs - features will come from the feature store
|
|
434
|
+
entity_df = df_customers.select("entity_id")
|
|
435
|
+
|
|
436
|
+
# Add the inference timestamp for point-in-time lookup
|
|
437
|
+
entity_df = entity_df.withColumn(
|
|
438
|
+
"inference_timestamp",
|
|
439
|
+
lit(INFERENCE_TIMESTAMP).cast(TimestampType())
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
print(f"Customers to score: {{entity_df.count():,}}")
|
|
443
|
+
print(f"Inference Point-in-Time: {{INFERENCE_TIMESTAMP}}")
|
|
444
|
+
entity_df.show(5)'''),
|
|
445
|
+
|
|
446
|
+
self.cb.section("4. Define Feature Lookups with Point-in-Time"),
|
|
447
|
+
self.cb.markdown('''The `timestamp_lookup_key` parameter ensures that features are retrieved
|
|
448
|
+
as they existed at the specified inference timestamp - no future data leakage.'''),
|
|
449
|
+
self.cb.code('''# Define feature lookups with PIT correctness
|
|
450
|
+
# The timestamp_lookup_key ensures features are retrieved as of inference_timestamp
|
|
451
|
+
feature_lookups = [
|
|
452
|
+
FeatureLookup(
|
|
453
|
+
table_name=FEATURE_TABLE,
|
|
454
|
+
lookup_key=["entity_id"],
|
|
455
|
+
timestamp_lookup_key="inference_timestamp", # PIT lookup
|
|
456
|
+
)
|
|
457
|
+
]
|
|
458
|
+
|
|
459
|
+
print("Feature lookups configured with Point-in-Time correctness")
|
|
460
|
+
print(f" Feature Table: {FEATURE_TABLE}")
|
|
461
|
+
print(f" Lookup Key: entity_id")
|
|
462
|
+
print(f" Timestamp Key: inference_timestamp")'''),
|
|
463
|
+
|
|
464
|
+
self.cb.section("5. Score with Feature Store (PIT-Correct)"),
|
|
465
|
+
self.cb.markdown('''Use `fe.score_batch()` to automatically retrieve features with PIT correctness
|
|
466
|
+
and apply the model. This ensures training-serving consistency.'''),
|
|
467
|
+
self.cb.code('''# Score using the feature store with automatic PIT feature retrieval
|
|
468
|
+
# This is the recommended approach for production inference
|
|
469
|
+
try:
|
|
470
|
+
# Method 1: Use fe.score_batch for automatic feature lookup
|
|
471
|
+
predictions = fe.score_batch(
|
|
472
|
+
df=entity_df,
|
|
473
|
+
model_uri=MODEL_URI,
|
|
474
|
+
result_type="double",
|
|
475
|
+
)
|
|
476
|
+
print("Scored using fe.score_batch with automatic feature lookup")
|
|
477
|
+
except Exception as e:
|
|
478
|
+
print(f"fe.score_batch not available: {e}")
|
|
479
|
+
print("Falling back to manual feature retrieval...")
|
|
480
|
+
|
|
481
|
+
# Method 2: Manual feature retrieval with PIT join
|
|
482
|
+
training_set = fe.create_training_set(
|
|
483
|
+
df=entity_df,
|
|
484
|
+
feature_lookups=feature_lookups,
|
|
485
|
+
label=None,
|
|
486
|
+
)
|
|
487
|
+
inference_df = training_set.load_df()
|
|
488
|
+
|
|
489
|
+
# Load model and score
|
|
490
|
+
model = mlflow.pyfunc.load_model(MODEL_URI)
|
|
491
|
+
|
|
492
|
+
# Convert to pandas for scoring
|
|
493
|
+
pdf = inference_df.toPandas()
|
|
494
|
+
feature_cols = [c for c in pdf.columns if c not in ["entity_id", "inference_timestamp"]]
|
|
495
|
+
|
|
496
|
+
predictions_array = model.predict(pdf[feature_cols])
|
|
497
|
+
pdf["prediction"] = predictions_array
|
|
498
|
+
|
|
499
|
+
predictions = spark.createDataFrame(pdf)
|
|
500
|
+
print("Scored using manual feature retrieval with PIT join")'''),
|
|
501
|
+
|
|
502
|
+
self.cb.section("6. Apply Threshold and Risk Tiers"),
|
|
503
|
+
self.cb.code(f'''threshold = {threshold}
|
|
504
|
+
|
|
505
|
+
# Add prediction columns and risk tiers
|
|
506
|
+
df_scored = (predictions
|
|
507
|
+
.withColumn("churn_probability", col("prediction"))
|
|
508
|
+
.withColumn("churn_prediction", when(col("prediction") >= threshold, 1).otherwise(0))
|
|
509
|
+
.withColumn("risk_tier",
|
|
510
|
+
when(col("prediction") >= 0.6, "High")
|
|
511
|
+
.when(col("prediction") >= 0.3, "Medium")
|
|
512
|
+
.otherwise("Low")
|
|
513
|
+
)
|
|
514
|
+
.withColumn("inference_point_in_time", lit(INFERENCE_TIMESTAMP).cast(TimestampType()))
|
|
515
|
+
.withColumn("model_uri", lit(MODEL_URI))
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
print(f"Applied threshold: {{threshold}}")
|
|
519
|
+
print(f"Added risk tiers: High (>=0.6), Medium (>=0.3), Low (<0.3)")'''),
|
|
520
|
+
|
|
521
|
+
self.cb.section("7. Batch Inference Results Dashboard"),
|
|
522
|
+
self.cb.markdown('''Display the batch scoring results with the **point-in-time** prominently shown.'''),
|
|
523
|
+
self.cb.code('''# Calculate summary statistics
|
|
524
|
+
summary = df_scored.agg(
|
|
525
|
+
count("*").alias("total_customers"),
|
|
526
|
+
spark_sum("churn_prediction").alias("predicted_churners"),
|
|
527
|
+
mean("churn_probability").alias("avg_probability")
|
|
528
|
+
).collect()[0]
|
|
529
|
+
|
|
530
|
+
total = summary["total_customers"]
|
|
531
|
+
churners = summary["predicted_churners"]
|
|
532
|
+
avg_prob = summary["avg_probability"]
|
|
533
|
+
|
|
534
|
+
# Risk distribution
|
|
535
|
+
risk_dist = df_scored.groupBy("risk_tier").count().collect()
|
|
536
|
+
risk_dict = {row["risk_tier"]: row["count"] for row in risk_dist}
|
|
537
|
+
|
|
538
|
+
print("=" * 70)
|
|
539
|
+
print("BATCH INFERENCE RESULTS DASHBOARD")
|
|
540
|
+
print("=" * 70)
|
|
541
|
+
print(f"")
|
|
542
|
+
print(f"š
INFERENCE POINT-IN-TIME: {INFERENCE_TIMESTAMP.strftime('%Y-%m-%d %H:%M:%S UTC')}")
|
|
543
|
+
print(f"")
|
|
544
|
+
print(f"š SUMMARY STATISTICS:")
|
|
545
|
+
print(f" Total Customers Scored: {total:,}")
|
|
546
|
+
print(f" Predicted Churners: {churners:,} ({churners/total*100:.1f}%)")
|
|
547
|
+
print(f" Average Churn Probability: {avg_prob:.3f}")
|
|
548
|
+
print(f"")
|
|
549
|
+
print(f"šÆ RISK DISTRIBUTION:")
|
|
550
|
+
print(f" High Risk: {risk_dict.get('High', 0):,}")
|
|
551
|
+
print(f" Medium Risk: {risk_dict.get('Medium', 0):,}")
|
|
552
|
+
print(f" Low Risk: {risk_dict.get('Low', 0):,}")
|
|
553
|
+
print(f"")
|
|
554
|
+
print("=" * 70)'''),
|
|
555
|
+
|
|
556
|
+
self.cb.section("8. Interactive Dashboard Display"),
|
|
557
|
+
self.cb.code('''# Display risk distribution
|
|
558
|
+
print(f"\\nš Risk Distribution (PIT: {INFERENCE_TIMESTAMP.strftime('%Y-%m-%d %H:%M')}):")
|
|
559
|
+
display(df_scored.groupBy("risk_tier").count().orderBy("risk_tier"))
|
|
560
|
+
|
|
561
|
+
# Display sample predictions with inference metadata
|
|
562
|
+
print(f"\\nš Sample Predictions (showing inference_point_in_time):")
|
|
563
|
+
display(
|
|
564
|
+
df_scored.select(
|
|
565
|
+
"entity_id",
|
|
566
|
+
"churn_probability",
|
|
567
|
+
"risk_tier",
|
|
568
|
+
"inference_point_in_time"
|
|
569
|
+
).limit(10)
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# Display probability distribution
|
|
573
|
+
print(f"\\nš Probability Distribution:")
|
|
574
|
+
display(df_scored.select("churn_probability").summary())'''),
|
|
575
|
+
|
|
576
|
+
self.cb.section("9. Save Predictions with Metadata"),
|
|
577
|
+
self.cb.code(f'''# Save predictions with full audit trail
|
|
578
|
+
# Include inference_point_in_time for reproducibility
|
|
579
|
+
|
|
580
|
+
output_cols = [
|
|
581
|
+
"entity_id",
|
|
582
|
+
"churn_probability",
|
|
583
|
+
"churn_prediction",
|
|
584
|
+
"risk_tier",
|
|
585
|
+
"inference_point_in_time", # Critical for audit
|
|
586
|
+
"model_uri"
|
|
587
|
+
]
|
|
588
|
+
|
|
589
|
+
# Save to Delta table with timestamp partition
|
|
590
|
+
df_scored.select(output_cols).write \\
|
|
591
|
+
.format("delta") \\
|
|
592
|
+
.mode("overwrite") \\
|
|
593
|
+
.option("overwriteSchema", "true") \\
|
|
594
|
+
.saveAsTable("{catalog}.{schema}.predictions")
|
|
595
|
+
|
|
596
|
+
print(f"ā
Predictions saved to {catalog}.{schema}.predictions")
|
|
597
|
+
print(f"š
Inference Point-in-Time: {{INFERENCE_TIMESTAMP}}")
|
|
598
|
+
print(f"š Records: {{df_scored.count():,}}")'''),
|
|
599
|
+
|
|
600
|
+
self.cb.section("10. Create Predictions Audit Log"),
|
|
601
|
+
self.cb.code(f'''from pyspark.sql.functions import current_timestamp as spark_current_timestamp
|
|
602
|
+
|
|
603
|
+
# Create or append to audit log
|
|
604
|
+
audit_record = spark.createDataFrame([{{
|
|
605
|
+
"inference_id": f"batch_{{INFERENCE_TIMESTAMP.strftime('%Y%m%d_%H%M%S')}}",
|
|
606
|
+
"inference_timestamp": INFERENCE_TIMESTAMP,
|
|
607
|
+
"total_customers": total,
|
|
608
|
+
"predicted_churners": int(churners),
|
|
609
|
+
"avg_probability": float(avg_prob),
|
|
610
|
+
"model_uri": MODEL_URI,
|
|
611
|
+
"threshold": {threshold},
|
|
612
|
+
"created_at": datetime.now(),
|
|
613
|
+
}}])
|
|
614
|
+
|
|
615
|
+
# Append to audit log
|
|
616
|
+
audit_record.write \\
|
|
617
|
+
.format("delta") \\
|
|
618
|
+
.mode("append") \\
|
|
619
|
+
.saveAsTable("{catalog}.{schema}.inference_audit_log")
|
|
620
|
+
|
|
621
|
+
print(f"ā
Audit log updated: {catalog}.{schema}.inference_audit_log")'''),
|
|
622
|
+
|
|
623
|
+
self.cb.section("11. Summary"),
|
|
624
|
+
self.cb.code('''print("=" * 70)
|
|
625
|
+
print("BATCH INFERENCE COMPLETE")
|
|
626
|
+
print("=" * 70)
|
|
627
|
+
print(f"")
|
|
628
|
+
print(f"š Point-in-Time Used: {INFERENCE_TIMESTAMP}")
|
|
629
|
+
print(f"š Customers Scored: {total:,}")
|
|
630
|
+
print(f"ā ļø High Risk: {risk_dict.get('High', 0):,}")
|
|
631
|
+
print(f"š” Medium Risk: {risk_dict.get('Medium', 0):,}")
|
|
632
|
+
print(f"ā
Low Risk: {risk_dict.get('Low', 0):,}")
|
|
633
|
+
print(f"")
|
|
634
|
+
print("The inference_point_in_time column in the predictions table")
|
|
635
|
+
print("records exactly when features were retrieved, ensuring")
|
|
636
|
+
print("full auditability and reproducibility.")
|
|
637
|
+
print(f"")
|
|
638
|
+
print("Next steps:")
|
|
639
|
+
print("1. Review high-risk customers for intervention")
|
|
640
|
+
print("2. Set up scheduled inference jobs")
|
|
641
|
+
print("3. Monitor prediction drift over time")'''),
|
|
642
|
+
]
|