ins-pricing 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. ins_pricing/README.md +74 -56
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +832 -898
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +1440 -1438
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/README.md +573 -419
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/config_builder.py +1 -0
  17. ins_pricing/frontend/example_workflows.py +1 -1
  18. ins_pricing/governance/__init__.py +20 -20
  19. ins_pricing/governance/release.py +159 -159
  20. ins_pricing/modelling/README.md +67 -0
  21. ins_pricing/modelling/__init__.py +147 -92
  22. ins_pricing/modelling/bayesopt/README.md +59 -0
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -550
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -962
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  32. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  37. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  39. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  40. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  41. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  42. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  43. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  44. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  45. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  46. ins_pricing/modelling/explain/__init__.py +55 -55
  47. ins_pricing/modelling/explain/metrics.py +27 -174
  48. ins_pricing/modelling/explain/permutation.py +237 -237
  49. ins_pricing/modelling/plotting/__init__.py +40 -36
  50. ins_pricing/modelling/plotting/compat.py +228 -0
  51. ins_pricing/modelling/plotting/curves.py +572 -572
  52. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  53. ins_pricing/modelling/plotting/geo.py +362 -362
  54. ins_pricing/modelling/plotting/importance.py +121 -121
  55. ins_pricing/pricing/__init__.py +27 -27
  56. ins_pricing/production/__init__.py +35 -25
  57. ins_pricing/production/{predict.py → inference.py} +140 -57
  58. ins_pricing/production/monitoring.py +8 -21
  59. ins_pricing/reporting/__init__.py +11 -11
  60. ins_pricing/setup.py +1 -1
  61. ins_pricing/tests/production/test_inference.py +90 -0
  62. ins_pricing/utils/__init__.py +116 -83
  63. ins_pricing/utils/device.py +255 -255
  64. ins_pricing/utils/features.py +53 -0
  65. ins_pricing/utils/io.py +72 -0
  66. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  67. ins_pricing/utils/metrics.py +158 -24
  68. ins_pricing/utils/numerics.py +76 -0
  69. ins_pricing/utils/paths.py +9 -1
  70. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +55 -35
  71. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  72. ins_pricing/CHANGELOG.md +0 -272
  73. ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
  74. ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
  75. ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
  76. ins_pricing/docs/modelling/README.md +0 -34
  77. ins_pricing/frontend/QUICKSTART.md +0 -152
  78. ins_pricing/modelling/core/BayesOpt.py +0 -146
  79. ins_pricing/modelling/core/__init__.py +0 -1
  80. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
  81. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
  82. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.4.dist-info/RECORD +0 -137
  92. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  93. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  94. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  95. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  96. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,623 +1,623 @@
1
- """PyTorch training mixin with resource management and training loops.
2
-
3
- This module provides the TorchTrainerMixin class which is used by
4
- PyTorch-based trainers (ResNet, FT, GNN) for:
5
- - Resource profiling and memory management
6
- - Batch size computation and optimization
7
- - DataLoader creation with DDP support
8
- - Generic training and validation loops with AMP
9
- - Early stopping and loss curve plotting
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- import copy
15
- import ctypes
16
- import gc
17
- import math
18
- import os
19
- import time
20
- from contextlib import nullcontext
21
- from typing import Any, Callable, Dict, List, Optional, Tuple
22
-
23
- import numpy as np
24
- import optuna
25
- import torch
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
- import torch.distributed as dist
29
- from torch.cuda.amp import autocast, GradScaler
30
- from torch.nn.parallel import DistributedDataParallel as DDP
31
- from torch.utils.data import DataLoader, DistributedSampler
32
-
33
- # Try to import plotting functions
34
- try:
35
- import matplotlib
36
- if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
37
- matplotlib.use("Agg")
38
- import matplotlib.pyplot as plt
39
- _MPL_IMPORT_ERROR: Optional[BaseException] = None
40
- except Exception as exc:
41
- matplotlib = None
42
- plt = None
43
- _MPL_IMPORT_ERROR = exc
44
-
45
- try:
46
- from ....plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
47
- except Exception:
48
- try:
49
- from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
50
- except Exception:
51
- plot_loss_curve_common = None
52
-
53
- # Import from other utils modules
54
- from .constants import EPS, compute_batch_size, tweedie_loss, ensure_parent_dir
55
- from .losses import (
56
- infer_loss_name_from_model_name,
57
- loss_requires_positive,
58
- normalize_loss_name,
59
- resolve_tweedie_power,
60
- )
61
- from .distributed_utils import DistributedUtils
62
-
63
-
64
- def _plot_skip(label: str) -> None:
65
- """Print message when plot is skipped due to missing matplotlib."""
66
- if _MPL_IMPORT_ERROR is not None:
67
- print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
68
- else:
69
- print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
70
-
71
-
72
- class TorchTrainerMixin:
73
- """Shared helpers for PyTorch tabular trainers.
74
-
75
- Provides resource profiling, memory management, batch size optimization,
76
- and standardized training loops with mixed precision and DDP support.
77
-
78
- This mixin is used by ResNetTrainer, FTTrainer, and GNNTrainer.
79
- """
80
-
81
- def _device_type(self) -> str:
82
- """Get device type (cpu/cuda/mps)."""
83
- return getattr(self, "device", torch.device("cpu")).type
84
-
85
- def _resolve_resource_profile(self) -> str:
86
- """Determine resource usage profile.
87
-
88
- Returns:
89
- One of: 'throughput', 'memory_saving', or 'auto'
90
- """
91
- profile = getattr(self, "resource_profile", None)
92
- if not profile:
93
- profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
94
- profile = str(profile).strip().lower()
95
- if profile in {"cpu", "mps", "cuda"}:
96
- profile = "auto"
97
- if profile not in {"auto", "throughput", "memory_saving"}:
98
- profile = "auto"
99
- if profile == "auto" and self._device_type() == "cuda":
100
- profile = "throughput"
101
- return profile
102
-
103
- def _log_resource_summary_once(self, profile: str) -> None:
104
- """Log resource configuration summary once."""
105
- if getattr(self, "_resource_summary_logged", False):
106
- return
107
- if dist.is_initialized() and not DistributedUtils.is_main_process():
108
- return
109
- self._resource_summary_logged = True
110
- device = getattr(self, "device", torch.device("cpu"))
111
- device_type = self._device_type()
112
- cpu_count = os.cpu_count() or 1
113
- cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
114
- mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
115
- ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
116
- data_parallel = bool(getattr(self, "use_data_parallel", False))
117
- print(
118
- f">>> Resource summary: device={device}, device_type={device_type}, "
119
- f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
120
- f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
121
- )
122
-
123
- def _available_system_memory(self) -> Optional[int]:
124
- """Get available system RAM in bytes."""
125
- if os.name == "nt":
126
- class _MemStatus(ctypes.Structure):
127
- _fields_ = [
128
- ("dwLength", ctypes.c_ulong),
129
- ("dwMemoryLoad", ctypes.c_ulong),
130
- ("ullTotalPhys", ctypes.c_ulonglong),
131
- ("ullAvailPhys", ctypes.c_ulonglong),
132
- ("ullTotalPageFile", ctypes.c_ulonglong),
133
- ("ullAvailPageFile", ctypes.c_ulonglong),
134
- ("ullTotalVirtual", ctypes.c_ulonglong),
135
- ("ullAvailVirtual", ctypes.c_ulonglong),
136
- ("sullAvailExtendedVirtual", ctypes.c_ulonglong),
137
- ]
138
- status = _MemStatus()
139
- status.dwLength = ctypes.sizeof(_MemStatus)
140
- if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
141
- return int(status.ullAvailPhys)
142
- return None
143
- try:
144
- pages = os.sysconf("SC_AVPHYS_PAGES")
145
- page_size = os.sysconf("SC_PAGE_SIZE")
146
- return int(pages * page_size)
147
- except Exception:
148
- return None
149
-
150
- def _available_cuda_memory(self) -> Optional[int]:
151
- """Get available CUDA memory in bytes."""
152
- if not torch.cuda.is_available():
153
- return None
154
- try:
155
- free_mem, _total_mem = torch.cuda.mem_get_info()
156
- except Exception:
157
- return None
158
- return int(free_mem)
159
-
160
- def _estimate_sample_bytes(self, dataset) -> Optional[int]:
161
- """Estimate memory per sample in bytes."""
162
- try:
163
- if len(dataset) == 0:
164
- return None
165
- sample = dataset[0]
166
- except Exception:
167
- return None
168
-
169
- def _bytes(obj) -> int:
170
- if obj is None:
171
- return 0
172
- if torch.is_tensor(obj):
173
- return int(obj.element_size() * obj.nelement())
174
- if isinstance(obj, np.ndarray):
175
- return int(obj.nbytes)
176
- if isinstance(obj, (list, tuple)):
177
- return int(sum(_bytes(item) for item in obj))
178
- if isinstance(obj, dict):
179
- return int(sum(_bytes(item) for item in obj.values()))
180
- return 0
181
-
182
- sample_bytes = _bytes(sample)
183
- return int(sample_bytes) if sample_bytes > 0 else None
184
-
185
- def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
186
- """Cap batch size based on available memory."""
187
- if batch_size <= 1:
188
- return batch_size
189
- sample_bytes = self._estimate_sample_bytes(dataset)
190
- if sample_bytes is None:
191
- return batch_size
192
- device_type = self._device_type()
193
- if device_type == "cuda":
194
- available = self._available_cuda_memory()
195
- if available is None:
196
- return batch_size
197
- if profile == "throughput":
198
- budget_ratio = 0.8
199
- overhead = 8.0
200
- elif profile == "memory_saving":
201
- budget_ratio = 0.5
202
- overhead = 14.0
203
- else:
204
- budget_ratio = 0.6
205
- overhead = 12.0
206
- else:
207
- available = self._available_system_memory()
208
- if available is None:
209
- return batch_size
210
- if profile == "throughput":
211
- budget_ratio = 0.4
212
- overhead = 1.8
213
- elif profile == "memory_saving":
214
- budget_ratio = 0.25
215
- overhead = 3.0
216
- else:
217
- budget_ratio = 0.3
218
- overhead = 2.6
219
- budget = int(available * budget_ratio)
220
- per_sample = int(sample_bytes * overhead)
221
- if per_sample <= 0:
222
- return batch_size
223
- max_batch = max(1, int(budget // per_sample))
224
- if max_batch < batch_size:
225
- print(
226
- f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
227
- f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
228
- )
229
- return min(batch_size, max_batch)
230
-
231
- def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
232
- """Determine number of DataLoader workers."""
233
- if os.name == 'nt':
234
- return 0
235
- override = getattr(self, "dataloader_workers", None)
236
- if override is None:
237
- override = os.environ.get("BAYESOPT_DATALOADER_WORKERS")
238
- if override is not None:
239
- try:
240
- return max(0, int(override))
241
- except (TypeError, ValueError):
242
- pass
243
- if getattr(self, "is_ddp_enabled", False):
244
- return 0
245
- profile = profile or self._resolve_resource_profile()
246
- if profile == "memory_saving":
247
- return 0
248
- worker_cap = min(int(max_workers), os.cpu_count() or 1)
249
- if self._device_type() == "mps":
250
- worker_cap = min(worker_cap, 2)
251
- return worker_cap
252
-
253
- def _build_dataloader(self,
254
- dataset,
255
- N: int,
256
- base_bs_gpu: tuple,
257
- base_bs_cpu: tuple,
258
- min_bs: int = 64,
259
- target_effective_cuda: int = 1024,
260
- target_effective_cpu: int = 512,
261
- large_threshold: int = 200_000,
262
- mid_threshold: int = 50_000):
263
- """Build DataLoader with adaptive batch size and worker configuration.
264
-
265
- Returns:
266
- Tuple of (dataloader, accum_steps)
267
- """
268
- profile = self._resolve_resource_profile()
269
- self._log_resource_summary_once(profile)
270
- batch_size = compute_batch_size(
271
- data_size=len(dataset),
272
- learning_rate=self.learning_rate,
273
- batch_num=self.batch_num,
274
- minimum=min_bs
275
- )
276
- gpu_large, gpu_mid, gpu_small = base_bs_gpu
277
- cpu_mid, cpu_small = base_bs_cpu
278
-
279
- if self._device_type() == 'cuda':
280
- # Only scale batch size by GPU count when DDP is enabled.
281
- # In single-process (non-DDP) mode, large multi-GPU nodes can
282
- # still OOM on RAM/VRAM if we scale by device_count.
283
- device_count = 1
284
- if getattr(self, "is_ddp_enabled", False):
285
- device_count = torch.cuda.device_count()
286
- if device_count > 1:
287
- min_bs = min_bs * device_count
288
- print(
289
- f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
290
-
291
- if N > large_threshold:
292
- base_bs = gpu_large * device_count
293
- elif N > mid_threshold:
294
- base_bs = gpu_mid * device_count
295
- else:
296
- base_bs = gpu_small * device_count
297
- else:
298
- base_bs = cpu_mid if N > mid_threshold else cpu_small
299
-
300
- batch_size = compute_batch_size(
301
- data_size=len(dataset),
302
- learning_rate=self.learning_rate,
303
- batch_num=self.batch_num,
304
- minimum=min_bs
305
- )
306
- batch_size = min(batch_size, base_bs, N)
307
- batch_size = self._cap_batch_size_by_memory(
308
- dataset, batch_size, profile)
309
-
310
- target_effective_bs = target_effective_cuda if self._device_type(
311
- ) == 'cuda' else target_effective_cpu
312
- if getattr(self, "is_ddp_enabled", False):
313
- world_size = max(1, DistributedUtils.world_size())
314
- target_effective_bs = max(1, target_effective_bs // world_size)
315
-
316
- world_size = getattr(self, "world_size", 1) if getattr(
317
- self, "is_ddp_enabled", False) else 1
318
- samples_per_rank = math.ceil(
319
- N / max(1, world_size)) if world_size > 1 else N
320
- steps_per_epoch = max(
321
- 1, math.ceil(samples_per_rank / max(1, batch_size)))
322
- desired_accum = max(1, target_effective_bs // max(1, batch_size))
323
- accum_steps = max(1, min(desired_accum, steps_per_epoch))
324
-
325
- workers = self._resolve_num_workers(8, profile=profile)
326
- prefetch_factor = None
327
- if workers > 0:
328
- prefetch_factor = 4 if profile == "throughput" else 2
329
- persistent = workers > 0 and profile != "memory_saving"
330
- print(
331
- f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
332
- f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
333
- sampler = None
334
- use_distributed_sampler = bool(
335
- dist.is_initialized() and getattr(self, "is_ddp_enabled", False)
336
- )
337
- if use_distributed_sampler:
338
- sampler = DistributedSampler(dataset, shuffle=True)
339
- shuffle = False
340
- else:
341
- shuffle = True
342
-
343
- dataloader = DataLoader(
344
- dataset,
345
- batch_size=batch_size,
346
- shuffle=shuffle,
347
- sampler=sampler,
348
- num_workers=workers,
349
- pin_memory=(self._device_type() == 'cuda'),
350
- persistent_workers=persistent,
351
- **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
352
- )
353
- return dataloader, accum_steps
354
-
355
- def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
356
- """Build validation DataLoader."""
357
- profile = self._resolve_resource_profile()
358
- val_bs = accum_steps * train_dataloader.batch_size
359
- val_workers = self._resolve_num_workers(4, profile=profile)
360
- prefetch_factor = None
361
- if val_workers > 0:
362
- prefetch_factor = 2
363
- return DataLoader(
364
- dataset,
365
- batch_size=val_bs,
366
- shuffle=False,
367
- num_workers=val_workers,
368
- pin_memory=(self._device_type() == 'cuda'),
369
- persistent_workers=(val_workers > 0 and profile != "memory_saving"),
370
- **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
371
- )
372
-
373
- def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
374
- """Compute per-sample losses based on task type."""
375
- task = getattr(self, "task_type", "regression")
376
- if task == 'classification':
377
- loss_fn = nn.BCEWithLogitsLoss(reduction='none')
378
- return loss_fn(y_pred, y_true).view(-1)
379
- loss_name = normalize_loss_name(
380
- getattr(self, "loss_name", None), task_type="regression"
381
- )
382
- if loss_name == "auto":
383
- loss_name = infer_loss_name_from_model_name(getattr(self, "model_nme", ""))
384
- if apply_softplus:
385
- y_pred = F.softplus(y_pred)
386
- if loss_requires_positive(loss_name):
387
- y_pred = torch.clamp(y_pred, min=1e-6)
388
- power = resolve_tweedie_power(
389
- loss_name, default=float(getattr(self, "tw_power", 1.5) or 1.5)
390
- )
391
- if power is None:
392
- power = float(getattr(self, "tw_power", 1.5) or 1.5)
393
- return tweedie_loss(y_pred, y_true, p=power).view(-1)
394
- if loss_name == "mse":
395
- return (y_pred - y_true).pow(2).view(-1)
396
- if loss_name == "mae":
397
- return (y_pred - y_true).abs().view(-1)
398
- raise ValueError(f"Unsupported loss_name '{loss_name}' for regression.")
399
-
400
- def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
401
- """Compute weighted loss."""
402
- losses = self._compute_losses(
403
- y_pred, y_true, apply_softplus=apply_softplus)
404
- weighted_loss = (losses * weights.view(-1)).sum() / \
405
- torch.clamp(weights.sum(), min=EPS)
406
- return weighted_loss
407
-
408
- def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
409
- ignore_keys: Optional[List[str]] = None):
410
- """Update early stopping state."""
411
- if val_loss < best_loss:
412
- ignore_keys = ignore_keys or []
413
- base_module = model.module if hasattr(model, "module") else model
414
- state_dict = {
415
- k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
416
- for k, v in base_module.state_dict().items()
417
- if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
418
- }
419
- return val_loss, state_dict, 0, False
420
- patience_counter += 1
421
- should_stop = best_state is not None and patience_counter >= getattr(
422
- self, "patience", 0)
423
- return best_loss, best_state, patience_counter, should_stop
424
-
425
- def _train_model(self,
426
- model,
427
- dataloader,
428
- accum_steps,
429
- optimizer,
430
- scaler,
431
- forward_fn,
432
- val_forward_fn=None,
433
- apply_softplus: bool = False,
434
- clip_fn=None,
435
- trial: Optional[optuna.trial.Trial] = None,
436
- loss_curve_path: Optional[str] = None):
437
- """Generic training loop with AMP, DDP, and early stopping support.
438
-
439
- Returns:
440
- Tuple of (best_state_dict, history)
441
- """
442
- device_type = self._device_type()
443
- best_loss = float('inf')
444
- best_state = None
445
- patience_counter = 0
446
- stop_training = False
447
- train_history: List[float] = []
448
- val_history: List[float] = []
449
-
450
- is_ddp_model = isinstance(model, DDP)
451
- use_collectives = dist.is_initialized() and is_ddp_model
452
-
453
- for epoch in range(1, getattr(self, "epochs", 1) + 1):
454
- epoch_start_ts = time.time()
455
- val_weighted_loss = None
456
- if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
457
- self.dataloader_sampler.set_epoch(epoch)
458
-
459
- model.train()
460
- optimizer.zero_grad()
461
-
462
- epoch_loss_sum = None
463
- epoch_weight_sum = None
464
- for step, batch in enumerate(dataloader):
465
- is_update_step = ((step + 1) % accum_steps == 0) or \
466
- ((step + 1) == len(dataloader))
467
- sync_cm = model.no_sync if (
468
- is_ddp_model and not is_update_step) else nullcontext
469
-
470
- with sync_cm():
471
- with autocast(enabled=(device_type == 'cuda')):
472
- y_pred, y_true, w = forward_fn(batch)
473
- weighted_loss = self._compute_weighted_loss(
474
- y_pred, y_true, w, apply_softplus=apply_softplus)
475
- loss_for_backward = weighted_loss / accum_steps
476
-
477
- batch_weight = torch.clamp(
478
- w.detach().sum(), min=EPS).to(dtype=torch.float32)
479
- loss_val = weighted_loss.detach().to(dtype=torch.float32)
480
- if epoch_loss_sum is None:
481
- epoch_loss_sum = torch.zeros(
482
- (), device=batch_weight.device, dtype=torch.float32)
483
- epoch_weight_sum = torch.zeros(
484
- (), device=batch_weight.device, dtype=torch.float32)
485
- epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
486
- epoch_weight_sum = epoch_weight_sum + batch_weight
487
- scaler.scale(loss_for_backward).backward()
488
-
489
- if is_update_step:
490
- if clip_fn is not None:
491
- clip_fn()
492
- scaler.step(optimizer)
493
- scaler.update()
494
- optimizer.zero_grad()
495
-
496
- if epoch_loss_sum is None or epoch_weight_sum is None:
497
- train_epoch_loss = 0.0
498
- else:
499
- train_epoch_loss = (
500
- epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
501
- ).item()
502
- train_history.append(float(train_epoch_loss))
503
-
504
- if val_forward_fn is not None:
505
- should_compute_val = (not dist.is_initialized()
506
- or DistributedUtils.is_main_process())
507
- val_device = getattr(self, "device", torch.device("cpu"))
508
- if not isinstance(val_device, torch.device):
509
- val_device = torch.device(val_device)
510
- loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
511
- "cpu")
512
- val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
513
-
514
- if should_compute_val:
515
- model.eval()
516
- with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
517
- val_result = val_forward_fn()
518
- if isinstance(val_result, tuple) and len(val_result) == 3:
519
- y_val_pred, y_val_true, w_val = val_result
520
- val_weighted_loss = self._compute_weighted_loss(
521
- y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
522
- else:
523
- val_weighted_loss = val_result
524
- val_loss_tensor[0] = float(val_weighted_loss)
525
-
526
- if use_collectives:
527
- dist.broadcast(val_loss_tensor, src=0)
528
- val_weighted_loss = float(val_loss_tensor.item())
529
-
530
- val_history.append(val_weighted_loss)
531
-
532
- best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
533
- val_weighted_loss, best_loss, best_state, patience_counter, model)
534
-
535
- prune_flag = False
536
- is_main_rank = DistributedUtils.is_main_process()
537
- if trial is not None and is_main_rank:
538
- trial.report(val_weighted_loss, epoch)
539
- prune_flag = trial.should_prune()
540
-
541
- if use_collectives:
542
- prune_device = getattr(self, "device", torch.device("cpu"))
543
- if not isinstance(prune_device, torch.device):
544
- prune_device = torch.device(prune_device)
545
- prune_tensor = torch.zeros(1, device=prune_device)
546
- if is_main_rank:
547
- prune_tensor.fill_(1 if prune_flag else 0)
548
- dist.broadcast(prune_tensor, src=0)
549
- prune_flag = bool(prune_tensor.item())
550
-
551
- if prune_flag:
552
- raise optuna.TrialPruned()
553
-
554
- if stop_training:
555
- break
556
-
557
- should_log_epoch = (not dist.is_initialized()
558
- or DistributedUtils.is_main_process())
559
- if should_log_epoch:
560
- elapsed = int(time.time() - epoch_start_ts)
561
- if val_weighted_loss is None:
562
- print(
563
- f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
564
- f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
565
- flush=True,
566
- )
567
- else:
568
- print(
569
- f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
570
- f"train_loss={float(train_epoch_loss):.6f} "
571
- f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
572
- flush=True,
573
- )
574
-
575
- if epoch % 10 == 0:
576
- if torch.cuda.is_available():
577
- torch.cuda.empty_cache()
578
- gc.collect()
579
-
580
- history = {"train": train_history, "val": val_history}
581
- self._plot_loss_curve(history, loss_curve_path)
582
- return best_state, history
583
-
584
- def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
585
- """Plot training and validation loss curves."""
586
- if not save_path:
587
- return
588
- if dist.is_initialized() and not DistributedUtils.is_main_process():
589
- return
590
- train_hist = history.get("train", []) if history else []
591
- val_hist = history.get("val", []) if history else []
592
- if not train_hist and not val_hist:
593
- return
594
- if plot_loss_curve_common is not None:
595
- plot_loss_curve_common(
596
- history=history,
597
- title="Loss vs. Epoch",
598
- save_path=save_path,
599
- show=False,
600
- )
601
- else:
602
- if plt is None:
603
- _plot_skip("loss curve")
604
- return
605
- ensure_parent_dir(save_path)
606
- epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
607
- fig = plt.figure(figsize=(8, 4))
608
- ax = fig.add_subplot(111)
609
- if train_hist:
610
- ax.plot(range(1, len(train_hist) + 1), train_hist,
611
- label='Train Loss', color='tab:blue')
612
- if val_hist:
613
- ax.plot(range(1, len(val_hist) + 1), val_hist,
614
- label='Validation Loss', color='tab:orange')
615
- ax.set_xlabel('Epoch')
616
- ax.set_ylabel('Weighted Loss')
617
- ax.set_title('Loss vs. Epoch')
618
- ax.grid(True, linestyle='--', alpha=0.3)
619
- ax.legend()
620
- plt.tight_layout()
621
- plt.savefig(save_path, dpi=300)
622
- plt.close(fig)
623
- print(f"[Training] Loss curve saved to {save_path}")
1
+ """PyTorch training mixin with resource management and training loops.
2
+
3
+ This module provides the TorchTrainerMixin class which is used by
4
+ PyTorch-based trainers (ResNet, FT, GNN) for:
5
+ - Resource profiling and memory management
6
+ - Batch size computation and optimization
7
+ - DataLoader creation with DDP support
8
+ - Generic training and validation loops with AMP
9
+ - Early stopping and loss curve plotting
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import ctypes
16
+ import gc
17
+ import math
18
+ import os
19
+ import time
20
+ from contextlib import nullcontext
21
+ from typing import Dict, List, Optional
22
+
23
+ import numpy as np
24
+ import optuna
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torch.distributed as dist
29
+ from torch.cuda.amp import autocast
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ from torch.utils.data import DataLoader, DistributedSampler
32
+
33
+ # Try to import plotting functions
34
+ try:
35
+ import matplotlib
36
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
37
+ matplotlib.use("Agg")
38
+ import matplotlib.pyplot as plt
39
+ _MPL_IMPORT_ERROR: Optional[BaseException] = None
40
+ except Exception as exc:
41
+ matplotlib = None
42
+ plt = None
43
+ _MPL_IMPORT_ERROR = exc
44
+
45
+ try:
46
+ from ins_pricing.modelling.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
47
+ except Exception:
48
+ try:
49
+ from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
50
+ except Exception:
51
+ plot_loss_curve_common = None
52
+
53
+ # Import from other utils modules
54
+ from ins_pricing.utils import EPS, compute_batch_size, tweedie_loss, ensure_parent_dir
55
+ from ins_pricing.utils.losses import (
56
+ infer_loss_name_from_model_name,
57
+ loss_requires_positive,
58
+ normalize_loss_name,
59
+ resolve_tweedie_power,
60
+ )
61
+ from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
62
+
63
+
64
+ def _plot_skip(label: str) -> None:
65
+ """Print message when plot is skipped due to missing matplotlib."""
66
+ if _MPL_IMPORT_ERROR is not None:
67
+ print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
68
+ else:
69
+ print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
70
+
71
+
72
+ class TorchTrainerMixin:
73
+ """Shared helpers for PyTorch tabular trainers.
74
+
75
+ Provides resource profiling, memory management, batch size optimization,
76
+ and standardized training loops with mixed precision and DDP support.
77
+
78
+ This mixin is used by ResNetTrainer, FTTrainer, and GNNTrainer.
79
+ """
80
+
81
+ def _resolve_device(self) -> torch.device:
82
+ """Resolve device to a torch.device instance."""
83
+ device = getattr(self, "device", None)
84
+ if device is None:
85
+ return torch.device("cpu")
86
+ return device if isinstance(device, torch.device) else torch.device(device)
87
+
88
+ def _device_type(self) -> str:
89
+ """Get device type (cpu/cuda/mps)."""
90
+ return self._resolve_device().type
91
+
92
+ def _resolve_resource_profile(self) -> str:
93
+ """Determine resource usage profile.
94
+
95
+ Returns:
96
+ One of: 'throughput', 'memory_saving', or 'auto'
97
+ """
98
+ profile = getattr(self, "resource_profile", None)
99
+ if not profile:
100
+ profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
101
+ profile = str(profile).strip().lower()
102
+ if profile in {"cpu", "mps", "cuda"}:
103
+ profile = "auto"
104
+ if profile not in {"auto", "throughput", "memory_saving"}:
105
+ profile = "auto"
106
+ if profile == "auto" and self._device_type() == "cuda":
107
+ profile = "throughput"
108
+ return profile
109
+
110
+ def _log_resource_summary_once(self, profile: str) -> None:
111
+ """Log resource configuration summary once."""
112
+ if getattr(self, "_resource_summary_logged", False):
113
+ return
114
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
115
+ return
116
+ self._resource_summary_logged = True
117
+ device = self._resolve_device()
118
+ device_type = self._device_type()
119
+ cpu_count = os.cpu_count() or 1
120
+ cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
121
+ mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
122
+ ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
123
+ data_parallel = bool(getattr(self, "use_data_parallel", False))
124
+ print(
125
+ f">>> Resource summary: device={device}, device_type={device_type}, "
126
+ f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
127
+ f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
128
+ )
129
+
130
+ def _available_system_memory(self) -> Optional[int]:
131
+ """Get available system RAM in bytes."""
132
+ if os.name == "nt":
133
+ class _MemStatus(ctypes.Structure):
134
+ _fields_ = [
135
+ ("dwLength", ctypes.c_ulong),
136
+ ("dwMemoryLoad", ctypes.c_ulong),
137
+ ("ullTotalPhys", ctypes.c_ulonglong),
138
+ ("ullAvailPhys", ctypes.c_ulonglong),
139
+ ("ullTotalPageFile", ctypes.c_ulonglong),
140
+ ("ullAvailPageFile", ctypes.c_ulonglong),
141
+ ("ullTotalVirtual", ctypes.c_ulonglong),
142
+ ("ullAvailVirtual", ctypes.c_ulonglong),
143
+ ("sullAvailExtendedVirtual", ctypes.c_ulonglong),
144
+ ]
145
+ status = _MemStatus()
146
+ status.dwLength = ctypes.sizeof(_MemStatus)
147
+ if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
148
+ return int(status.ullAvailPhys)
149
+ return None
150
+ try:
151
+ pages = os.sysconf("SC_AVPHYS_PAGES")
152
+ page_size = os.sysconf("SC_PAGE_SIZE")
153
+ return int(pages * page_size)
154
+ except Exception:
155
+ return None
156
+
157
+ def _available_cuda_memory(self) -> Optional[int]:
158
+ """Get available CUDA memory in bytes."""
159
+ if not torch.cuda.is_available():
160
+ return None
161
+ try:
162
+ free_mem, _total_mem = torch.cuda.mem_get_info()
163
+ except Exception:
164
+ return None
165
+ return int(free_mem)
166
+
167
+ def _estimate_sample_bytes(self, dataset) -> Optional[int]:
168
+ """Estimate memory per sample in bytes."""
169
+ try:
170
+ if len(dataset) == 0:
171
+ return None
172
+ sample = dataset[0]
173
+ except Exception:
174
+ return None
175
+
176
+ def _bytes(obj) -> int:
177
+ if obj is None:
178
+ return 0
179
+ if torch.is_tensor(obj):
180
+ return int(obj.element_size() * obj.nelement())
181
+ if isinstance(obj, np.ndarray):
182
+ return int(obj.nbytes)
183
+ if isinstance(obj, (list, tuple)):
184
+ return int(sum(_bytes(item) for item in obj))
185
+ if isinstance(obj, dict):
186
+ return int(sum(_bytes(item) for item in obj.values()))
187
+ return 0
188
+
189
+ sample_bytes = _bytes(sample)
190
+ return int(sample_bytes) if sample_bytes > 0 else None
191
+
192
+ def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
193
+ """Cap batch size based on available memory."""
194
+ if batch_size <= 1:
195
+ return batch_size
196
+ sample_bytes = self._estimate_sample_bytes(dataset)
197
+ if sample_bytes is None:
198
+ return batch_size
199
+ device_type = self._device_type()
200
+ if device_type == "cuda":
201
+ available = self._available_cuda_memory()
202
+ if available is None:
203
+ return batch_size
204
+ if profile == "throughput":
205
+ budget_ratio = 0.8
206
+ overhead = 8.0
207
+ elif profile == "memory_saving":
208
+ budget_ratio = 0.5
209
+ overhead = 14.0
210
+ else:
211
+ budget_ratio = 0.6
212
+ overhead = 12.0
213
+ else:
214
+ available = self._available_system_memory()
215
+ if available is None:
216
+ return batch_size
217
+ if profile == "throughput":
218
+ budget_ratio = 0.4
219
+ overhead = 1.8
220
+ elif profile == "memory_saving":
221
+ budget_ratio = 0.25
222
+ overhead = 3.0
223
+ else:
224
+ budget_ratio = 0.3
225
+ overhead = 2.6
226
+ budget = int(available * budget_ratio)
227
+ per_sample = int(sample_bytes * overhead)
228
+ if per_sample <= 0:
229
+ return batch_size
230
+ max_batch = max(1, int(budget // per_sample))
231
+ if max_batch < batch_size:
232
+ print(
233
+ f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
234
+ f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
235
+ )
236
+ return min(batch_size, max_batch)
237
+
238
+ def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
239
+ """Determine number of DataLoader workers."""
240
+ if os.name == 'nt':
241
+ return 0
242
+ override = getattr(self, "dataloader_workers", None)
243
+ if override is None:
244
+ override = os.environ.get("BAYESOPT_DATALOADER_WORKERS")
245
+ if override is not None:
246
+ try:
247
+ return max(0, int(override))
248
+ except (TypeError, ValueError):
249
+ pass
250
+ if getattr(self, "is_ddp_enabled", False):
251
+ return 0
252
+ profile = profile or self._resolve_resource_profile()
253
+ if profile == "memory_saving":
254
+ return 0
255
+ worker_cap = min(int(max_workers), os.cpu_count() or 1)
256
+ if self._device_type() == "mps":
257
+ worker_cap = min(worker_cap, 2)
258
+ return worker_cap
259
+
260
+ def _build_dataloader(self,
261
+ dataset,
262
+ N: int,
263
+ base_bs_gpu: tuple,
264
+ base_bs_cpu: tuple,
265
+ min_bs: int = 64,
266
+ target_effective_cuda: int = 1024,
267
+ target_effective_cpu: int = 512,
268
+ large_threshold: int = 200_000,
269
+ mid_threshold: int = 50_000):
270
+ """Build DataLoader with adaptive batch size and worker configuration.
271
+
272
+ Returns:
273
+ Tuple of (dataloader, accum_steps)
274
+ """
275
+ profile = self._resolve_resource_profile()
276
+ self._log_resource_summary_once(profile)
277
+ data_size = int(N) if N is not None else len(dataset)
278
+ gpu_large, gpu_mid, gpu_small = base_bs_gpu
279
+ cpu_mid, cpu_small = base_bs_cpu
280
+
281
+ device_type = self._device_type()
282
+ is_ddp = bool(getattr(self, "is_ddp_enabled", False))
283
+ if device_type == 'cuda':
284
+ # Only scale batch size by GPU count when DDP is enabled.
285
+ # In single-process (non-DDP) mode, large multi-GPU nodes can
286
+ # still OOM on RAM/VRAM if we scale by device_count.
287
+ device_count = 1
288
+ if is_ddp:
289
+ device_count = torch.cuda.device_count()
290
+ if device_count > 1:
291
+ min_bs = min_bs * device_count
292
+ print(
293
+ f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
294
+
295
+ if data_size > large_threshold:
296
+ base_bs = gpu_large * device_count
297
+ elif data_size > mid_threshold:
298
+ base_bs = gpu_mid * device_count
299
+ else:
300
+ base_bs = gpu_small * device_count
301
+ else:
302
+ base_bs = cpu_mid if data_size > mid_threshold else cpu_small
303
+
304
+ batch_size = compute_batch_size(
305
+ data_size=data_size,
306
+ learning_rate=self.learning_rate,
307
+ batch_num=self.batch_num,
308
+ minimum=min_bs
309
+ )
310
+ batch_size = min(batch_size, base_bs, data_size)
311
+ batch_size = self._cap_batch_size_by_memory(
312
+ dataset, batch_size, profile)
313
+
314
+ target_effective_bs = target_effective_cuda if device_type == 'cuda' else target_effective_cpu
315
+ world_size = 1
316
+ if is_ddp:
317
+ world_size = getattr(self, "world_size", None)
318
+ world_size = max(1, world_size or DistributedUtils.world_size())
319
+ target_effective_bs = max(1, target_effective_bs // world_size)
320
+ samples_per_rank = math.ceil(
321
+ data_size / max(1, world_size)) if world_size > 1 else data_size
322
+ steps_per_epoch = max(
323
+ 1, math.ceil(samples_per_rank / max(1, batch_size)))
324
+ desired_accum = max(1, target_effective_bs // max(1, batch_size))
325
+ accum_steps = max(1, min(desired_accum, steps_per_epoch))
326
+
327
+ workers = self._resolve_num_workers(8, profile=profile)
328
+ prefetch_factor = None
329
+ if workers > 0:
330
+ prefetch_factor = 4 if profile == "throughput" else 2
331
+ persistent = workers > 0 and profile != "memory_saving"
332
+ print(
333
+ f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
334
+ f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
335
+ sampler = None
336
+ use_distributed_sampler = bool(
337
+ dist.is_initialized() and getattr(self, "is_ddp_enabled", False)
338
+ )
339
+ if use_distributed_sampler:
340
+ sampler = DistributedSampler(dataset, shuffle=True)
341
+ shuffle = False
342
+ else:
343
+ shuffle = True
344
+
345
+ dataloader = DataLoader(
346
+ dataset,
347
+ batch_size=batch_size,
348
+ shuffle=shuffle,
349
+ sampler=sampler,
350
+ num_workers=workers,
351
+ pin_memory=(device_type == 'cuda'),
352
+ persistent_workers=persistent,
353
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
354
+ )
355
+ self.dataloader_sampler = sampler
356
+ return dataloader, accum_steps
357
+
358
+ def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
359
+ """Build validation DataLoader."""
360
+ profile = self._resolve_resource_profile()
361
+ val_bs = accum_steps * train_dataloader.batch_size
362
+ val_workers = self._resolve_num_workers(4, profile=profile)
363
+ prefetch_factor = None
364
+ if val_workers > 0:
365
+ prefetch_factor = 2
366
+ return DataLoader(
367
+ dataset,
368
+ batch_size=val_bs,
369
+ shuffle=False,
370
+ num_workers=val_workers,
371
+ pin_memory=(self._device_type() == 'cuda'),
372
+ persistent_workers=(val_workers > 0 and profile != "memory_saving"),
373
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
374
+ )
375
+
376
+ def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
377
+ """Compute per-sample losses based on task type."""
378
+ task = getattr(self, "task_type", "regression")
379
+ if task == 'classification':
380
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
381
+ return loss_fn(y_pred, y_true).view(-1)
382
+ loss_name = normalize_loss_name(
383
+ getattr(self, "loss_name", None), task_type="regression"
384
+ )
385
+ if loss_name == "auto":
386
+ model_name = getattr(self, "model_name", None) or getattr(self, "model_nme", "")
387
+ loss_name = infer_loss_name_from_model_name(model_name)
388
+ if apply_softplus:
389
+ y_pred = F.softplus(y_pred)
390
+ if loss_requires_positive(loss_name):
391
+ y_pred = torch.clamp(y_pred, min=1e-6)
392
+ power = resolve_tweedie_power(
393
+ loss_name, default=float(getattr(self, "tw_power", 1.5) or 1.5)
394
+ )
395
+ if power is None:
396
+ power = float(getattr(self, "tw_power", 1.5) or 1.5)
397
+ return tweedie_loss(y_pred, y_true, p=power).view(-1)
398
+ if loss_name == "mse":
399
+ return (y_pred - y_true).pow(2).view(-1)
400
+ if loss_name == "mae":
401
+ return (y_pred - y_true).abs().view(-1)
402
+ raise ValueError(f"Unsupported loss_name '{loss_name}' for regression.")
403
+
404
+ def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
405
+ """Compute weighted loss."""
406
+ losses = self._compute_losses(
407
+ y_pred, y_true, apply_softplus=apply_softplus)
408
+ weighted_loss = (losses * weights.view(-1)).sum() / \
409
+ torch.clamp(weights.sum(), min=EPS)
410
+ return weighted_loss
411
+
412
+ def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
413
+ ignore_keys: Optional[List[str]] = None):
414
+ """Update early stopping state."""
415
+ if val_loss < best_loss:
416
+ ignore_keys = ignore_keys or []
417
+ base_module = model.module if hasattr(model, "module") else model
418
+ state_dict = {
419
+ k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
420
+ for k, v in base_module.state_dict().items()
421
+ if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
422
+ }
423
+ return val_loss, state_dict, 0, False
424
+ patience_counter += 1
425
+ should_stop = best_state is not None and patience_counter >= getattr(
426
+ self, "patience", 0)
427
+ return best_loss, best_state, patience_counter, should_stop
428
+
429
+ def _train_model(self,
430
+ model,
431
+ dataloader,
432
+ accum_steps,
433
+ optimizer,
434
+ scaler,
435
+ forward_fn,
436
+ val_forward_fn=None,
437
+ apply_softplus: bool = False,
438
+ clip_fn=None,
439
+ trial: Optional[optuna.trial.Trial] = None,
440
+ loss_curve_path: Optional[str] = None):
441
+ """Generic training loop with AMP, DDP, and early stopping support.
442
+
443
+ Returns:
444
+ Tuple of (best_state_dict, history)
445
+ """
446
+ device_type = self._device_type()
447
+ best_loss = float('inf')
448
+ best_state = None
449
+ patience_counter = 0
450
+ stop_training = False
451
+ train_history: List[float] = []
452
+ val_history: List[float] = []
453
+
454
+ is_ddp_model = isinstance(model, DDP)
455
+ use_collectives = dist.is_initialized() and is_ddp_model
456
+
457
+ for epoch in range(1, getattr(self, "epochs", 1) + 1):
458
+ epoch_start_ts = time.time()
459
+ val_weighted_loss = None
460
+ if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
461
+ self.dataloader_sampler.set_epoch(epoch)
462
+
463
+ model.train()
464
+ optimizer.zero_grad()
465
+
466
+ epoch_loss_sum = None
467
+ epoch_weight_sum = None
468
+ for step, batch in enumerate(dataloader):
469
+ is_update_step = ((step + 1) % accum_steps == 0) or \
470
+ ((step + 1) == len(dataloader))
471
+ sync_cm = model.no_sync if (
472
+ is_ddp_model and not is_update_step) else nullcontext
473
+
474
+ with sync_cm():
475
+ with autocast(enabled=(device_type == 'cuda')):
476
+ y_pred, y_true, w = forward_fn(batch)
477
+ weighted_loss = self._compute_weighted_loss(
478
+ y_pred, y_true, w, apply_softplus=apply_softplus)
479
+ loss_for_backward = weighted_loss / accum_steps
480
+
481
+ batch_weight = torch.clamp(
482
+ w.detach().sum(), min=EPS).to(dtype=torch.float32)
483
+ loss_val = weighted_loss.detach().to(dtype=torch.float32)
484
+ if epoch_loss_sum is None:
485
+ epoch_loss_sum = torch.zeros(
486
+ (), device=batch_weight.device, dtype=torch.float32)
487
+ epoch_weight_sum = torch.zeros(
488
+ (), device=batch_weight.device, dtype=torch.float32)
489
+ epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
490
+ epoch_weight_sum = epoch_weight_sum + batch_weight
491
+ scaler.scale(loss_for_backward).backward()
492
+
493
+ if is_update_step:
494
+ if clip_fn is not None:
495
+ clip_fn()
496
+ scaler.step(optimizer)
497
+ scaler.update()
498
+ optimizer.zero_grad()
499
+
500
+ if epoch_loss_sum is None or epoch_weight_sum is None:
501
+ train_epoch_loss = 0.0
502
+ else:
503
+ train_epoch_loss = (
504
+ epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
505
+ ).item()
506
+ train_history.append(float(train_epoch_loss))
507
+
508
+ if val_forward_fn is not None:
509
+ should_compute_val = (not dist.is_initialized()
510
+ or DistributedUtils.is_main_process())
511
+ val_device = self._resolve_device()
512
+ loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
513
+ "cpu")
514
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
515
+
516
+ if should_compute_val:
517
+ model.eval()
518
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
519
+ val_result = val_forward_fn()
520
+ if isinstance(val_result, tuple) and len(val_result) == 3:
521
+ y_val_pred, y_val_true, w_val = val_result
522
+ val_weighted_loss = self._compute_weighted_loss(
523
+ y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
524
+ else:
525
+ val_weighted_loss = val_result
526
+ val_loss_tensor[0] = float(val_weighted_loss)
527
+
528
+ if use_collectives:
529
+ dist.broadcast(val_loss_tensor, src=0)
530
+ val_weighted_loss = float(val_loss_tensor.item())
531
+
532
+ val_history.append(val_weighted_loss)
533
+
534
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
535
+ val_weighted_loss, best_loss, best_state, patience_counter, model)
536
+
537
+ prune_flag = False
538
+ is_main_rank = DistributedUtils.is_main_process()
539
+ if trial is not None and is_main_rank:
540
+ trial.report(val_weighted_loss, epoch)
541
+ prune_flag = trial.should_prune()
542
+
543
+ if use_collectives:
544
+ prune_device = self._resolve_device()
545
+ prune_tensor = torch.zeros(1, device=prune_device)
546
+ if is_main_rank:
547
+ prune_tensor.fill_(1 if prune_flag else 0)
548
+ dist.broadcast(prune_tensor, src=0)
549
+ prune_flag = bool(prune_tensor.item())
550
+
551
+ if prune_flag:
552
+ raise optuna.TrialPruned()
553
+
554
+ if stop_training:
555
+ break
556
+
557
+ should_log_epoch = (not dist.is_initialized()
558
+ or DistributedUtils.is_main_process())
559
+ if should_log_epoch:
560
+ elapsed = int(time.time() - epoch_start_ts)
561
+ if val_weighted_loss is None:
562
+ print(
563
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
564
+ f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
565
+ flush=True,
566
+ )
567
+ else:
568
+ print(
569
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
570
+ f"train_loss={float(train_epoch_loss):.6f} "
571
+ f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
572
+ flush=True,
573
+ )
574
+
575
+ if epoch % 10 == 0:
576
+ if torch.cuda.is_available():
577
+ torch.cuda.empty_cache()
578
+ gc.collect()
579
+
580
+ history = {"train": train_history, "val": val_history}
581
+ self._plot_loss_curve(history, loss_curve_path)
582
+ return best_state, history
583
+
584
+ def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
585
+ """Plot training and validation loss curves."""
586
+ if not save_path:
587
+ return
588
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
589
+ return
590
+ train_hist = history.get("train", []) if history else []
591
+ val_hist = history.get("val", []) if history else []
592
+ if not train_hist and not val_hist:
593
+ return
594
+ if plot_loss_curve_common is not None:
595
+ plot_loss_curve_common(
596
+ history=history,
597
+ title="Loss vs. Epoch",
598
+ save_path=save_path,
599
+ show=False,
600
+ )
601
+ else:
602
+ if plt is None:
603
+ _plot_skip("loss curve")
604
+ return
605
+ ensure_parent_dir(save_path)
606
+ epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
607
+ fig = plt.figure(figsize=(8, 4))
608
+ ax = fig.add_subplot(111)
609
+ if train_hist:
610
+ ax.plot(range(1, len(train_hist) + 1), train_hist,
611
+ label='Train Loss', color='tab:blue')
612
+ if val_hist:
613
+ ax.plot(range(1, len(val_hist) + 1), val_hist,
614
+ label='Validation Loss', color='tab:orange')
615
+ ax.set_xlabel('Epoch')
616
+ ax.set_ylabel('Weighted Loss')
617
+ ax.set_title('Loss vs. Epoch')
618
+ ax.grid(True, linestyle='--', alpha=0.3)
619
+ ax.legend()
620
+ plt.tight_layout()
621
+ plt.savefig(save_path, dpi=300)
622
+ plt.close(fig)
623
+ print(f"[Training] Loss curve saved to {save_path}")