churnkit 0.75.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
  2. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
  3. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
  4. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
  5. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
  6. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
  7. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
  8. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
  9. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
  10. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
  11. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
  12. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
  13. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
  14. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
  15. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
  16. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
  17. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
  18. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
  19. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
  20. churnkit-0.75.0a1.dist-info/METADATA +229 -0
  21. churnkit-0.75.0a1.dist-info/RECORD +302 -0
  22. churnkit-0.75.0a1.dist-info/WHEEL +4 -0
  23. churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
  24. churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
  25. customer_retention/__init__.py +37 -0
  26. customer_retention/analysis/__init__.py +0 -0
  27. customer_retention/analysis/auto_explorer/__init__.py +62 -0
  28. customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
  29. customer_retention/analysis/auto_explorer/explorer.py +258 -0
  30. customer_retention/analysis/auto_explorer/findings.py +291 -0
  31. customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
  32. customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
  33. customer_retention/analysis/auto_explorer/recommendations.py +418 -0
  34. customer_retention/analysis/business/__init__.py +26 -0
  35. customer_retention/analysis/business/ab_test_designer.py +144 -0
  36. customer_retention/analysis/business/fairness_analyzer.py +166 -0
  37. customer_retention/analysis/business/intervention_matcher.py +121 -0
  38. customer_retention/analysis/business/report_generator.py +222 -0
  39. customer_retention/analysis/business/risk_profile.py +199 -0
  40. customer_retention/analysis/business/roi_analyzer.py +139 -0
  41. customer_retention/analysis/diagnostics/__init__.py +20 -0
  42. customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
  43. customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
  44. customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
  45. customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
  46. customer_retention/analysis/diagnostics/noise_tester.py +140 -0
  47. customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
  48. customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
  49. customer_retention/analysis/discovery/__init__.py +8 -0
  50. customer_retention/analysis/discovery/config_generator.py +49 -0
  51. customer_retention/analysis/discovery/discovery_flow.py +19 -0
  52. customer_retention/analysis/discovery/type_inferencer.py +147 -0
  53. customer_retention/analysis/interpretability/__init__.py +13 -0
  54. customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
  55. customer_retention/analysis/interpretability/counterfactual.py +175 -0
  56. customer_retention/analysis/interpretability/individual_explainer.py +141 -0
  57. customer_retention/analysis/interpretability/pdp_generator.py +103 -0
  58. customer_retention/analysis/interpretability/shap_explainer.py +106 -0
  59. customer_retention/analysis/jupyter_save_hook.py +28 -0
  60. customer_retention/analysis/notebook_html_exporter.py +136 -0
  61. customer_retention/analysis/notebook_progress.py +60 -0
  62. customer_retention/analysis/plotly_preprocessor.py +154 -0
  63. customer_retention/analysis/recommendations/__init__.py +54 -0
  64. customer_retention/analysis/recommendations/base.py +158 -0
  65. customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
  66. customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
  67. customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
  68. customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
  69. customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
  70. customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
  71. customer_retention/analysis/recommendations/datetime/extract.py +149 -0
  72. customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
  73. customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
  74. customer_retention/analysis/recommendations/pipeline.py +74 -0
  75. customer_retention/analysis/recommendations/registry.py +76 -0
  76. customer_retention/analysis/recommendations/selection/__init__.py +3 -0
  77. customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
  78. customer_retention/analysis/recommendations/transform/__init__.py +4 -0
  79. customer_retention/analysis/recommendations/transform/power.py +94 -0
  80. customer_retention/analysis/recommendations/transform/scale.py +112 -0
  81. customer_retention/analysis/visualization/__init__.py +15 -0
  82. customer_retention/analysis/visualization/chart_builder.py +2619 -0
  83. customer_retention/analysis/visualization/console.py +122 -0
  84. customer_retention/analysis/visualization/display.py +171 -0
  85. customer_retention/analysis/visualization/number_formatter.py +36 -0
  86. customer_retention/artifacts/__init__.py +3 -0
  87. customer_retention/artifacts/fit_artifact_registry.py +146 -0
  88. customer_retention/cli.py +93 -0
  89. customer_retention/core/__init__.py +0 -0
  90. customer_retention/core/compat/__init__.py +193 -0
  91. customer_retention/core/compat/detection.py +99 -0
  92. customer_retention/core/compat/ops.py +48 -0
  93. customer_retention/core/compat/pandas_backend.py +57 -0
  94. customer_retention/core/compat/spark_backend.py +75 -0
  95. customer_retention/core/components/__init__.py +11 -0
  96. customer_retention/core/components/base.py +79 -0
  97. customer_retention/core/components/components/__init__.py +13 -0
  98. customer_retention/core/components/components/deployer.py +26 -0
  99. customer_retention/core/components/components/explainer.py +26 -0
  100. customer_retention/core/components/components/feature_eng.py +33 -0
  101. customer_retention/core/components/components/ingester.py +34 -0
  102. customer_retention/core/components/components/profiler.py +34 -0
  103. customer_retention/core/components/components/trainer.py +38 -0
  104. customer_retention/core/components/components/transformer.py +36 -0
  105. customer_retention/core/components/components/validator.py +37 -0
  106. customer_retention/core/components/enums.py +33 -0
  107. customer_retention/core/components/orchestrator.py +94 -0
  108. customer_retention/core/components/registry.py +59 -0
  109. customer_retention/core/config/__init__.py +39 -0
  110. customer_retention/core/config/column_config.py +95 -0
  111. customer_retention/core/config/experiments.py +71 -0
  112. customer_retention/core/config/pipeline_config.py +117 -0
  113. customer_retention/core/config/source_config.py +83 -0
  114. customer_retention/core/utils/__init__.py +28 -0
  115. customer_retention/core/utils/leakage.py +85 -0
  116. customer_retention/core/utils/severity.py +53 -0
  117. customer_retention/core/utils/statistics.py +90 -0
  118. customer_retention/generators/__init__.py +0 -0
  119. customer_retention/generators/notebook_generator/__init__.py +167 -0
  120. customer_retention/generators/notebook_generator/base.py +55 -0
  121. customer_retention/generators/notebook_generator/cell_builder.py +49 -0
  122. customer_retention/generators/notebook_generator/config.py +47 -0
  123. customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
  124. customer_retention/generators/notebook_generator/local_generator.py +48 -0
  125. customer_retention/generators/notebook_generator/project_init.py +174 -0
  126. customer_retention/generators/notebook_generator/runner.py +150 -0
  127. customer_retention/generators/notebook_generator/script_generator.py +110 -0
  128. customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
  129. customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
  130. customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
  131. customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
  132. customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
  133. customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
  134. customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
  135. customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
  136. customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
  137. customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
  138. customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
  139. customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
  140. customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
  141. customer_retention/generators/orchestration/__init__.py +23 -0
  142. customer_retention/generators/orchestration/code_generator.py +196 -0
  143. customer_retention/generators/orchestration/context.py +147 -0
  144. customer_retention/generators/orchestration/data_materializer.py +188 -0
  145. customer_retention/generators/orchestration/databricks_exporter.py +411 -0
  146. customer_retention/generators/orchestration/doc_generator.py +311 -0
  147. customer_retention/generators/pipeline_generator/__init__.py +26 -0
  148. customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
  149. customer_retention/generators/pipeline_generator/generator.py +142 -0
  150. customer_retention/generators/pipeline_generator/models.py +166 -0
  151. customer_retention/generators/pipeline_generator/renderer.py +2125 -0
  152. customer_retention/generators/spec_generator/__init__.py +37 -0
  153. customer_retention/generators/spec_generator/databricks_generator.py +433 -0
  154. customer_retention/generators/spec_generator/generic_generator.py +373 -0
  155. customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
  156. customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
  157. customer_retention/integrations/__init__.py +0 -0
  158. customer_retention/integrations/adapters/__init__.py +13 -0
  159. customer_retention/integrations/adapters/base.py +10 -0
  160. customer_retention/integrations/adapters/factory.py +25 -0
  161. customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
  162. customer_retention/integrations/adapters/feature_store/base.py +57 -0
  163. customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
  164. customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
  165. customer_retention/integrations/adapters/feature_store/local.py +75 -0
  166. customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
  167. customer_retention/integrations/adapters/mlflow/base.py +32 -0
  168. customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
  169. customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
  170. customer_retention/integrations/adapters/mlflow/local.py +50 -0
  171. customer_retention/integrations/adapters/storage/__init__.py +5 -0
  172. customer_retention/integrations/adapters/storage/base.py +33 -0
  173. customer_retention/integrations/adapters/storage/databricks.py +76 -0
  174. customer_retention/integrations/adapters/storage/local.py +59 -0
  175. customer_retention/integrations/feature_store/__init__.py +47 -0
  176. customer_retention/integrations/feature_store/definitions.py +215 -0
  177. customer_retention/integrations/feature_store/manager.py +744 -0
  178. customer_retention/integrations/feature_store/registry.py +412 -0
  179. customer_retention/integrations/iteration/__init__.py +28 -0
  180. customer_retention/integrations/iteration/context.py +212 -0
  181. customer_retention/integrations/iteration/feedback_collector.py +184 -0
  182. customer_retention/integrations/iteration/orchestrator.py +168 -0
  183. customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
  184. customer_retention/integrations/iteration/signals.py +212 -0
  185. customer_retention/integrations/llm_context/__init__.py +4 -0
  186. customer_retention/integrations/llm_context/context_builder.py +201 -0
  187. customer_retention/integrations/llm_context/prompts.py +100 -0
  188. customer_retention/integrations/streaming/__init__.py +103 -0
  189. customer_retention/integrations/streaming/batch_integration.py +149 -0
  190. customer_retention/integrations/streaming/early_warning_model.py +227 -0
  191. customer_retention/integrations/streaming/event_schema.py +214 -0
  192. customer_retention/integrations/streaming/online_store_writer.py +249 -0
  193. customer_retention/integrations/streaming/realtime_scorer.py +261 -0
  194. customer_retention/integrations/streaming/trigger_engine.py +293 -0
  195. customer_retention/integrations/streaming/window_aggregator.py +393 -0
  196. customer_retention/stages/__init__.py +0 -0
  197. customer_retention/stages/cleaning/__init__.py +9 -0
  198. customer_retention/stages/cleaning/base.py +28 -0
  199. customer_retention/stages/cleaning/missing_handler.py +160 -0
  200. customer_retention/stages/cleaning/outlier_handler.py +204 -0
  201. customer_retention/stages/deployment/__init__.py +28 -0
  202. customer_retention/stages/deployment/batch_scorer.py +106 -0
  203. customer_retention/stages/deployment/champion_challenger.py +299 -0
  204. customer_retention/stages/deployment/model_registry.py +182 -0
  205. customer_retention/stages/deployment/retraining_trigger.py +245 -0
  206. customer_retention/stages/features/__init__.py +73 -0
  207. customer_retention/stages/features/behavioral_features.py +266 -0
  208. customer_retention/stages/features/customer_segmentation.py +505 -0
  209. customer_retention/stages/features/feature_definitions.py +265 -0
  210. customer_retention/stages/features/feature_engineer.py +551 -0
  211. customer_retention/stages/features/feature_manifest.py +340 -0
  212. customer_retention/stages/features/feature_selector.py +239 -0
  213. customer_retention/stages/features/interaction_features.py +160 -0
  214. customer_retention/stages/features/temporal_features.py +243 -0
  215. customer_retention/stages/ingestion/__init__.py +9 -0
  216. customer_retention/stages/ingestion/load_result.py +32 -0
  217. customer_retention/stages/ingestion/loaders.py +195 -0
  218. customer_retention/stages/ingestion/source_registry.py +130 -0
  219. customer_retention/stages/modeling/__init__.py +31 -0
  220. customer_retention/stages/modeling/baseline_trainer.py +139 -0
  221. customer_retention/stages/modeling/cross_validator.py +125 -0
  222. customer_retention/stages/modeling/data_splitter.py +205 -0
  223. customer_retention/stages/modeling/feature_scaler.py +99 -0
  224. customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
  225. customer_retention/stages/modeling/imbalance_handler.py +282 -0
  226. customer_retention/stages/modeling/mlflow_logger.py +95 -0
  227. customer_retention/stages/modeling/model_comparator.py +149 -0
  228. customer_retention/stages/modeling/model_evaluator.py +138 -0
  229. customer_retention/stages/modeling/threshold_optimizer.py +131 -0
  230. customer_retention/stages/monitoring/__init__.py +37 -0
  231. customer_retention/stages/monitoring/alert_manager.py +328 -0
  232. customer_retention/stages/monitoring/drift_detector.py +201 -0
  233. customer_retention/stages/monitoring/performance_monitor.py +242 -0
  234. customer_retention/stages/preprocessing/__init__.py +5 -0
  235. customer_retention/stages/preprocessing/transformer_manager.py +284 -0
  236. customer_retention/stages/profiling/__init__.py +256 -0
  237. customer_retention/stages/profiling/categorical_distribution.py +269 -0
  238. customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
  239. customer_retention/stages/profiling/column_profiler.py +527 -0
  240. customer_retention/stages/profiling/distribution_analysis.py +483 -0
  241. customer_retention/stages/profiling/drift_detector.py +310 -0
  242. customer_retention/stages/profiling/feature_capacity.py +507 -0
  243. customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
  244. customer_retention/stages/profiling/profile_result.py +212 -0
  245. customer_retention/stages/profiling/quality_checks.py +1632 -0
  246. customer_retention/stages/profiling/relationship_detector.py +256 -0
  247. customer_retention/stages/profiling/relationship_recommender.py +454 -0
  248. customer_retention/stages/profiling/report_generator.py +520 -0
  249. customer_retention/stages/profiling/scd_analyzer.py +151 -0
  250. customer_retention/stages/profiling/segment_analyzer.py +632 -0
  251. customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
  252. customer_retention/stages/profiling/target_level_analyzer.py +217 -0
  253. customer_retention/stages/profiling/temporal_analyzer.py +388 -0
  254. customer_retention/stages/profiling/temporal_coverage.py +488 -0
  255. customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
  256. customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
  257. customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
  258. customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
  259. customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
  260. customer_retention/stages/profiling/text_embedder.py +87 -0
  261. customer_retention/stages/profiling/text_processor.py +115 -0
  262. customer_retention/stages/profiling/text_reducer.py +60 -0
  263. customer_retention/stages/profiling/time_series_profiler.py +303 -0
  264. customer_retention/stages/profiling/time_window_aggregator.py +376 -0
  265. customer_retention/stages/profiling/type_detector.py +382 -0
  266. customer_retention/stages/profiling/window_recommendation.py +288 -0
  267. customer_retention/stages/temporal/__init__.py +166 -0
  268. customer_retention/stages/temporal/access_guard.py +180 -0
  269. customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
  270. customer_retention/stages/temporal/data_preparer.py +178 -0
  271. customer_retention/stages/temporal/point_in_time_join.py +134 -0
  272. customer_retention/stages/temporal/point_in_time_registry.py +148 -0
  273. customer_retention/stages/temporal/scenario_detector.py +163 -0
  274. customer_retention/stages/temporal/snapshot_manager.py +259 -0
  275. customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
  276. customer_retention/stages/temporal/timestamp_discovery.py +531 -0
  277. customer_retention/stages/temporal/timestamp_manager.py +255 -0
  278. customer_retention/stages/transformation/__init__.py +13 -0
  279. customer_retention/stages/transformation/binary_handler.py +85 -0
  280. customer_retention/stages/transformation/categorical_encoder.py +245 -0
  281. customer_retention/stages/transformation/datetime_transformer.py +97 -0
  282. customer_retention/stages/transformation/numeric_transformer.py +181 -0
  283. customer_retention/stages/transformation/pipeline.py +257 -0
  284. customer_retention/stages/validation/__init__.py +60 -0
  285. customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
  286. customer_retention/stages/validation/business_sense_gate.py +173 -0
  287. customer_retention/stages/validation/data_quality_gate.py +235 -0
  288. customer_retention/stages/validation/data_validators.py +511 -0
  289. customer_retention/stages/validation/feature_quality_gate.py +183 -0
  290. customer_retention/stages/validation/gates.py +117 -0
  291. customer_retention/stages/validation/leakage_gate.py +352 -0
  292. customer_retention/stages/validation/model_validity_gate.py +213 -0
  293. customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
  294. customer_retention/stages/validation/quality_scorer.py +544 -0
  295. customer_retention/stages/validation/rule_generator.py +57 -0
  296. customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
  297. customer_retention/stages/validation/timeseries_detector.py +769 -0
  298. customer_retention/transforms/__init__.py +47 -0
  299. customer_retention/transforms/artifact_store.py +50 -0
  300. customer_retention/transforms/executor.py +157 -0
  301. customer_retention/transforms/fitted.py +92 -0
  302. customer_retention/transforms/ops.py +148 -0
@@ -0,0 +1,130 @@
1
+ import json
2
+ from datetime import datetime
3
+ from typing import Optional
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from customer_retention.core.config.source_config import DataSourceConfig
8
+
9
+ from .load_result import LoadResult
10
+
11
+
12
+ class LoadHistoryEntry(BaseModel):
13
+ timestamp: str
14
+ row_count: int
15
+ duration_seconds: float
16
+ success: bool
17
+ warnings: list[str] = []
18
+ errors: list[str] = []
19
+
20
+
21
+ class SourceRegistration(BaseModel):
22
+ source_config: DataSourceConfig
23
+ registered_at: str
24
+ registered_by: str
25
+ last_loaded_at: Optional[str] = None
26
+ last_row_count: Optional[int] = None
27
+ last_load_duration: Optional[float] = None
28
+ load_history: list[LoadHistoryEntry] = []
29
+
30
+ def update_from_load(self, load_result: LoadResult) -> None:
31
+ entry = LoadHistoryEntry(
32
+ timestamp=datetime.now().isoformat(),
33
+ row_count=load_result.row_count,
34
+ duration_seconds=load_result.duration_seconds,
35
+ success=load_result.success,
36
+ warnings=load_result.warnings,
37
+ errors=load_result.errors
38
+ )
39
+
40
+ self.load_history.append(entry)
41
+ if len(self.load_history) > 100:
42
+ self.load_history = self.load_history[-100:]
43
+
44
+ self.last_loaded_at = entry.timestamp
45
+ self.last_row_count = load_result.row_count
46
+ self.last_load_duration = load_result.duration_seconds
47
+
48
+
49
+ class DataSourceRegistry:
50
+ def __init__(self):
51
+ self._sources: dict[str, SourceRegistration] = {}
52
+
53
+ def register(self, config: DataSourceConfig, registered_by: str = "system",
54
+ overwrite: bool = False) -> None:
55
+ if config.name in self._sources and not overwrite:
56
+ raise ValueError(f"Source '{config.name}' already registered. Use overwrite=True to replace.")
57
+
58
+ self._sources[config.name] = SourceRegistration(
59
+ source_config=config,
60
+ registered_at=datetime.now().isoformat(),
61
+ registered_by=registered_by
62
+ )
63
+
64
+ def get(self, name: str) -> Optional[SourceRegistration]:
65
+ return self._sources.get(name)
66
+
67
+ def list_sources(self) -> list[str]:
68
+ return list(self._sources.keys())
69
+
70
+ def record_load(self, source_name: str, load_result: LoadResult) -> None:
71
+ registration = self.get(source_name)
72
+ if not registration:
73
+ raise ValueError(f"Source '{source_name}' not found in registry")
74
+ registration.update_from_load(load_result)
75
+
76
+ def get_load_stats(self, source_name: str) -> dict:
77
+ registration = self.get(source_name)
78
+ if not registration:
79
+ raise ValueError(f"Source '{source_name}' not found in registry")
80
+
81
+ total_loads = len(registration.load_history)
82
+ successful_loads = sum(1 for entry in registration.load_history if entry.success)
83
+ failed_loads = total_loads - successful_loads
84
+
85
+ return {
86
+ "source_name": source_name,
87
+ "total_loads": total_loads,
88
+ "successful_loads": successful_loads,
89
+ "failed_loads": failed_loads,
90
+ "last_loaded_at": registration.last_loaded_at,
91
+ "last_row_count": registration.last_row_count,
92
+ "last_load_duration": registration.last_load_duration
93
+ }
94
+
95
+ def save_to_file(self, path: str) -> None:
96
+ data = {name: reg.model_dump() for name, reg in self._sources.items()}
97
+ with open(path, 'w') as f:
98
+ json.dump(data, f, indent=2)
99
+
100
+ def load_from_file(self, path: str) -> None:
101
+ with open(path, 'r') as f:
102
+ data = json.load(f)
103
+ self._sources = {
104
+ name: SourceRegistration(**reg_data)
105
+ for name, reg_data in data.items()
106
+ }
107
+
108
+ def validate_source(self, config: DataSourceConfig) -> list[str]:
109
+ errors = []
110
+
111
+ if not config.name:
112
+ errors.append("Source name is required")
113
+ if not config.primary_key:
114
+ errors.append("Primary key is required")
115
+
116
+ duplicate_columns = self.find_duplicate_column_names(config)
117
+ if duplicate_columns:
118
+ errors.append(f"Duplicate column names found: {', '.join(duplicate_columns)}")
119
+
120
+ return errors
121
+
122
+ def find_duplicate_column_names(self, config: DataSourceConfig) -> list[str]:
123
+ column_names = [c.name for c in config.columns]
124
+ seen = set()
125
+ duplicates = set()
126
+ for name in column_names:
127
+ if name in seen:
128
+ duplicates.add(name)
129
+ seen.add(name)
130
+ return list(duplicates)
@@ -0,0 +1,31 @@
1
+ from .baseline_trainer import BaselineTrainer, ModelType, TrainedModel, TrainingConfig
2
+ from .cross_validator import CrossValidator, CVResult, CVStrategy
3
+ from .data_splitter import DataSplitter, SplitConfig, SplitResult, SplitStrategy
4
+ from .feature_scaler import FeatureScaler, ScalerType, ScalingResult
5
+ from .hyperparameter_tuner import HyperparameterTuner, SearchStrategy, TuningResult
6
+ from .imbalance_handler import (
7
+ ClassWeightMethod,
8
+ ImbalanceHandler,
9
+ ImbalanceRecommendation,
10
+ ImbalanceRecommender,
11
+ ImbalanceResult,
12
+ ImbalanceStrategy,
13
+ )
14
+ from .mlflow_logger import ExperimentConfig, MLflowLogger
15
+ from .model_comparator import ComparisonResult, ModelComparator, ModelMetrics
16
+ from .model_evaluator import EvaluationResult, ModelEvaluator
17
+ from .threshold_optimizer import OptimizationObjective, ThresholdOptimizer, ThresholdResult
18
+
19
+ __all__ = [
20
+ "DataSplitter", "SplitStrategy", "SplitResult", "SplitConfig",
21
+ "ImbalanceHandler", "ImbalanceStrategy", "ClassWeightMethod", "ImbalanceResult",
22
+ "ImbalanceRecommender", "ImbalanceRecommendation",
23
+ "BaselineTrainer", "ModelType", "TrainingConfig", "TrainedModel",
24
+ "ModelEvaluator", "EvaluationResult",
25
+ "CrossValidator", "CVStrategy", "CVResult",
26
+ "HyperparameterTuner", "SearchStrategy", "TuningResult",
27
+ "ThresholdOptimizer", "OptimizationObjective", "ThresholdResult",
28
+ "ModelComparator", "ComparisonResult", "ModelMetrics",
29
+ "FeatureScaler", "ScalerType", "ScalingResult",
30
+ "MLflowLogger", "ExperimentConfig",
31
+ ]
@@ -0,0 +1,139 @@
1
+ """Baseline model training for customer retention prediction."""
2
+
3
+ import time
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from customer_retention.core.compat import DataFrame, Series
8
+ from customer_retention.core.components.enums import ModelType
9
+
10
+
11
+ @dataclass
12
+ class TrainingConfig:
13
+ random_state: int = 42
14
+ verbose: bool = False
15
+ n_jobs: int = -1
16
+
17
+
18
+ @dataclass
19
+ class TrainedModel:
20
+ model: Any
21
+ model_type: ModelType
22
+ hyperparameters: Dict[str, Any]
23
+ training_time: float
24
+ feature_names: List[str]
25
+ class_weight: Optional[Any] = None
26
+
27
+
28
+ class BaselineTrainer:
29
+ DEFAULT_PARAMS = {
30
+ ModelType.LOGISTIC_REGRESSION: {
31
+ "C": 1.0,
32
+ "solver": "lbfgs",
33
+ "max_iter": 1000,
34
+ },
35
+ ModelType.RANDOM_FOREST: {
36
+ "n_estimators": 100,
37
+ "max_depth": 10,
38
+ "min_samples_split": 5,
39
+ "min_samples_leaf": 2,
40
+ "n_jobs": -1,
41
+ },
42
+ ModelType.XGBOOST: {
43
+ "n_estimators": 100,
44
+ "max_depth": 6,
45
+ "learning_rate": 0.1,
46
+ "subsample": 0.8,
47
+ "colsample_bytree": 0.8,
48
+ "eval_metric": "logloss",
49
+ },
50
+ ModelType.LIGHTGBM: {
51
+ "n_estimators": 100,
52
+ "max_depth": 6,
53
+ "learning_rate": 0.1,
54
+ "num_leaves": 31,
55
+ },
56
+ }
57
+
58
+ def __init__(
59
+ self,
60
+ model_type: ModelType,
61
+ model_params: Optional[Dict[str, Any]] = None,
62
+ class_weight: Optional[Any] = None,
63
+ random_state: int = 42,
64
+ verbose: bool = False,
65
+ ):
66
+ self.model_type = model_type
67
+ self.model_params = model_params or {}
68
+ self.class_weight = class_weight
69
+ self.random_state = random_state
70
+ self.verbose = verbose
71
+
72
+ def fit(
73
+ self,
74
+ X: DataFrame,
75
+ y: Series,
76
+ X_val: Optional[DataFrame] = None,
77
+ y_val: Optional[Series] = None,
78
+ ) -> TrainedModel:
79
+ start_time = time.time()
80
+ params = self._build_params()
81
+ model = self._create_model(params)
82
+
83
+ if self.model_type == ModelType.XGBOOST and X_val is not None:
84
+ early_stopping = params.pop("early_stopping_rounds", None)
85
+ if early_stopping:
86
+ model.set_params(early_stopping_rounds=early_stopping)
87
+ model.fit(X, y, eval_set=[(X_val, y_val)], verbose=self.verbose)
88
+ else:
89
+ model.fit(X, y)
90
+ else:
91
+ model.fit(X, y)
92
+
93
+ training_time = time.time() - start_time
94
+
95
+ return TrainedModel(
96
+ model=model,
97
+ model_type=self.model_type,
98
+ hyperparameters=self._get_final_params(model),
99
+ training_time=training_time,
100
+ feature_names=list(X.columns),
101
+ class_weight=self.class_weight,
102
+ )
103
+
104
+ def _build_params(self) -> Dict[str, Any]:
105
+ defaults = self.DEFAULT_PARAMS.get(self.model_type, {}).copy()
106
+ defaults.update(self.model_params)
107
+ defaults["random_state"] = self.random_state
108
+ return defaults
109
+
110
+ def _create_model(self, params: Dict[str, Any]):
111
+ if self.model_type == ModelType.LOGISTIC_REGRESSION:
112
+ from sklearn.linear_model import LogisticRegression
113
+ if self.class_weight:
114
+ params["class_weight"] = self.class_weight
115
+ return LogisticRegression(**params)
116
+
117
+ if self.model_type == ModelType.RANDOM_FOREST:
118
+ from sklearn.ensemble import RandomForestClassifier
119
+ if self.class_weight:
120
+ params["class_weight"] = self.class_weight
121
+ return RandomForestClassifier(**params)
122
+
123
+ if self.model_type == ModelType.XGBOOST:
124
+ from xgboost import XGBClassifier
125
+ params.pop("class_weight", None)
126
+ return XGBClassifier(**params, verbosity=0 if not self.verbose else 1)
127
+
128
+ if self.model_type == ModelType.LIGHTGBM:
129
+ from lightgbm import LGBMClassifier
130
+ if self.class_weight:
131
+ params["class_weight"] = self.class_weight
132
+ return LGBMClassifier(**params, verbosity=-1 if not self.verbose else 1)
133
+
134
+ raise ValueError(f"Unsupported model type: {self.model_type}")
135
+
136
+ def _get_final_params(self, model) -> Dict[str, Any]:
137
+ if hasattr(model, "get_params"):
138
+ return model.get_params()
139
+ return self.model_params
@@ -0,0 +1,125 @@
1
+ """Cross-validation strategies for model evaluation."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ from sklearn.model_selection import GroupKFold, RepeatedStratifiedKFold, StratifiedKFold, cross_val_score
9
+
10
+ from customer_retention.core.compat import DataFrame, Series
11
+
12
+
13
+ class CVStrategy(Enum):
14
+ STRATIFIED_KFOLD = "stratified_kfold"
15
+ REPEATED_STRATIFIED = "repeated_stratified"
16
+ TIME_SERIES = "time_series"
17
+ GROUP_KFOLD = "group_kfold"
18
+
19
+
20
+ @dataclass
21
+ class CVResult:
22
+ cv_scores: np.ndarray
23
+ cv_mean: float
24
+ cv_std: float
25
+ fold_details: List[Dict[str, Any]]
26
+ scoring: str
27
+ is_stable: bool
28
+
29
+
30
+ class CrossValidator:
31
+ def __init__(
32
+ self,
33
+ strategy: CVStrategy = CVStrategy.STRATIFIED_KFOLD,
34
+ n_splits: int = 5,
35
+ n_repeats: int = 1,
36
+ shuffle: bool = True,
37
+ random_state: int = 42,
38
+ scoring: str = "average_precision",
39
+ stability_threshold: float = 0.10,
40
+ ):
41
+ self.strategy = strategy
42
+ self.n_splits = n_splits
43
+ self.n_repeats = n_repeats
44
+ self.shuffle = shuffle
45
+ self.random_state = random_state
46
+ self.scoring = scoring
47
+ self.stability_threshold = stability_threshold
48
+
49
+ def run(
50
+ self,
51
+ model,
52
+ X: DataFrame,
53
+ y: Series,
54
+ groups: Optional[Series] = None,
55
+ ) -> CVResult:
56
+ cv_splitter = self._create_cv_splitter(groups)
57
+ fold_details = []
58
+
59
+ if self.strategy == CVStrategy.GROUP_KFOLD:
60
+ scores = cross_val_score(model, X, y, cv=cv_splitter, scoring=self.scoring, groups=groups)
61
+ fold_details = self._collect_fold_details_with_groups(X, y, groups, cv_splitter)
62
+ else:
63
+ scores = cross_val_score(model, X, y, cv=cv_splitter, scoring=self.scoring)
64
+ fold_details = self._collect_fold_details(X, y, cv_splitter)
65
+
66
+ cv_mean = np.mean(scores)
67
+ cv_std = np.std(scores)
68
+ is_stable = bool(cv_std <= self.stability_threshold)
69
+
70
+ return CVResult(
71
+ cv_scores=scores,
72
+ cv_mean=cv_mean,
73
+ cv_std=cv_std,
74
+ fold_details=fold_details,
75
+ scoring=self.scoring,
76
+ is_stable=is_stable,
77
+ )
78
+
79
+ def _create_cv_splitter(self, groups: Optional[Series] = None):
80
+ if self.strategy == CVStrategy.STRATIFIED_KFOLD:
81
+ return StratifiedKFold(n_splits=self.n_splits, shuffle=self.shuffle, random_state=self.random_state)
82
+
83
+ if self.strategy == CVStrategy.REPEATED_STRATIFIED:
84
+ return RepeatedStratifiedKFold(n_splits=self.n_splits, n_repeats=self.n_repeats, random_state=self.random_state)
85
+
86
+ if self.strategy == CVStrategy.GROUP_KFOLD:
87
+ return GroupKFold(n_splits=self.n_splits)
88
+
89
+ if self.strategy == CVStrategy.TIME_SERIES:
90
+ from sklearn.model_selection import TimeSeriesSplit
91
+ return TimeSeriesSplit(n_splits=self.n_splits)
92
+
93
+ return StratifiedKFold(n_splits=self.n_splits, shuffle=self.shuffle, random_state=self.random_state)
94
+
95
+ def _collect_fold_details(self, X: DataFrame, y: Series, cv_splitter) -> List[Dict[str, Any]]:
96
+ fold_details = []
97
+ for fold_idx, (train_idx, test_idx) in enumerate(cv_splitter.split(X, y)):
98
+ y_train = y.iloc[train_idx]
99
+ fold_details.append({
100
+ "fold": fold_idx + 1,
101
+ "train_size": len(train_idx),
102
+ "test_size": len(test_idx),
103
+ "train_class_ratio": y_train.mean(),
104
+ "score": None,
105
+ })
106
+ return fold_details
107
+
108
+ def _collect_fold_details_with_groups(
109
+ self,
110
+ X: DataFrame,
111
+ y: Series,
112
+ groups: Series,
113
+ cv_splitter,
114
+ ) -> List[Dict[str, Any]]:
115
+ fold_details = []
116
+ for fold_idx, (train_idx, test_idx) in enumerate(cv_splitter.split(X, y, groups)):
117
+ y_train = y.iloc[train_idx]
118
+ fold_details.append({
119
+ "fold": fold_idx + 1,
120
+ "train_size": len(train_idx),
121
+ "test_size": len(test_idx),
122
+ "train_class_ratio": y_train.mean(),
123
+ "score": None,
124
+ })
125
+ return fold_details
@@ -0,0 +1,205 @@
1
+ import warnings
2
+ from dataclasses import dataclass, field
3
+ from enum import Enum
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
5
+
6
+ from sklearn.model_selection import GroupShuffleSplit, train_test_split
7
+
8
+ from customer_retention.core.compat import DataFrame, Series
9
+
10
+ if TYPE_CHECKING:
11
+ from customer_retention.analysis.auto_explorer.findings import FeatureAvailabilityMetadata
12
+
13
+
14
+ class SplitStrategy(Enum):
15
+ RANDOM_STRATIFIED = "random_stratified"
16
+ TEMPORAL = "temporal"
17
+ GROUP = "group"
18
+ CUSTOM = "custom"
19
+
20
+
21
+ @dataclass
22
+ class SplitConfig:
23
+ test_size: float = 0.11
24
+ validation_size: float = 0.10
25
+ stratify: bool = True
26
+ random_state: int = 42
27
+ temporal_column: Optional[str] = None
28
+ group_column: Optional[str] = None
29
+
30
+
31
+ @dataclass
32
+ class SplitResult:
33
+ X_train: DataFrame
34
+ X_test: DataFrame
35
+ y_train: Series
36
+ y_test: Series
37
+ X_val: Optional[DataFrame] = None
38
+ y_val: Optional[Series] = None
39
+ split_info: Dict[str, Any] = field(default_factory=dict)
40
+
41
+
42
+ @dataclass
43
+ class SplitWarning:
44
+ column: str
45
+ issue: str
46
+ severity: str
47
+ recommendation: str
48
+
49
+ def to_dict(self) -> Dict[str, str]:
50
+ return {"column": self.column, "issue": self.issue, "severity": self.severity, "recommendation": self.recommendation}
51
+
52
+
53
+ class DataSplitter:
54
+ def __init__(self, target_column: str, strategy: SplitStrategy = SplitStrategy.RANDOM_STRATIFIED, test_size: float = 0.11, validation_size: float = 0.10, stratify: bool = True, random_state: int = 42, temporal_column: Optional[str] = None, group_column: Optional[str] = None, exclude_columns: Optional[List[str]] = None, include_validation: bool = False):
55
+ self.target_column = target_column
56
+ self.strategy = strategy
57
+ self.test_size = test_size
58
+ self.validation_size = validation_size
59
+ self.stratify = stratify
60
+ self.random_state = random_state
61
+ self.temporal_column = temporal_column
62
+ self.group_column = group_column
63
+ self.exclude_columns = exclude_columns or []
64
+ self.include_validation = include_validation
65
+
66
+ def split(self, df: DataFrame, feature_availability: Optional["FeatureAvailabilityMetadata"] = None) -> SplitResult:
67
+ self._validate_minority_samples(df)
68
+ availability_warnings = self.validate_feature_availability(df, feature_availability)
69
+
70
+ if self.strategy == SplitStrategy.TEMPORAL:
71
+ result = self._temporal_split(df)
72
+ elif self.strategy == SplitStrategy.GROUP:
73
+ result = self._group_split(df)
74
+ else:
75
+ result = self._stratified_split(df)
76
+
77
+ if availability_warnings:
78
+ result.split_info["availability_warnings"] = [w.to_dict() for w in availability_warnings]
79
+ return result
80
+
81
+ def validate_feature_availability(self, df: DataFrame, availability: Optional["FeatureAvailabilityMetadata"]) -> List[SplitWarning]:
82
+ if availability is None:
83
+ return []
84
+ if self.strategy != SplitStrategy.TEMPORAL:
85
+ return []
86
+ warnings_list: List[SplitWarning] = []
87
+ for col in availability.new_tracking:
88
+ if col in df.columns:
89
+ feat_info = availability.features.get(col)
90
+ first_date = feat_info.first_valid_date if feat_info else "unknown"
91
+ warnings_list.append(SplitWarning(
92
+ column=col, issue="new_tracking", severity="warning",
93
+ recommendation=f"Feature '{col}' only available from {first_date}. Training data before this date will have missing values.",
94
+ ))
95
+ for col in availability.retired_tracking:
96
+ if col in df.columns:
97
+ feat_info = availability.features.get(col)
98
+ last_date = feat_info.last_valid_date if feat_info else "unknown"
99
+ warnings_list.append(SplitWarning(
100
+ column=col, issue="retired", severity="warning",
101
+ recommendation=f"Feature '{col}' retired at {last_date}. Test data after this date will have missing values.",
102
+ ))
103
+ for col in availability.partial_window:
104
+ if col in df.columns:
105
+ feat_info = availability.features.get(col)
106
+ first_date = feat_info.first_valid_date if feat_info else "unknown"
107
+ last_date = feat_info.last_valid_date if feat_info else "unknown"
108
+ warnings_list.append(SplitWarning(
109
+ column=col, issue="partial_window", severity="warning",
110
+ recommendation=f"Feature '{col}' only available {first_date} to {last_date}. Both train and test may have gaps.",
111
+ ))
112
+ return warnings_list
113
+
114
+ def _stratified_split(self, df: DataFrame) -> SplitResult:
115
+ X, y = self._prepare_features_target(df)
116
+ stratify_col = y if self.stratify else None
117
+
118
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=self.test_size, random_state=self.random_state, stratify=stratify_col)
119
+
120
+ X_val, y_val = None, None
121
+ if self.include_validation:
122
+ val_ratio = self.validation_size / (1 - self.test_size)
123
+ stratify_train = y_train if self.stratify else None
124
+ X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=val_ratio, random_state=self.random_state, stratify=stratify_train)
125
+
126
+ return SplitResult(
127
+ X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test,
128
+ X_val=X_val, y_val=y_val,
129
+ split_info=self._build_split_info(X_train, X_test, X_val)
130
+ )
131
+
132
+ def _temporal_split(self, df: DataFrame) -> SplitResult:
133
+ df_sorted = df.sort_values(self.temporal_column).reset_index(drop=True)
134
+ split_idx = int(len(df_sorted) * (1 - self.test_size))
135
+
136
+ train_df = df_sorted.iloc[:split_idx]
137
+ test_df = df_sorted.iloc[split_idx:]
138
+
139
+ X_train, y_train = self._prepare_features_target(train_df)
140
+ X_test, y_test = self._prepare_features_target(test_df)
141
+
142
+ X_val, y_val = None, None
143
+ if self.include_validation:
144
+ val_split = int(len(X_train) * (1 - self.validation_size / (1 - self.test_size)))
145
+ X_val, y_val = X_train.iloc[val_split:], y_train.iloc[val_split:]
146
+ X_train, y_train = X_train.iloc[:val_split], y_train.iloc[:val_split]
147
+
148
+ return SplitResult(
149
+ X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test,
150
+ X_val=X_val, y_val=y_val,
151
+ split_info=self._build_split_info(X_train, X_test, X_val)
152
+ )
153
+
154
+ def _group_split(self, df: DataFrame) -> SplitResult:
155
+ X, y = self._prepare_features_target(df)
156
+ groups = df[self.group_column]
157
+
158
+ gss = GroupShuffleSplit(n_splits=1, test_size=self.test_size, random_state=self.random_state)
159
+ train_idx, test_idx = next(gss.split(X, y, groups))
160
+
161
+ X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
162
+ y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
163
+
164
+ X_val, y_val = None, None
165
+ if self.include_validation:
166
+ val_ratio = self.validation_size / (1 - self.test_size)
167
+ train_groups = groups.iloc[train_idx]
168
+ gss_val = GroupShuffleSplit(n_splits=1, test_size=val_ratio, random_state=self.random_state)
169
+ train_idx2, val_idx2 = next(gss_val.split(X_train, y_train, train_groups))
170
+ X_val, y_val = X_train.iloc[val_idx2], y_train.iloc[val_idx2]
171
+ X_train, y_train = X_train.iloc[train_idx2], y_train.iloc[train_idx2]
172
+
173
+ return SplitResult(
174
+ X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test,
175
+ X_val=X_val, y_val=y_val,
176
+ split_info=self._build_split_info(X_train, X_test, X_val)
177
+ )
178
+
179
+ def _prepare_features_target(self, df: DataFrame) -> tuple[DataFrame, Series]:
180
+ exclude = [self.target_column] + self.exclude_columns
181
+ feature_cols = [c for c in df.columns if c not in exclude]
182
+ return df[feature_cols], df[self.target_column]
183
+
184
+ def _validate_minority_samples(self, df: DataFrame):
185
+ class_counts = df[self.target_column].value_counts()
186
+ minority_count = class_counts.min()
187
+ expected_minority_test = minority_count * self.test_size
188
+
189
+ if expected_minority_test < 50:
190
+ warnings.warn(
191
+ f"Insufficient minority samples: expected ~{expected_minority_test:.0f} in test set. "
192
+ "Consider using a smaller test_size or collecting more data.",
193
+ UserWarning
194
+ )
195
+
196
+ def _build_split_info(self, X_train, X_test, X_val) -> Dict[str, Any]:
197
+ info = {
198
+ "train_size": len(X_train),
199
+ "test_size": len(X_test),
200
+ "strategy": self.strategy.value,
201
+ "random_state": self.random_state,
202
+ }
203
+ if X_val is not None:
204
+ info["validation_size"] = len(X_val)
205
+ return info