ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,1020 @@
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import ctypes
5
+ import copy
6
+ import gc
7
+ import json
8
+ import math
9
+ import os
10
+ import random
11
+ import time
12
+ from contextlib import nullcontext
13
+ from datetime import timedelta
14
+ from pathlib import Path
15
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
16
+
17
+ try: # matplotlib is optional; avoid hard import failures in headless/minimal envs
18
+ import matplotlib
19
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
20
+ matplotlib.use("Agg")
21
+ import matplotlib.pyplot as plt
22
+ _MPL_IMPORT_ERROR: Optional[BaseException] = None
23
+ except Exception as exc: # pragma: no cover - optional dependency
24
+ matplotlib = None # type: ignore[assignment]
25
+ plt = None # type: ignore[assignment]
26
+ _MPL_IMPORT_ERROR = exc
27
+ import numpy as np
28
+ import optuna
29
+ import pandas as pd
30
+ import torch
31
+ import torch.distributed as dist
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from torch.cuda.amp import autocast, GradScaler
35
+ from torch.nn.parallel import DistributedDataParallel as DDP
36
+ from torch.nn.utils import clip_grad_norm_
37
+ from torch.utils.data import DataLoader, DistributedSampler
38
+
39
+ # Optional: unify plotting with shared plotting package
40
+ try:
41
+ from .plotting import curves as plot_curves_common
42
+ from .plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
43
+ except Exception: # pragma: no cover
44
+ try:
45
+ from ins_pricing.plotting import curves as plot_curves_common
46
+ from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
47
+ except Exception: # pragma: no cover
48
+ plot_curves_common = None
49
+ plot_loss_curve_common = None
50
+ # Limit CUDA allocator split size to reduce fragmentation and OOM risk.
51
+ # Override via PYTORCH_CUDA_ALLOC_CONF if needed.
52
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:256")
53
+
54
+
55
+ # Constants and utility helpers
56
+ # =============================================================================
57
+ torch.backends.cudnn.benchmark = True
58
+ EPS = 1e-8
59
+
60
+
61
+ def _plot_skip(label: str) -> None:
62
+ if _MPL_IMPORT_ERROR is not None:
63
+ print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
64
+ else:
65
+ print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
66
+
67
+
68
+ def set_global_seed(seed: int) -> None:
69
+ random.seed(seed)
70
+ np.random.seed(seed)
71
+ torch.manual_seed(seed)
72
+ if torch.cuda.is_available():
73
+ torch.cuda.manual_seed_all(seed)
74
+
75
+
76
+ class IOUtils:
77
+ # File and path utilities.
78
+
79
+ @staticmethod
80
+ def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
81
+ with open(file_path, mode='r', encoding='utf-8') as file:
82
+ reader = csv.DictReader(file)
83
+ return [
84
+ dict(filter(lambda item: item[0] != '', row.items()))
85
+ for row in reader
86
+ ]
87
+
88
+ @staticmethod
89
+ def _sanitize_params_dict(params: Dict[str, Any]) -> Dict[str, Any]:
90
+ # Filter index-like columns such as "Unnamed: 0" from pandas I/O.
91
+ return {
92
+ k: v
93
+ for k, v in (params or {}).items()
94
+ if k and not str(k).startswith("Unnamed")
95
+ }
96
+
97
+ @staticmethod
98
+ def load_params_file(path: str) -> Dict[str, Any]:
99
+ """Load parameter dict from JSON/CSV/TSV files.
100
+
101
+ - JSON: accept dict or {"best_params": {...}} wrapper
102
+ - CSV/TSV: read the first row as params
103
+ """
104
+ file_path = Path(path).expanduser().resolve()
105
+ if not file_path.exists():
106
+ raise FileNotFoundError(f"params file not found: {file_path}")
107
+ suffix = file_path.suffix.lower()
108
+ if suffix == ".json":
109
+ payload = json.loads(file_path.read_text(
110
+ encoding="utf-8", errors="replace"))
111
+ if isinstance(payload, dict) and "best_params" in payload:
112
+ payload = payload.get("best_params") or {}
113
+ if not isinstance(payload, dict):
114
+ raise ValueError(
115
+ f"Invalid JSON params file (expect dict): {file_path}")
116
+ return IOUtils._sanitize_params_dict(dict(payload))
117
+ if suffix in (".csv", ".tsv"):
118
+ df = pd.read_csv(file_path, sep="\t" if suffix == ".tsv" else ",")
119
+ if df.empty:
120
+ raise ValueError(f"Empty params file: {file_path}")
121
+ params = df.iloc[0].to_dict()
122
+ return IOUtils._sanitize_params_dict(params)
123
+ raise ValueError(
124
+ f"Unsupported params file type '{suffix}': {file_path}")
125
+
126
+ @staticmethod
127
+ def ensure_parent_dir(file_path: str) -> None:
128
+ # Create parent directories when missing.
129
+ directory = os.path.dirname(file_path)
130
+ if directory:
131
+ os.makedirs(directory, exist_ok=True)
132
+
133
+
134
+ class TrainingUtils:
135
+ # Small helpers used during training.
136
+
137
+ @staticmethod
138
+ def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
139
+ estimated = int((learning_rate / 1e-4) ** 0.5 *
140
+ (data_size / max(batch_num, 1)))
141
+ return max(1, min(data_size, max(minimum, estimated)))
142
+
143
+ @staticmethod
144
+ def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
145
+ # Clamp predictions to positive values for stability.
146
+ pred_clamped = torch.clamp(pred, min=eps)
147
+ if p == 1:
148
+ term1 = target * torch.log(target / pred_clamped + eps) # Poisson
149
+ term2 = -target + pred_clamped
150
+ term3 = 0
151
+ elif p == 0:
152
+ term1 = 0.5 * torch.pow(target - pred_clamped, 2) # Gaussian
153
+ term2 = 0
154
+ term3 = 0
155
+ elif p == 2:
156
+ term1 = torch.log(pred_clamped / target + eps) # Gamma
157
+ term2 = -target / pred_clamped + 1
158
+ term3 = 0
159
+ else:
160
+ term1 = torch.pow(target, 2 - p) / ((1 - p) * (2 - p))
161
+ term2 = target * torch.pow(pred_clamped, 1 - p) / (1 - p)
162
+ term3 = torch.pow(pred_clamped, 2 - p) / (2 - p)
163
+ return torch.nan_to_num( # Tweedie negative log-likelihood (constant omitted)
164
+ 2 * (term1 - term2 + term3),
165
+ nan=eps,
166
+ posinf=max_clip,
167
+ neginf=-max_clip
168
+ )
169
+
170
+ @staticmethod
171
+ def free_cuda() -> None:
172
+ print(">>> Moving all models to CPU...")
173
+ for obj in gc.get_objects():
174
+ try:
175
+ if hasattr(obj, "to") and callable(obj.to):
176
+ obj.to("cpu")
177
+ except Exception:
178
+ pass
179
+
180
+ print(">>> Releasing tensor/optimizer/DataLoader references...")
181
+ gc.collect()
182
+
183
+ print(">>> Clearing CUDA cache...")
184
+ if torch.cuda.is_available():
185
+ torch.cuda.empty_cache()
186
+ torch.cuda.synchronize()
187
+ print(">>> CUDA memory released.")
188
+ else:
189
+ print(">>> CUDA not available; cleanup skipped.")
190
+
191
+
192
+ class DistributedUtils:
193
+ _cached_state: Optional[tuple] = None
194
+
195
+ @staticmethod
196
+ def setup_ddp():
197
+ """Initialize the DDP process group for distributed training."""
198
+ if dist.is_initialized():
199
+ if DistributedUtils._cached_state is None:
200
+ rank = dist.get_rank()
201
+ world_size = dist.get_world_size()
202
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
203
+ DistributedUtils._cached_state = (
204
+ True,
205
+ local_rank,
206
+ rank,
207
+ world_size,
208
+ )
209
+ return DistributedUtils._cached_state
210
+
211
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
212
+ rank = int(os.environ["RANK"])
213
+ world_size = int(os.environ["WORLD_SIZE"])
214
+ local_rank = int(os.environ["LOCAL_RANK"])
215
+
216
+ if os.name == "nt" and torch.cuda.is_available() and world_size > 1:
217
+ print(
218
+ ">>> DDP Setup Disabled: Windows CUDA DDP is not supported. "
219
+ "Falling back to single process."
220
+ )
221
+ return False, 0, 0, 1
222
+
223
+ if torch.cuda.is_available():
224
+ torch.cuda.set_device(local_rank)
225
+
226
+ # Default to gloo for CPU/Mac compatibility, upgrade to nccl if valid cuda
227
+ backend = "gloo"
228
+ if torch.cuda.is_available() and os.name != "nt":
229
+ try:
230
+ if getattr(dist, "is_nccl_available", lambda: False)():
231
+ backend = "nccl"
232
+ except Exception:
233
+ backend = "gloo"
234
+
235
+ # Set timeout
236
+ timeout_seconds = int(os.environ.get(
237
+ "BAYESOPT_DDP_TIMEOUT_SECONDS", "1800"))
238
+ timeout = timedelta(seconds=max(1, timeout_seconds))
239
+
240
+ dist.init_process_group(
241
+ backend=backend, init_method="env://", timeout=timeout)
242
+ print(
243
+ f">>> DDP Initialized ({backend}, timeout={timeout_seconds}s): "
244
+ f"Rank {rank}/{world_size}, Local Rank {local_rank}"
245
+ )
246
+ DistributedUtils._cached_state = (
247
+ True,
248
+ local_rank,
249
+ rank,
250
+ world_size,
251
+ )
252
+ return DistributedUtils._cached_state
253
+ else:
254
+ # Not a distributed run
255
+ pass
256
+ return False, 0, 0, 1
257
+
258
+ @staticmethod
259
+ def cleanup_ddp():
260
+ """Destroy the DDP process group and clear cached state."""
261
+ if dist.is_initialized():
262
+ dist.destroy_process_group()
263
+ DistributedUtils._cached_state = None
264
+
265
+ @staticmethod
266
+ def is_main_process():
267
+ return not dist.is_initialized() or dist.get_rank() == 0
268
+
269
+ @staticmethod
270
+ def world_size() -> int:
271
+ return dist.get_world_size() if dist.is_initialized() else 1
272
+
273
+
274
+ class PlotUtils:
275
+ # Plot helpers shared across models.
276
+
277
+ @staticmethod
278
+ def split_data(data: pd.DataFrame, col_nme: str, wgt_nme: str, n_bins: int = 10) -> pd.DataFrame:
279
+ data_sorted = data.sort_values(by=col_nme, ascending=True).copy()
280
+ data_sorted['cum_weight'] = data_sorted[wgt_nme].cumsum()
281
+ w_sum = data_sorted[wgt_nme].sum()
282
+ if w_sum <= EPS:
283
+ data_sorted.loc[:, 'bins'] = 0
284
+ else:
285
+ data_sorted.loc[:, 'bins'] = np.floor(
286
+ data_sorted['cum_weight'] * float(n_bins) / w_sum
287
+ )
288
+ data_sorted.loc[(data_sorted['bins'] == n_bins),
289
+ 'bins'] = n_bins - 1
290
+ return data_sorted.groupby(['bins'], observed=True).sum(numeric_only=True)
291
+
292
+ @staticmethod
293
+ def plot_lift_ax(ax, plot_data, title, pred_label='Predicted', act_label='Actual', weight_label='Earned Exposure'):
294
+ ax.plot(plot_data.index, plot_data['act_v'],
295
+ label=act_label, color='red')
296
+ ax.plot(plot_data.index, plot_data['exp_v'],
297
+ label=pred_label, color='blue')
298
+ ax.set_title(title, fontsize=8)
299
+ ax.set_xticks(plot_data.index)
300
+ ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
301
+ ax.tick_params(axis='y', labelsize=6)
302
+ ax.legend(loc='upper left', fontsize=5, frameon=False)
303
+ ax.margins(0.05)
304
+ ax2 = ax.twinx()
305
+ ax2.bar(plot_data.index, plot_data['weight'],
306
+ alpha=0.5, color='seagreen',
307
+ label=weight_label)
308
+ ax2.tick_params(axis='y', labelsize=6)
309
+ ax2.legend(loc='upper right', fontsize=5, frameon=False)
310
+
311
+ @staticmethod
312
+ def plot_dlift_ax(ax, plot_data, title, label1, label2, act_label='Actual', weight_label='Earned Exposure'):
313
+ ax.plot(plot_data.index, plot_data['act_v'],
314
+ label=act_label, color='red')
315
+ ax.plot(plot_data.index, plot_data['exp_v1'],
316
+ label=label1, color='blue')
317
+ ax.plot(plot_data.index, plot_data['exp_v2'],
318
+ label=label2, color='black')
319
+ ax.set_title(title, fontsize=8)
320
+ ax.set_xticks(plot_data.index)
321
+ ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
322
+ ax.set_xlabel(f'{label1} / {label2}', fontsize=6)
323
+ ax.tick_params(axis='y', labelsize=6)
324
+ ax.legend(loc='upper left', fontsize=5, frameon=False)
325
+ ax.margins(0.1)
326
+ ax2 = ax.twinx()
327
+ ax2.bar(plot_data.index, plot_data['weight'],
328
+ alpha=0.5, color='seagreen',
329
+ label=weight_label)
330
+ ax2.tick_params(axis='y', labelsize=6)
331
+ ax2.legend(loc='upper right', fontsize=5, frameon=False)
332
+
333
+ @staticmethod
334
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
335
+ weight_list, tgt_nme, n_bins: int = 10,
336
+ fig_nme: str = 'Lift Chart'):
337
+ if plot_curves_common is not None:
338
+ save_path = os.path.join(
339
+ os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
340
+ plot_curves_common.plot_lift_curve(
341
+ pred_model,
342
+ w_act_list,
343
+ weight_list,
344
+ n_bins=n_bins,
345
+ title=f'Lift Chart of {tgt_nme}',
346
+ pred_label='Predicted',
347
+ act_label='Actual',
348
+ weight_label='Earned Exposure',
349
+ pred_weighted=False,
350
+ actual_weighted=True,
351
+ save_path=save_path,
352
+ show=False,
353
+ )
354
+ return
355
+ if plt is None:
356
+ _plot_skip("lift plot")
357
+ return
358
+ lift_data = pd.DataFrame()
359
+ lift_data.loc[:, 'pred'] = pred_model
360
+ lift_data.loc[:, 'w_pred'] = w_pred_list
361
+ lift_data.loc[:, 'act'] = w_act_list
362
+ lift_data.loc[:, 'weight'] = weight_list
363
+ plot_data = PlotUtils.split_data(lift_data, 'pred', 'weight', n_bins)
364
+ plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
365
+ plot_data['act_v'] = plot_data['act'] / plot_data['weight']
366
+ plot_data.reset_index(inplace=True)
367
+
368
+ fig = plt.figure(figsize=(7, 5))
369
+ ax = fig.add_subplot(111)
370
+ PlotUtils.plot_lift_ax(ax, plot_data, f'Lift Chart of {tgt_nme}')
371
+ plt.subplots_adjust(wspace=0.3)
372
+
373
+ save_path = os.path.join(
374
+ os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
375
+ IOUtils.ensure_parent_dir(save_path)
376
+ plt.savefig(save_path, dpi=300)
377
+ plt.close(fig)
378
+
379
+ @staticmethod
380
+ def plot_dlift_list(pred_model_1, pred_model_2,
381
+ model_nme_1, model_nme_2,
382
+ tgt_nme,
383
+ w_list, w_act_list, n_bins: int = 10,
384
+ fig_nme: str = 'Double Lift Chart'):
385
+ if plot_curves_common is not None:
386
+ save_path = os.path.join(
387
+ os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
388
+ plot_curves_common.plot_double_lift_curve(
389
+ pred_model_1,
390
+ pred_model_2,
391
+ w_act_list,
392
+ w_list,
393
+ n_bins=n_bins,
394
+ title=f'Double Lift Chart of {tgt_nme}',
395
+ label1=model_nme_1,
396
+ label2=model_nme_2,
397
+ pred1_weighted=False,
398
+ pred2_weighted=False,
399
+ actual_weighted=True,
400
+ save_path=save_path,
401
+ show=False,
402
+ )
403
+ return
404
+ if plt is None:
405
+ _plot_skip("double lift plot")
406
+ return
407
+ lift_data = pd.DataFrame()
408
+ lift_data.loc[:, 'pred1'] = pred_model_1
409
+ lift_data.loc[:, 'pred2'] = pred_model_2
410
+ lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
411
+ lift_data.loc[:, 'act'] = w_act_list
412
+ lift_data.loc[:, 'weight'] = w_list
413
+ lift_data.loc[:, 'w_pred1'] = lift_data['pred1'] * lift_data['weight']
414
+ lift_data.loc[:, 'w_pred2'] = lift_data['pred2'] * lift_data['weight']
415
+ plot_data = PlotUtils.split_data(
416
+ lift_data, 'diff_ly', 'weight', n_bins)
417
+ plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
418
+ plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
419
+ plot_data['act_v'] = plot_data['act']/plot_data['act']
420
+ plot_data.reset_index(inplace=True)
421
+
422
+ fig = plt.figure(figsize=(7, 5))
423
+ ax = fig.add_subplot(111)
424
+ PlotUtils.plot_dlift_ax(
425
+ ax, plot_data, f'Double Lift Chart of {tgt_nme}', model_nme_1, model_nme_2)
426
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
427
+
428
+ save_path = os.path.join(
429
+ os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
430
+ IOUtils.ensure_parent_dir(save_path)
431
+ plt.savefig(save_path, dpi=300)
432
+ plt.close(fig)
433
+
434
+
435
+ def infer_factor_and_cate_list(train_df: pd.DataFrame,
436
+ test_df: pd.DataFrame,
437
+ resp_nme: str,
438
+ weight_nme: str,
439
+ binary_resp_nme: Optional[str] = None,
440
+ factor_nmes: Optional[List[str]] = None,
441
+ cate_list: Optional[List[str]] = None,
442
+ infer_categorical_max_unique: int = 50,
443
+ infer_categorical_max_ratio: float = 0.05) -> Tuple[List[str], List[str]]:
444
+ """Infer factor_nmes/cate_list when feature names are not provided.
445
+
446
+ Rules:
447
+ - factor_nmes: start from shared train/test columns, exclude target/weight/(optional binary target).
448
+ - cate_list: object/category/bool plus low-cardinality integer columns.
449
+ - Always intersect with shared train/test columns to avoid mismatches.
450
+ """
451
+ excluded = {resp_nme, weight_nme}
452
+ if binary_resp_nme:
453
+ excluded.add(binary_resp_nme)
454
+
455
+ common_cols = [c for c in train_df.columns if c in test_df.columns]
456
+ if factor_nmes is None:
457
+ factors = [c for c in common_cols if c not in excluded]
458
+ else:
459
+ factors = [
460
+ c for c in factor_nmes if c in common_cols and c not in excluded]
461
+
462
+ if cate_list is not None:
463
+ cats = [c for c in cate_list if c in factors]
464
+ return factors, cats
465
+
466
+ n_rows = max(1, len(train_df))
467
+ cats: List[str] = []
468
+ for col in factors:
469
+ s = train_df[col]
470
+ if pd.api.types.is_bool_dtype(s) or pd.api.types.is_object_dtype(s) or isinstance(s.dtype, pd.CategoricalDtype):
471
+ cats.append(col)
472
+ continue
473
+ if pd.api.types.is_integer_dtype(s):
474
+ nunique = int(s.nunique(dropna=True))
475
+ if nunique <= infer_categorical_max_unique or (nunique / n_rows) <= infer_categorical_max_ratio:
476
+ cats.append(col)
477
+ return factors, cats
478
+
479
+
480
+ # Backward-compatible functional wrappers
481
+ def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
482
+ return IOUtils.csv_to_dict(file_path)
483
+
484
+
485
+ def ensure_parent_dir(file_path: str) -> None:
486
+ IOUtils.ensure_parent_dir(file_path)
487
+
488
+
489
+ def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
490
+ return TrainingUtils.compute_batch_size(data_size, learning_rate, batch_num, minimum)
491
+
492
+
493
+ # Tweedie deviance loss for PyTorch.
494
+ # Reference: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-poisson-gamma-and-tweedie-deviances
495
+ def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
496
+ return TrainingUtils.tweedie_loss(pred, target, p=p, eps=eps, max_clip=max_clip)
497
+
498
+
499
+ # CUDA memory release helper.
500
+ def free_cuda():
501
+ TrainingUtils.free_cuda()
502
+
503
+
504
+ class TorchTrainerMixin:
505
+ # Shared helpers for Torch tabular trainers.
506
+
507
+ def _device_type(self) -> str:
508
+ return getattr(self, "device", torch.device("cpu")).type
509
+
510
+ def _resolve_resource_profile(self) -> str:
511
+ profile = getattr(self, "resource_profile", None)
512
+ if not profile:
513
+ profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
514
+ profile = str(profile).strip().lower()
515
+ if profile in {"cpu", "mps", "cuda"}:
516
+ profile = "auto"
517
+ if profile not in {"auto", "throughput", "memory_saving"}:
518
+ profile = "auto"
519
+ if profile == "auto" and self._device_type() == "cuda":
520
+ profile = "throughput"
521
+ return profile
522
+
523
+ def _log_resource_summary_once(self, profile: str) -> None:
524
+ if getattr(self, "_resource_summary_logged", False):
525
+ return
526
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
527
+ return
528
+ self._resource_summary_logged = True
529
+ device = getattr(self, "device", torch.device("cpu"))
530
+ device_type = self._device_type()
531
+ cpu_count = os.cpu_count() or 1
532
+ cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
533
+ mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
534
+ ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
535
+ data_parallel = bool(getattr(self, "use_data_parallel", False))
536
+ print(
537
+ f">>> Resource summary: device={device}, device_type={device_type}, "
538
+ f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
539
+ f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
540
+ )
541
+
542
+ def _available_system_memory(self) -> Optional[int]:
543
+ if os.name == "nt":
544
+ class _MemStatus(ctypes.Structure):
545
+ _fields_ = [
546
+ ("dwLength", ctypes.c_ulong),
547
+ ("dwMemoryLoad", ctypes.c_ulong),
548
+ ("ullTotalPhys", ctypes.c_ulonglong),
549
+ ("ullAvailPhys", ctypes.c_ulonglong),
550
+ ("ullTotalPageFile", ctypes.c_ulonglong),
551
+ ("ullAvailPageFile", ctypes.c_ulonglong),
552
+ ("ullTotalVirtual", ctypes.c_ulonglong),
553
+ ("ullAvailVirtual", ctypes.c_ulonglong),
554
+ ("sullAvailExtendedVirtual", ctypes.c_ulonglong),
555
+ ]
556
+ status = _MemStatus()
557
+ status.dwLength = ctypes.sizeof(_MemStatus)
558
+ if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
559
+ return int(status.ullAvailPhys)
560
+ return None
561
+ try:
562
+ pages = os.sysconf("SC_AVPHYS_PAGES")
563
+ page_size = os.sysconf("SC_PAGE_SIZE")
564
+ return int(pages * page_size)
565
+ except Exception:
566
+ return None
567
+
568
+ def _available_cuda_memory(self) -> Optional[int]:
569
+ if not torch.cuda.is_available():
570
+ return None
571
+ try:
572
+ free_mem, _total_mem = torch.cuda.mem_get_info()
573
+ except Exception:
574
+ return None
575
+ return int(free_mem)
576
+
577
+ def _estimate_sample_bytes(self, dataset) -> Optional[int]:
578
+ try:
579
+ if len(dataset) == 0:
580
+ return None
581
+ sample = dataset[0]
582
+ except Exception:
583
+ return None
584
+
585
+ def _bytes(obj) -> int:
586
+ if obj is None:
587
+ return 0
588
+ if torch.is_tensor(obj):
589
+ return int(obj.element_size() * obj.nelement())
590
+ if isinstance(obj, np.ndarray):
591
+ return int(obj.nbytes)
592
+ if isinstance(obj, (list, tuple)):
593
+ return int(sum(_bytes(item) for item in obj))
594
+ if isinstance(obj, dict):
595
+ return int(sum(_bytes(item) for item in obj.values()))
596
+ return 0
597
+
598
+ sample_bytes = _bytes(sample)
599
+ return int(sample_bytes) if sample_bytes > 0 else None
600
+
601
+ def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
602
+ if batch_size <= 1:
603
+ return batch_size
604
+ sample_bytes = self._estimate_sample_bytes(dataset)
605
+ if sample_bytes is None:
606
+ return batch_size
607
+ device_type = self._device_type()
608
+ if device_type == "cuda":
609
+ available = self._available_cuda_memory()
610
+ if available is None:
611
+ return batch_size
612
+ if profile == "throughput":
613
+ budget_ratio = 0.8
614
+ overhead = 8.0
615
+ elif profile == "memory_saving":
616
+ budget_ratio = 0.5
617
+ overhead = 14.0
618
+ else:
619
+ budget_ratio = 0.6
620
+ overhead = 12.0
621
+ else:
622
+ available = self._available_system_memory()
623
+ if available is None:
624
+ return batch_size
625
+ if profile == "throughput":
626
+ budget_ratio = 0.4
627
+ overhead = 1.8
628
+ elif profile == "memory_saving":
629
+ budget_ratio = 0.25
630
+ overhead = 3.0
631
+ else:
632
+ budget_ratio = 0.3
633
+ overhead = 2.6
634
+ budget = int(available * budget_ratio)
635
+ per_sample = int(sample_bytes * overhead)
636
+ if per_sample <= 0:
637
+ return batch_size
638
+ max_batch = max(1, int(budget // per_sample))
639
+ if max_batch < batch_size:
640
+ print(
641
+ f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
642
+ f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
643
+ )
644
+ return min(batch_size, max_batch)
645
+
646
+ def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
647
+ if os.name == 'nt':
648
+ return 0
649
+ if getattr(self, "is_ddp_enabled", False):
650
+ return 0
651
+ profile = profile or self._resolve_resource_profile()
652
+ if profile == "memory_saving":
653
+ return 0
654
+ worker_cap = min(int(max_workers), os.cpu_count() or 1)
655
+ if self._device_type() == "mps":
656
+ worker_cap = min(worker_cap, 2)
657
+ return worker_cap
658
+
659
+ def _build_dataloader(self,
660
+ dataset,
661
+ N: int,
662
+ base_bs_gpu: tuple,
663
+ base_bs_cpu: tuple,
664
+ min_bs: int = 64,
665
+ target_effective_cuda: int = 1024,
666
+ target_effective_cpu: int = 512,
667
+ large_threshold: int = 200_000,
668
+ mid_threshold: int = 50_000):
669
+ profile = self._resolve_resource_profile()
670
+ self._log_resource_summary_once(profile)
671
+ batch_size = TrainingUtils.compute_batch_size(
672
+ data_size=len(dataset),
673
+ learning_rate=self.learning_rate,
674
+ batch_num=self.batch_num,
675
+ minimum=min_bs
676
+ )
677
+ gpu_large, gpu_mid, gpu_small = base_bs_gpu
678
+ cpu_mid, cpu_small = base_bs_cpu
679
+
680
+ if self._device_type() == 'cuda':
681
+ device_count = torch.cuda.device_count()
682
+ if getattr(self, "is_ddp_enabled", False):
683
+ device_count = 1
684
+ # In multi-GPU, increase min batch size so each GPU gets enough data.
685
+ if device_count > 1:
686
+ min_bs = min_bs * device_count
687
+ print(
688
+ f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
689
+
690
+ if N > large_threshold:
691
+ base_bs = gpu_large * device_count
692
+ elif N > mid_threshold:
693
+ base_bs = gpu_mid * device_count
694
+ else:
695
+ base_bs = gpu_small * device_count
696
+ else:
697
+ base_bs = cpu_mid if N > mid_threshold else cpu_small
698
+
699
+ # Recompute batch_size to respect the adjusted min_bs.
700
+ batch_size = TrainingUtils.compute_batch_size(
701
+ data_size=len(dataset),
702
+ learning_rate=self.learning_rate,
703
+ batch_num=self.batch_num,
704
+ minimum=min_bs
705
+ )
706
+ batch_size = min(batch_size, base_bs, N)
707
+ batch_size = self._cap_batch_size_by_memory(
708
+ dataset, batch_size, profile)
709
+
710
+ target_effective_bs = target_effective_cuda if self._device_type(
711
+ ) == 'cuda' else target_effective_cpu
712
+ if getattr(self, "is_ddp_enabled", False):
713
+ world_size = max(1, DistributedUtils.world_size())
714
+ target_effective_bs = max(1, target_effective_bs // world_size)
715
+
716
+ world_size = getattr(self, "world_size", 1) if getattr(
717
+ self, "is_ddp_enabled", False) else 1
718
+ samples_per_rank = math.ceil(
719
+ N / max(1, world_size)) if world_size > 1 else N
720
+ steps_per_epoch = max(
721
+ 1, math.ceil(samples_per_rank / max(1, batch_size)))
722
+ # Limit gradient accumulation to avoid scaling beyond actual batches.
723
+ desired_accum = max(1, target_effective_bs // max(1, batch_size))
724
+ accum_steps = max(1, min(desired_accum, steps_per_epoch))
725
+
726
+ # Linux (posix) uses fork; Windows (nt) uses spawn with higher overhead.
727
+ workers = self._resolve_num_workers(8, profile=profile)
728
+ prefetch_factor = None
729
+ if workers > 0:
730
+ prefetch_factor = 4 if profile == "throughput" else 2
731
+ persistent = workers > 0 and profile != "memory_saving"
732
+ print(
733
+ f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
734
+ f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
735
+ sampler = None
736
+ if dist.is_initialized():
737
+ sampler = DistributedSampler(dataset, shuffle=True)
738
+ shuffle = False # DistributedSampler handles shuffling.
739
+ else:
740
+ shuffle = True
741
+
742
+ dataloader = DataLoader(
743
+ dataset,
744
+ batch_size=batch_size,
745
+ shuffle=shuffle,
746
+ sampler=sampler,
747
+ num_workers=workers,
748
+ pin_memory=(self._device_type() == 'cuda'),
749
+ persistent_workers=persistent,
750
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
751
+ )
752
+ return dataloader, accum_steps
753
+
754
+ def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
755
+ profile = self._resolve_resource_profile()
756
+ val_bs = accum_steps * train_dataloader.batch_size
757
+ val_workers = self._resolve_num_workers(4, profile=profile)
758
+ prefetch_factor = None
759
+ if val_workers > 0:
760
+ prefetch_factor = 2
761
+ return DataLoader(
762
+ dataset,
763
+ batch_size=val_bs,
764
+ shuffle=False,
765
+ num_workers=val_workers,
766
+ pin_memory=(self._device_type() == 'cuda'),
767
+ persistent_workers=(val_workers > 0 and profile != "memory_saving"),
768
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
769
+ )
770
+
771
+ def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
772
+ task = getattr(self, "task_type", "regression")
773
+ if task == 'classification':
774
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
775
+ return loss_fn(y_pred, y_true).view(-1)
776
+ if apply_softplus:
777
+ y_pred = F.softplus(y_pred)
778
+ y_pred = torch.clamp(y_pred, min=1e-6)
779
+ power = getattr(self, "tw_power", 1.5)
780
+ return tweedie_loss(y_pred, y_true, p=power).view(-1)
781
+
782
+ def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
783
+ losses = self._compute_losses(
784
+ y_pred, y_true, apply_softplus=apply_softplus)
785
+ weighted_loss = (losses * weights.view(-1)).sum() / \
786
+ torch.clamp(weights.sum(), min=EPS)
787
+ return weighted_loss
788
+
789
+ def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
790
+ ignore_keys: Optional[List[str]] = None):
791
+ if val_loss < best_loss:
792
+ ignore_keys = ignore_keys or []
793
+ state_dict = {
794
+ k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
795
+ for k, v in model.state_dict().items()
796
+ if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
797
+ }
798
+ return val_loss, state_dict, 0, False
799
+ patience_counter += 1
800
+ should_stop = best_state is not None and patience_counter >= getattr(
801
+ self, "patience", 0)
802
+ return best_loss, best_state, patience_counter, should_stop
803
+
804
+ def _train_model(self,
805
+ model,
806
+ dataloader,
807
+ accum_steps,
808
+ optimizer,
809
+ scaler,
810
+ forward_fn,
811
+ val_forward_fn=None,
812
+ apply_softplus: bool = False,
813
+ clip_fn=None,
814
+ trial: Optional[optuna.trial.Trial] = None,
815
+ loss_curve_path: Optional[str] = None):
816
+ device_type = self._device_type()
817
+ best_loss = float('inf')
818
+ best_state = None
819
+ patience_counter = 0
820
+ stop_training = False
821
+ train_history: List[float] = []
822
+ val_history: List[float] = []
823
+
824
+ is_ddp_model = isinstance(model, DDP)
825
+
826
+ for epoch in range(1, getattr(self, "epochs", 1) + 1):
827
+ epoch_start_ts = time.time()
828
+ val_weighted_loss = None
829
+ if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
830
+ self.dataloader_sampler.set_epoch(epoch)
831
+
832
+ model.train()
833
+ optimizer.zero_grad()
834
+
835
+ epoch_loss_sum = None
836
+ epoch_weight_sum = None
837
+ for step, batch in enumerate(dataloader):
838
+ is_update_step = ((step + 1) % accum_steps == 0) or \
839
+ ((step + 1) == len(dataloader))
840
+ sync_cm = model.no_sync if (
841
+ is_ddp_model and not is_update_step) else nullcontext
842
+
843
+ with sync_cm():
844
+ with autocast(enabled=(device_type == 'cuda')):
845
+ y_pred, y_true, w = forward_fn(batch)
846
+ weighted_loss = self._compute_weighted_loss(
847
+ y_pred, y_true, w, apply_softplus=apply_softplus)
848
+ loss_for_backward = weighted_loss / accum_steps
849
+
850
+ batch_weight = torch.clamp(
851
+ w.detach().sum(), min=EPS).to(dtype=torch.float32)
852
+ loss_val = weighted_loss.detach().to(dtype=torch.float32)
853
+ if epoch_loss_sum is None:
854
+ epoch_loss_sum = torch.zeros(
855
+ (), device=batch_weight.device, dtype=torch.float32)
856
+ epoch_weight_sum = torch.zeros(
857
+ (), device=batch_weight.device, dtype=torch.float32)
858
+ epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
859
+ epoch_weight_sum = epoch_weight_sum + batch_weight
860
+ scaler.scale(loss_for_backward).backward()
861
+
862
+ if is_update_step:
863
+ if clip_fn is not None:
864
+ clip_fn()
865
+ scaler.step(optimizer)
866
+ scaler.update()
867
+ optimizer.zero_grad()
868
+
869
+ if epoch_loss_sum is None or epoch_weight_sum is None:
870
+ train_epoch_loss = 0.0
871
+ else:
872
+ train_epoch_loss = (
873
+ epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
874
+ ).item()
875
+ train_history.append(float(train_epoch_loss))
876
+
877
+ if val_forward_fn is not None:
878
+ should_compute_val = (not dist.is_initialized()
879
+ or DistributedUtils.is_main_process())
880
+ val_device = getattr(self, "device", torch.device("cpu"))
881
+ if not isinstance(val_device, torch.device):
882
+ val_device = torch.device(val_device)
883
+ loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
884
+ "cpu")
885
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
886
+
887
+ if should_compute_val:
888
+ model.eval()
889
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
890
+ val_result = val_forward_fn()
891
+ if isinstance(val_result, tuple) and len(val_result) == 3:
892
+ y_val_pred, y_val_true, w_val = val_result
893
+ val_weighted_loss = self._compute_weighted_loss(
894
+ y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
895
+ else:
896
+ val_weighted_loss = val_result
897
+ val_loss_tensor[0] = float(val_weighted_loss)
898
+
899
+ if dist.is_initialized():
900
+ dist.broadcast(val_loss_tensor, src=0)
901
+ val_weighted_loss = float(val_loss_tensor.item())
902
+
903
+ val_history.append(val_weighted_loss)
904
+
905
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
906
+ val_weighted_loss, best_loss, best_state, patience_counter, model)
907
+
908
+ prune_flag = False
909
+ is_main_rank = DistributedUtils.is_main_process()
910
+ if trial is not None and (not dist.is_initialized() or is_main_rank):
911
+ trial.report(val_weighted_loss, epoch)
912
+ prune_flag = trial.should_prune()
913
+
914
+ if dist.is_initialized():
915
+ prune_device = getattr(self, "device", torch.device("cpu"))
916
+ if not isinstance(prune_device, torch.device):
917
+ prune_device = torch.device(prune_device)
918
+ prune_tensor = torch.zeros(1, device=prune_device)
919
+ if is_main_rank:
920
+ prune_tensor.fill_(1 if prune_flag else 0)
921
+ dist.broadcast(prune_tensor, src=0)
922
+ prune_flag = bool(prune_tensor.item())
923
+
924
+ if prune_flag:
925
+ raise optuna.TrialPruned()
926
+
927
+ if stop_training:
928
+ break
929
+
930
+ should_log_epoch = (not dist.is_initialized()
931
+ or DistributedUtils.is_main_process())
932
+ if should_log_epoch:
933
+ elapsed = int(time.time() - epoch_start_ts)
934
+ if val_weighted_loss is None:
935
+ print(
936
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
937
+ f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
938
+ flush=True,
939
+ )
940
+ else:
941
+ print(
942
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
943
+ f"train_loss={float(train_epoch_loss):.6f} "
944
+ f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
945
+ flush=True,
946
+ )
947
+
948
+ history = {"train": train_history, "val": val_history}
949
+ self._plot_loss_curve(history, loss_curve_path)
950
+ return best_state, history
951
+
952
+ def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
953
+ if not save_path:
954
+ return
955
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
956
+ return
957
+ train_hist = history.get("train", []) if history else []
958
+ val_hist = history.get("val", []) if history else []
959
+ if not train_hist and not val_hist:
960
+ return
961
+ if plot_loss_curve_common is not None:
962
+ plot_loss_curve_common(
963
+ history=history,
964
+ title="Loss vs. Epoch",
965
+ save_path=save_path,
966
+ show=False,
967
+ )
968
+ else:
969
+ if plt is None:
970
+ _plot_skip("loss curve")
971
+ return
972
+ ensure_parent_dir(save_path)
973
+ epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
974
+ fig = plt.figure(figsize=(8, 4))
975
+ ax = fig.add_subplot(111)
976
+ if train_hist:
977
+ ax.plot(range(1, len(train_hist) + 1), train_hist,
978
+ label='Train Loss', color='tab:blue')
979
+ if val_hist:
980
+ ax.plot(range(1, len(val_hist) + 1), val_hist,
981
+ label='Validation Loss', color='tab:orange')
982
+ ax.set_xlabel('Epoch')
983
+ ax.set_ylabel('Weighted Loss')
984
+ ax.set_title('Loss vs. Epoch')
985
+ ax.grid(True, linestyle='--', alpha=0.3)
986
+ ax.legend()
987
+ plt.tight_layout()
988
+ plt.savefig(save_path, dpi=300)
989
+ plt.close(fig)
990
+ print(f"[Training] Loss curve saved to {save_path}")
991
+
992
+
993
+ # =============================================================================
994
+ # Plotting helpers
995
+ # =============================================================================
996
+
997
+ def split_data(data, col_nme, wgt_nme, n_bins=10):
998
+ return PlotUtils.split_data(data, col_nme, wgt_nme, n_bins)
999
+
1000
+ # Lift curve plotting wrapper
1001
+
1002
+
1003
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
1004
+ weight_list, tgt_nme, n_bins=10,
1005
+ fig_nme='Lift Chart'):
1006
+ return PlotUtils.plot_lift_list(pred_model, w_pred_list, w_act_list,
1007
+ weight_list, tgt_nme, n_bins, fig_nme)
1008
+
1009
+ # Double lift curve plotting wrapper
1010
+
1011
+
1012
+ def plot_dlift_list(pred_model_1, pred_model_2,
1013
+ model_nme_1, model_nme_2,
1014
+ tgt_nme,
1015
+ w_list, w_act_list, n_bins=10,
1016
+ fig_nme='Double Lift Chart'):
1017
+ return PlotUtils.plot_dlift_list(pred_model_1, pred_model_2,
1018
+ model_nme_1, model_nme_2,
1019
+ tgt_nme, w_list, w_act_list,
1020
+ n_bins, fig_nme)