ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,24 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- from typing import Any, Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import optuna
8
- import torch
9
- from sklearn.metrics import log_loss
10
-
11
- from .trainer_base import TrainerBase
12
- from ..models import ResNetSklearn
13
- from ..utils.losses import regression_loss
14
-
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import optuna
8
+ import torch
9
+ from sklearn.metrics import log_loss
10
+
11
+ from ins_pricing.modelling.bayesopt.trainers.trainer_base import TrainerBase
12
+ from ins_pricing.modelling.bayesopt.models import ResNetSklearn
13
+ from ins_pricing.utils.losses import regression_loss
14
+ from ins_pricing.utils import get_logger, log_print
15
+
16
+ _logger = get_logger("ins_pricing.trainer.resn")
17
+
18
+
19
+ def _log(*args, **kwargs) -> None:
20
+ log_print(_logger, *args, **kwargs)
21
+
15
22
  class ResNetTrainer(TrainerBase):
16
23
  def __init__(self, context: "BayesOptModel") -> None:
17
24
  if context.task_type == 'classification':
@@ -21,263 +28,268 @@ class ResNetTrainer(TrainerBase):
21
28
  self.model: Optional[ResNetSklearn] = None
22
29
  self.enable_distributed_optuna = bool(context.config.use_resn_ddp)
23
30
 
24
- def _resolve_input_dim(self) -> int:
25
- data = getattr(self.ctx, "train_oht_scl_data", None)
26
- if data is not None and getattr(self.ctx, "var_nmes", None):
27
- return int(data[self.ctx.var_nmes].shape[1])
28
- return int(len(self.ctx.var_nmes or []))
29
-
30
- def _build_model(self, params: Optional[Dict[str, Any]] = None) -> ResNetSklearn:
31
- params = params or {}
32
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
33
- power = params.get("tw_power")
34
- if self.ctx.task_type == "regression":
35
- base_tw = self.ctx.default_tweedie_power()
36
- if loss_name == "tweedie":
37
- power = base_tw if power is None else float(power)
38
- elif loss_name in ("poisson", "gamma"):
39
- power = base_tw
40
- else:
41
- power = None
42
- resn_weight_decay = float(
43
- params.get(
44
- "weight_decay",
45
- getattr(self.ctx.config, "resn_weight_decay", 1e-4),
46
- )
47
- )
48
- model = ResNetSklearn(
49
- model_nme=self.ctx.model_nme,
50
- input_dim=self._resolve_input_dim(),
51
- hidden_dim=int(params.get("hidden_dim", 64)),
52
- block_num=int(params.get("block_num", 2)),
53
- task_type=self.ctx.task_type,
54
- epochs=self.ctx.epochs,
55
- tweedie_power=power,
56
- learning_rate=float(params.get("learning_rate", 0.01)),
57
- patience=int(params.get("patience", 10)),
58
- use_layernorm=True,
59
- dropout=float(params.get("dropout", 0.1)),
60
- residual_scale=float(params.get("residual_scale", 0.1)),
61
- stochastic_depth=float(params.get("stochastic_depth", 0.0)),
62
- weight_decay=resn_weight_decay,
63
- use_data_parallel=self.ctx.config.use_resn_data_parallel,
64
- use_ddp=self.ctx.config.use_resn_ddp,
65
- loss_name=loss_name
66
- )
67
- return self._apply_dataloader_overrides(model)
68
-
69
- # ========= Cross-validation (for BayesOpt) =========
70
- def cross_val(self, trial: optuna.trial.Trial) -> float:
71
- # ResNet CV focuses on memory control:
72
- # - Create a ResNetSklearn per fold and release it immediately after.
73
- # - Move model to CPU, delete, and call gc/empty_cache after each fold.
74
- # - Optionally sample part of training data during BayesOpt to reduce memory.
75
-
76
- base_tw_power = self.ctx.default_tweedie_power()
77
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
78
-
79
- def data_provider():
80
- data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
81
- assert data is not None, "Preprocessed training data is missing."
82
- return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
83
-
84
- metric_ctx: Dict[str, Any] = {}
85
-
86
- def model_builder(params):
87
- if loss_name == "tweedie":
88
- power = params.get("tw_power", base_tw_power)
89
- elif loss_name in ("poisson", "gamma"):
90
- power = base_tw_power
91
- else:
92
- power = None
93
- metric_ctx["tw_power"] = power
94
- params_local = dict(params)
95
- if power is not None:
96
- params_local["tw_power"] = power
97
- return self._build_model(params_local)
98
-
99
- def preprocess_fn(X_train, X_val):
100
- X_train_s, X_val_s, _ = self._standardize_fold(
101
- X_train, X_val, self.ctx.num_features)
102
- return X_train_s, X_val_s
103
-
104
- def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
105
- model.fit(
106
- X_train, y_train, w_train,
107
- X_val, y_val, w_val,
108
- trial=trial_obj
109
- )
110
- return model.predict(X_val)
111
-
112
- def metric_fn(y_true, y_pred, weight):
113
- if self.ctx.task_type == 'regression':
114
- return regression_loss(
115
- y_true,
116
- y_pred,
117
- weight,
118
- loss_name=loss_name,
119
- tweedie_power=metric_ctx.get("tw_power", base_tw_power),
120
- )
121
- return log_loss(y_true, y_pred, sample_weight=weight)
122
-
123
- sample_cap = data_provider()[0]
124
- max_rows_for_resnet_bo = min(100000, int(len(sample_cap)/5))
125
-
126
- return self.cross_val_generic(
127
- trial=trial,
128
- hyperparameter_space={
129
- "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-6, 1e-2, log=True),
130
- "hidden_dim": lambda t: t.suggest_int('hidden_dim', 8, 32, step=2),
131
- "block_num": lambda t: t.suggest_int('block_num', 2, 10),
132
- "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3, step=0.05),
133
- "residual_scale": lambda t: t.suggest_float('residual_scale', 0.05, 0.3, step=0.05),
134
- "patience": lambda t: t.suggest_int('patience', 3, 12),
135
- "stochastic_depth": lambda t: t.suggest_float('stochastic_depth', 0.0, 0.2, step=0.05),
136
- **({"tw_power": lambda t: t.suggest_float('tw_power', 1.0, 2.0)} if self.ctx.task_type == 'regression' and loss_name == 'tweedie' else {})
137
- },
138
- data_provider=data_provider,
139
- model_builder=model_builder,
140
- metric_fn=metric_fn,
141
- sample_limit=max_rows_for_resnet_bo if len(
142
- sample_cap) > max_rows_for_resnet_bo > 0 else None,
143
- preprocess_fn=preprocess_fn,
144
- fit_predict_fn=fit_predict,
145
- cleanup_fn=lambda m: getattr(
146
- getattr(m, "resnet", None), "to", lambda *_args, **_kwargs: None)("cpu")
147
- )
148
-
149
- # ========= Train final ResNet with best hyperparameters =========
150
- def train(self) -> None:
151
- if not self.best_params:
152
- raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
153
-
154
- params = dict(self.best_params)
155
- use_refit = bool(getattr(self.ctx.config, "final_refit", True))
156
- data = self.ctx.train_oht_scl_data
157
- if data is None:
158
- raise RuntimeError("Missing standardized data for ResNet training.")
159
- X_all = data[self.ctx.var_nmes]
160
- y_all = data[self.ctx.resp_nme]
161
- w_all = data[self.ctx.weight_nme]
162
-
163
- refit_epochs = None
164
- split = self._resolve_train_val_indices(X_all)
165
- if use_refit and split is not None:
166
- train_idx, val_idx = split
167
- tmp_model = self._build_model(params)
168
- tmp_model.fit(
169
- X_all.iloc[train_idx],
170
- y_all.iloc[train_idx],
171
- w_all.iloc[train_idx],
172
- X_all.iloc[val_idx],
173
- y_all.iloc[val_idx],
174
- w_all.iloc[val_idx],
175
- trial=None,
176
- )
31
+ def _maybe_cleanup_gpu(self, model: Optional[ResNetSklearn]) -> None:
32
+ if not bool(getattr(self.ctx.config, "resn_cleanup_per_fold", False)):
33
+ return
34
+ if model is not None:
35
+ getattr(getattr(model, "resnet", None), "to",
36
+ lambda *_args, **_kwargs: None)("cpu")
37
+ synchronize = bool(getattr(self.ctx.config, "resn_cleanup_synchronize", False))
38
+ self._clean_gpu(synchronize=synchronize)
39
+
40
+ def _resolve_input_dim(self) -> int:
41
+ data = getattr(self.ctx, "train_oht_scl_data", None)
42
+ if data is not None and getattr(self.ctx, "var_nmes", None):
43
+ return int(data[self.ctx.var_nmes].shape[1])
44
+ return int(len(self.ctx.var_nmes or []))
45
+
46
+ def _build_model(self, params: Optional[Dict[str, Any]] = None) -> ResNetSklearn:
47
+ params = params or {}
48
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
49
+ power = params.get("tw_power")
50
+ if self.ctx.task_type == "regression":
51
+ base_tw = self.ctx.default_tweedie_power()
52
+ if loss_name == "tweedie":
53
+ power = base_tw if power is None else float(power)
54
+ elif loss_name in ("poisson", "gamma"):
55
+ power = base_tw
56
+ else:
57
+ power = None
58
+ resn_weight_decay = float(
59
+ params.get(
60
+ "weight_decay",
61
+ getattr(self.ctx.config, "resn_weight_decay", 1e-4),
62
+ )
63
+ )
64
+ model = ResNetSklearn(
65
+ model_nme=self.ctx.model_nme,
66
+ input_dim=self._resolve_input_dim(),
67
+ hidden_dim=int(params.get("hidden_dim", 64)),
68
+ block_num=int(params.get("block_num", 2)),
69
+ task_type=self.ctx.task_type,
70
+ epochs=self.ctx.epochs,
71
+ tweedie_power=power,
72
+ learning_rate=float(params.get("learning_rate", 0.01)),
73
+ patience=int(params.get("patience", 10)),
74
+ use_layernorm=True,
75
+ dropout=float(params.get("dropout", 0.1)),
76
+ residual_scale=float(params.get("residual_scale", 0.1)),
77
+ stochastic_depth=float(params.get("stochastic_depth", 0.0)),
78
+ weight_decay=resn_weight_decay,
79
+ use_data_parallel=self.ctx.config.use_resn_data_parallel,
80
+ use_ddp=self.ctx.config.use_resn_ddp,
81
+ loss_name=loss_name
82
+ )
83
+ return self._apply_dataloader_overrides(model)
84
+
85
+ # ========= Cross-validation (for BayesOpt) =========
86
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
87
+ # ResNet CV focuses on memory control:
88
+ # - Create a ResNetSklearn per fold and release it immediately after.
89
+ # - Move model to CPU, delete, and call gc/empty_cache after each fold.
90
+ # - Optionally sample part of training data during BayesOpt to reduce memory.
91
+
92
+ base_tw_power = self.ctx.default_tweedie_power()
93
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
94
+
95
+ def data_provider():
96
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
97
+ assert data is not None, "Preprocessed training data is missing."
98
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
99
+
100
+ metric_ctx: Dict[str, Any] = {}
101
+
102
+ def model_builder(params):
103
+ if loss_name == "tweedie":
104
+ power = params.get("tw_power", base_tw_power)
105
+ elif loss_name in ("poisson", "gamma"):
106
+ power = base_tw_power
107
+ else:
108
+ power = None
109
+ metric_ctx["tw_power"] = power
110
+ params_local = dict(params)
111
+ if power is not None:
112
+ params_local["tw_power"] = power
113
+ return self._build_model(params_local)
114
+
115
+ def preprocess_fn(X_train, X_val):
116
+ X_train_s, X_val_s, _ = self._standardize_fold(
117
+ X_train, X_val, self.ctx.num_features)
118
+ return X_train_s, X_val_s
119
+
120
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
121
+ model.fit(
122
+ X_train, y_train, w_train,
123
+ X_val, y_val, w_val,
124
+ trial=trial_obj
125
+ )
126
+ return model.predict(X_val)
127
+
128
+ def metric_fn(y_true, y_pred, weight):
129
+ if self.ctx.task_type == 'regression':
130
+ return regression_loss(
131
+ y_true,
132
+ y_pred,
133
+ weight,
134
+ loss_name=loss_name,
135
+ tweedie_power=metric_ctx.get("tw_power", base_tw_power),
136
+ )
137
+ return log_loss(y_true, y_pred, sample_weight=weight)
138
+
139
+ sample_cap = data_provider()[0]
140
+ max_rows_for_resnet_bo = min(100000, int(len(sample_cap)/5))
141
+
142
+ return self.cross_val_generic(
143
+ trial=trial,
144
+ hyperparameter_space={
145
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-6, 1e-2, log=True),
146
+ "hidden_dim": lambda t: t.suggest_int('hidden_dim', 8, 32, step=2),
147
+ "block_num": lambda t: t.suggest_int('block_num', 2, 10),
148
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3, step=0.05),
149
+ "residual_scale": lambda t: t.suggest_float('residual_scale', 0.05, 0.3, step=0.05),
150
+ "patience": lambda t: t.suggest_int('patience', 3, 12),
151
+ "stochastic_depth": lambda t: t.suggest_float('stochastic_depth', 0.0, 0.2, step=0.05),
152
+ **({"tw_power": lambda t: t.suggest_float('tw_power', 1.0, 2.0)} if self.ctx.task_type == 'regression' and loss_name == 'tweedie' else {})
153
+ },
154
+ data_provider=data_provider,
155
+ model_builder=model_builder,
156
+ metric_fn=metric_fn,
157
+ sample_limit=max_rows_for_resnet_bo if len(
158
+ sample_cap) > max_rows_for_resnet_bo > 0 else None,
159
+ preprocess_fn=preprocess_fn,
160
+ fit_predict_fn=fit_predict,
161
+ cleanup_fn=lambda m: getattr(
162
+ getattr(m, "resnet", None), "to", lambda *_args, **_kwargs: None)("cpu")
163
+ )
164
+
165
+ # ========= Train final ResNet with best hyperparameters =========
166
+ def train(self) -> None:
167
+ if not self.best_params:
168
+ raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
169
+
170
+ params = dict(self.best_params)
171
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
172
+ data = self.ctx.train_oht_scl_data
173
+ if data is None:
174
+ raise RuntimeError("Missing standardized data for ResNet training.")
175
+ X_all = data[self.ctx.var_nmes]
176
+ y_all = data[self.ctx.resp_nme]
177
+ w_all = data[self.ctx.weight_nme]
178
+
179
+ refit_epochs = None
180
+ split = self._resolve_train_val_indices(X_all)
181
+ if use_refit and split is not None:
182
+ train_idx, val_idx = split
183
+ tmp_model = self._build_model(params)
184
+ tmp_model.fit(
185
+ X_all.iloc[train_idx],
186
+ y_all.iloc[train_idx],
187
+ w_all.iloc[train_idx],
188
+ X_all.iloc[val_idx],
189
+ y_all.iloc[val_idx],
190
+ w_all.iloc[val_idx],
191
+ trial=None,
192
+ )
177
193
  refit_epochs = self._resolve_best_epoch(
178
194
  getattr(tmp_model, "training_history", None),
179
195
  default_epochs=int(self.ctx.epochs),
180
196
  )
181
- getattr(getattr(tmp_model, "resnet", None), "to",
182
- lambda *_args, **_kwargs: None)("cpu")
183
- self._clean_gpu()
184
-
185
- self.model = self._build_model(params)
186
- if refit_epochs is not None:
187
- self.model.epochs = int(refit_epochs)
188
- self.best_params = params
189
- loss_plot_path = self.output.plot_path(
190
- f'{self.ctx.model_nme}/loss/loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
191
- self.model.loss_curve_path = loss_plot_path
192
-
193
- self._fit_predict_cache(
194
- self.model,
195
- X_all,
196
- y_all,
197
- sample_weight=w_all,
198
- pred_prefix='resn',
199
- use_oht=True,
200
- sample_weight_arg='w_train'
201
- )
202
-
203
- # Convenience wrapper for external callers.
204
- self.ctx.resn_best = self.model
205
-
206
- def ensemble_predict(self, k: int) -> None:
207
- if not self.best_params:
208
- raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
209
- data = self.ctx.train_oht_scl_data
210
- test_data = self.ctx.test_oht_scl_data
211
- if data is None or test_data is None:
212
- raise RuntimeError("Missing standardized data for ResNet ensemble.")
213
- X_all = data[self.ctx.var_nmes]
214
- y_all = data[self.ctx.resp_nme]
215
- w_all = data[self.ctx.weight_nme]
216
- X_test = test_data[self.ctx.var_nmes]
217
-
218
- k = max(2, int(k))
219
- n_samples = len(X_all)
220
- split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
221
- if split_iter is None:
222
- print(
223
- f"[ResNet Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
224
- flush=True,
225
- )
226
- return
227
- preds_train_sum = np.zeros(n_samples, dtype=np.float64)
228
- preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
229
-
230
- split_count = 0
231
- for train_idx, val_idx in split_iter:
232
- model = self._build_model(self.best_params)
233
- model.fit(
234
- X_all.iloc[train_idx],
235
- y_all.iloc[train_idx],
236
- w_all.iloc[train_idx],
237
- X_all.iloc[val_idx],
238
- y_all.iloc[val_idx],
239
- w_all.iloc[val_idx],
240
- trial=None,
241
- )
242
- pred_train = model.predict(X_all)
243
- pred_test = model.predict(X_test)
197
+ self._maybe_cleanup_gpu(tmp_model)
198
+
199
+ self.model = self._build_model(params)
200
+ if refit_epochs is not None:
201
+ self.model.epochs = int(refit_epochs)
202
+ self.best_params = params
203
+ loss_plot_path = self.output.plot_path(
204
+ f'{self.ctx.model_nme}/loss/loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
205
+ self.model.loss_curve_path = loss_plot_path
206
+
207
+ self._fit_predict_cache(
208
+ self.model,
209
+ X_all,
210
+ y_all,
211
+ sample_weight=w_all,
212
+ pred_prefix='resn',
213
+ use_oht=True,
214
+ sample_weight_arg='w_train'
215
+ )
216
+
217
+ # Convenience wrapper for external callers.
218
+ self.ctx.resn_best = self.model
219
+
220
+ def ensemble_predict(self, k: int) -> None:
221
+ if not self.best_params:
222
+ raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
223
+ data = self.ctx.train_oht_scl_data
224
+ test_data = self.ctx.test_oht_scl_data
225
+ if data is None or test_data is None:
226
+ raise RuntimeError("Missing standardized data for ResNet ensemble.")
227
+ X_all = data[self.ctx.var_nmes]
228
+ y_all = data[self.ctx.resp_nme]
229
+ w_all = data[self.ctx.weight_nme]
230
+ X_test = test_data[self.ctx.var_nmes]
231
+
232
+ k = max(2, int(k))
233
+ n_samples = len(X_all)
234
+ split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
235
+ if split_iter is None:
236
+ _log(
237
+ f"[ResNet Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
238
+ flush=True,
239
+ )
240
+ return
241
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
242
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
243
+
244
+ split_count = 0
245
+ for train_idx, val_idx in split_iter:
246
+ model = self._build_model(self.best_params)
247
+ model.fit(
248
+ X_all.iloc[train_idx],
249
+ y_all.iloc[train_idx],
250
+ w_all.iloc[train_idx],
251
+ X_all.iloc[val_idx],
252
+ y_all.iloc[val_idx],
253
+ w_all.iloc[val_idx],
254
+ trial=None,
255
+ )
256
+ pred_train = model.predict(X_all)
257
+ pred_test = model.predict(X_test)
244
258
  preds_train_sum += np.asarray(pred_train, dtype=np.float64)
245
259
  preds_test_sum += np.asarray(pred_test, dtype=np.float64)
246
- getattr(getattr(model, "resnet", None), "to",
247
- lambda *_args, **_kwargs: None)("cpu")
248
- self._clean_gpu()
260
+ self._maybe_cleanup_gpu(model)
249
261
  split_count += 1
250
-
251
- if split_count < 1:
252
- print(
253
- f"[ResNet Ensemble] no CV splits generated; skip ensemble.",
254
- flush=True,
255
- )
256
- return
257
- preds_train = preds_train_sum / float(split_count)
258
- preds_test = preds_test_sum / float(split_count)
259
- self._cache_predictions("resn", preds_train, preds_test)
260
-
261
- # ========= Save / Load =========
262
- # ResNet is saved as state_dict and needs a custom load path.
263
- # Save logic is implemented in TrainerBase (checks .resnet attribute).
264
-
265
- def load(self) -> None:
266
- # Load ResNet weights to the current device to match context.
267
- path = self.output.model_path(self._get_model_filename())
268
- if os.path.exists(path):
269
- payload = torch.load(path, map_location='cpu')
270
- if isinstance(payload, dict) and "state_dict" in payload:
271
- state_dict = payload.get("state_dict")
272
- params = payload.get("best_params") or self.best_params
273
- else:
274
- state_dict = payload
275
- params = self.best_params
276
- resn_loaded = self._build_model(params)
277
- resn_loaded.resnet.load_state_dict(state_dict)
278
-
279
- self._move_to_device(resn_loaded)
280
- self.model = resn_loaded
281
- self.ctx.resn_best = self.model
282
- else:
283
- print(f"[ResNetTrainer.load] Model file not found: {path}")
262
+
263
+ if split_count < 1:
264
+ _log(
265
+ f"[ResNet Ensemble] no CV splits generated; skip ensemble.",
266
+ flush=True,
267
+ )
268
+ return
269
+ preds_train = preds_train_sum / float(split_count)
270
+ preds_test = preds_test_sum / float(split_count)
271
+ self._cache_predictions("resn", preds_train, preds_test)
272
+
273
+ # ========= Save / Load =========
274
+ # ResNet is saved as state_dict and needs a custom load path.
275
+ # Save logic is implemented in TrainerBase (checks .resnet attribute).
276
+
277
+ def load(self) -> None:
278
+ # Load ResNet weights to the current device to match context.
279
+ path = self.output.model_path(self._get_model_filename())
280
+ if os.path.exists(path):
281
+ payload = torch.load(path, map_location='cpu')
282
+ if isinstance(payload, dict) and "state_dict" in payload:
283
+ state_dict = payload.get("state_dict")
284
+ params = payload.get("best_params") or self.best_params
285
+ else:
286
+ state_dict = payload
287
+ params = self.best_params
288
+ resn_loaded = self._build_model(params)
289
+ resn_loaded.resnet.load_state_dict(state_dict)
290
+
291
+ self._move_to_device(resn_loaded)
292
+ self.model = resn_loaded
293
+ self.ctx.resn_best = self.model
294
+ else:
295
+ _log(f"[ResNetTrainer.load] Model file not found: {path}")