ins-pricing 0.2.9__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ins_pricing/CHANGELOG.md +93 -0
  2. ins_pricing/README.md +11 -0
  3. ins_pricing/cli/Explain_entry.py +50 -48
  4. ins_pricing/cli/bayesopt_entry_runner.py +699 -569
  5. ins_pricing/cli/utils/evaluation_context.py +320 -0
  6. ins_pricing/cli/utils/import_resolver.py +350 -0
  7. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
  8. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
  9. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
  10. ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
  11. ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
  12. ins_pricing/modelling/core/bayesopt/core.py +153 -94
  13. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +122 -34
  14. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +298 -142
  15. ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
  16. ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
  17. ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
  18. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
  19. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
  20. ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +591 -0
  21. ins_pricing/modelling/core/bayesopt/utils.py +98 -1496
  22. ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
  23. ins_pricing/setup.py +1 -1
  24. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/METADATA +14 -1
  25. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/RECORD +27 -14
  26. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/WHEEL +0 -0
  27. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,591 @@
1
+ """PyTorch training mixin with resource management and training loops.
2
+
3
+ This module provides the TorchTrainerMixin class which is used by
4
+ PyTorch-based trainers (ResNet, FT, GNN) for:
5
+ - Resource profiling and memory management
6
+ - Batch size computation and optimization
7
+ - DataLoader creation with DDP support
8
+ - Generic training and validation loops with AMP
9
+ - Early stopping and loss curve plotting
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import ctypes
16
+ import gc
17
+ import math
18
+ import os
19
+ import time
20
+ from contextlib import nullcontext
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple
22
+
23
+ import numpy as np
24
+ import optuna
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torch.distributed as dist
29
+ from torch.cuda.amp import autocast, GradScaler
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ from torch.utils.data import DataLoader, DistributedSampler
32
+
33
+ # Try to import plotting functions
34
+ try:
35
+ import matplotlib
36
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
37
+ matplotlib.use("Agg")
38
+ import matplotlib.pyplot as plt
39
+ _MPL_IMPORT_ERROR: Optional[BaseException] = None
40
+ except Exception as exc:
41
+ matplotlib = None
42
+ plt = None
43
+ _MPL_IMPORT_ERROR = exc
44
+
45
+ try:
46
+ from ....plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
47
+ except Exception:
48
+ try:
49
+ from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
50
+ except Exception:
51
+ plot_loss_curve_common = None
52
+
53
+ # Import from other utils modules
54
+ from .constants import EPS, compute_batch_size, tweedie_loss, ensure_parent_dir
55
+ from .distributed_utils import DistributedUtils
56
+
57
+
58
+ def _plot_skip(label: str) -> None:
59
+ """Print message when plot is skipped due to missing matplotlib."""
60
+ if _MPL_IMPORT_ERROR is not None:
61
+ print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
62
+ else:
63
+ print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
64
+
65
+
66
+ class TorchTrainerMixin:
67
+ """Shared helpers for PyTorch tabular trainers.
68
+
69
+ Provides resource profiling, memory management, batch size optimization,
70
+ and standardized training loops with mixed precision and DDP support.
71
+
72
+ This mixin is used by ResNetTrainer, FTTrainer, and GNNTrainer.
73
+ """
74
+
75
+ def _device_type(self) -> str:
76
+ """Get device type (cpu/cuda/mps)."""
77
+ return getattr(self, "device", torch.device("cpu")).type
78
+
79
+ def _resolve_resource_profile(self) -> str:
80
+ """Determine resource usage profile.
81
+
82
+ Returns:
83
+ One of: 'throughput', 'memory_saving', or 'auto'
84
+ """
85
+ profile = getattr(self, "resource_profile", None)
86
+ if not profile:
87
+ profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
88
+ profile = str(profile).strip().lower()
89
+ if profile in {"cpu", "mps", "cuda"}:
90
+ profile = "auto"
91
+ if profile not in {"auto", "throughput", "memory_saving"}:
92
+ profile = "auto"
93
+ if profile == "auto" and self._device_type() == "cuda":
94
+ profile = "throughput"
95
+ return profile
96
+
97
+ def _log_resource_summary_once(self, profile: str) -> None:
98
+ """Log resource configuration summary once."""
99
+ if getattr(self, "_resource_summary_logged", False):
100
+ return
101
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
102
+ return
103
+ self._resource_summary_logged = True
104
+ device = getattr(self, "device", torch.device("cpu"))
105
+ device_type = self._device_type()
106
+ cpu_count = os.cpu_count() or 1
107
+ cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
108
+ mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
109
+ ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
110
+ data_parallel = bool(getattr(self, "use_data_parallel", False))
111
+ print(
112
+ f">>> Resource summary: device={device}, device_type={device_type}, "
113
+ f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
114
+ f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
115
+ )
116
+
117
+ def _available_system_memory(self) -> Optional[int]:
118
+ """Get available system RAM in bytes."""
119
+ if os.name == "nt":
120
+ class _MemStatus(ctypes.Structure):
121
+ _fields_ = [
122
+ ("dwLength", ctypes.c_ulong),
123
+ ("dwMemoryLoad", ctypes.c_ulong),
124
+ ("ullTotalPhys", ctypes.c_ulonglong),
125
+ ("ullAvailPhys", ctypes.c_ulonglong),
126
+ ("ullTotalPageFile", ctypes.c_ulonglong),
127
+ ("ullAvailPageFile", ctypes.c_ulonglong),
128
+ ("ullTotalVirtual", ctypes.c_ulonglong),
129
+ ("ullAvailVirtual", ctypes.c_ulonglong),
130
+ ("sullAvailExtendedVirtual", ctypes.c_ulonglong),
131
+ ]
132
+ status = _MemStatus()
133
+ status.dwLength = ctypes.sizeof(_MemStatus)
134
+ if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
135
+ return int(status.ullAvailPhys)
136
+ return None
137
+ try:
138
+ pages = os.sysconf("SC_AVPHYS_PAGES")
139
+ page_size = os.sysconf("SC_PAGE_SIZE")
140
+ return int(pages * page_size)
141
+ except Exception:
142
+ return None
143
+
144
+ def _available_cuda_memory(self) -> Optional[int]:
145
+ """Get available CUDA memory in bytes."""
146
+ if not torch.cuda.is_available():
147
+ return None
148
+ try:
149
+ free_mem, _total_mem = torch.cuda.mem_get_info()
150
+ except Exception:
151
+ return None
152
+ return int(free_mem)
153
+
154
+ def _estimate_sample_bytes(self, dataset) -> Optional[int]:
155
+ """Estimate memory per sample in bytes."""
156
+ try:
157
+ if len(dataset) == 0:
158
+ return None
159
+ sample = dataset[0]
160
+ except Exception:
161
+ return None
162
+
163
+ def _bytes(obj) -> int:
164
+ if obj is None:
165
+ return 0
166
+ if torch.is_tensor(obj):
167
+ return int(obj.element_size() * obj.nelement())
168
+ if isinstance(obj, np.ndarray):
169
+ return int(obj.nbytes)
170
+ if isinstance(obj, (list, tuple)):
171
+ return int(sum(_bytes(item) for item in obj))
172
+ if isinstance(obj, dict):
173
+ return int(sum(_bytes(item) for item in obj.values()))
174
+ return 0
175
+
176
+ sample_bytes = _bytes(sample)
177
+ return int(sample_bytes) if sample_bytes > 0 else None
178
+
179
+ def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
180
+ """Cap batch size based on available memory."""
181
+ if batch_size <= 1:
182
+ return batch_size
183
+ sample_bytes = self._estimate_sample_bytes(dataset)
184
+ if sample_bytes is None:
185
+ return batch_size
186
+ device_type = self._device_type()
187
+ if device_type == "cuda":
188
+ available = self._available_cuda_memory()
189
+ if available is None:
190
+ return batch_size
191
+ if profile == "throughput":
192
+ budget_ratio = 0.8
193
+ overhead = 8.0
194
+ elif profile == "memory_saving":
195
+ budget_ratio = 0.5
196
+ overhead = 14.0
197
+ else:
198
+ budget_ratio = 0.6
199
+ overhead = 12.0
200
+ else:
201
+ available = self._available_system_memory()
202
+ if available is None:
203
+ return batch_size
204
+ if profile == "throughput":
205
+ budget_ratio = 0.4
206
+ overhead = 1.8
207
+ elif profile == "memory_saving":
208
+ budget_ratio = 0.25
209
+ overhead = 3.0
210
+ else:
211
+ budget_ratio = 0.3
212
+ overhead = 2.6
213
+ budget = int(available * budget_ratio)
214
+ per_sample = int(sample_bytes * overhead)
215
+ if per_sample <= 0:
216
+ return batch_size
217
+ max_batch = max(1, int(budget // per_sample))
218
+ if max_batch < batch_size:
219
+ print(
220
+ f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
221
+ f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
222
+ )
223
+ return min(batch_size, max_batch)
224
+
225
+ def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
226
+ """Determine number of DataLoader workers."""
227
+ if os.name == 'nt':
228
+ return 0
229
+ if getattr(self, "is_ddp_enabled", False):
230
+ return 0
231
+ profile = profile or self._resolve_resource_profile()
232
+ if profile == "memory_saving":
233
+ return 0
234
+ worker_cap = min(int(max_workers), os.cpu_count() or 1)
235
+ if self._device_type() == "mps":
236
+ worker_cap = min(worker_cap, 2)
237
+ return worker_cap
238
+
239
+ def _build_dataloader(self,
240
+ dataset,
241
+ N: int,
242
+ base_bs_gpu: tuple,
243
+ base_bs_cpu: tuple,
244
+ min_bs: int = 64,
245
+ target_effective_cuda: int = 1024,
246
+ target_effective_cpu: int = 512,
247
+ large_threshold: int = 200_000,
248
+ mid_threshold: int = 50_000):
249
+ """Build DataLoader with adaptive batch size and worker configuration.
250
+
251
+ Returns:
252
+ Tuple of (dataloader, accum_steps)
253
+ """
254
+ profile = self._resolve_resource_profile()
255
+ self._log_resource_summary_once(profile)
256
+ batch_size = compute_batch_size(
257
+ data_size=len(dataset),
258
+ learning_rate=self.learning_rate,
259
+ batch_num=self.batch_num,
260
+ minimum=min_bs
261
+ )
262
+ gpu_large, gpu_mid, gpu_small = base_bs_gpu
263
+ cpu_mid, cpu_small = base_bs_cpu
264
+
265
+ if self._device_type() == 'cuda':
266
+ device_count = torch.cuda.device_count()
267
+ if getattr(self, "is_ddp_enabled", False):
268
+ device_count = 1
269
+ if device_count > 1:
270
+ min_bs = min_bs * device_count
271
+ print(
272
+ f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
273
+
274
+ if N > large_threshold:
275
+ base_bs = gpu_large * device_count
276
+ elif N > mid_threshold:
277
+ base_bs = gpu_mid * device_count
278
+ else:
279
+ base_bs = gpu_small * device_count
280
+ else:
281
+ base_bs = cpu_mid if N > mid_threshold else cpu_small
282
+
283
+ batch_size = compute_batch_size(
284
+ data_size=len(dataset),
285
+ learning_rate=self.learning_rate,
286
+ batch_num=self.batch_num,
287
+ minimum=min_bs
288
+ )
289
+ batch_size = min(batch_size, base_bs, N)
290
+ batch_size = self._cap_batch_size_by_memory(
291
+ dataset, batch_size, profile)
292
+
293
+ target_effective_bs = target_effective_cuda if self._device_type(
294
+ ) == 'cuda' else target_effective_cpu
295
+ if getattr(self, "is_ddp_enabled", False):
296
+ world_size = max(1, DistributedUtils.world_size())
297
+ target_effective_bs = max(1, target_effective_bs // world_size)
298
+
299
+ world_size = getattr(self, "world_size", 1) if getattr(
300
+ self, "is_ddp_enabled", False) else 1
301
+ samples_per_rank = math.ceil(
302
+ N / max(1, world_size)) if world_size > 1 else N
303
+ steps_per_epoch = max(
304
+ 1, math.ceil(samples_per_rank / max(1, batch_size)))
305
+ desired_accum = max(1, target_effective_bs // max(1, batch_size))
306
+ accum_steps = max(1, min(desired_accum, steps_per_epoch))
307
+
308
+ workers = self._resolve_num_workers(8, profile=profile)
309
+ prefetch_factor = None
310
+ if workers > 0:
311
+ prefetch_factor = 4 if profile == "throughput" else 2
312
+ persistent = workers > 0 and profile != "memory_saving"
313
+ print(
314
+ f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
315
+ f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
316
+ sampler = None
317
+ use_distributed_sampler = bool(
318
+ dist.is_initialized() and getattr(self, "is_ddp_enabled", False)
319
+ )
320
+ if use_distributed_sampler:
321
+ sampler = DistributedSampler(dataset, shuffle=True)
322
+ shuffle = False
323
+ else:
324
+ shuffle = True
325
+
326
+ dataloader = DataLoader(
327
+ dataset,
328
+ batch_size=batch_size,
329
+ shuffle=shuffle,
330
+ sampler=sampler,
331
+ num_workers=workers,
332
+ pin_memory=(self._device_type() == 'cuda'),
333
+ persistent_workers=persistent,
334
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
335
+ )
336
+ return dataloader, accum_steps
337
+
338
+ def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
339
+ """Build validation DataLoader."""
340
+ profile = self._resolve_resource_profile()
341
+ val_bs = accum_steps * train_dataloader.batch_size
342
+ val_workers = self._resolve_num_workers(4, profile=profile)
343
+ prefetch_factor = None
344
+ if val_workers > 0:
345
+ prefetch_factor = 2
346
+ return DataLoader(
347
+ dataset,
348
+ batch_size=val_bs,
349
+ shuffle=False,
350
+ num_workers=val_workers,
351
+ pin_memory=(self._device_type() == 'cuda'),
352
+ persistent_workers=(val_workers > 0 and profile != "memory_saving"),
353
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
354
+ )
355
+
356
+ def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
357
+ """Compute per-sample losses based on task type."""
358
+ task = getattr(self, "task_type", "regression")
359
+ if task == 'classification':
360
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
361
+ return loss_fn(y_pred, y_true).view(-1)
362
+ if apply_softplus:
363
+ y_pred = F.softplus(y_pred)
364
+ y_pred = torch.clamp(y_pred, min=1e-6)
365
+ power = getattr(self, "tw_power", 1.5)
366
+ return tweedie_loss(y_pred, y_true, p=power).view(-1)
367
+
368
+ def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
369
+ """Compute weighted loss."""
370
+ losses = self._compute_losses(
371
+ y_pred, y_true, apply_softplus=apply_softplus)
372
+ weighted_loss = (losses * weights.view(-1)).sum() / \
373
+ torch.clamp(weights.sum(), min=EPS)
374
+ return weighted_loss
375
+
376
+ def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
377
+ ignore_keys: Optional[List[str]] = None):
378
+ """Update early stopping state."""
379
+ if val_loss < best_loss:
380
+ ignore_keys = ignore_keys or []
381
+ base_module = model.module if hasattr(model, "module") else model
382
+ state_dict = {
383
+ k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
384
+ for k, v in base_module.state_dict().items()
385
+ if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
386
+ }
387
+ return val_loss, state_dict, 0, False
388
+ patience_counter += 1
389
+ should_stop = best_state is not None and patience_counter >= getattr(
390
+ self, "patience", 0)
391
+ return best_loss, best_state, patience_counter, should_stop
392
+
393
+ def _train_model(self,
394
+ model,
395
+ dataloader,
396
+ accum_steps,
397
+ optimizer,
398
+ scaler,
399
+ forward_fn,
400
+ val_forward_fn=None,
401
+ apply_softplus: bool = False,
402
+ clip_fn=None,
403
+ trial: Optional[optuna.trial.Trial] = None,
404
+ loss_curve_path: Optional[str] = None):
405
+ """Generic training loop with AMP, DDP, and early stopping support.
406
+
407
+ Returns:
408
+ Tuple of (best_state_dict, history)
409
+ """
410
+ device_type = self._device_type()
411
+ best_loss = float('inf')
412
+ best_state = None
413
+ patience_counter = 0
414
+ stop_training = False
415
+ train_history: List[float] = []
416
+ val_history: List[float] = []
417
+
418
+ is_ddp_model = isinstance(model, DDP)
419
+ use_collectives = dist.is_initialized() and is_ddp_model
420
+
421
+ for epoch in range(1, getattr(self, "epochs", 1) + 1):
422
+ epoch_start_ts = time.time()
423
+ val_weighted_loss = None
424
+ if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
425
+ self.dataloader_sampler.set_epoch(epoch)
426
+
427
+ model.train()
428
+ optimizer.zero_grad()
429
+
430
+ epoch_loss_sum = None
431
+ epoch_weight_sum = None
432
+ for step, batch in enumerate(dataloader):
433
+ is_update_step = ((step + 1) % accum_steps == 0) or \
434
+ ((step + 1) == len(dataloader))
435
+ sync_cm = model.no_sync if (
436
+ is_ddp_model and not is_update_step) else nullcontext
437
+
438
+ with sync_cm():
439
+ with autocast(enabled=(device_type == 'cuda')):
440
+ y_pred, y_true, w = forward_fn(batch)
441
+ weighted_loss = self._compute_weighted_loss(
442
+ y_pred, y_true, w, apply_softplus=apply_softplus)
443
+ loss_for_backward = weighted_loss / accum_steps
444
+
445
+ batch_weight = torch.clamp(
446
+ w.detach().sum(), min=EPS).to(dtype=torch.float32)
447
+ loss_val = weighted_loss.detach().to(dtype=torch.float32)
448
+ if epoch_loss_sum is None:
449
+ epoch_loss_sum = torch.zeros(
450
+ (), device=batch_weight.device, dtype=torch.float32)
451
+ epoch_weight_sum = torch.zeros(
452
+ (), device=batch_weight.device, dtype=torch.float32)
453
+ epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
454
+ epoch_weight_sum = epoch_weight_sum + batch_weight
455
+ scaler.scale(loss_for_backward).backward()
456
+
457
+ if is_update_step:
458
+ if clip_fn is not None:
459
+ clip_fn()
460
+ scaler.step(optimizer)
461
+ scaler.update()
462
+ optimizer.zero_grad()
463
+
464
+ if epoch_loss_sum is None or epoch_weight_sum is None:
465
+ train_epoch_loss = 0.0
466
+ else:
467
+ train_epoch_loss = (
468
+ epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
469
+ ).item()
470
+ train_history.append(float(train_epoch_loss))
471
+
472
+ if val_forward_fn is not None:
473
+ should_compute_val = (not dist.is_initialized()
474
+ or DistributedUtils.is_main_process())
475
+ val_device = getattr(self, "device", torch.device("cpu"))
476
+ if not isinstance(val_device, torch.device):
477
+ val_device = torch.device(val_device)
478
+ loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
479
+ "cpu")
480
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
481
+
482
+ if should_compute_val:
483
+ model.eval()
484
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
485
+ val_result = val_forward_fn()
486
+ if isinstance(val_result, tuple) and len(val_result) == 3:
487
+ y_val_pred, y_val_true, w_val = val_result
488
+ val_weighted_loss = self._compute_weighted_loss(
489
+ y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
490
+ else:
491
+ val_weighted_loss = val_result
492
+ val_loss_tensor[0] = float(val_weighted_loss)
493
+
494
+ if use_collectives:
495
+ dist.broadcast(val_loss_tensor, src=0)
496
+ val_weighted_loss = float(val_loss_tensor.item())
497
+
498
+ val_history.append(val_weighted_loss)
499
+
500
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
501
+ val_weighted_loss, best_loss, best_state, patience_counter, model)
502
+
503
+ prune_flag = False
504
+ is_main_rank = DistributedUtils.is_main_process()
505
+ if trial is not None and is_main_rank:
506
+ trial.report(val_weighted_loss, epoch)
507
+ prune_flag = trial.should_prune()
508
+
509
+ if use_collectives:
510
+ prune_device = getattr(self, "device", torch.device("cpu"))
511
+ if not isinstance(prune_device, torch.device):
512
+ prune_device = torch.device(prune_device)
513
+ prune_tensor = torch.zeros(1, device=prune_device)
514
+ if is_main_rank:
515
+ prune_tensor.fill_(1 if prune_flag else 0)
516
+ dist.broadcast(prune_tensor, src=0)
517
+ prune_flag = bool(prune_tensor.item())
518
+
519
+ if prune_flag:
520
+ raise optuna.TrialPruned()
521
+
522
+ if stop_training:
523
+ break
524
+
525
+ should_log_epoch = (not dist.is_initialized()
526
+ or DistributedUtils.is_main_process())
527
+ if should_log_epoch:
528
+ elapsed = int(time.time() - epoch_start_ts)
529
+ if val_weighted_loss is None:
530
+ print(
531
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
532
+ f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
533
+ flush=True,
534
+ )
535
+ else:
536
+ print(
537
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
538
+ f"train_loss={float(train_epoch_loss):.6f} "
539
+ f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
540
+ flush=True,
541
+ )
542
+
543
+ if epoch % 10 == 0:
544
+ if torch.cuda.is_available():
545
+ torch.cuda.empty_cache()
546
+ gc.collect()
547
+
548
+ history = {"train": train_history, "val": val_history}
549
+ self._plot_loss_curve(history, loss_curve_path)
550
+ return best_state, history
551
+
552
+ def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
553
+ """Plot training and validation loss curves."""
554
+ if not save_path:
555
+ return
556
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
557
+ return
558
+ train_hist = history.get("train", []) if history else []
559
+ val_hist = history.get("val", []) if history else []
560
+ if not train_hist and not val_hist:
561
+ return
562
+ if plot_loss_curve_common is not None:
563
+ plot_loss_curve_common(
564
+ history=history,
565
+ title="Loss vs. Epoch",
566
+ save_path=save_path,
567
+ show=False,
568
+ )
569
+ else:
570
+ if plt is None:
571
+ _plot_skip("loss curve")
572
+ return
573
+ ensure_parent_dir(save_path)
574
+ epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
575
+ fig = plt.figure(figsize=(8, 4))
576
+ ax = fig.add_subplot(111)
577
+ if train_hist:
578
+ ax.plot(range(1, len(train_hist) + 1), train_hist,
579
+ label='Train Loss', color='tab:blue')
580
+ if val_hist:
581
+ ax.plot(range(1, len(val_hist) + 1), val_hist,
582
+ label='Validation Loss', color='tab:orange')
583
+ ax.set_xlabel('Epoch')
584
+ ax.set_ylabel('Weighted Loss')
585
+ ax.set_title('Loss vs. Epoch')
586
+ ax.grid(True, linestyle='--', alpha=0.3)
587
+ ax.legend()
588
+ plt.tight_layout()
589
+ plt.savefig(save_path, dpi=300)
590
+ plt.close(fig)
591
+ print(f"[Training] Loss curve saved to {save_path}")