churnkit 0.75.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
- churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
- churnkit-0.75.0a1.dist-info/METADATA +229 -0
- churnkit-0.75.0a1.dist-info/RECORD +302 -0
- churnkit-0.75.0a1.dist-info/WHEEL +4 -0
- churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
- churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
- customer_retention/__init__.py +37 -0
- customer_retention/analysis/__init__.py +0 -0
- customer_retention/analysis/auto_explorer/__init__.py +62 -0
- customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
- customer_retention/analysis/auto_explorer/explorer.py +258 -0
- customer_retention/analysis/auto_explorer/findings.py +291 -0
- customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
- customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
- customer_retention/analysis/auto_explorer/recommendations.py +418 -0
- customer_retention/analysis/business/__init__.py +26 -0
- customer_retention/analysis/business/ab_test_designer.py +144 -0
- customer_retention/analysis/business/fairness_analyzer.py +166 -0
- customer_retention/analysis/business/intervention_matcher.py +121 -0
- customer_retention/analysis/business/report_generator.py +222 -0
- customer_retention/analysis/business/risk_profile.py +199 -0
- customer_retention/analysis/business/roi_analyzer.py +139 -0
- customer_retention/analysis/diagnostics/__init__.py +20 -0
- customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
- customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
- customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
- customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
- customer_retention/analysis/diagnostics/noise_tester.py +140 -0
- customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
- customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
- customer_retention/analysis/discovery/__init__.py +8 -0
- customer_retention/analysis/discovery/config_generator.py +49 -0
- customer_retention/analysis/discovery/discovery_flow.py +19 -0
- customer_retention/analysis/discovery/type_inferencer.py +147 -0
- customer_retention/analysis/interpretability/__init__.py +13 -0
- customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
- customer_retention/analysis/interpretability/counterfactual.py +175 -0
- customer_retention/analysis/interpretability/individual_explainer.py +141 -0
- customer_retention/analysis/interpretability/pdp_generator.py +103 -0
- customer_retention/analysis/interpretability/shap_explainer.py +106 -0
- customer_retention/analysis/jupyter_save_hook.py +28 -0
- customer_retention/analysis/notebook_html_exporter.py +136 -0
- customer_retention/analysis/notebook_progress.py +60 -0
- customer_retention/analysis/plotly_preprocessor.py +154 -0
- customer_retention/analysis/recommendations/__init__.py +54 -0
- customer_retention/analysis/recommendations/base.py +158 -0
- customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
- customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
- customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
- customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
- customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
- customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
- customer_retention/analysis/recommendations/datetime/extract.py +149 -0
- customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
- customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
- customer_retention/analysis/recommendations/pipeline.py +74 -0
- customer_retention/analysis/recommendations/registry.py +76 -0
- customer_retention/analysis/recommendations/selection/__init__.py +3 -0
- customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
- customer_retention/analysis/recommendations/transform/__init__.py +4 -0
- customer_retention/analysis/recommendations/transform/power.py +94 -0
- customer_retention/analysis/recommendations/transform/scale.py +112 -0
- customer_retention/analysis/visualization/__init__.py +15 -0
- customer_retention/analysis/visualization/chart_builder.py +2619 -0
- customer_retention/analysis/visualization/console.py +122 -0
- customer_retention/analysis/visualization/display.py +171 -0
- customer_retention/analysis/visualization/number_formatter.py +36 -0
- customer_retention/artifacts/__init__.py +3 -0
- customer_retention/artifacts/fit_artifact_registry.py +146 -0
- customer_retention/cli.py +93 -0
- customer_retention/core/__init__.py +0 -0
- customer_retention/core/compat/__init__.py +193 -0
- customer_retention/core/compat/detection.py +99 -0
- customer_retention/core/compat/ops.py +48 -0
- customer_retention/core/compat/pandas_backend.py +57 -0
- customer_retention/core/compat/spark_backend.py +75 -0
- customer_retention/core/components/__init__.py +11 -0
- customer_retention/core/components/base.py +79 -0
- customer_retention/core/components/components/__init__.py +13 -0
- customer_retention/core/components/components/deployer.py +26 -0
- customer_retention/core/components/components/explainer.py +26 -0
- customer_retention/core/components/components/feature_eng.py +33 -0
- customer_retention/core/components/components/ingester.py +34 -0
- customer_retention/core/components/components/profiler.py +34 -0
- customer_retention/core/components/components/trainer.py +38 -0
- customer_retention/core/components/components/transformer.py +36 -0
- customer_retention/core/components/components/validator.py +37 -0
- customer_retention/core/components/enums.py +33 -0
- customer_retention/core/components/orchestrator.py +94 -0
- customer_retention/core/components/registry.py +59 -0
- customer_retention/core/config/__init__.py +39 -0
- customer_retention/core/config/column_config.py +95 -0
- customer_retention/core/config/experiments.py +71 -0
- customer_retention/core/config/pipeline_config.py +117 -0
- customer_retention/core/config/source_config.py +83 -0
- customer_retention/core/utils/__init__.py +28 -0
- customer_retention/core/utils/leakage.py +85 -0
- customer_retention/core/utils/severity.py +53 -0
- customer_retention/core/utils/statistics.py +90 -0
- customer_retention/generators/__init__.py +0 -0
- customer_retention/generators/notebook_generator/__init__.py +167 -0
- customer_retention/generators/notebook_generator/base.py +55 -0
- customer_retention/generators/notebook_generator/cell_builder.py +49 -0
- customer_retention/generators/notebook_generator/config.py +47 -0
- customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
- customer_retention/generators/notebook_generator/local_generator.py +48 -0
- customer_retention/generators/notebook_generator/project_init.py +174 -0
- customer_retention/generators/notebook_generator/runner.py +150 -0
- customer_retention/generators/notebook_generator/script_generator.py +110 -0
- customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
- customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
- customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
- customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
- customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
- customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
- customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
- customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
- customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
- customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
- customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
- customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
- customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
- customer_retention/generators/orchestration/__init__.py +23 -0
- customer_retention/generators/orchestration/code_generator.py +196 -0
- customer_retention/generators/orchestration/context.py +147 -0
- customer_retention/generators/orchestration/data_materializer.py +188 -0
- customer_retention/generators/orchestration/databricks_exporter.py +411 -0
- customer_retention/generators/orchestration/doc_generator.py +311 -0
- customer_retention/generators/pipeline_generator/__init__.py +26 -0
- customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
- customer_retention/generators/pipeline_generator/generator.py +142 -0
- customer_retention/generators/pipeline_generator/models.py +166 -0
- customer_retention/generators/pipeline_generator/renderer.py +2125 -0
- customer_retention/generators/spec_generator/__init__.py +37 -0
- customer_retention/generators/spec_generator/databricks_generator.py +433 -0
- customer_retention/generators/spec_generator/generic_generator.py +373 -0
- customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
- customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
- customer_retention/integrations/__init__.py +0 -0
- customer_retention/integrations/adapters/__init__.py +13 -0
- customer_retention/integrations/adapters/base.py +10 -0
- customer_retention/integrations/adapters/factory.py +25 -0
- customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
- customer_retention/integrations/adapters/feature_store/base.py +57 -0
- customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
- customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
- customer_retention/integrations/adapters/feature_store/local.py +75 -0
- customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
- customer_retention/integrations/adapters/mlflow/base.py +32 -0
- customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
- customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
- customer_retention/integrations/adapters/mlflow/local.py +50 -0
- customer_retention/integrations/adapters/storage/__init__.py +5 -0
- customer_retention/integrations/adapters/storage/base.py +33 -0
- customer_retention/integrations/adapters/storage/databricks.py +76 -0
- customer_retention/integrations/adapters/storage/local.py +59 -0
- customer_retention/integrations/feature_store/__init__.py +47 -0
- customer_retention/integrations/feature_store/definitions.py +215 -0
- customer_retention/integrations/feature_store/manager.py +744 -0
- customer_retention/integrations/feature_store/registry.py +412 -0
- customer_retention/integrations/iteration/__init__.py +28 -0
- customer_retention/integrations/iteration/context.py +212 -0
- customer_retention/integrations/iteration/feedback_collector.py +184 -0
- customer_retention/integrations/iteration/orchestrator.py +168 -0
- customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
- customer_retention/integrations/iteration/signals.py +212 -0
- customer_retention/integrations/llm_context/__init__.py +4 -0
- customer_retention/integrations/llm_context/context_builder.py +201 -0
- customer_retention/integrations/llm_context/prompts.py +100 -0
- customer_retention/integrations/streaming/__init__.py +103 -0
- customer_retention/integrations/streaming/batch_integration.py +149 -0
- customer_retention/integrations/streaming/early_warning_model.py +227 -0
- customer_retention/integrations/streaming/event_schema.py +214 -0
- customer_retention/integrations/streaming/online_store_writer.py +249 -0
- customer_retention/integrations/streaming/realtime_scorer.py +261 -0
- customer_retention/integrations/streaming/trigger_engine.py +293 -0
- customer_retention/integrations/streaming/window_aggregator.py +393 -0
- customer_retention/stages/__init__.py +0 -0
- customer_retention/stages/cleaning/__init__.py +9 -0
- customer_retention/stages/cleaning/base.py +28 -0
- customer_retention/stages/cleaning/missing_handler.py +160 -0
- customer_retention/stages/cleaning/outlier_handler.py +204 -0
- customer_retention/stages/deployment/__init__.py +28 -0
- customer_retention/stages/deployment/batch_scorer.py +106 -0
- customer_retention/stages/deployment/champion_challenger.py +299 -0
- customer_retention/stages/deployment/model_registry.py +182 -0
- customer_retention/stages/deployment/retraining_trigger.py +245 -0
- customer_retention/stages/features/__init__.py +73 -0
- customer_retention/stages/features/behavioral_features.py +266 -0
- customer_retention/stages/features/customer_segmentation.py +505 -0
- customer_retention/stages/features/feature_definitions.py +265 -0
- customer_retention/stages/features/feature_engineer.py +551 -0
- customer_retention/stages/features/feature_manifest.py +340 -0
- customer_retention/stages/features/feature_selector.py +239 -0
- customer_retention/stages/features/interaction_features.py +160 -0
- customer_retention/stages/features/temporal_features.py +243 -0
- customer_retention/stages/ingestion/__init__.py +9 -0
- customer_retention/stages/ingestion/load_result.py +32 -0
- customer_retention/stages/ingestion/loaders.py +195 -0
- customer_retention/stages/ingestion/source_registry.py +130 -0
- customer_retention/stages/modeling/__init__.py +31 -0
- customer_retention/stages/modeling/baseline_trainer.py +139 -0
- customer_retention/stages/modeling/cross_validator.py +125 -0
- customer_retention/stages/modeling/data_splitter.py +205 -0
- customer_retention/stages/modeling/feature_scaler.py +99 -0
- customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
- customer_retention/stages/modeling/imbalance_handler.py +282 -0
- customer_retention/stages/modeling/mlflow_logger.py +95 -0
- customer_retention/stages/modeling/model_comparator.py +149 -0
- customer_retention/stages/modeling/model_evaluator.py +138 -0
- customer_retention/stages/modeling/threshold_optimizer.py +131 -0
- customer_retention/stages/monitoring/__init__.py +37 -0
- customer_retention/stages/monitoring/alert_manager.py +328 -0
- customer_retention/stages/monitoring/drift_detector.py +201 -0
- customer_retention/stages/monitoring/performance_monitor.py +242 -0
- customer_retention/stages/preprocessing/__init__.py +5 -0
- customer_retention/stages/preprocessing/transformer_manager.py +284 -0
- customer_retention/stages/profiling/__init__.py +256 -0
- customer_retention/stages/profiling/categorical_distribution.py +269 -0
- customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
- customer_retention/stages/profiling/column_profiler.py +527 -0
- customer_retention/stages/profiling/distribution_analysis.py +483 -0
- customer_retention/stages/profiling/drift_detector.py +310 -0
- customer_retention/stages/profiling/feature_capacity.py +507 -0
- customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
- customer_retention/stages/profiling/profile_result.py +212 -0
- customer_retention/stages/profiling/quality_checks.py +1632 -0
- customer_retention/stages/profiling/relationship_detector.py +256 -0
- customer_retention/stages/profiling/relationship_recommender.py +454 -0
- customer_retention/stages/profiling/report_generator.py +520 -0
- customer_retention/stages/profiling/scd_analyzer.py +151 -0
- customer_retention/stages/profiling/segment_analyzer.py +632 -0
- customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
- customer_retention/stages/profiling/target_level_analyzer.py +217 -0
- customer_retention/stages/profiling/temporal_analyzer.py +388 -0
- customer_retention/stages/profiling/temporal_coverage.py +488 -0
- customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
- customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
- customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
- customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
- customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
- customer_retention/stages/profiling/text_embedder.py +87 -0
- customer_retention/stages/profiling/text_processor.py +115 -0
- customer_retention/stages/profiling/text_reducer.py +60 -0
- customer_retention/stages/profiling/time_series_profiler.py +303 -0
- customer_retention/stages/profiling/time_window_aggregator.py +376 -0
- customer_retention/stages/profiling/type_detector.py +382 -0
- customer_retention/stages/profiling/window_recommendation.py +288 -0
- customer_retention/stages/temporal/__init__.py +166 -0
- customer_retention/stages/temporal/access_guard.py +180 -0
- customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
- customer_retention/stages/temporal/data_preparer.py +178 -0
- customer_retention/stages/temporal/point_in_time_join.py +134 -0
- customer_retention/stages/temporal/point_in_time_registry.py +148 -0
- customer_retention/stages/temporal/scenario_detector.py +163 -0
- customer_retention/stages/temporal/snapshot_manager.py +259 -0
- customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
- customer_retention/stages/temporal/timestamp_discovery.py +531 -0
- customer_retention/stages/temporal/timestamp_manager.py +255 -0
- customer_retention/stages/transformation/__init__.py +13 -0
- customer_retention/stages/transformation/binary_handler.py +85 -0
- customer_retention/stages/transformation/categorical_encoder.py +245 -0
- customer_retention/stages/transformation/datetime_transformer.py +97 -0
- customer_retention/stages/transformation/numeric_transformer.py +181 -0
- customer_retention/stages/transformation/pipeline.py +257 -0
- customer_retention/stages/validation/__init__.py +60 -0
- customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
- customer_retention/stages/validation/business_sense_gate.py +173 -0
- customer_retention/stages/validation/data_quality_gate.py +235 -0
- customer_retention/stages/validation/data_validators.py +511 -0
- customer_retention/stages/validation/feature_quality_gate.py +183 -0
- customer_retention/stages/validation/gates.py +117 -0
- customer_retention/stages/validation/leakage_gate.py +352 -0
- customer_retention/stages/validation/model_validity_gate.py +213 -0
- customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
- customer_retention/stages/validation/quality_scorer.py +544 -0
- customer_retention/stages/validation/rule_generator.py +57 -0
- customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
- customer_retention/stages/validation/timeseries_detector.py +769 -0
- customer_retention/transforms/__init__.py +47 -0
- customer_retention/transforms/artifact_store.py +50 -0
- customer_retention/transforms/executor.py +157 -0
- customer_retention/transforms/fitted.py +92 -0
- customer_retention/transforms/ops.py +148 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""Timestamp management for leakage-safe ML pipelines.
|
|
2
|
+
|
|
3
|
+
This module provides the core timestamp handling infrastructure for ensuring
|
|
4
|
+
point-in-time (PIT) correctness in ML training pipelines. It supports multiple
|
|
5
|
+
strategies for managing timestamps depending on data availability.
|
|
6
|
+
|
|
7
|
+
Key concepts:
|
|
8
|
+
- feature_timestamp: When features were observed
|
|
9
|
+
- label_timestamp: When the label became known
|
|
10
|
+
- label_available_flag: Whether the label can be used for training
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
>>> from customer_retention.stages.temporal import TimestampManager, TimestampConfig, TimestampStrategy
|
|
14
|
+
>>> config = TimestampConfig(
|
|
15
|
+
... strategy=TimestampStrategy.PRODUCTION,
|
|
16
|
+
... feature_timestamp_column="last_activity_date",
|
|
17
|
+
... label_timestamp_column="churn_date"
|
|
18
|
+
... )
|
|
19
|
+
>>> manager = TimestampManager(config)
|
|
20
|
+
>>> df_with_timestamps = manager.ensure_timestamps(df)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
from datetime import datetime, timedelta
|
|
25
|
+
from enum import Enum
|
|
26
|
+
from typing import Any, Optional
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
import pandas as pd
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TimestampStrategy(Enum):
|
|
33
|
+
"""Strategy for handling timestamps in the ML pipeline.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
PRODUCTION: Use explicit timestamp columns from the data
|
|
37
|
+
SYNTHETIC_RANDOM: Generate random timestamps within a date range
|
|
38
|
+
SYNTHETIC_INDEX: Generate timestamps based on row index
|
|
39
|
+
SYNTHETIC_FIXED: Use a fixed timestamp for all rows
|
|
40
|
+
DERIVED: Derive timestamps from other columns (e.g., tenure)
|
|
41
|
+
"""
|
|
42
|
+
PRODUCTION = "production"
|
|
43
|
+
SYNTHETIC_RANDOM = "synthetic_random"
|
|
44
|
+
SYNTHETIC_INDEX = "synthetic_index"
|
|
45
|
+
SYNTHETIC_FIXED = "synthetic_fixed"
|
|
46
|
+
DERIVED = "derived"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class TimestampConfig:
|
|
51
|
+
"""Configuration for timestamp handling strategy.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
strategy: The timestamp handling strategy to use
|
|
55
|
+
feature_timestamp_column: Column name for feature timestamps (production strategy)
|
|
56
|
+
label_timestamp_column: Column name for label timestamps (production strategy)
|
|
57
|
+
observation_window_days: Days between feature observation and label availability
|
|
58
|
+
synthetic_base_date: Base date for synthetic timestamp generation
|
|
59
|
+
synthetic_range_days: Range of days for synthetic random timestamps
|
|
60
|
+
derive_label_from_feature: If True, derive label_timestamp from feature_timestamp
|
|
61
|
+
derivation_config: Configuration for derived timestamps (formula, source columns)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
strategy: TimestampStrategy
|
|
65
|
+
feature_timestamp_column: Optional[str] = None
|
|
66
|
+
label_timestamp_column: Optional[str] = None
|
|
67
|
+
observation_window_days: int = 90
|
|
68
|
+
synthetic_base_date: str = "2024-01-01"
|
|
69
|
+
synthetic_range_days: int = 365
|
|
70
|
+
derive_label_from_feature: bool = False
|
|
71
|
+
derivation_config: Optional[dict[str, Any]] = None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class TimestampManager:
|
|
75
|
+
"""Manages timestamp columns to ensure point-in-time correctness.
|
|
76
|
+
|
|
77
|
+
The TimestampManager ensures that all data has proper feature_timestamp,
|
|
78
|
+
label_timestamp, and label_available_flag columns, regardless of whether
|
|
79
|
+
the source data has explicit timestamps or needs synthetic ones.
|
|
80
|
+
|
|
81
|
+
Example:
|
|
82
|
+
>>> config = TimestampConfig(strategy=TimestampStrategy.SYNTHETIC_FIXED)
|
|
83
|
+
>>> manager = TimestampManager(config)
|
|
84
|
+
>>> df = manager.ensure_timestamps(df)
|
|
85
|
+
>>> assert "feature_timestamp" in df.columns
|
|
86
|
+
>>> assert "label_timestamp" in df.columns
|
|
87
|
+
>>> assert "label_available_flag" in df.columns
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(self, config: TimestampConfig):
|
|
91
|
+
"""Initialize the TimestampManager.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
config: Configuration specifying the timestamp strategy
|
|
95
|
+
"""
|
|
96
|
+
self.config = config
|
|
97
|
+
|
|
98
|
+
def ensure_timestamps(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
99
|
+
"""Add or validate timestamp columns based on the configured strategy.
|
|
100
|
+
|
|
101
|
+
This is the main entry point for timestamp handling. It adds feature_timestamp,
|
|
102
|
+
label_timestamp, and label_available_flag columns to the DataFrame.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
df: Input DataFrame
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
DataFrame with timestamp columns added
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
ValueError: If production strategy is used but required columns are missing
|
|
112
|
+
"""
|
|
113
|
+
if self.config.strategy == TimestampStrategy.PRODUCTION:
|
|
114
|
+
return self._validate_production_timestamps(df)
|
|
115
|
+
elif self.config.strategy == TimestampStrategy.DERIVED:
|
|
116
|
+
return self._derive_timestamps(df)
|
|
117
|
+
return self._add_synthetic_timestamps(df)
|
|
118
|
+
|
|
119
|
+
def _validate_production_timestamps(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
120
|
+
required = [self.config.feature_timestamp_column, self.config.label_timestamp_column]
|
|
121
|
+
missing = [col for col in required if col and col not in df.columns]
|
|
122
|
+
if missing:
|
|
123
|
+
raise ValueError(f"Missing required timestamp columns: {missing}")
|
|
124
|
+
|
|
125
|
+
df = df.copy()
|
|
126
|
+
if self.config.feature_timestamp_column:
|
|
127
|
+
df["feature_timestamp"] = self._parse_datetime_column(
|
|
128
|
+
df[self.config.feature_timestamp_column], self.config.feature_timestamp_column
|
|
129
|
+
)
|
|
130
|
+
if self.config.label_timestamp_column:
|
|
131
|
+
df["label_timestamp"] = self._parse_datetime_column(
|
|
132
|
+
df[self.config.label_timestamp_column], self.config.label_timestamp_column
|
|
133
|
+
)
|
|
134
|
+
elif self.config.derive_label_from_feature:
|
|
135
|
+
window = timedelta(days=self.config.observation_window_days)
|
|
136
|
+
df["label_timestamp"] = df["feature_timestamp"] + window
|
|
137
|
+
now = datetime.now()
|
|
138
|
+
has_event = df["label_timestamp"].notna() & (df["label_timestamp"] <= now)
|
|
139
|
+
observation_complete = (
|
|
140
|
+
df["feature_timestamp"].notna()
|
|
141
|
+
& (df["feature_timestamp"] + pd.Timedelta(days=self.config.observation_window_days) <= now)
|
|
142
|
+
)
|
|
143
|
+
df["label_available_flag"] = has_event | observation_complete
|
|
144
|
+
return df
|
|
145
|
+
|
|
146
|
+
def _parse_datetime_column(self, series: pd.Series, col_name: str) -> pd.Series:
|
|
147
|
+
if pd.api.types.is_datetime64_any_dtype(series):
|
|
148
|
+
return series
|
|
149
|
+
parsed = pd.to_datetime(series, format="mixed", errors="coerce")
|
|
150
|
+
invalid_count = parsed.isna().sum() - series.isna().sum()
|
|
151
|
+
if invalid_count > 0:
|
|
152
|
+
import warnings
|
|
153
|
+
warnings.warn(f"Column '{col_name}': {invalid_count} invalid dates coerced to NaT")
|
|
154
|
+
return parsed
|
|
155
|
+
|
|
156
|
+
def _derive_timestamps(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
157
|
+
if not self.config.derivation_config:
|
|
158
|
+
raise ValueError("derivation_config required for DERIVED strategy")
|
|
159
|
+
|
|
160
|
+
df = df.copy()
|
|
161
|
+
config = self.config.derivation_config
|
|
162
|
+
|
|
163
|
+
if "feature_derivation" in config:
|
|
164
|
+
df = self._apply_derivation(df, config["feature_derivation"], "feature_timestamp")
|
|
165
|
+
if "label_derivation" in config:
|
|
166
|
+
df = self._apply_derivation(df, config["label_derivation"], "label_timestamp")
|
|
167
|
+
elif "feature_timestamp" in df.columns:
|
|
168
|
+
window = timedelta(days=self.config.observation_window_days)
|
|
169
|
+
df["label_timestamp"] = df["feature_timestamp"] + window
|
|
170
|
+
|
|
171
|
+
df["label_available_flag"] = True
|
|
172
|
+
return df
|
|
173
|
+
|
|
174
|
+
def _apply_derivation(self, df: pd.DataFrame, derivation: dict, target_col: str) -> pd.DataFrame:
|
|
175
|
+
sources = derivation.get("sources", [])
|
|
176
|
+
formula = derivation.get("formula", "")
|
|
177
|
+
|
|
178
|
+
if not sources or not formula:
|
|
179
|
+
return df
|
|
180
|
+
|
|
181
|
+
if "tenure" in formula.lower() and len(sources) >= 1:
|
|
182
|
+
tenure_col = sources[0]
|
|
183
|
+
if tenure_col in df.columns:
|
|
184
|
+
reference_date = datetime.now()
|
|
185
|
+
df[target_col] = reference_date - pd.to_timedelta(df[tenure_col] * 30, unit="D")
|
|
186
|
+
return df
|
|
187
|
+
|
|
188
|
+
def _add_synthetic_timestamps(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
189
|
+
df = df.copy()
|
|
190
|
+
base = pd.to_datetime(self.config.synthetic_base_date)
|
|
191
|
+
window = timedelta(days=self.config.observation_window_days)
|
|
192
|
+
|
|
193
|
+
if self.config.strategy == TimestampStrategy.SYNTHETIC_FIXED:
|
|
194
|
+
df["feature_timestamp"] = base
|
|
195
|
+
df["label_timestamp"] = base + window
|
|
196
|
+
elif self.config.strategy == TimestampStrategy.SYNTHETIC_INDEX:
|
|
197
|
+
df["feature_timestamp"] = base + pd.to_timedelta(range(len(df)), unit="D")
|
|
198
|
+
df["label_timestamp"] = df["feature_timestamp"] + window
|
|
199
|
+
elif self.config.strategy == TimestampStrategy.SYNTHETIC_RANDOM:
|
|
200
|
+
np.random.seed(42)
|
|
201
|
+
days = np.random.randint(0, self.config.synthetic_range_days, len(df))
|
|
202
|
+
df["feature_timestamp"] = base + pd.to_timedelta(days, unit="D")
|
|
203
|
+
df["label_timestamp"] = df["feature_timestamp"] + window
|
|
204
|
+
|
|
205
|
+
df["label_available_flag"] = True
|
|
206
|
+
return df
|
|
207
|
+
|
|
208
|
+
def validate_point_in_time(self, df: pd.DataFrame) -> bool:
|
|
209
|
+
"""Validate that timestamps maintain point-in-time correctness.
|
|
210
|
+
|
|
211
|
+
Ensures that feature_timestamp is always <= label_timestamp for all rows,
|
|
212
|
+
which is required to prevent data leakage during training.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
df: DataFrame with timestamp columns
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
True if validation passes
|
|
219
|
+
|
|
220
|
+
Raises:
|
|
221
|
+
ValueError: If timestamp columns are missing or violations are found
|
|
222
|
+
"""
|
|
223
|
+
if "feature_timestamp" not in df.columns or "label_timestamp" not in df.columns:
|
|
224
|
+
raise ValueError("Missing timestamp columns for point-in-time validation")
|
|
225
|
+
|
|
226
|
+
violations = df[df["feature_timestamp"] > df["label_timestamp"]]
|
|
227
|
+
if len(violations) > 0:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Point-in-time violation: {len(violations)} rows have "
|
|
230
|
+
f"feature_timestamp > label_timestamp"
|
|
231
|
+
)
|
|
232
|
+
return True
|
|
233
|
+
|
|
234
|
+
def get_timestamp_summary(self, df: pd.DataFrame) -> dict[str, Any]:
|
|
235
|
+
"""Generate a summary of timestamp column statistics.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
df: DataFrame with timestamp columns
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Dictionary containing timestamp statistics including min/max dates,
|
|
242
|
+
null percentages, and label availability rates
|
|
243
|
+
"""
|
|
244
|
+
summary = {"strategy": self.config.strategy.value}
|
|
245
|
+
|
|
246
|
+
for col in ["feature_timestamp", "label_timestamp"]:
|
|
247
|
+
if col in df.columns:
|
|
248
|
+
summary[f"{col}_min"] = df[col].min()
|
|
249
|
+
summary[f"{col}_max"] = df[col].max()
|
|
250
|
+
summary[f"{col}_null_pct"] = df[col].isna().mean()
|
|
251
|
+
|
|
252
|
+
if "label_available_flag" in df.columns:
|
|
253
|
+
summary["label_available_pct"] = df["label_available_flag"].mean()
|
|
254
|
+
|
|
255
|
+
return summary
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .binary_handler import BinaryHandler, BinaryTransformResult
|
|
2
|
+
from .categorical_encoder import CategoricalEncoder, CategoricalEncodeResult, EncodingStrategy
|
|
3
|
+
from .datetime_transformer import DatetimeTransformer, DatetimeTransformResult
|
|
4
|
+
from .numeric_transformer import NumericTransformer, NumericTransformResult, PowerTransform, ScalingStrategy
|
|
5
|
+
from .pipeline import PipelineResult, TransformationManifest, TransformationPipeline
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"NumericTransformer", "ScalingStrategy", "PowerTransform", "NumericTransformResult",
|
|
9
|
+
"CategoricalEncoder", "EncodingStrategy", "CategoricalEncodeResult",
|
|
10
|
+
"DatetimeTransformer", "DatetimeTransformResult",
|
|
11
|
+
"BinaryHandler", "BinaryTransformResult",
|
|
12
|
+
"TransformationPipeline", "TransformationManifest", "PipelineResult"
|
|
13
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from customer_retention.core.compat import Series
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class BinaryTransformResult:
|
|
11
|
+
series: Series
|
|
12
|
+
mapping: dict = field(default_factory=dict)
|
|
13
|
+
original_values: list = field(default_factory=list)
|
|
14
|
+
positive_class: Any = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BinaryHandler:
|
|
18
|
+
TRUE_VALUES = {1, 1.0, True, "1", "yes", "Yes", "YES", "true", "True", "TRUE", "y", "Y"}
|
|
19
|
+
FALSE_VALUES = {0, 0.0, False, "0", "no", "No", "NO", "false", "False", "FALSE", "n", "N"}
|
|
20
|
+
|
|
21
|
+
def __init__(self, positive_class: Optional[Any] = None):
|
|
22
|
+
self.positive_class = positive_class
|
|
23
|
+
self._mapping: Optional[dict] = None
|
|
24
|
+
self._original_values: Optional[list] = None
|
|
25
|
+
self._positive: Any = None
|
|
26
|
+
self._is_fitted = False
|
|
27
|
+
|
|
28
|
+
def fit(self, series: Series) -> "BinaryHandler":
|
|
29
|
+
clean = series.dropna()
|
|
30
|
+
unique_vals = clean.unique().tolist()
|
|
31
|
+
self._original_values = unique_vals
|
|
32
|
+
|
|
33
|
+
if self.positive_class is not None:
|
|
34
|
+
self._positive = self.positive_class
|
|
35
|
+
self._mapping = {v: 1 if v == self.positive_class else 0 for v in unique_vals}
|
|
36
|
+
else:
|
|
37
|
+
self._mapping, self._positive = self._infer_mapping(unique_vals)
|
|
38
|
+
|
|
39
|
+
self._is_fitted = True
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def transform(self, series: Series) -> BinaryTransformResult:
|
|
43
|
+
if not self._is_fitted:
|
|
44
|
+
raise ValueError("Handler not fitted. Call fit() or fit_transform() first.")
|
|
45
|
+
return self._apply_transform(series)
|
|
46
|
+
|
|
47
|
+
def fit_transform(self, series: Series) -> BinaryTransformResult:
|
|
48
|
+
self.fit(series)
|
|
49
|
+
return self._apply_transform(series)
|
|
50
|
+
|
|
51
|
+
def _infer_mapping(self, unique_vals: list) -> tuple[dict, Any]:
|
|
52
|
+
if len(unique_vals) == 1:
|
|
53
|
+
val = unique_vals[0]
|
|
54
|
+
if val in self.TRUE_VALUES or str(val).lower() in {"yes", "y", "true", "1"}:
|
|
55
|
+
return {val: 1}, val
|
|
56
|
+
return {val: 0}, None
|
|
57
|
+
|
|
58
|
+
mapping = {}
|
|
59
|
+
positive = None
|
|
60
|
+
|
|
61
|
+
for val in unique_vals:
|
|
62
|
+
val_lower = str(val).lower() if isinstance(val, str) else val
|
|
63
|
+
if val in self.TRUE_VALUES or val_lower in {"yes", "y", "true", "1", "active"}:
|
|
64
|
+
mapping[val] = 1
|
|
65
|
+
positive = val
|
|
66
|
+
elif val in self.FALSE_VALUES or val_lower in {"no", "n", "false", "0", "inactive"}:
|
|
67
|
+
mapping[val] = 0
|
|
68
|
+
|
|
69
|
+
if len(mapping) == len(unique_vals) and positive is not None:
|
|
70
|
+
return mapping, positive
|
|
71
|
+
|
|
72
|
+
if len(unique_vals) == 2:
|
|
73
|
+
sorted_vals = sorted(unique_vals, key=lambda x: (str(x).lower(), x))
|
|
74
|
+
return {sorted_vals[0]: 0, sorted_vals[1]: 1}, sorted_vals[1]
|
|
75
|
+
|
|
76
|
+
return {v: i for i, v in enumerate(unique_vals)}, unique_vals[-1] if unique_vals else None
|
|
77
|
+
|
|
78
|
+
def _apply_transform(self, series: Series) -> BinaryTransformResult:
|
|
79
|
+
result = series.map(self._mapping)
|
|
80
|
+
result = result.where(series.notna(), np.nan)
|
|
81
|
+
|
|
82
|
+
return BinaryTransformResult(
|
|
83
|
+
series=result, mapping=self._mapping or {},
|
|
84
|
+
original_values=self._original_values or [], positive_class=self._positive
|
|
85
|
+
)
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from customer_retention.core.compat import DataFrame, Series, pd
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EncodingStrategy(str, Enum):
|
|
11
|
+
ONE_HOT = "one_hot"
|
|
12
|
+
LABEL = "label"
|
|
13
|
+
ORDINAL = "ordinal"
|
|
14
|
+
CYCLICAL = "cyclical"
|
|
15
|
+
TARGET = "target"
|
|
16
|
+
FREQUENCY = "frequency"
|
|
17
|
+
BINARY = "binary"
|
|
18
|
+
HASH = "hash"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class CategoricalEncodeResult:
|
|
23
|
+
series: Optional[Series] = None
|
|
24
|
+
df: Optional[DataFrame] = None
|
|
25
|
+
strategy: EncodingStrategy = EncodingStrategy.LABEL
|
|
26
|
+
columns_created: list = field(default_factory=list)
|
|
27
|
+
mapping: dict = field(default_factory=dict)
|
|
28
|
+
dropped_categories: list = field(default_factory=list)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CategoricalEncoder:
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
strategy: EncodingStrategy = EncodingStrategy.LABEL,
|
|
35
|
+
drop_first: bool = True,
|
|
36
|
+
handle_unknown: str = "error",
|
|
37
|
+
categories: Optional[list] = None,
|
|
38
|
+
period: Optional[int] = None,
|
|
39
|
+
smoothing: float = 1.0,
|
|
40
|
+
min_frequency: Optional[int] = None
|
|
41
|
+
):
|
|
42
|
+
self.strategy = strategy
|
|
43
|
+
self.drop_first = drop_first
|
|
44
|
+
self.handle_unknown = handle_unknown
|
|
45
|
+
self.categories = categories
|
|
46
|
+
self.period = period
|
|
47
|
+
self.smoothing = smoothing
|
|
48
|
+
self.min_frequency = min_frequency
|
|
49
|
+
self._mapping: Optional[dict] = None
|
|
50
|
+
self._categories: Optional[list] = None
|
|
51
|
+
self._target_means: Optional[dict] = None
|
|
52
|
+
self._global_mean: Optional[float] = None
|
|
53
|
+
self._frequencies: Optional[dict] = None
|
|
54
|
+
self._cyclical_mapping: Optional[dict] = None
|
|
55
|
+
self._is_fitted = False
|
|
56
|
+
|
|
57
|
+
def fit(self, series: Series, target: Optional[Series] = None) -> "CategoricalEncoder":
|
|
58
|
+
clean = series.dropna()
|
|
59
|
+
|
|
60
|
+
if self.strategy == EncodingStrategy.ONE_HOT:
|
|
61
|
+
self._fit_one_hot(clean)
|
|
62
|
+
elif self.strategy == EncodingStrategy.LABEL:
|
|
63
|
+
self._fit_label(clean)
|
|
64
|
+
elif self.strategy == EncodingStrategy.ORDINAL:
|
|
65
|
+
self._fit_ordinal(clean)
|
|
66
|
+
elif self.strategy == EncodingStrategy.CYCLICAL:
|
|
67
|
+
self._fit_cyclical(clean)
|
|
68
|
+
elif self.strategy == EncodingStrategy.FREQUENCY:
|
|
69
|
+
self._fit_frequency(clean)
|
|
70
|
+
elif self.strategy == EncodingStrategy.TARGET:
|
|
71
|
+
self._fit_target(clean, target)
|
|
72
|
+
|
|
73
|
+
self._is_fitted = True
|
|
74
|
+
return self
|
|
75
|
+
|
|
76
|
+
def transform(self, series: Series, target: Optional[Series] = None) -> CategoricalEncodeResult:
|
|
77
|
+
if not self._is_fitted:
|
|
78
|
+
raise ValueError("Encoder not fitted. Call fit() or fit_transform() first.")
|
|
79
|
+
return self._apply_encoding(series, target)
|
|
80
|
+
|
|
81
|
+
def fit_transform(self, series: Series, target: Optional[Series] = None) -> CategoricalEncodeResult:
|
|
82
|
+
self.fit(series, target)
|
|
83
|
+
return self._apply_encoding(series, target)
|
|
84
|
+
|
|
85
|
+
def _fit_one_hot(self, clean: Series):
|
|
86
|
+
categories = clean.unique().tolist()
|
|
87
|
+
if self.min_frequency is not None:
|
|
88
|
+
value_counts = clean.value_counts()
|
|
89
|
+
categories = [c for c in categories if value_counts.get(c, 0) >= self.min_frequency]
|
|
90
|
+
self._categories = sorted(categories)
|
|
91
|
+
self._mapping = {cat: i for i, cat in enumerate(self._categories)}
|
|
92
|
+
|
|
93
|
+
def _fit_label(self, clean: Series):
|
|
94
|
+
categories = sorted(clean.unique().tolist())
|
|
95
|
+
self._mapping = {cat: i for i, cat in enumerate(categories)}
|
|
96
|
+
|
|
97
|
+
def _fit_ordinal(self, clean: Series):
|
|
98
|
+
if self.categories is None:
|
|
99
|
+
raise ValueError("Ordinal encoding requires categories parameter")
|
|
100
|
+
self._mapping = {cat: i for i, cat in enumerate(self.categories)}
|
|
101
|
+
|
|
102
|
+
def _fit_cyclical(self, clean: Series):
|
|
103
|
+
if self.period is None:
|
|
104
|
+
raise ValueError("Cyclical encoding requires period parameter")
|
|
105
|
+
# Check if values are strings and need mapping to indices
|
|
106
|
+
if clean.dtype == object:
|
|
107
|
+
unique_values = sorted(clean.unique().tolist())
|
|
108
|
+
# Auto-detect day of week names
|
|
109
|
+
day_names = {
|
|
110
|
+
'monday': 0, 'tuesday': 1, 'wednesday': 2, 'thursday': 3,
|
|
111
|
+
'friday': 4, 'saturday': 5, 'sunday': 6,
|
|
112
|
+
'mon': 0, 'tue': 1, 'wed': 2, 'thu': 3, 'fri': 4, 'sat': 5, 'sun': 6
|
|
113
|
+
}
|
|
114
|
+
month_names = {
|
|
115
|
+
'january': 0, 'february': 1, 'march': 2, 'april': 3, 'may': 4, 'june': 5,
|
|
116
|
+
'july': 6, 'august': 7, 'september': 8, 'october': 9, 'november': 10, 'december': 11,
|
|
117
|
+
'jan': 0, 'feb': 1, 'mar': 2, 'apr': 3, 'jun': 5, 'jul': 6, 'aug': 7,
|
|
118
|
+
'sep': 8, 'oct': 9, 'nov': 10, 'dec': 11
|
|
119
|
+
}
|
|
120
|
+
# Try to auto-detect mapping from common patterns
|
|
121
|
+
sample_lower = [str(v).lower() for v in unique_values]
|
|
122
|
+
if all(s in day_names for s in sample_lower):
|
|
123
|
+
self._cyclical_mapping = {v: day_names[str(v).lower()] for v in unique_values}
|
|
124
|
+
elif all(s in month_names for s in sample_lower):
|
|
125
|
+
self._cyclical_mapping = {v: month_names[str(v).lower()] for v in unique_values}
|
|
126
|
+
else:
|
|
127
|
+
# Generic mapping: assign indices based on order
|
|
128
|
+
self._cyclical_mapping = {v: i for i, v in enumerate(unique_values)}
|
|
129
|
+
else:
|
|
130
|
+
self._cyclical_mapping = None
|
|
131
|
+
|
|
132
|
+
def _fit_frequency(self, clean: Series):
|
|
133
|
+
total = len(clean)
|
|
134
|
+
value_counts = clean.value_counts()
|
|
135
|
+
self._frequencies = {cat: count / total for cat, count in value_counts.items()}
|
|
136
|
+
|
|
137
|
+
def _fit_target(self, clean: Series, target: Optional[Series]):
|
|
138
|
+
if target is None:
|
|
139
|
+
raise ValueError("Target encoding requires target parameter")
|
|
140
|
+
|
|
141
|
+
self._global_mean = target.mean()
|
|
142
|
+
self._target_means = {}
|
|
143
|
+
|
|
144
|
+
for cat in clean.unique():
|
|
145
|
+
mask = clean == cat
|
|
146
|
+
cat_target = target[mask]
|
|
147
|
+
n = len(cat_target)
|
|
148
|
+
cat_target.mean()
|
|
149
|
+
|
|
150
|
+
smoothed = (cat_target.sum() + self.smoothing * self._global_mean) / (n + self.smoothing)
|
|
151
|
+
self._target_means[cat] = smoothed
|
|
152
|
+
|
|
153
|
+
def _apply_encoding(self, series: Series, target: Optional[Series] = None) -> CategoricalEncodeResult:
|
|
154
|
+
if self.strategy == EncodingStrategy.ONE_HOT:
|
|
155
|
+
return self._encode_one_hot(series)
|
|
156
|
+
elif self.strategy == EncodingStrategy.LABEL:
|
|
157
|
+
return self._encode_label(series)
|
|
158
|
+
elif self.strategy == EncodingStrategy.ORDINAL:
|
|
159
|
+
return self._encode_ordinal(series)
|
|
160
|
+
elif self.strategy == EncodingStrategy.CYCLICAL:
|
|
161
|
+
return self._encode_cyclical(series)
|
|
162
|
+
elif self.strategy == EncodingStrategy.FREQUENCY:
|
|
163
|
+
return self._encode_frequency(series)
|
|
164
|
+
elif self.strategy == EncodingStrategy.TARGET:
|
|
165
|
+
return self._encode_target(series)
|
|
166
|
+
|
|
167
|
+
return CategoricalEncodeResult(series=series, strategy=self.strategy)
|
|
168
|
+
|
|
169
|
+
def _encode_one_hot(self, series: Series) -> CategoricalEncodeResult:
|
|
170
|
+
categories = self._categories if self._categories else sorted(series.dropna().unique().tolist())
|
|
171
|
+
if self.drop_first and len(categories) > 0:
|
|
172
|
+
categories = categories[1:]
|
|
173
|
+
|
|
174
|
+
cols = {}
|
|
175
|
+
col_names = []
|
|
176
|
+
for cat in categories:
|
|
177
|
+
col_name = f"{series.name or 'col'}_{cat}"
|
|
178
|
+
cols[col_name] = (series == cat).astype(int)
|
|
179
|
+
col_names.append(col_name)
|
|
180
|
+
|
|
181
|
+
if self.handle_unknown == "ignore":
|
|
182
|
+
for col in cols:
|
|
183
|
+
known_cats = set(self._categories) if self._categories else set()
|
|
184
|
+
unknown_mask = ~series.isin(known_cats) & series.notna()
|
|
185
|
+
cols[col] = cols[col].where(~unknown_mask, 0)
|
|
186
|
+
|
|
187
|
+
df = DataFrame(cols)
|
|
188
|
+
|
|
189
|
+
return CategoricalEncodeResult(
|
|
190
|
+
df=df, strategy=self.strategy,
|
|
191
|
+
columns_created=col_names, mapping=self._mapping or {}
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _encode_label(self, series: Series) -> CategoricalEncodeResult:
|
|
195
|
+
result = series.map(self._mapping)
|
|
196
|
+
return CategoricalEncodeResult(
|
|
197
|
+
series=result, strategy=self.strategy, mapping=self._mapping or {}
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def _encode_ordinal(self, series: Series) -> CategoricalEncodeResult:
|
|
201
|
+
unknown = series[series.notna() & ~series.isin(self._mapping.keys())]
|
|
202
|
+
if len(unknown) > 0 and self.handle_unknown == "error":
|
|
203
|
+
raise ValueError(f"Found unknown categories: {unknown.unique().tolist()}")
|
|
204
|
+
|
|
205
|
+
result = series.map(self._mapping)
|
|
206
|
+
return CategoricalEncodeResult(
|
|
207
|
+
series=result, strategy=self.strategy, mapping=self._mapping or {}
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def _encode_cyclical(self, series: Series) -> CategoricalEncodeResult:
|
|
211
|
+
# Map strings to numeric indices if mapping exists
|
|
212
|
+
if hasattr(self, '_cyclical_mapping') and self._cyclical_mapping is not None:
|
|
213
|
+
numeric = series.map(self._cyclical_mapping)
|
|
214
|
+
else:
|
|
215
|
+
numeric = pd.to_numeric(series, errors='coerce')
|
|
216
|
+
|
|
217
|
+
sin_vals = np.sin(2 * np.pi * numeric / self.period)
|
|
218
|
+
cos_vals = np.cos(2 * np.pi * numeric / self.period)
|
|
219
|
+
|
|
220
|
+
col_name = series.name or "col"
|
|
221
|
+
sin_col = f"{col_name}_sin"
|
|
222
|
+
cos_col = f"{col_name}_cos"
|
|
223
|
+
|
|
224
|
+
df = DataFrame({sin_col: sin_vals, cos_col: cos_vals})
|
|
225
|
+
|
|
226
|
+
return CategoricalEncodeResult(
|
|
227
|
+
df=df, strategy=self.strategy, columns_created=[sin_col, cos_col],
|
|
228
|
+
mapping=self._cyclical_mapping if hasattr(self, '_cyclical_mapping') else {}
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def _encode_frequency(self, series: Series) -> CategoricalEncodeResult:
|
|
232
|
+
result = series.map(self._frequencies)
|
|
233
|
+
return CategoricalEncodeResult(
|
|
234
|
+
series=result, strategy=self.strategy, mapping=self._frequencies or {}
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def _encode_target(self, series: Series) -> CategoricalEncodeResult:
|
|
238
|
+
result = series.map(self._target_means)
|
|
239
|
+
result.isna() & series.notna()
|
|
240
|
+
result = result.fillna(self._global_mean)
|
|
241
|
+
result = result.where(series.notna(), np.nan)
|
|
242
|
+
|
|
243
|
+
return CategoricalEncodeResult(
|
|
244
|
+
series=result, strategy=self.strategy, mapping=self._target_means or {}
|
|
245
|
+
)
|