ins-pricing 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +39 -105
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +11 -9
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/__init__.py +10 -10
  15. ins_pricing/frontend/example_workflows.py +1 -1
  16. ins_pricing/governance/__init__.py +20 -20
  17. ins_pricing/governance/release.py +159 -159
  18. ins_pricing/modelling/__init__.py +147 -92
  19. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +2 -2
  20. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  21. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -562
  22. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -964
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  29. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  36. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  37. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  38. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  39. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  40. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  42. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  43. ins_pricing/modelling/explain/__init__.py +55 -55
  44. ins_pricing/modelling/explain/metrics.py +27 -174
  45. ins_pricing/modelling/explain/permutation.py +237 -237
  46. ins_pricing/modelling/plotting/__init__.py +40 -36
  47. ins_pricing/modelling/plotting/compat.py +228 -0
  48. ins_pricing/modelling/plotting/curves.py +572 -572
  49. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  50. ins_pricing/modelling/plotting/geo.py +362 -362
  51. ins_pricing/modelling/plotting/importance.py +121 -121
  52. ins_pricing/pricing/__init__.py +27 -27
  53. ins_pricing/production/__init__.py +35 -25
  54. ins_pricing/production/{predict.py → inference.py} +140 -57
  55. ins_pricing/production/monitoring.py +8 -21
  56. ins_pricing/reporting/__init__.py +11 -11
  57. ins_pricing/setup.py +1 -1
  58. ins_pricing/tests/production/test_inference.py +90 -0
  59. ins_pricing/utils/__init__.py +116 -83
  60. ins_pricing/utils/device.py +255 -255
  61. ins_pricing/utils/features.py +53 -0
  62. ins_pricing/utils/io.py +72 -0
  63. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  64. ins_pricing/utils/metrics.py +158 -24
  65. ins_pricing/utils/numerics.py +76 -0
  66. ins_pricing/utils/paths.py +9 -1
  67. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +182 -182
  68. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  69. ins_pricing/modelling/core/BayesOpt.py +0 -146
  70. ins_pricing/modelling/core/__init__.py +0 -1
  71. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  72. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  73. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  74. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  75. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  76. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  77. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  78. ins_pricing/tests/production/test_predict.py +0 -233
  79. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  80. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  81. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  82. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  83. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  84. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,283 +1,283 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- from typing import Any, Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import optuna
8
- import torch
9
- from sklearn.metrics import log_loss
10
-
11
- from .trainer_base import TrainerBase
12
- from ..models import ResNetSklearn
13
- from ..utils.losses import regression_loss
14
-
15
- class ResNetTrainer(TrainerBase):
16
- def __init__(self, context: "BayesOptModel") -> None:
17
- if context.task_type == 'classification':
18
- super().__init__(context, 'ResNetClassifier', 'ResNet')
19
- else:
20
- super().__init__(context, 'ResNet', 'ResNet')
21
- self.model: Optional[ResNetSklearn] = None
22
- self.enable_distributed_optuna = bool(context.config.use_resn_ddp)
23
-
24
- def _resolve_input_dim(self) -> int:
25
- data = getattr(self.ctx, "train_oht_scl_data", None)
26
- if data is not None and getattr(self.ctx, "var_nmes", None):
27
- return int(data[self.ctx.var_nmes].shape[1])
28
- return int(len(self.ctx.var_nmes or []))
29
-
30
- def _build_model(self, params: Optional[Dict[str, Any]] = None) -> ResNetSklearn:
31
- params = params or {}
32
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
33
- power = params.get("tw_power")
34
- if self.ctx.task_type == "regression":
35
- base_tw = self.ctx.default_tweedie_power()
36
- if loss_name == "tweedie":
37
- power = base_tw if power is None else float(power)
38
- elif loss_name in ("poisson", "gamma"):
39
- power = base_tw
40
- else:
41
- power = None
42
- resn_weight_decay = float(
43
- params.get(
44
- "weight_decay",
45
- getattr(self.ctx.config, "resn_weight_decay", 1e-4),
46
- )
47
- )
48
- model = ResNetSklearn(
49
- model_nme=self.ctx.model_nme,
50
- input_dim=self._resolve_input_dim(),
51
- hidden_dim=int(params.get("hidden_dim", 64)),
52
- block_num=int(params.get("block_num", 2)),
53
- task_type=self.ctx.task_type,
54
- epochs=self.ctx.epochs,
55
- tweedie_power=power,
56
- learning_rate=float(params.get("learning_rate", 0.01)),
57
- patience=int(params.get("patience", 10)),
58
- use_layernorm=True,
59
- dropout=float(params.get("dropout", 0.1)),
60
- residual_scale=float(params.get("residual_scale", 0.1)),
61
- stochastic_depth=float(params.get("stochastic_depth", 0.0)),
62
- weight_decay=resn_weight_decay,
63
- use_data_parallel=self.ctx.config.use_resn_data_parallel,
64
- use_ddp=self.ctx.config.use_resn_ddp,
65
- loss_name=loss_name
66
- )
67
- return self._apply_dataloader_overrides(model)
68
-
69
- # ========= Cross-validation (for BayesOpt) =========
70
- def cross_val(self, trial: optuna.trial.Trial) -> float:
71
- # ResNet CV focuses on memory control:
72
- # - Create a ResNetSklearn per fold and release it immediately after.
73
- # - Move model to CPU, delete, and call gc/empty_cache after each fold.
74
- # - Optionally sample part of training data during BayesOpt to reduce memory.
75
-
76
- base_tw_power = self.ctx.default_tweedie_power()
77
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
78
-
79
- def data_provider():
80
- data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
81
- assert data is not None, "Preprocessed training data is missing."
82
- return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
83
-
84
- metric_ctx: Dict[str, Any] = {}
85
-
86
- def model_builder(params):
87
- if loss_name == "tweedie":
88
- power = params.get("tw_power", base_tw_power)
89
- elif loss_name in ("poisson", "gamma"):
90
- power = base_tw_power
91
- else:
92
- power = None
93
- metric_ctx["tw_power"] = power
94
- params_local = dict(params)
95
- if power is not None:
96
- params_local["tw_power"] = power
97
- return self._build_model(params_local)
98
-
99
- def preprocess_fn(X_train, X_val):
100
- X_train_s, X_val_s, _ = self._standardize_fold(
101
- X_train, X_val, self.ctx.num_features)
102
- return X_train_s, X_val_s
103
-
104
- def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
105
- model.fit(
106
- X_train, y_train, w_train,
107
- X_val, y_val, w_val,
108
- trial=trial_obj
109
- )
110
- return model.predict(X_val)
111
-
112
- def metric_fn(y_true, y_pred, weight):
113
- if self.ctx.task_type == 'regression':
114
- return regression_loss(
115
- y_true,
116
- y_pred,
117
- weight,
118
- loss_name=loss_name,
119
- tweedie_power=metric_ctx.get("tw_power", base_tw_power),
120
- )
121
- return log_loss(y_true, y_pred, sample_weight=weight)
122
-
123
- sample_cap = data_provider()[0]
124
- max_rows_for_resnet_bo = min(100000, int(len(sample_cap)/5))
125
-
126
- return self.cross_val_generic(
127
- trial=trial,
128
- hyperparameter_space={
129
- "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-6, 1e-2, log=True),
130
- "hidden_dim": lambda t: t.suggest_int('hidden_dim', 8, 32, step=2),
131
- "block_num": lambda t: t.suggest_int('block_num', 2, 10),
132
- "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3, step=0.05),
133
- "residual_scale": lambda t: t.suggest_float('residual_scale', 0.05, 0.3, step=0.05),
134
- "patience": lambda t: t.suggest_int('patience', 3, 12),
135
- "stochastic_depth": lambda t: t.suggest_float('stochastic_depth', 0.0, 0.2, step=0.05),
136
- **({"tw_power": lambda t: t.suggest_float('tw_power', 1.0, 2.0)} if self.ctx.task_type == 'regression' and loss_name == 'tweedie' else {})
137
- },
138
- data_provider=data_provider,
139
- model_builder=model_builder,
140
- metric_fn=metric_fn,
141
- sample_limit=max_rows_for_resnet_bo if len(
142
- sample_cap) > max_rows_for_resnet_bo > 0 else None,
143
- preprocess_fn=preprocess_fn,
144
- fit_predict_fn=fit_predict,
145
- cleanup_fn=lambda m: getattr(
146
- getattr(m, "resnet", None), "to", lambda *_args, **_kwargs: None)("cpu")
147
- )
148
-
149
- # ========= Train final ResNet with best hyperparameters =========
150
- def train(self) -> None:
151
- if not self.best_params:
152
- raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
153
-
154
- params = dict(self.best_params)
155
- use_refit = bool(getattr(self.ctx.config, "final_refit", True))
156
- data = self.ctx.train_oht_scl_data
157
- if data is None:
158
- raise RuntimeError("Missing standardized data for ResNet training.")
159
- X_all = data[self.ctx.var_nmes]
160
- y_all = data[self.ctx.resp_nme]
161
- w_all = data[self.ctx.weight_nme]
162
-
163
- refit_epochs = None
164
- split = self._resolve_train_val_indices(X_all)
165
- if use_refit and split is not None:
166
- train_idx, val_idx = split
167
- tmp_model = self._build_model(params)
168
- tmp_model.fit(
169
- X_all.iloc[train_idx],
170
- y_all.iloc[train_idx],
171
- w_all.iloc[train_idx],
172
- X_all.iloc[val_idx],
173
- y_all.iloc[val_idx],
174
- w_all.iloc[val_idx],
175
- trial=None,
176
- )
177
- refit_epochs = self._resolve_best_epoch(
178
- getattr(tmp_model, "training_history", None),
179
- default_epochs=int(self.ctx.epochs),
180
- )
181
- getattr(getattr(tmp_model, "resnet", None), "to",
182
- lambda *_args, **_kwargs: None)("cpu")
183
- self._clean_gpu()
184
-
185
- self.model = self._build_model(params)
186
- if refit_epochs is not None:
187
- self.model.epochs = int(refit_epochs)
188
- self.best_params = params
189
- loss_plot_path = self.output.plot_path(
190
- f'{self.ctx.model_nme}/loss/loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
191
- self.model.loss_curve_path = loss_plot_path
192
-
193
- self._fit_predict_cache(
194
- self.model,
195
- X_all,
196
- y_all,
197
- sample_weight=w_all,
198
- pred_prefix='resn',
199
- use_oht=True,
200
- sample_weight_arg='w_train'
201
- )
202
-
203
- # Convenience wrapper for external callers.
204
- self.ctx.resn_best = self.model
205
-
206
- def ensemble_predict(self, k: int) -> None:
207
- if not self.best_params:
208
- raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
209
- data = self.ctx.train_oht_scl_data
210
- test_data = self.ctx.test_oht_scl_data
211
- if data is None or test_data is None:
212
- raise RuntimeError("Missing standardized data for ResNet ensemble.")
213
- X_all = data[self.ctx.var_nmes]
214
- y_all = data[self.ctx.resp_nme]
215
- w_all = data[self.ctx.weight_nme]
216
- X_test = test_data[self.ctx.var_nmes]
217
-
218
- k = max(2, int(k))
219
- n_samples = len(X_all)
220
- split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
221
- if split_iter is None:
222
- print(
223
- f"[ResNet Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
224
- flush=True,
225
- )
226
- return
227
- preds_train_sum = np.zeros(n_samples, dtype=np.float64)
228
- preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
229
-
230
- split_count = 0
231
- for train_idx, val_idx in split_iter:
232
- model = self._build_model(self.best_params)
233
- model.fit(
234
- X_all.iloc[train_idx],
235
- y_all.iloc[train_idx],
236
- w_all.iloc[train_idx],
237
- X_all.iloc[val_idx],
238
- y_all.iloc[val_idx],
239
- w_all.iloc[val_idx],
240
- trial=None,
241
- )
242
- pred_train = model.predict(X_all)
243
- pred_test = model.predict(X_test)
244
- preds_train_sum += np.asarray(pred_train, dtype=np.float64)
245
- preds_test_sum += np.asarray(pred_test, dtype=np.float64)
246
- getattr(getattr(model, "resnet", None), "to",
247
- lambda *_args, **_kwargs: None)("cpu")
248
- self._clean_gpu()
249
- split_count += 1
250
-
251
- if split_count < 1:
252
- print(
253
- f"[ResNet Ensemble] no CV splits generated; skip ensemble.",
254
- flush=True,
255
- )
256
- return
257
- preds_train = preds_train_sum / float(split_count)
258
- preds_test = preds_test_sum / float(split_count)
259
- self._cache_predictions("resn", preds_train, preds_test)
260
-
261
- # ========= Save / Load =========
262
- # ResNet is saved as state_dict and needs a custom load path.
263
- # Save logic is implemented in TrainerBase (checks .resnet attribute).
264
-
265
- def load(self) -> None:
266
- # Load ResNet weights to the current device to match context.
267
- path = self.output.model_path(self._get_model_filename())
268
- if os.path.exists(path):
269
- payload = torch.load(path, map_location='cpu')
270
- if isinstance(payload, dict) and "state_dict" in payload:
271
- state_dict = payload.get("state_dict")
272
- params = payload.get("best_params") or self.best_params
273
- else:
274
- state_dict = payload
275
- params = self.best_params
276
- resn_loaded = self._build_model(params)
277
- resn_loaded.resnet.load_state_dict(state_dict)
278
-
279
- self._move_to_device(resn_loaded)
280
- self.model = resn_loaded
281
- self.ctx.resn_best = self.model
282
- else:
283
- print(f"[ResNetTrainer.load] Model file not found: {path}")
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import optuna
8
+ import torch
9
+ from sklearn.metrics import log_loss
10
+
11
+ from ins_pricing.modelling.bayesopt.trainers.trainer_base import TrainerBase
12
+ from ins_pricing.modelling.bayesopt.models import ResNetSklearn
13
+ from ins_pricing.utils.losses import regression_loss
14
+
15
+ class ResNetTrainer(TrainerBase):
16
+ def __init__(self, context: "BayesOptModel") -> None:
17
+ if context.task_type == 'classification':
18
+ super().__init__(context, 'ResNetClassifier', 'ResNet')
19
+ else:
20
+ super().__init__(context, 'ResNet', 'ResNet')
21
+ self.model: Optional[ResNetSklearn] = None
22
+ self.enable_distributed_optuna = bool(context.config.use_resn_ddp)
23
+
24
+ def _resolve_input_dim(self) -> int:
25
+ data = getattr(self.ctx, "train_oht_scl_data", None)
26
+ if data is not None and getattr(self.ctx, "var_nmes", None):
27
+ return int(data[self.ctx.var_nmes].shape[1])
28
+ return int(len(self.ctx.var_nmes or []))
29
+
30
+ def _build_model(self, params: Optional[Dict[str, Any]] = None) -> ResNetSklearn:
31
+ params = params or {}
32
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
33
+ power = params.get("tw_power")
34
+ if self.ctx.task_type == "regression":
35
+ base_tw = self.ctx.default_tweedie_power()
36
+ if loss_name == "tweedie":
37
+ power = base_tw if power is None else float(power)
38
+ elif loss_name in ("poisson", "gamma"):
39
+ power = base_tw
40
+ else:
41
+ power = None
42
+ resn_weight_decay = float(
43
+ params.get(
44
+ "weight_decay",
45
+ getattr(self.ctx.config, "resn_weight_decay", 1e-4),
46
+ )
47
+ )
48
+ model = ResNetSklearn(
49
+ model_nme=self.ctx.model_nme,
50
+ input_dim=self._resolve_input_dim(),
51
+ hidden_dim=int(params.get("hidden_dim", 64)),
52
+ block_num=int(params.get("block_num", 2)),
53
+ task_type=self.ctx.task_type,
54
+ epochs=self.ctx.epochs,
55
+ tweedie_power=power,
56
+ learning_rate=float(params.get("learning_rate", 0.01)),
57
+ patience=int(params.get("patience", 10)),
58
+ use_layernorm=True,
59
+ dropout=float(params.get("dropout", 0.1)),
60
+ residual_scale=float(params.get("residual_scale", 0.1)),
61
+ stochastic_depth=float(params.get("stochastic_depth", 0.0)),
62
+ weight_decay=resn_weight_decay,
63
+ use_data_parallel=self.ctx.config.use_resn_data_parallel,
64
+ use_ddp=self.ctx.config.use_resn_ddp,
65
+ loss_name=loss_name
66
+ )
67
+ return self._apply_dataloader_overrides(model)
68
+
69
+ # ========= Cross-validation (for BayesOpt) =========
70
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
71
+ # ResNet CV focuses on memory control:
72
+ # - Create a ResNetSklearn per fold and release it immediately after.
73
+ # - Move model to CPU, delete, and call gc/empty_cache after each fold.
74
+ # - Optionally sample part of training data during BayesOpt to reduce memory.
75
+
76
+ base_tw_power = self.ctx.default_tweedie_power()
77
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
78
+
79
+ def data_provider():
80
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
81
+ assert data is not None, "Preprocessed training data is missing."
82
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
83
+
84
+ metric_ctx: Dict[str, Any] = {}
85
+
86
+ def model_builder(params):
87
+ if loss_name == "tweedie":
88
+ power = params.get("tw_power", base_tw_power)
89
+ elif loss_name in ("poisson", "gamma"):
90
+ power = base_tw_power
91
+ else:
92
+ power = None
93
+ metric_ctx["tw_power"] = power
94
+ params_local = dict(params)
95
+ if power is not None:
96
+ params_local["tw_power"] = power
97
+ return self._build_model(params_local)
98
+
99
+ def preprocess_fn(X_train, X_val):
100
+ X_train_s, X_val_s, _ = self._standardize_fold(
101
+ X_train, X_val, self.ctx.num_features)
102
+ return X_train_s, X_val_s
103
+
104
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
105
+ model.fit(
106
+ X_train, y_train, w_train,
107
+ X_val, y_val, w_val,
108
+ trial=trial_obj
109
+ )
110
+ return model.predict(X_val)
111
+
112
+ def metric_fn(y_true, y_pred, weight):
113
+ if self.ctx.task_type == 'regression':
114
+ return regression_loss(
115
+ y_true,
116
+ y_pred,
117
+ weight,
118
+ loss_name=loss_name,
119
+ tweedie_power=metric_ctx.get("tw_power", base_tw_power),
120
+ )
121
+ return log_loss(y_true, y_pred, sample_weight=weight)
122
+
123
+ sample_cap = data_provider()[0]
124
+ max_rows_for_resnet_bo = min(100000, int(len(sample_cap)/5))
125
+
126
+ return self.cross_val_generic(
127
+ trial=trial,
128
+ hyperparameter_space={
129
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-6, 1e-2, log=True),
130
+ "hidden_dim": lambda t: t.suggest_int('hidden_dim', 8, 32, step=2),
131
+ "block_num": lambda t: t.suggest_int('block_num', 2, 10),
132
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3, step=0.05),
133
+ "residual_scale": lambda t: t.suggest_float('residual_scale', 0.05, 0.3, step=0.05),
134
+ "patience": lambda t: t.suggest_int('patience', 3, 12),
135
+ "stochastic_depth": lambda t: t.suggest_float('stochastic_depth', 0.0, 0.2, step=0.05),
136
+ **({"tw_power": lambda t: t.suggest_float('tw_power', 1.0, 2.0)} if self.ctx.task_type == 'regression' and loss_name == 'tweedie' else {})
137
+ },
138
+ data_provider=data_provider,
139
+ model_builder=model_builder,
140
+ metric_fn=metric_fn,
141
+ sample_limit=max_rows_for_resnet_bo if len(
142
+ sample_cap) > max_rows_for_resnet_bo > 0 else None,
143
+ preprocess_fn=preprocess_fn,
144
+ fit_predict_fn=fit_predict,
145
+ cleanup_fn=lambda m: getattr(
146
+ getattr(m, "resnet", None), "to", lambda *_args, **_kwargs: None)("cpu")
147
+ )
148
+
149
+ # ========= Train final ResNet with best hyperparameters =========
150
+ def train(self) -> None:
151
+ if not self.best_params:
152
+ raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
153
+
154
+ params = dict(self.best_params)
155
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
156
+ data = self.ctx.train_oht_scl_data
157
+ if data is None:
158
+ raise RuntimeError("Missing standardized data for ResNet training.")
159
+ X_all = data[self.ctx.var_nmes]
160
+ y_all = data[self.ctx.resp_nme]
161
+ w_all = data[self.ctx.weight_nme]
162
+
163
+ refit_epochs = None
164
+ split = self._resolve_train_val_indices(X_all)
165
+ if use_refit and split is not None:
166
+ train_idx, val_idx = split
167
+ tmp_model = self._build_model(params)
168
+ tmp_model.fit(
169
+ X_all.iloc[train_idx],
170
+ y_all.iloc[train_idx],
171
+ w_all.iloc[train_idx],
172
+ X_all.iloc[val_idx],
173
+ y_all.iloc[val_idx],
174
+ w_all.iloc[val_idx],
175
+ trial=None,
176
+ )
177
+ refit_epochs = self._resolve_best_epoch(
178
+ getattr(tmp_model, "training_history", None),
179
+ default_epochs=int(self.ctx.epochs),
180
+ )
181
+ getattr(getattr(tmp_model, "resnet", None), "to",
182
+ lambda *_args, **_kwargs: None)("cpu")
183
+ self._clean_gpu()
184
+
185
+ self.model = self._build_model(params)
186
+ if refit_epochs is not None:
187
+ self.model.epochs = int(refit_epochs)
188
+ self.best_params = params
189
+ loss_plot_path = self.output.plot_path(
190
+ f'{self.ctx.model_nme}/loss/loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
191
+ self.model.loss_curve_path = loss_plot_path
192
+
193
+ self._fit_predict_cache(
194
+ self.model,
195
+ X_all,
196
+ y_all,
197
+ sample_weight=w_all,
198
+ pred_prefix='resn',
199
+ use_oht=True,
200
+ sample_weight_arg='w_train'
201
+ )
202
+
203
+ # Convenience wrapper for external callers.
204
+ self.ctx.resn_best = self.model
205
+
206
+ def ensemble_predict(self, k: int) -> None:
207
+ if not self.best_params:
208
+ raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
209
+ data = self.ctx.train_oht_scl_data
210
+ test_data = self.ctx.test_oht_scl_data
211
+ if data is None or test_data is None:
212
+ raise RuntimeError("Missing standardized data for ResNet ensemble.")
213
+ X_all = data[self.ctx.var_nmes]
214
+ y_all = data[self.ctx.resp_nme]
215
+ w_all = data[self.ctx.weight_nme]
216
+ X_test = test_data[self.ctx.var_nmes]
217
+
218
+ k = max(2, int(k))
219
+ n_samples = len(X_all)
220
+ split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
221
+ if split_iter is None:
222
+ print(
223
+ f"[ResNet Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
224
+ flush=True,
225
+ )
226
+ return
227
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
228
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
229
+
230
+ split_count = 0
231
+ for train_idx, val_idx in split_iter:
232
+ model = self._build_model(self.best_params)
233
+ model.fit(
234
+ X_all.iloc[train_idx],
235
+ y_all.iloc[train_idx],
236
+ w_all.iloc[train_idx],
237
+ X_all.iloc[val_idx],
238
+ y_all.iloc[val_idx],
239
+ w_all.iloc[val_idx],
240
+ trial=None,
241
+ )
242
+ pred_train = model.predict(X_all)
243
+ pred_test = model.predict(X_test)
244
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
245
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
246
+ getattr(getattr(model, "resnet", None), "to",
247
+ lambda *_args, **_kwargs: None)("cpu")
248
+ self._clean_gpu()
249
+ split_count += 1
250
+
251
+ if split_count < 1:
252
+ print(
253
+ f"[ResNet Ensemble] no CV splits generated; skip ensemble.",
254
+ flush=True,
255
+ )
256
+ return
257
+ preds_train = preds_train_sum / float(split_count)
258
+ preds_test = preds_test_sum / float(split_count)
259
+ self._cache_predictions("resn", preds_train, preds_test)
260
+
261
+ # ========= Save / Load =========
262
+ # ResNet is saved as state_dict and needs a custom load path.
263
+ # Save logic is implemented in TrainerBase (checks .resnet attribute).
264
+
265
+ def load(self) -> None:
266
+ # Load ResNet weights to the current device to match context.
267
+ path = self.output.model_path(self._get_model_filename())
268
+ if os.path.exists(path):
269
+ payload = torch.load(path, map_location='cpu')
270
+ if isinstance(payload, dict) and "state_dict" in payload:
271
+ state_dict = payload.get("state_dict")
272
+ params = payload.get("best_params") or self.best_params
273
+ else:
274
+ state_dict = payload
275
+ params = self.best_params
276
+ resn_loaded = self._build_model(params)
277
+ resn_loaded.resnet.load_state_dict(state_dict)
278
+
279
+ self._move_to_device(resn_loaded)
280
+ self.model = resn_loaded
281
+ self.ctx.resn_best = self.model
282
+ else:
283
+ print(f"[ResNetTrainer.load] Model file not found: {path}")