ins-pricing 0.2.8__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ins_pricing/CHANGELOG.md +93 -0
  2. ins_pricing/README.md +11 -0
  3. ins_pricing/cli/bayesopt_entry_runner.py +626 -499
  4. ins_pricing/cli/utils/evaluation_context.py +320 -0
  5. ins_pricing/cli/utils/import_resolver.py +350 -0
  6. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
  7. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
  8. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
  9. ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
  10. ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
  11. ins_pricing/modelling/core/bayesopt/core.py +153 -94
  12. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +118 -31
  13. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +294 -139
  14. ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
  15. ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
  16. ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
  17. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
  18. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
  19. ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +587 -0
  20. ins_pricing/modelling/core/bayesopt/utils.py +98 -1495
  21. ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
  22. ins_pricing/setup.py +1 -1
  23. ins_pricing-0.3.0.dist-info/METADATA +162 -0
  24. {ins_pricing-0.2.8.dist-info → ins_pricing-0.3.0.dist-info}/RECORD +26 -13
  25. ins_pricing-0.2.8.dist-info/METADATA +0 -51
  26. {ins_pricing-0.2.8.dist-info → ins_pricing-0.3.0.dist-info}/WHEEL +0 -0
  27. {ins_pricing-0.2.8.dist-info → ins_pricing-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,587 @@
1
+ """PyTorch training mixin with resource management and training loops.
2
+
3
+ This module provides the TorchTrainerMixin class which is used by
4
+ PyTorch-based trainers (ResNet, FT, GNN) for:
5
+ - Resource profiling and memory management
6
+ - Batch size computation and optimization
7
+ - DataLoader creation with DDP support
8
+ - Generic training and validation loops with AMP
9
+ - Early stopping and loss curve plotting
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import ctypes
16
+ import gc
17
+ import math
18
+ import os
19
+ import time
20
+ from contextlib import nullcontext
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple
22
+
23
+ import numpy as np
24
+ import optuna
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torch.distributed as dist
29
+ from torch.cuda.amp import autocast, GradScaler
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ from torch.utils.data import DataLoader, DistributedSampler
32
+
33
+ # Try to import plotting functions
34
+ try:
35
+ import matplotlib
36
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
37
+ matplotlib.use("Agg")
38
+ import matplotlib.pyplot as plt
39
+ _MPL_IMPORT_ERROR: Optional[BaseException] = None
40
+ except Exception as exc:
41
+ matplotlib = None
42
+ plt = None
43
+ _MPL_IMPORT_ERROR = exc
44
+
45
+ try:
46
+ from ....plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
47
+ except Exception:
48
+ try:
49
+ from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
50
+ except Exception:
51
+ plot_loss_curve_common = None
52
+
53
+ # Import from other utils modules
54
+ from .constants import EPS, compute_batch_size, tweedie_loss, ensure_parent_dir
55
+ from .distributed_utils import DistributedUtils
56
+
57
+
58
+ def _plot_skip(label: str) -> None:
59
+ """Print message when plot is skipped due to missing matplotlib."""
60
+ if _MPL_IMPORT_ERROR is not None:
61
+ print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
62
+ else:
63
+ print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
64
+
65
+
66
+ class TorchTrainerMixin:
67
+ """Shared helpers for PyTorch tabular trainers.
68
+
69
+ Provides resource profiling, memory management, batch size optimization,
70
+ and standardized training loops with mixed precision and DDP support.
71
+
72
+ This mixin is used by ResNetTrainer, FTTrainer, and GNNTrainer.
73
+ """
74
+
75
+ def _device_type(self) -> str:
76
+ """Get device type (cpu/cuda/mps)."""
77
+ return getattr(self, "device", torch.device("cpu")).type
78
+
79
+ def _resolve_resource_profile(self) -> str:
80
+ """Determine resource usage profile.
81
+
82
+ Returns:
83
+ One of: 'throughput', 'memory_saving', or 'auto'
84
+ """
85
+ profile = getattr(self, "resource_profile", None)
86
+ if not profile:
87
+ profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
88
+ profile = str(profile).strip().lower()
89
+ if profile in {"cpu", "mps", "cuda"}:
90
+ profile = "auto"
91
+ if profile not in {"auto", "throughput", "memory_saving"}:
92
+ profile = "auto"
93
+ if profile == "auto" and self._device_type() == "cuda":
94
+ profile = "throughput"
95
+ return profile
96
+
97
+ def _log_resource_summary_once(self, profile: str) -> None:
98
+ """Log resource configuration summary once."""
99
+ if getattr(self, "_resource_summary_logged", False):
100
+ return
101
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
102
+ return
103
+ self._resource_summary_logged = True
104
+ device = getattr(self, "device", torch.device("cpu"))
105
+ device_type = self._device_type()
106
+ cpu_count = os.cpu_count() or 1
107
+ cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
108
+ mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
109
+ ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
110
+ data_parallel = bool(getattr(self, "use_data_parallel", False))
111
+ print(
112
+ f">>> Resource summary: device={device}, device_type={device_type}, "
113
+ f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
114
+ f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
115
+ )
116
+
117
+ def _available_system_memory(self) -> Optional[int]:
118
+ """Get available system RAM in bytes."""
119
+ if os.name == "nt":
120
+ class _MemStatus(ctypes.Structure):
121
+ _fields_ = [
122
+ ("dwLength", ctypes.c_ulong),
123
+ ("dwMemoryLoad", ctypes.c_ulong),
124
+ ("ullTotalPhys", ctypes.c_ulonglong),
125
+ ("ullAvailPhys", ctypes.c_ulonglong),
126
+ ("ullTotalPageFile", ctypes.c_ulonglong),
127
+ ("ullAvailPageFile", ctypes.c_ulonglong),
128
+ ("ullTotalVirtual", ctypes.c_ulonglong),
129
+ ("ullAvailVirtual", ctypes.c_ulonglong),
130
+ ("sullAvailExtendedVirtual", ctypes.c_ulonglong),
131
+ ]
132
+ status = _MemStatus()
133
+ status.dwLength = ctypes.sizeof(_MemStatus)
134
+ if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
135
+ return int(status.ullAvailPhys)
136
+ return None
137
+ try:
138
+ pages = os.sysconf("SC_AVPHYS_PAGES")
139
+ page_size = os.sysconf("SC_PAGE_SIZE")
140
+ return int(pages * page_size)
141
+ except Exception:
142
+ return None
143
+
144
+ def _available_cuda_memory(self) -> Optional[int]:
145
+ """Get available CUDA memory in bytes."""
146
+ if not torch.cuda.is_available():
147
+ return None
148
+ try:
149
+ free_mem, _total_mem = torch.cuda.mem_get_info()
150
+ except Exception:
151
+ return None
152
+ return int(free_mem)
153
+
154
+ def _estimate_sample_bytes(self, dataset) -> Optional[int]:
155
+ """Estimate memory per sample in bytes."""
156
+ try:
157
+ if len(dataset) == 0:
158
+ return None
159
+ sample = dataset[0]
160
+ except Exception:
161
+ return None
162
+
163
+ def _bytes(obj) -> int:
164
+ if obj is None:
165
+ return 0
166
+ if torch.is_tensor(obj):
167
+ return int(obj.element_size() * obj.nelement())
168
+ if isinstance(obj, np.ndarray):
169
+ return int(obj.nbytes)
170
+ if isinstance(obj, (list, tuple)):
171
+ return int(sum(_bytes(item) for item in obj))
172
+ if isinstance(obj, dict):
173
+ return int(sum(_bytes(item) for item in obj.values()))
174
+ return 0
175
+
176
+ sample_bytes = _bytes(sample)
177
+ return int(sample_bytes) if sample_bytes > 0 else None
178
+
179
+ def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
180
+ """Cap batch size based on available memory."""
181
+ if batch_size <= 1:
182
+ return batch_size
183
+ sample_bytes = self._estimate_sample_bytes(dataset)
184
+ if sample_bytes is None:
185
+ return batch_size
186
+ device_type = self._device_type()
187
+ if device_type == "cuda":
188
+ available = self._available_cuda_memory()
189
+ if available is None:
190
+ return batch_size
191
+ if profile == "throughput":
192
+ budget_ratio = 0.8
193
+ overhead = 8.0
194
+ elif profile == "memory_saving":
195
+ budget_ratio = 0.5
196
+ overhead = 14.0
197
+ else:
198
+ budget_ratio = 0.6
199
+ overhead = 12.0
200
+ else:
201
+ available = self._available_system_memory()
202
+ if available is None:
203
+ return batch_size
204
+ if profile == "throughput":
205
+ budget_ratio = 0.4
206
+ overhead = 1.8
207
+ elif profile == "memory_saving":
208
+ budget_ratio = 0.25
209
+ overhead = 3.0
210
+ else:
211
+ budget_ratio = 0.3
212
+ overhead = 2.6
213
+ budget = int(available * budget_ratio)
214
+ per_sample = int(sample_bytes * overhead)
215
+ if per_sample <= 0:
216
+ return batch_size
217
+ max_batch = max(1, int(budget // per_sample))
218
+ if max_batch < batch_size:
219
+ print(
220
+ f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
221
+ f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
222
+ )
223
+ return min(batch_size, max_batch)
224
+
225
+ def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
226
+ """Determine number of DataLoader workers."""
227
+ if os.name == 'nt':
228
+ return 0
229
+ if getattr(self, "is_ddp_enabled", False):
230
+ return 0
231
+ profile = profile or self._resolve_resource_profile()
232
+ if profile == "memory_saving":
233
+ return 0
234
+ worker_cap = min(int(max_workers), os.cpu_count() or 1)
235
+ if self._device_type() == "mps":
236
+ worker_cap = min(worker_cap, 2)
237
+ return worker_cap
238
+
239
+ def _build_dataloader(self,
240
+ dataset,
241
+ N: int,
242
+ base_bs_gpu: tuple,
243
+ base_bs_cpu: tuple,
244
+ min_bs: int = 64,
245
+ target_effective_cuda: int = 1024,
246
+ target_effective_cpu: int = 512,
247
+ large_threshold: int = 200_000,
248
+ mid_threshold: int = 50_000):
249
+ """Build DataLoader with adaptive batch size and worker configuration.
250
+
251
+ Returns:
252
+ Tuple of (dataloader, accum_steps)
253
+ """
254
+ profile = self._resolve_resource_profile()
255
+ self._log_resource_summary_once(profile)
256
+ batch_size = compute_batch_size(
257
+ data_size=len(dataset),
258
+ learning_rate=self.learning_rate,
259
+ batch_num=self.batch_num,
260
+ minimum=min_bs
261
+ )
262
+ gpu_large, gpu_mid, gpu_small = base_bs_gpu
263
+ cpu_mid, cpu_small = base_bs_cpu
264
+
265
+ if self._device_type() == 'cuda':
266
+ device_count = torch.cuda.device_count()
267
+ if getattr(self, "is_ddp_enabled", False):
268
+ device_count = 1
269
+ if device_count > 1:
270
+ min_bs = min_bs * device_count
271
+ print(
272
+ f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
273
+
274
+ if N > large_threshold:
275
+ base_bs = gpu_large * device_count
276
+ elif N > mid_threshold:
277
+ base_bs = gpu_mid * device_count
278
+ else:
279
+ base_bs = gpu_small * device_count
280
+ else:
281
+ base_bs = cpu_mid if N > mid_threshold else cpu_small
282
+
283
+ batch_size = compute_batch_size(
284
+ data_size=len(dataset),
285
+ learning_rate=self.learning_rate,
286
+ batch_num=self.batch_num,
287
+ minimum=min_bs
288
+ )
289
+ batch_size = min(batch_size, base_bs, N)
290
+ batch_size = self._cap_batch_size_by_memory(
291
+ dataset, batch_size, profile)
292
+
293
+ target_effective_bs = target_effective_cuda if self._device_type(
294
+ ) == 'cuda' else target_effective_cpu
295
+ if getattr(self, "is_ddp_enabled", False):
296
+ world_size = max(1, DistributedUtils.world_size())
297
+ target_effective_bs = max(1, target_effective_bs // world_size)
298
+
299
+ world_size = getattr(self, "world_size", 1) if getattr(
300
+ self, "is_ddp_enabled", False) else 1
301
+ samples_per_rank = math.ceil(
302
+ N / max(1, world_size)) if world_size > 1 else N
303
+ steps_per_epoch = max(
304
+ 1, math.ceil(samples_per_rank / max(1, batch_size)))
305
+ desired_accum = max(1, target_effective_bs // max(1, batch_size))
306
+ accum_steps = max(1, min(desired_accum, steps_per_epoch))
307
+
308
+ workers = self._resolve_num_workers(8, profile=profile)
309
+ prefetch_factor = None
310
+ if workers > 0:
311
+ prefetch_factor = 4 if profile == "throughput" else 2
312
+ persistent = workers > 0 and profile != "memory_saving"
313
+ print(
314
+ f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
315
+ f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
316
+ sampler = None
317
+ if dist.is_initialized():
318
+ sampler = DistributedSampler(dataset, shuffle=True)
319
+ shuffle = False
320
+ else:
321
+ shuffle = True
322
+
323
+ dataloader = DataLoader(
324
+ dataset,
325
+ batch_size=batch_size,
326
+ shuffle=shuffle,
327
+ sampler=sampler,
328
+ num_workers=workers,
329
+ pin_memory=(self._device_type() == 'cuda'),
330
+ persistent_workers=persistent,
331
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
332
+ )
333
+ return dataloader, accum_steps
334
+
335
+ def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
336
+ """Build validation DataLoader."""
337
+ profile = self._resolve_resource_profile()
338
+ val_bs = accum_steps * train_dataloader.batch_size
339
+ val_workers = self._resolve_num_workers(4, profile=profile)
340
+ prefetch_factor = None
341
+ if val_workers > 0:
342
+ prefetch_factor = 2
343
+ return DataLoader(
344
+ dataset,
345
+ batch_size=val_bs,
346
+ shuffle=False,
347
+ num_workers=val_workers,
348
+ pin_memory=(self._device_type() == 'cuda'),
349
+ persistent_workers=(val_workers > 0 and profile != "memory_saving"),
350
+ **({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
351
+ )
352
+
353
+ def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
354
+ """Compute per-sample losses based on task type."""
355
+ task = getattr(self, "task_type", "regression")
356
+ if task == 'classification':
357
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
358
+ return loss_fn(y_pred, y_true).view(-1)
359
+ if apply_softplus:
360
+ y_pred = F.softplus(y_pred)
361
+ y_pred = torch.clamp(y_pred, min=1e-6)
362
+ power = getattr(self, "tw_power", 1.5)
363
+ return tweedie_loss(y_pred, y_true, p=power).view(-1)
364
+
365
+ def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
366
+ """Compute weighted loss."""
367
+ losses = self._compute_losses(
368
+ y_pred, y_true, apply_softplus=apply_softplus)
369
+ weighted_loss = (losses * weights.view(-1)).sum() / \
370
+ torch.clamp(weights.sum(), min=EPS)
371
+ return weighted_loss
372
+
373
+ def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
374
+ ignore_keys: Optional[List[str]] = None):
375
+ """Update early stopping state."""
376
+ if val_loss < best_loss:
377
+ ignore_keys = ignore_keys or []
378
+ base_module = model.module if hasattr(model, "module") else model
379
+ state_dict = {
380
+ k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
381
+ for k, v in base_module.state_dict().items()
382
+ if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
383
+ }
384
+ return val_loss, state_dict, 0, False
385
+ patience_counter += 1
386
+ should_stop = best_state is not None and patience_counter >= getattr(
387
+ self, "patience", 0)
388
+ return best_loss, best_state, patience_counter, should_stop
389
+
390
+ def _train_model(self,
391
+ model,
392
+ dataloader,
393
+ accum_steps,
394
+ optimizer,
395
+ scaler,
396
+ forward_fn,
397
+ val_forward_fn=None,
398
+ apply_softplus: bool = False,
399
+ clip_fn=None,
400
+ trial: Optional[optuna.trial.Trial] = None,
401
+ loss_curve_path: Optional[str] = None):
402
+ """Generic training loop with AMP, DDP, and early stopping support.
403
+
404
+ Returns:
405
+ Tuple of (best_state_dict, history)
406
+ """
407
+ device_type = self._device_type()
408
+ best_loss = float('inf')
409
+ best_state = None
410
+ patience_counter = 0
411
+ stop_training = False
412
+ train_history: List[float] = []
413
+ val_history: List[float] = []
414
+
415
+ is_ddp_model = isinstance(model, DDP)
416
+
417
+ for epoch in range(1, getattr(self, "epochs", 1) + 1):
418
+ epoch_start_ts = time.time()
419
+ val_weighted_loss = None
420
+ if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
421
+ self.dataloader_sampler.set_epoch(epoch)
422
+
423
+ model.train()
424
+ optimizer.zero_grad()
425
+
426
+ epoch_loss_sum = None
427
+ epoch_weight_sum = None
428
+ for step, batch in enumerate(dataloader):
429
+ is_update_step = ((step + 1) % accum_steps == 0) or \
430
+ ((step + 1) == len(dataloader))
431
+ sync_cm = model.no_sync if (
432
+ is_ddp_model and not is_update_step) else nullcontext
433
+
434
+ with sync_cm():
435
+ with autocast(enabled=(device_type == 'cuda')):
436
+ y_pred, y_true, w = forward_fn(batch)
437
+ weighted_loss = self._compute_weighted_loss(
438
+ y_pred, y_true, w, apply_softplus=apply_softplus)
439
+ loss_for_backward = weighted_loss / accum_steps
440
+
441
+ batch_weight = torch.clamp(
442
+ w.detach().sum(), min=EPS).to(dtype=torch.float32)
443
+ loss_val = weighted_loss.detach().to(dtype=torch.float32)
444
+ if epoch_loss_sum is None:
445
+ epoch_loss_sum = torch.zeros(
446
+ (), device=batch_weight.device, dtype=torch.float32)
447
+ epoch_weight_sum = torch.zeros(
448
+ (), device=batch_weight.device, dtype=torch.float32)
449
+ epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
450
+ epoch_weight_sum = epoch_weight_sum + batch_weight
451
+ scaler.scale(loss_for_backward).backward()
452
+
453
+ if is_update_step:
454
+ if clip_fn is not None:
455
+ clip_fn()
456
+ scaler.step(optimizer)
457
+ scaler.update()
458
+ optimizer.zero_grad()
459
+
460
+ if epoch_loss_sum is None or epoch_weight_sum is None:
461
+ train_epoch_loss = 0.0
462
+ else:
463
+ train_epoch_loss = (
464
+ epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
465
+ ).item()
466
+ train_history.append(float(train_epoch_loss))
467
+
468
+ if val_forward_fn is not None:
469
+ should_compute_val = (not dist.is_initialized()
470
+ or DistributedUtils.is_main_process())
471
+ val_device = getattr(self, "device", torch.device("cpu"))
472
+ if not isinstance(val_device, torch.device):
473
+ val_device = torch.device(val_device)
474
+ loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
475
+ "cpu")
476
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
477
+
478
+ if should_compute_val:
479
+ model.eval()
480
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
481
+ val_result = val_forward_fn()
482
+ if isinstance(val_result, tuple) and len(val_result) == 3:
483
+ y_val_pred, y_val_true, w_val = val_result
484
+ val_weighted_loss = self._compute_weighted_loss(
485
+ y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
486
+ else:
487
+ val_weighted_loss = val_result
488
+ val_loss_tensor[0] = float(val_weighted_loss)
489
+
490
+ if dist.is_initialized():
491
+ dist.broadcast(val_loss_tensor, src=0)
492
+ val_weighted_loss = float(val_loss_tensor.item())
493
+
494
+ val_history.append(val_weighted_loss)
495
+
496
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
497
+ val_weighted_loss, best_loss, best_state, patience_counter, model)
498
+
499
+ prune_flag = False
500
+ is_main_rank = DistributedUtils.is_main_process()
501
+ if trial is not None and is_main_rank:
502
+ trial.report(val_weighted_loss, epoch)
503
+ prune_flag = trial.should_prune()
504
+
505
+ if dist.is_initialized():
506
+ prune_device = getattr(self, "device", torch.device("cpu"))
507
+ if not isinstance(prune_device, torch.device):
508
+ prune_device = torch.device(prune_device)
509
+ prune_tensor = torch.zeros(1, device=prune_device)
510
+ if is_main_rank:
511
+ prune_tensor.fill_(1 if prune_flag else 0)
512
+ dist.broadcast(prune_tensor, src=0)
513
+ prune_flag = bool(prune_tensor.item())
514
+
515
+ if prune_flag:
516
+ raise optuna.TrialPruned()
517
+
518
+ if stop_training:
519
+ break
520
+
521
+ should_log_epoch = (not dist.is_initialized()
522
+ or DistributedUtils.is_main_process())
523
+ if should_log_epoch:
524
+ elapsed = int(time.time() - epoch_start_ts)
525
+ if val_weighted_loss is None:
526
+ print(
527
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
528
+ f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
529
+ flush=True,
530
+ )
531
+ else:
532
+ print(
533
+ f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
534
+ f"train_loss={float(train_epoch_loss):.6f} "
535
+ f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
536
+ flush=True,
537
+ )
538
+
539
+ if epoch % 10 == 0:
540
+ if torch.cuda.is_available():
541
+ torch.cuda.empty_cache()
542
+ gc.collect()
543
+
544
+ history = {"train": train_history, "val": val_history}
545
+ self._plot_loss_curve(history, loss_curve_path)
546
+ return best_state, history
547
+
548
+ def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
549
+ """Plot training and validation loss curves."""
550
+ if not save_path:
551
+ return
552
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
553
+ return
554
+ train_hist = history.get("train", []) if history else []
555
+ val_hist = history.get("val", []) if history else []
556
+ if not train_hist and not val_hist:
557
+ return
558
+ if plot_loss_curve_common is not None:
559
+ plot_loss_curve_common(
560
+ history=history,
561
+ title="Loss vs. Epoch",
562
+ save_path=save_path,
563
+ show=False,
564
+ )
565
+ else:
566
+ if plt is None:
567
+ _plot_skip("loss curve")
568
+ return
569
+ ensure_parent_dir(save_path)
570
+ epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
571
+ fig = plt.figure(figsize=(8, 4))
572
+ ax = fig.add_subplot(111)
573
+ if train_hist:
574
+ ax.plot(range(1, len(train_hist) + 1), train_hist,
575
+ label='Train Loss', color='tab:blue')
576
+ if val_hist:
577
+ ax.plot(range(1, len(val_hist) + 1), val_hist,
578
+ label='Validation Loss', color='tab:orange')
579
+ ax.set_xlabel('Epoch')
580
+ ax.set_ylabel('Weighted Loss')
581
+ ax.set_title('Loss vs. Epoch')
582
+ ax.grid(True, linestyle='--', alpha=0.3)
583
+ ax.legend()
584
+ plt.tight_layout()
585
+ plt.savefig(save_path, dpi=300)
586
+ plt.close(fig)
587
+ print(f"[Training] Loss curve saved to {save_path}")