ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +48 -22
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +58 -46
- ins_pricing/cli/BayesOpt_incremental.py +77 -110
- ins_pricing/cli/Explain_Run.py +42 -23
- ins_pricing/cli/Explain_entry.py +551 -577
- ins_pricing/cli/Pricing_Run.py +42 -23
- ins_pricing/cli/bayesopt_entry_runner.py +51 -16
- ins_pricing/cli/utils/bootstrap.py +23 -0
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +379 -360
- ins_pricing/cli/utils/import_resolver.py +375 -358
- ins_pricing/cli/utils/notebook_utils.py +256 -242
- ins_pricing/cli/watchdog_run.py +216 -198
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/app.py +132 -61
- ins_pricing/frontend/config_builder.py +33 -0
- ins_pricing/frontend/example_config.json +11 -0
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/frontend/runner.py +340 -388
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/README.md +1 -1
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/pricing/factors.py +67 -56
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +112 -78
- ins_pricing/utils/device.py +258 -237
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/utils/logging.py +34 -1
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- ins_pricing/utils/profiling.py +8 -4
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
- ins_pricing-0.5.1.dist-info/RECORD +132 -0
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.5.dist-info/RECORD +0 -130
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -1,1503 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import csv
|
|
4
|
-
import ctypes
|
|
5
|
-
import copy
|
|
6
|
-
import gc
|
|
7
|
-
import json
|
|
8
|
-
import math
|
|
9
|
-
import os
|
|
10
|
-
import random
|
|
11
|
-
import time
|
|
12
|
-
from contextlib import nullcontext
|
|
13
|
-
from datetime import timedelta
|
|
14
|
-
from pathlib import Path
|
|
15
|
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
|
16
|
-
|
|
17
|
-
try: # matplotlib is optional; avoid hard import failures in headless/minimal envs
|
|
18
|
-
import matplotlib
|
|
19
|
-
if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
|
|
20
|
-
matplotlib.use("Agg")
|
|
21
|
-
import matplotlib.pyplot as plt
|
|
22
|
-
_MPL_IMPORT_ERROR: Optional[BaseException] = None
|
|
23
|
-
except Exception as exc: # pragma: no cover - optional dependency
|
|
24
|
-
matplotlib = None # type: ignore[assignment]
|
|
25
|
-
plt = None # type: ignore[assignment]
|
|
26
|
-
_MPL_IMPORT_ERROR = exc
|
|
27
|
-
import numpy as np
|
|
28
|
-
import optuna
|
|
29
|
-
import pandas as pd
|
|
30
|
-
import torch
|
|
31
|
-
import torch.distributed as dist
|
|
32
|
-
import torch.nn as nn
|
|
33
|
-
import torch.nn.functional as F
|
|
34
|
-
from torch.cuda.amp import autocast, GradScaler
|
|
35
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
36
|
-
from torch.nn.utils import clip_grad_norm_
|
|
37
|
-
from torch.utils.data import DataLoader, DistributedSampler
|
|
38
|
-
|
|
39
|
-
# Optional: unify plotting with shared plotting package
|
|
40
|
-
try:
|
|
41
|
-
from ...plotting import curves as plot_curves_common
|
|
42
|
-
from ...plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
|
|
43
|
-
except Exception: # pragma: no cover
|
|
44
|
-
try:
|
|
45
|
-
from ins_pricing.plotting import curves as plot_curves_common
|
|
46
|
-
from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
|
|
47
|
-
except Exception: # pragma: no cover
|
|
48
|
-
plot_curves_common = None
|
|
49
|
-
plot_loss_curve_common = None
|
|
50
|
-
# Limit CUDA allocator split size to reduce fragmentation and OOM risk.
|
|
51
|
-
# Override via PYTORCH_CUDA_ALLOC_CONF if needed.
|
|
52
|
-
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:256")
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
# Constants and utility helpers
|
|
56
|
-
# =============================================================================
|
|
57
|
-
torch.backends.cudnn.benchmark = True
|
|
58
|
-
EPS = 1e-8
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def _plot_skip(label: str) -> None:
|
|
62
|
-
if _MPL_IMPORT_ERROR is not None:
|
|
63
|
-
print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
|
|
64
|
-
else:
|
|
65
|
-
print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def set_global_seed(seed: int) -> None:
|
|
69
|
-
random.seed(seed)
|
|
70
|
-
np.random.seed(seed)
|
|
71
|
-
torch.manual_seed(seed)
|
|
72
|
-
if torch.cuda.is_available():
|
|
73
|
-
torch.cuda.manual_seed_all(seed)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
class IOUtils:
|
|
77
|
-
# File and path utilities.
|
|
78
|
-
|
|
79
|
-
@staticmethod
|
|
80
|
-
def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
|
|
81
|
-
with open(file_path, mode='r', encoding='utf-8') as file:
|
|
82
|
-
reader = csv.DictReader(file)
|
|
83
|
-
return [
|
|
84
|
-
dict(filter(lambda item: item[0] != '', row.items()))
|
|
85
|
-
for row in reader
|
|
86
|
-
]
|
|
87
|
-
|
|
88
|
-
@staticmethod
|
|
89
|
-
def _sanitize_params_dict(params: Dict[str, Any]) -> Dict[str, Any]:
|
|
90
|
-
# Filter index-like columns such as "Unnamed: 0" from pandas I/O.
|
|
91
|
-
return {
|
|
92
|
-
k: v
|
|
93
|
-
for k, v in (params or {}).items()
|
|
94
|
-
if k and not str(k).startswith("Unnamed")
|
|
95
|
-
}
|
|
96
|
-
|
|
97
|
-
@staticmethod
|
|
98
|
-
def load_params_file(path: str) -> Dict[str, Any]:
|
|
99
|
-
"""Load parameter dict from JSON/CSV/TSV files.
|
|
100
|
-
|
|
101
|
-
- JSON: accept dict or {"best_params": {...}} wrapper
|
|
102
|
-
- CSV/TSV: read the first row as params
|
|
103
|
-
"""
|
|
104
|
-
file_path = Path(path).expanduser().resolve()
|
|
105
|
-
if not file_path.exists():
|
|
106
|
-
raise FileNotFoundError(f"params file not found: {file_path}")
|
|
107
|
-
suffix = file_path.suffix.lower()
|
|
108
|
-
if suffix == ".json":
|
|
109
|
-
payload = json.loads(file_path.read_text(
|
|
110
|
-
encoding="utf-8", errors="replace"))
|
|
111
|
-
if isinstance(payload, dict) and "best_params" in payload:
|
|
112
|
-
payload = payload.get("best_params") or {}
|
|
113
|
-
if not isinstance(payload, dict):
|
|
114
|
-
raise ValueError(
|
|
115
|
-
f"Invalid JSON params file (expect dict): {file_path}")
|
|
116
|
-
return IOUtils._sanitize_params_dict(dict(payload))
|
|
117
|
-
if suffix in (".csv", ".tsv"):
|
|
118
|
-
df = pd.read_csv(file_path, sep="\t" if suffix == ".tsv" else ",")
|
|
119
|
-
if df.empty:
|
|
120
|
-
raise ValueError(f"Empty params file: {file_path}")
|
|
121
|
-
params = df.iloc[0].to_dict()
|
|
122
|
-
return IOUtils._sanitize_params_dict(params)
|
|
123
|
-
raise ValueError(
|
|
124
|
-
f"Unsupported params file type '{suffix}': {file_path}")
|
|
125
|
-
|
|
126
|
-
@staticmethod
|
|
127
|
-
def ensure_parent_dir(file_path: str) -> None:
|
|
128
|
-
# Create parent directories when missing.
|
|
129
|
-
directory = os.path.dirname(file_path)
|
|
130
|
-
if directory:
|
|
131
|
-
os.makedirs(directory, exist_ok=True)
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
class TrainingUtils:
|
|
135
|
-
# Small helpers used during training.
|
|
136
|
-
|
|
137
|
-
@staticmethod
|
|
138
|
-
def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
|
|
139
|
-
estimated = int((learning_rate / 1e-4) ** 0.5 *
|
|
140
|
-
(data_size / max(batch_num, 1)))
|
|
141
|
-
return max(1, min(data_size, max(minimum, estimated)))
|
|
142
|
-
|
|
143
|
-
@staticmethod
|
|
144
|
-
def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
|
|
145
|
-
# Clamp predictions to positive values for stability.
|
|
146
|
-
pred_clamped = torch.clamp(pred, min=eps)
|
|
147
|
-
if p == 1:
|
|
148
|
-
term1 = target * torch.log(target / pred_clamped + eps) # Poisson
|
|
149
|
-
term2 = -target + pred_clamped
|
|
150
|
-
term3 = 0
|
|
151
|
-
elif p == 0:
|
|
152
|
-
term1 = 0.5 * torch.pow(target - pred_clamped, 2) # Gaussian
|
|
153
|
-
term2 = 0
|
|
154
|
-
term3 = 0
|
|
155
|
-
elif p == 2:
|
|
156
|
-
term1 = torch.log(pred_clamped / target + eps) # Gamma
|
|
157
|
-
term2 = -target / pred_clamped + 1
|
|
158
|
-
term3 = 0
|
|
159
|
-
else:
|
|
160
|
-
term1 = torch.pow(target, 2 - p) / ((1 - p) * (2 - p))
|
|
161
|
-
term2 = target * torch.pow(pred_clamped, 1 - p) / (1 - p)
|
|
162
|
-
term3 = torch.pow(pred_clamped, 2 - p) / (2 - p)
|
|
163
|
-
return torch.nan_to_num( # Tweedie negative log-likelihood (constant omitted)
|
|
164
|
-
2 * (term1 - term2 + term3),
|
|
165
|
-
nan=eps,
|
|
166
|
-
posinf=max_clip,
|
|
167
|
-
neginf=-max_clip
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
@staticmethod
|
|
171
|
-
def free_cuda() -> None:
|
|
172
|
-
print(">>> Moving all models to CPU...")
|
|
173
|
-
for obj in gc.get_objects():
|
|
174
|
-
try:
|
|
175
|
-
if hasattr(obj, "to") and callable(obj.to):
|
|
176
|
-
obj.to("cpu")
|
|
177
|
-
except Exception:
|
|
178
|
-
pass
|
|
179
|
-
|
|
180
|
-
print(">>> Releasing tensor/optimizer/DataLoader references...")
|
|
181
|
-
gc.collect()
|
|
182
|
-
|
|
183
|
-
print(">>> Clearing CUDA cache...")
|
|
184
|
-
if torch.cuda.is_available():
|
|
185
|
-
torch.cuda.empty_cache()
|
|
186
|
-
torch.cuda.synchronize()
|
|
187
|
-
print(">>> CUDA memory released.")
|
|
188
|
-
else:
|
|
189
|
-
print(">>> CUDA not available; cleanup skipped.")
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
class DistributedUtils:
|
|
193
|
-
_cached_state: Optional[tuple] = None
|
|
194
|
-
|
|
195
|
-
@staticmethod
|
|
196
|
-
def setup_ddp():
|
|
197
|
-
"""Initialize the DDP process group for distributed training."""
|
|
198
|
-
if dist.is_initialized():
|
|
199
|
-
if DistributedUtils._cached_state is None:
|
|
200
|
-
rank = dist.get_rank()
|
|
201
|
-
world_size = dist.get_world_size()
|
|
202
|
-
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
203
|
-
DistributedUtils._cached_state = (
|
|
204
|
-
True,
|
|
205
|
-
local_rank,
|
|
206
|
-
rank,
|
|
207
|
-
world_size,
|
|
208
|
-
)
|
|
209
|
-
return DistributedUtils._cached_state
|
|
210
|
-
|
|
211
|
-
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
|
212
|
-
rank = int(os.environ["RANK"])
|
|
213
|
-
world_size = int(os.environ["WORLD_SIZE"])
|
|
214
|
-
local_rank = int(os.environ["LOCAL_RANK"])
|
|
215
|
-
|
|
216
|
-
if os.name == "nt" and torch.cuda.is_available() and world_size > 1:
|
|
217
|
-
print(
|
|
218
|
-
">>> DDP Setup Disabled: Windows CUDA DDP is not supported. "
|
|
219
|
-
"Falling back to single process."
|
|
220
|
-
)
|
|
221
|
-
return False, 0, 0, 1
|
|
222
|
-
|
|
223
|
-
if torch.cuda.is_available():
|
|
224
|
-
torch.cuda.set_device(local_rank)
|
|
225
|
-
|
|
226
|
-
timeout_seconds = int(os.environ.get(
|
|
227
|
-
"BAYESOPT_DDP_TIMEOUT_SECONDS", "1800"))
|
|
228
|
-
timeout = timedelta(seconds=max(1, timeout_seconds))
|
|
229
|
-
backend = "gloo"
|
|
230
|
-
if torch.cuda.is_available() and os.name != "nt":
|
|
231
|
-
try:
|
|
232
|
-
if getattr(dist, "is_nccl_available", lambda: False)():
|
|
233
|
-
backend = "nccl"
|
|
234
|
-
except Exception:
|
|
235
|
-
backend = "gloo"
|
|
236
|
-
|
|
237
|
-
dist.init_process_group(
|
|
238
|
-
backend=backend, init_method="env://", timeout=timeout)
|
|
239
|
-
print(
|
|
240
|
-
f">>> DDP Initialized ({backend}, timeout={timeout_seconds}s): "
|
|
241
|
-
f"Rank {rank}/{world_size}, Local Rank {local_rank}"
|
|
242
|
-
)
|
|
243
|
-
DistributedUtils._cached_state = (
|
|
244
|
-
True,
|
|
245
|
-
local_rank,
|
|
246
|
-
rank,
|
|
247
|
-
world_size,
|
|
248
|
-
)
|
|
249
|
-
return DistributedUtils._cached_state
|
|
250
|
-
else:
|
|
251
|
-
print(
|
|
252
|
-
f">>> DDP Setup Failed: RANK or WORLD_SIZE not found in env. Keys found: {list(os.environ.keys())}"
|
|
253
|
-
)
|
|
254
|
-
print(
|
|
255
|
-
">>> Hint: launch with torchrun --nproc_per_node=<N> <script.py>"
|
|
256
|
-
)
|
|
257
|
-
return False, 0, 0, 1
|
|
258
|
-
|
|
259
|
-
@staticmethod
|
|
260
|
-
def cleanup_ddp():
|
|
261
|
-
"""Destroy the DDP process group and clear cached state."""
|
|
262
|
-
if dist.is_initialized():
|
|
263
|
-
dist.destroy_process_group()
|
|
264
|
-
DistributedUtils._cached_state = None
|
|
265
|
-
|
|
266
|
-
@staticmethod
|
|
267
|
-
def is_main_process():
|
|
268
|
-
return not dist.is_initialized() or dist.get_rank() == 0
|
|
269
|
-
|
|
270
|
-
@staticmethod
|
|
271
|
-
def world_size() -> int:
|
|
272
|
-
return dist.get_world_size() if dist.is_initialized() else 1
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
class PlotUtils:
|
|
276
|
-
# Plot helpers shared across models.
|
|
277
|
-
|
|
278
|
-
@staticmethod
|
|
279
|
-
def split_data(data: pd.DataFrame, col_nme: str, wgt_nme: str, n_bins: int = 10) -> pd.DataFrame:
|
|
280
|
-
data_sorted = data.sort_values(by=col_nme, ascending=True).copy()
|
|
281
|
-
data_sorted['cum_weight'] = data_sorted[wgt_nme].cumsum()
|
|
282
|
-
w_sum = data_sorted[wgt_nme].sum()
|
|
283
|
-
if w_sum <= EPS:
|
|
284
|
-
data_sorted['bins'] = 0
|
|
285
|
-
else:
|
|
286
|
-
data_sorted['bins'] = np.floor(
|
|
287
|
-
data_sorted['cum_weight'] * float(n_bins) / w_sum
|
|
288
|
-
)
|
|
289
|
-
data_sorted.loc[(data_sorted['bins'] == n_bins),
|
|
290
|
-
'bins'] = n_bins - 1
|
|
291
|
-
return data_sorted.groupby(['bins'], observed=True).sum(numeric_only=True)
|
|
292
|
-
|
|
293
|
-
@staticmethod
|
|
294
|
-
def plot_lift_ax(ax, plot_data, title, pred_label='Predicted', act_label='Actual', weight_label='Earned Exposure'):
|
|
295
|
-
ax.plot(plot_data.index, plot_data['act_v'],
|
|
296
|
-
label=act_label, color='red')
|
|
297
|
-
ax.plot(plot_data.index, plot_data['exp_v'],
|
|
298
|
-
label=pred_label, color='blue')
|
|
299
|
-
ax.set_title(title, fontsize=8)
|
|
300
|
-
ax.set_xticks(plot_data.index)
|
|
301
|
-
ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
|
|
302
|
-
ax.tick_params(axis='y', labelsize=6)
|
|
303
|
-
ax.legend(loc='upper left', fontsize=5, frameon=False)
|
|
304
|
-
ax.margins(0.05)
|
|
305
|
-
ax2 = ax.twinx()
|
|
306
|
-
ax2.bar(plot_data.index, plot_data['weight'],
|
|
307
|
-
alpha=0.5, color='seagreen',
|
|
308
|
-
label=weight_label)
|
|
309
|
-
ax2.tick_params(axis='y', labelsize=6)
|
|
310
|
-
ax2.legend(loc='upper right', fontsize=5, frameon=False)
|
|
311
|
-
|
|
312
|
-
@staticmethod
|
|
313
|
-
def plot_dlift_ax(ax, plot_data, title, label1, label2, act_label='Actual', weight_label='Earned Exposure'):
|
|
314
|
-
ax.plot(plot_data.index, plot_data['act_v'],
|
|
315
|
-
label=act_label, color='red')
|
|
316
|
-
ax.plot(plot_data.index, plot_data['exp_v1'],
|
|
317
|
-
label=label1, color='blue')
|
|
318
|
-
ax.plot(plot_data.index, plot_data['exp_v2'],
|
|
319
|
-
label=label2, color='black')
|
|
320
|
-
ax.set_title(title, fontsize=8)
|
|
321
|
-
ax.set_xticks(plot_data.index)
|
|
322
|
-
ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
|
|
323
|
-
ax.set_xlabel(f'{label1} / {label2}', fontsize=6)
|
|
324
|
-
ax.tick_params(axis='y', labelsize=6)
|
|
325
|
-
ax.legend(loc='upper left', fontsize=5, frameon=False)
|
|
326
|
-
ax.margins(0.1)
|
|
327
|
-
ax2 = ax.twinx()
|
|
328
|
-
ax2.bar(plot_data.index, plot_data['weight'],
|
|
329
|
-
alpha=0.5, color='seagreen',
|
|
330
|
-
label=weight_label)
|
|
331
|
-
ax2.tick_params(axis='y', labelsize=6)
|
|
332
|
-
ax2.legend(loc='upper right', fontsize=5, frameon=False)
|
|
333
|
-
|
|
334
|
-
@staticmethod
|
|
335
|
-
def plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
336
|
-
weight_list, tgt_nme, n_bins: int = 10,
|
|
337
|
-
fig_nme: str = 'Lift Chart'):
|
|
338
|
-
if plot_curves_common is not None:
|
|
339
|
-
save_path = os.path.join(
|
|
340
|
-
os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
|
|
341
|
-
plot_curves_common.plot_lift_curve(
|
|
342
|
-
pred_model,
|
|
343
|
-
w_act_list,
|
|
344
|
-
weight_list,
|
|
345
|
-
n_bins=n_bins,
|
|
346
|
-
title=f'Lift Chart of {tgt_nme}',
|
|
347
|
-
pred_label='Predicted',
|
|
348
|
-
act_label='Actual',
|
|
349
|
-
weight_label='Earned Exposure',
|
|
350
|
-
pred_weighted=False,
|
|
351
|
-
actual_weighted=True,
|
|
352
|
-
save_path=save_path,
|
|
353
|
-
show=False,
|
|
354
|
-
)
|
|
355
|
-
return
|
|
356
|
-
if plt is None:
|
|
357
|
-
_plot_skip("lift plot")
|
|
358
|
-
return
|
|
359
|
-
lift_data = pd.DataFrame({
|
|
360
|
-
'pred': pred_model,
|
|
361
|
-
'w_pred': w_pred_list,
|
|
362
|
-
'act': w_act_list,
|
|
363
|
-
'weight': weight_list
|
|
364
|
-
})
|
|
365
|
-
plot_data = PlotUtils.split_data(lift_data, 'pred', 'weight', n_bins)
|
|
366
|
-
plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
|
|
367
|
-
plot_data['act_v'] = plot_data['act'] / plot_data['weight']
|
|
368
|
-
plot_data.reset_index(inplace=True)
|
|
369
|
-
|
|
370
|
-
fig = plt.figure(figsize=(7, 5))
|
|
371
|
-
ax = fig.add_subplot(111)
|
|
372
|
-
PlotUtils.plot_lift_ax(ax, plot_data, f'Lift Chart of {tgt_nme}')
|
|
373
|
-
plt.subplots_adjust(wspace=0.3)
|
|
374
|
-
|
|
375
|
-
save_path = os.path.join(
|
|
376
|
-
os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
|
|
377
|
-
IOUtils.ensure_parent_dir(save_path)
|
|
378
|
-
plt.savefig(save_path, dpi=300)
|
|
379
|
-
plt.close(fig)
|
|
380
|
-
|
|
381
|
-
@staticmethod
|
|
382
|
-
def plot_dlift_list(pred_model_1, pred_model_2,
|
|
383
|
-
model_nme_1, model_nme_2,
|
|
384
|
-
tgt_nme,
|
|
385
|
-
w_list, w_act_list, n_bins: int = 10,
|
|
386
|
-
fig_nme: str = 'Double Lift Chart'):
|
|
387
|
-
if plot_curves_common is not None:
|
|
388
|
-
save_path = os.path.join(
|
|
389
|
-
os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
|
|
390
|
-
plot_curves_common.plot_double_lift_curve(
|
|
391
|
-
pred_model_1,
|
|
392
|
-
pred_model_2,
|
|
393
|
-
w_act_list,
|
|
394
|
-
w_list,
|
|
395
|
-
n_bins=n_bins,
|
|
396
|
-
title=f'Double Lift Chart of {tgt_nme}',
|
|
397
|
-
label1=model_nme_1,
|
|
398
|
-
label2=model_nme_2,
|
|
399
|
-
pred1_weighted=False,
|
|
400
|
-
pred2_weighted=False,
|
|
401
|
-
actual_weighted=True,
|
|
402
|
-
save_path=save_path,
|
|
403
|
-
show=False,
|
|
404
|
-
)
|
|
405
|
-
return
|
|
406
|
-
if plt is None:
|
|
407
|
-
_plot_skip("double lift plot")
|
|
408
|
-
return
|
|
409
|
-
lift_data = pd.DataFrame({
|
|
410
|
-
'pred1': pred_model_1,
|
|
411
|
-
'pred2': pred_model_2,
|
|
412
|
-
'act': w_act_list,
|
|
413
|
-
'weight': w_list
|
|
414
|
-
})
|
|
415
|
-
lift_data['diff_ly'] = lift_data['pred1'] / lift_data['pred2']
|
|
416
|
-
lift_data['w_pred1'] = lift_data['pred1'] * lift_data['weight']
|
|
417
|
-
lift_data['w_pred2'] = lift_data['pred2'] * lift_data['weight']
|
|
418
|
-
plot_data = PlotUtils.split_data(
|
|
419
|
-
lift_data, 'diff_ly', 'weight', n_bins)
|
|
420
|
-
plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
|
|
421
|
-
plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
|
|
422
|
-
plot_data['act_v'] = plot_data['act']/plot_data['act']
|
|
423
|
-
plot_data.reset_index(inplace=True)
|
|
424
|
-
|
|
425
|
-
fig = plt.figure(figsize=(7, 5))
|
|
426
|
-
ax = fig.add_subplot(111)
|
|
427
|
-
PlotUtils.plot_dlift_ax(
|
|
428
|
-
ax, plot_data, f'Double Lift Chart of {tgt_nme}', model_nme_1, model_nme_2)
|
|
429
|
-
plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
|
|
430
|
-
|
|
431
|
-
save_path = os.path.join(
|
|
432
|
-
os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
|
|
433
|
-
IOUtils.ensure_parent_dir(save_path)
|
|
434
|
-
plt.savefig(save_path, dpi=300)
|
|
435
|
-
plt.close(fig)
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
def infer_factor_and_cate_list(train_df: pd.DataFrame,
|
|
439
|
-
test_df: pd.DataFrame,
|
|
440
|
-
resp_nme: str,
|
|
441
|
-
weight_nme: str,
|
|
442
|
-
binary_resp_nme: Optional[str] = None,
|
|
443
|
-
factor_nmes: Optional[List[str]] = None,
|
|
444
|
-
cate_list: Optional[List[str]] = None,
|
|
445
|
-
infer_categorical_max_unique: int = 50,
|
|
446
|
-
infer_categorical_max_ratio: float = 0.05) -> Tuple[List[str], List[str]]:
|
|
447
|
-
"""Infer factor_nmes/cate_list when feature names are not provided.
|
|
448
|
-
|
|
449
|
-
Rules:
|
|
450
|
-
- factor_nmes: start from shared train/test columns, exclude target/weight/(optional binary target).
|
|
451
|
-
- cate_list: object/category/bool plus low-cardinality integer columns.
|
|
452
|
-
- Always intersect with shared train/test columns to avoid mismatches.
|
|
453
|
-
"""
|
|
454
|
-
excluded = {resp_nme, weight_nme}
|
|
455
|
-
if binary_resp_nme:
|
|
456
|
-
excluded.add(binary_resp_nme)
|
|
457
|
-
|
|
458
|
-
common_cols = [c for c in train_df.columns if c in test_df.columns]
|
|
459
|
-
if factor_nmes is None:
|
|
460
|
-
factors = [c for c in common_cols if c not in excluded]
|
|
461
|
-
else:
|
|
462
|
-
factors = [
|
|
463
|
-
c for c in factor_nmes if c in common_cols and c not in excluded]
|
|
464
|
-
|
|
465
|
-
if cate_list is not None:
|
|
466
|
-
cats = [c for c in cate_list if c in factors]
|
|
467
|
-
return factors, cats
|
|
468
|
-
|
|
469
|
-
n_rows = max(1, len(train_df))
|
|
470
|
-
cats: List[str] = []
|
|
471
|
-
for col in factors:
|
|
472
|
-
s = train_df[col]
|
|
473
|
-
if pd.api.types.is_bool_dtype(s) or pd.api.types.is_object_dtype(s) or isinstance(s.dtype, pd.CategoricalDtype):
|
|
474
|
-
cats.append(col)
|
|
475
|
-
continue
|
|
476
|
-
if pd.api.types.is_integer_dtype(s):
|
|
477
|
-
nunique = int(s.nunique(dropna=True))
|
|
478
|
-
if nunique <= infer_categorical_max_unique or (nunique / n_rows) <= infer_categorical_max_ratio:
|
|
479
|
-
cats.append(col)
|
|
480
|
-
return factors, cats
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
# Backward-compatible functional wrappers
|
|
484
|
-
def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
|
|
485
|
-
return IOUtils.csv_to_dict(file_path)
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
def ensure_parent_dir(file_path: str) -> None:
|
|
489
|
-
IOUtils.ensure_parent_dir(file_path)
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
|
|
493
|
-
return TrainingUtils.compute_batch_size(data_size, learning_rate, batch_num, minimum)
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
# Tweedie deviance loss for PyTorch.
|
|
497
|
-
# Reference: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-poisson-gamma-and-tweedie-deviances
|
|
498
|
-
def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
|
|
499
|
-
return TrainingUtils.tweedie_loss(pred, target, p=p, eps=eps, max_clip=max_clip)
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
# CUDA memory release helper.
|
|
503
|
-
def free_cuda():
|
|
504
|
-
TrainingUtils.free_cuda()
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
class TorchTrainerMixin:
|
|
508
|
-
# Shared helpers for Torch tabular trainers.
|
|
509
|
-
|
|
510
|
-
def _device_type(self) -> str:
|
|
511
|
-
return getattr(self, "device", torch.device("cpu")).type
|
|
512
|
-
|
|
513
|
-
def _resolve_resource_profile(self) -> str:
|
|
514
|
-
profile = getattr(self, "resource_profile", None)
|
|
515
|
-
if not profile:
|
|
516
|
-
profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
|
|
517
|
-
profile = str(profile).strip().lower()
|
|
518
|
-
if profile in {"cpu", "mps", "cuda"}:
|
|
519
|
-
profile = "auto"
|
|
520
|
-
if profile not in {"auto", "throughput", "memory_saving"}:
|
|
521
|
-
profile = "auto"
|
|
522
|
-
if profile == "auto" and self._device_type() == "cuda":
|
|
523
|
-
profile = "throughput"
|
|
524
|
-
return profile
|
|
525
|
-
|
|
526
|
-
def _log_resource_summary_once(self, profile: str) -> None:
|
|
527
|
-
if getattr(self, "_resource_summary_logged", False):
|
|
528
|
-
return
|
|
529
|
-
if dist.is_initialized() and not DistributedUtils.is_main_process():
|
|
530
|
-
return
|
|
531
|
-
self._resource_summary_logged = True
|
|
532
|
-
device = getattr(self, "device", torch.device("cpu"))
|
|
533
|
-
device_type = self._device_type()
|
|
534
|
-
cpu_count = os.cpu_count() or 1
|
|
535
|
-
cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
536
|
-
mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
|
|
537
|
-
ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
|
|
538
|
-
data_parallel = bool(getattr(self, "use_data_parallel", False))
|
|
539
|
-
print(
|
|
540
|
-
f">>> Resource summary: device={device}, device_type={device_type}, "
|
|
541
|
-
f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
|
|
542
|
-
f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
|
|
543
|
-
)
|
|
544
|
-
|
|
545
|
-
def _available_system_memory(self) -> Optional[int]:
|
|
546
|
-
if os.name == "nt":
|
|
547
|
-
class _MemStatus(ctypes.Structure):
|
|
548
|
-
_fields_ = [
|
|
549
|
-
("dwLength", ctypes.c_ulong),
|
|
550
|
-
("dwMemoryLoad", ctypes.c_ulong),
|
|
551
|
-
("ullTotalPhys", ctypes.c_ulonglong),
|
|
552
|
-
("ullAvailPhys", ctypes.c_ulonglong),
|
|
553
|
-
("ullTotalPageFile", ctypes.c_ulonglong),
|
|
554
|
-
("ullAvailPageFile", ctypes.c_ulonglong),
|
|
555
|
-
("ullTotalVirtual", ctypes.c_ulonglong),
|
|
556
|
-
("ullAvailVirtual", ctypes.c_ulonglong),
|
|
557
|
-
("sullAvailExtendedVirtual", ctypes.c_ulonglong),
|
|
558
|
-
]
|
|
559
|
-
status = _MemStatus()
|
|
560
|
-
status.dwLength = ctypes.sizeof(_MemStatus)
|
|
561
|
-
if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
|
|
562
|
-
return int(status.ullAvailPhys)
|
|
563
|
-
return None
|
|
564
|
-
try:
|
|
565
|
-
pages = os.sysconf("SC_AVPHYS_PAGES")
|
|
566
|
-
page_size = os.sysconf("SC_PAGE_SIZE")
|
|
567
|
-
return int(pages * page_size)
|
|
568
|
-
except Exception:
|
|
569
|
-
return None
|
|
570
|
-
|
|
571
|
-
def _available_cuda_memory(self) -> Optional[int]:
|
|
572
|
-
if not torch.cuda.is_available():
|
|
573
|
-
return None
|
|
574
|
-
try:
|
|
575
|
-
free_mem, _total_mem = torch.cuda.mem_get_info()
|
|
576
|
-
except Exception:
|
|
577
|
-
return None
|
|
578
|
-
return int(free_mem)
|
|
579
|
-
|
|
580
|
-
def _estimate_sample_bytes(self, dataset) -> Optional[int]:
|
|
581
|
-
try:
|
|
582
|
-
if len(dataset) == 0:
|
|
583
|
-
return None
|
|
584
|
-
sample = dataset[0]
|
|
585
|
-
except Exception:
|
|
586
|
-
return None
|
|
587
|
-
|
|
588
|
-
def _bytes(obj) -> int:
|
|
589
|
-
if obj is None:
|
|
590
|
-
return 0
|
|
591
|
-
if torch.is_tensor(obj):
|
|
592
|
-
return int(obj.element_size() * obj.nelement())
|
|
593
|
-
if isinstance(obj, np.ndarray):
|
|
594
|
-
return int(obj.nbytes)
|
|
595
|
-
if isinstance(obj, (list, tuple)):
|
|
596
|
-
return int(sum(_bytes(item) for item in obj))
|
|
597
|
-
if isinstance(obj, dict):
|
|
598
|
-
return int(sum(_bytes(item) for item in obj.values()))
|
|
599
|
-
return 0
|
|
600
|
-
|
|
601
|
-
sample_bytes = _bytes(sample)
|
|
602
|
-
return int(sample_bytes) if sample_bytes > 0 else None
|
|
603
|
-
|
|
604
|
-
def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
|
|
605
|
-
if batch_size <= 1:
|
|
606
|
-
return batch_size
|
|
607
|
-
sample_bytes = self._estimate_sample_bytes(dataset)
|
|
608
|
-
if sample_bytes is None:
|
|
609
|
-
return batch_size
|
|
610
|
-
device_type = self._device_type()
|
|
611
|
-
if device_type == "cuda":
|
|
612
|
-
available = self._available_cuda_memory()
|
|
613
|
-
if available is None:
|
|
614
|
-
return batch_size
|
|
615
|
-
if profile == "throughput":
|
|
616
|
-
budget_ratio = 0.8
|
|
617
|
-
overhead = 8.0
|
|
618
|
-
elif profile == "memory_saving":
|
|
619
|
-
budget_ratio = 0.5
|
|
620
|
-
overhead = 14.0
|
|
621
|
-
else:
|
|
622
|
-
budget_ratio = 0.6
|
|
623
|
-
overhead = 12.0
|
|
624
|
-
else:
|
|
625
|
-
available = self._available_system_memory()
|
|
626
|
-
if available is None:
|
|
627
|
-
return batch_size
|
|
628
|
-
if profile == "throughput":
|
|
629
|
-
budget_ratio = 0.4
|
|
630
|
-
overhead = 1.8
|
|
631
|
-
elif profile == "memory_saving":
|
|
632
|
-
budget_ratio = 0.25
|
|
633
|
-
overhead = 3.0
|
|
634
|
-
else:
|
|
635
|
-
budget_ratio = 0.3
|
|
636
|
-
overhead = 2.6
|
|
637
|
-
budget = int(available * budget_ratio)
|
|
638
|
-
per_sample = int(sample_bytes * overhead)
|
|
639
|
-
if per_sample <= 0:
|
|
640
|
-
return batch_size
|
|
641
|
-
max_batch = max(1, int(budget // per_sample))
|
|
642
|
-
if max_batch < batch_size:
|
|
643
|
-
print(
|
|
644
|
-
f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
|
|
645
|
-
f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
|
|
646
|
-
)
|
|
647
|
-
return min(batch_size, max_batch)
|
|
648
|
-
|
|
649
|
-
def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
|
|
650
|
-
if os.name == 'nt':
|
|
651
|
-
return 0
|
|
652
|
-
if getattr(self, "is_ddp_enabled", False):
|
|
653
|
-
return 0
|
|
654
|
-
profile = profile or self._resolve_resource_profile()
|
|
655
|
-
if profile == "memory_saving":
|
|
656
|
-
return 0
|
|
657
|
-
worker_cap = min(int(max_workers), os.cpu_count() or 1)
|
|
658
|
-
if self._device_type() == "mps":
|
|
659
|
-
worker_cap = min(worker_cap, 2)
|
|
660
|
-
return worker_cap
|
|
661
|
-
|
|
662
|
-
def _build_dataloader(self,
|
|
663
|
-
dataset,
|
|
664
|
-
N: int,
|
|
665
|
-
base_bs_gpu: tuple,
|
|
666
|
-
base_bs_cpu: tuple,
|
|
667
|
-
min_bs: int = 64,
|
|
668
|
-
target_effective_cuda: int = 1024,
|
|
669
|
-
target_effective_cpu: int = 512,
|
|
670
|
-
large_threshold: int = 200_000,
|
|
671
|
-
mid_threshold: int = 50_000):
|
|
672
|
-
profile = self._resolve_resource_profile()
|
|
673
|
-
self._log_resource_summary_once(profile)
|
|
674
|
-
batch_size = TrainingUtils.compute_batch_size(
|
|
675
|
-
data_size=len(dataset),
|
|
676
|
-
learning_rate=self.learning_rate,
|
|
677
|
-
batch_num=self.batch_num,
|
|
678
|
-
minimum=min_bs
|
|
679
|
-
)
|
|
680
|
-
gpu_large, gpu_mid, gpu_small = base_bs_gpu
|
|
681
|
-
cpu_mid, cpu_small = base_bs_cpu
|
|
682
|
-
|
|
683
|
-
if self._device_type() == 'cuda':
|
|
684
|
-
device_count = torch.cuda.device_count()
|
|
685
|
-
if getattr(self, "is_ddp_enabled", False):
|
|
686
|
-
device_count = 1
|
|
687
|
-
# In multi-GPU, increase min batch size so each GPU gets enough data.
|
|
688
|
-
if device_count > 1:
|
|
689
|
-
min_bs = min_bs * device_count
|
|
690
|
-
print(
|
|
691
|
-
f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
|
|
692
|
-
|
|
693
|
-
if N > large_threshold:
|
|
694
|
-
base_bs = gpu_large * device_count
|
|
695
|
-
elif N > mid_threshold:
|
|
696
|
-
base_bs = gpu_mid * device_count
|
|
697
|
-
else:
|
|
698
|
-
base_bs = gpu_small * device_count
|
|
699
|
-
else:
|
|
700
|
-
base_bs = cpu_mid if N > mid_threshold else cpu_small
|
|
701
|
-
|
|
702
|
-
# Recompute batch_size to respect the adjusted min_bs.
|
|
703
|
-
batch_size = TrainingUtils.compute_batch_size(
|
|
704
|
-
data_size=len(dataset),
|
|
705
|
-
learning_rate=self.learning_rate,
|
|
706
|
-
batch_num=self.batch_num,
|
|
707
|
-
minimum=min_bs
|
|
708
|
-
)
|
|
709
|
-
batch_size = min(batch_size, base_bs, N)
|
|
710
|
-
batch_size = self._cap_batch_size_by_memory(
|
|
711
|
-
dataset, batch_size, profile)
|
|
712
|
-
|
|
713
|
-
target_effective_bs = target_effective_cuda if self._device_type(
|
|
714
|
-
) == 'cuda' else target_effective_cpu
|
|
715
|
-
if getattr(self, "is_ddp_enabled", False):
|
|
716
|
-
world_size = max(1, DistributedUtils.world_size())
|
|
717
|
-
target_effective_bs = max(1, target_effective_bs // world_size)
|
|
718
|
-
|
|
719
|
-
world_size = getattr(self, "world_size", 1) if getattr(
|
|
720
|
-
self, "is_ddp_enabled", False) else 1
|
|
721
|
-
samples_per_rank = math.ceil(
|
|
722
|
-
N / max(1, world_size)) if world_size > 1 else N
|
|
723
|
-
steps_per_epoch = max(
|
|
724
|
-
1, math.ceil(samples_per_rank / max(1, batch_size)))
|
|
725
|
-
# Limit gradient accumulation to avoid scaling beyond actual batches.
|
|
726
|
-
desired_accum = max(1, target_effective_bs // max(1, batch_size))
|
|
727
|
-
accum_steps = max(1, min(desired_accum, steps_per_epoch))
|
|
728
|
-
|
|
729
|
-
# Linux (posix) uses fork; Windows (nt) uses spawn with higher overhead.
|
|
730
|
-
workers = self._resolve_num_workers(8, profile=profile)
|
|
731
|
-
prefetch_factor = None
|
|
732
|
-
if workers > 0:
|
|
733
|
-
prefetch_factor = 4 if profile == "throughput" else 2
|
|
734
|
-
persistent = workers > 0 and profile != "memory_saving"
|
|
735
|
-
print(
|
|
736
|
-
f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
|
|
737
|
-
f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
|
|
738
|
-
sampler = None
|
|
739
|
-
if dist.is_initialized():
|
|
740
|
-
sampler = DistributedSampler(dataset, shuffle=True)
|
|
741
|
-
shuffle = False # DistributedSampler handles shuffling.
|
|
742
|
-
else:
|
|
743
|
-
shuffle = True
|
|
744
|
-
|
|
745
|
-
dataloader = DataLoader(
|
|
746
|
-
dataset,
|
|
747
|
-
batch_size=batch_size,
|
|
748
|
-
shuffle=shuffle,
|
|
749
|
-
sampler=sampler,
|
|
750
|
-
num_workers=workers,
|
|
751
|
-
pin_memory=(self._device_type() == 'cuda'),
|
|
752
|
-
persistent_workers=persistent,
|
|
753
|
-
**({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
|
|
754
|
-
)
|
|
755
|
-
return dataloader, accum_steps
|
|
756
|
-
|
|
757
|
-
def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
|
|
758
|
-
profile = self._resolve_resource_profile()
|
|
759
|
-
val_bs = accum_steps * train_dataloader.batch_size
|
|
760
|
-
val_workers = self._resolve_num_workers(4, profile=profile)
|
|
761
|
-
prefetch_factor = None
|
|
762
|
-
if val_workers > 0:
|
|
763
|
-
prefetch_factor = 2
|
|
764
|
-
return DataLoader(
|
|
765
|
-
dataset,
|
|
766
|
-
batch_size=val_bs,
|
|
767
|
-
shuffle=False,
|
|
768
|
-
num_workers=val_workers,
|
|
769
|
-
pin_memory=(self._device_type() == 'cuda'),
|
|
770
|
-
persistent_workers=(val_workers > 0 and profile != "memory_saving"),
|
|
771
|
-
**({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
|
|
772
|
-
)
|
|
773
|
-
|
|
774
|
-
def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
|
|
775
|
-
task = getattr(self, "task_type", "regression")
|
|
776
|
-
if task == 'classification':
|
|
777
|
-
loss_fn = nn.BCEWithLogitsLoss(reduction='none')
|
|
778
|
-
return loss_fn(y_pred, y_true).view(-1)
|
|
779
|
-
if apply_softplus:
|
|
780
|
-
y_pred = F.softplus(y_pred)
|
|
781
|
-
y_pred = torch.clamp(y_pred, min=1e-6)
|
|
782
|
-
power = getattr(self, "tw_power", 1.5)
|
|
783
|
-
return tweedie_loss(y_pred, y_true, p=power).view(-1)
|
|
784
|
-
|
|
785
|
-
def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
|
|
786
|
-
losses = self._compute_losses(
|
|
787
|
-
y_pred, y_true, apply_softplus=apply_softplus)
|
|
788
|
-
weighted_loss = (losses * weights.view(-1)).sum() / \
|
|
789
|
-
torch.clamp(weights.sum(), min=EPS)
|
|
790
|
-
return weighted_loss
|
|
791
|
-
|
|
792
|
-
def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
|
|
793
|
-
ignore_keys: Optional[List[str]] = None):
|
|
794
|
-
if val_loss < best_loss:
|
|
795
|
-
ignore_keys = ignore_keys or []
|
|
796
|
-
# Unwrap DDP module to avoid module. prefix in state_dict keys
|
|
797
|
-
base_module = model.module if hasattr(model, "module") else model
|
|
798
|
-
state_dict = {
|
|
799
|
-
k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
|
|
800
|
-
for k, v in base_module.state_dict().items()
|
|
801
|
-
if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
|
|
802
|
-
}
|
|
803
|
-
return val_loss, state_dict, 0, False
|
|
804
|
-
patience_counter += 1
|
|
805
|
-
should_stop = best_state is not None and patience_counter >= getattr(
|
|
806
|
-
self, "patience", 0)
|
|
807
|
-
return best_loss, best_state, patience_counter, should_stop
|
|
808
|
-
|
|
809
|
-
def _train_model(self,
|
|
810
|
-
model,
|
|
811
|
-
dataloader,
|
|
812
|
-
accum_steps,
|
|
813
|
-
optimizer,
|
|
814
|
-
scaler,
|
|
815
|
-
forward_fn,
|
|
816
|
-
val_forward_fn=None,
|
|
817
|
-
apply_softplus: bool = False,
|
|
818
|
-
clip_fn=None,
|
|
819
|
-
trial: Optional[optuna.trial.Trial] = None,
|
|
820
|
-
loss_curve_path: Optional[str] = None):
|
|
821
|
-
device_type = self._device_type()
|
|
822
|
-
best_loss = float('inf')
|
|
823
|
-
best_state = None
|
|
824
|
-
patience_counter = 0
|
|
825
|
-
stop_training = False
|
|
826
|
-
train_history: List[float] = []
|
|
827
|
-
val_history: List[float] = []
|
|
828
|
-
|
|
829
|
-
is_ddp_model = isinstance(model, DDP)
|
|
830
|
-
|
|
831
|
-
for epoch in range(1, getattr(self, "epochs", 1) + 1):
|
|
832
|
-
epoch_start_ts = time.time()
|
|
833
|
-
val_weighted_loss = None
|
|
834
|
-
if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
|
|
835
|
-
self.dataloader_sampler.set_epoch(epoch)
|
|
836
|
-
|
|
837
|
-
model.train()
|
|
838
|
-
optimizer.zero_grad()
|
|
839
|
-
|
|
840
|
-
epoch_loss_sum = None
|
|
841
|
-
epoch_weight_sum = None
|
|
842
|
-
for step, batch in enumerate(dataloader):
|
|
843
|
-
is_update_step = ((step + 1) % accum_steps == 0) or \
|
|
844
|
-
((step + 1) == len(dataloader))
|
|
845
|
-
sync_cm = model.no_sync if (
|
|
846
|
-
is_ddp_model and not is_update_step) else nullcontext
|
|
847
|
-
|
|
848
|
-
with sync_cm():
|
|
849
|
-
with autocast(enabled=(device_type == 'cuda')):
|
|
850
|
-
y_pred, y_true, w = forward_fn(batch)
|
|
851
|
-
weighted_loss = self._compute_weighted_loss(
|
|
852
|
-
y_pred, y_true, w, apply_softplus=apply_softplus)
|
|
853
|
-
loss_for_backward = weighted_loss / accum_steps
|
|
854
|
-
|
|
855
|
-
batch_weight = torch.clamp(
|
|
856
|
-
w.detach().sum(), min=EPS).to(dtype=torch.float32)
|
|
857
|
-
loss_val = weighted_loss.detach().to(dtype=torch.float32)
|
|
858
|
-
if epoch_loss_sum is None:
|
|
859
|
-
epoch_loss_sum = torch.zeros(
|
|
860
|
-
(), device=batch_weight.device, dtype=torch.float32)
|
|
861
|
-
epoch_weight_sum = torch.zeros(
|
|
862
|
-
(), device=batch_weight.device, dtype=torch.float32)
|
|
863
|
-
epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
|
|
864
|
-
epoch_weight_sum = epoch_weight_sum + batch_weight
|
|
865
|
-
scaler.scale(loss_for_backward).backward()
|
|
866
|
-
|
|
867
|
-
if is_update_step:
|
|
868
|
-
if clip_fn is not None:
|
|
869
|
-
clip_fn()
|
|
870
|
-
scaler.step(optimizer)
|
|
871
|
-
scaler.update()
|
|
872
|
-
optimizer.zero_grad()
|
|
873
|
-
|
|
874
|
-
if epoch_loss_sum is None or epoch_weight_sum is None:
|
|
875
|
-
train_epoch_loss = 0.0
|
|
876
|
-
else:
|
|
877
|
-
train_epoch_loss = (
|
|
878
|
-
epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
|
|
879
|
-
).item()
|
|
880
|
-
train_history.append(float(train_epoch_loss))
|
|
881
|
-
|
|
882
|
-
if val_forward_fn is not None:
|
|
883
|
-
should_compute_val = (not dist.is_initialized()
|
|
884
|
-
or DistributedUtils.is_main_process())
|
|
885
|
-
val_device = getattr(self, "device", torch.device("cpu"))
|
|
886
|
-
if not isinstance(val_device, torch.device):
|
|
887
|
-
val_device = torch.device(val_device)
|
|
888
|
-
loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
|
|
889
|
-
"cpu")
|
|
890
|
-
val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
|
|
891
|
-
|
|
892
|
-
if should_compute_val:
|
|
893
|
-
model.eval()
|
|
894
|
-
with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
|
|
895
|
-
val_result = val_forward_fn()
|
|
896
|
-
if isinstance(val_result, tuple) and len(val_result) == 3:
|
|
897
|
-
y_val_pred, y_val_true, w_val = val_result
|
|
898
|
-
val_weighted_loss = self._compute_weighted_loss(
|
|
899
|
-
y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
|
|
900
|
-
else:
|
|
901
|
-
val_weighted_loss = val_result
|
|
902
|
-
val_loss_tensor[0] = float(val_weighted_loss)
|
|
903
|
-
|
|
904
|
-
if dist.is_initialized():
|
|
905
|
-
dist.broadcast(val_loss_tensor, src=0)
|
|
906
|
-
val_weighted_loss = float(val_loss_tensor.item())
|
|
907
|
-
|
|
908
|
-
val_history.append(val_weighted_loss)
|
|
909
|
-
|
|
910
|
-
best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
|
|
911
|
-
val_weighted_loss, best_loss, best_state, patience_counter, model)
|
|
912
|
-
|
|
913
|
-
prune_flag = False
|
|
914
|
-
is_main_rank = DistributedUtils.is_main_process()
|
|
915
|
-
# Only main process reports to Optuna to avoid duplicate reports
|
|
916
|
-
if trial is not None and is_main_rank:
|
|
917
|
-
trial.report(val_weighted_loss, epoch)
|
|
918
|
-
prune_flag = trial.should_prune()
|
|
919
|
-
|
|
920
|
-
if dist.is_initialized():
|
|
921
|
-
prune_device = getattr(self, "device", torch.device("cpu"))
|
|
922
|
-
if not isinstance(prune_device, torch.device):
|
|
923
|
-
prune_device = torch.device(prune_device)
|
|
924
|
-
prune_tensor = torch.zeros(1, device=prune_device)
|
|
925
|
-
if is_main_rank:
|
|
926
|
-
prune_tensor.fill_(1 if prune_flag else 0)
|
|
927
|
-
dist.broadcast(prune_tensor, src=0)
|
|
928
|
-
prune_flag = bool(prune_tensor.item())
|
|
929
|
-
|
|
930
|
-
if prune_flag:
|
|
931
|
-
raise optuna.TrialPruned()
|
|
932
|
-
|
|
933
|
-
if stop_training:
|
|
934
|
-
break
|
|
935
|
-
|
|
936
|
-
should_log_epoch = (not dist.is_initialized()
|
|
937
|
-
or DistributedUtils.is_main_process())
|
|
938
|
-
if should_log_epoch:
|
|
939
|
-
elapsed = int(time.time() - epoch_start_ts)
|
|
940
|
-
if val_weighted_loss is None:
|
|
941
|
-
print(
|
|
942
|
-
f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
|
|
943
|
-
f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
|
|
944
|
-
flush=True,
|
|
945
|
-
)
|
|
946
|
-
else:
|
|
947
|
-
print(
|
|
948
|
-
f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
|
|
949
|
-
f"train_loss={float(train_epoch_loss):.6f} "
|
|
950
|
-
f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
|
|
951
|
-
flush=True,
|
|
952
|
-
)
|
|
953
|
-
|
|
954
|
-
# Periodic memory cleanup to prevent accumulation (every 10 epochs)
|
|
955
|
-
if epoch % 10 == 0:
|
|
956
|
-
if torch.cuda.is_available():
|
|
957
|
-
torch.cuda.empty_cache()
|
|
958
|
-
gc.collect()
|
|
959
|
-
|
|
960
|
-
history = {"train": train_history, "val": val_history}
|
|
961
|
-
self._plot_loss_curve(history, loss_curve_path)
|
|
962
|
-
return best_state, history
|
|
963
|
-
|
|
964
|
-
def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
|
|
965
|
-
if not save_path:
|
|
966
|
-
return
|
|
967
|
-
if dist.is_initialized() and not DistributedUtils.is_main_process():
|
|
968
|
-
return
|
|
969
|
-
train_hist = history.get("train", []) if history else []
|
|
970
|
-
val_hist = history.get("val", []) if history else []
|
|
971
|
-
if not train_hist and not val_hist:
|
|
972
|
-
return
|
|
973
|
-
if plot_loss_curve_common is not None:
|
|
974
|
-
plot_loss_curve_common(
|
|
975
|
-
history=history,
|
|
976
|
-
title="Loss vs. Epoch",
|
|
977
|
-
save_path=save_path,
|
|
978
|
-
show=False,
|
|
979
|
-
)
|
|
980
|
-
else:
|
|
981
|
-
if plt is None:
|
|
982
|
-
_plot_skip("loss curve")
|
|
983
|
-
return
|
|
984
|
-
ensure_parent_dir(save_path)
|
|
985
|
-
epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
|
|
986
|
-
fig = plt.figure(figsize=(8, 4))
|
|
987
|
-
ax = fig.add_subplot(111)
|
|
988
|
-
if train_hist:
|
|
989
|
-
ax.plot(range(1, len(train_hist) + 1), train_hist,
|
|
990
|
-
label='Train Loss', color='tab:blue')
|
|
991
|
-
if val_hist:
|
|
992
|
-
ax.plot(range(1, len(val_hist) + 1), val_hist,
|
|
993
|
-
label='Validation Loss', color='tab:orange')
|
|
994
|
-
ax.set_xlabel('Epoch')
|
|
995
|
-
ax.set_ylabel('Weighted Loss')
|
|
996
|
-
ax.set_title('Loss vs. Epoch')
|
|
997
|
-
ax.grid(True, linestyle='--', alpha=0.3)
|
|
998
|
-
ax.legend()
|
|
999
|
-
plt.tight_layout()
|
|
1000
|
-
plt.savefig(save_path, dpi=300)
|
|
1001
|
-
plt.close(fig)
|
|
1002
|
-
print(f"[Training] Loss curve saved to {save_path}")
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
# =============================================================================
|
|
1006
|
-
# Plotting helpers
|
|
1007
|
-
# =============================================================================
|
|
1008
|
-
|
|
1009
|
-
def split_data(data, col_nme, wgt_nme, n_bins=10):
|
|
1010
|
-
return PlotUtils.split_data(data, col_nme, wgt_nme, n_bins)
|
|
1011
|
-
|
|
1012
|
-
# Lift curve plotting wrapper
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
def plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
1016
|
-
weight_list, tgt_nme, n_bins=10,
|
|
1017
|
-
fig_nme='Lift Chart'):
|
|
1018
|
-
return PlotUtils.plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
1019
|
-
weight_list, tgt_nme, n_bins, fig_nme)
|
|
1020
|
-
|
|
1021
|
-
# Double lift curve plotting wrapper
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
def plot_dlift_list(pred_model_1, pred_model_2,
|
|
1025
|
-
model_nme_1, model_nme_2,
|
|
1026
|
-
tgt_nme,
|
|
1027
|
-
w_list, w_act_list, n_bins=10,
|
|
1028
|
-
fig_nme='Double Lift Chart'):
|
|
1029
|
-
return PlotUtils.plot_dlift_list(pred_model_1, pred_model_2,
|
|
1030
|
-
model_nme_1, model_nme_2,
|
|
1031
|
-
tgt_nme, w_list, w_act_list,
|
|
1032
|
-
n_bins, fig_nme)
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
# =============================================================================
|
|
1036
|
-
# Logging System
|
|
1037
|
-
# =============================================================================
|
|
1038
|
-
import logging
|
|
1039
|
-
from functools import lru_cache
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
@lru_cache(maxsize=1)
|
|
1043
|
-
def _get_package_logger() -> logging.Logger:
|
|
1044
|
-
"""Get or create the package-level logger with consistent formatting."""
|
|
1045
|
-
logger = logging.getLogger("ins_pricing")
|
|
1046
|
-
if not logger.handlers:
|
|
1047
|
-
handler = logging.StreamHandler()
|
|
1048
|
-
formatter = logging.Formatter(
|
|
1049
|
-
"[%(levelname)s][%(name)s] %(message)s"
|
|
1050
|
-
)
|
|
1051
|
-
handler.setFormatter(formatter)
|
|
1052
|
-
logger.addHandler(handler)
|
|
1053
|
-
# Default to INFO, can be changed via environment variable
|
|
1054
|
-
level = os.environ.get("INS_PRICING_LOG_LEVEL", "INFO").upper()
|
|
1055
|
-
logger.setLevel(getattr(logging, level, logging.INFO))
|
|
1056
|
-
return logger
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
def get_logger(name: str = "ins_pricing") -> logging.Logger:
|
|
1060
|
-
"""Get a logger with the given name, inheriting package-level settings.
|
|
1061
|
-
|
|
1062
|
-
Args:
|
|
1063
|
-
name: Logger name, typically module name like 'ins_pricing.trainer'
|
|
1064
|
-
|
|
1065
|
-
Returns:
|
|
1066
|
-
Configured logger instance
|
|
1067
|
-
|
|
1068
|
-
Example:
|
|
1069
|
-
>>> logger = get_logger("ins_pricing.trainer.ft")
|
|
1070
|
-
>>> logger.info("Training started")
|
|
1071
|
-
"""
|
|
1072
|
-
# Ensure package logger is initialized
|
|
1073
|
-
_get_package_logger()
|
|
1074
|
-
return logging.getLogger(name)
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
# =============================================================================
|
|
1078
|
-
# Metric Computation Factory
|
|
1079
|
-
# =============================================================================
|
|
1080
|
-
from sklearn.metrics import log_loss, mean_tweedie_deviance
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
class MetricFactory:
|
|
1084
|
-
"""Factory for computing evaluation metrics consistently across all trainers.
|
|
1085
|
-
|
|
1086
|
-
This class centralizes metric computation logic that was previously duplicated
|
|
1087
|
-
across FTTrainer, ResNetTrainer, GNNTrainer, XGBTrainer, and GLMTrainer.
|
|
1088
|
-
|
|
1089
|
-
Example:
|
|
1090
|
-
>>> factory = MetricFactory(task_type='regression', tweedie_power=1.5)
|
|
1091
|
-
>>> score = factory.compute(y_true, y_pred, sample_weight)
|
|
1092
|
-
"""
|
|
1093
|
-
|
|
1094
|
-
def __init__(
|
|
1095
|
-
self,
|
|
1096
|
-
task_type: str = "regression",
|
|
1097
|
-
tweedie_power: float = 1.5,
|
|
1098
|
-
clip_min: float = 1e-8,
|
|
1099
|
-
clip_max: float = 1 - 1e-8,
|
|
1100
|
-
):
|
|
1101
|
-
"""Initialize the metric factory.
|
|
1102
|
-
|
|
1103
|
-
Args:
|
|
1104
|
-
task_type: Either 'regression' or 'classification'
|
|
1105
|
-
tweedie_power: Power parameter for Tweedie deviance (1.0-2.0)
|
|
1106
|
-
clip_min: Minimum value for clipping predictions
|
|
1107
|
-
clip_max: Maximum value for clipping predictions (for classification)
|
|
1108
|
-
"""
|
|
1109
|
-
self.task_type = task_type
|
|
1110
|
-
self.tweedie_power = tweedie_power
|
|
1111
|
-
self.clip_min = clip_min
|
|
1112
|
-
self.clip_max = clip_max
|
|
1113
|
-
|
|
1114
|
-
def compute(
|
|
1115
|
-
self,
|
|
1116
|
-
y_true: np.ndarray,
|
|
1117
|
-
y_pred: np.ndarray,
|
|
1118
|
-
sample_weight: Optional[np.ndarray] = None,
|
|
1119
|
-
) -> float:
|
|
1120
|
-
"""Compute the appropriate metric based on task type.
|
|
1121
|
-
|
|
1122
|
-
Args:
|
|
1123
|
-
y_true: Ground truth values
|
|
1124
|
-
y_pred: Predicted values
|
|
1125
|
-
sample_weight: Optional sample weights
|
|
1126
|
-
|
|
1127
|
-
Returns:
|
|
1128
|
-
Computed metric value (lower is better)
|
|
1129
|
-
"""
|
|
1130
|
-
y_pred = np.asarray(y_pred)
|
|
1131
|
-
y_true = np.asarray(y_true)
|
|
1132
|
-
|
|
1133
|
-
if self.task_type == "classification":
|
|
1134
|
-
y_pred_clipped = np.clip(y_pred, self.clip_min, self.clip_max)
|
|
1135
|
-
return float(log_loss(y_true, y_pred_clipped, sample_weight=sample_weight))
|
|
1136
|
-
|
|
1137
|
-
# Regression: use Tweedie deviance
|
|
1138
|
-
y_pred_safe = np.maximum(y_pred, self.clip_min)
|
|
1139
|
-
return float(mean_tweedie_deviance(
|
|
1140
|
-
y_true,
|
|
1141
|
-
y_pred_safe,
|
|
1142
|
-
sample_weight=sample_weight,
|
|
1143
|
-
power=self.tweedie_power,
|
|
1144
|
-
))
|
|
1145
|
-
|
|
1146
|
-
def update_power(self, power: float) -> None:
|
|
1147
|
-
"""Update the Tweedie power parameter.
|
|
1148
|
-
|
|
1149
|
-
Args:
|
|
1150
|
-
power: New power value (1.0-2.0)
|
|
1151
|
-
"""
|
|
1152
|
-
self.tweedie_power = power
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
# =============================================================================
|
|
1156
|
-
# GPU Memory Manager
|
|
1157
|
-
# =============================================================================
|
|
1158
|
-
from contextlib import contextmanager
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
class GPUMemoryManager:
|
|
1162
|
-
"""Context manager for GPU memory management and cleanup.
|
|
1163
|
-
|
|
1164
|
-
This class consolidates GPU memory cleanup logic that was previously
|
|
1165
|
-
scattered across multiple trainer files.
|
|
1166
|
-
|
|
1167
|
-
Example:
|
|
1168
|
-
>>> with GPUMemoryManager.cleanup_context():
|
|
1169
|
-
... model.train()
|
|
1170
|
-
... # Memory cleaned up after exiting context
|
|
1171
|
-
|
|
1172
|
-
>>> # Or use directly:
|
|
1173
|
-
>>> GPUMemoryManager.clean()
|
|
1174
|
-
"""
|
|
1175
|
-
|
|
1176
|
-
_logger = get_logger("ins_pricing.gpu")
|
|
1177
|
-
|
|
1178
|
-
@classmethod
|
|
1179
|
-
def clean(cls, verbose: bool = False) -> None:
|
|
1180
|
-
"""Clean up GPU memory.
|
|
1181
|
-
|
|
1182
|
-
Args:
|
|
1183
|
-
verbose: If True, print cleanup details
|
|
1184
|
-
"""
|
|
1185
|
-
gc.collect()
|
|
1186
|
-
|
|
1187
|
-
if torch.cuda.is_available():
|
|
1188
|
-
torch.cuda.empty_cache()
|
|
1189
|
-
torch.cuda.synchronize()
|
|
1190
|
-
if verbose:
|
|
1191
|
-
cls._logger.debug("CUDA cache cleared and synchronized")
|
|
1192
|
-
|
|
1193
|
-
# Optional: Force IPC collect for multi-process scenarios
|
|
1194
|
-
if os.environ.get("BAYESOPT_CUDA_IPC_COLLECT", "0") == "1":
|
|
1195
|
-
if torch.cuda.is_available():
|
|
1196
|
-
try:
|
|
1197
|
-
torch.cuda.ipc_collect()
|
|
1198
|
-
if verbose:
|
|
1199
|
-
cls._logger.debug("CUDA IPC collect performed")
|
|
1200
|
-
except Exception:
|
|
1201
|
-
pass
|
|
1202
|
-
|
|
1203
|
-
@classmethod
|
|
1204
|
-
@contextmanager
|
|
1205
|
-
def cleanup_context(cls, verbose: bool = False):
|
|
1206
|
-
"""Context manager that cleans GPU memory on exit.
|
|
1207
|
-
|
|
1208
|
-
Args:
|
|
1209
|
-
verbose: If True, print cleanup details
|
|
1210
|
-
|
|
1211
|
-
Yields:
|
|
1212
|
-
None
|
|
1213
|
-
"""
|
|
1214
|
-
try:
|
|
1215
|
-
yield
|
|
1216
|
-
finally:
|
|
1217
|
-
cls.clean(verbose=verbose)
|
|
1218
|
-
|
|
1219
|
-
@classmethod
|
|
1220
|
-
def move_model_to_cpu(cls, model: nn.Module) -> nn.Module:
|
|
1221
|
-
"""Move a model to CPU and clean GPU memory.
|
|
1222
|
-
|
|
1223
|
-
Args:
|
|
1224
|
-
model: PyTorch model to move
|
|
1225
|
-
|
|
1226
|
-
Returns:
|
|
1227
|
-
Model on CPU
|
|
1228
|
-
"""
|
|
1229
|
-
if model is not None:
|
|
1230
|
-
model.to("cpu")
|
|
1231
|
-
cls.clean()
|
|
1232
|
-
return model
|
|
1233
|
-
|
|
1234
|
-
@classmethod
|
|
1235
|
-
def get_memory_info(cls) -> Dict[str, int]:
|
|
1236
|
-
"""Get current GPU memory usage information.
|
|
1237
|
-
|
|
1238
|
-
Returns:
|
|
1239
|
-
Dictionary with memory info (allocated, reserved, free)
|
|
1240
|
-
"""
|
|
1241
|
-
if not torch.cuda.is_available():
|
|
1242
|
-
return {"available": False}
|
|
1243
|
-
|
|
1244
|
-
try:
|
|
1245
|
-
allocated = torch.cuda.memory_allocated()
|
|
1246
|
-
reserved = torch.cuda.memory_reserved()
|
|
1247
|
-
free, total = torch.cuda.mem_get_info()
|
|
1248
|
-
return {
|
|
1249
|
-
"available": True,
|
|
1250
|
-
"allocated_mb": allocated // (1024 * 1024),
|
|
1251
|
-
"reserved_mb": reserved // (1024 * 1024),
|
|
1252
|
-
"free_mb": free // (1024 * 1024),
|
|
1253
|
-
"total_mb": total // (1024 * 1024),
|
|
1254
|
-
}
|
|
1255
|
-
except Exception:
|
|
1256
|
-
return {"available": False}
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
# =============================================================================
|
|
1260
|
-
# Device Manager
|
|
1261
|
-
# =============================================================================
|
|
1262
|
-
|
|
1263
|
-
class DeviceManager:
|
|
1264
|
-
"""Unified device management for model and tensor placement.
|
|
1265
|
-
|
|
1266
|
-
This class consolidates device detection and model movement logic
|
|
1267
|
-
that was previously duplicated across trainer_base.py and predict.py.
|
|
1268
|
-
|
|
1269
|
-
Example:
|
|
1270
|
-
>>> device = DeviceManager.get_best_device()
|
|
1271
|
-
>>> model = DeviceManager.move_to_device(model)
|
|
1272
|
-
"""
|
|
1273
|
-
|
|
1274
|
-
_logger = get_logger("ins_pricing.device")
|
|
1275
|
-
_cached_device: Optional[torch.device] = None
|
|
1276
|
-
|
|
1277
|
-
@classmethod
|
|
1278
|
-
def get_best_device(cls, prefer_cuda: bool = True) -> torch.device:
|
|
1279
|
-
"""Get the best available device.
|
|
1280
|
-
|
|
1281
|
-
Args:
|
|
1282
|
-
prefer_cuda: If True, prefer CUDA over MPS
|
|
1283
|
-
|
|
1284
|
-
Returns:
|
|
1285
|
-
Best available torch.device
|
|
1286
|
-
"""
|
|
1287
|
-
if cls._cached_device is not None:
|
|
1288
|
-
return cls._cached_device
|
|
1289
|
-
|
|
1290
|
-
if prefer_cuda and torch.cuda.is_available():
|
|
1291
|
-
cls._cached_device = torch.device("cuda")
|
|
1292
|
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
1293
|
-
cls._cached_device = torch.device("mps")
|
|
1294
|
-
else:
|
|
1295
|
-
cls._cached_device = torch.device("cpu")
|
|
1296
|
-
|
|
1297
|
-
cls._logger.debug(f"Selected device: {cls._cached_device}")
|
|
1298
|
-
return cls._cached_device
|
|
1299
|
-
|
|
1300
|
-
@classmethod
|
|
1301
|
-
def move_to_device(cls, model_obj, device: Optional[torch.device] = None) -> None:
|
|
1302
|
-
"""Move a model object to the specified device.
|
|
1303
|
-
|
|
1304
|
-
Handles sklearn-style wrappers that have .ft, .resnet, or .gnn attributes.
|
|
1305
|
-
|
|
1306
|
-
Args:
|
|
1307
|
-
model_obj: Model object to move (may be sklearn wrapper)
|
|
1308
|
-
device: Target device (defaults to best available)
|
|
1309
|
-
"""
|
|
1310
|
-
if model_obj is None:
|
|
1311
|
-
return
|
|
1312
|
-
|
|
1313
|
-
device = device or cls.get_best_device()
|
|
1314
|
-
|
|
1315
|
-
# Update device attribute if present
|
|
1316
|
-
if hasattr(model_obj, "device"):
|
|
1317
|
-
model_obj.device = device
|
|
1318
|
-
|
|
1319
|
-
# Move the main model
|
|
1320
|
-
if hasattr(model_obj, "to"):
|
|
1321
|
-
model_obj.to(device)
|
|
1322
|
-
|
|
1323
|
-
# Move nested submodules (sklearn wrappers)
|
|
1324
|
-
for attr_name in ("ft", "resnet", "gnn"):
|
|
1325
|
-
submodule = getattr(model_obj, attr_name, None)
|
|
1326
|
-
if submodule is not None and hasattr(submodule, "to"):
|
|
1327
|
-
submodule.to(device)
|
|
1328
|
-
|
|
1329
|
-
@classmethod
|
|
1330
|
-
def unwrap_module(cls, module: nn.Module) -> nn.Module:
|
|
1331
|
-
"""Unwrap DDP or DataParallel wrapper to get the base module.
|
|
1332
|
-
|
|
1333
|
-
Args:
|
|
1334
|
-
module: Potentially wrapped PyTorch module
|
|
1335
|
-
|
|
1336
|
-
Returns:
|
|
1337
|
-
Unwrapped base module
|
|
1338
|
-
"""
|
|
1339
|
-
if isinstance(module, (DDP, nn.DataParallel)):
|
|
1340
|
-
return module.module
|
|
1341
|
-
return module
|
|
1342
|
-
|
|
1343
|
-
@classmethod
|
|
1344
|
-
def reset_cache(cls) -> None:
|
|
1345
|
-
"""Reset cached device selection."""
|
|
1346
|
-
cls._cached_device = None
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
# =============================================================================
|
|
1350
|
-
# Cross-Validation Strategy Resolver
|
|
1351
|
-
# =============================================================================
|
|
1352
|
-
from sklearn.model_selection import KFold, GroupKFold, TimeSeriesSplit, StratifiedKFold
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
class CVStrategyResolver:
|
|
1356
|
-
"""Resolver for cross-validation splitting strategies.
|
|
1357
|
-
|
|
1358
|
-
This class consolidates CV strategy resolution logic that was previously
|
|
1359
|
-
duplicated across trainer_base.py and trainer_ft.py.
|
|
1360
|
-
|
|
1361
|
-
Supported strategies:
|
|
1362
|
-
- 'random': Standard KFold
|
|
1363
|
-
- 'stratified': StratifiedKFold (for classification)
|
|
1364
|
-
- 'group': GroupKFold (requires group column)
|
|
1365
|
-
- 'time': TimeSeriesSplit (requires time column)
|
|
1366
|
-
|
|
1367
|
-
Example:
|
|
1368
|
-
>>> resolver = CVStrategyResolver(
|
|
1369
|
-
... strategy='group',
|
|
1370
|
-
... n_splits=5,
|
|
1371
|
-
... group_col='policy_id',
|
|
1372
|
-
... data=train_df,
|
|
1373
|
-
... )
|
|
1374
|
-
>>> splitter, groups = resolver.get_splitter()
|
|
1375
|
-
>>> for train_idx, val_idx in splitter.split(X, y, groups):
|
|
1376
|
-
... pass
|
|
1377
|
-
"""
|
|
1378
|
-
|
|
1379
|
-
VALID_STRATEGIES = {"random", "stratified", "group", "grouped", "time", "timeseries", "temporal"}
|
|
1380
|
-
|
|
1381
|
-
def __init__(
|
|
1382
|
-
self,
|
|
1383
|
-
strategy: str = "random",
|
|
1384
|
-
n_splits: int = 5,
|
|
1385
|
-
shuffle: bool = True,
|
|
1386
|
-
random_state: Optional[int] = None,
|
|
1387
|
-
group_col: Optional[str] = None,
|
|
1388
|
-
time_col: Optional[str] = None,
|
|
1389
|
-
time_ascending: bool = True,
|
|
1390
|
-
data: Optional[pd.DataFrame] = None,
|
|
1391
|
-
):
|
|
1392
|
-
"""Initialize the CV strategy resolver.
|
|
1393
|
-
|
|
1394
|
-
Args:
|
|
1395
|
-
strategy: CV strategy name
|
|
1396
|
-
n_splits: Number of CV folds
|
|
1397
|
-
shuffle: Whether to shuffle for random/stratified
|
|
1398
|
-
random_state: Random seed for reproducibility
|
|
1399
|
-
group_col: Column name for group-based splitting
|
|
1400
|
-
time_col: Column name for time-based splitting
|
|
1401
|
-
time_ascending: Sort order for time-based splitting
|
|
1402
|
-
data: DataFrame containing group/time columns
|
|
1403
|
-
"""
|
|
1404
|
-
self.strategy = strategy.strip().lower()
|
|
1405
|
-
self.n_splits = max(2, int(n_splits))
|
|
1406
|
-
self.shuffle = shuffle
|
|
1407
|
-
self.random_state = random_state
|
|
1408
|
-
self.group_col = group_col
|
|
1409
|
-
self.time_col = time_col
|
|
1410
|
-
self.time_ascending = time_ascending
|
|
1411
|
-
self.data = data
|
|
1412
|
-
|
|
1413
|
-
if self.strategy not in self.VALID_STRATEGIES:
|
|
1414
|
-
raise ValueError(
|
|
1415
|
-
f"Invalid strategy '{strategy}'. "
|
|
1416
|
-
f"Valid options: {sorted(self.VALID_STRATEGIES)}"
|
|
1417
|
-
)
|
|
1418
|
-
|
|
1419
|
-
def get_splitter(self) -> Tuple[Any, Optional[pd.Series]]:
|
|
1420
|
-
"""Get the appropriate splitter and groups.
|
|
1421
|
-
|
|
1422
|
-
Returns:
|
|
1423
|
-
Tuple of (splitter, groups) where groups may be None
|
|
1424
|
-
|
|
1425
|
-
Raises:
|
|
1426
|
-
ValueError: If required columns are missing
|
|
1427
|
-
"""
|
|
1428
|
-
if self.strategy in {"group", "grouped"}:
|
|
1429
|
-
return self._get_group_splitter()
|
|
1430
|
-
elif self.strategy in {"time", "timeseries", "temporal"}:
|
|
1431
|
-
return self._get_time_splitter()
|
|
1432
|
-
elif self.strategy == "stratified":
|
|
1433
|
-
return self._get_stratified_splitter()
|
|
1434
|
-
else:
|
|
1435
|
-
return self._get_random_splitter()
|
|
1436
|
-
|
|
1437
|
-
def _get_random_splitter(self) -> Tuple[KFold, None]:
|
|
1438
|
-
"""Get a random KFold splitter."""
|
|
1439
|
-
splitter = KFold(
|
|
1440
|
-
n_splits=self.n_splits,
|
|
1441
|
-
shuffle=self.shuffle,
|
|
1442
|
-
random_state=self.random_state if self.shuffle else None,
|
|
1443
|
-
)
|
|
1444
|
-
return splitter, None
|
|
1445
|
-
|
|
1446
|
-
def _get_stratified_splitter(self) -> Tuple[StratifiedKFold, None]:
|
|
1447
|
-
"""Get a stratified KFold splitter."""
|
|
1448
|
-
splitter = StratifiedKFold(
|
|
1449
|
-
n_splits=self.n_splits,
|
|
1450
|
-
shuffle=self.shuffle,
|
|
1451
|
-
random_state=self.random_state if self.shuffle else None,
|
|
1452
|
-
)
|
|
1453
|
-
return splitter, None
|
|
1454
|
-
|
|
1455
|
-
def _get_group_splitter(self) -> Tuple[GroupKFold, pd.Series]:
|
|
1456
|
-
"""Get a group-based KFold splitter."""
|
|
1457
|
-
if not self.group_col:
|
|
1458
|
-
raise ValueError("group_col is required for group strategy")
|
|
1459
|
-
if self.data is None:
|
|
1460
|
-
raise ValueError("data DataFrame is required for group strategy")
|
|
1461
|
-
if self.group_col not in self.data.columns:
|
|
1462
|
-
raise KeyError(f"group_col '{self.group_col}' not found in data")
|
|
1463
|
-
|
|
1464
|
-
groups = self.data[self.group_col]
|
|
1465
|
-
splitter = GroupKFold(n_splits=self.n_splits)
|
|
1466
|
-
return splitter, groups
|
|
1467
|
-
|
|
1468
|
-
def _get_time_splitter(self) -> Tuple[Any, None]:
|
|
1469
|
-
"""Get a time-series splitter."""
|
|
1470
|
-
if not self.time_col:
|
|
1471
|
-
raise ValueError("time_col is required for time strategy")
|
|
1472
|
-
if self.data is None:
|
|
1473
|
-
raise ValueError("data DataFrame is required for time strategy")
|
|
1474
|
-
if self.time_col not in self.data.columns:
|
|
1475
|
-
raise KeyError(f"time_col '{self.time_col}' not found in data")
|
|
1476
|
-
|
|
1477
|
-
splitter = TimeSeriesSplit(n_splits=self.n_splits)
|
|
1478
|
-
|
|
1479
|
-
# Create an ordered wrapper that sorts by time column
|
|
1480
|
-
order_index = self.data[self.time_col].sort_values(
|
|
1481
|
-
ascending=self.time_ascending
|
|
1482
|
-
).index
|
|
1483
|
-
order = self.data.index.get_indexer(order_index)
|
|
1484
|
-
|
|
1485
|
-
return _OrderedSplitter(splitter, order), None
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
class _OrderedSplitter:
|
|
1489
|
-
"""Wrapper for splitters that need to respect a specific ordering."""
|
|
1490
|
-
|
|
1491
|
-
def __init__(self, base_splitter, order: np.ndarray):
|
|
1492
|
-
self.base_splitter = base_splitter
|
|
1493
|
-
self.order = order
|
|
1494
|
-
|
|
1495
|
-
def split(self, X, y=None, groups=None):
|
|
1496
|
-
"""Split with ordering applied."""
|
|
1497
|
-
n = len(X)
|
|
1498
|
-
X_ordered = np.arange(n)[self.order]
|
|
1499
|
-
for train_idx, val_idx in self.base_splitter.split(X_ordered):
|
|
1500
|
-
yield self.order[train_idx], self.order[val_idx]
|
|
1501
|
-
|
|
1502
|
-
def get_n_splits(self, X=None, y=None, groups=None):
|
|
1503
|
-
return self.base_splitter.get_n_splits()
|