ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. ins_pricing/README.md +9 -6
  2. ins_pricing/__init__.py +3 -11
  3. ins_pricing/cli/BayesOpt_entry.py +24 -0
  4. ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
  5. ins_pricing/cli/Explain_Run.py +25 -0
  6. ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
  7. ins_pricing/cli/Pricing_Run.py +25 -0
  8. ins_pricing/cli/__init__.py +1 -0
  9. ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
  10. ins_pricing/cli/utils/__init__.py +1 -0
  11. ins_pricing/cli/utils/cli_common.py +320 -0
  12. ins_pricing/cli/utils/cli_config.py +375 -0
  13. ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
  14. {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
  15. ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
  16. ins_pricing/docs/modelling/README.md +34 -0
  17. ins_pricing/modelling/__init__.py +57 -6
  18. ins_pricing/modelling/core/__init__.py +1 -0
  19. ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
  20. ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
  21. ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
  22. ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
  23. ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
  24. ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
  25. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
  26. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
  27. ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
  28. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
  29. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
  30. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
  31. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
  32. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
  33. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
  34. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
  35. ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
  36. ins_pricing/modelling/core/evaluation.py +115 -0
  37. ins_pricing/production/__init__.py +4 -0
  38. ins_pricing/production/preprocess.py +71 -0
  39. ins_pricing/setup.py +10 -5
  40. {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
  41. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
  42. ins_pricing-0.2.0.dist-info/RECORD +125 -0
  43. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
  44. ins_pricing/modelling/BayesOpt_entry.py +0 -633
  45. ins_pricing/modelling/Explain_Run.py +0 -36
  46. ins_pricing/modelling/Pricing_Run.py +0 -36
  47. ins_pricing/modelling/README.md +0 -33
  48. ins_pricing/modelling/bayesopt/models.py +0 -2196
  49. ins_pricing/modelling/bayesopt/trainers.py +0 -2446
  50. ins_pricing/modelling/cli_common.py +0 -136
  51. ins_pricing/modelling/tests/test_plotting.py +0 -63
  52. ins_pricing/modelling/watchdog_run.py +0 -211
  53. ins_pricing-0.1.11.dist-info/RECORD +0 -169
  54. ins_pricing_gemini/__init__.py +0 -23
  55. ins_pricing_gemini/governance/__init__.py +0 -20
  56. ins_pricing_gemini/governance/approval.py +0 -93
  57. ins_pricing_gemini/governance/audit.py +0 -37
  58. ins_pricing_gemini/governance/registry.py +0 -99
  59. ins_pricing_gemini/governance/release.py +0 -159
  60. ins_pricing_gemini/modelling/Explain_Run.py +0 -36
  61. ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
  62. ins_pricing_gemini/modelling/__init__.py +0 -151
  63. ins_pricing_gemini/modelling/cli_common.py +0 -141
  64. ins_pricing_gemini/modelling/config.py +0 -249
  65. ins_pricing_gemini/modelling/config_preprocess.py +0 -254
  66. ins_pricing_gemini/modelling/core.py +0 -741
  67. ins_pricing_gemini/modelling/data_container.py +0 -42
  68. ins_pricing_gemini/modelling/explain/__init__.py +0 -55
  69. ins_pricing_gemini/modelling/explain/gradients.py +0 -334
  70. ins_pricing_gemini/modelling/explain/metrics.py +0 -176
  71. ins_pricing_gemini/modelling/explain/permutation.py +0 -155
  72. ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
  73. ins_pricing_gemini/modelling/features.py +0 -215
  74. ins_pricing_gemini/modelling/model_manager.py +0 -148
  75. ins_pricing_gemini/modelling/model_plotting.py +0 -463
  76. ins_pricing_gemini/modelling/models.py +0 -2203
  77. ins_pricing_gemini/modelling/notebook_utils.py +0 -294
  78. ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
  79. ins_pricing_gemini/modelling/plotting/common.py +0 -63
  80. ins_pricing_gemini/modelling/plotting/curves.py +0 -572
  81. ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
  82. ins_pricing_gemini/modelling/plotting/geo.py +0 -362
  83. ins_pricing_gemini/modelling/plotting/importance.py +0 -121
  84. ins_pricing_gemini/modelling/run_logging.py +0 -133
  85. ins_pricing_gemini/modelling/tests/conftest.py +0 -8
  86. ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
  87. ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
  88. ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
  89. ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
  90. ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
  91. ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
  92. ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
  93. ins_pricing_gemini/modelling/trainers.py +0 -2447
  94. ins_pricing_gemini/modelling/utils.py +0 -1020
  95. ins_pricing_gemini/pricing/__init__.py +0 -27
  96. ins_pricing_gemini/pricing/calibration.py +0 -39
  97. ins_pricing_gemini/pricing/data_quality.py +0 -117
  98. ins_pricing_gemini/pricing/exposure.py +0 -85
  99. ins_pricing_gemini/pricing/factors.py +0 -91
  100. ins_pricing_gemini/pricing/monitoring.py +0 -99
  101. ins_pricing_gemini/pricing/rate_table.py +0 -78
  102. ins_pricing_gemini/production/__init__.py +0 -21
  103. ins_pricing_gemini/production/drift.py +0 -30
  104. ins_pricing_gemini/production/monitoring.py +0 -143
  105. ins_pricing_gemini/production/scoring.py +0 -40
  106. ins_pricing_gemini/reporting/__init__.py +0 -11
  107. ins_pricing_gemini/reporting/report_builder.py +0 -72
  108. ins_pricing_gemini/reporting/scheduler.py +0 -45
  109. ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
  110. ins_pricing_gemini/scripts/Explain_entry.py +0 -545
  111. ins_pricing_gemini/scripts/__init__.py +0 -1
  112. ins_pricing_gemini/scripts/train.py +0 -568
  113. ins_pricing_gemini/setup.py +0 -55
  114. ins_pricing_gemini/smoke_test.py +0 -28
  115. /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
  116. /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
  117. /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
  118. /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
  119. /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
  120. /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
  121. /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
  122. /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
  123. /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
  124. /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
  125. /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
  126. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,435 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.cuda.amp import GradScaler
10
+ from torch.nn.parallel import DistributedDataParallel as DDP
11
+ from torch.nn.utils import clip_grad_norm_
12
+ from torch.utils.data import TensorDataset
13
+
14
+ from ..utils import DistributedUtils, EPS, TorchTrainerMixin
15
+
16
+
17
+ # =============================================================================
18
+ # ResNet model and sklearn-style wrapper
19
+ # =============================================================================
20
+
21
+ # ResNet model definition
22
+ # Residual block: two linear layers + ReLU + residual connection
23
+ # ResBlock inherits nn.Module
24
+ class ResBlock(nn.Module):
25
+ def __init__(self, dim: int, dropout: float = 0.1,
26
+ use_layernorm: bool = False, residual_scale: float = 0.1,
27
+ stochastic_depth: float = 0.0
28
+ ):
29
+ super().__init__()
30
+ self.use_layernorm = use_layernorm
31
+
32
+ if use_layernorm:
33
+ Norm = nn.LayerNorm # Normalize the last dimension
34
+ else:
35
+ def Norm(d): return nn.BatchNorm1d(d) # Keep a switch to try BN
36
+
37
+ self.norm1 = Norm(dim)
38
+ self.fc1 = nn.Linear(dim, dim, bias=True)
39
+ self.act = nn.ReLU(inplace=True)
40
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
41
+ # Enable post-second-layer norm if needed: self.norm2 = Norm(dim)
42
+ self.fc2 = nn.Linear(dim, dim, bias=True)
43
+
44
+ # Residual scaling to stabilize early training
45
+ self.res_scale = nn.Parameter(
46
+ torch.tensor(residual_scale, dtype=torch.float32)
47
+ )
48
+ self.stochastic_depth = max(0.0, float(stochastic_depth))
49
+
50
+ def _drop_path(self, x: torch.Tensor) -> torch.Tensor:
51
+ if self.stochastic_depth <= 0.0 or not self.training:
52
+ return x
53
+ keep_prob = 1.0 - self.stochastic_depth
54
+ if keep_prob <= 0.0:
55
+ return torch.zeros_like(x)
56
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
57
+ random_tensor = keep_prob + torch.rand(
58
+ shape, dtype=x.dtype, device=x.device)
59
+ binary_tensor = torch.floor(random_tensor)
60
+ return x * binary_tensor / keep_prob
61
+
62
+ def forward(self, x):
63
+ # Pre-activation structure
64
+ out = self.norm1(x)
65
+ out = self.fc1(out)
66
+ out = self.act(out)
67
+ out = self.dropout(out)
68
+ # If a second norm is enabled: out = self.norm2(out)
69
+ out = self.fc2(out)
70
+ # Apply residual scaling then add
71
+ out = self.res_scale * out
72
+ out = self._drop_path(out)
73
+ return x + out
74
+
75
+ # ResNetSequential defines the full network
76
+
77
+
78
+ class ResNetSequential(nn.Module):
79
+ # Input shape: (batch, input_dim)
80
+ # Network: FC + norm + ReLU, stack residual blocks, output Softplus
81
+
82
+ def __init__(self, input_dim: int, hidden_dim: int = 64, block_num: int = 2,
83
+ use_layernorm: bool = True, dropout: float = 0.1,
84
+ residual_scale: float = 0.1, stochastic_depth: float = 0.0,
85
+ task_type: str = 'regression'):
86
+ super(ResNetSequential, self).__init__()
87
+
88
+ self.net = nn.Sequential()
89
+ self.net.add_module('fc1', nn.Linear(input_dim, hidden_dim))
90
+
91
+ # Optional explicit normalization after the first layer:
92
+ # For LayerNorm:
93
+ # self.net.add_module('norm1', nn.LayerNorm(hidden_dim))
94
+ # Or BatchNorm:
95
+ # self.net.add_module('norm1', nn.BatchNorm1d(hidden_dim))
96
+
97
+ # If desired, insert ReLU before residual blocks:
98
+ # self.net.add_module('relu1', nn.ReLU(inplace=True))
99
+
100
+ # Residual blocks
101
+ drop_path_rate = max(0.0, float(stochastic_depth))
102
+ for i in range(block_num):
103
+ if block_num > 1:
104
+ block_drop = drop_path_rate * (i / (block_num - 1))
105
+ else:
106
+ block_drop = drop_path_rate
107
+ self.net.add_module(
108
+ f'ResBlk_{i+1}',
109
+ ResBlock(
110
+ hidden_dim,
111
+ dropout=dropout,
112
+ use_layernorm=use_layernorm,
113
+ residual_scale=residual_scale,
114
+ stochastic_depth=block_drop)
115
+ )
116
+
117
+ self.net.add_module('fc_out', nn.Linear(hidden_dim, 1))
118
+
119
+ if task_type == 'classification':
120
+ self.net.add_module('softplus', nn.Identity())
121
+ else:
122
+ self.net.add_module('softplus', nn.Softplus())
123
+
124
+ def forward(self, x):
125
+ if self.training and not hasattr(self, '_printed_device'):
126
+ print(f">>> ResNetSequential executing on device: {x.device}")
127
+ self._printed_device = True
128
+ return self.net(x)
129
+
130
+ # Define the ResNet sklearn-style wrapper.
131
+
132
+
133
+ class ResNetSklearn(TorchTrainerMixin, nn.Module):
134
+ def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
135
+ block_num: int = 2, batch_num: int = 100, epochs: int = 100,
136
+ task_type: str = 'regression',
137
+ tweedie_power: float = 1.5, learning_rate: float = 0.01, patience: int = 10,
138
+ use_layernorm: bool = True, dropout: float = 0.1,
139
+ residual_scale: float = 0.1,
140
+ stochastic_depth: float = 0.0,
141
+ weight_decay: float = 1e-4,
142
+ use_data_parallel: bool = True,
143
+ use_ddp: bool = False):
144
+ super(ResNetSklearn, self).__init__()
145
+
146
+ self.use_ddp = use_ddp
147
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
148
+ False, 0, 0, 1)
149
+
150
+ if self.use_ddp:
151
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
152
+
153
+ self.input_dim = input_dim
154
+ self.hidden_dim = hidden_dim
155
+ self.block_num = block_num
156
+ self.batch_num = batch_num
157
+ self.epochs = epochs
158
+ self.task_type = task_type
159
+ self.model_nme = model_nme
160
+ self.learning_rate = learning_rate
161
+ self.weight_decay = weight_decay
162
+ self.patience = patience
163
+ self.use_layernorm = use_layernorm
164
+ self.dropout = dropout
165
+ self.residual_scale = residual_scale
166
+ self.stochastic_depth = max(0.0, float(stochastic_depth))
167
+ self.loss_curve_path: Optional[str] = None
168
+ self.training_history: Dict[str, List[float]] = {
169
+ "train": [], "val": []}
170
+ self.use_data_parallel = bool(use_data_parallel)
171
+
172
+ # Device selection: cuda > mps > cpu
173
+ if self.is_ddp_enabled:
174
+ self.device = torch.device(f'cuda:{self.local_rank}')
175
+ elif torch.cuda.is_available():
176
+ self.device = torch.device('cuda')
177
+ elif torch.backends.mps.is_available():
178
+ self.device = torch.device('mps')
179
+ else:
180
+ self.device = torch.device('cpu')
181
+
182
+ # Tweedie power (unused for classification)
183
+ if self.task_type == 'classification':
184
+ self.tw_power = None
185
+ elif 'f' in self.model_nme:
186
+ self.tw_power = 1
187
+ elif 's' in self.model_nme:
188
+ self.tw_power = 2
189
+ else:
190
+ self.tw_power = tweedie_power
191
+
192
+ # Build network (construct on CPU first)
193
+ core = ResNetSequential(
194
+ self.input_dim,
195
+ self.hidden_dim,
196
+ self.block_num,
197
+ use_layernorm=self.use_layernorm,
198
+ dropout=self.dropout,
199
+ residual_scale=self.residual_scale,
200
+ stochastic_depth=self.stochastic_depth,
201
+ task_type=self.task_type
202
+ )
203
+
204
+ # ===== Multi-GPU: DataParallel vs DistributedDataParallel =====
205
+ if self.is_ddp_enabled:
206
+ core = core.to(self.device)
207
+ core = DDP(core, device_ids=[
208
+ self.local_rank], output_device=self.local_rank)
209
+ self.use_data_parallel = False
210
+ elif use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
211
+ if self.use_ddp and not self.is_ddp_enabled:
212
+ print(
213
+ ">>> DDP requested but not initialized; falling back to DataParallel.")
214
+ core = nn.DataParallel(core, device_ids=list(
215
+ range(torch.cuda.device_count())))
216
+ # DataParallel scatters inputs, but the primary device remains cuda:0.
217
+ self.device = torch.device('cuda')
218
+ self.use_data_parallel = True
219
+ else:
220
+ self.use_data_parallel = False
221
+
222
+ self.resnet = core.to(self.device)
223
+
224
+ # ================ Internal helpers ================
225
+ @staticmethod
226
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
227
+ if arr is None:
228
+ return
229
+ if isinstance(arr, pd.DataFrame):
230
+ if arr.shape[1] != 1:
231
+ raise ValueError(f"{name} must be 1d (single column).")
232
+ length = len(arr)
233
+ else:
234
+ arr_np = np.asarray(arr)
235
+ if arr_np.ndim == 0:
236
+ raise ValueError(f"{name} must be 1d.")
237
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
238
+ raise ValueError(f"{name} must be 1d or Nx1.")
239
+ length = arr_np.shape[0]
240
+ if length != n_rows:
241
+ raise ValueError(
242
+ f"{name} length {length} does not match X length {n_rows}."
243
+ )
244
+
245
+ def _validate_inputs(self, X, y, w, label: str) -> None:
246
+ if X is None:
247
+ raise ValueError(f"{label} X cannot be None.")
248
+ n_rows = len(X)
249
+ if y is None:
250
+ raise ValueError(f"{label} y cannot be None.")
251
+ self._validate_vector(y, f"{label} y", n_rows)
252
+ self._validate_vector(w, f"{label} w", n_rows)
253
+
254
+ def _build_train_val_tensors(self, X_train, y_train, w_train, X_val, y_val, w_val):
255
+ self._validate_inputs(X_train, y_train, w_train, "train")
256
+ if X_val is not None or y_val is not None or w_val is not None:
257
+ if X_val is None or y_val is None:
258
+ raise ValueError("validation X and y must both be provided.")
259
+ self._validate_inputs(X_val, y_val, w_val, "val")
260
+
261
+ def _to_numpy(arr):
262
+ if hasattr(arr, "to_numpy"):
263
+ return arr.to_numpy(dtype=np.float32, copy=False)
264
+ return np.asarray(arr, dtype=np.float32)
265
+
266
+ X_tensor = torch.as_tensor(_to_numpy(X_train))
267
+ y_tensor = torch.as_tensor(_to_numpy(y_train)).view(-1, 1)
268
+ w_tensor = (
269
+ torch.as_tensor(_to_numpy(w_train)).view(-1, 1)
270
+ if w_train is not None else torch.ones_like(y_tensor)
271
+ )
272
+
273
+ has_val = X_val is not None and y_val is not None
274
+ if has_val:
275
+ X_val_tensor = torch.as_tensor(_to_numpy(X_val))
276
+ y_val_tensor = torch.as_tensor(_to_numpy(y_val)).view(-1, 1)
277
+ w_val_tensor = (
278
+ torch.as_tensor(_to_numpy(w_val)).view(-1, 1)
279
+ if w_val is not None else torch.ones_like(y_val_tensor)
280
+ )
281
+ else:
282
+ X_val_tensor = y_val_tensor = w_val_tensor = None
283
+ return X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val
284
+
285
+ def forward(self, x):
286
+ # Handle SHAP NumPy input.
287
+ if isinstance(x, np.ndarray):
288
+ x_tensor = torch.as_tensor(x, dtype=torch.float32)
289
+ else:
290
+ x_tensor = x
291
+
292
+ x_tensor = x_tensor.to(self.device)
293
+ y_pred = self.resnet(x_tensor)
294
+ return y_pred
295
+
296
+ # ---------------- Training ----------------
297
+
298
+ def fit(self, X_train, y_train, w_train=None,
299
+ X_val=None, y_val=None, w_val=None, trial=None):
300
+
301
+ X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val = \
302
+ self._build_train_val_tensors(
303
+ X_train, y_train, w_train, X_val, y_val, w_val)
304
+
305
+ dataset = TensorDataset(X_tensor, y_tensor, w_tensor)
306
+ dataloader, accum_steps = self._build_dataloader(
307
+ dataset,
308
+ N=X_tensor.shape[0],
309
+ base_bs_gpu=(2048, 1024, 512),
310
+ base_bs_cpu=(256, 128),
311
+ min_bs=64,
312
+ target_effective_cuda=2048,
313
+ target_effective_cpu=1024
314
+ )
315
+
316
+ # Set sampler epoch at the start of each epoch to keep shuffling deterministic.
317
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
318
+ self.dataloader_sampler = dataloader.sampler
319
+ else:
320
+ self.dataloader_sampler = None
321
+
322
+ # === 4. Optimizer and AMP ===
323
+ self.optimizer = torch.optim.Adam(
324
+ self.resnet.parameters(),
325
+ lr=self.learning_rate,
326
+ weight_decay=float(self.weight_decay),
327
+ )
328
+ self.scaler = GradScaler(enabled=(self.device.type == 'cuda'))
329
+
330
+ X_val_dev = y_val_dev = w_val_dev = None
331
+ val_dataloader = None
332
+ if has_val:
333
+ # Build validation DataLoader.
334
+ val_dataset = TensorDataset(
335
+ X_val_tensor, y_val_tensor, w_val_tensor)
336
+ # No backward pass in validation; batch size can be larger for throughput.
337
+ val_dataloader = self._build_val_dataloader(
338
+ val_dataset, dataloader, accum_steps)
339
+ # Validation usually does not need a DDP sampler because we validate on the main process
340
+ # or aggregate results. For simplicity, keep validation on a single GPU or the main process.
341
+
342
+ is_data_parallel = isinstance(self.resnet, nn.DataParallel)
343
+
344
+ def forward_fn(batch):
345
+ X_batch, y_batch, w_batch = batch
346
+
347
+ if not is_data_parallel:
348
+ X_batch = X_batch.to(self.device, non_blocking=True)
349
+ # Keep targets and weights on the main device for loss computation.
350
+ y_batch = y_batch.to(self.device, non_blocking=True)
351
+ w_batch = w_batch.to(self.device, non_blocking=True)
352
+
353
+ y_pred = self.resnet(X_batch)
354
+ return y_pred, y_batch, w_batch
355
+
356
+ def val_forward_fn():
357
+ total_loss = 0.0
358
+ total_weight = 0.0
359
+ for batch in val_dataloader:
360
+ X_b, y_b, w_b = batch
361
+ if not is_data_parallel:
362
+ X_b = X_b.to(self.device, non_blocking=True)
363
+ y_b = y_b.to(self.device, non_blocking=True)
364
+ w_b = w_b.to(self.device, non_blocking=True)
365
+
366
+ y_pred = self.resnet(X_b)
367
+
368
+ # Manually compute weighted loss for accurate aggregation.
369
+ losses = self._compute_losses(
370
+ y_pred, y_b, apply_softplus=False)
371
+
372
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
373
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
374
+
375
+ total_loss += batch_weighted_loss_sum.item()
376
+ total_weight += batch_weight_sum.item()
377
+
378
+ return total_loss / max(total_weight, EPS)
379
+
380
+ clip_fn = None
381
+ if self.device.type == 'cuda':
382
+ def clip_fn(): return (self.scaler.unscale_(self.optimizer),
383
+ clip_grad_norm_(self.resnet.parameters(), max_norm=1.0))
384
+
385
+ # Under DDP, only the main process prints logs and saves models.
386
+ if self.is_ddp_enabled and not DistributedUtils.is_main_process():
387
+ # Non-main processes skip validation callback logging (handled inside _train_model).
388
+ pass
389
+
390
+ best_state, history = self._train_model(
391
+ self.resnet,
392
+ dataloader,
393
+ accum_steps,
394
+ self.optimizer,
395
+ self.scaler,
396
+ forward_fn,
397
+ val_forward_fn if has_val else None,
398
+ apply_softplus=False,
399
+ clip_fn=clip_fn,
400
+ trial=trial,
401
+ loss_curve_path=getattr(self, "loss_curve_path", None)
402
+ )
403
+
404
+ if has_val and best_state is not None:
405
+ self.resnet.load_state_dict(best_state)
406
+ self.training_history = history
407
+
408
+ # ---------------- Prediction ----------------
409
+
410
+ def predict(self, X_test):
411
+ self.resnet.eval()
412
+ if isinstance(X_test, pd.DataFrame):
413
+ X_np = X_test.to_numpy(dtype=np.float32, copy=False)
414
+ else:
415
+ X_np = np.asarray(X_test, dtype=np.float32)
416
+
417
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
418
+ with inference_cm():
419
+ y_pred = self(X_np).cpu().numpy()
420
+
421
+ if self.task_type == 'classification':
422
+ y_pred = 1 / (1 + np.exp(-y_pred)) # Sigmoid converts logits to probabilities.
423
+ else:
424
+ y_pred = np.clip(y_pred, 1e-6, None)
425
+ return y_pred.flatten()
426
+
427
+ # ---------------- Set Params ----------------
428
+
429
+ def set_params(self, params):
430
+ for key, value in params.items():
431
+ if hasattr(self, key):
432
+ setattr(self, key, value)
433
+ else:
434
+ raise ValueError(f"Parameter {key} not found in model.")
435
+ return self
@@ -0,0 +1,19 @@
1
+ """Trainer implementations split by model type."""
2
+ from __future__ import annotations
3
+
4
+ from .trainer_base import TrainerBase
5
+ from .trainer_ft import FTTrainer
6
+ from .trainer_glm import GLMTrainer
7
+ from .trainer_gnn import GNNTrainer
8
+ from .trainer_resn import ResNetTrainer
9
+ from .trainer_xgb import XGBTrainer, _xgb_cuda_available
10
+
11
+ __all__ = [
12
+ "TrainerBase",
13
+ "FTTrainer",
14
+ "GLMTrainer",
15
+ "GNNTrainer",
16
+ "ResNetTrainer",
17
+ "XGBTrainer",
18
+ "_xgb_cuda_available",
19
+ ]