churnkit 0.75.1a3__py3-none-any.whl → 0.76.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/METADATA +5 -2
- {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/RECORD +41 -40
- customer_retention/__init__.py +11 -1
- customer_retention/core/compat/__init__.py +3 -0
- customer_retention/core/config/__init__.py +43 -8
- customer_retention/core/config/experiments.py +20 -0
- customer_retention/generators/spec_generator/mlflow_pipeline_generator.py +222 -149
- customer_retention/integrations/adapters/factory.py +8 -5
- customer_retention/integrations/adapters/feature_store/base.py +1 -0
- customer_retention/integrations/adapters/feature_store/databricks.py +58 -10
- customer_retention/integrations/adapters/mlflow/base.py +8 -0
- customer_retention/integrations/adapters/mlflow/databricks.py +15 -2
- customer_retention/integrations/adapters/mlflow/local.py +7 -0
- customer_retention/integrations/databricks_init.py +141 -0
- customer_retention/stages/profiling/temporal_feature_analyzer.py +3 -3
- customer_retention/stages/profiling/temporal_feature_engineer.py +2 -2
- customer_retention/stages/profiling/temporal_pattern_analyzer.py +4 -3
- customer_retention/stages/profiling/time_series_profiler.py +5 -4
- customer_retention/stages/profiling/time_window_aggregator.py +3 -2
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/00_start_here.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01_data_discovery.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01a_a_temporal_text_deep_dive.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01a_temporal_deep_dive.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01b_temporal_quality.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01c_temporal_patterns.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/01d_event_aggregation.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/02_column_deep_dive.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/02a_text_columns_deep_dive.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/03_quality_assessment.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/04_relationship_analysis.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/05_multi_dataset.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/06_feature_opportunities.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/07_modeling_readiness.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/08_baseline_experiments.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/09_business_alignment.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/10_spec_generation.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/11_scoring_validation.ipynb +0 -0
- {churnkit-0.75.1a3.data → churnkit-0.76.0a1.data}/data/share/churnkit/exploration_notebooks/12_view_documentation.ipynb +0 -0
- {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/WHEEL +0 -0
- {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/entry_points.txt +0 -0
- {churnkit-0.75.1a3.dist-info → churnkit-0.76.0a1.dist-info}/licenses/LICENSE +0 -0
|
@@ -84,6 +84,9 @@ class MLflowConfig:
|
|
|
84
84
|
log_feature_importance: bool = True
|
|
85
85
|
nested_runs: bool = True
|
|
86
86
|
model_name: Optional[str] = None
|
|
87
|
+
databricks: bool = False
|
|
88
|
+
catalog: str = "main"
|
|
89
|
+
schema: str = "default"
|
|
87
90
|
|
|
88
91
|
|
|
89
92
|
class MLflowPipelineGenerator:
|
|
@@ -106,14 +109,16 @@ class MLflowPipelineGenerator:
|
|
|
106
109
|
if self.mlflow_config.log_data_quality:
|
|
107
110
|
sections.append(self._generate_data_quality_logging())
|
|
108
111
|
|
|
109
|
-
sections.extend(
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
112
|
+
sections.extend(
|
|
113
|
+
[
|
|
114
|
+
self.generate_cleaning_functions(findings),
|
|
115
|
+
self.generate_transform_functions(findings),
|
|
116
|
+
self.generate_feature_engineering(findings),
|
|
117
|
+
self.generate_model_training(findings),
|
|
118
|
+
self.generate_monitoring(findings),
|
|
119
|
+
self._generate_main(findings),
|
|
120
|
+
]
|
|
121
|
+
)
|
|
117
122
|
return "\n\n".join(sections)
|
|
118
123
|
|
|
119
124
|
def _generate_docstring(self, findings: ExplorationFindings) -> str:
|
|
@@ -122,13 +127,13 @@ MLflow-tracked ML Pipeline
|
|
|
122
127
|
Generated from exploration findings
|
|
123
128
|
|
|
124
129
|
Source: {findings.source_path}
|
|
125
|
-
Target: {findings.target_column or
|
|
130
|
+
Target: {findings.target_column or "Not specified"}
|
|
126
131
|
Rows: {findings.row_count:,}
|
|
127
132
|
Features: {findings.column_count}
|
|
128
133
|
"""'''
|
|
129
134
|
|
|
130
135
|
def _generate_imports(self) -> str:
|
|
131
|
-
|
|
136
|
+
base_imports = """import pandas as pd
|
|
132
137
|
import numpy as np
|
|
133
138
|
from datetime import datetime
|
|
134
139
|
from typing import Dict, List, Tuple, Any
|
|
@@ -146,8 +151,20 @@ from sklearn.metrics import (
|
|
|
146
151
|
accuracy_score, precision_score, recall_score, f1_score,
|
|
147
152
|
roc_auc_score, classification_report, confusion_matrix
|
|
148
153
|
)"""
|
|
154
|
+
if self.mlflow_config.databricks:
|
|
155
|
+
base_imports += "\nfrom mlflow.tracking import MlflowClient"
|
|
156
|
+
return base_imports
|
|
149
157
|
|
|
150
158
|
def _generate_mlflow_setup(self) -> str:
|
|
159
|
+
if self.mlflow_config.databricks:
|
|
160
|
+
return f'''
|
|
161
|
+
EXPERIMENT_NAME = "{self.mlflow_config.experiment_name}"
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def setup_mlflow():
|
|
165
|
+
"""Initialize MLflow tracking (Databricks auto-configures tracking URI)."""
|
|
166
|
+
mlflow.set_experiment(EXPERIMENT_NAME)
|
|
167
|
+
return mlflow.get_experiment_by_name(EXPERIMENT_NAME)'''
|
|
151
168
|
return f'''
|
|
152
169
|
MLFLOW_TRACKING_URI = "{self.mlflow_config.tracking_uri}"
|
|
153
170
|
EXPERIMENT_NAME = "{self.mlflow_config.experiment_name}"
|
|
@@ -194,11 +211,13 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
|
|
|
194
211
|
for action in actions:
|
|
195
212
|
code_lines.extend(self._action_to_cleaning_code(col_name, action))
|
|
196
213
|
|
|
197
|
-
code_lines.extend(
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
214
|
+
code_lines.extend(
|
|
215
|
+
[
|
|
216
|
+
"",
|
|
217
|
+
" mlflow.log_params({f'cleaned_{k}': v for k, v in cleaning_stats.items()})",
|
|
218
|
+
" return df",
|
|
219
|
+
]
|
|
220
|
+
)
|
|
202
221
|
|
|
203
222
|
return "\n".join(code_lines)
|
|
204
223
|
|
|
@@ -224,65 +243,73 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
|
|
|
224
243
|
|
|
225
244
|
if action.action_type == "impute":
|
|
226
245
|
if action.strategy == "median":
|
|
227
|
-
lines.extend(
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
246
|
+
lines.extend(
|
|
247
|
+
[
|
|
248
|
+
f" # Impute {col_name} with median",
|
|
249
|
+
f" if df['{col_name}'].isna().any():",
|
|
250
|
+
f" median_val = df['{col_name}'].median()",
|
|
251
|
+
f" cleaning_stats['{col_name}_imputed'] = df['{col_name}'].isna().sum()",
|
|
252
|
+
f" df['{col_name}'] = df['{col_name}'].fillna(median_val)",
|
|
253
|
+
"",
|
|
254
|
+
]
|
|
255
|
+
)
|
|
235
256
|
elif action.strategy == "mode":
|
|
236
|
-
lines.extend(
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
257
|
+
lines.extend(
|
|
258
|
+
[
|
|
259
|
+
f" # Impute {col_name} with mode",
|
|
260
|
+
f" if df['{col_name}'].isna().any():",
|
|
261
|
+
f" mode_val = df['{col_name}'].mode().iloc[0] if not df['{col_name}'].mode().empty else None",
|
|
262
|
+
" if mode_val is not None:",
|
|
263
|
+
f" cleaning_stats['{col_name}_imputed'] = df['{col_name}'].isna().sum()",
|
|
264
|
+
f" df['{col_name}'] = df['{col_name}'].fillna(mode_val)",
|
|
265
|
+
"",
|
|
266
|
+
]
|
|
267
|
+
)
|
|
245
268
|
elif action.strategy == "constant":
|
|
246
269
|
fill_value = action.params.get("fill_value", 0)
|
|
247
|
-
lines.extend(
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
270
|
+
lines.extend(
|
|
271
|
+
[
|
|
272
|
+
f" # Impute {col_name} with constant",
|
|
273
|
+
f" if df['{col_name}'].isna().any():",
|
|
274
|
+
f" cleaning_stats['{col_name}_imputed'] = df['{col_name}'].isna().sum()",
|
|
275
|
+
f" df['{col_name}'] = df['{col_name}'].fillna({repr(fill_value)})",
|
|
276
|
+
"",
|
|
277
|
+
]
|
|
278
|
+
)
|
|
254
279
|
|
|
255
280
|
elif action.action_type == "cap_outliers":
|
|
256
281
|
percentile = action.params.get("percentile", 99)
|
|
257
|
-
lines.extend(
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
282
|
+
lines.extend(
|
|
283
|
+
[
|
|
284
|
+
f" # Cap outliers in {col_name} at {percentile}th percentile",
|
|
285
|
+
f" lower = df['{col_name}'].quantile({(100 - percentile) / 100})",
|
|
286
|
+
f" upper = df['{col_name}'].quantile({percentile / 100})",
|
|
287
|
+
f" outliers = ((df['{col_name}'] < lower) | (df['{col_name}'] > upper)).sum()",
|
|
288
|
+
f" cleaning_stats['{col_name}_outliers_capped'] = outliers",
|
|
289
|
+
f" df['{col_name}'] = df['{col_name}'].clip(lower, upper)",
|
|
290
|
+
"",
|
|
291
|
+
]
|
|
292
|
+
)
|
|
266
293
|
|
|
267
294
|
elif action.action_type == "drop_rare":
|
|
268
295
|
threshold = action.params.get("threshold_percent", 5)
|
|
269
|
-
lines.extend(
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
296
|
+
lines.extend(
|
|
297
|
+
[
|
|
298
|
+
f" # Drop rare categories in {col_name} (< {threshold}%)",
|
|
299
|
+
f" value_counts = df['{col_name}'].value_counts(normalize=True)",
|
|
300
|
+
f" rare_values = value_counts[value_counts < {threshold / 100}].index",
|
|
301
|
+
" if len(rare_values) > 0:",
|
|
302
|
+
f" cleaning_stats['{col_name}_rare_dropped'] = len(rare_values)",
|
|
303
|
+
f" df.loc[df['{col_name}'].isin(rare_values), '{col_name}'] = df['{col_name}'].mode().iloc[0]",
|
|
304
|
+
"",
|
|
305
|
+
]
|
|
306
|
+
)
|
|
278
307
|
|
|
279
308
|
return lines
|
|
280
309
|
|
|
281
310
|
def generate_transform_functions(self, findings: ExplorationFindings) -> str:
|
|
282
|
-
self._get_columns_by_type(findings,
|
|
283
|
-
|
|
284
|
-
self._get_columns_by_type(findings,
|
|
285
|
-
[ColumnType.CATEGORICAL_NOMINAL, ColumnType.CATEGORICAL_ORDINAL])
|
|
311
|
+
self._get_columns_by_type(findings, [ColumnType.NUMERIC_CONTINUOUS, ColumnType.NUMERIC_DISCRETE])
|
|
312
|
+
self._get_columns_by_type(findings, [ColumnType.CATEGORICAL_NOMINAL, ColumnType.CATEGORICAL_ORDINAL])
|
|
286
313
|
|
|
287
314
|
transform_actions = self._build_transform_actions(findings)
|
|
288
315
|
|
|
@@ -295,79 +322,102 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
|
|
|
295
322
|
]
|
|
296
323
|
|
|
297
324
|
# Log transform for skewed columns
|
|
298
|
-
log_cols = [col for col, actions in transform_actions.items()
|
|
299
|
-
if any(a.method == "log1p" for a in actions)]
|
|
325
|
+
log_cols = [col for col, actions in transform_actions.items() if any(a.method == "log1p" for a in actions)]
|
|
300
326
|
if log_cols:
|
|
301
327
|
for col in log_cols:
|
|
302
|
-
code_lines.extend(
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
328
|
+
code_lines.extend(
|
|
329
|
+
[
|
|
330
|
+
f" # Log transform {col} (recommended for skewness)",
|
|
331
|
+
f" df['{col}_log'] = np.log1p(df['{col}'].clip(lower=0))",
|
|
332
|
+
f" transformers['{col}_log_transform'] = True",
|
|
333
|
+
"",
|
|
334
|
+
]
|
|
335
|
+
)
|
|
308
336
|
|
|
309
337
|
# Standard scaling
|
|
310
|
-
scale_standard = [
|
|
311
|
-
|
|
338
|
+
scale_standard = [
|
|
339
|
+
col
|
|
340
|
+
for col, actions in transform_actions.items()
|
|
341
|
+
if any(a.action_type == "scale" and a.method == "standard" for a in actions)
|
|
342
|
+
]
|
|
312
343
|
if scale_standard:
|
|
313
|
-
code_lines.extend(
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
344
|
+
code_lines.extend(
|
|
345
|
+
[
|
|
346
|
+
" # Standard scaling",
|
|
347
|
+
f" standard_cols = {scale_standard}",
|
|
348
|
+
" if standard_cols:",
|
|
349
|
+
" scaler = StandardScaler()",
|
|
350
|
+
" df[standard_cols] = scaler.fit_transform(df[standard_cols])",
|
|
351
|
+
" transformers['standard_scaler'] = {'columns': standard_cols}",
|
|
352
|
+
"",
|
|
353
|
+
]
|
|
354
|
+
)
|
|
322
355
|
|
|
323
356
|
# MinMax scaling
|
|
324
|
-
scale_minmax = [
|
|
325
|
-
|
|
357
|
+
scale_minmax = [
|
|
358
|
+
col
|
|
359
|
+
for col, actions in transform_actions.items()
|
|
360
|
+
if any(a.action_type == "scale" and a.method == "minmax" for a in actions)
|
|
361
|
+
]
|
|
326
362
|
if scale_minmax:
|
|
327
|
-
code_lines.extend(
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
363
|
+
code_lines.extend(
|
|
364
|
+
[
|
|
365
|
+
" # MinMax scaling",
|
|
366
|
+
f" minmax_cols = {scale_minmax}",
|
|
367
|
+
" if minmax_cols:",
|
|
368
|
+
" minmax_scaler = MinMaxScaler()",
|
|
369
|
+
" df[minmax_cols] = minmax_scaler.fit_transform(df[minmax_cols])",
|
|
370
|
+
" transformers['minmax_scaler'] = {'columns': minmax_cols}",
|
|
371
|
+
"",
|
|
372
|
+
]
|
|
373
|
+
)
|
|
336
374
|
|
|
337
375
|
# One-hot encoding
|
|
338
|
-
onehot_cols = [
|
|
339
|
-
|
|
376
|
+
onehot_cols = [
|
|
377
|
+
col
|
|
378
|
+
for col, actions in transform_actions.items()
|
|
379
|
+
if any(a.action_type == "encode" and a.method == "onehot" for a in actions)
|
|
380
|
+
]
|
|
340
381
|
if onehot_cols:
|
|
341
|
-
code_lines.extend(
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
382
|
+
code_lines.extend(
|
|
383
|
+
[
|
|
384
|
+
" # One-hot encoding",
|
|
385
|
+
f" onehot_cols = {onehot_cols}",
|
|
386
|
+
" for col in onehot_cols:",
|
|
387
|
+
" dummies = pd.get_dummies(df[col], prefix=col, drop_first=True)",
|
|
388
|
+
" df = pd.concat([df.drop(columns=[col]), dummies], axis=1)",
|
|
389
|
+
" transformers[f'{col}_onehot'] = list(dummies.columns)",
|
|
390
|
+
"",
|
|
391
|
+
]
|
|
392
|
+
)
|
|
350
393
|
|
|
351
394
|
# Label encoding
|
|
352
|
-
label_cols = [
|
|
353
|
-
|
|
395
|
+
label_cols = [
|
|
396
|
+
col
|
|
397
|
+
for col, actions in transform_actions.items()
|
|
398
|
+
if any(a.action_type == "encode" and a.method == "label" for a in actions)
|
|
399
|
+
]
|
|
354
400
|
if label_cols:
|
|
355
|
-
code_lines.extend(
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
401
|
+
code_lines.extend(
|
|
402
|
+
[
|
|
403
|
+
" # Label encoding",
|
|
404
|
+
f" label_cols = {label_cols}",
|
|
405
|
+
" label_encoders = {{}}",
|
|
406
|
+
" for col in label_cols:",
|
|
407
|
+
" le = LabelEncoder()",
|
|
408
|
+
" df[col] = le.fit_transform(df[col].astype(str))",
|
|
409
|
+
" label_encoders[col] = le",
|
|
410
|
+
" transformers['label_encoders'] = label_encoders",
|
|
411
|
+
"",
|
|
412
|
+
]
|
|
413
|
+
)
|
|
366
414
|
|
|
367
|
-
code_lines.extend(
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
415
|
+
code_lines.extend(
|
|
416
|
+
[
|
|
417
|
+
" mlflow.log_params({f'transform_{k}': str(v)[:250] for k, v in transformers.items()})",
|
|
418
|
+
" return df, transformers",
|
|
419
|
+
]
|
|
420
|
+
)
|
|
371
421
|
|
|
372
422
|
return "\n".join(code_lines)
|
|
373
423
|
|
|
@@ -409,12 +459,14 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
|
|
|
409
459
|
if not extract_types:
|
|
410
460
|
extract_types = ["month", "dayofweek", "days_since"]
|
|
411
461
|
|
|
412
|
-
code_lines.extend(
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
462
|
+
code_lines.extend(
|
|
463
|
+
[
|
|
464
|
+
f" # Datetime features from {col_name}",
|
|
465
|
+
f" if '{col_name}' in df.columns:",
|
|
466
|
+
f" df['{col_name}'] = safe_to_datetime(df['{col_name}'], errors='coerce')",
|
|
467
|
+
"",
|
|
468
|
+
]
|
|
469
|
+
)
|
|
418
470
|
|
|
419
471
|
for ext_type in extract_types:
|
|
420
472
|
if ext_type == "month":
|
|
@@ -433,19 +485,23 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
|
|
|
433
485
|
code_lines.append(f" df['{col_name}_year'] = df['{col_name}'].dt.year")
|
|
434
486
|
code_lines.append(f" new_features.append('{col_name}_year')")
|
|
435
487
|
elif ext_type == "days_since":
|
|
436
|
-
code_lines.extend(
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
488
|
+
code_lines.extend(
|
|
489
|
+
[
|
|
490
|
+
f" reference_date = df['{col_name}'].max()",
|
|
491
|
+
f" df['{col_name}_days_since'] = (reference_date - df['{col_name}']).dt.days",
|
|
492
|
+
f" new_features.append('{col_name}_days_since')",
|
|
493
|
+
]
|
|
494
|
+
)
|
|
441
495
|
|
|
442
496
|
code_lines.append("")
|
|
443
497
|
|
|
444
|
-
code_lines.extend(
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
498
|
+
code_lines.extend(
|
|
499
|
+
[
|
|
500
|
+
" if new_features:",
|
|
501
|
+
" mlflow.log_param('engineered_features', new_features)",
|
|
502
|
+
" return df",
|
|
503
|
+
]
|
|
504
|
+
)
|
|
449
505
|
|
|
450
506
|
return "\n".join(code_lines)
|
|
451
507
|
|
|
@@ -455,7 +511,7 @@ def log_data_quality_metrics(df: pd.DataFrame, prefix: str = "data"):
|
|
|
455
511
|
datetime_cols = findings.datetime_columns or []
|
|
456
512
|
exclude_cols = set(identifier_cols + datetime_cols + [target])
|
|
457
513
|
|
|
458
|
-
|
|
514
|
+
main_body = f'''
|
|
459
515
|
def train_model(
|
|
460
516
|
df: pd.DataFrame,
|
|
461
517
|
target_column: str = "{target}",
|
|
@@ -544,7 +600,7 @@ def train_model(
|
|
|
544
600
|
# Log everything
|
|
545
601
|
mlflow.log_params(model.get_params())
|
|
546
602
|
mlflow.log_metrics({{**val_metrics, **test_metrics, **cv_metrics}})
|
|
547
|
-
mlflow.sklearn.log_model(model, f"model_{{name}}")
|
|
603
|
+
mlflow.sklearn.log_model(model, name=f"model_{{name}}")
|
|
548
604
|
|
|
549
605
|
results[name] = {{
|
|
550
606
|
"model": model,
|
|
@@ -558,9 +614,29 @@ def train_model(
|
|
|
558
614
|
best_model = name
|
|
559
615
|
|
|
560
616
|
mlflow.log_param("best_model", best_model)
|
|
561
|
-
mlflow.log_metric("best_val_roc_auc", best_auc)
|
|
617
|
+
mlflow.log_metric("best_val_roc_auc", best_auc)'''
|
|
618
|
+
|
|
619
|
+
if self.mlflow_config.databricks and self.mlflow_config.model_name:
|
|
620
|
+
reg_name = f"{self.mlflow_config.catalog}.{self.mlflow_config.schema}.{self.mlflow_config.model_name}"
|
|
621
|
+
main_body += f'''
|
|
562
622
|
|
|
563
|
-
|
|
623
|
+
# Register best model in Unity Catalog and set alias
|
|
624
|
+
if best_model:
|
|
625
|
+
best_run = results[best_model]
|
|
626
|
+
model_info = mlflow.sklearn.log_model(
|
|
627
|
+
best_run["model"],
|
|
628
|
+
name="best_model",
|
|
629
|
+
registered_model_name="{reg_name}",
|
|
630
|
+
)
|
|
631
|
+
client = MlflowClient()
|
|
632
|
+
latest_version = client.get_latest_versions("{reg_name}")[0].version
|
|
633
|
+
client.set_registered_model_alias("{reg_name}", "champion", latest_version)'''
|
|
634
|
+
|
|
635
|
+
main_body += """
|
|
636
|
+
|
|
637
|
+
return {"results": results, "best_model": best_model}"""
|
|
638
|
+
|
|
639
|
+
return main_body
|
|
564
640
|
|
|
565
641
|
def generate_monitoring(self, findings: ExplorationFindings) -> str:
|
|
566
642
|
return '''
|
|
@@ -607,16 +683,16 @@ def main():
|
|
|
607
683
|
if self.mlflow_config.log_data_quality:
|
|
608
684
|
main_body += "\n log_data_quality_metrics(df, prefix='raw')"
|
|
609
685
|
|
|
610
|
-
main_body +=
|
|
686
|
+
main_body += """
|
|
611
687
|
|
|
612
688
|
# Clean data
|
|
613
689
|
print("Cleaning data...")
|
|
614
|
-
df = clean_data(df)
|
|
690
|
+
df = clean_data(df)"""
|
|
615
691
|
|
|
616
692
|
if self.mlflow_config.log_data_quality:
|
|
617
693
|
main_body += "\n log_data_quality_metrics(df, prefix='cleaned')"
|
|
618
694
|
|
|
619
|
-
main_body +=
|
|
695
|
+
main_body += """
|
|
620
696
|
|
|
621
697
|
# Apply transformations
|
|
622
698
|
print("Applying transformations...")
|
|
@@ -624,12 +700,12 @@ def main():
|
|
|
624
700
|
|
|
625
701
|
# Engineer features
|
|
626
702
|
print("Engineering features...")
|
|
627
|
-
df = engineer_features(df)
|
|
703
|
+
df = engineer_features(df)"""
|
|
628
704
|
|
|
629
705
|
if self.mlflow_config.log_data_quality:
|
|
630
706
|
main_body += "\n log_data_quality_metrics(df, prefix='final')"
|
|
631
707
|
|
|
632
|
-
main_body +=
|
|
708
|
+
main_body += """
|
|
633
709
|
|
|
634
710
|
# Train models
|
|
635
711
|
print("Training models...")
|
|
@@ -642,7 +718,7 @@ def main():
|
|
|
642
718
|
|
|
643
719
|
|
|
644
720
|
if __name__ == "__main__":
|
|
645
|
-
main()
|
|
721
|
+
main()"""
|
|
646
722
|
|
|
647
723
|
return main_body
|
|
648
724
|
|
|
@@ -651,10 +727,7 @@ if __name__ == "__main__":
|
|
|
651
727
|
findings: ExplorationFindings,
|
|
652
728
|
col_types: List[ColumnType],
|
|
653
729
|
) -> List[str]:
|
|
654
|
-
return [
|
|
655
|
-
name for name, col in findings.columns.items()
|
|
656
|
-
if col.inferred_type in col_types
|
|
657
|
-
]
|
|
730
|
+
return [name for name, col in findings.columns.items() if col.inferred_type in col_types]
|
|
658
731
|
|
|
659
732
|
def generate_all(self, findings: ExplorationFindings) -> Dict[str, str]:
|
|
660
733
|
return {
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from customer_retention.core.compat.detection import is_spark_available
|
|
2
|
+
from customer_retention.core.config.experiments import get_catalog, get_schema
|
|
2
3
|
|
|
3
4
|
from .feature_store import DatabricksFeatureStore, FeatureStoreAdapter, LocalFeatureStore
|
|
4
5
|
from .mlflow import DatabricksMLflow, LocalMLflow, MLflowAdapter
|
|
@@ -11,15 +12,17 @@ def get_delta(force_local: bool = False) -> DeltaStorage:
|
|
|
11
12
|
return DatabricksDelta()
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
def get_feature_store(
|
|
15
|
-
|
|
15
|
+
def get_feature_store(
|
|
16
|
+
base_path: str = "./feature_store", catalog: str | None = None, schema: str | None = None, force_local: bool = False
|
|
17
|
+
) -> FeatureStoreAdapter:
|
|
16
18
|
if force_local or not is_spark_available():
|
|
17
19
|
return LocalFeatureStore(base_path=base_path)
|
|
18
|
-
return DatabricksFeatureStore(catalog=catalog, schema=schema)
|
|
20
|
+
return DatabricksFeatureStore(catalog=catalog or get_catalog(), schema=schema or get_schema())
|
|
19
21
|
|
|
20
22
|
|
|
21
|
-
def get_mlflow(
|
|
22
|
-
|
|
23
|
+
def get_mlflow(
|
|
24
|
+
tracking_uri: str = "./mlruns", registry_uri: str = "databricks-uc", force_local: bool = False
|
|
25
|
+
) -> MLflowAdapter:
|
|
23
26
|
if force_local or not is_spark_available():
|
|
24
27
|
return LocalMLflow(tracking_uri=tracking_uri)
|
|
25
28
|
return DatabricksMLflow(registry_uri=registry_uri)
|