ins-pricing 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/modelling/core/bayesopt/config_preprocess.py +7 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +13 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +830 -804
- ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +1 -1
- ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +2 -1
- ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +8 -0
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.4.2.dist-info → ins_pricing-0.4.4.dist-info}/METADATA +1 -1
- {ins_pricing-0.4.2.dist-info → ins_pricing-0.4.4.dist-info}/RECORD +11 -11
- {ins_pricing-0.4.2.dist-info → ins_pricing-0.4.4.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.2.dist-info → ins_pricing-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -195,6 +195,7 @@ class BayesOptConfig:
|
|
|
195
195
|
cache_predictions: bool = False
|
|
196
196
|
prediction_cache_dir: Optional[str] = None
|
|
197
197
|
prediction_cache_format: str = "parquet"
|
|
198
|
+
dataloader_workers: Optional[int] = None
|
|
198
199
|
|
|
199
200
|
def __post_init__(self) -> None:
|
|
200
201
|
"""Validate configuration after initialization."""
|
|
@@ -210,6 +211,12 @@ class BayesOptConfig:
|
|
|
210
211
|
errors.append(
|
|
211
212
|
f"task_type must be one of {valid_task_types}, got '{self.task_type}'"
|
|
212
213
|
)
|
|
214
|
+
if self.dataloader_workers is not None:
|
|
215
|
+
try:
|
|
216
|
+
if int(self.dataloader_workers) < 0:
|
|
217
|
+
errors.append("dataloader_workers must be >= 0 when provided.")
|
|
218
|
+
except (TypeError, ValueError):
|
|
219
|
+
errors.append("dataloader_workers must be an integer when provided.")
|
|
213
220
|
# Validate loss_name
|
|
214
221
|
try:
|
|
215
222
|
normalized_loss = normalize_loss_name(self.loss_name, self.task_type)
|
|
@@ -306,6 +306,19 @@ class TrainerBase:
|
|
|
306
306
|
self.enable_distributed_optuna: bool = False
|
|
307
307
|
self._distributed_forced_params: Optional[Dict[str, Any]] = None
|
|
308
308
|
|
|
309
|
+
def _apply_dataloader_overrides(self, model: Any) -> Any:
|
|
310
|
+
"""Apply dataloader-related overrides from config to a model."""
|
|
311
|
+
cfg = getattr(self.ctx, "config", None)
|
|
312
|
+
if cfg is None:
|
|
313
|
+
return model
|
|
314
|
+
workers = getattr(cfg, "dataloader_workers", None)
|
|
315
|
+
if workers is not None:
|
|
316
|
+
model.dataloader_workers = int(workers)
|
|
317
|
+
profile = getattr(cfg, "resource_profile", None)
|
|
318
|
+
if profile:
|
|
319
|
+
model.resource_profile = str(profile)
|
|
320
|
+
return model
|
|
321
|
+
|
|
309
322
|
def _export_preprocess_artifacts(self) -> Dict[str, Any]:
|
|
310
323
|
dummy_columns: List[str] = []
|
|
311
324
|
if getattr(self.ctx, "train_oht_data", None) is not None:
|