ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,1021 @@
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import ctypes
5
+ import copy
6
+ import gc
7
+ import json
8
+ import math
9
+ import os
10
+ import random
11
+ import time
12
+ from contextlib import nullcontext
13
+ from datetime import timedelta
14
+ from pathlib import Path
15
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
16
+
17
+ try: # matplotlib is optional; avoid hard import failures in headless/minimal envs
18
+ import matplotlib
19
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
20
+ matplotlib.use("Agg")
21
+ import matplotlib.pyplot as plt
22
+ _MPL_IMPORT_ERROR: Optional[BaseException] = None
23
+ except Exception as exc: # pragma: no cover - optional dependency
24
+ matplotlib = None # type: ignore[assignment]
25
+ plt = None # type: ignore[assignment]
26
+ _MPL_IMPORT_ERROR = exc
27
+ import numpy as np
28
+ import optuna
29
+ import pandas as pd
30
+ import torch
31
+ import torch.distributed as dist
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from torch.cuda.amp import autocast, GradScaler
35
+ from torch.nn.parallel import DistributedDataParallel as DDP
36
+ from torch.nn.utils import clip_grad_norm_
37
+ from torch.utils.data import DataLoader, DistributedSampler
38
+
39
+ # Optional: unify plotting with shared plotting package
40
+ try:
41
+ from ..plotting import curves as plot_curves_common
42
+ from ..plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
43
+ except Exception: # pragma: no cover
44
+ try:
45
+ from ins_pricing.plotting import curves as plot_curves_common
46
+ from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
47
+ except Exception: # pragma: no cover
48
+ plot_curves_common = None
49
+ plot_loss_curve_common = None
50
+ # Limit CUDA allocator split size to reduce fragmentation and OOM risk.
51
+ # Override via PYTORCH_CUDA_ALLOC_CONF if needed.
52
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:256")
53
+
54
+
55
+ # Constants and utility helpers
56
+ # =============================================================================
57
+ torch.backends.cudnn.benchmark = True
58
+ EPS = 1e-8
59
+
60
+
61
+ def _plot_skip(label: str) -> None:
62
+ if _MPL_IMPORT_ERROR is not None:
63
+ print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
64
+ else:
65
+ print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
66
+
67
+
68
+ def set_global_seed(seed: int) -> None:
69
+ random.seed(seed)
70
+ np.random.seed(seed)
71
+ torch.manual_seed(seed)
72
+ if torch.cuda.is_available():
73
+ torch.cuda.manual_seed_all(seed)
74
+
75
+
76
+ class IOUtils:
77
+ # File and path utilities.
78
+
79
+ @staticmethod
80
+ def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
81
+ with open(file_path, mode='r', encoding='utf-8') as file:
82
+ reader = csv.DictReader(file)
83
+ return [
84
+ dict(filter(lambda item: item[0] != '', row.items()))
85
+ for row in reader
86
+ ]
87
+
88
+ @staticmethod
89
+ def _sanitize_params_dict(params: Dict[str, Any]) -> Dict[str, Any]:
90
+ # Filter index-like columns such as "Unnamed: 0" from pandas I/O.
91
+ return {
92
+ k: v
93
+ for k, v in (params or {}).items()
94
+ if k and not str(k).startswith("Unnamed")
95
+ }
96
+
97
+ @staticmethod
98
+ def load_params_file(path: str) -> Dict[str, Any]:
99
+ """Load parameter dict from JSON/CSV/TSV files.
100
+
101
+ - JSON: accept dict or {"best_params": {...}} wrapper
102
+ - CSV/TSV: read the first row as params
103
+ """
104
+ file_path = Path(path).expanduser().resolve()
105
+ if not file_path.exists():
106
+ raise FileNotFoundError(f"params file not found: {file_path}")
107
+ suffix = file_path.suffix.lower()
108
+ if suffix == ".json":
109
+ payload = json.loads(file_path.read_text(
110
+ encoding="utf-8", errors="replace"))
111
+ if isinstance(payload, dict) and "best_params" in payload:
112
+ payload = payload.get("best_params") or {}
113
+ if not isinstance(payload, dict):
114
+ raise ValueError(
115
+ f"Invalid JSON params file (expect dict): {file_path}")
116
+ return IOUtils._sanitize_params_dict(dict(payload))
117
+ if suffix in (".csv", ".tsv"):
118
+ df = pd.read_csv(file_path, sep="\t" if suffix == ".tsv" else ",")
119
+ if df.empty:
120
+ raise ValueError(f"Empty params file: {file_path}")
121
+ params = df.iloc[0].to_dict()
122
+ return IOUtils._sanitize_params_dict(params)
123
+ raise ValueError(
124
+ f"Unsupported params file type '{suffix}': {file_path}")
125
+
126
+ @staticmethod
127
+ def ensure_parent_dir(file_path: str) -> None:
128
+ # Create parent directories when missing.
129
+ directory = os.path.dirname(file_path)
130
+ if directory:
131
+ os.makedirs(directory, exist_ok=True)
132
+
133
+
134
+ class TrainingUtils:
135
+ # Small helpers used during training.
136
+
137
+ @staticmethod
138
+ def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
139
+ estimated = int((learning_rate / 1e-4) ** 0.5 *
140
+ (data_size / max(batch_num, 1)))
141
+ return max(1, min(data_size, max(minimum, estimated)))
142
+
143
+ @staticmethod
144
+ def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
145
+ # Clamp predictions to positive values for stability.
146
+ pred_clamped = torch.clamp(pred, min=eps)
147
+ if p == 1:
148
+ term1 = target * torch.log(target / pred_clamped + eps) # Poisson
149
+ term2 = -target + pred_clamped
150
+ term3 = 0
151
+ elif p == 0:
152
+ term1 = 0.5 * torch.pow(target - pred_clamped, 2) # Gaussian
153
+ term2 = 0
154
+ term3 = 0
155
+ elif p == 2:
156
+ term1 = torch.log(pred_clamped / target + eps) # Gamma
157
+ term2 = -target / pred_clamped + 1
158
+ term3 = 0
159
+ else:
160
+ term1 = torch.pow(target, 2 - p) / ((1 - p) * (2 - p))
161
+ term2 = target * torch.pow(pred_clamped, 1 - p) / (1 - p)
162
+ term3 = torch.pow(pred_clamped, 2 - p) / (2 - p)
163
+ return torch.nan_to_num( # Tweedie negative log-likelihood (constant omitted)
164
+ 2 * (term1 - term2 + term3),
165
+ nan=eps,
166
+ posinf=max_clip,
167
+ neginf=-max_clip
168
+ )
169
+
170
+ @staticmethod
171
+ def free_cuda() -> None:
172
+ print(">>> Moving all models to CPU...")
173
+ for obj in gc.get_objects():
174
+ try:
175
+ if hasattr(obj, "to") and callable(obj.to):
176
+ obj.to("cpu")
177
+ except Exception:
178
+ pass
179
+
180
+ print(">>> Releasing tensor/optimizer/DataLoader references...")
181
+ gc.collect()
182
+
183
+ print(">>> Clearing CUDA cache...")
184
+ if torch.cuda.is_available():
185
+ torch.cuda.empty_cache()
186
+ torch.cuda.synchronize()
187
+ print(">>> CUDA memory released.")
188
+ else:
189
+ print(">>> CUDA not available; cleanup skipped.")
190
+
191
+
192
+ class DistributedUtils:
193
+ _cached_state: Optional[tuple] = None
194
+
195
+ @staticmethod
196
+ def setup_ddp():
197
+ """Initialize the DDP process group for distributed training."""
198
+ if dist.is_initialized():
199
+ if DistributedUtils._cached_state is None:
200
+ rank = dist.get_rank()
201
+ world_size = dist.get_world_size()
202
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
203
+ DistributedUtils._cached_state = (
204
+ True,
205
+ local_rank,
206
+ rank,
207
+ world_size,
208
+ )
209
+ return DistributedUtils._cached_state
210
+
211
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
212
+ rank = int(os.environ["RANK"])
213
+ world_size = int(os.environ["WORLD_SIZE"])
214
+ local_rank = int(os.environ["LOCAL_RANK"])
215
+
216
+ if os.name == "nt" and torch.cuda.is_available() and world_size > 1:
217
+ print(
218
+ ">>> DDP Setup Disabled: Windows CUDA DDP is not supported. "
219
+ "Falling back to single process."
220
+ )
221
+ return False, 0, 0, 1
222
+
223
+ if torch.cuda.is_available():
224
+ torch.cuda.set_device(local_rank)
225
+
226
+ timeout_seconds = int(os.environ.get(
227
+ "BAYESOPT_DDP_TIMEOUT_SECONDS", "1800"))
228
+ timeout = timedelta(seconds=max(1, timeout_seconds))
229
+ backend = "gloo"
230
+ if torch.cuda.is_available() and os.name != "nt":
231
+ try:
232
+ if getattr(dist, "is_nccl_available", lambda: False)():
233
+ backend = "nccl"
234
+ except Exception:
235
+ backend = "gloo"
236
+
237
+ dist.init_process_group(
238
+ backend=backend, init_method="env://", timeout=timeout)
239
+ print(
240
+ f">>> DDP Initialized ({backend}, timeout={timeout_seconds}s): "
241
+ f"Rank {rank}/{world_size}, Local Rank {local_rank}"
242
+ )
243
+ DistributedUtils._cached_state = (
244
+ True,
245
+ local_rank,
246
+ rank,
247
+ world_size,
248
+ )
249
+ return DistributedUtils._cached_state
250
+ else:
251
+ print(
252
+ f">>> DDP Setup Failed: RANK or WORLD_SIZE not found in env. Keys found: {list(os.environ.keys())}"
253
+ )
254
+ print(
255
+ ">>> Hint: launch with torchrun --nproc_per_node=<N> <script.py>"
256
+ )
257
+ return False, 0, 0, 1
258
+
259
+ @staticmethod
260
+ def cleanup_ddp():
261
+ """Destroy the DDP process group and clear cached state."""
262
+ if dist.is_initialized():
263
+ dist.destroy_process_group()
264
+ DistributedUtils._cached_state = None
265
+
266
+ @staticmethod
267
+ def is_main_process():
268
+ return not dist.is_initialized() or dist.get_rank() == 0
269
+
270
+ @staticmethod
271
+ def world_size() -> int:
272
+ return dist.get_world_size() if dist.is_initialized() else 1
273
+
274
+
275
+ class PlotUtils:
276
+ # Plot helpers shared across models.
277
+
278
+ @staticmethod
279
+ def split_data(data: pd.DataFrame, col_nme: str, wgt_nme: str, n_bins: int = 10) -> pd.DataFrame:
280
+ data_sorted = data.sort_values(by=col_nme, ascending=True).copy()
281
+ data_sorted['cum_weight'] = data_sorted[wgt_nme].cumsum()
282
+ w_sum = data_sorted[wgt_nme].sum()
283
+ if w_sum <= EPS:
284
+ data_sorted.loc[:, 'bins'] = 0
285
+ else:
286
+ data_sorted.loc[:, 'bins'] = np.floor(
287
+ data_sorted['cum_weight'] * float(n_bins) / w_sum
288
+ )
289
+ data_sorted.loc[(data_sorted['bins'] == n_bins),
290
+ 'bins'] = n_bins - 1
291
+ return data_sorted.groupby(['bins'], observed=True).sum(numeric_only=True)
292
+
293
+ @staticmethod
294
+ def plot_lift_ax(ax, plot_data, title, pred_label='Predicted', act_label='Actual', weight_label='Earned Exposure'):
295
+ ax.plot(plot_data.index, plot_data['act_v'],
296
+ label=act_label, color='red')
297
+ ax.plot(plot_data.index, plot_data['exp_v'],
298
+ label=pred_label, color='blue')
299
+ ax.set_title(title, fontsize=8)
300
+ ax.set_xticks(plot_data.index)
301
+ ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
302
+ ax.tick_params(axis='y', labelsize=6)
303
+ ax.legend(loc='upper left', fontsize=5, frameon=False)
304
+ ax.margins(0.05)
305
+ ax2 = ax.twinx()
306
+ ax2.bar(plot_data.index, plot_data['weight'],
307
+ alpha=0.5, color='seagreen',
308
+ label=weight_label)
309
+ ax2.tick_params(axis='y', labelsize=6)
310
+ ax2.legend(loc='upper right', fontsize=5, frameon=False)
311
+
312
+ @staticmethod
313
+ def plot_dlift_ax(ax, plot_data, title, label1, label2, act_label='Actual', weight_label='Earned Exposure'):
314
+ ax.plot(plot_data.index, plot_data['act_v'],
315
+ label=act_label, color='red')
316
+ ax.plot(plot_data.index, plot_data['exp_v1'],
317
+ label=label1, color='blue')
318
+ ax.plot(plot_data.index, plot_data['exp_v2'],
319
+ label=label2, color='black')
320
+ ax.set_title(title, fontsize=8)
321
+ ax.set_xticks(plot_data.index)
322
+ ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
323
+ ax.set_xlabel(f'{label1} / {label2}', fontsize=6)
324
+ ax.tick_params(axis='y', labelsize=6)
325
+ ax.legend(loc='upper left', fontsize=5, frameon=False)
326
+ ax.margins(0.1)
327
+ ax2 = ax.twinx()
328
+ ax2.bar(plot_data.index, plot_data['weight'],
329
+ alpha=0.5, color='seagreen',
330
+ label=weight_label)
331
+ ax2.tick_params(axis='y', labelsize=6)
332
+ ax2.legend(loc='upper right', fontsize=5, frameon=False)
333
+
334
+ @staticmethod
335
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
336
+ weight_list, tgt_nme, n_bins: int = 10,
337
+ fig_nme: str = 'Lift Chart'):
338
+ if plot_curves_common is not None:
339
+ save_path = os.path.join(
340
+ os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
341
+ plot_curves_common.plot_lift_curve(
342
+ pred_model,
343
+ w_act_list,
344
+ weight_list,
345
+ n_bins=n_bins,
346
+ title=f'Lift Chart of {tgt_nme}',
347
+ pred_label='Predicted',
348
+ act_label='Actual',
349
+ weight_label='Earned Exposure',
350
+ pred_weighted=False,
351
+ actual_weighted=True,
352
+ save_path=save_path,
353
+ show=False,
354
+ )
355
+ return
356
+ if plt is None:
357
+ _plot_skip("lift plot")
358
+ return
359
+ lift_data = pd.DataFrame()
360
+ lift_data.loc[:, 'pred'] = pred_model
361
+ lift_data.loc[:, 'w_pred'] = w_pred_list
362
+ lift_data.loc[:, 'act'] = w_act_list
363
+ lift_data.loc[:, 'weight'] = weight_list
364
+ plot_data = PlotUtils.split_data(lift_data, 'pred', 'weight', n_bins)
365
+ plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
366
+ plot_data['act_v'] = plot_data['act'] / plot_data['weight']
367
+ plot_data.reset_index(inplace=True)
368
+
369
+ fig = plt.figure(figsize=(7, 5))
370
+ ax = fig.add_subplot(111)
371
+ PlotUtils.plot_lift_ax(ax, plot_data, f'Lift Chart of {tgt_nme}')
372
+ plt.subplots_adjust(wspace=0.3)
373
+
374
+ save_path = os.path.join(
375
+ os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
376
+ IOUtils.ensure_parent_dir(save_path)
377
+ plt.savefig(save_path, dpi=300)
378
+ plt.close(fig)
379
+
380
+ @staticmethod
381
+ def plot_dlift_list(pred_model_1, pred_model_2,
382
+ model_nme_1, model_nme_2,
383
+ tgt_nme,
384
+ w_list, w_act_list, n_bins: int = 10,
385
+ fig_nme: str = 'Double Lift Chart'):
386
+ if plot_curves_common is not None:
387
+ save_path = os.path.join(
388
+ os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
389
+ plot_curves_common.plot_double_lift_curve(
390
+ pred_model_1,
391
+ pred_model_2,
392
+ w_act_list,
393
+ w_list,
394
+ n_bins=n_bins,
395
+ title=f'Double Lift Chart of {tgt_nme}',
396
+ label1=model_nme_1,
397
+ label2=model_nme_2,
398
+ pred1_weighted=False,
399
+ pred2_weighted=False,
400
+ actual_weighted=True,
401
+ save_path=save_path,
402
+ show=False,
403
+ )
404
+ return
405
+ if plt is None:
406
+ _plot_skip("double lift plot")
407
+ return
408
+ lift_data = pd.DataFrame()
409
+ lift_data.loc[:, 'pred1'] = pred_model_1
410
+ lift_data.loc[:, 'pred2'] = pred_model_2
411
+ lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
412
+ lift_data.loc[:, 'act'] = w_act_list
413
+ lift_data.loc[:, 'weight'] = w_list
414
+ lift_data.loc[:, 'w_pred1'] = lift_data['pred1'] * lift_data['weight']
415
+ lift_data.loc[:, 'w_pred2'] = lift_data['pred2'] * lift_data['weight']
416
+ plot_data = PlotUtils.split_data(
417
+ lift_data, 'diff_ly', 'weight', n_bins)
418
+ plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
419
+ plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
420
+ plot_data['act_v'] = plot_data['act']/plot_data['act']
421
+ plot_data.reset_index(inplace=True)
422
+
423
+ fig = plt.figure(figsize=(7, 5))
424
+ ax = fig.add_subplot(111)
425
+ PlotUtils.plot_dlift_ax(
426
+ ax, plot_data, f'Double Lift Chart of {tgt_nme}', model_nme_1, model_nme_2)
427
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
428
+
429
+ save_path = os.path.join(
430
+ os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
431
+ IOUtils.ensure_parent_dir(save_path)
432
+ plt.savefig(save_path, dpi=300)
433
+ plt.close(fig)
434
+
435
+
436
+ def infer_factor_and_cate_list(train_df: pd.DataFrame,
437
+ test_df: pd.DataFrame,
438
+ resp_nme: str,
439
+ weight_nme: str,
440
+ binary_resp_nme: Optional[str] = None,
441
+ factor_nmes: Optional[List[str]] = None,
442
+ cate_list: Optional[List[str]] = None,
443
+ infer_categorical_max_unique: int = 50,
444
+ infer_categorical_max_ratio: float = 0.05) -> Tuple[List[str], List[str]]:
445
+ """Infer factor_nmes/cate_list when feature names are not provided.
446
+
447
+ Rules:
448
+ - factor_nmes: start from shared train/test columns, exclude target/weight/(optional binary target).
449
+ - cate_list: object/category/bool plus low-cardinality integer columns.
450
+ - Always intersect with shared train/test columns to avoid mismatches.
451
+ """
452
+ excluded = {resp_nme, weight_nme}
453
+ if binary_resp_nme:
454
+ excluded.add(binary_resp_nme)
455
+
456
+ common_cols = [c for c in train_df.columns if c in test_df.columns]
457
+ if factor_nmes is None:
458
+ factors = [c for c in common_cols if c not in excluded]
459
+ else:
460
+ factors = [
461
+ c for c in factor_nmes if c in common_cols and c not in excluded]
462
+
463
+ if cate_list is not None:
464
+ cats = [c for c in cate_list if c in factors]
465
+ return factors, cats
466
+
467
+ n_rows = max(1, len(train_df))
468
+ cats: List[str] = []
469
+ for col in factors:
470
+ s = train_df[col]
471
+ if pd.api.types.is_bool_dtype(s) or pd.api.types.is_object_dtype(s) or isinstance(s.dtype, pd.CategoricalDtype):
472
+ cats.append(col)
473
+ continue
474
+ if pd.api.types.is_integer_dtype(s):
475
+ nunique = int(s.nunique(dropna=True))
476
+ if nunique <= infer_categorical_max_unique or (nunique / n_rows) <= infer_categorical_max_ratio:
477
+ cats.append(col)
478
+ return factors, cats
479
+
480
+
481
+ # Backward-compatible functional wrappers
482
+ def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
483
+ return IOUtils.csv_to_dict(file_path)
484
+
485
+
486
+ def ensure_parent_dir(file_path: str) -> None:
487
+ IOUtils.ensure_parent_dir(file_path)
488
+
489
+
490
+ def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
491
+ return TrainingUtils.compute_batch_size(data_size, learning_rate, batch_num, minimum)
492
+
493
+
494
+ # Tweedie deviance loss for PyTorch.
495
+ # Reference: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-poisson-gamma-and-tweedie-deviances
496
+ def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
497
+ return TrainingUtils.tweedie_loss(pred, target, p=p, eps=eps, max_clip=max_clip)
498
+
499
+
500
+ # CUDA memory release helper.
501
+ def free_cuda():
502
+ TrainingUtils.free_cuda()
503
+
504
+
505
+ class TorchTrainerMixin:
506
+ # Shared helpers for Torch tabular trainers.
507
+
508
+ def _device_type(self) -> str:
509
+ return getattr(self, "device", torch.device("cpu")).type
510
+
511
+ def _resolve_resource_profile(self) -> str:
512
+ profile = getattr(self, "resource_profile", None)
513
+ if not profile:
514
+ profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
515
+ profile = str(profile).strip().lower()
516
+ if profile in {"cpu", "mps", "cuda"}:
517
+ profile = "auto"
518
+ if profile not in {"auto", "throughput", "memory_saving"}:
519
+ profile = "auto"
520
+ if profile == "auto" and self._device_type() == "cuda":
521
+ profile = "throughput"
522
+ return profile
523
+
524
+ def _log_resource_summary_once(self, profile: str) -> None:
525
+ if getattr(self, "_resource_summary_logged", False):
526
+ return
527
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
528
+ return
529
+ self._resource_summary_logged = True
530
+ device = getattr(self, "device", torch.device("cpu"))
531
+ device_type = self._device_type()
532
+ cpu_count = os.cpu_count() or 1
533
+ cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
534
+ mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
535
+ ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
536
+ data_parallel = bool(getattr(self, "use_data_parallel", False))
537
+ print(
538
+ f">>> Resource summary: device={device}, device_type={device_type}, "
539
+ f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
540
+ f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
541
+ )
542
+
543
+ def _available_system_memory(self) -> Optional[int]:
544
+ if os.name == "nt":
545
+ class _MemStatus(ctypes.Structure):
546
+ _fields_ = [
547
+ ("dwLength", ctypes.c_ulong),
548
+ ("dwMemoryLoad", ctypes.c_ulong),
549
+ ("ullTotalPhys", ctypes.c_ulonglong),
550
+ ("ullAvailPhys", ctypes.c_ulonglong),
551
+ ("ullTotalPageFile", ctypes.c_ulonglong),
552
+ ("ullAvailPageFile", ctypes.c_ulonglong),
553
+ ("ullTotalVirtual", ctypes.c_ulonglong),
554
+ ("ullAvailVirtual", ctypes.c_ulonglong),
555
+ ("sullAvailExtendedVirtual", ctypes.c_ulonglong),
556
+ ]
557
+ status = _MemStatus()
558
+ status.dwLength = ctypes.sizeof(_MemStatus)
559
+ if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
560
+ return int(status.ullAvailPhys)
561
+ return None
562
+ try:
563
+ pages = os.sysconf("SC_AVPHYS_PAGES")
564
+ page_size = os.sysconf("SC_PAGE_SIZE")
565
+ return int(pages * page_size)
566
+ except Exception:
567
+ return None
568
+
569
+ def _available_cuda_memory(self) -> Optional[int]:
570
+ if not torch.cuda.is_available():
571
+ return None
572
+ try:
573
+ free_mem, _total_mem = torch.cuda.mem_get_info()
574
+ except Exception:
575
+ return None
576
+ return int(free_mem)
577
+
578
+ def _estimate_sample_bytes(self, dataset) -> Optional[int]:
579
+ try:
580
+ if len(dataset) == 0:
581
+ return None
582
+ sample = dataset[0]
583
+ except Exception:
584
+ return None
585
+
586
+ def _bytes(obj) -> int:
587
+ if obj is None:
588
+ return 0
589
+ if torch.is_tensor(obj):
590
+ return int(obj.element_size() * obj.nelement())
591
+ if isinstance(obj, np.ndarray):
592
+ return int(obj.nbytes)
593
+ if isinstance(obj, (list, tuple)):
594
+ return int(sum(_bytes(item) for item in obj))
595
+ if isinstance(obj, dict):
596
+ return int(sum(_bytes(item) for item in obj.values()))
597
+ return 0
598
+
599
+ sample_bytes = _bytes(sample)
600
+ return int(sample_bytes) if sample_bytes > 0 else None
601
+
602
+ def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
603
+ if batch_size <= 1:
604
+ return batch_size
605
+ sample_bytes = self._estimate_sample_bytes(dataset)
606
+ if sample_bytes is None:
607
+ return batch_size
608
+ device_type = self._device_type()
609
+ if device_type == "cuda":
610
+ available = self._available_cuda_memory()
611
+ if available is None:
612
+ return batch_size
613
+ if profile == "throughput":
614
+ budget_ratio = 0.8
615
+ overhead = 8.0
616
+ elif profile == "memory_saving":
617
+ budget_ratio = 0.5
618
+ overhead = 14.0
619
+ else:
620
+ budget_ratio = 0.6
621
+ overhead = 12.0
622
+ else:
623
+ available = self._available_system_memory()
624
+ if available is None:
625
+ return batch_size
626
+ if profile == "throughput":
627
+ budget_ratio = 0.4
628
+ overhead = 1.8
629
+ elif profile == "memory_saving":
630
+ budget_ratio = 0.25
631
+ overhead = 3.0
632
+ else:
633
+ budget_ratio = 0.3
634
+ overhead = 2.6
635
+ budget = int(available * budget_ratio)
636
+ per_sample = int(sample_bytes * overhead)
637
+ if per_sample <= 0:
638
+ return batch_size
639
+ max_batch = max(1, int(budget // per_sample))
640
+ if max_batch < batch_size:
641
+ print(
642
+ f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
643
+ f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
644
+ )
645
+ return min(batch_size, max_batch)
646
+
647
+ def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
648
+ if os.name == 'nt':
649
+ return 0
650
+ if getattr(self, "is_ddp_enabled", False):
651
+ return 0
652
+ profile = profile or self._resolve_resource_profile()
653
+ if profile == "memory_saving":
654
+ return 0
655
+ worker_cap = min(int(max_workers), os.cpu_count() or 1)
656
+ if self._device_type() == "mps":
657
+ worker_cap = min(worker_cap, 2)
658
+ return worker_cap
659
+
660
+ def _build_dataloader(self,
661
+ dataset,
662
+ N: int,
663
+ base_bs_gpu: tuple,
664
+ base_bs_cpu: tuple,
665
+ min_bs: int = 64,
666
+ target_effective_cuda: int = 1024,
667
+ target_effective_cpu: int = 512,
668
+ large_threshold: int = 200_000,
669
+ mid_threshold: int = 50_000):
670
+ profile = self._resolve_resource_profile()
671
+ self._log_resource_summary_once(profile)
672
+ batch_size = TrainingUtils.compute_batch_size(
673
+ data_size=len(dataset),
674
+ learning_rate=self.learning_rate,
675
+ batch_num=self.batch_num,
676
+ minimum=min_bs
677
+ )
678
+ gpu_large, gpu_mid, gpu_small = base_bs_gpu
679
+ cpu_mid, cpu_small = base_bs_cpu
680
+
681
+ if self._device_type() == 'cuda':
682
+ device_count = torch.cuda.device_count()
683
+ if getattr(self, "is_ddp_enabled", False):
684
+ device_count = 1
685
+ # In multi-GPU, increase min batch size so each GPU gets enough data.
686
+ if device_count > 1:
687
+ min_bs = min_bs * device_count
688
+ print(
689
+ f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
690
+
691
+ if N > large_threshold:
692
+ base_bs = gpu_large * device_count
693
+ elif N > mid_threshold:
694
+ base_bs = gpu_mid * device_count
695
+ else:
696
+ base_bs = gpu_small * device_count
697
+ else:
698
+ base_bs = cpu_mid if N > mid_threshold else cpu_small
699
+
700
+ # Recompute batch_size to respect the adjusted min_bs.
701
+ batch_size = TrainingUtils.compute_batch_size(
702
+ data_size=len(dataset),
703
+ learning_rate=self.learning_rate,
704
+ batch_num=self.batch_num,
705
+ minimum=min_bs
706
+ )
707
+ batch_size = min(batch_size, base_bs, N)
708
+ batch_size = self._cap_batch_size_by_memory(
709
+ dataset, batch_size, profile)
710
+
711
+ target_effective_bs = target_effective_cuda if self._device_type(
712
+ ) == 'cuda' else target_effective_cpu
713
+ if getattr(self, "is_ddp_enabled", False):
714
+ world_size = max(1, DistributedUtils.world_size())
715
+ target_effective_bs = max(1, target_effective_bs // world_size)
716
+
717
+ world_size = getattr(self, "world_size", 1) if getattr(
718
+ self, "is_ddp_enabled", False) else 1
719
+ samples_per_rank = math.ceil(
720
+ N / max(1, world_size)) if world_size > 1 else N
721
+ steps_per_epoch = max(
722
+ 1, math.ceil(samples_per_rank / max(1, batch_size)))
723
+ # Limit gradient accumulation to avoid scaling beyond actual batches.
724
+ desired_accum = max(1, target_effective_bs // max(1, batch_size))
725
+ accum_steps = max(1, min(desired_accum, steps_per_epoch))
726
+
727
+ # Linux (posix) uses fork; Windows (nt) uses spawn with higher overhead.
728
+ workers = self._resolve_num_workers(8, profile=profile)
729
+ prefetch_factor = None
730
+ if workers > 0:
731
+ prefetch_factor = 4 if profile == "throughput" else 2
732
+ persistent = workers > 0 and profile != "memory_saving"
733
+ print(
734
+ f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
735
+ f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
736
+ sampler = None
737
+ if dist.is_initialized():
738
+ sampler = DistributedSampler(dataset, shuffle=True)
739
+ shuffle = False # DistributedSampler handles shuffling.
740
+ else:
741
+ shuffle = True
742
+
743
+ dataloader = DataLoader(
744
+ dataset,
745
+ batch_size=batch_size,
746
+ shuffle=shuffle,
747
+ sampler=sampler,
748
+ num_workers=workers,
749
+ pin_memory=(self._device_type() == 'cuda'),
750
+ persistent_workers=persistent,
751
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
752
+ )
753
+ return dataloader, accum_steps
754
+
755
+ def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
756
+ profile = self._resolve_resource_profile()
757
+ val_bs = accum_steps * train_dataloader.batch_size
758
+ val_workers = self._resolve_num_workers(4, profile=profile)
759
+ prefetch_factor = None
760
+ if val_workers > 0:
761
+ prefetch_factor = 2
762
+ return DataLoader(
763
+ dataset,
764
+ batch_size=val_bs,
765
+ shuffle=False,
766
+ num_workers=val_workers,
767
+ pin_memory=(self._device_type() == 'cuda'),
768
+ persistent_workers=(val_workers > 0 and profile != "memory_saving"),
769
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
770
+ )
771
+
772
+ def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
773
+ task = getattr(self, "task_type", "regression")
774
+ if task == 'classification':
775
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
776
+ return loss_fn(y_pred, y_true).view(-1)
777
+ if apply_softplus:
778
+ y_pred = F.softplus(y_pred)
779
+ y_pred = torch.clamp(y_pred, min=1e-6)
780
+ power = getattr(self, "tw_power", 1.5)
781
+ return tweedie_loss(y_pred, y_true, p=power).view(-1)
782
+
783
+ def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
784
+ losses = self._compute_losses(
785
+ y_pred, y_true, apply_softplus=apply_softplus)
786
+ weighted_loss = (losses * weights.view(-1)).sum() / \
787
+ torch.clamp(weights.sum(), min=EPS)
788
+ return weighted_loss
789
+
790
+ def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
791
+ ignore_keys: Optional[List[str]] = None):
792
+ if val_loss < best_loss:
793
+ ignore_keys = ignore_keys or []
794
+ state_dict = {
795
+ k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
796
+ for k, v in model.state_dict().items()
797
+ if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
798
+ }
799
+ return val_loss, state_dict, 0, False
800
+ patience_counter += 1
801
+ should_stop = best_state is not None and patience_counter >= getattr(
802
+ self, "patience", 0)
803
+ return best_loss, best_state, patience_counter, should_stop
804
+
805
+ def _train_model(self,
806
+ model,
807
+ dataloader,
808
+ accum_steps,
809
+ optimizer,
810
+ scaler,
811
+ forward_fn,
812
+ val_forward_fn=None,
813
+ apply_softplus: bool = False,
814
+ clip_fn=None,
815
+ trial: Optional[optuna.trial.Trial] = None,
816
+ loss_curve_path: Optional[str] = None):
817
+ device_type = self._device_type()
818
+ best_loss = float('inf')
819
+ best_state = None
820
+ patience_counter = 0
821
+ stop_training = False
822
+ train_history: List[float] = []
823
+ val_history: List[float] = []
824
+
825
+ is_ddp_model = isinstance(model, DDP)
826
+
827
+ for epoch in range(1, getattr(self, "epochs", 1) + 1):
828
+ epoch_start_ts = time.time()
829
+ val_weighted_loss = None
830
+ if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
831
+ self.dataloader_sampler.set_epoch(epoch)
832
+
833
+ model.train()
834
+ optimizer.zero_grad()
835
+
836
+ epoch_loss_sum = None
837
+ epoch_weight_sum = None
838
+ for step, batch in enumerate(dataloader):
839
+ is_update_step = ((step + 1) % accum_steps == 0) or \
840
+ ((step + 1) == len(dataloader))
841
+ sync_cm = model.no_sync if (
842
+ is_ddp_model and not is_update_step) else nullcontext
843
+
844
+ with sync_cm():
845
+ with autocast(enabled=(device_type == 'cuda')):
846
+ y_pred, y_true, w = forward_fn(batch)
847
+ weighted_loss = self._compute_weighted_loss(
848
+ y_pred, y_true, w, apply_softplus=apply_softplus)
849
+ loss_for_backward = weighted_loss / accum_steps
850
+
851
+ batch_weight = torch.clamp(
852
+ w.detach().sum(), min=EPS).to(dtype=torch.float32)
853
+ loss_val = weighted_loss.detach().to(dtype=torch.float32)
854
+ if epoch_loss_sum is None:
855
+ epoch_loss_sum = torch.zeros(
856
+ (), device=batch_weight.device, dtype=torch.float32)
857
+ epoch_weight_sum = torch.zeros(
858
+ (), device=batch_weight.device, dtype=torch.float32)
859
+ epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
860
+ epoch_weight_sum = epoch_weight_sum + batch_weight
861
+ scaler.scale(loss_for_backward).backward()
862
+
863
+ if is_update_step:
864
+ if clip_fn is not None:
865
+ clip_fn()
866
+ scaler.step(optimizer)
867
+ scaler.update()
868
+ optimizer.zero_grad()
869
+
870
+ if epoch_loss_sum is None or epoch_weight_sum is None:
871
+ train_epoch_loss = 0.0
872
+ else:
873
+ train_epoch_loss = (
874
+ epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
875
+ ).item()
876
+ train_history.append(float(train_epoch_loss))
877
+
878
+ if val_forward_fn is not None:
879
+ should_compute_val = (not dist.is_initialized()
880
+ or DistributedUtils.is_main_process())
881
+ val_device = getattr(self, "device", torch.device("cpu"))
882
+ if not isinstance(val_device, torch.device):
883
+ val_device = torch.device(val_device)
884
+ loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
885
+ "cpu")
886
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
887
+
888
+ if should_compute_val:
889
+ model.eval()
890
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
891
+ val_result = val_forward_fn()
892
+ if isinstance(val_result, tuple) and len(val_result) == 3:
893
+ y_val_pred, y_val_true, w_val = val_result
894
+ val_weighted_loss = self._compute_weighted_loss(
895
+ y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
896
+ else:
897
+ val_weighted_loss = val_result
898
+ val_loss_tensor[0] = float(val_weighted_loss)
899
+
900
+ if dist.is_initialized():
901
+ dist.broadcast(val_loss_tensor, src=0)
902
+ val_weighted_loss = float(val_loss_tensor.item())
903
+
904
+ val_history.append(val_weighted_loss)
905
+
906
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
907
+ val_weighted_loss, best_loss, best_state, patience_counter, model)
908
+
909
+ prune_flag = False
910
+ is_main_rank = DistributedUtils.is_main_process()
911
+ if trial is not None and (not dist.is_initialized() or is_main_rank):
912
+ trial.report(val_weighted_loss, epoch)
913
+ prune_flag = trial.should_prune()
914
+
915
+ if dist.is_initialized():
916
+ prune_device = getattr(self, "device", torch.device("cpu"))
917
+ if not isinstance(prune_device, torch.device):
918
+ prune_device = torch.device(prune_device)
919
+ prune_tensor = torch.zeros(1, device=prune_device)
920
+ if is_main_rank:
921
+ prune_tensor.fill_(1 if prune_flag else 0)
922
+ dist.broadcast(prune_tensor, src=0)
923
+ prune_flag = bool(prune_tensor.item())
924
+
925
+ if prune_flag:
926
+ raise optuna.TrialPruned()
927
+
928
+ if stop_training:
929
+ break
930
+
931
+ should_log_epoch = (not dist.is_initialized()
932
+ or DistributedUtils.is_main_process())
933
+ if should_log_epoch:
934
+ elapsed = int(time.time() - epoch_start_ts)
935
+ if val_weighted_loss is None:
936
+ print(
937
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
938
+ f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
939
+ flush=True,
940
+ )
941
+ else:
942
+ print(
943
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
944
+ f"train_loss={float(train_epoch_loss):.6f} "
945
+ f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
946
+ flush=True,
947
+ )
948
+
949
+ history = {"train": train_history, "val": val_history}
950
+ self._plot_loss_curve(history, loss_curve_path)
951
+ return best_state, history
952
+
953
+ def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
954
+ if not save_path:
955
+ return
956
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
957
+ return
958
+ train_hist = history.get("train", []) if history else []
959
+ val_hist = history.get("val", []) if history else []
960
+ if not train_hist and not val_hist:
961
+ return
962
+ if plot_loss_curve_common is not None:
963
+ plot_loss_curve_common(
964
+ history=history,
965
+ title="Loss vs. Epoch",
966
+ save_path=save_path,
967
+ show=False,
968
+ )
969
+ else:
970
+ if plt is None:
971
+ _plot_skip("loss curve")
972
+ return
973
+ ensure_parent_dir(save_path)
974
+ epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
975
+ fig = plt.figure(figsize=(8, 4))
976
+ ax = fig.add_subplot(111)
977
+ if train_hist:
978
+ ax.plot(range(1, len(train_hist) + 1), train_hist,
979
+ label='Train Loss', color='tab:blue')
980
+ if val_hist:
981
+ ax.plot(range(1, len(val_hist) + 1), val_hist,
982
+ label='Validation Loss', color='tab:orange')
983
+ ax.set_xlabel('Epoch')
984
+ ax.set_ylabel('Weighted Loss')
985
+ ax.set_title('Loss vs. Epoch')
986
+ ax.grid(True, linestyle='--', alpha=0.3)
987
+ ax.legend()
988
+ plt.tight_layout()
989
+ plt.savefig(save_path, dpi=300)
990
+ plt.close(fig)
991
+ print(f"[Training] Loss curve saved to {save_path}")
992
+
993
+
994
+ # =============================================================================
995
+ # Plotting helpers
996
+ # =============================================================================
997
+
998
+ def split_data(data, col_nme, wgt_nme, n_bins=10):
999
+ return PlotUtils.split_data(data, col_nme, wgt_nme, n_bins)
1000
+
1001
+ # Lift curve plotting wrapper
1002
+
1003
+
1004
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
1005
+ weight_list, tgt_nme, n_bins=10,
1006
+ fig_nme='Lift Chart'):
1007
+ return PlotUtils.plot_lift_list(pred_model, w_pred_list, w_act_list,
1008
+ weight_list, tgt_nme, n_bins, fig_nme)
1009
+
1010
+ # Double lift curve plotting wrapper
1011
+
1012
+
1013
+ def plot_dlift_list(pred_model_1, pred_model_2,
1014
+ model_nme_1, model_nme_2,
1015
+ tgt_nme,
1016
+ w_list, w_act_list, n_bins=10,
1017
+ fig_nme='Double Lift Chart'):
1018
+ return PlotUtils.plot_dlift_list(pred_model_1, pred_model_2,
1019
+ model_nme_1, model_nme_2,
1020
+ tgt_nme, w_list, w_act_list,
1021
+ n_bins, fig_nme)