churnkit 0.75.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
  2. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
  3. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
  4. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
  5. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
  6. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
  7. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
  8. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
  9. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
  10. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
  11. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
  12. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
  13. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
  14. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
  15. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
  16. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
  17. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
  18. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
  19. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
  20. churnkit-0.75.0a1.dist-info/METADATA +229 -0
  21. churnkit-0.75.0a1.dist-info/RECORD +302 -0
  22. churnkit-0.75.0a1.dist-info/WHEEL +4 -0
  23. churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
  24. churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
  25. customer_retention/__init__.py +37 -0
  26. customer_retention/analysis/__init__.py +0 -0
  27. customer_retention/analysis/auto_explorer/__init__.py +62 -0
  28. customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
  29. customer_retention/analysis/auto_explorer/explorer.py +258 -0
  30. customer_retention/analysis/auto_explorer/findings.py +291 -0
  31. customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
  32. customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
  33. customer_retention/analysis/auto_explorer/recommendations.py +418 -0
  34. customer_retention/analysis/business/__init__.py +26 -0
  35. customer_retention/analysis/business/ab_test_designer.py +144 -0
  36. customer_retention/analysis/business/fairness_analyzer.py +166 -0
  37. customer_retention/analysis/business/intervention_matcher.py +121 -0
  38. customer_retention/analysis/business/report_generator.py +222 -0
  39. customer_retention/analysis/business/risk_profile.py +199 -0
  40. customer_retention/analysis/business/roi_analyzer.py +139 -0
  41. customer_retention/analysis/diagnostics/__init__.py +20 -0
  42. customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
  43. customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
  44. customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
  45. customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
  46. customer_retention/analysis/diagnostics/noise_tester.py +140 -0
  47. customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
  48. customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
  49. customer_retention/analysis/discovery/__init__.py +8 -0
  50. customer_retention/analysis/discovery/config_generator.py +49 -0
  51. customer_retention/analysis/discovery/discovery_flow.py +19 -0
  52. customer_retention/analysis/discovery/type_inferencer.py +147 -0
  53. customer_retention/analysis/interpretability/__init__.py +13 -0
  54. customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
  55. customer_retention/analysis/interpretability/counterfactual.py +175 -0
  56. customer_retention/analysis/interpretability/individual_explainer.py +141 -0
  57. customer_retention/analysis/interpretability/pdp_generator.py +103 -0
  58. customer_retention/analysis/interpretability/shap_explainer.py +106 -0
  59. customer_retention/analysis/jupyter_save_hook.py +28 -0
  60. customer_retention/analysis/notebook_html_exporter.py +136 -0
  61. customer_retention/analysis/notebook_progress.py +60 -0
  62. customer_retention/analysis/plotly_preprocessor.py +154 -0
  63. customer_retention/analysis/recommendations/__init__.py +54 -0
  64. customer_retention/analysis/recommendations/base.py +158 -0
  65. customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
  66. customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
  67. customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
  68. customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
  69. customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
  70. customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
  71. customer_retention/analysis/recommendations/datetime/extract.py +149 -0
  72. customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
  73. customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
  74. customer_retention/analysis/recommendations/pipeline.py +74 -0
  75. customer_retention/analysis/recommendations/registry.py +76 -0
  76. customer_retention/analysis/recommendations/selection/__init__.py +3 -0
  77. customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
  78. customer_retention/analysis/recommendations/transform/__init__.py +4 -0
  79. customer_retention/analysis/recommendations/transform/power.py +94 -0
  80. customer_retention/analysis/recommendations/transform/scale.py +112 -0
  81. customer_retention/analysis/visualization/__init__.py +15 -0
  82. customer_retention/analysis/visualization/chart_builder.py +2619 -0
  83. customer_retention/analysis/visualization/console.py +122 -0
  84. customer_retention/analysis/visualization/display.py +171 -0
  85. customer_retention/analysis/visualization/number_formatter.py +36 -0
  86. customer_retention/artifacts/__init__.py +3 -0
  87. customer_retention/artifacts/fit_artifact_registry.py +146 -0
  88. customer_retention/cli.py +93 -0
  89. customer_retention/core/__init__.py +0 -0
  90. customer_retention/core/compat/__init__.py +193 -0
  91. customer_retention/core/compat/detection.py +99 -0
  92. customer_retention/core/compat/ops.py +48 -0
  93. customer_retention/core/compat/pandas_backend.py +57 -0
  94. customer_retention/core/compat/spark_backend.py +75 -0
  95. customer_retention/core/components/__init__.py +11 -0
  96. customer_retention/core/components/base.py +79 -0
  97. customer_retention/core/components/components/__init__.py +13 -0
  98. customer_retention/core/components/components/deployer.py +26 -0
  99. customer_retention/core/components/components/explainer.py +26 -0
  100. customer_retention/core/components/components/feature_eng.py +33 -0
  101. customer_retention/core/components/components/ingester.py +34 -0
  102. customer_retention/core/components/components/profiler.py +34 -0
  103. customer_retention/core/components/components/trainer.py +38 -0
  104. customer_retention/core/components/components/transformer.py +36 -0
  105. customer_retention/core/components/components/validator.py +37 -0
  106. customer_retention/core/components/enums.py +33 -0
  107. customer_retention/core/components/orchestrator.py +94 -0
  108. customer_retention/core/components/registry.py +59 -0
  109. customer_retention/core/config/__init__.py +39 -0
  110. customer_retention/core/config/column_config.py +95 -0
  111. customer_retention/core/config/experiments.py +71 -0
  112. customer_retention/core/config/pipeline_config.py +117 -0
  113. customer_retention/core/config/source_config.py +83 -0
  114. customer_retention/core/utils/__init__.py +28 -0
  115. customer_retention/core/utils/leakage.py +85 -0
  116. customer_retention/core/utils/severity.py +53 -0
  117. customer_retention/core/utils/statistics.py +90 -0
  118. customer_retention/generators/__init__.py +0 -0
  119. customer_retention/generators/notebook_generator/__init__.py +167 -0
  120. customer_retention/generators/notebook_generator/base.py +55 -0
  121. customer_retention/generators/notebook_generator/cell_builder.py +49 -0
  122. customer_retention/generators/notebook_generator/config.py +47 -0
  123. customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
  124. customer_retention/generators/notebook_generator/local_generator.py +48 -0
  125. customer_retention/generators/notebook_generator/project_init.py +174 -0
  126. customer_retention/generators/notebook_generator/runner.py +150 -0
  127. customer_retention/generators/notebook_generator/script_generator.py +110 -0
  128. customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
  129. customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
  130. customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
  131. customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
  132. customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
  133. customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
  134. customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
  135. customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
  136. customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
  137. customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
  138. customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
  139. customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
  140. customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
  141. customer_retention/generators/orchestration/__init__.py +23 -0
  142. customer_retention/generators/orchestration/code_generator.py +196 -0
  143. customer_retention/generators/orchestration/context.py +147 -0
  144. customer_retention/generators/orchestration/data_materializer.py +188 -0
  145. customer_retention/generators/orchestration/databricks_exporter.py +411 -0
  146. customer_retention/generators/orchestration/doc_generator.py +311 -0
  147. customer_retention/generators/pipeline_generator/__init__.py +26 -0
  148. customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
  149. customer_retention/generators/pipeline_generator/generator.py +142 -0
  150. customer_retention/generators/pipeline_generator/models.py +166 -0
  151. customer_retention/generators/pipeline_generator/renderer.py +2125 -0
  152. customer_retention/generators/spec_generator/__init__.py +37 -0
  153. customer_retention/generators/spec_generator/databricks_generator.py +433 -0
  154. customer_retention/generators/spec_generator/generic_generator.py +373 -0
  155. customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
  156. customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
  157. customer_retention/integrations/__init__.py +0 -0
  158. customer_retention/integrations/adapters/__init__.py +13 -0
  159. customer_retention/integrations/adapters/base.py +10 -0
  160. customer_retention/integrations/adapters/factory.py +25 -0
  161. customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
  162. customer_retention/integrations/adapters/feature_store/base.py +57 -0
  163. customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
  164. customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
  165. customer_retention/integrations/adapters/feature_store/local.py +75 -0
  166. customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
  167. customer_retention/integrations/adapters/mlflow/base.py +32 -0
  168. customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
  169. customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
  170. customer_retention/integrations/adapters/mlflow/local.py +50 -0
  171. customer_retention/integrations/adapters/storage/__init__.py +5 -0
  172. customer_retention/integrations/adapters/storage/base.py +33 -0
  173. customer_retention/integrations/adapters/storage/databricks.py +76 -0
  174. customer_retention/integrations/adapters/storage/local.py +59 -0
  175. customer_retention/integrations/feature_store/__init__.py +47 -0
  176. customer_retention/integrations/feature_store/definitions.py +215 -0
  177. customer_retention/integrations/feature_store/manager.py +744 -0
  178. customer_retention/integrations/feature_store/registry.py +412 -0
  179. customer_retention/integrations/iteration/__init__.py +28 -0
  180. customer_retention/integrations/iteration/context.py +212 -0
  181. customer_retention/integrations/iteration/feedback_collector.py +184 -0
  182. customer_retention/integrations/iteration/orchestrator.py +168 -0
  183. customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
  184. customer_retention/integrations/iteration/signals.py +212 -0
  185. customer_retention/integrations/llm_context/__init__.py +4 -0
  186. customer_retention/integrations/llm_context/context_builder.py +201 -0
  187. customer_retention/integrations/llm_context/prompts.py +100 -0
  188. customer_retention/integrations/streaming/__init__.py +103 -0
  189. customer_retention/integrations/streaming/batch_integration.py +149 -0
  190. customer_retention/integrations/streaming/early_warning_model.py +227 -0
  191. customer_retention/integrations/streaming/event_schema.py +214 -0
  192. customer_retention/integrations/streaming/online_store_writer.py +249 -0
  193. customer_retention/integrations/streaming/realtime_scorer.py +261 -0
  194. customer_retention/integrations/streaming/trigger_engine.py +293 -0
  195. customer_retention/integrations/streaming/window_aggregator.py +393 -0
  196. customer_retention/stages/__init__.py +0 -0
  197. customer_retention/stages/cleaning/__init__.py +9 -0
  198. customer_retention/stages/cleaning/base.py +28 -0
  199. customer_retention/stages/cleaning/missing_handler.py +160 -0
  200. customer_retention/stages/cleaning/outlier_handler.py +204 -0
  201. customer_retention/stages/deployment/__init__.py +28 -0
  202. customer_retention/stages/deployment/batch_scorer.py +106 -0
  203. customer_retention/stages/deployment/champion_challenger.py +299 -0
  204. customer_retention/stages/deployment/model_registry.py +182 -0
  205. customer_retention/stages/deployment/retraining_trigger.py +245 -0
  206. customer_retention/stages/features/__init__.py +73 -0
  207. customer_retention/stages/features/behavioral_features.py +266 -0
  208. customer_retention/stages/features/customer_segmentation.py +505 -0
  209. customer_retention/stages/features/feature_definitions.py +265 -0
  210. customer_retention/stages/features/feature_engineer.py +551 -0
  211. customer_retention/stages/features/feature_manifest.py +340 -0
  212. customer_retention/stages/features/feature_selector.py +239 -0
  213. customer_retention/stages/features/interaction_features.py +160 -0
  214. customer_retention/stages/features/temporal_features.py +243 -0
  215. customer_retention/stages/ingestion/__init__.py +9 -0
  216. customer_retention/stages/ingestion/load_result.py +32 -0
  217. customer_retention/stages/ingestion/loaders.py +195 -0
  218. customer_retention/stages/ingestion/source_registry.py +130 -0
  219. customer_retention/stages/modeling/__init__.py +31 -0
  220. customer_retention/stages/modeling/baseline_trainer.py +139 -0
  221. customer_retention/stages/modeling/cross_validator.py +125 -0
  222. customer_retention/stages/modeling/data_splitter.py +205 -0
  223. customer_retention/stages/modeling/feature_scaler.py +99 -0
  224. customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
  225. customer_retention/stages/modeling/imbalance_handler.py +282 -0
  226. customer_retention/stages/modeling/mlflow_logger.py +95 -0
  227. customer_retention/stages/modeling/model_comparator.py +149 -0
  228. customer_retention/stages/modeling/model_evaluator.py +138 -0
  229. customer_retention/stages/modeling/threshold_optimizer.py +131 -0
  230. customer_retention/stages/monitoring/__init__.py +37 -0
  231. customer_retention/stages/monitoring/alert_manager.py +328 -0
  232. customer_retention/stages/monitoring/drift_detector.py +201 -0
  233. customer_retention/stages/monitoring/performance_monitor.py +242 -0
  234. customer_retention/stages/preprocessing/__init__.py +5 -0
  235. customer_retention/stages/preprocessing/transformer_manager.py +284 -0
  236. customer_retention/stages/profiling/__init__.py +256 -0
  237. customer_retention/stages/profiling/categorical_distribution.py +269 -0
  238. customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
  239. customer_retention/stages/profiling/column_profiler.py +527 -0
  240. customer_retention/stages/profiling/distribution_analysis.py +483 -0
  241. customer_retention/stages/profiling/drift_detector.py +310 -0
  242. customer_retention/stages/profiling/feature_capacity.py +507 -0
  243. customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
  244. customer_retention/stages/profiling/profile_result.py +212 -0
  245. customer_retention/stages/profiling/quality_checks.py +1632 -0
  246. customer_retention/stages/profiling/relationship_detector.py +256 -0
  247. customer_retention/stages/profiling/relationship_recommender.py +454 -0
  248. customer_retention/stages/profiling/report_generator.py +520 -0
  249. customer_retention/stages/profiling/scd_analyzer.py +151 -0
  250. customer_retention/stages/profiling/segment_analyzer.py +632 -0
  251. customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
  252. customer_retention/stages/profiling/target_level_analyzer.py +217 -0
  253. customer_retention/stages/profiling/temporal_analyzer.py +388 -0
  254. customer_retention/stages/profiling/temporal_coverage.py +488 -0
  255. customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
  256. customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
  257. customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
  258. customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
  259. customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
  260. customer_retention/stages/profiling/text_embedder.py +87 -0
  261. customer_retention/stages/profiling/text_processor.py +115 -0
  262. customer_retention/stages/profiling/text_reducer.py +60 -0
  263. customer_retention/stages/profiling/time_series_profiler.py +303 -0
  264. customer_retention/stages/profiling/time_window_aggregator.py +376 -0
  265. customer_retention/stages/profiling/type_detector.py +382 -0
  266. customer_retention/stages/profiling/window_recommendation.py +288 -0
  267. customer_retention/stages/temporal/__init__.py +166 -0
  268. customer_retention/stages/temporal/access_guard.py +180 -0
  269. customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
  270. customer_retention/stages/temporal/data_preparer.py +178 -0
  271. customer_retention/stages/temporal/point_in_time_join.py +134 -0
  272. customer_retention/stages/temporal/point_in_time_registry.py +148 -0
  273. customer_retention/stages/temporal/scenario_detector.py +163 -0
  274. customer_retention/stages/temporal/snapshot_manager.py +259 -0
  275. customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
  276. customer_retention/stages/temporal/timestamp_discovery.py +531 -0
  277. customer_retention/stages/temporal/timestamp_manager.py +255 -0
  278. customer_retention/stages/transformation/__init__.py +13 -0
  279. customer_retention/stages/transformation/binary_handler.py +85 -0
  280. customer_retention/stages/transformation/categorical_encoder.py +245 -0
  281. customer_retention/stages/transformation/datetime_transformer.py +97 -0
  282. customer_retention/stages/transformation/numeric_transformer.py +181 -0
  283. customer_retention/stages/transformation/pipeline.py +257 -0
  284. customer_retention/stages/validation/__init__.py +60 -0
  285. customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
  286. customer_retention/stages/validation/business_sense_gate.py +173 -0
  287. customer_retention/stages/validation/data_quality_gate.py +235 -0
  288. customer_retention/stages/validation/data_validators.py +511 -0
  289. customer_retention/stages/validation/feature_quality_gate.py +183 -0
  290. customer_retention/stages/validation/gates.py +117 -0
  291. customer_retention/stages/validation/leakage_gate.py +352 -0
  292. customer_retention/stages/validation/model_validity_gate.py +213 -0
  293. customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
  294. customer_retention/stages/validation/quality_scorer.py +544 -0
  295. customer_retention/stages/validation/rule_generator.py +57 -0
  296. customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
  297. customer_retention/stages/validation/timeseries_detector.py +769 -0
  298. customer_retention/transforms/__init__.py +47 -0
  299. customer_retention/transforms/artifact_store.py +50 -0
  300. customer_retention/transforms/executor.py +157 -0
  301. customer_retention/transforms/fitted.py +92 -0
  302. customer_retention/transforms/ops.py +148 -0
@@ -0,0 +1,418 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Dict, List
3
+
4
+ from customer_retention.core.config.column_config import ColumnType
5
+
6
+ from .findings import ExplorationFindings
7
+
8
+
9
+ @dataclass
10
+ class TargetRecommendation:
11
+ column_name: str
12
+ confidence: float
13
+ rationale: str
14
+ alternatives: List[str] = field(default_factory=list)
15
+ target_type: str = "binary"
16
+
17
+
18
+ @dataclass
19
+ class FeatureRecommendation:
20
+ source_column: str
21
+ feature_name: str
22
+ feature_type: str
23
+ description: str
24
+ priority: str = "medium"
25
+ implementation_hint: str = ""
26
+
27
+
28
+ @dataclass
29
+ class CleaningRecommendation:
30
+ column_name: str
31
+ issue_type: str
32
+ severity: str
33
+ strategy: str
34
+ description: str
35
+ affected_rows: int = 0
36
+ strategy_label: str = ""
37
+ problem_impact: str = ""
38
+ action_steps: List[str] = field(default_factory=list)
39
+
40
+
41
+ @dataclass
42
+ class TransformRecommendation:
43
+ column_name: str
44
+ transform_type: str
45
+ reason: str
46
+ parameters: Dict[str, Any] = field(default_factory=dict)
47
+ priority: str = "medium"
48
+
49
+
50
+ class RecommendationEngine:
51
+ TARGET_PATTERNS = ["target", "label", "churn", "churned", "outcome", "class", "flag"]
52
+ SKEWNESS_THRESHOLD = 1.0
53
+ OUTLIER_THRESHOLD = 5.0
54
+ NULL_WARNING_THRESHOLD = 5.0
55
+ NULL_CRITICAL_THRESHOLD = 20.0
56
+
57
+ def __init__(self, min_confidence: float = 0.7):
58
+ self.min_confidence = min_confidence
59
+
60
+ def recommend_target(self, findings: ExplorationFindings) -> TargetRecommendation:
61
+ if findings.target_column:
62
+ target_finding = findings.columns.get(findings.target_column)
63
+ return TargetRecommendation(
64
+ column_name=findings.target_column,
65
+ confidence=target_finding.confidence if target_finding else 0.9,
66
+ rationale=f"Target already detected as {findings.target_type}",
67
+ alternatives=self._find_alternative_targets(findings),
68
+ target_type=findings.target_type or "binary"
69
+ )
70
+ return self._infer_target(findings)
71
+
72
+ def _infer_target(self, findings: ExplorationFindings) -> TargetRecommendation:
73
+ candidates = []
74
+ for name, col in findings.columns.items():
75
+ if col.inferred_type == ColumnType.IDENTIFIER:
76
+ continue
77
+ score = 0.0
78
+ rationale_parts = []
79
+ if col.inferred_type == ColumnType.BINARY:
80
+ score += 0.4
81
+ rationale_parts.append("Binary column")
82
+ if col.inferred_type == ColumnType.TARGET:
83
+ score += 0.5
84
+ rationale_parts.append("Detected as target type")
85
+ name_lower = name.lower()
86
+ for pattern in self.TARGET_PATTERNS:
87
+ if pattern in name_lower:
88
+ score += 0.3
89
+ rationale_parts.append(f"Name contains '{pattern}'")
90
+ break
91
+ distinct = col.universal_metrics.get("distinct_count", 0)
92
+ if 2 <= distinct <= 10:
93
+ score += 0.2
94
+ rationale_parts.append(f"Few distinct values ({distinct})")
95
+ if score > 0:
96
+ candidates.append((name, score, rationale_parts, col))
97
+ if not candidates:
98
+ return TargetRecommendation(
99
+ column_name="",
100
+ confidence=0.0,
101
+ rationale="No suitable target column found",
102
+ alternatives=[],
103
+ target_type="unknown"
104
+ )
105
+ candidates.sort(key=lambda x: x[1], reverse=True)
106
+ best = candidates[0]
107
+ target_type = "binary" if best[3].universal_metrics.get("distinct_count", 0) == 2 else "multiclass"
108
+ return TargetRecommendation(
109
+ column_name=best[0],
110
+ confidence=min(best[1], 1.0),
111
+ rationale="; ".join(best[2]),
112
+ alternatives=[c[0] for c in candidates[1:4]],
113
+ target_type=target_type
114
+ )
115
+
116
+ def _find_alternative_targets(self, findings: ExplorationFindings) -> List[str]:
117
+ alternatives = []
118
+ for name, col in findings.columns.items():
119
+ if name == findings.target_column:
120
+ continue
121
+ if col.inferred_type in [ColumnType.BINARY, ColumnType.TARGET]:
122
+ alternatives.append(name)
123
+ elif any(p in name.lower() for p in self.TARGET_PATTERNS):
124
+ alternatives.append(name)
125
+ return alternatives[:3]
126
+
127
+ def recommend_features(self, findings: ExplorationFindings) -> List[FeatureRecommendation]:
128
+ recommendations = []
129
+ for name, col in findings.columns.items():
130
+ if col.inferred_type == ColumnType.IDENTIFIER:
131
+ continue
132
+ if col.inferred_type == ColumnType.TARGET:
133
+ continue
134
+ recommendations.extend(self._feature_recs_for_column(name, col))
135
+ return recommendations
136
+
137
+ def _feature_recs_for_column(self, name: str, col) -> List[FeatureRecommendation]:
138
+ recs = []
139
+ if col.inferred_type == ColumnType.DATETIME:
140
+ recs.extend([
141
+ FeatureRecommendation(
142
+ source_column=name,
143
+ feature_name=f"{name}_year",
144
+ feature_type="temporal",
145
+ description=f"Extract year from {name}",
146
+ priority="medium",
147
+ implementation_hint="DatetimeTransformer.extract_year()"
148
+ ),
149
+ FeatureRecommendation(
150
+ source_column=name,
151
+ feature_name=f"{name}_month",
152
+ feature_type="temporal",
153
+ description=f"Extract month from {name}",
154
+ priority="medium",
155
+ implementation_hint="DatetimeTransformer.extract_month()"
156
+ ),
157
+ FeatureRecommendation(
158
+ source_column=name,
159
+ feature_name=f"{name}_dayofweek",
160
+ feature_type="temporal",
161
+ description=f"Extract day of week from {name}",
162
+ priority="medium",
163
+ implementation_hint="DatetimeTransformer.extract_dayofweek()"
164
+ ),
165
+ FeatureRecommendation(
166
+ source_column=name,
167
+ feature_name=f"days_since_{name}",
168
+ feature_type="datetime",
169
+ description=f"Days since {name} until today",
170
+ priority="high",
171
+ implementation_hint="DatetimeTransformer.days_since()"
172
+ )
173
+ ])
174
+ elif col.inferred_type in [ColumnType.NUMERIC_CONTINUOUS, ColumnType.NUMERIC_DISCRETE]:
175
+ recs.append(FeatureRecommendation(
176
+ source_column=name,
177
+ feature_name=f"{name}_binned",
178
+ feature_type="numeric",
179
+ description=f"Binned version of {name}",
180
+ priority="low",
181
+ implementation_hint="NumericTransformer.bin()"
182
+ ))
183
+ if col.type_metrics.get("skewness", 0) and abs(col.type_metrics.get("skewness", 0)) > self.SKEWNESS_THRESHOLD:
184
+ recs.append(FeatureRecommendation(
185
+ source_column=name,
186
+ feature_name=f"{name}_log",
187
+ feature_type="numeric",
188
+ description=f"Log transform of {name} (high skewness)",
189
+ priority="high",
190
+ implementation_hint="NumericTransformer.log_transform()"
191
+ ))
192
+ elif col.inferred_type in [ColumnType.CATEGORICAL_NOMINAL, ColumnType.CATEGORICAL_ORDINAL]:
193
+ cardinality = col.type_metrics.get("cardinality", 0)
194
+ if cardinality <= 10:
195
+ recs.append(FeatureRecommendation(
196
+ source_column=name,
197
+ feature_name=f"{name}_encoded",
198
+ feature_type="categorical",
199
+ description=f"One-hot encoded {name}",
200
+ priority="high",
201
+ implementation_hint="CategoricalEncoder.one_hot()"
202
+ ))
203
+ else:
204
+ recs.append(FeatureRecommendation(
205
+ source_column=name,
206
+ feature_name=f"{name}_target_encoded",
207
+ feature_type="categorical",
208
+ description=f"Target encoded {name}",
209
+ priority="medium",
210
+ implementation_hint="CategoricalEncoder.target_encode()"
211
+ ))
212
+ elif col.inferred_type == ColumnType.CATEGORICAL_CYCLICAL:
213
+ recs.append(FeatureRecommendation(
214
+ source_column=name,
215
+ feature_name=f"{name}_sin_cos",
216
+ feature_type="cyclical",
217
+ description=f"Cyclical encoding (sin/cos) for {name}",
218
+ priority="high",
219
+ implementation_hint="CategoricalEncoder.cyclical_encode()"
220
+ ))
221
+ return recs
222
+
223
+ def recommend_cleaning(self, findings: ExplorationFindings) -> List[CleaningRecommendation]:
224
+ recommendations = []
225
+ for name, col in findings.columns.items():
226
+ null_pct = col.universal_metrics.get("null_percentage", 0)
227
+ null_count = col.universal_metrics.get("null_count", 0)
228
+ if null_pct > self.NULL_CRITICAL_THRESHOLD:
229
+ recommendations.append(CleaningRecommendation(
230
+ column_name=name,
231
+ issue_type="missing_values",
232
+ severity="high",
233
+ strategy="drop_column_or_impute_indicator",
234
+ description=f"{null_pct:.1f}% missing values (critical)",
235
+ affected_rows=null_count,
236
+ strategy_label="Drop Column or Create Missing Indicator",
237
+ problem_impact="Models will fail or lose significant data. High missingness often indicates systematic data collection issues.",
238
+ action_steps=[
239
+ "Investigate why so much data is missing (data collection issue?)",
240
+ "If pattern-based: create binary indicator column for 'is_missing'",
241
+ "If random: consider dropping column if not critical",
242
+ "If critical: use advanced imputation (KNN, iterative)"
243
+ ]
244
+ ))
245
+ elif null_pct > self.NULL_WARNING_THRESHOLD:
246
+ is_numeric = col.inferred_type in [ColumnType.NUMERIC_CONTINUOUS, ColumnType.NUMERIC_DISCRETE]
247
+ strategy = "impute_median" if is_numeric else "impute_mode"
248
+ strategy_label = "Impute with Median" if is_numeric else "Impute with Mode"
249
+ recommendations.append(CleaningRecommendation(
250
+ column_name=name,
251
+ issue_type="missing_values",
252
+ severity="medium",
253
+ strategy=strategy,
254
+ description=f"{null_pct:.1f}% missing values",
255
+ affected_rows=null_count,
256
+ strategy_label=strategy_label,
257
+ problem_impact="May introduce bias if missing values are not random (MAR/MNAR). Model performance degradation possible.",
258
+ action_steps=[
259
+ "Check if missingness correlates with other columns (MAR pattern)",
260
+ f"{'Use median (robust to outliers)' if is_numeric else 'Use mode (most frequent value)'}",
261
+ "Consider creating additional 'is_missing' indicator feature",
262
+ "Validate imputation doesn't distort distributions"
263
+ ]
264
+ ))
265
+ elif null_count > 0:
266
+ is_numeric = col.inferred_type in [ColumnType.NUMERIC_CONTINUOUS, ColumnType.NUMERIC_DISCRETE]
267
+ strategy = "impute_median" if is_numeric else "impute_mode"
268
+ strategy_label = "Impute with Median" if is_numeric else "Impute with Mode"
269
+ recommendations.append(CleaningRecommendation(
270
+ column_name=name,
271
+ issue_type="null_values",
272
+ severity="low",
273
+ strategy=strategy,
274
+ description=f"{null_count} null values ({null_pct:.1f}%)",
275
+ affected_rows=null_count,
276
+ strategy_label=strategy_label,
277
+ problem_impact="Minor impact. Some models (XGBoost, LightGBM) handle nulls natively. Others will fail.",
278
+ action_steps=[
279
+ f"{'Impute with median for robustness' if is_numeric else 'Impute with most frequent value'}",
280
+ "Alternatively: drop rows if very few affected",
281
+ "For tree-based models: can leave as-is"
282
+ ]
283
+ ))
284
+ outlier_pct = col.type_metrics.get("outlier_percentage", 0)
285
+ if outlier_pct > self.OUTLIER_THRESHOLD:
286
+ recommendations.append(CleaningRecommendation(
287
+ column_name=name,
288
+ issue_type="outliers",
289
+ severity="medium",
290
+ strategy="clip_or_winsorize",
291
+ description=f"{outlier_pct:.1f}% outliers detected",
292
+ affected_rows=int(outlier_pct * findings.row_count / 100),
293
+ strategy_label="Clip to Bounds or Winsorize",
294
+ problem_impact="Outliers skew mean/std calculations, affect scaling, and can dominate model training. May cause unstable predictions.",
295
+ action_steps=[
296
+ "First verify if outliers are valid (high-value customers) or errors",
297
+ "If errors: remove or cap at reasonable bounds",
298
+ "If valid: clip to 1st/99th percentile (Winsorization)",
299
+ "Consider log transform if right-skewed",
300
+ "Use RobustScaler instead of StandardScaler"
301
+ ]
302
+ ))
303
+ return recommendations
304
+
305
+ def recommend_transformations(self, findings: ExplorationFindings) -> List[TransformRecommendation]:
306
+ recommendations = []
307
+ for name, col in findings.columns.items():
308
+ if col.inferred_type == ColumnType.IDENTIFIER:
309
+ continue
310
+ if col.inferred_type == ColumnType.TARGET:
311
+ continue
312
+ recommendations.extend(self._transform_recs_for_column(name, col))
313
+ return recommendations
314
+
315
+ def _transform_recs_for_column(self, name: str, col) -> List[TransformRecommendation]:
316
+ recs = []
317
+ if col.inferred_type in [ColumnType.NUMERIC_CONTINUOUS, ColumnType.NUMERIC_DISCRETE]:
318
+ skewness = col.type_metrics.get("skewness", 0)
319
+ if skewness and abs(skewness) > self.SKEWNESS_THRESHOLD:
320
+ recs.append(TransformRecommendation(
321
+ column_name=name,
322
+ transform_type="log_transform",
323
+ reason=f"High skewness ({skewness:.2f})",
324
+ parameters={"base": "natural"},
325
+ priority="high"
326
+ ))
327
+ outlier_pct = col.type_metrics.get("outlier_percentage", 0)
328
+ if outlier_pct > self.OUTLIER_THRESHOLD:
329
+ recs.append(TransformRecommendation(
330
+ column_name=name,
331
+ transform_type="robust_scaling",
332
+ reason=f"High outlier percentage ({outlier_pct:.1f}%)",
333
+ parameters={"method": "robust_scaler"},
334
+ priority="high"
335
+ ))
336
+ else:
337
+ recs.append(TransformRecommendation(
338
+ column_name=name,
339
+ transform_type="standard_scaling",
340
+ reason="Standard normalization for numeric column",
341
+ parameters={"method": "standard_scaler"},
342
+ priority="medium"
343
+ ))
344
+ elif col.inferred_type in [ColumnType.CATEGORICAL_NOMINAL, ColumnType.CATEGORICAL_ORDINAL]:
345
+ cardinality = col.type_metrics.get("cardinality", 0)
346
+ if cardinality <= 5:
347
+ recs.append(TransformRecommendation(
348
+ column_name=name,
349
+ transform_type="one_hot_encoding",
350
+ reason=f"Low cardinality ({cardinality})",
351
+ parameters={"drop_first": True},
352
+ priority="high"
353
+ ))
354
+ elif cardinality <= 20:
355
+ recs.append(TransformRecommendation(
356
+ column_name=name,
357
+ transform_type="target_encoding",
358
+ reason=f"Medium cardinality ({cardinality})",
359
+ parameters={"smoothing": 1.0},
360
+ priority="medium"
361
+ ))
362
+ else:
363
+ recs.append(TransformRecommendation(
364
+ column_name=name,
365
+ transform_type="hashing_encoding",
366
+ reason=f"High cardinality ({cardinality})",
367
+ parameters={"n_components": 8},
368
+ priority="medium"
369
+ ))
370
+ elif col.inferred_type == ColumnType.DATETIME:
371
+ recs.append(TransformRecommendation(
372
+ column_name=name,
373
+ transform_type="datetime_extraction",
374
+ reason="Extract temporal features from datetime",
375
+ parameters={"features": ["year", "month", "day", "dayofweek"]},
376
+ priority="high"
377
+ ))
378
+ elif col.inferred_type == ColumnType.BINARY:
379
+ recs.append(TransformRecommendation(
380
+ column_name=name,
381
+ transform_type="binary_encoding",
382
+ reason="Ensure binary column is 0/1",
383
+ parameters={"true_value": 1, "false_value": 0},
384
+ priority="low"
385
+ ))
386
+ return recs
387
+
388
+ def generate_summary(self, findings: ExplorationFindings) -> Dict[str, Any]:
389
+ return {
390
+ "target": self.recommend_target(findings),
391
+ "features": self.recommend_features(findings),
392
+ "cleaning": self.recommend_cleaning(findings),
393
+ "transformations": self.recommend_transformations(findings)
394
+ }
395
+
396
+ def to_markdown(self, findings: ExplorationFindings) -> str:
397
+ summary = self.generate_summary(findings)
398
+ lines = ["# Recommendations Report", ""]
399
+ lines.append("## Target Column")
400
+ target = summary["target"]
401
+ lines.append(f"**Recommended:** {target.column_name}")
402
+ lines.append(f"**Confidence:** {target.confidence:.0%}")
403
+ lines.append(f"**Rationale:** {target.rationale}")
404
+ if target.alternatives:
405
+ lines.append(f"**Alternatives:** {', '.join(target.alternatives)}")
406
+ lines.append("")
407
+ lines.append("## Feature Engineering Recommendations")
408
+ for rec in summary["features"][:10]:
409
+ lines.append(f"- **{rec.feature_name}** ({rec.priority}): {rec.description}")
410
+ lines.append("")
411
+ lines.append("## Data Cleaning Recommendations")
412
+ for rec in summary["cleaning"]:
413
+ lines.append(f"- **{rec.column_name}** [{rec.severity}]: {rec.description} → {rec.strategy}")
414
+ lines.append("")
415
+ lines.append("## Transformation Recommendations")
416
+ for rec in summary["transformations"][:10]:
417
+ lines.append(f"- **{rec.column_name}**: {rec.transform_type} ({rec.reason})")
418
+ return "\n".join(lines)
@@ -0,0 +1,26 @@
1
+ from .ab_test_designer import ABTestDesign, ABTestDesigner, MeasurementPlan, SampleSizeResult
2
+ from .fairness_analyzer import FairnessAnalyzer, FairnessMetric, FairnessResult, GroupMetrics
3
+ from .intervention_matcher import Intervention, InterventionCatalog, InterventionMatcher, InterventionRecommendation
4
+ from .intervention_matcher import RiskSegment as MatcherRiskSegment
5
+ from .report_generator import (
6
+ CampaignList,
7
+ CustomerServiceReport,
8
+ ExecutiveDashboard,
9
+ GovernanceReport,
10
+ ProductInsights,
11
+ ReportGenerator,
12
+ )
13
+ from .risk_profile import CustomerRiskProfile, RiskFactor, RiskProfiler, RiskSegment, Urgency
14
+ from .risk_profile import Intervention as RiskIntervention
15
+ from .roi_analyzer import InterventionROI, OptimizationResult, ROIAnalyzer, ROIResult
16
+
17
+ __all__ = [
18
+ "RiskProfiler", "CustomerRiskProfile", "RiskFactor", "RiskSegment", "Urgency",
19
+ "InterventionMatcher", "InterventionCatalog", "Intervention", "InterventionRecommendation",
20
+ "ROIAnalyzer", "ROIResult", "InterventionROI", "OptimizationResult",
21
+ "FairnessAnalyzer", "FairnessResult", "FairnessMetric", "GroupMetrics",
22
+ "ReportGenerator", "ExecutiveDashboard", "CampaignList", "CustomerServiceReport",
23
+ "ProductInsights", "GovernanceReport",
24
+ "ABTestDesigner", "ABTestDesign", "SampleSizeResult", "MeasurementPlan",
25
+ "MatcherRiskSegment", "RiskIntervention", # Aliases for disambiguation
26
+ ]
@@ -0,0 +1,144 @@
1
+ """A/B test design for retention interventions."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime, timedelta
5
+ from typing import List, Optional
6
+
7
+ import numpy as np
8
+ from scipy import stats
9
+
10
+ from customer_retention.core.compat import DataFrame, concat
11
+
12
+
13
+ @dataclass
14
+ class SampleSizeResult:
15
+ sample_size_per_group: int
16
+ total_sample_size: int
17
+ baseline_rate: float
18
+ min_detectable_effect: float
19
+ alpha: float
20
+ power: float
21
+
22
+
23
+ @dataclass
24
+ class MeasurementPlan:
25
+ primary_metric: str
26
+ secondary_metrics: List[str]
27
+ tracking_events: List[str] = field(default_factory=list)
28
+
29
+
30
+ @dataclass
31
+ class ABTestDesign:
32
+ test_name: str
33
+ control_name: str
34
+ treatment_groups: List[str]
35
+ recommended_sample_size: int
36
+ total_required: int
37
+ available_customers: int
38
+ feasible: bool
39
+ duration_days: int
40
+ expected_completion_date: datetime
41
+ stratification_variable: Optional[str] = None
42
+ measurement_plan: Optional[MeasurementPlan] = None
43
+
44
+
45
+ class ABTestDesigner:
46
+ def calculate_sample_size(self, baseline_rate: float = 0.25,
47
+ min_detectable_effect: float = 0.05,
48
+ alpha: float = 0.05, power: float = 0.80) -> SampleSizeResult:
49
+ z_alpha = stats.norm.ppf(1 - alpha / 2)
50
+ z_beta = stats.norm.ppf(power)
51
+ p1 = baseline_rate
52
+ p2 = baseline_rate - min_detectable_effect
53
+ p_pooled = (p1 + p2) / 2
54
+ numerator = (z_alpha * np.sqrt(2 * p_pooled * (1 - p_pooled)) +
55
+ z_beta * np.sqrt(p1 * (1 - p1) + p2 * (1 - p2))) ** 2
56
+ denominator = (p1 - p2) ** 2
57
+ n = int(np.ceil(numerator / denominator))
58
+ return SampleSizeResult(
59
+ sample_size_per_group=n,
60
+ total_sample_size=n * 2,
61
+ baseline_rate=baseline_rate,
62
+ min_detectable_effect=min_detectable_effect,
63
+ alpha=alpha,
64
+ power=power
65
+ )
66
+
67
+ def calculate_power(self, sample_size_per_group: int, baseline_rate: float = 0.25,
68
+ effect_size: float = 0.05, alpha: float = 0.05) -> float:
69
+ z_alpha = stats.norm.ppf(1 - alpha / 2)
70
+ p1 = baseline_rate
71
+ p2 = baseline_rate - effect_size
72
+ (p1 + p2) / 2
73
+ se = np.sqrt(p1 * (1 - p1) / sample_size_per_group + p2 * (1 - p2) / sample_size_per_group)
74
+ z = abs(p1 - p2) / se
75
+ power = stats.norm.cdf(z - z_alpha) + stats.norm.cdf(-z - z_alpha)
76
+ return float(np.clip(power, 0, 1))
77
+
78
+ def design_test(self, test_name: str, customer_pool: DataFrame,
79
+ control_name: str, treatment_names: List[str],
80
+ baseline_rate: float = 0.25, min_detectable_effect: float = 0.05,
81
+ alpha: float = 0.05, power: float = 0.80,
82
+ stratify_by: Optional[str] = None, duration_days: int = 30,
83
+ primary_metric: str = "churn_rate",
84
+ secondary_metrics: Optional[List[str]] = None) -> ABTestDesign:
85
+ sample_result = self.calculate_sample_size(
86
+ baseline_rate=baseline_rate,
87
+ min_detectable_effect=min_detectable_effect,
88
+ alpha=alpha,
89
+ power=power
90
+ )
91
+ n_groups = 1 + len(treatment_names)
92
+ total_required = sample_result.sample_size_per_group * n_groups
93
+ available = len(customer_pool)
94
+ feasible = available >= total_required
95
+ measurement_plan = MeasurementPlan(
96
+ primary_metric=primary_metric,
97
+ secondary_metrics=secondary_metrics or [],
98
+ tracking_events=["assignment", "intervention_delivered", "outcome_measured"]
99
+ )
100
+ return ABTestDesign(
101
+ test_name=test_name,
102
+ control_name=control_name,
103
+ treatment_groups=treatment_names,
104
+ recommended_sample_size=sample_result.sample_size_per_group,
105
+ total_required=total_required,
106
+ available_customers=available,
107
+ feasible=feasible,
108
+ duration_days=duration_days,
109
+ expected_completion_date=datetime.now() + timedelta(days=duration_days),
110
+ stratification_variable=stratify_by,
111
+ measurement_plan=measurement_plan
112
+ )
113
+
114
+ def generate_assignments(self, customer_pool: DataFrame, groups: List[str],
115
+ sample_size_per_group: int,
116
+ stratify_by: Optional[str] = None) -> DataFrame:
117
+ total_needed = sample_size_per_group * len(groups)
118
+ if len(customer_pool) < total_needed:
119
+ sample = customer_pool.copy()
120
+ else:
121
+ sample = customer_pool.sample(n=total_needed, random_state=42)
122
+ if stratify_by and stratify_by in sample.columns:
123
+ assignments = []
124
+ for stratum in sample[stratify_by].unique():
125
+ stratum_data = sample[sample[stratify_by] == stratum]
126
+ n_per_group = len(stratum_data) // len(groups)
127
+ shuffled = stratum_data.sample(frac=1, random_state=42)
128
+ for i, group in enumerate(groups):
129
+ start = i * n_per_group
130
+ end = start + n_per_group if i < len(groups) - 1 else len(shuffled)
131
+ group_data = shuffled.iloc[start:end].copy()
132
+ group_data["group"] = group
133
+ assignments.append(group_data)
134
+ return concat(assignments, ignore_index=True)
135
+ else:
136
+ shuffled = sample.sample(frac=1, random_state=42).reset_index(drop=True)
137
+ assignments = []
138
+ for i, group in enumerate(groups):
139
+ start = i * sample_size_per_group
140
+ end = start + sample_size_per_group
141
+ group_data = shuffled.iloc[start:end].copy()
142
+ group_data["group"] = group
143
+ assignments.append(group_data)
144
+ return concat(assignments, ignore_index=True)