ins-pricing 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/docs/LOSS_FUNCTIONS.md +78 -0
- ins_pricing/frontend/QUICKSTART.md +152 -0
- ins_pricing/frontend/README.md +388 -0
- ins_pricing/frontend/__init__.py +10 -0
- ins_pricing/frontend/app.py +903 -0
- ins_pricing/frontend/config_builder.py +352 -0
- ins_pricing/frontend/example_config.json +36 -0
- ins_pricing/frontend/example_workflows.py +979 -0
- ins_pricing/frontend/ft_workflow.py +316 -0
- ins_pricing/frontend/runner.py +388 -0
- ins_pricing/production/predict.py +693 -664
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.3.4.dist-info → ins_pricing-0.4.0.dist-info}/METADATA +1 -1
- {ins_pricing-0.3.4.dist-info → ins_pricing-0.4.0.dist-info}/RECORD +16 -6
- {ins_pricing-0.3.4.dist-info → ins_pricing-0.4.0.dist-info}/WHEEL +1 -1
- {ins_pricing-0.3.4.dist-info → ins_pricing-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,979 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Example workflows implemented in Python so the frontend can run
|
|
3
|
+
the same tasks as the example notebooks.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Iterable, List, Optional, Sequence, Tuple
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from ins_pricing.cli.utils.cli_common import split_train_test
|
|
15
|
+
from ins_pricing.modelling.plotting import (
|
|
16
|
+
PlotStyle,
|
|
17
|
+
plot_double_lift_curve,
|
|
18
|
+
plot_lift_curve,
|
|
19
|
+
plot_oneway,
|
|
20
|
+
)
|
|
21
|
+
from ins_pricing.modelling.plotting.common import finalize_figure, plt
|
|
22
|
+
from ins_pricing.production.predict import load_predictor_from_config
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _parse_csv_list(value: str) -> List[str]:
|
|
26
|
+
if not value:
|
|
27
|
+
return []
|
|
28
|
+
return [x.strip() for x in value.split(",") if x.strip()]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _dedupe_list(values: Iterable[str]) -> List[str]:
|
|
32
|
+
seen = set()
|
|
33
|
+
out: List[str] = []
|
|
34
|
+
for item in values or []:
|
|
35
|
+
if item in seen:
|
|
36
|
+
continue
|
|
37
|
+
seen.add(item)
|
|
38
|
+
out.append(item)
|
|
39
|
+
return out
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _drop_duplicate_columns(df: pd.DataFrame, label: str) -> pd.DataFrame:
|
|
43
|
+
if df.columns.duplicated().any():
|
|
44
|
+
dupes = [str(x) for x in df.columns[df.columns.duplicated()]]
|
|
45
|
+
print(f"[Warn] {label}: dropping duplicate columns: {sorted(set(dupes))}")
|
|
46
|
+
return df.loc[:, ~df.columns.duplicated()].copy()
|
|
47
|
+
return df
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _resolve_output_dir(cfg_obj: dict, cfg_file_path: Path) -> str:
|
|
51
|
+
output_dir = cfg_obj.get("output_dir", "./Results")
|
|
52
|
+
return str((cfg_file_path.parent / output_dir).resolve())
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _resolve_plot_style(cfg_obj: dict) -> str:
|
|
56
|
+
return str(cfg_obj.get("plot_path_style", "nested") or "nested").strip().lower()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _resolve_plot_path(output_root: str, plot_style: str, subdir: str, filename: str) -> str:
|
|
60
|
+
plot_root = Path(output_root) / "plot"
|
|
61
|
+
if plot_style in {"flat", "root"}:
|
|
62
|
+
return str((plot_root / filename).resolve())
|
|
63
|
+
if subdir:
|
|
64
|
+
return str((plot_root / subdir / filename).resolve())
|
|
65
|
+
return str((plot_root / filename).resolve())
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _safe_tag(value: str) -> str:
|
|
69
|
+
return (
|
|
70
|
+
value.strip()
|
|
71
|
+
.replace(" ", "_")
|
|
72
|
+
.replace("/", "_")
|
|
73
|
+
.replace("\\", "_")
|
|
74
|
+
.replace(":", "_")
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _resolve_data_path(cfg: dict, cfg_path: Path, model_name: str) -> Path:
|
|
79
|
+
data_dir = cfg.get("data_dir", ".")
|
|
80
|
+
data_format = cfg.get("data_format", "csv")
|
|
81
|
+
data_path_template = cfg.get("data_path_template", "{model_name}.{ext}")
|
|
82
|
+
filename = data_path_template.format(model_name=model_name, ext=data_format)
|
|
83
|
+
return (cfg_path.parent / data_dir / filename).resolve()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _infer_categorical_features(
|
|
87
|
+
df: pd.DataFrame,
|
|
88
|
+
feature_list: Sequence[str],
|
|
89
|
+
*,
|
|
90
|
+
max_unique: int = 50,
|
|
91
|
+
max_ratio: float = 0.05,
|
|
92
|
+
) -> List[str]:
|
|
93
|
+
categorical: List[str] = []
|
|
94
|
+
n_rows = max(1, len(df))
|
|
95
|
+
for feature in feature_list:
|
|
96
|
+
if feature not in df.columns:
|
|
97
|
+
continue
|
|
98
|
+
nunique = int(df[feature].nunique(dropna=True))
|
|
99
|
+
ratio = nunique / float(n_rows)
|
|
100
|
+
if nunique <= max_unique or ratio <= max_ratio:
|
|
101
|
+
categorical.append(feature)
|
|
102
|
+
return categorical
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def run_pre_oneway(
|
|
106
|
+
*,
|
|
107
|
+
data_path: str,
|
|
108
|
+
model_name: str,
|
|
109
|
+
target_col: str,
|
|
110
|
+
weight_col: str,
|
|
111
|
+
feature_list: str,
|
|
112
|
+
categorical_features: str,
|
|
113
|
+
n_bins: int = 10,
|
|
114
|
+
holdout_ratio: Optional[float] = 0.25,
|
|
115
|
+
rand_seed: int = 13,
|
|
116
|
+
output_dir: Optional[str] = None,
|
|
117
|
+
) -> str:
|
|
118
|
+
data_path = str(data_path or "").strip()
|
|
119
|
+
if not data_path:
|
|
120
|
+
raise ValueError("data_path is required.")
|
|
121
|
+
model_name = str(model_name or "").strip()
|
|
122
|
+
if not model_name:
|
|
123
|
+
raise ValueError("model_name is required.")
|
|
124
|
+
|
|
125
|
+
raw_path = Path(data_path).resolve()
|
|
126
|
+
if not raw_path.exists():
|
|
127
|
+
raise FileNotFoundError(f"Data file not found: {raw_path}")
|
|
128
|
+
|
|
129
|
+
raw = pd.read_csv(raw_path, low_memory=False)
|
|
130
|
+
raw = _drop_duplicate_columns(raw, "raw").reset_index(drop=True)
|
|
131
|
+
raw.fillna(0, inplace=True)
|
|
132
|
+
|
|
133
|
+
features = _dedupe_list(_parse_csv_list(feature_list))
|
|
134
|
+
cats = _dedupe_list(_parse_csv_list(categorical_features))
|
|
135
|
+
|
|
136
|
+
if not features:
|
|
137
|
+
raise ValueError("feature_list is empty.")
|
|
138
|
+
|
|
139
|
+
missing = [f for f in features if f not in raw.columns]
|
|
140
|
+
if missing:
|
|
141
|
+
print(f"[Warn] Missing features removed: {missing}")
|
|
142
|
+
features = [f for f in features if f in raw.columns]
|
|
143
|
+
cats = [f for f in cats if f in raw.columns]
|
|
144
|
+
|
|
145
|
+
if not cats:
|
|
146
|
+
cats = _infer_categorical_features(raw, features)
|
|
147
|
+
|
|
148
|
+
out_dir = (
|
|
149
|
+
Path(output_dir).resolve()
|
|
150
|
+
if output_dir
|
|
151
|
+
else raw_path.parent / "Results" / "plot" / model_name / "oneway" / "pre"
|
|
152
|
+
)
|
|
153
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
154
|
+
|
|
155
|
+
if holdout_ratio is not None and float(holdout_ratio) > 0:
|
|
156
|
+
train_df, _ = split_train_test(
|
|
157
|
+
raw,
|
|
158
|
+
holdout_ratio=float(holdout_ratio),
|
|
159
|
+
strategy="random",
|
|
160
|
+
rand_seed=int(rand_seed),
|
|
161
|
+
reset_index_mode="none",
|
|
162
|
+
ratio_label="holdout_ratio",
|
|
163
|
+
)
|
|
164
|
+
df = train_df.reset_index(drop=True).copy()
|
|
165
|
+
else:
|
|
166
|
+
df = raw.copy()
|
|
167
|
+
|
|
168
|
+
print(f"Generating oneway plots for {len(features)} features...")
|
|
169
|
+
saved = 0
|
|
170
|
+
for i, feature in enumerate(features, 1):
|
|
171
|
+
is_categorical = feature in cats
|
|
172
|
+
try:
|
|
173
|
+
save_path = out_dir / f"{feature}.png"
|
|
174
|
+
plot_oneway(
|
|
175
|
+
df,
|
|
176
|
+
feature=feature,
|
|
177
|
+
weight_col=weight_col,
|
|
178
|
+
target_col=target_col,
|
|
179
|
+
n_bins=int(n_bins),
|
|
180
|
+
is_categorical=is_categorical,
|
|
181
|
+
save_path=str(save_path),
|
|
182
|
+
show=False,
|
|
183
|
+
)
|
|
184
|
+
if save_path.exists():
|
|
185
|
+
saved += 1
|
|
186
|
+
if i % 5 == 0 or i == len(features):
|
|
187
|
+
print(f" [{i}/{len(features)}] {feature}")
|
|
188
|
+
except Exception as exc:
|
|
189
|
+
print(f" [Warn] {feature} failed: {exc}")
|
|
190
|
+
|
|
191
|
+
print(f"Complete. Saved {saved}/{len(features)} plots to: {out_dir}")
|
|
192
|
+
return f"Saved {saved} plots to {out_dir}"
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def run_plot_direct(
|
|
196
|
+
*,
|
|
197
|
+
cfg_path: str,
|
|
198
|
+
xgb_cfg_path: str,
|
|
199
|
+
resn_cfg_path: str,
|
|
200
|
+
) -> str:
|
|
201
|
+
cfg_path = Path(cfg_path).resolve()
|
|
202
|
+
xgb_cfg_path = Path(xgb_cfg_path).resolve()
|
|
203
|
+
resn_cfg_path = Path(resn_cfg_path).resolve()
|
|
204
|
+
|
|
205
|
+
cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
|
|
206
|
+
xgb_cfg = json.loads(xgb_cfg_path.read_text(encoding="utf-8"))
|
|
207
|
+
resn_cfg = json.loads(resn_cfg_path.read_text(encoding="utf-8"))
|
|
208
|
+
|
|
209
|
+
model_name = f"{cfg['model_list'][0]}_{cfg['model_categories'][0]}"
|
|
210
|
+
|
|
211
|
+
raw_data_dir = (cfg_path.parent / cfg["data_dir"]).resolve()
|
|
212
|
+
raw_path = raw_data_dir / f"{model_name}.csv"
|
|
213
|
+
raw = pd.read_csv(raw_path)
|
|
214
|
+
raw = _drop_duplicate_columns(raw, "raw").reset_index(drop=True)
|
|
215
|
+
raw.fillna(0, inplace=True)
|
|
216
|
+
|
|
217
|
+
holdout_ratio = cfg.get("holdout_ratio", cfg.get("prop_test", 0.25))
|
|
218
|
+
split_strategy = cfg.get("split_strategy", "random")
|
|
219
|
+
split_group_col = cfg.get("split_group_col")
|
|
220
|
+
split_time_col = cfg.get("split_time_col")
|
|
221
|
+
split_time_ascending = cfg.get("split_time_ascending", True)
|
|
222
|
+
rand_seed = cfg.get("rand_seed", 13)
|
|
223
|
+
|
|
224
|
+
train_raw, test_raw = split_train_test(
|
|
225
|
+
raw,
|
|
226
|
+
holdout_ratio=holdout_ratio,
|
|
227
|
+
strategy=split_strategy,
|
|
228
|
+
group_col=split_group_col,
|
|
229
|
+
time_col=split_time_col,
|
|
230
|
+
time_ascending=split_time_ascending,
|
|
231
|
+
rand_seed=rand_seed,
|
|
232
|
+
reset_index_mode="none",
|
|
233
|
+
ratio_label="holdout_ratio",
|
|
234
|
+
)
|
|
235
|
+
train_raw = _drop_duplicate_columns(train_raw, "train_raw")
|
|
236
|
+
test_raw = _drop_duplicate_columns(test_raw, "test_raw")
|
|
237
|
+
|
|
238
|
+
train_df = train_raw.copy()
|
|
239
|
+
test_df = test_raw.copy()
|
|
240
|
+
|
|
241
|
+
feature_list = _dedupe_list(cfg.get("feature_list") or [])
|
|
242
|
+
categorical_features = _dedupe_list(cfg.get("categorical_features") or [])
|
|
243
|
+
|
|
244
|
+
output_dir_map = {
|
|
245
|
+
"xgb": _resolve_output_dir(xgb_cfg, xgb_cfg_path),
|
|
246
|
+
"resn": _resolve_output_dir(resn_cfg, resn_cfg_path),
|
|
247
|
+
}
|
|
248
|
+
plot_path_style_map = {
|
|
249
|
+
"xgb": _resolve_plot_style(xgb_cfg),
|
|
250
|
+
"resn": _resolve_plot_style(resn_cfg),
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
def _get_plot_config(model_key: str) -> Tuple[str, str]:
|
|
254
|
+
return (
|
|
255
|
+
output_dir_map.get(model_key, _resolve_output_dir(cfg, cfg_path)),
|
|
256
|
+
plot_path_style_map.get(model_key, _resolve_plot_style(cfg)),
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def _load_predictor(cfg_path: Path, model_key: str):
|
|
260
|
+
return load_predictor_from_config(cfg_path, model_key, model_name=model_name)
|
|
261
|
+
|
|
262
|
+
model_cfg_map = {"xgb": xgb_cfg_path, "resn": resn_cfg_path}
|
|
263
|
+
model_keys = cfg.get("model_keys") or ["xgb", "resn"]
|
|
264
|
+
model_keys = [key for key in model_keys if key in model_cfg_map]
|
|
265
|
+
if not model_keys:
|
|
266
|
+
raise ValueError("No valid model keys found in plot config.")
|
|
267
|
+
|
|
268
|
+
default_model_labels = {"xgb": "Xgboost", "resn": "ResNet"}
|
|
269
|
+
|
|
270
|
+
def _model_label(model_key: str) -> str:
|
|
271
|
+
labels = cfg.get("model_labels") or {}
|
|
272
|
+
return str(labels.get(model_key, default_model_labels.get(model_key, model_key)))
|
|
273
|
+
|
|
274
|
+
predictors = {key: _load_predictor(model_cfg_map[key], key) for key in model_keys}
|
|
275
|
+
|
|
276
|
+
pred_train = {}
|
|
277
|
+
pred_test = {}
|
|
278
|
+
for key, predictor in predictors.items():
|
|
279
|
+
pred_train[key] = predictor.predict(train_df).reshape(-1)
|
|
280
|
+
pred_test[key] = predictor.predict(test_df).reshape(-1)
|
|
281
|
+
if len(pred_train[key]) != len(train_df):
|
|
282
|
+
raise ValueError(f"Train prediction length mismatch for {key}")
|
|
283
|
+
if len(pred_test[key]) != len(test_df):
|
|
284
|
+
raise ValueError(f"Test prediction length mismatch for {key}")
|
|
285
|
+
|
|
286
|
+
plot_train = train_raw.copy()
|
|
287
|
+
plot_test = test_raw.copy()
|
|
288
|
+
for key in model_keys:
|
|
289
|
+
plot_train[f"pred_{key}"] = pred_train[key]
|
|
290
|
+
plot_test[f"pred_{key}"] = pred_test[key]
|
|
291
|
+
|
|
292
|
+
weight_col = cfg["weight"]
|
|
293
|
+
target_col = cfg["target"]
|
|
294
|
+
|
|
295
|
+
if weight_col not in plot_train.columns:
|
|
296
|
+
plot_train[weight_col] = 1.0
|
|
297
|
+
if weight_col not in plot_test.columns:
|
|
298
|
+
plot_test[weight_col] = 1.0
|
|
299
|
+
if target_col in plot_train.columns:
|
|
300
|
+
plot_train["w_act"] = plot_train[target_col] * plot_train[weight_col]
|
|
301
|
+
if target_col in plot_test.columns:
|
|
302
|
+
plot_test["w_act"] = plot_test[target_col] * plot_test[weight_col]
|
|
303
|
+
|
|
304
|
+
if "w_act" not in plot_train.columns or plot_train["w_act"].isna().all():
|
|
305
|
+
print("[Plot] Missing target values in train split; skip plots.")
|
|
306
|
+
return "Skipped plotting due to missing target values."
|
|
307
|
+
|
|
308
|
+
n_bins = cfg.get("plot", {}).get("n_bins", 10)
|
|
309
|
+
oneway_features = feature_list
|
|
310
|
+
oneway_categorical = set(categorical_features)
|
|
311
|
+
|
|
312
|
+
for pred_key in model_keys:
|
|
313
|
+
pred_label = _model_label(pred_key)
|
|
314
|
+
pred_col = f"pred_{pred_key}"
|
|
315
|
+
pred_tag = _safe_tag(pred_label or pred_col)
|
|
316
|
+
output_root, plot_style = _get_plot_config(pred_key)
|
|
317
|
+
for feature in oneway_features:
|
|
318
|
+
if feature not in plot_train.columns:
|
|
319
|
+
continue
|
|
320
|
+
save_path = _resolve_plot_path(
|
|
321
|
+
output_root,
|
|
322
|
+
plot_style,
|
|
323
|
+
f"{model_name}/oneway/post",
|
|
324
|
+
f"00_{model_name}_{feature}_oneway_{pred_tag}.png",
|
|
325
|
+
)
|
|
326
|
+
plot_oneway(
|
|
327
|
+
plot_train,
|
|
328
|
+
feature=feature,
|
|
329
|
+
weight_col=weight_col,
|
|
330
|
+
target_col="w_act",
|
|
331
|
+
pred_col=pred_col,
|
|
332
|
+
pred_label=pred_label,
|
|
333
|
+
n_bins=n_bins,
|
|
334
|
+
is_categorical=feature in oneway_categorical,
|
|
335
|
+
save_path=save_path,
|
|
336
|
+
show=False,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
datasets = []
|
|
340
|
+
if "w_act" in plot_train.columns and not plot_train["w_act"].isna().all():
|
|
341
|
+
datasets.append(("Train Data", plot_train))
|
|
342
|
+
if "w_act" in plot_test.columns and not plot_test["w_act"].isna().all():
|
|
343
|
+
datasets.append(("Test Data", plot_test))
|
|
344
|
+
|
|
345
|
+
def _plot_lift_for_model(pred_key: str, pred_label: str) -> None:
|
|
346
|
+
if not datasets:
|
|
347
|
+
return
|
|
348
|
+
output_root, plot_style = _get_plot_config(pred_key)
|
|
349
|
+
style = PlotStyle()
|
|
350
|
+
fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
|
|
351
|
+
if len(datasets) == 1:
|
|
352
|
+
axes = [axes]
|
|
353
|
+
for ax, (title, data) in zip(axes, datasets):
|
|
354
|
+
pred_col = f"pred_{pred_key}"
|
|
355
|
+
if pred_col not in data.columns:
|
|
356
|
+
continue
|
|
357
|
+
plot_lift_curve(
|
|
358
|
+
data[pred_col].values,
|
|
359
|
+
data["w_act"].values,
|
|
360
|
+
data[weight_col].values,
|
|
361
|
+
n_bins=n_bins,
|
|
362
|
+
title=f"Lift Chart on {title}",
|
|
363
|
+
pred_label="Predicted",
|
|
364
|
+
act_label="Actual",
|
|
365
|
+
weight_label="Earned Exposure",
|
|
366
|
+
pred_weighted=False,
|
|
367
|
+
actual_weighted=True,
|
|
368
|
+
ax=ax,
|
|
369
|
+
show=False,
|
|
370
|
+
style=style,
|
|
371
|
+
)
|
|
372
|
+
plt.subplots_adjust(wspace=0.3)
|
|
373
|
+
filename = f"01_{model_name}_{_safe_tag(pred_label)}_lift.png"
|
|
374
|
+
save_path = _resolve_plot_path(
|
|
375
|
+
output_root,
|
|
376
|
+
plot_style,
|
|
377
|
+
f"{model_name}/lift",
|
|
378
|
+
filename,
|
|
379
|
+
)
|
|
380
|
+
finalize_figure(fig, save_path=save_path, show=False, style=style)
|
|
381
|
+
|
|
382
|
+
for pred_key in model_keys:
|
|
383
|
+
_plot_lift_for_model(pred_key, _model_label(pred_key))
|
|
384
|
+
|
|
385
|
+
if (
|
|
386
|
+
all(k in model_keys for k in ["xgb", "resn"])
|
|
387
|
+
and all(f"pred_{k}" in plot_train.columns for k in ["xgb", "resn"])
|
|
388
|
+
and datasets
|
|
389
|
+
):
|
|
390
|
+
style = PlotStyle()
|
|
391
|
+
fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
|
|
392
|
+
if len(datasets) == 1:
|
|
393
|
+
axes = [axes]
|
|
394
|
+
for ax, (title, data) in zip(axes, datasets):
|
|
395
|
+
plot_double_lift_curve(
|
|
396
|
+
data["pred_xgb"].values,
|
|
397
|
+
data["pred_resn"].values,
|
|
398
|
+
data["w_act"].values,
|
|
399
|
+
data[weight_col].values,
|
|
400
|
+
n_bins=n_bins,
|
|
401
|
+
title=f"Double Lift Chart on {title}",
|
|
402
|
+
label1="Xgboost",
|
|
403
|
+
label2="ResNet",
|
|
404
|
+
pred1_weighted=False,
|
|
405
|
+
pred2_weighted=False,
|
|
406
|
+
actual_weighted=True,
|
|
407
|
+
ax=ax,
|
|
408
|
+
show=False,
|
|
409
|
+
style=style,
|
|
410
|
+
)
|
|
411
|
+
plt.subplots_adjust(wspace=0.3)
|
|
412
|
+
save_path = _resolve_plot_path(
|
|
413
|
+
_resolve_output_dir(cfg, cfg_path),
|
|
414
|
+
_resolve_plot_style(cfg),
|
|
415
|
+
"",
|
|
416
|
+
f"02_{model_name}_dlift_xgb_vs_resn.png",
|
|
417
|
+
)
|
|
418
|
+
finalize_figure(fig, save_path=save_path, show=False, style=style)
|
|
419
|
+
|
|
420
|
+
print("Plots saved under:")
|
|
421
|
+
for key in model_keys:
|
|
422
|
+
output_root, _ = _get_plot_config(key)
|
|
423
|
+
print(f" - {key}: {output_root}/plot/{model_name}")
|
|
424
|
+
return "Plotting complete."
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def run_plot_embed(
|
|
428
|
+
*,
|
|
429
|
+
cfg_path: str,
|
|
430
|
+
xgb_cfg_path: str,
|
|
431
|
+
resn_cfg_path: str,
|
|
432
|
+
ft_cfg_path: str,
|
|
433
|
+
use_runtime_ft_embedding: bool = False,
|
|
434
|
+
) -> str:
|
|
435
|
+
cfg_path = Path(cfg_path).resolve()
|
|
436
|
+
xgb_cfg_path = Path(xgb_cfg_path).resolve()
|
|
437
|
+
resn_cfg_path = Path(resn_cfg_path).resolve()
|
|
438
|
+
ft_cfg_path = Path(ft_cfg_path).resolve()
|
|
439
|
+
|
|
440
|
+
cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
|
|
441
|
+
xgb_cfg = json.loads(xgb_cfg_path.read_text(encoding="utf-8"))
|
|
442
|
+
resn_cfg = json.loads(resn_cfg_path.read_text(encoding="utf-8"))
|
|
443
|
+
ft_cfg = json.loads(ft_cfg_path.read_text(encoding="utf-8"))
|
|
444
|
+
|
|
445
|
+
model_name = f"{cfg['model_list'][0]}_{cfg['model_categories'][0]}"
|
|
446
|
+
|
|
447
|
+
raw_data_dir = (ft_cfg_path.parent / ft_cfg["data_dir"]).resolve()
|
|
448
|
+
raw_path = raw_data_dir / f"{model_name}.csv"
|
|
449
|
+
raw = pd.read_csv(raw_path)
|
|
450
|
+
raw = _drop_duplicate_columns(raw, "raw").reset_index(drop=True)
|
|
451
|
+
raw.fillna(0, inplace=True)
|
|
452
|
+
|
|
453
|
+
ft_output_dir = (ft_cfg_path.parent / ft_cfg["output_dir"]).resolve()
|
|
454
|
+
ft_prefix = ft_cfg.get("ft_feature_prefix", "ft_emb")
|
|
455
|
+
raw_feature_list = _dedupe_list(ft_cfg.get("feature_list") or [])
|
|
456
|
+
raw_categorical_features = _dedupe_list(ft_cfg.get("categorical_features") or [])
|
|
457
|
+
|
|
458
|
+
if ft_cfg.get("geo_feature_nmes"):
|
|
459
|
+
raise ValueError("FT inference with geo tokens is not supported in this workflow.")
|
|
460
|
+
|
|
461
|
+
holdout_ratio = cfg.get("holdout_ratio", cfg.get("prop_test", 0.25))
|
|
462
|
+
split_strategy = cfg.get("split_strategy", "random")
|
|
463
|
+
split_group_col = cfg.get("split_group_col")
|
|
464
|
+
split_time_col = cfg.get("split_time_col")
|
|
465
|
+
split_time_ascending = cfg.get("split_time_ascending", True)
|
|
466
|
+
rand_seed = cfg.get("rand_seed", 13)
|
|
467
|
+
|
|
468
|
+
train_raw, test_raw = split_train_test(
|
|
469
|
+
raw,
|
|
470
|
+
holdout_ratio=holdout_ratio,
|
|
471
|
+
strategy=split_strategy,
|
|
472
|
+
group_col=split_group_col,
|
|
473
|
+
time_col=split_time_col,
|
|
474
|
+
time_ascending=split_time_ascending,
|
|
475
|
+
rand_seed=rand_seed,
|
|
476
|
+
reset_index_mode="none",
|
|
477
|
+
ratio_label="holdout_ratio",
|
|
478
|
+
)
|
|
479
|
+
train_raw = _drop_duplicate_columns(train_raw, "train_raw")
|
|
480
|
+
test_raw = _drop_duplicate_columns(test_raw, "test_raw")
|
|
481
|
+
|
|
482
|
+
if use_runtime_ft_embedding:
|
|
483
|
+
import torch
|
|
484
|
+
|
|
485
|
+
ft_model_path = ft_output_dir / "model" / f"01_{model_name}_FTTransformer.pth"
|
|
486
|
+
ft_payload = torch.load(ft_model_path, map_location="cpu")
|
|
487
|
+
ft_model = ft_payload["model"] if isinstance(ft_payload, dict) and "model" in ft_payload else ft_payload
|
|
488
|
+
|
|
489
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
490
|
+
if hasattr(ft_model, "device"):
|
|
491
|
+
ft_model.device = device
|
|
492
|
+
if hasattr(ft_model, "to"):
|
|
493
|
+
ft_model.to(device)
|
|
494
|
+
if hasattr(ft_model, "ft"):
|
|
495
|
+
ft_model.ft.to(device)
|
|
496
|
+
|
|
497
|
+
emb_train = ft_model.predict(train_raw, return_embedding=True)
|
|
498
|
+
emb_cols = [f"pred_{ft_prefix}_{i}" for i in range(emb_train.shape[1])]
|
|
499
|
+
train_df = train_raw.copy()
|
|
500
|
+
train_df[emb_cols] = emb_train
|
|
501
|
+
|
|
502
|
+
emb_test = ft_model.predict(test_raw, return_embedding=True)
|
|
503
|
+
test_df = test_raw.copy()
|
|
504
|
+
test_df[emb_cols] = emb_test
|
|
505
|
+
else:
|
|
506
|
+
embed_data_dir = (cfg_path.parent / cfg["data_dir"]).resolve()
|
|
507
|
+
embed_path = embed_data_dir / f"{model_name}.csv"
|
|
508
|
+
embed_df = pd.read_csv(embed_path)
|
|
509
|
+
embed_df = _drop_duplicate_columns(embed_df, "embed").reset_index(drop=True)
|
|
510
|
+
embed_df.fillna(0, inplace=True)
|
|
511
|
+
if len(embed_df) != len(raw):
|
|
512
|
+
raise ValueError(
|
|
513
|
+
f"Row count mismatch: raw={len(raw)}, embed={len(embed_df)}. "
|
|
514
|
+
"Cannot align predictions to raw features."
|
|
515
|
+
)
|
|
516
|
+
train_df = embed_df.loc[train_raw.index].copy()
|
|
517
|
+
test_df = embed_df.loc[test_raw.index].copy()
|
|
518
|
+
|
|
519
|
+
feature_list = _dedupe_list(cfg.get("feature_list") or [])
|
|
520
|
+
categorical_features = _dedupe_list(cfg.get("categorical_features") or [])
|
|
521
|
+
|
|
522
|
+
output_dir_map = {
|
|
523
|
+
"xgb": _resolve_output_dir(xgb_cfg, xgb_cfg_path),
|
|
524
|
+
"resn": _resolve_output_dir(resn_cfg, resn_cfg_path),
|
|
525
|
+
}
|
|
526
|
+
plot_path_style_map = {
|
|
527
|
+
"xgb": _resolve_plot_style(xgb_cfg),
|
|
528
|
+
"resn": _resolve_plot_style(resn_cfg),
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
def _get_plot_config(model_key: str) -> Tuple[str, str]:
|
|
532
|
+
return (
|
|
533
|
+
output_dir_map.get(model_key, _resolve_output_dir(cfg, cfg_path)),
|
|
534
|
+
plot_path_style_map.get(model_key, _resolve_plot_style(cfg)),
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
def _load_predictor(cfg_path: Path, model_key: str):
|
|
538
|
+
return load_predictor_from_config(cfg_path, model_key, model_name=model_name)
|
|
539
|
+
|
|
540
|
+
model_cfg_map = {"xgb": xgb_cfg_path, "resn": resn_cfg_path}
|
|
541
|
+
model_keys = cfg.get("model_keys") or ["xgb", "resn"]
|
|
542
|
+
model_keys = [key for key in model_keys if key in model_cfg_map]
|
|
543
|
+
if not model_keys:
|
|
544
|
+
raise ValueError("No valid model keys found in plot config.")
|
|
545
|
+
|
|
546
|
+
default_model_labels = {"xgb": "Xgboost", "resn": "ResNet"}
|
|
547
|
+
|
|
548
|
+
def _model_label(model_key: str) -> str:
|
|
549
|
+
labels = cfg.get("model_labels") or {}
|
|
550
|
+
return str(labels.get(model_key, default_model_labels.get(model_key, model_key)))
|
|
551
|
+
|
|
552
|
+
predictors = {key: _load_predictor(model_cfg_map[key], key) for key in model_keys}
|
|
553
|
+
pred_train = {}
|
|
554
|
+
pred_test = {}
|
|
555
|
+
for key, predictor in predictors.items():
|
|
556
|
+
pred_train[key] = predictor.predict(train_df).reshape(-1)
|
|
557
|
+
pred_test[key] = predictor.predict(test_df).reshape(-1)
|
|
558
|
+
if len(pred_train[key]) != len(train_df):
|
|
559
|
+
raise ValueError(f"Train prediction length mismatch for {key}")
|
|
560
|
+
if len(pred_test[key]) != len(test_df):
|
|
561
|
+
raise ValueError(f"Test prediction length mismatch for {key}")
|
|
562
|
+
|
|
563
|
+
plot_train = train_raw.copy()
|
|
564
|
+
plot_test = test_raw.copy()
|
|
565
|
+
for key in model_keys:
|
|
566
|
+
plot_train[f"pred_{key}"] = pred_train[key]
|
|
567
|
+
plot_test[f"pred_{key}"] = pred_test[key]
|
|
568
|
+
|
|
569
|
+
weight_col = cfg["weight"]
|
|
570
|
+
target_col = cfg["target"]
|
|
571
|
+
if weight_col not in plot_train.columns:
|
|
572
|
+
plot_train[weight_col] = 1.0
|
|
573
|
+
if weight_col not in plot_test.columns:
|
|
574
|
+
plot_test[weight_col] = 1.0
|
|
575
|
+
if target_col in plot_train.columns:
|
|
576
|
+
plot_train["w_act"] = plot_train[target_col] * plot_train[weight_col]
|
|
577
|
+
if target_col in plot_test.columns:
|
|
578
|
+
plot_test["w_act"] = plot_test[target_col] * plot_test[weight_col]
|
|
579
|
+
|
|
580
|
+
if "w_act" not in plot_train.columns or plot_train["w_act"].isna().all():
|
|
581
|
+
print("[Plot] Missing target values in train split; skip plots.")
|
|
582
|
+
return "Skipped plotting due to missing target values."
|
|
583
|
+
|
|
584
|
+
n_bins = cfg.get("plot", {}).get("n_bins", 10)
|
|
585
|
+
oneway_features = raw_feature_list or feature_list
|
|
586
|
+
oneway_categorical = set(raw_categorical_features or categorical_features)
|
|
587
|
+
|
|
588
|
+
for pred_key in model_keys:
|
|
589
|
+
pred_label = _model_label(pred_key)
|
|
590
|
+
pred_col = f"pred_{pred_key}"
|
|
591
|
+
pred_tag = _safe_tag(pred_label or pred_col)
|
|
592
|
+
output_root, plot_style = _get_plot_config(pred_key)
|
|
593
|
+
for feature in oneway_features:
|
|
594
|
+
if feature not in plot_train.columns:
|
|
595
|
+
continue
|
|
596
|
+
save_path = _resolve_plot_path(
|
|
597
|
+
output_root,
|
|
598
|
+
plot_style,
|
|
599
|
+
f"{model_name}/oneway/post",
|
|
600
|
+
f"00_{model_name}_{feature}_oneway_{pred_tag}.png",
|
|
601
|
+
)
|
|
602
|
+
plot_oneway(
|
|
603
|
+
plot_train,
|
|
604
|
+
feature=feature,
|
|
605
|
+
weight_col=weight_col,
|
|
606
|
+
target_col="w_act",
|
|
607
|
+
pred_col=pred_col,
|
|
608
|
+
pred_label=pred_label,
|
|
609
|
+
n_bins=n_bins,
|
|
610
|
+
is_categorical=feature in oneway_categorical,
|
|
611
|
+
save_path=save_path,
|
|
612
|
+
show=False,
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
datasets = []
|
|
616
|
+
if "w_act" in plot_train.columns and not plot_train["w_act"].isna().all():
|
|
617
|
+
datasets.append(("Train Data", plot_train))
|
|
618
|
+
if "w_act" in plot_test.columns and not plot_test["w_act"].isna().all():
|
|
619
|
+
datasets.append(("Test Data", plot_test))
|
|
620
|
+
|
|
621
|
+
def _plot_lift_for_model(pred_key: str, pred_label: str) -> None:
|
|
622
|
+
if not datasets:
|
|
623
|
+
return
|
|
624
|
+
output_root, plot_style = _get_plot_config(pred_key)
|
|
625
|
+
style = PlotStyle()
|
|
626
|
+
fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
|
|
627
|
+
if len(datasets) == 1:
|
|
628
|
+
axes = [axes]
|
|
629
|
+
for ax, (title, data) in zip(axes, datasets):
|
|
630
|
+
pred_col = f"pred_{pred_key}"
|
|
631
|
+
if pred_col not in data.columns:
|
|
632
|
+
continue
|
|
633
|
+
plot_lift_curve(
|
|
634
|
+
data[pred_col].values,
|
|
635
|
+
data["w_act"].values,
|
|
636
|
+
data[weight_col].values,
|
|
637
|
+
n_bins=n_bins,
|
|
638
|
+
title=f"Lift Chart on {title}",
|
|
639
|
+
pred_label="Predicted",
|
|
640
|
+
act_label="Actual",
|
|
641
|
+
weight_label="Earned Exposure",
|
|
642
|
+
pred_weighted=False,
|
|
643
|
+
actual_weighted=True,
|
|
644
|
+
ax=ax,
|
|
645
|
+
show=False,
|
|
646
|
+
style=style,
|
|
647
|
+
)
|
|
648
|
+
plt.subplots_adjust(wspace=0.3)
|
|
649
|
+
filename = f"01_{model_name}_{_safe_tag(pred_label)}_lift.png"
|
|
650
|
+
save_path = _resolve_plot_path(
|
|
651
|
+
output_root,
|
|
652
|
+
plot_style,
|
|
653
|
+
f"{model_name}/lift",
|
|
654
|
+
filename,
|
|
655
|
+
)
|
|
656
|
+
finalize_figure(fig, save_path=save_path, show=False, style=style)
|
|
657
|
+
|
|
658
|
+
for pred_key in model_keys:
|
|
659
|
+
_plot_lift_for_model(pred_key, _model_label(pred_key))
|
|
660
|
+
|
|
661
|
+
if (
|
|
662
|
+
all(k in model_keys for k in ["xgb", "resn"])
|
|
663
|
+
and all(f"pred_{k}" in plot_train.columns for k in ["xgb", "resn"])
|
|
664
|
+
and datasets
|
|
665
|
+
):
|
|
666
|
+
style = PlotStyle()
|
|
667
|
+
fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
|
|
668
|
+
if len(datasets) == 1:
|
|
669
|
+
axes = [axes]
|
|
670
|
+
for ax, (title, data) in zip(axes, datasets):
|
|
671
|
+
plot_double_lift_curve(
|
|
672
|
+
data["pred_xgb"].values,
|
|
673
|
+
data["pred_resn"].values,
|
|
674
|
+
data["w_act"].values,
|
|
675
|
+
data[weight_col].values,
|
|
676
|
+
n_bins=n_bins,
|
|
677
|
+
title=f"Double Lift Chart on {title}",
|
|
678
|
+
label1="Xgboost",
|
|
679
|
+
label2="ResNet",
|
|
680
|
+
pred1_weighted=False,
|
|
681
|
+
pred2_weighted=False,
|
|
682
|
+
actual_weighted=True,
|
|
683
|
+
ax=ax,
|
|
684
|
+
show=False,
|
|
685
|
+
style=style,
|
|
686
|
+
)
|
|
687
|
+
plt.subplots_adjust(wspace=0.3)
|
|
688
|
+
save_path = _resolve_plot_path(
|
|
689
|
+
_resolve_output_dir(cfg, cfg_path),
|
|
690
|
+
_resolve_plot_style(cfg),
|
|
691
|
+
"",
|
|
692
|
+
f"02_{model_name}_dlift_xgb_vs_resn.png",
|
|
693
|
+
)
|
|
694
|
+
finalize_figure(fig, save_path=save_path, show=False, style=style)
|
|
695
|
+
|
|
696
|
+
print("Plots saved under:")
|
|
697
|
+
for key in model_keys:
|
|
698
|
+
output_root, _ = _get_plot_config(key)
|
|
699
|
+
print(f" - {key}: {output_root}/plot/{model_name}")
|
|
700
|
+
return "Plotting complete."
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def run_predict_ft_embed(
|
|
704
|
+
*,
|
|
705
|
+
ft_cfg_path: str,
|
|
706
|
+
xgb_cfg_path: Optional[str],
|
|
707
|
+
resn_cfg_path: Optional[str],
|
|
708
|
+
input_path: str,
|
|
709
|
+
output_path: str,
|
|
710
|
+
model_name: Optional[str],
|
|
711
|
+
model_keys: str,
|
|
712
|
+
) -> str:
|
|
713
|
+
ft_cfg_path = Path(ft_cfg_path).resolve()
|
|
714
|
+
xgb_cfg_path = Path(xgb_cfg_path).resolve() if xgb_cfg_path else None
|
|
715
|
+
resn_cfg_path = Path(resn_cfg_path).resolve() if resn_cfg_path else None
|
|
716
|
+
input_path = Path(input_path).resolve()
|
|
717
|
+
output_path = Path(output_path).resolve()
|
|
718
|
+
|
|
719
|
+
if not input_path.exists():
|
|
720
|
+
raise FileNotFoundError(f"Input data not found: {input_path}")
|
|
721
|
+
|
|
722
|
+
keys = [k.strip() for k in model_keys.split(",") if k.strip()]
|
|
723
|
+
if not keys:
|
|
724
|
+
raise ValueError("model_keys is empty.")
|
|
725
|
+
|
|
726
|
+
ft_cfg = json.loads(ft_cfg_path.read_text(encoding="utf-8"))
|
|
727
|
+
xgb_cfg = json.loads(xgb_cfg_path.read_text(encoding="utf-8")) if xgb_cfg_path else None
|
|
728
|
+
resn_cfg = json.loads(resn_cfg_path.read_text(encoding="utf-8")) if resn_cfg_path else None
|
|
729
|
+
|
|
730
|
+
if model_name is None:
|
|
731
|
+
model_list = list(ft_cfg.get("model_list") or [])
|
|
732
|
+
model_categories = list(ft_cfg.get("model_categories") or [])
|
|
733
|
+
if len(model_list) != 1 or len(model_categories) != 1:
|
|
734
|
+
raise ValueError("Set model_name when multiple models exist.")
|
|
735
|
+
model_name = f"{model_list[0]}_{model_categories[0]}"
|
|
736
|
+
|
|
737
|
+
ft_output_dir = (ft_cfg_path.parent / ft_cfg["output_dir"]).resolve()
|
|
738
|
+
xgb_output_dir = (xgb_cfg_path.parent / xgb_cfg["output_dir"]).resolve() if xgb_cfg else None
|
|
739
|
+
ft_prefix = ft_cfg.get("ft_feature_prefix", "ft_emb")
|
|
740
|
+
xgb_task_type = str(xgb_cfg.get("task_type", "regression")) if xgb_cfg else None
|
|
741
|
+
|
|
742
|
+
if ft_cfg.get("geo_feature_nmes"):
|
|
743
|
+
raise ValueError("FT with geo tokens is not supported in this workflow.")
|
|
744
|
+
|
|
745
|
+
import torch
|
|
746
|
+
import joblib
|
|
747
|
+
|
|
748
|
+
print("Loading FT model...")
|
|
749
|
+
ft_model_path = ft_output_dir / "model" / f"01_{model_name}_FTTransformer.pth"
|
|
750
|
+
ft_payload = torch.load(ft_model_path, map_location="cpu")
|
|
751
|
+
ft_model = ft_payload["model"] if isinstance(ft_payload, dict) and "model" in ft_payload else ft_payload
|
|
752
|
+
|
|
753
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
754
|
+
if hasattr(ft_model, "device"):
|
|
755
|
+
ft_model.device = device
|
|
756
|
+
if hasattr(ft_model, "to"):
|
|
757
|
+
ft_model.to(device)
|
|
758
|
+
if hasattr(ft_model, "ft"):
|
|
759
|
+
ft_model.ft.to(device)
|
|
760
|
+
|
|
761
|
+
df_new = pd.read_csv(input_path)
|
|
762
|
+
emb = ft_model.predict(df_new, return_embedding=True)
|
|
763
|
+
emb_cols = [f"pred_{ft_prefix}_{i}" for i in range(emb.shape[1])]
|
|
764
|
+
df_with_emb = df_new.copy()
|
|
765
|
+
df_with_emb[emb_cols] = emb
|
|
766
|
+
result = df_with_emb.copy()
|
|
767
|
+
|
|
768
|
+
if "xgb" in keys:
|
|
769
|
+
if not xgb_cfg or not xgb_output_dir:
|
|
770
|
+
raise ValueError("xgb model selected but xgb_cfg_path is missing.")
|
|
771
|
+
xgb_model_path = xgb_output_dir / "model" / f"01_{model_name}_Xgboost.pkl"
|
|
772
|
+
xgb_payload = joblib.load(xgb_model_path)
|
|
773
|
+
if isinstance(xgb_payload, dict) and "model" in xgb_payload:
|
|
774
|
+
xgb_model = xgb_payload["model"]
|
|
775
|
+
feature_list = xgb_payload.get("preprocess_artifacts", {}).get("factor_nmes")
|
|
776
|
+
else:
|
|
777
|
+
xgb_model = xgb_payload
|
|
778
|
+
feature_list = None
|
|
779
|
+
if not feature_list:
|
|
780
|
+
feature_list = xgb_cfg.get("feature_list") or []
|
|
781
|
+
if not feature_list:
|
|
782
|
+
raise ValueError("Feature list missing for XGB model.")
|
|
783
|
+
|
|
784
|
+
X = df_with_emb[feature_list]
|
|
785
|
+
if xgb_task_type == "classification" and hasattr(xgb_model, "predict_proba"):
|
|
786
|
+
pred = xgb_model.predict_proba(X)[:, 1]
|
|
787
|
+
else:
|
|
788
|
+
pred = xgb_model.predict(X)
|
|
789
|
+
result["pred_xgb"] = pred
|
|
790
|
+
|
|
791
|
+
if "resn" in keys:
|
|
792
|
+
if not resn_cfg_path:
|
|
793
|
+
raise ValueError("resn model selected but resn_cfg_path is missing.")
|
|
794
|
+
resn_predictor = load_predictor_from_config(
|
|
795
|
+
resn_cfg_path, "resn", model_name=model_name
|
|
796
|
+
)
|
|
797
|
+
pred_resn = resn_predictor.predict(df_with_emb)
|
|
798
|
+
result["pred_resn"] = pred_resn
|
|
799
|
+
|
|
800
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
801
|
+
result.to_csv(output_path, index=False)
|
|
802
|
+
print(f"Saved predictions to: {output_path}")
|
|
803
|
+
return str(output_path)
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
def run_compare_ft_embed(
|
|
807
|
+
*,
|
|
808
|
+
direct_cfg_path: str,
|
|
809
|
+
ft_cfg_path: str,
|
|
810
|
+
ft_embed_cfg_path: str,
|
|
811
|
+
model_key: str,
|
|
812
|
+
label_direct: str,
|
|
813
|
+
label_ft: str,
|
|
814
|
+
use_runtime_ft_embedding: bool = False,
|
|
815
|
+
n_bins_override: Optional[int] = 10,
|
|
816
|
+
) -> str:
|
|
817
|
+
direct_cfg_path = Path(direct_cfg_path).resolve()
|
|
818
|
+
ft_cfg_path = Path(ft_cfg_path).resolve()
|
|
819
|
+
ft_embed_cfg_path = Path(ft_embed_cfg_path).resolve()
|
|
820
|
+
|
|
821
|
+
direct_cfg = json.loads(direct_cfg_path.read_text(encoding="utf-8"))
|
|
822
|
+
ft_embed_cfg = json.loads(ft_embed_cfg_path.read_text(encoding="utf-8"))
|
|
823
|
+
ft_cfg = json.loads(ft_cfg_path.read_text(encoding="utf-8"))
|
|
824
|
+
|
|
825
|
+
model_name = f"{direct_cfg['model_list'][0]}_{direct_cfg['model_categories'][0]}"
|
|
826
|
+
|
|
827
|
+
raw_path = _resolve_data_path(direct_cfg, direct_cfg_path, model_name)
|
|
828
|
+
raw = pd.read_csv(raw_path)
|
|
829
|
+
raw = _drop_duplicate_columns(raw, "raw").reset_index(drop=True)
|
|
830
|
+
raw.fillna(0, inplace=True)
|
|
831
|
+
|
|
832
|
+
holdout_ratio = direct_cfg.get("holdout_ratio", direct_cfg.get("prop_test", 0.25))
|
|
833
|
+
split_strategy = direct_cfg.get("split_strategy", "random")
|
|
834
|
+
split_group_col = direct_cfg.get("split_group_col")
|
|
835
|
+
split_time_col = direct_cfg.get("split_time_col")
|
|
836
|
+
split_time_ascending = direct_cfg.get("split_time_ascending", True)
|
|
837
|
+
rand_seed = direct_cfg.get("rand_seed", 13)
|
|
838
|
+
|
|
839
|
+
train_raw, test_raw = split_train_test(
|
|
840
|
+
raw,
|
|
841
|
+
holdout_ratio=holdout_ratio,
|
|
842
|
+
strategy=split_strategy,
|
|
843
|
+
group_col=split_group_col,
|
|
844
|
+
time_col=split_time_col,
|
|
845
|
+
time_ascending=split_time_ascending,
|
|
846
|
+
rand_seed=rand_seed,
|
|
847
|
+
reset_index_mode="none",
|
|
848
|
+
ratio_label="holdout_ratio",
|
|
849
|
+
)
|
|
850
|
+
train_raw = _drop_duplicate_columns(train_raw, "train_raw")
|
|
851
|
+
test_raw = _drop_duplicate_columns(test_raw, "test_raw")
|
|
852
|
+
|
|
853
|
+
ft_output_dir = (ft_cfg_path.parent / ft_cfg["output_dir"]).resolve()
|
|
854
|
+
ft_prefix = ft_cfg.get("ft_feature_prefix", "ft_emb")
|
|
855
|
+
|
|
856
|
+
if use_runtime_ft_embedding:
|
|
857
|
+
import torch
|
|
858
|
+
|
|
859
|
+
ft_model_path = ft_output_dir / "model" / f"01_{model_name}_FTTransformer.pth"
|
|
860
|
+
ft_payload = torch.load(ft_model_path, map_location="cpu")
|
|
861
|
+
ft_model = ft_payload["model"] if isinstance(ft_payload, dict) and "model" in ft_payload else ft_payload
|
|
862
|
+
|
|
863
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
864
|
+
if hasattr(ft_model, "device"):
|
|
865
|
+
ft_model.device = device
|
|
866
|
+
if hasattr(ft_model, "to"):
|
|
867
|
+
ft_model.to(device)
|
|
868
|
+
if hasattr(ft_model, "ft"):
|
|
869
|
+
ft_model.ft.to(device)
|
|
870
|
+
|
|
871
|
+
emb_train = ft_model.predict(train_raw, return_embedding=True)
|
|
872
|
+
emb_cols = [f"pred_{ft_prefix}_{i}" for i in range(emb_train.shape[1])]
|
|
873
|
+
train_df = train_raw.copy()
|
|
874
|
+
train_df[emb_cols] = emb_train
|
|
875
|
+
|
|
876
|
+
emb_test = ft_model.predict(test_raw, return_embedding=True)
|
|
877
|
+
test_df = test_raw.copy()
|
|
878
|
+
test_df[emb_cols] = emb_test
|
|
879
|
+
else:
|
|
880
|
+
embed_path = _resolve_data_path(ft_embed_cfg, ft_embed_cfg_path, model_name)
|
|
881
|
+
embed_df = pd.read_csv(embed_path)
|
|
882
|
+
embed_df = _drop_duplicate_columns(embed_df, "embed").reset_index(drop=True)
|
|
883
|
+
embed_df.fillna(0, inplace=True)
|
|
884
|
+
if len(embed_df) != len(raw):
|
|
885
|
+
raise ValueError(
|
|
886
|
+
f"Row count mismatch: raw={len(raw)}, embed={len(embed_df)}. "
|
|
887
|
+
"Cannot align predictions to raw features."
|
|
888
|
+
)
|
|
889
|
+
train_df = embed_df.loc[train_raw.index].copy()
|
|
890
|
+
test_df = embed_df.loc[test_raw.index].copy()
|
|
891
|
+
|
|
892
|
+
direct_predictor = load_predictor_from_config(
|
|
893
|
+
direct_cfg_path, model_key, model_name=model_name
|
|
894
|
+
)
|
|
895
|
+
ft_predictor = load_predictor_from_config(
|
|
896
|
+
ft_embed_cfg_path, model_key, model_name=model_name
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
pred_direct_train = direct_predictor.predict(train_raw).reshape(-1)
|
|
900
|
+
pred_direct_test = direct_predictor.predict(test_raw).reshape(-1)
|
|
901
|
+
pred_ft_train = ft_predictor.predict(train_df).reshape(-1)
|
|
902
|
+
pred_ft_test = ft_predictor.predict(test_df).reshape(-1)
|
|
903
|
+
|
|
904
|
+
if len(pred_direct_train) != len(train_raw):
|
|
905
|
+
raise ValueError("Train prediction length mismatch for direct model.")
|
|
906
|
+
if len(pred_direct_test) != len(test_raw):
|
|
907
|
+
raise ValueError("Test prediction length mismatch for direct model.")
|
|
908
|
+
if len(pred_ft_train) != len(train_df):
|
|
909
|
+
raise ValueError("Train prediction length mismatch for FT-embed model.")
|
|
910
|
+
if len(pred_ft_test) != len(test_df):
|
|
911
|
+
raise ValueError("Test prediction length mismatch for FT-embed model.")
|
|
912
|
+
|
|
913
|
+
plot_train = train_raw.copy()
|
|
914
|
+
plot_test = test_raw.copy()
|
|
915
|
+
plot_train["pred_direct"] = pred_direct_train
|
|
916
|
+
plot_train["pred_ft"] = pred_ft_train
|
|
917
|
+
plot_test["pred_direct"] = pred_direct_test
|
|
918
|
+
plot_test["pred_ft"] = pred_ft_test
|
|
919
|
+
|
|
920
|
+
weight_col = direct_cfg["weight"]
|
|
921
|
+
target_col = direct_cfg["target"]
|
|
922
|
+
if weight_col not in plot_train.columns:
|
|
923
|
+
plot_train[weight_col] = 1.0
|
|
924
|
+
if weight_col not in plot_test.columns:
|
|
925
|
+
plot_test[weight_col] = 1.0
|
|
926
|
+
if target_col in plot_train.columns:
|
|
927
|
+
plot_train["w_act"] = plot_train[target_col] * plot_train[weight_col]
|
|
928
|
+
if target_col in plot_test.columns:
|
|
929
|
+
plot_test["w_act"] = plot_test[target_col] * plot_test[weight_col]
|
|
930
|
+
|
|
931
|
+
if "w_act" not in plot_train.columns or plot_train["w_act"].isna().all():
|
|
932
|
+
print("[Plot] Missing target values in train split; skip plots.")
|
|
933
|
+
return "Skipped plotting due to missing target values."
|
|
934
|
+
|
|
935
|
+
n_bins = n_bins_override or direct_cfg.get("plot", {}).get("n_bins", 10)
|
|
936
|
+
datasets = []
|
|
937
|
+
if not plot_train["w_act"].isna().all():
|
|
938
|
+
datasets.append(("Train Data", plot_train))
|
|
939
|
+
if not plot_test["w_act"].isna().all():
|
|
940
|
+
datasets.append(("Test Data", plot_test))
|
|
941
|
+
|
|
942
|
+
style = PlotStyle()
|
|
943
|
+
fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
|
|
944
|
+
if len(datasets) == 1:
|
|
945
|
+
axes = [axes]
|
|
946
|
+
for ax, (title, data) in zip(axes, datasets):
|
|
947
|
+
plot_double_lift_curve(
|
|
948
|
+
data["pred_direct"].values,
|
|
949
|
+
data["pred_ft"].values,
|
|
950
|
+
data["w_act"].values,
|
|
951
|
+
data[weight_col].values,
|
|
952
|
+
n_bins=n_bins,
|
|
953
|
+
title=f"Double Lift Chart on {title}",
|
|
954
|
+
label1=label_direct,
|
|
955
|
+
label2=label_ft,
|
|
956
|
+
pred1_weighted=False,
|
|
957
|
+
pred2_weighted=False,
|
|
958
|
+
actual_weighted=True,
|
|
959
|
+
ax=ax,
|
|
960
|
+
show=False,
|
|
961
|
+
style=style,
|
|
962
|
+
)
|
|
963
|
+
plt.subplots_adjust(wspace=0.3)
|
|
964
|
+
|
|
965
|
+
output_root = _resolve_output_dir(direct_cfg, direct_cfg_path)
|
|
966
|
+
plot_style = _resolve_plot_style(direct_cfg)
|
|
967
|
+
filename = (
|
|
968
|
+
f"01_{model_name}_dlift_"
|
|
969
|
+
f"{_safe_tag(label_direct)}_vs_{_safe_tag(label_ft)}.png"
|
|
970
|
+
)
|
|
971
|
+
save_path = _resolve_plot_path(
|
|
972
|
+
output_root,
|
|
973
|
+
plot_style,
|
|
974
|
+
f"{model_name}/double_lift",
|
|
975
|
+
filename,
|
|
976
|
+
)
|
|
977
|
+
finalize_figure(fig, save_path=save_path, show=False, style=style)
|
|
978
|
+
print(f"Double lift saved to: {save_path}")
|
|
979
|
+
return str(save_path)
|