churnkit 0.75.1a3__py3-none-any.whl → 0.76.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/METADATA +5 -2
  2. {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/RECORD +41 -40
  3. customer_retention/__init__.py +11 -1
  4. customer_retention/core/compat/__init__.py +3 -0
  5. customer_retention/core/config/__init__.py +43 -8
  6. customer_retention/core/config/experiments.py +20 -0
  7. customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +222 -149
  8. customer_retention/integrations/adapters/factory.py +8 -5
  9. customer_retention/integrations/adapters/feature_store/base.py +1 -0
  10. customer_retention/integrations/adapters/feature_store/databricks.py +58 -10
  11. customer_retention/integrations/adapters/mlflow/base.py +8 -0
  12. customer_retention/integrations/adapters/mlflow/databricks.py +15 -2
  13. customer_retention/integrations/adapters/mlflow/local.py +7 -0
  14. customer_retention/integrations/databricks_init.py +141 -0
  15. customer_retention/stages/profiling/temporal_feature_analyzer.py +3 -3
  16. customer_retention/stages/profiling/temporal_feature_engineer.py +2 -2
  17. customer_retention/stages/profiling/temporal_pattern_analyzer.py +4 -3
  18. customer_retention/stages/profiling/time_series_profiler.py +5 -4
  19. customer_retention/stages/profiling/time_window_aggregator.py +3 -2
  20. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +0 -0
  21. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +0 -0
  22. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +0 -0
  23. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +0 -0
  24. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +0 -0
  25. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +0 -0
  26. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +0 -0
  27. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +0 -0
  28. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +0 -0
  29. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +0 -0
  30. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +0 -0
  31. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +0 -0
  32. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +0 -0
  33. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +0 -0
  34. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +0 -0
  35. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +0 -0
  36. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +0 -0
  37. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +0 -0
  38. {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +0 -0
  39. {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/WHEEL +0 -0
  40. {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/entry_points.txt +0 -0
  41. {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/licenses/LICENSE +0 -0
@@ -84,6 +84,9 @@ class MLflowConfig:
84
84
  log_feature_importance: bool = True
85
85
  nested_runs: bool = True
86
86
  model_name: Optional[str] = None
87
+ databricks: bool = False
88
+ catalog: str = "main"
89
+ schema: str = "default"
87
90
 
88
91
 
89
92
  class MLflowPipelineGenerator:
@@ -106,14 +109,16 @@ class MLflowPipelineGenerator:
106
109
  if self.mlflow_config.log_data_quality:
107
110
  sections.append(self._generate_data_quality_logging())
108
111
 
109
- sections.extend([
110
- self.generate_cleaning_functions(findings),
111
- self.generate_transform_functions(findings),
112
- self.generate_feature_engineering(findings),
113
- self.generate_model_training(findings),
114
- self.generate_monitoring(findings),
115
- self._generate_main(findings),
116
- ])
112
+ sections.extend(
113
+ [
114
+ self.generate_cleaning_functions(findings),
115
+ self.generate_transform_functions(findings),
116
+ self.generate_feature_engineering(findings),
117
+ self.generate_model_training(findings),
118
+ self.generate_monitoring(findings),
119
+ self._generate_main(findings),
120
+ ]
121
+ )
117
122
  return "\n\n".join(sections)
118
123
 
119
124
  def _generate_docstring(self, findings: ExplorationFindings) -> str:
@@ -122,13 +127,13 @@ MLflow-tracked ML Pipeline
122
127
  Generated from exploration findings
123
128
 
124
129
  Source: {findings.source_path}
125
- Target: {findings.target_column or 'Not specified'}
130
+ Target: {findings.target_column or "Not specified"}
126
131
  Rows: {findings.row_count:,}
127
132
  Features: {findings.column_count}
128
133
  """'''
129
134
 
130
135
  def _generate_imports(self) -> str:
131
- return """import pandas as pd
136
+ base_imports = """import pandas as pd
132
137
  import numpy as np
133
138
  from datetime import datetime
134
139
  from typing import Dict, List, Tuple, Any
@@ -146,8 +151,20 @@ from sklearn.metrics import (
146
151
  accuracy_score, precision_score, recall_score, f1_score,
147
152
  roc_auc_score, classification_report, confusion_matrix
148
153
  )"""
154
+ if self.mlflow_config.databricks:
155
+ base_imports += "\nfrom mlflow.tracking import MlflowClient"
156
+ return base_imports
149
157
 
150
158
  def _generate_mlflow_setup(self) -> str:
159
+ if self.mlflow_config.databricks:
160
+ return f'''
161
+ EXPERIMENT_NAME = "{self.mlflow_config.experiment_name}"
162
+
163
+
164
+ def setup_mlflow():
165
+ """Initialize MLflow tracking (Databricks auto-configures tracking URI)."""
166
+ mlflow.set_experiment(EXPERIMENT_NAME)
167
+ return mlflow.get_experiment_by_name(EXPERIMENT_NAME)'''
151
168
  return f'''
152
169
  MLFLOW_TRACKING_URI = "{self.mlflow_config.tracking_uri}"
153
170
  EXPERIMENT_NAME = "{self.mlflow_config.experiment_name}"
@@ -194,11 +211,13 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
194
211
  for action in actions:
195
212
  code_lines.extend(self._action_to_cleaning_code(col_name, action))
196
213
 
197
- code_lines.extend([
198
- "",
199
- " mlflow.log_params({f'cleaned_{k}': v for k, v in cleaning_stats.items()})",
200
- " return df",
201
- ])
214
+ code_lines.extend(
215
+ [
216
+ "",
217
+ " mlflow.log_params({f'cleaned_{k}': v for k, v in cleaning_stats.items()})",
218
+ " return df",
219
+ ]
220
+ )
202
221
 
203
222
  return "\n".join(code_lines)
204
223
 
@@ -224,65 +243,73 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
224
243
 
225
244
  if action.action_type == "impute":
226
245
  if action.strategy == "median":
227
- lines.extend([
228
- f" # Impute {col_name} with median",
229
- f" if df['{col_name}'].isna().any():",
230
- f" median_val = df['{col_name}'].median()",
231
- f" cleaning_stats['{col_name}_imputed'] = df['{col_name}'].isna().sum()",
232
- f" df['{col_name}'] = df['{col_name}'].fillna(median_val)",
233
- "",
234
- ])
246
+ lines.extend(
247
+ [
248
+ f" # Impute {col_name} with median",
249
+ f" if df['{col_name}'].isna().any():",
250
+ f" median_val = df['{col_name}'].median()",
251
+ f" cleaning_stats['{col_name}_imputed'] = df['{col_name}'].isna().sum()",
252
+ f" df['{col_name}'] = df['{col_name}'].fillna(median_val)",
253
+ "",
254
+ ]
255
+ )
235
256
  elif action.strategy == "mode":
236
- lines.extend([
237
- f" # Impute {col_name} with mode",
238
- f" if df['{col_name}'].isna().any():",
239
- f" mode_val = df['{col_name}'].mode().iloc[0] if not df['{col_name}'].mode().empty else None",
240
- " if mode_val is not None:",
241
- f" cleaning_stats['{col_name}_imputed'] = df['{col_name}'].isna().sum()",
242
- f" df['{col_name}'] = df['{col_name}'].fillna(mode_val)",
243
- "",
244
- ])
257
+ lines.extend(
258
+ [
259
+ f" # Impute {col_name} with mode",
260
+ f" if df['{col_name}'].isna().any():",
261
+ f" mode_val = df['{col_name}'].mode().iloc[0] if not df['{col_name}'].mode().empty else None",
262
+ " if mode_val is not None:",
263
+ f" cleaning_stats['{col_name}_imputed'] = df['{col_name}'].isna().sum()",
264
+ f" df['{col_name}'] = df['{col_name}'].fillna(mode_val)",
265
+ "",
266
+ ]
267
+ )
245
268
  elif action.strategy == "constant":
246
269
  fill_value = action.params.get("fill_value", 0)
247
- lines.extend([
248
- f" # Impute {col_name} with constant",
249
- f" if df['{col_name}'].isna().any():",
250
- f" cleaning_stats['{col_name}_imputed'] = df['{col_name}'].isna().sum()",
251
- f" df['{col_name}'] = df['{col_name}'].fillna({repr(fill_value)})",
252
- "",
253
- ])
270
+ lines.extend(
271
+ [
272
+ f" # Impute {col_name} with constant",
273
+ f" if df['{col_name}'].isna().any():",
274
+ f" cleaning_stats['{col_name}_imputed'] = df['{col_name}'].isna().sum()",
275
+ f" df['{col_name}'] = df['{col_name}'].fillna({repr(fill_value)})",
276
+ "",
277
+ ]
278
+ )
254
279
 
255
280
  elif action.action_type == "cap_outliers":
256
281
  percentile = action.params.get("percentile", 99)
257
- lines.extend([
258
- f" # Cap outliers in {col_name} at {percentile}th percentile",
259
- f" lower = df['{col_name}'].quantile({(100 - percentile) / 100})",
260
- f" upper = df['{col_name}'].quantile({percentile / 100})",
261
- f" outliers = ((df['{col_name}'] < lower) | (df['{col_name}'] > upper)).sum()",
262
- f" cleaning_stats['{col_name}_outliers_capped'] = outliers",
263
- f" df['{col_name}'] = df['{col_name}'].clip(lower, upper)",
264
- "",
265
- ])
282
+ lines.extend(
283
+ [
284
+ f" # Cap outliers in {col_name} at {percentile}th percentile",
285
+ f" lower = df['{col_name}'].quantile({(100 - percentile) / 100})",
286
+ f" upper = df['{col_name}'].quantile({percentile / 100})",
287
+ f" outliers = ((df['{col_name}'] < lower) | (df['{col_name}'] > upper)).sum()",
288
+ f" cleaning_stats['{col_name}_outliers_capped'] = outliers",
289
+ f" df['{col_name}'] = df['{col_name}'].clip(lower, upper)",
290
+ "",
291
+ ]
292
+ )
266
293
 
267
294
  elif action.action_type == "drop_rare":
268
295
  threshold = action.params.get("threshold_percent", 5)
269
- lines.extend([
270
- f" # Drop rare categories in {col_name} (< {threshold}%)",
271
- f" value_counts = df['{col_name}'].value_counts(normalize=True)",
272
- f" rare_values = value_counts[value_counts < {threshold / 100}].index",
273
- " if len(rare_values) > 0:",
274
- f" cleaning_stats['{col_name}_rare_dropped'] = len(rare_values)",
275
- f" df.loc[df['{col_name}'].isin(rare_values), '{col_name}'] = df['{col_name}'].mode().iloc[0]",
276
- "",
277
- ])
296
+ lines.extend(
297
+ [
298
+ f" # Drop rare categories in {col_name} (< {threshold}%)",
299
+ f" value_counts = df['{col_name}'].value_counts(normalize=True)",
300
+ f" rare_values = value_counts[value_counts < {threshold / 100}].index",
301
+ " if len(rare_values) > 0:",
302
+ f" cleaning_stats['{col_name}_rare_dropped'] = len(rare_values)",
303
+ f" df.loc[df['{col_name}'].isin(rare_values), '{col_name}'] = df['{col_name}'].mode().iloc[0]",
304
+ "",
305
+ ]
306
+ )
278
307
 
279
308
  return lines
280
309
 
281
310
  def generate_transform_functions(self, findings: ExplorationFindings) -> str:
282
- self._get_columns_by_type(findings,
283
- [ColumnType.NUMERIC_CONTINUOUS, ColumnType.NUMERIC_DISCRETE])
284
- self._get_columns_by_type(findings,
285
- [ColumnType.CATEGORICAL_NOMINAL, ColumnType.CATEGORICAL_ORDINAL])
311
+ self._get_columns_by_type(findings, [ColumnType.NUMERIC_CONTINUOUS, ColumnType.NUMERIC_DISCRETE])
312
+ self._get_columns_by_type(findings, [ColumnType.CATEGORICAL_NOMINAL, ColumnType.CATEGORICAL_ORDINAL])
286
313
 
287
314
  transform_actions = self._build_transform_actions(findings)
288
315
 
@@ -295,79 +322,102 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
295
322
  ]
296
323
 
297
324
  # Log transform for skewed columns
298
- log_cols = [col for col, actions in transform_actions.items()
299
- if any(a.method == "log1p" for a in actions)]
325
+ log_cols = [col for col, actions in transform_actions.items() if any(a.method == "log1p" for a in actions)]
300
326
  if log_cols:
301
327
  for col in log_cols:
302
- code_lines.extend([
303
- f" # Log transform {col} (recommended for skewness)",
304
- f" df['{col}_log'] = np.log1p(df['{col}'].clip(lower=0))",
305
- f" transformers['{col}_log_transform'] = True",
306
- "",
307
- ])
328
+ code_lines.extend(
329
+ [
330
+ f" # Log transform {col} (recommended for skewness)",
331
+ f" df['{col}_log'] = np.log1p(df['{col}'].clip(lower=0))",
332
+ f" transformers['{col}_log_transform'] = True",
333
+ "",
334
+ ]
335
+ )
308
336
 
309
337
  # Standard scaling
310
- scale_standard = [col for col, actions in transform_actions.items()
311
- if any(a.action_type == "scale" and a.method == "standard" for a in actions)]
338
+ scale_standard = [
339
+ col
340
+ for col, actions in transform_actions.items()
341
+ if any(a.action_type == "scale" and a.method == "standard" for a in actions)
342
+ ]
312
343
  if scale_standard:
313
- code_lines.extend([
314
- " # Standard scaling",
315
- f" standard_cols = {scale_standard}",
316
- " if standard_cols:",
317
- " scaler = StandardScaler()",
318
- " df[standard_cols] = scaler.fit_transform(df[standard_cols])",
319
- " transformers['standard_scaler'] = {'columns': standard_cols}",
320
- "",
321
- ])
344
+ code_lines.extend(
345
+ [
346
+ " # Standard scaling",
347
+ f" standard_cols = {scale_standard}",
348
+ " if standard_cols:",
349
+ " scaler = StandardScaler()",
350
+ " df[standard_cols] = scaler.fit_transform(df[standard_cols])",
351
+ " transformers['standard_scaler'] = {'columns': standard_cols}",
352
+ "",
353
+ ]
354
+ )
322
355
 
323
356
  # MinMax scaling
324
- scale_minmax = [col for col, actions in transform_actions.items()
325
- if any(a.action_type == "scale" and a.method == "minmax" for a in actions)]
357
+ scale_minmax = [
358
+ col
359
+ for col, actions in transform_actions.items()
360
+ if any(a.action_type == "scale" and a.method == "minmax" for a in actions)
361
+ ]
326
362
  if scale_minmax:
327
- code_lines.extend([
328
- " # MinMax scaling",
329
- f" minmax_cols = {scale_minmax}",
330
- " if minmax_cols:",
331
- " minmax_scaler = MinMaxScaler()",
332
- " df[minmax_cols] = minmax_scaler.fit_transform(df[minmax_cols])",
333
- " transformers['minmax_scaler'] = {'columns': minmax_cols}",
334
- "",
335
- ])
363
+ code_lines.extend(
364
+ [
365
+ " # MinMax scaling",
366
+ f" minmax_cols = {scale_minmax}",
367
+ " if minmax_cols:",
368
+ " minmax_scaler = MinMaxScaler()",
369
+ " df[minmax_cols] = minmax_scaler.fit_transform(df[minmax_cols])",
370
+ " transformers['minmax_scaler'] = {'columns': minmax_cols}",
371
+ "",
372
+ ]
373
+ )
336
374
 
337
375
  # One-hot encoding
338
- onehot_cols = [col for col, actions in transform_actions.items()
339
- if any(a.action_type == "encode" and a.method == "onehot" for a in actions)]
376
+ onehot_cols = [
377
+ col
378
+ for col, actions in transform_actions.items()
379
+ if any(a.action_type == "encode" and a.method == "onehot" for a in actions)
380
+ ]
340
381
  if onehot_cols:
341
- code_lines.extend([
342
- " # One-hot encoding",
343
- f" onehot_cols = {onehot_cols}",
344
- " for col in onehot_cols:",
345
- " dummies = pd.get_dummies(df[col], prefix=col, drop_first=True)",
346
- " df = pd.concat([df.drop(columns=[col]), dummies], axis=1)",
347
- " transformers[f'{col}_onehot'] = list(dummies.columns)",
348
- "",
349
- ])
382
+ code_lines.extend(
383
+ [
384
+ " # One-hot encoding",
385
+ f" onehot_cols = {onehot_cols}",
386
+ " for col in onehot_cols:",
387
+ " dummies = pd.get_dummies(df[col], prefix=col, drop_first=True)",
388
+ " df = pd.concat([df.drop(columns=[col]), dummies], axis=1)",
389
+ " transformers[f'{col}_onehot'] = list(dummies.columns)",
390
+ "",
391
+ ]
392
+ )
350
393
 
351
394
  # Label encoding
352
- label_cols = [col for col, actions in transform_actions.items()
353
- if any(a.action_type == "encode" and a.method == "label" for a in actions)]
395
+ label_cols = [
396
+ col
397
+ for col, actions in transform_actions.items()
398
+ if any(a.action_type == "encode" and a.method == "label" for a in actions)
399
+ ]
354
400
  if label_cols:
355
- code_lines.extend([
356
- " # Label encoding",
357
- f" label_cols = {label_cols}",
358
- " label_encoders = {{}}",
359
- " for col in label_cols:",
360
- " le = LabelEncoder()",
361
- " df[col] = le.fit_transform(df[col].astype(str))",
362
- " label_encoders[col] = le",
363
- " transformers['label_encoders'] = label_encoders",
364
- "",
365
- ])
401
+ code_lines.extend(
402
+ [
403
+ " # Label encoding",
404
+ f" label_cols = {label_cols}",
405
+ " label_encoders = {{}}",
406
+ " for col in label_cols:",
407
+ " le = LabelEncoder()",
408
+ " df[col] = le.fit_transform(df[col].astype(str))",
409
+ " label_encoders[col] = le",
410
+ " transformers['label_encoders'] = label_encoders",
411
+ "",
412
+ ]
413
+ )
366
414
 
367
- code_lines.extend([
368
- " mlflow.log_params({f'transform_{k}': str(v)[:250] for k, v in transformers.items()})",
369
- " return df, transformers",
370
- ])
415
+ code_lines.extend(
416
+ [
417
+ " mlflow.log_params({f'transform_{k}': str(v)[:250] for k, v in transformers.items()})",
418
+ " return df, transformers",
419
+ ]
420
+ )
371
421
 
372
422
  return "\n".join(code_lines)
373
423
 
@@ -409,12 +459,14 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
409
459
  if not extract_types:
410
460
  extract_types = ["month", "dayofweek", "days_since"]
411
461
 
412
- code_lines.extend([
413
- f" # Datetime features from {col_name}",
414
- f" if '{col_name}' in df.columns:",
415
- f" df['{col_name}'] = safe_to_datetime(df['{col_name}'], errors='coerce')",
416
- "",
417
- ])
462
+ code_lines.extend(
463
+ [
464
+ f" # Datetime features from {col_name}",
465
+ f" if '{col_name}' in df.columns:",
466
+ f" df['{col_name}'] = safe_to_datetime(df['{col_name}'], errors='coerce')",
467
+ "",
468
+ ]
469
+ )
418
470
 
419
471
  for ext_type in extract_types:
420
472
  if ext_type == "month":
@@ -433,19 +485,23 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
433
485
  code_lines.append(f" df['{col_name}_year'] = df['{col_name}'].dt.year")
434
486
  code_lines.append(f" new_features.append('{col_name}_year')")
435
487
  elif ext_type == "days_since":
436
- code_lines.extend([
437
- f" reference_date = df['{col_name}'].max()",
438
- f" df['{col_name}_days_since'] = (reference_date - df['{col_name}']).dt.days",
439
- f" new_features.append('{col_name}_days_since')",
440
- ])
488
+ code_lines.extend(
489
+ [
490
+ f" reference_date = df['{col_name}'].max()",
491
+ f" df['{col_name}_days_since'] = (reference_date - df['{col_name}']).dt.days",
492
+ f" new_features.append('{col_name}_days_since')",
493
+ ]
494
+ )
441
495
 
442
496
  code_lines.append("")
443
497
 
444
- code_lines.extend([
445
- " if new_features:",
446
- " mlflow.log_param('engineered_features', new_features)",
447
- " return df",
448
- ])
498
+ code_lines.extend(
499
+ [
500
+ " if new_features:",
501
+ " mlflow.log_param('engineered_features', new_features)",
502
+ " return df",
503
+ ]
504
+ )
449
505
 
450
506
  return "\n".join(code_lines)
451
507
 
@@ -455,7 +511,7 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
455
511
  datetime_cols = findings.datetime_columns or []
456
512
  exclude_cols = set(identifier_cols + datetime_cols + [target])
457
513
 
458
- return f'''
514
+ main_body = f'''
459
515
  def train_model(
460
516
  df: pd.DataFrame,
461
517
  target_column: str = "{target}",
@@ -544,7 +600,7 @@ def train_model(
544
600
  # Log everything
545
601
  mlflow.log_params(model.get_params())
546
602
  mlflow.log_metrics({{**val_metrics, **test_metrics, **cv_metrics}})
547
- mlflow.sklearn.log_model(model, f"model_{{name}}")
603
+ mlflow.sklearn.log_model(model, name=f"model_{{name}}")
548
604
 
549
605
  results[name] = {{
550
606
  "model": model,
@@ -558,9 +614,29 @@ def train_model(
558
614
  best_model = name
559
615
 
560
616
  mlflow.log_param("best_model", best_model)
561
- mlflow.log_metric("best_val_roc_auc", best_auc)
617
+ mlflow.log_metric("best_val_roc_auc", best_auc)'''
618
+
619
+ if self.mlflow_config.databricks and self.mlflow_config.model_name:
620
+ reg_name = f"{self.mlflow_config.catalog}.{self.mlflow_config.schema}.{self.mlflow_config.model_name}"
621
+ main_body += f'''
562
622
 
563
- return {{"results": results, "best_model": best_model}}'''
623
+ # Register best model in Unity Catalog and set alias
624
+ if best_model:
625
+ best_run = results[best_model]
626
+ model_info = mlflow.sklearn.log_model(
627
+ best_run["model"],
628
+ name="best_model",
629
+ registered_model_name="{reg_name}",
630
+ )
631
+ client = MlflowClient()
632
+ latest_version = client.get_latest_versions("{reg_name}")[0].version
633
+ client.set_registered_model_alias("{reg_name}", "champion", latest_version)'''
634
+
635
+ main_body += """
636
+
637
+ return {"results": results, "best_model": best_model}"""
638
+
639
+ return main_body
564
640
 
565
641
  def generate_monitoring(self, findings: ExplorationFindings) -> str:
566
642
  return '''
@@ -607,16 +683,16 @@ def main():
607
683
  if self.mlflow_config.log_data_quality:
608
684
  main_body += "\n log_data_quality_metrics(df, prefix='raw')"
609
685
 
610
- main_body += '''
686
+ main_body += """
611
687
 
612
688
  # Clean data
613
689
  print("Cleaning data...")
614
- df = clean_data(df)'''
690
+ df = clean_data(df)"""
615
691
 
616
692
  if self.mlflow_config.log_data_quality:
617
693
  main_body += "\n log_data_quality_metrics(df, prefix='cleaned')"
618
694
 
619
- main_body += '''
695
+ main_body += """
620
696
 
621
697
  # Apply transformations
622
698
  print("Applying transformations...")
@@ -624,12 +700,12 @@ def main():
624
700
 
625
701
  # Engineer features
626
702
  print("Engineering features...")
627
- df = engineer_features(df)'''
703
+ df = engineer_features(df)"""
628
704
 
629
705
  if self.mlflow_config.log_data_quality:
630
706
  main_body += "\n log_data_quality_metrics(df, prefix='final')"
631
707
 
632
- main_body += '''
708
+ main_body += """
633
709
 
634
710
  # Train models
635
711
  print("Training models...")
@@ -642,7 +718,7 @@ def main():
642
718
 
643
719
 
644
720
  if __name__ == "__main__":
645
- main()'''
721
+ main()"""
646
722
 
647
723
  return main_body
648
724
 
@@ -651,10 +727,7 @@ if __name__ == "__main__":
651
727
  findings: ExplorationFindings,
652
728
  col_types: List[ColumnType],
653
729
  ) -> List[str]:
654
- return [
655
- name for name, col in findings.columns.items()
656
- if col.inferred_type in col_types
657
- ]
730
+ return [name for name, col in findings.columns.items() if col.inferred_type in col_types]
658
731
 
659
732
  def generate_all(self, findings: ExplorationFindings) -> Dict[str, str]:
660
733
  return {
@@ -1,4 +1,5 @@
1
1
  from customer_retention.core.compat.detection import is_spark_available
2
+ from customer_retention.core.config.experiments import get_catalog, get_schema
2
3
 
3
4
  from .feature_store import DatabricksFeatureStore, FeatureStoreAdapter, LocalFeatureStore
4
5
  from .mlflow import DatabricksMLflow, LocalMLflow, MLflowAdapter
@@ -11,15 +12,17 @@ def get_delta(force_local: bool = False) -> DeltaStorage:
11
12
  return DatabricksDelta()
12
13
 
13
14
 
14
- def get_feature_store(base_path: str = "./feature_store", catalog: str = "main",
15
- schema: str = "default", force_local: bool = False) -> FeatureStoreAdapter:
15
+ def get_feature_store(
16
+ base_path: str = "./feature_store", catalog: str | None = None, schema: str | None = None, force_local: bool = False
17
+ ) -> FeatureStoreAdapter:
16
18
  if force_local or not is_spark_available():
17
19
  return LocalFeatureStore(base_path=base_path)
18
- return DatabricksFeatureStore(catalog=catalog, schema=schema)
20
+ return DatabricksFeatureStore(catalog=catalog or get_catalog(), schema=schema or get_schema())
19
21
 
20
22
 
21
- def get_mlflow(tracking_uri: str = "./mlruns", registry_uri: str = "databricks-uc",
22
- force_local: bool = False) -> MLflowAdapter:
23
+ def get_mlflow(
24
+ tracking_uri: str = "./mlruns", registry_uri: str = "databricks-uc", force_local: bool = False
25
+ ) -> MLflowAdapter:
23
26
  if force_local or not is_spark_available():
24
27
  return LocalMLflow(tracking_uri=tracking_uri)
25
28
  return DatabricksMLflow(registry_uri=registry_uri)
@@ -17,6 +17,7 @@ class FeatureViewConfig:
17
17
  tags: Dict[str, str] = field(default_factory=dict)
18
18
  cutoff_date: Optional[datetime] = None
19
19
  data_hash: Optional[str] = None
20
+ timeseries_column: Optional[str] = None
20
21
 
21
22
 
22
23
  class FeatureStoreAdapter(ABC):