ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,623 +1,636 @@
1
- """PyTorch training mixin with resource management and training loops.
2
-
3
- This module provides the TorchTrainerMixin class which is used by
4
- PyTorch-based trainers (ResNet, FT, GNN) for:
5
- - Resource profiling and memory management
6
- - Batch size computation and optimization
7
- - DataLoader creation with DDP support
8
- - Generic training and validation loops with AMP
9
- - Early stopping and loss curve plotting
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- import copy
15
- import ctypes
16
- import gc
17
- import math
18
- import os
19
- import time
20
- from contextlib import nullcontext
21
- from typing import Any, Callable, Dict, List, Optional, Tuple
22
-
23
- import numpy as np
24
- import optuna
25
- import torch
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
- import torch.distributed as dist
29
- from torch.cuda.amp import autocast, GradScaler
30
- from torch.nn.parallel import DistributedDataParallel as DDP
31
- from torch.utils.data import DataLoader, DistributedSampler
32
-
33
- # Try to import plotting functions
34
- try:
35
- import matplotlib
36
- if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
37
- matplotlib.use("Agg")
38
- import matplotlib.pyplot as plt
39
- _MPL_IMPORT_ERROR: Optional[BaseException] = None
40
- except Exception as exc:
41
- matplotlib = None
42
- plt = None
43
- _MPL_IMPORT_ERROR = exc
44
-
45
- try:
46
- from ....plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
47
- except Exception:
48
- try:
49
- from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
50
- except Exception:
51
- plot_loss_curve_common = None
52
-
53
- # Import from other utils modules
54
- from .constants import EPS, compute_batch_size, tweedie_loss, ensure_parent_dir
55
- from .losses import (
56
- infer_loss_name_from_model_name,
57
- loss_requires_positive,
58
- normalize_loss_name,
59
- resolve_tweedie_power,
60
- )
61
- from .distributed_utils import DistributedUtils
62
-
63
-
64
- def _plot_skip(label: str) -> None:
65
- """Print message when plot is skipped due to missing matplotlib."""
66
- if _MPL_IMPORT_ERROR is not None:
67
- print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
68
- else:
69
- print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
70
-
71
-
72
- class TorchTrainerMixin:
73
- """Shared helpers for PyTorch tabular trainers.
74
-
75
- Provides resource profiling, memory management, batch size optimization,
76
- and standardized training loops with mixed precision and DDP support.
77
-
78
- This mixin is used by ResNetTrainer, FTTrainer, and GNNTrainer.
79
- """
80
-
81
- def _device_type(self) -> str:
82
- """Get device type (cpu/cuda/mps)."""
83
- return getattr(self, "device", torch.device("cpu")).type
84
-
85
- def _resolve_resource_profile(self) -> str:
86
- """Determine resource usage profile.
87
-
88
- Returns:
89
- One of: 'throughput', 'memory_saving', or 'auto'
90
- """
91
- profile = getattr(self, "resource_profile", None)
92
- if not profile:
93
- profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
94
- profile = str(profile).strip().lower()
95
- if profile in {"cpu", "mps", "cuda"}:
96
- profile = "auto"
97
- if profile not in {"auto", "throughput", "memory_saving"}:
98
- profile = "auto"
99
- if profile == "auto" and self._device_type() == "cuda":
100
- profile = "throughput"
101
- return profile
102
-
103
- def _log_resource_summary_once(self, profile: str) -> None:
104
- """Log resource configuration summary once."""
105
- if getattr(self, "_resource_summary_logged", False):
106
- return
107
- if dist.is_initialized() and not DistributedUtils.is_main_process():
108
- return
109
- self._resource_summary_logged = True
110
- device = getattr(self, "device", torch.device("cpu"))
111
- device_type = self._device_type()
112
- cpu_count = os.cpu_count() or 1
113
- cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
114
- mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
115
- ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
116
- data_parallel = bool(getattr(self, "use_data_parallel", False))
117
- print(
118
- f">>> Resource summary: device={device}, device_type={device_type}, "
119
- f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
120
- f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
121
- )
122
-
123
- def _available_system_memory(self) -> Optional[int]:
124
- """Get available system RAM in bytes."""
125
- if os.name == "nt":
126
- class _MemStatus(ctypes.Structure):
127
- _fields_ = [
128
- ("dwLength", ctypes.c_ulong),
129
- ("dwMemoryLoad", ctypes.c_ulong),
130
- ("ullTotalPhys", ctypes.c_ulonglong),
131
- ("ullAvailPhys", ctypes.c_ulonglong),
132
- ("ullTotalPageFile", ctypes.c_ulonglong),
133
- ("ullAvailPageFile", ctypes.c_ulonglong),
134
- ("ullTotalVirtual", ctypes.c_ulonglong),
135
- ("ullAvailVirtual", ctypes.c_ulonglong),
136
- ("sullAvailExtendedVirtual", ctypes.c_ulonglong),
137
- ]
138
- status = _MemStatus()
139
- status.dwLength = ctypes.sizeof(_MemStatus)
140
- if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
141
- return int(status.ullAvailPhys)
142
- return None
143
- try:
144
- pages = os.sysconf("SC_AVPHYS_PAGES")
145
- page_size = os.sysconf("SC_PAGE_SIZE")
146
- return int(pages * page_size)
147
- except Exception:
148
- return None
149
-
150
- def _available_cuda_memory(self) -> Optional[int]:
151
- """Get available CUDA memory in bytes."""
152
- if not torch.cuda.is_available():
153
- return None
154
- try:
155
- free_mem, _total_mem = torch.cuda.mem_get_info()
156
- except Exception:
157
- return None
158
- return int(free_mem)
159
-
160
- def _estimate_sample_bytes(self, dataset) -> Optional[int]:
161
- """Estimate memory per sample in bytes."""
162
- try:
163
- if len(dataset) == 0:
164
- return None
165
- sample = dataset[0]
166
- except Exception:
167
- return None
168
-
169
- def _bytes(obj) -> int:
170
- if obj is None:
171
- return 0
172
- if torch.is_tensor(obj):
173
- return int(obj.element_size() * obj.nelement())
174
- if isinstance(obj, np.ndarray):
175
- return int(obj.nbytes)
176
- if isinstance(obj, (list, tuple)):
177
- return int(sum(_bytes(item) for item in obj))
178
- if isinstance(obj, dict):
179
- return int(sum(_bytes(item) for item in obj.values()))
180
- return 0
181
-
182
- sample_bytes = _bytes(sample)
183
- return int(sample_bytes) if sample_bytes > 0 else None
184
-
185
- def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
186
- """Cap batch size based on available memory."""
187
- if batch_size <= 1:
188
- return batch_size
189
- sample_bytes = self._estimate_sample_bytes(dataset)
190
- if sample_bytes is None:
191
- return batch_size
192
- device_type = self._device_type()
193
- if device_type == "cuda":
194
- available = self._available_cuda_memory()
195
- if available is None:
196
- return batch_size
197
- if profile == "throughput":
198
- budget_ratio = 0.8
199
- overhead = 8.0
200
- elif profile == "memory_saving":
201
- budget_ratio = 0.5
202
- overhead = 14.0
203
- else:
204
- budget_ratio = 0.6
205
- overhead = 12.0
206
- else:
207
- available = self._available_system_memory()
208
- if available is None:
209
- return batch_size
210
- if profile == "throughput":
211
- budget_ratio = 0.4
212
- overhead = 1.8
213
- elif profile == "memory_saving":
214
- budget_ratio = 0.25
215
- overhead = 3.0
216
- else:
217
- budget_ratio = 0.3
218
- overhead = 2.6
219
- budget = int(available * budget_ratio)
220
- per_sample = int(sample_bytes * overhead)
221
- if per_sample <= 0:
222
- return batch_size
223
- max_batch = max(1, int(budget // per_sample))
224
- if max_batch < batch_size:
225
- print(
226
- f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
227
- f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
228
- )
229
- return min(batch_size, max_batch)
230
-
231
- def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
232
- """Determine number of DataLoader workers."""
233
- if os.name == 'nt':
234
- return 0
235
- override = getattr(self, "dataloader_workers", None)
236
- if override is None:
237
- override = os.environ.get("BAYESOPT_DATALOADER_WORKERS")
238
- if override is not None:
239
- try:
240
- return max(0, int(override))
241
- except (TypeError, ValueError):
242
- pass
243
- if getattr(self, "is_ddp_enabled", False):
244
- return 0
245
- profile = profile or self._resolve_resource_profile()
246
- if profile == "memory_saving":
247
- return 0
248
- worker_cap = min(int(max_workers), os.cpu_count() or 1)
249
- if self._device_type() == "mps":
250
- worker_cap = min(worker_cap, 2)
251
- return worker_cap
252
-
253
- def _build_dataloader(self,
254
- dataset,
255
- N: int,
256
- base_bs_gpu: tuple,
257
- base_bs_cpu: tuple,
258
- min_bs: int = 64,
259
- target_effective_cuda: int = 1024,
260
- target_effective_cpu: int = 512,
261
- large_threshold: int = 200_000,
262
- mid_threshold: int = 50_000):
263
- """Build DataLoader with adaptive batch size and worker configuration.
264
-
265
- Returns:
266
- Tuple of (dataloader, accum_steps)
267
- """
268
- profile = self._resolve_resource_profile()
269
- self._log_resource_summary_once(profile)
270
- batch_size = compute_batch_size(
271
- data_size=len(dataset),
272
- learning_rate=self.learning_rate,
273
- batch_num=self.batch_num,
274
- minimum=min_bs
275
- )
276
- gpu_large, gpu_mid, gpu_small = base_bs_gpu
277
- cpu_mid, cpu_small = base_bs_cpu
278
-
279
- if self._device_type() == 'cuda':
280
- # Only scale batch size by GPU count when DDP is enabled.
281
- # In single-process (non-DDP) mode, large multi-GPU nodes can
282
- # still OOM on RAM/VRAM if we scale by device_count.
283
- device_count = 1
284
- if getattr(self, "is_ddp_enabled", False):
285
- device_count = torch.cuda.device_count()
286
- if device_count > 1:
287
- min_bs = min_bs * device_count
288
- print(
289
- f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
290
-
291
- if N > large_threshold:
292
- base_bs = gpu_large * device_count
293
- elif N > mid_threshold:
294
- base_bs = gpu_mid * device_count
295
- else:
296
- base_bs = gpu_small * device_count
297
- else:
298
- base_bs = cpu_mid if N > mid_threshold else cpu_small
299
-
300
- batch_size = compute_batch_size(
301
- data_size=len(dataset),
302
- learning_rate=self.learning_rate,
303
- batch_num=self.batch_num,
304
- minimum=min_bs
305
- )
306
- batch_size = min(batch_size, base_bs, N)
307
- batch_size = self._cap_batch_size_by_memory(
308
- dataset, batch_size, profile)
309
-
310
- target_effective_bs = target_effective_cuda if self._device_type(
311
- ) == 'cuda' else target_effective_cpu
312
- if getattr(self, "is_ddp_enabled", False):
313
- world_size = max(1, DistributedUtils.world_size())
314
- target_effective_bs = max(1, target_effective_bs // world_size)
315
-
316
- world_size = getattr(self, "world_size", 1) if getattr(
317
- self, "is_ddp_enabled", False) else 1
318
- samples_per_rank = math.ceil(
319
- N / max(1, world_size)) if world_size > 1 else N
320
- steps_per_epoch = max(
321
- 1, math.ceil(samples_per_rank / max(1, batch_size)))
322
- desired_accum = max(1, target_effective_bs // max(1, batch_size))
323
- accum_steps = max(1, min(desired_accum, steps_per_epoch))
324
-
325
- workers = self._resolve_num_workers(8, profile=profile)
326
- prefetch_factor = None
327
- if workers > 0:
328
- prefetch_factor = 4 if profile == "throughput" else 2
329
- persistent = workers > 0 and profile != "memory_saving"
330
- print(
331
- f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
332
- f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
333
- sampler = None
334
- use_distributed_sampler = bool(
335
- dist.is_initialized() and getattr(self, "is_ddp_enabled", False)
336
- )
337
- if use_distributed_sampler:
338
- sampler = DistributedSampler(dataset, shuffle=True)
339
- shuffle = False
340
- else:
341
- shuffle = True
342
-
343
- dataloader = DataLoader(
344
- dataset,
345
- batch_size=batch_size,
346
- shuffle=shuffle,
347
- sampler=sampler,
348
- num_workers=workers,
349
- pin_memory=(self._device_type() == 'cuda'),
350
- persistent_workers=persistent,
351
- **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
352
- )
353
- return dataloader, accum_steps
354
-
355
- def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
356
- """Build validation DataLoader."""
357
- profile = self._resolve_resource_profile()
358
- val_bs = accum_steps * train_dataloader.batch_size
359
- val_workers = self._resolve_num_workers(4, profile=profile)
360
- prefetch_factor = None
361
- if val_workers > 0:
362
- prefetch_factor = 2
363
- return DataLoader(
364
- dataset,
365
- batch_size=val_bs,
366
- shuffle=False,
367
- num_workers=val_workers,
368
- pin_memory=(self._device_type() == 'cuda'),
369
- persistent_workers=(val_workers > 0 and profile != "memory_saving"),
370
- **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
371
- )
372
-
373
- def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
374
- """Compute per-sample losses based on task type."""
375
- task = getattr(self, "task_type", "regression")
376
- if task == 'classification':
377
- loss_fn = nn.BCEWithLogitsLoss(reduction='none')
378
- return loss_fn(y_pred, y_true).view(-1)
379
- loss_name = normalize_loss_name(
380
- getattr(self, "loss_name", None), task_type="regression"
381
- )
382
- if loss_name == "auto":
383
- loss_name = infer_loss_name_from_model_name(getattr(self, "model_nme", ""))
384
- if apply_softplus:
385
- y_pred = F.softplus(y_pred)
386
- if loss_requires_positive(loss_name):
387
- y_pred = torch.clamp(y_pred, min=1e-6)
388
- power = resolve_tweedie_power(
389
- loss_name, default=float(getattr(self, "tw_power", 1.5) or 1.5)
390
- )
391
- if power is None:
392
- power = float(getattr(self, "tw_power", 1.5) or 1.5)
393
- return tweedie_loss(y_pred, y_true, p=power).view(-1)
394
- if loss_name == "mse":
395
- return (y_pred - y_true).pow(2).view(-1)
396
- if loss_name == "mae":
397
- return (y_pred - y_true).abs().view(-1)
398
- raise ValueError(f"Unsupported loss_name '{loss_name}' for regression.")
399
-
400
- def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
401
- """Compute weighted loss."""
402
- losses = self._compute_losses(
403
- y_pred, y_true, apply_softplus=apply_softplus)
404
- weighted_loss = (losses * weights.view(-1)).sum() / \
405
- torch.clamp(weights.sum(), min=EPS)
406
- return weighted_loss
407
-
408
- def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
409
- ignore_keys: Optional[List[str]] = None):
410
- """Update early stopping state."""
411
- if val_loss < best_loss:
412
- ignore_keys = ignore_keys or []
413
- base_module = model.module if hasattr(model, "module") else model
414
- state_dict = {
415
- k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
416
- for k, v in base_module.state_dict().items()
417
- if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
418
- }
419
- return val_loss, state_dict, 0, False
420
- patience_counter += 1
421
- should_stop = best_state is not None and patience_counter >= getattr(
422
- self, "patience", 0)
423
- return best_loss, best_state, patience_counter, should_stop
424
-
425
- def _train_model(self,
426
- model,
427
- dataloader,
428
- accum_steps,
429
- optimizer,
430
- scaler,
431
- forward_fn,
432
- val_forward_fn=None,
433
- apply_softplus: bool = False,
434
- clip_fn=None,
435
- trial: Optional[optuna.trial.Trial] = None,
436
- loss_curve_path: Optional[str] = None):
437
- """Generic training loop with AMP, DDP, and early stopping support.
438
-
439
- Returns:
440
- Tuple of (best_state_dict, history)
441
- """
442
- device_type = self._device_type()
443
- best_loss = float('inf')
444
- best_state = None
445
- patience_counter = 0
446
- stop_training = False
447
- train_history: List[float] = []
448
- val_history: List[float] = []
449
-
450
- is_ddp_model = isinstance(model, DDP)
451
- use_collectives = dist.is_initialized() and is_ddp_model
452
-
453
- for epoch in range(1, getattr(self, "epochs", 1) + 1):
454
- epoch_start_ts = time.time()
455
- val_weighted_loss = None
456
- if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
457
- self.dataloader_sampler.set_epoch(epoch)
458
-
459
- model.train()
460
- optimizer.zero_grad()
461
-
462
- epoch_loss_sum = None
463
- epoch_weight_sum = None
464
- for step, batch in enumerate(dataloader):
465
- is_update_step = ((step + 1) % accum_steps == 0) or \
466
- ((step + 1) == len(dataloader))
467
- sync_cm = model.no_sync if (
468
- is_ddp_model and not is_update_step) else nullcontext
469
-
470
- with sync_cm():
471
- with autocast(enabled=(device_type == 'cuda')):
472
- y_pred, y_true, w = forward_fn(batch)
473
- weighted_loss = self._compute_weighted_loss(
474
- y_pred, y_true, w, apply_softplus=apply_softplus)
475
- loss_for_backward = weighted_loss / accum_steps
476
-
477
- batch_weight = torch.clamp(
478
- w.detach().sum(), min=EPS).to(dtype=torch.float32)
479
- loss_val = weighted_loss.detach().to(dtype=torch.float32)
480
- if epoch_loss_sum is None:
481
- epoch_loss_sum = torch.zeros(
482
- (), device=batch_weight.device, dtype=torch.float32)
483
- epoch_weight_sum = torch.zeros(
484
- (), device=batch_weight.device, dtype=torch.float32)
485
- epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
486
- epoch_weight_sum = epoch_weight_sum + batch_weight
487
- scaler.scale(loss_for_backward).backward()
488
-
489
- if is_update_step:
490
- if clip_fn is not None:
491
- clip_fn()
492
- scaler.step(optimizer)
493
- scaler.update()
494
- optimizer.zero_grad()
495
-
496
- if epoch_loss_sum is None or epoch_weight_sum is None:
497
- train_epoch_loss = 0.0
498
- else:
499
- train_epoch_loss = (
500
- epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
501
- ).item()
502
- train_history.append(float(train_epoch_loss))
503
-
504
- if val_forward_fn is not None:
505
- should_compute_val = (not dist.is_initialized()
506
- or DistributedUtils.is_main_process())
507
- val_device = getattr(self, "device", torch.device("cpu"))
508
- if not isinstance(val_device, torch.device):
509
- val_device = torch.device(val_device)
510
- loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
511
- "cpu")
512
- val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
513
-
514
- if should_compute_val:
515
- model.eval()
516
- with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
517
- val_result = val_forward_fn()
518
- if isinstance(val_result, tuple) and len(val_result) == 3:
519
- y_val_pred, y_val_true, w_val = val_result
520
- val_weighted_loss = self._compute_weighted_loss(
521
- y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
522
- else:
523
- val_weighted_loss = val_result
524
- val_loss_tensor[0] = float(val_weighted_loss)
525
-
526
- if use_collectives:
527
- dist.broadcast(val_loss_tensor, src=0)
528
- val_weighted_loss = float(val_loss_tensor.item())
529
-
530
- val_history.append(val_weighted_loss)
531
-
532
- best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
533
- val_weighted_loss, best_loss, best_state, patience_counter, model)
534
-
535
- prune_flag = False
536
- is_main_rank = DistributedUtils.is_main_process()
537
- if trial is not None and is_main_rank:
538
- trial.report(val_weighted_loss, epoch)
539
- prune_flag = trial.should_prune()
540
-
541
- if use_collectives:
542
- prune_device = getattr(self, "device", torch.device("cpu"))
543
- if not isinstance(prune_device, torch.device):
544
- prune_device = torch.device(prune_device)
545
- prune_tensor = torch.zeros(1, device=prune_device)
546
- if is_main_rank:
547
- prune_tensor.fill_(1 if prune_flag else 0)
548
- dist.broadcast(prune_tensor, src=0)
549
- prune_flag = bool(prune_tensor.item())
550
-
551
- if prune_flag:
552
- raise optuna.TrialPruned()
553
-
554
- if stop_training:
555
- break
556
-
557
- should_log_epoch = (not dist.is_initialized()
558
- or DistributedUtils.is_main_process())
559
- if should_log_epoch:
560
- elapsed = int(time.time() - epoch_start_ts)
561
- if val_weighted_loss is None:
562
- print(
563
- f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
564
- f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
565
- flush=True,
566
- )
567
- else:
568
- print(
569
- f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
570
- f"train_loss={float(train_epoch_loss):.6f} "
571
- f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
572
- flush=True,
573
- )
574
-
575
- if epoch % 10 == 0:
576
- if torch.cuda.is_available():
577
- torch.cuda.empty_cache()
578
- gc.collect()
579
-
580
- history = {"train": train_history, "val": val_history}
581
- self._plot_loss_curve(history, loss_curve_path)
582
- return best_state, history
583
-
584
- def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
585
- """Plot training and validation loss curves."""
586
- if not save_path:
587
- return
588
- if dist.is_initialized() and not DistributedUtils.is_main_process():
589
- return
590
- train_hist = history.get("train", []) if history else []
591
- val_hist = history.get("val", []) if history else []
592
- if not train_hist and not val_hist:
593
- return
594
- if plot_loss_curve_common is not None:
595
- plot_loss_curve_common(
596
- history=history,
597
- title="Loss vs. Epoch",
598
- save_path=save_path,
599
- show=False,
600
- )
601
- else:
602
- if plt is None:
603
- _plot_skip("loss curve")
604
- return
605
- ensure_parent_dir(save_path)
606
- epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
607
- fig = plt.figure(figsize=(8, 4))
608
- ax = fig.add_subplot(111)
609
- if train_hist:
610
- ax.plot(range(1, len(train_hist) + 1), train_hist,
611
- label='Train Loss', color='tab:blue')
612
- if val_hist:
613
- ax.plot(range(1, len(val_hist) + 1), val_hist,
614
- label='Validation Loss', color='tab:orange')
615
- ax.set_xlabel('Epoch')
616
- ax.set_ylabel('Weighted Loss')
617
- ax.set_title('Loss vs. Epoch')
618
- ax.grid(True, linestyle='--', alpha=0.3)
619
- ax.legend()
620
- plt.tight_layout()
621
- plt.savefig(save_path, dpi=300)
622
- plt.close(fig)
623
- print(f"[Training] Loss curve saved to {save_path}")
1
+ """PyTorch training mixin with resource management and training loops.
2
+
3
+ This module provides the TorchTrainerMixin class which is used by
4
+ PyTorch-based trainers (ResNet, FT, GNN) for:
5
+ - Resource profiling and memory management
6
+ - Batch size computation and optimization
7
+ - DataLoader creation with DDP support
8
+ - Generic training and validation loops with AMP
9
+ - Early stopping and loss curve plotting
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import ctypes
16
+ import gc
17
+ import math
18
+ import os
19
+ import time
20
+ from contextlib import nullcontext
21
+ from typing import Dict, List, Optional
22
+
23
+ import numpy as np
24
+ import optuna
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torch.distributed as dist
29
+ from torch.cuda.amp import autocast
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ from torch.utils.data import DataLoader, DistributedSampler
32
+
33
+ # Try to import plotting functions
34
+ try:
35
+ import matplotlib
36
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
37
+ matplotlib.use("Agg")
38
+ import matplotlib.pyplot as plt
39
+ _MPL_IMPORT_ERROR: Optional[BaseException] = None
40
+ except Exception as exc:
41
+ matplotlib = None
42
+ plt = None
43
+ _MPL_IMPORT_ERROR = exc
44
+
45
+ try:
46
+ from ins_pricing.modelling.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
47
+ except Exception:
48
+ try:
49
+ from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
50
+ except Exception:
51
+ plot_loss_curve_common = None
52
+
53
+ # Import from other utils modules
54
+ from ins_pricing.utils import (
55
+ EPS,
56
+ compute_batch_size,
57
+ tweedie_loss,
58
+ ensure_parent_dir,
59
+ get_logger,
60
+ log_print,
61
+ )
62
+ from ins_pricing.utils.losses import (
63
+ infer_loss_name_from_model_name,
64
+ loss_requires_positive,
65
+ normalize_loss_name,
66
+ resolve_tweedie_power,
67
+ )
68
+ from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
69
+
70
+ _logger = get_logger("ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin")
71
+
72
+
73
+ def _log(*args, **kwargs) -> None:
74
+ log_print(_logger, *args, **kwargs)
75
+
76
+
77
+ def _plot_skip(label: str) -> None:
78
+ """Print message when plot is skipped due to missing matplotlib."""
79
+ if _MPL_IMPORT_ERROR is not None:
80
+ _log(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
81
+ else:
82
+ _log(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
83
+
84
+
85
+ class TorchTrainerMixin:
86
+ """Shared helpers for PyTorch tabular trainers.
87
+
88
+ Provides resource profiling, memory management, batch size optimization,
89
+ and standardized training loops with mixed precision and DDP support.
90
+
91
+ This mixin is used by ResNetTrainer, FTTrainer, and GNNTrainer.
92
+ """
93
+
94
+ def _resolve_device(self) -> torch.device:
95
+ """Resolve device to a torch.device instance."""
96
+ device = getattr(self, "device", None)
97
+ if device is None:
98
+ return torch.device("cpu")
99
+ return device if isinstance(device, torch.device) else torch.device(device)
100
+
101
+ def _device_type(self) -> str:
102
+ """Get device type (cpu/cuda/mps)."""
103
+ return self._resolve_device().type
104
+
105
+ def _resolve_resource_profile(self) -> str:
106
+ """Determine resource usage profile.
107
+
108
+ Returns:
109
+ One of: 'throughput', 'memory_saving', or 'auto'
110
+ """
111
+ profile = getattr(self, "resource_profile", None)
112
+ if not profile:
113
+ profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
114
+ profile = str(profile).strip().lower()
115
+ if profile in {"cpu", "mps", "cuda"}:
116
+ profile = "auto"
117
+ if profile not in {"auto", "throughput", "memory_saving"}:
118
+ profile = "auto"
119
+ if profile == "auto" and self._device_type() == "cuda":
120
+ profile = "throughput"
121
+ return profile
122
+
123
+ def _log_resource_summary_once(self, profile: str) -> None:
124
+ """Log resource configuration summary once."""
125
+ if getattr(self, "_resource_summary_logged", False):
126
+ return
127
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
128
+ return
129
+ self._resource_summary_logged = True
130
+ device = self._resolve_device()
131
+ device_type = self._device_type()
132
+ cpu_count = os.cpu_count() or 1
133
+ cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
134
+ mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
135
+ ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
136
+ data_parallel = bool(getattr(self, "use_data_parallel", False))
137
+ _log(
138
+ f">>> Resource summary: device={device}, device_type={device_type}, "
139
+ f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
140
+ f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
141
+ )
142
+
143
+ def _available_system_memory(self) -> Optional[int]:
144
+ """Get available system RAM in bytes."""
145
+ if os.name == "nt":
146
+ class _MemStatus(ctypes.Structure):
147
+ _fields_ = [
148
+ ("dwLength", ctypes.c_ulong),
149
+ ("dwMemoryLoad", ctypes.c_ulong),
150
+ ("ullTotalPhys", ctypes.c_ulonglong),
151
+ ("ullAvailPhys", ctypes.c_ulonglong),
152
+ ("ullTotalPageFile", ctypes.c_ulonglong),
153
+ ("ullAvailPageFile", ctypes.c_ulonglong),
154
+ ("ullTotalVirtual", ctypes.c_ulonglong),
155
+ ("ullAvailVirtual", ctypes.c_ulonglong),
156
+ ("sullAvailExtendedVirtual", ctypes.c_ulonglong),
157
+ ]
158
+ status = _MemStatus()
159
+ status.dwLength = ctypes.sizeof(_MemStatus)
160
+ if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
161
+ return int(status.ullAvailPhys)
162
+ return None
163
+ try:
164
+ pages = os.sysconf("SC_AVPHYS_PAGES")
165
+ page_size = os.sysconf("SC_PAGE_SIZE")
166
+ return int(pages * page_size)
167
+ except Exception:
168
+ return None
169
+
170
+ def _available_cuda_memory(self) -> Optional[int]:
171
+ """Get available CUDA memory in bytes."""
172
+ if not torch.cuda.is_available():
173
+ return None
174
+ try:
175
+ free_mem, _total_mem = torch.cuda.mem_get_info()
176
+ except Exception:
177
+ return None
178
+ return int(free_mem)
179
+
180
+ def _estimate_sample_bytes(self, dataset) -> Optional[int]:
181
+ """Estimate memory per sample in bytes."""
182
+ try:
183
+ if len(dataset) == 0:
184
+ return None
185
+ sample = dataset[0]
186
+ except Exception:
187
+ return None
188
+
189
+ def _bytes(obj) -> int:
190
+ if obj is None:
191
+ return 0
192
+ if torch.is_tensor(obj):
193
+ return int(obj.element_size() * obj.nelement())
194
+ if isinstance(obj, np.ndarray):
195
+ return int(obj.nbytes)
196
+ if isinstance(obj, (list, tuple)):
197
+ return int(sum(_bytes(item) for item in obj))
198
+ if isinstance(obj, dict):
199
+ return int(sum(_bytes(item) for item in obj.values()))
200
+ return 0
201
+
202
+ sample_bytes = _bytes(sample)
203
+ return int(sample_bytes) if sample_bytes > 0 else None
204
+
205
+ def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
206
+ """Cap batch size based on available memory."""
207
+ if batch_size <= 1:
208
+ return batch_size
209
+ sample_bytes = self._estimate_sample_bytes(dataset)
210
+ if sample_bytes is None:
211
+ return batch_size
212
+ device_type = self._device_type()
213
+ if device_type == "cuda":
214
+ available = self._available_cuda_memory()
215
+ if available is None:
216
+ return batch_size
217
+ if profile == "throughput":
218
+ budget_ratio = 0.8
219
+ overhead = 8.0
220
+ elif profile == "memory_saving":
221
+ budget_ratio = 0.5
222
+ overhead = 14.0
223
+ else:
224
+ budget_ratio = 0.6
225
+ overhead = 12.0
226
+ else:
227
+ available = self._available_system_memory()
228
+ if available is None:
229
+ return batch_size
230
+ if profile == "throughput":
231
+ budget_ratio = 0.4
232
+ overhead = 1.8
233
+ elif profile == "memory_saving":
234
+ budget_ratio = 0.25
235
+ overhead = 3.0
236
+ else:
237
+ budget_ratio = 0.3
238
+ overhead = 2.6
239
+ budget = int(available * budget_ratio)
240
+ per_sample = int(sample_bytes * overhead)
241
+ if per_sample <= 0:
242
+ return batch_size
243
+ max_batch = max(1, int(budget // per_sample))
244
+ if max_batch < batch_size:
245
+ _log(
246
+ f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
247
+ f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
248
+ )
249
+ return min(batch_size, max_batch)
250
+
251
+ def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
252
+ """Determine number of DataLoader workers."""
253
+ if os.name == 'nt':
254
+ return 0
255
+ override = getattr(self, "dataloader_workers", None)
256
+ if override is None:
257
+ override = os.environ.get("BAYESOPT_DATALOADER_WORKERS")
258
+ if override is not None:
259
+ try:
260
+ return max(0, int(override))
261
+ except (TypeError, ValueError):
262
+ pass
263
+ if getattr(self, "is_ddp_enabled", False):
264
+ return 0
265
+ profile = profile or self._resolve_resource_profile()
266
+ if profile == "memory_saving":
267
+ return 0
268
+ worker_cap = min(int(max_workers), os.cpu_count() or 1)
269
+ if self._device_type() == "mps":
270
+ worker_cap = min(worker_cap, 2)
271
+ return worker_cap
272
+
273
+ def _build_dataloader(self,
274
+ dataset,
275
+ N: int,
276
+ base_bs_gpu: tuple,
277
+ base_bs_cpu: tuple,
278
+ min_bs: int = 64,
279
+ target_effective_cuda: int = 1024,
280
+ target_effective_cpu: int = 512,
281
+ large_threshold: int = 200_000,
282
+ mid_threshold: int = 50_000):
283
+ """Build DataLoader with adaptive batch size and worker configuration.
284
+
285
+ Returns:
286
+ Tuple of (dataloader, accum_steps)
287
+ """
288
+ profile = self._resolve_resource_profile()
289
+ self._log_resource_summary_once(profile)
290
+ data_size = int(N) if N is not None else len(dataset)
291
+ gpu_large, gpu_mid, gpu_small = base_bs_gpu
292
+ cpu_mid, cpu_small = base_bs_cpu
293
+
294
+ device_type = self._device_type()
295
+ is_ddp = bool(getattr(self, "is_ddp_enabled", False))
296
+ if device_type == 'cuda':
297
+ # Only scale batch size by GPU count when DDP is enabled.
298
+ # In single-process (non-DDP) mode, large multi-GPU nodes can
299
+ # still OOM on RAM/VRAM if we scale by device_count.
300
+ device_count = 1
301
+ if is_ddp:
302
+ device_count = torch.cuda.device_count()
303
+ if device_count > 1:
304
+ min_bs = min_bs * device_count
305
+ _log(
306
+ f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
307
+
308
+ if data_size > large_threshold:
309
+ base_bs = gpu_large * device_count
310
+ elif data_size > mid_threshold:
311
+ base_bs = gpu_mid * device_count
312
+ else:
313
+ base_bs = gpu_small * device_count
314
+ else:
315
+ base_bs = cpu_mid if data_size > mid_threshold else cpu_small
316
+
317
+ batch_size = compute_batch_size(
318
+ data_size=data_size,
319
+ learning_rate=self.learning_rate,
320
+ batch_num=self.batch_num,
321
+ minimum=min_bs
322
+ )
323
+ batch_size = min(batch_size, base_bs, data_size)
324
+ batch_size = self._cap_batch_size_by_memory(
325
+ dataset, batch_size, profile)
326
+
327
+ target_effective_bs = target_effective_cuda if device_type == 'cuda' else target_effective_cpu
328
+ world_size = 1
329
+ if is_ddp:
330
+ world_size = getattr(self, "world_size", None)
331
+ world_size = max(1, world_size or DistributedUtils.world_size())
332
+ target_effective_bs = max(1, target_effective_bs // world_size)
333
+ samples_per_rank = math.ceil(
334
+ data_size / max(1, world_size)) if world_size > 1 else data_size
335
+ steps_per_epoch = max(
336
+ 1, math.ceil(samples_per_rank / max(1, batch_size)))
337
+ desired_accum = max(1, target_effective_bs // max(1, batch_size))
338
+ accum_steps = max(1, min(desired_accum, steps_per_epoch))
339
+
340
+ workers = self._resolve_num_workers(8, profile=profile)
341
+ prefetch_factor = None
342
+ if workers > 0:
343
+ prefetch_factor = 4 if profile == "throughput" else 2
344
+ persistent = workers > 0 and profile != "memory_saving"
345
+ _log(
346
+ f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
347
+ f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
348
+ sampler = None
349
+ use_distributed_sampler = bool(
350
+ dist.is_initialized() and getattr(self, "is_ddp_enabled", False)
351
+ )
352
+ if use_distributed_sampler:
353
+ sampler = DistributedSampler(dataset, shuffle=True)
354
+ shuffle = False
355
+ else:
356
+ shuffle = True
357
+
358
+ dataloader = DataLoader(
359
+ dataset,
360
+ batch_size=batch_size,
361
+ shuffle=shuffle,
362
+ sampler=sampler,
363
+ num_workers=workers,
364
+ pin_memory=(device_type == 'cuda'),
365
+ persistent_workers=persistent,
366
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
367
+ )
368
+ self.dataloader_sampler = sampler
369
+ return dataloader, accum_steps
370
+
371
+ def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
372
+ """Build validation DataLoader."""
373
+ profile = self._resolve_resource_profile()
374
+ val_bs = accum_steps * train_dataloader.batch_size
375
+ val_workers = self._resolve_num_workers(4, profile=profile)
376
+ prefetch_factor = None
377
+ if val_workers > 0:
378
+ prefetch_factor = 2
379
+ return DataLoader(
380
+ dataset,
381
+ batch_size=val_bs,
382
+ shuffle=False,
383
+ num_workers=val_workers,
384
+ pin_memory=(self._device_type() == 'cuda'),
385
+ persistent_workers=(val_workers > 0 and profile != "memory_saving"),
386
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
387
+ )
388
+
389
+ def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
390
+ """Compute per-sample losses based on task type."""
391
+ task = getattr(self, "task_type", "regression")
392
+ if task == 'classification':
393
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
394
+ return loss_fn(y_pred, y_true).view(-1)
395
+ loss_name = normalize_loss_name(
396
+ getattr(self, "loss_name", None), task_type="regression"
397
+ )
398
+ if loss_name == "auto":
399
+ model_name = getattr(self, "model_name", None) or getattr(self, "model_nme", "")
400
+ loss_name = infer_loss_name_from_model_name(model_name)
401
+ if apply_softplus:
402
+ y_pred = F.softplus(y_pred)
403
+ if loss_requires_positive(loss_name):
404
+ y_pred = torch.clamp(y_pred, min=1e-6)
405
+ power = resolve_tweedie_power(
406
+ loss_name, default=float(getattr(self, "tw_power", 1.5) or 1.5)
407
+ )
408
+ if power is None:
409
+ power = float(getattr(self, "tw_power", 1.5) or 1.5)
410
+ return tweedie_loss(y_pred, y_true, p=power).view(-1)
411
+ if loss_name == "mse":
412
+ return (y_pred - y_true).pow(2).view(-1)
413
+ if loss_name == "mae":
414
+ return (y_pred - y_true).abs().view(-1)
415
+ raise ValueError(f"Unsupported loss_name '{loss_name}' for regression.")
416
+
417
+ def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
418
+ """Compute weighted loss."""
419
+ losses = self._compute_losses(
420
+ y_pred, y_true, apply_softplus=apply_softplus)
421
+ weighted_loss = (losses * weights.view(-1)).sum() / \
422
+ torch.clamp(weights.sum(), min=EPS)
423
+ return weighted_loss
424
+
425
+ def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
426
+ ignore_keys: Optional[List[str]] = None):
427
+ """Update early stopping state."""
428
+ if val_loss < best_loss:
429
+ ignore_keys = ignore_keys or []
430
+ base_module = model.module if hasattr(model, "module") else model
431
+ state_dict = {
432
+ k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
433
+ for k, v in base_module.state_dict().items()
434
+ if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
435
+ }
436
+ return val_loss, state_dict, 0, False
437
+ patience_counter += 1
438
+ should_stop = best_state is not None and patience_counter >= getattr(
439
+ self, "patience", 0)
440
+ return best_loss, best_state, patience_counter, should_stop
441
+
442
+ def _train_model(self,
443
+ model,
444
+ dataloader,
445
+ accum_steps,
446
+ optimizer,
447
+ scaler,
448
+ forward_fn,
449
+ val_forward_fn=None,
450
+ apply_softplus: bool = False,
451
+ clip_fn=None,
452
+ trial: Optional[optuna.trial.Trial] = None,
453
+ loss_curve_path: Optional[str] = None):
454
+ """Generic training loop with AMP, DDP, and early stopping support.
455
+
456
+ Returns:
457
+ Tuple of (best_state_dict, history)
458
+ """
459
+ device_type = self._device_type()
460
+ best_loss = float('inf')
461
+ best_state = None
462
+ patience_counter = 0
463
+ stop_training = False
464
+ train_history: List[float] = []
465
+ val_history: List[float] = []
466
+
467
+ is_ddp_model = isinstance(model, DDP)
468
+ use_collectives = dist.is_initialized() and is_ddp_model
469
+
470
+ for epoch in range(1, getattr(self, "epochs", 1) + 1):
471
+ epoch_start_ts = time.time()
472
+ val_weighted_loss = None
473
+ if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
474
+ self.dataloader_sampler.set_epoch(epoch)
475
+
476
+ model.train()
477
+ optimizer.zero_grad()
478
+
479
+ epoch_loss_sum = None
480
+ epoch_weight_sum = None
481
+ for step, batch in enumerate(dataloader):
482
+ is_update_step = ((step + 1) % accum_steps == 0) or \
483
+ ((step + 1) == len(dataloader))
484
+ sync_cm = model.no_sync if (
485
+ is_ddp_model and not is_update_step) else nullcontext
486
+
487
+ with sync_cm():
488
+ with autocast(enabled=(device_type == 'cuda')):
489
+ y_pred, y_true, w = forward_fn(batch)
490
+ weighted_loss = self._compute_weighted_loss(
491
+ y_pred, y_true, w, apply_softplus=apply_softplus)
492
+ loss_for_backward = weighted_loss / accum_steps
493
+
494
+ batch_weight = torch.clamp(
495
+ w.detach().sum(), min=EPS).to(dtype=torch.float32)
496
+ loss_val = weighted_loss.detach().to(dtype=torch.float32)
497
+ if epoch_loss_sum is None:
498
+ epoch_loss_sum = torch.zeros(
499
+ (), device=batch_weight.device, dtype=torch.float32)
500
+ epoch_weight_sum = torch.zeros(
501
+ (), device=batch_weight.device, dtype=torch.float32)
502
+ epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
503
+ epoch_weight_sum = epoch_weight_sum + batch_weight
504
+ scaler.scale(loss_for_backward).backward()
505
+
506
+ if is_update_step:
507
+ if clip_fn is not None:
508
+ clip_fn()
509
+ scaler.step(optimizer)
510
+ scaler.update()
511
+ optimizer.zero_grad()
512
+
513
+ if epoch_loss_sum is None or epoch_weight_sum is None:
514
+ train_epoch_loss = 0.0
515
+ else:
516
+ train_epoch_loss = (
517
+ epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
518
+ ).item()
519
+ train_history.append(float(train_epoch_loss))
520
+
521
+ if val_forward_fn is not None:
522
+ should_compute_val = (not dist.is_initialized()
523
+ or DistributedUtils.is_main_process())
524
+ val_device = self._resolve_device()
525
+ loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
526
+ "cpu")
527
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
528
+
529
+ if should_compute_val:
530
+ model.eval()
531
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
532
+ val_result = val_forward_fn()
533
+ if isinstance(val_result, tuple) and len(val_result) == 3:
534
+ y_val_pred, y_val_true, w_val = val_result
535
+ val_weighted_loss = self._compute_weighted_loss(
536
+ y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
537
+ else:
538
+ val_weighted_loss = val_result
539
+ val_loss_tensor[0] = float(val_weighted_loss)
540
+
541
+ if use_collectives:
542
+ dist.broadcast(val_loss_tensor, src=0)
543
+ val_weighted_loss = float(val_loss_tensor.item())
544
+
545
+ val_history.append(val_weighted_loss)
546
+
547
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
548
+ val_weighted_loss, best_loss, best_state, patience_counter, model)
549
+
550
+ prune_flag = False
551
+ is_main_rank = DistributedUtils.is_main_process()
552
+ if trial is not None and is_main_rank:
553
+ trial.report(val_weighted_loss, epoch)
554
+ prune_flag = trial.should_prune()
555
+
556
+ if use_collectives:
557
+ prune_device = self._resolve_device()
558
+ prune_tensor = torch.zeros(1, device=prune_device)
559
+ if is_main_rank:
560
+ prune_tensor.fill_(1 if prune_flag else 0)
561
+ dist.broadcast(prune_tensor, src=0)
562
+ prune_flag = bool(prune_tensor.item())
563
+
564
+ if prune_flag:
565
+ raise optuna.TrialPruned()
566
+
567
+ if stop_training:
568
+ break
569
+
570
+ should_log_epoch = (not dist.is_initialized()
571
+ or DistributedUtils.is_main_process())
572
+ if should_log_epoch:
573
+ elapsed = int(time.time() - epoch_start_ts)
574
+ if val_weighted_loss is None:
575
+ _log(
576
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
577
+ f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
578
+ flush=True,
579
+ )
580
+ else:
581
+ _log(
582
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
583
+ f"train_loss={float(train_epoch_loss):.6f} "
584
+ f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
585
+ flush=True,
586
+ )
587
+
588
+ if epoch % 10 == 0:
589
+ if torch.cuda.is_available():
590
+ torch.cuda.empty_cache()
591
+ gc.collect()
592
+
593
+ history = {"train": train_history, "val": val_history}
594
+ self._plot_loss_curve(history, loss_curve_path)
595
+ return best_state, history
596
+
597
+ def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
598
+ """Plot training and validation loss curves."""
599
+ if not save_path:
600
+ return
601
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
602
+ return
603
+ train_hist = history.get("train", []) if history else []
604
+ val_hist = history.get("val", []) if history else []
605
+ if not train_hist and not val_hist:
606
+ return
607
+ if plot_loss_curve_common is not None:
608
+ plot_loss_curve_common(
609
+ history=history,
610
+ title="Loss vs. Epoch",
611
+ save_path=save_path,
612
+ show=False,
613
+ )
614
+ else:
615
+ if plt is None:
616
+ _plot_skip("loss curve")
617
+ return
618
+ ensure_parent_dir(save_path)
619
+ epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
620
+ fig = plt.figure(figsize=(8, 4))
621
+ ax = fig.add_subplot(111)
622
+ if train_hist:
623
+ ax.plot(range(1, len(train_hist) + 1), train_hist,
624
+ label='Train Loss', color='tab:blue')
625
+ if val_hist:
626
+ ax.plot(range(1, len(val_hist) + 1), val_hist,
627
+ label='Validation Loss', color='tab:orange')
628
+ ax.set_xlabel('Epoch')
629
+ ax.set_ylabel('Weighted Loss')
630
+ ax.set_title('Loss vs. Epoch')
631
+ ax.grid(True, linestyle='--', alpha=0.3)
632
+ ax.legend()
633
+ plt.tight_layout()
634
+ plt.savefig(save_path, dpi=300)
635
+ plt.close(fig)
636
+ _log(f"[Training] Loss curve saved to {save_path}")