churnkit 0.75.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +647 -0
  2. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +1165 -0
  3. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +961 -0
  4. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +1690 -0
  5. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +679 -0
  6. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +3305 -0
  7. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +1463 -0
  8. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +1430 -0
  9. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +854 -0
  10. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +1639 -0
  11. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +1890 -0
  12. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +1457 -0
  13. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +1624 -0
  14. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +780 -0
  15. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +979 -0
  16. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +572 -0
  17. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +1179 -0
  18. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +1418 -0
  19. churnkit-0.75.0a1.data/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +151 -0
  20. churnkit-0.75.0a1.dist-info/METADATA +229 -0
  21. churnkit-0.75.0a1.dist-info/RECORD +302 -0
  22. churnkit-0.75.0a1.dist-info/WHEEL +4 -0
  23. churnkit-0.75.0a1.dist-info/entry_points.txt +2 -0
  24. churnkit-0.75.0a1.dist-info/licenses/LICENSE +202 -0
  25. customer_retention/__init__.py +37 -0
  26. customer_retention/analysis/__init__.py +0 -0
  27. customer_retention/analysis/auto_explorer/__init__.py +62 -0
  28. customer_retention/analysis/auto_explorer/exploration_manager.py +470 -0
  29. customer_retention/analysis/auto_explorer/explorer.py +258 -0
  30. customer_retention/analysis/auto_explorer/findings.py +291 -0
  31. customer_retention/analysis/auto_explorer/layered_recommendations.py +485 -0
  32. customer_retention/analysis/auto_explorer/recommendation_builder.py +148 -0
  33. customer_retention/analysis/auto_explorer/recommendations.py +418 -0
  34. customer_retention/analysis/business/__init__.py +26 -0
  35. customer_retention/analysis/business/ab_test_designer.py +144 -0
  36. customer_retention/analysis/business/fairness_analyzer.py +166 -0
  37. customer_retention/analysis/business/intervention_matcher.py +121 -0
  38. customer_retention/analysis/business/report_generator.py +222 -0
  39. customer_retention/analysis/business/risk_profile.py +199 -0
  40. customer_retention/analysis/business/roi_analyzer.py +139 -0
  41. customer_retention/analysis/diagnostics/__init__.py +20 -0
  42. customer_retention/analysis/diagnostics/calibration_analyzer.py +133 -0
  43. customer_retention/analysis/diagnostics/cv_analyzer.py +144 -0
  44. customer_retention/analysis/diagnostics/error_analyzer.py +107 -0
  45. customer_retention/analysis/diagnostics/leakage_detector.py +394 -0
  46. customer_retention/analysis/diagnostics/noise_tester.py +140 -0
  47. customer_retention/analysis/diagnostics/overfitting_analyzer.py +190 -0
  48. customer_retention/analysis/diagnostics/segment_analyzer.py +122 -0
  49. customer_retention/analysis/discovery/__init__.py +8 -0
  50. customer_retention/analysis/discovery/config_generator.py +49 -0
  51. customer_retention/analysis/discovery/discovery_flow.py +19 -0
  52. customer_retention/analysis/discovery/type_inferencer.py +147 -0
  53. customer_retention/analysis/interpretability/__init__.py +13 -0
  54. customer_retention/analysis/interpretability/cohort_analyzer.py +185 -0
  55. customer_retention/analysis/interpretability/counterfactual.py +175 -0
  56. customer_retention/analysis/interpretability/individual_explainer.py +141 -0
  57. customer_retention/analysis/interpretability/pdp_generator.py +103 -0
  58. customer_retention/analysis/interpretability/shap_explainer.py +106 -0
  59. customer_retention/analysis/jupyter_save_hook.py +28 -0
  60. customer_retention/analysis/notebook_html_exporter.py +136 -0
  61. customer_retention/analysis/notebook_progress.py +60 -0
  62. customer_retention/analysis/plotly_preprocessor.py +154 -0
  63. customer_retention/analysis/recommendations/__init__.py +54 -0
  64. customer_retention/analysis/recommendations/base.py +158 -0
  65. customer_retention/analysis/recommendations/cleaning/__init__.py +11 -0
  66. customer_retention/analysis/recommendations/cleaning/consistency.py +107 -0
  67. customer_retention/analysis/recommendations/cleaning/deduplicate.py +94 -0
  68. customer_retention/analysis/recommendations/cleaning/impute.py +67 -0
  69. customer_retention/analysis/recommendations/cleaning/outlier.py +71 -0
  70. customer_retention/analysis/recommendations/datetime/__init__.py +3 -0
  71. customer_retention/analysis/recommendations/datetime/extract.py +149 -0
  72. customer_retention/analysis/recommendations/encoding/__init__.py +3 -0
  73. customer_retention/analysis/recommendations/encoding/categorical.py +114 -0
  74. customer_retention/analysis/recommendations/pipeline.py +74 -0
  75. customer_retention/analysis/recommendations/registry.py +76 -0
  76. customer_retention/analysis/recommendations/selection/__init__.py +3 -0
  77. customer_retention/analysis/recommendations/selection/drop_column.py +56 -0
  78. customer_retention/analysis/recommendations/transform/__init__.py +4 -0
  79. customer_retention/analysis/recommendations/transform/power.py +94 -0
  80. customer_retention/analysis/recommendations/transform/scale.py +112 -0
  81. customer_retention/analysis/visualization/__init__.py +15 -0
  82. customer_retention/analysis/visualization/chart_builder.py +2619 -0
  83. customer_retention/analysis/visualization/console.py +122 -0
  84. customer_retention/analysis/visualization/display.py +171 -0
  85. customer_retention/analysis/visualization/number_formatter.py +36 -0
  86. customer_retention/artifacts/__init__.py +3 -0
  87. customer_retention/artifacts/fit_artifact_registry.py +146 -0
  88. customer_retention/cli.py +93 -0
  89. customer_retention/core/__init__.py +0 -0
  90. customer_retention/core/compat/__init__.py +193 -0
  91. customer_retention/core/compat/detection.py +99 -0
  92. customer_retention/core/compat/ops.py +48 -0
  93. customer_retention/core/compat/pandas_backend.py +57 -0
  94. customer_retention/core/compat/spark_backend.py +75 -0
  95. customer_retention/core/components/__init__.py +11 -0
  96. customer_retention/core/components/base.py +79 -0
  97. customer_retention/core/components/components/__init__.py +13 -0
  98. customer_retention/core/components/components/deployer.py +26 -0
  99. customer_retention/core/components/components/explainer.py +26 -0
  100. customer_retention/core/components/components/feature_eng.py +33 -0
  101. customer_retention/core/components/components/ingester.py +34 -0
  102. customer_retention/core/components/components/profiler.py +34 -0
  103. customer_retention/core/components/components/trainer.py +38 -0
  104. customer_retention/core/components/components/transformer.py +36 -0
  105. customer_retention/core/components/components/validator.py +37 -0
  106. customer_retention/core/components/enums.py +33 -0
  107. customer_retention/core/components/orchestrator.py +94 -0
  108. customer_retention/core/components/registry.py +59 -0
  109. customer_retention/core/config/__init__.py +39 -0
  110. customer_retention/core/config/column_config.py +95 -0
  111. customer_retention/core/config/experiments.py +71 -0
  112. customer_retention/core/config/pipeline_config.py +117 -0
  113. customer_retention/core/config/source_config.py +83 -0
  114. customer_retention/core/utils/__init__.py +28 -0
  115. customer_retention/core/utils/leakage.py +85 -0
  116. customer_retention/core/utils/severity.py +53 -0
  117. customer_retention/core/utils/statistics.py +90 -0
  118. customer_retention/generators/__init__.py +0 -0
  119. customer_retention/generators/notebook_generator/__init__.py +167 -0
  120. customer_retention/generators/notebook_generator/base.py +55 -0
  121. customer_retention/generators/notebook_generator/cell_builder.py +49 -0
  122. customer_retention/generators/notebook_generator/config.py +47 -0
  123. customer_retention/generators/notebook_generator/databricks_generator.py +48 -0
  124. customer_retention/generators/notebook_generator/local_generator.py +48 -0
  125. customer_retention/generators/notebook_generator/project_init.py +174 -0
  126. customer_retention/generators/notebook_generator/runner.py +150 -0
  127. customer_retention/generators/notebook_generator/script_generator.py +110 -0
  128. customer_retention/generators/notebook_generator/stages/__init__.py +19 -0
  129. customer_retention/generators/notebook_generator/stages/base_stage.py +86 -0
  130. customer_retention/generators/notebook_generator/stages/s01_ingestion.py +100 -0
  131. customer_retention/generators/notebook_generator/stages/s02_profiling.py +95 -0
  132. customer_retention/generators/notebook_generator/stages/s03_cleaning.py +180 -0
  133. customer_retention/generators/notebook_generator/stages/s04_transformation.py +165 -0
  134. customer_retention/generators/notebook_generator/stages/s05_feature_engineering.py +115 -0
  135. customer_retention/generators/notebook_generator/stages/s06_feature_selection.py +97 -0
  136. customer_retention/generators/notebook_generator/stages/s07_model_training.py +176 -0
  137. customer_retention/generators/notebook_generator/stages/s08_deployment.py +81 -0
  138. customer_retention/generators/notebook_generator/stages/s09_monitoring.py +112 -0
  139. customer_retention/generators/notebook_generator/stages/s10_batch_inference.py +642 -0
  140. customer_retention/generators/notebook_generator/stages/s11_feature_store.py +348 -0
  141. customer_retention/generators/orchestration/__init__.py +23 -0
  142. customer_retention/generators/orchestration/code_generator.py +196 -0
  143. customer_retention/generators/orchestration/context.py +147 -0
  144. customer_retention/generators/orchestration/data_materializer.py +188 -0
  145. customer_retention/generators/orchestration/databricks_exporter.py +411 -0
  146. customer_retention/generators/orchestration/doc_generator.py +311 -0
  147. customer_retention/generators/pipeline_generator/__init__.py +26 -0
  148. customer_retention/generators/pipeline_generator/findings_parser.py +727 -0
  149. customer_retention/generators/pipeline_generator/generator.py +142 -0
  150. customer_retention/generators/pipeline_generator/models.py +166 -0
  151. customer_retention/generators/pipeline_generator/renderer.py +2125 -0
  152. customer_retention/generators/spec_generator/__init__.py +37 -0
  153. customer_retention/generators/spec_generator/databricks_generator.py +433 -0
  154. customer_retention/generators/spec_generator/generic_generator.py +373 -0
  155. customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +685 -0
  156. customer_retention/generators/spec_generator/pipeline_spec.py +298 -0
  157. customer_retention/integrations/__init__.py +0 -0
  158. customer_retention/integrations/adapters/__init__.py +13 -0
  159. customer_retention/integrations/adapters/base.py +10 -0
  160. customer_retention/integrations/adapters/factory.py +25 -0
  161. customer_retention/integrations/adapters/feature_store/__init__.py +6 -0
  162. customer_retention/integrations/adapters/feature_store/base.py +57 -0
  163. customer_retention/integrations/adapters/feature_store/databricks.py +94 -0
  164. customer_retention/integrations/adapters/feature_store/feast_adapter.py +97 -0
  165. customer_retention/integrations/adapters/feature_store/local.py +75 -0
  166. customer_retention/integrations/adapters/mlflow/__init__.py +6 -0
  167. customer_retention/integrations/adapters/mlflow/base.py +32 -0
  168. customer_retention/integrations/adapters/mlflow/databricks.py +54 -0
  169. customer_retention/integrations/adapters/mlflow/experiment_tracker.py +161 -0
  170. customer_retention/integrations/adapters/mlflow/local.py +50 -0
  171. customer_retention/integrations/adapters/storage/__init__.py +5 -0
  172. customer_retention/integrations/adapters/storage/base.py +33 -0
  173. customer_retention/integrations/adapters/storage/databricks.py +76 -0
  174. customer_retention/integrations/adapters/storage/local.py +59 -0
  175. customer_retention/integrations/feature_store/__init__.py +47 -0
  176. customer_retention/integrations/feature_store/definitions.py +215 -0
  177. customer_retention/integrations/feature_store/manager.py +744 -0
  178. customer_retention/integrations/feature_store/registry.py +412 -0
  179. customer_retention/integrations/iteration/__init__.py +28 -0
  180. customer_retention/integrations/iteration/context.py +212 -0
  181. customer_retention/integrations/iteration/feedback_collector.py +184 -0
  182. customer_retention/integrations/iteration/orchestrator.py +168 -0
  183. customer_retention/integrations/iteration/recommendation_tracker.py +341 -0
  184. customer_retention/integrations/iteration/signals.py +212 -0
  185. customer_retention/integrations/llm_context/__init__.py +4 -0
  186. customer_retention/integrations/llm_context/context_builder.py +201 -0
  187. customer_retention/integrations/llm_context/prompts.py +100 -0
  188. customer_retention/integrations/streaming/__init__.py +103 -0
  189. customer_retention/integrations/streaming/batch_integration.py +149 -0
  190. customer_retention/integrations/streaming/early_warning_model.py +227 -0
  191. customer_retention/integrations/streaming/event_schema.py +214 -0
  192. customer_retention/integrations/streaming/online_store_writer.py +249 -0
  193. customer_retention/integrations/streaming/realtime_scorer.py +261 -0
  194. customer_retention/integrations/streaming/trigger_engine.py +293 -0
  195. customer_retention/integrations/streaming/window_aggregator.py +393 -0
  196. customer_retention/stages/__init__.py +0 -0
  197. customer_retention/stages/cleaning/__init__.py +9 -0
  198. customer_retention/stages/cleaning/base.py +28 -0
  199. customer_retention/stages/cleaning/missing_handler.py +160 -0
  200. customer_retention/stages/cleaning/outlier_handler.py +204 -0
  201. customer_retention/stages/deployment/__init__.py +28 -0
  202. customer_retention/stages/deployment/batch_scorer.py +106 -0
  203. customer_retention/stages/deployment/champion_challenger.py +299 -0
  204. customer_retention/stages/deployment/model_registry.py +182 -0
  205. customer_retention/stages/deployment/retraining_trigger.py +245 -0
  206. customer_retention/stages/features/__init__.py +73 -0
  207. customer_retention/stages/features/behavioral_features.py +266 -0
  208. customer_retention/stages/features/customer_segmentation.py +505 -0
  209. customer_retention/stages/features/feature_definitions.py +265 -0
  210. customer_retention/stages/features/feature_engineer.py +551 -0
  211. customer_retention/stages/features/feature_manifest.py +340 -0
  212. customer_retention/stages/features/feature_selector.py +239 -0
  213. customer_retention/stages/features/interaction_features.py +160 -0
  214. customer_retention/stages/features/temporal_features.py +243 -0
  215. customer_retention/stages/ingestion/__init__.py +9 -0
  216. customer_retention/stages/ingestion/load_result.py +32 -0
  217. customer_retention/stages/ingestion/loaders.py +195 -0
  218. customer_retention/stages/ingestion/source_registry.py +130 -0
  219. customer_retention/stages/modeling/__init__.py +31 -0
  220. customer_retention/stages/modeling/baseline_trainer.py +139 -0
  221. customer_retention/stages/modeling/cross_validator.py +125 -0
  222. customer_retention/stages/modeling/data_splitter.py +205 -0
  223. customer_retention/stages/modeling/feature_scaler.py +99 -0
  224. customer_retention/stages/modeling/hyperparameter_tuner.py +107 -0
  225. customer_retention/stages/modeling/imbalance_handler.py +282 -0
  226. customer_retention/stages/modeling/mlflow_logger.py +95 -0
  227. customer_retention/stages/modeling/model_comparator.py +149 -0
  228. customer_retention/stages/modeling/model_evaluator.py +138 -0
  229. customer_retention/stages/modeling/threshold_optimizer.py +131 -0
  230. customer_retention/stages/monitoring/__init__.py +37 -0
  231. customer_retention/stages/monitoring/alert_manager.py +328 -0
  232. customer_retention/stages/monitoring/drift_detector.py +201 -0
  233. customer_retention/stages/monitoring/performance_monitor.py +242 -0
  234. customer_retention/stages/preprocessing/__init__.py +5 -0
  235. customer_retention/stages/preprocessing/transformer_manager.py +284 -0
  236. customer_retention/stages/profiling/__init__.py +256 -0
  237. customer_retention/stages/profiling/categorical_distribution.py +269 -0
  238. customer_retention/stages/profiling/categorical_target_analyzer.py +274 -0
  239. customer_retention/stages/profiling/column_profiler.py +527 -0
  240. customer_retention/stages/profiling/distribution_analysis.py +483 -0
  241. customer_retention/stages/profiling/drift_detector.py +310 -0
  242. customer_retention/stages/profiling/feature_capacity.py +507 -0
  243. customer_retention/stages/profiling/pattern_analysis_config.py +513 -0
  244. customer_retention/stages/profiling/profile_result.py +212 -0
  245. customer_retention/stages/profiling/quality_checks.py +1632 -0
  246. customer_retention/stages/profiling/relationship_detector.py +256 -0
  247. customer_retention/stages/profiling/relationship_recommender.py +454 -0
  248. customer_retention/stages/profiling/report_generator.py +520 -0
  249. customer_retention/stages/profiling/scd_analyzer.py +151 -0
  250. customer_retention/stages/profiling/segment_analyzer.py +632 -0
  251. customer_retention/stages/profiling/segment_aware_outlier.py +265 -0
  252. customer_retention/stages/profiling/target_level_analyzer.py +217 -0
  253. customer_retention/stages/profiling/temporal_analyzer.py +388 -0
  254. customer_retention/stages/profiling/temporal_coverage.py +488 -0
  255. customer_retention/stages/profiling/temporal_feature_analyzer.py +692 -0
  256. customer_retention/stages/profiling/temporal_feature_engineer.py +703 -0
  257. customer_retention/stages/profiling/temporal_pattern_analyzer.py +636 -0
  258. customer_retention/stages/profiling/temporal_quality_checks.py +278 -0
  259. customer_retention/stages/profiling/temporal_target_analyzer.py +241 -0
  260. customer_retention/stages/profiling/text_embedder.py +87 -0
  261. customer_retention/stages/profiling/text_processor.py +115 -0
  262. customer_retention/stages/profiling/text_reducer.py +60 -0
  263. customer_retention/stages/profiling/time_series_profiler.py +303 -0
  264. customer_retention/stages/profiling/time_window_aggregator.py +376 -0
  265. customer_retention/stages/profiling/type_detector.py +382 -0
  266. customer_retention/stages/profiling/window_recommendation.py +288 -0
  267. customer_retention/stages/temporal/__init__.py +166 -0
  268. customer_retention/stages/temporal/access_guard.py +180 -0
  269. customer_retention/stages/temporal/cutoff_analyzer.py +235 -0
  270. customer_retention/stages/temporal/data_preparer.py +178 -0
  271. customer_retention/stages/temporal/point_in_time_join.py +134 -0
  272. customer_retention/stages/temporal/point_in_time_registry.py +148 -0
  273. customer_retention/stages/temporal/scenario_detector.py +163 -0
  274. customer_retention/stages/temporal/snapshot_manager.py +259 -0
  275. customer_retention/stages/temporal/synthetic_coordinator.py +66 -0
  276. customer_retention/stages/temporal/timestamp_discovery.py +531 -0
  277. customer_retention/stages/temporal/timestamp_manager.py +255 -0
  278. customer_retention/stages/transformation/__init__.py +13 -0
  279. customer_retention/stages/transformation/binary_handler.py +85 -0
  280. customer_retention/stages/transformation/categorical_encoder.py +245 -0
  281. customer_retention/stages/transformation/datetime_transformer.py +97 -0
  282. customer_retention/stages/transformation/numeric_transformer.py +181 -0
  283. customer_retention/stages/transformation/pipeline.py +257 -0
  284. customer_retention/stages/validation/__init__.py +60 -0
  285. customer_retention/stages/validation/adversarial_scoring_validator.py +205 -0
  286. customer_retention/stages/validation/business_sense_gate.py +173 -0
  287. customer_retention/stages/validation/data_quality_gate.py +235 -0
  288. customer_retention/stages/validation/data_validators.py +511 -0
  289. customer_retention/stages/validation/feature_quality_gate.py +183 -0
  290. customer_retention/stages/validation/gates.py +117 -0
  291. customer_retention/stages/validation/leakage_gate.py +352 -0
  292. customer_retention/stages/validation/model_validity_gate.py +213 -0
  293. customer_retention/stages/validation/pipeline_validation_runner.py +264 -0
  294. customer_retention/stages/validation/quality_scorer.py +544 -0
  295. customer_retention/stages/validation/rule_generator.py +57 -0
  296. customer_retention/stages/validation/scoring_pipeline_validator.py +446 -0
  297. customer_retention/stages/validation/timeseries_detector.py +769 -0
  298. customer_retention/transforms/__init__.py +47 -0
  299. customer_retention/transforms/artifact_store.py +50 -0
  300. customer_retention/transforms/executor.py +157 -0
  301. customer_retention/transforms/fitted.py +92 -0
  302. customer_retention/transforms/ops.py +148 -0
@@ -0,0 +1,278 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+ from customer_retention.core.compat import DataFrame, pd
5
+ from customer_retention.core.components.enums import Severity
6
+
7
+
8
+ @dataclass
9
+ class TemporalQualityResult:
10
+ check_id: str
11
+ check_name: str
12
+ passed: bool
13
+ severity: Severity
14
+ message: str
15
+ details: dict = field(default_factory=dict)
16
+ recommendation: Optional[str] = None
17
+ duplicate_count: int = 0
18
+ gap_count: int = 0
19
+ max_gap_days: float = 0
20
+ future_count: int = 0
21
+ ambiguous_count: int = 0
22
+
23
+
24
+ class TemporalQualityCheck:
25
+ def __init__(self, check_id: str, check_name: str, severity: Severity):
26
+ self.check_id = check_id
27
+ self.check_name = check_name
28
+ self.severity = severity
29
+
30
+ def run(self, df: DataFrame) -> TemporalQualityResult:
31
+ raise NotImplementedError
32
+
33
+
34
+ class DuplicateEventCheck(TemporalQualityCheck):
35
+ def __init__(self, entity_column: str, time_column: str):
36
+ super().__init__("TQ001", "Duplicate Events", Severity.MEDIUM)
37
+ self.entity_column = entity_column
38
+ self.time_column = time_column
39
+
40
+ def run(self, df: DataFrame) -> TemporalQualityResult:
41
+ if len(df) == 0:
42
+ return self._pass_result("No data to check")
43
+
44
+ duplicates = df.duplicated(subset=[self.entity_column, self.time_column], keep=False)
45
+ duplicate_count = duplicates.sum() - df[duplicates].groupby([self.entity_column, self.time_column]).ngroups
46
+
47
+ if duplicate_count > 0:
48
+ examples = df[duplicates].head(10)[[self.entity_column, self.time_column]].to_dict('records')
49
+ return TemporalQualityResult(
50
+ check_id=self.check_id, check_name=self.check_name, passed=False, severity=self.severity,
51
+ message=f"Found {duplicate_count} duplicate events (same entity + timestamp)",
52
+ details={"duplicate_examples": examples, "affected_entities": df[duplicates][self.entity_column].nunique()},
53
+ recommendation="Review duplicates - may need deduplication logic", duplicate_count=duplicate_count)
54
+
55
+ return self._pass_result("No duplicate events found")
56
+
57
+ def _pass_result(self, message: str) -> TemporalQualityResult:
58
+ return TemporalQualityResult(
59
+ check_id=self.check_id, check_name=self.check_name, passed=True,
60
+ severity=Severity.INFO, message=message, duplicate_count=0)
61
+
62
+
63
+ class TemporalGapCheck(TemporalQualityCheck):
64
+ FREQ_TO_DAYS = {"D": 1, "W": 7, "M": 30, "Q": 90, "Y": 365, "H": 1/24, "T": 1/1440, "min": 1/1440}
65
+
66
+ def __init__(self, time_column: str, expected_frequency: str = "D", max_gap_multiple: float = 3.0):
67
+ super().__init__("TQ002", "Temporal Gaps", Severity.MEDIUM)
68
+ self.time_column = time_column
69
+ self.expected_frequency = expected_frequency
70
+ self.max_gap_multiple = max_gap_multiple
71
+
72
+ def run(self, df: DataFrame) -> TemporalQualityResult:
73
+ if len(df) < 2:
74
+ return self._pass_result("Insufficient data to check gaps")
75
+
76
+ time_col = pd.to_datetime(df.sort_values(self.time_column)[self.time_column])
77
+ diffs_days = time_col.diff().dropna().dt.total_seconds() / 86400
78
+ expected_days = self.FREQ_TO_DAYS.get(self.expected_frequency, 1)
79
+ threshold_days = expected_days * self.max_gap_multiple
80
+
81
+ large_gaps = diffs_days[diffs_days > threshold_days]
82
+ max_gap = float(diffs_days.max()) if len(diffs_days) > 0 else 0
83
+
84
+ if len(large_gaps) > 0:
85
+ return TemporalQualityResult(
86
+ check_id=self.check_id, check_name=self.check_name, passed=False, severity=self.severity,
87
+ message=f"Found {len(large_gaps)} gaps exceeding {threshold_days:.1f} days",
88
+ details={"threshold_days": threshold_days, "expected_frequency": self.expected_frequency,
89
+ "gap_locations": large_gaps.index.tolist()[:10]},
90
+ recommendation="Investigate data collection gaps or missing data",
91
+ gap_count=len(large_gaps), max_gap_days=max_gap)
92
+
93
+ return TemporalQualityResult(
94
+ check_id=self.check_id, check_name=self.check_name, passed=True, severity=Severity.INFO,
95
+ message="No significant temporal gaps detected", gap_count=0, max_gap_days=max_gap)
96
+
97
+ def _pass_result(self, message: str) -> TemporalQualityResult:
98
+ return TemporalQualityResult(
99
+ check_id=self.check_id, check_name=self.check_name, passed=True,
100
+ severity=Severity.INFO, message=message, gap_count=0, max_gap_days=0)
101
+
102
+
103
+ class FutureDateCheck(TemporalQualityCheck):
104
+ def __init__(self, time_column: str, reference_date: Optional[pd.Timestamp] = None):
105
+ super().__init__("TQ003", "Future Dates", Severity.HIGH)
106
+ self.time_column = time_column
107
+ self.reference_date = reference_date or pd.Timestamp.now()
108
+
109
+ def run(self, df: DataFrame) -> TemporalQualityResult:
110
+ if len(df) == 0:
111
+ return self._pass_result("No data to check")
112
+
113
+ time_col = pd.to_datetime(df[self.time_column])
114
+ future_mask = time_col > self.reference_date
115
+ future_count = future_mask.sum()
116
+
117
+ if future_count > 0:
118
+ return TemporalQualityResult(
119
+ check_id=self.check_id, check_name=self.check_name, passed=False, severity=self.severity,
120
+ message=f"Found {future_count} events with future dates",
121
+ details={"reference_date": str(self.reference_date),
122
+ "future_date_examples": [str(d) for d in time_col[future_mask].head(10).tolist()]},
123
+ recommendation="Review data entry or timestamp handling", future_count=future_count)
124
+
125
+ return self._pass_result("No future dates detected")
126
+
127
+ def _pass_result(self, message: str) -> TemporalQualityResult:
128
+ return TemporalQualityResult(
129
+ check_id=self.check_id, check_name=self.check_name, passed=True,
130
+ severity=Severity.INFO, message=message, future_count=0)
131
+
132
+
133
+ class EventOrderCheck(TemporalQualityCheck):
134
+ def __init__(self, entity_column: str, time_column: str):
135
+ super().__init__("TQ004", "Event Ordering", Severity.LOW)
136
+ self.entity_column = entity_column
137
+ self.time_column = time_column
138
+
139
+ def run(self, df: DataFrame) -> TemporalQualityResult:
140
+ if len(df) < 2:
141
+ return self._pass_result("Insufficient data to check ordering")
142
+
143
+ df_check = df.assign(_parsed_time=pd.to_datetime(df[self.time_column]))
144
+ collision_counts = df_check.groupby([self.entity_column, "_parsed_time"]).size()
145
+ ambiguous = collision_counts[collision_counts > 1]
146
+ ambiguous_count = ambiguous.sum() - len(ambiguous)
147
+
148
+ if ambiguous_count > 0:
149
+ return TemporalQualityResult(
150
+ check_id=self.check_id, check_name=self.check_name, passed=True, severity=Severity.LOW,
151
+ message=f"{ambiguous_count} events have ambiguous ordering (same timestamp)",
152
+ details={"collision_groups": len(ambiguous), "total_ambiguous_events": int(ambiguous.sum())},
153
+ recommendation="Consider adding sequence numbers for same-timestamp events",
154
+ ambiguous_count=ambiguous_count)
155
+
156
+ return self._pass_result("Event ordering is unambiguous")
157
+
158
+ def _pass_result(self, message: str) -> TemporalQualityResult:
159
+ return TemporalQualityResult(
160
+ check_id=self.check_id, check_name=self.check_name, passed=True,
161
+ severity=Severity.INFO, message=message, ambiguous_count=0)
162
+
163
+
164
+ @dataclass
165
+ class TemporalQualityScore:
166
+ score: float
167
+ grade: str
168
+ check_scores: list
169
+ passed: int
170
+ total: int
171
+
172
+ @property
173
+ def grade_emoji(self) -> str:
174
+ return {"A": "🏆", "B": "✅", "C": "⚠️", "D": "❌"}.get(self.grade, "")
175
+
176
+ @property
177
+ def grade_message(self) -> str:
178
+ return {"A": "Excellent - ready for feature engineering", "B": "Good - minor issues, proceed with caution",
179
+ "C": "Fair - address issues before proceeding", "D": "Poor - significant investigation needed"}.get(self.grade, "")
180
+
181
+
182
+ class TemporalQualityReporter:
183
+ ML_IMPACTS = {
184
+ "TQ001": {"impacts": [("Event counts", "Inflated metrics"), ("Aggregations", "Skewed"), ("Sequences", "Artificial patterns")],
185
+ "fix": "df.drop_duplicates(subset=[entity, time], keep='first')"},
186
+ "TQ002": {"impacts": [("Rolling features", "Low during gaps"), ("Recency", "Inflated"), ("Seasonality", "Distorted")],
187
+ "fix": "Document gaps; add df['has_gap'] indicator"},
188
+ "TQ003": {"impacts": [("Data leakage", "Future in training"), ("Time splits", "Broken"), ("Recency", "Negative values")],
189
+ "fix": "df = df[df[time_col] <= reference_date]"},
190
+ "TQ004": {"impacts": [("Sequences", "Undefined order"), ("State tracking", "Ambiguous"), ("Lags", "Unclear")],
191
+ "fix": "Add sequence: df['seq'] = df.groupby(entity).cumcount()"}
192
+ }
193
+
194
+ def __init__(self, results: list, total_rows: int):
195
+ self.results = results
196
+ self.total_rows = total_rows
197
+ self._calculate_scores()
198
+
199
+ def _calculate_scores(self):
200
+ self.check_scores = []
201
+ for r in self.results:
202
+ issue_count = r.duplicate_count or r.gap_count or r.future_count or r.ambiguous_count or 0
203
+ score = self._score_from_issues(issue_count, self.total_rows)
204
+ pct = (issue_count / self.total_rows * 100) if self.total_rows > 0 else 0
205
+ self.check_scores.append({
206
+ "check_id": r.check_id, "name": r.check_name, "result": r,
207
+ "issues": issue_count, "pct": pct, "score": score, "contribution": score * 0.25})
208
+ self.quality_score = sum(c["contribution"] for c in self.check_scores)
209
+ self.grade = "A" if self.quality_score >= 90 else "B" if self.quality_score >= 75 else "C" if self.quality_score >= 60 else "D"
210
+ self.passed = sum(1 for r in self.results if r.passed)
211
+
212
+ def _score_from_issues(self, issues: int, total: int) -> float:
213
+ if total == 0 or issues == 0:
214
+ return 100.0
215
+ pct = (issues / total) * 100
216
+ if pct < 0.1:
217
+ return 99.0
218
+ if pct < 1.0:
219
+ return 95.0 - (pct * 5)
220
+ if pct < 5.0:
221
+ return 90.0 - (pct * 4)
222
+ if pct < 20.0:
223
+ return 70.0 - (pct * 2)
224
+ return max(0, 30.0 - pct)
225
+
226
+ def get_score(self) -> TemporalQualityScore:
227
+ return TemporalQualityScore(
228
+ score=self.quality_score, grade=self.grade,
229
+ check_scores=self.check_scores, passed=self.passed, total=len(self.results))
230
+
231
+ def print_results(self):
232
+ severity_icons = {Severity.HIGH: "🔴", Severity.MEDIUM: "🟠", Severity.LOW: "🟡", Severity.INFO: "🔵"}
233
+ print("=" * 70 + "\nTEMPORAL QUALITY CHECK RESULTS\n" + "=" * 70)
234
+ print(f"\n📋 Summary: {self.passed}/{len(self.results)} checks passed\n")
235
+
236
+ for c in self.check_scores:
237
+ r = c["result"]
238
+ print(f"{'✅' if r.passed else '❌'} [{r.check_id}] {r.check_name}")
239
+ print(f" {severity_icons.get(r.severity, '⚪')} Severity: {r.severity.value} | {r.message}")
240
+
241
+ if c["issues"] > 0 and r.check_id in self.ML_IMPACTS:
242
+ impact = self.ML_IMPACTS[r.check_id]
243
+ print(f"\n 📊 Impact ({c['issues']:,} issues = {c['pct']:.2f}%):")
244
+ for area, problem in impact["impacts"]:
245
+ print(f" • {area}: {problem}")
246
+ print(f" 🛠️ Fix: {impact['fix']}")
247
+ elif r.recommendation:
248
+ print(f" 💡 {r.recommendation}")
249
+ print()
250
+
251
+ def print_score(self, bar_width: int = 40):
252
+ grade_emoji = {"A": "🏆", "B": "✅", "C": "⚠️", "D": "❌"}[self.grade]
253
+ print("\n" + "=" * 70)
254
+ print(f"QUALITY SCORE: {self.quality_score:.0f}/100 {grade_emoji} Grade {self.grade}\n" + "=" * 70)
255
+
256
+ filled = int((self.quality_score / 100) * bar_width)
257
+ print(f"\n Total: [{'█' * filled}{'░' * (bar_width - filled)}] {self.quality_score:.0f}%\n")
258
+
259
+ for c in self.check_scores:
260
+ filled = int((c["contribution"] / 25) * 20)
261
+ bar = f"[{'█' * filled}{'░' * (20 - filled)}] {c['contribution']:.1f}/25"
262
+ status = "✓" if c["issues"] == 0 else "△" if c["pct"] < 1 else "✗"
263
+ issues_str = f"{c['issues']:,} issues" if c["issues"] > 0 else "no issues"
264
+ print(f" {status} {c['name']:<18} {bar} ({issues_str})")
265
+
266
+ grade_messages = {"A": "Excellent - ready for feature engineering", "B": "Good - minor issues, proceed with caution",
267
+ "C": "Fair - address issues before proceeding", "D": "Poor - significant investigation needed"}
268
+ print(f"\n Grade {self.grade}: {grade_messages[self.grade]}")
269
+
270
+ def to_dict(self) -> dict:
271
+ return {
272
+ "temporal_quality_score": self.quality_score, "temporal_quality_grade": self.grade,
273
+ "checks_passed": self.passed, "checks_total": len(self.results),
274
+ "issues": {
275
+ "duplicate_events": self.results[0].duplicate_count if len(self.results) > 0 else 0,
276
+ "temporal_gaps": self.results[1].gap_count if len(self.results) > 1 else 0,
277
+ "future_dates": self.results[2].future_count if len(self.results) > 2 else 0,
278
+ "ambiguous_ordering": self.results[3].ambiguous_count if len(self.results) > 3 else 0}}
@@ -0,0 +1,241 @@
1
+ """Temporal feature analysis with respect to a binary target."""
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from customer_retention.core.compat import DataFrame, to_pandas
9
+
10
+
11
+ @dataclass
12
+ class TemporalTargetResult:
13
+ """Results from temporal-target analysis."""
14
+ datetime_col: str
15
+ target_col: str
16
+ min_date: pd.Timestamp
17
+ max_date: pd.Timestamp
18
+ n_valid_dates: int
19
+ overall_rate: float
20
+
21
+ # Yearly analysis
22
+ yearly_stats: pd.DataFrame # year, count, retention_rate, lift
23
+ yearly_trend: str # 'improving', 'declining', 'stable'
24
+
25
+ # Monthly analysis (seasonality)
26
+ monthly_stats: pd.DataFrame # month, month_name, count, retention_rate, lift
27
+ best_month: Optional[str]
28
+ worst_month: Optional[str]
29
+ seasonal_spread: float # difference between best and worst
30
+
31
+ # Day of week analysis
32
+ dow_stats: pd.DataFrame # day_of_week, day_name, count, retention_rate, lift
33
+
34
+ # Quarterly analysis
35
+ quarterly_stats: pd.DataFrame
36
+
37
+
38
+ class TemporalTargetAnalyzer:
39
+ """Analyzes relationship between datetime features and binary target.
40
+
41
+ Computes retention rates by:
42
+ - Year (cohort analysis)
43
+ - Month (seasonality)
44
+ - Day of week (weekly patterns)
45
+ - Quarter
46
+ """
47
+
48
+ MONTH_NAMES = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
49
+ 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
50
+ DOW_NAMES = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
51
+
52
+ def __init__(self, min_samples_per_period: int = 10):
53
+ self.min_samples_per_period = min_samples_per_period
54
+
55
+ def analyze(
56
+ self,
57
+ df: DataFrame,
58
+ datetime_col: str,
59
+ target_col: str
60
+ ) -> TemporalTargetResult:
61
+ """Analyze relationship between datetime feature and binary target."""
62
+ df = to_pandas(df)
63
+
64
+ if len(df) == 0 or datetime_col not in df.columns or target_col not in df.columns:
65
+ return self._empty_result(datetime_col, target_col)
66
+
67
+ # Parse dates and prepare data
68
+ df_clean = df[[datetime_col, target_col]].copy()
69
+ df_clean[datetime_col] = pd.to_datetime(df_clean[datetime_col], errors='coerce')
70
+ df_clean = df_clean.dropna()
71
+
72
+ if len(df_clean) == 0:
73
+ return self._empty_result(datetime_col, target_col)
74
+
75
+ # Calculate overall retention rate
76
+ overall_rate = df_clean[target_col].mean()
77
+
78
+ # Extract temporal components
79
+ df_clean['_year'] = df_clean[datetime_col].dt.year
80
+ df_clean['_month'] = df_clean[datetime_col].dt.month
81
+ df_clean['_quarter'] = df_clean[datetime_col].dt.quarter
82
+ df_clean['_dow'] = df_clean[datetime_col].dt.dayofweek
83
+
84
+ # Calculate stats by time period
85
+ yearly_stats = self._calculate_period_stats(df_clean, '_year', target_col, overall_rate)
86
+ monthly_stats = self._calculate_monthly_stats(df_clean, target_col, overall_rate)
87
+ quarterly_stats = self._calculate_period_stats(df_clean, '_quarter', target_col, overall_rate)
88
+ dow_stats = self._calculate_dow_stats(df_clean, target_col, overall_rate)
89
+
90
+ # Determine yearly trend
91
+ yearly_trend = self._determine_yearly_trend(yearly_stats)
92
+
93
+ # Find best/worst months
94
+ best_month, worst_month, seasonal_spread = self._find_seasonal_extremes(monthly_stats)
95
+
96
+ return TemporalTargetResult(
97
+ datetime_col=datetime_col,
98
+ target_col=target_col,
99
+ min_date=df_clean[datetime_col].min(),
100
+ max_date=df_clean[datetime_col].max(),
101
+ n_valid_dates=len(df_clean),
102
+ overall_rate=overall_rate,
103
+ yearly_stats=yearly_stats,
104
+ yearly_trend=yearly_trend,
105
+ monthly_stats=monthly_stats,
106
+ best_month=best_month,
107
+ worst_month=worst_month,
108
+ seasonal_spread=seasonal_spread,
109
+ dow_stats=dow_stats,
110
+ quarterly_stats=quarterly_stats
111
+ )
112
+
113
+ def _calculate_period_stats(
114
+ self,
115
+ df: pd.DataFrame,
116
+ period_col: str,
117
+ target_col: str,
118
+ overall_rate: float
119
+ ) -> pd.DataFrame:
120
+ """Calculate retention stats for a time period."""
121
+ stats = df.groupby(period_col)[target_col].agg(['sum', 'count', 'mean']).reset_index()
122
+ stats.columns = ['period', 'retained_count', 'count', 'retention_rate']
123
+ stats['lift'] = stats['retention_rate'] / overall_rate if overall_rate > 0 else 0
124
+
125
+ # Filter small samples
126
+ stats = stats[stats['count'] >= self.min_samples_per_period]
127
+
128
+ return stats.sort_values('period').reset_index(drop=True)
129
+
130
+ def _calculate_monthly_stats(
131
+ self,
132
+ df: pd.DataFrame,
133
+ target_col: str,
134
+ overall_rate: float
135
+ ) -> pd.DataFrame:
136
+ """Calculate monthly retention stats with month names."""
137
+ stats = df.groupby('_month')[target_col].agg(['sum', 'count', 'mean']).reset_index()
138
+ stats.columns = ['month', 'retained_count', 'count', 'retention_rate']
139
+ stats['lift'] = stats['retention_rate'] / overall_rate if overall_rate > 0 else 0
140
+ stats['month_name'] = stats['month'].apply(
141
+ lambda x: self.MONTH_NAMES[int(x) - 1] if 1 <= x <= 12 else 'Unknown'
142
+ )
143
+
144
+ # Filter small samples
145
+ stats = stats[stats['count'] >= self.min_samples_per_period]
146
+
147
+ return stats.sort_values('month').reset_index(drop=True)
148
+
149
+ def _calculate_dow_stats(
150
+ self,
151
+ df: pd.DataFrame,
152
+ target_col: str,
153
+ overall_rate: float
154
+ ) -> pd.DataFrame:
155
+ """Calculate day-of-week retention stats."""
156
+ stats = df.groupby('_dow')[target_col].agg(['sum', 'count', 'mean']).reset_index()
157
+ stats.columns = ['day_of_week', 'retained_count', 'count', 'retention_rate']
158
+ stats['lift'] = stats['retention_rate'] / overall_rate if overall_rate > 0 else 0
159
+ stats['day_name'] = stats['day_of_week'].apply(
160
+ lambda x: self.DOW_NAMES[int(x)] if 0 <= x <= 6 else 'Unknown'
161
+ )
162
+
163
+ return stats.sort_values('day_of_week').reset_index(drop=True)
164
+
165
+ def _determine_yearly_trend(self, yearly_stats: pd.DataFrame) -> str:
166
+ """Determine if retention is improving, declining, or stable over years."""
167
+ if len(yearly_stats) < 2:
168
+ return 'stable'
169
+
170
+ rates = yearly_stats['retention_rate'].values
171
+ yearly_stats['period'].values
172
+
173
+ # Simple linear regression
174
+ if len(rates) >= 2:
175
+ slope = np.polyfit(range(len(rates)), rates, 1)[0]
176
+
177
+ if slope > 0.02: # More than 2% improvement per year
178
+ return 'improving'
179
+ elif slope < -0.02: # More than 2% decline per year
180
+ return 'declining'
181
+
182
+ return 'stable'
183
+
184
+ def _find_seasonal_extremes(
185
+ self,
186
+ monthly_stats: pd.DataFrame
187
+ ) -> tuple:
188
+ """Find best and worst months for retention."""
189
+ if len(monthly_stats) == 0:
190
+ return None, None, 0.0
191
+
192
+ best_idx = monthly_stats['retention_rate'].idxmax()
193
+ worst_idx = monthly_stats['retention_rate'].idxmin()
194
+
195
+ best_month = monthly_stats.loc[best_idx, 'month_name']
196
+ worst_month = monthly_stats.loc[worst_idx, 'month_name']
197
+ spread = monthly_stats.loc[best_idx, 'retention_rate'] - monthly_stats.loc[worst_idx, 'retention_rate']
198
+
199
+ return best_month, worst_month, float(spread)
200
+
201
+ def _empty_result(self, datetime_col: str, target_col: str) -> TemporalTargetResult:
202
+ """Return empty result for edge cases."""
203
+ empty_df = pd.DataFrame()
204
+
205
+ return TemporalTargetResult(
206
+ datetime_col=datetime_col,
207
+ target_col=target_col,
208
+ min_date=pd.NaT,
209
+ max_date=pd.NaT,
210
+ n_valid_dates=0,
211
+ overall_rate=0.0,
212
+ yearly_stats=empty_df,
213
+ yearly_trend='stable',
214
+ monthly_stats=empty_df,
215
+ best_month=None,
216
+ worst_month=None,
217
+ seasonal_spread=0.0,
218
+ dow_stats=empty_df,
219
+ quarterly_stats=empty_df
220
+ )
221
+
222
+ def analyze_multiple(
223
+ self,
224
+ df: DataFrame,
225
+ datetime_cols: List[str],
226
+ target_col: str
227
+ ) -> pd.DataFrame:
228
+ """Analyze multiple datetime columns and return summary."""
229
+ results = []
230
+ for col in datetime_cols:
231
+ result = self.analyze(df, col, target_col)
232
+ results.append({
233
+ 'feature': col,
234
+ 'n_valid': result.n_valid_dates,
235
+ 'yearly_trend': result.yearly_trend,
236
+ 'best_month': result.best_month,
237
+ 'worst_month': result.worst_month,
238
+ 'seasonal_spread': result.seasonal_spread
239
+ })
240
+
241
+ return pd.DataFrame(results)
@@ -0,0 +1,87 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import numpy as np
4
+
5
+ from customer_retention.core.compat import DataFrame
6
+
7
+ EMBEDDING_MODELS: Dict[str, Dict[str, Any]] = {
8
+ "minilm": {
9
+ "model_name": "all-MiniLM-L6-v2",
10
+ "embedding_dim": 384,
11
+ "size_mb": 90,
12
+ "description": "Fast, lightweight model. Good for CPU and quick experimentation.",
13
+ "gpu_recommended": False,
14
+ },
15
+ "qwen3-0.6b": {
16
+ "model_name": "Qwen/Qwen3-Embedding-0.6B",
17
+ "embedding_dim": 1024,
18
+ "size_mb": 1200,
19
+ "description": "Higher quality embeddings, multilingual. Requires GPU for reasonable speed.",
20
+ "gpu_recommended": True,
21
+ },
22
+ "qwen3-4b": {
23
+ "model_name": "Qwen/Qwen3-Embedding-4B",
24
+ "embedding_dim": 2560,
25
+ "size_mb": 8000,
26
+ "description": "High quality, large model. Requires significant GPU memory (16GB+).",
27
+ "gpu_recommended": True,
28
+ },
29
+ "qwen3-8b": {
30
+ "model_name": "Qwen/Qwen3-Embedding-8B",
31
+ "embedding_dim": 4096,
32
+ "size_mb": 16000,
33
+ "description": "Highest quality, very large model. Requires 32GB+ GPU memory.",
34
+ "gpu_recommended": True,
35
+ },
36
+ }
37
+
38
+
39
+ def get_model_info(preset: str) -> Dict[str, Any]:
40
+ if preset not in EMBEDDING_MODELS:
41
+ raise ValueError(f"Unknown preset: {preset}. Available: {list(EMBEDDING_MODELS.keys())}")
42
+ return EMBEDDING_MODELS[preset].copy()
43
+
44
+
45
+ def list_available_models() -> List[str]:
46
+ return list(EMBEDDING_MODELS.keys())
47
+
48
+
49
+ class TextEmbedder:
50
+ DEFAULT_MODEL = "all-MiniLM-L6-v2"
51
+
52
+ def __init__(self, model_name: str = DEFAULT_MODEL):
53
+ self.model_name = model_name
54
+ self._model = None
55
+
56
+ @classmethod
57
+ def from_preset(cls, preset: str) -> "TextEmbedder":
58
+ if preset not in EMBEDDING_MODELS:
59
+ raise ValueError(f"Unknown preset: {preset}. Available: {list(EMBEDDING_MODELS.keys())}")
60
+ model_name = EMBEDDING_MODELS[preset]["model_name"]
61
+ return cls(model_name=model_name)
62
+
63
+ @property
64
+ def model(self):
65
+ if self._model is None:
66
+ from sentence_transformers import SentenceTransformer
67
+ self._model = SentenceTransformer(self.model_name)
68
+ return self._model
69
+
70
+ @property
71
+ def embedding_dim(self) -> int:
72
+ return self.model.get_sentence_embedding_dimension()
73
+
74
+ def embed(self, texts: List[Optional[str]], batch_size: int = 32,
75
+ show_progress: bool = False) -> np.ndarray:
76
+ clean_texts = [self._clean_text(t) for t in texts]
77
+ return self.model.encode(clean_texts, batch_size=batch_size,
78
+ show_progress_bar=show_progress)
79
+
80
+ def embed_column(self, df: DataFrame, column: str, batch_size: int = 32) -> np.ndarray:
81
+ texts = df[column].fillna("").astype(str).tolist()
82
+ return self.embed(texts, batch_size=batch_size)
83
+
84
+ def _clean_text(self, text: Optional[str]) -> str:
85
+ if not isinstance(text, str) or not text.strip():
86
+ return ""
87
+ return text