ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,1503 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import csv
4
- import ctypes
5
- import copy
6
- import gc
7
- import json
8
- import math
9
- import os
10
- import random
11
- import time
12
- from contextlib import nullcontext
13
- from datetime import timedelta
14
- from pathlib import Path
15
- from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
16
-
17
- try: # matplotlib is optional; avoid hard import failures in headless/minimal envs
18
- import matplotlib
19
- if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
20
- matplotlib.use("Agg")
21
- import matplotlib.pyplot as plt
22
- _MPL_IMPORT_ERROR: Optional[BaseException] = None
23
- except Exception as exc: # pragma: no cover - optional dependency
24
- matplotlib = None # type: ignore[assignment]
25
- plt = None # type: ignore[assignment]
26
- _MPL_IMPORT_ERROR = exc
27
- import numpy as np
28
- import optuna
29
- import pandas as pd
30
- import torch
31
- import torch.distributed as dist
32
- import torch.nn as nn
33
- import torch.nn.functional as F
34
- from torch.cuda.amp import autocast, GradScaler
35
- from torch.nn.parallel import DistributedDataParallel as DDP
36
- from torch.nn.utils import clip_grad_norm_
37
- from torch.utils.data import DataLoader, DistributedSampler
38
-
39
- # Optional: unify plotting with shared plotting package
40
- try:
41
- from ...plotting import curves as plot_curves_common
42
- from ...plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
43
- except Exception: # pragma: no cover
44
- try:
45
- from ins_pricing.plotting import curves as plot_curves_common
46
- from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
47
- except Exception: # pragma: no cover
48
- plot_curves_common = None
49
- plot_loss_curve_common = None
50
- # Limit CUDA allocator split size to reduce fragmentation and OOM risk.
51
- # Override via PYTORCH_CUDA_ALLOC_CONF if needed.
52
- os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:256")
53
-
54
-
55
- # Constants and utility helpers
56
- # =============================================================================
57
- torch.backends.cudnn.benchmark = True
58
- EPS = 1e-8
59
-
60
-
61
- def _plot_skip(label: str) -> None:
62
- if _MPL_IMPORT_ERROR is not None:
63
- print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
64
- else:
65
- print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
66
-
67
-
68
- def set_global_seed(seed: int) -> None:
69
- random.seed(seed)
70
- np.random.seed(seed)
71
- torch.manual_seed(seed)
72
- if torch.cuda.is_available():
73
- torch.cuda.manual_seed_all(seed)
74
-
75
-
76
- class IOUtils:
77
- # File and path utilities.
78
-
79
- @staticmethod
80
- def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
81
- with open(file_path, mode='r', encoding='utf-8') as file:
82
- reader = csv.DictReader(file)
83
- return [
84
- dict(filter(lambda item: item[0] != '', row.items()))
85
- for row in reader
86
- ]
87
-
88
- @staticmethod
89
- def _sanitize_params_dict(params: Dict[str, Any]) -> Dict[str, Any]:
90
- # Filter index-like columns such as "Unnamed: 0" from pandas I/O.
91
- return {
92
- k: v
93
- for k, v in (params or {}).items()
94
- if k and not str(k).startswith("Unnamed")
95
- }
96
-
97
- @staticmethod
98
- def load_params_file(path: str) -> Dict[str, Any]:
99
- """Load parameter dict from JSON/CSV/TSV files.
100
-
101
- - JSON: accept dict or {"best_params": {...}} wrapper
102
- - CSV/TSV: read the first row as params
103
- """
104
- file_path = Path(path).expanduser().resolve()
105
- if not file_path.exists():
106
- raise FileNotFoundError(f"params file not found: {file_path}")
107
- suffix = file_path.suffix.lower()
108
- if suffix == ".json":
109
- payload = json.loads(file_path.read_text(
110
- encoding="utf-8", errors="replace"))
111
- if isinstance(payload, dict) and "best_params" in payload:
112
- payload = payload.get("best_params") or {}
113
- if not isinstance(payload, dict):
114
- raise ValueError(
115
- f"Invalid JSON params file (expect dict): {file_path}")
116
- return IOUtils._sanitize_params_dict(dict(payload))
117
- if suffix in (".csv", ".tsv"):
118
- df = pd.read_csv(file_path, sep="\t" if suffix == ".tsv" else ",")
119
- if df.empty:
120
- raise ValueError(f"Empty params file: {file_path}")
121
- params = df.iloc[0].to_dict()
122
- return IOUtils._sanitize_params_dict(params)
123
- raise ValueError(
124
- f"Unsupported params file type '{suffix}': {file_path}")
125
-
126
- @staticmethod
127
- def ensure_parent_dir(file_path: str) -> None:
128
- # Create parent directories when missing.
129
- directory = os.path.dirname(file_path)
130
- if directory:
131
- os.makedirs(directory, exist_ok=True)
132
-
133
-
134
- class TrainingUtils:
135
- # Small helpers used during training.
136
-
137
- @staticmethod
138
- def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
139
- estimated = int((learning_rate / 1e-4) ** 0.5 *
140
- (data_size / max(batch_num, 1)))
141
- return max(1, min(data_size, max(minimum, estimated)))
142
-
143
- @staticmethod
144
- def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
145
- # Clamp predictions to positive values for stability.
146
- pred_clamped = torch.clamp(pred, min=eps)
147
- if p == 1:
148
- term1 = target * torch.log(target / pred_clamped + eps) # Poisson
149
- term2 = -target + pred_clamped
150
- term3 = 0
151
- elif p == 0:
152
- term1 = 0.5 * torch.pow(target - pred_clamped, 2) # Gaussian
153
- term2 = 0
154
- term3 = 0
155
- elif p == 2:
156
- term1 = torch.log(pred_clamped / target + eps) # Gamma
157
- term2 = -target / pred_clamped + 1
158
- term3 = 0
159
- else:
160
- term1 = torch.pow(target, 2 - p) / ((1 - p) * (2 - p))
161
- term2 = target * torch.pow(pred_clamped, 1 - p) / (1 - p)
162
- term3 = torch.pow(pred_clamped, 2 - p) / (2 - p)
163
- return torch.nan_to_num( # Tweedie negative log-likelihood (constant omitted)
164
- 2 * (term1 - term2 + term3),
165
- nan=eps,
166
- posinf=max_clip,
167
- neginf=-max_clip
168
- )
169
-
170
- @staticmethod
171
- def free_cuda() -> None:
172
- print(">>> Moving all models to CPU...")
173
- for obj in gc.get_objects():
174
- try:
175
- if hasattr(obj, "to") and callable(obj.to):
176
- obj.to("cpu")
177
- except Exception:
178
- pass
179
-
180
- print(">>> Releasing tensor/optimizer/DataLoader references...")
181
- gc.collect()
182
-
183
- print(">>> Clearing CUDA cache...")
184
- if torch.cuda.is_available():
185
- torch.cuda.empty_cache()
186
- torch.cuda.synchronize()
187
- print(">>> CUDA memory released.")
188
- else:
189
- print(">>> CUDA not available; cleanup skipped.")
190
-
191
-
192
- class DistributedUtils:
193
- _cached_state: Optional[tuple] = None
194
-
195
- @staticmethod
196
- def setup_ddp():
197
- """Initialize the DDP process group for distributed training."""
198
- if dist.is_initialized():
199
- if DistributedUtils._cached_state is None:
200
- rank = dist.get_rank()
201
- world_size = dist.get_world_size()
202
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
203
- DistributedUtils._cached_state = (
204
- True,
205
- local_rank,
206
- rank,
207
- world_size,
208
- )
209
- return DistributedUtils._cached_state
210
-
211
- if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
212
- rank = int(os.environ["RANK"])
213
- world_size = int(os.environ["WORLD_SIZE"])
214
- local_rank = int(os.environ["LOCAL_RANK"])
215
-
216
- if os.name == "nt" and torch.cuda.is_available() and world_size > 1:
217
- print(
218
- ">>> DDP Setup Disabled: Windows CUDA DDP is not supported. "
219
- "Falling back to single process."
220
- )
221
- return False, 0, 0, 1
222
-
223
- if torch.cuda.is_available():
224
- torch.cuda.set_device(local_rank)
225
-
226
- timeout_seconds = int(os.environ.get(
227
- "BAYESOPT_DDP_TIMEOUT_SECONDS", "1800"))
228
- timeout = timedelta(seconds=max(1, timeout_seconds))
229
- backend = "gloo"
230
- if torch.cuda.is_available() and os.name != "nt":
231
- try:
232
- if getattr(dist, "is_nccl_available", lambda: False)():
233
- backend = "nccl"
234
- except Exception:
235
- backend = "gloo"
236
-
237
- dist.init_process_group(
238
- backend=backend, init_method="env://", timeout=timeout)
239
- print(
240
- f">>> DDP Initialized ({backend}, timeout={timeout_seconds}s): "
241
- f"Rank {rank}/{world_size}, Local Rank {local_rank}"
242
- )
243
- DistributedUtils._cached_state = (
244
- True,
245
- local_rank,
246
- rank,
247
- world_size,
248
- )
249
- return DistributedUtils._cached_state
250
- else:
251
- print(
252
- f">>> DDP Setup Failed: RANK or WORLD_SIZE not found in env. Keys found: {list(os.environ.keys())}"
253
- )
254
- print(
255
- ">>> Hint: launch with torchrun --nproc_per_node=<N> <script.py>"
256
- )
257
- return False, 0, 0, 1
258
-
259
- @staticmethod
260
- def cleanup_ddp():
261
- """Destroy the DDP process group and clear cached state."""
262
- if dist.is_initialized():
263
- dist.destroy_process_group()
264
- DistributedUtils._cached_state = None
265
-
266
- @staticmethod
267
- def is_main_process():
268
- return not dist.is_initialized() or dist.get_rank() == 0
269
-
270
- @staticmethod
271
- def world_size() -> int:
272
- return dist.get_world_size() if dist.is_initialized() else 1
273
-
274
-
275
- class PlotUtils:
276
- # Plot helpers shared across models.
277
-
278
- @staticmethod
279
- def split_data(data: pd.DataFrame, col_nme: str, wgt_nme: str, n_bins: int = 10) -> pd.DataFrame:
280
- data_sorted = data.sort_values(by=col_nme, ascending=True).copy()
281
- data_sorted['cum_weight'] = data_sorted[wgt_nme].cumsum()
282
- w_sum = data_sorted[wgt_nme].sum()
283
- if w_sum <= EPS:
284
- data_sorted['bins'] = 0
285
- else:
286
- data_sorted['bins'] = np.floor(
287
- data_sorted['cum_weight'] * float(n_bins) / w_sum
288
- )
289
- data_sorted.loc[(data_sorted['bins'] == n_bins),
290
- 'bins'] = n_bins - 1
291
- return data_sorted.groupby(['bins'], observed=True).sum(numeric_only=True)
292
-
293
- @staticmethod
294
- def plot_lift_ax(ax, plot_data, title, pred_label='Predicted', act_label='Actual', weight_label='Earned Exposure'):
295
- ax.plot(plot_data.index, plot_data['act_v'],
296
- label=act_label, color='red')
297
- ax.plot(plot_data.index, plot_data['exp_v'],
298
- label=pred_label, color='blue')
299
- ax.set_title(title, fontsize=8)
300
- ax.set_xticks(plot_data.index)
301
- ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
302
- ax.tick_params(axis='y', labelsize=6)
303
- ax.legend(loc='upper left', fontsize=5, frameon=False)
304
- ax.margins(0.05)
305
- ax2 = ax.twinx()
306
- ax2.bar(plot_data.index, plot_data['weight'],
307
- alpha=0.5, color='seagreen',
308
- label=weight_label)
309
- ax2.tick_params(axis='y', labelsize=6)
310
- ax2.legend(loc='upper right', fontsize=5, frameon=False)
311
-
312
- @staticmethod
313
- def plot_dlift_ax(ax, plot_data, title, label1, label2, act_label='Actual', weight_label='Earned Exposure'):
314
- ax.plot(plot_data.index, plot_data['act_v'],
315
- label=act_label, color='red')
316
- ax.plot(plot_data.index, plot_data['exp_v1'],
317
- label=label1, color='blue')
318
- ax.plot(plot_data.index, plot_data['exp_v2'],
319
- label=label2, color='black')
320
- ax.set_title(title, fontsize=8)
321
- ax.set_xticks(plot_data.index)
322
- ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
323
- ax.set_xlabel(f'{label1} / {label2}', fontsize=6)
324
- ax.tick_params(axis='y', labelsize=6)
325
- ax.legend(loc='upper left', fontsize=5, frameon=False)
326
- ax.margins(0.1)
327
- ax2 = ax.twinx()
328
- ax2.bar(plot_data.index, plot_data['weight'],
329
- alpha=0.5, color='seagreen',
330
- label=weight_label)
331
- ax2.tick_params(axis='y', labelsize=6)
332
- ax2.legend(loc='upper right', fontsize=5, frameon=False)
333
-
334
- @staticmethod
335
- def plot_lift_list(pred_model, w_pred_list, w_act_list,
336
- weight_list, tgt_nme, n_bins: int = 10,
337
- fig_nme: str = 'Lift Chart'):
338
- if plot_curves_common is not None:
339
- save_path = os.path.join(
340
- os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
341
- plot_curves_common.plot_lift_curve(
342
- pred_model,
343
- w_act_list,
344
- weight_list,
345
- n_bins=n_bins,
346
- title=f'Lift Chart of {tgt_nme}',
347
- pred_label='Predicted',
348
- act_label='Actual',
349
- weight_label='Earned Exposure',
350
- pred_weighted=False,
351
- actual_weighted=True,
352
- save_path=save_path,
353
- show=False,
354
- )
355
- return
356
- if plt is None:
357
- _plot_skip("lift plot")
358
- return
359
- lift_data = pd.DataFrame({
360
- 'pred': pred_model,
361
- 'w_pred': w_pred_list,
362
- 'act': w_act_list,
363
- 'weight': weight_list
364
- })
365
- plot_data = PlotUtils.split_data(lift_data, 'pred', 'weight', n_bins)
366
- plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
367
- plot_data['act_v'] = plot_data['act'] / plot_data['weight']
368
- plot_data.reset_index(inplace=True)
369
-
370
- fig = plt.figure(figsize=(7, 5))
371
- ax = fig.add_subplot(111)
372
- PlotUtils.plot_lift_ax(ax, plot_data, f'Lift Chart of {tgt_nme}')
373
- plt.subplots_adjust(wspace=0.3)
374
-
375
- save_path = os.path.join(
376
- os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
377
- IOUtils.ensure_parent_dir(save_path)
378
- plt.savefig(save_path, dpi=300)
379
- plt.close(fig)
380
-
381
- @staticmethod
382
- def plot_dlift_list(pred_model_1, pred_model_2,
383
- model_nme_1, model_nme_2,
384
- tgt_nme,
385
- w_list, w_act_list, n_bins: int = 10,
386
- fig_nme: str = 'Double Lift Chart'):
387
- if plot_curves_common is not None:
388
- save_path = os.path.join(
389
- os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
390
- plot_curves_common.plot_double_lift_curve(
391
- pred_model_1,
392
- pred_model_2,
393
- w_act_list,
394
- w_list,
395
- n_bins=n_bins,
396
- title=f'Double Lift Chart of {tgt_nme}',
397
- label1=model_nme_1,
398
- label2=model_nme_2,
399
- pred1_weighted=False,
400
- pred2_weighted=False,
401
- actual_weighted=True,
402
- save_path=save_path,
403
- show=False,
404
- )
405
- return
406
- if plt is None:
407
- _plot_skip("double lift plot")
408
- return
409
- lift_data = pd.DataFrame({
410
- 'pred1': pred_model_1,
411
- 'pred2': pred_model_2,
412
- 'act': w_act_list,
413
- 'weight': w_list
414
- })
415
- lift_data['diff_ly'] = lift_data['pred1'] / lift_data['pred2']
416
- lift_data['w_pred1'] = lift_data['pred1'] * lift_data['weight']
417
- lift_data['w_pred2'] = lift_data['pred2'] * lift_data['weight']
418
- plot_data = PlotUtils.split_data(
419
- lift_data, 'diff_ly', 'weight', n_bins)
420
- plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
421
- plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
422
- plot_data['act_v'] = plot_data['act']/plot_data['act']
423
- plot_data.reset_index(inplace=True)
424
-
425
- fig = plt.figure(figsize=(7, 5))
426
- ax = fig.add_subplot(111)
427
- PlotUtils.plot_dlift_ax(
428
- ax, plot_data, f'Double Lift Chart of {tgt_nme}', model_nme_1, model_nme_2)
429
- plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
430
-
431
- save_path = os.path.join(
432
- os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
433
- IOUtils.ensure_parent_dir(save_path)
434
- plt.savefig(save_path, dpi=300)
435
- plt.close(fig)
436
-
437
-
438
- def infer_factor_and_cate_list(train_df: pd.DataFrame,
439
- test_df: pd.DataFrame,
440
- resp_nme: str,
441
- weight_nme: str,
442
- binary_resp_nme: Optional[str] = None,
443
- factor_nmes: Optional[List[str]] = None,
444
- cate_list: Optional[List[str]] = None,
445
- infer_categorical_max_unique: int = 50,
446
- infer_categorical_max_ratio: float = 0.05) -> Tuple[List[str], List[str]]:
447
- """Infer factor_nmes/cate_list when feature names are not provided.
448
-
449
- Rules:
450
- - factor_nmes: start from shared train/test columns, exclude target/weight/(optional binary target).
451
- - cate_list: object/category/bool plus low-cardinality integer columns.
452
- - Always intersect with shared train/test columns to avoid mismatches.
453
- """
454
- excluded = {resp_nme, weight_nme}
455
- if binary_resp_nme:
456
- excluded.add(binary_resp_nme)
457
-
458
- common_cols = [c for c in train_df.columns if c in test_df.columns]
459
- if factor_nmes is None:
460
- factors = [c for c in common_cols if c not in excluded]
461
- else:
462
- factors = [
463
- c for c in factor_nmes if c in common_cols and c not in excluded]
464
-
465
- if cate_list is not None:
466
- cats = [c for c in cate_list if c in factors]
467
- return factors, cats
468
-
469
- n_rows = max(1, len(train_df))
470
- cats: List[str] = []
471
- for col in factors:
472
- s = train_df[col]
473
- if pd.api.types.is_bool_dtype(s) or pd.api.types.is_object_dtype(s) or isinstance(s.dtype, pd.CategoricalDtype):
474
- cats.append(col)
475
- continue
476
- if pd.api.types.is_integer_dtype(s):
477
- nunique = int(s.nunique(dropna=True))
478
- if nunique <= infer_categorical_max_unique or (nunique / n_rows) <= infer_categorical_max_ratio:
479
- cats.append(col)
480
- return factors, cats
481
-
482
-
483
- # Backward-compatible functional wrappers
484
- def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
485
- return IOUtils.csv_to_dict(file_path)
486
-
487
-
488
- def ensure_parent_dir(file_path: str) -> None:
489
- IOUtils.ensure_parent_dir(file_path)
490
-
491
-
492
- def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
493
- return TrainingUtils.compute_batch_size(data_size, learning_rate, batch_num, minimum)
494
-
495
-
496
- # Tweedie deviance loss for PyTorch.
497
- # Reference: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-poisson-gamma-and-tweedie-deviances
498
- def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
499
- return TrainingUtils.tweedie_loss(pred, target, p=p, eps=eps, max_clip=max_clip)
500
-
501
-
502
- # CUDA memory release helper.
503
- def free_cuda():
504
- TrainingUtils.free_cuda()
505
-
506
-
507
- class TorchTrainerMixin:
508
- # Shared helpers for Torch tabular trainers.
509
-
510
- def _device_type(self) -> str:
511
- return getattr(self, "device", torch.device("cpu")).type
512
-
513
- def _resolve_resource_profile(self) -> str:
514
- profile = getattr(self, "resource_profile", None)
515
- if not profile:
516
- profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
517
- profile = str(profile).strip().lower()
518
- if profile in {"cpu", "mps", "cuda"}:
519
- profile = "auto"
520
- if profile not in {"auto", "throughput", "memory_saving"}:
521
- profile = "auto"
522
- if profile == "auto" and self._device_type() == "cuda":
523
- profile = "throughput"
524
- return profile
525
-
526
- def _log_resource_summary_once(self, profile: str) -> None:
527
- if getattr(self, "_resource_summary_logged", False):
528
- return
529
- if dist.is_initialized() and not DistributedUtils.is_main_process():
530
- return
531
- self._resource_summary_logged = True
532
- device = getattr(self, "device", torch.device("cpu"))
533
- device_type = self._device_type()
534
- cpu_count = os.cpu_count() or 1
535
- cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
536
- mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
537
- ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
538
- data_parallel = bool(getattr(self, "use_data_parallel", False))
539
- print(
540
- f">>> Resource summary: device={device}, device_type={device_type}, "
541
- f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
542
- f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
543
- )
544
-
545
- def _available_system_memory(self) -> Optional[int]:
546
- if os.name == "nt":
547
- class _MemStatus(ctypes.Structure):
548
- _fields_ = [
549
- ("dwLength", ctypes.c_ulong),
550
- ("dwMemoryLoad", ctypes.c_ulong),
551
- ("ullTotalPhys", ctypes.c_ulonglong),
552
- ("ullAvailPhys", ctypes.c_ulonglong),
553
- ("ullTotalPageFile", ctypes.c_ulonglong),
554
- ("ullAvailPageFile", ctypes.c_ulonglong),
555
- ("ullTotalVirtual", ctypes.c_ulonglong),
556
- ("ullAvailVirtual", ctypes.c_ulonglong),
557
- ("sullAvailExtendedVirtual", ctypes.c_ulonglong),
558
- ]
559
- status = _MemStatus()
560
- status.dwLength = ctypes.sizeof(_MemStatus)
561
- if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
562
- return int(status.ullAvailPhys)
563
- return None
564
- try:
565
- pages = os.sysconf("SC_AVPHYS_PAGES")
566
- page_size = os.sysconf("SC_PAGE_SIZE")
567
- return int(pages * page_size)
568
- except Exception:
569
- return None
570
-
571
- def _available_cuda_memory(self) -> Optional[int]:
572
- if not torch.cuda.is_available():
573
- return None
574
- try:
575
- free_mem, _total_mem = torch.cuda.mem_get_info()
576
- except Exception:
577
- return None
578
- return int(free_mem)
579
-
580
- def _estimate_sample_bytes(self, dataset) -> Optional[int]:
581
- try:
582
- if len(dataset) == 0:
583
- return None
584
- sample = dataset[0]
585
- except Exception:
586
- return None
587
-
588
- def _bytes(obj) -> int:
589
- if obj is None:
590
- return 0
591
- if torch.is_tensor(obj):
592
- return int(obj.element_size() * obj.nelement())
593
- if isinstance(obj, np.ndarray):
594
- return int(obj.nbytes)
595
- if isinstance(obj, (list, tuple)):
596
- return int(sum(_bytes(item) for item in obj))
597
- if isinstance(obj, dict):
598
- return int(sum(_bytes(item) for item in obj.values()))
599
- return 0
600
-
601
- sample_bytes = _bytes(sample)
602
- return int(sample_bytes) if sample_bytes > 0 else None
603
-
604
- def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
605
- if batch_size <= 1:
606
- return batch_size
607
- sample_bytes = self._estimate_sample_bytes(dataset)
608
- if sample_bytes is None:
609
- return batch_size
610
- device_type = self._device_type()
611
- if device_type == "cuda":
612
- available = self._available_cuda_memory()
613
- if available is None:
614
- return batch_size
615
- if profile == "throughput":
616
- budget_ratio = 0.8
617
- overhead = 8.0
618
- elif profile == "memory_saving":
619
- budget_ratio = 0.5
620
- overhead = 14.0
621
- else:
622
- budget_ratio = 0.6
623
- overhead = 12.0
624
- else:
625
- available = self._available_system_memory()
626
- if available is None:
627
- return batch_size
628
- if profile == "throughput":
629
- budget_ratio = 0.4
630
- overhead = 1.8
631
- elif profile == "memory_saving":
632
- budget_ratio = 0.25
633
- overhead = 3.0
634
- else:
635
- budget_ratio = 0.3
636
- overhead = 2.6
637
- budget = int(available * budget_ratio)
638
- per_sample = int(sample_bytes * overhead)
639
- if per_sample <= 0:
640
- return batch_size
641
- max_batch = max(1, int(budget // per_sample))
642
- if max_batch < batch_size:
643
- print(
644
- f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
645
- f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
646
- )
647
- return min(batch_size, max_batch)
648
-
649
- def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
650
- if os.name == 'nt':
651
- return 0
652
- if getattr(self, "is_ddp_enabled", False):
653
- return 0
654
- profile = profile or self._resolve_resource_profile()
655
- if profile == "memory_saving":
656
- return 0
657
- worker_cap = min(int(max_workers), os.cpu_count() or 1)
658
- if self._device_type() == "mps":
659
- worker_cap = min(worker_cap, 2)
660
- return worker_cap
661
-
662
- def _build_dataloader(self,
663
- dataset,
664
- N: int,
665
- base_bs_gpu: tuple,
666
- base_bs_cpu: tuple,
667
- min_bs: int = 64,
668
- target_effective_cuda: int = 1024,
669
- target_effective_cpu: int = 512,
670
- large_threshold: int = 200_000,
671
- mid_threshold: int = 50_000):
672
- profile = self._resolve_resource_profile()
673
- self._log_resource_summary_once(profile)
674
- batch_size = TrainingUtils.compute_batch_size(
675
- data_size=len(dataset),
676
- learning_rate=self.learning_rate,
677
- batch_num=self.batch_num,
678
- minimum=min_bs
679
- )
680
- gpu_large, gpu_mid, gpu_small = base_bs_gpu
681
- cpu_mid, cpu_small = base_bs_cpu
682
-
683
- if self._device_type() == 'cuda':
684
- device_count = torch.cuda.device_count()
685
- if getattr(self, "is_ddp_enabled", False):
686
- device_count = 1
687
- # In multi-GPU, increase min batch size so each GPU gets enough data.
688
- if device_count > 1:
689
- min_bs = min_bs * device_count
690
- print(
691
- f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
692
-
693
- if N > large_threshold:
694
- base_bs = gpu_large * device_count
695
- elif N > mid_threshold:
696
- base_bs = gpu_mid * device_count
697
- else:
698
- base_bs = gpu_small * device_count
699
- else:
700
- base_bs = cpu_mid if N > mid_threshold else cpu_small
701
-
702
- # Recompute batch_size to respect the adjusted min_bs.
703
- batch_size = TrainingUtils.compute_batch_size(
704
- data_size=len(dataset),
705
- learning_rate=self.learning_rate,
706
- batch_num=self.batch_num,
707
- minimum=min_bs
708
- )
709
- batch_size = min(batch_size, base_bs, N)
710
- batch_size = self._cap_batch_size_by_memory(
711
- dataset, batch_size, profile)
712
-
713
- target_effective_bs = target_effective_cuda if self._device_type(
714
- ) == 'cuda' else target_effective_cpu
715
- if getattr(self, "is_ddp_enabled", False):
716
- world_size = max(1, DistributedUtils.world_size())
717
- target_effective_bs = max(1, target_effective_bs // world_size)
718
-
719
- world_size = getattr(self, "world_size", 1) if getattr(
720
- self, "is_ddp_enabled", False) else 1
721
- samples_per_rank = math.ceil(
722
- N / max(1, world_size)) if world_size > 1 else N
723
- steps_per_epoch = max(
724
- 1, math.ceil(samples_per_rank / max(1, batch_size)))
725
- # Limit gradient accumulation to avoid scaling beyond actual batches.
726
- desired_accum = max(1, target_effective_bs // max(1, batch_size))
727
- accum_steps = max(1, min(desired_accum, steps_per_epoch))
728
-
729
- # Linux (posix) uses fork; Windows (nt) uses spawn with higher overhead.
730
- workers = self._resolve_num_workers(8, profile=profile)
731
- prefetch_factor = None
732
- if workers > 0:
733
- prefetch_factor = 4 if profile == "throughput" else 2
734
- persistent = workers > 0 and profile != "memory_saving"
735
- print(
736
- f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
737
- f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
738
- sampler = None
739
- if dist.is_initialized():
740
- sampler = DistributedSampler(dataset, shuffle=True)
741
- shuffle = False # DistributedSampler handles shuffling.
742
- else:
743
- shuffle = True
744
-
745
- dataloader = DataLoader(
746
- dataset,
747
- batch_size=batch_size,
748
- shuffle=shuffle,
749
- sampler=sampler,
750
- num_workers=workers,
751
- pin_memory=(self._device_type() == 'cuda'),
752
- persistent_workers=persistent,
753
- **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
754
- )
755
- return dataloader, accum_steps
756
-
757
- def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
758
- profile = self._resolve_resource_profile()
759
- val_bs = accum_steps * train_dataloader.batch_size
760
- val_workers = self._resolve_num_workers(4, profile=profile)
761
- prefetch_factor = None
762
- if val_workers > 0:
763
- prefetch_factor = 2
764
- return DataLoader(
765
- dataset,
766
- batch_size=val_bs,
767
- shuffle=False,
768
- num_workers=val_workers,
769
- pin_memory=(self._device_type() == 'cuda'),
770
- persistent_workers=(val_workers > 0 and profile != "memory_saving"),
771
- **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
772
- )
773
-
774
- def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
775
- task = getattr(self, "task_type", "regression")
776
- if task == 'classification':
777
- loss_fn = nn.BCEWithLogitsLoss(reduction='none')
778
- return loss_fn(y_pred, y_true).view(-1)
779
- if apply_softplus:
780
- y_pred = F.softplus(y_pred)
781
- y_pred = torch.clamp(y_pred, min=1e-6)
782
- power = getattr(self, "tw_power", 1.5)
783
- return tweedie_loss(y_pred, y_true, p=power).view(-1)
784
-
785
- def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
786
- losses = self._compute_losses(
787
- y_pred, y_true, apply_softplus=apply_softplus)
788
- weighted_loss = (losses * weights.view(-1)).sum() / \
789
- torch.clamp(weights.sum(), min=EPS)
790
- return weighted_loss
791
-
792
- def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
793
- ignore_keys: Optional[List[str]] = None):
794
- if val_loss < best_loss:
795
- ignore_keys = ignore_keys or []
796
- # Unwrap DDP module to avoid module. prefix in state_dict keys
797
- base_module = model.module if hasattr(model, "module") else model
798
- state_dict = {
799
- k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
800
- for k, v in base_module.state_dict().items()
801
- if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
802
- }
803
- return val_loss, state_dict, 0, False
804
- patience_counter += 1
805
- should_stop = best_state is not None and patience_counter >= getattr(
806
- self, "patience", 0)
807
- return best_loss, best_state, patience_counter, should_stop
808
-
809
- def _train_model(self,
810
- model,
811
- dataloader,
812
- accum_steps,
813
- optimizer,
814
- scaler,
815
- forward_fn,
816
- val_forward_fn=None,
817
- apply_softplus: bool = False,
818
- clip_fn=None,
819
- trial: Optional[optuna.trial.Trial] = None,
820
- loss_curve_path: Optional[str] = None):
821
- device_type = self._device_type()
822
- best_loss = float('inf')
823
- best_state = None
824
- patience_counter = 0
825
- stop_training = False
826
- train_history: List[float] = []
827
- val_history: List[float] = []
828
-
829
- is_ddp_model = isinstance(model, DDP)
830
-
831
- for epoch in range(1, getattr(self, "epochs", 1) + 1):
832
- epoch_start_ts = time.time()
833
- val_weighted_loss = None
834
- if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
835
- self.dataloader_sampler.set_epoch(epoch)
836
-
837
- model.train()
838
- optimizer.zero_grad()
839
-
840
- epoch_loss_sum = None
841
- epoch_weight_sum = None
842
- for step, batch in enumerate(dataloader):
843
- is_update_step = ((step + 1) % accum_steps == 0) or \
844
- ((step + 1) == len(dataloader))
845
- sync_cm = model.no_sync if (
846
- is_ddp_model and not is_update_step) else nullcontext
847
-
848
- with sync_cm():
849
- with autocast(enabled=(device_type == 'cuda')):
850
- y_pred, y_true, w = forward_fn(batch)
851
- weighted_loss = self._compute_weighted_loss(
852
- y_pred, y_true, w, apply_softplus=apply_softplus)
853
- loss_for_backward = weighted_loss / accum_steps
854
-
855
- batch_weight = torch.clamp(
856
- w.detach().sum(), min=EPS).to(dtype=torch.float32)
857
- loss_val = weighted_loss.detach().to(dtype=torch.float32)
858
- if epoch_loss_sum is None:
859
- epoch_loss_sum = torch.zeros(
860
- (), device=batch_weight.device, dtype=torch.float32)
861
- epoch_weight_sum = torch.zeros(
862
- (), device=batch_weight.device, dtype=torch.float32)
863
- epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
864
- epoch_weight_sum = epoch_weight_sum + batch_weight
865
- scaler.scale(loss_for_backward).backward()
866
-
867
- if is_update_step:
868
- if clip_fn is not None:
869
- clip_fn()
870
- scaler.step(optimizer)
871
- scaler.update()
872
- optimizer.zero_grad()
873
-
874
- if epoch_loss_sum is None or epoch_weight_sum is None:
875
- train_epoch_loss = 0.0
876
- else:
877
- train_epoch_loss = (
878
- epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
879
- ).item()
880
- train_history.append(float(train_epoch_loss))
881
-
882
- if val_forward_fn is not None:
883
- should_compute_val = (not dist.is_initialized()
884
- or DistributedUtils.is_main_process())
885
- val_device = getattr(self, "device", torch.device("cpu"))
886
- if not isinstance(val_device, torch.device):
887
- val_device = torch.device(val_device)
888
- loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
889
- "cpu")
890
- val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
891
-
892
- if should_compute_val:
893
- model.eval()
894
- with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
895
- val_result = val_forward_fn()
896
- if isinstance(val_result, tuple) and len(val_result) == 3:
897
- y_val_pred, y_val_true, w_val = val_result
898
- val_weighted_loss = self._compute_weighted_loss(
899
- y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
900
- else:
901
- val_weighted_loss = val_result
902
- val_loss_tensor[0] = float(val_weighted_loss)
903
-
904
- if dist.is_initialized():
905
- dist.broadcast(val_loss_tensor, src=0)
906
- val_weighted_loss = float(val_loss_tensor.item())
907
-
908
- val_history.append(val_weighted_loss)
909
-
910
- best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
911
- val_weighted_loss, best_loss, best_state, patience_counter, model)
912
-
913
- prune_flag = False
914
- is_main_rank = DistributedUtils.is_main_process()
915
- # Only main process reports to Optuna to avoid duplicate reports
916
- if trial is not None and is_main_rank:
917
- trial.report(val_weighted_loss, epoch)
918
- prune_flag = trial.should_prune()
919
-
920
- if dist.is_initialized():
921
- prune_device = getattr(self, "device", torch.device("cpu"))
922
- if not isinstance(prune_device, torch.device):
923
- prune_device = torch.device(prune_device)
924
- prune_tensor = torch.zeros(1, device=prune_device)
925
- if is_main_rank:
926
- prune_tensor.fill_(1 if prune_flag else 0)
927
- dist.broadcast(prune_tensor, src=0)
928
- prune_flag = bool(prune_tensor.item())
929
-
930
- if prune_flag:
931
- raise optuna.TrialPruned()
932
-
933
- if stop_training:
934
- break
935
-
936
- should_log_epoch = (not dist.is_initialized()
937
- or DistributedUtils.is_main_process())
938
- if should_log_epoch:
939
- elapsed = int(time.time() - epoch_start_ts)
940
- if val_weighted_loss is None:
941
- print(
942
- f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
943
- f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
944
- flush=True,
945
- )
946
- else:
947
- print(
948
- f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
949
- f"train_loss={float(train_epoch_loss):.6f} "
950
- f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
951
- flush=True,
952
- )
953
-
954
- # Periodic memory cleanup to prevent accumulation (every 10 epochs)
955
- if epoch % 10 == 0:
956
- if torch.cuda.is_available():
957
- torch.cuda.empty_cache()
958
- gc.collect()
959
-
960
- history = {"train": train_history, "val": val_history}
961
- self._plot_loss_curve(history, loss_curve_path)
962
- return best_state, history
963
-
964
- def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
965
- if not save_path:
966
- return
967
- if dist.is_initialized() and not DistributedUtils.is_main_process():
968
- return
969
- train_hist = history.get("train", []) if history else []
970
- val_hist = history.get("val", []) if history else []
971
- if not train_hist and not val_hist:
972
- return
973
- if plot_loss_curve_common is not None:
974
- plot_loss_curve_common(
975
- history=history,
976
- title="Loss vs. Epoch",
977
- save_path=save_path,
978
- show=False,
979
- )
980
- else:
981
- if plt is None:
982
- _plot_skip("loss curve")
983
- return
984
- ensure_parent_dir(save_path)
985
- epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
986
- fig = plt.figure(figsize=(8, 4))
987
- ax = fig.add_subplot(111)
988
- if train_hist:
989
- ax.plot(range(1, len(train_hist) + 1), train_hist,
990
- label='Train Loss', color='tab:blue')
991
- if val_hist:
992
- ax.plot(range(1, len(val_hist) + 1), val_hist,
993
- label='Validation Loss', color='tab:orange')
994
- ax.set_xlabel('Epoch')
995
- ax.set_ylabel('Weighted Loss')
996
- ax.set_title('Loss vs. Epoch')
997
- ax.grid(True, linestyle='--', alpha=0.3)
998
- ax.legend()
999
- plt.tight_layout()
1000
- plt.savefig(save_path, dpi=300)
1001
- plt.close(fig)
1002
- print(f"[Training] Loss curve saved to {save_path}")
1003
-
1004
-
1005
- # =============================================================================
1006
- # Plotting helpers
1007
- # =============================================================================
1008
-
1009
- def split_data(data, col_nme, wgt_nme, n_bins=10):
1010
- return PlotUtils.split_data(data, col_nme, wgt_nme, n_bins)
1011
-
1012
- # Lift curve plotting wrapper
1013
-
1014
-
1015
- def plot_lift_list(pred_model, w_pred_list, w_act_list,
1016
- weight_list, tgt_nme, n_bins=10,
1017
- fig_nme='Lift Chart'):
1018
- return PlotUtils.plot_lift_list(pred_model, w_pred_list, w_act_list,
1019
- weight_list, tgt_nme, n_bins, fig_nme)
1020
-
1021
- # Double lift curve plotting wrapper
1022
-
1023
-
1024
- def plot_dlift_list(pred_model_1, pred_model_2,
1025
- model_nme_1, model_nme_2,
1026
- tgt_nme,
1027
- w_list, w_act_list, n_bins=10,
1028
- fig_nme='Double Lift Chart'):
1029
- return PlotUtils.plot_dlift_list(pred_model_1, pred_model_2,
1030
- model_nme_1, model_nme_2,
1031
- tgt_nme, w_list, w_act_list,
1032
- n_bins, fig_nme)
1033
-
1034
-
1035
- # =============================================================================
1036
- # Logging System
1037
- # =============================================================================
1038
- import logging
1039
- from functools import lru_cache
1040
-
1041
-
1042
- @lru_cache(maxsize=1)
1043
- def _get_package_logger() -> logging.Logger:
1044
- """Get or create the package-level logger with consistent formatting."""
1045
- logger = logging.getLogger("ins_pricing")
1046
- if not logger.handlers:
1047
- handler = logging.StreamHandler()
1048
- formatter = logging.Formatter(
1049
- "[%(levelname)s][%(name)s] %(message)s"
1050
- )
1051
- handler.setFormatter(formatter)
1052
- logger.addHandler(handler)
1053
- # Default to INFO, can be changed via environment variable
1054
- level = os.environ.get("INS_PRICING_LOG_LEVEL", "INFO").upper()
1055
- logger.setLevel(getattr(logging, level, logging.INFO))
1056
- return logger
1057
-
1058
-
1059
- def get_logger(name: str = "ins_pricing") -> logging.Logger:
1060
- """Get a logger with the given name, inheriting package-level settings.
1061
-
1062
- Args:
1063
- name: Logger name, typically module name like 'ins_pricing.trainer'
1064
-
1065
- Returns:
1066
- Configured logger instance
1067
-
1068
- Example:
1069
- >>> logger = get_logger("ins_pricing.trainer.ft")
1070
- >>> logger.info("Training started")
1071
- """
1072
- # Ensure package logger is initialized
1073
- _get_package_logger()
1074
- return logging.getLogger(name)
1075
-
1076
-
1077
- # =============================================================================
1078
- # Metric Computation Factory
1079
- # =============================================================================
1080
- from sklearn.metrics import log_loss, mean_tweedie_deviance
1081
-
1082
-
1083
- class MetricFactory:
1084
- """Factory for computing evaluation metrics consistently across all trainers.
1085
-
1086
- This class centralizes metric computation logic that was previously duplicated
1087
- across FTTrainer, ResNetTrainer, GNNTrainer, XGBTrainer, and GLMTrainer.
1088
-
1089
- Example:
1090
- >>> factory = MetricFactory(task_type='regression', tweedie_power=1.5)
1091
- >>> score = factory.compute(y_true, y_pred, sample_weight)
1092
- """
1093
-
1094
- def __init__(
1095
- self,
1096
- task_type: str = "regression",
1097
- tweedie_power: float = 1.5,
1098
- clip_min: float = 1e-8,
1099
- clip_max: float = 1 - 1e-8,
1100
- ):
1101
- """Initialize the metric factory.
1102
-
1103
- Args:
1104
- task_type: Either 'regression' or 'classification'
1105
- tweedie_power: Power parameter for Tweedie deviance (1.0-2.0)
1106
- clip_min: Minimum value for clipping predictions
1107
- clip_max: Maximum value for clipping predictions (for classification)
1108
- """
1109
- self.task_type = task_type
1110
- self.tweedie_power = tweedie_power
1111
- self.clip_min = clip_min
1112
- self.clip_max = clip_max
1113
-
1114
- def compute(
1115
- self,
1116
- y_true: np.ndarray,
1117
- y_pred: np.ndarray,
1118
- sample_weight: Optional[np.ndarray] = None,
1119
- ) -> float:
1120
- """Compute the appropriate metric based on task type.
1121
-
1122
- Args:
1123
- y_true: Ground truth values
1124
- y_pred: Predicted values
1125
- sample_weight: Optional sample weights
1126
-
1127
- Returns:
1128
- Computed metric value (lower is better)
1129
- """
1130
- y_pred = np.asarray(y_pred)
1131
- y_true = np.asarray(y_true)
1132
-
1133
- if self.task_type == "classification":
1134
- y_pred_clipped = np.clip(y_pred, self.clip_min, self.clip_max)
1135
- return float(log_loss(y_true, y_pred_clipped, sample_weight=sample_weight))
1136
-
1137
- # Regression: use Tweedie deviance
1138
- y_pred_safe = np.maximum(y_pred, self.clip_min)
1139
- return float(mean_tweedie_deviance(
1140
- y_true,
1141
- y_pred_safe,
1142
- sample_weight=sample_weight,
1143
- power=self.tweedie_power,
1144
- ))
1145
-
1146
- def update_power(self, power: float) -> None:
1147
- """Update the Tweedie power parameter.
1148
-
1149
- Args:
1150
- power: New power value (1.0-2.0)
1151
- """
1152
- self.tweedie_power = power
1153
-
1154
-
1155
- # =============================================================================
1156
- # GPU Memory Manager
1157
- # =============================================================================
1158
- from contextlib import contextmanager
1159
-
1160
-
1161
- class GPUMemoryManager:
1162
- """Context manager for GPU memory management and cleanup.
1163
-
1164
- This class consolidates GPU memory cleanup logic that was previously
1165
- scattered across multiple trainer files.
1166
-
1167
- Example:
1168
- >>> with GPUMemoryManager.cleanup_context():
1169
- ... model.train()
1170
- ... # Memory cleaned up after exiting context
1171
-
1172
- >>> # Or use directly:
1173
- >>> GPUMemoryManager.clean()
1174
- """
1175
-
1176
- _logger = get_logger("ins_pricing.gpu")
1177
-
1178
- @classmethod
1179
- def clean(cls, verbose: bool = False) -> None:
1180
- """Clean up GPU memory.
1181
-
1182
- Args:
1183
- verbose: If True, print cleanup details
1184
- """
1185
- gc.collect()
1186
-
1187
- if torch.cuda.is_available():
1188
- torch.cuda.empty_cache()
1189
- torch.cuda.synchronize()
1190
- if verbose:
1191
- cls._logger.debug("CUDA cache cleared and synchronized")
1192
-
1193
- # Optional: Force IPC collect for multi-process scenarios
1194
- if os.environ.get("BAYESOPT_CUDA_IPC_COLLECT", "0") == "1":
1195
- if torch.cuda.is_available():
1196
- try:
1197
- torch.cuda.ipc_collect()
1198
- if verbose:
1199
- cls._logger.debug("CUDA IPC collect performed")
1200
- except Exception:
1201
- pass
1202
-
1203
- @classmethod
1204
- @contextmanager
1205
- def cleanup_context(cls, verbose: bool = False):
1206
- """Context manager that cleans GPU memory on exit.
1207
-
1208
- Args:
1209
- verbose: If True, print cleanup details
1210
-
1211
- Yields:
1212
- None
1213
- """
1214
- try:
1215
- yield
1216
- finally:
1217
- cls.clean(verbose=verbose)
1218
-
1219
- @classmethod
1220
- def move_model_to_cpu(cls, model: nn.Module) -> nn.Module:
1221
- """Move a model to CPU and clean GPU memory.
1222
-
1223
- Args:
1224
- model: PyTorch model to move
1225
-
1226
- Returns:
1227
- Model on CPU
1228
- """
1229
- if model is not None:
1230
- model.to("cpu")
1231
- cls.clean()
1232
- return model
1233
-
1234
- @classmethod
1235
- def get_memory_info(cls) -> Dict[str, int]:
1236
- """Get current GPU memory usage information.
1237
-
1238
- Returns:
1239
- Dictionary with memory info (allocated, reserved, free)
1240
- """
1241
- if not torch.cuda.is_available():
1242
- return {"available": False}
1243
-
1244
- try:
1245
- allocated = torch.cuda.memory_allocated()
1246
- reserved = torch.cuda.memory_reserved()
1247
- free, total = torch.cuda.mem_get_info()
1248
- return {
1249
- "available": True,
1250
- "allocated_mb": allocated // (1024 * 1024),
1251
- "reserved_mb": reserved // (1024 * 1024),
1252
- "free_mb": free // (1024 * 1024),
1253
- "total_mb": total // (1024 * 1024),
1254
- }
1255
- except Exception:
1256
- return {"available": False}
1257
-
1258
-
1259
- # =============================================================================
1260
- # Device Manager
1261
- # =============================================================================
1262
-
1263
- class DeviceManager:
1264
- """Unified device management for model and tensor placement.
1265
-
1266
- This class consolidates device detection and model movement logic
1267
- that was previously duplicated across trainer_base.py and predict.py.
1268
-
1269
- Example:
1270
- >>> device = DeviceManager.get_best_device()
1271
- >>> model = DeviceManager.move_to_device(model)
1272
- """
1273
-
1274
- _logger = get_logger("ins_pricing.device")
1275
- _cached_device: Optional[torch.device] = None
1276
-
1277
- @classmethod
1278
- def get_best_device(cls, prefer_cuda: bool = True) -> torch.device:
1279
- """Get the best available device.
1280
-
1281
- Args:
1282
- prefer_cuda: If True, prefer CUDA over MPS
1283
-
1284
- Returns:
1285
- Best available torch.device
1286
- """
1287
- if cls._cached_device is not None:
1288
- return cls._cached_device
1289
-
1290
- if prefer_cuda and torch.cuda.is_available():
1291
- cls._cached_device = torch.device("cuda")
1292
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
1293
- cls._cached_device = torch.device("mps")
1294
- else:
1295
- cls._cached_device = torch.device("cpu")
1296
-
1297
- cls._logger.debug(f"Selected device: {cls._cached_device}")
1298
- return cls._cached_device
1299
-
1300
- @classmethod
1301
- def move_to_device(cls, model_obj, device: Optional[torch.device] = None) -> None:
1302
- """Move a model object to the specified device.
1303
-
1304
- Handles sklearn-style wrappers that have .ft, .resnet, or .gnn attributes.
1305
-
1306
- Args:
1307
- model_obj: Model object to move (may be sklearn wrapper)
1308
- device: Target device (defaults to best available)
1309
- """
1310
- if model_obj is None:
1311
- return
1312
-
1313
- device = device or cls.get_best_device()
1314
-
1315
- # Update device attribute if present
1316
- if hasattr(model_obj, "device"):
1317
- model_obj.device = device
1318
-
1319
- # Move the main model
1320
- if hasattr(model_obj, "to"):
1321
- model_obj.to(device)
1322
-
1323
- # Move nested submodules (sklearn wrappers)
1324
- for attr_name in ("ft", "resnet", "gnn"):
1325
- submodule = getattr(model_obj, attr_name, None)
1326
- if submodule is not None and hasattr(submodule, "to"):
1327
- submodule.to(device)
1328
-
1329
- @classmethod
1330
- def unwrap_module(cls, module: nn.Module) -> nn.Module:
1331
- """Unwrap DDP or DataParallel wrapper to get the base module.
1332
-
1333
- Args:
1334
- module: Potentially wrapped PyTorch module
1335
-
1336
- Returns:
1337
- Unwrapped base module
1338
- """
1339
- if isinstance(module, (DDP, nn.DataParallel)):
1340
- return module.module
1341
- return module
1342
-
1343
- @classmethod
1344
- def reset_cache(cls) -> None:
1345
- """Reset cached device selection."""
1346
- cls._cached_device = None
1347
-
1348
-
1349
- # =============================================================================
1350
- # Cross-Validation Strategy Resolver
1351
- # =============================================================================
1352
- from sklearn.model_selection import KFold, GroupKFold, TimeSeriesSplit, StratifiedKFold
1353
-
1354
-
1355
- class CVStrategyResolver:
1356
- """Resolver for cross-validation splitting strategies.
1357
-
1358
- This class consolidates CV strategy resolution logic that was previously
1359
- duplicated across trainer_base.py and trainer_ft.py.
1360
-
1361
- Supported strategies:
1362
- - 'random': Standard KFold
1363
- - 'stratified': StratifiedKFold (for classification)
1364
- - 'group': GroupKFold (requires group column)
1365
- - 'time': TimeSeriesSplit (requires time column)
1366
-
1367
- Example:
1368
- >>> resolver = CVStrategyResolver(
1369
- ... strategy='group',
1370
- ... n_splits=5,
1371
- ... group_col='policy_id',
1372
- ... data=train_df,
1373
- ... )
1374
- >>> splitter, groups = resolver.get_splitter()
1375
- >>> for train_idx, val_idx in splitter.split(X, y, groups):
1376
- ... pass
1377
- """
1378
-
1379
- VALID_STRATEGIES = {"random", "stratified", "group", "grouped", "time", "timeseries", "temporal"}
1380
-
1381
- def __init__(
1382
- self,
1383
- strategy: str = "random",
1384
- n_splits: int = 5,
1385
- shuffle: bool = True,
1386
- random_state: Optional[int] = None,
1387
- group_col: Optional[str] = None,
1388
- time_col: Optional[str] = None,
1389
- time_ascending: bool = True,
1390
- data: Optional[pd.DataFrame] = None,
1391
- ):
1392
- """Initialize the CV strategy resolver.
1393
-
1394
- Args:
1395
- strategy: CV strategy name
1396
- n_splits: Number of CV folds
1397
- shuffle: Whether to shuffle for random/stratified
1398
- random_state: Random seed for reproducibility
1399
- group_col: Column name for group-based splitting
1400
- time_col: Column name for time-based splitting
1401
- time_ascending: Sort order for time-based splitting
1402
- data: DataFrame containing group/time columns
1403
- """
1404
- self.strategy = strategy.strip().lower()
1405
- self.n_splits = max(2, int(n_splits))
1406
- self.shuffle = shuffle
1407
- self.random_state = random_state
1408
- self.group_col = group_col
1409
- self.time_col = time_col
1410
- self.time_ascending = time_ascending
1411
- self.data = data
1412
-
1413
- if self.strategy not in self.VALID_STRATEGIES:
1414
- raise ValueError(
1415
- f"Invalid strategy '{strategy}'. "
1416
- f"Valid options: {sorted(self.VALID_STRATEGIES)}"
1417
- )
1418
-
1419
- def get_splitter(self) -> Tuple[Any, Optional[pd.Series]]:
1420
- """Get the appropriate splitter and groups.
1421
-
1422
- Returns:
1423
- Tuple of (splitter, groups) where groups may be None
1424
-
1425
- Raises:
1426
- ValueError: If required columns are missing
1427
- """
1428
- if self.strategy in {"group", "grouped"}:
1429
- return self._get_group_splitter()
1430
- elif self.strategy in {"time", "timeseries", "temporal"}:
1431
- return self._get_time_splitter()
1432
- elif self.strategy == "stratified":
1433
- return self._get_stratified_splitter()
1434
- else:
1435
- return self._get_random_splitter()
1436
-
1437
- def _get_random_splitter(self) -> Tuple[KFold, None]:
1438
- """Get a random KFold splitter."""
1439
- splitter = KFold(
1440
- n_splits=self.n_splits,
1441
- shuffle=self.shuffle,
1442
- random_state=self.random_state if self.shuffle else None,
1443
- )
1444
- return splitter, None
1445
-
1446
- def _get_stratified_splitter(self) -> Tuple[StratifiedKFold, None]:
1447
- """Get a stratified KFold splitter."""
1448
- splitter = StratifiedKFold(
1449
- n_splits=self.n_splits,
1450
- shuffle=self.shuffle,
1451
- random_state=self.random_state if self.shuffle else None,
1452
- )
1453
- return splitter, None
1454
-
1455
- def _get_group_splitter(self) -> Tuple[GroupKFold, pd.Series]:
1456
- """Get a group-based KFold splitter."""
1457
- if not self.group_col:
1458
- raise ValueError("group_col is required for group strategy")
1459
- if self.data is None:
1460
- raise ValueError("data DataFrame is required for group strategy")
1461
- if self.group_col not in self.data.columns:
1462
- raise KeyError(f"group_col '{self.group_col}' not found in data")
1463
-
1464
- groups = self.data[self.group_col]
1465
- splitter = GroupKFold(n_splits=self.n_splits)
1466
- return splitter, groups
1467
-
1468
- def _get_time_splitter(self) -> Tuple[Any, None]:
1469
- """Get a time-series splitter."""
1470
- if not self.time_col:
1471
- raise ValueError("time_col is required for time strategy")
1472
- if self.data is None:
1473
- raise ValueError("data DataFrame is required for time strategy")
1474
- if self.time_col not in self.data.columns:
1475
- raise KeyError(f"time_col '{self.time_col}' not found in data")
1476
-
1477
- splitter = TimeSeriesSplit(n_splits=self.n_splits)
1478
-
1479
- # Create an ordered wrapper that sorts by time column
1480
- order_index = self.data[self.time_col].sort_values(
1481
- ascending=self.time_ascending
1482
- ).index
1483
- order = self.data.index.get_indexer(order_index)
1484
-
1485
- return _OrderedSplitter(splitter, order), None
1486
-
1487
-
1488
- class _OrderedSplitter:
1489
- """Wrapper for splitters that need to respect a specific ordering."""
1490
-
1491
- def __init__(self, base_splitter, order: np.ndarray):
1492
- self.base_splitter = base_splitter
1493
- self.order = order
1494
-
1495
- def split(self, X, y=None, groups=None):
1496
- """Split with ordering applied."""
1497
- n = len(X)
1498
- X_ordered = np.arange(n)[self.order]
1499
- for train_idx, val_idx in self.base_splitter.split(X_ordered):
1500
- yield self.order[train_idx], self.order[val_idx]
1501
-
1502
- def get_n_splits(self, X=None, y=None, groups=None):
1503
- return self.base_splitter.get_n_splits()