ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +48 -22
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +58 -46
- ins_pricing/cli/BayesOpt_incremental.py +77 -110
- ins_pricing/cli/Explain_Run.py +42 -23
- ins_pricing/cli/Explain_entry.py +551 -577
- ins_pricing/cli/Pricing_Run.py +42 -23
- ins_pricing/cli/bayesopt_entry_runner.py +51 -16
- ins_pricing/cli/utils/bootstrap.py +23 -0
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +379 -360
- ins_pricing/cli/utils/import_resolver.py +375 -358
- ins_pricing/cli/utils/notebook_utils.py +256 -242
- ins_pricing/cli/watchdog_run.py +216 -198
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/app.py +132 -61
- ins_pricing/frontend/config_builder.py +33 -0
- ins_pricing/frontend/example_config.json +11 -0
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/frontend/runner.py +340 -388
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/README.md +1 -1
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/pricing/factors.py +67 -56
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +112 -78
- ins_pricing/utils/device.py +258 -237
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/utils/logging.py +34 -1
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- ins_pricing/utils/profiling.py +8 -4
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
- ins_pricing-0.5.1.dist-info/RECORD +132 -0
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.5.dist-info/RECORD +0 -130
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -1,438 +1,442 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from datetime import timedelta
|
|
4
|
-
import gc
|
|
5
|
-
import os
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
|
8
|
-
|
|
9
|
-
import joblib
|
|
10
|
-
import numpy as np
|
|
11
|
-
import optuna
|
|
12
|
-
import pandas as pd
|
|
13
|
-
import torch
|
|
14
|
-
try: # pragma: no cover
|
|
15
|
-
import torch.distributed as dist # type: ignore
|
|
16
|
-
except Exception: # pragma: no cover
|
|
17
|
-
dist = None # type: ignore
|
|
18
|
-
from sklearn.model_selection import (
|
|
19
|
-
GroupKFold,
|
|
20
|
-
GroupShuffleSplit,
|
|
21
|
-
KFold,
|
|
22
|
-
ShuffleSplit,
|
|
23
|
-
TimeSeriesSplit,
|
|
24
|
-
)
|
|
25
|
-
from sklearn.preprocessing import StandardScaler
|
|
26
|
-
|
|
27
|
-
from
|
|
28
|
-
from
|
|
29
|
-
from ins_pricing.utils import get_logger, GPUMemoryManager, DeviceManager
|
|
30
|
-
from ins_pricing.utils.torch_compat import torch_load
|
|
31
|
-
|
|
32
|
-
# Module-level logger
|
|
33
|
-
_logger = get_logger("ins_pricing.trainer")
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def
|
|
89
|
-
"""
|
|
90
|
-
return self._strategy
|
|
91
|
-
|
|
92
|
-
def
|
|
93
|
-
"""Check if using a
|
|
94
|
-
return self._strategy in self.
|
|
95
|
-
|
|
96
|
-
def
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
return
|
|
113
|
-
|
|
114
|
-
def
|
|
115
|
-
"""Get
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
self
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
n_splits
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
self.
|
|
304
|
-
self.
|
|
305
|
-
self.
|
|
306
|
-
self.
|
|
307
|
-
self.
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
"
|
|
332
|
-
"
|
|
333
|
-
"
|
|
334
|
-
"
|
|
335
|
-
"
|
|
336
|
-
"
|
|
337
|
-
"
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
flush=True
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
@property
|
|
415
|
-
def
|
|
416
|
-
return self.ctx.
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
return
|
|
421
|
-
|
|
422
|
-
def
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import timedelta
|
|
4
|
+
import gc
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import joblib
|
|
10
|
+
import numpy as np
|
|
11
|
+
import optuna
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import torch
|
|
14
|
+
try: # pragma: no cover
|
|
15
|
+
import torch.distributed as dist # type: ignore
|
|
16
|
+
except Exception: # pragma: no cover
|
|
17
|
+
dist = None # type: ignore
|
|
18
|
+
from sklearn.model_selection import (
|
|
19
|
+
GroupKFold,
|
|
20
|
+
GroupShuffleSplit,
|
|
21
|
+
KFold,
|
|
22
|
+
ShuffleSplit,
|
|
23
|
+
TimeSeriesSplit,
|
|
24
|
+
)
|
|
25
|
+
from sklearn.preprocessing import StandardScaler
|
|
26
|
+
|
|
27
|
+
from ins_pricing.modelling.bayesopt.config_preprocess import BayesOptConfig, OutputManager
|
|
28
|
+
from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
|
|
29
|
+
from ins_pricing.utils import EPS, ensure_parent_dir, get_logger, GPUMemoryManager, DeviceManager, log_print
|
|
30
|
+
from ins_pricing.utils.torch_compat import torch_load
|
|
31
|
+
|
|
32
|
+
# Module-level logger
|
|
33
|
+
_logger = get_logger("ins_pricing.trainer")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _log(*args, **kwargs) -> None:
|
|
37
|
+
log_print(_logger, *args, **kwargs)
|
|
38
|
+
|
|
39
|
+
class _OrderSplitter:
|
|
40
|
+
def __init__(self, splitter, order: np.ndarray) -> None:
|
|
41
|
+
self._splitter = splitter
|
|
42
|
+
self._order = np.asarray(order)
|
|
43
|
+
|
|
44
|
+
def split(self, X, y=None, groups=None):
|
|
45
|
+
order = self._order
|
|
46
|
+
X_ord = X.iloc[order] if hasattr(X, "iloc") else X[order]
|
|
47
|
+
for tr_idx, val_idx in self._splitter.split(X_ord, y=y, groups=groups):
|
|
48
|
+
yield order[tr_idx], order[val_idx]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# =============================================================================
|
|
52
|
+
# CV Strategy Resolution Helper
|
|
53
|
+
# =============================================================================
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class CVStrategyResolver:
|
|
57
|
+
"""Helper class to resolve cross-validation splitting strategies.
|
|
58
|
+
|
|
59
|
+
This encapsulates the logic for determining how to split data based on the
|
|
60
|
+
configured strategy (random, time, group). It provides methods to:
|
|
61
|
+
- Get time-ordered indices for a dataset
|
|
62
|
+
- Get group values for a dataset
|
|
63
|
+
- Create appropriate sklearn splitters
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
TIME_STRATEGIES = {"time", "timeseries", "temporal"}
|
|
67
|
+
GROUP_STRATEGIES = {"group", "grouped"}
|
|
68
|
+
|
|
69
|
+
def __init__(self, config, train_data: pd.DataFrame, rand_seed: Optional[int] = None):
|
|
70
|
+
"""Initialize the resolver.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
config: BayesOptConfig with cv_strategy, cv_time_col, cv_group_col, etc.
|
|
74
|
+
train_data: The training DataFrame (needed for column access)
|
|
75
|
+
rand_seed: Random seed for reproducible splits
|
|
76
|
+
"""
|
|
77
|
+
self.config = config
|
|
78
|
+
self.train_data = train_data
|
|
79
|
+
self.rand_seed = rand_seed
|
|
80
|
+
self._strategy = self._normalize_strategy()
|
|
81
|
+
|
|
82
|
+
def _normalize_strategy(self) -> str:
|
|
83
|
+
"""Normalize the strategy string to lowercase."""
|
|
84
|
+
raw = str(getattr(self.config, "cv_strategy", "random") or "random")
|
|
85
|
+
return raw.strip().lower()
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def strategy(self) -> str:
|
|
89
|
+
"""Return the normalized CV strategy."""
|
|
90
|
+
return self._strategy
|
|
91
|
+
|
|
92
|
+
def is_time_strategy(self) -> bool:
|
|
93
|
+
"""Check if using a time-based splitting strategy."""
|
|
94
|
+
return self._strategy in self.TIME_STRATEGIES
|
|
95
|
+
|
|
96
|
+
def is_group_strategy(self) -> bool:
|
|
97
|
+
"""Check if using a group-based splitting strategy."""
|
|
98
|
+
return self._strategy in self.GROUP_STRATEGIES
|
|
99
|
+
|
|
100
|
+
def get_time_col(self) -> str:
|
|
101
|
+
"""Get and validate the time column.
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
ValueError: If time column is not configured
|
|
105
|
+
KeyError: If time column not found in train_data
|
|
106
|
+
"""
|
|
107
|
+
time_col = getattr(self.config, "cv_time_col", None)
|
|
108
|
+
if not time_col:
|
|
109
|
+
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
110
|
+
if time_col not in self.train_data.columns:
|
|
111
|
+
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
112
|
+
return time_col
|
|
113
|
+
|
|
114
|
+
def get_time_ascending(self) -> bool:
|
|
115
|
+
"""Get the time ordering preference."""
|
|
116
|
+
return bool(getattr(self.config, "cv_time_ascending", True))
|
|
117
|
+
|
|
118
|
+
def get_group_col(self) -> str:
|
|
119
|
+
"""Get and validate the group column.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ValueError: If group column is not configured
|
|
123
|
+
KeyError: If group column not found in train_data
|
|
124
|
+
"""
|
|
125
|
+
group_col = getattr(self.config, "cv_group_col", None)
|
|
126
|
+
if not group_col:
|
|
127
|
+
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
128
|
+
if group_col not in self.train_data.columns:
|
|
129
|
+
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
130
|
+
return group_col
|
|
131
|
+
|
|
132
|
+
def get_time_ordered_indices(self, X_all: pd.DataFrame) -> np.ndarray:
|
|
133
|
+
"""Get indices ordered by time for the given dataset.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
X_all: DataFrame to get indices for (must have index compatible with train_data)
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Array of positional indices into X_all, ordered by time
|
|
140
|
+
"""
|
|
141
|
+
time_col = self.get_time_col()
|
|
142
|
+
ascending = self.get_time_ascending()
|
|
143
|
+
order_index = self.train_data[time_col].sort_values(ascending=ascending).index
|
|
144
|
+
index_set = set(X_all.index)
|
|
145
|
+
order_index = [idx for idx in order_index if idx in index_set]
|
|
146
|
+
order = X_all.index.get_indexer(order_index)
|
|
147
|
+
return order[order >= 0]
|
|
148
|
+
|
|
149
|
+
def get_groups(self, X_all: pd.DataFrame) -> pd.Series:
|
|
150
|
+
"""Get group labels for the given dataset.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
X_all: DataFrame to get groups for
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Series of group labels aligned with X_all
|
|
157
|
+
"""
|
|
158
|
+
group_col = self.get_group_col()
|
|
159
|
+
return self.train_data.reindex(X_all.index)[group_col]
|
|
160
|
+
|
|
161
|
+
def create_train_val_splitter(
|
|
162
|
+
self,
|
|
163
|
+
X_all: pd.DataFrame,
|
|
164
|
+
val_ratio: float,
|
|
165
|
+
) -> Tuple[Optional[Tuple[np.ndarray, np.ndarray]], Optional[pd.Series]]:
|
|
166
|
+
"""Create a single train/val split based on strategy.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
X_all: DataFrame to split
|
|
170
|
+
val_ratio: Fraction of data for validation
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Tuple of ((train_idx, val_idx), groups) where groups is None for non-group strategies
|
|
174
|
+
"""
|
|
175
|
+
if self.is_time_strategy():
|
|
176
|
+
order = self.get_time_ordered_indices(X_all)
|
|
177
|
+
cutoff = int(len(order) * (1.0 - val_ratio))
|
|
178
|
+
if cutoff <= 0 or cutoff >= len(order):
|
|
179
|
+
raise ValueError(f"val_ratio={val_ratio} leaves no data for train/val split.")
|
|
180
|
+
return (order[:cutoff], order[cutoff:]), None
|
|
181
|
+
|
|
182
|
+
if self.is_group_strategy():
|
|
183
|
+
groups = self.get_groups(X_all)
|
|
184
|
+
splitter = GroupShuffleSplit(
|
|
185
|
+
n_splits=1, test_size=val_ratio, random_state=self.rand_seed
|
|
186
|
+
)
|
|
187
|
+
train_idx, val_idx = next(splitter.split(X_all, groups=groups))
|
|
188
|
+
return (train_idx, val_idx), groups
|
|
189
|
+
|
|
190
|
+
# Random strategy
|
|
191
|
+
splitter = ShuffleSplit(
|
|
192
|
+
n_splits=1, test_size=val_ratio, random_state=self.rand_seed
|
|
193
|
+
)
|
|
194
|
+
train_idx, val_idx = next(splitter.split(X_all))
|
|
195
|
+
return (train_idx, val_idx), None
|
|
196
|
+
|
|
197
|
+
def create_cv_splitter(
|
|
198
|
+
self,
|
|
199
|
+
X_all: pd.DataFrame,
|
|
200
|
+
y_all: Optional[pd.Series],
|
|
201
|
+
n_splits: int,
|
|
202
|
+
val_ratio: float,
|
|
203
|
+
) -> Tuple[Iterable[Tuple[np.ndarray, np.ndarray]], int]:
|
|
204
|
+
"""Create a cross-validation splitter based on strategy.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
X_all: DataFrame to split
|
|
208
|
+
y_all: Target series (used by some splitters)
|
|
209
|
+
n_splits: Number of CV folds
|
|
210
|
+
val_ratio: Validation ratio (for ShuffleSplit)
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Tuple of (split_iterator, actual_n_splits)
|
|
214
|
+
"""
|
|
215
|
+
n_splits = max(2, int(n_splits))
|
|
216
|
+
|
|
217
|
+
if self.is_group_strategy():
|
|
218
|
+
groups = self.get_groups(X_all)
|
|
219
|
+
n_groups = int(groups.nunique(dropna=False))
|
|
220
|
+
if n_groups < 2:
|
|
221
|
+
return iter([]), 0
|
|
222
|
+
n_splits = min(n_splits, n_groups)
|
|
223
|
+
if n_splits < 2:
|
|
224
|
+
return iter([]), 0
|
|
225
|
+
splitter = GroupKFold(n_splits=n_splits)
|
|
226
|
+
return splitter.split(X_all, y_all, groups=groups), n_splits
|
|
227
|
+
|
|
228
|
+
if self.is_time_strategy():
|
|
229
|
+
order = self.get_time_ordered_indices(X_all)
|
|
230
|
+
if len(order) < 2:
|
|
231
|
+
return iter([]), 0
|
|
232
|
+
n_splits = min(n_splits, max(2, len(order) - 1))
|
|
233
|
+
if n_splits < 2:
|
|
234
|
+
return iter([]), 0
|
|
235
|
+
splitter = TimeSeriesSplit(n_splits=n_splits)
|
|
236
|
+
return _OrderSplitter(splitter, order).split(X_all), n_splits
|
|
237
|
+
|
|
238
|
+
# Random strategy
|
|
239
|
+
if len(X_all) < n_splits:
|
|
240
|
+
n_splits = len(X_all)
|
|
241
|
+
if n_splits < 2:
|
|
242
|
+
return iter([]), 0
|
|
243
|
+
splitter = ShuffleSplit(
|
|
244
|
+
n_splits=n_splits, test_size=val_ratio, random_state=self.rand_seed
|
|
245
|
+
)
|
|
246
|
+
return splitter.split(X_all), n_splits
|
|
247
|
+
|
|
248
|
+
def create_kfold_splitter(
|
|
249
|
+
self,
|
|
250
|
+
X_all: pd.DataFrame,
|
|
251
|
+
k: int,
|
|
252
|
+
) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
|
|
253
|
+
"""Create a K-fold splitter for ensemble training.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
X_all: DataFrame to split
|
|
257
|
+
k: Number of folds
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
|
|
261
|
+
"""
|
|
262
|
+
k = max(2, int(k))
|
|
263
|
+
n_samples = len(X_all)
|
|
264
|
+
if n_samples < 2:
|
|
265
|
+
return None, 0
|
|
266
|
+
|
|
267
|
+
if self.is_group_strategy():
|
|
268
|
+
groups = self.get_groups(X_all)
|
|
269
|
+
n_groups = int(groups.nunique(dropna=False))
|
|
270
|
+
if n_groups < 2:
|
|
271
|
+
return None, 0
|
|
272
|
+
k = min(k, n_groups)
|
|
273
|
+
if k < 2:
|
|
274
|
+
return None, 0
|
|
275
|
+
splitter = GroupKFold(n_splits=k)
|
|
276
|
+
return splitter.split(X_all, y=None, groups=groups), k
|
|
277
|
+
|
|
278
|
+
if self.is_time_strategy():
|
|
279
|
+
order = self.get_time_ordered_indices(X_all)
|
|
280
|
+
if len(order) < 2:
|
|
281
|
+
return None, 0
|
|
282
|
+
k = min(k, max(2, len(order) - 1))
|
|
283
|
+
if k < 2:
|
|
284
|
+
return None, 0
|
|
285
|
+
splitter = TimeSeriesSplit(n_splits=k)
|
|
286
|
+
return _OrderSplitter(splitter, order).split(X_all), k
|
|
287
|
+
|
|
288
|
+
# Random strategy with KFold
|
|
289
|
+
k = min(k, n_samples)
|
|
290
|
+
if k < 2:
|
|
291
|
+
return None, 0
|
|
292
|
+
splitter = KFold(n_splits=k, shuffle=True, random_state=self.rand_seed)
|
|
293
|
+
return splitter.split(X_all), k
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# =============================================================================
|
|
297
|
+
# Trainer system
|
|
298
|
+
# =============================================================================
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class TrainerBase:
|
|
302
|
+
def __init__(self, context: "BayesOptModel", label: str, model_name_prefix: str) -> None:
|
|
303
|
+
self.ctx = context
|
|
304
|
+
self.label = label
|
|
305
|
+
self.model_name_prefix = model_name_prefix
|
|
306
|
+
self.model = None
|
|
307
|
+
self.best_params: Optional[Dict[str, Any]] = None
|
|
308
|
+
self.best_trial = None
|
|
309
|
+
self.study_name: Optional[str] = None
|
|
310
|
+
self.enable_distributed_optuna: bool = False
|
|
311
|
+
self._distributed_forced_params: Optional[Dict[str, Any]] = None
|
|
312
|
+
|
|
313
|
+
def _apply_dataloader_overrides(self, model: Any) -> Any:
|
|
314
|
+
"""Apply dataloader-related overrides from config to a model."""
|
|
315
|
+
cfg = getattr(self.ctx, "config", None)
|
|
316
|
+
if cfg is None:
|
|
317
|
+
return model
|
|
318
|
+
workers = getattr(cfg, "dataloader_workers", None)
|
|
319
|
+
if workers is not None:
|
|
320
|
+
model.dataloader_workers = int(workers)
|
|
321
|
+
profile = getattr(cfg, "resource_profile", None)
|
|
322
|
+
if profile:
|
|
323
|
+
model.resource_profile = str(profile)
|
|
324
|
+
return model
|
|
325
|
+
|
|
326
|
+
def _export_preprocess_artifacts(self) -> Dict[str, Any]:
|
|
327
|
+
dummy_columns: List[str] = []
|
|
328
|
+
if getattr(self.ctx, "train_oht_data", None) is not None:
|
|
329
|
+
dummy_columns = list(self.ctx.train_oht_data.columns)
|
|
330
|
+
return {
|
|
331
|
+
"factor_nmes": list(getattr(self.ctx, "factor_nmes", []) or []),
|
|
332
|
+
"cate_list": list(getattr(self.ctx, "cate_list", []) or []),
|
|
333
|
+
"num_features": list(getattr(self.ctx, "num_features", []) or []),
|
|
334
|
+
"var_nmes": list(getattr(self.ctx, "var_nmes", []) or []),
|
|
335
|
+
"cat_categories": dict(getattr(self.ctx, "cat_categories_for_shap", {}) or {}),
|
|
336
|
+
"dummy_columns": dummy_columns,
|
|
337
|
+
"numeric_scalers": dict(getattr(self.ctx, "numeric_scalers", {}) or {}),
|
|
338
|
+
"weight_nme": str(getattr(self.ctx, "weight_nme", "")),
|
|
339
|
+
"resp_nme": str(getattr(self.ctx, "resp_nme", "")),
|
|
340
|
+
"binary_resp_nme": getattr(self.ctx, "binary_resp_nme", None),
|
|
341
|
+
"drop_first": True,
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
def _dist_barrier(self, reason: str) -> None:
|
|
345
|
+
"""DDP barrier wrapper used by distributed Optuna.
|
|
346
|
+
|
|
347
|
+
To debug "trial finished but next trial never starts" hangs, set these
|
|
348
|
+
environment variables (either in shell or config.json `env`):
|
|
349
|
+
- `BAYESOPT_DDP_BARRIER_DEBUG=1` to print barrier enter/exit per-rank
|
|
350
|
+
- `BAYESOPT_DDP_BARRIER_TIMEOUT=300` to fail fast instead of waiting forever
|
|
351
|
+
- `TORCH_DISTRIBUTED_DEBUG=DETAIL` and `NCCL_DEBUG=INFO` for PyTorch/NCCL logs
|
|
352
|
+
"""
|
|
353
|
+
if dist is None:
|
|
354
|
+
return
|
|
355
|
+
try:
|
|
356
|
+
if not getattr(dist, "is_available", lambda: False)():
|
|
357
|
+
return
|
|
358
|
+
if not dist.is_initialized():
|
|
359
|
+
return
|
|
360
|
+
except Exception:
|
|
361
|
+
return
|
|
362
|
+
|
|
363
|
+
timeout_seconds = int(os.environ.get("BAYESOPT_DDP_BARRIER_TIMEOUT", "1800"))
|
|
364
|
+
debug_barrier = os.environ.get("BAYESOPT_DDP_BARRIER_DEBUG", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
|
|
365
|
+
rank = None
|
|
366
|
+
world = None
|
|
367
|
+
if debug_barrier:
|
|
368
|
+
try:
|
|
369
|
+
rank = dist.get_rank()
|
|
370
|
+
world = dist.get_world_size()
|
|
371
|
+
_log(f"[DDP][{self.label}] entering barrier({reason}) rank={rank}/{world}", flush=True)
|
|
372
|
+
except Exception:
|
|
373
|
+
debug_barrier = False
|
|
374
|
+
try:
|
|
375
|
+
timeout = timedelta(seconds=timeout_seconds)
|
|
376
|
+
backend = None
|
|
377
|
+
try:
|
|
378
|
+
backend = dist.get_backend()
|
|
379
|
+
except Exception:
|
|
380
|
+
backend = None
|
|
381
|
+
|
|
382
|
+
# `monitored_barrier` is only implemented for GLOO; using it under NCCL
|
|
383
|
+
# will raise and can itself trigger a secondary hang. Prefer an async
|
|
384
|
+
# barrier with timeout for NCCL.
|
|
385
|
+
monitored = getattr(dist, "monitored_barrier", None)
|
|
386
|
+
if backend == "gloo" and callable(monitored):
|
|
387
|
+
monitored(timeout=timeout)
|
|
388
|
+
else:
|
|
389
|
+
work = None
|
|
390
|
+
try:
|
|
391
|
+
work = dist.barrier(async_op=True)
|
|
392
|
+
except TypeError:
|
|
393
|
+
work = None
|
|
394
|
+
if work is not None:
|
|
395
|
+
wait = getattr(work, "wait", None)
|
|
396
|
+
if callable(wait):
|
|
397
|
+
try:
|
|
398
|
+
wait(timeout=timeout)
|
|
399
|
+
except TypeError:
|
|
400
|
+
wait()
|
|
401
|
+
else:
|
|
402
|
+
dist.barrier()
|
|
403
|
+
else:
|
|
404
|
+
dist.barrier()
|
|
405
|
+
if debug_barrier:
|
|
406
|
+
_log(f"[DDP][{self.label}] exit barrier({reason}) rank={rank}/{world}", flush=True)
|
|
407
|
+
except Exception as exc:
|
|
408
|
+
_log(
|
|
409
|
+
f"[DDP][{self.label}] barrier failed during {reason}: {exc}",
|
|
410
|
+
flush=True,
|
|
411
|
+
)
|
|
412
|
+
raise
|
|
413
|
+
|
|
414
|
+
@property
|
|
415
|
+
def config(self) -> BayesOptConfig:
|
|
416
|
+
return self.ctx.config
|
|
417
|
+
|
|
418
|
+
@property
|
|
419
|
+
def output(self) -> OutputManager:
|
|
420
|
+
return self.ctx.output_manager
|
|
421
|
+
|
|
422
|
+
def _get_model_filename(self) -> str:
|
|
423
|
+
ext = 'pkl' if self.label in ['Xgboost', 'GLM'] else 'pth'
|
|
424
|
+
return f'01_{self.ctx.model_nme}_{self.model_name_prefix}.{ext}'
|
|
425
|
+
|
|
426
|
+
def _resolve_optuna_storage_url(self) -> Optional[str]:
|
|
427
|
+
storage = getattr(self.config, "optuna_storage", None)
|
|
428
|
+
if not storage:
|
|
429
|
+
return None
|
|
430
|
+
storage_str = str(storage).strip()
|
|
431
|
+
if not storage_str:
|
|
432
|
+
return None
|
|
433
|
+
if "://" in storage_str or storage_str == ":memory:":
|
|
434
|
+
return storage_str
|
|
435
|
+
path = Path(storage_str)
|
|
436
|
+
path = path.resolve()
|
|
437
|
+
ensure_parent_dir(str(path))
|
|
438
|
+
return f"sqlite:///{path.as_posix()}"
|
|
439
|
+
|
|
436
440
|
def _resolve_optuna_study_name(self) -> str:
|
|
437
441
|
prefix = getattr(self.config, "optuna_study_prefix",
|
|
438
442
|
None) or "bayesopt"
|
|
@@ -440,869 +444,877 @@ class TrainerBase:
|
|
|
440
444
|
safe = "".join([c if c.isalnum() or c in "._-" else "_" for c in raw])
|
|
441
445
|
return safe.lower()
|
|
442
446
|
|
|
443
|
-
def
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
self.
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
447
|
+
def _optuna_cleanup_sync(self) -> bool:
|
|
448
|
+
return bool(getattr(self.config, "optuna_cleanup_synchronize", False))
|
|
449
|
+
|
|
450
|
+
def tune(self, max_evals: int, objective_fn=None) -> None:
|
|
451
|
+
# Generic Optuna tuning loop.
|
|
452
|
+
if objective_fn is None:
|
|
453
|
+
# If subclass doesn't provide objective_fn, default to cross_val.
|
|
454
|
+
objective_fn = self.cross_val
|
|
455
|
+
|
|
456
|
+
if self._should_use_distributed_optuna():
|
|
457
|
+
self._distributed_tune(max_evals, objective_fn)
|
|
458
|
+
return
|
|
459
|
+
|
|
460
|
+
total_trials = max(1, int(max_evals))
|
|
461
|
+
progress_counter = {"count": 0}
|
|
462
|
+
|
|
463
|
+
def objective_wrapper(trial: optuna.trial.Trial) -> float:
|
|
464
|
+
should_log = DistributedUtils.is_main_process()
|
|
465
|
+
if should_log:
|
|
466
|
+
current_idx = progress_counter["count"] + 1
|
|
467
|
+
_log(
|
|
468
|
+
f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
|
|
469
|
+
f"(trial_id={trial.number})."
|
|
470
|
+
)
|
|
471
|
+
try:
|
|
472
|
+
result = objective_fn(trial)
|
|
466
473
|
except RuntimeError as exc:
|
|
467
474
|
if "out of memory" in str(exc).lower():
|
|
468
|
-
|
|
475
|
+
_log(
|
|
469
476
|
f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
|
|
470
477
|
)
|
|
471
|
-
self._clean_gpu()
|
|
478
|
+
self._clean_gpu(synchronize=True)
|
|
472
479
|
raise optuna.TrialPruned() from exc
|
|
473
480
|
raise
|
|
474
481
|
finally:
|
|
475
|
-
self._clean_gpu()
|
|
476
|
-
if should_log:
|
|
477
|
-
progress_counter["count"] = progress_counter["count"] + 1
|
|
478
|
-
trial_state = getattr(trial, "state", None)
|
|
479
|
-
state_repr = getattr(trial_state, "name", "OK")
|
|
480
|
-
|
|
481
|
-
f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
|
|
482
|
-
f"(status={state_repr})."
|
|
483
|
-
)
|
|
484
|
-
return result
|
|
485
|
-
|
|
486
|
-
storage_url = self._resolve_optuna_storage_url()
|
|
487
|
-
study_name = self._resolve_optuna_study_name()
|
|
488
|
-
study_kwargs: Dict[str, Any] = {
|
|
489
|
-
"direction": "minimize",
|
|
490
|
-
"sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
|
|
491
|
-
}
|
|
492
|
-
if storage_url:
|
|
493
|
-
study_kwargs.update(
|
|
494
|
-
storage=storage_url,
|
|
495
|
-
study_name=study_name,
|
|
496
|
-
load_if_exists=True,
|
|
497
|
-
)
|
|
498
|
-
|
|
499
|
-
study = optuna.create_study(**study_kwargs)
|
|
500
|
-
self.study_name = getattr(study, "study_name", None)
|
|
501
|
-
|
|
502
|
-
def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
|
|
503
|
-
# Persist best_params after each trial to allow safe resume.
|
|
504
|
-
try:
|
|
505
|
-
best = getattr(check_study, "best_trial", None)
|
|
506
|
-
if best is None:
|
|
507
|
-
return
|
|
508
|
-
best_params = getattr(best, "params", None)
|
|
509
|
-
if not best_params:
|
|
510
|
-
return
|
|
511
|
-
params_path = self.output.result_path(
|
|
512
|
-
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
513
|
-
)
|
|
514
|
-
pd.DataFrame(best_params, index=[0]).to_csv(
|
|
515
|
-
params_path, index=False)
|
|
516
|
-
except Exception:
|
|
517
|
-
return
|
|
518
|
-
|
|
519
|
-
completed_states = (
|
|
520
|
-
optuna.trial.TrialState.COMPLETE,
|
|
521
|
-
optuna.trial.TrialState.PRUNED,
|
|
522
|
-
optuna.trial.TrialState.FAIL,
|
|
523
|
-
)
|
|
524
|
-
completed = len(study.get_trials(states=completed_states))
|
|
525
|
-
progress_counter["count"] = completed
|
|
526
|
-
remaining = max(0, total_trials - completed)
|
|
527
|
-
if remaining > 0:
|
|
528
|
-
study.optimize(
|
|
529
|
-
objective_wrapper,
|
|
530
|
-
n_trials=remaining,
|
|
531
|
-
callbacks=[checkpoint_callback],
|
|
532
|
-
)
|
|
533
|
-
self.best_params = study.best_params
|
|
534
|
-
self.best_trial = study.best_trial
|
|
535
|
-
|
|
536
|
-
# Save best params to CSV for reproducibility.
|
|
537
|
-
params_path = self.output.result_path(
|
|
538
|
-
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
539
|
-
)
|
|
540
|
-
pd.DataFrame(self.best_params, index=[0]).to_csv(
|
|
541
|
-
params_path, index=False)
|
|
542
|
-
|
|
543
|
-
def train(self) -> None:
|
|
544
|
-
raise NotImplementedError
|
|
545
|
-
|
|
546
|
-
def _unwrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
|
547
|
-
"""Unwrap DDP or DataParallel wrapper to get the base module."""
|
|
548
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
549
|
-
if isinstance(module, (DDP, torch.nn.DataParallel)):
|
|
550
|
-
return module.module
|
|
551
|
-
return module
|
|
552
|
-
|
|
553
|
-
def save(self) -> None:
|
|
554
|
-
if self.model is None:
|
|
555
|
-
|
|
556
|
-
return
|
|
557
|
-
|
|
558
|
-
path = self.output.model_path(self._get_model_filename())
|
|
559
|
-
if self.label in ['Xgboost', 'GLM']:
|
|
560
|
-
payload = {
|
|
561
|
-
"model": self.model,
|
|
562
|
-
"preprocess_artifacts": self._export_preprocess_artifacts(),
|
|
563
|
-
}
|
|
564
|
-
joblib.dump(payload, path)
|
|
565
|
-
else:
|
|
566
|
-
# PyTorch models: save state_dict without DDP/DataParallel wrappers
|
|
567
|
-
# to ensure cross-platform compatibility.
|
|
568
|
-
payload = {
|
|
569
|
-
"preprocess_artifacts": self._export_preprocess_artifacts(),
|
|
570
|
-
}
|
|
571
|
-
if hasattr(self.model, 'resnet'): # ResNetSklearn model
|
|
572
|
-
# Unwrap DDP/DataParallel and move to CPU
|
|
573
|
-
resnet = self._unwrap_module(self.model.resnet)
|
|
574
|
-
resnet_cpu = resnet.to("cpu")
|
|
575
|
-
payload["state_dict"] = resnet_cpu.state_dict()
|
|
576
|
-
payload["best_params"] = dict(self.best_params or {})
|
|
577
|
-
elif hasattr(self.model, 'ft'): # FTTransformerSklearn model
|
|
578
|
-
# Unwrap DDP/DataParallel and save state_dict (not full model object)
|
|
579
|
-
# to avoid serialization issues with DDP wrappers
|
|
580
|
-
ft = self._unwrap_module(self.model.ft)
|
|
581
|
-
ft_cpu = ft.to("cpu")
|
|
582
|
-
payload["state_dict"] = ft_cpu.state_dict()
|
|
583
|
-
payload["best_params"] = dict(self.best_params or {})
|
|
584
|
-
# Save model configuration for reconstruction
|
|
585
|
-
payload["model_config"] = {
|
|
586
|
-
"model_nme": getattr(self.model, "model_nme", ""),
|
|
587
|
-
"num_cols": list(getattr(self.model, "num_cols", [])),
|
|
588
|
-
"cat_cols": list(getattr(self.model, "cat_cols", [])),
|
|
589
|
-
"d_model": getattr(self.model, "d_model", 64),
|
|
590
|
-
"n_heads": getattr(self.model, "n_heads", 8),
|
|
591
|
-
"n_layers": getattr(self.model, "n_layers", 4),
|
|
592
|
-
"dropout": getattr(self.model, "dropout", 0.1),
|
|
593
|
-
"task_type": getattr(self.model, "task_type", "regression"),
|
|
594
|
-
"loss_name": getattr(self.model, "loss_name", None),
|
|
595
|
-
"tw_power": getattr(self.model, "tw_power", 1.5),
|
|
596
|
-
"num_geo": getattr(self.model, "num_geo", 0),
|
|
597
|
-
"num_numeric_tokens": getattr(self.model, "num_numeric_tokens", None),
|
|
598
|
-
"cat_cardinalities": getattr(self.model, "cat_cardinalities", None),
|
|
599
|
-
"cat_categories": {k: list(v) for k, v in getattr(self.model, "cat_categories", {}).items()},
|
|
600
|
-
"_num_mean": getattr(self.model, "_num_mean", None),
|
|
601
|
-
"_num_std": getattr(self.model, "_num_std", None),
|
|
602
|
-
}
|
|
603
|
-
# Convert numpy arrays to lists for JSON serialization
|
|
604
|
-
if payload["model_config"]["_num_mean"] is not None:
|
|
605
|
-
payload["model_config"]["_num_mean"] = payload["model_config"]["_num_mean"].tolist() if hasattr(payload["model_config"]["_num_mean"], "tolist") else payload["model_config"]["_num_mean"]
|
|
606
|
-
if payload["model_config"]["_num_std"] is not None:
|
|
607
|
-
payload["model_config"]["_num_std"] = payload["model_config"]["_num_std"].tolist() if hasattr(payload["model_config"]["_num_std"], "tolist") else payload["model_config"]["_num_std"]
|
|
608
|
-
else:
|
|
609
|
-
# Generic PyTorch model fallback
|
|
610
|
-
if hasattr(self.model, 'to'):
|
|
611
|
-
self.model.to("cpu")
|
|
612
|
-
payload["model"] = self.model
|
|
613
|
-
torch.save(payload, path)
|
|
614
|
-
|
|
615
|
-
def load(self) -> None:
|
|
616
|
-
path = self.output.model_path(self._get_model_filename())
|
|
617
|
-
if not os.path.exists(path):
|
|
618
|
-
|
|
619
|
-
return
|
|
620
|
-
|
|
621
|
-
if self.label in ['Xgboost', 'GLM']:
|
|
622
|
-
loaded = joblib.load(path)
|
|
623
|
-
if isinstance(loaded, dict) and "model" in loaded:
|
|
624
|
-
self.model = loaded.get("model")
|
|
625
|
-
else:
|
|
626
|
-
self.model = loaded
|
|
627
|
-
else:
|
|
628
|
-
# PyTorch loading depends on the model structure.
|
|
629
|
-
if self.label == 'ResNet' or self.label == 'ResNetClassifier':
|
|
630
|
-
# ResNet requires reconstructing the skeleton; handled by subclass.
|
|
631
|
-
pass
|
|
632
|
-
else:
|
|
633
|
-
# FT-Transformer: load state_dict and reconstruct model
|
|
634
|
-
loaded = torch_load(path, map_location='cpu', weights_only=False)
|
|
635
|
-
if isinstance(loaded, dict):
|
|
636
|
-
if "state_dict" in loaded and "model_config" in loaded:
|
|
637
|
-
# New format: state_dict + model_config
|
|
638
|
-
state_dict = loaded.get("state_dict")
|
|
639
|
-
model_config = loaded.get("model_config", {})
|
|
640
|
-
self.best_params = loaded.get("best_params", {})
|
|
641
|
-
|
|
642
|
-
# Import FTTransformerSklearn for reconstruction
|
|
643
|
-
from
|
|
644
|
-
|
|
645
|
-
# Reconstruct model from config
|
|
646
|
-
model = FTTransformerSklearn(
|
|
647
|
-
model_nme=model_config.get("model_nme", ""),
|
|
648
|
-
num_cols=model_config.get("num_cols", []),
|
|
649
|
-
cat_cols=model_config.get("cat_cols", []),
|
|
650
|
-
d_model=model_config.get("d_model", 64),
|
|
651
|
-
n_heads=model_config.get("n_heads", 8),
|
|
652
|
-
n_layers=model_config.get("n_layers", 4),
|
|
653
|
-
dropout=model_config.get("dropout", 0.1),
|
|
654
|
-
task_type=model_config.get("task_type", "regression"),
|
|
655
|
-
loss_name=model_config.get("loss_name", None),
|
|
656
|
-
tweedie_power=model_config.get("tw_power", 1.5),
|
|
657
|
-
num_numeric_tokens=model_config.get("num_numeric_tokens"),
|
|
658
|
-
use_data_parallel=False,
|
|
659
|
-
use_ddp=False,
|
|
660
|
-
)
|
|
661
|
-
# Restore internal state
|
|
662
|
-
model.num_geo = model_config.get("num_geo", 0)
|
|
663
|
-
model.cat_cardinalities = model_config.get("cat_cardinalities")
|
|
664
|
-
model.cat_categories = {k: pd.Index(v) for k, v in model_config.get("cat_categories", {}).items()}
|
|
665
|
-
if model_config.get("_num_mean") is not None:
|
|
666
|
-
model._num_mean = np.array(model_config["_num_mean"], dtype=np.float32)
|
|
667
|
-
if model_config.get("_num_std") is not None:
|
|
668
|
-
model._num_std = np.array(model_config["_num_std"], dtype=np.float32)
|
|
669
|
-
|
|
670
|
-
# Build the model architecture and load weights
|
|
671
|
-
# We need a dummy dataframe to initialize the model
|
|
672
|
-
if model.cat_cardinalities is not None:
|
|
673
|
-
from
|
|
674
|
-
core = FTTransformerCore(
|
|
675
|
-
num_numeric=len(model.num_cols),
|
|
676
|
-
cat_cardinalities=model.cat_cardinalities,
|
|
677
|
-
d_model=model.d_model,
|
|
678
|
-
n_heads=model.n_heads,
|
|
679
|
-
n_layers=model.n_layers,
|
|
680
|
-
dropout=model.dropout,
|
|
681
|
-
task_type=model.task_type,
|
|
682
|
-
num_geo=model.num_geo,
|
|
683
|
-
num_numeric_tokens=model.num_numeric_tokens,
|
|
684
|
-
)
|
|
685
|
-
model.ft = core
|
|
686
|
-
model.ft.load_state_dict(state_dict)
|
|
687
|
-
|
|
688
|
-
self._move_to_device(model)
|
|
689
|
-
self.model = model
|
|
690
|
-
elif "model" in loaded:
|
|
691
|
-
# Legacy format: full model object
|
|
692
|
-
loaded_model = loaded.get("model")
|
|
693
|
-
if loaded_model is not None:
|
|
694
|
-
self._move_to_device(loaded_model)
|
|
695
|
-
self.model = loaded_model
|
|
696
|
-
else:
|
|
697
|
-
# Unknown format
|
|
698
|
-
|
|
699
|
-
else:
|
|
700
|
-
# Very old format: direct model object
|
|
701
|
-
if loaded is not None:
|
|
702
|
-
self._move_to_device(loaded)
|
|
703
|
-
self.model = loaded
|
|
704
|
-
|
|
705
|
-
def _move_to_device(self, model_obj):
|
|
706
|
-
"""Move model to the best available device using shared DeviceManager."""
|
|
707
|
-
DeviceManager.move_to_device(model_obj)
|
|
708
|
-
|
|
709
|
-
def _should_use_distributed_optuna(self) -> bool:
|
|
710
|
-
if not self.enable_distributed_optuna:
|
|
711
|
-
return False
|
|
712
|
-
rank_env = os.environ.get("RANK")
|
|
713
|
-
world_env = os.environ.get("WORLD_SIZE")
|
|
714
|
-
local_env = os.environ.get("LOCAL_RANK")
|
|
715
|
-
if rank_env is None or world_env is None or local_env is None:
|
|
716
|
-
return False
|
|
717
|
-
try:
|
|
718
|
-
world_size = int(world_env)
|
|
719
|
-
except Exception:
|
|
720
|
-
return False
|
|
721
|
-
return world_size > 1
|
|
722
|
-
|
|
723
|
-
def _distributed_is_main(self) -> bool:
|
|
724
|
-
return DistributedUtils.is_main_process()
|
|
725
|
-
|
|
726
|
-
def _distributed_send_command(self, payload: Dict[str, Any]) -> None:
|
|
727
|
-
if not self._should_use_distributed_optuna() or not self._distributed_is_main():
|
|
728
|
-
return
|
|
729
|
-
if dist is None:
|
|
730
|
-
return
|
|
731
|
-
DistributedUtils.setup_ddp()
|
|
732
|
-
if not dist.is_initialized():
|
|
733
|
-
return
|
|
734
|
-
message = [payload]
|
|
735
|
-
dist.broadcast_object_list(message, src=0)
|
|
736
|
-
|
|
737
|
-
def _distributed_prepare_trial(self, params: Dict[str, Any]) -> None:
|
|
738
|
-
if not self._should_use_distributed_optuna():
|
|
739
|
-
return
|
|
740
|
-
if not self._distributed_is_main():
|
|
741
|
-
return
|
|
742
|
-
if dist is None:
|
|
743
|
-
return
|
|
744
|
-
self._distributed_send_command({"type": "RUN", "params": params})
|
|
745
|
-
if not dist.is_initialized():
|
|
746
|
-
return
|
|
747
|
-
# STEP 2 (DDP/Optuna): make sure all ranks start the trial together.
|
|
748
|
-
self._dist_barrier("prepare_trial")
|
|
749
|
-
|
|
750
|
-
def _distributed_worker_loop(self, objective_fn: Callable[[Optional[optuna.trial.Trial]], float]) -> None:
|
|
751
|
-
if dist is None:
|
|
752
|
-
|
|
753
|
-
f"[Optuna][Worker][{self.label}] torch.distributed unavailable. Worker exit.",
|
|
754
|
-
flush=True,
|
|
755
|
-
)
|
|
756
|
-
return
|
|
757
|
-
DistributedUtils.setup_ddp()
|
|
758
|
-
if not dist.is_initialized():
|
|
759
|
-
|
|
760
|
-
f"[Optuna][Worker][{self.label}] DDP init failed. Worker exit.",
|
|
761
|
-
flush=True,
|
|
762
|
-
)
|
|
763
|
-
return
|
|
764
|
-
while True:
|
|
765
|
-
message = [None]
|
|
766
|
-
dist.broadcast_object_list(message, src=0)
|
|
767
|
-
payload = message[0]
|
|
768
|
-
if not isinstance(payload, dict):
|
|
769
|
-
continue
|
|
770
|
-
cmd = payload.get("type")
|
|
771
|
-
if cmd == "STOP":
|
|
772
|
-
best_params = payload.get("best_params")
|
|
773
|
-
if best_params is not None:
|
|
774
|
-
self.best_params = best_params
|
|
775
|
-
break
|
|
776
|
-
if cmd == "RUN":
|
|
777
|
-
params = payload.get("params") or {}
|
|
778
|
-
self._distributed_forced_params = params
|
|
779
|
-
# STEP 2 (DDP/Optuna): align worker with rank0 before running objective_fn.
|
|
780
|
-
self._dist_barrier("worker_start")
|
|
781
|
-
try:
|
|
782
|
-
objective_fn(None)
|
|
783
|
-
except optuna.TrialPruned:
|
|
784
|
-
pass
|
|
785
|
-
except Exception as exc:
|
|
786
|
-
|
|
787
|
-
f"[Optuna][Worker][{self.label}] Exception: {exc}", flush=True)
|
|
482
|
+
self._clean_gpu(synchronize=self._optuna_cleanup_sync())
|
|
483
|
+
if should_log:
|
|
484
|
+
progress_counter["count"] = progress_counter["count"] + 1
|
|
485
|
+
trial_state = getattr(trial, "state", None)
|
|
486
|
+
state_repr = getattr(trial_state, "name", "OK")
|
|
487
|
+
_log(
|
|
488
|
+
f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
|
|
489
|
+
f"(status={state_repr})."
|
|
490
|
+
)
|
|
491
|
+
return result
|
|
492
|
+
|
|
493
|
+
storage_url = self._resolve_optuna_storage_url()
|
|
494
|
+
study_name = self._resolve_optuna_study_name()
|
|
495
|
+
study_kwargs: Dict[str, Any] = {
|
|
496
|
+
"direction": "minimize",
|
|
497
|
+
"sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
|
|
498
|
+
}
|
|
499
|
+
if storage_url:
|
|
500
|
+
study_kwargs.update(
|
|
501
|
+
storage=storage_url,
|
|
502
|
+
study_name=study_name,
|
|
503
|
+
load_if_exists=True,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
study = optuna.create_study(**study_kwargs)
|
|
507
|
+
self.study_name = getattr(study, "study_name", None)
|
|
508
|
+
|
|
509
|
+
def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
|
|
510
|
+
# Persist best_params after each trial to allow safe resume.
|
|
511
|
+
try:
|
|
512
|
+
best = getattr(check_study, "best_trial", None)
|
|
513
|
+
if best is None:
|
|
514
|
+
return
|
|
515
|
+
best_params = getattr(best, "params", None)
|
|
516
|
+
if not best_params:
|
|
517
|
+
return
|
|
518
|
+
params_path = self.output.result_path(
|
|
519
|
+
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
520
|
+
)
|
|
521
|
+
pd.DataFrame(best_params, index=[0]).to_csv(
|
|
522
|
+
params_path, index=False)
|
|
523
|
+
except Exception:
|
|
524
|
+
return
|
|
525
|
+
|
|
526
|
+
completed_states = (
|
|
527
|
+
optuna.trial.TrialState.COMPLETE,
|
|
528
|
+
optuna.trial.TrialState.PRUNED,
|
|
529
|
+
optuna.trial.TrialState.FAIL,
|
|
530
|
+
)
|
|
531
|
+
completed = len(study.get_trials(states=completed_states))
|
|
532
|
+
progress_counter["count"] = completed
|
|
533
|
+
remaining = max(0, total_trials - completed)
|
|
534
|
+
if remaining > 0:
|
|
535
|
+
study.optimize(
|
|
536
|
+
objective_wrapper,
|
|
537
|
+
n_trials=remaining,
|
|
538
|
+
callbacks=[checkpoint_callback],
|
|
539
|
+
)
|
|
540
|
+
self.best_params = study.best_params
|
|
541
|
+
self.best_trial = study.best_trial
|
|
542
|
+
|
|
543
|
+
# Save best params to CSV for reproducibility.
|
|
544
|
+
params_path = self.output.result_path(
|
|
545
|
+
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
546
|
+
)
|
|
547
|
+
pd.DataFrame(self.best_params, index=[0]).to_csv(
|
|
548
|
+
params_path, index=False)
|
|
549
|
+
|
|
550
|
+
def train(self) -> None:
|
|
551
|
+
raise NotImplementedError
|
|
552
|
+
|
|
553
|
+
def _unwrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
|
554
|
+
"""Unwrap DDP or DataParallel wrapper to get the base module."""
|
|
555
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
556
|
+
if isinstance(module, (DDP, torch.nn.DataParallel)):
|
|
557
|
+
return module.module
|
|
558
|
+
return module
|
|
559
|
+
|
|
560
|
+
def save(self) -> None:
|
|
561
|
+
if self.model is None:
|
|
562
|
+
_log(f"[save] Warning: No model to save for {self.label}")
|
|
563
|
+
return
|
|
564
|
+
|
|
565
|
+
path = self.output.model_path(self._get_model_filename())
|
|
566
|
+
if self.label in ['Xgboost', 'GLM']:
|
|
567
|
+
payload = {
|
|
568
|
+
"model": self.model,
|
|
569
|
+
"preprocess_artifacts": self._export_preprocess_artifacts(),
|
|
570
|
+
}
|
|
571
|
+
joblib.dump(payload, path)
|
|
572
|
+
else:
|
|
573
|
+
# PyTorch models: save state_dict without DDP/DataParallel wrappers
|
|
574
|
+
# to ensure cross-platform compatibility.
|
|
575
|
+
payload = {
|
|
576
|
+
"preprocess_artifacts": self._export_preprocess_artifacts(),
|
|
577
|
+
}
|
|
578
|
+
if hasattr(self.model, 'resnet'): # ResNetSklearn model
|
|
579
|
+
# Unwrap DDP/DataParallel and move to CPU
|
|
580
|
+
resnet = self._unwrap_module(self.model.resnet)
|
|
581
|
+
resnet_cpu = resnet.to("cpu")
|
|
582
|
+
payload["state_dict"] = resnet_cpu.state_dict()
|
|
583
|
+
payload["best_params"] = dict(self.best_params or {})
|
|
584
|
+
elif hasattr(self.model, 'ft'): # FTTransformerSklearn model
|
|
585
|
+
# Unwrap DDP/DataParallel and save state_dict (not full model object)
|
|
586
|
+
# to avoid serialization issues with DDP wrappers
|
|
587
|
+
ft = self._unwrap_module(self.model.ft)
|
|
588
|
+
ft_cpu = ft.to("cpu")
|
|
589
|
+
payload["state_dict"] = ft_cpu.state_dict()
|
|
590
|
+
payload["best_params"] = dict(self.best_params or {})
|
|
591
|
+
# Save model configuration for reconstruction
|
|
592
|
+
payload["model_config"] = {
|
|
593
|
+
"model_nme": getattr(self.model, "model_nme", ""),
|
|
594
|
+
"num_cols": list(getattr(self.model, "num_cols", [])),
|
|
595
|
+
"cat_cols": list(getattr(self.model, "cat_cols", [])),
|
|
596
|
+
"d_model": getattr(self.model, "d_model", 64),
|
|
597
|
+
"n_heads": getattr(self.model, "n_heads", 8),
|
|
598
|
+
"n_layers": getattr(self.model, "n_layers", 4),
|
|
599
|
+
"dropout": getattr(self.model, "dropout", 0.1),
|
|
600
|
+
"task_type": getattr(self.model, "task_type", "regression"),
|
|
601
|
+
"loss_name": getattr(self.model, "loss_name", None),
|
|
602
|
+
"tw_power": getattr(self.model, "tw_power", 1.5),
|
|
603
|
+
"num_geo": getattr(self.model, "num_geo", 0),
|
|
604
|
+
"num_numeric_tokens": getattr(self.model, "num_numeric_tokens", None),
|
|
605
|
+
"cat_cardinalities": getattr(self.model, "cat_cardinalities", None),
|
|
606
|
+
"cat_categories": {k: list(v) for k, v in getattr(self.model, "cat_categories", {}).items()},
|
|
607
|
+
"_num_mean": getattr(self.model, "_num_mean", None),
|
|
608
|
+
"_num_std": getattr(self.model, "_num_std", None),
|
|
609
|
+
}
|
|
610
|
+
# Convert numpy arrays to lists for JSON serialization
|
|
611
|
+
if payload["model_config"]["_num_mean"] is not None:
|
|
612
|
+
payload["model_config"]["_num_mean"] = payload["model_config"]["_num_mean"].tolist() if hasattr(payload["model_config"]["_num_mean"], "tolist") else payload["model_config"]["_num_mean"]
|
|
613
|
+
if payload["model_config"]["_num_std"] is not None:
|
|
614
|
+
payload["model_config"]["_num_std"] = payload["model_config"]["_num_std"].tolist() if hasattr(payload["model_config"]["_num_std"], "tolist") else payload["model_config"]["_num_std"]
|
|
615
|
+
else:
|
|
616
|
+
# Generic PyTorch model fallback
|
|
617
|
+
if hasattr(self.model, 'to'):
|
|
618
|
+
self.model.to("cpu")
|
|
619
|
+
payload["model"] = self.model
|
|
620
|
+
torch.save(payload, path)
|
|
621
|
+
|
|
622
|
+
def load(self) -> None:
|
|
623
|
+
path = self.output.model_path(self._get_model_filename())
|
|
624
|
+
if not os.path.exists(path):
|
|
625
|
+
_log(f"[load] Warning: Model file not found: {path}")
|
|
626
|
+
return
|
|
627
|
+
|
|
628
|
+
if self.label in ['Xgboost', 'GLM']:
|
|
629
|
+
loaded = joblib.load(path)
|
|
630
|
+
if isinstance(loaded, dict) and "model" in loaded:
|
|
631
|
+
self.model = loaded.get("model")
|
|
632
|
+
else:
|
|
633
|
+
self.model = loaded
|
|
634
|
+
else:
|
|
635
|
+
# PyTorch loading depends on the model structure.
|
|
636
|
+
if self.label == 'ResNet' or self.label == 'ResNetClassifier':
|
|
637
|
+
# ResNet requires reconstructing the skeleton; handled by subclass.
|
|
638
|
+
pass
|
|
639
|
+
else:
|
|
640
|
+
# FT-Transformer: load state_dict and reconstruct model
|
|
641
|
+
loaded = torch_load(path, map_location='cpu', weights_only=False)
|
|
642
|
+
if isinstance(loaded, dict):
|
|
643
|
+
if "state_dict" in loaded and "model_config" in loaded:
|
|
644
|
+
# New format: state_dict + model_config
|
|
645
|
+
state_dict = loaded.get("state_dict")
|
|
646
|
+
model_config = loaded.get("model_config", {})
|
|
647
|
+
self.best_params = loaded.get("best_params", {})
|
|
648
|
+
|
|
649
|
+
# Import FTTransformerSklearn for reconstruction
|
|
650
|
+
from ins_pricing.modelling.bayesopt.models import FTTransformerSklearn
|
|
651
|
+
|
|
652
|
+
# Reconstruct model from config
|
|
653
|
+
model = FTTransformerSklearn(
|
|
654
|
+
model_nme=model_config.get("model_nme", ""),
|
|
655
|
+
num_cols=model_config.get("num_cols", []),
|
|
656
|
+
cat_cols=model_config.get("cat_cols", []),
|
|
657
|
+
d_model=model_config.get("d_model", 64),
|
|
658
|
+
n_heads=model_config.get("n_heads", 8),
|
|
659
|
+
n_layers=model_config.get("n_layers", 4),
|
|
660
|
+
dropout=model_config.get("dropout", 0.1),
|
|
661
|
+
task_type=model_config.get("task_type", "regression"),
|
|
662
|
+
loss_name=model_config.get("loss_name", None),
|
|
663
|
+
tweedie_power=model_config.get("tw_power", 1.5),
|
|
664
|
+
num_numeric_tokens=model_config.get("num_numeric_tokens"),
|
|
665
|
+
use_data_parallel=False,
|
|
666
|
+
use_ddp=False,
|
|
667
|
+
)
|
|
668
|
+
# Restore internal state
|
|
669
|
+
model.num_geo = model_config.get("num_geo", 0)
|
|
670
|
+
model.cat_cardinalities = model_config.get("cat_cardinalities")
|
|
671
|
+
model.cat_categories = {k: pd.Index(v) for k, v in model_config.get("cat_categories", {}).items()}
|
|
672
|
+
if model_config.get("_num_mean") is not None:
|
|
673
|
+
model._num_mean = np.array(model_config["_num_mean"], dtype=np.float32)
|
|
674
|
+
if model_config.get("_num_std") is not None:
|
|
675
|
+
model._num_std = np.array(model_config["_num_std"], dtype=np.float32)
|
|
676
|
+
|
|
677
|
+
# Build the model architecture and load weights
|
|
678
|
+
# We need a dummy dataframe to initialize the model
|
|
679
|
+
if model.cat_cardinalities is not None:
|
|
680
|
+
from ins_pricing.modelling.bayesopt.models.model_ft_components import FTTransformerCore
|
|
681
|
+
core = FTTransformerCore(
|
|
682
|
+
num_numeric=len(model.num_cols),
|
|
683
|
+
cat_cardinalities=model.cat_cardinalities,
|
|
684
|
+
d_model=model.d_model,
|
|
685
|
+
n_heads=model.n_heads,
|
|
686
|
+
n_layers=model.n_layers,
|
|
687
|
+
dropout=model.dropout,
|
|
688
|
+
task_type=model.task_type,
|
|
689
|
+
num_geo=model.num_geo,
|
|
690
|
+
num_numeric_tokens=model.num_numeric_tokens,
|
|
691
|
+
)
|
|
692
|
+
model.ft = core
|
|
693
|
+
model.ft.load_state_dict(state_dict)
|
|
694
|
+
|
|
695
|
+
self._move_to_device(model)
|
|
696
|
+
self.model = model
|
|
697
|
+
elif "model" in loaded:
|
|
698
|
+
# Legacy format: full model object
|
|
699
|
+
loaded_model = loaded.get("model")
|
|
700
|
+
if loaded_model is not None:
|
|
701
|
+
self._move_to_device(loaded_model)
|
|
702
|
+
self.model = loaded_model
|
|
703
|
+
else:
|
|
704
|
+
# Unknown format
|
|
705
|
+
_log(f"[load] Warning: Unknown model format in {path}")
|
|
706
|
+
else:
|
|
707
|
+
# Very old format: direct model object
|
|
708
|
+
if loaded is not None:
|
|
709
|
+
self._move_to_device(loaded)
|
|
710
|
+
self.model = loaded
|
|
711
|
+
|
|
712
|
+
def _move_to_device(self, model_obj):
|
|
713
|
+
"""Move model to the best available device using shared DeviceManager."""
|
|
714
|
+
DeviceManager.move_to_device(model_obj)
|
|
715
|
+
|
|
716
|
+
def _should_use_distributed_optuna(self) -> bool:
|
|
717
|
+
if not self.enable_distributed_optuna:
|
|
718
|
+
return False
|
|
719
|
+
rank_env = os.environ.get("RANK")
|
|
720
|
+
world_env = os.environ.get("WORLD_SIZE")
|
|
721
|
+
local_env = os.environ.get("LOCAL_RANK")
|
|
722
|
+
if rank_env is None or world_env is None or local_env is None:
|
|
723
|
+
return False
|
|
724
|
+
try:
|
|
725
|
+
world_size = int(world_env)
|
|
726
|
+
except Exception:
|
|
727
|
+
return False
|
|
728
|
+
return world_size > 1
|
|
729
|
+
|
|
730
|
+
def _distributed_is_main(self) -> bool:
|
|
731
|
+
return DistributedUtils.is_main_process()
|
|
732
|
+
|
|
733
|
+
def _distributed_send_command(self, payload: Dict[str, Any]) -> None:
|
|
734
|
+
if not self._should_use_distributed_optuna() or not self._distributed_is_main():
|
|
735
|
+
return
|
|
736
|
+
if dist is None:
|
|
737
|
+
return
|
|
738
|
+
DistributedUtils.setup_ddp()
|
|
739
|
+
if not dist.is_initialized():
|
|
740
|
+
return
|
|
741
|
+
message = [payload]
|
|
742
|
+
dist.broadcast_object_list(message, src=0)
|
|
743
|
+
|
|
744
|
+
def _distributed_prepare_trial(self, params: Dict[str, Any]) -> None:
|
|
745
|
+
if not self._should_use_distributed_optuna():
|
|
746
|
+
return
|
|
747
|
+
if not self._distributed_is_main():
|
|
748
|
+
return
|
|
749
|
+
if dist is None:
|
|
750
|
+
return
|
|
751
|
+
self._distributed_send_command({"type": "RUN", "params": params})
|
|
752
|
+
if not dist.is_initialized():
|
|
753
|
+
return
|
|
754
|
+
# STEP 2 (DDP/Optuna): make sure all ranks start the trial together.
|
|
755
|
+
self._dist_barrier("prepare_trial")
|
|
756
|
+
|
|
757
|
+
def _distributed_worker_loop(self, objective_fn: Callable[[Optional[optuna.trial.Trial]], float]) -> None:
|
|
758
|
+
if dist is None:
|
|
759
|
+
_log(
|
|
760
|
+
f"[Optuna][Worker][{self.label}] torch.distributed unavailable. Worker exit.",
|
|
761
|
+
flush=True,
|
|
762
|
+
)
|
|
763
|
+
return
|
|
764
|
+
DistributedUtils.setup_ddp()
|
|
765
|
+
if not dist.is_initialized():
|
|
766
|
+
_log(
|
|
767
|
+
f"[Optuna][Worker][{self.label}] DDP init failed. Worker exit.",
|
|
768
|
+
flush=True,
|
|
769
|
+
)
|
|
770
|
+
return
|
|
771
|
+
while True:
|
|
772
|
+
message = [None]
|
|
773
|
+
dist.broadcast_object_list(message, src=0)
|
|
774
|
+
payload = message[0]
|
|
775
|
+
if not isinstance(payload, dict):
|
|
776
|
+
continue
|
|
777
|
+
cmd = payload.get("type")
|
|
778
|
+
if cmd == "STOP":
|
|
779
|
+
best_params = payload.get("best_params")
|
|
780
|
+
if best_params is not None:
|
|
781
|
+
self.best_params = best_params
|
|
782
|
+
break
|
|
783
|
+
if cmd == "RUN":
|
|
784
|
+
params = payload.get("params") or {}
|
|
785
|
+
self._distributed_forced_params = params
|
|
786
|
+
# STEP 2 (DDP/Optuna): align worker with rank0 before running objective_fn.
|
|
787
|
+
self._dist_barrier("worker_start")
|
|
788
|
+
try:
|
|
789
|
+
objective_fn(None)
|
|
790
|
+
except optuna.TrialPruned:
|
|
791
|
+
pass
|
|
792
|
+
except Exception as exc:
|
|
793
|
+
_log(
|
|
794
|
+
f"[Optuna][Worker][{self.label}] Exception: {exc}", flush=True)
|
|
788
795
|
finally:
|
|
789
|
-
self._clean_gpu()
|
|
796
|
+
self._clean_gpu(synchronize=self._optuna_cleanup_sync())
|
|
790
797
|
# STEP 2 (DDP/Optuna): align worker with rank0 after objective_fn returns/raises.
|
|
791
798
|
self._dist_barrier("worker_end")
|
|
792
|
-
|
|
793
|
-
def _distributed_tune(self, max_evals: int, objective_fn: Callable[[optuna.trial.Trial], float]) -> None:
|
|
794
|
-
if dist is None:
|
|
795
|
-
|
|
796
|
-
f"[Optuna][{self.label}] torch.distributed unavailable. Fallback to single-process.",
|
|
797
|
-
flush=True,
|
|
798
|
-
)
|
|
799
|
-
prev = self.enable_distributed_optuna
|
|
800
|
-
self.enable_distributed_optuna = False
|
|
801
|
-
try:
|
|
802
|
-
self.tune(max_evals, objective_fn)
|
|
803
|
-
finally:
|
|
804
|
-
self.enable_distributed_optuna = prev
|
|
805
|
-
return
|
|
806
|
-
DistributedUtils.setup_ddp()
|
|
807
|
-
if not dist.is_initialized():
|
|
808
|
-
rank_env = os.environ.get("RANK", "0")
|
|
809
|
-
if str(rank_env) != "0":
|
|
810
|
-
|
|
811
|
-
f"[Optuna][{self.label}] DDP init failed on worker. Skip.",
|
|
812
|
-
flush=True,
|
|
813
|
-
)
|
|
814
|
-
return
|
|
815
|
-
|
|
816
|
-
f"[Optuna][{self.label}] DDP init failed. Fallback to single-process.",
|
|
817
|
-
flush=True,
|
|
818
|
-
)
|
|
819
|
-
prev = self.enable_distributed_optuna
|
|
820
|
-
self.enable_distributed_optuna = False
|
|
821
|
-
try:
|
|
822
|
-
self.tune(max_evals, objective_fn)
|
|
823
|
-
finally:
|
|
824
|
-
self.enable_distributed_optuna = prev
|
|
825
|
-
return
|
|
826
|
-
if not self._distributed_is_main():
|
|
827
|
-
self._distributed_worker_loop(objective_fn)
|
|
828
|
-
return
|
|
829
|
-
|
|
830
|
-
total_trials = max(1, int(max_evals))
|
|
831
|
-
progress_counter = {"count": 0}
|
|
832
|
-
|
|
833
|
-
def objective_wrapper(trial: optuna.trial.Trial) -> float:
|
|
834
|
-
should_log = True
|
|
835
|
-
if should_log:
|
|
836
|
-
current_idx = progress_counter["count"] + 1
|
|
837
|
-
|
|
838
|
-
f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
|
|
839
|
-
f"(trial_id={trial.number})."
|
|
840
|
-
)
|
|
841
|
-
try:
|
|
842
|
-
result = objective_fn(trial)
|
|
799
|
+
|
|
800
|
+
def _distributed_tune(self, max_evals: int, objective_fn: Callable[[optuna.trial.Trial], float]) -> None:
|
|
801
|
+
if dist is None:
|
|
802
|
+
_log(
|
|
803
|
+
f"[Optuna][{self.label}] torch.distributed unavailable. Fallback to single-process.",
|
|
804
|
+
flush=True,
|
|
805
|
+
)
|
|
806
|
+
prev = self.enable_distributed_optuna
|
|
807
|
+
self.enable_distributed_optuna = False
|
|
808
|
+
try:
|
|
809
|
+
self.tune(max_evals, objective_fn)
|
|
810
|
+
finally:
|
|
811
|
+
self.enable_distributed_optuna = prev
|
|
812
|
+
return
|
|
813
|
+
DistributedUtils.setup_ddp()
|
|
814
|
+
if not dist.is_initialized():
|
|
815
|
+
rank_env = os.environ.get("RANK", "0")
|
|
816
|
+
if str(rank_env) != "0":
|
|
817
|
+
_log(
|
|
818
|
+
f"[Optuna][{self.label}] DDP init failed on worker. Skip.",
|
|
819
|
+
flush=True,
|
|
820
|
+
)
|
|
821
|
+
return
|
|
822
|
+
_log(
|
|
823
|
+
f"[Optuna][{self.label}] DDP init failed. Fallback to single-process.",
|
|
824
|
+
flush=True,
|
|
825
|
+
)
|
|
826
|
+
prev = self.enable_distributed_optuna
|
|
827
|
+
self.enable_distributed_optuna = False
|
|
828
|
+
try:
|
|
829
|
+
self.tune(max_evals, objective_fn)
|
|
830
|
+
finally:
|
|
831
|
+
self.enable_distributed_optuna = prev
|
|
832
|
+
return
|
|
833
|
+
if not self._distributed_is_main():
|
|
834
|
+
self._distributed_worker_loop(objective_fn)
|
|
835
|
+
return
|
|
836
|
+
|
|
837
|
+
total_trials = max(1, int(max_evals))
|
|
838
|
+
progress_counter = {"count": 0}
|
|
839
|
+
|
|
840
|
+
def objective_wrapper(trial: optuna.trial.Trial) -> float:
|
|
841
|
+
should_log = True
|
|
842
|
+
if should_log:
|
|
843
|
+
current_idx = progress_counter["count"] + 1
|
|
844
|
+
_log(
|
|
845
|
+
f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
|
|
846
|
+
f"(trial_id={trial.number})."
|
|
847
|
+
)
|
|
848
|
+
try:
|
|
849
|
+
result = objective_fn(trial)
|
|
843
850
|
except RuntimeError as exc:
|
|
844
851
|
if "out of memory" in str(exc).lower():
|
|
845
|
-
|
|
852
|
+
_log(
|
|
846
853
|
f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
|
|
847
854
|
)
|
|
848
|
-
self._clean_gpu()
|
|
855
|
+
self._clean_gpu(synchronize=True)
|
|
849
856
|
raise optuna.TrialPruned() from exc
|
|
850
857
|
raise
|
|
851
858
|
finally:
|
|
852
|
-
self._clean_gpu()
|
|
853
|
-
if should_log:
|
|
854
|
-
progress_counter["count"] = progress_counter["count"] + 1
|
|
855
|
-
trial_state = getattr(trial, "state", None)
|
|
856
|
-
state_repr = getattr(trial_state, "name", "OK")
|
|
857
|
-
|
|
858
|
-
f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
|
|
859
|
-
f"(status={state_repr})."
|
|
860
|
-
)
|
|
861
|
-
# STEP 2 (DDP/Optuna): a trial-end sync point; debug with BAYESOPT_DDP_BARRIER_DEBUG=1.
|
|
862
|
-
self._dist_barrier("trial_end")
|
|
863
|
-
return result
|
|
864
|
-
|
|
865
|
-
storage_url = self._resolve_optuna_storage_url()
|
|
866
|
-
study_name = self._resolve_optuna_study_name()
|
|
867
|
-
study_kwargs: Dict[str, Any] = {
|
|
868
|
-
"direction": "minimize",
|
|
869
|
-
"sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
|
|
870
|
-
}
|
|
871
|
-
if storage_url:
|
|
872
|
-
study_kwargs.update(
|
|
873
|
-
storage=storage_url,
|
|
874
|
-
study_name=study_name,
|
|
875
|
-
load_if_exists=True,
|
|
876
|
-
)
|
|
877
|
-
study = optuna.create_study(**study_kwargs)
|
|
878
|
-
self.study_name = getattr(study, "study_name", None)
|
|
879
|
-
|
|
880
|
-
def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
|
|
881
|
-
try:
|
|
882
|
-
best = getattr(check_study, "best_trial", None)
|
|
883
|
-
if best is None:
|
|
884
|
-
return
|
|
885
|
-
best_params = getattr(best, "params", None)
|
|
886
|
-
if not best_params:
|
|
887
|
-
return
|
|
888
|
-
params_path = self.output.result_path(
|
|
889
|
-
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
890
|
-
)
|
|
891
|
-
pd.DataFrame(best_params, index=[0]).to_csv(
|
|
892
|
-
params_path, index=False)
|
|
893
|
-
except Exception:
|
|
894
|
-
return
|
|
895
|
-
|
|
896
|
-
completed_states = (
|
|
897
|
-
optuna.trial.TrialState.COMPLETE,
|
|
898
|
-
optuna.trial.TrialState.PRUNED,
|
|
899
|
-
optuna.trial.TrialState.FAIL,
|
|
900
|
-
)
|
|
901
|
-
completed = len(study.get_trials(states=completed_states))
|
|
902
|
-
progress_counter["count"] = completed
|
|
903
|
-
remaining = max(0, total_trials - completed)
|
|
904
|
-
try:
|
|
905
|
-
if remaining > 0:
|
|
906
|
-
study.optimize(
|
|
907
|
-
objective_wrapper,
|
|
908
|
-
n_trials=remaining,
|
|
909
|
-
callbacks=[checkpoint_callback],
|
|
910
|
-
)
|
|
911
|
-
self.best_params = study.best_params
|
|
912
|
-
self.best_trial = study.best_trial
|
|
913
|
-
params_path = self.output.result_path(
|
|
914
|
-
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
915
|
-
)
|
|
916
|
-
pd.DataFrame(self.best_params, index=[0]).to_csv(
|
|
917
|
-
params_path, index=False)
|
|
918
|
-
finally:
|
|
919
|
-
self._distributed_send_command(
|
|
920
|
-
{"type": "STOP", "best_params": self.best_params})
|
|
921
|
-
|
|
922
|
-
def _clean_gpu(
|
|
923
|
-
"""Clean up GPU memory using shared GPUMemoryManager."""
|
|
924
|
-
GPUMemoryManager.clean()
|
|
925
|
-
|
|
926
|
-
def _standardize_fold(self,
|
|
927
|
-
X_train: pd.DataFrame,
|
|
928
|
-
X_val: pd.DataFrame,
|
|
929
|
-
columns: Optional[List[str]] = None
|
|
930
|
-
) -> Tuple[pd.DataFrame, pd.DataFrame, StandardScaler]:
|
|
931
|
-
"""Fit StandardScaler on the training fold and transform train/val features.
|
|
932
|
-
|
|
933
|
-
Args:
|
|
934
|
-
X_train: training features.
|
|
935
|
-
X_val: validation features.
|
|
936
|
-
columns: columns to scale (default: all).
|
|
937
|
-
|
|
938
|
-
Returns:
|
|
939
|
-
Scaled train/val features and the fitted scaler.
|
|
940
|
-
"""
|
|
941
|
-
scaler = StandardScaler()
|
|
942
|
-
cols = list(columns) if columns else list(X_train.columns)
|
|
943
|
-
X_train_scaled = X_train.copy(deep=True)
|
|
944
|
-
X_val_scaled = X_val.copy(deep=True)
|
|
945
|
-
if cols:
|
|
946
|
-
scaler.fit(X_train_scaled[cols])
|
|
947
|
-
X_train_scaled[cols] = scaler.transform(X_train_scaled[cols])
|
|
948
|
-
X_val_scaled[cols] = scaler.transform(X_val_scaled[cols])
|
|
949
|
-
return X_train_scaled, X_val_scaled, scaler
|
|
950
|
-
|
|
951
|
-
def _resolve_train_val_indices(
|
|
952
|
-
self,
|
|
953
|
-
X_all: pd.DataFrame,
|
|
954
|
-
*,
|
|
955
|
-
allow_default: bool = False,
|
|
956
|
-
) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
957
|
-
"""Resolve train/validation split indices based on configured CV strategy.
|
|
958
|
-
|
|
959
|
-
Args:
|
|
960
|
-
X_all: DataFrame to split
|
|
961
|
-
allow_default: If True, use default val_ratio when config is invalid
|
|
962
|
-
|
|
963
|
-
Returns:
|
|
964
|
-
Tuple of (train_indices, val_indices) or None if not enough data
|
|
965
|
-
"""
|
|
966
|
-
val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
|
|
967
|
-
if not (0.0 < val_ratio < 1.0):
|
|
968
|
-
if not allow_default:
|
|
969
|
-
return None
|
|
970
|
-
val_ratio = 0.25
|
|
971
|
-
if len(X_all) < 10:
|
|
972
|
-
return None
|
|
973
|
-
|
|
974
|
-
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
975
|
-
(train_idx, val_idx), _ = resolver.create_train_val_splitter(X_all, val_ratio)
|
|
976
|
-
return train_idx, val_idx
|
|
977
|
-
|
|
978
|
-
def _resolve_time_sample_indices(
|
|
979
|
-
self,
|
|
980
|
-
X_all: pd.DataFrame,
|
|
981
|
-
sample_limit: int,
|
|
982
|
-
) -> Optional[pd.Index]:
|
|
983
|
-
"""Get the most recent indices for time-based sampling.
|
|
984
|
-
|
|
985
|
-
For time-based CV strategies, returns the last `sample_limit` indices
|
|
986
|
-
ordered by time. For other strategies, returns None.
|
|
987
|
-
|
|
988
|
-
Args:
|
|
989
|
-
X_all: DataFrame to sample from
|
|
990
|
-
sample_limit: Maximum number of samples to return
|
|
991
|
-
|
|
992
|
-
Returns:
|
|
993
|
-
Index of sampled rows, or None if not using time-based strategy
|
|
994
|
-
"""
|
|
995
|
-
if sample_limit <= 0:
|
|
996
|
-
return None
|
|
997
|
-
|
|
998
|
-
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
999
|
-
if not resolver.is_time_strategy():
|
|
1000
|
-
return None
|
|
1001
|
-
|
|
1002
|
-
order = resolver.get_time_ordered_indices(X_all)
|
|
1003
|
-
if len(order) == 0:
|
|
1004
|
-
return None
|
|
1005
|
-
|
|
1006
|
-
# Get the last sample_limit indices (most recent in time)
|
|
1007
|
-
if len(order) > sample_limit:
|
|
1008
|
-
order = order[-sample_limit:]
|
|
1009
|
-
|
|
1010
|
-
return X_all.index[order]
|
|
1011
|
-
|
|
1012
|
-
def _resolve_ensemble_splits(
|
|
859
|
+
self._clean_gpu(synchronize=self._optuna_cleanup_sync())
|
|
860
|
+
if should_log:
|
|
861
|
+
progress_counter["count"] = progress_counter["count"] + 1
|
|
862
|
+
trial_state = getattr(trial, "state", None)
|
|
863
|
+
state_repr = getattr(trial_state, "name", "OK")
|
|
864
|
+
_log(
|
|
865
|
+
f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
|
|
866
|
+
f"(status={state_repr})."
|
|
867
|
+
)
|
|
868
|
+
# STEP 2 (DDP/Optuna): a trial-end sync point; debug with BAYESOPT_DDP_BARRIER_DEBUG=1.
|
|
869
|
+
self._dist_barrier("trial_end")
|
|
870
|
+
return result
|
|
871
|
+
|
|
872
|
+
storage_url = self._resolve_optuna_storage_url()
|
|
873
|
+
study_name = self._resolve_optuna_study_name()
|
|
874
|
+
study_kwargs: Dict[str, Any] = {
|
|
875
|
+
"direction": "minimize",
|
|
876
|
+
"sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
|
|
877
|
+
}
|
|
878
|
+
if storage_url:
|
|
879
|
+
study_kwargs.update(
|
|
880
|
+
storage=storage_url,
|
|
881
|
+
study_name=study_name,
|
|
882
|
+
load_if_exists=True,
|
|
883
|
+
)
|
|
884
|
+
study = optuna.create_study(**study_kwargs)
|
|
885
|
+
self.study_name = getattr(study, "study_name", None)
|
|
886
|
+
|
|
887
|
+
def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
|
|
888
|
+
try:
|
|
889
|
+
best = getattr(check_study, "best_trial", None)
|
|
890
|
+
if best is None:
|
|
891
|
+
return
|
|
892
|
+
best_params = getattr(best, "params", None)
|
|
893
|
+
if not best_params:
|
|
894
|
+
return
|
|
895
|
+
params_path = self.output.result_path(
|
|
896
|
+
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
897
|
+
)
|
|
898
|
+
pd.DataFrame(best_params, index=[0]).to_csv(
|
|
899
|
+
params_path, index=False)
|
|
900
|
+
except Exception:
|
|
901
|
+
return
|
|
902
|
+
|
|
903
|
+
completed_states = (
|
|
904
|
+
optuna.trial.TrialState.COMPLETE,
|
|
905
|
+
optuna.trial.TrialState.PRUNED,
|
|
906
|
+
optuna.trial.TrialState.FAIL,
|
|
907
|
+
)
|
|
908
|
+
completed = len(study.get_trials(states=completed_states))
|
|
909
|
+
progress_counter["count"] = completed
|
|
910
|
+
remaining = max(0, total_trials - completed)
|
|
911
|
+
try:
|
|
912
|
+
if remaining > 0:
|
|
913
|
+
study.optimize(
|
|
914
|
+
objective_wrapper,
|
|
915
|
+
n_trials=remaining,
|
|
916
|
+
callbacks=[checkpoint_callback],
|
|
917
|
+
)
|
|
918
|
+
self.best_params = study.best_params
|
|
919
|
+
self.best_trial = study.best_trial
|
|
920
|
+
params_path = self.output.result_path(
|
|
921
|
+
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
922
|
+
)
|
|
923
|
+
pd.DataFrame(self.best_params, index=[0]).to_csv(
|
|
924
|
+
params_path, index=False)
|
|
925
|
+
finally:
|
|
926
|
+
self._distributed_send_command(
|
|
927
|
+
{"type": "STOP", "best_params": self.best_params})
|
|
928
|
+
|
|
929
|
+
def _clean_gpu(
|
|
1013
930
|
self,
|
|
1014
|
-
X_all: pd.DataFrame,
|
|
1015
931
|
*,
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
] =
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
X_all
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
if
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
)
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
self.ctx.
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
)
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
932
|
+
synchronize: bool = True,
|
|
933
|
+
empty_cache: bool = True,
|
|
934
|
+
) -> None:
|
|
935
|
+
"""Clean up GPU memory using shared GPUMemoryManager."""
|
|
936
|
+
GPUMemoryManager.clean(synchronize=synchronize, empty_cache=empty_cache)
|
|
937
|
+
|
|
938
|
+
def _standardize_fold(self,
|
|
939
|
+
X_train: pd.DataFrame,
|
|
940
|
+
X_val: pd.DataFrame,
|
|
941
|
+
columns: Optional[List[str]] = None
|
|
942
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame, StandardScaler]:
|
|
943
|
+
"""Fit StandardScaler on the training fold and transform train/val features.
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
X_train: training features.
|
|
947
|
+
X_val: validation features.
|
|
948
|
+
columns: columns to scale (default: all).
|
|
949
|
+
|
|
950
|
+
Returns:
|
|
951
|
+
Scaled train/val features and the fitted scaler.
|
|
952
|
+
"""
|
|
953
|
+
scaler = StandardScaler()
|
|
954
|
+
cols = list(columns) if columns else list(X_train.columns)
|
|
955
|
+
X_train_scaled = X_train.copy(deep=True)
|
|
956
|
+
X_val_scaled = X_val.copy(deep=True)
|
|
957
|
+
if cols:
|
|
958
|
+
scaler.fit(X_train_scaled[cols])
|
|
959
|
+
X_train_scaled[cols] = scaler.transform(X_train_scaled[cols])
|
|
960
|
+
X_val_scaled[cols] = scaler.transform(X_val_scaled[cols])
|
|
961
|
+
return X_train_scaled, X_val_scaled, scaler
|
|
962
|
+
|
|
963
|
+
def _resolve_train_val_indices(
|
|
964
|
+
self,
|
|
965
|
+
X_all: pd.DataFrame,
|
|
966
|
+
*,
|
|
967
|
+
allow_default: bool = False,
|
|
968
|
+
) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
969
|
+
"""Resolve train/validation split indices based on configured CV strategy.
|
|
970
|
+
|
|
971
|
+
Args:
|
|
972
|
+
X_all: DataFrame to split
|
|
973
|
+
allow_default: If True, use default val_ratio when config is invalid
|
|
974
|
+
|
|
975
|
+
Returns:
|
|
976
|
+
Tuple of (train_indices, val_indices) or None if not enough data
|
|
977
|
+
"""
|
|
978
|
+
val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
|
|
979
|
+
if not (0.0 < val_ratio < 1.0):
|
|
980
|
+
if not allow_default:
|
|
981
|
+
return None
|
|
982
|
+
val_ratio = 0.25
|
|
983
|
+
if len(X_all) < 10:
|
|
984
|
+
return None
|
|
985
|
+
|
|
986
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
987
|
+
(train_idx, val_idx), _ = resolver.create_train_val_splitter(X_all, val_ratio)
|
|
988
|
+
return train_idx, val_idx
|
|
989
|
+
|
|
990
|
+
def _resolve_time_sample_indices(
|
|
991
|
+
self,
|
|
992
|
+
X_all: pd.DataFrame,
|
|
993
|
+
sample_limit: int,
|
|
994
|
+
) -> Optional[pd.Index]:
|
|
995
|
+
"""Get the most recent indices for time-based sampling.
|
|
996
|
+
|
|
997
|
+
For time-based CV strategies, returns the last `sample_limit` indices
|
|
998
|
+
ordered by time. For other strategies, returns None.
|
|
999
|
+
|
|
1000
|
+
Args:
|
|
1001
|
+
X_all: DataFrame to sample from
|
|
1002
|
+
sample_limit: Maximum number of samples to return
|
|
1003
|
+
|
|
1004
|
+
Returns:
|
|
1005
|
+
Index of sampled rows, or None if not using time-based strategy
|
|
1006
|
+
"""
|
|
1007
|
+
if sample_limit <= 0:
|
|
1008
|
+
return None
|
|
1009
|
+
|
|
1010
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
1011
|
+
if not resolver.is_time_strategy():
|
|
1012
|
+
return None
|
|
1013
|
+
|
|
1014
|
+
order = resolver.get_time_ordered_indices(X_all)
|
|
1015
|
+
if len(order) == 0:
|
|
1016
|
+
return None
|
|
1017
|
+
|
|
1018
|
+
# Get the last sample_limit indices (most recent in time)
|
|
1019
|
+
if len(order) > sample_limit:
|
|
1020
|
+
order = order[-sample_limit:]
|
|
1021
|
+
|
|
1022
|
+
return X_all.index[order]
|
|
1023
|
+
|
|
1024
|
+
def _resolve_ensemble_splits(
|
|
1025
|
+
self,
|
|
1026
|
+
X_all: pd.DataFrame,
|
|
1027
|
+
*,
|
|
1028
|
+
k: int,
|
|
1029
|
+
) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
|
|
1030
|
+
"""Resolve K-fold splits for ensemble training based on configured CV strategy.
|
|
1031
|
+
|
|
1032
|
+
Args:
|
|
1033
|
+
X_all: DataFrame to split
|
|
1034
|
+
k: Number of folds requested
|
|
1035
|
+
|
|
1036
|
+
Returns:
|
|
1037
|
+
Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
|
|
1038
|
+
"""
|
|
1039
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
1040
|
+
return resolver.create_kfold_splitter(X_all, k)
|
|
1041
|
+
|
|
1042
|
+
def cross_val_generic(
|
|
1043
|
+
self,
|
|
1044
|
+
trial: optuna.trial.Trial,
|
|
1045
|
+
hyperparameter_space: Dict[str, Callable[[optuna.trial.Trial], Any]],
|
|
1046
|
+
data_provider: Callable[[], Tuple[pd.DataFrame, pd.Series, Optional[pd.Series]]],
|
|
1047
|
+
model_builder: Callable[[Dict[str, Any]], Any],
|
|
1048
|
+
metric_fn: Callable[[pd.Series, np.ndarray, Optional[pd.Series]], float],
|
|
1049
|
+
sample_limit: Optional[int] = None,
|
|
1050
|
+
preprocess_fn: Optional[Callable[[
|
|
1051
|
+
pd.DataFrame, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]] = None,
|
|
1052
|
+
fit_predict_fn: Optional[
|
|
1053
|
+
Callable[[Any, pd.DataFrame, pd.Series, Optional[pd.Series],
|
|
1054
|
+
pd.DataFrame, pd.Series, Optional[pd.Series],
|
|
1055
|
+
optuna.trial.Trial], np.ndarray]
|
|
1056
|
+
] = None,
|
|
1057
|
+
cleanup_fn: Optional[Callable[[Any], None]] = None,
|
|
1058
|
+
splitter: Optional[Iterable[Tuple[np.ndarray, np.ndarray]]] = None) -> float:
|
|
1059
|
+
"""Generic holdout/CV helper to reuse tuning workflows.
|
|
1060
|
+
|
|
1061
|
+
Args:
|
|
1062
|
+
trial: current Optuna trial.
|
|
1063
|
+
hyperparameter_space: sampler dict keyed by parameter name.
|
|
1064
|
+
data_provider: callback returning (X, y, sample_weight).
|
|
1065
|
+
model_builder: callback to build a model per fold.
|
|
1066
|
+
metric_fn: loss/score function taking y_true, y_pred, weight.
|
|
1067
|
+
sample_limit: optional sample cap; random sample if exceeded.
|
|
1068
|
+
preprocess_fn: optional per-fold preprocessing (X_train, X_val).
|
|
1069
|
+
fit_predict_fn: optional custom fit/predict logic for validation.
|
|
1070
|
+
cleanup_fn: optional cleanup callback per fold.
|
|
1071
|
+
splitter: optional (train_idx, val_idx) iterator; defaults to cv_strategy config.
|
|
1072
|
+
|
|
1073
|
+
Returns:
|
|
1074
|
+
Mean validation metric across folds.
|
|
1075
|
+
"""
|
|
1076
|
+
params: Optional[Dict[str, Any]] = None
|
|
1077
|
+
if self._distributed_forced_params is not None:
|
|
1078
|
+
params = self._distributed_forced_params
|
|
1079
|
+
self._distributed_forced_params = None
|
|
1080
|
+
else:
|
|
1081
|
+
if trial is None:
|
|
1082
|
+
raise RuntimeError(
|
|
1083
|
+
"Missing Optuna trial for parameter sampling.")
|
|
1084
|
+
params = {name: sampler(trial)
|
|
1085
|
+
for name, sampler in hyperparameter_space.items()}
|
|
1086
|
+
if self._should_use_distributed_optuna():
|
|
1087
|
+
self._distributed_prepare_trial(params)
|
|
1088
|
+
X_all, y_all, w_all = data_provider()
|
|
1089
|
+
cfg_limit = getattr(self.ctx.config, "bo_sample_limit", None)
|
|
1090
|
+
if cfg_limit is not None:
|
|
1091
|
+
cfg_limit = int(cfg_limit)
|
|
1092
|
+
if cfg_limit > 0:
|
|
1093
|
+
sample_limit = cfg_limit if sample_limit is None else min(sample_limit, cfg_limit)
|
|
1094
|
+
if sample_limit is not None and len(X_all) > sample_limit:
|
|
1095
|
+
sampled_idx = self._resolve_time_sample_indices(X_all, int(sample_limit))
|
|
1096
|
+
if sampled_idx is None:
|
|
1097
|
+
sampled_idx = X_all.sample(
|
|
1098
|
+
n=sample_limit,
|
|
1099
|
+
random_state=self.ctx.rand_seed
|
|
1100
|
+
).index
|
|
1101
|
+
X_all = X_all.loc[sampled_idx]
|
|
1102
|
+
y_all = y_all.loc[sampled_idx]
|
|
1103
|
+
w_all = w_all.loc[sampled_idx] if w_all is not None else None
|
|
1104
|
+
|
|
1105
|
+
if splitter is None:
|
|
1106
|
+
val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
|
|
1107
|
+
if not (0.0 < val_ratio < 1.0):
|
|
1108
|
+
val_ratio = 0.25
|
|
1109
|
+
cv_splits = getattr(self.ctx.config, "cv_splits", None)
|
|
1110
|
+
if cv_splits is None:
|
|
1111
|
+
cv_splits = max(2, int(round(1 / val_ratio)))
|
|
1112
|
+
cv_splits = max(2, int(cv_splits))
|
|
1113
|
+
|
|
1114
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
1115
|
+
split_iter, actual_splits = resolver.create_cv_splitter(X_all, y_all, cv_splits, val_ratio)
|
|
1116
|
+
if actual_splits < 2:
|
|
1117
|
+
raise ValueError("Not enough samples for cross-validation.")
|
|
1118
|
+
else:
|
|
1119
|
+
if hasattr(splitter, "split"):
|
|
1120
|
+
split_iter = splitter.split(X_all, y_all, groups=None)
|
|
1121
|
+
else:
|
|
1122
|
+
split_iter = splitter
|
|
1123
|
+
|
|
1124
|
+
losses: List[float] = []
|
|
1125
|
+
for fold_idx, (train_idx, val_idx) in enumerate(split_iter):
|
|
1126
|
+
X_train = X_all.iloc[train_idx]
|
|
1127
|
+
y_train = y_all.iloc[train_idx]
|
|
1128
|
+
X_val = X_all.iloc[val_idx]
|
|
1129
|
+
y_val = y_all.iloc[val_idx]
|
|
1130
|
+
w_train = w_all.iloc[train_idx] if w_all is not None else None
|
|
1131
|
+
w_val = w_all.iloc[val_idx] if w_all is not None else None
|
|
1132
|
+
|
|
1133
|
+
if preprocess_fn:
|
|
1134
|
+
X_train, X_val = preprocess_fn(X_train, X_val)
|
|
1135
|
+
|
|
1136
|
+
model = model_builder(params)
|
|
1137
|
+
try:
|
|
1138
|
+
if fit_predict_fn:
|
|
1139
|
+
# Avoid duplicate Optuna step reports across folds.
|
|
1140
|
+
trial_for_fold = trial if fold_idx == 0 else None
|
|
1141
|
+
y_pred = fit_predict_fn(
|
|
1142
|
+
model, X_train, y_train, w_train,
|
|
1143
|
+
X_val, y_val, w_val, trial_for_fold
|
|
1144
|
+
)
|
|
1145
|
+
else:
|
|
1146
|
+
fit_kwargs = {}
|
|
1147
|
+
if w_train is not None:
|
|
1148
|
+
fit_kwargs["sample_weight"] = w_train
|
|
1149
|
+
model.fit(X_train, y_train, **fit_kwargs)
|
|
1150
|
+
y_pred = model.predict(X_val)
|
|
1151
|
+
losses.append(metric_fn(y_val, y_pred, w_val))
|
|
1152
|
+
finally:
|
|
1153
|
+
if cleanup_fn:
|
|
1154
|
+
cleanup_fn(model)
|
|
1155
|
+
self._clean_gpu()
|
|
1156
|
+
|
|
1157
|
+
return float(np.mean(losses))
|
|
1158
|
+
|
|
1159
|
+
# Prediction + caching logic.
|
|
1160
|
+
def _predict_and_cache(self,
|
|
1161
|
+
model,
|
|
1162
|
+
pred_prefix: str,
|
|
1163
|
+
use_oht: bool = False,
|
|
1164
|
+
design_fn=None,
|
|
1165
|
+
predict_kwargs_train: Optional[Dict[str, Any]] = None,
|
|
1166
|
+
predict_kwargs_test: Optional[Dict[str, Any]] = None,
|
|
1167
|
+
predict_fn: Optional[Callable[..., Any]] = None) -> None:
|
|
1168
|
+
if design_fn:
|
|
1169
|
+
X_train = design_fn(train=True)
|
|
1170
|
+
X_test = design_fn(train=False)
|
|
1171
|
+
elif use_oht:
|
|
1172
|
+
X_train = self.ctx.train_oht_scl_data[self.ctx.var_nmes]
|
|
1173
|
+
X_test = self.ctx.test_oht_scl_data[self.ctx.var_nmes]
|
|
1174
|
+
else:
|
|
1175
|
+
X_train = self.ctx.train_data[self.ctx.factor_nmes]
|
|
1176
|
+
X_test = self.ctx.test_data[self.ctx.factor_nmes]
|
|
1177
|
+
|
|
1178
|
+
predictor = predict_fn or model.predict
|
|
1179
|
+
preds_train = predictor(X_train, **(predict_kwargs_train or {}))
|
|
1180
|
+
preds_test = predictor(X_test, **(predict_kwargs_test or {}))
|
|
1181
|
+
preds_train = np.asarray(preds_train)
|
|
1182
|
+
preds_test = np.asarray(preds_test)
|
|
1183
|
+
|
|
1184
|
+
if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
|
|
1185
|
+
col_name = f'pred_{pred_prefix}'
|
|
1186
|
+
self.ctx.train_data[col_name] = preds_train.reshape(-1)
|
|
1187
|
+
self.ctx.test_data[col_name] = preds_test.reshape(-1)
|
|
1188
|
+
self.ctx.train_data[f'w_{col_name}'] = (
|
|
1189
|
+
self.ctx.train_data[col_name] *
|
|
1190
|
+
self.ctx.train_data[self.ctx.weight_nme]
|
|
1191
|
+
)
|
|
1192
|
+
self.ctx.test_data[f'w_{col_name}'] = (
|
|
1193
|
+
self.ctx.test_data[col_name] *
|
|
1194
|
+
self.ctx.test_data[self.ctx.weight_nme]
|
|
1195
|
+
)
|
|
1196
|
+
self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
|
|
1197
|
+
return
|
|
1198
|
+
|
|
1199
|
+
# Vector outputs (e.g., embeddings) are expanded into pred_<prefix>_0.. columns.
|
|
1200
|
+
if preds_train.ndim != 2:
|
|
1201
|
+
raise ValueError(
|
|
1202
|
+
f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
|
|
1203
|
+
if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
|
|
1204
|
+
raise ValueError(
|
|
1205
|
+
f"Train/test prediction dims mismatch for '{pred_prefix}': "
|
|
1206
|
+
f"{preds_train.shape} vs {preds_test.shape}")
|
|
1207
|
+
for j in range(preds_train.shape[1]):
|
|
1208
|
+
col_name = f'pred_{pred_prefix}_{j}'
|
|
1209
|
+
self.ctx.train_data[col_name] = preds_train[:, j]
|
|
1210
|
+
self.ctx.test_data[col_name] = preds_test[:, j]
|
|
1211
|
+
self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
|
|
1212
|
+
|
|
1213
|
+
def _cache_predictions(self,
|
|
1214
|
+
pred_prefix: str,
|
|
1215
|
+
preds_train,
|
|
1216
|
+
preds_test) -> None:
|
|
1217
|
+
preds_train = np.asarray(preds_train)
|
|
1218
|
+
preds_test = np.asarray(preds_test)
|
|
1219
|
+
if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
|
|
1220
|
+
if preds_test.ndim > 1:
|
|
1221
|
+
preds_test = preds_test.reshape(-1)
|
|
1222
|
+
col_name = f'pred_{pred_prefix}'
|
|
1223
|
+
self.ctx.train_data[col_name] = preds_train.reshape(-1)
|
|
1224
|
+
self.ctx.test_data[col_name] = preds_test.reshape(-1)
|
|
1225
|
+
self.ctx.train_data[f'w_{col_name}'] = (
|
|
1226
|
+
self.ctx.train_data[col_name] *
|
|
1227
|
+
self.ctx.train_data[self.ctx.weight_nme]
|
|
1228
|
+
)
|
|
1229
|
+
self.ctx.test_data[f'w_{col_name}'] = (
|
|
1230
|
+
self.ctx.test_data[col_name] *
|
|
1231
|
+
self.ctx.test_data[self.ctx.weight_nme]
|
|
1232
|
+
)
|
|
1233
|
+
self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
|
|
1234
|
+
return
|
|
1235
|
+
|
|
1236
|
+
if preds_train.ndim != 2:
|
|
1237
|
+
raise ValueError(
|
|
1238
|
+
f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
|
|
1239
|
+
if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
|
|
1240
|
+
raise ValueError(
|
|
1241
|
+
f"Train/test prediction dims mismatch for '{pred_prefix}': "
|
|
1242
|
+
f"{preds_train.shape} vs {preds_test.shape}")
|
|
1243
|
+
for j in range(preds_train.shape[1]):
|
|
1244
|
+
col_name = f'pred_{pred_prefix}_{j}'
|
|
1245
|
+
self.ctx.train_data[col_name] = preds_train[:, j]
|
|
1246
|
+
self.ctx.test_data[col_name] = preds_test[:, j]
|
|
1247
|
+
self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
|
|
1248
|
+
|
|
1249
|
+
def _maybe_cache_predictions(self, pred_prefix: str, preds_train, preds_test) -> None:
|
|
1250
|
+
cfg = getattr(self.ctx, "config", None)
|
|
1251
|
+
if cfg is None or not bool(getattr(cfg, "cache_predictions", False)):
|
|
1252
|
+
return
|
|
1253
|
+
fmt = str(getattr(cfg, "prediction_cache_format", "parquet") or "parquet").lower()
|
|
1254
|
+
cache_dir = getattr(cfg, "prediction_cache_dir", None)
|
|
1255
|
+
if cache_dir:
|
|
1256
|
+
target_dir = Path(str(cache_dir))
|
|
1257
|
+
if not target_dir.is_absolute():
|
|
1258
|
+
target_dir = Path(self.output.result_dir) / target_dir
|
|
1259
|
+
else:
|
|
1260
|
+
target_dir = Path(self.output.result_dir) / "predictions"
|
|
1261
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
|
1262
|
+
|
|
1263
|
+
def _build_frame(preds, split_label: str) -> pd.DataFrame:
|
|
1264
|
+
arr = np.asarray(preds)
|
|
1265
|
+
if arr.ndim <= 1:
|
|
1266
|
+
return pd.DataFrame({f"pred_{pred_prefix}": arr.reshape(-1)})
|
|
1267
|
+
cols = [f"pred_{pred_prefix}_{i}" for i in range(arr.shape[1])]
|
|
1268
|
+
return pd.DataFrame(arr, columns=cols)
|
|
1269
|
+
|
|
1270
|
+
for split_label, preds in [("train", preds_train), ("test", preds_test)]:
|
|
1271
|
+
frame = _build_frame(preds, split_label)
|
|
1272
|
+
filename = f"{self.ctx.model_nme}_{pred_prefix}_{split_label}.{ 'csv' if fmt == 'csv' else 'parquet' }"
|
|
1273
|
+
path = target_dir / filename
|
|
1274
|
+
try:
|
|
1275
|
+
if fmt == "csv":
|
|
1276
|
+
frame.to_csv(path, index=False)
|
|
1277
|
+
else:
|
|
1278
|
+
frame.to_parquet(path, index=False)
|
|
1279
|
+
except Exception:
|
|
1280
|
+
pass
|
|
1281
|
+
|
|
1282
|
+
def _resolve_best_epoch(self,
|
|
1283
|
+
history: Optional[Dict[str, List[float]]],
|
|
1284
|
+
default_epochs: int) -> int:
|
|
1285
|
+
if not history:
|
|
1286
|
+
return max(1, int(default_epochs))
|
|
1287
|
+
vals = history.get("val") or []
|
|
1288
|
+
if not vals:
|
|
1289
|
+
return max(1, int(default_epochs))
|
|
1290
|
+
best_idx = int(np.nanargmin(vals))
|
|
1291
|
+
return max(1, best_idx + 1)
|
|
1292
|
+
|
|
1293
|
+
def _fit_predict_cache(self,
|
|
1294
|
+
model,
|
|
1295
|
+
X_train,
|
|
1296
|
+
y_train,
|
|
1297
|
+
sample_weight,
|
|
1298
|
+
pred_prefix: str,
|
|
1299
|
+
use_oht: bool = False,
|
|
1300
|
+
design_fn=None,
|
|
1301
|
+
fit_kwargs: Optional[Dict[str, Any]] = None,
|
|
1302
|
+
sample_weight_arg: Optional[str] = 'sample_weight',
|
|
1303
|
+
predict_kwargs_train: Optional[Dict[str, Any]] = None,
|
|
1304
|
+
predict_kwargs_test: Optional[Dict[str, Any]] = None,
|
|
1305
|
+
predict_fn: Optional[Callable[..., Any]] = None,
|
|
1306
|
+
record_label: bool = True) -> None:
|
|
1307
|
+
fit_kwargs = fit_kwargs.copy() if fit_kwargs else {}
|
|
1308
|
+
if sample_weight is not None and sample_weight_arg:
|
|
1309
|
+
fit_kwargs.setdefault(sample_weight_arg, sample_weight)
|
|
1310
|
+
model.fit(X_train, y_train, **fit_kwargs)
|
|
1311
|
+
if record_label:
|
|
1312
|
+
self.ctx.model_label.append(self.label)
|
|
1313
|
+
self._predict_and_cache(
|
|
1314
|
+
model,
|
|
1315
|
+
pred_prefix,
|
|
1316
|
+
use_oht=use_oht,
|
|
1317
|
+
design_fn=design_fn,
|
|
1318
|
+
predict_kwargs_train=predict_kwargs_train,
|
|
1319
|
+
predict_kwargs_test=predict_kwargs_test,
|
|
1320
|
+
predict_fn=predict_fn)
|