ins-pricing 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -195,6 +195,7 @@ class BayesOptConfig:
195
195
  cache_predictions: bool = False
196
196
  prediction_cache_dir: Optional[str] = None
197
197
  prediction_cache_format: str = "parquet"
198
+ dataloader_workers: Optional[int] = None
198
199
 
199
200
  def __post_init__(self) -> None:
200
201
  """Validate configuration after initialization."""
@@ -210,6 +211,12 @@ class BayesOptConfig:
210
211
  errors.append(
211
212
  f"task_type must be one of {valid_task_types}, got '{self.task_type}'"
212
213
  )
214
+ if self.dataloader_workers is not None:
215
+ try:
216
+ if int(self.dataloader_workers) < 0:
217
+ errors.append("dataloader_workers must be >= 0 when provided.")
218
+ except (TypeError, ValueError):
219
+ errors.append("dataloader_workers must be an integer when provided.")
213
220
  # Validate loss_name
214
221
  try:
215
222
  normalized_loss = normalize_loss_name(self.loss_name, self.task_type)
@@ -306,6 +306,19 @@ class TrainerBase:
306
306
  self.enable_distributed_optuna: bool = False
307
307
  self._distributed_forced_params: Optional[Dict[str, Any]] = None
308
308
 
309
+ def _apply_dataloader_overrides(self, model: Any) -> Any:
310
+ """Apply dataloader-related overrides from config to a model."""
311
+ cfg = getattr(self.ctx, "config", None)
312
+ if cfg is None:
313
+ return model
314
+ workers = getattr(cfg, "dataloader_workers", None)
315
+ if workers is not None:
316
+ model.dataloader_workers = int(workers)
317
+ profile = getattr(cfg, "resource_profile", None)
318
+ if profile:
319
+ model.resource_profile = str(profile)
320
+ return model
321
+
309
322
  def _export_preprocess_artifacts(self) -> Dict[str, Any]:
310
323
  dummy_columns: List[str] = []
311
324
  if getattr(self.ctx, "train_oht_data", None) is not None: