ins-pricing 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/modelling/core/bayesopt/config_preprocess.py +7 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +13 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +12 -7
- ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +1 -1
- ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +2 -1
- ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +8 -0
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.4.2.dist-info → ins_pricing-0.4.3.dist-info}/METADATA +1 -1
- {ins_pricing-0.4.2.dist-info → ins_pricing-0.4.3.dist-info}/RECORD +11 -11
- {ins_pricing-0.4.2.dist-info → ins_pricing-0.4.3.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.2.dist-info → ins_pricing-0.4.3.dist-info}/top_level.txt +0 -0
|
@@ -195,6 +195,7 @@ class BayesOptConfig:
|
|
|
195
195
|
cache_predictions: bool = False
|
|
196
196
|
prediction_cache_dir: Optional[str] = None
|
|
197
197
|
prediction_cache_format: str = "parquet"
|
|
198
|
+
dataloader_workers: Optional[int] = None
|
|
198
199
|
|
|
199
200
|
def __post_init__(self) -> None:
|
|
200
201
|
"""Validate configuration after initialization."""
|
|
@@ -210,6 +211,12 @@ class BayesOptConfig:
|
|
|
210
211
|
errors.append(
|
|
211
212
|
f"task_type must be one of {valid_task_types}, got '{self.task_type}'"
|
|
212
213
|
)
|
|
214
|
+
if self.dataloader_workers is not None:
|
|
215
|
+
try:
|
|
216
|
+
if int(self.dataloader_workers) < 0:
|
|
217
|
+
errors.append("dataloader_workers must be >= 0 when provided.")
|
|
218
|
+
except (TypeError, ValueError):
|
|
219
|
+
errors.append("dataloader_workers must be an integer when provided.")
|
|
213
220
|
# Validate loss_name
|
|
214
221
|
try:
|
|
215
222
|
normalized_loss = normalize_loss_name(self.loss_name, self.task_type)
|
|
@@ -306,6 +306,19 @@ class TrainerBase:
|
|
|
306
306
|
self.enable_distributed_optuna: bool = False
|
|
307
307
|
self._distributed_forced_params: Optional[Dict[str, Any]] = None
|
|
308
308
|
|
|
309
|
+
def _apply_dataloader_overrides(self, model: Any) -> Any:
|
|
310
|
+
"""Apply dataloader-related overrides from config to a model."""
|
|
311
|
+
cfg = getattr(self.ctx, "config", None)
|
|
312
|
+
if cfg is None:
|
|
313
|
+
return model
|
|
314
|
+
workers = getattr(cfg, "dataloader_workers", None)
|
|
315
|
+
if workers is not None:
|
|
316
|
+
model.dataloader_workers = int(workers)
|
|
317
|
+
profile = getattr(cfg, "resource_profile", None)
|
|
318
|
+
if profile:
|
|
319
|
+
model.resource_profile = str(profile)
|
|
320
|
+
return model
|
|
321
|
+
|
|
309
322
|
def _export_preprocess_artifacts(self) -> Dict[str, Any]:
|
|
310
323
|
dummy_columns: List[str] = []
|
|
311
324
|
if getattr(self.ctx, "train_oht_data", None) is not None:
|
|
@@ -163,6 +163,7 @@ class FTTrainer(TrainerBase):
|
|
|
163
163
|
num_numeric_tokens=num_numeric_tokens,
|
|
164
164
|
loss_name=loss_name,
|
|
165
165
|
)
|
|
166
|
+
model = self._apply_dataloader_overrides(model)
|
|
166
167
|
model.set_params(model_params)
|
|
167
168
|
try:
|
|
168
169
|
return float(model.fit_unsupervised(
|
|
@@ -248,7 +249,7 @@ class FTTrainer(TrainerBase):
|
|
|
248
249
|
requested_heads=params.get("n_heads")
|
|
249
250
|
)
|
|
250
251
|
|
|
251
|
-
|
|
252
|
+
model = FTTransformerSklearn(
|
|
252
253
|
model_nme=self.ctx.model_nme,
|
|
253
254
|
num_cols=self.ctx.num_features,
|
|
254
255
|
cat_cols=self.ctx.cate_list,
|
|
@@ -266,7 +267,10 @@ class FTTrainer(TrainerBase):
|
|
|
266
267
|
use_ddp=self.ctx.config.use_ft_ddp,
|
|
267
268
|
num_numeric_tokens=num_numeric_tokens,
|
|
268
269
|
loss_name=loss_name,
|
|
269
|
-
)
|
|
270
|
+
)
|
|
271
|
+
model = self._apply_dataloader_overrides(model)
|
|
272
|
+
model.set_params({"_geo_params": geo_params_local} if geo_enabled else {})
|
|
273
|
+
return model
|
|
270
274
|
|
|
271
275
|
def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
|
|
272
276
|
geo_train = geo_val = None
|
|
@@ -353,6 +357,7 @@ class FTTrainer(TrainerBase):
|
|
|
353
357
|
weight_decay=float(resolved_params.get("weight_decay", 0.0)),
|
|
354
358
|
loss_name=loss_name,
|
|
355
359
|
)
|
|
360
|
+
tmp_model = self._apply_dataloader_overrides(tmp_model)
|
|
356
361
|
tmp_model.set_params(resolved_params)
|
|
357
362
|
geo_train_full = self.ctx.train_geo_tokens
|
|
358
363
|
geo_train = None if geo_train_full is None else geo_train_full.iloc[train_idx]
|
|
@@ -387,6 +392,7 @@ class FTTrainer(TrainerBase):
|
|
|
387
392
|
weight_decay=float(resolved_params.get("weight_decay", 0.0)),
|
|
388
393
|
loss_name=loss_name,
|
|
389
394
|
)
|
|
395
|
+
self.model = self._apply_dataloader_overrides(self.model)
|
|
390
396
|
if refit_epochs is not None:
|
|
391
397
|
self.model.epochs = int(refit_epochs)
|
|
392
398
|
self.model.set_params(resolved_params)
|
|
@@ -460,6 +466,7 @@ class FTTrainer(TrainerBase):
|
|
|
460
466
|
weight_decay=float(resolved_params.get("weight_decay", 0.0)),
|
|
461
467
|
loss_name=loss_name,
|
|
462
468
|
)
|
|
469
|
+
model = self._apply_dataloader_overrides(model)
|
|
463
470
|
model.set_params(resolved_params)
|
|
464
471
|
|
|
465
472
|
geo_train = geo_val = None
|
|
@@ -565,6 +572,7 @@ class FTTrainer(TrainerBase):
|
|
|
565
572
|
num_numeric_tokens=self._resolve_numeric_tokens(),
|
|
566
573
|
loss_name=loss_name,
|
|
567
574
|
)
|
|
575
|
+
model = self._apply_dataloader_overrides(model)
|
|
568
576
|
adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
|
|
569
577
|
d_model=resolved_params.get("d_model", model.d_model),
|
|
570
578
|
requested_heads=resolved_params.get("n_heads"),
|
|
@@ -728,6 +736,7 @@ class FTTrainer(TrainerBase):
|
|
|
728
736
|
num_numeric_tokens=self._resolve_numeric_tokens(),
|
|
729
737
|
loss_name=loss_name,
|
|
730
738
|
)
|
|
739
|
+
self.model = self._apply_dataloader_overrides(self.model)
|
|
731
740
|
resolved_params = dict(params or {})
|
|
732
741
|
# Reuse supervised tuning structure params unless explicitly overridden.
|
|
733
742
|
if not resolved_params and self.best_params:
|
|
@@ -797,8 +806,4 @@ class FTTrainer(TrainerBase):
|
|
|
797
806
|
self.model,
|
|
798
807
|
pred_prefix=pred_prefix,
|
|
799
808
|
predict_kwargs_train=predict_kwargs_train,
|
|
800
|
-
|
|
801
|
-
)
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
# =============================================================================
|
|
809
|
+
|
|
@@ -59,7 +59,7 @@ class GNNTrainer(TrainerBase):
|
|
|
59
59
|
self.ctx.config.gnn_knn_gpu_mem_overhead),
|
|
60
60
|
loss_name=loss_name,
|
|
61
61
|
)
|
|
62
|
-
return model
|
|
62
|
+
return self._apply_dataloader_overrides(model)
|
|
63
63
|
|
|
64
64
|
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
65
65
|
base_tw_power = self.ctx.default_tweedie_power()
|
|
@@ -45,7 +45,7 @@ class ResNetTrainer(TrainerBase):
|
|
|
45
45
|
getattr(self.ctx.config, "resn_weight_decay", 1e-4),
|
|
46
46
|
)
|
|
47
47
|
)
|
|
48
|
-
|
|
48
|
+
model = ResNetSklearn(
|
|
49
49
|
model_nme=self.ctx.model_nme,
|
|
50
50
|
input_dim=self._resolve_input_dim(),
|
|
51
51
|
hidden_dim=int(params.get("hidden_dim", 64)),
|
|
@@ -64,6 +64,7 @@ class ResNetTrainer(TrainerBase):
|
|
|
64
64
|
use_ddp=self.ctx.config.use_resn_ddp,
|
|
65
65
|
loss_name=loss_name
|
|
66
66
|
)
|
|
67
|
+
return self._apply_dataloader_overrides(model)
|
|
67
68
|
|
|
68
69
|
# ========= Cross-validation (for BayesOpt) =========
|
|
69
70
|
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
@@ -232,6 +232,14 @@ class TorchTrainerMixin:
|
|
|
232
232
|
"""Determine number of DataLoader workers."""
|
|
233
233
|
if os.name == 'nt':
|
|
234
234
|
return 0
|
|
235
|
+
override = getattr(self, "dataloader_workers", None)
|
|
236
|
+
if override is None:
|
|
237
|
+
override = os.environ.get("BAYESOPT_DATALOADER_WORKERS")
|
|
238
|
+
if override is not None:
|
|
239
|
+
try:
|
|
240
|
+
return max(0, int(override))
|
|
241
|
+
except (TypeError, ValueError):
|
|
242
|
+
pass
|
|
235
243
|
if getattr(self, "is_ddp_enabled", False):
|
|
236
244
|
return 0
|
|
237
245
|
profile = profile or self._resolve_resource_profile()
|
ins_pricing/setup.py
CHANGED
|
@@ -3,7 +3,7 @@ ins_pricing/README.md,sha256=W4V2xtzM6pyQzwJPvWP7cNn-We9rxM8xrxRlBVQwoY8,3399
|
|
|
3
3
|
ins_pricing/RELEASE_NOTES_0.2.8.md,sha256=KIJzk1jbZbZPKjwnkPSDHO_2Ipv3SP3CzCNDdf07jI0,9331
|
|
4
4
|
ins_pricing/__init__.py,sha256=46j1wCdLVrgrofeBwKl-3NXTxzjbTv-w3KjW-dyKGiY,2622
|
|
5
5
|
ins_pricing/exceptions.py,sha256=5fZavPV4zNJ7wPC75L215KkHXX9pRrfDAYZOdSKJMGo,4778
|
|
6
|
-
ins_pricing/setup.py,sha256=
|
|
6
|
+
ins_pricing/setup.py,sha256=Jyq-oIi6qUEbE7_hXVWWbNXvF_mVP_aNQvOjf10DSu8,1702
|
|
7
7
|
ins_pricing/cli/BayesOpt_entry.py,sha256=6UBVxu36O3bXn1WC-BBi-l_W9_MqEoHmDGnwwDKNo5Q,1594
|
|
8
8
|
ins_pricing/cli/BayesOpt_incremental.py,sha256=_Klr5vvNoq_TbgwrH_T3f0a6cHmA9iVJMViiji6ahJY,35927
|
|
9
9
|
ins_pricing/cli/Explain_Run.py,sha256=gEPQjqHiXyXlCTKjUzwSvbAn5_h74ABgb_sEGs-YHVE,664
|
|
@@ -46,7 +46,7 @@ ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md,sha256=B8ZEzaL
|
|
|
46
46
|
ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md,sha256=hJZKXe9-bBGJVN_5c5l8nHQ1X7NK4BbeE-uXQoH0rAM,7479
|
|
47
47
|
ins_pricing/modelling/core/bayesopt/__init__.py,sha256=nj6IA0r7D5U5-hYyiwXmcp_bEtoU-hRJ_prdtRmLMg0,2070
|
|
48
48
|
ins_pricing/modelling/core/bayesopt/config_components.py,sha256=OjRyM1EuSXL9_3THD1nGLRsioJs7lO_ZKVZDkUA3LX8,12156
|
|
49
|
-
ins_pricing/modelling/core/bayesopt/config_preprocess.py,sha256=
|
|
49
|
+
ins_pricing/modelling/core/bayesopt/config_preprocess.py,sha256=vjxhDuJJm-bYyfphWnsZP_O3Tgtx22WGo80myLCB4cw,21647
|
|
50
50
|
ins_pricing/modelling/core/bayesopt/core.py,sha256=1m4pCrPP3iYIfU6QX3j6Eczjwz3-cD4ySzv9bll3PGg,44474
|
|
51
51
|
ins_pricing/modelling/core/bayesopt/model_explain_mixin.py,sha256=jCk1zPpwgwBBCndaq-A0_cQnc4RHueh2p5cAuE9ArTo,11620
|
|
52
52
|
ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py,sha256=lD0rUvWV4eWatmTzMrmAUm2Flj8uAOa3R9S2JyYV94k,21807
|
|
@@ -58,11 +58,11 @@ ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py,sha256=jk9pm7IzVL
|
|
|
58
58
|
ins_pricing/modelling/core/bayesopt/models/model_gnn.py,sha256=blCTgML-fMkHDerzwoJZPw2XnEvuwVR_U5t0YWE1lZI,32901
|
|
59
59
|
ins_pricing/modelling/core/bayesopt/models/model_resn.py,sha256=Pddu0q04Sz8RwKqjP0fv4xXWd6KobwMsD47sCDBbB-Y,17581
|
|
60
60
|
ins_pricing/modelling/core/bayesopt/trainers/__init__.py,sha256=ODYKjT-v4IDxu4ohGLCXY8r1-pMME9LAaNx6pmj5_38,481
|
|
61
|
-
ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py,sha256=
|
|
62
|
-
ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py,sha256=
|
|
61
|
+
ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py,sha256=DOam1HLsslNQlhQt88j6LWCZekioaNzw2PbKzLwtszY,55687
|
|
62
|
+
ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py,sha256=DkBHF50RK70W2LlLWKbg-NEnE2LOn1cI9d-1u8zkOD0,35816
|
|
63
63
|
ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py,sha256=gMhx9IX9nz-rsf-zi9UYMtViBPD1nmQ5r8XVPGU21Ys,7912
|
|
64
|
-
ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py,sha256=
|
|
65
|
-
ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py,sha256=
|
|
64
|
+
ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py,sha256=vJbQTm-3ByBguZYz4gvYhXeWbUa0L7z7mxaxoamGsic,14259
|
|
65
|
+
ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py,sha256=MaxCvmyybE72H38akOA_rNW3Rx9Mxb229ZRJaqIzWTA,11855
|
|
66
66
|
ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py,sha256=NlEqH0wxe5frwxXNTeymWX5_qC3_rIzF3QjDZz4RBMg,13752
|
|
67
67
|
ins_pricing/modelling/core/bayesopt/utils/__init__.py,sha256=dbf4DrWOH4rABOuaZdBF7drYOBH5prjvM0TexT6DYyg,1911
|
|
68
68
|
ins_pricing/modelling/core/bayesopt/utils/constants.py,sha256=0ihYxGlJ8tIElYvkhIDe5FfJShegvu29WZ_Xvfqa0iE,5790
|
|
@@ -70,7 +70,7 @@ ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py,sha256=cu01dHyYE5
|
|
|
70
70
|
ins_pricing/modelling/core/bayesopt/utils/io_utils.py,sha256=vXDlAc_taCG2joxnC6wu0jVYA76UhRbX9OT_5z_im-E,3857
|
|
71
71
|
ins_pricing/modelling/core/bayesopt/utils/losses.py,sha256=yn3ggeM1NRkCzcTt_Nef_EvpD6Pb_jGs49bj-VV4uWU,3894
|
|
72
72
|
ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py,sha256=kfQZnGE8FvGfl7WsTFShGGIA_sQhp5Th9mrwUXphiNQ,21200
|
|
73
|
-
ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py,sha256=
|
|
73
|
+
ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py,sha256=qxSAP4vQC-JTHTZ7SDPZx7ZKyJKLwcAOP0IWr4GA_l4,25881
|
|
74
74
|
ins_pricing/modelling/explain/__init__.py,sha256=CPoGzGu8TTO3FOXjxoXC13VkuIDCf3YTH6L3BqJq3Ok,1171
|
|
75
75
|
ins_pricing/modelling/explain/gradients.py,sha256=9TqCws_p49nFxVMcjVxe4KCZ7frezeL0uV_LCdoM5yo,11088
|
|
76
76
|
ins_pricing/modelling/explain/metrics.py,sha256=K_xOY7ZrHWhbJ79RNB7eXN3VXeTe8vq68ZLH2BlZufA,5389
|
|
@@ -131,7 +131,7 @@ ins_pricing/utils/paths.py,sha256=o_tBiclFvBci4cYg9WANwKPxrMcglEdOjDP-EZgGjdQ,87
|
|
|
131
131
|
ins_pricing/utils/profiling.py,sha256=kmbykHLcYywlZxAf_aVU8HXID3zOvUcBoO5Q58AijhA,11132
|
|
132
132
|
ins_pricing/utils/torch_compat.py,sha256=UrRsqx2qboDG8WE0OmxNOi08ojwE-dCxTQh0N2s3Rgw,2441
|
|
133
133
|
ins_pricing/utils/validation.py,sha256=4Tw9VUJPk0N-WO3YUqZP-xXRl1Xpubkm0vi3WzzZrv4,13348
|
|
134
|
-
ins_pricing-0.4.
|
|
135
|
-
ins_pricing-0.4.
|
|
136
|
-
ins_pricing-0.4.
|
|
137
|
-
ins_pricing-0.4.
|
|
134
|
+
ins_pricing-0.4.3.dist-info/METADATA,sha256=33V_YLmiYurQvj4716hMRASiIVW4MgLeDstga8N6w8g,6263
|
|
135
|
+
ins_pricing-0.4.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
136
|
+
ins_pricing-0.4.3.dist-info/top_level.txt,sha256=haZuNQpHKNBEPZx3NjLnHp8pV3I_J9QG8-HyJn00FA0,12
|
|
137
|
+
ins_pricing-0.4.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|